动手实现一个带自动微分的深度学习框架

动手实现一个带自动微分的深度学习框架

转自:Automatic Differentiation Tutorial

参考代码:https://github.com/borgwang/tinynn-autograd (主要看 core/tensor.py 和 core/ops.py)

在这里插入图片描述

目录

  • 简介
  • 自动求导设计
  • 自动求导实现
  • 一个例子
  • 总结
  • 参考资料

简介

梯度下降(Gradient Descent)及其衍生算法是神经网络训练的基础,梯度下降本质上就是求解损失关于网络参数的梯度,不断计算这个梯度对网络参数进行更新。现代的神经网络框架都实现了自动求导的功能,只需要要定义好网络前向计算的逻辑,在运算时自动求导模块就会自动把梯度算好,不用自己手写求导梯度。

笔者在之前的 一篇文章 中讲解和实现了一个迷你的神经网络框架 tinynn,在 tinynn 中我们定义了网络层 layer 的概念,整个网络是由一层层的 layer 叠起来的(全连接层、卷积层、激活函数层、Pooling 层等等),如下图所示

在这里插入图片描述

在实现的时候需要显示为每层定义好前向 forward 和反向 backward(梯度计算)的计算逻辑。从本质上看 这些 layer 其实是一组基础算子的组合,而这些基础算子(加减乘除、矩阵变换等等)的导函数本身都比较简单,如果能够将这些基础算子的导函数写好,同时把不同算子之间连接逻辑记录(计算依赖图)下来,那么这个时候就不再需要自己写反向了,只需要计算损失,然后从损失函数开始,让梯度自己用预先定义好的导函数,沿着计算图反向流动即可以得到参数的梯度,这个就是自动求导的核心思想。tinynn 中之所有 layer 这个概念,一方面是符合我们直觉上的理解,另一方面是为了在没有自动求导的情况下方便实现。有了自动求导,我们可以抛开 layer 这个概念,神经网络的训练可以抽象为定义好一个网络的计算图,然后让数据前向流动,让梯度自动反向流动( TensorFlow 这个名字起得相当有水准)。

我们可以看看 PyTorch 的一小段核心的训练代码(来源官方文档 MNIST 例子)

for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()  # 初始化梯度output = model(data)  # 从 data 到 output 的计算图loss = F.nll_loss(output, target) # 从 output 到 loss 的计算图loss.backward()  # 梯度从 loss 开始反向流动optimizer.step()  # 使用梯度对参数更新

可以看到 PyTorch 的基本思路和我们上面描述的是一致的,定义好计算图 -> forward 得到损失 -> 梯度反向流动。

自动求导设计

知道了自动求导的基本流程之后,我们考虑如何来实现。先考虑没有自动求导,为每个运算手动写 backward 的情况,在这种情况下我们实际上定义了两个计算图,一个前向一个反向,考虑最简单的线性回归的运算 WX+BWX+BWX+B,其计如下所示。

在这里插入图片描述

可以看到这两个计算图的结构实际上是一样的,只是在前向流动的是计算的中间结果,反向流动的是梯度,以及中间的运算反向的时候是导数运算。实际上我们可以把两者结合到一起,只定义一次前向计算图,让反向计算图自动生成

在这里插入图片描述

从实现的角度看,如果我们不需要自动求导,那么网络框架中的 Tensor 类只需要对 Tensor 运算符有定义,能够进行数值运算(tinynn 中就简单的使用 ndarray 作为 Tensor 的实现)。但如果要实现自动求导,那么 Tensor 类需要额外做几件事:

  1. 增加一个梯度的变量保存当前 tensor 的梯度
  2. 保存当前 tensor 依赖的 tensor(如上图中 O1O1 依赖于 X,WX,W)
  3. 保存下对各个依赖 tensor 的导函数(这个导函数的作用是将当前 tensor 的梯度传到依赖的 tensor 上)

自动求导实现

我们按照上面的分析开始实现 Tensor 类如下,初始化方法中首先把 tensor 的值保存下来,然后有一个 requires_grad 的 bool 变量表明这个 tensor 是不是需要求梯度,还有一个 dependency 的列表用于保存该 tensor 依赖的 tensor 以及对于他们的导函数。

zero_grad() 方法比较简单,将当前 tensor 的梯度设置为 0,防止梯度的累加。自动求导从调用计算图的最后一个节点 tensor 的 backward() 方法开始(在神经网络中这个节点一般是 loss)。backward() 方法主要流程为

  • 确保改 tensor 确实需要求导 self.requires_grad == True
  • 将从上个 tensor 传进来的梯度加到自身梯度上,如果没有(反向求导的起点 tensor),则将梯度初始化为 1.0
  • 对每一个依赖的 tensor 运行保存下来的导函数,计算传播到依赖 tensor 的梯度,然后调用依赖 tensor 的 backward() 方法。可以看到这其实就是 Depth-First Search 计算图的节点
def as_tensor(obj):if not isinstance(obj, Tensor):obj = Tensor(obj)return objclass Tensor:def __init__(self, values, requires_grad=False, dependency=None):self._values = np.array(values)self.shape = self.values.shapeself.grad = Noneif requires_grad:self.zero_grad()self.requires_grad = requires_gradif dependency is None:dependency = []self.dependency = dependency@propertydef values(self):return self._values@values.setterdef values(self, new_values):self._values = np.array(new_values)self.grad = Nonedef zero_grad(self):self.grad = np.zeros(self.shape)def backward(self, grad=None):assert self.requires_grad, "Call backward() on a non-requires-grad tensor."grad = 1.0 if grad is None else gradgrad = np.array(grad)# accumulate gradientself.grad += grad# propagate the gradient to its dependenciesfor dep in self.dependency:grad_for_dep = dep["grad_fn"](grad)dep["tensor"].backward(grad_for_dep)

可能看到这里读者可能会疑问,一个 tensor 依赖的 tensor 和对他们的导函数(也就是 dependency 里面的东西)从哪里来?似乎没有哪一个方法在做保存依赖这件事。

假设我们可能会这样使用我们的 Tensor 类

W = Tensor([[1], [3]], requires_grad=True)  # 2x1 tensor
X = Tensor([[1, 2], [3, 4], [5, 6], [7, 8]], requires_grad=True)  # 4x2 tensor
O = X @ W  # suppose to be a 4x1 tensor

如何让 XW 完成矩阵乘法输出正确的 O 的同时,让 O 能记下他依赖于 WX 呢?答案是重载运算符

class Tensor:# ...def __matmul__(self, other):# 1. calculate forward valuesvalues = self.values @ other.values# 2. if output tensor requires_gradrequires_grad = ts1.requires_grad or ts2.requires_grad# 3. build dependency listdependency = []if self.requires_grad:# O = X @ W# D_O / D_X = grad @ W.Tdef grad_fn1(grad):return grad @ other.values.Tdependency.append(dict(tensor=self, grad_fn=grad_fn1))if other.requires_grad:# O = X @ W# D_O / D_W = X.T @ graddef grad_fn2(grad):return self.values.T @ graddependency.append(dict(tensor=other, grad_fn=grad_fn2))return Tensor(values, requires_grad, dependency)# ...

关于 Python 中如何重载运算符这里不展开,读者有兴趣可以参考官方文档或者这篇文章。基本上在 Tensor 类内定义了 __matmul__ 这个方法后,实际上是重载了矩阵乘法运算符 @ (Python 3.5 以上支持) 。当运行 X @ W 时会自动调用 X__matmul__ 方法。

这个方法里面做了三件事:

  1. 计算矩阵乘法结果(这个是必须的)

  2. 确定是否需要新生成的 tensor 是否需要梯度,这个由两个操作数决定。比如在这个例子中,如果 W 或者 X 需要梯度,那么生成的 O也是需要计算梯度的(这样才能够计算 W 或者 X 的梯度)

  3. 建立 tensor 的依赖列表

    自动求导中最关键的部分就是在这里了,还是以 O = X @ W 为例子,这里我们会先检查是否 X需要计算梯度,如果需要,我们需要把导函数 D_O / D_X 定义好,保存下来;同样的如果 W 需要梯度,我们将 D_O / D_W 定义好保存下来。最后生成一个 dependency 列表保存着在新生成的 tensor O 中。

然后我们再回顾前面讲的 backward()方法,backward() 方法会遍历 tensor 的 dependency ,将用保存的 grad_fn 计算要传给依赖 tensor 的梯度,然后调用依赖 tensor 的 backward() 方法将梯度传递下去,从而实现了梯度在整个计算图的流动。

grad_for_dep = dep["grad_fn"](grad)
dep["tensor"].backward(grad_for_dep)

自动求导讲到这里其实已经基本没有什么新东西,剩下的工作就是以类似的方法大量地重载各种各样的运算符,使其能够 cover 住大部分所需要的操作(基本上照着 NumPy 的接口都给重载一次就差不多了 🤨)。无论你定义了多复杂的运算,只要重载了相关的运算符,就都能够自动求导了,再也不用自己写梯度了。

在这里插入图片描述

一个例子

大量的重载运算符的工作在文章里就不贴上来了(过程不怎么有趣),我写在了一个 notebook 上,大家有兴趣可以去看看 borgwang/toys/ml-autograd。在这个 notebook 里面重载了实现一个简单的线性回归需要的几种运算符,以及一个线性回归的例子。这里把例子和结果贴上来

# training data
x = Tensor(np.random.normal(0, 1.0, (100, 3)))
coef = Tensor(np.random.randint(0, 10, (3,)))
y = x * coef - 3params = {"w": Tensor(np.random.normal(0, 1.0, (3, 3)), requires_grad=True),"b": Tensor(np.random.normal(0, 1.0, 3), requires_grad=True)
}learng_rate = 3e-4
loss_list = []
for e in range(101):# set gradient to zerofor param in params.values():param.zero_grad()# forwardpredicted = x @ params["w"] + params["b"]err = predicted - yloss = (err * err).sum()# backward automaticallyloss.backward()# updata parameters (gradient descent)for param in params.values():param -= learng_rate * param.gradloss_list.append(loss.values)if e % 10 == 0:print("epoch-%i \tloss: %.4f" % (e, loss.values))
epoch-0 	loss: 8976.9821
epoch-10 	loss: 2747.4262
epoch-20 	loss: 871.4415
epoch-30 	loss: 284.9750
epoch-40 	loss: 95.7080
epoch-50 	loss: 32.9175
epoch-60 	loss: 11.5687
epoch-70 	loss: 4.1467
epoch-80 	loss: 1.5132
epoch-90 	loss: 0.5611
epoch-100 	loss: 0.2111

在这里插入图片描述

接口和 PyTorch 相似,在每个循环里面首先将参数梯度设为 0 ,然后定义计算图,然后从 loss 开始反向传播,最后更新参数。从结果可以看到 loss 随着训练进行非常漂亮地下降,说明我们的自动求导按照我们的设想 work 了。

总结

本文实现了讨论了自动求导的设计思路和整个过程是怎么运作的。总结起来:自动求导就是在定义了一个有状态的计算图,该计算图上的节点不仅保存了节点的前向运算,还保存了反向计算所需的上下文信息。利用上下文信息,通过图遍历让梯度在图中流动,实现自动求节点梯度。

我们通过重载运算符实现了一个支持自动求导的 Tensor 类,用一个简单的线性回归 demo 测试了自动求导。当然这只是最基本的能实现自动求导功能的 demo,从实现的角度上看还有很多需要优化的地方(内存开销、运算速度等),笔者有空会继续深入研究,读者如果有兴趣也可以自行查阅相关资料。Peace out. 🤘

参考资料

  • PyTorch Doc
  • PyTorch Autograd Explained - In-depth Tutorial
  • joelgrus/autograd
  • Automatic Differentiation in Machine Learning: a Survey

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532393.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

http 错误 404.0 - not found_电脑Regsvr32 用法和错误消息的说明

​ 对于那些可以自行注册的对象链接和嵌入 (OLE) 控件,例如动态链接库 (DLL) 文件或 ActiveX 控件 (OCX) 文件,您可以使用 Regsvr32 工具 (Regsvr32.exe) 来将它们注册和取消注册。Regsvr32.exe 的用法RegSvr32.exe 具有以下命令行选项: Regs…

MobileNet 系列:从V1到V3

MobileNet 系列:从V1到V3 转自:轻量级神经网络“巡礼”(二)—— MobileNet,从V1到V3 自从2017年由谷歌公司提出,MobileNet可谓是轻量级网络中的Inception,经历了一代又一代的更新。成为了学习轻…

mysql 高级知识点_这是我见过最全的《MySQL笔记》,涵盖MySQL所有高级知识点!...

作为运维和编程人员,对MySQL一定不会陌生,尤其是互联网行业,对MySQL的使用是比较多的。MySQL 作为主流的数据库,是各大厂面试官百问不厌的知识点,但是需要了解到什么程度呢?仅仅停留在 建库、创表、增删查改…

teechart mysql_TeeChart 的应用

TeeChart 是一个很棒的绘图控件,不过由于里面没有注释,网上相关的资料也很少,所以在应用的时候只能是一点点的试。为了防止以后用到的时候忘记,我就把自己用到的东西都记录下来,以便以后使用的时候查询。1、进制缩放图…

NLP新宠——浅谈Prompt的前世今生

NLP新宠——浅谈Prompt的前世今生 转自:NLP新宠——浅谈Prompt的前世今生 作者:闵映乾,中国人民大学信息学院硕士,目前研究方向为自然语言处理。 《Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in…

requestfacade 这个是什么类?_Java 的大 Class 到底是什么?

作者在之前工作中,面试过很多求职者,发现有很多面试者对Java的 Class 搞不明白,理解的不到位,一知半解,一到用的时候,就不太会用。想写一篇关于Java Class 的文章,没有那么多专业名词&#xff0…

初学机器学习:直观解读KL散度的数学概念

初学机器学习:直观解读KL散度的数学概念 转自:初学机器学习:直观解读KL散度的数学概念 译自:https://towardsdatascience.com/light-on-math-machine-learning-intuitive-guide-to-understanding-kl-divergence-2b382ca2b2a8 解读…

MySQL应用安装_mysql安装和应用

1.下载mysql安装包2.安装mysql,自定义->修改路径3.配置mysql,选择自定义->server模式->500访问量->勾选控制台->设置gbk->设置密码和允许root用户远程登录等等。以管理员权限,在控制台输入:net start MySQL, 启…

mysql 商品规格表_商品规格分析

产品表每次更新商品都会变动的,ID不能用,可是购物车还是用了,这就导致每次保存商品,哪怕什么都没有改动,也会导致用户的购物车失效。~~~其实可以考虑不是每次更新商品就除所有的SKU,毕竟有时什么都没修改呢…

huggingface NLP工具包教程1:Transformers模型

huggingface NLP工具包教程1:Transformers模型 原文:TRANSFORMER MODELS 本课程会通过 Hugging Face 生态系统中的一些工具包,包括 Transformers, Datasets, Tokenizers, Accelerate 和 Hugging Face Hub。…

隐马尔可夫模型HMM推导

隐马尔可夫模型HMM推导 机器学习-白板推导系列(十四)-隐马尔可夫模型HMM(Hidden Markov Model) 课程笔记 背景介绍 介绍一下频率派和贝叶斯派两大流派发展出的建模方式。 频率派 频率派逐渐发展成了统计机器学习,该流派通常将任务建模为一…

使用randomaccessfile类将一个文本文件中的内容逆序输出_Java 中比较常用的知识点:I/O 总结...

Java中I/O操作主要是指使用Java进行输入,输出操作. Java所有的I/O机制都是基于数据流进行输入输出,这些数据流表示了字符或者字节数据的流动序列。数据流是一串连续不断的数据的集合,就象水管里的水流,在水管的一端一点一点地供水…

huggingface NLP工具包教程2:使用Transformers

huggingface NLP工具包教程2:使用Transformers 引言 Transformer 模型通常非常大,由于有数百万到数百亿个参数,训练和部署这些模型是一项复杂的任务。此外,由于几乎每天都有新模型发布,而且每个模型都有自己的实现&a…

mysql精讲_Mysql 索引精讲

开门见山,直接上图,下面的思维导图即是现在要讲的内容,可以先有个印象~常见索引类型(实现层面)索引种类(应用层面)聚簇索引与非聚簇索引覆盖索引最佳索引使用策略1.常见索引类型(实现层面)首先不谈Mysql怎么实现索引的,先马后炮一…

RT-Smart 官方 ARM 32 平台 musl gcc 工具链下载

前言 RT-Smart 的开发离不开 musl gcc 工具链,用于编译 RT-Smart 内核与用户态应用程序 RT-Smart musl gcc 工具链代码当前未开源,但可以下载到 RT-Thread 官方编译好的最新的 musl gcc 工具链 ARM 32位 平台 比如 RT-Smart 最好用的 ARM32 位 qemu 平…

OpenAI Whisper论文笔记

OpenAI Whisper论文笔记 OpenAI 收集了 68 万小时的有标签的语音数据,通过多任务、多语言的方式训练了一个 seq2seq (语音到文本)的 Transformer 模型,自动语音识别(ASR)能力达到商用水准。本文为李沐老师…

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作

【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 转自:【经典简读】知识蒸馏(Knowledge Distillation) 经典之作 作者:潘小小 知识蒸馏是一种模型压缩方法,是一种基于“教师-学生网络思想”的训练方法,由于其简单&#xf…

深度学习三大谜团:集成、知识蒸馏和自蒸馏

深度学习三大谜团:集成、知识蒸馏和自蒸馏 转自:https://mp.weixin.qq.com/s/DdgjJ-j6jHHleGtq8DlNSA 原文(英):https://www.microsoft.com/en-us/research/blog/three-mysteries-in-deep-learning-ensemble-knowledge…

在墙上找垂直线_墙上如何快速找水平线

在装修房子的时候,墙面的面积一般都很大,所以在施工的时候要找准水平线很重要,那么一般施工人员是如何在墙上快速找水平线的呢?今天小编就来告诉大家几种找水平线的方法。一、如何快速找水平线1、用一根透明的软管,长度…

Vision Transformer(ViT)PyTorch代码全解析(附图解)

Vision Transformer(ViT)PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单。本文将根据最原始的Vision Transformer论文,及其PyTorch实现,将整个ViT的代码做一…