Hasso-Plattner-InstitutSDG am HPI
Hasso-Plattner-InstitutDSG am HPI
  
Login
  • de
 

Learning Disentangled Deep Latent Space Representations

Alexander Rakowski

Chair Digital Health & Machine Learning

Office: Campus III Building G2, Room G-2.1.32
Email: alexander.rakowski(at)hpi.de
Links: HomeGoogle Scholar

Supervisor: Prof. Dr. Christoph Lippert
 

Introduction

Representation learning algorithms aim to learn transformations mapping complex, usually high-dimensional data (e.g., images) into more compact (lower dimensional) feature representations. These should ideally describe more abstract characteristics of the data (e.g., objects in an image). This can be ilustrated with the following diagram:

 

 

where \(x_1 ... x_m \) denote the dimensions of the observed data and \(r_1 ... r_k \) dimensions of the learned representation. Note that usually \(k << m \).

In the setting of learning disentangled representations we assume that the observed data were caused by a set of so called ground-truth factors of variation, via a (possibly unknown) generative process. Typically we further assume that these factors are mutually independent, and come from a given prior distribution (e.g., normal).

 

A learned representation is then assumed to be disentangled if each of these factors would only influence distinct sets of the learned features. More formally, we want the following to hold:

\( \forall_i : | \{j: \frac{\partial r_i}{\partial g_j} \neq 0 \} | \leq 1 \),

where \( \frac{\partial r_i}{\partial g_j} \) denotes the derivative of the i-th learned feature wrt. the j-th ground-truth factor of variation. Reversely, a representation is fully entangled if a change in just one factor \( g_j \) would induce a change in each of the learned features \( r_i(x) \). Ideally, one would want the learned representation to match the ground-truth factors (up to a permutation of the dimensions).

Disentaglement is believed to be a crucial property when learning representations, allowing for better interpretability and transferability across tasks (Bengio et al. [1]).

Background (State of the Art)

Variational Autoencoders (VAEs)

The most popular approach to learning disentangled representations is to employ models from the Variational Autoencoders (VAEs) framework (Kingma, Welling [2]), which combine Variational Inference with Autoencoder models. A VAE is a stochastic generative model, consisting of two (usually symmetric) networks - the encoder and the decoder. They are trained to jointly model the distribution of the observed data \( p(x) \) along with a conditional distribution of latent variables \( p(z|x) \). The latter is defined by the encoder network and serves as the aforementioned latent represenation (encodings) of the data.

In standard autoencoder models the encoder yields a point estimate of the latent representation \( z_i = E(x_i) \). In a VAE on the other hand we obtain parameters of a distribution over possible encodings \( \mu_i,\,\sigma^{2}_i = E(x_i) \), from which we then sample: \( z_i \sim \mathcal{N}(\mu_i,\,\sigma^{2}_i ) \). 

A key aspect of this framework is the regularization imposed on these latent encodings, forcing each of them to be close to a prior distribution (usually a standard normal distribution), measured with Kullback-Leibler divergence. The tradeoff between regularization and reconstruction introduces an additional informational bottleneck, which is believed to force the model to ignore less important factors of variation in the data.

The training objective is then:

\( L = \mathbb{E}_{x \sim X } ( \mathbb{E}_{z \sim E(x) } [ \| x - D(z) \|^2 ]  + D_{KL} (E(z|x) \| p(z) ) ) \),

where \( E(\cdot), D(\cdot) \) denote the functions defined by the encoder and decoder network respectively, \( z \sim E(x) \) is a shorthand for \( z \sim \mathcal{N}(\mu,\,\sigma^{2} ) \), where  \( \mu_i,\,\sigma^{2} = E(x) \), \( D_{KL}\) denotes Kullback-Leibler divergence and \( p(z) \) denotes the probability density function of the prior distribution. The first term in the objective corresponds to a mean squared error of reconstructing the inputs, while the second one corresponds to the regularization term.

Approaches to Disentanglement

Most approaches to obtaining disentanglement can be divided in two groups - ones that modify the regularization term (e.g., by increasing its corresponding weight) or by enforcing statistical independence of dimensions of the latent space representations, effectively imposing a factorising prior on the aggregated posterior of the encodings. 

However, a large-scale study (Locatello et al. [3]) showed that all these methods do not yield consistent performance (in terms of disentangling), with high variance of achieved scores and low correlation between the selected method and level of disentanglement obtained. Also, and perhaps more importantly, they introduced the so-called Impossibility Result. It states that in a purely unsupervised setting it is impossible to identify disentangled representations, even when succesfully imposing the factorising prior.

Current Work

My current research thus concentrates on alternative approaches, that go beyond regularization of the learned latent spaces.

Effect of Sparsity on Disentanglement

Sparsity in machine learning models can be referring to either a sparse output of the model (e.g., a feature vector with many zero entries) or to sparsity of the solution (many parameter weights equal or close to zero). 

In my project I decided to investigate the latter - that is sparsity of neural networks - and its relation to disentanglement. Standard neural network models have dense channel-to-channel connectivity. However, it is not necessary for a feature (represented by a hidden unit) to be constructed using all available features from the subsequent layer - instead, it should suffice for it to be a combination of only a subset of them, without a loss in performance. This has been verified empirically by methods such as the Lottery Ticket Hypothesis, showing that in many cases a model with 90% less parameters can still achieve similar performance (Frankie, Carbin [4]). 

A motivation behind a possible relation between model sparsity and disentanglement can be estabilished by revision of the proof of the aforementioned impossibility result. Even if a model is powerful enough to recover all underlying factors of variation of the data and encode them in the latent space, it suffices to rotate this space to obtain an entangled representation. Because the employed prior distributions are rotationally invariant it is then impossible to identify a correct model. I argue that this entanglement due to rotations can happen already inside the model, in the hidden layers. This is due to the dense structure of modern neural networks, since there is no incentive for the model to use the connections sparsely, especially given a random initialization of weights. Explicitly enforcing sparsity might thus alleviate this potential issue, by favoring models which "rotate the features less". Preference for sparsity can also be found in the literature regarding causality. By the Independent Causal Mechanisms Principle, the casual mechanisms (which can be modeled as hidden layers of a deep network) “should usually not affect all factors simultaneously“, that is, they should propagate down the network in a sparse and local manner (Schölkopf [5]).

A naïve approach to obtaining sparsity would be to simply employ the \( L_1 \) weight penalty. While this is plausible for 2-dimensional weight matrices, used by fully-connected layers, it does not have to have the desired effect in case of convolutional layers, whose weight tensors are at least 3-dimensional. Instead, we are interested in limiting the number of channel-to-channel connections - that is, zeroing out all weights in a convolutional filter that are connected to a particular channel of a previous layer. We implement this by introducing an additional masking tensor to each layer. Before computing a layer's output, its weight tensor is first multiplied with this mask tensor. Thus each entry in the mask effectively serves as a gate, controlling flow of information between a pair of channels from subsequent layers. This is ilustrated in the figure below:

Example of a binary variant of the proposed masking technique applied on a fully-connected layer. An entry of \( M_{i, j} \) equal to 0 will remove the connection between the input node \( h^1_i \) and output node \( h^2_j \) , as indicated by the dotted lines.

More formally, for a convolution layer with a corresponding weight tensor \( K \in \mathbb{R}^{m \times n \times w \times h} \) we are interested in zeroing out all entries of slices \( K_{i, j, :, :} \) of shape \( w \times h \), for certain pairs of indices \( i, j \) indexing channels in the layer's input and output respectively. To achieve this, we extend the layers that we want to impose sparsity on with an additional masking tensor \( M \in \mathbb{R}^{m \times n} \). During the forward pass a masked weight tensor \( K' = K \odot M \) is used to compute the output, where \( \odot \) denotes the Hadamard product. This method also applies in the exact same way to fully-connected layers, with the only difference being that the original weight tensor is 2-dimensional.

References

1. Bengio Y, Courville A, Vincent P. Representation Learning: A Review and New Perspectives. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2013 Mar 7;35(8):1798-828.

2. Kingma DP, Welling M. Auto-Encoding Variational Bayes. arXiv preprint arXiv:1312.6114. 2013 Dec 20.

3. Locatello F, Bauer S, Lucic M, Raetsch G, Gelly S, Schölkopf B, Bachem O. Challenging Common Assumptions in the Unsupervised Learning of Disentangled Representations. In International Conference on Machine Learning 2019 May 24 (pp. 4114-4124).

4. Frankle J, Carbin M. The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. arXiv preprint arXiv:1803.03635. 2018 Mar 9.

5. Schölkopf B. Causality for Machine Learning. arXiv preprint arXiv:1911.10500. 2019 Nov 24.