强化学习------DQN算法

简介

DQN,即深度Q网络(Deep Q-network),是指基于深度学习的Q-Learing算法。Q-Learing算法维护一个Q-table,使用表格存储每个状态s下采取动作a获得的奖励,即状态-价值函数Q(s,a),这种算法存在很大的局限性。在现实中很多情况下,强化学习任务所面临的状态空间是连续的,存在无穷多个状态,这种情况就不能再使用表格的方式存储价值函数。
为了解决这个问题,我们可以用一个函数Q(s,a;w)来近似动作-价值Q(s,a),称为价值函数近似Value Function Approximation,我们用神经网络来生成这个函数Q(s,a;w),称为Q网络(Deep Q-network),w是神经网络训练的参数。

Q-Learning参考:https://blog.csdn.net/niulinbiao/article/details/133659036

DQN相较于传统的强化学习算法(Q-learning)有三大重要的改进:

  • 引入深度学习中的神经网络,利用神经网络去拟合Q-learning中的Q表,解决了Q-learning中,当状态维数过高时产生的“维数灾难”问题;

  • 固定Q目标网络,利用延后更新的目标网络计算目标Q值,极大的提高了网络训练的稳定性和收敛性;

  • 引入经验回放机制,使得在进行网络更新时输入的数据符合独立同分布,打破了数据间的相关性。

本文还增加了动态探索概率,也就是随着模型的训练,我们有必要减少探索的概率

DQN的算法流程如下:

在这里插入图片描述

  • 首先,算法开始前随机选择一个初始状态,然后基于这个状态选择执行动作,这里需要进行一个判断,即是通过Q-Network选择一个Q值最大对应的动作,还是在动作空间中随机选择一个动作。
  • 在程序编程中,由于刚开始时,Q-Network中的相关参数是随机的,所以在经验池存满之前,通常将设置的很小,即初期基本都是随机选择动作。
  • 在动作选择结束后,agent将会在环境(Environment)中执行这个动作,随后环境会返回下一状态(S_)和奖励(R),这时将四元组(S,A,R,S_)存入经验池。
  • 接下来将下一个状态(S_)视为当前状态(S),重复以上步骤,直至将经验池存满。
  • 当经验池存满之后,DQN中的网络开始更新。即开始从经验池中随机采样,将采样得到的奖励(R)和下一个状态(S_)送入目标网络计算下一Q值(y),并将y送入Q-Network计算loss值,开始更新Q-Network。往后就是agent与环境交互,产生经验(S,A,R,S_),并将经验放入经验池,然后从经验池中采样更新Q-Network,周而复始,直到Q-Network完成收敛。

在这里插入图片描述

  • DQN中目标网络的参数更新是硬更新,即主网络(Q-Network)参数更新一定步数后,将主网络更新后的参数全部复制给目标网络(Target
    Q-Network)。
  • 在程序编程中,通常将设置成随训练步数的增加而递增,即agent越来越信任Q-Network来指导动作。

代码实现

1、环境准备

我们选择openAIgym环境作为我们训练的环境

  env1 = gym.make("CartPole-v0")

在这里插入图片描述

2、编写经验池函数

经验池的主要内容就是,存数据和取数据

import random
import collections
from torch import FloatTensorclass ReplayBuffer(object):# 初始化def __init__(self, max_size, num_steps=1 ):""":param max_size: 经验吃大小:param num_steps: 每经过训练num_steps次后,函数就学习一次"""self.buffer = collections.deque(maxlen=max_size)self.num_steps  = num_stepsdef append(self, exp):"""想经验池添加数据:param exp: :return: """self.buffer.append(exp)def sample(self, batch_size):"""向经验池中获取batch_size个(obs_batch,action_batch,reward_batch,next_obs_batch,done_batch)这样的数据:param batch_size: :return: """mini_batch = random.sample(self.buffer, batch_size)obs_batch, action_batch, reward_batch, next_obs_batch, done_batch = zip(*mini_batch)obs_batch = FloatTensor(obs_batch)action_batch = FloatTensor(action_batch)reward_batch = FloatTensor(reward_batch)next_obs_batch = FloatTensor(next_obs_batch)done_batch = FloatTensor(done_batch)return obs_batch,action_batch,reward_batch,next_obs_batch,done_batchdef __len__(self):return len(self.buffer)

3、神经网络模型

我们简单地使用神经网络

import torchclass MLP(torch.nn.Module):def __init__(self, obs_size,n_act):super().__init__()self.mlp = self.__mlp(obs_size,n_act)def __mlp(self,obs_size,n_act):return torch.nn.Sequential(torch.nn.Linear(obs_size, 50),torch.nn.ReLU(),torch.nn.Linear(50, 50),torch.nn.ReLU(),torch.nn.Linear(50, n_act))def forward(self, x):return self.mlp(x)

4、探索率衰减函数

随着训练过程,我们动态地减小探索率,因为训练到后面,模型会越来越收敛,没必要继续探索

#!/usr/bin/env python 
# -*- coding:utf-8 -*-
import numpy as npclass EpsilonGreedy():def __init__(self,n_act,e_greed,decay_rate):self.n_act = n_actself.epsilon = e_greedself.decay_rate = decay_ratedef act(self,predict_func,obs):if np.random.uniform(0, 1) < self.epsilon:  # 探索action = np.random.choice(self.n_act)else:  # 利用action = predict_func(obs)self.epsilon = max(0.01,self.epsilon-self.decay_rate)   #是探索率最低为0.01return action

5、DQN算法

import copyimport numpy as np
import torch
from utils import torchUtils# 添加探索值递减的策略
class DQNAgent(object):def __init__( self, q_func, optimizer, replay_buffer, batch_size, replay_start_size,update_target_steps, n_act,explorer, gamma=0.9):''':param q_func: Q函数:param optimizer: 优化器:param replay_buffer: 经验回放器:param batch_size: 批次数量:param replay_start_size: 开始回放的次数:param update_target_steps: 经过多少步才会同步target网络:param n_act: 动作数量:param gamma: 收益衰减率:param e_greed: 探索与利用中的探索概率'''self.pred_func = q_funcself.target_func = copy.deepcopy(q_func)self.update_target_steps = update_target_stepsself.explorer = explorerself.global_step = 0  #全局self.rb = replay_bufferself.batch_size = batch_sizeself.replay_start_size = replay_start_sizeself.optimizer = optimizerself.criterion = torch.nn.MSELoss()self.n_act = n_act  # 动作数量self.gamma = gamma  # 收益衰减率# 根据经验得到actiondef predict(self, obs):obs = torch.FloatTensor(obs)Q_list = self.pred_func(obs)action = int(torch.argmax(Q_list).detach().numpy())return action# 根据探索与利用得到actiondef act(self, obs):return self.explorer.act(self.predict,obs)def learn_batch(self,batch_obs, batch_action, batch_reward, batch_next_obs, batch_done):# predict_Qpred_Vs = self.pred_func(batch_obs)action_onehot = torchUtils.one_hot(batch_action, self.n_act)predict_Q = (pred_Vs * action_onehot).sum(1)# target_Qnext_pred_Vs = self.target_func(batch_next_obs)best_V = next_pred_Vs.max(1)[0]target_Q = batch_reward + (1 - batch_done) * self.gamma * best_V# 更新参数self.optimizer.zero_grad()loss = self.criterion(predict_Q, target_Q)loss.backward()self.optimizer.step()def learn(self, obs, action, reward, next_obs, done):self.global_step+=1self.rb.append((obs, action, reward, next_obs, done))#当经验池中到的数据足够多时,并且满足每训练num_steps轮就更新一次参数if len(self.rb) > self.replay_start_size and self.global_step%self.rb.num_steps==0:self.learn_batch(*self.rb.sample(self.batch_size))#我们每训练update_target_steps轮就同步目标网络if self.global_step%self.update_target_steps==0:self.sync_target()# 同步target网络def sync_target(self):for target_param,param in zip(self.target_func.parameters(),self.pred_func.parameters()):target_param.data.copy_(param.data)

6、训练代码


import dqn,modules,replay_buffers
import gym
import torch
from explorers import  EpsilonGreedyclass TrainManager():def __init__(self,env,  #环境episodes=1000,  #轮次数量batch_size=32,  #每一批次的数量num_steps=4,  #进行学习的频次memory_size = 2000,  #经验回放池的容量replay_start_size = 200,  #开始回放的次数update_target_steps=200,  #经过训练update_target_steps次后将参数同步给target网络lr=0.001,  #学习率gamma=0.9,  #收益衰减率e_greed=0.1,  #探索与利用中的探索概率e_greed_decay=1e-6, #探索率衰减值):self.env = envself.episodes = episodesn_act = env.action_space.nn_obs = env.observation_space.shape[0]q_func = modules.MLP(n_obs, n_act)optimizer = torch.optim.AdamW(q_func.parameters(), lr=lr)rb = replay_buffers.ReplayBuffer(memory_size,num_steps)explorer = EpsilonGreedy(n_act,e_greed,e_greed_decay)self.agent = dqn.DQNAgent(q_func=q_func,optimizer=optimizer,replay_buffer = rb,batch_size=batch_size,update_target_steps=update_target_steps,replay_start_size = replay_start_size,n_act=n_act,explorer = explorer,gamma=gamma)# 训练一轮游戏def train_episode(self):total_reward = 0obs = self.env.reset()while True:action = self.agent.act(obs)next_obs, reward, done, _ = self.env.step(action)total_reward += rewardself.agent.learn(obs, action, reward, next_obs, done)obs = next_obsif done: breakprint('e_greed=',self.agent.explorer.epsilon)return total_reward# 测试一轮游戏def test_episode(self):total_reward = 0obs = self.env.reset()while True:action = self.agent.predict(obs)next_obs, reward, done, _ = self.env.step(action)total_reward += rewardobs = next_obsself.env.render()if done: breakreturn total_rewarddef train(self):for e in range(self.episodes):ep_reward = self.train_episode()print('Episode %s: reward = %.1f' % (e, ep_reward))#每训练100轮我们就测试一轮if e % 100 == 0:test_reward = self.test_episode()print('test reward = %.1f' % (test_reward))if __name__ == '__main__':env1 = gym.make("CartPole-v0")tm = TrainManager(env1)tm.train()

实现效果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/98254.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎么用蜂邮EDM和Outlook批量发送邮件带附件

蜂邮EDM和Outlook批量发送邮件带附件的流程&#xff1f;有哪些邮件批量发送邮件附件的方法&#xff1f; 在现代社会中&#xff0c;电子邮件是一种广泛应用的沟通工具&#xff0c;而批量发送邮件带附件则是许多商业和个人用户的常见需求。本文将介绍如何使用蜂邮EDM和Outlook这…

高通camx开源部分简介

camera整体框架 ISP Pipeline diagram Simple Model Camx and chi_cdk 整体框架 CtsVerifier, Camra Formats Topology of Camera Formats. Topology (USECASE: UsecaseVideo) Nodes List Links between nodes Pipeline PreviewVideo Buffer manager Create Destro…

攻防世界-fakebook

打开题目链接 尝试弱口令登录 失败 随便注册 点击admin后跳转到下面这个页面 显示的是注册用户信息&#xff0c;观察url发现no1&#xff0c;猜测存在注入 用单引号测试一下&#xff0c;报错&#xff0c;确实存在SQL注入 使用order by 判断字段数 ?no1 order by 5 5的时候…

2.2.3 vim操作合集

1 vim VIM 是 Linux 系统上一款文本编辑器,学习 VIM 最好的文档,应该是阅读学习 VIM 的帮助文档,可以使用本地的帮助文件(vim--->:help),或者使用在线帮助文档。同时针对vim的使用,相应的相书籍也很多,如下 2 vim操作模式 命令模式:默认模式,该模式下可以移动光标…

【Java】微服务——Nacos配置管理(统一配置管理热更新配置共享Nacos集群搭建)

目录 1.统一配置管理1.1.在nacos中添加配置文件1.2.从微服务拉取配置1.3总结 2.配置热更新2.1.方式一2.2.方式二2.3总结 3.配置共享1&#xff09;添加一个环境共享配置2&#xff09;在user-service中读取共享配置3&#xff09;运行两个UserApplication&#xff0c;使用不同的pr…

【云备份项目】【Linux】:环境搭建(g++、json库、bundle库、httplib库)

文章目录 1. g 升级到 7.3 版本2. 安装 jsoncpp 库3. 下载 bundle 数据压缩库4. 下载 httplib 库从 Win 传输文件到 Linux解压缩 1. g 升级到 7.3 版本 &#x1f517;链接跳转 2. 安装 jsoncpp 库 &#x1f517;链接跳转 3. 下载 bundle 数据压缩库 安装 git 工具 sudo yum…

C++ 字符串

在本文中&#xff0c;您将学习如何在C中处理字符串。您将学习声明它们&#xff0c;对其进行初始化以及将它们用于各种输入/输出操作。 字符串是字符的集合。C 编程语言中通常使用两种类型的字符串&#xff1a; 作为字符串类对象的字符串&#xff08;标准C 库字符串类&#xff0…

小米、华为、iPhone、OPPO、vivo如何在手机让几张图拼成一张?

现在很多手机自带的相册APP已经有这个拼图功能了。 华为手机的拼图 打开图库&#xff0c;选定需要拼图的几张图片后&#xff0c;点击底部的【创作】&#xff0c;然后选择【拼图】就可以将多张图片按照自己想要的位置&#xff0c;组合在一起。 OPPO手机的拼图 打开相册&#…

Nginx配置文件的通用语法介绍

要是参考《Ubuntu 20.04使用源码安装nginx 1.14.0》安装nginx的话&#xff0c;nginx配置文件在/nginx/conf目录里边&#xff0c;/nginx/conf里边的配置文件结构如下图所示&#xff1a; nginx.conf是主配置文件&#xff0c;它是一个ascii文本文件。配置文件由指令&#xff08;…

【数据结构】二叉树--顺序结构及实现 (堆)

目录 一 二叉树的顺序结构 二 堆的概念及结构 三 堆的实现 1 包含所有接口 (Heap.h) 2 初始化,销毁和交换&#xff08;Heap.c) 3 向上调整&#xff08;Heap.c) 4 插入&#xff08;Heap.c) ​5 向下调整&#xff08;Heap.c) 6 删除&#xff08;Heap.c) ​7 打印&#…

数据统计--图形报表--ApacheEcharts技术 --苍穹外卖day10

Apache Echarts 营业额统计 重点:已完成订单金额要排除其他状态的金额 根据时间选择区间 设计vo用于后端向前端传输数据,dto用于后端接收前端发送的数据 GetMapping("/turnoverStatistics")ApiOperation("营业额统计")public Result<TurnoverReportVO…

叶工好容6-自定义与扩展

本篇主要介绍扩展的本质以及CRD与Operator之间的区别&#xff0c;帮助大家理解相关的概念以及知道要进行扩展需要做哪些工作。 CRD&#xff08;CustomerResourceDefinition&#xff09; 自定义资源定义,代表某种自定义的配置或者独立运行的服务。 用户只定义了CRD没有任何意…

课题学习(五)----阅读论文《抗差自适应滤波的导向钻具动态姿态测量方法》

一、简介 抗差自适应滤波&#xff1a;利用等价权函数和自适应因子合理的分配信息&#xff0c;有效地滤除钻具振动对动态姿态测量的影响。、   针对导向钻井工具动态测量受钻具振动的影响而导致测量不准确的问题&#xff0c;提出一种抗差自适应滤波的动态空间姿态测量方法。通…

MySQL:主从复制-基础复制(6)

环境 主服务器 192.168.254.1 从服务器&#xff08;1&#xff09;192.168.254.2 从服务器&#xff08;2&#xff09;192.168.253.3 我在主服务器上执行的操作会同步至从服务器 主服务器 yum -y install ntp 我们去配置ntp是需要让从服务器和我们主服务器时间同步 sed -i /…

WPS/word 表格跨行如何续表、和表的名称

1&#xff1a;具体操作&#xff1a; 将光标定位在跨页部分的第一行任意位置&#xff0c;按下快捷键ctrlshiftenter&#xff0c;就可以在跨页的表格上方插入空行&#xff08;在空行可以写&#xff0c;表1-3 xxxx&#xff08;续&#xff09;&#xff09; 在空行中输入…

好物周刊#19:开源指北

https://github.com/cunyu1943/JavaPark https://yuque.com/cunyu1943 村雨遥的好物周刊&#xff0c;记录每周看到的有价值的信息&#xff0c;主要针对计算机领域&#xff0c;每周五发布。 一、项目 1. Vditor 一款浏览器端的 Markdown 编辑器&#xff0c;支持所见即所得、…

韩语学习|韩语零基础|柯桥韩语学校,每日一词

今日一词:개방도 평지 韩语每日一词打卡:개방도[개방도]【名词】开放度,开放程度 原文&#xff1a;한 지역의 개방도는 경제 발전 수준에 달려 있습니다. 意思&#xff1a;一个地区的开放程度取决于经济发展水平。 【原文分解】 1、경제[경제]经济 2、지역[지역]地域 3、발전[발…

PHP 个人愿望众筹网站系统mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP 个人愿望众筹网站系统是一套完善的web设计系统&#xff0c;对理解php编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。 php 个人愿望众筹网站 代码 https://download.csdn.net/download/qq_41221322/8…

在Android中实现动态应用图标

在Android中实现动态应用图标 你可能已经遇到过那些能够完成一个神奇的技巧的应用程序——在你的生日时改变他们的应用图标&#xff0c;然后无缝切换回常规图标。这是一种引发你好奇心的功能&#xff0c;让你想知道&#xff0c;“他们到底是如何做到的&#xff1f;”。嗯&…

Unity实现设计模式——模板方法模式

Unity实现设计模式——模板方法模式 模板模式(Template Pattern)&#xff0c; 指在一个抽象类公开定义了执行它的方法的模板。它的子类可以按需要重写方法实现&#xff0c;但调用将以抽象类中定义的方式进行。 简单说&#xff0c; 模板方法模式定义一个操作中的算法的骨架&…