《昇思25天学习打卡营第7天 | mindspore 模型训练常见用法》

1. 背景:

使用 mindspore 学习神经网络,打卡第7天;

2. 训练的内容:

使用 mindspore 的模型训练的常见用法,基本上是将前几章节的功能串起来

3. 常见的用法小节:

模型训练的常见流程,如数据加载,神经网路网络定义,损失函数定义,优化器定义,训练,测试,验证等步骤

3.1 构建数据集:

首先从数据集 Dataset加载代码,构建数据集

# 首先从数据集 Dataset加载代码,构建数据集
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

3.2 构建神经网络模型

从网络构建中加载代码,构建一个神经网络模型

# 从网络构建中加载代码,构建一个神经网络模型
class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()

3.3 定义损失函数与优化器

定义超参(Hyperparameters),损失函数(LossFunction)和优化器(Optimizer)

# 定义超参(Hyperparameters),损失函数(LossFunction)和优化器(Optimizer)
epochs = 3
batch_size = 64
learning_rate = 1e-2# 损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。
# 结合了nn.LogSoftmax和负对数似然(nn.NLLLoss) 
loss_fn = nn.CrossEntropyLoss()# 模型优化(Optimization)
# 在每个训练中调整模型参数,减少模型误差的过程
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

3.4 定义训练函数

定义训练函数

# 训练与评估
# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

3.5 定义测试函数

定义测试函数

# 测试函数:
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

3.6 实例化训练与测试,并运行

实例化训练与测试,并运行

# 实例化的损失函数和优化器传入 train_loop 和 test_loop 中
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

相关链接:

  • https://xihe.mindspore.cn/events/mindspore-training-camp
  • https://gitee.com/mindspore/docs/blob/r2.3.0rc2/tutorials/source_zh_cn/beginner/train.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869616.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

不想成为失业大军,就要学习六西格玛?

最近,优思学院收到一封邮件,这封邮件的发送者是一位完成了我们六西格玛绿带课程的学生。 他的公司裡有20%的工程师被裁员,但值得注意的是,留下来的工程师中有70%人竟然都持有六西格玛绿带或黑带证书。 他的公司不仅希望利用这些…

el-table封装popver組件,点击列筛选行数据功能,支持筛选,搜索,排序功能

子组件&#xff1a; <template><div class"tableTool" ref"tableTool" click.stop><el-button click"shengFnc">升序</el-button><el-button click"jiangFnc">降序</el-button><el-input v-m…

安卓 APK 安装过程详解

&#x1f34e;个人博客&#xff1a;个人主页 &#x1f3c6;个人专栏&#xff1a;Android ⛳️ 功不唐捐&#xff0c;玉汝于成 目录 前言 正文 1. 开机后连上网线 2. 查看网线的IP地址 3. 检查ADB连接 4. 修改文件权限 步骤 结语 我的其他博客 前言 在安卓设备上安装…

容器之docker compose

Docker Compose 是一个用于定义和运行多容器 Docker 应用的工具。通过一个 YAML 文件&#xff0c;您可以配置应用程序需要的所有服务&#xff0c;并使用单个命令来创建和启动这些服务。以下是对 Docker Compose 的详细介绍&#xff1a; 核心概念 服务&#xff08;Services&am…

python批量读取Excel数据写入word

from docx import Document from docx.shared import Pt from docx.enum.table import WD_TABLE_ALIGNMENT, WD_ROW_HEIGHT_RULE import os import pandas as pd from docx import Document from docx.oxml.ns import qn from docx.shared import Pt # ... 其他代码 ... work…

加密市场重新定位:全球流动性困境下的转型之痛

随着ETF通道的打开&#xff0c;加密市场所期待的“泼天资金量”并未达预期&#xff0c;全球金融市场的流动性匮乏问题蔓延至加密市场。新通道的打开也意味着之前那个复杂且成熟市场的规则随之与加密市场的文化与投资逻辑相碰撞。于是&#xff0c;加密市场从一个近乎封闭的避风港…

python3 ftplib乱码怎么解决

其实很简单。ftplib.FTP里面有个参数叫encoding。 如上图最后一行。所以在使用FTP时&#xff0c;主动指定编码格式即可。 ftp ftplib.FTP() ftp.encoding "utf-8" 再使用就可以了。

!vue3中defineEmits接收父组件向子组件传递方法,以及方法所需传的参数及类型定义,避免踩坑!

使用说明 1、在子组件中调用defineEmits并定义要发射给父组件的方法 const emits defineEmits([‘foldchange’]) 2、使用defineEmits会返回一个方法&#xff0c;使用一个变量emits(变量名随意)去接收 3、在子组件要触发的方法中&#xff0c;调用emits并传入发射给父组件的方法…

Kimi携手思维链,点亮论文写作之路!

学境思源&#xff0c;一键生成论文初稿&#xff1a; AcademicIdeas - 学境思源AI论文写作 在学术的海洋中&#xff0c;思想的火花常常在静谧的图书馆角落或深夜的电脑屏幕前迸发。今天分享的内容是一种高阶的论文写作方法&#xff1a;Kimi思维链。 Kimi&#xff0c;一个由月之…

【数据结构和算法的概念等】

目录 一、数据结构1、数据结构的基本概念2、数据结构的三要素2.1 数据的逻辑结构2.2 数据的存储&#xff08;物理&#xff09;结构2.3 数据的运算 二、算法1、算法概念2、算法的特性及特点3、算法分析 一、数据结构 1、数据结构的基本概念 数据&#xff1a; 是所有能输入到计…

D634-341C电液伺服系统比例控制阀 R40KO2M0NSS2

D634-341C/R40KO2M0NSS2宁波秉圣现货供应 宁波秉圣工业技术有限公司是一家专门从事于欧洲,美国等多国家的进口备件进出口销售、技术咨询、技术服务、自动化设备服务为一体的贸易公司。公司的优势品牌如下&#xff1a;德国REXROTH&#xff08;力士乐&#xff09;、德国MOOG、美…

全球数字贸易中心解析_保税区保的是什么税_为什么保税区还要交税

保税区税收机制深度解析&#xff1a;保税免的是什么税&#xff1f;为何仍需缴税&#xff1f; 保税区概述 保税区&#xff0c;作为海关特殊监管区域的重要一环&#xff0c;享有国家高度开放的政策优惠与功能齐全的海关监管服务。它专为保税加工、保税物流和保税服务而设&#…

【Python】已解决:ModuleNotFoundError: No module named ‘pip‘(重新安装pip的两种方式)

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例 四、重新安装pip的两种方式方式一&#xff1a;使用get-pip.py脚本方式二&#xff1a;使用ensurepip模块五、注意事项 已解决&#xff1a;ModuleNotFoundError: No module named ‘pip’&#xff08;重新安装pip的…

GLM4大模型微调入门实战-命名实体识别(NER)任务

[GLM4]是清华智谱团队最近开源的大语言模型。 以GLM4作为基座大模型&#xff0c;通过指令微调的方式做高精度的命名实体识别&#xff08;NER&#xff09;&#xff0c;是学习入门LLM微调、建立大模型认知的非常好的任务。 显存要求相对较高&#xff0c;需要40GB左右。 知识点1&…

Unity之王牌飞行员申请出战

目录 &#x1f4da;一、准备工作 &#x1f4bb;二、飞机的两套控制器 &#x1f3ae;2.1 起飞前 &#x1f579;️2.2 起飞后 &#x1f680;三、实现射击功能 &#x1f4a5;3.1 射击脚本 &#x1f4a5;3.2 爆炸脚本 &#x1f4a5;3.3 爆炸特效脚本 &#x1f6e0;️四、组…

系统卡顿原因,jbd2 , cat /proc/interrupts, 网络ssh

https://zhuanlan.zhihu.com/p/582827171 cat /proc/interrupts, cat /proc/stat_cat proc stat-CSDN博客 Linux 系统运行速度太慢的关键原因全都在这了-腾讯云开发者社区-腾讯云 结论 虽然有很多因素可能导致系统缓慢&#xff0c;但CPU、内存和磁盘I/O是导致绝大多数性能问…

AGE 可比性、相等性、可排序性和等效性

AGE已经对原始类型&#xff08;布尔值、字符串、整数和浮点数&#xff09;和映射的相等性有了良好的语义。此外&#xff0c;Cypher对整数、浮点数和字符串的可比性和可排序性也有很好的语义。然而&#xff0c;处理不同类型的值与Postgres定义的逻辑和openCypher规范存在偏差&am…

Linux 期末速成(知识点+例题)

代码学习 &#xff08;1&#xff09;查看运行结果&#xff0c;思考为什么&#xff1f; [root red ~]#VAR11 ​ [root red ~]#(VAR12; echo $VAR1) ​ [root red ~]#echo $VAR1 ​ [root red ~]# { VAR12; echo $VAR1; } ​ [root red ~]# echo $VAR1VAR11&#xf…

千万级消息推送系统设计与实战

系统概述 功能: 支持千万级用户的消息推送&#xff0c;包括通知和透传消息&#xff0c;快速触达用户&#xff0c;提升用户留存和活跃度。接入方式: 通过API接入&#xff0c;实现消息的即时推送。管理后台: 提供网页端后台&#xff0c;便于应用授权、业务管理、消息管理及数据查…

Linux_网络编程_TCP

服务器客户端模型&#xff1a; client / server brow / ser b / s http p2p socket——tcp 1、模式 C/S 模式 》服务器/客户端模型 server :socket()-->bind()--->listen()-->accept()-->recv()-->close()client :socket()-->conn…