Pytorch深度学习快速入门—LeNet简单介绍(附代码)

一、网络模型结构

        LeNet是具有代表性的CNN,在1998年被提出,是进行手写数字识别的网络,是其他深度学习网络模型的基础。如下图所示,它具有连狙的卷积层和池化层,最后经全连接层输出结果。

二、各层参数详解

2.1 INPUT层-输入层

        数据input层,输入图像的尺寸为:32*32大小的一维一通道图片。

        注意:①灰度图像是单通道图像,其中每个像素只携带有关光强度的信息;

                   ②RGB图像是彩色图像,为三通道图像;

                   ③传统上输入层不被视为网络层次结构之一,因此输入层不算LeNet的网络结构。

2.2 C1层-卷积层

       输入数据(输入特征图input feature map):32*32

       卷积核大小:5*5

计算公式:

height_{out}=\frac{height_{in}-height_{kernel}+2*padding}{stride}+1width_{out}=\frac{width_{out}-widtht_{kernel}+2*padding}{stride}+1

其中,height_{in}是指输入图片的高度;width_{in}是指输入图片的宽度;height_{kernel}是指卷积核的大小;padding是指向图片外面补边,默认为0;S是指步长,卷积核遍历图片的步长,默认为1。

       卷积核种类(通道数):6

       输出数据(输出特征图output feature map):28*28

2.3 S2层-池化层(下采样层)

       池化是缩小高、长方向上的空间的运算。

       输入数据:28*28

       采样区域:2*2

       采样种类(通道数):6

       输出数据:14*14

注意:①经过池化运算,输入数据和输出数据的通道数不会发生变化。

②此时,S2中每个特征图的大小是C1中每个特征图大小的1/4.

2.4 C3层-卷积层

       输入数据:S2中所有6个或者几个特征map组合

       卷积核大小:5*5

       卷积核种类(通道数):16

       输出数据(输出特征图output feature map):10*10

注意:C3中的每个特征map是连接到S2中的所有6个或者几个特征map的,表示本层的特征map是上一层提取到的特征map的不同组合。

2.5 S4层-池化层(下采样层)

       输入数据:10*10

       采样区域:2*2

       采样种类(通道数):16

       输出数据:5*5

2.6 C5层-卷积层

       输入数据:S4层的全部16个单元特征map(与s4全相连)

       卷积核大小:5*5

       卷积核种类(通道数):120

       输出数据(输出特征图output feature map):1*1

2.7 F6层-全连接层

       输入数据:120维向量

       输出数据:84维向量

2.2 Output层-全连接层

       输入数据:84维向量

       输出数据:10维向量

三、代码实现(采用的激活函数为relu函数)

3.1 搭建网络框架

(1)导包:

import torch
import torch.nn as nn
import torch.nn.functional as F

 (2)定义卷积神经网络:由于训练数据采用的是彩色图片(三通道),因此与上面介绍的通道数有出入。

class Net(nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = nn.Conv2d(3,6,5)self.conv2 = nn.Conv2d(6,16,5)self.fc1 = nn.Linear(16*5*5,120)self.fc2 = nn.Linear(120,84)self.fc3 = nn.Linear(84,10)def forward(self,x):x = self.conv1(x)x = F.relu(x)x = F.max_pool2d(x,(2,2))x = F.max_pool2d(F.relu(self.conv2(x)),2)x = x.view(-1,x.size()[1:].numel())x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x

(3)测试网络效果:相当于打印初始化部分

net = Net()
print(net)

3.2 定义数据集

(1)导包:

import torchvision
import torchvision.transforms as transforms

(2)下载数据集:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=0)

(3)定义元组:进行类别名的中文转换

classes = ('airplane','automobile','bird','car','deer','dog','frog','horse','ship','truck')

 (4)运行数据加载器:使用绘图函数查看数据加载效果

import matplotlib.pyplot as plt
import numpy as npdef imshow(img):img = img / 2 + 0.5npimg = img.numpy()plt.imshow(np.transpose(npimg,(1,2,0)))plt.show()dataiter = iter(trainloader)
images,labels = dataiter.next()imshow(torchvision.utils.make_grid(images))print(labels)
print(labels[0],classes[labels[0]])
print(' '.join(classes[labels[j]] for j in range(4)))

3.3 定义损失函数与优化器

(1)定义损失函数:交叉熵损失函数

criterion = nn.CrossEntropyLoss()

(2)定义优化器:让网络进行更新,不断更新好的参数,达到更好的效果

import torch.optim as optim
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

3.4 训练网络

for epoch in range(2):running_loss = 0.0for i,data in enumerate(trainloader,0):inputs,labels = dataoptimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs,labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print('[%d,%5d] loss:%.3f' % (epoch + 1,i+1,running_loss/2000))running_loss = 0.0print("Finish")

3.5 测试网络

(1)保存学习好的网络参数:将权重文件保存到本地

PATH='./cifar_net.pth'
torch.save(net.state_dict(),PATH)

(2) 测试一组图片的训练效果

dataiter = iter(testloader)
images,labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
print('GroundTruth:',' '.join('%5s'% classes[labels[j]] for j in range(4)))

(3)观察整个训练集的测试效果

correct = 0
total = 0
with torch.no_grad():for data in testloader:images,labels = dataoutputs = net(images)_,predicted = torch.max(outputs,1)total += labels.size(0)correct += (predicted == labels).sum().item()correctGailv = 100*(correct / total)
print(correctGailv)

四、小结

        与“目前的CNN”相比,LeNet有以下几个不同点:

        ①激活函数不同:LeNet使用sigmoid函数,而目前的CNN中主要使用ReLU函数;

        ②原始的LeNet中使用子采样(subsampling)缩小中间数据的大小,而目前的CNN中Max池化是主流。

参考:LeNet详解-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/115984.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++之函数重载【详解】

C之函数重载【详解】 1. 函数重载的概念2. C支持函数重载的原理(名字修饰)2.1 前言2.2 函数名修饰规则2.3 VS下的命名修饰规则 重载函数是函数的一种特殊情况,为方便使用,C允许在同一中声明几个功能类似的同名函数,但是这些同名函数的形式参数…

HarmonyOS 音频开发指导:使用 AudioRenderer 开发音频播放功能

AudioRenderer 是音频渲染器,用于播放 PCM(Pulse Code Modulation)音频数据,相比 AVPlayer 而言,可以在输入前添加数据预处理,更适合有音频开发经验的开发者,以实现更灵活的播放功能。 开发指导…

Redis --- 安装教程

Redis--- 特性,使用场景,安装 安装教程在Ubuntu下安装在Centos7.6下安装Redis5 特性在内存中存储数据可编程的扩展能力持久化集群高可用快速 应用场景实时数据存储作为缓存或者Session存储消息队列 安装教程 🚀安装之前切换到root用户。 在…

Amazon图片下载器:利用Scrapy库完成图像下载任务

概述 本文介绍了如何使用Python的Scrapy库编写一个简单的爬虫程序,实现从Amazon网站下载商品图片的功能。Scrapy是一个强大的爬虫框架,提供了许多方便的特性,如选择器、管道、中间件、代理等。本文将重点介绍如何使用Scrapy的图片管道和代理…

mysql下载和安装,使用

先下载安装 官方下载 已下载备份软件 安装,一路下一步设置环境变量 4. 打开一个cmd,输入mysql -u root -p

dns服务

安装 apt install bind9 bind9-utils 监听53端口 udp53做解析用的 tcp53端口 创建配置文件 [rootrocky8 ~]# cd /var/named/ 注意权限,不然不生效 [rootrocky8 named]# touch luohw.org.zone [rootrocky8 named]# chmod 640 luohw.org.zone [rootrocky8 named]# c…

MySql数据库实现注册登录及个人信息查询的数据库设计

前言: 数据库使用的是mysql 以下创建的表,实现以下功能: 用户1,账号admin,年龄20,关联3件商品 用户2,账号admin2,年龄30,关联2件商品(没有商品和用户1重复) 用户3,账号admin3,年龄50,关联2件商品(这两件商品均是用户1的其中两种) 登录查询对应数据的实现 1.创建用户表Users,并…

【Java集合类面试十二】、HashMap为什么线程不安全?

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:HashMap为什么线程不安全…

【ALO-BP预测】基于蚁狮算法优化BP神经网络回归预测研究(Matlab代码实现)

💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…

vue3 element-plus 组件table表格 勾选框回显(初始化默认回显)完整静态代码

<template><el-table ref"multipleTableRef" :data"tableData" style"width: 100%"><el-table-column type"selection" width"55" /><el-table-column label"时间" width"120">…

Linux 中监控磁盘分区使用情况的 10 个工具

在本文[1]中&#xff0c;我们将回顾一些可用于检查 Linux 中磁盘分区的 Linux 命令行实用程序。 监控存储设备的空间使用情况是系统管理员最重要的任务之一&#xff0c;它可以确保存储设备上有足够的可用空间&#xff0c;以维持 Linux 系统的高效运行。 1. fdisk fdisk 是一个强…

Mysql事务+redo日志+锁分类+隔离机制+mvcc

事务&#xff1a; 是数据库操作的最小工作单元&#xff0c;是作为单个逻辑工作单元执行的一系列操作&#xff1b;这些操作作为一个整体一起向系统提交&#xff0c;要么都执行、要么都不执行&#xff1b;事务是一组不可再分割的操作集合&#xff08;工作逻辑单元&#xff09;&a…

前端导出数据到Excel(Excel.js导出数据)

库&#xff1a;Excel.js&#xff08;版本4.3.0&#xff09; 和 FileSaver&#xff08;版本2.0.5&#xff09; CDN地址&#xff1a; <script src"https://cdn.bootcdn.net/ajax/libs/exceljs/4.3.0/exceljs.min.js"></script> <script src"http…

volatile-可见性案例详解

6.3 volatile特性 6.3.1 保证可见性 保证不同线程对某个变量完成操作后结果及时可见&#xff0c;即该共享变量一旦改变所有线程立即可见 不加volatile&#xff0c;没有可见性&#xff0c;程序无法停止 加了volatile&#xff0c;保证可见性&#xff0c;程序可以停止 public…

031-第三代软件开发-屏幕保护

第三代软件开发-屏幕保护 文章目录 第三代软件开发-屏幕保护项目介绍屏幕保护 关键字&#xff1a; Qt、 Qml、 MediaPlayer、 VideoOutput、 function 项目介绍 欢迎来到我们的 QML & C 项目&#xff01;这个项目结合了 QML&#xff08;Qt Meta-Object Language&#…

Rowset Class

Rowset类在PeopleCode中非常常见&#xff0c;以下将Rowset翻译成行集&#xff0c;顾名思义&#xff0c;行的集合 目录 Understanding Rowset Class Shortcut Considerations Data Type of a Rowset Object Scope of a Rowset Object Rowset Class Built-In Functions Row…

SysTick—系统定时器

SysTick 简介 SysTick—系统定时器是属于CM3内核中的一个外设&#xff0c;内嵌在NVIC中。系统定时器是一个24bit 的向下递减的计数器&#xff0c;计数器每计数一次的时间为1/SYSCLK&#xff0c;一般我们设置系统时钟SYSCLK 等于72M。当重装载数值寄存器的值递减到0的时候&#…

SpringBoot+Mybatis 配置多数据源及事务管理

目录 1.多数据源 2.事务配置 项目搭建参考: 从零开始搭建SpringBoot项目_从0搭建springboot项目-CSDN博客 SpringBoot学习笔记(二) 整合redismybatisDubbo-CSDN博客 1.多数据源 添加依赖 <dependencies><dependency><groupId>org.springframework.boot&…

TCP--拥塞控制

大家好&#xff0c;我叫徐锦桐&#xff0c;个人博客地址为www.xujintong.com。平时记录一下学习计算机过程中获取的知识&#xff0c;还有日常折腾的经验&#xff0c;欢迎大家来访。 TCP中另一个重要的点就是拥塞控制&#xff0c;TCP是无私的当它感受到网络拥堵了&#xff0c;就…

字节码进阶之javassist字节码操作类库详解

字节码进阶之javassist字节码操作类库详解 文章目录 前言使用教程添加Javassist依赖库创建和修改类方法拦截创建新的方法 进阶用法创建新的注解创建新的接口创建新的构造器生成动态代理修改方法示例2 前言 Javassist&#xff08;Java programming assistant&#xff09;是一个…