这个问题是小虎在对张量去平均值的时候遇到。解决方法是先将改张量转成浮点数,然后再取平均值。
背景
pytorch + ubuntu 22.04
问题原文
Traceback (most recent call last): File "<string>", line 1, in <module> RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long
简单地说就是pytroch不支持Long类型的tensor取平均,必须是浮点数或者复数。
解决方法
将:
# fails
images.mean(2)
# RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Byte
改成:
# works
images.float().mean(2)
参考
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Byte