RNN-T Training,RNN-T模型训练详解——语音信号处理学习(三)(选修三)

参考文献:

Speech Recognition (option) - RNN-T Training哔哩哔哩bilibili

2020 年 3月 新番 李宏毅 人类语言处理 独家笔记 Alignment Train - 8 - 知乎 (zhihu.com)

本次省略所有引用论文

目录

一、如何将 Alignment 概率加和

对齐方式概率如何计算

概率加和计算原理

概率加和计算方式

二、RNN-T 的模型训练

模型训练思路

偏微分计算-1-展开变形

偏微分计算-2-第一个偏微分求解

偏微分计算-3-第二个偏微分求解

三、RNN-T 的模型测试(推理/解码)

目标函数的近似

实际操作

四、总结——LAS、CTC、RNN-T 模型比较


 

一、如何将 Alignment 概率加和

对齐方式概率如何计算
  • 想要知道如何将所有的对齐方式的概率相加,我们就需要知道一条对齐方式的概率是怎么计算的。由于 HMM、CTC 和 RNN-T 的概率计算方式在本质上是一样的,因此我们下面的实验与计算全都基于 RNN-T。

  • 我们将一个 alignment 通过状态图的方式表现出,实际上,需要计算这个 alignment 的概率,只需要将所有位置的概率进行连乘就行。比如h = ∅c∅∅a∅t∅∅。P(h|X) 就等于每个位置的发射概率和转移概率的连乘,也就是第一次输出 ∅ 的概率,乘以给定 ∅,输出 c 的概率,乘以给定 ∅c,输出 ∅ 的概率……

  • 我们将整个过程落实到实际操作当中去看看。首先,我们需要回顾 RNN-T 的架构。RNN-T 的一大神奇之处在于,它单独训练了一个 RNN,将已输出的 token 当作输入,去影响 RNN-T 接下来的输出。

  • 我们使用 h 表示经过 encoder 的声学特征向量,图中上半部分蓝色的方块表示单独训练的 RNN。在起始阶段,没有产生任何的 token,我们就输入一个 <BOS (Begin of Sentence)>,让它产生 l0。我们把编码产生的h1,与l0一起输入给解码器,让它产生一个概率 p_{1,0}。这里的下标表示的意思为:输入第一个声学特征向量(1),没产生任何 token 时(0),RNN-T 产生出的概率分布。

  • 那么 ∅ 落在句首的概率就可以计算了,也就是从 p_{1,0} 中采样出 ∅ 的概率。

  • 接下来我们需要计算有了 ∅ 以后产生 c 的概率。值得一提的是,刚刚产生的 ∅ 对我们的 RNN 并没有什么影响。因为 RNN 只吃产生的 token。不过,产生的 ∅ 会对 Encoder 产生影响,这代表当前的隐藏层向量已经被读完了,没啥价值了,需要切换下一个向量。

  • 因此在下一步计算过程中,我们将 h2 和之前的 l0 一起输入编码器,输出得到概率 p_{2,0}。那么产生 c 的概率也就好算了,就是从概率分布 p_{2,0} 中采样得到 c 的概率。

  • 接下来该计算有 ∅c 后产生 ∅ 的概率。由于我们刚刚输出了 token c,RNN 就会受到影响,输入 token c 以后产生 l1,而 Encoder 不变,因为它没有看到 ∅,所以不需要更换向量。因此我们最终将 l1 和 h2 丢给解码器,得到新的概率分布 p_{2,1},从中我们可以得到我们需要的概率。

  • 按照上面的过程,我们一直反复下去,最终我们就可以算出所有需要的概率,我们将所有的概率相乘,就是我们最终想得到的这一个 alignment 的概率。

概率加和计算原理
  • 那么我们是怎么计算所有对齐方式的概率加和的呢?这就要归功于我们刚刚所说的 RNN-T 的神奇之处:使用单独的 RNN 来表示 token 之间的关系,而忽略 ∅ 的影响。这在后续的训练中大有帮助。

  • 我们看下图,实际上下图中的每一个格子都可以对应到一个概率分布,由于刚刚定义的概率分布的下标分别表示读到的声学特征向量以及已输出的token数量,那么格子对应的概率分布就显而易见。比如图中给出了 p_{4,2} 的概率分布的格子,这就表示我们已经读到 x4,并且前边已经输出了两个 token ca。

  • 而对于 p_{4,2},需要计算之后产生 ∅ 或者 t 的概率都可以从中得到。

  • 神奇之处在于,每一个格子代表的概率分布实际上都是固定的,它们不会受到如何走到当前格子的走法的影响,因为就其输入来说,无论怎么走,输入的都是 h4 和 l2。

概率加和计算方式
  • HMM 采用的是 forward 和 backward 算法来计算所有对齐方式的概率分数。而实际上,RNN-T 和 HMM 所用的方法也是一模一样的。

  • 我们新定义一个变量 α_{i,j},其表示已经读取了 i 个的声学特征向量,输出 j 个 token 的所有对齐方式的概率分数之和。比如 α_{4,2},就是由读取 4 个声学特征向量,输出 2 个 token 的所有 Alignment 的分数相加之和。

  • 那么 α_{4,2} 有没有方法计算呢?有。我们可以通过 α_{4,1} 和 α_{3,2} 进行计算。事实上,在变成 α_{4,2} 之前,有两种可能,一种是读了 4 个声学特征向量,输出一个 token 了,准备输出下一个;还有一种可能是已经读了 3 个声学特征向量,产生了两个 token,准备读取下一个声学特征向量(产生 ∅ )。假设我们的 α_{4,1} 和 α_{3,2} 已经计算出来了,则结合之前定义的 p_{i, j},我们可以有:


    \alpha_{4,2} = \alpha_{4,1}p_{4,1}(a) + \alpha_{3,2}p_{3,2}(\varnothing)
     

  • 也就是 α_{4,1} 代表的所有 alignment 乘上之后产生 token a 的概率,加上 α_{3,2} 代表的所有 alignment 乘上之后产生 ∅ 的概率。

  • 根据上面的式子,我们就可以得到一个基于动态规划的递推式,这样就能从左上角开始,一直算出最后一个格子的分数总和了。

二、RNN-T 的模型训练

我们刚刚讲述了如何去穷举所有的对齐方式进行概率总分计算,不过这一切都需要基于我们已经有了训练好的 RNN-T 的基础上。所以本节我们来了解一下如何训练 RNN-T。

模型训练思路
  • 首先我们要明确我们的训练目标。假设 Y_hat 是我们的 Ground Truth,也就是正确的识别文本,那么也就是说我们希望学习到一组参数 θ,使得 Y_hat 的概率越大越好:


    \theta^* = \arg\max_\theta \log P_\theta(\widehat{Y}|X)
     

  • 那么我们如何 optimize 这个函数呢?当然是使用梯度下降法进行。所以我们下一个要解决的问题就是如何求取函数对参数求偏微分。


    \frac{\partial P(\widehat{Y}|X)}{\partial \theta} = ?
     

偏微分计算-1-展开变形
  • 我们将概率求解函数展开,它就像我们上面所说,是由一堆对齐概率加和而成的。而每一个对齐概率又是由某些概率相乘而得到的。

  • 哪些概率?是由从起点到终点的某条路径上的每一个箭头所代表的概率,也就是在某个状态下产生某一个 token 的概率相乘得到的。因此,由这一系列所有的箭头相乘,然后相加,就最终得到了我们的概率。

  • 而这些产生某一个 token 的概率又受到模型参数 θ 的影响,目标概率又受到这些小概率的影响,所以我们可以先计算某个小概率对 θ 的偏微分,然后再计算目标概率对这些小概率的偏微分,和之前的相乘,然后再计算下一个小概率对 θ 的偏微分,乘上目标概率对小概率的偏微分……以此类推,最终将所有结果加和,就可以得到我们的目标式子,即:

偏微分计算-2-第一个偏微分求解
  • 好的,经过上面的变形,现在压力给到了如何计算小概率,即每个箭头代表的概率,对参数 θ 的偏微分。

  • 我们以 p_{4,1}(a) 对 θ 的偏微分的计算作为例子。


    \frac{\partial p_{4,1}(a)}{\partial \theta} = ?
     

  • 其计算方式,或者说训练方式其实和普通模型一样,还是采用经典的 BPTT(Backpropagation Through Time,反向传播通过时间)时序的反向传播。一开始最右边的结果计算和标签的损失,反向传播传到编码器,再传到上面的解码器 RNN。

偏微分计算-3-第二个偏微分求解
  • 第一个偏微分式子可以解了,下面压力来到了第二个偏微分式子上,也就是目标概率对每个箭头概率的偏微分。我们以计算对 p_{4,1} 的偏微分为例,公式如下:


    \frac{\partial P(\widehat{Y}|X)}{\partial p_{4,1}(a)} = ?
     

  • 首先,我们要把包含 p_{4,1} 的对齐方式和不包含 p_{4,1} 的对齐方式分开算:


    P(\widehat{Y}|X) = \sum_{h\space with\space p_{4,1}(a)}P(h|x) + \sum_{h\space without\space p_{4,1}(a)}P(h|x)
     

  • 由于第二项是没有 p_{4,1} 的,因此当做偏微分的时候,第二项就消失了。而第一项我们知道,是由很多箭头概率相乘相加得到的。既然有 p_{4,1},我们就可以将它提取出来,如下图

  • 这样偏微分后就只剩提取出 p_{4,1} 之后的 other 了。并且我们还可以把 other 写成 P/p,然后再把这个 1/p 提出来,就可以了。

  • 所以,问题就被转化成了计算带有 p_{4,1} 的对齐方式的概率之和。我们应该如何计算呢?此时,我们可以再引入另一个辅助变量 β_{i,j},它与α_{i,j}很像,它表示从第 i 个声学特征开始且输出到第 j 个 token,在当前位置到结束的所有对齐方式分数之和。

  • β_{4,2} 如图所示,它表示已经产生了4个声学特征和输出两个 token 的情况下,在当前位置走到结尾为止的所有路径的分数总和。β_{i,j} 刚好是 α_{i,j} 的反过来。前面 α_{i,j} 对应着 HMM 的正向传播算法,这里 β_{i,j} 对应着 HMM 的反向传播算法。通过动态规划算法,于是我们有递推式,β_{i,j} = β_{i+1,j}p_{i,j} + β_{i,j+1}p_{i,j}。

  • 有了递推式以后,我们就可以将所有点的 β 值全部计算出来。而有了 α 和 β 的值以后,我们就可以计算带有 p_{4,1} 的对齐方式的概率之和了。 我们看下图:所有从起始位置到 (4,1) 的候选对齐路径的分数和 α_{4,1} 乘上 p_{4,1}(a) 后,再乘上所有从位置 (4,2) 到终点的候选对齐路径的分数和 β_{4,2},这就是所有包含 p_{4,1}(a) 的分数总和。

  • 我们将式子带入,并乘上系数,p_{4,1}(a) 得到约分,最终的偏微分结果就是 α_{4,1}β_{4,2}。

  • 因此带入最终的式子后,就能计算全部候选对齐的得分对模型参数的梯度。然后反向传播更新模型参数进行训练。我们就可以进行正常训练了。

三、RNN-T 的模型测试(推理/解码)

目标函数的近似
  • 训练好模型了以后,我们就可以进行模型的使用了。我们的目标函数如下,也就是找到一个 Y,使得 P of Y given X 达到最大值,这个 Y 就是模型语音辨识的结果。


    Y^* = \arg\max_Y \log{P(Y|X)}
     

  • 这实际上不是一个简单的问题。理想状态下我们需要穷举所有的 Y,来计算概率,然而别说穷举不容易实现,就连计算概率都是大量的对齐方式概率相加之和,就更不容易了。

  • 所以我们采用一些近似估计的方法,首先就是对 “将所有对齐方式概率加和作为分数” 这一条进行近似。我们不把所有的候选对齐分数加起来,而是选取每一个Y中,分数最高的那个对齐方式的概率作为分数。不过,这个近似需要基于这样一个事实:概率最大的对齐方式要比其他的对齐方式要大很多。那事实真的是这样吗?(老师:反正我信了)

  • 我们将概率最大的对齐方式记作 h*,然后用 h* 进行 inverse,找到其对应的 Y*,就是最终解码的结果啦。计算 P of h given X 的方式我们在之前都有讲过,这里在图中呈现回顾一下,不再用文字赘述。

实际操作
  • 实际中要怎么找一个概率最高的对齐方式呢?RNN-T 每一个时间步都会跑出一个概率分布。我们把每个概率分布中,概率最大的那个 token 取出来,就是 h* 的一个近似。不过,每次都取概率分布中概率最大的,不见得会使得整个对齐方式的概率是最大的(原因距离可以看束搜索 Beam Search 讲解)。不过没有关系,我们照样可以采用 Beam Search 的方法来得到更准确的结果。

四、总结——LAS、CTC、RNN-T 模型比较

  • 我们看下面这张表。在解码部分,LAS 和 RNN-T 会考虑前面的时序对当前时序的影响。而 CTC 并不会考虑之前的时间步已经生成出来的token。所以 LAS 和 RNN-T 在解码部分是相对比较强的。

  • 在对齐部分,CTC 和 RNN-T 都是需要考虑对齐的。而因为中间的注意力层,LAS不用显式地考虑对齐,而是采用 soft alignment,使用注意力机制来找出语音和文字之间的关系。

  • 在训练部分,LAS 只需要直接训练就行,而 CTC 和 RNN-T 则需要将所有的对齐方式概率相加,比较麻烦。

  • 对于语音识别模型,在线识别(实时识别)也是一个很重要的功能,使用者一边说一边就能跑出语音辨识的结果。对于 LAS,由于注意力一次要看全部,也就是需要等语者说完才能进行推理,因此 LAS 不能在线识别。而 CTC 和 RNN-T 都是可以的,之前有说过,Pixel 的语音助手就是使用 RNN-T 进行语音识别的。

 

 

课程也告一段落啦,我之后会将所有的语音学习内容整合成一个pdf,欢迎大家下载~如果觉得csdn上下载不方便,也可以找我私聊联系~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/181782.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyQt6第一个程序HelloWorld实现

锋哥原创的PyQt6视频教程&#xff1a; 2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili2024版 PyQt6 Python桌面开发 视频教程(无废话版) 玩命更新中~共计12条视频&#xff0c;包括&#xff1a;2024版 PyQt6 Python桌面开发 视频教程(无废话版…

three.js球体实现

作者&#xff1a;baekpcyyy&#x1f41f; 使用three.js渲染出可以调节大小的立方体 1.搭建开发环境 1.首先新建文件夹用vsc打开项目终端 2.执行npm init -y 创建配置文件夹 3.执行npm i three0.152 安装three.js依赖 4.执行npm I vite -D 安装 Vite 作为开发依赖 5.根…

网络协议系列:TCP三次握手,四次挥手的全过程,为什么需要三次握手,四次挥手

TCP三次握手&#xff0c;四次挥手的全过程&#xff0c;为什么需要三次握手&#xff0c;四次挥手 一. TCP三次握手&#xff0c;四次挥手的全过程&#xff0c;为什么需要三次握手&#xff0c;四次挥手前言TCP协议的介绍三次握手三次握手流程&#xff1a;1. A 的 TCP 向 B 发送 连…

【嵌入式Linux开发一路清障-连载04】虚拟机VirtualBox7.0安装Ubuntu22.04后挂载Windows平台共享文件夹

虚拟机安装Ubuntu22.04后挂载Windows平台共享文件夹 障碍07-虚拟机VirtualBox7.0完装完Ubuntu22.04后&#xff0c;无法成功挂载Windows平台中共享文件夹&#xff0c;无法访问电脑中的各类重要文件&#xff0c;我该怎么办&#xff1f;一、问题的模样&#xff1a;VirtualBox7.0设…

【算法训练营】算法分析实验(递归实现斐波那契+插入排序、分治思想实现归并排序+快排)附代码+解析

![0 &#x1f308;欢迎来到算法专栏 &#x1f64b;&#x1f3fe;‍♀️作者介绍&#xff1a;前PLA队员 目前是一名普通本科大三的软件工程专业学生 &#x1f30f;IP坐标&#xff1a;湖北武汉 &#x1f349; 目前技术栈&#xff1a;C/C、Linux系统编程、计算机网络、数据结构、M…

SpringBoot : ch08 自动配置原理

前言 在现代的Java开发中&#xff0c;Spring Boot已经成为了一个备受欢迎的框架。它以其简化开发流程、提高效率和强大的功能而闻名&#xff0c;使得开发人员能够更加专注于业务逻辑的实现而不必过多地关注配置问题。 然而&#xff0c;你是否曾经好奇过Spring Boot是如何做到…

白盒测试 接口测试 自动化测试

一、什么是白盒测试 白盒测试是一种测试策略&#xff0c;这种策略允许我们检查程序的内部结构&#xff0c;对程序的逻辑结构进行检查&#xff0c;从中获取测试数据。白盒测试的对象基本是源程序&#xff0c;所以它又称为结构测试或逻辑驱动测试&#xff0c;白盒测试方法一般分为…

Python编程基础:数据类型和运算符解析

想要学习Python编程语言&#xff1f;本文将为您介绍Python中常见的数据类型和运算符&#xff0c;为您打下坚实的编程基础。了解不同的数据类型和运算符&#xff0c;掌握它们之间的配合方式&#xff0c;让您能够更轻松地进行数据处理和计算任务。无论您是初学者还是有一定经验的…

电能量数据采集终端是电表采集器吗?

随着科技的发展和能源管理的日益精细化&#xff0c;电能量数据采集终端——电表采集器在保障电力系统稳定运行、实现节能减排等方面发挥着越来越重要的作用。下面&#xff0c;小编来为大家全面介绍电表采集器的功能、应用场景及其在我国能源领域的价值。 一、电表采集器的定义与…

第二十章Java博客

如果一次只完成一件事情&#xff0c;很容易实现。但现实生活中&#xff0c;很多事情都是同时进行的。Java中为了模拟这种状态&#xff0c;引入了线程机制。简单地说&#xff0c;当程序同时完成多件事情时&#xff0c;就是所谓的多线程。多线程应用相当广泛&#xff0c;使用多线…

【Java学习笔记】 74 - 本章作业

1.验证电子邮件格式是否合法 规定电子邮件规则为 1.只能有一个 2. 前面是用户名,可以是a-z A-Z 0-9 _ - 字符 3. 后面是域名&#xff0c;并且域名只能是英文字母&#xff0c;比如sohu.com或者tsinghua.org.cn 4.写出对应的正则表达式&#xff0c;验证输入的字符串是否为满…

浏览器触发下载Excel文件-Java实现

目录 1:引入maven 2:代码实现 3.导出通讯录信息到Excel文件 4.生成并下载Excel文件部分解释 1:引入maven 添加依赖:首先,在你的项目中添加EasyExcel库的依赖。你可以在项目的构建文件(如Maven的pom.xml)中添加以下依赖项:<dependency><groupId>com.alib…

Python基础语法之学习input()函数

Python基础语法之学习input函数 前言一、代码二、效果 前言 一、代码 # 默认是字符串类型 number input("请输入一个数字&#xff1a;") print("输入的数字是",number)二、效果 没有人可以阻止你成为自己想成为的人&#xff0c;只有你自己才能放弃梦想。…

【LeetCode刷题笔记】160.相交链表

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞 关注支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 更多算法知识专栏&#xff1a;算法分析&#x1f525; 给大家跳段街舞感谢…

Spring(2):Spring事务管理机制

Spring事务管理高层抽象主要包括3个接口&#xff0c;Spring的事务主要是由他们共同完成的&#xff1a; PlatformTransactionManager&#xff1a;事务管理器—主要用于平台相关事务的管理。TransactionDefinition&#xff1a; 事务定义信息(隔离、传播、超时、只读)—通过配置如…

LeetCode算法题解(动态规划)|LeetCode198. 打家劫舍、LeetCode213. 打家劫舍 II、LeetCode337. 打家劫舍 III

一、LeetCode198. 打家劫舍 题目链接&#xff1a;198. 打家劫舍 题目描述&#xff1a; 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的…

哪里可了解低代码数据可视化开发平台?

如果想要提升办公协作效率&#xff0c;可以用什么样的平台助力实现这一目标&#xff1f;其实&#xff0c;随着市场竞争的日益加剧&#xff0c;低代码技术平台的应用价值也逐渐凸显出来&#xff0c;其可视化、易操作、灵活便利等优势特点&#xff0c;是很多中大型企业倾向于使用…

Nature子刊最新研究:Hi-C宏基因组揭示土壤-噬菌体-宿主相互作用

土壤中有大量的噬菌体。然而&#xff0c;大多数宿主未知&#xff0c;无法获得其基因组特征。2023年11月23日&#xff0c;最新发表于《Nature communications》期刊题为“Hi-C metagenome sequencing reveals soil phage–host interactions”的文章&#xff0c;通过高通量染色体…

2023 最新版navicat 下载与安装 步骤及演示 (图示版)

2023 最新版navicat 下载与安装 步骤演示 -图示版 1. 下载Navicat2 .安装navicat 160 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&#x1f466;&#x1f3fb; 《java 面试题大全》 &#x1f369;惟余辈才疏学浅&#xff0c;临摹之作或有不妥之处&#xff0c…

时钟控制模块

时钟控制模块 锁相环电路简单的理解 https://www.bilibili.com/video/BV1yS4y1n7vV/?spm_id_from333.337.search-card.all.click&vd_source712cdb762d6632543eeeadb56271617a一 时钟是从哪里来的 时钟晶振&#xff08;32.768KHz&#xff09;供给RTC使用在IMX6ULL的T16和…