之前使用深度学习时一直对各种激活函数和权重初始化策略信手拈用,然而不能只知其表不知其里。若想深入理解为何选择某种激活函数和权重初始化方法卓有成效还是得回归本源,本文就从反向传播的计算过程来按图索骥。
为了更好地演示深度学习中的前向传播和反向传播,有必要图文结合,先按下面这个计算图造些数据。
这是一个输入只有单个样本、包含两个特征,两个隐藏层、分别带有2个神经元,以及一个输出的三层全连接神经网络。
输入和权重
输入 I n p u t Input Input (每行表示一个样本,每列表示一个特征)
X = [ x 1 , x 2 ] = [ 1 , − 1 ] X=[x_1,x_2]=[1,-1] X=[x1,x2]=[1,−1]
标签 y = [ 1 ] y=[1] y=[1]
权重 W W W (每列对应一个神经元,行数等于样本特征数)
W 1 = [ w 1 w 3 w 2 w 4 ] = [ 1 − 1 − 2 1 ] \begin{align} W_1 & = \begin{bmatrix} w_1 & w_3 \\ w_2 & w_4 \\ \end{bmatrix} \hspace{100cm} \\ & = \begin{bmatrix} 1 & -1 \\ -2 & 1 \\ \end{bmatrix} \end{align} W1=[w1w2w3w4]=[1−2−11]
W 2 = [ w 5 w 7 w 6 w 8 ] = [ 2 − 2 − 1 − 1 ] \begin{align} W_2 & = \begin{bmatrix} w_5 & w_7 \\ w_6 & w_8 \\ \end{bmatrix} \hspace{100cm} \\ & = \begin{bmatrix} 2 & -2 \\ -1 & -1 \\ \end{bmatrix} \end{align} W2=[w5w6w7w8]=[2−1−2−1]
W 3 = [ w 9 w 11 w 10 w 12 ] = [ 3 − 1 − 1 4 ] \begin{align} W_3 & = \begin{bmatrix} w_9 & w_{11} \\ w_{10} & w_{12} \\ \end{bmatrix} \hspace{100cm} \\ & = \begin{bmatrix} 3 & -1 \\ -1 & 4 \\ \end{bmatrix} \end{align} W3=[w9w10w11w12]=[3−1−14]
偏置 b b b (长度等于神经元数量)
b 1 = [ b 11 , b 12 ] = [ 1 , 0 ] b_1=[b_{11},b_{12}]=[1,0] b1=[b11,b12]=[1,0]
b 2 = [ b 21 , b 22 ] = [ 0 , 0 ] b_2=[b_{21},b_{22}]=[0,0] b2=[b21,b22]=[0,0]
b 3 = [ − 2 ] b_3=[-2] b3=[−2]
前向传播过程
前向传播就是从输入经隐藏层到输出层的计算过程。
从输入到第一个隐藏层的计算
z 1 = w 1 ⋅ x 1 + w 2 ⋅ x 2 + b 11 = 4 z_1=w_1 · x_1 + w_2 · x_2 + b_{11}=4 z1=w1⋅x1+w2⋅x2+b11=4
z 2 = w 3 ⋅ x 1 + w 4 ⋅ x 2 + b 12 = − 2 z_2=w_3 · x_1 + w_4 · x_2 + b_{12}=-2 z2=w3⋅x1+w4⋅x2+b12=−2
a 11 = σ ( z 1 ) = 0.9820 a_{11}=\sigma(z_1)=0.9820 a11=σ(z1)=0.9820
a 12 = σ ( z 2 ) = 0.1192 a_{12}=\sigma(z_2)=0.1192 a12=σ(z2)=0.1192
其中, σ = s i g m o i d = 1 1 + e − x \sigma=sigmoid={1 \over{1+e^{-x}}} σ=sigmoid=1+e−x1 ,其导数为 σ ′ = s i g m o i d ∗ ( 1 − s i g m o i d ) = 1 1 + e − x − 1 ( 1 + e − x ) 2 \sigma'=sigmoid * (1 - sigmoid)={1 \over{1+e^{-x}}}-{1 \over{(1+e^{-x}})^2} σ′=sigmoid∗(1−sigmoid)=1+e−x1−(1+e−x)21
隐藏层 H 1 = [ a 11 , a 12 ] H_1=[a_{11},a_{12}] H1=[a11,a12] ,作为第二个隐藏层的输入。
从第一个隐藏层到第二个隐藏层的计算
z 3 = w 5 ⋅ a 11 + w 6 ⋅ a 12 + b 21 = 1.8448 z_3=w_5 · a_{11} + w_6 · a_{12} + b_{21}=1.8448 z3=w5⋅a11+w6⋅a12+b21=1.8448
z 4 = w 7 ⋅ a 11 + w 8 ⋅ a 12 + b 22 = − 2.0832 z_4=w_7 · a_{11} + w_8 · a_{12} + b_{22}=-2.0832 z4=w7⋅a11+w8⋅a12+b22=−2.0832
a 21 = σ ( z 3 ) = 0.8635 a_{21}=\sigma(z_3)=0.8635 a21=σ(z3)=0.8635
a 22 = σ ( z 4 ) = 0.1107 a_{22}=\sigma(z_4)=0.1107 a22=σ(z4)=0.1107
隐藏层 H 2 = [ a 21 , a 22 ] H_2=[a_{21},a_{22}] H2=[a21,a22] ,作为输出层的输入。
从第二个隐藏层到输出层的计算
y ^ = w 9 ⋅ a 21 + w 10 ⋅ a 22 + b 3 = 0.4798 \hat{y}=w_9 · a_{21} + w_{10} · a_{22} + b_{3}=0.4798 y^=w9⋅a21+w10⋅a22+b3=0.4798
一个样本的损失: L = ( y ^ − y ) 2 = y ^ 2 + y 2 − 2 y ^ y = 0.2706 L=(\hat{y}-y)^2=\hat{y}^2+y^2-2\hat{y}y=0.2706 L=(y^−y)2=y^2+y2−2y^y=0.2706
计算结果如下:
反向传播过程
以求 w 1 w_1 w1 的偏导数为例,其他可仿照之,利用链式法则计算梯度。
∂ L ∂ w 1 = ∂ z 1 ∂ w 1 ∂ L ∂ z 1 = x 1 ∂ L ∂ z 1 ( 1 ) \begin{align} {\partial L \over \partial w_1} & = {\partial z_1 \over \partial w_1} {\partial L \over \partial z_1} \hspace{100cm} \\ &=x_1 {\partial L \over \partial z_1} \ \ \ \ \ (1) \end{align} ∂w1∂L=∂w1∂z1∂z1∂L=x1∂z1∂L (1)
∂ L ∂ w 1 = ∂ z 1 ∂ w 1 ∂ a 11 ∂ z 1 ∂ L ∂ a 11 = x 1 σ ′ ( z 1 ) ∂ L ∂ a 11 ( 2 ) \begin{align} {\partial L \over \partial w_1} & = {\partial z_1 \over \partial w_1} {\partial a_{11} \over \partial z_1} {\partial L \over \partial a_{11}} \hspace{100cm} \\ &=x_1 \sigma'(z_1) {\partial L \over \partial a_{11}} \ \ \ \ \ (2) \end{align} ∂w1∂L=∂w1∂z1∂z1∂a11∂a11∂L=x1σ′(z1)∂a11∂L (2)
∂ L ∂ w 1 = ∂ z 1 ∂ w 1 ∂ a 11 ∂ z 1 ( ∂ z 3 ∂ a 11 ∂ L ∂ z 3 + ∂ z 4 ∂ a 11 ∂ L ∂ z 4 ) = x 1 σ ′ ( z 1 ) [ w 5 ∂ L ∂ z 3 + w 7 ∂ L ∂ z 4 ] ( 3 ) \begin{align} {\partial L \over \partial w_1} & = {\partial z_1 \over \partial w_1} {\partial a_{11} \over \partial z_1} ({\partial z_3 \over \partial a_{11}} {\partial L \over \partial z_{3}} + {\partial z_4 \over \partial a_{11}} {\partial L \over \partial z_{4}}) \hspace{100cm} \\ &=x_1 \sigma'(z_1) [w_5 {\partial L \over \partial z_{3}} + w_7 {\partial L \over \partial z_{4}}] \ \ \ \ \ (3) \end{align} ∂w1∂L=∂w1∂z1∂z1∂a11(∂a11∂z3∂z3∂L+∂a11∂z4∂z4∂L)=x1σ′(z1)[w5∂z3∂L+w7∂z4∂L] (3)
∂ L ∂ w 1 = ∂ z 1 ∂ w 1 ∂ a 11 ∂ z 1 ( ∂ z 3 ∂ a 11 ∂ a 21 ∂ z 3 ∂ L ∂ a 21 + ∂ z 4 ∂ a 11 ∂ a 22 ∂ z 4 ∂ L ∂ a 22 ) = x 1 σ ′ ( z 1 ) [ w 5 σ ′ ( z 3 ) ∂ L ∂ a 21 + w 7 σ ′ ( z 4 ) ∂ L ∂ a 22 ] ( 4 ) \begin{align} {\partial L \over \partial w_1} & = {\partial z_1 \over \partial w_1} {\partial a_{11} \over \partial z_1} ({\partial z_3 \over \partial a_{11}} {\partial a_{21} \over \partial z_{3}} {\partial L \over \partial a_{21}} + {\partial z_4 \over \partial a_{11}} {\partial a_{22} \over \partial z_{4}} {\partial L \over \partial a_{22}}) \hspace{100cm} \\ &=x_1 \sigma'(z_1) [w_5 \sigma'(z_3) {\partial L \over \partial a_{21}} + w_7 \sigma'(z_4) {\partial L \over \partial a_{22}}] \ \ \ \ \ (4) \end{align} ∂w1∂L=∂w1∂z1∂z1∂a11(∂a11∂z3∂z3∂a21∂a21∂L+∂a11∂z4∂z4∂a22∂a22∂L)=x1σ′(z1)[w5σ′(z3)∂a21∂L+w7σ′(z4)∂a22∂L] (4)
∂ L ∂ w 1 = ∂ z 1 ∂ w 1 ∂ a 11 ∂ z 1 ( ∂ z 3 ∂ a 11 ∂ a 21 ∂ z 3 ∂ y ^ ∂ a 21 ∂ L ∂ y ^ + ∂ z 4 ∂ a 11 ∂ a 22 ∂ z 4 ∂ y ^ ∂ a 22 ∂ L ∂ y ^ ) = x 1 σ ′ ( z 1 ) [ w 5 σ ′ ( z 3 ) w 9 ∂ L ∂ y ^ + w 7 σ ′ ( z 4 ) w 10 ∂ L ∂ y ^ ] ( 5 ) \begin{align} {\partial L \over \partial w_1} & = {\partial z_1 \over \partial w_1} {\partial a_{11} \over \partial z_1} ({\partial z_3 \over \partial a_{11}} {\partial a_{21} \over \partial z_{3}} {\partial \hat{y} \over \partial a_{21}} {\partial L \over \partial \hat{y}} + {\partial z_4 \over \partial a_{11}} {\partial a_{22} \over \partial z_{4}} {\partial \hat{y} \over \partial a_{22}} {\partial L \over \partial \hat{y}}) \hspace{100cm} \\ &=x_1 \sigma'(z_1) [w_5 \sigma'(z_3) w_9 {\partial L \over \partial \hat{y}} + w_7 \sigma'(z_4) w_{10} {\partial L \over \partial \hat{y}}] \ \ \ \ \ (5) \end{align} ∂w1∂L=∂w1∂z1∂z1∂a11(∂a11∂z3∂z3∂a21∂a21∂y^∂y^∂L+∂a11∂z4∂z4∂a22∂a22∂y^∂y^∂L)=x1σ′(z1)[w5σ′(z3)w9∂y^∂L+w7σ′(z4)w10∂y^∂L] (5)
∂ L ∂ w 1 = x 1 σ ′ ( z 1 ) [ w 5 σ ′ ( z 3 ) w 9 ∂ L ∂ y ^ + w 7 σ ′ ( z 4 ) w 10 ∂ L ∂ y ^ ] = 1 ∗ 0.0177 ∗ [ 2 ∗ 0.1179 ∗ 3 ∗ ( 2 y ^ − 2 y ) + ( − 2 ∗ 0.0985 ∗ − 1 ∗ ( 2 y ^ − 2 y ) ) ] = − 0.0166 ( 6 ) \begin{align} {\partial L \over \partial w_1} & = x_1 \sigma'(z_1) [w_5 \sigma'(z_3) w_9 {\partial L \over \partial \hat{y}} + w_7 \sigma'(z_4) w_{10} {\partial L \over \partial \hat{y}}] \hspace{100cm} \\ &=1*0.0177*[2*0.1179*3*(2 \hat{y}-2y) + (-2*0.0985*-1*(2 \hat{y}-2y))] \\ &=-0.0166 \ \ \ \ \ (6) \end{align} ∂w1∂L=x1σ′(z1)[w5σ′(z3)w9∂y^∂L+w7σ′(z4)w10∂y^∂L]=1∗0.0177∗[2∗0.1179∗3∗(2y^−2y)+(−2∗0.0985∗−1∗(2y^−2y))]=−0.0166 (6)
与pytorch计算结果相同。
import torch
from torch import nn#输入与权重
X=torch.tensor([[1.0,-1.0]])
y=torch.tensor([1.0])
W1=torch.tensor([[1.0,-1.0],[-2.0,1.0]],requires_grad=True)
b1=torch.tensor([1.0,0.0],requires_grad=True)
W2=torch.tensor([[2.0,-2.0],[-1.0,-1.0]],requires_grad=True)
b2=torch.tensor([0.0,0.0],requires_grad=True)
W3=torch.tensor([[3.0],[-1.0]],requires_grad=True)
b3=torch.tensor([-2.0],requires_grad=True)#隐藏层1
z1=torch.matmul(X,W1)+b1
a1=torch.sigmoid(z1) #隐藏层2
z2=torch.matmul(a1,W2)+b2
a2=torch.sigmoid(z2) #输出层
y_hat=torch.matmul(a2,W3)+b3#损失函数
loss=nn.MSELoss(reduction='none')#计算损失
L=loss(y_hat,y).sum()
L.backward()
print(W1.grad)
要想求 ∂ L ∂ w 1 {\partial L \over \partial w_1} ∂w1∂L ,我们先看式 ( 1 ) (1) (1) , ∂ z 1 ∂ w 1 {\partial z_1 \over \partial w_1} ∂w1∂z1 是可以立刻得出的,因为它就是 w 1 w_1 w1 前面连接的输入的值。实际上对于任何权重,其偏导数 ∂ w {\partial w} ∂w 表达式的第一项都是可以通过其连接的输入立刻获得(即利用前向传播过程中存储的每个神经元的中间结果),比如对于靠后的 w 9 w_9 w9 ,其输入为 a 21 a_{21} a21 ,展开得:
a 21 = σ [ w 5 ⋅ σ ( w 1 ⋅ x 1 + w 2 ⋅ x 2 + b 11 ) + w 6 ⋅ σ ( w 3 ⋅ x 1 + w 4 ⋅ x 2 + b 12 ) + b 21 ] a_{21}=\sigma[w_5 · \sigma(w_1 · x_1 + w_2 · x_2 + b_{11}) + w_6 · \sigma(w_3 · x_1 + w_4 · x_2 + b_{12}) + b_{21}] a21=σ[w5⋅σ(w1⋅x1+w2⋅x2+b11)+w6⋅σ(w3⋅x1+w4⋅x2+b12)+b21]
a 21 a_{21} a21 是 σ ( W 2 σ ( W 1 X + b 1 ) + b 2 ) ( 7 ) \sigma(W_2 \sigma(W_1X+b1)+b2) \ \ \ \ \ (7) σ(W2σ(W1X+b1)+b2) (7) 结果其中之一。
可以看出,每一部分都会经激活函数,而对于 s i g m o i d sigmoid sigmoid 激活函数来说,第一项的计算可能会是无穷小,因此可能会引发梯度消失问题,而使用Relu则可以 减轻困扰以往神经网络的梯度消失问题。
继续回到对 ∂ L ∂ w 1 {\partial L \over \partial w_1} ∂w1∂L 的讨论上。现在还要求 ∂ L ∂ z 1 {\partial L \over \partial z_1} ∂z1∂L ,那么 ∂ L ∂ z 1 {\partial L \over \partial z_1} ∂z1∂L 如何求解呢?这就是反向传播要解决的问题了。
我们再回看一下式 ( 2 ) − ( 5 ) (2)-(5) (2)−(5) 中的 ∂ L ∂ z 1 {\partial L \over \partial z_1} ∂z1∂L ,列示如下:
∂ L ∂ z 1 = σ ′ ( z 1 ) ∂ L ∂ a 11 {\partial L \over \partial z_1} = \sigma'(z_1) {\partial L \over \partial a_{11}} ∂z1∂L=σ′(z1)∂a11∂L
∂ L ∂ z 1 = σ ′ ( z 1 ) [ w 5 ∂ L ∂ z 3 + w 7 ∂ L ∂ z 4 ] {\partial L \over \partial z_1} = \sigma'(z_1) [w_5 {\partial L \over \partial z_{3}} + w_7 {\partial L \over \partial z_{4}}] ∂z1∂L=σ′(z1)[w5∂z3∂L+w7∂z4∂L]
∂ L ∂ z 1 = σ ′ ( z 1 ) [ w 5 σ ′ ( z 3 ) ∂ L ∂ a 21 + w 7 σ ′ ( z 4 ) ∂ L ∂ a 22 ] {\partial L \over \partial z_1} = \sigma'(z_1) [w_5 \sigma'(z_3) {\partial L \over \partial a_{21}} + w_7 \sigma'(z_4) {\partial L \over \partial a_{22}}] ∂z1∂L=σ′(z1)[w5σ′(z3)∂a21∂L+w7σ′(z4)∂a22∂L]
∂ L ∂ z 1 = σ ′ ( z 1 ) [ w 5 σ ′ ( z 3 ) w 9 ∂ L ∂ y ^ + w 7 σ ′ ( z 4 ) w 10 ∂ L ∂ y ^ ] {\partial L \over \partial z_1} = \sigma'(z_1) [w_5 \sigma'(z_3) w_9 {\partial L \over \partial \hat{y}} + w_7 \sigma'(z_4) w_{10} {\partial L \over \partial \hat{y}}] ∂z1∂L=σ′(z1)[w5σ′(z3)w9∂y^∂L+w7σ′(z4)w10∂y^∂L]
可以看出,从前往后计算 ∂ L ∂ z 1 {\partial L \over \partial z_1} ∂z1∂L 会不太容易,因为前面项总会依赖后面项的计算结果,所以得先一直往后计算。
但反过来就简单多了,我们可以从最后一项出发,对于最初的计算图,最后一项是输出值关于损失的导数 ∂ L ∂ y ^ {\partial L \over \partial \hat{y}} ∂y^∂L ,这个可以由确定的损失函数求得。
有了 ∂ L ∂ y ^ {\partial L \over \partial \hat{y}} ∂y^∂L ,可以通过 w 9 、 w 10 w_9、w_{10} w9、w10 求得 ∂ L ∂ a 21 、 ∂ L ∂ a 22 {\partial L \over \partial a_{21}}、 {\partial L \over \partial a_{22}} ∂a21∂L、∂a22∂L
有了 ∂ L ∂ a 21 、 ∂ L ∂ a 22 {\partial L \over \partial a_{21}}、 {\partial L \over \partial a_{22}} ∂a21∂L、∂a22∂L ,可以通过 w 5 、 w 7 w_5、w_7 w5、w7 求得 ∂ L ∂ a 11 {\partial L \over \partial a_{11}} ∂a11∂L (别忘了中间还要乘以一个 $\sigma’(z) $ , z z z 只是一个常量,也可以从前向传播存储的中间结果获得) 。
再回味一下上面这个从后往前的计算过程,是不是跟前向传播很相似?这就是梯度的反向传播!与前向传播的图示比对如下:
其中:
∂ L ∂ a 21 = w 9 ∂ L ∂ y ^ ( 8 ) {\partial L \over \partial a_{21}}=w_9 {\partial L \over \partial \hat{y}} \ \ \ \ \ (8) ∂a21∂L=w9∂y^∂L (8)
∂ L ∂ a 22 = w 10 ∂ L ∂ y ^ ( 9 ) {\partial L \over \partial a_{22}}=w_{10} {\partial L \over \partial \hat{y}} \ \ \ \ \ (9) ∂a22∂L=w10∂y^∂L (9)
∂ L ∂ a 11 = w 5 ∂ L ∂ z 3 + w 7 ∂ L ∂ z 4 ( 10 ) {\partial L \over \partial a_{11}}=w_5 {\partial L \over \partial z_{3}} + w_7 {\partial L \over \partial z_{4}} \ \ \ \ \ (10) ∂a11∂L=w5∂z3∂L+w7∂z4∂L (10)
∂ L ∂ z 3 = σ ′ ( z 3 ) ∂ L ∂ a 21 {\partial L \over \partial z_{3}}=\sigma'(z_3) {\partial L \over \partial a_{21}} ∂z3∂L=σ′(z3)∂a21∂L
∂ L ∂ z 4 = σ ′ ( z 4 ) ∂ L ∂ a 22 {\partial L \over \partial z_{4}}=\sigma'(z_4) {\partial L \over \partial a_{22}} ∂z4∂L=σ′(z4)∂a22∂L
∂ L ∂ z 1 = σ ′ ( z 1 ) ∂ L ∂ a 11 {\partial L \over \partial z_{1}}=\sigma'(z_1) {\partial L \over \partial a_{11}} ∂z1∂L=σ′(z1)∂a11∂L
这个计算过程和前向传播很类似(尤其是式 ( 10 ) (10) (10) ),所以称之为反向传播。
从式 ( 5 ) 、 ( 7 ) (5)、(7) (5)、(7) 可以看出,每个权重的偏导数都会涉及到一连串 w w w 与激活函数导数的乘积以及权重与输入的乘积,试想,如果没有一个良好初始化的权重,这么多 w w w 相乘很可能会引起梯度爆炸或梯度消失等参数不稳定问题。
比如方差为1的正态随机矩阵和一个初始权重矩阵相乘,会引起梯度爆炸:
W = torch.normal(0, 1, size=(5,5))
print('初始权重矩阵 \n',W)
for i in range(100):W = torch.matmul(W,torch.normal(0, 1, size=(5, 5)))print('100个矩阵相乘后 \n', W)