pytorch小记(七):pytorch中的保存/加载模型操作

pytorch小记(七):pytorch中的保存/加载模型操作

  • 1. 加载模型参数 (`state_dict`)
    • 1.1 保存模型参数
    • 1.2 加载模型参数
    • 1.3 常见变种
      • 1.3.1 指定加载设备
      • 1.3.2 非严格加载(跳过部分层)
      • 1.3.3 打印加载的参数
  • 2. 加载整个模型
    • 2.1 保存整个模型
    • 2.2 加载整个模型
    • 2.3 注意事项
  • 3. 总结
  • 4. 加载模型的完整代码示例
    • 4.1 保存和加载参数
    • 4.2 保存和加载整个模型
    • 4.3 加载到不同设备
    • 4.4 忽略部分参数(非严格加载)
    • 5. 检查模型是否加载成功


在 PyTorch 中,加载模型通常分为两种情况:加载模型参数(state_dict)加载整个模型。以下是加载模型的所有相关操作及其详细步骤:


1. 加载模型参数 (state_dict)

当仅保存了模型的参数时(使用 model.state_dict() 保存),加载模型的步骤如下:

1.1 保存模型参数

torch.save(model.state_dict(), 'model.pth')
  • 文件内容:只保存模型的参数(权重和偏置)。
  • 优点
    • 节省存储空间。
    • 灵活性更高,可以与不同的模型架构配合使用。
  • 缺点
    • 需要手动重新定义模型结构。

1.2 加载模型参数

  1. 重新定义模型架构:

    model = MyModel()  # 替换为你的模型类
    
  2. 加载参数:

    state_dict = torch.load('model.pth')  # 加载参数字典
    model.load_state_dict(state_dict)    # 加载参数到模型
    
  3. 选择运行设备:

    model.to('cuda')  # 如果需要运行在 GPU 上
    

1.3 常见变种

1.3.1 指定加载设备

  • 如果保存时模型在 GPU 上,而加载时在 CPU 环境中,可以使用 map_location
    state_dict = torch.load('model.pth', map_location='cpu')
    

1.3.2 非严格加载(跳过部分层)

  • 如果保存的参数与模型结构不完全匹配(例如额外的层或不同的顺序),可以使用 strict=False
    model.load_state_dict(state_dict, strict=False)
    

1.3.3 打印加载的参数

  • 可以检查参数字典的内容:
    print(state_dict.keys())
    

2. 加载整个模型

当模型是通过 torch.save(model) 保存时,文件包含了模型的结构和参数,加载更为简单。

2.1 保存整个模型

torch.save(model, 'model_full.pth')
  • 文件内容:包含模型的架构和参数。
  • 优点
    • 无需重新定义模型结构。
    • 直接加载并使用。
  • 缺点
    • 文件依赖于保存时的代码版本(如模型定义)。
    • 文件体积较大。

2.2 加载整个模型

model = torch.load('model_full.pth')
model.to('cuda')  # 如果需要在 GPU 上运行

2.3 注意事项

  • 动态定义的模型
    • 如果模型结构是动态定义的(如包含条件逻辑),保存和加载整个模型可能会依赖于代码的一致性。
    • 确保在加载时导入了与保存时相同的模型类。

3. 总结

操作使用场景优点缺点
保存参数 (state_dict)推荐大多数情况文件小、灵活性高需要手动定义模型架构
保存整个模型模型复杂且固定时不需要重新定义模型,直接加载文件大、依赖保存时的代码版本

4. 加载模型的完整代码示例

4.1 保存和加载参数

import torch
import torch.nn as nn# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)# 保存参数
model = MyModel()
torch.save(model.state_dict(), 'model.pth')# 加载参数
model = MyModel()  # 重新定义模型
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict)
model.to('cuda')  # 运行在 GPU

4.2 保存和加载整个模型

# 保存整个模型
torch.save(model, 'model_full.pth')# 加载整个模型
model = torch.load('model_full.pth')
model.to('cuda')  # 运行在 GPU

4.3 加载到不同设备

# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载到 CPU
state_dict = torch.load('model.pth', map_location='cpu')
model.load_state_dict(state_dict)# 加载到 GPU
model.to('cuda')

4.4 忽略部分参数(非严格加载)

# 保存参数
torch.save(model.state_dict(), 'model.pth')# 加载参数(非严格模式)
model = MyModel()
state_dict = torch.load('model.pth')
model.load_state_dict(state_dict, strict=False)

5. 检查模型是否加载成功

  1. 验证权重是否加载

    for name, param in model.named_parameters():print(f"{name}: {param.data}")
    
  2. 进行推理验证

    x = torch.randn(1, 10).to('cuda')  # 假设输入维度为 10
    output = model(x)
    print(output)
    

通过以上操作,你可以灵活加载 PyTorch 模型,无论是仅加载参数还是加载整个模型结构和权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/65829.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql--运维篇--主从复制和集群(主从复制I/O线程,SQL线程,二进制日志,中继日志,集群NDB)

一、主从复制 MySQL的主从复制(Master-Slave Replication)是一种数据冗余和高可用性的解决方案,它通过将一个或多个从服务器(Slave)与主服务器(Master)同步来实现。主从复制的基本原理是&#…

【EI会议征稿通知】第十一届机械工程、材料和自动化技术国际会议(MMEAT 2025)

本次大会旨在汇聚全球机械工程、材料科学及自动化技术的创新学者和行业专家,为他们提供一个卓越的交流与合作平台。随着全球对可持续技术和智能制造需求的不断增加,MMEAT 2025将重点关注这些领域的最新发展趋势和未来前景。此次大会的主要目标是推动机械…

OpenCV基础:视频的采集、读取与录制

从摄像头采集视频 相关接口 - VideoCapture VideoCapture 用于从视频文件、摄像头或其他视频流设备中读取视频帧。它可以捕捉来自多种源的视频。 主要参数: cv2.VideoCapture(source): source: 这是一个整数或字符串,表示视频的来源。 如果是整数&a…

解读Linux Bridge中的东西流向与南北流向

解读Linux Bridge中的东西流向与南北流向 在现代云计算和虚拟化环境中,网络流量的管理和优化变得越来越重要。Linux Bridge作为Linux内核提供的一个强大的二层交换机工具,在虚拟化和容器化应用中扮演着至关重要的角色。本文将深入探讨Linux Bridge中的两…

在线实用工具 json格式化,base64转码,正则表达式测试工具

1、在线json格式化工具: https://json.openai2025.com/ 2、在线base64转码工具 https://base64.openai2025.com/ 3、在线正则表达式测试工具 https://reg.openai2025.com/ 4、在线去水印工具 https://watermark.openai2025.com

java 中 main 方法使用 KafkaConsumer 拉取 kafka 消息如何禁止输出 debug 日志

pom 依赖&#xff1a; <dependency><groupId>org.springframework.kafka</groupId><artifactId>spring-kafka</artifactId><version>2.5.14.RELEASE</version> </dependency> 或者 <dependency><groupId>org.ap…

车联网安全--TLS握手过程详解

目录 1. TLS协议概述 2. 为什么要握手 2.1 Hello 2.2 协商 2.3 同意 3.总共握了几次手&#xff1f; 1. TLS协议概述 车内各ECU间基于CAN的安全通讯--SecOC&#xff0c;想必现目前多数通信工程师们都已经搞的差不多了&#xff08;不要再问FvM了&#xff09;&#xff1b;…

RuoYi Cloud项目解读【四、项目配置与启动】

四、项目配置与启动 当上面环境全部准备好之后&#xff0c;接下来就是项目配置。需要将项目相关配置修改成当前相关环境。 1 后端配置 1.1 数据库 创建数据库ry-cloud并导入数据脚本ry_2024xxxx.sql&#xff08;必须&#xff09;&#xff0c;quartz.sql&#xff08;可选&…

C#对象池

一、资源管理的困境与破局 在软件开发的征程中&#xff0c;我们时常陷入资源管理的泥沼。以一个繁忙餐厅为例&#xff0c;每个顾客都急需一个盘子盛美食&#xff0c;可盘子数量有限&#xff0c;如果每次顾客用完盘子后&#xff0c;都不假思索地去清洗一个全新的盘子来供下一位…

Vue.js组件开发-如何使用moment.js

在Vue.js组件开发中&#xff0c;需要处理日期和时间&#xff0c;moment.js 是一个非常有用的库。moment.js 提供了丰富的API来解析、验证、操作和显示日期和时间。 步骤&#xff1a; 1. 安装moment.js 首先&#xff0c;需要通过npm或yarn安装moment.js。在项目根目录下运行以…

微信小程序mp3音频播放组件,仅需传入url即可

// index.js // packageChat/components/audio-player/index.js Component({/*** 组件的属性列表*/properties: {/*** MP3 文件的 URL*/src: {type: String,value: ,observer(newVal, oldVal) {if (newVal ! oldVal && newVal) {// 如果 InnerAudioContext 已存在&…

要避免除数绝对值远远小于被除数绝对值的除法

要避免除数绝对值远远小于被除数绝对值的除法 用绝对值小的数作除数&#xff0c;舍人误差会增大&#xff0c;如计算 x y \frac xy yx​,若 0 < ∣ y ∣ < ∣ x ∣ 0<|y|<|x| 0<∣y∣<∣x∣&#xff0c;则可能对计算结果带来严重影响&#xff0c;应尽量避免…

深入了解OpenStack中的隧道网络

在OpenStack环境中&#xff0c;隧道网络是一项关键技术&#xff0c;它确保了虚拟机之间以及虚拟机与外部网络之间的安全通信。通过隧道机制&#xff0c;我们可以有效地隔离不同租户的流量&#xff0c;并支持多租户环境下的复杂网络需求。之前我们介绍了隧道网络&#xff0c;下面…

4. scala高阶之隐式转换与泛型

背景 上一节&#xff0c;我介绍了scala中的面向对象相关概念&#xff0c;还有一个特色功能&#xff1a;模式匹配。本文&#xff0c;我会介绍另外一个特别强大的功能隐式转换&#xff0c;并在最后介绍scala中泛型的使用 1. 隐式转换 Scala提供的隐式转换和隐式参数功能&#…

pandas与sql对应关系【帮助sql使用者快速上手pandas】

本页旨在提供一些如何使用pandas执行各种SQL操作的示例&#xff0c;来帮助SQL使用者快速上手使用pandas。 目录 SQL语法一、选择SELECT1、选择2、添加计算列 二、连接JOIN ON1、内连接2、左外连接3、右外连接4、全外连接 三、过滤WHERE1、AND2、OR3、IS NULL4、IS NOT NULL5、B…

第432场周赛:跳过交替单元格的之字形遍历、机器人可以获得的最大金币数、图的最大边权的最小值、统计 K 次操作以内得到非递减子数组的数目

Q1、跳过交替单元格的之字形遍历 1、题目描述 给你一个 m x n 的二维数组 grid&#xff0c;数组由 正整数 组成。 你的任务是以 之字形 遍历 grid&#xff0c;同时跳过每个 交替 的单元格。 之字形遍历的定义如下&#xff1a; 从左上角的单元格 (0, 0) 开始。在当前行中向…

《探索鸿蒙Next上开发人工智能游戏应用的技术难点》

在科技飞速发展的当下&#xff0c;鸿蒙Next系统为应用开发带来了新的机遇与挑战&#xff0c;开发一款运行在鸿蒙Next上的人工智能游戏应用更是备受关注。以下是在开发过程中可能会遇到的一些技术难点&#xff1a; 鸿蒙Next系统适配性 多设备协同&#xff1a;鸿蒙Next的一大特色…

Harry技术添加存储(minio、aliyun oss)、短信sms(aliyun、模拟)、邮件发送等功能

Harry技术添加存储&#xff08;minio、aliyun oss&#xff09;、短信sms&#xff08;aliyun、模拟&#xff09;、邮件发送等功能 基于SpringBoot3Vue3前后端分离的Java快速开发框架 项目简介&#xff1a;基于 JDK 17、Spring Boot 3、Spring Security 6、JWT、Redis、Mybatis-P…

Vue2: el-table为每一行添加超链接,并实现光标移至文字上时改变形状

为表格中的某一列添加超链接 一个表格通常有许多列,网上许多教程都可以实现为某一列添加超链接,如下,实现了当光标悬浮在“姓名”上时,改变为手形,点击可实现跳转。 <el-table :data="tableData"><el-table-column label="姓名" prop=&quo…

R数据分析:多分类问题预测模型的ROC做法及解释

有同学做了个多分类的预测模型,结局有三个类别,做的模型包括多分类逻辑回归、随机森林和决策树,多分类逻辑回归是用ROC曲线并报告AUC作为模型评估的,后面两种模型报告了混淆矩阵,审稿人就提出要统一模型评估指标。那么肯定是统一成ROC了,刚好借这个机会给大家讲讲ROC在多…