强化学习------Sarsa算法

简介

SARSA(State-Action-Reward-State-Action)是一个学习马尔可夫决策过程策略的算法,通常应用于机器学习和强化学习学习领域中。它由RummeryNiranjan在技术论文“Modified Connectionist Q-Learning(MCQL)” 中介绍了这个算法,并且由Rich Sutton在注脚处提到了SARSA这个别名。
State-Action-Reward-State-Action这个名称清楚地反应了其学习更新函数依赖的5个值,分别是当前状态S1,当前状态选中的动作A1,获得的奖励RewardS1状态下执行A1后取得的状态S2S2状态下将会执行的动作A2。我们取这5个值的首字母串起来可以得出一个词SARSA

算法的核心思想可以简化为:

Latex代码:
用伪代码可以表示为:
在这里插入图片描述

算法实战

我们使用openAI的gym中的CliffWalking-v0作为环境

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import numpy as np
import gym
import time
import gridworld#Sarsa算法
class Sarsa():def __init__(self,num_states,num_actions,e_greed=0.1,lr=0.9,gamma=0.8):#建立Q表格self.Q = np.zeros((num_states,num_actions))self.e_greed = e_greed   #探索概率self.num_states = num_statesself.num_actions = num_actionsself.lr = lr   #学习率self.gamma = gamma #折扣因子def predict(self,state):"""通过当前状态预测下一个动作:param state::return:"""#获取当前状态的所有动作的切片Q_list = self.Q[state,:]#随机选取其中最大值中的某一个(防止存在多个最大值时,总是选最前面的问题)action = np.random.choice(np.flatnonzero(Q_list == Q_list.max()))return  actiondef action(self,state):"""选取动作:param state::return:"""#探索,随机选择一个动作if np.random.uniform(0,1) < self.e_greed:action = np.random.choice(self.num_actions)else:   #直接选取最大Q值的动作action = self.predict(state)return actiondef learn(self,state,action,reward,next_state,next_action,done):cur_Q = self.Q[state,action]# 当游戏结束时,不存在next_action和next_stateif done:target_Q = rewardelse:target_Q = reward + self.gamma*self.Q[next_state,next_action]self.Q[state,action] += self.lr*(target_Q - cur_Q)#训练
def train_episode(env,agent,is_render):total_reward = 0#初始化环境state,_ = env.reset()action = agent.action(state)while True:#执行动作返回结果next_state,reward,done,_,_ = env.step(action)#根据状态获取动作next_action = agent.action(next_state)#更新参数agent.learn(state,action,reward,next_state,next_action,done)#循环执行action = next_actionstate = next_statetotal_reward += rewardif is_render:env.render()if done:breakreturn  total_reward
#测试
def test_episode(env,agent,is_render=False):total_reward = 0# 初始化环境state,_ = env.reset()while True:action = agent.predict(state)next_state, reward, done, _,_ = env.step(action)state = next_statetotal_reward += rewardenv.render()time.sleep(0.5)if done:breakreturn total_reward
#训练
def train(env,episodes=500,lr=0.1,gamma=0.9,e_greed=0.1):agent = Sarsa(num_states = env.observation_space.n,num_actions = env.action_space.n,lr = lr,gamma = gamma,e_greed = e_greed)is_render = False#先训练episodes次for e in range(episodes):ep_reward = train_episode(env,agent,is_render)print('Episode %s : reward= %.1f'%(e,ep_reward))#每执行50轮就显示一次if e%50 == 0:is_render = Trueelse:is_render = False#训练结束后,我i们测试模型test_reward = test_episode(env,agent)print('test_reward= %.1f' % (test_reward))if __name__ == '__main__':env = gym.make("CliffWalking-v0")env = gridworld.CliffWalkingWapper(env)train(env)

运行效果

在这里插入图片描述

另附工具类

#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.# -*- coding: utf-8 -*-import gym
import turtle
import numpy as np# turtle tutorial : https://docs.python.org/3.3/library/turtle.htmldef GridWorld(gridmap=None, is_slippery=False):if gridmap is None:gridmap = ['SFFF', 'FHFH', 'FFFH', 'HFFG']env = gym.make("FrozenLake-v0", desc=gridmap, is_slippery=False)env = FrozenLakeWapper(env)return envclass FrozenLakeWapper(gym.Wrapper):def __init__(self, env):gym.Wrapper.__init__(self, env)self.max_y = env.desc.shape[0]self.max_x = env.desc.shape[1]self.t = Noneself.unit = 50def draw_box(self, x, y, fillcolor='', line_color='gray'):self.t.up()self.t.goto(x * self.unit, y * self.unit)self.t.color(line_color)self.t.fillcolor(fillcolor)self.t.setheading(90)self.t.down()self.t.begin_fill()for _ in range(4):self.t.forward(self.unit)self.t.right(90)self.t.end_fill()def move_player(self, x, y):self.t.up()self.t.setheading(90)self.t.fillcolor('red')self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)def render(self):if self.t == None:self.t = turtle.Turtle()self.wn = turtle.Screen()self.wn.setup(self.unit * self.max_x + 100,self.unit * self.max_y + 100)self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,self.unit * self.max_y)self.t.shape('circle')self.t.width(2)self.t.speed(0)self.t.color('gray')for i in range(self.desc.shape[0]):for j in range(self.desc.shape[1]):x = jy = self.max_y - 1 - iif self.desc[i][j] == b'S':  # Startself.draw_box(x, y, 'white')elif self.desc[i][j] == b'F':  # Frozen iceself.draw_box(x, y, 'white')elif self.desc[i][j] == b'G':  # Goalself.draw_box(x, y, 'yellow')elif self.desc[i][j] == b'H':  # Holeself.draw_box(x, y, 'black')else:self.draw_box(x, y, 'white')self.t.shape('turtle')x_pos = self.s % self.max_xy_pos = self.max_y - 1 - int(self.s / self.max_x)self.move_player(x_pos, y_pos)class CliffWalkingWapper(gym.Wrapper):def __init__(self, env):gym.Wrapper.__init__(self, env)self.t = Noneself.unit = 50self.max_x = 12self.max_y = 4def draw_x_line(self, y, x0, x1, color='gray'):assert x1 > x0self.t.color(color)self.t.setheading(0)self.t.up()self.t.goto(x0, y)self.t.down()self.t.forward(x1 - x0)def draw_y_line(self, x, y0, y1, color='gray'):assert y1 > y0self.t.color(color)self.t.setheading(90)self.t.up()self.t.goto(x, y0)self.t.down()self.t.forward(y1 - y0)def draw_box(self, x, y, fillcolor='', line_color='gray'):self.t.up()self.t.goto(x * self.unit, y * self.unit)self.t.color(line_color)self.t.fillcolor(fillcolor)self.t.setheading(90)self.t.down()self.t.begin_fill()for i in range(4):self.t.forward(self.unit)self.t.right(90)self.t.end_fill()def move_player(self, x, y):self.t.up()self.t.setheading(90)self.t.fillcolor('red')self.t.goto((x + 0.5) * self.unit, (y + 0.5) * self.unit)def render(self):if self.t == None:self.t = turtle.Turtle()self.wn = turtle.Screen()self.wn.setup(self.unit * self.max_x + 100,self.unit * self.max_y + 100)self.wn.setworldcoordinates(0, 0, self.unit * self.max_x,self.unit * self.max_y)self.t.shape('circle')self.t.width(2)self.t.speed(0)self.t.color('gray')for _ in range(2):self.t.forward(self.max_x * self.unit)self.t.left(90)self.t.forward(self.max_y * self.unit)self.t.left(90)for i in range(1, self.max_y):self.draw_x_line(y=i * self.unit, x0=0, x1=self.max_x * self.unit)for i in range(1, self.max_x):self.draw_y_line(x=i * self.unit, y0=0, y1=self.max_y * self.unit)for i in range(1, self.max_x - 1):self.draw_box(i, 0, 'black')self.draw_box(self.max_x - 1, 0, 'yellow')self.t.shape('turtle')x_pos = self.s % self.max_xy_pos = self.max_y - 1 - int(self.s / self.max_x)self.move_player(x_pos, y_pos)if __name__ == '__main__':# 环境1:FrozenLake, 可以配置冰面是否是滑的# 0 left, 1 down, 2 right, 3 upenv = gym.make("FrozenLake-v0", is_slippery=False)env = FrozenLakeWapper(env)# 环境2:CliffWalking, 悬崖环境# env = gym.make("CliffWalking-v0")  # 0 up, 1 right, 2 down, 3 left# env = CliffWalkingWapper(env)# 环境3:自定义格子世界,可以配置地图, S为出发点Start, F为平地Floor, H为洞Hole, G为出口目标Goal# gridmap = [#         'SFFF',#         'FHFF',#         'FFFF',#         'HFGF' ]# env = GridWorld(gridmap)env.reset()for step in range(10):action = np.random.randint(0, 4)obs, reward, done, info = env.step(action)print('step {}: action {}, obs {}, reward {}, done {}, info {}'.format(\step, action, obs, reward, done, info))env.render()  # 渲染一帧图像

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/96771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringTask ----定时任务框架 ----苍穹外卖day10

目录 SpringTask 需求分析 快速入门 使用步骤 ​编辑业务开发 SpringTask 定时任务场景特化的框架 需求分析 快速入门 使用cron表达式来使用该框架 使用步骤 添加注解 自定义定时任务类 重点在于以下cron表达式的书写,精确表达触发的间隔 业务开发 主task方法 time使用(-…

数据结构:二叉树(超详解析)

目录​​​​​​​ 1.树概念及结构 1.1树的概念 1.2树的相关概念 1.3树的表示 1.3.1孩子兄弟表示法&#xff1a; 1.3.2双亲表示法&#xff1a;只存储双亲的下标或指针 两节点不在同一树上&#xff1a; 2.二叉树概念及结构 2.1.概念 2.2.特殊的二叉树&#xff1a; 2…

【C++设计模式之组合模式:结构型】分析及示例

简介 组合模式是一种结构型设计模式&#xff0c;它能够将对象组合成树形结构以表示“整体-部分”的层次结构&#xff0c;并且能够使用相同的方式处理单个对象和组合对象。组合模式使得客户端可以一致地处理单个对象和组合对象&#xff0c;无需关心具体的对象类型。 组合模式将对…

企业想过等保,其中2FA双因素认证手段必不可少

随着信息技术的飞速发展&#xff0c;网络安全问题日益凸显。等保2.0时代的到来&#xff0c;意味着企业和组织需要更加严格地保护自身的信息安全。而在这个过程中&#xff0c;双因素认证的重要性逐渐得到广泛认可。本文将探讨 2FA 双因素认证的重要性。 在了解 2FA 双因素认证的…

2023-IDEA插件推荐

CamelCase 链接 https://plugins.jetbrains.com/plugin/7160-camelcase https://github.com/netnexus/camelcaseplugin 介绍 提供下划线、驼峰等代码风格的切换。快捷键是⇧ ⌥ U / Shift Alt U GsonFormatPlus 链接 https://plugins.jetbrains.com/plugin/14949-gs…

2023/10/7 -- ARM

【程序状态寄存器读写指令】 1.指令码以及格式 mrs:读取CPSR寄存器的值 mrs 目标寄存器 CPSR&#xff1a;读取CPSR的数值保存到目标寄存器中msr:修改CPSR寄存器的数值msr CPSR,第一操作数:将第一操作数的数值保存到CPSR寄存器中//修改CPSR寄存器&#xff0c;也就表示程序的状…

从哈希表到红黑树:探讨 epoll 是如何管理事件的?

一、引言 在计算机领域&#xff0c;事件通知是一种重要的机制&#xff0c;用于监视和响应各种事件&#xff0c;例如网络连接、文件IO、定时器等。随着计算机应用变得越来越复杂&#xff0c;对于高性能事件通知机制的需求也越来越迫切。传统的事件通知机制可能存在效率低下的问…

Excel·VBA使用ADO读取工作簿工作表数据

目录 查询遍历写入数组查询整体写入数组查询工作簿所有工作表名称查询工作簿所有工作表数据 不打开工作簿读取数据&#xff0c;以下举例都为《ExcelVBA合并工作簿》中 7&#xff0c;合并子文件夹同名工作簿中同名工作表&#xff0c;纵向汇总数据所举例的工作簿&#xff0c;使用…

Angular学习笔记:路由

本文是自己的学习笔记&#xff0c;主要参考资料如下。 - B站《Angular全套实战教程》&#xff0c;达内官方账号制作&#xff0c;https://www.bilibili.com/video/BV1i741157Fj?https://www.bilibili.com/video/BV1R54y1J75g/?p32&vd_sourceab2511a81f5c634b6416d4cc1067…

Vue.js3学习篇--Vue模板应用

目录 一,模板基础 1.模板插值 &#xff08;1&#xff09;基础插值 &#xff08;2&#xff09;HTML代码插值 &#xff08;3&#xff09;标签属性插值 2.模板指令 &#xff08;1&#xff09;定义 &#xff08;2&#xff09;指令参数 二.条件渲染 1.使用v-if指令渲染 2.使…

【网络安全 --- 工具安装】Centos 7 详细安装过程及xshell,FTP等工具的安装(提供资源)

VMware虚拟机的安装教程如下&#xff0c;如没有安装&#xff0c;可以参考这篇博客安装&#xff08;提供资源&#xff09; 【网络安全 --- 工具安装】VMware 16.0 详细安装过程&#xff08;提供资源&#xff09;-CSDN博客【网络安全 --- 工具安装】VMware 16.0 详细安装过程&am…

告警繁杂迷人眼,多源分析见月明

随着数字化浪潮的蓬勃兴起&#xff0c;网络安全问题日趋凸显&#xff0c;面对指数级增长的威胁和告警&#xff0c;传统的安全防御往往力不从心。网内业务逻辑不规范、安全设备技术不成熟都会导致安全设备触发告警。如何在海量众多安全告警中识别出真正的网络安全攻击事件成为安…

数据结构(2-5~2-8)

2-5编写算法&#xff0c;在单链表中查找第一值为x的结点&#xff0c;并输出其前驱和后继的存储位置 #include<stdio.h> #include<stdlib.h>typedef int DataType; struct Node {DataType data; struct Node* next; }; typedef struct Node *PNode; …

Pikachu靶场——远程命令执行漏洞(RCE)

文章目录 1. RCE1.1 exec "ping"1.1.1 源代码分析1.1.2 漏洞防御 1.2 exec "eval"1.2.1 源代码分析1.2.2 漏洞防御 1.3 RCE 漏洞防御 1. RCE RCE(remote command/code execute)概述&#xff1a; RCE漏洞&#xff0c;可以让攻击者直接向后台服务器远程注入…

接口测试总结

一、了解一下HTTP与RPC 1. HTTP&#xff08;HyperText Transfer Protocol) 说明&#xff1a;超文本传输协议&#xff0c;是互联网上应用最为广泛的一种网络协议。 优点&#xff1a;就是简单、直接、开发方便&#xff0c;利用现成的http协议进行传输。 流程图&#xff1a; 2. R…

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】

【QT5-程序控制电源-RS232-SCPI协议-上位机-基础样例【1】】 1、前言2、实验环境3、自我总结1、基础了解仪器控制-熟悉仪器2、连接SCPI协议3、选择控制方式-程控方式-RS2324、代码编写 4、熟悉协议-SCPI协议5、测试实验-测试指令&#xff08;1&#xff09;硬件连接&#xff08;…

课题学习(三)----倾角和方位角的动态测量方法(基于陀螺仪的测量系统)

一、内容介绍 该测量系统基于三轴加速度和三轴陀螺仪&#xff0c;安装在钻柱内部&#xff0c;随钻柱一起旋转&#xff0c;形成捷联惯性导航系统&#xff0c;安装如下图所示&#xff1a;   假设三轴加速度和陀螺仪的输出为: f b [ f x f y f z ] T f^b\begin{bmatrix}f_{x} …

Docker 安装 MongoDB

一、什么是MongoDB MongoDB 是一个基于分布式文件存储的数据库。是一个介于关系数据库和非关系数据库之间的产品&#xff0c;是非关系数据库当中功能最丰富&#xff0c;最像关系数据库的。 二、MongoDB的安装 这里使用docker来安装MongoD 1.docker 拉取mysql镜像 docker pu…

论文笔记:Contrastive Trajectory Similarity Learning withDual-Feature Attention

ICDE 2023 1 intro 1.1 背景 轨迹相似性&#xff0c;可以分为两类 启发式度量 根据手工制定的规则&#xff0c;找到两条轨迹之间基于点的匹配学习式度量 通过计算轨迹嵌入之间的距离来预测相似性值上述两种度量的挑战&#xff1a; 无效性&#xff1a; 具有不同采样率或含有噪…

vue模版语法-{{}}/v-text/v-html/v-once

一、{{}}双括号&#xff1a;用于文本渲染 1、 {{变量名}}:data中返回对象的变量名 2、{{js表达式}}:可以直接进行js表达式处理 3、注意&#xff1a;双大括号中不要写等式书写 二、v-text 指令&#xff0c;用于文本渲染 1、为了解决双大括号渲染数据出现闪烁问题 三、v-cloak …