手写数据集minist基于pytorch分类学习

1.Mnist数据集介绍
1.1 基本介绍
Mnist数据集可以算是学习深度学习最常用到的了。这个数据集包含70000张手写数字图片,分别是60000张训练图片和10000张测试图片,训练集由来自250个不同人手写的数字构成,一般来自高中生,一半来自工作人员,测试集(test set)也是同样比例的手写数字数据,并且保证了测试集和训练集的作者不同。每个图片都是2828个像素点,数据集会把一张图片的数据转成一个2828=784的一维向量存储起来。
里面的图片数据如下所示,每张图是0-9的手写数字黑底白字的图片,存储时,黑色用0表示,白色用0-1的浮点数表示。


1.2 数据集下载
1)官网下载
Mnist数据集的下载地址如下:http://yann.lecun.com/exdb/mnist/
打开后会有四个文件:


训练数据集:train-images-idx3-ubyte.gz
训练数据集标签:train-labels-idx1-ubyte.gz
测试数据集:t10k-images-idx3-ubyte.gz
测试数据集标签:t10k-labels-idx1-ubyte.gz
将这四个文件下载后放置到需要用的文件夹下即可不要解压!下载后是什么就怎么放!

2)代码导入
文件夹下运行下面的代码,即可自动检测数据集是否存在,若没有会自动进行下载,下载后在这一路径:

下载数据集:

# 下载数据集
from torchvision import datasets, transformstrain_set = datasets.MNIST("data",train=True,download=True, transform=transforms.ToTensor(),)
test_set = datasets.MNIST("data",train=False,download=True, transform=transforms.ToTensor(),)

参数解释:

datasets.MNIST:是Pytorch的内置函数torchvision.datasets.MNIST,可以导入数据集
train=True :读入的数据作为训练集
transform:读入我们自己定义的数据预处理操作
download=True:当我们的根目录(root)下没有数据集时,便自动下载
如果这时候我们通过联网自动下载方式download我们的数据后,它的文件路径是以下形式:原文件夹/data/MNIST/raw

14轮左右,模型识别准确率达到98%以上

 

 加载数据集

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

构建模型,模型主要由两个卷积层,两个池化层,以及一个全连接层构成,激活函数使用relu. 

 

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e

模型训练
每次训练完成后会自动保存参数到pkl模型中,如果路径中有Pkl文件,下次运行会自动加载上一次的模型参数,在这个基础上继续训练,第一次运行时没有模型参数,结束后会自动生成。

# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))

加载模型
第一次运行这里需要一个空的model文件夹

if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))

模型测试

def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))

自己手写数字图片识别函数(可选用)
这部分主要是加载训练好的pkl模型测试自己的数据,因此在进行自己手写图的测试时,需要有训练好的pkl文件,并且就不要调用train()函数和test()函数啦注意:这个图片像素也要说黑底白字,28*28像素,否则无法识别

def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()

MNIST中的数据识别测试数据
训练过程中的打印信息我进行了修改,这里设置的训练轮数是15轮,每次训练生成的pkl模型参数也是会更新的,想要更多训练信息可以查看对应的教程哦~

if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

完整代码:

import os.path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import MNIST
from PIL import Image
from torch.utils.data import Dataset,DataLoader
from torchvision import datasets, transforms
# 下载数据集
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(), # 将灰度图片像素值(0~255)转为Tensor(0~1),方便后续处理transforms.Normalize((0.1307,),(0.3081,))# 归一化,均值0,方差1;mean:各通道的均值std:各通道的标准差inplace:是否原地操作
])train_data = MNIST(root='./minist_data',train=True,download=False,transform=transform)
train_loader = DataLoader(dataset=train_data,shuffle=True,batch_size=64)
test_data = MNIST(root='./minist_data',train=False,download=False,transform=transform)
test_loader = DataLoader(dataset=test_data,shuffle=True,batch_size=64)# train_data返回的是很多张图,每一张图是一个元组,包含图片和对应的数字
# print(test_data[0])
# print(train_data[0][0].show())train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.conv1 = torch.nn.Conv2d(in_channels=1,out_channels=10,stride=1,kernel_size=5,padding=0)self.maxpool1 = torch.nn.MaxPool2d(2)self.conv2 = torch.nn.Conv2d(in_channels=10,out_channels=20,kernel_size=5,stride=1,padding=0)self.maxpool2 = torch.nn.MaxPool2d(2)self.linear = torch.nn.Linear(320,10)def forward(self,x):x = torch.relu(self.conv1(x))x = self.maxpool1(x)x = torch.relu(self.conv2(x))x = self.maxpool2(x)x = x.view(x.size(0),-1)x = self.linear(x)return x
model = Model()criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.14)# 交叉熵损失,相当于Softmax+Log+NllLoss
# 线性多分类模型Softmax,给出最终预测值对于10个类别出现的概率,Log:将乘法转换为加法,减少计算量,保证函数的单调性
# NLLLoss:计算损失,此过程不需要手动one-hot编码,NLLLoss会自动完成
# SGD,优化器,梯度下降算法e# 模型训练
def train():# index = 0for index, data in enumerate(train_loader):  # 获取训练数据以及对应标签# for data in train_loader:input, target = data  # input为输入数据,target为标签y_predict = model(input)  # 模型预测loss = criterion(y_predict, target)optimizer.zero_grad()  # 梯度清零loss.backward()  # loss值反向传播optimizer.step()  # 更新参数# index += 1if index % 100 == 0:  # 每一百次保存一次模型,打印损失torch.save(model.state_dict(), "model.pkl")  # 保存模型torch.save(optimizer.state_dict(), "optimizer.pkl")print("训练次数为:{},损失值为:{}".format(index, loss.item()))if os.path.exists('model.pkl'):model.load_state_dict(torch.load("model.pkl"))def test():correct = 0total = 0with torch.no_grad():for index,data in enumerate(test_loader):inputs,target = dataoutput = model(inputs)probability,predict = torch.max(input=output.data, dim=1)total += target.size(0)  # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小correct += (predict == target).sum().item()  # predict 和target均为(batch_size,1)的矩阵,sum求出相等的个数print("测试准确率为:%.6f" % (correct / total))def test_mydata():image = Image.open('5fd4e4c2c99a24e3e27eb9b2ee3b053c.jpg')  # 读取自定义手写图片image = image.resize((28, 28))  # 裁剪尺寸为28*28image = image.convert('L')  # 转换为灰度图像transform = transforms.ToTensor()image = transform(image)image = image.resize(1, 1, 28, 28)output = model(image)probability, predict = torch.max(output.data, dim=1)print("此手写图片值为:%d,其最大概率为:%.2f " % (predict[0], probability))plt.title("此手写图片值为:{}".format((int(predict))), fontname='SimHei')plt.imshow(image.squeeze())plt.show()if __name__ == '__main__':# 训练与测试for i in range(15):  # 训练和测试进行5轮print({"————————第{}轮测试开始——————".format(i + 1)})train()test()test_mydata()

 

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/21764.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MTU相关随笔

一、MTU的概念 MTU(最大传输单元):用来通知对方所能接受数据服务单元的最大尺寸,说明发送方能够接受的有效载荷大小。MTU是包或帧的最大长度,一般以字节记,如果过大在碰到路由器时会被拒绝转发&#xff0c…

SpringBoot项目本地运行正常,jar包运行时前端报错403:No mapping for......

SpringBoot项目本地运行正常,jar包运行时前端报错403:No mapping for… 提示:在部署jar包到云服务器上之前,一定要在本地运行jar包,查看前端代码是否运行正常,若报错的话可以节省很多时间 方式:…

友顺科技(UTC)分立器件与集成IC产品选型和应用

友顺科技股份有限公司成立于1990年,是全球领先的集成电路与功率半导体厂商 ,集团总部位于台北,生产基地位于福州、厦门。 友顺科技具有完整模拟组件产品线,其中类比IC涵盖各种稳压器、PWM控制IC, 放大器、比较器、逻辑IC、Voltage Translato…

基于飞腾 D2000 8 核+ 32G DDR+板载 6 千兆电口+ 4 千兆光口高性能网络安全主板

第一章、产品介绍 1.1 产品概述 XM-D2000GW是一款基于飞腾 D2000 8 核X100 桥片高性能网络安全主板,D2000 为飞腾首款支持 8 核桌面平 台处理器,支持双通道 DDR4-2666 内存,芯片内置国密 SM2/SM3/SM4/SM9 加速引擎,支持单精度、双…

gitee和github的协同

假设gitee上zhaodezan有一个开发库,但是从andeyeluguo上拉取最新的(从github上同步过来最新的) git remote add dbgpt_in_gitee https://gitee.com/andeyeluguo/DB-GPT.git remote -v git pull --rebase dbgpt_in_gitee main 有冲突可能需要…

【调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包】

调试笔记-系列文章目录 调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包 文章目录 调试笔记-系列文章目录调试笔记-20240603-Linux-在 OpenWrt-23.05 上运行 ipkg-build 生成. ipk 安装包 前言一、调试环境操作系统:Windows 10 专业…

Android11 AudioTrack和Track建立联系

应用程序创建AudioTrack时,导致AudioFlinger在播放线程中,创建Track和其对应。那它们之间是通过什么来建立联系传递数据的?答案是共享内存。 创建Track时,导致其父类TrackBase的构造函数被调用 //frameworks/av/services/audiofl…

数字化时代还需要传统智慧图书馆吗

尽管以电子阅览室代表的数字化时代带来了许多便利和创新,但传统智慧图书馆依然具有重要的价值和意义。以下是一些原因: 1. 保存历史文化:传统智慧图书馆是保存历史文化遗产的重要载体,收藏了许多珍贵的古籍、手稿和纸质图书&#…

基于 Amazon EC2 快速部署 Stable Diffusion WebUI + chilloutmax 模型

自2023年以来,AI绘图已经从兴趣娱乐逐渐步入实际应用,在众多的模型中,作为闪耀的一颗明星,Stable diffusion已经成为当前最多人使用且效果最好的开源AI绘图软件之一。Stable Diffusion Web UI 是由AUTOMATIC1111 开发的基于 Stabl…

vue-cl-service不同环境运行/build配置

概述 在项目开发过程中,同一个项目在开发、测试、灰度、生产可能需要不同的配置信息,所以如果能根据环境的不同来设置参数很重要。 vue项目的vue-cl-service插件也支持不同环境的不同参数配置和打包。 实现 新建不同环境配置文件 vue项目中的配置文件以…

面向对象程序设计之从C到C++的初步了解

1. C语言 1. C的发展 C是从C语言发展演变而来的,首先是一个更好的C引入了类的机制,最初的C被称为“带类的C”1983年正式取名为C 从1989年开始C语言的标准化工作 于1994年制定了ANSIC标准草案 于1998年11月被国际标准化组织(ISO)批准为国际标准&#xf…

Ubuntu系统安装

目录 安装准备 安装步骤 虚拟机配置 系统安装 安装准备 Ubuntu系统镜像,虚拟机环境 虚拟机环境 使用的虚拟机软件为VMware Workstation 系统镜像 阿里镜像站:阿里巴巴开源镜像站-OPSX镜像站-阿里云开发者社区 (aliyun.com)https://developer.aliyun.com…

记一次使用mysql存储过程时,游标取值为空问题

call modify_collation(num,count_num) > 1146 - Table test.table_name doesnt exist > 时间: 0.009s 我在使用mysql存储过程时,打印时游标取值为空,报错找不到表。我的过程语句是这样的: drop procedure if exists modify_collation…

Redis中大Key与热Key的解决方案

原文地址:https://mp.weixin.qq.com/s/13p2VCmqC4oc85h37YoBcg 在工作中Redis已经成为必备的一款高性能的缓存数据库,但是在实际的使用过程中,我们常常会遇到两个常见的问题,也就是文章标题所说的大 key与热 key。 一、定义 1.1…

THS6011启动控制台后无法使用https访问控制台(by yz+lqw)

原因: 6011相对于6010版本,多了一个ssl的开关,下图是6010版本的参考配置: 而6011版本下的conf目录下的http.yaml,里面的ssl开关,默认是关闭的,也就是enable:false. 所以需要把enable&#xf…

构建LangChain应用程序的示例代码:9、使用Anthropic API生成结构化输出的工具教程

使用Anthropic API生成结构化输出的工具 Anthropic API最近增加了工具使用功能。 这对于生成结构化输出非常有用。 ! pip install -U langchain-anthropic可选配置: import osos.environ[LANGCHAIN_TRACING_V2] true # 启用追踪 os.environ[LANGCHAIN_API_KEY…

echarts-series的x,y轴的规则

series的data与x,y轴的匹配规则 如果series的data为[1,2,3,4,5,6] 1.如果x,y轴都是类目轴,且data没有与x,y轴的值匹配上,则无效。 2.如果x,y轴都为类目,data中能够跟类目轴上的字符串对应上,轴,有效。 3.如果都为value.,则按数值…

【贪心算法·哈夫曼编码问题】从定长编码和不定长编码讲到最小化带权路径长度和

一、问题介绍 1.1:编码问题 首先,我们知道,数字字符等任何数据的底层,都是以二进制(0,1序列)的方式存储在计算机内的。 对于“编码”其实就是那些能显示在计算机屏幕上的:不同字母、汉字、字…

半导体光子电学期末笔记2: 光子晶体 Photonic crystals

光子晶体概述 光子晶体定义和分类 [P4-5] 光子晶体是一种在一维、二维或三维空间内周期性排列的多层介质。这些结构通过在光子尺度上排列的重复单元,可以对光进行调控和控制。具体来说,光子晶体是指那些在空间上具有周期性排列的介质结构,它…

【深度学习】温故而知新4-手写体识别-多层感知机+CNN网络-完整代码-可运行

多层感知机版本 import torch import torch.nn as nn import numpy as np import torch.utils from torch.utils.data import DataLoader, Dataset import torchvision from torchvision import transforms import matplotlib.pyplot as plt import matplotlib import os # 前…