昇思25天学习打卡营第7天 | 模型训练

内容介绍:

模型训练一般分为四个步骤:

1. 构建数据集。
2. 定义神经网络模型。
3. 定义超参、损失函数及优化器。
4. 输入数据集进行训练与评估。

具体内容:

1. 导包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from download import download

2. 构建数据集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

3. 定义神经网络模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()

4. 定义超参、损失函数和优化器

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化。

训练轮次(epoch):训练时遍历数据集的次数。

批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。

学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的`nn.MSELoss`(均方误差)和用于分类的`nn.NLLLoss`(负对数似然)等。 `nn.CrossEntropyLoss` 结合了`nn.LogSoftmax`和`nn.NLLLoss`,可以对logits 进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过`model.trainable_params()`方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:

grads = grad_fn(inputs)

optimizer(grads)

5. 训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。
 

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

MindSpore的易用性也给我带来了很大的便利。通过简洁明了的API和丰富的文档支持,我能够快速地掌握MindSpore的使用方法,并轻松地构建自己的深度学习模型。同时,MindSpore还提供了丰富的预训练模型和示例代码,让我能够更快地入门并深入理解深度学习的应用。

在模型训练的过程中,我深刻体会到了深度学习模型的复杂性和挑战性。通过不断地调整网络结构、优化参数设置以及尝试不同的训练策略,我逐渐掌握了如何构建和训练一个性能优异的深度学习模型。这个过程让我更加明白了深度学习模型训练需要耐心、细致和持续的努力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/35458.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手把手教你使用kimi创建流程图【实践篇】

学境思源,一键生成论文初稿: AcademicIdeas - 学境思源AI论文写作 引言 在昨日的文章中,我们介绍了如何使用Kimi生成论文中的流程图。今天,我们将更进一步,通过实践案例来展示Kimi在生成流程图方面的应用。这不仅将加…

【大数据技术原理与应用(概念、存储、处理、分析与应用)】第1章-大数据概述习题与知识点回顾

文章目录 单选题多选题知识点回顾几次信息化浪潮主要解决什么问题?信息科技为大数据时代提供哪些技术支撑?数据产生方式有哪些变革?大数据的发展历程大数据的四个特点(4V)大数据对思维方式的影响大数据有哪些关键技术&…

软考《信息系统运行管理员》-1.2信息系统运维

1.2信息系统运维 传统运维模式(软件) 泛化:软件交付后围绕其所做的任何工作纠错:软件运行中错误的发现和改正适应:为适应环境做出的改变用户支持:为软件用户提供的支持 新的不同视角下的运维 “管理”的…

Java 面试指南合集

线程篇 springBoot篇 待更新 黑夜无论怎样悠长,白昼总会到来。 此文会一直更新哈 如果你希望成功,当以恒心为良友,以经验为参谋,以当心为兄弟,以希望为哨兵。

拉普拉斯变换与卷积

前面描述 卷积,本文由卷积引入拉普拉斯变换。 拉普拉斯变换就是给傅里叶变换的 iωt 加了个实部,也可以反着理解,原函数乘以 e − β t e^{-\beta t} e−βt 再做傅里叶变换,本质上都是傅里叶变换的扩展。 加入实部的拉普拉斯变…

【建设方案】智慧园区大数据云平台建设方案(DOC原件)

大数据云平台建设技术要点主要包括以下几个方面: 云计算平台选择:选择安全性高、效率性强、成本可控的云计算平台,如阿里云、腾讯云等,确保大数据处理的基础环境稳定可靠。 数据存储与管理:利用Hadoop、HBase等分布式…

一年Java转GO|19K|腾讯 CSIG 一二面经

面经哥只做互联网社招面试经历分享,关注我,每日推送精选面经,面试前,先找面经哥 背景 学历:本科工作经验:一年(不算实习)当前语言:Javabase:武汉部门\岗位:腾讯云‍ 一…

5000天后的世界:科技引领的未来之路

**你是否想过,5000天后的世界会是什么样子?** 科技日新月异,改变着我们的生活方式,也引领着人类文明的进程。著名科技思想家凯文凯利在他的著作《5000天后的世界》中,对未来进行了大胆的预测。 **这本书中&#xff0c…

基于微信小程序的在线点餐系统【前后台+附源码+LW】

摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 点餐小程序,主要的模块包括实现管理员;管理员用户,可以对整个系统进行基本的增删改查,系统的日…

Opencv+python模板匹配

我们经常玩匹配图像或者找相似,opencv可以很好实现这个简单的小功能。 模板是被查找目标的图像,查找模板在原始图像中的哪个位置的过程就叫模板匹配。OpenCV提供的matchTemplate()方法就是模板匹配方法,其语法如下: result cv2.…

使用go语言来完成复杂excel表的导出导入

使用go语言来完成复杂excel表的导出导入(一) 1.复杂表的导入 开发需求是需要在功能页面上开发一个excel文件的导入导出功能,这里的复杂指定是表内数据夹杂着一对多,多对一的形式,如下图所示。数据杂乱而且对应不统一。…

中国90米分辨率可蚀性因子K数据

土壤可蚀性因子(K)数据,基于多种土壤属性数据计算,所用数据包括土壤黏粒含量(%)、粉粒含量(%)、砂粒含量(%)、土壤有机碳含量(g/kg)、…

[DALL·E 2] Hierarchical Text-Conditional Image Generation with CLIP Latents

1、目的 CLIP DDPM进行text-to-image生成 2、数据 (x, y),x为图像,y为相应的captions;设定和为CLIP的image和text embeddings 3、方法 1)CLIP 学习图像和文本的embedding;在训练prior和decoder时固定该部分参数 2&a…

开放式耳机什么牌子好一点?亲检的几款开放式蓝牙耳机推荐

不入耳的开放式耳机更好一些,不入耳式耳机佩戴更舒适,适合长时间佩戴,不会引起强烈的压迫感或耳部不适。不入耳式的设计不需要接触耳朵,比入耳式耳机更加卫生且不挑耳型,因此备受运动爱好者和音乐爱好者的喜爱。这里给…

周转车配料拣货方案

根据周转车安装的电子标签,被悬挂的扫码器扫到墨水屏显示的二维码,投屏发送配料拣货的数据。 方便快捷分拣物料

20240625(周二)欧美股市总结:标普纳指止步三日连跌,英伟达反弹6.8%,谷歌微软新高,油价跌1%

美联储理事鲍曼鹰派发声,若通胀没有持续改善将支持加息,加拿大5月CPI重新加速,对加拿大央行7月降息构成阻碍。美股走势分化,道指收跌近300点且六日里首跌,英伟达市值重上3.10万亿美元,芯片股指显著反弹1.8%…

想要用tween实现相机的移动,three.js渲染的canvas画布上相机位置一点没动,如何解决??

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

第1章 物联网模式简介---独特要求和体系结构原则

物联网用例的独特要求 物联网用例往往在功耗、带宽、分析等方面具有非常独特的要求。此外,物联网实施的固有复杂性(一端的现场设备在计算上受到挑战,另一端的云容量几乎无限)迫使架构师做出艰难的架构决策和实施选择。可用实现技…

【自动调参】年化29.3%,最大回撤18.5%​:lightGBM的参数优化

原创文章第570篇,专注“AI量化投资、世界运行的规律、个人成长与财富自由"。 研报复现继续:【研报复现】年化27.1%,人工智能多因子大类资产配置策略之benchmark 昨天调了一版参数,主要是lambda_l1, lambda_l2,防…

Vmvare12安装CentOS7.6

Vmvare12安装 注意事项 安装完成以后有这两个虚拟网卡。 CentOS官网镜像地址 https://www.centos.org/download/mirrors/Vmvare安装CentOS7.6 创建虚拟机 安装CentOS7.6 选择桌面版 磁盘分区 上述是确认使用自动分区。 设置密码 设置license information 欢迎页面 CentOS7…