从代码学习深度学习 - Bahdanau注意力 PyTorch版

文章目录

    • 1. 前言
      • 为什么选择Bahdanau注意力
      • 本文目标与预备知识
    • 2. Bahdanau注意力机制概述
      • 注意力机制简述
      • 加性注意力与乘性注意力对比
      • Bahdanau注意力的数学原理与流程图
        • 数学原理
        • 流程图
        • 可视化与直观理解
    • 3. 数据准备与预处理
      • 数据集简介
      • 数据加载与预处理
        • 1. 读取数据集
        • 2. 预处理文本
        • 3. 词元化
      • 词表构建
      • 序列截断与填充
      • 构建张量与有效长度
      • 创建数据迭代器
      • 数据准备的关键点
      • 与Bahdanau注意力的关联
      • 总结
    • 4. 模型组件搭建
      • 4.1 总体架构概述
      • 4.2 编码器(Encoder)
      • 4.3 解码器(Decoder)
      • 4.4 Bahdanau注意力机制(AdditiveAttention)
      • 4.5 屏蔽机制(sequence_mask 和 masked_softmax)
        • sequence_mask
        • masked_softmax
      • 4.6 数据加载与模型整合
      • 4.7 关键点与优势
      • 4.8 可视化与验证
      • 4.9 总结
  • 5. 训练流程实现
    • 5.1 数据加载
    • 5.2 模型定义
    • 5.3 训练过程
      • 5.3.1 权重初始化
      • 5.3.2 优化器和损失函数
      • 5.3.3 训练循环
      • 5.3.4 训练结果输出
    • 5.4 预测与评估
      • 5.4.1 预测实现
      • 5.4.2 BLEU 分数评估
      • 5.4.3 注意力权重可视化
    • 5.5 实现亮点
    • 5.6 总结
  • 6. 模型推理与预测
    • 6.1 序列翻译预测函数详解
      • 6.1.1 函数定义与参数
      • 6.1.2 预处理阶段
      • 6.1.3 编码器前向传播
      • 6.1.4 解码器逐时间步预测
      • 6.1.5 输出处理
      • 6.1.6 实现亮点
      • 6.1.7 潜在改进方向
    • 6.2 BLEU 评估指标解释与实现
      • 6.2.1 BLEU 指标概述
      • 6.2.2 函数定义与参数
      • 6.2.3 计算逻辑与实现
        • 6.2.3.1 预处理
        • 6.2.3.2 长度惩罚
        • 6.2.3.3 n-gram 精确度
        • 6.2.3.4 返回结果
      • 6.2.4 BLEU 的意义与局限性
      • 6.2.5 实现亮点
      • 6.2.6 潜在改进方向
    • 6.3 总结
  • 7. 可视化注意力权重
    • 7.1 注意力热图绘制与分析
      • 7.1.1 代码实现
      • 7.1.2 热图分析
      • 7.1.3 可视化效果
    • 7.2 模型关注词元的可解释性展示
      • 7.2.1 可解释性意义
      • 7.2.2 可视化案例
      • 7.2.3 提升可解释性的方法
    • 7.3 实现亮点
  • 8. 总结
    • 8.1 Bahdanau 注意力的实现经验分享
    • 8.2 PyTorch 中模块化建模的优势
    • 8.3 下一步可以探索的方向
    • 8.4 总结


完整代码:下载连接

1. 前言

为什么选择Bahdanau注意力

在深度学习领域,尤其是自然语言处理(NLP)任务中,序列到序列(Seq2Seq)模型是许多应用的核心,如机器翻译、文本摘要和对话系统等。传统的Seq2Seq模型依赖于编码器-解码器架构,通过编码器将输入序列压缩为固定长度的上下文向量,再由解码器生成输出序列。然而,这种方法在处理长序列时往往面临信息丢失的问题,上下文向量难以捕捉输入序列的全部细节。

Bahdanau注意力机制(Bahdanau et al., 2014)通过引入动态的上下文选择机制,显著提升了模型对输入序列的利用效率。它允许解码器在生成每个输出时,动态地关注输入序列的不同部分,而非依赖单一的上下文向量。这种机制不仅提高了翻译质量,还为后续的注意力机制(如Transformer)奠定了基础。选择Bahdanau注意力作为学习对象,是因为它直观地展示了注意力机制的核心思想,同时在实现上具有足够的复杂度,能够帮助我们深入理解深度学习的建模过程。

此外,PyTorch作为一个灵活且直观的深度学习框架,非常适合实现和调试复杂的模型结构。通过本文的代码分析,我们将以Bahdanau注意力为核心,结合PyTorch的模块化编程,探索Seq2Seq模型的完整实现流程,为进一步学习Transformer等高级模型打下坚实基础。

本文目标与预备知识

本文的目标是通过剖析一个基于PyTorch实现的Bahdanau注意力Seq2Seq模型,帮助读者从代码层面理解深度学习模型的设计与实现。我们将从数据预处理、模型组件搭建、训练流程到推理与可视化,逐步拆解每个环节的核心代码,揭示Bahdanau注意力机制的运作原理,并提供直观的解释和可视化结果。同时,通过模块化代码的分析,我们将展示如何在PyTorch中高效地组织复杂项目。

为了更好地理解本文内容,建议读者具备以下预备知识:

  • Python编程基础:熟悉Python语法、面向对象编程以及PyTorch的基本操作(如张量操作、模块定义和自动求导)。
  • 深度学习基础:了解神经网络的基本概念(如前向传播、反向传播、损失函数和优化器),以及循环神经网络(RNN)或门控循环单元(GRU)的工作原理。
  • NLP基础:对词嵌入(Word Embedding)、序列建模和机器翻译任务有初步了解。
  • 数学基础:熟悉线性代数(如矩阵运算)、概率论(softmax函数)以及基本的优化理论。

如果你对上述内容有所欠缺,不必担心!本文将尽量通过代码注释和直观的解释,降低学习门槛,让你能够通过实践逐步掌握Bahdanau注意力的精髓。

接下来,我们将进入Bahdanau注意力机制的详细分析,从理论到代码实现,带你一步步走进深度学习的精彩世界!

2. Bahdanau注意力机制概述

注意力机制简述

在深度学习领域,特别是在序列到序列(Seq2Seq)任务如机器翻译中,注意力机制(Attention Mechanism)是一种革命性的技术,用于解决传统Seq2Seq模型在处理长序列时的瓶颈问题。传统Seq2Seq模型通过编码器将输入序列压缩为一个固定长度的上下文向量,再由解码器基于此向量生成输出序列。然而,当输入序列较长时,固定上下文向量难以充分捕捉所有输入信息,导致信息丢失和翻译质量下降。

注意力机制的提出,允许模型在生成输出时动态地关注输入序列的不同部分,而不是依赖单一的上下文向量。具体来说,注意力机制通过计算输入序列每个位置与当前解码步骤的相关性(注意力权重),为解码器提供一个加权的上下文向量。这种动态聚焦的方式极大地提高了模型对长序列的建模能力,并增强了生成结果的可解释性。

Bahdanau注意力(也称为加性注意力,Additive Attention)是注意力机制的早期代表之一,首次提出于2014年的论文《Neural Machine Translation by Jointly Learning to Align and Translate》。它通过引入一个可学习的对齐模型,动态计算输入序列与输出序列之间的关联,被广泛应用于机器翻译等任务。

加性注意力与乘性注意力对比

注意力机制根据计算注意力得分(Attention Score)的方式不同,可以分为加性注意力和乘性注意力(Dot-Product Attention)两大类:

  • 加性注意力(Additive Attention)

    • 计算方式:Bahdanau注意力属于加性注意力,其核心是通过将查询(Query)和键(Key)映射到相同的隐藏维度后,相加并通过非线性激活函数(如tanh)处理,最后通过线性变换得到注意力得分。

    • 数学表达式
      score ( q , k i ) = w v ⊤ ⋅ tanh ⁡ ( W q q + W k k i ) \text{score}(q, k_i) = w_v^\top \cdot \tanh(W_q q + W_k k_i) score(q,ki)=wvtanh(Wqq+Wkki)
      其中,(q)是查询向量,(k_i)是键向量,(W_q)和(W_k)是可学习的权重矩阵,(w_v)是用于计算最终得分的权重向量。

    • 特点

      • 计算复杂度较高,因为需要对查询和键进行线性变换并相加。
      • 适合查询和键维度不同的场景,因为它通过映射统一了维度。
      • 在Bahdanau注意力中,注意力得分经过softmax归一化,生成权重,用于加权求和值(Value)向量,形成上下文向量。
    • 代码体现
      在提供的代码中,AdditiveAttention类实现了这一过程:

      queries, keys = self.W_q(queries), self.W_k(keys)
      features = queries.unsqueeze(2) + keys.unsqueeze(1)
      features = torch.tanh(features)
      scores = self.w_v(features).squeeze(-1)
      self.attention_weights = masked_softmax(scores, valid_lens)
      
  • 乘性注意力(Dot-Product Attention)

    • 计算方式:乘性注意力通过查询和键的点积直接计算得分,通常在查询和键维度相同时使用。
    • 数学表达式
      score ( q , k i ) = q ⊤ k i \text{score}(q, k_i) = q^\top k_i score(q,ki)=qki
      或其缩放版本(Scaled Dot-Product Attention):
      score ( q , k i ) = q ⊤ k i d k \text{score}(q, k_i) = \frac{q^\top k_i}{\sqrt{d_k}} score(q,ki)=dk qki
      其中, d k d_k dk是键的维度,用于防止点积过大。
    • 特点
      • 计算效率较高,适合大规模并行计算,广泛用于Transformer模型。
      • 假设查询和键具有相同的维度,否则需要额外的映射。
      • 对于高维输入,可能需要缩放以稳定训练。
    • 适用场景
      乘性注意力在Transformer等现代模型中更为常见,但在Bahdanau注意力提出时,RNN-based的Seq2Seq模型更倾向于使用加性注意力,因为它能更好地处理变长序列和不同维度的输入。

对比总结

  • 加性注意力(Bahdanau)通过显式的非线性变换,灵活性更高,适合早期RNN模型,但计算开销较大。
  • 乘性注意力(Luong或Transformer)计算简单,效率高,适合现代GPU加速的场景,但在维度不匹配时需要额外处理。
  • Bahdanau注意力作为加性注意力的代表,为后续的乘性注意力机制奠定了理论基础。

Bahdanau注意力的数学原理与流程图

数学原理

Bahdanau注意力的核心目标是为解码器的每个时间步生成一个上下文向量,该向量是输入序列隐藏状态的加权和,权重由注意力得分决定。其工作流程可以分解为以下步骤:

  1. 输入

    • 编码器输出:编码器(通常为GRU或LSTM)处理输入序列,生成隐藏状态序列 ( h 1 , h 2 , … , h T h_1, h_2, \dots, h_T h1,h2,,hT ),其中 $T $ 是输入序列长度,每个 h i h_i hi是键(Key)和值(Value)。
    • 解码器状态:解码器在时间步 t t t的隐藏状态 s t s_t st,作为查询(Query)。
  2. 注意力得分计算

    • 对于解码器状态 s t s_t st 和每个编码器隐藏状态 h i h_i hi,计算注意力得分:
      e t , i = w v ⊤ ⋅ tanh ⁡ ( W s s t + W h h i ) e_{t,i} = w_v^\top \cdot \tanh(W_s s_t + W_h h_i) et,i=wvtanh(Wsst+Whhi)
      其中, W s W_s Ws W h W_h Wh是将查询和键映射到隐藏维度的权重矩阵, w v w_v wv是用于生成标量得分的权重向量。
  3. 注意力权重归一化

    • 将得分通过softmax函数归一化为权重:
      $\alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_{j=1}^T \exp(e_{t,j})}
      $
      其中, α t , i \alpha_{t,i} αt,i表示时间步 t t t 对输入位置 i i i的关注程度,满足 ∑ i α t , i = 1 \sum_i \alpha_{t,i} = 1 iαt,i=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/76619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

19【动手学深度学习】卷积层

1. 从全连接到卷积 2. 图像卷积 3. 图形卷积代码 互相关操作 import torch from torch import nn from d2l import torch as d2ldef corr2d(X, K):"""计算2维互相关运算"""h, w K.shapeY torch.zeros((X.shape[0]-h1, X.shape[1]-w 1))for …

Linux xorg-server 解析(一)- 编译安装Debug版本的xorg-server

一:下载代码 1. 配置源,以Ubuntu24.04 为例( /etc/apt/sources.list.d/ubuntu.sources): 2. apt source xserver-xorg-core 二:编译代码 1. sudo apt build-dep ./ 2. DEB_BUILD_OPTIONS="nostrip" DEB_CFLAGS_SET="-g -O0" dpkg-buildpac…

大模型SFT用chat版还是base版 SFT后灾难性遗忘怎么办

大模型SFT用chat版还是base版 进行 SFT 时,基座模型选用 Chat 还是 Base 模型? 选 Base 还是 Chat 模型,首先先熟悉 Base 和 Chat 是两种不同的大模型,它们在训练数据、应用场景和模型特性上有所区别。 在训练数据方面&#xf…

【图像生成之21】融合了Transformer与Diffusion,Meta新作Transfusion实现图像与语言大一统

论文:Transfusion: Predict the Next Token and Diffuse Images with One Multi-Modal Model 地址:https://arxiv.org/abs/2408.11039 类型:理解与生成 Transfusion模型‌是一种将Transformer和Diffusion模型融合的多模态模型,旨…

动态多目标进化算法:基于知识转移和维护功能的动态多目标进化算法(KTM-DMOEA)求解CEC2018(DF1-DF14)

一、KTM-DMOEA介绍 在实际工程和现实生活中,许多优化问题具有动态性和多目标性,即目标函数会随着环境的变化而改变,并且存在多个相互冲突的目标。传统的多目标进化算法在处理这类动态问题时面临着一些挑战,如收敛速度慢、难以跟踪…

部署NFS版StorageClass(存储类)

部署NFS版StorageClass存储类 NFS版PV动态供给StorageClass(存储类)基于NFS实现动态供应下载NFS存储类资源清单部署NFS服务器为StorageClass(存储类)创建所需的RBAC部署nfs-client-provisioner的deployment创建StorageClass使用存储类创建PVC NFS版PV动态供给StorageClass(存储…

Vue使用el-table给每一行数据上面增加一行自定义合并行

// template <template><el-table:data"flattenedData":span-method"objectSpanMethod"borderclass"custom-header-table"style"width: 100%"ref"myTable":height"60vh"><!-- 订单详情列 -->&l…

vue项目使用html2canvas和jspdf将页面导出成PDF文件

一、需求&#xff1a; 页面上某一部分内容需要生成pdf并下载 二、技术方案&#xff1a; 使用html2canvas和jsPDF插件 三、js代码 // 页面导出为pdf格式 import html2Canvas from "html2canvas"; import jsPDF from "jspdf"; import { uploadImg } f…

大模型LLM表格报表分析:markitdown文件转markdown,大模型markdown统计分析

整体流程&#xff1a;用markitdown工具文件转markdown&#xff0c;然后大模型markdown统计分析 markitdown https://github.com/microsoft/markitdown 在线体验&#xff1a;https://huggingface.co/spaces/AlirezaF138/Markitdown 安装&#xff1a; pip install markitdown…

Linux 第二讲 --- 基础指令(二)

前言 这是基础指令的第二部分&#xff0c;但是该部分的讲解会大量使用到基础指令&#xff08;一&#xff09;的内容&#xff0c;为了大家的观感&#xff0c;如果对Linux的一些基本指令不了解的话&#xff0c;可以先看基础指令&#xff08;一&#xff09;&#xff0c;同样的本文…

python格式化字符串漏洞

什么是python格式化字符串漏洞 python中&#xff0c;存在几种格式化字符串的方式&#xff0c;然而当我们使用的方式不正确的时候&#xff0c;即格式化的字符串能够被我们控制时&#xff0c;就会导致一些严重的问题&#xff0c;比如获取敏感信息 python常见的格式化字符串 百…

LLaMA-Factory双卡4090微调DeepSeek-R1-Distill-Qwen-14B医学领域

unsloth单卡4090微调DeepSeek-R1-Distill-Qwen-14B医学领域后&#xff0c;跑通一下多卡微调。 1&#xff0c;准备2卡RTX 4090 2&#xff0c;准备数据集 医学领域 pip install -U huggingface_hub export HF_ENDPOINThttps://hf-mirror.com huggingface-cli download --resum…

React Hooks: useRef,useCallback,useMemo用法详解

1. useRef&#xff08;保存引用值&#xff09; useRef 通常用于保存“不会参与 UI 渲染&#xff0c;但生命周期要长”的对象引用&#xff0c;比如获取 DOM、保存定时器 ID、WebSocket等。 新建useRef.js组件&#xff0c;写入代码&#xff1a; import React, { useRef, useSt…

Spring AI 结构化输出详解

一、Spring AI 结构化输出的定义与核心概念 Spring AI 提供了一种强大的功能&#xff0c;允许开发者将大型语言模型&#xff08;LLM&#xff09;的输出从字符串转换为结构化格式&#xff0c;如 JSON、XML 或 Java 对象。这种结构化输出能力对于依赖可靠解析输出值的下游应用程…

THM Billing

1. 信息收集 (1) Nmap 扫描 bashnmap -T4 -sC -sV -p- 10.10.189.216 输出关键信息&#xff1a; PORT STATE SERVICE VERSION22/tcp open ssh OpenSSH 8.4p1 Debian 5deb11u380/tcp open http Apache 2.4.56 (Debian) # MagnusBilling 应用3306/tcp open …

布局决定终局:基于开源AI大模型、AI智能名片与S2B2C商城小程序的战略反推思维

摘要&#xff1a;在商业竞争日益激烈的当下&#xff0c;布局与终局预判成为企业成功的关键要素。本文探讨了布局与终局预判的智慧性&#xff0c;强调其虽无法做到百分之百准确&#xff0c;但能显著提升思考能力。终局思维作为重要战略工具&#xff0c;并非一步到位的战略部署&a…

贪心算法 day08(加油站+单调递增的数字+坏了的计算机)

目录 1.加油站 2.单调递增的数字 3.坏了的计算器 1.加油站 链接&#xff1a;. - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; gas[index] - cost[index]&#xff0c;ret 表示的是在i位置开始循环时剩余的油量 a到达的最大路径假设是f那么我们可以得出 a b …

【技术派部署篇】云服务器部署技术派

1 环境搭建 1.1 JDK安装 # ubuntu sudo apt update # 更新apt apt install openjdk-8-jdk # 安装JDK安装完毕之后&#xff0c;执行 java -version 命令进行验证&#xff1a; 1.2 Maven安装 cd ~ mkdir soft cd soft wget https://dlcdn.apache.org/maven/maven-3/3.8.8/bina…

Linux:35.其他IPC和IPC原理+信号量入门

通过命名管道队共享内存的数据发送进行保护的bug&#xff1a; 命名管道挂掉后&#xff0c;进程也挂掉了。 6.systemV消息队列 原理:进程间IPC:原理->看到同一份资源->维护成为一个队列。 过程&#xff1a; 进程A,进程B进行通信。 让操作系统提供一个队列结构&#xff0c;…

【数据结构】红黑树超详解 ---一篇通关红黑树原理(含源码解析+动态构建红黑树)

一.什么是红黑树 红黑树是一种自平衡的二叉查找树&#xff0c;是计算机科学中用到的一种数据结构。1972年出现&#xff0c;最初被称为平衡二叉B树。1978年更名为“红黑树”。是一种特殊的二叉查找树&#xff0c;红黑树的每一个节点上都有存储表示节点的颜色。每一个节点可以是…