代码链接:RIPGeo代码实现
├── lib # 包含模型(model)实现文件
│ |── layers.py # 注意力机制的代码。
│ |── model.py # TrustGeo的核心源代码。
│ |── sublayers.py # layer.py的支持文件。
│ |── utils.py # 辅助函数。
一、导入常用库和模块
from __future__ import print_function
import numpy as np
import torch
import warnings
import torch.nn as nn
import random
import matplotlib.pyplot as plt
import copy
这段代码首先包含一些导入语句,接着进行一些版本和警告的处理,最后导入了一些常用的库(numpy
、torch
、matplotlib
),并定义了一些常用的模块(nn
,plt
)。
1、from __future__ import print_function:这是为了确保代码同时在Python 2和Python 3中都能正常运行。在Python 2中,print
是一个语句,而在Python 3中,print()
是一个函数。通过这个导入语句,可以在Python 2中使用Python 3风格的print
函数。
2、import numpy as np:导入NumPy库,并用np
作为别名。NumPy是一个用于科学计算的库,提供了数组等高性能数学运算工具。
3、import torch::导入PyTorch库。PyTorch是一个深度学习框架,提供了张量计算和神经网络搭建等功能。
4、import warnings:导入warnings
模块,用于处理警告。
5、import torch.nn as nn:导入PyTorch中的神经网络模块。
6、import random:导入Python的random
模块,用于生成伪随机数。
7、import matplotlib.pyplot as plt:导入matplotlib
库的pyplot
模块,用于绘制图表。
8、import copy:导入Python的copy
模块,用于复制对象。
二、warnings.filterwarnings(action='once')
warnings.filterwarnings(action='once')
设置了在使用warnings.filterwarnings
时的参数。filterwarnings
函数用于配置警告过滤器,以控制哪些警告会被触发,以及如何处理这些警告。
具体来说,action='once'
表示警告信息只会被显示一次。这对于一些可能会频繁触发的警告而言是一种控制方式,以避免在控制台或日志中大量重复的警告信息。在第一次触发警告时,它会被显示,但在后续的同类警告中,将不再显示。
请注意,这个配置仅适用于在warnings
模块中配置的警告,它并不会影响其他类型的警告或错误。
三、DataPerturb() 数据扰动
class DataPerturb:def __init__(self, eta=1):self.eta = etaself.loss = torch.nn.MSELoss(reduction='sum')def perturb(self, model, data):# originallm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data# obtain new graph representation_, ori_graph_feature = model(lm_X, lm_Y, tg_X,tg_Y, lm_delay,tg_delay)# add Gaussian data perturbnew_lm_X, new_lm_Y, new_tg_X, new_tg_Y, new_lm_delay, new_tg_delay = lm_X.clone(), lm_Y.clone(), \tg_X.clone(), tg_Y.clone(), \lm_delay.clone(), tg_delay.clone()new_lm_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_lm_X[:, -16:]) * new_lm_X[:, -16:]).cuda()new_tg_X[:, -16:] += self.eta * torch.normal(0, torch.ones_like(new_tg_X[:, -16:]) * new_tg_X[:, -16:]).cuda()new_lm_delay += self.eta * torch.normal(0, torch.ones_like(new_lm_delay) * new_lm_delay).cuda()new_tg_delay += self.eta * torch.normal(0, torch.ones_like(new_tg_delay) * new_tg_delay).cuda()# obtain new graph representation_, new_graph_feature = model(new_lm_X, new_lm_Y, new_tg_X,new_tg_Y, new_lm_delay,new_tg_delay)data_loss = self.loss(ori_graph_feature, new_graph_feature)return data_loss
这段代码定义了一个名为 DataPerturb
的类,其目的是对给定的数据进行扰动,并计算扰动后的损失。
(一)__init__()
def __init__(self, eta=1):self.eta = etaself.loss = torch.nn.MSELoss(reduction='sum')
在 __init__
方法中,类初始化时可以指定一个参数 eta
,默认为1。该参数用于控制扰动的强度。
损失函数使用MSELoss。
(二)perturb()
def perturb(self, model, data):# originallm_X, lm_Y, tg_X, tg_Y, lm_delay, tg_delay = data# obtain new graph representation_, ori_graph_feature = model(lm_X, lm_Y, tg_X,tg_Y, lm_delay,