强化学习_06_pytorch-TD3实践(CarRacing-v2)

0、TD3算法原理简介

详见笔者前一篇实践强化学习_06_pytorch-TD3实践(BipedalWalkerHardcore-v3)

1、CarRacing环境观察及调整

Action SpaceBox([-1. 0. 0.], 1.0, (3,), float32)
Observation SpaceBox(0, 255, (96, 96, 3), uint8)

动作空间是[-1~1, 0~1, 0~1], 状态空间是 96 × 96 × 3 96\times96\times3 96×96×3 的图片。

1.1 图片裁剪及跳帧

环境初始的时候有40-50帧是没有意义的,可能还会影响模型训练。同时图片下面黑色部分也是没有太多意义,所以可以直接对图片截取s = s[:84, 6:90]
在这里插入图片描述

对环境进行简单观察会发现,一个step是一帧,一帧很难捕捉动作产生的影响(移动量,奖励等)。所以我们进行跳帧观察(1个action进行n个step,期间累计奖励),从红线看,每隔5帧已经可以看出小车在移动。
在这里插入图片描述

1.2 车驶离赛道判断 & reward调整

我们可以看出在gymnasiumCarRacing-V2连续的环境中没有驶出赛道终止的设定,所以我们可以基于像素进行判断是否驶离赛道。观察三个channel,我们可以看出在第二个channel中可以基于大约75行左右的一行像素进行是否行驶出去的判断
经过试验我们可以直接用s[75, 35:50, 1] 前2个和后2个像素点来判断是否行驶到赛道外。
在这里插入图片描述

    def judge_out_of_route(self, obs):s = obs[:84, 6:90, :]out_sum = (s[75, 35:48, 1][:2] > 200).sum() + (s[75, 35:48, 1][-2:] > 200).sum()return out_sum == 4

在加入了是否行驶到赛道外的判断后,如果判断出了赛道则reward=-10

1.4 对多个输出进行通道叠加FrameStack

进行跳帧可以看出车辆的移动,但是只有多张的连续输入,CNN才能感知连续的动作。所以我们这两将4次跳帧组成一个observe,即最终20个step返回一个observe和叠加reward
在这里插入图片描述

1.5 最终环境构建python code

import gymnasium as gym
import torch
import numpy as np
from torchvision import transforms
from gymnasium.spaces import Box
from gymnasium.wrappers import FrameStackclass CarV2SkipFrame(gym.Wrapper):def __init__(self, env, skip: int):"""skip frameArgs:env (_type_): _description_skip (int): skip frames"""super().__init__(env)self._skip = skipdef step(self, action):tt_reward_list = []done = Falsetotal_reward = 0for i in range(self._skip):obs, reward, done, info, _ = self.env.step(action)out_done = self.judge_out_of_route(obs)done_f = done or out_donereward = -10 if out_done else reward# reward = -100 if out_done else reward# reward = reward * 10 if reward > 0 else rewardtotal_reward += rewardtt_reward_list.append(reward)if done_f:breakreturn obs[:84, 6:90, :], total_reward, done_f, info, _def judge_out_of_route(self, obs):s = obs[:84, 6:90, :]out_sum = (s[75, 35:48, 1][:2] > 200).sum() + (s[75, 35:48, 1][-2:] > 200).sum()return out_sum == 4def reset(self, seed=0, options=None):s, info = self.env.reset(seed=seed, options=options)# steering  gas  breakinga = np.array([0.0, 0.0, 0.0])for i in range(45):obs, reward, done, info, _ = self.env.step(a)return obs[:84, 6:90, :], infoclass SkipFrame(gym.Wrapper):def __init__(self, env, skip: int):"""skip frameArgs:env (_type_): _description_skip (int): skip frames"""super().__init__(env)self._skip = skipdef step(self, action):total_reward = 0.0done = Falsefor _ in range(self._skip):obs, reward, done, info, _ = self.env.step(action)total_reward += rewardif done:breakreturn obs, total_reward, done, info, _class GrayScaleObservation(gym.ObservationWrapper):def __init__(self, env):"""RGP -> Gray(high, width, channel) -> (1, high, width) """super().__init__(env)self.observation_space = Box(low=0, high=255, shape=self.observation_space.shape[:2], dtype=np.uint8)def observation(self, observation):tf = transforms.Grayscale()# channel firstreturn tf(torch.tensor(np.transpose(observation, (2, 0, 1)).copy(), dtype=torch.float))class ResizeObservation(gym.ObservationWrapper):def __init__(self, env, shape: int):"""reshape observeArgs:env (_type_): _description_shape (int): reshape size"""super().__init__(env)self.shape = (shape, shape)obs_shape = self.shape + self.observation_space.shape[2:]self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)def observation(self, observation):#  Normalize -> input[channel] - mean[channel]) / std[channel]transformations = transforms.Compose([transforms.Resize(self.shape), transforms.Normalize(0, 255)])return transformations(observation).squeeze(0)env_name = 'CarRacing-v2'
env = gym.make(env_name)
SKIP_N = 5
STACK_N = 4
env_ = FrameStack(ResizeObservation(GrayScaleObservation(CarV2SkipFrame(env, skip=SKIP_N)), shape=84), num_stack=STACK_N
)

二、智能体构建

因为是用的CNN,所以需要注意梯度消失的问题。

2.1 actor

主要架构就是CNN + MLP + maxMinScale

  • CNN: 因为环境比较简单第一层用MaxPool2d采样,第二层进行AvgPool2d平滑
    nn.Sequential(nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2),nn.ReLU(),nn.MaxPool2d(2, 2, 0),nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2),nn.ReLU(),nn.AvgPool2d(2, 2, 0),nn.Flatten()
    )
    
  • MLP
    • 对cnn提取的特征进行 LayerNorm (一定程度干预梯度消失)
    • 对最后层全连接层的输出进行 LayerNorm (一定程度干预梯度消失)
  • maxMinScale
    • 最后通过tanh激活层action全部归一化到[-1,1]之间
    • 基于环境的动作上线限,用maxMinScale方式将最终的输出映射到[动作下限,动作上限]

actor 网络

class TD3CNNPolicyNet(nn.Module):"""输入state, 输出action"""def __init__(self, state_dim: int, hidden_layers_dim: typ.List, action_dim: int, action_bound: typ.Union[float, gym.Env]=1.0, state_feature_share: bool=False):super(TD3CNNPolicyNet, self).__init__()self.state_feature_share = state_feature_shareself.low_high_flag = hasattr(action_bound, "action_space")print('action_bound=',action_bound)self.action_bound = action_boundif self.low_high_flag:self.action_high = torch.FloatTensor(action_bound.action_space.low)self.action_low = torch.FloatTensor(action_bound.action_space.high)self.cnn_feature = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2),nn.ReLU(),nn.MaxPool2d(2, 2, 0),nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2),nn.ReLU(),nn.AvgPool2d(2, 2, 0),nn.Flatten())self.cnn_out_ln = nn.LayerNorm([512])self.features = nn.ModuleList()for idx, h in enumerate(hidden_layers_dim):self.features.append(nn.ModuleDict({'linear': nn.Linear(hidden_layers_dim[idx-1] if idx else 512, h),'linear_action': nn.ReLU()}))self.fc_out = nn.Linear(hidden_layers_dim[-1], action_dim)self.final_ln = nn.LayerNorm([action_dim])def max_min_scale(self, act):"""X_std = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0))X_scaled = X_std * (max - min) + min"""# print("max_min_scale(", act, ")")device_ = act.deviceaction_range = self.action_high.to(device_) - self.action_low.to(device_)act_std = (act - -1.0) / 2.0return act_std * action_range.to(device_) + self.action_low.to(device_)def forward(self, state):if len(state.shape) == 3:state = state.unsqueeze(0)try:x = self.cnn_feature(state)except Exception as e:print(state.shape)state = state.permute(0, 3, 1, 2)x = self.cnn_feature(state)x = self.cnn_out_ln(x)for layer in self.features:x = layer['linear_action'](layer['linear'](x))device_ = x.deviceif self.low_high_flag:return self.max_min_scale(torch.tanh(self.final_ln(self.fc_out(x))))return torch.tanh(self.final_ln(self.fc_out(x)).clip(-6.0, 6.0)) * self.action_bound

2.2 critic

  • CNN: 设计同Actor
  • concat状态和action
    • 进行observe和action concat 之前对action进行线性变换(一定程度解决梯度消失 及 原地转圈)
class TD3CNNValueNet(nn.Module):"""输入[state, cation], 输出value"""def __init__(self, state_dim: int, action_dim: int, hidden_layers_dim: typ.List, state_feature_share=False):super(TD3CNNValueNet, self).__init__()self.state_feature_share = state_feature_shareself.q1_cnn_feature = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2),nn.ReLU(),nn.MaxPool2d(2, 2, 0),nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2),nn.ReLU(),nn.AvgPool2d(2, 2, 0),nn.Flatten())self.q2_cnn_feature = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2),nn.ReLU(),nn.MaxPool2d(2, 2, 0),nn.Conv2d(in_channels=16, out_channels=32, kernel_size=4, stride=2),nn.ReLU(),nn.AvgPool2d(2, 2, 0),nn.Flatten())self.features_q1 = nn.ModuleList()self.features_q2 = nn.ModuleList()for idx, h in enumerate(hidden_layers_dim + [action_dim]):self.features_q1.append(nn.ModuleDict({'linear': nn.Linear(hidden_layers_dim[idx-1] if idx else 512, h),'linear_activation': nn.ReLU()}))self.features_q2.append(nn.ModuleDict({'linear': nn.Linear(hidden_layers_dim[idx-1] if idx else 512, h),'linear_activation': nn.ReLU()}))self.act_q1_fc = nn.Linear(action_dim, action_dim)self.act_q2_fc = nn.Linear(action_dim, action_dim)self.head_q1_bf = nn.Linear(action_dim * 2, action_dim)self.head_q2_bf = nn.Linear(action_dim * 2, action_dim)self.head_q1 = nn.Linear(action_dim, 1)self.head_q2 = nn.Linear(action_dim, 1)def forward(self, state, action):if len(state.shape) == 3:state = state.unsqueeze(0)try:x1 = self.q1_cnn_feature(state)x2 = self.q2_cnn_feature(state)except Exception as e:state = state.permute(0, 3, 1, 2)x1 = self.q1_cnn_feature(state)x2 = self.q2_cnn_feature(state)for layer1, layer2 in zip(self.features_q1, self.features_q2):x1 = layer1['linear_activation'](layer1['linear'](x1))x2 = layer2['linear_activation'](layer2['linear'](x2))# 拼接状态和动作act1 = torch.relu(self.act_q1_fc(action.float()))act2 = torch.relu(self.act_q2_fc(action.float()))x1 = torch.relu( self.head_q1_bf(torch.cat([x1, act1], dim=-1).float()))# print("torch.cat([x1, action], dim=-1)=", torch.cat([x1, act1], dim=-1)[:5, :])x2 = torch.relu( self.head_q2_bf(torch.cat([x2, act2], dim=-1).float()))return self.head_q1(x1), self.head_q2(x2)def Q1(self, state, action):if len(state.shape) == 3:state = state.unsqueeze(0)try:x = self.q1_cnn_feature(state)except Exception as e:state = state.permute(0, 3, 1, 2)x = self.q1_cnn_feature(state)for layer in self.features_q1:x = layer['linear_activation'](layer['linear'](x))# 拼接状态和动作act1 = torch.relu(self.act_q1_fc(action.float()))x = torch.relu( self.head_q1_bf(torch.cat([x, act1], dim=-1).float()))return self.head_q1(x) 

2.3 TD3算法简单调整

  1. policy_noise: 分布调整为(mean=0, std=每个维度动作范围) * self.policy_noise
  2. expl_noise: 分布调整为(mean=0, std=每个维度动作范围) * self.train_noise

3、训练

整体训练脚本可以看笔者的github test_TD3.py : CarRacing_TD3_test()

  1. 对训练做了一些调整: 在训练的过程中增加测试阶段:每隔test_ep_freq进行测试
  2. 基于多次测试的奖励均值进行最佳模型参数保存
def CarRacing_TD3_test():env_name = 'CarRacing-v2'gym_env_desc(env_name)env = gym.make(env_name)env = FrameStack(ResizeObservation(GrayScaleObservation(CarV2SkipFrame(env, skip=5)), shape=84), num_stack=4)print("gym.__version__ = ", gym.__version__ )path_ = os.path.dirname(__file__)cfg = Config(env, # 环境参数save_path=os.path.join(path_, "test_models" ,'TD3_CarRacing-v2_test2-3'), seed=42,# 网络参数actor_hidden_layers_dim=[128], # 256critic_hidden_layers_dim=[128],# agent参数actor_lr=2.5e-4, #5.5e-5,critic_lr=1e-3, #7.5e-4,  gamma=0.99,# 训练参数num_episode=15000,sample_size=128,# 环境复杂多变,需要保存多一些bufferoff_buffer_size=1024*100,  off_minimal_size=256,max_episode_rewards=50000,max_episode_steps=1200, # 200# agent 其他参数TD3_kwargs={'CNN_env_flag': 1,'pic_shape': env.observation_space.shape,"env": env,'action_low': env.action_space.low,'action_high': env.action_space.high,# soft update parameters'tau': 0.05, # trick2: Delayed Policy Update'delay_freq': 1,# trick3: Target Policy Smoothing'policy_noise': 0.2,'policy_noise_clip': 0.5,# exploration noise'expl_noise': 0.5,# 探索的 noise 指数系数率减少 noise = expl_noise * expl_noise_exp_reduce_factor^t'expl_noise_exp_reduce_factor':  1 - 1e-4})agent = TD3(state_dim=cfg.state_dim,actor_hidden_layers_dim=cfg.actor_hidden_layers_dim,critic_hidden_layers_dim=cfg.critic_hidden_layers_dim,action_dim=cfg.action_dim,actor_lr=cfg.actor_lr,critic_lr=cfg.critic_lr,gamma=cfg.gamma,TD3_kwargs=cfg.TD3_kwargs,device=cfg.device)agent.train()train_off_policy(env, agent, cfg, done_add=False, train_without_seed=True, wandb_flag=False, test_ep_freq=100)agent.load_model(cfg.save_path)agent.eval()env = gym.make(env_name, render_mode='human') # env = FrameStack(ResizeObservation(GrayScaleObservation(CarV2SkipFrame(env, skip=5)), shape=84), num_stack=4)play(env, agent, cfg, episode_count=2)

4、训练结果观察及后续工作

由于上传大小限制5MB, 所以对较多直线部分进行了裁剪

最终训练的时候发现会突然陷入低分状态,可以考虑间隔n(可以设置较大比如2000)个episode和最佳的reward比较,分数低于x%个百分点,就重新载入最佳参数,以继续训练。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/578316.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解决国内大模型痛点的最佳实践方案

1.前言 自AI热潮掀起以来,国内互联网大厂躬身入局,各类机构奋起追赶,创业型企业纷至沓来。业内戏称,一场大模型的“百模大战”已经扩展到“千模大战”。 根据近期中国科学技术信息研究所发布的《中国人工智能大模型地图研究报告…

【2023下算法课设】Gray码的分治构造算法

Gray码是一个长度为2ⁿ的序列,序列中无相同元素,且每个元素都是长度为n位的二进制位串,相邻元素恰好只有1位不同。例如长度为2的格雷码为(000,001,011,010,110,111,101,100),设计分治算法对任意的n值构造相…

iOS设备信息详解

文章目录 ID 体系iOS设备信息详解IDFA介绍特点IDFA新政前世今生获取方式 IDFV介绍获取方式 UUID介绍特点获取方式 UDID介绍获取方式 OpenUDID介绍 Bundle ID介绍分类其他 IP地址介绍获取方式 MAC地址介绍获取方式正常获取MAC地址获取对应Wi-Fi的MAC地址 系统版本获取方式 设备型…

云HIS源码 云HIS解决方案 支持医保功能

云HIS系统重建统一的信息架构体系,重构管理服务流程,重造病人服务环境,向不同类型的医疗机构提供SaaS化HIS服务解决方案。 云HIS作为基于云计算的B/S构架的HIS系统,为基层医疗机构(包括诊所、社区卫生服务中心、乡镇卫…

【贪心算法】专题练习一

欢迎来到Cefler的博客😁 🕌博客主页:那个传说中的man的主页 🏠个人专栏:题目解析 🌎推荐文章:题目大解析(3) 前言 1.什么是贪心算法?——贪婪鼠目寸光 贪心策…

在pyqt5界面中直接设置图标icon,不需要python程序代码!!一步搞定!!

小白轻松玩转pyqt5 1. 第一步:点击mainwindow,然后在windowicon中上传图片即可2. 设置成功总结(对于小白入门pyqt5的一些忠告) 1. 第一步:点击mainwindow,然后在windowicon中上传图片即可 2. 设置成功 总结(对于小白入…

【Java 进阶篇】Jedis 操作 List:Redis中的列表类型

Redis中的列表(List)是一种有序的、可重复的数据类型,支持在列表的两端进行元素的插入和删除操作。Jedis作为Java开发者与Redis交互的工具,提供了丰富的API来操作List类型。本文将深入介绍Jedis如何操作Redis中的List类型数据&…

嵌入式-stm32-用PWM点亮LED实现呼吸灯

一:知识前置 1.1、LED灯怎么才能亮? 答:LED需要低电平才能亮,高电平是灯灭。 1.2、LED灯为什么可以越来越亮,越来越暗? 答:这是用到不同占空比来实现的,控制LED实现呼吸灯&…

陈可之油画|《远古河谷》,古老的三峡

《远古河谷》 尺寸:90x66cm 陈可之2002年绘 《远古河谷》是陈可之先生“白垩纪组画七千万年三峡原生映象”系列作品之一,通过细腻的笔触所呈现的神秘,去体会自然的历史、生命的历史以及人文的历史! 三峡,沉淀了7000多…

AI赋能金融创新:ChatGPT引领量化交易新时代

文章目录 一、引言二、ChatGPT与量化交易的融合三、实践应用:ChatGPT在量化交易中的成功案例四、挑战与前景五、结论《AI时代Python量化交易实战:ChatGPT让量化交易插上翅膀》📚→ [当当](http://product.dangdang.com/29658180.html) | [京东…

web前端 JQuery下拉菜单的案例

浏览器运行结果&#xff1a; JQuery下载&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/17LXZigLQ8yau0toTGj4P_Q?pwd4332 提取码&#xff1a;4332 代码&#xff1a; <!doctype html> <html> <head> <meta charset"UTF-8"><…

WPS复选框里打对号,显示小太阳或粗黑圆圈的问题解决方法

问题描述 WPS是时下最流行的字处理软件之一&#xff0c;是目前唯一可以和微软office办公套件相抗衡的国产软件。然而&#xff0c;在使用WPS的过程中也会出现一些莫名其妙的错误&#xff0c;如利用WPS打开docx文件时&#xff0c;如果文件包含复选框&#xff0c;经常会出…

自定义注解结合Hutool对SpringBoot接口返回数据进行脱敏

首先说到脱敏问题,我相信在座的很多人都需要处理这样的场景,比如前端页面显示的身份证号、地址等敏感信息都需要脱敏处理,而hutool就有这样的一个工具来辅助我们完成对某些字段属性信息的脱敏,hutool没有现成的实现方式,只是借助这个工具帮助我们来具体实现 前言 我们在…

【Vue2+3入门到实战】(4)Vue基础之指令修饰符 、v-bind对样式增强的操作、v-model应用于其他表单元素 详细示例

目录 一、今日学习目标1.指令补充 二、指令修饰符1.什么是指令修饰符&#xff1f;2.按键修饰符3.v-model修饰符4.事件修饰符 三、v-bind对样式控制的增强-操作class1.语法&#xff1a;2.对象语法3.数组语法4.代码练习 四、京东秒杀-tab栏切换导航高亮1.需求&#xff1a;2.准备代…

RHCE9学习指南 第7章 服务管理

刚装好Windows系统时&#xff0c;需要进行一些优化&#xff0c;如下图所示。 右键单击所得菜单&#xff0c;可以看到一些按钮包括重启、停止、启动该服务。这些管理的是这个服务的当前状态。 双击服务名&#xff0c;在启动类型中设置的是系统启动时&#xff0c;这个服务要不要…

git之UGit可视化工具使用

一、下载安装UGit 链接&#xff1a;https://pan.baidu.com/s/1KGJvWkFL91neI6vAxjGAag?pwdsyq1 提取码&#xff1a;syq1 二 、使用SSH进行远程仓库连接 1.生成SSH密钥 由于我们的本地 git仓库和 gitee仓库之间的传输是通过SSH加密的&#xff0c;所以我们需要配置SSH公钥。才…

​ iOS技术博客:App备案指南

&#x1f4dd; 摘要 本文介绍了移动应用程序&#xff08;App&#xff09;备案的重要性和流程。备案是规范App开发和运营的必要手段&#xff0c;有助于保护用户权益、维护网络安全和社会秩序。为了帮助开发者更好地了解备案流程&#xff0c;本文提供了一份最新、最全、最详的备…

蓝牙物联网通信网络设计方案

随着当前经济的快速发展&#xff0c;社会运行节奏加快&#xff0c;人们更倾向于选择高效的出行方式&#xff0c;而飞机就是其中之一。近年来&#xff0c;全国各地机场的吞吐量不断增长&#xff0c;导致航站楼面积过大&#xff0c;而 GPS全球定位系统在室内感测不到卫星信号无法…

RPC(6):RMI实现RPC

1RMI简介 RMI(Remote Method Invocation) 远程方法调用。 RMI是从JDK1.2推出的功能&#xff0c;它可以实现在一个Java应用中可以像调用本地方法一样调用另一个服务器中Java应用&#xff08;JVM&#xff09;中的内容。 RMI 是Java语言的远程调用&#xff0c;无法实现跨语言。…

基于java+控件台+mysql的学生信息管理系统(含演示视频)

基于java控件台mysql的学生信息管理系统_含演示视频 一、系统介绍二、功能展示1.项目内容2.项目骨架3.数据库4.登录系统5.新增学生6.查询学生7.修改学生8.删除学生9.退出系统 四、其它1.其他系统实现五.获取源码 一、系统介绍 项目类型&#xff1a;Java SE项目&#xff08;控制…