pytorch模型加载caffe模型的权重

一、将caffe模型的权重转成dict格式

caffe库的编译可以参考我之前写的一篇博客:ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客

安装好后使用以下脚本便可将caffe模型的参数名和参数保存成dict, 

import pickle as pkl
import caffeMODEL_FILE = 'xxx.prototxt'
PRETRAIN_FILE = 'xxx.caffemodel'if __name__ == '__main__':net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)name_weights = {}for param_name in net.params.keys():name_weights[param_name] = {}layer_params = net.params[param_name]if len(layer_params) == 1:weight = layer_params[0].dataname_weights[param_name]['weight'] = weightprint('%s:\n\t%s (weight)' % (param_name, weight.shape))elif len(layer_params) == 2:# weightweight = layer_params[0].dataname_weights[param_name]['weight'] = weight# biasbias = layer_params[1].dataname_weights[param_name]['bias'] = biasprint('%s:\n\t%s (weight)' % (param_name, weight.shape))print('\t%s (bias)' % str(bias.shape))elif len(layer_params) == 3:# BN: running_mean, running_var, scale factorrunning_mean = layer_params[0].data  # running_meanname_weights[param_name]['running_mean'] = running_mean / layer_params[2].datarunning_var = layer_params[1].data  # running_varname_weights[param_name]['running_var'] = running_var/layer_params[2].dataprint('%s:\n\t%s (running_var)' % (param_name, running_var.shape),)print('\t%s (running_mean)' % str(running_mean.shape))else:raise RuntimeError("error\n")# save weightwith open('weights.pkl', 'wb') as f:pkl.dump(name_weights, f, protocol=2)

二、pytorch模型加载dict格式的权重

这里有两个思路,一是根据权重名来匹配,二是根据权重的shape来匹配,但第二个方法有个问题,就是如果网络中有两个以上shape一样的层的话,那么根据权重的shape来匹配就会出错。下面分别介绍一下以上两个思路,

1、根据权重名匹配

这个方法比较繁琐,要求pytorch模型的参数名要与caffe模型的保持一致,如果不一致,则需要自己写个dict进行映射。具体操作如下,

import pickle as pkl
import torch
import copymodel = xxx
model1 = copy.deepcopy(model)state_dict = {}
with open("weights.pkl", "rb") as wp:  # weights.pkl: 步骤一中生成的dictname_weights = pkl.load(wp, encoding='iso-8859-1')for key, value in name_weights.items():for k, v in value.items():state_dict[key + "." + k] = torch.from_numpy(v)
model1.load_state_dict(state_dict, strict=True)

另一种实现是直接对pytorch模型的参数赋值,代码如下,

import pickle as pkl
import torch
import copymodel = xxx
model2 = copy.deepcopy(model)with open("weights.pkl", "rb") as wp:name_weights = pkl.load(wp, encoding='iso-8859-1')for name, param in model2.named_parameters():for key, value in name_weights.items():if name.split(".")[0] == key:for k, v in value.items():if name.split(".")[1] == k:param.data = torch.from_numpy(v)

2、根据权重shape匹配

import pickle as pkl
import torch
import copymodel = LightCNN_ir_eye()
model3 = copy.deepcopy(model)with open("weights.pkl", "rb") as wp:name_weights = pkl.load(wp, encoding='iso-8859-1')for name, param in model3.named_parameters():for key, value in name_weights.items():for k, v in value.items():v = torch.from_numpy(v)if param.data.shape == v.shape:if name == key + "." + k:  # 防止多个权重shape一致导致的错误param.data = v

3、检查以上模型初始化方法是否正确

import cv2
import numpy as np
import torchimg = cv2.imread("xxx.jpg")
img = cv2.resize(img, (width, height))
img = np.tranpose(img, (2,0,1))
img = np.expand_dims(img, axis=0)out1 = model1(torch.from_numpy(img).float())
out2 = model2(torch.from_numpy(img).float())
out3 = model3(torch.from_numpy(img).float())print(out1)
print(out2)
print(out3)
for i in range(len(out1)):print(out1[i] == out2[i])print(out1[i] == out3[i])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33094.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式测试插件 pytest-xdist 使用详解

目录 使用背景: 使用前提: 使用快速入门: 使用小结: 使用背景: 大型测试套件:当你的测试套件非常庞大,包含了大量的测试用例时,pytest-xdist可以通过并行执行来加速整体的测试过…

js中的break和continue中的区别

js中break和continue有着一些差别。 首先&#xff0c;虽然break和continue都有跳出循环的作用&#xff0c;但break是完全跳出循环&#xff0c;而continue则是跳出一次循环&#xff0c;然后开启下一次的循环。 下面我就来举几个例子吧。 var num 0;for(var i 1;i < 10;i){i…

如何使用ChatGPT设计LOGO,只需知道品牌名字就能完成傻瓜式操作

​独特且引人注目的LOGO对于引导用户/消费者快速识别并与你建立联系至关重要。然而&#xff0c;聘请专业的设计师来创建个性化LOGO可能非常昂贵。这里可以使用使用ChatGPT。[1] 你只需要&#xff1a; 准备好公司名称&#xff1b; 能用ChatGPT&#xff0c;用来给BingChat喂log…

学习总结(TAT)

好久都没交总结了&#xff0c;今天把之前的思路和错误整理了一下&#xff1a; 在服务器和客户端两侧&#xff0c;不可以同时先初始化获取输入流&#xff0c;否则会造成堵塞&#xff0c;同时为这位作者大大打call&#xff1a; (3条消息) 关于Java Socket和创建输入输出流的几点…

一、安全世界观

文章目录 1、 Web安全简史1.1 中国黑客简史1.2 黑客技术的发展历程1.3 web安全的兴起 2、黑帽子、白帽子3、安全的本质4、安全三要素5、如何实施安全评估5.1 资产等级划分5.2 威胁分析5.3 风险分析5.4 设计安全方案 6、白帽子兵法6.1 Secure By Default6.2 纵深防御原则6.3 数据…

java的junit之异常测试、参数化测试、超时测试

1.对可能抛出的异常进行测试 异常本身是方法签名的一部分测试错误的输入是否导致特定的异常 summary 测试异常可以使用Test(expectedExceptio.class)对可能发生的每种类型的异常进行测试 2.参数化测试 如果待测试的输入和输出是一组数据&#xff1a; 可以把测试数据组织起…

Oracle时间查询使用笔记:sysdate用法

Oracle的sysdate用法 通常会有 sysdate - 1 / 12这种&#xff0c;或者sysdate - 1 / 24/3 这两种用法,表示从当前时间往前推若干时间 下面就用sysdate - A/B,sysdate - A/B/C代替 第一种 sysdate - A/B型&#xff0c;这种结果是小时&#xff0c;A代表天数&#xff0c;B代表小时…

学习51单片机怎么开始?

学习的过程不总是先打好基础&#xff0c;然后再盖上层建筑&#xff0c;尤其是实践性的、工程性很强的东西。如果你一定要先全面打好基础&#xff0c;再学习单片机&#xff0c;我觉得你一定学不好&#xff0c;因为你的基础永远打不好&#xff0c;因为基础太庞大了&#xff0c;基…

Spring AOP 切点表达式

参考博客&#xff1a; 参考博客

Oracle 知识篇+会话级全局临时表在不同连接模式中的表现

标签&#xff1a;会话级临时表、全局临时表、幻读释义&#xff1a;Oracle 全局临时表又叫GTT ★ 结论 ✔ 专用服务器模式&#xff1a;不同应用会话只能访问自己的数据 ✔ 共享服务器模式&#xff1a;不同应用会话只能访问自己的数据 ✔ 数据库驻留连接池模式&#xff1a;不同应…

探索数据之美:初步学习 Python 柱状图绘制

文章目录 一 基础柱状图1.1 创建简单柱状图1.2 反转x和y轴1.3 数值标签在右侧1.4 演示结果 二 基础时间线柱状图2.1 创建时间线2.2 时间线主题设置取值表2.3 演示结果 三 GDP动态柱状图绘制3.1 需求分析3.2 数据文件内容3.3 列表排序方法3.4 参考代码3.5 运行结果 一 基础柱状图…

谷粒商城第十二天-基本属性销售属性管理功能的实现

目录 一、总述 二、前端部分 三、后端部分 四、总结 一、总述 前端的话&#xff0c;依旧是直接使用老师给的。 前端的话还是那些增删改查&#xff0c;业务复杂一点的话&#xff0c;无非就是设计到多个字段多个表的操作&#xff0c;当然这是后端的事了&#xff0c;前端这里…

Nodejs安装及环境变量配置(修改全局安装依赖工具包和缓存文件夹及npm镜像源)

本机环境&#xff1a;win11家庭中文版 一、官网下载 二、安装 三、查看nodejs及npm版本号 1、查看node版本号 node -v 2、查看NPM版本号&#xff08;安装nodejs时已自动安装npm&#xff09; npm -v 四、配置npm全局下载工具包和缓存目录 1、查看安装目录 在本目录下创建no…

瓴羊发布All in One 产品,零售SaaS的尽头是DaaS?

“打破烟囱、化繁为简&#xff0c;让丰富的能力、数据和智能All in One”&#xff0c;这是瓴羊新发布的产品瓴羊One承担的使命&#xff0c;也意味着瓴羊DaaS事业迈入了一个新阶段。 成立伊始&#xff0c;瓴羊就打出了“Not SaaS&#xff0c;But DaaS”旗号&#xff0c;将自己的…

小程序裂变怎么做?小程序裂变机制有哪些?

做了小程序就等于“生意上门”&#xff1f;其实并不是这样。小程序跟流量平台较为明显的区别就在于小程序并非“自带流量”&#xff0c;而是需要企业利用自己的营销推广能力来建立引流渠道&#xff0c;从而完成用户的拉新和留存、转化。因此&#xff0c;想要用小程序来增加自己…

[虚幻引擎] UE DTBase64 插件说明 使用蓝图对字符串或文件进行Base64加密解密

本插件可以在虚幻引擎中使用蓝图对字符串&#xff0c;字节数组&#xff0c;文件进行Base64的加密和解密。 目录 1. 节点说明 String To Base64 Base64 To String Binary To Base64 Base64 To Binary File To Base64 Base64 To File 2. 案例演示 3. 插件下载 1. 节点说…

centos如何配置IP地址?

CentOS如何查看和临时配置IP地址 CentOS系统中&#xff0c;可以通过使用ifconfig命令来查看当前本机的IP地址信息。输入ifconfig即可显示当前网络接口的IP地址、网络掩码和网关信息。如果需要设置临时IP地址&#xff0c;可以使用ifconfig命令后接网卡名称和需要设置的IP地址、网…

自定义element-plus的弹框样式

项目中弹框使用频繁,需要统一样式风格,此组件可以自定义弹框的头部样式和内容 一、文件结构如下: 二、自定义myDialog组件 需求&#xff1a; 1.自定义弹框头部背景样式和文字 2.自定义弹框内容 3.基本业务流程框架 components/myDialog/index.vue完整代码&#xff1a; &…

采用pycharm在虚拟环境使用pyinstaller打包python程序

一年多以前&#xff0c;我写过一篇博客描述了如何虚拟环境打包&#xff0c;这一次有所不同&#xff0c;直接用IDE pycharm构成虚拟环境并运行pyinstaller打包 之前的博文&#xff1a; 虚拟环境venu使用pyinstaller打包python程序_伊玛目的门徒的博客-CSDN博客 第一步&#xf…

山西电力市场日前价格预测【2023-08-12】

日前价格预测 预测明日&#xff08;2023-08-12&#xff09;山西电力市场全天平均日前电价为330.52元/MWh。其中&#xff0c;最高日前电价为387.00元/MWh&#xff0c;预计出现在19: 45。最低日前电价为278.05元/MWh&#xff0c;预计出现在13: 00。 价差方向预测 1&#xff1a; 实…