Keyboard shortcuts

Press or to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Van Den Oord 2017 - VQ-VAE

Neural Discrete Representation Learning

This paper proposes a Variational Auto Encoder (VAE) that uses discrete (categorical) latent variables instead of continuous ones. Discrete latents are more desirable from an interpretability perspective, but historically have not been able to perform as well as continuous ones. This paper is apparently the first to bridge this performance gap.

We are in the domain of computer vision, where the VAE is used to generate new images from existing ones.

Note that although VQ-VAE starts with the VAE framework, it is not strictly variational because of the deterministic operations within. It is more precise to call it a (deterministic) autoencoder with vector quantization regularization.

VAEs

A quick recap on regular VAEs:

  • We start with some input which is an image
  • The encoder encodes the input using into latent space
  • Where in this paper is a discrete latent variable
  • is the prior distribution (a uniform categorical in this paper)
  • is the decoder that decodes a given latent back into input space (i.e. generates an image)

VQ-VAE

In VQ-VAE, the prior and posterior distributions are categorical. Drawing from the categorical distribution gives us an index which is used to index into an embedding table comprising dimensional embeddings. This extracted embedding is then used to represent the sample and fed into the decoder model.

More specifically, let us define a latent embedding space . That is, we have discrete latents and is the embedding dimension of each latent. Starting with an input , we pass it into the encoder to produce , which is a grid of -dimensional vectors. We then find the nearest embedding to in the embedding table to get a categorical index and a corresponding embedding for each encoded vector.

Note that using our running example of image generation, is encoded into a 2D grid of latents (say 32 x 32 x D). We find the nearest embedding at each position such that we end up with a grid of 32 x 32 codebook indices. Since each position is discretized independently of the others, in the exposition we refer to and so on as though it is one vector.

One way to think about this operation is that we are applying a particular non-linear operation that maps . Noticeably, this non linear operation is non-differentiable, which we will need to tackle later on.

We can thus define the posterior probability distribution for as:

Note that this posterior distribution is deterministic. If we define a simple uniform prior for , we get that the KL divergence is constant: .

Recall that:

Since if and otherwise, only one term in the summation is non-zero:

Hence the KL divergence term is constant and ignored during training.

Learning

At this point, we have fully specified the forward pass:

  • Start with input
  • Encode into
  • Use embedding table lookup to find nearest neighbour
  • Decode into

But it is not yet clear what the optimization objective and gradient flow should be. Recall that the standard VAE objective is to optimize:

The first term is the reconstruction loss, where we draw from the distribution that maps an input (using the encoder plus some random noise). Then we decode the back into input space and try to make it similar to input . The second term is the KL divergence which tries to regularize the distribution to be as simple as possible.

However, we cannot simply use this equation for the VQ-VAE:

  • While the reconstruction loss is no issue (we can use the standard gaussian formulation), the second KL divergence term is a constant as we saw above. Hence doing this naively just reduces to a standard deterministic auto-encoder.
  • Another problem is that in computing , we have to go through the non-linear operation of looking up the embedding table. This does not allow the gradient to flow back into the encoder.

Hence the authors need to re-design the optimization objective.

The first decision is to use the straight through estimator to circumvent the gradient issue. For the non-linear operation , we compute the forward pass normally (i.e. embedding table lookup) but simply pass the gradients through during backpropagation. This means that we approximate:

This allows the encoder to still receive gradient updates despite the non-differentiable operation. The theoretical justification for this operation is given in an earlier Bengio 2013 paper. Intuitively, if is close to , the gradients should still be meaningful.

The second decision is to use l2 distance to learn the embedding table. This is a form of dictionary learning. Specifically, we add a term to the loss:

Note that:

  • here refers to the closest embedding to a given . We want embeddings in the codebook to move toward the average encoded representation
  • is the stop gradient operation (e.g. .detach() in pytorch). It uses the value of but does not pass gradients back to the encoder. Since the objective of this loss term is to learn the codebook, we do not wish to pass gradients back to the encoder

The third decision is to add a commitment loss to bound the encoder outputs. This part feels a bit more arbitrary. The authors say that with just the first two terms, there is nothing that tethers the encoder output, which can grow arbitrarily and perpetually be far away from the codebook embeddings. The solution is to include the reverse direction from the dictionary learning loss:

Notice that this is identical to the second term except that the stop gradient operator is applied to the codebook embedding. Thus this gradient pushes the encoder embedding to be closer to its nearest codebook embedding. is a hyperparameter but the authors observed that results were robust to a wide range of values (0.1to 2.0).

A natural question is to wonder why we need both the second and third term which are identical except for where the stop gradient is placed. Why can't we just do in a single term?

It appears that this will result in unstable training, because both sides (encoder and codebook embeddings) are simultaneously moving. This is a common issue in training things like GANs. Separating the terms results in more stable training.

There is also a close connection between VQ-VAE and k-means clustering (as a special case of expectation maximization). The step of assigning each encoder output to the nearest codebook embedding is akin to assign a cluster for each data point in k-means. The step of updating the codebook embedding is akin to updating the cluster centroids in k-means. This idea is explored in subsequent papers like Roy 2018.

Hence, the final training objective becomes:

Evaluation

The log likelihood of the complete model can be evaluated as the total probability over all latents . It is common practice to compute the log likelihood of a held out test set to evaluate how well our model has learned the data distribution:

Note that is also computed to report bits per dimension, which is a common way to evaluate such VAE models on test data. This is literally the number of bits required to represent our data under this model.

Because the decoder is trained with from MAP-inference, the decoder should end up placing all probability mass on after full convergence and no probability mass on . So we can write:

Learning the Prior Distribution

Up to this point, we assumed that the prior is a uniform categorical for training the encoder, decoder and codebook embeddings. This may be viewed as a training trick or mathematical convenience to make our framework work. As you may recall, using the uniform prior resulted in a constant term for the KL-divergence, meaning that the term is ignored during training.

At inference time, when generating a new image, using a uniform prior will result in incoherent images. Instead, we need to train a separate decoder like a pixelCNN or transformer over the latent space to generate a coherent latent grid for decoding.

Specifically, we encode and quantize our training dataset into grids of 32 x 32 codebook indices. Then, an autoregressive decoder is trained to predict the codebook indices in a causal autoregressive way. For PixelCNN, it works on the 2D grid of 32 x 32, but for transformer we need to linearize it in row-major form before training.

Now at inference time, we can start with a uniform latent index for the first position, then generate the subsequent positions using the latent decoder. After we have generated a grid of 32 x 32 latents, we look up embeddings to get 32 x 32 x D grid. This is passed to the decoder to get the generated image.