微调chatglm 报错RuntimeError: expected scalar type Half but found Float
1. 背景
博主显卡:3090
最初的设置:bfloat16
开始训练后,线性层报错
2. 解决: 统一代码中所有精度
1)将模型和数据精度都设置为torch.float32/torch.float16
xxx = torch.tensor(xxx, dtype=torch.float32)
model.config.torch_dtype = torch.float32
2)将模型参数都设置为torch.float32/torch.float16
for param in model.parameters():# Check if parameter dtype is Float (float32)if param.dtype == torch.float16:param.data = param.data.to(torch.float32)