pytorch:加载预训练模型(多卡加载单卡预训练模型,多GPU,单GPU)

在pytorch加载预训练模型时,可能遇到以下几种情况。

分为以下几种

  • 在pytorch加载预训练模型时,可能遇到以下几种情况。
    • 1.多卡训练模型加载单卡预训练模型
    • 2. 多卡训练模型加载多卡预训练模型
    • 3. 单卡训练模型加载单卡预训练模型
    • 4. 单卡训练模型加载多卡预训练模型
    • 5.直接删除预训练模型中不匹配的键
    • 6. 新版torch的模型加载torch<0.4 版本模型
    • 7.在加载的参数模型中增加缺失的键,然后赋予随机参数

问题分为几种情况:

1.多卡训练模型加载单卡预训练模型

if isinstance(self.netG, torch.nn.DataParallel):self.netG = self.netG.module
self.netG.load_state_dict(torch.load(path))

在这里插入图片描述
这是多卡训练的模型加载单卡训练的模型出现的问题。

2. 多卡训练模型加载多卡预训练模型

self.netG.load_state_dict(torch.load(path))

3. 单卡训练模型加载单卡预训练模型

self.netG.load_state_dict(torch.load(path))

4. 单卡训练模型加载多卡预训练模型

对预训练模型创建新的字典,去掉key值前面的’module.’

state_dict = torch.load('checkpoint.pt’)
from collections import OrderedDict
new_state_dict = OrderedDict()
for k,v in state_dict.items():name = k[7:]new_state_dict[name]  =v 
self.netG.load_state_dict(new_state_dict)

5.直接删除预训练模型中不匹配的键

 model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder)http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'}pretrained_dict=model_zoo.load_url(http['url'])model_dict = model.state_dict()pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict)model.load_state_dict(model_dict)model = torch.nn.DataParallel(model).cuda()

6. 新版torch的模型加载torch<0.4 版本模型

baol

7.在加载的参数模型中增加缺失的键,然后赋予随机参数

在state_dict 参数模型中增加开头是conv1一些键

state_dict = torch.load(path, map_location=self.device)
model_dict = self.netG_A.state_dict()for k,v in model_dict.items():if k.startswith('conv11') or k.startswith('conv21') or k.startswith('conv31'):state_dict[k] = vself.netG_A.load_state_dict(state_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/535199.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

知识蒸馏 knowledge distill 相关论文理解

Knowledge Distil 相关文章1.FitNets : Hints For Thin Deep Nets &#xff08;ICLR2015&#xff09;2.A Gift from Knowledge Distillation&#xff1a;Fast Optimization, Network Minimization and Transfer Learning (CVPR 2017)3.Matching Guided Distillation&#xff08…

模型压缩 相关文章解读

模型压缩相关文章Learning both Weights and Connections for Efficient Neural Networks (NIPS2015)Deep Compression: Compressing Deep Neural Networks with Pruning, Trained Quantization and Huffman Coding&#xff08;ICLR2016&#xff09;Learning both Weights and …

Latex 生成的PDF增加行号 左右两边

增加行号 \usepackage[switch]{lineno}\linenumbers \begin{document} \end{document}

Linux 修改用户名的主目录 家目录

首先root 登陆 sudo -i 输入密码然后 vim /etc/passwd 找到用户名 然后修改后面的路径即可

ubunt16.04 安装3090显卡驱动 cuda cudnn pytorch

安装驱动 需要的安装包 30系列显卡是新一代架构&#xff0c;新驱动不支持cuda9以及cuda10&#xff0c;所以必须安装cuda11、而pytorch现在稳定版为1.6&#xff0c;最高仅支持到cud10.2。所以唯一的办法就是使用上处于beta测试的1.7或1.8。这也是为啥一开始就强调本文的写作时…

测试项目:车牌检测,行人检测,红绿灯检测,人流检测,目标识别

本项目为2020年中国软件杯&#xff21;组第一批赛题&#xff02;基于计算机视觉的交通场景智能应用&#xff02;&#xff0e;项目用python实现&#xff0c;主要使用YOLO模型实现道路目标如人、车、交通灯等物体的识别&#xff0c;使用开源的&#xff02;中文车牌识别HyperLPR&a…

linux 安装python3.8的几种方法

1.命令行搞定 git clone https://github.com/waketzheng/carstino cd carstino python3 upgrade_py.py2.离线安装 自己在官网下载安装包 https://www.python.org/ftp/python/3.8.0/ 解压&#xff1a; tar -zvf Python-3.8.0.tgz安装 cd Python-3.8.0 ./configure --prefix/u…

面试题目:欠拟合、过拟合及如何防止过拟合

对于深度学习或机器学习模型而言&#xff0c;我们不仅要求它对训练数据集有很好的拟合&#xff08;训练误差&#xff09;&#xff0c;同时也希望它可以对未知数据集&#xff08;测试集&#xff09;有很好的拟合结果&#xff08;泛化能力&#xff09;&#xff0c;所产生的测试误…

LaTeX:equation, aligned 书写公式换行,顶部对齐

使用aligined 函数&#xff0c;其中aligned就是用来公式对齐的&#xff0c;在中间公式中&#xff0c;\ 表示换行&#xff0c; & 表示对齐。在公式中等号之前加&&#xff0c;等号介绍要换行的地方加\就可以了。 \begin{equation} \begin{aligned} L_{task} &\lamb…

Latex: 表格中 自动换行居中

1、在导言区添加宏包&#xff1a; \usepackage{makecell}2、环境&#xff1a;tabular 命令&#xff1a; \makecell[居中情况]{第1行内容 \\ 第2行内容 \\ 第3行内容 ...} \makecell [c]{ResNet101\\ (11.7M)}参数说明&#xff1a; [c]是水平居中&#xff0c;[l]水平左居中&am…

在服务器上远程使用tensorboard查看训练loss和准确率

本人使用的是vscode 很简单 from torch.utils.tensorboard import SummaryWriter writer SummaryWriter(./logs)writer.add_scalar(train_loss,loss.val(),iteration) # 名字&#xff0c;数据&#xff0c;迭代次数训练的过程中会产生一个./logs的文件夹&#xff0c;里面存放的…

pytorch:固定部分层参数,固定单个模型

文章目录固定部分层参数固定指定层的参数不同层设置不同的学习率固定部分层参数 class RESNET_attention(nn.Module):def __init__(self, model, pretrained):super(RESNET_attetnion, self).__init__()self.resnet model(pretrained) # 这个model被固定for p in self.parame…

各种损失损失函数的使用场景和使用方法:KL散度

KL 散度的使用场景 KL散度( Kullback–Leibler divergence)&#xff0c;又称相对熵&#xff0c;是描述两个概率分布 P 和 Q 差异的一种方法 torch.nn.functional.kl_div(input, target, size_averageNone, reduceNone, reductionmean) torch.nn.KLDivLoss(input, target, si…

RNN,LSTM,GRU的理解

RNN x 为当前状态下数据的输入&#xff0c; h 表示接收到的上一个节点的输入。 y为当前节点状态下的输出&#xff0c;而h′h^\primeh′为传递到下一个节点的输出. LSTM #定义网络 lstm nn.LSTM(input_size20,hidden_size50,num_layers2) #输入变量 input_data Variable(tor…

常用的loss函数,以及在训练中的使用

文章目录KL 散度L2 loss做标准化处理CElossCTCLossAdaptiveAvgPool2dKL 散度 算KL散度的时候要注意前后顺序以及加log import torhch.nn as nn d_loss nn.KLDivLoss(reductionreduction_kd)(F.log_softmax(y / T, dim1),F.softmax(teacher_scores / T, dim1)) * T * T蒸馏lo…

windows, 放方向键设置为vim格式,autohotkey-windows

安装 Autohotkey https://www.autohotkey.com/download/ 设置快捷键 随便找个目录,鼠标右键新建一个autohotkey的脚本。 映射一个键——上左下右 经常打字的人都知道,我们编辑文本时要上下左右移动光标,难免要将手移到方向键再移回来打字。对我这样的懒癌后期患者,这简直不能…

Hbase数据模型及Hbase Shell

目录 1 数据模型 1.1 相关名词概念 1.2 模型分析 2 Hbase Shell操作 2.1 命名空间 2.2 表操作 2.2.1 创建表 2.2.2 更改表结构 2.2.3 表的其他操作 2.3 数据操作 2.3.1 添加数据(put) 2.3.2 删除数据(delete) 2.3.3 获取数据(get|scan) 3 过滤器 3.1 比较运算符…

非关型数据库之Hbase

目录 1 Hbase简介 1.1 初识Hbase 1.2 Hbase的特性 2 HDFS专项模块 2.1 HDFS的基本架构 2.1.1 HDFS各组件的功能&#xff1a; 2.2 HFDFS多种机制 2.2.1 分块机制 2.2.2 副本机制 2.2.3 容错机制 2.2.4 读写机制 3 Hbase组件及其功能 3.1 客户端 3.2 Zookeeper 3.3 …

MongoDB Shell操作

目录 1 数据库操作 2 集合操作 3 文档操作 3.1 插入文档(insert|insertOne|insertMany) 3.2插入、删除的循环操作 3.2 删除文档(remove|deleteOne|deleteMany) 3.3 更新文档(update|save) 3.4 查询文档(find) 4 游标 5 索引 6 聚合 1 数据库操作 当新创建的数据库里…

MongoDB副本集、分片集的伪分布式部署(保姆级教程)

目录 1 集群架构(概念篇) 1.1 MongoDB核心组件 1.2 主从复制 1.3 副本集 1.4 分片集 2 集群搭建 2.1 部署副本集(伪分布式) 2.2 分片集部署(伪分布式) 2.3 副本集与分片集区别 1 集群架构(概念篇) MongoDB有三种集群部署模式&#xff0c;分别是主从复制(Master-Slave)…