深度学习-迁移学习

深度学习中的迁移学习是通过在大规模数据上训练的模型,将其知识迁移到数据相对较少的相关任务中,能显著提升目标任务的模型性能。


一、迁移学习的核心概念

  1. 源任务(Source Task)与目标任务(Target Task)

    (1)源任务:通常拥有大量标注数据以及预训练好的模型,模型可以从中提取到通用特征。(2)目标任务:数据量相对有限,与源任务有相似性,但需要迁移模型知识适应特定的需求。
  2. 特征迁移

    (1)深度学习模型的层级结构有“自下而上”的特征表示,底层(如边缘、形状特征)更通用高层特征(如复杂纹理、特定形状)更具体。(2)迁移学习通过保留底层特征,并微调高层特征以适应新任务。
  3. 微调与冻结

    (1)冻结:冻结模型底层权重,保留已学到的底层特征,适合用于不同数据但相似的任务。(2)微调:对高层权重进行少量训练,使其适应目标任务,适用于源、目标任务有一定关联的情况。
  4. 模型剪枝与特征选择

    (1)剪枝可以减少模型复杂度,提升推理速度,适合在特定硬件上优化迁移模型的性能。

二、迁移学习的策略及示意图

迁移学习主要有以下策略,每个策略适用于不同场景。

1. 特征提取策略(Feature Extraction)
  • 使用预训练模型的卷积层作为固定的特征提取器,只在输出部分添加新的全连接层或分类层。
  • 应用于源任务和目标任务相似度较高的情况(如图像分类任务)。

代码示例

from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Flatten# 加载预训练的 VGG16 模型,不包含顶层
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))# 将卷积层的权重冻结
for layer in base_model.layers:layer.trainable = False# 添加新的全连接层
x = Flatten()(base_model.output)
output = Dense(10, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=output)
2. 微调策略(Fine-tuning)
  • 在预训练模型的基础上保留底层特征,微调高层特征,适应新的目标任务。适合在源任务和目标任务高度相似时使用。

代码示例

# 微调部分卷积层
for layer in base_model.layers[:15]:layer.trainable = False
for layer in base_model.layers[15:]:layer.trainable = True

3. 跨领域迁移(Cross-domain Transfer)
  • 针对不同领域任务的特征迁移策略,如图像到文本、语音到文本的跨领域迁移。需要添加或替换特定的适应层以完成不同领域的转换。

三、迁移学习的代码实现示例

以下代码展示了在 ImageNet 预训练的 VGG16 模型上,通过冻结部分卷积层并添加自定义全连接层,用于一个新的分类任务(如猫狗分类)。

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 1. 加载预训练的 VGG16 模型
vgg16 = models.vgg16(pretrained=True)# 2. 冻结前面的卷积层
for param in vgg16.features.parameters():param.requires_grad = False# 3. 修改分类器部分,适应猫狗二分类任务
# 获取 VGG16 的输入特征数,并替换最后一层为适合二分类的线性层
num_features = vgg16.classifier[6].in_features
vgg16.classifier[6] = nn.Linear(num_features, 2)  # 2 classes for binary classification# 4. 定义训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vgg16 = vgg16.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(vgg16.classifier[6].parameters(), lr=0.001)  # 只更新最后一层参数# 5. 定义数据预处理和加载
data_transforms = {'train': transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}train_dataset = datasets.ImageFolder(root='data/train', transform=data_transforms['train'])
val_dataset = datasets.ImageFolder(root='data/val', transform=data_transforms['val'])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)# 6. 训练模型
def train_model(model, criterion, optimizer, num_epochs=10):for epoch in range(num_epochs):model.train()running_loss = 0.0correct = 0for inputs, labels in train_loader:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 统计损失和准确率running_loss += loss.item() * inputs.size(0)_, preds = torch.max(outputs, 1)correct += torch.sum(preds == labels)epoch_loss = running_loss / len(train_loader.dataset)epoch_acc = correct.double() / len(train_loader.dataset)print(f'Epoch {epoch}/{num_epochs - 1} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}')# 7. 调用训练函数
train_model(vgg16, criterion, optimizer, num_epochs=10)

  • 冻结卷积层:使用 for param in vgg16.features.parameters(): param.requires_grad = False 冻结了 vgg16.features 中的参数,使其在训练中不更新。

  • 修改分类层:更改 vgg16.classifier[6] 中的最后一个线性层,使其适应二分类任务(猫狗分类)。

  • 数据预处理与加载:利用 transforms 进行图像的标准化和尺寸调整,确保模型输入一致,加载后的数据放入 DataLoader 中便于批量处理。

  • 训练循环:在 train_model 函数中进行批次训练,计算损失并更新模型参数。

四、迁移学习的实际应用场景

  1. 图像分类:用于医疗影像分析、卫星图像识别等。例如使用 ImageNet 预训练模型进行皮肤癌检测。
  2. 目标检测与分割:自动驾驶中的行人检测、视频监控中的异常事件检测等。
  3. 自然语言处理:在 BERT、GPT-3 等预训练模型基础上微调,以适应情感分析、文本分类等任务。
  4. 语音识别:预训练语音模型可用于语音情感识别、口音识别等任务。

五、迁移学习的优缺点

优点

  • 数据需求少:不需要大量标注数据,可以显著缩短模型开发时间。
  • 训练高效:利用已有模型权重,减少训练时间。
  • 泛化能力强:预训练模型在大数据上学到的特征更具普适性,提高目标任务的泛化能力。

缺点

  • 源任务与目标任务的相似性要求:源任务和目标任务若差异较大,迁移效果会明显下降。
  • 存在偏差风险:源任务的偏差可能会迁移到目标任务中,对任务结果产生负面影响。
  • 额外存储开销:需要存储源模型的权重,对计算和存储资源有额外要求。

六、迁移学习的注意事项

  1. 选择合适的源任务:尽量选择与目标任务具有相似特征的源任务模型。
  2. 调整学习率:微调时的学习率应小于源任务,避免过度改变预训练模型的特征。
  3. 慎重选择微调层数:微调的层数应考虑目标任务的复杂性,避免过拟合。
  4. 数据预处理保持一致:确保源任务和目标任务的数据预处理方式一致,否则会影响模型性能。

七、总结

迁移学习在深度学习应用中已成为提升模型训练效率和性能的关键技术,尤其在目标任务与源任务具有一定关联性、且标注数据有限的情况下效果尤为显著。迁移学习通过利用在大规模数据集(如 ImageNet)上预训练的模型知识,将其迁移到新任务中,减少了对大规模数据和计算资源的需求。不同的迁移学习策略(如特征提取、微调、参数冻结等)能够针对性地调整模型层级的学习参数,实现高效的模型适应性。深入理解和灵活应用这些策略是深度学习项目开发的重要技能,能够在分类、检测、分割、文本分析等领域中有效缩短训练周期,并在数据有限的情况下显著提升模型的泛化性能和准确性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/57929.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

H7-TOOL自制Flash读写保护算法系列,为兆易创新GD32E23X制作使能和解除算法,支持在线烧录和脱机烧录使用(2024-10-29)

说明: 很多IC厂家仅发布了内部Flash算法文件,并没有提供读写保护算法文件,也就是选项字节算法文件,需要我们制作。 实际上当前已经发布的TOOL版本,已经自制很多了。但是依然有些厂家还没自制,所以陆续开始…

flutter 写个简单的界面

起因, 目的: 来源: 客户需求。 着急要,我随便写的,应付一下。 过程: 略,直接看代码,看注释。 代码 1 xxx import package:flutter/material.dart;void main() {runApp(const MyApp()); }// # class MyApp extends…

.NET 8 中 Entity Framework Core 的使用

本文代码:https://download.csdn.net/download/hefeng_aspnet/89935738 概述 Entity Framework Core (EF Core) 已成为 .NET 开发中数据访问的基石工具,为开发人员提供了强大而多功能的解决方案。随着 .NET 8 和 C# 10 中引入的改进,开发人…

推荐一款可视化和检查原始数据的工具:RawDigger

RawDigger是一款强大的工具,旨在可视化和检查相机记录的原始数据。它被称为一种“显微镜”,使用户能够深入分析原始图像数据,而不对其进行任何更改。RawDigger并不是一个原始转换器,而是一个帮助用户查看将由转换器使用的数据的工…

第三十三章 Vue路由进阶路由模块封装

目录 一、引言 二、完整代码 main.js index.js App.vue Find.vue My.vue 一、引言 在上一个章节中,我们将所有的路由配置都堆在main.js中来实现路径组件的路由,这样做的话非常不利于我们后期对项目的维护。因此正确的做法是将路由模块抽离出来&a…

基于java+SpringBoot+Vue的新闻推荐系统设计与实现

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis Maven mysql5.7或8.0等等组成&#x…

指派问题的求解

实验类型:◆验证性实验 ◇综合性实验 ◇设计性实验 实验目的:学会使用Matlab求解指派问题。 实验内容:利用Matlab编程实现枚举法求解指派问题。 实验例题:有5人分别对应完成5项工作,其各自的耗费如下表所示&#…

商品满减、限时活动、折扣活动的计算最划算 golang

可以对商品的不同活动(如满减、限时价和折扣)进行分组,并在购物车中显示各个活动标签下的最优价格组合。以下代码将商品按活动类别进行分组计算,并输出在购物车中的显示信息。 package mainimport ("fmt""math&qu…

AWS RDS Oracle hit ORA-39405

报错信息: ORA-39405: Oracle Data Pump does not support importing from a source database with TSTZ version 42 into a target database with TSTZ version 35. 分析过程: 这个报错是由于timezone_file的版本,源端比目标端高&#xf…

显卡服务器和普通服务器之间的区别有哪些?

显卡服务器也被称之为GPU服务器,显卡服务器与普通的服务器之间有着很明显的区别,下面就让我们共同来了解一下吧! 普通服务器的主要处理器通常都是配备的中央处理器,可以用于执行大部分通用计算任务和操作系统的管理;而…

下载安装COPT+如何在jupyter中使用(安装心得,windows,最新7.2版本)

目录 1.到杉树科技官网申请下载COPT 2.安装COPT&配置许可文件 3.在jupyter中使用COPT的python接口 最近看到一本和数学建模有关的新书:《数学建模与数学规划:方法、案例及编程实战》,作为数学建模老手,肯定要学习一下&…

Python自动化运维:技能掌握与快速入门指南

#编程小白如何成为大神?大学生的最佳入门攻略# 在当今快速发展的IT行业中,Python自动化运维已经成为了一个不可或缺的技能。本文将为您详细介绍Python自动化运维所需的技能,并提供快速入门的资源,帮助您迅速掌握这一领域。 必备…

深入理解跨域资源共享(CORS)安全问题原理及解决思路

目录 引言 CORS 基础 CORS 安全问题原理 解决思路 结论 引言 跨域资源共享(CORS, Cross-Origin Resource Sharing)是现代Web应用中不可或缺的一部分,特别是在前后端分离的架构中。CORS允许一个域上的Web应用请求另一个域上的资源&#…

基于“互联网+”医养结合的智慧养老实训室建设方案

一、建设背景 根据国家统计局的数据,截至2023年末,我国60岁及以上的老年人口已达到29,697万人,占总人口的21.1%;其中,65岁及以上的人口为21,676万人,占总人口的15.4%。这一数据表明,我国正面临…

为什么需要MQ消息系统,mysql 不能满足需求吗?

大家好,我是锋哥。今天分享关于【为什么需要MQ消息系统,mysql 不能满足需求吗?】面试题?希望对大家有帮助; 为什么需要MQ消息系统,mysql 不能满足需求吗? 1000道 互联网大厂Java工程师 精选面试…

C++编程法则365天一天一条(303)异步编程之std::promise和std::future

文章目录 主要特点基本用法示例代码使用场景注意事项std::promise 是 C++11 引入的一个模板类,位于 <future> 头文件,用于实现异步操作中的值传递和异常传递。它与 std::future 一起使用,提供了一种机制,使得一个线程可以将结果或异常传递给另一个线程。 主要特点 异…

计算机网络-以太网小结

前导码与帧开始分界符有什么区别? 前导码--解决帧同步/时钟同步问题 帧开始分界符-解决帧对界问题 集线器 集线器通过双绞线连接终端, 学校机房的里面就有集线器 这种方式仍然属于共享式以太网, 传播方式依然是广播 网桥: 工作特点: 1.如果转发表中存在数据接收方的端口信息…

C/C++常用编译工具链:GCC,Clang

目录 GNU Compiler Collection GCC的优势 编译产生的中间文件 Clang Clang的特点 什么是LLVM&#xff1f; Clang编译过程中产生的中间表示文件 关于Clang的调试 C 编译工具链中有几个主要的编译工具&#xff0c;包括&#xff1a; GNU Compiler Collection (GCC…

NNLM——预测下一个单词

一、原理篇 NNLM&#xff08;Neural Network Language Model&#xff0c;神经网络语言模型&#xff09;是一种通过神经网络进行语言建模的技术&#xff0c;通常用于预测序列中的下一个词。 NNLM的核心思想是使用词嵌入&#xff08;word embedding&#xff09;将词转换为低维向…

【C++】类和对象(十二):实现日期类

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家了解C的实现日期类&#xff0c;如果你觉得我写的还不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 目录 1 /!/>/</>/<运算符重载2 /-//-运算符重载(A) 先写&#xff0c;再通过写(B…