【pytorch23】MNIST测试实战

理解

训练完之后也需要做测试
为什么要做test?
在这里插入图片描述
上图蓝色代表train的accuracy
下图蓝色代表train的loss
基本上符合预期,随着epoch增大,train的accuracy也会上升,loss也会一直下降,下降到一个较小的程度

但是如果只看train的情况的话,就会被欺骗,虽然accuracy高并且loss很低,你会以为这个算法就很好了,但是做其他的事情就不是特别好,这就是过拟合(overfitting)

deep learning表达能力非常强表达效果很好,在模型在训练数据上表现非常好,但在新的、未见过的数据上表现不佳的情况。这是因为模型学习到了训练数据中的特定噪声和细节,而不是更通用的特征。

如何缓解这种情况?
在train的时候做一个test,这个test使用validation set 验证集做的,在刚开始的阶段,蓝色的线在上升的时候,验证集的accuracy也会上升loss也与train基本一致,只不过是在训练集上面train在验证集上测试不一定完全符合,所以波动会有点大,很明显train会的更好,validation的表现(包括accuracy和loss)也会变的更好

说明在刚开始的阶段确实学到了一些通用的特征,随着时间的推移,就开始over fitting了,开始去记住一些噪声和细节,这样的话泛化能力会变差,所以在训练集上训练后,在验证集上测试的时候,accuracy会保持不变或者可能下降同样的loss也会巨幅的波动

深度学习所以并不是越训练越好,数据量和架构是核心问题,有一个好的结构再加上足够的数据才能取得一个好的结果

在这里插入图片描述
logits是一个是十个节点的向量,经过cross entropy loss(包含softmax和log和nll_loss)训练,得到loss和accuracy(经过softmax之后就变成了Y=i,i代表第i号节点的概率,只需要argmax之后就能得到概率最大所在的位置),这里对softmax之前和之后都做了一下argmax,其实是一样的效果,因为softmax不会改变单调性,即原来大的数据在softmax之后也会大

这是计算accuracy的基本流程

在这里插入图片描述
什么时候计算test的accuracy和loss

不能够每做一个batch就训练一次,这样就会花大量的时间做测试,不合理,尤其是对于大型数据集

一般情况:

  • 训练若干个batch做一次测试
  • 训练一个epoch做一次测试

如何做测试
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsdef load_data(batch_size):train_loader = torch.utils.data.DataLoader(datasets.MNIST('mnist_data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)return train_loader, test_loaderclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdef training(train_loader, net, device):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),loss.item()))def testing(test_loader, net, device):test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))global netif __name__ == '__main__':batch_size = 200learning_rate = 0.01epochs = 10train_loader, test_loader = load_data(batch_size)device = torch.device('cuda:0')net = MLP().to(device)optimizer = optim.SGD(net.parameters(), lr=learning_rate)criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):training(train_loader, net, device)testing(test_loader, net, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java:使用synchronized和Redis实现并发控制的区别

在线程同步中,synchronized和Redis虽然都可以用来实现并发控制,但它们的作用范围、机制以及性能特点存在显著差异。 1. 作用范围 synchronized: 是Java语言内置的关键字,用于实现线程间的同步。它作用于对象或代码块,可以确保同一…

你手上有offer吗?

作者:猿java。 ​顺便吆喝一声,如果你计算机、软件工程、电子等相关专业本科及以上学历,欢迎来共事,有个offer注意查收。 前端/后端/测试等均可投→技术大厂机会。 都说面试是 7分靠技术,3分靠技巧,今天我…

9.2 栅格图层符号化单波段灰度渲染

文章目录 前言单波段灰度QGis设置为单波段灰度二次开发代码实现单波段灰度 总结 前言 介绍栅格图层数据渲染之单波段灰度显示说明:文章中的示例代码均来自开源项目qgis_cpp_api_apps 单波段灰度 以“3420C_2010_327_RGB_LATLNG.tif”数据为例,在QGis中…

easy-poi实现动态列(标题)、多sheet导出excel

一个sheet动态导出、伪代码&#xff0c;创建填充后的workbook对象 List<Map<String, Object>>list new ArrayList<Map<String, Object>>(); HashMap<String, Object> map new HashMap<>(); map.put("name", "小明")…

启动完 kubelet 日志显示 failed to get azure cloud in GetVolumeLimits, plugin.host: 1

查看 kubelet 日志组件命令 journalctl -xefu kubelet 文字描述问题 Jul 09 07:45:17 node01 kubelet[1344]: I0709 07:45:17.410786 1344 operation_generator.go:568] MountVolume.SetUp succeeded for volume "default-token-mfzqf" (UniqueName: "ku…

全光谱灯和普通led灯的区别?忠告行业三大隐患弊端!

随着社会的迅猛发展和生活步伐的加速&#xff0c;科技产品层出不穷&#xff0c;其中全光谱灯作为书房的新宠&#xff0c;备受瞩目。它是否真如其宣传的那样具有多重优势&#xff0c;尤其是对那些格外注重视力健康的人群而言&#xff0c;全光谱灯是否会带来潜在的健康风险&#…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第一篇 嵌入式Linux入门篇-第十二章 Linux 权限管理

i.MX8MM处理器采用了先进的14LPCFinFET工艺&#xff0c;提供更快的速度和更高的电源效率;四核Cortex-A53&#xff0c;单核Cortex-M4&#xff0c;多达五个内核 &#xff0c;主频高达1.8GHz&#xff0c;2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

Java基础(十六):String的常用API

目录 一、构造器方法二、String与字节数组的转换&#xff08;编码与解码&#xff09;1、字符串 --> 字节数组&#xff1a;&#xff08;编码&#xff09;2、字节数组 --> 字符串&#xff1a;&#xff08;解码&#xff09;3、iso-8859-1的特殊用法4、byte数组的数字表示 三…

Java版Flink使用指南——从RabbitMQ中队列中接入消息流

大纲 创建RabbitMQ队列新建工程新增依赖编码设置数据源配置读取、处理数据完整代码 打包、上传和运行任务测试 工程代码 在《Java版Flink使用指南——安装Flink和使用IntelliJ制作任务包》一文中&#xff0c;我们完成了第一个小型Demo的编写。例子中的数据是代码预先指定的。而…

判断对象能否回收的两种方法,以及JVM引用

判断对象能否回收的两种方法&#xff1a;引用计数算法&#xff0c;可达性分析算法 引用计数算法&#xff1a;给对象添加一个引用计数器&#xff0c;当该对象被其它对象引用时计数加一&#xff0c;引用失效时计数减一&#xff0c;计数为0时&#xff0c;可以回收。 特点&#xf…

自动驾驶SLAM又一开源巅峰之作!深挖时间一致性,精准构建超清地图

论文标题&#xff1a; DTCLMapper: Dual Temporal Consistent Learning for Vectorized HD Map Construction 论文作者&#xff1a; Siyu Li, Jiacheng Lin, Hao Shi, Jiaming Zhang, Song Wang, You Yao, Zhiyong Li, Kailun Yang 导读&#xff1a; 本文介绍了一种用于自动…

突发!马斯克3140亿参数Grok开源!Grok原理大公开!

BIG NEWS: 全球最大开源大模型&#xff01;马斯克Grok-1参数量3410亿&#xff0c;正式开源!!! 说到做到&#xff0c;马斯克xAI的Grok&#xff0c;果然如期开源了&#xff01; 就在刚刚&#xff0c;马斯克的AI创企xAI正式发布了此前备受期待大模型Grok-1&#xff0c;其参数量达…

硅纪元视角 | 虚拟神经科学的突破:AI「赛博老鼠」诞生

在数字化浪潮的推动下&#xff0c;人工智能&#xff08;AI&#xff09;正成为塑造未来的关键力量。硅纪元视角栏目紧跟AI科技的最新发展&#xff0c;捕捉行业动态&#xff1b;提供深入的新闻解读&#xff0c;助您洞悉技术背后的逻辑&#xff1b;汇聚行业专家的见解&#xff0c;…

企业需要什么样的MES?

MES&#xff08;英文全称&#xff1a;Manufacturing Execution System&#xff09;&#xff0c;即制造执行系统&#xff0c;是面向车间生产的管理系统。它位于上层计划管理系统&#xff08;如ERP&#xff09;与底层工业控制&#xff08;如PCS层&#xff09;之间&#xff0c;是制…

【Linux】:服务器用户的登陆、删除、密码修改

用Xshell登录云服务器。 1.登录云服务器 先打开Xshell。弹出的界面点。 在终端上输入命令ssh usernameip_address&#xff0c;其中username为要登录的用户名&#xff0c;ip_address为Linux系统的IP地址或主机名。 然后输入密码进行登录。 具体如下&#xff1a; 找到新建会话…

Windows与time.windows.com同步time出错(手把手操作)

今天我来针对Windows讲解Time同步 时间问题 计算机的时间不同&#xff0c;过快或者过慢。&#xff08;可以和自己的手机时间进行对比&#xff0c;手机的时间进行同步的频率会比计算机更快&#xff0c;因此更精准&#xff09;计算机time过快和过慢&#xff0c;会导致使用过程中…

想实现随时随地远程访问?解析可道云teamOS内网穿透功能

在数字化时代&#xff0c;无论是个人还是企业&#xff0c;都面临着数据共享与远程访问的迫切需求。 比如我有时会需要在家中加班&#xff0c;急需访问公司内网中的某个关键文件。 然而&#xff0c;由于公网与内网的天然隔阂&#xff0c;这些需求往往难以实现。这时&#xff0c…

代码随想录 链表章节总结

移除链表元素 && 设计链表 学会设置虚拟头结点 翻转链表 leetcode 206 https://leetcode.cn/problems/reverse-linked-list/description/ 方法一&#xff1a;非递归新开链表 头插法&#xff1a;创建一个新的链表&#xff0c;遍历旧链表&#xff0c;按顺序在新链表使…

AIGC | 在机器学习工作站安装NVIDIA CUDA® 并行计算平台和编程模型

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] 0x02.初识与安装 CUDA 并行计算平台和编程模型 什么是 CUDA? CUDA&#xff08;Compute Unified Device Architecture&#xff09;是英伟达&#xff08;NVIDIA&#xff09;推出的并行计算平台和编…

idea提交代码或更新代码一直提示token然后登陆失败无法提交或者更新代码

最近因为换了电脑需要对开发环境做配置&#xff0c; 遇到了这个问题&#xff0c; 应该是因为我们用到了gitlab&#xff0c;默认的最新的idea会有gitlab插件 强制录入gitlab的token&#xff0c;如果gitlab不支持token的验证那么问题就来了 &#xff0c; 不管怎么操作都无法提交或…