机器学习 - save和load训练好的模型

如果已经训练好了一个模型,你就可以save和load这模型。

For saving and loading models in PyTorch, there are three main methods you should be aware of.

PyTorch methodWhat does it do?
torch.saveSaves a serialized object to disk using Python’s pickle utility. Models, tensors and various other Python objects like dictionaries can be saved using torch.save
torch.loadUses pickle’s unpickling features to deserialize and load pickled Python object files (like models, tensors or dictionaries) into memory. You can also set which device to load the object to (CPU, GPU etc)
torch.nn.Module.load_state_dictLoads a model’s parameter dictionary (model.state_dict()) using a saved state_dict() object

在 PyTorch 中,pickle 是一个用于序列化和反序列化Python对象的标准库模块。它可以将Python对象转换为字节流 (即序列化),并将字节流转换回Python对象 (即反序列化)。pickle模块在很多情况下都非常有用,特别是在保存和加载模型,保存训练中间状态等方面。
在深度学习中,经常需要保存训练好的模型或者训练过程中的中间结果,以便后续的使用或分析。PyTorch提高了方便的API来保存和加载模型,其中就包括了使用pickle模块进行对象的序列化和反序列化。


save model

import torch
from pathlib import Path # 1. Create models directory
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)# 2. Create model save path
MODEL_NAME = "trained_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME# 3. Save the model state dict 
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj = model_0.state_dict(),f = MODEL_SAVE_PATH)

就能看到 trained_model.pth 文件下载到所属的文件夹位置。


Load the saved PyTorch model
You can load it in using torch.nn.Module.load_state_dict(torch.load(f)) where f is the filepath of the saved model state_dict().

Why call torch.load() inside torch.nn.Module.load_state_dict()?
Because you only saved the model’s state_dict() which is a dictionary of learned parameters and not the entire model, you first have to load the state_dict() with torch.load() and then pass that state_dict() to a new instance of the model (which is a subclass of nn.Module).

# Instantiate a new instance of the model 
loaded_model_0 = LinearRegressionModel()# Load the state_dict of the saved model
loaded_model_0.load_state_dict(torch.load(f=MODEL_SAVE_PATH))# 结果如下
<All keys matched successfully>

测试 loaded model。

# Put the loaded model into evaluation model 
loaded_model_0.eval() # 2. Use the inference mode context manager to make predictions
with torch.inference_mode():loaded_model_preds = loaded_model_0(X_test)# Compare previous model predictions with loaded model predictions
print(y_preds == loaded_model_preds) # 结果如下
tensor([[True],[True],[True],[True],[True],[True],[True],[True],[True],[True]])

看到这了,点个赞呗~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/772325.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI大模型学习的伦理与社会影响

AI大模型学习 随着人工智能技术的快速发展&#xff0c;AI大模型学习成为当前热门研究领域之一。AI大模型学习是指基于大规模数据集和深度学习模型进行训练&#xff0c;以实现更高的准确性和复杂性。这些大模型已经在几乎所有领域都取得了显著的成就&#xff0c;包括自然语言处…

通讯录管理系统实现(C++版本)

1.菜单栏的设置 &#xff08;1&#xff09;我么自定义了一个showmenu函数&#xff0c;用来打印输出我们的菜单栏&#xff1b; &#xff08;2&#xff09;菜单栏里面设置一些我们的通讯录里面需要用到的功能&#xff0c;例如增加联系人&#xff0c;删除联系人等等 2.退出功能…

ocrclass.h:117:18: error: field ‘end_time‘ has incomplete type ‘timeval‘

Alpine Linux v3.5上安装 tesseract-4.1.1 报错&#xff1a; 缺少timeval函数 ocrclass.h:117:18: error: field end_time has incomplete type timeval Current Behavior: In file included from control.cpp:37:0: ../../src/ccutil/ocrclass.h:117:18: error: field end…

javaWeb私人牙科诊所管理系统

一、摘要 随着科技的飞速发展&#xff0c;计算机已经广泛的应用于各个领域之中。在医学领域中&#xff0c;计算机主要应用于两个方面&#xff1a;一是医疗设备智能化&#xff0c;以硬件为主。另一种是病例信息管理系统&#xff08;HIS&#xff09;以软件建设为主&#xff0c;以…

1978-2022年全国31省社会消费品零售总额数据

1978-2022年全国31省社会消费品零售总额数据 1、时间&#xff1a;1978-2022年 2、指标&#xff1a;社会消费品零售总额 3、范围&#xff1a;31省市 4、来源&#xff1a;整理自国家统计J和各省年鉴 5、缺失情况说明&#xff1a;1997-2022年31省市均无缺失&#xff0c; 199…

GB 16807-2009 防火膨胀密封件检测

防火膨胀密封件是指火灾时遇火或高温作用能够膨胀&#xff0c;且能辅助建筑构配件使之具有隔火、隔烟、隔热等防火密封性能的产品。 GB 16807-2009防火膨胀密封件检测项目&#xff1a; 测试项目 测试方法 外观 GB 16807 尺寸允许偏差 GB 16807 膨胀性能 GB 16807 产烟…

随机链表的深拷贝

目录 一、何为深拷贝&#xff1f; 二、题目 三、思路 1.拷贝节点插入到原节点后面 2.控制拷贝节点的random 3.脱离原链表 : 尾插的思想 四、代码 五、附加 一、何为深拷贝&#xff1f; 一个引用对象一般来说由两个部分组成&#xff1a;一个具名的Handle&#xff0c;也就…

spring boot3 解决跨域几种方式

在Spring Boot 3中&#xff0c;解决跨域请求&#xff08;CORS&#xff0c;Cross-Origin Resource Sharing&#xff09;的问题主要有以下几种方式&#xff1a; 1. 使用CrossOrigin注解 你可以直接在Controller类或者具体的请求处理方法上使用CrossOrigin注解来允许跨域请求。 …

Java面试题:请解释Java中的输入输出(I/O)流?详细说明应用场景

Java中的输入输出&#xff08;I/O&#xff09;流是用于读取和写入数据的机制。在Java中&#xff0c;I/O流被设计为按照流的方向和数据源/目标类型进行分类。流的方向分为输入流和输出流&#xff0c;而数据源/目标类型则分为字节流和字符流。 流的方向&#xff1a; 输入流&…

面试官问我 ,try catch 应该在 for 循环里面还是外面?

首先 &#xff0c; 话说在前头&#xff0c; 没有什么 在里面 好 和在外面好 或者 不好的 一说。 本篇文章内容&#xff1a; 使用场景 性能分析 个人看法 1. 使用场景 为什么要把 使用场景 摆在第一个 &#xff1f; 因为本身try catch 放在 for循环 外面 和里面 &#…

图片标注编辑平台搭建系列教程(2)——fabric.js简介

文章目录 综述数据管理图形渲染图形编辑事件监听预告 综述 fabric提供了二维图形编辑需要的所有基础能力&#xff0c;包括&#xff1a;数据管理、图形渲染、图形编辑和事件监听。其中&#xff0c;图形编辑可以通过事件监听和图形渲染来实现&#xff0c;所以可以弃用。数据管理…

2024年NOC大赛软件创意编程(python初中组初赛)真题

题型和分值&#xff1a;单选题(20题,40分)、判断题(5题,10分)、多选题(5题,20分)、填空题(10题,30分) 一、单选题&#xff08;每题2分&#xff0c;共20题&#xff0c;满分40分&#xff09; 1、下面的程序&#xff0c;会无限循环下去的是&#xff08; &#xff09; A&#x…

【数据结构】双向奔赴的爱恋 --- 双向链表

关注小庄 顿顿解馋๑ᵒᯅᵒ๑ 引言&#xff1a;上回我们讲解了单链表(单向不循环不带头链表)&#xff0c;我们可以发现他是存在一定缺陷的&#xff0c;比如尾删的时候需要遍历一遍链表&#xff0c;这会大大降低我们的性能&#xff0c;再比如对于链表中的一个结点我们是无法直接…

OJ :1092 :素数表(函数专题)

题目描述 输入两个正整数m和n&#xff0c;输出m和n之间的所有素数。 要求程序定义一个prime()函数和一个main()函数&#xff0c;prime()函数判断一个整数n是否是素数&#xff0c;其余功能在main()函数中实现。 int prime(int n) { //判断n是否为素数&#xff0c; 若n为素数…

DNS协议 是什么?说说DNS 完整的查询过程?

一、是什么 DNS&#xff08;Domain Names System&#xff09;&#xff0c;域名系统&#xff0c;是互联网一项服务&#xff0c;是进行域名和与之相对应的 IP 地址进行转换的服务器 简单来讲&#xff0c;DNS相当于一个翻译官&#xff0c;负责将域名翻译成ip地址 IP 地址&#…

linux - rm命令删除文件到垃圾箱

修改原来的rm指令到垃圾箱&#xff0c;对于误操作的删除可以直接从垃圾箱里拉回来&#xff0c;同时提高网络安全意识。 创建remove.sh 脚本 PARA_CNT$# TRASH_DIR"/home/pass/.trash" # 指定垃圾箱目录 for i in $*; doSTAMPdate %Y%m%d # 删除时间fileName…

js实现拖放效果

dataTransfer对象 说明&#xff1a;dataTransfer对象用于从被拖动元素向放置目标传递字符串数据。因为这个对象是 event 的属性&#xff0c;所以在拖放事件的事件处理程序外部无法访问 dataTransfer。在事件处理程序内部&#xff0c;可以使用这个对象的属性和方法实现拖放功能…

【鸿蒙HarmonyOS开发笔记】使用@Preview装饰器预览组件

概述 ArkTS应用/服务支持组件预览&#xff0c;要求compileSdkVersion为8或以上。组件预览支持实时预览&#xff0c;不支持动态图和动态预览。组件预览通过在组件前添加注解Preview实现&#xff0c;在单个源文件中&#xff0c;最多可以使用10个Preview装饰自定义组件。 Preview…

算法---矩阵的乘法及其运用

相信我们都做过一个题叫斐波那契数列&#xff0c;对于一般的题&#xff0c;n的取值范围通常在1000以内&#xff0c;但是如果你遇到的是下面这题呢&#xff1f; 斐波那契数列 - 洛谷 发现了吗&#xff1f;我的n取值范围连long long都会爆出&#xff0c;所以下面我们通过矩阵乘法…

张驰咨询:光伏产业新质生产力提升咨询方案

光伏产业新质生产力提升咨询方案 一、光伏行业目前发展现状及特点 1、高度竞争 2、技术驱动 3、绿色发展 二、光伏发展新质生产力面临的痛点 1、成本压缩与效率提升并存挑战 2、新技术应用与推广难度 3、国际贸易摩擦影响 4、市场需求波动大 5、政策与补贴依赖性 三、…