【逆强化学习-1】学徒学习(Apprenticeship Learning)

文章目录

  • 0.引言
  • 1.算法原理
  • 2.仿真环境
  • 3.运行
  • 4.补充(学徒学习+深度Q网络)

本文为逆强化学习系列第1篇,没有看过逆强化学习介绍的那篇的朋友,可以看一下:

Inverse Reinforcement Learning-Introduction 传送门

0.引言

\qquad可以说学徒学习(简称App)是最早的一种逆强化学习的方法,如果看过最原始那篇论文的读者可能会觉得这个方法怎么这么晦涩难懂,但其实看完Code你会觉得这个Algorithm还挺粗糙的,这还是一件挺矛盾的事情。由于IRL-Introduction的论文介绍中没有该论文的下载链接,因此将其提供如下:

本文完整代码链接:CSDN下载

论文PDF链接:论文PDF(CSDN资源,永久免费)

1.算法原理

\qquad在IRL-Introduction介绍过了,该方法是2004年出现的,那时候DL还不是很受欢迎,因此作者采用的是线性函数的方法计算Reward的。APP方法从Observation中提取特征,计算特征期望,再将reward作为特征期望的线性函数。
\qquad提取特征的函数可以是非线性的,对算法没有任何影响,一般记为ϕ(s)\phi(s)ϕ(s),而轨迹π\piπ的特征期望μ(π)\mu(\pi)μ(π)在paper中的计算方法即为
μ(π)=∑t=0∞γtϕ(st)\mu(\pi)=\sum_{t=0}^{\infty}\gamma^t\phi(s_t)μ(π)=t=0γtϕ(st)
\qquad如果是环境在互动过程中具有随机性,则求得的应该是特征期望的数学期望,用均值近似数学期望可以得到(其中n为轨迹数目,st(i)s_t^{(i)}st(i)为第iii条轨迹的第ttt个状态):
μ(π)=1n∑i=1n∑t=0∞γtϕ(st(i))\mu(\pi)=\frac{1}{n}\sum_{i=1}^{n}\sum_{t=0}^{\infty}\gamma^t\phi(s_t^{(i)})μ(π)=n1i=1nt=0γtϕ(st(i))
在实际求解中,计算有限的时域ttt近似即可。
\qquad学徒学习的基本思想是寻找Reward使得所有的Agent产生的Reward都小于Expert的Reward的,再用这个reward去训练Agent。但为了防止Margin走向0和无穷大两个极端,在约束条件时还加入了一些限制。

该算法的步骤如下:

  1. 获得专家轨迹πE\pi_EπE,定于特征提取函数ϕ(st),t=1,2,...,N\phi(s_t),t=1,2,...,Nϕ(st),t=1,2,...,N,产生一个Random Agent,迭代次数g=0g=0g=0,随机设定一个www,奖赏函数rϕ=wTμ(π)r_\phi=w^T\mu(\pi)rϕ=wTμ(π)
  2. 利用ϕ\phiϕ提取专家轨迹的特征ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sN(E))\phi(s_1^{(E)}),\phi(s_2^{(E)}),...,\phi(s_N^{(E)})ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sN(E))
  3. 定义折扣因子γ\gammaγ并计算专家轨迹的特征期望μ(πE)=∑tγtϕ(st(E))\mu(\pi_E)=\sum_t\gamma^t\phi(s_t^{(E)})μ(πE)=tγtϕ(st(E))
  4. 奖赏函数设为rϕ=wTμ(πg)r_\phi=w^T\mu(\pi_g)rϕ=wTμ(πg),Agent与环境互动(可能会互动不止一次,因为model的更新需要时间),产生轨迹πg\pi_gπg,提取特征为ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sM(E))\phi(s_1^{(E)}),\phi(s_2^{(E)}),...,\phi(s_M^{(E)})ϕ(s1(E)),ϕ(s2(E)),...,ϕ(sM(E)),也计算Agent的轨迹期望值μ(πg)=∑tγtϕ(st(g))\mu(\pi_g)=\sum_t\gamma^t\phi(s_t^{(g)})μ(πg)=tγtϕ(st(g))
  5. 求解最优w=w∗w=w^*w=w以更新线性Reward函数rϕ=wTμ(π)r_\phi=w^T\mu(\pi)rϕ=wTμ(π):
    maxt,wts.t.{wTμ(πE)≥wTμ(πg)+t,g=0,1,...,i−1∥w∥2≤1\begin{aligned}& max_{t,w}\quad t\\ & s.t.\begin{cases}w^T\mu(\pi_E)\geq w^T\mu(\pi_g)+t,g=0,1,...,i-1 \\[2ex] \lVert w\rVert_2\leq1 \end{cases} \end{aligned} maxt,wts.t.wTμ(πE)wTμ(πg)+t,g=0,1,...,i1w21
  6. g=g+1g=g+1g=g+1,若达到最大迭代次数,终止,否则转步骤4.

其他的步骤都没什么,主要是步骤(5),步骤(5)不仅不是线性规划,还带有一个令人讨厌的非线性约束(意味着只能采用拉格朗日乘子法),但如果大家仔细分析这个问题就会发现,损失函数只和t有关,而www向量虽然有模长的限制,但是方向是可以随意变化的,那么为了让Margin最大,肯定要取和μ(πE)−μ(πg)\mu(\pi_E)-\mu(\pi_g)μ(πE)μ(πg)平行的方向,然后取到最大模长1,由此步骤(5)就迎刃而解。对于g=0,1,2,...,ig=0,1,2,...,ig=0,1,2,...,i均成立的问题,只需要计算取Margin最小的w∗w^*w以满足ttt即可:
wg∗=μ(πE)−μ(πg)∥μ(πE)−μ(πg)∥2,(g=0,1,...,i)w∗=arg min⁡g=0,1,...,iwg∗[μ(πE)−μ(πg)]t∗=w∗(μ(πE)−μ(πg))\begin{array}{l} w^*_g=\frac{\mu(\pi_E)-\mu(\pi_g)}{\lVert\mu(\pi_E)-\mu(\pi_g)\rVert_2},(g=0,1,...,i)\\[2ex] w^*=\argmin_{g=0,1,...,i} {w^*_g[\mu(\pi_E)-\mu(\pi_g)]}\\[2ex] t^*=w^*(\mu(\pi_E)-\mu(\pi_g))\\ \end{array}wg=μ(πE)μ(πg)2μ(πE)μ(πg),(g=0,1,...,i)w=argming=0,1,...,iwg[μ(πE)μ(πg)]t=w(μ(πE)μ(πg))

当然了,其实真正有用的只是w∗w^*w罢了(再确切一些,只是w∗w^*w的方向有用)。原作者给出的Python代码中用的不是这个公式,但是它的解和上述公式解出来是一样的。

2.仿真环境

\qquad仿真环境使用的是gym,在IRL-Introduction也提到过这是需要Linux环境的,另外python的版本不能超过3.7(超过之后虽然这个代码可以运行,但是其他很多gym的环境都会出现bug,以强化学习为课题的读者们建议在Python3.7环境下运行)。

仿真环境说明-链接
MontainCar

\qquadMountainCar-v1是连续的状态空间,离散的动作空间的仿真环境,Action其实就是小车加速度(只能取-1, 0, 1),Observation是小车位置和速度,取值范围Boundary见上图。
\qquad真实的Reward是每隔一秒就会扣1分,直到小车到达终点。MountainCar-v1中走左侧坡道不扣分,而
v2版本则是将这条规则取消(以增加模型难度)。当然我们只是用它来做evaluation,在train中是不起作用的(因为逆强化学习用不到Reward函数)。
\qquad学徒学习参考了github项目的代码(稍微有点问题,已经在博客中修正并且在git中另申请了一个branch)
链接:

Github代码网址

\qquadpython环境依赖:numpy,gym,readchar (Python捕捉键盘操作的库)

3.运行

\qquadGithub上的代码是基于Q-Table的,没有显卡加速的,运行60000代大概需要20分钟左右(AMD R3处理器)。由于贴出了GitHub上的下载链接,这里就不粘贴代码了,如果发现不能下载的朋友也不用着急,我已经上传到资源了(下载是免费的):

下载链接

\qquad直接运行mountaincar/app/test.py即可查看效果,trian.py是训练(会把之前保存的Agent给覆盖掉),结果保存在results文件夹下。
\qquadIRL的专家轨迹保存在expert_demo文件夹下,有一个make_expert.py是专门用来产生专家轨迹的(人工游戏产生),里面指定了三个按键用来采取左,停,右的动作(加速度)。事实上这个代码有那么一点bug,我的电脑上是不能正常运行的,所以我重新写了一个,运行的时候需要在程序根目录下运行。但是它产生的轨迹包含的step数目并不是每次都相等的,只需要修改app.py中有关demonstration遍历的代码即可。

make_expert2.py \;\;\;操作:ASD按键

import gym
import readchar
import numpy as np
import pickle as pkl
# MACROS
Push_Left = 0
No_Push = 1
Push_Right = 2# Key mapping
arrow_keys = {'A': Push_Left,'S': No_Push,'D': Push_Right}
end_key = 'Q'
env = gym.make('MountainCar-v0')end_flag = False
trajectories = []
for episode in range(20): # n_trajectories : 20trajectory = []env.reset()print("episode:{}".format(episode))score = 0while True: env.render()key = readchar.readkey().upper()if key not in arrow_keys.keys():print('invalid key:{}'.format(key))if key == end_key:end_flag = Truebreakaction = arrow_keys[key]state, reward, done, _ = env.step(action)score += rewardif state[0] >= env.env.goal_position: trajectory.append((state[0], state[1], action))env.reset()print('mission accomplished! env is reset.')breaktrajectory.append((state[0], state[1], action))if end_flag:print('end!')breaktrajectory_numpy = np.array(trajectory, float)print("trajectory_numpy.shape", trajectory_numpy.shape)print("score:{}".format(score))trajectories.append(trajectory)  # don't need to seperate trajectories
env.close()
if not end_flag:with open('expert_demo.p',"wb")as f:pkl.dump(trajectories,f)

app.py修改的部分【expert_feature_expectation】

def expert_feature_expectation(feature_num, gamma, demonstrations, env):feature_estimate = FeatureEstimate(feature_num, env)feature_expectations = np.zeros(feature_num)for demo_num,traj in enumerate(demonstrations):for demo_length in range(len(traj)):state = demonstrations[demo_num][demo_length]features = feature_estimate.get_features(state)feature_expectations += (gamma**(demo_length)) * np.array(features)feature_expectations = feature_expectations / len(demonstrations)

学徒学习+Q-table的训练结果
在这里插入图片描述
在这里插入图片描述

4.补充(学徒学习+深度Q网络)

\qquad除此之外,我还尝试了将Q-Table换成DQN加以训练,由于DQN的输入state不像Q-Table一样是有限的,因此训练的时候非常不稳定,在多次调参之后,获得了一个差强人意的结果。
dqn_20000
下面是gym仿真的gif截图,与Q-Table不同,DQN的结果时好时坏,好在大多数情况下真实Reward均可以大于-200(即表示成功到达终点)

Reward:-88Reward:-90Reward: -141
在这里插入图片描述在这里插入图片描述在这里插入图片描述

\qquad需要增加和替换的代码主要为dqn,train_dpn和test_dqn三个文件,另外只需对app.py添加一个计算dqn输出的feature expectation即可运行。篇幅原因,这里仅给出dqn的代码,train和test的过程读者模仿原github项目中的train.py和test.py即可顺利完成。(全套代码链接,限时免费,本文点赞过一百将设为 永久免费,说实话研究逆强化学习的朋友不多,我个人也是没有导师指导自学的,还是期望与大家多交流)

dqn.py

import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as npdefault_dqn_paras=dict(gamma=0.99,epsilon=0.1,lr=5e-3,input_dims=2,\batch_size=128,n_actions=3,max_mem_size=int(4e3),\eps_end=0.01,eps_dec=1e-4,replace_target=50,weight_decay=5e-4)
class DeepQNetwork(nn.Module):def __init__(self, lr, input_dims, fc1_dims, fc2_dims, n_actions,weight_decay):super(DeepQNetwork, self).__init__()self.input_dims = input_dimsself.fc1_dims = fc1_dimsself.fc2_dims = fc2_dimsself.n_actions = n_actionsself.fc1 = nn.Linear(self.input_dims, self.fc1_dims)self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)self.fc3 = nn.Linear(self.fc2_dims, self.n_actions)self.optimizer = optim.Adam(self.parameters(), lr=lr,weight_decay=weight_decay)self.loss = nn.MSELoss()self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')self.to(self.device)def forward(self, state):x = F.relu(self.fc1(state),inplace=True)x = F.relu(self.fc2(x),inplace=True)actions = self.fc3(x)return actionsclass Agent():def __init__(self, gamma, epsilon, lr, input_dims, batch_size, n_actions,max_mem_size=100000, eps_end=0.05, eps_dec=5e-4, replace_target=100, weight_decay=1e-4):self.gamma = gammaself.epsilon = epsilonself.eps_min = eps_endself.eps_dec = eps_decself.training = Trueself.lr = lrself.action_space = [i for i in range(n_actions)]self.mem_size = max_mem_sizeself.batch_size = batch_sizeself.mem_cntr = 0self.iter_cntr = 0self.replace_target = replace_targetself.Q_eval = DeepQNetwork(lr, n_actions=n_actions, input_dims=input_dims,fc1_dims=32, fc2_dims=32, weight_decay=weight_decay)self.Q_next = DeepQNetwork(lr, n_actions=n_actions, input_dims=input_dims,fc1_dims=32, fc2_dims=32, weight_decay=weight_decay)self.state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)self.new_state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)self.action_memory = np.zeros(self.mem_size, dtype=np.int32)self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)self.terminal_memory = np.zeros(self.mem_size, dtype=np.bool)self.Q_eval.eval()self.Q_next.eval()def store_transition(self, state, action, reward, state_, terminal):index = self.mem_cntr % self.mem_sizeself.state_memory[index] = stateself.new_state_memory[index] = state_self.reward_memory[index] = rewardself.action_memory[index] = actionself.terminal_memory[index] = terminalself.mem_cntr += 1@T.no_grad()def choose_action(self, observation):"""Epsilon Greedy ExplorationArgs:observation ([iterable]): observation vectorReturns:action: element in env.action_space"""self.Q_eval.eval()if np.random.random() > self.epsilon or (not self.training):state = T.tensor(observation,dtype=T.float32).detach().to(self.Q_eval.device)actions = self.Q_eval.forward(state)action = T.argmax(actions).item()else:action = np.random.choice(self.action_space)return action@T.enable_grad()def learn(self):if self.mem_cntr < self.batch_size:returnself.Q_eval.train()self.Q_next.eval()self.Q_eval.optimizer.zero_grad()max_mem = min(self.mem_cntr, self.mem_size)batch = np.random.choice(max_mem, self.batch_size, replace=False)batch_index = np.arange(self.batch_size, dtype=np.int32)state_batch = T.tensor(self.state_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)new_state_batch = T.tensor(self.new_state_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)action_batch = self.action_memory[batch]reward_batch = T.tensor(self.reward_memory[batch]).detach().requires_grad_(True).to(self.Q_eval.device)terminal_batch = T.tensor(self.terminal_memory[batch]).to(self.Q_eval.device)q_eval = self.Q_eval.forward(state_batch)[batch_index, action_batch]q_next = self.Q_next.forward(new_state_batch).detach()q_next[terminal_batch] = 0.0  # when state is terminal state, value function is zeroq_target = reward_batch + self.gamma*T.max(q_next,dim=1)[0]loss = self.Q_eval.loss(q_target, q_eval).to(self.Q_eval.device)loss.backward()self.Q_eval.optimizer.step()self.iter_cntr += 1self.epsilon = max(self.epsilon - self.eps_dec, self.eps_min)if self.iter_cntr % self.replace_target == 0:self.Q_next.load_state_dict(self.Q_eval.state_dict())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/544426.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试官:HashMap有几种遍历方法?推荐使用哪种?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;HashMap 的遍历方法有很多种&#xff0c;不同的 JDK 版本有不同的写法&#xff0c;其中 JDK 8 就提供了 3 种 HashMa…

【逆强化学习-2】最大熵学习(Maximum Entropy Learning)

文章目录0.引言1.算法原理2.仿真0.引言 \qquad本文是逆强化学习系列的第2篇&#xff0c;其余博客传送门如下&#xff1a; 逆强化学习0-Introduction 逆强化学习1-学徒学习 \qquad最大熵学习是2008年出现的方法&#xff0c;原论文&#xff08;链接见【逆强化学习0】的博客&#…

面试官又整新活,居然问我for循环用i++和++i哪个效率高?

前几天&#xff0c;一个小伙伴告诉我&#xff0c;他在面试的时候被面试官问了这么一个问题&#xff1a;在for循环中&#xff0c;到底应该用 i 还是 i &#xff1f;听到这&#xff0c;我感觉这面试官确实有点不按套路出牌了&#xff0c;放着好好的八股文不问&#xff0c;净整些幺…

面试官:如何实现 List 集合去重?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;本文已收录《Java常见面试题》系列&#xff0c;开源地址&#xff1a;https://gitee.com/mydb/interviewList 去重指的…

Windows重装Anaconda3失败解决方案【重装失败10来次首次成功的案例!】

文章目录0.环境1.原因2.解决方案0.环境 Win10 Anaconda3 2018版 python 3.7.1 注意&#xff01;此种情况只会在windows上发生&#xff0c;因为在linux上你只需要删除anaconda3整个文件夹&#xff0c;重新安装一定会成功&#xff01; 1.原因 Anaconda肯定是没有成功安装的&am…

python读取pcd点云/转numpy(python2+python3,非ROS环境)

0.引言 \qquadROS的PCL库支持python读取点云&#xff0c;ROS1关联的是python2&#xff08;2.7&#xff09;&#xff0c;ROS2关联的是python3&#xff08;>3.5&#xff09;&#xff0c;但这对于windows的用户和没装ROS的ubuntu用户似乎不够友好。下面就介绍两种不需要ros的方…

Java中List排序的3种方法!

作者 | 王磊来源 | Java中文社群&#xff08;ID&#xff1a;javacn666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;在某些特殊的场景下&#xff0c;我们需要在 Java 程序中对 List 集合进行排序操作。比如从第三方接口中获取所有用户的列表&…

Spring 事务失效的 8 种场景!

在日常工作中&#xff0c;如果对Spring的事务管理功能使用不当&#xff0c;则会造成Spring事务不生效的问题。而针对Spring事务不生效的问题&#xff0c;也是在跳槽面试中被问的比较频繁的一个问题。点击上方卡片关注我今天&#xff0c;我们就一起梳理下有哪些场景会导致Spring…

vscode无法识别constexpr

问题 vscode 无法识别constexpr&#xff08;常指针类型&#xff09; 方法 打开工程路径下的.vscode文件夹&#xff08;一般是自动隐藏的&#xff0c;CtrlH显示隐藏&#xff09;设置c_cpp_properties.json文件如下&#xff1a; {"configurations": [{"name…

三流Java搞技术,二流Java搞框架,一流Java…

如何反驳“99&#xff05; 的 Java 程序员都是 Spring 程序员”这句话&#xff1f;答案是不能。互联网发展至今&#xff0c;站在巨人肩膀上编程像一日三餐一样寻常。Spring Boot 的确凭一己之力拉低了 Java 开发的门槛&#xff0c;可普通开发与高开之间&#xff0c;真就因为一个…

【Ubuntu】vscode配置PCL库/vscode无法导入PCL库

问题 PCL库是ROS框架自带的点云处理库&#xff0c;可以通过find_package(PCL REQUIRED)在CMakeLists.txt中导入&#xff0c;但是vscode却无法识别&#xff0c;出现问题如下&#xff1a; 注意&#xff0c;本文解决方案仅限Ubuntu&#xff01; 解决方案 打开工程路径下的.vsc…

面试官:HashSet是如何保证元素不重复的?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;本文已收录《Java常见面试题》系列&#xff0c;开源地址&#xff1a;https://gitee.com/mydb/interviewHashSet 实现…

【Ubuntu】Ubuntu 20.04无法识别网口/以太网/有线网卡

这里写自定义目录标题0.症状1.查看网卡类型2.下载网卡驱动3.安装网卡驱动0.症状 \qquad表现为插入以太网网口后右上角没有显示网络&#xff0c;即没有下图的音量左侧的标志 打开设置的【网络】选项没有以太网接入&#xff0c;然而以太网口信号灯仍然正常闪烁。这种情况基本可以…

小心Lombok用法中的坑

刚才写完了代码&#xff0c;自测的时候&#xff0c;出现了NPE问题。排查的时候发现是Lombok的坑&#xff0c;以前也遇到过&#xff0c;所以觉得有必要过来记录一下。我先描述一下现象&#xff0c;我的代码里面订单服务A 需要调用缓存服务B&#xff0c;服务B就是一个Bean&#x…

【VSCode】VSCode使用conda环境时找不到python包/找不到Module

这里写自定义目录标题0.问题描述1.原因2.解决方法0.问题描述 \qquad首先需要排除是否是VSCode未配置conda环境的问题&#xff0c;当然&#xff0c;相信VSCode的老粉都不会犯这个低级错误&#xff0c;请CtrlP&#xff0c;在搜索框>select interpreter检查一下python环境。 …

PS如何对JPG文件直接抠图

如何JPG文件直接抠图 先转为智能对象&#xff1a; 再栅格化图层 此进即可直接进行抠图&#xff01;

更快的Maven来了,我的天,速度提升了8倍!

作者 | 王磊来源 | Java中文社群&#xff08;ID&#xff1a;javacn666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;周末被 maven-mvnd 刷屏了&#xff0c;于是我也下载了一个 mvnd 体验了一把。虽然测试的数据都是基于我本地项目&#xff0c…

Java 中接口和抽象类竟然有 7 点不同?

作者 | 磊哥来源 | Java面试真题解析&#xff08;ID&#xff1a;aimianshi666&#xff09;转载请联系授权&#xff08;微信ID&#xff1a;GG_Stone&#xff09;本文已收录《Java常见面试题》系列&#xff1a;https://gitee.com/mydb/interviewJava 是一门面向对象的编程语言&am…

粉丝不在于多,在于够残

李善友&#xff1a;所有可能被互联网取代的组织一定会被取代 2015-07-30 格局视野 格局视野 格局视野 微信号 geju365 功能介绍 格局生涯学院官方自媒体。面向互联网人的在线商学院。推送互联网行业知识&#xff0c;培养互联网实操人才。聚焦新行业、新模式、新公司、新人物。…

保姆级教学:缓存穿透、缓存击穿和缓存雪崩!

前言对于从事后端开发的同学来说&#xff0c;缓存已经变成的项目中必不可少的技术之一。没错&#xff0c;缓存能给我们系统显著的提升性能。但如果你使用不好&#xff0c;或者缺乏相关经验&#xff0c;它也会带来很多意想不到的问题。今天我们一起聊聊如果在项目中引入了缓存&a…