作用
ignore_index用于忽略ground-truth中某些不需要参与计算的类。假设有两类{0:背景,1:前景},若想在计算交叉熵时忽略背景(0)类,则可令ignore_index=0(同理忽略前景计算可设ignore_index=1)。
代码示例
import torch
import torch.nn.functional as F
pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]]
) # shape=(N,C)=(3,2),N为样本数,C为类数
label = torch.LongTensor([1, 0, 1]) # shape=(N)=(3),3个样本的label分别为1,0,1
out = F.cross_entropy(pred, label, ignore_index=0) # 忽略0类
print(out)
输出
tensor(1.0421)
验证
pytorch的CrossEntropy使用公式:
计算:
ignore_index表示计算交叉熵时,自动忽略的标签值,example:
import torch
import torch.nn.functional as F
pred = []
pred.append([0.9, 0.1])
pred.append([0.8, 0.2])
pred = torch.Tensor(pred).view(-1, 2)label = torch.LongTensor([[1], [-1]]) # 这里输出类别为0或1,-1表示不参与计算loss。且计算平均loss的时候,reduction只计算实际参与计算的个数,这里相当于batchsize=2,但其中第index=1行为-1不参与计算loss。# out = F.cross_entropy(pred.view(-1, 2), label.view(-1, ))
out = F.cross_entropy(pred.view(-1, 2), label.view(-1, ), ignore_index=-1)
print(out)
输出结果:
tensor(1.1711)
再比如:
例如我的pred是(b,2,w,h),而label索引是(b,1,w,h)的矩阵,其中只有0,1值,0值代表从pred的第0个通道选择像素值,1值代表从pred的第1个通道选择像素值。
而此时我发现因为程序的错误,label矩阵中混入了一些-1值,这样正常的话是会报错的,因为pred矩阵没有-1通道。此时最简单的一个方法就是
loss = nn.CrossEntropyLoss(ignore_index=-1)
上述操作就是相当于忽略-1标签值为-1的位置的对应像素值就不参与计算梯度了
torch.nn.CrossEntropyLoss 同理。