pytorch小记(七):pytorch中的保存/加载模型操作

pytorch小记(七):pytorch中的保存/加载模型操作

  • 1. 加载模型参数 (`state_dict`)
    • 1.1 保存模型参数
    • 1.2 加载模型参数
    • 1.3 常见变种
      • 1.3.1 指定加载设备
      • 1.3.2 非严格加载(跳过部分层)
      • 1.3.3 打印加载的参数
  • 2. 加载整个模型
    • 2.1 保存整个模型
    • 2.2 加载整个模型
    • 2.3 注意事项
  • 3. 总结
  • 4. 加载模型的完整代码示例
    • 4.1 保存和加载参数
    • 4.2 保存和加载整个模型
    • 4.3 加载到不同设备
    • 4.4 忽略部分参数(非严格加载)
    • 5. 检查模型是否加载成功


在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict)加载整个模型。以下是加载模型的所有相关操作及其详细步骤:


1. 加载模型参数 (state_dict)

当仅保存了模型的参数时(使用 model.state_dict() 保存),加载模型的步骤如下:

1.1 保存模型参数

torch.save(model.state_dict(), 'model.pth')
  • 文件内容:只保存模型的参数(权重和偏置)。
  • 优点
    • 节省存储空间。
    • 灵活性更高,可以与不同的模型架构配合使用。
  • 缺点
    • 需要手动重新定义模型结构。

1.2 加载模型参数

  1. 重新定义模型架构:

    model = MyModel()  # 替换为你的模型类
    
  2. 加载参数:

    state_dict = torch.load('model.pth')  # 加载参数字典
    model.load_state_dict(state_dict)    # 加载参数到模型
    
  3. 选择运行设备:

    model.to('cuda')  # 如果需要运行在 GPU 上
    

1.3 常见变种

1.3.1 指定加载设备

  • 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用 map_location
    state_dict = torch.load('model.pth', map_location='cpu')
    

1.3.2 非严格加载(跳过部分层)

  • 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用 strict=False
    model.load_state_dict(state_dict, strict=False)
    

1.3.3 打印加载的参数

  • 可以检查参数字典的内容:
    print(state_dict.keys())
    

2. 加载整个模型

当模型是通过 torch.save(model) 保存时,文件包含了模型的结构和参数,加载更为简单。

2.1 保存整个模型

torch.save(model, 'model_full.pth')
  • 文件内容:包含模型的架构和参数。
  • 优点
    • 无需重新定义模型结构。
    • 直接加载并使用。
  • 缺点
    • 文件依赖于保存时的代码版本(如模型定义)。
    • 文件体积较大。

2.2 加载整个模型

model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上运行

2.3 注意事项

  • 动态定义的模型
    • 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
    • 确保在加载时导入了与保存时相同的模型类。

3. 总结

操作使用场景优点缺点
保存参数 (state_dict)推荐大多数情况文件小、灵活性高需要手动定义模型架构
保存整个模型模型复杂且固定时不需要重新定义模型,直接加载文件大、依赖保存时的代码版本

4. 加载模型的完整代码示例

4.1 保存和加载参数

import torch
import torch.nn as nn# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')# 加载参数
model = MyModel()  # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 运行在 GPU

4.2 保存和加载整个模型

# 保存整个模型
torch.save(model, 'model_full.pth')# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda')  # 运行在 GPU

4.3 加载到不同设备

# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)# 加载到 GPU
model.to('cuda')

4.4 忽略部分参数(非严格加载)

# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 检查模型是否加载成功

  1. 验证权重是否加载

    for name, param in model.named_parameters():print(f"{name}: {param.data}")
    
  2. 进行推理验证

    x = torch.randn(1, 10).to('cuda')  # 假设输入维度为 10
    output = model(x)
    print(output)
    

通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65829.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql--运维篇--主从复制和集群(主从复制I/O线程,SQL线程,二进制日志,中继日志,集群NDB)

一、主从复制 MySQL的主从复制(Master-Slave Replication)是一种数据冗余和高可用性的解决方案,它通过将一个或多个从服务器(Slave)与主服务器(Master)同步来实现。主从复制的基本原理是&#…

【EI会议征稿通知】第十一届机械工程、材料和自动化技术国际会议(MMEAT 2025)

本次大会旨在汇聚全球机械工程、材料科学及自动化技术的创新学者和行业专家,为他们提供一个卓越的交流与合作平台。随着全球对可持续技术和智能制造需求的不断增加,MMEAT 2025将重点关注这些领域的最新发展趋势和未来前景。此次大会的主要目标是推动机械…

OpenCV基础:视频的采集、读取与录制

从摄像头采集视频 相关接口 - VideoCapture VideoCapture 用于从视频文件、摄像头或其他视频流设备中读取视频帧。它可以捕捉来自多种源的视频。 主要参数: cv2.VideoCapture(source): source: 这是一个整数或字符串,表示视频的来源。 如果是整数&a…

解读Linux Bridge中的东西流向与南北流向

解读Linux Bridge中的东西流向与南北流向 在现代云计算和虚拟化环境中,网络流量的管理和优化变得越来越重要。Linux Bridge作为Linux内核提供的一个强大的二层交换机工具,在虚拟化和容器化应用中扮演着至关重要的角色。本文将深入探讨Linux Bridge中的两…

车联网安全--TLS握手过程详解

目录 1. TLS协议概述 2. 为什么要握手 2.1 Hello 2.2 协商 2.3 同意 3.总共握了几次手? 1. TLS协议概述 车内各ECU间基于CAN的安全通讯--SecOC,想必现目前多数通信工程师们都已经搞的差不多了(不要再问FvM了);…

RuoYi Cloud项目解读【四、项目配置与启动】

四、项目配置与启动 当上面环境全部准备好之后,接下来就是项目配置。需要将项目相关配置修改成当前相关环境。 1 后端配置 1.1 数据库 创建数据库ry-cloud并导入数据脚本ry_2024xxxx.sql(必须),quartz.sql(可选&…

第432场周赛:跳过交替单元格的之字形遍历、机器人可以获得的最大金币数、图的最大边权的最小值、统计 K 次操作以内得到非递减子数组的数目

Q1、跳过交替单元格的之字形遍历 1、题目描述 给你一个 m x n 的二维数组 grid,数组由 正整数 组成。 你的任务是以 之字形 遍历 grid,同时跳过每个 交替 的单元格。 之字形遍历的定义如下: 从左上角的单元格 (0, 0) 开始。在当前行中向…

Harry技术添加存储(minio、aliyun oss)、短信sms(aliyun、模拟)、邮件发送等功能

Harry技术添加存储(minio、aliyun oss)、短信sms(aliyun、模拟)、邮件发送等功能 基于SpringBoot3Vue3前后端分离的Java快速开发框架 项目简介:基于 JDK 17、Spring Boot 3、Spring Security 6、JWT、Redis、Mybatis-P…

R数据分析:多分类问题预测模型的ROC做法及解释

有同学做了个多分类的预测模型,结局有三个类别,做的模型包括多分类逻辑回归、随机森林和决策树,多分类逻辑回归是用ROC曲线并报告AUC作为模型评估的,后面两种模型报告了混淆矩阵,审稿人就提出要统一模型评估指标。那么肯定是统一成ROC了,刚好借这个机会给大家讲讲ROC在多…

记一次学习skynet中的C/Lua接口编程解析protobuf过程

1.引言 最近在学习skynet过程中发现在网络收发数据的过程中数据都是裸奔,就想加入一种数据序列化方式,json、xml简单好用,但我就是不想用,于是就想到了protobuf,对于protobuf C/C的使用个人感觉有点重,正好…

SQLAlchemy

https://docs.sqlalchemy.org.cn/en/20/orm/quickstart.htmlhttps://docs.sqlalchemy.org.cn/en/20/orm/quickstart.html 声明模型 在这里,我们定义模块级构造,这些构造将构成我们从数据库中查询的结构。这种结构被称为 声明式映射,它同时定…

Trimble自动化激光监测支持历史遗产实现可持续发展【沪敖3D】

故事桥(Story Bridge)位于澳大利亚布里斯班,建造于1940年,全长777米,横跨布里斯班河,可载汽车、自行车和行人往返于布里斯班的北部和南部郊区。故事桥是澳大利亚最长的悬臂桥,是全世界两座手工建…

Playwright vs Selenium:全面对比分析

在现代软件开发中,自动化测试工具在保证应用质量和加快开发周期方面发挥着至关重要的作用。Selenium 作为自动化测试领域的老牌工具,长期以来被广泛使用。而近年来,Playwright 作为新兴工具迅速崛起,吸引了众多开发者的关注。那么…

Windows 程序设计3:宽窄字节的区别及重要性

文章目录 前言一、宽窄字节简介二、操作系统及VS编译器对宽窄字节的编码支持1. 操作系统2. 编译器 三、宽窄字符串的优缺点四、宽窄字节数据类型总结 前言 Windows 程序设计3:宽窄字节的区别及重要性。 一、宽窄字节简介 在C中,常用的字符串指针就是ch…

进阶——十六届蓝桥杯嵌入式熟练度练习(LED的全开,全闭,点亮指定灯,交替闪烁,PWM控制LED呼吸灯)

点亮灯的函数 void led_show(unsigned char upled) { HAL_GPIO_WritePin(GPIOC,GPIO_PIN_All,GPIO_PIN_SET); HAL_GPIO_WritePin(GPIOC,upled<<8,GPIO_PIN_RESET); HAL_GPIO_WritePin(GPIOD,GPIO_PIN_2,GPIO_PIN_SET); HAL_GPIO_WritePin(GPIOD,GPIO_PIN_2,GPIO_PIN_RE…

力扣 最大子数组和

动态规划&#xff0c;前缀和&#xff0c;维护状态更新。 题目 从题可以看出&#xff0c;找的是最大和的连续子数组&#xff0c;即一个数组中的其中一个连续部分。从前往后遍历&#xff0c;每遍历到一个数可以尝试做叠加&#xff0c;注意是尝试&#xff0c;因为有可能会遇到一个…

Homestyler 和 Tripo AI 如何利用人工智能驱动的 3D 建模改变定制室内设计

让设计梦想照进现实 在Homestyler,我们致力于为每一个梦想设计师提供灵感的源泉,而非挫折。无论是初学者打造第一套公寓,或是专业设计师展示作品集,我们的直观工具都能让您轻松以惊人的3D形式呈现空间。 挑战:实现定制设计的新纪元 我们知道,将个人物品如传家宝椅子、…

算法练习4——一个六位数

这道题特别妙 大家仔细做一做 我这里采用的是动态规划来解这道题 结合题目要求找出数与数之间的规律 抽象出状态转移方程 题目描述 有一个六位数&#xff0c;其个位数字 7 &#xff0c;现将个位数字移至首位&#xff08;十万位&#xff09;&#xff0c;而其余各位数字顺序不…

client-go 的 QPS 和 Burst 限速

1. 什么是 QPS 和 Burst &#xff1f; 在 kubernetes client-go 中&#xff0c;QPS 和 Burst 是用于控制客户端与 Kubernetes API 交互速率的两个关键参数&#xff1a; QPS (Queries Per Second) 定义&#xff1a;表示每秒允许发送的请求数量&#xff0c;即限速器的平滑速率…

太原理工大学软件设计与体系结构 --javaEE

这个是简答题的内容 选择题的一些老师会给你们题库&#xff0c;一些注意的点我会做出文档在这个网址 项目目录预览 - TYUT复习资料:复习资料 - GitCode 希望大家可以给我一些打赏 什么是Spring的IOC和DI IOC 是一种设计思想&#xff0c;它将对象的创建和对象之间的依赖关系…