公式介绍
在【CS285】中提到了高斯策略对数概率公式的公式如下:
log π θ ( a t ∣ s t ) = − 1 2 ∥ f ( s t ) − a t ∥ Σ 2 + const \log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} \left\| f(\mathbf{s}_t) - \mathbf{a}_t \right\|_{\Sigma}^2 + \text{const} logπθ(at∣st)=−21∥f(st)−at∥Σ2+const
符号说明
PDF:Probability Density Function,概率密度函数
推导说明(from DeepSeek-R1-web)
1. 多元高斯分布的概率密度函数
设策略 π θ ( a t ∣ s t ) \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) πθ(at∣st)服从均值为 μ = f ( s t ) \mu= f(\mathbf{s}_t) μ=f(st)、协方差矩阵为 Σ \Sigma Σ的多元高斯分布,其PDF为:
π θ ( a t ∣ s t ) = 1 ( 2 π ) d ∣ Σ ∣ exp ( − 1 2 ( a t − f ( s t ) ) T Σ − 1 ( a t − f ( s t ) ) ) \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = \frac{1}{\sqrt{(2\pi)^{d} |\Sigma|}} \exp\left( -\frac{1}{2} (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t)) \right) πθ(at∣st)=(2π)d∣Σ∣1exp(−21(at−f(st))TΣ−1(at−f(st)))
其中 d d d是动作 a t \mathbf{a}_t at的维度。
2. 对PDF取对数
对上述公式取自然对数,得到对数概率:
log π θ ( a t ∣ s t ) = − 1 2 ( a t − f ( s t ) ) T Σ − 1 ( a t − f ( s t ) ) − d 2 log ( 2 π ) − 1 2 log ∣ Σ ∣ \log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t)) - \frac{d}{2} \log(2\pi) - \frac{1}{2} \log|\Sigma| logπθ(at∣st)=−21(at−f(st))TΣ−1(at−f(st))−2dlog(2π)−21log∣Σ∣
3. 简化与假设
在强化学习中,通常假设:
- 协方差矩阵 Σ \Sigma Σ 是固定的(例如,设为对角矩阵或常数矩阵),或者与参数 θ \theta θ 无关。
- 常数项对梯度更新无影响:在对策略梯度进行优化时,与 θ \theta θ 无关的项在求导后会消失,因此可以合并为常数。
基于上述假设,将对数概率中的常数项合并:
const = − d 2 log ( 2 π ) − 1 2 log ∣ Σ ∣ \text{const} = -\frac{d}{2} \log(2\pi) - \frac{1}{2} \log|\Sigma| const=−2dlog(2π)−21log∣Σ∣
4. 引入马氏距离符号
定义马氏距离(Mahalanobis distance)为:
∥ a t − f ( s t ) ∥ Σ 2 = ( a t − f ( s t ) ) T Σ − 1 ( a t − f ( s t ) ) \left\| \mathbf{a}_t - f(\mathbf{s}_t) \right\|_{\Sigma}^2 = (\mathbf{a}_t - f(\mathbf{s}_t))^T \Sigma^{-1} (\mathbf{a}_t - f(\mathbf{s}_t)) ∥at−f(st)∥Σ2=(at−f(st))TΣ−1(at−f(st))
代入对数概率公式,得到:
log π θ ( a t ∣ s t ) = − 1 2 ∥ f ( s t ) − a t ∥ Σ 2 + const \log \pi_{\theta}(\mathbf{a}_t | \mathbf{s}_t) = -\frac{1}{2} \left\| f(\mathbf{s}_t) - \mathbf{a}_t \right\|_{\Sigma}^2 + \text{const} logπθ(at∣st)=−21∥f(st)−at∥Σ2+const
最终公式
课程中给出的公式正是上述结果,其中:
- f ( s t ) f(\mathbf{s}_t) f(st) 是由参数 θ \theta θ 确定的均值函数(例如神经网络)。
- Σ \Sigma Σ 是固定的协方差矩阵,与 θ \theta θ 无关。
- const \text{const} const 包含所有与 θ \theta θ 无关的常数项。
补充说明
- 协方差矩阵的简化:如果 Σ \Sigma Σ 是对角矩阵或各向同性( Σ = σ 2 I \Sigma = \sigma^2 I Σ=σ2I),则计算马氏距离时只需对每个维度单独计算平方误差。
- 策略梯度的应用:在计算策略梯度 ∇ θ log π θ \nabla_\theta \log \pi_{\theta} ∇θlogπθ 时,常数项 const \text{const} const 的导数为零,因此可以安全忽略。
结论:通过假设协方差矩阵 Σ \Sigma Σ 固定且与参数 θ \theta θ 无关,课程中的公式从多元高斯分布的对数概率密度函数中合理推导得出。