一、将caffe模型的权重转成dict格式
caffe库的编译可以参考我之前写的一篇博客:ImportError: dynamic module does not define module export function (PyInit__caffe)问题解决记录_chen_zn95的博客-CSDN博客
安装好后使用以下脚本便可将caffe模型的参数名和参数保存成dict,
import pickle as pkl
import caffeMODEL_FILE = 'xxx.prototxt'
PRETRAIN_FILE = 'xxx.caffemodel'if __name__ == '__main__':net = caffe.Net(MODEL_FILE, PRETRAIN_FILE, caffe.TEST)name_weights = {}for param_name in net.params.keys():name_weights[param_name] = {}layer_params = net.params[param_name]if len(layer_params) == 1:weight = layer_params[0].dataname_weights[param_name]['weight'] = weightprint('%s:\n\t%s (weight)' % (param_name, weight.shape))elif len(layer_params) == 2:# weightweight = layer_params[0].dataname_weights[param_name]['weight'] = weight# biasbias = layer_params[1].dataname_weights[param_name]['bias'] = biasprint('%s:\n\t%s (weight)' % (param_name, weight.shape))print('\t%s (bias)' % str(bias.shape))elif len(layer_params) == 3:# BN: running_mean, running_var, scale factorrunning_mean = layer_params[0].data # running_meanname_weights[param_name]['running_mean'] = running_mean / layer_params[2].datarunning_var = layer_params[1].data # running_varname_weights[param_name]['running_var'] = running_var/layer_params[2].dataprint('%s:\n\t%s (running_var)' % (param_name, running_var.shape),)print('\t%s (running_mean)' % str(running_mean.shape))else:raise RuntimeError("error\n")# save weightwith open('weights.pkl', 'wb') as f:pkl.dump(name_weights, f, protocol=2)
二、pytorch模型加载dict格式的权重
这里有两个思路,一是根据权重名来匹配,二是根据权重的shape来匹配,但第二个方法有个问题,就是如果网络中有两个以上shape一样的层的话,那么根据权重的shape来匹配就会出错。下面分别介绍一下以上两个思路,
1、根据权重名匹配
这个方法比较繁琐,要求pytorch模型的参数名要与caffe模型的保持一致,如果不一致,则需要自己写个dict进行映射。具体操作如下,
import pickle as pkl
import torch
import copymodel = xxx
model1 = copy.deepcopy(model)state_dict = {}
with open("weights.pkl", "rb") as wp: # weights.pkl: 步骤一中生成的dictname_weights = pkl.load(wp, encoding='iso-8859-1')for key, value in name_weights.items():for k, v in value.items():state_dict[key + "." + k] = torch.from_numpy(v)
model1.load_state_dict(state_dict, strict=True)
另一种实现是直接对pytorch模型的参数赋值,代码如下,
import pickle as pkl
import torch
import copymodel = xxx
model2 = copy.deepcopy(model)with open("weights.pkl", "rb") as wp:name_weights = pkl.load(wp, encoding='iso-8859-1')for name, param in model2.named_parameters():for key, value in name_weights.items():if name.split(".")[0] == key:for k, v in value.items():if name.split(".")[1] == k:param.data = torch.from_numpy(v)
2、根据权重shape匹配
import pickle as pkl
import torch
import copymodel = LightCNN_ir_eye()
model3 = copy.deepcopy(model)with open("weights.pkl", "rb") as wp:name_weights = pkl.load(wp, encoding='iso-8859-1')for name, param in model3.named_parameters():for key, value in name_weights.items():for k, v in value.items():v = torch.from_numpy(v)if param.data.shape == v.shape:if name == key + "." + k: # 防止多个权重shape一致导致的错误param.data = v
3、检查以上模型初始化方法是否正确
import cv2
import numpy as np
import torchimg = cv2.imread("xxx.jpg")
img = cv2.resize(img, (width, height))
img = np.tranpose(img, (2,0,1))
img = np.expand_dims(img, axis=0)out1 = model1(torch.from_numpy(img).float())
out2 = model2(torch.from_numpy(img).float())
out3 = model3(torch.from_numpy(img).float())print(out1)
print(out2)
print(out3)
for i in range(len(out1)):print(out1[i] == out2[i])print(out1[i] == out3[i])