用例子和代码了解词嵌入和位置编码

1.嵌入(Input Embedding)

让我用一个更具体的例子来解释输入嵌入(Input Embedding)。

背景

假设我们有一个非常小的词汇表,其中包含以下 5 个词:

  • "I"
  • "love"
  • "machine"
  • "learning"
  • "!"

假设我们想把这句话 "I love machine learning !" 作为输入。

步骤 1:创建词汇表(Vocabulary)

我们给每个词分配一个唯一的索引号:

  • "I" -> 0
  • "love" -> 1
  • "machine" -> 2
  • "learning" -> 3
  • "!" -> 4
步骤 2:创建嵌入矩阵(Embedding Matrix)

假设我们选择每个词的向量维度为 3(实际应用中维度会更高)。我们初始化一个大小为 5x3 的嵌入矩阵,如下所示:

嵌入矩阵(Embedding Matrix):
[[0.1, 0.2, 0.3],  // "I" 的向量表示[0.4, 0.5, 0.6],  // "love" 的向量表示[0.7, 0.8, 0.9],  // "machine" 的向量表示[1.0, 1.1, 1.2],  // "learning" 的向量表示[1.3, 1.4, 1.5]   // "!" 的向量表示
]
步骤 3:查找表操作(Lookup Table Operation)

当我们输入句子 "I love machine learning !" 时,我们首先将每个词转换为其对应的索引:

  • "I" -> 0
  • "love" -> 1
  • "machine" -> 2
  • "learning" -> 3
  • "!" -> 4

然后,我们使用这些索引在嵌入矩阵中查找相应的向量表示:

输入句子嵌入表示:
[[0.1, 0.2, 0.3],  // "I" 的向量表示[0.4, 0.5, 0.6],  // "love" 的向量表示[0.7, 0.8, 0.9],  // "machine" 的向量表示[1.0, 1.1, 1.2],  // "learning" 的向量表示[1.3, 1.4, 1.5]   // "!" 的向量表示
]
步骤 4:输入嵌入过程

        通过查找表操作,我们把原本的句子 "I love machine learning !" 转换成了一个二维数组,每一行是一个词的嵌入向量。

码示例代

让我们用 Python 和 PyTorch 来实现这个过程:

import torch
import torch.nn as nn# 假设词汇表大小为 5,嵌入维度为 3
vocab_size = 5
embedding_dim = 3# 创建一个嵌入层
embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)# 初始化嵌入矩阵(为了便于理解,这里手动设置嵌入矩阵的值)
embedding_layer.weight = nn.Parameter(torch.tensor([[0.1, 0.2, 0.3],  # "I"[0.4, 0.5, 0.6],  # "love"[0.7, 0.8, 0.9],  # "machine"[1.0, 1.1, 1.2],  # "learning"[1.3, 1.4, 1.5]   # "!"
]))# 输入句子对应的索引
input_indices = torch.tensor([0, 1, 2, 3, 4])# 获取输入词的嵌入表示
embedded = embedding_layer(input_indices)print(embedded)

输出: 

tensor([[0.1000, 0.2000, 0.3000],[0.4000, 0.5000, 0.6000],[0.7000, 0.8000, 0.9000],[1.0000, 1.1000, 1.2000],[1.3000, 1.4000, 1.5000]], grad_fn=<EmbeddingBackward>)

        这样我们就完成了输入嵌入的过程,把离散的词转换为了连续的向量表示。

        当你完成了词嵌入,将离散的词转换为连续的向量表示后,位置编码步骤如下:

2. 理解位置编码

        位置编码(Positional Encoding)通过生成一组特殊的向量,表示词在序列中的位置,并将这些向量添加到词嵌入上,使模型能够识别词序。

2.1 位置编码公式

        位置编码使用正弦和余弦函数生成。具体公式如下:

其中:

  •  是词在序列中的位置。
  •  i是词嵌入向量的维度索引。
  •  d是词嵌入向量的总维度。

2.2 生成位置编码向量

        以下是 Python 代码示例,展示如何生成位置编码向量,并将其添加到词嵌入上:

        生成位置编码向量

import numpy as np
import torchdef get_positional_encoding(max_len, d_model):"""生成位置编码向量:param max_len: 序列的最大长度:param d_model: 词嵌入向量的维度:return: 形状为 (max_len, d_model) 的位置编码矩阵"""pos = np.arange(max_len)[:, np.newaxis]i = np.arange(d_model)[np.newaxis, :]angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))angle_rads = pos * angle_rates# 采用正弦函数应用于偶数索引 (2i)angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])# 采用余弦函数应用于奇数索引 (2i+1)angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])return torch.tensor(angle_rads, dtype=torch.float32)# 示例参数
max_len = 100  # 假设最大序列长度为 100
d_model = 512  # 假设词嵌入维度为 512# 生成位置编码矩阵
positional_encoding = get_positional_encoding(max_len, d_model)
print(positional_encoding.shape)  # 输出: torch.Size([100, 512])

2.3 添加位置编码到词嵌入

        假设你已经有一个词嵌入张量 embedded,它的形状为 (batch_size, seq_len, d_model),可以将位置编码添加到词嵌入中:

class TransformerEmbedding(nn.Module):def __init__(self, vocab_size, d_model, max_len):super(TransformerEmbedding, self).__init__()self.token_embedding = nn.Embedding(vocab_size, d_model)self.positional_encoding = get_positional_encoding(max_len, d_model)self.dropout = nn.Dropout(p=0.1)def forward(self, x):# 获取词嵌入token_embeddings = self.token_embedding(x)# 添加位置编码seq_len = x.size(1)position_embeddings = self.positional_encoding[:seq_len, :]# 词嵌入和位置编码相加embeddings = token_embeddings + position_embeddings.unsqueeze(0)return self.dropout(embeddings)# 示例参数
vocab_size = 10000  # 假设词汇表大小为 10000
d_model = 512       # 词嵌入维度
max_len = 100       # 最大序列长度# 实例化嵌入层
embedding_layer = TransformerEmbedding(vocab_size, d_model, max_len)# 假设输入序列为一批大小为 2,序列长度为 10 的张量
input_tensor = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],[10, 9, 8, 7, 6, 5, 4, 3, 2, 1]], dtype=torch.long)# 获取嵌入表示
output_embeddings = embedding_layer(input_tensor)
print(output_embeddings.shape)  # 输出: torch.Size([2, 10, 512])

2.4. 继续进行 Transformer 模型的前向传播

        有了词嵌入和位置编码之后,接下来的步骤就是将这些嵌入输入到 Transformer 模型的编码器和解码器中,进行进一步处理。Transformer 模型的编码器和解码器由多层注意力机制和前馈神经网络组成。

        位置编码步骤通过生成一组正弦和余弦函数的向量,并将这些向量添加到词嵌入上,使 Transformer 模型能够捕捉序列中的位置信息。

import torch
import torch.nn as nnclass MultiHeadSelfAttention(nn.Module):def __init__(self, d_model, nhead):super(MultiHeadSelfAttention, self).__init__()assert d_model % nhead == 0, "d_model 必须能被 nhead 整除"self.d_model = d_modelself.d_k = d_model // nheadself.nhead = nheadself.W_q = nn.Linear(d_model, d_model)self.W_k = nn.Linear(d_model, d_model)self.W_v = nn.Linear(d_model, d_model)self.fc = nn.Linear(d_model, d_model)self.dropout = nn.Dropout(0.1)self.scale = torch.sqrt(torch.FloatTensor([self.d_k]))def forward(self, x):batch_size = x.size(0)seq_len = x.size(1)# 线性变换得到 Q, K, VQ = self.W_q(x)K = self.W_k(x)V = self.W_v(x)# 分成多头Q = Q.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)K = K.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)V = V.view(batch_size, seq_len, self.nhead, self.d_k).transpose(1, 2)# 计算注意力权重attn_weights = torch.matmul(Q, K.transpose(-2, -1)) / self.scaleattn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)attn_weights = self.dropout(attn_weights)# 加权求和attn_output = torch.matmul(attn_weights, V)# 拼接多头输出attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)# 最后的线性变换output = self.fc(attn_output)return output# 示例参数
d_model = 8
nhead = 2# 输入张量
x = torch.rand(2, 5, d_model)# 实例化多头自注意力层
multi_head_attn = MultiHeadSelfAttention(d_model, nhead)# 前向传播
output = multi_head_attn(x)
print("多头自注意力输出:\n", output)

解释

  • 线性变换:使用 nn.Linear 实现线性变换,将输入张量  通过三个不同的线性层得到查询、键和值向量。
  • 分成多头:使用 view 和 transpose 方法将查询、键和值向量分成多头,形状变为 。
  • 计算注意力权重:通过点积计算查询和键的相似度,并通过 softmax 归一化得到注意力权重。
  • 加权求和:使用注意力权重对值向量进行加权求和,得到每个头的输出。
  • 拼接多头输出:将多头的输出拼接起来,并通过一个线性层进行变换,得到最终的输出。

        查询、键和值向量的生成是多头自注意力机制的关键步骤,通过线性变换将输入向量转换为查询、键和值向量,然后使用这些向量计算注意力权重,捕捉输入序列中不同位置的相关性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/39107.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10 Posix API与网络协议栈

POSIX概念 POSIX是由IEEE指定的一系列标准,用于澄清和统一Unix-y操作系统提供的应用程序编程接口(以及辅助问题,如命令行shell实用程序),当您编写程序以依赖POSIX标准时,您可以非常肯定能够轻松地将它们移植到大量的Unix衍生产品系列中(包括Linux,但不限于此!)。 如…

DeepFaceLive----AI换脸简单使用

非常强大的软件,官方github https://github.com/iperov/DeepFaceLive 百度云链接: 链接&#xff1a;https://pan.baidu.com/s/1VHY-wxqJXSh5lCn1c4whZg 提取码&#xff1a;nhev 1下载解压软件 下载完成后双击.exe文件进行解压.完成后双击.bat文件打开软件 2 视频使用图片换…

k8s部署单机版mysql8

一、创建命名空间 # cat mysql8-namespace.yaml apiVersion: v1 kind: Namespace metadata:name: mysql8labels:name: mysql8# kubectl apply -f mysql8-namespace.yaml namespace/mysql8 created# kubectl get ns|grep mysql8 mysql8 Active 8s二、创建mysql配…

SSM学习4:spring整合mybatis、spring整合Junit

spring整合mybatis 之前的内容是有service层&#xff08;业务实现层&#xff09;、dao层&#xff08;操作数据库&#xff09;&#xff0c;现在新添加一个domain&#xff08;与业务相关的实体类&#xff09; 依赖配置 pom.xml <?xml version"1.0" encoding&quo…

2.2.3 C#中显示控件BDPictureBox 的实现----控件实现

2.2.3 C#中显示控件BDPictureBox 的实现----控件实现 1 界面控件布局 2图片内存Mat类说明 原始图片&#xff1a;m_raw_mat ,Display_Mat()调用时更新或者InitDisplay_Mat时更新局部放大显示图片&#xff1a;m_extract_zoom_mat&#xff0c;更新scale和scroll信息后更新overla…

2024年精选100道软件测试面试题(内含文档)

测试技术面试题 1、我现在有个程序&#xff0c;发现在 Windows 上运行得很慢&#xff0c;怎么判别是程序存在问题还是软硬件系统存在问题&#xff1f; 2、什么是兼容性测试&#xff1f;兼容性测试侧重哪些方面&#xff1f; 3、测试的策略有哪些&#xff1f; 4、正交表测试用…

市场规模5万亿,护理员缺口550万,商业护理企业如何解决服务供给难题?

干货抢先看 1. 据统计&#xff0c;我国失能、半失能老人数量约4400万&#xff0c;商业护理服务市场规模达5万亿。然而&#xff0c;当前养老护理员缺口巨大&#xff0c;人员的供需不匹配是很多养老服务企业需要克服的难题。 2. 当前居家护理服务的主要市场参与者分为两类&…

利用GPT 将 matlab 内置 bwlookup 函数转C

最近业务需要将 matlab中bwlookup 的转C 这个函数没有现成的m文件参考&#xff0c;内置已经打成库了&#xff0c;所以没有参考源代码 但是它的解释还是很清楚的&#xff0c;可以根据这个来写 Nonlinear filtering using lookup tables - MATLAB bwlookup - MathWorks 中国 A…

python请求报错::requests.exceptions.ProxyError: HTTPSConnectionPool

在发送网页请求时&#xff0c;发现很久未响应&#xff0c;最后报错&#xff1a; requests.exceptions.ProxyError: HTTPSConnectionPool(hostsvr-6-9009.share.51env.net, port443): Max retries exceeded with url: /prod-api/getInfo (Caused by ProxyError(Unable to conne…

秒懂设计模式--学习笔记(5)【创建篇-抽象工厂】

目录 4、抽象工厂4.1 介绍4.2 品牌与系列&#xff08;针对工厂泛滥&#xff09;(**分类**)4.3 产品规划&#xff08;**数据模型**&#xff09;4.4 生产线规划&#xff08;**工厂类**&#xff09;4.5 分而治之4.6 抽象工厂模式的各角色定义如下4.7 基于此抽象工厂模式以品牌与系…

vue启动时的错误

解决办法一&#xff1a;在vue.config.js中直接添加一行代码 lintOnSave:false 关闭该项目重新运行就可启动 解决办法二&#xff1a; 修改组件名称

配音软件有哪些?分享五款超级好用的配音软件

随着嫦娥六号的壮丽回归&#xff0c;举国上下都沉浸在这份自豪与激动之中。 在这样一个历史性的时刻&#xff0c;我们何不用声音记录下这份情感&#xff0c;让这份记忆以声音的形式流传&#xff1f; 无论是制作视频分享这份喜悦&#xff0c;还是创作音频讲述探月故事&#xff…

Oracle数据库中RETURNING子句

RETURNING子句允许您检索插入、删除或更新所修改的列&#xff08;以及基于列的表达式&#xff09;的值。如果不使用RETURNING&#xff0c;则必须在DML语句完成后运行SELECT语句&#xff0c;才能获得更改列的值。因此&#xff0c;RETURNING有助于避免再次往返数据库&#xff0c;…

CXL-GPU: 全球首款实现百ns以内的低延迟CXL解决方案

数据中心在追求更高性能和更低总拥有成本&#xff08;TCO&#xff09;的过程中面临三大主要内存挑战。首先&#xff0c;当前服务器内存层次结构存在局限性。直接连接的DRAM与固态硬盘&#xff08;SSD&#xff09;存储之间存在三个数量级的延迟差异。当处理器直接连接的内存容量…

VideoPrism——探索视频分析领域模型的算法与应用

概述 论文地址:https://arxiv.org/pdf/2402.13217.pdf 视频是我们观察世界的生动窗口&#xff0c;记录了从日常瞬间到科学探索的各种体验。在这个数字时代&#xff0c;视频基础模型&#xff08;ViFM&#xff09;有可能分析如此海量的信息并提取新的见解。迄今为止&#xff0c;…

采煤机作业3D虚拟仿真教学线上展示增强应急培训效果

在化工行业的生产现场&#xff0c;安全永远是首要之务。为了加强从业人员的应急响应能力和危机管理能力&#xff0c;纷纷引入化工行业工艺VR模拟培训&#xff0c;让应急演练更加生动、高效。 化工行业工艺VR模拟培训软件基于真实的厂区环境&#xff0c;精确还原了各类事件场景和…

医疗器械FDA | 医疗器械软件如何做源代码审计?

医疗器械网络安全测试https://link.zhihu.com/?targethttps%3A//www.wanyun.cn/Support%3Fshare%3D24315_ea8a0e47-b38d-4cd6-8ed1-9e7711a8ad5e 医疗器械源代码审计是一个确保医疗器械软件安全性和可靠性的重要过程。以下是医疗器械源代码审计的主要步骤和要点&#xff0c;以…

Vue3 sortablejs 表格拖拽后,表格无法更新的问题处理

实用sortablejs在vue项目中实现表格行拖拽排序 你可能会发现&#xff0c;表格排序是可以实现&#xff0c;但是我们基于数据驱动的vue中关联的数据并没有发生变化&#xff0c; 如果你的表格带有列固定(固定列实际上在dom中有两个表格&#xff0c;其中固定的列在一个表格中&…

游泳哪个牌子好?6大游泳耳机选购技巧总结分享

游泳耳机作为水上运动爱好者和游泳专业人士的必备装备&#xff0c;不仅要能够抵御水的侵入&#xff0c;还要提供清晰的音质和舒适的佩戴体验。在市面上&#xff0c;不同品牌的游泳耳机琳琅满目&#xff0c;选择起来可能会令人头疼。本文旨在为您提供一份详尽的游泳耳机选购指南…

Gemma轻量级开放模型在个人PC上释放强大性能,让每个桌面秒变AI工作站

Google DeepMind团队最近推出了Gemma&#xff0c;这是一个基于其先前Gemini模型研究和技术的开放模型家族。这些模型专为语言理解、推理和安全性而设计&#xff0c;具有轻量级和高性能的特点。 Gemma 7B模型在不同能力领域的语言理解和生成性能&#xff0c;与同样规模的开放模型…