transformer学习笔记-自注意力机制(2)

经过上一篇transformer学习笔记-自注意力机制(1)原理学习,这一篇对其中的几个关键知识点代码演示:

1、整体qkv注意力计算

先来个最简单未经变换的QKV处理:

import torch  
Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
K = Q.T
V = Qscores = Q @ K #计算内积
weights = torch.softmax(scores, dim=0)
print(f"概率分布:{weights}")
newQ = weights @ V
print(f"输出:{newQ}")

再来个输入经过Wq/Wk/Wv变换的:

import torch  
Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
torch.manual_seed(123)  
d_q, d_k, d_v = 4, 4, 5 # W_query, W_key, W_value 的维度  
d = Q.shape[1] #  W_query, W_key, W_value 的行数等于输入token的维度
# 获取W_query, W_key, W_value(随机生成)
W_query = torch.nn.Parameter(torch.rand(d, d_q))  
W_key = torch.nn.Parameter(torch.rand(d, d_k))  
W_value = torch.nn.Parameter(torch.rand(d, d_v))print("W_query:", W_query)
print("W_key:", W_key)
print("W_value:", W_value)#先只计算苹果对整个句子的注意力,看看效果
apple = Q[0]
query_apple = apple @ W_query  
keys = Q @ W_key  
values = Q @ W_value  
print(f"query_apple:{query_apple}")
print(f"keys:{keys}")
print(f"values:{values}")
scores = query_apple @ keys.T
print(f"scores:{scores}")
weights = torch.softmax(scores, dim=0)
print(f"weights:{weights}")
newQ = weights @ values
print(f"newQ:{newQ}")#再看下整体的
querys = Q @ W_query
all_scores = querys @ keys.T
print(f"all_scores:{all_scores}")
all_weights = torch.softmax(all_scores, dim=-1)
print(f"all_weights:{all_weights}")
output = all_weights @ values
print(f"output:{output}")

在这里插入图片描述
最终生成的output的维度与W_value 的维度一致。

2、调换顺序结果不变

import torchdef simple_attention(Q):K = Q.TV = Qscores = Q @ K #计算内积weights = torch.softmax(scores, dim=-1)print(f"概率分布:{weights}")newQ = weights @ Vprint(f"输出:{newQ}")Q = torch.tensor([[3.0, 3.0,0.0],[0.5, 4.0,0.0]])
Q1 = torch.tensor([[0.5, 4.0,0.0],[3.0, 3.0,0.0]])
print("模拟‘苹果梨’:")
simple_attention(Q)
print("模拟‘梨苹果’:")
simple_attention(Q1)

在这里插入图片描述
可以看到“苹果梨”、“梨苹果”即便换了词token的顺序,并不会影响新的梨和新的苹果的向量数值。这里我们用了softmax函数求概率分布,因此跟上一篇文章的示例数值不一样,不要在意这个细节。

3、softmax:

import numpy as npdef softmax(x):e_x = np.exp(x)return e_x / e_x.sum(axis=0)def softmax_with_temperature(x,T):e_x = np.exp(x/T)return e_x / e_x.sum(axis=0)# 示例使用
if __name__ == "__main__":input_vector = np.array([2.0, 1.0, 0.1])output = softmax(input_vector)print("Softmax Output:", output)print("Softmax with Temperature 0.5 Output:", softmax_with_temperature(input_vector,0.5))print("Softmax with Temperature 1 Output:", softmax_with_temperature(input_vector,1))print("Softmax with Temperature 5 Output:", softmax_with_temperature(input_vector,5))

在这里插入图片描述
可以看到随着T的不断加大,概率分布不断趋于均匀分布。

4、softmax除以 d k \sqrt{d_k} dk

还是用上面的softmax函数,演示下除以 d k \sqrt{d_k} dk 的效果:

        # 高维输入向量input_vector_high_dim = np.random.randn(100) * 10  # 生成一个100维的高斯分布随机向量,乘以10增加内积output_high_dim = softmax(input_vector_high_dim)print("High Dimension Softmax Output:", output_high_dim)# 打印高维输出的概率分布print("Max Probability in High Dimension:", np.max(output_high_dim))print("Min Probability in High Dimension:", np.min(output_high_dim))# 高维输入向量除以10input_vector_high_dim_div10 = input_vector_high_dim / 10output_high_dim_div10 = softmax(input_vector_high_dim_div10)print("High Dimension Softmax Output (Divided by 10):", output_high_dim_div10)# 打印高维输出的概率分布print("Max Probability in High Dimension (Divided by 10):", np.max(output_high_dim_div10))print("Min Probability in High Dimension (Divided by 10):", np.min(output_high_dim_div10))# 绘制高维概率分布曲线plt.figure(figsize=(10, 6))# 绘制图形plt.plot(output_high_dim, label='High Dim')plt.plot(output_high_dim_div10, label='High Dim Divided by 10')plt.legend()plt.title('High Dimension Softmax Output Comparison')plt.xlabel('Index')plt.ylabel('Probability')plt.show()

在这里插入图片描述
在除以 d k \sqrt{d_k} dk 之前,由于内积变大,导致概率分布变得尖锐,趋近0的位置梯度基本消失,softmax 函数的损失函数的导数在输出接近 0 时接近零,在反向传播过程中,无法有效地更新权重。有兴趣的话可以试试对softmax 函数的损失函数求导。

继续上面的代码,来看下softmax的输出的损失函数求梯度:

        def test_grad( dim_vertor):import numpy as npimport torchimport torch.nn.functional as F# 假设的输入z = torch.tensor(dim_vertor, requires_grad=True)print(z)# 计算 softmax 输出p = F.softmax(z, dim=0)true_label = np.zeros(100)true_label[3] = 1# 模拟损失函数(例如交叉熵)y = torch.tensor(true_label)  # one-hot 编码的真实标签loss = -torch.sum(y * torch.log(p))# 反向传播并获取梯度loss.backward()# print(z.grad)  # 输出梯度return z.gradgrad_div10 = test_grad(input_vector_high_dim_div10)grad = test_grad(input_vector_high_dim)print(f"grad_div10:{grad_div10}")print(f"grad:{grad}")

在这里插入图片描述
明显看出,没有除以 d k \sqrt{d_k} dk 求出的梯度,基本为0;上面的代码是torch已经实现的。当然也可以根据损失函数自己求导,这里我们只为演示效果,点到即止:

5、多头注意力:

import torch
import torch.nn as nntorch.manual_seed(123)# 输入矩阵 Q
Q = torch.tensor([[3.0, 3.0, 0.0],[0.5, 4.0, 0.0]])# 维度设置
d_q, d_k, d_v = 4, 4, 5  # 每个头的 query, key, value 的维度
d_model = Q.shape[1]     # 输入 token 的维度
num_heads = 2            # 头的数量# 初始化每个头的权重矩阵
W_query = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_q)) for _ in range(num_heads)])
W_key = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_k)) for _ in range(num_heads)])
W_value = nn.ParameterList([nn.Parameter(torch.rand(d_model, d_v)) for _ in range(num_heads)])# 输出权重矩阵
W_output = nn.Parameter(torch.rand(num_heads * d_v, d_model))# 打印权重矩阵
for i in range(num_heads):print(f"W_query_{i+1}:\n{W_query[i]}")print(f"W_key_{i+1}:\n{W_key[i]}")print(f"W_value_{i+1}:\n{W_value[i]}")# 计算每个头的 Q, K, V
queries = [Q @ W_query[i] for i in range(num_heads)]
keys = [Q @ W_key[i] for i in range(num_heads)]
values = [Q @ W_value[i] for i in range(num_heads)]# 计算每个头的注意力分数和权重
outputs = []
for i in range(num_heads):scores = queries[i] @ keys[i].T / (d_k ** 0.5)weights = torch.softmax(scores, dim=-1)output = weights @ values[i]outputs.append(output)# 拼接所有头的输出
concat_output = torch.cat(outputs, dim=-1)
print(f"concat_output:\n{concat_output}")
# 最终线性变换
final_output = concat_output @ W_output# 打印结果
print(f"Final Output:\n{final_output}")

6、掩码注意力:

import torch# 原始 Q 矩阵
Q = torch.tensor([[3.0, 3.0, 0.0],[0.5, 4.0, 0.0],[1.0, 2.0, 0.0],[2.0, 1.0, 0.0]])torch.manual_seed(123)
d_q, d_k, d_v = 4, 4, 5  # query, key, value 的维度
d = Q.shape[1]           # query, key, value 的行数等于输入 token 的维度# 初始化权重矩阵
W_query = torch.nn.Parameter(torch.rand(d, d_q))
W_key = torch.nn.Parameter(torch.rand(d, d_k))
W_value = torch.nn.Parameter(torch.rand(d, d_v))print("W_query:", W_query)
print("W_key:", W_key)
print("W_value:", W_value)# 计算 Q, K, V
querys = Q @ W_query
keys = Q @ W_key
values = Q @ W_valueprint(f"querys:\n{querys}")
print(f"keys:\n{keys}")
print(f"values:\n{values}")# 计算注意力分数
all_scores = querys @ keys.T / (d_k ** 0.5)
print(f"all_scores:\n{all_scores}")# 生成掩码
seq_len = Q.shape[0]
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
masked_scores = all_scores.masked_fill(mask, float('-inf'))print(f"Mask:\n{mask}")
print(f"Masked Scores:\n{masked_scores}")# 计算权重
all_weights = torch.softmax(masked_scores, dim=-1)
print(f"all_weights:\n{all_weights}")# 计算输出
output = all_weights @ values
print(f"output:\n{output}")

主要看下生成的掩码矩阵,和通过掩码矩阵处理的权重分布:在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64295.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leecode刷题C++之形成目标字符串需要的最少字符串数①

执行结果:通过 执行用时和内存消耗如下&#xff1a; 代码如下&#xff1a; class Solution { public:int minValidStrings(vector<string>& words, string target) {auto prefix_function [](const string& word, const string& target) -> vector<…

CompletableFuture异步业务 默认ForkJoinPool 导致类加载器加载类失败

目录 1、Bug案发现场 2、捉虫过程过程 3、解决方案与代码 4、成果展现与总结 ​编辑 5、参考文章 1、Bug案发现场 最近参与帮助以前同事实际业务开发中业务&#xff0c;在一个业务场景之中&#xff1b;使用H5页面通过二维码收集小微企业/个体工商户贷款业务需求。其中在获…

mybatis-plus超详细讲解

mybatis-plus &#xff08;简化代码神器&#xff09; 地址&#xff1a;https://mp.baomidou.com/ 目录 mybatis-plus 简介 特性 支持数据库 参与贡献 快速指南 1、创建数据库 mybatis_plus 2、导入相关的依赖 3、创建对应的文件夹 4、编写配置文件 5、编写代码 …

Houdini abc 导入 maya uv无法识别

参考&#xff1a;Houdini导出abc 至maya UV 无法识别_houdini导出abc没有uv-CSDN博客 从maya导入到houdini的uv默认是vertex层级的&#xff0c;而在maya中&#xff0c;uv是在point层级的&#xff1b;因此在houdini中导出abc时应将uv转为点层级&#xff0c;使用vertexsplit节点&…

2025erp系统开源免费进销存系统搭建教程/功能介绍/上线即可运营软件平台源码

系统介绍 基于ThinkPHP与LayUI构建的全方位进销存解决方案 本系统集成了采购、销售、零售、多仓库管理、财务管理等核心功能模块&#xff0c;旨在为企业提供一站式进销存管理体验。借助详尽的报表分析和灵活的设置选项&#xff0c;企业可实现精细化管理&#xff0c;提升运营效…

半导体数据分析(二):徒手玩转STDF格式文件 -- 码农切入半导体系列

一、概述 在上一篇文章中&#xff0c;我们一起学习了STDF格式的文件&#xff0c;知道了这是半导体测试数据的标准格式文件。也解释了为什么码农掌握了STDF文件之后&#xff0c;好比掌握了切入半导体行业的金钥匙。 从今天开始&#xff0c;我们一起来一步步地学习如何解构、熟…

OCR:文字识别

使用场景: 远程身份认证 自动识别录入用户身份/企业资质信息&#xff0c;应用于金融、政务、保险、电商、直播等场景&#xff0c;对用户、商家、主播进行实名身份认证&#xff0c;有效降低用户输入成本&#xff0c;控制业务风险 文档电子化 识别提取各类办公文档、合同文件、企…

深入C语言文件操作:从库函数到系统调用

引言 文件操作是编程中不可或缺的一部分&#xff0c;尤其在C语言中&#xff0c;文件操作不仅是处理数据的基本手段&#xff0c;也是连接程序与外部世界的重要桥梁。C语言提供了丰富的库函数来处理文件&#xff0c;如 fopen、fclose、fread、fwrite 等。然而&#xff0c;这些库…

linux上qt打包(二)

sudo apt install git 新建一个文件夹 名为xiazai&#xff0c; chmod -R 777 xiazai cd xiazai 并进入这个文件夹&#xff0c;然后clone git clone https://github.com/probonopd/linuxdeployqt.git 此处可能要fanQiang才能下 cd linuxdeployqt文件夹 下载平台需要的…

Windos中解决redis-server.exe闪退问题

一、闪退原因 &#xff08;一&#xff09;数据状态异常 数据不一致 在 Redis 运行过程中&#xff0c;如果发生意外情况&#xff0c;如突然断电、系统崩溃或者不正确的操作&#xff0c;可能会导致数据在内存中的存储状态不一致。例如&#xff0c;Redis 使用多种数据结构&#x…

【数据分享】2013-2023年我国省市县三级的逐年CO数据(免费获取\excel\shp格式)

空气质量数据是在我们日常研究中经常使用的数据&#xff01;之前我们给大家分享了2000-2023年的省市县三级的逐年PM2.5数据、2000-2023年的省市县三级的逐年PM10数据、2013-2023年的省市县三级的逐年SO2数据、2000-2023年省市县三级的逐年O3数据和2008-2023年我国省市县三级的逐…

华为WLAN基础配置(AC6005模拟配置)

AC6005基础配置 本次实验模拟华为AC6005的基本配置 Tip display interface GigabitEthernet 0/0/0 查看ap接口mac 前提条件&#xff1a;Vlan10为业务网段&#xff0c;vlan100为管理网段&#xff0c;5700作为dhcp。 5700配置如下 <Huawei>sy [Huawei]sys 5700 //设…

shell编程2 永久环境变量和字符串显位

声明 学习视频来自B站UP主 泷羽sec 常见变量 echo $HOME &#xff08;家目录 root用户&#xff09; /root cd /root windows的环境变量可以去设置里去新建 为什么输入ls dir的命令的时候就会输出相应的内容呢 因为这些命令都有相应的变量 which ls 通过这个命令查看ls命令脚本…

WebRTC服务质量(05)- 重传机制(02) NACK判断丢包

WebRTC服务质量&#xff08;01&#xff09;- Qos概述 WebRTC服务质量&#xff08;02&#xff09;- RTP协议 WebRTC服务质量&#xff08;03&#xff09;- RTCP协议 WebRTC服务质量&#xff08;04&#xff09;- 重传机制&#xff08;01) RTX NACK概述 WebRTC服务质量&#xff08;…

AI工具如何深刻改变我们的工作与生活

在当今这个科技日新月异的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经从科幻小说中的概念变成了我们日常生活中不可或缺的一部分。从智能家居到自动驾驶汽车&#xff0c;从医疗诊断到金融服务&#xff0c;AI正以惊人的速度重塑着我们的世界。 一、工作方式的革新…

基于matlab的单目相机标定

链接&#xff1a; 单目相机标定&#xff08;使用Matlab&#xff09; 用Matlab对单目相机参数的标定步骤&#xff08;保姆级教程&#xff09; 1.准备代码 调用摄像头代码&#xff08;用于测试摄像头是否可用&#xff09;&#xff1a; #https://blog.csdn.net/qq_37759113/art…

[maven]使用spring

为了更好理解springboot&#xff0c;我们先通过学习spring了解其底层。 这里讲一下简单的maven使用spring框架入门使用。因为这一块的东西很多都需要联合起来后才好去细讲&#xff0c;本篇通过spring-context大致地介绍相关内容。 注意&#xff1a;spring只是一个框架&#xff…

eBay如何养号?新手养号宝典

​ebay是热门的跨境电商平台之一&#xff0c;然而与其他跨境电商平台不同&#xff0c;不同等级的ebay账户可刊登的数量是不同的。对于新手来说&#xff0c;想要提升ebay账户的等级就需要养号。那ebay如何养号&#xff1f;本文将带来一些实用的养号策略&#xff0c;帮助新手快速…

学习日志024--opencv中处理轮廓的函数

目录 前言​​​​​​​ 一、 梯度处理的sobel算子函数 功能 参数 返回值 代码演示 二、梯度处理拉普拉斯算子 功能 参数 返回值 代码演示 三、Canny算子 功能 参数 返回值 代码演示 四、findContours函数与drawContours函数 功能 参数 返回值 代码演示 …

梳理你的思路(从OOP到架构设计)_UML应用:业务内涵的分析抽象表达03

目录 1、举例(四)&#xff1a;五子棋 【五子棋】 的分析步骤 2、讨论&#xff1a; 模型与代码 1、举例(四)&#xff1a;五子棋 【五子棋】 的分析步骤 Step-1: 找到主角— 棋手&#xff0c;很容易发现核心的概念了&#xff0c;例如&#xff1a;五子棋游戏的主角是棋手(玩家…