TrustGeo代码理解(三)model.py

代码链接:https://github.com/ICDM-UESTC/TrustGeo

一、导入各种模块和神经网络类

from math import gamma
from re import L
from .layers import *
import torch
import torch.nn as nn
import torch.nn.functional as Func
import numpy as np

这段代码是一个 Python 模块,包含了一些导入语句和定义了一个神经网络模型的类。

1、from math import gamma:导入了 gamma 函数,这是 Python 标准库中 math 模块中的一个函数,用于计算伽玛函数。
2、from re import L:导入了 L,这看起来是一个导入错误。通常来说,应该是导入正则表达式相关的模块,比如 import re。不过,这行可能是一个错误,可能需要修改。(好像没什么用)
3、from .layers import *:导入了当前模块所在目录中的 layers 模块中的所有内容。* 表示导入所有的内容。
4、import torch:导入了 PyTorch 库中的相关模块。torch 是主要的 PyTorch 模块。
5、import torch.nn as nn:导入了 PyTorch 库中的相关模块。tnn 包含了神经网络的构建块。
6、import torch.nn.functional as Func:导入了 PyTorch 库中的相关模块。functional 模块包含了一些与神经网络相关的函数。
7、import numpy as np:导入了 NumPy 库,NumPy 是一个用于科学计算的 Python 库,提供了大量用于数组操作的函数。

二、TrustGeo类定义(NN模型)

class TrustGeo(nn.Module):def __init__(self, dim_in):super(TrustGeo, self).__init__()self.dim_in = dim_inself.dim_z = dim_in + 2# TrustGeoself.att_attribute = SimpleAttention(temperature=self.dim_z ** 0.5,d_q_in=self.dim_in,d_k_in=self.dim_in,d_v_in=self.dim_in + 2,d_q_out=self.dim_z,d_k_out=self.dim_z,d_v_out=self.dim_z)# calculate Aself.gamma_1 = nn.Parameter(torch.ones(1, 1))self.gamma_2 = nn.Parameter(torch.ones(1, 1))self.gamma_3 = nn.Parameter(torch.ones(1, 1))self.alpha = nn.Parameter(torch.ones(1, 1))self.beta = nn.Parameter(torch.zeros(1, 1))# transform in Graphself.w_1 = nn.Linear(self.dim_in + 2, self.dim_in + 2)self.w_2 = nn.Linear(self.dim_in + 2, self.dim_in + 2)# higher-order evidence# graph view self.out_layer_graph_view = nn.Linear(self.dim_z*2, 5)# attribute view self.out_layer_attri_view = nn.Linear(self.dim_in, 5)# for output mu, v, alpha, betadef evidence(self, x):return Func.softplus(x)def trans(self, gamma1, gamma2, logv, logalpha, logbeta):v = self.evidence(logv)alpha = self.evidence(logalpha) + 1beta = self.evidence(logbeta)return gamma1, gamma2, v, alpha, betadef forward(self, lm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay, add_noise=0):""":param lm_X: feature of landmarks [..., 30]: 14 attribute + 16 measurement:param lm_Y: location of landmarks [..., 2]: longitude + latitude:param tg_X: feature of targets [..., 30]:param tg_Y: location of targets [..., 2]:param lm_delay: delay from landmark to the common router [..., 1]:param tg_delay: delay from target to the common router [..., 1]:return:"""N1 = lm_Y.size(0)N2 = tg_Y.size(0)ones = torch.ones(N1 + N2 + 1).cuda()lm_feature = torch.cat((lm_X, lm_Y), dim=1)tg_feature_0 = torch.cat((tg_X, torch.zeros(N2, 2).cuda()), dim=1)router_0 = torch.mean(lm_feature, dim=0, keepdim=True)all_feature_0 = torch.cat((lm_feature, tg_feature_0, router_0), dim=0)'''star-GNNproperties:1. single directed graph: feature of <landmarks> will never be updated.2. the target IP will receive from surrounding landmarks from two ways: (1) attribute similarity-based one-hop propagation;(2) delay measurement-based two-hop propagation via the common router;'''# GNN-step 1adj_matrix_0 = torch.diag(ones)# star connections (measurement)delay_score = torch.exp(-self.gamma_1 * (self.alpha * lm_delay + self.beta))rou2tar_score_0 = torch.exp(-self.gamma_2 * (self.alpha * tg_delay + self.beta)).reshape(N2)# satellite connections (feature)_, attribute_score = self.att_attribute(tg_X, lm_X, lm_feature)attribute_score = torch.exp(attribute_score)adj_matrix_0[N1:N1 + N2, :N1] = attribute_scoreadj_matrix_0[-1, :N1] = delay_scoreadj_matrix_0[N1:N1 + N2:, -1] = rou2tar_score_0degree_0 = torch.sum(adj_matrix_0, dim=1)degree_reverse_0 = 1.0 / degree_0degree_matrix_reverse_0 = torch.diag(degree_reverse_0)degree_mul_adj_0 = degree_matrix_reverse_0 @ adj_matrix_0step_1_all_feature = self.w_1(degree_mul_adj_0 @ all_feature_0)tg_feature_1 = step_1_all_feature[N1:N1 + N2, :]router_1 = step_1_all_feature[-1, :].reshape(1, -1)# GNN-step 2adj_matrix_1 = torch.diag(ones)rou2tar_score_1 = torch.exp(-self.gamma_3 * (self.alpha * tg_delay + self.beta)).reshape(N2)adj_matrix_1[N1:N1 + N2:, -1] = rou2tar_score_1all_feature_1 = torch.cat((lm_feature, tg_feature_1, router_1), dim=0)degree_1 = torch.sum(adj_matrix_1, dim=1)degree_reverse_1 = 1.0 / degree_1degree_matrix_reverse_1 = torch.diag(degree_reverse_1)degree_mul_adj_1 = degree_matrix_reverse_1 @ adj_matrix_1step_2_all_feature = self.w_2(degree_mul_adj_1 @ all_feature_1)tg_feature_2 = step_2_all_feature[N1:N1 + N2, :]# graph viewtg_feature_graph_view = torch.cat((tg_feature_1,tg_feature_2), dim=-1)# attribute view (for shanghai dim=51) tg_feature_attribute_view = tg_X'''predict'''output1 = self.out_layer_graph_view(tg_feature_graph_view)gamma1_g, gamma2_g, v_g, alpha_g, beta_g = torch.split(output1, 1, dim=-1)# attributeoutput2 = self.out_layer_attri_view(tg_feature_attribute_view)gamma1_a, gamma2_a, v_a, alpha_a, beta_a = torch.split(output2, 1, dim=-1)# transform, let v>0, aplha>1, beta>0 gamma1_g, gamma2_g, v_g, alpha_g, beta_g = self.trans(gamma1_g, gamma2_g, v_g, alpha_g, beta_g)gamma1_a, gamma2_a, v_a, alpha_a, beta_a = self.trans(gamma1_a, gamma2_a, v_a, alpha_a, beta_a)two_gamma_g = torch.cat((gamma1_g, gamma2_g), dim=1)two_gamma_a = torch.cat((gamma1_a, gamma2_a), dim=1)return two_gamma_g, v_g, alpha_g, beta_g, \two_gamma_a, v_a, alpha_a, beta_a

这是一个 PyTorch 中神经网络模型的类定义,它继承自 nn.Module 类,表明这个类是一个 PyTorch 模型。

分为几个部分展开描述:

(一)__init__()

    def __init__(self, dim_in):super(TrustGeo, self).__init__()self.dim_in = dim_inself.dim_z = dim_in + 2# TrustGeoself.att_attribute = SimpleAttention(temperature=self.dim_z ** 0.5,d_q_in=self.dim_in,d_k_in=self.dim_in,d_v_in=self.dim_in + 2,d_q_out=self.dim_z,d_k_out=self.dim_z,d_v_out=self.dim_z)# calculate A

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/226688.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python中的程序逻辑经典案例详解

我的博客 文章首发于公众号&#xff1a;小肖学数据分析 Python作为一种强大的编程语言&#xff0c;以其简洁明了的语法和强大的标准库&#xff0c;成为了理想的工具来构建这些解决方案。 本文将通过Python解析几个经典的编程问题。 经典案例 水仙花数 问题描述&#xff1a…

极坐标下的牛拉法潮流计算39节点MATLAB程序

微❤关注“电气仔推送”获得资料&#xff08;专享优惠&#xff09; 潮流计算&#xff1a; 潮流计算是根据给定的电网结构、参数和发电机、负荷等元件的运行条件&#xff0c;确定电力系统各部分稳态运行状态参数的计算。通常给定的运行条件有系统中各电源和负荷点的功率、枢纽…

设计模式之建造者模式(二)

目录 概述概念角色类图适用场景 详述画小人业务类的介绍代码解析 建造者基本代码类介绍代码解析 总结设计原则其他 概述 概念 建造者模式是一种创建型设计模式&#xff0c;它可以将复杂对象的构建过程与其表示分离&#xff0c;使得同样的构建过程可以创建不同的表示。 角色 …

Python函数和模块的使用

我的博客 文章首发于公众号&#xff1a;小肖学数据分析 在开发过程中&#xff0c;函数和模块帮助我们将复杂的代码逻辑分解为可管理的部分&#xff0c;提升代码的可读性、可维护性和重用性。 本文将介绍如何在Python中有效利用函数和模块&#xff0c;提供详细的示例。 函数的…

【C++干货铺】会搜索的二叉树(BSTree)

个人主页点击直达&#xff1a;小白不是程序媛 C系列专栏&#xff1a;C干货铺 代码仓库&#xff1a;Gitee 目录 前言&#xff1a; 二叉搜索树 二叉搜索树概念 二叉搜索树操作 二叉搜索树的查找 二叉搜索树的插入 二叉搜索树元素的删除 ​二叉搜索树的实现 BSTree结点 …

GraphicsProfiler 使用教程

GraphicsProfiler 使用教程 1.工具简介&#xff1a;2.Navigation介绍2.1.打开安装好的Graphics Profiler。2.2.将手机连接到计算机&#xff0c;软件会在手机中安装一个GraphicsProfiler应用(该应用是无界面的&#xff09;。2.3.Show files list2.4.Record new trace2.4.1.Appli…

TSINGSEE视频智能解决方案边缘AI智能与后端智能分析的区别与应用

视频监控与AI人工智能的结合是当今社会安全领域的重要发展趋势。随着科技的不断进步&#xff0c;视频监控系统已经不再局限于简单的录像和监视功能&#xff0c;而是开始融入人工智能技术&#xff0c;实现更加智能化的监控和安全管理。传统的监控系统往往需要人工操作来进行监控…

Windows11安装python模块transformers报错Long Path处理

Windows11安装python模块transformers报错&#xff0c;报错信息如下 ERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: C:\\Users\\27467\\AppData\\Local\\Packages\\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\\Local…

xcode 修改 target 中设备朝向崩溃

修改xcode的target中的设备朝向导致崩溃。 从日志上看好像没有什么特别的信息。 之后想了想&#xff0c;感觉这个应该还是跟xcode的配置有关系&#xff0c;不过改动的地方好像也只有plist。 就又翻腾了半天plist中的各种配置项&#xff0c;再把所有的用户权限提示相关的东西之…

重要通知!中国电信警告:用户须关闭路由器“双频合一”功能

在网络的无尽时空里&#xff0c;一场电信官方的宣战正酝酿中&#xff0c;目标锁定在我们日常生活中不可或缺的WiFi身上~ 最新消息曝光&#xff0c;竟然是路由器内藏的一个名为“双频合一”的功能引发了这场轰轰烈烈的网络风暴。 我们时常觉得WiFi就像是隐身在我们生活中的超级英…

call 和 apply:改变对象行为的秘密武器(上)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

IntelliJ IDEA 运行 若依分离版后端

一、本地运行 一、选择打开IntelliJ IDEA项目 二、选择若依项目 如&#xff1a;java123 三、等待右下角的准备工作&#xff08;有进度条的&#xff09;完成 四、修改MySQL 五、修改资源上传目录 六、修改redis 七、然后点击运行 八、成功图 九、测试访问 二、部署服务器运行 …

初级数据结构(五)——树和二叉树的概念

文中代码源文件已上传&#xff1a;数据结构源码 <-上一篇 初级数据结构&#xff08;四&#xff09;——队列 | NULL 下一篇-> 1、树结构&#xff08;Tree&#xff09; 1.1、树结构的特点 自然界中的树由根部开始向上生长&#xff0c;随机长出分支&…

对自己的博客网站进行DOS攻击

对自己的博客网站进行DOS攻击 先说明一点,别对别人的网站进行ddos/dos攻击(dos攻击一般短时间攻击不下来),这是违法的,很多都有自动报警机制,本篇博客仅用于学习,请勿用于非法用途 安装kaili Linux 进入KALI官网,下载iso镜像文件 vmware新建虚拟机,选择自定义 点击下一步 …

ROS-ROS运行管理-工作空间覆盖;节点、话题、参数名称重名

文章目录 一、工作空间覆盖二、节点名称重名2.1 rosrun设置命名空间与重映射2.2 launch文件设置命名空间与重映射2.3 编码设置命名空间与重映射 三、话题名称设置3.1 rosrun设置话题重映射3.2 launch文件设置话题重映射3.3 编码设置话题名称 四、参数名称设置4.1 rosrun设置参数…

Github与Gitlab

学习目标 能够使用GitHub创建远程仓库并使用能够安装部署GitLab服务器能够使用GitLab创建仓库并使用掌握CI/CD的概念掌握蓝绿部署, 滚动更新,灰度发布的概念 GitHub是目前最火的开源项目代码托管平台。它是基于web的Git仓库&#xff0c;提供公有仓库和私有仓库&#xff0c;但私…

使用Go实现一个百行聊天服务器

前段时间, redis作者不是整了个c语言版本的聊天服务器嘛, 地址, 代码量拢共不过百行. 于是, 心血来潮下, 我也整了个Go语言版本. 简单来说就是实现了一个聊天室的功能. 将所有注释空行都去掉, 刚好100行实现. 废话不多说, 先上代码: package mainimport ("fmt"&quo…

SoC中跨时钟域的信号同步设计(单比特同步设计)

一、 亚稳态 在数字电路中&#xff0c;触发器是一种很常用的器件。对于任意一个触发器&#xff0c;都由其参数库文件规定了能正常使用的“建立时间”&#xff08;Setup time&#xff09;和“保持时间”&#xff08;Hold time &#xff09;两个参数。“建立时间”是指在时钟…

【MySQL学习之基础篇】多表查询

文章目录 1. 多表关系1.1. 一对多1.2. 多对多1.3. 一对一 2. 多表查询概述2.1. 数据准备2.2. 概述 3. 查询的分类3.1. 内连接查询3.2. 外连接查询3.3. 自连接3.3.1. 自连接查询3.3.2. 联合查询 3.4. 子查询3.4.1. 概述3.4.2. 标量子查询3.4.3. 列子查询3.4.4. 行子查询3.4.5. 表…

python+requests+pytest 接口自动化实现

最近工作之余拿公司的项目写了一个接口测试框架&#xff0c;功能还不是很完善&#xff0c;算是抛砖引玉了&#xff0c;欢迎各位来吐槽。 主要思路&#xff1a; ①对 requests 进行二次封装&#xff0c;做到定制化效果 ②使用 excel 存放接口请求数据&#xff0c;作为数据驱动 ③…