分割模型TransNetR的pytorch代码学习笔记

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):super().__init__()self.ps = patch_sizeself.c1 = Conv2D(in_c, out_c)encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)self.r1 = residual_block(out_c, out_c)def forward(self, inputs):x = self.c1(inputs)b, c, h, w = x.shapenum_patches = (h*w)//(self.ps**2)x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))x = self.te(x)x = torch.reshape(x, (b, c, h, w))x = self.c2(x)s = self.c3(inputs)x = self.relu(x + s)x = self.r1(x)return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):def __init__(self, in_c, out_c):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c),nn.LeakyReLU(negative_slope=0.1, inplace=True),nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),nn.BatchNorm2d(out_c))self.shortcut = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),nn.BatchNorm2d(out_c))self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)def forward(self, inputs):x = self.conv(inputs)s = self.shortcut(inputs)return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):def __init__(self):super().__init__()""" Encoder """backbone = resnet50()self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)self.layer2 = backbone.layer2self.layer3 = backbone.layer3self.layer4 = backbone.layer4self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)""" Decoder """self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)self.r1 = residual_transformer_block(64+64, 64, dim=64)self.r2 = residual_transformer_block(64+64, 64, dim=256)self.r3 = residual_block(64+64, 64)""" Classifier """self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)def forward(self, inputs):""" Encoder """x0 = inputsx1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]e1 = self.e1(x1)e2 = self.e2(x2)e3 = self.e3(x3)e4 = self.e4(x4)""" Decoder """x = self.up(e4)x = torch.cat([x, e3], axis=1)x = self.r1(x)x = self.up(x)x = torch.cat([x, e2], axis=1)x = self.r2(x)x = self.up(x)x = torch.cat([x, e1], axis=1)x = self.r3(x)x = self.up(x)""" Classifier """outputs = self.outputs(x)return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734825.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch搭建LeNet训练集详细实现

一、下载训练集 导包 import torch import torchvision import torch.nn as nn from model import LeNet import torch.optim as optim import torchvision.transforms as transforms import matplotlib.pyplot as plt import numpy as npToTensor()函数: 把图像…

git学习(创建项目提交代码)

操作步骤如下 git init //初始化git remote add origin https://gitee.com/aydvvs.git //建立连接git remote -v //查看git add . //添加到暂存区git push 返送到暂存区git status // 查看提交代码git commit -m初次提交git push -u origin "master"//提交远程分支 …

微信小程序(五十二)开屏页面效果

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.使用控件模拟开屏界面 2.倒计时逻辑 3.布局方法 4.TabBar隐藏复现 源码&#xff1a; components/openPage/openPage.wxml <view class"openPage-box"><image src"{{imagePath}}"…

三维不同坐标系下点位姿态旋转平移变换

文章目录 前言正文计算方法思路Python实现总结前言 本文主要说明以下几种场景3D变换的应用: 3D相机坐标系下长方体物体,有本身坐标系,沿该物体长边方向移动一段距离,并绕长边轴正旋转方向转90度,求解当前物体中心点在相机坐标系下的位置和姿态多关节机器人末端沿工具坐标…

STM32 利用FlashDB库实现在线扇区数据管理不丢失

STM32 利用FlashDB库实现在线扇区数据管理不丢失 &#x1f4cd;FalshDB地址:https://gitee.com/Armink/FlashDB ✨STM32没有片内EEPROM这样的存储区&#xff0c;虽然有备份寄存器&#xff0c;仅可以实现对少量数据的频繁存储&#xff0c;但是依赖备份电源&#xff08;BAT引脚&a…

美国签证|附面签相关事项√

小伙伴最近都忙着办签证吧&#xff01;但是需要注意的是&#xff0c;美国的签证跟其他任何国家的签证不同&#xff0c;并不是办理了就一定拿得到&#xff0c;据说概率是50%左右。所以办理美国签证&#xff0c;不要太着急啦&#xff01;先来了解一下美国签证的相片该怎么拍叭 ✅…

RocketMQ的事务消息流程

什么是事务消息&#xff1f; 事务消息是一种在发送方和接收方之间保证消息传递的一致性和可靠性的消息传递机制。在消息发送过程中&#xff0c;生产者可以将消息发送到消息队列&#xff0c;但不会立即被消费者接收和处理。相反&#xff0c;消息会先进入一种“准备”状态&#x…

用chatgpt写insar地质灾害的论文,重复率只有1.8%,chatgpt4.0写论文不是梦

突发奇想&#xff0c;想用chatgpt写一篇论文&#xff0c;并看看查重率&#xff0c;结果很惊艳&#xff0c;说明是确实可行的&#xff0c;请看下图。 下面是完整的文字内容。 InSAR (Interferometric Synthetic Aperture Radar) 地质灾害监测技术是一种基于合成孔径雷达…

【JavaScript】JavaScript 变量 ① ( JavaScript 变量概念 | 变量声明 | 变量类型 | 变量初始化 | ES6 简介 )

文章目录 一、JavaScript 变量1、变量概念2、变量声明3、ES6 简介4、变量类型5、变量初始化 二、JavaScript 变量示例1、代码示例2、展示效果 一、JavaScript 变量 1、变量概念 JavaScript 变量 是用于 存储数据 的 容器 , 通过 变量名称 , 可以 获取 / 修改 变量 中的数据 ; …

第十五届蓝桥杯模拟赛(第三期)

大家好&#xff0c;我是晴天学长&#xff0c;本次分享&#xff0c;制作不易&#xff0c;本次题解只用于学习用途&#xff0c;如果有考试需要的小伙伴请考完试再来看题解进行学习&#xff0c;需要的小伙伴可以点赞关注评论一波哦&#xff01;蓝桥杯省赛就要开始了&#xff0c;祝…

【DimPlot】【FeaturePlot】使用小tips

目录 DimPlot函数参数解析 栅格化点图 放大 ggplot2 图例的点&#xff0c;修改图例的标题 FeaturePlot函数参数解析 调整FeaturePlot颜色 分组绘制featureplot 随手笔记&#xff0c;持续更新中。。。 Reference DimPlot函数参数解析 object: 一个Seurat对象&#xff0c;…

工作纪实46-关于微服务的上线发布姿势

蓝绿部署 在部署时&#xff0c;不需要将旧版本的服务停掉&#xff0c;而是将新版本与旧版本同时运行&#xff0c;新版本测试无误之后再将旧版本停掉。这样可以避免再升级的过程中如果失败服务不可用的问题&#xff0c;因为同时部署了两个版本的程序&#xff0c;使得硬件资源是…

【项目笔记】java微服务:黑马头条(day01)

文章目录 环境搭建、SpringCloud微服务(注册发现、服务调用、网关)1)课程对比2)项目概述2.1)能让你收获什么2.2)项目课程大纲2.3)项目概述2.4)项目术语2.5)业务说明 3)技术栈4)nacos环境搭建4.1)虚拟机镜像准备4.2)nacos安装 5)初始工程搭建5.1)环境准备5.2)主体结构 6)登录6.1…

JavaScript中的Set和Map:理解与使用

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

C++:类和对象(三)——拷贝构造函数和运算符重载

目录 一、拷贝构造函数 1.概念 2.特性 二、赋值运算符重载 1.运算符重载 2.赋值运算符重载 &#xff08;1&#xff09;注意的点&#xff1a; &#xff08;2&#xff09;赋值运算符不允许被重载为全局函数&#xff0c;只能重载为类的成员函数 &#xff08;3&#xff09;…

C++ 字符串OJ

目录 1、14. 最长公共前缀 2、 5. 最长回文子串 3、 67. 二进制求和 4、43. 字符串相乘 1、14. 最长公共前缀 思路一&#xff1a;两两字符串进行比较&#xff0c;每次比较过程相同&#xff0c;可以添加一个函数辅助比较&#xff0c;查找最长公共前缀。 class Solution { pu…

【C++】函数模板和类模板

目录 1.泛型编程 2.函数模板 2.1函数模板的定义格式 2.2函数模板的实例化 2.3函数模板参数的匹配原则 3.类模板 3.1类模板的定义格式 3.2类模板的实例化 3.3模板的分离编译 1.泛型编程 泛型编程&#xff1a;编写与类型无关的通用代码&#xff0c;是代码复用的一种手段…

【前端CSS】CSS的3种基本选择器和5种高级选择器使用方式

目录 前言 基本选择器 1.1 标签选择器 1.2 ID选择器 1.3 类选择器 高级选择器 2.1 并集选择器 2.2 交集选择器 2.3 后代选择器 2.4 子元素选择器 2.5 属性选择器 前言 1W&#xff1a;什么是CSS选择器&#xff1f; CSS选择器由HTML元素的id、class属性或元素名本身以及…

SpringBoot中定时任务、corn表达式

SpringBoot中定时任务、corn表达式 corn表达式网站&#xff1a;https://cron.qqe2.com/ 方法上加上Scheduled(cron表达式) 启动类上加上EnableScheduling 示例 启动类上 启动类加上EnableScheduling开启定时任务。 SpringBootApplication EnableScheduling public class…

vue 在什么情况下在数据发生改变的时候不会触发视图更新

在 Vue 中&#xff0c;通常数据发生变化时&#xff0c;视图会自动更新。但是&#xff0c;有几种情况可能导致数据变化不会触发视图更新&#xff1a; 1.对象属性的添加或删除&#xff1a; Vue 无法检测到对象属性的添加或删除。因为 Vue 在初始化实例时对属性执行了 getter/se…