图像语义分割 pytorch复现U2Net图像分割网络详解

图像语义分割 pytorch复现U2Net图像分割网络详解

  • 1、U2Net网络模型结构
  • 2、block模块结构解析
    • RSU-7模块
    • RSU-4F
    • saliency map fusion module
  • U2Net网络结构详细参数配置
  • RSU模块代码实现
  • RSU4F模块代码实现
  • u2net_full与u2net_lite模型配置函数
  • U2Net网络整体定义类
  • 损失函数计算
  • 评价指标
  • 数据集
  • pytorch训练U2Net图像分割模型

在这里插入图片描述
U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection

1、U2Net网络模型结构

在这里插入图片描述
网络的主体类似于U-Net的网络结构,在大的U-Net中,每一个小的block都是一个小型的类似于U-Net的结构,因此作者取名U2Net
仔细观察,可以将网络中的block分成两类:
第一类:En_1 ~ En_4 与 De_1 ~ De_4这8个block采用的block其实是一样的,只不过模块的深度不同。

第二类:En_5、En_6、De_5

  • 在整个U2Net网络中,在Encoder阶段,每通过一个block都会进行一次下采样操作(下采样2倍,maxpool)
  • 在Decoder阶段,在每个block之间,都会进行一次上采样(2倍,bilinear)

2、block模块结构解析

在 En_1 与 De_1 模块中,采用的 block 是RSU-7;
En_2 与 De_2采用的 block 是RSU-6(RSU-6相对于RSU-7 就是少了一个下采样卷积以及上采样卷积的部分,RSU-6 block只会下采样16倍,RSU-7 block下采样的32倍);
En_3 与 De_3采用的 block 是RSU-5
En_4 与 De_4采用的 block 是RSU-4
En_5、En_6、De_5采用的block是RSU-4F
(使用RSU-4F的原因:因为数据经过En_1 ~ En4 下采样处理后对应特征图的高与宽就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息,作者为了保留上下文信息,就对En_5、En_6、De_5不再进行下采样了而是在RSU-4F的模块中,将下采样、上采样结构换成了膨胀卷积)

RSU-7模块

在这里插入图片描述详细结构图解
在这里插入图片描述

RSU-4F

在这里插入图片描述

saliency map fusion module

saliency map fusion module模块是将每个阶段的特征图进行融合,得到最终的预测概率图,即下图中,红色框标注的模块
在这里插入图片描述
其会收集De_1、De_2、De_3、De_4、De_5、En_6模块的输出,将这些输出分别通过一个3x3的卷积层(这些卷积层的kerner的个数都是为1)输出的featuremap的channel是为1的,在经过双线性插值算法将得到的特征图还原回输入图像的大小;再将得到的6个特征图进行concant拼接;在经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的预测概率图。

U2Net网络结构详细参数配置

在这里插入图片描述
u2net_full大小为176.3M、u2net_lite大小为4.7M

RSU模块代码实现

在这里插入图片描述

class RSU(nn.Module):def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):super().__init__()assert height >= 2self.conv_in = ConvBNReLU(in_ch, out_ch)encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)]decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)]for i in range(height - 2):encode_list.append(DownConvBNReLU(mid_ch, mid_ch))decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2))self.encode_modules = nn.ModuleList(encode_list)self.decode_modules = nn.ModuleList(decode_list)def forward(self, x: torch.Tensor) -> torch.Tensor:x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop()for m in self.decode_modules:x2 = encode_outputs.pop()x = m(x, x2)return x + x_in

RSU4F模块代码实现

在这里插入图片描述

class RSU4F(nn.Module):def __init__(self, in_ch: int, mid_ch: int, out_ch: int):super().__init__()self.conv_in = ConvBNReLU(in_ch, out_ch)self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),ConvBNReLU(mid_ch, mid_ch, dilation=2),ConvBNReLU(mid_ch, mid_ch, dilation=4),ConvBNReLU(mid_ch, mid_ch, dilation=8)])self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),ConvBNReLU(mid_ch * 2, out_ch)])def forward(self, x: torch.Tensor) -> torch.Tensor:x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop()for m in self.decode_modules:x2 = encode_outputs.pop()x = m(torch.cat([x, x2], dim=1))return x + x_in

u2net_full与u2net_lite模型配置函数

def u2net_full(out_ch: int = 1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side     side:表示是否要收集当前block的输出"encode": [[7, 3, 32, 64, False, False],      # En1[6, 64, 32, 128, False, False],    # En2[5, 128, 64, 256, False, False],   # En3[4, 256, 128, 512, False, False],  # En4[4, 512, 256, 512, True, False],   # En5[4, 512, 256, 512, True, True]],   # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 1024, 256, 512, True, True],   # De5[4, 1024, 128, 256, False, True],  # De4[5, 512, 64, 128, False, True],    # De3[6, 256, 32, 64, False, True],     # De2[7, 128, 16, 64, False, True]]     # De1}return U2Net(cfg, out_ch)def u2net_lite(out_ch: int = 1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side"encode": [[7, 3, 16, 64, False, False],  # En1[6, 64, 16, 64, False, False],  # En2[5, 64, 16, 64, False, False],  # En3[4, 64, 16, 64, False, False],  # En4[4, 64, 16, 64, True, False],  # En5[4, 64, 16, 64, True, True]],  # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 128, 16, 64, True, True],  # De5[4, 128, 16, 64, False, True],  # De4[5, 128, 16, 64, False, True],  # De3[6, 128, 16, 64, False, True],  # De2[7, 128, 16, 64, False, True]]  # De1}

U2Net网络整体定义类

class U2Net(nn.Module):def __init__(self, cfg: dict, out_ch: int = 1):super().__init__()assert "encode" in cfgassert "decode" in cfgself.encode_num = len(cfg["encode"])encode_list = []side_list = []for c in cfg["encode"]:# c: [height, in_ch, mid_ch, out_ch, RSU4F, side]assert len(c) == 6encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))     # 判断当前是构建RSU模块,还是构建RSU4F模块if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))self.encode_modules = nn.ModuleList(encode_list)decode_list = []for c in cfg["decode"]:# c: [height, in_ch, mid_ch, out_ch, RSU4F, side]assert len(c) == 6decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))    # 收集当前block的输出self.decode_modules = nn.ModuleList(decode_list)self.side_modules = nn.ModuleList(side_list)self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1)   # 构建一个1x1的卷积层,去融合来自不同尺度的信息def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:_, _, h, w = x.shape# collect encode outputsencode_outputs = []for i, m in enumerate(self.encode_modules):x = m(x)encode_outputs.append(x)if i != self.encode_num - 1:  # 此处需要进行判断,因为在没通过一个encoder模块后,都需要进行下采样的,但最后一个模块后,是不需要下采样的x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)# collect decode outputsx = encode_outputs.pop()decode_outputs = [x]for m in self.decode_modules:x2 = encode_outputs.pop()x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)x = m(torch.concat([x, x2], dim=1))decode_outputs.insert(0, x)# collect side outputsside_outputs = []for m in self.side_modules:x = decode_outputs.pop()x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False)side_outputs.insert(0, x)x = self.out_conv(torch.concat(side_outputs, dim=1))if self.training:# do not use torch.sigmoid for amp safereturn [x] + side_outputs     # 用于计算损失else:return torch.sigmoid(x)

损失函数计算

在这里插入图片描述
如上图所示,红色框部分为每个分量与真实标签的交叉熵损失函数求和;黄色框标部分为将各个分量经双线性插值恢复至原始尺寸、进行concant处理、经过1x1的卷积核与sigmoid处理后的结果与真实标签的交叉熵损失函数。
损失函数代码实现:

import math
import torch
from torch.nn import functional as F
import train_utils.distributed_utils as utilsdef criterion(inputs, target):losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]total_loss = sum(losses)return total_loss

评价指标

在这里插入图片描述
其中F-measure是在0~1之间的,数值越大,代表的网络分割效果越好;
MAE是Mean Absolute Error的缩写,其值是在0~1之间的,越趋近于0,代表网络性能越好。

数据集

在这里插入图片描述
在这里插入图片描述

pytorch训练U2Net图像分割模型

项目目录结构:

├── src: 搭建网络相关代码
├── train_utils: 训练以及验证相关代码
├── my_dataset.py: 自定义数据集读取相关代码
├── predict.py: 简易的预测代码
├── train.py: 单GPU或CPU训练代码
├── train_multi_GPU.py: 多GPU并行训练代码
├── validation.py: 单独验证模型相关代码
├── transforms.py: 数据预处理相关代码
└── requirements.txt: 项目依赖

项目目录:
在这里插入图片描述
项目中u2net_full大小为176.3M、u2net_lite大小为4.7M,演示过程中,训练的为u2net_lite版本
多GPU训练指令:
pytorch版本为1.7

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --use_env train_multi_GPU.py --data-path ./data_root

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/107759.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity之ShaderGraph如何实现上下溶解

前言 我们经常在电影中见到的一个物体或者人物&#xff0c;从头上到脚下&#xff0c;慢慢消失的效果&#xff0c;我么今天就来体验一下这个上下溶解。 主要节点 Position节点&#xff1a;提供对网格顶点或片段的Position 的访问 Step节点&#xff1a;如果输入In的值大于或等…

福昕阅读器打开pdf文档时显示的标题不是文件名

0 Preface/Foreword 1 现象 文件名为&#xff1a;Demo-20231017 打开效果&#xff1a;显示名字为 word template 2 解决方法 2.1 利用打印方式将word生产pdf 在word生产pdf文件时&#xff0c;使用打印方式生成pdf文档。 2.2 删除word文档设置的标题 文件---》信息---》标…

.NET Core/.NET6 使用DbContext 连接数据库,SqlServer

安装以下NuGet包 Microsoft.EntityFrameworkCore.SqlServer&#xff1a;SQL server 需要添加包 Microsoft.EntityFrameworkCore.Tools Newtonsoft.Json&#xff1a;用于Json格式转换 创建一个实体类来表示数据库表。在项目中创建一个名为Customer.cs的文件&#xff0c;并添加以…

微信小程序--小程序框架

目录 前言&#xff1a; 一.框架基本介绍 1.整体结构&#xff1a; 2.页面结构&#xff1a; 3.生命周期&#xff1a; 4.事件系统&#xff1a; 5.数据绑定&#xff1a; 6.组件系统&#xff1a; 7.API&#xff1a; 8.路由&#xff1a; 9.模块化&#xff1a; 10.全局配置&…

运维 | 如何在 Linux 系统中删除软链接 | Linux

运维 | 如何在 Linux 系统中删除软链接 | Linux 介绍 在 Linux 中&#xff0c;符号链接&#xff08;symbolic link&#xff0c;或者symlink&#xff09;也称为软链接&#xff0c;是一种特殊类型的文件&#xff0c;用作指向另一个文件的快捷方式。 使用方法 我们可以使用 ln…

[C国演义] 第十五章

第十五章 最长湍流子数组环绕字符串中唯⼀的⼦字符串 最长湍流子数组 力扣链接 子数组 ⇒ dp[i]的含义: 以arr[i] 结尾的所有子数组中的最长湍流子数组的长度 子数组 ⇒ 状态转移方程根据 最后一个位置来划分&#x1f447;&#x1f447;&#x1f447; 初始化: 都初始化为…

电力物联网关智能通讯管理机-安科瑞黄安南

众所周知&#xff0c;网关应用于各种行业的终端设备的数据采集与数据分析&#xff0c;然后去实现设备的监测、控制、计算&#xff0c;为系统与设备之间建立通讯联系&#xff0c;达到双向的数据通讯。 网关可以实时监测并及时发现异常数据&#xff0c;同时自身根据用户规则进行…

乡村新业态 | 直播电商引领经济发展,拓世法宝AI智能直播一体机助推乡村振兴

党的二十大报告作出加快建设数字中国、全面推进乡村振兴的战略部署&#xff0c;为进一步加强数字乡村建设、全面推进乡村振兴指明了方向。近年来&#xff0c;随着乡村新业态新模式的不断涌现&#xff0c;以直播电商为代表的数字经济为各地的农村产业升级带来了新契机。各地政府…

【Android】adjustViewBounds 的理解和使用

理解 adjustViewBounds 是一个 ImageView 的属性&#xff0c;用于调整 ImageView 的边界以适应图像的尺寸。当设置为 true 时&#xff0c;ImageView 的边界将根据图像的宽高比例进行调整&#xff0c;以确保图像完全显示在 ImageView 内部。 理解和使用 adjustViewBounds 的步…

Leetcode—136.只出现一次的数字【简单】

2023每日刷题&#xff08;二&#xff09; Leetcode—136.只出现一次的数字 位运算法 实现代码 int singleNumber(int* nums, int numsSize){int i 0;int res 0;for(; i < numsSize; i) {res ^ nums[i];}return res; }运行结果 之后我会持续更新&#xff0c;如果喜欢我的…

启动速度提升 10 倍:Apache Dubbo 静态化方案深入解析

作者&#xff1a;华钟明 文章摘要&#xff1a; 本文整理自有赞中间件技术专家、Apache Dubbo PMC 华钟明的分享。本篇内容主要分为五个部分&#xff1a; -GraalVM 直面 Java 应用在云时代的挑战 -Dubbo 享受 AOT 带来的技术红利 -Dubbo Native Image 的实践和示例 -Dubbo…

中国人口文化促进会社区文化推广工作委员会成立 暨2024社区春晚文艺活动新闻发布会在京成功举办

2023年10.13日&#xff0c;下午1点&#xff0c;在北京大红门国际会展中心召开了中国人口文化促进会社区文化推广工作委员会成立暨2024社区春晚文艺活动新闻发布会。来自政府相关部门、社会组织、新闻媒体和公益企业界的相关领导与代表齐聚一堂&#xff0c;共襄盛举。 本次大会由…

POI报表的入门

POI报表的入门 理解员工管理的的业务逻辑 能够说出Eureka和Feign的作用 理解报表的两种形式和POI的基本操作熟练使用POI完成Excel的导入导出操作 员工管理 需求分析 企业员工管理是人事资源管理系统中最重要的一个环节&#xff0c;分为对员工入职&#xff0c;转正&#x…

自动驾驶:控制算法概述

自动驾驶&#xff1a;控制算法概述 常见控制算法PID算法LQR算法MPC算法 自动驾驶控制算法横向控制纵向控制 参考文献 常见控制算法 PID算法 PID&#xff08;Proportional-Integral-Derivative&#xff09;控制是一种经典的反馈控制算法&#xff0c;通常用于稳定性和响应速度要…

ue5蓝图请求接口

安装与使用 1、在虚幻商城搜索 VaRest 插件 2、选择自己项目的对应版本安装 3、查看是否安装成功 4、进入项目后&#xff0c;分别启动VaRest、JSON Blueprint Utilities两个插件&#xff08;勾选后会提示重启项目&#xff09; 5、基本用法&#xff1a;打开关卡蓝图使用&#xf…

Android Studio SDK manager加载packages不全

打开Android Studio里的SDK manager&#xff0c;发现除了已安装的&#xff0c;其他的都不显示。 解决方法&#xff1a; 设置代理&#xff1a; 方便复制> http://mirrors.neusoft.edu.cn/ 重启Android Studio

【Java学习之道】TCPIP套接字编程实例

引言 网络编程是Java学习中不可或缺的一部分&#xff0c;而TCP/IP套接字编程又是网络编程的基础。那么&#xff0c;初学者如何才能快速掌握TCP/IP套接字编程呢&#xff1f;今天我们就来通过一个简单的实例&#xff0c;为你揭示TCP/IP套接字编程的奥秘&#xff01; 一、什么是…

Sql Server 数据库中的所有已定义的唯一约束 (列名称 合并过了)

查询Sql Server Database中的唯一约束 with UniqueBasic as (SELECTtab.name AS TableName, -- 表名称idx.name AS UniqueName, -- 唯一约束的名称col.name AS UniqueFieldName -- 唯一约束的表字段FROMsys.indexes idxJOIN sys.index_columns idxColON (idx.object_id idxCo…

PyTorch 深度学习之循环神经网络(基础篇)Basic RNN(十一)

0.Revision: DNN dense 重义层 全连接 RNN处理带有序列的数据 1. What is RNNs? linear layer 1.1 What is RNN? tanh (-1, 1) 1.2 RNN Cell in PyTorch 1.3 How to use RNNCell *先把维度搞清楚 多了一个序列的维度 2. How to use RNN 2.1 How to use RNN - numLayers…

PC电脑 VMware安装的linux CentOs7如何扩容磁盘?

一、VM中进行扩容设置 必须要关闭当前CentOS&#xff0c;不然扩展按钮是灰色的。 输入值必须大于当前磁盘容量。然后点击扩展&#xff0c;等待扩展完成会提示一个弹框&#xff0c;点击确定&#xff0c;继续确定。 二、操作CentOS扩容——磁盘分区 第一步设置完成。那就启动 …