Pytorch加载部分预训练模型的参数

问题背景

假设我有一个已训练好的Model1,并已保存它的参数为.pth格式,我有一个与Model1结构完全相同的模型Model2,我希望Model2加载Model1中与特征提取有关的模块的参数,其他模块的参数随机初始化。

应用场景为在K折交叉验证时,我希望从第二折开始的模型加载第一折训练模型的部分参数,并在此基础上微调,从而减少训练轮数。

解决方法

加载保存的第一折训练好的模型参数,因为我保存时是多GPU,加载时需要多GPU模型参数去掉module:

state_dict = torch.load("./Fold_1_checkpoint_4.pth",map_location=torch.device('cpu'))
# 多GPU 模型参数去掉 module
from collections import OrderedDict
state_dict_new = OrderedDict()
for k, v in state_dict.items():name = k[7:]  # 去掉 `module.`# name = k.replace(“module.", "")state_dict_new[name] = v

打印出来看看:

OrderedDict([('LD1.conv.weight',tensor([[[ 0.0047,  0.0013,  0.0026,  0.0025,  0.0060, -0.0055,  0.0024,-0.0022,  0.0052, -0.0039, -0.0015,  0.0049,  0.0861,  0.0095,0.0032,  0.0033,  0.0037,  0.0008,  0.0023,  0.0080,  0.0045,-0.0023,  0.0097,  0.0034,  0.0160]]])),('LD1.conv.bias', tensor([0.0101])),('position_embedder1.position_embedding',tensor([[-0.1612, -0.2091, -0.9842,  ...,  0.7624, -0.7286, -0.6549],[-0.2659,  0.0225, -1.1011,  ...,  1.3899, -0.6597, -0.7806],[-0.3426, -0.2478, -1.0329,  ...,  1.2922, -0.8009, -0.8948],...,[ 1.4002,  0.0854,  0.4595,  ...,  0.8604, -0.6667, -0.3106],[-0.2322,  0.1524,  0.0809,  ..., -0.3961,  1.2030, -0.3428],[-0.7272,  2.0944, -0.7098,  ...,  0.2923,  0.1994, -0.2035]])),······('encoder1.attn_layers.0.attention.gen1.weight',tensor([[-0.2019,  0.2232,  0.2444,  ..., -0.0791,  0.1458,  0.2732],[-0.3740,  0.0378, -0.1596,  ..., -0.2621,  0.2512,  0.1971],[ 0.0941, -0.2160, -0.1651,  ...,  0.1103,  0.2440, -0.8079],...,[-0.0320, -0.2835,  0.7415,  ...,  0.1435, -0.5525, -0.2343],[-0.2918, -0.4909,  0.0787,  ..., -0.2581, -0.2801,  0.5388],[-0.3744, -0.8789, -0.3549,  ..., -0.0468, -0.1703, -0.1652]])),('encoder1.attn_layers.0.attention.gen1.bias',tensor([ 0.1687, -0.8035, -0.0674, -0.4795, -0.2964, -0.5702, -0.9225, -0.0856,1.4210, -0.9099, -1.3605, -0.8256, -1.6564, -1.5394, -0.7184,  0.4084,-0.6281, -2.2652, -0.7902, -0.9841, -0.4321, -0.4909, -1.1394, -1.2694,-0.0219, -0.3581,  0.5682, -1.2015, -0.9408, -1.4014, -1.2602, -1.3793,-0.8697, -2.1542, -0.7590, -1.7175,  0.0472,  0.0683,  0.4771, -1.7388,-1.7942, -0.2653, -0.2880,  0.1545, -1.4922, -0.3756,  0.5862, -1.5961,-0.2772,  0.9524, -0.1537, -0.4896, -0.2866, -0.9232,  0.2729,  0.0779,-1.7396,  0.0744, -0.9369,  0.3411, -0.2821, -0.9431, -1.1710,  0.1081])),······('gate1.0.weight',tensor([[ 0.4562, -0.0911,  0.2505,  ...,  0.4371,  0.1663, -0.4728],[ 0.1472,  0.1287,  0.0289,  ...,  0.0881,  0.0862, -0.2207],[ 0.5243,  0.0716,  0.4541,  ..., -0.4321,  0.6488,  0.7909],...,[-0.0867,  0.2686,  0.1457,  ...,  0.0756,  0.4644, -0.4607],[-0.1575,  0.3019,  0.1621,  ..., -0.1950,  0.2406, -0.2090],[-0.0231, -0.2043, -0.0222,  ...,  0.1126,  0.2387, -0.7590]])),('gate1.0.bias',tensor([-0.5651, -0.6304, -0.6943, -0.5891, -0.2952, -0.3897, -0.6592, -0.3604,0.1627, -0.4173, -0.3456, -0.6569, -0.0337, -0.1433, -0.3954, -0.0141])),······('Linear_res2.bias',tensor([-0.0314, -0.0065,  0.0268, -0.0321, -0.0738, -0.0063, -0.0531, -0.0615,-0.0552, -0.0357, -0.0639, -0.0893, -0.0361, -0.0736, -0.0347, -0.0330,-0.0759, -0.0828, -0.0665, -0.0439, -0.0652, -0.0718, -0.0231, -0.0297,-0.0448, -0.0408, -0.0181, -0.0379, -0.0274, -0.0526, -0.0139, -0.0404,-0.0284, -0.0496, -0.0515, -0.0054, -0.0704, -0.0666, -0.0385, -0.0613,-0.0471, -0.0886, -0.0398, -0.0616, -0.0304, -0.0558, -0.0301, -0.0728,-0.0869, -0.0409, -0.0514, -0.0737, -0.0510, -0.1048, -0.0555, -0.0530,-0.0721, -0.0315, -0.0070, -0.0687, -0.0707, -0.0403, -0.0611, -0.0340,-0.0935, -0.0339, -0.0462, -0.0842, -0.0516, -0.0445, -0.0364, -0.0748]))])

在此案例中,模型特征提取部分为encoder开头和gate开头的模块,因此需要过滤掉其他不需要迁移的模块:

pretrained_dict = {k: v for k, v in state_dict_new.items() if k.split('.')[0][:-1] in ['encoder','gate']}

初始化Model2

model2 = Model(args).cuda()
model_dict = model2.state_dict()

更新model_dict中需要迁移的部分并导回模型:

model_dict.update(pretrained_dict)
model2.load_state_dict(model_dict)

至此,model2中就加载了model1中与特征提取相关的模块参数,即可在此基础上微调。

参考文献

Pytorch如何加载部分预训练模型的参数
pytorch如何使模型只更新一部分参数 pytorch加载模型部分参数 转载
pytorch加载多GPU模型和单GPU模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/50340.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

告别繁琐地推!Xinstall如何一键优化你的App地推方案

在这个移动应用遍地开花的时代,App地推活动早已成为各大厂商获取新用户、提升品牌曝光度的重要手段。然而,传统地推方案中的种种弊端,如填写地推码/邀请码的繁琐、渠道打包的工作量繁重、人工登记上报的不准确等,无一不在拖慢地推…

纯电SUV又一个卷王,比亚迪都没它狠

文 | AUTO芯球 作者 | 雷慢 太狠了,就在刚刚, 我劝阻了一个高中同学暂时不要买宋PLUS纯电版, 因为又一个新能源卷王出现了, 在卷价格上,宋PLUS都没它狠。 不信你们看,埃安V第二代刚发布, …

科技与梦想 | 任正非引领华为的品牌革新之旅

从一个小型的交换设备供应商到全球通信技术的领军企业,华为的发展历程就是一部激动人心的品牌传奇。 在这背后,有一位引领者——任正非,他的远见和决心塑造了华为今日的辉煌。 "把技术做尖,把产品做精,把服务做…

如何快速抓取小红书帖子评论?两大实战Python技巧揭秘

摘要: 本文将深入探讨两种高效的Python方法,助您迅速获取小红书文章下方的所有评论,提升市场分析与用户洞察力。通过实战示例与详细解析,让您轻松掌握数据抓取技巧,为您的内容营销策略提供有力支持。 如何快速抓取小…

可见性::

目录 定义: 解决方法: ①使用synchronized实现缓存和内存的同步 修改一: 加入语句: 代码: 修改2: 在代码块中加入: 代码: 执行结果: 原因: ②使用…

【面试题】测试工程师面试题汇总

1.测试基础 【测试基础】归纳整理2023年面试题-CSDN博客 2.性能测试 【性能测试】归纳整理2023年面试题 3.Python语言 【python】归纳整理2023年常见面试题 4.自动化 https://blog.csdn.net/weixin_46697247/article/details/133493163 5.测试用例 https://blog.csdn.…

java通过poi解析word入门

文章目录 介绍一、了解word docx文档的结构二、引入POI的依赖三、解析Word文档常用API加载Word文档获取文档整体结构获取文档中的段落获取文档中的表格获取文档中的脚注 四、解析Word中的段落示例五、读取Word文档并遍历图片六、解析Word中的图片示例 介绍 Apache POI 是一个处…

基于高光谱图像的压缩感知网络

压缩感知算法原理 压缩感知(Compressed Sensing, CS)是一种信号处理技术,它允许在远低于Nyquist采样率的情况下对信号进行有效采样和重建。压缩感知理论的核心思想是利用信号的稀疏性,通过少量的线性测量重建出原始信号。以下是压…

oncoPredict:根据细胞系筛选数据预测体内或癌症患者药物反应和生物标志物

在14年的时候,oncoPredict函数的开发团队在Genome Biology上发了一篇文章。 这篇文章的核心目的是阐释了使用治疗前基线肿瘤基因表达数据去预测患者化疗反应。开发团队发现使用细胞系去预测临床样本的药物反应是可行的。 鉴于之前的理论,该研究团队首先…

[pycharm]解决pycharm运行程序出现卡住scanning files to index索引的问题

有时候会出现索引问题,显示scanning files to index 解决方法: in pycharm, go to the "File" on the left top, then select "invalidate caches/restart...", and press "invalidate and restart". 然后等它自己重启…

LC 283.移动零

283. 移动零 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行操作。 示例 1: 输入: nums [0,1,0,3,12] 输出: [1,3,12,0,0]示例 2: 输入: num…

@RestController和@Controller

RestController和Controller 在 Spring MVC 中,RestController 和 Controller 是用于定义控制器的注解,但它们有一些重要的区别。下面是对它们的详细解释和示例: Controller Controller 注解用于标记一个类是一个 Spring MVC 控制器&#…

Marin说PCB之----我的创作纪念日

今天早上打开手机无意间看到了CSDN给我发来的私信,不知不觉中已经是512天了,下面小编我就给诸位道友们分享我和CSDN的那些年。 机缘 有一天小编我正在回去的路上,突然从天上落下一本书,叫信号完整性与电源完整性分析: …

[Mysql-DML数据操作语句]

目录 数据增加:INSERT 全字段插入: 部分字段插入: 一次性添加多条: 数据修改:UPDATE 数据删除:DELECT delete truncate drop 区别 数据增加:INSERT 总体格式:insert into 表…

Vue的安装配置

1.安装node js Node.js — 在任何地方运行 JavaScript (nodejs.org) 2.测试nodejs是否安装成功 node -v npm -v3.通过npm 安装 vue npm install -g vue/cli4.测试vue是否安装成功 vue --version5.打开PyCharm,创建项目:flask-web vue create flask…

定制化爬虫管理:为企业量身打造的数据抓取方案

在数据驱动的时代,企业如何高效、安全地获取互联网上的宝贵信息?定制化爬虫管理服务应运而生,成为解锁专属数据宝藏的金钥匙。本文将深入探讨定制化爬虫管理如何为企业量身打造数据抓取方案,揭秘其在海量信息中精准捕获价值数据的…

音视频入门基础:WAV专题(1)——使用FFmpeg命令生成WAV音频文件

在文章《音视频入门基础:PCM专题(1)——使用FFmpeg命令生成PCM音频文件并播放》中讲述了生成PCM文件的方法。通过FFmpeg命令可以把该PCM文件转为WAV格式的音频文件: ./ffmpeg -ar 44100 -ac 2 -f s16le -acodec pcm_s16le -i aud…

C#知识|账号管理系统:实现修改管理员登录密码

哈喽,你好啊,我是雷工! 本节主要记录实现修改管理员登录密码的后端逻辑及相关功能,以下为学习笔记。 01 实现逻辑 ①:首先输入原密码,验证,验证通过然后可以输入新密码进行修改; ②:新密码修改为了避免输入失误导致输入的密码与自己以为修改的密码不符的情况,增加了…

创建基于 sysroot 的 linux arm64 交叉编译环境

背景 编译 arm64 架构的程序的方法有两种: 将代码上传到 arm64 架构的机器上编译。在 x64 架构上进行 arm64 交叉编译。 多数需要交叉编译的场景一般是夸平台多架构支持或是嵌入式开发。使用 sysroot 方法是一个更优的方案,不需要特定架构的编译服务器…

JavaScript(15)——操作表单元素属性和自定义属性

操作表单元素属性 表单很多情况,也需要修改属性,比如点击眼睛可以看到密码,本质是把表单类型转换为文本框正常的有属性有取值的,跟其他的标签属性没有任何区别 获取:DOM对象.属性名 设置:DOM对象.属性名…