PyTorch|保存及加载模型、nn.Sequential、ModuleList和ModuleDict

系列文章目录

PyTorch|Dataset与DataLoader使用、构建自定义数据集
PyTorch|搭建分类网络实例、nn.Module源码学习
pytorch|autograd使用、训练模型

文章目录

  • 系列文章目录
  • 一、保存及加载模型
    • (一)保存及加载模型的权重
    • (二)保存及加载优化器的权重
    • (三)保存及加载整个模型
    • (四)保存及加载更具一般性的checkpoint
    • (五)保存多个模型
  • 二、nn.Sequential源码分析
    • (一)init函数
    • (二)forward函数
  • 三、ModuleList和ModuleDict
    • (一)ModuleList
    • (二)ModuleDict


一、保存及加载模型

通过torch.save可以将该模型的参数、优化器状态、batch normalization、dropout、buffer变量等信息。

import torch
import torchvision.models as models

(一)保存及加载模型的权重

模型取自torchvision.models里的vgg16,权重为IMAGENET1K_V1。

model.state_dict()是模型的权重。state_dict状态字典:一般包含当前model的参数及buffer变量

model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

推理时可以实现模型的加载:

  • 创建模型实例
  • 将实现保存的模型信息通过torch.load导入进来
  • 采用load_state_dict函数将模型信息载入模型实例
  • model.eval()使得模型进入推理模式
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

(二)保存及加载优化器的权重

保存优化器权重:
在这里插入图片描述

加载优化器权重:
在这里插入图片描述

(三)保存及加载整个模型

保存整个模型:

torch.save(model, 'model.pth')

加载整个模型:

model = torch.load('model.pth')

(四)保存及加载更具一般性的checkpoint

保存并加载用于推理或恢复训练的一般性checkpoint有助于从上次中断的地方重新开始。在保存一般检查点时,不仅仅是保存模型的state_dict,还包括保存优化器的state_dict、停止使用的时间,最近记录的训练损失,外部的torch.nn.Embedding层等等。

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4torch.save({'epoch': EPOCH,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': LOSS,}, PATH)

加载:

model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

(五)保存多个模型

保存多个模型时可以将其直接合并到一个大字典中保存。

# Specify a path to save to
PATH = "model.pt"torch.save({'modelA_state_dict': netA.state_dict(),'modelB_state_dict': netB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),}, PATH)

二、nn.Sequential源码分析

nn.Sequential是有序的,当实例化nn.Sequential时,传入的模块顺序就是神经网络前向传播的顺序

在使用nn.Sequential时,可以按顺序传入模块,也可以输入一个字典。
在这里插入图片描述

(一)init函数

如果输入的是一个字典,init函数会采用遍历字典的方式,如果是一个一个的模块,init函数也会针对性的采取其他遍历方法。
在这里插入图片描述

(二)forward函数

对于一个模型的输入,nn.Sequential会依次的过其中的子模块。
在这里插入图片描述

nn.Sequential相比于ModuleList和ModuleDict来说,优势在于具有forward的功能。

三、ModuleList和ModuleDict

(一)ModuleList

pytorch允许我们把很多子模块放到一个列表中。ModuleList就是用于存放多个子模块的一个列表,在使用时可以对其进行遍历。ModuleList不单纯是一个列表,它本身就是一个module。
在这里插入图片描述

(二)ModuleDict

ModuleDict是用于存放多个子模块的一个字典,在使用时可以根据索引获得对应的子模块。ModuleDict不单纯是一个字典,它本身也是一个module。
在这里插入图片描述

除此之外,还有ParameterList、ParameterDict等,这些与ModuleList和ModuleDict的作用及使用方式类似。

参考:
8、深入剖析PyTorch的state_dict、parameters、modules源码
9、深入剖析PyTorch的nn.Sequential及ModuleList源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/796.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Rust 语言中的跨平台 GUI 库

在 Rust 社区中,Iced 是值得关注的跨平台 GUI (图形用户界面) 库之一。由 iced-rs 团队开发,Iced的设计灵感来源于 Elm 语言,它以简洁性和类型安全性为特色,旨在提供一个简单易用且功能丰富的 GUI 开发体验。本文将深入探讨 Iced&…

探究欧拉恒等式的美学与数学威力

正如老子所述,“道生一,一生二,二生三,三生万物”,数学作为人类认知自然法则的语言,其数系的不断发展象征着对世界理解的深化。从自然数经由分数、无理数至复数,复数虽看似反直觉,却…

MATLAB实现蚁群算法优化柔性车间调度(ACO-fjsp)

蚁群算法优化车间调度的步骤可以分为以下几个主要阶段: 1.初始化阶段: 设置算法参数,如信息素浓度、启发式因子等。这些参数将影响蚂蚁在选择路径时的决策过程。 确定车间调度的具体问题规模,包括工件数量、机器数量以及每个工件…

前端监控系统建设:错误收集、性能监控与用户体验优化

在前端开发过程中,建立一个监控系统是非常重要的,它可以帮助我们实时捕获错误、监控性能,并优化用户体验。下面是一些建设前端监控系统的关键部分。 错误收集:一个好的错误收集系统可以帮助我们迅速发现并修复代码中的错误。我们可…

AI:162-如何使用Python进行图像识别与处理深度学习与卷积神经网络的应用

本文收录于专栏:精通AI实战千例专栏合集 从基础到实践,深入学习。无论你是初学者还是经验丰富的老手,对于本专栏案例和项目实践都有参考学习意义。 每一个案例都附带关键代码,详细讲解供大家学习,希望可以帮到大家。正…

Beckhoff倍福工业电脑C6240-1037-0030主板维修CB1051-0003 CPU深圳捷达工控维修

Installation and Operating instructions for Control Cabinet PC C6240 from -0060 PS/2 连接 PS/2 上部 PS/2 连接器 (X104) 允许使用 PS/2 鼠标,而 PC 键盘可连接至下部 PS/2 连接器 (X103)。 USB接口 USB1 – USB4 四个 USB 接口 (X108 – X111) 用于通过 US…

OpenHarmony GIF图像渲染库—ohos-gif-drawable

简介 本项目是OpenHarmony系统的一款GIF图像渲染库,基于Canvas进行绘制,主要能力如下: 支持播放GIF图片。支持控制GIF播放/暂停。支持重置GIF播放动画。支持调节GIF播放速率。支持监听GIF所有帧显示完成后的回调。支持设置显示大小。支持7种不同的展示…

面试题:Redis如何防止缓存穿透 + 布隆过滤器原理

题目来源 招银网络-技术-1面 题目描述 缓存穿透是什么?如何防止缓存穿透布隆过滤器的原理是什么? 我的回答 缓存穿透是什么? 攻击者大量请求缓存和数据库中都不存在的key。如何防止缓存穿透 可以使用布隆过滤器布隆过滤器的原理是什么&a…

AI容器化部署开发尝试 (一)(Pycharm连接docker,并部署django测试)

注意:从 Docker 19.03 开始,Docker 引入了对 NVIDIA GPU 的原生支持,因此若AI要调用GPU算力的话docker版本也是有要求的,后面博客测试。 当然本篇博客还没设计到GPU的调用,主要Pycharm加Anaconda的方案用习惯了&#…

缓存的使用及常见问题的解决方案

用户通过浏览器向我们发送请求,这个时候浏览器就会建立一个缓存,主要缓存一些静态资源(js、css、图片),这样做可以降低之后访问的网络延迟。然后我们可以在Tomcat里面添加一些应用缓存,将一些从数据库查询到…

wsl中ollama不能使用gpu加速

之前还能有gpu加速的, 突然一次发现不能加速了, 启动之后发现只能用cpu了 log time2024-04-19T00:05:08.21308:00 levelINFO sourceimages.go:806 msg"total blobs: 80" time2024-04-19T00:05:08.24808:00 levelINFO sourceimages.go:813 msg"tota…

Flask:URL与视图的映射

默认端口号80、443 blog_id 限制数据类型的话(int) 除此之外别的数据类型也可以,或者多个(用any) /book/list?page6

骑砍2霸主MOD开发(5)-游戏事件

一.MissionBehavior Mission任务中发生的事件,AgentSpawn,AgentRemove,BeforeMissionStart等统称为MissionBehavior. 通过在Mission中添加属于自己的MissionBehavior实现对游戏任务事件的捕捉 <1.在MBSubModuleBase中重写OnBeforeMissionBehaviorInitialize(Mission mission…

【笔记】ASP.NET Core Web API之Token验证

在实际开发中经常需要对外提供接口以便客户获取数据&#xff0c;由于数据属于私密信息&#xff0c;并不能随意供其他人访问&#xff0c;所以就需要验证客户身份。那么如何才能验证客户的身份呢&#xff1f;一个简单的小例子&#xff0c;简述ASP.NET Core Web API开发过程中&…

Git学习笔记(三)Git分支

Git分支是Git中非常重要的一个概念&#xff0c;无论是个人开发还是多人协作中&#xff0c;分支都起着至关重要的作用。几乎所有的版本控制系统都以某种形式支持分支。 使用分支意味着你可以把你的工作从开发主线上分离 开来进行重大的Bug修改、开发新的功能&#xff0c;以免影响…

笨蛋学C++【C++基础第二弹】

C基础第二弹 2.C运算2.1运算符2.1.1算术运算符2.1.2关系运算符2.1.3逻辑运算符2.1.4位运算符2.1.5赋值运算符2.1.6杂项运算符2.1.7运算符优先级2.1.8注意 3.C循环3.1Cwhile循环3.1.1语法 3.2Cfor循环3.2.1基于范围的for循环方式13.2.2基于范围的for循环方式23.2.3基于范围的for…

Linux驱动开发笔记(零)驱动基础知识及准备

文章目录 前言一、Liunx、MCU和FPGA编程的区别二、Linux内核模块1. 什么是内核模块2. 内核模块的代码架构3. 头文件4. 模块参数5. makefile说明 三、 驱动程序设计思路1. 基本步骤2. 设备号3. 数据结构3.1 file_operations3.2 file3.3 inode3.4 哈希表3.5 cdev结构体3.6 kobj_m…

[Linux][进程信号][一][信号基础][如何产生信号]详细解读

目录 0.前言预备1.系统定义的信号列表2.核心转储 -- Core Dump 1.信号基础1.信号概念2.信号处理方式概览3.理解信号如何被保存4.信号发送的本质 2.如何产生信号&#xff1f;1.终端按键产生信号2.系统调用接口1.kill()2.raise()3.abort()4.如何理解&#xff1f; 3.由软件条件产生…

将闲置的windows硬盘通过smb共享的方式提供给mac作为时间机器备份

1.windows端&#xff0c;开启smb共享 自行解决 2.mac端 磁盘工具-文件-新建映像-空白映像 假设你的名字为&#xff1a;backup 大小&#xff1a;350GB&#xff08;自己修改&#xff09; 格式&#xff1a;MacOS扩展&#xff08;日志式&#xff09; 分区&#xff1a;单个分区-A…

C# 图像旋转一定角度后,对应坐标怎么计算?

原理分析 要计算图像内坐标在旋转一定角度后的新坐标&#xff0c;可以使用二维空间中的点旋转公式。假设图像的中心点&#xff08;即旋转中心&#xff09;为 (Cx, Cy)&#xff0c;通常对于正方形图像而言&#xff0c;中心点坐标为 (Width / 2, Height / 2)。给定原坐标点 (X, …