Explanation of Mutual Information Neural Estimation

9 minute read


Explanation of the paper Mutual Information Neural Estimation, ICML 2018.

1. Introduction

Mutual information has been successfully applied in deep learning recently. The original difficulty of using mutual information is that it is hard to compute exactly. Recent methods focus on how to derive a tractable bound that can be optimized on.

Let’s begin with the first big-hit paper, Mutual Information Neural Estimation, in ICML 2018.

2. Method Detail

We will discuss from the mutual information itself to how to get a tractable bound for optimization.

2.1. Mutual Information

Mutual information \(I\) quantifies the statistical dependence of two random variables \(X\) and \(Z\) (can be thought of as input variables and the corresponding latent variables). \(I(X;Z)\) is generally defined as:

\[\label{eq:mi} I(X;Z)=H(X)-H(X\mid Z)=H(Z)-H(Z\mid X)=I(Z;X),\]

where \(H\) is Shannon entropy of a random variable. \(H(X)=-\mathbb{E}_{p(x)}\log p(x)\) and \(H(X\mid Z)=-\mathbb{E}_{p(x,z)}\log p(x\mid z)\).

Note that \(H(X\mid Z)\) is derived as:

\[\label{eq:con-ent} \begin{aligned} H(X\mid Z)&=\int_{z}p(z)H(X\mid Z=z)dz \\ &=-\int_{z}p(z)\int_{x}p(x\mid z)\log p(x\mid z)dxdz \\ &=-\int_{x,z}p(x,z)\log p(x\mid z)dxdz. \end{aligned}\]

There is a more useful way, in terms of computation, to describe \(I(X;Z)\):

\[\begin{aligned} \label{eq:kl-mi} I(X; Z) &=H(X)-H(X \mid Z) \\ &=-\int_{x} p(x) \log p(x) d x+\int_{x, z} p(x, z) \log p(x \mid z) d x d z \\ &=\int_{x, z}(-p(x, z) \log p(x)+p(x, z) \log p(x \mid z)) d x d z \\ &=\int_{x, z}\left(p(x, z) \log \frac{p(x, z)}{p(x) p(z)}\right) d x d z \\ &=D_{K L}(P(X, Z) \| P(X) \otimes P(Z)) \end{aligned}\]

So the mutual information can be thought of as the KL divergence between the joint distribution and the product of both marginal distributions. And the mutual information is obviously symmetric.

According to the KL divergence form of the mutual information, if \(X\) and \(Z\) are independent, \(p(x, z) = p(x) \times p(z)\) and \(I(X;Z)=0\). If \(X\) and \(Z\) become more dependent, the mutual information will intuitively increase. Therefore, mutual information can be considered as a good metric for determining the dependence between random variables.

2.2. The Donsker-Varadhan Representation of KL

Although we have a metric for the dependence, this KL form is intractable (yet). Consider an normal encoder model, we will only have \(P(Z\mid X)\). While \(P(X,Z)\), \(P(X)\) and \(P(Z)\) are all intractable. Therefore, we need to find a tractable representation to optimize. Here comes the Donsker-Varadhan (DV) representation:

\[\label{eq:dv-dk} D_{K L}(\mathbb{P} \| \mathbb{Q})=\sup _{T: \Omega \rightarrow \mathbb{R}} \mathbb{E}_{\mathbb{P}}[T]-\log (\mathbb{E}_{\mathbb{Q}}[e^{T}]),\]

where \(\mathbb{P}\) and \(\mathbb{Q}\) are two arbitrary distributions and \(T\) is an arbitrary function mapping from the sample space \(\Omega\) to the real number \(\mathbb{R}\).

Proof. To prove Eq. (\ref{eq:dv-dk}), we actually want to prove the following statement:

\[\label{eq:dv-le} D_{K L}(\mathbb{P} \| \mathbb{Q}) \geq \mathbb{E}_{\mathbb{P}}[T]-\log (\mathbb{E}_{\mathbb{Q}}[e^{T}]),\]

and under some \(T\), the equal sign in Eq. (\ref{eq:dv-le}) is achievable.

First, we define an auxiliary Gibbs distribution \(d \mathbb{G}=\frac{1}{Z} e^{T} d \mathbb{Q}\), where \(Z=\mathbb{E}_{\mathbb{Q}}\left[e^{T}\right]\) and \(d \mathbb{Q}\) can be safely thought of as the density function. In an informal way, the Gibbs distribution can be re-written as \(g(x)=\frac{1}{Z} e^{T(x)} q(x)\). And to ensure \(g(x)\) is a probability density function (integral result equals 1), we let \(Z=\int_{x}e^{T(x)}q(x)dx\), which is exactly the expectation of \(e^{T(x)}\) under distribution \(\mathbb{Q}\).

With \(\mathbb{G}\), the right hand side of Eq. (\ref{eq:dv-le}) can be reformed as:

\[\label{eq:rhs} \begin{aligned} \mathbb{E}_{\mathbb{P}}[T]-\log (\mathbb{E}_{\mathbb{Q}}[e^{T}])&=\mathbb{E}_{\mathbb{P}}\left[\log (e^{T})-\log (\mathbb{E}_{\mathbb{Q}}[e^{T}])\right] \\ &=\mathbb{E}_{\mathbb{P}}\left[\log \frac{e^{T}}{\mathbb{E}_{\mathbb{Q}}[e^{T}]}\right], \text{ (group log terms together)} \\ &=\mathbb{E}_{\mathbb{P}}\left[\log \frac{e^{T}d\mathbb{Q}}{\mathbb{E}_{\mathbb{Q}}[e^{T}]}\frac{1}{d\mathbb{Q}}\right], \text{ (multiply } d\mathbb{Q}\text{)} \\ &=\mathbb{E}_{\mathbb{P}}\left[\log \frac{d\mathbb{G}}{d\mathbb{Q}}\right]. \end{aligned}\]

Let \(\Delta\) be the gap of terms in Eq. (\ref{eq:dv-le}):

\[\Delta:=D_{K L}(\mathbb{P} \| \mathbb{Q})-\left(\mathbb{E}_{\mathbb{P}}[T]-\log \left(\mathbb{E}_{Q}\left[e^{T}\right]\right)\right).\]

And \(\Delta\) can be represented as a KL term:

\[\Delta=\mathbb{E}_{\mathbb{P}}\left[\log \frac{d \mathbb{P}}{d \mathbb{Q}}-\log \frac{d \mathbb{G}}{d \mathbb{Q}}\right]=\mathbb{E}_{\mathbb{P}} \log \frac{d \mathbb{P}}{d \mathbb{G}}=D_{K L}(\mathbb{P} \| \mathbb{G}) \geq 0.\]

Therefore, Eq. (\ref{eq:dv-le}) holds in any situation. Further, for \(\mathbb{G}=\mathbb{P}\), this bound is tight because the KL in between is 0.

But what does it indicate by \(\mathbb{G}=\mathbb{P}\)? From the definition of \(d \mathbb{G}=\frac{1}{Z} e^{T} d \mathbb{Q}\), it means that \(d \mathbb{P}=\frac{1}{Z} e^{T} d \mathbb{Q}\) as well. Overall, it indicates that for some optimal \(T^{*}=\log \frac{d \mathbb{P}}{d \mathbb{Q}}+\text{const}\), the gap can equal 0. Note that it does not mean that \(\mathbb{Q}=\mathbb{P}\) because we are not minimizing the KL and in stead, we are calculating a bound of the KL and finding the tightest point of this bound.

So formally, the bound is like:

\[\label{eq:bound-dv} D_{K L}(\mathbb{P} \| \mathbb{Q}) \geq \sup _{T \in \mathcal{F}} \mathbb{E}_{\mathbb{P}}[T]-\log (\mathbb{E}_{\mathbb{Q}}[e^{T}]),\]

where \(\mathcal{F}\) is an arbitrary transformation function (neural networks are functions).

2.3. The \(f\)-Divergence Representation of KL

There is a weaker bound of KL using \(f\)-divergence to derive:

\[D_{K L}(\mathbb{P} \| \mathbb{Q}) \geq \sup _{T \in \mathcal{F}} \mathbb{E}_{\mathbb{P}}[T]-\mathbb{E}_{\mathbb{Q}}[e^{T-1}].\]

It is weaker because \(\mathbb{E}_{\mathbb{Q}}[e^{T-1}]>\log (\mathbb{E}_{\mathbb{Q}}[e^{T}])\). Consider an inequality \(\frac{x}{e}>\log x\). It is easily to verify it holds.

2.4. Mutual Information with DV

In Eq. (3), the MI is represented as a KL term \(D_{K L}(P(X, Z) \| P(X) \otimes P(Z))\). With the Bound (\ref{eq:bound-dv}), we have:

\[\label{eq:bound-mi} I(X ; Z) \geq I_{\Theta}(X, Z)=\sup _{\theta \in \Theta} \mathbb{E}_{\mathbb{P}_{X Z}}[T_{\theta}]-\log (\mathbb{E}_{\mathbb{P}_{X} \otimes \mathbb{P}_{Z}}[e^{T_{\theta}}]),\]

where \(T_{\theta}: \mathcal{X} \times \mathcal{Z} \rightarrow \mathbb{R}\) is a neural network with input of two random variables and one-digit output (statistics network in the original paper). Note that we can use samples of \(X\) and \(Z\) to estimate the expectation in Bound (\ref{eq:bound-mi}) and those intractable probabilities are no longer needed.

2.5. Mutual Information Neural Estimator

Yea, it is MINE.

Using a neural network \(T\) and \(n\) samples of \(X\), we can have the following estimator of MI:

\[I \widehat{(X ; Z)}_{n}=\sup _{\theta \in \Theta} \mathbb{E}_{\mathbb{P}_{X Z}^{(n)}}[T_{\theta}]-\log (\mathbb{E}_{\mathbb{P}_{X}^{(n)} \otimes \hat{\mathbb{P}}_{Z}^{(n)}}[e^{T_{\theta}}]),\]

where all distribution are empirical distribution. \(X\) and \(Z\) from \(\mathbb{P}_{X Z}^{(n)}\) are empirically related, e.g., input and label, input and latent code, or noise and the generation result etc. While for the marginal distribution, if the data generation procedure is known, like experiment 4.1 in the paper, samples can still be drawn empirically. If the procedure is unknown, like the GAN experiment, shuffling the batch dimension of the data from the joint distribution respectively can give out samples from marginal distributions.

Detailed optimization steps of MINE can be found in Algorithm 1 in the paper. It is very straightforward.

The estimated gradient (by data) of \(\theta\) in the network \(T\) is:

\[\label{eq:grad} \widehat{G}_{B}=\mathbb{E}_{B}\left[\nabla_{\theta} T_{\theta}\right]-\frac{\mathbb{E}_{B}\left[\nabla_{\theta} T_{\theta} e^{T_{\theta}}\right]}{\mathbb{E}_{B}\left[e^{T_{\theta}}\right]},\]

where \(B\) is a batch of data.

2.6. Bias in Gradient

Are we done? Yes and no.

We do find a proper way to update parameters. But this way is biased (in the technical sense). The gradient in Eq. (\ref{eq:grad}) is biased estimation of the full-batch gradient under the expectation.

For an extreme case, a batch only includes one sample. Then the second term in the gradient becomes \(\nabla_{\theta} T_{\theta}\). The numerator and denominator are simplified with the term \(e^{T_{\theta}}\). However, what we really want to calculate is the expectation of the numerator and denominator respectively:

\[\frac{\mathbb{E}_{B}\left[\nabla_{\theta} T_{\theta} e^{T_{\theta}}\right]}{\mathbb{E}_{B}\left[e^{T_{\theta}}\right]}\neq \frac{\nabla_{\theta} T_{\theta} e^{T_{\theta}}}{e^{T_{\theta}}}\neq \nabla_{\theta} T_{\theta}.\]

We cannot directly use batch-based optimization algorithms (e.g., SGD) to optimize this objective.

A common question is that why \(\widehat{G}_{B}\) is not 0 accordingly. That is because the first term in Eq. (\ref{eq:grad}) uses samples from the joint distribution while the second term uses samples from marginal distributions.

2.7. Fix the Bias

Assume \(T\) will not change drastically. Then we can use a moving average to estimate \(\mathbb{E}_{B}\left[e^{T_{\theta}}\right]\). The original paper uses an exponential moving average with a small learning rate.

3. Theoretical Analysis

Because MINE is an estimator, we need to see how good it is.

3.1. Unbiased Estimation

First, it is unbiased since the derivation of MINE is based on DV representation of KL, which is based on calculating the expectation. And we prove that the bound is tight.

3.2. Consistency

Besides, we want the estimator to converge to the true value, which is a stronger requirement.

The strong consistency says: for all \(\epsilon > 0\), there is a positive integer \(N\) and a choice of the statistics network (\(T_\theta\)) such that:

\[\forall n \geq N, \quad\left|I(X, Z)-I \widehat{(X ; Z)}_{n}\right| \leq \epsilon, \text { a.e. }.\]

For the term \(I \widehat{(X ; Z)}_{n}\), it consists two things that need to verify: (1) does \(T_\theta\) converge to the real transformation function? (2) does sampling estimation converge to the theoretical true value?

Lemma 1 (approximation). Let \(\epsilon > 0\). There exists a neural network parameterizing functions \(T_\theta\) with parameters \(\theta\) in some compact domain \(\Theta \subset \mathbb{R}^k\), such that

\[\left|I(X, Z)-I_{\Theta}(X, Z)\right| \leq \epsilon, \text { a.e. }.\]

Proof. The proof is not difficult but a little bit long. Please refer to the appendix in the original paper.

Lemma 2 (estimation). Let \(\epsilon > 0\). Given a family of neural network functions \(T_\theta\) with parameters \(\theta\) in some bounded domain \(\Theta \subset \mathbb{R}^k\), there exists an \(N\in\mathbb{N}\), such that

\[\forall n \geq N, \quad\left|I \widehat{(X ; Z)}_{n}-I_{\Theta}(X, Z)\right| \leq \epsilon, \text { a.e.}.\]

Theorem 1. MINE is strong consistent.

Proof. With Lemma 1 and Lemma 2,

\[\begin{aligned} \left|I(X, Z)-I \widehat{(X ; Z)}_{n}\right| &=\left|I(X, Z)-I_{\Theta}(X, Z)+I_{\Theta}(X, Z)-I \widehat{(X ; Z)}_{n}\right|\\ &\leq \left|I(X, Z)-I_{\Theta}(X, Z)\right|+\left|I \widehat{(X ; Z)}_{n}-I_{\Theta}(X, Z)\right|\\ &\leq 2\epsilon. \end{aligned}\]

3.3. Sample Complexity

This is the sample complexity for Lemma 2. We need to see how many sample we need to estimate the optimal \(T^*\).

Theorem 2. Given any values \(\epsilon,\delta\) of the desired accuracy and confidence parameters, we have,

\[\operatorname{Pr}\left(\left|I \widehat{(X ; Z)}_{n}-I_{\Theta}(X, Z)\right| \leq \epsilon\right) \geq 1-\delta,\]

whenever the number n of samples satisfies

\[n \geq \frac{2 M^{2}(d \log (16 K L \sqrt{d} / \epsilon)+2 d M+\log (2 / \delta))}{\epsilon^{2}}.\]