TensorFlow2实战-系列教程6:猫狗识别3------迁移学习

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  include_top = False, weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))     validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述
猫狗识别1
数据增强
猫狗识别2------数据增强
猫狗识别3------迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657569.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解读4篇混合类型文件Polyglot相关的论文

0. 引入 Polyglot文件指的是混合类型文件,关于混合类型文件的基础,请参考文末给出的第一个链接(参考1)。 1. Toward the Detection of Polyglot Files 1.1 主题 这篇2022年的论文,提出了Polyglot文件的检测方法。虽…

openssl3.2 - .pod文件的查看方法

文章目录 .pod文件的查看方法概述笔记初步的解决方法备注 - pod2html.bat的详细用法好像Perl就自带这个BATEND .pod文件的查看方法 概述 看到openssl源码目录下有很多.pod文件, 软件发布的帮助内容都在里面. 当make install后, 大部分的.pod都会转成html文件, 但是有一部分…

【Java程序设计】【C00215】基于SSM的勤工助学管理系统(论文+PPT)

基于SSM的勤工助学管理系统(论文PPT) 项目简介项目获取开发环境项目技术运行截图 项目简介 这个一个基于SSM的勤工助学管理系统,本系统共分为三种权限:管理员、教师和学生 管理员:首页、个人中心、教师管理、学生管理…

gdp调试—Linux

目录 介绍 使用 介绍 代码分为debug模式和release模式 如果一份代码要被调试,这份代码必须是debug Linux下编译代码默认是是release模式 如果你想代码是debug模式 必须加上 - g 小提: vim默认:命令模式 gcc默认:releas…

操作系统--进程、线程基础知识

一、进程 我们编写的代码只是一个存储在硬盘的静态文件,通过编译后就会生成二进制可执行文件,当我们运行这个可执行文件后,它会被装载到内存中,接着 CPU 会执行程序中的每一条指令,那么这个运行中的程序,就…

ModelArts加速识别,助力新零售电商业务功能的实现

前言 如果说为客户提供最好的商品是产品眼中零售的本质,那么用户的思维是什么呢? 在用户眼中,极致的服务体验与优质的商品同等重要。 企业想要满足上面两项服务,关键在于提升效率,也就是需要有更高效率的零售&#…

C++ //练习 3.8 分别用while循环和传统的for循环重写第一题的程序,你觉得哪种形式更好呢?为什么?

C Primer(第5版) 练习 3.8 练习 3.8 分别用while循环和传统的for循环重写第一题的程序,你觉得哪种形式更好呢?为什么? 环境:Linux Ubuntu(云服务器) 工具:vim 代码块 /********…

【三】【C++】类与对象(二)

类的六个默认成员函数 在C中,有六个默认成员函数,它们是编译器在需要的情况下自动生成的成员函数,如果你不显式地定义它们,编译器会自动提供默认实现。这些默认成员函数包括: 默认构造函数 (Default Constructor)&…

C++ 数论相关题目 博弈论:拆分-Nim游戏

给定 n 堆石子,两位玩家轮流操作,每次操作可以取走其中的一堆石子,然后放入两堆规模更小的石子(新堆规模可以为 0 ,且两个新堆的石子总数可以大于取走的那堆石子数),最后无法进行操作的人视为失…

PMP中的数据收集工具:打开项目成功的钥匙

在项目管理中,数据收集是关键的一环。准确、及时的数据能够为项目决策提供可靠的依据,帮助项目经理更好地监控项目进展、识别潜在风险,并制定有效的应对措施。本文将深入探讨PMP(项目管理专业)中常用的数据收集工具&am…

力扣题目训练(6)

2024年1月30日力扣题目训练 2024年1月30日力扣题目训练367. 有效的完全平方数374. 猜数字大小383. 赎金信99. 恢复二叉搜索树105. 从前序与中序遍历序列构造二叉树51. N 皇后 2024年1月30日力扣题目训练 2024年1月30日第六天编程训练,今天主要是进行一些题训练&…

在ubuntu上在安装Squid代理服务器

Squid 是一个代理和缓存服务器,它将请求转发到所需的目的地,同时保存请求的内容,当你再次请求相同内容时,他可以向你提供缓冲内容,从而提高访问速度。Squid代理服务器目前支持的协议有:http、SSL、DNS、FTP…

App测试中ios和Android有哪些区别呢?

App测试中,大家最常问到的问题就是:ios和 Android有什么区别呢? 在Android端,我们经常会使用 JavaScript、 HTML、 CSS等技术来编写一些简单的 UI界面。而 iOS端,我们经常会使用到 UI设计、界面布局、代码结构、 API等…

C++——特殊类

特殊类 文章目录 特殊类一、请设计一个类,不能被拷贝二、请设计一个类,只能在堆上创建对象方案一:析构函数私有化方案二:构造函数私有化 三、请设计一个类,只能在栈上创建对象四、请设计一个类,不能被继承五…

微软Office Plus与WPS Office的较量:办公软件市场将迎来巨变?

微软Office Plus在功能表现上远超WPS Office? 微软出品的Office套件实力强劲,其不仅在办公场景中扮演着不可或缺的角色,为用户带来高效便捷的体验,而且在娱乐生活管理等多元领域中同样展现出了卓越的应用价值 作为中国本土办公软…

Leetcode 第 381 场周赛题解

Leetcode 第 381 场周赛题解 Leetcode 第 381 场周赛题解题目1:3014. 输入单词需要的最少按键次数 I思路代码复杂度分析 题目2:3015. 按距离统计房屋对数目 I思路代码复杂度分析 题目3:3016. 输入单词需要的最少按键次数 II思路代码复杂度分析…

HarmonyOS使用Web组件加载页面

1、加载网络页面 在Web组件创建时,指定默认加载的网络页面 。在默认页面加载完成后,如果开发者需要变更此Web组件显示的网络页面,可以通过调用loadUrl()接口加载指定的网页。 默认在Web组件加载完“www.baidu.com”页面后,点击按…

html+js+css静态故宫主题

登录代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>登录 - 故宫博物院</title><…

C语言菜鸟入门·函数

目录 1. 函数的定义 2. 函数声明 3. 函数调用 4. 函数参数 4.1 传值调用 4.2 引用调用 函数是一组一起执行一个任务的语句。每个 C 程序都至少有一个函数&#xff0c;即主函数 main() &#xff0c;所有简单的程序都可以定义其他额外的函数。 您可以把代码划分到不同…

【C++】C++入门—— 引用

引用 1 前情提要2 概念剖析3 引用特性4 常引用5 使用场景5.1做参数5.2 做返回值 6 传值 传引用的效率比较7 引用与指针的差异Thanks♪(&#xff65;ω&#xff65;)&#xff89;谢谢阅读下一篇文章见 1 前情提要 在C语言中&#xff0c;我们往往会遇见复杂的指针&#xff08;如…