深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoaderfrom dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnxparser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:torch.backends.cudnn.benchmark = Truedef train():os.makedirs('./output', exist_ok=True)if True: #not os.path.exists('output/total.txt'):ml.image_list(args.datapath, 'output/total.txt')ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)model = Net(10)#model = models.vgg16(num_classes=10)#model = models.resnet18(num_classes=10)  # 调用内置模型#model.load_state_dict(torch.load('./output/params_10.pth'))#from torchsummary import summary#summary(model, (3, 28, 28))if args.cuda:print('training with cuda')model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)loss_func = nn.CrossEntropyLoss()for epoch in range(args.epochs):# training-----------------------------------model.train()train_loss = 0train_acc = 0for batch, (batch_x, batch_y) in enumerate(train_loader):if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)  # 256x3x28x28  out 256x10loss = loss_func(out, batch_y)train_loss += loss.item()pred = torch.max(out, 1)[1]train_correct = (pred == batch_y).sum()train_acc += train_correct.item()print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'% (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),loss.item(), train_correct.item() / len(batch_x)))optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 更新learning rateprint('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),train_acc / (len(train_data))))# evaluation--------------------------------model.eval()eval_loss = 0eval_acc = 0for batch_x, batch_y in val_loader:if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)loss = loss_func(out, batch_y)eval_loss += loss.item()pred = torch.max(out, 1)[1]num_correct = (pred == batch_y).sum()eval_acc += num_correct.item()print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),eval_acc / (len(val_data))))# 保存模型。每隔多少帧存模型,此处可修改------------if (epoch + 1) % 1 == 0:# torch.save(model, 'output/model_' + str(epoch+1) + '.pth')torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')#to_onnx(model, 3, 28, 28, 'params.onnx')if __name__ == '__main__':train()
  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/667383.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vcruntime140.dll最新的修复方法,一键修复vcruntime140.dll的手段

在这篇文章中,我们将深入探讨并详细介绍各种修复vcruntime140.dll文件缺失或损坏问题的方法。鉴于此类问题广泛存在并影响了众多用户,本文目的是向大家展示不同的修复策略,希望能够帮助每个人解决这些棘手的技术难题。下面一起来看看vcruntim…

【RT-DETR有效改进】UNetv2提出的一种SDI多层次特征融合模块(细节高效涨点)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、本文介绍 本问给大家带来的改进机制是UNetv2提出的一种多层次特征融合模块(SDI)其是一种用于替换Concat操作的模块,SDI模块的主要思想是通过整合编码器生成的层级特征图来增强图像中的语义信息和细节信息。包括皮肤…

黑豹程序员-ElementPlus选择图标器

ElementPlus组件提供了很多图标svg 如何在你的系统中&#xff0c;用户可以使用呢&#xff1f; 这就是图标器&#xff0c;去调用ElementPlus的icon组件库&#xff0c;展示到页面&#xff0c;用户选择&#xff0c;返回选择的组件名称。 效果 代码 <template><el-inpu…

机器学习 - 梯度下降

场景 上一章学习了代价函数&#xff0c;在机器学习中&#xff0c;代价模型是用于衡量模型预测值与真实值之间的差异的函数。它是优化算法的核心&#xff0c;目标是通过调整模型的参数来最小化代价模型的值&#xff0c;从而使模型的预测结果更接近真实值。常见的代价模型是均方…

【Boost】:searcher的建立(四)

searcher的建立 一.初始化二.搜索功能三.完整源代码 sercher主要分为两部分&#xff1a;初始化和查找。 一.初始化 初始化分为两步&#xff1a;1.创建Index对象&#xff1b;2.建立索引 二.搜索功能 搜索分为四个步骤 分词&#xff1b;触发&#xff1a;根据分词找到对应的文档…

架构设计特训

一、考点分布 软件架构风格&#xff08;※※※※&#xff09;层次型软件架构风格&#xff08;※※※※&#xff09;面向服务的软件架构风格&#xff08;※※※※&#xff09;云原生架构风格&#xff08;※※※※&#xff09;质量属性与架构评估&#xff08;※※※※※&#xff…

Transformer实战-系列教程1:Transformer算法解读1

&#x1f6a9;&#x1f6a9;&#x1f6a9;Transformer实战-系列教程总目录 有任何问题欢迎在下面留言 Transformer实战-系列教程1&#xff1a;Transformer算法解读1 Transformer实战-系列教程2&#xff1a;Transformer算法解读2 现在最火的AI内容&#xff0c;chatGPT、视觉大模…

Golang切片与数组

在Go语言中&#xff0c;切片&#xff08;Slice&#xff09;和数组&#xff08;Array&#xff09;是两个核心的数据结构&#xff0c;它们在内存管理、灵活性以及性能方面有着显著的区别。接下来将解析Golang中的切片与数组&#xff0c;通过清晰的概念解释、案例代码和实际应用场…

小林Coding_操作系统_读书笔记

一、硬件结构 1. CPU是如何执行的 冯诺依曼模型&#xff1a;中央处理器&#xff08;CPU&#xff09;、内存、输入设备、输出设备、总线 CPU中&#xff1a;寄存器&#xff08;程序计数器、通用暂存器、指令暂存器&#xff09;&#xff0c;控制单元&#xff08;控制CPU工作&am…

[word] word页面视图放大后,影响打印吗? #笔记#学习方法

word页面视图放大后&#xff0c;影响打印吗&#xff1f; word文档的页面视图又叫普通视图&#xff0c;又叫打印视图&#xff0c;是系统默认的视图&#xff0c;是用户用的最多最常见的视图。 问&#xff1a;怎样打开页面视图&#xff1f; 答&#xff1a;两种方法 方法一、点…

JS 基本语句

函数调用&#xff0c;分支&#xff0c;循环&#xff0c;语句示例。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"&g…

【Funny guys】龙年专属测试鼠标寿命小游戏...... 用Python给大家半年了......

目录 【Funny guys】龙年专属测试鼠标寿命小游戏...... 用Python给大家半年了...... 龙年专属测试鼠标寿命小游戏用Python给大家半年了贪吃龙游戏 文章所属专区 码农新闻 欢迎各位编程大佬&#xff0c;技术达人&#xff0c;以及对编程充满热情的朋友们&#xff0c;来到我们的程…

【项目实践03】【布隆过滤器】

文章目录 一、前言二、项目背景三、实现方案1. 谷歌 布隆过滤器2. Redis 布隆过滤器 四、思路延伸1. 布隆过滤器的实现原理2. 布隆过滤器的一些扩展3. 布谷鸟过滤器 五、参考内容 一、前言 本系列用来记录一些在实际项目中的小东西&#xff0c;并记录在过程中想到一些小东西&a…

基于python+控制台的车辆信息管理系统

基于python控制台的车辆信息管理系统 一、系统介绍二、效果展示三、其他系统实现四、获取源码 一、系统介绍 打印功能菜单、添加车辆信息、删除车辆信息、修改车辆信息、显示车辆信息、退出系统&#xff0c;并且需要接收用户的输入&#xff0c;在根据输入内容调用相应函数实现…

Docker部署Grafana+Promethus监控Mysql和服务器

一、Grafana部署所需资源 Grafana 需要最少的系统资源&#xff1a; 建议的最小内存&#xff1a;512 MB建议的最低 CPU&#xff1a;1 官方文档&#xff1a;https://grafana.com/docs/grafana/latest/getting-started/build-first-dashboard/ 可以看到&#xff0c;我的这台服务…

有了Future为什么还要CompletableFuture?

文章目录 Future 接口理论知识复习Future 接口概述场景描述小结 Future 接口常用实现类 FutureTask 异步任务Future 的作用Futrue 编码测试优缺点分析优点缺点小结 面对一些复杂的任务对于简单的业务场景使用 Future 接口完全 OK回调通知创建异步任务多个任务前后依赖可以组合对…

DFS——连通性和搜索顺序

dfs的搜索是基于栈&#xff0c;但一般可以用用递归实现&#xff0c;实际上用的是系统栈。有内部搜索和外部搜索两种&#xff0c;内部搜索是在图的内部&#xff0c;内部搜索一般基于连通性&#xff0c;从一个点转移到另一个点&#xff0c;或者判断是否连通之类的问题&#xff0c…

[Python] opencv - 什么是直方图?如何绘制图像的直方图?

什么是直方图&#xff1f; 直方图是一种统计图&#xff0c;用于展示数据的分布情况。它将数据按照一定的区间或者组进行划分&#xff0c;然后计算在每个区间或组内的数据频数或频率&#xff08;即数据出现的次数或占比&#xff09;&#xff0c;然后用矩形或者柱形图的形式将这…

C++学习Day03之构造函数和析构函数

目录 一、程序及输出1.1 构造函数1.2 析构函数1.3 构造和析构必须要声明在全局作用域 二、分析与总结 一、程序及输出 1.1 构造函数 构造函数 没有返回值 不用写void 函数名 与 类名相同 可以有参数 &#xff0c;可以发生重载 构造函数 由编译器自动调用一次 无须手动调用 创建…

C语言——Q/编译和链接

目录 一、翻译环境和运⾏环境 二、翻译环境 1、预处理&#xff08;预编译&#xff09; 2、编译 2.2.1 词法分析&#xff1a; 2.2.2 语法分析 2.2.3 语义分析 3、汇编 4、链接 三、运行环境 一、翻译环境和运行环境 在ANSI C 的任何⼀种实现中&#xff0c;存在两个不…