PyTorch深度学习实战(46)——深度Q学习

PyTorch深度学习实战(46)——深度Q学习

    • 0. 前言
    • 1. 深度 Q 学习
    • 2. 网络架构
    • 3. 实现深度 Q 学习模型进行 CartPole 游戏
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何构建一个 Q 表,通过在多个 episode 中重复进行游戏获取与给定状态-动作组合相对应的值。然而,当状态空间是连续时,可能的状态空间数会变得非常巨大。在本节中,我们将学习如何使用神经网络在没有 Q 表的情况下估计状态-动作组合的 Q 值,因此称为深度 Q 学习 (deep Q-learning)。

1. 深度 Q 学习

与 Q 表相比,深度 Q 学习利用神经网络将任意给定的状态-动作(其中状态可以是连续或离散的)组合映射到相应 Q 值。
在本节中,将使用 Gym 中的 CartPole 环境,智能体的任务是尽可能长时间地平衡 CartPoleCartPole 环境如下图所示:

CartPole-v0

当小车向右移动时,杆向左移动,反之亦然,CartPole 环境中的每个状态都由四个观测值定义,其名称及其最小值和最大值如下:

状态最小值最大值
Cart position-2.42.4
Cart velocity-infinf
Pole angle-41.8°41.8°
Pole velocity at the tip-infinf

需要注意的是,表示状态的所有观测值都具有连续值,用于 CartPole 平衡游戏的深度 Q 学习的工作原理如下:

  1. 获取输入值(游戏图像/游戏元数据)
  2. 通过网络传递输入值,网络的输出与可能的动作数相同
  3. 输出层预测在给定状态下采取某个动作对应的 Q 值

2. 网络架构

网络架构使用状态(四个观测值)作为输入,在当前状态下采取左/右动作的 Q 值作为输出。神经网络训练策略如下:

  1. 在探索阶段,执行输出层中具有最高值的随机动作
  2. 将动作、下一个状态、奖励和指示游戏是否完成的标志存储在内存中
  3. 如果游戏没有完成,计算在给定状态下采取行动的 Q 值,即奖励 + 折扣因子 x 下一个状态中所有动作的最大可能 Q 值
  4. 修改采取动作的Q值,而其他状态-动作组合的 Q 值保持不变
  5. 多次执行步骤 14 并存储经验
  6. 拟合模型,将状态作为输入,动作值作为预期输出(来自内存和回放经验),并最小化 MSE 损失
  7. 在降低探索率的同时在多个 episode 上重复上述步骤

3. 实现深度 Q 学习模型进行 CartPole 游戏

根据以上策略,使用 PyTorch 编写深度 Q 学习模型,进行 CartPole 游戏。

(1) 导入相关库:

import gym
import numpy as np
import cv2
from collections import deque
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
from collections import namedtuple, deque
import torch
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(2) 定义环境:

env = gym.make('CartPole-v1')

(3) 定义网络架构:

class DQNetwork(nn.Module):def __init__(self, state_size, action_size):super(DQNetwork, self).__init__()self.fc1 = nn.Linear(state_size, 24)self.fc2 = nn.Linear(24, 24)self.fc3 = nn.Linear(24, action_size)def forward(self, state):       x = F.relu(self.fc1(state))x = F.relu(self.fc2(x))x = self.fc3(x)return x

该架构在两个隐藏层中仅包含 24 个单元,输出层包含与可能动作数相同的单元。

(4) 定义 Agent 类。

定义 __init__ 方法,其中包含各种参数、网络的定义:

class Agent():def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.seed = random.seed(0)## hyperparametersself.buffer_size = 2000self.batch_size = 64self.gamma = 0.99self.lr = 0.0025self.update_every = 4 # Q-Networkself.local = DQNetwork(state_size, action_size).to(device)self.optimizer = optim.Adam(self.local.parameters(), lr=self.lr)# Replay memoryself.memory = deque(maxlen=self.buffer_size) self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])self.t_step = 0

定义 step 函数,该函数从内存中获取数据并通过调用 learn 函数将其拟合到模型中:

    def step(self, state, action, reward, next_state, done):# Save experience in replay memoryself.memory.append(self.experience(state, action, reward, next_state, done)) # Learn every update_every time steps.self.t_step = (self.t_step + 1) % self.update_everyif self.t_step == 0:# If enough samples are available in memory, get random subset and learnif len(self.memory) > self.batch_size:experiences = self.sample_experiences()self.learn(experiences, self.gamma)

定义 act 函数,该函数在给定状态的情况下预测动作:

    def act(self, state, eps=0.):# Epsilon-greedy action selectionif random.random() > eps:state = torch.from_numpy(state).float().unsqueeze(0).to(device)self.local.eval()with torch.no_grad():action_values = self.local(state)self.local.train()return np.argmax(action_values.cpu().data.numpy())else:return random.choice(np.arange(self.action_size))

在以上代码中,我们在确定要采取的行动时使用探索-利用策略。

定义 learn 函数用于拟合模型,使其在给定状态时预测动作值:

    def learn(self, experiences, gamma): states, actions, rewards, next_states, dones = experiences# Get expected Q values from local modelQ_expected = self.local(states).gather(1, actions)# Get max predicted Q values (for next states) from local modelQ_targets_next = self.local(next_states).detach().max(1)[0].unsqueeze(1)# Compute Q targets for current states Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))# Compute lossloss = F.mse_loss(Q_expected, Q_targets)# Minimize the lossself.optimizer.zero_grad()loss.backward()self.optimizer.step()

在以上代码中,获取采样经验并预测我们执行的动作的 Q 值。此外,由于我们已经知道下一个状态,可以预测下一个状态下动作的最佳 Q 值。因此,我们可以得到与在给定状态下采取的动作相对应的目标值。最后,计算在当前状态下采取的动作的 Q 值的期望值 (Q_targets) 和预测值 (Q_expected) 之间的误差。

定义 sample_experiences 函数以便从内存中采样经验:

    def sample_experiences(self):experiences = random.sample(self.memory, k=self.batch_size)        states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)        return (states, actions, rewards, next_states, dones)

(5) 定义智能体对象:

agent = Agent(env.observation_space.shape[0], env.action_space.n)

(6) 训练模型。

初始化列表:

scores = [] # list containing scores from each episode
scores_window = deque(maxlen=100) # last 100 scores
n_episodes=5000
max_t=5000
eps_start=1.0
eps_end=0.001
eps_decay=0.9995
eps = eps_start

在每个 episode 中重置环境并获取状态的形状,此外,整形状态维度形状,以便可以将其传递给网络:

for i_episode in range(1, n_episodes+1):state = env.reset()state_size = env.observation_space.shape[0]state = np.reshape(state, [1, state_size])score = 0

循环通过 max_t 个时间步,确定要执行的动作,并使用 step 方法执行,使用 np.reshape 整形状态张量,并将整形后的状态传递给神经网络:

    for i in range(max_t):action = agent.act(state, eps)next_state, reward, done, _ = env.step(action)next_state = np.reshape(next_state, [1, state_size])

通过指定 agent.step 在当前状态之上拟合模型,并将状态重置为下一个状态,以便在下一次迭代中使用。

如果前 10 步的得分平均值大于 450,则存储相关数据并停止训练:

        reward = reward if not done or score == 499 else -10agent.step(state, action, reward, next_state, done)state = next_statescore += rewardif done:break scores_window.append(score) # save most recent score scores.append(score) # save most recent scoreeps = max(eps_end, eps_decay*eps) # decrease epsilonprint('\rEpisode {}\tReward {} \tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode,score,np.mean(scores_window), eps), end="")if i_episode % 100 == 0:print('\rEpisode {}\tAverage Score: {:.2f} \tEpsilon: {}'.format(i_episode, np.mean(scores_window), eps))if i_episode>10 and np.mean(scores[-10:])>450:break
"""
Episode 100     Average Score: 12.65 ge Epsilon: 0.951217530242334.9512175302423344
...
Episode 2700    Average Score: 116.56 e Epsilon: 0.259152752655221145915275265522114
Episode 2712    Reward 500.0    Average Score: 159.01   Epsilon: 0.2576021050410192
"""

(7) 绘制随着 episode 的增加的分数变化情况如下:

import matplotlib.pyplot as plt
plt.plot(scores)
plt.title('Scores over increasing episodes')
plt.show()

请添加图片描述

从上图中可以看出,在第 2000episode 之后,该模型在进行 CartPole 游戏时能够获得较高分。

小结

深度 Q 学习是一种结合了深度学习和强化学习的方法,通过深度神经网络逼近 Q 值函数,在解决大规模、连续状态空间问题方面具有优势,并在多个领域展示了强大的学习和决策能力。在本节中,介绍了深度 Q 学习的基本概念,并学习了如何使用 PyTorch 实现深度 Q 学习进行 CartPole 游戏。

系列链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
PyTorch深度学习实战(16)——面部关键点检测
PyTorch深度学习实战(17)——多任务学习
PyTorch深度学习实战(18)——目标检测基础
PyTorch深度学习实战(19)——从零开始实现R-CNN目标检测
PyTorch深度学习实战(20)——从零开始实现Fast R-CNN目标检测
PyTorch深度学习实战(21)——从零开始实现Faster R-CNN目标检测
PyTorch深度学习实战(22)——从零开始实现YOLO目标检测
PyTorch深度学习实战(23)——从零开始实现SSD目标检测
PyTorch深度学习实战(24)——使用U-Net架构进行图像分割
PyTorch深度学习实战(25)——从零开始实现Mask R-CNN实例分割
PyTorch深度学习实战(26)——多对象实例分割
PyTorch深度学习实战(27)——自编码器(Autoencoder)
PyTorch深度学习实战(28)——卷积自编码器(Convolutional Autoencoder)
PyTorch深度学习实战(29)——变分自编码器(Variational Autoencoder, VAE)
PyTorch深度学习实战(30)——对抗攻击(Adversarial Attack)
PyTorch深度学习实战(31)——神经风格迁移
PyTorch深度学习实战(32)——Deepfakes
PyTorch深度学习实战(33)——生成对抗网络(Generative Adversarial Network, GAN)
PyTorch深度学习实战(34)——DCGAN详解与实现
PyTorch深度学习实战(35)——条件生成对抗网络(Conditional Generative Adversarial Network, CGAN)
PyTorch深度学习实战(36)——Pix2Pix详解与实现
PyTorch深度学习实战(37)——CycleGAN详解与实现
PyTorch深度学习实战(38)——StyleGAN详解与实现
PyTorch深度学习实战(39)——小样本学习(Few-shot Learning)
PyTorch深度学习实战(40)——零样本学习(Zero-Shot Learning)
PyTorch深度学习实战(41)——循环神经网络与长短期记忆网络
PyTorch深度学习实战(42)——图像字幕生成
PyTorch深度学习实战(43)——手写文本识别
PyTorch深度学习实战(44)——基于 DETR 实现目标检测
PyTorch深度学习实战(45)——强化学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/45231.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(十四)-无人机操控关键绩效指标(KPI)框架

引言 本文是3GPP TR 22.829 V17.1.0技术报告,专注于无人机(UAV)在3GPP系统中的增强支持。文章提出了多个无人机应用场景,分析了相应的能力要求,并建议了新的服务级别要求和关键性能指标(KPIs)。…

第二证券:转融通是什么意思?什么是转融通?

转融通,包含转融资和转融券,实质是借钱和借券。转融通是指证券金融公司借入证券、筹得资金后,再转借给证券公司,是一假贷联络,具体是指证券公司从符合要求的基金处理公司、保险公司、社保基金等组织出资者融券&#xf…

Python应用开发——30天学习Streamlit Python包进行APP的构建(15):优化性能并为应用程序添加状态

Caching and state 优化性能并为应用程序添加状态! Caching 缓存 Streamlit 为数据和全局资源提供了强大的缓存原语。即使从网络加载数据、处理大型数据集或执行昂贵的计算,它们也能让您的应用程序保持高性能。 本页仅包含有关 st.cache_data API 的信息。如需深入了解缓…

技术成神之路:设计模式(六)策略模式

1.介绍 策略模式(Strategy Pattern)是一种行为型设计模式,它定义了一系列算法,封装每一个算法,并使它们可以相互替换。策略模式使得算法的变化独立于使用算法的客户端。 2.主要作用 策略模式的主要作用是将算法或行为…

什么叫图像的双边滤波,并附利用OpenCV和MATLB实现双边滤波的代码

双边滤波(Bilateral Filtering)是一种在图像处理中常用的非线性滤波技术,主要用于去噪和保边。它在空间域和像素值域上同时进行加权,既考虑了像素之间的空间距离,也考虑了像素值之间的相似度,从而能够有效地…

手机怎么看WiFi的IP地址

在如今数字化快速发展的时代,无线网络已成为我们日常生活中不可或缺的一部分。无论是工作、学习还是娱乐,我们可能都离不开WiFi的陪伴。然而,在使用WiFi的过程中,有时我们可能需要查看其IP地址,以便更好地管理我们的网…

【动态规划】背包问题 {01背包问题;完全背包问题;二维费用背包问题}

一、背包问题概述 背包问题(Knapsackproblem)是⼀种组合优化的NP完全问题。 问题可以描述为:给定一组物品,每种物品都有自己的重量和价格,在限定的总重量内,我们如何选择,才能使得物品的总价格最⾼。 根据物品的个数…

链接追踪系列-07.logstash安装json_lines插件

进入docker中的logstash 容器内: jelexbogon ~ % docker exec -it 7ee8960c99a31e607f346b2802419b8b819cc860863bc283cb7483bc03ba1420 /bin/sh $ pwd /usr/share/logstash $ ls bin CONTRIBUTORS Gemfile jdk logstash-core modules tools x-pack …

语音识别概述

语音识别概述 一.什么是语音? 语音是语言的声学表现形式,是人类自然的交流工具。 图片来源:https://www.shenlanxueyuan.com/course/381 二.语音识别的定义 语音识别(Automatic Speech Recognition, ASR 或 Speech to Text, ST…

基于RAG大模型的变电站智慧运维-第十届Nvidia Sky Hackathon参赛作品

第十届Nvidia Sky Hackathon参赛作品 1. 项目说明 变电站是用于变电的设施,主要的作用是将电压转化,使电能在输电线路中能够长距离传输。在电力系统中,变电站起到了极为重要的作用,它可以完成电能的负荷分配、电压的稳定、容错保…

电影购票小程序论文(设计)开题报告

一、课题的背景和意义 随着互联网技术的不断发展,人们对于购票的需求也越来越高。传统的购票方式存在着排队时间长、购票流程繁琐等问题,而网上购票则能够有效地解决这些问题。电影购票小程序是网上购票的一种新型应用,它能够让用户随时随地…

06.截断文本 选择任何链接 :root 和 html 有什么区别

截断文本 对超过一行的文本进行截断,在末尾添加省略号(…)。 使用 overflow: hidden 防止文本超出其尺寸。使用 white-space: nowrap 防止文本超过一行高度。使用 text-overflow: ellipsis 使得如果文本超出其尺寸,将以省略号结尾。为元素指定固定的 width,以确定何时显示省略号…

笔记 4 :linux 0.11 中继续分析 0 号进程创建一号进程的 fork () 函数

(27)本条目开始, 开始分析 copy_process () 函数,其又会调用别的函数,故先分析别的函数。 get_free_page () ; 先 介绍汇编指令 scasb : 以及 指令 sstosd :…

什么是架构设计师?定义、职责和任务,全方位解析需要具备的专业素质

目录 1. 架构设计师的定义 2. 架构设计师的职责和任务 2.1 系统架构设计 2.1.1 模块划分 2.1.2 接口设计 2.1.3 通信方式 2.2 技术选型与决策 2.2.1 技术评估 2.2.2 技术选型 2.2.3 技术决策 2.3 性能优化与调优 2.3.1 性能分析 2.3.2 性能优化 2.3.3 性能调优 …

视图库对接系列(GA-T 1400)十七、视图库对接系列(本级)采集设备获取

背景 这一章的话,我们写写如何获取采集设备获取,之前其实也有说过类似的 就我们订阅的时候如果subscribeDetail=3的话,下级就会主动给我们推送采集设备。但这里的话,是下级主动推,如果下级平台不支持,或者说可能因为某个原因推的不全,怎么办? 我们能否主动获取采集设备…

WPF学习(4) -- 数据模板

一、DataTemplate 在WPF(Windows Presentation Foundation)中,DataTemplate 用于定义数据的可视化呈现方式。它允许你自定义如何展示数据对象,从而实现更灵活和丰富的用户界面。DataTemplate 通常用于控件(如ListBox、…

知识图谱和 LLM:利用 Neo4j 实现大型语言模型

这是关于 Neo4j 的 NaLLM 项目的一篇博客文章。这个项目是为了探索、开发和展示这些 LLM 与 Neo4j 结合的实际用途。 2023 年,ChatGPT 等大型语言模型 (LLM) 因其理解和生成类似人类的文本的能力而风靡全球。它们能够适应不同的对话环境、回答各种主题的问题,甚至模拟创意写…

NSSCTF中24网安培训day1中web的题目

我flag呢 直接查看源代码即可CtrlU [SWPUCTF 2021 新生赛]Do_you_know_http 用Burpsuite抓包,之后在User-agent下面添加XFF头,即X-Forwarded-For:127.0.0.1 [SWPUCTF 2022 新生赛]funny_php 首先是php的弱比较,对于num参数,我们…

hot100 | 十一、二分搜索

1-leetcode35. 搜索插入位置 注意&#xff1a; 看Labuladong的书&#xff0c;知道while的判断符号跟left right的关系 public int searchInsert(int[] nums, int target) {int left 0;int right nums.length - 1;while (left < right) {int mid left (right - left) /…

PostgreSQL日志文件配置,记录所有操作记录

为了更详细的记录PostgreSQL 的运行日志&#xff0c;我们一般需要修改PostgreSQL 默认的配置文件&#xff0c;这里整理了一些常用的配置 修改配置文件 打开 PostgreSQL 配置文件 postgresql.conf。该文件通常位于 PostgreSQL 安装目录下的 data 文件夹中。 找到并修改以下配…