pytorch模型的保存与加载

1 pytorch保存和加载模型的三种方法

PyTorch提供了三种方式来保存和加载模型,在这三种方式中,加载模型的代码和保存模型的代码必须相匹配,才能保证模型的加载成功。通常情况下,使用第一种方式(保存和加载模型状态字典)更加常见,因为它更轻量且不依赖于特定的模型类。

1.1 仅保存和加载模型参数(推荐)

1.1.1 保存模型参数

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))# 保存整个模型
torch.save(model.state_dict(), 'sample_model.pt')

1.1.2 加载模型参数

import torch
import torch.nn as nn# 下载模型参数 并放到模型中
loaded_model = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))
loaded_model.load_state_dict(torch.load('sample_model.pt'))
print(loaded_model)

显示如下:

Sequential((0): Linear(in_features=128, out_features=16, bias=True)(1): ReLU()(2): Linear(in_features=16, out_features=1, bias=True)
)

net.state_dict(),在PyTorch中,Module 的可学习参数 (即权重和偏差),模块模型包含在参数中 (通过 model.parameters() 访问)。state_dict 是一个从参数名称隐射到参数 Tesnor 的有序字典对象。只有具有可学习参数的层(卷积层、线性层等) 才有 state_dict 中的条目。

1.2 保存和加载整个模型

1.2.1 保存整个模型

import torch
import torch.nn as nnnet = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))# 保存整个模型,包含模型结构和参数
torch.save(net, 'sample_model.pt')

1.2.2  加载整个模型

import torch
import torch.nn as nn# 加载整个模型,包含模型结构和参数
loaded_model = torch.load('sample_model.pt')
print(loaded_model)

显示如下:

Sequential((0): Linear(in_features=128, out_features=16, bias=True)(1): ReLU()(2): Linear(in_features=16, out_features=1, bias=True)
)

1.3 导出和加载ONNX格式模型

1.3.1 保存模型

import torch
import torch.nn as nnmodel = nn.Sequential(nn.Linear(128, 16), nn.ReLU(), nn.Linear(16, 1))input_sample = torch.randn(16, 128)  # 提供一个输入样本作为示例
torch.onnx.export(model, input_sample, 'sample_model.onnx')

1.3.2 加载模型

import torch
import torch.nn as nn
import onnx
import onnxruntimeloaded_model = onnx.load('sample_model.onnx')
session = onnxruntime.InferenceSession('sample_model.onnx')
print(session)

2 模型保存与加载使用的函数

2.1 保存模型函数torch.save

将对象序列化保存到磁盘中,该方法原理是基于python中的pickle来序列化,各种Models,tensors,dictionaries 都可以使用该方法保存。保存的模型文件名可以是.pth, .pt, .pkl

def save(obj: object,f: FILE_LIKE,pickle_module: Any = pickle,pickle_protocol: int = DEFAULT_PROTOCOL,_use_new_zipfile_serialization: bool = True
) -> None:
  • obj:保存的对象
  • f:一个类似文件的对象(必须实现写入和刷新)或字符串或操作系统。包含文件名的类似路径对象
  • pickle_module:用于挑选元数据和对象的模块
  • pickle_protocol:可以指定以覆盖默认协议

2.2 加载模型函数torch.load

def load(f: FILE_LIKE,map_location: MAP_LOCATION = None,pickle_module: Any = None,*,weights_only: bool = False,**pickle_load_args: Any
) -> Any:
  • f:类文件对象 (返回文件描述符)或一个保存文件名的字符串
  • map_location:一个函数或字典规定如何映射存储设备,torch.device对象
  • pickle_module:用于 unpickling 元数据和对象的模块 (必须匹配序列化文件时的 pickle_module )

2.3 加载模型参数torch.nn.Module.load_state_dict

序列化 (Serialization)是将对象的状态信息转换为可以存储或传输的形式的过程。 在序列化期间,对象将其当前状态写入到临时或持久性存储区。以后,可以通过从存储区中读取或反序列化对象的状态,重新创建该对象。

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):
  • state_dict:保存 parameters 和 persistent buffers 的字典
  • strict:可选,bool型。state_dict 中的 key 是否和 model.state_dict() 返回的 key 一致。

2.4 状态字典state_dict

函数作用是“获取优化器当前状态信息字典”,在神经网络中模型上训练出来的模型参数,也就是权重和偏置值。在Pytorch中,定义网络模型是通过继承torch.nn.Module来实现的。其网络模型中包含可学习的参数(weights, bias, 和一些登记的缓存如batchnorm’s running_mean 等)。模型内部的可学习参数可通过两种方式进行调用:

  • 通过model.parameters()这个生成器来访问所有参数。
  • 通过model.state_dict()来为每一层和它的参数建立一个映射关系并存储在字典中,其键值由每个网络层和其对应的参数张量构成。
def state_dict(self, destination=None, prefix='', keep_vars=False):

除模型外,优化器对象(torch.optim)同样也有一个状态字典,包含的优化器状态信息以及使用的超参数。由于状态字典属于Python 字典,因此对 PyTorch 模型和优化器的保存、更新、替换、恢复等操作都比较便捷。

3 指定map_location加载模型

采用仅加载模型参数的方式,指定设备类型进行模型加载,代码如下:

model_path = '/opt/sample_model.pth'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
map_location = torch.device(device)model.load_state_dict(torch.load(self.model_path, map_location=self.map_location))

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/13248.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Linux下6818开发板(ARM)】硬件空间挂载

(꒪ꇴ꒪ ),hello我是祐言博客主页:C语言基础,Linux基础,软件配置领域博主🌍快上🚘,一起学习!送给读者的一句鸡汤🤔:集中起来的意志可以击穿顽石!作者水平很有限,如果发现错误&#x…

linux----源码安装如何加入到系统服务中(systemclt)

将自己源码安装的软件加入到系统服务中。例如nginx,mysql 就以nginx为例,源码安装,加入到系统服务中 使用yum安装nginx,自动会加入到系统服务 16-Linux系统服务 - 刘清政 - 博客园 (cnblogs.com) 第一步: 源码安装好nginx之后&#xff0…

【Maven】Maven配置国内镜像

文章目录 1. 配置maven的settings.xml文件1.1. 先把镜像mirror配置好1.2. 再把仓库配置好 2. 在idea中引用3. 参考资料 网上配置maven国内镜像的文章很多,为什么选择我,原因是:一次配置得永生、仓库覆盖广、仓库覆盖全面、作者自用的配置。 1…

JavaSE - Sting类

目录 一. 字符串的定义 二. String类中的常用方法 1. 比较两个字符串是否相等(返回值是boolean类型) 2. 比较两个字符串的大小(返回值是int类型) 3. 字符串查找 (1)s1.charAt(index) index:下标&…

Vue3中使用pinia

在Vue 3中使用Pinia,您需要按照以下步骤进行设置: 安装Pinia: npm install pinia创建和配置Pinia存储: // main.jsimport { createApp } from vue import { createPinia } from pinia import App from ./App.vueconst app create…

基于RK3588+AI的边缘计算算法方案:智慧园区、智慧社区、智慧物流

RK3588 AI 边缘计算主板规格书简介 关于本文档 本文档详细介绍了基于Rockchip RK3588芯片的AI边缘计算主板外形、尺寸、技术规格,以及详细的硬件接口设计参考说明,使客户可以快速将RK3588边缘计算主板应用于工业互联网、智慧城市、智慧安防、智慧交通&am…

Python 进阶(四):日期和时间(time、datetime、calendar 模块)

❤️ 博客主页:水滴技术 🌸 订阅专栏:Python 入门核心技术 🚀 支持水滴:点赞👍 收藏⭐ 留言💬 文章目录 1. time模块1.1 获取当前时间1.2 时间休眠1.3 格式化时间 2. datetime模块2.1 获取当前…

EXCEL数据自动web网页查询----高效工作,做个监工

目的 自动将excel将数据填充到web网页,将反馈的数据粘贴到excel表 准备 24KB的鼠标连点器软件(文末附链接)、Excel 宏模块 优势 不需要编程、web验证、爬虫等风险提示。轻量、稳定、安全。 缺点 效率没那么快 演示 宏环境 ht…

Go语法入门 + 项目实战

👂 Take me Hand Acoustic - Ccile Corbel - 单曲 - 网易云音乐 第3个小项目有问题,不能在Windows下跑,懒得去搜Linux上怎么跑了,已经落下进度了.... 目录 😳前言 🍉Go两小时 🔑小项目实战 …

《Kubernetes故障篇:unable to retrieve OCI runtime error》

一、背景信息 1、环境信息如下: 操作系统K8S版本containerd版本Centos7.6v1.24.12v1.6.12 2、报错信息如下: Warning FailedCreatePodSandBox 106s (x39 over 10m) kubelet (combined from similar events): Failed to create pod sandbox: rpc error: …

【SAP Abap】记录一次SAP长文本内容通过Web页面完整显示的应用

【SAP Abap】记录一次SAP长文本内容通过Web页面完整显示的应用 1、业务背景2、实现效果3、开发代码3.1、拼接html3.2、显示html3.3、ALV导出Excel 1、业务背景 业务在销售订单中,通过长文本描述,记录了一些生产备注信息,如生产标准、客户要求…

Jacobi雅克比算法计算特征向量-全网最简单

算法原理 算法涉及数据: 矩阵V:存储特征向量。 矩阵A:表示需要求特征向量的实对称矩阵。算法过程:(1)初始化V为对角矩阵,即主对角线的元素是1,其他元素都为0。(2&#x…

CentOS7安装jenkins

一、安装相关依赖 sudo yum install -y wget sudo yum install -y fontconfig java-11-openjdk二、安装Jenkins 可以查看官网的安装方式 安装官网步骤 先导入jenkins yum 源 sudo wget -O /etc/yum.repos.d/jenkins.repo https://pkg.jenkins.io/redhat-stable/jenkins.repo…

MySQL~DCL

三、DCL 1、SQL分类 DDL:操作数据库和表 DML:增删改表中数据 DQL:查询表中数据 DCL:管理用户,授权 DBA:数据库管理员 DCL:管理用户,授权 2、管理用户 2.1 添加用户 语法&a…

索引的数据结构

索引的数据结构 部分资料来自B站尚硅谷-宋红康老师 1. 为什么使用索引 使用索引是为了加快数据库的查询速度和提高数据库的性能。索引是数据库表中的一种数据结构,它可以帮助数据库快速定位并检索所需的数据。 当数据库表中的数据量较大时,如果没有索…

ELK + Fliebeat + Kafka日志系统

参考: ELKFilebeatKafka分布式日志管理平台搭建_51CTO博客_elk 搭建 ELK 日志分析系统概述及部署(上)-阿里云开发者社区 ELK是三个开源软件的缩写,分别表示:Elasticsearch , Logstash, Kibana , 它们都是开源软件。…

Python桥接模式介绍、使用

一、Python桥接模式介绍 概念: Python桥接模式(Bridge Pattern)是一种软件设计模式,用于将抽象部分与其实现部分分离,使它们可以独立地变化。 它可以通过使用桥接接口来创建一个桥接对象来连接抽象和实现部分。 功能…

HIS信息管理系统 HIS源码

HIS(Hospital Information System)是覆盖医院所有业务和业务全过程的信息管理系统。 HIS系统以财务信息、病人信息和物资信息为主线,通过对信息的收集、存储、传递、统计、分析、综合查询、报表输出和信息共享,及时为医院领导及各…

Verilog语法学习——LV6_多功能数据处理器

LV6_多功能数据处理器 题目来源于牛客网 [牛客网在线编程_Verilog篇_Verilog快速入门 (nowcoder.com)](https://www.nowcoder.com/exam/oj?page1&tabVerilog篇&topicId301) 题目 描述 根据指示信号select的不同,对输入信号a,b实现不同的运算。输入信号a…

解决使用@Field注解配置分词器失效问题(Spring Data Elasticsearch)

问题复现:插入数据时,实体类配置的Field注解没有生效 实体类: package cn.aopmin.pojo;import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import org.springframework.data.annotation.Id; import…