[Pytorch]手写数字识别——真·手写!

Github网址:https://github.com/diaoquesang/pytorchTutorials/tree/main

本教程创建于2023/7/31,几乎所有代码都有对应的注释,帮助初学者理解dataset、dataloader、transform的封装,初步体验调参的过程,初步掌握opencv、pandas、os等库的使用,😋纯手撸手写数字识别项目(为减少代码量简化了部分数据集相关操作),全流程跑通Pytorch!❤️❤️❤️
This tutorial was created on 2023/7/31. Almost all the code has corresponding comments, to help beginners understand dataset, dataloader, transform packaging, preliminary experience of the process of tuning the parameters, the initial grasp of the use of libraries such as opencv, pandas, os, etc., 😋 and get involved in this handwritten digit recognition project (we simplified some dataset-related operations in order to reduce the amount of code). Enjoy the whole process of running Pytorch!❤️❤️❤️

如果喜欢本项目的话,留下你的⭐吧!
Give me a ⭐ if you like this project!

一、train.py

import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myDataset(Dataset):  # 定义数据集类def __init__(self, annotations_file, img_dir, transform=None,target_transform=None):  # 传入参数(标签路径,图像路径,图像预处理方式,标签预处理方式)self.img_labels = pd.read_csv(annotations_file, sep=" ", header=None)# 从标签路径中读取标签,sep为划分间隔符,header为列标题的行位置self.img_dir = img_dir  # 读取图像路径self.transform = transform  # 读取图像预处理方式self.target_transform = target_transform  # 读取标签预处理方式def __len__(self):return len(self.img_labels)  # 读取标签数量作为数据集长度def __getitem__(self, idx):  # 从数据集中取出数据img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])# 从标签对象中取出第idx行第0列(第0列为图像位置所在列)的值(numberImages\5.bmp),并与图像路径(numberImages)进行拼接image = cv.imread(img_path)  # 用openCV的imread函数读取图像label = self.img_labels.iloc[idx, 1]  # 从标签对象中取出第idx行第1列(第1列为图像标签所在列)的值(5)if self.transform:image = self.transform(image)  # 图像预处理if self.target_transform:label = self.target_transform(label)  # 标签预处理return image, label  # 返回图像和标签class myTransformMethod1():  # Python3默认继承object类def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28))  # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGBreturn img  # 返回预处理后的图像# 测试函数
# print(pd.read_csv("annotations.txt", sep=" ", header=None))
# print(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0]))
# print(pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 1])
# cv.imshow("1",cv.imread(os.path.join("numberImages", pd.read_csv("annotations.txt", sep=" ", header=None).iloc[5, 0])))
# cv.waitKey(0)class myNetwork(nn.Module):  # 定义神经网络def __init__(self):super().__init__()  # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):  # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logits# 设置运行环境,默认为cuda,若cuda不可用则改为mps,若mps也不可用则改为cpu
device = ("cuda"if torch.cuda.is_available()else "mps"if torch.backends.mps.is_available()else "cpu"
)
print(f"Using {device} device")  # 输出运行环境model = myNetwork().to(device)  # 创建神经网络模型实例# 设置超参数
learning_rate = 1e-5  # 学习率
batch_size = 8  # 每批数据数量
epochs = 3000  # 总轮数img_path = "./numberImages"  # 设置图像路径
label_path = "./annotations.txt"  # 设置标签路径myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])
# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]myDataset = myDataset(label_path, img_path, myTransform)  # 创建数据集实例myDataLoader = DataLoader(myDataset, batch_size=batch_size,shuffle=True)
# 创建数据读取器(可对训练集和测试集分别创建),batch_size为每批数据数量(一般为2的n次幂以提高运行速度),shuffle为随机打乱数据def train():# 根据epochs(总轮数)训练for epoch in range(epochs):totalLoss = 0# 分批读取数据for batch, (images, labels) in enumerate(myDataLoader):# 数据转换到对应运行环境images = images.to(device)labels = labels.to(device)pred = model(images)  # 前向传播myLoss = nn.CrossEntropyLoss()  # 定义损失函数(交叉熵)optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # 定义优化器loss = myLoss(pred, labels)  # 计算损失函数totalLoss += loss  # 计入总损失函数loss.backward()  # 反向传播optimizer.step()  # 更新权重optimizer.zero_grad()  # 清空梯度if batch % 1 == 0:  # 每隔1个batch输出1次lossloss, current = loss.item(), min((batch + 1) * batch_size,len(myDataset))print(f"epoch: {epoch:>5d} loss: {loss:>7f}  [{current:>5d}/{len(myDataset):>5d}]")if epoch == 0:minTotalLoss = totalLossif totalLoss < minTotalLoss:print("······························模型已保存······························")minTotalLoss = totalLosstorch.save(model, "./myModel.pth")  # 保存性能最好的模型if __name__ == "__main__":model.train()  # 设置训练模式train()

二、eval.py

import torch
import torchvisionfrom torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transformsimport os
import cv2 as cv
import pandas as pdclass myTransformMethod1():  # Python3默认继承object类def __call__(self, img):  # __call___,让类实例变成一个可以被调用的对象,像函数img = cv.resize(img, (28, 28))  # 改变图像大小img = cv.cvtColor(img, cv.COLOR_BGR2RGB)  # 将BGR(openCV默认读取为BGR)改为RGBreturn img  # 返回预处理后的图像class myNetwork(nn.Module):  # 定义神经网络def __init__(self):super().__init__()  # 继承nn.Module的构造器self.flatten = nn.Flatten(-3, -1)# 继承nn.Module的Flatten函数并改为flatten,考虑到推理时没有batch(CHW),若使用默认值(1,-1)会导致C没有被flatten,故使用(-3,-1)self.linear_relu_stack = nn.Sequential(  # 定义前向传播序列nn.Linear(3 * 28 * 28, 512),nn.ReLU(),nn.Linear(512, 512),nn.ReLU(),nn.Linear(512, 10),)def forward(self, x):  # 定义前向传播方法x = self.flatten(x)logits = self.linear_relu_stack(x)return logitsif __name__ == "__main__":model = torch.load("./myModel.pth").to("cuda")  # 载入模型model.eval()  # 设置推理模式myTransform = transforms.Compose([myTransformMethod1(), transforms.ToTensor()])# 定义图像预处理组合,ToTensor()中Pytorch将HWC(openCV默认读取为height,width,channel)改为CHW,并将值[0,255]除以255进行归一化[0,1]for i in range(10):img = cv.imread("./numberImages/"+str(i)+".bmp")  # 用openCV的imread函数读取图像img = myTransform(img).to("cuda")  # 图像预处理print(torch.argmax(model(img)))

三、其余资料详见Github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/22445.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jsqlparser 安装和使用

jsqlparser是sql语句解析工具&#xff0c;可以解析sql并分析语法。 安装 <dependency><groupId>com.github.jsqlparser</groupId><artifactId>jsqlparser</artifactId><version>4.3</version> </dependency>使用 String s …

【语音合成】微软 edge-tts

目录 1. edge-tts 介绍 2. 代码示例 1. edge-tts 介绍 https://github.com/rany2/edge-tts 在Python代码中使用Microsoft Edge的在线文本到语音服务 2. 代码示例 import asyncio # pip install edge_tts import edge_tts TEXT """给我放首我喜欢听的歌曲…

React ~ React Router 6

React Router 6 VS React Router 5.x 内置组件的变化; 移除<Switch /> , 新增<Routes />语法的变化; component { About } 变为 element { <About /> }新增多个hook官方明确推荐函数式组件了! 一级路由(变化) 安装路由 npm i react-router-dom (默认是最…

【数据结构】快速排序

快速排序是一种高效的排序算法&#xff0c;其基本思想是分治法。它将一个大问题分解成若干个小问题进行解决&#xff0c;最后将这些解合并得到最终结果。 快速排序的主要思路如下&#xff1a; 选择一个基准元素&#xff1a;从待排序的数组中选择一个元素作为基准&#xff08;…

Spring指定bean在哪个应用加载

1.背景 某项目,spring架构,有2个不同的WebAppApplication入口,大部分service类共用,小部分类有区别,只需要在一个应用中加载,不需要在另一个应用中加载. 2.实现代码 自定义限制注解 package mis.shared.annotation;import java.lang.annotation.ElementType; import java.lan…

springboot+maven插件调用mybatis generator自动生成对应的mybatis.xml文件和java类

mybatis最繁琐的事就是sql语句和实体类&#xff0c;sql语句写在java文件里很难看&#xff0c;字段多的表一开始写感觉阻力很大&#xff0c;没有耐心&#xff0c;自动生成便成了最称心的做法。自动生成xml文件&#xff0c;dao接口&#xff0c;实体类&#xff0c;虽一直感觉不太优…

skywalking全链路追踪

文章目录 一、介绍二、全链路追踪1. 测试1 - 正常请求2. 测试2 - 异常请求 三、过滤非业务请求链路1. 链路忽略插件2. 配置3. 测试 一、介绍 在上一篇文章skywalking安装教程中我们介绍了skywalking的作用以及如何将其集成到我们的微服务项目中。本篇文章我们介绍在微服务架构…

没有jodatime,rust怎么方便高效的操作时间呢?

关注我&#xff0c;学习Rust不迷路&#xff01;&#xff01; 当使用Rust进行日期操作时&#xff0c;可以使用 chrono 库。下面给出了二十个常见的日期操作的例子&#xff1a; 1. 获取当前日期和时间&#xff1a; use chrono::prelude::*;let current_datetime Local::now()…

router 跳转打开新窗口

let url router.resolve({name: screen, })?.hrefwindow.open(url, _black)注意&#xff1a;新窗口无法全屏 参考链接&#xff1a;https://stackoverflow.com/questions/29281986/run-a-website-in-fullscreen-mode/30970886#30970886

数据库索引失效的情况

1.对添加了索引的字段进行函数运算 2.如果是字符串类型的字段&#xff0c;如果不加单引号也会导致索引失效 3.如果最索引字段使用模糊查询&#xff0c;如果是头部模糊索引将失效&#xff0c;如果是尾部模糊索引则正常 4.如果使用or分割符&#xff0c;如果or前面的条件中的列有…

nano命令

nano 被编辑的文件.txt等 输入xxxx crtl键X 输入yes 按enter保存 使用nano命令编辑文件或创建新文件&#xff0c;可按以下快捷键&#xff1a; 用法 光标控制 移动光标&#xff1a;使用用方向键移动。 选择文字&#xff1a;按住鼠标左键拖到。 复制、剪贴和粘贴 复制一整行&…

基于Yolov2深度学习网络的车辆检测算法matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1. 卷积神经网络&#xff08;CNN&#xff09; 4.2. YOLOv2 网络 4.3. 实现过程 4.4. 应用领域 5.算法完整程序工程 1.算法运行效果图预览 2.算法运行软件版本 MATLAB2022A 3.部分核心…

使用css和js给按钮添加微交互的几种方式

使用css和js给按钮添加微交互的几种方式 在现实世界中&#xff0c;当我们轻弹或按下某些东西时&#xff0c;它们会发出咔嗒声&#xff0c;例如电灯开关。有些东西会亮起或发出蜂鸣声&#xff0c;这些响应都是“微交互”&#xff0c;让我们知道我们何时成功完成了某件事。在本文…

【Winform学习笔记(五)】引用自定义控件库(dll文件)

引用自定义控件库dll文件 前言正文1、生成dll文件2、选择工具箱项3、选择需要导入的dll文件4、确定需要导入的控件5、导入及使用 前言 在本文中主要介绍 如何引用自定义控件库(dll文件)。 正文 1、生成dll文件 通过生成解决方案 或 重新生成解决方案 生成 dll 文件 生成的…

小程序页面传递布尔值不起作用的解决方法

问题 传参&#xff1a; wx.navigateTo({url: ../mymeet/mymeeting-detail?isprincipaltrue, })以下方式使用时不起作用或出现问题&#xff1a; onLoad(options) {if (options.isprincipal) {...}//或者if (options.isprincipal true) {...} }原因 这种方式传参后isprinci…

如何发布自己的npm包

发布一个简单的npm包 首先创建一个文件夹&#xff08;唯一的命名&#xff09;创建package.json包&#xff0c;输出npm init&#xff0c;一直回车就好。创建index.js文件&#xff0c;向外暴露方法。 将包上传或更新到 npm 执行登录命令&#xff1a;npm login 登录npm官网&…

React Hooks 中的属性详解

React Hooks 是 React 16.8 版本中新增的特性&#xff0c;允许我们在不编写 class 的情况下使用 state 和其他的 React 特性。Hooks 是一种可以让你在函数组件中“钩入” React 特性的函数。以下是一些常用的 React Hooks&#xff0c;并附有详细的用法和代码示例。 1. useStat…

OPENCV C++(二)直方图+分离颜色通道+画圆画线画矩形

分离RGB彩图颜色通道 也就是把每种分量的亮度图提出来 vector<Mat> channels;split(image1, channels);Mat R channels.at(0);Mat G channels.at(1);Mat B channels.at(2); 这样R,G,B每个图就是这个图的颜色分量图了 图片的克隆&#xff0c;深拷贝&#xff01; Mat…

代码随想录算法训练营day56 583.两个字符串的删除操作 72.编辑距离

题目链接583.两个字符串的删除操作 class Solution {public int minDistance(String word1, String word2) {int len1 word1.length();int len2 word2.length();int[][] dp new int[len11][len22];for(int i 0; i <len1; i){dp[i][0] i;}for(int j 0; j <len2; j){…

正则表达式学习记录(Python)

正则表达式学习记录&#xff08;Python&#xff09; 一、特殊符号和字符 多个正则表达式匹配 &#xff08; | ) 用来分隔不同的匹配模式&#xff0c;相当于逻辑或&#xff0c;可以符合其中任何一个正则表达式 at | home # 表示匹配at或者home bat | bet | bit # 表示匹配bat或…