【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习

一、说明

        在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么它被称为迁移学习;我们将现有模型的学习转移到新的数据集中。

教程概述:

  1. 介绍
  2. 使用内置的 TensorFlow 模型进行迁移学习
  3. 使用 TensorFlow Hub 进行迁移学习

二、为什么迁移学习简介

之        我们已经探讨了如何使用数据增强来提高模型性能。现在的问题是,“如果我们没有足够的数据来从头开始训练我们的网络怎么办?

对        此的解决方案是使用迁移学习方法。一篇更具理论意义的帖子已经发布在我们的博客上。如果需要,请查看它以刷新一些想法。我们可以使用迁移学习将知识从一些预先训练好的开源网络转移到我们自己的简历问题中。

计        算机视觉研究社区在互联网上发布了许多数据集,如Imagenet或MS Coco或Pascal数据集。许多计算机视觉研究人员已经在这些数据集上训练了他们的算法。有时,此培训需要数周时间,并且可能需要许多 GPU。事实上,其他人已经完成了这项任务并经历了痛苦的高性能研究过程,这意味着我们经常可以下载开源权重。

有        很多网络已经过训练。例如,Imagenet 数据集,它由 1000 个不同的类和超过 14 万张图像组成。因此,网络可能有一个 softmax 单元,它输出一千个可能的类之一。我们可以做的是摆脱softmax层并创建我们自己的输出单元来表示例如猫或

        由于我们使用下载的权重,我们将只训练与我们的输出层关联的参数,在我们的例子中,这将是一个 sigmoid 输出层。

三、使用内置的 TensorFlow 模型进行迁移学习

        让我们首先为训练准备数据集。我们将使用 wget.download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile
wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        现在,让我们导入所有必需的库并构建模型。我们将使用一个名为MobileNetV2的预训练网络,该网络在ImageNet数据集上进行训练。在这里,我们希望使用除顶部分类图层之外的所有图层,因此我们不会将它们包含在我们的网络中。

        obileNetV2 体系结构概述:https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html

这        个模型和其他预训练模型已经在TensorFlow中可用。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGeneratorbase_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.summary()


输出:
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]  

        果最后一层输出不同数量的类,那么我们需要有自己的输出单元来输出以下类:。有几种方法可以做到这一点:

  • 取最后几层的权重,将它们用作初始化并进行梯度下降。这样,我们将重新训练网络的一部分。
  • 删除最后几层的权重,使用我们自己的新隐藏单元和我们自己的最终 sigmoid(或 softmax)输出。通过这种方式,我们可以更改输出的数量。

        因此,这两种方法中的任何一种都值得尝试。

        现在让我们冻结预训练层并添加一个新层,称为 GlobalAveragePooling2D,之后是具有 sigmoid 激活函数的 Dense 层。

现        

base_model.trainable = Falsemodel = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.summary()

    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

          现在是训练步骤的时候了。我们将使用图像数据生成器来迭代图像。现在没有必要为大量时期训练网络,因为我们已经预先训练了一部分网络。

model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)

让        我们看看结果。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy, label="Training")
plt.plot(epochs, val_accuracy, label="Validation")
plt.legend()
plt.title('Training and validation accuracy')
plt.figure()plt.plot(epochs, loss, label="Training")
plt.plot(epochs, val_loss, label="Validation")
plt.legend()
plt.title('Training and validation loss')

3.Text(0.5, 1.0, 'Training and validation loss')

四、 使用 TensorFlow Hub 进行迁移学习

        访问预训练模型的另一种方法是TensorFlow Hub。TensorFlow Hub是一个库,用于发布,发现和使用机器学习模型的可重用部分。您可以在此处找到更多预训练模型。

        我们将冻结图层并添加一个用于分类的新图层。全连接网络的输入称为瓶颈要素。它们表示网络中最后一个卷积层的激活图。

我们为给定数量的 epoch 训练模型。

        在这里,我们也可以使用 TensorBoard,但这次让我们保持简单。

        最后,让我们可视化一些预测。在此数据集中,猫标记为 0,狗标记为 1。我们将用蓝色显示正确的预测,用红色显示假。

预测

五、总结

        正如我们在上面看到的,使用迁移学习可以帮助我们在短时间内取得非常好的结果。使用数据增强,可以进一步增强结果。

         在下一篇文章中,我们将展示如何创建网络并将其转换以在移动设备上使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101945.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 安装mysql

1、下载mysql安装包 2、创建mysql文件夹 mkdir /usr/local/mysql 3、解压mysql安装包,并将解压出来的文件夹下面的内容全部移动到/usr/local/mysql下 解压 tar zxvf mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 移动 mv /usr/local/src/mysql-5.7.39-linux-gl…

postgres之pg_dump导出和导入

postgres使用有一段时间了,现记录一下一些常用的命令-导出导入,备以后查询: 1.指定表结构导出 pg_dump --host127.0.0.1 --port5432 --username[用户名] -t[表名1] -t [表名1] --schema-only postgres > F:\db.sql 2.指定表数据的导出…

Vue项目为页面添加水印效果

最近在做项目,有这样要求,需要在指定容器中添加水印,也可不设置容器,如果没有容器,则添加在整个页面中,即body,当接到这个需求的时候我第一想的方法就是用canvas来实现,话不多说搞起…

Unity设计模式——装饰模式

装饰模式(Decorator),动态地给一个对象添加一些额外的职责,就增加功能来说,装饰模式比生成子类更为灵活。 Component类: abstract class Component : MonoBehaviour {public abstract void Operation(); …

华为认证 | HCIP-Datacom,这门认证正式发布新版本!

华为认证数通高级工程师HCIP-Datacom-Campus Network Planning and Deployment V1.5(中文版)自2023年9月28日起,正式在中国区发布。 01 发布概述 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构,华为公司…

互联网性能和可用性优化CDN和DNS

当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…

C++ 指向数组的指针

如果您对 C 指针的概念有所了解,那么就可以开始本章的学习。数组名是指向数组中第一个元素的常量指针。因此,在下面的声明中: double runoobAarray[50];runoobAarray 是一个指向 &runoobAarray[0] 的指针,即数组 runoobAarra…

Java 多线程编程

Java 多线程编程 目录 Java 多线程编程 一个线程的生命周 线程的优先级 创建一个线程 通过实现Runnable接口来创建线程 实例 通过继承Thread来创建线程 实例 Thread 方法 实例 线程的几个主要概念: 多线程的使用 Java给多线程编程提供了内置的支持。一个多线程程序包…

微信小程序项目如何用Hbuild启动,先让对方在微信开发平台将你的微信号添加成开发成员。

微信小程序项目如何用Hbuild启动,先让对方在微信开发平台将你的微信号添加成开发成员。然后在Hbuild官网下载,下载好后运行,点击文件导入项目,然后点击运行,模拟微信小程序,选择微信开发工具的地址&#xf…

【2023美团后端-8】删除字符串的方案,限制不能连续删

小美定义一个字符申是“美丽串”,当且仅当该字符串包含”mei”连续子串。例如”meimei”、“xiaomeichan"都是美丽串,现在小美拿到了一个字符串,她准备删除一些字符,但不能删除两个连续字符。小美希望最终字符串变成美丽串&a…

技巧:训练时Loss剧烈震荡原因汇总

输入数据有误 数据集太少 输入数据为经过标准化:没有经过正则化处理的数据可能存在异常点,或者数据的量纲不一致。可以采用min_max归一化或者z-score标准化 没有选择合理的数据增强没有经过合理的数据增强可能导致训练的时候网络学习的数据不是想要的最…

pyqt---子线程进行gui操作导致界面崩溃

在 PyQt(或 Qt 通常)中,您不能直接在子线程中执行与 GUI 相关的操作。这可能会导致应用程序崩溃或不可预测的行为。所有与 GUI 相关的操作都应该在主线程中执行。 如果您需要在子线程完成某些操作后显示一个消息框,可以使用 PyQt…

软件工程与计算总结(六)需求分析方法

本贴介绍需求分析方法,涉及到诸多实践性的东西,掌握各种图表的绘制是重中之重~ 一.需求分析基础 1.原因 需求获取中得到的信息仅仅解释了用户对软件系统的理解与期待,使用的是实际业务的表达方式,还不是开发者能够立即加以实现…

【算法|动态规划No.16】leetcode931. 下降路径最小和

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【手撕算法系列专栏】【LeetCode】 🍔本专栏旨在提高自己算法能力的同时,记录一下自己的学习过程,希望…

VSCODE+PHP8.2配置踩坑记录

VSCODEPHP8.2配置踩坑记录 – WhiteNights Site 我配置过的最恶心的环境之一:windows上的php。另一个是我centos服务器上的php。 进不了断点 端口配置和xdebug的安装 这个应该是最常见的问题了。从网上下载完php并解压到本地,打开vscode,安装…

基于SpringBoot的网上摄影工作室

目录 前言 一、技术栈 二、系统功能介绍 用户信息管理 作品分类管理 轮播图管理 摄影作品管理 摄影作品收藏 摄影圈 摄影作品发布 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统…

MySql基础:utf8_unicode_ci、utf8_general_ci

区别 utf8_general_ci 校对速度快,但准确度稍差。 utf8_unicode_ci 准确度高,但校对速度稍慢。 德语、法语或者俄语,请一定使用utf8_unicode_ci。一般用utf8_general_ci就够了。 utf8_unicode_ci和utf8_general_ci对中、英文来说没有差别。 …

华为云云耀云服务器L实例评测|华为云上的CentOS性能监测与调优指南

目录 引言 ​编辑1 性能调优的基本要素 2 性能监控功能 2.1 监控数据指标 2.2 数据历史记录 2.3 多种统计指标 3 性能优化策略 3.1 资源分配 3.2 磁盘性能优化 3.3 网络性能优化 3.4 操作系统参数和内核优化 结论 引言 在云计算时代,性能优化和调优对于…

Nginx 网站服务

Nginx 稳定性高 (但是没有apache稳定) 版本号:1.12 1.20 1.22 系统资源消耗低 (处理http请求的并发能力很高,单台物理服务器可以处理30000-50000个并发请求) 稳定:一般在企业中,为了保持服务器稳定,并发量…

记录 K8S 挂了的解决经过

背景:早上到公司,有同事反馈部署K8S在集群上的 Redis 和 禅道 都不可用 排查循序: 登录 kubesphere 的 web 界面 (界面打开失败)ssh 登录主服务器 (正常)在主服务器上运行 kubectl get node 命…