文章目录
- 介绍
- 代码实现
介绍
ASPP(Atrous Spatial Pyramid Pooling),空洞空间卷积池化金字塔。简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。ASPP 的结构如下:
如图所示,ASPP 本质上由一个1×1的卷积(最上) + 池化金字塔(中间三个) + ASPP Pooling(最下面三层)组成。而池化金字塔各层的膨胀因子可自定义,从而实现自由的多尺度特征提取。
代码实现
'''
pytorch实现ASPP
'''
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.uniform import Uniform
from torch.nn import functional as F
from typing import Dict, Listclass ASPPConv(nn.Sequential):def __init__(self, in_channels: int, out_channels: int, dilation: int) -> None:super(ASPPConv, self).__init__(nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU())class ASPPPooling(nn.Sequential):def __init__(self, in_channels: int, out_channels: int) -> None:super(ASPPPooling, self).__init__(nn.AdaptiveAvgPool2d(1),nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU())def forward(self, x: torch.Tensor) -> torch.Tensor:size = x.shape[-2:]for mod in self:x = mod(x)return F.interpolate(x, size=size, mode='bilinear', align_corners=False)class ASPP(nn.Module):def __init__(self, in_channels: int, atrous_rates: List[int], out_channels: int = 256) -> None:super(ASPP, self).__init__()modules = [nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU())]rates = tuple(atrous_rates)for rate in rates:modules.append(ASPPConv(in_channels, out_channels, rate))modules.append(ASPPPooling(in_channels, out_channels))self.convs = nn.ModuleList(modules)self.project = nn.Sequential(nn.Conv2d(len(self.convs) * out_channels, out_channels, 1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(),nn.Dropout(0.5))def forward(self, x: torch.Tensor) -> torch.Tensor:_res = []for conv in self.convs:_res.append(conv(x))res = torch.cat(_res, dim=1)return self.project(res)class EncoderAspp(nn.Module):def __init__(self, params):super(EncoderAspp, self).__init__()self.aspp = ASPP(in_channels = 64 , [4, 6, 8], out_channels = 64)def forward(self, x):x = self.aspp(x)return x