pytorch-MNIST测试实战

目录

  • 1. 为什么test
  • 2. 如何做test
  • 3. 什么时候做test
  • 4. 完整代码

1. 为什么test

如下图:上下两幅图中蓝色分别表示train的accuracy和loss,黄色表示test的accuracy和loss,如果单纯看train的accuracy和loss曲线就会认为模型已经train的很好了,accuracy一直在上升接近于1了,loss一直在下降已经接近于0了,殊不知此时可能已经出现了over fitting(本数据集准确率很高,其他数据准确率很低),此时就需要test了,从图中可以看出test在红色划线右侧的accuracy已经不变甚至下降了,loss曲线波动也比较大,甚至已经上升了。
在这里插入图片描述

2. 如何做test

如下图所示:
argmax找出概率最大的数字的index
softmax在这里使用与不使用结果是一样的,因为softmax不改变单调性(大的依然大,小的依然小)
使用torch.eq计算预测值与目标值是否相当,相等返回1不等返回0
correct.sum().float().item() /4是用来计算accuracy的,其他sum()是计算正确的个数,item是tensor转bumpy; /4是除以总样本数
在这里插入图片描述

3. 什么时候做test

  • 每几个batch做一次
  • 一个epoch做一次
    注意:为什么不一个batch做一次test呢?因为test的数据可能也比较大,每个batch都test会影响train的速度

4. 完整代码

从一下代码可知,test是一个epoch做一次,首先像train一样load test数据,并搬到GPU中,然后数据输入到网络中,计算loss,最后计算准确了并打印输出

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

23年新算法,SAO-SVM,基于SAO雪消融算法优化SVM支持向量机回归预测(多输入单输出)-附代码

SAO-SVM是一种基于SAO雪消融算法优化的支持向量机(SVM)回归预测方法,适用于多输入单输出的情况。下面是一个简要的概述,包括如何使用SAO-SVM进行回归预测的步骤: 步骤: 1. 数据准备: 收集并准…

勾八头歌之RNN

一、RNN快速入门 1.学习单步的RNN:RNNCell # -*- coding: utf-8 -*- import tensorflow as tf# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape [ b , c ] def creatRNNCell(a,b,c):# 请在此添加代码…

python制作ppt

在Python中,你可以使用python-pptx库来创建和修改PowerPoint (.pptx) 文件。这个库允许你添加幻灯片、文本框、图片、形状、表格等元素,并可以调整它们的格式和布局。 下面是一个简单的例子,展示了如何使用python-pptx库来创建一个PPT文件&a…

sprinboot+人大金仓配置

1. .yml 配置 spring:datasource:type: com.alibaba.druid.pool.DruidDataSource#driverClassName: dm.jdbc.driver.DmDriver## todo 人大金仓driverClassName: com.kingbase8.Driverdruid:## todo 人大金仓master:url: jdbc:kingbase8://111.111.111.111:54321/dbname?cu…

粘合聚酰亚胺PI塑料材料使用UV胶的优势有哪些? (三十四)

聚酰亚胺PI难于粘接,在PI粘接方法中使用UV胶粘剂粘接PI的优势有哪些? 聚酰亚胺(PI)是一种具有耐高低温性能、高绝缘性、耐化性、低热膨胀系数的材料,广泛用于FPC基材和各种耐高温电机电器的绝缘材料。然而,…

MySQL常见的约束

什么是约束? 限制,限制我们表中的数据,保证添加到数据表中的数据准确和可靠性!凡是不符合约束的数据,插入时就会失败,插入不进去的! 比如:学生信息表中,学号就会约束不…

Java | Leetcode Java题解之第45题跳跃游戏II

题目&#xff1a; 题解&#xff1a; class Solution {public int jump(int[] nums) {int length nums.length;int end 0;int maxPosition 0; int steps 0;for (int i 0; i < length - 1; i) {maxPosition Math.max(maxPosition, i nums[i]); if (i end) {end maxP…

POP —— 简介

目录 Emitting Applying forces Reacting to surfaces Limiting particle speed Following a leader or leaders Swirling particles around vortex filaments Visualizing Forces Collisions Instancing and Rendering Sprite Particles Streams Writing particle…

深度学习-01

机器学习是一种人工智能的分支&#xff0c;主要研究如何让计算机系统通过从大量的数据中学习并自动改进其性能。机器学习算法通过统计和数学模型来分析和理解数据&#xff0c;从而能够自动发现数据中的模式和规律&#xff0c;并根据这些规律进行预测和决策。 机器学习可以应用于…

编程基础“四大件”

基础四大件包括&#xff1a;数据结构和算法,计算机网络,操作系统,设计模式 这跟学什么编程语言,后续从事什么编程方向均无关&#xff0c;只要做编程开发&#xff0c;这四个计算机基础就无法避开。可以这么说&#xff0c;这基础四大件真的比编程语言重要&#xff01;&#xff0…

色温的介绍

文章目录 色温的概念照明领域显示技术领域 色温的概念 色温是描述光源色彩特性的一个重要参数&#xff0c;通常用来表征光的暖冷程度。它以开尔文&#xff08;Kelvin&#xff0c;K&#xff09;为单位来表示&#xff0c;通常简写为K。色温越高&#xff0c;光线看起来就越接近于…

如何用PHP语言实现远程语音播报

如何用PHP语言实现远程语音播报呢&#xff1f; 本文描述了使用PHP语言调用HTTP接口&#xff0c;实现语音播报。通过发送文本信息&#xff0c;来实现远程语音播报、语音提醒、语音警报等。 可选用产品&#xff1a;可根据实际场景需求&#xff0c;选择对应的规格 序号设备名称1…

比特币之路:技术突破、创新思维与领军人物

比特币的兴起是一段充满技术突破、创新思维和领军人物的传奇之路。在这篇文章中&#xff0c;我们将探讨比特币发展的历程&#xff0c;以及那些在这一过程中发挥重要作用的关键人物。 技术突破与前奏 比特币的诞生并非凭空而来&#xff0c;而是建立在先前的技术储备之上。在密码…

机器学习中常见的数据分析,处理方式(以泰坦尼克号为例)

数据分析 读取数据查看数据各个参数信息查看有无空值如何填充空值一些特殊字段如何处理读取数据查看数据中的参数信息实操具体问题具体分析年龄问题 重新划分数据集如何删除含有空白值的行根据条件删除一些行查看特征和标签的相关性 读取数据 查看数据各个参数信息 查看有无空…

TCP三次握手详解

目录 什么是TCP TCP头格式组成 三次握手 第一次握手 第二次握手 第三次握手 三次握手的好处 为什么需要三次握手&#xff1f; 什么是TCP 传输控制协议(TCP)是Internet一个重要的传输层协议。TCP提供面向连接、可靠、有序、字节流传输服务。 面向连接&#xff1a; 应用…

百度糯米携手中山大学举办“开学流水宴”

热游圈消息&#xff1a; 百度糯米携手中山大学&#xff0c;于9月13日在“百团大战”游园会上举办了一场别开生面的“开学流水宴”&#xff0c;吸引了众多新生和百度糯米用户参与。这场长达20米的流水宴不仅为新生们带来了美味佳肴&#xff0c;更为他们提供了结交新朋友、增进同…

关于TC简单编程的AB爪爪的几点东西

最近在帮公司写一个SAP页面的自动录入数据的小工具。 端口是5000&#xff0c;SAP版本好像是7.2.很老的东西&#xff0c;老到页面只支持IE打开。其他浏览器打开就报IVEW不支持什么的一大堆错误。 没办法&#xff0c;拉出TC。但是更麻烦的又来了。TC自带的AB爪爪抓不到各种输入…

编写你的第一个java 程序

1.安装 jdk 网址&#xff1a; Java Downloads | Oracle 一般我们安装jdk 17 就行了 自己练习 自己学习 真正的开发中我们使用jdk 8 这个是最适合开发java 应用程序的 当然你也可以选择你的 系统 来安装这个java 在文件资源管理器打开JDK的安装目录的bin目录&#xff0c;会发…

pycharm远程连接server

1.工具–部署–配置 2.部署完成后&#xff0c;将现有的项目的解释器设置为ssh 解释器。实现在远端开发 解释器可以使用/usr/bin/python3

ROC和AUC

什么是ROC和AUC ROC曲线&#xff08;Receiver Operating Characteristic curve&#xff09;和AUC&#xff08;Area Under the Curve&#xff09;是用于评估二分类模型性能的重要工具。 ROC曲线以真正例率&#xff08;True Positive Rate&#xff0c;也称为召回率或灵敏度&…