设置随机种子,复现模型结果
- 1.Python本身的随机因素
- 2.numpy随机因素
- 3.pytorch随机因素
在很多情况下,我们希望能够复现实验的结果。为了消除程序中随机因素的影响,我们需要将随机数的种子固定下来。将所有带随机因素的种子全部固定下来后,多次执行同一代码将得到相同的结果。
在pytorch 模型运行时可能会涉及到三类随机因素:Python本身的随机因素,Numpy随机因素,pytorch随机因素。
没法清楚的知道代码涉及到那些随机因素的情况时有发生,为了以防万一,把所有可能的随机因素都禁止掉比较保险。
def set_rand_seed(seed=1):print("Random Seed: ", seed)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)# torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True # 保证每次返回得的卷积算法是确定的
1.Python本身的随机因素
random.seed(seed)
2.numpy随机因素
np.random.seed(seed)
3.pytorch随机因素
1.cpu随机种子
torch.manual_seed(seed)
2.GPU随机种子(with the latest pytorch 0.3 version you only need to set torch.manual_seed which will seed all devices)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
3.cuDNN 是英伟达专门为深度神经网络所开发出来的 GPU 加速库,针对卷积、池化等等常见操作做了非常多的底层优化,比一般的 GPU 程序要快很多。在使用 GPU 的时候,PyTorch 会默认使用 cuDNN 加速。但是使用cuDNN 加速时,torch.backends.cudnn.benchmark 模式是为 False。
cuDNN 对网络进行优化通过torch.backends.cudnn.benchmark 模式选择不同版本的优化算法。但是这些优化算法有些是非确定性的,所以会导致结果的随机性。所以需要orch.backends.cudnn.deterministic = True,选择默认的优化方式,使得结果可以复现。
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
如果不是很清楚默认情况下的torch.backends.cudnn.benchmark值,还是显式设置为 False 比较保险
demo1: 结果随机
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
demo2:结果可复现
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
demo3: 直接禁用cudnn,结果可复现
torch.backends.cudnn.enabled = False
4.数据导入num_workers = 0
def _init_fn():np.random.seed(manualSeed)DataLoding = data.DataLoader(..., batch_size = ..., collate_fn = ..., num_workers =0shuffle = ..., pin_memory = ...,worker_init_fn=np.random.seed(1)
参考博文:
Random seed initialization:https://discuss.pytorch.org/t/random-seed-initialization/7854/17
torch.backends.cudnn.benchmark ?!:https://zhuanlan.zhihu.com/p/73711222
pytorch torch.backends.cudnn设置作用:
https://www.cnblogs.com/wanghui-garcia/p/11514502.html