昇思25天学习打卡营第7天 | 模型训练

内容介绍:

模型训练一般分为四个步骤:

1. 构建数据集。
2. 定义神经网络模型。
3. 定义超参、损失函数及优化器。
4. 输入数据集进行训练与评估。

具体内容:

1. 导包

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset
from download import download

2. 构建数据集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

3. 定义神经网络模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()

4. 定义超参、损失函数和优化器

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化。

训练轮次(epoch):训练时遍历数据集的次数。

批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。

学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的`nn.MSELoss`(均方误差)和用于分类的`nn.NLLLoss`(负对数似然)等。 `nn.CrossEntropyLoss` 结合了`nn.LogSoftmax`和`nn.NLLLoss`,可以对logits 进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过`model.trainable_params()`方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)

在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:

grads = grad_fn(inputs)

optimizer(grads)

5. 训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。
 

def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

MindSpore的易用性也给我带来了很大的便利。通过简洁明了的API和丰富的文档支持,我能够快速地掌握MindSpore的使用方法,并轻松地构建自己的深度学习模型。同时,MindSpore还提供了丰富的预训练模型和示例代码,让我能够更快地入门并深入理解深度学习的应用。

在模型训练的过程中,我深刻体会到了深度学习模型的复杂性和挑战性。通过不断地调整网络结构、优化参数设置以及尝试不同的训练策略,我逐渐掌握了如何构建和训练一个性能优异的深度学习模型。这个过程让我更加明白了深度学习模型训练需要耐心、细致和持续的努力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/35458.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手把手教你使用kimi创建流程图【实践篇】

学境思源,一键生成论文初稿: AcademicIdeas - 学境思源AI论文写作 引言 在昨日的文章中,我们介绍了如何使用Kimi生成论文中的流程图。今天,我们将更进一步,通过实践案例来展示Kimi在生成流程图方面的应用。这不仅将加…

【大数据技术原理与应用(概念、存储、处理、分析与应用)】第1章-大数据概述习题与知识点回顾

文章目录 单选题多选题知识点回顾几次信息化浪潮主要解决什么问题?信息科技为大数据时代提供哪些技术支撑?数据产生方式有哪些变革?大数据的发展历程大数据的四个特点(4V)大数据对思维方式的影响大数据有哪些关键技术&…

burpsuite 抓https的方法(CA证书操作)

https://cloud.tencent.com/developer/article/1391501

软考《信息系统运行管理员》-1.2信息系统运维

1.2信息系统运维 传统运维模式(软件) 泛化:软件交付后围绕其所做的任何工作纠错:软件运行中错误的发现和改正适应:为适应环境做出的改变用户支持:为软件用户提供的支持 新的不同视角下的运维 “管理”的…

Java 面试指南合集

线程篇 springBoot篇 待更新 黑夜无论怎样悠长,白昼总会到来。 此文会一直更新哈 如果你希望成功,当以恒心为良友,以经验为参谋,以当心为兄弟,以希望为哨兵。

拉普拉斯变换与卷积

前面描述 卷积,本文由卷积引入拉普拉斯变换。 拉普拉斯变换就是给傅里叶变换的 iωt 加了个实部,也可以反着理解,原函数乘以 e − β t e^{-\beta t} e−βt 再做傅里叶变换,本质上都是傅里叶变换的扩展。 加入实部的拉普拉斯变…

【建设方案】智慧园区大数据云平台建设方案(DOC原件)

大数据云平台建设技术要点主要包括以下几个方面: 云计算平台选择:选择安全性高、效率性强、成本可控的云计算平台,如阿里云、腾讯云等,确保大数据处理的基础环境稳定可靠。 数据存储与管理:利用Hadoop、HBase等分布式…

一年Java转GO|19K|腾讯 CSIG 一二面经

面经哥只做互联网社招面试经历分享,关注我,每日推送精选面经,面试前,先找面经哥 背景 学历:本科工作经验:一年(不算实习)当前语言:Javabase:武汉部门\岗位:腾讯云‍ 一…

5000天后的世界:科技引领的未来之路

**你是否想过,5000天后的世界会是什么样子?** 科技日新月异,改变着我们的生活方式,也引领着人类文明的进程。著名科技思想家凯文凯利在他的著作《5000天后的世界》中,对未来进行了大胆的预测。 **这本书中&#xff0c…

基于微信小程序的在线点餐系统【前后台+附源码+LW】

摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 点餐小程序,主要的模块包括实现管理员;管理员用户,可以对整个系统进行基本的增删改查,系统的日…

什么是<meta> 标签

<meta> 标签是 HTML 文档头部 (<head>) 中的一种元数据标签&#xff0c;用于提供关于 HTML 文档的信息。虽然它不会直接影响文档的呈现&#xff0c;但它在搜索引擎优化 (SEO)、浏览器行为和文档元信息方面起着重要作用。以下是一些常见的 <meta> 标签及其用途…

Opencv+python模板匹配

我们经常玩匹配图像或者找相似&#xff0c;opencv可以很好实现这个简单的小功能。 模板是被查找目标的图像&#xff0c;查找模板在原始图像中的哪个位置的过程就叫模板匹配。OpenCV提供的matchTemplate()方法就是模板匹配方法&#xff0c;其语法如下&#xff1a; result cv2.…

使用go语言来完成复杂excel表的导出导入

使用go语言来完成复杂excel表的导出导入&#xff08;一&#xff09; 1.复杂表的导入 开发需求是需要在功能页面上开发一个excel文件的导入导出功能&#xff0c;这里的复杂指定是表内数据夹杂着一对多&#xff0c;多对一的形式&#xff0c;如下图所示。数据杂乱而且对应不统一。…

中国90米分辨率可蚀性因子K数据

土壤可蚀性因子&#xff08;K&#xff09;数据&#xff0c;基于多种土壤属性数据计算&#xff0c;所用数据包括土壤黏粒含量&#xff08;%&#xff09;、粉粒含量&#xff08;%&#xff09;、砂粒含量&#xff08;%&#xff09;、土壤有机碳含量&#xff08;g/kg&#xff09;、…

[DALL·E 2] Hierarchical Text-Conditional Image Generation with CLIP Latents

1、目的 CLIP DDPM进行text-to-image生成 2、数据 (x, y)&#xff0c;x为图像&#xff0c;y为相应的captions&#xff1b;设定和为CLIP的image和text embeddings 3、方法 1&#xff09;CLIP 学习图像和文本的embedding&#xff1b;在训练prior和decoder时固定该部分参数 2&a…

开放式耳机什么牌子好一点?亲检的几款开放式蓝牙耳机推荐

不入耳的开放式耳机更好一些&#xff0c;不入耳式耳机佩戴更舒适&#xff0c;适合长时间佩戴&#xff0c;不会引起强烈的压迫感或耳部不适。不入耳式的设计不需要接触耳朵&#xff0c;比入耳式耳机更加卫生且不挑耳型&#xff0c;因此备受运动爱好者和音乐爱好者的喜爱。这里给…

MySQL中ALTER LOGFILE GROUP 语句详解

在 MySQL 的 InnoDB 存储引擎中&#xff0c;ALTER LOGFILE GROUP 语句用于修改重做日志组&#xff08;redo log group&#xff09;的配置。重做日志是 InnoDB 用来保证事务持久性的一个关键组件&#xff0c;它们用于在系统崩溃后恢复数据。 InnoDB 支持多个重做日志组&#xf…

周转车配料拣货方案

根据周转车安装的电子标签&#xff0c;被悬挂的扫码器扫到墨水屏显示的二维码&#xff0c;投屏发送配料拣货的数据。 方便快捷分拣物料

20240625(周二)欧美股市总结:标普纳指止步三日连跌,英伟达反弹6.8%,谷歌微软新高,油价跌1%

美联储理事鲍曼鹰派发声&#xff0c;若通胀没有持续改善将支持加息&#xff0c;加拿大5月CPI重新加速&#xff0c;对加拿大央行7月降息构成阻碍。美股走势分化&#xff0c;道指收跌近300点且六日里首跌&#xff0c;英伟达市值重上3.10万亿美元&#xff0c;芯片股指显著反弹1.8%…

想要用tween实现相机的移动,three.js渲染的canvas画布上相机位置一点没动,如何解决??

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…