实现代码github仓库:RL-BaselineCode
代码库将持续更新,希望得到您的支持⭐,让我们一起进步!
文章目录
- 1. 原理讲解
- 1.1 Q值更新公式
- 1.2 ε-greedy随机方法
- 2. 算法实现
- 2.1 算法简要流程
- 2.2 游戏场景
- 2.3 算法实现
- 3. 参考文章
1. 原理讲解
Q-learning算法实际上相当简单,仅仅维护一个Q值表即可,表的维数为(所有状态S,所有动作A),表的内容称为Q值,体现该状态下采取当前动作的未来奖励期望。智能体每次选择动作时都会查询Q值表在当前状态下采取何种动作得到的未来奖励可能最多,当然也会添加一些随机性,使智能体可能选择别的可能当前认为未来奖励并不多的动作,以便跳出局部最优解,尽量得到全局最优解。
1.1 Q值更新公式
Q值更新公式为:
Q [ S , A ] = ( 1 − α ) × Q [ S , A ] + α × ( R + γ × m a x ( Q [ S n e x t , : ] ) ) Q[S,A]=(1-\alpha)\times Q[S,A]+\alpha\times (R+\gamma\times max(Q[S_{next}, :])) Q[S,A]=(1−α)×Q[S,A]+α×(R+γ×max(Q[Snext,:]))
其中,α为学习速率(learning rate),γ为折扣因子(discount factor)。根据公式可以看出,学习速率α越大,保留之前训练的效果就越少。折扣因子γ越大, m a x ( Q [ S n e x t , : ] ) max(Q[S_{next}, :]) max(Q[Snext,:])所起到的作用就越大。
其中, m a x ( Q [ S n e x t , : ] ) max(Q[S_{next}, :]) max(Q[Snext,:])指以前学习到的新状态下可能得到的最大奖励期望,也就是**记忆中的利益。**如果智能体在过去的游戏中于位置 S n e x t S_{next} Snext的某个动作上吃过甜头(例如选择了某个动作之后获得了100的奖赏),这个公式就可以让它提早地得知这个消息,以便使下回再通过位置S时选择正确的动作继续进入这个吃甜头的位置 S n e x t S_{next} Snext。
但如果下一步就是终点,那么Q值更新公式变为: Q [ S , A ] = ( 1 − α ) × Q [ S , A ] + α × R Q[S,A]=(1-\alpha)\times Q[S,A]+\alpha\times R Q[S,A]=(1−α)×Q[S,A]+α×R
其中,减少了最后一项 γ × m a x ( Q [ S n e x t , : ] ) \gamma\times max(Q[S_{next}, :]) γ×max(Q[Snext,:]),这是因为当下一个状态就是最终目标时我们不需要知道下个状态在未来可能的收益,因为下个状态就可以得到游戏结束的即时收益。一般Q值更新公式之所以多了这一步也正是因为对于下一步不是终点的状态,这一步的奖励R一般来说是0或-1,拿不到即时奖励,但是又需要记录该节点的该操作在未来的可能贡献。
1.2 ε-greedy随机方法
前面提到我们为了跳出局部最优解,尽量得到全局最优解,我们采用的方法为ε-greedy方法:每个状态以ε(epsilon 探索速率)的概率进行探索(Exploration),此时将随机选取动作,而剩下的1-ε的概率则进行利用(Exploitation),即选取当前状态下效用值较大的动作。
2. 算法实现
2.1 算法简要流程
算法流程:
初始化 Q = {};
while Q 未收敛:初始化智能体的位置S,开始新一轮游戏while S != 终结状态:使用策略π,获得动作a=π(S) 使用动作a进行游戏,获得智能体的新位置S',与奖励R(S,a)Q[S,A] ← (1-α)*Q[S,A] + α*(R(S,a) + γ* max Q[S',a]) // 更新QS ← S'
2.2 游戏场景
假设机器人必须越过迷宫并到达终点。有地雷,机器人一次只能移动一个地砖。如果机器人踏上矿井,机器人就死了。机器人必须在尽可能短的时间内到达终点。
得分/奖励系统如下:
-
机器人在每一步都失去1点。这样做是为了使机器人采用最短路径并尽可能快地到达目标。
-
如果机器人踩到地雷,则点损失为100并且游戏结束。
-
如果机器人获得动力⚡️,它会获得1点。
-
如果机器人达到最终目标,则机器人获得100分。
现在,显而易见的问题是:我们如何训练机器人以最短的路径到达最终目标而不踩矿井?
2.3 算法实现
- 超参数设置
np.random.seed(2) # 确保结果可复现
row = 5 # 游戏表格行数
col = 6 # 游戏表格列数
ACTIONS = ['up', 'right', 'down', 'left'] # 可采取的动作
EPSILON = 0.9 # ε-greedy随机方法中的ε
ALPHA = 0.1 # learning rate
GAMMA = 0.9 # discount factor
MAX_EPISODES = 5000 # 游戏共学多少轮
targetXY = [4, 4] # 游戏的最终目标位置
env_list = ['--+---', '-*--*-', '--+--+', '*--*--', '-+--T-'] # 游戏地图
- 初始化Q值表
def build_q_table(row, col, actions):table = pd.DataFrame(np.zeros((row * col, len(actions))), # q_table initial valuescolumns=actions, # actions' name)# print(table) # show tablereturn table
- 选择动作A
def choose_action(state, q_table): # ε-greedy随机方法# This is how to choose an actionstate_actions = q_table.iloc[state[0] * col + state[1], :]if (np.random.uniform() > EPSILON) or ((state_actions == 0).all()): # act non-greedy or state-action have no valueaction_name = np.random.choice(ACTIONS)else: # act greedyaction_name = state_actions.idxmax()# replace argmax to idxmax as argmax means a different function in newer version of pandasreturn action_name
- 状态S下采取动作A到达新位置A’,并得到奖励/惩罚
def getR(S):str = env_list[S[0]][S[1]]if str == '-':return -1elif str == '*':return -100elif str == '+':return 1else:return 100def get_env_feedback(S, A):# This is how agent will interact with the environmentif A == 'up': # move upif S[0] == targetXY[0]+1 and S[1] == targetXY[1]: # 到达终点S_ = 'terminal'R = 100elif S[0] == 0: # 向上碰壁S_ = SR = -1else: # 正常移动S_ = [S[0] - 1, S[1]]R = getR(S_)if R == -100: # 碰到炸弹直接结束S_ = 'terminal'elif A == 'right': # move rightif S[0] == targetXY[0] and S[1] == targetXY[1] - 1: # 到达终点S_ = 'terminal'R = 100elif S[1] == col - 1: # 向右碰壁S_ = SR = -1else: # 正常移动S_ = [S[0], S[1] + 1]R = getR(S_)if R == -100: # 碰到炸弹直接结束S_ = 'terminal'elif A == 'down': # move downif S[0] == row - 1: # 向下碰壁S_ = SR = -1elif S[0] == targetXY[0] - 1 and S[1] == targetXY[1]: # 到达终点S_ = 'terminal'R = 100else: # 正常移动S_ = [S[0] + 1, S[1]]R = getR(S_)if R == -100: # 碰到炸弹直接结束S_ = 'terminal'else: # move leftif S[0] == targetXY[0] and S[1] == targetXY[1] + 1: # 到达终点S_ = 'terminal'R = 100elif S[1] == 0: # 向左碰壁S_ = SR = -1else: # 正常移动S_ = [S[0], S[1] - 1]R = getR(S_)if R == -100: # 碰到炸弹直接结束S_ = 'terminal'return S_, R
- 更新Q值表
q_predict = q_table.loc[S[0] * col + S[1], A] # 当前位置的动作价值
if S_ != 'terminal': # next state is not terminalq_target = R + GAMMA * q_table.iloc[S_[0] * col + S_[1], :].max()
else:q_target = R # next state is terminalis_terminated = True # terminate this episode
# 当前位置的动作价值+新位置的状态价值
q_table.loc[S[0] * col + S[1], A] = (1 - ALPHA) * q_predict + ALPHA * q_target
S = S_ # move to next state
3. 参考文章
-
强化学习入门:基本思想和经典算法
-
Q-learning 的具体过程
-
【强化学习】Q-Learning算法详解以及Python实现【80行代码】