optimizer.load_state_dict()报错parameter group不匹配的问题的原因

在加载预训练权重时可能会遇到类似下面的错误:

optimizer.load_state_dict(checkpoint['optimizer_state'])
  File "/opt/conda/lib/python3.8/site-packages/torch/optim/optimizer.py", line 145, in load_state_dict
    raise ValueError("loaded state dict contains a parameter group "
ValueError: loaded state dict contains a parameter group that doesn't match the size of optimizer's group

遇到这个问题时你去网上看,一般都是泛泛的说原因是因为模型的参数和优化器的参数不匹配,看完保证你还是一头雾水,一般遇到这样情况我干脆直接去翻看torch/optim/optimizer.py里的源码比看网上七嘴八舌甚至胡说八道的瞎说好,例如我使用的pytorch的optimizer.py的源码是这样的:

class Optimizer:r"""Base class for all optimizers... warning::Parameters need to be specified as collections that have a deterministicordering that is consistent between runs. Examples of objects that don'tsatisfy those properties are sets and iterators over values of dictionaries.Args:params (iterable): an iterable of :class:`torch.Tensor` s or:class:`dict` s. Specifies what Tensors should be optimized.defaults: (dict): a dict containing default values of optimizationoptions (used when a parameter group doesn't specify them)."""def __init__(self, params, defaults):torch._C._log_api_usage_once("python.optimizer")self.defaults = defaultsself._optimizer_step_pre_hooks: Dict[int, Callable] = OrderedDict()self._optimizer_step_post_hooks: Dict[int, Callable] = OrderedDict()self._patch_step_function()if isinstance(params, torch.Tensor):raise TypeError("params argument given to the optimizer should be ""an iterable of Tensors or dicts, but got " +torch.typename(params))self.state = defaultdict(dict)self.param_groups = []param_groups = list(params)if len(param_groups) == 0:raise ValueError("optimizer got an empty parameter list")if not isinstance(param_groups[0], dict):param_groups = [{'params': param_groups}]for param_group in param_groups:self.add_param_group(param_group)# Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,# which I don't think exists# https://github.com/pytorch/pytorch/issues/72948self._warned_capturable_if_run_uncaptured = True...def load_state_dict(self, state_dict):r"""Loads the optimizer state.Args:state_dict (dict): optimizer state. Should be an object returnedfrom a call to :meth:`state_dict`."""# deepcopy, to be consistent with module APIstate_dict = deepcopy(state_dict)# Validate the state_dictgroups = self.param_groupssaved_groups = state_dict['param_groups']if len(groups) != len(saved_groups):raise ValueError("loaded state dict has a different number of ""parameter groups")param_lens = (len(g['params']) for g in groups)saved_lens = (len(g['params']) for g in saved_groups)if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):raise ValueError("loaded state dict contains a parameter group ""that doesn't match the size of optimizer's group")# Update the stateid_map = {old_id: p for old_id, p inzip(chain.from_iterable((g['params'] for g in saved_groups)),chain.from_iterable((g['params'] for g in groups)))}...

可以看到Opimizer类的param_groups是list类型,里面的每个元素是个dict,dict里面至少有params这个key,load_state_dict()里检查目前模型的optimizer的param_groups里的元素个数和预训练权重里读取到的optimizer的param_groups里的元数个数必须一致,并且两个param_groups的对应dict类型的元素里的params key对应的参数tensor(这些参数一般都是模型网络层次里的参数,也就是torch.optim.Adam(model.parameters(), lr=0.1)这样的语句创建optimizer实例时传入的model.parameters(),至于dict里保存的其他参数,例如key是lr时对应的值是学习率超参数,以及和optimizer相关的可学习参数,例如SGD的momentum、Adam的betas等参数)的长度也必须一致!

一般来说,如果你模型训练使用的Optimizer和你要加载的预训练权重保存时的Optimizer一致的话,跟Optimizer本身相关的超参数和可学习参数的个数不会有不同,如果还报上面的错误,那说明你的模型的网络结构和导出预训练权重的模型的网络结构不一致,例如保存预训练权重时的网络结构里有检测头也有分割头,而你们目前要加载权重的模型网络里只有检测头,就会触发上面的错误,要么保持网络结构的一致,要么采用类似下面的办法把预训练参数权重里目前网络结构和Optimizer需要的读取出来保存为一个新的文件,然后再调用load_state_dict()加载即可

net = new_model()
pretrained_weights = torch.load('pretrained_weights.pth')
new_model_dict = net.state_dict()
state_dict = {k:v for k,v in pretrained_weights.items() if k in new_model_dict.keys()}
new_model_dict.update(state_dict)
net.load_state_dict(new_model_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/586607.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Visual Studio Markdown Editor 插件导出 HTML

Visual Studio Markdown Editor 可通过右键弹出菜单选择“另存为”,轻松导出HTML文件或是单一的mhtml文件。然而,这个插件在导出HTML文件时有一个小问题,就是md文件中的一些内部链接无法在导出的HTML文件中正常工作。 其中的原因是&#xff…

Docker单点部署Seata(2.0.0) + Nacos(v2.3.0) + Mysql(5.7)

文章目录 一、部署Nacos二、部署Mysql三、Seata准备工作1. 记住nacos、mysql、宿主机的ip2. 建立数据库3. Nacos远程配置文件 四、部署Seata五、初步检验Seata部署情况六、微服务使用Seata1.引入依赖2. application.yml配置 七、遇到的坑1. Nacos显示Seata服务的ip为容器内网ip…

使用SpringBoot AOP记录操作日志和异常日志

使用SpringBoot AOP记录操作日志和异常日志 平时我们在做项目时经常需要对一些重要功能操作记录日志,方便以后跟踪是谁在操作此功能;我们在操作某些功 能时也有可能会发生异常,但是每次发生异常要定位原因我们都要到服务器去查询日志才能找…

第3课 获取并播放音频流

本课对应源文件下载链接: https://download.csdn.net/download/XiBuQiuChong/88680079 FFmpeg作为一套庞大的音视频处理开源工具,其源码有太多值得研究的地方。但对于大多数初学者而言,如何快速利用相关的API写出自己想要的东西才是迫切需要…

【机器学习】卷积神经网络(一)

一、网络结构 典型CNN结构 卷积神经网络是一种能够从图像、声音或其他类型的数据中学习特征的人工智能模型。你可以把它想象成一个有很多层的过滤器,每一层都能够提取出数据中的一些有用的信息,比如边缘、形状、颜色、纹理等。这些信息可以帮助卷积神经网…

ES高级用法:DeleteByQueryRequest

背景 在Elasticsearch中,delete_by_query API 允许你基于查询条件删除文档。在Java中,你可以使用Elasticsearch的Rest High Level Client或者Transport Client来执行这个操作。 示例代码 下面是使用Rest High Level Client进行delete_by_query操作的一…

【Matlab】ELM极限学习机时序预测算法

资源下载: https://download.csdn.net/download/vvoennvv/88681649 一,概述 ELM(Extreme Learning Machine)是一种单层前馈神经网络结构,与传统神经网络不同的是,ELM的隐层神经元权重以及偏置都是随机产生的…

【Android12】Android Framework系列---tombstone墓碑生成机制

tombstone墓碑生成机制 Android中程序在运行时会遇到各种各样的问题,相应的就会产生各种异常信号,比如常见的异常信号 Singal 11:Segmentation fault表示无效的地址进行了操作,比如内存越界、空指针调用等。 Android中在进程(主要…

Apache-ActiveMQ 反序列化漏洞(CVE-2015-5254)复现

CVE-2016-3088 一、环境搭建 Java:jdk8 影响版本 Apache ActiveMQ < 5.13.0 二、用docker搭建漏洞环境 访问一下web界面 然后进入admin目录登录 账号:admin 密码:admin 三、工具准备 cd /opt wget https://github.com/matthiaskaiser/jmet/releases/download/0.1.0/jmet-0…

QT上位机开发(第一个应用)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 不管是软件&#xff0c;还是硬件&#xff0c;如果我们能够顺利启动第一个应用&#xff0c;点亮第一个电路的话&#xff0c;这对我们的信心来说会有…

如何恢复 iPhone 上永久删除的照片?

2007年&#xff0c;苹果公司推出了一款惊天动地的智能手机&#xff0c;也就是后来的iPhone。你会惊讶地发现&#xff0c;迄今为止&#xff0c;苹果公司已经售出了 7 亿部 iPhone 设备。根据最新一项调查数据&#xff0c;智能手机利润的 95% 都进了苹果公司的腰包。 如此受欢迎…

多态的底层实现原理和泛型的底层实现原理

Java 多态的底层原理 - 知乎 (zhihu.com) 使用的是动态绑定&#xff0c;在调用这个方法的时候先去找实例的类&#xff0c;看是否有权限访问&#xff0c;并且看是否实现了该方法&#xff0c;没有的话就去父类找&#xff0c;为了提升效率&#xff0c;虚拟机不会每次都一层一层的…

【Vue2+3入门到实战】(16)VUEVue路由的重定向、404、编程式导航、path路径跳转传参 详细代码示例

目录 一、Vue路由-重定向1.问题2.解决方案3.语法4.代码演示 二、Vue路由-4041.作用2.位置3.语法4.代码示例 三、Vue路由-模式设置1.问题2.语法 四、编程式导航-两种路由跳转方式1.问题2.方案3.语法4.path路径跳转语法5.代码演示 path跳转方式6.name命名路由跳转7.代码演示通过n…

2023十大编程语言及未来展望

2023十大编程语言及未来展望 1. 2023年十大编程语言排行榜2. 十大编程语言未来展望PythonCCJavaC#JavaScriptPHPVisual BasicSQLAssembly language 1. 2023年十大编程语言排行榜 TIOBE排行榜是根据互联网上有经验的程序员、课程和第三方厂商的数量&#xff0c;并使用搜索引擎&a…

阿里云PolarDB数据库优惠价格表11元一天起

阿里云数据库PolarDB租用价格表&#xff0c;云数据库PolarDB MySQL版2核4GB&#xff08;通用&#xff09;、2个节点、60 GB存储空间55元5天&#xff0c;云数据库 PolarDB 分布式版标准版2核16G&#xff08;通用&#xff09;57.6元3天&#xff0c;阿里云百科aliyunbaike.com分享…

数据挖掘 聚类度量

格式化之前的代码&#xff1a; import numpy as np#计算 import pandas as pd#处理结构化表格 import matplotlib.pyplot as plt#绘制图表和可视化数据的函数&#xff0c;通常与numpy和pandas一起使用。 from sklearn import metrics#聚类算法的评估指标。 from sklearn.clust…

ansible管理windows测试

一、环境介绍 Ansible管理主机&#xff1a; 系统: redhat7.6 Linux管理服务器需安装pywinrm插件 Windows客户端主机&#xff1a; 系统: Server2012R2 Windows机器需要安装或升级powershell4.0以上版本&#xff0c;Server2008R2默认的版本是2.0&#xff0c;因此必须升…

k8s学习 — (DevOps实践)第十四章 微服务 DevOps 实战

k8s学习 — &#xff08;DevOps实践&#xff09;第十四章 微服务 DevOps 实战 ※ 各章节重要知识点1 项目构建1.1 项目环境1.2 服务 2 Jenkins CICD2.1 创建流水线项目2.2 Extended Choice Parameter 3 Kubesphere DevOps ※ 各章节重要知识点 k8s学习 — 各章节重要知识点 1…

使用flutter开发windows桌面软件读取ACR22U设备的nfc卡片id,5分钟搞定demo

最近有个需求&#xff0c;要使用acr122u读卡器插入电脑usb口&#xff0c;然后读取nfc卡片的id&#xff0c;并和用户账号绑定&#xff0c;调研了很多方式&#xff0c;之前使用rust实现过一次&#xff0c;还有go实现过一次&#xff0c;然后使用electron的时候遇到安装pcsc-lite失…

MacBook查看本机IP

嘚吧嘚 其实这也不是什么困难的问题&#xff0c;但是今年刚刚入坑Mac&#xff0c;外加用的频率不是很高&#xff0c;每次使用的时候都查&#xff0c;用完就忘&#xff0c;下次用的时候再查&#x1f92e;。真的把自己恶心坏了&#x1f648;。 所以写篇文章记录一下&#x1f92…