GPT中使用的Gaussian Error Linear Unit (GELU)
flyfish
高斯误差线性单元(GELU)
Gaussian Error Linear Unit
g e l u ( x ) = x P ( X ≤ x ) X ∼ N ( 0 , 1 ) gelu(x) = x P(X \leq x) \qquad X \sim \mathcal{N}(0, 1) gelu(x)=xP(X≤x)X∼N(0,1)
GELU ( x ) = x P ( X ≤ x ) = x Φ ( x ) = x ⋅ 1 2 [ 1 + erf ( x / 2 ) ] , X ∼ N ( 0 , 1 ) \text{GELU}\left(x\right) = x{P}\left(X\leq{x}\right) = x\Phi\left(x\right) = x \cdot \frac{1}{2}\left[1 + \text{erf}(x/\sqrt{2})\right], X\sim \mathcal{N}(0,1) GELU(x)=xP(X≤x)=xΦ(x)=x⋅21[1+erf(x/2)],X∼N(0,1)
Φ ( x ) \Phi(x) Φ(x)是标准高斯累积分布函数(standard Gaussian cumulative distribution function)
近似方式1是
0.5 x ( 1 + tanh [ 2 / π ( x + 0.044715 x 3 ) ] ) 0.5x\left(1+\tanh\left[\sqrt{2/\pi}\left(x + 0.044715x^{3}\right)\right]\right) 0.5x(1+tanh[2/π(x+0.044715x3)])
近似方式2是
x σ ( 1.702 x ) , x\sigma\left(1.702x\right), xσ(1.702x),
g e l u ( x ) = 1 2 x ( 1 + e r f ( x 2 ) gelu(x) = \frac{1}{2} x (1 + erf(\frac{x}{\sqrt{2}}) gelu(x)=21x(1+erf(2x)
numpy 实现
import numpy as npdef gelu(x):return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))import matplotlib.pyplot as plt# 创建一个包含从-10到10的值的数组
x = np.linspace(-10, 10, 1000)# 计算 GELU 函数的值
y_gelu = gelu(x)# 绘制 GELU 函数的图像
plt.plot(x, y_gelu, label='GELU')
plt.title('GELU Activation Function')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.show()
Pytorch实现
import torch
import torch.nn as nn
import numpy as npdef gelu1(x):m = nn.GELU()return m(x)def gelu2(x):cdf = 0.5 * (1.0 + torch.erf(x / np.sqrt(2.0)))return x * cdf
import matplotlib.pyplot as plt# 创建一个包含从-10到10的值的数组
x = torch.linspace(-10, 10, 1000)# 计算 GELU 函数的值
y_gelu1 = gelu1(x)
y_gelu2 = gelu2(x)
# 绘制 GELU 函数的图像
plt.plot(x, y_gelu1, label='GELU1')
plt.plot(x, y_gelu2, label='GELU2')
plt.title('GELU Activation Function')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.grid(True)
plt.show()