强化学习算法DQN实现

DQN的基本思想

  1. Q学习:Q学习是一种基于值函数的强化学习方法,目的是通过学习状态-动作值函数Q(s, a)来指导智能体的动作选择。Q函数表示在状态s采取动作a后能够获得的期望总回报。

  2. 深度神经网络:使用深度神经网络来近似Q函数。输入是状态s,输出是每个动作的Q值。神经网络的参数通过与目标Q值的均方误差(MSE)损失函数进行反向传播来更新。

  3. 经验回放:经验回放机制用于解决样本之间的相关性问题。通过存储智能体的经验(状态,动作,奖励,下一个状态,是否终止)到回放池中,并从中随机抽取小批量样本进行训练,打破了样本之间的相关性,提高了样本利用效率。

  4. 目标网络:为了增强训练的稳定性,DQN引入了目标网络。目标网络的结构和参数与Q网络相同,但参数更新频率较低。目标Q值使用目标网络来计算,避免了训练过程中参数震荡的问题。

详细的训练过程

  1. 初始化:初始化Q网络和目标网络,设置超参数和经验回放池。

  2. 交互环境:在每一回合中,智能体根据当前策略与环境进行交互,选择动作并获得奖励,存储经验到回放池中。

  3. 经验采样:当回放池中的经验数量足够时,从中随机抽取一个小批量样本用于训练。

  4. 计算目标Q值:使用目标网络计算目标Q值,对于每个样本,目标Q值等于即时奖励加上下一状态的最大Q值乘以折扣因子。

  5. 更新Q网络:通过最小化预测Q值和目标Q值之间的均方误差来更新Q网络的参数。

  6. 更新目标网络:每隔一段时间,将Q网络的参数复制到目标网络中。

  7. 探索与利用:采用ε-greedy策略选择动作,即以ε的概率随机选择动作,以1-ε的概率选择当前Q网络认为最优的动作。随着训练的进行,ε逐渐减小,以增加利用率。

  8. 训练结束:在达到设定的回合数后,结束训练过程。

import gym
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers# 环境设置
env = gym.make('CartPole-v1')# 超参数设置
gamma = 0.99  # 折扣因子
epsilon = 1.0  # 探索率
epsilon_min = 0.01  # 最小探索率
epsilon_decay = 0.995  # 探索率衰减
learning_rate = 0.001  # 学习率
batch_size = 64  # 批量大小
memory_size = 2000  # 经验回放池大小# 构建Q网络
def build_model(state_shape, action_size):model = tf.keras.Sequential()model.add(layers.Dense(24, input_shape=state_shape, activation='relu'))model.add(layers.Dense(24, activation='relu'))model.add(layers.Dense(action_size, activation='linear'))model.compile(loss='mse', optimizer=tf.keras.optimizers.Adam(lr=learning_rate))return model# 经验回放池
class ReplayMemory:def __init__(self, max_size):self.buffer = []self.max_size = max_sizedef add(self, experience):if len(self.buffer) >= self.max_size:self.buffer.pop(0)self.buffer.append(experience)def sample(self, batch_size):idx = np.random.choice(len(self.buffer), size=batch_size, replace=False)return [self.buffer[i] for i in idx]# 训练DQN
def train_dqn(episodes):state_size = env.observation_space.shape[0]action_size = env.action_space.nmodel = build_model((state_size,), action_size)target_model = build_model((state_size,), action_size)target_model.set_weights(model.get_weights())memory = ReplayMemory(memory_size)for episode in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])total_reward = 0while True:if np.random.rand() <= epsilon:action = np.random.choice(action_size)else:q_values = model.predict(state)action = np.argmax(q_values[0])next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])total_reward += rewardmemory.add((state, action, reward, next_state, done))state = next_stateif done:print(f"Episode: {episode + 1}, Reward: {total_reward}, Epsilon: {epsilon:.2f}")breakif len(memory.buffer) >= batch_size:experiences = memory.sample(batch_size)states, actions, rewards, next_states, dones = zip(*experiences)states = np.vstack(states)next_states = np.vstack(next_states)q_values = model.predict_on_batch(states)q_values_next = target_model.predict_on_batch(next_states)for i in range(batch_size):q_values[i][actions[i]] = rewards[i] + (1 - dones[i]) * gamma * np.amax(q_values_next[i])model.train_on_batch(states, q_values)if epsilon > epsilon_min:epsilon *= epsilon_decayif (episode + 1) % 10 == 0:target_model.set_weights(model.get_weights())train_dqn(500)
env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/46739.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Ngix】快速上手,由浅入深

内容概述 1、nginx 简介 &#xff08;1&#xff09;介绍 nginx 的应用场景和具体可以做什么事情 &#xff08;2&#xff09;介绍什么是反向代理 &#xff08;3&#xff09;介绍什么是负载均衡 &#xff08;4&#xff09;介绍什么是动静分离 2、nginx 安装 &#xff08;1…

Leetcode - 周赛406

目录 一&#xff0c;3216. 交换后字典序最小的字符串 二&#xff0c;3217. 从链表中移除在数组中存在的节点 三&#xff0c;3218. 切蛋糕的最小总开销 I 四&#xff0c;3219. 切蛋糕的最小总开销 II 一&#xff0c;3216. 交换后字典序最小的字符串 本题要求交换一次相邻字符…

C++编程逻辑讲解step by step:静态数组长度后确定还能编译成功

定义 定义一维数组的一般格式为 类型标识符 数组名&#xff3b;常量表达式&#xff3d;&#xff1b; 例如&#xff1a; int a&#xff3b;10&#xff3d;; 问题 很多人写成这样&#xff1a; int n; cin>>n; int a[n]; 这个写法已经明确&#xff0c;是错的&…

新建vue项目和安装第三方库

安装vue 打开vscode编辑器&#xff0c;按Ctrl组合键打开终端&#xff0c;在命令行中运行以下命令 npm create vuelatest项目初始化完成&#xff0c;可执行以下命令&#xff1a; cd vue-project --切换到项目目录 npm install -- 安装依赖包 npm run dev -- 运行项目安装 …

大数据架构对比记录

Lambda架构 -维护两套项目&#xff0c;开发和维护成本高 -两套链路&#xff0c;数据容易不一致 -数据计算成本大&#xff08;例如原定每小时计算一次&#xff0c;但有额外新需求需要计算两点半-三点半之间数据&#xff0c;则需要重新计算&#xff09; Kappa -过于依赖kafka消…

FPGA:基于复旦微FMQL10S400 /FMQL20S400 国产化核心板

复旦微电子是国内集成电路设计行业的领军企业之一&#xff0c;早在2000年就在香港创业板上市&#xff0c;成为行业内首家上市公司。公司的RFID芯片、智能卡芯片、EEPROM、智能电表MCU等多种产品在市场上的占有率位居行业前列。 今天介绍的是搭载复旦微 FMQL10S400/FMQL20S400的…

嵌入式Linux应用开发基础-现有动态库so的使用

前言 最近做嵌入式Linux项目&#xff0c;需要调用客户提供的现成的动态库(so文件&#xff0c;包含对应头文件)&#xff0c;我这边用的是cmake来构建。 此篇文章主要是记录一下嵌入式Linux的动态库的使用&#xff0c;与君共勉&#xff01; 一、通过cmake使用so库和对应的头文件…

01数据结构 - 顺序表

这里是只讲干货不讲废话的炽念&#xff0c;这个系列的文章是为了我自己以后复习数据结构而写&#xff0c;所以可能会用一种我自己能够听懂的方式来描述&#xff0c;不会像书本上那么枯燥和无聊&#xff0c;且全系列的代码均是可运行的代码&#xff0c;关键地方会给出注释^_^ 全…

C++客户端Qt开发——常用控件(容器类控件)

6.容器类控件 ①GroupBox 带标题分组框 属性 说明 title 分组框的标题 alignment 分组框内部内容的对齐方式 flat 是否是"扁平"模式 checkable 是否可选择 设为true,则在title前方会多出一个可勾选的部分. check 描述分组框的选择状态&#xff08;前提…

数据结构(5.1)——树的性质

结点数总度数1 结点的度——结点有几个孩子(分支) 度为m的树、m叉树的区别 度为m的树第i层至多有 个结点(i>1) 高度为h的m叉树至多有 个结点 高度为h的m叉树至少有h个结点 、高度为h&#xff0c;度为m叉树至多有hm-1个结点 具有n个结点的m叉树的最小高度为 总结

通过角点进行水果的果梗检测一种新方法

一、前言 在前面的《数字图像处理与机器视觉》案例一&#xff08;库尔勒香梨果梗提取和测量&#xff09;中主要使用数学形态学的方法进行果梗提取&#xff0c;下面给出一种提取果梗的新思路。 众所周知&#xff0c;一般果梗和果实在边缘处角度有较大突变&#xff0c;可以通过合…

探索WebKit的CSS列表与标记:美化列表的艺术

探索WebKit的CSS列表与标记&#xff1a;美化列表的艺术 CSS列表和标记是网页设计中用于增强列表展示效果的重要工具。WebKit&#xff0c;作为多种现代浏览器的内核&#xff0c;包括Safari、QQ浏览器等&#xff0c;提供了对CSS列表和标记的广泛支持。本文将深入探讨WebKit对CSS…

spring security源码追踪理解(一)

一、前言 近期看了spring security相关的介绍&#xff0c;再加上项目所用若依框架的底层安全模块也是spring security&#xff0c;所以想从源码的角度加深下对该安全模块的理解&#xff08;看源码之前&#xff0c;我们要先有个意识&#xff0c;那就是spring security安全模块主…

Solus Linux简介

以下是学习笔记&#xff0c;具体详实的内容请参考官网&#xff1a;Home | Solus Solus Linux 是一个独立的 Linux 发行版&#xff0c;它以其现代的设计、优化的性能和友好的用户体验而著称。以下是一些关于 Solus Linux 的最新动向和特点&#xff1a; 1. **最新版本发布**&a…

第122天:内网安全-域信息收集应用网络凭据CS 插件AdfindBloodHound

目录 前置知识 背景和思路 判断是否在域内 案例一&#xff1a;架构信息类收集-网络&用户&域控等 案例二&#xff1a;自动化工具探针-插件&Adfind&BloodHound Adfind(域信息收集工具) ​BloodHound&#xff08;自动化域渗透工具&#xff09; 前置知识 本…

计算机视觉10 总结

全卷积网络&#xff08;FCN&#xff09;是计算机视觉中用于处理图像任务的重要网络架构。 核心要点&#xff1a; 与传统 CNN 不同&#xff0c;FCN 将最后的全连接层替换为卷积层&#xff0c;从而能够处理任意尺寸的输入图像&#xff0c;并保留了空间信息。优点包括可处理不同大…

java基础万字笔记

前言 此篇文章为本人在初学java时所记录的java基础的笔记&#xff0c;其中全面记录了java的基础知识点以及自己的一些理解和要注意的点。由于该笔记是边学边记录而成&#xff0c;所以基本很多模块内都会有一些我本人后期记录的知识穿插进去&#xff0c;导致一些模块内的内容并…

搭建个人智能家居 7 - 空气颗粒物检测

搭建个人智能家居 7 - 空气颗粒物检测 前言说明PMS5003ESPHomeHomeAssistant结束 前言 到目前为止&#xff0c;我们这个智能家居系统添加了4个外设&#xff0c;分别是&#xff1a;LED灯、RGB灯、DHT11温度传感器和SGP30。今天继续添加环境测量类传感器“PMS5003空气颗粒物检测…

Django获取request请求中的参数

支持 post put json_str request.body # 属性获取最原始的请求体数据 json_dict json.loads(json_str)# 将原始数据转成字典格式 json_dict.get("key", "默认值") # 获取数据参考 https://blog.csdn.net/user_san/article/details/109654028

Windows FFmpeg 开发环境搭建

FFmpeg 开发环境搭建 FFmpeg命令行环境搭建使用FFmpeg官方编译的库Windows编译FFmpeg1. 下载[msys2](https://www.msys2.org/#installation)2. 安装完成之后,将安装⽬录下的msys2_shell.cmd中注释掉的 rem set3. 修改pacman 镜像源并安装依赖4. 下载并编译源码 FFmpeg命令行环境…