RIPGeo代码理解(三)layers.py(注意力机制的代码)

代码链接:RIPGeo代码实现

├── lib # 包含模型(model)实现文件
    │        |── layers.py # 注意力机制的代码。
    │        |── model.py # TrustGeo的核心源代码。
    │        |── sublayers.py # layer.py的支持文件。
    │        |── utils.py # 辅助函数。

一、导入模块

import torch.nn as nn
import torch
from .sublayers import MultiHeadAttention, PositionwiseFeedForward
import numpy as np
import torch.functional as F

这段代码是一个自定义的神经网络模块,其中使用了 PyTorch 库。

1、import torch.nn as nn:导入 PyTorch 中的神经网络模块,这是定义神经网络层和模型的基本库。

2、import torch:导入 PyTorch 库,用于张量(tensor)操作和其他深度学习功能。

3、from .sublayers import MultiHeadAttention, PositionwiseFeedForward:从当前目录下的 sublayers 模块导入 MultiHeadAttentionPositionwiseFeedForward。这表明在当前文件所在目录下存在一个名为 sublayers.py 的文件,其中定义了这两个子层(sublayers)。

4、import numpy as np:导入 NumPy 库,通常用于数值计算和数组操作。在这里可能是为了使用一些与数组相关的功能。

5、import torch.functional as F:导入 PyTorch 的 functional 模块,并将其命名为 F。这个模块包含了一些与神经网络操作相关的函数,如激活函数等。

二、SimpleAttention1类定义(NN模型)

class SimpleAttention1(nn.Module):''' Just follow GraphGeo '''def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):super().__init__()self.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)self.q_w = nn.Linear(d_q_in, d_q_out)self.k_w = nn.Linear(d_k_in, d_k_out)self.v_w = nn.Linear(d_v_in, d_v_out)self.drop_last_layer = drop_last_layerdef forward(self, q, k, v):q = self.q_w(q)k = self.k_w(k)v = self.v_w(v)att_score = (q / self.temperature) @ k.transpose(0, 1)att_weight = torch.softmax(att_score, dim=-1)output = att_weight @ vreturn output, att_weight

这是一个简单的注意力机制模块,用于计算注意力分数并将其应用于值(value)向量

分为几个部分展开描述:

(一)__init__()

def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):super().__init__()self.temperature = temperatureself.dropout = nn.Dropout(attn_dropout)self.q_w = nn.Linear(d_q_in, d_q_out)self.k_w = nn.Linear(d_k_in, d_k_out)self.v_w = nn.Linear(d_v_in, d_v_out)self.drop_last_layer = drop_last_layer

初始化函数,用于定义模块的结构。

1、def __init__(self, temperature, attn_dropout=0.1, d_q_in=32, d_q_out=32, d_k_in=32, d_k_out=32, d_v_in=32,d_v_out=32, dropout=0.1, drop_last_layer=False):这是初始化方法的签名,它接受一系列参数,用于配置自注意力模块。这些参数包括温度(temperature)、注意力层的 dropout(attn_dropout)、输入和输出维度的设置(d_q_ind_q_outd_k_ind_k_outd_v_ind_v_out),以及是否在最后一层使用 dropout(dropout)等。

2、super().__init__():调用父类(nn.Module)的初始化方法,确保正确地初始化该模块。

3、self.temperature = temperature:将传入的温度参数保存为类成员变量,用于调整注意力分布的尖锐度。

4、self.dropout = nn.Dropout(attn_dropout):创建一个 dropout 层,用于在自注意力中进行随机失活。

5、self.q_w = nn.Linear(d_q_in, d_q_out):创建一个线性层,用于将输入的查询(query)向量进行线性变换,从 d_q_in 维映射到 d_q_out 维。

6、self.k_w = nn.Linear(d_k_in, d_k_out):创建一个线性层,用于将输入的键(key)向量进行线性变换,从 d_k_in 维映射到 d_k_out 维。

7、self.k_v = nn.Linear(d_v_in, d_v_out):创建一个线性层,用于将输入的值(value)向量进行线性变换,从 d_v_in 维映射到 d_v_out 维。

(二)forward()

def forward(self, q, k, v):q = self.q_w(q)k = self.k_w(k)v = self.v_w(v)att_score = (q / self.temperature) @ k.transpose(0, 1)att_weight = torch.softmax(att_score, dim=-1)output = att_weight @ vreturn output, att_weight

这段代码实现了一个简单的自注意力机制,其中q(查询)、k(键)和v(值)分别通过线性变换(矩阵乘法)得到新的表示。这是自注意力机制的基本组成部分,其整体功能是通过计算注意力分数(att_score),生成加权和的输出(output)。整体上,这个前向函数实现了自注意力机制的计算,用于将输入序列中的每个元素与其他元素进行交互,并生成加权和的输出。这是 Transformer 网络中的关键组件。

1、q = self.q_w(q):通过线性变换(全连接层)对查询 q 进行变换。self.q_w 是查询的权重矩阵。

2、k = self.k_w(k):通过线性变换对键 k 进行变换。self.k_w 是键的权重矩阵。

3、v = self.v_w(v):通过线性变换对值 v 进行变换。self.v_w 是值的权重矩阵。

4、att_score = (q / self.temperature) @ k.transpose(0, 1):计算注意力分数。首先,将查询 q 除以 self.temperature(温度参数,用于

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/760794.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32CubeMX学习笔记23---FreeRTOS(任务的挂起与恢复)

1、硬件设置 本实验通过freertos创建两个任务来分别控制LED2和LED3的亮灭,需要用到的硬件资源 LED2和LED3指示灯串口 2、STM32CubeMX设置 根据上一章的步骤创建两个任务:STM32CubeMX学习笔记22---FreeRTOS(任务创建和删除)-CS…

FPGA - SPI总线介绍以及通用接口模块设计

一,SPI总线 1,SPI总线概述 SPI,是英语Serial Peripheral interface的缩写,顾名思义就是串行外围设备接口。串行外设接口总线(SPI),是一种高速的,全双工,同步的通信总线,并且在芯片的…

【C++】为什么vector的地址与首元素地址不同?

文章目录 一、问题发现:二、结果分析三、问题解析 一、问题发现: &vector和&vector[0]得到的两个地址居然不相同,对数组array取变量名地址和取首元素地址的结果是相同的。这是为啥呢? 使用下面代码进行验证:…

html5cssjs代码 035 课程表

html5&css&js代码 035 课程表 一、代码二、解释基本结构示例代码常用属性样式和装饰响应式表格辅助技术 一个具有亮蓝色背景的网页,其中包含一个样式化的表格用于展示一周课程安排。表格设计了交替行颜色、鼠标悬停效果以及亮色表头,并对单元格设…

MoonBit 首场 MeetUp 活动火热报名中!更多活动惊喜等你来探索!

首场线下MeetUp来啦! 在数字化浪潮中,基础软件构筑了信息产业发展的根基,不仅是技术进步的支柱,也是推动经济增长的重要力量。基础软件的发展不仅关乎硬件的完善与应用软件的创新,更是连接过去与未来的桥梁。 尽管中国…

Docker容器化技术(docker-compose安装部署案例)

docker-compose编排工具 安装docker-compose [rootservice ~]# systemctl stop firewalld [rootservice ~]# setenforce 0 [rootservice ~]# systemctl start docker[rootservice ~]# wget https://github.com/docker/compose/releases/download/v2.5.0/docker-compose-linux-…

anaconda迁移深度学习虚拟环境 在云服务器上配置

1 anaconda 虚拟环境操作 1、 查看虚拟环境 conda info -e2、 创建新的虚拟环境 conda create -n deeplearning_all pip python3.63、 激活新建的虚拟环境 Conda activate deeplearning_all2 环境中相关库的版本即安装说明(这些库都是对应匹配的) …

Lenze伦茨8400变频器E84A L-force Drives 操作使用说明

Lenze伦茨8400变频器E84A L-force Drives 操作使用说明

跟selenium并肩的自动化神器 Playwright 的 Web 自动化测试解决方案

1. 主流框架的认识 总结: 由于Selenium在3.x和4.x两个版本的迭代中并没有发生多大的变化,因此Selenium一统天下的地位可能因新框架的出现而变得不那么稳固。后续的Cypress、TestCafe、Puppeteer被誉为后Selenium时代Web UI自动化的三驾马车。但是由于这三…

JavaEE 初阶篇-深入了解操作系统中的进程与 PCB

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 关于计算机是如何进行工作的 “常识” 1.1 关于寄存器、缓存与内存是如何配合 CPU “工作” 2.0 操作系统概述 2.1 操作系统内核 2.2 进程 2.3 PCB 2.3.1 PCB 属性…

Centos7没有可用软件包 ifconfig问题解决

问题描述 在Centos7中查看ip没有ifconfig,使用yum安装ifconfig报错没有可用软件包 ifconfig问题解决 [rootlocalhost etc]# yum -y install ifconfig 已加载插件:fastestmirror base …

动手做简易版俄罗斯方块

导读:让我们了解如何处理形状的旋转、行的消除以及游戏结束条件等控制因素。 目录 准备工作 游戏设计概述 构建游戏窗口 游戏方块设计 游戏板面设计 游戏控制与逻辑 行消除和计分 判断游戏结束 界面美化和增强体验 看看游戏效果 准备工作 在开始编码之前…

火灾自动报警及消防联动控制系统主机的九个主要组成部分

关于火灾报警联动系统的主机组成,一般有两种不同的概括,下面分别讨论。 一: 火灾报警主机的组成部分较多,主要包括以下消防设备:主电源、联动电源、打印机、驱动器、直接控制板、总线控制板、消防广播、消防电话主机…

粒子群算法 - 目标函数最优解计算

粒子群算法概念 粒子群算法 (particle swarm optimization,PSO) 由 Kennedy 和 Eberhart 在 1995 年提出,该算法模拟鸟群觅食的方法进行寻找最优解。基本思想:人们发现,鸟群觅食的方向由两个因素决定。第一个是自己当初飞过离食物…

FPGA工程正确的设计流程

1 正确的设计流程 分析项目的具体需求来设计系统的结构,划分系统的层次,确定各个子模块的结构关系和信号之间的相互关系,然后确定模块的端口信号等根据每隔模块的功能和自己的理解,结合芯片手册接口的时序,使用visio画…

基于QGraphicsView的图像显示控件,支持放大、缩小、鼠标拖动

原链接 前言 这是一个Qt平台的基于QGraphicsView类的图像显示控件,支持输入QPixmap、QImage、opencv的从cv::Mat类。 实现平台:Windows 10 x64 Qt 6.2.3 MSVC 2019 opencv 4.5 先来看演示视频 控件类实现 ImageViewer.h文件 #ifndef IMAGEVIEWER…

Docker 笔记(八)--Dockerfile

目录 1. 背景2. 参考3. 原文3.1 Dockerfile 支持的指令3.2 Dockerfile格式3.3 Parser指令syntaxescape 3.4 环境变量替换3.5 docker构建忽略文件3.6 Shell 和 exec 格式Exec 格式Shell 格式使用不同的 shell 3.7 FROM指令了解ARG和FROM如何交互 3.8 RUN指令RUN指令缓存失效RUN …

4 CUDA 环境搭建

4.1 简介 本章面向从未接触过CUDA的初学者。我们将依次介绍如何在不同操作系统上安装CUDA、有哪些可用的CUDA 工具以及CUDA如何编译代码,最后介绍应用程序接口提供的错误处理手段,并帮助读者识别CUDA代码和开发过程中必然碰到的应用程序接口报错。Windo…

java框架 2 springboot 过滤器 拦截器 异常处理 事务管理 AOP

Filter 过滤器 对所有请求都可以过滤。 实现Filter接口,重写几个方法,加上WebFilter注解,表示拦截哪些路由,如上是所有请求都会拦截。 然后还需要在入口处加上SvlterComponentScan注解,因为Filter是javaweb三大组件之…

Leetcode刷题【每日n题】(8)

题目一 思路分析 1.循环遍历直到这个数小于102.获取每个位数之合3.将合赋值给目标数,直到小于10 代码实现 class MyTest{public int addDigits(int num) {//直到目标数小于10while(num>10){//定义各个位数合int sum0;//num不能为0while(num>0){//获取每个位上…