[pytorch、学习] - 4.5 读取和存储

参考

4.5 读取和存储

到目前为止,我们介绍了如何处理数据以及如何构建、训练和测试深度学习模型。然而在实际中,我们有时需要把训练好的模型部署到很多不同的设备。在这种情况下,我们可以把内存中训练好的模型参数存储在硬盘上供后续读取使用。

4.5.1 读写tensor

我们可以直接使用save函数和load函数分别存储和读取Tensor

下面的例子创建了Tensor变量x,并将其存储在文件名为x.pt的文件里.

import torch
import torch.nn as nnx = torch.ones(3)
torch.save(x, 'x.pt')

在这里插入图片描述

然后我们将数据从存储的文件读回内存

x2 = torch.load('x.pt')
x2

在这里插入图片描述
存储一个Tensor列表并返回

y =torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

在这里插入图片描述
在这里插入图片描述
存储并读取一个从字符串映射到Tensor的字典

torch.save({'x': x,'y': y
}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

在这里插入图片描述
在这里插入图片描述

4.5.2 读写模型

4.5.2.1 state_dict

static_dict是一个从参数名称映射到参数Tensor的字典对象

class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.hidden = nn.Linear(3, 2)self.act = nn.ReLU()self.output = nn.Linear(2, 1)def forward(self, x):a = self.act(self.hidden(x))return self.output(a)net = MLP()
net.state_dict()

在这里插入图片描述

注意,只有具有可学习参数的层(卷积层、线性层)才有 state_dict中的条目

optimizer = torch.optim.SGD(net.parameters(), lr= 0.001, momentum=0.9)
optimizer.state_dict()

在这里插入图片描述

4.5.2.2 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict)
  2. 保存和加载整个模型。

1. 保存加载static_dict(推荐方式)
torch.save(model.state_dict(), PATH)

# 保存
torch.save(model.state_dict(), PATH)# 加载
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.laod(PATH))

2. 保存和加载整个模型

# 保存
torch.save(model, PATH)# 加载
model = torch.load(PATH)

采用第一种方法来试验一下:

X = torch.randn(2, 3)
Y = net(X)PATH = "./net.pt"
torch.save(net.state_dict(), PATH)net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)Y2 ==Y 

在这里插入图片描述在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250154.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA排序的方法

//冒泡排序法: package fuxi;public class Bubble { public static void main(String[] args) { int a[] { 10,23,11,56,45,26,59,28,84,79 }; int i,temp; System.out.println("输出原始数组数据:"); for (i…

spring-boot注解详解(五)

AutoWired 首先要知道另一个东西,default-autowire,它是在xml文件中进行配置的,可以设置为byName、byType、constructor和autodetect;比如byName,不用显式的在bean中写出依赖的对象,它会自动的匹配其它bea…

什么是p12证书?ios p12证书怎么获取?

.cer是苹果的默认证书,在xcode开发打包可以使用,如果在lbuilder、phonegap、HBuilder、AppCan、APICloud这些跨平台开发工具打包,就需要用到p12文件。 .cer证书仅包含公钥,.p12证书可能既包含公钥也包含私钥,这就是他们…

[pytorch、学习] - 4.6 GPU计算

参考 4.6 GPU计算 到目前为止,我们一直使用CPU进行计算。对复杂的神经网络和大规模数据来说,使用CPU来计算可能不够高效。 在本节中,将要介绍如何使用单块NIVIDA GPU进行计算 4.6.1 计算设备 PyTorch可以指定用来存储和计算的设备,如果用内存的CPU或者显存的GPU。默认情况下…

adb connect 192.168.1.10 failed to connect to 192.168.1.10:5555

adb connect 192.168.1.10 输出 failed to connect to 192.168.1.10:5555 关闭安卓端Wi-Fi,重新打开连接即可 转载于:https://www.cnblogs.com/sea-stream/p/10020995.html

创建oracle数据库表空间并分配用户

我们在本地的oracle上或者virtualbox的oracle上 创建新的数据库表空间操作:通过system账号来创建并授权/*--创建表空间create tablespace YUJKDATAdatafile c:\yujkdata200.dbf --指定表空间对应的datafile文件的具体的路径size 100mautoextend onnext 10m*/ /*--创…

spring-boot注解详解(六)

Target Target说明了Annotation所修饰的对象范围:Annotation可被用于 packages、types(类、接口、枚举、Annotation类型)、类型成员(方法、构造方法、成员变量、枚举值)、方法参数和本地变量(如循环变量、…

[pytorch、学习] - 5.1 二维卷积层

参考 5.1 二维卷积层 卷积神经网络(convolutional neural network)是含有卷积层(convolutional layer)的神经网络。本章介绍的卷积神经网络均使用最常见的二维卷积层。它有高和宽两个空间维度,常用来处理图像数据。本节中,我们将介绍简单形式的二维卷积层的工作原理。 5.1.1…

[51CTO]给您介绍Windows10各大版本之间区别

给您介绍Windows10各大版本之间区别 随着win10的不断普及和推广,越来越多的朋友想安装win10系统了,但是很多朋友不知道win10哪个版本好用,为了让大家能够更好的选择win10系统版本,下面小编就来告诉你 http://os.51cto.com/art/201…

iOS中NSString转换成HEX(十六进制)-NSData转换成int

NSString *str "0xff055008"; //先以16为参数告诉strtoul字符串参数表示16进制数字,然后使用0x%X转为数字类型 unsigned long red strtoul([str UTF8String],0,16); //strtoul如果传入的字符开头是“0x”,那么第三个参数是0,也是会转为十…

spring-boot注解详解(七)

Configuration 从Spring3.0,Configuration用于定义配置类,可替换xml配置文件,被注解的类内部包含有一个或多个被Bean注解的方法,这些方法将会被AnnotationConfigApplicationContext或AnnotationConfigWebApplicationContext类进行…

[pytorch、学习] - 5.2 填充和步幅

参考 5.2 填充和步幅 5.2.1 填充 填充(padding)是指在输入高和宽的两侧填充元素(通常是0元素)。图5.2里我们在原输入高和宽的两侧分别添加了值为0的元素,使得输入高和宽从3变成了5,并导致输出高和宽由2增加到4。图5.2中的阴影部分为第一个输出元素及其计算所使用的输入和核数…

java实现Comparable接口和Comparator接口,并重写compareTo方法和compare方法

原文地址https://segmentfault.com/a/1190000005738975 实体类:java.lang.Comparable(接口) comareTo(重写方法),业务排序类 java.util.Comparator(接口) compare(重写方法). 这两个接口我们非常的熟悉,但是 在用的时候会有一些不知道怎么下手的感觉&a…

hdu 4714 树+DFS

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid4714 本来想直接求树的直径,再得出答案,后来发现是错的。 思路:任选一个点进行DFS,对于一棵以点u为根节点的子树来说,如果它的分支数大于1&#xff0c…

springboot----shiro集成

springboot中集成shiro相对简单,只需要两个类:一个是shiroConfig类,一个是CustonRealm类。 ShiroConfig类: 顾名思义就是对shiro的一些配置,相对于之前的xml配置。包括:过滤的文件和权限,密码加…

[pytorch、学习] - 5.3 多输入通道和多输出通道

参考 5.3 多输入通道和多输出通道 前面两节里我们用到的输入和输出都是二维数组,但真实数据的维度经常更高。例如,彩色图像在高和宽2个维度外还有RGB(红、绿、蓝)3个颜色通道。假设彩色图像的高和宽分别是h和w(像素),那么它可以表示为一个3 * h * w的多维数组。我们将大小为3…

非阻塞算法简介

在不只一个线程访问一个互斥的变量时,所有线程都必须使用同步,否则就可能会发生一些非常糟糕的事情。Java 语言中主要的同步手段就是 synchronized 关键字(也称为内在锁),它强制实行互斥,确保执行 synchron…

springboot---成员初始化顺序

如果我们的类有如下成员变量: Component public class A {Autowiredpublic B b; // B is a beanpublic static C c; // C is also a beanpublic static int count;public float version;public A() {System.out.println("This is A constructor.");}Au…

[pytorch、学习] - 5.4 池化层

参考 5.4 池化层 在本节中我们介绍池化(pooling)层,它的提出是为了缓解卷积层对位置的过度敏感性。 5.4.1 二维最大池化层和平均池化层 池化层直接计算池化窗口内元素的最大值或者平均值。该运算也叫做最大池化层或平均池化层。 下面把池化层的前向计算实现在pool2d函数里…

mac上安装Chromedriver注意事宜

mac上安装Chromedriver注意事宜: 1.网上下载chromedriver文件或在百度网盘找chromedirver文件 2.将 chromedriver 放置到:/usr/local/bin/,操作如下: 打开Mac终端terminal : 进入 chromedirve文件所在目录,输入命令: s…