Graph U-Net Code【图分类】

1. main.py


# GNet是需要用到的model
net = GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度,类别数,参数
trainer = Trainer(args, net, G_data) #开始训练数据
# 正式开始训练数据
trainer.train()

2. network.py

class GNet(nn.Module):def __init__(self, in_dim, n_classes, args):super(GNet, self).__init__()self.n_act = getattr(nn, args.act_n)()# getattr() 是 Python 内置的一个函数,可以用来获取一个对象的属性值或方法self.c_act = getattr(nn, args.act_c)()# print('GNet1: in_dim=', in_dim, 'n_class=',n_classes)  # GNet1: in_dim= 82 n_class= 2"用的是GCN的框架,输入分别是feat dim、layer dim、network act、drop net(net表示GCN网络本身的参数)"self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)self.g_unet = GraphUnet(args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n)"""nn.Linear定义一个神经网络的线性层,方法如下:torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)"""self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)self.out_l_2 = nn.Linear(args.h_dim, n_classes)"nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活"self.out_drop = nn.Dropout(p=args.drop_c)Initializer.weights_init(self)def forward(self, gs, hs, labels):print('GNet2: gs=',type(gs), len(gs), 'hs=',type(hs), len(hs), 'labels:',type(labels),labels.shape)# GNet2: gs= <class 'list'> 32 hs= <class 'list'> 32 labels: <class 'torch.Tensor'> torch.Size([32])hs = self.embed(gs, hs)print('GNet2: hs=', type(hs), hs.shape)logits = self.classify(hs)return self.metric(logits, labels)

3. trainer.py

class Trainer:"init初始化,输入分别是arg参数、gcn net、graph Data,将这些装进self里面"def __init__(self, args, net, G_data):self.args = argsself.net = netself.feat_dim = G_data.feat_dimself.fold_idx = G_data.fold_idxself.init(args, G_data.train_gs, G_data.test_gs)# 若是有显卡,则用显卡跑if torch.cuda.is_available():self.net.cuda()"初始化——开始训练数据"def init(self, args, train_gs, test_gs):print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))# 分成训练集和测试集,记载数据train_data = GraphData(train_gs, self.feat_dim)test_data = GraphData(test_gs, self.feat_dim)# DataLoader 为pytorch 内部类,此时只需要指定trainset, batch_size, shuffle, num_workers, ...等self.train_d = train_data.loader(self.args.batch, True)self.test_d = test_data.loader(self.args.batch, False)self.optimizer = optim.Adam(self.net.parameters(), lr=self.args.lr, amsgrad=True,weight_decay=0.0008)
    def train(self):max_acc = 0.0train_str = 'Train epoch %d: loss %.5f acc %.5f'test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'line_str = '%d:\t%.5f\n'for e_id in range(self.args.num_epochs):self.net.train()# 从每个epoch开始训练loss, acc = self.run_epoch(e_id, self.train_d, self.net, self.optimizer)print(train_str % (e_id, loss, acc))with torch.no_grad():self.net.eval()loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)max_acc = max(max_acc, acc)print(test_str % (e_id, loss, acc, max_acc))with open(self.args.acc_file, 'a+') as f:f.write(line_str % (self.fold_idx, max_acc))
    def run_epoch(self, epoch, data, model, optimizer):#self.run_epoch(e_id, self.train_d, self.net, self.optimizer)losses, accs, n_samples = [], [], 0for batch in tqdm(data, desc=str(epoch), unit='b'):cur_len, gs, hs, ys = batchgs, hs, ys = map(self.to_cuda, [gs, hs, ys])loss, acc = model(gs, hs, ys)losses.append(loss*cur_len)accs.append(acc*cur_len)n_samples += cur_lenif optimizer is not None:optimizer.zero_grad()loss.backward()optimizer.step()avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samplesreturn avg_loss.item(), avg_acc.item()

不懂

class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_features"""为啥要这么做???5555555555555555555555555555"""self.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

两个字符串的最小ASCII删除和

题目描述 给定两个字符串s1 和 s2&#xff0c;返回 使两个字符串相等所需删除字符的 ASCII 值的最小和 。 示例 思路 这个题的解法一和最长公共子序列的解法大致相同&#xff0c;我们可以在此代码基础上稍微更改即可。 代码如下 解法一 public int minimumDeleteSum1(Stri…

手机ip地址切换后有什么影响

随着互联网的普及和人们对网络连接的需求不断增加&#xff0c;手机已经成为我们日常生活中不可或缺的一部分。而在使用手机的过程中&#xff0c;手机ip地址的切换也成为了许多用户需要注意的问题。虎观代理小二二将探讨手机ip地址切换后可能产生的影响。 手机ip地址的含义及作…

遍历激光雷达数据

文章目录 初始化catkin工作空间(如果你还没有)创建功能包构建功能包配置环境CMakeLists.txttraverse_lidar_data.cctraverse_lidar_data.launchros::spin()初始化catkin工作空间(如果你还没有) mkdir -p catkin_ws/src cd catkin_ws/ catkin_make这会在你的目录下创建一个…

unittest与pytest的区别

Unittest vs Pytest 主要从用例编写规则、用例的前置和后置、参数化、断言、用例执行、失败重运行和报告这几个方面比较unittest和pytest的区别: 用例编写规则 用例前置与后置条件 断言 测试报告 失败重跑机制 参数化 用例分类执行 如果不好看&#xff0c;可以看下面表格&…

数据库开发软件Navicat Premium 15 mac中文软件介绍

Navicat Premium 15 mac是一款数据库开发工具&#xff0c;Navicat Premium 15 Mac版可以让你以单一程序同時连接到 MySQL、MariaDB、SQL Server、SQLite、Oracle 和 PostgreSQL 数据库。 Navicat Premium for Mac软件介绍 Navicat premium是一款数据库管理工具。将此工具连接数…

Android裁剪图片之后无法加载的问题

适配Android11之后更改了图片保存目录&#xff0c;导致裁剪之后图片一直无法加载&#xff08;fileNotfound&#xff09; 最主要的问题在于保存裁剪文件的目录不能为私有目录&#xff0c;因为裁剪工具是系统工具&#xff0c;无法直接访问项目本身的私有目录。 解决办法&#x…

京东数据分析:2023年9月京东洗地机行业品牌销售排行榜

鲸参谋监测的京东平台9月份洗地机市场销售数据已出炉&#xff01; 9月份&#xff0c;洗地机市场的销售额增长。根据鲸参谋电商数据分析平台的相关数据显示&#xff0c;9月京东平台上洗地机的销量为9.2万&#xff0c;销售额将近2.2亿&#xff0c;同比增长约9%。从价格上看&#…

一次cs上线服务器的练习

环境&#xff1a;利用vm搭建的环境 仅主机为65段 测试是否能与win10ping通 配置转发 配置好iis Kali访问测试 现在就用burp抓取winser的包 开启代理 使用默认的8080抓取成功 上线

Elasticsearch(一)---介绍

简介 Elasticsearch是一个基于Lucene的实际的分布式搜索和分析引擎。设计用于云计算中&#xff0c;能够达到近实时搜索&#xff0c;稳定&#xff0c;可靠&#xff0c;快速&#xff0c;安装使用方便。基于RESTful接口。 官网地址&#xff1a;Elasticsearch 平台 — 大规模查找…

国密SM算法及实现加密和解密

一 引入pom <dependency><groupId>com.antherd</groupId><artifactId>sm-crypto</artifactId><version>0.3.2</version></dependency> 二 代码实现 package com.example.ytyproject.component;import com.antherd.smcrypto.…

C++学习初探---‘C++面向对象‘-继承函数重载与运算符重载

文章目录 前言继承继承是什么&#xff1f;三种访问权限的继承&#xff1a; 函数重载与运算符重载函数重载运算符重载可重载运算符&不可重载运算符 前言 第三次学习记录&#xff0c;依旧是C面向对象的内容。 继承 继承是什么&#xff1f; C中的继承是一种面向对象编程&am…

企业 Tomcat 运维 部署tomcat反向代理集群

一、Tomcat 简介 Tomcat服务器是一个免费的开放源代码的Web应用服务器&#xff0c;属于轻量级应用服务器&#xff0c; Tomcat和Nginx、Apache(httpd)、Web服务器一样&#xff0c;具有处理HTML页面的功能不过Tomcat处理静态HTML的能力不如Nginx/Apache服务器 一个tomcat默认并…

vue中把弹出层.vue文件注册成组件供其他.vue文件调用的写法

背景&#xff1a;因弹出层多个页面的详情都是一样的&#xff0c;因此把弹出层定义成组件&#xff0c;多次调用 定义组件的过程中出现很多问题&#xff0c;因此再次记录最终成功的写法 一、 简单实现页面调用弹出层组件的打开弹出层方法&#xff1a; 1. 弹出层组件 (in…

Linux机器网络检查

查看DNS file: dianTestLRSSnapshot:~$ cat /etc/resolv.conf # This file is managed by man:systemd-resolved(8). Do not edit. # # This is a dynamic resolv.conf file for connecting local clients to the # internal DNS stub resolver of systemd-resolved. This file…

SpringBoot快速整合canal1.1.5(TCP模式)

SpringBoot快速整合canal1.1.5&#xff08;TCP模式&#xff09; 安装并配置MySQL主从⭐ 1&#xff1a;Docker安装MySQL8.0.28 docker pull mysql:8.0.282&#xff1a;创建目录&#xff1a; mkdir -p /usr/local/mysql8/data mkdir -p /usr/local/mysql8/log mkdir -p /usr/…

STL源码剖析系列:其一、list

一、基本用法 list的基本用法比较简单&#xff0c;可以参考站长严长生的教程&#xff1a; C list&#xff08;STL list&#xff09;容器完全攻略&#xff08;超级详细&#xff09; 下面重点介绍list源码。 二、

Web3 React项目Dapp获取智能合约对象

上文Web3 整理React项目 导入Web3 并获取区块链信息中&#xff0c;我们在react搭建的dapp中简单拿到了我们区块链中的账号授权信息 那 我们继续 先终端运行 ganache -d将ganache环境起起来 然后 我们运行 dapp 拿到授权列表 回到上文结束的一个状态 然后 我们发布一下自己的…

ArcGIS统计各种土地利用类型的总面积

如下图为研究区多个村的土地利用现状图,现在需统计每种类型的面积总和,以及每个行政村内各种土地利用类型的总面积。本文通过案例的形式,讲解ArcGIS中两种常用的分类统计面积的工具,建议收藏。 文章目录 1. 加载土地利用数据2. 常规属性汇总统计3. 汇总统计数据1. 加载土地…

软件安利——火绒安全

近年来&#xff0c;以优化、驱动、管理为目标所打造的软件屡见不鲜&#xff0c;大同小异的电脑管家相继走入了公众的视野。然而&#xff0c;在这日益急功近利的社会氛围驱动之下&#xff0c;真正坚持初心、优先考虑用户体验的电脑管家逐渐湮没在了浪潮之中。无论是鲁大师&#…

【C++】string类

STL STL(standard template libaray-标准模板库)&#xff1a;是C标准库的重要组成部分&#xff0c;不仅是一个可复用的组件库&#xff0c;而且是一个包罗数据结构与算法的软件框架。 为什么学习string类&#xff1f; 1、C语言中的字符串 C语言中&#xff0c;字符串是以\0结尾…