Pytorch从零开始实战16

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merage = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merage)x = ReLU()(x)return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)# epsilon为BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224,224,3]->[230,230,3]x = ZeroPadding2D((3, 3))(inputs)x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/611420.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TensorRt(5)动态尺寸输入的分割模型测试

文章目录 1、固定输入尺寸逻辑2、动态输入尺寸2.1、模型导出2.2、推理测试2.3、显存分配问题2.4、完整代码 这里主要说明使用TensorRT进行加载编译优化后的模型engine进行推理测试,与前面进行目标识别、目标分类的模型的网络输入是固定大小不同,导致输入…

【现代密码学】笔记3.4-3.7--构造安全加密方案、CPA安全、CCA安全 《introduction to modern cryphtography》

【现代密码学】笔记3.4-3.7--构造安全加密方案、CPA安全、CCA安全 《introduction to modern cryphtography》 写在最前面私钥加密与伪随机性 第二部分流加密与CPA多重加密 CPA安全加密方案CPA安全实验、预言机访问(oracle access) 操作模式伪随机函数PR…

Java微服务系列之 ShardingSphere - ShardingSphere-JDBC

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 系列专栏目录 [Java项…

报错解决:No module named ‘pytorch_lightning‘ 安装pytorch_lightning

报错记录 执行如下代码: import pytorch_lightning报错: No module named ‘pytorch_lightning’ 解决方式 安装pytorch_lightning包即可。 一般情况下,缺失的包通过pip安装,即: pip install pytorch_lightning然…

1 快速前端开发

1 前端开发 目的:开发一个平台(网站)- 前端开发:HTML、CSS、JavaScript- Web框架:接收请求并处理- MySQL数据库:存储数据地方快速上手:基于Flask Web框架让你快速搭建一个网站出来。1.快速开发…

HarmonyOS应用开发学习笔记 应用上下文Context 获取文件夹路径

1、 HarmoryOS Ability页面的生命周期 2、 Component自定义组件 3、HarmonyOS 应用开发学习笔记 ets组件生命周期 4、HarmonyOS 应用开发学习笔记 ets组件样式定义 Styles装饰器:定义组件重用样式 Extend装饰器:定义扩展组件样式 5、HarmonyOS 应用开发…

14-股票K线图功能-个股日K线SQL分析__ev

需求:统计个股日K线数据,也就是把某只股票每天的最高价,开盘价,收盘价,最低价形成K线图。

山西电力市场日前价格预测【2024-01-11】

日前价格预测 预测说明: 如上图所示,预测明日(2024-01-11)山西电力市场全天平均日前电价为231.43元/MWh。其中,最高日前电价为422.21元/MWh,预计出现在18:00。最低日前电价为0.00元/MWh,预计出…

现代软件测试中的自动化测试工具

自动化测试的重要性和优势 引言:随着软件开发的不断发展,自动化测试工具在现代软件测试中扮演着重要角色。提高效率:自动化测试可以加快测试流程,减少人工测试所需的时间和资源。提升准确性:自动化测试工具可以减少人…

PACS医学影像报告管理系统源码带CT三维后处理技术

PACS从各种医学影像检查设备中获取、存储、处理影像数据,传输到体检信息系统中,生成图文并茂的体检报告,满足体检中心高水准、高效率影像处理的需要。 自主知识产权:拥有完整知识产权,能够同其他模块无缝对接 国际标准…

Linux CentOS 7.6安装JDK详细保姆级教程

一、检查系统是否自带jdk java --version 如果有的话,找到对应的文件删除 第一步:先查看Linux自带的JDK有几个,用命令: rpm -qa | grep -i java第二步:删除JDK,执行命令: rpm -qa | grep -i java | xarg…

企业的 Android 移动设备管理 (MDM) 解决方案

移动设备管理可帮助您在不影响最终用户体验的情况下,通过无线方式管理和保护组织的移动设备群,现代 MDM 解决方案还可以控制 App、内容和安全性,因此员工可以毫无顾虑地在托管设备上工作。移动设备管理软件可有效管理个人设备上的公司空间。M…

优化CentOS 7.6的HTTP隧道代理网络性能

在CentOS 7.6上,通过HTTP隧道代理优化网络性能是一项复杂且细致的任务。首先,我们要了解HTTP隧道代理的工作原理:通过建立一个安全的隧道,HTTP隧道代理允许用户绕过某些网络限制,提高数据传输的速度和安全性。然而&…

工业交换机在智慧水务和水处理中的应用

智慧水务是一种基于互联网和物联网技术的水务管理模式。它利用现代信息技术,将传统的水务管理模式升级,实现智慧化的水务管理方式。智慧水务的实现离不开各种先进的技术手段。物联网技术是智慧水务的重要组成部分。通过在水务系统中部署工业交换机、传感…

C/C++调用matlab

C/C调用matlab matlab虽然可以生成C/C的程序,但其能力很有限,很多操作无法生成C/C程序,比如函数求解、优化、拟合等。为了解决这个问题,可以采用matlab和C/C联合编程的方式进行。使用matlab将关键操作打包成dll环境,再…

MySQL 存储引擎全攻略:选择最适合你的数据库引擎

1. MySQL的支持的存储引擎有哪些 官方文档给出的有以下几种: 我们也可以通过SHOW ENGINES命令来查看: 还可以通过ENGINES表查看 2. 存储引擎比较 我们通过存储引擎表来看各自的优点: InnoDB 默认的存储引擎(SUPPORT字段为D…

广东做“人工心脏”可以报销啦

(人民日报健康客户端记者 杨林宋)1月5日,据南方医科大学珠江医院消息,医院为一位57岁患者处于心衰终末期的患者,植入一款国产“人工心脏”——左心室辅助装置。据了解,这是该款“人工心脏”纳入广东省医保准…

py的循环语句(for和while)

前言:本章节和友友们探讨一下py的循环语句,主播觉得稍微有点难主要是太浑了,但是会尽量描述清楚,OK上车!(本章节有节目效果) 目录 一.while循环的基本使用 1.1关于while循环 1.2举例 1.31-1…

[C#]使用winform部署PP-MattingV2人像分割onnx模型

【官方框架地址】 https://github.com/PaddlePaddle/PaddleSeg 【算法介绍】 PP-MattingV2是一种先进的图像和视频抠图算法,由百度公司基于PaddlePaddle深度学习框架开发。它旨在提供更精准和高效的图像分割功能,特别是在处理图像中的细微部分&#xf…

【Copilot使用】

Copilot是什么 copilot有多火,1月4日,科技巨头微软在官网上宣布将为Windows 11 PC推出Copilot键。 Copilot是微软在Windows 11中加入的AI助手,该AI助手是一个集成了在操作系统中的侧边栏工具,可以帮助用户完成各种任务。 Copilo…