参考博客:CSDN | 【理论推导】互信息与 InfoNCE 损失:从公式推导理解对比学习的本质 ,感觉是讲的最清楚的一个博客。
- 1 InfoNCE loss 和互信息的数学形式
- 1.1 互信息的数学形式
- 1.2 InfoNCE loss 的数学形式
- 1.3 为什么我们希望最大化\((X,Y)\)的互信息
- 2 InfoNCE loss 与互信息的数学关联
- 3 证明过程
- 3.1 第一步:证明使 InfoNCE loss 取值最小的\(f(x,y)\),满足\(f(x, y) = \log [p(y|x) / p(y)]\)
- 3.2 第二步:将以上\(f(x,y)\)代入,推导互信息下界
1 InfoNCE loss 和互信息的数学形式
1.1 互信息的数学形式
互信息\(I(X,Y)\)是信息论中的核心概念,用于衡量两个随机变量\(X,Y\)之间的依赖程度。
从直观上理解,互信息回答了这样一个问题:知道一个变量 Y 后,我们对另一个变量 X 的不确定性减少了多少?如果 X 的不确定性减少较多,则代表 XY 之间的互信息较大(为正);如果 X 的不确定性没有减少,则 XY 是相互独立的,即\(P(X)P(Y) = P(X,Y)\),XY 之间的互信息为 0。
数学上,互信息有三种等价的定义方式:
① 基于联合分布\(p(x,y)\)和边缘分布\(p(x)p(y)\)的 KL 散度的形式
这个形式直接体现了互信息的本质:它衡量的是联合分布\(p(x,y)\)与假设 X 和 Y 独立时的分布\(p(x)p(y)\)之间的差异。如果 X 和 Y 独立,这个差异为 0,否则为正数,差异越大说明两个变量关联越强。
② 基于熵的形式
这里\(H(X)=\int p(x)\log p(x)\)是 X 的熵(不确定性),\(H(X|Y)\)是已知 Y 时 X 的条件熵,互信息则是不确定性的减少量。
③ 基于条件概率的形式
这个形式在对比学习中特别有用,因为它直接表达了“在给定 x 的情况下,y 的概率相对于其先验概率的变化”。
1.2 InfoNCE loss 的数学形式
InfoNCE loss 是现代对比学习(Contrastive Learning)的核心。它的设计灵感来自一个简单的直觉:从一堆样本中,找出与给定样本 x 匹配的正样本 y。
具体的,假设我们有一个正样本对\((x, y)\),比如同一张图片的两种不同数据增强结果,同时从数据集中随机采样\(N-1\)个负样本\(y_2, y_3, ..., y_N\)。我们定义一个评分函数\(f(x, y)\)(通常是神经网络)来衡量 x 和 y 的相似度。InfoNCE loss 的形式为:
其中,分子\(e^{f(x,y)}\)是正样本的得分,而分母\(\sum_{j=1}^{N} e^{f(x,y_j)}\)是所有样本(1 个正样本 + N-1 个负样本)得分的总和。整个分式表示:给定 x 和 N 个候选 y,我们正确选出正样本 y 的概率。
也可将其视为交叉熵损失(cross-entropy loss)的一个变种。交叉熵损失的形式如下:
其中,\(p(a)\)为真是概率,而\(\hat p(a)\)是我们估计的概率。在 InfoNCE loss 的 setting 中,真概率\(p(x,y) = 1\),而\(p(x,y_j) = 0\)。
1.3 为什么我们希望最大化\((X,Y)\)的互信息
在对比学习中,我们希望最大化正样本对的互信息,同时最小化正负样本之间的互信息。这迫使编码器提取出两个不同视图(view)的共享信息(比如同一张图片的不同数据增强版本、语言 / 视觉等不同的模态),这些信息通常对应于数据的内在语义,例如物体的类别、场景等,而忽略无关的噪声或增强引入的变化。
在 skill discovery(强化学习的一个子领域)中,我们希望最大化 skill z 和 state s 之间的互信息。从信息理论的角度,最大化\(I(S;Z)\)意味着,我们希望从状态\(s\)中尽可能多地获取关于技能\(z\)的信息。这确保了技能是“有区分度的”:看到智能体的行为,我们就能推断出它使用了哪个技能。
2 InfoNCE loss 与互信息的数学关联
核心结论:最小化 InfoNCE loss,等价于最大化互信息的一个下界。
(互信息下界的含义是,互信息的取值将会大于这个值。从这个角度来说,下界的值越大,互信息的值就随之变大,所以,我们最小化 InfoNCE loss,相当于在推动互信息最大化。)
具体来说,对于任意评分函数\(f\),有以下不等式成立:
其中,\(L\)是我们模型的 InfoNCE loss,这个差值就是互信息的下界。
3 证明过程
证明过程可以分为两步:
3.1 第一步:证明使 InfoNCE loss 取值最小的\(f(x,y)\),满足\(f(x, y) = \log [p(y|x) / p(y)]\)
我们要证明:使 InfoNCE loss 最小的\(f(x,y)\)满足:
我们考虑 InfoNCE loss 的期望形式:
我们可以将这个损失看作一个分类问题:给定 x 和 N 个样本\({y_1, y_2, \cdots, y_N}\),其中只有\(y_1=y\)是正样本,其余是负样本。模型的任务是选出正样本。
对于固定的 x,最优的分类器应该给出真实的后验概率,即给定 x 后,y 为这个 x 的正样本的概率。那么,真实的后验概率是多少呢?
根据贝叶斯定理,在给定 x 和 y 样本集合的情况下,第 k 个样本是正样本的概率为(这个没完全看懂):
化简后得到:
关键观察:如果我们取\(f(x,y) = \log\frac{p(y|x)}{p(y)} + c(x)\),其中\(c(x)\)是只依赖于 x 的任意函数,那么:
这正是真实的后验分布。因此,这个\(f(x,y)\)取值使得模型的输出分布与真实分布完全一致,从而最小化 InfoNCE loss。
为简便起见,我们通常取\(c(x)=0\),得到最优\(f(x,y)\):
3.2 第二步:将以上\(f(x,y)\)代入,推导互信息下界
现在我们将最优\(f^*(x,y)\)代入 InfoNCE loss:
考虑互信息\(I(X;Y)\)的以下形式:
现在,我们想建立\(I(X;Y)\)和\(L_{\text{min}}\)的关系。通过巧妙的代数变换,把互信息拆开:
第一项就是\(-L_{\text{min}}\)吗?不完全是。实际上,\(-L_{\text{min}}\)= 第一项 - log N:
所以:
现在看最后一项:\(\mathbb{E}\left[\log\left(\frac{1}{N}\sum_{j=1}^{N} \frac{p(y_j|x)}{p(y_j)}\right)\right]\)
由于对数函数是凹函数,根据琴生不等式(Jensen's Inequality):
因此:
我们计算这个 log 里面的期望:
期望 = 1。所以:
代入上式 = 0,使用琴生不等式,因此:
将上式 ≤ 0 代回原式:
由于\(f^*\)是最优的,所以对于任意评分函数\(f\),有\(L(f) \geq L_{\text{min}}\),得到:
证毕:对于任意评分函数\(f\),\(\log N - L(f)\)是互信息\(I(X;Y)\)的一个下界。