[PyTorch][chapter 64][强化学习-DQN]

前言:

            DQN 就是结合了深度学习和强化学习的一种算法,最初是 DeepMind 在 NIPS 2013年提出,它的核心利润包括马尔科夫决策链以及贝尔曼公式。

            Q-learning的核心在于Q表格,通过建立Q表格来为行动提供指引,但这适用于状态和动作空间是离散且维数不高时,当状态和动作空间是高维连续时Q表格将变得十分巨大,对于维护Q表格和查找都是不现实的。


1: DQN 历史

2:  DQN 网络参数配置

3:DQN 网络模型搭建


一 DQN 历史

     DQN 跟机器学习的时序差分学习里面的Q-Learning 算法相似

    1.1 Q-Learning 算法

在Q Learning 中,我们有个Q table ,记录不同状态下,各个动作的Q 值

我们通过Q table 更新当前的策略

Q table 的作用: 是我们输入S,通过查表返回能够获得最大Q值的动作A.

但是很多场景状态S 并不是离散的,很难去定义

 1.2  DQN 发展史

     Deep network+Q-learning = DQN

     DQN 和 Q-tabel 没有本质区别:

     Q-table: 内部维护 Q Tabel

     DQN:   通过神经网络  a= NN(s), 替代了 Q Tabel

   


二 网络模型

    2.1 DQN 算法

  2.1 模型

模型参数


三  代码实现:

 5.1 main.py

   

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 16:53:02 2023@author: chengxf2
"""import numpy as np
import torch
import gym
import random 
from Replaybuffer import Replay
from Agent import DQN
import rl_utils
import matplotlib.pyplot as plt
from tqdm import tqdm  #生成进度条lr = 5e-3
hidden_dim = 128
num_episodes = 500
minimal_size = 500
gamma = 0.98
epsilon =0.01
target_update = 10
buffer_size = 10000
mini_size = 500
batch_size = 64
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")if __name__ == "__main__":env_name = 'CartPole-v0'env = gym.make(env_name)random.seed(0)np.random.seed(0)env.seed(0)torch.manual_seed(0)replay_buffer = Replay(buffer_size)state_dim = env.observation_space.shape[0]action_dim = env.action_space.nagent = DQN(state_dim, hidden_dim, action_dim, lr, gamma, epsilon,target_update, device)return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += reward# 当buffer数据的数量超过一定值后,才进行Q网络训练if replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s,'actions': b_a,'next_states': b_ns,'rewards': b_r,'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))plt.figure(1) plt.subplot(1, 2, 1)  # fig.1是一个一行两列布局的图,且现在画的是左图plt.plot(episodes_list, return_list,c='r')plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN on {}'.format(env_name))plt.figure(1)  # 当前要处理的图为fig.1,而且当前图是fig.1的左图plt.subplot(1, 2, 2)  # 当前图变为fig.1的右图mv_return = rl_utils.moving_average(return_list, 9)plt.plot(episodes_list, mv_return,c='g')plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN on {}'.format(env_name))plt.show()

5.2  Agent.py

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 16:00:46 2023@author: chengxf2
"""import random 
import numpy as np
from   torch import nn
import torch
import torch.nn.functional as Fclass QNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QNet, self).__init__()self.net = nn.Sequential(nn.Linear(state_dim, hidden_dim),nn.Linear(hidden_dim, action_dim))def forward(self, state):qvalue = self.net(state)return qvalueclass  DQN:def __init__(self,state_dim, hidden_dim, action_dim,learning_rate,discount, epsilon, target_update, device):self.action_dim = action_dimself.q_net = QNet(state_dim, hidden_dim, action_dim).to(device)self.target_q_net = QNet(state_dim, hidden_dim, action_dim).to(device)#Adam 优化器self.optimizer = torch.optim.Adam(self.q_net.parameters(),lr=learning_rate)self.gamma = discount #折扣因子self.epsilon = epsilon  # e-贪心算法self.target_update = target_update  # 目标网络更新频率self.device = deviceself.count = 0 #计数器def  take_action(self, state):rnd = np.random.random() #产生随机数if rnd <self.epsilon:action = np.random.randint(0, self.action_dim)else:state = torch.tensor([state], dtype=torch.float).to(self.device)qvalue = self.q_net(state)action = qvalue.argmax().item()return actiondef update(self, data):states = torch.tensor(data['states'],dtype=torch.float).to(self.device)actions = torch.tensor(data['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(data['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(data['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(data['dones'],dtype=torch.float).view(-1, 1).to(self.device)#从完整数据中按索引取值[64]#print("\n actions ",actions,actions.shape)q_value = self.q_net(states).gather(1,actions) #Q值#下一个状态的Q值max_next_q_values = self.target_q_net(next_states).max(1)[0].view(-1,1)q_targets = rewards + self.gamma * max_next_q_values * (1 - dones)loss = F.mse_loss(q_value, q_targets)loss = torch.mean(loss)self.optimizer.zero_grad()loss.backward()self.optimizer.step()if self.count %self.target_update  ==0:#更新目标网络self.target_q_net.load_state_dict(self.q_net.state_dict())self.count +=1

 5.3 Replaybuffer.py

   

# -*- coding: utf-8 -*-
"""
Created on Fri Nov 17 15:50:07 2023@author: chengxf2
"""import collections 
import random 
import numpy as np
class Replay:def __init__(self, capacity):#双向队列,可以在队列的两端任意添加或删除元素。self.buffer = collections.deque(maxlen = capacity)def add(self, state, action ,reward, next_state, done):#数据加入bufferself.buffer.append((state,action,reward, next_state, done))def sample(self, batch_size):#采样数据data = random.sample(self.buffer, batch_size)state,action, reward, next_state,done = zip(*data)return np.array(state), action, reward, np.array(next_state), donedef size(self):return len(self.buffer)

 5.4 rl_utils.py

from tqdm import tqdm
import numpy as np
import torch
import collections
import randomclass ReplayBuffer:def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity) def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) def sample(self, batch_size): transitions = random.sample(self.buffer, batch_size)state, action, reward, next_state, done = zip(*transitions)return np.array(state), action, reward, np.array(next_state), done def size(self): return len(self.buffer)def moving_average(a, window_size):cumulative_sum = np.cumsum(np.insert(a, 0, 0)) middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_sizer = np.arange(1, window_size-1, 2)begin = np.cumsum(a[:window_size-1])[::2] / rend = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]return np.concatenate((begin, middle, end))def train_on_policy_agent(env, agent, num_episodes):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': [], 'dones': []}state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)agent.update(transition_dict)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef train_off_policy_agent(env, agent, num_episodes, replay_buffer, minimal_size, batch_size):return_list = []for i in range(10):with tqdm(total=int(num_episodes/10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes/10)):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _ = env.step(action)replay_buffer.add(state, action, reward, next_state, done)state = next_stateepisode_return += rewardif replay_buffer.size() > minimal_size:b_s, b_a, b_r, b_ns, b_d = replay_buffer.sample(batch_size)transition_dict = {'states': b_s, 'actions': b_a, 'next_states': b_ns, 'rewards': b_r, 'dones': b_d}agent.update(transition_dict)return_list.append(episode_return)if (i_episode+1) % 10 == 0:pbar.set_postfix({'episode': '%d' % (num_episodes/10 * i + i_episode+1), 'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)return return_listdef compute_advantage(gamma, lmbda, td_delta):td_delta = td_delta.detach().numpy()advantage_list = []advantage = 0.0for delta in td_delta[::-1]:advantage = gamma * lmbda * advantage + deltaadvantage_list.append(advantage)advantage_list.reverse()return torch.tensor(advantage_list, dtype=torch.float)

DQN 算法
遇强则强(八):从Q-table到DQN - 知乎使用Pytorch实现强化学习——DQN算法_dqn pytorch-CSDN博客

https://www.cnblogs.com/xiaohuiduan/p/12993691.html

https://www.cnblogs.com/xiaohuiduan/p/12945449.html

强化学习第五节(DQN)【个人知识分享】_哔哩哔哩_bilibili

CSDN

组会讲解强化学习的DQN算法_哔哩哔哩_bilibili

3-ε-greedy_ReplayBuffer_FixedQ-targets_哔哩哔哩_bilibili

4-代码实战DQN_Agent和Env整体交互_哔哩哔哩_bilibili

DQN基本概念和算法流程(附Pytorch代码) - 知乎

CSDN

DQN 算法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/170101.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python web项目导包规范

python web项目导包规范 python 内置的模块通过第三方库安装的模块框架自身提供的模块用户自己定义的模块 如&#xff1a; from __future__ import absolute_import, unicode_literalsfrom debug_toolbar.panels import Panelfrom django.utils.translation import ugettext_…

YOLOv5改进 | 添加SE注意力机制 + 更换NMS之EIoU-NMS

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。为提高算法模型在不同环境下的目标识别准确率&#xff0c;提出一种基于改进 YOLOv5 深度学习的识别方法&#xff08;SE-NMS-YOLOv5&#xff09;&#xff0c;该方法融合SE&#xff08;Squeeze-and-Excitation&#xff09;注…

邮件与协议

TCP/IP 邮件 电子邮件是 TCP/IP 最重要的应用之一。 您不会用到… 当您写邮件时&#xff0c;您不会用到 TCP/IP。 当您写邮件时&#xff0c;您用到的是电子邮件程序&#xff0c;例如莲花软件的 Notes&#xff0c;微软公司出品的 Outlook&#xff0c;或者 Netscape Communic…

算法必刷系列之数字与数学

文章目录 数字与数学符号统计阶乘0的个数整数反转字符串转数字判断回文数字十进制转七进制进制转换数组实现整数加法字符串加法二进制求和求2的幂求3的幂求4的幂最大公约数最小公倍数判断质数质数计数判断丑数丑数计数 数字与数学 数字与数学的问题基础且庞大&#xff0c;算法…

c 语言线程的使用

在 C 语言中&#xff0c;可以使用 POSIX 线程&#xff08;pthread&#xff09;库来创建和管理线程。下面是一个简单的示例程序&#xff0c;演示如何在 C 语言中使用线程&#xff1a; #include <stdio.h> #include <stdlib.h> #include <pthread.h> voi…

【pytest】Hooks函数之统计测试结果(pytest_terminal_summary)

前言 用例执行完成后&#xff0c;我们希望能获取到执行的结果&#xff0c;这样方便我们快速统计用例的执行情况。 也可以把获取到的结果当成总结报告&#xff0c;发邮件的时候可以先统计测试结果&#xff0c;再加上html的报告。 pytest_terminal_summary 关于TerminalReporter…

Python pandas数据分析

Python pandas数据分析&#xff1a; 2022找工作是学历、能力和运气的超强结合体&#xff0c;遇到寒冬&#xff0c;大厂不招人&#xff0c;可能很多算法学生都得去找开发&#xff0c;测开 测开的话&#xff0c;你就得学数据库&#xff0c;sql&#xff0c;oracle&#xff0c;尤其…

枚举 B. Lorry

Problem - B - Codeforces 题目大意&#xff1a;给物品数量 n n n&#xff0c;体积为 v ( 0 ≤ v ≤ 1 e 9 ) v_{(0 \le v \le 1e9)} v(0≤v≤1e9)​&#xff0c;第一行读入 n , v n, v n,v&#xff0c;之后 n n n行&#xff0c;读入 n n n个物品&#xff0c;之后每行依次是体…

易错知识点(数学一)

一、反常积分判敛 1、构造使其极限等于一个大于0的常数 1&#xff09;前者通过&#xff1a;化等价无穷小 or 泰勒展开 2&#xff09;若存在p>1使得等式成立&#xff0c;则收敛 考察形式&#xff1a;1、已知收敛&#xff0c;求f(x)中的幂次取值范围 主要思想&#xff1a;比较…

⑧【HyperLoglog】Redis数据类型:HyperLoglog [使用手册]

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ Redis HyperLoglog ⑧Redis HyperLoglog基本操…

基于SpringBoot+Redis的前后端分离外卖项目-苍穹外卖(七)

分页查询、删除和修改菜品 1. 菜品分页查询1.1 需求分析和设计1.1.1 产品原型1.1.2 接口设计 1.2 代码开发1.2.1 设计DTO类1.2.2 设计VO类1.2.3 Controller层1.2.4 Service层接口1.2.5 Service层实现类1.2.6 Mapper层 1.3 功能测试1.3.2 前后端联调测试 2. 删除菜品2.1 需求分析…

使用 HTML、CSS 和 JavaScript 创建图像滑块

使用 HTML、CSS 和 JavaScript 创建轮播图 在本文中&#xff0c;我们将讨论如何使用 HTML、CSS 和 JavaScript 构建轮播图。我们将演示两种不同的创建滑块的方法&#xff0c;一种是基于opacity的滑块&#xff0c;另一种是基于transform的。 创建 HTML 我们首先从 HTML 代码开…

Linux系统的研究

摘要&#xff1a; Linux是一个开源的Unix-like操作系统&#xff0c;由林纳斯托瓦兹在1991年首次发布。由于其开放性和灵活性&#xff0c;Linux已成为许多企业和个人用户的重要选择。本文将探讨Linux系统的历史、特点、应用领域以及优缺点。 一、引言 Linux系统的开发始于1980年…

yolo系列中的一些评价指标说明

文章目录 一. 混淆矩阵二. 准确度(Accuracy)三. 精确度(Precision)四. 召回率(Recall)五. F1-score六. P-R曲线七. AP八. mAP九. mAP0.5十. mAP[0.5:0.95] 一. 混淆矩阵 TP (True positives)&#xff1a;被正确地划分为正例的个数&#xff0c;即实际为正例且被分类器划分为正例…

Redis-主从与哨兵架构

Jedis使用 Jedis连接代码示例&#xff1a; 1、引入依赖 <dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>2.9.0</version> </dependency> 2、访问代码 public class JedisSingleTe…

JVM虚拟机:JVM调优第一步,了解JVM常用命令行参数

本文重点 从本文课程开始&#xff0c;我们将用几篇文章来介绍JVM中常用的命令行的参数&#xff0c;这个非常重要&#xff0c;第一我们可以通过参数了解JVM的配置&#xff0c;第二我们可以通过参数完成对JVM的调参。以及后面的JVM的调优也需要用到这些参数&#xff0c;所以我们…

轻松入门自然语言处理系列 项目3 基于Linear-CRF的医疗实体识别

文章目录 前言一、项目概况1.项目描述2.数据描述3.项目框架二、核心技术1.实体识别数据标注2.文本特征工程3.CRF模型4.BiLSTM-CRF模型三、项目实施1.读取数据2.数据标注3.文本特征工程4.模型训练5.模型评估6.BiLSTM-CRF总结前言 本文主要介绍了以Linear-CRF为基础模型进行医疗…

App 设计工具

目录 说明 打开 App 设计工具 示例 创建 App 创建自定义 UI 组件 打开现有 App 文件 打包和共享 App 本文主要讲述以交互方式创建 App。 说明 App 设计工具是一个交互式开发环境&#xff0c;用于设计 App 布局并对其行为进行编程。 可以使用 App 设计工具&#xff1a…

【黑马甄选离线数仓day05_核销主题域开发】

1. 指标分类 ​ 通过沟通调研&#xff0c;把需求进行分析、抽象和总结&#xff0c;整理成指标列表。指标有原子指标、派生指标、 衍生指标三种类型。 ​ 原子指标基于某一业务过程的度量值&#xff0c;是业务定义中不可再拆解的指标&#xff0c;原子指标的核心功能就是对指标…

Python武器库开发-前端篇之CSS元素(三十二)

前端篇之CSS元素(三十二) CSS 元素是一个网页中的 HTML 元素&#xff0c;包括标签、类和 ID。它们可以通过 CSS 选择器选中并设置样式属性&#xff0c;以使网页呈现具有吸引力和良好的可读性。常见的 HTML 元素包括 div、p、h1、h2、span 等&#xff0c;它们可以使用 CSS 设置…