【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割1(综述篇)

在上一个关于3D 目标的任务,是基于普通CNN网络的3D分类任务。在这个任务中,分类数据采用的是CT结节的LIDC-IDRI数据集,其中对结节的良恶性、毛刺、分叶征等等特征进行了各自的等级分类。感兴趣的可以直接点击下方的链接,直达学习:

  1. 【3D图像分类】基于Pytorch的3D立体图像分类1(基础篇)
  2. 【3D图像分类】基于Pytorch的3D立体图像分类2(数据增强篇)

在开始本次关于3D 目标的分割任务前呢,我还是建议先去看看上述较为简单的分类任务,毕竟大多数是相似的,有很高的借鉴意义。

一、导言

准备一个训练,需要下面这些内容组成:

  1. 准备数据
  2. 准备网络
  3. 搭建训练主模型
    • train one epoch
    • valid one epoch
    • 存储模型
    • 存储指标
  4. loss 函数
  5. dice coeff 评估指标
  6. optimizer优化方式

其中,在本项目中:

  1. 网络采用vnet 3d模型
  2. 数据采用patch裁剪大小
  3. loss函数未dice loss
  4. 评价指标是dice coeff
  5. optimizer优化方式是SGD

二、搭建主结构

训练的主体结构(骨架),总数包括几个部分:

  1. config:可调参数定义,包括数据路径、图像大小、类别数量、学习率、batch size等等;
  2. main:主函数,包括:
    • 构建模型
    • 构建数据
    • 优化器
    • 学习率变化方式
    • 损失函数
    • 评估指标
    • 训练batch循环
    • 验证batch循环
  3. 后处理:包括模型参数存储,指标走势绘图等等。

上面这些个内容,基本上是囊括了深度学习模型训练的整体结构了,后面的工作就是对每一部分进行补充。就犹如已经有了骨架,后续就是补充肉身了。

后面给出的这个pytorch骨架案例,也是后面再构建训练任务,一个可以参考的依据,可收藏。

2.1、导入库和配置参数

import os
import matplotlib.pyplot as plt
import torch.utils.data
import torch.optim as optimfrom datasets.datasets import myDatasetos.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3"  # 使用gpu0
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 没gpu就用cpu
print(DEVICE)############################################################
# Configuration
############################################################
class Configuration(object):train_path = r"./database/sk_output/train"valid_path = r"./database/sk_output/valid"model_path = r'./checkpoints'Crop_Size = (48, 96, 96)num_outs = 2Batch_Train = 32Batch_Test = 16Max_epoch = 220Num_Workers = 8Dice_Best = 0LR = 0.0003momentum = 0.99weight_decay = 1e-8def display(self):"""Display Configuration values."""print("\nConfigurations:")print("")for a in dir(self):if not a.startswith("__") and not callable(getattr(self, a)):print("{:30} {}".format(a, getattr(self, a)))print("\n")

2.2、构建main主函数

def main():Config = Configuration()Config.display()train_loader, valid_loader = get_Dataloader(Config)model = get_model(Config).to(DEVICE)# ---- OPTIMIZER ----optimizer = optim.SGD(model.parameters(), lr=Config.LR, momentum=Config.momentum, weight_decay=Config.weight_decay)train_loss_list = []  # 用来记录训练损失valid_loss_list = []  # 用来记录验证损失valid_dice_list = []epoch_list = []for epoch in range(1, Config.Max_epoch + 1):epoch_list.append(epoch)train_loss = train_model(model, DEVICE, train_loader, optimizer, epoch)  # 训练valid_loss, valid_dice = valid_model(model, DEVICE, valid_loader, epoch)  # 验证train_loss_list.append(train_loss)valid_loss_list.append(valid_loss)valid_dice_list.append(valid_dice)draw_plot(epoch_list, valid_dice_list, 'valid_dice')draw_plot(epoch_list, valid_loss_list, 'valid_loss')draw_plot(epoch_list, train_loss_list, 'train_loss')if valid_dice > Config.Dice_Best:path_ckpt = os.path.join(Config.model_path, 'best_model.pth')save_model(path_ckpt, model)Config.Dice_Best = valid_diceelse:path_ckpt = os.path.join(Config.model_path, 'last_model.pth')save_model(path_ckpt, model)print('best val Dice is ', Config.Dice_Best)if __name__ == '__main__':main()

2.3、构建获取模型和数据的函数

def get_model(config):from models.vnet3d import VNet3Dmodel = VNet3D(num_outs=config.num_outs, channels=16)model = model.to(DEVICE)  # 模型部署到gpu或cpu里model = torch.nn.DataParallel(model).to(DEVICE)return modeldef get_Dataloader(config):# get train datadataset_train = myDataset(config.train_path, config.Crop_Size, isTrain=True)print(len(dataset_train))train_loader = torch.utils.data.DataLoader(dataset_train,batch_size=config.Batch_Train, shuffle=True,num_workers=config.Num_Workers, drop_last=False)# get valid datadataset_valid = myDataset(config.valid_path, config.Crop_Size, isTrain=False)valid_loader = torch.utils.data.DataLoader(dataset_valid,batch_size=config.Batch_Test, shuffle=False,num_workers=config.Num_Workers, drop_last=False)return train_loader, valid_loader

2.4、构建训练循环和验证循环

def train_model(model, device, train_loader, optimizer, epoch):config = Configuration()model.train()for batch_index, (data, target) in enumerate(train_loader):  # 取batch索引,(data,target),也就是图和标签data, target = data.to(device), target.to(device)output = model(data)loss = Loss(output, target)optimizer.zero_grad()  # 梯度归零loss.backward()  # 反向传播optimizer.step()  # 优化器走一步return losses.avg  # 返回平均损失,损失列表def valid_model(model, device, test_loader, epoch):config = Configuration()model.eval()with torch.no_grad():  # 不进行 梯度计算(反向传播)for batch_index, (data, target) in enumerate(test_loader):  # 枚举batch索引,(图,标签)data, target = data.to(device), target.to(device)output = model(data)loss = Loss(output, target)  # 计算损失return losses.avg, multi_dices.avg

2.5、后处理

保存模型的参数,和绘制训练过程中train loss、valid loss,以及valid dice走势图,如下:

def draw_plot(x_list, y_list, title_name):plt.plot(x_list, y_list, label=title_name)plt.xlabel('x', fontsize=15)plt.ylabel('y', fontsize=15)plt.title(title_name, fontsize=15)plt.savefig('./logs/cure.png')def save_model(path, model):if isinstance(model, torch.nn.DataParallel):state_dict = model.module.state_dict()else:state_dict = model.state_dict()torch.save(state_dict, path)

至此,每一个模块都有了对应的归宿,后面就是如何将缺漏的地方,补全过程了。反倒是这部分的代码相对较少,两大需要单独验证的数据和模型是大头,其他就好办了。

三、总结

本文是关于PytorchVNet 3D 图像分割的第一篇,也就是一个综述篇,主要是对这个项目的任务目的,以及其中的一个流程进行了梳理。

上述的骨干代码还不能够作为训练使用,还需要补充进去骨肉,才能够适应不同的任务,这一块的内容将会在后面的几个篇章中,一一陈述。

如果你也在做类似的事情,欢迎点赞、收藏,mark住。对于这部分的内容可以一起交流,欢迎多多评论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/114593.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

在 Android 10 中访问/proc/net/route权限被拒绝

在 Android 10 中访问/proc/net/route权限被拒绝 问题分析完整代码问题 FileReader fr = new FileReader(“/proc/net/route”);在 Android 10 中访问/proc/net/route权限被拒绝 分析 运行/proc/net/route命令并处理其输出: val runtime = Runtime.getRuntime() val proc …

Mysql数据库 2.SQL语言 数据类型与字段约束

Mysql数据类型 数据类型:指的是数据表中的列文件支持存放的数据类型 1.数值类型 Mysql当中有多种数据类型可以存放数值,不同的类型存放的数值的范围或者形式是不同的 注:前三种数字类型我们在实际研发中用的很少,一般整数类型…

[论文笔记]NEZHA

引言 今天带来华为诺亚方舟实验室提出的论文NEZHA,题目是 针对中文中文语言理解神经网络上下文表示(NEural contextualiZed representation for CHinese lAnguage understanding),为了拼出哪吒。 预训练语言模型由于具有通过对大型语料库进行预训练来捕获文本中深层上下文信…

【每日一题Day352】LC1726同积元组 | 哈希表+排列组合

同积元组【LC1726】 给你一个由 不同 正整数组成的数组 nums ,请你返回满足 a * b c * d 的元组 (a, b, c, d) 的数量。其中 a、b、c 和 d 都是 nums 中的元素,且 a ! b ! c ! d 。 思路 求出所有二元组的积及其出现次数,假设某个积出现的次…

空中计算(Over-the-Air Computation)学习笔记

文章目录 写在前面 写在前面 本文是论文A Survey on Over-the-Air Computation的阅读笔记: 通信和计算通常被视为独立的任务。 从工程的角度来看,这种方法是非常有效的,因为可以执行孤立的优化。 然而,对于许多面向计算的应用程序…

Docker镜像制作

目录 Dockfile是什么 构建镜像的三个步骤 dockerfile内容基础知识 docker执行一个Dockerfile脚本的大致流程 Dockerfile指令 FROM MAINTAINER RUN EXPOSE WORKDIR ENV ADD COPY VOLUME USER ONBUILD CMD ENTRYPOINT CMD和ENTRYPOINT区别 构建dockerfile Do…

shell之常见网络命令介绍

shell之常见网络命令介绍 1)ifconfig 用于配置网络接口。可以用于开启、关闭和设置网络接口的参数,如IP地址、子网掩码、MAC地址等。 ifconfig eth0 192.168.1.1 netmask 255.255.255.0 up上述命令将设置eth0网络接口的IP地址为192.168.1.1,子…

leetcode(2)栈

leetcode 155 最小栈 stack相当于栈,先进后出 存储全部栈元素 [-3,2,-1] min_stack,存储栈当前位置最小的元素 [-3,-3,-3] class MinStack:def __init__(self):self.stack []self.min_stack [math.inf]def push(self, x: int) :self.stack.append(x)self.min_sta…

游戏反虚拟框架检测方案

游戏风险环境,是指独立于原有设备或破坏设备原有系统的环境。常见的游戏风险环境有:iOS越狱、安卓设备root、虚拟机、虚拟框架、云手机等。 因为这类风险环境可以为游戏外挂、破解提供所需的高级别设备权限,所以当游戏处于这些设备环境下&am…

ARM可用的可信固件项目简介

安全之安全(security)博客目录导读 目录 一、TrustedFirmware-A (TF-A) 二、MCUboot 三、TrustedFirmware-M (TF-M) 四、TF-RMM 五、OP-TEE 六、Mbed TLS 七、Hafnium 八、Trusted Services 九、Open CI 可信固件为Armv8-A、Armv9-A和Armv8-M提供了安全软件的参考实现…

【UE5】 ListView使用DataTable数据的蓝图方法

【UE5】 ListView使用DataTable数据的蓝图方法 ListView 是虚幻引擎中的一种用户界面控件,用于显示可滚动的列表。它可以用于显示大量的数据,并提供了各种功能和自定义选项来满足不同的需求。 DataTable是虚幻引擎中的一种数据表格结构,用于存…

Vue Router - 路由的使用、两种切换方式、两种传参方式、嵌套方式

目录 一、Vue Router 1.1、下载 1.2、基本使用 a)引入 vue-router.js(注意:要在 Vue.js 之后引入). b)创建好路由规则 c)注册到 Vue 实例中 d)展示路由组件 1.3、切换路由的两种方式 1.…

ubuntu20.04 nvidia显卡驱动掉了,变成开源驱动,在软件与更新里选择专有驱动,下载出错,调整ubuntu镜像源之后成功修复

驱动配置好,环境隔了一段时间,打开Ubuntu发现装好的驱动又掉了,软件与更新 那里,附加驱动,显示开源驱动,命令行输入 nvidia-smi 命令查找不到驱动。 点击上面的 nvidia-driver-470(专有&#x…

wps excel js编程

定义全局变量 const a "dota" function test() {Debug.Print(a) }获取表格中单元格内容 function test() {Debug.Print("第一行第二列",Cells(1,2).Text)Debug.Print("A1:",Range("A1").Text) }写单元格 Range("C1").Val…

【前端设计模式】之状态模式

引言 在前端开发中,我们经常需要处理复杂的应用状态。这时候,状态模式就能派上用场了。状态模式允许我们根据不同的状态来改变对象的行为,从而实现优雅地管理应用状态。 状态模式的特性 状态模式具有以下特性: 状态&#xff0…

【面试经典150 | 栈】有效的括号

文章目录 Tag题目来源题目解读解题思路方法一:栈哈希表 其他语言cpython3 写在最后 Tag 【栈】 题目来源 20. 有效的括号 题目解读 括号有三种类型,分别是小括号、中括号和大括号,每种括号的左右两半括号必须一一对应才是有效的括号&#…

JetBrains系列IDE全家桶激活

jetbrains全家桶 正版授权,这里有账号授权的渠道: https://www.mano100.cn/thread-1942-1-1.html 附加授权后的一张图片

Spark---数据输出

1. 输出为Python对象 collect算子:将RDD各个分区内的数据,统一收集到Driver中,形成一个List对象 reduce算子:对RDD数据集按照传入的逻辑进行聚合 take算子:取RDD的前N个元素,组合成list返回给你 count…

HCIA -- 动态路由协议之RIP

一、静态协议的优缺点: 缺点: 1、中大型网络配置量过大 2、不能基于拓扑的变化而实时的变化 优点: 1、不会额外暂用物理资源 2、安全问题 3、计算路径问题 简单、小型网络建议使用静态路由;中大型较复杂网络,建议使用…

【MySQL】8.0新特性、窗口函数和公用表表达式

文章目录 1. 新增特性2. 移除旧特性2.1 优点2.2 缺点 3. 新特性1:窗口函数3.1 使用窗口函数前后对比3.2 窗口函数分类3.3 语法结构3.4 分类讲解3.4.1 序号函数3.4.1.1 ROW_NUMBER()函数3.4.1.2 RANK()函数3.4.1.3 DENSE_RANK()函数 3.4.2 分布函数3.4.2.1 PERCENT_R…