本文是对Transformer - Attention is all you need 论文阅读-CSDN博客以及【李宏毅机器学习】Transformer 内容补充-CSDN博客的补充,是对相关代码的理解。
先说个题外话,在之前李宏毅老师的课程中提到multi-head attention是把得到的qkv分别乘上不同的矩阵,得到更多的qkv。
实际上,这里采用的方法是直接截取,比如这里有两个头,那么q^i就被分成两部分q^{i,1}和q^{i,2}。在BERT Intro-CSDN博客中有解释,也推荐手推transformer_哔哩哔哩_bilibili
self-attention
本节内容是self-attention这个模块的实现,会先从某一个句子开始,先不在乎怎么组装在一起批量的处理,只是单个拆开看看每一个部件是怎么work的。
现在需要解决的是:
- 输入怎么embedding?
- 位置信息怎么保留?
- 三个矩阵怎么初始化?
单个句子的attention
输入embedding
sentence = 'Life is short, eat dessert first'dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print("dictionary: {}".format(dc))sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print("sentence, but words have been replaced by index in dictionary: \n{}".format(sentence_int))torch.manual_seed(123)
# len(sentence.replace(',', '').split()) == 6
# embedded length == 16
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()print("sentence, but word embedded: \n{}".format(embedded_sentence))
dictionary: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5} sentence, but words have been replaced by index in dictionary: tensor([0, 4, 5, 2, 1, 3]) sentence, but word embedded: tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196, -0.3792,0.7671, -1.1925, 0.6984, -1.4097, 0.1794, 1.8951, 0.4954, 0.2692],[ 0.5146, 0.9938, -0.2587, -1.0826, -0.0444, 1.6236, -2.3229, 1.0878,0.6716, 0.6933, -0.9487, -0.0765, -0.1526, 0.1167, 0.4403, -1.4465],[ 0.2553, -0.5496, 1.0042, 0.8272, -0.3948, 0.4892, -0.2168, -1.7472,-1.6025, -1.0764, 0.9031, -0.7218, -0.5951, -0.7112, 0.6230, -1.3729],[-1.3250, 0.1784, -2.1338, 1.0524, -0.3885, -0.9343, -0.4991, -1.0867,0.8805, 1.5542, 0.6266, -0.1755, 0.0983, -0.0935, 0.2662, -0.5850],[-0.0770, -1.0205, -0.1690, 0.9178, 1.5810, 1.3010, 1.2753, -0.2010,0.4965, -1.5723, 0.9666, -1.1481, -1.1589, 0.3255, -0.6315, -2.8400],[ 0.8768, 1.6221, -1.4779, 1.1331, -1.2203, 1.3139, 1.0533, 0.1388,2.2473, -0.8036, -0.2808, 0.7697, -0.6596, -0.7979, 0.1838, 0.2293]])
Embedding — PyTorch 2.1 documentation
位置embedding
我发现这里似乎没有一个固定的名字,有叫position embedding的,有叫position encoding的,还有positional embedding和positional encoding,排列组合orz
### position embedding
def sinusoid_positional_encoding(length, dimensions):# odd position# cos(position/100000^{2i/d_model})# even position# sin(position/100000^{2i/d_model})def get_position_angle_vec(position):return [position / np.power(10000, 2*(i//2)/dimensions) for i in range(dimensions)]PE = np.array([get_position_angle_vec(i) for i in range(length)])PE[:, 0::2] = np.sin(PE[:, 0::2])PE[:, 1::2] = np.sin(PE[:, 1::2])return PE
embedded_position = torch.tensor(sinusoid_positional_encoding(6, 16))
print("position embedding: \n{}".format(embedded_position))
position embedding: tensor([[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,0.0000e+00],[ 8.4147e-01, 8.4147e-01, 3.1098e-01, 3.1098e-01, 9.9833e-02,9.9833e-02, 3.1618e-02, 3.1618e-02, 9.9998e-03, 9.9998e-03,3.1623e-03, 3.1623e-03, 1.0000e-03, 1.0000e-03, 3.1623e-04,3.1623e-04],[ 9.0930e-01, 9.0930e-01, 5.9113e-01, 5.9113e-01, 1.9867e-01,1.9867e-01, 6.3203e-02, 6.3203e-02, 1.9999e-02, 1.9999e-02,6.3245e-03, 6.3245e-03, 2.0000e-03, 2.0000e-03, 6.3246e-04,6.3246e-04],[ 1.4112e-01, 1.4112e-01, 8.1265e-01, 8.1265e-01, 2.9552e-01,2.9552e-01, 9.4726e-02, 9.4726e-02, 2.9996e-02, 2.9996e-02,9.4867e-03, 9.4867e-03, 3.0000e-03, 3.0000e-03, 9.4868e-04,9.4868e-04],[-7.5680e-01, -7.5680e-01, 9.5358e-01, 9.5358e-01, 3.8942e-01,3.8942e-01, 1.2615e-01, 1.2615e-01, 3.9989e-02, 3.9989e-02,1.2649e-02, 1.2649e-02, 4.0000e-03, 4.0000e-03, 1.2649e-03,1.2649e-03],[-9.5892e-01, -9.5892e-01, 9.9995e-01, 9.9995e-01, 4.7943e-01,4.7943e-01, 1.5746e-01, 1.5746e-01, 4.9979e-02, 4.9979e-02,1.5811e-02, 1.5811e-02, 5.0000e-03, 5.0000e-03, 1.5811e-03,1.5811e-03]], dtype=torch.float64)
在手推transformer_哔哩哔哩_bilibili中提到这一方法与傅里叶变换相关(这个细节是我在其他地方没有看到的,记录一下)
初始化权重矩阵
new_embedding =(embedded_position+embedded_sentence).to(torch.float32)
print('add embedding together:\n{}'.format(new_embedding))
torch.manual_seed(123)
d = new_embedding.shape[1]
print('embedding dimension:\n{}'.format(d))
d_q, d_k, d_v = 24, 24, 28
# torch.rand 均匀分布 torch.nn.Parameter 普通的tensor不可训练,转换成可以训练的类型
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print('size of query matrix: {}'.format(W_query.shape))
print('size of key matrix: {}'.format(W_key.shape))
print('size of value matrix: {}'.format(W_value.shape))
add embedding together: tensor([[ 0.3374, -0.1778, -0.3035, -0.5880, 0.3486, 0.6603, -0.2196, -0.3792,0.7671, -1.1925, 0.6984, -1.4097, 0.1794, 1.8951, 0.4954, 0.2692],[ 1.3561, 1.8352, 0.0523, -0.7716, 0.0555, 1.7234, -2.2913, 1.1194,0.6816, 0.7033, -0.9456, -0.0733, -0.1516, 0.1177, 0.4406, -1.4462],[ 1.1646, 0.3597, 1.5954, 1.4184, -0.1961, 0.6879, -0.1536, -1.6840,-1.5825, -1.0564, 0.9095, -0.7155, -0.5931, -0.7092, 0.6236, -1.3722],[-1.1838, 0.3195, -1.3211, 1.8650, -0.0930, -0.6388, -0.4044, -0.9919,0.9105, 1.5842, 0.6361, -0.1660, 0.1013, -0.0905, 0.2672, -0.5841],[-0.8338, -1.7773, 0.7846, 1.8713, 1.9704, 1.6905, 1.4015, -0.0748,0.5365, -1.5323, 0.9792, -1.1355, -1.1549, 0.3295, -0.6302, -2.8387],[-0.0821, 0.6632, -0.4780, 2.1331, -0.7409, 1.7933, 1.2108, 0.2963,2.2973, -0.7537, -0.2650, 0.7855, -0.6546, -0.7929, 0.1854, 0.2309]],dtype=torch.float64) embedding dimension: 16 size of query matrix: torch.Size([24, 16]) size of key matrix: torch.Size([24, 16]) size of value matrix: torch.Size([28, 16])
Parameter — PyTorch 2.1 documentation
OK,我们在这里先断一下,整理一下:
此时sequence:new_embedding(来源:word embedding+position embedding)
word embedding:6×16(有6个token,每个token用16维向量表示)
position embedding:6×16(和word embedding大小相同,因为要相加)
q:24×16
k:24×16
v:28×16
这样后面再计算query的时候就是每个token(1×16)×q(24×16),反正两个得转置一个
计算qkv
x_1 = embedded_sentence[0]
query_1 = W_query.matmul(x_1)
key_1 = W_key.matmul(x_1)
value_1 = W_value.matmul(x_1)x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)
torch.matmul — PyTorch 2.1 documentation
querys = W_key.matmul(new_embedding.T).T
keys = W_key.matmul(new_embedding.T).T
values = W_value.matmul(new_embedding.T).Tprint("querys.shape:", querys.shape)
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
querys.shape: torch.Size([6, 24]) keys.shape: torch.Size([6, 24]) values.shape: torch.Size([6, 28])
计算attention score
alpha_24 = query_2.dot(keys[4])
print(alpha_24)
比如这里,就是第2个query对第5个key的attention
import torch.nn.functional as Fattention_score = F.softmax(keys.matmul(querys.T) / d_k**0.5, dim=0)
print(attention_score)
tensor([[3.8184e-01, 3.7217e-08, 2.8697e-08, 2.3739e-03, 8.8205e-04, 2.0233e-18],[5.1460e-03, 3.1125e-02, 3.3185e-09, 1.9323e-03, 1.3870e-07, 4.5397e-13],[1.8988e-04, 1.5880e-10, 9.9998e-01, 6.8005e-04, 8.3661e-03, 2.5704e-26],[2.1968e-04, 1.2932e-09, 9.5111e-09, 7.5759e-01, 6.8821e-05, 7.0347e-20],[1.5536e-02, 1.7667e-11, 2.2270e-05, 1.3099e-02, 9.9068e-01, 3.1426e-23],[5.9707e-01, 9.6887e-01, 1.1463e-12, 2.2433e-01, 5.2653e-07, 1.0000e+00]],grad_fn=<SoftmaxBackward0>)
获得context value
context_vector_2 = attention_score[2].matmul(values)
print(context_vector_2)
tensor([-2.8135, -0.2665, -0.1881, 0.4058, 0.8079, -3.1120, 0.5449, -1.2232,-0.1618, 0.3803, 0.6926, -0.4669, 0.2446, -0.3647, -0.0034, -2.2524,-2.7228, -1.5109, -0.7725, -1.0958, -2.1254, 0.3064, 0.5129, -0.1340,0.7020, -2.2086, -1.9595, 0.4520], grad_fn=<SqueezeBackward4>)
context_vector = attention_score.matmul(values)
print(context_vector)
tensor([[ 2.8488e-01, 6.4077e-01, 1.0665e+00, 5.5947e-01, -3.2868e-01,4.2391e-01, -3.2123e-01, 1.0594e-01, 6.5982e-01, 6.1927e-01,8.2067e-01, 4.3722e-01, 6.4925e-01, 5.9935e-01, 6.7425e-01,3.6706e-01, 5.0318e-01, 9.9682e-02, 1.1377e-01, 1.2804e-01,9.1880e-01, 7.6178e-01, -4.2619e-01, 2.5550e-01, -8.1348e-02,3.1145e-01, 1.9705e-01, 3.8195e-01],[ 3.6250e-02, 3.7593e-02, 8.9476e-02, 9.9750e-02, 9.1430e-02,6.2556e-02, 5.8136e-02, 5.5746e-02, 3.5098e-02, 4.1406e-02,4.1621e-02, 1.9771e-02, 4.0799e-02, -4.7170e-03, 4.1176e-02,4.3792e-02, 6.2029e-02, 5.2132e-02, 7.6929e-03, 5.4507e-02,1.4537e-02, 6.9540e-02, 4.1809e-02, 5.8921e-02, 1.2542e-02,1.4625e-01, 3.0627e-02, 1.0624e-01],[-2.8135e+00, -2.6652e-01, -1.8809e-01, 4.0583e-01, 8.0793e-01,-3.1120e+00, 5.4491e-01, -1.2232e+00, -1.6184e-01, 3.8030e-01,6.9257e-01, -4.6693e-01, 2.4462e-01, -3.6468e-01, -3.3741e-03,-2.2524e+00, -2.7228e+00, -1.5109e+00, -7.7255e-01, -1.0958e+00,-2.1254e+00, 3.0638e-01, 5.1293e-01, -1.3400e-01, 7.0203e-01,-2.2086e+00, -1.9595e+00, 4.5198e-01],[ 1.3995e+00, -5.1583e-02, -7.6128e-01, 6.2276e-01, 1.4197e+00,-1.1195e+00, 2.6502e-01, 9.7265e-02, -1.3257e+00, 5.2765e-01,-9.0406e-01, 1.0977e+00, 1.0775e+00, -1.1202e+00, -5.3005e-01,1.1657e+00, 5.2906e-01, -3.4296e-01, -1.0341e+00, -9.9314e-02,2.4160e-01, 1.0506e+00, -2.5196e-01, -1.2585e+00, 7.7441e-01,-3.8052e-02, 1.4004e+00, 4.0364e-01],[-1.9422e+00, -1.1669e-01, 2.4155e+00, -6.0575e-01, 1.1378e-01,-8.1691e-01, 2.8678e-01, -2.6922e+00, 1.9804e+00, 2.7446e+00,1.9828e-01, -1.5773e+00, -5.2589e-01, 2.2252e+00, -2.9130e-01,-4.2694e+00, 2.4834e+00, -3.3346e+00, -2.5167e-01, -2.8141e+00,1.3780e+00, -1.5563e-01, -1.4588e+00, 5.3617e-01, -5.3745e-01,-7.6528e-01, 1.2408e+00, 3.5827e+00],[ 5.3134e+00, 3.5967e+00, 7.1373e+00, 5.9613e+00, 6.1520e+00,5.0065e+00, 4.2107e+00, 5.2589e+00, 9.2143e-01, 6.5614e+00,2.7412e+00, 4.6712e+00, 4.9725e+00, 2.2118e+00, 5.2451e+00,4.4219e+00, 4.5800e+00, 2.9179e+00, 2.2116e+00, 5.3678e+00,5.7133e+00, 7.1016e+00, 3.7317e+00, 5.1325e+00, 4.1306e+00,9.4941e+00, 5.6733e+00, 9.7489e+00]], grad_fn=<MmBackward0>)
参考链接
Positional Encoding: Everything You Need to Know - inovex GmbH
Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch (sebastianraschka.com)
2021-03-18-Transformers - Multihead Self Attention Explanation & Implementation in Pytorch.ipynb - Colaboratory (google.com)
通俗易懂的理解傅里叶变换(一)[收藏] - 知乎 (zhihu.com)
Linear Relationships in the Transformer’s Positional Encoding - Timo Denk's Blog
Transformer 中的 positional embedding - 知乎 (zhihu.com)
transformer中使用的position embedding为什么是加法? - 知乎 (zhihu.com)
multi-head self-attention
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copyclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 这里d_k是每个key和query的size,同时在后面归一化也需要使用self.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):# 计算attention score,Q和K反正得转置一个,看怎么定义# 比如现在的attn_scores的第(i,j)位置表示:# 第i个query对第k个key的attention(相关性高低)attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)attn_probs = torch.softmax(attn_scores, dim=-1)output = torch.matmul(attn_probs, V)return outputdef split_heads(self, x):batch_size, seq_length, d_model = x.size()return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):batch_size, _, seq_length, d_k = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)def forward(self, Q, K, V, mask=None):Q = self.split_heads(self.W_q(Q))K = self.split_heads(self.W_k(K))V = self.split_heads(self.W_v(V))attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.W_o(self.combine_heads(attn_output))return output
这样就大概看懂了orz。
view
维度变化
randmat = torch.rand((3, 2, 5))
print("view(2, 3, 5): \n{}".format(randmat.view(2,3,5)))
print("view(2, 5, 3): \n{}".format(randmat.view(2,3,5)))
view(2, 3, 5): tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],[0.3409, 0.5355, 0.3474, 0.8371, 0.6785],[0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],[0.1633, 0.5220, 0.7583, 0.7841, 0.0838],[0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]]) view(2, 5, 3): tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],[0.3409, 0.5355, 0.3474, 0.8371, 0.6785],[0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],[0.1633, 0.5220, 0.7583, 0.7841, 0.0838],[0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]]) view(5, 2, 3): tensor([[[0.8058, 0.3869, 0.7523],[0.1501, 0.1501, 0.3409],[0.5355, 0.3474, 0.8371],[0.6785, 0.6564, 0.8204],[0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118],[0.8868, 0.4249, 0.1633],[0.5220, 0.7583, 0.7841],[0.0838, 0.4304, 0.5082],[0.3141, 0.1689, 0.0869]]])
transpose
randmat = torch.rand((3, 2, 5))
print(randmat)
print("tanspose(-2,-1): \n{}".format(randmat.transpose(-2,-1)))
print("transpose(1,2): \n{}".format(randmat.transpose(1,2)))
tensor([[[0.3440, 0.9779, 0.9154, 0.6843, 0.9358],[0.5081, 0.7446, 0.0274, 0.6329, 0.6427]],[[0.6770, 0.6826, 0.2888, 0.8483, 0.9896],[0.1457, 0.3154, 0.6381, 0.6555, 0.2204]],[[0.4549, 0.0385, 0.1135, 0.8426, 0.8534],[0.7915, 0.4030, 0.8209, 0.3390, 0.6290]]]) tanspose(-2,-1): tensor([[[0.3440, 0.5081],[0.9779, 0.7446],[0.9154, 0.0274],[0.6843, 0.6329],[0.9358, 0.6427]],[[0.6770, 0.1457],[0.6826, 0.3154],[0.2888, 0.6381],[0.8483, 0.6555],[0.9896, 0.2204]],[[0.4549, 0.7915],[0.0385, 0.4030],[0.1135, 0.8209],[0.8426, 0.3390],[0.8534, 0.6290]]]) transpose(1,2): tensor([[[0.3440, 0.5081],[0.9779, 0.7446],[0.9154, 0.0274],[0.6843, 0.6329],[0.9358, 0.6427]],[[0.6770, 0.1457],[0.6826, 0.3154],[0.2888, 0.6381],[0.8483, 0.6555],[0.9896, 0.2204]],[[0.4549, 0.7915],[0.0385, 0.4030],[0.1135, 0.8209],[0.8426, 0.3390],[0.8534, 0.6290]]])
参考链接
Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Python numpy.transpose 详解 - 我的明天不是梦 - 博客园 (cnblogs.com)
附录
"""Self-attention module1. Read the code and explain the following:- The nature of the dataset- The data flow- The shapes of the tensors- Why can the attention module be used for this dataset?
2. Create a training loop and evaluate the model according to the instructions
"""
import copy
import mathimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.auto import tqdmclass SampleDataset(Dataset):def __init__(self,size: int = 1024,emb_dim: int = 32,sequence_length: int = 8,n_classes: int = 3,):self.embeddings = torch.randn(size, emb_dim)self.sequence_length = sequence_lengthself.n_classes = n_classesdef __len__(self):return len(self.embeddings) - self.sequence_length + 1def __getitem__(self, idx):indices = np.random.choice(np.arange(0, len(self.embeddings)), self.sequence_length)# np.random.shuffle(indices)return (self.embeddings[indices], # sequence_length x emb_dimtorch.tensor(np.max(indices) % self.n_classes),)def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"# The length of the key and value sequences need to be the samed_k = query.size(-1)# N *scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim=-1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attndef clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class MultiHeadAttention(nn.Module):def __init__(self, heads, d_model, dropout=0.1):"Take in model size and number of heads."super().__init__()assert d_model % heads == 0# We assume d_v always equals d_kself.d = d_model // heads # d_model: 32 heads: 4self.h = heads # h: 4self.linears = clones(nn.Linear(d_model, d_model), 4) # 4 identical layers (input: 32, output: 32)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_kquery, key, value = [l(x).view(nbatches, -1, self.h, self.d).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))] # 4 x # 2) Apply attention on all the projected vectors in batch.x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear.x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d)return self.linears[-1](x)class SequenceClassifier(nn.Module):def __init__(self, heads: int = 4, d_model: int = 32, n_classes: int = 3):super().__init__()self.attention = MultiHeadAttention(heads, d_model)self.linear = nn.Linear(d_model, n_classes)def forward(self, x):# x: N x sequence_length x emb_dimx = self.attention(x, x, x)x = self.linear(x[:, 0])return xdef main(n_epochs: int = 1000,size: int = 256,emb_dim: int = 128,sequence_length: int = 8,n_classes: int = 3,
):dataset = SampleDataset(size=size, emb_dim=emb_dim, sequence_length=sequence_length, n_classes=n_classes)# TODO: create a training loop# TODO: Evaluate with the same dataset# TODO: Evaluate with a different sequence length (12)if __name__ == "__main__":main()