昇思25天学习打卡营第8天|模型训练

昇思25天学习打卡营第8天|模型训练

  • 前言
  • 模型训练
    • 构建数据集
    • 定义神经网络模型
    • 定义超参、损失函数和优化器
      • 超参
      • 损失函数
      • 优化器
    • 训练与评估
  • 个人任务打卡(读者请忽略)
  • 个人理解与总结

前言

  非常感谢华为昇思大模型平台和CSDN邀请体验昇思大模型!从今天起,笔者将以打卡的方式,将原文搬运和个人思考结合,分享25天的学习内容与成果。为了提升文章质量和阅读体验,笔者会将思考部分放在最后,供大家探索讨论。同时也欢迎各位领取算力,免费体验昇思大模型!

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

构建数据集

首先从数据集 Dataset加载代码,构建数据集。

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0rc1,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.3.0rc1
import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)#从URL下载MNIST数据集用于模型训练def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),						#将图像中的矩阵单元转为float32类型vision.Normalize(mean=(0.1307,), std=(0.3081,)),	#设置均值和标准差将图像矩阵元素归一化vision.HWC2CHW()									#将图像的shape从[H, W, C]转为[C, H, W]]label_transform = transforms.TypeCast(mindspore.int32)	#图像的label转为int32类型dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)	#训练集设置路径,batch_size=64
test_dataset = datapipe('MNIST_Data/test', batch_size=64)	#测试集设置路径,batch_size=64

在这里插入图片描述

定义神经网络模型

从网络构建中加载代码,构建一个神经网络模型。

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(		#网络依序搭建nn.Dense(28*28, 512),							#全连接层,输入通道数为28*28,输出通道数为512nn.ReLU(),										#ReLU激活函数nn.Dense(512, 512),								#全连接层,输入通道数为512,输出通道数为512nn.ReLU(),										#ReLU激活函数nn.Dense(512, 10)								#全连接层,输入通道数为512,输出通道数为10,对应0-9十种情况)def construct(self, x):x = self.flatten(x)									#将输入图像展平为一维序列logits = self.dense_relu_sequential(x)				#将一维序列输入神经网络获得结果return logitsmodel = Network()											#神经网络实例化为model

定义超参、损失函数和优化器

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下:

w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right) wt+1=wtηn1xBl(x,wt)

公式中, n n n是批量大小(batch size), η η η是学习率(learning rate)。另外, w t w_{t} wt为训练轮次 t t t中的权重参数, ∇ l \nabla l l为损失函数的导数。除了梯度本身,这两个因子直接决定了模型的权重更新,从优化本身来看,它们是影响模型性能收敛最重要的参数。一般会定义以下超参用于训练:

  • 训练轮次(epoch):训练时遍历数据集的次数。

  • 批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。

  • 学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的nn.MSELoss(均方误差)和用于分类的nn.NLLLoss(负对数似然)等。 nn.CrossEntropyLoss 结合了nn.LogSoftmaxnn.NLLLoss,可以对logits 进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()	#使用交叉熵计算损失

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过model.trainable_params()方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)	#使用SGD优化器,对初始学习率和训练参数进行优化

在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:

grads = grad_fn(inputs)

optimizer(grads)

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。

接下来我们定义用于训练的train_loop函数和用于测试的test_loop函数。

使用函数式自动微分,需先定义正向函数forward_fn,使用value_and_grad获得微分函数grad_fn。然后,我们将微分函数和优化器的执行封装为train_step函数,接下来循环迭代数据集进行训练即可。

# Define forward function
def forward_fn(data, label):		#定义前向函数logits = model(data)			loss = loss_fn(logits, label)	#损失为训练结果与真实值(标签)的差距return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)#梯度函数# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)						#优化梯度return lossdef train_loop(model, dataset):			#模型训练size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:							#当batch为100的整数倍时,输出lossloss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

test_loop函数同样需循环遍历数据集,调用模型计算loss和Accuray并返回最终结果。

def test_loop(model, dataset, loss_fn):				#模型测试num_batches = dataset.get_dataset_size()model.set_train(False)							#设置模型为非训练模式total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():	#计算准确率和测试损失pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

我们将实例化的损失函数和优化器传入train_looptest_loop中。训练3轮并输出loss和Accuracy,查看性能变化。

loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

在这里插入图片描述

个人任务打卡(读者请忽略)

在这里插入图片描述

个人理解与总结

本章节主要描述了昇思大模型模型训练的主要功能,从训练模型的最基础的构建数据集开始,先后完成了定义神经网络模型、定义超参、损失函数和优化器以及最终的训练预评估四部分。通过搭建简单的全连接图像分类网络完成手写体识别任务。网络首先通过下载和设置数据集,完成数据集搭建,为后续的深度神经网络训练打下基础;然后定义神经网络,搭建输入-隐藏-输出三层全连接神经网络以完成图像分类任务;定义超参数、损失函数和优化器保证深度神经网络在训练过程中的正向优化;最后经train_loop和test_loop完成深度神经网络的训练与测试,并输出每100个iteration的损失及每代的测试准确率和测试损失。综上所述,昇思大模型为深度神经网络的搭建提供了系统性参照,为深度更深、结构更复杂的模型提供了搭建基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/38590.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux中如何启动python虚拟环境

找到python虚拟环境所在目录 执行下面的命令即可 source auth_python/bin/activate

【遇坑笔记】Node.js 开发环境与配置 Visual Studio Code

【遇坑笔记】Node.js 开发环境与配置 Visual Studio Code 前言node.js开发环境配置解决pnpm 不是内部或外部命令的问题(pnpm安装教程) 解决 pnpm : 无法加载文件 C:\Program Files\nodejs\pnpm.ps1,因为在此系统上禁止运行脚本。 前言 最近部…

【代码随想录】【算法训练营】【第49天】 [300]最长递增子序列 [674]最长连续递增序列 [718]最长重复子数组

前言 思路及算法思维,指路 代码随想录。 题目来自 LeetCode。 day 49,周二,坚持不了一点~ 题目详情 [300] 最长递增子序列 题目描述 300 最长递增子序列 解题思路 前提:最大递增子序列的长度 思路:动态规划 d…

基于X86+FPGA的精密加工检测设备解决方案

应用场景 随着我国高新技术的发展和国防现代化发展,航空、航天等领域需 要的大型光电子器件,微型电子机械、 光 电信息等领域需要的微型器件,还有一些复杂零件的加工需求日益增加,这些都需要借助精密甚至超精密的加工检测设备 客…

esp12实现的网络时钟校准

网络时间的获取是通过向第三方服务器发送GET请求获取并解析出来的。 在本篇博客中,网络时间的获取是一种自动的行为,当系统成功连接WiFi获取到网络天气后,系统将自动获取并解析得到时间和日期,为了减少误差每两分钟左右进行一次校…

web平台—apache

web平台—apache 1. 学apache前需要知道的知识点2. apache详解2.1 概述2.2 工作模式2.3 启动apache网站整体流程2.4 相关文件保存位置2.5 配置文件详解 3. apache配置实验实验1:设置apache的目录别名实验2:apache的用户认证实验3:虚拟主机 (重…

江门MES制造执行系统:助力工厂实现智能化管理

江门MES制造执行系统(MES)在工厂实现智能化管理方面发挥着重要作用,以下是它的一些助力方面: 实时监控与控制:江门MES系统可以实时监控生产过程中的各个环节,包括设备状态、生产进度、质量指标等,帮助工厂管理人员及时…

LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection

LW-DETR: A Transformer Replacement to YOLO for Real-Time Detection 论文链接:http://arxiv.org/abs/2406.03459 代码链接:https://github.com/Atten4Vis/LW-DETR 一、摘要 介绍了一种轻量级检测变换器LWDETR,它在实时物体检测方面超越…

CF1981D Turtle and Multiplication 题解

Turtle and Multiplication 传送门 Turtle just learned how to multiply two integers in his math class, and he was very excited. Then Piggy gave him an integer n n n , and asked him to construct a sequence a 1 , a 2 , … , a n a_1, a_2, \ldots, a_n a1​,…

Java [ 基础 ] Stream流 ✨

✨探索Java基础Stream流✨ 在现代Java编程中,Stream是一个非常强大的工具,它提供了一种更高效和简洁的方式来处理集合数据。在这篇博客中,我们将深入探讨Java中的Stream流,介绍它的基础知识、常见操作和一些实用示例。 什么是Str…

10-错误-java.lang.IllegalStateException Stopwatch is not running

10-错误-java.lang.IllegalStateException Stopwatch is not running 更多内容欢迎关注我(持续更新中,欢迎Star✨) Github:CodeZeng1998/Java-Developer-Work-Note 技术公众号:CodeZeng1998(纯纯技术文&…

用易查分下发《致家长一封信》,支持在线手写签名,一键导出PDF!

暑假来临之际,学校通常需要下发致家长信,以正式、书面的形式向家长传达重要的通知或建议。传统的发放方式如家长签字后学生将回执单上交,容易存在丢失、遗忘的问题。 那么如何更高效、便捷、安全地将致家长一封信送达给每位家长呢&#xff1f…

Linux[高级管理]——Squid代理服务器的部署和应用(反向代理详解)

🏡作者主页:点击! 👨‍💻Linux高级管理专栏:点击! ⏰️创作时间:2024年6月24日11点11分 🀄️文章质量:95分 目录 ————前言———— Squid的几种模式…

游戏录制视频软件哪个好?这份攻略帮你搞定!

随着游戏行业的快速发展,越来越多的玩家开始录制游戏视频,以便分享自己的游戏体验或保存珍贵回忆。而选择一款合适的游戏录制视频软件显得尤为重要。可是游戏录制视频软件哪个好呢?本文将为大家介绍两款优秀的游戏录制视频软件,通…

Vatee万腾平台:科技驱动,智慧生活

随着科技的飞速发展,我们生活的方方面面正在经历前所未有的变革。Vatee万腾平台,作为这一变革的推动者之一,以其科技驱动的理念,正引领我们迈向更加智慧、便捷的生活。 Vatee万腾平台,是一个集科技研发、应用创新、服务…

Unity热更方案HybridCLR+YooAsset,纯c#开发热更,保姆级教程,从零开始

文章目录: 一、前言二、创建空工程三、接入HybridCLR四、接入YooAsset五、搭建本地资源服务器Nginx六、实战七、最后 一、前言 unity热更有很多方案,各种lua热更,ILRuntime等,这里介绍的是YooAssetHybridCLR的热更方案&#xff0…

jvm性能监控常用工具

在java的/bin目录下有许多java自带的工具。 我们常用的有 基础工具 jar:创建和管理jar文件 java:java运行工具,用于运行class文件或jar文件 javac:java的编译器 javadoc:java的API文档生成工具 性能监控和故障处理 jps jstat…

鸿蒙应用更新跳转到应用市场

鸿蒙没有应用下载安装,只支持跳转到应用市场更新 gotoMarket(){try {const request: Want {parameters: {// 此处填入要加载的应用包名,例如: bundleName: "com.huawei.hmsapp.appgallery"bundleName: com.huawei.hmos.maps.app}}…

浅谈定时器之常数吞吐量定时器

浅谈定时器之常数吞吐量定时器 常数吞吐量定时器的主要目的是在JMeter测试中维持一个恒定的吞吐量(通常是每分钟的请求数或事务数),从而确保测试能够以预期的负载水平运行。这对于模拟特定的用户访问模式、进行稳定性测试、负载测试以及压力…

量化交易 - 策略回测

策略回测 1、什么是策略回测?2、策略回测的作用3、策略回测系统概述3.1策略回测中相关的指标介绍3.2量化交易策略的资金容量3.3 完整的策略回测系统包含哪些内容 1、什么是策略回测? 策略回测,也称之为策略回溯测试,是指利用交易…