卷积神经网络迁移学习:原理与实践指南

引言

        在深度学习领域,卷积神经网络(CNN)已经在计算机视觉任务中取得了巨大成功。然而,从头开始训练一个高性能的CNN模型需要大量标注数据和计算资源。迁移学习(Transfer Learning)技术为我们提供了一种高效解决方案,它能够将预训练模型的知识迁移到新任务中,显著减少训练时间和数据需求。本文将全面介绍CNN迁移学习的原理、优势及实践方法。

1、内容

      迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作

2、步骤

1、选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

2、冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

3、在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

4、微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

5、评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

Resnet网络:

原理:

       卷积神经网络都是通过卷积层和池化层的叠加组成的。 在实际的试验中发现,随着卷积层和池化层的叠加,学习效果不会逐渐变好,反而出现2个问题:

1、梯度消失和梯度爆炸

梯度消失:若每一层的误差梯度小于1,反向传播时,网络越深,梯度越趋近于0

梯度爆炸:若每一层的误差梯度大于1,反向传播时,网络越深,梯度越来越大

2、退化问题

       为了解决梯度消失或梯度爆炸问题,论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。

      为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)

1、18层resnet结构:

2、BN(Batch Normalization)

实例

1、导入相关的库

import torch
from torch.utils.data import DataLoader,Dataset  #数据包管理工具,打包数据,
from torchvision import transforms
from torch import nn
import torchvision.models as models
from PIL import Image
import numpy as np

2、调取模型并冻结参数

#不再需要自己来搭建模型了。预训练的文件也加载进去了。
# 将resnet18模型迁移到食物分类项目中.#残差网络是固定的网络结构,不需要你自己来类定义了。
resnet_model=models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#weights=models.ResNet18_Weights.DEFAULT表示使用在 ImageNet 数据集上预先训练好的权重
for param in resnet_model.parameters():print(param)param.requires_grad=False      #冻结
#模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数,

详细说明:

  1. models.resnet18()加载ResNet18架构

  2. weights=models.ResNet18_Weights.DEFAULT指定使用官方预训练权重

  3. 遍历所有参数并冻结

3、对网络模型进行微调

in_features=resnet_model.fc.in_features    #获取模型原输入的特征个数
resnet_model.fc=nn.Linear(in_features,20)  #创建一个全连接层

4、保存需要训练的参数

params_to_update=[]    #保存需要训练的参数,仅仅包含全连接层的参数
for param in resnet_model.parameters():if param.requires_grad==True:params_to_update.append(param)

5、数据预处理

data_transforms={
'train':
transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机选transforms.CenterCrop(224),  # 从中心开始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([transforms.Resize([224,224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
}

数据预处理包括:

  • 训练集使用多种数据增强(随机旋转、水平翻转等)

  • 验证集只进行简单的resize和归一化

  • 归一化参数使用ImageNet的均值和标准差

6、自定义数据集的类

class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image=Image.open(self.imgs[idx])if self.transform:image=self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,label

这个自定义Dataset类:

  1. 从文本文件读取图像路径和标签

  2. 实现__len____getitem__方法,供DataLoader使用

  3. 应用指定的transform处理图像

7、数据加载器准备

# 创建训练集和测试集实例
training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])# 创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

数据加载器提供:

  • 批量加载功能(batch_size=64)

  • 训练数据随机打乱(shuffle=True)

  • 多线程数据预读取

8、训练和测试流程

def train(dataloader,model,loss_fn,optimizer):model.train()   #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
#pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
#一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()for X,y in dataloader:       #其中batch为每一个数据的编号X,y=X.to(device),y.to(device)    #把训练数据集和标签传入cpu或GPUpred=model.forward(X)    #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化loss=loss_fn(pred,y)     #通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad()    #梯度值清零loss.backward()          #反向传播计算得到每个参数的梯度值woptimizer.step()         #根据梯度更新网络w参数best_acc=0
def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= size

9、循环训练

# 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失
optimizer = torch.optim.Adam(params_to_update, lr=0.001)  # Adam优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 学习率调度# 训练10个epoch
epoch = 10
for i in range(epoch):print(f"Epoch {i + 1}")train(train_dataloader, model, loss_fn, optimizer)  # 训练scheduler.step()  # 更新学习率test(test_dataloader, model, loss_fn)  # 测试print('Best accuracy:', best_acc)  # 打印最佳准确率

主循环流程:

  1. 定义损失函数和优化器

  2. 设置学习率调度器(每5个epoch学习率减半)

  3. 进行10轮训练和测试

  4. 打印最终最佳准确率

结果展示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/78467.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

图论---朴素Prim(稠密图)

O( n ^2 ) 题目通常会提示数据范围&#xff1a; 若 V ≤ 500&#xff0c;两种方法均可&#xff08;朴素Prim更稳&#xff09;。 若 V ≤ 1e5&#xff0c;必须用优先队列Prim vector 存图。 // 最小生成树 —朴素Prim #include<cstring> #include<iostream> #i…

Spring-Cache替换Keys为Scan—负优化?

背景 使用ORM工具是往往会配合缓存框架实现三级缓存提高查询效率&#xff0c;spring-cache配合redis是非常常规的实现方案&#xff0c;如未做特殊配置&#xff0c;CacheEvict(allEntries true) 的批量驱逐方式&#xff0c;默认使用keys的方式查询历史缓存列表而后delete&…

【N8N】Docker Desktop + WSL 安装过程(Docker Desktop - WSL update Failed解决方法)

背景说明&#xff1a; 因为要用n8n&#xff0c;官网推荐这个就下载了&#xff0c;然后又是一堆卡的安装问题记录过程。 1. 下载安装包 直接去官网Get Docker | Docker Docs下载 下载的是第一个windows - x86_64. &#xff08;*下面那个beta的感觉是测试版&#xff09; PS&am…

RT Thread 发生异常时打印输出cpu寄存器信息和栈数据

打印输出发生hardfault时,当前栈十六进制数据和cpu寄存器信息 在发生 HardFault 时,打印当前栈的十六进制数据和 CPU 寄存器信息是非常重要的调试手段。以下是如何实现这一功能的具体步骤和示例代码。 1. 实现 HardFault 处理函数 我们需要在 HardFault 中捕获异常上下文,…

【安装neo4j-5.26.5社区版 完整过程】

1. 安装java 下载 JDK21-windows官网地址 配置环境变量 在底下的系统变量中新建系统变量&#xff0c;变量名为JAVA_HOME21&#xff0c;变量值为JDK文件夹路径&#xff0c;默认为&#xff1a; C:\Program Files\Java\jdk-21然后在用户变量的Path中&#xff0c;添加下面两个&am…

android jatpack Compose 多数据源依赖处理:从状态管理到精准更新的架构设计

Android Compose 多接口数据依赖管理&#xff1a;ViewModel 状态共享最佳实践 &#x1f4cc; 问题背景 在 Jetpack Compose 开发中&#xff0c;经常遇到以下场景&#xff1a; 页面由多个独立接口数据组成&#xff08;如 Part1、Part2&#xff09;Part2 的某些 UI 需要依赖 P…

面试之消息队列

消息队列场景 什么是消息队列&#xff1f; 消息队列是一个使用队列来通信的组件&#xff0c;它的本质就是个转发器&#xff0c;包含发消息、存消息、消费消息。 消息队列怎么选型&#xff1f; 特性ActiveMQRabbitMQRocketMQKafka单机吞吐量万级万级10万级10万级时效性毫秒级…

GStreamer 简明教程(十一):插件开发,以一个音频生成(Audio Source)插件为例

系列文章目录 GStreamer 简明教程&#xff08;一&#xff09;&#xff1a;环境搭建&#xff0c;运行 Basic Tutorial 1 Hello world! GStreamer 简明教程&#xff08;二&#xff09;&#xff1a;基本概念介绍&#xff0c;Element 和 Pipeline GStreamer 简明教程&#xff08;三…

Linux kernel signal原理(下)- aarch64架构sigreturn流程

一、前言 在上篇中写到了linux中signal的处理流程&#xff0c;在do_signal信号处理的流程最后&#xff0c;会通过sigreturn再次回到线程现场&#xff0c;上篇文章中介绍了在X86_64架构下的实现&#xff0c;本篇中介绍下在aarch64架构下的实现原理。 二、sigaction系统调用 #i…

华为OD机试真题——简易内存池(2025A卷:200分)Java/python/JavaScript/C++/C/GO最佳实现

2025 A卷 200分 题型 本文涵盖详细的问题分析、解题思路、代码实现、代码详解、测试用例以及综合分析&#xff1b; 并提供Java、python、JavaScript、C、C语言、GO六种语言的最佳实现方式&#xff01; 本文收录于专栏&#xff1a;《2025华为OD真题目录全流程解析/备考攻略/经验…

腾讯一面面经:总结一下

1. Java 中的 和 equals 有什么区别&#xff1f;比较对象时使用哪一个 1. 操作符&#xff1a; 用于比较对象的内存地址&#xff08;引用是否相同&#xff09;。 对于基本数据类型、 比较的是值。&#xff08;8种基本数据类型&#xff09;对于引用数据类型、 比较的是两个引…

计算机网络中的DHCP是什么呀? 详情解答

目录 DHCP 是什么&#xff1f; DHCP 的工作原理 主要功能 DHCP 与网络安全的关系 1. 正面作用 2. 潜在安全风险 DHCP 的已知漏洞 1. 协议设计缺陷 2. 软件实现漏洞 3. 配置错误导致的漏洞 4. 已知漏洞总结 举例说明 DHCP 与网络安全 如何提升 DHCP 安全性 总结 D…

2025 年导游证报考条件新政策解读与应对策略

2025 年导游证报考政策有了不少新变化&#xff0c;这些变化会对报考者产生哪些影响&#xff1f;我们又该如何应对&#xff1f;下面就为大家详细解读新政策&#xff0c;并提供实用的应对策略。 最引人注目的变化当属中职旅游类专业学生的报考政策。以往&#xff0c;中专学历报考…

【物联网】基于LORA组网的远程环境监测系统设计(ThingsCloud云平台版)

演示视频: 基于LORA组网的远程环境监测系统设计(ThingsCloud云平台版) 前言:本设计是基于ThingsCloud云平台版,还有另外一个版本是基于机智云平台版本,两个设计只是云平台和手机APP的区别,其他功能都一样。如下链接: 【物联网】基于LORA组网的远程环境监测系统设计(机…

SQL 函数进行左边自动补位fnPadLeft和FORMAT

目录 1.问题 2.解决 方式1 方式2 3.结果 1.问题 例如在SQL存储过程中&#xff0c;将1 或10 或 100 长度不足的时候&#xff0c;自动补足长度。 例如 1 → 001 10→ 010 100→100 2.解决 方式1 SELECT FORMAT (1, 000) AS FormattedNum; SELECT FORMAT(12, 000) AS Form…

Nacos简介—2.Nacos的原理简介

大纲 1.Nacos集群模式的数据写入存储与读取问题 2.基于Distro协议在启动后的运行规则 3.基于Distro协议在处理服务实例注册时的写路由 4.由于写路由造成的数据分片以及随机读问题 5.写路由 数据分区 读路由的CP方案分析 6.基于Distro协议的定时同步机制 7.基于Distro协…

中电金信联合阿里云推出智能陪练Agent

在金融业加速数智化转型的今天&#xff0c;提升服务效率与改善用户体验已成为行业升级的核心方向。面对这一趋势&#xff0c;智能体与智能陪练的结合应用&#xff0c;正帮助金融机构突破传统业务模式&#xff0c;开拓更具竞争力的创新机遇。 在近日召开的阿里云AI势能大会期间&…

十分钟恢复服务器攻击——群联AI云防护系统实战

场景描述 服务器遭遇大规模DDoS攻击&#xff0c;导致服务不可用。通过群联AI云防护系统的分布式节点和智能调度功能&#xff0c;快速切换流量至安全节点&#xff0c;清洗恶意流量&#xff0c;10分钟内恢复业务。 技术实现步骤 1. 启用智能调度API触发节点切换 群联系统提供RE…

LLM量化技术全景:GPTQ、QAT、AWQ、GGUF与GGML

01 引言 本文介绍的是在 LLM 讨论中经常听到的各种量化技术。本文的目的是提供一步一步的解释和代码&#xff0c;让大家可以自己使用这些技术来压缩模型。 闲话少说&#xff0c;我们来研究一下吧&#xff01; 02 Quantization 量化是指将高精度数字转换为低精度数字。低精…

IP的基础知识以及相关机制

IP地址 1.IP地址的概念 IP地址是分配给连接到互联网或局域网中的每一个设备的唯一标识符 也就是说IP地址是你设备在网络中的定位~ 2.IP版本~ IP版本分为IPv4和IPv6&#xff0c;目前我们最常用的还是IPv4~~但是IPv4有个缺点就是地址到现在为止&#xff0c;已经接近枯竭~~&…