pytorch load state dict_PyTorch 学习笔记(五):Finetune和各层定制学习率

本文截取自《PyTorch 模型训练实用教程》,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial

@[toc]

我们知道一个良好的权值初始化,可以使收敛速度加快,甚至可以获得更好的精度。而在实际应用中,我们通常采用一个已经训练模型的模型的权值参数作为我们模型的初始化参数,也称之为Finetune,更宽泛的称之为迁移学习。迁移学习中的Finetune技术,本质上就是让我们新构建的模型,拥有一个较好的权值初始值。

finetune权值初始化三步曲,finetune就相当于给模型进行初始化,其流程共用三步:

第一步:保存模型,拥有一个预训练模型; 第二步:加载模型,把预训练模型中的权值取出来; 第三步:初始化,将权值对应的“放”到新模型中

一、Finetune之权值初始化

在进行finetune之前我们需要拥有一个模型或者是模型参数,因此需要了解如何保存模型。官方文档中介绍了两种保存模型的方法,一种是保存整个模型,另外一种是仅保存模型参数(官方推荐用这种方法),这里采用官方推荐的方法。

第一步:保存模型参数

若拥有模型参数,可跳过这一步。
假设创建了一个net = Net(),并且经过训练,通过以下方式保存:
torch.save(net.state_dict(), 'net_params.pkl')

第二步:加载模型

进行三步曲中的第二步,加载模型,这里只是加载模型的参数:
pretrained_dict = torch.load('net_params.pkl')

第三步:初始化

进行三步曲中的第三步,将取到的权值,对应的放到新模型中:
首先我们创建新模型,并且获取新模型的参数字典net_state_dict:
net = Net() # 创建net
net_state_dict = net.state_dict() # 获取已创建net的state_dict接着将pretrained_dict里不属于net_state_dict的键剔除掉:
pretrained_dict_1 =  {k: v for k, v in pretrained_dict.items() if k in net_state_dict}然后,用预训练模型的参数字典 对 新模型的参数字典net_state_dict 进行更新:
net_state_dict.update(pretrained_dict_1)最后,将更新了参数的字典 “放”回到网络中:
net.load_state_dict(net_state_dict)

这样,利用预训练模型参数对新模型的权值进行初始化过程就做完了。

采用finetune的训练过程中,有时候希望前面层的学习率低一些,改变不要太大,而后面的全连接层的学习率相对大一些。这时就需要对不同的层设置不同的学习率,下面就介绍如何为不同层配置不同的学习率。

二、不同层设置不同的学习率

在利用pre-trained model的参数做初始化之后,我们可能想让fc层更新相对快一些,而希望前面的权值更新小一些,这就可以通过为不同的层设置不同的学习率来达到此目的。

为不同层设置不同的学习率,主要通过优化器对多个参数组进行设置不同的参数。所以,只需要将原始的参数组,划分成两个,甚至更多的参数组,然后分别进行设置学习率。 这里将原始参数“切分”成fc3层参数和其余参数,为fc3层设置更大的学习率。

请看代码:

ignored_params = list(map(id, net.fc3.parameters())) # 返回的是parameters的 内存地址
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 
optimizer = optim.SGD([
{'params': base_params},
{'params': net.fc3.parameters(), 'lr': 0.001*10}], 0.001, momentum=0.9, weight_decay=1e-4)

第一行+ 第二行的意思就是,将fc3层的参数net.fc3.parameters()从原始参数net.parameters()中剥离出来 base_params就是剥离了fc3层的参数的其余参数,然后在优化器中为fc3层的参数单独设定学习率。

optimizer = optim.SGD(......)这里的意思就是 base_params中的层,用 0.001, momentum=0.9, weight_decay=1e-4 fc3层设定学习率为: 0.001*10

完整代码位于 https://github.com/tensor-yu/PyTorch_Tutorial/blob/master/Code/2_model/2_finetune.py

补充:

挑选出特定的层的机制是利用内存地址作为过滤条件,将需要单独设定的那部分参数,从总的参数中剔除。 base_params 是一个list,每个元素是一个Parameter 类 net.fc3.parameters() 是一个

ignored_params = list(map(id, net.fc3.parameters())) net.fc3.parameters() 是一个 所以迭代的返回其中的parameter,这里有weight 和 bias 最终返回weight和bias所在内存的地址


转载请注明出处:https://blog.csdn.net/u011995719/article/details/85107310

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/336220.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为配备鸿蒙系统的手机,华为P50/新平板双双来袭!全球首发鸿蒙系统:配置都非常强悍...

【12月12日讯】相信大家都知道,华为方面已经正式官宣,将会在12月16日正式推出鸿蒙系统首个手机Bate版本,但也有很多网友们担忧,华为手机在脱离了Android系统以后,鸿蒙OS系统是否真的可以击败Android系统,第…

【WebRTC---入门篇】(十八)WebRTC非音视频数据传输

WebRTC传输非音视频重要API createDataChannel options ordered 在传输非音视频的时候是否是按序到达的。 maxPacketLifeTime/maxRetransmits 最大包存活时间;最大传输次数。两者二选一 negotiated ID 唯一标识 DataChannel事件

ios 静音模式_静音设计模式

ios 静音模式您最近是否遵循Mute-Design-Pattern™编写了大量代码? 例如 try {complex();logic();here(); } catch (Exception ignore) {// Will never happen heheSystem.exit(-1); }Java 8有一个更简单的方法! 只需将这个非常有用的工具添加到您的Ut…

datatable使用_使用Streamlit从简单的Python脚本创建交互式WebApp

如果有人告诉您可以使用150-200行代码创建交互式Web应用程序,该怎么办? 有趣的权利。 Streamlit为您提供了使用简单的python脚本和一些streamlit调用来创建漂亮的Web应用程序的相同机会。Streamlit是一个开放源代码框架,用于以最快的方式创建…

和谐 平等_平等还是认同?

和谐 平等将对象存储在集合中时,同一对象永远不能添加两次很重要。 这是集合的核心定义。 在Java中,使用两种方法来确定两个引用的对象是否相同,或者它们都可以存在于同一Set中。 equals()和hashCode(&…

html监控用户在线与离线,HTML5判断设备在线离线及监听网络状态变化例子

经测试android ipad默认的浏览器支持,用appcan封装的网页也支持html>网络在线与离线$$function(id){return document.getElementById(id);};if(navigator.onLine){$$("status").innerHTML"第一次加载时在线";}else{$$("status").i…

opengl如何画出一个球_OpenGL-Controlling and Monitoring the Pipeline

全球图形学领域教育的领先者、自研引擎的倡导者、底层技术研究领域的技术公开者,东汉书院在致力于使得更多人群具备内核级竞争力的道路上,将带给小伙伴们更多的公开技术教学和视频,感谢一路以来有你的支持。我们正在用实际行动来帮助小伙伴们…

【WebRTC---入门篇】(二十)WebRTC核心之SDP详解

SDK规范 会话层 媒体层 SDP规范相关参考 WebRTC中的SDP

junit5和junit4_JUnit 5 –条件

junit5和junit4最近,我们了解了JUnit的新扩展模型以及它如何使我们能够将自定义行为注入测试引擎。 我向你保证要看情况。 现在就开始吧! 条件允许我们在应该执行或不应该执行测试时定义灵活的标准。 它们的正式名称是“ 条件测试执行” 。 总览 本系列…

android 画圆教程,android shap画圆(空心圆、实心圆)

实心圆:android:shape"oval"android:useLevel"false">android:width"1dp"android:color"color/colorWhite" />android:width"10dp"android:height"10dp" />空心圆:android:shape&…

python opencv输出mp4_10分钟学会使用YOLO及Opencv实现目标检测

点击边框调出视频工具条 计算机视觉领域中,目标检测一直是工业应用上比较热门且成熟的应用领域,比如人脸识别、行人检测等,国内的旷视科技、商汤科技等公司在该领域占据行业领先地位。相对于图像分类任务而言,目标检测会更加复杂一…

【开源项目】向Nginx-RTMP服务器推流

Nginx-RTMP服务器搭建 Nginx下载 Nginx-RTMP模块 先使用root用户,首先安装GCC ; G ;make; libssl ;libpcre3-dev ;zlib1g-dev sudo apt-get install libssl-dev sudo apt-get install libpcre3 libpcre3-dev sudo apt-get install openssl libssl-dev sudo …

捍卫者usb管理控制系统_捍卫Java

捍卫者usb管理控制系统因此,我们不时发布了一本电子书,名为“十大Java性能问题” 。 毫无例外,一些人回答了一些“问题是您正在使用Java”。 显然,Java一直在受到批评,人们已经预测了它的消亡已有一段时间了。 当然&a…

html怎么上传qq空间,qq空间怎么上传照片

当我们想要把照片上传到qq空间里,应该怎么办呢?下面就让学习啦小编告诉你空间上传照片的方法,希望对大家有所帮助。空间上传照片的方法打开QQ主界面,在主界面头像的右则有个小星星,那就是进入空间的快捷方式,点一下小…

android gridview控件使用详解_Android开发实现自定义日历、日期选择控件

点击上方蓝字关注 ??来源: wenzhihao123https://www.jianshu.com/p/a2f102c728ce前言最近项目需要日历效果,考虑用第三方的反而不太适合设计需求,修改复杂,与其这样不入自己重新写一个干净的控件。虽不是什么牛逼控件&#xff0…

LeetCode 225. 用队列实现栈

算法 (队列,栈) O(n) 我们用一个队列来存储栈中元素。对于栈中的四种操作: push(x) – 直接入队; pop() – 即需要弹出队尾元素。我们先将队首元素弹出并插入队尾,循环 n−1次,n是队列长度。此时队尾元素已经在队首…

java jinq_将JINQ与JPA和H2一起使用

java jinq几天前,我读了Iu Ming-Yee对JINQ的有趣采访 。 顾名思义,JINQ是一种尝试提供类似于LINQ for Java的尝试。 基本思想是缩小在关系数据模型上执行查询的面向对象代码之间的语义鸿沟。 关系数据库模型的查询应轻松集成到代码中,以使其感…

HTML设置字体颜色1008无标题,如何在HTML中设置字体颜色,你知道这几种方式吗?...

color设置字体颜色在color设置字体颜色之前,我们首先了解color在css中有几种取值方式,一共有4种方式,若有不全还请在评论区告知谢谢,4种方式如下:十六进制、十进制、 英文单词、十六进制的缩写。现在让我们进入字体颜色…