【pytorch】pytorch模型可复现设置

文章目录

        • 序言
        • 1. 可复现设置代码
        • 2. 可复现设置代码解析
          • 2.1 消除python与numpy的随机性
          • 2.2 消除torch的随机性
          • 2.3 消除DataLoader的随机性
          • 2.4 消除cuda的随机性
          • 2.5 避免pytorch使用不确定性算法
          • 2.6 使用pytorch-lightning
          • 2.7 特殊情况

序言
  • 为了让模型在同一设备每次训练的结果可复现,需进行可复现设置
1. 可复现设置代码
def set_seed(seed):random.seed(seed)np.random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.enabled = Falseos.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'torch.use_deterministic_algorithms(True)set_seed(21)
2. 可复现设置代码解析
2.1 消除python与numpy的随机性
import random
import numpy as np# 消除numpy和random的随机性
random.seed(SEED)
np.random.seed(SEED)# 固定python环境变量中的PYTHONHASHSEED,禁止hash随机化
os.environ['PYTHONHASHSEED'] = str(seed)
2.2 消除torch的随机性
import torch
import torch
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED) 		# 适用于显卡训练
torch.cuda.manual_seed_all(SEED) 	# 适用于多显卡训练
2.3 消除DataLoader的随机性
  • 使用torch时,一般都会使用DataLoader加载数据集,这个类使用了多线程的处理方式,因此会造成一定随机性
def seed_worker(worker_id):random.seed(SEED + worker_id)g = torch.Generator()
g.manual_seed(SEED)DataLoader(train_dataset,batch_size=batch_size,num_workers=num_workers,worker_init_fn=seed_workergenerator=g,
)
  • 设置shuffle=True并设置随机种子
# 可复现设置代码,可按上述来设置
def setup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = True# 设置随机数种子
setup_seed(21)# shufle=True,不同训练之间一样的乱序
train_loader2 = DataLoader(dataset=dealDataset, batch_size=32, shuffle=True)
2.4 消除cuda的随机性
  • 适用于GPU训练
# 确保每次返回的卷积算法等是固定的
torch.backends.cudnn.deterministic = True# 禁止cudnn使用非确定性算法
torch.backends.cudnn.enabled = False# 配合enabled命令使用
# True:自动寻找最适合当前配置的高效算法来优化运行效率
# False:保证实验结果可复现
torch.backends.cudnn.benchmark = False
  • 不设置torch.backends.cudnn.enabled = False的话,无法保证训练结果可复现性。但添加这行后会导致训练速度很慢

  • 如果cuda是10.2及以上版本,少数cuda操作是不确定的,需要做如下设置

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':16:8'或os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
2.5 避免pytorch使用不确定性算法
  • 配置pytorch在可用的情况下,使用确定性算法而不是非确定性算法。如果已知某个操作是不确定的,并且没有确定的替代办法,则抛出RuntimeError错误
torch.use_deterministic_algorithms(True)
2.6 使用pytorch-lightning
  • 详见 pytorch-lightning设置

  • 该方法保证了可复现的同时,没有牺牲任何训练速度

2.7 特殊情况
  • 在某些版本的CUDA中,RNN和LSTM网络可能具有不确定性行为,如LSTM中的dropout,需要注意这一特性。需要关注如何消除不确定性

 


【参考文章】
[1]. pytorch模型可复现性设置
[2]. DataLoader类设置打乱的随机数种子
[3]. pytorch消除模型训练的随机性
[4]. pytorch模型训练可复现方案

created by shuaixio, 2024.02.24

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/700796.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

复旦大学MBA聚劲联合会:洞见智慧,拓宽思维格局及国际化视野

12月2日,“焕拥时代 俱创未来”聚劲联合会俱创会年度盛典暨俱乐部募新仪式圆满收官。16家复旦MBA俱乐部、200余名同学、校友、各界同仁齐聚复旦管院,一起在精彩纷呈的圆桌论坛里激荡思想,在活力四射的俱乐部风采展示中凝聚力量。      以…

08 Redis之集群的搭建和复制原理+哨兵机制+CAP定理+Raft算法

5 Redis 集群 2.8版本之前, Redis采用主从集群模式. 实现了数据备份和读写分离 2.8版本之后, Redis采用Sentinel哨兵集群模式 , 实现了集群的高可用 5.1 主从集群搭建 首先, 基本所有系统 , “读” 的压力都大于 “写” 的压力 Redis 的主从集群是一个“一主多从”的读写分…

海莲花APT组织样本跟踪分析

APT组织简介 OceanLotus(海莲花)APT组织是一个长期针对中国及其他东亚、东南亚国家(地区)政府、科研机构、海运企业等领域进行攻击的APT组织,该组织也是针对中国境内的最活跃的APT组织之一,该组织主要通过鱼叉攻击和水坑攻击等方法,配合多种…

计算机网络面经-HTTP的8种请求方式

简单介绍 HTTP是超文本传输协议,其定义了客户端与服务器端之间文本传输的规范。HTTP默认使用80端口,这个端口指的是服务端的端口,而客户端使用的端口是动态分配的。当我们没有指定端口访问时,浏览器会默认帮我们添加80端口。我们…

2.21日学习打卡----初学Nginx(一)

2.21日学习打卡 目录: 2.21日学习打卡一. Nginx是什么?概述Nginx 五大应用场景HTTP服务器正向代理反向代理正向代理与反向代理的区别:负载均衡动静分离 为啥使用Nginx? 二.下载Nginx(linux)环境准备下载Nginx和安装NginxNginx源码…

新手搭建服装小程序全攻略

随着互联网的快速发展,线上购物已经成为了人们日常生活中不可或缺的一部分。服装作为人们日常消费的重要品类,线上化趋势也日益明显。本文将详细介绍如何从零开始搭建一个服装小程序商城,从入门到精通的捷径,帮助你快速掌握小程序…

面试前端性能优化八股文十问十答第一期

面试前端性能优化八股文十问十答第一期 作者:程序员小白条,个人博客 相信看了本文后,对你的面试是有一定帮助的!关注专栏后就能收到持续更新! ⭐点赞⭐收藏⭐不迷路!⭐ 1)CDN的概念 CDN&…

专项:PID控制方法深究

1.前言 PID在工业界随处可见。其的原理是什么? 2.数学物理代表意义 PID全名为比例积分微分控制器。顾名思义,表明其由三个控制器组成。 一是P,其代表比例(Proportional); 二是I,其代表积分(I…

《TCP/IP详解 卷一》第2章 Internet地址结构

目录 2.1 引言 2.2 表示IP地址 2.3 基本的IP地址结构 单播地址 全球单播地址: 组播地址 任播地址 2.4 CIDR和聚合 2.5 特殊用途地址 2.6 分配机构 2.7 单播地址分配 2.8 与IP地址相关的攻击 2.9 总结 2.1 引言 2.2 表示IP地址 IPv4地址:3…

【数据分享】不同共享社会经济路径下中国未来280个城市土地数量数据集(免费获取)

了解未来城市土地数量对于城市规划、社会经济发展和气候变化研究具有重要意义。通过分析不同共享社会经济路径下中国未来城市土地数量的数据,可以为未来城市发展趋势和可持续规划提供科学依据。 本次我们给大家带来的是不同共享社会经济路径下中国未来城市土地数量…

【退役之重学前端】使用vite+vue3+vue-router,重构react+react-router前后端分离的商城后台管理系统

前言: 对前端各个技术板块,HTML、CSS、JavaScript、ES6、vue家族,整体上能“摸其大概”。笔者计划重构一个基于react的商城后台管理系统。 —— 2024年2月16日 技术选型 #语言和框架 vue3sassbootstrapES7 #架构 前后端分离分层架构模块化…

C# 实现网页内容保存为图片并生成压缩包

目录 应用场景 实现代码 扩展功能(生成压缩包) 小结 应用场景 我们在一个求职简历打印的项目功能里,需要根据一定的查询条件,得到结果并批量导出指定格式的文件。导出的格式可能有多种,比如WORD格式、EXCEL格式、PDF格式等,…

使用命令行创建文件夹和文件

创建文件夹 md 文件夹名字 创建文件 echo >文件名字.后缀然后回车即可 注意点:echo >文件名字.后缀 的 >后面不可以加空格,不然会报错\

深入理解Go语言中的Channel与Select

Go 语言中的 Channel 和 Select 是并发编程中的重要概念和机制,它们为协程之间的通信和同步提供了强大的支持。接下来将深入介绍 Channel 和 Select 的概念、使用方法、特性,并结合实际工作场景和示例代码进行详细讨论。 1. Channel 概述 1.1 什么是 C…

《Docker极简教程》--Docker卷和数据持久化--Docker卷的概念

在容器化环境中,数据持久性是一个重要挑战。传统上,容器是短暂的、易于销毁和重建的,这与数据的持久性需求相冲突。当容器被销毁时,容器内部的数据通常会丢失,因此需要一种方法来确保数据的持久性。这涉及到数据的存储…

Java基础之lambda表达式(五)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

《Python 语音转换简易速速上手小册》第9章 特定领域的语音处理(2024 最新版)

文章目录 9.1 语音处理在不同行业的应用9.1.1 基础知识9.1.2 主要案例:智能客服机器人案例介绍案例 Demo案例分析9.1.3 扩展案例 1:医疗语音助手案例介绍案例 Demo案例分析9.1.4 扩展案例 2:语言学习应用案例介绍案例 Demo

Python 实现Hash算法验证

目录 一、Hash 算法的原理及作用 二、python验证Hash算法代码实现 三、运行脚本验证如下 四、在线工具验证结果如下 五、总结 一、Hash 算法的原理及作用 Hash加密算法是一种将任意长度的消息压缩成固定长度散列值的算法。它的特点是快速、不可逆和安全。对于相同的消息&a…

Java整型字符串数组

整数类型 byte,字节 【1字节】表示范围:-128~127即: -2^7~2^7 -1 short,短整型 【2字节】表示范围: -32768~32767 int,整型 【4字节】表示范围: -2147483648~2147483647 long,长整型 【8字节】表示范围: -9223372036854775…

陪玩软件系统的开发-用PHP书写,uni开发的陪玩平台更有质量-线上线下功能齐全-APP小程序H5公众号都有,源码交付!

线上陪玩系统的功能 在线预订:用户可以在陪玩系统中在线预订陪玩服务,系统会根据用户的订单要求自动匹配陪玩人员。 指定搜索:用户可以通过搜索指定的ID来找到他们想要的陪玩人员。 在线交流:在陪玩系统中提供在线沟通功能&…