文章目录
- 前言
- 提出问题
- 核心思想
- 代码理解
- 模块初始化
- forward过程
- self.p_conv
- self._get_p
- self._get_x_q
- self._reshape_x_offset
- 参考文献
前言
代码见:https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py
论文:https://arxiv.org/abs/1703.06211
提出问题
为什么需要可变形卷积,他和普通卷积有什么差异,有什么优势?
核心思想
原始图像通过卷积操作可以变成多通道的特征图,通过特征提取和分析可以完成不同的任务,传统卷积的基本流程如下图,卷积核在原特征图上遍历,加权平均后得到输出特征图相应位置的输出。如公式所示,如果是传统卷积,针对输出图的每个位置,原图上的采样位置是固定的,以3x3卷积核为例,相对采样位置就是公式中的R。
作者认为这种采样方式太规则了,不利于一些不规则特征的提取。例如下图所示,规则卷积vs 可变形卷积提取到的特征有较大区别。针对这个情况,作者提出可变形卷积,也就是说,采样的位置发生了一些变化,可以增加学习采样偏移量,如公式所示。xp代表着新的位置的值,通过bilinear插值得到。
代码理解
整个原始代码难以理解的地方就是这个offset的计算,插值的计算,也就是最终用来卷积的这些数是怎么得到的。
模块初始化
模块初始化设置三个卷积,self.conv
用来执行最后的卷积运算,self.p_conv
用来学习偏移量,self.m_conv
用来给不同位置增加学习权重,代码及注意点和注释如下所示
class DeformConv2d(nn.Module):def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):"""Args:modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2)."""super(DeformConv2d, self).__init__()self.kernel_size = kernel_sizeself.padding = paddingself.stride = strideself.zero_padding = nn.ZeroPad2d(padding)# 最终使用的卷积操作,注意stride=kernel,# 原因是最终采样点不是规则的点,需要结合通过偏移量取值,因此需要构建新的特征图# 新的特征图的尺寸是原来特征图的hw 是原来hw x kernel_size 的大小self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)# 用来学习偏移量的卷积,其中通道数为2xksxks ,如果k=3,也就是学习9个位置的2方向(x、y)偏移量self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)# 初始化偏移卷积为0nn.init.constant_(self.p_conv.weight, 0)# 学习率设置为整个网络0.1倍,避免影响整体网络性能self.p_conv.register_backward_hook(self._set_lr)# 为每个位置增加学习权重,初始化和偏移卷积一样self.modulation = modulationif modulation:self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)nn.init.constant_(self.m_conv.weight, 0)self.m_conv.register_backward_hook(self._set_lr)
forward过程
整体主要步骤如下
1、self.p_conv卷积计算offset ,# 维度 (b,2ksks,h),w # 2xksxks 若k=3,也就是用来卷积的9的位置的x、y方向偏移
2、self._get_p 函数获取offset的位置 ( 绝对位置+相对位置)# 维度 (b, 2N, h, w) ,N=ksxks。
3、self._get_x_q 之前的函数用来计算双线性插值采样点,因为位置是浮点数,需要映射回具体坐标位置 # (b, c, h, w, N),不同的通道c其实对应相同的位置。# (b, c, h, w, N)
4、self._get_x_q 函数用来得到每个位置的插值权重 # (b, c, h, w, N)
5、self._reshape_x_offset 将b, c, h, w, N重新排布为b, c, hxks, wxks 用来进行最终的卷积
代码及解释如下。
def forward(self, x): # b,c,h,w# 计算偏移量,维度 b,2*ks*ks,h,w # N=kxkoffset = self.p_conv(x)if self.modulation: # 为偏移量增加权重m = torch.sigmoid(self.m_conv(x))dtype = offset.data.type()ks = self.kernel_sizeN = offset.size(1) // 2 # N=ks*ks# 填充:k=3的卷积,填充p=1,尺度才不会发生改变if self.padding:x = self.zero_padding(x)# (b, 2N, h, w) ,得到p的位置p = self._get_p(offset, dtype)# (b, h, w, 2N) ,位置放在最后一个维度,方便处理p = p.contiguous().permute(0, 2, 3, 1)q_lt = p.detach().floor() #left top 左上角坐标,也就是最小值,如果是0-1之间就是0q_rb = q_lt + 1 # right bottom右下角坐标,也就是最大值,如果是0-1之间就是1# 确定四个角点坐标,设置在0 到 h-1 或 w-1 之间q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)# clip p ,采样点也需要clamp一下p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)# bilinear kernel (b, h, w, N)g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))# (b, c, h, w, N),计算四个领域的权重x_q_lt = self._get_x_q(x, q_lt, N)x_q_rb = self._get_x_q(x, q_rb, N)x_q_lb = self._get_x_q(x, q_lb, N)x_q_rt = self._get_x_q(x, q_rt, N)# (b, c, h, w, N),计算插值结果x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \g_rb.unsqueeze(dim=1) * x_q_rb + \g_lb.unsqueeze(dim=1) * x_q_lb + \g_rt.unsqueeze(dim=1) * x_q_rt# modulation,如果存在这个模块,就让偏移量*mif self.modulation:m = m.contiguous().permute(0, 2, 3, 1)m = m.unsqueeze(dim=1)m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)x_offset *= m# 重新排列,变成 h,c,h*ks,w*ks 特征图,用来最后卷积x_offset = self._reshape_x_offset(x_offset, ks)out = self.conv(x_offset)return out
self.p_conv
就是普通卷积操作,不进行解释
self._get_p
包括绝对位置和相对位置,绝对位置就是卷积中心在原图中的位置 0-(h-1) ,0-(w-1) ,相对位置就是0-(ks-1) ,卷积操作中每个点与中心位置的相对关系。
def _get_p_n(self, N, dtype): # 相对位置p_n_x, p_n_y = torch.meshgrid(torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))# (2N, 1)p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)p_n = p_n.view(1, 2*N, 1, 1).type(dtype)return p_ndef _get_p_0(self, h, w, N, dtype): #绝对位置p_0_x, p_0_y = torch.meshgrid(torch.arange(1, h*self.stride+1, self.stride),torch.arange(1, w*self.stride+1, self.stride))p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)return p_0def _get_p(self, offset, dtype):N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)# (1, 2N, 1, 1),相对位置只有2N个,因为就是卷积核大小p_n = self._get_p_n(N, dtype)# (1, 2N, h, w),绝对位置有2NxHxW,因为每个位置都有偏移量p_0 = self._get_p_0(h, w, N, dtype) p = p_0 + p_n + offsetreturn p
self._get_x_q
将原始输入的hw变成一个维度的向量,相应的位置索引也需要变成一维,所以需要乘以w,然后最后在重新变成hxw格式
def _get_x_q(self, x, q, N):b, h, w, _ = q.size()padded_w = x.size(3)c = x.size(1)# (b, c, h*w)x = x.contiguous().view(b, c, -1)# (b, h, w, N)index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y# (b, c, h*w*N)index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)return x_offset
self._reshape_x_offset
这个的理解可以根据这个博客@链接来,也就是将整体数据重新排布成卷积的类型
@staticmethoddef _reshape_x_offset(x_offset, ks):b, c, h, w, N = x_offset.size()x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)return x_offset
参考文献
https://blog.csdn.net/panghuzhenbang/article/details/129816869
https://zhuanlan.zhihu.com/p/335147713
https://zhuanlan.zhihu.com/p/102707081
https://blog.csdn.net/panghuzhenbang/article/details/129816869