Hasso-Plattner-InstitutSDG am HPI
Hasso-Plattner-InstitutDSG am HPI
Login
 

Learning Disentangled Deep Latent Space Representations

Alexander Rakowski

In Research School since October 2019

Chair Digital Health & Machine Learning

Office: Campus III Building G2, Room G-2.1.32
Email: alexander.rakowski(at)hpi.de
Links: HomeGoogle Scholar

Supervisor: Prof. Dr. Christoph Lipper

Introduction

Representation learning algorithms aim to learn transformations mapping complex, usually high-dimensional data (e.g., images) into more compact (lower dimensional) feature representations. These should ideally describe more abstract characteristics of the data (e.g., objects in an image). This can be ilustrated with the following diagram:

 

 

where \(x_1 ... x_m \) denote the dimensions of the observed data and \(r_1 ... r_k \) dimensions of the learned representation. Note that usually \(k << m \).

In the setting of learning disentangled representations we assume that the observed data were caused by a set of so called ground-truth factors of variation, via a (possibly unknown) generative process. Typically we further assume that these factors are mutually independent, and come from a given prior distribution (e.g., normal).

 

A learned representation is then assumed to be disentangled if each of these factors would only influence distinct sets of the learned features. More formally, we want the following to hold:

\( \forall_i : | \{j: \frac{\partial r_i}{\partial g_j} \neq 0 \} | \leq 1 \),

where \( \frac{\partial r_i}{\partial g_j} \) denotes the derivative of the i-th learned feature wrt. the j-th ground-truth factor of variation. Reversely, a representation is fully entangled if a change in just one factor \( g_j \) would induce a change in each of the learned features \( r_i(x) \). Ideally, one would want the learned representation to match the ground-truth factors (up to a permutation of the dimensions).

Disentaglement is believed to be a crucial property when learning representations, allowing for better interpretability and transferability across tasks (Bengio et al. [1]).

Background (State of the Art)

Variational Autoencoders (VAEs)

The most popular approach to learning disentangled representations is to employ models from the Variational Autoencoders (VAEs) framework (Kingma, Welling [2]), which combine Variational Inference with Autoencoder models. A VAE is a stochastic generative model, consisting of two (usually symmetric) networks - the encoder and the decoder. They are trained to jointly model the distribution of the observed data \( p(x) \) along with a conditional distribution of latent variables \( p(z|x) \). The latter is defined by the encoder network and serves as the aforementioned latent represenation (encodings) of the data.

In standard autoencoder models the encoder yields a point estimate of the latent representation \( z_i = E(x_i) \). In a VAE on the other hand we obtain parameters of a distribution over possible encodings \( \mu_i,\,\sigma^{2}_i = E(x_i) \), from which we then sample: \( z_i \sim \mathcal{N}(\mu_i,\,\sigma^{2}_i ) \). 

A key aspect of this framework is the regularization imposed on these latent encodings, forcing each of them to be close to a prior distribution (usually a standard normal distribution), measured with Kullback-Leibler divergence. The tradeoff between regularization and reconstruction introduces an additional informational bottleneck, which is believed to force the model to ignore less important factors of variation in the data.

The training objective is then:

\( L = \mathbb{E}_{x \sim X } ( \mathbb{E}_{z \sim E(x) } [ \| x - D(z) \|^2 ]  + D_{KL} (E(z|x) \| p(z) ) ) \),

where \( E(\cdot), D(\cdot) \) denote the functions defined by the encoder and decoder network respectively, \( z \sim E(x) \) is a shorthand for \( z \sim \mathcal{N}(\mu,\,\sigma^{2} ) \), where  \( \mu_i,\,\sigma^{2} = E(x) \), \( D_{KL}\) denotes Kullback-Leibler divergence and \( p(z) \) denotes the probability density function of the prior distribution. The first term in the objective corresponds to a mean squared error of reconstructing the inputs, while the second one corresponds to the regularization term.

Approaches to Disentanglement

Most approaches to obtaining disentanglement can be divided in two groups - ones that modify the regularization term (e.g., by increasing its corresponding weight) or by enforcing statistical independence of dimensions of the latent space representations, effectively imposing a factorising prior on the aggregated posterior of the encodings. 

However, a large-scale study (Locatello et al. [3]) showed that all these methods do not yield consistent performance (in terms of disentangling), with high variance of achieved scores and low correlation between the selected method and level of disentanglement obtained. Also, and perhaps more importantly, they introduced the so-called Impossibility Result. It states that in a purely unsupervised setting it is impossible to identify disentangled representations, even when succesfully imposing the factorising prior.

Current Work

Disentanglement and Local Directions of Variance

How does the structure of variance in datasets affects disentanglement performance of Variational Autoencoders?

The aforementioned impossibility result, showing the crucial problem of unsupervised disentangling, seemed to be in contrast with previous works on VAE-based approaches, which were reporting improvements in performance. Locatello et al. [3] point to possible inductive biases as one means of overcoming this issue. One such a bias is postulated to be the connection between VAEs and Principal Component Analysis (PCA). Rolinek et al. [6] argue that the PCA-like behavior leads to disentanglement, and Zietlow et al. [7] present this as an inductive bias exploiting a globally consistent structure of variance in the data - disentanglement should occur because the principal directions of variance are (globally) aligned with the ground-truth factors of variation. We follow this hypothesis by quantitatively analysing properties of a large set of models trained on benchmark datasets using measures designed to capture local and global structures of variance in both the data and trained models (representations). 

Drawing from the PCA analogies we speculate on conditions regarding the correct alignment of the learned representations, in the case of finite data. The plots show the distribution of two normal random variables with differing variances. Dashed lines indicate the "true" principal directions, while solid lines indicate principal directions found by PCA on a finite-size sample:

As the variances differ more (left plot), the PCA solution is relatively well-aligned with the ground-truth - however as the variances become closer (right plot), the PCA solution can is more likely to become misaligned. We observed a similar effect with models trained on both toy and benchmark datasets - they seem to disentangle better when the underlying ground-truth factors of variation induce different amounts of variance in the data.

We also investigate the effect of global consistency of the principal directions of variance, i.e., whether a ground-truth factor of variance induces the same amount of variance in the data globally on the data manifold. An example of globally inconsistent data can be seen below - changing the color of the object influences a different number of pixels depending on the object's size (position on the size-corresponding axis in the latent space):

 

Analysing results on toy and benchmark datasets, we observe that VAEs disentangle the better on datasets with globally consistent structures of variance. Additionally, we note that such variance inconsistencies lead not only to entanglement (incorrect alignment of axes), but also to inconsitency of encodings (factors being encoded differently depending on the position on the data manifold).

Past Projects

Effect of Sparsity on Disentanglement

Sparsity in machine learning models can be referring to either a sparse output of the model (e.g., a feature vector with many zero entries) or to sparsity of the solution (many parameter weights equal or close to zero). 

In this project I investigate the latter - that is sparsity of neural networks - and its relation to disentanglement. Standard neural network models have dense channel-to-channel connectivity. However, it is not necessary for a feature (represented by a hidden unit) to be constructed using all available features from the subsequent layer - instead, it should suffice for it to be a combination of only a subset of them, without a loss in performance. This has been verified empirically by methods such as the Lottery Ticket Hypothesis, showing that in many cases a model with 90% less parameters can still achieve similar performance (Frankie, Carbin [4]). 

A motivation behind a possible relation between model sparsity and disentanglement can be estabilished by revision of the proof of the aforementioned impossibility result. Even if a model is powerful enough to recover all underlying factors of variation of the data and encode them in the latent space, it suffices to rotate this space to obtain an entangled representation. Because the employed prior distributions are rotationally invariant it is then impossible to identify a correct model. I argue that this entanglement due to rotations can happen already inside the model, in the hidden layers. This is due to the dense structure of modern neural networks, since there is no incentive for the model to use the connections sparsely, especially given a random initialization of weights. Explicitly enforcing sparsity might thus alleviate this potential issue, by favoring models which "rotate the features less". Preference for sparsity can also be found in the literature regarding causality. By the Independent Causal Mechanisms Principle, the casual mechanisms (which can be modeled as hidden layers of a deep network) “should usually not affect all factors simultaneously“, that is, they should propagate down the network in a sparse and local manner (Schölkopf [5]).

A naïve approach to obtaining sparsity would be to simply employ the \( L_1 \) weight penalty. While this is plausible for 2-dimensional weight matrices, used by fully-connected layers, it does not have to have the desired effect in case of convolutional layers, whose weight tensors are at least 3-dimensional. Instead, we are interested in limiting the number of channel-to-channel connections - that is, zeroing out all weights in a convolutional filter that are connected to a particular channel of a previous layer. We implement this by introducing an additional masking tensor to each layer. Before computing a layer's output, its weight tensor is first multiplied with this mask tensor. Thus each entry in the mask effectively serves as a gate, controlling flow of information between a pair of channels from subsequent layers. This is ilustrated in the figure below:

Example of a binary variant of the proposed masking technique applied on a fully-connected layer. An entry of \( M_{i, j} \) equal to 0 will remove the connection between the input node \( h^1_i \) and output node \( h^2_j \) , as indicated by the dotted lines.

More formally, for a convolution layer with a corresponding weight tensor \( K \in \mathbb{R}^{m \times n \times w \times h} \) we are interested in zeroing out all entries of slices \( K_{i, j, :, :} \) of shape \( w \times h \), for certain pairs of indices \( i, j \) indexing channels in the layer's input and output respectively. To achieve this, we extend the layers that we want to impose sparsity on with an additional masking tensor \( M \in \mathbb{R}^{m \times n} \). During the forward pass a masked weight tensor \( K' = K \odot M \) is used to compute the output, where \( \odot \) denotes the Hadamard product. This method also applies in the exact same way to fully-connected layers, with the only difference being that the original weight tensor is 2-dimensional.

References

1. Bengio Y, Courville A, Vincent P. Representation Learning: A Review and New Perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2013 Mar 7;35(8):1798-828.

2. Kingma DP, Welling M. Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114. 2013 Dec 20.

3. Locatello F, Bauer S, Lucic M, Raetsch G, Gelly S, Schölkopf B, Bachem O. Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations. In International Conference on Machine Learning 2019 May 24 (pp. 4114-4124).

4. Frankle J, Carbin M. The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. arXiv preprint arXiv:1803.03635. 2018 Mar 9.

5. Schölkopf B. Causality for Machine Learning. arXiv preprint arXiv:1911.10500. 2019 Nov 24.

6. Rolinek, M., Zietlow, D., Martius, G.: Variational autoencoders pursue pca directions
(by accident). In: Proceedings of the IEEE/CVF Conference on Computer Vision
and Pattern Recognition. pp. 12406–12415 (2019)

7. D. Zietlow, M. Rolinek, G. Martius: Demystifying Inductive Biases for (Beta-)VAE Based Architectures. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:12945-12954, 2021.

Publications

  • Predicting the SARS-CoV-2 effective reproduction number using bulk contact data from mobile phones. Rüdiger, Sten; Konigorski, Stefan; Rakowski, Alexander; Edelman, Jonathan Antonio; Zernick, Detlef; Thieme, Alexander; Lippert, Christoph in Proceedings of the National Academy of Sciences (2021). 118(31)
     
  • Disentanglement and Local Directions of Variance. Rakowski, Alexander; Lippert, Christoph in Machine Learning and Knowledge Discovery in Databases. Research Track (2021). 19–34.
     

Teaching Activities

Co-tutored the Master Thesis of Dominika Matus: "Recognition of Epilepsy Seizures in EEG Data Using Deep Learning Methods"