在pytorch加载预训练模型时,可能遇到以下几种情况。
分为以下几种
- 在pytorch加载预训练模型时,可能遇到以下几种情况。
- 1.多卡训练模型加载单卡预训练模型
- 2. 多卡训练模型加载多卡预训练模型
- 3. 单卡训练模型加载单卡预训练模型
- 4. 单卡训练模型加载多卡预训练模型
- 5.直接删除预训练模型中不匹配的键
- 6. 新版torch的模型加载torch<0.4 版本模型
- 7.在加载的参数模型中增加缺失的键,然后赋予随机参数
问题分为几种情况:
1.多卡训练模型加载单卡预训练模型
if isinstance(self.netG, torch.nn.DataParallel):self.netG = self.netG.module
self.netG.load_state_dict(torch.load(path))
这是多卡训练的模型加载单卡训练的模型出现的问题。
2. 多卡训练模型加载多卡预训练模型
self.netG.load_state_dict(torch.load(path))
3. 单卡训练模型加载单卡预训练模型
self.netG.load_state_dict(torch.load(path))
4. 单卡训练模型加载多卡预训练模型
对预训练模型创建新的字典,去掉key值前面的’module.’
state_dict = torch.load('checkpoint.pt’)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k,v in state_dict.items():name = k[7:]new_state_dict[name] =v
self.netG.load_state_dict(new_state_dict)
5.直接删除预训练模型中不匹配的键
model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}pretrained_dict=model_zoo.load_url(http['url'])model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model = torch.nn.DataParallel(model).cuda()
6. 新版torch的模型加载torch<0.4 版本模型
baol
7.在加载的参数模型中增加缺失的键,然后赋予随机参数
在state_dict 参数模型中增加开头是conv1一些键
state_dict = torch.load(path, map_location=self.device)
model_dict = self.netG_A.state_dict()for k,v in model_dict.items():if k.startswith('conv11') or k.startswith('conv21') or k.startswith('conv31'):state_dict[k] = vself.netG_A.load_state_dict(state_dict)