【逆强化学习-1】学徒学习(Apprenticeship Learning)

文章目录

  • 0.引言
  • 1.算法原理
  • 2.仿真环境
  • 3.运行
  • 4.补充(学徒学习+深度Q网络)

本文为逆强化学习系列第1篇,没有看过逆强化学习介绍的那篇的朋友,可以看一下:

Inverse Reinforcement Learning-Introduction 传送门

0.引言

\qquad可以说学徒学习(简称App)是最早的一种逆强化学习的方法,如果看过最原始那篇论文的读者可能会觉得这个方法怎么这么晦涩难懂,但其实看完Code你会觉得这个Algorithm还挺粗糙的,这还是一件挺矛盾的事情。由于IRL-Introduction的论文介绍中没有该论文的下载链接,因此将其提供如下:

本文完整代码链接:CSDN下载

论文PDF链接:论文PDF(CSDN资源,永久免费)

1.算法原理

\qquad在IRL-Introduction介绍过了,该方法是2004年出现的,那时候DL还不是很受欢迎,因此作者采用的是线性函数的方法计算Reward的。APP方法从Observation中提取特征,计算特征期望,再将reward作为特征期望的线性函数。
\qquad提取特征的函数可以是非线性的,对算法没有任何影响,一般记为ϕ(s)\phi(s)ϕ(s),而轨迹π\piπ的特征期望μ(π)\mu(\pi)μ(π)在paper中的计算方法即为
μ(π)=∑t=0∞γtϕ(st)\mu(\pi)=\sum_{t=0}^{\infty}\gamma^t\phi(s_t)μ(π)=t=0γtϕ(st)
\qquad如果是环境在互动过程中具有随机性,则求得的应该是特征期望的数学期望,用均值近似数学期望可以得到(其中n为轨迹数目,st(i)s_t^{(i)}st(i)为第iii条轨迹的第ttt个状态):
μ(π)=1n∑i=1n∑t=0∞γtϕ(st(i))\mu(\pi)=\frac{1}{n}\sum_{i=1}^{n}\sum_{t=0}^{\infty}\gamma^t\phi(s_t^{(i)})μ(π)=n1i=1nt=0γtϕ(st(i))
在实际求解中,计算有限的时域ttt近似即可。
\qquad学徒学习的基本思想是寻找Reward使得所有的Agent产生的Reward都小于Expert的Reward的,再用这个reward去训练Agent。但为了防止Margin走向0和无穷大两个极端,在约束条件时还加入了一些限制。

该算法的步骤如下:

  1. 获得专家轨迹πE\pi_EπE,定于特征提取函数ϕ(st),t=1,2,...,N\phi(s_t),t=1,2,...,Nϕ(st),t=1,2,...,N,产生一个Random Agent,迭代次数g=0g=0g=0,随机设定一个www,奖赏函数rϕ=wTμ(π)r_\phi=w^T\mu(\pi)rϕ=wTμ(π)
  2. 利用ϕ\phiϕ提取专家轨迹的特征ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sN(E))\phi(s_1^{(E)}),\phi(s_2^{(E)}),...,\phi(s_N^{(E)})ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sN(E))
  3. 定义折扣因子γ\gammaγ并计算专家轨迹的特征期望μ(πE)=∑tγtϕ(st(E))\mu(\pi_E)=\sum_t\gamma^t\phi(s_t^{(E)})μ(πE)=tγtϕ(st(E))
  4. 奖赏函数设为rϕ=wTμ(πg)r_\phi=w^T\mu(\pi_g)rϕ=wTμ(πg),Agent与环境互动(可能会互动不止一次,因为model的更新需要时间),产生轨迹πg\pi_gπg,提取特征为ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sM(E))\phi(s_1^{(E)}),\phi(s_2^{(E)}),...,\phi(s_M^{(E)})ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sM(E)),也计算Agent的轨迹期望值μ(πg)=∑tγtϕ(st(g))\mu(\pi_g)=\sum_t\gamma^t\phi(s_t^{(g)})μ(πg)=tγtϕ(st(g))
  5. 求解最优w=w∗w=w^*w=w以更新线性Reward函数rϕ=wTμ(π)r_\phi=w^T\mu(\pi)rϕ=wTμ(π):
    maxt,wts.t.{wTμ(πE)≥wTμ(πg)+t,g=0,1,...,i−1∥w∥2≤1\begin{aligned}& max_{t,w}\quad t\\ & s.t.\begin{cases}w^T\mu(\pi_E)\geq w^T\mu(\pi_g)+t,g=0,1,...,i-1 \\[2ex] \lVert w\rVert_2\leq1 \end{cases} \end{aligned} maxt,wts.t.wTμ(πE)wTμ(πg)+t,g=0,1,...,i1w21
  6. g=g+1g=g+1g=g+1,若达到最大迭代次数,终止,否则转步骤4.

其他的步骤都没什么,主要是步骤(5),步骤(5)不仅不是线性规划,还带有一个令人讨厌的非线性约束(意味着只能采用拉格朗日乘子法),但如果大家仔细分析这个问题就会发现,损失函数只和t有关,而www向量虽然有模长的限制,但是方向是可以随意变化的,那么为了让Margin最大,肯定要取和μ(πE)−μ(πg)\mu(\pi_E)-\mu(\pi_g)μ(πE)μ(πg)平行的方向,然后取到最大模长1,由此步骤(5)就迎刃而解。对于g=0,1,2,...,ig=0,1,2,...,ig=0,1,2,...,i均成立的问题,只需要计算取Margin最小的w∗w^*w以满足ttt即可:
wg∗=μ(πE)−μ(πg)∥μ(πE)−μ(πg)∥2,(g=0,1,...,i)w∗=arg min⁡g=0,1,...,iwg∗[μ(πE)−μ(πg)]t∗=w∗(μ(πE)−μ(πg))\begin{array}{l} w^*_g=\frac{\mu(\pi_E)-\mu(\pi_g)}{\lVert\mu(\pi_E)-\mu(\pi_g)\rVert_2},(g=0,1,...,i)\\[2ex] w^*=\argmin_{g=0,1,...,i} {w^*_g[\mu(\pi_E)-\mu(\pi_g)]}\\[2ex] t^*=w^*(\mu(\pi_E)-\mu(\pi_g))\\ \end{array}wg=μ(πE)μ(πg)2μ(πE)μ(πg),(g=0,1,...,i)w=argming=0,1,...,iwg[μ(πE)μ(πg)]t=w(μ(πE)μ(πg))

当然了,其实真正有用的只是w∗w^*w罢了(再确切一些,只是w∗w^*w的方向有用)。原作者给出的Python代码中用的不是这个公式,但是它的解和上述公式解出来是一样的。

2.仿真环境

\qquad仿真环境使用的是gym,在IRL-Introduction也提到过这是需要Linux环境的,另外python的版本不能超过3.7(超过之后虽然这个代码可以运行,但是其他很多gym的环境都会出现bug,以强化学习为课题的读者们建议在Python3.7环境下运行)。

仿真环境说明-链接
MontainCar

\qquadMountainCar-v1是连续的状态空间,离散的动作空间的仿真环境,Action其实就是小车加速度(只能取-1, 0, 1),Observation是小车位置和速度,取值范围Boundary见上图。
\qquad真实的Reward是每隔一秒就会扣1分,直到小车到达终点。MountainCar-v1中走左侧坡道不扣分,而
v2版本则是将这条规则取消(以增加模型难度)。当然我们只是用它来做evaluation,在train中是不起作用的(因为逆强化学习用不到Reward函数)。
\qquad学徒学习参考了github项目的代码(稍微有点问题,已经在博客中修正并且在git中另申请了一个branch)
链接:

Github代码网址

\qquadpython环境依赖:numpy,gym,readchar (Python捕捉键盘操作的库)

3.运行

\qquadGithub上的代码是基于Q-Table的,没有显卡加速的,运行60000代大概需要20分钟左右(AMD R3处理器)。由于贴出了GitHub上的下载链接,这里就不粘贴代码了,如果发现不能下载的朋友也不用着急,我已经上传到资源了(下载是免费的):

下载链接

\qquad直接运行mountaincar/app/test.py即可查看效果,trian.py是训练(会把之前保存的Agent给覆盖掉),结果保存在results文件夹下。
\qquadIRL的专家轨迹保存在expert_demo文件夹下,有一个make_expert.py是专门用来产生专家轨迹的(人工游戏产生),里面指定了三个按键用来采取左,停,右的动作(加速度)。事实上这个代码有那么一点bug,我的电脑上是不能正常运行的,所以我重新写了一个,运行的时候需要在程序根目录下运行。但是它产生的轨迹包含的step数目并不是每次都相等的,只需要修改app.py中有关demonstration遍历的代码即可。

make_expert2.py \;\;\;操作:ASD按键

import gym
import readchar
import numpy as np
import pickle as pkl
# MACROS
Push_Left = 0
No_Push = 1
Push_Right = 2# Key mapping
arrow_keys = {'A': Push_Left,'S': No_Push,'D': Push_Right}
end_key = 'Q'
env = gym.make('MountainCar-v0')end_flag = False
trajectories = []
for episode in range(20): # n_trajectories : 20trajectory = []env.reset()print("episode:{}".format(episode))score = 0while True: env.render()key = readchar.readkey().upper()if key not in arrow_keys.keys():print('invalid key:{}'.format(key))if key == end_key:end_flag = Truebreakaction = arrow_keys[key]state, reward, done, _ = env.step(action)score += rewardif state[0] >= env.env.goal_position: trajectory.append((state[0], state[1], action))env.reset()print('mission accomplished! env is reset.')breaktrajectory.append((state[0], state[1], action))if end_flag:print('end!')breaktrajectory_numpy = np.array(trajectory, float)print("trajectory_numpy.shape", trajectory_numpy.shape)print("score:{}".format(score))trajectories.append(trajectory)  # don't need to seperate trajectories
env.close()
if not end_flag:with open('expert_demo.p',"wb")as f:pkl.dump(trajectories,f)

app.py修改的部分【expert_feature_expectation】

def expert_feature_expectation(feature_num, gamma, demonstrations, env):feature_estimate = FeatureEstimate(feature_num, env)feature_expectations = np.zeros(feature_num)for demo_num,traj in enumerate(demonstrations):for demo_length in range(len(traj)):state = demonstrations[demo_num][demo_length]features = feature_estimate.get_features(state)feature_expectations += (gamma**(demo_length)) * np.array(features)feature_expectations = feature_expectations / len(demonstrations)

学徒学习+Q-table的训练结果
在这里插入图片描述
在这里插入图片描述

4.补充(学徒学习+深度Q网络)

\qquad除此之外,我还尝试了将Q-Table换成DQN加以训练,由于DQN的输入state不像Q-Table一样是有限的,因此训练的时候非常不稳定,在多次调参之后,获得了一个差强人意的结果。
dqn_20000
下面是gym仿真的gif截图,与Q-Table不同,DQN的结果时好时坏,好在大多数情况下真实Reward均可以大于-200(即表示成功到达终点)

Reward:-88Reward:-90Reward: -141
在这里插入图片描述在这里插入图片描述在这里插入图片描述

\qquad需要增加和替换的代码主要为dqn,train_dpn和test_dqn三个文件,另外只需对app.py添加一个计算dqn输出的feature expectation即可运行。篇幅原因,这里仅给出dqn的代码,train和test的过程读者模仿原github项目中的train.py和test.py即可顺利完成。(全套代码链接,限时免费,本文点赞过一百将设为 永久免费,说实话研究逆强化学习的朋友不多,我个人也是没有导师指导自学的,还是期望与大家多交流)

dqn.py

import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as npdefault_dqn_paras=dict(gamma=0.99,epsilon=0.1,lr=5e-3,input_dims=2,\batch_size=128,n_actions=3,max_mem_size=int(4e3),\eps_end=0.01,eps_dec=1e-4,replace_target=50,weight_decay=5e-4)
class DeepQNetwork(nn.Module):def __init__(self, lr, input_dims, fc1_dims, fc2_dims, n_actions,weight_decay):super(DeepQNetwork, self).__init__()self.input_dims = input_dimsself.fc1_dims = fc1_dimsself.fc2_dims = fc2_dimsself.n_actions = n_actionsself.fc1 = nn.Linear(self.input_dims, self.fc1_dims)self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)self.fc3 = nn.Linear(self.fc2_dims, self.n_actions)self.optimizer = optim.Adam(self.parameters(), lr=lr,weight_decay=weight_decay)self.loss = nn.MSELoss()self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')self.to(self.device)def forward(self, state):x = F.relu(self.fc1(state),inplace=True)x = F.relu(self.fc2(x),inplace=True)actions = self.fc3(x)return actionsclass Agent():def __init__(self, gamma, epsilon, lr, input_dims, batch_size, n_actions,max_mem_size=100000, eps_end=0.05, eps_dec=5e-4, replace_target=100, weight_decay=1e-4):self.gamma = gammaself.epsilon = epsilonself.eps_min = eps_endself.eps_dec = eps_decself.training = Trueself.lr = lrself.action_space = [i for i in range(n_actions)]self.mem_size = max_mem_sizeself.batch_size = batch_sizeself.mem_cntr = 0self.iter_cntr = 0self.replace_target = replace_targetself.Q_eval = DeepQNetwork(lr, n_actions=n_actions, input_dims=input_dims,fc1_dims=32, fc2_dims=32, weight_decay=weight_decay)self.Q_next = DeepQNetwork(lr, n_actions=n_actions, input_dims=input_dims,fc1_dims=32, fc2_dims=32, weight_decay=weight_decay)self.state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)self.new_state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)self.action_memory = np.zeros(self.mem_size, dtype=np.int32)self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)self.Q_eval.eval()self.Q_next.eval()def store_transition(self, state, action, reward, state_, terminal):index = self.mem_cntr % self.mem_sizeself.state_memory[index] = stateself.new_state_memory[index] = state_self.reward_memory[index] = rewardself.action_memory[index] = actionself.terminal_memory[index] = terminalself.mem_cntr += 1@T.no_grad()def choose_action(self, observation):"""Epsilon Greedy ExplorationArgs:observation ([iterable]): observation vectorReturns:action: element in env.action_space"""self.Q_eval.eval()if np.random.random() > self.epsilon or (not self.training):state = T.tensor(observation,dtype=T.float32).detach().to(self.Q_eval.device)actions = self.Q_eval.forward(state)action = T.argmax(actions).item()else:action = np.random.choice(self.action_space)return action@T.enable_grad()def learn(self):if self.mem_cntr < self.batch_size:returnself.Q_eval.train()self.Q_next.eval()self.Q_eval.optimizer.zero_grad()max_mem = min(self.mem_cntr, self.mem_size)batch = np.random.choice(max_mem, self.batch_size, replace=False)batch_index = np.arange(self.batch_size, dtype=np.int32)state_batch = T.tensor(self.state_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)new_state_batch = T.tensor(self.new_state_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)action_batch = self.action_memory[batch]reward_batch = T.tensor(self.reward_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)terminal_batch = T.tensor(self.terminal_memory[batch]).to(self.Q_eval.device)q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]q_next = self.Q_next.forward(new_state_batch).detach()q_next[terminal_batch] = 0.0  # when state is terminal state, value function is zeroq_target = reward_batch + self.gamma*T.max(q_next,dim=1)[0]loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device)loss.backward()self.Q_eval.optimizer.step()self.iter_cntr += 1self.epsilon = max(self.epsilon - self.eps_dec, self.eps_min)if self.iter_cntr % self.replace_target == 0:self.Q_next.load_state_dict(self.Q_eval.state_dict())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/544426.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试官:HashMap有几种遍历方法?推荐使用哪种?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;HashMap 的遍历方法有很多种&#xff0c;不同的 JDK 版本有不同的写法&#xff0c;其中 JDK 8 就提供了 3 种 HashMa…

HTML 5 input placeholder 属性

<input placeholder"请先选择组织" type"text" value"" </input>placeholder 属性提供可描述输入字段预期值的提示信息&#xff08;hint&#xff09;。 该提示会在输入字段为空时显示&#xff0c;并会在字段获得焦点时消失。 注释&…

【逆强化学习-2】最大熵学习(Maximum Entropy Learning)

文章目录0.引言1.算法原理2.仿真0.引言 \qquad本文是逆强化学习系列的第2篇&#xff0c;其余博客传送门如下&#xff1a; 逆强化学习0-Introduction 逆强化学习1-学徒学习 \qquad最大熵学习是2008年出现的方法&#xff0c;原论文&#xff08;链接见【逆强化学习0】的博客&#…

uselocale_Java扫描仪useLocale()方法与示例

uselocale扫描器类useLocale()方法 (Scanner Class useLocale() method) useLocale() method is available in java.util package. useLocale()方法在java.util包中可用。 useLocale() method is used to use this Scanner locale to the given locale (lo). useLocale()方法用…

面试官又整新活,居然问我for循环用i++和++i哪个效率高?

前几天&#xff0c;一个小伙伴告诉我&#xff0c;他在面试的时候被面试官问了这么一个问题&#xff1a;在for循环中&#xff0c;到底应该用 i 还是 i &#xff1f;听到这&#xff0c;我感觉这面试官确实有点不按套路出牌了&#xff0c;放着好好的八股文不问&#xff0c;净整些幺…

UVa 988 - Many Paths, One Destination

称号&#xff1a;生命是非常多的选择。现在给你一些选择&#xff08;0~n-1&#xff09;&#xff0c;和其他选项后&#xff0c;分支数每一次选择&#xff0c;选择共求。 分析&#xff1a;dp&#xff0c;图论。假设一个状态也许是选择的数量0一个是&#xff0c;代表死亡&#xff…

Java PrintWriter close()方法与示例

PrintWriter类close()方法 (PrintWriter Class close() method) close() method is available in java.io package. close()方法在java.io包中可用。 close() method is used to close this stream and free all system resources linked with the stream. close()方法用于关闭…

pipedreader_Java PipedReader ready()方法与示例

pipedreaderPipedReader类ready()方法 (PipedReader Class ready() method) ready() method is available in java.io package. ready()方法在java.io包中可用。 ready() method is used to check whether this PipedReader stream is ready to be read or not. ready()方法用…

面试官:如何实现 List 集合去重?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;本文已收录《Java常见面试题》系列&#xff0c;开源地址&#xff1a;https://gitee.com/mydb/interviewList 去重指的…

Windows重装Anaconda3失败解决方案【重装失败10来次首次成功的案例!】

文章目录0.环境1.原因2.解决方案0.环境 Win10 Anaconda3 2018版 python 3.7.1 注意&#xff01;此种情况只会在windows上发生&#xff0c;因为在linux上你只需要删除anaconda3整个文件夹&#xff0c;重新安装一定会成功&#xff01; 1.原因 Anaconda肯定是没有成功安装的&am…

java写入文件的几种方法分享

转自&#xff1a;http://www.jb51.net/article/47062.htm 一&#xff0c;FileWritter写入文件 FileWritter, 字符流写入字符到文件。默认情况下&#xff0c;它会使用新的内容取代所有现有的内容&#xff0c;然而&#xff0c;当指定一个true &#xff08;布尔&#xff09;值作为…

python读取pcd点云/转numpy(python2+python3,非ROS环境)

0.引言 \qquadROS的PCL库支持python读取点云&#xff0c;ROS1关联的是python2&#xff08;2.7&#xff09;&#xff0c;ROS2关联的是python3&#xff08;>3.5&#xff09;&#xff0c;但这对于windows的用户和没装ROS的ubuntu用户似乎不够友好。下面就介绍两种不需要ros的方…

Java中List排序的3种方法!

作者 | 王磊来源 | Java中文社群&#xff08;ID&#xff1a;javacn666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;在某些特殊的场景下&#xff0c;我们需要在 Java 程序中对 List 集合进行排序操作。比如从第三方接口中获取所有用户的列表&…

setdefault_Java语言环境setDefault()方法及示例

setdefault语言环境类setDefault()方法 (Locale Class setDefault() method) setDefault() method is available in java.util package. setDefault()方法在java.util包中可用。 setDefault() method is used to assign the default locale for this Locale instance of the JV…

Spring 事务失效的 8 种场景!

在日常工作中&#xff0c;如果对Spring的事务管理功能使用不当&#xff0c;则会造成Spring事务不生效的问题。而针对Spring事务不生效的问题&#xff0c;也是在跳槽面试中被问的比较频繁的一个问题。点击上方卡片关注我今天&#xff0c;我们就一起梳理下有哪些场景会导致Spring…

xcode6 AsynchronousTesting 异步任务测试

xcode集成了非常方便的测试框架&#xff0c;XCTest 在xcode6之后&#xff0c;提供了 <XCTest/XCTestCaseAsynchronousTesting.h> 利用此我们可以直接在XCTest里面测试一些异步的任务&#xff0c;比如异步网络请求 如下示例 - (void)testExample {XCTestExpectation *exce…

vscode无法识别constexpr

问题 vscode 无法识别constexpr&#xff08;常指针类型&#xff09; 方法 打开工程路径下的.vscode文件夹&#xff08;一般是自动隐藏的&#xff0c;CtrlH显示隐藏&#xff09;设置c_cpp_properties.json文件如下&#xff1a; {"configurations": [{"name…

三流Java搞技术,二流Java搞框架,一流Java…

如何反驳“99&#xff05; 的 Java 程序员都是 Spring 程序员”这句话&#xff1f;答案是不能。互联网发展至今&#xff0c;站在巨人肩膀上编程像一日三餐一样寻常。Spring Boot 的确凭一己之力拉低了 Java 开发的门槛&#xff0c;可普通开发与高开之间&#xff0c;真就因为一个…

java 方法 示例_Java语言环境getVariant()方法与示例

java 方法 示例区域设置类getVariant()方法 (Locale Class getVariant() method) getVariant() method is available in java.util package. getVariant()方法在java.util包中可用。 getVariant() method is used to get the variant code for this Locale. getVariant()方法用…

2.7-源码编译安装

网上下载源码包 wget http://网址 如果没有wget yum install -y wget建议下载下来的源码包&#xff0c;统一放到/usr/local/scr/下&#xff0c;方便维护管理养成查看INSTALL和README文档的习惯&#xff0c;内有软件安装方法和详细信息。1. ./configure --prefix/usr/l…