TensorFlow2实战-系列教程13:Resnet实战1

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

1、残差连接

深度学习中出现了随着网络的堆叠效果下降的现象,Resnet使用残差连接的方法解决了这个问题,让深度学习从此变得深了起来。
残差连接的做法可以表示为:
y = f ( x ) + x ( s h o r t c u t ) y = f(x)+x(shortcut) y=f(x)+x(shortcut)
其中y表示网络最后的输出,x为输入,f(x)则表示输入经过几层网络的输出结果,比如三次卷积+三次批归一化+三次relu,一般情况下f(x)就是网络的最终输出,这里再加上x就是一个残差连接的操作。

残差连接的操作保证了,x经过网络后得到的y一定比f(x)更好的结果,最差是同等效果,也就是保证了不会出现效果降低的情况。

这个shortcut是什么意思呢?因为x经过几次卷积后,可能会出现多个特征图,也就是f(x)和x的通道数不一样了,这个时候就需要对x的通道数进行调整再与f(x)相加得到y,如果通道数一样就不需要调整了
在这里插入图片描述
上图就是通道数没有发生变化的情况, y = f ( x ) + x y = f(x)+x y=f(x)+x,x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),再加上x后经过ReLU就得到了最终的y
在这里插入图片描述
上图就是通道数发生变化的情况, y = f ( x ) + C o n v 2 d ( x ) y = f(x)+Conv2d(x) y=f(x)+Conv2d(x),x经过两次(卷积+批归一化+ReLU)和一次(卷积+批归一化)后得到f(x),x再经过一次(二维卷积+批归一化),这个二维卷积的卷积核是1x1的,经过这个二维卷积的x再加上f(x)后经过ReLU就得到了最终的y

2、项目介绍

在这里插入图片描述

  1. dataset文件夹,将原始数据分割成训练、验证、测试三个数据集
  2. models构建模型的代码,包含resnet31、resnet50、resnet101、resnet152的构建代码,以及残差模块实现的代码
  3. original_dataset,原始数据,包含猫、狗、熊猫3个类别的数据,每个类别1000张图像
  4. save_model,训练模型保存的路径
  5. config.py 设置配置参数的代码
  6. evaluate.py 使用测试集对模型进行测试的代码
  7. prepare_data.py 数据预处理的辅助函数代码
  8. split_dataset.py 将原始数据集分割成训练集、验证集、测试集的代码
  9. train.py 训练验证的代码

3、训练脚本train.py解读------数据预处理

from __future__ import absolute_import, division, print_function
import tensorflow as tf
from models import resnet50, resnet101, resnet152, resnet34
import config
from prepare_data import generate_datasets
import mathif __name__ == '__main__':# GPU settingsgpus = tf.config.experimental.list_physical_devices('GPU')if gpus:for gpu in gpus:tf.config.experimental.set_memory_growth(gpu, True)

导入项目工具包和辅助函数
配置 TensorFlow 中的 GPU 内存,:

  1. tf.config.experimental.list_physical_devices('GPU'):这个函数调用列出了 TensorFlow 在你的机器上可用的所有 GPU
  2. if gpus: 这个检查用来确认是否有可用的 GPU。如果有,它将继续对每一个 GPU 进行配置
  3. 在循环内部,对每一个 GPU 调用 tf.config.experimental.set_memory_growth(gpu, True),这使得 GPU 上的内存增长被启用
# get the original_datasettrain_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets()
def generate_datasets():train_dataset, train_count = get_dataset(dataset_root_dir=config.train_dir)valid_dataset, valid_count = get_dataset(dataset_root_dir=config.valid_dir)test_dataset, test_count = get_dataset(dataset_root_dir=config.test_dir)# read the original_dataset in the form of batchtrain_dataset = train_dataset.shuffle(buffer_size=train_count).batch(batch_size=config.BATCH_SIZE)valid_dataset = valid_dataset.batch(batch_size=config.BATCH_SIZE)test_dataset = test_dataset.batch(batch_size=config.BATCH_SIZE)return train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count
def get_dataset(dataset_root_dir):all_image_path, all_image_label = get_images_and_labels(data_root_dir=dataset_root_dir)# print("image_path: {}".format(all_image_path[:]))# print("image_label: {}".format(all_image_label[:]))# load the dataset and preprocess imagesimage_dataset = tf.data.Dataset.from_tensor_slices(all_image_path).map(load_and_preprocess_image)label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)dataset = tf.data.Dataset.zip((image_dataset, label_dataset))image_count = len(all_image_path)return dataset, image_count
def get_images_and_labels(data_root_dir):# 得到所有图像路径data_root = pathlib.Path(data_root_dir)all_image_path = [str(path) for path in list(data_root.glob('*/*'))]# 得到标签名字label_names = sorted(item.name for item in data_root.glob('*/'))# 例如:{'cats': 0, 'dogs': 1, 'panda': 2}label_to_index = dict((index, label) for label, index in enumerate(label_names))# 每一个图像对应的标签all_image_label = [label_to_index[pathlib.Path(single_image_path).parent.name] for single_image_path in all_image_path]return all_image_path, all_image_labeldef load_and_preprocess_image(img_path):# read picturesimg_raw = tf.io.read_file(img_path)# decode picturesimg_tensor = tf.image.decode_jpeg(img_raw, channels=channels)# resizeimg_tensor = tf.image.resize(img_tensor, [image_height, image_width])img_tensor = tf.cast(img_tensor, tf.float32)# normalizationimg = img_tensor / 255.0return img

load_and_preprocess_image()函数:

  1. 通过读取一个图像的路径
  2. 返回Tensor
  3. 进去进行归一化

get_images_and_labels()函数:

  1. 通过数据集的地址,获取当前目录下的所有图像的名称
  2. 在加上前缀路径和文件后缀,得到当前所有图像的对应的地址
  3. 返回地址和标签

get_dataset()函数:

  1. 通过调用get_images_and_labels()函数,得到当前目录下的图像的对应的地址和标签
  2. 使用from_tensor_slices方法和load_and_preprocess_image()函数读取地址和标签转换为Tensor
  3. 返回标签和数据组成的Tensor以及数据量

generate_datasets()函数:

  1. 训练、验证、测试数据路径分别通过调用get_dataset()函数得到训练、验证、测试数据Tensor和数据量
  2. 对训练、验证、测试数据加上batch_size和shuffle参数
  3. 返回训练、验证、测试数据Tensor和数据量

Resnet实战1
Resnet实战2
Resnet实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/659698.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

scienceplots绘图浅尝

前言 科研写作中,黑压压的文字里面如果能有一些优美的图片无疑会给论文增色不少,绘图的工具有很多,常用的有Excel、Python、Matlab等,Matlab在绘图方面相较于Python有一种更加原生的科研风,而且可视化编辑图例、坐标轴…

ManticoreSearch-(安装配置,集群搭建)-学习总结

ManticoreSearch-(安装配置)-学习总结 基础概念安装搭建集群搭建(基于K8S) 原文地址 https://blog.csdn.net/liuyij3430448/article/details/135955025 基础概念 Manticore Search是一个专门为搜索设计的多存储数据库,具有强大的全文搜索功能,适用于…

【SparkML系列3】特征提取器TF-IDF、Word2Vec和CountVectorizer

本节介绍了用于处理特征的算法,大致可以分为以下几组: 提取(Extraction):从“原始”数据中提取特征。转换(Transformation):缩放、转换或修改特征。选择(Selection&…

15.Golang中的反射机制及应用

目录 概述实践基本应用复杂应用 结束 概述 Golang中的反射用法还是比较简单的 reflect.TypeOf(arg)reflect.ValueOf(arg) 实践 基本应用 package mainimport ("fmt""reflect" )func reflectNum(arg interface{}) {fmt.Println("type ", re…

网络异常案例三_RST

问题现象 在做功能测试的时候,经常看到设备离线的消息(MQTT遗嘱)。 在终端连接的网络设备上抓包分析,看到终端设备发起大量的RST请求。 151这个设备,7min,重置断开了8个TCP连接(mqtt连接&#…

理解部署描述符的元素

理解部署描述符的元素 部署描述符是文件名为web.xml的XML文件,其包含了Web应用程序的配置信息。每个Web应用程序都有一个web.xml文件。web.xml文件的元素可用于指定servlet的初始化参数、不同文件的MIME类型、侦听器类,以及将URL模式映射到servlet上。一…

2024年,AI 掀起数据与分析市场的新风暴

2024 年伊始,Kyligence 联合创始人兼 CEO 韩卿在其公司内部的飞书订阅号发表了多篇 Rethink Data & Analytics 的内部信,分享了对数据与分析行业的一些战略思考,尤其是 AI 带来的各种变化和革命,是如何深刻地影响这个行业乃至…

防御挂马攻击:从防御到清除的最佳实践

挂马攻击,也称为马式攻击(Horse Attack),是一种常见的网络攻击手段。攻击者通过在目标服务器或网站中植入恶意程序,以获取系统权限或窃取敏感信息。为了应对这种威胁,本文将重点介绍防御挂马攻击的最佳实践…

AI项目落地成功因素:数据和机器学习模型的选择

构建机器学习模型时,需要考虑几个关键要素:计算能力、算法和数据。公司往往会将大部分资源集中于开发正确的、无偏见的算法,并加大对计算能力的投入,而在运行模型前,数据通常靠边站或完全被抛诸脑后。 如果数据被遗忘&…

C语言——动态内存管理(经典例题)

题1、 为什么会崩溃呢&#xff1f;&#x1f914;&#x1f914;&#x1f914; #include <stdio.h> #include <stdlib.h> #include <string.h>void GetMemory(char** p) {*p (char*)malloc(100); } void Test(void) {char* str NULL;GetMemory(&str);str…

腾讯云幻兽帕鲁Palworld服务器价格表,2024年2月最新

腾讯云幻兽帕鲁服务器价格32元起&#xff0c;4核16G12M配置32元1个月、96元3个月、156元6个月、312元一年&#xff0c;支持4-8个玩家&#xff1b;8核32G22M幻兽帕鲁服务器115元1个月、345元3个月&#xff0c;支持10到20人在线开黑。腾讯云百科txybk.com分享更多4核8G12M、16核6…

力扣hot100 不同路径 多维DP 滚动数组 数论

Problem: 62. 不同路径 文章目录 思路解题方法复杂度朴素DP 思路 讲述看到这一题的思路 解题方法 &#x1f468;‍&#x1f3eb; 卡尔一题三解 复杂度 时间复杂度: &#xff1a; O ( n m ) O(nm) O(nm) 空间复杂度: O ( n m ) O(nm) O(nm) 朴素DP class Solution {p…

查看 npm的一些命令,以及npm config set registry x x x 不生效 解决方案

在 Mac 上查看自己的 npm 源&#xff0c;可以使用以下命令&#xff1a; 打开终端应用程序&#xff08;Terminal&#xff09;。 运行以下命令来查看当前的 npm 配置&#xff1a; npm config list这会显示 npm 的配置信息&#xff0c;包括当前使用的源&#xff08;registry&am…

操作系统基础:死锁

&#x1f308;个人主页&#xff1a;godspeed_lucip &#x1f525; 系列专栏&#xff1a;OS从基础到进阶 &#x1f426;1 死锁的概念&#x1f9a2;1.1 总览&#x1f9a2;1.2 什么是死锁&#x1f9a2;1.3 死锁、饥饿、死循环的区别&#x1f427;1.3.1 概念&#x1f427;1.3.2 区别…

快速排序|超详细讲解|入门深入学习排序算法

快速排序介绍 快速排序(Quick Sort)使用分治法策略。 它的基本思想是&#xff1a;选择一个基准数&#xff0c;通过一趟排序将要排序的数据分割成独立的两部分&#xff1b;其中一部分的所有数据都比另外一部分的所有数据都要小。然后&#xff0c;再按此方法对这两部分数据分别进…

vue3-深入组件-插槽

插槽 Slots 组件用来接收模板内容 插槽内容与出口 <slot> 元素是一个插槽出口 (slot outlet),&#xff0c;标示了父元素提供的插槽内容 (slot content) 将在哪里被渲染。 插槽内容可以是任意合法的模板内容&#xff0c;不局限于文本。例如我们可以传入多个元素&#xff0…

HTML+CSS:导航栏组件

效果演示 实现了一个导航栏的动画效果&#xff0c;当用户点击导航栏中的某个选项时&#xff0c;对应的选项卡会向左平移&#xff0c;同时一个小圆圈会出现在选项卡的中心&#xff0c;表示当前选项卡的位置。这个效果可以让用户更加清晰地了解当前页面的位置和内容。 Code <…

关于source批量处理sql命令建立数据库后发现中文乱码问题解决方案(Mysql)

今天在使用souce建表的时候发现自己表结构中的中文出现了乱码问题&#xff0c;那么具体的解决方案如下&#xff1a; 首先我们先使用命令行连接自己的数据库 mysql -u root -p 12345 然后使用show variables like "char%"; 如果说你的这个里面不是utf-8那么就是出现了…

第九篇【传奇开心果系列】Python的OpenCV技术点案例示例:目标跟踪

传奇开心果短博文系列 系列短博文目录Python的OpenCV技术点案例示例系列 短博文目录前言二、常用的目标跟踪功能、高级功能和增强跟踪技术介绍三、常用的目标跟踪功能示例代码四、OpenCV高级功能示例代码五、OpenCV跟踪目标增强技术示例代码六、归纳总结 系列短博文目录 Pytho…

maven--将jar包上传到maven中央仓库(公库)

原文网址&#xff1a;maven--将jar包上传到maven中央仓库(公库)-CSDN博客 简介 本文介绍怎样将jar包上传到maven中央仓库(公库)。 当自己有一些公共组件时&#xff0c;上传到maven公库是最好的&#xff0c;这样项目里直接引用即可&#xff0c;不需要在多处修改&#xff0c;而…