残差网络 ResNet

目录

1.1 ResNet

2.代码实现


1.1 ResNet

如上图函数的大小代表函数的复杂程度,星星代表最优解,可见加了更多层之后的预测比小模型的预测离真实最优解更远了, ResNet做的事情就是使得模型加深一定会使效果变好而不是变差。

2.代码实现

import torch 
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lclass Residual(nn.Module):def __init__(self,input_channels,num_channels,use_1x1conv=False,strides=1):super().__init__()self.conv1=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1,stride=strides)self.conv2=nn.Conv2d(input_channels,num_channels,kernel_size=3,padding=1)#以上两个卷积都保证了输入输出得大小不变if use_1x1conv:self.conv3=nnn.Conv2d(input_channels,num_channels,kernel_size=1,stride=strides)else:self.conv=Noneself.bn1=nn.BatchNorm2d(num_channels)self.bn2=nn.BatchNorm2d(num_channels)self.relu=nn.ReLU(inplace=True)#inplace=True表示原地操作def forward(self,X):Y=F.relu(self.bn1(self.conv1(X)))Y=self.bn2(self.conv2(Y))if self.conv3:X=self.conv3(X)Y+=Xreturn F.relu(Y)#查看输入和输出形状一致的情况。
blk=Residual(3)
blk.initialize()
X = np.random.uniform(size=(4, 3, 6, 6))
Y=blk(X)
Y.shape
"""结果输出:
(4, 3, 6, 6)""""""在增加输出通道数的同时,减半输出的高和宽。"""
blk=Residul(3,6,use_1x1conv=True,strides=2)
blk.initialize()
blk(X).shape
"""结果输出:
(4, 6, 3, 3)""""""ResNet模型"""
#ResNet的前两层跟之前介绍的GoogLeNet中的一样: 在输出通道数为64、步幅为2的7*7卷积层后,
#接步幅为2的3*3的最大汇聚层。 不同之处在于ResNet每个卷积层后增加了批量规范化层。
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))#ResNet则使用4个由残差块组成的模块,每个模块使用若干个同样输出通道数的残差块。 
#第一个模块的通道数同输入通道数一致。 由于之前已经使用了步幅为2的最大汇聚层,所以无须
#减小高和宽。之后的每个模块在第一个残差块里将上一个模块的通道数翻倍,并将高和宽减半。#注意,我们对第一个模块做了特别处理。
def resnet_block(input_channels,num_channels,num_residuals,first_block=False):blk=[]for i in range(num_residuals):#num_residuals等于2if i==0 and not first_block:#first_block此时等于False,说明不是第一个模块,第一个模块的输入已经减半了blk.append(Residual(input_channels,num_channels,use_1x1conv=Truestrides=2))#除开第一个模块,其余每个模块的第一个残差块都strides=2高宽减半#还有输出和输入通道数的变化else:blk.append(Residual(num_channels,num_channels))#其余的所有模块的第二个残差块和第一个模块输入和输出通道数不变return blk#接着在ResNet加入所有残差块,这里每个模块使用2个残差块。
b2=nn.Sequential(*resnet_block(64,64,2,first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))#在ResNet中加入全局平均汇聚层,以及全连接层输出。
net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 10))#在训练ResNet之前,让我们观察一下ResNet中不同模块的输入形状是如何变化的。在之前所有架构中,
#分辨率降低,通道数量增加,直到全局平均汇聚层聚集所有特征。
X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)
"""结果输出:
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 64, 56, 56])
Sequential output shape:     torch.Size([1, 128, 28, 28])
Sequential output shape:     torch.Size([1, 256, 14, 14])
Sequential output shape:     torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape:      torch.Size([1, 512, 1, 1])
Flatten output shape:        torch.Size([1, 512])
Linear output shape:         torch.Size([1, 10])""""""训练模型"""
lr, num_epochs, batch_size = 0.05, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
"""结果输出:
loss 0.012, train acc 0.997, test acc 0.893
5032.7 examples/sec on cuda:0"""

参考:

inplace=True (原地操作)-CSDN博客

Python中initialize的全面讲解_笔记大全_设计学院 (python100.com)

python 中类的初始化方法_python initialize(self)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC-05

Spring MVC拦截器是在请求到达处理器前或返回客户端前执行的组件,它可以用于拦截和处理请求,实现一些通用的功能。 Spring MVC拦截器可以实现的功能包括: 登录验证:拦截所有请求,检查用户是否已登录,未登录…

网页设计(九)JavaScript基础应用

一、网页中文字的字号选择性改变 单击前初始状态页面 单击“中”链接后页面 文字素材:   JavaScript是一种能让你的网页更加生动活泼的程式语言,也是目前网页中设计中最容易学又最方便的语言。你可以利用JavaScript轻易的做出亲切的欢迎讯息、漂亮的…

web前端第二次

第一题&#xff1a; <!DOCTYPE html> <html> <head><title>计算奇数和</title> </head> <body><label for"input">请输入一个正整数&#xff1a;</label><input type"number" id"input&qu…

影响CSGO搬砖饰品价格上涨和下跌的原因有哪些

到底哪些情况下CSGO饰品价格会涨&#xff0c;哪些情况会跌&#xff0c;下面是一个混迹steam平台多年的老油条&#xff0c;一点个人见解&#xff0c;不喜吻喷。 首先&#xff0c;CSGO饰品的交易是从市场进行的&#xff0c;市场终究是市场&#xff0c;是自由买卖的&#xff0c;必…

VMware Vsphere 日志:用户 dcui@127.0.01已以vMware-client/6.5.0 的身份登录

一、事件截图&#xff1a; 二、解决办法 原因&#xff1a; 三、解决办法 1.开启锁定模式 2.操作 1、从清单中选择您的 ESXi 主机&#xff0c;然后转至管理 > 设置 > 安全配置文件&#xff0c;然后单击锁定模式的编辑按钮 2、在打开的锁定模式窗口中&#xff0c;选中启…

【Python】P4 异常处理

Python 异常处理 Python 中对于异常的处理主要通过 try-except、finally 和 raise 语句实现。 try-except 语句&#xff1a; 尝试执行一段代码&#xff0c;如果该代码块引发了异常&#xff0c;那么将跳过 try 代码块中剩余的代码&#xff0c;转而执行相应的 except 子句。 …

云服务器 云服务器概述-产品简介-文档中心-腾讯云

腾讯云服务器入门教程包括云服务器CPU内存带宽配置选择&#xff0c;选择云服务器CVM或轻量应用服务器&#xff0c;云服务器创建后重置密码、远程连接、搭建程序环境等&#xff0c;腾讯云服务器网txyfwq.com分享从0到1腾讯云服务器入门教程&#xff1a; 腾讯云服务器入门教程 …

C++中的引用及指针变量

目录 1.1 C中的引用 1.2 C中的指针变量&#xff08;pointer&#xff09; 1.1 C中的引用 C中的引用&#xff08;reference&#xff09;是一种特殊的变量&#xff0c;它是某个已存在变量的另一个名字。引用变量与指针变量类似&#xff0c;但引用变量必须在声明时进行初始化&…

CSDN 年度总结|知识改变命运,学习成就未来

欢迎来到英杰社区&#xff1a; https://bbs.csdn.net/topics/617804998 欢迎来到阿Q社区&#xff1a; https://bbs.csdn.net/topics/617897397 &#x1f4d5;作者简介&#xff1a;热爱跑步的恒川&#xff0c;致力于C/C、Java、Python等多编程语言&#xff0c;热爱跑步&#xff…

Python操作PDF的全面指南

引言&#xff1a; 在现代数字化时代&#xff0c;PDF&#xff08;Portable Document Format&#xff09;已成为一种常见的文档格式。无论是在工作中还是在学习中&#xff0c;我们经常需要处理和操作PDF文件。幸运的是&#xff0c;Python提供了许多强大的库和工具&#xff0c;可以…

「许战海矩阵战略洞察」吉香居给调味品企业带来的战略启示

引言&#xff1a;吉香居通过实施份额化战略和打造形象产品&#xff0c;在调味品行业中取得了成功。但品牌结构需要调整&#xff0c;需要将子品牌整合到吉香居主品牌下&#xff0c;共同提升品牌势能。此外&#xff0c;企业需保持主品牌竞争战略&#xff0c;以实现长期稳定的高速…

transfomer中正余弦位置编码的源码实现

简介 Transformer模型抛弃了RNN、CNN作为序列学习的基本模型。循环神经网络本身就是一种顺序结构&#xff0c;天生就包含了词在序列中的位置信息。当抛弃循环神经网络结构&#xff0c;完全采用Attention取而代之&#xff0c;这些词序信息就会丢失&#xff0c;模型就没有办法知…

进阶Docker4:网桥模式、主机模式与自定义网络

目录 网络相关 子网掩码 网关 规则 docke网络配置 bridge模式 host模式 创建自定义网络(自定义IP) 网络相关 IP 子网掩码 网关 DNS 端口号 子网掩码 互联网是由许多小型网络构成的&#xff0c;每个网络上都有许多主机&#xff0c;这样便构成了一个有层次的结构。 IP 地…

python实现屏幕颜色获取

为了实时监听鼠标移动并输出鼠标当前位置的颜色值&#xff0c;你可以结合使用pyautogui和pynput库。pynput库可以用来监听鼠标事件&#xff0c;而pyautogui则可以用来获取鼠标当前位置的屏幕颜色。 首先&#xff0c;你需要安装这两个库&#xff08;如果尚未安装&#xff09;&a…

FreeBSD上安装mysql数据库

安装前提 1、使用pkg安装mysql有个前提FreeBSD版本12.2及以上。 2、内存最好是8GB及以上 安装 $ pkg search mysql …… mysql80-client-8.0.35 Multithreaded SQL database (client) mysql80-server-8.0.35 Multithreaded SQL database (server) mysql81…

SpringAOP-说说 Spring AOP 和 AspectJ AOP 区别

Spring AOP Spring AOP 属于运行时增强&#xff0c;主要具有如下特点&#xff1a; 基于动态代理来实现&#xff0c;默认如果使用接口的&#xff0c;用 JDK 提供的动态代理实现&#xff0c;如果是方法则使用 CGLIB 实现Spring AOP 需要依赖 IOC 容器来管理&#xff0c;并且只能…

浅谈安科瑞铁塔/基站电力监控解决方案

I.背景信息&#xff1a; 2020年5G元年&#xff0c;通信行业承蓬勃发展之态&#xff0c;各大运营商和铁塔集团在布局新一代通讯基站。基站用电量不断上升&#xff0c;通信基站智能化电力监控及节能管理已成为各运营商企业的研究方向。 而同时&#xff0c;目前铁塔基站电力使用…

靶机-basic_pentesting_2

basic_pentesting_2 arp-scan -l查找靶机IP masscan 192.168.253.154 --ports 0-65535 --rate10000 端口扫描 nmap扫描nmap -T5 -A -p- 192.168.253.154 目录扫描80端口 http://192.168.253.154/development/dev.txt 2018-04-23: I’ve been messing with that struts stu…

面向Java开发者的ChatGPT提示词工程(10)

在ChatGPT的众多应用中&#xff0c;拼写检查和语法检查犹如璀璨的明珠&#xff0c;受到广大用户的热烈追捧。我对此深信不疑&#xff0c;且一直在实践中坚定不移。特别是在使用非母语的情况下&#xff0c;它的作用更为显著。接下来&#xff0c;让我们通过一些常见的拼写和语法问…

mipi协议

完成mipi信号通道分配后&#xff0c;需要生成与物理层对接的时序、同步信号&#xff1a; MIPI规定&#xff0c;传输过程中&#xff0c;包内是200mV、包间以及包启动和包结束时是1.2V&#xff0c;两种不同的电压摆幅&#xff0c;需要两组不同的LVDS驱动电路在轮流切换工作&#…