Imitation Learning学习记录(理论例程)

前言

最近还是衔接着之前的学习记录,这次打算开始学习模仿学习的相关原理,参考的开源资料为

TeaPearce/Counter-Strike_Behavioural_Cloning: IEEE CoG & NeurIPS workshop paper ‘Counter-Strike Deathmatch with Large-Scale Behavioural Cloning’ (github.com)
[2104.04258] Counter-Strike Deathmatch with Large-Scale Behavioural Cloning (arxiv.org)

简单来说,行为克隆就是利用已有的人为示范数据作为输入来训练出一个策略,策略就会输出指定的动作,然而,行为克隆只能学习到专家的行为,而无法进行探索和自主学习。这意味着行为克隆的性能受限于专家的行为水平,并且可能无法适应新的、未在专家演示中出现过的情况。通过引入奖励函数,可以在行为克隆中加入一定的探索和自主学习能力。奖励函数可以根据当前状态和采取的动作来评估行为的好坏,并为模型提供反馈信号。通过优化奖励函数,可以使模型学习到更好的策略,并且能够适应新的情况和环境。奖励函数在行为克隆中起到了指导和调整模型学习的作用。其中,文章用到的例程的网络结构如下

请添加图片描述

本文打算从数据获取、模型训练、效果展示三个部分展开介绍

数据获取

在这个过程中作者使用了Game State Integration(GSI)技术来获取在线数据。通过GSI,作者可以从游戏中获取实时的游戏状态信息,包括玩家、队伍、武器、位置等各种数据。具体来说,作者可能使用了Valve提供的GSI接口来获取游戏状态信息。这些信息可以用于后续的行为克隆和分析工作。

这个过程中的核心代码如下

# now find the requried process and where two modules (dll files) are in RAM  
hwin_csgo = win32gui.FindWindow(0, ('counter-Strike: Global Offensive'))  
if(hwin_csgo):  pid=win32process.GetWindowThreadProcessId(hwin_csgo)  handle = pymem.Pymem()  handle.open_process_from_id(pid[1])  csgo_entry = handle.process_base  
else:  print('CSGO wasnt found')  os.system('pause')  sys.exit()  # now find two dll files needed  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  tmp=next(list_of_modules)  # used to be client_panorama.dll, moved to client.dll during 2020  if(tmp.name=="client.dll"):  print('found client.dll')  off_clientdll=tmp.lpBaseOfDll  break  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  tmp=next(list_of_modules)  if(tmp.name=="engine.dll"):  print('found engine.dll')  off_enginedll=tmp.lpBaseOfDll  break

大致逻辑为:

  1. 查找CSGO进程:

    • 使用win32gui.FindWindow查找名为’counter-Strike: Global Offensive’的窗口句柄。
    • 如果找到窗口句柄,则通过win32process.GetWindowThreadProcessId获取与该窗口关联的进程ID。
    • 使用pymem.Pymem()创建一个进程内存访问对象,并通过open_process_from_id方法打开该进程。
    • 如果CSGO进程未找到,则打印消息并退出程序。
  2. 查找client.dll和engine.dll:

    • 使用handle.list_modules()获取进程中的所有模块列表。
    • 遍历模块列表,查找名为"client.dll"的模块,并获取动态链接库的基地址(lpBaseOfDll)。
    • 注意:这里使用了两次handle.list_modules()来分别查找两个DLL文件,但实际上你可以只调用一次并将结果存储在列表中,然后遍历这个列表来查找两个DLL。
    • 类似地,代码还查找名为"engine.dll"的模块,并获取其基地址。

找到窗口和动态链接库以后就可以开始录像并通过GSI或者RAM来访问键位等游戏信息,得到的数据类型大致为

  • frame_i_x: 图像信息
  • frame_i_xaux: 包含在前一个时间步骤中应用的动作,以及血量、弹药和团队。用于更好地帮助智能体寻找敌人以及适应当前情况
  • frame_i_y: 对应键盘以及鼠标的动作
  • frame_i_helperarr: 在格式kill_flag, death_flag中,每个变量都是二元变量,例如[[1,0]],意味着玩家击杀一次,但在该时间步内没有死亡

其中,具体的键位信息如下:

# how many slots were used for each action type?  
n_keys = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r  
n_clicks = 2 # number of mouse buttons, left, right  
n_mouse_x = len(mouse_x_possibles) # number of outputs on mouse x axis  
n_mouse_y = len(mouse_y_possibles) # number of outputs on mouse y axis  
n_extras = 3 # number of extra aux inputs, eg health, ammo, team. others could be weapon, kills, deaths  
aux_input_length = n_keys+n_clicks+1+1+n_extras # aux uses continuous input for mouse this is multiplied by ACTIONS_PREV elsewhere

一个帧所包含的具体信息值如下:

请添加图片描述
请添加图片描述

模型训练

网络结构

输入先进入一个预训练好的EfficientNetB0模型,该模型在ImageNet数据集上进行了训练。并加上了时间序列信息,接下来将提取好的特征输入进一个带有时序信息的ConvLSTM网络

base_model = EfficientNetB0(weights='imagenet',input_shape=(input_shape[1:]),include_top=False,drop_connect_rate=0.2)
if 'drop' in model_name:  if 'big' in model_name:  x = ConvLSTM2D(filters=512,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)  else:  x = ConvLSTM2D(filters=256,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)

输出的信息为

# set up outputs, sepearate outputs will allow seperate losses to be applied  
output_1 = TimeDistributed(Dense(n_keys, activation='sigmoid'))(dense_5)  
output_2 = TimeDistributed(Dense(n_clicks, activation='sigmoid'))(dense_5)  
output_3 = TimeDistributed(Dense(n_mouse_x, activation='softmax'))(dense_5) # softmax since mouse is mutually exclusive  
output_4 = TimeDistributed(Dense(n_mouse_y, activation='softmax'))(dense_5)   
output_5 = TimeDistributed(Dense(1, activation='linear'))(dense_5)   
# output_all = concatenate([output_1,output_2,output_3,output_4], axis=-1)  
output_all = concatenate([output_1,output_2,output_3,output_4,output_5], axis=-1)

损失函数

  1. 键盘按键损失(loss1a, loss1b, loss1c, loss1d
    • loss1a:计算 WASD 键(通常用于游戏中的移动)的二进制交叉熵损失。
    • loss1b:计算空格键的二进制交叉熵损失。
    • loss1c:计算重新加载键(如游戏中的“R”键)的二进制交叉熵损失。
    • loss1d(注释掉的部分):原本可能用于计算其他键盘按键的损失,但在提供的代码中,它被重新定义为武器切换键(1, 2, 3)的损失。
  2. 鼠标点击损失(loss2a, loss2b
    • loss2a:计算鼠标左键点击的二进制交叉熵损失。
    • loss2b:计算鼠标右键点击的二进制交叉熵损失(如果n_clicks大于1的话)。
  3. 鼠标移动损失(loss3, loss4
    • loss3:计算鼠标在 X 轴上的移动损失。由于鼠标移动是互斥的(即鼠标不能同时处于多个位置),因此使用了分类交叉熵损失(categorical_crossentropy)。
    • loss4:计算鼠标在 Y 轴上的移动损失,同样使用了分类交叉熵损失。

除此之外,还有一个loss_crit损失函数,

loss_crit = 10*losses.MSE(y_true[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  + GAMMA*y_pred[:,1:,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  ,y_pred[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1])

这是一个基于时序差分(Temporal Difference, TD)的均方误差(Mean Squared Error, MSE)损失函数,用于强化学习中的值函数逼近。它计算了当前时间步的奖励(或值)与下一个时间步的预测奖励(或值)之和(经过折扣因子 GAMMA 调整后)与当前时间步的预测奖励(或值)之间的均方误差。这种损失函数允许神经网络学习如何根据当前状态和环境信息来预测未来的奖励或值,从而优化策略或值函数。在这个特定的实现中,损失还乘以了一个系数(如10),可能是为了调整该损失在总损失中的相对权重。

奖励函数如下,奖励为 R(杀敌数,死亡数,子弹数)

reward_i = kill_i - 0.5*dead_i - 0.01*shoot_i # this is reward function  
y[i,j,-2:] = (reward_i,0.) # 0. is a placeholder for original advantage

效果展示

通过e2e.yml文件配置虚拟环境,更改了游戏内的窗口分辨率,设置了一些其他的参数,运行dm_run_agent.py以后在自己的电脑上成功复现

请添加图片描述

总结

本次基于Counter-Strike Deathmatch with Large-Scale Behavioural Cloning这个开源项目系统地学习了一下行为克隆的基本流程,从数据采集、模型训练以及损失函数的定义到最终复现,拓宽了我对RL的认知,在日后也能够更好地迁移到Robotic,逻辑如下:

  1. 数据收集:首先,需要收集人类专家在特定任务中的行为数据。这些数据通常包括机器人所处的状态(如位置、姿态、环境信息等)以及对应的人类专家在该状态下所采取的动作(如移动方向、操作指令等)。这些数据构成了行为克隆算法的训练集。
  2. 模型训练:使用收集到的数据训练一个模型,如神经网络模型。这个模型将学习从状态到动作的映射关系,即根据机器人当前的状态预测应该执行的动作。在训练过程中,模型会不断优化其参数,以最小化预测动作与真实动作之间的差异。
  3. 模型部署:训练好的模型可以部署到机器人上,用于指导机器人的行为。当机器人遇到新的状态时,它会将当前状态输入到模型中,模型会输出一个预测的动作。机器人将根据这个预测的动作来执行相应的操作。
  4. 反馈与调整:在机器人执行动作的过程中,可以通过收集反馈信息来进一步调整模型。例如,可以观察机器人执行动作后的效果,如果效果不理想,则可以收集新的数据并重新训练模型,以提高其性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/10723.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java项目之汽车资讯网站源码(springboot+mysql+vue)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的汽车资讯网站。项目源码以及部署相关请联系风歌,文末附上联系信息 。 项目简介: 汽车资讯网站的主要使用者管…

Vue从入门到实战Day04

一、组件的三大组成部分(结构/样式/逻辑) 1. scoped样式冲突 默认情况:写在组件中的样式会全局生效 -> 因此很容易造成多个组件之间的样式冲突问题。 1. 全局样式:默认组件中的样式会作用到全局 2. 局部样式:可以…

LeetCode 138. 随机链表的复制

目录 1.原题链接: 2.结点拆分: 代码实现: 3.提交结果: 4.读书分享: 1.原题链接: 138. 随机链表的复制 2.结点拆分: ①.拷贝各个结点,连接在原结点后面; ②.处…

【MySQL】基本操作

欢迎来到Cefler的博客😁 🕌博客主页:折纸花满衣 🏠个人专栏:MySQL 目录 👉🏻创建和删除数据库👉🏻数据库编码集和数据库校验集校验规则对数据库的影响 👉&…

【1 bit 翻转+无任何保护】MidnightsunQuals 2021 BroHammer

前言 又是一道非常有意思的题目,其实笔者很喜欢这种跟页表、特权级等相关的题目(:虽然大多都无法独立做出来,但是通过这些题目可以学到很多的东西 题目分析 内核版本:v4.17.0smap/smep/kpti/kaslr 全关 题目给了源…

laravel8 导入 excel常见问题

上传xls 或 xlsx 文件后,文件解析为 zip 格式,输入正常情况,不影响解析 里面的内容 遇到解析内容,解析为空的情况,可能是 因为excel 存在多个 Sheet1 造成,服务器不能解析一个 Sheet1 的情况&#xff0…

智慧停车场管理系统主要组成

智慧泊车场办理体系,完成了泊车办理过程中的车辆类型分类、出场时的车牌辨认、行进路线的引导、空余车位诱导,以及准备离场前的反向寻车和方便缴费等全部环节。这六个流程中,泊车场对车辆的办理,进步了泊车场的运行效率&#xff0…

【网络】为什么TCP需要四次挥手?

在网络通信中,TCP(传输控制协议)是一种可靠的、面向连接的协议,它在数据传输过程中保证了数据的可靠性和顺序性。而TCP的连接建立过程只需要三次握手,但是TCP的挥手过程却需要四次挥手,这是为什么呢&#x…

数据分享—中国土壤有机质数据

土壤有机质数据是进行区域土地资源评价,开展自然地理研究常使用的数据,本期推文主要分享全国土壤有机质数据集。梧桐君会不定期分享地理信息数据,欢迎大家长期订阅。 数据来源 “万物土中生”,小编今天要分享的中国土壤有机质数…

Tomcat 内核详解 - Web服务器机制

详细介绍 Apache Tomcat 是一个开源的Web服务器和Servlet容器,它实现了Java Servlet、JavaServer Pages (JSP) 和WebSocket规范。Tomcat的核心设计围绕着几个关键组件,它们共同构成了处理HTTP请求、管理Web应用部署和执行Servlet逻辑的基础架构。 Apac…

牛客NC404 最接近的K个元素【中等 二分查找+双指针 Java/Go/PHP】

题目 题目链接: https://www.nowcoder.com/practice/b4d7edc45759453e9bc8ab71f0888e0f 知识点 二分查找;找到第一个大于等于x的数的位置idx;然后从idx开始往两边扩展Java代码 import java.util.*;public class Solution {/*** 代码中的类名、方法名、…

小程序组件间传值

1、属性绑定&#xff08;Props&#xff09;: 父组件通过在子组件标签上设置属性的方式向子组件传值。 子组件通过properties定义接收的属性 父组件&#xff1a; wxml <child-component title"{{parentData}}"></child-component>子组件&#xff1a; js p…

可观测性监控

1 目的 常见的监控&#xff0c;主要是以收集数据以识别异常系统效应为主&#xff0c;多是单个服务&#xff0c;相互独立的状态。 可观测性&#xff0c;希望调查异常系统效应的根本原因&#xff0c;能够把多个服务、中间件、容器等串联起来&#xff0c;同时柔和metrics、log、…

前端怎么用 EventSource? EventSource怎么配置请求头及加参数? EventSourcePolyfill使用方法

EventSource EventSource 接口是 web 内容与服务器发送事件通信的接口。 一个 EventSource 实例会对 HTTP 服务器开启一个持久化的连接&#xff0c;以 text/event-stream 格式发送事件&#xff0c;此连接会一直保持开启直到通过调用 EventSource.close() 关闭。 EventSource…

常见的推荐系统框架

1&#xff09;Microsoft Recommender&#xff1a; 该框架由微软开发&#xff0c;可以免费使用&#xff0c;主要提供了包括一般功能&#xff08;Common Utilities&#xff09;、大数据功能&#xff08;Dataset Utilities&#xff09;、评价功能&#xff08;Evaluation Utilitie…

将本地docker镜像以压缩包格式保存至其他路径、从本地的镜像压缩包中加载docker镜像

保存本地Docker镜像为压缩包至其他路径 你可以使用 docker save 命令结合输出重定向&#xff08; -o 选项&#xff09;来将本地Docker镜像保存为一个压缩包&#xff08;通常是tar格式&#xff09;并直接保存到指定的路径。以下是一个示例命令&#xff1a; docker save -o /pa…

c++ - 在循环中使用迭代器删除 unordered_set 中的元素

标签 c unordered-set 请考虑以下代码: Class MyClass 为自定义类:class MyClass { public:MyClass(int v) : Val(v) {}int Val; };然后下面的代码将在调用 it T.erase(it); 之后在循环中导致 Debug Assertion Failed: unordered_set<MyClass*> T; unordered_set<…

vue3.0(六) toRef,toValue,toRefs和toRow,markRaw

文章目录 toReftoValuetoRefstoRowmarkRawtoRef和toRefs的区别toRaw 和markRaw的用处 toRef toRef 函数可以将一个响应式对象的属性转换为一个独立的 ref 对象。返回的是一个指向源对象属性的 ref 引用&#xff0c;任何对该引用的修改都会同步到源对象属性上。使用 toRef 时需…

C#中的继承、接口和多态性

继承&#xff08;Inheritance&#xff09; 在C#中&#xff0c;继承允许我们创建一个新的类&#xff08;称为子类或派生类&#xff09;&#xff0c;该类从另一个已存在的类&#xff08;称为父类或基类&#xff09;中继承方法和属性。子类可以添加新的方法和属性&#xff0c;或者…

2024年最新趋势跨境电商平台开发需了解的新技术

随着数字化技术的不断演进和全球市场的日益融合&#xff0c;跨境电商平台开发将面临前所未有的挑战和机遇。为了更好地适应并引领这一发展&#xff0c;开发者需要密切关注2024年最新的技术趋势&#xff0c;以确保他们的平台能够在竞争激烈的市场中脱颖而出。本文将对跨境电商平…