Double DQN算法
问题
DQN 算法通过贪婪法直接获得目标 Q 值,贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛,但易导致过估计Q 值的问题,使模型具有较大的偏差。
即:
对于DQN模型, 损失函数使用的
Q(state) = reward + Q(nextState)max
Q(state)
由训练网络生成, Q(nextState)max
由目标网络生成
这种损失函数会存在问题,即当Q(nextState)max
总是大于0时,那么Q(state)
总是在不停的增大,同时Q(nextState)max
也在不断的增大, 即Q(state)
存在被高估的情况。
作者采用 Double DQN 算法解耦动作的选择和目标 Q 值的计算,以解决过估计 Q 值的问题。
Double DQN 原理
Double DQN 算法结构如下。在 Double DQN 框架中存在两个神经网络模型,分别是训练网络与目标网络。这两个神经网络模型的结构完全相同,但是权重参数不同;每训练一段之间后,训练网络的权重参数才会复制给目标网络。训练时,训练网络用于估计当前的 ,而目标网络用于估计 ,这样就能保证真实值 的估计不会随着训练网络的不断自更新而变化过快。此外,DQN 还是一种支持离线学习的框架,即通过构建经验池的方式离线学习过去的经验。将均方误差 MSE(Q_{train}, Q_{target}) 作为训练模型的损失函数,通过梯度下降法进行反向传播,对训练模型进行更新;若干轮经验池采样后,再将训练模型的权重赋给目标模型,以此进行 Double DQN 框架下的模型自学习。
目标 Q 值的计算公式如下所示:
y j = r j + γ max a ′ Q ( s j + 1 , a ′ ; θ ′ ) y_j=r_j+\gamma \max _{a^{\prime}} Q\left(s_{j+1}, a^{\prime} ; \theta^{\prime}\right) yj=rj+γa′maxQ(sj+1,a′;θ′)
Double DQN 算法不直接通过最大化的方式选取目标网络计算的所有可能 Q Q Q 值,而是首先通过估计网络选取最大 Q Q Q 值对应的动作,公式表示如下:
a max = argmax a Q ( s t + 1 , a ; θ ) a_{\max }=\operatorname{argmax}_a Q\left(s_{t+1}, a ; \theta\right) amax=argmaxaQ(st+1,a;θ)
然后目标网络根据 a max a_{\max } amax 计算目标 Q 值,公式表示如下:
y j = r j + γ Q ( s j + 1 , a max ; θ ′ ) y_j=r_j+\gamma Q\left(s_{j+1}, a_{\max } ; \theta^{\prime}\right) yj=rj+γQ(sj+1,amax;θ′)
最后将上面两个公式结合,目标 Q Q Q 值的最终表示形式如下:
y j = r j + γ Q ( s j + 1 , argmax a Q ( s t + 1 , a ; θ ) ; θ ′ ) y_j=r_j+\gamma Q\left(s_{j+1}, \operatorname{argmax}_a Q\left(s_{t+1, a ; \theta}\right) ; \theta^{\prime}\right) yj=rj+γQ(sj+1,argmaxaQ(st+1,a;θ);θ′)
目标是最小化目标函数,即最小化估计 Q Q Q 值和目标 Q Q Q 值的差值,公式如下:
δ = ∣ Q ( s t , a t ) − y t ∣ = ∣ Q ( s t , a t ; θ ) − ( r t + γ Q ( S t + 1 , argmax a Q ( s t + 1 , a ; θ ) ; θ ′ ) ) ∣ \begin{aligned} & \delta=\left|Q\left(s_t, a_t\right)-y_t\right|=\mid Q\left(s_t, a_t ; \theta\right)-\left(r_t+\right. \\ & \left.\gamma Q\left(S_{t+1}, \operatorname{argmax}_a Q\left(s_{t+1}, a ; \theta\right) ; \theta^{\prime}\right)\right) \mid \end{aligned} δ=∣Q(st,at)−yt∣=∣Q(st,at;θ)−(rt+γQ(St+1,argmaxaQ(st+1,a;θ);θ′))∣
结合目标函数,损失函数定义如下:
loss = { 1 2 δ 2 for ∣ δ ∣ ⩽ 1 ∣ δ ∣ − 1 2 otherwize } \text { loss }=\left\{\begin{array}{cl} \frac{1}{2} \delta^2 & \text { for }|\delta| \leqslant 1 \\ |\delta|-\frac{1}{2} & \text { otherwize } \end{array}\right\} loss ={21δ2∣δ∣−21 for ∣δ∣⩽1 otherwize }
代码
- 游戏环境
import gym#定义环境
class MyWrapper(gym.Wrapper):def __init__(self):env = gym.make('CartPole-v1', render_mode='rgb_array')super().__init__(env)self.env = envself.step_n = 0def reset(self):state, _ = self.env.reset()self.step_n = 0return statedef step(self, action):state, reward, terminated, truncated, info = self.env.step(action)over = terminated or truncated#限制最大步数self.step_n += 1if self.step_n >= 200:over = True#没坚持到最后,扣分if over and self.step_n < 200:reward = -1000return state, reward, over#打印游戏图像def show(self):from matplotlib import pyplot as pltplt.figure(figsize=(3, 3))plt.imshow(self.env.render())plt.show()env = MyWrapper()env.reset()env.show()
- Q价值函数
import torch#定义模型,评估状态下每个动作的价值
model = torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#延迟更新的模型,用于计算target
model_delay = torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#复制参数
model_delay.load_state_dict(model.state_dict())model, model_delay
- 单条轨迹
from IPython import display
import random#玩一局游戏并记录数据
def play(show=False):data = []reward_sum = 0state = env.reset()over = Falsewhile not over:action = model(torch.FloatTensor(state).reshape(1, 4)).argmax().item()if random.random() < 0.1:action = env.action_space.sample()next_state, reward, over = env.step(action)data.append((state, action, reward, next_state, over))reward_sum += rewardstate = next_stateif show:display.clear_output(wait=True)env.show()return data, reward_sumplay()[-1]
- 经验池
#数据池
class Pool:def __init__(self):self.pool = []def __len__(self):return len(self.pool)def __getitem__(self, i):return self.pool[i]#更新动作池def update(self):#每次更新不少于N条新数据old_len = len(self.pool)while len(pool) - old_len < 200:self.pool.extend(play()[0])#只保留最新的N条数据self.pool = self.pool[-2_0000:]#获取一批数据样本def sample(self):data = random.sample(self.pool, 64)state = torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)action = torch.LongTensor([i[1] for i in data]).reshape(-1, 1)reward = torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)next_state = torch.FloatTensor([i[3] for i in data]).reshape(-1, 4)over = torch.LongTensor([i[4] for i in data]).reshape(-1, 1)return state, action, reward, next_state, overpool = Pool()
pool.update()
pool.sample()len(pool), pool[0]
- 训练
#训练
def train():model.train()optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)loss_fn = torch.nn.MSELoss()#共更新N轮数据for epoch in range(1000):pool.update()#每次更新数据后,训练N次for i in range(200):#采样N条数据state, action, reward, next_state, over = pool.sample()#计算valuevalue = model(state).gather(dim=1, index=action)#计算targetwith torch.no_grad():target = model_delay(next_state)target = target.max(dim=1)[0].reshape(-1, 1)target = target * 0.99 * (1 - over) + rewardloss = loss_fn(value, target)loss.backward()optimizer.step()optimizer.zero_grad()#复制参数if (epoch + 1) % 5 == 0:model_delay.load_state_dict(model.state_dict())if epoch % 100 == 0:test_result = sum([play()[-1] for _ in range(20)]) / 20print(epoch, len(pool), test_result)train()