深度学习基础之参数量(3)

一般的CNN网络的参数量估计代码

class ResidualBlock(nn.Module):def __init__(self, in_planes, planes, norm_fn='group', stride=1):super(ResidualBlock, self).__init__()print(in_planes, planes, norm_fn, stride)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)self.relu = nn.ReLU(inplace=True)num_groups = planes // 8if norm_fn == 'group':self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)if not stride == 1:self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)elif norm_fn == 'batch':self.norm1 = nn.BatchNorm2d(planes)self.norm2 = nn.BatchNorm2d(planes)if not stride == 1:self.norm3 = nn.BatchNorm2d(planes)elif norm_fn == 'instance':self.norm1 = nn.InstanceNorm2d(planes)self.norm2 = nn.InstanceNorm2d(planes)if not stride == 1:self.norm3 = nn.InstanceNorm2d(planes)elif norm_fn == 'none':self.norm1 = nn.Sequential()self.norm2 = nn.Sequential()if not stride == 1:self.norm3 = nn.Sequential()if stride == 1:self.downsample = Noneelse:self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)def forward(self, x):print(x.shape)#exit()y = xy = self.relu(self.norm1(self.conv1(y)))y = self.relu(self.norm2(self.conv2(y)))if self.downsample is not None:x = self.downsample(x)return self.relu(x + y)R=ResidualBlock(384, 384, norm_fn='instance', stride=1)
summary(R.to("cuda" if torch.cuda.is_available() else "cpu"), (384, 32, 32))

transformer结构的参数量的估计结果

import torch
import torch.nn as nn
from thop import profile
from torchsummary import summary# 定义一个简单的Transformer模型
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.embedding = nn.Embedding(input_dim, hidden_dim)self.transformer_layers = nn.Transformer(d_model=hidden_dim,nhead=num_heads,num_encoder_layers=num_layers,num_decoder_layers=num_layers)self.fc = nn.Linear(hidden_dim, input_dim)def forward(self, src, tgt):src = self.embedding(src)tgt = self.embedding(tgt)output = self.transformer_layers(src, tgt)output = self.fc(output)return output# 创建Transformer模型实例
model2 = Transformer(input_dim=512, hidden_dim=512, num_heads=8, num_layers=6)# 使用thop进行FLOPS估算
flops, params = profile(model2, inputs=(torch.randint(0, 512, (128,)), torch.randint(0, 512, (64,))))
print(f"FLOPS: {flops / 1e9} G FLOPS")  # 打印FLOPS,以十亿FLOPS(GFLOPS)为单位# 计算参数量并打印
num_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {num_params}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/96942.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

后端解决跨域(极速版)

header(Access-Control-Allow-Origin: *); header(Access-Control-Allow-Methods:*); 代表接收全部的请求,"POST,GET"//允许访问的方式 指定域,如http://172.20.0.206//宝塔的域名,注意不是:http://wang.jingyi.icu等…

网络和系统操作命令

目录 ping:用于检测网络是否通畅,以及网络时延情况。ipconfig:查看计算机的IP参数配置信息,如IP地址、默认网关、子网掩码等信息。netstat:显示协议统计信息和当前TCP/IP网络连接。tasklist:显示当前运行的…

正点原子嵌入式linux驱动开发——U-boot图形化配置及其原理

经过之前对uboot的学习可以知道:uboot可以通过stm32mp15_trusted_defconfig来配置,或者通过文件stm32mp1.h来配置uboot。还有另外一种配置uboot的方法,就是图形化配置,以前的uboot是不支持图形化配置,只有Linux内核才支…

JMeter工具的介绍,安装

一、本文学习目标 1、能知道JMeter的优缺点 2、能掌握JMeter的安装流程 3、能掌握JMeter线程组的设置 4、能掌握JMeter参数化的使用 5、能掌握JMeter直连数据库操作 6、能掌握JMeter的断言. 二、JMeter简介 (1)Jmeter详细介绍 **JMeter(A…

C++递归函数

在本文中,您将学习创建递归函数。调用自身的函数。 调用自身的函数称为递归函数。并且,这种技术称为递归。 递归在C 中如何工作? void recurse() {... .. ...recurse();... .. ... }int main() {... .. ...recurse();... .. ... } 下图显…

基于SSM+Vue的物流管理系统的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:VueHTML 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 …

[每日算法 - 阿里机试] leetcode19. 删除链表的倒数第 N 个结点

入口 力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/remove-nth-node-from-end…

(面试)谈谈我对C++面向对象特性的理解

💯 博客内容:C读取一行内个数不定的整数的方式 😀 作  者:陈大大陈 🚀 个人简介:一个正在努力学技术的准前端,专注基础和实战分享 ,欢迎私信! 💖 欢迎大家&…

yolov5 web端部署进行图片和视频检测

目录 1、思路 2、代码结构 3、代码运行 4、api接口代码 5、web ui界面 6、参考资料 7、代码分享 1、思路 通过搭建flask微型服务器后端,以后通过vue搭建网页前端。flask是第一个第三方库。与其他模块一样,安装时可以直接使用python的pip命令实现…

字符串常量池位于JVM哪里

Java6 和6之前,常量池是存放在方法区(永久代)中的。Java7,将常量池是存放到了堆中。Java8 之后,取消了整个永久代区域,取而代之的是元空间。运行时常量池和静态常量池存放在元空间中,而字符串常…

c语言:通讯录管理系统(增删查改)

前言:在大多数高校内,都是通过设计一个通讯录管理系统来作为c语言课程设计,通过一个具体的系统设计将我们学习过的结构体和函数等知识糅合起来,可以很好的锻炼学生的编程思维,本文旨在为通讯录管理系统的设计提供思路和…

雷达散射截面(RCS)相关概念

一、雷达散射截面(RCS) RCS被指定为直径为1.128 m的完美导电球体的倍数。该球体的可见表面为1 m,但仅具有较小的反向散射有效面积。因此,更好的反射表面可以具有比其几何尺寸大得多的RCS。 雷达截面积 二、简单目标的RCS 简单目标的RCS如下表所示: 三、瑞利、米氏和光学…

基于SSM的家庭财务管理系统设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

SpringBoot banner 样式 自动生成

目录 SpringBoot banner 样式 自动生成 图案网站: 1.第一步创建banner.txt文件 2.访问网站Ascii艺术字实现个性化Spring Boot启动banner图案,轻松修改更换banner.txt文件内容,收集了丰富的banner艺术字和图,并且支持中文banner下…

echarts

1 type值汇总 不同的type的值对应的图表类型如下: type: ‘bar’:柱状/条形图 type: ‘line’:折线/面积图 type: ‘pie’:饼图 type: ‘scatter’:散点(气泡)图 type: ‘effectScatter’&…

ansible - Role

1、简介: Ansible 中的角色(Role)是一种组织和封装Playbook的方法,用于管理和组织 Ansible代码。它可以将任务和配置逻辑模块化,以便在不同的Playbook中共享和重用。 2、通过 role 远程部署并配置 nginx (1) 准备目…

数组(数据结构)

优质博文:IT-BLOG-CN 一、简介 数组Array是一种线性表数据结构,它用一组连续的内存空间,存储一组具有相同类型的数据。 数组因具有连续的内存空间的特点,数据拥有非常高效率的“随机访问”,时间复杂度为O(1)。但因要保…

ubuntu使用whisper和funASR-语者分离-二值化

文章目录 一、选择系统1.1 更新环境 二、安装使用whisper2.1 创建环境2.1 安装2.1.1安装基础包2.1.2安装依赖 3测试13测试2 语着分离创建代码报错ModuleNotFoundError: No module named pyannote报错No module named pyannote_whisper 三、安装使用funASR1 安装1.1 安装 Conda&…

黑豹程序员-架构师学习路线图-百科:Database数据库

文章目录 1、什么是Database2、发展历史3、数据库排行网4、总结 1、什么是Database 当今世界是一个充满着数据的互联网世界,各处都充斥着大量的数据。即这个互联网世界就是数据世界。 支撑这个数据世界的基石就是数据库,数据库也可以称为数据的仓库。 …

typescript开发环境搭建

typescript是基于javascript的强类型标记性语言,使用typescript语言可开发出不同规模的、易于扩展的web前端页面应用,本文主要描述typescript的开发环境搭建。 npm install -g typescript 如上所示,在本地开发环境中,使用nodejs…