相同的随机种子CPU和GPU上torch.nn.init.xavier_normal_结果并不一致
- 一.测试代码
- 二.输出
在训练pytorch模型时,相同的随机种子,不同的服务器上loss并不一样,通过调试发现这二个平台的权值也不一样.单独测试torch.nn.init.xavier_normal_,发现也不一样.如果都放在CPU上则二台服务器上的结果一致,原来Megatron-DeepSpeed也有–use-cpu-initialization这样一个参数,采用CPU初始化权值
一.测试代码
cat > test_torch_rand.py <<-'EOF'
import torch
import numpy as np
import randomdef init_test(device):shape=(1,4)RANDOM_SEED = 42random.seed(RANDOM_SEED)np.random.seed(RANDOM_SEED)torch.manual_seed(RANDOM_SEED)if torch.cuda.is_available():torch.cuda.manual_seed_all(RANDOM_SEED)print(f"------------------------test torch init on {device}-------------------------------")weight = torch.rand(shape, dtype = torch.float16).to(device)print("torch.rand:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.xavier_normal_(weight)print("xavier_normal_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.uniform_(weight, a=0.0, b=1.0)print("uniform_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.normal_(weight)print("normal_:",weight.detach().cpu().float().numpy())weight = torch.nn.Parameter(torch.empty(shape,device=device,dtype=torch.float16))torch.nn.init.kaiming_uniform_(weight)print("kaiming_uniform_:",weight.detach().cpu().float().numpy())init_test("cpu")
if torch.cuda.is_available():init_test("cuda")
EOF
python3 test_torch_rand.py
二.输出
------------------------test torch init on cpu-------------------------------
torch.rand: [[0.5498047 0.71240234 0.41992188 0.63183594]]
xavier_normal_: [[ 0.14831543 0.14562988 -0.70996094 -0.11785889]]
uniform_: [[0.16113281 0.7236328 0.04248047 0.6816406 ]]
normal_: [[0.46166992 0.26733398 0.53466797 0.8095703 ]]
kaiming_uniform_: [[ 0.58740234 0.49389648 -0.2619629 -0.76416016]]
------------------------test torch init on cuda-------------------------------
torch.rand: [[0.5498047 0.71240234 0.41992188 0.63183594]]
xavier_normal_: [[ 0.12268066 1.3671875 -0.10882568 0.5371094 ]]
uniform_: [[0.98779297 0.12890625 0.5620117 0.52197266]]
normal_: [[-0.5185547 1.2265625 0.6254883 -0.9116211]]
kaiming_uniform_: [[ 0.5625 -0.9321289 -1.0996094 -0.640625 ]]