昇思25天学习打卡营第7天之二 | 模型保存与加载

1. 保存与加载

在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,本章节我们将介绍如何保存与加载模型。

1.1 导入依赖

# 导入numpy库,并将其重命名为np,以便在代码中引用
import numpy as np# 导入MindSpore库,这是华为推出的一个开源深度学习框架,用于构建和训练神经网络
import mindspore# 从MindSpore库中导入nn模块,这个模块包含了构建神经网络所需的各种层和函数
from mindspore import nn# 从MindSpore库中导入Tensor模块,Tensor是MindSpore中用于表示张量的类
from mindspore import Tensor

1.1定义神经网络模型

# 定义一个函数,该函数创建一个简单的全连接神经网络模型
def network():# 使用nn.SequentialCell创建一个层序列,这是一个容器类,可以包含多个层model = nn.SequentialCell(# 第一个层是一个Flatten层,用于将输入的二维图像数据展平为一维向量nn.Flatten(),# 第二个层是一个全连接层,将28x28的输入节点映射到512个节点nn.Dense(28*28, 512),# 第三个层是一个ReLU激活函数,用于非线性变换nn.ReLU(),# 第四个层是一个全连接层,将512个节点映射到512个节点nn.Dense(512, 512),# 第五个层是一个ReLU激活函数,用于非线性变换nn.ReLU(),# 第六个层是一个全连接层,将512个节点映射到10个节点,对应于10个类别的输出nn.Dense(512, 10))# 返回创建好的模型return model

1.2 保存和加载模型权重

1.2.1 保存模型

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

# 创建一个神经网络模型实例
model = network()# 使用MindSpore的save_checkpoint函数将模型的检查点保存到文件
# 第一个参数是模型对象
# 第二个参数是文件名,这里保存为"model.ckpt"
mindspore.save_checkpoint(model, "model.ckpt")
# 打印模型结构
print(model)

输出:

SequentialCell<(0): Flatten<>(1): Dense<input_channels=784, output_channels=512, has_bias=True>(2): ReLU<>(3): Dense<input_channels=512, output_channels=512, has_bias=True>(4): ReLU<>(5): Dense<input_channels=512, output_channels=10, has_bias=True>>

模型大小估算:
model_capacity ≈ 模型参数 * 数据精度(默认是int32类型)大小 = [(784512+512) + (512512+512) + (512*10 +10)] *32bit/8(bit/Byte)= 669704 *4 = 2678824 Byte
可以看到,模型参数量约为67W,占用空间大小应约为2678824字节
实际该模型文件大小为2679017。可以说非常接近了,剩下的字节应该就是文件类型描述符加模型结构描述符之类的内容了。
所以当我们已知一个模型的参数量和参数精度后,实际就可以估算出模型占用的磁盘空间大小了。

1.2.2 加载模型

要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

# 创建一个神经网络模型
model = network()# 使用MindSpore的load_checkpoint函数从文件中加载模型的参数和优化器状态
# 参数是检查点的文件名,这里加载的文件名为"model.ckpt"
param_dict = mindspore.load_checkpoint("model.ckpt")# 使用MindSpore的load_param_into_net函数将加载的参数字典加载到模型中
# 第一个参数是模型对象
# 第二个参数是参数字典
# 返回值是一个元组,第一个元素是未加载的参数列表,第二个元素是加载的参数列表
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)# 打印未加载的参数列表,如果加载成功,这个列表应该是空的
print(param_not_load)

输出:

[]

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。

1.3 保存和加载MindIR

除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

# 创建网络模型
model = network()
# 创建一个Tensor对象,它包含一个大小为[1, 1, 28, 28]的矩阵,所有元素都是1,数据类型为float32
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
# 使用mindspore.export函数将模型导出为MINDIR格式
# 第一个参数是模型对象
# 第二个参数是输入数据,这里使用了一个Tensor对象作为示例
# 第三个参数是文件名,这里导出的文件名为"model"
# 第四个参数是文件格式,这里设置为"MINDIR",表示导出的模型格式
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

# 设置MindSpore的执行模式为GRAPH_MODE
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# 加载之前导出的MINDIR模型
graph = mindspore.load("model.mindir")
# 创建一个GraphCell对象,它将graph作为其成员
model = nn.GraphCell(graph)
# 使用模型对输入数据进行前向计算,得到输出
outputs = model(inputs)
# 打印输出的形状
print(outputs.shape)

输出:
模型加载f

2. 小结

本文主要介绍了模型的保存和加载,都包括检查点checkpoint和统一中间表示MindIR(Intermediate Representation)两种方法,还介绍了模型大小的估算方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/37162.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

六月,允许自己做自己,别人做别人

今天结束后&#xff0c;2024 就过去一半了。 年初的规划完成一半了吗&#xff1f;如果没有也没关系&#xff0c;做你自己继续前进。 家人来北京旅游&#xff0c;我累趴了 六月初&#xff0c;我搬家了&#xff0c;这次租了一整套房&#xff0c;是一个小俩居、还带一个小阁楼。…

数学学习与研究杂志社《数学学习与研究》杂志社2024年第6期目录

课改前沿 基于核心素养的高中数学课堂教学研究——以“直线与圆、圆与圆的位置关系”为例 张亚红; 2-4 核心素养视角下初中生数学阅读能力的培养策略探究 贾象虎; 5-7 初中数学大单元教学实践策略探索 耿忠义; 8-10《数学学习与研究》投稿&#xff1a;cn7kantougao…

使用Python绘制极坐标图

使用Python绘制极坐标图 极坐标图极坐标图的优点使用场景 效果代码 极坐标图 极坐标图&#xff08;Polar Chart&#xff09;是一种图表类型&#xff0c;用于显示在极坐标系中的数据。极坐标图使用圆形坐标系&#xff0c;角度表示一个变量的值&#xff0c;半径表示另一个变量的…

线程安全问题(二)——死锁

死锁 前言可重入锁逻辑 两个线程两把锁&#xff08;死锁&#xff09;死锁的特点多个线程多把锁&#xff08;哲学家就餐问题&#xff09;总结 前言 在前面的文章中&#xff0c;介绍了锁的基本使用方式——锁 在上一篇文章中&#xff0c;通过synchronized关键字进行加锁操作&am…

XML简介XML 使用教程XML的基本结构XML的使用场景

学习总结 1、掌握 JAVA入门到进阶知识(持续写作中……&#xff09; 2、学会Oracle数据库入门到入土用法(创作中……&#xff09; 3、手把手教你开发炫酷的vbs脚本制作(完善中……&#xff09; 4、牛逼哄哄的 IDEA编程利器技巧(编写中……&#xff09; 5、面经吐血整理的 面试技…

VMware每次打开网络设置都出现需要运行NetworkManager问题

每次打开都出现这个情况&#xff0c;是因为之前把NetworkManager服务服务关闭&#xff0c;重新输入命令&#xff1a; sudo systemctl start NetworkManager.service或者 sudo service network-manager restart 即可解决&#xff0c;但是每次开机重启都要打开就很麻烦&#xf…

【Chapter4】汇编语言及其程序设计,《微机系统》第一版,赵宏伟

一、汇编语言概述 **指令&#xff1a;**指使计算机完成某种操作的命令。 **程序&#xff1a;**完成某种功能的指令序列。 **软件&#xff1a;**各种程序总称。 **机器语言&#xff1a;**计算机能直接识别的语言。用机器语言写出的程序称为机器代码。 **汇编语言&#xff1…

Forecasting from LiDAR via Future Object Detection

Forecasting from LiDAR via Future Object Detection 基础信息 论文&#xff1a;cvpr2022paper https://openaccess.thecvf.com/content/CVPR2022/papers/Peri_Forecasting_From_LiDAR_via_Future_Object_Detection_CVPR_2022_paper.pdfgithub&#xff1a;https://github.co…

SyncUnsafeCell替换Mutex提高性能

1. 背景 在Rust开发过程中&#xff0c;很多情况下需要在不可变的情况下获取可变性或者在多线程的情况下可以安全的贡献可变数据。这种情况下我们一般使用**Mutex来实现通过加锁来实现。现在我们可以通过使用SyncUnsafeCell来替代Mutex**。 2. SyncUnsafeCell SyncUnsafeCell…

K8S之网络深度剖析(一)(持续更新ing)

K8S之网络深度剖析 一 、关于K8S的网络模型 在K8s的世界上,IP是以Pod为单位进行分配的。一个Pod内部的所有容器共享一个网络堆栈(相当于一个网络命名空间,它们的IP地址、网络设备、配置等都是共享的)。按照这个网络原则抽象出来的为每个Pod都设置一个IP地址的模型也被称作为I…

SpringBoot(一)创建一个简单的SpringBoot工程

Spring框架常用注解简单介绍 SpringMVC常用注解简单介绍 SpringBoot&#xff08;一&#xff09;创建一个简单的SpringBoot工程 SpringBoot&#xff08;二&#xff09;SpringBoot多环境配置 SpringBoot&#xff08;三&#xff09;SpringBoot整合MyBatis SpringBoot&#xff08;四…

3.ROS串口实例

#include <iostream> #include <ros/ros.h> #include <serial/serial.h> #include<geometry_msgs/Twist.h> using namespace std;//运行打开速度控制插件&#xff1a; rosrun rqt_robot_steering rqt_robot_steering //若串口访问权限不够&#xff1a…

详解PEFT库中LoRA源码

前言 GitHub项目地址Some-Paper-CN。本项目是译者在学习长时间序列预测、CV、NLP和机器学习过程中精读的一些论文&#xff0c;并对其进行了中文翻译。还有部分最佳示例教程。如果有帮助到大家&#xff0c;请帮忙点亮Star&#xff0c;也是对译者莫大的鼓励&#xff0c;谢谢啦~本…

读书笔记-《Spring技术内幕》(三)MVC与Web环境

前面我们学习了 Spring 最核心的 IoC 与 AOP 模块&#xff08;读书笔记-《Spring技术内幕》&#xff08;一&#xff09;IoC容器的实现、读书笔记-《Spring技术内幕》&#xff08;二&#xff09;AOP的实现&#xff09;&#xff0c;接下来继续学习 MVC&#xff0c;其同样也是经典…

Spring底层原理之bean的加载方式八 BeanDefinitionRegistryPostProcessor注解

BeanDefinitionRegistryPostProcessor注解 这种方式和第七种比较像 要实现两个方法 第一个方法是实现工厂 第二个方法叫后处理bean注册 package com.bigdata1421.bean;import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.…

解决idea中git无法管理项目中所有需要管理的文件

点击文件->设置 选择版本控制—>目录映射 点击加号 设置整个项目被Git管理

MySQL高级-事务-并发事务演示及隔离级别

文章目录 0、四种隔离级别1、创建表 account2、修改当前会话隔离级别为 read uncommitted2.1、会出现脏读 3、修改当前会话隔离级别为 read committed3.1、可以解决脏读3.2、会出现不可重复读 4、修改当前会话隔离级别为 repeatable read&#xff08;默认&#xff09;4.1、解决…

【论文阅读】transformer及其变体

写在前面&#xff1a; transformer模型已经是老生常谈的一个东西&#xff0c;以transformer为基础出现了很多变体和文章&#xff0c;Informer、autoformer、itransformer等等都是顶刊顶会。一提到transformer自然就是注意力机制&#xff0c;变体更是数不胜数&#xff0c;一提到…

【目标检测】DN-DETR

一、引言 论文&#xff1a; DN-DETR: Accelerate DETR Training by Introducing Query DeNoising 作者&#xff1a; IDEA 代码&#xff1a; DN-DETR 注意&#xff1a; 该算法是在DAB-DETR基础上的改进&#xff0c;在学习该算法前&#xff0c;建议掌握DETR、DAB-DETR等相关知识…

VMamba: Visual State Space Model论文笔记

文章目录 VMamba: Visual State Space Model摘要引言相关工作Preliminaries方法网络结构2D-Selective-Scan for Vision Data(SS2D) VMamba: Visual State Space Model 论文地址: https://arxiv.org/abs/2401.10166 代码地址: https://github.com/MzeroMiko/VMamba 摘要 卷积神…