RIPGeo代码理解(三)layers.py(注意力机制的代码)

代码链接:RIPGeo代码实现

├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数。

一、导入模块

import torch.nn as nn
import torch
from .sublayers import MultiHeadAttention, PositionwiseFeedForward
import numpy as np
import torch.functional as F

这段代码是一个自定义的神经网络模块,其中使用了 PyTorch 库。

1、import torch.nn as nn:导入 PyTorch 中的神经网络模块,这是定义神经网络层和模型的基本库。

2、import torch:导入 PyTorch 库,用于张量(tensor)操作和其他深度学习功能。

3、from .sublayers import MultiHeadAttention, PositionwiseFeedForward:从当前目录下的 sublayers 模块导入 MultiHeadAttentionPositionwiseFeedForward。这表明在当前文件所在目录下存在一个名为 sublayers.py 的文件,其中定义了这两个子层(sublayers)。

4、import numpy as np:导入 NumPy 库,通常用于数值计算和数组操作。在这里可能是为了使用一些与数组相关的功能。

5、import torch.functional as F:导入 PyTorch 的 functional 模块,并将其命名为 F。这个模块包含了一些与神经网络操作相关的函数,如激活函数等。

二、SimpleAttention1类定义(NN模型)

class SimpleAttention1(nn.Module):''' Just follow GraphGeo '''def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):super().__init__()self.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)self.q_w = nn.Linear(d_q_in, d_q_out)self.k_w = nn.Linear(d_k_in, d_k_out)self.v_w = nn.Linear(d_v_in, d_v_out)self.drop_last_layer = drop_last_layerdef forward(self, q, k, v):q = self.q_w(q)k = self.k_w(k)v = self.v_w(v)att_score = (q / self.temperature) @ k.transpose(0, 1)att_weight = torch.softmax(att_score, dim=-1)output = att_weight @ vreturn output, att_weight

这是一个简单的注意力机制模块,用于计算注意力分数并将其应用于值(value)向量

分为几个部分展开描述:

(一)__init__()

def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):super().__init__()self.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)self.q_w = nn.Linear(d_q_in, d_q_out)self.k_w = nn.Linear(d_k_in, d_k_out)self.v_w = nn.Linear(d_v_in, d_v_out)self.drop_last_layer = drop_last_layer

初始化函数,用于定义模块的结构。

1、def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):这是初始化方法的签名,它接受一系列参数,用于配置自注意力模块。这些参数包括温度(temperature)、注意力层的 dropout(attn_dropout)、输入和输出维度的设置(d_q_ind_q_outd_k_ind_k_outd_v_ind_v_out),以及是否在最后一层使用 dropout(dropout)等。

2、super().__init__():调用父类(nn.Module)的初始化方法,确保正确地初始化该模块。

3、self.temperature = temperature:将传入的温度参数保存为类成员变量,用于调整注意力分布的尖锐度。

4、self.dropout = nn.Dropout(attn_dropout):创建一个 dropout 层,用于在自注意力中进行随机失活。

5、self.q_w = nn.Linear(d_q_in, d_q_out):创建一个线性层,用于将输入的查询(query)向量进行线性变换,从 d_q_in 维映射到 d_q_out 维。

6、self.k_w = nn.Linear(d_k_in, d_k_out):创建一个线性层,用于将输入的键(key)向量进行线性变换,从 d_k_in 维映射到 d_k_out 维。

7、self.k_v = nn.Linear(d_v_in, d_v_out):创建一个线性层,用于将输入的值(value)向量进行线性变换,从 d_v_in 维映射到 d_v_out 维。

(二)forward()

def forward(self, q, k, v):q = self.q_w(q)k = self.k_w(k)v = self.v_w(v)att_score = (q / self.temperature) @ k.transpose(0, 1)att_weight = torch.softmax(att_score, dim=-1)output = att_weight @ vreturn output, att_weight

这段代码实现了一个简单的自注意力机制,其中q(查询)、k(键)和v(值)分别通过线性变换(矩阵乘法)得到新的表示。这是自注意力机制的基本组成部分,其整体功能是通过计算注意力分数(att_score),生成加权和的输出(output)。整体上,这个前向函数实现了自注意力机制的计算,用于将输入序列中的每个元素与其他元素进行交互,并生成加权和的输出。这是 Transformer 网络中的关键组件。

1、q = self.q_w(q):通过线性变换(全连接层)对查询 q 进行变换。self.q_w 是查询的权重矩阵。

2、k = self.k_w(k):通过线性变换对键 k 进行变换。self.k_w 是键的权重矩阵。

3、v = self.v_w(v):通过线性变换对值 v 进行变换。self.v_w 是值的权重矩阵。

4、att_score = (q / self.temperature) @ k.transpose(0, 1):计算注意力分数。首先,将查询 q 除以 self.temperature(温度参数,用于

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/760794.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uView Badge 徽标数

该组件一般用于图标右上角显示未读的消息数量,提示用户点击,有圆点和圆包含文字两种形式。 #平台差异说明 App(vue)App(nvue)H5小程序√√√√ #基本使用 通过value参数定义徽标内容通过type设置主题。重…

【教你如何制作一个简单的HTML个人网页】

制作一个简单的HTML个人网页 创建一个简单个人的HTML网页很容易,下面是一个基本的示例,其中包含一些常见的元素,比如标题、段落、一张图片和一些链接,请记住,您将需要一个地方来存储您的HTML文件和任何相关资源&#…

电子商务类网站搭建需要注意的几点。

随着电子商务的迅猛发展,越来越多的企业和创业者选择在互联网上开设自己的电商网站。为了确保电商网站能够高效运行,给用户提供良好的体验,选择合适的服务器配置至关重要。 一、硬件配置 1、 CPU(中央处理器) 电商网…

STM32CubeMX学习笔记23---FreeRTOS(任务的挂起与恢复)

1、硬件设置 本实验通过freertos创建两个任务来分别控制LED2和LED3的亮灭,需要用到的硬件资源 LED2和LED3指示灯串口 2、STM32CubeMX设置 根据上一章的步骤创建两个任务:STM32CubeMX学习笔记22---FreeRTOS(任务创建和删除)-CS…

vue01

一、什么是vue.js(单页面应用程序) 用于构建用户界面的渐进式框架,采用自底向上增量开发的设计。核心理念:数据驱动视图,组件化开发前端三大主流框架:Vue.js Angular.js React.js 二、为什么学习流行框架…

1060:均值

【题目描述】 给出一组样本数据,包含n个浮点数,计算其均值,精确到小数点后4位。 【输入】 输入有两行,第一行包含一个整数n(n小于100),代表样本容量;第二行包含n个绝对值不超过10…

FPGA - SPI总线介绍以及通用接口模块设计

一,SPI总线 1,SPI总线概述 SPI,是英语Serial Peripheral interface的缩写,顾名思义就是串行外围设备接口。串行外设接口总线(SPI),是一种高速的,全双工,同步的通信总线,并且在芯片的…

Debian时间和时区配置

1. 时区 1.1. 查看时区 timedatectl输出 Local time: Thu 2024-03-07 13:46:06 CSTUniversal time: Thu 2024-03-07 05:46:06 UTCRTC time: Thu 2024-03-07 05:46:06Time zone: Asia/Shanghai (CST, 0800) System clock synchronized: yesNTP service: activeRTC in local TZ…

面试十一、代理模式

代理模式是一种结构型设计模式,旨在为其他对象提供一种代理或替代方法,以控制对这些对象的访问。在代理模式中,代理对象充当了客户端和目标对象之间的中间人,客户端通过代理访问目标对象,而不直接访问目标对象。 代理模…

【C++】为什么vector的地址与首元素地址不同?

文章目录 一、问题发现:二、结果分析三、问题解析 一、问题发现: &vector和&vector[0]得到的两个地址居然不相同,对数组array取变量名地址和取首元素地址的结果是相同的。这是为啥呢? 使用下面代码进行验证:…

Oracle中全表扫描优化方法

在Oracle数据库中,全表扫描(Full Table Scan, FTS)是指查询执行时扫描表的所有数据块来获取结果集。虽然在某些场景下全表扫描可能是最优选择(例如:当需要访问大部分或全部数据、表很小或者索引访问成本高于全表扫描时…

MKdocs博客中文教程 - 已经整理到知乎专栏

MKdocs博客中文教程 - 知乎 Mkdocs-Wcowin中文主题 通过主题和目录以打开文章 基于Material for MkDocs美化简洁美观,功能多元化简单易上手,小白配置教程详细,清晰易懂

html5cssjs代码 035 课程表

html5&css&js代码 035 课程表 一、代码二、解释基本结构示例代码常用属性样式和装饰响应式表格辅助技术 一个具有亮蓝色背景的网页,其中包含一个样式化的表格用于展示一周课程安排。表格设计了交替行颜色、鼠标悬停效果以及亮色表头,并对单元格设…

C++基础之运算符重载续(十三)

一.函数调用运算符 我们知道,普通函数执行时,有一个特点就是无记忆性,一个普通函数执行完毕,它所在的函数栈空间就会被销毁,所以普通函数执行时的状态信息,是无法保存下来的,这就让它无法应用在…

python实现 linux 执行命令./test启动进程,进程运行中,输入参数s, 再输入参数1, 再输入参数exit, 获取进程运行结果重定向写入到文件

要在 Python 中实现执行 ./test 启动进程,并在进程运行中依次输入参数 s、1,最后输入参数 exit,并将进程的输出结果重定向写入到文件,你可以使用 subprocess 模块。以下是一个示例代码: import subprocess# 启动 test…

MoonBit 首场 MeetUp 活动火热报名中!更多活动惊喜等你来探索!

首场线下MeetUp来啦! 在数字化浪潮中,基础软件构筑了信息产业发展的根基,不仅是技术进步的支柱,也是推动经济增长的重要力量。基础软件的发展不仅关乎硬件的完善与应用软件的创新,更是连接过去与未来的桥梁。 尽管中国…

Docker容器化技术(docker-compose安装部署案例)

docker-compose编排工具 安装docker-compose [rootservice ~]# systemctl stop firewalld [rootservice ~]# setenforce 0 [rootservice ~]# systemctl start docker[rootservice ~]# wget https://github.com/docker/compose/releases/download/v2.5.0/docker-compose-linux-…

anaconda迁移深度学习虚拟环境 在云服务器上配置

1 anaconda 虚拟环境操作 1、 查看虚拟环境 conda info -e2、 创建新的虚拟环境 conda create -n deeplearning_all pip python3.63、 激活新建的虚拟环境 Conda activate deeplearning_all2 环境中相关库的版本即安装说明(这些库都是对应匹配的) …

Lenze伦茨8400变频器E84A L-force Drives 操作使用说明

Lenze伦茨8400变频器E84A L-force Drives 操作使用说明

跟selenium并肩的自动化神器 Playwright 的 Web 自动化测试解决方案

1. 主流框架的认识 总结: 由于Selenium在3.x和4.x两个版本的迭代中并没有发生多大的变化,因此Selenium一统天下的地位可能因新框架的出现而变得不那么稳固。后续的Cypress、TestCafe、Puppeteer被誉为后Selenium时代Web UI自动化的三驾马车。但是由于这三…