pytorch实现分割模型TransUNet

TransUNet是一个非常经典的图像分割模型。该模型出现在Transformer引入图像领域的早期,所以结构比较简单,但是实际上效果却比很多后续花哨的模型更好。所以有必要捋一遍pytorch实现TransUNet的整体流程。

首先,按照惯例,先看一下TransUNet的结构图:

根据结构图,我们可以看出,整体结构就是基于UNet魔改的。

1,具体结构如下:

1. CNN-Transformer混合编码器:TransUNet使用卷积神经网络(CNN)作为特征提取器,生成特征图。然后,从CNN特征图中提取的1x1 patches通过patch embedding转换为序列,作为Transformer的输入。这种设计允许模型利用CNN的高分辨率特征图。

2. Transformer编码器:Transformer编码器由多头自注意力(Multihead Self-Attention, MSA)和多层感知器(MLP)块组成。这些层处理输入序列,捕获全局上下文信息。

3. 级联上采样器(Cascaded Upsampler, CUP):为了从Transformer编码器的输出中恢复空间分辨率,TransUNet引入了CUP。CUP由多个上采样步骤组成,每个步骤包括一个2x上采样操作、一个3x3卷积层和一个ReLU激活层。这些步骤将特征图从低分辨率逐步上采样到原始图像的分辨率。

4. skip connection:TransUNet采用了U-Net的u形架构设计,通过跳跃连接(skip-connections)将编码器中的高分辨率CNN特征图与Transformer编码的全局上下文特征结合起来,以实现精确的定位。

5. 解码器:解码器部分使用CUP来从Transformer编码器的输出中恢复出最终的分割掩码。这包括将Transformer的输出特征图与CNN特征图结合,并通过上采样步骤恢复到原始图像的分辨率。

我们只需要实现其每个模块,然后安装UNet拼装成整体就可以了。

2,首先实现的是编码器分支的卷积部分:

每个卷积模块可以使用resnet的一个块,或者自己实现一个

class EncoderBottleneck(nn.Module):def __init__(self, in_channels, out_channels, stride=1, base_width=64):super().__init__()  # 初始化父类self.downsample = nn.Sequential(  # 下采样层,用于降低特征图的维度nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))width = int(out_channels * (base_width / 64))  # 计算中间通道数self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False)  # 第一个卷积层self.norm1 = nn.BatchNorm2d(width)  # 第一个批量归一化层self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)  # 第二个卷积层self.norm2 = nn.BatchNorm2d(width)  # 第二个批量归一化层self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False)  # 第三个卷积层self.norm3 = nn.BatchNorm2d(out_channels)  # 第三个批量归一化层self.relu = nn.ReLU(inplace=True)  # ReLU激活函数def forward(self, x):x_down = self.downsample(x)  # 下采样操作x = self.conv1(x)  # 第一个卷积操作x = self.norm1(x)  # 第一个批量归一化x = self.relu(x)  # ReLU激活x = self.conv2(x)  # 第二个卷积操作x = self.norm2(x)  # 第二个批量归一化x = self.relu(x)  # ReLU激活x = self.conv3(x)  # 第三个卷积操作x = self.norm3(x)  # 第三个批量归一化x = x + x_down  # 残差连接x = self.relu(x)  # ReLU激活return x

3,实现ViT模块

多头注意力实现如下:

class MultiHeadAttention(nn.Module):def __init__(self, embedding_dim, head_num):super().__init__()  # 调用父类构造函数self.head_num = head_num  # 多头的数量self.dk = (embedding_dim // head_num) ** (1 / 2)  # 缩放因子,用于缩放点积注意力self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)  # 线性层,用于生成查询(Q)、键(K)和值(V)self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)  # 输出线性层def forward(self, x, mask=None):qkv = self.qkv_layer(x)  # 通过线性层生成Q、K、Vquery, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.head_num))  # 将Q、K、V重塑为多头注意力的格式energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk  # 计算点积注意力的能量if mask is not None:  # 如果提供了掩码,则在能量上应用掩码energy = energy.masked_fill(mask, -np.inf)attention = torch.softmax(energy, dim=-1)  # 应用softmax函数,得到注意力权重x = torch.einsum("... i j , ... j d -> ... i d", attention, value)  # 应用注意力权重到值上x = rearrange(x, "b h t d -> b t (h d)")  # 重塑x以准备输出x = self.out_attention(x)  # 通过输出线性层return x

MLP实现如下:

# 定义MLP模块
class MLP(nn.Module):def __init__(self, embedding_dim, mlp_dim):super().__init__()  # 调用父类构造函数self.mlp_layers = nn.Sequential(  # 定义MLP的层nn.Linear(embedding_dim, mlp_dim),nn.GELU(),  # GELU激活函数nn.Dropout(0.1),  # Dropout层,用于正则化nn.Linear(mlp_dim, embedding_dim),  # 线性层nn.Dropout(0.1)  # Dropout层)def forward(self, x):x = self.mlp_layers(x)  # 通过MLP层return x

一个Transformer编码器块由归一化层,多头注意力,MLP和残差连接组成,实现如下:

# 定义Transformer编码器块
class TransformerEncoderBlock(nn.Module):def __init__(self, embedding_dim, head_num, mlp_dim):super().__init__()  # 调用父类构造函数self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)  # 多头注意力模块self.mlp = MLP(embedding_dim, mlp_dim)  # MLP模块self.layer_norm1 = nn.LayerNorm(embedding_dim)  # 第一层归一化self.layer_norm2 = nn.LayerNorm(embedding_dim)  # 第二层归一化self.dropout = nn.Dropout(0.1)  # Dropout层def forward(self, x):_x = self.multi_head_attention(x)  # 通过多头注意力模块_x = self.dropout(_x)  # 应用dropoutx = x + _x  # 残差连接x = self.layer_norm1(x)  # 第一层归一化_x = self.mlp(x)  # 通过MLP模块x = x + _x  # 残差连接x = self.layer_norm2(x)  # 第二层归一化return x

Transformer 编码器由多层Transformer块堆叠而成,其中block_num代表的就是堆叠的层数

# 定义Transformer编码器
class TransformerEncoder(nn.Module):def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):super().__init__()  # 调用父类构造函数self.layer_blocks = nn.ModuleList([  # 创建一个模块列表,包含多个编码器块TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)])def forward(self, x):for layer_block in self.layer_blocks:  # 遍历每个编码器块x = layer_block(x)  # 通过每个块return x

vit的全部模块已经实现,下面就vit整体结构了。

vit的整体结构就是先将输入图片划分patches,然后将patches做embedding。

vit的分类头是一组额外添加的cl-token,将这个class-token复制batches遍,之后就可以将复制后的class_Token拼接到之前的embedding上了。

之后需要把位置编码加到这个embedding上。

这样,输入的图像特征就被处理好了,转换成了输入给Transformer块的形式。

之后只要输入一个Transformer编码器和一个MLP头,就可以得到vit的输出结果。

如果是分类任务,则class_token就是分类结果,如果不是分类任务,比如分割或者vit作为一个模块,那么输出的就是patches形式的特征图。

# 定义ViT模型
class ViT(nn.Module):def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim, block_num, patch_dim, classification=True, num_classes=1):super().__init__()  # 调用父类构造函数self.patch_dim = patch_dim  # 定义patch的维度self.classification = classification  # 是否进行分类self.num_tokens = (img_dim // patch_dim) ** 2  # 计算tokens的数量self.token_dim = in_channels * (patch_dim ** 2)  # 计算每个token的维度self.projection = nn.Linear(self.token_dim, embedding_dim)  # 线性层,用于将patches投影到embedding空间self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))  # 可学习的embeddingself.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))  # 类别tokenself.dropout = nn.Dropout(0.1)  # Dropout层self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)  # Transformer编码器if self.classification:  # 如果是分类任务self.mlp_head = nn.Linear(embedding_dim, num_classes)  # 分类头def forward(self, x):img_patches = rearrange(x,  # 将输入图像重塑为patches序列'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',patch_x=self.patch_dim, patch_y=self.patch_dim)batch_size, tokens, _ = img_patches.shape  # 获取批次大小、tokens数量和通道数project = self.projection(img_patches)  # 将patches投影到embedding空间token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', batch_size=batch_size)  # 重复cls_token以匹配批次大小patches = torch.cat((token, project), dim=1)  # 将cls_token和投影后的patches拼接patches += self.embedding[:tokens + 1, :]  # 将可学习的embedding添加到patchesx = self.dropout(patches)  # 应用dropoutx = self.transformer(x)  # 通过Transformer编码器x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]  # 如果是分类任务,使用cls_token的输出;否则,使用patches的输出return x

4,实现解码器的模块

解码器的模块就是卷积模块,接受两个输入:上采样而来的特征图以及skip-connection来的特征图

# 定义解码器中的瓶颈层
class DecoderBottleneck(nn.Module):def __init__(self, in_channels, out_channels, scale_factor=2):super().__init__()  # 初始化父类self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)  # 上采样层self.layer = nn.Sequential(  # 解码器层nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x, x_concat=None):x = self.upsample(x)  # 上采样操作if x_concat is not None:  # 如果有额外的特征图进行拼接x = torch.cat([x_concat, x], dim=1)  # 在通道维度上拼接x = self.layer(x)  # 通过解码器层return x

5,组装成模型

所有模块都已经定义完成,下面拿这些模块来组装成模型。

编码器分支由三个卷积模块和一个vit模块组成,输出的x为解码分支最终的特征图,x1,x2,x3分别是三个卷积模块的输出

# 定义编码器
class Encoder(nn.Module):def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):super().__init__()  # 初始化父类self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)  # 第一个卷积层self.norm1 = nn.BatchNorm2d(out_channels)  # 第一个批量归一化层self.relu = nn.ReLU(inplace=True)  # ReLU激活函数self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)  # 第一个编码器瓶颈层self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)  # 第二个编码器瓶颈层self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)  # 第三个编码器瓶颈层self.vit_img_dim = img_dim // patch_dim  # ViT的图像维度self.vit = ViT(self.vit_img_dim, out_channels * 8, out_channels * 8,  # ViT模型head_num, mlp_dim, block_num, patch_dim=1, classification=False)self.conv2 = nn.Conv2d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)  # 第四个卷积层self.norm2 = nn.BatchNorm2d(512)  # 第四个批量归一化层def forward(self, x):x = self.conv1(x)  # 第一个卷积操作x = self.norm1(x)  # 第一个批量归一化x1 = self.relu(x)  # ReLU激活x2 = self.encoder1(x1)  # 第一个编码器瓶颈层x3 = self.encoder2(x2)  # 第二个编码器瓶颈层x = self.encoder3(x3)  # 第三个编码器瓶颈层x = self.vit(x)  # 通过ViT模型x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_img_dim, y=self.vit_img_dim)  # 重塑特征图x = self.conv2(x)  # 第四个卷积操作x = self.norm2(x)  # 第四个批量归一化x = self.relu(x)  # ReLU激活return x, x1, x2, x3  # 返回多个特征图

解码分支接受编码分支的输出x,以及三个卷积模块的输出x1,x2,x3,

# 定义解码器
class Decoder(nn.Module):def __init__(self, out_channels, class_num):super().__init__()  # 初始化父类self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2)  # 第一个解码器瓶颈层self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)  # 第二个解码器瓶颈层self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))  # 第三个解码器瓶颈层self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))  # 第四个解码器瓶颈层self.conv1 = nn.Conv2d(int(out_channels * 1 / 8), class_num, kernel_size=1)  # 最后一个卷积层,用于输出def forward(self, x, x1, x2, x3):x = self.decoder1(x, x3)  # 第一个解码器瓶颈层x = self.decoder2(x, x2)  # 第二个解码器瓶颈层x = self.decoder3(x, x1)  # 第三个解码器瓶颈层x = self.decoder4(x)  # 第四个解码器瓶颈层x = self.conv1(x)  # 最后一个卷积层return x  # 返回解码器的输出

整个模型结构:

# 定义TransUNet模型
class TransUNet(nn.Module):def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):super().__init__()  # 初始化父类self.encoder = Encoder(img_dim, in_channels, out_channels,  # 初始化编码器head_num, mlp_dim, block_num, patch_dim)self.decoder = Decoder(out_channels, class_num)  # 初始化解码器def forward(self, x):x, x1, x2, x3 = self.encoder(x)  # 编码分支x = self.decoder(x, x1, x2, x3)  # 解码分支return x  # 返回最终输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

逼疯快递员的送货上门,谁来背锅?

快递上门的问题近几年来一直争论不休。 最近,随着新修订的《快递市场管理办法》正式实施,这个话题又成为了焦点。 消费者希望快递能够送上门省去麻烦,快递员希望统一送到代收点提高效率。 是消费者要求太高?快递员太过怠慢&…

[金三银四] 系统调用相关

2.36 系统调用的详细流程 Linux 在x86上的系统调用通过 int 0x80 实现,用系统调用号来区分入口函数。操作系统实现系统调用的基本过程是: 应用程序调用库函数(API);API 将系统调用号存入寄存器(EAX&#…

CKA备考攻略:掌握Pod日志收集,事半功倍的秘诀!

往期精彩文章 : 提升CKA考试胜算:一文带你全面了解RBAC权限控制!揭秘高效运维:如何用kubectl top命令实时监控K8s资源使用情况?CKA认证必备:掌握k8s网络策略的关键要点提高CKA认证成功率,CKA真题中的节点维…

2.JavaWebMySql基础

导语: 一、数据库基本概念 1.什么是数据库 2.关于MySql数据库 二、MySQL的安装与卸载 安装步骤: 卸载步骤: 三、MySQL服务操作 1.服务启动和关闭: 2.登录和退出MySQL: 3.服务自启动: 4.命令行登…

Python实现线性查找算法

Python实现线性查找算法 以下是使用 Python 实现线性查找算法的示例代码: def linear_search(arr, target):"""线性查找算法:param arr: 要搜索的数组:param target: 目标值:return: 如果找到目标值,返回其索引;否则返回 -1…

【玩转Linux】有关Linux权限

目录 一.Linux权限的概念 1. 权限的本质 2.Linux中的用户 3.Linux中的权限管理 (1)文件访问者的分类 (2)文件类型和访问权限(事物属性) ①文件基本权限 ②文件权限值的表示方法 (3)文件访问权限的相关设置方法 ① 用 户 表 示 符 / - 权 …

Vue3 快速上手从0到1,两小时学会【附源码】

小伙伴们好,欢迎关注,一起学习,无限进步 以下内容为vue3的学习笔记 项目需要使用到的依赖 npm install axios npm install nanoid vue-router npm install pinia npm install mitt 源码:Gitee 运行 npm install npm run dev需要运…

MacBook2024苹果免费mac电脑清理垃圾软件CleanMyMac X

CleanMyMac X是一款专业的Mac清理软件,具备多种强大功能。首先,它能够智能清理Mac磁盘上的垃圾文件和多余语言安装包,从而快速释放电脑内存。其次,CleanMyMac X可以轻松管理和升级Mac上的应用,同时强力卸载恶意软件并修…

windows使用pyenv

1、前言 虽然anaconda比pyenv相比有更好的python安装体验,但是有一个比较严重的问题的就是,他的python版本跨度不够大,一些老一些的项目的python版本找不到,比如py12306要求的python版本是3.6,在anaconda却找不到这个版…

查看pip当前关联python版本及位置

好久没用python了,把各种pip指向的环境忘光光啦,这里记录一下查看pip当前关联的python版本及位置的方法: pip -V结果: 我一般不用这个版本的python,去环境变量看了一下,原来是anaconda的Scripts自带pip&a…

gprof安装使用(CMake)说明

一、安装 1、gprof默认已安装,可安装相关图形处理 sudo apt-get install python graphviz sudo pip install gprof2dot 注意:在Debian中没有安装成功,报Python的版本不匹配 二、使用说明 1、使用CMake管理的工程: 重新配置CMa…

Elastic Stack--05--聚合、映射mapping

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1.聚合(aggregations)基本概念桶(bucket)度量(metrics) 案例 11. 接下来按price字段进行分组:2. 若想对所…

LVS集群 ----------------(直接路由 )DR模式部署 (二)

一、LVS集群的三种工作模式 lvs-nat:修改请求报文的目标IP,多目标IP的DNAT lvs-dr:操纵封装新的MAC地址(直接路由) lvs-tun:隧道模式 lvs-dr 是 LVS集群的 默认工作模式 NAT通过网络地址转换实现的虚拟服务器&…

2024年【电工(初级)】考试内容及电工(初级)考试报名

题库来源:安全生产模拟考试一点通公众号小程序 电工(初级)考试内容根据新电工(初级)考试大纲要求,安全生产模拟考试一点通将电工(初级)模拟考试试题进行汇编,组成一套电…

Gitlab修改仓库权限为public、Internal、Private

Public(公开):所有人都可以访问该仓库; Internal(内部):同一个GitLab群组或实例内的所有用户都可以访问该仓库; Private(私人):仅包括指定成员的用…

2024 年广东省职业院校技能大赛(高职组) “云计算应用”赛项样题②

2024 年广东省职业院校技能大赛(高职组) “云计算应用”赛项样题② 模块一 私有云(50 分)任务 1 私有云服务搭建(10 分)任务 2 私有云服务运维(25 分)任务 3 私有云运维开发&#xf…

突破编程_前端_JS编程实例(目录导航)

1 开发目标 目录导航组件旨在提供一个滚动目录导航功能,使得用户可以方便地通过点击目录条目快速定位到对应的内容标题位置,同时也能够随着滚动条的移动动态显示当前位置在目录中的位置: 2 详细需求 2.1 标题提取与目录生成 组件需要能够自…

虚拟机实验环境配置与使用(计算机系统2)

一、 实验目标: 熟悉Linux上C程序的编译和调试工具,包括以下内容: 1. 了解Linux操作系统及其常用命令 2. 掌握编译工具gcc的基本用法 3. 掌握使用gdb进行程序调试 二、实验环境与工件 1.个人电脑 2. Fedora 13 Linux 操作系统 3. gcc…

【Python】牛客网—软件开发-Python专项练习(day1)

1.(单选)下面哪个是Python中不可变的数据结构? A.set B.list C.tuple D.dict 可变数据类型:列表list[ ]、字典dict{ }、集合set{ }(能查询,也可更改)数据发生改变,但内存地址不变 不…

OPCUA 学习笔记:程序模型

无论是边缘控制器,还是PLC 中,除了信息模型之外,还有应用程序,这些程序可能是IEC61131-3 编写的程序,也可能是其它程序开发的可执行程序。 尽管OPCUA 描述模型能力很强,但是它缺乏算法的描述方式。但是OPCU…