【学习笔记】深度学习实战 | LeNet

在这里插入图片描述

简要声明


  1. 学习相关网址
    1. [双语字幕]吴恩达深度学习deeplearning.ai
    2. Papers With Code
    3. Datasets
  2. 深度学习网络基于PyTorch学习架构,代码测试可跑。
  3. 本学习笔记单纯是为了能对学到的内容有更深入的理解,如果有错误的地方,恳请包容和指正。

参考文献


  1. PyTorch Tutorials [https://pytorch.org/tutorials/]
  2. PyTorch Docs [https://pytorch.org/docs/stable/index.html]
  3. LeNet (1998) [Gradient-based learning applied to document recognition]

简要介绍


LeNet

在这里插入图片描述

DatasetMNIST
Input (feature maps)32×32 (28×28)
CONV Layers2
FC Layers2
ActivationSigmoid
Output10

代码分析


函数库调用

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

处理数据

数据下载

# 从开放数据集中下载训练数据
train_data = datasets.MNIST(root="data",train=True,download=True,transform=ToTensor(),
)# 从开放数据集中下载测试数据
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)print(f'Number of training examples: {len(train_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 60000
Number of testing examples: 10000

数据加载器(可选)

batch_size = 64# 创建数据加载器
train_dataloader = DataLoader(train_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)for X, y in test_dataloader:print(f"Shape of X [N, C, H, W]: {X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")break

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64

创建模型

# 选择训练设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")

Using cuda device

class LeNet(nn.Module):def __init__(self, output_dim):super().__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.conv_2 = nn.Sequential(nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc_1 = nn.Sequential(nn.Linear(16*5*5, 120),nn.Sigmoid())self.fc_2 = nn.Sequential(nn.Linear(120, 84),nn.Sigmoid())self.fc_3 = nn.Linear(84, output_dim)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = x.view(x.size(0), -1)x = self.fc_1(x)x = self.fc_2(x)x = self.fc_3(x)return xmodel = LeNet(10).to(device)
print(model)

LeNet(
(conv_1): Sequential(
(0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(conv_2): Sequential(
(0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(1): Sigmoid()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(fc_1): Sequential(
(0): Linear(in_features=400, out_features=120, bias=True)
(1): Sigmoid()
)
(fc_2): Sequential(
(0): Linear(in_features=120, out_features=84, bias=True)
(1): Sigmoid()
)
(fc_3): Linear(in_features=84, out_features=10, bias=True)
)

训练模型

选择损失函数和优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

训练循环

def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)model.train()for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)# Compute prediction errorpred = model(X)loss = loss_fn(pred, y)# Backpropagationloss.backward()optimizer.step()optimizer.zero_grad()if batch % 100 == 0:loss, current = loss.item(), (batch + 1) * len(X)print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

测试循环

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

训练模型

epochs = 10.
for t in range(epochs):print(f"Epoch {t+1}\n-------------------------------")train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
print("Done!")

Epoch 10
loss: 0.015569 [ 64/60000]
loss: 0.029817 [ 6464/60000]
loss: 0.043169 [12864/60000]
loss: 0.027709 [19264/60000]
loss: 0.021492 [25664/60000]
loss: 0.011533 [32064/60000]
loss: 0.045418 [38464/60000]
loss: 0.042875 [44864/60000]
loss: 0.152001 [51264/60000]
loss: 0.040214 [57664/60000]
Test Error:
Accuracy: 98.6%, Avg loss: 0.044844

模型处理

保存模型

model_name = 'LeNet'
model_file = model_name + ".pth"
torch.save(model.state_dict(), model_file)
print("Saved PyTorch Model State to " + model_file)

Saved PyTorch Model State to LeNet.pth

Summary


安装torchsummary

pip install torchsummary

调用summary

from torchsummary import summarymodel = LeNet(10).to(device)
summary(model, (1, 28, 28))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1            [-1, 6, 28, 28]             156Sigmoid-2            [-1, 6, 28, 28]               0MaxPool2d-3            [-1, 6, 14, 14]               0Conv2d-4           [-1, 16, 10, 10]           2,416Sigmoid-5           [-1, 16, 10, 10]               0MaxPool2d-6             [-1, 16, 5, 5]               0Linear-7                  [-1, 120]          48,120Sigmoid-8                  [-1, 120]               0Linear-9                   [-1, 84]          10,164Sigmoid-10                   [-1, 84]               0Linear-11                   [-1, 10]             850
================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.11
Params size (MB): 0.24
Estimated Total Size (MB): 0.35
----------------------------------------------------------------

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/711428.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

KubeEdge 边缘计算

文章目录 1.KubeEdge2.KubeEdge 特点3.KubeEdge 组成4.KubeEdge 架构 KubeEdge # KubeEdgehttps://iothub.org.cn/docs/kubeedge/ https://iothub.org.cn/docs/kubeedge/kubeedge-summary/1.KubeEdge KubeEdge 是一个开源的系统,可将本机容器化应用编排和管理扩展…

蓝牙耳机和笔记本电脑配对连接上了,播放设备里没有显示蓝牙耳机这个设备,选不了输出设备

环境: WIN10 杂牌蓝牙耳机6s 问题描述: 蓝牙耳机和笔记本电脑配对连接上了,播放设备里没有显示蓝牙耳机这个设备,选不了输出设备 解决方案: 1.打开设备和打印机,找到这个设备 2.选中这个设备&#…

Linux下gcc编译常用命令详解

在Linux环境下,使用gcc编译器进行源代码的编译是程序员日常工作的一部分。本篇将介绍一些常用的gcc编译命令,帮助开发者更好地理解和使用这些命令。 1. 基本编译命令 gcc工作流程: 编译单个源文件 gcc source.c -o output这个命令将sour…

调试工具vue,react,redux

React Developer Tools Redux DevTools Vue devtools 使用浏览器官方组件扩展搜索安装

枚举和联合(共用体)

目录 枚举枚举类型的定义枚举的优点 联合(共用体)联合类型的定义联合的特点联合大小的计算 枚举 枚举顾名思义就是一一列举,把可能的取值一一列举 枚举类型的定义 enum Day , enum Sex ,enum Color 都是枚举类型{}中…

曾桂华:车载座舱音频体验探究与思考| 演讲嘉宾公布

智能车载音频 I 分论坛将于3月27日同期举办! 我们正站在一个前所未有的科技革新的交汇点上,重塑我们出行体验的变革正在悄然发生。当人工智能的磅礴力量与车载音频相交融,智慧、便捷与未来的探索之旅正式扬帆起航。 在驾驶的旅途中&#xff0…

通过css修改video标签的原生样式

通过css修改video标签的原生样式 描述实现结果 描述 修改video标签的原生样式 实现 在控制台中打开设置,勾选显示用户代理 shadow DOM,就可以审查video标签的内部样式了 箭头处标出来的就是shodow DOM的内容,这些内容正常不可见的&#x…

MySQL 用了哪种默认隔离级别,实现原理是什么?

MySQL 的默认隔离级别是 RR - 可重复读,可以通过命令来查看 MySQL 中的默认隔离级别。 RR - 可重复读是基于多版本并发控制(Multi-Version Concurrency Control,MVCC )实现的。MVCC,在读取数据时通过一种类似快照的方…

视觉三维重建colmap框架的现状与未来

注:该文章首发3D视觉工坊,链接如下3D视觉工坊 前言 众所周知,三维重建的发展已经进入了稳定期,尤其是离线方案的发展几乎处于停滞期,在各大论刊上也很少见到传统sfmmvs亮眼的文章。这也不难理解,传统的多视…

MYSQL 解释器小记

解释器的结果通常通过上述表格展示: 1. select_type 表示查询的类型 simple: 表示简单的选择查询,没有子查询或连接操作 primary:表示主查询,通常是最外层的查询 subquery :表示子查询,在主查询中嵌套的查询 derived: 表示派…

【王道数据结构】【chapter8排序】【P360t2】

试编写一个算法,使之能够在数组L[1……n]中找出第k小的元素(即从小到大排序后处于第k个位置的元素)(可以直接采用排序,但下面的排序的代码只是为了方便核对是不是第k小的元素,k从0开始计算) #in…

出海手游收入一路高歌,营销上如何成功?

出海手游收入一路高歌,营销上如何成功? 以RPG和SLG为代表的中重度游戏一直是国内厂商在海外市场的传统优势品类,因为它们具有较高的投资回报率,是国内厂商在国际市场上取得成功的“吸金”利器。 据伽马数据发布的《2023全球移动游…

SpringCloud搭建微服务之Consul服务配置

1. 概述 前面有介绍过Consul既可以用于服务注册和发现,也可以用于服务配置,本文主要介绍如何使用Consul实现微服务的配置中心,有需要了解如何安装Consul的小伙伴,请查阅SpringCloud搭建微服务之Consul服务注册与发现 &#xff0c…

steam怎么付款

信用卡支付 登录Steam账户,选择需要购买的游戏或其他物品,点击“加入购物车”。在购物车页面点击“去结账”按钮,进入付款页面。在付款页面选择信用卡付款方式,填写信用卡信息,输入验证码,点击确认付款。 …

Servlet 新手村引入-编写一个简单的servlet项目

Servlet 新手村引入-编写一个简单的servlet项目 文章目录 Servlet 新手村引入-编写一个简单的servlet项目一、编写一个 Hello world 项目1.创建项目2.引入依赖3.手动创建一些必要的目录/文件4.编写代码5.打包程序6.部署7.验证程序 二、更方便的处理方案(插件引入&am…

autocrlf和safecrlf

git远程拉取及提交代码,windows和linux平台换行符转换问题,用以下两行命令进行配置: git config --global core.autocrlf false git config --global core.safecrlf true CRLF是windows平台下的换行符,LF是linux平台下的换行符。…

基于springboot+vue的公交线路查询系统

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战,欢迎高校老师\讲师\同行交流合作 ​主要内容:毕业设计(Javaweb项目|小程序|Pyt…

Find My运动相机|苹果Find My技术与相机结合,智能防丢,全球定位

运动相机设计用于在各种运动和极限环境中使用,如徒步、登山、攀岩、骑行、滑翔、滑雪、游泳和潜水等,它们通常具有防抖防震、深度防水和高清画质的特点,能够适应颠簸剧烈的环境,甚至可以承受一定程度的摔落,一些运动相…

基于systick实现获取系统运行时间

基于systick实现获取系统运行时间 文章目录 基于systick实现获取系统运行时间systick.c代码结构:代码功能:总结 systick.c #include <stdint.h> #include "gd32f30x.h"static volatile uint64_t g_sysRunTime 0;/** ***************************************…

数学建模【聚类模型】

一、聚类模型简介 “物以类聚&#xff0c; 人以群分”&#xff0c;所谓的聚类&#xff0c;就是将样本划分为由类似的对象组成的多个类的过程。聚类后&#xff0c;我们可以更加准确的在每个类中单独使用统计模型进行估计、分析或预测&#xff0c;也可以探究不同类之间的相关性和…