ResNet神经网络搭建

一、定义残差结构

BasicBlock   

18层、34层网络对应的残差结构

浅层网络主线由两个3x3的卷积层链接,相加后通过relu激活函数输出。还有一个shortcut捷径

参数解释

        expansion = 1  : 判断对应主分支的残差结构有无变化

        downsample=None : 下采样参数,默认为none

        stride步距为1,对应实线残差结构 ; 步距为2,对应虚线残差结构

        self.conv2 = nn.Conv2d(in_channels=out_channel :卷积层1的输出即为输入

class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channel)self.relu = nn.ReLU()self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channel)self.downsample = downsample

定义正向传播

 identity = x : shortcut捷径上的输出值

identity = self.downsample(x) : 将输出特征矩阵x输入到下采样函数中得到捷径分支的输出 

 def forward(self, x):   #定义正向传播identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out += identity  # 相加后通过relu激活函数输出out = self.relu(out)return out

BottleBlock  

50层、101层、152层神经网路对应的残差结构

深层网络主线由一个1x1的降维卷积层,3x3卷积层、1x1升维卷积层和一个shortcut捷径组成。

 按照残差结构进行定义,大致与BasicBlock参数一样,不同的是expansion=4,卷积核个数是之前的4倍。

class Bottleneck(nn.Module):expansion = 4def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()width = int(out_channel * (width_per_group / 64.)) * groupsself.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# -----------------------------------------self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# -----------------------------------------self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsample

 定义正向传播

 if self.downsample is not None :  is None是实线  is not None 是虚线

    def forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

二、 定义网络结构

block,  根据定义不同的层结构传入不同的block
blocks_num,  所使用残差结构的数目、参数列表
num_classes=1000, 分类个数
include_top=True, 在ResNet基础上搭建其他的网络

self.layer1 对应conv2_x
self.layer2 对应conv3_x
self.layer3 对应conv4_x
self.layer4 对应conv5_x 这一系列的残差结构都通过_make_layer函数实线

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)  不管输入是多杀,通过自适应平均池化下采样都会输出(1,1)

self.fc = nn.Linear(512 * block.expansion, num_classes)  通过全连接输出节点层,输入的节点个数是通过平均池化下采样层后的特征矩阵展平后所得到的节点个数,但是由于节点的高和宽都是1,所以节点的个数=深度

class ResNet(nn.Module):def __init__(self,block,blocks_num,num_classes=1000,include_top=True,groups=1,width_per_group=64):super(ResNet, self).__init__()self.include_top = include_top   # 将参数参入变为类变量self.in_channel = 64  # 表格中通过maxpooling后得到的深度self.groups = groupsself.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, blocks_num[0])self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)if self.include_top:self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

三、定义_make_layer函数

block,  定义的残差结构

channel, 残差结构中卷积层使用卷积核的个数

block_num, 该层包含了几个残差结构

block(传入第一层残差结构

        self.in_channel,   
        channel,   主分支第一个卷积层卷积核的个数
        downsample=downsample,  下采样函数

 for _ in range(1, block_num): 通过循环,将剩下一系列的实线残差结构压入进去

 return nn.Sequential(*layers)  非关键字传入,将定义的一系列层结构组合并返回。

    def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:   # 18、34层会跳过这一部分 ;50、101、152层不会downsample = nn.Sequential(  # 生成下采样函数nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []  # 定义空的列表layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)

 定义正向传播

    def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)if self.include_top:x = self.avgpool(x)x = torch.flatten(x, 1)  # 展平后链接x = self.fc(x)return x

四、建立不同的层结构 

传入的参数分别按照定义的顺序传入。

def resnet34(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet50(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)def resnet101(num_classes=1000, include_top=True):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)

后续继续补充。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/7946.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Minio(官方docker版)容器部署时区问题研究记录

文章目录 感慨&概述补充:MINIO_REGION和容器时间的关系 问题一:minio容器和本地容器时间不一致问题说明原因探究解决方法结果验证 问题二:minio修改时间和本地查询结果不一致具体问题原因探究解决办法时间转化工具类调用测试和验证上传文…

Unit6

Unit6 1. val 强壮 valid invalid validate invalidate prevail prevailing prevalent 2. pri 主要的 prime Prime Minister premier primary primary school prior prior to sth prioritize priority principle principal prince princess 3. nov 新 news newspap…

thinkphp5 中路由常见的使用方法

在ThinkPHP 5中,路由的常见使用方法主要包括以下几个方面: 基本路由配置: 你可以通过修改config目录下面的route.php文件来配置路由规则。例如,使用Route::get或Route::post等方法定义不同的HTTP请求类型的路由。 use think\Route…

stm32芯片外设

STM32 F1系列微控制器是ST公司推出的一系列基于ARM Cortex-M3内核的微控制器。这一系列微控制器拥有丰富的外设资源,包括但不限于: ADC(模数转换器):用于将模拟信号转换为数字信号,通常用于传感器数据的读取…

高级数据结构与算法习题(9)

一、判断题 1、Let S be the set of activities in Activity Selection Problem. Then the earliest finish activity am​ must be included in all the maximum-size subset of mutually compatible activities of S. T F 解析:F。设S是活动选择问题中的一…

flutter开发实战-webview_flutter 4.x版本使用

flutter开发实战-webview_flutter 4.x版本使用 在之前使用的webview_flutter版本是3.x的,升级到4.x后,使用方式有所变化。 一、webview_flutter 在工程的pubspec.yaml中引入插件 webview_flutter: ^4.4.2二、使用webview_flutter 在4.x版本中&#…

JAVA语言开发的(智慧校园系统源码)智慧校园的痛点、智慧校园的安全应用、智慧校园解决方案

一、智慧校园的痛点 1、信息孤岛问题:由于校园内各部门或系统独立开发,缺乏统一规划和标准,导致数据无法有效整合和共享,形成了信息孤岛。 2、技术更新与运维挑战:智慧校园的建设依赖于前沿的信息技术,如云…

15【PS作图】像素画地图绘制

绘制视角 绘制地图的时候,有的人会习惯把要绘制的 房子、车子、围栏 小物件先画好,然后安放在地图上 但这样绘制出的各种物件之间,会缺乏凝聚力 既然物品都是人构造出的,不如以人的视角去一步步丰富地图; 比如下图…

关于c++ 中 string s { ‘a‘ , ‘b‘ , ‘c‘ , ‘d‘ } 的方式的构造过程

(1)这样的构造方式不常见,但也确实 STL 库提供了这样的构造函数 (2)以反汇编分析这行代码 (3)谢谢阅读

前端深度扩展

1 为什么要有webpack 模块化管理:构建工具支持Common JS、ES6模块等规范;依赖管理:在大型项目中,手动管理文件依赖关系。webpack可以自动分析项目中的依赖关系,将其打包成1个或多个优化过的文件,减少页面加…

【网络】深入了解端口,一个端口能否被多个进程绑定

引言 在计算机网络中,端口是一项关键概念,它在网络通信中扮演着重要的角色。本文将深入介绍端口的作用、分类,并分析一个端口能否被多个进程绑定的问题。 1. 端口的作用 端口是计算机与网络通信的入口或出口,用于标识进程和应用…

Prefiquence(双指针,动态规划)

文章目录 题目描述输入格式输出格式样例样例输入 #1样例输出 #1样例输入 #2样例输出 #2样例输入 #3样例输出 #3 提示提交链接解析参考代码 题目描述 给您两个二进制字符串 a a a 和 b b b 。二进制字符串是由字符 0 0 0 和 1 1 1 组成的字符串。 您的任务是确定最大可能的…

【数据结构与算法】力扣 239. 滑动窗口最大值

题干描述 给你一个整数数组 nums,有一个大小为 k **的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只向右移动一位。 返回 滑动窗口中的最大值 。 示例 1: 输入: nums [1,3,-1,-3,5,3…

代码随想录day56 | 动态规划P16 | ● 583. ● 72. ● 编辑距离总结篇

583. 两个字符串的删除操作 给定两个单词 word1 和 word2 ,返回使得 word1 和 word2 相同所需的最小步数。 每步 可以删除任意一个字符串中的一个字符。 示例 1: 输入: word1 "sea", word2 "eat" 输出: 2 解释: 第一步将 &quo…

嵌入式开发常见概念简介

目录 0. 《STM32单片机自学教程》专栏总纲 API Handle(句柄) 0. 《STM32单片机自学教程》专栏总纲 本文作为专栏《STM32单片机自学教程》专栏其中的一部分,返回专栏总纲,阅读所有文章,点击Link: STM32单片机自学教程-[目录总纲]_stm32 学习-CSD…

每日OJ题_贪心算法三⑤_力扣134. 加油站

目录 力扣134. 加油站 解析代码 力扣134. 加油站 134. 加油站 难度 中等 在一条环路上有 n 个加油站,其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[i] 升。你从其中的一…

java设计模式三

工厂模式是一种创建型设计模式,它提供了一个创建对象的接口,但允许子类决定实例化哪一个类。工厂模式有几种变体,包括简单工厂模式、工厂方法模式和抽象工厂模式。下面通过一个简化的案例和对Java标准库中使用工厂模式的源码分析来说明这一模…

SpringBoot3项目打包和运行

六、SpringBoot3项目打包和运行 6.1 添加打包插件 在Spring Boot项目中添加spring-boot-maven-plugin插件是为了支持将项目打包成可执行的可运行jar包。如果不添加spring-boot-maven-plugin插件配置,使用常规的java -jar命令来运行打包后的Spring Boot项目是无法找…

linux笔记--tmux的使用

目录 1--安装tmux 2--创建新会话 3--离开会话 4--查看所有会话 5--进入会话 6--结束会话 1--安装tmux # Ubuntu 或 Debian sudo apt-get install tmux# CentOS 或 Fedora sudo yum install tmux# Mac brew install tmux 2--创建新会话 tmux new -s <session-name&g…

scrapy常用命令总结

1.创建scrapy项目的命令&#xff1a;     scrapy startproject <项目名字> 示例&#xff1a;     scrapy startproject myspider 2.通过命令创建出爬虫文件&#xff0c;爬虫文件为主要的代码文件&#xff0c;通常一个网站的爬取动作都会在爬虫文件中进行编写。 …