【自然语言处理】【大模型】RWKV:基于RNN的LLM

相关博客
【自然语言处理】【大模型】RWKV:基于RNN的LLM
【自然语言处理】【大模型】CodeGen:一个用于多轮程序合成的代码大语言模型
【自然语言处理】【大模型】CodeGeeX:用于代码生成的多语言预训练模型
【自然语言处理】【大模型】LaMDA:用于对话应用程序的语言模型
【自然语言处理】【大模型】DeepMind的大模型Gopher
【自然语言处理】【大模型】Chinchilla:训练计算利用率最优的大语言模型
【自然语言处理】【大模型】大语言模型BLOOM推理工具测试
【自然语言处理】【大模型】GLM-130B:一个开源双语预训练语言模型
【自然语言处理】【大模型】用于大型Transformer的8-bit矩阵乘法介绍
【自然语言处理】【大模型】BLOOM:一个176B参数且可开放获取的多语言模型
【自然语言处理】【大模型】PaLM:基于Pathways的大语言模型
【自然语言处理】【chatGPT系列】大语言模型可以自我改进
【自然语言处理】【ChatGPT系列】FLAN:微调语言模型是Zero-Shot学习器
【自然语言处理】【ChatGPT系列】ChatGPT的智能来自哪里?

RWKV:基于RNN的LLM

​ 基于Transformer的LLM已经取得了巨大的成功,但是其在显存消耗和计算复杂度上都很高。RWKV是一个基于RNN的LLM,其能够像Transformer那样高效的并行训练,也能够像RNN那样高效的推理。

一、背景知识

1. RNN

​ RNN是指一类神经网络模型结构,其中最具有代表性的是LSTM:
f t = σ g ( W f x t + U f h t − 1 + b f ) i t = σ g ( W i x t + U i h t − 1 + b i ) o t = σ g ( W o x t + U o h t − 1 + b o ) c ~ t = σ c ( W c x t + U c h t − 1 + b c ) c t = f t ⊙ c t − 1 + i t ⊙ c ~ t h t = o t ⊙ σ h ( c t ) \begin{align} f_t&=\sigma_g(W_fx_t+U_f h_{t-1}+b_f) \tag*{(1)} \\ i_t&=\sigma_g(W_ix_t+U_i h_{t-1}+b_i) \tag*{(2)} \\ o_t&=\sigma_g(W_ox_t+U_o h_{t-1}+b_o) \tag*{(3)} \\ \tilde{c}_t&=\sigma_c(W_cx_t+U_c h_{t-1}+b_c) \tag*{(4)} \\ c_t&=f_t\odot c_{t-1}+i_t\odot\tilde{c}_t \tag*{(5)} \\ h_t&=o_t\odot\sigma_h(c_t) \tag*{(6)} \end{align} \\ ftitotc~tctht=σg(Wfxt+Ufht1+bf)=σg(Wixt+Uiht1+bi)=σg(Woxt+Uoht1+bo)=σc(Wcxt+Ucht1+bc)=ftct1+itc~t=otσh(ct)(1)(2)(3)(4)(5)(6)
其中, x t x_t xt是当前时间步的输入, h t − 1 h_{t-1} ht1是上一个时间步的隐藏状态,所有的 W W W U U U b b b都是可学习参数, σ \sigma σ表示 sigmoid \text{sigmoid} sigmoid函数。 f t f_t ft是“遗忘门”,用来控制前一个时间步上传递信息的比例; i t i_t it是“输入门”,用于控制当前时间步保留的信息比例; o t o_t ot是"输出门",用于产生最终的输出。

2. Transformers和AFT

​ Transformer是NLP中主流的一种模型架构,其依赖于注意力机制来捕获所有输入和输出tokens的关系:
Attn ( Q , K , V ) = softmax ( Q K ⊤ ) V (7) \text{Attn}(Q,K,V)=\text{softmax}(QK^\top)V \tag{7} \\ Attn(Q,K,V)=softmax(QK)V(7)
为了简洁,这里忽略了多头和缩放因子 1 d k \frac{1}{\sqrt{d_k}} dk 1 Q K ⊤ QK^\top QK是序列中每个token之间的成对注意力分数,其能够被分解为向量表示:
Attn ( Q , K , V ) t = ∑ i = 1 T e q t ⊤ k i ∑ i = 1 T e q t ⊤ k i v i = ∑ i = 1 T e q t ⊤ k i v i ∑ i = 1 T e q t ⊤ k i (8) \text{Attn}(Q,K,V)_t=\sum_{i=1}^T\frac{e^{q_t^\top k_i}}{\sum_{i=1}^T e^{q_t^\top k_i}}v_i=\frac{\sum_{i=1}^T e^{q_t^\top k_i}v_i}{\sum_{i=1}^T e^{q_t^\top k_i}}\tag{8} \\ Attn(Q,K,V)t=i=1Ti=1Teqtkieqtkivi=i=1Teqtkii=1Teqtkivi(8)
在AFT中,设计了一种注意力变体:
Attn + ( W , K , V ) t = ∑ i = 1 t e w t , i + k i v i ∑ i = 1 t e w t , i + k i (9) \text{Attn}^+(W,K,V)_t=\frac{\sum_{i=1}^t e^{w_{t,i}+k_i}v_i}{\sum_{i=1}^t e^{w_{t,i}+k_i}} \tag{9} \\ Attn+(W,K,V)t=i=1tewt,i+kii=1tewt,i+kivi(9)
其中, { w t , i } ∈ R T × T \{w_{t,i}\}\in R^{T\times T} {wt,i}RT×T是可学习的位置偏差,每个 w t , i w_{t,i} wt,i是一个标量。

​ 受AFT启发,在RWKV中的 w t , i w_{t,i} wt,i是一个乘以相对位置的时间衰减向量:
w t , i = − ( t − i ) w (10) w_{t,i}=-(t-i)w \tag{10} \\ wt,i=(ti)w(10)
其中, w ∈ ( R ≥ 0 ) d w\in (R_{\geq 0})^d w(R0)d d d d是通道数。这里需要 w w w是非负来保证 e w t , i ≤ 1 e^{w_{t,i}}\leq 1 ewt,i1并且每个信道随时间衰减。

二、RWKV(Receptance Weighted Key Value)

在这里插入图片描述

​ RWKV由一系列的基本Block组成,每个Block则由time-mixing block和channel-mixing block组成的(如上图所示)。
在这里插入图片描述

​ RWKV递归的形式可以看做是当前输入和前一个时间不输入的线性插值,如上图所示。

1. Time-mixing block

​ Time-mixing block的作用同Self-Attention相同,就是提供全局token的交互。细节如下:
r t = W r ⋅ ( μ r x t + ( 1 − u r ) x t − 1 ) k t = W k ⋅ ( μ k x t + ( 1 − u k ) x t − 1 ) v t = W v ⋅ ( μ v x t + ( 1 − μ v ) x t − 1 ) w k v t = ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i v i + e u + k t v t ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i + e u + k t o t = W o ⋅ ( σ ( r t ) ⊙ w k v t ) \begin{align} r_t&=W_r\cdot(\mu_rx_t+(1-u_r)x_{t-1}) \tag*{(11)} \\ k_t&=W_k\cdot(\mu_kx_t+(1-u_k)x_{t-1}) \tag*{(12)} \\ v_t&=W_v\cdot(\mu_vx_t+(1-\mu_v)x_{t-1}) \tag*{(13)} \\ wkv_t&=\frac{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i+e^{u+k_t}v_t}{\sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}+e^{u+k_t}} \tag*{(14)} \\ o_t&=W_o\cdot(\sigma(r_t)\odot wkv_t) \tag*{(15)} \end{align} \\ rtktvtwkvtot=Wr(μrxt+(1ur)xt1)=Wk(μkxt+(1uk)xt1)=Wv(μvxt+(1μv)xt1)=i=1t1e(t1i)w+ki+eu+kti=1t1e(t1i)w+kivi+eu+ktvt=Wo(σ(rt)wkvt)(11)(12)(13)(14)(15)
所有的 μ \mu μ W W W都是可训练参数, r t r_t rt k t k_t kt v t v_t vt是当前输入 x t x_t xt和上一个时间步输入 x t − 1 x_{t-1} xt1的加权投影。

公式(14)中, w w w u u u是可训练参数,分子的第一项 ∑ i = 1 t − 1 e − ( t − 1 − i ) w + k i v i \sum_{i=1}^{t-1}e^{-(t-1-i)w+k_i}v_i i=1t1e(t1i)w+kivi表示前 t − 1 t-1 t1步的加权结果, − ( t − 1 − i ) w + k i -(t-1-i)w+k_i (t1i)w+ki是随相对距离逐步衰减; e u + k t v t e^{u+k_t}v_t eu+ktvt则是当前时间步的结果。

公式(15)中,则通过 σ ( r t ) \sigma(r_t) σ(rt)控制最终输出的比例。

2. Channel-mixing block

​ Channel-mixing block类似于Transformer中的FFN部分,细节如下:
r t = W r ⋅ ( μ r x t − ( 1 − μ r ) x t − 1 ) k t = W k ⋅ ( μ k x t − ( 1 − μ k ) x t − 1 ) o t = σ ( r t ) ⊙ ( W v ⋅ max ⁡ ( k t , 0 ) 2 ) \begin{align} r_t&=W_r\cdot(\mu_rx_t-(1-\mu_r)x_{t-1}) \tag*{(16)} \\ k_t&=W_k\cdot(\mu_kx_t-(1-\mu_k)x_{t-1}) \tag*{(17)} \\ o_t&=\sigma(r_t)\odot(W_v\cdot\max(k_t,0)^2) \tag*{(18)} \\ \end{align} \\ rtktot=Wr(μrxt(1μr)xt1)=Wk(μkxt(1μk)xt1)=σ(rt)(Wvmax(kt,0)2)(16)(17)(18)

三、并行训练和序列解码

​ RWKV可以类似Transformer那样高效的并行。设batch size为B、seq_length为T、channels为d,计算量主要来自于矩阵乘法 W □ , □ ∈ { r , k , v , o } W_\square,\square\in \{r,k,v,o\} W,{r,k,v,o},单层的时间复杂度为 O ( B T d 2 ) O(BTd^2) O(BTd2)。此外,更新注意力分数 w k v t wkv_t wkvt需要顺序扫描,其时间复杂度为 O ( B T d ) O(BTd) O(BTd)。矩阵乘法可以像Transformer那样并行,但是WKV的计算是依赖时间步的,所以只能在其他维度上并行。

​ RWKV具有类似RNN的结构,解码时将 t t t步的输出作为 t + 1 t+1 t+1步的输入。相比于自注意力机制随着序列长度,计算复杂度呈平方次增长,RWKV则是与序列长度呈线性关系。因此,RWKV能够更高效的处理更长的序列。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/80948.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入网络底层,了解Linux系统收发网络数据包的过程、原理、流程,附图文说明

深入网络底层,了解Linux系统收发网络数据包的过程、原理、流程,附图文说明。 Linux 服务器收到网络数据包,需要经过哪些处理,一步步将数据传给应用进程的呢?应用进程发送数据包时,Linux 又是如何操作将数据…

android studio platform使用体验分享(as无法跳转c/c++等native源码的福音,强烈推荐)

hi,粉丝朋友们: 大家好!这些天粉丝朋友们分享了一下Android Studio for Platform 这个最新的google开发的阅读aosp源码的工具,特别适合做原生系统开发。具体官方介绍如下地址: 参考链接:https://developer.…

react的状态管理简单钩子方法

1.recoil useProvider文件: import { atom, useRecoilState } from recoil;const initState atom({key: initState,default: {state: [],}, })// 将业务逻辑拆分到一个单独文件中,方便进行状态管理 export interface StateProps {id: number;text: string;isFini…

异地远程访问本地SQL Server数据库【无公网IP内网穿透】

文章目录 1. 前言2. SeaFile云盘设置2.1 Owncould的安装环境设置2.2 SeaFile下载安装2.3 SeaFile的配置 3. cpolar内网穿透3.1 Cpolar下载安装3.2 Cpolar的注册3.3 Cpolar云端设置3.4 Cpolar本地设置 4. 公网访问测试5. 结语 1. 前言 现在我们身边的只能设备越来越多&#xff…

文心一言 VS 讯飞星火 VS chatgpt (94)-- 算法导论9.2 3题

三、用go语言,给出 RANDOMIZED-SELECT的一个基于循环的版本。 文心一言: 在循环中实现 RANDOMIZED-SELECT 的一个版本可以如下所示。这个版本使用 Go 语言编写,利用随机化来选择一个元素,并在循环中不断地调整选择的元素&#x…

【探索Linux】—— 强大的命令行工具 P.8(进程地址空间)

阅读导航 前言一、内存空间分布二、什么是进程地址空间1. 概念2. 进程地址空间的组成 三、进程地址空间的设计原理1. 基本原理2. 虚拟地址空间 概念 大小和范围 作用 虚拟地址空间的优点 3. 页表 四、为什么要有地址空间五、总结温馨提示 前言 前面我们讲了C语言的基础知识&am…

HTTP协议(超级详细)

HTTP协议介绍 基本介绍: HTTP:超文本传输协议,是从万维网服务器传输超文本到本地浏览器的传送协议HTTP是一种应用层协议,是基于TCP/IP通信协议来传送数据的,其中 HTTP1.0、HTTP1.1、HTTP2.0 均为 TCP 实现&#xff0…

vue组件库开发,webpack打包,发布npm

做一个像elment-ui一样的vue组件库 那多好啊!这是我前几年就想做的 但webpack真的太难用,也许是我功力不够 今天看到一个视频,早上6-13点,终于实现了,呜呜 感谢视频的分享-来龙去脉-大家可以看这个视频:htt…

【C语言】【数据存储】用%u打印char类型?用char存128?

1.题目一&#xff1a; #include <stdio.h> int main() {char a -128;printf("%u\n",a);return 0; }%u 是打印无符号整型 解题逻辑&#xff1a; 1. 原反补互换&#xff0c;截断 -128 原码&#xff1a;10000000…10000000 补码&#xff1a;11111111…10000000…

uniapp项目实践总结(十六)自定义下拉刷新组件

导语&#xff1a;在日常的开发过程中&#xff0c;我们经常遇到下拉刷新的场景&#xff0c;很方便的刷新游览的内容&#xff0c;在此我也实现了一个下拉刷新的自定义组件。 目录 准备工作原理分析组件实现实战演练内置刷新案例展示 准备工作 在components新建一个q-pull文件夹…

LVS负载均衡群集(NAT模式、IP隧道模式、DR模式)

目录 一、集群 1.1 含义即特点 1.2 群集的类型 1.3 LVS 的三种工作模式&#xff1a; 1.4 LVS 调度算法 1.5 负载均衡群集的结构 1.6 ipvsadm 工具 二、NAT模式 LVS-NAT模式配置步骤&#xff1a; 实例&#xff1a; 配置NFS服务器192.168.20.100 配置web1服务器192.168…

C++11线程库简介

前言 在c11之前涉及多线程的问题都是和平台相关的&#xff0c;比如windows和linux都有一套自己的接口&#xff0c;这使得代码的可移植性变差。C11中最重要的特性就是对线程进行了支持&#xff0c;使得C在编程时不再依赖第三方库&#xff0c;而且原子操作中还引入了原子类的概念…

Java计算机毕业设计 基于SpringBoot+Vue的毕业生信息招聘平台的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

网络请求【小程序】

一、get 二、post 1.获取相应数据 Page({/*** 页面的初始数据*/data: { inptValue:, isArr:[]},/*** 生命周期函数--监听页面加载*/onLoad(options) {},onSubmit(){// console.log(this.data.inptValue)//2.后台请求数据wx.request({url: https://tea.qingnian8.com/demoArt/…

CentOS7安装MySQL

文章目录 前言一、MySQL5.71.1 安装wget1.2 下载&安装MySQL的rpm源1.3 修改MySQL安装版本1.4 下载并启动MySQL1.5 开启MySQL远程连接用户 二、MySQL8.0注意事项 前言 CentOS7的安装&#xff0c;采用的是yum的方式安装。 yum方式安装&#xff0c;就类似在Windows下不停的下…

flink时间处理语义

背景 在flink中有两种不同的时间处理语义&#xff0c;一种是基于算子处理时间的时间&#xff0c;也就是以flink的算子所在的机器的本地时间为准&#xff0c;一种是事件发生的实际时间&#xff0c;它只与事件发生时的时间有关&#xff0c;而与flink算子的所在的本地机器的本地时…

机器学习——决策树/随机森林

0、前言&#xff1a; 决策树可以做分类也可以做回归&#xff0c;决策树容易过拟合决策树算法的基本原理是依据信息学熵的概念设计的&#xff08;Logistic回归和贝叶斯是基于概率论&#xff09;&#xff0c;熵最早起源于物理学&#xff0c;在信息学当中表示不确定性的度量&…

WorkPlus | 好用、专业、安全的局域网即时通讯及协同办公平台

自国家于2022年发布的《关于加强数字政府建设的指导意见》以来&#xff0c;我国数字政府建设已经迈入了一个全新的里程碑&#xff0c;迎来了全面改革和深化升级的全新阶段。 WorkPlus作为自主可控、可信安全、专属定制的数字化平台&#xff0c;扮演着政务机关、政府单位以及各…

JDK19特性

文章目录 JAVA19概述1. 记录模式(预览版本)2.Linux/RISC-V 移植3.外部函数和内存 API &#xff08;预览版&#xff09;4.虚拟线程(预览版)5.Vector API &#xff08;第四次孵化&#xff09;6.Switch 模式匹配&#xff08;第三预览版&#xff09;7.结构化并发&#xff08;孵化阶…

【算法专题突破】滑动窗口 - 串联所有单词的子串(15)

目录 1. 题目解析 2. 算法原理 3. 代码编写 写在最后&#xff1a; 1. 题目解析 题目链接&#xff1a;30. 串联所有单词的子串 - 力扣&#xff08;LeetCode&#xff09; 这道题其实也很好理解&#xff0c;看一下示例就基本知道是什么意思了&#xff0c; 主要就是找 s 里面…