浅谈一谈pytorch中模型的几种保存方式、以及如何从中止的地方继续开始训练;

一、本文总共介绍3中pytorch模型的保存方式:1.保存整个模型;2.只保存模型参数;3.保存模型参数、优化器、学习率、epoch和其它的所有命令行相关参数以方便从上次中止训练的地方重新启动训练过程。

1.保存整个模型。这种保存方式最简单,保存内容包括模型结构、模型参数以及其它相关信息。代码如下:

# 保存模型,PATH为模型的保存路径及模型命名
import torch
torch.save(model,PATH)# 加载模型
model = torch.load(PATH)

2. 只保存模型参数,不保存模型结构和其它相关信息。这种方式保存的模型,在加载模型前需要构建相同的模型结构,然后再将加载的模型参数赋值给对应的层。代码如下:

# 只保存模型参数
torch.save(model.state_dict(), PATH)# 创建相同结构的模型,然后加载模型参数
model = Model()   # 调用Model类实例化模型
model_dict = torch.load(PATH)
model.load_state_dict(model_dict) #加载模型参数

如果进行模型加载前,创建的模型结构发生了改变,和原来预训练的模型的结构不同,则需要遍历模型参数进行选择性赋值,例如下面的代码:

from collections import OrderedDictmodel = Unet()  # 实例化Unet模型
model_dict = torch.load(pretrained_pth, map_location="cpu")  # 加载模型时将参数映射到CPU上
new_state_dict = OrderedDict()  # 新建一个字典类型用来存储新的模型参数
# 改变模型结构名称,如果有,就去掉backbone.前缀
for k, v in model_dict["state_dict"].items():new_state_dict[k.replace("backbone.", "")] = vmodel.load_state_dict(new_state_dict)  # 加载模型参数

注意上述代码中,有一个参数 map_location="cpu",这个参数是指定将模型参数映射到CPU上,这个参数一般在一下情况下比较适用:1. 当你在CPU上训练了一个模型,并且想将其加载到CPU上进行推断或者继续训练时,使用map_location="cpu"可以确保模型参数被正确地映射到CPU上;2.如果你的预训练模型是在GPU上训练的,但是你在没有GPU的环境中加载模型时,使用这个参数可以避免找不到GPU而导致的错误。 而如果你的代码没有指定map_location参数,则默认情况下pytorch会尝试将模型加载到当前可用设备上(通常是GPU)

3. 保存模型必要参数,使下次训练可以从模型训练停止的地方继续训练,代码如下:

# 将需要保存的参数打包成字典类型
save_file = {"model": model.state_dict(),"optimizer": optimizer.state_dict(),"lr_scheduler": lr_scheduler.state_dict(),"epoch": epoch,"args": args}     # 保存模型和其它参数    
torch.save(save_file, "save_weights/model.pth")
# 加载模型和必要的参数
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])  # 加载模型参数
optimizer.load_state_dict(checkpoint['optimizer'])  # 加载模型优化器
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])  # 加载模型学习策略
args.start_epoch = checkpoint['epoch'] + 1  # 加载模型训练epoch停止数

如果仅是进行模型推理,则只用加载模型参数即可,不用加载其它的东西。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/621225.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

xtu oj 1354 Digit String

题目描述 小明获得了一些密码的片段,包含0∼9,A∼F 这些字符,他猜这些是某个进制下的一个整数的数码串。 小明想知道从2到16进制中,哪些进制下,这个数码串的对应的十进制整数值,等于n? 输入 存在不超过1000个样例&…

C#,史密斯数(Smith Number)的计算方法与源代码

一、关于史密斯数的传说 1、关于理海大学Lehigh University 理海大学(Lehigh University),位于宾夕法尼亚州(Pennsylvania)伯利恒(Bethlehem),由富有爱国情怀与民族精神的实业家艾萨…

Java SE入门及基础(12)

do-while 循环 1. 语法 do { //循环操作 } while ( 循环条件 ); 2. 执行流程图 3. 案例 从控制台录入学生的成绩并计算总成绩,输入0 时退出 4. 代码实现 public static void main ( String [] args ) { Scanner sc new Scanner ( System . in )…

SqlAlchemy使用教程(三) CoreAPI访问与操作数据库详解

SqlAlchemy使用教程(一) 原理与环境搭建SqlAlchemy使用教程(二) 入门示例及编程步骤 三、使用Core API访问与操作数据库 Sqlalchemy 的Core部分集成了DB API, 事务管理,schema描述等功能,ORM构筑于其上。本章介绍创建 Engine对象,使用基本的…

使用lodash原地起飞,总结了几个常用的lodash方法

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 目录 什么是lodash lodash的按需引入 数组操作 求交集 求合集 求差集 求总和…

如何使用C++编程使得在Windows和Linux输入密码的时候保密 linux:tcgetattr tcsetattr

在C编程中,在执行一些操作的时候,终端需要接收用户名和密码,那么在终端输入密码的时候,如何不让别人看见自己的密码,是一个较为关注的问题; 1、问题分析 定义一个登录函数Login //用户登录主循环bool Lo…

Android蓝牙协议栈fluoride(十一) - 音乐播放(4)

上一篇介绍了蓝牙音频的播放通路和编解码器,接下来介绍Source和Sink如何选择编解码器以及编解码流程。 编解码器选择 连接蓝牙后想要播放音乐,需要协商使用哪种编码器,还需要协商编码器使用什么配置,前面介绍了如何协商编码器的…

Redis分布式锁--java实现

文章目录 Redis分布式锁方案:SETNX EXPIRE基本原理比较好的实现会产生四个问题 几种解决原子性的方案方案:SETNX value值是(系统时间过期时间)方案:使用Lua脚本(包含SETNX EXPIRE两条指令)方案:SET的扩展…

设计模式之多线程版本的if------Balking模式

系列文章目录 设计模式之避免共享的设计模式Immutability(不变性)模式 设计模式之并发特定场景下的设计模式 Two-phase Termination(两阶段终止)模式 设计模式之避免共享的设计模式Copy-on-Write模式 设计模式之避免共享的设计模…

watchdog,一个无敌的 Python 库

大家好,今天为大家分享一个无敌的 Python 库 - watchdog。 在软件开发和系统管理领域,经常需要监控文件和目录的变化,以便在文件被创建、修改或删除时触发相应的操作。Python Watchdog是一个强大的Python库,它提供了简单而灵活的…

低代码开发平台

低代码开发平台(LCDP)本身也是一种软件,它为开发者提供了一个创建应用软件的开发环境。看到“开发环境”几个字是不是很亲切?对于程序员而言,低代码开发平台的性质与IDEA、VS等代码IDE(集成开发环境&#x…

蓝桥杯练习题(九)

📑前言 本文主要是【算法】——蓝桥杯练习题(九)的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 …

抖店需要怎么开通?抖店入驻全流程,一看就会!

我是电商珠珠 抖店的热度很高,所以有很多新手想要入驻。 但是对于入驻的流程,部分新手还不太了解,今天我就来给大家详细的讲一下。 入驻准备 在入驻之前需要准备一张个体的营业执照,再准备好个人的身份证、银行卡、抖音号。 …

Tomcat解压打包文件和并部署

一、文件压缩和上传解压 1.本地打包好dist.tar.gz文件 2.通过xftp拖拽上传到知道文件夹下,或者通过命令: cp dist.tar.gz /path/to/destination/folder注:将dist.tar.gz复制到 /path/to/destination/folder文件夹下,该文件夹只是举个例子怎么复制和解压! 3.进入/path/…

使用Python批量上传本地maven库到nexus

背景:外包类项目开发时是调用的公司maven仓库进行开发,交付后需要将maven仓库转移到客户环境。 原理:1、打开idea运行源代码,将maven包下载到本地仓库, 2、下载包所在目录中执行脚本将本地仓库的maven包上传到客户nex…

UE5 C++的TCP客户端示例

客户端.h 需要在Build.cs中加入模块:"Networking","Sockets","Json","JsonUtilities" // Fill out your copyright notice in the Description page of Project Settings.#pragma once#include "CoreMinimal.h" #include…

LeetCode第380场周赛个人题解

目录 100162.最大频率元素计数 原题链接 思路分析 AC代码 100165.找出数组中的美丽下标I 原题链接 思路分析 AC代码 100160. 价值和小于等于 K 的最大数字 原题链接 思路分析 位运算二分 AC代码 100207.找出数组中的美丽下标II 原题链接 思路分析 AC代码 10016…

51-13 多模态论文串讲—BEiT v3 论文精读

BEiT-3的核心思想是将图像建模为一种语言,这样我们就可以对图像、文本以及图像-文本对进行统一的mask modeling。Multi-way transformer模型可以有效地完成不同的视觉和视觉语言任务,使其成为通用建模的一个有效选择。 同时,本文也对多模态大…

K8s-Pod资源(一)Pod介绍、创建Pod、Pod简单资源配额

Pod概述 Kubernetes Pod | Kubernetes Pod是Kubernetes中的最小调度单元,k8s都是以pod的方式运行服务的 一个pod可以指定镜像,封装一个或多个容器 pod需要调度到工作节点运行,节点的选择由scheduler调度器实现 pod定义时,会…

Android Studio 如何设置中文

Android Studio 是一个为 Adndroid 平台开发程序的集成开发环境(IDE)。 如何安装中文插件 在 Jetbrains 家族的插件市场上,是能够搜到语言包插件的,正常情况下安装之后只需要重启即可享受中文界面,可AndroidStudio 中…