残差网络 ResNet

目录

1.1 ResNet

2.代码实现


1.1 ResNet

如上图函数的大小代表函数的复杂程度,星星代表最优解,可见加了更多层之后的预测比小模型的预测离真实最优解更远了, ResNet做的事情就是使得模型加深一定会使效果变好而不是变差。

2.代码实现

import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)self.conv2=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1)#以上两个卷积都保证了输入输出得大小不变if use_1x1conv:self.conv3=nnn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)else:self.conv=Noneself.bn1=nn.BatchNorm2d(num_channels)self.bn2=nn.BatchNorm2d(num_channels)self.relu=nn.ReLU(inplace=True)#inplace=True表示原地操作def forward(self,X):Y=F.relu(self.bn1(self.conv1(X)))Y=self.bn2(self.conv2(Y))if self.conv3:X=self.conv3(X)Y+=Xreturn F.relu(Y)#查看输入和输出形状一致的情况。
blk=Residual(3)
blk.initialize()
X = np.random.uniform(size=(4, 3, 6, 6))
Y=blk(X)
Y.shape
"""结果输出:
(4, 3, 6, 6)""""""在增加输出通道数的同时,减半输出的高和宽。"""
blk=Residul(3,6,use_1x1conv=True,strides=2)
blk.initialize()
blk(X).shape
"""结果输出:
(4, 6, 3, 3)""""""ResNet模型"""
#ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7*7卷积层后,
#接步幅为2的3*3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))#ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。 
#第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须
#减小高和宽。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。#注意,我们对第一个模块做了特别处理。
def resnet_block(input_channels,num_channels,num_residuals,first_block=False):blk=[]for i in range(num_residuals):#num_residuals等于2if i==0 and not first_block:#first_block此时等于False,说明不是第一个模块,第一个模块的输入已经减半了blk.append(Residual(input_channels,num_channels,use_1x1conv=Truestrides=2))#除开第一个模块,其余每个模块的第一个残差块都strides=2高宽减半#还有输出和输入通道数的变化else:blk.append(Residual(num_channels,num_channels))#其余的所有模块的第二个残差块和第一个模块输入和输出通道数不变return blk#接着在ResNet加入所有残差块,这里每个模块使用2个残差块。
b2=nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))#在ResNet中加入全局平均汇聚层,以及全连接层输出。
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))#在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。在之前所有架构中,
#分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)
"""结果输出:
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:      torch.Size([1, 512, 1, 1])
Flatten output shape:        torch.Size([1, 512])
Linear output shape:         torch.Size([1, 10])""""""训练模型"""
lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
"""结果输出:
loss 0.012, train acc 0.997, test acc 0.893
5032.7 examples/sec on cuda:0"""

参考:

inplace=True (原地操作)-CSDN博客

Python中initialize的全面讲解_笔记大全_设计学院 (python100.com)

python 中类的初始化方法_python initialize(self)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网页设计(九)JavaScript基础应用

一、网页中文字的字号选择性改变 单击前初始状态页面 单击“中”链接后页面 文字素材:   JavaScript是一种能让你的网页更加生动活泼的程式语言,也是目前网页中设计中最容易学又最方便的语言。你可以利用JavaScript轻易的做出亲切的欢迎讯息、漂亮的…

web前端第二次

第一题&#xff1a; <!DOCTYPE html> <html> <head><title>计算奇数和</title> </head> <body><label for"input">请输入一个正整数&#xff1a;</label><input type"number" id"input&qu…

影响CSGO搬砖饰品价格上涨和下跌的原因有哪些

到底哪些情况下CSGO饰品价格会涨&#xff0c;哪些情况会跌&#xff0c;下面是一个混迹steam平台多年的老油条&#xff0c;一点个人见解&#xff0c;不喜吻喷。 首先&#xff0c;CSGO饰品的交易是从市场进行的&#xff0c;市场终究是市场&#xff0c;是自由买卖的&#xff0c;必…

VMware Vsphere 日志:用户 dcui@127.0.01已以vMware-client/6.5.0 的身份登录

一、事件截图&#xff1a; 二、解决办法 原因&#xff1a; 三、解决办法 1.开启锁定模式 2.操作 1、从清单中选择您的 ESXi 主机&#xff0c;然后转至管理 > 设置 > 安全配置文件&#xff0c;然后单击锁定模式的编辑按钮 2、在打开的锁定模式窗口中&#xff0c;选中启…

云服务器 云服务器概述-产品简介-文档中心-腾讯云

腾讯云服务器入门教程包括云服务器CPU内存带宽配置选择&#xff0c;选择云服务器CVM或轻量应用服务器&#xff0c;云服务器创建后重置密码、远程连接、搭建程序环境等&#xff0c;腾讯云服务器网txyfwq.com分享从0到1腾讯云服务器入门教程&#xff1a; 腾讯云服务器入门教程 …

CSDN 年度总结|知识改变命运,学习成就未来

欢迎来到英杰社区&#xff1a; https://bbs.csdn.net/topics/617804998 欢迎来到阿Q社区&#xff1a; https://bbs.csdn.net/topics/617897397 &#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff…

「许战海矩阵战略洞察」吉香居给调味品企业带来的战略启示

引言&#xff1a;吉香居通过实施份额化战略和打造形象产品&#xff0c;在调味品行业中取得了成功。但品牌结构需要调整&#xff0c;需要将子品牌整合到吉香居主品牌下&#xff0c;共同提升品牌势能。此外&#xff0c;企业需保持主品牌竞争战略&#xff0c;以实现长期稳定的高速…

transfomer中正余弦位置编码的源码实现

简介 Transformer模型抛弃了RNN、CNN作为序列学习的基本模型。循环神经网络本身就是一种顺序结构&#xff0c;天生就包含了词在序列中的位置信息。当抛弃循环神经网络结构&#xff0c;完全采用Attention取而代之&#xff0c;这些词序信息就会丢失&#xff0c;模型就没有办法知…

进阶Docker4:网桥模式、主机模式与自定义网络

目录 网络相关 子网掩码 网关 规则 docke网络配置 bridge模式 host模式 创建自定义网络(自定义IP) 网络相关 IP 子网掩码 网关 DNS 端口号 子网掩码 互联网是由许多小型网络构成的&#xff0c;每个网络上都有许多主机&#xff0c;这样便构成了一个有层次的结构。 IP 地…

SpringAOP-说说 Spring AOP 和 AspectJ AOP 区别

Spring AOP Spring AOP 属于运行时增强&#xff0c;主要具有如下特点&#xff1a; 基于动态代理来实现&#xff0c;默认如果使用接口的&#xff0c;用 JDK 提供的动态代理实现&#xff0c;如果是方法则使用 CGLIB 实现Spring AOP 需要依赖 IOC 容器来管理&#xff0c;并且只能…

浅谈安科瑞铁塔/基站电力监控解决方案

I.背景信息&#xff1a; 2020年5G元年&#xff0c;通信行业承蓬勃发展之态&#xff0c;各大运营商和铁塔集团在布局新一代通讯基站。基站用电量不断上升&#xff0c;通信基站智能化电力监控及节能管理已成为各运营商企业的研究方向。 而同时&#xff0c;目前铁塔基站电力使用…

靶机-basic_pentesting_2

basic_pentesting_2 arp-scan -l查找靶机IP masscan 192.168.253.154 --ports 0-65535 --rate10000 端口扫描 nmap扫描nmap -T5 -A -p- 192.168.253.154 目录扫描80端口 http://192.168.253.154/development/dev.txt 2018-04-23: I’ve been messing with that struts stu…

mipi协议

完成mipi信号通道分配后&#xff0c;需要生成与物理层对接的时序、同步信号&#xff1a; MIPI规定&#xff0c;传输过程中&#xff0c;包内是200mV、包间以及包启动和包结束时是1.2V&#xff0c;两种不同的电压摆幅&#xff0c;需要两组不同的LVDS驱动电路在轮流切换工作&#…

数据集成时表模型同步方法解析

01 背景介绍 数据治理的第一步&#xff0c;也是数据中台的一个基础功能 — 即将来自各类业务数据源的数据&#xff0c;同步集成至中台 ODS 层。业务数据源多种多样&#xff0c;单单可能涉及到的主流关系型数据库就有近十种。功能更加全面的数据中台通常还具有对接非关系型数据…

mac查看maven版本报错:The JAVA_HOME environment variable is not defined correctly

终端输入mvn -version报错: The JAVA_HOME environment variable is not defined correctly, this environment variable is needed to run this program. Java环境变量的问题,打开bash_profile查看 open ~/.bash_profile export JAVA_8_HOME/Library/Java/JavaVirtualMachine…

Python图像处理【18】边缘检测详解

边缘检测详解 0. 前言1. 图像导数2. LoG/zero-crossing2.1 Marr-Hildteth 算法 3. Canny 与 holistically-nested 算法3.1 Canny 边缘检测3.2 holistically-nested 边缘检测 小结系列链接 0. 前言 边缘是图像中两个区域之间具有相对不同灰级特性的边界&#xff0c;或者说是亮度…

应用案例 | Softing工业物联网连接解决方案助力汽车零部件供应商实现智能制造升级

随着业务的扩展和技术的进步&#xff0c;某国际先进汽车零部件供应商在其工业物联网的升级方案中使用了Softing的dataFEED OPC Suite——通过MQTT协议将现场控制器和数控系统的数据上传到其物联网云平台&#xff0c;从而实现了设备状态的远程监控&#xff0c;不仅能够提前发现设…

【机器学习300问】9、梯度下降是用来干嘛的?

当你和我一样对自己问出这个问题后&#xff0c;分析一下&#xff01;其实我首先得知道梯度下降是什么&#xff0c;也就它的定义。其次我得了解它具体用在什么地方&#xff0c;也就是使用场景。最后才是这个问题&#xff0c;梯度下降有什么用&#xff1f;怎么用&#xff1f; 所以…

C语言第一弹---C语言基本概念(上)

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】 C语言基本概念 1、C语言是什么&#xff1f;2、C语言的历史和辉煌3、编译器的选择VS20223.1、编译和链接3.2、编译器对比3.3、VS2022优缺点 4、VS项目和源文件、头…

test0117测试1

欢迎关注博主 Mindtechnist 或加入【Linux C/C/Python社区】一起探讨和分享Linux C/C/Python/Shell编程、机器人技术、机器学习、机器视觉、嵌入式AI相关领域的知识和技术。 磁盘满的本质分析 专栏&#xff1a;《Linux从小白到大神》 | 系统学习Linux开发、VIM/GCC/GDB/Make工具…