实验环境
python=3.10
torch=2.1.1
gym=0.26.2
gym[classic_control]
matplotlib=3.8.0
numpy=1.26.2
DQN代码
首先是module.py
代码,在这里定义了网络模型和DQN模型
import torch
import torch.nn as nn
import numpy as npclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x): # [b,n_states]return self.network(x)class DQN:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states # 状态的特征数self.n_hidden = n_hidden # 隐含层个数self.n_actions = n_actions # 动作数self.lr = lr # 训练时的学习率self.gamma = gamma # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss() # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)if np.random.random() < self.epsilon:action_values = self.q_net(state) # q_net(state)采取动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]# 前千万不能缺少done,如果下一步游戏结束的花,那下一步的q值应该为0q_target = reward + self.gamma * self.q_net(next_state).max() * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()
然后是train.py
代码,在这里调用DQN模型和gym环境,来进行训练:
import gym
import torch
from module import DQN
import matplotlib.pyplot as pltlr = 1e-3 # 学习率
gamma = 0.95 # 折扣因子
epsilon = 0.8 # 贪心系数
n_hidden = 200 # 隐含层神经元个数env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0] # 4
n_actions = env.action_space.n # 2 动作的个数dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)if __name__ == '__main__':reward_list = []for i in range(500):state = env.reset()[0] # len=4total_reward = 0done = Falsewhile True:# 获取当前状态下需要采取的动作action = dqn.choose_action(state)# 更新环境next_state, reward, done, _, _ = env.step(action)dqn.update(state, action, reward, next_state, done)state = next_statetotal_reward += rewardif done:breakprint("第%d回合,total_reward=%f" % (i, total_reward))reward_list.append(total_reward)# 绘图episodes_list = list(range(len(reward_list)))plt.plot(episodes_list, reward_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN Returns')plt.show()
SARSA代码
首先是module.py
代码,在这里定义了网络模型和SARSA模型。
SARSA和DQN基本相同,只有在更新Q网络的时候略有不同,已在代码相应位置做出注释。
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Fclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x): # [b,n_states]return self.network(x)class SARSA:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states # 状态的特征数self.n_hidden = n_hidden # 隐含层个数self.n_actions = n_actions # 动作数self.lr = lr # 训练时的学习率self.gamma = gamma # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss() # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)# 基于贪婪系数,有一定概率采取随机策略if np.random.random() < self.epsilon:action_values = self.q_net(state) # q_net(state)是在当前状态采取各个动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]'''sarsa在更新网络时选择的是q_net(next_state)[next_action] 这是sarsa算法和dqn的唯一不同dqn是选择max(q_net(next))'''next_action = self.choose_action(next_state)# 千万不能缺少done,如果下一步游戏结束的话,那下一步的q值应该为0,而不是q网络输出的值q_target = reward + self.gamma * self.q_net(next_state)[next_action] * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()
SARSA也有tarin.py
文件,功能和上面DQN的一样,内容也几乎完全一样,只是把DQN的名字改成SARSA而已,所以在这里不再赘述。
运行结果
DQN的运行结果如下:
SARSA运行结果如下: