【教程】从0开始搭建大语言模型:实现Attention机制

从0开始搭建大语言模型

  • 从0开始搭建大语言模型:实现Attention机制
    • 建模长序列存在的问题
    • 使用attention机制获得数据间的依赖
    • Self-attention
      • 介绍
      • 带有可训练权重的self-attention
        • 1.生成Q,K,V变量
        • 2.计算attention score
        • 3.attention weight的获得
        • 4.计算context vector
        • 5.对于query, key, value的理解
        • 6.masked attention的应用
        • 7.Dropout的使用
        • 8.多头注意力的应用

从0开始搭建大语言模型:实现Attention机制

接上文:【教程】从0开始搭建大语言模型:Word和位置Embedding

建模长序列存在的问题

思考一个问题,之前语言模型没有注意力机制,它有什么问题?

当将文本从一种语言翻译成另一种语言时,例如从德语翻译成英语,仅仅是逐字翻译是不可能的。相反,翻译过程需要上下文理解和语法对齐。

为了解决这个问题,通常使用具有两个子模块的深度神经网络,即所谓的编码器和解码器。编码器的工作首先是读入并处理整个文本,然后解码器生成翻译后的文本。

在transformer出现之前,循环神经网络(RNN)是语言翻译中最流行的编码器-解码器架构。RNN是一种神经网络,前一步的输出作为当前步骤的输入,使其非常适合于文本等顺序数据。

下图是RNN进行语言翻译的例子:
在这里插入图片描述
编码器将来自源语言的token序列作为输入,其中编码器的隐藏状态(中间神经网络层)对整个输入序列的压缩表示进行编码。然后,解码器使用其当前的隐藏状态开始一个词一个词的翻译。

上面这种结果存在不足:RNN不能在解码阶段直接从编码器获取早期的隐藏状态。因此,它仅依赖于当前隐藏状态,其中保存了所有相关信息。这可能会导致上下文的丢失,特别是在依赖关系可能跨越长距离的复杂句子中。

使用attention机制获得数据间的依赖

前面提到,RNN可以很好地翻译短句,但不能很好地翻译较长的文本,因为它们不能直接访问输入中之前的单词。

为了解决这个问题,研究人员为RNN开发了所谓的Bahdanau注意力机制,如下:
在这里插入图片描述
上图的网络的文本生成解码器部分可以有选择地访问所有输入token,因此在生成给定输出token时,某些输入token比其他标记更重要,类似于注意力权重。

受Bahdanau注意力的启发,自注意力被提出,它允许在计算序列表示时,输入序列中的每个位置都关注同一序列中的所有位置。

Self-attention

介绍

在自注意力中,“self”是指机制通过关联单个输入序列中的不同位置来计算注意力权重的能力。它评估和学习输入本身不同部分之间的关系和依赖关系,例如句子中的单词或图像中的像素。

自注意力的目标是为每个输入元素计算一个上下文向量,该向量结合了来自所有其他输入元素的信息,如下图所示:
在这里插入图片描述
在self-attention中,我们的目标是为输入序列中的每个元素x (i)计算上下文向量z(i)。上下文向量可以被解释为丰富的embedding向量。

在self-attention中,上下文向量起着至关重要的作用。它们的目的是通过合并来自序列中所有其他元素的信息,创建输入序列中每个元素(如句子)的丰富表示。

x(2)和其他元素的attention score的计算如下图所示,通过计算点积得到。
在这里插入图片描述
点积是一种相似性度量,因为它量化了两个向量之间的对齐程度:

  • 点积越高,表示向量之间的对齐程度或相似性越高。
  • 在自注意力机制的背景下,点积决定了序列中元素相互关注的程度
    • 点积越高,两个元素之间的相似性和注意力分数越高

得到attention score后,还需要归一化得到最终的attention weights,如下图:
在这里插入图片描述
归一化背后的主要目标是获得总和为1的注意力权重。这种归一化是一种约定,对于解释和保持LLM的训练稳定性很有用。

在实践中,使用softmax函数进行归一化更常见,也更可取。这种方法可以更好地管理极值,并在训练过程中提供更有利的梯度属性。此外,softmax函数确保注意力权重始终为正。这使得输出可以解释为概率或相对重要性,其中权重越大,说明重要性越高。

最后,将各个元素的信息通过attention weights结合其他,得到x(2)的上下文信息z(2):
在这里插入图片描述
如果要对所有token计算它与其他token的信息,可以使用矩阵乘法,而不是for循环,因为for循环计算效率低。

带有可训练权重的self-attention

在LLM中,self-attention中权重矩阵会在模型训练过程中更新。这些可训练的权重矩阵至关重要,以便模型(特别是模型中的注意力模块)可以学习产生“好的”上下文向量。

1.生成Q,K,V变量

引入三个可训练的权重矩阵Wq、Wk和Wv,这三个矩阵用于将embedding的输入标记x(i)投影为query、key和value,如下:
在这里插入图片描述
需要注意,权重参数(权重矩阵的值)是定义网络连接的基本的学习系数,而注意力权重是动态的、特定于上下文的值。

2.计算attention score

在这里插入图片描述
注意力分数计算是一种点积计算,需要注意的是:我们不直接计算输入元素之间的点积,而是使用通过各自的权重矩阵转换输入得到的query和key。

3.attention weight的获得

如下图,在计算出注意力分数ω之后,下一步是使用softmax函数对这些分数进行归一化,以获得注意力权重:
在这里插入图片描述
需要注意的是,我们通过除以key的embedding维度的平方根来缩放注意力分数。缩放注意力的原因:

  • 通过避免小梯度来提高训练性能。例如,当扩大嵌入维度时(对于类似gpt的LLM通常大于1000),由于对其应用了softmax函数,在反向传播过程中,较大的点积可能会导致非常小的梯度
  • 随着点积的增加,softmax函数的行为更像一个阶梯函数,导致梯度接近于零。这些小的梯度可能会大大减慢学习速度或导致训练停滞、

对embedding维度平方根的缩放是这种自注意力机制也被称为scaled-dot product attention的原因。

4.计算context vector

在这里插入图片描述
在自注意力计算的最后一步,我们通过注意力权重组合所有value向量来计算上下文向量。注意力权重作为一个权重因子,对每个值向量的重要性进行加权。

5.对于query, key, value的理解

注意力机制上下文中的术语“query”、“key”和“value”是从信息检索和数据库领域借用的。

“query”类似于数据库中的搜索查询。它表示模型关注或试图理解的当前项(例如,句子中的一个单词或标记)。query用于探测输入序列的其他部分,以确定对它们的关注程度。

“key”类似于数据库中用于索引和搜索的key。在注意力机制中,输入序列中的每个项目(例如,句子中的每个单词)都有一个关联的关键字。这些键用于匹配查询。

这里的“value”类似于数据库中键值对中的值。它表示输入项的实际内容或表示。一旦模型确定了哪些键(以及输入的哪些部分)与查询(当前焦点项)最相关,它就检索相应的值。

上面过程的总体流程为:
在这里插入图片描述
在self-attention中,我们用三个权重矩阵Wq、Wk和Wv对输入矩阵X中的输入向量进行变换。然后,我们根据结果查询(Q)和键(K)计算注意力权重矩阵。使用注意力权重和值(V),然后计算上下文向量(Z)。

使用权重矩阵的代码为:

import torch.nn as nn
class SelfAttention_v1(nn.Module):def __init__(self, d_in, d_out):super().__init__()self.d_out = d_outself.W_query = nn.Parameter(torch.rand(d_in, d_out))self.W_key = nn.Parameter(torch.rand(d_in, d_out))self.W_value = nn.Parameter(torch.rand(d_in, d_out))def forward(self, x):keys = x @ self.W_keyqueries = x @ self.W_queryvalues = x @ self.W_valueattn_scores = queries @ keys.T # omegaattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)context_vec = attn_weights @ valuesreturn context_vecimport torch
inputs = torch.tensor([[0.43, 0.15, 0.89], # Your (x^1)[0.55, 0.87, 0.66], # journey (x^2)[0.57, 0.85, 0.64], # starts (x^3)[0.22, 0.58, 0.33], # with (x^4)[0.77, 0.25, 0.10], # one (x^5)[0.05, 0.80, 0.55]] # step (x^6)
)
torch.manual_seed(123)
d_in = inputs.shape[1] #B
d_out = 2 #C
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

我们也可以使用nn.Linear实现上述过程,代码为:

class SelfAttention_v2(nn.Module):def __init__(self, d_in, d_out, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)def forward(self, x):keys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.Tattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)context_vec = attn_weights @ valuesreturn context_vec
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

需要注意的是,Linear将权重矩阵存储为转置形式。

6.masked attention的应用

masked attention限制模型在处理任何给定token时只考虑序列中以前和当前的输入。这与标准的自注意力机制不一样,后者允许一次访问整个输入序列。在计算注意力分数时,masked attention机制确保模型只考虑序列中出现在当前token之前或之前的token。

如下图,在masked attention中,我们屏蔽了对角线上的注意力权重,使得对于给定的输入,LLM在使用注意力权重计算上下文向量时无法访问未来的token。
在这里插入图片描述
在上图中,我们mask掉对角线上的注意力权重,并对未mask的注意力权重进行归一化,使每行的注意力权重之和为1。

获得masked attention weights的一种方式如下:
在这里插入图片描述
需要注意:

  • 当我们应用掩码,然后重新规范化注意力权重时,最初可能看起来来自未来token(打算mask)的信息仍然可能影响当前token,因为它们的值是softmax计算的一部分。
  • 当我们重新规范化掩码后的注意力权重时,我们本质上所做的是在较小的子集上重新计算softmax(因为掩码位置对softmax值没有贡献)

softmax的数学优雅之处在于,尽管最初在分母中包括所有位置,但在掩码和重归一化之后,掩码位置的影响被消除——它们不会以任何有意义的方式对softmax分数做出贡献。

另外一种更高效的实现方法如下:
在这里插入图片描述

7.Dropout的使用

深度学习中的Dropout是一种在训练过程中忽略随机选择的隐藏层单元的技术,有效地“丢弃”它们。这种方法通过确保模型不过度依赖任何特定的隐藏层单元集,有助于防止过拟合。需要强调的是,dropout只在训练过程中使用,并在训练结束后停用。

在transformer架构中,包括像GPT这样的模型,注意力机制中的dropout通常应用于两个特定领域:

  • 计算注意力分数后或
  • 注意力权重应用于value向量后

随机去掉一些attention weights的示例如下:
在这里插入图片描述
当将dropout应用于一个dropout率为50%的注意力权重矩阵时,矩阵中有一半元素随机设置为0。为了补偿活跃元素的减少,矩阵中剩余元素的值按比例放大1/0.5 =2。这种缩放对于保持注意力权重的整体平衡至关重要,确保注意力机制的平均影响在训练和推理阶段保持一致。

结合了mask attention和dropout的注意力代码为:

class CausalAttention(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):super().__init__()self.d_out = d_outself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.dropout = nn.Dropout(dropout) #A# 下三角,context_length为文本长度self.register_buffer('mask',torch.triu(torch.ones(context_length, context_length), diagonal=1)) #Bdef forward(self, x):b, num_tokens, d_in = x.shape #C# New batch dimension bkeys = self.W_key(x)queries = self.W_query(x)values = self.W_value(x)attn_scores = queries @ keys.transpose(1, 2) #C# mask操作attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)context_vec = attn_weights @ valuesreturn context_vec
torch.manual_seed(123)
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
8.多头注意力的应用

“多头”是指将注意力机制划分为多个“头”,每个“头”独立操作。在这种情况下,单个注意力模块可以被认为是单头注意力,其中只有一组注意力权重按顺序处理输入,而多头是有多组权重处理输入。

实现多头可以通过堆叠多个attention模块来实现,如下图:

在这里插入图片描述
如上图,在一个有两个头的多头注意力模块中,query, key和value各有两个权重矩阵,该部分用代码实现为:

class MultiHeadAttentionWrapper(nn.Module):def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):super().__init__()self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])def forward(self, x):return torch.cat([head(x) for head in self.heads], dim=-1)

需要注意,最后需要将每个头得到的context vector拼接,如下图:
在这里插入图片描述
上面实现多头的方法不能并行处理每一个头,事实上,可以通过矩阵乘法同时计算所有注意力头的输出。

具体来说,可以通过reshape投影的query、key和value张量,将输入分成多个头,然后在计算注意力后结合这些头的结果。

代码为:

class MultiHeadAttention(nn.Module):def __init__(self, d_in, d_out,context_length, dropout, num_heads, qkv_bias=False):super().__init__()assert d_out % num_heads == 0, "d_out must be divisible by num_heads"self.d_out = d_outself.num_heads = num_heads# 每个头中的输出维度self.head_dim = d_out // num_heads #Aself.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)self.out_proj = nn.Linear(d_out, d_out) #Bself.dropout = nn.Dropout(dropout)self.register_buffer( 'mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))def forward(self, x):b, num_tokens, d_in = x.shapekeys = self.W_key(x) #Cqueries = self.W_query(x) #Cvalues = self.W_value(x) #C# 将映射的变量reshape成多个头keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) #Dvalues = values.view(b, num_tokens, self.num_heads, self.head_dim) #Dqueries = queries.view(b, num_tokens, self.num_heads, self.head_dim)#D# 交换矩阵的维度keys = keys.transpose(1, 2) #Equeries = queries.transpose(1, 2) #Evalues = values.transpose(1, 2) #Eattn_scores = queries @ keys.transpose(2, 3) #Fmask_bool = self.mask.bool()[:num_tokens, :num_tokens] #Gattn_scores.masked_fill_(mask_bool, -torch.inf) #Hattn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)attn_weights = self.dropout(attn_weights)context_vec = (attn_weights @ values).transpose(1, 2) #I#Jcontext_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)context_vec = self.out_proj(context_vec) #Kreturn context_vec

将query,key,value变成(b, num_heads,num_tokens, head_dim)是很重要的,矩阵乘法在最后2个维度(num_tokens, head_dim)之间进行,然后对每个头重复。

除此之外,代码中还加入了output projection layer,这在LLM的代码中经常能看到。

该方法比第一种方法更加高效,因为只需要一次矩阵乘法就可以计算出key,而在第一种方法中,我们需要为每个注意力头重复这个矩阵乘法,这是计算成本最高的步骤之一。它们两个的区别如下:
在这里插入图片描述
需要注意的是,最小的GPT-2模型(1.17亿个参数)具有12个注意力头和768个上下文向量embedding大小。最大的GPT-2模型(15亿个参数)有25个注意力头和1600个上下文向量embedding大小。token输入的embedding大小和上下文embedding在GPT模型中是相同的(d_in = d_out)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/851607.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

支付交易——在线支付系统基本概念

摘要 本文聚集于实战,只讲解最实用的知识点,至于支付起源、在线支付发展历程等科普知识,感兴趣的读者可参考其它优秀的支付类书籍或网络上其它优秀的文章。本章内容对大部分专业概念进行了极致简化,以便更好地帮助读者入门。实际…

鸿蒙轻内核M核源码分析系列二十 Newlib C

LiteOS-M内核LibC实现有2种,可以根据需求进行二选一,分别是musl libC和newlibc。本文先学习下Newlib C的实现代码。文中所涉及的源码,均可以在开源站点https://gitee.com/openharmony/kernel_liteos_m 获取。 使用Musl C库的时候&#xff0c…

具有可编程电流限制的1.5A电源开关LPW5210用于5V或USB供电输出过流保护只要3毛

前言 适合要求反应时间较快的保护电路,保险丝或自恢复保险丝也能起到保护作用,但断开电流是额定电流的一倍,过流较小时,甚至需要数秒或更长的时间才能保护,因此半导体的过流保护开关更合适,相对成本要高一…

replaceAll is not a function 详解

先说说原因: 在chrome 浏览器中使用 replaceAll 报这个错误,是因为chrome 版本过低, 在chrome 85 以上版本才支持 用法 replaceAll(pattern, replacement)const paragraph "I think Ruths dog is cuter than your dog!"; console…

数据可视化后起之秀——pyecharts

题目一:绘制折线图,展示商家A与商家B各类饮品的销售额 题目描述: 编写程序。根据第9.3.1,绘制折线图,展示商家A与商家B各类饮品的销售额。 运行代码: #绘制折线图,展示商家A与商家B各类饮品的…

淘宝数据抓取的全景解析

——分析淘宝数据抓取的常见方法及其适用场景,探讨不同技术如何影响数据的质量和可用性 在当今数据驱动的电子商务时代,能够有效地抓取和利用数据成为商家获得竞争优势的关键。淘宝作为中国最大的在线零售平台,其海量数据具有极高的价值&…

Spring--Bean的作用域,生命周期

Bean的作用域 Bean的作用域有很多种,在Spring Framework中支持6种(其中有四种只有在web环境中才能生效),同时Spring还支持自定义Bean的范围。 Spring Framework中支持的6种范围: 作用域解释singleton每个Spring IoC…

从零到一建设数据中台(番外篇)- 数据中台UI欣赏

番外篇 - 数据中台 UI 欣赏 话不多说,直接上图。

想让谷歌独立站关键词排名一飞冲天?这个秘密技巧必须知道!

在激烈的谷歌排名竞争中,我们要确保自己优化的独立站在谷歌搜索结果中占据首页位置至关重要。排名首页不仅能显著提高展现和流量,还能带来更多潜在客户和业务机会。本文将从谷歌SEO技术的角度,深入探讨在谷歌独立站关键词排名首页最重要的几个…

ShardingSphere跨表查询报错

目录 一、场景简介二、报错信息三、SQL四、原因五、解决方法一、调整SQL,不使用子查询方法二、将子查询的SQL独立出来,后续连接逻辑由代码处理 一、场景简介 1、使用ShardingSphere按月份进行分表 2、单月查询正常(单表) 3、跨…

CA证书及PKI

文章目录 概述非对称加密User Case: 数据加密User Case: 签名验证潜在问题 CACA证书的组成CA签发证书流程CA验证签名流程CA吊销证书流程 PKI信任链证书链 概述 首先我们需要简单对证书有一个基本的概念,以几个问题进入了解 ❓ Question1: 什么是证书? 证…

福利|免费申请长期单域名、通配符、多域名SSL证书,不限量

一、什么是单域名、通配符、多域名SSL证书 单域名证书:仅保护一个特定的域名。 通配符证书:保护一个主域名及其所有二级子域名。 多域名证书:在同一张证书中保护多个不同的域名,可以是主域名也可以是子域名,域名之间…

目前比较好用的LabVIEW架构及其选择

LabVIEW提供了多种架构供开发者选择,以满足不同类型项目的需求。选择合适的架构不仅可以提高开发效率,还能确保项目的稳定性和可维护性。本文将介绍几种常用的LabVIEW架构,并根据不同项目需求和个人习惯提供选择建议。 常用LabVIEW架构 1. …

Invalid keystore format,获取安全码SHA1值出错

AndroidStudio版本:Android Studio Electric Eel | 2022.1.1 项目运行JDK版本:11.0.15,查看方法如下: 在Terminal 窗口中,获取的Java版本是:1.8.0,修改Java系统环境变量,改成&#…

如何在MySQL中创建不同的索引和用途?

目录 1 基本的 CREATE INDEX 语法 2 创建单列索引 3 创建多列索引 4 创建唯一索引 5 创建全文索引 6 在表创建时添加索引 7 使用 ALTER TABLE 添加索引 8 删除索引 9 索引管理的最佳实践 10 示例 在 MySQL 中,索引(index)是一种用于…

Git保姆级教程

目录 Git是什么,为什么要学这个工具? 码云注册并创建仓库 Git安装 查看本地仓库状态 添加到暂存区 提交到本地库 修改文件 版本回退 创建、切换和删除分支 合并分支 克隆远端库到本地 将本地库推送到远端库 命令设置别名 Git是什么&#xf…

远程咨询的好处都有哪些呢?

随着科技的飞速发展,远程咨询正逐渐成为人们获取医疗服务的一种新方式。那么什么是远程咨询呢?其又有哪些好处呢?下面就给大家详细地说说。 远程咨询的概念 远程咨询,顾名思义,是指通过互联网技术,实现患…

使用try-catch捕获异常到底会不会影响性能?尤其是try-catch还比较多的情况下?

从字节码层面来看,没抛错两者的执行效率其实没啥差别。 “那为什么网上流传着try-catch会有性能问题的说法啊? 这个说法确实有,在《Effective Java》这本书里就提到了 try-catch 性能问题: 总结: 1、try-catch 相比较…

汇编:数组数据传送

要在32位汇编中实现数组数据的传送,可以使用字符串操作指令 MOVS 以及其前缀 REP,可以高效地复制数组数据。 MOVS 指令是一种字符串操作指令,用于将数据从源地址移动到目标地址。MOVS 指令有不同的变种,可以处理不同大小的数据&a…

水印怎么去除?Windows 上的最佳水印软件

我们都知道,任何水印软件都可以防止您的数字财产被盗。此外,水印是一种虚拟营销元素,可以帮助您推广您的作品。 奇客水印管家是 Internet 上适用于 Windows 7、8 、10 和 11 的最高效的水印软件。此外,它还允许用户通过添加或删除…