图像分割实战-系列教程4:unet医学细胞分割实战2(医学数据集、图像分割、语义分割、unet网络、代码逐行解读)

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

上篇内容:
unet医学细胞分割实战1
下篇内容:
unet医学细胞分割实战3

3、指定训练参数

"""
指定参数:
--dataset dsb2018_96 
--arch NestedUNet
"""

指定数据集和网络架构的参数后,执行train.py

4、train.py主函数解析

4.1 读取配置文件

def main():config = vars(parse_args())if config['name'] is None:if config['deep_supervision']:config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])else:config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])os.makedirs('models/%s' % config['name'], exist_ok=True)print('-' * 20)for key in config:print('%s: %s' % (key, config[key]))print('-' * 20)with open('models/%s/config.yml' % config['name'], 'w') as f:yaml.dump(config, f)
  1. main函数
  2. 解析命令行参数为字典
  3. 检查 config[‘name’] 是否为 None,如果是
  4. 它根据 config[‘deep_supervision’] 的布尔值来设置 config[‘name’], 如果config[‘deep_supervision’] 的值为True
  5. config[‘dataset’] 和 config[‘arch’] 的值,并在末尾添加 ‘_wDS’(表示“with Deep Supervision”)
  6. 如果为False,末尾则添加 ‘_woDS’(表示“without Deep Supervision”)
  7. 使用 config[‘name’] 来创建一个新目录。这个目录位于 ‘models/’ 目录下,目录名是 config[‘name’] 的值,exist_ok=True 参数的意思是如果目录已经存在,则不会抛出错误
  8. 打印符号
  9. 打印所有配置参数的名字和默认值
  10. 打印符号
  11. 根据模型名称创建一个.yaml文件
  12. 把所有配置信息全部写入文件中

4.2 定义模型参数

if config['loss'] == 'BCEWithLogitsLoss':criterion = nn.BCEWithLogitsLoss().cuda()#WithLogits 就是先将输出结果经过sigmoid再交叉熵
else:criterion = losses.__dict__[config['loss']]().cuda()
cudnn.benchmark = True
print("=> creating model %s" % config['arch'])
model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'])
model = model.cuda()
params = filter(lambda p: p.requires_grad, model.parameters())
  1. 定义损失函数,如果损失函数的配置的默认字符参数为BCEWithLogitsLoss
  2. 那么使用 PyTorch 中的 nn.BCEWithLogitsLoss 作为损失函数,并且将损失函数的计算移入到GPU中计算,加快速度
  3. 如果不是
  4. 则从 losses.__dict__ 中查找对应的损失函数,同样使用 .cuda() 方法将损失函数移动到 GPU。(losses.__dict__ 应该是一个包含了多种损失函数的字典,其中键是损失函数的名称,值是相应的损失函数类,这个类是我们自己写的,在后面会解析)
  5. 启用 CUDA 深度神经网络(cuDNN)的自动调优器,当设置为 True 时,cuDNN 会自动寻找最适合当前配置的算法来优化运行效率,这在使用固定尺寸的输入数据时往往可以加快训练速度
  6. 打印当前创建的模型的名字
  7. 动态实例化一个模型, archs 是一个包含多个网络架构的模块, archs.__dict__[config['arch']] 这部分代码通过查找 archs 对象的 __dict__ 属性来动态地选择一个网络架构, __dict__ 是一个包含对象所有属性的字典。在这里,它被用来获取名为 config['arch'] 的网络架构类,config['arch'] 是一个字符串,表示所选用的架构名称
  8. 模型放入GPU中

4.3 定义优化器、调度器等参数

if config['optimizer'] == 'Adam':optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])elif config['optimizer'] == 'SGD':optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],nesterov=config['nesterov'], weight_decay=config['weight_decay'])else:raise NotImplementedErrorif config['scheduler'] == 'CosineAnnealingLR':scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=config['epochs'], eta_min=config['min_lr'])elif config['scheduler'] == 'ReduceLROnPlateau':scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'],verbose=1, min_lr=config['min_lr'])elif config['scheduler'] == 'MultiStepLR':scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])elif config['scheduler'] == 'ConstantLR':scheduler = Noneelse:raise NotImplementedError
  1. 创建一个过滤器,它筛选出神经网络模型中所有需要梯度(即可训练的)参数, model.parameters(),返回模型的权重和偏置,lambda p: p.requires_grad: 这是一个匿名函数(lambda 函数),用于检查每个参数 p 是否需要梯度
  2. 如果优化器是 Adam
  3. 则指定参数、学习率、学习率衰减的参数给Adam
  4. 如果是SGD
  5. 则指定参数、学习率、学习率衰减的参数给SGD,此外还有momentum动量加速,此外还使用了一个自定义的类型转换函数 str2bool 来处理输入值
  6. 如果两者都不是
  7. 返回错误
  8. 如果学习率调度器为CosineAnnealingLR
  9. 给该调度器,指定优化器、epochs、最小学习率
  10. 如果是ReduceLROnPlateau
  11. 给该调度器,指定优化器、指定调整学习率时的乘法因子、指定在性能不再提升时减少学习率要等待多少周期、verbose=1: 这个设置意味着调度器会在每次更新学习率时打印一条信息、最小学习率
  12. 如果是MultiStepLR
  13. 给该调度器,指定优化器、何时降低学习率的周期数、gamma值
  14. 如果是ConstantLR
  15. 调度器为None
  16. 如果都不是
  17. 返回错误

4.4 数据增强

    img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)#数据增强:train_transform = Compose([transforms.RandomRotate90(),transforms.Flip(),OneOf([ transforms.HueSaturationValue(), transforms.RandomBrightness(), transforms.RandomContrast(), ], p=1),#按照归一化的概率选择执行哪一个transforms.Resize(config['input_h'], config['input_w']),transforms.Normalize(),])val_transform = Compose([transforms.Resize(config['input_h'], config['input_w']),transforms.Normalize(),])
  1. 从本地文件夹inputs,根据config[‘dataset’]的值选择一个数据集,然后images文件,*代表后面所有的文件名称,加上config[‘img_ext’]对应的后缀,返回一个列表,列表的每个元素都是每条数据的路径加文件名和后缀名组成的字符串,类似这种形式:[‘inputs/dataset_name/images/image1.png’, ‘inputs/dataset_name/images/image2.png’, ‘inputs/dataset_name/images/image3.png’]

  2. for p in img_ids按照每个字符串包含的信息,进行遍历,os.path.basename(p)从每个路径 p 中提取文件名,os.path.splitext(...)[0] 则从文件名中去除扩展名,留下文件的基本名称(即 ID),最后是一个只包含文件名的list,即:[‘image1’, ‘image2’, ‘image3’]

  3. 使用sklearn包的train_test_split函数,按照80%和20%的比例分为训练集和验证集,并且打乱数据集,41是随机种子

  4. 训练集数据增强

  5. 随机以 90 度的倍数旋转图像进行数据增强

  6. 水平或垂直翻转图像进行数据增强

  7. 从调整色调和饱和度和值(HSV)、随机调整图像的亮度、随机调整图像的对比度这个方式中随机选择一个进行数据增强

  8. 将图像调整到指定的高度和宽度

  9. 对图像进行标准化(比如减去均值,除以标准差)

  10. 验证集同样进行调整,是为了和训练集的尺寸、标准化等保存一致

  11. 调整和训练集一样的长宽

  12. 调整和训练一样的 标准化处理

4.5 数据集制作

 train_dataset = Dataset(img_ids=train_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=train_transform)
val_dataset = Dataset(img_ids=val_img_ids,img_dir=os.path.join('inputs', config['dataset'], 'images'),mask_dir=os.path.join('inputs', config['dataset'], 'masks'),img_ext=config['img_ext'],mask_ext=config['mask_ext'],num_classes=config['num_classes'],transform=val_transform)
  1. 使用自己写的数据集类制作训练数据集
  2. 返回图像数据id
  3. 返回图像数据路径
  4. 返回掩码数据路径
  5. 返回后缀
  6. 返回掩码后缀
  7. 分类的种类
  8. 数据增强(这里制定为None)
  9. 同样的给验证集也来一遍
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=config['batch_size'],shuffle=True,num_workers=config['num_workers'],drop_last=True)
val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=config['batch_size'],shuffle=False,num_workers=config['num_workers'],drop_last=False)log = OrderedDict([ ('epoch', []), ('lr', []), ('loss', []), ('iou', []), ('val_loss', []), ('val_iou', []), ])
  1. 制作训练集Dataloader
  2. 指定训练数据集
  3. batch_size
  4. 洗牌操作
  5. 进程数
  6. 不能整除的batch是否就不要了
  7. 同样的给验证集也来一遍
  8. 最后一行日志记录:创建OrderedDict 对象 log,将’epoch’、‘lr’、‘loss’、‘iou’、‘val_loss’、'val_iou’按照类似字典的形式进行存储(与字典不同的是它会记住插入元素的顺序)

4.6 迭代训练

    best_iou = 0trigger = 0for epoch in range(config['epochs']):print('Epoch [%d/%d]' % (epoch, config['epochs']))train_log = train(config, train_loader, model, criterion, optimizer)val_log = validate(config, val_loader, model, criterion)if config['scheduler'] == 'CosineAnnealingLR':scheduler.step()elif config['scheduler'] == 'ReduceLROnPlateau':scheduler.step(val_log['loss'])print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))log['epoch'].append(epoch)log['lr'].append(config['lr'])log['loss'].append(train_log['loss'])log['iou'].append(train_log['iou'])log['val_loss'].append(val_log['loss'])log['val_iou'].append(val_log['iou'])pd.DataFrame(log).to_csv('models/%s/log.csv' % config['name'], index=False)trigger += 1if val_log['iou'] > best_iou:torch.save(model.state_dict(), 'models/%s/model.pth' % config['name'])best_iou = val_log['iou']print("=> saved best model")trigger = 0if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:print("=> early stopping")breaktorch.cuda.empty_cache()
  1. 记录最好的IOU的值
  2. trigger 是一个计数器,用于追踪自从模型上次改进(即达到更好的验证 IoU)以来经过了多少个训练周期(epochs),这种技术通常用于实现早停(early stopping)机制,以避免过度拟合
  3. 按照epochs进行迭代训练
  4. 打印当前epochs数,即训练进度
  5. 使用训练函数进行单个epoch的训练
  6. 使用验证函数进行单个epoch的验证
  7. 当使用 CosineAnnealingLR 调度器时
  8. scheduler.step() 被直接调用,无需任何参数
  9. 当使用 ReduceLROnPlateau 调度器时
  10. scheduler.step(val_log['loss']) 调用时传入了验证集的损失 val_log[‘loss’] 作为参数
  11. 打印当前epoch训练损失、训练iou
  12. 打印当前epoch验证损失、验证iou
  13. 当前epoch索引加入日志字典中
  14. 当前学习率值加入日志字典中
  15. 当前训练损失加入日志字典中
  16. 当前训练iou加入日志字典中
  17. 当前验证损失加入日志字典中
  18. 当前验证iou加入日志字典中
  19. 当前日志信息保存为csv文件
  20. trigger +1
  21. 如果当前验证iou的值比当前记录最佳iou的值要好
  22. 保存当前模型文件
  23. 更新最佳iou的值
  24. 打印保存了当前的最好模型
  25. 把trigger 置0
  26. 如果当前记录的trigger的值大于提前设置的trigger阈值
  27. 打印提前停止
  28. 停止训练
  29. 清除GPU缓存

自此,train.py的main函数部分全部解读完毕,其中有多个子函数或者类,在下一篇文章中继续解读
上篇内容:
unet医学细胞分割实战1
下篇内容:
unet医学细胞分割实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/588819.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

再薅!Pika全球开放使用;字节版GPTs免费不限量;大模型应用知识地图;MoE深度好文;2024年AIGC发展轨迹;李飞飞最新自传 | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 👀 终于!AI视频生成平台 Pika 面向所有用户开放网页端 https://twitter.com/pika_labs Pika 营销很猛,讲述的「使…

qt中信号槽第五个参数

文章目录 connent函数第五个参数的作用自动连接(Qt::AutoConnection)直接连接(Qt::DirectConnection - 同步)同线程不同线程 队列连接(Qt::QueuedConnection - 异步)同一线程不同线程 锁定队列连接(Qt::BlockingQueuedConnection) connent函数第五个参数的作用 connect(const …

LSTM Siamese neural network

本文中的代码在Github仓库或Gitee仓库中可找到。 Hi, 你好。我是茶桁。 大家是否还记得,在「核心基础」课程中,我们讲过CNN以及LSTM。 卷积神经网络(CNN)已经在计算机视觉处理中得到广泛应用,不过,2017年…

Shell脚本自动化部署LAMP环境

[rootlocalhost ~]# vim liang.sh #!/bin/bash# LAMP终极部署cat <<-EOF-------------------------------------------------------------------------| LAMP终极部署 V1.0 |-------------------------------------------------------------------------| a. 部署Apache服…

Go 泛型之明确使用时机与泛型实现原理

Go 泛型之明确使用时机与泛型实现原理 文章目录 Go 泛型之明确使用时机与泛型实现原理一、引入二、何时适合使用泛型&#xff1f;场景一&#xff1a;编写通用数据结构时场景二&#xff1a;函数操作的是 Go 原生的容器类型时场景三&#xff1a;不同类型实现一些方法的逻辑相同时…

pycharm python环境安装

目录 1.Python安装 2.PyQt5介绍 3.安装pyuic 4.启动designer.exe 5.pyinstaller(打包发布程序) 6.指定源安装 7.PyQt5-tools安装失败处理 8.控件介绍 9.错误记录 1.NameError: name reload is not defined 10.开发记录 重写报文输出和文件 ​编辑 1.Python安装 点…

docker里面不能使用vim的解决办法

docker里面不能使用vim的解决办法 目录 docker里面不能使用vim的解决办法 1.在使用时会出现 2.在使用这些都不能解决的时候考虑 3.测试是否可用 1.在使用时会出现 bash: vim: command not found 出现这种错误时首先考虑使用 apt-get update 然后在用 apt-get install …

Oracle中decode函数详解

Oracle中decode函数详解 大家好&#xff0c;我是免费搭建查券返利机器人赚佣金就用微赚淘客系统3.0的小编&#xff0c;也是冬天不穿秋裤&#xff0c;天冷也要风度的程序猿&#xff01;今天&#xff0c;我们将深入探讨Oracle数据库中的DECODE函数&#xff0c;这是一种强大的条件…

大模型入门0: 基础知识

transformerscaling law分布式训练 自然语言处理包括几大任务 NLP: 文本分类&#xff0c;词性标注&#xff0c;信息检索NLG&#xff1a;机器翻译&#xff0c;自动摘要&#xff0c;问答QA、对话机器ChatBot Transformer T5 Bert GPT in context learning: (zero-shot tra…

vue3中pinia的使用及持久化(详细解释)

解释一下pinia&#xff1a; Pinia是一个基于Vue3的状态管理库&#xff0c;它提供了类似Vuex的功能&#xff0c;但是更加轻量化和简单易用。Pinia的核心思想是将所有状态存储在单个store中&#xff0c;并且将store的行为和数据暴露为可响应的API&#xff0c;从而实现数据&#…

cnn lstm结合网络

目录 特征处理例子&#xff1a; cnn 5张图片一组&#xff0c;提取特征后&#xff0c;再给lstm&#xff0c;进时间序列分类。 特征处理例子&#xff1a; import torch# 假设 tensor 是形状为 15x64 的张量 tensor torch.arange(15 * 2).reshape(15, 2) # 生成顺序编号的张量&…

中国历史长河图

历史是一种传承和记忆&#xff0c;不管你是否承认&#xff0c;他就在那里。你也身处其中&#xff0c;就像一条小鱼身处波澜壮阔的大河中&#xff0c;没留下一点痕迹。 了解历史&#xff0c;不是只为了多知道些古代人物、历史事件&#xff0c;或者为了应付考试。而是应该想到&am…

2024年元旦,祝福所有的人和事物

愿风调雨顺&#xff0c;国泰民安。 愿人生平安健康&#xff0c;安居乐业&#xff0c;福慧增长&#xff0c;丰足富饶。 愿我们能一起进步。

今年努力输出的嵌入式Linux视频

今年努力了一波&#xff0c;几个月周六日无休&#xff0c;自己在嵌入式linux工作有些年头&#xff0c;结合自己也是一直和SLAM工程师对接&#xff0c;所以输出了一波面向SLAM算法工程师Linux课程&#xff0c;当然嵌入式入门的同学也可以学习。下面是合作的官方前面发的宣传文章…

【c++】使用vector存放键值对时,明明给vector的不同键赋了不同的值,但为什么前面键的值会被后面键的值给覆盖掉?

错误描述 运行程序得到结果如下图所示&#xff08;左边是原始数据&#xff0c;xxml文件中真实数据的样子&#xff0c;右图是程序运行得到的结果结果&#xff09;&#xff1a; 对比以上两图可以发现&#xff0c;右图中两个实例的三个属性值都来自左图中的第二个User实例&#x…

【模拟电路】软件Circuit JS

一、模拟电路软件Circuit JS 二、Circuit JS软件配置 三、Circuit JS 软件 常见的快捷键 四、Circuit JS软件基础使用 五、Circuit JS软件使用讲解 欧姆定律电阻的串联和并联电容器的充放电过程电感器和实现理想超导的概念电容阻止电压的突变&#xff0c;电感阻止电流的突变LR…

一二三应用开发平台文件处理设计与实现系列之3——后端统一封装设计与实现

背景 前面介绍了前端通过集成vue-simple-uploader实现了文件的上传&#xff0c;今天重点说一下后端的设计与实现。 功能需求梳理 从功能角度而言&#xff0c;实际主要就两项&#xff0c;一是上传&#xff0c;二是下载。其中上传在文件体积较大的情况下&#xff0c;为了加快上…

vue3 element plus el-table封装(二)

上文是对el-table的基本封装&#xff0c;只能满足最简单的应用&#xff0c;本文主要是在上文的基础上增加slot插槽&#xff0c;并且对col插槽进行拓展&#xff0c;增加通用性 // BaseTable.vue <template><el-table><template v-for"name in tableSlots&…

Hadoop安装笔记1单机/伪分布式配置_Hadoop3.1.3——备赛笔记——2024全国职业院校技能大赛“大数据应用开发”赛项——任务2:离线数据处理

将下发的ds_db01.sql数据库文件放置mysql中 12、编写Scala代码&#xff0c;使用Spark将MySQL的ds_db01库中表user_info的全量数据抽取到Hive的ods库中表user_info。字段名称、类型不变&#xff0c;同时添加静态分区&#xff0c;分区字段为etl_date&#xff0c;类型为String&am…

年度总结 | 回味2023不平凡的一年

目录 前言1. 平台成就2. 自我提升3. Bug连连4. 个人展望 前言 每年CSDN的总结都不能落下&#xff0c;回顾去年&#xff1a;年度总结 | 回味2022不平凡的一年&#xff0c;在回忆今年&#xff0c;展望下年 1. 平台成就 平台造就我&#xff08;我也造就平台哈哈&#xff09; 每…