pyTorch框架使用CNN进行手写数字识别

目录

1.导包

2.torchvision数据处理的方法 

3.下载加载手写数字的训练数据集 

4.下载加载手写数字的测试数据集  

5. 将训练数据与测试数据 转换成dataloader

6.转成迭代器取数据 

7.创建模型 

8. 把model拷到GPU上面去

9. 定义损失函数

10. 定义优化器

11. 定义训练过程

12.最终运行测试 


 

1.导包

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
torch.__version__
'2.4.1+cu121'
# 检查GPU是否可用
torch.cuda.is_available()
True
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')device
device(type='cuda', index=0)

# pytorch中使用GPU进行训练的主要流程注意点:
#1. 把模型转移到GPU上. 
#2. 将每一批次的训练数据转移到GPU上.  

# torchvision 内置了常用的数据集和常见的模型. 
#使用pyTorch框架 整迁移学习时,可以从torchvision中加载出来 

2.torchvision数据处理的方法 

import torchvisionfrom torchvision import datasets, transforms

 # transforms.ToTensor  的作用如下:
# 1. 把数据转化为tensor
# 2. 数据的值转化为0到1之间. 
# 3. 会把channel放到第一个维度上.

# transforms用来做数据增强, 数据预处理等功能的. 
#Compose()可以将很多数据处理的方法组合起来, 用列表来组合
transformation = transforms.Compose([transforms.ToTensor(),])

3.下载加载手写数字的训练数据集 

#下载获取手写数字数据集MNIST  ,获取其中的训练数据集
#train=True表示获取训练数据集
train_ds = datasets.MNIST('./', train=True, transform=transformation, download=True)

datasets
<module 'torchvision.datasets' from 'D:\\anaconda3\\lib\\site-packages\\torchvision\\datasets\\__init__.py'>

4.下载加载手写数字的测试数据集  

# 下载获取 手写数字的 测试数据集
#train=False表示获取测试数据集
test_ds = datasets.MNIST('./', train=False, transform=transformation, download=True)

5. 将训练数据与测试数据 转换成dataloader

# 将训练数据与测试数据 转换成dataloader
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)   #shuffle=True表示打乱数据
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=256)  #测试数据时,不需要进行反向传播,可以将batch_size的值给大一些

6.转成迭代器取数据 

#可迭代对象 与 迭代器 不同
#train_dl是可迭代对象
#iter()将train_dl可迭代对象的数据 变成 迭代器
#next()从迭代器中取出一批数据
images, labels = next(iter(train_dl))
#tensorflow中图片的表现形式[batch, hight, width, channel]
# pytorch中图片的表现形式[batch, channel, hight, width]
images.shape
torch.Size([64, 1, 28, 28])
labels  #还没有one_hot编码的
tensor([1, 1, 2, 3, 2, 1, 8, 4, 5, 8, 4, 3, 0, 0, 4, 8, 2, 3, 3, 7, 3, 0, 5, 5,5, 6, 7, 2, 9, 4, 7, 9, 6, 7, 1, 4, 3, 9, 2, 4, 6, 4, 1, 1, 9, 2, 4, 7,7, 6, 2, 6, 8, 1, 3, 5, 4, 7, 5, 0, 6, 0, 9, 1])
img = images[0]   #取一张图的数据
img.shape  #一张图数据的形状   #三维数据 不方便可视化
img = img.numpy()  #可以先将数据转成numpy数据类型, 再进行数据降维,  再进行可视化
img = np.squeeze(img)  #数据降维,降一个维度,把只有1的维度降掉, 将形状变成(28, 28)
img.shape
(28, 28)
plt.imshow(img, cmap='gray')   #图片数据可视化

 

7.创建模型 

class Model(nn.Module):    #继承nn.Moduledef __init__(self):    #重写方法super().__init__() #继承父类的方法#无论经过什么层,batch_size一直保持不变(第一个数)#第一层卷积层#nn.Conv2d(输入的通道数, 自定义输出的通道数=这一层使用的卷积核的个数, 卷积核的大小)#输出的通道数=卷积核的个数(神经元个数)#nn.Conv2d()参数dilation=1(默认值),表示 不膨胀卷积#padding 默认为0#卷积核的大小为奇数时,padding=valid, 图片大小= ((原图片大小w - 卷积核的大小F)+ 1) /   步长s (此时步长默认steps = 1)self.conv1 = nn.Conv2d(1, 32, 3)# in: 64, 1, 28 , 28 -> out: 64, 32, 26, 26#池化层nn.MaxPool2d((卷积核的大小))   ,卷积核的大小可以用元组(2,2)表示,或者直接用一个数2表示   #strip步长默认为2#池化层的数值设置一般都一样,可重复使用,创建一次就行#卷积核的大小为偶数时,padding=same, 图片大小= 原图片大小w  / 步长s(此时的步长steps默认为2)#上一句化 等同于 经过一次池化层,原图片大小减半#经过池化层,输入通道数=上一层的输出通道数=上一层使用的卷积核个数self.pool = nn.MaxPool2d((2, 2)) # out: 64, 32, 13, 13
#         self.pool = nn.MaxPool2d(2)  #等同于上一行代码#第二层卷积核nn.Conv2d(上一层卷积的输出通道数, 自定义在这一层使用的卷积核的个数, 3)self.conv2 = nn.Conv2d(32, 64, 3)# in: 64, 32, 13, 13 -> out: 64, 64, 11, 11# 再加一层池化操作, in: 64, 64, 11, 11  --> out: 64, 64, 5, 5#第一层全连接层(输入通道数=上一层维度形状相乘,自定义输出通道数量=这一层使用的神经元数量)self.linear_1 = nn.Linear(64 * 5 * 5, 256)#第二层全连接层(输入通道数=上一层输出通道数,自定义输出通道数=问题分类数量(数字识别0~9,共10个数字))self.linear_2 = nn.Linear(256, 10)#定义前向传播def forward(self, input):#链式调用     #relu激活函数(第一层卷积层)  #调用自定义方法,需要加上self.x = F.relu(self.conv1(input))x = self.pool(x)           #第二层:池化层x = F.relu(self.conv2(x))  #第三层:卷积层,然后+ relu激活函数x = self.pool(x)           #第四层:池化层# flatten 展平, 进行维度变形x = x.view(-1, 64 * 5 * 5)   # (每批次具有64个样本, 特征数量)x = F.relu(self.linear_1(x)) #第五层:全连接层+激活函数relux = self.linear_2(x)       #第六层:输出层     #一般会在这一层之后添加一个sigmoid函数,这里没有加,后面需要处理一下return x

8. 把model拷到GPU上面去

model = Model()# 把model拷到GPU上面去
model.to(device)
Model((conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(pool): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))(linear_1): Linear(in_features=1600, out_features=256, bias=True)(linear_2): Linear(in_features=256, out_features=10, bias=True)
)

9. 定义损失函数

#定义损失函数
loss_fn = torch.nn.CrossEntropyLoss()

10. 定义优化器

#定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

11. 定义训练过程

#定义训练过程
def fit(epoch, model, train_loader, test_loader):#声明变量correct = 0      #预测准确的数量total = 0        #总共的样本数量running_loss = 0 #每运行一次的累计损失for x, y in train_loader:# 把CPU上的数据放到GPU上去. x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_fn(y_pred, y)optimizer.zero_grad()loss.backward()optimizer.step()with torch.no_grad():  #不计算梯度求导时#计算预测值   #argmax()是取最大值的索引, y_pred是10个预测数字的概率(包含10个概率数值),是二维数据类型   # dim=1表示取第二个维度上的索引:列上y_pred = torch.argmax(y_pred, dim=1)#计算预测准确的数量,  #.item()表示把sum()求和的聚合运算(bool值会直接用1或0表示)之后的 标量(一个具体的数)取出来correct += (y_pred == y).sum().item()total += y.size(0)  #总共样本数量running_loss += loss.item()  #每运行一次的累计损失epoch_loss = running_loss / len(train_loader.dataset)  #计算平均损失epoch_acc = correct / total  #计算准确率# 测试过程,不需要计算梯度求导test_correct = 0test_total = 0test_running_loss = 0with torch.no_grad():for x, y in test_loader:x, y = x.to(device), y.to(device)y_pred = model(x)loss = loss_fn(y_pred, y)y_pred = torch.argmax(y_pred, dim=1)test_correct += (y_pred == y).sum().item()test_total += y.size(0)test_running_loss += loss.item()test_epoch_loss = test_running_loss / len(test_loader.dataset)test_epoch_acc = test_correct / test_totalprint('epoch: ', epoch,'loss: ', round(epoch_loss, 3),    #3表示三位小数'accuracy: ', round(epoch_acc, 3),'test_loss: ', round(test_epoch_loss, 3),'test_accuracy: ', round(test_epoch_acc))return epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc

12.最终运行测试 

epochs = 20     #指定运行的次数
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc = fit(epoch, model, train_dl, test_dl)train_loss.append(epoch_loss)train_acc.append(epoch_acc)test_loss.append(epoch_loss)test_acc.append(epoch_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74660.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化学习课程:stanford_cs234 学习笔记(3)introduction to RL

文章目录 前言7 markov 实践7.1 markov 过程再叙7.2 markov 奖励过程 MRP&#xff08;markov reward process&#xff09;7.3 markov 价值函数与贝尔曼方程7.4 markov 决策过程MDP&#xff08;markov decision process&#xff09;的 状态价值函数7.4.1 状态价值函数7.4.2 状态…

操作系统 4.5-文件使用磁盘的实现

通过文件进行磁盘操作入口 // 在fs/read_write.c中 int sys_write(int fd, const char* buf, int count) {struct file *file current->filp[fd];struct m_inode *inode file->inode;if (S_ISREG(inode->i_mode))return file_write(inode, file, buf, count); } 进程…

libreoffice-help-common` 的版本(`24.8.5`)与官方源要求的版本(`24.2.7`)不一致

出现此错误的原因主要是软件包依赖冲突&#xff0c;具体分析如下&#xff1a; ### 主要原因 1. **软件源版本不匹配&#xff08;国内和官方服务器版本有差距&#xff09; 系统中可能启用了第三方软件源&#xff08;如 PPA 或 backports 源&#xff09;&#xff0c;导致 lib…

使用Geotools中的原始方法来操作PostGIS空间数据库

目录 前言 一、原生PostGIS连接介绍 1、连接参数说明 2、创建DataStore 二、工程实战 1、Maven Pom.xml定义 2、空间数据库表 3、读取空间表的数据 三、总结 前言 在当今数字化与信息化飞速发展的时代&#xff0c;空间数据的处理与分析已成为众多领域不可或缺的一环。从…

讯飞语音合成(流式版)语音专业版高质量的分析

一、引言 在现代的 Web 应用开发中&#xff0c;语音合成技术为用户提供了更加便捷和人性化的交互体验。讯飞语音合成&#xff08;流式版&#xff09;以其高效、稳定的性能&#xff0c;成为了众多开发者的首选。本文将详细介绍在 Home.vue 文件中实现讯飞语音合成&#xff08;流…

走进未来的交互世界:下一代HMI设计趋势解析

在科技日新月异的今天&#xff0c;人机交互界面&#xff08;HMI&#xff09;设计正以前所未有的速度发展&#xff0c;不断引领着未来的交互世界。从简单的按钮和图标&#xff0c;到如今的智能助手和虚拟现实&#xff0c;HMI设计不仅改变了我们的生活方式&#xff0c;还深刻影响…

洛谷题单3-P1217 [USACO1.5] 回文质数 Prime Palindromes-python-流程图重构

题目描述 因为 151 151 151 既是一个质数又是一个回文数&#xff08;从左到右和从右到左是看一样的&#xff09;&#xff0c;所以 151 151 151 是回文质数。 写一个程序来找出范围 [ a , b ] ( 5 ≤ a < b ≤ 100 , 000 , 000 ) [a,b] (5 \le a < b \le 100,000,000…

学习笔记,DbContext context 对象是保存了所有用户对象吗

DbContext 并不会将所有用户对象保存在内存中&#xff1a; DbContext 是 Entity Framework Core (EF Core) 的数据库上下文&#xff0c;它是一个数据库访问的抽象层它实际上是与数据库的一个连接会话&#xff0c;而不是数据的内存缓存当您通过 _context.Users 查询数据时&…

本地命令行启动服务并连接MySQL8

启动服务命令 net start mysql8 关闭服务命令 net stop mysql8 本地连接MySQL数据库mysql -u [用户名] -p[密码] 这里&#xff0c;我遇到了个问题 —— 启动、关闭服务时&#xff0c;显示 “发生系统错误 5。拒绝访问。 ” 解法1&#xff1a;在 Windows 上以管理员身份打开…

数据蒸馏:Dataset Distillation by Matching Training Trajectories 论文翻译和理解

一、TL&#xff1b;DR 数据集蒸馏的任务是合成一个较小的数据集&#xff0c;使得在该合成数据集上训练的模型能够达到在完整数据集上训练的模型相同的测试准确率&#xff0c;号称优于coreset的选择方法本文中&#xff0c;对于给定的网络&#xff0c;我们在蒸馏数据上对其进行几…

【spring cloud Netflix】Ribbon组件

1.基本概念 SpringCloud Ribbon是基于Netflix Ribbon 实现的一套客户端负载均衡的工具。简单的说&#xff0c;Ribbon 是 Netflix 发布的开源项目&#xff0c;主要功能是提供客户端的软件负载均衡算法&#xff0c;将 Netflix 的中间层服务连接在一 起。Ribbon 的客户端组件提供…

P1036 [NOIP 2002 普及组] 选数(DFS)

题目描述 已知 n 个整数 x1​,x2​,⋯,xn​&#xff0c;以及 1 个整数 k&#xff08;k<n&#xff09;。从 n 个整数中任选 k 个整数相加&#xff0c;可分别得到一系列的和。例如当 n4&#xff0c;k3&#xff0c;4 个整数分别为 3,7,12,19 时&#xff0c;可得全部的组合与它…

在响应式网页的开发中使用固定布局、流式布局、弹性布局哪种更好

一、首先看下固定布局与流体布局的区别 &#xff08;一&#xff09;固定布局 固定布局的网页有一个固定宽度的容器&#xff0c;内部组件宽度可以是固定像素值或百分比。其容器元素不会移动&#xff0c;无论访客屏幕分辨率如何&#xff0c;看到的网页宽度都相同。现代网页设计…

二分查找与二叉树中序遍历——面试算法

目录 二分查找与分治 循环方式 递归方式 元素中有重复的二分查找 基于二分查找的拓展问题 山脉数组的顶峰索引——局部有序 旋转数字中的最小数字 找缺失数字 优化平方根 中序与搜索树 二叉搜索树中搜索特定值 验证二叉搜索树 有序数组转化为二叉搜索树 寻找两个…

字符串——面试考察高频算法题

目录 转换成小写字母 字符串转化为整数 反转相关的问题 反转字符串 k个一组反转 仅仅反转字母 反转字符串里的单词 验证回文串 判断是否互为字符重排 最长公共前缀 字符串压缩问题 转换成小写字母 给你一个字符串 s &#xff0c;将该字符串中的大写字母转换成相同的…

现代复古电影海报品牌徽标设计衬线英文字体安装包 Thick – Retro Vintage Cinematic Font

Thick 是一种大胆的复古字体&#xff0c;专为有影响力的标题和怀旧的视觉效果而设计。其厚实的字体、复古魅力和电影风格使其成为电影海报、产品标签、活动品牌和编辑设计的理想选择。无论您是在引导电影的黄金时代&#xff0c;还是在现代布局中注入复古活力&#xff0c;Thick …

[C++面试] new、delete相关面试点

一、入门 1、说说new与malloc的基本用途 int* p1 (int*)malloc(sizeof(int)); // C风格 int* p2 new int(10); // C风格&#xff0c;初始化为10 new 是 C 中的运算符&#xff0c;用于在堆上动态分配内存并调用对象的构造函数&#xff0c;会自动计算所需内存…

Unity URP管线与HDRP管线对比

1. 渲染架构与底层技术 URP 渲染路径&#xff1a; 前向渲染&#xff08;Forward&#xff09;&#xff1a;默认单Pass前向&#xff0c;支持少量实时光源&#xff08;通常4-8个逐物体&#xff09;。 延迟渲染&#xff08;Deferred&#xff09;&#xff1a;可选但功能简化&#…

JDK8卸载与安装教程(超详细)

JDK8卸载与安装教程&#xff08;超详细&#xff09; 最近学习一个项目&#xff0c;需要使用更高级的JDK&#xff0c;这里记录一下卸载旧版本与安装新版本JDK的过程。 JDK8卸载 以windows10操作系统为例&#xff0c;使用快捷键winR输入cmd&#xff0c;打开控制台窗口&#xf…

python爬虫:DrissionPage实战教程

如果本文章看不懂可以看看上一篇文章&#xff0c;加强自己的基础&#xff1a;爬虫自动化工具&#xff1a;DrissionPage-CSDN博客 案例解析&#xff1a; 前提&#xff1a;我们以ChromiumPage为主&#xff0c;写代码工具使用Pycharm&#xff08;python环境3.9-3.10&#xff09; …