【TensorFlow2 之012】TF2.0 中的 TF 迁移学习

#012 TensorFlow 2.0 中的 TF 迁移学习

一、说明

        在这篇文章中,我们将展示如何在不从头开始构建计算机视觉模型的情况下构建它。迁移学习背后的想法是,在大型数据集上训练的神经网络可以将其知识应用于以前从未见过的数据集。也就是说,为什么它被称为迁移学习;我们将现有模型的学习转移到新的数据集中。

教程概述:

  1. 介绍
  2. 使用内置的 TensorFlow 模型进行迁移学习
  3. 使用 TensorFlow Hub 进行迁移学习

二、为什么迁移学习简介

之        我们已经探讨了如何使用数据增强来提高模型性能。现在的问题是,“如果我们没有足够的数据来从头开始训练我们的网络怎么办?

对        此的解决方案是使用迁移学习方法。一篇更具理论意义的帖子已经发布在我们的博客上。如果需要,请查看它以刷新一些想法。我们可以使用迁移学习将知识从一些预先训练好的开源网络转移到我们自己的简历问题中。

计        算机视觉研究社区在互联网上发布了许多数据集,如Imagenet或MS Coco或Pascal数据集。许多计算机视觉研究人员已经在这些数据集上训练了他们的算法。有时,此培训需要数周时间,并且可能需要许多 GPU。事实上,其他人已经完成了这项任务并经历了痛苦的高性能研究过程,这意味着我们经常可以下载开源权重。

有        很多网络已经过训练。例如,Imagenet 数据集,它由 1000 个不同的类和超过 14 万张图像组成。因此,网络可能有一个 softmax 单元,它输出一千个可能的类之一。我们可以做的是摆脱softmax层并创建我们自己的输出单元来表示例如猫或

        由于我们使用下载的权重,我们将只训练与我们的输出层关联的参数,在我们的例子中,这将是一个 sigmoid 输出层。

三、使用内置的 TensorFlow 模型进行迁移学习

        让我们首先为训练准备数据集。我们将使用 wget.download 命令下载数据集。之后,我们需要解压缩它并合并训练和测试部分的路径。

import os
import wget
import zipfile
wget.download("https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip")with zipfile.ZipFile("cats_and_dogs_filtered.zip","r") as zip_ref:zip_ref.extractall()base_dir = 'cats_and_dogs_filtered'train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')

        现在,让我们导入所有必需的库并构建模型。我们将使用一个名为MobileNetV2的预训练网络,该网络在ImageNet数据集上进行训练。在这里,我们希望使用除顶部分类图层之外的所有图层,因此我们不会将它们包含在我们的网络中。

        obileNetV2 体系结构概述:https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html

这        个模型和其他预训练模型已经在TensorFlow中可用。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import matplotlib.image as mpimgfrom tensorflow.keras import Model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.preprocessing.image import ImageDataGeneratorbase_model = MobileNetV2(input_shape=(224, 224, 3),include_top=False,weights='imagenet')base_model.summary()


输出:
Model: "mobilenetv2_1.00_224"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
Conv1_pad (ZeroPadding2D)       (None, 225, 225, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 112, 112, 32) 864         Conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 112, 112, 32) 128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 112, 112, 32) 0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 112, 112, 32) 288         Conv1_relu[0][0]  

        果最后一层输出不同数量的类,那么我们需要有自己的输出单元来输出以下类:。有几种方法可以做到这一点:

  • 取最后几层的权重,将它们用作初始化并进行梯度下降。这样,我们将重新训练网络的一部分。
  • 删除最后几层的权重,使用我们自己的新隐藏单元和我们自己的最终 sigmoid(或 softmax)输出。通过这种方式,我们可以更改输出的数量。

        因此,这两种方法中的任何一种都值得尝试。

        现在让我们冻结预训练层并添加一个新层,称为 GlobalAveragePooling2D,之后是具有 sigmoid 激活函数的 Dense 层。

现        

base_model.trainable = Falsemodel = Sequential([base_model,GlobalAveragePooling2D(),Dense(1, activation='sigmoid')])
model.summary()

    

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
mobilenetv2_1.00_224 (Model) (None, 7, 7, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

          现在是训练步骤的时候了。我们将使用图像数据生成器来迭代图像。现在没有必要为大量时期训练网络,因为我们已经预先训练了一部分网络。

model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)train_generator = train_datagen.flow_from_directory(train_dir,target_size=(224, 224),batch_size=32,class_mode='binary')validation_generator = val_datagen.flow_from_directory(validation_dir,target_size=(224, 224),batch_size=32,class_mode='binary')history = model.fit(train_generator,epochs=6,validation_data=validation_generator,verbose=2)

让        我们看看结果。

accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(accuracy))plt.plot(epochs, accuracy, label="Training")
plt.plot(epochs, val_accuracy, label="Validation")
plt.legend()
plt.title('Training and validation accuracy')
plt.figure()plt.plot(epochs, loss, label="Training")
plt.plot(epochs, val_loss, label="Validation")
plt.legend()
plt.title('Training and validation loss')

3.Text(0.5, 1.0, 'Training and validation loss')

四、 使用 TensorFlow Hub 进行迁移学习

        访问预训练模型的另一种方法是TensorFlow Hub。TensorFlow Hub是一个库,用于发布,发现和使用机器学习模型的可重用部分。您可以在此处找到更多预训练模型。

        我们将冻结图层并添加一个用于分类的新图层。全连接网络的输入称为瓶颈要素。它们表示网络中最后一个卷积层的激活图。

我们为给定数量的 epoch 训练模型。

        在这里,我们也可以使用 TensorBoard,但这次让我们保持简单。

        最后,让我们可视化一些预测。在此数据集中,猫标记为 0,狗标记为 1。我们将用蓝色显示正确的预测,用红色显示假。

预测

五、总结

        正如我们在上面看到的,使用迁移学习可以帮助我们在短时间内取得非常好的结果。使用数据增强,可以进一步增强结果。

         在下一篇文章中,我们将展示如何创建网络并将其转换以在移动设备上使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/101945.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 安装mysql

1、下载mysql安装包 2、创建mysql文件夹 mkdir /usr/local/mysql 3、解压mysql安装包,并将解压出来的文件夹下面的内容全部移动到/usr/local/mysql下 解压 tar zxvf mysql-5.7.39-linux-glibc2.12-x86_64.tar.gz 移动 mv /usr/local/src/mysql-5.7.39-linux-gl…

Unity设计模式——装饰模式

装饰模式(Decorator),动态地给一个对象添加一些额外的职责,就增加功能来说,装饰模式比生成子类更为灵活。 Component类: abstract class Component : MonoBehaviour {public abstract void Operation(); …

华为认证 | HCIP-Datacom,这门认证正式发布新版本!

华为认证数通高级工程师HCIP-Datacom-Campus Network Planning and Deployment V1.5(中文版)自2023年9月28日起,正式在中国区发布。 01 发布概述 基于“平台生态”战略,围绕“云-管-端”协同的新ICT技术架构,华为公司…

互联网性能和可用性优化CDN和DNS

当涉及到互联网性能和可用性优化时,DNS(Domain Name System)和CDN(Content Delivery Network)是两个至关重要的元素。它们各自发挥着关键作用,以确保用户能够快速、可靠地访问网站和应用程序。在本文中&…

软件工程与计算总结(六)需求分析方法

本贴介绍需求分析方法,涉及到诸多实践性的东西,掌握各种图表的绘制是重中之重~ 一.需求分析基础 1.原因 需求获取中得到的信息仅仅解释了用户对软件系统的理解与期待,使用的是实际业务的表达方式,还不是开发者能够立即加以实现…

【算法|动态规划No.16】leetcode931. 下降路径最小和

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【手撕算法系列专栏】【LeetCode】 🍔本专栏旨在提高自己算法能力的同时,记录一下自己的学习过程,希望…

VSCODE+PHP8.2配置踩坑记录

VSCODEPHP8.2配置踩坑记录 – WhiteNights Site 我配置过的最恶心的环境之一:windows上的php。另一个是我centos服务器上的php。 进不了断点 端口配置和xdebug的安装 这个应该是最常见的问题了。从网上下载完php并解压到本地,打开vscode,安装…

基于SpringBoot的网上摄影工作室

目录 前言 一、技术栈 二、系统功能介绍 用户信息管理 作品分类管理 轮播图管理 摄影作品管理 摄影作品收藏 摄影圈 摄影作品发布 三、核心代码 1、登录模块 2、文件上传模块 3、代码封装 前言 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统…

华为云云耀云服务器L实例评测|华为云上的CentOS性能监测与调优指南

目录 引言 ​编辑1 性能调优的基本要素 2 性能监控功能 2.1 监控数据指标 2.2 数据历史记录 2.3 多种统计指标 3 性能优化策略 3.1 资源分配 3.2 磁盘性能优化 3.3 网络性能优化 3.4 操作系统参数和内核优化 结论 引言 在云计算时代,性能优化和调优对于…

Qt QPair

QPair 文章目录 QPair 摘要QPairQPair 特点代码示例QPair 与 QMap 区别 关键字: Qt、 QPair、 QMap、 键值、 容器 摘要 今天在观摩小伙伴撸代码的时候,突然听到了QPair自己使用Qt开发这么就,竟然都不知道,所以趁没有被人发…

3、在 CentOS 8 系统上安装 PostgreSQL 15.4

PostgreSQL,作为一款备受欢迎的开源关系数据库管理系统(RDBMS),已经存在了三十多年的历史。它提供了SQL语言支持,用于管理数据库和执行CRUD操作(创建、读取、更新、删除)。 由于其卓越的健壮性…

k8s 集群部署 kubesphere

一、最小化部署 kubesphere 1、在已有的 Kubernetes 集群上部署 KubeSphere,下载 YAML 文件: wget https://github.com/kubesphere/ks-installer/releases/download/v3.4.0/kubesphere-installer.yaml wget https://github.com/kubesphere/ks-installer/releases/…

iPhone手机记笔记工具选择用哪个

iPhone手机大家应该都比较熟悉,其使用性能是比较流畅的,在iPhone手机上记录笔记可以帮助大家快速地进行总结工作、记录工作内容等,在iPhone手机上记笔记工具选择用哪个呢? 可以在iPhone手机上使用的笔记工具是比较多的&#xff0…

Andriod学习笔记(一)

写在前面的话 App开发的编程语言Java和KotlinXML App连接的数据库App工程目录结构模块级别的编译配置文件清单文件 界面显示与逻辑处理 安卓是一种基于Linux内核的自由及开放源代码的操作系统,主要使用于移动设备。 Mininum SDK表示安卓该版本以上的设备都可以运行该…

在Openresty中使用lua语言向请求浏览器返回请求头User-Agent里边的值

可以参考《Linux学习之Ubuntu 20.04在https://openresty.org下载源码安装Openresty 1.19.3.1,使用systemd管理OpenResty服务》安装Openresty。 然后把下边的内容写入到openresty配置文件/usr/local/openresty/nginx/conf/nginx.conf(根据实际情况进行选…

基于SSM的网络安全宣传网站设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

医疗机构又进化了,一招搞定UPS设备太爽了!

在现代社会中,能源供应的可靠性至关重要,不间断电源(UPS)系统是维持关键设备运行的关键组成部分。UPS监控是一种重要的技术,用于实时监测UPS的性能、电池状态和电力质量。 客户案例 四川某医院是一家大型医疗机构&…

应对广告虚假流量,app广告变现该如何风控?

移动广告市场中的虚假流量一直是困扰各移动应用厂商的难题,广告作为app商业化变现最为直接快捷的途径,也引申出了流量作弊与反作弊的纷争。 根据《2021中国异常流量报告》,2021年中国品牌广告市场因异常流量造成的损失约为326亿人民币&#…

适合学生写作业的台灯有哪些?高品质学生读写台灯推荐

不得不说如今我国青少年儿童的近视率还是非常高的,据国家卫健委疾控局数据,我国儿童青少年总体近视率为52.7%,其中6岁儿童为14.3%,小学生为35.6%,初中生为71.1%,高中生为80.5%,造成近视的原因不…

2. redis常见数据类型

一、Redis 数据类型 Redis支持五种数据类型:string(字符串),hash(哈希),list(列表),set(集合)及zset(sorted set:有序集合…