深度学习_微调_7

目标

  • 微调的原理
  • 利用微调模型来完成图像的分类任务

微调的原理

微调(Fine-tuning)是一种在深度学习中广泛应用的技术,特别是在预训练模型(Pretrained-Models)的基础上进行定制化训练的过程。微调的基本原理和步骤如下:

  1. 预训练阶段

    • 微调通常始于一个已经在大规模数据集上预训练过的模型,例如预训练的神经网络模型(如BERT、GPT系列、Vision Transformer等)。
    • 预训练模型在诸如自然语言处理(NLP)或计算机视觉(CV)等任务上进行了自我监督学习或有监督学习,从而习得了大量的通用特征表示。
  2. 适应新任务

    • 当需要解决与预训练任务相关的但更加具体的新任务时,就可以利用微调技术。
    • 对于NLP任务,可能是训练模型去回答问题、生成文本或者分类;对于CV任务,可能是对特定种类的物体进行识别或定位。
  3. 模型结构调整

    • 根据新任务的需求,可能需要对预训练模型的部分或全部层进行微调。
    • 通常在迁移学习中,预训练模型的底层可以捕获非常通用的特征,因此通常会被保留并微调,而顶层(尤其是分类层)会替换为新的、与新任务适配的输出层。
  4. 参数更新

    • 微调过程中,模型在新任务的数据集上重新开始训练,不过并非从随机初始化参数开始,而是从预训练模型的参数开始。
    • 学习率通常会设置得相对较低,以免破坏预训练模型学到的良好特征表示。
    • 其他优化参数(如权重衰减、批归一化层的状态等)也可能根据新任务的特点进行调整。
  5. 学习过程

    • 在目标数据集上训练模型时,不仅会更新新增的输出层参数,也会对预训练模型的某些层参数进行微调,使其更好地适应新任务的数据分布和特征。
    • 由于预训练模型已经具备良好的初始化,因此在相对较小的数据集上进行微调时,模型往往能够更快收敛到较好的解。

总结起来,微调的原理是利用预训练模型中的已学知识作为初始状态,通过对新任务数据的训练,对模型参数进行针对性的更新和优化,从而使模型能够适应新的应用场景。相较于从头训练,微调大大减少了所需的训练时间和数据量,提高了模型在特定任务上的性能和泛化能力。

1.微调

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

2.热狗识别

接下来我们来实践一个具体的例子:热狗识别。将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的工具包。

import tensorflow as tf
import numpy as np

2.1 获取数据集

我们首先将数据集放在路径hotdog/data之下:

每个类别文件夹里面是图像文件。

上一节中我们介绍了ImageDataGenerator进行图像增强,我们可以通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self, directory,target_size=(256, 256), color_mode='rgb',classes=None, class_mode='categorical',batch_size=32, shuffle=True, seed=None,save_to_dir=None)

主要参数:

  • directory: 目标文件夹路径,对于每一个类对应一个子文件夹,该子文件夹中任何JPG、PNG、BNP、PPM的图片都可以读取。
  • target_size: 默认为(256, 256),图像将被resize成该尺寸。
  • batch_size: batch数据的大小,默认32。
  • shuffle: 是否打乱数据,默认为True。

我们创建两个tf.keras.preprocessing.image.ImageDataGenerator实例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高和宽均为224像素的输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化。

# 获取数据集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 获取训练集数据
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 获取测试集数据
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 创建imageDataGenerator进行图像处理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 设置参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 获取训练数据
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),batch_size=BATCH_SIZE,target_size=(IMG_HEIGHT, IMG_WIDTH),shuffle=True)
# 获取测试数据
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),batch_size=BATCH_SIZE,target_size=(IMG_HEIGHT, IMG_WIDTH),shuffle=True)

下面我们随机取1个batch的图片然后绘制出来。

import matplotlib.pyplot as plt
# 显示图像
def show_batch(image_batch, label_batch):plt.figure(figsize=(10,10))for n in range(15):ax = plt.subplot(5,5,n+1)plt.imshow(image_batch[n])plt.axis('off')
# 随机选择一个batch的图像        
image_batch, label_batch = next(train_data_gen)
# 图像显示
show_batch(image_batch, label_batch)

2.2 模型构建与训练

我们使用在ImageNet数据集上预训练的ResNet-50作为源模型。这里指定weights='imagenet'来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

Keras应用程序(keras.applications)是具有预先训练权值的固定架构,该类封装了很多重量级的网络架构,如下图所示:

实现时实例化模型架构:

tf.keras.applications.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None,pooling=None, classes=1000, **kwargs
)

主要参数:

  • include_top: 是否包括顶层的全连接层。
  • weights: None 代表随机初始化, 'imagenet' 代表加载在 ImageNet 上预训练的权值。
  • input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必须为 3 个输入通道,且宽高必须不小于 32,比如 (200, 200, 3) 是一个合法的输入尺寸。

在该案例中我们使用resNet50预训练模型构建模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 设置所有层不可训练
for layer in ResNet50.layers:layer.trainable = False
# 设置模型
net = tf.keras.models.Sequential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation='softmax'))

接下来我们使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练。

# 模型编译:指定优化器,损失函数和评价指标
net.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(train_data_gen,steps_per_epoch=10,epochs=3,validation_data=test_data_gen,validation_steps=10)

Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AutoSAR配置与实践(深入篇)10.5 CANTP 层对意外到达的N-PDU处理策略

AutoSAR配置与实践(深入篇)10.5 CANTP 层对意外到达的N-PDU处理策略 CANTP 层对意外到达的N-PDU处理策略一、规范说明二、具体流程图解析2.1 发送端对意外到达的PDU的处理图解2.2 接收端对意外到达的PDU的处理图解CANTP 层对意外到达的N-PDU处理策略 ->返回总目录<- …

【项目】YOLOv5+PaddleOCR实现艺术字验证码识别

YOLOv5PaddleOCR实现艺术字类验证码识别 一、引言1.1 实现目标1.2 人手动点选验证码逻辑1.3 计算机点选逻辑 二、计算机验证方法2.1 PaddleOCR下方文字识别方法2.2 YOLOv5目标检测方法2.3 艺术字分类方法2.4 返回结果 三、代码获取 一、引言 1.1 实现目标 要识别的验证码类型…

c语言综合练习题

1.编写程序实现键盘输入一个学生的学分绩点 score&#xff08;合法的范围为:1.0—5.0&#xff09;&#xff0c;根据学生的学分绩点判定该学 生的奖学金的等级&#xff0c;判定规则如下表所示。 #include <stdio.h>int main() {float score;printf("请输入学生的学分…

Harbor-私有镜像仓库

目录 一、Harbor 原理说明 1.软件资源介绍 2.Harbor 特性 3.Harbor 认证过程 4.Harbor 认证流程 二、私有镜像仓库实验 1.环境准备 2.安装docker 3.配置镜像加速和私有仓库地址 4.搭建harbor仓库 5.本地windows浏览器访问配置 一、Harbor 原理说明 1.软件资源介绍 …

突破编程_C++_设计模式(访问者模式)

1 访问者模式的基本概念 C中的访问者模式是一种行为设计模式&#xff0c;它允许你在不修改类层次结构的情况下增加新的操作。这种模式将数据结构与数据操作解耦&#xff0c;使得操作可以独立于对象的类来定义。 访问者模式的主要组成部分包括&#xff1a; &#xff08;1&…

面试算法-62-盛最多水的容器

题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明&#xff1a;你不能倾斜容器。…

CycleGAN训练及测试过程细节记录

CycleGAN训练及测试过程细节记录 文章目录 关于训练关于测试 关于训练 1、训练前将数据配置好&#xff0c;并在Pycharm中写好配置信息 2、关于训练过程的参数配置在 options/train_options.py options/base_options.py batch_size&#xff1a;批大小 crop_size&#xff1a;…

python实现--拓扑排序

拓扑排序是对有向无环图&#xff08;DAG&#xff09;进行排序的一种算法&#xff0c;它可以将图中的顶点排成一个线性序列&#xff0c;使得图中的任意一条有向边都从序列中的较早顶点指向较晚顶点。换句话说&#xff0c;如果图中存在一条从顶点A到顶点B的有向边&#xff0c;那么…

Android分区存储到底该怎么做

文章目录 一、Android存储结构二、什么是分区存储&#xff1f;三、私有目录和公有目录三、存储权限和分区存储有什么关系&#xff1f;四、我们应该该怎么做适配&#xff1f;4.1、利用File进行操作4.2、使用MediaStore操作数据库 一、Android存储结构 Android存储分为内部存储和…

【逆向】使用 Frida 进行 Android 应用程序动态分析与加密算法逆向

不愿染是与非 怎料事与愿违 心中的花枯萎 时光它去不回 回忆辗转来回 痛不过这心扉 愿只愿余生无悔 随花香远飞 &#x1f3b5; 毛不易《不染》 在移动应用程序开发中&#xff0c;保护用户数据的安全至关重要。加密算法是保护数据安全的重要手段之一。然而…

【晴问算法】提高篇—动态规划专题—最长上升子序列

题目描述 现有一个整数序列a1,a2,...,an​​​​​​&#xff0c;求最长的子序列&#xff08;可以不连续&#xff09;&#xff0c;使得这个子序列中的元素是非递减的。输出该最大长度。 输入描述 第一行一个正整数n&#xff08;1≤n≤100​​​​&#xff09;&#xff0c;表示序…

【进阶版讲解深度学习如何入门?】

深度学习如何入门&#xff1f; 1. 前言2. 学习基础知识3. 了解机器学习4. 编程和工具5. 深度学习基础6. 实战项目7. 高级概念8. 持续学习9. 推荐资源 1. 前言 深度学习是机器学习的一个子领域&#xff0c;它受到了生物神经网络的启发&#xff0c;依赖于构建多层的神经网络来学…

Windows 11 安装 Scoop

[Windows 11 安装 Scoop](Windows 11 安装 Scoop) 0. 引言 Scoop 从命令行安装您熟悉和喜爱的程序&#xff0c;差异最小。 它的主要功能如下&#xff1a; 消除权限弹出窗口 隐藏 GUI 向导样式的安装程序 防止PATH污染安装大量程序 避免安装和卸载程序的意外副作用 自动查…

算法-背包问题

问题描述 假设我有一个背包&#xff0c;希望在装得下的情况下&#xff0c;尽量装进价值更多的物品。那么我该怎么做呢&#xff1f; 问题抽象 假设背包的容量是m&#xff0c;就假设是4吧 # 表示背包容量4KG m 4 可选装进背包的物品有n个&#xff0c;物品的价值存储在prices…

支付宝手机网站支付,微信扫描二维码支付

支付宝手机网站支付 支付宝文档 响应示例 <form name"punchout_form" method"post" action"https://openapi.alipay.com/gateway.do?charsetUTF-8&methodalipay.trade.wap.pay&formatjson&signERITJKEIJKJHKKKKKKKHJEREEEEEEEEEEE…

Maven打包时报错:Cannot allocate memory

使用Jenkins执行Maven打包任务时报错 Cannot allocate memory解决办法&#xff1a; 配置系统变量 MAVEN_OPTS-Xmx256m -XX:MaxPermSize512m或者 在项目目录下新建文件 .mvn/jvm.config -Xmx256m -Xms256m -XX:MaxPermSize512m -Djava.awt.headlesstrue参考 Jenkins Maven …

MySQL 数据库设计范式

第一范式&#xff08;1NF&#xff09; 每一列都是不可分割的原子数据项第二范式&#xff08;2NF&#xff09; 在1NF的基础上&#xff0c;非码属性必须完全依赖于候选码(在1NF基础上消除非主属性对主码的部分函数依赖) 1.函数依赖A->B&#xff0c;如果通过A属性(属性组)的值…

Transformer学习【从零理解】

Transformer 一、整体框架 二、Encoder 1.输入部分: &#xff08;1&#xff09;Embedding&#xff1a;将输入的词转换为对应的词向量。 &#xff08;2&#xff09;位置编码&#xff1a;因为保证输出时&#xff0c;顺序不会打乱&#xff0c;所以要加入时序信息即位置编码。 公…

如何避免AI网红经济泡沫?警惕细分行业的AI转型而不是转行

一、AI泡沫预防针 要避免AI相关新概念催生的网红经济泡沫&#xff0c;可以从多个角度采取措施&#xff1a; 1. **理性投资**&#xff1a; - 投资者应对AI项目和网红经济中的企业进行深入研究&#xff0c;了解其真实的技术实力、商业模式的可行性和盈利能力&#xff0c;而非…

代码随想录Day52:最长递增子序列、最长连续递增序列、最长重复子数组

最长递增子序列 class Solution { public:int lengthOfLIS(vector<int>& nums) {if(nums.size() < 1) return nums.size();vector<int> dp(nums.size(), 1);int res 0;for(int i 1; i < nums.size(); i){for(int j 0; j < i; j){if(nums[i] > …