Transformer 代码补充

本文是对Transformer - Attention is all you need 论文阅读-CSDN博客以及【李宏毅机器学习】Transformer 内容补充-CSDN博客的补充,是对相关代码的理解。

先说个题外话,在之前李宏毅老师的课程中提到multi-head attention是把得到的qkv分别乘上不同的矩阵,得到更多的qkv。

实际上,这里采用的方法是直接截取,比如这里有两个头,那么q^i就被分成两部分q^{i,1}和q^{i,2}。在BERT Intro-CSDN博客中有解释,也推荐手推transformer_哔哩哔哩_bilibili

self-attention

本节内容是self-attention这个模块的实现,会先从某一个句子开始,先不在乎怎么组装在一起批量的处理,只是单个拆开看看每一个部件是怎么work的。

现在需要解决的是:

  1. 输入怎么embedding? 
  2. 位置信息怎么保留?
  3. 三个矩阵怎么初始化?

单个句子的attention

输入embedding

sentence = 'Life is short, eat dessert first'dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
print("dictionary: {}".format(dc))sentence_int = torch.tensor([dc[s] for s in sentence.replace(',', '').split()])
print("sentence, but words have been replaced by index in dictionary: \n{}".format(sentence_int))torch.manual_seed(123)
# len(sentence.replace(',', '').split()) == 6
# embedded length == 16
embed = torch.nn.Embedding(6, 16)
embedded_sentence = embed(sentence_int).detach()print("sentence, but word embedded: \n{}".format(embedded_sentence))
dictionary: {'Life': 0, 'dessert': 1, 'eat': 2, 'first': 3, 'is': 4, 'short': 5}
sentence, but words have been replaced by index in dictionary: 
tensor([0, 4, 5, 2, 1, 3])
sentence, but word embedded: 
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],[ 0.5146,  0.9938, -0.2587, -1.0826, -0.0444,  1.6236, -2.3229,  1.0878,0.6716,  0.6933, -0.9487, -0.0765, -0.1526,  0.1167,  0.4403, -1.4465],[ 0.2553, -0.5496,  1.0042,  0.8272, -0.3948,  0.4892, -0.2168, -1.7472,-1.6025, -1.0764,  0.9031, -0.7218, -0.5951, -0.7112,  0.6230, -1.3729],[-1.3250,  0.1784, -2.1338,  1.0524, -0.3885, -0.9343, -0.4991, -1.0867,0.8805,  1.5542,  0.6266, -0.1755,  0.0983, -0.0935,  0.2662, -0.5850],[-0.0770, -1.0205, -0.1690,  0.9178,  1.5810,  1.3010,  1.2753, -0.2010,0.4965, -1.5723,  0.9666, -1.1481, -1.1589,  0.3255, -0.6315, -2.8400],[ 0.8768,  1.6221, -1.4779,  1.1331, -1.2203,  1.3139,  1.0533,  0.1388,2.2473, -0.8036, -0.2808,  0.7697, -0.6596, -0.7979,  0.1838,  0.2293]])

Embedding — PyTorch 2.1 documentation

位置embedding

我发现这里似乎没有一个固定的名字,有叫position embedding的,有叫position encoding的,还有positional embedding和positional encoding,排列组合orz

### position embedding
def sinusoid_positional_encoding(length, dimensions):# odd position# cos(position/100000^{2i/d_model})# even position# sin(position/100000^{2i/d_model})def get_position_angle_vec(position):return [position / np.power(10000, 2*(i//2)/dimensions) for i in range(dimensions)]PE = np.array([get_position_angle_vec(i) for i in range(length)])PE[:, 0::2] = np.sin(PE[:, 0::2])PE[:, 1::2] = np.sin(PE[:, 1::2])return PE
embedded_position = torch.tensor(sinusoid_positional_encoding(6, 16))
print("position embedding: \n{}".format(embedded_position))
position embedding: 
tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,0.0000e+00],[ 8.4147e-01,  8.4147e-01,  3.1098e-01,  3.1098e-01,  9.9833e-02,9.9833e-02,  3.1618e-02,  3.1618e-02,  9.9998e-03,  9.9998e-03,3.1623e-03,  3.1623e-03,  1.0000e-03,  1.0000e-03,  3.1623e-04,3.1623e-04],[ 9.0930e-01,  9.0930e-01,  5.9113e-01,  5.9113e-01,  1.9867e-01,1.9867e-01,  6.3203e-02,  6.3203e-02,  1.9999e-02,  1.9999e-02,6.3245e-03,  6.3245e-03,  2.0000e-03,  2.0000e-03,  6.3246e-04,6.3246e-04],[ 1.4112e-01,  1.4112e-01,  8.1265e-01,  8.1265e-01,  2.9552e-01,2.9552e-01,  9.4726e-02,  9.4726e-02,  2.9996e-02,  2.9996e-02,9.4867e-03,  9.4867e-03,  3.0000e-03,  3.0000e-03,  9.4868e-04,9.4868e-04],[-7.5680e-01, -7.5680e-01,  9.5358e-01,  9.5358e-01,  3.8942e-01,3.8942e-01,  1.2615e-01,  1.2615e-01,  3.9989e-02,  3.9989e-02,1.2649e-02,  1.2649e-02,  4.0000e-03,  4.0000e-03,  1.2649e-03,1.2649e-03],[-9.5892e-01, -9.5892e-01,  9.9995e-01,  9.9995e-01,  4.7943e-01,4.7943e-01,  1.5746e-01,  1.5746e-01,  4.9979e-02,  4.9979e-02,1.5811e-02,  1.5811e-02,  5.0000e-03,  5.0000e-03,  1.5811e-03,1.5811e-03]], dtype=torch.float64)

在手推transformer_哔哩哔哩_bilibili中提到这一方法与傅里叶变换相关(这个细节是我在其他地方没有看到的,记录一下)

初始化权重矩阵

new_embedding =(embedded_position+embedded_sentence).to(torch.float32)
print('add embedding together:\n{}'.format(new_embedding))
torch.manual_seed(123)
​
d = new_embedding.shape[1]
print('embedding dimension:\n{}'.format(d))
​
d_q, d_k, d_v = 24, 24, 28
​
# torch.rand 均匀分布 torch.nn.Parameter 普通的tensor不可训练,转换成可以训练的类型
W_query = torch.nn.Parameter(torch.rand(d_q, d))
W_key = torch.nn.Parameter(torch.rand(d_k, d))
W_value = torch.nn.Parameter(torch.rand(d_v, d))
print('size of query matrix: {}'.format(W_query.shape))
print('size of key matrix: {}'.format(W_key.shape))
print('size of value matrix: {}'.format(W_value.shape))
add embedding together:
tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  0.3486,  0.6603, -0.2196, -0.3792,0.7671, -1.1925,  0.6984, -1.4097,  0.1794,  1.8951,  0.4954,  0.2692],[ 1.3561,  1.8352,  0.0523, -0.7716,  0.0555,  1.7234, -2.2913,  1.1194,0.6816,  0.7033, -0.9456, -0.0733, -0.1516,  0.1177,  0.4406, -1.4462],[ 1.1646,  0.3597,  1.5954,  1.4184, -0.1961,  0.6879, -0.1536, -1.6840,-1.5825, -1.0564,  0.9095, -0.7155, -0.5931, -0.7092,  0.6236, -1.3722],[-1.1838,  0.3195, -1.3211,  1.8650, -0.0930, -0.6388, -0.4044, -0.9919,0.9105,  1.5842,  0.6361, -0.1660,  0.1013, -0.0905,  0.2672, -0.5841],[-0.8338, -1.7773,  0.7846,  1.8713,  1.9704,  1.6905,  1.4015, -0.0748,0.5365, -1.5323,  0.9792, -1.1355, -1.1549,  0.3295, -0.6302, -2.8387],[-0.0821,  0.6632, -0.4780,  2.1331, -0.7409,  1.7933,  1.2108,  0.2963,2.2973, -0.7537, -0.2650,  0.7855, -0.6546, -0.7929,  0.1854,  0.2309]],dtype=torch.float64)
embedding dimension:
16
size of query matrix: torch.Size([24, 16])
size of key matrix: torch.Size([24, 16])
size of value matrix: torch.Size([28, 16])

Parameter — PyTorch 2.1 documentation

OK,我们在这里先断一下,整理一下:

此时sequence:new_embedding(来源:word embedding+position embedding)

word embedding:6×16(有6个token,每个token用16维向量表示)

position embedding:6×16(和word embedding大小相同,因为要相加)

q:24×16

k:24×16

v:28×16

这样后面再计算query的时候就是每个token(1×16)×q(24×16),反正两个得转置一个

计算qkv

x_1 = embedded_sentence[0]
query_1 = W_query.matmul(x_1)
key_1 = W_key.matmul(x_1)
value_1 = W_value.matmul(x_1)x_2 = embedded_sentence[1]
query_2 = W_query.matmul(x_2)
key_2 = W_key.matmul(x_2)
value_2 = W_value.matmul(x_2)

torch.matmul — PyTorch 2.1 documentation

querys = W_key.matmul(new_embedding.T).T
keys = W_key.matmul(new_embedding.T).T
values = W_value.matmul(new_embedding.T).Tprint("querys.shape:", querys.shape)
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
querys.shape: torch.Size([6, 24])
keys.shape: torch.Size([6, 24])
values.shape: torch.Size([6, 28])

计算attention score

alpha_24 = query_2.dot(keys[4])
print(alpha_24)

比如这里,就是第2个query对第5个key的attention

import torch.nn.functional as Fattention_score = F.softmax(keys.matmul(querys.T) / d_k**0.5, dim=0)
print(attention_score)
tensor([[3.8184e-01, 3.7217e-08, 2.8697e-08, 2.3739e-03, 8.8205e-04, 2.0233e-18],[5.1460e-03, 3.1125e-02, 3.3185e-09, 1.9323e-03, 1.3870e-07, 4.5397e-13],[1.8988e-04, 1.5880e-10, 9.9998e-01, 6.8005e-04, 8.3661e-03, 2.5704e-26],[2.1968e-04, 1.2932e-09, 9.5111e-09, 7.5759e-01, 6.8821e-05, 7.0347e-20],[1.5536e-02, 1.7667e-11, 2.2270e-05, 1.3099e-02, 9.9068e-01, 3.1426e-23],[5.9707e-01, 9.6887e-01, 1.1463e-12, 2.2433e-01, 5.2653e-07, 1.0000e+00]],grad_fn=<SoftmaxBackward0>)

获得context value

context_vector_2 = attention_score[2].matmul(values)
print(context_vector_2)
tensor([-2.8135, -0.2665, -0.1881,  0.4058,  0.8079, -3.1120,  0.5449, -1.2232,-0.1618,  0.3803,  0.6926, -0.4669,  0.2446, -0.3647, -0.0034, -2.2524,-2.7228, -1.5109, -0.7725, -1.0958, -2.1254,  0.3064,  0.5129, -0.1340,0.7020, -2.2086, -1.9595,  0.4520], grad_fn=<SqueezeBackward4>)
context_vector = attention_score.matmul(values)
print(context_vector)
tensor([[ 2.8488e-01,  6.4077e-01,  1.0665e+00,  5.5947e-01, -3.2868e-01,4.2391e-01, -3.2123e-01,  1.0594e-01,  6.5982e-01,  6.1927e-01,8.2067e-01,  4.3722e-01,  6.4925e-01,  5.9935e-01,  6.7425e-01,3.6706e-01,  5.0318e-01,  9.9682e-02,  1.1377e-01,  1.2804e-01,9.1880e-01,  7.6178e-01, -4.2619e-01,  2.5550e-01, -8.1348e-02,3.1145e-01,  1.9705e-01,  3.8195e-01],[ 3.6250e-02,  3.7593e-02,  8.9476e-02,  9.9750e-02,  9.1430e-02,6.2556e-02,  5.8136e-02,  5.5746e-02,  3.5098e-02,  4.1406e-02,4.1621e-02,  1.9771e-02,  4.0799e-02, -4.7170e-03,  4.1176e-02,4.3792e-02,  6.2029e-02,  5.2132e-02,  7.6929e-03,  5.4507e-02,1.4537e-02,  6.9540e-02,  4.1809e-02,  5.8921e-02,  1.2542e-02,1.4625e-01,  3.0627e-02,  1.0624e-01],[-2.8135e+00, -2.6652e-01, -1.8809e-01,  4.0583e-01,  8.0793e-01,-3.1120e+00,  5.4491e-01, -1.2232e+00, -1.6184e-01,  3.8030e-01,6.9257e-01, -4.6693e-01,  2.4462e-01, -3.6468e-01, -3.3741e-03,-2.2524e+00, -2.7228e+00, -1.5109e+00, -7.7255e-01, -1.0958e+00,-2.1254e+00,  3.0638e-01,  5.1293e-01, -1.3400e-01,  7.0203e-01,-2.2086e+00, -1.9595e+00,  4.5198e-01],[ 1.3995e+00, -5.1583e-02, -7.6128e-01,  6.2276e-01,  1.4197e+00,-1.1195e+00,  2.6502e-01,  9.7265e-02, -1.3257e+00,  5.2765e-01,-9.0406e-01,  1.0977e+00,  1.0775e+00, -1.1202e+00, -5.3005e-01,1.1657e+00,  5.2906e-01, -3.4296e-01, -1.0341e+00, -9.9314e-02,2.4160e-01,  1.0506e+00, -2.5196e-01, -1.2585e+00,  7.7441e-01,-3.8052e-02,  1.4004e+00,  4.0364e-01],[-1.9422e+00, -1.1669e-01,  2.4155e+00, -6.0575e-01,  1.1378e-01,-8.1691e-01,  2.8678e-01, -2.6922e+00,  1.9804e+00,  2.7446e+00,1.9828e-01, -1.5773e+00, -5.2589e-01,  2.2252e+00, -2.9130e-01,-4.2694e+00,  2.4834e+00, -3.3346e+00, -2.5167e-01, -2.8141e+00,1.3780e+00, -1.5563e-01, -1.4588e+00,  5.3617e-01, -5.3745e-01,-7.6528e-01,  1.2408e+00,  3.5827e+00],[ 5.3134e+00,  3.5967e+00,  7.1373e+00,  5.9613e+00,  6.1520e+00,5.0065e+00,  4.2107e+00,  5.2589e+00,  9.2143e-01,  6.5614e+00,2.7412e+00,  4.6712e+00,  4.9725e+00,  2.2118e+00,  5.2451e+00,4.4219e+00,  4.5800e+00,  2.9179e+00,  2.2116e+00,  5.3678e+00,5.7133e+00,  7.1016e+00,  3.7317e+00,  5.1325e+00,  4.1306e+00,9.4941e+00,  5.6733e+00,  9.7489e+00]], grad_fn=<MmBackward0>)

参考链接

Positional Encoding: Everything You Need to Know - inovex GmbH
Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Understanding and Coding the Self-Attention Mechanism of Large Language Models From Scratch (sebastianraschka.com)
2021-03-18-Transformers - Multihead Self Attention Explanation & Implementation in Pytorch.ipynb - Colaboratory (google.com)
通俗易懂的理解傅里叶变换(一)[收藏] - 知乎 (zhihu.com)
Linear Relationships in the Transformer’s Positional Encoding - Timo Denk's Blog
Transformer 中的 positional embedding - 知乎 (zhihu.com)
transformer中使用的position embedding为什么是加法? - 知乎 (zhihu.com)

multi-head self-attention

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copyclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_seq_length):super(PositionalEncoding, self).__init__()pe = torch.zeros(max_seq_length, d_model)position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe.unsqueeze(0))def forward(self, x):return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_heads# 这里d_k是每个key和query的size,同时在后面归一化也需要使用self.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.W_o = nn.Linear(d_model, d_model)def scaled_dot_product_attention(self, Q, K, V, mask=None):# 计算attention score,Q和K反正得转置一个,看怎么定义# 比如现在的attn_scores的第(i,j)位置表示:# 第i个query对第k个key的attention(相关性高低)attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)if mask is not None:attn_scores = attn_scores.masked_fill(mask == 0, -1e9)attn_probs = torch.softmax(attn_scores, dim=-1)output = torch.matmul(attn_probs, V)return outputdef split_heads(self, x):batch_size, seq_length, d_model = x.size()return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)def combine_heads(self, x):batch_size, _, seq_length, d_k = x.size()return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)def forward(self, Q, K, V, mask=None):Q = self.split_heads(self.W_q(Q))K = self.split_heads(self.W_k(K))V = self.split_heads(self.W_v(V))attn_output = self.scaled_dot_product_attention(Q, K, V, mask)output = self.W_o(self.combine_heads(attn_output))return output

这样就大概看懂了orz。

view

维度变化

randmat = torch.rand((3, 2, 5))
print("view(2, 3, 5): \n{}".format(randmat.view(2,3,5)))
print("view(2, 5, 3): \n{}".format(randmat.view(2,3,5)))
view(2, 3, 5): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],[0.3409, 0.5355, 0.3474, 0.8371, 0.6785],[0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],[0.1633, 0.5220, 0.7583, 0.7841, 0.0838],[0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(2, 5, 3): 
tensor([[[0.8058, 0.3869, 0.7523, 0.1501, 0.1501],[0.3409, 0.5355, 0.3474, 0.8371, 0.6785],[0.6564, 0.8204, 0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118, 0.8868, 0.4249],[0.1633, 0.5220, 0.7583, 0.7841, 0.0838],[0.4304, 0.5082, 0.3141, 0.1689, 0.0869]]])
view(5, 2, 3): 
tensor([[[0.8058, 0.3869, 0.7523],[0.1501, 0.1501, 0.3409],[0.5355, 0.3474, 0.8371],[0.6785, 0.6564, 0.8204],[0.0539, 0.7422, 0.2216]],[[0.9450, 0.7839, 0.7118],[0.8868, 0.4249, 0.1633],[0.5220, 0.7583, 0.7841],[0.0838, 0.4304, 0.5082],[0.3141, 0.1689, 0.0869]]])

transpose

randmat = torch.rand((3, 2, 5))
print(randmat)
print("tanspose(-2,-1): \n{}".format(randmat.transpose(-2,-1)))
print("transpose(1,2): \n{}".format(randmat.transpose(1,2)))
tensor([[[0.3440, 0.9779, 0.9154, 0.6843, 0.9358],[0.5081, 0.7446, 0.0274, 0.6329, 0.6427]],[[0.6770, 0.6826, 0.2888, 0.8483, 0.9896],[0.1457, 0.3154, 0.6381, 0.6555, 0.2204]],[[0.4549, 0.0385, 0.1135, 0.8426, 0.8534],[0.7915, 0.4030, 0.8209, 0.3390, 0.6290]]])
tanspose(-2,-1): 
tensor([[[0.3440, 0.5081],[0.9779, 0.7446],[0.9154, 0.0274],[0.6843, 0.6329],[0.9358, 0.6427]],[[0.6770, 0.1457],[0.6826, 0.3154],[0.2888, 0.6381],[0.8483, 0.6555],[0.9896, 0.2204]],[[0.4549, 0.7915],[0.0385, 0.4030],[0.1135, 0.8209],[0.8426, 0.3390],[0.8534, 0.6290]]])
transpose(1,2): 
tensor([[[0.3440, 0.5081],[0.9779, 0.7446],[0.9154, 0.0274],[0.6843, 0.6329],[0.9358, 0.6427]],[[0.6770, 0.1457],[0.6826, 0.3154],[0.2888, 0.6381],[0.8483, 0.6555],[0.9896, 0.2204]],[[0.4549, 0.7915],[0.0385, 0.4030],[0.1135, 0.8209],[0.8426, 0.3390],[0.8534, 0.6290]]])

参考链接

Build your own Transformer from scratch using Pytorch | by Arjun Sarkar | Towards Data Science
Python numpy.transpose 详解 - 我的明天不是梦 - 博客园 (cnblogs.com)

附录

"""Self-attention module1. Read the code and explain the following:- The nature of the dataset- The data flow- The shapes of the tensors- Why can the attention module be used for this dataset?
2. Create a training loop and evaluate the model according to the instructions
"""
import copy
import mathimport torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm.auto import tqdmclass SampleDataset(Dataset):def __init__(self,size: int = 1024,emb_dim: int = 32,sequence_length: int = 8,n_classes: int = 3,):self.embeddings = torch.randn(size, emb_dim)self.sequence_length = sequence_lengthself.n_classes = n_classesdef __len__(self):return len(self.embeddings) - self.sequence_length + 1def __getitem__(self, idx):indices = np.random.choice(np.arange(0, len(self.embeddings)), self.sequence_length)# np.random.shuffle(indices)return (self.embeddings[indices],  # sequence_length x emb_dimtorch.tensor(np.max(indices) % self.n_classes),)def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"# The length of the key and value sequences need to be the samed_k = query.size(-1)# N *scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim=-1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attndef clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class MultiHeadAttention(nn.Module):def __init__(self, heads, d_model, dropout=0.1):"Take in model size and number of heads."super().__init__()assert d_model % heads == 0# We assume d_v always equals d_kself.d = d_model // heads # d_model: 32 heads: 4self.h = heads # h: 4self.linears = clones(nn.Linear(d_model, d_model), 4) # 4 identical layers (input: 32, output: 32)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_kquery, key, value = [l(x).view(nbatches, -1, self.h, self.d).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))] # 4 x # 2) Apply attention on all the projected vectors in batch.x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear.x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d)return self.linears[-1](x)class SequenceClassifier(nn.Module):def __init__(self, heads: int = 4, d_model: int = 32, n_classes: int = 3):super().__init__()self.attention = MultiHeadAttention(heads, d_model)self.linear = nn.Linear(d_model, n_classes)def forward(self, x):# x: N x sequence_length x emb_dimx = self.attention(x, x, x)x = self.linear(x[:, 0])return xdef main(n_epochs: int = 1000,size: int = 256,emb_dim: int = 128,sequence_length: int = 8,n_classes: int = 3,
):dataset = SampleDataset(size=size, emb_dim=emb_dim, sequence_length=sequence_length, n_classes=n_classes)# TODO: create a training loop# TODO: Evaluate with the same dataset# TODO: Evaluate with a different sequence length (12)if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/658316.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java】Lombok的使用

一、Lombok是什么&#xff1f; Lombok是一个Java库&#xff0c;能自动插入编辑器并构建工具&#xff0c;简化Java开发。通过添加注解的方式&#xff0c;不需要为类编写getter或eques方法&#xff0c;同时可以自动化日志变量&#x1f680; 在我们封装一个类时&#xff0c;最常用…

CUDA编程- - GPU线程的理解 thread,block,grid - 学习记录

GPU线程的理解 thread,block,grid 一、从 cpu 多线程角度理解 gpu 多线程1、cpu 多线程并行加速2、gpu多线程并行加速2.1、cpu 线程与 gpu 线程的理解&#xff08;核函数&#xff09;2.1.1 、第一步&#xff1a;编写核函数2.1.2、第二步&#xff1a;调用核函数&#xff08;使用…

网络代理用途

网络代理的用途广泛&#xff0c;常用于代理爬虫&#xff0c;代理VPN &#xff0c;代理注入等。使用网络 代理能够将入侵痕迹进一步减少&#xff0c;能够突破自身IP的访问限制&#xff0c;提高访问速度&#xff0c; 以及隐藏真实IP &#xff0c;还能起到一定的防止攻击的作用。下…

ROS方向第二次汇报(2)

文章目录 1.本方向内学习内容&#xff1a;1.1.动作&#xff1a;1.1.1.案例接口定义:1.1.2.案例通信模型&#xff1a;1.1.3.服务器端代码&#xff1a;1.1.4.客户端源代码&#xff1a;1.1.5.动作命令行操作&#xff1a; 1.2.参数&#xff1a;1.2.1.查看参数列表&#xff1a;1.2.2…

dvwa靶场xss储存型

xss储存型 xxs储存型lowmessage框插入恶意代码name栏插入恶意代码 medium绕过方法 high xxs储存型 攻击者事先将恶意代码上传或储存到漏洞服务器中&#xff0c;只要受害者浏览包含此恶意代码的页面就会执行恶意代码。产生层面:后端漏洞特征:持久性的、前端执行、储存在后端数据…

4JS表达式和运算符expression and operator

表达式&#xff08;expression&#xff09;JavaScript中的一个短语&#xff0c;JavaScript解释器会将其计算&#xff08;evaluate&#xff09;出一个结果。程序中的常量是最简单的一类表达式。变量名也是一种简单的表达式&#xff0c;它的值就是赋值给变量的值。复杂表达式是由…

类与对象(中篇)

1、类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 默认成员函数:用户没有显式实现,编译器会生成的成员函数称为默认成员函数。 2、构造函数---初始…

Java技术栈 —— Spring MVC 与 Spring Boot

参考文章或视频链接[1] Spring vs. Spring Boot vs. Spring MVC[2] Key Differences Between Spring vs Spring Boot vs Spring MVC

CRM系统主要干什么的

阅读本文&#xff0c;你将了解&#xff1a;一、CRM系统是什么&#xff1b;二、CRM系统主要干什么的&#xff1b;三、CRM系统在企业管理中的重要作用&#xff1b;四、企业落地案例分享——大吉包装。 本文所提及的功能演示和图片内容均来自于我们公司正在使用的简道云CRM系统&a…

【JVM】类加载流程

目录 1.加载 2.链接 &#xff08;1&#xff09;校验 &#xff08;2&#xff09;准备 &#xff08;3&#xff09;解析 3.初始化 4.使用 5.卸载 1.加载 加载阶段&#xff0c;简言之&#xff0c;查找并加载类的二进制数据&#xff0c;生成 Class 的实例 在加载类时&#x…

【C++航海王:追寻罗杰的编程之路】引用、内联、auto关键字、基于范围的for、指针空值nullptr

目录 1 -> 引用 1.1 -> 引用概念 1.2 -> 引用特性 1.3 -> 常引用 1.4 -> 使用场景 1.5 -> 传值、传引用效率比较 1.6 -> 值和引用作为返回值类型的性能比较 1.7 -> 引用和指针的区别 2 -> 内联函数 2.1 -> 概念 2.2 -> 特性 3 -…

Linux ---- Shell编程三剑客之AWK

一、awk处理文本工具 1、awk概述 awk 是一种处理文本文件的语言&#xff0c;是一个强大的文本分析工具。AWK是专门为文本处理设计的编程语言&#xff0c;也是行处理软件&#xff0c;通常用于扫描、过滤、统计汇总工作。用来处理列。数据可以来自标准输入也可以是管道或文件。…

共享的IP隔一段时间就变?用这种方法可以不需要知道电脑IP

前言 一般来说,电脑接入路由器之后,IP是由路由器自动分配的(DHCP),但如果隔一段时间不开机连接路由器,或者更换了别的网卡进行连接,自动分配的IP就会更改。 比如你手机连接着电脑的共享IP:192.168.1.10,但过段时间之后,电脑的IP突然变成了192.168.1.11,那么你的所有…

Qt+css绘制标题

之前学过html和小程序&#xff0c;帮老师做项目的时候也用过vue&#xff0c;在想qt绘制界面是不是也可以使用css,然后查了一些资料&#xff0c;绘制了一个标题&#xff0c;准备用到智能家居的上位机上面。 成果 源码 重写了paintEvent函数和TimeEvent函数&#xff0c;一个用于绘…

幻兽帕鲁服务器多少钱一个?26元,阿里云腾讯云华为云

2024年幻兽帕鲁服务器价格表更新&#xff0c;阿里云、腾讯云和华为云Palworld服务器报价大全&#xff0c;4核16G幻兽帕鲁专用服务器阿里云26元、腾讯云32元、华为云26元&#xff0c;阿腾云atengyun.com分享幻兽帕鲁服务器优惠价格表&#xff0c;多配置报价&#xff1a; 幻兽帕鲁…

透明拼接屏造型:多样拼接与影响因素

透明拼接屏&#xff0c;以其独特的透明显示效果和灵活的拼接方式&#xff0c;在现代显示领域中独树一帜。其造型多样&#xff0c;包括横屏拼接、竖屏拼接、异形拼接以及定制拼接等多种方式&#xff0c;满足了不同场景和应用的需求。尼伽小编将详细介绍这些拼接方式&#xff0c;…

PR转场模板|超级炫酷故障特效电影游戏视频转场PR模板剪辑素材

premiere转场&#xff0c;包含200个带有Sound FX的独特视频转场效果。加强剪辑视频视觉效果&#xff0c;在镜头之间的剪辑和添加文字动画&#xff01; MYFX Extension可帮助您一键浏览和应用预设&#xff01;可以喜爱预设&#xff0c;并拥有自己亲手挑选的库。如果您有任何问题…

字符数组的学习

前言&#xff1a; 在前面我们介绍过字符型数据是以字符的ASCII码储存在存储单元中&#xff0c;一般占一个字节&#xff0c;由于 ASCII码也属于整数类型&#xff0c;因此在C99标准中把字符类型归纳为整数类型中的一种&#xff0c;由于字符数据 的应用比较广泛&#xff0c;尤其…

抽象类(Java)、模板方法设计模式

一、概念 在Java中有abstract关键字&#xff0c;就是抽象的意思&#xff0c;可用来修饰类和成员方法。 用abstract来修饰类&#xff0c;那这个类就是抽象类&#xff1b;修饰方法&#xff0c;那这个方法就是抽象方法。 修饰符 abstract class 类名{修饰符 abstract 返回值类型…

【数据结构 02】队列

一、原理 队列通常是链表结构&#xff0c;只允许在一端进行数据插入&#xff0c;在另一端进行数据删除。 队列的特性是链式存储&#xff08;随机增删&#xff09;和先进先出&#xff08;FIFO&#xff1a;First In First Out&#xff09;。 队列的缺陷&#xff1a; 不支持随机…