手写数据集minist基于pytorch分类学习

1.Mnist数据集介绍
1.1 基本介绍
Mnist数据集可以算是学习深度学习最常用到的了。这个数据集包含70000张手写数字图片,分别是60000张训练图片和10000张测试图片,训练集由来自250个不同人手写的数字构成,一般来自高中生,一半来自工作人员,测试集(test set)也是同样比例的手写数字数据,并且保证了测试集和训练集的作者不同。每个图片都是2828个像素点,数据集会把一张图片的数据转成一个2828=784的一维向量存储起来。
里面的图片数据如下所示,每张图是0-9的手写数字黑底白字的图片,存储时,黑色用0表示,白色用0-1的浮点数表示。


1.2 数据集下载
1)官网下载
Mnist数据集的下载地址如下:http://yann.lecun.com/exdb/mnist/
打开后会有四个文件:


训练数据集:train-images-idx3-ubyte.gz
训练数据集标签:train-labels-idx1-ubyte.gz
测试数据集:t10k-images-idx3-ubyte.gz
测试数据集标签:t10k-labels-idx1-ubyte.gz
将这四个文件下载后放置到需要用的文件夹下即可不要解压!下载后是什么就怎么放!

2)代码导入
文件夹下运行下面的代码,即可自动检测数据集是否存在,若没有会自动进行下载,下载后在这一路径:

下载数据集:

# 下载数据集
from torchvision import datasets, transformstrain_set = datasets.MNIST("data",train=True,download=True, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("data",train=False,download=True, transform=transforms.ToTensor(),)

参数解释:

datasets.MNIST:是Pytorch的内置函数torchvision.datasets.MNIST,可以导入数据集
train=True :读入的数据作为训练集
transform:读入我们自己定义的数据预处理操作
download=True:当我们的根目录(root)下没有数据集时,便自动下载
如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:原文件夹/data/MNIST/raw

14轮左右,模型识别准确率达到98%以上

 

 加载数据集

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

构建模型,模型主要由两个卷积层,两个池化层,以及一个全连接层构成,激活函数使用relu. 

 

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e

模型训练
每次训练完成后会自动保存参数到pkl模型中,如果路径中有Pkl文件,下次运行会自动加载上一次的模型参数,在这个基础上继续训练,第一次运行时没有模型参数,结束后会自动生成。

# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))

加载模型
第一次运行这里需要一个空的model文件夹

if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))

模型测试

def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))

自己手写数字图片识别函数(可选用)
这部分主要是加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()

MNIST中的数据识别测试数据
训练过程中的打印信息我进行了修改,这里设置的训练轮数是15轮,每次训练生成的pkl模型参数也是会更新的,想要更多训练信息可以查看对应的教程哦~

if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

完整代码:

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/21764.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MTU相关随笔

一、MTU的概念 MTU(最大传输单元):用来通知对方所能接受数据服务单元的最大尺寸,说明发送方能够接受的有效载荷大小。MTU是包或帧的最大长度,一般以字节记,如果过大在碰到路由器时会被拒绝转发&#xff0c…

SpringBoot项目本地运行正常,jar包运行时前端报错403:No mapping for......

SpringBoot项目本地运行正常,jar包运行时前端报错403:No mapping for… 提示:在部署jar包到云服务器上之前,一定要在本地运行jar包,查看前端代码是否运行正常,若报错的话可以节省很多时间 方式:…

SQL 如何获取A列相同但是B列不同的数据项

用户表里有两个字段:部门和职位。一个部门可能对应多个职位,多个部门也可能都有同一职位。比如: 部门 职位 财务 部长 财务 副部长 财务 会计 财务 职员 编辑 部长 编辑 副部长 编辑 主编 编辑 副主编 现在想通过筛选,获取职位名称…

友顺科技(UTC)分立器件与集成IC产品选型和应用

友顺科技股份有限公司成立于1990年,是全球领先的集成电路与功率半导体厂商 ,集团总部位于台北,生产基地位于福州、厦门。 友顺科技具有完整模拟组件产品线,其中类比IC涵盖各种稳压器、PWM控制IC, 放大器、比较器、逻辑IC、Voltage Translato…

基于飞腾 D2000 8 核+ 32G DDR+板载 6 千兆电口+ 4 千兆光口高性能网络安全主板

第一章、产品介绍 1.1 产品概述 XM-D2000GW是一款基于飞腾 D2000 8 核X100 桥片高性能网络安全主板,D2000 为飞腾首款支持 8 核桌面平 台处理器,支持双通道 DDR4-2666 内存,芯片内置国密 SM2/SM3/SM4/SM9 加速引擎,支持单精度、双…

gitee和github的协同

假设gitee上zhaodezan有一个开发库,但是从andeyeluguo上拉取最新的(从github上同步过来最新的) git remote add dbgpt_in_gitee https://gitee.com/andeyeluguo/DB-GPT.git remote -v git pull --rebase dbgpt_in_gitee main 有冲突可能需要…

使用Scapy框架分析HTTP流量

网络流量分析是网络安全和管理中的一个重要部分。通过分析网络流量,我们可以检测异常行为、诊断网络问题以及提升网络性能。本文将介绍如何使用Scapy框架分析HTTP流量。我们将从tcpdump导出的PCAP文件中提取HTTP流量,并进行简单的分析。 PCAP文件格式 …

【调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包】

调试笔记-系列文章目录 调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包 文章目录 调试笔记-系列文章目录调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包 前言一、调试环境操作系统:Windows 10 专业…

Android11 AudioTrack和Track建立联系

应用程序创建AudioTrack时,导致AudioFlinger在播放线程中,创建Track和其对应。那它们之间是通过什么来建立联系传递数据的?答案是共享内存。 创建Track时,导致其父类TrackBase的构造函数被调用 //frameworks/av/services/audiofl…

数字化时代还需要传统智慧图书馆吗

尽管以电子阅览室代表的数字化时代带来了许多便利和创新,但传统智慧图书馆依然具有重要的价值和意义。以下是一些原因: 1. 保存历史文化:传统智慧图书馆是保存历史文化遗产的重要载体,收藏了许多珍贵的古籍、手稿和纸质图书&#…

Prop 和 State 有什么区别?

Prop (属性) 和 State (状态) 是 React 中两个非常重要的概念,它们之间有以下几个主要区别: 来源:Prop 是父组件传递给子组件的数据。State 是组件内部维护的数据。可变性:Prop 是不可变的(immutable)。一旦父组件传递给子组件,子组件就无法直接修改 prop。State 是可变的(mut…

基于 Amazon EC2 快速部署 Stable Diffusion WebUI + chilloutmax 模型

自2023年以来,AI绘图已经从兴趣娱乐逐渐步入实际应用,在众多的模型中,作为闪耀的一颗明星,Stable diffusion已经成为当前最多人使用且效果最好的开源AI绘图软件之一。Stable Diffusion Web UI 是由AUTOMATIC1111 开发的基于 Stabl…

力扣2090.半径为k的子数组平均值

力扣2090.半径为k的子数组平均值 accumulate函数&#xff1a;求一段和(起始迭代器&#xff0c;终止迭代器&#xff0c;初始值) class Solution {public:vector<int> getAverages(vector<int>& nums, int k) {int n nums.size();vector<int> res(n,-1…

Java密码复杂度实现

在Java中实现密码复杂度验证&#xff0c;通常需要考虑以下几个因素&#xff1a; 密码长度&#xff1b; 包含大写字母&#xff1b; 包含小写字母&#xff1b; 包含数字&#xff1b; 包含特殊字符&#xff08;可选&#xff09;。 以下是一个简单的Java类&#xff0c;用于验…

江苏服务器租用的优势有哪些?

随着互联网科技的快速发展&#xff0c;网络行业也逐渐开始兴起&#xff0c;而网络服务则离不开服务器的使用&#xff0c;那么江苏服务器租用对于其它地区来说都哪些优势呢&#xff1f; 江苏省是经济发展比较迅速的地区&#xff0c;所以江苏的企业对于网络方面的发展也是十分快速…

vue-cl-service不同环境运行/build配置

概述 在项目开发过程中&#xff0c;同一个项目在开发、测试、灰度、生产可能需要不同的配置信息&#xff0c;所以如果能根据环境的不同来设置参数很重要。 vue项目的vue-cl-service插件也支持不同环境的不同参数配置和打包。 实现 新建不同环境配置文件 vue项目中的配置文件以…

面向对象程序设计之从C到C++的初步了解

1. C语言 1. C的发展 C是从C语言发展演变而来的&#xff0c;首先是一个更好的C引入了类的机制&#xff0c;最初的C被称为“带类的C”1983年正式取名为C 从1989年开始C语言的标准化工作 于1994年制定了ANSIC标准草案 于1998年11月被国际标准化组织(ISO)批准为国际标准&#xf…

QT的窗口坐标和全局坐标

1、定义解释 窗口坐标&#xff1a;创建的窗口的坐标&#xff0c;以窗口左上角点为原点&#xff0c;横向往右为x轴正向&#xff0c;竖向往下为y轴正向。 全局坐标系&#xff1a;电脑屏幕的坐标系&#xff0c;以电脑屏幕左上角点为原点&#xff0c;横向往右为x轴正向&#xff0…

Ubuntu系统安装

目录 安装准备 安装步骤 虚拟机配置 系统安装 安装准备 Ubuntu系统镜像&#xff0c;虚拟机环境 虚拟机环境 使用的虚拟机软件为VMware Workstation 系统镜像 阿里镜像站&#xff1a;阿里巴巴开源镜像站-OPSX镜像站-阿里云开发者社区 (aliyun.com)https://developer.aliyun.com…

记一次使用mysql存储过程时,游标取值为空问题

call modify_collation(num,count_num) > 1146 - Table test.table_name doesnt exist > 时间: 0.009s 我在使用mysql存储过程时&#xff0c;打印时游标取值为空&#xff0c;报错找不到表。我的过程语句是这样的&#xff1a; drop procedure if exists modify_collation…