从 0 手撸一个 pytorch

背景介绍

最近抽空看了下 Andrej Karpathy 的视频教程 building micrograd,教程的质量很高。教程不需要任何前置机器学习基础,只需要有高中水平的数学基础即可。整个教程从 0 到 1 手撸了一个类 pytorch 的机器学习库 micrograd,核心代码不到 100 行。虽然为了简化没有实现复杂的矩阵运算,但是对于理解 pytorch 的设计思想有很大帮助。

动手实践

为了验证 micrograd 的可用性,先基于 micrograd 实现了简单的线性回归算法。

首先构造出数据集,我使用随机数作为 x,通过线性回归确定结果后增加必要的噪声,对应的构造方法如下所示:

import numpy as npdef get_train_dataset(num_samples, noise):x = np.random.rand(num_samples)y = 4 * x + 3 + np.random.normal(0, noise, num_samples)return x.tolist(), y.tolist()

可以看到最终期望的结果为 y = 4 * x + 3

接下来实现训练流程,线性回归的模型的初始值都使用随机值,持续跟踪训练过程中损失值与对应参数的变化,实现如下所示:

import numpy as np
from micrograd.engine import Valuedef zero_grad(w, b):w.grad = 0b.grad = 0def step(w, b, learning_rate):w.data -= learning_rate * w.gradb.data -=  learning_rate * b.graddef train_loop():dataset_x, dataset_y = get_train_dataset(10, 0.01)learning_rate = 0.1w = Value(np.random.rand())b = Value(np.random.rand())epoch = 40print(f"Init w {w.data}, b {b.data}")for idx in range(epoch):loss = 0for x, y in zip(dataset_x, dataset_y):x_value, y_value = Value(x), Value(y)y_pred = x_value * w + bcurrent_loss = (y_value - y_pred) ** 2loss += current_loss.datazero_grad(w, b)current_loss.backward()step(w, b, learning_rate)print(f"Epoch {idx} got loss: {loss}, w {w.data}, b {b.data}")

上面的实现中使用 zero_grad() 方法重置参数的梯度,使用 step() 方法实际更新模型参数,训练流程就实现在 train_loop() 中。最终结果如下所示:
在这里插入图片描述
可以看到经过 40 轮训练后,损失值从最初的 55.69 下降至 0.0016,而参数 w, b 也接近期望的目标。从实践结果来看,micrograd 确实能实现简单模型的训练。

通过上面的实践来看,micrograd 最核心的就是 Value,按照 Andrej Karpathy 的说法,不到 100 行实现的 Value 就已经完成的 pytorch 中的 Tensor 90% 的功能了,除了这部分核心功能之外,pytorch 更多的是做了效率上的优化。

流程梳理

在机器学习中,模型训练都是基于 梯度下降 来更新模型的。模型训练的过程一般分为前向传播和反向传播:

  • 前向传播会根据训练数据确定对应的损失值,对应于上面的实现如下:
x_value, y_value = Value(x), Value(y)
y_pred = x_value * w + b
current_loss = (y_value - y_pred) ** 2

前向传播就是根据模型确定预测值 y_pred, 基于 MSE 确定损失值 (y - y_pred)^2。前向传播相对容易理解。

  • 反向传播就是根据确定的损失值进行模型参数的调整,从而降低损失值,对应的实现就是:
zero_grad(w, b)
current_loss.backward()
step(w, b, learning_rate)

上面最核心的功能就是调用 current_loss.backward() 确定各个参数对应的梯度,然后在 step() 方法中对参数的值进行更新。

参数更新的方案是相对明确,就是减去梯度与学习率之积实现。因此主要关注如何确定参数的梯度。梯度的计算存在如下所示的关注点:

  1. 数学运算各个元素对应的梯度如何计算,这部分就是微积分中导数的计算;
  2. 链式法则;
  3. 复杂模型中包含上亿参数,如何确定参数各自的梯度;

实现细节

micrograd 最核心的实现位于 engine.py,主要关注 Value 类的实现。

初始化过程

关注初始化过程可以看到 Value 中包含的元素,实现如下:

def __init__(self, data, _children=(), _op=''):self.data = dataself.grad = 0self._backward = lambda: Noneself._prev = set(_children)self._op = _op # the op that produced this node, for graphviz / debugging / etc

初始化阶段可以看到 Value 中最重要的两个参数,data 保存的是元素中的原始数据,grad 保存的是当前元素对应的梯度。

_backward() 方法保存的是反向传播的方法,用于计算反向传播的梯度

_prev 保存的是当前节点前置的节点,比如 y = w * x 中节点 y 对应的 _prev 保存的是 wx。通过不断的获取 _prev 节点,即可还原完整的运算链路。

数学运算支持

Value 中支持了不同的数学运算,首先以加法为例,实现如下所示:

def __add__(self, other):other = other if isinstance(other, Value) else Value(other)# 加法运算得到结果,同样是 Value 元素out = Value(self.data + other.data, (self, other), '+')# 加法反向传播函数def _backward():self.grad += out.gradother.grad += out.gradout._backward = _backwardreturn out

前向传播计算的实现比较简单,直接基于 data 进行计算,通过加法运算生成了结果 out。同时将参与运算的元素 selfother 保存至 self._prev 中,方便还原运算链路。

out 对应的反向传播的方法 _backward() 是基于链式法则实现。举例如下:

c = a + b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = 1,因此 ∂l/∂a = ∂l/∂c,因此加法中元素的梯度就等于其结果的梯度。

那么为什么实现是 self.grad += out.grad 而不是 self.grad = out.grad 呢,因为单个元素涉及多个运算链路时,梯度是不同链路确定的梯度之和。

这个也带来一个隐患,每次重新计算梯度之前,需要将原有的梯度重置为 0。对应于上面的 zero_grad() 的实现。了解 pytorch 应该也会注意到 pytorch 训练过程中也存在类似情况。

同样来查看乘法运算,对应的实现如下:


def __mul__(self, other):other = other if isinstance(other, Value) else Value(other)out = Value(self.data * other.data, (self, other), '*')def _backward():self.grad += other.data * out.gradother.grad += self.data * out.gradout._backward = _backwardreturn out

主要关注反向传播的实现,可以看到同样是链路法则的推演,举例如下:

c = a * b

那么 ∂l/∂a = ∂l/∂c * ∂c/∂a,而 ∂c/∂a = b, 因此 ∂l/∂a = ∂l/∂c * b, 因此就可以理解上面的实现了。

反向传播

通过上面的运算过程可以看到,通过不断保存其前置元素至 self._prev 中,可以构建出完整的运算链路图。而在运算过程中,元素反向传播计算的梯度的方法 _backward() 也被确定。因此反向传播就是从后往前调用 _backward() 来实现的:


def backward(self):topo = []visited = set()# 根据前置元素的关系构建拓扑排序的元素列表,保证最终调用时是从后往前的def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)# 最后元素的梯度为 1, 依次计算前置元素的梯度self.grad = 1for v in reversed(topo):v._backward()

最终反向传播就是调用 _backward() 即可确定各个元素的梯度。

总结

通过上面的流程可以很容易理解机器学习模型训练框架的设计方案,这一套流程也完全适用于 pytorch,可以帮助更好地理解 pytorch 的训练流程。整体总结下实现思路:

  1. 前向传播过程中会逐层计算运行结果,并确定结果与运算元素梯度之前的关系,在结果元素梯度确定后就可以确定运算元素的梯度;
  2. 反向传播就是按照从后往前依次确认各个元素的梯度,方便后续根据梯度更新元素对应的值;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/17696.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue状态管理深度剖析:Vuex vs Pinia —— 从原理到实践的全面对比

🔥 个人主页:空白诗 文章目录 👋 引言📌 Vuex 基础知识核心构成要素示例代码 📌 Pinia 基础知识核心构成要素示例代码 📌 Vuex与Pinia的区别📌 使用示例与对比📌 总结 👋…

探索Solana链上DApp开发:高性能区块链生态的新机遇

Solana 是一个新兴的区块链平台,致力于为 DApp(去中心化应用程序)开发者提供高性能、低成本的解决方案。Solana 的独特之处在于其创新性的共识机制和高吞吐量的网络,使得开发者可以构建高度可扩展的 DApp,并为用户提供…

云服务器如何使用局域网服务器的磁盘空间

说明 云服务器中的磁盘空间不足时,想要开通更多的磁盘空间,但奈何价格太贵,开不起 刚好局域网中有闲置的服务器空间可以拿来用,这里我们直接使用Samba服务来共享文件夹,使用frp来进行内网穿透; 1、磁盘挂…

OSPF优化——OSPF减少LSA更新量2

二、特殊区域——优化非骨干区域的LSA数量 不是骨干区域、不能存在虚链路 1、不能存在 ASBR 1)末梢区域 该区域将拒绝 4、5LSA的进人,同时由该区域连接骨干0区域的ABR 向该区域,发布一条3类的缺省路由; 该区域内每台路由器均需配置&#xf…

Unity 实现心电图波形播放(需波形图图片)

实现 在Hierarchy 面板从2D Object 中新建一个Sprite,将波形图图片的赋给Sprite。 修改Sprite 的Sprite Renderer 组件中Draw Mode 为Tiled, 修改Sprite Renderer 的Size 即可实现波形图播放。 在Hierarchy 面板从2D Object 中新建一个Sprite Mask 并赋以遮罩图片…

【设计模式】JAVA Design Patterns——Curiously Recurring Template Pattern(奇异递归模板模式)

🔍目的 允许派生组件从与派生类型兼容的基本组件继承某些功能。 🔍解释 真实世界例子 对于正在策划赛事的综合格斗推广活动来说,确保在相同重量级的运动员之间组织比赛至关重要。这样可以防止体型明显不同的拳手之间的不匹配,例如…

生成模型 | 从 VAE 到 Diffusion Model (下)

生成模型 | 从 VAE 到 Diffusion Model (上)的链接请点击下方蓝色字体: 上部分主要介绍了,GAN, AE, VAE, VQ-VAE, DALL-E 生成模型 | 从 VAE 到 Diffusion Model (上) 文章目录 我们先来看一下生成模型现在的能力一&…

IT人的拖延——一放松就停不下来,耽误事?

拖延的表现 在我们的日常工作中,经常会面对这样一种情况:因为要做的Sprint ticket比较复杂或者长时间的集中注意力后,本来打算休息放松一下,刷刷剧,玩玩下游戏,但却一个不小心,没控制住时间&am…

IP 分片过程及偏移量计算

IP 报头中与分片相关的三个字段 1、 标识符( ldentifier ):16 bit 该字段与 Flags 和 Fragment Offest 字段联合使用, 对较大的上层数据包进行分段( fragment ) 操作。 路由器将一个包拆分后,所有拆分开的…

图解Java数组的内存分布

我们知道,访问数组元素要通过数组索引,如: arr[0]如果直接访问数组,比如: int[] arr1 {1}; System.out.println(arr1);会发生什么呢? 打印的是一串奇怪的字符串:[I16b98e56。 这个字符串是J…

强化训练:day11(游游的水果大礼包、 买卖股票的最好时机(二)、倒置字符串)

文章目录 前言1. 游游的水果大礼包1.1 题目描述1.2 解题思路1.3 代码实现 2. 买卖股票的最好时机(二)2.1 题目描述2.2 解题思路2.3 代码实现 3. 倒置字符串3.1 题目描述3.2 解题思路3.3 代码实现 总结 前言 1. 游游的水果大礼包   2. 买卖股票的最好时机(二)   3. 倒置字符…

数据结构初阶 栈

一. 栈的基本介绍 1. 基本概念 栈是一种线性表 是一种特殊的数据结构 栈顶:进行数据插入和删除操作的一端 另一端叫做栈底 压栈:插入数据叫做压栈 压栈的数据在栈顶 出栈: 栈的删除操作叫做出栈 出栈操作也是在栈顶 栈遵循一个原则 叫做…

JavaEE:Servlet创建和使用及生命周期介绍

目录 ▐ Servlet概述 ▐ Servlet的创建和使用 ▐ Servlet中方法介绍 ▐ Servlet的生命周期 ▐ Servlet概述 • Servlet是Server Applet的简称,意思是 用Java编写的服务器端的程序,Servlet被部署在服务器中,而服务器负责管理并调用Servle…

2024.5.21 作业 xyt

今日课堂内容&#xff1a;域套接字 TCP流式套接字 //服务器 #include <myhead.h> int main(int argc, const char *argv[]) {//1、为通信创建一个端点int sfd socket(AF_UNIX, SOCK_STREAM, 0);//参数1&#xff1a;说明使用的是ipv4通信域//参数2&#xff1a;说明使用…

HTML静态网页成品作业(HTML+CSS)——动漫海绵宝宝介绍网页(5个页面)

&#x1f389;不定期分享源码&#xff0c;关注不丢失哦 文章目录 一、作品介绍二、作品演示三、代码目录四、网站代码HTML部分代码 五、源码获取 一、作品介绍 &#x1f3f7;️本套采用HTMLCSS&#xff0c;未使用Javacsript代码&#xff0c;共有5个页面。 二、作品演示 三、代…

【前端笔记】记录一个能优化Echarts Geo JSON大小的网站

前端在使用Echarts等可视化图表库会不可避免遇到的问题&#xff0c;渲染地图的数据太大。 而有那么一个网站能给予这个问题一个解决方案&#xff1a;链接在此 使用方法很简单&#xff0c;首先先进入网站&#xff0c;如果进入了会是这个页面&#xff1a; 接着&#xff0c;选择一…

HCIP的学习(25)

VLAN间通讯技术 使用多臂路由的方式 ​ 路由器的物理接口默认是不识别802.1Q标签的&#xff0c;所以&#xff0c;交换机连接路由器的接口在发送数据帧时&#xff0c;应该将标签剥离。----一般常使用Access接口配置。 单臂路由 ​ 所谓的单臂路由&#xff0c;实际上试讲路由器…

【主流分布式算法总结】

文章目录 分布式常见的问题常见的分布式算法Raft算法概念Raft的实现 ZAB算法Paxos算法 分布式常见的问题 分布式场景下困扰我们的3个核心问题&#xff08;CAP&#xff09;&#xff1a;一致性、可用性、分区容错性。 1、一致性&#xff08;Consistency&#xff09;&#xff1a;…

Docker是什么?使用场景作用及Docker的安装和启动详解

目录 Docker是什么&#xff1f; Docker的发展 Docker的安装 Docker使用 Docker的运行机制 第一个Docker容器 进入Docker容器 客户机访问容器 Docker是什么&#xff1f; Docker 是一个开源的应用容器引擎&#xff0c;基于 Go 语言 并遵从 Apache2.0 协议开源。 Docker …