TensorFlow2实战-系列教程14:Resnet实战2

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

4、训练脚本train.py解读------创建模型

def get_model():model = resnet50.ResNet50()if config.model == "resnet34":model = resnet34.ResNet34()if config.model == "resnet101":model = resnet101.ResNet101()if config.model == "resnet152":model = resnet152.ResNet152()model.build(input_shape=(None, config.image_height, config.image_width, config.channels))model.summary()tf.keras.utils.plot_model(model, to_file='model.png')return model# create model
model = get_model()

调用get_model()函数构建模型

get_model()函数:

  1. 通过resnet50.py调用ResNet50类,构建ResNet50模型
  2. 如果在配置参数中设置的是"resnet34"、“resnet101”、“resnet152”,则会对应使用(resnet34.py调用ResNet34类,构建ResNet34模型)、(resnet101.py调用ResNet101类,构建ResNet101模型)、(resnet152.py调用ResNet152类,构建ResNet152模型)
  3. 准备模型以供训练或评估,
  4. 输出模型的概览
  5. 创建了模型的结构图,plot_model 函数从 Keras 工具包中生成模型的可视化表示,指定了保存路径

5、模型构建解析------models/resnet50.py

import tensorflow as tf
from models.residual_block import build_res_block_2
from config import NUM_CLASSESclass ResNet50(tf.keras.Model):def __init__(self, num_classes=NUM_CLASSES):super(ResNet50, self).__init__()self.pre1 = tf.keras.layers.Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='same')self.pre2 = tf.keras.layers.BatchNormalization()self.pre3 = tf.keras.layers.Activation(tf.keras.activations.relu)self.pre4 = tf.keras.layers.MaxPool2D(pool_size=(3, 3), strides=2)self.layer1 = build_res_block_2(filter_num=64, blocks=3)self.layer2 = build_res_block_2(filter_num=128, blocks=4, stride=2)self.layer3 = build_res_block_2(filter_num=256, blocks=6, stride=2)self.layer4 = build_res_block_2(filter_num=512, blocks=3, stride=2)self.avgpool = tf.keras.layers.GlobalAveragePooling2D()self.fc1 = tf.keras.layers.Dense(units=1000, activation=tf.keras.activations.relu)self.drop_out = tf.keras.layers.Dropout(rate=0.5)self.fc2 = tf.keras.layers.Dense(units=num_classes, activation=tf.keras.activations.softmax)def call(self, inputs, training=None, mask=None):pre1 = self.pre1(inputs)pre2 = self.pre2(pre1, training=training)pre3 = self.pre3(pre2)pre4 = self.pre4(pre3)l1 = self.layer1(pre4, training=training)l2 = self.layer2(l1, training=training)l3 = self.layer3(l2, training=training)l4 = self.layer4(l3, training=training)avgpool = self.avgpool(l4)fc1 = self.fc1(avgpool)drop = self.drop_out(fc1)out = self.fc2(drop)return out

class ResNet50(tf.keras.Model),这个类定义了ResNet50模型的结构,以及前向传播的方式、顺序

ResNet50类解析:

  1. 构造函数,传入了预测的类别数
  2. 初始化
  3. pre1 ,定义一个二维卷积,输出64个特征图,7x7的卷积,步长为2
  4. pre2 ,定义一个批归一化
  5. pre3,定义一个ReLU激活函数
  6. pre4,一个二维的最大池化
  7. 依次通过build_res_block_2()函数定义4个残差块
  8. 定义一个全局平均池化
  9. 定义一个全连接层,输出维度为1000
  10. 定义一个dropout
  11. 定义一个输出层的全连接层
  12. 前向传播函数,传入输入值
  13. 依次经过pre1、pre2、pre3、pre4,即卷积、批归一化、ReLU、最大池化
  14. 依次经过layer1 、layer2 、layer3 、layer4 等四个残差块
  15. 将layer4 的输出经过平局池化
  16. 依次经过两个全连接层

6、模型构建解析------models/residual_block.py

  • BottleNeck类
  • build_res_block_2()函数
  • build_res_block_2()函数通过调用BottleNeck类构建残差块
class BottleNeck(tf.keras.layers.Layer):def __init__(self, filter_num, stride=1,with_downsample=True):super(BottleNeck, self).__init__()self.with_downsample = with_downsampleself.conv1 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(1, 1), strides=1, padding='same')self.bn1 = tf.keras.layers.BatchNormalization()self.conv2 = tf.keras.layers.Conv2D(filters=filter_num, kernel_size=(3, 3), strides=stride, padding='same')self.bn2 = tf.keras.layers.BatchNormalization()self.conv3 = tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=1, padding='same')self.bn3 = tf.keras.layers.BatchNormalization()self.downsample = tf.keras.Sequential()self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num * 4, kernel_size=(1, 1), strides=stride))self.downsample.add(tf.keras.layers.BatchNormalization())def call(self, inputs, training=None):identity = self.downsample(inputs)conv1 = self.conv1(inputs)bn1 = self.bn1(conv1, training=training)relu1 = tf.nn.relu(bn1)conv2 = self.conv2(relu1)bn2 = self.bn2(conv2, training=training)relu2 = tf.nn.relu(bn2)conv3 = self.conv3(relu2)bn3 = self.bn3(conv3, training=training)if self.with_downsample == True:output = tf.nn.relu(tf.keras.layers.add([identity, bn3]))else:output = tf.nn.relu(tf.keras.layers.add([inputs, bn3]))return output

BottleNeck类解析:

  1. 继承tf.keras.layers.Layer
  2. 构造函数,传入 特征图个数、步长、是否下采样等参数
  3. 初始化
  4. 是否进行下采样参数
  5. 定义一个1x1,步长为1的二维卷积conv1
  6. conv1 对应的批归一化
  7. 定义一个3x3,步长为1的二维卷积conv2
  8. conv2 对应的批归一化
  9. 定义一个3x3,步长为1的二维卷积conv2
  10. conv3 对应的批归一化
  11. 定义一个下采样层(self.downsample),这个层是一个包含卷积层和批量归一化的 Sequential 模型,用于匹配输入和残差的维度
  12. call()函数为前向传播
  13. 应用下采样
  14. 应用三层卷积和批量归一化以及对应的ReLU
  15. with_downsample == True:
  16. 启用下采样,将下采样后的输入(identity)与最后一个卷积层的输出(bn3)相加
  17. 没有启用下采样,将原始输入(inputs)与最后一个卷积层的输出(bn3)相加
def build_res_block_2(filter_num, blocks, stride=1):res_block = tf.keras.Sequential()res_block.add(BottleNeck(filter_num, stride=stride))for _ in range(1, blocks):res_block.add(BottleNeck(filter_num, stride=1,with_downsample=False))    return res_block

build_res_block_2函数解析:

  1. 这个函数构建了一个包含多个BottleNeck层的残差块
  2. filter_num 是每个瓶颈层内卷积层的过滤器数量
  3. blocks 是要添加到顺序模型中的瓶颈层的数量
  4. stride 是卷积的步长,默认为 1
  5. 该函数初始化一个 Sequential 模型,并添加一个 BottleNeck 层作为第一层
  6. 然后,它迭代地添加额外的 BottleNeck 层,每个层的 stride=1 且
    with_downsample=False(除第一个之外)
  7. 此函数返回组装好的顺序模型,代表一个残差块

Resnet实战1
Resnet实战2
Resnet实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/666811.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

27. 云原生流量治理之kubesphere灰度发布

云原生专栏大纲 文章目录 灰度发布介绍灰度发布策略KubeSphere中恢复发布策略蓝绿部署金丝雀发布流量镜像 灰度发布实战部署自制应用金丝雀发布创建金丝雀发布任务测试金丝雀发布情况 蓝绿部署创建蓝绿部署测试蓝绿部署情况 流量镜像创建流量进行任务测试流量镜像情况 灰度发布…

【iOS ARKit】3D 人体姿态估计

与基于屏幕空间的 2D人体姿态估计不同,3D人体姿态估计是尝试还原人体在三维世界中的形状与姿态,包括深度信息。绝大多数的现有3D人体姿态估计方法依赖2D人体姿态估计,通过获取 2D人体姿态后再构建神经网络算法,实现从 2D 到 3D人体…

常见的web前端开发框架介绍

Web前端开发框架是为了简化网页设计和开发的流程而创建的工具集。它们提供了预定义的组件、工具和库,帮助开发者快速构建交互式的用户界面。以下是一些常见的Web前端开发框架,以及它们的原理、基础技术和应用场景的介绍: 1. React **…

APK签名 v1、 v2、v3、v3.1、v4 解析

在 Android 应用签名中,V1 V2 V3 V4签名是不同的签名方案,具体描述如下: V1 签名(JAR 签名):早期 Android 应用签名的基本形式,基于 Java 签名(JAR 签名)规范。它将应用…

<Linux> 进程信号

目录 一、信号概念 二、信号的作用 三、信号的特性 四、信号捕捉初识 五、信号产生 (一)通过终端按键产生信号 (二)硬件中断 (三)系统调用产生信号 1. kill 函数 2. raise 函数 3. abort 函数 …

Redis核心技术与实战【学习笔记】 - 22.浅谈Redis的ACID相关知识

概述 事务是数据库的一个重要功能。所谓的事务,就是指对数据进行读写的一系列操作。事务在执行时,会提供专门的属性保证,包括原子性(Atomicity)、一致性(Consistency)、隔离性(Isol…

2024.02.04

写出三种进程间通信的代码示例 有名管道 创建两个有名管道文件 #include <myhead.h> int main(int argc, const char *argv[]) {if(mkfifo("pipe1",0664)-1){perror("mkfifo pipe1 error");return -1;}if(mkfifo("pipe2",0664)-1){perr…

sentinel的Context创建流程分析

sentinel入门 功能 限流&#xff1a;通过限制请求速率、并发数或者用户数量来控制系统的流量&#xff0c;防止系统因为流量过大而崩溃或无响应的情况发生。 熔断&#xff1a;在系统出现故障或异常时将故障节点从系统中断开&#xff0c;从而保证系统的可用性。 降级&#xf…

性能测试工具架构

背景 性能测试工具&#xff08;LoadRunner为例&#xff09; 性能测试工具通常是指那些用来支持压力、负载测试&#xff0c;能够录制和生成脚本、设置和部署场景、产生并发用户和向系统施加持续压力的工具。 性能测试工具录制的是服务端与应用之间的通信数据&#xff0c;而不是…

【Spring】自定义注解 + AOP 记录用户的使用日志

目录 ​编辑 自定义注解 AOP 记录用户的使用日志 使用背景 落地实践 一&#xff1a;自定义注解 二&#xff1a;切面配置 三&#xff1a;Api层使用 使用效果 自定义注解 AOP 记录用户的使用日志 使用背景 &#xff08;1&#xff09;在学校项目中&#xff0c;安防平台…

【RT-DETR有效改进】 DySample一种超级轻量的动态上采样算子(上采样中的No.1)

👑欢迎大家订阅本专栏,一起学习RT-DETR👑 一、 本文介绍 本文给大家带来的改进机制是一种号称超轻量级且有效的动态上采样器——DySample。与传统的基于内核的动态上采样器相比,DySample采用了一种基于点采样的方法,相比于以前的基于内核的动态上采样器,DySample具…

整理:汉诺塔简析

大体上&#xff0c;要解决一个汉诺塔问题&#xff0c;就需要解决两个更简单的汉诺塔问题 以盘子数量 3 的汉诺塔问题为例 要将 3 个盘子从 A 移动到 C&#xff0c;就要&#xff1a; 将两个盘子从 A 移动到 B&#xff08;子问题 1&#xff09; 为了解决子问题 1&#xff0c;就…

clickhouse在MES中的应用-跟踪扫描

开发的MES&#xff0c;往往都要做生产执行跟踪扫描&#xff0c;这样会产生大量的扫描数据&#xff0c;用关系型数据库&#xff0c;很容易造成查询冲突的问题。 生产跟踪扫描就发生的密度是非常高的&#xff0c;每个零部件的加工过程&#xff0c;都要被记录下来&#xff0c;特别…

Opencv(C++)学习 TBB与OPENMP的加速效果实验与ARM上的实践(三)

接上文&#xff0c;本章尝试在RV1106上使用TBB。依然是一言难尽&#xff0c;此文依然只是记录实践过程。 源码下载&#xff0c;编译TBB 下载地址: https://github.com/oneapi-src/oneTBB 版本使用 oneTBB-2021.11.0&#xff0c;这个版本可以使用cmake编译。 cmake配置完后&am…

动手学深度学习v2-线性回归-笔记

简化核心模型 假设1: 影响房价的关键因素是卧室个数&#xff0c;卫生间个数和居住面积&#xff0c;记为 x 1 x_{1} x1​&#xff0c; x 2 x_{2} x2​&#xff0c; x 3 x_{3} x3​假设2: 成交价是关键因素的加权和 y w 1 x 1 w 2 x 2 w 3 x 3 b yw_{1}x_{1}w_{2}x_{2}w_{3…

鲲鹏--垂直生态领导者

这是ren_dong的第24篇原创 1、概述 华为鲲鹏920&#xff1a;性能最高的ARM架构服务器芯片 鲲鹏是华为在芯片领域布局的重要一环 &#xff0c;是垂直生态的领导者&#xff0c;鲲鹏芯片诞生于2020年 &#xff0c;已获得ARMv8架构的永久授权&#xff0c;主要聚焦通用计算领域华为针…

typecho 在文章中添加 bilibili 视频

一、获取视频来源&#xff1a; 可以有2种方式来定位一个 bilibili 视频&#xff1a; 第一种是使用 bvid 参数定位第二种是使用 aid 参数定位 如何获取这两个参数&#xff1f; 首先我们可以看看 bilibili 网站中的视频页面链接其实可以分为两种&#xff1a; 第一种是类似&a…

建材智能工厂数字孪生可视化管控平台,推进建材行业数字化转型

建材智能工厂数字孪生可视化管控平台&#xff0c;推进建材行业数字化转型。随着科技的不断发展&#xff0c;数字化转型已经成为各行各业发展的重要趋势。在建材行业中&#xff0c;智能工厂和数字孪生技术的应用正在改变传统的生产模式和管理方式&#xff0c;为行业的数字化转型…

C语言常见面试题:C语言中如何进行加密解密编程?

在C语言中进行加密和解密编程通常涉及到使用加密算法和相关的库。下面是一些常用的加密和解密算法以及如何使用它们的基本说明&#xff1a; AES加密算法&#xff1a; AES&#xff08;Advanced Encryption Standard&#xff09;是一种常用的对称加密算法。在C语言中&#xff0c;…

clr的执行模型-笔记

学习来源&#xff1a;《CLR via C by Jeffrey Richter 》第四版&#xff0c;第1章 clr的执行模型 1.C#编译生成执行程序集文件 编译文件的组成&#xff1a;pe32/pe32头&#xff0c;clr头&#xff0c;元数据&#xff0c;IL pe32/pe32头&#xff1a;windows标准执行文件头 cl…