PyTorch教程:如何读写张量与模型参数

本文演示了PyTorch中张量(Tensor)和模型参数的保存与加载方法,并提供完整的代码示例及输出结果,帮助读者快速掌握数据持久化的核心操作。


1. 保存和加载单个张量

通过torch.savetorch.load可以直接保存和读取张量。

import torch# 创建并保存张量
x = torch.arange(4)
torch.save(x, 'x-file')# 加载张量
x2 = torch.load('x-file')
print(x2)  # 输出:tensor([0, 1, 2, 3])

输出结果

tensor([0, 1, 2, 3])

2. 保存和加载张量列表

可以将多个张量存储为列表,并一次性加载。

# 创建两个张量并保存为列表
y = torch.zeros(4)
torch.save([x, y], 'x-files')# 加载列表
x2, y2 = torch.load('x-files')
print((x2, y2))

输出结果

(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

3. 保存和加载字典

通过字典可以更灵活地管理多个张量。

# 创建字典并保存
mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')# 加载字典
mydict2 = torch.load('mydict')
print(mydict2)

输出结果

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

4. 定义神经网络模型

以下是一个简单的全连接神经网络示例:

from torch import nn
from torch.nn import functional as Fclass Model(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)  # 隐藏层self.output = nn.Linear(256, 10)   # 输出层def forward(self, x):return self.output(F.relu(self.hidden(x)))# 实例化模型并进行前向传播
net = Model()
x = torch.rand(size=(2, 20))
y = net(x)
print(y)

输出结果(因随机初始化可能不同):

tensor([[-0.0711, 0.1161, -0.1113, ..., 0.0787],[-0.0151, 0.0275, -0.1652, ..., 0.0109]], grad_fn=<AddmmBackward0>)

5. 保存模型参数

使用state_dict保存模型参数:

torch.save(net.state_dict(), 'net.params')

6. 加载模型参数并验证

加载参数到新模型实例,并验证一致性:

# 创建新模型并加载参数
clone = Model()
clone.load_state_dict(torch.load('net.params'))
clone.eval()  # 设置为评估模式(关闭Dropout/BatchNorm等)# 比较输出结果
Y_clone = clone(x)
print(Y_clone == y)

输出结果

tensor([[True, True, ..., True],[True, True, ..., True]])

总结

  1. 张量读写:直接使用torch.savetorch.load,支持列表和字典。

  2. 模型参数保存:通过state_dict保存模型状态,加载时需重新实例化模型。

  3. 验证一致性:加载参数后,输出与原模型一致表明操作成功。

通过本文的代码示例,读者可以快速掌握PyTorch中数据和模型参数的持久化方法,为模型训练和部署提供便利。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75061.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

持续集成:GitLab CI/CD 与 Jenkins CI/CD 的全面剖析

一、引言 在当今快速迭代的软件开发领域,持续集成(Continuous Integration,CI)已成为保障软件质量、加速开发流程的关键实践。通过频繁地将代码集成到共享仓库,并自动进行构建和测试,持续集成能够尽早发现并解决代码冲突和缺陷。而 GitLab CI/CD 和 Jenkins CI/CD 作为两…

Python 序列构成的数组(序列的增量赋值)

序列的增量赋值 增量赋值运算符 和 * 的表现取决于它们的第一个操作对象。简单起 见&#xff0c;我们把讨论集中在增量加法&#xff08;&#xff09;上&#xff0c;但是这些概念对 * 和其他 增量运算符来说都是一样的。 背后的特殊方法是 iadd &#xff08;用于“就地加法”&…

GEO, TCGA 等将被禁用?!这40个公开数据库可能要小心使用了

GEO, TCGA 等将被禁用&#xff1f;&#xff01;这40个公开数据库可能要小心使用了 最近NIH公共数据库开始对中国禁用的消息闹得风风火火&#xff1a; 你认为研究者上传到 GEO 数据库上的数据会被禁用吗&#xff1f; 单选 会&#xff0c;毕竟占用存储资源 不会&#xff0c;不…

【如何自建MCP服务器?从协议原理到实践的全流程指南】

文章目录 如何自建MCP服务器&#xff1f;从协议原理到实践的全流程指南一、MCP协议是什么&#xff1f;核心架构 二、为什么要自建MCP服务器&#xff1f;1. 突破LLM的固有局限2. 实现个性化功能扩展3. 确保数据隐私安全 三、手把手搭建MCP服务器&#xff08;Python示例&#xff…

鸿蒙开发_ARKTS快速入门_语法说明_渲染控制---纯血鸿蒙HarmonyOS5.0工作笔记012

然后我们再来看渲染控制 首先看条件渲染,其实就是根据不同的状态,渲染不同的UI界面 比如下面这个暂停 开启播放的 可以看到就是通过if 这种条件语句 修改状态变量的值 然后我们再来看这个, 下面点击哪个,上面横线就让让他显示哪个 去看一下代码 可以看到,有两个状态变量opt…

【Java设计模式】第3章 软件设计七大原则

3-1 本章导航 学习开辟原则(基础原则)依赖倒置原则单一职责原则接口隔离原则迪米特法则(最少知道原则)里氏替换原则合成复用原则(组合复用原则)核心思想: 设计原则需结合实际场景平衡,避免过度设计。设计模式中可能部分遵循原则,需灵活取舍。3-2 开闭原则讲解 定义 软…

JVM即时编译(JIT)

JVM基础回顾 Java 作为一门高级程序语言&#xff0c;由于它自身的语言特性&#xff0c;它并非直接在硬件上运行&#xff0c;而是通过编译器(前端编译器)将 Java 程序转换成该虚拟机所能识别的指令序列&#xff0c;也就是字节码&#xff0c;然后运行在虚拟机之上的&#xff1b;…

刚体碰撞检测与响应(C++实现)

本文实现一个经典的物理算法&#xff1a;刚体碰撞检测与响应。这个算法用于检测两个刚体&#xff08;如矩形或圆形&#xff09;是否发生碰撞&#xff0c;并在碰撞时更新它们的速度和位置。我们将使用C来实现这个算法&#xff0c;并结合**边界框&#xff08;Bounding Box&#x…

常用的国内镜像源

常见的 pip 镜像源 阿里云镜像&#xff1a;https://mirrors.aliyun.com/pypi/simple/ 清华大学镜像&#xff1a;https://pypi.tuna.tsinghua.edu.cn/simple 中国科学技术大学镜像&#xff1a;https://pypi.mirrors.ustc.edu.cn/simple/ 豆瓣镜像&#xff1a;https://pypi.doub…

鸿蒙小案例-京东登录

效果 代码实现 Entry Component struct Index {build() {Column() {Row() {Image($r(app.media.jd_cancel)).width(20)Text(帮助).fontSize(16).fontColor(#666)}.width(100%).justifyContent(FlexAlign.SpaceBetween)Image($r(app.media.jd_logo)).height(250).width(250)// …

《 Scikit-learn与MySQL的深度协同:构建智能数据生态系统的架构哲学》

在机器学习工程实践中&#xff0c;数据存储与模型训练的割裂始终是制约算法效能的关键瓶颈。Scikit-learn作为经典机器学习库&#xff0c;其与MySQL的深度协同并非简单的数据管道连接&#xff0c;而是构建了一个具备自组织能力的智能数据生态系统。这种集成突破了传统ETL流程的…

华为AI-agent新作:使用自然语言生成工作流

论文标题 WorkTeam: Constructing Workflows from Natural Language with Multi-Agents 论文地址 https://arxiv.org/pdf/2503.22473 作者背景 华为&#xff0c;北京大学 动机 当下AI-agent产品百花齐放&#xff0c;尽管有ReAct、MCP等框架帮助大模型调用工具&#xff0…

关于软件bug描述

软件缺陷&#xff08;Defect&#xff09;&#xff0c;常常又被叫做Bug。 所谓软件缺陷&#xff0c;即为计算机软件或程序中存在的某种破坏正常运行能力的问题、错误&#xff0c;或者隐藏的功能缺陷。缺陷的存在会导致软件产品在某种程度上不能满足用户的需要。IEEE729-1983对缺…

【元表 vs 元方法】

元表 vs 元方法 —— 就像“魔法书”和“咒语”的关系 1. 元表&#xff08;Metatable&#xff09;&#xff1a;魔法书 是什么&#xff1f; 元表是一本**“规则说明书”**&#xff0c;它本身是一个普通的 Lua 表&#xff0c;但可以绑定到其他表上&#xff0c;用来定义这个表应该…

Spring Boot 通过全局配置去除字符串类型参数的前后空格

1、问题 避免前端输入的字符串参数两端包含空格&#xff0c;通过统一处理的方式&#xff0c;trim掉空格 2、实现方式 /*** 去除字符串类型参数的前后空格* author yanlei* since 2022-06-14*/ Configuration AutoConfigureAfter(WebMvcAutoConfiguration.class) public clas…

C语言核心知识点整理:结构体对齐、预处理、文件操作与Makefile

目录 结构体的字节对齐预处理指令详解文件操作基础Makefile自动化构建总结 1. 结构体的字节对齐 字节对齐原理 内存对齐&#xff1a;CPU访问内存时&#xff0c;对齐的地址能提高效率。操作系统要求变量按类型大小对齐。对齐规则&#xff1a; 每个成员的起始地址必须是min(成…

VBA+BOS单据+插件,解决计划任务跟踪的问题之二:导入ERP

第二步&#xff0c;就是要将拆分好的任务导入ERP了 1、将建一个BOS单据叫“任务池”&#xff0c;大概是这样的 然后在拆分工具中进行导数据&#xff0c;点击“数据导出准备”&#xff0c;跳转到“导入ERP”界面&#xff0c;然后点“获取数据”&#xff0c;将拆分好的数据转过来…

使用uglifyjs对静态引入的js文件进行压缩

前言 因为有时候js文件没有npm包&#xff0c;或者需要修改&#xff0c;只能引入静态的js&#xff0c;那么这个时候就可以对js进行压缩了。我其实想通过vite、webpack等插件进行压缩的&#xff0c;可是他都不能定位到public目录下面的文件&#xff0c;所以我只能自己压缩了。编…

蓝桥杯 web 水果拼盘 (css3)

做题步骤&#xff1a; 看结构&#xff1a;html 、css 、f12 分析: f12 查看元素&#xff0c;你会发现水果的高度刚好和拼盘的高度一样&#xff0c;每一种水果的盘子刚好把页面填满了&#xff0c;所以咱们就只要让元素竖着排列&#xff0c;加上是竖着&#xff0c;排不下的换行…

差分音频转单端音频单电源方案

TI LMV321介绍 TI的LMV321是单通道的低压轨到轨输出运算放大器&#xff0c;适用于需要低工作压、节省空间和低成本的应用。 其中&#xff0c;芯片设计中的轨到轨输出&#xff08;Rail-to-Rail Output&#xff09; 是指通过特定的电路设计&#xff0c;使芯片&#xff08;如运算…