计算机视觉-03-使用U-Net实现肾脏CT分割(包含数据和代码)

文章目录

    • 0. 数据获取
    • 1. 介绍
      • 1.1 简介
      • 1.2 任务介绍
      • 1.3 数据集介绍
        • 1.3.1 介绍
        • 1.3.2 数据预处理建议
      • 1.4 代码实现参考
      • 1.5 训练过程
        • 1.5.1 参数设置
        • 1.5.2 可视化
        • 1.5.3 结果分析

0. 数据获取

关注公众号:『AI学习星球
回复:肾脏CT分割 即可获取数据下载。
算法学习4对1辅导论文讲解核心期刊可以通过公众号CSDN滴滴我
在这里插入图片描述

1. 介绍

1.1 简介

每年有超过 400,000 例新发肾癌病例,手术是其最常见的治疗方法。由于肾脏和肾脏肿瘤形态的多样性,目前人们对肿瘤形态如何与手术结果相关 ,以及开发先进的手术计划技术 非常感兴趣。自动语义分割是这些工作的一个很有前途的工具,但形态异质性使其成为一个难题。
这一挑战的目标是加速可靠的肾脏和肾脏肿瘤语义分割方法的发展。我们已经为 300 名在我们机构接受部分或根治性肾切除术的独特肾癌患者的动脉期腹部 CT 扫描生成了真实语义分割。其中 210 个已发布用于模型训练和验证,其余 90 个将保​​留用于客观模型评估。

1.2 任务介绍

该项目描述了,使用深度学习中图像语义分割网络的U-Net,很多医学图像处理的网络结构都由U-Net改进而来。U-Net可以被看作是基于FCN和SegNet的一种改进方法,采用了FCN的全卷积、反卷积上采样、越级连接的方法,采用了SegNet的Encoder-Decoder结构。

1.3 数据集介绍

1.3.1 介绍

KiTS2019是MICCAI19的一个竞赛项目,项目的任务是对3D-CT数据进行肾脏和肾脏肿瘤的分割,官方的数据集提供了210个case作为训练集,90个case作为测试集。共有800多人报名参加了这一竞赛,最终提交的结果的team有126支,其中被认定有效的为100个记录入leaderboard。目前这一竞赛状态为开放性质的,有兴趣的可以参与一下。

1.3.2 数据预处理建议

KiTS19提供的数据是3D CT图像,我们要训练的是最简单的2D U-Net,因此要从3D CT体数据中读取2D切片。数据集的提供方在其Github上很贴心的提供了可视化的代码(就在我们的数据集中),是用python调用了nibabel库处理.nii格式的体数据得到2D的.png格式的切片。可视化的结果如下图所示,需要对切片进行筛选。另外需要补充的是在KiTS的数据集中分割的标签有三类:背景、肾脏、肾脏肿瘤,我们想进行的是简单的背景与肾脏二分类问题而不是多分类问题,因此在可视化过程中比较简单粗暴的将肿瘤视为肾脏的一部分。

肾肿瘤训练数据一共有210例,选择0-199例来训练,200-209例来测试。

  1. 分析肾肿瘤数据金标准的类别信息,一共有三个类别值:0是背景,1是肾区域,2是肾肿瘤区域。
  2. 分析肾肿瘤数据的大小和Spacing信息,大多数图像大小都是512x512xthickness,只有第160例数据是796x512xthickness,thickness数值从几十到几百,z方向上Spacing值是从1mm到5mm。
  3. 分析窗宽窗位信息去除噪声和不相关区域信息。
  4. 窗宽窗位设置成-200-300,将图像x和y都缩放到512,通过插值将z方向上Spacing值从原始变成1mm。
  5. 为了准备3D分割肾区域,需要对图像取Patch操作,Patch大小选择128x128x32,如果你的GPU显存够大可以设置成48或64。

1.4 代码实现参考

import argparse
import logging
import os
import sysimport numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdmfrom eval import eval_net
from unet import UNetfrom visdom import Visdom
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_splitdir_img = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\slice_png'
dir_mask = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\mask_png'
dir_checkpoint = 'checkpoints/'def train_net(net,device,epochs=5,batch_size=1,lr=0.1,val_percent=0.2,save_cp=True,img_scale=1):dataset = BasicDataset(dir_img, dir_mask, img_scale)n_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain, val = random_split(dataset, [n_train, n_val])train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)#writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')viz=Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))viz.line([0.], [0.], win='learning_rate', opts=dict(title='learning_rate'))viz.line([0.], [0.], win='Dice/test', opts=dict(title='Dice/test'))global_step = 0logging.info(f'''Starting training:Epochs:          {epochs}Batch size:      {batch_size}Learning rate:   {lr}Training size:   {n_train}Validation size: {n_val}Checkpoints:     {save_cp}Device:          {device.type}Images scaling:  {img_scale}''')optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)if net.n_classes > 1:criterion = nn.CrossEntropyLoss()else:criterion = nn.BCEWithLogitsLoss()for epoch in range(epochs):net.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:for batch in train_loader:imgs = batch['image']true_masks = batch['mask']assert imgs.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {imgs.shape[1]} channels. Please check that ' \'the images are loaded correctly.'imgs = imgs.to(device=device, dtype=torch.float32)mask_type = torch.float32 if net.n_classes == 1 else torch.longtrue_masks = true_masks.to(device=device, dtype=mask_type)masks_pred = net(imgs)#print('mask_pred',masks_pred.shape)#print('masks_pred',masks_pred.shape)#print('true_masks', true_masks.shape)viz.image(imgs, win='imgs/train')viz.image(true_masks, win='masks/true/train')viz.image(masks_pred, win='masks/pred/train')loss = criterion(masks_pred, true_masks)epoch_loss += loss.item()#writer.add_scalar('Loss/train', loss.item(), global_step)viz.line([loss.item()],[global_step],win='train_loss',update='append')pbar.set_postfix(**{'loss (batch)': loss.item()})optimizer.zero_grad()loss.backward()#nn.utils.clip_grad_value_(net.parameters(), 0.1)optimizer.step()pbar.update(imgs.shape[0])global_step += 1if global_step % (n_train // (10 * batch_size)) == 0:# for tag, value in net.named_parameters():#     tag = tag.replace('.', '/')#     writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)#     writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)val_score = eval_net(net, val_loader, device)scheduler.step(val_score)#writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)viz.line([optimizer.param_groups[0]['lr']], [global_step], win='learning_rate', update='append')if net.n_classes > 1:logging.info('Validation cross entropy: {}'.format(val_score))#writer.add_scalar('Loss/test', val_score, global_step)else:logging.info('Validation Dice Coeff: {}'.format(val_score))#writer.add_scalar('Dice/test', val_score, global_step)viz.line([val_score], [global_step], win='Dice/test', update='append')viz.image(imgs, win='images')if net.n_classes == 1:print('true_mask',true_masks.shape,true_masks.type)viz.image( true_masks, win='masks/true')print('pred',(torch.sigmoid(masks_pred) > 0.5).squeeze(0).shape)viz.images((torch.sigmoid(masks_pred) > 0.5),win='masks/pred')if save_cp:try:os.mkdir(dir_checkpoint)logging.info('Created checkpoint directory')except OSError:passtorch.save(net.state_dict(),dir_checkpoint + f'CP_epoch{epoch + 1}.pth')logging.info(f'Checkpoint {epoch + 1} saved !')#writer.close()def eval_net(net, loader, device):"""Evaluation without the densecrf with the dice coefficient"""net.eval()mask_type = torch.float32 #if net.n_classes == 1 else torch.longn_val = len(loader)  # the number of batchtot = 0with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:for batch in loader:imgs, true_masks = batch['image'], batch['mask']imgs = imgs.to(device=device, dtype=torch.float32)true_masks = true_masks.to(device=device, dtype=mask_type)with torch.no_grad():mask_pred = net(imgs)#['out']# if net.n_classes > 1:#     tot += F.cross_entropy(mask_pred, true_masks).item()# else:pred = torch.sigmoid(mask_pred)pred = (pred > 0.5).float()tot += dice_coeff(pred, true_masks).item()pbar.update()net.train()return tot / n_valdef get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,help='Number of epochs', dest='epochs')parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,help='Batch size', dest='batchsize')parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,help='Learning rate', dest='lr')parser.add_argument('-f', '--load', dest='load', type=str, default=False,help='Load model from a .pth file')parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,help='Downscaling factor of the images')parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,help='Percent of the data that is used as validation (0-100)')return parser.parse_args()if __name__ == '__main__':logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')args = get_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixel#   - For 1 class and background, use n_classes=1#   - For 2 classes, use n_classes=1#   - For N > 2 classes, use n_classes=Nnet = UNet(n_channels=1, n_classes=1, bilinear=True)logging.info(f'Network:\n'f'\t{net.n_channels} input channels\n'f'\t{net.n_classes} output channels (classes)\n'f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')if args.load:net.load_state_dict(torch.load(args.load, map_location=device))logging.info(f'Model loaded from {args.load}')net.to(device=device)# faster convolutions, but more memory# cudnn.benchmark = Truetry:train_net(net=net,epochs=args.epochs,batch_size=args.batchsize,lr=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100)except KeyboardInterrupt:torch.save(net.state_dict(), 'INTERRUPTED.pth')logging.info('Saved interrupt')try:sys.exit(0)except SystemExit:os._exit(0)

1.5 训练过程

1.5.1 参数设置

训练与验证比例: 8:2 (1680:420)
batch_size: 2
学习率:torch.optim.lr_scheduler.ReduceLROnPlateau,当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能损失函数:BCEWithLogitsLoss 衡量目标和输出之间的二进制交叉熵

1.5.2 可视化

使用visdom进行可视化一开始的训练状态,左边为真实的mask,右边为网络的输出,可以看到一开始网络的输出还是不太行的。

在这里插入图片描述

当进行完第一轮训练之后训练的结果如图所示,红色所框的为训练过程,蓝色所框为验证过程,包括了原图、真实的mask T、预测的mask P。训练和验证过程中预测mask的差异来自于是否进行了二值化处理。

在这里插入图片描述
第四轮训练之后的结果,预测的mask与真实的mask已经很接近了。
在这里插入图片描述

1.5.3 结果分析

在这里插入图片描述
实验结果:Dice系数:0.832
结果分析:

  1. 原数据为三维,本次实验只使用的二维切片
  2. 原数据的mask肿瘤和肾脏是分开的,在数据处理过程中统一化为了肾脏。
  3. 没有做数据增强、参数调整,训练不够充分。

关注公众号:『AI学习星球
回复:肾脏CT分割 即可获取数据下载。
算法学习4对1辅导论文讲解核心期刊可以通过公众号CSDN滴滴我
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

高精度时钟芯片SD2405

概要 SD2405是一款非常优秀的RTC解决方案,为了能让用户在Arduino上有一款方便易用的时钟模块。该模块是一款内置晶振,支持IIC串行接口的高精度时钟模块;内置一次性工业级电池,可保证外部掉电的情况下,可以继续工作5~8…

实例分割 Mask-RCNN

参考文章 使用LabelMe标注目标检测数据集并转换为COCO2017格式_labelme转coco-CSDN博客 数据集选择 voc 这次不选择voc,因为文件组织太难了 voc2012文件夹组织 COCO COCO介绍 MC COCO2017年主要包含以下四个任务:目标检测与分割、图像描述、人体关…

KP 2sv Authenticator一款免费处理亚马逊两步验证码的软件

KP 2sv Authenticator 被誉为一款免费而强大的亚马逊两步验证软件,操作简便轻松。 软件使用方法极为简单,用户只需直接输入身份验证应用程序生成的代码,即可迅速生成随机验证码,帮助用户顺利完成亚马逊的两步验证流程。这款小软件…

有了安卓模拟器,就能在Windows 10或11上像使用安卓操作系统一样使用安卓

你可以使用Android模拟器在Windows 11或Windows 10中运行Android应用程序。如果你喜欢的应用程序只在手机上运行,但你想在电脑上使用,这些模拟器会很有用。 BlueStacks 与整个操作系统模拟器不同,BlueStacks只在Windows上模拟Android应用程序。它真的很容易使用,所以你不需…

香港云服务器:全面介绍与使用场景分析

这几年基于国内互联网技术的发展,各类海外贸易的兴起,很多网站都启用了海外云服务。这其中,香港的 IDC 市场异常火爆。也不奇怪,就目前来看,国内大多数网站的访问用户在国内外均有涉及,而香港云服务器恰好满…

Java第二十一章总结

网络编程三要素 ip地址:计算机在网络中的唯一标识 端口:应用程序在计算机中唯一标识 协议:通信协议,常见有UDP和TCP协议 InetAddress类 表示Internet协议地址 //返回InetAddress对象 InetAddress byName InetAddress.…

全国公共汽车、出租车拥有情况及客运量、货运量数据,shp、excel数据均有,多指标可查询

基本信息. 数据名称: 全国公共汽车、出租车拥有情况及客运量、货运量数据 数据格式: Shp、Excel 数据时间: 2020-2022年 数据几何类型: 面 数据坐标系: WGS84 数据来源:中国城市统计年鉴 数据字段: 序号字段名称字段说明1xzqhdm行政区划代码…

机器学习基础知识分享:深度学习

深度学习(Deep Learning)是近年来发展十分迅速的研究领域,并且在人工智能的很多子领域都取得了巨大的成功.从根源来讲,深度学习是机器学习的一个分支,是指一类问题以及解决这类问题的方法。 深度学习 为了…

vue中的内置指令v-model的作用和常见使用方法以及自定义组件上的用法

一、v-model是什么 v-model是Vue框架的一种内置的API指令,本质是一种语法糖写法,它负责监听用户的输入事件以更新数据,并对一些极端场景进行一些特殊处理。在Vue中,v-model是用于在表单元素和组件之间创建双向数据绑定的指令。它…

【AIGC】大语言模型的采样策略--temperature、top-k、top-p等

总结如下: 图片链接 参考 LLM解码-采样策略串讲 LLM大模型解码生成方式总结 LLM探索:GPT类模型的几个常用参数 Top-k, Top-p, Temperature

【动态规划系列】环形子数组的和-918

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

linux下的进程程序替换

进程程序替换 替换概念替换函数execl()execv()execvp()/execlp()execle()/execvpe() 如何在C/C程序里面执行别的语言写的程序。小tips 替换概念 当进程调用一种exec函数时,该进程的用户空间代码和数据完全被新程序替换,从新程序的代码部分开始运行。调用…

爬虫 selenium语法 (八)

目录 一、为什么使用selenium 二、selenium语法——元素定位 1.根据 id 找到对象 2.根据标签属性的属性值找到对象 3.根据Xpath语句获取对象 4.根据标签名获取对象 5.使用bs语法获取对象 6.通过链接文本获取对象 三、selenium语法——访问元素信息 1.获取属性的属性值…

有爱的冬天不再冷——壹基金儿童温暖包抵达富平

12月6日,富平县帮帮乐公益协会组织志愿者在协会楼下分装了由爱心企业、个人捐赠的144个壹基金儿童温暖包,争取在下周寒流来临前送到困境儿童手中,温暖他们的整个冬天。 壹基金温暖包项目是针对6—12岁困境儿童、留守儿童设计的暖冬应急生活物…

MySQL数据库sql语句操作

一、数据库模型 关系型数据库是一种以表格形式组织和存储数据的数据库。它使用关系模型,其中数据被组织为多个表格,每个表格包含了多个行和列。每个表格的列描述了数据的属性,而行包含了实际的数据记录。 非关系型数据库,也称为…

使用命令行创建vue3项目等待时间长解决方案

问题描述 今天在使用命令行创建vue3项目的时候,发现命令行窗口卡了很久,明明已经更换了安装包的源,并且检查环境变量配置正确的情况下,为什么还要等待那么久呢? 解决方案 使用命令再次检查更换淘宝的源是否配置成功…

玩转系统|利用HestiaCP自建NS解析及邮局并利用MailGun进行发信

前述 HestiaCP是一个VestaCP分叉来的产物,而同样作为VestaCP分叉来的myVesta也具有类似的功能。VestaCP本身作为一个社区的产区,其仅仅有一个商业插件需要每月付费5USD进行使用,因此为了达到完全开放使用的目的,这里选择使用Hest…

小电流MOSFET 选型分析数据,可应用于电子烟,电动工具,智能穿戴等产品上

小电流双N,D-N通道MOSFET,电压60V-100V左右 电流300mA-500MA,采用封装形式多样。具有低导通电阻,可快速切换速度,易于设计的驱动电路也易于并联,ESD保护,低电压驱动使该器件非常适合便携式设备…

通俗易懂的案例+代码解释AOP 切面编程

目录 1. 理解AOP2 Before2.1 controller层2.2 service层2.3 自定义注解2.4 切面 advice 3 After4 Around spring的三大核心:IOC控制反转、DI依赖注入、AOP面向切面编程 刚开始接触springboot项目,前两个使用的多,亲自使用AOP的机会并不多&…

【学习笔记】python仅拷贝list的值,引出浅拷贝与深拷贝

一、python 仅拷贝list的值(来源于gpt) 在 Python 中,可以使用切片或 copy() 方法来仅拷贝列表的值。 1、使用切片 a [1, 2, 3, 4, 5] b a[:] # 通过切片来拷贝 a 的值 在上面的代码中,我们使用切片来拷贝列表 a 的值&#xff…