(3)Gymnasium--CartPole的测试基于DQN

 1、使用Pytorch基于DQN的实现

1.1 主要参考

(1)推荐pytorch官方的教程

Reinforcement Learning (DQN) Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation

(2)

Pytorch 深度强化学习 – CartPole问题|极客笔记

2.2 pytorch官方的教程原理

待续,这两天时期多,过两天整理一下。

2.3代码实现

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import countimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Fenv = gym.make("CartPole-v1")# set up matplotlib
# is_ipython = 'inline' in matplotlib.get_backend()
# if is_ipython:
#     from IPython import displayplt.ion()# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")Transition = namedtuple('Transition',('state', 'action', 'next_state', 'reward'))class ReplayMemory(object):def __init__(self, capacity):self.memory = deque([], maxlen=capacity)def push(self, *args):"""Save a transition"""self.memory.append(Transition(*args))def sample(self, batch_size):return random.sample(self.memory, batch_size)def __len__(self):return len(self.memory)class DQN(nn.Module):def __init__(self, n_observations, n_actions):super(DQN, self).__init__()self.layer1 = nn.Linear(n_observations, 128)self.layer2 = nn.Linear(128, 128)self.layer3 = nn.Linear(128, n_actions)# Called with either one element to determine next action, or a batch# during optimization. Returns tensor([[left0exp,right0exp]...]).def forward(self, x):x = F.relu(self.layer1(x))x = F.relu(self.layer2(x))return self.layer3(x)# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)steps_done = 0def select_action(state):global steps_donesample = random.random()eps_threshold = EPS_END + (EPS_START - EPS_END) * \math.exp(-1. * steps_done / EPS_DECAY)steps_done += 1if sample > eps_threshold:with torch.no_grad():# t.max(1) will return the largest column value of each row.# second column on max result is index of where max element was# found, so we pick action with the larger expected reward.return policy_net(state).max(1)[1].view(1, 1)else:return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)episode_durations = []def plot_durations(show_result=False):plt.figure(1)durations_t = torch.tensor(episode_durations, dtype=torch.float)if show_result:plt.title('Result')else:plt.clf()plt.title('Training...')plt.xlabel('Episode')plt.ylabel('Duration')plt.plot(durations_t.numpy())# Take 100 episode averages and plot them tooif len(durations_t) >= 100:means = durations_t.unfold(0, 100, 1).mean(1).view(-1)means = torch.cat((torch.zeros(99), means))plt.plot(means.numpy())plt.pause(0.001)  # pause a bit so that plots are updated# if is_ipython:#     if not show_result:#         display.display(plt.gcf())#         display.clear_output(wait=True)#     else:#         display.display(plt.gcf())def optimize_model():if len(memory) < BATCH_SIZE:returntransitions = memory.sample(BATCH_SIZE)# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for# detailed explanation). This converts batch-array of Transitions# to Transition of batch-arrays.batch = Transition(*zip(*transitions))# Compute a mask of non-final states and concatenate the batch elements# (a final state would've been the one after which simulation ended)non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,batch.next_state)), device=device, dtype=torch.bool)non_final_next_states = torch.cat([s for s in batch.next_stateif s is not None])state_batch = torch.cat(batch.state)action_batch = torch.cat(batch.action)reward_batch = torch.cat(batch.reward)# Compute Q(s_t, a) - the model computes Q(s_t), then we select the# columns of actions taken. These are the actions which would've been taken# for each batch state according to policy_netstate_action_values = policy_net(state_batch).gather(1, action_batch)# Compute V(s_{t+1}) for all next states.# Expected values of actions for non_final_next_states are computed based# on the "older" target_net; selecting their best reward with max(1)[0].# This is merged based on the mask, such that we'll have either the expected# state value or 0 in case the state was final.next_state_values = torch.zeros(BATCH_SIZE, device=device)with torch.no_grad():next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]# Compute the expected Q valuesexpected_state_action_values = (next_state_values * GAMMA) + reward_batch# Compute Huber losscriterion = nn.SmoothL1Loss()loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))# Optimize the modeloptimizer.zero_grad()loss.backward()# In-place gradient clippingtorch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)optimizer.step()if torch.cuda.is_available():num_episodes = 600
else:# num_episodes = 50num_episodes = 600for i_episode in range(num_episodes):# Initialize the environment and get it's statestate, info = env.reset()state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)for t in count():action = select_action(state)observation, reward, terminated, truncated, _ = env.step(action.item())reward = torch.tensor([reward], device=device)done = terminated or truncatedif terminated:next_state = Noneelse:next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)# Store the transition in memorymemory.push(state, action, next_state, reward)# Move to the next statestate = next_state# Perform one step of the optimization (on the policy network)optimize_model()# Soft update of the target network's weights# θ′ ← τ θ + (1 −τ )θ′target_net_state_dict = target_net.state_dict()policy_net_state_dict = policy_net.state_dict()for key in policy_net_state_dict:target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)target_net.load_state_dict(target_net_state_dict)if done:episode_durations.append(t + 1)plot_durations()breakprint('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/16168.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

bug篇之基于docker安装nacos(2.1.1)使用dubbo连接不上的问题

说明&#xff1a;首先我的nacos安装是2.1.1版本&#xff0c;请注意版本问题。另外启动时用dubbo的话必须先启动服务提供者再启动服务使用者&#xff0c;否则会报错&#xff0c;同时也必须开放三个端口&#xff1a;8848&#xff0c;9848&#xff0c;9849 java.lang.IllegalStat…

Python入门【__init__ 构造方法和 __new__ 方法、类对象、类属性、类方法、静态方法、内存分析实例对象和类对象创建过程(重要)】(十四)

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱敲代码的小王&#xff0c;CSDN博客博主,Python小白 &#x1f4d5;系列专栏&#xff1a;python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 &#x1f4e7;如果文章知识点有错误…

防止表单的重复提交

思想 打开页面时&#xff0c;生成一个token&#xff0c;将这个token保存到Session中&#xff0c;在表单中提供一个隐藏域&#xff0c;设置其值为每1步中生成的token在处理表单的Servlet中&#xff0c;获取表单隐藏域中的token与Session中的token进行比较&#xff0c;比较完之后…

设计模式——简单工厂模式

1 概述 将创造对象的工作交给一个单独的类来实现 &#xff0c;这个单独的类就是工厂。 2 实现 假设要做一个计算器的需求&#xff0c;通常我们想到的是这样写&#xff1a; package com.example.easyfactory;import java.util.Scanner;public class Demo1 {public static vo…

spring学习笔记十五

Spring整合Mybatis 1、导入pom坐标 <dependency><groupId>com.alibaba</groupId><artifactId>druid</artifactId><version>1.1.16</version></dependency><!-- https://mvnrepository.com/artifact/c3p0/c3p0 --><!…

SpringBoot Jackson 日期格式化统一配置

目录 1.在全局配置文件配置 2.通过JavaBean方式配置 1.在全局配置文件配置 spring:jackson:date-format: yyyy-MM-dd HH:mm:sstime-zone: GMT8 该配置方式仅支持 Date 类型的日期格式化&#xff0c;不支持LocalDate 及 LocalDateTime 的格式化。 2.通过JavaBean方式配置 …

[C++] 类与对象(上)

目录 1、前言 2、类的引入 3、类的定义 3.1 类的两种定义方式 4、类的访问限定符 5、类的作用域 6、类的实例化 7、类对象模型 7.1 内存对齐规则 7.1 类对象的存储方式 8、this指针 8.1 this指针的特性 8.2 this指针是否可以为空 1、前言 C语言是面向过程的&#…

DUBBO服务多网卡,服务调用失败

如果服务器是多网卡的&#xff0c;比如安装了docker&#xff0c;有一个docker虚拟网卡&#xff0c;一个实体网卡eth0&#xff0c;当我们运行springboot应用后&#xff0c;dubbo注入到zk的地址是 docker虚拟网卡的地址172网段&#xff0c;而不是实际内网地址192网段&#xff0c;…

类的封装和包(JAVA)

封装 所有的OOP语言都会有三个特征&#xff1a; 封装&#xff1b;继承&#xff1b;多态。 本篇文章会为大家带来有关封装的知识。 在我们日常生活中可以看到电视就只有那么几个按键&#xff08;开关&#xff0c;菜单……&#xff09;和一些接口&#xff0c;而而我们通过这些东…

【计算机视觉|人脸建模】SOFA:基于风格、由单一示例的2D关键点驱动的3D面部动画

本系列博文为深度学习/计算机视觉论文笔记&#xff0c;转载请注明出处 标题&#xff1a;SOFA: Style-based One-shot 3D Facial Animation Driven by 2D landmarks 链接&#xff1a;SOFA: Style-based One-shot 3D Facial Animation Driven by 2D landmarks | Proceedings of …

jmeter压力测试指标解释

目录 RT(response time) Throughput 吞吐量 并发用户数 QPS (query per seconds) TPS (transition per seconds) PV和UV 聚合报告&#xff1a; RT(response time) 什么是RT? RT就是指系统在接收到请求和做出相应这段时间跨度 但是值得一提的是RT的值越高,并不真的就能…

什么是云原生和 CNCF?

一、CNCF简介 CNCF&#xff1a;全称Cloud Native Computing Foundation&#xff08;云原生计算基金会&#xff09;&#xff0c;成立于 2015 年 12 月 11 日&#xff0c;是一个开源软件基金会&#xff0c;它致力于云原生&#xff08;Cloud Native&#xff09;技术的普及和可持续…

Klipper seria.c 文件代码分析

一. 前言 Klipper 底层硬件的串口模块程序写的是否正确是决定下位机与上位机能否正常通信的前提&#xff0c;如果这个文件的驱动没写好&#xff0c;那上位机控制下位机就无从谈起&#xff0c;更无法通过上位机去验证下位机程序的正确性。 本篇博文将详细解析 Klipper src 文件夹…

Android—Monkey用法

文章目录 Monkey知识 Monkey知识 介绍 Monkey是Android中的一个命令行工具&#xff0c;可以运行在模拟器里或实际设备中。它向系统发送伪随机的用户事件流(如按键输入、触摸屏输入、手势输入等)&#xff0c;实现对正在开发的应用程序进行压力测试。Monkey测试是一种为了测试软…

SQL server 简介

SQL server 简介 学习目的 SQL Server 是由微软公司开发的一种关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;用于存储和检索数据。它提供了一个可扩展的、安全的和可靠的数据存储和管理解决方案。 SQL Server 主要用于构建企业级应用程序&#xff0c;支持…

js笔记总结

prototype 属性的作用 JavaScript 规定&#xff0c;每个函数都有一个prototype属性&#xff0c;指向一个对象。 function f() {} typeof f.prototype // "object" ​ 上面代码中&#xff0c;函数f默认具有prototype属性&#xff0c;指向一个对象。 对于普通函数来…

809协议

809协议 目录概述需求&#xff1a; 设计思路实现思路分析1.809协议数据流——链路管理类 参考资料和推荐阅读 Survive by day and develop by night. talk for import biz , show your perfect code,full busy&#xff0c;skip hardness,make a better result,wait for change,…

在idea中添加try/catch的快捷键

在idea中添加try/catch的快捷键 在idea中添加try/catch的快捷键 ctrlaltt 选中想被try/catch包围的语句&#xff0c;同时按下ctrlaltt&#xff0c; 出现下图 选择try/catch即可。

HTML5 Web Worker

HTML5 Web Worker是一种浏览器提供的JavaScript多线程解决方案&#xff0c;它允许在后台运行独立于页面主线程的脚本&#xff0c;从而避免阻塞页面的交互和渲染。Web Worker可以用于执行计算密集型任务、处理大量数据、实现并行计算等&#xff0c;从而提升前端应用的性能和响应…

Elasticsearch搜索引擎系统入门

目录 【认识Elasticsearch】 Elasticsearch主要应用场景 Elasticsearch的版本与升级 【Elastic Stack全家桶】 Logstash Kibana Beats Elasticsearch在日志场景的应用 Elasticsearch与数据库的集成 【安装Elasticsearch】 安装插件 安装Kibana 安装Logstash 【认…