关于如何得到Mindspore lite所需要的.ms模型

关于如何得到Mindspore lite所需要的.ms模型

  • 一、.ckpt模型文件转.mindir模型
  • 二、.mindir模型转.ms模型
  • 三、其它
    • 3.1 代码
    • 3.2 数据
  • 四、参考文档

一、.ckpt模型文件转.mindir模型

由于要得到ms模型,mindspore的所有模型里面,是必须要用mindir模型才可以进行转换的,因此我们是必须先拿到mindir模型~
在这里插入图片描述

此过程并不复杂,需要注意的是,要在昇腾910的npu环境下训练得到的ckpt模型文件才可以转换,其它如cpu、gpu下得到的模型均不可以,所以可以用启智AI平台来,按照昇思官方给的示例就可以转成。

这里可以用启智AI平台,有免费的npu提供,速度也很快!

在这里插入图片描述

input_np为训练/推理过程输入网络的数据(其中一个),可以先打印出来确定其内容和类型,我这个案例里面用的是(10,1),即一个二维数字,10列1行,这也是为什么数据是这个样子的原因;
在这里插入图片描述
其它调用模型、网络都用自己搭建的,简单调一下就可以一下子转成了;

二、.mindir模型转.ms模型

需要用官方所提供的转换工具,下载版本最好和mindspore版本对应,下载后设置环境变量时候,最好是用管理员模式powershell设置,设置指令如下
$env:PATH = "C:\Users\l\Desktop\ls\mindspore-lite-2.2.0-win-x64\tools\converter\lib;" + $env:PATH
路径需要替换为自己的mindspore lite地址,后面按照转换示例走一下就可以转换得到,主要容易出错的是环境变量的设置
在这里插入图片描述

三、其它

3.1 代码

此部分为模型训练和保存代码,注意模型训练所用的data数据列为’CRIM’, ‘ZN’, ‘INDUS’, “CHAS”, ‘NOX’, ‘RM’, ‘AGE’, ‘DIS’, ‘RAD’, 'LSTAT’几列,并不是13列全用


import numpy as np
import mindspore as ms
from mindspore import ops, nn
import mindspore.dataset as ds
import mindspore.common.initializer as init
import pandas as pd
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")def get_data(data_num, data_size, trian=True):df = pd.read_csv("boston.csv")df = df.dropna(axis=0)df.head()#     feature=df[['CRIM','RM','LSTAT']]feature = df[['CRIM', 'ZN', 'INDUS', "CHAS", 'NOX', 'RM', 'AGE', 'DIS', 'RAD', 'LSTAT']]feature.describe()target = df['MEDV']split_num = int(len(feature) * 0.7)if trian == True:for i in range(split_num):y = [target.iloc[i]]x = [feature.loc[i, field] for field in feature.columns]yield np.array(x[:]).astype(np.float32), np.array([y[0]]).astype(np.float32)else:for i in range(split_num, len(feature)):y = [target.iloc[i]]x = [feature.iloc[i].values for field in feature.columns]x = x[0]yield np.array(x[:]).astype(np.float32), np.array([y[0]]).astype(np.float32)def create_dataset(data_num, data_size, batch_size=1, repeat_size=1, train=True):"""定义数据集"""input_data = ds.GeneratorDataset(list(get_data(data_num, data_size, train)), column_names=['data', 'label'])input_data = input_data.batch(batch_size)input_data = input_data.repeat(repeat_size)return input_dataclass MyNet(nn.Cell):"""定义网络"""def __init__(self, input_size=32):super(MyNet, self).__init__()self.fc1 = nn.Dense(10, 1, weight_init=init.Normal(0.02))self.relu = nn.ReLU()def construct(self, x):x = self.relu(self.fc1(x))return xclass MyL1Loss(nn.LossBase):"""定义损失"""def __init__(self, reduction="mean"):super(MyL1Loss, self).__init__(reduction)self.abs = ops.Abs()def construct(self, base, target):x = self.abs(base - target)return self.get_loss(x)class MyMomentum(nn.Optimizer):"""使用ApplyMomentum算子定义优化器"""def __init__(self, params, learning_rate, momentum=0.9, use_nesterov=False):super(MyMomentum, self).__init__(learning_rate, params)self.moments = self.parameters.clone(prefix="moments", init="zeros")self.momentum = momentumself.opt = ops.ApplyMomentum(use_nesterov=use_nesterov)def construct(self, gradients):params = self.parameterssuccess = Nonefor param, mom, grad in zip(params, self.moments, gradients):success = self.opt(param, mom, self.learning_rate, grad, self.momentum)return successclass MyWithLossCell(nn.Cell):"""定义损失网络"""def __init__(self, backbone, loss_fn):super(MyWithLossCell, self).__init__(auto_prefix=False)self.backbone = backboneself.loss_fn = loss_fndef construct(self, data, label):out = self.backbone(data)return self.loss_fn(out, label)def backbone_network(self):return self.backboneclass MyTrainStep(nn.TrainOneStepCell):"""定义训练流程"""def __init__(self, network, optimizer):"""参数初始化"""super(MyTrainStep, self).__init__(network, optimizer)self.grad = ops.GradOperation(get_by_list=True)def construct(self, data, label):"""构建训练过程"""weights = self.weightsloss = self.network(data, label)grads = self.grad(self.network, weights)(data, label)return loss, self.optimizer(grads)# 生成多项式分布的数据
dataset_size = 64
ds_train = create_dataset(2048, dataset_size)
# 网络
net = MyNet()
# 损失函数
loss_func = MyL1Loss()
# 优化器
opt = MyMomentum(net.trainable_params(), 0.0001)
# 构建损失网络
net_with_criterion = MyWithLossCell(net, loss_func)
# 构建训练网络
train_net = MyTrainStep(net_with_criterion, opt)
# 执行训练,每个epoch打印一次损失值
epochs = 50
for epoch in range(epochs):for train_x, train_y in ds_train:train_net(train_x, train_y)# print(train_x.shape)# print(train_x.shape)#         print(train_x,train_y)loss_val = net_with_criterion(train_x, train_y)#     print(loss_val)class MyMAE(nn.Metric):"""定义metric"""def __init__(self):super(MyMAE, self).__init__()self.clear()def clear(self):self.abs_error_sum = 0self.samples_num = 0def update(self, *inputs):y_pred = inputs[0].asnumpy()y = inputs[1].asnumpy()error_abs = np.abs(y.reshape(y_pred.shape) - y_pred)self.abs_error_sum += error_abs.sum()self.samples_num += y.shape[0]def eval(self):return self.abs_error_sum / self.samples_numclass MyWithEvalCell(nn.Cell):"""定义验证流程"""def __init__(self, network):super(MyWithEvalCell, self).__init__(auto_prefix=False)self.network = networkdef construct(self, data, label):outputs = self.network(data)return outputs, label# 获取验证数据
ds_eval = create_dataset(128, dataset_size, 1, train=False)
# 定义评估网络
eval_net = MyWithEvalCell(net)
eval_net.set_train(False)
# 定义评估指标
mae = MyMAE()
# 执行推理过程
for eval_x, eval_y in ds_eval:output, eval_y = eval_net(eval_x, eval_y)mae.update(output, eval_y)print("output is {} label is {}".format(output, eval_y))
mae_result = mae.eval()
print("mae on val_set: ", mae_result)ms.save_checkpoint(net, "./MyNet.ckpt")

运行上述代码,可以得到ckpt模型,接下来是进行推理,验证数据形式

net = MyNet()
ms.load_checkpoint("MyNet.ckpt", net=net)
ls=[[0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296]]
np_array = np.array(ls)
input_np = np_array.astype(np.float32)
output = net(ms.Tensor(input_np))
print(output)

运行可以发现,能够得到推理结果,代表数据形式正确
即为一个二维列表-》numpy形式-》tensor形式
然后可以按照示例,根据自己代码进行模型转换,得到mindir模型文件


import numpy as np
import mindspore as msnet = MyNet()ms.load_checkpoint("MyNet.ckpt", net=net)
ls=[[0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296]]
np_array = np.array(ls)
input_np = np_array.astype(np.float32)
ms.export(net, ms.Tensor(input_np), file_name='mind', file_format='MINDIR')

3.2 数据

即为波斯顿房价预测案例数据,这里就不再放了,只放个简单示例,可以自己直接去搜寻并下载

CRIM,ZN,INDUS,CHAS,NOX,RM,AGE,DIS,RAD,TAX,PIRATIO,B,LSTAT,MEDV
0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24
0.02731,0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6
0.02729,0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7
0.03237,0,2.18,0,0.458,6.998,45.8,6.0622,3,222,18.7,394.63,2.94,33.4
0.06905,0,2.18,0,0.458,7.147,54.2,6.0622,3,222,18.7,396.9,5.33,36.2
0.02985,0,2.18,0,0.458,6.43,58.7,6.0622,3,222,18.7,394.12,5.21,28.7
0.08829,12.5,7.87,0,0.524,6.012,66.6,5.5605,5,311,15.2,395.6,12.43,22.9
0.14455,12.5,7.87,0,0.524,6.172,96.1,5.9505,5,311,15.2,396.9,19.15,27.1
0.21124,12.5,7.87,0,0.524,5.631,100,6.0821,5,311,15.2,386.63,29.93,16.5
0.17004,12.5,7.87,0,0.524,6.004,85.9,6.5921,5,311,15.2,386.71,17.1,18.9
0.22489,12.5,7.87,0,0.524,6.377,94.3,6.3467,5,311,15.2,392.52,20.45,15
0.11747,12.5,7.87,0,0.524,6.009,82.9,6.2267,5,311,15.2,396.9,13.27,18.9
0.09378,12.5,7.87,0,0.524,5.889,39,5.4509,5,311,15.2,390.5,15.71,21.7
0.62976,0,8.14,0,0.538,5.949,61.8,4.7075,4,307,21,396.9,8.26,20.4
0.63796,0,8.14,0,0.538,6.096,84.5,4.4619,4,307,21,380.02,10.26,18.2

四、参考文档

1、mindspore教程:https://www.mindspore.cn/tutorials/zh-CN/r1.7/advanced/train/save.html
2、华为mindspore入门-波士顿房价回归:https://blog.csdn.net/weixin_47895059/article/details/123964083
3、mindspore推理模型转换:https://www.mindspore.cn/lite/docs/zh-CN/r1.3/use/converter_tool.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/31207.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32C8T6与TB6612

好久没写博客了,今天水一篇 接线

【Unity设计模式】状态编程模式

前言 最近在学习Unity游戏设计模式,看到两本比较适合入门的书,一本是unity官方的 《Level up your programming with game programming patterns》 ,另一本是 《游戏编程模式》 这两本书介绍了大部分会使用到的设计模式,因此很值得学习 本…

豆瓣电影top250网页爬虫

设计思路 选择技术栈:确定使用Python及其相关库,如requests用于发送网络请求,获取网址,用re(正则表达式)或BeautifulSoup用于页面内容解析。设计流程:规划爬虫的基本流程,包括发起请求、接受响应、解析内容、存储数据等环节。模块…

小程序中用font-spider压缩字体后,字体没效果(解决办法)

因为项目中需要引入外部字体,有两种方案, 第一是把字体下载到本地, 第二种是cdn请求服务器放字体的地址 但是小程序是有大小限制的,所以必须要压缩字体大小,这时候有些人就说了,那把字体放在服务器上&a…

【人工智能】—基于K-Means算法商场顾客聚类实战教程

在这篇博文之前一直是给大家做机器学习有监督学习教程,今天来一篇无监督学习教程。 K-Means算法是一种基于中心的聚类方法,它试图找到数据点的K个簇,使得簇内的数据点尽可能相似,而簇间的数据点尽可能不同。下面是K-Means算法的详…

Spring Boot集成tablesaw插件快速入门

1 什么是tablesaw? Tablesaw是一款Java的数据可视化库,主要包括两部分: 数据解析库,主要用于加载数据,对数据进行操作(转化,过滤,汇总等),类比Python中的Pandas库; 数据…

苹果cms10影视网整站源码下载/苹果cms模板MXone Pro自适应影视电影网站模板

下载地址:苹果cms10影视网整站源码下载/苹果cms模板MXone Pro自适应影视电影网站模板 模板带有夜间模式、白天晚上自动切换,有观影记录、后台设置页。全新UI全新框架,加载响应速度更快,seo更好,去除多余页面优化代码。…

从零开始搭建创业公司全新技术栈解决方案

从零开始搭建创业公司全新技术栈解决方案 关于猫头虎 大家好,我是猫头虎,别名猫头虎博主,擅长的技术领域包括云原生、前端、后端、运维和AI。我的博客主要分享技术教程、bug解决思路、开发工具教程、前沿科技资讯、产品评测图文、产品使用体…

Ollma本地大模型沉浸式翻译【403报错解决】

最终效果 通过Chrome的 沉浸式翻译 插件,用OpenAI通用接口调用本地的Ollma上的模型,实现本地的大模型翻译文献。 官方文档指导的Ollama的配置:一定要配置环境变量,否则会出现【403报错】

GoogLeNet(InceptionV3)模型算法

GoogLeNet 团队在给出了一些通用的网络设计准则,以期望在不提高网络参数 量的前提下提升网络的表达能力: 避免特征图 (feature map) 表达瓶颈:从理论上讲,尺寸 (seize) 才包含了相关结构等重要因素,维度(channel) 仅仅…

torch.optim 之 Algorithms (Implementation: for-loop, foreach, fused)

torch.optim的官方文档 官方文档中文版 一、Implementation torch.optim的官方文档在介绍一些optimizer Algorithms时提及它们的implementation共有如下三个类别:for-loop, foreach (multi-tensor), and fused。 Chat-GPT对这三个implementation的解释是&#xf…

账号和权限的管理

文章目录 管理用户账号和组账号用户账号的分类超级用户普通用户程序用户 UID(用户id)和(组账号)GIDUID用户识别号GID组标识号 用户账号文件添加用户账号设置/更改用户口令 管理用户账号和组账号 用户账号的分类 超级用户 root 用户是 Linux 操作系统中默认的超级…

第N5周:调用Gensim库训练Word2Vec模型

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 | 接辅导、项目定制🚀 文章来源:K同学的学习圈子 目录 本周任务: 1.安装Gensim库 2.对原始语料分词 3.停用词 4.训练Woed2Vec模型 …

办展览如何盈利?论办展的商业模式

想要弄清楚办展览怎么赚钱这个问题,我可以来说说。 首先来说说展览收益的大头:门票收入。 这个其实是可以大致预测的。简单来说,就是用流量乘以到店率。 但别忘了,这背后得有合适的展览定位、方便的展览场地和合理的票价。 说…

小林图解系统-三、操作系统结构

Linux 内核 vs Windows 内核 内核 作为应用连接硬件设备的桥梁,保证应用程序只需要关心与内核交互,不需要关心硬件的细节 内核具备四个基本能力: 管理进程、线程,决定哪个进程、线程使用CPU,也就是进程调度的能力&a…

Linux——ansible关于“文件操作”的模块

修改文件并将其复制到主机 一、确保受管主机上存在文件 使用 file 模块处理受管主机上的文件。其工作方式与 touch 命令类似,如果不存在则创建一个空文件,如果存在,则更新其修改时间。在本例中,除了处理文件之外,Ansi…

华为设备SSH远程访问配置实验简述

一、实验需求: 1、AR1模拟电脑SSH 访问AR2路由器。 二、实验步骤: 1、AR1和AR2接口配置IP,实现链路通信。 2、AR2配置AAA模式 配置用户及密码 配置用户访问级别 配置用户SSH 访问服务 AR2配置远程服务数量 配置用户远程访问模式为AAA 配置允许登录接入用…

【问题记录】Ubuntu提示: “E: 软件包 gcc 没有可安装候选“

Ubuntu提示: "E: 软件包 gcc 没有可安装候选" 一,问题现象二,问题原因&解决方法 一,问题现象 在虚拟机Ubuntu中进行安装gcc命令时报错:“E: 软件包 gcc 没有可安装候选”: 二,问题原因&解决方法 …

性能测试-性能监控分析与调优(三)《实战》

性能监控 使用命令监控 cpu瓶颈分析 top命令 在进行性能测试时使用top命令,界面如下 上图可以看出 CPU 概况区: %Cpu(s): us(用户进程占用CPU的百分比), 和 sy(系统进程占用CPU的百分比) 的数值很高…

代码随想录-Day36

452. 用最少数量的箭引爆气球 有一些球形气球贴在一堵用 XY 平面表示的墙面上。墙面上的气球记录在整数数组 points ,其中points[i] [xstart, xend] 表示水平直径在 xstart 和 xend之间的气球。你不知道气球的确切 y 坐标。 一支弓箭可以沿着 x 轴从不同点 完全垂…