【LM、LLM】浅尝二叉树在前馈神经网络上的应用

前言

随着大模型的发展,模型参数量暴涨,以Transformer的为组成成分的隐藏神经元数量增长的越来越多。因此,降低前馈层的推理成本逐渐进入视野。前段时间看到本文介绍的相关工作还是MNIST数据集上的实验,现在这个工作推进到BERT上面来了,再次引起兴趣记录一下。该工作将前馈神经基于二叉树结构进行改装,加速前向传播的速度,称为:快速前馈网络(FFF),然后应用FFF,取代BERT中的前馈网络(FF),实现12个神经元加速推理。

快速前馈网络算法概述

快速前馈网络(Fast Feedforward Network,FFF)是由两部分组成的:节点网络集合 N \mathcal{N} N 和叶子网络集合 L \mathcal{L} L

  • 节点网络集合 N \mathcal{N} N 包含了一组节点网络,每个节点网络都是一个 < dim ⁡ I , n , 1 > \left<\dim_I,n,1\right> dimI,n,1-前馈网络,并在输出上增加了一个 sigmoid 激活函数。这些节点网络按照平衡的可微分二叉树的形式排列,其中 N m + 1 , 2 n N_{m+1,2n} Nm+1,2n N m + 1 , 2 n + 1 N_{m+1,2n+1} Nm+1,2n+1 N m , n N_{m,n} Nm,n 的子节点。

  • 叶子网络集合 L \mathcal{L} L 包含了一组叶子网络,每个叶子网络都是一个 < dim ⁡ I , ℓ , dim ⁡ O > \left<\dim_I,\ell,\dim_O\right> dimI,,dimO-前馈网络。叶子网络没有子节点,它们的输出直接作为 FFF 的输出。

前向传播过程由下面算法控制。

算法的输入包括一个输入样本 ι \iota ι 和根节点 N 0 , 0 N_{0,0} N0,0,输出为该样本在 FFF 中的输出。

算法定义了两个函数: F o r w a r d T Forward_T ForwardT F o r w a r d I {Forward}_I ForwardI。其中, F o r w a r d T {Forward}_T ForwardT 函数用于计算节点的输出,而 F o r w a r d I {Forward}_I ForwardI 函数用于计算节点的指示值(indicator value)。

  • F o r w a r d T {Forward}_T ForwardT 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后递归地调用 F o r w a r d T {Forward}_T ForwardT 函数来计算当前节点的两个子节点的输出,并将它们加权相加作为当前节点的输出。
  • F o r w a r d I {Forward}_I ForwardI 函数中,如果当前节点是叶子节点,则直接调用该节点的前馈传播函数 N m , n ( ι ) N_{m,n}(\iota) Nm,n(ι) 来计算输出。否则,首先计算当前节点的输出 c m , n = N m , n ( ι ) c_{m,n}=N_{m,n}(\iota) cm,n=Nm,n(ι),然后根据输出值的大小决定选择哪个子节点进行递归计算。


传统前馈神经网络

快速前馈神经网络

与传统的前馈神经网络算法相比,该算法的主要区别在于引入了一个计算节点的指示值。指示值表示了当前节点的输出是否大于等于阈值(这里的阈值为0.5),根据指示值的大小来确定选择哪个子节点进行计算。这种方式可以大大减少计算量,提高前向传播的效率。同时,FFF 是一种具有平衡二叉树结构的前馈神经网络,其中节点网络和叶子网络分别用于处理中间层和输出层的计算。通过利用二叉树结构和递归计算,FFF 可以实现快速的前向传播。

UltraFastBERT

UltraFastBERT,一种BERT变体,在推理过程中使用0.3%的神经元,同时表现 与类似的BERT模型相当。UltraFastBERT选择性地使用4095个神经元中的12个(有选择的执行矩阵乘法(CMM))进行每层推理。这是通过用快速前馈网络(FFFs)取代前馈网络来实现的。

FFF_BMM代码

import torch
from torch import nn
import mathclass FFF(nn.Module):def __init__(self, input_width: int, depth: int, output_width: int, *args, **kwargs):super().__init__(*args, **kwargs)self.input_width = input_widthself.depth = depthself.output_width = output_widthself.n_nodes = 2 ** (depth + 1) - 1self.initialise_weights()def initialise_weights(self):init_factor_l1 = 1.0 / math.sqrt(self.input_width)init_factor_l2 = 1.0 / math.sqrt(self.depth + 1)self.w1s = nn.Parameter(torch.empty(self.n_nodes, self.input_width).uniform_(-init_factor_l1, +init_factor_l1), requires_grad=True)self.w2s = nn.Parameter(torch.empty(self.n_nodes, self.output_width).uniform_(-init_factor_l2, +init_factor_l2), requires_grad=True)def forward(self, x):# the shape of x is (batch_size, input_width)# retrieve the indices of the relevant elementsbatch_size = x.shape[0]current_nodes = torch.zeros((batch_size,), dtype=torch.long, device=x.device)all_nodes = torch.zeros(batch_size, self.depth+1, dtype=torch.long, device=x.device)all_logits = torch.empty((batch_size, self.depth+1), dtype=torch.float, device=x.device)for i in range(self.depth+1):all_nodes[:, i] = current_nodesplane_coeffs = self.w1s.index_select(dim=0, index=current_nodes)			# (batch_size, input_width)plane_coeff_score = torch.bmm(x.unsqueeze(1), plane_coeffs.unsqueeze(-1))	# (batch_size, 1, 1)plane_score = plane_coeff_score.squeeze(-1).squeeze(-1) 					# (batch_size,)all_logits[:, i] = plane_scoreplane_choices = (plane_score >= 0).long()									# (batch_size,)current_nodes = current_nodes * 2 + plane_choices + 1						# (batch_size,)# get the weightsselected_w2s = self.w2s.index_select(0, index=all_nodes.flatten()) \.view(batch_size, self.depth+1, self.output_width)	# (batch_size, depth+1, output_width)# forward passmlp1 = torch.nn.functional.gelu(all_logits)				# (batch_size, depth+1)mlp2 = torch.bmm(mlp1.unsqueeze(1), selected_w2s) 		# (batch_size, output_width)# donereturn mlp2

从代码可以看出,与传统的批矩阵乘法(BMM)不同的是,在forward中,基于二叉树的结构,通过迭代计算节点的索引和权重,使用激活函数(GeLU)对结果进行处理,并最终得到输出。

结果

在推理过程中仅使用0.3%的神经元,同时表现与类似的BERT模型相当(下游任务没有降很多点);实现78倍CPU加速,实现40倍PyTorch加速。

总结

该工作很有趣,将传统前馈神经网络定义成一棵二叉树,其叶子是小型神经网络,在每个非叶子节点处都有一个微小的神经网络(单个神经元也可以工作)来决定走哪条路径取决于在输入上。在训练期间,它们对所选路径进行加权平均值,从而得出树的所有叶子(在输入上评估为神经网络)的总加权平均值,但在推理过程中,它们可以只遵循投票最高的分支,从而得出建议的结果指数加速。并且,基于FFF的思想,将工作推到BERT这种语言模型上,初步证明了大模型的前馈层的神经元并不是都需要参与推理。

文章及公开的代码还介绍了条件矩阵乘法的详细细节,因此感兴趣可以深入研究一下。

参考文献

【1】paper:Exponentially Faster Language Modelling,https://arxiv.org/abs/2311.10770
【2】code:https://github.com/pbelcak/fastfeedforward
【3】paper:Fast Feedforward Networks,https://arxiv.org/abs/2308.14711

【4】code:https://github.com/pbelcak/UltraFastBERT
【5】model:https://huggingface.co/pbelcak/UltraFastBERT-1x11-long

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/169511.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[极客大挑战 2019]Secret File1

[极客大挑战 2019]Secret File1 在bp里面发现secr3t.php 将secr3t.php 直接加在网站后面&#xff0c;发现了有关flag的信息&#xff0c;一个flag.php文件 在遇到flag.php时候&#xff0c;联想到php伪协议&#xff0c;构造伪协议方式 secr3t.php?filephp://filter/readconver…

0002Java程序设计-springboot在线考试系统小程序

文章目录 **摘 要****目录**系统实现开发环境 编程技术交流、源码分享、模板分享、网课分享 企鹅&#x1f427;裙&#xff1a;776871563 摘 要 本毕业设计的内容是设计并且实现一个基于springboot的在线考试系统小程序。它是在Windows下&#xff0c;以MYSQL为数据库开发平台&…

FFmpeg零基础学习(一)——初步介绍与环境搭建

目录 前言正文一、开发环境二、搭建环境三、测试代码四、调用库的介绍End、遇到的问题2、Qt 在线安装容易报错&#xff0c;断开问题1、在线安装QMaintainTool很慢2、Qt5.15 无法调试FFmpeg 参考 前言 FFmpeg是一个开源的跨平台多媒体处理框架&#xff0c;它包含了一组用于处理…

【图解系列】一张图带你了解 DevOps 生态工具

一张图带你了解 DevOps 生态工具 ✅ 协作&#xff08;Collaborate&#xff09;&#xff1a;JIRA、Confluence 大家肯定不陌生了&#xff0c;我之前也写过利用 Jekyll 搭建个人博客的帖子。✅ 构建&#xff08;Build&#xff09;&#xff1a;常用的 SCM&#xff08;Software Con…

掌握未来技术趋势,成为领先者——深度解析2023年技术热点

掌握未来技术趋势&#xff0c;成为领先者——深度解析2023年技术热点 摘要&#xff1a;本文探讨当前最热门的技术趋势。我们将介绍人工智能、大数据、区块链、5G等前沿技术&#xff0c;并阐述它们如何改变我们的生活。最后&#xff0c;我们将总结如何利用这些技术趋势&#xf…

2024年天津天狮学院专升本计算机科学与技术《数据结构》考试大纲

2024年天津天狮学院计算机科学与技术专业高职升本入学考试《数据结构》考试大纲 一、考试性质 《数据结构》专业课程考试是天津天狮学院计算机科学与技术专业高职升本入学考 试的必考科目之一&#xff0c;其性质是考核学生是否达到了升入本科继续学习的要求而进行的选拔性考试…

Word打印模板,打印效果更出众丨三叠云

Word打印模板 路径 表单设置 >> 打印设置 功能简介 新增「Word打印模板」(beta版)。 Word 打印模板是指&#xff0c;在 Word 文档的基础上插入表单中的字段代码&#xff0c;打印时即可根据 Word 文档的格式&#xff0c;对表单数据进行个性化打印。 Word 打印模板能…

matlab不用sawtooth,自己写代码实现锯齿波/三角波

matlab自己写代码实现锯齿波/三角波 为什么要自己写代码&#xff0c;不用现成的函数sawtooth&#xff1f; 函数sawtooth的采样频率是固定的&#xff0c;也就是给定一个时间段&#xff0c;只能按照固定的频率取点。比如10s内&#xff0c;每1s取一个点。这样就得到了1s 2s 3s……

激活函数与其导数:神经网络中的关键元素

激活函数是神经网络中的重要组成部分&#xff0c;有力地推动了深度学习的发展。然而&#xff0c;仅仅了解和选择激活函数是不够的&#xff0c;我们还需要理解激活函数的导数。本文将详细介绍激活函数的概念、作用及其导数的重要性&#xff0c;并探究导数对神经网络训练的影响。…

【室内定位系统源码】UWB超宽带定位技术的特点和应用前景

uwb人员、物品定位系统源码&#xff0c;智慧工厂人员安全管理定位&#xff0c;高精度定位系统源码 UWB超宽带定位技术概念&#xff1a; 超宽带无线通信技术&#xff08;UWB&#xff09;是一种无载波通信技术&#xff0c;UWB不使用载波&#xff0c;而是使用短的能量脉冲序…

Presto+Alluxio数据平台实战

数新网络&#xff0c;让每个人享受数据的价值https://xie.infoq.cn/link?targethttps%3A%2F%2Fwww.datacyber.com%2F 一、Presto & Alluxio简介 Presto Presto是由Facebook开发的开源大数据分布式高性能 SQL查询引擎。 起初&#xff0c;Facebook使用Hive来进行交互式查询…

AI创作工具:Claude2注册保姆级教程

最近软件打算多接入几个AI写作平台&#xff0c;包括讯飞星火&#xff0c;百度文心&#xff0c;Claude2&#xff0c;这样就能给用户提供更多的写作选择 经过半天的调研&#xff0c;讯飞星火&#xff0c;百度文心一言&#xff0c;接入都比较简单&#xff0c;毕竟是国内的。 在调…

关于pytorch以及相关包的安装教程

一.查看自己电脑的配置 首先查看自己电脑的cuda的版本&#xff0c;WinR,敲入cmd打开终端 输入nvidia-smi&#xff0c;查看自己电脑的显卡等配置 这里要说明一下关于这个CUDA,它具有向后兼容性&#xff0c;这意味着支持较低版本的 CUDA 的应用程序通常也可以在较高版本的 CUD…

Jmeter接口自动化测试操作流程

在企业使用jmeter开展实际的接口自动化测试工具&#xff0c;建议按如下操作流程&#xff0c; 可以使整个接口测试过程更规范&#xff0c;更有效。 接口自动化的流程&#xff1a; 1、获取到接口文档&#xff1a;swagger、word、excel ... 2、熟悉接口文档然后设计测试用例&am…

前端大厂(腾讯、字节跳动、阿里......)校招面试真题解析,让你面试轻松无压力!

前言 校招很重要&#xff0c;应届生的身份很珍贵&#xff01;在校招的时候与我们竞争的大部分都是没有工作经验的学生&#xff0c;而且校招企业对学生的包容度高&#xff0c;一般对企业来说&#xff0c;社招更看重实际工作经验&#xff0c;而校招更愿意“培养人”&#xff0c;校…

node fs模板及蓝桥案例实战

文章目录 介绍文件写入writeFile 异步写入writeFileSync 同步写入appendFile / appendFileSync 追加写入createWriteStream 流式写入 文件读取readFile 异步读取readFileSync 同步读取createReadStream 流式读取 文件移动与重命名文件删除文件夹操作mkdir / mkdirSync 创建文件…

python操作redis

操作单redis 需要安装redis模块&#xff1a;pip install redis demo&#xff1a; #!/usr/bin/env python3 # coding utf-8import redis import threadingdef a():conn redis.Redis(host"192.168.1.66", port6379, password"123456", db6,# decode_res…

数据库表结构导出成Excel或Word格式

前言 该工具主要用于导出excel、word&#xff0c;方便快速编写《数据库设计文档》&#xff0c;同时可以快速查看表的结构和相关信息。 本博客仅作记录&#xff0c;最新源码已经支持多种数据库多种格式导出&#xff0c;有兴趣的可移步源码作者地址&#xff1a;https://gitee.co…

RK3568驱动指南|第八篇 设备树插件-第73章 设备树插件使用实验

瑞芯微RK3568芯片是一款定位中高端的通用型SOC&#xff0c;采用22nm制程工艺&#xff0c;搭载一颗四核Cortex-A55处理器和Mali G52 2EE 图形处理器。RK3568 支持4K 解码和 1080P 编码&#xff0c;支持SATA/PCIE/USB3.0 外围接口。RK3568内置独立NPU&#xff0c;可用于轻量级人工…

【Redis】前言--redis产生的背景以及过程

一.介绍 为什么会出现Redis这个中间件&#xff0c;从原始的磁盘存储到Redis中间又发生了哪些事&#xff0c;下面进入正题 二.发展史 2.1 磁盘存储 最早的时候都是以磁盘进行数据存储&#xff0c;每个磁盘都有一个磁道。每个磁道有很多扇区&#xff0c;一个扇区接近512Byte。…