机器学习中qa测试_如何对机器学习做单元测试

作者:Chase Roberts

编译:ronghuaiyang

导读

养成良好的单元测试的习惯,真的是受益终身的,特别是机器学习代码,有些bug真不是看看就能看出来的。

9ad1a64e1939284f73524ee5c01a7131.png

在过去的一年里,我把大部分的工作时间都花在了深度学习研究和实习上。那一年,我犯了很多大错误,这些错误不仅帮助我了解了ML,还帮助我了解了如何正确而稳健地设计这些系统。我在谷歌Brain学到的一个主要原则是,单元测试可以决定算法的成败,可以为你节省数周的调试和训练时间。

然而,在如何为神经网络代码编写单元测试方面,似乎没有一个可靠的在线教程。即使是像OpenAI这样的地方,也只是通过盯着他们代码的每一行,并试着思考为什么它会导致bug来发现bug的。显然,我们大多数人都没有这样的时间,所以希望本教程能够帮助你开始理智地测试你的系统!

让我们从一个简单的例子开始。试着找出这段代码中的错误。

def make_convnet(input_image):    net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")    net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")    net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')    net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")    net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool2')    net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool3')    net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")    return net

你看到了吗?网络实际上并没有堆积起来。在编写这段代码时,我复制并粘贴了slim.conv2d(…)行,并且只修改了内核大小,而没有修改实际的输入。

我很不好意思地说,这件事在一周前就发生在我身上了……但这是很重要的一课!由于一些原因,这些bug很难捕获。

  1. 这段代码不会崩溃,不会产生错误,甚至不会变慢。
  2. 这个网络仍在运行,损失仍将下降。
  3. 几个小时后,这些值就会收敛,但结果却非常糟糕,让你摸不着头脑,不知道需要修复什么。

当你唯一的反馈是最终的验证错误时,你惟一需要搜索的地方就是你的整个网络体系结构。不用说,你需要一个更好的系统。

那么,在我们进行完整的多日训练之前,我们如何真正抓住这个机会呢?关于这个最容易注意到的是层的值实际上不会到达函数外的任何其他张量。假设我们有某种类型的损失和一个优化器,这些张量永远不会得到优化,所以它们总是有它们的默认值。

我们可以通过简单的训练步骤和前后对比来检测它。

def test_convnet():  image = tf.placeholder(tf.float32, (None, 100, 100, 3)  model = Model(image)  sess = tf.Session()  sess.run(tf.global_variables_initializer())  before = sess.run(tf.trainable_variables())  _ = sess.run(model.train, feed_dict={               image: np.ones((1, 100, 100, 3)),               })  after = sess.run(tf.trainable_variables())  for b, a, n in zip(before, after):      # Make sure something changed.      assert (b != a).any()

在不到15行代码中,我们现在验证了至少我们创建的所有变量都得到了训练。

这个测试超级简单,超级有用。假设我们修复了前面的问题,现在我们要开始添加一些批归一化。看看你能否发现这个bug。

  def make_convnet(image_input):        # Try to normalize the input before convoluting        net = slim.batch_norm(image_input)        net = slim.conv2d(net, 32, [11, 11], scope="conv1_11x11")        net = slim.conv2d(net, 64, [5, 5], scope="conv2_5x5")        net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')        net = slim.conv2d(net, 64, [5, 5], scope="conv3_5x5")        net = slim.conv2d(net, 128, [3, 3], scope="conv4_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool2')        net = slim.conv2d(net, 128, [3, 3], scope="conv5_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool3')        net = slim.conv2d(net, 32, [1, 1], scope="conv6_1x1")        return net

你看到了吗?这个非常微妙。您可以看到,在tensorflow batch_norm中,is_training的默认值是False,所以添加这行代码并不能使你在训练期间的输入正常化!值得庆幸的是,我们编写的最后一个单元测试将立即发现这个问题!(我知道,因为这是三天前发生在我身上的事。)

再看一个例子。这实际上来自我一天看到的一篇文章(https://www.reddit.com/r/MachineLearning/comments/6qyvvg/p_tensorflow_response_is_making_no_sense/)。我不会讲太多细节,但是基本上这个人想要创建一个输出范围为(0,1)的分类器。

class Model:  def __init__(self, input, labels):    """Classifier model    Args:      input: Input tensor of size (None, input_dims)      label: Label tensor of size (None, 1).         Should be of type tf.int32.    """    prediction = self.make_network(input)    # Prediction size is (None, 1).    self.loss = tf.nn.softmax_cross_entropy_with_logits(        logits=prediction, labels=labels)    self.train_op = tf.train.AdamOptimizer().minimize(self.loss)

注意到这个错误吗?这是真的很难提前发现,并可能导致超级混乱的结果。基本上,这里发生的是预测只有一个输出,当你将softmax交叉熵应用到它上时,它的损失总是0。

一个简单的测试方法是确保损失不为0。

def test_loss():  in_tensor = tf.placeholder(tf.float32, (None, 3))  labels = tf.placeholder(tf.int32, None, 1))  model = Model(in_tensor, labels)  sess = tf.Session()  loss = sess.run(model.loss, feed_dict={    in_tensor:np.ones(1, 3),    labels:[[1]]  })  assert loss != 0

另一个很好的测试与我们的第一个测试类似,但是是反向的。你可以确保只有你想训练的变量得到了训练。以GAN为例。出现的一个常见错误是在进行优化时不小心忘记设置要训练的变量。这样的代码经常发生。

class GAN:  def __init__(self, z_vector, true_images):    # Pretend these are implemented.    with tf.variable_scope("gen"):      self.make_geneator(z_vector)    with tf.variable_scope("des"):      self.make_descriminator(true_images)    opt = tf.AdamOptimizer()    train_descrim = opt.minimize(self.descrim_loss)    train_gen = opt.minimize(self.gen_loss)

这里最大的问题是优化器有一个默认设置来优化所有变量。在像GANs这样的高级架构中,这是对你所有训练时间的死刑判决。但是,你可以通过编写这样的测试来轻松地发现这些错误:

def test_gen_training():  model = Model  sess = tf.Session()  gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen')  des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des')  before_gen = sess.run(gen_vars)  before_des = sess.run(des_vars)  # Train the generator.  sess.run(model.train_gen)  after_gen = sess.run(gen_vars)  after_des = sess.run(des_vars)  # Make sure the generator variables changed.  for b,a in zip(before_gen, after_gen):    assert (a != b).any()  # Make sure descriminator did NOT change.  for b,a in zip(before_des, after_des):    assert (a == b).all()

可以为鉴别器编写一个非常类似的测试。同样的测试也可以用于许多强化学习算法。许多行为-批评模型有单独的网络,需要根据不同的损失进行优化。

下面是一些我推荐你进行测试的模式。

  1. 让测试具有确定性。如果一个测试以一种奇怪的方式失败,却永远无法重现这个错误,那就太糟糕了。如果你真的想要随机输入,确保使用种子随机数,这样你就可以轻松地重新运行测试。
  2. 保持测试简短。不要使用单元测试来训练收敛性并检查验证集。这样做是在浪费自己的时间。
  3. 确保你在每个测试之间重置了计算图。

总之,这些黑箱算法仍然有很多方法需要测试!花一个小时写一个测试可以节省你几天的重新运行训练模型,并可以大大提高你的研究效率。因为我们的实现有缺陷而不得不放弃完美的想法,这不是很糟糕吗?

这个列表显然不全面,但它是一个坚实的开始!

英文原文:https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/454365.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一个从文本文件里“查找并替换”的功能

12345678910111213141516171819202122232425# -*- coding: UTF-8 -*-file input("请输入文件路径:") word1 input("请输入要替换的词:") word2 input("请输入新的词:") fopen(file,"r") AAAf.read() count 0 def BBB()…

机器学习算法之 KNN

K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法了,在我们平常的生活中也会不自主的应用。比如,我们判断一个人的人品,只需要观察他来往最密切的几个人的人品好坏就可以得出了。这里就运用了KNN的思想。KNN方法既可以做分类&…

安装云端服务器操作系统,安装云端服务器操作系统

安装云端服务器操作系统 内容精选换一换SAP云服务器规格在申请SAP ECS之前,请参考SAP标准Sizing方法进行SAPS值评估,并根据Sizing结果申请云端ECS服务器资源,详细信息请参考SAP Quick Sizer。SAP 各组件最低硬盘空间、RAM,以及软件…

python 进度条_六种酷炫Python运行进度条

转自:一行数据阅读文本大概需要 3 分钟你的代码进度还剩多少?今天给大家介绍下目前6种比较常用的进度条,让大家都能直观地看到脚本运行最新的进展情况。1.普通进度条2.带时间进度条3.tpdm进度条4.progress进度条5.alive_progress进度条6.可视…

权限之浅理解

白马过隙,在感叹时光流逝的同时不得不承认在学习中随着知识面的不断扩展所接受的东西也越来越多,尤其是那些外形比较容易混淆的命令,着实让作为新手的吃了很多苦头,趁着学习紧张之时偷个懒整理这周易混淆的命令: chgrp…

机器学习算法之生成树

一、什么是决策树? 决策树(Decision Tree)是一种基本的分类和回归的方法。 分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点(node)和有向边(directed edge)组成。结点有两种…

机器学习算法之集成学习

集成学习的思想是将若干个学习器(分类器&回归器)组合之后产生一个新学习器。弱分类器(weak learner)指那些分类准确率只稍微好于随机猜测的分类器(errorrate < 0.5)。 集成算法的成功在于保证弱分类器的多样性(Diversity)。而且集成不稳定的算法也能够得到一个比较明显…

常用的方法论-NPS

转载于:https://www.cnblogs.com/qjm201000/p/7687510.html

controller调用controller的方法_SpringBoot 优雅停止服务的几种方法

转自&#xff1a;博客园&#xff0c;作者&#xff1a;黄青石www.cnblogs.com/huangqingshi/p/11370291.html 在使用 SpringBoot 的时候&#xff0c;都要涉及到服务的停止和启动&#xff0c;当我们停止服务的时候&#xff0c;很多时候大家都是kill -9 直接把程序进程杀掉&#x…

机器学习之聚类概述

什么是聚类 聚类就是对大量未知标注的数据集&#xff0c;按照数据 内部存在的数据特征 将数据集划分为 多个不同的类别 &#xff0c;使 类别内的数据比较相似&#xff0c;类别之间的数据相似度比较小&#xff1b;属于 无监督学习。 聚类算法的重点是计算样本项之间的 相似度&…

qt release打包发布_几种解决Qt程序打包后无法连接数据库问题的方法

Qt是一个跨平台C图形用户界面应用程序开发框架&#xff0c;使用它不仅可以方便地开发GUI程序&#xff0c;也可以开发非GUI程序&#xff0c;可以一次编写&#xff0c;处处编译。今天遇到的问题比较怪异&#xff0c;我开发的是一个桌面版订单管理系统&#xff0c;整体架构就是一个…

机器学习之拉格朗日乘子法和 KKT

有约束的最优化问题 最优化问题一般是指对于某一个函数而言&#xff0c;求解在其指定作用域上的全局最小值问题&#xff0c;一般分为以下三种情况(备注&#xff1a;以下几种方式求出来的解都有可能是局部极小值&#xff0c;只有当函数是凸函数的时候&#xff0c;才可以得到全局…

pmp思维导图 第六版_PMP考试技巧攻略(上)

PMP考试需要有保证足够的时间投入&#xff1a;获得PMP 考试并拿到5A 成绩&#xff0c;并且还需要理解性记忆&#xff1a;PMP 指定教材PMBOK第六版&#xff08;教材为必看三遍以上&#xff09;&#xff0c;学习起来是有趣的&#xff0c;同时也是痛苦的。因为看书时字面的字我们认…

浅谈MVC MVP MVVM

复杂的软件必须有清晰合理的架构&#xff0c;否则无法开发和维护。 MVC&#xff08;Model-View-Controller&#xff09;是最常见的软件架构之一&#xff0c;业界有着广泛应用。 它本身很容易理解&#xff0c;但是要讲清楚&#xff0c;它与衍生的 MVP 和 MVVM 架构的区别就不容易…

商务搜索引擎_外贸研修 | 世界各国常用搜索引擎,开发客户必备!

我们平时生活中也好&#xff0c;开发客户也好&#xff0c;搜索引擎是我们离不开的工具。最佳没有之一的当属谷歌了。谷歌网址&#xff1a;www.google.com谷歌高级搜索&#xff1a;https://www.google.com/advanced_search (通过设置/排除一些字词缩小精确搜索范围)作为普通使用…

HaProxy+Keepalived+Mycat高可用群集配置

概述 本章节主要介绍配置HaProxyKeepalived高可用群集&#xff0c;Mycat的配置就不在这里做介绍&#xff0c;可以参考我前面写的几篇关于Mycat的文章。 部署图&#xff1a; 配置 HaProxy安装 181和179两台服务器安装haproxy的步骤一致 --创建haproxy用户 useradd haproxy--…

奇怪的bug,不懂Atom在添加markdown-themeable-pdf,在配置好phantomjs的情况下报错

本来打算用一下atom但是导出pdf报错&#xff0c;可是在预览的情况下就没有问题&#xff0c;顺便吐槽一下谷歌浏览器自己的markdown在线预览插件无法适配&#xff0c;用搜狗搭载谷歌的插件才能导出pdf&#xff0c;一下感觉逼格少了很多&#xff0c;等忙完这阵再来看一下。先贴出…

Python 面试题

Python面试315道题第一部 Python面试题基础篇&#xff08;80道&#xff09;1、为什么学习Python&#xff1f;2、通过什么途径学习的Python&#xff1f;3、Python和Java、PHP、C、C#、C等其他语言的对比&#xff1f;PHPjavacc#c4、简述解释型和编译型编程语言&#xff1f;编译型…

bzoj1038500AC!

序列dp 先开始想了一个类似区间dp的东西...少了一维 然后发现似乎不太对&#xff0c;因为女生的最大差和男生的最大差并不相等 dp[i][j][x][y]表示当前有i个人&#xff0c;j个男生&#xff0c;男生和女生的后缀最大差是x&#xff0c;女生和男生最大差是y&#xff0c;x,y>0,转…

android生命周期_Android开发 View的生命周期结合代码详解

咱们以TextView控件为例&#xff1a;/*** Created by SunshineBoy on 2020/9/23.*/public class TestTextView extends android.support.v7.widget.AppCompatTextView {public TestTextView(Context context) {super(context);Log.e("TestTextView","TestTextVi…