Hasso-Plattner-Institut25 Jahre HPI
Hasso-Plattner-Institut25 Jahre HPI
Login
 

Learning Disentangled Deep Latent Space Representations

Alexander Rakowski

In the Research School since October 2019

Chair Digital Health & Machine Learning

Office: Campus III Building G2, Room G-2.1.32
Email: alexander.rakowski(at)hpi.de
Links: HomeGoogle Scholar

Supervisor: Prof. Dr. Christoph Lippert

Introduction

I am investigating methods of learning disentangled representations of data with deep neural networks and their applications in the medical domain. Deep representation learning is an ever growing field, but as the computational capabilities of modern neural networks continue to grow, the problem of their interpretability remains open. This is especially crucial in medical tasks, where working with purely "black-box" models might be simply not acceptable. By learning disentangled representations of data we aim for greater interpretability, transferability and even the ability to answer counterfactual questions.

In my recent work I investigated the relations between dataset properties and disentanglement, uncovering the role of structures of variance caused by the underlying factors. 

Problem Setting and Definition

Representation learning algorithms aim to learn transformations mapping complex, usually high-dimensional data (e.g., images) into more compact (lower dimensional) feature representations. These should ideally describe more abstract characteristics of the data (e.g., objects in an image). This can be ilustrated with the following diagram:

 

where \(x_1 ... x_l \) denote the dimensions of the observed data and \(r_1 ... r_m \) dimensions of the learned representation. Note that usually \(m << k \).

In the setting of learning disentangled representations we assume that the observed data were caused by a set of so called ground-truth factors of variation, via a (possibly unknown) generative process. Typically we further assume that these factors are mutually independent, and come from a given prior distribution (e.g., normal):

 

A learned representation is then assumed to be disentangled if each of these factors would only influence distinct sets of the learned features. More formally, we want the following to hold:

\( \forall i : | \{ j \in \{1 \ldots k\}: \frac{\partial r_i}{\partial z_j} \neq 0 \} | \leq 1 \),

where \( \frac{\partial r_i}{\partial z_j} \) denotes the derivative of the i-th learned feature wrt. the j-th ground-truth factor of variation. Reversely, a representation is fully entangled if a change in just one factor \( z_j \) would induce a change in each of the learned features \( r_i(x) \). Ideally, one would want the learned representation to match the ground-truth factors (up to a permutation of the dimensions).

Disentaglement is believed to be a crucial property when learning representations, allowing for better interpretability and transferability across tasks (Bengio et al. [1]).

Variational Autoencoders (VAEs)

The most popular approach to learning disentangled representations is to employ models from the Variational Autoencoders (VAEs) framework (Kingma, Welling [2]), which combine Variational Inference with Autoencoder models. A VAE is a stochastic generative model, consisting of two (usually symmetric) networks - the encoder and the decoder. They are trained to jointly model the distribution of the observed data \( p(x) \), with a generative model, along with a conditional distribution of latent variables \( p(z|x) \), with a predictive model. The latter is defined by the encoder network and serves as the aforementioned latent represenation (encodings) of the data.

In standard autoencoder models the encoder yields a point estimate of the latent representation \( z_i = E(x_i) \). In a VAE, on the other hand, we obtain parameters of a distribution over possible encodings \( \mu_i,\,\sigma^{2}_i = E(x_i) \), from which we then sample: \( z_i \sim \mathcal{N}(\mu_i,\,\sigma^{2}_i ) \). 

A key aspect of this framework is the regularization imposed on these latent encodings, forcing each of them to be close to a prior distribution (usually a standard normal distribution), measured with Kullback-Leibler divergence. The tradeoff between regularization and reconstruction introduces an additional informational bottleneck, which is believed to force the model to ignore less important factors of variation in the data.

Approaches to Disentanglement

Most approaches to obtaining disentanglement can be divided in two groups - ones that modify the regularization term (e.g., by increasing its corresponding weight) or by enforcing statistical independence of dimensions of the latent space representations, effectively imposing a factorising prior on the aggregated posterior of the encodings. 

However, a large-scale study (Locatello et al. [3]) showed that all these methods do not yield consistent performance (in terms of disentangling), with high variance of achieved scores and low correlation between the selected method and level of disentanglement obtained. Also, and perhaps more importantly, they introduced the so-called Impossibility Result. It states that in a purely unsupervised setting it is impossible to identify disentangled representations, even when succesfully imposing the factorising prior.

Weakly-Supervised Disentanglement for Longitudinal Brain Imaging Studies

Identifying a disentangled representation in a purely unsupervised setting has been proven impossible (Locatello et al. [3]). However, labeling of the dataset can be labor-intensive, especially when expert knowledge is required. In such cases, weakly-supervised algorithms can be employed, which leverage high-level auxiliary information about the data samples, for example labels indicating which class does a sample belong to (Bouchacourt et al. [8], Shu et al. [9]) or which samples have the same (although unknown) values for a subset of traits (Bengio et al. [10], Locatello et al. [11]).

In this work we evaluate a weakly-supervised approach for learning disentangled representations of brain Magnetic Resonace Imaging (MRI) data. We leverage the longitudinal nature of brain imaging studies, such as the Alzheimer’s Disease Neuroimaging Initiative (ADNI), using repeated measurements of the same subjects as the signal for weak supervision. We compare the proposed approach across a range of settings, measuring its disentanglement as well as performance in real-life downstream tasks, such as dementia score prediction or genome-wide association tests.

In a longitudinal study the measurements of interest (e.g., MRI scans) are repeated over several points in time for each participant. We assume that different scans from the same participant obtained at different time-points will be similar to each other, that is they will share certain underlying attributes, such as volumetric measures or disease state, even though we do not have direct access to these attributes. We leverage this assumption to construct pairs of samples from the same participants and use it to train a state-of-the-art weakly-supervised disentangled representation learning model - Adaptive-Group-Variational-Autoencoder (Ada-GVAE, Locatello et al. [11]). During training, the model compares encodings of each sample in a pair and adaptively selects a group of most similar dimensions. Values in these dimensions are then averaged iniside each pair of samples to create modified, partially averaged encodings. For example, a scan repeated after a few years on the same person will show roughly the same head size, but the white matter volume might decrease due to aging. The modified encodings are then used to reconstruct the original inputs. This forces the model to select only the features shared between two samples for averaging, effectively improving disentanglement of the representations. A graphical illustration of the approach is shown below:

 

Below we can see a visual investigation of representations of two trained models - baseline (5 columns on the left) and the proposed Ada-GVAE model (5 columns on the right). Odd rows show images generated by interpolating across a single latent dimension while even rows show difference images with respect to the mean (middle) image. The Ada-GVAE (right) model was able to better disentangle intensity (top rows) from features related to aging or neurodegeneration, such as changes in ventricle or fissure sizes (bottom rows). In the baseline model (left) these features are more entangled with each other - both ventricle size and intensity are changing in each row.

Interpolations of the baseline model
Interpolations of the Ada-GVAE model

Past Projects

Disentanglement and Local Directions of Variance

In this project we investigate how the structure of variance in datasets affects disentanglement performance of Variational Autoencoders (VAEs). We propose and compute measures quantifying variance properties of datasets and relate them to the performance of a large number of models trained on these data. This allows us to reason about the mechanisms governing disentanglement in VAEs and the influence of structures of variance in the data on them.

The aforementioned impossibility result, showing the crucial problem of unsupervised disentangling, seemed to be in contrast with previous works on VAE-based approaches, which were reporting improvements in performance. Locatello et al. [3] point to possible inductive biases as one means of overcoming this issue. One such a bias is postulated to be the connection between VAEs and Principal Component Analysis (PCA). Rolinek et al. [6] argue that the PCA-like behavior leads to disentanglement, and Zietlow et al. [7] present this as an inductive bias exploiting a globally consistent structure of variance in the data - disentanglement should occur because the principal directions of variance are (globally) aligned with the ground-truth factors of variation. We follow this hypothesis by quantitatively analysing properties of a large set of models trained on benchmark datasets using measures designed to capture local and global structures of variance in both the data and trained models (representations). 

Drawing from the PCA analogies we speculate on conditions regarding the correct alignment of the learned representations, in the case of finite data. The plots show the distribution of two normal random variables with differing variances. Dashed lines indicate the "true" principal directions, while solid lines indicate principal directions found by PCA on a finite-size sample:

As the variances differ more (left plot), the PCA solution is relatively well-aligned with the ground-truth - however as the variances become closer (right plot), the PCA solution can is more likely to become misaligned. We observed a similar effect with models trained on both toy and benchmark datasets - they seem to disentangle better when the underlying ground-truth factors of variation induce different amounts of variance in the data.

We also investigate the effect of global consistency of the principal directions of variance, i.e., whether a ground-truth factor of variance induces the same amount of variance in the data globally on the data manifold. An example of globally inconsistent data can be seen below - changing the color of the object influences a different number of pixels depending on the object's size (position on the size-corresponding axis in the latent space):

Analysing results on toy and benchmark datasets, we observe that VAEs disentangle the better on datasets with globally consistent structures of variance. Additionally, we note that such variance inconsistencies lead not only to entanglement (incorrect alignment of axes), but also to inconsistency of encodings (factors being encoded using different dimensions of the representation, depending on the position on the data manifold).

Effect of Sparsity on Disentanglement

Sparsity in machine learning models can be referring to either a sparse output of the model (e.g., a feature vector with many zero entries) or to sparsity of the solution (many parameter weights equal or close to zero). 

In this project I investigate the latter - that is sparsity of neural networks - and its relation to disentanglement. Standard neural network models have dense channel-to-channel connectivity. However, it is not necessary for a feature (represented by a hidden unit) to be constructed using all available features from the subsequent layer - instead, it should suffice for it to be a combination of only a subset of them, without a loss in performance. This has been verified empirically by methods such as the Lottery Ticket Hypothesis, showing that in many cases a model with 90% less parameters can still achieve similar performance (Frankie, Carbin [4]). 

A motivation behind a possible relation between model sparsity and disentanglement can be estabilished by revision of the proof of the aforementioned impossibility result. Even if a model is powerful enough to recover all underlying factors of variation of the data and encode them in the latent space, it suffices to rotate this space to obtain an entangled representation. Because the employed prior distributions are rotationally invariant it is then impossible to identify a correct model. I argue that this entanglement can happen already in the hidden layers of a model. This is due to the dense structure of modern neural networks, since there is no incentive for the model to use the connections sparsely, especially given a random initialization of weights. Explicitly enforcing sparsity might thus alleviate this potential issue, by favoring models which "rotate the features less". Preference for sparsity can also be found in the literature regarding causality. By the Independent Causal Mechanisms Principle, the casual mechanisms (which can be modeled as hidden layers of a deep network) “should usually not affect all factors simultaneously“, that is, they should propagate down the network in a sparse and local manner (Schölkopf [5]).

A naïve approach to obtaining sparsity would be to simply employ the \( L_1 \) weight penalty. While this is plausible for 2-dimensional weight matrices, used by fully-connected layers, it does not have to have the desired effect in case of convolutional layers, whose weight tensors are at least 3-dimensional. Instead, we are interested in limiting the number of channel-to-channel connections - that is, zeroing out all weights in a convolutional filter that are connected to a particular channel of a previous layer. We implement this by introducing an additional masking tensor to each layer. Before computing a layer's output, its weight tensor is first multiplied with this mask tensor. Thus each entry in the mask effectively serves as a gate, controlling flow of information between a pair of channels from subsequent layers. This is ilustrated in the figure below:

Example of a binary variant of the proposed masking technique applied on a fully-connected layer. An entry of \( M_{i, j} \) equal to 0 will remove the connection between the input node \( h^1_i \) and output node \( h^2_j \) , as indicated by the dotted lines.

More formally, for a convolution layer with a corresponding weight tensor \( K \in \mathbb{R}^{m \times n \times w \times h} \) we are interested in zeroing out all entries of slices \( K_{i, j, :, :} \) of shape \( w \times h \), for certain pairs of indices \( i, j \) indexing channels in the layer's input and output respectively. To achieve this, we extend the layers that we want to impose sparsity on with an additional masking tensor \( M \in \mathbb{R}^{m \times n} \). During the forward pass a masked weight tensor \( K' = K \odot M \) is used to compute the output, where \( \odot \) denotes the Hadamard product. This method also applies in the exact same way to fully-connected layers, with the only difference being that the original weight tensor is 2-dimensional.

Preliminary Results

The plot below shows results obtained on the dSprites and Cars3D datases (rows), for 3 disentanglement metrics - Mutual Information Gap (MIG), DCI Disentanglement and Modularity (columns):

For each metric/dataset combination we plot the mean results over a range of 6 regularization strengths, along with the baseline model (dashed line). For certain regularization strengths there is an improvement over the baseline. However, extremely high values seem to hinder performance for the MIG and DCI metrics, while Modularity seems to exhibit an opposite trend. Overall, there seems to be no "sweet spot" for the hyperparameter setting that would transfer well across datasets and metrics.
 

References

1. Bengio Y, Courville A, Vincent P. Representation Learning: A Review and New Perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2013 Mar 7;35(8):1798-828.

2. Kingma DP, Welling M. Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114. 2013 Dec 20.

3. Locatello F, Bauer S, Lucic M, Raetsch G, Gelly S, Schölkopf B, Bachem O. Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations. In International Conference on Machine Learning 2019 May 24 (pp. 4114-4124).

4. Frankle J, Carbin M. The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. arXiv preprint arXiv:1803.03635. 2018 Mar 9.

5. Schölkopf B. Causality for Machine Learning. arXiv preprint arXiv:1911.10500. 2019 Nov 24.

6. Rolinek, M., Zietlow, D., Martius, G.: Variational autoencoders pursue pca directions (by accident). In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. pp. 12406–12415 (2019)

7. D. Zietlow, M. Rolinek, G. Martius: Demystifying Inductive Biases for (Beta-)VAE Based Architectures. Proceedings of the 38th International Conference on Machine Learning, PMLR 139:12945-12954, 2021.

8. Bouchacourt, D., Tomioka, R., Nowozin, S.: Multi-level variational autoencoder: Learning disentangled representations from grouped observations. In: Proceedings of the AAAI Conference on Artificial Intelligence. vol. 32 (2018)

9. Shu, R., Chen, Y., Kumar, A., Ermon, S., Poole, B.: Weakly supervised disentanglement with guarantees. arXiv preprint arXiv:1910.09772 (2019)

10. Bengio, Y., Deleu, T., Rahaman, N., Ke, R., Lachapelle, S., Bilaniuk, O., Goyal, A., Pal, C.: A meta-transfer objective for learning to disentangle causal mechanisms. arXiv preprint arXiv:1901.10912 (2019)

11. Locatello, F., Poole, B., Rätsch, G., Schölkopf, B., Bachem, O., Tschannen, M.: Weakly-supervised disentanglement without compromises. In: International Conference on Machine Learning. pp. 6348–6359. PMLR (2020)

Publications

  • Predicting the SARS-CoV-2 effective reproduction number using bulk contact data from mobile phones. Rüdiger, Sten; Konigorski, Stefan; Rakowski, Alexander; Edelman, Jonathan Antonio; Zernick, Detlef; Thieme, Alexander; Lippert, Christoph in Proceedings of the National Academy of Sciences (2021). 118(31)
     
  • Disentanglement and Local Directions of Variance. Rakowski, Alexander; Lippert, Christoph in Machine Learning and Knowledge Discovery in Databases. Research Track (2021). 19–34.
     

Teaching Activities

Co-tutored the Master Thesis of Dominika Matus: "Recognition of Epilepsy Seizures in EEG Data Using Deep Learning Methods"