pytorch深度Q网络

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

DQN 引入了深度神经网络来近似Q函数,解决了传统Q-learning在处理高维状态空间时的瓶颈,尤其是在像 Atari 游戏这样的复杂环境中。DQN的核心思想是使用神经网络 Q(s,a;θ)Q(s, a; \theta)Q(s,a;θ) 来近似 Q 值函数,其中 θ\thetaθ 是神经网络的参数。

DQN 的关键创新包括:

  1. 经验回放(Experience Replay):在强化学习中,当前的学习可能会依赖于最近的经验,容易导致学习过程的不稳定。经验回放通过将智能体的经历存储到一个回放池中,然后随机抽取批量数据进行训练,这样可以打破数据之间的相关性,使得训练更加稳定。

  2. 目标网络(Target Network):在Q-learning中,Q值的更新依赖于下一个状态的最大Q值。为了避免Q值更新时过度依赖当前网络的输出(导致不稳定),DQN引入了目标网络。目标网络的结构与行为网络相同,但它的参数更新频率较低,这使得Q值更新更加稳定。

DQN算法流程

  1. 初始化Q网络:初始化Q网络的参数 θ\thetaθ,以及目标网络的参数 θ−\theta^-θ−(通常与Q网络相同)。
  2. 行为选择:基于当前的Q网络来选择动作(通常使用ε-greedy策略,即以ε的概率选择随机动作,否则选择当前Q值最大的动作)。
  3. 执行动作并存储经验:执行所选动作,观察奖励,并记录状态转移 (st,at,rt+1,st+1)(s_t, a_t, r_{t+1}, s_{t+1})(st​,at​,rt+1​,st+1​)。
  4. 经验回放:从回放池中随机抽取一个小批量的经验数据。
  5. 计算Q值目标:对于每个样本,计算目标值 y=rt+1+γmax⁡a′Q(st+1,a′;θ−)y = r_{t+1} + \gamma \max_{a'} Q(s_{t+1}, a'; \theta^-)y=rt+1​+γmaxa′​Q(st+1​,a′;θ−)。
  6. 更新Q网络:通过最小化损失函数 L(θ)=1N∑(y−Q(st,at;θ))2L(\theta) = \frac{1}{N} \sum (y - Q(s_t, a_t; \theta))^2L(θ)=N1​∑(y−Q(st​,at​;θ))2 来更新Q网络的参数。
  7. 周期性更新目标网络:每隔一段时间,将Q网络的参数复制到目标网络。

DQN的应用

DQN在多个领域取得了重要应用,尤其是在强化学习任务中:

  • Atari 游戏:DQN 在多个经典的 Atari 游戏上成功展示了其能力,比如《Breakout》和《Pong》等。
  • 机器人控制:利用DQN,机器人可以在复杂的环境中自主学习如何执行任务。
  • 自动驾驶:在自动驾驶领域,DQN可以用来训练智能体通过道路、避开障碍物等。

例子:

这里我们手动实现一个非常简单的环境:一个1D平衡问题,类似于一个可以左右移动的棒球,目标是让它保持在某个位置上。

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 自定义环境
class SimpleEnv:def __init__(self):self.state = 0.0  # 初始状态self.goal = 10.0  # 目标位置self.done = Falsedef reset(self):self.state = 0.0self.done = Falsereturn self.statedef step(self, action):if self.done:return self.state, 0, self.done  # 游戏结束,不再变化# 通过动作修改状态self.state += action  # 动作是 -1、0、1,控制移动方向reward = -abs(self.state - self.goal)  # 奖励是距离目标位置的负值# 如果距离目标很近,就结束if abs(self.state - self.goal) < 0.1:self.done = Truereward = 10  # 达到目标时奖励较高return self.state, reward, self.done# Q网络定义
class QNetwork(nn.Module):def __init__(self, input_dim, output_dim):super(QNetwork, self).__init__()self.fc = nn.Linear(input_dim, 24)self.fc2 = nn.Linear(24, output_dim)def forward(self, x):x = torch.relu(self.fc(x))x = self.fc2(x)return x# DQN智能体
class DQN:def __init__(self, env, gamma=0.99, epsilon=0.1, batch_size=32, learning_rate=1e-3):self.env = envself.gamma = gammaself.epsilon = epsilonself.batch_size = batch_sizeself.learning_rate = learning_rateself.input_dim = 1  # 因为环境状态是一个单一的数值self.output_dim = 3  # 动作空间大小:-1, 0, 1self.q_network = QNetwork(self.input_dim, self.output_dim)self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.learning_rate)self.criterion = nn.MSELoss()def select_action(self, state):if random.random() < self.epsilon:return random.choice([-1, 0, 1])  # 随机选择动作state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)with torch.no_grad():q_values = self.q_network(state)# 将动作值 -1, 0, 1 转换为索引 0, 1, 2action_idx = torch.argmax(q_values, dim=1).item()action_map = [-1, 0, 1]  # -1 -> 0, 0 -> 1, 1 -> 2return action_map[action_idx]def update(self, state, action, reward, next_state, done):state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)next_state = torch.tensor(next_state, dtype=torch.float32).unsqueeze(0)# 将动作 -1, 0, 1 转换为索引 0, 1, 2action_map = [-1, 0, 1]action_idx = action_map.index(action)action = torch.tensor(action_idx, dtype=torch.long).unsqueeze(0)reward = torch.tensor(reward, dtype=torch.float32).unsqueeze(0)# 确保done是Python标准bool类型done = torch.tensor(done, dtype=torch.float32).unsqueeze(0)# 计算目标Q值with torch.no_grad():next_q_values = self.q_network(next_state)next_q_value = next_q_values.max(1)[0]target_q_value = reward + self.gamma * next_q_value * (1 - done)# 获取当前Q值q_values = self.q_network(state)action_q_values = q_values.gather(1, action.unsqueeze(1)).squeeze(1)# 计算损失并更新Q网络loss = self.criterion(action_q_values, target_q_value)self.optimizer.zero_grad()loss.backward()self.optimizer.step()def train(self, num_episodes=200):rewards = []best_reward = -float('inf')  # 初始最好的奖励设为负无穷best_episode = 0for episode in range(num_episodes):state = self.env.reset()  # 获取初始状态total_reward = 0done = Falsewhile not done:action = self.select_action([state])next_state, reward, done = self.env.step(action)total_reward += reward# 更新Q网络self.update([state], action, reward, [next_state], done)state = next_staterewards.append(total_reward)# 记录最佳奖励和对应的episodeif total_reward > best_reward:best_reward = total_rewardbest_episode = episodeprint(f"Episode {episode}, Total Reward: {total_reward}")# 打印最佳结果print(f"Best Reward: {best_reward} at Episode {best_episode}")# 绘制奖励图plt.plot(rewards)plt.title('Total Rewards per Episode')plt.xlabel('Episode')plt.ylabel('Total Reward')# 在最佳位置添加标记plt.scatter(best_episode, best_reward, color='red', label=f"Best Reward at Episode {best_episode}")plt.legend()plt.show()# 初始化环境和DQN智能体
env = SimpleEnv()
dqn = DQN(env)# 训练智能体
dqn.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67789.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Baklib构建高效协同的基于云的内容中台解决方案

内容概要 随着云计算技术的飞速发展&#xff0c;内容管理的方式也在不断演变。企业面临着如何在数字化转型过程中高效管理和协同处理内容的新挑战。为应对这些挑战&#xff0c;引入基于云的内容中台解决方案显得尤为重要。 Baklib作为创新型解决方案提供商&#xff0c;致力于…

deepseek+vscode自动化测试脚本生成

近几日Deepseek大火,我这里也尝试了一下,确实很强。而目前vscode的AI toolkit插件也已经集成了deepseek R1,这里就介绍下在vscode中利用deepseek帮助我们完成自动化测试脚本的实践分享 安装AI ToolKit并启用Deepseek 微软官方提供了一个针对AI辅助的插件,也就是 AI Toolk…

电介质超表面中指定涡旋的非线性生成

涡旋光束在众多领域具有重要应用&#xff0c;但传统光学器件产生涡旋光束的方式限制了其在集成系统中的应用。超表面的出现为涡旋光束的产生带来了新的可能性&#xff0c;尤其是在非线性领域&#xff0c;尽管近些年来已经有一些研究&#xff0c;但仍存在诸多问题&#xff0c;如…

基于Springboot+mybatis+mysql+html图书管理系统2

基于Springbootmybatismysqlhtml图书管理系统2 一、系统介绍二、功能展示1.用户登陆2.用户主页3.图书查询4.还书5.个人信息修改6.图书管理&#xff08;管理员&#xff09;7.学生管理&#xff08;管理员&#xff09;8.废除记录&#xff08;管理员&#xff09; 三、数据库四、其它…

本地部署DeepSeek方法

本地部署完成后的效果如下图&#xff0c;整体与chatgpt类似&#xff0c;只是模型在本地推理。 我们在本地部署主要使用两个工具&#xff1a; ollamaopen-webui ollama是在本地管理和运行大模型的工具&#xff0c;可以直接在terminal里和大模型对话。open-webui是提供一个类…

游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

【开源免费】基于Vue和SpringBoot的公寓报修管理系统(附论文)

本文项目编号 T 186 &#xff0c;文末自助获取源码 \color{red}{T186&#xff0c;文末自助获取源码} T186&#xff0c;文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

《苍穹外卖》项目学习记录-Day11订单统计

根据起始时间和结束时间&#xff0c;先把begin放入集合中用while循环当begin不等于end的时候&#xff0c;让begin加一天&#xff0c;这样就把这个区间内的时间放到List集合。 查询每天的订单总数也就是查询的时间段是大于当天的开始时间&#xff08;0点0分0秒&#xff09;小于…

【python】python油田数据分析与可视化(源码+数据集)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;专__注&#x1f448;&#xff1a;专注主流机器人、人工智能等相关领域的开发、测试技术。 【python】python油田数据分析与可视化&#xff08…

FBX SDK的使用:基础知识

Windows环境配置 FBX SDK安装后&#xff0c;目录下有三个文件夹&#xff1a; include 头文件lib 编译的二进制库&#xff0c;根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库&#xff0c;需要在配置属性->C/C->预…

一文讲解HashMap线程安全相关问题(上)

HashMap不是线程安全的&#xff0c;主要有以下几个问题&#xff1a; ①、多线程下扩容会死循环。JDK1.7 中的 HashMap 使用的是头插法插入元素&#xff0c;在多线程的环境下&#xff0c;扩容的时候就有可能导致出现环形链表&#xff0c;造成死循环。 JDK 8 时已经修复了这个问…

python学习——常用的内置函数汇总

文章目录 类型转换函数数学函数常用的迭代器操作函数常用的其他内置函数 类型转换函数 数学函数 常用的迭代器操作函数 实操&#xff1a; from cv2.gapi import descr_oflst [55, 42, 37, 2, 66, 23, 18, 99]# (1) 排序操作 asc_lst sorted(lst) # 升序 desc_lst sorted(l…

MySQL数据库环境搭建

下载MySQL 官网&#xff1a;https://downloads.mysql.com/archives/installer/ 下载社区版就行了。 安装流程 看b站大佬的视频吧&#xff1a;https://www.bilibili.com/video/BV12q4y1477i/?spm_id_from333.337.search-card.all.click&vd_source37dfd298d2133f3e1f3e3c…

如何用微信小程序写春联

​ 生活没有模板,只需心灯一盏。 如果笑能让你释然,那就开怀一笑;如果哭能让你减压,那就让泪水流下来。如果沉默是金,那就不用解释;如果放下能更好地前行,就别再扛着。 一、引入 Vant UI 1、通过 npm 安装 npm i @vant/weapp -S --production​​ 2、修改 app.json …

[SAP ABAP] 静态断点的使用

在 ABAP 编程环境中&#xff0c;静态断点通过关键字BREAK-POINT实现&#xff0c;当程序执行到这一语句时&#xff0c;会触发调试器中断程序的运行&#xff0c;允许开发人员检查当前状态并逐步跟踪后续代码逻辑 通常情况下&#xff0c;在代码的关键位置插入静态断点可以帮助开发…

96,【4】 buuctf web [BJDCTF2020]EzPHP

进入靶场 查看源代码 GFXEIM3YFZYGQ4A 一看就是编码后的 1nD3x.php 访问 得到源代码 <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;用于调试或展示代码结构 highlight_file(__FILE__); // 关闭所有 PHP 错误报告&#xff0c;防止错误信息泄露可能的安全漏洞 erro…

基于深度学习的输电线路缺陷检测算法研究(论文+源码)

输电线路关键部件的缺陷检测对于电网安全运行至关重要&#xff0c;传统方法存在效率低、准确性不高等问题。本研究探讨了利用深度学习技术进行输电线路关键组件的缺陷检测&#xff0c;目的是提升检测的效率与准确度。选用了YOLOv8模型作为基础&#xff0c;并通过加入CA注意力机…

3、从langchain到rag

文章目录 本文介绍向量和向量数据库向量向量数据库 索引开始动手实现rag加载文档数据并建立索引将向量存放到向量数据库中检索生成构成一条链 本文介绍 从本节开始&#xff0c;有了上一节的langchain基础学习&#xff0c;接下来使用langchain实现一个rag应用&#xff0c;并稍微…

DeepSeek-R1大模型本地化部署

前言 Ollama作为一个轻量级、易上手的工具&#xff0c;可以帮助你在自己的电脑上快速部署和运行大型语言模型&#xff0c;无需依赖云端服务。通过加载各种开源模型&#xff0c;比如LLaMA、GPT-J等&#xff0c;并通过简单的命令行操作进行模型推理和测试。 此小结主要介绍使用…

【高级篇 / IPv6】(7.6) ❀ 03. 宽带IPv6 - ADSL拨号宽带上网配置 ❀ FortiGate 防火墙

【简介】大部分ADSL拨号宽带都支持IPv6&#xff0c;这里以ADSL拨号宽带为例&#xff0c;演示在FortiGate防火墙上的配置方法。 准备工作 同上篇文章一样&#xff0c;为了兼顾不熟悉FortiGate防火墙的朋友&#xff0c;我们从基础操作进行演示&#xff0c;熟练的朋友可以跳过这一…