推荐一个快速定位深度学习代码bug的炼丹神器!

文 | McGL

源 | 知乎


写深度学习网络代码,最大的挑战之一,尤其对新手来说,就是把所有的张量维度正确对齐。如果以前就有TensorSensor这个工具,相信我的头发一定比现在更浓密茂盛!

TensorSensor,码痴教授 Terence Parr 出品,他也是著名 parser 工具 ANTLR 的作者。

在包含多个张量和张量运算的复杂表达式中,张量的维数很容易忘了。即使只是将数据输入到预定义的 TensorFlow 网络层,维度也要弄对。当你要求进行错误的计算时,通常会得到一些没啥用的异常消息。为了帮助自己和其他程序员调试张量代码,Terence Parr 写了一个名叫 TensorSensor 的库(pip install tensor-sensor 直接安装) 。TensorSensor 通过增加消息和可视化 Python 代码来展示张量变量的形状,让异常更清晰(见下图)。它可以兼容 TensorFlow、PyTorch 和 Numpy以及 Keras 和 fastai 等高级库。

在张量代码中定位问题令人抓狂!

即使是专家,执行张量操作的 Python 代码行中发生异常,也很难快速定位原因。调试过程通常是在有问题的行前面添加一个 print 语句,以打出每个张量的形状。这需要编辑代码添加调试语句并重新运行训练过程。或者,我们可以使用交互式调试器手动单击或键入命令来请求所有张量形状。(这在像 PyCharm 这样的 IDE 中不太实用,因为在调试模式很慢。)下面将详细对比展示看了让人贫血的缺省异常消息和 TensorSensor 提出的方法,而不用调试器或 print 大法。

调试一个简单的线性层

让我们来看一个简单的张量计算,来说明缺省异常消息提供的信息不太理想。下面是一个包含张量维度错误的硬编码单(线性)网络层的简单 NumPy 实现。

import numpy as npn = 200                          # number of instances
d = 764                          # number of instance features
n_neurons = 100                  # how many neurons in this layer?W = np.random.rand(d,n_neurons)  # Ooops! Should be (n_neurons,d)
b = np.random.rand(n_neurons,1)
X = np.random.rand(n,d)          # fake input matrix with n rows of d-dimensionsY = W @ X.T + b                  # pass all X instances through layer

10 Y = W @ X.T + b

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)

执行该代码会触发一个异常,其重要元素如下:

...
---> 10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)

异常显示了出错的行以及是哪个操作(matmul: 矩阵乘法),但是如果给出完整的张量维数会更有用。此外,这个异常也无法区分在 Python 的一行中的多个矩阵乘法。

接下来,让我们看看 TensorSensor 如何使调试语句更加容易的。如果我们使用 Python with 和tsensor 的 clarify()包装语句,我们将得到一个可视化和增强的错误消息。

import tsensor
with tsensor.clarify():Y = W @ X.T + b
...
ValueError: matmul: Input operand ...
Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)

从可视化中可以清楚地看到,W 的维度应该翻转为 n _ neurons x d; W 的列必须与 X.T 的行匹配。您还可以检查一个完整的带有和不带阐明()的并排图像,以查看它在笔记本中的样子。下面是带有和没有 clarify() 的例子在notebook 中的比较。

clarify() 功能在没有异常时不会增加正在执行的程序任何开销。有异常时, clarify():

  1. 增加由底层张量库创建的异常对象消息。

  2. 给出出错操作所涉及的张量大小的可视化表示; 只突出显示异常涉及的操作对象和运算符,而其他 Python 元素则不突出显示。

TensorSensor 还区分了 PyTorch 和 TensorFlow 引发的与张量相关的异常。下面是等效的代码片段和增强的异常错误消息(Cause: @ on tensor ...)以及 TensorSensor 的可视化:

PyTorch 消息没有标识是哪个操作触发了异常,但 TensorFlow 的消息指出了是矩阵乘法。两者都显示操作对象维度。

调试复杂的张量表达式

缺省消息缺乏具体细节,在包含大量操作符的更复杂的语句中,识别出有问题的子表达式很难。例如,下面是从一个门控循环单元(GRU)实现的内部提取的一个语句:

h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)

这是什么计算或者变量代表什么不重要,它们只是张量变量。有两个矩阵乘法,两个向量加法,还有一个向量逐元素修改(r*h)。如果没有增强的错误消息或可视化,我们就无法知道是哪个操作符或操作对象导致了异常。为了演示 TensorSensor 在这种情况下是如何分清异常的,我们需要给语句中使用的变量(为 h _ 赋值)一些伪定义,以得到可执行代码:

nhidden = 256
Whh_ = torch.eye(nhidden, nhidden)  # Identity matrix
Uxh_ = torch.randn(d, nhidden)
bh_  = torch.zeros(nhidden, 1)
h = torch.randn(nhidden, 1)         # fake previous hidden state h
r = torch.randn(nhidden, 1)         # fake this computation
X = torch.rand(n,d)                 # fake inputwith tsensor.clarify():h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)

同样,你可以忽略代码执行的实际计算,将重点放在张量变量的形状上。

对于我们大多数人来说,仅仅通过张量维数和张量代码是不可能识别问题的。当然,默认的异常消息是有帮助的,但是我们中的大多数人仍然难以定位问题。以下是默认异常消息的关键部分(注意对 C++ 代码的不太有用的引用) :

---> 10     h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41

我们需要知道的是哪个操作符和操作对象出错了,然后我们可以通过维数来确定问题。以下是 TensorSensor 的可视化和增强的异常消息:

---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: @ on tensor operand Uxh_ w/shape [764, 256] and operand X.T w/shape [764, 200]

人眼可以迅速锁定在指示的算子和矩阵相乘的维度上。哎呀, Uxh 的列必须与 X.T的行匹配,Uxh_的维度翻转了,应该为:

Uxh_ = torch.randn(nhidden, d)

现在,我们只在 with 代码块中使用我们自己直接指定的张量计算。那么在张量库的内置预建网络层中触发的异常又会如何呢?

理清预建层中触发的异常

TensorSensor 可视化进入你选择的张量库前的最后一段代码。例如,让我们使用标准的 PyTorch nn.Linear 线性层,但输入一个 X 矩阵维度是 n x n,而不是正确的 n x d:

L = torch.nn.Linear(d, n_neurons)
X = torch.rand(n,n) # oops! Should be n x d
with tsensor.clarify():Y = L(X)

增强的异常信息

RuntimeError: size mismatch, m1: [200 x 200], m2: [764 x 100] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: L(X) tensor arg X w/shape [200, 200]

TensorSensor 将张量库的调用视为操作符,无论是对网络层还是对 torch.dot(a,b) 之类的简单操作的调用。在库函数中触发的异常会产生消息,消息标示了函数和任何张量参数的维数。

后台回复关键词【入群

加入卖萌屋NLP/IR/Rec与求职讨论群

后台回复关键词【顶会

获取ACL、CIKM等各大顶会论文集!

 

[1] https://explained.ai/tensor-sensor/index.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479517.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

研讨会 | 知识图谱引领认知智能+

本文转载自公众号:中国计算机学会。本论坛将于 CNCC2019 中国计算机大会第一天(10月17日)在苏州金鸡湖国际会议中心 A102 会议室举行,共邀微软、阿里巴巴、华为、小米、浙江大学、苏州大学等机构的专家与你探讨。知识图谱是一种用…

LeetCode 617. 合并二叉树

文章目录1. 题目2. 递归解题1. 题目 给定两个二叉树,想象当你将它们中的一个覆盖到另一个上时,两个二叉树的一些节点便会重叠。 你需要将他们合并为一个新的二叉树。合并的规则是如果两个节点重叠,那么将他们的值相加作为节点合并后的新值&…

BIO,NIO,AIO

BIO,NIO,AIO 介绍一、背景1.1 说明1.2 通信技术整体解决的问题二、Java的I/O演进之路2.1 I/O 模型基本说明2.2 I/O模型Java BIOJava NIOJava AIO2.3 BIO、NIO、AIO 适用场景分析三、BIO,NIO,AIO总结一、背景 1.1 说明 ​ 在Java的软件设计开发中,通信架构是不可避…

学PyTorch还是TensorFlow?

在机器学习领域,面对各类复杂多变的业务问题,构建灵活易调整的模型是高阶机器学习工程师必备的工作能力。然而,许多工程师还是有一个想法上的误区,以为只要掌握了一种深度学习的框架就能走遍天下了。事实上,在机器学习…

Bifrost微前端框架及其在美团闪购中的实践

Bifrost(英 [‘bi:frɔst])原意彩虹桥,北欧神话中是连通天地的一条通道。而在漫威电影《雷神》中,Bifrost是神域——阿斯加德(Asgard)的出入口,神域的人通过它自由穿梭于“九界”(指…

设计模式之观察者模式在Listview中的应用

有时候我们会有这么一个需求,在Listview的某个Item上有个按钮,点击这个按钮之后呢,需要对其它的item做一些操作,就像下面这个: 采纳按钮点击之前:采纳按钮点击之后: 简单介绍一下这两张图的意…

新书速递 | 《知识图谱:方法、实践与应用》

本文转载自公众号:博文视点Broadview 。互联网促成了大数据的集聚,大数据进而促进了人工智能算法的进步。近年来知识图谱作为AI领域底层技术被越来越多的人谈起。知识图谱的升温得益于新数据和新算法为规模化知识图谱构建提供了新的技术基础和发展条件&a…

Github Star过万的阿里学长独家干货分享

浅梦是我认识的一位浙大计算机系的学长,目前在阿里从事算法相关的工作。无论在学校还是工作中,他都保持着对新知识的学习和分享。他的github star 1w,世界排名700,参与开发的项目下载量接近30w次。主要涉及「推荐系统」&#xff0…

React Native在美团外卖客户端的实践

MRN简介 MRN(Meituan React Native) 是基于开源的React Native框架改造并完善而成的一套动态化方案,在开发体验上基本能与原生RN保持一致,同时从业务需求的角度满足从开发、构建、测试、部署、运维的工程化需要。解决了一系列痛点…

论文浅尝 | 使用预训练深度模型和迁移学习方法的端到端模糊实体匹配

论文笔记整理:高凤宁,南京大学硕士,研究方向为知识图谱、实体消解。链接:https://doi.org/10.1145/3308558.3313578动机目前实体匹配过程中实体之间的差异比较微妙,不同的情况下可能会有不同的决策结果,导致…

推荐几个Android开发非常有用的工具(for android studio)

原文地址: http://stormzhang.com/android/2015/05/26/android-tools/ 一晃好久没更新博客了,最近一个月真的很忙,因为公司在准备C轮融资,公司的发展到了一个关键的阶段,自己全部精力投入在公司产品上,这个状态可能还会…

分布式机器学习(下)-联邦学习

原文链接:https://zhuanlan.zhihu.com/p/114028503 本视频来源于Shusen Wang讲解的《分布式机器学习》,总共有三讲,内容和连接如下:并行计算与机器学习(上)并行计算与机器学习(下)联…

怎样将Embedding融入传统机器学习框架?

文 | 石塔西源 | 知乎LR本身是一个经典的CTR模型,广泛应用于推荐/广告系统。输入的特征大多数是离散型/组合型。那么对于Embedding技术,如何在不使用深度学习模型的情况下(假设就是不能用DNN),融入到LR框架中呢&#x…

推荐系统中的Embedding

推荐系统之Embedding一、什么是embedding?1. 让embedding空前流行的word2vec:2. 从word2vec到item2vec二、Graph Embedding1. 经典的Graph Embedding方法 — DeepWalk2. DeepWalk改进 — Node2vec3. 阿里的Graph Embedding方法EGES三、深度学习推荐系统中…

美团下一代服务治理系统 OCTO 2.0 的探索与实践

本文根据美团基础架构部服务治理团队工程师郭继东在2019 QCon(全球软件开发大会)上的演讲内容整理而成,主要阐述美团大规模治理体系结合 Service Mesh 演进的探索实践,希望对从事此领域的同学有所帮助。 一、OCTO 现状分析 OCTO 是…

技术动态 | 跨句多元关系抽取

本文转载自公众号&#xff1a;知识工场。第一部分 概述关系抽取简介关系抽取是从自由文本中获取实体间所具有的语义关系。这种语义关系常以三元组 <E1,R,E2> 的形式表达&#xff0c;其中&#xff0c;E1 和E2 表示实体&#xff0c;R 表示实体间所具有的语义关系。如图1所示…

网络解析(一):LeNet-5详解

原文链接&#xff1a;https://cuijiahua.com/blog/2018/01/dl_3.html 2018年1月9日21:03:313994,282 C摘要LeNet-5出自论文Gradient-Based Learning Applied to Document Recognition&#xff0c;是一种用于手写体字符识别的非常高效的卷积神经网络。一、前言LeNet-5出自论文Gr…

LeetCode 69. x 的平方根(二分查找)

文章目录1. 题目2.解题2.1 二分查找2.2 牛顿迭代1. 题目 实现 int sqrt(int x) 函数。 计算并返回 x 的平方根&#xff0c;其中 x 是非负整数。 由于返回类型是整数&#xff0c;结果只保留整数的部分&#xff0c;小数部分将被舍去。 示例 1:输入: 4 输出: 2 示例 2:输入: 8…

Google综述:细数Transformer模型的17大高效变种

文 | 黄浴来源 | 知乎在NLP领域transformer已经是成功地取代了RNN&#xff08;LSTM/GRU&#xff09;&#xff0c;在CV领域也出现了应用&#xff0c;比如目标检测和图像加注&#xff0c;还有RL领域。这是一篇谷歌2020年9月份在arXiv发表的综述论文 “Efficient Transformers: A …

从ReentrantLock的实现看AQS的原理及应用

前言 Java中的大部分同步类&#xff08;Lock、Semaphore、ReentrantLock等&#xff09;都是基于AbstractQueuedSynchronizer&#xff08;简称为AQS&#xff09;实现的。AQS是一种提供了原子式管理同步状态、阻塞和唤醒线程功能以及队列模型的简单框架。本文会从应用层逐渐深入到…