T10 数据增强

文章目录

  • 一、准备环境和数据
    • 1.环境
    • 2. 数据
  • 二、数据增强(增加数据集中样本的多样性)
  • 三、将增强后的数据添加到模型中
  • 四、开始训练
  • 五、自定义增强函数
  • 六、一些增强函数

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第10周:数据增强(训练营内部成员可读)
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

本文说明了两种数据增强方式,以及如何自定义数据增强方式并将其放到我们代码当中,两种数据增强方式如下:
● 将数据增强模块嵌入model中
● 在Dataset数据集中进行数据增强

常用的tf增强函数在文末有说明

一、准备环境和数据

1.环境

import matplotlib.pyplot as plt
import numpy as np
import sys
from datetime import datetime
#隐藏警告
import warnings
warnings.filterwarnings('ignore')from tensorflow.keras import layers
import tensorflow as tfprint("--------# 使用环境说明---------")
print("Today: ", datetime.today())
print("Python: " + sys.version)
print("Tensorflow: ", tf.__version__)gpus = tf.config.list_physical_devices("GPU")
if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用print(gpus)
else:print("Use CPU")

在这里插入图片描述

2. 数据

使用上一课的数据集,即猫狗识别2的数据集。其次,原数据集中不包括测试集,所以使用tf.data.experimental.cardinality确定验证集中有多少批次的数据,然后将其中的 20% 移至测试集。

# 从本地路径读入图像数据
print("--------# 从本地路径读入图像数据---------")
data_dir   = "D:/jupyter notebook/DL-100-days/datasets/Cats&Dogs Data2/"
img_height = 224
img_width  = 224
batch_size = 32# 划分训练集
print("--------# 划分训练集---------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 划分验证集
print("--------# 划分验证集---------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 从验证集中划20%的数据用作测试集
print("--------# 从验证集中划20%的数据用作测试集---------")
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)print('验证集的批次数: %d' % tf.data.experimental.cardinality(val_ds))
print('测试集的批次数: %d' % tf.data.experimental.cardinality(test_ds))# 显示数据类别
print("--------# 显示数据类别---------")
class_names = train_ds.class_names
print(class_names)print("--------# 归一化处理---------")
AUTOTUNE = tf.data.AUTOTUNEdef preprocess_image(image,label):return (image/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 数据可视化
print("--------# 数据可视化---------")
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1) plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")

在这里插入图片描述

二、数据增强(增加数据集中样本的多样性)

数据增强的常用方法包括(但不限于):随机平移、随机翻转、随机旋转、随机亮度、随机对比度,可以在Tf中文网的experimental/preprocessing类目下查看,也可以在Tf中文网的layers/类目下查看。

本文使用随机翻转随机旋转来进行增强:

tf.keras.layers.experimental.preprocessing.RandomFlip:水平和垂直随机翻转每个图像

tf.keras.layers.experimental.preprocessing.RandomRotation:随机旋转每个图像

# 第一个层表示进行随机的水平和垂直翻转,而第二个层表示按照 0.2 的弧度值进行随机旋转。
print("--------# 数据增强:随机翻转+随机旋转---------")
data_augmentation = tf.keras.Sequential([tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])# Add the image to a batch.
print("--------# 添加图像到batch中---------")
# Q:这个i从哪来的??????
image = tf.expand_dims(images[i], 0)print("--------# 显示增强后的图像---------")
plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = data_augmentation(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0])plt.axis("off")
--------# 数据增强:随机翻转+随机旋转---------
--------# 添加图像到batch中---------
--------# 显示增强后的图像---------
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.

在这里插入图片描述

三、将增强后的数据添加到模型中

两种方式:

  • (1)将其嵌入model中

优点是:

● 数据增强这块的工作可以得到GPU的加速(如果使用了GPU训练的话)

注意:只有在模型训练时(Model.fit)才会进行增强,在模型评估(Model.evaluate)以及预测(Model.predict)时并不会进行增强操作。

'''
model = tf.keras.Sequential([data_augmentation,layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),
])
'''
"\nmodel = tf.keras.Sequential([\n  data_augmentation,\n  layers.Conv2D(16, 3, padding='same', activation='relu'),\n  layers.MaxPooling2D(),\n])\n"
  • (2)在Dataset数据集中进行数据增强
batch_size = 32
AUTOTUNE = tf.data.AUTOTUNEdef prepare(ds):ds = ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=AUTOTUNE)return dsprint("--------# 增强后的图像加到模型中---------")
train_ds = prepare(train_ds)

在这里插入图片描述

四、开始训练

# 设置模型
print("--------# 设置模型---------")
model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])# 设置编译参数
# ● 损失函数(loss):用于衡量模型在训练期间的准确率。
# ● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
# ● 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
print("--------# 设置编译器参数---------")
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])print("--------# 开始训练---------")
epochs=20
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)print("--------# 查看训练结果---------")
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

在这里插入图片描述

五、自定义增强函数

print("--------# 自定义增强函数---------")
import random
# 这是大家可以自由发挥的一个地方
def aug_img(image):seed = (random.randint(0,9), 0)# 随机改变图像对比度stateless_random_brightness = tf.image.stateless_random_contrast(image, lower=0.1, upper=1.0, seed=seed)return stateless_random_brightnessimage = tf.expand_dims(images[3]*255, 0)
print("Min and max pixel values:", image.numpy().min(), image.numpy().max())plt.figure(figsize=(8, 8))
for i in range(9):augmented_image = aug_img(image)ax = plt.subplot(3, 3, i + 1)plt.imshow(augmented_image[0].numpy().astype("uint8"))plt.axis("off")# Q: 将自定义增强函数应用到我们数据上呢?
# 请参考上文的 preprocess_image 函数,将 aug_img 函数嵌入到 preprocess_image 函数中,在数据预处理时完成数据增强就OK啦。

在这里插入图片描述
在这里插入图片描述

# 从本地路径读入图像数据
print("--------# 从本地路径读入图像数据---------")
data_dir   = "D:/jupyter notebook/DL-100-days/datasets/Cats&Dogs Data2/"
img_height = 224
img_width  = 224
batch_size = 32# 划分训练集
print("--------# 划分训练集---------")
train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 划分验证集
print("--------# 划分验证集---------")
val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.3,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)# 从验证集中划20%的数据用作测试集
print("--------# 从验证集中划20%的数据用作测试集---------")
val_batches = tf.data.experimental.cardinality(val_ds)
test_ds     = val_ds.take(val_batches // 5)
val_ds      = val_ds.skip(val_batches // 5)print('验证集的批次数: %d' % tf.data.experimental.cardinality(val_ds))
print('测试集的批次数: %d' % tf.data.experimental.cardinality(test_ds))# 显示数据类别
print("--------# 显示数据类别---------")
class_names = train_ds.class_names
print(class_names)print("--------# 归一化处理---------")
AUTOTUNE = tf.data.AUTOTUNEprint("--------# 将自定义增强函数应用到数据上---------")
def preprocess_image(aug_img,label):return (aug_img/255.0,label)# 归一化处理
train_ds = train_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
val_ds   = val_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)
test_ds  = test_ds.map(preprocess_image, num_parallel_calls=AUTOTUNE)train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds   = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# 数据可视化
print("--------# 数据可视化---------")
plt.figure(figsize=(15, 10))  # 图形的宽为15高为10for images, labels in train_ds.take(1):for i in range(8):ax = plt.subplot(5, 8, i + 1) plt.imshow(images[i])plt.title(class_names[labels[i]])plt.axis("off")# 设置模型
print("--------# 设置模型---------")
model = tf.keras.Sequential([layers.Conv2D(16, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(32, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Conv2D(64, 3, padding='same', activation='relu'),layers.MaxPooling2D(),layers.Flatten(),layers.Dense(128, activation='relu'),layers.Dense(len(class_names))
])# 设置编译参数
# ● 损失函数(loss):用于衡量模型在训练期间的准确率。
# ● 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
# ● 评价函数(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
print("--------# 设置编译器参数---------")
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])print("--------# 开始训练---------")
epochs=20
history = model.fit(train_ds,validation_data=val_ds,epochs=epochs
)print("--------# 查看训练结果---------")
loss, acc = model.evaluate(test_ds)
print("Accuracy", acc)

使用自定义增强函数增强后的数据重新训练的结果:
在这里插入图片描述

六、一些增强函数

在这里插入图片描述
(1)随机亮度(RandomBrightness)

tf.keras.layers.RandomBrightness( factor, value_range=(0, 255), seed=None, **kwargs )

(2)随机对比度(RandomContrast)

tf.keras.layers.RandomContrast( factor, seed=None, **kwargs )

(3)随机裁剪(RandomCrop)

tf.keras.layers.RandomCrop( height, width, seed=None, **kwargs )

(4)随机翻转(RandomFlip)

tf.keras.layers.RandomFlip( mode=HORIZONTAL_AND_VERTICAL, seed=None, **kwargs )
(5)随机高度(RandomHeight)和随机宽度(RandomWidth)

tf.keras.layers.RandomHeight( factor, interpolation='bilinear', seed=None, **kwargs )

tf.keras.layers.RandomWidth( factor, interpolation='bilinear', seed=None, **kwargs )

(6)随机平移(RandomTranslation)

tf.keras.layers.RandomTranslation( height_factor, width_factor, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

(7)随机旋转(RandonRotation)

tf.keras.layers.RandomRotation( factor, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

(8)随机缩放(RandonZoom)

tf.keras.layers.RandomZoom( height_factor, width_factor=None, fill_mode='reflect', interpolation='bilinear', seed=None, fill_value=0.0, **kwargs )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/149601.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

docker数据卷详细讲解及数据卷常用命令

docker数据卷详细讲解及数据卷常用命令 Docker 数据卷是一种将宿主机的目录或文件直接映射到容器中的特殊目录,用于实现数据的持久化和共享。Docker 数据卷有以下特点: 数据卷可以在一个或多个容器之间共享和重用,不受容器的生命周期影响。…

CSS-表格独有属性

属性名:able-layout功能:设置列宽度属性值: auto(默认值):自动,列宽根据内容计算 table-layout: auto; fixed:固定列宽,平均分 table-layout: fixed; 属性名:…

C语言中文网 - Shell脚本 - 8

第1章 Shell基础(开胃菜) 8. Linux Shell命令提示符 启动 Linux 桌面环境自带的终端模拟包,或者从 Linux 控制台登录后,便可以看到 Shell 命令提示符。看见命令提示符就意味着可以输入命令了。命令提示符不是命令的一部分&#x…

Linux QT交叉编译环境安装

参考链接 linux交叉编译Qt_linux qt 交叉编译-CSDN博客 关键点:编译脚本,放在qt源代码根目录的.sh文件 #!/bin/shcd ./qt-everywhere-src-5.12.9./configure -prefix /home/qsqya/compile/qt5.12.9/build \ -opensource \ -release \ -confirm-license…

​LeetCode解法汇总2342. 数位和相等数对的最大和

目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 描述: 给你一个下…

充电桩负载测试需要检测哪些项目

充电桩负载测试在进行充电桩负载测试时,需要检测以下几个项目: 充电速度:测试充电桩的充电速度,包括直流充电桩的最大输出功率和交流充电桩的充电功率,以确定其是否符合标准要求。充电效率:测试充电桩的充电…

vue项目动态配置网站图标

1、在.env中配置图标地址 # 网站图标 VUE_APP_ICON_URL ./民政.png2、在main.js中将配置的图标地址存入缓存 if(process.env.VUE_APP_ICON_URL){sessionStorage.setItem(VUE_APP_ICON_URL, process.env.VUE_APP_ICON_URL); }else{sessionStorage.setItem(VUE_APP_ICON_URL, …

Go语言中ipv4与Uint32转换

简介 ip对于我们都不陌生,但是如果有一道题目要你判断某个ip在不在一个ip段的范围内,该怎么做呢,要是能把它弄成可比较的数字就好了 例如 127.0.0.1如何转数字呢,我们可以把它分成四段 127 0 0 1 每一段转为二进制拼起来 011111…

特征缩放和转换以及自定义Transformers(Machine Learning 研习之九)

特征缩放和转换 您需要应用于数据的最重要的转换之一是功能扩展。除了少数例外,机器学习算法在输入数值属性具有非常不同的尺度时表现不佳。住房数据就是这种情况:房间总数约为6至39320间,而收入中位数仅为0至15间。如果没有任何缩放,大多数…

初识分布式键值对存储etcd

欢迎大家到我的博客浏览。胤凯 (oyto.github.io)大家好,今天我带大家来学习一下 etcd。 一、什么是 etcd etcd 是一个开源的分布式键值存储系统,主要用于构建分布式系统中那点服务发现、配置管理、分布式锁等场景。它采用 Raft 一致性算法来确保所有节…

JavaScript 字符处理

1.删除前几个字符 使用 slice console.log(12345.slice(1))// 23452.首字母大写 var word abcconsole.log(word.charAt(0).toUpperCase() word.slice(1))// Abc3.字符为数字时可直接相乘 console.log(2*3) 4.字符串中是否包含某个子字符串 子串既可以为数字也可为字符串 /…

react中设置activeClassName的笔记

React是一种流行的JavaScript库,用于构建动态用户界面。它具有许多有用的组件,其中之一是NavLink组件。NavLink组件用于在React应用程序中创建链接,并且它具有许多有用的属性,例如选中的样式设置。 react-router-dom": “^6…

Pyside6/PyQt6如何添加右键菜单,源码示例

文章目录 📖 介绍 📖🏡 环境 🏡📒 源码分享 📒🎈 添加图标📖 介绍 📖 在UI开发中经常会使用到右键菜单,本文记录了一个添加右键菜单的示例,可以举一反三,仅供参考! 🏡 环境 🏡 本文演示环境如下 Windows11Python3.11.5PySide6📒 源码分享 📒 下面…

左支座零件的机械加工工艺规程及工艺装备设计【计算机辅助设计与制造CAD】

wx供重浩:创享日记 对话框发送:左支座 获取完整CAD工程源文件论文报告说明书等 一、论文目录 二、论文部分内容 设计任务 1.完成左支座零件—毛坯合图及左支座零件图 2.完成左支座零件工艺规程设计 3.完成左支座零件加工工艺卡 4.机床专用夹具装备总图 …

ACWSpring1.3

git使用git status看我们仓库有多少什么东西 首先,前端写ajax写上我们的访问路径(就在我们前端的源代码里面),我们建了两个包pkController用于前端页面url映射过来一层一层找到我们的RestController返回bot1里面有键值,返回的这就是一个session对象bot1这个map.前端拿到我们bot…

汇编-在VisualStudio调试器中显示数组

1.调试运行程序 2.菜单-->调试--> 3.在地址栏上输入 &数组名 4.其它选项 右击窗口 根据实际情况自己选择

石油石化物资采购杂志社石油石化物资采购编辑部2023年第18期部分目录

物资采购与管理 依法规范招标采购融合现代智慧供应链的路径探索 黄缵烨1-3 海上油气建设项目工程物资价格数据库的建立和应用研究 韩萍萍4-6《石油石化物资采购》投稿:cnqikantg126.com 招投标采办管理风险研究 王威7-9 海外风电总包项目风机等主要设备采购管理要点…

51.Sentinel微服务保护

目录 (1)初识Sentinel。 (1.1)雪崩问题及解决方案。 (1.1.1)雪崩问题。 (1.1.2)解决雪崩问题的四种方式。 (1.1.3)总结。 (1.2)…

在Ubuntu系统中安装VNC并结合内网穿透实现公网远程访问

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…