2.1.21. 损失函数-分类-KL 散度(KL Loss)

定义:KL 散度

  • 是什么:衡量两个概率分布 P 和 Q 之间差异的非对称性度量

  • 别名:相对熵

  • 核心思想:用分布 Q 来近似真实分布 P 时所造成的信息损失或额外成本

  • 关键性质:非负性(≥0,相等时为0)和不对称性(P||Q ≠ Q||P)

  • 主要用途:机器学习中的损失函数、信息论、变分推断、模型评估等

一、核心概念:衡量两个概率分布的“差异”

KL 散度,全称Kullback-Leibler Divergence,在中文里也常被称为相对熵。它的核心目的是衡量同一个随机变量有两个单独的概率分布 P(x) 和 Q(x) 之间的差异。

你可以把它想象成一个“距离”度量,但它不是真正的距离(因为它不对称,不满足三角不等式)。更准确地说,它衡量的是当你使用一个近似分布 Q 来编码来自真实分布 P 的样本时,所额外产生的“信息损失”或“额外开销”。


二、一个生动的例子:新闻编码

假设你是一名记者,每天要从一个事件库中挑选新闻进行报道。事件库中不同事件发生的真实概率是分布 P(比如,政治事件 50%,体育事件 30%,娱乐事件 20%)。

为了高效地发回电报,你需要设计一套编码方案(比如,用短的代码表示高频事件,长的代码表示低频事件)。最优的编码方案一定是根据真实分布 P 来设计的。

  1. 理想情况(完美编码)

    • 你使用根据真实分布 P 设计的最优编码。

    • 平均下来,每条新闻的电报长度最短。这个最短的平均长度就是 P 的熵

  2. 实际情况(使用错误模型)

    • 但你错误地估计了形势,认为体育新闻最流行。你根据一个错误的分布 Q 来设计编码(比如,体育 60%,政治 30%,娱乐 10%)。

    • 你现在用为 Q 设计的编码,去传输实际上符合分布 P 的新闻。

会发生什么?

  • 高频的政治事件(在 P 中占 50%)被你分配了一个较长的代码(因为在你的错误模型 Q 中,它只占 30%)。

  • 低频的体育事件(在 P 中占 30%)却被你分配了一个很短的代码(因为在 Q 中它占 60%)。

结果就是: 你的平均电报长度变长了!这个额外增加的平均长度就是 KL 散度

公式表示: D_KL(P || Q) = (使用Q编码传输P样本的平均长度) - (使用P编码传输P样本的平均长度) = (交叉熵) - (熵)

所以,KL 散度衡量的是因使用错误分布 Q 而不是真实分布 P 所带来的额外成本

三、重要性质

  1. 非负性:KL 散度永远 ≥ 0。

    • D_KL(P || Q) = 0 当且仅当 P 和 Q 是完全相同的分布

    • 只要 P 和 Q 有差异,KL 散度就会大于零。

  2. 不对称性:这是它不能被称为“距离”的根本原因。

    • D_KL(P || Q) D_KL(Q || P)

    • 意义不同

      • D_KL(P || Q):基于“真实分布是 P,我们用 Q 去近似它”的假设。

      • D_KL(Q || P):基于“真实分布是 Q,我们用 P 去近似它”的假设。

    • 这两种假设下的信息损失是不同的。

四、公式

基础公式: DKL(PQ)=iP(i)log(P(i)Q(i))

对于离散随机变量: DKL(PQ)=iP(i)log(P(i)Q(i))

对于连续随机变量: DKL(PQ)=p(x)log(p(x)q(x))dx

其中:

  • P 是真实分布(或参考分布)。

  • Q 是近似分布(或比较分布)。

从公式可以看出:

  • log(P(i) / Q(i)) 可以理解为在点 i 处,P 和 Q 的差异程度。

  • 这个差异程度用概率 P(i) 进行了加权平均。因此,KL 散度更关注在 P 出现概率高的地方,Q 与 P 的拟合程度。如果某个区域 P(x) 的概率很大,但 Q(x) 的概率很小(即比值 P(x)/Q(x) 很大),那么它对 KL 散度的贡献就会非常大,会导致一个很大的惩罚(Penalty)。

五、主要应用场景

  1. 机器学习与深度学习

    • 变分自编码器:衡量学到的潜在变量分布与标准正态分布之间的差异。

    • GANs:虽然早期 GAN 用 JS 散度,但其思想与 KL 散度密切相关,都是衡量生成分布与真实分布的差异。

    • 更多时候是作为损失函数的一部分,用来约束模型的输出分布尽可能接近真实的标签分布。

  2. 信息论

    • 就像上面的例子,用于衡量编码方案相对于最优方案的效率损失。

  3. 贝叶斯推理

    • 当我们用一个简单的分布 Q(称为变分分布)去近似一个复杂的后验分布 P 时,可以用 KL 散度来衡量这个近似的好坏。

  4. 交叉熵损失

    • 在分类任务中,我们常用的交叉熵损失 其实就是 KL 散度 + 真实分布 P 的熵

    • 由于训练时真实标签 P 是 one-hot 编码(熵为0),所以最小化交叉熵损失就等价于最小化 KL 散度

六、变体或替代方法

  • 反向KL散度: 使用 DKL(Q||P) ,适用于某些需要强调 Q(x) 为 0 的场景。

  • Jensen-Shannon Divergence (JS散度): 是 KL散度的对称变体,定义为 DJS(P||Q)=12DKL(P||M)+12DKL(Q||M)

    • 其中 M 为 P 和 Q 的均值分布:M=12(P+Q)

  • Wasserstein距离: 用于生成对抗网络(GAN)中的分布对比,解决了 KL散度在某些场景下的数值问题。

KL 散度作为损失函数

工作原理

核心思想:

  • 最小化 D_KL(P || Q),即惩罚模型分布 Q 与真实分布 P 之间的差异。

在训练过程中:

  1. P目标分布(真实数据的分布或我们想要的分布)。

  2. Q模型预测的分布(由神经网络或其他模型参数化)。

  3. 我们通过梯度下降等优化算法,调整模型参数,使得 D_KL(P || Q) 的值不断减小。

  4. D_KL(P || Q) 接近 0 时,我们认为模型 Q 已经很好地拟合了真实分布 P。

具体应用场景

1. 分类任务(最经典的例子)

在分类任务中,我们通常使用交叉熵损失。而交叉熵损失本质上就是 KL 散度

  • 真实分布 P:通常是 one-hot 编码 的标签。例如,对于一张猫的图片,真实标签是 [1, 0, 0](假设类别为 [猫,狗,鸟])。

  • 模型分布 Q:是模型通过 Softmax 层输出的概率分布,例如 [0.7, 0.2, 0.1]

让我们看看它们的关系: 交叉熵 (H(P, Q)) = KL 散度 (D_KL(P||Q)) + 真实分布的熵 (H(P))

用公式表示: H(P,Q)=P(x)logQ(x)=DKL(PQ)+H(P)

在分类任务中,P 是 one-hot 分布,其熵 H(P) = 0(因为只有一个事件概率为1,其他为0,计算熵为0)。 所以: H(P, Q) = D_KL(P || Q) + 0

结论:在分类任务中,最小化交叉熵损失 H(P, Q) 完全等价于最小化 KL 散度 D_KL(P || Q) 这就是 KL 散度作为损失函数最广泛的应用。

2. 生成模型(如VAE, GANs)

变分自编码器(VAE) 中,KL 散度扮演着核心角色,它作为损失函数的一部分(正则项)。

  • 目标:VAE 不仅想精确地重建输入数据(重构损失),还希望其编码器学到的潜在变量(Latent Variables)的分布 Q(z|X) 接近一个简单的先验分布(通常是标准正态分布 P(z) = N(0, I))。

  • 作用D_KL(Q(z|X) || N(0, I)) 这项损失迫使潜在空间变得规整、连续和有结构,使得我们在潜在空间中插值时,解码器能生成有意义的输出。

3. 贝叶斯推理与变分推断

在贝叶斯模型中,后验分布 P(Z|X) 往往非常复杂,难以直接计算。我们会用一个简单的分布 Q(Z)(称为变分分布)去近似它。

  • 目标:找到一组参数,使得 Q(Z) 尽可能接近 P(Z|X)

  • 方法:通过最小化它们之间的 KL 散度 D_KL(Q(Z) || P(Z|X)) 来实现。最小化这个 KL 散度等价于最大化证据下界,是变分推断的核心。

4. 强化学习

强化学习(如PPO算法)中,KL 散度被用作约束,以确保策略模型的更新幅度不会太大太剧烈。

  • 目标:在更新策略网络参数时,要求新策略 π_new 和旧策略 π_old 之间的 KL 散度不能超过一个阈值。

  • 作用:这样可以防止策略崩溃,保证训练过程的稳定性。