TensorFlow 2.0深度强化学习指南

在本教程中,我将通过实施Advantage Actor-Critic(演员-评论家,A2C)代理来解决经典的CartPole-v0环境,通过深度强化学习(DRL)展示即将推出的TensorFlow2.0特性。虽然我们的目标是展示TensorFlow2.0,但我将尽最大努力让DRL的讲解更加平易近人,包括对该领域的简要概述。

事实上,由于2.0版本的焦点是让开发人员的生活变得更轻松,所以我认为现在是使用TensorFlow进入DRL的好时机,本文用到的例子的源代码不到150行!代码可以在这里或者这里获取。

建立

由于TensorFlow2.0仍处于试验阶段,我建议将其安装在独立的虚拟环境中。我个人比较喜欢Anaconda,所以我将用它来演示安装过程:

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2">></span> conda create <span style="color:#f8f8f2">-</span>n tf2 python<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">3.6</span></span>
<span style="color:#f8f8f2">></span> source activate tf2
<span style="color:#f8f8f2">></span> pip install tf<span style="color:#f8f8f2">-</span>nightly<span style="color:#ae81ff"><span style="color:#ae81ff">-2.0</span></span><span style="color:#f8f8f2">-</span>preview <span style="color:slategray"><span style="color:#75715e"># tf-nightly-gpu-2.0-preview for GPU version</span></span></code></span>

让我们快速验证一切是否按能够正常工作:

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> tf
<span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>__version__<span style="color:#f8f8f2">)</span>
<span style="color:#ae81ff"><span style="color:#ae81ff">1.13</span></span><span style="color:#f8f8f2"><span style="color:#ae81ff">.</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">-</span>dev20190117
<span style="color:#f8f8f2"><span style="color:#75715e">>></span></span><span style="color:#f8f8f2"><span style="color:#75715e">></span></span> <span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span>
<span style="color:#ae81ff"><span style="color:#f92672">True</span></span></code></span>

不要担心1.13.x版本,这只是意味着它是早期预览。这里要注意的是我们默认处于eager模式!

<span style="color:#f8f8f2"><code class="language-none">>>> print(tf.reduce_sum([1, 2, 3, 4, 5]))
tf.Tensor(15, shape=(), dtype=int32)</code></span>

如果你还不熟悉eager模式,那么实质上意味着计算是在运行时被执行的,而不是通过预编译的图(曲线图)来执行。你可以在TensorFlow文档中找到一个很好的概述。

深度化学

一般而言,强化学习是解决连续决策问题的高级框架。RL通过基于某些agent进行导航观察环境,并且获得奖励。大多数RL算法通过最大化代理在一轮游戏期间收集的奖励总和来工作。

基于RL的算法的输出通常是policy(策略)-将状态映射到函数有效的策略中,有效的策略可以像硬编码的无操作动作一样简单。在某些状态下,随机策略表示为行动的条件概率分布。

评论家方(Actor-Critic Methods)

RL算法通常基于它们优化的目标函数进行分组。Value-based诸如DQN之类的方法通过减少预期的状态-动作值的误差来工作。

策略梯度(Policy Gradients)方法通过调整其参数直接优化策略本身,通常通过梯度下降完成的。完全计算梯度通常是难以处理的,因此通常要通过蒙特卡罗方法估算它们。

最流行的方法是两者的混合:actor-critic方法,其中代理策略通过策略梯度进行优化,而基于值的方法用作预期值估计的引导。

深度演-方法

虽然很多基础的RL理论是在表格案例中开发的,但现代RL几乎完全是用函数逼近器完成的,例如人工神经网络。具体而言,如果策略和值函数用深度神经网络近似,则RL算法被认为是“深度”。

异步优势-评论家(actor-critical)

多年来,为了提高学习过程的样本效率和稳定性,技术发明者已经进行了一些改进。

首先,梯度加权回报:折现的未来奖励,这在一定程度上缓解了信用分配问题,并以无限的时间步长解决了理论问题。

其次,使用优势函数代替原始回报。优势在收益与某些基线之间的差异之间形成,并且可以被视为衡量给定值与某些平均值相比有多好的指标。

第三,在目标函数中使用额外的熵最大化项以确保代理充分探索各种策略。本质上,熵以均匀分布最大化来测量概率分布的随机性。

最后,并行使用多个工人加速样品采集,同时在训练期间帮助它们去相关。

将所有这些变化与深度神经网络相结合,我们得出了两种最流行的现代算法:异步优势演员评论家(actor-critical)算法,简称A3C或者A2C。两者之间的区别在于技术性而非理论性:顾名思义,它归结为并行工人如何估计其梯度并将其传播到模型中。

有了这个,我将结束我们的DRL方法之旅,因为博客文章的重点更多是关于TensorFlow2.0的功能。如果你仍然不了解该主题,请不要担心,代码示例应该更清楚。如果你想了解更多,那么一个好的资源就可以开始Deep RL中进行Spinning Up

使用TensorFlow 2.0优势-评论

让我们看看实现现代DRL算法的基础是什么:演员评论家代理(actor-critic agent。如前一节所述,为简单起见,我们不会实现并行工作程序,尽管大多数代码都会支持它,感兴趣的读者可以将其用作锻炼机会。

作为测试平台,我们将使用CartPole-v0环境。虽然它有点简单,但它仍然是一个很好的选择开始。在实现RL算法时,我总是依赖它作为一种健全性检查。

Keras Model API实现的策略和价值

首先,让我们在单个模型类下创建策略和价值估计NN:

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> numpy <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> np
<span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> tf
<span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>layers <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> kl<span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">ProbabilityDistribution</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">tf</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">keras</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">call</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># sample a random categorical action from given logits</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> tf<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>random<span style="color:#f8f8f2">.</span>categorical<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">tf</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">keras</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">.</span></span><span style="color:#f8f8f2">Model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> num_actions</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>super<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">.</span>__init__<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'mlp_policy'</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># no tf.get_variable(), just simple Keras API</span></span>self<span style="color:#f8f8f2">.</span>hidden1 <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">128</span></span><span style="color:#f8f8f2">,</span> activation<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'relu'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>hidden2 <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">128</span></span><span style="color:#f8f8f2">,</span> activation<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'relu'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>value <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">,</span> name<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># logits are unnormalized log probabilities</span></span>self<span style="color:#f8f8f2">.</span>logits <span style="color:#f8f8f2">=</span> kl<span style="color:#f8f8f2">.</span>Dense<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">,</span> name<span style="color:#f8f8f2">=</span><span style="color:#a6e22e"><span style="color:#e6db74">'policy_logits'</span></span><span style="color:#f8f8f2">)</span>self<span style="color:#f8f8f2">.</span>dist <span style="color:#f8f8f2">=</span> ProbabilityDistribution<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">call</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> inputs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># inputs is a numpy array, convert to Tensor</span></span>x <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>convert_to_tensor<span style="color:#f8f8f2">(</span>inputs<span style="color:#f8f8f2">,</span> dtype<span style="color:#f8f8f2">=</span>tf<span style="color:#f8f8f2">.</span>float32<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># separate hidden layers from the same input tensor</span></span>hidden_logs <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>hidden1<span style="color:#f8f8f2">(</span>x<span style="color:#f8f8f2">)</span>hidden_vals <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>hidden2<span style="color:#f8f8f2">(</span>x<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> self<span style="color:#f8f8f2">.</span>logits<span style="color:#f8f8f2">(</span>hidden_logs<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> self<span style="color:#f8f8f2">.</span>value<span style="color:#f8f8f2">(</span>hidden_vals<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">action_value</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> obs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># executes call() under the hood</span></span>logits<span style="color:#f8f8f2">,</span> value <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>predict<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">)</span>action <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>dist<span style="color:#f8f8f2">.</span>predict<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># a simpler option, will become clear later why we don't use it</span></span><span style="color:slategray"><span style="color:#75715e"># action = tf.random.categorical(logits, 1)</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> np<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> np<span style="color:#f8f8f2">.</span>squeeze<span style="color:#f8f8f2">(</span>value<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span></code></span>

验证我们验证模型是否按预期工作:

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> gym
env <span style="color:#f8f8f2">=</span> gym<span style="color:#f8f8f2">.</span>make<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'CartPole-v0'</span></span><span style="color:#f8f8f2">)</span>
model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">=</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span>
obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span>
<span style="color:slategray"><span style="color:#75715e"># no feed_dict or tf.Session() needed at all</span></span>
action<span style="color:#f8f8f2">,</span> value <span style="color:#f8f8f2">=</span> model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">,</span> value<span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># [1] [-0.00145713]</span></span></code></span>

这里要注意的事项:

  • 模型层和执行路径是分开定义的;
  • 没有“输入”图层,模型将接受原始numpy数组;
  • 可以通过函数API在一个模型中定义两个计算路径;
  • 模型可以包含一些辅助方法,例如动作采样;
  • 在eager的模式下,一切都可以从原始的numpy数组中运行;

随机代理

现在我们可以继续学习一些有趣的东西A2CAgent类。首先,让我们添加一个贯穿整集的test方法并返回奖励总和。

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>self<span style="color:#f8f8f2">.</span>model <span style="color:#f8f8f2">=</span> model<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span>obs<span style="color:#f8f8f2">,</span> done<span style="color:#f8f8f2">,</span> ep_reward <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#f92672">False</span></span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#66d9ef"><span style="color:#f92672">while</span></span> <span style="color:#f8f8f2"><span style="color:#f92672">not</span></span> done<span style="color:#f8f8f2">:</span>action<span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>obs<span style="color:#f8f8f2">,</span> reward<span style="color:#f8f8f2">,</span> done<span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>step<span style="color:#f8f8f2">(</span>action<span style="color:#f8f8f2">)</span>ep_reward <span style="color:#f8f8f2">+=</span> reward<span style="color:#66d9ef"><span style="color:#f92672">if</span></span> render<span style="color:#f8f8f2">:</span>env<span style="color:#f8f8f2">.</span>render<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> ep_reward</code></span>

让我们看看我们的模型在随机初始化权重下得分多少:

<span style="color:#f8f8f2"><code class="language-python">agent <span style="color:#f8f8f2">=</span> A2CAgent<span style="color:#f8f8f2">(</span>model<span style="color:#f8f8f2">)</span>
rewards_sum <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> rewards_sum<span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 18 out of 200</span></span></code></span>

离最佳转台还有很远,接下来是训练部分!

损失/函数

正如我在DRL概述部分所描述的那样,代理通过基于某些损失(目标)函数的梯度下降来改进其策略。在演员评论家中,我们训练了三个目标:用优势加权梯度加上熵最大化来改进策略,并最小化价值估计误差。

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>losses <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> kls
<span style="color:#66d9ef"><span style="color:#f92672">import</span></span> tensorflow<span style="color:#f8f8f2">.</span>keras<span style="color:#f8f8f2">.</span>optimizers <span style="color:#66d9ef"><span style="color:#f92672">as</span></span> ko
<span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># hyperparameters for loss terms</span></span>self<span style="color:#f8f8f2">.</span>params <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">{</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.5</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.0001</span></span><span style="color:#f8f8f2">}</span>self<span style="color:#f8f8f2">.</span>model <span style="color:#f8f8f2">=</span> modelself<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>compile<span style="color:#f8f8f2">(</span>optimizer<span style="color:#f8f8f2">=</span>ko<span style="color:#f8f8f2">.</span>RMSprop<span style="color:#f8f8f2">(</span>lr<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0007</span></span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span><span style="color:slategray"><span style="color:#75715e"># define separate losses for policy logits and value estimate</span></span>loss<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2">[</span>self<span style="color:#f8f8f2">.</span>_logits_loss<span style="color:#f8f8f2">,</span> self<span style="color:#f8f8f2">.</span>_value_loss<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_value_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> returns</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># value loss is typically MSE between value estimates and returns</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">*</span>kls<span style="color:#f8f8f2">.</span>mean_squared_error<span style="color:#f8f8f2">(</span>returns<span style="color:#f8f8f2">,</span> value<span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_logits_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> acts_and_advs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># a trick to input actions and advantages through same API</span></span>actions<span style="color:#f8f8f2">,</span> advantages <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>split<span style="color:#f8f8f2">(</span>acts_and_advs<span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">2</span></span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># polymorphic CE loss function that supports sparse and weighted options</span></span><span style="color:slategray"><span style="color:#75715e"># from_logits argument ensures transformation into normalized probabilities</span></span>cross_entropy <span style="color:#f8f8f2">=</span> kls<span style="color:#f8f8f2">.</span>CategoricalCrossentropy<span style="color:#f8f8f2">(</span>from_logits<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#f92672">True</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># policy loss is defined by policy gradients, weighted by advantages</span></span><span style="color:slategray"><span style="color:#75715e"># note: we only calculate the loss on the actions we've actually taken</span></span><span style="color:slategray"><span style="color:#75715e"># thus under the hood a sparse version of CE loss will be executed</span></span>actions <span style="color:#f8f8f2">=</span> tf<span style="color:#f8f8f2">.</span>cast<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">,</span> tf<span style="color:#f8f8f2">.</span>int32<span style="color:#f8f8f2">)</span>policy_loss <span style="color:#f8f8f2">=</span> cross_entropy<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">,</span> logits<span style="color:#f8f8f2">,</span> sample_weight<span style="color:#f8f8f2">=</span>advantages<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># entropy loss can be calculated via CE over itself</span></span>entropy_loss <span style="color:#f8f8f2">=</span> cross_entropy<span style="color:#f8f8f2">(</span>logits<span style="color:#f8f8f2">,</span> logits<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># here signs are flipped because optimizer minimizes</span></span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> policy_loss <span style="color:#f8f8f2">-</span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">*</span>entropy_loss</code></span>

我们完成了目标函数!请注意代码的紧凑程度:注释行几乎比代码本身多。

代理训练

最后,还有训练回路本身,它相对较长,但相当简单:收集样本,计算回报和优势,并在其上训练模型。

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">class</span></span> <span style="color:#f8f8f2">A2CAgent</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">__init__</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> model</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># hyperparameters for loss terms</span></span>self<span style="color:#f8f8f2">.</span>params <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">{</span><span style="color:#a6e22e"><span style="color:#e6db74">'value'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.5</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'entropy'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.0001</span></span><span style="color:#f8f8f2">,</span> <span style="color:#a6e22e"><span style="color:#e6db74">'gamma'</span></span><span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">0.99</span></span><span style="color:#f8f8f2">}</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">train</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> batch_sz</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2"><span style="color:#ae81ff">32</span></span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> updates</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2"><span style="color:#ae81ff">1000</span></span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># storage helpers for a single batch of data</span></span>actions <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">,</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> dtype<span style="color:#f8f8f2">=</span>np<span style="color:#f8f8f2">.</span>int32<span style="color:#f8f8f2">)</span>rewards<span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">,</span> values <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">3</span></span><span style="color:#f8f8f2">,</span> batch_sz<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span>observations <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>empty<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">,</span><span style="color:#f8f8f2">)</span> <span style="color:#f8f8f2">+</span> env<span style="color:#f8f8f2">.</span>observation_space<span style="color:#f8f8f2">.</span>shape<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># training loop: collect samples, send to optimizer, repeat updates times</span></span>ep_rews <span style="color:#f8f8f2">=</span> <span style="color:#f8f8f2">[</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0</span></span><span style="color:#f8f8f2">]</span>next_obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> update <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> range<span style="color:#f8f8f2">(</span>updates<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> step <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> range<span style="color:#f8f8f2">(</span>batch_sz<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span>observations<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> next_obs<span style="color:#f8f8f2">.</span>copy<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span>actions<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> values<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>next_obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>next_obs<span style="color:#f8f8f2">,</span> rewards<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> _ <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>step<span style="color:#f8f8f2">(</span>actions<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>ep_rews<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">+=</span> rewards<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#66d9ef"><span style="color:#f92672">if</span></span> dones<span style="color:#f8f8f2">[</span>step<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">:</span>ep_rews<span style="color:#f8f8f2">.</span>append<span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">0.0</span></span><span style="color:#f8f8f2">)</span>next_obs <span style="color:#f8f8f2">=</span> env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span>_<span style="color:#f8f8f2">,</span> next_value <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>action_value<span style="color:#f8f8f2">(</span>next_obs<span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>returns<span style="color:#f8f8f2">,</span> advs <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>_returns_advantages<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">,</span> dones<span style="color:#f8f8f2">,</span> values<span style="color:#f8f8f2">,</span> next_value<span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># a trick to input actions and advantages through same API</span></span>acts_and_advs <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>concatenate<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">[</span>actions<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">,</span> <span style="color:#f92672">None</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> advs<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">,</span> <span style="color:#f92672">None</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># performs a full training step on the collected batch</span></span><span style="color:slategray"><span style="color:#75715e"># note: no need to mess around with gradients, Keras API handles it</span></span>losses <span style="color:#f8f8f2">=</span> self<span style="color:#f8f8f2">.</span>model<span style="color:#f8f8f2">.</span>train_on_batch<span style="color:#f8f8f2">(</span>observations<span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">[</span>acts_and_advs<span style="color:#f8f8f2">,</span> returns<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef"><span style="color:#f92672">return</span></span> ep_rews<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_returns_advantages</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> rewards</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> dones</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> values</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> next_value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># next_value is the bootstrap value estimate of a future state (the critic)</span></span>returns <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>append<span style="color:#f8f8f2">(</span>np<span style="color:#f8f8f2">.</span>zeros_like<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">,</span> next_value<span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">)</span><span style="color:slategray"><span style="color:#75715e"># returns are calculated as discounted sum of future rewards</span></span><span style="color:#66d9ef"><span style="color:#f92672">for</span></span> t <span style="color:#66d9ef"><span style="color:#f92672">in</span></span> reversed<span style="color:#f8f8f2">(</span>range<span style="color:#f8f8f2">(</span>rewards<span style="color:#f8f8f2">.</span>shape<span style="color:#f8f8f2">[</span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span>returns<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">=</span> rewards<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">+</span> self<span style="color:#f8f8f2">.</span>params<span style="color:#f8f8f2">[</span><span style="color:#a6e22e"><span style="color:#e6db74">'gamma'</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">*</span> returns<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">+</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span> <span style="color:#f8f8f2">*</span> <span style="color:#f8f8f2">(</span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">-</span>dones<span style="color:#f8f8f2">[</span>t<span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">)</span>returns <span style="color:#f8f8f2">=</span> returns<span style="color:#f8f8f2">[</span><span style="color:#f8f8f2">:</span><span style="color:#f8f8f2"><span style="color:#ae81ff">-</span></span><span style="color:#ae81ff"><span style="color:#ae81ff">1</span></span><span style="color:#f8f8f2">]</span><span style="color:slategray"><span style="color:#75715e"># advantages are returns - baseline, value estimates in our case</span></span>advantages <span style="color:#f8f8f2">=</span> returns <span style="color:#f8f8f2">-</span> values<span style="color:#66d9ef"><span style="color:#f92672">return</span></span> returns<span style="color:#f8f8f2">,</span> advantages<span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">test</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> env</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> render</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">=</span></span><span style="color:#ae81ff"><span style="color:#f8f8f2">True</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_value_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> returns</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> value</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#66d9ef"><span style="color:#f92672">def</span></span> <span style="color:#e6db74"><span style="color:#a6e22e">_logits_loss</span></span><span style="color:#f8f8f2"><span style="color:#f8f8f2">(</span></span><span style="color:#f8f8f2">self</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> acts_and_advs</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">,</span></span><span style="color:#f8f8f2"> logits</span><span style="color:#f8f8f2"><span style="color:#f8f8f2">)</span></span><span style="color:#f8f8f2">:</span><span style="color:slategray"><span style="color:#75715e"># unchanged from previous section</span></span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span><span style="color:#f8f8f2">.</span></code></span>

训练和

我们现在已经准备好在CartPole-v0上训练我们的单工A2C代理了!训练过程不应超过几分钟,训练完成后,你应该看到代理成功达到200分中的目标。

<span style="color:#f8f8f2"><code class="language-python">rewards_history <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>train<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Finished training, testing..."</span></span><span style="color:#f8f8f2">)</span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 200 out of 200</span></span></code></span>

 

fe57bd94e74ec6673c03baee707695fa5be7d788

在源代码中,我包含了一些额外的帮助程序,可以打印出运行的奖励和损失,以及rewards_history的基本绘图仪。

态计

有了所有这种渴望模式的成功的喜悦,你可能想知道静态图形执行是否可以。当然!此外,我们还需要多一行代码来启用它!

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#66d9ef"><span style="color:#f92672">with</span></span> tf<span style="color:#f8f8f2">.</span>Graph<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">.</span>as_default<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">:</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span>tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># False</span></span>model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>num_actions<span style="color:#f8f8f2">=</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span>agent <span style="color:#f8f8f2">=</span> A2CAgent<span style="color:#f8f8f2">(</span>model<span style="color:#f8f8f2">)</span>rewards_history <span style="color:#f8f8f2">=</span> agent<span style="color:#f8f8f2">.</span>train<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Finished training, testing..."</span></span><span style="color:#f8f8f2">)</span><span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"%d out of 200"</span></span> <span style="color:#f8f8f2">%</span> agent<span style="color:#f8f8f2">.</span>test<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span> <span style="color:slategray"><span style="color:#75715e"># 200 out of 200</span></span></code></span>

有一点需要注意,在静态图形执行期间,我们不能只有Tensors,这就是为什么我们在模型定义期间需要使用CategoricalDistribution的技巧。事实上,当我在寻找一种在静态模式下执行的方法时,我发现了一个关于通过Keras API构建的模型的一个有趣的低级细节。

还有一件事

还记得我说过TensorFlow默认是运行在eager模式下吧,甚至用代码片段证明它吗?好吧,我错了。

如果你使用Keras API来构建和管理模型,那么它将尝试将它们编译为静态图形。所以你最终得到的是静态计算图的性能,具有渴望执行的灵活性。

你可以通过model.run_eagerly标志检查模型的状态,你也可以通过设置此标志来强制执行eager模式变成True,尽管大多数情况下你可能不需要这样做。但如果Keras检测到没有办法绕过eager模式,它将自动退出。

为了说明它确实是作为静态图运行,这里是一个简单的基准测试:

<span style="color:#f8f8f2"><code class="language-python"><span style="color:slategray"><span style="color:#75715e"># create a 100000 samples batch</span></span>
env <span style="color:#f8f8f2">=</span> gym<span style="color:#f8f8f2">.</span>make<span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">'CartPole-v0'</span></span><span style="color:#f8f8f2">)</span>
obs <span style="color:#f8f8f2">=</span> np<span style="color:#f8f8f2">.</span>repeat<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">.</span>reset<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">[</span><span style="color:#f92672">None</span><span style="color:#f8f8f2">,</span> <span style="color:#f8f8f2">:</span><span style="color:#f8f8f2">]</span><span style="color:#f8f8f2">,</span> <span style="color:#ae81ff"><span style="color:#ae81ff">100000</span></span><span style="color:#f8f8f2">,</span> axis<span style="color:#f8f8f2">=</span><span style="color:#ae81ff"><span style="color:#ae81ff">0</span></span><span style="color:#f8f8f2">)</span></code></span>

Eager基准

<span style="color:#f8f8f2"><code class="language-python"><span style="color:#f8f8f2">%</span><span style="color:#f8f8f2">%</span>time
model <span style="color:#f8f8f2">=</span> Model<span style="color:#f8f8f2">(</span>env<span style="color:#f8f8f2">.</span>action_space<span style="color:#f8f8f2">.</span>n<span style="color:#f8f8f2">)</span>
model<span style="color:#f8f8f2">.</span>run_eagerly <span style="color:#f8f8f2">=</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Eager Execution:  "</span></span><span style="color:#f8f8f2">,</span> tf<span style="color:#f8f8f2">.</span>executing_eagerly<span style="color:#f8f8f2">(</span><span style="color:#f8f8f2">)</span><span style="color:#f8f8f2">)</span>
<span style="color:#66d9ef">print</span><span style="color:#f8f8f2">(</span><span style="color:#a6e22e"><span style="color:#e6db74">"Eager Keras Model:"</span></span><span style="color:#f8f8f2">,</span> model<span style="color:#f8f8f2">.</span>run_eagerly<span style="color:#f8f8f2">)</span>
_ <span style="color:#f8f8f2">=</span> model<span style="color:#f8f8f2">(</span>obs<span style="color:#f8f8f2">)</span>
<span style="color:slategray"><span style="color:#75715e">######## Results #######</span></span>
Eager Execution<span style="color:#f8f8f2">:</span>   <span style="color:#ae81ff"><span style="color:#f92672">True</span></span>
Eager Keras Model<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#f92672">True</span></span>
CPU times<span style="color:#f8f8f2">:</span> user <span style="color:#ae81ff"><span style="color:#ae81ff">639</span></span> ms<span style="color:#f8f8f2">,</span> sys<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">736</span></span> ms<span style="color:#f8f8f2">,</span> total<span style="color:#f8f8f2">:</span> <span style="color:#ae81ff"><span style="color:#ae81ff">1.38</span></span> s</code></span>

基准

<span style="color:#f8f8f2"><code class="language-none">%%time
with tf.Graph().as_default():model = Model(env.action_space.n)print("Eager Execution:  ", tf.executing_eagerly())print("Eager Keras Model:", model.run_eagerly)_ = model.predict(obs)
######## Results #######
Eager Execution:   False
Eager Keras Model: False
CPU times: user 793 ms, sys: 79.7 ms, total: 873 ms</code></span>

基准

<span style="color:#333333"><span style="color:#f8f8f2"><code class="language-none">%%time
model = Model(env.action_space.n)
print("Eager Execution:  ", tf.executing_eagerly())
print("Eager Keras Model:", model.run_eagerly)
_ = model.predict(obs)
######## Results #######
Eager Execution:   True
Eager Keras Model: False
CPU times: user 994 ms, sys: 23.1 ms, total: 1.02 s</code></span></span>

正如你所看到的,eager模式是静态模式的背后,默认情况下,我们的模型确实是静态执行的。

结论

希望本文能够帮助你理解DRL和TensorFlow2.0。请注意,TensorFlow2.0仍然只是预览版本,甚至不是候选版本,一切都可能发生变化。如果TensorFlow有什么东西你特别不喜欢,让它的开发者知道

人们可能会有一个挥之不去的问题:TensorFlow比PyTorch好吗?也许,也许不是。它们两个都是伟大的库,所以很难说这样谁好,谁不好。如果你熟悉PyTorch,你可能已经注意到TensorFlow 2.0不仅赶上了它,而且还避免了一些PyTorch API的缺陷。

在任何一种情况下,对于开发者来说,这场竞争都已经为双方带来了积极的结果,我很期待看到未来的框架将会变成什么样。

 

原文链接
本文为云栖社区原创内容,未经允许不得转载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/519875.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

互联网诞生记: 浪成于微澜之间

戳蓝字“CSDN云计算”关注我们哦&#xff01;作者 | 老姜出品 | CSDN云计算&#xff08;ID&#xff1a;CSDNcloud&#xff09;“我早就预言了互联网。1975年&#xff0c;所有的技术都已经准备好了&#xff1b;1985年&#xff0c;所有的技术都应该很平常了&#xff1b;而直到199…

“有趣”的投影:当PCA失效时怎么办?

目前&#xff0c;大多数的数据科学家都比较熟悉主成分分析 (Principal Components Analysis&#xff0c;PCA)&#xff0c;它是一个探索性的数据分析工具。可以这样简要的描述&#xff1a;研究人员经常使用PCA来降低维度&#xff0c;希望在他们的数据中找出有用的信息&#xff0…

线程最最基础的知识

戳蓝字“CSDN云计算”关注我们哦&#xff01;什么是线程试想一下没有线程的程序是怎么样的&#xff1f;百度网盘在上传文件时就无法下载文件了&#xff0c;得等文件上传完成后才能下载文件。这个我们现在看起来很反人性&#xff0c;因为我们习惯了一个程序同时可以进行运行多个…

特征工程自动化如何为机器学习带来重大变化

随着技术的快速发展&#xff0c;在数据科学领域中&#xff0c;包括库、工具和算法等总会不断地变化的。然而&#xff0c;一直都有这么一个趋势&#xff0c;那就是自动化水平不断地提高。 近些年来&#xff0c;在模型的自动化选择和超参数调整方面取得了一些进展&#xff0c;但…

序列模型简介——RNN, Bidirectional RNN, LSTM, GRU

既然我们已经有了前馈网络和CNN&#xff0c;为什么我们还需要序列模型呢&#xff1f;这些模型的问题在于&#xff0c;当给定一系列的数据时&#xff0c;它们表现的性能很差。序列数据的一个例子是音频的剪辑&#xff0c;其中包含一系列的人说过的话。另一个例子是英文句子&…

行!人工智能玩大了!程序员:太牛!你怎么看?

人工智能真的玩大了吗&#xff1f;人工智能行业的人才真的“爆发了&#xff1f;”AI程序员究竟怎么样&#xff1f;中国AI前景分析 程序员与远方最新参考&#xff0c;是12月2日出炉的《人工智能技术专利深度分析报告》。中国AI专利&#xff0c;已经位于世界前列&#xff0c;且正…

可应用于实际的14个NLP突破性研究成果(一)

语言理解对计算机来说是一个巨大的挑战。幼儿可以理解的微妙的细微差别仍然会使最强大的机器混淆。尽管深度学习等技术可以检测和复制复杂的语言模式&#xff0c;但机器学习模型仍然缺乏对我们的语言真正含义的基本概念性理解。 但在2018年确实产生了许多具有里程碑意义的研究…

开发函数计算的正确姿势——网页截图服务

前言 首先介绍下在本文出现的几个比较重要的概念&#xff1a; 函数计算&#xff08;Function Compute&#xff09;: 函数计算是一个事件驱动的服务&#xff0c;通过函数计算&#xff0c;用户无需管理服务器等运行情况&#xff0c;只需编写代码并上传。函数计算准备计算资源&am…

如何关闭计算机的f12功能键,win10如何关闭快捷键?win10关闭F1~F12快捷键的方法

win10如何使用快捷键关闭?在win10系统中我们按下F1~F12原本可以正常使用系统中的功能。而然在笔记本中F1~F12竟然被笔记本中的功能所替代了&#xff0c;例如&#xff1a;打开/关闭 无线网卡&#xff0c;屏幕亮度加减、系统音量大小等&#xff0c;导致我们在使用F1~F12的时候只…

阿里巴巴宣布架构调整;英伟达放大招!重磅发布 ​TensorRT 7 ,支持超千种计算变换;苹果、谷歌和亚马逊罕见结盟……...

戳蓝字“CSDN云计算”关注我们哦&#xff01; 嗨&#xff0c;大家好&#xff0c;重磅君带来的【云重磅】特别栏目&#xff0c;如期而至&#xff0c;每周五第一时间为大家带来重磅新闻。把握技术风向标&#xff0c;了解行业应用与实践&#xff0c;就交给我重磅君吧&#xff01;重…

阿里开源分布式事务解决方案 Fescar 全解析

广为人知的阿里分布式事务解决方案&#xff1a;GTS&#xff08;Global Transaction Service&#xff09;&#xff0c;已正式推出开源版本&#xff0c;取名为“Fescar”&#xff0c;希望帮助业界解决微服务架构下的分布式事务问题&#xff0c;今天我们一起来深入了解。 FESCAR o…

鲜为人知的混沌工程,到底哪里好?

混沌工程属于一门新兴的技术学科&#xff0c;行业认知和实践积累比较少&#xff0c;大多数IT团队对它的理解还没有上升到一个领域概念。阿里电商域在2010年左右开始尝试故障注入测试的工作&#xff0c;希望解决微服务架构带来的强弱依赖问题。通过本文&#xff0c;你将了解到&a…

将视觉深度学习模型应用于非视觉领域

介绍 近些年来&#xff0c;深度学习技术已经彻底改变了计算机视觉领域。由于迁移学习和各种各样的学习资源的出现&#xff0c;任何人都可以通过使用预训练的模型&#xff0c;将其应用到自己的工作当中&#xff0c;以此获得非常好的结果。随着深度学习越来越商业化&#xff0c;…

如何在Flutter上优雅地序列化一个对象

序列化一个对象才是正经事 对象的序列化和反序列化是我们日常编码中一个非常基础的需求&#xff0c;尤其是对一个对象的json encode/decode操作。每一个平台都会有相关的库来帮助开发者方便得进行这两个操作&#xff0c;比如Java平台上赫赫有名的GSON&#xff0c;阿里巴巴开源…

腾讯汤道生:2020年加大投入产业互联网生态建设

新一轮产业革命正在不断深化&#xff0c;为全球经济发展提供了历史性机遇。如何通过数字化、智能化等手段打通产业链不同环节&#xff0c;优化产业效率&#xff0c;实现产业协同&#xff0c;加速产业转型升级&#xff1f; “与合作伙伴‘共创’是产业互联网发展最重要的路径&am…

优酷IPv6改造纪实:视频行业首家拥抱下一代网络技术

阿里妹导读&#xff1a;2018年双11前&#xff0c;优酷开启了IPV6的大门。9月份PC端业务开启灰度&#xff0c;迎来首位IPV6 VIP用户后&#xff0c;优酷移动客户端也马不停蹄地加入灰度大军。从0到1&#xff0c;花了几个月&#xff1b;从10到1000&#xff0c;花了几天&#xff1b…

服务器上的文件怎么取名,给新的服务器取名你会取神马?

亲爱的谕霸们&#xff1a;本周话题 emmmm.....又一次想话题想到脑壳痛&#xff0c;忽然想到&#xff0c;要不然大家也来想一个&#xff0c;于是乎就是......噢对了&#xff0c;看到论坛的宝宝们都晒出来2018年新年历了&#xff0c;怎么能少了APP的宝宝们&#xff1f;&#xff1…

NVIDIA发布全新推理软件,开创交互式会话AI新时代!

近日&#xff0c; NVIDIA发布了一款突破性的推理软件。借助于该软件&#xff0c;全球各地的开发者都可以实现会话式AI应用&#xff0c;大幅减少推理延迟。而此前&#xff0c;巨大的推理延迟一直都是实现真正交互式互动的一大阻碍。 NVIDIA TensorRT™ 7作为NVIDIA第七代推理软件…

可应用于实际的14个NLP突破性研究成果(二)

论文摘要 尽管最近在训练高质量的句子嵌入上做出了很多的努力&#xff0c;但是大家仍然对它们所捕捉的内容缺乏了解。基于句子分类的‘Downstream’tasks通常用于评估句子表示的质量。然而任务的复杂性使得它很难推断出句子表示中出现了什么样的信息。在本文将介绍10个probing…

可应用于实际的14个NLP突破性研究成果(三)

论文摘要 当前最先进的语义角色标记&#xff08;SRL&#xff09;使用深度神经网络&#xff0c;但没有明确的语言特征。之前的工作表明&#xff0c;抽象语法树可以显著改善SRL&#xff0c;从而提高模型准确性。在这项研究中&#xff0c;我们提出了语言学的自我关注&#xff08;…