pytorch实现分割模型TransUNet

TransUNet是一个非常经典的图像分割模型。该模型出现在Transformer引入图像领域的早期,所以结构比较简单,但是实际上效果却比很多后续花哨的模型更好。所以有必要捋一遍pytorch实现TransUNet的整体流程。

首先,按照惯例,先看一下TransUNet的结构图:

根据结构图,我们可以看出,整体结构就是基于UNet魔改的。

1,具体结构如下:

1. CNN-Transformer混合编码器:TransUNet使用卷积神经网络(CNN)作为特征提取器,生成特征图。然后,从CNN特征图中提取的1x1 patches通过patch embedding转换为序列,作为Transformer的输入。这种设计允许模型利用CNN的高分辨率特征图。

2. Transformer编码器:Transformer编码器由多头自注意力(Multihead Self-Attention, MSA)和多层感知器(MLP)块组成。这些层处理输入序列,捕获全局上下文信息。

3. 级联上采样器(Cascaded Upsampler, CUP):为了从Transformer编码器的输出中恢复空间分辨率,TransUNet引入了CUP。CUP由多个上采样步骤组成,每个步骤包括一个2x上采样操作、一个3x3卷积层和一个ReLU激活层。这些步骤将特征图从低分辨率逐步上采样到原始图像的分辨率。

4. skip connection:TransUNet采用了U-Net的u形架构设计,通过跳跃连接(skip-connections)将编码器中的高分辨率CNN特征图与Transformer编码的全局上下文特征结合起来,以实现精确的定位。

5. 解码器:解码器部分使用CUP来从Transformer编码器的输出中恢复出最终的分割掩码。这包括将Transformer的输出特征图与CNN特征图结合,并通过上采样步骤恢复到原始图像的分辨率。

我们只需要实现其每个模块,然后安装UNet拼装成整体就可以了。

2,首先实现的是编码器分支的卷积部分:

每个卷积模块可以使用resnet的一个块,或者自己实现一个

class EncoderBottleneck(nn.Module):def __init__(self, in_channels, out_channels, stride=1, base_width=64):super().__init__()  # 初始化父类self.downsample = nn.Sequential(  # 下采样层,用于降低特征图的维度nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))width = int(out_channels * (base_width / 64))  # 计算中间通道数self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False)  # 第一个卷积层self.norm1 = nn.BatchNorm2d(width)  # 第一个批量归一化层self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)  # 第二个卷积层self.norm2 = nn.BatchNorm2d(width)  # 第二个批量归一化层self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False)  # 第三个卷积层self.norm3 = nn.BatchNorm2d(out_channels)  # 第三个批量归一化层self.relu = nn.ReLU(inplace=True)  # ReLU激活函数def forward(self, x):x_down = self.downsample(x)  # 下采样操作x = self.conv1(x)  # 第一个卷积操作x = self.norm1(x)  # 第一个批量归一化x = self.relu(x)  # ReLU激活x = self.conv2(x)  # 第二个卷积操作x = self.norm2(x)  # 第二个批量归一化x = self.relu(x)  # ReLU激活x = self.conv3(x)  # 第三个卷积操作x = self.norm3(x)  # 第三个批量归一化x = x + x_down  # 残差连接x = self.relu(x)  # ReLU激活return x

3,实现ViT模块

多头注意力实现如下:

class MultiHeadAttention(nn.Module):def __init__(self, embedding_dim, head_num):super().__init__()  # 调用父类构造函数self.head_num = head_num  # 多头的数量self.dk = (embedding_dim // head_num) ** (1 / 2)  # 缩放因子,用于缩放点积注意力self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)  # 线性层,用于生成查询(Q)、键(K)和值(V)self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)  # 输出线性层def forward(self, x, mask=None):qkv = self.qkv_layer(x)  # 通过线性层生成Q、K、Vquery, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.head_num))  # 将Q、K、V重塑为多头注意力的格式energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk  # 计算点积注意力的能量if mask is not None:  # 如果提供了掩码,则在能量上应用掩码energy = energy.masked_fill(mask, -np.inf)attention = torch.softmax(energy, dim=-1)  # 应用softmax函数,得到注意力权重x = torch.einsum("... i j , ... j d -> ... i d", attention, value)  # 应用注意力权重到值上x = rearrange(x, "b h t d -> b t (h d)")  # 重塑x以准备输出x = self.out_attention(x)  # 通过输出线性层return x

MLP实现如下:

# 定义MLP模块
class MLP(nn.Module):def __init__(self, embedding_dim, mlp_dim):super().__init__()  # 调用父类构造函数self.mlp_layers = nn.Sequential(  # 定义MLP的层nn.Linear(embedding_dim, mlp_dim),nn.GELU(),  # GELU激活函数nn.Dropout(0.1),  # Dropout层,用于正则化nn.Linear(mlp_dim, embedding_dim),  # 线性层nn.Dropout(0.1)  # Dropout层)def forward(self, x):x = self.mlp_layers(x)  # 通过MLP层return x

一个Transformer编码器块由归一化层,多头注意力,MLP和残差连接组成,实现如下:

# 定义Transformer编码器块
class TransformerEncoderBlock(nn.Module):def __init__(self, embedding_dim, head_num, mlp_dim):super().__init__()  # 调用父类构造函数self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)  # 多头注意力模块self.mlp = MLP(embedding_dim, mlp_dim)  # MLP模块self.layer_norm1 = nn.LayerNorm(embedding_dim)  # 第一层归一化self.layer_norm2 = nn.LayerNorm(embedding_dim)  # 第二层归一化self.dropout = nn.Dropout(0.1)  # Dropout层def forward(self, x):_x = self.multi_head_attention(x)  # 通过多头注意力模块_x = self.dropout(_x)  # 应用dropoutx = x + _x  # 残差连接x = self.layer_norm1(x)  # 第一层归一化_x = self.mlp(x)  # 通过MLP模块x = x + _x  # 残差连接x = self.layer_norm2(x)  # 第二层归一化return x

Transformer 编码器由多层Transformer块堆叠而成,其中block_num代表的就是堆叠的层数

# 定义Transformer编码器
class TransformerEncoder(nn.Module):def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):super().__init__()  # 调用父类构造函数self.layer_blocks = nn.ModuleList([  # 创建一个模块列表,包含多个编码器块TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)])def forward(self, x):for layer_block in self.layer_blocks:  # 遍历每个编码器块x = layer_block(x)  # 通过每个块return x

vit的全部模块已经实现,下面就vit整体结构了。

vit的整体结构就是先将输入图片划分patches,然后将patches做embedding。

vit的分类头是一组额外添加的cl-token,将这个class-token复制batches遍,之后就可以将复制后的class_Token拼接到之前的embedding上了。

之后需要把位置编码加到这个embedding上。

这样,输入的图像特征就被处理好了,转换成了输入给Transformer块的形式。

之后只要输入一个Transformer编码器和一个MLP头,就可以得到vit的输出结果。

如果是分类任务,则class_token就是分类结果,如果不是分类任务,比如分割或者vit作为一个模块,那么输出的就是patches形式的特征图。

# 定义ViT模型
class ViT(nn.Module):def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim, block_num, patch_dim, classification=True, num_classes=1):super().__init__()  # 调用父类构造函数self.patch_dim = patch_dim  # 定义patch的维度self.classification = classification  # 是否进行分类self.num_tokens = (img_dim // patch_dim) ** 2  # 计算tokens的数量self.token_dim = in_channels * (patch_dim ** 2)  # 计算每个token的维度self.projection = nn.Linear(self.token_dim, embedding_dim)  # 线性层,用于将patches投影到embedding空间self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))  # 可学习的embeddingself.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))  # 类别tokenself.dropout = nn.Dropout(0.1)  # Dropout层self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)  # Transformer编码器if self.classification:  # 如果是分类任务self.mlp_head = nn.Linear(embedding_dim, num_classes)  # 分类头def forward(self, x):img_patches = rearrange(x,  # 将输入图像重塑为patches序列'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',patch_x=self.patch_dim, patch_y=self.patch_dim)batch_size, tokens, _ = img_patches.shape  # 获取批次大小、tokens数量和通道数project = self.projection(img_patches)  # 将patches投影到embedding空间token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', batch_size=batch_size)  # 重复cls_token以匹配批次大小patches = torch.cat((token, project), dim=1)  # 将cls_token和投影后的patches拼接patches += self.embedding[:tokens + 1, :]  # 将可学习的embedding添加到patchesx = self.dropout(patches)  # 应用dropoutx = self.transformer(x)  # 通过Transformer编码器x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]  # 如果是分类任务,使用cls_token的输出;否则,使用patches的输出return x

4,实现解码器的模块

解码器的模块就是卷积模块,接受两个输入:上采样而来的特征图以及skip-connection来的特征图

# 定义解码器中的瓶颈层
class DecoderBottleneck(nn.Module):def __init__(self, in_channels, out_channels, scale_factor=2):super().__init__()  # 初始化父类self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)  # 上采样层self.layer = nn.Sequential(  # 解码器层nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))def forward(self, x, x_concat=None):x = self.upsample(x)  # 上采样操作if x_concat is not None:  # 如果有额外的特征图进行拼接x = torch.cat([x_concat, x], dim=1)  # 在通道维度上拼接x = self.layer(x)  # 通过解码器层return x

5,组装成模型

所有模块都已经定义完成,下面拿这些模块来组装成模型。

编码器分支由三个卷积模块和一个vit模块组成,输出的x为解码分支最终的特征图,x1,x2,x3分别是三个卷积模块的输出

# 定义编码器
class Encoder(nn.Module):def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):super().__init__()  # 初始化父类self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)  # 第一个卷积层self.norm1 = nn.BatchNorm2d(out_channels)  # 第一个批量归一化层self.relu = nn.ReLU(inplace=True)  # ReLU激活函数self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)  # 第一个编码器瓶颈层self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)  # 第二个编码器瓶颈层self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)  # 第三个编码器瓶颈层self.vit_img_dim = img_dim // patch_dim  # ViT的图像维度self.vit = ViT(self.vit_img_dim, out_channels * 8, out_channels * 8,  # ViT模型head_num, mlp_dim, block_num, patch_dim=1, classification=False)self.conv2 = nn.Conv2d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)  # 第四个卷积层self.norm2 = nn.BatchNorm2d(512)  # 第四个批量归一化层def forward(self, x):x = self.conv1(x)  # 第一个卷积操作x = self.norm1(x)  # 第一个批量归一化x1 = self.relu(x)  # ReLU激活x2 = self.encoder1(x1)  # 第一个编码器瓶颈层x3 = self.encoder2(x2)  # 第二个编码器瓶颈层x = self.encoder3(x3)  # 第三个编码器瓶颈层x = self.vit(x)  # 通过ViT模型x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_img_dim, y=self.vit_img_dim)  # 重塑特征图x = self.conv2(x)  # 第四个卷积操作x = self.norm2(x)  # 第四个批量归一化x = self.relu(x)  # ReLU激活return x, x1, x2, x3  # 返回多个特征图

解码分支接受编码分支的输出x,以及三个卷积模块的输出x1,x2,x3,

# 定义解码器
class Decoder(nn.Module):def __init__(self, out_channels, class_num):super().__init__()  # 初始化父类self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2)  # 第一个解码器瓶颈层self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)  # 第二个解码器瓶颈层self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))  # 第三个解码器瓶颈层self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))  # 第四个解码器瓶颈层self.conv1 = nn.Conv2d(int(out_channels * 1 / 8), class_num, kernel_size=1)  # 最后一个卷积层,用于输出def forward(self, x, x1, x2, x3):x = self.decoder1(x, x3)  # 第一个解码器瓶颈层x = self.decoder2(x, x2)  # 第二个解码器瓶颈层x = self.decoder3(x, x1)  # 第三个解码器瓶颈层x = self.decoder4(x)  # 第四个解码器瓶颈层x = self.conv1(x)  # 最后一个卷积层return x  # 返回解码器的输出

整个模型结构:

# 定义TransUNet模型
class TransUNet(nn.Module):def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):super().__init__()  # 初始化父类self.encoder = Encoder(img_dim, in_channels, out_channels,  # 初始化编码器head_num, mlp_dim, block_num, patch_dim)self.decoder = Decoder(out_channels, class_num)  # 初始化解码器def forward(self, x):x, x1, x2, x3 = self.encoder(x)  # 编码分支x = self.decoder(x, x1, x2, x3)  # 解码分支return x  # 返回最终输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/735293.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

逼疯快递员的送货上门,谁来背锅?

快递上门的问题近几年来一直争论不休。 最近,随着新修订的《快递市场管理办法》正式实施,这个话题又成为了焦点。 消费者希望快递能够送上门省去麻烦,快递员希望统一送到代收点提高效率。 是消费者要求太高?快递员太过怠慢&…

LightDB24.1插件oracle_fdw需要支持oracle.date和oracle.varchar2类型

背景介绍 oracle.date和oracle.varchar2是LightDB中新增的类型,对应于Oracle数据库的date和varchar2类型。oracle_fdw是第三方插件,所以不支持oracle.date和oracle.varchar2类型。从LightDB24.1版本开始,oracle_fdw插件开始支持oracle.date和…

理解自相关图AC和偏自相关图PAC Plots

when we talk about the time-series data, many factors affect the time series, but the only thing that affects the lagged version of the variable is the time series data itself. by Yugesh Verma 时序数据按照时间点的先后顺序进行排列,变化是在邻近的时间段之间发…

2.1基本算法之枚举1978:生理周期

人生来就有三个生理周期,分别为体力、感情和智力周期,它们的周期长度为23天、28天和33天。每一个周期中有一天是高峰。在高峰这天,人会在相应的方面表现出色。例如,智力周期的高峰,人会思维敏捷,精力容易高…

[金三银四] 系统调用相关

2.36 系统调用的详细流程 Linux 在x86上的系统调用通过 int 0x80 实现,用系统调用号来区分入口函数。操作系统实现系统调用的基本过程是: 应用程序调用库函数(API);API 将系统调用号存入寄存器(EAX&#…

CKA备考攻略:掌握Pod日志收集,事半功倍的秘诀!

往期精彩文章 : 提升CKA考试胜算:一文带你全面了解RBAC权限控制!揭秘高效运维:如何用kubectl top命令实时监控K8s资源使用情况?CKA认证必备:掌握k8s网络策略的关键要点提高CKA认证成功率,CKA真题中的节点维…

稳定性三——wachdog机制与分析发方法

文章目录 1. 介绍2 watchdog 机制2.1 初始化2.2 添加Watchdog监测对象2.3 监测机制 3 问题分析3.1 日志分类3.2 定位3.3 场景还原 4. 实例分析5. 总结 1. 介绍 最早引入Watchdog是在单片机系统中,由于单片机的工作环境容易受到外界磁场的干扰,导致程序“…

uniapp上拉加载、下拉刷新

我这个是自定义header、main、和footer的布局&#xff0c;是盒子中的上拉加载、下拉刷新&#xff0c;不是页面的&#xff0c;废话不说&#xff0c;直接上代码&#xff01; <template><view class"assembly"><u-navbar title"个人中心" lef…

2.JavaWebMySql基础

导语&#xff1a; 一、数据库基本概念 1.什么是数据库 2.关于MySql数据库 二、MySQL的安装与卸载 安装步骤&#xff1a; 卸载步骤&#xff1a; 三、MySQL服务操作 1.服务启动和关闭&#xff1a; 2.登录和退出MySQL&#xff1a; 3.服务自启动&#xff1a; 4.命令行登…

Python实现线性查找算法

Python实现线性查找算法 以下是使用 Python 实现线性查找算法的示例代码&#xff1a; def linear_search(arr, target):"""线性查找算法:param arr: 要搜索的数组:param target: 目标值:return: 如果找到目标值&#xff0c;返回其索引&#xff1b;否则返回 -1…

linux系统 QT 处理键盘Ctrl+C信号

linux系统 QT 处理键盘CtrlC信号 1 设置CtrlC信号处理函数 CtrlC运行 &#xff0c;serialPort不能用 .h public:explicit axisControl(axisInfo *axisinf,QWidget *parent nullptr);~axisControl();// 成员函数的CtrlC信号处理程序static void handleCtrlC(int signal);//…

【玩转Linux】有关Linux权限

目录 一.Linux权限的概念 1. 权限的本质 2.Linux中的用户 3.Linux中的权限管理 (1)文件访问者的分类 (2)文件类型和访问权限&#xff08;事物属性&#xff09; ①文件基本权限 ②文件权限值的表示方法 (3)文件访问权限的相关设置方法 ① 用 户 表 示 符 / - 权 …

EKF+PF的MATLAB例程

EKF+PF 扩展卡尔曼滤波与粒子滤波的MATLAB程序,有中文注释 程序源码 % EKF+PF效果对比 % author:Evand % 作者联系方式:evandjiang@qq.com(除前期达成一致外,咨询需付费) % date: 2024-1-10 % Ver2 clear;clc;close all; rng(0); %% 参数设置 N = 100; %粒子总数

c++之迭代器与反向迭代器

&#xff09; 正向迭代器迭代器的变量与typedef与模版operator()operator--()operator*()operator->() 反向迭代器模版与typedef与变量operator()operator--()operator*()operator->() 正向迭代器 以链表的迭代器为例 具体的代码以及可以看上一篇链表的文章:链表 迭代器的…

Vue3 快速上手从0到1,两小时学会【附源码】

小伙伴们好&#xff0c;欢迎关注&#xff0c;一起学习&#xff0c;无限进步 以下内容为vue3的学习笔记 项目需要使用到的依赖 npm install axios npm install nanoid vue-router npm install pinia npm install mitt 源码&#xff1a;Gitee 运行 npm install npm run dev需要运…

FastAPI静态文件映射到网页

安装了FastAPI 和 Uvicorn&#xff1a;pip install fastapi uvicorn 然后运行代码 from fastapi import FastAPI from fastapi.staticfiles import StaticFilesapp FastAPI()# 假设 dir_upload 为 "/Users/yourusername/yourprojectpath/files/" dir_upload &quo…

大唐杯学习笔记:Day10

1.1 5G网络基本架构-SA 基站 gNB可支持FDD模式,TDD模式或双模式操作&#xff1b; gNB可以通过Xn接口互联&#xff1b; gNB内部CU分为控制面和用户面分离架构&#xff1b; gNB可以由gNB-CU和一个或多个gNB-DU组成&#xff1b; gNB-CU和gNB-DU通过F1接口连接&#xff1b; …

每日OJ题_链表④_力扣23. 合并 K 个升序链表(小根堆_归并)

目录 力扣23. 合并 K 个升序链表 解析代码1&#xff08;小根堆优化&#xff09; 解析代码2&#xff08;递归_归并&#xff09; 力扣23. 合并 K 个升序链表 23. 合并 K 个升序链表 难度 困难 给你一个链表数组&#xff0c;每个链表都已经按升序排列。 请你将所有链表合并…

MacBook2024苹果免费mac电脑清理垃圾软件CleanMyMac X

CleanMyMac X是一款专业的Mac清理软件&#xff0c;具备多种强大功能。首先&#xff0c;它能够智能清理Mac磁盘上的垃圾文件和多余语言安装包&#xff0c;从而快速释放电脑内存。其次&#xff0c;CleanMyMac X可以轻松管理和升级Mac上的应用&#xff0c;同时强力卸载恶意软件并修…

windows使用pyenv

1、前言 虽然anaconda比pyenv相比有更好的python安装体验&#xff0c;但是有一个比较严重的问题的就是&#xff0c;他的python版本跨度不够大&#xff0c;一些老一些的项目的python版本找不到&#xff0c;比如py12306要求的python版本是3.6&#xff0c;在anaconda却找不到这个版…