强化学习嵌入Transformer(代码实践)

这里写目录标题

  • ChatGPT的答案
  • GPT4.0

ChatGPT的答案

# 定义Transformer模块
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads)self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)self.linear = nn.Linear(input_dim, hidden_dim)def forward(self, x):x = self.transformer_encoder(x)x = self.linear(x)return x# 定义策略网络
class PolicyNetwork(nn.Module):def __init__(self, input_dim, output_dim, hidden_dim, num_heads, num_layers):super(PolicyNetwork, self).__init__()self.transformer = Transformer(input_dim, hidden_dim, num_heads, num_layers)self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):x = self.transformer(x)x = F.softmax(self.fc(x), dim=-1)return x

问:为什么只有Encoder 没有decoder?
在强化学习任务中,特别是在策略梯度方法中,通常不需要一个完整的Transformer模型,包括Encoder和Decoder。这是因为在这些任务中,我们通常只需要一个从状态到动作的映射,而不需要生成序列。

因此,我们可以只使用Transformer的Encoder部分,将状态作为输入,经过编码后得到一个表示状态的向量,然后将这个向量传递给策略网络(或者Critic网络)来进行动作选择或值函数估计。

使用Transformer的Encoder部分可以有效地处理状态的变长输入,并且具有对序列建模的能力,这在处理许多强化学习任务中是非常有用的。

因此,在PPO算法中,我们只需使用Transformer的Encoder部分,而不需要Decoder部分。在强化学习中,尤其是在处理具有复杂状态和动作空间的任务时,能够捕获状态序列间的依赖关系对于预测未来奖励和决策至关重要。Transformer通过其自注意力机制,能够有效地处理这种序列依赖,从而提升策略性能。

GPT4.0

# time: 2024/3/1 21:31
# author: YanJP
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler# Transformer特征提取器
class TransformerFeatureExtractor(nn.Module):def __init__(self, input_dim, model_dim, num_heads, num_layers, dropout=0.1):super(TransformerFeatureExtractor, self).__init__()self.model_dim = model_dimencoder_layers = TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dropout=dropout)self.transformer_encoder = TransformerEncoder(encoder_layer=encoder_layers, num_layers=num_layers)self.input_projection = nn.Linear(input_dim, model_dim)def forward(self, x):x = self.input_projection(x)# 添加一个序列长度的维度,假设x的原始形状为(batch, model_dim)x = x.unsqueeze(0)  # 现在x的形状变为(1, batch, model_dim)output = self.transformer_encoder(x)return output.squeeze(0)  # 移除序列长度的维度,恢复到(batch, model_dim)# PPO网络定义
class PPONetwork(nn.Module):def __init__(self, state_dim, action_dim, model_dim=64, num_heads=4, num_layers=4):super(PPONetwork, self).__init__()self.feature_extractor = TransformerFeatureExtractor(input_dim=state_dim, model_dim=model_dim,num_heads=num_heads, num_layers=num_layers)self.policy_head = nn.Linear(model_dim, action_dim)self.value_head = nn.Linear(model_dim, 1)def forward(self, state):features = self.feature_extractor(state)# features = features[:, -1, :]  # 使用最后一个时间步的特征action_probs = torch.softmax(self.policy_head(features), dim=-1)state_values = self.value_head(features)return action_probs, state_values# PPO Agent
class PPOAgent:def __init__(self, env):self.env = envself.state_dim = env.observation_space.shape[0]self.action_dim = env.action_space.nself.network = PPONetwork(self.state_dim, self.action_dim)self.optimizer = optim.Adam(self.network.parameters(), lr=2.5e-4)self.gamma = 0.99self.lamda = 0.95self.eps_clip = 0.2self.K_epoch = 4self.buffer_capacity = 1000self.batch_size = 64self.buffer = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'is_terminals': []}def select_action(self, state):state = torch.FloatTensor(state).unsqueeze(0)with torch.no_grad():action_probs, _ = self.network(state)dist = Categorical(action_probs)action = dist.sample()return action.item(), dist.log_prob(action)def put_data(self, transition):self.buffer['states'].append(transition[0])self.buffer['actions'].append(transition[1])self.buffer['log_probs'].append(transition[2])self.buffer['rewards'].append(transition[3])self.buffer['is_terminals'].append(transition[4])def train_net(self):R = 0discounted_rewards = []for reward, is_terminal in zip(reversed(self.buffer['rewards']), reversed(self.buffer['is_terminals'])):if is_terminal:R = 0R = reward + (self.gamma * R)discounted_rewards.insert(0, R)discounted_rewards = torch.tensor(discounted_rewards, dtype=torch.float32)old_states = torch.tensor(np.array(self.buffer['states']), dtype=torch.float32)old_actions = torch.tensor(self.buffer['actions']).view(-1, 1)old_log_probs = torch.tensor(self.buffer['log_probs']).view(-1, 1)# Normalize the rewardsdiscounted_rewards = (discounted_rewards - discounted_rewards.mean()) / (discounted_rewards.std() + 1e-5)for _ in range(self.K_epoch):for index in BatchSampler(SubsetRandomSampler(range(len(self.buffer['states']))), self.batch_size, False):# Extract batchesstate_sample = old_states[index]action_sample = old_actions[index]old_log_probs_sample = old_log_probs[index]returns_sample = discounted_rewards[index].view(-1, 1)# Get current policiesaction_probs, state_values = self.network(state_sample)dist = Categorical(action_probs)entropy = dist.entropy().mean()new_log_probs = dist.log_prob(action_sample.squeeze(-1))# Calculating the ratio (pi_theta / pi_theta__old):ratios = torch.exp(new_log_probs - old_log_probs_sample.detach())# Calculating Surrogate Loss:advantages = returns_sample - state_values.detach()surr1 = ratios * advantagessurr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantagesloss = -torch.min(surr1, surr2) + 0.5 * (state_values - returns_sample).pow(2) - 0.01 * entropy# take gradient stepself.optimizer.zero_grad()loss.mean().backward()self.optimizer.step()self.buffer = {'states': [], 'actions': [], 'log_probs': [], 'rewards': [], 'is_terminals': []}def train(self, max_episodes):for episode in range(max_episodes):state = self.env.reset()done = Falserewards=0while not done:action, log_prob = self.select_action(state)next_state, reward, done, _ = self.env.step(action)rewards+=rewardself.put_data((state, action, log_prob, reward, done))state = next_stateif done:self.train_net()if episode % 5 == 0:print("eposide:", episode, '\t reward:', rewards)# 主函数
def main():env = gym.make('CartPole-v1')agent = PPOAgent(env)max_episodes = 300agent.train(max_episodes)if __name__ == "__main__":main()

注意:代码能跑,但是不能正常学习到策略!!!!!!!!!!!!!!!!!!!!!!!!!!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/712949.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue3中组件通讯的方式

Vue3中组件通讯的方式 1 🤖GPT🤖: (答案有点问题混淆了vue2的内容) 父组件向子组件传递数据 props 子组件通过 props 属性从父组件接收数据。emit事件子组件通过emit 事件 子组件通过 emit事件子组件通过emit 发射事件向父组件发送消息。provide / in…

Java SpringCloud gateway面试题

Java SpringCloud gateway面试题 前言1、什么是网关Zuul(gateway)?2、服务网关的作用?3、Zuul网关(Gateway)如何搭建集群?4、ZuulFilter常用有那些方法?5、如何实现动态zuul网关路由转发?6、在Z…

kubeadm安装部署

目录 1.要求 2.环境准备 3.所有节点安装docker 4.所有节点安装kubeadm,kubelet和kubectl 5.部署K8S集群 6.测试 7.扩展3个副本 8.部署Dashboard master(2C/4G,cpu核心数要求大于2)192.168.27.10docker、kubeadm、kubelet、…

LightDB - ecpg 支持dml 中使用 return into 【24.1】

在之前的版本中ecpg 中只能使用returning into 来给c 变量赋值,如下: exec sql update t1 set c aa where id 2 returning c into :c_val;为了兼容oracle pro*c 中return into 的用法,从24.1 开始, LightDB 也支持通过return in…

Chrome插件 | WEB 网页数据采集和爬虫程序

无边无形的互联网遍地是数据,品类丰富、格式繁多,包罗万象。数据采集,或说抓取,就是把分散各处的内容,通过各种方式汇聚一堂,是个有讲究要思考的体力活。君子爱数,取之有道,得注意遵…

mobile app 安全扫描工具MobSF了解下

可以干啥: static 静态分析 dynamic 动态分析 可以用来渗透了 如何docker安装 docker image 下载地址https://hub.docker.com/r/opensecurity/mobile-security-framework-mobsf/ setup 两行即可 1 docker pull opensecurity/mobile-security-framework-mobsf…

关于VScode远程编写linux SHELL的报错处理

使用vscode远程编写linux保存shell时,提示报错: 未能保存“shell”: 无法写入文件"vscode-remote:.../tmp/shell"(NoPermissions (FileSystemError): Error: EACCES: permission denied, open /tmp/shell) 大体意思是说:权限被拒…

Python | 从子目录文件导入父目录模块的方法

问题描述 我有两级目录,第一级称为parent_dir,第二级称为child_dir。现在在child_dir下,有一个py,称为child.py,在parent_dir下,也有一个py,称为parent.py。 我想从child.py中导入parent.py中…

Go Slice的底层实现原理深度解析

文章目录 切片的诞生:数组的延伸切片的结构初始化切片 切片的内存管理扩容机制 实例分析:切片的动态特性切片与性能性能对比 切片的并发安全并发场景下的切片操作 切片与接口切片与空接口 切片的遍历与操作遍历切片切片的切片操作 切片的垃圾回收切片的生…

年轻人怎么搞钱?

年轻人想要搞钱,可以考虑以下几个方面: 1. 创业:年轻人可以通过自己的创意,找到一个市场的空缺,开创自己的业务。可以从比较小的项目开始,逐渐扩大范围,积累经验和财富。 2. 投资:…

成为大佬之路--linux软件安装使用第000000021篇--linux安装docker

简介 Docker 是一个开源项目,诞生于 2013 年初,最初是 dotCloud 公司内部的一个业余项目。它基于 Google 公司推出的 Go 语言实现。 项目后来加入了 Linux 基金会,遵从了 Apache 2.0 协议,项目代码在 [GitHub](https://github.co…

Hadoop之HDFS——【模块二】数据管理

一、Namespace的概述 1.1.集群与命名空间的关系 类似于大集群与小集群之间的关系,彼此之间独立又相互依存。每个namespace彼此独立,Namespace工作时只负责维护本区域的数据,同时所有的namespace维护的文件都可以共用DataNode节点,为了区分数据属于哪些Namespace,DataNode…

强大而灵活的python装饰器

装饰器(Decorators) 一、概述 在Python中,装饰器是一种特殊类型的函数,它允许我们修改或增强其他函数的功能,而无需修改其源代码。装饰器在函数定义之后立即调用,并以函数对象作为参数。装饰器返回一个新…

力扣151--反转字符串中的单词(优)

清晰易懂,简单高效! 大体思路: 每次截取到想要的单词,拼接到新的sb中,过程中伴随双指针进行空格位置指向控制, 其中如果start指针如果0的情况要放在第一个判断条件防止边界条件失效,并且这种…

Linux系统运维脚本:shell脚本查看一定网段范围在线网络设备的ip地址和不在线的网络设备的数量(查看在线和不在线网络设备)

目 录 一、需求说明 二、解决方案 (一)解决思路 (二)方案 三、脚本程序实现 (一)脚本代码和解释 1、脚本代码 2、代码解释 (二)脚本验证 1、脚本编辑…

CrossOver 24下载-CrossOver 24 for Mac下载 v24.0.0中文永久版

CrossOver 24是一款可以让mac用户能够自由运行和游戏windows游戏软件的虚拟机类应用,虽然能够虚拟windows但是却并不是一款虚拟机,也不需要重启系统或者启动虚拟机,类似于一种能够让mac系统直接运行windows软件的插件。它以其出色的跨平台兼容…

NVMe开发——PCIe复位

简介 PCIe中有4种复位机制,早期的3种被称为传统复位(Conventional Reset)。传统复位中的前2种又称为基本复位(Fundamental Resets),分别为冷复位(Cold Reset),暖复位(Warm Reset)。第3种复位为热复位(Hot Reset)。第4种复位被称为功能级复位…

js 正则记录

正则表达式 正则表达式创建一个正则表达式修饰符常用的特殊字符使用正则表达式的方法replace指定字符串作为替换项使用场景:交换字符串中的两个单词将"-"链接的方式改为驼峰式(忽略开头的-)将华氏温度转换为响应的摄氏温度 常用正则示例判断输入是否是正确…

使用docker安装dolphinscheduler

1、前提是安装docker和docker-compose 2、#mkdir /data/dolphinscheduler 3、镜像 docker load -i dolphinscheduler-mysql-driver.tar docker pull zookeeper:3.6.2:3.6.2 docker tag a7 bitnami/zookeeper:3.6.2 理论上postgresql也可以在线pull,但是在线do…

179基于matlab的2D-VMD处理图像

基于matlab的2D-VMD处理图像,将图片进行VMD分解,得到K个子模态图,将每个模态图进行重构,得到近似的原图。可以利用这点进行图像去噪。程序已调通,可直接运行。 179 2D-VMD 图像分解重构 图像处理 (xiaohongshu.com)