【论文笔记】Run, Don’t Walk: Chasing Higher FLOPS for Faster Neural Networks

论文地址:Run, Don't Walk: Chasing Higher FLOPS for Faster Neural Networks

代码地址:https://github.com/jierunchen/fasternet

该论文主要提出了PConv,通过优化FLOPS提出了快速推理模型FasterNet。

在设计神经网络结构的时候,大部分注意力都会放在降低FLOPs( floating-point opera-
tions)上,有的时候FLOPs降低了,并不意味了推理速度加快了,这主要是因为没考虑到FLOPS(floating-point operations per second)。针对该问题,作者提出了PConv( partial convolution),通过提高FLOPS来加快推理速度。

一、引言

      非常多的实时推理模型都将重点放在降低FLOPs上,比如:MobileNet,ShuffleNet,GhostNet等等。虽然这些网络都降低了FLOPs,但是他们没有考虑到FLOPS,所以推理速度仍有优化空间,推理的延时计算公式如下:

由上式可以看出,要想加快推理速度,不仅可以从FLOPs入手,也可以优化FLOPS。作者在多个模型上做了实验,发现很多模型的FLOPS低于ResNet50。于是作者提出了PConv,通过提高FLOPS来加快推理速度。

二、PConv

为了提高FLOPS,作者提出了PConv,其结构如下图:

部分通道数经过卷积运算,其他通道不进行运算。再看了几眼。。。。这个和GhostConv好像呀。。。。

网络整体结构如下:

三、模型性能

FasterNet在ImageNet-1K上的表现如下:

在coco数据集上的表现如下:

四、代码

给出PConv的代码,也是非常简单:

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from functools import partial
from typing import List
from torch import Tensor
import copy
import ostry:from mmdet.models.builder import BACKBONES as det_BACKBONESfrom mmdet.utils import get_root_loggerfrom mmcv.runner import _load_checkpointhas_mmdet = True
except ImportError:print("If for detection, please install mmdetection first")has_mmdet = Falseclass Partial_conv3(nn.Module):def __init__(self, dim, n_div, forward):super().__init__()self.dim_conv3 = dim // n_divself.dim_untouched = dim - self.dim_conv3self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, 3, 1, 1, bias=False)if forward == 'slicing':self.forward = self.forward_slicingelif forward == 'split_cat':self.forward = self.forward_split_catelse:raise NotImplementedErrordef forward_slicing(self, x: Tensor) -> Tensor:# only for inferencex = x.clone()   # !!! Keep the original input intact for the residual connection laterx[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])return xdef forward_split_cat(self, x: Tensor) -> Tensor:# for training/inferencex1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)x1 = self.partial_conv3(x1)x = torch.cat((x1, x2), 1)return xclass MLPBlock(nn.Module):def __init__(self,dim,n_div,mlp_ratio,drop_path,layer_scale_init_value,act_layer,norm_layer,pconv_fw_type):super().__init__()self.dim = dimself.mlp_ratio = mlp_ratioself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.n_div = n_divmlp_hidden_dim = int(dim * mlp_ratio)mlp_layer: List[nn.Module] = [nn.Conv2d(dim, mlp_hidden_dim, 1, bias=False),norm_layer(mlp_hidden_dim),act_layer(),nn.Conv2d(mlp_hidden_dim, dim, 1, bias=False)]self.mlp = nn.Sequential(*mlp_layer)self.spatial_mixing = Partial_conv3(dim,n_div,pconv_fw_type)if layer_scale_init_value > 0:self.layer_scale = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.forward = self.forward_layer_scaleelse:self.forward = self.forwarddef forward(self, x: Tensor) -> Tensor:shortcut = xx = self.spatial_mixing(x)x = shortcut + self.drop_path(self.mlp(x))return xdef forward_layer_scale(self, x: Tensor) -> Tensor:shortcut = xx = self.spatial_mixing(x)x = shortcut + self.drop_path(self.layer_scale.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))return xclass BasicStage(nn.Module):def __init__(self,dim,depth,n_div,mlp_ratio,drop_path,layer_scale_init_value,norm_layer,act_layer,pconv_fw_type):super().__init__()blocks_list = [MLPBlock(dim=dim,n_div=n_div,mlp_ratio=mlp_ratio,drop_path=drop_path[i],layer_scale_init_value=layer_scale_init_value,norm_layer=norm_layer,act_layer=act_layer,pconv_fw_type=pconv_fw_type)for i in range(depth)]self.blocks = nn.Sequential(*blocks_list)def forward(self, x: Tensor) -> Tensor:x = self.blocks(x)return xclass PatchEmbed(nn.Module):def __init__(self, patch_size, patch_stride, in_chans, embed_dim, norm_layer):super().__init__()self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_stride, bias=False)if norm_layer is not None:self.norm = norm_layer(embed_dim)else:self.norm = nn.Identity()def forward(self, x: Tensor) -> Tensor:x = self.norm(self.proj(x))return xclass PatchMerging(nn.Module):def __init__(self, patch_size2, patch_stride2, dim, norm_layer):super().__init__()self.reduction = nn.Conv2d(dim, 2 * dim, kernel_size=patch_size2, stride=patch_stride2, bias=False)if norm_layer is not None:self.norm = norm_layer(2 * dim)else:self.norm = nn.Identity()def forward(self, x: Tensor) -> Tensor:x = self.norm(self.reduction(x))return xclass FasterNet(nn.Module):def __init__(self,in_chans=3,num_classes=1000,embed_dim=96,depths=(1, 2, 8, 2),mlp_ratio=2.,n_div=4,patch_size=4,patch_stride=4,patch_size2=2,  # for subsequent layerspatch_stride2=2,patch_norm=True,feature_dim=1280,drop_path_rate=0.1,layer_scale_init_value=0,norm_layer='BN',act_layer='RELU',fork_feat=False,init_cfg=None,pretrained=None,pconv_fw_type='split_cat',**kwargs):super().__init__()if norm_layer == 'BN':norm_layer = nn.BatchNorm2delse:raise NotImplementedErrorif act_layer == 'GELU':act_layer = nn.GELUelif act_layer == 'RELU':act_layer = partial(nn.ReLU, inplace=True)else:raise NotImplementedErrorif not fork_feat:self.num_classes = num_classesself.num_stages = len(depths)self.embed_dim = embed_dimself.patch_norm = patch_normself.num_features = int(embed_dim * 2 ** (self.num_stages - 1))self.mlp_ratio = mlp_ratioself.depths = depths# split image into non-overlapping patchesself.patch_embed = PatchEmbed(patch_size=patch_size,patch_stride=patch_stride,in_chans=in_chans,embed_dim=embed_dim,norm_layer=norm_layer if self.patch_norm else None)# stochastic depth decay ruledpr = [x.item()for x in torch.linspace(0, drop_path_rate, sum(depths))]# build layersstages_list = []for i_stage in range(self.num_stages):stage = BasicStage(dim=int(embed_dim * 2 ** i_stage),n_div=n_div,depth=depths[i_stage],mlp_ratio=self.mlp_ratio,drop_path=dpr[sum(depths[:i_stage]):sum(depths[:i_stage + 1])],layer_scale_init_value=layer_scale_init_value,norm_layer=norm_layer,act_layer=act_layer,pconv_fw_type=pconv_fw_type)stages_list.append(stage)# patch merging layerif i_stage < self.num_stages - 1:stages_list.append(PatchMerging(patch_size2=patch_size2,patch_stride2=patch_stride2,dim=int(embed_dim * 2 ** i_stage),norm_layer=norm_layer))self.stages = nn.Sequential(*stages_list)self.fork_feat = fork_featif self.fork_feat:self.forward = self.forward_det# add a norm layer for each outputself.out_indices = [0, 2, 4, 6]for i_emb, i_layer in enumerate(self.out_indices):if i_emb == 0 and os.environ.get('FORK_LAST3', None):raise NotImplementedErrorelse:layer = norm_layer(int(embed_dim * 2 ** i_emb))layer_name = f'norm{i_layer}'self.add_module(layer_name, layer)else:self.forward = self.forward_cls# Classifier headself.avgpool_pre_head = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Conv2d(self.num_features, feature_dim, 1, bias=False),act_layer())self.head = nn.Linear(feature_dim, num_classes) \if num_classes > 0 else nn.Identity()self.apply(self.cls_init_weights)self.init_cfg = copy.deepcopy(init_cfg)if self.fork_feat and (self.init_cfg is not None or pretrained is not None):self.init_weights()def cls_init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, (nn.Conv1d, nn.Conv2d)):trunc_normal_(m.weight, std=.02)if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):nn.init.constant_(m.bias, 0)nn.init.constant_(m.weight, 1.0)# init for mmdetection by loading imagenet pre-trained weightsdef init_weights(self, pretrained=None):logger = get_root_logger()if self.init_cfg is None and pretrained is None:logger.warn(f'No pre-trained weights for 'f'{self.__class__.__name__}, 'f'training start from scratch')passelse:assert 'checkpoint' in self.init_cfg, f'Only support ' \f'specify `Pretrained` in ' \f'`init_cfg` in ' \f'{self.__class__.__name__} 'if self.init_cfg is not None:ckpt_path = self.init_cfg['checkpoint']elif pretrained is not None:ckpt_path = pretrainedckpt = _load_checkpoint(ckpt_path, logger=logger, map_location='cpu')if 'state_dict' in ckpt:_state_dict = ckpt['state_dict']elif 'model' in ckpt:_state_dict = ckpt['model']else:_state_dict = ckptstate_dict = _state_dictmissing_keys, unexpected_keys = \self.load_state_dict(state_dict, False)# show for debugprint('missing_keys: ', missing_keys)print('unexpected_keys: ', unexpected_keys)def forward_cls(self, x):# output only the features of last layer for image classificationx = self.patch_embed(x)x = self.stages(x)x = self.avgpool_pre_head(x)  # B C 1 1x = torch.flatten(x, 1)x = self.head(x)return xdef forward_det(self, x: Tensor) -> Tensor:# output the features of four stages for dense predictionx = self.patch_embed(x)outs = []for idx, stage in enumerate(self.stages):x = stage(x)if self.fork_feat and idx in self.out_indices:norm_layer = getattr(self, f'norm{idx}')x_out = norm_layer(x)outs.append(x_out)return outs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/579583.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

性能优化,让用户体验更加完美(渲染层面)

前言 上一篇我们已经围绕“网络层面”探索页面性能优化的方案&#xff0c;接下来本篇围绕“浏览器渲染层面”继续开展探索。正文开始前&#xff0c;我们思考如下问题&#xff1a; 浏览器渲染页面会经过哪几个关键环节&#xff1f;“渲染层面”的优化从哪几方面着手&#xff1f…

【Redis】一文掌握Redis原理及常见问题

Redis是基于内存数据库&#xff0c;操作效率高&#xff0c;提供丰富的数据结构&#xff08;Redis底层对数据结构还做了优化&#xff09;&#xff0c;可用作数据库&#xff0c;缓存&#xff0c;消息中间件等。如今广泛用于互联网大厂&#xff0c;面试必考点之一&#xff0c;本文…

.NET Conf 2023 回顾 – 庆祝社区、创新和 .NET 8 的发布

作者&#xff1a; Jon Galloway - Principal Program Manager, .NET Community Team Mehul Harry - Product Marketing Manager, .NET, Azure Marketing 排版&#xff1a;Alan Wang .NET Conf 2023 是有史以来规模最大的 .NET 会议&#xff0c;来自全球各地的演讲者进行了 100 …

设计模式-注册模式

设计模式专栏 模式介绍模式特点应用场景注册模式和单例模式的区别代码示例Java实现注册模式Python实现注册模式 注册模式在spring中的应用 模式介绍 注册模式是一种设计模式&#xff0c;也称为注册树或注册器模式。这种模式将类的实例化和创建分离开来&#xff0c;避免在应用程…

【广州华锐互动】VR科技科普展厅平台:快速、便捷地创建出属于自己的虚拟展馆

随着科技的不断进步&#xff0c;虚拟现实(VR)技术已经在许多领域取得了显著的成果。尤其是在展馆设计领域&#xff0c;VR科技科普展厅平台已经实现了许多令人瞩目的新突破。 VR科技科普展厅平台是广州华锐互动专门为企业和机构提供虚拟展馆设计和制作的在线平台。通过这个平台&…

Git基础学习_p1

文章目录 一、前言二、Git手册学习2.1 Git介绍&前置知识2.2 Git教程2.2.1 导入新项目2.2.2 做更改2.2.3 Git追踪内容而非文件2.2.4 查看项目历史2.2.5 管理分支&#x1f53a;2.2.6 用Git来协同工作2.2.7 查看历史 三、结尾 一、前言 Git相信大部分从事软件工作的人都听说过…

ASP.NET MVC的5种AuthorizationFilter

一、IAuthorizationFilter 所有的AuthorizationFilter实现了接口IAuthorizationFilter。如下面的代码片断所示&#xff0c;IAuthorizationFilter定义了一个OnAuthorization方法用于实现授权的操作。作为该方法的参数filterContext是一个表示授权上下文的AuthorizationContext对…

从计算机内存结构到iOS

一、冯.诺伊曼结构 当前计算机都是冯.诺伊曼结构&#xff08;Von Neumann architecture&#xff09;&#xff0c;是指存储器存放程序的指令以及数据&#xff0c;在程序运行时根据需要提供给CPU使用。 冯.诺伊曼瓶颈 在目前的科技水平之下&#xff0c;CPU与存储器之间的读写速…

挑战与应对:迅软科技探讨IT企业应对数据泄密危机的智慧之路

随着信息技术的快速发展&#xff0c;软件IT行业面临着前所未有的数据安全挑战。黑客攻击、病毒传播、内部泄密等安全威胁层出不穷&#xff0c;给企业的核心资产和运营带来严重威胁。同时&#xff0c;国家对于数据安全的法律法规也日益严格&#xff0c;要求企业必须采取更加有效…

https密钥认证、上传镜像实验

一、第一台主机通过https密钥对认证 1、安装docker服务 &#xff08;1&#xff09;安装环境依赖包 yum -y install yum-utils device-mapper-persistent-data lvm2 &#xff08;2&#xff09;设置阿里云镜像源 yum-config-manager --add-repo http://mirrors.aliyun.com/do…

VLAN简介

在配置交换机或者传输设备时&#xff0c;经常会提到vlan&#xff0c;这个vlan具体是啥呢&#xff1f; VLAN&#xff08;Virtual Local Area Network&#xff09;中文名为“虚拟局域网”。它是一种在物理网络上划分出逻辑网络的方法&#xff0c;将物理上的局域网在逻辑上划分为多…

设计模式——适配器模式(Adapter Pattern)

概述 适配器模式可以将一个类的接口和另一个类的接口匹配起来&#xff0c;而无须修改原来的适配者接口和抽象目标类接口。适配器模式(Adapter Pattern)&#xff1a;将一个接口转换成客户希望的另一个接口&#xff0c;使接口不兼容的那些类可以一起工作&#xff0c;其别名为包装…

分布式下有哪些好用的监控组件?

在之前的内容中&#xff0c;分析了分布式系统下的线上服务监控的常用指标&#xff0c;那么在实际开发中&#xff0c;如何收集各个监控指标呢&#xff1f;线上出现告警之后&#xff0c;又如何快速处理呢&#xff1f;本文我们就来看下这两个问题。 常用监控组件 目前分布式系统…

Node.js版本对比

目录 1. node版本与Npm版本对照表 2. node版本与node-sass版本对照表 3. node-sass与sass-loader版本对照表 1. node版本与Npm版本对照表 以往的版本 | Node.js 下面显示最新的对应内容&#xff0c;如果需要查找历史版本&#xff0c;可以进入上面的页面查询 VersionLTSDateV8np…

鸿蒙实战-库的调用(ArkTS)

整体框架搭建 主页面、本地库组件页面、社区库组件页面三个页面组成&#xff0c;主页面由Navigation作为根组件实现全局标题&#xff0c;由Tabs组件实现本地库和社区库页面的切换。 // MainPage.ets import { Outer } from ../view/OuterComponent; import { Inner } from ..…

【微服务核心】Spring Boot

Spring Boot 文章目录 Spring Boot1. 简介2. 开发步骤3. 配置文件4. 整合 Spring MVC 功能5. 整合 Druid 和 Mybatis6. 使用声明式事务7. AOP整合配置8. SpringBoot项目打包和运行 1. 简介 SpringBoot&#xff0c;开箱即用&#xff0c;设置合理的默认值&#xff0c;同时也可以…

如何让机器人具备实时、多模态的触觉感知能力?

人类能够直观地感知和理解复杂的触觉信息&#xff0c;是因为分布在指尖皮肤的皮肤感受器同时接收到不同的触觉刺激&#xff0c;并将触觉信号立即传输到大脑。尽管许多研究小组试图模仿人类皮肤的结构和功能&#xff0c;但在一个系统内实现类似人类的触觉感知过程仍然是一个挑战…

【go语言】CSP并发机制与Actor模型

一、多线程共享内存 1. 概念 多线程共享内存模型是一种并发编程模型&#xff0c;其中多个线程在同一个进程的地址空间中共享相同的内存区域。这种模型允许多个线程并发地读取和写入相同的数据结构&#xff0c;但也引入了一些潜在的问题&#xff0c;其中最常见的问题之一就是…

【WordPress插件】热门关键词推荐v1.3.0 Pro开心版

介绍&#xff1a; WordPress插件-WBOLT热门关键词推荐插件&#xff08;Smart Keywords Tool&#xff09;是一款集即时关键词推荐、关键词选词工具及文章智能标签功能于一体的WordPress网站SEO优化插件。 智能推荐&#xff1a; 热门关键词推荐引擎-支持360搜索、Bing、谷歌&a…

【已解决】c++qt如何制作翻译供程序调用

本博文源于笔者正在编写的工具需要创建翻译文件&#xff0c;恰好将qt如何进行翻译&#xff0c;从零到结果进行读者查阅&#xff0c;并非常推荐读者进行收藏点赞&#xff0c;因为步步都很清晰&#xff0c;堪称胎教式c制作&#xff0c;而且内容还包括如何部署在windows下。堪称值…