GNN+A2C 强化学习训练一个粒子避障决策模型

最近想尝试下使用GNN+ A2C 进行强化学习,GNN 可以充当一个特征提取器,这样可以增加强化学状态空间因为张量长度受限泛化能力不足的缺点,之前做强化学习的时候受限于需要在环境里提取每个对手的特征,在每个不同场景下因为对手的数量是变化的,对应的状态空间也得一一对应,每个场景的训练都是定制化的。 提高泛化能力一直是模型训练推理的一个课题。

恰好最近在看图神经网络相关的内容,这里贴点废话回头找工作面试官问起来可以喷一些

图神经网络(GNN)是一种专门用于处理图结构数据的深度学习方法,具有以下优点:

优点

  1. 处理非欧几里得数据:GNN 能够有效处理图结构数据,而传统的神经网络主要处理的是欧几里得数据(如图像和文本)。

  2. 捕捉节点间关系:GNN 可以利用节点之间的连接关系来捕捉复杂的结构信息,从而更好地理解数据的上下文。

  3. 灵活性:GNN 可以应用于各种类型的图,包括无向图、有向图和加权图,适应性强。

  4. 共享参数:GNN 的参数可以在图的不同部分共享,这减少了模型的复杂性和训练时间。

  5. 强大的特征学习能力:GNN 能够自动学习节点的特征表示,并通过聚合邻居节点的信息来更新自身的特征。

  6. 适用于多种任务:GNN 可以用于节点分类、边预测、图分类等多种任务。

图神经网络通过以下方式扩展了特征:

  1. 邻居信息聚合:GNN 通过聚合邻居节点的特征,将局部结构信息融入到节点的特征中,从而生成更丰富的特征表示。

  2. 多层堆叠:通过多层堆叠的网络结构,GNN 能够逐步捕捉更高阶的邻接信息,使得节点的特征不仅反映自身的信息,还能反映其邻居的特征和关系。

  3. 动态更新:节点特征在每一层中不断更新,使得特征能够随着图的结构变化而变化,从而增强了模型的表达能力。

图神经网络(GNN)与强化学习(RL)的结合,形成了图强化学习(Graph Reinforcement Learning),这种结合具有多种优点:

  • 结构化数据处理:GNN 能够有效处理图结构数据,使得 RL 能够在复杂的环境中(如社交网络、交通网络等)做出更好的决策。

  • 信息传递:GNN 通过节点间的信息传递,将邻居节点的状态和特征引入到决策过程中,提高了智能体对环境的理解。

  • 特征学习:GNN 可以自动学习图中节点的特征表示,帮助强化学习算法更好地估计状态值和动作值,提升策略的性能。

  • 上下文感知:结合 GNN 的强化学习能够更好地捕捉环境的动态变化,适应不同的上下文,从而提高决策的灵活性和准确性。

这么说可能很多人还不是清楚优点具体是什么,下面我用个试验的例子来说明

首先制作CS架构的粒子干扰避障的游戏,这里就不细讲了,直接上代码

服务端代码

# server.py
from flask import Flask, request, jsonify
import threading
import randomapp = Flask(__name__)# 存储主球的位置和粒子
clients = {}
particles = []
particles_number = 30# 初始化粒子
def generate_particles():while len(particles) < particles_number:  # 生成初始粒子particles.append({'x': random.randint(0, 500),'y': random.randint(0, 500),'vx': random.choice([-6, -3, -1, 1, 3, 6]),'vy': random.choice([-6, -3, -1, 1, 3, 6])})def update_particles():while True:for particle in particles:# 更新粒子位置particle['x'] += particle['vx']particle['y'] += particle['vy']# 碰撞边界处理if particle['x'] <= 0 or particle['x'] >= 500:particle['vx'] *= -1if particle['y'] <= 0 or particle['y'] >= 500:particle['vy'] *= -1threading.Event().wait(0.1)@app.route('/register', methods=['POST'])
def register_client():client_id = request.json.get('id')clients[client_id] = {'position': {'x': 250, 'y': 250}}  # 初始化主球位置return jsonify(success=True)width, height = 500, 500
ball_radius = 15@app.route('/move/<client_id>', methods=['POST'])
def move(client_id):direction = request.json.get('direction')if client_id in clients:# 获取当前球的位置position = clients[client_id]['position']if direction == 'up':new_y = position['y'] - 10if new_y >= 0:  # 确保不超出上边界position['y'] = new_yelif direction == 'down':new_y = position['y'] + 10if new_y <= height - ball_radius:  # 确保不超出下边界 (减去半径)position['y'] = new_yelif direction == 'left':new_x = position['x'] - 10if new_x >= 0:  # 确保不超出左边界position['x'] = new_xelif direction == 'right':new_x = position['x'] + 10if new_x <= width - ball_radius:  # 确保不超出右边界 (减去半径)position['x'] = new_xreturn jsonify(clients[client_id]['position'])@app.route('/position/<client_id>', methods=['GET'])
def get_position(client_id):if client_id in clients:return jsonify(clients[client_id]['position'])else:return jsonify({'error': 'Client not found'}), 404@app.route('/particles', methods=['GET'])
def get_particles():return jsonify(particles)def run_server():app.run(host='0.0.0.0', port=5000, threaded=True)if __name__ == '__main__':threading.Thread(target=generate_particles, daemon=True).start()threading.Thread(target=update_particles, daemon=True).start()run_server()

客户端代码

import pygame
import requests# 初始化pygame
pygame.init()# 设置窗口大小
width, height = 500, 500
window = pygame.display.set_mode((width, height))
pygame.display.set_caption("Particle Avoidance Game")# 颜色
BLACK = (0, 0, 0)  # 黑色
BLUE = (0, 0, 255)
RED = (255, 0, 0)# 主球初始位置
ball_radius = 15
client_id = 'client1'  # 确保每个客户端使用不同的 ID# 注册客户端
register_response = requests.post('http://127.0.0.1:5000/register', json={'id': client_id})
if register_response.status_code != 200:print("Failed to register client:", register_response.text)def get_ball_position():response = requests.get(f'http://127.0.0.1:5000/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos)  # 打印响应if 'x' in ball_pos and 'y' in ball_pos:  # 确保包含 x 和 yreturn ball_poselse:print("Ball position does not contain 'x' and 'y':", ball_pos)  # 额外调试信息else:print("Failed to get ball position:", response.text)return {'x': 250, 'y': 250}  # 默认值def move_ball(direction):requests.post(f'http://127.0.0.1:5000/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get('http://127.0.0.1:5000/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles)  # 打印响应return particleselse:print("Failed to get particles:", response.text)return []  # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5)  # 粒子的半径为5running = True
while running:for event in pygame.event.get():if event.type == pygame.QUIT:running = Falsekeys = pygame.key.get_pressed()if keys[pygame.K_UP]:move_ball('up')if keys[pygame.K_DOWN]:move_ball('down')if keys[pygame.K_LEFT]:move_ball('left')if keys[pygame.K_RIGHT]:move_ball('right')# 更新球的位置ball_position = get_ball_position()particles = get_particles()# 检查碰撞if 'x' in ball_position and 'y' in ball_position:  # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置if check_collision(ball_position, particle):print("Game Over! You collided with a particle.")running = Falseelse:print("Invalid ball position:", ball_position)  # 调试信息# 渲染window.fill(BLACK)  # 将背景填充为黑色if 'x' in ball_position and 'y' in ball_position:  # 确保球有有效位置pygame.draw.circle(window, BLUE, (ball_position['x'], ball_position['y']), ball_radius)else:print("Ball position is invalid, not drawing.")  # 调试信息for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle['x'], particle['y']), 5)  # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit()

启动程序调试了下,基本可以充当强化学习的环境

后续直接在客户端上添加强化模型的训练代码,把server端代码部署到k8s 上做为训练环境

要设计一个基于图神经网络(GNN)和优势演员-评论家(A2C)算法的强化学习模型,以训练 main_ball 在环境中左右移动,我们可以遵循以下步骤:

1. 确定问题

  • 目标:控制 main_ball 左右移动以避免与粒子碰撞。
  • 状态空间:包含 main_ball 和粒子的状态信息,包括位置和速度。
  • 动作空间:定义 main_ball 的动作(左右上下或保持不动)。

2. 数据结构和环境设计

首先,构建一个环境来模拟 main_ball 和粒子的动态行为。

class Environment:def __init__(self, main_ball, particles):self.main_ball = main_ballself.particles = particlesdef reset(self):# 重置环境,返回初始状态return self.get_state()def get_state(self):# 获取当前状态state = {'main_ball': self.main_ball,'particles': self.particles}return statedef step(self, action):# 根据动作更新环境状态,按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward = 0done = Falsenew_main_ball = get_ball_position()particles = get_particles()COLLISION_THRESHOLD = 20for particle in self.particles:distance = ((self.main_ball['x'] - particle['x']) ** 2 + (self.main_ball['y'] - particle['y']) ** 2) ** 0.5if distance < COLLISION_THRESHOLD:  # 定义一个阈值判断碰撞reward = -1  # 碰撞时给予负奖励break# 检查是否重新开始if 'x' in new_main_ball and 'y' in new_main_ball:  # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置if check_collision(new_main_ball, particle):print("Game Over! You collided with a particle.")done = True  # 结束游戏breakif not done:reward = 5return self.get_state(), reward, done, {}

3. GNN 构建

使用 DGL 构建图结构,以表示 main_ball 和粒子之间的关系。

import dgl
import torchdef create_graph(main_ball, particles):num_particles = len(particles)G = dgl.graph(([], []), num_nodes=num_particles + 1)# 添加主球节点G.ndata['pos'] = torch.zeros(num_particles + 1, 2)G.ndata['pos'][0] = torch.tensor([main_ball['x'], main_ball['y']])# 添加粒子节点for i, particle in enumerate(particles):G.ndata['pos'][i + 1] = torch.tensor([particle['x'], particle['y']])# 添加边和距离权重edges = []distances = []for i in range(num_particles):for j in range(i + 1, num_particles):edges.append((i + 1, j + 1))  # 粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][i + 1] - G.ndata['pos'][j + 1])distances.append(distance.item())edges.append((0, i + 1))  # 主球与粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][0] - G.ndata['pos'][i + 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata['distance'] = torch.tensor(distances)# 添加自环G = dgl.add_self_loop(G)return G

4. A2C 模型设计

使用 PyTorch 设计 A2C 模型,包括 Actor 和 Critic。

在强化学习中,"Actor"(演员)和"Critic"(评论家)是两个关键的角色或组件。

Actor(演员):
Actor 是强化学习中的一个组件,通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作,并将其发送给环境。
Actor 的目标是学习一个策略,即从状态到动作的映射函数,以便在与环境的交互中获得高回报。策略可以是确定性的(直接选择最优动作)或概率性的(选择动作的概率分布)。Actor 的训练目标是最大化预期回报,通常使用梯度上升方法(如策略梯度法)进行优化。

Critic(评论家):
Critic 是强化学习中的另一个组件,用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估,
提供一个值函数(或者动作值函数)来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数,
用于评估不同状态-动作对的价值,并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差,
通常使用时序差分学习(如 Q-learning 或 TD-learning)或函数逼近方法(如神经网络)进行优化。

在某些强化学习算法中,Actor 和 Critic 可以是分离的组件,各自独立进行训练。
Actor 使用 Critic 提供的价值信息来指导动作选择;
而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。

class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 = dgl.nn.GraphConv(2, 128)  # 输入特征为 2:位置self.conv2 = dgl.nn.GraphConv(128, (1 + action_space_dim))  # 隐藏层self.commonCov = nn.Linear((len(particles) +1 ) * (1 + action_space_dim), 128)self.actor = nn.Linear(128, action_space_dim)  # 行动空间self.critic = nn.Linear(128, 1)def forward(self, g):g = dgl.add_self_loop(g)x = g.ndata['pos']x = self.conv1(g, x)x = F.relu(x)x = self.conv2(g, x)x = F.relu(x)x = x.reshape(-1)x = torch.relu(self.commonCov(x))actor = self.actor(x)critic = self.critic(x)return actor, critic

5. 训练循环

设计训练循环,使用强化学习算法更新模型的参数。

# 初始化环境和模型
main_ball = get_ball_position()
particles = get_particles()
env = Environment(main_ball, particles)
device_id = 1
if torch.cuda.is_available():device_id = torch.cuda.current_device()
model = ActorCritic(n_devices=device_id, action_space_dim=len(action_space)).to(device)
# 检查模型文件是否存在
model_path = 'actor_critic_model.pth'
if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print("Model loaded successfully.")model.train()  # 切换到训练模式
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练循环
for episode in range(num_episodes):state = env.reset()done = Falsewhile not done:# 创建图g = create_graph(state['main_ball'], state['particles'])# 前向传播actor_logits, critic_value = model(g)# 选择动作(使用 softmax)action_prob = F.softmax(actor_logits, dim=-1)# 使用 torch.multinomial 选择一个动作action = torch.multinomial(action_prob, num_samples=1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ = env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value = model(g)td_target = reward + 0.95 * next_valuedelta = td_target - critic_value'''这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。'''actor_loss = -action_prob[action] * delta.detach()critic_loss = delta.pow(2)loss = actor_loss + critic_lossprint("epoch = %d, device_id = %d, epoch_loss= %lf" % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state = next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(f"Model saved after episode {episode}.")

通过以上步骤,你可以构建一个 GNN + A2C 的强化学习模型来训练 main_ball 的左右移动策略。

启动训练 num_episodes = 1000

强化学习在对接仿真有没有gpu加速速度并不明显,因为受限于仿真state-action-next_state 这套流程的处理速度,我使用在本机上跑该训练,训练结束后保存模型文件到actor_critic_model.pth

6. 模型推理

使用 actor_critic_model.pth接入到之前的客户端代码上,用模型决策来替换键盘操作

import os
import pygame
import requests
import random
import string
import torch
import torch.nn.functional as F
from g1_client_train import ActorCritic, create_graph, action_space# 初始化pygame
pygame.init()# 设置窗口大小
width, height = 500, 500
window = pygame.display.set_mode((width, height))
pygame.display.set_caption("Particle Avoidance Game")# 颜色
BLACK = (0, 0, 0)  # 黑色
BLUE = (0, 0, 255)
RED = (255, 0, 0)# 主球初始位置
ball_radius = 15
if torch.cuda.is_available():# 使用 CUDA 设备device = torch.device("cuda")
else:# 使用 CPU 设备device = torch.device("cpu")
def generate_random_string(length=5):# 定义可用字符,包括数字和大小写字母characters = string.ascii_letters + string.digits# 随机选择字符并生成字符串random_string = ''.join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串
random_string = generate_random_string()
print(random_string)
client_id=random_stringhost = os.getenv("SERVER_HOST", 'http://192.168.110.126:31007')def register_client():register_response = requests.post(f'{host}/register', json={'id': client_id})if register_response.status_code != 200:print("Failed to register client:", register_response.text)return Falsereturn Truedef get_ball_position():response = requests.get(f'{host}/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos)  # 打印响应if 'x' in ball_pos and 'y' in ball_pos:  # 确保包含 x 和 yreturn ball_posprint("Failed to get ball position:", response.text)return None  # 返回 None 表示失败def move_ball(direction):requests.post(f'{host}/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get(f'{host}/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles)  # 打印响应return particlesprint("Failed to get particles:", response.text)return []  # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5)  # 粒子的半径为5def restart_game():global runningprint("Restarting game...")running = Trueregister_client()  # 重新注册客户端if __name__ == "__main__":# 注册客户端if not register_client():exit()device_id = 1if torch.cuda.is_available():device_id = torch.cuda.current_device()# 更新球的位置main_ball = get_ball_position()particles = get_particles()model = ActorCritic(n_devices=device_id,particles_dim=len(particles),action_space_dim=len(action_space)).to(device)model_path = 'actor_critic_model.pth'model.load_state_dict(torch.load(model_path))running = Truewhile running:for event in pygame.event.get():if event.type == pygame.QUIT:running = Falsemain_ball = get_ball_position()particles = get_particles()g = create_graph(main_ball=main_ball, particles=particles)actor_logits, _ = model(g)action_prob = F.softmax(actor_logits, dim=-1)action = torch.multinomial(action_prob, num_samples=1).item()move_ball(action_space[action])# 检查碰撞if main_ball:  # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置if check_collision(main_ball, particle):print("Game Over! You collided with a particle.")restart_game()breakelse:print("Invalid ball position, restarting game.")restart_game()# 渲染window.fill(BLACK)  # 将背景填充为黑色if main_ball:  # 确保球有有效位置pygame.draw.circle(window, BLUE, (main_ball['x'], main_ball['y']), ball_radius)for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle['x'], particle['y']), 5)  # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit()

录个动图看看效果

有必要再优化下奖励函数,每一步避险操作后的next_state应该是优于之前state

训练的好的模型奖惩函数规则都挺细的

附上完整的训练代码文件

import osimport requests
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl
import dgl.nn as dglnn
import os
import random
import stringfrom g1_server import speed_choicesball_radius = 15
width, height = 500, 500
num_episodes = 1000min_vx = min(speed_choices)
max_vx = max(speed_choices)
min_vy = min(speed_choices)
max_vy = max(speed_choices)def register_client():register_response = requests.post(f'{host}/register', json={'id': client_id})if register_response.status_code != 200:print("Failed to register client:", register_response.text)return Falsereturn Truedef restart_game():global runningprint("Restarting game...")running = Trueregister_client()  # 重新注册客户端def generate_random_string(length=5):# 定义可用字符,包括数字和大小写字母characters = string.ascii_letters + string.digits# 随机选择字符并生成字符串random_string = ''.join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串
random_string = generate_random_string()
print(random_string)
client_id = random_stringhost = os.getenv("SERVER_HOST", 'http://192.168.110.126:31007')
# 注册客户端
register_response = requests.post(f'{host}/register', json={'id': client_id})
if register_response.status_code != 200:print("Failed to register client:", register_response.text)if torch.cuda.is_available():# 使用 CUDA 设备device = torch.device("cuda")
else:# 使用 CPU 设备device = torch.device("cpu")def get_ball_position():response = requests.get(f'{host}/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos)  # 打印响应if 'x' in ball_pos and 'y' in ball_pos:  # 确保包含 x 和 yreturn ball_poselse:print("Ball position does not contain 'x' and 'y':", ball_pos)  # 额外调试信息else:print("Failed to get ball position:", response.text)return {'x': 250, 'y': 250}  # 默认值def move_ball(direction):requests.post(f'{host}/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get(f'{host}/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles)  # 打印响应return particleselse:print("Failed to get particles:", response.text)return []  # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5)  # 粒子的半径为5def create_graph(main_ball, particles):num_particles = len(particles)G = dgl.graph(([], []), num_nodes=num_particles + 1)# 添加主球节点G.ndata['pos'] = torch.zeros(num_particles + 1, 2)G.ndata['pos'][0] = torch.tensor([main_ball['x'], main_ball['y']])# 添加粒子节点for i, particle in enumerate(particles):G.ndata['pos'][i + 1] = torch.tensor([particle['x'], particle['y']])# 添加边和距离权重edges = []distances = []for i in range(num_particles):for j in range(i + 1, num_particles):edges.append((i + 1, j + 1))  # 粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][i + 1] - G.ndata['pos'][j + 1])distances.append(distance.item())edges.append((0, i + 1))  # 主球与粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][0] - G.ndata['pos'][i + 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata['distance'] = torch.tensor(distances)# 添加自环G = dgl.add_self_loop(G)return Gdef compute_state(main_ball, particles):flow = []normalized_main_ball_x = main_ball['x'] / widthnormalized_main_ball_y = main_ball['y'] / heightflow.append([0, 0, normalized_main_ball_x, normalized_main_ball_y])for particle in particles:# 归一化粒子的速度和位置normalized_vx = (particle['vx'] - min_vx) / (max_vx - min_vx)  # 根据你的数据范围进行归一化normalized_vy = (particle['vy'] - min_vy) / (max_vy - min_vy)normalized_x = particle['x'] / widthnormalized_y = particle['y'] / heightflow.append([normalized_vx, normalized_vy, normalized_x, normalized_y])return torch.tensor(flow).float()'''
在强化学习中,"Actor"(演员)和"Critic"(评论家)是两个关键的角色或组件。Actor(演员):
Actor 是强化学习中的一个组件,通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作,并将其发送给环境。
Actor 的目标是学习一个策略,即从状态到动作的映射函数,以便在与环境的交互中获得高回报。策略可以是确定性的(直接选择最优动作)或概率性的(选择动作的概率分布)。Actor 的训练目标是最大化预期回报,通常使用梯度上升方法(如策略梯度法)进行优化。Critic(评论家):
Critic 是强化学习中的另一个组件,用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估,
提供一个值函数(或者动作值函数)来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数,
用于评估不同状态-动作对的价值,并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差,
通常使用时序差分学习(如 Q-learning 或 TD-learning)或函数逼近方法(如神经网络)进行优化。在某些强化学习算法中,Actor 和 Critic 可以是分离的组件,各自独立进行训练。
Actor 使用 Critic 提供的价值信息来指导动作选择;
而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。
'''
class Environment:def __init__(self, main_ball, particles):self.main_ball = main_ballself.particles = particlesdef reset(self):# 重置环境,返回初始状态return self.get_state()def get_state(self):# 获取当前状态state = {'main_ball': self.main_ball,'particles': self.particles}return statedef step(self, action):# 根据动作更新环境状态,按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward = 0done = Falsenew_main_ball = get_ball_position()particles = get_particles()COLLISION_THRESHOLD = 20for particle in self.particles:distance = ((self.main_ball['x'] - particle['x']) ** 2 + (self.main_ball['y'] - particle['y']) ** 2) ** 0.5if distance < COLLISION_THRESHOLD:  # 定义一个阈值判断碰撞reward = -1  # 碰撞时给予负奖励break# 检查是否重新开始if 'x' in new_main_ball and 'y' in new_main_ball:  # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle:  # 确保粒子有有效位置if check_collision(new_main_ball, particle):print("Game Over! You collided with a particle.")done = True  # 结束游戏breakif not done:reward = 5return self.get_state(), reward, done, {}class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 = dgl.nn.GraphConv(2, 128)  # 输入特征为 2:位置self.conv2 = dgl.nn.GraphConv(128, (1 + action_space_dim))  # 隐藏层self.commonCov = nn.Linear((len(particles) +1 ) * (1 + action_space_dim), 128)self.actor = nn.Linear(128, action_space_dim)  # 行动空间self.critic = nn.Linear(128, 1)def forward(self, g):g = dgl.add_self_loop(g)x = g.ndata['pos']x = self.conv1(g, x)x = F.relu(x)x = self.conv2(g, x)x = F.relu(x)x = x.reshape(-1)x = torch.relu(self.commonCov(x))actor = self.actor(x)critic = self.critic(x)return actor, criticaction_space = ['up', 'down', 'left', 'right', None]if not register_client():print('not register')exit()# 初始化环境和模型
main_ball = get_ball_position()
particles = get_particles()
env = Environment(main_ball, particles)
device_id = 1
if torch.cuda.is_available():device_id = torch.cuda.current_device()
model = ActorCritic(n_devices=device_id, action_space_dim=len(action_space)).to(device)
# 检查模型文件是否存在
model_path = 'actor_critic_model.pth'
if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print("Model loaded successfully.")model.train()  # 切换到训练模式
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练循环
for episode in range(num_episodes):state = env.reset()done = Falsewhile not done:# 创建图g = create_graph(state['main_ball'], state['particles'])# 前向传播actor_logits, critic_value = model(g)# 选择动作(使用 softmax)action_prob = F.softmax(actor_logits, dim=-1)# 使用 torch.multinomial 选择一个动作action = torch.multinomial(action_prob, num_samples=1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ = env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value = model(g)td_target = reward + 0.95 * next_valuedelta = td_target - critic_value'''这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。'''actor_loss = -action_prob[action] * delta.detach()critic_loss = delta.pow(2)loss = actor_loss + critic_lossprint("epoch = %d, device_id = %d, epoch_loss= %lf" % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state = next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(f"Model saved after episode {episode}.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/883147.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用excel.js(layui-excel)进行layui多级表头导出,根据单元格内容设置背景颜色,并将导出函数添加到toolbar

本段是菜狗子的碎碎念&#xff0c;解决办法请直接从第二段开始看。layui多级表头的导出&#xff0c;弄了两天才搞定&#xff0c;中途一度想放弃&#xff0c;还好坚持下来了。一开始用的是layui的toolbar里自带的那个导出&#xff0c;但是多级表头没有正常导出&#xff0c;单元格…

【功能安全】技术安全概念TSC

目录 01 TSC定义 02 TSC注意事项 03 TSC案例 01 TSC定义 所处位置 TSC:Technical safety concept技术安全概念 TSR:Technical safety requirement技术安全需求 在系统开发阶段属于安全活动4-6 系统层产品开发示例 TSC目的

Codeforces Round 981 (Div. 3)

前言&#xff1a; 记录一下自己昨天晚上打的div3吧&#xff0c;感觉自己好久没写博客&#xff0c;以后可能会更新一些其他内容&#xff0c;在这里先买个关子&#xff0c;我要现在今年沉淀几个月&#xff0c;所以这几天可能不会更新博客&#xff0c;今天先出来冒个泡先。 正文&a…

数理统计(第3章:单侧假设检验)

目录 概念&#xff0c;步骤 单个正态母体 两个正态母体 概念&#xff0c;步骤 如果构造统计量是一个未知数&#xff0c;则构造不成统计量&#xff0c;所以拿来构造统计量&#xff0c;用保守估计作为假设&#xff1a;有无显著提高&#xff0c;减小&#xff0c;则假设没有显著…

【在Win11下安装ubuntu +图形化界面】

在win11下安装ubuntu 一、安装流程1. 前期准备&#xff1a;先配置好基础设置2. 安装 ubuntu3. ubuntu进行配置4. 下载图形化界面 并安装 二、遇到的问题问题1. win11安装wsl报错&#xff1a;无法解析服务器的名称或地址1. 方法一&#xff1a;更改DNS&#xff08;对本人无效&…

SpringBoot最佳实践之 - 项目中统一记录正常和异常日志

1. 前言 此篇博客是本人在实际项目开发工作中的一些总结和感悟。是在特定需求背景下&#xff0c;针对项目中统一记录日志(包括正常和错误日志)需求的实现方式之一&#xff0c;并不是普适的记录日志的解决方案。所以阅读本篇博客的朋友&#xff0c;可以参考此篇博客中记录日志的…

【问题解决】三维相关:​Unity Package Manager中没有Newtonsoft Json‌​

问题&#xff1a; 在Unity开发中&#xff0c;用到复杂的json的数据格式&#xff0c;需要将对象和json数据之间相互转换。Unity原生json支持不适用复杂json&#xff08;例如嵌套数组、动态键值对等&#xff09;。大部分人推荐直接在Package Manager中搜索导入(如怎么在unity3D工…

Jupyter Notebook 中使用render_notebook渲染pyecharts图像不显示的一种情况

一开始我发现自己的jupyter文件在渲染pyecharts图片时一开始可以显示&#xff0c;但后来不知道怎么的就不显示了&#xff0c;查找了很多方法&#xff0c;但是没有效果&#xff0c;都是改js渲染什么的&#xff0c;还有就是参数不对的&#xff0c;对于我来说都没什么用&#xff0…

excel中,将时间戳(ms或s)转换成yyyy-MM-dd hh:mm.ss或毫秒格式

问题 在一些输出为时间戳的文本中&#xff0c;按照某种格式显示更便于查看。 如下&#xff0c;第一列为时间戳(s)&#xff0c;第二列是转换后的格式。 解决方案&#xff1a; 在公式输入框中输入&#xff1a;yyyy/mm/dd hh:mm:ss TEXT((A18*3600)/8640070*36519, "yyy…

从传统到智能,从被动监控到主动预警,解锁视频安防平台EasyCVR视频监控智能化升级的关键密钥

视频监控技术从传统监控到智能化升级的过程是一个技术革新和应用场景拓展的过程。智能视频监控系统通过集成AI和机器学习算法&#xff0c;能够实现行为分析、人脸识别和异常事件检测等功能&#xff0c;提升了监控的准确性和响应速度。这些系统不仅用于传统的安全防护&#xff0…

Ribbon客户端负载均衡策略测试及其改进

文章目录 一、目的概述二、验证步骤1、源码下载2、导入IDE3、运行前修改配置4、策略说明5、修改策略 三、最终结论四、改进措施1. 思路分析2. 核心代码3. 测试页面 一、目的概述 为了验证Ribbon客户端负载均衡策略在负载节点失效的情况下&#xff0c;是否具有故障转移的功能&a…

一家生物技术企业终止,科创属性可能不足,报告期内专利数猛增

轩凯生物九成以上营业收入来源于植物营养领域&#xff0c;收入来源结构单一&#xff0c;产品下游应用领域较为集中。报告期内公司应收账款账面价值逐年上升&#xff0c;回款比例显著低于前两年&#xff0c;遭交易所问询是否存在较大的坏账风险。 轩凯生物核心技术是否成熟以及是…

【SDL】微软SDL建设指南

【SDL】微软SDL建设指南 1.建立安全标准、指标和治理2.要求使用经过验证的安全功能、语言和框架3.执行安全设计审查和威胁建模4.定义并使用密码学标准5.确保软件供应链安全6.确保工程环境安全7.执行安全测试8.确保运营平台安全9.实施安全监控和响应&#xff08;态势管理或漏洞管…

二十、Innodb底层原理与Mysql日志机制深入剖析

文章目录 一、MySQL的内部组件结构1、Server层1.1、连接器1.2、查询缓存1.3、分析器1.4、优化器1.5、执行器 2、存储引擎层 二、Innodb底层原理与Mysql日志机制1、redo log重做日志关键参数2、binlog二进制归档日志2.1、binlog日志文件恢复数据 3、undo log回滚日志4、错误日志…

群晖通过 Docker 安装 Firefox

1. 获取 firefox 镜像 在注册表搜索 jlesage/firefox&#xff0c;并且下载 2. 创建容器 运行映像 jlesage/firefox&#xff0c;开始创建容器 3. 配置容器 启用自动重新启动&#xff0c;重点配置存储空间和环境变量&#xff0c;其他默认。 创建文件夹&#xff0c;及子文件夹…

高效设备管理:中小企业的Spring Boot解决方案

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及&#xff0c;互联网成为人们查找信息的重要场所&#xff0c;二十一世纪是信息的时代&#xff0c;所以信息的管理显得特别重要。因此&#xff0c;使用计算机来管理中小企业设备管理系统的相关信息成为必然。…

Lucas带你手撕机器学习——SVM支持向量机

#1024程序员节&#xff5c;征文# 支持向量机&#xff08;SVM&#xff09;的详细讲解 什么是SVM&#xff1f; 支持向量机&#xff08;Support Vector Machine&#xff0c;SVM&#xff09;是一种用于分类和回归的监督学习算法。它的主要任务是从给定的数据中找到一个最佳的决策…

原来“有符号数变成无符号数,并不是-1变成1,-15变成15”!!

不怕大家伙笑话&#xff0c;我以前一直以为在C语言中&#xff0c;有符号变无符号仅仅就是去掉数字前面的符号就行&#xff0c;如今做了一道题&#xff0c;细细研究&#xff0c;才发现&#xff0c;原来不是&#xff01; 如果你也感兴趣&#xff0c;那就学学今天这节吧~ 话不多说…

前端必知必会-JavaScript 简介

文章目录 JavaScript 简介JavaScript 可以更改 HTML 内容JavaScript 可以更改 HTML 属性值JavaScript 可以更改 HTML 样式 (CSS)JavaScript 可以隐藏 HTML 元素JavaScript 可以显示 HTML 元素 总结 JavaScript 简介 本页包含一些 JavaScript 功能的示例。 JavaScript 可以更改…

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-20

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-20 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-20目录1. FLARE: Faithful Logic-Aided Reasoning and Exploration摘要研究背景问题与挑战如何解决创新点算法模型实验效果重要数…