[原理] 可变性卷积(deformable convolution)原理及代码解释

文章目录

  • 前言
  • 提出问题
  • 核心思想
  • 代码理解
    • 模块初始化
    • forward过程
      • self.p_conv
      • self._get_p
      • self._get_x_q
      • self._reshape_x_offset
    • 参考文献

前言

代码见:https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py
论文:https://arxiv.org/abs/1703.06211

提出问题

为什么需要可变形卷积,他和普通卷积有什么差异,有什么优势?

核心思想

原始图像通过卷积操作可以变成多通道的特征图,通过特征提取和分析可以完成不同的任务,传统卷积的基本流程如下图,卷积核在原特征图上遍历,加权平均后得到输出特征图相应位置的输出。如公式所示,如果是传统卷积,针对输出图的每个位置,原图上的采样位置是固定的,以3x3卷积核为例,相对采样位置就是公式中的R。

图片2 图片1
传统卷积

作者认为这种采样方式太规则了,不利于一些不规则特征的提取。例如下图所示,规则卷积vs 可变形卷积提取到的特征有较大区别。针对这个情况,作者提出可变形卷积,也就是说,采样的位置发生了一些变化,可以增加学习采样偏移量,如公式所示。xp代表着新的位置的值,通过bilinear插值得到。

图片2 图片1
可变形卷积
经过可变形卷积,整体的特征图尺寸不会发生变化,跟普通卷积一样,如下图所示。

在这里插入图片描述

代码理解

整个原始代码难以理解的地方就是这个offset的计算,插值的计算,也就是最终用来卷积的这些数是怎么得到的。

模块初始化

模块初始化设置三个卷积,self.conv 用来执行最后的卷积运算,self.p_conv 用来学习偏移量,self.m_conv用来给不同位置增加学习权重,代码及注意点和注释如下所示

class DeformConv2d(nn.Module):def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):"""Args:modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2)."""super(DeformConv2d, self).__init__()self.kernel_size = kernel_sizeself.padding = paddingself.stride = strideself.zero_padding = nn.ZeroPad2d(padding)# 最终使用的卷积操作,注意stride=kernel,# 原因是最终采样点不是规则的点,需要结合通过偏移量取值,因此需要构建新的特征图# 新的特征图的尺寸是原来特征图的hw 是原来hw x kernel_size 的大小self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)# 用来学习偏移量的卷积,其中通道数为2xksxks ,如果k=3,也就是学习9个位置的2方向(x、y)偏移量self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)# 初始化偏移卷积为0nn.init.constant_(self.p_conv.weight, 0)# 学习率设置为整个网络0.1倍,避免影响整体网络性能self.p_conv.register_backward_hook(self._set_lr)# 为每个位置增加学习权重,初始化和偏移卷积一样self.modulation = modulationif modulation:self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)nn.init.constant_(self.m_conv.weight, 0)self.m_conv.register_backward_hook(self._set_lr)

forward过程

整体主要步骤如下
1、self.p_conv卷积计算offset ,# 维度 (b,2ksks,h),w # 2xksxks 若k=3,也就是用来卷积的9的位置的x、y方向偏移

2、self._get_p 函数获取offset的位置 ( 绝对位置+相对位置)# 维度 (b, 2N, h, w) ,N=ksxks。

3、self._get_x_q 之前的函数用来计算双线性插值采样点,因为位置是浮点数,需要映射回具体坐标位置 # (b, c, h, w, N),不同的通道c其实对应相同的位置。# (b, c, h, w, N)

4、self._get_x_q 函数用来得到每个位置的插值权重 # (b, c, h, w, N)

5、self._reshape_x_offset 将b, c, h, w, N重新排布为b, c, hxks, wxks 用来进行最终的卷积
代码及解释如下。

    def forward(self, x): # b,c,h,w# 计算偏移量,维度 b,2*ks*ks,h,w  # N=kxkoffset = self.p_conv(x)if self.modulation: # 为偏移量增加权重m = torch.sigmoid(self.m_conv(x))dtype = offset.data.type()ks = self.kernel_sizeN = offset.size(1) // 2 # N=ks*ks# 填充:k=3的卷积,填充p=1,尺度才不会发生改变if self.padding:x = self.zero_padding(x)# (b, 2N, h, w) ,得到p的位置p = self._get_p(offset, dtype)# (b, h, w, 2N) ,位置放在最后一个维度,方便处理p = p.contiguous().permute(0, 2, 3, 1)q_lt = p.detach().floor() #left top 左上角坐标,也就是最小值,如果是0-1之间就是0q_rb = q_lt + 1 # right bottom右下角坐标,也就是最大值,如果是0-1之间就是1# 确定四个角点坐标,设置在0 到 h-1 或 w-1 之间q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)# clip p ,采样点也需要clamp一下p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)# bilinear kernel (b, h, w, N)g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))# (b, c, h, w, N),计算四个领域的权重x_q_lt = self._get_x_q(x, q_lt, N)x_q_rb = self._get_x_q(x, q_rb, N)x_q_lb = self._get_x_q(x, q_lb, N)x_q_rt = self._get_x_q(x, q_rt, N)# (b, c, h, w, N),计算插值结果x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \g_rb.unsqueeze(dim=1) * x_q_rb + \g_lb.unsqueeze(dim=1) * x_q_lb + \g_rt.unsqueeze(dim=1) * x_q_rt# modulation,如果存在这个模块,就让偏移量*mif self.modulation:m = m.contiguous().permute(0, 2, 3, 1)m = m.unsqueeze(dim=1)m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)x_offset *= m# 重新排列,变成 h,c,h*ks,w*ks 特征图,用来最后卷积x_offset = self._reshape_x_offset(x_offset, ks)out = self.conv(x_offset)return out

self.p_conv

就是普通卷积操作,不进行解释

self._get_p

包括绝对位置和相对位置,绝对位置就是卷积中心在原图中的位置 0-(h-1) ,0-(w-1) ,相对位置就是0-(ks-1) ,卷积操作中每个点与中心位置的相对关系。

       def _get_p_n(self, N, dtype): # 相对位置p_n_x, p_n_y = torch.meshgrid(torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))# (2N, 1)p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)p_n = p_n.view(1, 2*N, 1, 1).type(dtype)return p_ndef _get_p_0(self, h, w, N, dtype): #绝对位置p_0_x, p_0_y = torch.meshgrid(torch.arange(1, h*self.stride+1, self.stride),torch.arange(1, w*self.stride+1, self.stride))p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)return p_0def _get_p(self, offset, dtype):N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)# (1, 2N, 1, 1),相对位置只有2N个,因为就是卷积核大小p_n = self._get_p_n(N, dtype)# (1, 2N, h, w),绝对位置有2NxHxW,因为每个位置都有偏移量p_0 = self._get_p_0(h, w, N, dtype) p = p_0 + p_n + offsetreturn p

self._get_x_q

将原始输入的hw变成一个维度的向量,相应的位置索引也需要变成一维,所以需要乘以w,然后最后在重新变成hxw格式

def _get_x_q(self, x, q, N):b, h, w, _ = q.size()padded_w = x.size(3)c = x.size(1)# (b, c, h*w)x = x.contiguous().view(b, c, -1)# (b, h, w, N)index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y# (b, c, h*w*N)index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)return x_offset

self._reshape_x_offset

这个的理解可以根据这个博客@链接来,也就是将整体数据重新排布成卷积的类型

图片2
    @staticmethoddef _reshape_x_offset(x_offset, ks):b, c, h, w, N = x_offset.size()x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)return x_offset

参考文献

https://blog.csdn.net/panghuzhenbang/article/details/129816869
https://zhuanlan.zhihu.com/p/335147713
https://zhuanlan.zhihu.com/p/102707081
https://blog.csdn.net/panghuzhenbang/article/details/129816869

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/51079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】使用栈实现综合计算器

首先初始化两个栈,数栈(numStack)用于存放数据,符号栈(operStack)用于存放运算符 计算思路 1.通过一个index值(索引)来遍历表达式 2.如果发现扫描到一个数字,就直接入数栈…

Python | TypeError: ‘function’ object is not subscriptable

Python | TypeError: ‘function’ object is not subscriptable 在Python编程中,遇到“TypeError: ‘function’ object is not subscriptable”这一错误通常意味着你尝试像访问列表、元组、字典或字符串等可订阅(subscriptable)对象那样去…

Javascript面试基础6(下)

获取页面所有checkbox 怎样添加、移除、移动、复制、创建和查找节点 在JavaScript中,操作DOM(文档对象模型)是常见的任务,包括添加、移除、移动、复制、创建和查找节点。以下是一些基本的示例,说明如何执行这些操作&a…

【Python】如何修改元组的值?

一、题目 We have seen that lists are mutable (they can be changed), and tuples are immutable (they cannot be changed). Lets try to understand this with an example. You are given an immutable string, and you want to make chaneges to it. Example >>…

Java语言程序设计——篇九(2)

🌿🌿🌿跟随博主脚步,从这里开始→博主主页🌿🌿🌿 枚举类型 枚举类型的定义枚举类型的方法实战演练 枚举在switch中的应用实战演练 枚举类的构造方法实战演练 枚举类型的定义 [修饰符] enum 枚举…

内部文档:如何创建、提示和示例

每个组织都拥有一定的内部知识储备,可以让组织不断运转、持续发展。这些知识资产对于每个企业来说都是独一无二的,并且高度依赖于为您工作的人员。问题是,这些知识常常储存在员工的头脑中。如果知识被记录下来,它通常隐藏在 Slack…

医院影像平台源码,C/S体系结构的C#语言PACS系统全套商业源代码

医学学影像临床信息系统具有图像采集、显示、存储、传输和管理等功能,支持DICOM影像设备和非DICOM影像设备,可以识别CT、MR、CR/DR、X光、DSA、B超、NM、SC等设备的图像类型,可对数字影像进行无损压缩和有损压缩处理。C/S体系结构的多媒体数据…

湖仓一体架构解析:数仓架构选择(第48天)

系列文章目录 1、Lambda 架构 2、Kappa 架构 3、混合架构 4、架构选择 5、实时数仓现状 6、湖仓一体架构 7、流批一体架构 文章目录 系列文章目录前言1、Lambda 架构2、Kappa 架构3、混合架构4、架构选择5、实时数仓现状6、湖仓一体架构7、流批一体架构 前言 本文解析了Lambd…

Verilog语言和C语言的本质区别是什么?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「C语言的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 用老石的一句话其实很好说…

QtQuick-QML类型系统-对象特性 (方法特性)

概念 对象类型的方法就是一个函数,可以执行,也可以和信号关联,这样发射信号会自动调用。 在C中,可以使用Q_INVOKABLE宏或者Q_SLOT宏进行注册的方式定义方法; 另外,也可以在QML文档的对象声明里使用下面的…

二分类损失 - BCELoss详解

BCELoss (Binary Cross-Entropy Loss) 是用于二分类问题的损失函数。它用于评估预测值和实际标签之间的差异。在 PyTorch 中,BCELoss 是一个常用的损失函数。以下是 BCELoss 的详细计算过程和代码实现。 BCELoss 的计算过程 给定一组预测值 y ^ \hat{y} y^​ 和实…

redis的使用场景-分布式锁

使用redis的setnx命令放入数据并用此数据当锁完成业务(但是如果用户操作途中出现异常导致超出指定时间会出现问题) Service public class StockService {Autowiredprivate StockDao stockDao; //mapper注入Autowiredprivate StringRedisTemplate redisT…

ssm框架整合,异常处理器和拦截器(纯注解开发)

目录 ssm框架整合 第一步:指定打包方式和导入所需要的依赖 打包方法:war springMVC所需依赖 解析json依赖 mybatis依赖 数据库驱动依赖 druid数据源依赖 junit依赖 第二步:导入tomcat插件 第三步:编写配置类 SpringCon…

【AI绘画】Midjourney V6初学者完全指南 参数篇

本文我们将详细介绍对图像生成结果产生重大影响的"参数"。 1. 什么是参数? 参数是一种添加到提示末尾以调整图像生成输出设置的方法。 它们用两个连字符"–“和特定字符串表示,如”–ar"、“–chaos”、"–r"等。 您也可以同时使用多个参数…

分布式控制算法——第一部分:基础概念与原理

分布式控制算法 文章目录 分布式控制算法第一部分:基础概念与原理1. 引言分布式控制的定义分布式控制系统的特点与优势分布式控制的应用场景 2. 分布式系统基础分布式系统的定义和特性分布式计算模型常见的分布式系统架构 3. 分布式控制基础分布式控制的基本原理中央…

面试题003:面向对象的特征 之 封装性

Java规定了4种权限修饰,分别是:private、缺省、protected、public。我们可以使用4种权限修饰来修饰类及类的内部成员。当这些成员被调用时,体现可见性的大小。 封装性在程序中的体现: 场景1:私有化(private)类的属性,提供公共(pub…

java项目中添加SDK项目作为依赖使用(无需上传Maven)

需求: 当需要多次调用某个函数或算法时,不想每次调用接口都自己编写,可以将该项目打包,以添加依赖的方式实现调用 适用于: 无需上线的项目,仅公司或团队内部使用的项目 操作步骤: 以下面这…

菜鸟从0学微服务——MyBatis-Plus

关于“菜鸟从0学微服务” 针对有编程基础,开始学习微服务的同学,我们陆续推出从0学微服务的笔记分享。力求从各个中间件的使用来反思这些中间件的作用和优势。 会分享的比较快,会记录demo演算和中间件的使用过程,至于细节的理论…

OPENMV脱机调阈值

用到了7个按钮,第一个用来选择是否进入调阈值模式。 后6个用来调整OPENMV阈值编辑器的6个滑动条 OPENMV代码 import sensor, image, time, pyb,math, display from pyb import UARTsensor.reset() sensor.set_framesize(sensor.QQVGA) sensor.set_pixformat(sens…

【数学建模】——【python】实现【最短路径】【最小生成树】【复杂网络分析】

目录 1. 最短路径问题 - 绘制城市间旅行最短路径图 题目描述: 要求: 示例数据: python 代码实现 实现思想: 要点: 2. 最小生成树问题 - Kruskal算法绘制MST 题目描述: 要求: 示例数据…