PyTorch读写模型(state_dict、torch.save、torch.load)

1. state_dict

在PyTorch中,state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系。(如model的每一层的weights及bias等)

首先,我们来定义一个MLP模型:

import torch.nn as nnclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448,  0.1856, -0.5678],[ 0.2030, -0.2073, -0.0104]])),('hidden.bias', tensor([-0.3117, -0.4232])),('output.weight', tensor([[-0.4556,  0.4084]])),('output.bias', tensor([-0.3573]))])

注意:只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

输出:

{
'state': {}, 
'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1, 2, 3]}]
}

2. 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  • 仅保存和加载模型参数(state_dict);
  • 保存和加载整个模型。

1. 保存和加载state_dict(推荐方式)
保存:

torch.save(model.state_dict(), PATH) # 推荐的文件后缀名是pt或pth

加载:

model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))

2. 保存和加载整个模型
保存:

torch.save(model, PATH)

加载:

model = torch.load(PATH)

我们采用推荐的方法一来实验一下:

X = torch.randn(2, 3)
Y = net(X)PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

输出:

tensor([[1],[1]], dtype=torch.uint8)

因为这net和net2都有同样的模型参数,那么对同一个输入X的计算结果将会是一样的。上面的输出也验证了这一点。

参考资料

  • https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter04_DL_computation/4.5_read-write.md

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/863653.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

494. 目标和 Medium

给你一个非负整数数组 nums 和一个整数 target 。 向数组中的每个整数前添加 或 - ,然后串联起所有整数,可以构造一个 表达式 : 例如,nums [2, 1] ,可以在 2 之前添加 ,在 1 之前添加 - ,然…

使用Calendar.add进行日期计算

使用Calendar.add进行日期计算 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将深入探讨在Java中如何使用Calendar.add方法进行日期计算。Calendar类是…

如何在Ubuntu20上离线安装joern(包括sbt和scala)

在Ubuntu 20上离线安装Joern,由于Joern通常需要通过互联网从其官方源或GitHub等地方下载,但在离线环境中,我们需要通过一些额外的步骤来准备和安装。(本人水平有限,希望得到大家的指正) 我们首先要做的就是…

在QGIS中调用天地图

2019年 1月 1日起,天地图 API及服务接口调用需要获得开发授权,之前使用 QGIS等 GIS软件无法继续调用天地图,这就需要申请一个许可。 一、注册并申请 Key 具体申请可以登录如下地址:https://www.tianditu.gov.cn打开上述网址后点…

速盾:cdn加速哪个好?

在现代互联网时代,网站的速度和稳定性是非常重要的。为了提供最佳的用户体验,许多网站和应用程序都使用CDN(内容分发网络)来加速其内容的传输。CDN是由位于全球各地的分布式服务器组成的网络,其目的是将内容尽可能快地…

工厂方法模式:概念与应用

目录 工厂方法模式工厂方法模式结构工厂方法适合的应用场景工厂方法模式的优缺点练手题目题目描述输入描述输出描述**提示信息**解题: 工厂方法模式 工厂方法模式是一种创建型设计模式, 其在父类中提供一个创建对象的方法, 允许子类决定实例…

SQLite3的使用

14_SQLite3 SQLite3是一个嵌入式数据库系统,它的数据库就是一个文件。SQLite3不需要一个单独的服务器进程或操作系统,不需要配置,这意味着不需要安装或管理,所有的维护都来自于SQLite3软件本身。 安装步骤 在Linux上安装SQLite…

《概率论与数理统计》期末复习笔记_下

目录 第4章 随机变量的数字特征 4.1 数学期望 4.2 方差 4.3 常见分布的期望与方差 4.4 协方差与相关系教 第5章 大数定律和中心极限定理 5.1 大数定律 5.2 中心极限定理 第6章 样本与抽样分布 6.1 数理统汁的基本概念 6.2 抽样分布 6.2.1 卡方分布 6.2.2 t分布 6.…

Winform使用HttpClient调用WebApi的基本用法

Winform程序调用WebApi的方式有很多,本文学习并记录采用HttpClient调用基于GET、POST请求的WebApi的基本方式。WebApi使用之前编写的检索环境检测数据的接口,如下图所示。 调用基于GET请求的无参数WebApi 创建HttpClient实例后调用GetStringAsync函数获…

技术打包 催化剂浸渍制作方法设备

网盘 https://pan.baidu.com/s/1Bybbyy5qEA2uTUlaELmWwg?pwdepdk 改性加氢处理催化剂载体、催化剂及其制备方法和应用.pdf 水滑石基催化剂在高浓度糖转化到1,2-丙二醇中的应用.pdf 海泡石负载铁锰双金属催化剂及其制备方法和应用.pdf 甘油氢解催化剂及其制备方法和应用.pdf 用…

【原理】机器学习中的最小二乘法公式推导过程

本文来自《老饼讲解-BP神经网络》https://www.bbbdata.com/ 目录 一、什么是最小二乘法1.1. 什么是最小二乘法1.2. 最小二乘法的求解公式 二、最小二乘法求解公式的推导 最小二乘法是基本的线性求解问题之一,本文介绍最小二乘法的原理,和最小二法求解公式…

如何使用Spring Boot进行单元测试

如何使用Spring Boot进行单元测试 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿!今天我们将探讨如何在Spring Boot项目中进行单元测试,确保代码质量…

Week 4-杨帆-学习总结

目录 28 批量归一化批量规范化的背景和必要性批量规范化的实现理论探讨与争议遇到的问题&解决办法 29 残差网络 ResNet残差网络(ResNet)的核心概念函数类与嵌套函数类残差块(Residual Blocks)的结构与功能深度学习框架的应用模…

【学习笔记】Redis学习笔记——第2章:简单动态字符串

第2章:简单动态字符串 Redis用作键值对或AOF缓冲区的字符串为SDS(简单动态字符串),而不是C语言传统字符串(只用作打印log等不会修改字符串值的地方)。 2.1 SDS的定义 {//SDS字符串长度(buf数组中已使用的空间)int len;//buf数组…

【Vue】Vue3基础

VUE3基础 1、简介2、创建工程2.1 基于vue-cli创建(脚手架webpack)2.2 基于vite创建(推荐)2.3 目录结构2.4 vscode插件推荐 3、核心语法3.1 选项式(options API)和组合式(composition API&#x…

Arduino - LED 矩阵

Arduino - LED 矩阵 Arduino - LED Matrix LED matrix display, also known as LED display, or dot matrix display, are wide-used. In this tutorial, we are going to learn: LED矩阵显示器,也称为LED显示器,或点阵显示器,应用广泛。在…

scatterlist的相关概念与实例分析

概念 scatterlist scatterlist用来描述一块内存,sg_table一般用于将物理不同大小的物理内存链接起来,一次性送给DMA控制器搬运 struct scatterlist {unsigned long page_link; //指示该内存块所在的页面unsigned int offset; //指示该内存块在页面中的…

纯硬件FOC驱动BLDC

1. 硬件FOC 图 1 为采用 FOC 的方式控制 BLDC 电机的过程,经由 FOC 变换( Clark 与 Park 变换) ,将三相电流转换为空间平 行电流 ID 与空间垂直电流 IQ。经过 FOC 逆变化逆( Clark 变换与逆 Park 变换) ,将两相电流转换为三相电流用于控 制电…

喜茶新品被迫更名,内容营销专家刘鑫炜谈品牌定位敏锐度和适应性

喜茶,作为茶饮界的知名品牌,一直以其独特的创意和优质的产品受到消费者的喜爱。然而,近期喜茶推出的一款新品“小奶栀”却因其名称发音问题引发了不小的争议。 事件回顾 “小奶栀”这款新品在上市之初,以其独特的口感和创新的命名…

【算法——快慢指针链表】

【如何判断单链表是否有环?链表中"快慢指针"的妙用】 判断环 快慢指针一开始都在开头,快指针2/s,慢指针1/s;如果链表有环,那么二者一定相遇 那么快慢指针的移动步数固定了吗? 链表中心结点 8…