jax可微分编程的笔记(4)
第四章 JAX的微分运算
我们从最小二乘法说起,构建出深度学习的轮廓,并最终基于
MINST手写数据集训练了一个简单的全连接的神经网络.
4.1 微分操作的语法
本节给出复杂积分的运算和隐函数求导这两个例子.
4.1.1 JAX中的梯度操作
from jax imprt grad
def f(x,y):
return (x+y)**2
df1=grad(f,argnums=0)
x,y=1.0,2.0
print(f(x,y))
print(df1(x,y))
4.1.2 JAX中的雅可比矩阵
在JAX中,雅可比—矢量乘法对应的函数是jvp,它接受一个函数
fun,输入变量的主值primals以及输入变量的切值tangents,
返回函数作用后的输出变量的主值和切值。
4.1.3 JAX中的黑塞矩阵
例如,在截断牛顿CG算法中,我们可以利用黑塞矩阵寻找
凸函数的极小值,抑或是探索神经网络在某点处的曲率。
黑塞矩阵的获得的示例代码如下:
from jax import jacfwd,jacrev
def hessian1(f): return jacfwd(jacfwd(f))
def hessian2(f): return jacfwd(jacrev(f))
def hessian3(f): return jacrev(jacfwd(f))
def hessian4(f): return jacrev(jacrev(f))
上述代码中4个函数的结果相同,但是性能上第2个最好。
4.1.4 自定义算符及隐函数求导
例如 f(x)=ln(1+exp(x)) 它的导函数是f'(x)=exp(x)/(1+exp(x)),
在x的值比较大时,我们期待函数的导函数趋于1,但是由于这里存在
大数相除的情况,程序容易由于数值上溢或者是数值不稳定而产生
错误,为此,我们把导函数变形为 f'(x)=1-1/(1+exp(x))
4.2 梯度下降
梯度下降算法是自动微分在实际问题中最为成功的应用之一。
它在深度学习中的地位是本质性的。
4.2.1 从最小二乘法说开去
最小二乘法的本质,是通过待定函数中的参量以撑开一个函数的空
间,并在这个计算机可以表示的狭窄空间中,寻找并确定一组最优
的参数。全连接神经网络是最小二乘法在高维问题上的推广。
最小二乘法通常用于数据的线性拟合。
4.2.2 寻找极小值
假设有一个标量的函数f(x),函数在某点增长最快的方向由
delta f给出。我们找极小值,只要按公式 更新x即可。
公式为 第n+1 个 x = 第 n 个 x + a delta f(第n 个 x)
上述公式是梯度下降算法最简单的版本,参数a也称为学习率,
也称为步长,一个公式的重要性往往与它的长度成反比,
这个公式道出了几乎是一切优化算法的核心所在。
在实际的程序设计中,像学习率a这种参数的选取,往往更多
依赖程序调试者的经验。区别模型参数x,像学习率a这样在优化
过程开始前预先设定的参数,也被称为模型的超参数。从本质上
说,模型参数决定了模型本身的状态,而模型的超参数则确定了
选取的模型。
除了依赖经验,人们也针对超参数的优化问题发展了一系列的算法
例如,网络搜索,随机搜索,贝叶斯优化等。
4.2.3 训练及误差
由于待定的参数过少,无法准确地描述原本数据的规律,这种情况
称为欠拟合。当拟合出的多项式只是准确地经过了每一个数据点,
却没有理解数据背后的规律,这样得到的模型,在训练数据集表现
很好,在测试数据集上表现很差,这种情况被称为过拟合。
参数数目小于数据数目被称为参数化不足,参数数目大于数据数目
的情况被称为过参数化。
参数化程度仅仅描述了模型中参数的数目,而拟合与过拟合则描述了
模型的泛化能力。
过拟合函数,在数据点的边缘处出现了猛烈的震荡行为,在数值分析中
这也称为龙格现象。
4.2.4 全连接神经网络
我们通过softmax函数将矢量的分量值映射到[0,1]之间。
在损失函数中,使用交叉熵作为距离函数的定义,来加快
梯度更新速度。每一次更新选取的样本数目称为批大小,也是超参数。