文章目录
- 0. 数据获取
- 1. 介绍
- 1.1 简介
- 1.2 任务介绍
- 1.3 数据集介绍
- 1.3.1 介绍
- 1.3.2 数据预处理建议
- 1.4 代码实现参考
- 1.5 训练过程
- 1.5.1 参数设置
- 1.5.2 可视化
- 1.5.3 结果分析
0. 数据获取
关注公众号:『AI学习星球
』
回复:肾脏CT分割
即可获取数据下载。
算法学习
、4对1辅导
、论文讲解
或核心期刊
可以通过公众号
或CSDN
滴滴我
1. 介绍
1.1 简介
每年有超过 400,000 例新发肾癌病例,手术是其最常见的治疗方法。由于肾脏和肾脏肿瘤形态的多样性,目前人们对肿瘤形态如何与手术结果相关 ,以及开发先进的手术计划技术 非常感兴趣。自动语义分割是这些工作的一个很有前途的工具,但形态异质性使其成为一个难题。
这一挑战的目标是加速可靠的肾脏和肾脏肿瘤语义分割方法的发展。我们已经为 300 名在我们机构接受部分或根治性肾切除术的独特肾癌患者的动脉期腹部 CT 扫描生成了真实语义分割。其中 210 个已发布用于模型训练和验证,其余 90 个将保留用于客观模型评估。
1.2 任务介绍
该项目描述了,使用深度学习中图像语义分割网络的U-Net,很多医学图像处理的网络结构都由U-Net改进而来。U-Net可以被看作是基于FCN和SegNet的一种改进方法,采用了FCN的全卷积、反卷积上采样、越级连接的方法,采用了SegNet的Encoder-Decoder结构。
1.3 数据集介绍
1.3.1 介绍
KiTS2019是MICCAI19的一个竞赛项目,项目的任务是对3D-CT数据进行肾脏和肾脏肿瘤的分割,官方的数据集提供了210个case作为训练集,90个case作为测试集。共有800多人报名参加了这一竞赛,最终提交的结果的team有126支,其中被认定有效的为100个记录入leaderboard。目前这一竞赛状态为开放性质的,有兴趣的可以参与一下。
1.3.2 数据预处理建议
KiTS19提供的数据是3D CT图像,我们要训练的是最简单的2D U-Net,因此要从3D CT体数据中读取2D切片。数据集的提供方在其Github上很贴心的提供了可视化的代码(就在我们的数据集中),是用python调用了nibabel库处理.nii格式的体数据得到2D的.png格式的切片。可视化的结果如下图所示,需要对切片进行筛选。另外需要补充的是在KiTS的数据集中分割的标签有三类:背景、肾脏、肾脏肿瘤,我们想进行的是简单的背景与肾脏二分类问题而不是多分类问题,因此在可视化过程中比较简单粗暴的将肿瘤视为肾脏的一部分。
肾肿瘤训练数据一共有210例,选择0-199例来训练,200-209例来测试。
- 分析肾肿瘤数据金标准的类别信息,一共有三个类别值:0是背景,1是肾区域,2是肾肿瘤区域。
- 分析肾肿瘤数据的大小和Spacing信息,大多数图像大小都是512x512xthickness,只有第160例数据是796x512xthickness,thickness数值从几十到几百,z方向上Spacing值是从1mm到5mm。
- 分析窗宽窗位信息去除噪声和不相关区域信息。
- 窗宽窗位设置成-200-300,将图像x和y都缩放到512,通过插值将z方向上Spacing值从原始变成1mm。
- 为了准备3D分割肾区域,需要对图像取Patch操作,Patch大小选择128x128x32,如果你的GPU显存够大可以设置成48或64。
1.4 代码实现参考
import argparse
import logging
import os
import sysimport numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdmfrom eval import eval_net
from unet import UNetfrom visdom import Visdom
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_splitdir_img = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\slice_png'
dir_mask = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\mask_png'
dir_checkpoint = 'checkpoints/'def train_net(net,device,epochs=5,batch_size=1,lr=0.1,val_percent=0.2,save_cp=True,img_scale=1):dataset = BasicDataset(dir_img, dir_mask, img_scale)n_val = int(len(dataset) * val_percent)n_train = len(dataset) - n_valtrain, val = random_split(dataset, [n_train, n_val])train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)#writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')viz=Visdom()viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))viz.line([0.], [0.], win='learning_rate', opts=dict(title='learning_rate'))viz.line([0.], [0.], win='Dice/test', opts=dict(title='Dice/test'))global_step = 0logging.info(f'''Starting training:Epochs: {epochs}Batch size: {batch_size}Learning rate: {lr}Training size: {n_train}Validation size: {n_val}Checkpoints: {save_cp}Device: {device.type}Images scaling: {img_scale}''')optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)if net.n_classes > 1:criterion = nn.CrossEntropyLoss()else:criterion = nn.BCEWithLogitsLoss()for epoch in range(epochs):net.train()epoch_loss = 0with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:for batch in train_loader:imgs = batch['image']true_masks = batch['mask']assert imgs.shape[1] == net.n_channels, \f'Network has been defined with {net.n_channels} input channels, ' \f'but loaded images have {imgs.shape[1]} channels. Please check that ' \'the images are loaded correctly.'imgs = imgs.to(device=device, dtype=torch.float32)mask_type = torch.float32 if net.n_classes == 1 else torch.longtrue_masks = true_masks.to(device=device, dtype=mask_type)masks_pred = net(imgs)#print('mask_pred',masks_pred.shape)#print('masks_pred',masks_pred.shape)#print('true_masks', true_masks.shape)viz.image(imgs, win='imgs/train')viz.image(true_masks, win='masks/true/train')viz.image(masks_pred, win='masks/pred/train')loss = criterion(masks_pred, true_masks)epoch_loss += loss.item()#writer.add_scalar('Loss/train', loss.item(), global_step)viz.line([loss.item()],[global_step],win='train_loss',update='append')pbar.set_postfix(**{'loss (batch)': loss.item()})optimizer.zero_grad()loss.backward()#nn.utils.clip_grad_value_(net.parameters(), 0.1)optimizer.step()pbar.update(imgs.shape[0])global_step += 1if global_step % (n_train // (10 * batch_size)) == 0:# for tag, value in net.named_parameters():# tag = tag.replace('.', '/')# writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)# writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)val_score = eval_net(net, val_loader, device)scheduler.step(val_score)#writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)viz.line([optimizer.param_groups[0]['lr']], [global_step], win='learning_rate', update='append')if net.n_classes > 1:logging.info('Validation cross entropy: {}'.format(val_score))#writer.add_scalar('Loss/test', val_score, global_step)else:logging.info('Validation Dice Coeff: {}'.format(val_score))#writer.add_scalar('Dice/test', val_score, global_step)viz.line([val_score], [global_step], win='Dice/test', update='append')viz.image(imgs, win='images')if net.n_classes == 1:print('true_mask',true_masks.shape,true_masks.type)viz.image( true_masks, win='masks/true')print('pred',(torch.sigmoid(masks_pred) > 0.5).squeeze(0).shape)viz.images((torch.sigmoid(masks_pred) > 0.5),win='masks/pred')if save_cp:try:os.mkdir(dir_checkpoint)logging.info('Created checkpoint directory')except OSError:passtorch.save(net.state_dict(),dir_checkpoint + f'CP_epoch{epoch + 1}.pth')logging.info(f'Checkpoint {epoch + 1} saved !')#writer.close()def eval_net(net, loader, device):"""Evaluation without the densecrf with the dice coefficient"""net.eval()mask_type = torch.float32 #if net.n_classes == 1 else torch.longn_val = len(loader) # the number of batchtot = 0with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:for batch in loader:imgs, true_masks = batch['image'], batch['mask']imgs = imgs.to(device=device, dtype=torch.float32)true_masks = true_masks.to(device=device, dtype=mask_type)with torch.no_grad():mask_pred = net(imgs)#['out']# if net.n_classes > 1:# tot += F.cross_entropy(mask_pred, true_masks).item()# else:pred = torch.sigmoid(mask_pred)pred = (pred > 0.5).float()tot += dice_coeff(pred, true_masks).item()pbar.update()net.train()return tot / n_valdef get_args():parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,help='Number of epochs', dest='epochs')parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,help='Batch size', dest='batchsize')parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,help='Learning rate', dest='lr')parser.add_argument('-f', '--load', dest='load', type=str, default=False,help='Load model from a .pth file')parser.add_argument('-s', '--scale', dest='scale', type=float, default=1,help='Downscaling factor of the images')parser.add_argument('-v', '--validation', dest='val', type=float, default=20.0,help='Percent of the data that is used as validation (0-100)')return parser.parse_args()if __name__ == '__main__':logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')args = get_args()device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')logging.info(f'Using device {device}')# Change here to adapt to your data# n_channels=3 for RGB images# n_classes is the number of probabilities you want to get per pixel# - For 1 class and background, use n_classes=1# - For 2 classes, use n_classes=1# - For N > 2 classes, use n_classes=Nnet = UNet(n_channels=1, n_classes=1, bilinear=True)logging.info(f'Network:\n'f'\t{net.n_channels} input channels\n'f'\t{net.n_classes} output channels (classes)\n'f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')if args.load:net.load_state_dict(torch.load(args.load, map_location=device))logging.info(f'Model loaded from {args.load}')net.to(device=device)# faster convolutions, but more memory# cudnn.benchmark = Truetry:train_net(net=net,epochs=args.epochs,batch_size=args.batchsize,lr=args.lr,device=device,img_scale=args.scale,val_percent=args.val / 100)except KeyboardInterrupt:torch.save(net.state_dict(), 'INTERRUPTED.pth')logging.info('Saved interrupt')try:sys.exit(0)except SystemExit:os._exit(0)
1.5 训练过程
1.5.1 参数设置
训练与验证比例: 8:2 (1680:420)
batch_size: 2
学习率:torch.optim.lr_scheduler.ReduceLROnPlateau,当网络的评价指标不在提升的时候,可以通过降低网络的学习率来提高网络性能损失函数:BCEWithLogitsLoss 衡量目标和输出之间的二进制交叉熵
1.5.2 可视化
使用visdom进行可视化一开始的训练状态,左边为真实的mask,右边为网络的输出,可以看到一开始网络的输出还是不太行的。
当进行完第一轮训练之后训练的结果如图所示,红色所框的为训练过程,蓝色所框为验证过程,包括了原图、真实的mask T、预测的mask P。训练和验证过程中预测mask的差异来自于是否进行了二值化处理。
第四轮训练之后的结果,预测的mask与真实的mask已经很接近了。
1.5.3 结果分析
实验结果:Dice系数:0.832
结果分析:
- 原数据为三维,本次实验只使用的二维切片
- 原数据的mask肿瘤和肾脏是分开的,在数据处理过程中统一化为了肾脏。
- 没有做数据增强、参数调整,训练不够充分。
关注公众号:『AI学习星球
』
回复:肾脏CT分割
即可获取数据下载。
算法学习
、4对1辅导
、论文讲解
或核心期刊
可以通过公众号
或CSDN
滴滴我