图像分割实战-系列教程17:deeplabV3+ VOC分割实战5-------main.py

在这里插入图片描述

🍁🍁🍁图像分割实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

10、main.py的main()函数

def main():opts = get_argparser().parse_args()if opts.dataset.lower() == 'voc':opts.num_classes = 21elif opts.dataset.lower() == 'cityscapes':opts.num_classes = 19
  1. 定义main函数
  2. 调用参数函数,解析命令行参数
  3. 判断数据集名称是否为voc,是则设置类别数为21,21=20个对象+1背景
    # Setup visualizationvis = Visualizer(port=opts.vis_port, env=opts.vis_env) if opts.enable_vis else Noneif vis is not None:  # display optionsvis.vis_table("Options", vars(opts))os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_iddevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print("Device: %s" % device)

设置可视化工具和配置训练设备

  1. 根据 opts.enable_vis 的值决定是否创建一个 Visualizer 对象,其端口和环境设置由 opts.vis_portopts.vis_env 提供。Visualizer 是一个python可视化工具类,用于在训练过程中显示图像、图表等信息
  2. 检查是否启用了可视化
  3. 如果可视化被启用,这一行调用 vis 对象的 vis_table 方法来显示配置选项。vars(opts) 是将 opts 对象转换为字典,其中包含了所有的命令行参数
  4. 设置环境变量 CUDA_VISIBLE_DEVICES,其值为 opts.gpu_id
  5. 如果GPU可用,则使用 CUDA;否则,使用 CPU
  6. 打印出使用的设备信息
    torch.manual_seed(opts.random_seed)np.random.seed(opts.random_seed)random.seed(opts.random_seed)if opts.dataset=='voc' and not opts.crop_val:opts.val_batch_size = 1train_dst, val_dst = get_dataset(opts)train_loader = data.DataLoader( train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=0)val_loader = data.DataLoader( val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=0)print("Dataset: %s, Train set: %d, Val set: %d" % (opts.dataset, len(train_dst), len(val_dst)))
  1. 分别为pytorch
  2. numpy
  3. python设置全局随机种子
  4. 检查是否使用VOC数据集并且没有启用验证集的裁剪
  5. 如果条件满足,将验证批次大小设置为 1
  6. 调用 get_dataset 函数来获取训练和验证数据集,该函数在第3部分已经解析
  7. 训练集DataLoader,opts.batch_size训练批次大小
  8. 验证集DataLoader,opts.val_batch_size验证批次大小
  9. 打印出使用的数据集名称以及训练和验证集的大小
    model_map = {'deeplabv3_resnet50': network.deeplabv3_resnet50,'deeplabv3plus_resnet50': network.deeplabv3plus_resnet50,'deeplabv3_resnet101': network.deeplabv3_resnet101,'deeplabv3plus_resnet101': network.deeplabv3plus_resnet101,'deeplabv3_mobilenet': network.deeplabv3_mobilenet,'deeplabv3plus_mobilenet': network.deeplabv3plus_mobilenet}model = model_map[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)if opts.separable_conv and 'plus' in opts.model:network.convert_to_separable_conv(model.classifier)utils.set_bn_momentum(model.backbone, momentum=0.01)

这部分设置网络的参数

  1. 定义一个模型映射字典,字典包括本项目可选择的多个网络:

  2. deeplabv3的resnet50

  3. deeplabv3+的resnet50

  4. deeplabv3的resnet101

  5. deeplabv3+的resnet101

  6. deeplabv3的mobilenet

  7. deeplabv3+的mobilenet,这些网络在Network文件夹中使用一定方法构建,在这里直接导入

  8. 从预设要选择的网络名称、类别数、输出通道数加载网络

  9. 检查是否启用可分离卷积、模型名称包含 ‘plus’

  10. 如果条件满足,则对模型的分类器部分应用可分离卷积的转换

  11. 调用set_bn_momentum函数,设置批量归一化的动量,set_bn_momentum函数:

    def set_bn_momentum(model, momentum=0.1):for m in model.modules():if isinstance(m, nn.BatchNorm2d):m.momentum = momentum
    
    metrics = StreamSegMetrics(opts.num_classes)optimizer = torch.optim.SGD(params=[{'params': model.backbone.parameters(), 'lr': 0.1*opts.lr},{'params': model.classifier.parameters(), 'lr': opts.lr},], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)if opts.lr_policy=='poly':scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)elif opts.lr_policy=='step':scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)if opts.loss_type == 'focal_loss':criterion = utils.FocalLoss(ignore_index=255, size_average=True)elif opts.loss_type == 'cross_entropy':criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')

评价指标、优化器、学习率、损失函数

  1. StreamSegMetrics类是一个评价指标,用于在语义分割任务中跟踪模型的性能,计算的指标包括总体准确率、平均准确率、频率加权平均准确率、平均交并比(IoU)和每个类别的 IoU,通过计算混淆矩阵和从中派生出的多个指标。实例化出一个对象
  2. 设置SGD为优化器,并设置一些参数
  3. 设置backbone的学习率,预训练的backbone部分通常需要较小的学习率来微调
  4. 设置分类器的学习率
  5. 设置优化器的初始全局学习率,设置动量帮助加速SGD并抑制震荡,设置权重衰减用于防止过拟合
  6. 如果学习率策略设为 “poly”
  7. 则使用多项式衰减
  8. 如果学习率策略设为 “step”
  9. 则使用阶梯式衰减
  10. 根据 opts.loss_type 的值选择合适的损失函数 ,如果损失函数类型为 “focal_loss”
  11. 则使用焦点损失(Focal Loss),这对于处理类别不平衡问题很有效、
  12. 如果损失函数类型为 “cross_entropy”
  13. 则使用交叉熵损失
	utils.mkdir('checkpoints')best_score = 0.0cur_itrs = 0cur_epochs = 0if opts.ckpt is not None and os.path.isfile(opts.ckpt):checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))model.load_state_dict(checkpoint["model_state"])model = nn.DataParallel(model)model.to(device)if opts.continue_training:optimizer.load_state_dict(checkpoint["optimizer_state"])scheduler.load_state_dict(checkpoint["scheduler_state"])cur_itrs = checkpoint["cur_itrs"]best_score = checkpoint['best_score']print("Training state restored from %s" % opts.ckpt)print("Model restored from %s" % opts.ckpt)del checkpoint  # free memoryelse:print("[!] Retrain")model = nn.DataParallel(model)model.to(device)

checkpoints,检查点

  1. 创建 “checkpoints” 文件夹,存储训练过程中的模型检查点
  2. 初始化最佳分数为0,记录验证集上的最高分数
  3. 初始化当前迭代次数为0
  4. 初始化当前轮数为0
  5. 检查是否提供了检查点文件,并且该文件确实存在
  6. 加载检查点文件,意味着无论检查点是在CPU还是GPU上保存的,都会先被加载到CPU内存中
  7. 从检查点中恢复模型的状态
  8. 使用 DataParallel 来利用多个GPU(如果可用)
  9. 模型放入GPU
  10. 检查是否需要从检查点继续训练
  11. 从检查点恢复优化器的状态
  12. 从检查点恢复学习率调度器的状态
  13. 从检查点恢复当前迭代次数
  14. 从检查点恢复当前最佳分数
  15. 打印模型恢复信息
  16. 删除检查点变量以释放内存
  17. 如果没有提供有效的检查点文件
  18. 打印重新训练的信息
  19. 使用 DataParallel 来利用多个GPU(如果可用)
  20. 模型放入GPU
    vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples, np.int32) if opts.enable_vis else Nonedenorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # denormalization for ori imagesif opts.test_only:model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))return

训练中的可视化设置、测试模式处理、初始化变量

  1. 选择可视化样本:如果启用了可视化则从验证数据加载器中随机选择一定数量的样本进行可视化,否则,vis_sample_id 设为 None
  2. 反标准化操作:创建一个反标准化对象,将用于将预处理(标准化)后的图像恢复到原始图像的颜色空间
  3. 检查是否仅进行测试,如果是:
  4. 将模型设置为评估模式
  5. 调用 validate 函数,返回评分、选定的样本
  6. 打印评估结果
  7. 结束函数执行。
    interval_loss = 0while True: #cur_itrs < opts.total_itrs:model.train()cur_epochs += 1for (images, labels) in train_loader:cur_itrs += 1images = images.to(device, dtype=torch.float32)labels = labels.to(device, dtype=torch.long)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()np_loss = loss.detach().cpu().numpy()interval_loss += np_loss
  1. 初始化间隔损失为0
  2. 开始无限循环(达到最大迭代次数停止)
  3. 开启模型训练模式
  4. 当前epoch+1
  5. 从训练集dataloader取出数据和标签
  6. 当前迭代次数+1
  7. 训练数据进入GPU
  8. 训练标签进入GPU
  9. 梯度清零
  10. 训练数据进入模型后得到输出
  11. 输出和标签通过损失函数计算损失
  12. 反向传播,计算损失相对于模型参数的梯度
  13. 根据计算的梯度更新模型的权重
  14. 将损失从GPU(如果使用)转移到CPU,并转换为NumPy格式
  15. 累加损失
            if vis is not None:vis.vis_scalar('Loss', cur_itrs, np_loss)if (cur_itrs) % 10 == 0:interval_loss = interval_loss/10print("Epoch %d, Itrs %d/%d, Loss=%f" % (cur_epochs, cur_itrs, opts.total_itrs, interval_loss))interval_loss = 0.0

可视化、日志记录部分:

  1. 条件可视化:检查是否创建了可视化对象
  2. 使用可视化工具来记录当前迭代的损失。 'Loss' 是要记录的数据的名称,cur_itrs 是当前迭代次数,np_loss 是当前迭代的损失值
  3. 每当当前迭代次数是10的倍数时执行以下操作:
  4. 计算过去10次迭代的平均损失
  5. 打印当前的epoch、迭代次数、总迭代次数、平均损失
  6. 重置间隔损失,为计算下一个间隔的平均损失做准备
            if (cur_itrs) % opts.val_interval == 0:save_ckpt('checkpoints/latest_%s_%s_os%d.pth' % (opts.model, opts.dataset, opts.output_stride))print("validation...")model.eval()val_score, ret_samples = validate(opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)print(metrics.to_str(val_score))if val_score['Mean IoU'] > best_score:  # save best modelbest_score = val_score['Mean IoU']save_ckpt('checkpoints/best_%s_%s_os%d.pth' % (opts.model, opts.dataset,opts.output_stride))if vis is not None:  # visualize validation score and samplesvis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])vis.vis_table("[Val] Class IoU", val_score['Class IoU'])for k, (img, target, lbl) in enumerate(ret_samples):img = (denorm(img) * 255).astype(np.uint8)target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)concat_img = np.concatenate((img, target, lbl), axis=2)  # concat along widthvis.vis_image('Sample %d' % k, concat_img)model.train()scheduler.step()  if cur_itrs >=  opts.total_itrs:return        

验证、模型保存、可视化反馈:

  1. 每隔一定的迭代次数,程序将执行一次验证过程
  2. 保存当前模型的状态到检查点。这里的文件名包含模型类型、数据集、输出步长
  3. 打印正在验证中
  4. 将模型设置为评估模式
  5. 执行验证函数,
  6. 返回性能得分、一些样本图片
  7. 打印验证得分
  8. 如果当前验证得分的平均交并比(Mean IoU)高于之前的最佳得分
  9. 更新最佳得分
  10. 保存当前模型
  11. 如果可视化被启用
  12. 可视化总体准确率
  13. 可视化平均 IoU
  14. 可视化各类别的 IoU
  15. for循环取出返回的样本,每个样本包含原始图像(img)、真实标签(target)、模型预测的标签(lbl),对返回的样本进行可视化:
  16. 对原始图像应用反标准化操作
  17. 将目标标签从模型输出的格式解码成可视化格式
  18. 将预测标签从模型输出的格式解码成可视化格式
  19. 将原始图像、真实标签和预测标签沿着宽度方向拼接在一起
  20. 使用可视化工具显示拼接后的图像
  21. 将模型设置回训练模式
  22. 更新学习率
  23. 如果当前迭代次数达到或超过预设的总迭代次数
  24. 则结束训练

deeplab系列算法概述
deeplabV3+ VOC分割实战1
deeplabV3+ VOC分割实战2
deeplabV3+ VOC分割实战3
deeplabV3+ VOC分割实战4
deeplabV3+ VOC分割实战5

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/639437.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能算法 | Matlab实现改进黑猩猩优化算法SLWCHOA与多个基准函数对比与秩和检验

智能算法 | Matlab实现改进黑猩猩优化算法SLWCHOA与多个基准函数对比与秩和检验 目录 智能算法 | Matlab实现改进黑猩猩优化算法SLWCHOA与多个基准函数对比与秩和检验预测效果基本描述程序设计参考资料 预测效果 基本描述 1.Matlab实现改进黑猩猩优化算法SLWCHOA与多个基准函数…

Flutter编译报错Connection timed out: connect

背景&#xff1a;用Android Studo 创建了Flutter项目&#xff0c;编译运行报错java.net.ConnectException: Connection timed out: connect 我自己的环境&#xff1a; windows11 Android Studio Flutter 截图如下&#xff1a; 将错误日志展开之后&#xff1a; Exception…

LLM面面观之LLM上下文扩展方案

1. 背景 本qiang~这段时间调研了LLM上下文扩展的问题&#xff0c;并且实打实的运行了几个开源的项目&#xff0c;所谓实践与理论相结合嘛&#xff01; 此文是本qiang~针对上下文扩展问题的总结&#xff0c;包括解决方案的整理概括&#xff0c;文中参考了多篇有意义的文章&…

【C++】类和对象(上篇)

文章目录 &#x1f6df;一、面向过程和面向对象初步认识&#x1f6df;二、类的引入&#x1f6df;三、类的定义&#x1f4dd;1、类的两种定义方式&#x1f4dd;2、成员变量命名规则的建议 &#x1f6df;四、类的访问限定符及封装&#x1f369;1、访问限定符&#x1f369;2、封装…

鼠害监测站设立的意义是什么

鼠害监测站对草原生态环境的影响主要体现在以下几个方面&#xff1a; 保护草原植被&#xff1a;鼠害监测站通过实时监测鼠害活动&#xff0c;及时采取控制措施&#xff0c;可以有效减少鼠类对草原植被的破坏&#xff0c;保护草原生态系统的稳定性。维持草原土壤健康&#xff1…

C++ 知识列表【图】

举例C的设计模式和智能指针 当谈到 C 的设计模式时&#xff0c;以下是一些常见的设计模式&#xff1a; 工厂模式&#xff08;Factory Pattern&#xff09;&#xff1a;用于创建对象的模式&#xff0c;隐藏了对象的具体实现细节&#xff0c;只暴露一个公共接口来创建对象。 单例…

Web04--Flex布局

1、flex布局 1.1 flex认识 1.2 flex组成 1.3 flex布局 1.3.1 主轴对齐方式 <!DOCTYPE html> <html lang"CN"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

解决Windows下Goland的Terminal设置为Git Bash失败

路径不要选错了&#xff1a; 如果还是不行&#xff1a; 把bash路径加进去试试 goland设置Terminal

uni-app小程序:文件下载打开文件方法苹果安卓都适用

api: const filetype e.substr(e.lastIndexOf(.)1)//获取文件地址的类型 console.log(文档,filetype) uni.downloadFile({url: e,//e是图片地址success(res) {console.log(res)if (res.statusCode 200) {console.log(下载成功,);var filePath encodeURI(res.tempFilePath);…

爬虫案例—抓取找歌词网站的按歌词找歌名数据

爬虫案例—抓取找歌词网站的按歌词找歌名数据 找个词网址&#xff1a;https://www.91ge.cn/lxyyplay/find/ 目标&#xff1a;抓取页面里的所有要查的歌词及歌名等信息&#xff0c;并存为txt文件 一共46页数据 网站截图如下&#xff1a; 抓取完整歌词数据&#xff0c;如下图…

DevOps系列文章之 GitLab CI/CD

CICD是什么? 由于目前公司使用的gitlab&#xff0c;大部分项目使用的CICD是gitlab的CICD&#xff0c;少部分用的是jenkins&#xff0c;使用了gitlab-ci一段时间后感觉还不错&#xff0c;因此总结一下 介绍gitlab的CICD之前&#xff0c;可以先了解CICD是什么 我们的开发模式…

司铭宇老师:房地产中介销售经理培训:如何激发房产中介销售人员的斗志与激情

房地产中介销售经理培训&#xff1a;如何激发房产中介销售人员的斗志与激情 在房产中介行业&#xff0c;销售人员的斗志与激情直接影响着业绩的高低。一个有动力的销售团队能够积极应对市场的变化&#xff0c;更好地服务客户&#xff0c;从而实现销售目标。本文将探讨如何通过有…

CGLIB动态代理(AOP原理)(面试重点)

推荐先看JDK 动态代理&#xff08;Spring AOP 的原理&#xff09;&#xff08;面试重点&#xff09; JDK 动态代理与 CGLIB 动态代理的区别 JDK 动态代理有⼀个最致命的问题是其只能代理实现了接⼝的类. 有些场景下,我们的业务代码是直接实现的,并没有接⼝定义.为了解决这个问…

【C++干货基地】namespace超越C语言的独特魅力(文末送书)

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 引入 哈喽各位铁汁们好啊&#xff0c;我是博主鸽芷咕《C干货基地》是由我的襄阳家乡零食基地有感而发&#xff0c;不知道各位的…

HarmonyOS4.0系统性深入开发24启动DataAbility

DataAbility组件概述 DataAbility&#xff0c;即"使用Data模板的Ability"&#xff0c;主要用于对外部提供统一的数据访问抽象&#xff0c;不提供用户交互界面。DataAbility可由PageAbility、ServiceAbility或其他应用启动&#xff0c;即使用户切换到其他应用&#x…

Debian11下编译ADAravis和Motor模块的一条龙过程

Debian11编译EPICS ADAravis记录 一年前整理的上面文&#xff0c;这几天重新走了一遍&#xff0c;有些地方会碰到问题&#xff0c;需要补充些环节&#xff0c;motor模块以前和areaDetector一条龙编译时&#xff0c;总是有问题&#xff0c;当时就没尝试了&#xff0c;这几天尝试…

位运算的魅力:使用Redis Bitmap高效处理百万级布尔值

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 位运算的魅力&#xff1a;使用Redis Bitmap高效处理百万级布尔值 前言1. Bitmap的基本概念Bitmap的定义和原理为什么Bitmap特别适合处理大量布尔值 2. Redis中的Bitmap操作基础命令高级命令 实际应用场…

低压防雷箱综合选型应用方案

低压防雷箱是一种用于保护低压配电系统免受雷电过电压的影响的装置&#xff0c;它主要由防雷箱模块、浪涌保护器SPD、接地线等组成。本文将介绍低压防雷箱的作用原理和行业应用解决方案&#xff0c;以及低压防雷箱的选型方法。 低压防雷箱的作用原理 低压防雷箱的作用原理是利…

股东出资透明度提升:企业股东出资信息API的应用

前言 在当今商业环境中&#xff0c;股东出资信息的透明度对于投资者、监管机构以及企业自身的健康发展至关重要。随着企业信息公开化的推进&#xff0c;企业股东出资信息API应运而生&#xff0c;为各方提供了一个便捷、高效的信息获取渠道。本文将探讨企业股东出资信息API如何…

HCIA NAT练习

目录 实验拓扑 实验要求 实验步骤 1、IP分配 2、使用ACL使PC访问外网 3、缺省路由 4、边界路由器公网ip端口配置 测试 实验拓扑 实验要求 1、R2为ISP路由器&#xff0c;其上只能配置ip地址&#xff0c;不得再进行其他的任何配置 2、PC1-PC2可以ping通客户平板和DNS服…