Llama改进之——SwiGLU激活函数

引言

今天介绍LLAMA模型引入的关于激活函数的改进——SwiGLU1,该激活函数取得了不错的效果,得到了广泛地应用。

SwiGLU是GLU的一种变体,其中包含了GLU和Swish激活函数。

GLU

GLU(Gated Linear Units,门控线性单元)2引入了两个不同的线性层,其中一个首先经过sigmoid函数,其结果将和另一个线性层的输出进行逐元素相乘作为最终的输出:
GLU ( x , W , V , b , c ) = σ ( x W + b ) ⊗ ( x V + c ) (1) \text{GLU}(x,W,V,b,c) = \sigma(xW+b) \otimes (xV+c) \tag 1 GLU(x,W,V,b,c)=σ(xW+b)(xV+c)(1)
这里 W , V W,V W,V以及 b , c b,c b,c分别是这两个线性层的参数; σ ( x W + b ) \sigma(xW+b) σ(xW+b)作为门控,控制 x V + c xV+c xV+c的输出。

这里使用 σ \sigma σ作为激活函数,修改改激活函数得到的变体通常能带来更好的性能表现,比如SwiGLU修改激活函数为Swish。我们来看下Swish激活函数。

Swish

Swish3激活函数的形式为:
Swish β ( x ) = x σ ( β x ) (2) \text{Swish}_\beta(x) = x \sigma(\beta x) \tag 2 Swishβ(x)=xσ(βx)(2)
其中 σ ( x ) \sigma(x) σ(x)是Sigmoid函数; β \beta β是一个可学习的参数。

可以通过下面的代码画出Swish激活函数在不同参数 β \beta β下的图像:

import numpy as np
import matplotlib.pyplot as pltdef swish(x, beta):return x / (1 + np.exp(-beta*x))x = np.linspace(-10, 10, 100)
betas = [0.1, 1.0, 10.0]plt.figure(figsize=(10, 6))for beta in betas:y = swish(x, beta)plt.plot(x, y, label=f'beta={beta}')plt.legend()
plt.title('Swish Activation Function')
plt.xlabel('x')
plt.ylabel('f(x)')
plt.grid(True)
plt.show()

image-20240428224729925

可以看到3,当 β \beta β趋近于 0 0 0时,Swish函数趋近于线性函数 y = x 2 y=x^2 y=x2;当 β \beta β趋近于无穷大时,Swish函数趋近于ReLU函数;当 β \beta β取值为 1 1 1时,Swish函数是光滑且非单调的,等价于参考4中介绍的SiLU。

Swish与ReLU之间最显著的区别是当 x < 0 x < 0 x<0时Swish的非单调“凸起”3

SwiGLU

如前文所述,将公式(1)中GLU的激活函数改为Swish即变成了所谓的SwiGLU激活函数1
SwiGLU ( x , W , V ) = Swish β ( x W ) ⊗ ( x V ) (3) \text{SwiGLU}(x,W,V) = \text{Swish}_\beta(xW) \otimes (xV) \tag{3} SwiGLU(x,W,V)=Swishβ(xW)(xV)(3)
这里省略了偏置项。

代码实现

参考LLaMA,全连接层使用带有SwiGLU激活函数的FFN(Position-wise Feed-Forward Network)的公式如下1
FFN SwiGLU ( x , W , V , W 2 ) = ( Swish 1 ( x W ) ⊗ x V ) W 2 (4) \text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{Swish}_1(\pmb xW) \otimes \pmb xV)W_2 \tag 4 FFNSwiGLU(x,W,V,W2)=(Swish1(xW)xV)W2(4)
这里的Swish函数可以被SiLU函数替代:
SiLU ( x ) = x σ ( x ) \text{SiLU}(\pmb x) = \pmb x \sigma(\pmb x) SiLU(x)=xσ(x)
即:
FFN SwiGLU ( x , W , V , W 2 ) = ( SiLU ( x W ) ⊗ x V ) W 2 (5) \text{FFN}_{\text{SwiGLU}}(\pmb x,W,V,W_2) = (\text{SiLU}(\pmb xW) \otimes \pmb xV)W_2 \tag 5 FFNSwiGLU(x,W,V,W2)=(SiLU(xW)xV)W2(5)

import torch
from torch import nn
import torch.nn.functional as Fclass FeedForward(nn.Module):def __init__(self, hidden_size: int, intermediate_size: int) -> None:super().__init__()self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)def forward(self, x: torch.Tensor) -> torch.Tensor:# x: (batch_size, seq_len, hidden_size)# w1(x) -> (batch_size, seq_len, intermediate_size)# w1(x) -> (batch_size, seq_len, intermediate_size)# w2(*) -> (batch_size, seq_len, hidden_size)return self.w2(F.silu(self.w1(x)) * self.w3(x))

这里w1,w2,w3分别对应公式(5)中的 W , W 2 , V W,W_2,V W,W2,V

注意维度,其中w1,w3x转换到维度intermediate_size,然后w2转换回hidden_size

参考


  1. [论文翻译]GLU Variants Improve Transformer ↩︎ ↩︎ ↩︎

  2. [论文笔记]Language Modeling with Gated Convolutional Networks ↩︎

  3. [论文笔记]SEARCHING FOR ACTIVATION FUNCTIONS ↩︎ ↩︎ ↩︎

  4. [论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS) ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/6522.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

83、动态规划-打家劫舍

思路&#xff1a; 首先使用递归方式求出最优解。从每个房屋开始&#xff0c;分别考虑偷与不偷两种情况&#xff0c;然后递归地对后续的房屋做同样的决策。这种方法确保了可以找到在不触发警报的情况下可能的最高金额。 代码如下&#xff1a; public static int rob(int[] nu…

【C++】深入剖析C++中的lambda表达式包装器bind

目录 一、lambda表达式 1、引入 2、lambda表达式 3、lambda表达式语法 ​4、lambda 的底层逻辑 二、包装器 1、包装器的表达式 ​ 2、实例化多份 3、可调用对象类型 4、实操例题 三、bind 1、bind 的表达式 2、调整参数的位置 3、绑定参数 一、lambda表达式 1、引…

wpf线程中更新UI的4种方式

在wpf中&#xff0c;更新UI上面的数据&#xff0c;那是必经之路&#xff0c;搞不好&#xff0c;就是死锁&#xff0c;或者没反应&#xff0c;很多时候&#xff0c;都是嵌套的非常深导致的。但是更新UI的方式&#xff0c;有很多的种&#xff0c;不同的方式&#xff0c;表示的意思…

hadoop学习---基于Hive的教育平台数据仓库分析案例(一)

案例背景&#xff1a; 大数据技术的应用可以从海量的用户行为数据中进行挖掘分析&#xff0c;根据分析结果优化平台的服务质量&#xff0c;最终满足用户的需求。教育大数据分析平台项目就是将大数据技术应用于教育培训领域&#xff0c;为企业经营提供数据支撑。 案例数据产生流…

现代循环神经网络(GRU、LSTM)(Pytorch 14)

一 简介 前一章中我们介绍了循环神经网络的基础知识&#xff0c;这种网络 可以更好地处理序列数据。我们在文本数据上实现 了基于循环神经网络的语言模型&#xff0c;但是对于当今各种各样的序列学习问题&#xff0c;这些技术可能并不够用。 例如&#xff0c;循环神经网络在…

使用OpenCV实现图像平移

使用OpenCV实现图像平移 程序流程效果代码 程序流程 读取图像并获取其高度、宽度和通道数。定义平移量tx和ty&#xff0c;并创建平移矩阵M。使用cv2.warpAffine函数对图像进行仿射变换&#xff08;平移&#xff09;&#xff0c;得到平移后的图像。显示平移后的图像。等待用户按…

【副本向】Lua副本逻辑

副本生命周期 OnCopySceneTick() 子线程每次心跳调用 --副本心跳 function x3323_OnCopySceneTick(elapse)if x3323_g_IsPlayerEnter 0 thenreturn; -- 如果没人进入&#xff0c;则函数直接返回endif x3323_g_GameOver 1 thenif x3323_g_EndTick > 0 thenx3323_CountDown…

【SRC-Python】在数字与字母 / 中文与英文之间插入空格的自动化解决方案

文章目录 Part.I IntroductionPart.II 使用方法Chap.I 直接处理字符串Chap.II 处理文件 Part.III Source CodeReference Part.I Introduction 在编辑文本的过程中&#xff0c;尤其是在 COPY 的过程中&#xff0c;经常会遇到如下问题&#xff1a; 源文本数字与英文字母之间没有…

循环神经网络完整实现(Pytorch 13)

一 循环神经网络的从零开始实现 从头开始基于循环神经网络实现字符级语言模型。 %matplotlib inline import math import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2lbatch_size, num_steps 32, 35 train_iter, vocab …

【AI】ONNX

长期更新&#xff0c;建议收藏关注&#xff01; 友情链接 Netron 开放神经网络交换&#xff08;Open Neural Network Exchange&#xff09;简称ONNX,是微软和Facebook提出用来表示深度学习模型的开放格式。所谓开放就是ONNX定义了一组和环境&#xff0c;平台均无关的标准格式…

ASP.NET IIS Express一定vs停止调试,就退出了,如何不退出

》》》 在项目右击属性&#xff0c;找到Web&#xff0c;把启用”编辑并继续“ 复选框 去掉

asp.net结课作业中遇到的问题解决2

目录 1、如何实现评论交流的界面 2、如果想要将文字添加到数据库中&#xff0c;而不是乱码&#xff0c;该怎么修改 3、如果想要添加的数据已经存在于数据库&#xff0c;就不允许添加了&#xff0c;该如何实现 4、想要实现某个模块下有好几个小的功能该如何实现 5、想要实现…

Altium Designer入门基础操作

软件下载环境搭建&#xff1a;pan.baidu.com/s/1HshgKTmkkBpbIRa-9Wq9cQ 密码&#xff1a;ckck 工程建立&#xff1a; 创建 库绘制 为什么管脚要100mil 元素10mil 原理图库得正确性报告 原理图页设置大小&#xff0c;标准自定义&#xff0c;格点为100mil 使用库画原理图&a…

08 IRF技术 华三交换机实现

IRF 详细介绍 我知道 AI IRF 技术是指集成路由功能(Integrated Routing and Bridging)技术,是惠普(Hewlett Packard)公司开发的一种基于硬件的虚拟化技术。IRF 技术可以将多台物理设备组合成一个逻辑设备,实现设备的高可用性和灵活性。 IRF 技术主要有以下特点: 1. …

MySQL-集群1

一、为什么要用mysql集群&#xff1f;&#xff1a; mysql单体架构在企业中很少用&#xff0c;原因&#xff1a;①会形成单点故障&#xff0c;没有高可用的效果&#xff1b;②mysql本身是一个I/O能力比较差&#xff0c;并发能力比较差的应用服务&#xff0c;在较高规模的网络I/…

【计算机网络】循环冗余校验:Cyclic Redundancy Check

1. 任务目标 利用循环冗余校验&#xff08;CRC&#xff09;检测错误。 循环冗余校验&#xff08;英语&#xff1a;Cyclic redundancy check&#xff0c;通称 CRC&#xff09;是一种根据网上数据包或计算机文件等数据产生简短固定位数校验码的一种散列函数&#xff0c;主要用来…

谈谈Tcpserver开启多线程并发处理遇到的问题!

最近在学习最基础的socket网络编程&#xff0c;在Tcpserver开启多线程并发处理时遇到了一些问题&#xff01; 说明 在linux以及Windows的共享文件夹进行编写的&#xff0c;所以代码中有的部分使用 #ifdef WIN64 ... #else ... #endif 进入正题&#xff01;&#xff01;&…

OSPF优化

OSPF的优化主要目的是为了减少LSA的更新量 路由汇总-----可以减少骨干区域的LSA数量 特殊区域-----可以减少非骨干区域的LSA数量 OSPF路由汇总 域间路由汇总 域间路由汇总在ABR设备上进行操作 [GS-R2-ospf-1-area-0.0.0.1]abr-summary 192.168.0.0 255.255.224.0 [GS-R3-o…

NEO 学习之session7

文章目录 选项 A&#xff1a;它涉及学习标记数据。 选项 B&#xff1a;它需要预定义的输出标签进行训练。 选项 C&#xff1a;它涉及在未标记的数据中寻找模式和关系。 选项 D&#xff1a;它专注于根据输入-输出对进行预测。 答案&#xff1a;选项 C 描述了无监督学习的本质&am…

服务器被攻击,为什么后台任务管理器无法打开?

在服务器遭受DDoS攻击后&#xff0c;当后台任务管理器由于系统资源耗尽无法打开时&#xff0c;管理员需要依赖间接手段来进行攻击类型的判断和解决措施的实施。由于涉及真实代码可能涉及到敏感操作&#xff0c;这里将以概念性伪代码和示例指令的方式来说明。 判断攻击类型 步…