径向基神经网络_谷歌开源Neural Tangents:5行代码打造无限宽神经网络模型,帮助“打开ML黑匣子”...

116c11aedf3cf04fb9760180d16f842e.png
鱼羊 假装发自 凹非寺
量子位 报道 | 公众号 QbitAI

只要网络足够宽,深度学习动态就能大大简化,并且更易于理解。

最近的许多研究结果表明,无限宽度的DNN会收敛成一类更为简单的模型,称为高斯过程(Gaussian processes)。

于是,复杂的现象可以被归结为简单的线性代数方程,以了解AI到底是怎样工作的。

acf481bf243d7a098b679f88c22da89c.gif

所谓的无限宽度(infinite width),指的是完全连接层中的隐藏单元数,或卷积层中的通道数量有无穷多。

但是,问题来了:推导有限网络的无限宽度限制需要大量的数学知识,并且必须针对不同研究的体系结构分别进行计算。对工程技术水平的要求也很高。

谷歌最新开源的 Neural Tangents,旨在解决这个问题,让研究人员能够轻松建立、训练无限宽神经网络。

甚至只需要5行代码,就能够打造一个无限宽神经网络模型。

这一研究成果已经中了ICLR 2020。戳进文末Colab链接,即可在线试玩。

开箱即用,5行代码打造无限宽神经网络模型

Neural Tangents 是一个高级神经网络 API,可用于指定复杂、分层的神经网络,在 CPU/GPU/TPU 上开箱即用。

该库用 JAX编写,既可以构建有限宽度神经网络,亦可轻松创建和训练无限宽度神经网络。

有什么用呢?举个例子,你需要训练一个完全连接神经网络。通常,神经网络是随机初始化的,然后采用梯度下降进行训练。

研究人员通过对一组神经网络中不同成员的预测取均值,来提升模型的性能。另外,每个成员预测中的方差可以用来估计不确定性。

如此一来,就需要大量的计算预算。

但当神经网络变得无限宽时,网络集合就可以用高斯过程来描述,其均值和方差可以在整个训练过程中进行计算。

而使用 Neural Tangents ,仅需5行代码,就能完成对无限宽网络集合的构造和训练。

from neural_tangents import predict, staxinit_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(),stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(),stax.Dense(1, W_std=1.5, b_std=0.05))y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ‘ntk’, diag_reg=1e-4, compute_cov=True)

840b93547f724650d98594c73c80045e.gif

上图中,左图为训练过程中输出(f)随输入数据(x)的变化;右图为训练过程中的不确定性训练、测试损失。

将有限神经网络的集合训练和相同体系结构的无限宽度神经网络集合进行比较,研究人员发现,使用无限宽模型的精确推理,与使用梯度下降训练整体模型的结果之间,具有良好的一致性。

这说明了无限宽神经网络捕捉训练动态的能力。

不仅如此,常规神经网络可以解决的问题,Neural Tangents 构建的网络亦不在话下。

研究人员在 CIFAR-10 数据集的图像识别任务上比较了 3 种不同架构的无限宽神经网络。

d4d169d17b824206a1e9f3879a52f034.png

可以看到,无限宽网络模拟有限神经网络,遵循相似的性能层次结构,其全连接网络的性能比卷积网络差,而卷积网络的性能又比宽残余网络差。

但是,与常规训练不同,这些模型的学习动力在封闭形式下是易于控制的,也就是说,可以用前所未有的视角去观察其行为。

对于深入理解机器学习机制来说,该研究也提供了一种新思路。谷歌表示,这将有助于“打开机器学习的黑匣子”。

传送门

论文地址:https://arxiv.org/abs/1912.02803

谷歌博客:https://ai.googleblog.com/2020/03/fast-and-easy-infinitely-wide-networks.html

GitHub地址:https://github.com/google/neural-tangents

Colab地址:https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb

—完—

@量子位 · 追踪AI技术和产品新动态

深有感触的朋友,欢迎赞同、关注、分享三连վ'ᴗ' ի ❤

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/567725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python eval简介

eval函数的简介和语法 eval()函数用来执行一个字符串表达式,并返回表达式的值。还可以把字符串转化为list、tuple、dict。 eval函数的语法: eval(expression[,globals[,locals]])参数: expression:表达式。 globals&#xff1…

佳能g2800清零软件天空_可能是史上最有趣的3D建模软件

今天咱们要讲的不是一款BIM软件,而是一款有趣的3D体素建模软件。体素英文名叫Voxel,是把像素风格中的小方块引申到三维空间里,让图像呈现一小块一小块的鲜明风格。比如《我的世界》和最近非常火的《纪念碑谷2》就是这样的风格。这款软件叫做M…

Python 打开文件注意事项

利用try except语句捕获打开文件异常 filename"student.txt"#利用try except语句捕获打开文件异常 try:student_txtopen(filename,a) #以追加模式打开文件 except Exception as e:student_txtopen(filename,w) #文件不存在,创建文件并打开#打开文件 i…

python删除数组元素_python:从数组列表中删除一系列数字

我在从数组列表中删除范围A到B的元素时遇到问题。我在网上搜索的解决方案似乎只适用于单个元素、相邻元素和或整数元素。我在处理浮点数。 1 2 3 4 5 6 7self.genx np.arange(0, 5, 0.1) temp_select self.genx[1:3] #I want to remove numbers from 1 - 3 from genx print(t…

python 列表中dict中key排序

#1列表排序:使用lamada表达式进行排序 student_new[{id:1,name:无语1,english:100,python:98},{id:2,name:无语2,english:87,python:96},{id: 3, name: 无语3, english: 95, python: 100}]student_new.sort(keylambda x:x[english],reverseTrue) print(student_new)…

springcloud feign 服务调用其他服务_微服务实战——SpringCloud与Feign集成

上一篇集成了ZuulGateway和Eureka并进行了测试。在实际场景中,我们肯定会有很多的微服务,而他们之间可能会存在相互调用的关系,那么,如何优雅的处理服务之间的调用问题呢?接下来就是我们要解决的。简单的说下FeignFeig…

Python部分知识点

1format方法 format中 数字表示所占宽度 符号^表示居中显示 \t表示添加制表符 format_title"{:^4}{:^12}\t{:^8}\t{:^10}\t{:^10}" print(format_title.format("ID","名字","英语成绩","Python成绩","C语言成绩"…

python数字形式转换_在Python中将字母转换为数字

在Python中将字母转换为数字 如何完成以下步骤? characters [abcdefghijklmnopqrtuvwxyz] numbers [123456789101112131415161718192021222324] text raw_input( Write text: ) 我已经尝试了许多方法来解决它,但无法做到。 我想做事。 如果我键入“ h…

Python中赋值,深拷贝和浅拷贝

1python变量 变量的存储,采用了引用语义的方式,存储的只是一个变量的值所在的内存地址,而不是这个变量的值本身。 2赋值 python变量赋值实际上是对象的引用。 如: list_a [1,2,3,"hello",["python",&qu…

excel进度条与百分比不符_Excel项目管理模板V2.0

Excel表哥公众号推送的第一篇文章 如何用Excel制作一个高逼格的项目管理模板 累积获得了超多的下载量。下面是和读者朋友的一些交流互动:在使用过程中大家陆续也反馈了一些问题和建议。因此我们推出了项目管理模板V2.0 升级版!算作是给读者朋友们的一个答…

python 格式化输出%和format

1 %用法 1.1整数的输出 %o —— oct 八进制 %d —— dec 十进制 %x —— hex 十六进制 print(%o % 20) #24 print(%d % 20) #20 print(%x % 20) #141.2浮点数输出 %f ——默认保留小数点后面六位有效数字   %.3f,保留3位小数位 %e ——默认保留小数点后面六…

linux系统中安装python_2. Linux 下安装python

Linux 各个版本的系统都自带python解释器,可以在shell界面输入 python 就能进入交互界面,并显示python版本信息; 现在最流行的版本是python2.7,Centos6 默认安装2.6.6的版本,Centos7 默认安装2.7.* 的版本。 如果Cento…

零基础学习Java-素数和

在编写素数和程序中,发现了以下的问题: 在编程的过程中: 关于比较范围的不牢固各数据需要给初始化的值对于使用isPrime来辅助程序运行的遗忘 在程序运行的过程中: 除数不能为0 出现Exception in thread “main” java.lang.Ar…

访问修饰符作用范围由大到小是_9个java基础小知识

一、面向对象和面向过程的区别1. 面向过程 : 面向过程性能比面向对象高。因为类调用时需要实例化,开销比较大,比较消耗资源,所以当性能是最重要的考虑因素时(例如单片机、嵌入式开发、Linux/Unix等一般采用面向过程开发…

System.out.println(i++); System.out.println(++i);的区别

之前一直对i和i很模糊,这次通过两个小demo来探究下。 例1: 1 public static void main(String[] args) { 2 int i2; 3 System.out.println(i); 4 System.out.println(i); 5 }run: 2 3 例2: 1 public static void…

python如何导入函数_Python导入(import)模块的方法

1、导入整个模块:模块 是扩展名为.py的文件,包含要导入到程序中的代码。import module_name 2、导入特定的函数from module_name import function_name 也可以导入多个from module_name import function_0, function_1, function_2 3、使用as 给函数指定…

白盒测试用例设计方法(语句覆盖、判定覆盖、条件覆盖、判定/条件覆盖、组合覆盖、路径覆盖、基本路径覆盖)

语句覆盖:每条语句至少执行一次。 判定覆盖:每个判定的所有可能结果至少出现一次。(又称“分支覆盖”) 条件覆盖:每个条件的所有可能结果至少执行一次。 判定/条件覆盖:一个判定中的每个条件的所有可能结果…

Python之列表和元组

01 序列 成员有序排列的,且可以通过下标偏移量访问到它的一个或者几个成员,这类类型统称为序列 序列数据类型包括:字符串,列表,和元组类型。 特点: 都支持下面的特性 索引与切片操作符 成员关系操作符(in , not in) 连接操作符() & 重复操作符(*) 0…

负数、取模与取余

总结: ‘%’ 在C/C,Java等语言中意为 取余 ,在python 中意为 取模 取余(rem)和取模(mod)在被除数、除数同号时,结果是等同的,异号时会有区别,所以要特别注意…

webgl编程指南源码_ThreeJS 源码剖析之 Renderer(一)

引子?最近,忽然想起曾在 WebGL 基础系列 文章中立下 flag:“后续还打算出 《ThreeJS 源码剖析》 系列”(特意翻出原话?),项目忙了一阵后,便决定开始写此系列,更新周期不固定,毕竟项目排期“天晓得”。此系…