政安晨:【Keras机器学习实践要点】(十二)—— 迁移学习和微调

目录

设置

介绍

冻结层:了解可训练属性

可训练属性的递归设置

典型的迁移学习工作流程

微调

关于compile()和trainable的重要说明

BatchNormalization层的重要注意事项


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

本文是Keras 迁移学习和微调的完全指南文章。

设置

import numpy as np
import keras
from keras import layers
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

Keras是一种用于构建深度学习模型的高级神经网络API。迁移学习是指在一个任务上训练好的模型或特征提取器用于另一个相关任务上。

Keras迁移学习是利用预训练模型的特征提取能力来加速模型训练。预训练模型通常是在大规模数据集上训练的,并且已经提取出了一些有用的特征。迁移学习可以通过利用这些特征来降低新任务的数据需求和训练时间。

Keras提供了一些流行的预训练模型,如VGG16、ResNet50和InceptionV3等。这些模型可以直接在Keras中加载,并且可以通过设置参数来冻结一部分或全部层,以便在新任务上进行微调。

迁移学习的步骤包括加载预训练模型、修改模型结构(根据新任务的要求)并选择要冻结的层、添加新的全连接层或输出层、在新数据集上进行训练和微调模型。在训练过程中,可以选择冻结一些层,只训练新添加的层,以避免破坏原始特征提取能力。

Keras迁移学习的好处包括:

  1. 加速模型训练因为可以利用预训练模型的特征提取能力。
  2. 避免从头开始训练模型减少数据需求和计算资源。
  3. 可以应用在小数据集上而不需要大规模数据集。

总而言之,Keras迁移学习是一种利用预训练模型的特征提取能力来加速深度学习模型训练的方法。它可以帮助我们在新任务上更快、更有效地构建和训练模型。

介绍

迁移学习包括利用在一个问题上学习到的特征,在一个新的、类似的问题上加以利用。

例如,从一个识别浣熊的模型中学习到的特征可能有助于启动一个用于识别塔努基鱼的模型。

迁移学习通常用于数据集数据太少,无法从头开始训练完整模型的任务。

在深度学习中,迁移学习最常见的表现形式是以下工作流程:
 

× 从先前训练好的模型中提取图层。
× 冻结它们,以避免在未来的训练中破坏它们所包含的任何信息。
× 在冻结层上添加一些新的、可训练的层。它们将学会在新数据集上把旧特征转化为预测结果。
× 在数据集上训练新层。

最后一个可选步骤是微调,包括解冻上述获得的整个模型(或部分模型),并以极低的学习率在新数据上对其进行重新训练。通过逐步调整预训练特征以适应新数据,这有可能实现有意义的改进。

首先,我们将详细介绍 Keras 可训练 API,它是大多数迁移学习和微调工作流程的基础。

然后,我们将通过在 ImageNet 数据集上预训练模型,并在 Kaggle "猫与狗 "分类数据集上重新训练模型来演示典型的工作流程。

冻结层:了解可训练属性

层和模型有三种权重属性:

weights 是层的所有权重变量的列表。
trainable_weights(可训练权重)是指在训练过程中为了最小化损失而更新(通过梯度下降)的权重列表。
non_trainable_weights(不可训练权重)是不需要训练的权重列表。通常情况下,模型会在前向传递过程中更新这些权重。

例如密集层有 2 个可训练的权重(内核和偏置)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weightsprint("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

结果如下:

weights: 2
trainable_weights: 2
non_trainable_weights: 0

一般来说,所有权重都是可训练权重。唯一具有不可训练权重的内置层是批归一化层。它在训练过程中使用不可训练权重来跟踪输入的均值和方差。以后咱们也会了解如何在自己的自定义层中使用不可训练权重。

示例:将可训练设置为假

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layerprint("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

结果如下:

weights: 2
trainable_weights: 0
non_trainable_weights: 2

当可训练权重变为不可训练权重时,其值在训练过程中不再更新。

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])# Freeze the first layer
layer1.trainable = False# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(initial_layer1_weights_values[1], final_layer1_weights_values[1]
)

演绎如下:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 766ms/step - loss: 0.0615

请勿将 layer.trainable 属性与 layer.__call__() 中的参数 training 混淆(后者控制层是以推理模式还是训练模式运行前向传递)。

可训练属性的递归设置

如果在模型上或任何有子图层的图层上设置 trainable = False,所有子图层也将变得不可训练。

示例:

inner_model = keras.Sequential([keras.Input(shape=(3,)),keras.layers.Dense(3, activation="relu"),keras.layers.Dense(3, activation="relu"),]
)model = keras.Sequential([keras.Input(shape=(3,)),inner_model,keras.layers.Dense(3, activation="sigmoid"),]
)model.trainable = False  # Freeze the outer modelassert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

典型的迁移学习工作流程

这就引出了如何在 Keras 中实现典型的迁移学习工作流程:

实例化基础模型并加载预训练权重。
通过设置可训练 = 假,冻结基础模型中的所有层。
在基础模型中一个(或多个)层的输出基础上创建一个新模型。
在新数据集上训练新模型。

请注意,也可以采用另一种更轻便的工作流程:

实例化一个基础模型,并将预先训练好的权重载入其中。
通过它运行新数据集,并记录基础模型中一个(或多个)层的输出。这就是所谓的特征提取。
将该输出作为一个新的、更小的模型的输入数据。

第二个工作流程的主要优势在于,你只需在数据上运行一次基础模型,而不是每个历元训练一次。因此速度更快,成本更低。
 

不过,第二种工作流程的一个问题是,它无法在训练过程中动态修改新模型的输入数据,而这在进行数据扩增时是必需的。当新数据集的数据太少,无法从头开始训练一个完整的模型时,迁移学习通常会被用于这种任务,在这种情况下,数据扩增就显得非常重要。

因此,在下文中,我们将重点介绍第一种工作流程。

下面是 Keras 中的第一个工作流程:


首先,实例化一个带有预训练权重的基础模型。

base_model = keras.applications.Xception(weights='imagenet',  # Load weights pre-trained on ImageNet.input_shape=(150, 150, 3),include_top=False)  # Do not include the ImageNet classifier at the top.

然后,冻结基本模型。

base_model.trainable = False

在最开始处创建一个新模型。

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

在新数据上训练新数据。

model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.BinaryCrossentropy(from_logits=True),metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

微调

一旦模型在新数据上收敛,你可以尝试解冻所有或部分基础模型,并使用非常低的学习率对整个模型进行端到端的重新训练。

这是一个可选的最后一步,可能会给你带来渐进的改进。但也有可能很快出现过拟合的情况,要记住这一点。

在模型的冻结层收敛之后才进行此步骤非常关键。如果将随机初始化的可训练层与保存预训练特征的可训练层混合,随机初始化的层将在训练过程中导致非常大的梯度更新,破坏预训练特征。

此阶段使用非常低的学习率也非常关键,因为你正在训练一个比第一轮训练中要大得多的模型,而且通常使用的数据集非常小。因此,如果应用大的权重更新,很容易出现过拟合的情况。在这里,你只想以增量的方式重新适应预训练的权重。

这就是如何实施整个基础模型的微调过程:

# Unfreeze the base model
base_model.trainable = True# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rateloss=keras.losses.BinaryCrossentropy(from_logits=True),metrics=[keras.metrics.BinaryAccuracy()])# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

关于compile()和trainable的重要说明

在模型上调用compile()旨在“冻结”该模型的行为。这意味着在模型编译时可训练属性的值应保持不变,直到再次调用compile()。因此,如果您更改任何可训练值,请确保再次调用compile()以使您的更改生效。

BatchNormalization层的重要注意事项

许多图像模型都包含BatchNormalization层。在各个方面,该层都是一个特例。以下是一些需要记住的事情。

1. BatchNormalization包含2个不可训练的权重,在训练过程中这些权重会更新。这些变量用于跟踪输入的均值和方差。

2. 当你设置bn_layer.trainable = False时,BatchNormalization层将以推理模式运行,并且不会更新其均值和方差统计数据。这与一般情况下的其他层不同,因为权重的可训练性和推理/训练模式是两个不相关的概念。但是,在BatchNormalization层的情况下,这两者是相互关联的。

3. 当你解冻包含BatchNormalization层的模型以进行微调时,应该在调用基本模型时将BatchNormalization层保持在推理模式中,即通过传递training=False来实现。否则,对不可训练权重的更新将突然破坏模型所学到的内容。

后面的文章中,咱们将开展一个端到端示例,在那里面,您将看到这种模式的实际应用。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/784404.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端对数据进行分组和计数处理

js对数组数据的处理,添加属性,合并表格数据。 let data[{id:1,group_id:111},{id:2,group_id:111},{id:3,group_id:111},{id:4,group_id:222},{id:5,group_id:222} ]let tempDatadata; tempDatatempData.reduce((arr,item)>{let findarr.find(i>i…

this.$route.back()时的组件缓存

1.this.$route.back()回到上一个路径会重新加载 跳转时,前一个路由的内容会被销毁,当回来时,重新创建树,组件内有保存了距离,没有一开始是0. 2.keep-alive写在router-view上面,这个地方所代表的路由会被保存,因此可以写在上面,保存,当返回时,如果是这个路由,里面的内容是一样…

CommunityToolkit.Mvvm----配置

一、介绍: CommunityToolkit.Mvvm 包(又名 MVVM 工具包,以前称为 Microsoft.Toolkit.Mvvm)是一个现代、快速和模块化的 MVVM 库。 它是 .NET 社区工具包的一部分,围绕以下原则生成: 独立于平台和运行时 - …

Android进阶学习:移动端开发重点学习的十点,不能再得过且过的写业务代码了

最近有朋友问我:“安卓开发是不是没人要了,除了画 UI 别的都不会怎么办?” 考虑到这可能是很多人共同的疑问,决定简单写一下。 说了很多遍了,**不是安卓开发没人要了,是初级安卓没人要了。**现在还在大量…

Karrier One在Sui上构建无线电话服务

Karrier One计划实现无线连接的民主化,为长期以来一直缺乏稳定服务或根本没有服务的地区提供服务,并为没有传统银行账户的人提供现代支付能力。 但是,将以行动迟缓著称的电信行业引入Web3世界是一项艰巨的任务。Karrier One团队决定利用Sui技…

保研线性代数机器学习基础复习2

1.什么是群(Group)? 对于一个集合 G 以及集合上的操作 ,如果G G-> G,那么称(G,)为一个群,并且满足如下性质: 封闭性:结合性:中性…

超强命令行解析工具Apache Commons CLI

概述 为什么要写这篇文章呢?因为在读flink cdc3.0源码的时候发现了这个工具包,感觉很牛,之前写过shell命令,shell是用getopts来处理命令行参数的,但是其实写起来很麻烦,长时间不写已经完全忘记了,现在才发现原来java也有这种工具类,所以先学习一下这个的使用,也许之后自己在写…

速通汇编(三)寄存器及汇编mul、div指令

一,寄存器及标志 AH&ALAX(accumulator):累加寄存器BH&BLBX(base):基址寄存器CH&CLCX(count):计数寄存器DH&DLDX(data):数据寄存器SP(Stack Pointer):堆栈指针寄存器BP(Base Pointer)&#…

红黑树是什么,为什么HashMap使用红黑树代替数组+链表?

前言 我们都知道在HashMap中,当数组长度大于64并且链表长度大于8时,HashMap会从数组链表的结构转换成红黑树,那为什么要转换成红黑树呢,或者为什么不一开始就使用红黑树呢?接下来我们将去具体的去剖析一下!…

java计算机网络(一)-- url,tcp,udp,socket

网络编程: 计算机网络 计算机网络指将地理位置不同的具有独立功能的多台计算机及其外部设备,通过通信线路连接起来,在网络操作系统、网络管理软件及网络通信协议的管理和协调下,实现资源共享和信息传递的计算机系统。 网络协议…

0.5米多光谱卫星影像在农业中进行地物非粮化、非农化监测

一、引言 随着科技的发展,卫星遥感技术已经成为了农业领域中重要的数据来源。其中,多光谱卫星影像以其独特的优势,在农业应用中发挥着越来越重要的作用。本文将重点探讨0.5米加2米多光谱卫星影像在农业中的应用。 二、多光谱卫星影像概述 多…

机器学习全攻略:概念、流程、分类与行业应用案例集锦

目录 1.引言 2.从零开始认识机器学习:基本概念与重要术语 3.五步走:掌握机器学习项目执行的完整流程 3.1.问题定义与数据收集 3.2.数据预处理与特征工程 3.3.模型选择与训练 3.4.模型评估与优化 3.5.模型部署与监控 4.深入了解各类机器学习方法…

Python爬虫-懂车帝城市销量榜单

前言 本文是该专栏的第23篇,后面会持续分享python爬虫干货知识,记得关注。 最近粉丝留言咨询某汽车平台的汽车销量榜单数据,本文笔者以懂车帝平台为例,采集对应的城市汽车销量榜单数据。 具体的详细思路以及代码实现逻辑,跟着笔者直接往下看正文详细内容。(附带完整代码…

pnpm比npm、yarn好在哪里?

前言 pnpm对比npm/yarn的优点: 更快速的依赖下载更高效的利用磁盘空间更优秀的依赖管理 我们按照包管理工具的发展历史,从 npm2 开始讲起: npm2 使用早期的npm1/2安装依赖,node_modules文件会以递归的形式呈现,严格…

统计子矩阵(前缀和+双指针)

题目描述 给定一个 N M 的矩阵 A,请你统计有多少个子矩阵 (最小 1 1,最大 N M) 满足子矩阵中所有数的和不超过给定的整数 K? 输入格式 第一行包含三个整数 N, M 和 K. 之后 N 行每行包含 M 个整数,代表矩阵 A. 输出格式 一个整数…

Django DRF视图

文章目录 一、DRF类视图介绍APIViewGenericAPIView类ViewSet类ModelViewSet类重写方法 二、Request与ResponseRequestResponse 参考 一、DRF类视图介绍 在DRF框架中提供了众多的通用视图基类与扩展类,以简化视图的编写。 • View:Django默认的视图基类&…

ES的RestClient相关操作

ES的RestClient相关操作 Elasticsearch使用Java操作。 本文仅介绍CURD索引库和文档!!! Elasticsearch基础:https://blog.csdn.net/weixin_46533577/article/details/137207222 Elasticsearch Clients官网:https://ww…

(文章复现)考虑分布式电源不确定性的配电网鲁棒动态重构

参考文献: [1]徐俊俊,吴在军,周力,等.考虑分布式电源不确定性的配电网鲁棒动态重构[J].中国电机工程学报,2018,38(16):4715-47254976. 1.摘要 间歇性分布式电源并网使得配电网网络重构过程需要考虑更多的不确定因素。在利用仿射数对分布式电源出力的不确定性进行合…

博客页面---前端

目录 主页 HTML CSS 文章详细页面 HTML CSS 登录页面 HTML CSS 文章编辑页 HTML CSS 这只是前端的页面组成&#xff0c;还没有接入后端&#xff0c;并不是完全体 主页 HTML <!DOCTYPE html> <!-- <html lang"en"> --> <head>&…

区间预测 | Matlab实现带有置信区间的BP神经网络时间序列未来趋势预测

区间预测 | Matlab实现带有置信区间的BP神经网络时间序列未来趋势预测 目录 区间预测 | Matlab实现带有置信区间的BP神经网络时间序列未来趋势预测预测效果基本介绍研究回顾程序设计参考资料预测效果 基本介绍 BP神经网络(Backpropagation neural network)是一种常用的人工神…