pytorch 模型可视化_【深度学习】高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力...

作者:Eugene Khvedchenya   编译:ronghuaiyang

导读

只报告模型的Top-1准确率往往是不够的。

07a21aaef16d212bd5f6150ded676007.png

将train.py脚本转换为具有一些附加特性的强大pipeline

每一个深度学习项目的最终目标都是为产品带来价值。当然,我们想要最好的模型。什么是“最好的” —— 取决于特定的用例,我将把这个讨论放到这篇文章之外。我想谈谈如何从你的train.py脚本中得到最好的模型。

在这篇文章中,我们将介绍以下技巧:

  1. 用高级框架代替自己写的循环
  2. 使用另外的度量标准监控训练的进展
  3. 使用TensorBoard
  4. 使模型的预测可视化
  5. 使用Dict作为数据集和模型的返回值
  6. 检测异常和解决数值的不稳定性

免责声明:在下一节中,我将引用一些源代码。大多数都是为[Catalyst](https://github.com/catalysts -team/catalyst)框架(20.08版)定制的,可以在pytorch-toolbelt中使用。

不要重复造轮子

131892274d4ee6a73e09d599c0da1ae0.png

建议1 — 利用PyTorch生态系统的高级训练框架

PyTorch在从头开始编写训练循环时提供了极佳的灵活性和自由度。理论上,这为编写任何训练逻辑提供了无限可能。在实践中,你很少会为训练CycleGAN、distilling BERT或3D物体检测从头开始实现编写训练循环。

从头编写一个完整的训练循环是学习PyTorch基本原理的一个很好的方法。不过,我强烈建议你在掌握了一些知识之后,转向高级框架。有很多选择:Catalyst, PyTorch-Lightning, Fast.AI, Ignite,以及其他。高级框架通过以下方式节省你的时间:

  • 提供经过良好测试的训练循环
  • 支持配置文件
  • 支持多gpu和分布式训练
  • 管理检查点/实验
  • 自动记录训练进度

从这些高级库中获得最大效果需要一些时间。然而,这种一次性的投资从长期来看是有回报的。

优点

  • 训练pipeline变得更小 —— 代码越少 —— 出错的机会就越少。
  • 易于进行实验管理。
  • 简化分布式和混合精度训练。

缺点

  • 通常,当使用一个高级框架时,我们必须在框架特定的设计原则和范例中编写代码。
  • 时间投资,学习额外的框架需要时间。

给我看指标

b923278dd3d660047768b0f45a526e59.png

建议2 —— 在训练期间查看其他指标

几乎每一个用于在MNIST或CIFAR甚至ImageNet中对图像进行分类的快速启动示例项目都有一个共同点 —— 它们在训练期间和训练之后都报告了一组最精简的度量标准。通常情况下,包括Top-1和Top-5准确度、错误率、训练/验证损失,仅此而已。虽然这些指标是必要的,但它只是冰山一角!

现代图像分类模型有数千万个参数。你想只使用一个标量值来计算它吗?

Top-1准确率最好的CNN分类模型在泛化方面可能不是最好的。根据你的领域和需求,你可能希望保存具有最 false-positive/false-negative的模型,或者具有最高平均精度的模型。

让我给你一些建议,在训练过程中你可以记录哪些数据:

  • Grad-CAM heat-map —— 看看图像的哪个部分对某一特定类的贡献最大。

f5e08049ea1850eb6bb3c562de6f7e50.png

可视化Grad-CAM heat-maps有助于识别模型是否基于真实病理或图像伪影做出预测
  • Confusion Matrix — 显示了对你的模型来说哪两个类最具挑战性。

11d7469323a49d3fa7dfe4d03a5082fd.png

混淆矩阵揭示了一个模型对特定类型进行不正确分类的频率
  • Distribution of predictions — 让你了解最优决策边界。

2453986881939f7cab37b5ee29137aab.png

该模型的negative和positive 预测的分布表明,有很大一部分数据模型无法确定地分类
  • Minimum/Average/Maximum 跨所有层的梯度值,允许识别是否在模型中存在消失/爆炸的梯度或初始化不好的层。

使用面板工具来监控训练

建议3 — 使用TensorBoard或任何其他解决方案来监控训练进度

在训练模型时,你可能最不愿意做的事情就是查看控制台输出。通过一个功能强大的仪表板,你可以在其中一次看到所有的度量标准,这是检查训练结果的更有效的方法。

34e732ea05ba6b46cbd648b48498e3d1.png

Tensorboard可以快速的检查和比较你运行的训练

对于少量实验和非分布式环境,TensorBoard是一个黄金标准。自版本1.3以来,PyTorch就完全支持它,并提供了一组丰富的特性来管理试用版。还有一些更先进的基于云的解决方案,比如Weights&Biases、[Alchemy](https://github.com/catalyst team/alchemy)和TensorBoard.dev,这些解决方案使得在多台机器上监控和比较训练变得更容易。

当使用Tensorboard时,我通常记录这样一组指标:

  • 学习率和其他可能改变的优化参数(动量,重量衰减,等等)
  • 用于数据预处理和模型内部的时间
  • 贯穿训练和验证的损失(每个batch和每个epoch的平均值)
  • 跨训练和验证的度量
  • 训练session的超参数最终值
  • 混淆矩阵,Precision-Recall曲线,AUC(如果适用)
  • 模型预测的可视化(如适用)

一图胜千言

直观地观察模型的预测是非常重要的。有时训练数据是有噪声的;有时,模型会过拟合图像的伪影。通过可视化最好的和最差的batch(基于损失或你感兴趣的度量),你可以对模型执行良好和糟糕的情况进行有价值的洞察。

建议5 — 可视化每个epoch中最好和最坏的batch。它可能会给你宝贵的见解。

Catalyst用户提示:这里是使用可视化回调的示例:https://github.com/BloodAxe/Catalyst-Inria-Segmentation-Example/blob/master/fit_predict.py#L258

例如,在全球小麦检测挑战中,我们需要在图像上检测小麦头。通过可视化最佳batch的图片(基于mAP度量),我们看到模型在寻找小物体方面做得近乎完美。

0749e7c758d9658a610b26db5e906289.png

最佳模型预测的可视化显示了模型在小物体上的良好表现

相反,当我们查看最差一批的第一个样本时,我们看到模型很难对大物体做出准确的预测。可视化分析为任何数据科学家都提供了宝贵的见解。

6fa200fcb2ec3800e98e246077b230a7.png

最差模型预测的可视化揭示了模型在大物体上的性能很差

查看最差的batch也有助于发现数据标记中的错误。通常情况下,贴错标签的样本损失更大,因此会成为最差的batch。通过在每个epoch对最糟糕的batch做一个视觉检查,你可以消除这些错误:

57b6ec65a6a28d821a688b3da8298d94.png

标记错误的例子。绿色像素表示true positives,红色像素表示false negative。在这个示例中,ground-truth掩模标在了它实际上不存在的位置上。

使用Dict作为Dataset和Model的返回值

建议4 — 如果你的模型返回一个以上的值,使用Dict来返回结果,不要使用tuple

在复杂的模型中,返回多个输出并不少见。例如,目标检测模型通常返回边界框及其标签,在图像分割CNN-s中,我们经常返回中间层的mask进行深度监督,多任务学习最近也很常用。

在许多开源实现中,我经常看到这样的东西:

# Bad practice, don't return tupleclass RetinaNet(nn.Module):
  ...def forward(self, image):
    x = self.encoder(image)
    x = self.decoder(x)
    bboxes, scores = self.head(x)return bboxes, scores
  ...

对于作者来说,我认为这是一种非常糟糕的从模型返回结果的方法。下面是我推荐的替代方法:

class RetinaNet(nn.Module):
  RETINA_NET_OUTPUT_BBOXES = "bboxes"
  RETINA_NET_OUTPUT_SCORES = "scores"
  ...def forward(self, image):
    x = self.encoder(image)
    x = self.decoder(x)
    bboxes, scores = self.head(x)return { RETINA_NET_OUTPUT_BBOXES: bboxes, 
             RETINA_NET_OUTPUT_SCORES: scores }
  ...

这个建议在某种程度上与“The Zen of Python”的设定产生了共鸣 —— “明确的比含蓄的更好”。遵循这一规则将使你的代码更清晰、更容易维护。

那么为什么我认为第二种选择更好呢?有几个原因:

  • 返回值有一个显式的名称与它关联。你不需要记住元组中元素的确切顺序。
  • 如果你需要访问返回的字典的一个特定元素,你可以通过它的名字来访问。
  • 从模型中添加新的输出不会破坏代码。

使用Dict,你甚至可以更改模型的行为,以按需返回额外的输出。例如,这里有一个简短的片段,演示了如何返回多个“主”输出和两个“辅助”输出来进行度量学习:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/models/timm.py#L104def forward(self, **kwargs):
  x = kwargs[self.input_key]
  x = self.rgb_bn(x)
  x = self.encoder.forward_features(x)
  embedding = self.pool(x)
  result = {
    OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)),
    OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)),
  }if self.need_embedding:
    result[OUTPUT_PRED_EMBEDDING] = embeddingif self.arc_margin is not None:
    result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)return result

同样的建议也适用于Dataset类。对于Cifar-10玩具示例,可以将图像及其对应的标签作为元组返回。但当处理多任务或多输入模型,你想从数据集返回Dict类型的样本:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373class TrainingValidationDataset(Dataset):def __init__(
        self,
        images: Union[List, np.ndarray],
        targets: Optional[Union[List, np.ndarray]],
        quality: Union[List, np.ndarray],
        bits: Optional[Union[List, np.ndarray]],
        transform: Union[A.Compose, A.BasicTransform],
        features: List[str],
    ):"""
        :param obliterate - Augmentation that destroys embedding.
        """if targets is not None:if len(images) != len(targets):raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}")
        self.images = images
        self.targets = targets
        self.transform = transform
        self.features = features
        self.quality = quality
        self.bits = bitsdef __len__(self):return len(self.images)def __repr__(self):return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})"def __getitem__(self, index):
        image_fname = self.images[index]try:
            image = cv2.imread(image_fname)if image is None:raise FileNotFoundError(image_fname)except Exception as e:
            print("Cannot read image ", image_fname, "at index", index)
            print(e)
        qf = self.quality[index]
        data = {}
        data["image"] = image
        data.update(compute_features(image, image_fname, self.features))
        data = self.transform(**data)
        sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)}if self.bits is not None:# OK
            sample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32)if self.targets is not None:
            target = int(self.targets[index])
            sample[INPUT_TRUE_MODIFICATION_TYPE] = target
            sample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float()for key, value in data.items():if key in self.features:
                sample[key] = tensor_from_rgb_image(value)return sample

当你的代码中有Dictionaries时,你可以在任何地方使用名称常量引用输入/输出。遵循这条规则将使你的训练管道非常清晰和容易遵循:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2
callbacks += [
  CriterionCallback(
    input_key=INPUT_TRUE_MODIFICATION_FLAG,
    output_key=OUTPUT_PRED_MODIFICATION_FLAG,
    criterion_key="bce"
  ),
  CriterionCallback(
    input_key=INPUT_TRUE_MODIFICATION_TYPE,
    output_key=OUTPUT_PRED_MODIFICATION_TYPE,
    criterion_key="ce"
  ),
  CompetitionMetricCallback(
    input_key=INPUT_TRUE_MODIFICATION_FLAG,
    output_key=OUTPUT_PRED_MODIFICATION_FLAG,
    prefix="auc",
    output_activation=binary_logits_to_probas,
    class_names=class_names,
  ),
  OutputDistributionCallback(
      input_key=INPUT_TRUE_MODIFICATION_FLAG,
      output_key=OUTPUT_PRED_MODIFICATION_FLAG,
      output_activation=binary_logits_to_probas,
      prefix="distribution/binary",
  ),
  BestMetricCheckpointCallback(
    target_metric="auc", 
    target_metric_minimize=False, 
    save_n_best=3),
]

在训练中检测异常

4b8eb58c8d9c22c509b14dc651d1442d.png

就像人类可以阅读含有许多错误的文本一样,深度学习模型也可以在训练过程中出现错误时学习“一些合理的东西”。作为一名开发人员,你要负责搜索异常并对其表现进行推理。

建议5 — 在训练期间使用 torch.autograd.detect_anomaly()查找算术异常

如果你在训练过程中在损失/度量中看到NaNs或Inf,你的脑海中就会响起一个警报。它是你的管道中有问题的指示器。通常情况下,它可能由以下原因引起:

  • 模型或特定层的初始化不好(你可以通过观察梯度大小来检查哪些层)
  • 数学上不正确的运算(负数的torch.sqrt(),非正数的torch.log(),等等)
  • 不当使用torch.mean()torch.sum() 的reduction(zero-sized张量上的均值会得到nan,大张量上的sum容易导致溢出)
  • 在loss中使用x.sigmoid()(如果你需要在loss函数中使用概率,更好的方法是x.sigmoid().clamp(eps,1-eps )以防止梯度消失)
  • 在Adam-like的优化器中的低epsilon值
  • 在使用fp16的训练的时候没有使用动态损失缩放

为了找到你代码中第一次出现Nan/Inf的确切位置,PyTorch提供了一个简单易用的方法torch. autograde .detect_anomaly()

import torchdef main():
    torch.autograd.detect_anomaly()
    ...# Rest of the training code# ORclass MyNumericallyUnstableLoss(nn.Module):def forward(self, input, target):with torch.autograd.set_detect_anomaly(True):
       loss = input * targetreturn loss

将其用于调试目的,否则就禁用它,异常检测会带来计算开销,并将训练速度降低10-15% 。

b9e9ceb60ffdcd57c2f0f3b0c48e7d5d.png—END—

英文原文:https://towardsdatascience.com/efficient-pytorch-supercharging-training-pipeline-19a26265adae

aeef93a124c56168f7025e807800243c.png

往期精彩回顾

适合初学者入门人工智能的路线及资料下载

机器学习及深度学习笔记等资料打印

机器学习在线手册

深度学习笔记专辑

《统计学习方法》的代码复现专辑

AI基础下载

机器学习的数学基础专辑

获取一折本站知识星球优惠券,复制链接直接打开:

https://t.zsxq.com/662nyZF

本站qq群1003271085。

加入微信群请扫码进群(如果是博士或者准备读博士请说明):

b89ab42ecfb7e185253733e667781b76.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/573957.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

tstringlist怎么查看是否存在该数据_注意!研究生招生信息只公开1个月!应该怎么用?...

请注意!全国硕士研究生招生信息公开平台(以下简称“研招信息公开平台”)已于2019年7月1日开放-2019年7月30日结束。招生信息怎么看?老师在线教你看懂研究生招生信息!本篇目录:1.全国硕士研究生招生信息公开…

delphi读取xml中的内容property name传递参数_Python 进阶知识全篇-XML 解析

什么是 XML?XML 指可扩展标记语言(eXtensible Markup Language),标准通用标记语言的子集,是一种用于标记电子文件使其具有结构性的标记语言。 你可以通过本站学习 XML 教程XML 被设计用来传输和存储数据。XML 是一套定…

c语言编写网页图形界面代码,「分享」C语言如何编写图形界面

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼贴吧内经常有人问C语言是不是只能用于字符终端界面开发,不能用于图形界面。大家也都有回答,需要其他的库。MFC,GTK,QT。本人近期刚用GTK库加上纯C写成了第一个LINUX实用程序。现在与大…

python 读取word_教你怎么使用 Python 对 word文档 进行操作

使用Python对word文档进行操作一、安装Python-docxPython-docx是专门针对于word文档的一个模块,只能读取docx 不能读取doc文件。说白了,python就相当于windows操作系统,QQ就是跑在windows操作系统上的软件,QQ最大的作用是可以去聊…

stm32cubemx adc_STM32CubeMX__Exp5_ADC1_2CH_DMA_TIM3_Trig__简明指导文件__jyb

用定时器TIM3触发DMA方式的双通道ADC定时采样:拷贝STM32CubeMX工程文件LED_Flash_PC12.ioc,修改为:Exp5_ADC1_2CH_DMA_TIM3_Trig.ioc(1)配置ADC1的通道和参数配置ADC通道参数(2)配置ADC1的DMA①通过点"Add"按钮,添加ADC…

JS 实现 jQuery的$(function(){});

1、浏览器渲染引擎的HTML解析流程 何谓“渲染”,其实就是浏览器把请求到的HTML内容显示出来的过程。渲染引擎首先通过网络获得所请求文档的内容,通常以8K分块的方式完成。下面是渲染引擎在取得内容之后的基本流程: 1,解析html以构…

html 分页_MySQL——优化嵌套查询和分页查询

Java识堂,一个高原创,高收藏,有干货的微信公众号,欢迎关注优化嵌套查询嵌套查询(子查询)可以使用SELECT语句来创建一个单列的查询结果,然后把这个结果作为过滤条件用在另一个查询中。嵌套查询写起来简单,也…

从原理上搞定编码-- Base64编码

开发者对Base64编码肯定很熟悉,是否对它有很清晰的认识就不一定了。实际 上Base64已经简单到不能再简单了,如果对它的理解还是模棱两可实在不应该。大概介绍一下Base64的相关内容,花几分钟时间就可以彻底理解它。文 章下边贴了一个Base64的编…

docker mysql总是退出_Docker提升测试效率之路

现如今,Docker已经成为了很多公司部署应用、服务的首选方案。依靠容器技术,我们能在不同的体系结构之上轻松部署几乎任何种类的应用。作为测试一方,我们应与时俱进,将Docker容器技术应用到测试工作中。为了让小伙伴们可以快速上手…

32位mysql安装包_软件测试基础——Linux系统搭建MySQL数据库

一、mysql下载1. 下载:官方网址:https://dev.mysql.com/downloads/mysql/2. 选择相应的版本,由于cenos是基于红帽的,所以Select Operating System选择Red Hat...。我所用的镜像为cenos7所以Red Hat....linux7,一定要选相应的版本&…

python gevent模块 下载_Python中的多任务,并行,并发,多线程,多进程,协程区别...

多任务CPU承担了所有的计算任务。一个CPU在一个时间切片里只能运行一个程序。当我们想同时运行多于一个程序的时候,就是多任务,例如同时运行微信,QQ,浏览器等等。多任务的目的是提升程序的执行效率,更充分利用CPU的资源…

vue-router 路由嵌套显示不出来_网络协议|OSI模型第三层网络层中的路由

的IP协议OSI第二层中用以太网协议定义了信息传输单元,简称为帧,它长这个样子。同样的在OSI第三层中,会用 IP 协议去定义信息传输单元,简称为数据包,它长这个样子。实际上,最终在网络上传输的是第二层的帧&a…

asp.net怎么实现按条件查询_【33期】分别谈谈联合索引生效和失效的条件

点击上方“Java面试题精选”,关注公众号面试刷图,查缺补漏>>号外:往期面试题,10篇为一个单位归置到本公众号菜单栏->面试题,有需要的欢迎翻阅。这道题考查索引生效条件、失效条件。像这类问题才其实很有意义&…

java 二分搜索获得大于目标数的第一位_程序员数据结构算法编程,二分查找搜索算法的原理与应用介绍!...

本文来讲一种搜索算法,即二分搜索算法,通常在面试时也会被问到。我们先来看一个例子,在图书馆通常是根据查到的编号去找书,可以在书架上按顺序一本本地查找,也可以找到一本书不符合预期时,再跳过一大部分书…

2020idea插件怎么同步_VScode 插件整理

1、auto rename tag :HTML 标签自动闭合;避免了在整个页面中费劲查找。你想将一个H2标签更改为H3标签,或者你想将一个div标签更改为span标签,不管要做什么,你都要浪费时间来查找结束标签,这时候就该用这个插…

python 将两幅图拼接_清华王教授典藏的python电子书,整整10个G拿去不谢

终于拿到!清华王教授典藏的电子书,整整10个G!兄弟,毫无套路!无偿获取方式:1.点赞评论2.关注小编,私信“Python”(点开头像就能看到私信按钮啦).Python指南——五行代码实现批量抠图你是否曾经想将某张照片中…

地磅称重软件源码_【漯河衡器】导致地磅称重不准原因及处理措施

地磅是一种新型的大型电子衡器,能够迅速、直观、高准确度地展现工商业、仓储、货站贸易计量的重要工具。做为贸易结算的工具,地磅的可靠性、准确性、科学性有着极为重要的影响。而在货物来往中,地磅是等价交换的桥梁,一旦地磅显现…

寻宝机器人电路板焊接_专业维修淮安市KUKA库卡KRC2机器人回收{机器人调试}

FANUC机器人伺服-023故障排除:FANUCR-2000六轴焊接机器人点焊进程中,J4机械臂显现自动滑动故障,机器人发出伺服故障报警,报警故障码为伺服-023,依据FANUC机器人维修手册,故障代码解释以下:伺服误…

android uber启动动画,仿 Uber 视频背景登录界面以及登录动画

现在有越来越多的 app 的登录/注册界面的背景是播放视频或者 gif,我主要看了 Uber 和 keep 的登录界面再配合拉勾的登录界面仿作了一个登录界面。1.首先,查资料我在 github 上找到了这两个库:-STLBGVideo 这个库是 oc 写的,但你的…

遍历列表python_Python 遍历List的三种方法

转载至https://www.cnblogs.com/pizitai/p/6398276.html #!/usr/bin/env python # -*- coding: utf-8 -*- if __name__ __main__: list [html, js, css, python] # 方法1 print 遍历列表方法1: for i in list: print ("序号:%s 值:%s&…