机器学习中qa测试_如何对机器学习做单元测试

作者:Chase Roberts

编译:ronghuaiyang

导读

养成良好的单元测试的习惯,真的是受益终身的,特别是机器学习代码,有些bug真不是看看就能看出来的。

9ad1a64e1939284f73524ee5c01a7131.png

在过去的一年里,我把大部分的工作时间都花在了深度学习研究和实习上。那一年,我犯了很多大错误,这些错误不仅帮助我了解了ML,还帮助我了解了如何正确而稳健地设计这些系统。我在谷歌Brain学到的一个主要原则是,单元测试可以决定算法的成败,可以为你节省数周的调试和训练时间。

然而,在如何为神经网络代码编写单元测试方面,似乎没有一个可靠的在线教程。即使是像OpenAI这样的地方,也只是通过盯着他们代码的每一行,并试着思考为什么它会导致bug来发现bug的。显然,我们大多数人都没有这样的时间,所以希望本教程能够帮助你开始理智地测试你的系统!

让我们从一个简单的例子开始。试着找出这段代码中的错误。

def make_convnet(input_image):    net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11")    net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5")    net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')    net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5")    net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool2')    net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3")    net = slim.max_pool2d(net, [2, 2], scope='pool3')    net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1")    return net

你看到了吗?网络实际上并没有堆积起来。在编写这段代码时,我复制并粘贴了slim.conv2d(…)行,并且只修改了内核大小,而没有修改实际的输入。

我很不好意思地说,这件事在一周前就发生在我身上了……但这是很重要的一课!由于一些原因,这些bug很难捕获。

  1. 这段代码不会崩溃,不会产生错误,甚至不会变慢。
  2. 这个网络仍在运行,损失仍将下降。
  3. 几个小时后,这些值就会收敛,但结果却非常糟糕,让你摸不着头脑,不知道需要修复什么。

当你唯一的反馈是最终的验证错误时,你惟一需要搜索的地方就是你的整个网络体系结构。不用说,你需要一个更好的系统。

那么,在我们进行完整的多日训练之前,我们如何真正抓住这个机会呢?关于这个最容易注意到的是层的值实际上不会到达函数外的任何其他张量。假设我们有某种类型的损失和一个优化器,这些张量永远不会得到优化,所以它们总是有它们的默认值。

我们可以通过简单的训练步骤和前后对比来检测它。

def test_convnet():  image = tf.placeholder(tf.float32, (None, 100, 100, 3)  model = Model(image)  sess = tf.Session()  sess.run(tf.global_variables_initializer())  before = sess.run(tf.trainable_variables())  _ = sess.run(model.train, feed_dict={               image: np.ones((1, 100, 100, 3)),               })  after = sess.run(tf.trainable_variables())  for b, a, n in zip(before, after):      # Make sure something changed.      assert (b != a).any()

在不到15行代码中,我们现在验证了至少我们创建的所有变量都得到了训练。

这个测试超级简单,超级有用。假设我们修复了前面的问题,现在我们要开始添加一些批归一化。看看你能否发现这个bug。

  def make_convnet(image_input):        # Try to normalize the input before convoluting        net = slim.batch_norm(image_input)        net = slim.conv2d(net, 32, [11, 11], scope="conv1_11x11")        net = slim.conv2d(net, 64, [5, 5], scope="conv2_5x5")        net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1')        net = slim.conv2d(net, 64, [5, 5], scope="conv3_5x5")        net = slim.conv2d(net, 128, [3, 3], scope="conv4_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool2')        net = slim.conv2d(net, 128, [3, 3], scope="conv5_3x3")        net = slim.max_pool2d(net, [2, 2], scope='pool3')        net = slim.conv2d(net, 32, [1, 1], scope="conv6_1x1")        return net

你看到了吗?这个非常微妙。您可以看到,在tensorflow batch_norm中,is_training的默认值是False,所以添加这行代码并不能使你在训练期间的输入正常化!值得庆幸的是,我们编写的最后一个单元测试将立即发现这个问题!(我知道,因为这是三天前发生在我身上的事。)

再看一个例子。这实际上来自我一天看到的一篇文章(https://www.reddit.com/r/MachineLearning/comments/6qyvvg/p_tensorflow_response_is_making_no_sense/)。我不会讲太多细节,但是基本上这个人想要创建一个输出范围为(0,1)的分类器。

class Model:  def __init__(self, input, labels):    """Classifier model    Args:      input: Input tensor of size (None, input_dims)      label: Label tensor of size (None, 1).         Should be of type tf.int32.    """    prediction = self.make_network(input)    # Prediction size is (None, 1).    self.loss = tf.nn.softmax_cross_entropy_with_logits(        logits=prediction, labels=labels)    self.train_op = tf.train.AdamOptimizer().minimize(self.loss)

注意到这个错误吗?这是真的很难提前发现,并可能导致超级混乱的结果。基本上,这里发生的是预测只有一个输出,当你将softmax交叉熵应用到它上时,它的损失总是0。

一个简单的测试方法是确保损失不为0。

def test_loss():  in_tensor = tf.placeholder(tf.float32, (None, 3))  labels = tf.placeholder(tf.int32, None, 1))  model = Model(in_tensor, labels)  sess = tf.Session()  loss = sess.run(model.loss, feed_dict={    in_tensor:np.ones(1, 3),    labels:[[1]]  })  assert loss != 0

另一个很好的测试与我们的第一个测试类似,但是是反向的。你可以确保只有你想训练的变量得到了训练。以GAN为例。出现的一个常见错误是在进行优化时不小心忘记设置要训练的变量。这样的代码经常发生。

class GAN:  def __init__(self, z_vector, true_images):    # Pretend these are implemented.    with tf.variable_scope("gen"):      self.make_geneator(z_vector)    with tf.variable_scope("des"):      self.make_descriminator(true_images)    opt = tf.AdamOptimizer()    train_descrim = opt.minimize(self.descrim_loss)    train_gen = opt.minimize(self.gen_loss)

这里最大的问题是优化器有一个默认设置来优化所有变量。在像GANs这样的高级架构中,这是对你所有训练时间的死刑判决。但是,你可以通过编写这样的测试来轻松地发现这些错误:

def test_gen_training():  model = Model  sess = tf.Session()  gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen')  des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des')  before_gen = sess.run(gen_vars)  before_des = sess.run(des_vars)  # Train the generator.  sess.run(model.train_gen)  after_gen = sess.run(gen_vars)  after_des = sess.run(des_vars)  # Make sure the generator variables changed.  for b,a in zip(before_gen, after_gen):    assert (a != b).any()  # Make sure descriminator did NOT change.  for b,a in zip(before_des, after_des):    assert (a == b).all()

可以为鉴别器编写一个非常类似的测试。同样的测试也可以用于许多强化学习算法。许多行为-批评模型有单独的网络,需要根据不同的损失进行优化。

下面是一些我推荐你进行测试的模式。

  1. 让测试具有确定性。如果一个测试以一种奇怪的方式失败,却永远无法重现这个错误,那就太糟糕了。如果你真的想要随机输入,确保使用种子随机数,这样你就可以轻松地重新运行测试。
  2. 保持测试简短。不要使用单元测试来训练收敛性并检查验证集。这样做是在浪费自己的时间。
  3. 确保你在每个测试之间重置了计算图。

总之,这些黑箱算法仍然有很多方法需要测试!花一个小时写一个测试可以节省你几天的重新运行训练模型,并可以大大提高你的研究效率。因为我们的实现有缺陷而不得不放弃完美的想法,这不是很糟糕吗?

这个列表显然不全面,但它是一个坚实的开始!

英文原文:https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/454365.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

项目宝提供的服务器,开源WebSocket服务器项目宝贝鱼CshBBrain V4.0.1 和 V2.0.2发布

开源WebSocket服务器项目宝贝鱼CshBBrain V4.0.1 和 V2.0.2发布更新的功能列表如下:1.解决开启广播消息开关时,不能同时接入2个客户端的重大缺陷。2.对广播消息做了重大优化,从以前一个线程发送广播消息进化到使用工作线程池中的线程并行的发…

c# 无损高质量压缩图片代码

/// 无损压缩图片 /// <param name"sFile">原图片</param> /// <param name"dFile">压缩后保存位置</param> /// <param name"dHeight">高度</param> /// <param name"dWidth"…

一个从文本文件里“查找并替换”的功能

12345678910111213141516171819202122232425# -*- coding: UTF-8 -*-file input("请输入文件路径:") word1 input("请输入要替换的词:") word2 input("请输入新的词&#xff1a;") fopen(file,"r") AAAf.read() count 0 def BBB()…

机器学习算法之 KNN

K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法了&#xff0c;在我们平常的生活中也会不自主的应用。比如&#xff0c;我们判断一个人的人品&#xff0c;只需要观察他来往最密切的几个人的人品好坏就可以得出了。这里就运用了KNN的思想。KNN方法既可以做分类&…

安装云端服务器操作系统,安装云端服务器操作系统

安装云端服务器操作系统 内容精选换一换SAP云服务器规格在申请SAP ECS之前&#xff0c;请参考SAP标准Sizing方法进行SAPS值评估&#xff0c;并根据Sizing结果申请云端ECS服务器资源&#xff0c;详细信息请参考SAP Quick Sizer。SAP 各组件最低硬盘空间、RAM&#xff0c;以及软件…

python 进度条_六种酷炫Python运行进度条

转自&#xff1a;一行数据阅读文本大概需要 3 分钟你的代码进度还剩多少&#xff1f;今天给大家介绍下目前6种比较常用的进度条&#xff0c;让大家都能直观地看到脚本运行最新的进展情况。1.普通进度条2.带时间进度条3.tpdm进度条4.progress进度条5.alive_progress进度条6.可视…

js 获取多少天前

getBeforeDate: function(day, str) { var now new Date().getTime(); //获取毫秒数 var before new Date(now - ((day > 0 && day ? day : 0) * 86400 * 1000)); var year before.getFullYear(); var month before.getMonth()1; var date before.getDate(); …

程序员的基本素质

给所有立志成为程序员的朋友 以及 自勉之&#xff01; 程序员基本素质&#xff1a; 作一个真正合格的程序员&#xff0c;或者说就是可以真正合格完成一些代码工作的程序员&#xff0c;应该具有的素质。 1&#xff1a;团队精神和协作能力 把它作为基本素质&#xff0c;并…

权限之浅理解

白马过隙&#xff0c;在感叹时光流逝的同时不得不承认在学习中随着知识面的不断扩展所接受的东西也越来越多&#xff0c;尤其是那些外形比较容易混淆的命令&#xff0c;着实让作为新手的吃了很多苦头&#xff0c;趁着学习紧张之时偷个懒整理这周易混淆的命令&#xff1a; chgrp…

机器学习算法之生成树

一、什么是决策树&#xff1f; 决策树&#xff08;Decision Tree&#xff09;是一种基本的分类和回归的方法。 分类决策树模型是一种描述对实例进行分类的树形结构。决策树由结点&#xff08;node&#xff09;和有向边&#xff08;directed edge&#xff09;组成。结点有两种…

强烈推荐给从事IT业的同行们 (转载)

作者&#xff1a;李学凌 文章来源&#xff1a;bbs.ustc.edu.cn 中国有很多小朋友&#xff0c;他们18,9岁或21,2岁&#xff0c;通过自学也写了不少代码&#xff0c;他们有的代码写的很漂亮&#xff0c;一些技术细节相当出众&#xff0c;也很有钻研精神&#xff0c;但是他…

微机原理控制转移类指令

1.无条件跳转指令 指令格式;JMP 目标地址 功能&#xff1a;JMP可以使程序无条件地跳转到程序存储器中某目标地址 注意点&#xff1a; 1&#xff09;指令目标地址若在JMP指令所在的代码段内&#xff0c;属段内跳转&#xff0c;指令只修改IP内容。指令目标地址若在JMP指令所在的代…

OPENNMS的后台并行管理任务

Concurrent management tasks: 1. . Action daemon - automated action (work flow)2. .数据采集Collection daemon - collects data3. .能力检查Capability daemon - capability check on nodes4. .动态主机配置协议DHCP daemon - DHCP clien…

机器学习算法之集成学习

集成学习的思想是将若干个学习器(分类器&回归器)组合之后产生一个新学习器。弱分类器(weak learner)指那些分类准确率只稍微好于随机猜测的分类器(errorrate < 0.5)。 集成算法的成功在于保证弱分类器的多样性(Diversity)。而且集成不稳定的算法也能够得到一个比较明显…

常用的方法论-NPS

转载于:https://www.cnblogs.com/qjm201000/p/7687510.html

controller调用controller的方法_SpringBoot 优雅停止服务的几种方法

转自&#xff1a;博客园&#xff0c;作者&#xff1a;黄青石www.cnblogs.com/huangqingshi/p/11370291.html 在使用 SpringBoot 的时候&#xff0c;都要涉及到服务的停止和启动&#xff0c;当我们停止服务的时候&#xff0c;很多时候大家都是kill -9 直接把程序进程杀掉&#x…

linux下安装Oracle10g时,安装rpm文件的技巧 (rpm -Uvh package名)

rpm -q package名 &#xff1a; 查询该package是否已经被安装了rpm -qa | grep package名 或是package 的关键字 &#xff1a; 查询该package是否已经被安装了rpm -Uvh package名 &#xff1a; 意思是update packagerpm -Uvh package名 --force &#xff1a; 意思是如果该…

机器学习之聚类概述

什么是聚类 聚类就是对大量未知标注的数据集&#xff0c;按照数据 内部存在的数据特征 将数据集划分为 多个不同的类别 &#xff0c;使 类别内的数据比较相似&#xff0c;类别之间的数据相似度比较小&#xff1b;属于 无监督学习。 聚类算法的重点是计算样本项之间的 相似度&…

程序员-建立你的商业意识 闫辉 著

1 程序员为什么需要商业意识 几 年前&#xff0c;当我刚刚认识Fishman的时候&#xff0c;听到他神奇的创业经历&#xff0c;觉得非常不可思议。甚至还专门写了一篇报道发到《电脑报》上&#xff0c;题目是《从程序员到 CEO》。不久&#xff0c;Fishman将创建的又一个新公司…

qt release打包发布_几种解决Qt程序打包后无法连接数据库问题的方法

Qt是一个跨平台C图形用户界面应用程序开发框架&#xff0c;使用它不仅可以方便地开发GUI程序&#xff0c;也可以开发非GUI程序&#xff0c;可以一次编写&#xff0c;处处编译。今天遇到的问题比较怪异&#xff0c;我开发的是一个桌面版订单管理系统&#xff0c;整体架构就是一个…