深度学习笔记——循环神经网络RNN

大家好,这里是好评笔记,公主号:Goodnote,专栏文章私信限时Free。本文详细介绍面试过程中可能遇到的循环神经网络RNN知识点。

在这里插入图片描述


文章目录

  • 文本特征提取的方法
    • 1. 基础方法
      • 1.1 词袋模型(Bag of Words, BOW)
        • 工作原理
        • 举例
        • 优点
        • 缺点
      • 1.2 TF-IDF(Term Frequency-Inverse Document Frequency)
        • 工作原理
        • 举例
        • 优点
        • 缺点
      • 1.3 TF-IDF的改进——BM25
        • 优化
      • 1.4 N-Gram 模型
        • 工作原理
        • 举例
        • 优点
        • 缺点
    • 2. 词向量(Word Embeddings)
      • 2.1 Word2Vec
        • 工作原理
        • 举例
        • 优点
        • 缺点
      • 2.2 FastText
        • 工作原理
        • 优点
        • 缺点
    • 3. 预训练模型:BERT(Bidirectional Encoder Representations from Transformers)
        • 工作原理
        • 优点
        • 缺点
    • 总结
  • RNN
    • RNN 参数
    • RNN 的特点
    • RNN 的局限性
    • 前向传播的核心计算
      • 隐状态更新
      • 输出更新
    • RNN 的训练流程
      • 1. 输入准备
      • 2. 前向传播(Forward Pass)
      • 3. 计算损失(Loss Calculation)
      • 4. 反向传播(Backward Pass)
      • 5. 参数更新
    • RNN 的推理流程
      • 1. 输入准备
      • 2. 前向传播
      • 3. 生成输出
      • 4. 推理结束
    • RNN参数初始化
  • 全连接层在各神经网络模型中的作用
  • 历史文章
    • 机器学习
    • 深度学习

大家好,这里是Goodnote(好评笔记)。本文详细介绍面试过程中可能遇到的循环神经网络RNN、LSTM、GRU、Bi-RNN知识点。

文本特征提取的方法

1. 基础方法

1.1 词袋模型(Bag of Words, BOW)

词袋模型最简单的方法。它将文本表示为一个词频向量不考虑词语的顺序或上下文关系,只统计每个词在文本中出现的频率。

工作原理
  1. 构建词汇表:对整个语料库中的所有词汇建立一个词汇表(也称为词典)。每个文档中的每个词都与词汇表中的一个位置对应。
  2. 生成词频向量:对于每个文本(文档),生成一个与词汇表长度相同的向量。向量中每个元素表示该词在文档中出现的次数(或者是否出现,用二进制表示)。
举例

假设有两个句子:

  • 句子 1:猫 喜欢 鱼
  • 句子 2:狗 不 喜欢 鱼

词汇表 = [“猫”, “狗”, “喜欢”, “不”, “鱼”]

  • 句子 1 的词袋向量表示为:[1, 0, 1, 0, 1]
  • 句子 2 的词袋向量表示为:[0, 1, 1, 1, 1]
优点
  • 简单直观,易于实现,有效地表示词频信息。
缺点
  • 忽略词序:词袋模型无法捕捉词语的顺序,因此在语义表达上有局限。
  • 高维稀疏:对于大词汇表,词袋模型会生成非常长的特征向量,大多数元素为 0,容易导致稀疏矩阵,影响计算效率
  • 受到常见词的影响:常见词(如 “the”、“and” 等)可能在各类文档中频繁出现,但对语义贡献较少,词袋模型会受到这些高频词的影响,降低模型的效果。

1.2 TF-IDF(Term Frequency-Inverse Document Frequency)

TF-IDF 是对词袋模型的改进,它为词语赋予不同的权重,来衡量每个词在文档中的重要性。与词袋模型相比,TF-IDF 不仅考虑词频,还考虑词的普遍性,以避免常见词(如"the"、“and”)的影响

工作原理
  1. TF(词频):计算某个词在文档中出现的频率。
    T F ( t , d ) = 词 t 在文档 d 中的出现次数 文档 d 的总词数 TF(t,d)=\frac{词t在文档d中的出现次数}{文档d的总词数} TF(t,d)=文档d的总词数t在文档d中的出现次数
  2. IDF(逆文档频率):衡量词在整个语料库中的普遍性,出现频率越低的词权重越高。
    I D F ( t ) = log ⁡ ( N 1 + D F ( t ) ) IDF(t)=\log\left(\frac{N}{1 + DF(t)}\right) IDF(t)=log(1+DF(t)N)
    • 其中 N N N是文档总数, D F ( t ) DF(t) DF(t)是包含词 t t t的文档数。
  3. TF - IDF:将 T F TF TF I D F IDF IDF相乘,得到词在特定文档中的权重:
    T F − I D F ( t , d ) = T F ( t , d ) × I D F ( t ) TF - IDF(t,d)=TF(t,d)\times IDF(t) TFIDF(t,d)=TF(t,d)×IDF(t)
举例

对于句子“猫 喜欢 鱼”和“狗 不 喜欢 鱼”,假设 “喜欢” 出现在所有文档中,IDF 会给它较低的权重,而像 “猫”、“狗” 这样的词会有较高的 IDF 权重,因为它们只出现在一部分文档中。

优点
  • 更准确地反映词的重要性,避免了词袋模型中常见词占主导地位的情况。尤其适用于文本分类任务。
缺点
  • 稀疏矩阵:虽然词频的权重经过调整,但词汇表的大小仍然很大,容易产生稀疏矩阵问题
  • 忽略词序:仍然无法捕捉词语之间的顺序和上下文关系

1.3 TF-IDF的改进——BM25

BM25对TF和IDF进行加权,同时考虑文档长度对相关性的影响,使得对较短和较长文档的评分更加合理

BM25 的计算公式如下:

B M 25 ( q , d ) = ∑ t ∈ q I D F ( t ) ⋅ T F ( t , d ) ⋅ ( k 1 + 1 ) T F ( t , d ) + k 1 ⋅ ( 1 − b + b ⋅ ∣ d ∣ a v g d l ) BM25(q,d)=\sum_{t\in q}IDF(t)\cdot\frac{TF(t,d)\cdot(k_1 + 1)}{TF(t,d)+k_1\cdot(1 - b + b\cdot\frac{|d|}{avgdl})} BM25(q,d)=tqIDF(t)TF(t,d)+k1(1b+bavgdld)TF(t,d)(k1+1)
其中:

  • q q q 是查询, d d d 是文档, t t t 是查询中的词。
  • I D F IDF IDF是与 T F − I D F TF - IDF TFIDF相似的逆文档频率。
  • T F TF TF是词频。
  • k 1 k_1 k1 是调节词频饱和度的参数,通常取值范围为 [ 1.2 , 2.0 ] [1.2,2.0] [1.2,2.0]
  • b b b 是调节文档长度的参数,通常取值范围为 [ 0.0 , 1.0 ] [0.0,1.0] [0.0,1.0] b = 0.75 b = 0.75 b=0.75是一个常用的设置。
  • ∣ d ∣ |d| d是文档的长度(词数), a v g d l avgdl avgdl是语料库中文档的平均长度。

TF-IDF 中的 IDF(逆文档频率)使用 log ⁡ N d f ( t ) \log\frac{N}{df(t)} logdf(t)N来衡量词的普遍性。然而这种计算方式可能会导致在某些极端情况下(如 df(t) = 0 )出现不合理的结果。
BM25 对 IDF 进行了小改进,以提高在极端情况下的稳定性:

I D F ( t ) = log ⁡ N − d f ( t ) + 0.5 d f ( t ) + 0.5 IDF(t)=\log\frac{N - df(t)+ 0.5}{df(t)+ 0.5} IDF(t)=logdf(t)+0.5Ndf(t)+0.5

这种改进的 IDF 计算在文档数量较少或者某个词的出现频率极高时,能提供更合理的 IDF 值,增加了 BM25 的稳定性。

优化

相比于 TF-IDF,BM25 主要做了以下改进:

  • 非线性词频缩放:通过 k 1 k_1 k1 控制词频TF饱和 ,避免 TF 值无限增大导致的偏差。
  • 文档长度归一化:使用参数 b b b调整文档长度对评分的影响,防止长文档得分偏高。
  • 改进的 IDF 计算使用平滑后的 IDF 计算,保证在极端情况下的稳定性
  • 查询词频考虑:在评分中更合理地衡量查询中词频的影响,提高了对复杂查询的检索效果。

1.4 N-Gram 模型

N-Gram 模型是一种基于词袋模型扩展方法,它通过将词组作为特征,来捕捉词语的顺序信息

工作原理
  • N-Gram 是指在文本中提取连续的 n 个词组成的词组作为特征。当 n=1 时,即为 unigram(单词级别特征);当 n=2 时,即为 bigram(双词组特征);当 n=3 时,即为 trigram(词三元组特征)。

  • 在提取 N-Gram 时,模型不仅关注单个词,还捕捉到词与词之间的顺序和依赖关系。例如,2-Gram 模型会将句子分解为相邻的两词组合。

举例

对于句子“猫 喜欢 吃 鱼”,2-Gram 模型会提取出以下特征:

  • [“猫 喜欢”, “喜欢 吃”, “吃 鱼”]
优点
  • 捕捉到顺序和依赖关系,比单词级别的特征表达更丰富。n 越大,模型捕捉的上下文信息越多。
缺点
  • 维度膨胀:n 值越大,特征向量的维度会急剧增加,容易导致稀疏矩阵和计算复杂度升高。
  • 对长文本,N-Gram 模型可能会生成非常多的组合,计算资源消耗较大

2. 词向量(Word Embeddings)

词向量是现代 NLP 中的关键特征提取方法,能够捕捉词语的语义信息。常见的词向量方法包括 Word2Vec、GloVe、和 FastText。词向量的核心思想是将每个词表示为一个低维的、密集的向量,词向量之间的相似性能够反映词语的语义相似性

2.1 Word2Vec

Word2Vec 是一种使用浅层神经网络学习词向量的模型,由 Google 在 2013 年提出。它有两种模型架构:CBOW 和 Skip-gram。

工作原理
  • CBOW(Continuous Bag of Words)根据上下文中的词语来预测中心词。模型输入是上下文词语,输出是预测的中心词。
  • Skip-gram:与 CBOW 相反,它是根据中心词来预测上下文中的词语
举例

对于一个句子 “猫 喜欢 吃 鱼”,CBOW 会使用上下文 [“猫”, “吃”, “鱼”] 来预测 “喜欢”,而 Skip-gram 则会使用 “喜欢” 来预测上下文。

优点
  • 语义相似性:Word2Vec 生成的词向量能够捕捉词语之间的语义相似性。例如,“king” 和 “queen” 的词向量会非常相近。
  • 稠密向量:与词袋模型和 TF-IDF 生成的高维稀疏向量不同,Word2Vec 生成的词向量是低维的密集向量(如 100 维或 300 维),更加高效
缺点
  • 无法处理 OOV(未登录词):如果测试集中出现了训练集中未见过的词,Word2Vec 无法为其生成词向量。
  • 上下文无关:Word2Vec 生成的词向量是固定的,无法根据上下文变化来调整词向量

2.2 FastText

FastText 是 Facebook 提出的词向量方法,它是 Word2Vec 的改进版。FastText 通过将词分解为n-gram字符级别的子词,捕捉词的形态信息

工作原理
  • FastText 将词分解为多个字符 n-gram,然后对每个 n-gram 生成词向量。通过这种方式,FastText 可以捕捉到词语内部的形态信息,尤其对拼写错误或未登录词有较好的处理能力。
优点
  • 处理 OOV(未登录词):因为 FastText 基于子词生成词向量,它能够为未见过的词生成向量表示
  • 考虑词形信息:能够捕捉词的形态变化,例如词根、前缀、后缀等。
缺点
  • 计算复杂度较高:相比 Word2Vec,FastText 需要对每个词生成多个 n-gram,因此计算量更大。

3. 预训练模型:BERT(Bidirectional Encoder Representations from Transformers)

BERT 是一种基于 Transformer 架构的预训练语言模型,由 Google 于 2018 年提出。与传统的词向量方法不同,BERT 通过双向的 Transformer 网络,能够生成上下文相关的动态词向量。

工作原理
  • 双向Transformer:BERT 同时从词语的前后上下文学习词的表示,而不像传统的模型只从前向或后向学习。这样,BERT 能够捕捉到更丰富的语义信息。
  • 预训练任务
    1. 遮蔽语言模型(Masked Language Model, MLM):在训练时,BERT 会随机遮蔽部分词语,并要求模型预测这些词,从而让模型学到上下文的双向依赖关系。
    2. 下一句预测(Next Sentence Prediction, NSP):训练时,BERT 要预测两句话是否是连续的句子对,这让模型能够学习句子级别的关系。
优点
  • 上下文相关词向量:BERT 生成的词向量是上下文相关的。例如,“bank” 在句子 “I went to the bank” 和 “The river bank” 中会有不同的向量表示。
  • 强大的语义理解能力:BERT 在问答、阅读理解、文本分类等任务中表现非常好,能够捕捉到复杂的语义关系。
缺点
  • 计算资源需求大:BERT 是一个深层的 Transformer 模型,预训练和微调都需要大量的计算资源,训练时间较长。
  • 较慢的推理速度:由于模型较大,在实际应用中推理速度较慢,尤其在实时任务中。

BERT详细参考历史/后续文章:[深度学习笔记——GPT、BERT、T5]

总结

方法工作原理优点缺点适用场景
词袋模型(BOW)将文本表示为词频向量不考虑词序和上下文简单直观,易实现,能够有效表示词频信息。忽略词序,生成高维稀疏向量。文本分类、信息检索
TF-IDF基于词袋模型考虑词在文档中的频率以及整个语料库中的普遍性,赋予不同词权重反映词的重要性,避免常见词主导影响,适用于文本分类。生成稀疏矩阵,无法捕捉词序和上下文关系文本分类、关键词提取
BM25基于 TF-IDF 的改进,考虑词频、文档长度、词重要性等因素,以计算每个词对文档匹配的相关性得分。 非线性词频缩放、 文档长度归一化改进的 IDF 计算更好地反映词在文档中的相关性,更适合信息检索,适用于长文档,计算匹配更准确。对参数敏感,适用性依赖于超参数调优,不能捕捉上下文关系。信息检索、文档排名
N-Gram捕捉连续 n 个词作为特征考虑词序信息能捕捉词语的顺序和依赖关系,n 越大捕捉的上下文信息越多。维度膨胀,计算资源消耗大。语言模型、短文本分类
Word2Vec使用浅层神经网络学习词向量,有 CBOWSkip-gram 两种架构。词向量能捕捉语义相似性,生成低维稠密向量,效率高。无法处理未登录词(OOV),词向量上下文无关。词嵌入、相似度计算、文本分类
FastText词分解为字符 n-gram,生成词向量,捕捉词的形态信息。能处理未登录词,捕捉词形信息,适合拼写错误和变形词。计算复杂度高于 Word2Vec。词嵌入、拼写纠错、文本分类
BERT基于双向 Transformer,通过预训练生成上下文相关的词向量,支持 Masked Language Model 和 Next Sentence Prediction生成上下文相关词向量,语义理解强,适用于复杂 NLP 任务。需要大量计算资源,训练和推理时间长。问答系统、文本分类、阅读理解
  • 传统方法:如词袋模型、TF-IDF 和 N-Gram 易于实现,但无法捕捉语义和上下文信息。
  • 词向量方法:如 Word2Vec 和 FastText 通过词嵌入表示词语的语义关系,适合语义相似度计算、文本分类等任务。FastText 能够处理未登录词。
  • 预训练模型:如 BERT,能够生成上下文相关的动态词向量,适用于更复杂的自然语言处理任务,但对计算资源的要求更高。

RNN

循环神经网络(RNN,Recurrent Neural Network)是一种用于处理序列数据等具有顺序关系的数据的神经网络。与传统的前馈神经网络不同,RNN 具有循环连接 ,允许信息通过隐藏状态在序列的不同时间步之间传播。这种结构使得RNN非常适合处理时间序列、文本数据、语音信号等具有顺序关系的数据

RNN 参数

参数维度作用
输入权重矩阵 W x h W_{x h} Wxh ( d h i d d e n × d i n p u t ) (d_{hidden}\times d_{input}) (dhidden×dinput)将输入 x t x_t xt映射到隐藏状态,确定当前输入对隐藏层状态的影响。
隐藏状态权重矩阵 W h h W_{h h} Whh ( d h i d d e n × d h i d d e n ) (d_{hidden}\times d_{hidden}) (dhidden×dhidden)将前一时间步的隐藏状态 h t − 1 h_{t - 1} ht1传递到当前时间步 h t h_t ht,捕捉时间依赖关系。
输出权重矩阵 W h y W_{h y} Why ( d o u t p u t × d h i d d e n ) (d_{output}\times d_{hidden}) (doutput×dhidden)将隐藏状态 h t h_t ht映射为输出 y t y_t yt
隐藏层偏置向量 b h b_{h} bh ( d h i d d e n ) (d_{hidden}) (dhidden)增强隐藏层的灵活性,通过加偏置调整隐藏层的激活函数输出。
输出层偏置向量 b y b_{y} by ( d o u t p u t ) (d_{output}) (doutput)增强输出层的灵活性,通过加偏置调整输出层的结果。
激活函数用于隐藏层状态的非线性变换,常用 t a n h tanh tanh R e L U ReLU ReLU

其他相关参数:

  • 时间步(Time Steps):决定模型的循环次数,非可学习参数。
  • 损失函数(Loss Function):指导模型参数更新的依据,非可学习参数。
  • 学习率(Learning Rate):控制优化过程的步幅大小,超参数。

RNN 的特点

  1. 顺序处理:RNN 可以处理不同长度的输入序列,这是由于其内部结构允许将前一步的信息作为当前步的输入之一。
  2. 隐藏状态:RNN 具有隐藏状态,隐藏状态是前一个时间步的信息的压缩,并与当前输入一起决定下一时间步的输出。
  3. 权重共享:RNN 中每个时间步之间共享相同的网络权重( W h h W_{hh} Whh, W x h W_{xh} Wxh, W h y W_{hy} Why),减少了模型参数的数量,适合处理序列长度不同的问题。

RNN 的局限性

  1. 梯度消失与爆炸问题:在长序列处理中,由于反向传播算法在计算梯度时,RNN 容易出现梯度消失(gradient vanishing)或梯度爆炸(gradient exploding)的现象,导致模型难以学习长期依赖关系。
  2. 长时间依赖问题:RNN 处理长序列时,无法有效捕捉到前后相距较远的依赖关系,导致模型的性能下降。
  3. 并行化困难:由于 RNN 是逐时间步处理序列的,因此不容易并行化处理,这使得其训练时间较长。

前向传播的核心计算

在 RNN 中,当前时间步的输出不仅依赖于当前输入,还依赖于之前时间步的隐状态(hidden state)。隐状态是 RNN 中的一个内部存储器,它能够保存之前的时间步的信息,使得网络具备记忆能力。RNN 的计算公式如下:

隐状态更新

h t = f ( W h h h t − 1 + W x h x t + b h ) h_t=f(W_{h h}h_{t - 1}+W_{x h}x_t + b_{h}) ht=f(Whhht1+Wxhxt+bh)

  • h t h_t ht:时间步 t t t的隐藏状态,是通过上一时间步 t − 1 t - 1 t1的隐藏状态 h t − 1 h_{t - 1} ht1和当前的输入 x t x_t xt计算得到的。
  • W h h W_{h h} Whh:隐藏状态到隐藏状态的权重矩阵,用于表示时间步之间的状态传递。
  • W x h W_{x h} Wxh:输入到隐藏状态的权重矩阵,负责将当前输入 x t x_t xt映射到隐藏层。
  • b h b_{h} bh:隐藏层的偏置向量。
  • f f f:激活函数,通常使用 t a n h tanh tanh R e L U ReLU ReLU

输出更新

y t = g ( W h y h t + b y ) y_t = g(W_{hy}h_t + b_y) yt=g(Whyht+by)

  • y t y_t yt:时间步 t t t的输出。
  • W h y W_{hy} Why:隐藏状态到输出的权重矩阵。
  • b y b_y by:输出层的偏置。
  • g g g:输出层的激活函数,取决于具体的任务,如分类任务常用Softmax。

RNN 的训练流程

RNN 的训练主要包括以下步骤:

1. 输入准备

  • 输入数据:RNN处理的是序列数据,输入可以是时间序列、文本、语音等。输入通常表示为 X = [ x 1 , x 2 , … , x T ] X = [x_1, x_2, \ldots, x_T] X=[x1,x2,,xT],其中 x t x_t xt代表时间步 t t t时的输入。
  • 标签数据(监督学习):如果是监督学习任务,训练数据通常带有标签 Y = [ y 1 , y 2 , … , y T ] Y = [y_1, y_2, \ldots, y_T] Y=[y1,y2,,yT],表示每个时间步 t t t的目标输出。

2. 前向传播(Forward Pass)

逐个时间步执行前向传播,将输入数据逐步传递到隐藏层,计算每个时间步的隐藏状态和输出(上面的核心计算)。

  • 初始化隐藏状态:在时间步 t = 0 时,隐藏状态 h 0 h_0 h0通常初始化为 0 或随机值
  • 逐时间步的状态更新:对于每一个时间步 t ,计算当前时间步的隐藏状态和输出:

(1). 隐藏状态更新:
h t = f ( W h h h t − 1 + W x h x t + b h ) h_t = f(W_{hh}h_{t - 1}+W_{xh}x_t + b_h) ht=f(Whhht1+Wxhxt+bh)

  • h t h_t ht:时间步 t t t的隐藏状态,是通过上一时间步 t − 1 t - 1 t1的隐藏状态 h t − 1 h_{t - 1} ht1和当前的输入 x t x_t xt计算得到的。
  • W h h W_{hh} Whh:隐藏状态到隐藏状态的权重矩阵,用于表示时间步之间的状态传递。
  • W x h W_{xh} Wxh:输入到隐藏状态的权重矩阵,负责将当前输入 x t x_t xt映射到隐藏层。
  • b h b_h bh:隐藏层的偏置向量。
  • f f f:激活函数,通常使用tanh或ReLU。

(2). 输出更新:
y t = g ( W h y h t + b y ) y_t = g(W_{hy}h_t + b_y) yt=g(Whyht+by)

  • y t y_t yt:时间步 t t t的输出。
  • W h y W_{hy} Why:隐藏状态到输出的权重矩阵。
  • b y b_y by:输出层的偏置。
  • g g g:输出层的激活函数,取决于具体的任务,如分类任务常用Softmax。

3. 计算损失(Loss Calculation)

  • 损失函数:RNN 根据每个时间步的预测输出 y t y_t yt 和真实标签 y ^ t \hat{y}_t y^t计算损失。常见的损失函数包括:
    • 分类任务使用 交叉熵损失
    • 回归任务使用 均方误差(MSE)

总的损失是各个时间步损失的累加:
L = ∑ t = 1 T L o s s ( y t , y ^ t ) L=\sum_{t = 1}^{T}L_{oss}(y_t, \hat{y}_t) L=t=1TLoss(yt,y^t)

4. 反向传播(Backward Pass)

RNN 的反向传播主要通过**反向传播通过时间(Backpropagation Through Time, BPTT)**来更新权重。BPTT 是对标准反向传播算法的扩展,沿着时间轴进行梯度传递。

反向传播的核心步骤

  • 从时间步 T T T开始,逐步向前计算各个时间步的梯度,直至第一个时间步。
  • 在每个时间步,计算损失相对于输出 y t y_t yt、隐藏状态 h t h_t ht以及参数(权重和偏置)的梯度。
    • 对损失函数 L L L进行求导,得到每个时间步的梯度,并依次更新参数:
      W x h ← W x h − α ∂ L ∂ W x h W_{xh}\leftarrow W_{xh}-\alpha\frac{\partial L}{\partial W_{xh}} WxhWxhαWxhL
      W h h ← W h h − α ∂ L ∂ W h h W_{hh}\leftarrow W_{hh}-\alpha\frac{\partial L}{\partial W_{hh}} WhhWhhαWhhL
      W h y ← W h y − α ∂ L ∂ W h y W_{hy}\leftarrow W_{hy}-\alpha\frac{\partial L}{\partial W_{hy}} WhyWhyαWhyL
    • 其中 α \alpha α是学习率。
  • 梯度消失问题:BPTT 在长时间序列上可能出现梯度消失或爆炸问题,尤其是当梯度逐步传递时,这限制了 RNN 捕捉长时依赖的能力。

5. 参数更新

  • 根据反向传播得到的梯度更新网络中的权重参数,如 W x h W_{xh} Wxh, W h h W_{hh} Whh, W h y W_{hy} Why,完成当前批次的训练。
  • 继续执行下一个批次的训练,直至完成所有训练数据的迭代。

RNN 的推理流程

RNN 的推理(inference)过程与训练过程中的前向传播类似,但没有反向传播和参数更新,主要是用于生成输出或进行预测。

1. 输入准备

  • 输入序列数据 X = [ x 1 , x 2 , . . . , x T ] X = [x_1, x_2, ..., x_T] X=[x1,x2,...,xT]
  • 不需要提供标签数据,因为推理阶段是无监督的,RNN 根据输入数据生成输出。

2. 前向传播

推理阶段执行与训练相同的前向传播过程:

  1. 初始化隐藏状态
    • 与训练时一样,隐藏状态 h 0 h_0 h0初始化为 0 或随机值。
  2. 逐时间步的状态更新
  • 对每个时间步 t t t执行前向传播,计算隐藏状态 h t h_t ht和输出 y t y_t yt
    h t = f ( W x h x t + W h h h t − 1 + b h ) h_t = f(W_{xh}x_t + W_{hh}h_{t - 1}+b_h) ht=f(Wxhxt+Whhht1+bh)
    y t = g ( W h y h t + b y ) y_t = g(W_{hy}h_t + b_y) yt=g(Whyht+by)

3. 生成输出

  • 推理过程中,RNN 根据输入数据在每个时间步生成相应的输出 y t y_t yt
  • 在序列生成任务(如文本生成)中,RNN 的输出 y t y_t yt 可以作为下一个时间步的输入 x t + 1 x_{t+1} xt+1,从而生成一个完整的输出序列。

4. 推理结束

  • 推理结束时,RNN 已经生成了完整的输出序列或预测结果。

RNN参数初始化

RNN 的参数初始化策略会影响训练过程的稳定性和收敛速度,以下是不同参数的初始化方法概述:

参数初始化方法适用场景
输入权重 ( W x h W_{xh} Wxh )Xavier 初始化、He 初始化、标准正态分布用于将输入映射到隐藏状态,适合不同激活函数场景
隐藏状态权重 ( W h h W_{hh} Whh )正交初始化、Xavier 初始化、He 初始化隐藏状态循环权重,正交初始化适合处理梯度问题
输出权重 ( W h y W_{hy} Why )Xavier 初始化、He 初始化用于将隐藏状态映射到输出层,取决于任务类型
偏置项 ( b h , b y b_h, b_y bh,by )零初始化、小随机数、遗忘门的偏置可初始化为较大正值(LSTM/GRU)偏置项通常为零初始化,特殊情况下设定固定值

通过使用合适的初始化方法,可以显著提高 RNN 的收敛性和模型的训练效果,尤其在处理长序列时,正交初始化Xavier 初始化能帮助缓解梯度消失和梯度爆炸问题。

全连接层在各神经网络模型中的作用

全连接层(Fully Connected Layer, FC Layer)广泛应用于分类任务或回归任务的最后阶段。全连接层的主要作用是将上层提取的特征转换为具体的决策结果或输出。它在不同类型的神经网络模型中具有不同的作用。以下是全连接层在各类神经网络模型中的作用详细解释:

  1. MLP中,全连接层是主要计算单元,完成输入到输出的映射
  2. CNN中,全连接层主要用于整合局部特征并生成分类结果
  3. RNNLSTM中,全连接层将时间序列特征映射为输出结果
  4. Transformer模型中,全连接层参与注意力机制和输出映射
  5. GAN中,全连接层用于潜在空间和图像特征之间的映射
  6. 自编码器中,全连接层用于特征压缩和重构
  7. 注意力机制中,全连接层用于计算注意力得分并变换上下文向量

无论在哪种神经网络中,全连接层的核心作用都是将前一层的特征进一步映射到目标空间,形成最后的输出或决策。

历史文章

机器学习

机器学习笔记——损失函数、代价函数和KL散度
机器学习笔记——特征工程、正则化、强化学习
机器学习笔记——30种常见机器学习算法简要汇总
机器学习笔记——感知机、多层感知机(MLP)、支持向量机(SVM)
机器学习笔记——KNN(K-Nearest Neighbors,K 近邻算法)
机器学习笔记——朴素贝叶斯算法
机器学习笔记——决策树
机器学习笔记——集成学习、Bagging(随机森林)、Boosting(AdaBoost、GBDT、XGBoost、LightGBM)、Stacking
机器学习笔记——Boosting中常用算法(GBDT、XGBoost、LightGBM)迭代路径
机器学习笔记——聚类算法(Kmeans、GMM-使用EM优化)
机器学习笔记——降维

深度学习

深度学习笔记——优化算法、激活函数
深度学习——归一化、正则化
深度学习笔记——前向传播与反向传播、神经网络(前馈神经网络与反馈神经网络)、常见算法概要汇总
深度学习笔记——卷积神经网络CNN

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/67474.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Selenium工具使用Python 语言实现下拉框定位操作

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 我们通常遇到的下拉框有显性的下拉框和隐性的下拉框;有的下拉框还可以进行单选或多选操作,在selenium中如何实现下拉框的定位通常使用selec…

使用 Continue 插件时,发现调用外部地址

https://us.i.posthog.com/e/?ip1&_1737025525924&ver1.163.0&compressiongzip-js 看是一个帮助改善产品的网址。估计类似某推广流量监控的插件工具吧。网上没用查到其他说明,可能国内使用不多的原因。 但是发送的数据看不出来是个什么内容。 我用来搜…

【PyQt】图像处理系统

[toc]pyqt实现图像处理系统 图像处理系统 1.创建阴影去除ui文件 2.阴影去除代码 1.创建阴影去除ui文件 UI文件效果图: 1.1QT Desiger设置组件 1.两个Pushbutton按钮 2.两个label来显示图像 3.Text Browser来显示输出信息 1.2布局的设置 1.先不使用任何La…

【Idea】编译Spring源码 read timeout 问题

Idea现在是大家工作中用的比较多的开发工具,尤其是做java开发的,那么做java开发,了解spring框架源码是提高自己技能水平的一个方式,所以会从spring 官网下载源码,导入到 Idea 工具并编译,但是发现build的时…

Linux 音视频入门到实战专栏(视频篇)视频编解码 MPP

文章目录 一、MPP 介绍二、获取和编译RKMPP库三、视频解码四、视频编码 沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将介绍如何调用alsa api来进行音频数据的播放和录制。 一、MPP 介绍 瑞芯微提供的媒体处理软件平台…

爬虫后的数据处理与使用(使用篇--实现分类预测)

()紧接上文,在完成基本的数据处理后,接下来就是正常的使用了。当然怎么用,确实需要好好思考一下~ 上文:爬虫后的数据处理与使用(处理篇) 前言: 一般来说,我…

RabbitMQ--延迟队列

(一)延迟队列 1.概念 延迟队列是一种特殊的队列,消息被发送后,消费者并不会立刻拿到消息,而是等待一段时间后,消费者才可以从这个队列中拿到消息进行消费 2.应用场景 延迟队列的应用场景很多,…

flutter开发-figma交互设计图可以转换为flutter源代码-如何将设计图转换为flutter源代码-优雅草央千澈

flutter开发-figma交互设计图可以转换为flutter源代码-如何将设计图转换为flutter源代码-优雅草央千澈 开发背景 可能大家听过过蓝湖可以转ui设计图为vue.js,react native代码,那么请问听说过将figma的设计图转换为flutter源代码吗?本文优雅草央千澈带…

当设置dialog中有el-table时,并设置el-table区域的滚动,看到el-table中多了一条横线

问题:当设置dialog中有el-table时,并设置el-table区域的滚动,看到el-table中多了一条横线; 原因:el-table有一个before的伪元素作为表格的下边框下,初始的时候已设置,在滚动的时候并没有重新设置…

代理模式实现

一、概念:代理模式属于结构型设计模式。客户端不能直接访问一个对象,可以通过代理的第三者来间接访问该对象,代理对象控制着对于原对象的访问,并允许在客户端访问对象的前后进行一些扩展和处理;这种设置模式称为代理模…

windows 搭建flutter环境,开发windows程序

环境安装配置: 下载flutter sdk https://docs.flutter.dev/get-started/install/windows 下载到本地后,随便找个地方解压,然后配置下系统环境变量 编译windows程序本地需要安装vs2019或更新的开发环境 主要就这2步安装后就可以了&#xff0…

Redis系列之底层数据结构字典Dict

Redis系列之底层数据结构字典Dict Dict数据结构 Dict是Redis数据结构中使用最为频繁的复合型数据结构,本质上是一个哈希表 查看redis6.0版本的源码,链接:https://github.com/redis/redis/blob/6.0/src/dict.h 哈希表的结构定义&#xff1…

【Azure 架构师学习笔记】- Azure Function (2) --实操1

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Function 】系列。 接上文【Azure 架构师学习笔记】- Azure Function (1) --环境搭建和背景介绍 前言 上一文介绍了环境搭建,接下来就在本地环境下使用一下。 环境准备 这里我下载了最新的VS studio&…

【NextJS】PostgreSQL 遇上 Prisma ORM

NextJS 数据库 之 遇上Prisma ORM 前言一、环境要求二、概念介绍1、Prisma Schema Language(PSL) 结构描述语言1.1 概念1.2 组成1.2.1 Data Source 数据源1.2.2 Generators 生成器1.2.3 Data Model Definition 数据模型定义字段(数据)类型和约束关系&…

左神算法基础提升--3

文章目录 Manacher 算法经典算法Manacher算法原理 单调栈或单调队列 Manacher 算法 经典算法 在每学习Manacher算法之前我们可能会使用一种比较经典暴力的算法:遍历str字符串,将字符串中的每个字符作为对称点,向两边扩散找到回文字段&#x…

浅谈操作系统与初识Linux

一、Linux操作系统的出现 1.1操作系统的出现以及相关的四个要素 1.2最早出现的操作系统及其创始人 起初,IBM为了让计算机可以以更低技术成本进行使用,以此来售卖计算机; 为计算机搭载上了Unix操作系统,Unix由肯汤普森用汇编语…

ElasticSearch下

DSL查询 叶子查询:在特定字段里查询特定值,属于简单查询,很少单独使用复合查询:以逻辑方式组合多个叶子查询或更改叶子查询的行为方式 在查询后还可以对查询结果做处理: 排序:按照1个或多个字段做排序分页…

java根据模板导出word,并在word中插入echarts相关统计图片以及表格

引入依赖创建word模板创建ftl模板文件保存的ftl可能会出现占位符分割的问题,需要处理将ftl文件中的图片的Base64删除,并使用占位符代替插入表格,并指定表格的位置在图片下方 Echarts转图片根据模板生成word文档DocUtil导出word文档 生成的wor…

链式前向星的写法

【图论02】动画说图的三种保存方式 降低理解门槛 邻接表 链式前向星 邻接矩阵_哔哩哔哩_bilibili 杭电ACM刘老师-算法入门培训-第12讲-拓扑排序及链式前向星_哔哩哔哩_bilibili 图论003链式前向星_哔哩哔哩_bilibili(链式前向星的遍历) head数组的下标…

想品客老师的第一天:值类型使用

前面两章的摘要 ECMAscript(也就是ES)是JavaScript的一个标准,就像c的c11和c99一样,几把的一年出一套标准 freeze()是一个对象方法,表示锁定、固定一个对象不可改变(因为const对于标量不可变,…