Lnton羚通关于Optimization在【PyTorch】中的基础知识

OPTIMIZING MODEL PARAMETERS (模型参数优化)
现在我们有了模型和数据,是时候通过优化数据上的参数来训练了,验证和测试我们的模型。训练一个模型是一个迭代的过程,在每次迭代中,模型会对输出进行猜测,计算猜测数据与真实数据的误差(损失),收集误差对其参数的导数(正如前一节我们看到的那样),并使用梯度下降优化这些参数。

Prerequisite Code ( 先决代码 )
We load the code from the previous sections on

import torch 
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transformstraining_data = datasets.FashionMNIST(root = "../../data/",train = True,download = True, transform = transforms.ToTensor()
)test_data = datasets.FashionMNIST(root = "../../data/",train = False,download = True, transform = transforms.ToTensor()
)train_dataloader = DataLoader(training_data, batch_size = 32, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = 32, shuffle = True)class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10)  )def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return outmodel = NeuralNetwork()

Hyperparameters ( 超参数 )
超参数是可调节的参数,允许控制模型优化过程,不同的超参数会影响模型的训练和收敛速度。read more

我们定义如下的超参数进行训练:

Number of Epochs: 遍历数据集的次数
Batch Size: 每一次使用的数据集大小,即每一次用于训练的样本数量
Learning Rate: 每个 batch/epoch 更新模型参数的速度,较小的值会导致较慢的学习速度,而较大的值可能会导致训练过程中不可预测的行为,例如训练抖动频繁,有可能会发散等。

learning_rate = 1e-3
batch_size = 32
epochs = 5

Optimization Loop ( 优化循环 )
我们设置完超参数后,就可以利用优化循环训练和优化模型;优化循环的每次迭代称为一个 epoch, 每个 epoch 包含两个主要部分:

The Train Loop: 遍历训练数据集并尝试收敛到最优参数。
The Validation/Test Loop: 验证/测试循环—遍历测试数据集以检查模型性能是否得到改善。
让我们简单地熟悉一下训练循环中使用的一些概念。跳转到前面以查看优化循环的完整实现。

Loss Function ( 损失函数 )
当给出一些训练数据时,我们未经训练的网络可能不会给出正确的答案。 Loss function 衡量的是得到的结果与目标值的不相似程度,是我们在训练过程中想要最小化的 Loss function。为了计算 loss ,我们使用给定数据样本的输入进行预测,并将其与真实的数据标签值进行比较。

常见的损失函数包括nn.MSELoss (均方误差)用于回归任务,nn.NLLLoss(负对数似然)用于分类神经网络。nn.CrossEntropyLoss 结合 nn.LogSoftmax 和 nn.NLLLoss 。

我们将模型的输出 logits 传递给 nn.CrossEntropyLoss ,它将规范化 logits 并计算预测误差。

# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()

Optimizer ( 优化器 )
优化是在每个训练步骤中调整模型参数以减少模型误差的过程。优化算法定义了如何执行这个过程(在这个例子中,我们使用随机梯度下降)。所有优化逻辑都封装在优化器对象中。这里,我们使用 SGD 优化器; 此外,PyTorch 中还有许多不同的优化器,如 ADAM 和 RMSProp ,它们可以更好地用于不同类型的模型和数据。

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

在训练的循环中,优化分为3个步骤:

调用 optimizer.zero_grad() 重置模型参数的梯度,默认情况下,梯度是累加的。为了防止重复计算,我们在每次迭代中显式将他们归零。
通过调用 loss.backward() 反向传播预测损失, PyTorch 保存每个参数的损失梯度。
一旦我们有了梯度,我们调用 optimizer.step() 在向后传递中收集梯度调整参数。
Full Implementation (完整实现)
我们定义了遍历优化参数代码的 train loop, 以及根据测试数据定义了test loop。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms## 数据集
training_data = datasets.FashionMNIST(root="../../data/",train=True,download=True,transform=transforms.ToTensor()
)test_data = datasets.FashionMNIST(root="../../data/",train=False,download=True,transform=transforms.ToTensor()
)## dataloader
train_dataloader = DataLoader(training_data, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=32, shuffle=True)## 定义神经网络
class NeuralNetwork(nn.Module):def __init__(self):super(NeuralNetwork, self).__init__()self.flatten = nn.Flatten()self.linear_relu_stack = nn.Sequential(nn.Linear(28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10))def forward(self, x):out = self.flatten(x)out = self.linear_relu_stack(out)return out## 实例化模型
model = NeuralNetwork()## 损失函数
loss_fn = nn.CrossEntropyLoss()## 优化器
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)## 超参数
learning_rate = 1e-3
batch_size = 32
epochs = 5## 训练循环
def train_loop(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):# 计算预测和损失pred = model(X)loss = loss_fn(pred, y)## 反向传播optimizer.zero_grad()loss.backward()optimizer.step()if batch % 100 == 0:loss, current = loss.item(), batch * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")## 测试循环
def test_loop(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")## 训练网络
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train_loop(train_dataloader, model, loss_fn, optimizer)test_loop(test_dataloader, model, loss_fn)
print("Done!")

Lnton羚通专注于音视频算法、算力、云平台的高科技人工智能企业。 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持ONVIF、RTSP、GB/T28181等多协议、多路数的音视频智能分析服务器/云平台。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/43027.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python3 0基础学习----数据结构(基础+练习)

python 0基础学习笔记之数据结构 📚 几种常见数据结构列表 (List)1. 定义2. 实例:3. 列表中常用方法.append(要添加内容) 向列表末尾添加数据.extend(列表) 将可迭代对象逐个添加到列表中.insert(索引,插入内容) 向指定…

国家一带一路和万众创业创新的方针政策指引下,Live Market探索跨境产业的创新发展

现代社会,全球经济互联互通,跨境产业也因此而崛起。为了推动跨境产业的创新发展,中国政府提出了“一带一路”和“万众创业、万众创新”的方针政策,旨在促进全球经济的互联互通和创新发展。在这个大环境下,Live Market积…

Mariadb高可用MHA

本节主要学习了Mariadb高可用MHA的概述,案例如何构建MHA 提示:以下是本篇文章正文内容,下面案例可供参考 一、概述 1、概念 MHA(MasterHigh Availability)是一套优秀的MySQL高可用环境下故障切换和主从复制的软件。…

合宙Air724UG LuatOS-Air LVGL API--简介

为何是 LVGL LVGL 是一个开源的图形库,它提供了创建嵌入式 GUI 所需的一切,具有易于使用的图形元素、漂亮的视觉效果和低内存占用的特点。 LVGL特点: 强大的 控件 :按钮、图表、列表、滑动条、图像等 高级图形引擎:动…

BIO、NIO和AIO

一.引言 何为IO 涉及计算机核心(CPU和内存)与其他设备间数据迁移的过程,就是I/O。数据输入到计算机内存的过程即输入,反之输出到外部存储(比如数据库,文件,远程主机)的过程即输出。 I/O 描述了计算机系统…

插入排序优化——超越归并排序的超级算法

插入排序及优化 插入排序算法算法讲解数据模拟代码 优化思路一、二分查找二、copy函数 优化后代码算法的用途题目:数星星(POJ2352 star)输入输出格式输入格式:输出格式 输入输出样例输入样例输出样例 题目讲解步骤如下AC 代码 插入…

Linux系统中基于NGINX的代理缓存配置指南

作为一名专业的爬虫程序员,你一定知道代理缓存在加速网站响应速度方面的重要性。而使用NGINX作为代理缓存服务器,能够极大地提高性能和效率。本文将为你分享Linux系统中基于NGINX的代理缓存配置指南,提供实用的解决方案,助你解决在…

C语言刷题训练DAY.8

1.计算单位阶跃函数 解题思路&#xff1a; 这个非常简单&#xff0c;只需要if else语句即可完成 解题代码&#xff1a; #include <stdio.h>int main() {int t 0;while(scanf("%d",&t)!EOF){if (t > 0)printf("1\n");else if (t < 0)pr…

大模型基础02:GPT家族与提示学习

大模型基础&#xff1a;GPT 家族与提示学习 从 GPT-1 到 GPT-3.5 GPT(Generative Pre-trained Transformer)是 Google 于2018年提出的一种基于 Transformer 的预训练语言模型。它标志着自然语言处理领域从 RNN 时代进入 Transformer 时代。GPT 的发展历史和技术特点如下: GP…

【校招VIP】java语言类和对象之map、set集合

考点介绍&#xff1a; map、set集合相关内容是校招面试的高频考点之一。 map和set是一种专门用来进行搜索的容器或者数据结构&#xff0c;其搜索效率与其具体的实例化子类有关系。 『java语言类和对象之map、set集合』相关题目及解析内容可点击文章末尾链接查看&#xff01; …

深入了解Maven(一)

目录 一.Maven介绍与功能 二.依赖管理 1.依赖的配置 2.依赖的传递性 3.排除依赖 4.依赖的作用范围 5.依赖的生命周期 一.Maven介绍与功能 maven是一个项目管理和构建工具&#xff0c;是基于对象模型POM实现。 Maven的作用&#xff1a; 便捷的依赖管理&#xff1a;使用…

【java安全】Log4j反序列化漏洞

文章目录 【java安全】Log4j反序列化漏洞关于Apache Log4j漏洞成因CVE-2017-5645漏洞版本复现环境漏洞复现漏洞分析 CVE-2019-17571漏洞版本漏洞复现漏洞分析 参考 【java安全】Log4j反序列化漏洞 关于Apache Log4j Log4j是Apache的开源项目&#xff0c;可以实现对System.out…

前端性能优化——包体积压缩插件,打包速度提升插件,提升浏览器响应的速率模式

前端代码优化 –其他的优化可以具体在网上搜索 压缩项目打包后的体积大小、提升打包速度&#xff0c;是前端性能优化中非常重要的环节&#xff0c;结合工作中的实践总结&#xff0c;梳理出一些 常规且有效 的性能优化建议 ue 项目可以通过添加–report命令&#xff1a; "…

innodb索引与算法

B树主键插入 B树在innodb的插入有三种模式page_last_insert, page_dirction, page_N_direction 而在bustub里面的B树就是page_N_direction,如果是自增主键的话&#xff0c;就是上面这样的插入法 FIC优化 (DDL) 选择性统计 覆盖索引 MMR ICP优化 自适应hash 全文索引 MySQL…

List和ObservableCollection和ListBinding在MVVM模式下的对比

List和ObservableCollection和ListBinding在MVVM模式下的对比 List 当对List进行增删操作后&#xff0c;并不会对View进行通知。 //Employee public class Employee : INotifyPropertyChanged {public event PropertyChangedEventHandler? PropertyChanged;public string N…

Vue-13.创建完整的Vue项目(vue+vue-cli+js)

前言 之前写了命令创建Vue项目&#xff0c;但是事实上我们可以直接用编译器直接创建项目&#xff0c;这里我使用webstorm&#xff08;因为我是前后端兼修的所以我习惯使用Idea家族的编译器&#xff09; 只写前端的推荐用VsCode前后端都写的推荐用webstorm 新建项目 项目初始…

vscode 安装勾选项解释

1、通过code 打开“操作添加到windows资源管理器文件上下文菜单 &#xff1a;把这个两个勾选上&#xff0c;可以对文件使用鼠标右键&#xff0c;选择VSCode 打开。 2、将code注册为受支持的文件类型的编辑器&#xff1a;不建议勾选&#xff0c;这样会默认使用VSCode打开支持的相…

《Linux从练气到飞升》No.15 Linux 环境变量

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…

SASS 学习笔记 II

SASS 学习笔记 II 上篇笔记&#xff0c;SASS 学习笔记 中包含&#xff1a; 配置 变量 嵌套 这里加一个扩展&#xff0c;嵌套中有一个 & 的用法&#xff0c;使用 & 可以指代当前 block 中的 selector&#xff0c;后面可以追加其他的选择器。如当前的 scope 是 form&a…

GuLi商城-前端基础Vue-使用Vue脚手架进行模块化开发

自己亲自实践&#xff1a; mac安装webpack webpack 简介Webpack 是一个非常流行的前端构建工具&#xff0c;它可以将多个模块&#xff08;包括CSS、JavaScript、图片等&#xff09;打包成一个或多个静态资源文件&#xff08;bundle&#xff09;&#xff0c;以便用于部署到生产…