强化学习 - Twin Delayed DDPG (TD3)

什么是机器学习

Twin Delayed DDPG (TD3) 是一种用于解决连续动作空间的强化学习问题的算法,是 Deep Deterministic Policy Gradient (DDPG) 的改进版本。TD3引入了一些技巧,例如双Q网络Twin Q-networks)和延迟更新,以提高算法的性能和稳定性。

以下是一个使用 Python 和 TensorFlow/Keras 实现简单的 TD3 的示例。在这个例子中,我们将使用 OpenAI GymPendulum 环境。

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
import gym# 定义TD3 Agent
class TD3Agent:def __init__(self, state_size, action_size):self.state_size = state_sizeself.action_size = action_sizeself.gamma = 0.99  # 折扣因子self.tau = 0.005  # 软更新参数self.actor_lr = 0.001self.critic_lr = 0.001self.policy_noise = 0.2self.policy_noise_clip = 0.5self.exploration_noise = 0.1self.buffer_size = 1000000self.batch_size = 100self.buffer = []# 构建演员(Actor)网络和目标演员网络self.actor = self.build_actor()self.target_actor = self.build_actor()self.target_actor.set_weights(self.actor.get_weights())# 构建两个评论家(Critic)网络和目标评论家网络self.critic_1 = self.build_critic()self.target_critic_1 = self.build_critic()self.target_critic_1.set_weights(self.critic_1.get_weights())self.critic_2 = self.build_critic()self.target_critic_2 = self.build_critic()self.target_critic_2.set_weights(self.critic_2.get_weights())def build_actor(self):state_input = Input(shape=(self.state_size,))dense1 = Dense(400, activation='relu')(state_input)dense2 = Dense(300, activation='relu')(dense1)output = Dense(self.action_size, activation='tanh')(dense2)model = Model(inputs=state_input, outputs=output)model.compile(loss='mse', optimizer=Adam(lr=self.actor_lr))return modeldef build_critic(self):state_input = Input(shape=(self.state_size,))action_input = Input(shape=(self.action_size,))concat = tf.keras.layers.concatenate([state_input, action_input])dense1 = Dense(400, activation='relu')(concat)dense2 = Dense(300, activation='relu')(dense1)output = Dense(1, activation='linear')(dense2)model = Model(inputs=[state_input, action_input], outputs=output)model.compile(loss='mse', optimizer=Adam(lr=self.critic_lr))return modeldef get_action(self, state):state = np.reshape(state, [1, self.state_size])action = self.actor.predict(state)[0]action = np.clip(action + np.random.normal(0, self.exploration_noise, self.action_size), -1, 1)return actiondef train(self):if len(self.buffer) < self.batch_size:returnbatch = np.random.choice(self.buffer, self.batch_size, replace=False)states, actions, rewards, next_states, dones = zip(*batch)states = np.vstack(states)actions = np.vstack(actions)rewards = np.vstack(rewards)next_states = np.vstack(next_states)dones = np.vstack(dones)next_actions = self.target_actor.predict(next_states) + np.clip(np.random.normal(0, self.policy_noise, self.action_size), -self.policy_noise_clip, self.policy_noise_clip)next_actions = np.clip(next_actions, -1, 1)target_q_values = np.minimum(self.target_critic_1.predict([next_states, next_actions]),self.target_critic_2.predict([next_states, next_actions]))target_values = rewards + self.gamma * (1 - dones) * target_q_valuesself.critic_1.train_on_batch([states, actions], target_values)self.critic_2.train_on_batch([states, actions], target_values)actor_gradients = np.reshape(self.critic_1.gradient(states + self.actor.predict(states)), [-1, self.action_size])actor_gradients = actor_gradients / self.batch_sizeself.actor.train_on_batch(states, actor_gradients)self.soft_update_target_networks()def soft_update_target_networks(self):actor_weights = np.array(self.actor.get_weights())target_actor_weights = np.array(self.target_actor.get_weights())self.target_actor.set_weights(self.tau * actor_weights + (1 - self.tau) * target_actor_weights)critic_1_weights = np.array(self.critic_1.get_weights())target_critic_1_weights = np.array(self.target_critic_1.get_weights())self.target_critic_1.set_weights(self.tau * critic_1_weights + (1 - self.tau) * target_critic_1_weights)critic_2_weights = np.array(self.critic_2.get_weights())target_critic_2_weights = np.array(self.target_critic_2.get_weights())self.target_critic_2.set_weights(self.tau * critic_2_weights + (1 - self.tau) * target_critic_2_weights)def store_experience(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))if len(self.buffer) > self.buffer_size:self.buffer.pop(0)# 初始化环境和Agent
env = gym.make('Pendulum-v0')
state_size = env.observation_space.shape[0]
action_size = env.action_space.shape[0]
agent = TD3Agent(state_size, action_size)# 训练TD3 Agent
num_episodes = 500
for episode in range(num_episodes):state = env.reset()total_reward = 0for time in range(500):  # 限制每个episode的步数,防止无限循环# env.render()  # 如果想可视化训练过程,可以取消注释此行action = agent.get_action(state)next_state, reward, done, _ = env.step(action)total_reward += rewardagent.store_experience(state, action, reward, next_state, done)agent.train()state = next_stateif done:print("Episode: {}, Total Reward: {}".format(episode + 1, total_reward))break# 关闭环境
env.close()

在这个例子中,我们定义了一个简单的TD3 Agent,包括演员(Actor)和两个评论家(Critic)神经网络。在训练过程中,我们使用了两个评论家网络和一些技巧来提高稳定性,并进行了软更新。请注意,TD3算法的实现可能因问题的复杂性而有所不同,可能需要更多的技术和调整,如归一化奖励、使用更复杂的神经网络结构等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/652408.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp scroll-view用法[下拉刷新,触底事件等等...](4)

前言:可滚动视图区域。用于区域滚动 话不多说 直接上官网属性 官网示例 讲一下常用的几个 scroll 滚动时触发 scrolltoupper 滚动到顶部或左边&#xff0c;会触发 scrolltoupper 事件 scrolltolower 滚动到底部或右边&#xff0c;会触发 scrolltolower 事件 1.纵向滚动…

【HTML教程】跟着菜鸟学语言—HTML5个人笔记经验(四)

HTML学习第三天&#xff01; PS&#xff1a;牛牛只是每天花了1.5-2小时左右来学习HTML。 书接上回 HTML<div>和<span> HTML 可以通过<div> 和 <span>将元素组合起来。 HTML 区块元素 大多数 HTML 元素被定义为块级元素或内联元素。 块级元素在浏…

【Git配置代理】Failed to connect to github.com port 443 问题解决方法

前言&#xff1a; 在学习代码审计时&#xff0c;有时会需要使用git去拉取代码&#xff0c;然后就出现了如下错误 看过网上很多解决方法&#xff0c;觉得问题的关键还是因为命令行在拉取/推送代码时并没有使用VPN进行代理。 解决办法 &#xff1a; 配置http代理&#xff1a;…

MySQL-round()四舍五入取整函数

定义和用法 ROUND() 函数将数字四舍五入到指定的小数位数。 语法 ROUND(number, decimals) 参数值 参数描述number必需。要四舍五入的数字decimals可选。number 要四舍五入的小数位数。 如果省略&#xff0c;则返回整数&#xff08;无小数&#xff09; 1、ROUND(X)函数 …

huanju一台dell机器Ubuntu wifi 故障留档

非常神奇的一台机器&#xff0c;先说结果&#xff0c;放弃该机器了&#xff0c;让另一台机器顶上去。 故障表现&#xff1a; 1 在经过一路颠簸后&#xff0c;wifi一开&#xff0c;不管连没连上&#xff0c;屏幕都疯狂输出报错信息 [11376.275959] pcieport 0000:00:1c.7: AER…

【Linux】第三十七站:信号保存

文章目录 一、信号发送二、信号保存1.为什么要进行信号保存&#xff1f; 三、阻塞信号1.信号的一些相关概念2.在内核中的表示3.sigset_t4.信号集操作函数5.sigprocmask6.sigpending7. 总结 一、信号发送 如下所示&#xff0c;对于普通信号&#xff0c;它的编号是从1~31。这个是…

指针的深入了解2

1.const修饰指针 在这之前我们还学过static修饰变量&#xff0c;那我们用const来修饰一下变量会有什么样的效果呢&#xff1f; 我们来看看&#xff1a; 我们可以看到编译器报错告诉我们a变成了一个不可修改的值&#xff0c;我们在变量前加上了const进行限制&#xff0c;但是我…

使用py-spy对python程序进行性能诊断学习

py-spy简介 py-spy是一个用Rust编写的轻量级Python分析工具&#xff0c;它能够监视正在运行的Python程序&#xff0c;而不需要修改代码或者重新启动程序。Py-spy可以在不影响程序运行的情况下&#xff0c;采集程序运行时的信息&#xff0c;生成火焰图&#xff08;flame graph&…

php数组算法(1)判断一维数组和多元数组中的元素是否相等并输出键值key

在php中&#xff0c;如何判断[1,0,1]和[ [0, 0, 0],//体质正常 [1, 0, 0],//气虚体质 [0, 1, 0],//血瘀体质 [0, 0, 1],//阴虚体质 [1, 1, 0],//气虚兼血瘀体质 [1, 0, 1],//气虚兼阴虚体质 [0, 1, 1],//血瘀兼阴虚体质 [1, 1, 1],//气虚兼血瘀兼阴虚体质 ];中的第n项相等&…

SpringBoot集成MyBatis操作MySql8的JSON类型

SpringBoot集成MyBatis操作MySql8的JSON类型 1.定义Json类型转换器&#xff1a;JsonTypeHandler 一个包有一个类型转换器就够了开箱即用&#xff0c;复制即可 package com.ins.iot.sync.server.handle;import com.fasterxml.jackson.annotation.JsonInclude; import com.fas…

linux 基于科大讯飞的文字转语音使用

官方文档地址&#xff1a;离线语音合成 Linux SDK 文档 | 讯飞开放平台文档中心 一、SDK下载 1、点击上面官方文档地址的链接&#xff0c;可以跳转到以下界面。 2、点击“普通版”&#xff0c;跳转到以下界面。 3、点击“下载”跳转到以下界面 4、最后&#xff0c;点击“SDK下…

Qt6入门教程 12:QAbstractButton

目录 一.状态 二.信号 三.使用 1.自定义按钮 2.多选 3.互斥 QAbstractButton类实现了一个抽象按钮&#xff0c;并且让它的子类来指定如何处理用户的动作&#xff0c;并指定如何绘制按钮。QAbstractButton类是所有按钮控件的基类。 QAbstractButton提供…

【阿里云服务器数据迁移】 同一个账号 不同区域服务器

前言 假如说一台云服务器要过期了,现在新买了一台,有的人会烦恼又要将重新在新的服务器上装环境,部署上线旧服务器上的网站项目, 但是不必烦恼,本文将介绍如何快速将就旧的服务器上的数据迁移到新的服务器上. 包括所有的环境和网站项目噢 ! 步骤 (1) 创建旧服务器自定义镜像…

Linux命令 - 统计log日志某接口用户访问频次并排序

​ 背景 某天发现内部人员使用的app服务器访问突增&#xff0c;但不影响服务正常运行&#xff0c;想通过log统计接口的人员访问频次。 从监控平台可以看到访问激增的接口&#xff0c;因Nginx不缓存用户信息只有访问IP&#xff0c;日志清洗的Hive表只能访问前一天&#xff0c;…

行测-资料:2. 一般增长率、增长量

1、一般增长率 1.1 百分数和百分点 50%&#xff0c;20% 1.2 增长率和倍数 1.5&#xff1b;50 1.3 成数和翻番 1.4 增幅&#xff0c;降幅&#xff0c;变化幅度 A&#xff0c;A&#xff0c;D B&#xff0c;高于全国增速 2.3 个百分点&#xff0c;21.8 - 2.3 19.5。 5%&#xff0…

Oracle PL/SQL Programming 第1章:Introduction to PL/SQL 读书笔记

总的目录和进度&#xff0c;请参见开始读 Oracle PL/SQL Programming 第6版 PL/SQL 是 “Procedural Language extensions to the Structured Query Language.”的缩写。 什么是PL/SQL? Oracle 公司推出 PL/SQL 是为了克服 SQL 中的一些限制&#xff0c;并为那些寻求构建针…

RTPS协议

文章目录 RTPS(Real-time Publish-Subscribe Protocol)RTPS(Real-time Publish-Subscribe Protocol) RTPS(Real-time Publish-Subscribe Protocol)是一种基于发布/订阅模型的协议,用于实时数据分发。它是 OMG 组织定义的标准协议,用于支持分布式实时系统中的数据分发和…

打开 IOS开发者模式

前言 需要 1、辅助设备&#xff1a;苹果电脑&#xff1b; 2、辅助应用&#xff1a;Xcode&#xff1b; 3、准备工作&#xff1a;苹果手机 使用数据线连接 苹果电脑&#xff1b; 当前系统版本 IOS 17.3 通过Xcode激活 两指同时点击 Xcode 显示选择&#xff0c;Open Develop…

重生奇迹MU平民玩家推荐的职业

女魔法师 女魔法师是一个非常适合平民玩家的职业选择。她拥有着强大的魔法攻击能力&#xff0c;可以轻松地击败敌人。而且女魔法师的装备价格相对较低&#xff0c;适合玩家们的经济实力。 精灵射手 精灵射手是一个非常灵活的职业选择。他们可以远程攻击&#xff0c;可以在战…

go-carbon v2.3.7 发布,轻量级、语义化、对开发者友好的 golang 时间处理库

carbon 是一个轻量级、语义化、对开发者友好的 golang 时间处理库&#xff0c;支持链式调用。 目前已被 awesome-go 收录&#xff0c;如果您觉得不错&#xff0c;请给个 star 吧 github.com/golang-module/carbon gitee.com/golang-module/carbon 安装使用 Golang 版本大于…