[原理] 可变性卷积(deformable convolution)原理及代码解释

文章目录

  • 前言
  • 提出问题
  • 核心思想
  • 代码理解
    • 模块初始化
    • forward过程
      • self.p_conv
      • self._get_p
      • self._get_x_q
      • self._reshape_x_offset
    • 参考文献

前言

代码见:https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py
论文:https://arxiv.org/abs/1703.06211

提出问题

为什么需要可变形卷积,他和普通卷积有什么差异,有什么优势?

核心思想

原始图像通过卷积操作可以变成多通道的特征图,通过特征提取和分析可以完成不同的任务,传统卷积的基本流程如下图,卷积核在原特征图上遍历,加权平均后得到输出特征图相应位置的输出。如公式所示,如果是传统卷积,针对输出图的每个位置,原图上的采样位置是固定的,以3x3卷积核为例,相对采样位置就是公式中的R。

图片2 图片1
传统卷积

作者认为这种采样方式太规则了,不利于一些不规则特征的提取。例如下图所示,规则卷积vs 可变形卷积提取到的特征有较大区别。针对这个情况,作者提出可变形卷积,也就是说,采样的位置发生了一些变化,可以增加学习采样偏移量,如公式所示。xp代表着新的位置的值,通过bilinear插值得到。

图片2 图片1
可变形卷积
经过可变形卷积,整体的特征图尺寸不会发生变化,跟普通卷积一样,如下图所示。

在这里插入图片描述

代码理解

整个原始代码难以理解的地方就是这个offset的计算,插值的计算,也就是最终用来卷积的这些数是怎么得到的。

模块初始化

模块初始化设置三个卷积,self.conv 用来执行最后的卷积运算,self.p_conv 用来学习偏移量,self.m_conv用来给不同位置增加学习权重,代码及注意点和注释如下所示

class DeformConv2d(nn.Module):def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):"""Args:modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2)."""super(DeformConv2d, self).__init__()self.kernel_size = kernel_sizeself.padding = paddingself.stride = strideself.zero_padding = nn.ZeroPad2d(padding)# 最终使用的卷积操作,注意stride=kernel,# 原因是最终采样点不是规则的点,需要结合通过偏移量取值,因此需要构建新的特征图# 新的特征图的尺寸是原来特征图的hw 是原来hw x kernel_size 的大小self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)# 用来学习偏移量的卷积,其中通道数为2xksxks ,如果k=3,也就是学习9个位置的2方向(x、y)偏移量self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)# 初始化偏移卷积为0nn.init.constant_(self.p_conv.weight, 0)# 学习率设置为整个网络0.1倍,避免影响整体网络性能self.p_conv.register_backward_hook(self._set_lr)# 为每个位置增加学习权重,初始化和偏移卷积一样self.modulation = modulationif modulation:self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)nn.init.constant_(self.m_conv.weight, 0)self.m_conv.register_backward_hook(self._set_lr)

forward过程

整体主要步骤如下
1、self.p_conv卷积计算offset ,# 维度 (b,2ksks,h),w # 2xksxks 若k=3,也就是用来卷积的9的位置的x、y方向偏移

2、self._get_p 函数获取offset的位置 ( 绝对位置+相对位置)# 维度 (b, 2N, h, w) ,N=ksxks。

3、self._get_x_q 之前的函数用来计算双线性插值采样点,因为位置是浮点数,需要映射回具体坐标位置 # (b, c, h, w, N),不同的通道c其实对应相同的位置。# (b, c, h, w, N)

4、self._get_x_q 函数用来得到每个位置的插值权重 # (b, c, h, w, N)

5、self._reshape_x_offset 将b, c, h, w, N重新排布为b, c, hxks, wxks 用来进行最终的卷积
代码及解释如下。

    def forward(self, x): # b,c,h,w# 计算偏移量,维度 b,2*ks*ks,h,w  # N=kxkoffset = self.p_conv(x)if self.modulation: # 为偏移量增加权重m = torch.sigmoid(self.m_conv(x))dtype = offset.data.type()ks = self.kernel_sizeN = offset.size(1) // 2 # N=ks*ks# 填充:k=3的卷积,填充p=1,尺度才不会发生改变if self.padding:x = self.zero_padding(x)# (b, 2N, h, w) ,得到p的位置p = self._get_p(offset, dtype)# (b, h, w, 2N) ,位置放在最后一个维度,方便处理p = p.contiguous().permute(0, 2, 3, 1)q_lt = p.detach().floor() #left top 左上角坐标,也就是最小值,如果是0-1之间就是0q_rb = q_lt + 1 # right bottom右下角坐标,也就是最大值,如果是0-1之间就是1# 确定四个角点坐标,设置在0 到 h-1 或 w-1 之间q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)# clip p ,采样点也需要clamp一下p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)# bilinear kernel (b, h, w, N)g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))# (b, c, h, w, N),计算四个领域的权重x_q_lt = self._get_x_q(x, q_lt, N)x_q_rb = self._get_x_q(x, q_rb, N)x_q_lb = self._get_x_q(x, q_lb, N)x_q_rt = self._get_x_q(x, q_rt, N)# (b, c, h, w, N),计算插值结果x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \g_rb.unsqueeze(dim=1) * x_q_rb + \g_lb.unsqueeze(dim=1) * x_q_lb + \g_rt.unsqueeze(dim=1) * x_q_rt# modulation,如果存在这个模块,就让偏移量*mif self.modulation:m = m.contiguous().permute(0, 2, 3, 1)m = m.unsqueeze(dim=1)m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)x_offset *= m# 重新排列,变成 h,c,h*ks,w*ks 特征图,用来最后卷积x_offset = self._reshape_x_offset(x_offset, ks)out = self.conv(x_offset)return out

self.p_conv

就是普通卷积操作,不进行解释

self._get_p

包括绝对位置和相对位置,绝对位置就是卷积中心在原图中的位置 0-(h-1) ,0-(w-1) ,相对位置就是0-(ks-1) ,卷积操作中每个点与中心位置的相对关系。

       def _get_p_n(self, N, dtype): # 相对位置p_n_x, p_n_y = torch.meshgrid(torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))# (2N, 1)p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)p_n = p_n.view(1, 2*N, 1, 1).type(dtype)return p_ndef _get_p_0(self, h, w, N, dtype): #绝对位置p_0_x, p_0_y = torch.meshgrid(torch.arange(1, h*self.stride+1, self.stride),torch.arange(1, w*self.stride+1, self.stride))p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)return p_0def _get_p(self, offset, dtype):N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)# (1, 2N, 1, 1),相对位置只有2N个,因为就是卷积核大小p_n = self._get_p_n(N, dtype)# (1, 2N, h, w),绝对位置有2NxHxW,因为每个位置都有偏移量p_0 = self._get_p_0(h, w, N, dtype) p = p_0 + p_n + offsetreturn p

self._get_x_q

将原始输入的hw变成一个维度的向量,相应的位置索引也需要变成一维,所以需要乘以w,然后最后在重新变成hxw格式

def _get_x_q(self, x, q, N):b, h, w, _ = q.size()padded_w = x.size(3)c = x.size(1)# (b, c, h*w)x = x.contiguous().view(b, c, -1)# (b, h, w, N)index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y# (b, c, h*w*N)index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)return x_offset

self._reshape_x_offset

这个的理解可以根据这个博客@链接来,也就是将整体数据重新排布成卷积的类型

图片2
    @staticmethoddef _reshape_x_offset(x_offset, ks):b, c, h, w, N = x_offset.size()x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)return x_offset

参考文献

https://blog.csdn.net/panghuzhenbang/article/details/129816869
https://zhuanlan.zhihu.com/p/335147713
https://zhuanlan.zhihu.com/p/102707081
https://blog.csdn.net/panghuzhenbang/article/details/129816869

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/51079.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】使用栈实现综合计算器

首先初始化两个栈,数栈(numStack)用于存放数据,符号栈(operStack)用于存放运算符 计算思路 1.通过一个index值(索引)来遍历表达式 2.如果发现扫描到一个数字,就直接入数栈…

Python | TypeError: ‘function’ object is not subscriptable

Python | TypeError: ‘function’ object is not subscriptable 在Python编程中,遇到“TypeError: ‘function’ object is not subscriptable”这一错误通常意味着你尝试像访问列表、元组、字典或字符串等可订阅(subscriptable)对象那样去…

Javascript面试基础6(下)

获取页面所有checkbox 怎样添加、移除、移动、复制、创建和查找节点 在JavaScript中,操作DOM(文档对象模型)是常见的任务,包括添加、移除、移动、复制、创建和查找节点。以下是一些基本的示例,说明如何执行这些操作&a…

Java语言程序设计——篇九(2)

🌿🌿🌿跟随博主脚步,从这里开始→博主主页🌿🌿🌿 枚举类型 枚举类型的定义枚举类型的方法实战演练 枚举在switch中的应用实战演练 枚举类的构造方法实战演练 枚举类型的定义 [修饰符] enum 枚举…

医院影像平台源码,C/S体系结构的C#语言PACS系统全套商业源代码

医学学影像临床信息系统具有图像采集、显示、存储、传输和管理等功能,支持DICOM影像设备和非DICOM影像设备,可以识别CT、MR、CR/DR、X光、DSA、B超、NM、SC等设备的图像类型,可对数字影像进行无损压缩和有损压缩处理。C/S体系结构的多媒体数据…

湖仓一体架构解析:数仓架构选择(第48天)

系列文章目录 1、Lambda 架构 2、Kappa 架构 3、混合架构 4、架构选择 5、实时数仓现状 6、湖仓一体架构 7、流批一体架构 文章目录 系列文章目录前言1、Lambda 架构2、Kappa 架构3、混合架构4、架构选择5、实时数仓现状6、湖仓一体架构7、流批一体架构 前言 本文解析了Lambd…

Verilog语言和C语言的本质区别是什么?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「C语言的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 用老石的一句话其实很好说…

ssm框架整合,异常处理器和拦截器(纯注解开发)

目录 ssm框架整合 第一步:指定打包方式和导入所需要的依赖 打包方法:war springMVC所需依赖 解析json依赖 mybatis依赖 数据库驱动依赖 druid数据源依赖 junit依赖 第二步:导入tomcat插件 第三步:编写配置类 SpringCon…

【AI绘画】Midjourney V6初学者完全指南 参数篇

本文我们将详细介绍对图像生成结果产生重大影响的"参数"。 1. 什么是参数? 参数是一种添加到提示末尾以调整图像生成输出设置的方法。 它们用两个连字符"–“和特定字符串表示,如”–ar"、“–chaos”、"–r"等。 您也可以同时使用多个参数…

java项目中添加SDK项目作为依赖使用(无需上传Maven)

需求: 当需要多次调用某个函数或算法时,不想每次调用接口都自己编写,可以将该项目打包,以添加依赖的方式实现调用 适用于: 无需上线的项目,仅公司或团队内部使用的项目 操作步骤: 以下面这…

菜鸟从0学微服务——MyBatis-Plus

关于“菜鸟从0学微服务” 针对有编程基础,开始学习微服务的同学,我们陆续推出从0学微服务的笔记分享。力求从各个中间件的使用来反思这些中间件的作用和优势。 会分享的比较快,会记录demo演算和中间件的使用过程,至于细节的理论…

【数学建模】——【python】实现【最短路径】【最小生成树】【复杂网络分析】

目录 1. 最短路径问题 - 绘制城市间旅行最短路径图 题目描述: 要求: 示例数据: python 代码实现 实现思想: 要点: 2. 最小生成树问题 - Kruskal算法绘制MST 题目描述: 要求: 示例数据…

PostgreSQL入门与进阶学习,体系化的SQL知识,完成终极目标高可用与容灾,性能优化与架构设计,以及安全策略

​专栏内容: postgresql使用入门基础手写数据库toadb并发编程 个人主页:我的主页 管理社区:开源数据库 座右铭:天行健,君子以自强不息;地势坤,君子以厚德载物. 文章目录 概述基础篇初级篇进阶篇…

事务、函数和索引

目录 什么是事务? 事务的ACID原则: 事务的操作 事务的原子性、一致性、持久性 事务的隔离性 什么是事务的隔离性? 用什么方法实现事务的隔离性? MySQL中的锁 锁分类: 事务的隔离级别 事务并发问题 InnoDB的MVCC MVCC…

【C++】红黑树的应用(封装map和set)

✨ 青山一道同云雨,明月何曾是两乡 🌏 📃个人主页:island1314 🔥个人专栏:C学习 🚀 欢迎关注:👍点赞 &…

Unity UGUI 实战学习笔记(3)

仅作学习,不做任何商业用途 不是源码,不是源码! 是我通过"照虎画猫"写的,可能有些小修改 不提供素材,所以应该不算是盗版资源,侵权删 拼UI 提示面板的逻辑 using System.Collections; using System.Col…

大数据——Hive原理

摘要 Apache Hive 是一个基于 Hadoop 分布式文件系统 (HDFS) 的数据仓库软件项目,专为存储和处理大规模数据集而设计。它提供类似 SQL 的查询语言 HiveQL,使用户能够轻松编写复杂的查询和分析任务,而无需深入了解 Hadoop 的底层实现。 Hive…

Firefox扩展程序和Java程序通信

实现Firefox扩展程序,和Java RMI Client端进行通信。 在Firefox工具栏注册按钮,点击按钮后弹出Popup.html页面,引用Popup.js脚本,通过脚本向Java RMI client发送消息,Java RMI Client接收消息后转发到Java RMI Server…

MyBatis的入门操作--打印日志和增删改查(单表静态)

下面介绍注解和xml实现crud的操作 目录 一、日志打印和参数传递 1.1.使用mybatis打印日志 1.2.参数传递细节 二、crud(注解实现) 2.1.增(insert) 2.2.删(delete) 和 (update) 2.3.查(select) 三、crud(xml实现) 3.1.准备…

中国居民膳食指南书籍知识点汇总

人如果吃不好,就不能好好思考,好好爱,好好休息。——维吉尼亚伍儿夫 文章目录 书籍简介饮食准则推荐膳食图示 准则一:食物多样,合理搭配合理搭配的方法平衡膳食的科学原理均衡饮食的作用食物功效(有科学实验…