简单实现Transformer的自注意力
关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!
如果你感觉对你有所帮助,请关注我。
源码获取:VX关注并回复chatgpt-0获得
- 实现的功能
假如有八个令牌,现在想让每一个令牌至于其前面的通信,如第5个令牌不与6,7,8位置的令牌通信(这是未来的令牌),只与4,3,2,1位置的令牌通信。因此只能通过以前的上下文信息猜测后面的;一种弱的通信方式是取前面的平局值。如5位置==5,4,3,2,1位置上的平局值。
- 实现
- 循环的版本
import torch
from torch.nn import functional as F
import torch.nn as nn
torch.manual_seed(1337)B,T,C = 4,8,2 #batch,time,channels
x = torch.randn(B,T,C)
xbow = torch.zeros((B,T,C))
print(f'x: {x[0]}')
for b in range(B):for t in range(T):xprev = x[b,:t+1] #()t,Cxbow[b,t] = torch.mean(xprev,0)
print(f'xbow: {xbow[0]}')#结果
x: tensor([[ 0.1808, -0.0700],[-0.3596, -0.9152],[ 0.6258, 0.0255],[ 0.9545, 0.0643],[ 0.3612, 1.1679],[-1.3499, -0.5102],[ 0.2360, -0.2398],[-0.9211, 1.5433]])
xbow: tensor([[ 0.1808, -0.0700],[-0.0894, -0.4926],[ 0.1490, -0.3199],[ 0.3504, -0.2238],[ 0.3525, 0.0545],[ 0.0688, -0.0396],[ 0.0927, -0.0682],[-0.0341, 0.1332]])
# 每一行至于自己以及自己以前的数据进行通信
- 通过数据矩阵高效实现
a = torch.tril(torch.ones(3,3)) #下三角函数
a = a/torch.sum(a,1,keepdim=True) #对a求平均数
b = torch.randint(0,10,(3,2)).float()
c = a @ bprint(f'a:{a}')
print(f'b:{b}')
print(f'c:{c}')#结果a:tensor([[1.0000, 0.0000, 0.0000],[0.5000, 0.5000, 0.0000],[0.3333, 0.3333, 0.3333]])
b:tensor([[0., 4.],[1., 2.],[5., 5.]])
c:tensor([[0.0000, 4.0000],[0.5000, 3.0000],[2.0000, 3.6667]])
- 使用Softmax
tril = torch.tril(torch.ones(T,T)) #下三角函数
print(f'tril:{tril}')wei = torch.zeros((T,T))
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大
print(f'wei: {wei}')
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数
print(f'wei: {wei}')
xbow3 = wei @ xprint(f'xbow3: {xbow3}')
print(torch.allclose(xbow,xbow3))
-
单头自注意力
- 上面的自注意力是通过相同的方式获取以往的信息。但是实际上并不希望是统一的方式,因为不同的token标记会发现其他不同的标记。
- 例如:我是元音,那么也许我正在寻找过去的辅音,或与我想知道这些辅音是什么。希望这些信息流向我,所以我现在想以依赖数据的方式收集过去的信息。这就是自注意力解决的问题。
- 方式如下:每个节点或每个位置的每个令牌都会发出两个向量,一个发出查询query,一个发出键key。查询向量粗略的说就是我要找的东西,键向量粗略的讲就是我包含什么。
- 现在在序列中获取这些标记之间的亲和力的方式基本上只是在键和查询之间做一个点乘积。所以我的查询与所有的其他tokens令牌的所有键进行点乘积。并且点积方式变了。如果键和查询有点对齐,它们将交互到非常高的数量,然后我将了解有关特定标记的更多信息,而不是其他不再序列中的任何其他标记。
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T)) #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
out = wei @ x
print(f'out:{out.shape}')
- 但是在真是中并不聚合到x而是计算一个v.x看作为该令牌的私人信息,与不同头交流的信息存储在v中
head_size = 16
key = nn.Linear(C,head_size,bias=False)
query = nn.Linear(C,head_size,bias=False)k = key(x) #(B,T,16)
q = key(x) #(B,T,16)
wei = q @ k.transpose(-2,-1) #转置时最后两个维度为负 (B,T,16) @ (B,16,T) ---> (B,T,T)tril = torch.tril(torch.ones(T,T)) #下三角函数
wei = wei.masked_fill(tril==0,float('-inf'))# mask填充,对于tril为0的填充负无穷大 主要是为了避免关注后面信息。如果想让所有节点进行交流删除词句。解码器中保留,编码器删除允许所有节点通信
wei = F.softmax(wei,dim=-1)# softmax对没一行的每个元素进行求幂,在求平均数 主要为了避免关注过小的信息主要是负数
print(f'wei: {wei[0]}')
value = nn.Linear(C,head_size,bias=False)
v = value(x)
out = wei @ v
print(f'out:{out.shape}')
简单实现自注意力
关注{晓理紫|小李子},获取技术推送信息,如感兴趣,请转发给有需要的同学,谢谢支持!!
如果你感觉对你有所帮助,请关注我。