TensorFlow2实战-系列教程6:迁移学习实战

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  include_top = False, weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))     validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/653327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】工程实践问题概述

机器学习实际应用时的工程问题与面临的挑战 一、实现细节问题 1.1 训练样本 训练样本与标注对各类机器学习算法和模型的精度影响 训练样本的选择对各类机器学习算法和模型的影响 训练样本的优化 如何进行数据增强? 如何进行数据清洗? 样本的标注对各类机…

数据结构(二)------单链表

制作不易,三连支持一下呗!!! 文章目录 前言一.什么是链表二.链表的分类三.单链表的实现总结 前言 上一节,我们介绍了顺序表的实现与一些经典算法。 但是顺序表这个数据结构依然有不少缺陷: 1.顺序表指定…

导航页配置服务Dashy本地部署并实现公网远程访问

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务,具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起,形成自己的导航…

基于springboot宠物领养系统

摘要 随着社会的不断发展和人们生活水平的提高,宠物在家庭中的地位逐渐上升,宠物领养成为一种流行的社会现象。为了更好地管理和促进宠物领养的过程,本文基于Spring Boot框架设计和实现了一套宠物领养系统。该系统以用户友好的界面为特点&…

时序分析中的去趋势化方法

时序分析中的去趋势化方法 时序分析是研究随时间变化的数据模式的一门学科。在时序数据中,趋势是一种随着时间推移而呈现的长期变化趋势,去趋势化是为了消除或减弱这种趋势,使数据更具平稳性。本文将简单介绍时序分析中常用的去趋势化方法&a…

跟着cherno手搓游戏引擎【13】着色器(shader)

创建着色器类&#xff1a; shader.h:初始化、绑定和解绑方法&#xff1a; #pragma once #include <string> namespace YOTO {class Shader {public:Shader(const std::string& vertexSrc, const std::string& fragmentSrc);~Shader();void Bind()const;void Un…

怎样自行搭建幻兽帕鲁游戏联机服务器?

幻兽帕鲁是一款深受玩家喜爱的多人在线游戏&#xff0c;为了获取更好的游戏体验&#xff0c;许多玩家希望能够自行搭建幻兽帕鲁游戏联机服务器&#xff0c;本文将指导大家如何自行搭建幻兽帕鲁游戏联机服务器。 自行搭建幻兽帕鲁游戏联机服务器&#xff0c;阿里云是一个不错的选…

结构体的增删查改

结构体&#xff0c;是为了解决生活中的一些不方便利用c语言自带数据类型来表示的问题。例如表示一个学生&#xff0c;那么学生这个个体假如用c语言自带数据类型怎么表示呢。可以使用名字&#xff0c;也就是字符数组&#xff1b;也可以使用学号&#xff0c;也就是int类型。但是这…

iOS 面试 Swift基础题

一、Swift 存储属性和计算属性比较&#xff1a; 存储型属性:用于存储一个常量或者变量 计算型属性: 计算性属性不直接存储值,而是用 get / set 来取值 和 赋值,可以操作其他属性的变化. 计算属性可以用于类、结构体和枚举&#xff0c;存储属性只能用于类和结构体。存储属性可…

检测头篇 | 原创自研 | YOLOv8 更换 SEResNeXtBottleneck 头 | 附详细结构图

左图:ResNet 的一个模块。右图:复杂度大致相同的 ResNeXt 模块,基数(cardinality)为32。图中的一层表示为(输入通道数,滤波器大小,输出通道数)。 1. 思路 ResNeXt是微软研究院在2017年发表的成果。它的设计灵感来自于经典的ResNet模型,但ResNeXt有个特别之处:它采用…

MySQL-窗口函数 简单易懂

窗口函数 考查知识点&#xff1a; • 如何用窗口函数解决排名问题、Top N问题、前百分之N问题、累计问题、每组内比较问题、连续问题。 什么是窗口函数 窗口函数也叫作OLAP&#xff08;Online Analytical Processing&#xff0c;联机分析处理&#xff09;函数&#xff0c;可…

Android 基础技术——列表卡顿问题如何分析解决

笔者希望做一个系列&#xff0c;整理 Android 基础技术&#xff0c;本章是关于列表卡顿问题如何分析解决 onBindViewHolder 优化 是否有耗时操作、重复创建对象、设置监听器、findViewByID、局部的动画对象等操作 是否存在内存泄漏 发生内存泄露&#xff0c;会导致一些不再使用…

游戏开发丨基于Tkinter的扫雷小游戏

文章目录 写在前面扫雷小游戏需求分析程序设计程序分析运行结果系列文章写在后面 写在前面 本期内容 基于tkinter的扫雷小游戏 所需环境 pythonpycharm或anaconda 下载地址 https://download.csdn.net/download/m0_68111267/88790713 扫雷小游戏 扫雷是一款广为人知的单…

RabbitMQ“延时队列“

1.RabbitMQ"延时队列" 延迟队列存储的对象是对应的延迟消息&#xff0c;所谓“延迟消息”是指当消息被发送以后&#xff0c;并不想让消费者立刻拿到消息&#xff0c;而是等待特定时间后&#xff0c;消费者才能拿到这个消息进行消费 注意RabbitMQ并没有延时队列慨念,…

OpenCV-29 自适应阈值二值化

一、引入 在前面的部分我们使用的是全局阈值&#xff0c;整幅图像采用同一个数作为阈值。当时这种方法并不适应于所有情况。尤其是当同一幅图像上的不同部分具有不同的亮度时。这种情况下我们需要采用自适应阈值。此时的阈值时根据图像上的每一个小区域计算与其对应的阈值。因此…

【幻兽帕鲁】开服务器,高性能高带宽(100mbps),免费!!!【学生党强推】

【幻兽帕鲁】开服务器&#xff0c;高性能高带宽&#xff08;100mbps&#xff09;&#xff0c;免费&#xff01;&#xff01;&#xff01;【学生党强推】 教程相关视频地址&#xff1a;https://www.bilibili.com/video/BV16e411Y7Fd/ 目前幻兽帕鲁开服务器有以下几套比较性价比的…

研发日记,Matlab/Simulink避坑指南(九)——可变数组应用Bug

文章目录 前言 背景介绍 问题描述 分析排查 解决方案 总结归纳 前言 见《研发日记&#xff0c;Matlab/Simulink避坑指南(四)——transpose()转置函数Bug》 见《研发日记&#xff0c;Matlab/Simulink避坑指南(五)——CAN解包 DLC Bug》 见《研发日记&#xff0c;Matlab/Si…

qemu + vscode图形化调试linux kernel

一、背景 使用命令行连接gdb 在调试时&#xff0c;虽然可以通过tui enable 显示源码&#xff0c;但还是存在设置断点麻烦&#xff08;需要对着源码设置&#xff09;&#xff0c;terminal显示代码不方便&#xff0c;不利于我们学习&#xff1b;另外在gdb 下p命令显示结构体内容…

重构改善既有代码的设计-学习(六):处理继承关系

1、函数上移&#xff08;Pull Up Method&#xff09; 无论何时&#xff0c;只要系统内出现重复&#xff0c;你就会面临“修改其中一个却未能修改另一个”的风险。通常&#xff0c;找出重复也有一定的难度。 所以&#xff0c;某个函数在各个子类中的函数体都相同&#xff08;它们…

Pandas--数据结构 - Series(3)

Pandas Series 类似表格中的一个列&#xff08;column&#xff09;&#xff0c;类似于一维数组&#xff0c;可以保存任何数据类型。 Series 特点&#xff1a; 索引&#xff1a; 每个 Series 都有一个索引&#xff0c;它可以是整数、字符串、日期等类型。如果没有显式指定索引&…