【论文笔记】MCANet: Medical Image Segmentation withMulti-Scale Cross-Axis Attention

        医疗图像分割任务中,捕获多尺度信息、构建长期依赖对分割结果有非常大的影响。该论文提出了 Multi-scale Cross-axis Attention(MCA)模块,融合了多尺度特征,并使用Attention提取全局上下文信息。

论文地址:MCANet: Medical Image Segmentation with Multi-Scale Cross-Axis Attention

代码地址:https://github.com/haoshao-nku/medical_seg

一、MCA(Multi-scale Cross-axis Attention)

MCA的结构如下,将E2/3/4通过concat连接起来(concat前先插值到同样分辨率),经过1x1的卷积后(压缩通道数来降低计算量),得到了包含多尺度信息的特征图F,然后在X和Y方向使用不同大小的卷积核进行卷积运算(比如1x11的卷积是x方向,11x1的是y方向,这里可以对着代码看,容易理解),将Q在X和Y方向交换后(这就是Cross-Axis),经过注意力模块后,将多个特征图相加,并融合E1,经过卷积后得到输出。该模块有以下特点:

1、注意力机制作用在多个不同尺度的特征图;

2、Multi-Scale x-Axis Convolution和Multi-Scale y-Axis Convolution分别关注不同轴的特征,在计算注意力时交叉计算,使得不同方向的特征都能被关注到。

MCA细节如下图,输入特征图进入x和y方向的路径,经过不同大小的卷积后进行融合,然后跨轴(x和y轴的Q交换)计算Attention,最后得到输出特征图。

二、代码

MCA的代码如下所示,总体来说比较简单:

from audioop import bias
from pip import main
import torch
import torch.nn as nn
import torch.nn.functional as F
import numbers
from mmseg.registry import MODELS
from einops import rearrange
from ..utils import resize
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmseg.models.decode_heads.decode_head import BaseDecodeHeaddef to_3d(x):return rearrange(x, 'b c h w -> b (h w) c')def to_4d(x,h,w):return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)class BiasFree_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(BiasFree_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):sigma = x.var(-1, keepdim=True, unbiased=False)return x / torch.sqrt(sigma+1e-5) * self.weightclass WithBias_LayerNorm(nn.Module):def __init__(self, normalized_shape):super(WithBias_LayerNorm, self).__init__()if isinstance(normalized_shape, numbers.Integral):normalized_shape = (normalized_shape,)normalized_shape = torch.Size(normalized_shape)assert len(normalized_shape) == 1self.weight = nn.Parameter(torch.ones(normalized_shape))self.bias = nn.Parameter(torch.zeros(normalized_shape))self.normalized_shape = normalized_shapedef forward(self, x):mu = x.mean(-1, keepdim=True)sigma = x.var(-1, keepdim=True, unbiased=False)return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.biasclass LayerNorm(nn.Module):def __init__(self, dim, LayerNorm_type):super(LayerNorm, self).__init__()if LayerNorm_type =='BiasFree':self.body = BiasFree_LayerNorm(dim)else:self.body = WithBias_LayerNorm(dim)def forward(self, x):h, w = x.shape[-2:]return to_4d(self.body(to_3d(x)), h, w)class Attention(nn.Module):def __init__(self, dim, num_heads,LayerNorm_type,):super(Attention, self).__init__()self.num_heads = num_heads   self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))   self.norm1 = LayerNorm(dim, LayerNorm_type)self.project_out = nn.Conv2d(dim, dim, kernel_size=1)      self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)def forward(self, x):b,c,h,w = x.shape   x1 = self.norm1(x)attn_00 = self.conv0_1(x1)attn_01= self.conv0_2(x1)  attn_10 = self.conv1_1(x1)attn_11 = self.conv1_2(x1)attn_20 = self.conv2_1(x1)attn_21 = self.conv2_2(x1)   out1 = attn_00+attn_10+attn_20out2 = attn_01+attn_11+attn_21   out1 = self.project_out(out1)out2 = self.project_out(out2)  k1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)v1 = rearrange(out1, 'b (head c) h w -> b head h (w c)', head=self.num_heads)k2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)v2 = rearrange(out2, 'b (head c) h w -> b head w (h c)', head=self.num_heads)   q2 = rearrange(out1, 'b (head c) h w -> b head w (h c)', head=self.num_heads) q1 = rearrange(out2, 'b (head c) h w -> b head h (w c)', head=self.num_heads)       q1 = torch.nn.functional.normalize(q1, dim=-1)q2 = torch.nn.functional.normalize(q2, dim=-1)k1 = torch.nn.functional.normalize(k1, dim=-1)k2 = torch.nn.functional.normalize(k2, dim=-1)          attn1 = (q1 @ k1.transpose(-2, -1))attn1 = attn1.softmax(dim=-1)   out3 = (attn1 @ v1) + q1      attn2 = (q2 @ k2.transpose(-2, -1))attn2 = attn2.softmax(dim=-1)   out4 = (attn2 @ v2) + q2                         out3 = rearrange(out3, 'b head h (w c) -> b (head c) h w', head=self.num_heads, h=h, w=w)out4 = rearrange(out4, 'b head w (h c) -> b (head c) h w', head=self.num_heads, h=h, w=w)       out =  self.project_out(out3)  + self.project_out(out4) + xreturn out@MODELS.register_module()
class MCAHead(BaseDecodeHead):def __init__(self,in_channels,image_size,heads,c1_channels,**kwargs):super(MCAHead, self).__init__(in_channels,input_transform = 'multiple_select',**kwargs)self.image_size = image_sizeself.decoder_level = Attention(in_channels[1],heads,LayerNorm_type = 'WithBias')self.align = ConvModule(in_channels[3],in_channels[0],1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg)self.squeeze = ConvModule(sum((in_channels[1],in_channels[2],in_channels[3])),in_channels[1],1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg)self.sep_bottleneck = nn.Sequential(DepthwiseSeparableConvModule(in_channels[1] + in_channels[0],in_channels[3],3,padding=1,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg),DepthwiseSeparableConvModule(in_channels[3],in_channels[3],3,padding=1,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg))             def forward(self, inputs):"""Forward function."""inputs = self._transform_inputs(inputs)inputs = [resize(level,size=self.image_size,mode='bilinear',align_corners=self.align_corners) for level in inputs]y1 = torch.cat([inputs[1],inputs[2],inputs[3]], dim=1)x = self.squeeze(y1)  x = self.decoder_level(x)x = torch.cat([x,inputs[0]], dim=1) x = self.sep_bottleneck(x)output = self.align(x)  output = self.cls_seg(output)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/234490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蔚来打败“蔚来”

作者 | 魏启扬 来源 | 洞见新研社 继2019年后,又一次深陷倒闭传闻的蔚来汽车,“在关键时刻找到钱了”。 12月18日,蔚来汽车宣布,与阿布扎比投资机构CYVN Holdings签订新一轮股份认购协议,CYVN Holdings将通过其附属公…

四色问题(图论)python

四色问题是一种著名的图论问题,它要求在给定的地图上给每个区域着一种颜色,使得相邻的区域颜色不同,而只使用四种颜色。这个问题可以通过图的着色来解决,其中图的节点表示区域,边表示相邻的关系。 在 Python 中&#…

excel导出,post还是get请求?

1,前提 今天在解决excel导出的bug时,因为导出接口查询参数较多,所以把原来的get请求接口修改为post请求 原代码: 修改后: 2,修改后 postman请求正常,然后让前端对接口进行同步修改&#xff0…

Clion自定义管理和配置软件构建过程的工具(代替CMake)构建程序

在公司由于需要x86环境和其他arm环境,同时需要使用公司自定义的mine_x86或者mine_orin对代码进行编译。 编译命令如下mine_x86 build -Dlocal1 -j8,为使用Clion对程序进行调试,需要对程序进行设置。方便调试代码时能够断点查看变量。尝试了很多次&#…

@WebMethod 这个注解的作用

WebMethod 注解是 Java 中 JAX-WS(Java API for XML Web Services)的一部分,用于将一个特定的方法标记为 Web 服务操作。当你在类方法上使用 WebMethod 注解时,这表明该方法是一个对外暴露的 Web 服务方法,即这个方法可…

pytorch实现DCP暗通道先验去雾算法及其onnx导出

pytorch实现DCP暗通道先验去雾算法及其onnx导出 简介实现ONNX导出导出测试 简介 最近在做图像去雾,于是在Pytorch上复现了一下dcp算法。暗通道先验去雾算法是大神何恺明2009年发表在CVPR上的一篇论文,还获得了当年的CVPR最佳论文。 实现 具体原理就不…

计算机组成原理综合1

1、完整的计算机系统应包括______。D A. 运算器、存储器和控制器 B. 外部设备和主机 C. 主机和实用程序 D. 配套的硬件设备和软件系统 2、计算机系统中的存储器系统是指______。D A. RAM存储器 B. ROM存储器 C. 主存储器 …

安捷伦Agilent 34970A数据采集

易学易用 从34972A简化的配置到内置的图形Web界面,我们都投入了非常多的时间和精力,以帮助您节约宝贵的时间。一些非常简单的东西,例如模块上螺旋型端子连接器内置热电偶参考结、包括众多实例和提示的完整用户文档,以及使您能够在开机数分钟后…

接口测试和测试用例分析

只要有软件产品的公司百分之九十以上都会做接口测试,要做接口测试的公司那是少不了接口测试工程师的,接口测试工程师相对于其他的职位又比较轻松并且容易胜任。如果你想从事接口测试的工作那就少不了对接口进行分析,同时也会对测试用例进行研…

node.js mongoose middleware

目录 官方文档 简介 定义模型 注册中间件 创建doc实例,并进行增删改查 方法名和注册的中间件名相匹配 执行结果 分析 错误处理中间件 手动抛出错误 注意点 官方文档 Mongoose v8.0.3: Middleware 简介 在mongoose中,中间件是一种允许在执…

DDD领域驱动设计(二)

软件系统复杂性的应对 解决复杂和大规模软件的武器可以粗略的归位三种:抽象 分治和知识 抽象: 使用抽象能够精简问题空间,而且问题越小越容易理解。比如你去一个地方 一开始的时候并不需要确定用什么方式到达。分治: 类似算法里面的dp用的就是分治的想法。分割后的…

破局新渠道|2023年热度全域达人分销生态大会回顾

12月7日,由热度电商、热度云、集脉新电商联合举办的「破局新渠道」热度全域达人分销生态大会暨热度云3.0发布会在杭州国际博览中心圆满收官。大会邀请了平台官方、电商协会、品牌方、业务操盘手、数据专家、团长机构、达人等达人分销生态中的多个角色,从…

Python办公—pandas读取Excel表格增加列、两列保持一致、依条件修改单元格内容(附代码)

目录 专栏导读背景插入一列插入多列依条件修改单元格内容(2个条件以内)依条件修改单元格内容(3个条件以上)两列保持一致结尾 专栏导读 🌸 欢迎来到Python办公自动化专栏—Python处理办公问题,解放您的双手 🏳️‍🌈 博客主页&am…

2023年中国数据智能管理峰会(DAMS上海站2023)-核心PPT资料下载

一、峰会简介 数据已经成为企业的核心竞争力!谁掌控数据、更好的利用数据、实现资产化,谁就会真正率先进入大数据时代。 1、数据智能管理趋势和挑战 在峰会上,与会者讨论了数据智能管理的最新趋势和挑战。随着数据量的不断增加&#xff0c…

JNI逆向

IDA:JNI类型转换 1.IDA高版本(IDA 高版本内置了定义的JNI结构体; 如果没有的话,在Views->Open subviews -> Type Libraries 中添加Android ARM的lib即可) 解决方法: 只需要对JNIEnv 指针(JNIEnv * &#xff09…

使用postman时,报错SSL Error: Unable to verify the first certificate

开发中使用postman调用接口,出现以下问题,在确认路径、参数、请求方式均为正确的情况下 解决方法 File - Settings -> SSL certification verification 关闭 找到图中配置,这里默认是打开状态,把它关闭即可:ON …

虾皮测评选品:如何在虾皮平台上进行有效的产品测评和选品

在如今的电商市场中,虾皮(Shopee)平台已经成为了卖家们最为重要的销售渠道之一。而在虾皮平台上进行产品测评和选品对于卖家来说至关重要,它直接影响到店铺的销售额和利润。本文将为您提供一些关于如何在虾皮平台上进行有效的产品…

如何通过ETLCloud的API对接功能实现各种SaaS平台数据对接

前言 当前使用SaaS系统的企业越来越多,当我们需要对SaaS系统中产生的数据进行分析和对接时就需要与SaaS系统提供的API进行对接,因为SaaS一般是不会提供数据库表给企业,这时就应该使用ETL(Extract, Transform, Load)的…

复杂 SQL 实现分组分情况分页查询

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、根据 camp_status 字段分为 6 种情况 1.1 SQL语句 1.2 SQL解释 二、分页 SQL 实现 2.1 SQL语句 2.2 根据 camp_type 区分返…

Unity中Shader测试常用的UGUI功能简介

文章目录 前言一、锚点1、锚点快捷修改位置2、使用Anchor Presets快捷修改3、Anchor Presets界面按下 Shift 可以快捷修改锚点和中心点位置4、Anchor Presets界面按下 Alt 可以快捷修改锚点位置、UI对象位置 和 长宽大小 二、Canvas画布1、UGUI中 Transform 变成了 Rect Transf…