一文实践强化学习训练游戏ai--doom枪战游戏实践

一文实践强化学习训练游戏ai–doom枪战游戏实践
上次文章写道下载doom的环境并尝试了简单的操作,这次让我们来进行对象化和训练、验证,如果你有基础,可以直接阅读本文,不然请你先阅读Doom基础知识,其中包含了下载、动作等等的基础知识。
本次与之前的马里奥训练不同,马里奥是已经有做好的step等函数的,而这个doom没有,但也因此我们可以更好的一窥训练的过程。
完整代码在最后,可以复制执行。

文章目录

  • 一、训练模型
    • 1、vizdoom_train类
      • 1)__init__
      • 2)step
      • 3)togray
      • 4)其他函数
    • 2、保存模型函数
    • 3、训练模型
  • 二、成果验收
  • 完整代码
    • 训练代码
    • 测试代码

一、训练模型

1、vizdoom_train类

这是我们的训练基类,由于我们想用openai gym环境,因此我们需要手写此环境必须的__init__、step等函数

1)init

在这个函数中,我们要定义训练的观察空间(即游戏图像)、动作空间(即ai可以执行的操作)和一些基础设置。

        VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"self.game = vizdoom.DoomGame()self.game.load_config(VizDoom_basic_cfg)if render == False: self.game.set_window_visible(False)else:self.game.set_window_visible(True)self.game.init()

此处,我们先指定游戏文件的位置,然后加一个判断,决定是否允许游戏窗口显示出来,最后初始化游戏。
初始化后,我们就要规定观察空间和动作空间了

        self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) self.action_space = Discrete(3)

观察空间是我们游戏的界面,这是一个灰化后的图像,所以维度为1,大小为100*160
动作空间为离散空间3,即可以选取动作0、1、2,我们只需要在后续的代码中指定数字指代的动作就可以了。

2)step

step函数非常关键,这是ai执行动作的时候会调用的函数。
首先,我们定义我们的动作

        actions = np.identity(3,dtype=np.uint8)reward = self.game.make_action(actions[action], 4) 

这部分的详细解释在Doom基础知识中,事实上就是定义一个矩阵,调用游戏文件中的左移、右移和射击。
接下来,我们要获取当前的状态,比如得分,游戏图像等,这样我们才可以训练。

        try:state = self.game.get_state()img = state.screen_bufferimg = self.togray(img)info = state.game_variables[0]except:img = np.zeros(self.observation_space.shape)info = 0 finally:info = {"info":info}done = self.game.is_episode_finished()#img_show(img)return img,reward,done,info

使用try,是因为gameover时有些内容获取不到,为了防止程序因此暂停,用try。最后,我们要把info变成字典形式,这是因为openai gym环境时这么要求的。为了方便理解,这里可以调用imgshow,查看目前的图像,其实现如下。

def img_show(img):plt.imshow(img)plt.show()time.sleep(5)

3)togray

灰度化图像,我们知道,彩色图像时由rgb三个颜色矩阵组成,但这么大的数据量给我们的训练增添的很多负担,于是我们采用灰度图。同时,我们缩小图像,这样可以训练的更快。

    def togray(self,observation):gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)state = np.reshape(resize, (100,160,1))return state

observation 是一个 NumPy 数组,通常表示图像数据。
np.moveaxis 是 NumPy 库中的一个函数,用于重新排列数组的轴。
参数 0 表示将第0轴(通常是颜色通道)移动到新位置的最后一个轴位置(即 -1)。
例如,如果 observation 的形状是 (C, H, W)(即颜色通道在第一个维度),经过 np.moveaxis 后,形状将变为 (H, W, C),这对于 OpenCV 处理图像更为常见,因为 OpenCV 期望颜色通道是图像的最后一个维度。

cv2.cvtColor 是 OpenCV 中用于颜色空间转换的函数。
第一个参数是输入图像(在这里是经过 np.moveaxis 处理后的图像)。
第二个参数 cv2.COLOR_BGR2GRAY 指定了将图像从 BGR 颜色空间转换为灰度图像。

4)其他函数

我们要定义一个关闭游戏的函数close

   def close(self):self.game.close()

以及一个reset函数,用于在结束一个游戏后,重置状态,继续下一轮训练

    def reset(self):state = self.game.new_episode()state = self.game.get_state()return self.togray(state.screen_buffer)

至此,我们已经成功地把这个独立游戏包装成了可以使用openai gym的环境的游戏。

2、保存模型函数

class TrainAndLoggingCallback(BaseCallback):def __init__(self, check_freq, save_path, verbose=1):super(TrainAndLoggingCallback, self).__init__(verbose)self.check_freq = check_freqself.save_path = save_pathdef _init_callback(self):if self.save_path is not None:os.makedirs(self.save_path, exist_ok=True)def _on_step(self):if self.n_calls % self.check_freq == 0:model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))self.model.save(model_path)return True

我们使用这段代码来存档数据,这部分复制即可

3、训练模型

首先,我们指定训练结果保存的路径

    CHECKPOINT_DIR = './train/train_basic'LOG_DIR = './logs/log_basic'callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    ```

然后我们调用训练函数进行训练

    env = vizdoom_train(render=False)model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)model.learn(total_timesteps=100000, callback=callback)

这里我们使用PPO这个强化学习算法。

经过几十分钟的等待,就得到训练好的模型了

二、成果验收

训练好模型后,我们要使用模型,看看效果。
首先,我们载入训练好的模型

model = PPO.load('./train/train_basic/best_model_100000')

然后测试以下模型的平均得分

mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=5)

这样还不够直观,我们让ai玩给我们看


for episode in range(100): obs = env.reset()done = Falsetotal_reward = 0while not done: action, _ = model.predict(obs)obs, reward, done, info = env.step(action)time.sleep(0.20)total_reward += rewardprint('Total Reward for episode {} is {}'.format(total_reward, episode))time.sleep(2)

我们首先让ai模型预测,然后将预测的动作输入step函数,然后展示页面。

如图所示,平均得分非常高,基本能快速索敌,然后一枪秒杀。
至此,我们完成了ai的训练。

完整代码

训练代码

from gym import Env
from gym.spaces import Discrete,Box
import cv2
from vizdoom import *
import vizdoom
import random
import time
import numpy as np
#DIS离散空间 作用类似于random
#box用来装游戏图像
from matplotlib import pyplot as plt
class vizdoom_train(Env):def __init__(self, render=True):super().__init__()VizDoom_basic_cfg = r"C:/Users/tttiger/Desktop/ViZDoom-master/ViZDoom-master/scenarios/basic.cfg"self.game = vizdoom.DoomGame()self.game.load_config(VizDoom_basic_cfg)if render == False: self.game.set_window_visible(False)else:self.game.set_window_visible(True)self.game.init()self.observation_space = Box(low=0, high=255, shape=(100,160,1), dtype=np.uint8) self.action_space = Discrete(3)def step(self,action):actions = np.identity(3,dtype=np.uint8)reward = self.game.make_action(actions[action], 4) try:state = self.game.get_state()img = state.screen_bufferimg = self.togray(img)info = state.game_variables[0]except:img = np.zeros(self.observation_space.shape)info = 0 finally:info = {"info":info}done = self.game.is_episode_finished()#img_show(img)return img,reward,done,infodef close(self):self.game.close()def reset(self):state = self.game.new_episode()state = self.game.get_state()return self.togray(state.screen_buffer)def togray(self,observation):gray = cv2.cvtColor(np.moveaxis(observation, 0, -1), cv2.COLOR_BGR2GRAY)resize = cv2.resize(gray, (160,100), interpolation=cv2.INTER_CUBIC)state = np.reshape(resize, (100,160,1))return statedef img_show(img):plt.imshow(img)plt.show()time.sleep(5)# Import os for file nav
import os 
# Import callback class from sb3
from stable_baselines3.common.callbacks import BaseCallback
class TrainAndLoggingCallback(BaseCallback):def __init__(self, check_freq, save_path, verbose=1):super(TrainAndLoggingCallback, self).__init__(verbose)self.check_freq = check_freqself.save_path = save_pathdef _init_callback(self):if self.save_path is not None:os.makedirs(self.save_path, exist_ok=True)def _on_step(self):if self.n_calls % self.check_freq == 0:model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))self.model.save(model_path)return Trueif __name__ == "__main__":CHECKPOINT_DIR = './train/train_basic'LOG_DIR = './logs/log_basic'callback = TrainAndLoggingCallback(check_freq=10000, save_path=CHECKPOINT_DIR)    train = vizdoom_train()from stable_baselines3 import PPOenv = vizdoom_train(render=False)model = PPO('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, learning_rate=0.0001, n_steps=2048)model.learn(total_timesteps=100000, callback=callback)

测试代码

from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3 import PPO
import time
model = PPO.load('./train/train_basic/best_model_100000')from second_gym import vizdoom_train
env = vizdoom_train(render=True)mean_reward, _ ,_= evaluate_policy(model, env, n_eval_episodes=5)print(mean_reward)for episode in range(100): obs = env.reset()done = Falsetotal_reward = 0while not done: action, _ = model.predict(obs)obs, reward, done, info = env.step(action)time.sleep(0.20)total_reward += rewardprint('Total Reward for episode {} is {}'.format(total_reward, episode))time.sleep(2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/43874.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

打开ps提示dll文件丢失如何解决?教你几种靠谱的方法

在日常使用电脑过程中,由于不当操作,dll文件丢失是一种常见现象。当dll文件丢失时,程序将无法正常运行,比如ps,pr等待软件。此时,我们需要对其进行修复以恢复其功能,下面我们一起来了解一下出现…

【堆 (优先队列) 扫描线】218. 天际线问题

本文涉及知识点 堆 (优先队列) 扫描线 LeetCode218. 天际线问题 城市的 天际线 是从远处观看该城市中所有建筑物形成的轮廓的外部轮廓。给你所有建筑物的位置和高度,请返回 由这些建筑物形成的 天际线 。 每个建筑物的几何信息由数组 buildings 表示&…

景芯SoC训练营DFT debug

景芯训练营VIP学员在实践课上遇到个DFT C1 violation,导致check_design_rule无法通过,具体报错如下: 遇到这个问题第一反映一定是确认时钟,于是小编让学员去排查add_clock是否指定了时钟,指定的时钟位置是否正确。 景芯…

2024年信息系统项目管理师1批次上午客观题参考答案及解析(3)

51、探索各种选项,权衡包括时间与成本、质量与成本、风险与进度、进度与质量等多种因素,在整个过程中,舍弃无效或次优的替代方案,这种不确定性应对方法是()。 A.集合设计 B.坚韧性 C.多种结果…

离线运行Llama3:本地部署终极指南_liama2 本地部署

4月18日,Meta在官方博客官宣了Llama3,标志着人工智能领域迈向了一个重要的飞跃。经过笔者的个人体验,Llama3 8B效果已经超越GPT-3.5,最为重要的是,Llama3是开源的,我们可以自己部署! 本文和大家…

大数据------JavaWeb------FilterListenerAJAXAxiosJSON

Filter Filter简介 定义:Filter表示过滤器,是JavaWeb三大组件(Servlet、Filter、Listener)之一。 作用:它可把对资源(Servlet、JSP、Html)的请求拦截下来从而实现一些特殊功能 过滤器一般完成…

【QT中实现摄像头播放、以及视频录制】

学习分享 1、效果图2、camerathread.h3、camerathread.cpp4、mainwindow.h5、mainwindow.cpp6、main.cpp 1、效果图 2、camerathread.h #ifndef CAMERATHREAD_H #define CAMERATHREAD_H#include <QObject> #include <QThread> #include <QDebug> #include &…

选择排序(C语言版)

选择排序是一种简单直观的排序算法 算法实现 首先在未排序序列中找到最小&#xff08;大&#xff09;元素&#xff0c;存放到排序序列的起始位置。 再从剩余未排序元素中继续寻找最小&#xff08;大&#xff09;元素&#xff0c;然后放到已排序序列的末尾。 重复第二步&…

020-GeoGebra中级篇-几何对象之点与向量

本文概述了在GeoGebra中如何使用笛卡尔或极坐标系输入点和向量。用户可以通过指令栏输入数字和角度&#xff0c;使用工具或指令创建点和向量。在笛卡尔坐标系中&#xff0c;示例如“P(1,0)”&#xff1b;在极坐标系中&#xff0c;示例如“P(1;0)”或“v(5;90)”。文章还介绍了点…

深入理解循环神经网络(RNN)

深入理解循环神经网络&#xff08;RNN&#xff09; 循环神经网络&#xff08;Recurrent Neural Network, RNN&#xff09;是一类专门处理序列数据的神经网络&#xff0c;广泛应用于自然语言处理、时间序列预测、语音识别等领域。本文将详细解释RNN的基本结构、工作原理以及其优…

uniapp本地打包到Android Studio生成APK文件

&#xff08;1&#xff09;安装 Android Studio 软件&#xff1b; 下载地址&#xff1a;官方下载地址&#xff0c;英文环境 安装&#xff1a;如下之外&#xff0c;其他一键 next &#xff08;2&#xff09;配置java环境&#xff1b; 下载&#xff1a;j…

基于SpringBoot构造超简易QQ邮件服务发送 第二版

目录 追加 邮箱附件 添加依赖 编码 测试 第二版的更新点是追加了 邮箱附件功能 ( 后期追加定时任务 ) 基于SpringBoot构造超简易QQ邮件服务发送(分离-图解-新手) 第一版 追加 邮箱附件 添加依赖 <!-- 电子邮件 --><dependency><groupId>org.spri…

如何评价Flutter?

哈喽&#xff0c;我是老刘 我们团队使用Flutter已经快6年了。 有很多人问过我们对Flutter的评价。 今天在这里回顾一下6年前选择Flutter时的原因&#xff0c;以及Flutter在这几年中的实际表现如何。 选择Flutter时的判断 1、性能 最开始吸引我们的就是其优秀的性能。 特别是…

【vue3|第16期】初探Vue-Router与现代网页路由

日期:2024年7月6日 作者:Commas 签名:(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释:如果您觉得有所帮助,帮忙点个赞,也可以关注我,我们一起成长;如果有不对的地方,还望各位大佬不吝赐教,谢谢^ - ^ 1.01365 = 37.7834;0.99365 = 0.0255 1.02365 = 1377.4083…

深入探索联邦学习框架 Flower

联邦学习框架 本文主要期望介绍一个设计良好的联邦学习框架 Flower&#xff0c;在开始介绍 Flower 框架的细节前&#xff0c;先了解下联邦学习框架的基础知识。 作为一个联邦学习框架&#xff0c;必然会包含对横向联邦学习的支持。横向联邦是指拥有类似数据的多方可以在不泄露…

【CVPR 2024】GART: Gaussian Articulated Template Models

【CVPR 2024】GART: Gaussian Articulated Template Models 一、前言Abstract1. Introduction2. Related Work3. Method3.1. Template Prior3.2. Shape Appearance Representation with GMM3.3. Motion Representation with Forward Skinning3.4. Reconstruct GART from Monocu…

Java--instanceof和类型转换

1.如图&#xff0c;Object&#xff0c;Person&#xff0c;Teacher&#xff0c;Student四类的关系已经写出来了&#xff0c;由于实例化的是Student类&#xff0c;因此&#xff0c;与Student类存在关系的类在使用instanceof时都会输出True&#xff0c;而无关的都会输出False&…

数据结构 —— Dijkstra算法

数据结构 —— Dijkstra算法 Dijkstra算法划分集合模拟过程打印路径 在上次的博客中&#xff0c;我们解决了使用最小的边让各个顶点连通&#xff08;最小生成树&#xff09; 这次我们要解决的问题是现在有一个图&#xff0c;我们要找到一条路&#xff0c;使得从一个顶点到另一个…

对比学习和多模态任务

1. 对比学习 对比学习&#xff08;Contrastive Learning&#xff09;是一种自监督学习的方法&#xff0c;旨在通过比较数据表示空间中的不同样本来学习有用的特征表示。其核心思想是通过最大化同类样本之间的相似性&#xff08;或降低它们之间的距离&#xff09;&#xff0c;同…

【Linux】网络新兵连

欢迎来到 破晓的历程的 博客 ⛺️不负时光&#xff0c;不负己✈️ 引言 在上一篇博客中&#xff0c;我们简单的介绍了一些Linux网络一些比较基本的概念。本篇博客我们将开始正式学习Linux网络套接字的内容&#xff0c;那么我们开始吧&#xff01; 1.网络中的地址管理 大家一…