深度学习_微调_7

目标

  • 微调的原理
  • 利用微调模型来完成图像的分类任务

微调的原理

微调(Fine-tuning)是一种在深度学习中广泛应用的技术,特别是在预训练模型(Pretrained-Models)的基础上进行定制化训练的过程。微调的基本原理和步骤如下:

  1. 预训练阶段

    • 微调通常始于一个已经在大规模数据集上预训练过的模型,例如预训练的神经网络模型(如BERT、GPT系列、Vision Transformer等)。
    • 预训练模型在诸如自然语言处理(NLP)或计算机视觉(CV)等任务上进行了自我监督学习或有监督学习,从而习得了大量的通用特征表示。
  2. 适应新任务

    • 当需要解决与预训练任务相关的但更加具体的新任务时,就可以利用微调技术。
    • 对于NLP任务,可能是训练模型去回答问题、生成文本或者分类;对于CV任务,可能是对特定种类的物体进行识别或定位。
  3. 模型结构调整

    • 根据新任务的需求,可能需要对预训练模型的部分或全部层进行微调。
    • 通常在迁移学习中,预训练模型的底层可以捕获非常通用的特征,因此通常会被保留并微调,而顶层(尤其是分类层)会替换为新的、与新任务适配的输出层。
  4. 参数更新

    • 微调过程中,模型在新任务的数据集上重新开始训练,不过并非从随机初始化参数开始,而是从预训练模型的参数开始。
    • 学习率通常会设置得相对较低,以免破坏预训练模型学到的良好特征表示。
    • 其他优化参数(如权重衰减、批归一化层的状态等)也可能根据新任务的特点进行调整。
  5. 学习过程

    • 在目标数据集上训练模型时,不仅会更新新增的输出层参数,也会对预训练模型的某些层参数进行微调,使其更好地适应新任务的数据分布和特征。
    • 由于预训练模型已经具备良好的初始化,因此在相对较小的数据集上进行微调时,模型往往能够更快收敛到较好的解。

总结起来,微调的原理是利用预训练模型中的已学知识作为初始状态,通过对新任务数据的训练,对模型参数进行针对性的更新和优化,从而使模型能够适应新的应用场景。相较于从头训练,微调大大减少了所需的训练时间和数据量,提高了模型在特定任务上的性能和泛化能力。

1.微调

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效。

微调由以下4步构成。

  1. 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  2. 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  3. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  4. 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。

当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力。

2.热狗识别

接下来我们来实践一个具体的例子:热狗识别。将基于一个小数据集对在ImageNet数据集上训练好的ResNet模型进行微调。该小数据集含有数千张热狗或者其他事物的图像。我们将使用微调得到的模型来识别一张图像中是否包含热狗。

首先,导入实验所需的工具包。

import tensorflow as tf
import numpy as np

2.1 获取数据集

我们首先将数据集放在路径hotdog/data之下:

每个类别文件夹里面是图像文件。

上一节中我们介绍了ImageDataGenerator进行图像增强,我们可以通过以下方法读取图像文件,该方法以文件夹路径为参数,生成经过图像增强后的结果,并产生batch数据:

flow_from_directory(self, directory,target_size=(256, 256), color_mode='rgb',classes=None, class_mode='categorical',batch_size=32, shuffle=True, seed=None,save_to_dir=None)

主要参数:

  • directory: 目标文件夹路径,对于每一个类对应一个子文件夹,该子文件夹中任何JPG、PNG、BNP、PPM的图片都可以读取。
  • target_size: 默认为(256, 256),图像将被resize成该尺寸。
  • batch_size: batch数据的大小,默认32。
  • shuffle: 是否打乱数据,默认为True。

我们创建两个tf.keras.preprocessing.image.ImageDataGenerator实例来分别读取训练数据集和测试数据集中的所有图像文件。将训练集图片全部处理为高和宽均为224像素的输入。此外,我们对RGB(红、绿、蓝)三个颜色通道的数值做标准化。

# 获取数据集
import pathlib
train_dir = 'transferdata/train'
test_dir = 'transferdata/test'
# 获取训练集数据
train_dir = pathlib.Path(train_dir)
train_count = len(list(train_dir.glob('*/*.jpg')))
# 获取测试集数据
test_dir = pathlib.Path(test_dir)
test_count = len(list(test_dir.glob('*/*.jpg')))
# 创建imageDataGenerator进行图像处理
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
# 设置参数
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
# 获取训练数据
train_data_gen = image_generator.flow_from_directory(directory=str(train_dir),batch_size=BATCH_SIZE,target_size=(IMG_HEIGHT, IMG_WIDTH),shuffle=True)
# 获取测试数据
test_data_gen = image_generator.flow_from_directory(directory=str(test_dir),batch_size=BATCH_SIZE,target_size=(IMG_HEIGHT, IMG_WIDTH),shuffle=True)

下面我们随机取1个batch的图片然后绘制出来。

import matplotlib.pyplot as plt
# 显示图像
def show_batch(image_batch, label_batch):plt.figure(figsize=(10,10))for n in range(15):ax = plt.subplot(5,5,n+1)plt.imshow(image_batch[n])plt.axis('off')
# 随机选择一个batch的图像        
image_batch, label_batch = next(train_data_gen)
# 图像显示
show_batch(image_batch, label_batch)

2.2 模型构建与训练

我们使用在ImageNet数据集上预训练的ResNet-50作为源模型。这里指定weights='imagenet'来自动下载并加载预训练的模型参数。在第一次使用时需要联网下载模型参数。

Keras应用程序(keras.applications)是具有预先训练权值的固定架构,该类封装了很多重量级的网络架构,如下图所示:

实现时实例化模型架构:

tf.keras.applications.ResNet50(include_top=True, weights='imagenet', input_tensor=None, input_shape=None,pooling=None, classes=1000, **kwargs
)

主要参数:

  • include_top: 是否包括顶层的全连接层。
  • weights: None 代表随机初始化, 'imagenet' 代表加载在 ImageNet 上预训练的权值。
  • input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (224, 224, 3)(channels_last 格式)或 (3, 224, 224)(channels_first 格式)。它必须为 3 个输入通道,且宽高必须不小于 32,比如 (200, 200, 3) 是一个合法的输入尺寸。

在该案例中我们使用resNet50预训练模型构建模型:

# 加载预训练模型
ResNet50 = tf.keras.applications.ResNet50(weights='imagenet', input_shape=(224,224,3))
# 设置所有层不可训练
for layer in ResNet50.layers:layer.trainable = False
# 设置模型
net = tf.keras.models.Sequential()
# 预训练模型
net.add(ResNet50)
# 展开
net.add(tf.keras.layers.Flatten())
# 二分类的全连接层
net.add(tf.keras.layers.Dense(2, activation='softmax'))

接下来我们使用之前定义好的ImageGenerator将训练集图片送入ResNet50进行训练。

# 模型编译:指定优化器,损失函数和评价指标
net.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
# 模型训练:指定数据,每一个epoch中只运行10个迭代,指定验证数据集
history = net.fit(train_data_gen,steps_per_epoch=10,epochs=3,validation_data=test_data_gen,validation_steps=10)

Epoch 1/3
10/10 [==============================] - 28s 3s/step - loss: 0.6931 - accuracy: 0.5031 - val_loss: 0.6930 - val_accuracy: 0.5094
Epoch 2/3
10/10 [==============================] - 29s 3s/step - loss: 0.6932 - accuracy: 0.5094 - val_loss: 0.6935 - val_accuracy: 0.4812
Epoch 3/3
10/10 [==============================] - 31s 3s/step - loss: 0.6935 - accuracy: 0.4844 - val_loss: 0.6933 - val_accuracy: 0.4875

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/757724.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【项目】YOLOv5+PaddleOCR实现艺术字验证码识别

YOLOv5PaddleOCR实现艺术字类验证码识别 一、引言1.1 实现目标1.2 人手动点选验证码逻辑1.3 计算机点选逻辑 二、计算机验证方法2.1 PaddleOCR下方文字识别方法2.2 YOLOv5目标检测方法2.3 艺术字分类方法2.4 返回结果 三、代码获取 一、引言 1.1 实现目标 要识别的验证码类型…

c语言综合练习题

1.编写程序实现键盘输入一个学生的学分绩点 score&#xff08;合法的范围为:1.0—5.0&#xff09;&#xff0c;根据学生的学分绩点判定该学 生的奖学金的等级&#xff0c;判定规则如下表所示。 #include <stdio.h>int main() {float score;printf("请输入学生的学分…

Harbor-私有镜像仓库

目录 一、Harbor 原理说明 1.软件资源介绍 2.Harbor 特性 3.Harbor 认证过程 4.Harbor 认证流程 二、私有镜像仓库实验 1.环境准备 2.安装docker 3.配置镜像加速和私有仓库地址 4.搭建harbor仓库 5.本地windows浏览器访问配置 一、Harbor 原理说明 1.软件资源介绍 …

面试算法-62-盛最多水的容器

题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回容器可以储存的最大水量。 说明&#xff1a;你不能倾斜容器。…

CycleGAN训练及测试过程细节记录

CycleGAN训练及测试过程细节记录 文章目录 关于训练关于测试 关于训练 1、训练前将数据配置好&#xff0c;并在Pycharm中写好配置信息 2、关于训练过程的参数配置在 options/train_options.py options/base_options.py batch_size&#xff1a;批大小 crop_size&#xff1a;…

Android分区存储到底该怎么做

文章目录 一、Android存储结构二、什么是分区存储&#xff1f;三、私有目录和公有目录三、存储权限和分区存储有什么关系&#xff1f;四、我们应该该怎么做适配&#xff1f;4.1、利用File进行操作4.2、使用MediaStore操作数据库 一、Android存储结构 Android存储分为内部存储和…

支付宝手机网站支付,微信扫描二维码支付

支付宝手机网站支付 支付宝文档 响应示例 <form name"punchout_form" method"post" action"https://openapi.alipay.com/gateway.do?charsetUTF-8&methodalipay.trade.wap.pay&formatjson&signERITJKEIJKJHKKKKKKKHJEREEEEEEEEEEE…

MySQL 数据库设计范式

第一范式&#xff08;1NF&#xff09; 每一列都是不可分割的原子数据项第二范式&#xff08;2NF&#xff09; 在1NF的基础上&#xff0c;非码属性必须完全依赖于候选码(在1NF基础上消除非主属性对主码的部分函数依赖) 1.函数依赖A->B&#xff0c;如果通过A属性(属性组)的值…

Transformer学习【从零理解】

Transformer 一、整体框架 二、Encoder 1.输入部分: &#xff08;1&#xff09;Embedding&#xff1a;将输入的词转换为对应的词向量。 &#xff08;2&#xff09;位置编码&#xff1a;因为保证输出时&#xff0c;顺序不会打乱&#xff0c;所以要加入时序信息即位置编码。 公…

如何避免AI网红经济泡沫?警惕细分行业的AI转型而不是转行

一、AI泡沫预防针 要避免AI相关新概念催生的网红经济泡沫&#xff0c;可以从多个角度采取措施&#xff1a; 1. **理性投资**&#xff1a; - 投资者应对AI项目和网红经济中的企业进行深入研究&#xff0c;了解其真实的技术实力、商业模式的可行性和盈利能力&#xff0c;而非…

初识GO语言

是由google公司推出的一门编程语言&#xff0c;12年推出的第一个版本 Go的特点 Go为什么能在最近的IT领域炙手可热 集python简洁&C语言的性能于一身 21世纪的C语言 顺应容器化时代的到来 区块链的崛起 学习一门编程语言可以划分为下面这三个步骤 安装 编译器 or 解…

JAVA多线程之synchronized锁

文章目录 1. 临界区2. synchronized使用2.1 不加锁实现2.2 synchronized加锁2.3 面向对象的改进2.4 方法上加synchronized2.5 线程安全 3. Monitor3.1 Java对象头3.2 Monitor工作流程3.3 字节码角度 4. synchronized原理4.1 轻量级锁4.2 锁膨胀4.3 偏向锁4.3.1 偏向锁过程4.3.2…

【链表】Leetcode 2. 两数相加【中等】

两数相加 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c; 并且每个节点只能存储 一位 数字。请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。你可以假设除了数字 0 之外&#xff0c;这两个数都不…

Redis数据结构对象中的对象共享、对象的空转时长

对象共享 概述 除了用于实现引用计数内存回收机制之外&#xff0c;对象的引用计数属性还带有对象共享的作用。 在Redis中&#xff0c;让多个键共享同一个值对象需要执行以下两个步骤: 1.将数据库键的值指针指向一个现有的值对象2.将被共享的值对象的引用计数增一 目前来说…

pytorch 实现线性回归(Pytorch 03)

一 从零实现线性回归 1.1 生成训练数据 原始 计算公式&#xff0c; 我们先使用该公式生成一批数据&#xff0c;然后使用 结果数据去计算 计算 w1, w2 和 b。 %matplotlib inline import random import torch from d2l import torch as d2ldef synthetic_data(w, b, num_ex…

基于springboot+vue的餐饮管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、阿里云专家博主、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战&#xff0c;欢迎高校老师\讲师\同行交流合作 ​主要内容&#xff1a;毕业设计(Javaweb项目|小程序|Pyt…

Java学习笔记21——使用JDBC访问MySQL数据库

JDBC&#xff08;Java Database Connectivity&#xff0c;Java数据库连接&#xff09;是应用程序编程借口&#xff08;API&#xff09;&#xff0c;描述了一套访问关系数据库的标准Java类库。可以在程序中使用这些API&#xff0c;连接到关系数据库&#xff0c;执行SQL语句&…

IDEA Git恢复DropCommit删除的提交

刚刚Dorp commit了&#xff0c;本地代码也被删除了&#xff0c;如何恢复呢&#xff0c; 从项目中登录git&#xff0c;找到刚刚的commit代码&#xff0c;如下所示&#xff1a;输入命令git reflog 复制代码&#xff0c;到idea中&#xff0c;打开GIt&#xff0c;找到RESET HEAD, …

初始 Navicat BI 工具

早前&#xff0c;海外 LearnBI online 博主 Adam Finer 对 Navicat Charts Creator 这款 BI&#xff08;商业智能&#xff09;工具进行了真实的测评。今天&#xff0c;我们来看下他对 Navicat BI 工具的初始之感&#xff0c;希望这能给用户一些启发与建议。LearnBI online 作为…

《计算机考研精炼1000题》为你考研之路保驾护航

创作背景 在这个充满挑战与竞争的时代&#xff0c;每一位考生在备战研究生考试的过程中&#xff0c;都希望通过更多符合考纲要求的练习题来提高自己的知识和技能。为了满足这一需求&#xff0c;我们精心策划和编辑了这本《计算机考研精炼1000题》。在考研政治和考研数学领域&a…