大模型基础——从零实现一个Transformer(2)

大模型基础——从零实现一个Transformer(1)

一、引言

上一章主要实现了一下Transformer里面的BPE算法和 Embedding模块定义
本章主要讲一下 Transformer里面的位置编码以及多头注意力

二、位置编码

2.1正弦位置编码(Sinusoidal Position Encoding)

其中:

pos:表示token在文本中的位置
: i代表词向量具体的某一维度,即位置编码的每个维度对应一个波长不同的正弦或余弦波
d : d表示位置编码的最大维度,和词嵌入的维度相同,假设是512

对于位置0的编码为:

对于位置1的编码为:

2.2 正弦位置编码特性

  • 相对位置关系:pos + k的位置编码可以被位置pos的位置编码线性表示
    三角函数公式如下:

对于pos + k的位置编码:

根据式( 3 )和( 4 )整理上式有:

  • 位置之间的相对距离

𝑃𝐸𝑝𝑜𝑠+𝑘∙𝑃𝐸𝑝𝑜𝑠 的内积:

位置之间内积的关系大小如下:

可以看到内积会随着相对位置的递增而减少,从而可以表示位置的相对距离。内积的结果是对称的,所以没有方向信息。

2.3 代码实现

import torch
from torch import nn,Tensor
import mathclass PositionalEmbedding(nn.Module):def __init__(self,d_model:int=512,dropout:float=0.1,max_positions:int=1024) -> None:''':param d_model: embedding向量的维度:param dropout::param max_positions: 最大长度'''super().__init__()self.dropout = nn.Dropout(p=dropout)# Position Embedding  (max_positions,d_model)pe = torch.zeros(max_positions,d_model)# 创建position index列表 ,形状为:(max_positions, 1)position = torch.arange(0,max_positions).unsqueeze(1)# d_model 维度 偶数位是sin ,奇数位是cos# 计算除数,这里的除数将用于计算正弦和余弦的频率div_term = torch.exp(torch.arange(0,d_model,2) * -(math.log(10000.0) /d_model))# 对矩阵的偶数列(0,2,4...)进行正弦函数编码pe[:, 0::2] = torch.sin(position * div_term)# 对矩阵的奇数列(1,3,5...)进行余弦函数编码pe[:, 1::2] = torch.cos(position * div_term)# 扩展维度,增加batch_size: pe (1, max_positions, d_model)pe = pe.unsqueeze(0)# buffers will not be trainedself.register_buffer("pe", pe)def forward(self,x:Tensor) ->Tensor:"""Args:x (Tensor): (batch_size, seq_len, d_model) embeddingsReturns:Tensor: (batch_size, seq_len, d_model)"""# x.size(1)是指当前x的最大长度x = x + self.pe[:,:x.size(1)]return self.dropout(x)if __name__ == '__main__':seq_len = 128d_model = 512pe = PositionalEmbedding(d_model)x = torch.rand((1,100,d_model))print(pe(x).shape)

三、多头注意力

3.1 自注意力

公式如下:

  • 假设一个矩阵X,分别乘上权重矩阵,,就得到了Q , K , V向量矩阵

  • 然后除以 𝑑𝑘 进行缩放,再经过Softmax,得到注意力权重矩阵,接着乘以value向量矩阵V,就一次得到了所有单词的输出矩阵Z

3.2 多头注意力

将原来n_head分割乘Nx n_sub_head.对于每个头i,都有它自己不同的key,query和value矩阵: 𝑊𝑖𝐾,𝑊𝑖𝑄,𝑊𝑖𝑉 。在多头注意力中,key和query的维度是 𝑑𝑘 ,value嵌入的维度是 𝑑𝑣 (其中key,query和value的维度可以不同,Transformer里面一般设置的是相同的),这样每个头i,权重 𝑊𝑖𝑄∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝐾∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝑉∈𝑅𝑑×𝑑𝑣 ,然后与压缩到X中的输入相乘,得到 𝑄∈𝑅𝑁×𝑑𝑘,𝐾∈𝑅𝑁×𝑑𝑘,𝑉∈𝑅𝑁×𝑑𝑣 .

3.3 代码实现

import mathimport torch
from torch import nn,Tensor
from typing import *class MultiHeadAttention(nn.Module):def __init__(self,d_model: int = 512,n_heads: int=8,dropout: float = 0.1):''':param d_model: embedding大小:param n_heads: 多头个数:param dropout:'''super().__init__()assert d_model % n_heads == 0self.d_model = d_modelself.n_heads = n_headsself.d_key = d_model // n_headsself.q = nn.Linear(d_model,d_model)self.k = nn.Linear(d_model,d_model)self.k = nn.Linear(d_model,d_model)self.concat = nn.Linear(d_model,d_model)self.dropout = nn.Dropout(dropout)def split_heads(self,x:Tensor,is_key : bool = False) -> Tensor:'''分割向量为N个头,如果是key的话,softmax时候,key需要转置一下:param x::param is_key::return:'''batch_size = x.size(0)# x (batch_size,seq_len,n_heads,d_key)x = x.view(batch_size,-1,self.n_heads,self.d_key)if is_key:# (batch_size,n_heads,d_key,seq_len)return x.permute(0,2,3,1)# (batch_size,n_heads,seq_len,d_keyreturn x.transpose(1,2)def merge_heads(self,x: Tensor) -> Tensor:x = x.transpose(1,2).contigouse().view(x.size(0),-1,self.d_model)return xdef attention(self,query:Tensor,key:Tensor,value:Tensor,mask:Tensor = None,keep_attentions:bool = False):scores = torch.matmul(query,key) / math.sqrt(self.d_key)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)# weights (batch_size,n_heads,q_length,k_length)weights = self.dropout(torch.softmax(scores,dim=-1))# (batch_size,n_heads,q_length,k_length) x (batch_size,n_heads,v_length,d_key)# -> (batch_size,n_heads,q_length,d_key)# assert k_length == v_length# attn_output (batch_size, n_heads, q_length, d_key)atten_output = torch.matmul(weights,value)if keep_attentions:self.weights = weightselse:del weightsreturn atten_outputdef forward(self,query: Tensor,key: Tensor,value: Tensor,mask: Tensor = None,keep_attentions: bool = False)-> Tuple[Tensor,Tensor]:''':param query:(batch_size, q_length, d_model):param key:(batch_size, k_length, d_model):param value:(batch_size, v_length, d_model):param mask: mask for padding or decoder. Defaults to None.:param keep_attentions: whether keep attention weigths or not. Defaults to False.:return: (batch_size, q_length, d_model) attention output'''query = self.q(query)key = self.k(key)value = self.v(value)query,key,value = (self.split_heads(query),self.split_heads(key,is_key=True),self.split_heads(value))atten_output = self.attention(query,key,value,mask,keep_attentions)del querydel keydel value# concatconcat_output = self.merge_heads(atten_output)# the final liear# output (batch_size, q_length, d_model)output = self.concat(concat_output)return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/25954.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

持续总结中!2024年面试必问 20 道分布式、微服务面试题(七)

上一篇地址:持续总结中!2024年面试必问 20 道分布式、微服务面试题(六)-CSDN博客 十三、请解释什么是服务网格(Service Mesh)? 服务网格(Service Mesh)是一种用于处理服…

线程知识点总结

Java线程是Java并发编程中的核心概念之一,它允许程序同时执行多个任务。以下是关于Java线程的一些关键知识点总结: 1. 线程的创建与启动 继承Thread类:创建一个新的类继承Thread类,并重写其run()方法。通过创建该类的实例并调用st…

TypeScript基础教程学习

菜鸟教程 TypeScript基础类型 数字类型 number 双精度 64 位浮点值。它可以用来表示整数和分数。 let binaryLiteral: number 0b1010; // 二进制 let octalLiteral: number 0o744; // 八进制 let decLiteral: number 6; // 十进制 let hexLiteral: number 0xf00d…

从信号灯到泊车位,ARMxy如何重塑城市交通智能化

城市智能交通系统的高效运行对于缓解交通拥堵、提高出行安全及优化城市管理至关重要。ARMxy工业计算机,作为这一领域内的技术先锋,正以其强大的性能和灵活性,悄然推动着交通管理的智能化升级。 智能信号控制的精细化管理 想象一下&#xff0…

【C语言】11.字符函数和字符串函数

文章目录 1.字符分类函数2.字符转换函数3.strlen的使用和模拟实现4.strcpy的使用和模拟实现5.strcat的使用和模拟实现6.strcmp的使用和模拟实现7.strncpy函数的使用8.strncat函数的使用9.strncmp函数的使用10.strstr的使用和模拟实现11.strtok函数的使用12.strerror函数的使用 …

视频修复工具,模糊视频变清晰!

老旧视频画面效果差,视频效果模糊。我们经常找不到一个好的工具来让视频更清晰,并把它变成高清画质。相信很多网友都会有这个需求,尤其是视频剪辑行业的网友,经常会遇到这个问题。今天给大家分享一个可以把模糊视频修复清晰的工具…

cnvd_2015_07557-redis未授权访问rce漏洞复现-vulfocus复现

1.复现环境与工具 环境是在vulfocus上面 工具:GitHub - vulhub/redis-rogue-getshell: redis 4.x/5.x master/slave getshell module 参考攻击使用方式与原理:https://vulhub.org/#/environments/redis/4-unacc/ 2.复现 需要一个外网的服务器做&…

《TCP/IP网络编程》(第十四章)多播与广播

当需要向多个用户发送多媒体信息时,如果使用TCP套接字,则需要维护与用户数量相等的套接字;如果使用之前学习的UDP,传输次数也需要和用户数量相同。 所以为了解决这些问题,可以采用多播和广播技术,这样只需要…

Python学习打卡:day02

day2 笔记来源于:黑马程序员python教程,8天python从入门到精通,学python看这套就够了 8、字符串的三种定义方式 字符串在Python中有多种定义形式 单引号定义法: name 黑马程序员双引号定义法: name "黑马程序…

网安面试题总结_1

#创作灵感# 助力网安人员顺利面试 等保测评 等保测评一般分成五个阶段,定级、备案、测评、整改、监督检查。 外网 外网打点的基本流程主要分为:靶标确认、信息收集、漏洞探测、漏洞利用、权限获取,其最终目的是为了获取靶标的系统权限/关…

Spring Boot中Excel的导入导出的实现之Apache POI框架使用教程

文章目录 前言一、Apache POI 是什么?二、使用 Apache POI 实现 Excel 的导入和导出① 导入 Excel1. 添加依赖2. 编写导入逻辑3. 在 Controller 中处理上传请求 ② 导出 Excel1. 添加依赖2. 编写导出逻辑3. 在 Controller 中处理导出请求 总结 前言 在 Spring Boot …

代码随想录算法训练营第四十四天 | 01背包问题理论基础、01背包问题滚动数组、416. 分割等和子集

背包问题其实有很多种,01背包是最基础也是最经典的,软工计科学生一定要掌握的。 01背包问题 代码随想录 视频讲解:带你学透0-1背包问题!| 关于背包问题,你不清楚的地方,这里都讲了!| 动态规划经…

C++11:列表初始化 初始化列表initializer_list decltype关键字

目录 前言 列表初始化 初始化列表initializer_list decltype关键字 左值和右值 move 前言 2003年C标准委员会曾经提交了一份技术勘误表(简称TC1),使得C03这个名字取代了C98成为了C11前最新的C标准名称。不过由于C03主要是对C98标准中的…

网络安全在个人生活中具体有哪些常见的应用场景?

网络安全在个人生活中的应用场景非常广泛,以下是一些常见的例子: 1. 个人隐私保护:网络安全可以帮助保护个人的隐私信息,如银行账户、身份证号、联系方式等,防止被黑客窃取或滥用。 2. 电子商务:在进行在…

认识和使用 Vite 环境变量配置,优化定制化开发体验

Vite 官方中文文档:https://cn.vitejs.dev/ 环境变量 Vite 内置的环境变量如下: {"MODE": "development", // 应用的运行环境"BASE_URL": "/", // 部署应用时使用的 URL 前缀"PROD": false, //应用…

国外媒体软文发稿-引时代潮流-助力跨国企业蓬勃发展

大舍传媒:开疆拓土,引领传媒新潮流 随着全球经济的一体化和信息技术的高速发展,跨国企业在国际市场上的竞争越来越激烈。这也给跨国企业带来了巨大的机遇和挑战。在这个时代背景下,大舍传媒凭借其独特的优势和创新的服务模式&…

分布式数据库中,如何正确的将数据分片?

前面我们了解了分布式数据库的架构,知道各类分布式数据库都离不开计算层、存储层、元数据层这三层关系。另外,很重要的一点是,了解了分布式数据库是把数据打散存储在一个个分片中。在基于MySQL 的分布式数据库架构中,分片就存在于 MySQL 实例中。 本篇文章,我们就来了解一…

市值超越苹果,英伟达的AI崛起与天润融通的数智化转型

Agent,开启客户服务新时代。 世界商业格局又迎来一个历史性时刻。 北京时间6月6日,人工智能芯片巨头英伟达(NVDA)收涨5.16%,总市值达到3.01万亿美元,正式超越苹果公司,成为仅次于微软&#xf…

IDEA启动项目报java.lang.OutOfMemoryError: GC overhead limit exceeded

idea编译项目时报j ava.lang.OutOfMemoryError: GC overhead limit exceeded错误,教你两步搞定! 第一步:打开help -> Edit Custom VM Options ,修改xms和xmx的大小,如下图: 第二步:File -> Settings…

[力扣题解] 501. 二叉搜索树中的众数

题目:501. 二叉搜索树中的众数 思路 代码 Method 1 把二叉搜索树的结果拉直,排序,再从前往后统计; 其中,使用unordered_map来记录元素->次数对,用vector来排序; /*** Definition for a …