损失函数-分类-cross-entropy(交叉熵)¶
机器学习和信息理论中常用的一种损失函数或评价指标。它主要用于分类问题,尤其是多分类任务,广泛应用于深度学习模型(如神经网络)的训练中。
备注
从本质上讲,交叉熵损失是机器学习和深度学习中的一种流行测量方法,它测量两个概率分布之间的差异——通常是标签的真实分布(此处为数据集中的标记)和模型的预测分布(例如,由 LLM 生成的标记概率)之间的差异。
定义¶
备注
cross_entropy 本质上是 Softmax 和负对数似然(NLLLoss)的组合。PyTorch 的交叉熵计算相当于直接求负平均对数概率,这是因为分类任务的目标分布 𝑝(𝑥) 通常是 one-hot 编码(每次只有一个类别为 1,其余为 0)。
单样本交叉熵的公式¶
- 说明:
真实分布 𝑃(目标分布)
预测分布 𝑄(模型输出)
P(𝑖) 表示类别 𝑖 的真实分布(通常为 1 或 0,在分类问题中由 one-hot 编码表示)
𝑄(𝑖) 表示模型预测的类别 𝑖 的概率(通常是通过 softmax 输出的概率值)
二分类单样本交叉熵损失的公式¶
- 其中
y是真实标签(0或1)
\(\hat{y}\) 是模型预测的概率(介于0和1之间)
批量交叉熵损失函数的公式¶
- 说明
𝑁:样本数
𝐶:类别数
\(𝑦_{𝑗𝑖}\) :第 𝑗 个样本的真实类别的 one-hot 编码值(1 表示正确类别,0 表示其他类别)
\(\hat{y}_{𝑗𝑖}\) :模型对第 𝑗 个样本预测类别 𝑖 的概率(softmax 输出)
单个样本的交叉熵损失 H(P,Q) 是针对一个样本的,衡量模型预测概率分布与真实标签分布之间的差异。
批量交叉熵损失 L 是针对一批样本的,是对所有样本的单个样本交叉熵损失的平均值。
作用¶
交叉熵衡量的是预测分布 𝑄 与真实分布 𝑃 的相似性:
当 𝑄 和 𝑃 完全一致时,交叉熵最小。 当 𝑄 和 𝑃 差距越大时,交叉熵越大。
使用场景¶
多分类问题:
在分类任务中,交叉熵损失函数能够很好地处理多个类别的概率分布问题。
搭配 softmax 函数使用效果最佳。
二分类问题:
对于二分类问题,交叉熵公式可以简化为对数损失(log loss),结合 sigmoid 函数即可。
示例¶
猜一个单词¶
目标单词是 “cat”。
你给出猜测的概率分布:{cat: 0.7, dog: 0.2, bird: 0.1}。
如果目标是 “cat”,则交叉熵只会关注你给 “cat” 的概率(0.7)。
通过计算 −log(0.7),得出损失。
备注
换个角度看,这损失就是你猜中目标的负平均对数概率,强调你需要最大化这个概率。
代码讲解¶
定义模型输出和真实标签:
import torch
import torch.nn.functional as F
input = torch.tensor([[2.0, 0.5, 0.1],
[0.1, 0.2, 0.9]]) # Logits
target = torch.tensor([0, 2]) # 类别索引
torch.nn.functional.cross_entropy
函数计算过程拆解:
1. 将输入转换为概率分布,即softmax操作: softmax(input)
2. 计算log_softmax: log(softmax(input))
3. 根据类别取出 target 对应位置的log_softmax值: log_softmax[0], log_softmax[2]
4. 求平均:(log_softmax[0] + log_softmax[1])/2
# Step 1: 计算 softmax
softmax = F.softmax(input, dim=1)
# 输出
# tensor([[0.7285, 0.1625, 0.1090],
# [0.2309, 0.2552, 0.5139]])
# Step 2: 计算log_softmax
log_softmax=torch.log(softmax)
# 输出
# tensor([[-0.3168, -1.8168, -2.2168],
# [-1.4657, -1.3657, -0.6657]])
# Step 3: 提取真实类别的 log-softmax 值,并取反
nll_loss = -log_softmax[range(target.shape[0]), target] # 取每一行的,第目标类别对应的值
# 输出
# tensor([0.3168, 0.6657])
# Step 4: 计算平均损失
loss = nll_loss.mean()
print("===> Cross-Entropy Loss=", loss)
# 输出: tensor(0.4913)
备注
单样本交叉熵本质其实只是步骤2和步骤3,N样本交叉熵本质其实是步骤2、步骤3和步骤4,但交叉熵一般和softmax一起使用(也就是步骤1)。