pytorch-MNIST测试实战

目录

  • 1. 为什么test
  • 2. 如何做test
  • 3. 什么时候做test
  • 4. 完整代码

1. 为什么test

如下图:上下两幅图中蓝色分别表示train的accuracy和loss,黄色表示test的accuracy和loss,如果单纯看train的accuracy和loss曲线就会认为模型已经train的很好了,accuracy一直在上升接近于1了,loss一直在下降已经接近于0了,殊不知此时可能已经出现了over fitting(本数据集准确率很高,其他数据准确率很低),此时就需要test了,从图中可以看出test在红色划线右侧的accuracy已经不变甚至下降了,loss曲线波动也比较大,甚至已经上升了。
在这里插入图片描述

2. 如何做test

如下图所示:
argmax找出概率最大的数字的index
softmax在这里使用与不使用结果是一样的,因为softmax不改变单调性(大的依然大,小的依然小)
使用torch.eq计算预测值与目标值是否相当,相等返回1不等返回0
correct.sum().float().item() /4是用来计算accuracy的,其他sum()是计算正确的个数,item是tensor转bumpy; /4是除以总样本数
在这里插入图片描述

3. 什么时候做test

  • 每几个batch做一次
  • 一个epoch做一次
    注意:为什么不一个batch做一次test呢?因为test的数据可能也比较大,每个batch都test会影响train的速度

4. 完整代码

从一下代码可知,test是一个epoch做一次,首先像train一样load test数据,并搬到GPU中,然后数据输入到网络中,计算loss,最后计算准确了并打印输出

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transformsbatch_size=200
learning_rate=0.01
epochs=10train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)class MLP(nn.Module):def __init__(self):super(MLP, self).__init__()self.model = nn.Sequential(nn.Linear(784, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 200),nn.LeakyReLU(inplace=True),nn.Linear(200, 10),nn.LeakyReLU(inplace=True),)def forward(self, x):x = self.model(x)return xdevice = torch.device('cuda:0')
net = MLP().to(device)
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
criteon = nn.CrossEntropyLoss().to(device)for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28*28)data, target = data.to(device), target.cuda()logits = net(data)loss = criteon(logits, target)optimizer.zero_grad()loss.backward()# print(w1.grad.norm(), w2.grad.norm())optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)data, target = data.to(device), target.cuda()logits = net(data)test_loss += criteon(logits, target).item()pred = logits.argmax(dim=1)correct += pred.eq(target).float().sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/2619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

23年新算法,SAO-SVM,基于SAO雪消融算法优化SVM支持向量机回归预测(多输入单输出)-附代码

SAO-SVM是一种基于SAO雪消融算法优化的支持向量机(SVM)回归预测方法,适用于多输入单输出的情况。下面是一个简要的概述,包括如何使用SAO-SVM进行回归预测的步骤: 步骤: 1. 数据准备: 收集并准…

勾八头歌之RNN

一、RNN快速入门 1.学习单步的RNN:RNNCell # -*- coding: utf-8 -*- import tensorflow as tf# 参数 a 是 BasicRNNCell所含的神经元数, 参数 b 是 batch_size, 参数 c 是单个 input 的维数,shape [ b , c ] def creatRNNCell(a,b,c):# 请在此添加代码…

sprinboot+人大金仓配置

1. .yml 配置 spring:datasource:type: com.alibaba.druid.pool.DruidDataSource#driverClassName: dm.jdbc.driver.DmDriver## todo 人大金仓driverClassName: com.kingbase8.Driverdruid:## todo 人大金仓master:url: jdbc:kingbase8://111.111.111.111:54321/dbname?cu…

粘合聚酰亚胺PI塑料材料使用UV胶的优势有哪些? (三十四)

聚酰亚胺PI难于粘接,在PI粘接方法中使用UV胶粘剂粘接PI的优势有哪些? 聚酰亚胺(PI)是一种具有耐高低温性能、高绝缘性、耐化性、低热膨胀系数的材料,广泛用于FPC基材和各种耐高温电机电器的绝缘材料。然而,…

MySQL常见的约束

什么是约束? 限制,限制我们表中的数据,保证添加到数据表中的数据准确和可靠性!凡是不符合约束的数据,插入时就会失败,插入不进去的! 比如:学生信息表中,学号就会约束不…

Java | Leetcode Java题解之第45题跳跃游戏II

题目&#xff1a; 题解&#xff1a; class Solution {public int jump(int[] nums) {int length nums.length;int end 0;int maxPosition 0; int steps 0;for (int i 0; i < length - 1; i) {maxPosition Math.max(maxPosition, i nums[i]); if (i end) {end maxP…

POP —— 简介

目录 Emitting Applying forces Reacting to surfaces Limiting particle speed Following a leader or leaders Swirling particles around vortex filaments Visualizing Forces Collisions Instancing and Rendering Sprite Particles Streams Writing particle…

编程基础“四大件”

基础四大件包括&#xff1a;数据结构和算法,计算机网络,操作系统,设计模式 这跟学什么编程语言,后续从事什么编程方向均无关&#xff0c;只要做编程开发&#xff0c;这四个计算机基础就无法避开。可以这么说&#xff0c;这基础四大件真的比编程语言重要&#xff01;&#xff0…

色温的介绍

文章目录 色温的概念照明领域显示技术领域 色温的概念 色温是描述光源色彩特性的一个重要参数&#xff0c;通常用来表征光的暖冷程度。它以开尔文&#xff08;Kelvin&#xff0c;K&#xff09;为单位来表示&#xff0c;通常简写为K。色温越高&#xff0c;光线看起来就越接近于…

如何用PHP语言实现远程语音播报

如何用PHP语言实现远程语音播报呢&#xff1f; 本文描述了使用PHP语言调用HTTP接口&#xff0c;实现语音播报。通过发送文本信息&#xff0c;来实现远程语音播报、语音提醒、语音警报等。 可选用产品&#xff1a;可根据实际场景需求&#xff0c;选择对应的规格 序号设备名称1…

比特币之路:技术突破、创新思维与领军人物

比特币的兴起是一段充满技术突破、创新思维和领军人物的传奇之路。在这篇文章中&#xff0c;我们将探讨比特币发展的历程&#xff0c;以及那些在这一过程中发挥重要作用的关键人物。 技术突破与前奏 比特币的诞生并非凭空而来&#xff0c;而是建立在先前的技术储备之上。在密码…

机器学习中常见的数据分析,处理方式(以泰坦尼克号为例)

数据分析 读取数据查看数据各个参数信息查看有无空值如何填充空值一些特殊字段如何处理读取数据查看数据中的参数信息实操具体问题具体分析年龄问题 重新划分数据集如何删除含有空白值的行根据条件删除一些行查看特征和标签的相关性 读取数据 查看数据各个参数信息 查看有无空…

TCP三次握手详解

目录 什么是TCP TCP头格式组成 三次握手 第一次握手 第二次握手 第三次握手 三次握手的好处 为什么需要三次握手&#xff1f; 什么是TCP 传输控制协议(TCP)是Internet一个重要的传输层协议。TCP提供面向连接、可靠、有序、字节流传输服务。 面向连接&#xff1a; 应用…

百度糯米携手中山大学举办“开学流水宴”

热游圈消息&#xff1a; 百度糯米携手中山大学&#xff0c;于9月13日在“百团大战”游园会上举办了一场别开生面的“开学流水宴”&#xff0c;吸引了众多新生和百度糯米用户参与。这场长达20米的流水宴不仅为新生们带来了美味佳肴&#xff0c;更为他们提供了结交新朋友、增进同…

编写你的第一个java 程序

1.安装 jdk 网址&#xff1a; Java Downloads | Oracle 一般我们安装jdk 17 就行了 自己练习 自己学习 真正的开发中我们使用jdk 8 这个是最适合开发java 应用程序的 当然你也可以选择你的 系统 来安装这个java 在文件资源管理器打开JDK的安装目录的bin目录&#xff0c;会发…

pycharm远程连接server

1.工具–部署–配置 2.部署完成后&#xff0c;将现有的项目的解释器设置为ssh 解释器。实现在远端开发 解释器可以使用/usr/bin/python3

ROC和AUC

什么是ROC和AUC ROC曲线&#xff08;Receiver Operating Characteristic curve&#xff09;和AUC&#xff08;Area Under the Curve&#xff09;是用于评估二分类模型性能的重要工具。 ROC曲线以真正例率&#xff08;True Positive Rate&#xff0c;也称为召回率或灵敏度&…

Scala的函数至简原则

对于scala语言来说&#xff0c;函数的至简原则是它的一大特色。下面让我们一起来看看分别有什么吧&#xff01; 函数至简原则&#xff1a;能省则省&#xff01; 初始函数 def test(name:String):String{return name }1、return可以省略&#xff0c;Scala会使用函数体的最后一…

【Ubuntu20.04+Noetic】UR5e+Gazebo+Moveit

环境准备 创建工作空间 mkdir -p ur5e_ws/src cd ur5e_ws/srcUR机械臂软件包 UR官方没更新最新的noetic的分支,因此安装melodic,并需要改动相关文件。 安装UR的模型配置包,包里面有UR模型文件,moveit配置等: cd ~/ur5e_ws/src git clone -b melodic-devel https://git…

探索未来的区块链DApp应用,畅享数字世界的无限可能

随着区块链技术的飞速发展&#xff0c;分布式应用&#xff08;DApp&#xff09;正成为数字经济中的一股强劲力量。DApp以其去中心化、透明公正的特点&#xff0c;为用户带来了全新的数字体验&#xff0c;开创了数字经济的新潮流。作为一家专业的区块链DApp应用开发公司&#xf…