机器学习之模型训练

前言

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

有了数据集和模型后,可以进行模型的训练与评估。

构建数据集

定义神经网络模型

class Network(nn.Cell):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.dense_relu_sequential = nn.SequentialCell(nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))def construct(self, x):x = self.flatten(x)logits = self.dense_relu_sequential(x)return logitsmodel = Network()

从网络构建中加载代码,构建一个神经网络模型。

定义超参、损失函数和优化器

超参

超参数是可以调整的参数,可以控制深度学习模型训练优化的过程,包括训练轮次、批次大小和学习率等。这些超参数的取值会影响模型的训练和收敛速度,其中学习率在迭代过程中控制模型的学习进度。

损失函数

损失函数用于评估模型预测值和目标值之间的误差,帮助模型降低误差并提高预测准确性。常见的损失函数包括均方误差和负对数似然,用于回归和分类任务。nn.CrossEntropyLoss结合了多种损失函数的功能,对模型的预测结果进行归一化并计算误差。

优化器

模型优化是通过调整模型参数来减少模型误差的过程,MindSpore提供了多种优化算法的实现,称之为优化器。优化器内部定义了模型参数优化过程,所有优化逻辑都封装在优化器对象中。在这里,使用了SGD(随机梯度下降)优化器。

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:训练和验证/测试。在训练阶段,模型通过迭代训练数据集来调整参数,以尝试收敛到最佳参数。而在验证/测试阶段,模型通过迭代测试数据集来评估模型的性能是否提升。这种流程的循环迭代可以帮助模型不断学习和优化,以达到更好的性能和准确度。

# Define forward function
def forward_fn(data, label):logits = model(data)loss = loss_fn(logits, label)return loss, logits# Get gradient function
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)# Define function of one-step training
def train_step(data, label):(loss, _), grads = grad_fn(data, label)optimizer(grads)return lossdef train_loop(model, dataset):size = dataset.get_dataset_size()model.set_train()for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):loss = train_step(data, label)if batch % 100 == 0:loss, current = loss.asnumpy(), batchprint(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")
def test_loop(model, dataset, loss_fn):num_batches = dataset.get_dataset_size()model.set_train(False)total, test_loss, correct = 0, 0, 0for data, label in dataset.create_tuple_iterator():pred = model(data)total += len(data)test_loss += loss_fn(pred, label).asnumpy()correct += (pred.argmax(1) == label).asnumpy().sum()test_loss /= num_batchescorrect /= totalprint(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
loss_fn = nn.CrossEntropyLoss()
optimizer = nn.SGD(model.trainable_params(), learning_rate=learning_rate)for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(model, train_dataset)test_loop(model, test_dataset, loss_fn)
print("Done!")

总结

模型训练一般包括构建数据集、定义神经网络模型、定义超参数、损失函数和优化器,以及输入数据集进行训练和评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/43173.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AtCoder Beginner Contest 360

A - A Healthy Breakfast 枚举一下&#xff0c;只要R在M之前就行了 #include <iostream>using namespace std;int main() {char a,b,c;cin >> a >> b >> c;if((a R && (b M || c M)) || (b R && c M)){cout << "Yes…

【trition-server】运行一个pytorch的ngc镜像

ngc 提供了pytorch容器 号称是做了gpu加速的 我装的系统版本是3.8的python,但是pytorch似乎是用conda安装的3.5的: torch的python库是ls支持gpu加速是真的 英伟达的pytorch的说明书 root@a79bc3874b9d:/opt/pytorch# cat NVREADME.md PyTorch ======= PyTorch is a python …

顶会FAST24最佳论文|阿里云块存储架构演进的得与失-1.引言

今年早些时候&#xff0c;2月份举办的全球计算机存储顶会USENIX FAST 2024&#xff0c;最佳论文来自阿里云&#xff0c;论文名称《What’s the Story in EBS Glory: Evolutions and Lessons in Building Cloud Block Store》 &#xff0c;论文详尽地探讨了阿里云在过去十年中开…

HTML(28)——空间转换

空间&#xff1a;是从坐标轴角度定义的XYZ三条坐标轴构成了一个立体空间 Z轴位置与视线方向相同 空间转换 平移 属性&#xff1a; transform: translate3d(x,y,z);transform: translateX();transform: translateY();transform: translateZ(); 取值&#xff1a;像素单位数值…

国内教育科技公司自研大语言模型

好未来的数学大模型九章大模型&#xff08;MathGPT&#xff09; 2023年8月下旬&#xff0c;在好未来20周年直播活动中&#xff0c;好未来公司CTO田密宣布好未来自研的数学领域千亿级大模型MathGPT正式上线并开启公测。根据九章大模型的官网介绍&#xff0c;九章大模型&#xff…

强化学习编程实战-2马尔可夫决策过程

2.1 从多臂赌博机到马尔可夫决策过程 如图2-1&#xff0c;图中A为多臂赌博机&#xff0c;B为一堆鸳鸯&#xff0c;其中左上角为雄性鸳鸯&#xff0c;右上角为雌性鸳鸯&#xff0c;B展示的任务是雄性鸳鸯绕过障碍物找到词性鸳鸯。跟多臂赌博机不同的是&#xff0c;雄性鸳鸯经过一…

019-GeoGebra中级篇-GeoGebra的坐标系

GeoGebra作为一款强大的数学软件&#xff0c;支持多种坐标系的使用&#xff0c;包括但不限于&#xff1a;笛卡尔坐标系&#xff08;Cartesian Coordinate System&#xff09;、极坐标系&#xff08;Polar Coordinate System&#xff09;、参数坐标系&#xff08;Parametric Coo…

虚拟机使用

1、安装 如何安装虚拟机&#xff1f;保姆级安装教程&#xff01; - 知乎 (zhihu.com) 2、使用 2.1 快照 作用&#xff1a;保留当前系统信息为快照&#xff0c;随时可以恢复&#xff0c;以防未来系统被你玩坏&#xff0c;就好比游戏中的归档&#xff01;每配置好一个就可以保…

Linux dig命令常见用法

Linux dig命令常见用法 一、dig安装二、dig用法 DIG命令(Domain Information Groper命令)是常用的域名查询工具&#xff0c;通过此命令&#xff0c;你可以实现域名查询和域名问题的定位&#xff0c;对于网络管理员和在域名系统(DNS)领域工作的小伙伴来说&#xff0c;它是一个非…

昇思MindSpore学习笔记6-01LLM原理和实践--FCN图像语义分割

摘要&#xff1a; 记录MindSpore AI框架使用FCN全卷积网络理解图像进行图像语议分割的过程、步骤和方法。包括环境准备、下载数据集、数据集加载和预处理、构建网络、训练准备、模型训练、模型评估、模型推理等。 一、概念 1.语义分割 图像语义分割 semantic segmentation …

【计算机毕业设计】018基于weixin小程序实习记录

&#x1f64a;作者简介&#xff1a;拥有多年开发工作经验&#xff0c;分享技术代码帮助学生学习&#xff0c;独立完成自己的项目或者毕业设计。 代码可以私聊博主获取。&#x1f339;赠送计算机毕业设计600个选题excel文件&#xff0c;帮助大学选题。赠送开题报告模板&#xff…

【Java系列】深入解析 Lambda表达式

简化这个代码 这个就是Lambda表达式,可以简化匿名内部类的写法 package lambda;public class demo2 {public static void main(String[] args) {//第二个参数是一个接口,所以我们在调用方法的时候,需要传递这个接口的实现类对象--接口多态// 但是这个实现类,我只要用一次,所以我…

@Builder注解详解:巧妙避开常见的陷阱

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 &#x1f38f;&#xff1a;你只管努力&#xff0c;剩下的交给时间 &#x1f3e0; &#xff1a;小破站 Builder注解详解&#xff1a;巧妙避开常见的陷阱 前言1. Builder的基本使用使用示例示例类创建对…

极客时间:使用Autogen Builder和本地LLM(Microsoft Phi3模型)在Mac上创建本地AI代理

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

运维系列.Nginx:自定义错误页面

运维系列 Nginx&#xff1a;自定义错误页面 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite&#xff1a;http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress of this article:https://blog.csdn.net/…

本地部署秘塔开源搜索引擎

秘塔AI搜索是由秘塔科技于2024年初推出的一款新型搜索引擎&#xff0c;被业界誉为“中国版的Perplexity”。秘塔科技成立于2018年4月&#xff0c;其核心团队包括CEO闵可锐、技术专家唐悦和首席运营官王益为等。秘塔AI搜索以其高效简洁的特点受到关注&#xff0c;其搜索结果直接…

LeetCode——第 405 场周赛

题目 找出加密后的字符串 给你一个字符串 s 和一个整数 k。请你使用以下算法加密字符串&#xff1a; 对于字符串 s 中的每个字符 c&#xff0c;用字符串中 c 后面的第 k 个字符替换 c&#xff08;以循环方式&#xff09;。 返回加密后的字符串。 示例 1&#xff1a; 输入&…

谷粒商城学习笔记-16-人人开源搭建后台管理系统

文章目录 一&#xff0c;克隆前/后端代码1&#xff0c;克隆前端工程renren-fast-value2&#xff0c;克隆后端工程renren-fast 二&#xff0c;集成后台管理系统的后端代码三&#xff0c;启动后台管理系统四&#xff0c;前端系统的安装和运行1&#xff0c;下载安装VSCode2&#x…

为什么KV Cache只需缓存K矩阵和V矩阵,无需缓存Q矩阵?

大家都知道大模型是通过语言序列预测下一个词的概率。假定{ x 1 x_1 x1​&#xff0c; x 2 x_2 x2​&#xff0c; x 3 x_3 x3​&#xff0c;…&#xff0c; x n − 1 x_{n-1} xn−1​}为已知序列&#xff0c;其中 x 1 x_1 x1​&#xff0c; x 2 x_2 x2​&#xff0c; x 3 x_3 x…

STM32对数码管显示的控制

1、在项目开发过程中会遇到STM32控制的数码管显示应用&#xff0c;这里以四位共阴极数码管显示控制为例讲解&#xff1b;这里采用的控制芯片为STM32F103RCT6。 2、首先要确定数码管的段选的8个引脚连接的单片机的引脚是哪8个&#xff0c;然后确认位选的4个引脚连接的单片机的4…