【TensorFlow2 之011】TF 如何使用数据增强提高模型性能?

一、说明

        亮点:在这篇文章中,我们将展示数据增强技术作为提高模型性能的一种方式的好处。当我们没有足够的数据可供使用时,这种方法将非常有益。

教程概述:

  1. 无需数据增强的训练
  2. 什么是数据增强?
  3. 使用数据增强进行训练
  4. 可视化

二、没有数据增强的训练

        一个熟悉的问题是“我们为什么要使用数据增强?所以,让我们看看答案。

        为了证明这一点,我们将在TensorFlow中创建一个卷积神经网络,并在Cats-vs-Dog数据集上对其进行训练。

        首先,我们将准备用于训练的数据集。我们将首先从在线存储库下载数据集。完成此操作后,我们将继续解压缩并为训练和验证集创建路径位置。

import os
import wget
import zipfilewget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")
100% [........................................................................] 68606236 / 68606236

Out[2]:

'cats_and_dogs_filtered.zip'

 

with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

       让我们继续加载本教程所需的必要库。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout, Activation
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.preprocessing.image import img_to_array, load_img
from tensorflow.keras.preprocessing.image import ImageDataGenerator

        我们将使用“模型子类化技术”来构建模型。这样,我们应该在__init__中定义我们的层,并在调用中实现模型的前向传递。模型的输入将是大小为 \([150, 150, 3]\) 的图像。在卷积层之后,我们将利用两个完全连接的层来进行预测。这是一个二元分类问题,所以我们在输出层中只有一个神经元。

class Create_model(Model):def __init__(self, chanDim=-1):super(Create_model, self).__init__()self.conv1A = Conv2D(16, 3, input_shape = (150, 150, 3))self.act1A  = Activation("relu")self.pool1A = MaxPooling2D(2)self.conv1B = Conv2D(32, 3)self.act1B  = Activation("relu")self.pool1B = MaxPooling2D(pool_size=(2, 2))self.conv1C = Conv2D(64, 3)self.act1C  = Activation("relu")self.pool1C = MaxPooling2D(2)self.flatten = Flatten()self.dense2A = Dense(512)self.act2A  = Activation("relu")self.dense2B = Dense(1)self.sigmoid  = Activation("sigmoid")def call(self, inputs):x = self.conv1A(inputs)x = self.act1A(x)x = self.pool1A(x)x = self.conv1B(x)x = self.act1B(x)x = self.pool1B(x)x = self.conv1C(x)x = self.act1C(x)x = self.pool1C(x)x = self.flatten(x)x = self.dense2A(x)x = self.act2A(x)x = self.dense2B(x)x = self.sigmoid(x)return xmodel = Create_model()model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])

        我们的图像不在一个文件中,而是在多个文件夹中。为了在这样的数据集上训练网络,我们需要使用图像数据生成器。在创建两个生成器(用于训练和验证)后,我们可以使用 fit 方法训练网络。唯一的区别是,我们不是将输入和输出分别传递给我们的网络,而是将数据生成器传递给网络。

        最好规范化像素值,以便每个像素值的值介于 0 和 1 之间,以免中断或减慢学习过程。因此,这将是传递给图像数据生成器的唯一参数。然后,我们将使用这些数据生成器遍历目录、调整图像大小和创建批处理。

train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')history = model.fit(train_generator,steps_per_epoch=100,epochs=15,validation_data=validation_generator,validation_steps=50,verbose=0)

        让我们检查一下模型的分类准确性和损失。

        在这里,训练集的准确性和损失都以蓝色显示,而验证集的准确度和损失都以橙色显示。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy)
plt.plot(epochs, val_accuracy)
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.title('Training and validation loss')

  

       从这些图中,我们可以清楚地看到,模型在训练中的表现比在验证集上的表现要好得多。那么我们能做什么呢?使用数据增强。

三、什么是数据增强?

        大多数计算机视觉任务需要大量数据,而数据增强是用于提高计算机视觉系统性能的技术之一。计算机视觉是一项相当复杂的任务。对于输入图像,算法必须找到一种模式来理解图片中的内容。

        在实践中,拥有更多数据将有助于几乎所有的计算机视觉任务。今天,计算机视觉的状态需要更多的数据来解决大多数计算机视觉问题。对于卷积神经网络的所有应用来说,这可能不是真的,但对于计算机视觉领域来说确实如此。

        当我们训练计算机视觉模型时,数据增强通常会有所帮助。无论我们使用迁移学习还是从头开始训练模型,都是如此。

        因此,数据增强是一种技术,可以在不收集新数据的情况下显着增加可用于训练的数据的多样性。您可以在此处找到有关数据增强理论方面的更多信息。

四、 使用数据增强进行训练

        乍一看,数据增强可能听起来很复杂,但幸运的是,TensorFlow 允许我们有效地实现它。

        因此,我们将像以前一样使用图像数据生成器,但我们将添加重新缩放、旋转、移动、缩放、缩放和翻转。再次重要的是要说,此过程仅适用于训练集,而不应用于验证。

        在这里,我们不仅要调整大小,还要添加旋转(以度为单位的范围)、高度和宽度偏移(以像素为单位的范围)、剪切范围(以度为单位的逆时针方向的角度)、缩放范围和翻转。

train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,)val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150, 150),batch_size=20,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(150, 150),batch_size=20,class_mode='binary')

Found 2000 images belonging to 2 classes.
Found 1000 images belonging to 2 classes.

        此外,Dropout 层将被添加到我们的神经网络中。我们将丢弃\(50%\)个起始神经元。添加 Dropout 图层将有助于防止过度拟合。

class Create_model(Model):def __init__(self, chanDim=-1):super(Create_model, self).__init__()self.conv1A = Conv2D(16, 3, input_shape = (150, 150, 3))self.act1A  = Activation("relu")self.pool1A = MaxPooling2D(2)self.conv1B = Conv2D(32, 3)self.act1B  = Activation("relu")self.pool1B = MaxPooling2D(pool_size=(2, 2))self.conv1C = Conv2D(64, 3)self.act1C  = Activation("relu")self.pool1C = MaxPooling2D(2)self.flatten = Flatten()self.dense2A = Dense(512)self.act2A  = Activation("relu")self.dropout = Dropout(0.5)self.dense2B = Dense(1)self.sigmoid  = Activation("sigmoid")def call(self, inputs):x = self.conv1A(inputs)x = self.act1A(x)x = self.pool1A(x)x = self.conv1B(x)x = self.act1B(x)x = self.pool1B(x)x = self.conv1C(x)x = self.act1C(x)x = self.pool1C(x)x = self.flatten(x)x = self.dense2A(x)x = self.act2A(x)x = self.dropout(x)x = self.dense2B(x)x = self.sigmoid(x)return xmodel = Create_model()model.compile(loss='binary_crossentropy',optimizer=RMSprop(lr=0.001),metrics=['accuracy'])

现在我们可以训练网络了。

history = model.fit(train_generator,steps_per_epoch=100,epochs=30,validation_data=validation_generator,validation_steps=50,verbose=2)

Train for 100 steps, validate for 50 steps
Epoch 1/30
100/100 - 92s - loss: 0.9063 - accuracy: 0.5125 - val_loss: 0.7271 - val_accuracy: 0.5000
Epoch 2/30
100/100 - 56s - loss: 0.7020 - accuracy: 0.5625 - val_loss: 0.6551 - val_accuracy: 0.5480
Epoch 3/30
100/100 - 56s - loss: 0.6815 - accuracy: 0.5950 - val_loss: 0.6253 - val_accuracy: 0.6600
Epoch 4/30
100/100 - 57s - loss: 0.6594 - accuracy: 0.6220 - val_loss: 0.6262 - val_accuracy: 0.6350
Epoch 5/30
100/100 - 56s - loss: 0.6352 - accuracy: 0.6485 - val_loss: 0.5916 - val_accuracy: 0.6890
Epoch 6/30
100/100 - 56s - loss: 0.6336 - accuracy: 0.6675 - val_loss: 0.5774 - val_accuracy: 0.6790
Epoch 7/30
100/100 - 57s - loss: 0.6383 - accuracy: 0.6570 - val_loss: 0.5830 - val_accuracy: 0.6980


让我们看看结果。在这里,训练集的准确性和损失都以蓝色显示,而验证集的准确度和损失都以橙色显示。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy)
plt.plot(epochs, val_accuracy)
plt.title('Training and validation accuracy')plt.figure()plt.plot(epochs, loss)
plt.plot(epochs, val_loss)
plt.title('Training and validation loss')
Text(0.5, 1.0, 'Training and validation loss')

现在的结果好多了。但是,我们模型的准确性还不完美。

我们将在下一篇文章中使用迁移学习来解决这个问题。

五、 可视化

到目前为止,我们只是在讨论如何创建增强图像,但让我们看看它们的外观。

为此,我们需要使用来自生成器的一个图像并“循环它”。这将遍历生成器并执行增强。下面我们展示了这些图像样本的可视化。

augmented_images = [train_generator[0][0][0] for i in range(12)]
plt.figure(figsize=(8,6))for i in range(12):plt.subplot(3, 4, i+1)image = augmented_images[i]image = image.reshape(150, 150, 3)plt.imshow(image)
pyplot.show()
狗增强图像
增强图像

六、总结

        总而言之,我们已经学会了如何使用数据增强技术来提高模型的性能。在数据稀缺或数据收集成本高昂的情况下,我们可以使用这种方法。但是,请注意,我们不能将数据集扩充到非常大的比例。此方法有其局限性。在下一篇文章中,我们将展示如何应用迁移学习的过程。

有关该主题的更多资源:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/104606.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

sts搭建springboot

sts搭建springboot 需要的软件 Apache-maven-3.3.9.rar 链接:百度网盘 请输入提取码 提取码:1xgj Spring-tool-suite-4-4.20.0.RELEASE-e4.29.0-win32.win32.x86_64.self-extracting.jar 链接:百度网盘 请输入提取码 提取码:p…

初级问题 程序中的变量是指什么?中级问题 把若干个数据沿直线排列起来的数据结构叫作什么?高级问题 栈和队列的区别是什么?

目录 1.深刻主题 2.描写复杂人物 初级问题 程序中的变量是指什么? 中级问题 把若干个数据沿直线排列起来的数据结构叫作什么? 高级问题 栈和队列的区别是什么? 计算机图形学(有效边表算法) 介绍一下计算机图形学…

焦炭反应性及反应后强度试验方法

声明 本文是学习GB-T 4000-2017 焦炭反应性及反应后强度试验方法. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 7— 进气口; 8— 测温热电偶。 图 A.1 单点测温加热炉体结构示意图 A.3 温度控制装置 控制精度:(11003)℃。…

C++-Mongoose(3)-http-server-https-restful

1.url 结构 2.http和 http-restful区别在于对于mg_tls_opts的赋值 2.1 http和https 区分 a) port地址 static const char *s_http_addr "http://0.0.0.0:8000"; // HTTP port static const char *s_https_addr "https://0.0.0.0:8443"; // HTTP…

Android笔记(六):JetPack Compose常见的UI组件

一、文本组件 1.1Text Column(modifier Modifier.fillMaxSize().background(Color.Green).padding(10.dp)){Text(text stringResource(id R.string.title_content),modifier Modifier.fillMaxWidth().border(BorderStroke(1.dp, Color.White)),fontSize 20.sp,textAlign …

分布式文件服务器——Windows环境MinIO的三种部署模式

上节简单聊到MinIO:分布式文件存储服务——初识MinIO-CSDN博客,但没具化,本节开始展开在Windows环境下 MinIO的三种部署模式:单机单节点、单机纠删码、集群模式。 部署的几种模式简要概括 所谓单机单节点模式:即MinI…

在Windows下自己从源码编译Python3.10.13成安装包

文章目录 (一)Python 3.10 的生命周期(一)下载源码(二)准备环境(三)编译(3.1)解压源码到目录(3.2)下载依赖(PCBuild&#…

消息称苹果或在明年推出搭载M3芯片的MacBook产品

近日据 DigiTimes 发布的博文,苹果公司计划在 2024 年推出搭载 M3 芯片的 MacBook 产品。然而,关于这款新产品的发布日期仍存在争议。虽然一些爆料认为苹果可能会在今年发布这款产品,但也有一些爆料认为发布时间会推迟到 2024 年。根据各项报…

02Maven核心程序的下载与settings.xml文件的配置,环境变量的配置

Maven核心程序的解压与配置 Maven的下载与解压 Maven官网下载安装包 将下载的Maven核心程序压缩包apache-maven-3.8.4-bin.zip解压到一个非中文且没有空格的目录 Maven的核心配置文件 在Maven的解压目录conf中我们需要配置Maven的核心配置文件settings.xml 配置本地仓库位置…

多机器人三角形编队的实现

文章目录 前言一、机器人编队前的准备二、配置仿真环境2.编写机器人编队.cpp文件 三、三角形编队测试 前言 前阵子一直想要实现多机器人编队,找到了很多开源的编队代码,经过好几天的思索,终于实现了在gazebo环境中的TB3三角形机器人编队。 一…

【数据分享】2022年我国30米分辨率的坡向数据(免费获取)

地形数据,也叫DEM数据,是我们在各项研究中最常使用的数据之一。之前我们分享过2022年哥白尼30米分辨率的DEM高程数据,该数据被公认为是全球最佳的开源DEM数据之一,甚至没有之一(可查看之前的文章获悉详情)&…

macbook电脑删除app怎么才能彻底清理?

macBook是苹果公司推出的一款笔记本电脑,它的操作系统是macOS。在macBook上安装的app可能会占用大量的存储空间,因此,当我们不再需要某个app时,需要将其彻底删除。macbook删除app,怎么才能彻底呢?本文将给大…

京东数据平台:2023年京东营养保健品市场销售数据分析

随着十一长假结束,市场端也开始了一系列的消费数据回顾和复盘。从现有数据表现来看,营养保健品市场的增长备受关注。 近日,京东消费及产业发展研究院与《经济日报》联合整合了相关数据。数据显示,2023年中秋福利采购季期间&#…

一文理清JVM结构

JVM结构介绍 JVM一共分为三个组成部分: 1 类加载子系统 主要是将class文件加载到内存中的一个系统,其核心组件是类加载器 2 运行时数据区子系统 1 JVM私有部分 1 虚拟机栈 描述的是Java方法执行的内存模型:每个方法在执行的同时都会创建一个栈帧&…

微信小程序/vue3/uview-plus form兜底校验

效果图 代码 <template><u-form :model"form" ref"formRole" :rules"rules"><u-form-item prop"nickname"><u-input v-model"form.nickname" placeholder"姓名" border"none" /&…

京东数据平台:2023年服饰行业销售数据分析

最近看到有些消费机构分析&#xff0c;不少知名的运动品牌都把“主战场”放到了冲锋衣&#xff0c;那么羽绒服市场就比较危险了。但其实羽绒服市场也有机会点可寻。 先来说冲锋衣。的确&#xff0c;从今年的销售数据以及增长情况&#xff0c;冲锋衣的确会是今年冬天的大热门品…

大数据flink篇之三-flink运行环境安装(一)单机Standalone安装

一、安装包下载地址 https://archive.apache.org/dist/flink/flink-1.15.0/ 二、安装配置流程 前提基础&#xff1a;Centos环境&#xff08;建议7以上&#xff09; 安装命令&#xff1a; 解压&#xff1a;tar -zxvf flink-xxxx.tar.gz 修改配置conf/flink-conf.yaml&#xff1…

ubuntu中查看进程并结束进程以查看资源占用命令

ps命令&#xff1a;可以列出正在运行的进程。ps -e ps -aux 查看所有进程&#xff0c;每行一个程序&#xff08;常用&#xff09;ps -A 查看当前系统所有的进程。&#xff08;常用&#xff09;ps -A | grep chrome 命令去搜索某个指定进程。&#xff08;常用&#xff09;ps -A…

LiveMedia视频中间件如何与第三方系统实现事件录像关联

一、平台简介 LiveMedia视频中间件是支持部署到本地服务器或者云服务器的纯软件服务&#xff0c;也提供服务器、GPU一体机全包服务&#xff0c;提供视频设备管理、无插件、跨平台的实时视频、历史回放、语音对讲、设备控制等基础功能&#xff0c;支持视频协议有海康、大华私有协…

Tomcat的安装和配置

一.Tomcat下载&#xff1a;去Tomcat官网地址 在左侧Download中选择你需要下载的版本&#xff0c;这里我选择Tomcat9 根据电脑系统是32位还是64位选择&#xff0c;这里我选择64-bit Windows zip&#xff0c;点击即可下载 下载后直接解压&#xff0c;这里我解压在E盘的computer…