变量名解释
logits:未经过normalize(未经过激活函数处理)的原始分数,例如一个mlp将特征映射到num_target_class维的输出tensor就是logits。
probs:probabilities的简写,logits经过sigmoid函数,就变成了分布在0-1之间的概率值probs。
Binary Cross-Entropy Loss
Binary Cross-Entropy Loss,简称为BCE loss,即二元交叉熵损失。
二元交叉熵损失是一种用于二分类问题的损失函数。它衡量的是模型预测的概率分布与真实标签的概率分布之间的差异。在二分类问题中,每个样本的标签只有两种可能的状态,通常表示为 0(负类)和 1(正类)。
其公式为:
BCE Loss = − 1 N ∑ i = 1 N [ y i log ( p i ) + ( 1 − y i ) log ( 1 − p i ) ] \text { BCE Loss }=-\frac{1}{N} \sum_{i=1}^N\left[y_i \log \left(p_i\right)+\left(1-y_i\right) \log \left(1-p_i\right)\right] BCE Loss =−N1i=1∑N[yilog(pi)+(1−yi)log(1−pi)]
其中:
- N N N是数据集的样本数量。
- y i y_i yi是第 i {i} i个样本的真实标签,取值为 0 或 1,即第 i {i} i个样本要么属于类别 0(负类),要么属于类别 1(正类)。
- p i p_i pi是第 i {i} i个样本属于类别 1(正类)的概率
- log \log log是是自然对数。
当真实标签 y i = 1 y_i=1 yi=1 时,损失函数的第一部分 y i log ( p i ) y_i \log \left(p_i\right) yilog(pi) 起作用,第二部分为 0 。此时, 如果预测概率 p i p_i pi 接近 1 (接近真实标签 y i = 1 y_i=1 yi=1), 那么 log ( p i ) \log \left(p_i\right) log(pi) 接近 0 , 损失较小;如果 p i p_i pi 接近 0 (即模型预测错误),那么 log ( p i ) \log \left(p_i\right) log(pi) 会变得成绝对值很大的负数,导致取反后loss很大。
当真实标签 y i = 0 y_i=0 yi=0 时,损失函数的第二部分 ( 1 − y i ) log ( 1 − p i ) \left(1-y_i\right) \log \left(1-p_i\right) (1−yi)log(1−pi)起作用,第一部分为 0。此时,预测概率 p i p_i pi越接近于 0,整体loss越小。
Pytorch手动实现
import torch
import torch.nn.functional as Fdef manual_binary_cross_entropy_with_logits(logits, targets):# 使用 Sigmoid 函数将 logits 转换为概率probs = torch.sigmoid(logits)# 计算二元交叉熵损失loss = - torch.mean(targets * torch.log(probs) + (1 - targets) * torch.log(1 - probs))return loss# logits和targets可以是任意shape的tensor,只要shape相同即可
logits = torch.tensor([0.2, -0.4, 1.2, 0.8])
targets = torch.tensor([0., 1., 1., 0.])
assert logits.shape == targets.shape# 使用 PyTorch 的 F.binary_cross_entropy_with_logits 函数计算损失
loss_pytorch = F.binary_cross_entropy_with_logits(logits, targets)# 使用手动实现的函数计算损失
loss_manual = manual_binary_cross_entropy_with_logits(logits, targets)print(f'Loss (PyTorch): {loss_pytorch.item()}')
print(f'Loss (Manual): {loss_manual.item()}')
F.binary_cross_entropy 与 F.binary_cross_entropy_with_logits的区别
F.binary_cross_entropy的输入是probs
F.binary_cross_entropy_with_logits的输入是logits