本系列文章板块规划
提示:以下内容仅为个人学习感悟,无法保证完全的正确和权威,大家酌情食用谢谢。
第一部分 torchdiffeq背后的数理逻辑
第二部分 torchdiffeq的基本用法
第三部分 trochdiffeq的升级用法
第四部分 torchdifffeq的案例和代码解析
第五部分 总结
第二部分的参考网站:https://github.com/rtqichen/torchdiffeq
torchdiffeq的基本用法
代码解释
我们来看看官网文档中给出的解释,其翻译为:
“torchdiffeq是一个在PyTorch中实现的常微分方程(ODE)求解器,使用伴随方法支持通过ODE解进行反向传播,可以以恒定的内存成本进行。”
这里的ODE解进行反向传播是指,在神经网络中,ODE被用作网络的一部分时,可以通过ODE的解对于网络进行训练。比如,我们使用神经网络对连续动态系统进行建模,ODE会用来描述系统状态随时间的变化。网络的一部分输出因此将是ODE的解,我们通过这些揭解对于网络进行训练。
关于这里提到的伴随方法(adjoint method),上一部分提到,ODE的解是通过数值方法得到的(欧拉,龙格库塔等等,不是解方程解出来的,是通过算法逼近的),他会储存一些相关梯度,而不需要储存训练过程中的所有状态,因此可以减少内存使用。Pytorch中就写了能够执行ODE解协助的神经网络,和加入了伴随方法节省空间的算法。
那接下来我们来看看,如何使用代码。
torchdiffeq安装
因此,在运行相关的代码时,首先需要安装这个库:
pip install torchdiffeq
如果安装最新版,从Github上的代码仓库安装。这是 torchdiffeq 库的GitHub仓库URL,它指向了库的源代码所在的位置:
pip install git+https://github.com/rtqichen/torchdiffeq
torchdiffeq基本用法
这个库提供了一个主要的接口 odeint,它包含了用于解决初始值问题(Initial Value Problems, IVP)的通用算法,并且对所有主要参数都实现了梯度计算。初始值问题由一个常微分方程(ODE)和一个初始值组成:
dy/dt = f(t, y) y(t_0) = y_0.
这部分就是是我在第一部分提到的,ODE的基本表达式。我们需要一个方程表达变化,一个初值为迭代的开始。常微分方程求解器通过初始条件,找到满足ODE的连续轨迹。
使用默认求解器解决一个初始值问题的代码如下:
from torchdiffeq import odeintodeint(func, y0, t)
- 从torchdiffeq 导入 odeint 函数,作为求解接口。
- 定义ODE函数func,代表了我们要求解的常微分方程。
- 初始条件:y0,即在初试时间t_0时刻的函数值y(t_0)
- 时间向量t,记为想要计算解的时间点,
- 然后调用求解。
官方文档特别推荐使用伴随方法。
from torchdiffeq import odeint_adjoint as odeintodeint(func, y0, t)
伴随方法只是简单地围绕 odeint 进行了封装,但是在反向调用中解决伴随ODE时,它将只使用 O(1) 的内存,但是在使用时,我们需要注意func必须是一个nn.Moudle。
torchdiffeq高级用法
可以基于事件停止ODE的求解。
from torchdiffeq import odeint_event
odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs)
参数 | 类型/默认值 | 描述 |
---|---|---|
func | 必需 | 代表ODE系统的函数,形式为 func(t, y),它定义了如何根据时间 t 和当前状态 y 计算导数 dy/dt。 |
y0 | 必需 | 表示ODE系统在初始时间 t0 的初始状态的张量。 |
t0 | 必需 | 标量,表示初始时间的值。 |
event_fn | 必需 | 关键字参数。一个函数,形式为 event_fn(t, y),它返回一个张量。当此张量的任意元素为零时,求解将终止。可以返回多个值来定义多个事件。 |
reverse_time | bool / False | 指定是否在反向时间中求解ODE。如果为 True,求解器将从 t0 开始向过去求解。默认为 False。 |
odeint_interface | odeint 或 odeint_adjoint | 指定用于通过ODE求解进行微分的模式。odeint 表示直接模式,odeint_adjoint 表示伴随模式。默认是 odeint。 |
**kwargs | - | 传递给 odeint_interface 的任何剩余关键字参数,可以用来设置求解器选项,如容差、最大步数等。 |
atol | float | (通常作为 **kwargs 之一)绝对容差参数,控制事件检测的数值精度。 |
options | 算法名 | Adaptive-step,fixed-step等等相应的算法 |