DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in

一、数据处理
用于训练 RNN 的 mini-batch 数据不同于通常的数据。 这些数据通常应按时间序列排列。 对于 DI-engine, 这个处理是在 collector 阶段完成的。 用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。 这里将在下一节中具体解释。

什么是数据处理?
数据处理指的是为循环神经网络(RNN)训练准备时间序列数据的过程。这个过程包括将收集到的数据组织成适当格式的小批量(mini-batches),这些批量数据将用于网络的训练。这一步骤通常发生在DI-engine的collector阶段,也就是数据收集和预处理发生的地方。用户需要在配置文件中指定 learn_unroll_len 以确保序列数据的长度与算法匹配。 对于大多数情况, learn_unroll_len 应该等于 RNN 的历史长度(a.k.a 时间序列长度),但在某些情况下并非如此。比如,在 r2d2 中, 我们使用burn-in操作, 序列长度等于 learn_unroll_len + burnin_step 。例如,如果你设置 learn_unroll_len = 10 和 burnin_step = 5,那么 RNN 实际接收的输入序列长度将是 15:前 5 步为 burn-in(用于预热隐藏状态),接下来的 10 步作为学习的一部分。这样设置可以帮助 RNN 在计算梯度和进行权重更新时,有一个更加准确的隐藏状态作为起点。
部分名词解释

  • mini-batches:在机器学习中,特别是在训练神经网络时,数据一般被分成小的批次进行处理,这些批次被称为 “mini-batch”。一个 mini-batch 包含了一组样本,这组样本用于执行单次迭代的前向传播和反向传播,以更新网络的权重。使用 mini-batches 而不是单个样本或整个数据集(后者称为 “batch” 或 “full-batch”)可以平衡计算效率和内存限制,有助于提高学习的稳定性和收敛速度。
  • collector阶段:在 DI-engine中,collector 阶段是指环境与智能体交互并收集经验数据的过程。在这个阶段,智能体根据其当前的策略执行操作,环境则返回新的状态、奖励和其他可能的信息,如是否达到终止状态。收集到的数据(经常被称为经验或转换)随后被用于训练智能体的模型,例如对策略或价值函数进行更新。

为什么要进行数据处理:

  1. 保持时间依赖性:RNN的核心优势是处理具有时间序列依赖性的数据,比如语言、视频帧、股票价格等。正确的数据处理确保了这些时间依赖性在训练数据中得以保留,使得模型能够学习到数据中的序列特征。
  2. 提高学习效率:通过将数据划分为与模型期望的序列长度匹配的批次,可以提高模型学习的效率。这样做可以确保网络在每次更新时都接收到足够的上下文信息。
  3. 适配算法要求:不同的RNN算法可能需要不同形式的输入数据。例如,标准的RNN只需要过去的信息,而一些变体如LSTM或GRU可能会处理更长的序列。特定的算法,如R2D2,还可能需要额外的步骤(如burn-in),以便更好地初始化网络状态。
  4. 处理不规则长度:在现实世界的数据集中,序列长度往往是不规则的。数据处理确保了每个mini-batch都有统一的序列长度,这通常通过截断过长的序列或填充过短的序列来实现。
  5. 优化内存和计算资源:通过将数据组织成具有固定时间步长的批次,可以更有效地利用GPU等计算资源,因为这些资源在处理固定大小的数据时通常更高效。
  6. 稳定学习过程:特别是在强化学习中,使用如n-step返回或经验回放的技术,可以帮助模型从环境反馈中学习,并减少方差,从而稳定学习过程。

如何进行数据处理

def _get_train_sample(self, data: list) -> Union[None, List[Any]]:    data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)    return get_train_sample(data, self._sequence_len)

 代码段 def _get_train_sample(self, data: list) 是一个方法,它的作用是从收集到的数据中提取用于训练 RNN 的样本。这个方法会在两个步骤中处理数据:

  • N步返回计算(get_nstep_return_data): 这个函数接受原始的经验数据,然后计算所谓的 N 步返回值。N 步返回是一个在强化学习中用于临时差分(Temporal Difference, TD)学习的概念,它考虑了从当前状态开始的未来 N 步的累积奖励。计算这个值需要使用折现因子 gamma。这个步骤的目的是为了让智能体学习如何根据当前的行动预测未来的奖励,这是强化学习中价值函数估计的重要部分。
  • 训练样本获取(get_train_sample): 在得到 N 步返回值之后,这个函数进一步处理数据以生成训练样本。具体地,它会根据 self._sequence_len(即时间序列长度或者 RNN 的历史长度)来选择数据序列。这意味着每个训练样本将是一个具有 self._sequence_len 长度的数据序列,这对于训练 RNN 来说是必要的,因为 RNN 需要一定长度的历史来维护其内部状态(或记忆)。

有关这两个数据处理功能的工作流程见下图:

二、初始化隐藏状态 (Hidden State)
RNN用于处理具有时间依赖性的信息。RNN的隐藏状态(Hidden State)是其记忆的一部分,它能够捕捉到前一时间步长的信息。这些信息对于预测下一个动作或状态非常关键。在此上下文中,初始化RNN的隐藏状态是一个重要的步骤,它确保了RNN在开始新的数据批次处理时具有正确的起始状态。
策略的 _learn_model 需要初始化 RNN。这些隐藏状态来自 _collect_model 保存的 prev_state。 用户需要通过 _process_transition 函数将这些状态添加到 _learn_model 输入数据字典中。 

def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:    transition = {        'obs': obs,        'action': model_output['action'],        'prev_state': model_output['prev_state'], # add ``prev_state`` key here        'reward': timestep.reward,        'done': timestep.done,    }    return transition

点击DI-engine强化学习入门(十又二分之一)如何使用RNN——数据处理、隐藏状态、Burn-in - 古月居 可查看全文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/600877.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

UDP多播

1 、多播的概念 多播,也被称为组播,是一种网络通信模式,其中数据的传输和接收仅在同一组内进行。多播具有以下特点: 多播地址标识一组接口:多播使用特定的多播地址,该地址标识一组接收数据的接口。发送到多…

【知识点随笔分享 | 第十篇】快速介绍一致性Hash算法

前言: 在分布式系统中,数据的分布和负载均衡是至关重要的问题。一致性哈希算法是一种解决这些挑战的有效工具,它在分布式存储、负载均衡和缓存系统等领域得到了广泛应用。 随着互联网规模的不断扩大,传统的哈希算法在面对大规模…

大历史下的 tcp:一个松弛的传输协议

如果 tcp 是一个相对松弛的协议,会发生什么。 所谓松弛感,意思是它允许 “漏洞”,允许可靠传输的不封闭,大致就是:“不求 100% 可靠,只要 90%(或多或少) 可靠,另外 10% 的错误可检测到” or “…

STM32:配置EXTI—对射式红外传感器计次

文章目录 1、中断1.2 中断系统1.3 中断执行流程 2、STM32中断2.2EXTI(外部中断)2.3 EXTI 的基本结构2.4 AFIO复用IO口 3、NVIC基本结构3.2 NVIC优先级分组 4、配置EXTI4.2 AFIO 库函数4.3 EXTI 库函数4.4 NVIC 库函数4.5 配置EXTI的步骤4.6 初始化EXTI 1…

渗透测试流程

一、攻击流程 信息收集阶段→漏洞分析阶段→攻击阶段→后渗透阶段 二、信息收集 1、收集内容: IP资源:真实IP获取、旁站信息收集、C段主机信息收集域名发现:子域名信息收集、子域名枚举发现子域名、搜索引擎发现子域名、第三方聚合服务器发…

PyQt 入门

Qt hello - 专注于Qt的技术分享平台 Python体系下GUI框架也多了去了,PyQt算是比较受欢迎的一个。如果对Qt框架熟悉,那掌握这套框架是很简单的。 一,安装 1.PyQt5 pip3 install PyQt5 2.Designer UI工具 pip3 install PyQt5-tools 3.UI…

MFC DLL注入失败一些错误总结

使用cheat Engine为MFC窗口程序注入DLL时一定要注意,被注入的exe程序和注入的DLL 的绝对路径中一定不要带有中文字符,否则会遇到各种各样的奇怪错误,如下所示: 以下是dll绝对路径中均含有中文字符,会报错误&#xff…

【BUUCTF】Crypto_RSA(铜锁/openssl使用系列)

【BUUCTF】Crypto_RSA(铜锁/openssl使用系列) 1、题目 在一次RSA密钥对生成中,假设p473398607161,q4511491,e17 求解出d作为flga提交 2、解析 RSA加密过程: 1)选择素数:选择两个不…

python中一些莫名其妙的异常

目录 一、字符串中空格\xa0二、文件写入为空问题三、Counter对NAN空值的统计问题 一、字符串中空格\xa0 对于文本中的一些空格,原始状态时显示为普通“空格”(其实是latin1编码字符),但是经过split()操作后,这些latin…

Linux cmake 初窥【2】

1.开发背景 基于上一篇的基础上,再次升级 2.开发需求 基于 cmake 指定源文件目录可以是多个文件夹,多层目录 3.开发环境 ubuntu 20.04 cmake-3.23.1 4.实现步骤 4.1 准备源码文件 工程目录如下 顶层脚本 compile.sh 负责执行 cmake 操作&#xff0…

类加载器aa

一,关系图及各自管辖范围 (不赘述) 二,查看关系 package com.jiazai;public class Main {public static void main(String[] args) {ClassLoader appClassLoader ClassLoader.getSystemClassLoader();//默认System.out.println…

赋能企业数字化转型 - 易点易动固定资产系统与飞书实现协同管理

在当前瞬息万变的商业环境下,企业如何借助信息化手段提升管理效率,已经成为摆在各行各业面前的紧迫课题。作为企业数字化转型的重要一环,固定资产管理的信息化建设更是不容忽视。 易点易动作为国内领先的企业资产管理服务商,凭借其全方位的固定资产管理解决方案,助力众多企业实…

SQL注入实例(sqli-labs/less-1)

初始网页 从网页可知传递的参数名为 id,并且为数字类型 1、得知数据表有多少列 1.1 使用联合查询查找列数(效率低) http://localhost/sqli-labs-master/Less-1/?id1 union select 1,2 -- 1.2 使用order by查找列数(效率高&…

重学java 30.API 1.String字符串

于是,虚度的光阴换来了模糊 —— 24.5.8 一、String基础知识以及创建 1.String介绍 1.概述 String类代表字符串 2.特点 a.Java程序中的所有字符串字面值(如“abc”)都作为此类的实例(对象)实现 凡是带双引号的,都是String的对象 String s "abc&q…

【JVM】类加载机制及双亲委派模型

目录 一、类加载过程 1. 加载 2. 连接 a. 验证 b. 准备 c. 解析 3. 初始化 二、双亲委派模型 类加载器 双亲委派模型的工作过程 双亲委派模型的优点 一、类加载过程 JVM的类加载机制是JVM在运行时,将 .class 文件加载到内存中并转换为Java类的过程。它…

第8篇:创建Nios II工程之读取Switch的值<一>

Q:本期我们再添加一个PIO组件设为输入,创建Nios II工程读取输入值显示在LED上。 A:在前2期创建的控制LED工程的Platform Designer系统基础上再添加一个PIO核,参数设置为18位和单向输入模式,表示DE2-115开发板上的18个…

rmallox勒索病毒肆虐,如何保护网络安全?

rmallox勒索病毒与网络安全的关系可以从以下几个方面来阐述: 一、rmallox勒索病毒的特性 rmallox勒索病毒是一种极具破坏性的计算机病毒,它具有多个显著特性,这些特性使得该病毒对网络安全构成了严重威胁。具体来说,rmallox病毒具…

六西格玛项目的核心要素:理论学习、实践应用与项目经验

许多朋友担心,没有项目经验是否就意味着无法考取六西格玛证书。针对这一疑问,张驰咨询为大家详细解答。 首先,需要明确的是,六西格玛项目不仅仅是一种管理工具或方法,更是一种追求卓越、持续改进的思维方式。它强调通…

Java反序列化-CC11链

前言 这条链子的主要作用是为了可以在 Commons-Collections 3.2.1 版本中使用,而且还是无数组的方法。这条链子适用于 Shiro550漏洞 CC11链子流程 CC2 CC6的结合体 CC2 这是CC2的流程图,我们取的是后面那三个链子,但是由于CC2 只能在 c…

2024年第九届数维杯数学建模A题思路分享

文章目录 1 赛题思路2 比赛日期和时间3 竞赛信息4 建模常见问题类型4.1 分类问题4.2 优化问题4.3 预测问题4.4 评价问题 5 建模资料 1 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 2 比赛日期和时间 报名截止时间:2024…
最新文章