【pytorch23】MNIST测试实战

理解

训练完之后也需要做测试
为什么要做test?
在这里插入图片描述
上图蓝色代表train的accuracy
下图蓝色代表train的loss
基本上符合预期,随着epoch增大,train的accuracy也会上升,loss也会一直下降,下降到一个较小的程度

但是如果只看train的情况的话,就会被欺骗,虽然accuracy高并且loss很低,你会以为这个算法就很好了,但是做其他的事情就不是特别好,这就是过拟合(overfitting)

deep learning表达能力非常强表达效果很好,在模型在训练数据上表现非常好,但在新的、未见过的数据上表现不佳的情况。这是因为模型学习到了训练数据中的特定噪声和细节,而不是更通用的特征。

如何缓解这种情况?
在train的时候做一个test,这个test使用validation set 验证集做的,在刚开始的阶段,蓝色的线在上升的时候,验证集的accuracy也会上升loss也与train基本一致,只不过是在训练集上面train在验证集上测试不一定完全符合,所以波动会有点大,很明显train会的更好,validation的表现(包括accuracy和loss)也会变的更好

说明在刚开始的阶段确实学到了一些通用的特征,随着时间的推移,就开始over fitting了,开始去记住一些噪声和细节,这样的话泛化能力会变差,所以在训练集上训练后,在验证集上测试的时候,accuracy会保持不变或者可能下降同样的loss也会巨幅的波动

深度学习所以并不是越训练越好,数据量和架构是核心问题,有一个好的结构再加上足够的数据才能取得一个好的结果

在这里插入图片描述
logits是一个是十个节点的向量,经过cross entropy loss(包含softmax和log和nll_loss)训练,得到loss和accuracy(经过softmax之后就变成了Y=i,i代表第i号节点的概率,只需要argmax之后就能得到概率最大所在的位置),这里对softmax之前和之后都做了一下argmax,其实是一样的效果,因为softmax不会改变单调性,即原来大的数据在softmax之后也会大

这是计算accuracy的基本流程

在这里插入图片描述
什么时候计算test的accuracy和loss

不能够每做一个batch就训练一次,这样就会花大量的时间做测试,不合理,尤其是对于大型数据集

一般情况:

  • 训练若干个batch做一次测试
  • 训练一个epoch做一次测试

如何做测试
在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsdef load_data(batch_size):train_loader = torch.utils.data.DataLoader(datasets.MNIST('mnist_data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.MNIST('mnist_data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)return train_loader, test_loaderclass MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdef training(train_loader, net, device):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader),loss.item()))def testing(test_loader, net, device):test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))global netif __name__ == '__main__':batch_size = 200learning_rate = 0.01epochs = 10train_loader, test_loader = load_data(batch_size)device = torch.device('cuda:0')net = MLP().to(device)optimizer = optim.SGD(net.parameters(), lr=learning_rate)criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):training(train_loader, net, device)testing(test_loader, net, device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/42364.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

9.2 栅格图层符号化单波段灰度渲染

文章目录 前言单波段灰度QGis设置为单波段灰度二次开发代码实现单波段灰度 总结 前言 介绍栅格图层数据渲染之单波段灰度显示说明:文章中的示例代码均来自开源项目qgis_cpp_api_apps 单波段灰度 以“3420C_2010_327_RGB_LATLNG.tif”数据为例,在QGis中…

全光谱灯和普通led灯的区别?忠告行业三大隐患弊端!

随着社会的迅猛发展和生活步伐的加速,科技产品层出不穷,其中全光谱灯作为书房的新宠,备受瞩目。它是否真如其宣传的那样具有多重优势,尤其是对那些格外注重视力健康的人群而言,全光谱灯是否会带来潜在的健康风险&#…

【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第一篇 嵌入式Linux入门篇-第十二章 Linux 权限管理

i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…

Java基础(十六):String的常用API

目录 一、构造器方法二、String与字节数组的转换(编码与解码)1、字符串 --> 字节数组:(编码)2、字节数组 --> 字符串:(解码)3、iso-8859-1的特殊用法4、byte数组的数字表示 三…

Java版Flink使用指南——从RabbitMQ中队列中接入消息流

大纲 创建RabbitMQ队列新建工程新增依赖编码设置数据源配置读取、处理数据完整代码 打包、上传和运行任务测试 工程代码 在《Java版Flink使用指南——安装Flink和使用IntelliJ制作任务包》一文中,我们完成了第一个小型Demo的编写。例子中的数据是代码预先指定的。而…

判断对象能否回收的两种方法,以及JVM引用

判断对象能否回收的两种方法:引用计数算法,可达性分析算法 引用计数算法:给对象添加一个引用计数器,当该对象被其它对象引用时计数加一,引用失效时计数减一,计数为0时,可以回收。 特点&#xf…

自动驾驶SLAM又一开源巅峰之作!深挖时间一致性,精准构建超清地图

论文标题: DTCLMapper: Dual Temporal Consistent Learning for Vectorized HD Map Construction 论文作者: Siyu Li, Jiacheng Lin, Hao Shi, Jiaming Zhang, Song Wang, You Yao, Zhiyong Li, Kailun Yang 导读: 本文介绍了一种用于自动…

突发!马斯克3140亿参数Grok开源!Grok原理大公开!

BIG NEWS: 全球最大开源大模型!马斯克Grok-1参数量3410亿,正式开源!!! 说到做到,马斯克xAI的Grok,果然如期开源了! 就在刚刚,马斯克的AI创企xAI正式发布了此前备受期待大模型Grok-1,其参数量达…

硅纪元视角 | 虚拟神经科学的突破:AI「赛博老鼠」诞生

在数字化浪潮的推动下,人工智能(AI)正成为塑造未来的关键力量。硅纪元视角栏目紧跟AI科技的最新发展,捕捉行业动态;提供深入的新闻解读,助您洞悉技术背后的逻辑;汇聚行业专家的见解,…

企业需要什么样的MES?

MES(英文全称:Manufacturing Execution System),即制造执行系统,是面向车间生产的管理系统。它位于上层计划管理系统(如ERP)与底层工业控制(如PCS层)之间,是制…

【Linux】:服务器用户的登陆、删除、密码修改

用Xshell登录云服务器。 1.登录云服务器 先打开Xshell。弹出的界面点。 在终端上输入命令ssh usernameip_address,其中username为要登录的用户名,ip_address为Linux系统的IP地址或主机名。 然后输入密码进行登录。 具体如下: 找到新建会话…

Windows与time.windows.com同步time出错(手把手操作)

今天我来针对Windows讲解Time同步 时间问题 计算机的时间不同,过快或者过慢。(可以和自己的手机时间进行对比,手机的时间进行同步的频率会比计算机更快,因此更精准)计算机time过快和过慢,会导致使用过程中…

想实现随时随地远程访问?解析可道云teamOS内网穿透功能

在数字化时代,无论是个人还是企业,都面临着数据共享与远程访问的迫切需求。 比如我有时会需要在家中加班,急需访问公司内网中的某个关键文件。 然而,由于公网与内网的天然隔阂,这些需求往往难以实现。这时&#xff0c…

代码随想录 链表章节总结

移除链表元素 && 设计链表 学会设置虚拟头结点 翻转链表 leetcode 206 https://leetcode.cn/problems/reverse-linked-list/description/ 方法一:非递归新开链表 头插法:创建一个新的链表,遍历旧链表,按顺序在新链表使…

AIGC | 在机器学习工作站安装NVIDIA CUDA® 并行计算平台和编程模型

[ 知识是人生的灯塔,只有不断学习,才能照亮前行的道路 ] 0x02.初识与安装 CUDA 并行计算平台和编程模型 什么是 CUDA? CUDA(Compute Unified Device Architecture)是英伟达(NVIDIA)推出的并行计算平台和编…

idea提交代码或更新代码一直提示token然后登陆失败无法提交或者更新代码

最近因为换了电脑需要对开发环境做配置, 遇到了这个问题, 应该是因为我们用到了gitlab,默认的最新的idea会有gitlab插件 强制录入gitlab的token,如果gitlab不支持token的验证那么问题就来了 , 不管怎么操作都无法提交或…

Spring MVC深入理解之源码实现

1、SpringMVC的理解 1)谈谈对Spring MVC的了解 MVC 是模型(Model)、视图(View)、控制器(Controller)的简写,其核心思想是通过将业务逻辑、数据、显示分离来组织代码。 Model:数据模型,JavaBean的类,用来进行数据封装…

单链表详解(2)

三、函数定义 查找节点 //查找结点 SLTNode* SLTNodeFind(SLTNode* phead, SLTDataType x) {assert(phead);SLTNode* pcur phead;while (pcur){if (pcur->data x){return pcur;}pcur pcur->next;}return NULL; } 查找节点我们是通过看数据域来查找的,查…

【MySQL05】【 undo 日志】

文章目录 一、前言二、undo 日志(回滚日志)1. 事务 id2. undo 日志格式2.1 INSERT 对应的 undo 日志2.2 DELETE 对应的 undo 日志2.3 UPDATE 对应的 undo 日志2.3.1 不更新主键2.3.2 更新主键 2.3 增删改操作对二级索引的影响2.4 roll_pointer 3. FIL_PA…

layui项目中的layui.define、layui.config以及layui.use的使用

第一步:创建一个layuiTest项目,结构如下 第二步:新建一个test.js,利用layui.define定义一个模块test,并向外暴露该模块,该模块里面有两个方法method1和method2. 第三步:新建一个test.html,在该页面引入layui.js&#x…