基于CNN+ViT的蔬果图像分类实验

本文只是做一个简单融合的实验,没有任何新颖,大家看看就行了。

1.数据集

本文所采用的数据集为Fruit-360 果蔬图像数据集,该数据集由 Horea Mureșan 等人整理并发布于 GitHub(项目地址:Horea94/Fruit-Images-Dataset),广泛应用于图像分类和目标识别等计算机视觉任务。该数据集共包含141 类水果和蔬菜图像,总计 94,110 张图像,每张图像的尺寸统一为 100×100 像素,且背景已统一处理为白色背景,以减少背景噪声对模型训练的影响。

数据集中涵盖了大量常见和不常见的果蔬品类,主要包括:

  1. 苹果(多个品种:如深雪、金苹果、金红、青奶奶、粉红女士、红苹果、红美味等)
  2. 香蕉(黄色、红色、淑女手指等)
  3. 葡萄(蓝色、粉红色、白色多个品种)
  4. 柑橘类(橙子、柠檬、酸橙、葡萄柚、柑橘等)
  5. 热带水果(芒果、木瓜、红毛丹、百香果、番石榴、荔枝、菠萝、火龙果等)
  6. 浆果类(蓝莓、覆盆子、草莓、黑加仑、红醋栗、桑葚等)
  7. 核果类与坚果类(桃子、李子、杏、椰子、榛子、核桃、栗子、山核桃等)
  8. 蔬菜类(黄瓜、茄子、胡椒、番茄、洋葱、花椰菜、甜菜根、玉米、土豆等)
  9. 其他类如:仙人掌果实、杨布拉、姜根、格兰纳迪拉、Physalis(灯笼果)、油桃、佩皮诺、罗望子、大头菜等。

在数据划分方面,本研究按照如下比例进行数据集划分:

(1)训练集:70,491 张图像
          其中按照 8:2 的比例划分出验证集,得到最终:

                        训练子集:56,432 张

                        验证集:14,059 张

(2)测试集:23,619 张图像

2.模型简述

在图像分类任务中,深度学习方法已经取得了显著的进展,如残差神经网络(ResNet),Vision Transformer展现了较强的性能。ResNet作为CNN下的网络架构,在局部特征提取方面具有优势,能够有效地捕捉图像中的空间结构信息。而Vision Transformer作为Transformer的变种,在捕捉全局依赖关系和建模长程依赖性方面的具有更好的优势。

由于CNN的卷积操作本质上能够生成具有空间局部关联性的特征图,实际上可以视为一种变相的patch操作。因此,在将CNN与Transformer相结合时,可以避免传统ViT中对输入图像进行切分patch的操作,只需对图像进行位置编码,从而使得Transformer能够有效处理这些具有空间结构的特征图。这种设计不仅减少了计算开销,还使得整个模型在处理图像时更具效率与准确性。

同时,与原始ViT框架中描述的技术不同,原始框架通常会将一个可学习的位置嵌入向量预先添加到编码后的patch序列中,作为图像的位置信息进行表示。然而,为了简化模型的实现并提高计算效率,本文在架构设计上有所调整,省略了额外的位置编码步骤。具体来说,本文的模型通过直接输入编码后的patch序列到Transformer块中,跳过了对每个patch进行独立位置编码的操作。

基于这一思路,结合了残差神经网络(ResNet)和Vision Transformer(ViT)两种网络架构,将它们以串行连接的方式进行融合。具体模型架构图如下图所示

3.实验

模型代码(基于tensorflow2.X)

import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers,models
import warnings
warnings.filterwarnings('ignore')
import osTrain = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Training"
Test = r"D:\archive (1)\fruits-360_dataset_100x100\fruits-360\Test"IMAGE_SIZE = 100
NUM_CLASSES = 141
BATCH_SIZE = 32imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0, validation_split=0.2, rotation_range=10, horizontal_flip=True)# Training and validation data generators
Train_Data = imagegenerator.flow_from_directory(Train,target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',subset='training'
)
Validation_Data = imagegenerator.flow_from_directory(Train,target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',subset='validation'
)# Test data generator (no augmentation)
test_imagegenerator = ImageDataGenerator(rescale=1.0 / 255.0)
Test_Data = test_imagegenerator.flow_from_directory(Test,target_size=(IMAGE_SIZE,IMAGE_SIZE),batch_size=BATCH_SIZE,class_mode='categorical',# subset='test'
)
class ResidualBlock(layers.Layer):def __init__(self, filters, kernel_size=(3, 3), strides=1):super(ResidualBlock, self).__init__()self.conv1 = layers.Conv2D(filters, kernel_size, strides=strides, padding="same", activation='relu')self.conv2 = layers.Conv2D(filters, kernel_size, strides=1, padding='same', activation='relu')self.shortcut = layers.Conv2D(filters, (1, 1), strides=strides, padding='same', activation='relu')self.bn1 = layers.BatchNormalization()self.bn2 = layers.BatchNormalization()self.relu = layers.ReLU()def call(self, inputs):x = self.conv1(inputs)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)shortcut = self.shortcut(inputs)x = layers.add([x, shortcut])x = self.relu(x)return x# ResNet Model definition
class ResNetModel(layers.Layer):def __init__(self):super(ResNetModel, self).__init__()self.conv1 = layers.Conv2D(32, (5, 5), activation='relu', input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),padding='same')self.maxpool1 = layers.MaxPooling2D((2, 2))# Residual Blocksself.resblock1 = ResidualBlock(32,strides=1)self.resblock2 = ResidualBlock(64,strides=2)self.resblock3 = ResidualBlock(128,strides=2)self.resblock4 = ResidualBlock(256, strides=2)# self.global_avg_pool = layers.GlobalAveragePooling2D()def call(self, inputs):print(inputs.shape)x = self.conv1(inputs)print(x.shape)x = self.maxpool1(x)print(x.shape)# Apply Residual Blocksx = self.resblock1(x)print(x.shape)x = self.resblock2(x)print(x.shape)# x = self.resblock3(x)# print(x.shape)# x = self.resblock4(x)# x = self.global_avg_pool(x)# print(x.shape)return x
class TransformerEncoder(layers.Layer):def __init__(self, num_heads=8, key_dim=64, ff_dim=256, dropout_rate=0.1):super(TransformerEncoder, self).__init__()self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=key_dim)self.dropout1 = layers.Dropout(dropout_rate)self.norm1 = layers.LayerNormalization()self.ff = layers.Dense(ff_dim, activation='relu')self.ff_output = layers.Dense(key_dim*num_heads)self.dropout2 = layers.Dropout(dropout_rate)self.norm2 = layers.LayerNormalization()def call(self, x):# Multi-head self-attentionattention_output = self.attention(x, x)attention_output = self.dropout1(attention_output)x = self.norm1(attention_output + x)  # Residual connection# Feed Forward Networkff_output = self.ff(x)ff_output = self.ff_output(ff_output)ff_output = self.dropout2(ff_output)x = self.norm2(ff_output + x)  # Residual connectionreturn x# Vision Transformer (ViT) 模型
class VisionTransformer(models.Model):def __init__(self, input_shape=(100, 100, 3), num_classes=141, num_encoders=3, patch_size=8, num_heads=16,key_dim=4, ff_dim=256, dropout_rate=0.2):super(VisionTransformer, self).__init__()self.patch_size = patch_size#Resnetself.resnet=ResNetModel()# Patch Embeddingself.conv = layers.Conv2D(64, (patch_size, patch_size), strides=(patch_size, patch_size), padding='valid')self.reshape = layers.Reshape((-1, 64))self.norm = layers.LayerNormalization()# 位置编码层self.position_encoding = self.add_weight("position_encoding", shape=(1, 625, 64))# Stack multiple Transformer Encoder layersself.encoders = [TransformerEncoder(num_heads=num_heads, key_dim=key_dim, ff_dim=ff_dim, dropout_rate=dropout_rate) for _ inrange(num_encoders)]# Global Average Poolingself.global_avg_pooling = layers.GlobalAveragePooling1D()# Fully connected layerself.fc1 = layers.Dense(256, activation='relu')self.dropout = layers.Dropout(0.2)self.fc2 = layers.Dense(num_classes, activation='softmax')def call(self, inputs):#resnetx = self.resnet(inputs)# print("===========================")# print(x.shape)# Patch Embeddingx = self.reshape(x)# 添加位置编码x = x + self.position_encoding  # 将位置编码加到Patch嵌入向量中# print(x.shape)# x = self.norm(x)# Apply multiple Transformer encodersfor encoder in self.encoders:x = encoder(x)# Global Average Poolingx = self.global_avg_pooling(x)# Fully connected layersx = self.fc1(x)x = self.dropout(x)x = self.fc2(x)return x
# 构建 Vision Transformer 模型vit_model = VisionTransformer(input_shape=(100, 100, 3), num_classes=141, num_encoders=3)
vit_model.build(input_shape=(None, IMAGE_SIZE, IMAGE_SIZE, 3))  # 手动构建模型
# 打印模型摘要
vit_model.summary()
# 编译模型
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss='categorical_crossentropy',metrics=['accuracy']
)
checkpoint_path = "training_checkpoints_1/vit_model_checkpoint_epoch_{epoch:02d}.h5"# 创建ModelCheckpoint回调
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,monitor='val_accuracy',  # 你可以选择监控验证集的损失或准确度save_best_only=True,  # 只保存验证集损失最小的模型save_weights_only=True,  # 只保存权重(而不是整个模型)verbose=1  # 打印日志
)
# 检查是否有保存的模型权重文件
checkpoint_dir = "training_checkpoints_1/"
# 查找所有的 .h5 文件
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "vit_model_checkpoint_epoch_*.h5"))
# print(latest_checkpoint)
if checkpoint_files:# 使用 os.path.getctime() 获取文件创建时间(或者使用 getmtime() 获取修改时间)latest_checkpoint = max(checkpoint_files, key=os.path.getctime)print(f"Loading model from checkpoint: {latest_checkpoint}")# 加载模型权重vit_model.load_weights(latest_checkpoint)
else:print("No checkpoint found, starting from scratch.")# 训练模型
history = vit_model.fit(Train_Data,epochs=20,validation_data=Validation_Data,shuffle=True,callbacks=[checkpoint_callback]
)# 评估模型
test_loss, test_acc = vit_model.evaluate(Test_Data)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_acc}")
# 训练和验证的准确率和损失历史记录
def plot_training_history(history):# 创建子图plt.figure(figsize=(14, 6))# 准备训练准确率和验证准确率的图plt.subplot(1, 2, 1)plt.title('Accuracy History')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.plot(history.history['accuracy'], label='Training Accuracy', marker='o')plt.plot(history.history['val_accuracy'], label='Validation Accuracy', color='green', marker='o')plt.legend()# 准备训练损失和验证损失的图plt.subplot(1, 2, 2)plt.title('Loss History')plt.xlabel('Epochs')plt.ylabel('Loss')plt.plot(history.history['loss'], label='Training Loss', marker='o')plt.plot(history.history['val_loss'], label='Validation Loss', color='green', marker='o')plt.legend()# 显示图形plt.tight_layout()plt.show()# 绘制训练过程
plot_training_history(history)
for i in range(16):# 获取测试数据的下一个批次img_batch, labels_batch = Test_Data.next()img = img_batch[0]  # 获取当前批次的第一张图像true_label_idx = np.argmax(labels_batch[0])  # 获取真实标签的索引# 获取真实标签的名称true_label = [key for key, value in Train_Data.class_indices.items() if value == true_label_idx]# 扩展维度以匹配模型输入EachImage = np.expand_dims(img, axis=0)# 进行预测prediction = vit_model.predict(EachImage)# 获取预测标签predicted_label = [key for key, value in Train_Data.class_indices.items() ifvalue == np.argmax(prediction, axis=1)[0]]# 获取预测的概率predicted_prob = np.max(prediction, axis=1)[0]# 绘制图像plt.subplot(4, 4, i + 1)plt.imshow(img)plt.title(f"True: {true_label[0]} \nPred: {predicted_label[0]} \nProb: {predicted_prob:.2f}")plt.axis('off')plt.tight_layout()
plt.show()

做了如下参数实验

ResNet层数

Encoder层数

num_heads

test_accuracy

2(32,64)

3

4

92.14%

3(32,64,128)

3

4

94.53%

2(32,64)

3

8

96.19%

3(32,64,128)

3

8

97.46%

2(32,64)

3

16

93.32%

3(32,64,128)

3

16

93.17%

 分类效果图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/76982.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu24.04安装libgl1-mesa-glx 报错,软件包缺失

在 Ubuntu 24.04 系统中,您遇到的 libgl1-mesa-glx 软件包缺失问题可能是由于该包在最新的 Ubuntu 版本中被重命名为 libglx-mesa0。以下是针对该问题的详细解决方案: 1. 问题原因分析 包名称变更:在 Ubuntu 24.04 中,libgl1-me…

webpack vite

​ 1、webpack webpack打包工具(重点在于配置和使用,原理并不高优。只在开发环境应用,不在线上环境运行),压缩整合代码,让网页加载更快。 前端代码为什么要进行构建和打包? 体积更好&#x…

如何在爬虫中合理使用海外代理?在爬虫中合理使用海外ip

我们都知道,爬虫工作就是在各类网页中游走,快速而高效地采集数据。然而如果目标网站分布在多个国家或者存在区域性限制,那靠普通的网络访问可能会带来诸多阻碍。而这时,“海外代理”俨然成了爬虫工程师们的得力帮手! …

数据仓库分层存储设计:平衡存储成本与查询效率

数据仓库分层存储不仅是一个技术问题,更是一种艺术:如何在有限的资源下,让数据既能快速响应查询,又能以最低的成本存储? 目录 一、什么是数据仓库分层存储? 二、分层存储的体系架构 1. 数据源层(ODS,Operational Data Store) 2. 数据仓库层(DW,Data Warehouse)…

YOLO学习笔记 | 基于YOLOv8的植物病害检测系统

以下是基于YOLOv8的植物病害检测系统完整技术文档,包含原理分析、数学公式推导及代码实现框架。 基于YOLOv8的智能植物病害检测系统研究 摘要 针对传统植物病害检测方法存在的效率低、泛化性差等问题,本研究提出一种基于改进YOLOv8算法的智能检测系统。通过设计轻量化特征提…

高级语言调用C接口(二)回调函数(4)Python

前面2篇分别说了java和c#调用C接口,参数为回调函数,回调函数中参数是结构体指针。 接下来说下python的调用方法。 from ctypes import * import sysclass stPayResult(Structure):_pack_ 4 # 根据实际C结构体的对齐方式设置(常见值为1,4,…

springboot启动动态定时任务

1.自定义定时任务线程池 package com.x.devicetcpserver.global.tcp.tcpscheduler;import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.annotatio…

pytorch框架认识--手写数字识别

手写数字是机器学习中非常经典的案例,本文将通过pytorch框架,利用神经网络来实现手写数字识别 pytorch中提供了手写数字的数据集,我们可以直接从pytorch中下载 MNIST中包含70000张手写数字图像:60000张用于训练,10000…

WPF 使用依赖注入后关闭窗口程序不结束

原因是在ViewModel中在构造函数中注入了Window 对象,即使没有使用,主窗口关闭程序不会退出,即使 ViewModel 是 AddTransient 注入的。 解决方法:不使用构造函数注入Window,通过GetService获取Window 通过注入对象调用…

用户管理(添加和删除,查询信息,切换用户,查看登录用户,用户组,配置文件)

目录 添加和删除用户 查询用户信息 切换用户 查看当前的操作用户是谁 查看首次登录的用户是谁 用户组(对属于同个角色的用户统一管理) 新增组 删除组 添加用户的同时,指定组 修改用户的组 组的配置文件(/etc/group&…

PyTorch学习-小土堆教程

网络搭建torch.nn.Module 卷积操作 torch.nn.functional.conv2d(input, weight, biasNone, stride1, padding0, dilation1, groups1) 神经网络-卷积层

MVCC详细介绍及面试题

目录 1.什么是mvcc? 2.问题引入 3. MVCC实现原理? 3.1 隐藏字段 3.2 undo log 日志 3.2.1 undo log版本链 3.3 readview 3.3.1 当前读 ​编辑 3.3.2 快照读 3.3.3 ReadView中4个核心字段 3.3.4 版本数据链访问的规则(了解&#x…

企业级Active Directory架构设计与运维管理白皮书

企业级Active Directory架构设计与运维管理白皮书 第一章 多域架构设计与信任管理 1.1 企业域架构拓扑设计 1.1.1 林架构设计规范 林根域规划原则: 采用三段式域名结构(如corp.enterprise.com),避免使用不相关的顶级域名架构主…

android11 DevicePolicyManager浅析

目录 📘 简单定义 📘应用启用设备管理者 📂 文件位置 🧠 DevicePolicyManager 功能分类举例 🛡️ 1. 安全策略控制 📷 2. 控制硬件功能 🧰 3. 应用管理 🔒 4. 用户管理 &am…

Java学习手册:Java线程安全与同步机制

在Java并发编程中,线程安全和同步机制是确保程序正确性和数据一致性的关键。当多个线程同时访问共享资源时,如果不加以控制,可能会导致数据不一致、竞态条件等问题。本文将深入探讨Java中的线程安全问题以及解决这些问题的同步机制。 线程安…

PyTorch核心函数详解:gather与where的实战指南

PyTorch中的torch.gather和torch.where是处理张量数据的关键工具,前者实现基于索引的灵活数据提取,后者完成条件筛选与动态生成。本文通过典型应用场景和代码演示,深入解析两者的工作原理及使用技巧,帮助开发者提升数据处理的灵活…

声学测温度原理解释

已知声速,就可以得到温度。 不同温度下的胜诉不同。 25度的声速大约346m/s 绝对温度-273度 不同温度下的声速。 FPGA 通过测距雷达测温度,固定测量距离,或者可以测出当前距离。已知距离,然后雷达发出声波到接收到回波的时间&a…

【网络篇】UDP协议的封装分用全过程

大家好呀 我是浪前 今天讲解的是网络篇的第二章:UDP协议的封装分用 我们的协议最开始是OSI七层网络协议 这个OSI 七层网络协议 是计算机的大佬写的,但是这个协议一共有七层,太多了太麻烦了,于是我们就把这个七层网络协议就简化为…

spring-ai-alibaba使用Agent实现智能机票助手

示例目标是使用 Spring AI Alibaba 框架开发一个智能机票助手,它可以帮助消费者完成机票预定、问题解答、机票改签、取消等动作,具体要求为: 基于 AI 大模型与用户对话,理解用户自然语言表达的需求支持多轮连续对话,能…

嵌入式C语言高级编程:OOP封装、TDD测试与防御性编程实践

一、面向对象编程(OOP) 尽管 C 语言并非面向对象编程语言,但借助一些编程技巧,也能实现面向对象编程(OOP)的核心特性,如封装、继承和多态。 1.1 封装 封装是把数据和操作数据的函数捆绑在一起,对外部隐藏…