简易机器学习笔记(八)关于经典的图像分类问题-常见经典神经网络LeNet

前言

图像分类是根据图像的语义信息对不同类别图像进行区分,是计算机视觉的核心,是物体检测、图像分割、物体跟踪、行为分析、人脸识别等其他高层次视觉任务的基础。图像分类在许多领域都有着广泛的应用,如:安防领域的人脸识别和智能视频分析等,交通领域的交通场景识别,互联网领域基于内容的图像检索和相册自动归类,医学领域的图像识别等。

这里简单讲讲LeNet

我的推荐是可以看看这个视频,可视化的查看卷积神经网络是如何一层一层地抽稀获得特征,最后将所有的图像展开成一个一维的轴,再通过全连接神经网络预测得到一个最后的预测值。

手写数字识别 1.4 LeNet-5-哔哩哔哩

在这里插入图片描述

计算过程

前置知识:

  1. 步长 Stride & 加边 Padding

卷积后尺寸=(输入尺寸-卷积核大小+加边像素数)/步长 + 1

默认Padding = ‘valid’ (丢弃),strides = 1
在这里插入图片描述

正式计算

  1. 卷积层1:

第一层我们给定的图像时32 * 32,使用六个5 x 5的卷积核,步长为1

第一层中没有加边,那么卷积后的尺寸就是(32 - 5 + 0 )/1 + 1 =28,那么输出的图像就是 28*28的边长

在第一层中,由于我们使用了六个卷积核,我们得到的输出为:62828,可以理解为一个六层厚的图像

  1. 池化层1:

我们在池化层内在2x2的图像内选取了一个最大值或者平均值,也就是图片整体缩水到原先的二分之一,所以我们得到池化层的输出为 6 x 14 x 14

  1. 卷积层2:

还是按照公式,卷积后尺寸=(输入-卷积核+加边像素数)/步长 + 1,这个时候输入为6 x 14 x 14,这一次我们给定了16个卷积核,得到输出后的尺寸为(14 - 5 + 0)/1 + 1 = 10,得到输出为161010

关于这个16个卷积核是怎么来的,可以见图:

问了下组里的大佬,大佬说这个卷积核数目和层数很多是经验值,即你寻求更多或者更少的卷积核数目或者层数,实际效果不一定有经验值更好,反正都是离散值,就随便试试就行了。

其中:卷积输出尺寸nout:nin为输入原图尺寸大小;s是步长(一次移动几个像素);p补零圈数,

我们这里输入的值

  1. 池化层2

得到 输出后尺寸为16 * 5 * 5

  1. 全连接层1:

输入为16 * 5 * 5 ,有120个5*5卷积核,步长为1,输出尺寸为(5 - 5 + 0)/1 + 1 =1,这时候输出的就是一条直线的一维输出了

  1. 全连接层2:

输入为120,使用了84个神经元,

  1. 输出层

输入84,输出为10

比如我们如图所示,在代码中是这样的:

# 导入需要的包
import paddle
import numpy as np
from paddle.nn import Conv2D, MaxPool2D, Linear## 组网
import paddle.nn.functional as F
from paddle.vision.transforms import ToTensor
from paddle.vision.datasets import MNIST
#定义LeNet网络结构# 定义 LeNet 网络结构
class LeNet(paddle.nn.Layer):def __init__(self, num_classes=1):super(LeNet,self).__init__()#创建卷积层和池化层#创建第一个卷积层self.conv1 = Conv2D(in_channels=1,out_channels=6,kernel_size=5)self.max_pool1 = MaxPool2D(kernel_size=2,stride=2)#尺寸的逻辑:池化层未改变通道数,当前通道为6#创建第二个卷积层self.conv2 = Conv2D(in_channels=6,out_channels=16,kernel_size=5)self.max_pool2 = MaxPool2D(kernel_size=2,stride=2)#创建第三个卷积层self.conv3 = Conv2D(in_channels=16,out_channels=120,kernel_size=4)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]# 输入size是[28,28],经过三次卷积和两次池化之后,C*H*W等于120self.fc1 = Linear(in_features=120, out_features=64)# 创建全连接层,第一个全连接层的输出神经元个数为64, 第二个全连接层输出神经元个数为分类标签的类别数self.fc2 = Linear(in_features=64, out_features=num_classes)# 网络的前向计算过程def forward(self, x):x = self.conv1(x)# 每个卷积层使用Sigmoid激活函数,后面跟着一个2x2的池化x = F.sigmoid(x)x = self.max_pool1(x)x = F.sigmoid(x)x = self.conv2(x)x = self.max_pool2(x)x = self.conv3(x)# 尺寸的逻辑:输入层将数据拉平[B,C,H,W] -> [B,C*H*W]x = paddle.reshape(x, [x.shape[0], -1])x = self.fc1(x)x = F.sigmoid(x)x = self.fc2(x)return x
# 飞桨会根据实际图像数据的尺寸和卷积核参数自动推断中间层数据的W和H等,只需要用户表达通道数即可。
# 下面的程序使用随机数作为输入,查看经过LeNet-5的每一层作用之后,输出数据的形状。# 输入数据形状是 [N, 1, H, W]
# 这里用np.random创建一个随机数组作为输入数据
x = np.random.randn(*[3,1,28,28])
x = x.astype('float32')# 创建LeNet类的实例,指定模型名称和分类的类别数目
model = LeNet(num_classes=10)# 通过调用LeNet从基类继承的sublayers()函数,
# 查看LeNet中所包含的子层
print(model.sublayers())
x = paddle.to_tensor(x)for item in model.sublayers():#item是LeNet类中的一个子层#查看经过子层之后的输出数据形状try:x = item(x)except:x = paddle.reshape(x, [x.shape[0], -1])x = item(x)if len(item.parameters())==2:# 查看卷积和全连接层的数据和参数的形状,# 其中item.parameters()[0]是权重参数w,item.parameters()[1]是偏置参数bprint(item.full_name(), x.shape, item.parameters()[0].shape, item.parameters()[1].shape)else:# 池化层没有参数print(item.full_name(), x.shape)# 设置迭代轮数
EPOCH_NUM = 5
#定义训练过程 
def train(model,opt,train_loader,valid_loader):print("start training ... ")model.train()for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):img = data[0]label = data[1] #计算模型输出# 计算模型输出logits = model(img)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)avg_loss = paddle.mean(loss)if batch_id % 2000 == 0:print("epoch: {}, batch_id: {}, loss is: {:.4f}".format(epoch, batch_id, float(avg_loss.numpy())))#反向传播avg_loss.backward()opt.step()opt.clear_grad()model.eval()accuracies = []losses = []for batch_id, data in enumerate(valid_loader()):img = data[0]label = data[1]# 计算模型输出logits = model(img)pred = F.softmax(logits)# 计算损失函数loss_func = paddle.nn.CrossEntropyLoss(reduction='none')loss = loss_func(logits, label)acc = paddle.metric.accuracy(pred, label)accuracies.append(acc.numpy())losses.append(loss.numpy())print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))model.train()# 保存模型参数paddle.save(model.state_dict(), 'mnist.pdparams')    # 创建模型
model = LeNet(num_classes=10)
# 设置迭代轮数
EPOCH_NUM = 5
# 设置优化器为Momentum,学习率为0.001
opt = paddle.optimizer.Momentum(learning_rate=0.001, momentum=0.9, parameters=model.parameters())
# 定义数据读取器
train_loader = paddle.io.DataLoader(MNIST(mode='train', transform=ToTensor()), batch_size=10, shuffle=True)
valid_loader = paddle.io.DataLoader(MNIST(mode='test', transform=ToTensor()), batch_size=10)
# 启动训练过程
train(model, opt, train_loader, valid_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/595323.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视频号频繁显眼!是资本的运作?还是互联网新风口到来?

视频号这个平台出现了,特别是在最近存在感越来越强,而且已经有些人开始在视频号当中购物了,这也就意味着,视频号电商出现了,腾讯也开始搞电商了。 很多人可能对视频号做电商这个事情呢,抱有一定的迟疑态度&…

【算法】数论---约数

约数里面的一个重要性质&#xff1a;一个数的约数都是成对存在的(以sqrt(x)为分界线) 一、求一个数的所有约数---试除法 int x; cin>>x; int yue[10000]{0},idx0; for(int i1;i<x/i;i) {if(x%i0){yue[idx]i;cout<<i<<" ";} }for(int iidx-1;i&…

深度学习:大规模模型分布式训练框架DeepSpeed

深度学习&#xff1a;大规模模型分布式训练框架DeepSpeed DeepSpeed简介DeepSpeed核心特点DeepSpeed如何工作&#xff1f;DeepSpeed如何使用&#xff1f;参考文献 DeepSpeed简介 随着机器学习模型变得越来越复杂和庞大&#xff0c;训练这些模型所需的计算资源也在不断增加。特别…

九州金榜|家庭教育小妙招如何培养孩子学习习惯

做小学老师的时候&#xff0c;很多家长都问过我同一个问题&#xff0c;孩子成绩差&#xff0c;如何提高孩子的成绩&#xff1f; 好像成绩是我们的家长判断孩子是否优秀的唯一标准&#xff0c;一切都是围绕着成绩说话&#xff0c;考好了表扬、鼓励&#xff0c;考不好就会被批评…

【UE5.1】给森林添加天气效果

在上一篇博客&#xff08;【UE5.1】程序化生成Nanite植被&#xff09;基础上给森林添加天气交互效果&#xff0c;角色和雪地、水坑的交互效果。 目录 效果 步骤 一、准备工作 二、添加超动态天空 2.1 修改时间 2.2 昼夜交替 三、添加超动态天气 3.1 改变天气 3.2 …

uniCloud 云数据库(新建表、增、删、改、查)

新建表结构描述文件 todo 为自定义的表名 表结构描述文件的默认后缀为 .schema.json 设置表的操作权限 uniCloud-aliyun/database/todo.schema.json 默认的操作权限都是 false "permission": {"read": false,"create": false,"update&quo…

html中的form表单以及相关控件input、文本域、下拉select等等的详细解释 ,点赞加关注持续更新~

文章目录 表单创建表单forminput 标签input标签的value属性设置input标签格式单选框多选框上传文件下拉菜单文本域设置文本域格式label 标签按钮 表单 作用&#xff1a;收集用户信息。 使用场景&#xff1a; 登录页面注册页面搜索区域 创建表单form <form action".…

DataGear 4.7.0 发布,数据可视化分析平台

DataGear专业版 1.0.0 正式发布&#xff0c;欢迎试用&#xff01; http://datagear.tech/pro/ DataGear 4.7.0 发布&#xff0c;严重漏洞和BUG修复&#xff0c;具体更新内容如下&#xff1a; 新增&#xff1a;HTTP数据集新增【编码请求地址】支持&#xff0c;可用于解决请求…

希亦、觉飞、小吉三款婴儿洗衣机大比拼!全方位对比测评

由于年龄幼小的婴儿的皮肤都非常的幼嫩&#xff0c;因此婴儿衣物材质的类型大部分都是采用为纯棉&#xff0c;并且婴儿的衣物不能够与大人的衣物一起进行混洗&#xff0c;容易把细菌感染到宝宝的衣物上&#xff0c;因此很多家庭为了保证宝宝衣服的有效清洁&#xff0c;避免交叉…

TXT文本删除第一行文本变成空要如何解决呢

首先大家一起来看下这个TXT文本里面有多行内容&#xff0c;想把开头第一行批量删除不要掉。 1..如果是一两个本可以手动删除也很方便哦&#xff0c;如果文本量比较大如几十几、几百个文本大家一直都选用《首助编辑高手》工具去批量操作哦。批量操作可以大大提高工作效率。接来看…

AI实景无人直播创业项目:开启自动直播新时代,一部手机即可实现财富增长

在当今社会&#xff0c;直播已经成为了人们日常生活中不可或缺的一部分。无论是商家推广产品、明星互动粉丝还是普通人分享生活&#xff0c;直播已经渗透到了各行各业。然而&#xff0c;传统直播方式存在着一些不足之处&#xff0c;如需现场主持人操作、高昂的费用等。近年来&a…

Minitab 各版本安装指南

Minitab下载链接 https://pan.baidu.com/s/1PLqocknkoRGGI9lbV3e45A?pwd0531 1.鼠标右击【Minitab 21(64bit)】压缩包&#xff08;win11及以上系统需先点击“显示更多选项”&#xff09;选择【解压到 Minitab 21(64bit)】。 2.打开解压后的文件夹&#xff0c;鼠标右击【setu…

MacOS - 苹果电脑程序还能正常启动,但图标消失不见了~

问题描述 网上有一些解决方案说是 killall Finder 命令&#xff0c;重置 Docker 等等&#xff0c;但是发现还是不行&#xff0c;于是必杀技…… 解决方案 方案一、删除该 App&#xff0c;重装即可方案二、如果懒得重装&#xff0c;可以在 Finder 中找到对应的应用程序&#xf…

如何把照片多余的地方擦除?一键消除图片上的瑕疵,简单又轻松,太方便了

在数字繁荣的时代&#xff0c;图片处理已然成为我们生活乐章中不可或缺的一部分&#xff0c;就如画师手中的画笔般灵动&#xff0c;摄影师镜头下的世界般多彩。然而&#xff0c;在捕捉或获取这些美丽的图片时&#xff0c;可能会不小心闯入一些不速之客&#xff0c;给画面带来瑕…

听GPT 讲Rust源代码--compiler(3)

File: rust/compiler/rustc_codegen_cranelift/src/value_and_place.rs 在Rust的编译器源代码中&#xff0c;rust/compiler/rustc_codegen_cranelift/src/value_and_place.rs文件扮演着重要的角色。它包含了与值和位置&#xff08;Place&#xff09;相关的实现和结构体定义&…

非常不错的SSH工具

Tabby 官网地址&#xff1a; Tabby - a terminal for a more modern age GitHub地址&#xff1a; GitHub - Eugeny/tabby: A terminal for a more modern age 使用说明&#xff1a; Xterminal 使用说明地址&#xff1a; 一款颜值、功能都很能打的 SSH 工具 官方地址&…

SpingBoot的项目实战--模拟电商【4.订单及订单详情的生成】

&#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 接下来看看由辉辉所写的关于SpringBoot电商项目的相关操作吧 目录 &#x1f973;&#x1f973;Welcome Huihuis Code World ! !&#x1f973;&#x1f973; 一.功能需求 二.代码编写 …

求职开挂-用Chatgpt回答面试官的提问和帮助写代码可行不?

如果戴个谷歌眼镜去面试&#xff0c;直接AI扫问题得到答案&#xff0c;算不算作弊&#xff1f; 如果未来公司面试题目都连网&#xff0c;大家用力扣刷题之类学会做了&#xff0c;找工作工资多拿个20%&#xff0c;多个3-5000一个月不是很美&#xff1f; 最近的电视剧《鸣龙少年…

数据结构—环形缓冲区

写在前面&#xff0c;2023年11月开始进入岗位&#xff0c;工作岗位是嵌入式软件工程师。2024年是上班的第一年的&#xff0c;希望今年收获满满&#xff0c;增长见闻。 数据结构—环形缓冲区 为什么要使用环形数组&#xff0c;环形数组比起原来的常规数组的优势是什么&#xf…

mysql的读写分离

MySQL 读写分离原理 读写分离就是只在主服务器上写&#xff0c;只在从服务器上读。 主数据库处理事务性操作&#xff0c;而从数据库处理 select 查询。 数据库复制被用来把主数据库上事务性操作导致的变更同步到集群中的从数据库。 常见的mysql读写分离分为以下两种 1&…