torch.nn.BCEWithLogitsLoss相当于sigmoid+torch.nn.BCELoss。代码示例如下,
import torch
import torch.nn as nnBCEWithLogitsLoss = nn.BCEWithLogitsLoss()
BCELoss = nn.BCELoss()x = torch.randn((1,))
y = torch.FloatTensor([1])Loss_BCEWithLogits = BCEWithLogitsLoss(x, y)
Loss_BCE = BCELoss(torch.sigmoid(x), y)print("BCEWithLogitsLoss:", Loss_BCEWithLogits)
print("BCELoss:", Loss_BCE)"""
BCEWithLogitsLoss: tensor(0.2138)
BCELoss: tensor(0.2138)
"""