使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

使用 PyTorch 数据读取,JAX 框架来训练一个简单的神经网络

本文例程部分主要参考官方文档。

JAX简介

JAX 的前身是 Autograd ,也就是说 JAX 是 Autograd 升级版本,JAX 可以对 Python 和 NumPy 程序进行自动微分。可以通过 Python的大量特征子集进行区分,包括循环、分支、递归和闭包语句进行自动求导,也可以求三阶导数(三阶导数是由原函数导数的导数的导数。 所谓三阶导数,即原函数导数的导数的导数,将原函数进行三次求导)。通过 grad ,JAX 支持反向模式和正向模式的求导,而且这两种模式可以任意组合成任何顺序,具有一定灵活性。

另一个特点是基于 XLA 的 JIT 即时编译,大大提高速度。

需要注意的是,JAX 仅提供计算时的优化,相当于是一个支持自动微分和 JIT 编译的 NumPy。也就是说,数据处理 Dataloader 等其他框架都会提供的 utils 功能这里是没有的。所幸 JAX 可以比较好的支持 PyTorch、 TensorFlow 等主流框架的数据读取。本文就将基于 PyTorch 的数据读取工具和 JAX 框架来训练一个简单的神经网络

以下是国内优秀的机器学习框架 OneFlow 同名公司的创始人袁进辉老师在知乎上的一个评价:

如果说tensorflow 是主打lazy, 偏functional 的思想,但实现的臃肿面目可憎;pytorch 则主打eager, 偏imperative 编程,但内核简单,可视为支持gpu的numpy, 加上一个autograd。JAX 像是这俩框架的混合体,取了tensorflow的functional和PyTorch的精简,即支持gpu的 numpy, 具有autograd功能,非常追求函数式编程的思想,强调无状态,immutable,加上JIT修饰符后就是lazy,可以使用xla对计算流程进行静态分析和优化。当然JAX不带jit也可像pytorch那种命令式编程和eager执行。

JAX有可能和PyTorch竞争。

安装

安装可以通过源码编译,也可以直接 pip。源码编译详见[官方文档: Building from source][2],对于官方没有提供预编译包的 cuda-cudnn 版本组合,只能通过自己源码构建。pip的方式比较简单,在 github 仓库的 README 文档中就有介绍。要注意,不同于 PyTorch 等框架,JAX 不会再 pip 安装中绑定 CUDA 或 cuDNN 进行安装,若未安装,需要自己先手动安装。仅使用 CPU 的版本也有支持。

笔者是 CUDA11.1,CUDNN 8.2,安装如下:

pip install --upgrade pip
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

前面已经提到过,本文会借用 PyTorch 的数据处理工具,因此 torch 和 torchvision 也是必不可少的(已经安装的可跳过):

pip install torch torchvision

构建简单的神经网络训练

框架安装完毕,我们正式开始。接下来我们使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算,用 PyTorch 的数据加载 API 来加载图像和标签。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

超参数

# 本函数用来随机初始化网络权重
def random_layer_params(m, n, key, scale=1e-2):w_key, b_key = random.split(key)return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n, ))# 初始化各个全连接层
def init_network_params(sizes, key):keys = random.split(key, len(sizes))return [random_layer_params(m, n, k) for m, n, k in zip(sizes[: -1], sizes[1: ], keys)]layer_sizes = [784, 512, 512, 10]
step_size = 0.01
num_epochs = 8
batch_size = 128
n_targets = 10
params = init_network_params(layer_sizes, random.PRNGKey(0))

自动分批次预测

对于小批量,我们稍后将使用 JAX 的 vmap 函数来自动处理,而不会降低性能。我们现在先准备一个单张图像推理预测函数:

from jax.scipy.special import logsumexpdef relu(x):return jnp.maximum(0, x)# 对单张图像进行推理的函数
def predict(params, image):activations = imagefor w, b in params[: -1]:outputs = jnp.dot(w, activations) + bactivations = relu(outputs)final_w, final_b = params[-1]logits = jnp.dot(final_w, activations) + final_breturn logits - logsumexp(logits)

这个函数应该只能用来处理单张图像推理预测,而不能批量处理,我们简单测试一下,对于单张:

random_flattened_images = random.normal(random.PRNGKey(1), (28 * 28,))
preds = predict(params, random_flattened_images)
print(preds.shape)

输出:

(10,)

对于批次:

random_flattened_images = random.normal(random.PRNGKey(1), (10, 28 * 28))
try:preds = predict(params, random_flattened_images)
except TypeError:print('Invalid shapes!')

输出:

Invalid shapes!

现在我们使用 vmap 来使它能够处理批量数据:

# 用 vmap 来实现一个批量版本
batched_predict = vmap(predict, in_axes=(None, 0))# batched_predict 的调用与 predict 相同
batched_preds = batched_predict(params, random_flattened_images)
print(batched_preds.shape)

输出:

(10, 10)

现在,我们已经做好了准备工作,接下来就是要定义一个神经网络并且进行训练了,我们已经构建了的自动批处理版本的 predict 函数,并且将在损失函数中也使用它。我们将使用 grad 来得到损失关于神经网络参数的导数。而且,这一切都可以用 jit 进行加速。

实用工具函数和损失函数

def one_hot(x, k, dtype=jnp.float32):"""构建一个 x 的 k 维 one-hot 编码."""return jnp.array(x[:, None] == jnp.arange(k), dtype)def accuracy(params, images, targets):target_class = jnp.argmax(targets, axis=1)predicted_class =  jnp.argmax(batched_predict(params, images), axis=1)return jnp.mean(predicted_class == target_class)def loss(params, images, targets):preds = batched_predict(params, images)return -jnp.mean(preds * targets)@jit
def update(params, x, y):grads = grad(loss)(params, x, y)return [(w - step_size * dw, b - step_size * db) for (w, b), (dw, db) in zip(params, grads)]

使用 PyTorch 进行数据读取

JAX 是一个专注于程序转换和支持加速的 NumPy,对于数据的读取,已经有很多优秀的工具了,这里我们就直接用 PyTorch 的 API。我们会做一个小的 shim 来使得它能够支持 NumPy 数组。

import numpy as np
from torch.utils import data
from torchvision.datasets import MNISTdef numpy_collate(batch):if isinstance(batch[0], np.ndarray):return np.stack(batch)elif isinstance(batch[0], (tuple, list)):transposed = zip(*batch)return [numpy_collate(samples) for samples in transposed]else:return np.array(batch)class NumpyLoader(data.DataLoader):def __init__(self, dataset, batch_size=1,shuffle=False, sampler=None,batch_sampler=None, num_workers=0,pin_memory=False, drop_last=False,timeout=0, worker_init_fn=None):super(self.__class__, self).__init__(dataset,batch_size=batch_size,shuffle=shuffle,sampler=sampler,batch_sampler=batch_sampler,collate_fn=numpy_collate,num_workers=num_workers,pin_memory=pin_memory,drop_last=drop_last,timeout=timeout,worker_init_fn=worker_init_fn)class FlattenAndCast(object):def __call__(self, pic):return np.ravel(np.array(pic, dtype=jnp.float32))

接下来借助 PyTorch 的 datasets,定义我们自己的 dataset:

mnist_dataset = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
training_generator = NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)

此处应该输出一堆下载 MNIST 数据集的信息,就不贴了。

接下来分别拿到整个训练集和整个测试集,下面会用于测准确率:

train_images = np.array(mnist_dataset.train_data).reshape(len(mnist_dataset.train_data), -1)
train_labels = one_hot(np.array(mnist_dataset.train_labels), n_targets)mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False)
test_images = jnp.array(mnist_dataset_test.test_data.numpy().reshape(len(mnist_dataset_test.test_data), -1), dtype=jnp.float32)
test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)

开始训练

import time
for epoch in range(num_epochs):start_time = time.time()for x, y in training_generator:y = one_hot(y, n_targets)params = update(params, x, y)epoch_time = time.time() - start_timetrain_acc = accuracy(params, train_images, train_labels)test_acc = accuracy(params, test_images, test_labels)print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))print("Training set accuracy {}".format(train_acc))print("Test set accuracy {}".format(test_acc))

输出:

Epoch 0 in 3.29 sec
Training set accuracy 0.9156666994094849
Test set accuracy 0.9196999669075012
...
Epoch 7 in 1.78 sec
Training set accuracy 0.9736666679382324
Test set accuracy 0.9670999646186829

在本文的过程中,我们已经使用了整个 JAX API:grad 用于自动微分、jit 用于加速、vmap 用于自动矢量化。我们使用 NumPy 来进行我们所有的计算,并从 PyTorch 借用了出色的数据加载器,并在 GPU 上运行了整个过程。

Ref:

https://juejin.cn/post/6994695537316331556

https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html

https://jax.readthedocs.io/en/latest/developer.html#building-from-source

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532574.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Yapi Mock 远程代码执行漏洞

跟风一波复现Yapi 漏洞描述: YApi接口管理平台远程代码执行0day漏洞,攻击者可通过平台注册用户添加接口,设置mock脚本从而执行任意代码。鉴于该漏洞目前处于0day漏洞利用状态,强烈建议客户尽快采取缓解措施以避免受此漏洞影响 …

C++ ACM模式输入输出

C ACM模式输入输出 以下我们都以求和作为题目要求,来看一下各种输入输出应该怎么写。 1 只有一个或几个输入 输入样例: 3 5 7输入输出模板: int main() {int a, b, c;// 接收有限个输入cin >> a >> b >> c;// 输出结果…

CVE-2017-10271 WebLogic XMLDecoder反序列化漏洞

漏洞产生原因: CVE-2017-10271漏洞产生的原因大致是Weblogic的WLS Security组件对外提供webservice服务,其中使用了XMLDecoder来解析用户传入的XML数据,在解析的过程中出现反序列化漏洞,导致可执行任意命令。攻击者发送精心构造的…

树莓派摄像头 C++ OpenCV YoloV3 实现实时目标检测

树莓派摄像头 C OpenCV YoloV3 实现实时目标检测 本文将实现树莓派摄像头 C OpenCV YoloV3 实现实时目标检测,我们会先实现树莓派对视频文件的逐帧检测来验证算法流程,成功后,再接入摄像头进行实时目标检测。 先声明一下笔者的主要软硬件配…

【实战】记录一次服务器挖矿病毒处理

信息收集及kill: 查看监控显示长期CPU利用率超高,怀疑中了病毒 top 命令查看进程资源占用: netstat -lntupa 命令查看有无ip进行发包 netstat -antp 然而并没有找到对应的进程名 查看java进程和solr进程 ps aux :查看所有进程…

ag 搜索工具参数详解

ag 搜索工具参数详解 Ag 是类似ack, grep的工具,它来在文件中搜索相应关键字。 官方列出了几点选择它的理由: 它比ack还要快 (和grep不在一个数量级上)它会忽略.gitignore和.hgignore中的匹配文件如果有你想忽略的文…

CVE-2013-4547 文件名逻辑漏洞

搭建环境,访问 8080 端口 漏洞说明: Nginx: Nginx是一款轻量级的Web 服务器/反向代理服务器及电子邮件(IMAP/POP3)代理服务器,在BSD-like 协议下发行。其特点是占有内存少,并发能力强&#xf…

CMake指令入门 ——以构建OpenCV项目为例

CMake指令入门 ——以构建OpenCV项目为例 转自:https://blog.csdn.net/sandalphon4869/article/details/100589747 一、安装 sudo apt-get install cmake安装好后,输入 cmake -version如果出现了cmake的版本显示,那么说明安装成功 二、c…

CVE-2017-7529Nginx越界读取缓存漏洞POC

漏洞影响 低危,造成信息泄露,暴露真实ip等 实验内容 漏洞原理 通过查看patch确定问题是由于对http header中range域处理不当造成,焦点在ngx_http_range_parse 函数中的循环: HTTP头部range域的内容大约为Range: bytes4096-81…

Linux命令行性能监控工具大全

Linux命令行性能监控工具大全 作者:Arnold Lu 原文:https://www.cnblogs.com/arnoldlu/p/9462221.html 关键词:top、perf、sar、ksar、mpstat、uptime、vmstat、pidstat、time、cpustat、munin、htop、glances、atop、nmon、pcp-gui、collect…

Weblogic12c T3 协议安全漏洞分析【CVE-2020-14645 CVE-2020-2883 CVE-2020-14645】

给个关注?宝儿! 给个关注?宝儿! 给个关注?宝儿! 关注公众号:b1gpig信息安全,文章推送不错过 ## 前言 WebLogic是美国Oracle公司出品的一个application server,确切的说是一个基于JAV…

Getshell总结

按方式分类: 0x01注入getshell: 0x02 上传 getwebshell 0x03 RCE getshell 0x04 包含getwebshell 0x05 漏洞组合拳getshell 0x06 系统层getcmdshell 0x07 钓鱼 getcmdshell 0x08 cms后台getshell 0x09 红队shell竞争分析 0x01注入getshell:…

编写可靠bash脚本的一些技巧

编写可靠bash脚本的一些技巧 原作者:腾讯技术工程 原文链接:https://zhuanlan.zhihu.com/p/123989641 写过很多 bash 脚本的人都知道,bash 的坑不是一般的多。 其实 bash 本身并不是一个很严谨的语言,但是很多时候也不得不用。以下…

python 到 poc

0x01 特殊函数 0x02 模块 0x03 小工具开发记录 特殊函数 # -*- coding:utf-8 -*- #内容见POC.demo; POC.demo2 ;def add(x,y):axyprint(a)add(3,5) print(------------引入lambad版本:) add lambda x,y : xy print(add(3,5)) #lambda函数,在lambda函数后面直接…

protobuf版本常见问题

protobuf版本常见问题 许多软件都依赖 google 的 protobuf,我们很有可能在安装多个软件时重复安装了多个版本的 protobuf,它们之间很可能出现冲突并导致在后续的工作中出现版本不匹配之类的错误。本文将讨论笔者在使用 protobuf 中遇到的一些问题&#…

CMake常用命令整理

CMake常用命令整理 转自:https://zhuanlan.zhihu.com/p/315768216 CMake 是什么我就不用再多说什么了,相信大家都有接触才会看一篇文章。对于不太熟悉的开发人员可以把这篇文章当个查找手册。 1.CMake语法 1.1 指定cmake的最小版本 cmake_minimum_r…

CVE-2021-41773 CVE-2021-42013 Apache HTTPd最新RCE漏洞复现 目录穿越漏洞

给个关注?宝儿! 给个关注?宝儿! 给个关注?宝儿! CVE-2021-41773漏洞描述: Apache HTTPd是Apache基金会开源的一款流行的HTTP服务器。2021年10月8日Apache HTTPd官方发布安全更新,披…

SSRF,以weblogic为案例

给个关注?宝儿! 给个关注?宝儿! 给个关注?宝儿! 复习一下ssrf的原理及危害,并且以weblog的ssrf漏洞为案例 漏洞原理 SSRF(Server-side Request Forge, 服务端请求伪造) 通常用于控制web进而…

C++11 右值引用、移动语义、完美转发、万能引用

C11 右值引用、移动语义、完美转发、引用折叠、万能引用 转自:http://c.biancheng.net/ C中的左值和右值 右值引用可以从字面意思上理解,指的是以引用传递(而非值传递)的方式使用 C 右值。关于 C 引用,已经在《C引用…

C++11 std::function, std::bind, std::ref, std::cref

C11 std::function, std::bind, std::ref, std::cref 转自&#xff1a;http://www.jellythink.com/ std::function 看看这段代码 先来看看下面这两行代码&#xff1a; std::function<void(EventKeyboard::KeyCode, Event*)> onKeyPressed; std::function<void(Ev…