Pytorch-08 实战:手写数字识别

手写数字识别项目在机器学习中经常被用作入门练习,因为它相对简单,但又涵盖了许多基本的概念。这个项目可以视为机器学习中的 “Hello World”,因为它涉及到数据收集、特征提取、模型选择、训练和评估等机器学习中的基本步骤,所以手写数字识别项目是一个很好的起点。

我们的要做的是,训练出一个人工神经网络,使它能够识别手写数字(如下图所示):

以下是一个简单的示例代码,展示如何使用PyTorch创建一个手写数字识别的模型,包括数据集加载、训练和测试过程。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f"训练集第1张图像形状 = {train_dataset.__getitem__(0)[0].shape}")
print(f"训练集第1张图像标签 = {train_dataset.__getitem__(0)[1]}")
print(f"测试集第1张图像形状 = {test_dataset.__getitem__(0)[0].shape}")
print(f"测试集第1张图像标签 = {test_dataset.__getitem__(0)[1]}")# 使用数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义神经网络模型并将其移至GPU
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Sequential(nn.Linear(28*28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10))def forward(self, x):x = x.view(x.size(0), -1)x = self.fc(x)return xmodel = Net().to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型,训练过程输出损失值
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoptimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 测试模型,输出数字识别准确率
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on the test set: {100 * correct / total}%')

程序运行后,输出如下:

训练集第1张图像形状 = torch.Size([1, 28, 28])
训练集第1张图像标签 = 5
测试集第1张图像形状 = torch.Size([1, 28, 28])
测试集第1张图像标签 = 7
Epoch [1/5], Loss: 0.3935443162918091
Epoch [2/5], Loss: 0.1757822483778
Epoch [3/5], Loss: 0.1337398886680603
Epoch [4/5], Loss: 0.03868262842297554
Epoch [5/5], Loss: 0.025882571935653687
Accuracy on the test set: 96.85%进程已结束,退出代码为 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15120.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 打印、自定义打印、页面打印、隐藏页眉页脚

花了一天时间搞了个打印功能,现则将整体实现过程进行整理分享。先来看看效果图: 1、页面展示为: 2、重组页面打印格式为:这里重组页面的原因是客户要求为一行两列打印 !内容过于多的行则独占一行显示完整。 整体实现&…

区块链论文总结速读--CCF A会议 USENIX Security 2024 共7篇 附pdf下载

Conference:33rd USENIX Security Symposium CCF level:CCF A Categories:网络与信息安全 Year:2024 Num:7 1 Title: Practical Security Analysis of Zero-Knowledge Proof Circuits 零知识证明电路的实用安全…

hbase版本从1.2升级到2.1 spark读取hive数据写入hbase 批量写入类不存在问题

在hbase1.2版本中&#xff0c;pom.xml中引入hbase-server1.2…0和hbase-client1.2.0就已经可以有如下图的类。但是在hbase2.1.0版本中增加这两个不行。hbase-server2.1.0中没有mapred包&#xff0c;同时mapreduce下就2个类。版本已经不支持。 <dependency><groupId>…

安全访问python字典:避免空键错误的艺术

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、直接访问字典键的问题 三、使用get方法安全访问字典键 四、get方法的实际应…

Could not create connection to database server的错误原因

1、使用MyBatis 连接数据库报错 org.apache.ibatis.exceptions.PersistenceException: ### Error updating database. Cause: com.mysql.jdbc.exceptions.jdbc4.MySQLNonTransientConnectionException: Could not create connection to database server. ### The error may …

用队列实现栈,用栈实现队列

有两个地方会讨论到栈&#xff0c;一个是程序运行的栈空间&#xff0c;一个是数据结构中的栈&#xff0c;本文中讨论的是后者。 栈是一个先入后出&#xff0c;后入先出的数据结构&#xff0c;只能操作栈顶。栈有两个操作&#xff0c;push 和 pop&#xff0c;push 是向将数据压…

电脑如何远程监控?如何远程监控电脑屏幕?

远程监控是指通过网络技术和远程视频传输技术&#xff0c;实现对某一特定区域、设备或场景进行远程实时监测、管理、控制的一种技术手段。 它将视频传输、图像采集、数据存储和远程操作等多种技术相结合&#xff0c;能够在任意时间、任意地点实现对被监测对象的远程监控。 远程…

IO模型:同步阻塞、同步非阻塞、同步多路复用、异步非阻塞

目录 stream和channel对比 同步、异步、阻塞、非阻塞 线程读取数据的过程 同步阻塞IO 同步非阻塞IO 同步IO多路复用 异步IO 优缺点对比 stream和channel对比 stream不会自动缓冲数据&#xff0c;channel会利用系统提供的发送缓冲区、接收缓冲区。stream仅支持阻塞API&am…

轻松拿捏C语言——【字符函数】字符分类函数、字符转换函数

&#x1f970;欢迎关注 轻松拿捏C语言系列&#xff0c;来和 小哇 一起进步&#xff01;✊ &#x1f308;感谢大家的阅读、点赞、收藏和关注&#x1f495; &#x1f339;如有问题&#xff0c;欢迎指正 感谢 目录&#x1f451; 一、字符分类函数&#x1f319; 二、字符转换函数…

hive3从入门到精通(二)

第15章:Hive SQL Join连接操作 15-1.Hive Join语法规则 join分类 在Hive中&#xff0c;当下版本3.1.2总共支持6种join语法。分别是&#xff1a; inner join&#xff08;内连接&#xff09;left join&#xff08;左连接&#xff09;right join&#xff08;右连接&#xff09;…

力扣HOT100 - 136. 只出现一次的数字

解题思路&#xff1a; class Solution {public int singleNumber(int[] nums) {int single 0;for (int num : nums) {single ^ num;}return single;} }

基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录 数据集介绍&#xff1a;resnet18模型代码加载数据集&#xff08;Dataset与Dataloader&#xff09;模型训练训练准确率及损失函数&#xff1a;resnet18交通标志分类源码yolov5检测与识别&#xff08;交通标志&#xff09; 本文共包含两部分&#xff0c; 第一部分是用re…

C++学习笔记(21)——继承

目录 1. 继承的概念及定义1.1 继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承关系和访问限定符1.2.3 继承基类成员访问方式的变化 继承的概念总结&#xff1a; 2. 基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数知识点&#xff1a;派生类中6个默认成员函数…

win11 wsl ubuntu24.04

win11 wsl ubuntu24.04 一&#xff1a;开启Hyper-V二&#xff1a;安装wsl三&#xff1a;安装ubuntu24.04三&#xff1a;桥接模式&#xff0c;固定IP四&#xff1a;U盘使用五&#xff1a;wsl 从c盘迁移到其它盘参考资料 一&#xff1a;开启Hyper-V win11家庭版开启hyper-v 桌面…

Pytorch-01 框架简介

智能框架概述 人工智能框架是一种软件工具&#xff0c;用于帮助开发人员构建和训练人工智能模型。这些框架提供了各种功能&#xff0c;如定义神经网络结构、优化算法、自动求导等&#xff0c;使得开发人员可以更轻松地实现各种人工智能任务。通过使用人工智能框架&#xff0c;…

虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。

问题&#xff1a; 虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本。 模块“Upgrade”启动失败。 未能启动虚拟机。 分析&#xff1a; 该虚拟机环境之前使用的VMware版本与你所使用的VMware版本不一致。大概率你使用的是刚从别人电脑里拷过来的虚拟机环境。 解决&…

游戏后台开发技术全面解析

在这个数字时代&#xff0c;游戏产业已经成为全球最受欢迎的娱乐方式之一。从简单的手机游戏到复杂的大型多人在线角色扮演游戏&#xff08;MMORPG&#xff09;&#xff0c;游戏的世界正变得越来越丰富和多样化。而这一切的背后&#xff0c;都离不开强大的游戏后台技术支持。在…

Java重写

方法重写的意义 在java中&#xff0c;子类可以继承父类中的方法&#xff0c;而不需要重新编写相同的方法&#xff0c;但是有时子类并不想原封不动的继承父类方法&#xff0c;需要做一定的修改&#xff0c;这时候就需要使用方法重写 方法重写的定义 在继承的前提下 子类可以根据…

Python使用连接池操作MySQL

测试环境说明&#xff1a;Python版本是 3.8.10 &#xff0c;DBUtils版本是3.1.0 &#xff0c;pymysql版本是1.0.3 首先安装指定版本的连接池库DBUtils 、还有pymysql pip install DBUtils3.1.0 pip install pymysql1.0.3创建文件 sqlConfig.py # sqlConfig.pyimport pymysql…

YOLOv10论文解读:实时端到端的目标检测模型

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…