2024年1月8日学习总结

目录

  • 学习计划
  • 学习内容
    • how to save and load models in pytorch
      • (1)什么是state_dict
        • 定义一个模型
        • 实例化模型
        • 初始化优化器
        • 查看模型的state_dict
        • 查看优化器的state_dict
      • (2)保存模型
        • A、save/load state_dict(推荐)
        • B、save/load entire model
      • (3)使用检查点(checkpoint)来保存和加载模型
      • (4)在一个文件中保存多个模型

学习计划

  • 对代码中的保存模型进行更深入的了解
  • 对预测部分进行编写(使用从未见过的数据)
  • 对模型进行对比(一个基站的情况)

学习内容

how to save and load models in pytorch

保存和加载模型用到的核心函数:

  • torch.save:保存序列化对象到磁盘中。可以使用这个函数保存models, tensors, dictionaries of all kinds of objects
  • torch.load:数据加载
  • torch.nn.Module.load_state_dict:使用state_dict加载模型参数

(1)什么是state_dict

在pytorch中,一个torch.nn.Model模型的可学习参数(比如权重或者偏置)保存在model.parameters()中。state_dict是一个Python字典,它将每一层映射到对应的参数张量。
例子

定义一个模型
# Define model
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
实例化模型
# Initialize model
model = TheModelClass()
初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
查看模型的state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())

结果:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])
查看优化器的state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

结果:

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

(2)保存模型

A、save/load state_dict(推荐)

save:

torch.save(model.state_dict(), 'save/to/path/model.pth')

load:

model = MyModelDefinition(args)
model.load_state_dict(torch.load('load/from/path/model.pth'))
model.eval()

在pytorch里,保存模型使用的文件后缀要么是“.pt”要么是“.pth”
在运行评估(inference)之前要使用model.eval()将dropout和批处理规范化层设置为评估模式。如果不这样做,将产生不一致的推理结果

Attention:
1️⃣load_state_dict函数接收的是字典,不能直接接收路径。
所以这样写就是错的:model.load_state_dict(PATH)
应该是:model.load_state_dict(torch.load(PATH))✔️
2️⃣如果想要保存best model,不能仅仅写成:best_model_state=model.state_dict(),因为这只是返回一个引用而不是复制下来了,需要写成:best_model_state=deepcopy(model.state_dict())。否则,最佳best_model_state将在随后的训练迭代中不断更新,进而导致最终的模型状态将是过拟合模型的状态。

B、save/load entire model

save:

torch.save(model, PATH)

load:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

(3)使用检查点(checkpoint)来保存和加载模型

通常ML pipeline需要定期或在满足条件时保存模型检查点。保存检查点是为了防止训练的时候因为一些莫名其妙的原因中断,这样就可以从最后或最佳检查点恢复训练。
但是仅仅保存模型的state_dict是不够的,还需要保存优化器的state_dict(因为这包含随着模型训练而更新的缓冲区和参数)以及最后的epoch number,loss, external torch.nn.Embedding layers等等。基本上需要存储所有的大小来使用检查点恢复训练。这样的检查点通常比单独的模型大2~3倍。
save:

torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss,...}, PATH)

load:

model = MyModelDefinition(args)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']model.eval()
# - or -
model.train()

(4)在一个文件中保存多个模型

当要保存多个复杂的模型时,和创建检查点相同,需要保存每个模型的state_dict和相对于的优化器在一个字典中。如前所述,您可以保存任何其他可能有助于您恢复训练的项目,只需将它们添加到字典中。
save:

torch.save({'modelA_state_dict': modelA.state_dict(),'modelB_state_dict': modelB.state_dict(),'optimizerA_state_dict': optimizerA.state_dict(),'optimizerB_state_dict': optimizerB.state_dict(),...}, PATH)

加载模型,首先初始化模型和优化器,然后使用torch.load()在本地加载字典。
load:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/606974.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高压放大器设计要求是什么

高压放大器在科学研究和工程应用中扮演着至关重要的角色,特别是在需要处理高电压信号的实验和应用中。高压放大器设计要求的充分考虑至关重要,以确保其在各种环境中稳定、可靠地工作。下面将介绍设计高压放大器时需要考虑的关键要求和因素。 1.电压范围 …

赋能软件开发:生成式AI在优化编程工作流中的应用与前景

随着人工智能(AI)技术的快速发展,特别是生成式AI模型如GPT-3/4的出现,软件开发行业正经历一场变革,这些模型通过提供代码生成、自动化测试和错误检测等功能,极大地提高了开发效率和软件质量。 本文旨在深入…

AnnexB封装格式介绍(主要用于H.264和H.265视频编码标准,是一种常见的视频流NALU封装格式,常用于RTSP、RTP传输)

参考文章:解码中的AnnexB和avcC两种分割数据方式 文章目录 AnnexB 格式介绍1. NALU单元与开始代码1.1 NALU单元1.2 开始代码 2. AnnexB格式详述2.1 基本结构2.2 长度前缀 3. 从AnnexB格式到AVCC格式4. AnnexB格式的优缺点4.1 优点4.2 缺点 5. 疑难问题解析如何确定开…

Android readelf 工具查找函数符号

ELF(Executable and Linkable Format)是一种执行文件和可链接文件的格式。它是一种通用的二进制文件格式,用于在各种操作系统中存储可执行程序、共享库和内核模块。 Android 开发当中的 so 库本质上就是一种特殊类型的 ELF 文件,…

文章解读与仿真程序复现思路——电工技术学报EI\CSCD\北大核心《考虑灵活性补偿的高比例风电与多元灵活性资源博弈优化调度》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 这个标题涉及到高比例风电与多元灵活性资源在博弈优化调度中考虑灵活性补偿的问题。以下是对标题各个部分的解读: 高比例风电: …

【模拟IC学习笔记】Cascode OTA 设计

辅助定理 增益Gm*输出阻抗 输出短路求Gm 输入置0求输出阻抗 求源极负反馈的增益 随着Vin的增加,Id也在增加,Rs上压降增加,所以,Vin的一部分电压体现在Rs上,而不是全部作为Vgs,因此导致Id变得平滑。 Rs足…

【数据结构篇】数据结构中的 R 树和 B 树

数据结构中的 R 树和 B 树 ✔️关于R树(RTree)✔️什么是B树(B-tree)✔️B树和B树的区别✔️B树和B树在数据存储方面的具体差异 ✔️拓展知识仓✔️R树和B树的区别✔️ 那在内存消耗上有什么区别?✔️ R树有哪些优点和…

【算法与数据结构】509、LeetCode斐波那契数

文章目录 一、题目二、递归,动态规划解法2.1 递归解法2.2 动态规划解法 三、完整代码 所有的LeetCode题解索引,可以看这篇文章——【算法和数据结构】LeetCode题解。 一、题目 二、递归,动态规划解法 2.1 递归解法 思路分析:斐波…

go 语言中的 iota

我们经常会在我们的代码中定义类似以下这些常量: const (ColorRed "Red"ColorGreen "Green"ColorBlue "Blue" )在其他时候,我们仅仅关注能把一个东西与其他的做区分。 有些时候,有些时候一件事没有本质…

【Leetcode】240. 搜索二维矩阵 II

一、题目 1、题目描述 编写一个高效的算法来搜索 m x n 矩阵 matrix 中的一个目标值 target 。该矩阵具有以下特性: 每行的元素从左到右升序排列。每列的元素从上到下升序排列。示例1: 输入:matrix = [[1,4,7,11,15],[2,5,8,12,19],[3,6,9,16,22],[10,13,14,17,24],[18,21…

Pytorch:torch.nn.Module.apply用法详解

torch.nn.Module.apply 是 PyTorch 中用于递归地应用函数到模型的所有子模块的方法。它允许对模型中的每个子模块进行操作,比如初始化权重、改变参数类型等。 以下是关于 torch.nn.Module.apply 的示例: 1. 语法 Module.apply(fn)Module:P…

【REST2SQL】05 GO 操作 达梦 数据库

【REST2SQL】01RDB关系型数据库REST初设计 【REST2SQL】02 GO连接Oracle数据库 【REST2SQL】03 GO读取JSON文件 【REST2SQL】04 REST2SQL第一版Oracle版实现 信创要求用国产数据库,刚好有项目用的达梦,研究一下go如何操作达梦数据库 1 准备工作 1.1 安…

ros2 基础学习 15- URDF:机器人建模方法

URDF:机器人建模方法 ROS是机器人操作系统,当然要给机器人使用啦,不过在使用之前,还得让ROS认识下我们使用的机器人,如何把一个机器人介绍给ROS呢? 为此,ROS专门提供了一种机器人建模方法——…

2024华为OD机试:最多几个直角三角形

题目描述 有N条线段&#xff0c;长度分别为a[1]-a[n]。 现要求你计算这N条线段最多可以组合成几个直角三角形。每条线段只能使用一次&#xff0c;每个三角形包含三条线段。 输入描述 第一行输入一个正整数T(1<T<100),表示有T组测试数据.对于每组测试数据&#xff0c;…

软件测试|SQL中的UNION和UNION ALL详解

简介 在SQL&#xff08;结构化查询语言&#xff09;中&#xff0c;UNION和UNION ALL是用于合并查询结果集的两个关键字。它们在数据库查询中非常常用&#xff0c;但它们之间有一些重要的区别。在本文中&#xff0c;我们将深入探讨UNION和UNION ALL的含义、用法以及它们之间的区…

Ubuntu 22.04 编译安装 Qt mysql驱动

参考自 Ubuntu20.04.3 QT5.15.2 MySQL驱动编译 Ubuntu 18.04 编译安装 Qt mysql驱动 下边这篇博客不是主要参考的, 但是似乎解决了我的难题(找不到 libmysqlclient.so) ubuntu18.04.2 LTS 系统关于Qt5.12.3 无法加载mysql驱动&#xff0c;需要重新编译MYSQL数据库驱动的问题以…

【代码随想录】刷题笔记Day45

前言 早上又赖了会床......早睡早起是奢望了现在&#xff0c;新一年不能这样&#xff01;支棱起来&#xff01; 377. 组合总和 Ⅳ - 力扣&#xff08;LeetCode&#xff09; 这一题用的就是完全背包排列数的遍历顺序&#xff1a;先背包再物品&#xff0c;从前往后求的也是有几…

IO类day01

File类 File类的每一个实例可以表示硬盘(文件系统)中的一个文件或目录(实际上表示的是一个抽象路径) 使用File可以做到: 1:访问其表示的文件或目录的属性信息,例如:名字,大小,修改时间等等 2:创建和删除文件或目录 3:访问一个目录中的子项 但是File不能访问文件数据. pu…

mac电脑php命令如何设置默认的php版本

前提条件&#xff1a;如果mac电脑还没安装多个PHP版本&#xff0c;可以先看这篇安装一下 mac电脑运行多个php版本_mac 同时运行两个php-CSDN博客 第一部分&#xff1a;简单总结 #先解除现在默认的php版本 brew unlink php7.4#再设置的想要设置的php版本 brew link php8.1第二部…

AWS Simple Email Service (SES) 实战指南

Amazon Simple Email Service (SES) 是一项强大的电子邮件发送服务&#xff0c;适用于数字营销、应用程序通知以及事务性邮件。在这个实战指南中&#xff0c;我们将演示如何设置 AWS SES 并通过几个示例展示其用法。 设置 AWS SES 1. 创建 AWS 账户 首先&#xff0c;您需要创…