深度学习笔记_5 经典卷积神经网络LeNet-5 解决MNIST数据集

1、定义LeNet-5模型,包括卷积层和全连接层。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 导入必要的库# 定义 LeNet-5 模型
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层和全连接层self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # 输入通道1,输出通道6,卷积核大小5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 输入通道6,输出通道16,卷积核大小5x5self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 全连接层,输入维度为16*4*4,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 64)  # 全连接层,输入维度为84,输出维度为64self.fc4 = nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10def forward(self, x):x = torch.relu(self.conv1(x))  # 第一个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = torch.relu(self.conv2(x))  # 第二个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = x.view(-1, 16 * 4 * 4)  # 数据展平,以便输入全连接层x = torch.relu(self.fc1(x))  # 第一个全连接层后接ReLU激活函数x = torch.relu(self.fc2(x))  # 第二个全连接层后接ReLU激活函数x = self.fc3(x)  # 第三个全连接层return x

2、对MNIST数据集进行加载和预处理,包括将图像转换为张量和标准化。

# 加载 MNIST 训练集和测试集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 定义数据预处理,包括转换为张量和标准化train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 下载MNIST数据集,并应用数据预处理train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 创建训练和测试数据加载器

3、初始化模型和优化器,使用随机梯度下降(SGD)优化器。

# 初始化模型和优化器
model = LeNet5()  # 创建LeNet-5模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 使用随机梯度下降作为优化器# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否支持GPU,如果支持则使用GPU
model.to(device)  # 将模型移动到GPU或CPU
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数

4、在每个训练周期内,进行前向传播、反向传播和参数更新。在训练集上进行精度测试,以评估模型性能。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms# 导入必要的库# 定义 LeNet-5 模型
class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()# 定义卷积层和全连接层self.conv1 = nn.Conv2d(1, 6, kernel_size=5)  # 输入通道1,输出通道6,卷积核大小5x5self.conv2 = nn.Conv2d(6, 16, kernel_size=5)  # 输入通道6,输出通道16,卷积核大小5x5self.fc1 = nn.Linear(16 * 4 * 4, 120)  # 全连接层,输入维度为16*4*4,输出维度为120self.fc2 = nn.Linear(120, 84)  # 全连接层,输入维度为120,输出维度为84self.fc3 = nn.Linear(84, 64)  # 全连接层,输入维度为84,输出维度为64self.fc4 = nn.Linear(64, 10)  # 全连接层,输入维度为64,输出维度为10def forward(self, x):x = torch.relu(self.conv1(x))  # 第一个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = torch.relu(self.conv2(x))  # 第二个卷积层后接ReLU激活函数x = torch.max_pool2d(x, 2)  # 池化层,执行2x2的最大池化x = x.view(-1, 16 * 4 * 4)  # 数据展平,以便输入全连接层x = torch.relu(self.fc1(x))  # 第一个全连接层后接ReLU激活函数x = torch.relu(self.fc2(x))  # 第二个全连接层后接ReLU激活函数x = self.fc3(x)  # 第三个全连接层return x# 加载 MNIST 训练集和测试集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# 定义数据预处理,包括转换为张量和标准化train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 下载MNIST数据集,并应用数据预处理train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 创建训练和测试数据加载器# 初始化模型和优化器
model = LeNet5()  # 创建LeNet-5模型
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # 使用随机梯度下降作为优化器# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 检查是否支持GPU,如果支持则使用GPU
model.to(device)  # 将模型移动到GPU或CPU
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失函数for epoch in range(10):model.train()  # 设置模型为训练模式running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU或CPUoptimizer.zero_grad()  # 梯度清零outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新模型参数running_loss += loss.item()# 在训练集上进行精度测试model.eval()  # 设置模型为评估模式correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)  # 将数据移动到GPU或CPUoutputs = model(images)  # 前向传播_, predicted = torch.max(outputs.data, 1)  # 获取预测类别total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint('Epoch: %d, Loss: %.3f, Accuracy: %.2f%%' % (epoch+1, running_loss, accuracy))

5、实验步骤和关键要点:

  1. 模型定义:LeNet-5模型被定义为一个经典的卷积神经网络,包括卷积层和全连接层。模型结构包括两个卷积层、两个池化层和三个全连接层。

  2. 数据加载和预处理:MNIST数据集被加载并进行预处理。预处理包括将图像转换为张量,并进行标准化,以确保输入数据在训练期间具有相似的尺度和分布。

  3. 初始化模型和优化器:LeNet-5模型被初始化,并使用随机梯度下降(SGD)作为优化器。SGD用于在训练期间调整模型参数以最小化损失函数。

  4. 模型训练:模型被训练在训练集上进行了多个周期的训练。对于每个训练周期,执行以下步骤:

    • 设置模型为训练模式。
    • 对每个批次进行前向传播,计算损失,执行反向传播,更新模型参数。
    • 记录每个训练周期的损失值。
  5. 模型测试:在每个训练周期结束后,模型在测试集上进行了精度测试。在测试期间,执行以下步骤:

    • 设置模型为评估模式。
    • 通过模型进行前向传播,计算测试集上的预测结果。
    • 检查模型的准确性,计算正确分类的样本数量,并计算总样本数量。
    • 记录每个训练周期的测试精度。
  6. 打印结果:在每个训练周期结束后,打印出训练周期的损失和测试精度,以便监控模型的性能。

这个实验展示了如何使用PyTorch框架构建、训练和测试深度学习模型。LeNet-5模型在MNIST数据集上取得了不错的手写数字识别性能。通过多个训练周期,模型的损失逐渐减小,测试精度逐渐增加,表明模型在训练过程中逐渐学习到了有效的特征表示,从而提高了在新样本上的分类准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114790.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

xxl-job学习

学习链接 xxl-job官方文档 【分布式任务调度】三、XXL-JOB详细介绍 xxljob从入门到精通-全网段最全解说 XXL-JOB分布式任务调度框架(一)-基础入门 XXL-JOB分布式任务调度框架(二)-策略详解 XXL-JOB分布式任务调度框架(三)-集群部署 XXL-JOB分布式任务调度框架(四)-源码分析…

C语言每日一题(16) 消失的数字

题目链接 一.题目描述 数组nums包含从0到n的所有整数,但其中缺了一个。请编写代码找出那个缺失的整数。你有办法在O(n)时间内完成吗? 二.题目分析 方法1 异或法 基于异或的思想,将0与数组中的数一一进行异或后得到的值,再与0…

行为型模式-备忘录模式

备忘录模式保存一个对象的某个状态,以便在适当的时候恢复对象。备忘录模式属于行为型模式。 意图:在不破坏封装性的前提下,捕获一个对象的内部状态,并在该对象之外保存这个状态。 主要解决:所谓备忘录模式就是在不破坏…

ARouter - 组件化通信方案

官网 https://github.com/alibaba/ARouter/blob/master/README_CN.md 项目简介 一个用于帮助 Android App 进行组件化改造的框架 —— 支持模块间的路由、通信、解耦 功能介绍 支持直接解析标准URL进行跳转,并自动注入参数到目标页面中支持多模块工程使用支持添…

【c++】跟webrtc学std array 1: 混音的多维数组

对于固定大小的数组,非常适合用std的array 实现。静态赋初值 static constexpr std::array<int, 5> kInputValues = {0, 1, 2, 1, 0}

html表格标签

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title> </head> <body><!--表格table 行 tr 列 td --> <table border"1px"><tr> <!--colsp…

逻辑漏洞详解

原理&#xff1a; 没有固定的概念&#xff0c;一般都是不符合常识的情况。比如任意用户注册&#xff0c;短信炸弹&#xff0c;占用资源&#xff0c;交易支付、密码修改、密码找回、越权修改、越权查询、突破限制。 根据实际业务逻辑进行比对&#xff0c;购物的可以根据数量&a…

gulp打包vue3+jsx+less插件

最终转换结果如下 在根目录下添加gulpfile.js文件&#xff0c;package.json添加命令npm run gulp var gulp require(gulp) var babel require(gulp-babel) var less require(gulp-less) var del require(del); var spawn require(child_process).spawn;const outDir &…

【FPGA零基础学习之旅#16】嵌入式块RAM-双口ram的使用

&#x1f389;欢迎来到FPGA专栏~双口ram的使用 ☆* o(≧▽≦)o *☆嗨~我是小夏与酒&#x1f379; ✨博客主页&#xff1a;小夏与酒的博客 &#x1f388;该系列文章专栏&#xff1a;FPGA学习之旅 文章作者技术和水平有限&#xff0c;如果文中出现错误&#xff0c;希望大家能指正…

X32位汇编和X64位区别无参函数分析(一)

前言 一、X32汇编函数无参无返回分析 二、X64汇编函数无参无返回分析 总结 前言 提示&#xff1a;以下是个人学习总结&#xff1a;如有错误请大神指出来&#xff0c;只供学习参考&#xff0c;本内容使用使用VS2017开发工具&#xff1a;语言是C&#xff0c;需要一些常见的汇编指…

手机知识:安卓内存都卷到24GB了,为何iPhone还在固守8GB

目录 一、系统机制 二、生态差异 三、总结 在刚刚过去的9月&#xff0c;年货iPhone 15系列正式发布&#xff0c;标准版不出意外还是挤药膏&#xff0c;除了镜头、屏幕有些升级&#xff0c;芯片用iPhone 14 Pro系列的&#xff0c;内存只有6GB&#xff1b;即使是集钛合金机身、…

【大数据】Kafka 实战教程(一)

Kafka 实战教程&#xff08;一&#xff09; 1.Kafka 介绍1.1. 主要功能1.2. 使用场景1.3 详细介绍1.3.1 消息传输流程1.3.2 Kafka 服务器消息存储策略1.3.3 与生产者的交互1.3.4 与消费者的交互 2.Kafka 生产者3.Kafka 消费者3.1 Kafka 消费模式3.1.1 At-most-once&#xff08;…

NVIDIA NCCL 源码学习(十一)- ring allreduce

之前的章节里我们看到了nccl send/recv通信的过程&#xff0c;本节我们以ring allreduce为例看下集合通信的过程。整体执行流程和send/recv很像&#xff0c;所以对于相似的流程只做简单介绍&#xff0c;主要介绍ring allreduce自己特有内容。 单机 搜索ring 在nccl初始化的过…

51单片机仿真软件 Proteus 8 Pro 安装步骤

51单片机仿真软件 Proteus 8 Pro 安装步骤 学习 51 单片机的时候&#xff0c;如果手头没有开发板&#xff0c;可以使用仿真软件 Proteus。Proteus 可以仿真 51 单片机及周边元器件&#xff08;例&#xff1a; LED&#xff09; 的运行情况。 可以简单认为&#xff1a;Proteus …

经典链表问题:解析链表中的关键挑战

这里写目录标题 公共子节点采用集合或者哈希采用栈拼接两个字符串差和双指针 旋转链表 公共子节点 例如这样一道题&#xff1a;给定两个链表&#xff0c;找出它们的第一个公共节点。 具体的题目描述我们来看看牛客的一道题&#xff1a; 这里我们有四种解决办法&#xff1a; …

pandas写入MySQL

安装好pandas、mysql pip install pandas pip install pymysql 导入pandas、mysql import pymysql as mysql import pandas as pd 建立连接 conmysql.connect(host10.10.0.221,userroot,passwordroot,databasepandas,port3306,charsetutf8) 创建游标 curcon.cursor() 读…

文档的重要性及接口文档模板

随着工作年限的增长&#xff0c;我们逐渐意识到工作中文档的重要性不可忽视。优质的文档不仅能提高工作效率&#xff0c;还能有效降低沟通成本&#xff0c;因此我们必须注重文档的撰写和格式。最近&#xff0c;由于未能及时更新文档&#xff0c;导致在项目开发中出现了信息冲突…

Vue解决 npm -v 报错(一)

报错内容&#xff1a; npm WARN config global --global, --local are deprecated. Use --locationglobal instead. 解决方案&#xff1a; 代码&#xff1a; prefix -g 替换为&#xff1a; prefix --locationglobal 原创作者&#xff1a;吴小糖 创作时间&#xff1a;2023.1…

Android之AMS原理分析

在学习android框架原理过程中&#xff0c;ams的原理非常重要&#xff0c;无论是在面试中还是在自己开发类库过程中都会接触到。 1 简述 ActivityManagerService是Android最核心的服务&#xff0c;负责管理四大组件的启动、切换、调度等工作。由于AMS的功能和重要性&#xff0c…

字符串输入(注意:cin遇到空白字符停止读入)

1.输入多个字符串时&#xff0c;又无法开二维数组&#xff1b; 可动态分配数组&#xff08;直接声明数组&#xff0c;指向的地址的不变的&#xff09; while (num--){char* arr (char*)malloc(10000 * sizeof(char));char ch 0;int k 0;while ((ch getchar()) ! \n){arr[k…