Transformer step by step--Positional Embedding 和 Word Embedding

Transformer step by step往期文章:

Transformer step by step--层归一化和批量归一化 

要把Transformer中的Embedding说清楚,那就要说清楚Positional EmbeddingWord Embedding。至于为什么有这两个Embedding,我们不妨看一眼Transformer的结构图。

从上图可以看到,我们的输入需要在Input EmbeddingPositional Encoding的共同作用下才会分别输入给EncoderDecoder,所以我们就分别介绍一下怎么样进行Input EmbeddingPositional Encoding

同时为了帮助大家更好地理解这两种Embedding方式,我们这里生成一个自己的迷你数据集。

import tiktoken
import torch
import torch.nn as nn
encoding = tiktoken.get_encoding("cl100k_base") #导入openai的开源tokenizer库
context_length = 4 #选取4个token
batch = 4 # 批处理大小
example_text1 = "I am now writing an example to show the usage of word embedding."
example_text2 = "I am now writing an sentence to show the usage of another word embedding."
total_text = example_text1 + example_text2 #形成总数据集
tokenize_text = encoding.encode(total_text) #进行tokenize
tokenize_text = torch.tensor(tokenize_text, dtype=torch.long) #这里转换成tentor是因为后续我们要用pytorch的框架
idxs = torch.randint(low = 0,high = len(tokenize_text) - context_length,size = (batch,))#这里我们随机选batch个id
x_batch = torch.stack([tokenize_text[idx:idx + context_length] for idx in idxs]) #根据id和context_length抽取训练数据

一、Word Embedding

这里的Word Embedding就是论文中提到的Input Embedding。我们在之前的文章中已经介绍使用tokenizer将原始的文本信息转换为数字,便于输入进模型。 可是使用tokenizer有一个问题,这个问题就是我们虽然用不同的ID表示了不同的子词,但是这些ID所能蕴含的语义信息非常有限,比如dog和dogs这两个单词语义非常相近,但很有可能它们的ID相隔非常远,所以为了更好地体现不同单词之间的语义关系,我们将每个ID通过Word Embedding的方式变为一个向量。

这里的实现在Pytorch框架之下变得非常简单,只需要一行代码就可以搞定,但是我们这里还是详细对代码的参数进行一些讲解。

max_token_value = tokenize_text.max().item()
d_model = 16
input_embedding_lookup_table = nn.Embedding(max_token_value + 1, d_model)
x_batch_embedding = input_embedding_lookup_table(x_batch)
print(max_token_value)
#
40188

这里的第一、二行我们一起讲:

max_token_value = tokenize_text.max().item()是取出当前token中ID最大的那个数。

input_embedding_lookup_table = nn.Embedding(max_token_value + 1, d_model)是根据最大的ID去构建Embedding层。

首先第一个问题,nn.Embedding的两个参数是什么意思?

首先第一个参数我们这里使用的是 max_token_value + 1,这个是用来表示我们词汇库的最大长度。那这里可能又有一个新的问题,就是我们上面的句子,算上标点符号一共也就十几个词,为什么我们这里要用40188+1作为这个词汇库的最大长度呢?这是因为我们这里用的tokenizer是openai的第三方库,这个库在做tokenize的时候对应着大量的原始文本,而我们example_text中的文本在经过tokenize之后,最大的token ID对应的是这个库中的原始文本的40188。那么这里还有第二个问题,这里返回的是40188,我们为什么要加上1呢?因为tokenize之后的ID是从0开始算的,也就是说40188对应的词汇表的最大长度应该是40189。

第二个参数我们使用的是d_model,这个参数会好理解一些。我们之前说过,一个ID没有什么语义信息,但变成向量之后就可以通过余弦相似度计算两个向量之间的相关性。那么这里的一个问题就在于,我用多少维的向量去表示呢?d_model这个参数就是来解决这个事情,我们想让ID变成多少维的向量,就把d_model设置成多少。

那么第三行,就是我们讲原来形状为4 * 4的x_batch变为了4 * 4 * 16 的x_batch_embedding。后面多出来的16就是我们自己设置的嵌入的维度。

二、Positional Embedding

总体来说,word_embedding还是比较通俗易懂的,接下来我们根据论文当中的公式去写一下Positional Encoding,也就是Positional Embedding(同一个意思)。

这里我们解释一下这两行公式啥意思,Positional Embedding简单来说,就是给每个token分配一个位置信息,因为 𝑠𝑒𝑙𝑓 - 𝑎𝑡𝑡𝑒𝑛𝑡𝑖𝑜𝑛 本身无法判断不同token所在的位置。PE对应Positional Encoding的缩写,括号中的pos对应我们设置的context length的长度,2i对应嵌入维度中的偶数维度,2i+1对应嵌入维度中的奇数维度。

接下来我们就来实现一下相关的代码:

positional_encoding = torch.zeros(context_length, d_model) #首先初始化一个和token形状大小一样的positional encoding
positional = torch.arange(0, context_length).float().unsqueeze(1) #按照我们设置的context length去初始化position
_2i = torch.arange(0, d_model, 2) # 这里用生成d_model/2的步长,因为sin和cos两个加起来就变成了d_model
positional_encoding[:, 0::2] = torch.sin(torch.exp(positional/10000**(_2i/d_model))) #按照公式写一遍
positional_encoding[:, 1::2] = torch.cos(torch.exp(positional/10000**(_2i/d_model)))
positional_encoding = positional_encoding.squeeze(0).expand(batch, -1, -1) #最终根据batch的数量对维度进行扩充
input_x = x_batch_embedding + positional_encoding # 将word embedding和positional embedding相加得到模型的输入
print(positional_encoding.shape)
print(input_x.shape)
##
torch.Size([4, 4, 16]) 
torch.Size([4, 4, 16]) 

到这里,我们也就完成了两个embedding操作。

知乎原文链接:安全验证 - 知乎知乎,中文互联网高质量的问答社区和创作者聚集的原创内容平台,于 2011 年 1 月正式上线,以「让人们更好的分享知识、经验和见解,找到自己的解答」为品牌使命。知乎凭借认真、专业、友善的社区氛围、独特的产品机制以及结构化和易获得的优质内容,聚集了中文互联网科技、商业、影视、时尚、文化等领域最具创造力的人群,已成为综合性、全品类、在诸多领域具有关键影响力的知识分享社区和创作者聚集的原创内容平台,建立起了以社区驱动的内容变现商业模式。icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/691169616

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/3712.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营第三十八天| LeetCode509.斐波那契数、LeetCode70.爬楼梯、LeetCode746.使用最小花费爬楼梯

LeetCode 509 斐波那契数 题目链接:509. 斐波那契数 - 力扣(LeetCode) 【解题思路】 1.确定dp数组以及下标的含义 dp[i]的定义为:第i个数的斐波那契数值是dp[i] 2.确定递推公式 递推公式就是题目说的斐波那契值递推公式 3.确定…

【Java GUI】人机对弈五子棋

在学校的Java课程中,我们被分配了一项有趣的任务:开发一款能够实现人机对弈的五子棋游戏。为了更好地理解Java GUI的运用,并与大家分享学习心得,我将整个开发过程记录在这篇博客中。欢迎大家阅读并提供宝贵的意见和建议&#xff0…

微信小程序-------模板与配置

能够使用 WXML 模板语法渲染页面结构能够使用 WXSS 样式美化页面结构能够使用 app.json 对小程序进行全局性配置能够使用 page.json 对小程序页面进行个性化配置能够知道如何发起网络数据请求 一.WXML 模板语法 数据绑定 1. 数据绑定的基本原则 ① 在 data 中定义数据 ② 在…

文件分享新风尚,二维码生成器全功能解析

随着科技的飞速发展,二维码生成器已经成为我们日常生活中不可或缺的一部分。无论是支付、下载应用还是获取信息,二维码都以其便捷性和高效性赢得了广大用户的青睐。在这样的背景下,二维码生成器应运而生,不仅支持多文件生成二维码…

Android Studio修改“choose boot runtime for the IDE“后无法打开

在Android Studio中选择了"choose boot runtime for the IDE"的New后,会自动重启AS,然后就无法打开android studio了,打开直接报错,cause by如下 Unable to make field protected java.lang.Runnable java.awt.event.In…

[leetcode] 58. 最后一个单词的长度

文章目录 题目描述解题方法倒序遍历java代码复杂度分析 题目描述 给你一个字符串 s,由若干单词组成,单词前后用一些空格字符隔开。返回字符串中 最后一个 单词的长度。 单词 是指仅由字母组成、不包含任何空格字符的最大子字符串。 示例 1&#xff1a…

未来已来:解锁AGI的无限潜能与挑战

未来已来:解锁AGI的无限潜能与挑战 引言 假设你有一天醒来,发现你的智能手机不仅提醒你今天的日程,还把你昨晚做的那个奇怪的梦解释了一番,并建议你可能需要减少咖啡摄入量——这不是科幻电影的情节,而是人工通用智能…

maldev tricks在注册表中存储 payload

简介 注册表是 Windows 操作系统中一个重要的数据库,它包含 Windows 操作系统和应用程序的重要设置和选项。由于注册表的功能非常强大,因此注册表对于恶意程序来说是非常有利用价值的。 在 windows 注册表中存储二进制数据,这是一种常见的技…

每天一个数据分析题(二百九十六)

订单详情表是以每一笔订单的每一件商品为最小业务记录单位进行记录的,那么可能成为订单详情表的主键字段的是? A. 订单编号 B. 产品编号 C. 订单ID D. 订单编号产品编号 题目来源于CDA模拟题库 点击此处获取答案 cda数据分析考试:点击…

十七、Java网络编程(一)

1、Java网络编程的基本概念 1)网络编程的概念 Java作为一种与平台无关的语言,从一出现就与网络有关及其密切的关系,因为Java写的程序可以在网络上直接运行,使用Java,只需编写简单的代码就能实现强大的网络功能。下面将介绍几个与Java网络编程有关的概念。 2)TCP/IP协议概…

MVP+敏捷开发

MVP敏捷开发 1. 什么是敏捷开发? 敏捷开发是一种软件开发方法论,旨在通过迭代、自组织的团队和持续反馈,快速响应需求变化并交付高质量的软件。相较于传统的瀑布模型,敏捷开发强调灵活性、适应性和与客户的紧密合作。敏捷开发方…

qml和c++结合使用

目录 文章简介1. 创建qml工程2. 创建一个类和qml文件,修改main函数3. 函数说明:4. qml 文件间的调用5. 界面布局6. 代码举例 文章简介 初学qml用来记录qml的学习过程,方便后面归纳总结整理。 1. 创建qml工程 如下图,我使用的是…

【北京迅为】《iTOP龙芯2K1000开发指南》-第四部分 ubuntu开发环境搭建

龙芯2K1000处理器集成2个64位GS264处理器核,主频1GHz,以及各种系统IO接口,集高性能与高配置于一身。支持4G模块、GPS模块、千兆以太网、16GB固态硬盘、双路UART、四路USB、WIFI蓝牙二合一模块、MiniPCIE等接口、双路CAN总线、RS485总线&#…

光伏无人机:巡检无人机解决巡检难题

随着科技的飞速发展,无人机技术已经广泛应用于各个领域,其中光伏无人机在解决光伏电站巡检难题方面发挥了重要作用。光伏无人机以其高效、精准、安全的特点,为光伏电站的巡检工作带来了革命性的变革。 光伏电站通常位于广阔的户外场地&#x…

webpack热更新原理详解

文章目录 前言基础配置创建项目HMR配置 HMR交互概览HMR流程概述HMR实现细节初始化注册监听编译完成事件启动服务监听文件代码变化服务端发送消息客户端收到消息热更新文件请求热更新代码替换 问题思考 前言 刷新分为两种:一种是页面刷新,不保留页面状态…

二叉树的三种遍历方式非递归

## 二叉树的遍历需要把当前根节点和左右子树抽象成只有一个根节点和一个左子树,一个右子树的情况 ##套路 二叉树的遍历先序中序后序无非是根节点的处理顺序不同,三种遍历方式都有一个相同的逻辑流程,不管怎么样首先将左子树入栈,遍…

Uniapp 跨页面传复杂参、传对象

1、需要传递参数的页面 attended(e) {this.id e.baseIdlet obj {page: this.page,size: this.pageSize,baseId: e.baseId,str: this.struni.navigateTo({url: /pages/index/index?item encodeURIComponent(JSON.stringify(obj))})}}, encodeURIComponent 2、接收参数的页…

GPU深度学习环境搭建:Win10+CUDA 11.7+Pytorch1.13.1+Anaconda3+python3.10.9

1. 查看显卡驱动及对应cuda版本关系 1.1 显卡驱动和cuda版本信息查看方法 在命令行中输入【nvidia-smi】可以当前显卡驱动版本和cuda版本。 根据显示,显卡驱动版本为:Driver Version: 516.59,CUDA 的版本为:CUDA Version 11.7。 此处我们可以根据下面的表1 显卡驱动和c…

大模型咨询培训老师叶梓:利用知识图谱和Llama-Index增强大模型应用

大模型(LLMs)在自然语言处理领域取得了显著成就,但它们有时会产生不准确或不一致的信息,这种现象被称为“幻觉”。为了提高LLMs的准确性和可靠性,可以借助外部知识源,如知识图谱。那么我们如何通过Llama-In…

将阿里云中数据传输到其他超算服务器

目录 方法一:在阿里云中连接超算,然后使用rsync(速度慢) 方法2:rclone(速度很快,100G只花了大约20min) 方法一:在阿里云中连接超算,然后使用rsync/scp(速度慢&#xff0…