目录
- 1. 作者介绍
- 2. 理论知识介绍
- 2.1 Unet++模型介绍
- 3. 实验过程
- 3.1 数据集介绍
- 3.2 代码实现
- 3.3 结果
- 4. 参考链接
1. 作者介绍
郭冠群,男,西安工程大学电子信息学院,2023级研究生
研究方向:机器视觉与人工智能
电子邮件:1347418097@qq.com
路治东,男,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:2063079527@qq.com
2. 理论知识介绍
2.1 Unet++模型介绍
- Unet
语义分割是将图像划分为有意义的区域,并标注每个区域所属的类别。语义分割网络是实现这一任务的工具,其中Unet模型通过跨阶段融合不同尺寸的特征图来实现这一目标。
- 特征图融合
特征图融合的目的是结合浅层和深层特征,提升分割效果。浅层特征能提取图像的简单特征如边界和颜色,而深层特征提取图像的深层次语义信息。多个特征图的融合能够弥补单一特征层次信息的不足。 - Unet++
Unet++通过嵌套的密集跳过路径连接编码器和解码器子网络,减少了特征映射之间的语义差距,从而提高了分割效果。在测试阶段,由于输入图像只进行前向传播,被剪掉的部分对前面输出没有影响,而在训练阶段,这些部分会帮助其他部分进行权重更新。
3. 实验过程
3.1 数据集介绍
- 数据集来源
Kaggle—2018dsb数据集来自于2018年数据科学碗,其任务是从显微镜图像中分割细胞核。这对于推动医学发现具有重要意义,特别是在病理学、癌症研究和其他生命科学领域。
- 下载途径
百度网盘 链接:https://pan.baidu.com/s/1GXtZ0clE12oZKooF61siKQ
提取码:tsh7
- 数据集内容
数据集包含显微镜下细胞图像及其对应的分割掩码。训练集用于训练模型,测试集用于评估模型性能。
3.2 代码实现
- train.py
import os
import argparse
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import albumentations as albu
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
import archs
import losses
from dataset import CustomDataset
from metrics import iou_score
from utils import AverageMeter, str2boolclass Config:@staticmethoddef from_cmdline():parser = argparse.ArgumentParser(description='Training configuration')parser.add_argument('--name', default=None, help='Model name: (default: arch+timestamp)')parser.add_argument('--epochs', type=int, default=100, help='Number of total epochs to run')parser.add_argument('-b', '--batch_size', type=int, default=8, help='Mini-batch size (default: 16)')parser.add_argument('--arch', default='NestedUNet', choices=archs.__all__, help='Model architecture')parser.add_argument('--deep_supervision', type=str2bool, default=False, help='Use deep supervision if True')parser.add_argument('--input_channels', type=int, default=3, help='Number of input channels')parser.add_argument('--num_classes', type=int, default=1, help='Number of classes')parser.add_argument('--input_w', type=int, default=96, help='Input image width')parser.add_argument('--input_h', type=int, default=96, help='Input image height')parser.add_argument('--loss', default='BCEDiceLoss', choices=losses.__all__, help='Loss function')parser.add_argument('--dataset', default='dsb2018_96', help='Dataset name')parser.add_argument('--img_ext', default='.png', help='Image file extension')parser.add_argument('--mask_ext', default='.png', help='Mask file extension')parser.add_argument('--optimizer', default='SGD', choices=['Adam', 'SGD'], help='Optimizer type')parser.add_argument('--lr', '--learning_rate', type=float, default=1e-3, help='Initial learning rate')parser.add_argument('--momentum', type=float, default=0.9, help='Optimizer momentum')parser.add_argument('--weight_decay', type=float, default=1e-4, help='Weight decay rate')parser.add_argument('--nesterov', type=str2bool, default=False, help='Nesterov momentum')parser.add_argument('--scheduler', default='CosineAnnealingLR',choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'],help='Learning rate scheduler')parser.add_argument('--min_lr', type=float, default=1e-5, help='Minimum learning rate')parser.add_argument('--factor', type=float, default=0.1, help='Factor for ReduceLROnPlateau')parser.add_argument('--patience', type=int, default=2, help='Patience for ReduceLROnPlateau')parser.add_argument('--milestones', type=str, default='1,2', help='Milestones for MultiStepLR')parser.add_argument('--gamma', type=float, default=2 / 3, help='Gamma for MultiStepLR')parser.add_argument('--early_stopping', type=int, default=-1, help='Early stopping threshold')parser.add_argument('--num_workers', type=int, default=0, help='Number of data loading workers')args = parser.parse_args()return vars(args)class ModelManager:def __init__(self, config):self.config = configself.model = self.create_model().cuda()self.criterion = self.create_criterion().cuda()self.optimizer = self.create_optimizer()self.scheduler = self.create_scheduler()def create_model(self):return archs.__dict__[self.config['arch']](self.config['num_classes'],self.config['input_channels'],self.config['deep_supervision'])def create_criterion(self):if self.config['loss'] == 'BCEWithLogitsLoss':return nn.BCEWithLogitsLoss()else:return losses.__dict__[self.config['loss']]()def create_optimizer(self):params = filter(lambda p: p.requires_grad, self.model.parameters())if self.config['optimizer'] == 'Adam':return optim.Adam(params, lr=self.config['lr'], weight_decay=self.config['weight_decay'])elif self.config['optimizer'] == 'SGD':return optim.SGD(params, lr=self.config['lr'], momentum=self.config['momentum'],nesterov=self.config['nesterov'], weight_decay=self.config['weight_decay'])def create_scheduler(self):if self.config['scheduler'] == 'CosineAnnealingLR':return lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=self.config['epochs'], eta_min=self.config['min_lr'])elif self.config['scheduler'] == 'ReduceLROnPlateau':return lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=self.config['factor'], patience=self.config['patience'],min_lr=self.config['min_lr'])elif self.config['scheduler'] == 'MultiStepLR':milestones = list(map(int, self.config['milestones'].split(',')))return lr_scheduler.MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config['gamma'])class DataManager:def __init__(self, config):self.config = configself.train_loader, self.val_loader = self.setup_loaders()def setup_loaders(self):img_ids = glob(os.path.join('inputs', self.config['dataset'], 'images', '*' + self.config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)train_transform = Compose([albu.RandomRotate90(), albu.Flip(),OneOf([albu.HueSaturationValue(), albu.RandomBrightnessContrast()], p=1),albu.Resize(self.config['input_h'], self.config['input_w']), albu.Normalize()])val_transform = Compose([albu.Resize(self.config['input_h'], self.config['input_w']), albu.Normalize()])train_dataset = CustomDataset(img_ids=train_img_ids,img_dir=os.path.join('inputs', self.config['dataset'], 'images'),mask_dir=os.path.join('inputs', self.config['dataset'], 'masks'),img_ext=self.config['img_ext'],mask_ext=self.config['mask_ext'],num_classes=self.config['num_classes'],transform=train_transform)val_dataset = CustomDataset(img_ids=val_img_ids,img_dir=os.path.join('inputs', self.config['dataset'], 'images'),mask_dir=os.path.join('inputs', self.config['dataset'], 'masks'),img_ext=self.config['img_ext'],mask_ext=self.config['mask_ext'],num_classes=self.config['num_classes'],transform=val_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=self.config['batch_size'],shuffle=True, num_workers=self.config['num_workers'], drop_last=True)val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=self.config['batch_size'],shuffle=False, num_workers=self.config['num_workers'], drop_last=False)return train_loader, val_loaderdef main():config = Config.from_cmdline()manager = ModelManager(config)data_manager = DataManager(config)for epoch in range(config['epochs']):train_loss, train_iou = train_epoch(data_manager.train_loader, manager.model, manager.criterion,manager.optimizer, config)val_loss, val_iou = validate_epoch(data_manager.val_loader, manager.model, manager.criterion, config)print(f'Epoch: {epoch}, Train Loss: {train_loss}, Train IOU: {train_iou}, Val Loss: {val_loss}, Val IOU: {val_iou}')# Update scheduler, save models, etc.if config['scheduler'] == 'ReduceLROnPlateau':manager.scheduler.step(val_loss)else:manager.scheduler.step()def train_epoch(train_loader, model, criterion, optimizer, config):model.train()avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}pbar = tqdm(total=len(train_loader), desc='Train')for data in train_loader:if len(data) == 2:input, target = dataelif len(data) > 2:input, target, _ = data # 根据实际返回的数据格式解包input = input.cuda()target = target.cuda()# Compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)# Compute gradient and do optimizer stepoptimizer.zero_grad()loss.backward()optimizer.step()avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))pbar.update(1)pbar.set_postfix({'Loss': avg_meters['loss'].avg, 'IoU': avg_meters['iou'].avg})pbar.close()return avg_meters['loss'].avg, avg_meters['iou'].avgdef validate_epoch(val_loader, model, criterion, config):model.eval()avg_meters = {'loss': AverageMeter(), 'iou': AverageMeter()}pbar = tqdm(total=len(val_loader), desc='Validate')with torch.no_grad():for data in val_loader:if len(data) == 2:input, target = dataelif len(data) > 2:input, target, _ = data # 根据实际返回的数据格式解包input = input.cuda()target = target.cuda()# Compute outputif config['deep_supervision']:outputs = model(input)loss = 0for output in outputs:loss += criterion(output, target)loss /= len(outputs)iou = iou_score(outputs[-1], target)else:output = model(input)loss = criterion(output, target)iou = iou_score(output, target)avg_meters['loss'].update(loss.item(), input.size(0))avg_meters['iou'].update(iou, input.size(0))pbar.update(1)pbar.set_postfix({'Loss': avg_meters['loss'].avg, 'IoU': avg_meters['iou'].avg})pbar.close()return avg_meters['loss'].avg, avg_meters['iou'].avg# Training and validation logic can go here using manager and data_managerif __name__ == '__main__':main()
- val.py
import argparse
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
import albumentations as albu
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdmimport archs
from dataset import CustomDataset
from metrics import iou_score
from utils import AverageMeter"""
需要指定参数:--name dsb2018_96_NestedUNet_woDS
"""def parse_args():parser = argparse.ArgumentParser()parser.add_argument('--name', default=None,help='model name')args = parser.parse_args()return argsdef main():args = parse_args()if args.name is None:print("Error: You must specify the model name using the --name argument.")returnwith open(f'models/{args.name}/config.yml', 'r') as f:config = yaml.load(f, Loader=yaml.FullLoader)print('-' * 20)for key in config.keys():print('%s: %s' % (key, str(config[key])))print('-' * 20)cudnn.benchmark = True# create modelprint("=> creating model %s" % config['arch'])model = archs.__dict__[config['arch']](config['num_classes'],config['input_channels'],config['deep_supervision'])model = model.cuda()# Data loading codeimg_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)model.load_state_dict(torch.load(f'models/{config["name"]}/model.pth'))model.eval()val_transform = Compose([albu.Resize(config['input_h'], config['input_w']),albu.Normalize(),])val_dataset = CustomDataset(img_ids=val_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=val_transform)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=config['num_workers'],drop_last=False)avg_meter = AverageMeter()for c in range(config['num_classes']):os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)with torch.no_grad():for input, target, meta in tqdm(val_loader, total=len(val_loader)):input = input.cuda()target = target.cuda()# 将元组转换为字典meta_dict = {'img_id': meta}print(f"meta_dict: {meta_dict}")# compute outputif config['deep_supervision']:output = model(input)[-1]else:output = model(input)iou = iou_score(output, target)avg_meter.update(iou, input.size(0))output = torch.sigmoid(output).cpu().numpy()for i in range(len(output)):for c in range(config['num_classes']):cv2.imwrite(os.path.join('outputs', config['name'], str(c), str(meta_dict['img_id'][i]) + '.png'),(output[i, c] * 255).astype('uint8'))print('IoU: %.4f' % avg_meter.avg)plot_examples(input, target, model, num_examples=3)torch.cuda.empty_cache()def plot_examples(datax, datay, model, num_examples=6):fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18, 4 * num_examples))m = datax.shape[0]for row_num in range(num_examples):image_indx = np.random.randint(m)image_arr = model(datax[image_indx:image_indx + 1]).squeeze(0).detach().cpu().numpy()ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])ax[row_num][0].set_title("Orignal Image")ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0, :, :].astype(int)))ax[row_num][1].set_title("Segmented Image localization")ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1, 2, 0))[:, :, 0])ax[row_num][2].set_title("Target image")plt.show()if __name__ == '__main__':main()
- archs.py
import torch
from torch import nn__all__ = ['UNet', 'NestedUNet']# 基本计算单元
class VGGBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.relu = nn.ReLU(inplace=True)self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)self.bn1 = nn.BatchNorm2d(middle_channels)self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, x):# VGGBlock实际上就是相当于做了两次卷积out = self.conv1(x)out = self.bn1(out) # 归一化out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)return outclass UNet(nn.Module):def __init__(self, num_classes, input_channels=3, **kwargs):super().__init__()nb_filter = [32, 64, 128, 256, 512]self.pool = nn.MaxPool2d(2, 2)self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)#scale_factor:放大的倍数 插值self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv2_2 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv1_3 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):x0_0 = self.conv0_0(input)x1_0 = self.conv1_0(self.pool(x0_0))x2_0 = self.conv2_0(self.pool(x1_0))x3_0 = self.conv3_0(self.pool(x2_0))x4_0 = self.conv4_0(self.pool(x3_0))x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))x2_2 = self.conv2_2(torch.cat([x2_0, self.up(x3_1)], 1))x1_3 = self.conv1_3(torch.cat([x1_0, self.up(x2_2)], 1))x0_4 = self.conv0_4(torch.cat([x0_0, self.up(x1_3)], 1))output = self.final(x0_4)return outputclass NestedUNet(nn.Module):def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):super().__init__()# 定义了一个列表,包含NestedUNet中不同层的通道数nb_filter = [32, 64, 128, 256, 512]# 深度监督:是否需要都计算损失函数self.deep_supervision = deep_supervisionself.pool = nn.MaxPool2d(2, 2) # 最大池化,池化核大小为2x2,步幅为2# 创建一个上采样层实例,尺度因子为2,采用双线性插值的方式进行上采样,边缘对齐方式为Trueself.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])if self.deep_supervision:self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)else:self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)def forward(self, input):# 入口函数打个断点,看数据的维度很重要print('input:', input.shape)x0_0 = self.conv0_0(input) # 第一次卷积print('x0_0:',x0_0.shape) # 升维 input: torch.Size([8, 32, 96, 96])x1_0 = self.conv1_0(self.pool(x0_0))print('x1_0:', x1_0.shape) # 升维,降数据量,x1_0: torch.Size([8, 32, 96, 96])x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1))# cat 拼接,再经历一次卷积,input是96=32+64,output=32print('x0_1:', x0_1.shape) # x0_1: torch.Size([8, 32, 96, 96])# 梳理清楚一个关键点即可,后面依次类推,可以打印结果自己手动推一下x2_0 = self.conv2_0(self.pool(x1_0))print('x2_0:', x2_0.shape)x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))print('x1_1:',x1_1.shape)x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))print('x0_2:',x0_2.shape)x3_0 = self.conv3_0(self.pool(x2_0))print('x3_0:',x3_0.shape)x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))print('x2_1:',x2_1.shape)x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))print('x1_2:',x1_2.shape)x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))print('x0_3:',x0_3.shape)x4_0 = self.conv4_0(self.pool(x3_0))print('x4_0:',x4_0.shape)x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))print('x3_1:',x3_1.shape)x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))print('x2_2:',x2_2.shape)x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))print('x1_3:',x1_3.shape)x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))print('x0_4:',x0_4.shape)if self.deep_supervision:output1 = self.final1(x0_1)output2 = self.final2(x0_2)output3 = self.final3(x0_3)output4 = self.final4(x0_4)return [output1, output2, output3, output4]else:# 输出一个结果,结果是0~1之间output = self.final(x0_4)return output
archs解读
conv00代表着图中的X00,conv20代表图中的X20,以此类推。
每一个vggblock
3.3 结果
4. 参考链接
深度学习分割任务——Unet++分割网络代码详细解读