[pytorch、学习] - 4.5 读取和存储

参考

4.5 读取和存储

到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

4.5.1 读写tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor

下面的例子创建了Tensor变量x,并将其存储在文件名为x.pt的文件里.

import torch
import torch.nn as nnx = torch.ones(3)
torch.save(x, 'x.pt')

在这里插入图片描述

然后我们将数据从存储的文件读回内存

x2 = torch.load('x.pt')
x2

在这里插入图片描述
存储一个Tensor列表并返回

y =torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

在这里插入图片描述
在这里插入图片描述
存储并读取一个从字符串映射到Tensor的字典

torch.save({'x': x,'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

在这里插入图片描述
在这里插入图片描述

4.5.2 读写模型

4.5.2.1 state_dict

static_dict是一个从参数名称映射到参数Tensor的字典对象

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

在这里插入图片描述

注意,只有具有可学习参数的层(卷积层、线性层)才有 state_dict中的条目

optimizer = torch.optim.SGD(net.parameters(), lr= 0.001, momentum=0.9)
optimizer.state_dict()

在这里插入图片描述

4.5.2.2 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict)
  2. 保存和加载整个模型。

1. 保存加载static_dict(推荐方式)
torch.save(model.state_dict(), PATH)

# 保存
torch.save(model.state_dict(), PATH)# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.laod(PATH))

2. 保存和加载整个模型

# 保存
torch.save(model, PATH)# 加载
model = torch.load(PATH)

采用第一种方法来试验一下:

X = torch.randn(2, 3)
Y = net(X)PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)Y2 ==Y 

在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250154.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

spring-boot注解详解(五)

AutoWired 首先要知道另一个东西,default-autowire,它是在xml文件中进行配置的,可以设置为byName、byType、constructor和autodetect;比如byName,不用显式的在bean中写出依赖的对象,它会自动的匹配其它bea…

什么是p12证书?ios p12证书怎么获取?

.cer是苹果的默认证书,在xcode开发打包可以使用,如果在lbuilder、phonegap、HBuilder、AppCan、APICloud这些跨平台开发工具打包,就需要用到p12文件。 .cer证书仅包含公钥,.p12证书可能既包含公钥也包含私钥,这就是他们…

[pytorch、学习] - 4.6 GPU计算

参考 4.6 GPU计算 到目前为止,我们一直使用CPU进行计算。对复杂的神经网络和大规模数据来说,使用CPU来计算可能不够高效。 在本节中,将要介绍如何使用单块NIVIDA GPU进行计算 4.6.1 计算设备 PyTorch可以指定用来存储和计算的设备,如果用内存的CPU或者显存的GPU。默认情况下…

spring-boot注解详解(六)

Target Target说明了Annotation所修饰的对象范围:Annotation可被用于 packages、types(类、接口、枚举、Annotation类型)、类型成员(方法、构造方法、成员变量、枚举值)、方法参数和本地变量(如循环变量、…

[pytorch、学习] - 5.1 二维卷积层

参考 5.1 二维卷积层 卷积神经网络(convolutional neural network)是含有卷积层(convolutional layer)的神经网络。本章介绍的卷积神经网络均使用最常见的二维卷积层。它有高和宽两个空间维度,常用来处理图像数据。本节中,我们将介绍简单形式的二维卷积层的工作原理。 5.1.1…

[51CTO]给您介绍Windows10各大版本之间区别

给您介绍Windows10各大版本之间区别 随着win10的不断普及和推广,越来越多的朋友想安装win10系统了,但是很多朋友不知道win10哪个版本好用,为了让大家能够更好的选择win10系统版本,下面小编就来告诉你 http://os.51cto.com/art/201…

spring-boot注解详解(七)

Configuration 从Spring3.0,Configuration用于定义配置类,可替换xml配置文件,被注解的类内部包含有一个或多个被Bean注解的方法,这些方法将会被AnnotationConfigApplicationContext或AnnotationConfigWebApplicationContext类进行…

[pytorch、学习] - 5.2 填充和步幅

参考 5.2 填充和步幅 5.2.1 填充 填充(padding)是指在输入高和宽的两侧填充元素(通常是0元素)。图5.2里我们在原输入高和宽的两侧分别添加了值为0的元素,使得输入高和宽从3变成了5,并导致输出高和宽由2增加到4。图5.2中的阴影部分为第一个输出元素及其计算所使用的输入和核数…

springboot----shiro集成

springboot中集成shiro相对简单,只需要两个类:一个是shiroConfig类,一个是CustonRealm类。 ShiroConfig类: 顾名思义就是对shiro的一些配置,相对于之前的xml配置。包括:过滤的文件和权限,密码加…

[pytorch、学习] - 5.3 多输入通道和多输出通道

参考 5.3 多输入通道和多输出通道 前面两节里我们用到的输入和输出都是二维数组,但真实数据的维度经常更高。例如,彩色图像在高和宽2个维度外还有RGB(红、绿、蓝)3个颜色通道。假设彩色图像的高和宽分别是h和w(像素),那么它可以表示为一个3 * h * w的多维数组。我们将大小为3…

非阻塞算法简介

在不只一个线程访问一个互斥的变量时,所有线程都必须使用同步,否则就可能会发生一些非常糟糕的事情。Java 语言中主要的同步手段就是 synchronized 关键字(也称为内在锁),它强制实行互斥,确保执行 synchron…

[pytorch、学习] - 5.4 池化层

参考 5.4 池化层 在本节中我们介绍池化(pooling)层,它的提出是为了缓解卷积层对位置的过度敏感性。 5.4.1 二维最大池化层和平均池化层 池化层直接计算池化窗口内元素的最大值或者平均值。该运算也叫做最大池化层或平均池化层。 下面把池化层的前向计算实现在pool2d函数里…

[pytorch、学习] - 5.5 卷积神经网络(LeNet)

参考 5.5 卷积神经网络(LeNet) 卷积层尝试解决两个问题: 卷积层保留输入形状,使图像的像素在高和宽两个方向上的相关性均可能被有效识别;卷积层通过滑动窗口将同一卷积核和不同位置的输入重复计算,从而避免参数尺寸过大。 5.5.1 LeNet模型 LeNet分为…

[pytorch、学习] - 5.6 深度卷积神经网络(AlexNet)

参考 5.6 深度卷积神经网络(AlexNet) 在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机。虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意。一方面,神经网络计算复杂。虽然20世纪…

Springboot---Model,ModelMap,ModelAndView

Model(org.springframework.ui.Model) Model是一个接口,包含addAttribute方法,其实现类是ExtendedModelMap。 ExtendedModelMap继承了ModelMap类,ModelMap类实现了Map接口。 public class ExtendedModelMap extends M…

[pytorch、学习] - 5.7 使用重复元素的网络(VGG)

参考 5.7 使用重复元素的网络(VGG) AlexNet在LeNet的基础上增加了3个卷积层。但AlexNet作者对它们的卷积窗口、输出通道数和构造顺序均做了大量的调整。虽然AlexNet指明了深度卷积神经网络可以取得出色的结果,但并没有提供简单的规则以指导…

[pytorch、学习] - 5.8 网络中的网络(NiN)

参考 5.8 网络中的网络(NiN) 前几节介绍的LeNet、AlexNet和VGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何…

[pytorch、学习] - 5.9 含并行连结的网络(GoogLeNet)

参考 5.9 含并行连结的网络(GoogLeNet) 在2014年的ImageNet图像识别挑战赛中,一个名叫GoogLeNet的网络结构大放异彩。它虽然在名字上向LeNet致敬,但在网络结构上已经很难看到LeNet的影子。GoogLeNet吸收了NiN中网络串联网络的思…

mybits注解详解

一、mybatis 简单注解 关键注解词 : Insert : 插入sql , 和xml insert sql语法完全一样 Select : 查询sql, 和xml select sql语法完全一样 Update : 更新sql, 和xml update sql语法完全一样 Delete : 删除sql, 和xml d…

使用python装饰器计算函数运行时间的实例

使用python装饰器计算函数运行时间的实例 装饰器在python里面有很重要的作用, 如果能够熟练使用,将会大大的提高工作效率 今天就来见识一下 python 装饰器,到底是怎么工作的。 本文主要是利用python装饰器计算函数运行时间 一些需要精确的计算…