《昇思25天学习打卡营第8天|CarpeDiem》

《昇思25天学习打卡营第8天|CarpeDiem》

  • 模型训练
    • 构建数据集
    • 定义神经网络模型
    • 定义超参、损失函数和优化器
      • 超参
      • 损失函数
      • 优化器
    • 训练与评估

打卡

在这里插入图片描述

今天是昇思25天学习打卡营的第8天,终于迎来 模型训练 的部分了!!!

兴奋 发癫

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

评估的时候也可以采用第二天的方式 创建一个 loss_history 的数组 将每次的loss值计算出来存进去,然后借助 matplotlab 模块将loss数据可视化展示,这样可以更加直观的感受到模型训练的过程和结果的好坏

构建数据集

首先从数据集 Dataset加载代码,构建数据集。

老生常谈,没什么好说的了 不会就***给我去前面的篇章看

import mindspore
from mindspore import nn
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset# Download data from open datasets
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)def datapipe(path, batch_size):image_transforms = [vision.Rescale(1.0 / 255.0, 0),vision.Normalize(mean=(0.1307,), std=(0.3081,)),vision.HWC2CHW()]label_transform = transforms.TypeCast(mindspore.int32)dataset = MnistDataset(path)dataset = dataset.map(image_transforms, 'image')dataset = dataset.map(label_transform, 'label')dataset = dataset.batch(batch_size)return datasettrain_dataset = datapipe('MNIST_Data/train', batch_size=64)
test_dataset = datapipe('MNIST_Data/test', batch_size=64)

在这里插入图片描述

定义神经网络模型

从网络构建中加载代码,构建一个神经网络模型。

不会就***给我去前面的篇章看

不好意思这个没单独讲过

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsclass Network_new(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 1024),nn.ReLU(),nn.Dense(1024, 512),nn.ReLU(),nn.Dense(512, 128),nn.ReLU(),nn.Dense(128, 10),)def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()
model_new = Network_new()

这里的神经网络模型和第二天的神经网络模型是一模一样的

都是将 28*28 的图片先线性变换为 512 再线性变化成 512 再线性变化成 10 得到 10 个类别的特征值

Network_new 是创建的一个新的模型 测试一下多添加一些新的网络层后 其训练结果是怎么样变换的

采用 28*28 -> 512 -> 1024 -> 512 -> 128 -> 10 的形式

可以发现多添加了两层分别为 1024 和 128 个神经元

定义超参、损失函数和优化器

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下:

w t + 1 = w t − η 1 n ∑ x ∈ B ∇ l ( x , w t ) w_{t+1}=w_{t}-\eta \frac{1}{n} \sum_{x \in \mathcal{B}} \nabla l\left(x, w_{t}\right) wt+1=wtηn1xBl(x,wt)

公式中, n n n是批量大小(batch size), η η η是学习率(learning rate)。另外, w t w_{t} wt为训练轮次 t t t中的权重参数, ∇ l \nabla l l为损失函数的导数。除了梯度本身,这两个因子直接决定了模型的权重更新,从优化本身来看,它们是影响模型性能收敛最重要的参数。一般会定义以下超参用于训练:

  • 训练轮次(epoch):训练时遍历数据集的次数。

  • 批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。

  • 学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

epochs = 3
batch_size = 64
learning_rate = 1e-2

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的nn.MSELoss(均方误差)和用于分类的nn.NLLLoss(负对数似然)等。 nn.CrossEntropyLoss 结合了nn.LogSoftmaxnn.NLLLoss,可以对logits 进行归一化并计算预测误差。

loss_fn = nn.CrossEntropyLoss()
loss_fn_new = nn.CrossEntropyLoss()

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过model.trainable_params()方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)
optimizer_new = nn.SGD(model_new.trainable_params(), learning_rate=learning_rate)

在训练过程中,通过微分函数可计算获得参数对应的梯度,将其传入优化器中即可实现参数优化,具体形态如下:

grads = grad_fn(inputs)

optimizer(grads)

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。

接下来我们定义用于训练的train_loop函数和用于测试的test_loop函数。

使用函数式自动微分,需先定义正向函数forward_fn,使用value_and_grad获得微分函数grad_fn。然后,我们将微分函数和优化器的执行封装为train_step函数,接下来循环迭代数据集进行训练即可。

# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logitsdef forward_fn_new(data, label):logits = model(data)loss = loss_fn_new(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
grad_fn_new = mindspore.value_and_grad(forward_fn_new, None, optimizer_new.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_step_new(data, label):(loss, _), grads = grad_fn_new(data, label)optimizer_new(grads)return loss# 相较第二天的新加入了一个参数 loss_history 使得可以记录历史 loss 值
def train_loop(model, dataset,loss_history):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)# 添加损失到 loss_historyloss_history.append(loss)if batch % 100 == 0:loss, current = loss.asnumpy(), batch# # 添加损失到 loss_history# loss_history.append(loss)print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

test_loop函数同样需循环遍历数据集,调用模型计算loss和Accuray并返回最终结果。

def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

我们将实例化的损失函数和优化器传入train_looptest_loop中。训练3轮并输出loss和Accuracy,查看性能变化。

# use a loss_history array to save losses for visual display 
# 用一个 loss_history 数组去存储所有的损失值 loss 以便进行可视化展示
loss_history = []
loss_history_new = []loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)loss_fn_new = nn.CrossEntropyLoss()
optimizer_new = nn.SGD(model_new.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset, loss_history)test_loop(model, test_dataset, loss_fn)
print("Done!")for t in range(epochs):print(f"Epoch_new {t+1}\n-------------------------------")train_loop(model_new, train_dataset, loss_history_new)test_loop(model_new, test_dataset, loss_fn_new)
print("Done_new!")

在这里插入图片描述

下面导入 matplotlib 模块来进行可视化展示

import matplotlib.pyplot as plt
plt.plot(loss_history)

在这里插入图片描述

plt.plot(loss_history_new)

在这里插入图片描述

就上面两个图来看,第一个训练的结果还是蛮不错的,第二个的loss抖动太多,虽然幅度不大(0.05–0.35)但是还是不如第一个,毕竟加了两层,白瞎了

所以,在训练模型的时候一定要记得化繁为简,多大规模的模型干多大规模的事情,要不然很容易出现譬如:欠拟合、过拟合等各种各样的问题 还得好好调教

这就是今天的全部内容了,别忘了点赞收藏加关注 别逼我求你!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/38316.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数据库。

数据库安全性 论述题5’ 编程题10’ sql语言实现权限控制 一、概述 1、不安全因素 (1)⾮授权对数据库的恶意存取和破坏 (2)数据库中重要的数据泄露 (3)安全环境的脆弱性 2、⾃主存取控制⽅法 gr…

基于KMeans的航空公司客户数据聚类分析

💐大家好!我是码银~,欢迎关注💐: CSDN:码银 公众号:码银学编程 实验目的和要求 会用Python创建Kmeans聚类分析模型使用KMeans模型对航空公司客户价值进行聚类分析会对聚类结果进行分析评价 实…

Linux修炼之路之进程概念,fork函数,进程状态

目录 一:进程概念 二:Linux中的进程概念 三:用getpid(),getppid()获取该进程的PID,PPID 四:用fork()来创建子进程 五:操作系统学科的进程状态 六:Linux中的进程状态 接下来的日子会顺顺利利&#xf…

配置windows环境下独立浏览器爬虫方案【不依赖系统环境与chrome】

引言 由于部署浏览器爬虫的机器浏览器版本不同,同时也不想因为部署了爬虫导致影响系统浏览器数据,以及避免爬虫过程中遇到的chrome与webdriver版本冲突。我决定将特定版本的chrome浏览器与webdriver下载到项目目录内,同时chrome_driver在初始…

我使用 GPT-4o 帮我挑西瓜

在 5 月 15 日,OpenAI 旗下的大模型 GPT-4o 已经发布,那时网络上已经传开, 但很多小伙伴始终没有看到 GPT-4o 的体验选项。 在周五的时候,我组建的 ChatGPT 交流群的伙伴已经发现了 GPT-4o 这个选项了,是在没有充值升…

NSSCTF-Web题目21(文件上传-phar协议、RCE-空格绕过)

目录 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 3、思路 [FSCTF 2023]细狗2.0 4、题目 5、知识点 6、思路 [NISACTF 2022]bingdundun~ 1、题目 2、知识点 文件上传,phar伪协议 3、思路 点击upload,看看 这里提示我们可以上传图片或压缩包&…

Unity 解包工具(AssetStudio/UtinyRipper)

文章目录 1.UtinyRipper2.AssetStudio 1.UtinyRipper 官方地址: https://github.com/mafaca/UtinyRipper/ 下载步骤: 2.AssetStudio 官方地址: https://github.com/Perfare/AssetStudio 下载步骤:

STM32mp157aaa按键中断实验

效果图&#xff1a; 源码&#xff1a; #include "key.h" void hal_key1_rcc_gpio_init() {// 使能GPIOF组RCC->MP_AHB4ENSETR | (0x1 << 5);// 设置引脚位输入模式GPIOF->MODER & (~(0X3 << 18));GPIOF->MODER & (~(0X3 << 16))…

VMware创建新虚拟机教程(保姆级别)

&#x1f4e2; 续上一篇 最新超详细VMware虚拟机安装完整教程-CSDN博客 &#xff0c;本章将详细讲解VMware创建虚拟机。 一、创建新的虚拟机 点击【创建新的虚拟机】&#xff01; 点击【自定义&#xff08;高级&#xff09;】> 下一步&#xff01; > 默认下一步&#x…

耐克:老大的烦恼

股价暴跌20%&#xff0c;老大最近比较烦。 今天说说全球&#xff08;最&#xff09;大运动品牌——耐克。 最近耐克发布2023-2024财年业绩&#xff08;截止于2024.5.31&#xff09;&#xff0c;还是爆赚几百亿美元&#xff0c;还是行业第一&#xff0c;但业绩不及预期&#xf…

Redis为什么设计多个数据库

​关于Redis的知识前面已经介绍过很多了,但有个点没有讲,那就是一个Redis的实例并不是只有一个数据库,一般情况下,默认是Databases 0。 一 内部结构 设计如下: Redis 的源码中定义了 redisDb 结构体来表示单个数据库。这个结构有若干重要字段,比如: dict:该字段存储了…

scikit-learn教程

scikit-learn&#xff08;通常简称为sklearn&#xff09;是Python中最受欢迎的机器学习库之一&#xff0c;它提供了各种监督和非监督学习算法的实现。下面是一个基本的教程&#xff0c;涵盖如何使用sklearn进行数据预处理、模型训练和评估。 1. 安装和导入包 首先确保安装了…

【漏洞复现】D-Link NAS 未授权RCE漏洞(CVE-2024-3273)

0x01 产品简介 D-Link 网络存储 (NAS)是中国友讯&#xff08;D-link&#xff09;公司的一款统一服务路由器。 0x02 漏洞概述 D-Link NAS nas_sharing.cgi接口存在命令执行漏洞&#xff0c;该漏洞存在于“/cgi-bin/nas_sharing.cgi”脚本中&#xff0c;影响其 HTTP GET 请求处…

Java 汉诺塔问题 详细分析

汉诺塔 汉诺塔&#xff08;Tower of Hanoi&#xff09;&#xff0c;又称河内塔&#xff0c;是一个源于印度古老传说的益智玩具。大梵天创造世界的时候做了三根金刚石柱子&#xff0c;在一根柱子上从下往上按照大小顺序摞着64片黄金圆盘。大梵天命令婆罗门把圆盘从下面开始按大小…

vulnhub靶场ai-web 2.0

1 信息收集 1.1 主机发现 arp-scan -l 主机地址为192.168.1.4 1.2 服务端口扫描 nmap -sS -sV -A -T5 -p- 192.168.1.4 开放22&#xff0c;80端口 2 访问服务 2.1 80端口访问 http://192.168.1.4:80/ 先尝试admin等其他常见用户名登录无果 然后点击signup发现这是一个注…

prescan软件中导入路径文件txt/lpx

由于博主收到的是lpx格式的路径文件&#xff0c;因此&#xff0c;第一步 1.记事本打开 ctrla 全选 ctrlc 复制 2.新建一个excel 鼠标定位到第一行第一列的格子 ctrlv 复制 3.数据栏“分列”功能 4. (0.1递增的数列&#xff0c;纬度&#xff0c;经度&#xff0c;高程) 导入…

JDBC操作流程

目录 简介 具体操作 1. 引入驱动包 1&#xff09;下载驱动包 2&#xff09;引入驱动包到项目中 2. 编写代码 1&#xff09;创建数据源 2&#xff09;建立连接 3&#xff09;构造 SQL 语句 4&#xff09;执行 SQL 语句 5&#xff09;释放资源 总结 简介 JDBC 就是使…

某网页gpt的JS逆向

原网页网址 (base64) 在线解码 aHR0cHM6Ly9jbGF1ZGUzLmZyZWUyZ3B0Lnh5ei8 逆向效果图 调用代码&#xff08;复制即用&#xff09; 把倒数第三行换成下面的base64解码 aHR0cHM6Ly9jbGF1ZGUzLmZyZWUyZ3B0Lnh5ei9hcGkvZ2VuZXJhdGU import hashlib import time import reques…

C语言+ MSSQL技术开发的 PACS系统源码:CT后处理技术之仿真内镜CTVE

C语言 MSSQL技术开发的 PACS系统源码&#xff1a;CT后处理技术之仿真内镜CTVE 仿真内窥镜VE VE是利用医学影像作为原始数据&#xff0c;融合图像处理、计算机图形学、科学计算可视化、虚拟现实技术&#xff0c;模拟传统光学内镜的一种技术。 又叫做腔内重建技术&#xff0c;是…

试用笔记之-汇通来电显示软件

首先汇通来电显示软件下载 http://www.htsoft.com.cn/download/httelephone.rar