最近想尝试下使用GNN+ A2C 进行强化学习,GNN 可以充当一个特征提取器,这样可以增加强化学状态空间因为张量长度受限泛化能力不足的缺点,之前做强化学习的时候受限于需要在环境里提取每个对手的特征,在每个不同场景下因为对手的数量是变化的,对应的状态空间也得一一对应,每个场景的训练都是定制化的。 提高泛化能力一直是模型训练推理的一个课题。
恰好最近在看图神经网络相关的内容,这里贴点废话回头找工作面试官问起来可以喷一些
图神经网络(GNN)是一种专门用于处理图结构数据的深度学习方法,具有以下优点:
优点
-
处理非欧几里得数据:GNN 能够有效处理图结构数据,而传统的神经网络主要处理的是欧几里得数据(如图像和文本)。
-
捕捉节点间关系:GNN 可以利用节点之间的连接关系来捕捉复杂的结构信息,从而更好地理解数据的上下文。
-
灵活性:GNN 可以应用于各种类型的图,包括无向图、有向图和加权图,适应性强。
-
共享参数:GNN 的参数可以在图的不同部分共享,这减少了模型的复杂性和训练时间。
-
强大的特征学习能力:GNN 能够自动学习节点的特征表示,并通过聚合邻居节点的信息来更新自身的特征。
-
适用于多种任务:GNN 可以用于节点分类、边预测、图分类等多种任务。
图神经网络通过以下方式扩展了特征:
-
邻居信息聚合:GNN 通过聚合邻居节点的特征,将局部结构信息融入到节点的特征中,从而生成更丰富的特征表示。
-
多层堆叠:通过多层堆叠的网络结构,GNN 能够逐步捕捉更高阶的邻接信息,使得节点的特征不仅反映自身的信息,还能反映其邻居的特征和关系。
-
动态更新:节点特征在每一层中不断更新,使得特征能够随着图的结构变化而变化,从而增强了模型的表达能力。
图神经网络(GNN)与强化学习(RL)的结合,形成了图强化学习(Graph Reinforcement Learning),这种结合具有多种优点:
结构化数据处理:GNN 能够有效处理图结构数据,使得 RL 能够在复杂的环境中(如社交网络、交通网络等)做出更好的决策。
信息传递:GNN 通过节点间的信息传递,将邻居节点的状态和特征引入到决策过程中,提高了智能体对环境的理解。
特征学习:GNN 可以自动学习图中节点的特征表示,帮助强化学习算法更好地估计状态值和动作值,提升策略的性能。
上下文感知:结合 GNN 的强化学习能够更好地捕捉环境的动态变化,适应不同的上下文,从而提高决策的灵活性和准确性。
这么说可能很多人还不是清楚优点具体是什么,下面我用个试验的例子来说明
首先制作CS架构的粒子干扰避障的游戏,这里就不细讲了,直接上代码
服务端代码
# server.py
from flask import Flask, request, jsonify
import threading
import randomapp = Flask(__name__)# 存储主球的位置和粒子
clients = {}
particles = []
particles_number = 30# 初始化粒子
def generate_particles():while len(particles) < particles_number: # 生成初始粒子particles.append({'x': random.randint(0, 500),'y': random.randint(0, 500),'vx': random.choice([-6, -3, -1, 1, 3, 6]),'vy': random.choice([-6, -3, -1, 1, 3, 6])})def update_particles():while True:for particle in particles:# 更新粒子位置particle['x'] += particle['vx']particle['y'] += particle['vy']# 碰撞边界处理if particle['x'] <= 0 or particle['x'] >= 500:particle['vx'] *= -1if particle['y'] <= 0 or particle['y'] >= 500:particle['vy'] *= -1threading.Event().wait(0.1)@app.route('/register', methods=['POST'])
def register_client():client_id = request.json.get('id')clients[client_id] = {'position': {'x': 250, 'y': 250}} # 初始化主球位置return jsonify(success=True)width, height = 500, 500
ball_radius = 15@app.route('/move/<client_id>', methods=['POST'])
def move(client_id):direction = request.json.get('direction')if client_id in clients:# 获取当前球的位置position = clients[client_id]['position']if direction == 'up':new_y = position['y'] - 10if new_y >= 0: # 确保不超出上边界position['y'] = new_yelif direction == 'down':new_y = position['y'] + 10if new_y <= height - ball_radius: # 确保不超出下边界 (减去半径)position['y'] = new_yelif direction == 'left':new_x = position['x'] - 10if new_x >= 0: # 确保不超出左边界position['x'] = new_xelif direction == 'right':new_x = position['x'] + 10if new_x <= width - ball_radius: # 确保不超出右边界 (减去半径)position['x'] = new_xreturn jsonify(clients[client_id]['position'])@app.route('/position/<client_id>', methods=['GET'])
def get_position(client_id):if client_id in clients:return jsonify(clients[client_id]['position'])else:return jsonify({'error': 'Client not found'}), 404@app.route('/particles', methods=['GET'])
def get_particles():return jsonify(particles)def run_server():app.run(host='0.0.0.0', port=5000, threaded=True)if __name__ == '__main__':threading.Thread(target=generate_particles, daemon=True).start()threading.Thread(target=update_particles, daemon=True).start()run_server()
客户端代码
import pygame
import requests# 初始化pygame
pygame.init()# 设置窗口大小
width, height = 500, 500
window = pygame.display.set_mode((width, height))
pygame.display.set_caption("Particle Avoidance Game")# 颜色
BLACK = (0, 0, 0) # 黑色
BLUE = (0, 0, 255)
RED = (255, 0, 0)# 主球初始位置
ball_radius = 15
client_id = 'client1' # 确保每个客户端使用不同的 ID# 注册客户端
register_response = requests.post('http://127.0.0.1:5000/register', json={'id': client_id})
if register_response.status_code != 200:print("Failed to register client:", register_response.text)def get_ball_position():response = requests.get(f'http://127.0.0.1:5000/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos) # 打印响应if 'x' in ball_pos and 'y' in ball_pos: # 确保包含 x 和 yreturn ball_poselse:print("Ball position does not contain 'x' and 'y':", ball_pos) # 额外调试信息else:print("Failed to get ball position:", response.text)return {'x': 250, 'y': 250} # 默认值def move_ball(direction):requests.post(f'http://127.0.0.1:5000/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get('http://127.0.0.1:5000/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles) # 打印响应return particleselse:print("Failed to get particles:", response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5) # 粒子的半径为5running = True
while running:for event in pygame.event.get():if event.type == pygame.QUIT:running = Falsekeys = pygame.key.get_pressed()if keys[pygame.K_UP]:move_ball('up')if keys[pygame.K_DOWN]:move_ball('down')if keys[pygame.K_LEFT]:move_ball('left')if keys[pygame.K_RIGHT]:move_ball('right')# 更新球的位置ball_position = get_ball_position()particles = get_particles()# 检查碰撞if 'x' in ball_position and 'y' in ball_position: # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置if check_collision(ball_position, particle):print("Game Over! You collided with a particle.")running = Falseelse:print("Invalid ball position:", ball_position) # 调试信息# 渲染window.fill(BLACK) # 将背景填充为黑色if 'x' in ball_position and 'y' in ball_position: # 确保球有有效位置pygame.draw.circle(window, BLUE, (ball_position['x'], ball_position['y']), ball_radius)else:print("Ball position is invalid, not drawing.") # 调试信息for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle['x'], particle['y']), 5) # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit()
启动程序调试了下,基本可以充当强化学习的环境
后续直接在客户端上添加强化模型的训练代码,把server端代码部署到k8s 上做为训练环境
要设计一个基于图神经网络(GNN)和优势演员-评论家(A2C)算法的强化学习模型,以训练 main_ball
在环境中左右移动,我们可以遵循以下步骤:
1. 确定问题
- 目标:控制
main_ball
左右移动以避免与粒子碰撞。 - 状态空间:包含
main_ball
和粒子的状态信息,包括位置和速度。 - 动作空间:定义
main_ball
的动作(左右上下或保持不动)。
2. 数据结构和环境设计
首先,构建一个环境来模拟 main_ball
和粒子的动态行为。
class Environment:def __init__(self, main_ball, particles):self.main_ball = main_ballself.particles = particlesdef reset(self):# 重置环境,返回初始状态return self.get_state()def get_state(self):# 获取当前状态state = {'main_ball': self.main_ball,'particles': self.particles}return statedef step(self, action):# 根据动作更新环境状态,按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward = 0done = Falsenew_main_ball = get_ball_position()particles = get_particles()COLLISION_THRESHOLD = 20for particle in self.particles:distance = ((self.main_ball['x'] - particle['x']) ** 2 + (self.main_ball['y'] - particle['y']) ** 2) ** 0.5if distance < COLLISION_THRESHOLD: # 定义一个阈值判断碰撞reward = -1 # 碰撞时给予负奖励break# 检查是否重新开始if 'x' in new_main_ball and 'y' in new_main_ball: # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置if check_collision(new_main_ball, particle):print("Game Over! You collided with a particle.")done = True # 结束游戏breakif not done:reward = 5return self.get_state(), reward, done, {}
3. GNN 构建
使用 DGL 构建图结构,以表示 main_ball
和粒子之间的关系。
import dgl
import torchdef create_graph(main_ball, particles):num_particles = len(particles)G = dgl.graph(([], []), num_nodes=num_particles + 1)# 添加主球节点G.ndata['pos'] = torch.zeros(num_particles + 1, 2)G.ndata['pos'][0] = torch.tensor([main_ball['x'], main_ball['y']])# 添加粒子节点for i, particle in enumerate(particles):G.ndata['pos'][i + 1] = torch.tensor([particle['x'], particle['y']])# 添加边和距离权重edges = []distances = []for i in range(num_particles):for j in range(i + 1, num_particles):edges.append((i + 1, j + 1)) # 粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][i + 1] - G.ndata['pos'][j + 1])distances.append(distance.item())edges.append((0, i + 1)) # 主球与粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][0] - G.ndata['pos'][i + 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata['distance'] = torch.tensor(distances)# 添加自环G = dgl.add_self_loop(G)return G
4. A2C 模型设计
使用 PyTorch 设计 A2C 模型,包括 Actor 和 Critic。
在强化学习中,"Actor"(演员)和"Critic"(评论家)是两个关键的角色或组件。
Actor(演员):
Actor 是强化学习中的一个组件,通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作,并将其发送给环境。
Actor 的目标是学习一个策略,即从状态到动作的映射函数,以便在与环境的交互中获得高回报。策略可以是确定性的(直接选择最优动作)或概率性的(选择动作的概率分布)。Actor 的训练目标是最大化预期回报,通常使用梯度上升方法(如策略梯度法)进行优化。
Critic(评论家):
Critic 是强化学习中的另一个组件,用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估,
提供一个值函数(或者动作值函数)来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数,
用于评估不同状态-动作对的价值,并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差,
通常使用时序差分学习(如 Q-learning 或 TD-learning)或函数逼近方法(如神经网络)进行优化。
在某些强化学习算法中,Actor 和 Critic 可以是分离的组件,各自独立进行训练。
Actor 使用 Critic 提供的价值信息来指导动作选择;
而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。
class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 = dgl.nn.GraphConv(2, 128) # 输入特征为 2:位置self.conv2 = dgl.nn.GraphConv(128, (1 + action_space_dim)) # 隐藏层self.commonCov = nn.Linear((len(particles) +1 ) * (1 + action_space_dim), 128)self.actor = nn.Linear(128, action_space_dim) # 行动空间self.critic = nn.Linear(128, 1)def forward(self, g):g = dgl.add_self_loop(g)x = g.ndata['pos']x = self.conv1(g, x)x = F.relu(x)x = self.conv2(g, x)x = F.relu(x)x = x.reshape(-1)x = torch.relu(self.commonCov(x))actor = self.actor(x)critic = self.critic(x)return actor, critic
5. 训练循环
设计训练循环,使用强化学习算法更新模型的参数。
# 初始化环境和模型
main_ball = get_ball_position()
particles = get_particles()
env = Environment(main_ball, particles)
device_id = 1
if torch.cuda.is_available():device_id = torch.cuda.current_device()
model = ActorCritic(n_devices=device_id, action_space_dim=len(action_space)).to(device)
# 检查模型文件是否存在
model_path = 'actor_critic_model.pth'
if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print("Model loaded successfully.")model.train() # 切换到训练模式
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练循环
for episode in range(num_episodes):state = env.reset()done = Falsewhile not done:# 创建图g = create_graph(state['main_ball'], state['particles'])# 前向传播actor_logits, critic_value = model(g)# 选择动作(使用 softmax)action_prob = F.softmax(actor_logits, dim=-1)# 使用 torch.multinomial 选择一个动作action = torch.multinomial(action_prob, num_samples=1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ = env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value = model(g)td_target = reward + 0.95 * next_valuedelta = td_target - critic_value'''这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。'''actor_loss = -action_prob[action] * delta.detach()critic_loss = delta.pow(2)loss = actor_loss + critic_lossprint("epoch = %d, device_id = %d, epoch_loss= %lf" % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state = next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(f"Model saved after episode {episode}.")
通过以上步骤,你可以构建一个 GNN + A2C 的强化学习模型来训练 main_ball
的左右移动策略。
启动训练 num_episodes = 1000
强化学习在对接仿真有没有gpu加速速度并不明显,因为受限于仿真state-action-next_state 这套流程的处理速度,我使用在本机上跑该训练,训练结束后保存模型文件到actor_critic_model.pth
6. 模型推理
使用 actor_critic_model.pth接入到之前的客户端代码上,用模型决策来替换键盘操作
import os
import pygame
import requests
import random
import string
import torch
import torch.nn.functional as F
from g1_client_train import ActorCritic, create_graph, action_space# 初始化pygame
pygame.init()# 设置窗口大小
width, height = 500, 500
window = pygame.display.set_mode((width, height))
pygame.display.set_caption("Particle Avoidance Game")# 颜色
BLACK = (0, 0, 0) # 黑色
BLUE = (0, 0, 255)
RED = (255, 0, 0)# 主球初始位置
ball_radius = 15
if torch.cuda.is_available():# 使用 CUDA 设备device = torch.device("cuda")
else:# 使用 CPU 设备device = torch.device("cpu")
def generate_random_string(length=5):# 定义可用字符,包括数字和大小写字母characters = string.ascii_letters + string.digits# 随机选择字符并生成字符串random_string = ''.join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串
random_string = generate_random_string()
print(random_string)
client_id=random_stringhost = os.getenv("SERVER_HOST", 'http://192.168.110.126:31007')def register_client():register_response = requests.post(f'{host}/register', json={'id': client_id})if register_response.status_code != 200:print("Failed to register client:", register_response.text)return Falsereturn Truedef get_ball_position():response = requests.get(f'{host}/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos) # 打印响应if 'x' in ball_pos and 'y' in ball_pos: # 确保包含 x 和 yreturn ball_posprint("Failed to get ball position:", response.text)return None # 返回 None 表示失败def move_ball(direction):requests.post(f'{host}/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get(f'{host}/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles) # 打印响应return particlesprint("Failed to get particles:", response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5) # 粒子的半径为5def restart_game():global runningprint("Restarting game...")running = Trueregister_client() # 重新注册客户端if __name__ == "__main__":# 注册客户端if not register_client():exit()device_id = 1if torch.cuda.is_available():device_id = torch.cuda.current_device()# 更新球的位置main_ball = get_ball_position()particles = get_particles()model = ActorCritic(n_devices=device_id,particles_dim=len(particles),action_space_dim=len(action_space)).to(device)model_path = 'actor_critic_model.pth'model.load_state_dict(torch.load(model_path))running = Truewhile running:for event in pygame.event.get():if event.type == pygame.QUIT:running = Falsemain_ball = get_ball_position()particles = get_particles()g = create_graph(main_ball=main_ball, particles=particles)actor_logits, _ = model(g)action_prob = F.softmax(actor_logits, dim=-1)action = torch.multinomial(action_prob, num_samples=1).item()move_ball(action_space[action])# 检查碰撞if main_ball: # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置if check_collision(main_ball, particle):print("Game Over! You collided with a particle.")restart_game()breakelse:print("Invalid ball position, restarting game.")restart_game()# 渲染window.fill(BLACK) # 将背景填充为黑色if main_ball: # 确保球有有效位置pygame.draw.circle(window, BLUE, (main_ball['x'], main_ball['y']), ball_radius)for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle['x'], particle['y']), 5) # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit()
录个动图看看效果
有必要再优化下奖励函数,每一步避险操作后的next_state应该是优于之前state
训练的好的模型奖惩函数规则都挺细的
附上完整的训练代码文件
import osimport requests
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl
import dgl.nn as dglnn
import os
import random
import stringfrom g1_server import speed_choicesball_radius = 15
width, height = 500, 500
num_episodes = 1000min_vx = min(speed_choices)
max_vx = max(speed_choices)
min_vy = min(speed_choices)
max_vy = max(speed_choices)def register_client():register_response = requests.post(f'{host}/register', json={'id': client_id})if register_response.status_code != 200:print("Failed to register client:", register_response.text)return Falsereturn Truedef restart_game():global runningprint("Restarting game...")running = Trueregister_client() # 重新注册客户端def generate_random_string(length=5):# 定义可用字符,包括数字和大小写字母characters = string.ascii_letters + string.digits# 随机选择字符并生成字符串random_string = ''.join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串
random_string = generate_random_string()
print(random_string)
client_id = random_stringhost = os.getenv("SERVER_HOST", 'http://192.168.110.126:31007')
# 注册客户端
register_response = requests.post(f'{host}/register', json={'id': client_id})
if register_response.status_code != 200:print("Failed to register client:", register_response.text)if torch.cuda.is_available():# 使用 CUDA 设备device = torch.device("cuda")
else:# 使用 CPU 设备device = torch.device("cpu")def get_ball_position():response = requests.get(f'{host}/position/{client_id}')if response.status_code == 200:ball_pos = response.json()print("Ball position response:", ball_pos) # 打印响应if 'x' in ball_pos and 'y' in ball_pos: # 确保包含 x 和 yreturn ball_poselse:print("Ball position does not contain 'x' and 'y':", ball_pos) # 额外调试信息else:print("Failed to get ball position:", response.text)return {'x': 250, 'y': 250} # 默认值def move_ball(direction):requests.post(f'{host}/move/{client_id}', json={'direction': direction})def get_particles():response = requests.get(f'{host}/particles')if response.status_code == 200:particles = response.json()print("Particles response:", particles) # 打印响应return particleselse:print("Failed to get particles:", response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance = ((ball_pos['x'] - particle_pos['x']) ** 2 + (ball_pos['y'] - particle_pos['y']) ** 2) ** 0.5return distance < (ball_radius + 5) # 粒子的半径为5def create_graph(main_ball, particles):num_particles = len(particles)G = dgl.graph(([], []), num_nodes=num_particles + 1)# 添加主球节点G.ndata['pos'] = torch.zeros(num_particles + 1, 2)G.ndata['pos'][0] = torch.tensor([main_ball['x'], main_ball['y']])# 添加粒子节点for i, particle in enumerate(particles):G.ndata['pos'][i + 1] = torch.tensor([particle['x'], particle['y']])# 添加边和距离权重edges = []distances = []for i in range(num_particles):for j in range(i + 1, num_particles):edges.append((i + 1, j + 1)) # 粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][i + 1] - G.ndata['pos'][j + 1])distances.append(distance.item())edges.append((0, i + 1)) # 主球与粒子之间的边# 计算距离并存储distance = torch.norm(G.ndata['pos'][0] - G.ndata['pos'][i + 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata['distance'] = torch.tensor(distances)# 添加自环G = dgl.add_self_loop(G)return Gdef compute_state(main_ball, particles):flow = []normalized_main_ball_x = main_ball['x'] / widthnormalized_main_ball_y = main_ball['y'] / heightflow.append([0, 0, normalized_main_ball_x, normalized_main_ball_y])for particle in particles:# 归一化粒子的速度和位置normalized_vx = (particle['vx'] - min_vx) / (max_vx - min_vx) # 根据你的数据范围进行归一化normalized_vy = (particle['vy'] - min_vy) / (max_vy - min_vy)normalized_x = particle['x'] / widthnormalized_y = particle['y'] / heightflow.append([normalized_vx, normalized_vy, normalized_x, normalized_y])return torch.tensor(flow).float()'''
在强化学习中,"Actor"(演员)和"Critic"(评论家)是两个关键的角色或组件。Actor(演员):
Actor 是强化学习中的一个组件,通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作,并将其发送给环境。
Actor 的目标是学习一个策略,即从状态到动作的映射函数,以便在与环境的交互中获得高回报。策略可以是确定性的(直接选择最优动作)或概率性的(选择动作的概率分布)。Actor 的训练目标是最大化预期回报,通常使用梯度上升方法(如策略梯度法)进行优化。Critic(评论家):
Critic 是强化学习中的另一个组件,用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估,
提供一个值函数(或者动作值函数)来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数,
用于评估不同状态-动作对的价值,并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差,
通常使用时序差分学习(如 Q-learning 或 TD-learning)或函数逼近方法(如神经网络)进行优化。在某些强化学习算法中,Actor 和 Critic 可以是分离的组件,各自独立进行训练。
Actor 使用 Critic 提供的价值信息来指导动作选择;
而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。
'''
class Environment:def __init__(self, main_ball, particles):self.main_ball = main_ballself.particles = particlesdef reset(self):# 重置环境,返回初始状态return self.get_state()def get_state(self):# 获取当前状态state = {'main_ball': self.main_ball,'particles': self.particles}return statedef step(self, action):# 根据动作更新环境状态,按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward = 0done = Falsenew_main_ball = get_ball_position()particles = get_particles()COLLISION_THRESHOLD = 20for particle in self.particles:distance = ((self.main_ball['x'] - particle['x']) ** 2 + (self.main_ball['y'] - particle['y']) ** 2) ** 0.5if distance < COLLISION_THRESHOLD: # 定义一个阈值判断碰撞reward = -1 # 碰撞时给予负奖励break# 检查是否重新开始if 'x' in new_main_ball and 'y' in new_main_ball: # 确保球有有效位置for particle in particles:if 'x' in particle and 'y' in particle: # 确保粒子有有效位置if check_collision(new_main_ball, particle):print("Game Over! You collided with a particle.")done = True # 结束游戏breakif not done:reward = 5return self.get_state(), reward, done, {}class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 = dgl.nn.GraphConv(2, 128) # 输入特征为 2:位置self.conv2 = dgl.nn.GraphConv(128, (1 + action_space_dim)) # 隐藏层self.commonCov = nn.Linear((len(particles) +1 ) * (1 + action_space_dim), 128)self.actor = nn.Linear(128, action_space_dim) # 行动空间self.critic = nn.Linear(128, 1)def forward(self, g):g = dgl.add_self_loop(g)x = g.ndata['pos']x = self.conv1(g, x)x = F.relu(x)x = self.conv2(g, x)x = F.relu(x)x = x.reshape(-1)x = torch.relu(self.commonCov(x))actor = self.actor(x)critic = self.critic(x)return actor, criticaction_space = ['up', 'down', 'left', 'right', None]if not register_client():print('not register')exit()# 初始化环境和模型
main_ball = get_ball_position()
particles = get_particles()
env = Environment(main_ball, particles)
device_id = 1
if torch.cuda.is_available():device_id = torch.cuda.current_device()
model = ActorCritic(n_devices=device_id, action_space_dim=len(action_space)).to(device)
# 检查模型文件是否存在
model_path = 'actor_critic_model.pth'
if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print("Model loaded successfully.")model.train() # 切换到训练模式
optimizer = optim.Adam(model.parameters(), lr=0.01)# 训练循环
for episode in range(num_episodes):state = env.reset()done = Falsewhile not done:# 创建图g = create_graph(state['main_ball'], state['particles'])# 前向传播actor_logits, critic_value = model(g)# 选择动作(使用 softmax)action_prob = F.softmax(actor_logits, dim=-1)# 使用 torch.multinomial 选择一个动作action = torch.multinomial(action_prob, num_samples=1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ = env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value = model(g)td_target = reward + 0.95 * next_valuedelta = td_target - critic_value'''这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。'''actor_loss = -action_prob[action] * delta.detach()critic_loss = delta.pow(2)loss = actor_loss + critic_lossprint("epoch = %d, device_id = %d, epoch_loss= %lf" % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state = next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(f"Model saved after episode {episode}.")