pytorch代码实现之空间通道重组卷积SCConv

空间通道重组卷积SCConv

空间通道重组卷积SCConv,全称Spatial and Channel Reconstruction Convolution,CPR2023年提出,可以即插即用,能够在减少参数的同时提升性能的模块。其核心思想是希望能够实现减少特征冗余从而提高算法的效率。一般压缩模型的方法分为三种,分别是network pruning, weight quantization, low-rank factorization以及knowledge distillation,虽然这些方法能够达到减少参数的效果,但是往往都会导致模型性能的衰减。另一种方法就是在构建模型时利用特殊的模块或操作减少模型参数,获得轻量级的网络模型,这种方法能够在保证性能的同时达到参数减少的效果。

原文地址:SCConv: Spatial and Channel Reconstruction Convolution for Feature Redundancy

作者提出的SCConv包含两部分,分别是Spatial Reconstruction Unit (SRU)和Channel Reconstruction Unit (CRU),下面是SSConv的总体结构。
SCConv结构原理图
可以看出,SCConv模块设计,对于输入的特征图先利用1x1的卷积改变为适合的通道数,之后便分别是SRU和CRU两个模块对于特征图进行处理,最后在通过1x1的卷积将特征通道数恢复并进行残差操作。
SRU模块结构
CRU模块结构

代码实现如下:

import torch
import torch.nn.functional as F
import torch.nn as nn class GroupBatchnorm2d(nn.Module):def __init__(self, c_num:int, group_num:int = 16, eps:float = 1e-10):super(GroupBatchnorm2d,self).__init__()assert c_num    >= group_numself.group_num  = group_numself.gamma      = nn.Parameter( torch.randn(c_num, 1, 1)    )self.beta       = nn.Parameter( torch.zeros(c_num, 1, 1)    )self.eps        = epsdef forward(self, x):N, C, H, W  = x.size()x           = x.view(   N, self.group_num, -1   )mean        = x.mean(   dim = 2, keepdim = True )std         = x.std (   dim = 2, keepdim = True )x           = (x - mean) / (std+self.eps)x           = x.view(N, C, H, W)return x * self.gamma + self.betaclass SRU(nn.Module):def __init__(self,oup_channels:int, group_num:int = 16,gate_treshold:float = 0.5 ):super().__init__()self.gn             = GroupBatchnorm2d( oup_channels, group_num = group_num )self.gate_treshold  = gate_tresholdself.sigomid        = nn.Sigmoid()def forward(self,x):gn_x        = self.gn(x)w_gamma     = self.gn.gamma/sum(self.gn.gamma)reweigts    = self.sigomid( gn_x * w_gamma )# Gateinfo_mask   = reweigts>=self.gate_tresholdnoninfo_mask= reweigts<self.gate_tresholdx_1         = info_mask * xx_2         = noninfo_mask * xx           = self.reconstruct(x_1,x_2)return xdef reconstruct(self,x_1,x_2):x_11,x_12 = torch.split(x_1, x_1.size(1)//2, dim=1)x_21,x_22 = torch.split(x_2, x_2.size(1)//2, dim=1)return torch.cat([ x_11+x_22, x_12+x_21 ],dim=1)class CRU(nn.Module):'''alpha: 0<alpha<1'''def __init__(self, op_channel:int,alpha:float = 1/2,squeeze_radio:int = 2 ,group_size:int = 2,group_kernel_size:int = 3,):super().__init__()self.up_channel     = up_channel   =   int(alpha*op_channel)self.low_channel    = low_channel  =   op_channel-up_channelself.squeeze1       = nn.Conv2d(up_channel,up_channel//squeeze_radio,kernel_size=1,bias=False)self.squeeze2       = nn.Conv2d(low_channel,low_channel//squeeze_radio,kernel_size=1,bias=False)#upself.GWC            = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=group_kernel_size, stride=1,padding=group_kernel_size//2, groups = group_size)self.PWC1           = nn.Conv2d(up_channel//squeeze_radio, op_channel,kernel_size=1, bias=False)#lowself.PWC2           = nn.Conv2d(low_channel//squeeze_radio, op_channel-low_channel//squeeze_radio,kernel_size=1, bias=False)self.advavg         = nn.AdaptiveAvgPool2d(1)def forward(self,x):# Splitup,low  = torch.split(x,[self.up_channel,self.low_channel],dim=1)up,low  = self.squeeze1(up),self.squeeze2(low)# TransformY1      = self.GWC(up) + self.PWC1(up)Y2      = torch.cat( [self.PWC2(low), low], dim= 1 )# Fuseout     = torch.cat( [Y1,Y2], dim= 1 )out     = F.softmax( self.advavg(out), dim=1 ) * outout1,out2 = torch.split(out,out.size(1)//2,dim=1)return out1+out2class ScConv(nn.Module):def __init__(self,op_channel:int,group_num:int = 16,gate_treshold:float = 0.5,alpha:float = 1/2,squeeze_radio:int = 2 ,group_size:int = 2,group_kernel_size:int = 3,):super().__init__()self.SRU = SRU( op_channel, group_num            = group_num,  gate_treshold        = gate_treshold )self.CRU = CRU( op_channel, alpha                = alpha, squeeze_radio        = squeeze_radio ,group_size           = group_size ,group_kernel_size    = group_kernel_size )def forward(self,x):x = self.SRU(x)x = self.CRU(x)return xif __name__ == '__main__':x       = torch.randn(1,32,16,16)model   = ScConv(32)print(model(x).shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/70205.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【探索Linux】—— 强大的命令行工具 P.8(进程优先级、环境变量)

阅读导航 前言一、进程优先级1. 优先级概念2. Linux查看系统进程3. PRI&#xff08;Priority&#xff09;和NI&#xff08;Nice&#xff09; 二、环境变量1. 概念2. 查看环境变量方法3. 环境变量的组织方式4.通过代码获取环境变量5. 环境变量的特点 总结温馨提示 前言 前面我们…

C++ - 多态的实现原理

前言 本博客主要介绍C 当中 多态语法的实现原理&#xff0c;如果有对 多态语法 有疑问的&#xff0c;请看下面这篇博客&#xff1a; 探究&#xff0c;为什么多态的条件是那样的&#xff08;虚函数表&#xff09; 首先&#xff0c;调用虚函数必须是 父类的 指针或 引用&#xf…

KT142C-sop16语音芯片ic的功能介绍 支持pwm和dac输出 usb直接更新内置空间

1.1 简介 KT142C是一个提供串口的SOP16语音芯片&#xff0c;完美的集成了MP3的硬解码。内置330KByte的空间&#xff0c;最大支持330秒的语音长度&#xff0c;支持多段语音&#xff0c;支持直驱0.5W的扬声器无需外置功放 软件支持串口通信协议&#xff0c;默认波特率9600.同时…

opencv旋转图像

0 、使用旋转矩阵旋转 import cv2img cv2.imread(img.jpg, 1) (h, w) img.shape[:2] # 获取图像的宽和高# 定义旋转中心坐标 center (w / 2, h / 2)# 定义旋转角度 angle 90# 定义缩放比例 scale 1# 获得旋转矩阵 M cv2.getRotationMatrix2D(center, angle, scale)# 进行…

比亚迪海豹:特斯拉强劲对手,瑞银拆解成本比同级车型低15%~35%

瑞银证券日前对中国电动车产品比亚迪海豹进行了拆解&#xff0c;发现海豹具有强大的成本优势&#xff0c;而这个优势主要来自于中国本土生产和国内完善的电动车供应链以及比亚迪的垂直整合体系和零部件高度集成性。比亚迪的整车成本比同级别竞争车型分别低15%至35%。 瑞银预测&…

【100天精通Python】Day55:Python 数据分析_Pandas数据选取和常用操作

目录 Pandas数据选择和操作 1 选择列和行 2 过滤数据 3 添加、删除和修改数据 4 数据排序 Pandas数据选择和操作 Pandas是一个Python库&#xff0c;用于数据分析和操作&#xff0c;提供了丰富的功能来选择、过滤、添加、删除和修改数据。 1 选择列和行 Pandas 提供了多种…

学习Bootstrap 5的第六天

目录 信息警告框 警告框 实例 警告框链接 实例 关闭警告框 实例 警告框动画 实例 按钮 按钮样式 实例 按钮轮廓 实例 ​编辑按钮尺寸 实例 块级按钮 实例 实例 活动/禁用按钮 实例 加载器按钮 实例 扩展小知识 信息警告框 警告框 警告框是使用 .aler…

ETCD详解

一、etcd概念 ETCD 是一个高可用的分布式键值key-value数据库&#xff0c;可用于服务发现。 ETCD 采用raft 一致性算法&#xff0c;基于 Go语言实现。 etcd作为一个高可用键值存储系统&#xff0c;天生就是为集群化而设计的。由于Raft算法在做决策时需要多数节点的投票&…

【算法】归并排序 详解

归并排序 详解 归并排序代码实现1. 递归版本2. 非递归版本 排序&#xff1a; 排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a; 假定在待排序的记录序列中&#xff0c;存在多个具有相…

eclipse进入断点之后,一直卡死,线程一直在运行【记录一种情况】

问题描述: 一直卡死在某个断点处&#xff0c;取消断点也是卡死在这边的进程处。 解决方式&#xff1a; 将JDK的使用内存进行了修改 ① 打开eclipse&#xff0c;window->preference->Java->Installed JREs&#xff0c;选中使用的jdk然后点击右侧的edit&#xff0c;在…

【算法】插入排序

插入排序 插入排序代码实现代码优化 排序&#xff1a; 排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大小&#xff0c;递增或递减的排列起来的操作。 稳定性&#xff1a; 假定在待排序的记录序列中&#xff0c;存在多个具有相同的关键字的记录&…

npm/yarn link 测试包时报错 Warning: Invalid hook call. Hooks can only be called ...

使用 dumi 开发 React 组件库时&#xff0c;为避免每次修改都发布到 npm&#xff0c;需要在本地的测试项目中使用 npm link 为组件库建立软连接&#xff0c;方便本地调试。 结果在本地测试项目使用 $ npm link 组件库 后&#xff0c;使用内部组件确报错&#xff1a; react.dev…

“安全即服务”为网络安全推开一道门

8月30日&#xff0c;三六零&#xff08;下称“360”&#xff09;集团发布了2023年半年报&#xff0c;其中安全业务第二季度收入6.54亿元&#xff0c;同比增长98.76%&#xff0c;环比增长157.16%&#xff0c;安全第二增长曲线已完全成型&#xff01;特别值得一提的是&#xff0c…

高速路自动驾驶功能HWP功能定义

一、功能定义 高速路自动驾驶功能HWP是指在一般畅通高速公路或城市快速路上驾驶员可以放开双手双脚&#xff0c;同时注意力可在较长时间内从驾驶环境中转移&#xff0c;做一些诸如看手机、接电话、看风景等活动&#xff0c;该系统最低工作速度为60kph。 如上两种不同环境和速度…

Vue+NodeJS+MongoDB实现邮箱验证注册、登录

一.主要内容 邮件发送用户注册用户信息存储到数据库用户登录密码加密JWT生成tokenCookie实现快速登录 在用户注册时,先发送邮件得到验证码.后端将验证进行缓存比对,如果验证码到期,比对不正确,拒绝登录;如果比对正确,将用户的信息进行加密存储到数据库. 用户登录时,先通过用…

LRTimelapse 6 for Mac(延时摄影视频制作软件)

LRTimelapse 是一款适用于macOS 系统的延时摄影视频制作软件&#xff0c;可以帮助用户创建高质量的延时摄影视频。该软件提供了直观的界面和丰富的功能&#xff0c;支持多种时间轴摄影工具和文件格式&#xff0c;并具有高度的可定制性和扩展性。 LRTimelapse 的主要特点如下&am…

Leetcode刷题笔记--Hot41-50

1--二叉树的层序遍历&#xff08;102&#xff09; 主要思路&#xff1a; 经典广度优先搜索&#xff0c;基于队列&#xff1b; 对于本题需要将同一层的节点放在一个数组中&#xff0c;因此遍历的时候需要用一个变量 nums 来记录当前层的节点数&#xff0c;即 nums 等于队列元素的…

全网独家:编译CentOS6.10系统的openssl-1.1.1多版本并存的rpm安装包

CentOS6.10系统原生的openssl版本太老&#xff0c;1.0.1e&#xff0c;不能满足一些新版本应用软件的要求&#xff0c;但是它又被wget、mysql-libs、python-2.6.6、yum等一众系统包所依赖&#xff0c;不能再做升级。故需考虑在不影响系统原生openssl的情况下&#xff0c;安装较新…

HarmonyOS/OpenHarmony(Stage模型)应用开发单一手势(三)

五、旋转手势&#xff08;RotationGesture&#xff09; RotationGesture(value?:{fingers?:number; angle?:number}) 旋转手势用于触发旋转手势事件&#xff0c;触发旋转手势的最少手指数量为2指&#xff0c;最大为5指&#xff0c;最小改变度数为1度&#xff0c;拥有两个可…

mac安装adobe需要注意的tips(含win+mac all安装包)

M2芯片只能安装2022年以后的&#xff08;包含2022年的&#xff09; 1、必须操作的开启“任何来源” “任何来源“设置&#xff0c;这是为了系统安全性&#xff0c;苹果希望所有的软件都从商店或是能验证的官方下载&#xff0c;导致默认不允许从第三方下载应用程序。macOS sie…