7.3 详解NiN模型--首次使用多层感知机(1x1卷积核)替换掉全连接层的模型

一.前提知识

多层感知机:由一个输入层,一个或多个隐藏层和一个输出层组成。(至少有一个隐藏层,即至少3层)

全连接层:是MLP的一种特殊情况,每个节点都与前一层的所有节点连接,全连接层可以解决线性可分问题,无法学习到非线性特征。(只有输入和输出层)

二.NiN模型特点

NiN与过去模型的区别:AlexNet和VGG对LeNet的改进在于如何扩大加深这两个模块。他们都使用了全连接层,使用全连接层就可能完全放弃表征的空间结构。
NiN放弃了使用全连接层,而是使用两个1x1卷积层(将空间维度中的每个像素视为单个样本,将通道维度视为不同特征。),相当于在每个像素的通道上分别使用多层感知机

优点:NiN去除了全连接层,可以减少过拟合,同时显著减少NiN的参数数量

三.模型架构

在这里插入图片描述

四.代码

import torch
from torch import nn
from d2l import torch as d2l
import time
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(# 卷积层nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),# 两个带有ReLU激活函数的 1x1卷积层nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU())
net = nn.Sequential(nin_block(1,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),# 标签类别是10nin_block(384,10,kernel_size=3,strides=1,padding=1),# 二维自适应平均池化,不用指定池化窗口大小nn.AdaptiveAvgPool2d((1,1)),# 将(样本,通道,w,h) = (批量,10,1,1),四维的输出转成2维的输出,其形状为(批量大小,10)nn.Flatten()
)
X = torch.rand(size=(1,1,224,224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

六.不同参数训练结果

学习率是0.1的情况

# 训练模型
lr,num_epochs,batch_size = 0.1,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

学习率是0.05的情况(提升了6个点)

'''开始计时'''
start_time = time.time()
# 训练模型
lr,num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
print(f'{round(run_time,2)}s')

在这里插入图片描述

学习率为0.01,批次等于30的情况(反而下降了)

在这里插入图片描述

思考

为什么NiN块中有两个1x1卷积层?

从NiN替换掉全连接层,使用多层感知机角度来说:
因为1个1x1卷基层相当于全连接层,两个1x1卷积层使输入和输出层中间有了隐藏层,才相当于多层感知机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/33828.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式(6)原型模式

一、介绍 Java中自带的原型模式是clone()方法。该方法是Object的方法,native类型。他的作用就是将对象的在内存的那一块内存数据一字不差地再复制一个。我们写简单类的时候只需要实现Cloneable接口,然后调用Object::clone方法就可实现克隆功能。这样实现…

emqx-5.1.4开源版使用记录

emqx-5.1.4开源版使用记录 windows系统安装eqmx 去官网下载 emqx-5.1.4-windows-amd64.zip,然后找个目录解压 进入bin目录,执行命令启动emqx 执行命令 emqx.cmd start使用emqx 访问内置的web管理页面 浏览器访问地址 http://localhost:18083/#/dashboard/overv…

猿人学刷题系列(第一届比赛)——第二题( js 混淆 - 动态cookie 1)

题目:提取全部5页发布日热度的值,计算所有值的加和 地址:https://match.yuanrenxue.cn/match/2 思路分析 本题我们会简单说一下两种不同的方式去处理,一种是不还原混淆代码直接从源代码硬扣生成逻辑,另一种则是还原…

【golang】库源码文件

库源码文件是不能被直接运行的源码文件,它仅用于存放程序实体,这些程序实体可以被其他代码使用(只要遵从Go语言规范的话)。 这里的“其他代码”可以与被使用的程序实体在同一个源码文件内,也可以在其他源码文件&#x…

Android JNI3--JNI基础

1,C预处理器 C 预处理器不是编译器的组成部分,但是它是编译过程中一个单独的步骤。简言之,C 预处理器只不过是一个文本替换工具而已,它们会指示编译器在实际编译之前完成所需的预处理。我们将把 C 预处理器(C Preproc…

数据通信——VRRP

引言 之前把实验做了,结果发现我好像没有写过VRRP的文章,连笔记都没记过。可能是因为对STP的记忆,导致现在都没忘太多。 一,什么是VRRP VRRP全名是虚拟路由冗余协议,虚拟路由,看名字就知道这是运行在三层接…

Linux 基础(三)常用命令-文件目录

常用命令 man 和 --help文件(夹)文件补充命令 man 和 --help 查看命令帮助信息 man 命令命令 --helphelp 命令(适用于绑定在shell里面的一些基础命令) 文件(夹) mkdir -p dirrmdir -p dirrm -rf …cp […

Solr的入门使用

Solr是Apache下的一个顶级开源项目,采用Java开发,它是基于Lucene的全文搜索服务器。Solr提供了比Lucene更为丰富的查询语言,同时实现了可配置、可扩展,并对索引、搜索性能进行了优化,被很多需要搜索的网站中广泛使用。…

人员数量统计和跟踪的实现原理详细介绍,并提供完整实现代码

人数统计和跟踪项目的方法 目标是构建一个具有以下功能的系统。 读取视频中的帧。在输入框架上绘制所需的参考线。使用对象检测模型检测人员。标记检测到的人的质心。跟踪该标记质心的运动。计算质心移动的方向(向上还是向下移动)。计算进入或离开参考线的人数。根据计数,递…

C语言三子棋小游戏--数组的应用

注:在最后面,完整源码会以两种形式展现。在讲解时,以三个源文件的形式。 前言:三子棋,顾名思义,就是三个子连在一起就可以胜出。在本节我们要介绍的三子棋模式是这样子的:在键盘输入坐标&#x…

HTML详解连载(3)

HTML详解连载(3) 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽表单作用使用场景 input标签基本使用示例type属性值以及说明 input标签占位文本示例注意 单选框 radio代码示例 多选框-checkbox注意代码示例 文本域作用标签&#xff1…

【前端 | CSS】flex布局

基本概念 Flexible模型,通常被称为 flexbox,是一种一维的布局模型。它给 flexbox 的子元素之间提供了强大的空间分布和对齐能力 我们说 flexbox 是一种一维的布局,是因为一个 flexbox 一次只能处理一个维度上的元素布局,一行或者…

R语言4_安装BayesSpace

环境Ubuntu22/20, R4.1 你可能会报错说你的R语言版本没有这个库,但其实不然。这是一个在Bioconductor上的库。 同时我也碰到了这个问题,ERROR: configuration failed for package systemfonts’等诸多类似问题,下面的方法可以一并解决。 第…

Mr. Cappuccino的第61杯咖啡——Spring之BeanPostProcessor

Spring之BeanPostProcessor 概述基本使用项目结构项目代码运行结果源代码 常用处理器项目结构项目代码执行结果 概述 BeanPostProcessor:Bean对象的后置处理器,负责对已创建好的bean对象进行加工处理; BeanPostProcessor中的两个核心方法&am…

[内网渗透]CFS三层靶机渗透

文章目录 [内网渗透]CFS三层靶机渗透网络拓扑图靶机搭建Target10x01.nmap主机探活0x02.端口扫描0x03.ThinkPHP5 RCE漏洞拿shell0x04.上传msf后门(reverse_tcp)反向连接拿主机权限 内网渗透Target2(1)路由信息探测(2)msf代理配置&a…

LeetCode150道面试经典题--找出字符串中第一个匹配项的下标(简单)

1.题目 给你两个字符串 haystack 和 needle ,请你在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始)。如果 needle 不是 haystack 的一部分,则返回 -1 。 2.示例 3.思路 回溯算法:首先将…

LeetCode[1122]数组的相对排序

难度:Easy 题目: 给你两个数组,arr1 和 arr2,arr2 中的元素各不相同,arr2 中的每个元素都出现在 arr1 中。 对 arr1 中的元素进行排序,使 arr1 中项的相对顺序和 arr2 中的相对顺序相同。未在 arr2 中出现…

【mysql】MySQL CUP过高如何排查?

文章目录 一. 问题锁定二. QPS激增会导致CPU飘高三. 慢SQL会导致CPU飘高四. 大量空闲连接会导致CPU飘高五. MySQL问题排查常用命令 一. 问题锁定 通过top命令查看服务器CPU资源使用情况,明确CPU占用率较高的是否是mysqld进程,如果是则可以明确CUP飘高的原…

[ubuntu]创建root权限的用户 该用户登录后自动切换为root用户

一、创建新用户 1、创建新用户 sudo useradd -r -m -s /bin/bash 用户名 # -r:建立系统账号 -m:自动建立用户的登入目录 -s:指定用户登入后所使用的shell2、手动为用户设置密码 passwd 用户名 二、为用户增加root权限 1、添加写权限 ch…

【redis基础】

目录 一、概述 1.NoSQL 1.1 简述 1.2 类型 1.3 应用场景 1.3.1 缓存 1.3.2 分布式锁 1.3.3 计数器 1.3.4 会话管理 1.3.5 消息队列 2.Redis 2.1 简述 2.2 特性 2.3 监听端口号 2.4 数据类型 二、安装 1.编译安装 2.RPM安装 三、目录结构 1.查看 2.主配置文…