FLOPs、FLOPS、Params的含义及PyTorch中的计算方法

FLOPs、FLOPS、Params的含义及PyTorch中的计算方法

含义解释

  1. FLOPS:注意全大写,是floating point operations per second的缩写(这里的大S表示second秒),表示每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。

  2. FLOPs:注意s小写,是floating point operations的缩写(这里的小s则表示复数),表示浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。

  3. Params:没有固定的名称,大小写均可,表示模型的参数量,也是用来衡量算法/模型的复杂度。通常我们在论文中见到的是这样:# Params,那个井号是表示 number of 的意思,因此 # Params 的意思就是:参数的数量。

在这里插入图片描述

FLOPs与模型时间复杂度、GPU利用率有关,Params与模型空间复杂度、显存占用有关。即我们常见的nvidia-smi命令中的GPU利用率(红框)和显存占用(篮框)。

MAC

MAC:Multiply Accumulate,乘加运算。乘积累加运算(英语:Multiply Accumulate, MAC)是在数字信号处理器或一些微处理器中的特殊运算。实现此运算操作的硬件电路单元,被称为“乘数累加器”。这种运算的操作,是将乘法的乘积结果和累加器的值相加,再存入累加器:
a←a+b×ca\leftarrow a+b\times c aa+b×c
使用MAC可以将原本需要的两个指令操作减少到一个指令操作,从而提高运算效率。

FLOPs的计算

以下不考虑激活函数的计算量。

卷积层

(2×Ci×K2−1)×H×W×C0(2\times C_i\times K^2-1)\times H\times W\times C_0(2×Ci×K21)×H×W×C0

CiC_iCi=输入通道数, KKK=卷积核尺寸,H,WH,WH,W=输出特征图空间尺寸,CoC_oCo=输出通道数。

一个MAC算两个个浮点运算,所以在最前面×2\times 2×2。不考虑bias时有−1-11,有bias时没有−1-11。由于考虑的一般是模型推理时的计算量,所以上述公式是针对一个输入样本的情况,即batch size=1。

理解上面这个公式分两步,括号内是第一步,计算出输出特征图的一个pixel的计算量,然后再乘以 H×W×CoH\times W\times C_oH×W×Co 拓展到整个输出特征图。

括号内的部分又可以分为两步,(2⋅Ci⋅K2−1)=(Ci⋅K2)+(Ci⋅K2−1)(2\cdot C_i\cdot K^2-1)=(C_i\cdot K^2)+(C_i\cdot K^2-1)(2CiK21)=(CiK2)+(CiK21)。第一项是乘法运算数,第二项是加法运算数,因为 nnn 个数相加,要加 n−1n-1n1 次,所以不考虑bias,会有一个−1-11,如果考虑bias,刚好中和掉,括号内变为 2⋅Ci⋅K22\cdot C_i\cdot K^22CiK2

全连接层

全连接层: (2×I−1)×O(2\times I-1)\times O(2×I1)×O

III=输入层神经元个数 ,OOO=输出层神经元个数。

还是因为一个MAC算两个个浮点运算,所以在最前面×2\times 2×2。同样不考虑bias时有−1-11,有bias时没有−1-11。分析同理,括号内是一个输出神经元的计算量,拓展到OOO了输出神经元。

NVIDIA Paper [2017-ICLR]

笔者在这里放上 NVIDIA 在 【2017-ICLR】的论文:PRUNING CONVOLUTIONAL NEURAL NETWORKS FOR RESOURCE EFFICIENT INFERENCE 的附录部分FLOPs计算方法截图放在下面供读者参考。
在这里插入图片描述

使用PyTorch直接输出模型的Params(参数量)

完整统计参数量

import torch 
from torchvision.models import resnet50
import numpy as npTotal_params = 0
Trainable_params = 0
NonTrainable_params = 0model = resnet50()
for param in model.parameters():mulValue = np.prod(param.size())  # 使用numpy prod接口计算参数数组所有元素之积Total_params += mulValue  # 总参数量if param.requires_grad:Trainable_params += mulValue  # 可训练参数量else:NonTrainable_params += mulValue  # 非可训练参数量print(f'Total params: {Total_params / 1e6}M')
print(f'Trainable params: {Trainable_params/ 1e6}M')
print(f'Non-trainable params: {NonTrainable_params/ 1e6}M')

输出:

Total params: 25.557032M
Trainable params: 25.557032M
Non-trainable params: 0.0M

简单统计可训练的参数量

通常,我们想知道的只是可训练的参数量,我们也可以简单地直接一行统计出可训练的参数量:

import torchvision.models as modelsmodel = models.resnet50(pretrained=False)Trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Trainable params: {Trainable_params/ 1e6}M')

输出:

Trainable params: 25.557032M

统计每一层的参数量

倘若想要统计每一层的参数量,参考代码如下:

model = vgg16()
for name, parameters in model.named_parameters():print(name, ':', np.prod(parameters.size()))

会打印出每一层的名称及参数量:

features.0.weight : 1728
features.0.bias : 64
features.2.weight : 36864
features.2.bias : 64
features.5.weight : 73728
...

使用thop库来获取模型的FLOPs(计算量)和Params(参数量)

安装

直接pypi安装即可

pip install thop

使用

我们使用thop库来计算vgg16模型的计算量和参数量。

import torch
from thop import profile
from archs.ViT_model import get_vit, ViT_Aes
from torchvision.models import resnet50model = resnet50()
input1 = torch.randn(4, 3, 224, 224) 
flops, params = profile(model, inputs=(input1, ))
print('FLOPs = ' + str(flops/1000**3) + 'G')
print('Params = ' + str(params/1000**2) + 'M')

输出:

FLOPs = 16.446058496G
Params = 25.557032M

Ref:

https://openreview.net/forum?id=SJGCiw5gl

https://www.zhihu.com/question/65305385/answer/451060549

https://www.cnblogs.com/chuqianyu/p/14254702.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532797.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么?

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么? 转自:https://zhuanlan.zhihu.com/p/93812784 我们提到圆周率 π 的时候,它有很多种表达方式,既可以用数学常数3.14159表示,也可以用一长串1和0的二进制长串表示。 …

linux设备驱动之串口移植,Linux设备驱动之UART驱动结构

一、对于串口驱动Linux系统中UART驱动属于终端设备驱动,应该说是实现串口驱动和终端驱动来实现串口终端设备的驱动。要了解串口终端的驱动在Linux系统的结构就先要了解终端设备驱动在Linux系统中的结构体系,一方面自己了解的不够,另一发面关于…

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL 笔者注:NCCL 开源项目地址:https://github.com/NVIDIA/nccl 转自:https://www.zhihu.com/question/63219175/answer/206697974 NCCL是Nvidia Collective multi-GPU Communication Library的简称&…

C语言n个坐标点间的最大距离,c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。...

c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容,让我们赶快一起来看一下吧!c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。#…

[分布式训练] 单机多卡的正确打开方式:理论基础

[分布式训练] 单机多卡的正确打开方式:理论基础 转自:https://fyubang.com/2019/07/08/distributed-training/ 瓦砾由于最近bert-large用的比较多,踩了很多分布式训练的坑,加上在TensorFlow和PyTorch之间更换,算是熟…

s3c2416开发板 linux,S3C2416移植内核Linux3.1的wm9713声卡过程

移植内核的声卡驱动。原因没有声卡驱动,WM9713声卡驱动移植(原来的内核有UDA1341声卡驱动,我们再次基础上直接修改)1、直接复制内核得到三个文件:s3c2416_wm9713.c , wm9713.c , s3c2416_ac97.c.linux-3.1\sound\soc\codecs\Wm9713.c---->wm9713.c;li…

c语言六位抢答器课程设计,51单片机八路抢答器课程设计

;说明:本人的这个设计改进后解决了前一个版本中1号抢答优先的问题,并增加了锦囊的设置,当参赛选手在回答问题时要求使用锦囊,则主持人按下抢答开始键,计时重新开始。;八路抢答器电路请看下图是用ps仿真的,已…

ELF文件详解—初步认识

ELF文件详解—初步认识 转自:https://blog.csdn.net/daide2012/article/details/73065204 一、 引言 在讲解ELF文件格式之前,我们来回顾一下,一个用C语言编写的高级语言程序是从编写到打包、再到编译执行的基本过程,我们知道在C…

linux下ora 01110,ORA-01003ORA-01110

Oracle 9i数据库登录时,提示ORA-01003&ORA-01110,大概意思是数据文件存储介质损坏。startup nomount,正常;alter database mount,也正常;alter database open,提示如下:alter database open*ERROR 位于第 1 行:ORA…

x11转发:通过ssh远程使用GUI程序

x11转发:通过ssh远程使用GUI程序 我们常常使用ssh服务远程操控服务器,大多数操作我们都可以通过命令行命令来实现。 ssh远程无法查看GUI程序 现在,笔者在x11-test目录下放入一张图片test.jpg,并通过opnencv-python写一个简单的…

操作系统引导详细过程

操作系统引导详细过程 转自:https://blog.csdn.net/lijie45655/article/details/89366372 就直观而言,我们所见到计算机启动的过程是:按下电脑开机键,系统在黑色的屏幕下打印出一些英文语句、然后进入进度条状态,最后…

android 自定义透明 等待 dialog,Android自定义Dialog内部透明、外部遮罩效果

Android自定义Dialog内部透明、外部遮罩效果发布时间:2020-09-09 03:01:41来源:脚本之家阅读:117作者:zst1303939801本文实例为大家分享了Android自定义Dialog遮罩效果的具体代码,供大家参考,具体内容如下图…

对比损失的PyTorch实现详解

对比损失的PyTorch实现详解 本文以SiT代码中对比损失的实现为例作介绍。 论文:https://arxiv.org/abs/2104.03602 代码:https://github.com/Sara-Ahmed/SiT 对比损失简介 作为一种经典的自监督损失,对比损失就是对一张原图像做不同的图像…

android 融云浏览大图,融云 Android sdk kit 头像昵称更新机制

先申明笔者的实现方式不是唯一 也不一定是最优化的方案 如果您看到此篇博文 有不同看法 或者 更好的优化 更高的效率 欢迎在评论发表意见 融云官网点我融云头像机制相关视频详解首先跟大家说一下 kit 跟 lib 的头像机制 kit 是已经包含融云已经给开发者定制好的界面 诸如 会话界…

Linux中的awk、sed、grep及正则表达式详解

Linux中的awk、sed、grep及正则表达式详解 简介 awk、sed和grep是Linux中文本操作的三大利器。 其中awk适用于取列,sed适用于取行,grep适用于过滤。 正则表达式 首先我们来介绍一下正则表达式,正则表达式(regular expression)描述了一种…

android聚焦时如何给控件加边框,edittext设置获得焦点时的边框颜色

第一步:为了更好的比较,准备两个一模一样的EditText(当Activity启动时,焦点会在第一个EditText上,如果你不希望这样只需要写一个高度和宽带为0的EditText即可避免,这里就不这么做了),代码如下:a…

xargs 命令教程

xargs 命令教程 转自:http://www.ruanyifeng.com/blog/2019/08/xargs-tutorial.html 作者: 阮一峰 日期: 2019年8月 8日 xargs是 Unix 系统的一个很有用的命令,但是常常被忽视,很多人不了解它的用法。 本文介绍如…

android strictmode有什么作用,Android 性能优化 之 StrictMode

8种机械键盘轴体对比本人程序员,要买一个写代码的键盘,请问红轴和茶轴怎么选?StrictMode概述StrictMode 是用来检测程序中违例情况的开发者工具。使用StrictMode,系统检测出主线程违例的情况会做出相应的反应,如日志打…

curl 的用法指南

curl 的用法指南 转自:http://www.ruanyifeng.com/blog/2019/09/curl-reference.html 作者: 阮一峰 日期: 2019年9月 5日 简介 curl 是常用的命令行工具,用来请求 Web 服务器。它的名字就是客户端(client&#xf…

怎么在html显示已登录状态,jQuery Ajax 实现在html页面实时显示用户登录状态

当网站是全静态的html页面时,而又希望网站会员在登录之后并在所有页面头部显示登录状态,如用户名等,如果未登录就是未登录状态,下面给大家来分享实现的方法。一、在html静态页面中加入div,并指定ID如:二、新…