Pointnet++改进:更换不同的激活函数,打造更优性能

简介:
1.该教程提供大量的首发改进的方式,降低上手难度,多种结构改进,助力寻找创新点!
2.本篇文章对Pointnet++进行激活函数的改进,助力解决RELU激活函数缺陷。
3.专栏持续更新,紧随最新的研究内容。


文章目录

  • 步骤一
  • 步骤二
  • 步骤三


代码地址

步骤一

新建activate.py文件,我存放在新建的block目录下,加入以下代码:

# Activation functionsimport torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np# SiLU https://arxiv.org/pdf/1606.08415.pdf ----------------------------------------------------------------------------
class SiLU(nn.Module):  # export-friendly version of nn.SiLU()@staticmethoddef forward(x):return x * torch.sigmoid(x)class Hardswish(nn.Module):  # export-friendly version of nn.Hardswish()@staticmethoddef forward(x):# return x * F.hardsigmoid(x)  # for torchscript and CoreMLreturn x * F.hardtanh(x + 3, 0., 6.) / 6.  # for torchscript, CoreML and ONNXclass MemoryEfficientSwish(nn.Module):class F(torch.autograd.Function):@staticmethoddef forward(ctx, x):ctx.save_for_backward(x)return x * torch.sigmoid(x)@staticmethoddef backward(ctx, grad_output):x = ctx.saved_tensors[0]sx = torch.sigmoid(x)return grad_output * (sx * (1 + x * (1 - sx)))def forward(self, x):return self.F.apply(x)# Mish https://github.com/digantamisra98/Mish --------------------------------------------------------------------------
class Mish(nn.Module):@staticmethoddef forward(x):return x * F.softplus(x).tanh()class MemoryEfficientMish(nn.Module):class F(torch.autograd.Function):@staticmethoddef forward(ctx, x):ctx.save_for_backward(x)return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + aconcxunlian(x)))@staticmethoddef backward(ctx, grad_output):x = ctx.saved_tensors[0]sx = torch.sigmoid(x)fx = F.softplus(x).tanh()return grad_output * (fx + x * sx * (1 - fx * fx))def forward(self, x):return self.F.apply(x)# FReLU https://arxiv.org/abs/2007.11824 -------------------------------------------------------------------------------
class FReLU(nn.Module):def __init__(self, c1, k=3):  # ch_in, kernelsuper().__init__()self.conv = nn.Conv2d(c1, c1, k, 1, 1, groups=c1, bias=False)self.bn = nn.BatchNorm2d(c1)def forward(self, x):return torch.max(x, self.bn(self.conv(x)))class GELU(nn.Module):def __init__(self):super(GELU, self).__init__()def forward(self, x):return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))#
class MetaAconC(nn.Module):r""" ACON activation (activate or not).MetaAconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is generated by a small networkaccording to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>."""def __init__(self, c1, k=1, s=1, r=16):  # ch_in, kernel, stride, rsuper().__init__()c2 = max(r, c1 // r)self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))self.fc1 = nn.Conv2d(c1, c2, k, s, bias=True)self.fc2 = nn.Conv2d(c2, c1, k, s, bias=True)# self.bn1 = nn.BatchNorm2d(c2)# self.bn2 = nn.BatchNorm2d(c1)def forward(self, x):y = x.mean(dim=2, keepdims=True).mean(dim=3, keepdims=True)# batch-size 1 bug/instabilities https://github.com/ultralytics/yolov5/issues/2891# beta = torch.sigmoid(self.bn2(self.fc2(self.bn1(self.fc1(y)))))  # bug/unstablebeta = torch.sigmoid(self.fc2(self.fc1(y)))  # bug patch BN layers removeddpx = (self.p1 - self.p2) * xreturn dpx * torch.sigmoid(beta * dpx) + self.p2 * x
###
class AconC(nn.Module):"""ACON https://arxiv.org/pdf/2009.04759.pdfACON activation (activate or not).AconC: (p1*x-p2*x) * sigmoid(beta*(p1*x-p2*x)) + p2*x, beta is a learnable parameteraccording to "Activate or Not: Learning Customized Activation" <https://arxiv.org/pdf/2009.04759.pdf>."""def __init__(self, c1):super().__init__()self.p1 = nn.Parameter(torch.randn(1, c1, 1, 1))self.p2 = nn.Parameter(torch.randn(1, c1, 1, 1))self.beta = nn.Parameter(torch.ones(1, c1, 1, 1))def forward(self, x):dpx = (self.p1 - self.p2) * xreturn dpx * torch.sigmoid(self.beta * dpx) + self.p2 * x

步骤二

在models/pointnet2_utils.py中加入以下代码,该代码将PointNetSetAbstraction中的mlp三层感知机重新封装成一个class Conv模块,便于直接在Conv模块中修改激活函数,修改后的代码和源码结构是一致的。修改不同的激活函数直接在Conv类中修改即可。
PointNetSetAbstraction结构图如下,PointNetSetAbstractionMSG比PointNetSetAbstraction多一个不同尺度的三层mlp,其他结构是一样的。
在这里插入图片描述

class Conv(nn.Module):# Standard convolutiondef __init__(self, c1, c2, k=1):  # ch_in, ch_out, kernel, stride, padding, groupssuper(Conv, self).__init__()self.conv = nn.Conv2d(c1, c2, k)self.bn = nn.BatchNorm2d(c2)#self.act = nn.SiLU()#self.act = nn.LeakyReLU(0.1)self.act = nn.ReLU()#self.act = MetaAconC(c2)#self.act = AconC(c2)#self.act = Mish()#self.act = Hardswish()#self.act = FReLU(c2)def forward(self, x):return self.act(self.bn(self.conv(x)))def fuseforward(self, x):return self.act(self.conv(x))class PointNetSetAbstractionAttention(nn.Module):def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):super(PointNetSetAbstractionAttention, self).__init__()self.npoint = npointself.radius = radiusself.nsample = nsample#self.mlp_convs = nn.ModuleList()self.mlp_conv1 = Conv(in_channel,mlp[0],1)self.mlp_attention = CBAM(mlp[0])self.mlp_conv2 = Conv(mlp[0],mlp[1],1)self.mlp_conv3 = Conv(mlp[1],mlp[2],1)self.group_all = group_alldef forward(self, xyz, points):"""Input:xyz: input points position data, [B, C, N]points: input points data, [B, D, N]Return:new_xyz: sampled points position data, [B, C, S]new_points_concat: sample points feature data, [B, D', S]"""xyz = xyz.permute(0, 2, 1)if points is not None:points = points.permute(0, 2, 1)if self.group_all:new_xyz, new_points = sample_and_group_all(xyz, points)else:new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)# new_xyz: sampled points position data, [B, npoint, C]# new_points: sampled points data, [B, npoint, nsample, C+D]new_points = new_points.permute(0, 3, 2, 1)  # [B, C+D, nsample,npoint]new_points=self.mlp_conv1(new_points)new_points = self.mlp_attention(new_points)new_points = self.mlp_conv2(new_points)new_points = self.mlp_conv3(new_points)new_points = torch.max(new_points, 2)[0]new_xyz = new_xyz.permute(0, 2, 1)return new_xyz, new_pointsclass PointNetSetAbstractionMsgAttention(nn.Module):def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):super(PointNetSetAbstractionMsgAttention, self).__init__()self.npoint = npointself.radius_list = radius_listself.nsample_list = nsample_listself.mlp_conv00 = Conv(in_channel+3,mlp_list[0][0],1)self.mlp_conv01 = Conv(mlp_list[0][0],mlp_list[0][1],1)self.mlp_conv02 = Conv(mlp_list[0][1],mlp_list[0][2],1)self.mlp_conv10 = Conv(in_channel+3,mlp_list[1][0],1)self.mlp_conv11 = Conv(mlp_list[1][0],mlp_list[1][1],1)self.mlp_conv12 = Conv(mlp_list[1][1],mlp_list[1][2],1)# self.conv_blocks = nn.ModuleList()# self.bn_blocks = nn.ModuleList()# for i in range(len(mlp_list)):#     convs = nn.ModuleList()#     bns = nn.ModuleList()#     last_channel = in_channel + 3#     for out_channel in mlp_list[i]:#         convs.append(nn.Conv2d(last_channel, out_channel, 1))#         bns.append(nn.BatchNorm2d(out_channel))#         last_channel = out_channel#     self.conv_blocks.append(convs)#     self.bn_blocks.append(bns)def forward(self, xyz, points):"""Input:xyz: input points position data, [B, C, N]points: input points data, [B, D, N]Return:new_xyz: sampled points position data, [B, C, S]new_points_concat: sample points feature data, [B, D', S]"""xyz = xyz.permute(0, 2, 1)if points is not None:points = points.permute(0, 2, 1)B, N, C = xyz.shapeS = self.npointnew_xyz = index_points(xyz, farthest_point_sample(xyz, S))new_points_list = []for i, radius in enumerate(self.radius_list):K = self.nsample_list[i]group_idx = query_ball_point(radius, K, xyz, new_xyz)grouped_xyz = index_points(xyz, group_idx)grouped_xyz -= new_xyz.view(B, S, 1, C)if points is not None:grouped_points = index_points(points, group_idx)grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)else:grouped_points = grouped_xyzgrouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]if i==0:grouped_points =self.mlp_conv00(grouped_points)grouped_points = self.mlp_conv01(grouped_points)grouped_points = self.mlp_conv02(grouped_points)else:grouped_points = self.mlp_conv10(grouped_points)grouped_points = self.mlp_conv11(grouped_points)grouped_points = self.mlp_conv12(grouped_points)# for j in range(len(self.conv_blocks[i])):#     conv = self.conv_blocks[i][j]#     bn = self.bn_blocks[i][j]#     grouped_points =  F.relu(bn(conv(grouped_points)))new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]new_points_list.append(new_points)new_xyz = new_xyz.permute(0, 2, 1)new_points_concat = torch.cat(new_points_list, dim=1)return new_xyz, new_points_concat

步骤三

在不同的模型中修改调用即可,如在models/pointnet2_sem_seg.py文件中修改,训练即可

import torch.nn as nn
import torch.nn.functional as F
# from models.pointnet2_utils import PointNetSetAbstraction, PointNetFeaturePropagation, PointNetSetAbstractionKPconv, \
#     PointNetSetAbstractionAttention
from models.pointnet2_utils import *class get_model(nn.Module):def __init__(self, num_classes):super(get_model, self).__init__()self.sa1 = PointNetSetAbstractionAttention(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)self.fp4 = PointNetFeaturePropagation(768, [256, 256])self.fp3 = PointNetFeaturePropagation(384, [256, 256])self.fp2 = PointNetFeaturePropagation(320, [256, 128])self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])self.conv1 = nn.Conv1d(128, 128, 1)self.bn1 = nn.BatchNorm1d(128)self.drop1 = nn.Dropout(0.5)self.conv2 = nn.Conv1d(128, num_classes, 1)def forward(self, xyz):l0_points = xyzl0_xyz = xyz[:,:3,:]l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)l0_points = self.fp1(l0_xyz, l1_xyz, None, l1_points)x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))x = self.conv2(x)x = F.log_softmax(x, dim=1)x = x.permute(0, 2, 1)return x, l4_pointsclass get_loss(nn.Module):def __init__(self):super(get_loss, self).__init__()self.gamma=2def forward(self, pred, target, trans_feat, weight):#pred: 模型预测的输出   target: 真实的标签或数据,用于计算损失total_loss = F.nll_loss(pred, target, weight=weight)return total_loss
if __name__ == '__main__':import  torchmodel = get_model(13)xyz = torch.rand(6, 9, 2048)(model(xyz))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/592405.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端Web系统架构设计

文章目录 1.目录结构定义2. 路由封装2.1 API路由定义2.2 组件路由定义 3. Axios请求开发4. 环境变量封装5. storage模块封装(sessionStorage, localStorage)6. 公共函数封装(日期,金额,权限..)7. 通用交互定义(删除二次确认,类别,面包屑...)8. 接口全貌概览 1.目录结构定义 2. …

LeetCode刷题--- 三步问题

个人主页&#xff1a;元清加油_【C】,【C语言】,【数据结构与算法】-CSDN博客 个人专栏 力扣递归算法题 http://t.csdnimg.cn/yUl2I 【C】 ​​​​​​http://t.csdnimg.cn/6AbpV 数据结构与算法 ​​​http://t.csdnimg.cn/hKh2l 前言&#xff1a;这个专栏主要讲述动…

【Matlab】PSO-BP 基于粒子群算法优化BP神经网络的数据时序预测(附代码)

资源下载&#xff1a; https://download.csdn.net/download/vvoennvv/88689096 一&#xff0c;概述 PSO-BP算法是一种结合了粒子群算法&#xff08;PSO&#xff09;和BP神经网络的方法&#xff0c;用于数据时序预测。下面是PSO-BP算法的原理和过程&#xff1a; 1. 数据准备&…

继承和多态

全局变量&#xff0c;int monster 10000:定义英雄类hero&#xff0c;受保护的属性string name&#xff0c;int hp,int attck;公有的无参构造&#xff0c;有参构造&#xff0c;虚成员函数 void Ak(blood-0)&#xff0c;法师类继承自英雄类&#xff0c;私有属性 int p_atk50;重写…

Github 2024-01-03 开源项目日报 Top10

根据Github Trendings的统计&#xff0c;今日(2024-01-03统计)共有10个项目上榜。根据开发语言中项目的数量&#xff0c;汇总情况如下&#xff1a; 开发语言项目数量Python项目3TypeScript项目3Jupyter Notebook项目1Dart项目1C项目1Rust项目1 系统设计指南 创建周期&#x…

计算机毕业设计 SpringBoot的停车场管理系统 Javaweb项目 Java实战项目 前后端分离 文档报告 代码讲解 安装调试

&#x1f34a;作者&#xff1a;计算机编程-吉哥 &#x1f34a;简介&#xff1a;专业从事JavaWeb程序开发&#xff0c;微信小程序开发&#xff0c;定制化项目、 源码、代码讲解、文档撰写、ppt制作。做自己喜欢的事&#xff0c;生活就是快乐的。 &#x1f34a;心愿&#xff1a;点…

堆排序 Heapsort

堆排序&#xff08;Heapsort&#xff09;是指利用堆这种数据结构所设计的一种排序算法。堆积是一个近似完全二叉树的结构&#xff0c;并同时满足堆积的性质&#xff1a;即子结点的键值或索引总是小于&#xff08;或者大于&#xff09;它的父节点。堆排序可以说是一种利用堆的概…

简易机器学习笔记(四)初识卷积神经网络

前言 第一次写卷积神经网络&#xff0c;也是照着paddlepaddle的官方文档抄&#xff0c;这里简单讲解一下心得。 首先我们要知道之前写的那些东西都是什么&#xff0c;之前写的我们称之为简单神经网络&#xff0c;也就是简单一层连接输出和输出&#xff0c;通过前向计算和逆向…

Simple Facebook Sign-In

简单的Facebook登录为Android、iOS、Windows、Mac、通用Windows平台(UWP)和Unity制作的WebGL应用程序提供了基于OAuth 2.0的Facebook登录。 优点: ● 跨平台游戏和应用程序的跨平台用户身份验证 ● 无插件,无第三方库,无依赖● 对建筑规模没有影响 ● 客户端-服务器应…

solidity显示以太坊美元价格

看过以太坊白皮书的都知道&#xff0c;以太坊比较比特币而言所提升的地方中&#xff0c;我认为最重要的一点就是能够访问外部的数据&#xff0c;这一点在赌博、金融领域应用会很广泛&#xff0c;但是区块链是一个确定的系统&#xff0c;包括里面的所有数值包括交易ID等都是确定…

OS 7--DNS配置+Apache发布网站

环境准备 centOS 7 1.配置DNS 1.1 域名为lianxi.com 1.2 为WWW服务器、FTP服务器、NEWS服务器做域名解析 1)安装DNS yum -y install bind bind-utils (如果安装不上&#xff0c;就把磁盘在重洗挂载一下&#xff09; 2&#xff09;修改DNS配置文件 vim /etc/resolv.conf…

车载 Android之 核心服务 - CarPropertyService 解析

重要类的源码文件名及位置&#xff1a; CarPropertyManager.java packages/services/Car/car-lib/src/android/car/hardware/property/ CarPropertyService.java packages/services/Car/service/src/com/android/car/ 类的介绍&#xff1a; CarPropertyManager&#xff1a…

航芯ACM32G103开发板评测 02-GPIO输入输出

航芯ACM32G103开发板评测 02-GPIO输入输出 航芯ACM32G103开发板评测 GPIO输入输出应用 软硬件平台 ACM32G103 Board开发板 MDK-ARM Keil GPIO输出典型应用——点灯 GPIO输入典型应用——按键 GPIO 功能概述 GPIO 是通用输入/输出&#xff08;General Purpose I/O&#x…

[Flutter]WindowsOS中相关配置

Flutter项目在Windows平台上如何配置 目录 Flutter项目在Windows平台上如何配置 写在开头 正文 1、OS准备 2、编译环境准备 ① 下载AndroidStudio ② 下载dart ③ 下载flutter ④ 下载并安装VS ⑤ 在AS中配置dart和flutter 3、配置中遇到的问题 写在结尾 写在开头…

C++ stack使用、模拟实现、OJ题

目录 一、介绍 二、常用函数 三、模拟实现 四、OJ练习题 1、最小栈 2、栈的压入、弹出序列 3、逆波兰表达式(后缀转中缀) 4、中缀转后缀思路 5、用栈实现队列 一、介绍 stack是一种容器适配器&#xff0c;专门用在具有后进先出操作的上下文环境中&#xff0c;其删除…

自动驾驶论文

文章目录 一、Convolutional Social Pooling for Vehicle Trajectory Prediction二、QCNet&#xff1a;Query-Centric Trajectory Prediction三、VectorNet: Encoding HD Maps and Agent Dynamics from Vectorized Representation 一、Convolutional Social Pooling for Vehicl…

iOS 小组件开发

iOS14之后Apple引入了新的WidgetKit&#xff0c;舍弃了原有额TodayExtension。 开发准备&#xff1a; 新的WidgetExtension只能通过SwiftUI进行开发&#xff1b; Widget有三种尺寸&#xff1a;systemSmall、 systemMedium、systemLarge&#xff0c;三种尺寸对应固定的UI类型布…

BIND-DNS配置介绍

一、主要配置文件 /etc/named.conf options { //Option 段全部配置 listen-on port 53 { 127.0.0.1; };//表示BIND将在53端口监听&#xff0c;若需要对所有IP进行监听&#xff0c;则修改为// listen-on port 53 { any; }; directory "/var/named"…

(六)数码管动态刷新

文章目录 如何实现利用人眼的余晖效应&#xff08;100hz无闪烁&#xff09;1ms刷一个数码管 8个看起来就是一块亮的 结合前面内容进行操作前面内容传送门&#xff1a;如何段选原理图代码写法这里借助isp复制共阴数码管码值 如何位选原理图代码写法 如何消隐在每次 段选 赋值之前…

K8S集群部署MySql

挂载MySQL数据卷 在k8s集群中挂载MySQL数据卷 需要安装一个NFS。 在主节点安装NFS yum install -y nfs-utils rpcbind 在主节点创建目录 mkdir -p /nfs chmod 777 /nfs 更改归属组与用户 chown -R nfsnobody:nfsnobody /nfs 配置共享目录 echo "/nfs *(insecure,rw,s…