文章目录
- 一、错误解释RuntimeError: expected scalar type Long but found Float
- 二、错误分析
- 三、解决办法
- 总结
一、错误解释RuntimeError: expected scalar type Long but found Float
RuntimeError:应为标量类型Long,但找到了Float
二、错误分析
我之前的代码:
loss_function = torch.nn.CrossEntropyLoss()
loss1 = loss_function(predict.unsqueeze(0), c_all_train_y)
根据错误信息,可以看出目标标签 c_all_train_y
的数据类型应为 Long
(整型),而不是 Float
(浮点型)。
三、解决办法
为了解决这个问题,将目标标签 c_all_train_y
的数据类型转换为 Long
,例如通过使用 c_all_train_y.long()
,然后再进行损失计算。
修正后的代码如下所示:
loss_function = torch.nn.CrossEntropyLoss()
loss1 = loss_function(predict.unsqueeze(0), c_all_train_y.long())
这样能够正确计算损失函数,同时确保数据类型匹配。请小伙伴们注意,在进行类别预测时,确保预测结果 predict
是未经 Softmax 处理的原始分数。这一点真的很重要!!!!
总结
RuntimeError: expected scalar type Long but found Float,表明代码在某个地方需要一个Long类型的标量(即整数),但是提供的却是Float类型的数据。