#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2021/4/22 17:06
# @Author : @linlianqin
# @Site :
# @File : fpn.py
# @Software: PyCharm
# @description:其搭建的基本流程和resnet是一致的,只是将每一层的卷积结果保存了起来import torch
import torch.nn as nn
import torch.nn.functional as F# ResNet的基本Bottleneck类
class Bottleneck(nn.Module):expansion = 4 # 通道倍增数def __init__(self, in_planes, planes, stride=1, downsample=None):super(Bottleneck, self).__init__()self.bottleneck = nn.Sequential(nn.Conv2d(in_planes, planes, (1,1), bias=False),nn.BatchNorm2d(planes),nn.ReLU(inplace=True),nn.Conv2d(planes, planes, (3,3), (stride,stride), (1,1), bias=False),nn.BatchNorm2d(planes),nn.ReLU(inplace=True),nn.Conv2d(planes, self.expansion * planes, (1,1), bias=False),nn.BatchNorm2d(self.expansion * planes),)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xout = self.bottleneck(x)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out# FNP的类,初始化需要一个list,代表RESNET的每一个阶段的Bottleneck的数量
class FPN(nn.Module):def __init__(self, layers):super(FPN, self).__init__()self.inplanes = 64# 处理输入的C1模块(C1代表了RestNet的前几个卷积与池化层)self.conv1 = nn.Conv2d(3, 64, (7,7), (2,2), (3,3), bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(3, 2, 1)# 搭建自下而上的C2,C3,C4,C5self.layer1 = self._make_layer(64, layers[0])self.layer2 = self._make_layer(128, layers[1], 2)self.layer3 = self._make_layer(256, layers[2], 2)self.layer4 = self._make_layer(512, layers[3], 2)# 对C5减少通道数,得到P5self.toplayer = nn.Conv2d(2048, 256, (1,1),(1,1) ,(0,0))# 3x3卷积融合特征self.smooth1 = nn.Conv2d(256, 256, (3,3), (1,1), (1,1))self.smooth2 = nn.Conv2d(256, 256, (3,3), (1,1), (1,1))self.smooth3 = nn.Conv2d(256, 256, (3,3), (1,1), (1,1))# 横向连接,保证通道数相同self.latlayer1 = nn.Conv2d(1024, 256, (1,1),(1,1) ,(0,0))self.latlayer2 = nn.Conv2d(512, 256, (1,1),(1,1) ,(0,0))self.latlayer3 = nn.Conv2d(256, 256, (1,1),(1,1) ,(0,0))def _make_layer(self, planes, blocks, stride=1):downsample = Noneif stride != 1 or self.inplanes != Bottleneck.expansion * planes:downsample = nn.Sequential(nn.Conv2d(self.inplanes, Bottleneck.expansion * planes, (1,1), (stride,stride), bias=False),nn.BatchNorm2d(Bottleneck.expansion * planes))layers = []layers.append(Bottleneck(self.inplanes, planes, stride, downsample))self.inplanes = planes * Bottleneck.expansionfor i in range(1, blocks):layers.append(Bottleneck(self.inplanes, planes))return nn.Sequential(*layers)# 自上而下的采样模块def _upsample_add(self, x, y):_, _, H, W = y.shapereturn F.interpolate(x,size=(H,W),mode='bilinear',align_corners=False) + ydef forward(self, x):# 自下而上c1 = self.maxpool(self.relu(self.bn1(self.conv1(x))))c2 = self.layer1(c1)c3 = self.layer2(c2)c4 = self.layer3(c3)c5 = self.layer4(c4)# 自上而下p5 = self.toplayer(c5)p4 = self._upsample_add(p5, self.latlayer1(c4))p3 = self._upsample_add(p4, self.latlayer2(c3))p2 = self._upsample_add(p3, self.latlayer3(c2))# 卷积的融合,平滑处理p4 = self.smooth1(p4)p3 = self.smooth2(p3)p2 = self.smooth3(p2)return p2, p3, p4, p5if __name__ == '__main__':model = FPN([3,4,6,3]) # FPN网络模型input = torch.randn(1,3,224,224)out = model(input)print(out[1].shape)
理论参考:https://blog.csdn.net/qq_41251963/article/details/109398699?ops_request_misc=&request_id=&biz_id=102&utm_term=fpn%20pytorch%E6%BA%90%E7%A0%81%E8%AE%AD%E7%BB%83%E6%A8%A1%E5%9E%8B&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduweb~default-6-.first_rank_v2_pc_rank_v29&spm=1018.2226.3001.4187