强化学习(一)——基本概念及DQN

1 基本概念

  • 智能体 agent ,做动作的主体,(大模型中的AI agent)

  • 环境 environment:与智能体交互的对象

  • 状态 state ;当前所处状态,如围棋棋局

  • 动作 action:执行的动作,如围棋可落子点

  • 奖励 reward:执行当前动作得到的奖励,(大模型中的奖励模型)

  • 策略 policy: π ( a ∣ s ) \pi(a|s) π(as) 当前状态如何选择action,如当前棋局,落子每个点的策略

  • 回报(累计奖励) return : 是从当前时刻开始到本回合结束的所有奖励的总和, U t = R t + γ R t + 1 + γ 2 R t + 2 + γ 3 R t + 3 . . . . U_t=R_t+\gamma R_{t+1}+\gamma^2R{t+2}+\gamma^3R{t+3} .... Ut=Rt+γRt+1+γ2Rt+2+γ3Rt+3....

  • 折扣回报 𝛾:

  • 动作价值函数: Q π ( s t , a t ) = E [ U t ∣ S t = s t , A t = a t ] Q_\pi (s_t,a_t)=E[U_t|S_t=s_t,A_t=a_t] Qπ(st,at)=E[UtSt=st,At=at]

  • 最优动作价值函数: Q ∗ ( s t , a t ) = m a x π Q π ( s t , a t ) Q^*(s_t,a_t)=max_\pi Q_\pi(s_t,a_t) Q(st,at)=maxπQπ(st,at)

  • 状态价值函数: V π ( s t ) = E A [ Q π ( s t , A ) ] V_\pi (s_t)=E_A[Q_\pi(s_t,A)] Vπ(st)=EA[Qπ(st,A)]

2 DQN

折扣回报: U t = R t + γ R t + 1 + γ 2 R t + 2 + γ 3 R t + 3 . . . . U_t=R_t+\gamma R_{t+1}+\gamma^2R{t+2}+\gamma^3R{t+3} .... Ut=Rt+γRt+1+γ2Rt+2+γ3Rt+3....
动作价值函数: Q π ( s t , a t ) = E [ U t ∣ S t = s t , A t = a t ] Q_\pi (s_t,a_t)=E[U_t|S_t=s_t,A_t=a_t] Qπ(st,at)=E[UtSt=st,At=at]
最优动作价值函数: Q ∗ ( s t , a t ) = m a x π Q π ( s t , a t ) Q^*(s_t,a_t)=max_\pi Q_\pi(s_t,a_t) Q(st,at)=maxπQπ(st,at)

核心公式:时间差分算法

Q ( s t , a t ; w ) = r t + γ max ⁡ a ∈ A Q ( s t + 1 , a ; w ) Q(s_t,a_t;w)=r_t+\gamma \max _{a\in A}Q(s_{t+1},a;w) Q(st,at;w)=rt+γmaxaAQ(st+1,a;w)
证明:略

公式解读及注意事项:
输入:( s t , a t , r t , s t + 1 s_t,a_t,r_t,s_{t+1} st,at,rt,st+1
左边项 Q ( s t , a t ; w ) Q(s_t,a_t;w) Q(st,at;w) : 是神经网络在t时刻的预测
右边 r t r_t rt是当前奖励值, max ⁡ a ∈ A Q ( s t + 1 , a ; w ) \max _{a\in A}Q(s_{t+1},a;w) maxaAQ(st+1,a;w)
目标:使左右两边误差最小。

DQN 是对最优动作价值函数 Q⋆ 的近似。DQN 的输入是当前状态 st,输出是每个动作的 Q 值。DQN 要求动作空间 A 是离散集合

DQN高估问题:

1 最大化导致高估, 上式中总是取最大值,会导致高估
2 自举导致高估 上式中目标函数也用自己,使用自己估计自己,会导致高估
因此可以对目标函数进行以下改进。

目标函数分析:

Q ( s t , a t ; w ) = r t + γ max ⁡ a ∈ A Q ( s t + 1 , a ; w ) Q(s_t,a_t;w)=r_t+\gamma \max _{a\in A}Q(s_{t+1},a;w) Q(st,at;w)=rt+γmaxaAQ(st+1,a;w)

  • a .左右两边可以使用统一个Q函数
    b. 左右两边使用不同Q函数
    在这里插入图片描述

c. 左右两边使用不同Q函数,且target 的 Q t a r g e t ( s t + 1 , a ; w ) Q_{target}(s_{t+1},a;w) Qtarget(st+1,a;w) 的a 来自第一个函数 max ⁡ a ∈ A Q 1 ( s t + 1 , a ; w ) \max _{a\in A}Q_1(s_{t+1},a;w) maxaAQ1(st+1,a;w)
在这里插入图片描述

  • 高估解决办法:
    b 策略可以减少自举带来的高估
    c 策略一定程度上能减少最大化带来的高估,因为用第一个Q函数中的a,在 Q t a r g e t Q_{target} Qtarget中总是小于等于最大值的 max ⁡ a ∈ A Q t a r g e t ( s t + 1 , a ; w ) \max _{a\in A}Q_{target}(s_{t+1},a;w) maxaAQtarget(st+1,a;w) (DDQN方法)

3 核心代码实现DQN,DDQN

DQN 如下代码,

self.model为Q函数
self.model_target为目标Q函数,
s_batch :当前状态
a_batch:当前执行动作
r_batch: 奖励
d_batch ; 是否游戏结束
next_s_batch; 执行动作a_batch后,到下一个状态

self.model在当前状态s_batch下得到每个状态的Q值,选择a_batch对应的Q值,即为当前Q值
self.target_model 在下一步状态next_s_batch下,取self.target_model 最大值对应到a的值(DDQN,是在self.target_model中取self.model最大值对应a的值)。

def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):# Compute current Q value based on current states and actions.qvals = self.model(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()# next state的value不参与导数计算,避免不收敛。next_qvals, _ = self.target_model(next_s_batch).detach().max(dim=1)loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)return loss

DDQN

与上面唯一区别是:使用Q1函数中的a
在这里插入图片描述

    def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):# Compute current Q value based on current states and actions.Q1=self.model(s_batch)qvals =Q1 .gather(1, a_batch.unsqueeze(1)).squeeze()a_target =Q1argmax()# next state的value不参与导数计算,避免不收敛。next_qvals = self.target_model(next_s_batch).detach().gather(1, a_target).squeeze()loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)return lossdef get_action(self, obs):qvals = self.model(obs)return 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/192690.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#——Delegate(委托)与Event(事件)

C#——Delegate(委托)与Event(事件) 前言一、Delegate(委托)1.是什么?2.怎么用?Example 1:无输入无返回值Example 2:有输入Example 3:有返回值Exa…

【C#】接口定义和使用知多少

给自己一个目标,然后坚持一段时间,总会有收获和感悟! 最近在封装和参考sdk时,看到一个不错的写法,并且打破自己对接口和实现类固定的观念,这也充分说明自己理解掌握的知识点还不够深。 目录 前言一、什么是…

Kubernetes(K8s)_16_CSI

Kubernetes(K8s)_16_CSI CSICSI实现CSI接口CSI插件 CSI CSI(Container Storage Interface): 实现容器存储的规范 本质: Dynamic Provisioning、Attach/Detach、Mount/Unmount等功能的抽象CSI功能通过3个gRPC暴露服务: IdentityServer、ControllerServe…

C++二维数组名到底代表个啥

题目先导 int a[3][4]; 则对数组元素a[i][j]正确的引用是*(*(ai)j)先翻译一下这个*(*(ai)j),即a后移i解引用,再后移j再解引用,这么看来a就应该是个二维数组,第一层存储行向量,一次解引用获得行向量的地址,…

LLM推理部署(三):一个强大的LLM生态系统GPT4All

GPT4All,这是一个开放源代码的软件生态系,它让每一个人都可以在常规硬件上训练并运行强大且个性化的大型语言模型(LLM)。Nomic AI是此开源生态系的守护者,他们致力于监控所有贡献,以确保质量、安全和可持续…

听GPT 讲Rust源代码--src/tools(6)

File: rust/src/tools/rust-analyzer/crates/ide/src/references.rs 在Rust源代码中,references.rs文件位于rust-analyzer工具的ide模块中,其作用是实现了用于搜索引用的功能。 该文件包含了多个重要的结构体、特质和枚举类型,我将逐一介绍它…

node.js-连接SQLserver数据库

1.在自己的项目JS文件夹中建文件:config.js、mssql.js和server.js以及api文件夹下的user.js 2.在config.js中封装数据库信息 let app {user: sa, //这里写你的数据库的用户名password: ,//这里写数据库的密码server: localhost,database: medicineSystem, // 数据…

OpenSSH 漏洞修复升级最新版本

Centos7系统ssh默认版本一般是OpenSSH7.4左右,低版本是有漏洞的而且是高危漏洞,在软件交付和安全扫描上是过不了关的,一般情况需要升级OpenSSH的最新版本 今天详细说下升级最新版本的处理过程(认真看会发现操作很简单&#xff0c…

Best Rational Approximation ——二分

许多微控制器没有浮点单元,但确实有一个(合理)快速整数除法单元。在这些情况下,使用有理值来近似浮点常数可能是值得的. 例如,355/113 3.1415929203539823008849557522124 是 π 3.14159265358979323846 一个很好的近…

【教学类-06-12】20231202 0-9数字分合-房屋样式(一)-下右空-升序-抽7题

作品展示-屋顶分合(0-9之间随机抽取7个不重复分合) 背景需求: 大班幼儿学分合题,通常区角里会设计一个“房屋分合”的样式 根据这种房屋样式,设计0-9内的升序分合题模板 素材准备 WORD样式 代码展示: 2-9…

PlantUML语法(全)及使用教程-用例图

目录 1. 用例图1.1、什么是用例图1.2、用例图的构成1.3、参与者1.4、用例1.4.1、用例基本概念1.4.2、用例的识别1.4.3、用例的要点1.4.3、用例的命名1.4.4、用例的粒度 1.5、应用示例1.5.1、用例1.5.2、角色1.5.3、改变角色的样式1.5.4、用例描述1.5.5、改变箭头方向1.5.6、使用…

AI创作ChatGPT源码+AI绘画(Midjourney绘画)+DALL-E3文生图+思维导图生成

一、AI创作系统 SparkAi创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统,支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI…

C语言——指针(四)

📝前言: 上篇文章C语言——指针(三)对指针和数组进行了讲解,今天主要更深入的讲解一下不同类型指针变量的特点: 1,字符指针变量 2,数组指针变量 3,函数指针变量 &#x1…

Spring boot命令执行 (CVE-2022-22947)漏洞复现和相关利用工具

Spring boot命令执行 (CVE-2022-22947)漏洞复现和相关利用工具 名称: spring 命令执行 (CVE-2022-22947) 描述: Spring Cloud Gateway是Spring中的一个API网关。其3.1.0及3.0.6版本(包含)以前存在一处SpEL表达式注入漏洞,当攻击者可以访问A…

2022年8月2日 Go生态洞察:Go 1.19版本发布深度解析

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

6-63.圆类的定义与使用(拷贝构造函数)

本题要求完成一个圆类的定义,设计适当的函数:包括构造函数、拷贝构造函数以及析构函数,从而可以通过测试程序输出样例 在这里给出一组输入。例如: 5 输出样例: 在这里给出相应的输出。例如: Constructo…

本项目基于Spring boot的AMQP模块,整合流行的开源消息队列中间件rabbitMQ,实现一个向rabbitMQ

在业务逻辑的异步处理,系统解耦,分布式通信以及控制高并发的场景下,消息队列有着广泛的应用。本项目基于Spring的AMQP模块,整合流行的开源消息队列中间件rabbitMQ,实现一个向rabbitMQ添加和读取消息的功能。并比较了两种模式&…

osg LOD节点动态调度

1、LOD节点 LOD(level of detail):是指根据物体模型的结点在显示环境中所处的位置和重要度,决定物体渲染的资源分配,降低非重要物体的面数和细节度,从而获得高效率的渲染运算。在OSG的场景结点组织结构中&…

mongoose学习记录

mongoose安装和连接数据库 npm i mongoose导入mongoose const mongoose require(mongoose) mongoose.set("strictQuery",true)连接数据库 mongoose.connect(mongodb:127.0.0.1:27017/test)设置回调 mongoose.connection.on(open,()>{console.log("连接成…

规则引擎专题---3、Drools组成和入门

Drools概述 drools是一款由JBoss组织提供的基于Java语言开发的开源规则引擎,可以将复杂且多变的业务规则从硬编码中解放出来,以规则脚本的形式存放在文件或特定的存储介质中(例如存放在数据库中),使得业务规则的变更不需要修改项目代码、重启…