《昇思25天学习打卡营第8天|CarpeDiem》

《昇思25天学习打卡营第8天|CarpeDiem》

  • 模型训练
    • 构建数据集
    • 定义神经网络模型
    • 定义超参、损失函数和优化器
      • 超参
      • 损失函数
      • 优化器
    • 训练与评估

打卡

在这里插入图片描述

今天是昇思25天学习打卡营的第8天,终于迎来 模型训练 的部分了!!!

兴奋 发癫

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

评估的时候也可以采用第二天的方式 创建一个 loss_history 的数组 将每次的loss值计算出来存进去,然后借助 matplotlab 模块将loss数据可视化展示,这样可以更加直观的感受到模型训练的过程和结果的好坏

构建数据集

首先从数据集 Dataset加载代码,构建数据集。

老生常谈,没什么好说的了 不会就***给我去前面的篇章看

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

在这里插入图片描述

定义神经网络模型

从网络构建中加载代码,构建一个神经网络模型。

不会就***给我去前面的篇章看

不好意思这个没单独讲过

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsclass Network_new(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 1024),nn.ReLU(),nn.Dense(1024, 512),nn.ReLU(),nn.Dense(512, 128),nn.ReLU(),nn.Dense(128, 10),)def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
model_new = Network_new()

这里的神经网络模型和第二天的神经网络模型是一模一样的

都是将 28*28 的图片先线性变换为 512 再线性变化成 512 再线性变化成 10 得到 10 个类别的特征值

Network_new 是创建的一个新的模型 测试一下多添加一些新的网络层后 其训练结果是怎么样变换的

采用 28*28 -> 512 -> 1024 -> 512 -> 128 -> 10 的形式

可以发现多添加了两层分别为 1024 和 128 个神经元

定义超参、损失函数和优化器

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下:

w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right) wt+1=wtηn1xBl(x,wt)

公式中, n n n是批量大小(batch size), η η η是学习率(learning rate)。另外, w t w_{t} wt为训练轮次 t t t中的权重参数, ∇ l \nabla l l为损失函数的导数。除了梯度本身,这两个因子直接决定了模型的权重更新,从优化本身来看,它们是影响模型性能收敛最重要的参数。一般会定义以下超参用于训练:

  • 训练轮次(epoch):训练时遍历数据集的次数。

  • 批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。

  • 学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的nn.MSELoss(均方误差)和用于分类的nn.NLLLoss(负对数似然)等。 nn.CrossEntropyLoss 结合了nn.LogSoftmaxnn.NLLLoss,可以对logits 进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()
loss_fn_new = nn.CrossEntropyLoss()

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过model.trainable_params()方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)
optimizer_new = nn.SGD(model_new.trainable_params(), learning_rate=learning_rate)

在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:

grads = grad_fn(inputs)

optimizer(grads)

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。

接下来我们定义用于训练的train_loop函数和用于测试的test_loop函数。

使用函数式自动微分,需先定义正向函数forward_fn,使用value_and_grad获得微分函数grad_fn。然后,我们将微分函数和优化器的执行封装为train_step函数,接下来循环迭代数据集进行训练即可。

# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logitsdef forward_fn_new(data, label):logits = model(data)loss = loss_fn_new(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
grad_fn_new = mindspore.value_and_grad(forward_fn_new, None, optimizer_new.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_step_new(data, label):(loss, _), grads = grad_fn_new(data, label)optimizer_new(grads)return loss# 相较第二天的新加入了一个参数 loss_history 使得可以记录历史 loss 值
def train_loop(model, dataset,loss_history):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)# 添加损失到 loss_historyloss_history.append(loss)if batch % 100 == 0:loss, current = loss.asnumpy(), batch# # 添加损失到 loss_history# loss_history.append(loss)print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

test_loop函数同样需循环遍历数据集,调用模型计算loss和Accuray并返回最终结果。

def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

我们将实例化的损失函数和优化器传入train_looptest_loop中。训练3轮并输出loss和Accuracy,查看性能变化。

# use a loss_history array to save losses for visual display 
# 用一个 loss_history 数组去存储所有的损失值 loss 以便进行可视化展示
loss_history = []
loss_history_new = []loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)loss_fn_new = nn.CrossEntropyLoss()
optimizer_new = nn.SGD(model_new.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset, loss_history)test_loop(model, test_dataset, loss_fn)
print("Done!")for t in range(epochs):print(f"Epoch_new {t+1}\n-------------------------------")train_loop(model_new, train_dataset, loss_history_new)test_loop(model_new, test_dataset, loss_fn_new)
print("Done_new!")

在这里插入图片描述

下面导入 matplotlib 模块来进行可视化展示

import matplotlib.pyplot as plt
plt.plot(loss_history)

在这里插入图片描述

plt.plot(loss_history_new)

在这里插入图片描述

就上面两个图来看,第一个训练的结果还是蛮不错的,第二个的loss抖动太多,虽然幅度不大(0.05–0.35)但是还是不如第一个,毕竟加了两层,白瞎了

所以,在训练模型的时候一定要记得化繁为简,多大规模的模型干多大规模的事情,要不然很容易出现譬如:欠拟合、过拟合等各种各样的问题 还得好好调教

这就是今天的全部内容了,别忘了点赞收藏加关注 别逼我求你!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38316.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SSH远程命令执行漏洞(CVE-2024-6387)验证

0x01、漏洞名称 OpenSSH远程代码执行漏洞 (CVE-2024-6387) 0x02、漏洞简介 ​ OpenSSH是SSH(Secure SHell)协议的开源实现,它通过不安全的网络在两个不受信任的主机之间提供安全的加密通信。OpenSSH 广泛用于基于Un…

数据库。

数据库安全性 论述题5’ 编程题10’ sql语言实现权限控制 一、概述 1、不安全因素 (1)⾮授权对数据库的恶意存取和破坏 (2)数据库中重要的数据泄露 (3)安全环境的脆弱性 2、⾃主存取控制⽅法 gr…

【ajax实战06】进行文章发布

本文章目标:收集文章内容,并提交服务器保存 一:基于form-serialize插件收集表单数据 form-serialize插件仅能收集到表单数据,除此之外的数据无法收集到 二:基于axios提交到服务器保存 三:调用alert警告…

基于KMeans的航空公司客户数据聚类分析

💐大家好!我是码银~,欢迎关注💐: CSDN:码银 公众号:码银学编程 实验目的和要求 会用Python创建Kmeans聚类分析模型使用KMeans模型对航空公司客户价值进行聚类分析会对聚类结果进行分析评价 实…

Linux修炼之路之进程概念,fork函数,进程状态

目录 一:进程概念 二:Linux中的进程概念 三:用getpid(),getppid()获取该进程的PID,PPID 四:用fork()来创建子进程 五:操作系统学科的进程状态 六:Linux中的进程状态 接下来的日子会顺顺利利&#xf…

【区块链+基础设施】深证金融区块链平台 | FISCO BCOS应用案例

作为数据交换密集型行业,资本市场是区块链创新应用的重要领域,区块链技术可以有效解决诸多痛点问题。比 如,针对信息不对称的问题,区块链技术通过将整个企业的经营活动信息上链,有效降低尽调成本,为投融资决…

配置windows环境下独立浏览器爬虫方案【不依赖系统环境与chrome】

引言 由于部署浏览器爬虫的机器浏览器版本不同,同时也不想因为部署了爬虫导致影响系统浏览器数据,以及避免爬虫过程中遇到的chrome与webdriver版本冲突。我决定将特定版本的chrome浏览器与webdriver下载到项目目录内,同时chrome_driver在初始…

我使用 GPT-4o 帮我挑西瓜

在 5 月 15 日,OpenAI 旗下的大模型 GPT-4o 已经发布,那时网络上已经传开, 但很多小伙伴始终没有看到 GPT-4o 的体验选项。 在周五的时候,我组建的 ChatGPT 交流群的伙伴已经发现了 GPT-4o 这个选项了,是在没有充值升…

NSSCTF-Web题目21(文件上传-phar协议、RCE-空格绕过)

目录 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 3、思路 [FSCTF 2023]细狗2.0 4、题目 5、知识点 6、思路 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 文件上传,phar伪协议 3、思路 点击upload,看看 这里提示我们可以上传图片或压缩包&…

应对.Kastaneya勒索病毒:保护您的数据安全

导言: 随着科技的发展,网络安全问题也日益严峻。最近,一种名为.Kastaneya的勒索病毒开始在网络上出现,对用户的计算机和数据造成严重威胁。本文91数据恢复将介绍.Kastaneya勒索病毒的特点及其传播方式,并提供一些有效…

Unity 解包工具(AssetStudio/UtinyRipper)

文章目录 1.UtinyRipper2.AssetStudio 1.UtinyRipper 官方地址: https://github.com/mafaca/UtinyRipper/ 下载步骤: 2.AssetStudio 官方地址: https://github.com/Perfare/AssetStudio 下载步骤:

【HarmonyOS NEXT】鸿蒙多线程Sendable开发

非共享模块在同一线程内只加载一次,在不同线程间会加载多次,单例类也会创建多次,导致数据不共享,在不同的线程内都会产生新的模块对象。 基础概念 Sendable协议 Sendable协议定义了ArkTS的可共享对象体系及其规格约束。符合Sen…

STM32mp157aaa按键中断实验

效果图&#xff1a; 源码&#xff1a; #include "key.h" void hal_key1_rcc_gpio_init() {// 使能GPIOF组RCC->MP_AHB4ENSETR | (0x1 << 5);// 设置引脚位输入模式GPIOF->MODER & (~(0X3 << 18));GPIOF->MODER & (~(0X3 << 16))…

[C++11] 退出清理函数(quick_exit at_quick_exit)

说明&#xff1a;在C11中&#xff0c;quick_exit和at_quick_exit是新增的快速退出功能&#xff0c;用于在程序终止时提供一种快速清理资源的方式。 quick_exit std::quick_exit函数允许程序快速退出&#xff0c;并且可以传递一个退出状态码给操作系统。与std::exit相比&#…

[今日一水]论坛该如何选择

想要搭建一个论坛其实选择是很多的&#xff0c;就比如国内的dz&#xff0c;国外的xenforo和flarum&#xff0c;具体还是根据的面向的用户和需求来&#xff0c;就比如flarum它的界面肯定是三个论坛里最现代化的&#xff0c;但是xenforo社区生态很强&#xff0c;而dz对于国内用户…

VMware创建新虚拟机教程(保姆级别)

&#x1f4e2; 续上一篇 最新超详细VMware虚拟机安装完整教程-CSDN博客 &#xff0c;本章将详细讲解VMware创建虚拟机。 一、创建新的虚拟机 点击【创建新的虚拟机】&#xff01; 点击【自定义&#xff08;高级&#xff09;】> 下一步&#xff01; > 默认下一步&#x…

耐克:老大的烦恼

股价暴跌20%&#xff0c;老大最近比较烦。 今天说说全球&#xff08;最&#xff09;大运动品牌——耐克。 最近耐克发布2023-2024财年业绩&#xff08;截止于2024.5.31&#xff09;&#xff0c;还是爆赚几百亿美元&#xff0c;还是行业第一&#xff0c;但业绩不及预期&#xf…

Redis为什么设计多个数据库

​关于Redis的知识前面已经介绍过很多了,但有个点没有讲,那就是一个Redis的实例并不是只有一个数据库,一般情况下,默认是Databases 0。 一 内部结构 设计如下: Redis 的源码中定义了 redisDb 结构体来表示单个数据库。这个结构有若干重要字段,比如: dict:该字段存储了…

backbone是什么?

在深度学习中&#xff0c;特别是计算机视觉领域&#xff0c;"backbone"&#xff08;骨干网络&#xff09;是指用于提取特征的基础网络。它通常是卷积神经网络&#xff08;CNN&#xff09;&#xff0c;其任务是从输入图像中提取高层次特征&#xff0c;这些特征然后被用…

【第12章】MyBatis-Plus条件构造器(下)

文章目录 前言一、使用 TypeHandler二、使用提示三、Wrappers四、线程安全性五、使用 Wrapper 自定义 SQL1.注意事项2.示例3. 使用方法 总结 前言 本章继续上章条件构造器相关内容。 一、使用 TypeHandler 在 wrapper 中使用 typeHandler 需要特殊处理利用 formatSqlMaybeWit…