Explanation of Contrastive Predictive Coding

5 minute read


Explanation of the paper Representing Learning with Contrastive Predictive Coding, arXiv 2019.

1. Introduction

This paper is kind of parallel to the previous MINE method. This paper also provides an objective to optimize the mutual information between random variables. This InfoNCE objective is later discussed and claimed to be tighter than MINE.

Let’s begin with this paper, Representing Learning with Contrastive Predictive Coding, in arXiv.

2. Method Detail

We will discuss from the mutual information itself to how to get an objective function. If there is any concern about the notation and the model structure, please referred to Figure 1 of the original paper.

2.1. Mutual Information

In the MINE paper, the mutual information is represented as a KL term, which later is transformed into the DV representation. While in this InfoNCE method, the MI between the original signal \(x\) and the encoding \(c\) is firstly defined as:

\[\label{eq:mi} I(x ; c)=\sum_{x, c} p(x, c) \log \frac{p(x \mid c)}{p(x)}.\]

2.2. Contrastive Predictive Coding

The original task is a sequence prediction task. Given observed input \(x_t\), an encoder maps the input into a sequential latent representation \(z_t=g_\text{enc}(x_t)\). And an autoregressive model summarizes all \(z_{\leq t}\) into a context latent representation \(c_{t}=g_{\text{ar}}\left(z_{\leq t}\right)\).

To predict the futural signals \(x_{t+k}\), we do not apply a generative model \(p_k(x_{t+k}\mid c)\). Instead, we want to maximize the MI between \(x_{t+k}\) and \(x_{\leq t}\) (or \(c_t\)). According to Eq. (\ref{eq:mi}), the ground truth joint distribution \(p(x,c)\) is not the thing that we can control, we want to maximize the density ratio:

\[\label{eq:ratio} f_{k}\left(x_{t+k}, c_{t}\right) \propto \frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t+k}\right)},\]

where we can do a “proportional to” trick because it is an unbounded value. In the paper, this density ratio is approximated by the following function:

\[\label{eq:nn} f_{k}\left(x_{t+k}, c_{t}\right)=\exp \left(z_{t+k}^{T} W_{k} c_{t}\right),\]

where \(W_k\) is for the \(k\)-th step. This function is similar to a dot product, which could indicate the similarity between the predicted latent representation \(z_{t+k}^{T}\) and the step-wise context latent representation \(W_{k} c_{t}\).

2.3. InfoNCE Loss

After obtaining a surrogate representation function to the MI, we want to plug it into the loss function. Here, the NCE loss is applied:

\[\label{eq:loss} \mathcal{L}_{\mathrm{N}}=-\underset{X}{\mathbb{E}}\left[\log \frac{f_{k}\left(x_{t+k}, c_{t}\right)}{\sum_{x_{j} \in X} f_{k}\left(x_{j}, c_{t}\right)}\right],\]

where \(X\) is a set of samples, \(X=\left\{x_{1}, \ldots x_{N}\right\}\). Using this set, the \(x_{t+k}\) is generated from \(c_t\), which can be viewed as a positive sample pair. While for all other samples \(x_j\in X\), along with the same \(c_t\), they serve as negative samples. To minimize this loss, intuitively, we will have the numerator as large as possible while the denominator as small as possible. This is exactly what we want: the MI between positive samples to be large while for negative samples, to be small.

Take a look at Eq. (\ref{eq:loss}), it is actually a cross-entropy loss for multi categories (note that cross-entropy does not need a softmax, which is usually misunderstood in lots of literatures).

So what is the optimal case of Eq. (\ref{eq:loss})? Assume for now, \(x_i\) is the ground truth prediction of \(c_t\). Then the probability of \(x_i\) and other \(x_l\) is:

\[\begin{aligned} p\left(d=i \mid X, c_{t}\right) &=\frac{p\left(x_{i} \mid c_{t}\right) \prod_{l \neq i} p\left(x_{l}\right)}{\sum_{j=1}^{N} p\left(x_{j} \mid c_{t}\right) \prod_{l \neq j} p\left(x_{l}\right)} \\ &=\frac{\frac{p\left(x_{i} \mid c_{t}\right)}{p\left(x_{i}\right)}}{\sum_{j=1}^{N} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}}, \end{aligned}\]

where \(p\left(x_{i} \mid c_{t}\right) \prod_{l \neq i} p\left(x_{l}\right)\) stands for that independently, \(x_i\) is from \(c_t\) while other \(x_l\) is from the prior or the marginal distribution.

Therefore, we can see that \(f(x_{t+k},c_t)\) in Eq. (\ref{eq:ratio}) is, indeed, proportional to \(\frac{p\left(x_{i} \mid c_{t}\right)}{p\left(x_{i}\right)}\).

2.4. Estimating MI with InfoNCE

Although we have already been able to train the network, we still want to see how well the InfoNCE approximate the MI.

Inserting the optimal \(f(x_{t+k},c_t)\) to the loss function Eq. (\ref{eq:loss}),

\[\begin{aligned} \mathcal{L}_{\mathrm{N}}^{\text {opt }} &=-\underset{X}{\mathbb{E}} \log \left[\frac{\frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t}\right)}}{\frac{p\left(x_{t+k} \mid c_{t}\right)}{p\left(x_{t+k}\right)}+\sum_{x_{j} \in X_{\text {neg }}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}}\right] \text{, split into positive and negative} \\ &=\underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)} \sum_{x_{j} \in X_{\text {neg }}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}\right] \text{, move "-" inside} \\ & \approx \underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)}(N-1) \underset{x_{j}}{\mathbb{E}} \frac{p\left(x_{j} \mid c_{t}\right)}{p\left(x_{j}\right)}\right] \text{, sum to expectation}\\ &=\underset{X}{\mathbb{E}} \log \left[1+\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)}(N-1)\right],x_j \perp c_t\\ & \geq \underset{X}{\mathbb{E}} \log \left[\frac{p\left(x_{t+k}\right)}{p\left(x_{t+k} \mid c_{t}\right)} N\right],p\left(x_{t+k} \mid c_{t}\right)>p\left(x_{t+k}\right) \\ &=-I\left(x_{t+k}, c_{t}\right)+\log (N). \end{aligned}\]

2.5. Relation with MINE

We now derive the relation between InfoNCE and MINE.

For simplicity and without the loss of generality, we let \(f(x,c)=e^{F(x,c)}\), then:

\[\begin{aligned} &\mathbb{E}_{X}\left[\log \frac{f(x, c)}{\sum_{x_{j} \in X} f\left(x_{j}, c\right)}\right] \\ &=\underset{(x, c)}{\mathbb{E}}[F(x, c)]-\underset{(x, c)}{\mathbb{E}}\left[\log \sum_{x_{j} \in X} e^{F\left(x_{j}, c\right)}\right] \\ &=\underset{(x, c)}{\mathbb{E}}[F(x, c)]-\underset{(x, c)}{\mathbb{E}}\left[\log \left(e^{F(x, c)}+\sum_{x_{j} \in X_{\text {neg }}} e^{F\left(x_{j}, c\right)}\right)\right] \\ &\leq \underset{(x, c)}{\mathbb{E}}[F(x, c)]-\underset{c}{\mathbb{E}}\left[\log \sum_{x_{j} \in X_{\mathrm{neg}}} e^{F\left(x_{j}, c\right)}\right] \\ &=\underset{(x, c)}{\mathbb{E}}[F(x, c)]-\underset{c}{\mathbb{E}}\left[\log \frac{1}{N-1} \sum_{x_{j} \in X_{\text {neg }}} e^{F\left(x_{j}, c\right)}+\log (N-1)\right]. \end{aligned}\]

MINE by definition:

\[I(\widehat{X ; Z})_{n}=\sup _{\theta \in \Theta} \mathbb{E}_{\mathbb{P}_{X Z}^{(n)}}\left[T_{\theta}\right]-\log \left(\mathbb{E}_{\mathbb{P}_{X}^{(n)} \otimes \hat{\mathbb{P}}_{Z}^{(n)}}\left[e^{T_{\theta}}\right]\right).\]

From both equations, we can see that they are almost the same, except that InfoNCE has an extra constant term \(\underset{c}{\mathbb{E}}[\log (N-1)]\). Authors claim that for difficult tasks, both losses perform almost the same well. While for easy tasks, MINE is unstable.