PyTorch中加载模型权重 A匹配B|A不匹配B

在做深度学习项目时,从头训练一个模型是需要大量时间和算力的,我们通常采用加载预训练权重的方法,而我们往往面临以下几种情况:
在这里插入图片描述

未修改网络,A与B一致

很简单,直接.load_state_dict()

net = ANet(num_classses = 5,init_weights=True)
net.to(device)
net.load_state_dict(torch.load('weight/B_weight.pth'))

修改了网络,A与B不一致

[pytorch官方文档](Search — PyTorch master documentation):

load_state_dict(state_dict, strict=True)

将 state_dict 中的参数和缓冲区复制到此模块及其后代中。如果 strict 为 True,则 state_dict 的键必须与该模块的 state_dict() 函数返回的键完全匹配。

state_dict是包含参数和持久缓冲区的字典,可以看出 strict默认为True,所以默认状态下是严格要求state_dict中的key与torch.nn.Module.state_dict返回的key完全一致的

load_state_dict()函数有两个返回值:

missing_keys 是包含缺失键的 str 列表
unexpected_keys 是包含意外键的 str 列表

方法一:

将strict改为false,加载键值相同的部分。

model = NET2()
state_dict = model.state_dict()
weights = torch.load(weights_path)['model_state_dict']	#读取预训练模型权重
model.load_state_dict(weights, strict=False)	#strict

但是此时还存在一种情况:键值相同但shape不同,故应进行if…in…的判断:

ANet = torch.load('ANet.pt')  # 加载预训练权重模型(.pt文件)参数
#现成的模型的话,如resnet50 = models.resnet50(pretrained=True)
#采用:pretrained_dict = resnet50().state_dict()  
model = Model() # 创建模型
model_dict = model.state_dict() # 得到模型的参数字典# 判断预训练模型中网络的模块是否修改后的网络中也存在,并且shape相同,如果相同则取出
pretrained_dict = {k: v for k, v in ANet.items() if k in model_dict and (v.shape == model_dict[k].shape)}# 更新修改之后的 model_dict
model_dict.update(pretrained_dict)# 加载我们真正需要的 state_dict
model.load_state_dict(model_dict, strict=False)

方法二:

1.将权重导入原模型,之后在加载后的原模型基础上进行修改。
2.修改权重文件参数,再进行导入
适用于改动不大的模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/24504.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java课设--学生信息管理系统(例1)

文章目录 前提一、运行效果二、Text实现类三、Manage选择类四、StudentWay学生方法类五、StudnetSql数据库类 前题 例1为无使用GUI图形界面,例2使用GUI图形界面! 首先自己的JDBC驱动已经接好了,连接自己的数据库没有问题。连接数据库可以看…

《吐血整理》高级系列教程-吃透Fiddler抓包教程(33)-Fiddler如何抓取WebSocket数据包

1.简介 本来打算再写一篇这个系列的文章也要和小伙伴或者童鞋们说再见了,可是有人留言问WebSocket包和小程序的包不会抓,那就关于这两个知识点宏哥就再水两篇文章。 2.什么是Socket? 在计算机通信领域,socket 被翻译为“套接字…

物联网||不一样的点灯实验(2)|通过使用CMSIS库函数实现点灯实验-学习笔记(12)

文章目录 通过使用CMSIS库函数实现点灯实验1 如何使用CMIS库2 如何利用CMSIS库操作IO 两种实现方法的比较课后作业:完整代码:LED.C:test.c:led.h:systick.h:systick.c: 通过使用CMSIS库函数实现点灯实验 1 如何使用CMIS库 #####如何使用此驱动#####[. .](#)启用GPI…

langchain-ChatGLM源码阅读:参数设置

文章目录 上下文关联对话轮数向量匹配 top k控制生成质量的参数参数设置心得 上下文关联 上下文关联相关参数: 知识相关度阈值score_threshold内容条数k是否启用上下文关联chunk_conent上下文最大长度chunk_size 其主要作用是在所在文档中扩展与当前query相似度较高…

RichTextBox基本用法

作用:富文本编辑器,支持多种文本展示 常用属性: 允许显示多行 自动换行 展示滚动条 ScrollBars属性值: 1、Both:只有当文本超过RichTextBox的宽度或长度时,才显示水平滚动条或垂直滚动条,或两…

基于粒子群优化算法的配电网光伏储能双层优化配置模型[IEEE33节点](选址定容)(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码、数据、讲解 💥1 概述 由于能源的日益匮乏,电力需求的不断增长等,配电网中分布式能源渗透率不断提高,且逐渐向主动配电网方…

【雕爷学编程】Arduino动手做(184)---快餐盒盖,极低成本搭建机器人实验平台2

吃完快餐粥,除了粥的味道不错之外,我对个快餐盒的圆盖子产生了兴趣,能否做个极低成本的简易机器人呢?也许只需要二十元左右 知识点:轮子(wheel) 中国词语。是用不同材料制成的圆形滚动物体。简…

Mac端口扫描工具

端口扫描工具 Mac内置了一个网络工具 网络使用工具 按住 Command 空格 然后搜索 “网络实用工具” 或 “Network Utility” 即可 域名/ip转换Lookup ping功能 端口扫描 https://zhhll.icu/2022/Mac/端口扫描工具/ 本文由 mdnice 多平台发布

【具生智能】前沿思考与总结(DALL-E-Bot TinyBot)

1. DALL-E-Bot DALL-E-Bot: Introducing Web-Scale Diffusion Models to Robotics (robot-learning.uk) **(2023-05-04)**DALL-E-Bot: Introducing Web-Scale Diffusion Models to Robotics DALL-E-Bot:将网络规模的扩散模型引入机器人 第…

C语言----动态内存分配(malloc calloc relloc free)超全知识点

目录 一.动态内存函数 1.malloc 2.free 3.calloc 4.malloc和calloc的区别 5.realloc 二.动态内存分配的常见错误 1.对null进行解引用操作 2.对动态开辟空间的越界访问 3.对非动态开辟内存使用free释放 4.使用free释放动态开辟内存的一部分 5.对同一块动态内存多次…

物联网|按键实验---学习I/O的输入及中断的编程|读取I/O的输入信号|中断的编程方法|轮询实现按键捕获实验-学习笔记(13)

文章目录 实验目的了解擒键的工作原理及电原理图 STM32F407中如何读取I/O的输入信号STM32F407对中断的编程方法通过轮询实现按键捕获实验如何利用已有内工程创建新工程通过轮询实现按键捕获代码实现及分析1 代码的流程分析2 代码的实现 Tips:下载错误的解决 实验目的 了解擒键…

Leetcode-每日一题【剑指 Offer 09. 用两个栈实现队列】

题目 用两个栈实现一个队列。队列的声明如下,请实现它的两个函数 appendTail 和 deleteHead ,分别完成在队列尾部插入整数和在队列头部删除整数的功能。(若队列中没有元素,deleteHead 操作返回 -1 ) 示例 1: 输入: [&…

springboot第34集:ES 搜索,nginx

#用search after解决深分页性能问题 #第一页 GET /bank/_search {"size": 10,"sort": [{"account_number": {"order": "asc"}}] }#第二页 GET /bank/_search {"size": 10,"sort": [{"account_numb…

【WEB逆向】前端全报文加密的分析技巧

由于前端全报文加密,无法从变量的全文搜索来快速定位加密函数对加密参数的定位(全局搜索还有个弊病是编码混淆的js也不能全局搜到,需要进一步分析判定混淆的编码形式后再全局搜编码后的变量名),因此可利用xhr断点全局拦…

LLM reasoners 入门实验 24点游戏

LLM reasoners Ber666/llm-reasoners 实验过程 实验样例24games,examples/tot_game24,在inference.py中配置使用代理和open ai的api key。 首先安装依赖 git clone https://github.com/Ber666/llm-reasoners cd llm-reasoners pip install -e .然后…

【雕爷学编程】Arduino动手做(187)---1.3寸OLED液晶屏模块2

37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和执行器模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的&#x…

Spring Security OAuth2.0(7):自定义认证连接数据库

自定义认证连接数据库 首先创建数据库和用户表 CREATE TABLE t_user (id bigint(20) NOT NULL AUTO_INCREMENT,username varchar(64) DEFAULT NULL,password varchar(64) DEFAULT NULL,fullname varchar(255) DEFAULT NULL,mobile varchar(20) DEFAULT NULL,PRIMARY KEY (id)…

MacOS使用brew如何下载Nginx

首先,第一步切换源: 切换 brew.git 仓库地址: cd "$(brew --repo)" git remote set-url origin https://mirrors.aliyun.com/homebrew/brew.git 替换 homebrew-core.git 仓库地址: cd "$(brew --repo)/Library/Taps/home…

LabVIEW 开发在不确定路况下自动速度辅助系统

LabVIEW 开发在不确定路况下自动速度辅助系统 智能驾驶辅助系统是汽车行业最先进的升级和尖端技术,智能交通系统依靠智能驾驶辅助系统在公共交通部门工作。该智能驾驶辅助系统技术包括自适应巡航控制,防抱死制动系统,安全气囊展开&#xff0…

【机器学习】编码、创造和筛选特征

在机器学习和数据科学领域中,特征工程是提取、转换和选择原始数据以创建更具信息价值的特征的过程。假设拿到一份数据集之后,如何逐步完成特征工程呢? 文章目录 一、特性类型分析1.1 数值型特征1.2 类别型特征1.3 时间型特征1.4 文本型特征1.…