探索大语言模型(LLM):Transformer 与 BERT从原理到实践

Transformer 与 BERT:从原理到实践

  • 前言
  • 一、背景介绍
  • 二、核心公式推导
    • 1. 注意力机制(Attention Mechanism)
    • 2. 多头注意力机制(Multi-Head Attention)
    • 3. Transformer 编码器(Transformer Encoder)
    • 4. BERT 的预训练任务
  • 三、代码实现
    • 1. 注意力机制
    • 2. 多头注意力机制
    • 3. Transformer 编码器层
    • 4. Transformer 编码器
    • 5. BERT 模型
  • 四、总结


前言

在自然语言处理(NLP)的发展历程中,Transformer 和 BERT 无疑是具有里程碑意义的技术。它们的出现,彻底改变了 NLP 领域的研究和应用格局。本文将深入探讨 Transformer 和 BERT 的背景、核心公式推导,并提供代码实现,帮助大家更好地理解和应用这两项技术。

一、背景介绍

在 Transformer 出现之前,循环神经网络(RNN)及其变体长短时记忆网络(LSTM)、门控循环单元(GRU)等在 NLP 任务中占据主导地位。RNN 能够处理序列数据,通过隐状态传递信息,从而捕捉上下文依赖关系。然而,RNN 存在严重的梯度消失和梯度爆炸问题,使得训练深层网络变得困难。此外,RNN 的顺序计算特性导致其难以并行化,处理长序列时效率低下。

为了解决这些问题,2017 年谷歌团队在论文《Attention Is All You Need》中提出了 Transformer 架构。Transformer 完全摒弃了循环结构,采用多头注意力机制(Multi-Head Attention)替代 RNN,实现了并行计算,大幅提高了训练效率。同时,多头注意力机制能够更好地捕捉序列中的长距离依赖关系,在机器翻译、文本生成等多个 NLP 任务中取得了优异的性能。

BERT(Bidirectional Encoder Representations from Transformers)则是基于 Transformer 的预训练语言模型,由谷歌在 2018 年提出。与传统的语言模型(如 Word2Vec、GPT)不同,BERT 采用双向 Transformer 编码器,能够同时利用上下文信息,学习到更丰富的语义表示。通过在大规模文本数据上进行预训练,并在特定任务上进行微调,BERT 在问答系统、文本分类、命名实体识别等众多 NLP 任务中刷新了当时的最优成绩,开启了预训练模型在 NLP 领域的新时代。

二、核心公式推导

1. 注意力机制(Attention Mechanism)

注意力机制的核心思想是根据输入序列的不同部分对当前任务的重要程度,分配不同的权重,从而聚焦于关键信息。其计算过程如下:
给定查询向量 (Q),键向量 (K) 和值向量 (V),注意力分数 (scores) 计算为:
s c o r e s = Q K T d k scores = \frac{QK^T}{\sqrt{d_k}} scores=dk QKT
其中, d k d_k dk 是键向量 K K K 的维度, d k \sqrt{d_k} dk 用于缩放,防止分数过大导致 softmax 函数梯度消失。
通过 softmax 函数对注意力分数进行归一化,得到注意力权重 (attention weights):
a t t e n t i o n _ w e i g h t s = s o f t m a x ( s c o r e s ) attention\_weights = softmax(scores) attention_weights=softmax(scores)
最后,加权求和得到注意力输出 A t t e n t i o n ( Q , K , V ) Attention(Q, K, V) Attention(Q,K,V)
A t t e n t i o n ( Q , K , V ) = a t t e n t i o n _ w e i g h t s ⋅ V Attention(Q, K, V) = attention\_weights \cdot V Attention(Q,K,V)=attention_weightsV

2. 多头注意力机制(Multi-Head Attention)

多头注意力机制通过多个独立的注意力头并行计算,从不同角度捕捉输入序列的特征,然后将各个头的输出拼接并线性变换得到最终输出。具体计算过程如下:
首先,将输入 X X X分别通过三个线性变换得到 Q Q Q K K K V V V
Q = X W Q K = X W K V = X W V Q = XW^Q\\ K = XW^K\\ V = XW^V Q=XWQK=XWKV=XWV
其中, W Q W^Q WQ W K W^K WK W V W^V WV 是可学习的权重矩阵。
然后,将 Q Q Q K K K V V V 分割成 h h h 个头部(head),每个头部的维度为 d k / h d_{k/h} dk/h d v / h d_{v/h} dv/h
Q i = Q ( i − 1 ) d k / h : i d k / h K i = K ( i − 1 ) d k / h : i d k / h V i = V ( i − 1 ) d v / h : i d v / h Q_i = Q_{(i-1)d_{k/h}:id_{k/h}} \\ K_i = K_{(i-1)d_{k/h}:id_{k/h}} \\ V_i = V_{(i-1)d_{v/h}:id_{v/h}} Qi=Q(i1)dk/h:idk/hKi=K(i1)dk/h:idk/hVi=V(i1)dv/h:idv/h
对每个头部分别计算注意力输出:
h e a d i = A t t e n t i o n ( Q i , K i , V i ) head_i = Attention(Q_i, K_i, V_i) headi=Attention(Qi,Ki,Vi)
将所有头部的输出拼接起来:
c o n c a t ( h e a d 1 , . . . , h e a d h ) concat(head_1, ..., head_h) concat(head1,...,headh)
最后,通过一个线性变换得到多头注意力机制的最终输出:
M u l t i H e a d A t t e n t i o n ( X ) = W O ⋅ c o n c a t ( h e a d 1 , . . . , h e a d h ) MultiHeadAttention(X) = W^O \cdot concat(head_1, ..., head_h) MultiHeadAttention(X)=WOconcat(head1,...,headh)
其中, W O W^O WO是可学习的权重矩阵。

3. Transformer 编码器(Transformer Encoder)

Transformer 编码器由多个相同的层堆叠而成,每个层包含两个子层:多头注意力机制子层和前馈神经网络子层。每个子层都使用了残差连接(Residual Connection)和层归一化(Layer Normalization)。
输入 (X) 首先经过多头注意力机制子层:
X 1 = L a y e r N o r m ( X + M u l t i H e a d A t t e n t i o n ( X ) ) X_1 = LayerNorm(X + MultiHeadAttention(X)) X1=LayerNorm(X+MultiHeadAttention(X))
然后, X 1 X_1 X1 经过前馈神经网络子层:
X 2 = L a y e r N o r m ( X 1 + F F N ( X 1 ) ) X_2 = LayerNorm(X_1 + FFN(X_1)) X2=LayerNorm(X1+FFN(X1))
其中, F F N ( X ) FFN(X) FFN(X)是前馈神经网络,通常由两个线性层和一个激活函数组成:
F F N ( X ) = m a x ( 0 , X W 1 + b 1 ) W 2 + b 2 FFN(X) = max(0, XW_1 + b_1)W_2 + b_2 FFN(X)=max(0,XW1+b1)W2+b2

4. BERT 的预训练任务

BERT 采用了两个预训练任务:掩码语言模型(Masked Language Model,MLM)和下一句预测(Next Sentence Prediction,NSP)。
在掩码语言模型任务中,随机将输入文本中的一些单词替换为 [MASK] 标记,然后让模型预测这些被掩码的单词。例如,对于句子 “I love natural language processing”,可能会将 “love” 替换为 [MASK],模型需要根据上下文预测出 “love”。
下一句预测任务用于学习句子之间的关系。给定一对句子,判断第二个句子是否是第一个句子的下一句。例如,句子对 “今天天气很好。我们去公园散步吧。” 标签为正例,而 “今天天气很好。我喜欢吃苹果。” 标签为负例。
BERT 通过最小化这两个任务的损失函数进行预训练,损失函数为:
L = L M L M + L N S P L = L_{MLM} + L_{NSP} L=LMLM+LNSP

三、代码实现

下面我们使用 PyTorch 实现一个简单的 Transformer 编码器和 BERT 预训练模型。

1. 注意力机制

import torch
import torch.nn as nndef scaled_dot_product_attention(Q, K, V):d_k = K.size(-1)scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))attention_weights = nn.functional.softmax(scores, dim=-1)return torch.matmul(attention_weights, V)

2. 多头注意力机制

class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()self.num_heads = num_headsself.d_model = d_modelself.depth = d_model // num_headsself.WQ = nn.Linear(d_model, d_model)self.WK = nn.Linear(d_model, d_model)self.WV = nn.Linear(d_model, d_model)self.WO = nn.Linear(d_model, d_model)def split_heads(self, x):batch_size = x.size(0)return x.view(batch_size, -1, self.num_heads, self.depth).transpose(1, 2)def forward(self, X):Q = self.split_heads(self.WQ(X))K = self.split_heads(self.WK(X))V = self.split_heads(self.WV(X))attention = scaled_dot_product_attention(Q, K, V)concatenated_attention = attention.transpose(1, 2).contiguous().view(-1, self.d_model)return self.WO(concatenated_attention)

3. Transformer 编码器层

class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, num_heads):super(TransformerEncoderLayer, self).__init__()self.attention = MultiHeadAttention(d_model, num_heads)self.ffn = nn.Sequential(nn.Linear(d_model, d_model * 4),nn.ReLU(),nn.Linear(d_model * 4, d_model))self.layernorm1 = nn.LayerNorm(d_model)self.layernorm2 = nn.LayerNorm(d_model)def forward(self, X):attn_output = self.attention(X)X = self.layernorm1(X + attn_output)ffn_output = self.ffn(X)return self.layernorm2(X + ffn_output)

4. Transformer 编码器

class TransformerEncoder(nn.Module):def __init__(self, num_layers, d_model, num_heads):super(TransformerEncoder, self).__init__()self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, num_heads) for _ in range(num_layers)])def forward(self, X):for layer in self.layers:X = layer(X)return X

5. BERT 模型

class BERT(nn.Module):def __init__(self, vocab_size, num_layers, d_model, num_heads):super(BERT, self).__init__()self.embedding = nn.Embedding(vocab_size, d_model)self.transformer = TransformerEncoder(num_layers, d_model, num_heads)self.mlm_head = nn.Linear(d_model, vocab_size)self.nsp_head = nn.Linear(d_model, 2)def forward(self, X, masked_indices=None):X = self.embedding(X)X = self.transformer(X)if masked_indices is not None:masked_X = torch.gather(X, 1, masked_indices.unsqueeze(-1).repeat(1, 1, X.size(-1)))mlm_logits = self.mlm_head(masked_X)else:mlm_logits = Nonensp_logits = self.nsp_head(X[:, 0])return mlm_logits, nsp_logits

以上代码实现了 Transformer 编码器和 BERT 模型的基本结构。在实际应用中,还需要进行数据预处理、模型训练和评估等步骤。

四、总结

Transformer 和 BERT 作为 NLP 领域的重要技术,以其独特的架构和强大的性能,推动了 NLP 技术的快速发展。通过本文对 Transformer 和 BERT 的背景介绍、公式推导和代码实现,相信大家对它们有了更深入的理解。随着研究的不断深入,Transformer 和 BERT 的应用场景也在不断拓展,未来它们将在更多领域发挥重要作用。希望本文能为大家的学习和研究提供帮助,欢迎大家在评论区交流讨论。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902073.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络八股——HTTP协议与HTTPS协议

目录 HTTP1.1简述与特性 1. 报文清晰易读 2. 灵活和易于扩展 3. ⽆状态 Cookie和Session 4. 明⽂传输、不安全 HTTP协议发展过程 HTTP/1.1的不足 HTTP/2.0 HTTP/3.0 HTTPS协议 HTTP协议和HTTPS协议的区别 HTTPS中的加密方式 HTTPS中建立连接的方式 前言&#xff…

QML中的3D功能--入门开发

Qt Quick 提供了强大的 3D 功能支持,主要通过 Qt 3D 模块实现。以下是 QML 中开发 3D 应用的全面指南。 1. 基本配置 环境要求 Qt 5.10 或更高版本(推荐 Qt 6.x) 启用 Qt 3D 模块 支持 OpenGL 的硬件 项目配置 在 .pro 文件中添加: QT += 3dcore 3drender 3dinput 3dex…

Git合并分支的两种常用方式`git merge`和`git cherry-pick`

Git合并分支的两种常用方式git merge和git cherry-pick 写在前面1. git merge用途工作方式使用git命令方式合并使用idea工具方式合并 2. git cherry-pick用途工作方式使用git命令方式合并使用idea工具方式合并 3. 区别总结 写在前面 一般我们使用git合并分支常用的就是git mer…

Web三漏洞学习(其三:rce漏洞)

靶场:NSSCTF 三、RCE漏洞 1、概述 在Web应用开发中会让应用调用代码执行函数或系统命令执行函数处理,若应用对用户的输入过滤不严,容易产生远程代码执行漏洞或系统命令执行漏洞 所以常见的RCE漏洞函数又分为代码执行函数和系统命令执行函数…

从零开始:Python运行环境之VSCode与Anaconda安装配置全攻略 (1)

从零开始:Python 运行环境之 VSCode 与 Anaconda 安装配置全攻略 在当今数字化时代,Python 作为一种功能强大且易于学习的编程语言,被广泛应用于数据科学、人工智能、Web 开发等众多领域。为了顺利开启 Python 编程之旅,搭建一个稳…

从FPGA实现角度介绍DP_Main_link主通道原理

DisplayPort(简称DP)是一个标准化的数字式视频接口标准,具有三大基本架构包含影音传输的主要通道(Main Link)、辅助通道(AUX)、与热插拔(HPD)。 Main Link:用…

嵌入式软件--stm32 DAY 2

大家学习嵌入式的时候,多多学习用KEIL写代码,虽然作为编译器,大家常用vscode等常用工具关联编码,但目前keil仍然是主流工具之一,学习掌握十分必要。 1.再次创建项目 1.1编译器自动生成文件 1.2初始文件 这样下次创建新…

游戏引擎学习第234天:实现基数排序

回顾并为今天的内容设定背景 我们今天继续进行排序的相关,虽然基本已经完成了,但还是想收尾一下,让整个流程更完整。其实这次排序只是个借口,主要是想顺便聊一聊一些计算机科学的知识点,这些内容在我们项目中平时不会…

计算机网络——常见的网络攻击手段

什么是XSS攻击,如何避免? XSS 攻击,全称跨站脚本攻击(Cross-Site Scripting),这会与层叠样式表(Cascading Style Sheets, CSS)的缩写混淆,因此有人将跨站脚本攻击缩写为XSS。它指的是恶意攻击者往Web页面…

Agent的九种设计模式 介绍

Agent的九种设计模式 介绍 一、ReAct模式 原理:将推理(Reasoning)和行动(Acting)相结合,使Agent能够在推理的指导下采取行动,并根据行动的结果进一步推理,形成一个循环。Agent通过生成一系列的思维链(Thought Chains)来明确推理步骤,并根据推理结果执行相应的动作,…

LeetCode 热题 100:回溯

46. 全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#xff…

cJSON_Print 和 cJSON_PrintUnformatted的区别

cJSON_Print 和 cJSON_PrintUnformatted 是 cJSON 库中用于将 cJSON 对象转换为 JSON 字符串的两个函数,它们的区别主要在于输出的格式: 1. cJSON_Print 功能:将 cJSON 对象转换为格式化的 JSON 字符串。 特点: 输出的 JSON 字符…

A股周度复盘与下周策略 的deepseek提示词模板

以下是反向整理的股票大盘分析提示词模板,采用结构化框架数据占位符设计,可直接套用每周市场数据: 请根据一下markdown格式的模板,帮我检索整理并输出本周股市复盘和下周投资策略 【A股周度复盘与下周策略提示词模板】 一、市场…

Linux下使用C++获取硬件信息

目录 方法获取CPU信息:读取"/proc/cpuinfo"文件获取磁盘信息:读取"/proc/diskstats"文件获取BIOS信息有两种方法:1、读取文件;2、使用dmidecode命令获取主板信息有两种方法:1、读取文件&#xff1…

BootStrap:进阶使用(其二)

今天我要讲述的是在BootStrap中第二篇关于进一步使用的方法与代码举例; 分页: 对于一些大型网站而言,分页是一个很有必要的存在,如果当数据内容过大时,则需要分页来分担一些,这可以使得大量内容能整合并全面地展示&a…

【技术派后端篇】技术派中的白名单机制:基于Redis的Set实现

在技术派社区中,为了保证文章的质量和社区的良性发展,所有发布的文章都需要经过审核。然而,并非所有作者的文章都需要审核,我们通过白名单机制来优化这一流程。本文将详细介绍技术派中白名单的实现方式,以及如何利用Re…

TRAE.AI 国际版本

国际版下载地址: https://www.trae.ai/https://www.trae.ai/ 国际版本优势:提供更多高校的AI助手模型 Claude-3.5-Sonnet Claude-3.7-Sonnet Gemini-2.5-Pro GPT-4.1 GPT-40 DeepSeek-V3-0324DeepSeek-V3DeepSeek-Reasoner(R1)

关于支付宝网页提示非官方网页

关于支付宝网站提示 非官方网站 需要找官方添加白名单 下面可以直接用自己的邮箱去发送申请 支付宝提示“非支付宝官方网页,请确认是否继续访问”通常是因为支付宝的安全机制检测到您访问的页面不是支付宝官方页面,这可能是由于域名或页面内容不符合支…

【今日三题】打怪(模拟) / 字符串分类(字符串哈希) / 城市群数量(dfs)

⭐️个人主页&#xff1a;小羊 ⭐️所属专栏&#xff1a;每日两三题 很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~ 目录 打怪(模拟)字符串分类(字符串哈希)城市群数量(dfs) 打怪(模拟) 打怪 #include <iostream> using namespace std;int …

npm install 版本过高引发错误,请添加 --legacy-peer-deps

起因&#xff1a;由于使用"react": "^19.0.0", 第三方包要低版本react&#xff0c;错解决方法&#xff01; npm install --save emoji-mart emoji-mart/data emoji-mart/react npm install --save emoji-mart emoji-mart/data emoji-mart/react npm err…