LLM - Make Causal Mask 构造因果关系掩码

目录

一.引言

二.make_causal_mask

1.完整代码

2.Torch.full

3.torch.view

4.torch.masked_fill_

5.past_key_values_length

6.Test Main

三.总结


一.引言

Causal Mask 主要用于限定模型的可视范围,防止模型看到未来的数据。在具体应用中,Causal Mask 可将所有未来的 token 设置为零,从注意力机制中屏蔽掉这些令牌,使得模型在进行预测时只能关注过去和当前的 token,并确保模型仅基于每个时间步骤可用的信息进行预测。

在 Transformer 模型中,Multihead Attention 中的 Causal Mask 就是用于解决这个问题,以实现模型对于输入序列的正确处理。下面是 Causal 的可视化示例,在实践中其呈现倒三角形状:

全文为 'I love eating lunch.' ,对于 'love' 而言其只能看到 'I',不能看到未来的 'eating'、'lunch'。

二.make_causal_mask

为了方便后续示例的展示,这里选择较小的参数,batch_size = 2,target_length = 4。

1.完整代码

# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):"""Make causal mask used for bi-directional self-attention."""bsz, tgt_len = input_ids_shapemask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)mask_cond = torch.arange(mask.size(-1), device=device)mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)mask = mask.to(dtype)if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

代码的 Input 主要就是一个二维的 input_ids_shape,分别为 batch_size 和 target_length,dtype 和 device 在这里比较好理解,还有就是最后的 past_key_values_length,用于补齐,这个也比较简单。Output 则是 (batch_size, 1, target_length, target_length) 的 Causal Mask,其中 Msak 的矩阵 target_length x target_length 就是上面所示的倒三角形状。

2.Torch.full

函数介绍

该函数用于创建一个具有指定填充值的新张量。该函数的语法如下:

torch.full(size, fill_value, *, dtype=None, device=None, requires_grad=False)

参数说明:

  • size:张量的形状,可以是一个整数或者一个元组,例如:(3, 3) 或 3。
  • fill_value:张量的填充值。
  • dtype:张量的数据类型,默认为None,即根据输入的数据类型推断。
  • device:张量所在的设备,默认为None,即根据输入的设备推断。
  • requires_grad:是否需要计算梯度,默认为False。

该函数返回一个与指定形状相同且所有元素都被设置为指定填充值的新张量。

函数使用

mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))

这里 torch.finfo(dtype).min 为对应 torch.dtype 类型的最小值,以 bfloat16 为例:

print(torch.tensor(torch.finfo(dtype).min))
=> tensor(-3.3895e+38)

而这一步 mask 的操作就是生成一个 tgt_len x tgt_len 的充满 min 元素的方阵:

3.torch.view

函数介绍

在 PyTorch 库中,view 函数用于改变一个张量(Tensor)的形状(shape)。它返回一个新的张量,其元素与原始张量相同,但形状(shape)已被改变。view 函数的行为非常类似于 NumPy的 reshape 函数。它会返回一个与原始张量共享数据但具有不同形状的新的张量。如果给定的形状与原始张量的元素总数不匹配,则会引发错误。

import torch  x = torch.randn(4, 5)  # 创建一个4x5的随机张量  
y = x.view(20)  # 改变形状为20的一维张量  
z = x.view(-1, 10)  # 改变形状为10的一维张量,第一维度由其他维度决定

函数使用

mask_cond = torch.arange(mask.size(-1))

mask_cond 是一个1维向量:

(mask_cond + 1).view(mask.size(-1), 1)

这一步相当于在 mask_cond 基础上先加常量再 reshape:

4.torch.masked_fill_

函数介绍

在 PyTorch 库中,masked_fill_() 函数是一个张量(Tensor)方法,用于将张量中的指定区域填充为特定值。此函数需要一个掩码(mask)作为输入,该掩码应与原张量具有相同的形状。掩码中的 True 值表示需要填充的区域,False 值表示需要保留的原始值。

torch.Tensor.masked_fill_(mask, value)

参数说明: 

  • mask (Bool tensor) - 掩码张量,用于指定需要填充的区域。
  • value (float) - 填充的值。

示例

假设我们有一个 3x3 的张量,我们想要将所有大于 5 的元素替换为 -1。我们可以使用该函数来实现这个目标。

import torch  # 创建一个3x3的张量  
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  # 创建一个掩码,其中大于5的元素为True,其余为False  
mask = x > 5  # 使用masked_fill_函数将大于5的元素替换为-1  
x.masked_fill_(mask, -1)  print(x)

输出:

tensor([[ 1,  2,  3],  [ 4,  5,  6],  [-1, -1, -1]])

函数使用

mask_cond < (mask_cond + 1).view(mask.size(-1), 1)

传入函数的 mask 如下,呈倒三角形态,其中 True 的部分填充新值,False 部分保持不变: 

mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)

 根据 mask 对 target x target 的方阵进行填充 0 得到我们上面提到的倒三角:

5.past_key_values_length

在 PyTorch 中,past_key_values_length 是一个参数,用于指定在使用 Transformer 模型时,过去键值缓存(past key-value cache)的长度。该参数通常与 Transformer 模型中的自注意力机制(self-attention mechanism)一起使用。在过去键值缓存中,模型保存了过去的键和值向量,以便在生成序列时重复使用它们。这些过去的键和值向量可以用于计算自注意力分数,从而提高生成序列的效率。较大的past_key_values_length可以增加模型的表现力,但也会增加计算量和内存消耗。因此,需要根据具体任务和资源限制来选择合适的值。

if past_key_values_length > 0:mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)

这里定义 past_key_values_length = 1,代码逻辑就是在原有的 tgt x tgt 方阵前补 past_key_values_length 个 0:

6.Test Main

mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
if __name__ == '__main__':batch_size = 2target_length = 4input_shape = (batch_size, target_length)data_type = torch.bfloat16causal_mask = _make_causal_mask(input_shape, data_type, 1)print(causal_mask)print(causal_mask.shape)

=>        pask_key_length = 1填充后,tgt x tgt 变为 tgt x (1 + tgt)

=>        通过 None + expand 的组合,tgt x (1 + tgt) 变为 bsz x 1 x tgt x (1 + tgt)

三.总结

新系列博文一方面是阅读 HF 上 LLM 模型实现的源码,了解对应知识的实现过程。另一方面是之前很多同学主要接触 TF 1.x、TF 2.x 以及 Estimator 和 Keras 这一类深度学习工具,趁此机会也能熟悉 Torch 的使用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/87808.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity之Hololens开发如何实现UI交互

一.前言 什么是Hololens? Hololens是由微软开发的一款混合现实头戴式设备,它将虚拟内容与现实世界相结合,为用户提供了沉浸式的AR体验。Hololens通过内置的传感器和摄像头,能够感知用户的环境,并在用户的视野中显示虚拟对象。这使得用户可以与虚拟内容进行互动,将数字信…

内存对齐--面试常问问题和笔试常考问题

1.内存对齐的意义 C 内存对齐的主要意义可以简练概括为以下几点&#xff1a; 提高访问效率&#xff1a;内存对齐可以使数据在内存中以更加紧凑的方式存储&#xff0c;从而提高了数据的访问效率。处理器通常能够更快地访问内存中对齐的数据&#xff0c;而不需要额外的字节偏移计…

Learn Prompt- Midjourney 图片生成:Image Prompts

Prompt 自动生成 前不久&#xff0c;Midjourney 宣布支持图片转 prompt 功能。 原始图片​ blueprint holographic design of futuristic Midlibrary --v 5Prompt 生成​ 直接输入 /describe 指令通过弹出窗口上传图像并发送&#xff0c;Midjourney 会根据该图像生成四种可…

完成“重大项目”引进签约,美创科技正式落户中国(南京)软件谷

近日&#xff0c;美创科技正式入驻中国&#xff08;南京&#xff09;软件谷&#xff0c;并受邀出席中国南京“金洽会"之“雨花台区数字经济创新发展大会”。美创科技副总裁罗亮亮作为代表&#xff0c;在活动现场完成“重大项目”引进签约。 作为国家重要的软件产业与信息服…

关于ubuntu设置sh文件开机自启动python3和sudo python3问题

关于ubuntu设置sh文件开机自启动python3和sudo python3问题 说明系统为 ubuntu22.04python是python3.10.12ros系统为ros2 humble 背景解决方法补充 说明 系统为 ubuntu22.04 python是python3.10.12 ros系统为ros2 humble 背景 将一个py文件设置为开机自启动&#xff0c;服…

贪心算法总结归类(图文解析)

贪心算法实际上并没有什么套路可言&#xff0c;贪心的关键就在于它的思想&#xff1a; 如何求出局部最优解&#xff0c;通过局部最优解从而推导出全局最优解 常见的贪心算法题目 455. 分发饼干 这题的解法很符合“贪心”二字 如果使用暴力的解法&#xff0c;那么本题是通过…

Windows--Python永久换下载源

1.新建pip文件夹&#xff0c;注意路径 2.在上述文件中&#xff0c;新建文件pip.ini 3.pip.ini记事本打开&#xff0c;输入内容&#xff0c;保存完事。 [global] index-url https://pypi.douban.com/simple

【记录文】Android自定义Dialog实现圆角对话框

圆角的dialog还是蛮常用的&#xff0c;demo中正好用上了 自定义Dialog&#xff0c;代码中可以设置指定大小与位置 /*** author : jiangxue* date : 2023/9/25 13:21* description :圆角的矩形*/internal class RoundCornerView(context: Context,view: Int, StyleRes theme…

开机自启动Linux and windows

1、背景 服务器由于更新等原因重启&#xff0c;部署到该服务上的响应的应用需要自启动 2、Linux 2.1 方式一 编写启动应用的sh脚本授权该脚本权限 chmod 777 xxx.sh 修改rc.loacl 位置&#xff1a;/etc/rc.local 脚本&#xff1a;sh /home/xxxx.sh & 授权rc.local …

Elasticsearch—(MacOs)

1⃣️环境准备 准备 Java 环境&#xff1a;终端输入 java -version 命令来确认版本是否符合 Elasticsearch 要求下载并解压 Elasticsearch&#xff1a;前往&#xff08;https://www.elastic.co/downloads/elasticsearch&#xff09;选择适合你的 Mac 系统的 Elasticsearch 版本…

python使用蓝牙库选择

蓝牙库选择 pybluez 项目地址&#xff1a;https://github.com/pybluez/pybluez 文档地址&#xff1a;https://pybluez.readthedocs.io/en/latest/index.html 蓝牙支持&#xff1a;经典蓝牙 / BLE蓝牙【仅Linux】 平台支持&#xff1a; LinuxRaspberry PimacOSWindows✔️✔️…

分享40个Python源代码总有一个是你想要的

分享40个Python源代码总有一个是你想要的 源码下载链接&#xff1a;https://pan.baidu.com/s/1PNR3_RqVWLPzSBUVAo2rnA?pwd8888 提取码&#xff1a;8888 下面是文件的名字。 dailyfresh-天天生鲜 Django-Quick-Start freenom-自动续期域名的脚本 Full Stack Python简体中…

ADC数模转化器

简介 • ADC &#xff08; Analog-Digital Converter &#xff09;模拟 - 数字转换器 • ADC 可以将引脚上连续变化的模拟电压转换为内存中存储的数字变量&#xff0c;建立模拟电路到数字电路的桥梁 • 12 位逐次逼近型 ADC &#xff0c; 1us 转换时间 &#xff08;12位:分辨率…

02 MIT线性代数-矩阵消元 Elimination with matrices

一, 消元法 Method of Elimination 消元法是计算机软件求解线形方程组所用的最常见的方法。任何情况下&#xff0c;只要是矩阵A可逆&#xff0c;均可以通过消元法求得Axb的解 eg: 我们将矩阵左上角的1称之为“主元一”&#xff08;the first pivot&#xff09;&#xff0c;第…

MySQL库表操作

开始之前分享一个数据库远程连接工具&#xff1a; Navicat Premium 15 可以远程连接数据库&#xff0c;并且查看表的结构等&#xff0c;非常好用。 SQL语句基础 SQL&#xff1a;结构化查询语言(Structured Query Language)&#xff0c;在关系型数据库上执行数据操作、数据检索…

景联文数据标注:AI大模型产生幻觉该如何应对?

大语言模型在诸多下游任务中展现出令人瞩目的能力&#xff0c;然而在运用过程中仍然存在一些问题。幻觉现象是目前阻碍大模型成功应用的关键问题之一。 什么是大模型幻觉问题&#xff1f; 大模型幻觉问题是指一些人工智能模型在面对某些输入时&#xff0c;会生成不准确、不完整…

Visual Studio Cpp CLR C# 替换

1、首先将文件中所有都替换 你需要的名字 替换为整个解决方案 2、新建工程取名 Laserbeam_upper 3、把原工程下的cpp放进来&#xff0c;并改名Laserbeam_upper 4、在这里逐步添加 属性表配置opencv 5、cpp需要修改的两个地方 6、CLR新建和添加 选类库新建、然后直接粘贴进来…

阿里云七代云服务器实例、倚天云服务器及通用算力型和经济型实例规格介绍

在目前阿里云的云服务器产品中&#xff0c;既有五代六代实例规格&#xff0c;也有七代和八代倚天云服务器&#xff0c;同时还有通用算力型及经济型这些刚推出不久的新品云服务器实例&#xff0c;其中第五代实例规格目前不在是主推的实例规格了&#xff0c;现在主售的实例规格是…

php实现分页功能跳转和ajax方式实现

实现效果 准备工作 创建数据表和导入测试数据 CREATE TABLE users ( id int(10) unsigned NOT NULL AUTO_INCREMENT, username varchar(30) DEFAULT NULL COMMENT 账号, email varchar(30) DEFAULT NULL COMMENT 密码, PRIMARY KEY (id) ) ENGINEMyISAM AUTO_INCREM…

【OpenSSL】OpenSSL实现Base64

Base 64概述和应用场景 概述 Base64就是将二进制数据转换为字符串的一种算法。 应用场景 邮件编码xml或则json存储二进制内容网页传递数据URL数据库中以文本形式存放二进制数据可打印的比特币钱包地址base58Check(hash校验)网页上可以将图片直接使用Base64表达公私密钥的文…