Pytorch-08 实战:手写数字识别

手写数字识别项目在机器学习中经常被用作入门练习,因为它相对简单,但又涵盖了许多基本的概念。这个项目可以视为机器学习中的 “Hello World”,因为它涉及到数据收集、特征提取、模型选择、训练和评估等机器学习中的基本步骤,所以手写数字识别项目是一个很好的起点。

我们的要做的是,训练出一个人工神经网络,使它能够识别手写数字(如下图所示):

以下是一个简单的示例代码,展示如何使用PyTorch创建一个手写数字识别的模型,包括数据集加载、训练和测试过程。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载MNIST数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f"训练集第1张图像形状 = {train_dataset.__getitem__(0)[0].shape}")
print(f"训练集第1张图像标签 = {train_dataset.__getitem__(0)[1]}")
print(f"测试集第1张图像形状 = {test_dataset.__getitem__(0)[0].shape}")
print(f"测试集第1张图像标签 = {test_dataset.__getitem__(0)[1]}")# 使用数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义神经网络模型并将其移至GPU
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc = nn.Sequential(nn.Linear(28*28, 128),nn.ReLU(),nn.Linear(128, 64),nn.ReLU(),nn.Linear(64, 10))def forward(self, x):x = x.view(x.size(0), -1)x = self.fc(x)return xmodel = Net().to(device)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型,训练过程输出损失值
num_epochs = 5
for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoptimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')# 测试模型,输出数字识别准确率
model.eval()
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移至GPUoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on the test set: {100 * correct / total}%')

程序运行后,输出如下:

训练集第1张图像形状 = torch.Size([1, 28, 28])
训练集第1张图像标签 = 5
测试集第1张图像形状 = torch.Size([1, 28, 28])
测试集第1张图像标签 = 7
Epoch [1/5], Loss: 0.3935443162918091
Epoch [2/5], Loss: 0.1757822483778
Epoch [3/5], Loss: 0.1337398886680603
Epoch [4/5], Loss: 0.03868262842297554
Epoch [5/5], Loss: 0.025882571935653687
Accuracy on the test set: 96.85%进程已结束,退出代码为 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/15120.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue 打印、自定义打印、页面打印、隐藏页眉页脚

花了一天时间搞了个打印功能,现则将整体实现过程进行整理分享。先来看看效果图: 1、页面展示为: 2、重组页面打印格式为:这里重组页面的原因是客户要求为一行两列打印 !内容过于多的行则独占一行显示完整。 整体实现&…

区块链论文总结速读--CCF A会议 USENIX Security 2024 共7篇 附pdf下载

Conference:33rd USENIX Security Symposium CCF level:CCF A Categories:网络与信息安全 Year:2024 Num:7 1 Title: Practical Security Analysis of Zero-Knowledge Proof Circuits 零知识证明电路的实用安全…

hbase版本从1.2升级到2.1 spark读取hive数据写入hbase 批量写入类不存在问题

在hbase1.2版本中&#xff0c;pom.xml中引入hbase-server1.2…0和hbase-client1.2.0就已经可以有如下图的类。但是在hbase2.1.0版本中增加这两个不行。hbase-server2.1.0中没有mapred包&#xff0c;同时mapreduce下就2个类。版本已经不支持。 <dependency><groupId>…

安全访问python字典:避免空键错误的艺术

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、直接访问字典键的问题 三、使用get方法安全访问字典键 四、get方法的实际应…

Could not create connection to database server的错误原因

1、使用MyBatis 连接数据库报错 org.apache.ibatis.exceptions.PersistenceException: ### Error updating database. Cause: com.mysql.jdbc.exceptions.jdbc4.MySQLNonTransientConnectionException: Could not create connection to database server. ### The error may …

流量卡激活先交100块钱,这个确定不是套路吗?

很多第一次申请流量卡的时候&#xff0c;在激活的时候都会问的一个问题&#xff0c;为什么激活时都要先交100块钱&#xff0c;这个确定不是套路吗 ​  先说一下&#xff0c;首充其实并不是商家的要求&#xff0c;准确来说是运营商的要求&#xff0c;目前运营商推出的线上流量…

用队列实现栈,用栈实现队列

有两个地方会讨论到栈&#xff0c;一个是程序运行的栈空间&#xff0c;一个是数据结构中的栈&#xff0c;本文中讨论的是后者。 栈是一个先入后出&#xff0c;后入先出的数据结构&#xff0c;只能操作栈顶。栈有两个操作&#xff0c;push 和 pop&#xff0c;push 是向将数据压…

电脑如何远程监控?如何远程监控电脑屏幕?

远程监控是指通过网络技术和远程视频传输技术&#xff0c;实现对某一特定区域、设备或场景进行远程实时监测、管理、控制的一种技术手段。 它将视频传输、图像采集、数据存储和远程操作等多种技术相结合&#xff0c;能够在任意时间、任意地点实现对被监测对象的远程监控。 远程…

在Windows系统服务器上安装Node.js的步骤

在windows操作系统中&#xff0c;可以使用命令行&#xff08;CMD&#xff09;安装Node.js。 【第一步】下载Node.js安装包 在官网https://nodejs.cn/download下载安装包&#xff0c;以 64位 Windows 安装包为例。 【第二步】将Node.js安装包上传到服务器 将安装包上传到指定…

基于Docker Compose部署One-API的详细指南

部署One-API的详细指南 前言 one-api是一个开源项目(https://github.com/songquanpeng/one-api)&#xff0c;旨在简化API的开发与管理过程。这个项目提供了一个全面的解决方案&#xff0c;特别适用于需要高效管理API接口的开发者和团队。以下是该项目的一些核心特点和功能&am…

IO模型:同步阻塞、同步非阻塞、同步多路复用、异步非阻塞

目录 stream和channel对比 同步、异步、阻塞、非阻塞 线程读取数据的过程 同步阻塞IO 同步非阻塞IO 同步IO多路复用 异步IO 优缺点对比 stream和channel对比 stream不会自动缓冲数据&#xff0c;channel会利用系统提供的发送缓冲区、接收缓冲区。stream仅支持阻塞API&am…

轻松拿捏C语言——【字符函数】字符分类函数、字符转换函数

&#x1f970;欢迎关注 轻松拿捏C语言系列&#xff0c;来和 小哇 一起进步&#xff01;✊ &#x1f308;感谢大家的阅读、点赞、收藏和关注&#x1f495; &#x1f339;如有问题&#xff0c;欢迎指正 感谢 目录&#x1f451; 一、字符分类函数&#x1f319; 二、字符转换函数…

hive3从入门到精通(二)

第15章:Hive SQL Join连接操作 15-1.Hive Join语法规则 join分类 在Hive中&#xff0c;当下版本3.1.2总共支持6种join语法。分别是&#xff1a; inner join&#xff08;内连接&#xff09;left join&#xff08;左连接&#xff09;right join&#xff08;右连接&#xff09;…

Python学习——— tupple

Python 中的数据结构是通过某种方式组织在一起的数据元素的集合&#xff0c;这些数据元素可以是数字、字符、甚至可以是其他数据结构 在 Python 中&#xff0c;最基本的数据结构是序列&#xff08;列表和元组&#xff09;&#xff0c;序列中的每个元素都有一个序号&#xff08;…

力扣HOT100 - 136. 只出现一次的数字

解题思路&#xff1a; class Solution {public int singleNumber(int[] nums) {int single 0;for (int num : nums) {single ^ num;}return single;} }

基于卷积神经网络的交通标志识别(pytorch,opencv,yolov5)

文章目录 数据集介绍&#xff1a;resnet18模型代码加载数据集&#xff08;Dataset与Dataloader&#xff09;模型训练训练准确率及损失函数&#xff1a;resnet18交通标志分类源码yolov5检测与识别&#xff08;交通标志&#xff09; 本文共包含两部分&#xff0c; 第一部分是用re…

回溯算法06(总结+leetcode332,51,37)

参考资料&#xff1a; https://programmercarl.com/%E5%9B%9E%E6%BA%AF%E6%80%BB%E7%BB%93.html 力扣这三题暂时不在本篇笔记中贴代码了&#xff0c;有兴趣的可参考332.重新安排形成、N皇后、解数独 总结&#xff1a; 画树形图分析题目 用途&#xff1a;回溯算法是用 递归实现…

C++学习笔记(21)——继承

目录 1. 继承的概念及定义1.1 继承的概念1.2 继承定义1.2.1 定义格式1.2.2 继承关系和访问限定符1.2.3 继承基类成员访问方式的变化 继承的概念总结&#xff1a; 2. 基类和派生类对象赋值转换3.继承中的作用域4.派生类的默认成员函数知识点&#xff1a;派生类中6个默认成员函数…

win11 wsl ubuntu24.04

win11 wsl ubuntu24.04 一&#xff1a;开启Hyper-V二&#xff1a;安装wsl三&#xff1a;安装ubuntu24.04三&#xff1a;桥接模式&#xff0c;固定IP四&#xff1a;U盘使用五&#xff1a;wsl 从c盘迁移到其它盘参考资料 一&#xff1a;开启Hyper-V win11家庭版开启hyper-v 桌面…