目录
CrossEntropyLoss交叉熵损失函数的使用:
一、官方说明:
二、两种使用情况:
1)情况一:target是一个类索引(Example of target with class indices)
2)情况二:target是一个类概率(Example of target with class probabilities)
CrossEntropyLoss交叉熵损失函数的使用:
一、官方说明:
CrossEntropyLoss — PyTorch 2.3 documentation
torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)
二、两种使用情况:
1)情况一:target是一个类索引(Example of target with class indices)
当target是一个类索引,使用CrossEntropyLoss函数,计算input(样本的预测概率分布)和target(样本的真实标签)的交叉熵损失:
import torch
import torch.nn as nn######### 交叉熵损失损失函数的两种情况(以下假设有3个样本,5个类别)# 情况一:Example of target with class indices(具有类索引的目标示例,target是一个类索引)
loss = nn.CrossEntropyLoss()input = torch.randn(3, 5, requires_grad=True) # input为3个样本的预测概率分布
target = torch.empty(3, dtype=torch.long).random_(5) # target为3个样本的真实标签
output = loss(input, target) # 计算损失print(input)
print(target)
print(output)#### 以下是输出结果:
# tensor([[ 1.9081, 0.0438, -0.8243, -0.3006, 0.2915],
# [ 0.3592, 1.4114, -0.3863, 0.5843, -0.7542],
# [ 1.0673, -0.9307, -0.8625, -2.3816, -0.5145]], requires_grad=True)
# tensor([3, 0, 4])
# tensor(2.1368, grad_fn=<NllLossBackward0>)
2)情况二:target是一个类概率(Example of target with class probabilities)
当target是一个类概率,使用CrossEntropyLoss函数,计算input(样本的预测概率分布)和target(样本的真实标签)的交叉熵损失:
# 情况二:Example of target with class probabilities(具有类概率的目标示例,target是一个类概率分布)
loss = nn.CrossEntropyLoss()input = torch.randn(3, 5, requires_grad=True) # input为3个样本的预测概率分布
target = torch.randn(3, 5).softmax(dim=1) # target为3个样本的真实标签
output = loss(input, target) # 计算损失print(input)
print(target)
print(output)#### 以下是输出结果:
# tensor([[ 0.5793, -0.3210, 0.1222, -0.7272, -0.8790],
# [-0.2824, 0.2521, 0.9788, -0.4009, -0.1519],
# [-0.5411, -0.1141, 0.6473, -0.1465, -1.0575]], requires_grad=True)
# tensor([[0.4346, 0.1689, 0.1089, 0.0145, 0.2731],
# [0.0651, 0.4137, 0.0194, 0.3543, 0.1476],
# [0.1100, 0.0912, 0.1600, 0.6082, 0.0306]])
# tensor(1.6848, grad_fn=<DivBackward1>)