PyTorch中加载模型权重 A匹配B|A不匹配B

在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
在这里插入图片描述

未修改网络,A与B一致

很简单,直接.load_state_dict()

net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))

修改了网络,A与B不一致

[pytorch官方文档](Search — PyTorch master documentation):

load_state_dict(state_dict, strict=True)

将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的

load_state_dict()函数有两个返回值:

missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表

方法一:

将strict改为false,加载键值相同的部分。

model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:

ANet = torch.load('ANet.pt')  # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()  
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

方法二:

1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/24504.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TypeScript技能总结(三)

typescript是js的超集,目前很多前端框架都开始使用它来作为项目的维护管理的工具,还在不断地更新,添加新功能中,我们学习它,才能更好的在的项目中运用它,发挥它的最大功效 //泛型 > 参数和返回值类型相…

Java课设--学生信息管理系统(例1)

文章目录 前提一、运行效果二、Text实现类三、Manage选择类四、StudentWay学生方法类五、StudnetSql数据库类 前题 例1为无使用GUI图形界面,例2使用GUI图形界面! 首先自己的JDBC驱动已经接好了,连接自己的数据库没有问题。连接数据库可以看…

《吐血整理》高级系列教程-吃透Fiddler抓包教程(33)-Fiddler如何抓取WebSocket数据包

1.简介 本来打算再写一篇这个系列的文章也要和小伙伴或者童鞋们说再见了,可是有人留言问WebSocket包和小程序的包不会抓,那就关于这两个知识点宏哥就再水两篇文章。 2.什么是Socket? 在计算机通信领域,socket 被翻译为“套接字…

物联网||不一样的点灯实验(2)|通过使用CMSIS库函数实现点灯实验-学习笔记(12)

文章目录 通过使用CMSIS库函数实现点灯实验1 如何使用CMIS库2 如何利用CMSIS库操作IO 两种实现方法的比较课后作业:完整代码:LED.C:test.c:led.h:systick.h:systick.c: 通过使用CMSIS库函数实现点灯实验 1 如何使用CMIS库 #####如何使用此驱动#####[. .](#)启用GPI…

langchain-ChatGLM源码阅读:参数设置

文章目录 上下文关联对话轮数向量匹配 top k控制生成质量的参数参数设置心得 上下文关联 上下文关联相关参数: 知识相关度阈值score_threshold内容条数k是否启用上下文关联chunk_conent上下文最大长度chunk_size 其主要作用是在所在文档中扩展与当前query相似度较高…

RichTextBox基本用法

作用:富文本编辑器,支持多种文本展示 常用属性: 允许显示多行 自动换行 展示滚动条 ScrollBars属性值: 1、Both:只有当文本超过RichTextBox的宽度或长度时,才显示水平滚动条或垂直滚动条,或两…

基于粒子群优化算法的配电网光伏储能双层优化配置模型[IEEE33节点](选址定容)(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码、数据、讲解 💥1 概述 由于能源的日益匮乏,电力需求的不断增长等,配电网中分布式能源渗透率不断提高,且逐渐向主动配电网方…

掌握Java排序算法:实现主流排序方法与性能对比

一,C语言,主流的排序方法介绍 当谈论主流的排序方法时,通常指的是在实际应用中表现优秀且被广泛采用的排序算法。以下是常见的主流排序方法及其介绍、时间复杂度、空间复杂度和简单的C语言代码实现: 冒泡排序(Bubble S…

nginx使用

1 安装 yum -y install gcc pcre-devel zlib-devel openssl openssl-devel yum install -y wget wget https://nginx.org/download/nginx-1.16.1.tar.gz tar -zxvf nginx-1.16.1.tar.gz cd nginx-1.16.1 ./configure --prefix/usr/local/nginx make make install2 目录 目录说…

【雕爷学编程】Arduino动手做(184)---快餐盒盖,极低成本搭建机器人实验平台2

吃完快餐粥,除了粥的味道不错之外,我对个快餐盒的圆盖子产生了兴趣,能否做个极低成本的简易机器人呢?也许只需要二十元左右 知识点:轮子(wheel) 中国词语。是用不同材料制成的圆形滚动物体。简…

Mac端口扫描工具

端口扫描工具 Mac内置了一个网络工具 网络使用工具 按住 Command 空格 然后搜索 “网络实用工具” 或 “Network Utility” 即可 域名/ip转换Lookup ping功能 端口扫描 https://zhhll.icu/2022/Mac/端口扫描工具/ 本文由 mdnice 多平台发布

C++中的常量介绍

C中的常量介绍 在C中,常量是一个固定的值,它在程序执行期间不会发生改变。常量可以分为: 1.字面常量 2.符号常量:符号常量是通过标识符来表示的常量值,在程序中使用时要先进行定义。使用符号常量的好处是可以给常量起…

JUnit教程_编程入门自学教程_菜鸟教程-免费教程分享

教程简介 JUnit是一个Java语言的单元测试框架。它由Kent Beck和Erich Gamma建立,逐渐成为源于Kent Beck的sUnit的xUnit家族中最为成功的一个。 JUnit有它自己的JUnit扩展生态圈。多数Java的开发环境都已经集成了JUnit作为单元测试的工具。JUnit是由 Erich Gamma 和…

【具生智能】前沿思考与总结(DALL-E-Bot TinyBot)

1. DALL-E-Bot DALL-E-Bot: Introducing Web-Scale Diffusion Models to Robotics (robot-learning.uk) **(2023-05-04)**DALL-E-Bot: Introducing Web-Scale Diffusion Models to Robotics DALL-E-Bot:将网络规模的扩散模型引入机器人 第…

C语言----动态内存分配(malloc calloc relloc free)超全知识点

目录 一.动态内存函数 1.malloc 2.free 3.calloc 4.malloc和calloc的区别 5.realloc 二.动态内存分配的常见错误 1.对null进行解引用操作 2.对动态开辟空间的越界访问 3.对非动态开辟内存使用free释放 4.使用free释放动态开辟内存的一部分 5.对同一块动态内存多次…

物联网|按键实验---学习I/O的输入及中断的编程|读取I/O的输入信号|中断的编程方法|轮询实现按键捕获实验-学习笔记(13)

文章目录 实验目的了解擒键的工作原理及电原理图 STM32F407中如何读取I/O的输入信号STM32F407对中断的编程方法通过轮询实现按键捕获实验如何利用已有内工程创建新工程通过轮询实现按键捕获代码实现及分析1 代码的流程分析2 代码的实现 Tips:下载错误的解决 实验目的 了解擒键…

Leetcode-每日一题【剑指 Offer 09. 用两个栈实现队列】

题目 用两个栈实现一个队列。队列的声明如下,请实现它的两个函数 appendTail 和 deleteHead ,分别完成在队列尾部插入整数和在队列头部删除整数的功能。(若队列中没有元素,deleteHead 操作返回 -1 ) 示例 1: 输入: [&…

如何快速做单元测试?

首先写unit test之前,要确认自己的测试遵循两个原则: 1、尽量不要干涉原来的代码。从阅读代码的体验来说,不要让你的测试(哪怕是一小段if..else...的代码)出现在你准备测试的代码中。 2、代码要只是测试某个class里面…

springboot第34集:ES 搜索,nginx

#用search after解决深分页性能问题 #第一页 GET /bank/_search {"size": 10,"sort": [{"account_number": {"order": "asc"}}] }#第二页 GET /bank/_search {"size": 10,"sort": [{"account_numb…

【WEB逆向】前端全报文加密的分析技巧

由于前端全报文加密,无法从变量的全文搜索来快速定位加密函数对加密参数的定位(全局搜索还有个弊病是编码混淆的js也不能全局搜到,需要进一步分析判定混淆的编码形式后再全局搜编码后的变量名),因此可利用xhr断点全局拦…