JAX: 快如 PyTorch,简单如 NumPy - 深度学习与数据科学

alt

JAX 是 TensorFlow 和 PyTorch 的新竞争对手。 JAX 强调简单性而不牺牲速度和可扩展性。由于 JAX 需要更少的样板代码,因此程序更短、更接近数学,因此更容易理解。

长话短说:

  • 使用 import jax.numpy 访问 NumPy 函数,使用 import jax.scipy 访问 SciPy 函数。
  • 通过使用 @jax.jit 进行装饰,可以加快即时编译速度。
  • 使用 jax.grad 求导。
  • 使用 jax.vmap 进行矢量化,并使用 jax.pmap 进行跨设备并行化。

函数式编程

alt

JAX 遵循函数式编程哲学。这意味着您的函数必须是独立的或纯粹的:不允许有副作用。本质上,纯函数看起来像数学函数(图 1)。有输入进来,有东西出来,但与外界没有沟通。

例子#1

以下代码片段是一个非功能纯的示例。

import jax.numpy as jnp

bias = jnp.array(0)
def impure_example(x):
   total = x + bias
   return total

注意 impure_example 之外的偏差。在编译期间(见下文),偏差可能会被缓存,因此不再反映偏差的变化。

例子#2

这是一个pure的例子。

def pure_example(x, weights, bias):
   activation = weights @ x + bias
   return activation

在这里,pure_example 是独立的:所有参数都作为参数传递。

确定性采样器

alt

在计算机中,不存在真正的随机性。相反,NumPy 和 TensorFlow 等库会跟踪伪随机数状态来生成“随机”样本。

函数式编程的直接后果是随机函数的工作方式不同。由于不再允许全局状态,因此每次采样随机数时都需要显式传入伪随机数生成器 (PRNG) 密钥

import jax

key = jax.random.PRNGKey(42)
u = jax.random.uniform(key)

此外,您有责任为任何后续调用推进“随机状态”。

key = jax.random.PRNGKey(43)

# Split off and consume subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

# Split off and consume second subkey.
key, subkey = jax.random.split(key)
u = jax.random.uniform(subkey)

..

jit

您可以通过即时编译 JAX 指令来加快代码速度。例如,要编译缩放指数线性单位 (SELU) 函数,请使用 jax.numpy 中的 NumPy 函数并将 jax.jit 装饰器添加到该函数,如下所示:

from jax import jit

@jit
def selu(x, α=1.67, λ=1.05):
 return λ * jnp.where(x > 0, x, α * jnp.exp(x) - α)

JAX 会跟踪您的指令并将其转换为 jaxpr。这使得加速线性代数 (XLA) 编译器能够为您的加速器生成非常高效的优化代码。

gard

JAX 最强大的功能之一是您可以轻松获取 gard。使用 jax.grad,您可以定义一个新函数,即符号导数。

from jax import grad

def f(x):
   return x + 0.5 * x**2

df_dx = grad(f)
d2f_dx2 = grad(grad(f))

正如您在示例中看到的,您不仅限于一阶导数。您可以通过简单地按顺序链接 grad 函数 n 次来获取 n 阶导数。

vmap 和 pmap

矩阵乘法使所有批次尺寸正确需要非常细心。 JAX 的矢量化映射函数 vmap 通过对函数进行矢量化来减轻这种负担。基本上,每个按元素应用函数 f 的代码块都是由 vmap 替换的候选者。让我们看一个例子。

计算线性函数:

def linear(x):
 return weights @ x

在一批示例 [x₁, x2,..] 中,我们可以天真地(没有 vmap)实现它,如下所示:

def naively_batched_linear(X_batched):
 return jnp.stack([linear(x) for x in X_batched])

相反,通过使用 vmap 对线性进行向量化,我们可以一次性计算整个批次:

def vmap_batched_linear(X_batched):
 return vmap(linear)(X_batched)

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/582943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为云CCE-集群内访问-根据ip访问同个pod

华为云CCE-集群内访问-根据ip访问同个pod 问题描述:架构如下:解决方法: 问题描述: 使用service集群内访问时,由于启用了两个pod,导致请求轮询在两个pod之间,无法返回正确的结果。 架构如下&am…

Javascript 可迭代对象与yeild

一、可迭代对象(Iterable object) Javascript 可迭代对象是指实现了Symbol.iterator方法的对象,该方法返回一个迭代器对象,可以通过迭代器对象来遍历对象中的元素。常见的可迭代对象包括数组、字符串、Map、Set等。可以使用for..…

Redis布隆过滤器BloomFilter

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理、数据库技术🔥如果感觉博主的文章还不错的…

PS绘图,切图方法

在PS绘图并切图时,以下是一种常见的方法: 绘图 打开 Photoshop,并创建一个新的文件以开始您的绘图工作。使用绘图工具(例如画笔、铅笔、形状工具等)进行您的绘图操作。您可以使用不同的图层来组织和管理不同的元素。…

Kafka优异的性能是如何实现的?

Apache Kafka是一个分布式流处理平台,设计用来处理高吞吐量的数据。它被广泛用于构建实时数据管道和流式应用程序。Kafka之所以能够提供优秀的性能和高吞吐量,主要得益于以下几个方面的设计和实现: 1. 分布式系统设计 Kafka是一个分布式系统…

Unreal Engine游戏引擎的优势

在现在这个繁荣的游戏开发行业中,选择合适的游戏引擎是非常重要的。其中,Unreal Engine作为一款功能强大的游戏引擎,在业界广受赞誉。那Unreal Engine游戏引擎究竟有哪些优势,带大家简单的了解一下。 图形渲染技术 Unreal Engin…

遇见sql语句拼装报错 sql injection violation, syntax error: syntax error, expect RPAREN

在使用PostgreSql瀚高数据库时,相同的语句 select * from public.files_info fi where fi.file_size notnull 在DBever能执行,但是在spring中报错 在spring中JPA版本问题导致,不支持这种写法,会识别为sql注入风险,应…

硅像素传感器文献调研(四)

写在前面: 好喜欢这种短论文哈哈哈哈哈 感觉这篇文献已经提到了保护环的概念啊,只不过叫的是:场限制环。 1986——高压功率器件场终端横向掺杂的变化 0.摘要 对于高压平面结提出了一个简单的新概念。通过在氧化物掩模中的小开口和随后的驱…

python如何读取被压缩的图像

读取压缩的图像数据: PackBits 压缩介绍: CCITT T.3 压缩介绍: 读取压缩的图像数据: 在做图像处理的时候,平时都是使用 函数io.imread() 或者是 函数cv2.imread( ) 函数来读取图像数据,很少用PIL.Image…

Grafana Loki 组件介绍

Loki 日志系统由以下3个部分组成: Loki是主服务器,负责存储日志和处理查询。Promtail是专为loki定制的代理,负责收集日志并将其发送给 loki 。Grafana用于 UI展示。 Distributor Distributor 是客户端连接的组件,用于收集日志…

2022年全国职业院校技能大赛(高职组)“云计算”赛项赛卷①第一场次:私有云

2022年全国职业院校技能大赛(高职组) “云计算”赛项赛卷1 第一场次:私有云(30分) 目录 2022年全国职业院校技能大赛(高职组) “云计算”赛项赛卷1 第一场次:私有云&#xff0…

python获取异常信息exc_info和print_exc

1 python获取异常信息exc_info和print_exc python通过sys.exc_info获取异常信息,通过traceback.print_exc打印堆栈信息,包括错误类型和错误位置等信息。 1.1 异常不一定是错误 所有错误都是异常,但并非所有异常都是错误。比如,…

【Linux学习笔记】解析Linux系统内核:架构、功能、工作原理和发展趋势

操作系统是一个用来和硬件打交道并为用户程序提供一个有限服务集的低级支撑软件。一个计算机系统是一个硬件和软件的共生体,它们互相依赖,不可分割。计算机的硬件,含有外围设备、处理器、内存、硬盘和其他的电子设备组成计算机的发动机。但是…

磁盘相关知识

一、硬盘数据结构 1.扇区: 盘片被分为多个扇形区域,每个扇区存放512字节的数据(扇区越多容量越大) 存放数据的最小单位 512字节 (硬盘最小的存储单位是扇区,512 个字节,八个扇区组成一块&…

FPGA - 231227 - 5CSEMA5F31C6 - 电子万年历

TAG - F P G A 、 5 C S E M A 5 F 31 C 6 、电子万年历、 V e r i l o g FPGA、5CSEMA5F31C6、电子万年历、Verilog FPGA、5CSEMA5F31C6、电子万年历、Verilog 顶层模块 module TOP(input CLK,RST,inA,inB,inC,switch_alarm,output led,beep_led,output [41:0] dp );// 按键…

听GPT 讲Rust源代码--src/tools(29)

File: rust/src/tools/clippy/clippy_lints/src/unused_peekable.rs 在Rust源代码中,rust/src/tools/clippy/clippy_lints/src/unused_peekable.rs这个文件是Clippy工具中一个特定的Lint规则的实现文件,用于检测未使用的Peekable迭代器。 Peekable迭代器…

JavaScript中alert、confrim、prompt的使用及区别【精选】

Hi i,m JinXiang ⭐ 前言 ⭐ 本篇文章主要介绍JavaScript中alert、confrim、prompt的区别及使用以及部分理论知识 🍉欢迎点赞 👍 收藏 ⭐留言评论 📝私信必回哟😁 🍉博主收将持续更新学习记录获,友友们有任…

每日一题:LeetCode-LCR 179. 查找总价格为目标值的两个商品

每日一题系列(day 16) 前言: 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 🌈 &#x1f50e…

大厂前端面试题总结(百度、字节跳动、腾讯、小米.....),附上热乎面试经验!

先简单介绍下自己,我“平平无奇小天才”一枚,毕业于南方普通985普通学生,有幸去百度、字节面试,感觉大公司就是不一样,印象最深的是字节,所以有必要总结一下面试经验,以及面试中遇到的一些问题&…

八股文打卡day13——计算机网络(13)

面试题:DNS是什么?DNS的查询过程是什么? 我的回答: 我来讲一下我对DNS的理解 DNS是域名系统,它是一个域名和IP地址相互映射的数据库。通过DNS,可以将我们浏览器中输入的域名,例如:…