DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/8033.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

神经网络极简入门

神经网络是深度学习的基础,正是深度学习的兴起,让停滞不前的人工智能再一次的取得飞速的发展。 其实神经网络的理论由来已久,灵感来自仿生智能计算,只是以前限于硬件的计算能力,没有突出的表现,直至谷歌的A…

mysql workbench如何导出insert语句?

进行导出设置 导出的sql文件 CREATE DATABASE IF NOT EXISTS jeesite /*!40100 DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci */ /*!80016 DEFAULT ENCRYPTIONN */; USE jeesite; -- MySQL dump 10.13 Distrib 8.0.28, for Win64 (x86_64) -- -- Host: 127.0…

如何使用dockerfile文件将项目打包成镜像

要根据Dockerfile文件来打包一个Docker镜像,你需要遵循以下步骤。这里假设你已经安装了Docker环境。 1. 准备Dockerfile 确保你的Dockerfile文件已经准备就绪,并且位于你希望构建上下文的目录中。Dockerfile是一个文本文件,包含了用户可以调…

顺序表的实现(迈入数据结构的大门)(1)

什么是数据结构 数据结构是由:“数据”与“结构”两部分组成 数据与结构 数据:如我们所看见的广告、图片、视频等,常见的数值,教务系统里的(姓名、性别、学号、学历等等); 结构:当…

线性表--数据结构设计与操作

单链表 1.单链表的定义&#xff1a; typedef struct LNode{Elemtype data;struct Lnode *next; }LNode ,*LinkList;//单链表的数据结构&#xff08;手写&#xff09; #include<iostream> #include<vector> #include<algorithm>typedef int TypeElem; //单链表…

OpenAI API搭建的智能家居助手;私密大型语言模型(LLM)聊天机器人;视频和音频文件的自动化识别和翻译工具

✨ 1: GPT Home 基于Raspberry Pi和OpenAI API搭建的智能家居助手 GPT Home是一个基于Raspberry Pi和OpenAI API搭建的智能家居助手&#xff0c;功能上类似于Google Nest Hub或Amazon Alexa。通过详细的设置指南和配件列表&#xff0c;用户可以自行组装和配置这个设备&#x…

Ansible自动运维工具之playbook

一.inventory主机清单 1.定义 Inventory支持对主机进行分组&#xff0c;每个组内可以定义多个主机&#xff0c;每个主机都可以定义在任何一个或多个主机组内。 2.变量 &#xff08;1&#xff09;主机变量 [webservers] 192.168.10.14 ansible_port22 ansible_userroot ans…

使用sqlmodel实现唯一性校验

代码&#xff1a; from sqlmodel import Field, Session, SQLModel, create_engine# 声明模型 class User(SQLModel, tableTrue):id: int | None Field(defaultNone, primary_keyTrue)# 不能为空&#xff0c;必须唯一name: str Field(nullableFalse, uniqueTrue)age: int | …

Flutter弹窗链-顺序弹出对话框

效果 前言 弹窗的顺序执行在App中是一个比较常见的应用场景。比如进入App首页&#xff0c;一系列的弹窗就会弹出。如果不做处理就会导致弹窗堆积的全部弹出&#xff0c;严重影响用户体验。 如果多个弹窗中又有判断逻辑&#xff0c;根据点击后需要弹出另一个弹窗&#xff0c;这…

大数据Scala教程从入门到精通第五篇:Scala环境搭建

一&#xff1a;安装步骤 1&#xff1a;scala安装 1&#xff1a;首先确保 JDK1.8 安装成功: 2&#xff1a;下载对应的 Scala 安装文件 scala-2.12.11.zip 3&#xff1a;解压 scala-2.12.11.zip 4&#xff1a;配置 Scala 的环境变量 在Windows上安装Scala_windows安装scala…

docker搭建代码审计平台sonarqube

docker搭建代码审计平台sonarqube 一、代码审计关注的质量指标二、静态分析技术分类三、sonarqube流程四、快速搭建sonarqube五、sonarqube scanner的安装和使用 一、代码审计关注的质量指标 代码坏味道 代码规范技术债评估 bug和漏洞代码重复度单测与集成 测试用例数量覆盖率…

node报错——解决Error: error:0308010C:digital envelope routines::unsupported——亲测可用

今天在打包vue2项目时&#xff0c;遇到一个报错&#xff1a; 最关键的代码如下&#xff1a; Error: error:0308010C:digital envelope routines::unsupportedat new Hash (node:internal/crypto/hash:80:19)百度后发现是node版本的问题。 在昨天我确实操作了一下node&…

Ansible——Playbook剧本

目录 一、Playbook概述 1.Playbook定义 2.Playbook组成 3.Playbook配置文件详解 4.运行Playbook 4.1Ansible-Playbook相关命令 4.2运行Playbook启动httpd服务 4.3变量的定义和引用 4.4指定远程主机sudo切换用户 4.5When——条件判断 4.6迭代 4.6.1创建文件夹 4.6.2…

[Linux][网络][TCP][四][流量控制][拥塞控制]详细讲解

目录 1.流量控制2.拥塞控制0.为什么要有拥塞控制&#xff0c;不是有流量控制么&#xff1f;1.什么是拥塞窗口&#xff1f;和发送窗口有什么关系呢&#xff1f;2.怎么知道当前网络是否出现了拥塞呢&#xff1f;3.拥塞控制有哪些算法&#xff1f;4.慢启动5.拥塞避免6.拥塞发生7.快…

劝退计算机?CS再过几年会没落!?

事实上&#xff0c;未来计算机不仅不会没落&#xff0c;国家还会大力发展 只不过大家认为的计算机就是什么Java web&#xff0c;真正的计算机行业是老美那样的&#xff0c;涉及到方方面面&#xff0c;比如&#xff1a; web&#xff0c;图形学&#xff0c;Linux系统开发&#…

2024DCIC海上风电出力预测Top方案 + 光伏发电出力高分方案学习记录

海上风电出力预测 赛题数据 海上风电出力预测的用电数据分为训练组和测试组两大类&#xff0c;主要包括风电场基本信息、气象变量数据和实际功率数据三个部分。风电场基本信息主要是各风电场的装机容量等信息&#xff1b;气象变量数据是从2022年1月到2024年1月份&#xff0c;…

Skywalking数据持久化与自定义链路追踪

学习本篇文章之前首先要了解一下Sky walking的基础知识 分布式链路追踪工具Skywalking详解 一&#xff0c;Sky walking数据持久化 Sky walking提供了es&#xff0c;MySQL等数据持久化方案&#xff0c;默认使用h2基于内存的数据库&#xff0c;重启之后数据即会丢失。 在实际工…

【Git】Git学习-16:git merge,且解决合并冲突

学习视频链接&#xff1a; 【GeekHour】一小时Git教程_哔哩哔哩_bilibili​编辑https://www.bilibili.com/video/BV1HM411377j/?vd_source95dda35ac10d1ae6785cc7006f365780 1 创建分支dev&#xff0c;并用merge合并master分支&#xff0c;使dev分支合并上master分支中内容为…

【学习笔记】HarmonyOS 4.0 鸿蒙Next 应用开发--安装开发环境

开发前的准备 首先先到官网去下载Devco Studio 这个开发工具&#xff0c;https://developer.harmonyos.com/cn/develop/deveco-studio/#download 提供了WIndows和Mac的开发环境&#xff0c;我自己是Windows的开发环境。 所以下载之后直接点击exe进行安装即可。 如果之前安装过…

Eplan带你做项目——如何实现项目的交付

前言 Eplan作为一款专业的电气工程设计软件&#xff0c;不仅在设计阶段为电气工程师提供了强大的绘图、计算、仿真等功能&#xff0c;还具备丰富的数据管理与交换能力&#xff0c;能够便捷、准确地导出软件设计、生产制造所需的数据&#xff0c;实现电气设计与软件设计、生产制…