基于CNN卷积神经网络迁移学习的图像识别实现

基于CNN卷积神经网络迁移学习的图像识别实现

  • 基于CNN卷积神经网络迁移学习的图像识别实现
    • 写在前面
    • 一,原理介绍
      • 迁移学习的基本方法
        • 1.样本迁移(Instance based TL)
        • 2.特征迁移(Feature based TL)
        • 3.模型迁移(Parameter based TL)
      • 4.关系迁移(Relation based TL)
    • 二. 准备工作
      • 1.依赖库安装
      • 2.IDE设置
      • 3.检查GPU-cuda核心是否可用
    • 三. 具体实现
      • 1.导入所需软件包
      • 2.加载数据
      • 3.可视化数据
      • 4.训练模型
      • 5.可视化模型预测
      • 6.卷积神经网络微调
      • 7.评估和训练
      • 8.神经网络固定特征提取
      • 9.自定义测试集测试
      • 10.结论
        • 1. 微调(Fine-tuning)
        • 2. 使用固定特征提取器(Fixed Feature Extractor)
      • 总结
    • 四.全部代码
    • 写在最后

基于CNN卷积神经网络迁移学习的图像识别实现

写在前面

笔者是一名ADAS底层软件工程师。在繁忙的嵌入式软件开发工作之余,我对新技术的保持浓厚兴趣。近年来,深度学习特别是卷积神经网络(CNN)的迅猛发展。尽管我的主要工作集中在车载系统和嵌入式应用,但我深知新技术对未来的巨大潜力。因此,我自学CNN及其在迁移学习中的应用,并希望将自己的学习经验整理成文。这篇博客不仅是我个人学习的总结,也是希望为那些对CNN迁移学习感兴趣的同学提供实用的参考。通过这篇笔记,我将分享一些关键概念和部署经验,包括如何进行模型微调、固定特征提取器的使用方法,以及如何利用训练好的模型进行实际预测。我真诚希望这篇博客能对大家有所帮助,欢迎大家在评论区留言交流,共同探讨和学习!

一,原理介绍

们常常将迁移学习和神经网络的训练上存在误区将其混为一谈。实际上,这两个概念最初是独立的。迁移学习是机器学习的一个分支,其中有许多方法并不依赖于神经网络。然而,随着神经网络的快速发展、强大能力和广泛应用,迁移学习的研究逐渐与神经网络紧密联系起来。

迁移学习(transfer learning)通俗来讲,就是运用已有的知识来学习新的知识,核心是找到已有知识和新知识之间的相似性,用成语来说就是举一反三。由于直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。比如,已经会下中国象棋,就可以类比着来学习国际象棋:已经会编写Java程序,就可以类比着来学习C#;已经学会英语,就可以类比着来学习法语;等等。世间万事万物皆有共性,如何合理地找寻它们之间的相似性,进而利用这个桥梁来帮助学习新知识,是迁移学习的核心问题。

迁移学习的基本方法

1.样本迁移(Instance based TL)

在源域中找到与目标域相似的数据,把这个数据的权值进行调整,使得新的数据与目标域的数据进行匹配。下图的例子就是找到源域的例子3,然后加重该样本的权值,使得在预测目标域时的比重加大。优点是方法简单,实现容易。缺点在于权重的选择与相似度的度量依赖经验,且源域与目标域的数据分布往往不同。
在这里插入图片描述

2.特征迁移(Feature based TL)

假设源域和目标域含有一些共同的交叉特征,通过特征变换,将源域和目标域的特征变换到相同空间,使得该空间中源域数据与目标域数据具有相同分布的数据分布,然后进行传统的机器学习。优点是对大多数方法适用,效果较好。缺点在于难于求解,容易发生过适配。

在这里插入图片描述

3.模型迁移(Parameter based TL)

假设源域和目标域共享模型参数,是指将之前在源域中通过大量数据训练好的模型应用到目标域上进行预测,比如利用上千万的图象来训练好一个图象识别的系统,当我们遇到一个新的图象领域问题的时候,就不用再去找几千万个图象来训练了,只需把原来训练好的模型迁移到新的领域,在新的领域往往只需几万张图片就够,同样可以得到很高的精度。优点是可以充分利用模型之间存在的相似性。缺点在于模型参数不易收敛。

在这里插入图片描述

4.关系迁移(Relation based TL)

假设两个域是相似的,那么它们之间会共享某种相似关系,将源域中逻辑网络关系应用到目标域上来进行迁移,比方说生物病毒传播到计算机病毒传播的迁移。

在这里插入图片描述

对于CNN的迁移学习网上有很多大神的讲解都非常精彩,笔者只是简单介绍基本的概念,只要呢让大家明白为什么要进行迁移学习足以,我们还是着手实践,格物致知。

文章推荐:
链接: 微软亚洲研究院对迁移学习问题的回答

二. 准备工作

关于开发环境笔者是用Anaconda+PyCharm,个人认为这样包管理和开发都比较方便,当然因人而异,适合自己就好

1.依赖库安装

我将CondaList打出来,各位对照着版本安装就可以conda list

# Name                    Version                   Build  Channel
ca-certificates           2024.7.2             haa95532_0    defaults
contourpy                 1.1.1                    pypi_0    pypi
cycler                    0.12.1                   pypi_0    pypi
filelock                  3.15.4                   pypi_0    pypi
fonttools                 4.53.1                   pypi_0    pypi
fsspec                    2024.6.1                 pypi_0    pypi
importlib-resources       6.4.4                    pypi_0    pypi
kiwisolver                1.4.5                    pypi_0    pypi
libffi                    3.4.4                hd77b12b_1    defaults
matplotlib                3.7.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
networkx                  3.1                      pypi_0    pypi
numpy                     1.24.4                   pypi_0    pypi
openssl                   3.0.14               h827c3e9_0    defaults
pillow                    10.4.0                   pypi_0    pypi
pip                       24.2             py38haa95532_0    defaults
pyparsing                 3.1.4                    pypi_0    pypi
python                    3.8.19               h1aa4202_0    defaults
python-dateutil           2.9.0.post0              pypi_0    pypi
setuptools                72.1.0           py38haa95532_0    defaults
six                       1.16.0                   pypi_0    pypi
sqlite                    3.45.3               h2bbff1b_0    defaults
sympy                     1.13.2                   pypi_0    pypi
torchvision               0.19.0                   pypi_0    pypi
typing-extensions         4.12.2                   pypi_0    pypi
vc                        14.40                h2eaa2aa_0    defaults
vs2015_runtime            14.40.33807          h98bb1dd_0    defaults
wheel                     0.43.0           py38haa95532_0    defaults

2.IDE设置

其实这就是Anaconda+PyCharm开发的方便之处了,直接新建项目并选择刚刚创建的conda环境就可以愉快的编写代码了。
在这里插入图片描述

下面是文件结构,所有的代码都写在main.py里了,所以创建python工程时直接生成一个mian文件就可以
这里是数据集下载链接🔗:数据集下载
模型文件夹需要手动创建一下,以保存原始,微调与固定特征提取器的模型。

    D:\PYTHON_CODE_WORKSPACE\BEESORANTS_DL-------------------主文件名│  main.py-----------------------------------------------全部代码│  readme.md├─.idea├─hymenoptera_data---------------------------------------数据集│  ├─train-----------------------------------------------训练数据集│  │  ├─ants│  │  │      0013035.jpg│  │  │      ...│  │  │      VietnameseAntMimicSpider.jpg│  │  ││  │  └─bees│  │          1092977343_cb42b38d62.jpg│  │          ...│  │          969455125_58c797ef17.jpg│  │          98391118_bdb1e80cce.jpg│  ││  └─val------------------------------------------------测试数据集│      ├─ants│      │      10308379_1b6c72e180.jpg│      │      ...│      │      Hormiga.jpg│      ││      └─bees│              1032546534_06907fe3b3.jpg│              ...│              936182217_c4caa5222d.jpg│              abeja.jpg│├─model-------------------------------------------------模型│      best_model_params.pt│      finetuned_model_params.pt│      initial_model_params.pt│├─redme_img│└─test_img----------------------------------------------自定义测试集220px-Acrobat.ant1web.jpg40708249_1415445497609.jpg

3.检查GPU-cuda核心是否可用

在开始编译模型开始预测之前,先看一下cuda核心版本和是否可用

输入nvidia-smi我的CUDA核心版本为12.5,从网上找到对应的pyrorch版本下载即可(但其实我更推荐conda下载)

Tue Sep  3 20:21:43 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.99                 Driver Version: 555.99         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3060 ...  WDDM  |   00000000:01:00.0  On |                  N/A |
| N/A   49C    P8             14W /  130W |     252MiB /   6144MiB |      3%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      1060    C+G   ...1.28.6010\updated_web\WXWorkWeb.exe      N/A      |
|    0   N/A  N/A      6532    C+G   ...5n1h2txyewy\ShellExperienceHost.exe      N/A      |
+-----------------------------------------------------------------------------------------+

打开Anaconda.Nvigator,选择open with python当然你也可以在环境文件夹打开python解释器,依次输入下面指令:

import torch导入torch包

print(torch.__version__)打印torch版本验证是否安装成功

torch.cuda.is_available()验证GPU是否可用,虽然CPU也完全可以完成模型的计算(15-20m),但是GPU则更快(8m)

在这里插入图片描述

如果是FALSE的话那大概率是版本不符,如果你遇到了这种情况请从网上搜索一些相关教程,还是很多的,我比较幸运版本都完全匹配,和我一样的配置可以直接照抄我的conda list。

三. 具体实现

下面我会分块讲解每块代码是做什么的,以便你能完全理解代码,全部代码最后附上

1.导入所需软件包

# License: BSD
# Author: Sasank Chilamkurthyimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectorycudnn.benchmark = True
plt.ion()   # interactive mode

2.加载数据

我们将使用 torchvision 和 torch.utils.data 包来加载 数据。

我们今天要解决的问题是训练一个模型来对蚂蚁和蜜蜂进行分类。我们大约有 120 张蚂蚁和蜜蜂的训练图像。 每个类有 75 个验证图像。通常,这是如果从头开始训练,则要推广的小型数据集。由于我们 在使用迁移学习,我们应该能够合理地进行概括

此数据集是 imagenet 的一个非常小的子集

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

3.可视化数据

def imshow(inp, title=None):"""Display image for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs)imshow(out, title=[class_names[x] for x in classes])

运行代码后输出四张图片和标签
在这里插入图片描述

4.训练模型

这段代码定义用于训练模型的函数 train_model
使用通用模型ResNet-18,它首先记录训练开始的时间,然后在每个训练周期(epoch)中,通过迭代训练集和验证集的数据来调整模型的参数。训练过程中,模型在训练阶段被设置为训练模式,在验证阶段被设置为评估模式。每次经过验证阶段后,如果模型在验证集上的准确率比之前最好的一次更好,它会保存当前的模型参数。训练完成后,函数将加载并返回在验证集上表现最好的模型权重。

    通用模型 ResNet-18 通常是为了在图像分类任务中利用其强大的特征提取能力。ResNet-18 是一个深度残差网络(Residual Network)
包含 18 层卷积和全连接层。它在 ImageNet 数据集上进行了预训练,因此能够识别和提取图像中的复杂特征。这些预训练的特征可以用于各种
计算机视觉任务,如分类、检测等。下面是如何使用 ResNet-18 的步骤:1-加载预训练模型:我们可以使用 PyTorch 提供的 torchvision.models 模块来加载一个预训练的 ResNet-18 模型,该模型已经在ImageNet 数据集上训练好了。2-冻结早期层:根据任务的需要,我们可以选择冻结模型的早期层,只训练最后一层。这意味着前几层的权重保持不变,我们只调整最后一层的权重。3-修改输出层:ResNet-18 的原始输出层用于 1000 类分类,但我们可能只需要区分更少的类别(例如蜜蜂和蚂蚁)。因此,我们需要替换掉模型的最后一层,使其输出我们需要的类别数量。4-训练模型:在我们特定的数据集上(例如蜜蜂和蚂蚁的图像数据集)进行训练,通过多轮次的训练调整最后一层的权重,使得模型能够准确地分类新图像。5-评估模型:在验证集上评估模型的性能,如果表现良好,我们可以保存最优的模型权重。    

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()# Create a temporary directory to save training checkpointswith TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')torch.save(model.state_dict(), best_model_params_path)best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), best_model_params_path)print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# load best model weightsmodel.load_state_dict(torch.load(best_model_params_path, weights_only=True))return model

使用通用模型 ResNet-18 通常是为了在图像分类任务中利用其强大的特征提取能力。ResNet-18 是一个深度残差网络(Residual Network),包含 18 层卷积和全连接层。它在 ImageNet 数据集上进行了预训练,因此能够识别和提取图像中的复杂特征。这些预训练的特征可以用于各种计算机视觉任务,如分类、检测等。

下面是如何使用 ResNet-18 的步骤:

  1. 加载预训练模型:我们可以使用 PyTorch 提供的 torchvision.models 模块来加载一个预训练的 ResNet-18 模型,该模型已经在 ImageNet 数据集上训练好了。

  2. 冻结早期层:根据任务的需要,我们可以选择冻结模型的早期层,只训练最后一层。这意味着前几层的权重保持不变,我们只调整最后一层的权重。

  3. 修改输出层:ResNet-18 的原始输出层用于 1000 类分类,但我们可能只需要区分更少的类别(例如蜜蜂和蚂蚁)。因此,我们需要替换掉模型的最后一层,使其输出我们需要的类别数量。

  4. 训练模型:在我们特定的数据集上(例如蜜蜂和蚂蚁的图像数据集)进行训练,通过多轮次的训练调整最后一层的权重,使得模型能够准确地分类新图像。

  5. 评估模型:在验证集上评估模型的性能,如果表现良好,我们可以保存最优的模型权重。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms# 设置数据转换
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}# 数据集路径
data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载 ResNet-18 模型,使用预训练权重
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features# 修改最后一层以匹配我们的分类任务(蜜蜂和蚂蚁,2 类)
model_ft.fc = nn.Linear(num_ftrs, 2)model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()# 优化器设置
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# 学习率调度器
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)# 训练模型
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25)# 训练完成后可以使用训练好的模型进行预测或评估

通过以上步骤,你可以利用预训练的 ResNet-18 模型进行迁移学习,将其应用于不同的图像分类任务中。这样可以节省训练时间,并提升小数据集任务上的性能。

5.可视化模型预测

用于显示一些图像预测的通用函数

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

6.卷积神经网络微调

加载预训练模型并重新配置最终的全连接层。

model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to ``nn.Linear(num_ftrs, len(class_names))``.
model_ft.fc = nn.Linear(num_ftrs, 2)model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

运行后会下载resnet18-f37072fd。pth模型用于预训练,输出如下:

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth0%|          | 0.00/44.7M [00:00<?, ?B/s]47%|####7     | 21.0M/44.7M [00:00<00:00, 219MB/s]95%|#########5| 42.6M/44.7M [00:00<00:00, 223MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 221MB/s]

7.评估和训练

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)

在这段代码中,model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25) 调用了 train_model 函数,并传递了几个参数。每个参数在这里都有其特定的作用和含义:

  1. model_ft:

    • 含义: 这是要训练的模型(在这里是经过预训练的 ResNet-18 模型)。model_ft 包含了模型的架构和当前的权重。
    • 作用: 该模型会在训练过程中被更新(优化),以更好地适应特定任务(如蜜蜂和蚂蚁的分类)。
  2. criterion:

    • 含义: 损失函数(在这里是交叉熵损失函数 CrossEntropyLoss)。
    • 作用: 损失函数用于衡量模型的预测结果与实际标签之间的差异。模型的目标是最小化这个损失,以提高预测的准确性。
  3. optimizer_ft:

    • 含义: 优化器(在这里是随机梯度下降优化器 SGD)。
    • 作用: 优化器负责更新模型的权重,以最小化损失函数的输出。通过调整学习率和动量,优化器能够帮助模型更快地收敛到最优解。
  4. exp_lr_scheduler:

    • 含义: 学习率调度器(在这里是 StepLR 调度器)。
    • 作用: 学习率调度器用于在训练过程中逐步降低学习率。这样可以帮助模型在训练的后期以较小的步伐调整权重,从而更精细地调整模型的参数,提高模型的最终性能。在这里,学习率每过 7 个 epoch 会按照 gamma=0.1 的因子进行衰减。
  5. num_epochs=25:

    • 含义: 训练的轮次数量。
    • 作用: 这个参数决定了训练的总轮次。在这里,模型将会被训练 25 个 epoch(每个 epoch 包含一次完整的训练集和验证集的前向传播和反向传播)。

运行代码后开始训练,输出如下:

在这里插入图片描述

对模型结果进行评估,得到下面结果:

visualize_model(model_ft)

在这里插入图片描述

visualize_model 函数的主要目的是展示模型在验证集(或测试集)上对图像的预测结果。通过查看模型对几个样本图像的预测,你可以直观地理解模型的表现,判断它是否能够正确识别图像中的对象。

8.神经网络固定特征提取

下面这段代码,我们需要冻结除最后一层之外的所有网络,设置冻结参数,以便不再计算梯度。requires_grad = Falsebackward()

model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
for param in model_conv.parameters():param.requires_grad = False# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

这段代码加载了一个使用 IMAGENET1K_V1 权重预训练的 ResNet-18 模型,冻结了除最后一层全连接层以外的所有层的参数,并将最后一层替换为一个适应二分类任务(蜜蜂和蚂蚁)的全连接层,然后设置优化器只更新最后一层的参数。

Q:为什么要使这些参数在训练中保持不变?

A:在迁移学习中,通常会冻结预训练模型的大部分参数,只训练最后几层(如最后一层全连接层)。这是因为预训练模型(例如使用 IMAGENET1K_V1 权重的 ResNet-18)已经在一个大规模的数据集(ImageNet-1000)上学习到了非常通用的特征,这些特征能够很好地表示图像中的低级和中级信息(如边缘、纹理、形状等)。
通过冻结这些层的参数,可以:

  1. 减少训练时间和计算资源:冻结大部分层减少了需要更新的参数数量,因此训练速度更快,资源消耗更低。
  2. 避免过拟合:预训练的特征已经被证明是有效的,通过只训练最后一层,可以防止模型过度拟合到新的、小规模的数据集上。
  3. 利用通用特征:前几层学习到的特征是通用的,适用于多种任务,通过保留这些特征,可以提高模型在新任务上的表现。

这种方法有效地将预训练模型的强大特征提取能力与新任务的特定需求结合起来,从而实现更好的模型性能。
在这里插入图片描述

训练并评估模型

model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=25)visualize_model(model_conv)plt.ioff()
plt.show()

运行结果如下:
在这里插入图片描述

9.自定义测试集测试

将需要识别的图片放在test_img文件夹中,并将测试图片路径赋值给img_path,模型路径赋值于model调用visualize_model_predictions函数实现对自定义图片的识别

def visualize_model_predictions(model,img_path):was_training = model.trainingmodel.eval()img = Image.open(img_path)img = data_transforms['val'](img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():outputs = model(img)_, preds = torch.max(outputs, 1)ax = plt.subplot(2,2,1)ax.axis('off')ax.set_title(f'Predicted: {class_names[preds[0]]}')imshow(img.cpu().data[0])model.train(mode=was_training)visualize_model_predictions(model_conv,img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg')plt.ioff()
plt.show()

运行代码后可以看到图像被正确的识别,结果如下:

在这里插入图片描述

10.结论

微调(fine-tuning)和使用卷积神经网络(ConvNet)作为固定特征提取器是两个不同的方法,尽管它们都用于迁移学习。它们的关系可以理解为:

1. 微调(Fine-tuning)

微调是迁移学习的一种策略,其中一个预训练的模型在新的数据集上进行进一步训练。微调的主要步骤包括:

  • 加载预训练模型:通常从一个大数据集(如ImageNet)上训练的模型开始。
  • 替换输出层:将模型的最后一层(通常是全连接层)替换为适合新任务的层。例如,若原始模型是为1000类分类任务设计的,而你需要进行2类分类任务,则需要替换为一个具有2个输出单元的全连接层。
  • 训练:在新任务的数据集上训练模型时,可以选择是否训练整个网络的所有层,或者只训练新的全连接层。微调通常会解冻一些原本被冻结的层,并对这些层进行训练。
2. 使用固定特征提取器(Fixed Feature Extractor)

固定特征提取器是一种简单的迁移学习方法,其中预训练模型的特征提取部分被冻结,只有新添加的分类层(全连接层)会被训练。具体步骤包括:

  • 加载预训练模型:从一个大数据集上训练的模型开始。
  • 冻结特征提取层:将模型中除最后一层外的所有卷积层设置为不可训练(requires_grad=False)。
  • 添加新的分类层:在冻结的特征提取器之后添加一个新的全连接层,用于进行新的分类任务。
  • 训练:仅训练新的全连接层,固定的卷积层部分不进行训练。

总结

  • 微调固定特征提取器不是相互独立的,而是两种迁移学习的策略。微调是更灵活的方法,能够调整整个网络(或者大部分网络),适应新任务。固定特征提取器则是一种较为简单的方法,仅训练新的分类层,同时保持特征提取部分不变。

  • 顺序结构:在实际应用中,可以先使用固定特征提取器策略,然后逐步转向微调。如果固定特征提取器的结果不尽如人意,可以尝试微调以进一步提高性能。

  • 相互独立:这两个方法在实现上是相互独立的,但可以根据具体需求选择其一或者组合使用。

四.全部代码

说明:

  • flag == 1:进行微调,训练整个模型。
  • flag == 2:使用固定特征提取器,仅训练新加的分类层。
  • flag == 3:使用训练好的模型进行预测
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory#选择操作
flag = 1def main():cudnn.benchmark = Trueplt.ion()  # interactive mode# Data augmentation and normalization for training# Just normalization for validationdata_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),}data_dir = 'hymenoptera_data'image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def imshow(inp, title=None):"""Display image for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training datainputs, classes = next(iter(dataloaders['train']))# Make a grid from batchout = torchvision.utils.make_grid(inputs)# 模型训练def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_name="best_model_params.pt"):since = time.time()# 创建目录以保存模型if not os.path.exists('model'):os.makedirs('model')best_model_params_path = os.path.join('model', model_name)best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()  # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_acctorch.save(model.state_dict(), best_model_params_path)print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# load best model weightsmodel.load_state_dict(torch.load(best_model_params_path))return model# 可视化模型预测def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images // 2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)imshow(out, title=[class_names[x] for x in classes])# 微调特征提取器if flag == 1:model_ft = models.resnet18(weights='IMAGENET1K_V1')num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, 2)model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,model_name="finetuned_model_params.pt")visualize_model(model_ft)# 固定特征提取器elif flag == 2:model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')for param in model_conv.parameters():param.requires_grad = Falsenum_ftrs = model_conv.fc.in_featuresmodel_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)model_conv = train_model(model_conv, criterion, optimizer_conv, exp_lr_scheduler, num_epochs=25,model_name="fixed_model_params.pt")visualize_model(model_conv)plt.ioff()plt.show()# 使用训练好的模型识别elif flag == 3:model_conv = models.resnet18()num_ftrs = model_conv.fc.in_featuresmodel_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)model_conv.load_state_dict(torch.load('model/finetuned_model_params.pt'))def visualize_model_predictions(model, img_path):was_training = model.trainingmodel.eval()img = Image.open(img_path).convert("RGB")img = data_transforms['val'](img)img = img.unsqueeze(0)img = img.to(device)with torch.no_grad():outputs = model(img)_, preds = torch.max(outputs, 1)ax = plt.subplot(2, 2, 1)ax.axis('off')ax.set_title(f'Predicted: {class_names[preds[0]]}')imshow(img.cpu().data[0])model.train(mode=was_training)visualize_model_predictions(model_conv,img_path='D:/Python_Code_WorkSpace/BeesOrAnts_DL/test_img/40708249_1415445497609.jpg')plt.ioff()plt.show()if __name__ == '__main__':main()

写在最后

在这篇文章中,我们深入探讨了卷积神经网络(CNN)迁移学习的基本概念和实用技巧,这些知识是理解和实现端到端智能驾驶系统的基础。端到端智能驾驶系统旨在通过一个统一的深度学习框架,直接将传感器数据映射到驾驶决策,从而简化传统的多阶段处理流程。在未来的研究中,结合BEV(鸟瞰视角)图像和Transformer模型的先进方法正在成为热门趋势,它们可以有效提升对复杂驾驶场景的理解和处理能力。掌握CNN迁移学习将为进一步深入这些前沿技术打下坚实的基础,

加油,汽车人!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/52381.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++(一)----C++基础

1.C的发展史 C语言诞生后&#xff0c;很快普及使用&#xff0c;但是随着编程规模增大且越来越复杂&#xff0c;并且需要高度的抽象和建模时&#xff0c;C语言的诸多短板便表现了出来&#xff0c;为了解决软件危机&#xff0c;上世纪八十年代&#xff0c;计算机界提出了oop&…

如何理解有效值电流?电流的均方根值

电流的有效值就是电流的均方根。 有效值电流定义&#xff1a;将一直流电与一交流电分别通过相同阻值的电阻&#xff0c;如果相同时间内两电流通过电阻产生的热量相同&#xff0c;就说这一直流电的电流值是这一交流电的有效值。 如果说电流就是直流电&#xff0c;那么电流的有效…

Flutter MacOS 去掉窗口导航栏

操作步骤 用xcode打开Flutter项目&#xff0c;点击Runner——>Runner——>Resources——>MainMenu 点击APP_NAME&#xff0c;在右侧勾选窗口选项来控制是否有窗口或者关闭缩小按钮。我这里并没有取消勾选Show Title Bar&#xff0c;因为当我取消勾选后&#xff0c;窗…

已经存在的项目如何变成git的一个repository

已经存在的项目如何被git管理 背景&#xff1a; 有一套代码很敏感&#xff0c;可能动不动就要不能正常工作(硬件开发常事)&#xff0c;那改动一下下就要有个记录&#xff0c;就决定用git管理 已经有了服务里里docker里运行的gitbucket,已经有了开发用的电脑上的git客户端&…

【Python基础】Python函数

本文收录于 《Python编程入门》专栏&#xff0c;从零基础开始&#xff0c;分享一些Python编程基础知识&#xff0c;欢迎关注&#xff0c;谢谢&#xff01; 文章目录 一、前言二、函数的定义与调用三、函数参数3.1 位置参数3.2 默认参数3.3 可变数量参数&#xff08;或不定长参数…

【项目】云备份

云备份 云备份概述框架 功能演示服务端客户端 公共模块文件操作模块目录操作模块 服务端模块功能划分功能细分模块数据管理热点管理 客户端模块功能划分功能细分模块数据管理目录检查文件备份 云备份 概述 自动将本地计算机上指定文件夹中需要备份的文件上传备份到服务器中。…

【Visual Studio 报错】vs 在使用二进制写入文件时弹窗报错:使用简体中文 gb2312 编码加载文件

如以下报错 解决办法 解决方法&#xff1a;文件->高级保存选项->将文件编码形式改为“UTF-8带签名” 若找不到高级保存选项&#xff0c;可以跟着下面路径把该选项调出来 &#xff1a;工具->自定义->命令->菜单栏中改成文件->预览右边点添加命令->类别中…

【C++ Primer Plus习题】14.1

大家好,这里是国中之林! ❥前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。有兴趣的可以点点进去看看← 问题: 解答: main.cpp #include <iostream> #include "wine.h" …

传知代码-机器情绪及抑郁症算法(四)!(论文复现)

代码以及视频讲解 本文所涉及所有资源均在传知代码平台可获取 计算机来理解你的情绪&#xff1a;情感计算的发展 近年来&#xff0c;多模态情感分析&#xff08;MSA&#xff09;受到越来越多的关注&#xff0c;多模态情感分析是一个综合了视觉、听觉等语言和非语言信息的重要…

Parsec问题解决方案

Parsec目前就是被墙了&#xff0c;有解决方案但治标不治本&#xff0c;如果想稳定串流建议是更换稳定的串流软件&#xff0c;以下是一些解决方案 方案一&#xff1a;在%appdata%/Parsec/config.txt中&#xff0c;添加代理 app_proxy_address 127.0.0.1 app_proxy_scheme http…

Qt篇——Qt在msvc编译下提示“C2001:常量中有换行符“的错误

在pro文件中添加以下配置即可&#xff1a; msvc{QMAKE_CFLAGS /utf-8QMAKE_CXXFLAGS /utf-8 }

双指针(7)_单调性_三数之和

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 双指针(7)_单调性_三数之和 收录于专栏【经典算法练习】 本专栏旨在分享学习C的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 1. 题目…

Redis 常用命令总结

文章目录 目录 文章目录 1 . 前置内容 1.1 基本全局命令 KEYS EXISTS ​编辑 DEL EXPIRE TTL TYPE 1.2 数据结构和内部编码 2. String类型 SET GET MGET MSET SETNX INCR INCRBY DECR DECYBY INCRBYFLOAT 命令小结 内部编码 3 . Hash 哈希类型 HSET …

gpt4最新保姆级教程

如何使用 WildCard 服务注册 Claude3 随着 Claude3 的震撼发布&#xff0c;最强 AI 模型的桂冠已不再由 GPT-4 独揽。Claude3 推出了三个备受瞩目的模型&#xff1a;Claude 3 Haiku、Claude 3 Sonnet 以及 Claude 3 Opus&#xff0c;每个模型都展现了卓越的性能与特色。其中&a…

数据结构基本知识

一、什么是数据结构 1.1、组织存储数据 ---------》内存&#xff08;存储&#xff09; 1.2、研究目的 如何存储数据&#xff08;变量&#xff0c;数组....)程序数据结构算法 1.3、常见保存数据的方法 数组&#xff1a;保存自己的数据指针&#xff1a;是间接访问已经存在的…

【Vue】pnpm创建Vue3+Vite项目

初始化项目 &#xff08;1&#xff09;cmd切换到指定工作目录&#xff0c;运行pnpm create vue命令&#xff0c;输入项目名称后按需安装组件 &#xff08;2&#xff09;使用vs code打开所创建的项目目录&#xff0c;Ctrl~快捷键打开终端&#xff0c;输入pnpm install下载项目…

[概率论] 随机变量的分布函数 (一)

文章目录 1.随机变量的分布函数2.离散型随机变量的分布函数3.连续性随机变量的分布函数 1.随机变量的分布函数 设X XX是一个随机变量&#xff0c;x xx是任意实数&#xff0c;则函数 几何表示 性质&#xff08;一个函数是分布函数的充要条件&#xff09; 2.离散型随机变量的分布…

数据结构-图-存储-邻接矩阵-邻接表

数据结构-图-存储 邻接矩阵 存储如下图1,图2 图1 对应邻接矩阵 图2 #include<bits/stdc.h> #define MAXN 1005 using namespace std; int n; int v[MAXN][MAXN]; int main(){cin>>n;for(int i1;i<n;i){for(int j1;j<n;j){cin>>v[i][j];}}for(int…

深度解析Unix系统的基本概念及优缺点和原理

介绍 Unix系统是一种多用户、多任务、分时操作系统&#xff0c;起源于20世纪70年代初&#xff0c;由贝尔实验室开发。它具有强大的命令行接口和层次结构的文件系统&#xff0c;支持多种处理器架构&#xff0c;广泛应用于工程应用和科学计算等领域。 基本概念 一、Unix系统的起…

数学建模强化宝典(13)M-K检验法

前言 M-K检验法&#xff0c;全称为Mann-Kendall检验法&#xff0c;是一种非参数的假设检验方法&#xff0c;广泛应用于时间序列数据的趋势性变化检验&#xff0c;特别是气候序列中的趋势分析和突变点检测。以下是对M-K检验法的详细介绍&#xff1a; 一、定义与背景 M-K检验法由…