图像语义分割 U-Net图像分割网络详解

图像语义分割 U-Net图像分割网络详解

  • 简介
  • 原始论文中的网络结构
  • 在医学方面的应用
  • pytorch官方实现
  • 以DRIVE眼底血管分割数据集训练U-Net语义分割网络模型
  • U-Net网络训练损失函数

简介

U-Net网络非常的简单,前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大写的英文字母U,所以叫做U-Net。

  • Encoder:左半部分,由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块
  • Decoder:有半部分,由一个上采样的卷积层(去卷积层)+特征拼接concat+两个3x3的卷积层(ReLU)反复构成,在进行上采样的过程中,原理论文中采用转置卷积,但在实际应用中通常采用双线性插值法实现。在原始论文论文对图像进行拼接过程中,由于尺寸不匹配,所以先进行中心裁剪,得到相同尺寸特征图,再在通道上进行拼接。

原始论文中的网络结构

在这里插入图片描述

在医学方面的应用

大多数医疗影像语义分割任务都会首先用Unet作为baseline,这里谈一谈医疗影像的特点:

  • 医疗影像语义较为简单、结构固定。因此语义信息相比自动驾驶等较为单一,因此并不需要去筛选过滤无用的信息。医疗影像的所有特征都很重要,因此低级特征和高级语义特征都很重要,所以U型结构的skip connection结构(特征拼接)更好派上用场。
  • 医学影像的数据较少,获取难度大,数据量可能只有几百甚至不到100,因此如果使用大型的网络例如DeepLabv3+等模型,很容易过拟合。大型网络的优点是更强的图像表述能力,而较为简单、数量少的医学影像并没有那么多的内容需要表述,因此也有人发现在小数量级中,分割的SOTA模型与轻量的Unet并没有什么优势。
  • 医学影像任务中,往往需要自己设计网络去提取不同的模态特征,因此轻量结构简单的Unet可以有更大的操作空间

pytorch官方实现

在进行卷积的过程中,进行padding的操作,不改变图像的尺寸,所以不需要进行中心裁剪的过程,最后得到的特征图与输入原始图像尺寸上保持一致
网络结构:
在这里插入图片描述
网络结构代码实现:

from typing import Dict
import torch
import torch.nn as nn
import torch.nn.functional as Fclass DoubleConv(nn.Sequential):def __init__(self, in_channels, out_channels, mid_channels=None):if mid_channels is None:mid_channels = out_channelssuper(DoubleConv, self).__init__(nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(mid_channels),nn.ReLU(inplace=True),nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))class Down(nn.Sequential):def __init__(self, in_channels, out_channels):super(Down, self).__init__(nn.MaxPool2d(2, stride=2),DoubleConv(in_channels, out_channels))class Up(nn.Module):def __init__(self, in_channels, out_channels, bilinear=True):super(Up, self).__init__()if bilinear:self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)else:self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)self.conv = DoubleConv(in_channels, out_channels)def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:x1 = self.up(x1)# [N, C, H, W]# gqr:以下padding操作的目的是为了防止输入的图像不是16的整数倍导致在进行拼接过程时尺寸不一致的问题diff_y = x2.size()[2] - x1.size()[2]diff_x = x2.size()[3] - x1.size()[3]# padding_left, padding_right, padding_top, padding_bottomx1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,diff_y // 2, diff_y - diff_y // 2])x = torch.cat([x2, x1], dim=1)x = self.conv(x)return xclass OutConv(nn.Sequential):def __init__(self, in_channels, num_classes):super(OutConv, self).__init__(nn.Conv2d(in_channels, num_classes, kernel_size=1))class UNet(nn.Module):def __init__(self,in_channels: int = 1,num_classes: int = 2,bilinear: bool = True,base_c: int = 64):super(UNet, self).__init__()self.in_channels = in_channelsself.num_classes = num_classesself.bilinear = bilinearself.in_conv = DoubleConv(in_channels, base_c)self.down1 = Down(base_c, base_c * 2)self.down2 = Down(base_c * 2, base_c * 4)self.down3 = Down(base_c * 4, base_c * 8)factor = 2 if bilinear else 1self.down4 = Down(base_c * 8, base_c * 16 // factor)self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)self.up4 = Up(base_c * 2, base_c, bilinear)self.out_conv = OutConv(base_c, num_classes)def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:x1 = self.in_conv(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.out_conv(x)return {"out": logits}

以DRIVE眼底血管分割数据集训练U-Net语义分割网络模型

数据集目录结构:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
数据预处理代码:
注意:在进行语义分割时,前景像素值要从1开始

import os
from PIL import Image
import numpy as np
from torch.utils.data import Datasetclass DriveDataset(Dataset):def __init__(self, root: str, train: bool, transforms=None):super(DriveDataset, self).__init__()self.flag = "training" if train else "test"data_root = os.path.join(root, "DRIVE", self.flag)assert os.path.exists(data_root), f"path '{data_root}' does not exists."self.transforms = transformsimg_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]self.img_list = [os.path.join(data_root, "images", i) for i in img_names]self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")for i in img_names]# check filesfor i in self.manual:if os.path.exists(i) is False:raise FileNotFoundError(f"file {i} does not exists.")self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")for i in img_names]# check filesfor i in self.roi_mask:if os.path.exists(i) is False:raise FileNotFoundError(f"file {i} does not exists.")def __getitem__(self, idx):img = Image.open(self.img_list[idx]).convert('RGB')manual = Image.open(self.manual[idx]).convert('L')   # gqr:转换得到灰度图后,前景的为255,背景的像素为0manual = np.array(manual) / 255  # gqr:将数据进行归一化,前景的为1;背景的像素为0;在进行语义分割时,前景像素值要从1开始roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # gqr:转换成灰度图,感兴趣区域为255;不感兴趣区域是0roi_mask = 255 - np.array(roi_mask)   # gqr:将感兴趣的区域设置为0,不感兴趣的区域设置为255,这样在计算损失时可以排除掉像素为255的区域mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # gqr:想加后,需要分割的部分为1,背景为0,还有为255的不感兴趣区域"""print(np.unique(mask)):输出结果为:[  0.   1. 255.]"""# 这里转回PIL的原因是,transforms中是对PIL数据进行处理mask = Image.fromarray(mask)if self.transforms is not None:img, mask = self.transforms(img, mask)return img, maskdef __len__(self):return len(self.img_list)@staticmethoddef collate_fn(batch):images, targets = list(zip(*batch))batched_imgs = cat_list(images, fill_value=0)batched_targets = cat_list(targets, fill_value=255)return batched_imgs, batched_targetsdef cat_list(images, fill_value=0):max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))batch_shape = (len(images),) + max_sizebatched_imgs = images[0].new(*batch_shape).fill_(fill_value)for img, pad_img in zip(images, batched_imgs):pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)return batched_imgs

对__getitem__()函数的重点说明

def __getitem__(self, idx):img = Image.open(self.img_list[idx]).convert('RGB')manual = Image.open(self.manual[idx]).convert('L')   # gqr:转换得到灰度图后,前景的为255,背景的像素为0manual = np.array(manual) / 255  # gqr:将数据进行归一化,前景的为1;背景的像素为0;在进行语义分割时,前景像素值要从1开始roi_mask = Image.open(self.roi_mask[idx]).convert('L')  # gqr:转换成灰度图,感兴趣区域为255;不感兴趣区域是0roi_mask = 255 - np.array(roi_mask)   # gqr:将感兴趣的区域设置为0,不感兴趣的区域设置为255,这样在计算损失时可以排除掉像素为255的区域mask = np.clip(manual + roi_mask, a_min=0, a_max=255)   # gqr:想加后,需要分割的部分为1,背景为0,还有为255的不感兴趣区域"""print(np.unique(mask)):输出结果为:[  0.   1. 255.]"""# 这里转回PIL的原因是,transforms中是对PIL数据进行处理mask = Image.fromarray(mask)if self.transforms is not None:img, mask = self.transforms(img, mask)return img, mask

测试代码:

import os
import timeimport torch
from torchvision import transforms
import numpy as np
from PIL import Imagefrom src import UNetdef time_synchronized():torch.cuda.synchronize() if torch.cuda.is_available() else Nonereturn time.time()def main():classes = 1  # exclude backgroundweights_path = "./multi_train/best_model.pth"img_path = "./DRIVE/test/images/01_test.tif"roi_mask_path = "./DRIVE/test/mask/01_test_mask.gif"assert os.path.exists(weights_path), f"weights {weights_path} not found."assert os.path.exists(img_path), f"image {img_path} not found."assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."mean = (0.709, 0.381, 0.224)std = (0.127, 0.079, 0.043)# get devicesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))# create modelmodel = UNet(in_channels=3, num_classes=classes+1, base_c=32)# load weightsmodel.load_state_dict(torch.load(weights_path, map_location='cpu')['model'])model.to(device)# load roi maskroi_img = Image.open(roi_mask_path).convert('L')   # 将图像转成灰度图roi_img = np.array(roi_img)# load imageoriginal_img = Image.open(img_path).convert('RGB')# from pil image to tensor and normalizedata_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=mean, std=std)])img = data_transform(original_img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)model.eval()  # 进入验证模式with torch.no_grad():# init modelimg_height, img_width = img.shape[-2:]init_img = torch.zeros((1, 3, img_height, img_width), device=device)model(init_img)t_start = time_synchronized()output = model(img.to(device))t_end = time_synchronized()print("inference time: {}".format(t_end - t_start))prediction = output['out'].argmax(1).squeeze(0)   # gqr:在通道维度进行argmaxprediction = prediction.to("cpu").numpy().astype(np.uint8)# 将前景对应的像素值改成255(白色)prediction[prediction == 1] = 255# 将不敢兴趣的区域像素设置成0(黑色)prediction[roi_img == 0] = 0mask = Image.fromarray(prediction)mask.save("test_result.png")if __name__ == '__main__':main()

U-Net网络训练损失函数

采用Dice-Loss损失函数
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
看下图所示 ↓ ↓ ↓ ↓
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/91467.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

A. Sequence with Digits

题目:样例: 输入 8 1 4 487 1 487 2 487 3 487 4 487 5 487 6 487 7输出 42 487 519 528 544 564 588 628 思路: 暴力模拟题,看这数据范围,有些人可能会被唬住,以为是高精度或者容易超时,实际上…

漏斗分析模型

从业务流程起点开始到最后日标完成的每个环节都会有用户流失,因此需要一种分析方法来衡量业务流程每一步的转化效率,漏斗分析方法就是这样的分析方法。 例如,在淘宝上一款商品的浏览量是 300、点击量是 100、订单量是 20、支付量是 10&#…

自动混剪多段视频、合并音频、添加文案的技巧分享

在如今的社交媒体时代,视频的重要性越来越被人们所重视。许多人喜欢记录生活中的美好瞬间,并将其制作成视频分享给朋友和家人。然而,对于那些拍摄了大量视频的人来说,一个一个地进行剪辑和合并可能是一项令人头痛的任务。但是&…

ElasticSearch - 基于 JavaRestClient 查询文档(match、精确、复合查询,以及排序、分页、高亮)

目录 一、基于 JavaRestClient 查询文档 1.1、查询 API 演示 1.1.1、查询基本框架 DSL 请求的对应格式 响应的解析 1.1.2、全文检索查询 1.1.3、精确查询 1.1.4、复合查询 1.1.5、排序和分页 1.1.6、高亮 一、基于 JavaRestClient 查询文档 1.1、查询 API 演示 1.1.…

​中秋团圆季《乡村振兴战略下传统村落文化旅游设计》许少辉八月新著

​中秋团圆季《乡村振兴战略下传统村落文化旅游设计》许少辉八月新著 ​中秋团圆季《乡村振兴战略下传统村落文化旅游设计》许少辉八月新著

机器人中的数值优化|【五】BFGS算法非凸/非光滑处理

机器人中的数值优化|【五】BFGS算法的非凸/非光滑处理 往期内容回顾 机器人中的数值优化|【一】数值优化基础 机器人中的数值优化|【二】最速下降法,可行牛顿法的python实现,以Rosenbrock function为例 机器人中的数值优化|【三】无约束优化&#xff0…

【web前端特效源码】使用HTML5+CSS3制作一个会动的loading加载动画效果~~适合初学者~超简单~ |前端开发|IT软件

b站视频演示效果: 【web前端特效源码】使用HTML5+CSS3制作一个会动的loading加载动画效果~~适合初学者~超简单~ |前端开发|IT软件 效果图: 完整代码: <!DOCTYPE html> <html lang="en"> <head><meta charset="UTF-8"><titl…

贪心算法-

代码随想录 什么是贪心 贪心的本质是选择每一阶段的局部最优&#xff0c;从而达到全局最优。 这么说有点抽象&#xff0c;来举一个例子&#xff1a; 例如&#xff0c;有一堆钞票&#xff0c;你可以拿走十张&#xff0c;如果想达到最大的金额&#xff0c;你要怎么拿&#xff…

Win10自带输入法怎么删除-Win10卸载微软输入法的方法

Win10自带输入法怎么删除&#xff1f;Win10系统自带输入法就是微软输入法&#xff0c;这个输入法满足了很多用户的输入需求。但是&#xff0c;有些用户想要使用其它的输入法&#xff0c;这时候就想删除掉微软输入法。下面小编给大家介绍最简单方便的卸载方法吧。 Win10卸载微软…

赋能工业数字化转型|辽宁七彩赛通受邀出席辽宁省工业互联网+安全可控先进制造业数字服务产业峰会

2023年9月25日下午&#xff0c;由软通动力信息技术&#xff08;集团&#xff09;股份有限公司主办的“工业互联网安全可控先进制造业数字服务产业峰会”在辽宁沈阳顺利举办。省市区各级政府、科研院所领导、技术专家、企业高管以及生态合作伙伴代表等齐聚一堂&#xff0c;共同探…

Lua学习笔记:debug.sethook函数

前言 本篇在讲什么 使用Lua的debug.setHook函数 本篇需要什么 对Lua语法有简单认知 依赖Sublime Text工具 本篇的特色 具有全流程的图文教学 重实践&#xff0c;轻理论&#xff0c;快速上手 提供全流程的源码内容 ★提高阅读体验★ &#x1f449; ♠ 一级标题 &…

【Element-UI】Mock.js,案例首页导航、左侧菜单

一.Mock.js 1、什么是Mock.js Mock.js是一个用于生成模拟数据的JavaScript库。它能够模拟后端API接口&#xff0c;用于前端开发时进行接口调试和测试提高自动化测试效率。使用Mock.js可以快速创建虚拟的数据&#xff0c;并且可以设置数据的类型、格式和规则&#xff0c;从而模…

Backblaze发布2023中期SSD故障数据质量报告

作为一家在2021年在美国纳斯达克上市的云端备份公司&#xff0c;Backblaze一直保持着对外定期发布HDD和SSD的故障率稳定性质量报告&#xff0c;给大家提供了一份真实应用场景下的稳定性分析参考数据。 本文我们主要看下Backblaze最新发布的2023中期SSD相关故障稳定性数据报告。…

AWS-Lambda之导入自定义包-pip包

参考文档&#xff1a; https://repost.aws/zh-Hans/knowledge-center/lambda-import-module-error-python https://blog.csdn.net/fxtxz2/article/details/112035627 简单来说,以 " alibabacloud_dyvmsapi20170525 " 包为例 ## 创建临时目录 mkdir /tmp cd ./tmp …

在Vue中通过ElementUI构建前端页面【登录,注册】,在IEDA构建后端实现前后端分离

一.ElementUI组件入门 1.对于ElementUI的理解 是一套基于 Vue.js 的开源UI组件库&#xff0c;提供了丰富的可复用组件&#xff0c;可以帮助开发者快速构建美观、易用的前端界面 2.Element UI 的特点和优势 多样化的组件&#xff1a;Element UI 提供了众多常用的基础组件&#…

NLP 04(GRU)

一、GRU GRU (Gated Recurrent Unit)也称门控循环单元结构,它也是传统RNN的变体,同LSTM一样能够有效捕捉长序列之间的语义关联&#xff0c; 缓解梯度消失或爆炸现象&#xff0c;同时它的结构和计算要比LSTM更简单,它的核心结构可以分为两个部分去解析: 更新门、重置门 GRU的内…

两横一纵 | 寅家科技发布10年新征程战略

2023年9月22日&#xff0c;寅家科技“寅路向前”10年新征程战略发布会在上海举办&#xff0c;来自投资领域的东方富海、深创投、高新投等知名投资机构&#xff0c;一汽大众、一汽红旗、奇瑞汽车等主机厂&#xff0c;国家新能源汽车技术创新中心、梅克朗、芯驰科技、思特威等合作…

一朵华为云,如何做好百模千态?

点击关注 文丨刘雨琦、郝鑫 2005年华为提出网络时代的“All IP”&#xff0c;2011年提出数字化时代的“All Cloud”&#xff0c;2023年提出智能时代的“All Intelligence”。 截至目前&#xff0c;华为的战略升级经历了三个阶段。 步入智能化&#xff0c;需要迎接的困难依然…

mysql面试题6:MySQL索引的底层原理,是如何实现的?B+树和B树的区别?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:MySQL索引的底层原理,是如何实现的? MySQL索引的底层实现是通过B+树来实现的。B+树是一种多叉平衡查找树,它的特点是能够高效地支持数据的插入…

怎么修改jupyter lab 的工作路径而不是直接再桌面路径打开

要修改Jupyter Lab的工作路径&#xff0c;你可以按照以下步骤操作&#xff1a; 打开终端或命令提示符窗口。 输入 jupyter lab --generate-config 命令来生成Jupyter Lab的配置文件。 找到生成的配置文件&#xff0c;通常会位于 ~/.jupyter/jupyter_notebook_config.py。 使…