基于卷积神经网络(CNN)和ResNet50的水果与蔬菜图像分类系统

前言

在现代智能生活中,计算机视觉技术已经成为不可或缺的工具,特别是在食物识别领域。想象一下,您只需拍摄一张水果或蔬菜的照片,系统就能自动识别其种类并为您提供丰富的食谱建议。这项技术不仅在日常生活中极具实用性,在农业、食品配送及健康监测等多个行业中也有着广泛的应用。

本文展示了一个基于深度学习的水果与蔬菜分类系统,采用了强大的卷积神经网络(CNN)和先进的数据增强技术,能够在各种复杂环境下准确识别出不同的水果和蔬菜种类。通过使用预训练的ResNet50模型和混合精度训练,系统优化了训练过程的效率和准确度,并且引入了OneCycleLR学习率调度策略,以确保最佳的学习速度。

无论是在个人项目、商业应用,还是在未来的食品识别系统中,本项目都能为您提供强有力的技术支持。通过本代码,您将能够实现从数据加载、模型训练到最终预测的完整流程,轻松将深度学习应用到食品识别的各个方面。

让我们一起探索这个强大的工具,如何帮助我们实现更智能的生活!

概述

本项目实现了一个基于深度学习的水果和蔬菜识别系统,旨在通过计算机视觉技术对图像中的食品进行分类。系统的核心基于卷积神经网络(CNN)架构,结合了数据增强技术、预训练模型、混合精度训练和学习率调度等先进策略,以提高训练效率和分类准确度。

主要功能:

  1. 数据预处理与增强:使用图像预处理技术(如调整大小、随机旋转、颜色调整等)对输入数据进行增强,提高模型的鲁棒性和泛化能力。
  2. 自定义数据集:通过FruitVegDataset类构建自定义数据集,支持从指定路径加载和标记图像,并能够方便地应用图像转换。
  3. 深度学习模型:利用卷积神经网络(CNN)进行特征提取,并通过ResNet50预训练模型提升识别能力。该模型经过优化,具有较强的表现力,能够识别多达36类水果和蔬菜。
  4. 训练与验证:通过使用AdamW优化器、交叉熵损失函数以及OneCycleLR学习率调度器,优化了训练过程。采用了混合精度训练(Mixed Precision Training)以加速训练过程,同时减少显存使用。
  5. 预测与应用:训练好的模型可用于实时图像预测,用户只需上传一张水果或蔬菜的图片,系统即可返回预测结果,并展示分类的概率信息。

系统特点:

  • 高效训练:通过学习率调度和优化器调整,训练过程不仅更加高效,还能提升模型在验证集上的准确度。
  • 增强现实应用:该模型能够应用于餐厅菜单识别、农业监测、食品配送、健康管理等实际场景,具有较高的商业和应用价值。
  • 简易部署:训练后的模型可以轻松部署到各类应用中,包括移动端应用或web端服务,使得实时食品识别变得更加便捷。

本项目展示了如何通过深度学习技术实现水果和蔬菜的自动分类,推动了食品识别领域的进一步发展,同时为智能农业、健康饮食等领域提供了有力的技术支持。

ResNet50模型介绍

ResNet50 是一种深度残差网络(Residual Network),由微软研究院的何恺明等人于2015年提出。它是ResNet系列中的一个重要变体,具有50层深度,广泛用于计算机视觉任务,如图像分类、目标检测和语义分割。ResNet50的核心思想是引入残差连接(Residual Connections),即通过跳跃连接(skip connections)直接将输入添加到输出,从而解决深层网络中的梯度消失和梯度爆炸问题,促进更深层次网络的训练。

ResNet50的特点
  1. 残差连接(Residual Connections)

    • 传统的深层网络容易出现梯度消失或梯度爆炸的问题,使得训练变得困难。ResNet通过引入残差连接,将输入数据直接跳跃到输出端,形成“捷径”(shortcut)。这使得网络能够更容易地学习到残差(输入和输出的差值),而非直接学习整个映射函数。
    • 这种设计可以有效避免深层网络中的退化问题,提升网络的训练效率和性能。
  2. 深度网络结构

    • ResNet50的深度为50层,采用了多个卷积层(Convolutional Layers)批量归一化层(Batch Normalization),通过堆叠的方式构成深层的神经网络。每一层的输出与输入之间通过跳跃连接直接相加,简化了网络的训练过程。
    • ResNet50相比于其它较浅的网络(如ResNet18、ResNet34)提供了更多的学习能力,能够学习到更复杂的特征。
  3. 残差模块(Residual Block)

    • 在ResNet50中,残差模块是由多个卷积层和残差连接组成的。通常,一个残差模块包括两到三层卷积,每层后跟一个批量归一化层和ReLU激活函数。
    • 每个模块通过1x1卷积(通常用于减少或恢复通道数)与输入建立直接的跳跃连接,最终将输入和输出相加。
    • 通过残差模块,ResNet能够在避免过拟合的情况下训练非常深的网络,并保持较高的准确率。
  4. 瓶颈结构(Bottleneck Architecture)

    • ResNet50采用了瓶颈结构,即每个残差块包含三个卷积层:一个1x1卷积层(用于降低维度),一个3x3卷积层(用于特征提取),以及一个1x1卷积层(用于恢复维度)。
    • 这种结构有效减少了计算量,并且提高了网络的效率。相比于普通的卷积层,瓶颈结构大大减少了参数数量和计算量,使得网络能够在有限的硬件资源上运行得更加高效。
  5. 跳跃连接的应用

    • ResNet50的最大创新之一就是其跳跃连接,它允许信号在网络中传递得更远。每个跳跃连接将前一层的输出与当前层的输出相加,生成最终的输出,这样有助于更容易地训练更深的网络,减少了网络中的退化问题。
    • 通过这种方式,网络不仅可以学习到更复杂的特征,还能够避免梯度在反向传播中的衰减。
  6. 预训练和迁移学习

    • ResNet50常常用作预训练模型,尤其在迁移学习中非常流行。通过在大规模数据集(如ImageNet)上进行预训练,ResNet50能够学习到通用的图像特征,这些特征可以迁移到其他特定的任务上,从而提高目标任务的性能。
    • 由于其出色的特征提取能力,ResNet50作为特征提取器在许多计算机视觉任务中表现出色,并且能够显著减少训练时间。
  7. 较低的计算成本

    • ResNet50相较于更深的网络(如ResNet101、ResNet152)在保持高性能的同时,计算成本相对较低。50层深度的网络结构相较于更深的变体,参数和计算量适中,适合于资源受限的环境。
ResNet50的应用
  • 图像分类:ResNet50被广泛用于图像分类任务,特别是在ImageNet等大规模数据集上训练后,能够为图像提供强大的特征表示。它在ImageNet挑战赛中表现出色,取得了很高的准确率。
  • 目标检测与语义分割:通过结合其它架构(如Faster R-CNN、Mask R-CNN),ResNet50也常用于目标检测和语义分割任务,提取高质量的特征来帮助检测和分割任务。
  • 迁移学习:由于其优异的特征提取能力,ResNet50常作为迁移学习模型的基础,能够应用于医疗图像分析、面部识别、视频分析等领域。
    在这里插入图片描述

模型的核心逻辑

本项目采用了基于深度学习的卷积神经网络(CNN)来进行水果与蔬菜分类任务。具体的核心逻辑包括以下几个部分:

1. 使用预训练模型作为特征提取器

核心的模型结构基于ResNet50,该模型在ImageNet上预训练过,已经学到了有效的图像特征。因此,在我们的任务中,ResNet50能够有效地提取水果和蔬菜图像中的低层次和高层次特征。

  • 冻结部分层:为了减少计算量,并且避免在较少的数据集上过拟合,我们选择冻结ResNet50模型的前30层(即不更新这些层的权重)。这使得模型能够专注于学习更高层次的特征,而不需要重新学习基础的图像特征。
self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)  
for param in list(self.backbone.parameters())[:-30]:  param.requires_grad = False  
  • 替换全连接层:ResNet50的原始全连接层被替换成自定义的全连接层,这一层是针对水果和蔬菜分类任务进行设计的。通过新的全连接层将提取到的特征映射到目标类别(水果与蔬菜类别)。
self.backbone.fc = nn.Sequential(  nn.Linear(num_features, 1024),  nn.BatchNorm1d(1024),  nn.ReLU(inplace=True),  nn.Dropout(0.3),  nn.Linear(1024, 512),  nn.BatchNorm1d(512),  nn.ReLU(inplace=True),  nn.Dropout(0.3),  nn.Linear(512, num_classes)  
)  
2. 数据增强与预处理

为了增加训练数据的多样性,减少模型的过拟合,输入图像经过了一系列的数据增强操作。这些操作包括:

  • 缩放、裁剪:通过随机缩放、随机裁剪等操作确保模型能够应对不同尺度的图像。
  • 旋转与翻转:通过随机旋转、水平和垂直翻转等,增强模型的鲁棒性。
  • 颜色抖动:对图像的亮度、对比度、饱和度等进行随机变化,以增加模型对颜色变化的适应性。

这些数据增强方法提高了模型在未见数据上的泛化能力。

train_transform = transforms.Compose([  transforms.Resize((256, 256)),  transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  transforms.RandomHorizontalFlip(),  transforms.RandomVerticalFlip(),  transforms.RandomRotation(20),  transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  transforms.ToTensor(),  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
])  
3. 模型训练与优化

训练过程中,使用了以下几个重要的技术:

  • OneCycleLR学习率调度器:为了加速训练过程并避免过拟合,使用了OneCycleLR学习率调度器,它帮助在训练初期增加学习率,然后逐渐减小,以使模型收敛得更快并且避免在训练结束时陷入局部最优解。
scheduler = OneCycleLR(  optimizer,  max_lr=config.learning_rate,  epochs=config.epochs,  steps_per_epoch=len(train_loader),  pct_start=0.1,  anneal_strategy='cos'  
)  
  • 优化器:使用了AdamW优化器,它是一种基于自适应估计的优化方法,适合深度学习任务。通过AdamW优化器,我们能够有效地更新模型参数。

  • 混合精度训练:为了提高训练效率和减少显存占用,使用了PyTorch的混合精度训练(autocastGradScaler)。这使得在计算过程中部分操作使用半精度浮点数(FP16),以提高速度和节省内存,同时保持较高的精度。

with autocast():  outputs = model(inputs)  loss = criterion(outputs, labels)  
4. 损失函数与评估
  • 损失函数:使用了交叉熵损失(Cross-Entropy Loss)作为训练的目标函数,因为它适用于多类别分类任务。模型通过最小化交叉熵损失来优化其分类精度。
criterion = nn.CrossEntropyLoss()  
  • 评估指标:除了损失函数,训练过程中还监控了准确率(Accuracy),即模型在给定的测试集上的分类正确率。通过准确率来评估模型的性能,并在训练过程中选择最优的模型。
5. 模型预测与推断

训练完成后,模型可以用于对新的图像进行预测。输入图像首先经过相同的数据预处理和增强(例如调整大小、规范化等),然后输入到训练好的模型中,得到模型的预测输出。

模型输出的结果通过softmax函数转化为每个类别的概率值,最终返回最可能的类别及其对应的概率。

def predict_image(url, model):  response = requests.get(url)  image = Image.open(BytesIO(response.content)).convert('RGB')  input_tensor = transform(image).unsqueeze(0)  with torch.no_grad():  output = model(input_tensor)  probabilities = torch.nn.functional.softmax(output[0], dim=0)  predicted_class = torch.argmax(probabilities).item()  return predicted_class, probabilities[predicted_class].item()  

代码实现

1. 设置随机种子和设备

为了保证结果的可重复性,我们设置了随机种子。然后确定是否使用GPU,如果GPU可用,则使用GPU,否则使用CPU。

!pip install ultralytics -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install albumentations -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install timm -i  https://mirrors.aliyun.com/pypi/simple/ numpy
!pip install wandb -i  https://mirrors.aliyun.com/pypi/simple/ numpy
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import random# Set seeds for reproducibility
def set_seed(seed=42):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Falseset_seed()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")# Create folder for saving results
os.makedirs('results', exist_ok=True)
2. 数据集展示

这一部分的代码用来展示数据集的结构,打印数据集的类和图像数量,并随机展示一些训练集的图像。

def explore_data(data_path):"""Explore and visualize the dataset"""print("\nExploring Dataset Structure:")print("-" * 50)splits = ['train', 'validation', 'test']for split in splits:split_path = os.path.join(data_path, split)if os.path.exists(split_path):classes = sorted(os.listdir(split_path))total_images = sum(len(os.listdir(os.path.join(split_path, cls))) for cls in classes)print(f"\n{split.capitalize()} Set:")print(f"Number of classes: {len(classes)}")print(f"Total images: {total_images}")print(f"Example classes: {', '.join(classes[:5])}...")# Visualize sample imagesprint("\nVisualizing Sample Images...")train_path = os.path.join(data_path, 'train')classes = sorted(os.listdir(train_path))plt.figure(figsize=(15, 10))for i in range(9):class_name = random.choice(classes)class_path = os.path.join(train_path, class_name)img_name = random.choice(os.listdir(class_path))img_path = os.path.join(class_path, img_name)img = Image.open(img_path)plt.subplot(3, 3, i+1)plt.imshow(img)plt.title(f'Class: {class_name}')plt.axis('off')plt.tight_layout()plt.savefig('results/sample_images.png')plt.show()# Explore dataset
data_path = "/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset"
explore_data(data_path)

在这里插入图片描述
在这里插入图片描述

3. 自定义数据集类

这部分代码定义了一个自定义的PyTorch Dataset 类,FruitVegDataset,用于加载数据集,并支持图像的转换(如缩放、裁剪等)。

class FruitVegDataset(Dataset):def __init__(self, root_dir, split='train', transform=None):self.root_dir = os.path.join(root_dir, split)self.transform = transformself.classes = sorted(os.listdir(self.root_dir))self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}self.images = []self.labels = []for class_name in self.classes:class_path = os.path.join(self.root_dir, class_name)for img_name in os.listdir(class_path):if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):self.images.append(os.path.join(class_path, img_name))self.labels.append(self.class_to_idx[class_name])def __len__(self):return len(self.images)def __getitem__(self, idx):img_path = self.images[idx]label = self.labels[idx]image = Image.open(img_path).convert('RGB')if self.transform:image = self.transform(image)return image, label
4. 数据增强和预处理

这里定义了数据增强和预处理流程。使用了常见的数据增强方法,如随机水平翻转、随机旋转、颜色抖动等。并且对图像进行标准化处理。

# Define transforms
train_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# Visualize augmentations
def show_augmentations(dataset, num_augments=5):"""Show original image and its augmented versions"""idx = random.randint(0, len(dataset)-1)img_path = dataset.images[idx]original_img = Image.open(img_path).convert('RGB')plt.figure(figsize=(15, 5))# Show originalplt.subplot(1, num_augments+1, 1)plt.imshow(original_img)plt.title('Original')plt.axis('off')# Show augmented versionsfor i in range(num_augments):augmented = train_transform(original_img)augmented = augmented.permute(1, 2, 0).numpy()augmented = (augmented * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406])augmented = np.clip(augmented, 0, 1)plt.subplot(1, num_augments+1, i+2)plt.imshow(augmented)plt.title(f'Augmented {i+1}')plt.axis('off')plt.tight_layout()plt.savefig('results/augmentations.png')plt.show()# Create datasets and show augmentations
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
show_augmentations(train_dataset)

在这里插入图片描述

5. 卷积块和网络结构

这一部分代码定义了一个卷积块(ConvBlock)和一个自定义的卷积神经网络(FruitVegCNN)用于图像分类。

class ConvBlock(nn.Module):def __init__(self, in_channels, out_channels):super().__init__()self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, 3, padding=1),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.MaxPool2d(2))def forward(self, x):return self.conv(x)class FruitVegCNN(nn.Module):def __init__(self, num_classes):super().__init__()self.features = nn.Sequential(ConvBlock(3, 64),ConvBlock(64, 128),ConvBlock(128, 256),ConvBlock(256, 512),ConvBlock(512, 512))self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1),nn.Flatten(),nn.Dropout(0.5),nn.Linear(512, 256),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(256, num_classes))def forward(self, x):x = self.features(x)x = self.classifier(x)return x# Function to visualize feature maps
def visualize_feature_maps(model, sample_image):"""Visualize feature maps after each conv block"""model.eval()# Get feature maps after each conv blockfeature_maps = []x = sample_image.unsqueeze(0).to(device)for block in model.features:x = block(x)feature_maps.append(x.detach().cpu())# Plot feature mapsplt.figure(figsize=(15, 10))for i, fmap in enumerate(feature_maps):# Plot first 6 channels of each blockfmap = fmap[0][:6].permute(1, 2, 0)fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min())for j in range(min(6, fmap.shape[-1])):plt.subplot(5, 6, i*6 + j + 1)plt.imshow(fmap[:, :, j], cmap='viridis')plt.title(f'Block {i+1}, Ch {j+1}')plt.axis('off')plt.tight_layout()plt.savefig('results/feature_maps.png')plt.show()# Initialize model and visualize feature maps
model = FruitVegCNN(num_classes=len(train_dataset.classes)).to(device)
sample_image, _ = train_dataset[0]
visualize_feature_maps(model, sample_image)

在这里插入图片描述

6. 训练和验证函数

定义了训练(train_one_epoch)和验证(validate)函数。这些函数在每个epoch中更新模型权重,并计算损失和准确率。

def train_one_epoch(model, train_loader, criterion, optimizer, device):model.train()running_loss = 0.0correct = 0total = 0pbar = tqdm(train_loader, desc='Training')for inputs, labels in pbar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()pbar.set_postfix({'loss': f'{loss.item():.4f}','acc': f'{100.*correct/total:.2f}%'})return running_loss / len(train_loader), 100. * correct / totaldef validate(model, val_loader, criterion, device):model.eval()running_loss = 0.0correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(val_loader, desc='Validation'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()return running_loss / len(val_loader), 100. * correct / totaldef plot_training_progress(history):"""Plot and save training progress"""plt.figure(figsize=(12, 4))plt.subplot(1, 2, 1)plt.plot(history['train_loss'], label='Train Loss')plt.plot(history['val_loss'], label='Val Loss')plt.title('Loss History')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(history['train_acc'], label='Train Acc')plt.plot(history['val_acc'], label='Val Acc')plt.title('Accuracy History')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.tight_layout()plt.savefig('results/training_progress.png')plt.show()
7. 训练与验证过程

在此部分代码中,我们定义了训练和验证的数据加载器,并设置了模型训练的相关配置。使用CrossEntropyLoss作为损失函数,AdamW优化器来优化模型,同时设置了学习率调度器ReduceLROnPlateau以自动调整学习率。训练过程包括多轮的训练与验证,并在每个周期结束时记录和打印训练与验证的损失和准确率。此外,还会保存每个周期的模型权重并在验证准确率提高时保存最佳模型。

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)# Training loop
num_epochs = 30
best_val_acc = 0
history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []
}print("\nStarting training...")
for epoch in range(num_epochs):print(f'\nEpoch {epoch+1}/{num_epochs}')train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)val_loss, val_acc = validate(model, val_loader, criterion, device)# Update schedulerscheduler.step(val_loss)# Save historyhistory['train_loss'].append(train_loss)history['train_acc'].append(train_acc)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')# Plot progressif (epoch + 1) % 5 == 0:plot_training_progress(history)# Save best modelif val_acc > best_val_acc:best_val_acc = val_accprint(f'New best validation accuracy: {best_val_acc:.2f}%')torch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'best_acc': best_val_acc,}, 'results/best_model.pth')# Final training visualization
plot_training_progress(history)

在这里插入图片描述
在这里插入图片描述

8. 绘制训练与验证的准确率与损失曲线

此部分代码用于可视化训练过程中模型的准确率和损失变化情况。通过绘制训练和验证集上的准确率与损失曲线,帮助我们直观地观察模型在不同训练周期中的表现。同时,代码会输出训练和验证过程中达到的最佳准确率,以便进一步分析模型的性能。

import matplotlib.pyplot as pltdef plot_accuracy_loss(history):"""Plot training and validation accuracy/loss curves"""plt.figure(figsize=(12, 4))# Plot Accuracyplt.subplot(1, 2, 1)plt.plot(history['train_acc'], label='Training', marker='o')plt.plot(history['val_acc'], label='Validation', marker='o')plt.title('Model Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.grid(True)# Plot Lossplt.subplot(1, 2, 2)plt.plot(history['train_loss'], label='Training', marker='o')plt.plot(history['val_loss'], label='Validation', marker='o')plt.title('Model Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('results/accuracy_loss_curves.png')plt.show()# Print best accuracy valuesbest_train_acc = max(history['train_acc'])best_val_acc = max(history['val_acc'])print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")print(f"Best Validation Accuracy: {best_val_acc:.2f}%")# Plot the curves
plot_accuracy_loss(history)

在这里插入图片描述

9. 优化的训练配置与增强数据增强

此部分代码实现了一个优化的训练流程,主要包括改进的超参数配置、增强的数据预处理以及混合精度训练技术。通过使用 ResNet50 作为骨干网络,添加了逐层冻结策略、增强的分类器结构(带有Dropout和Batch Normalization)以及One Cycle Learning Rate调度器等技术,可以提升模型的训练效果和泛化能力。此外,训练过程中应用了混合精度训练来加速计算并减少显存占用,进一步优化了训练过程。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler# Improved training configurations
class OptimizedConfig:def __init__(self):self.image_size = 256  # Increased from 224self.batch_size = 16   # Smaller batch size for better generalizationself.learning_rate = 3e-4self.weight_decay = 0.01self.epochs = 50self.dropout = 0.3# Enhanced data augmentation
train_transform = transforms.Compose([transforms.Resize((256, 256)),transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(20),transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomApply([transforms.GaussianBlur(3)], p=0.3),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])# Optimized model architecture
class OptimizedCNN(nn.Module):def __init__(self, num_classes):super().__init__()# Use pretrained ResNet50 as backboneself.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)# Freeze early layersfor param in list(self.backbone.parameters())[:-30]:param.requires_grad = False# Modified classifiernum_features = self.backbone.fc.in_featuresself.backbone.fc = nn.Sequential(nn.Linear(num_features, 1024),nn.BatchNorm1d(1024),nn.ReLU(inplace=True),nn.Dropout(0.3),nn.Linear(1024, 512),nn.BatchNorm1d(512),nn.ReLU(inplace=True),nn.Dropout(0.3),nn.Linear(512, num_classes))def forward(self, x):return self.backbone(x)# Optimized training function
def train_with_optimization(model, train_loader, val_loader, config):criterion = nn.CrossEntropyLoss(label_smoothing=0.1)optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)# One Cycle Learning Rate Schedulerscheduler = OneCycleLR(optimizer,max_lr=config.learning_rate,epochs=config.epochs,steps_per_epoch=len(train_loader),pct_start=0.1,anneal_strategy='cos')# Gradient Scaler for mixed precision trainingscaler = GradScaler()history = {'train_loss': [], 'train_acc': [],'val_loss': [], 'val_acc': []}best_val_acc = 0for epoch in range(config.epochs):# Trainingmodel.train()train_loss = 0correct = 0total = 0pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.epochs}')for inputs, labels in pbar:inputs, labels = inputs.to(device), labels.to(device)optimizer.zero_grad()# Mixed precision trainingwith autocast():outputs = model(inputs)loss = criterion(outputs, labels)scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()scheduler.step()train_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()pbar.set_postfix({'loss': f'{loss.item():.4f}','acc': f'{100.*correct/total:.2f}%','lr': f'{scheduler.get_last_lr()[0]:.6f}'})train_acc = 100. * correct / totaltrain_loss = train_loss / len(train_loader)# Validationmodel.eval()val_loss = 0correct = 0total = 0with torch.no_grad():for inputs, labels in tqdm(val_loader, desc='Validation'):inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)loss = criterion(outputs, labels)val_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()val_acc = 100. * correct / totalval_loss = val_loss / len(val_loader)# Save historyhistory['train_loss'].append(train_loss)history['train_acc'].append(train_acc)history['val_loss'].append(val_loss)history['val_acc'].append(val_acc)print(f'\nEpoch {epoch+1}/{config.epochs}:')print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')# Save best modelif val_acc > best_val_acc:best_val_acc = val_acctorch.save({'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'best_acc': best_val_acc,}, 'optimized_model.pth')print(f'New best validation accuracy: {best_val_acc:.2f}%')return history# Create dataloaders with optimized configuration
config = OptimizedConfig()
train_dataset = FruitVegDataset(data_path, 'train', train_transform)
val_dataset = FruitVegDataset(data_path, 'validation', val_transform)train_loader = DataLoader(train_dataset, batch_size=config.batch_size,shuffle=True,num_workers=4,pin_memory=True)
val_loader = DataLoader(val_dataset,batch_size=config.batch_size,shuffle=False,num_workers=4,pin_memory=True)# Initialize and train optimized model
model = OptimizedCNN(num_classes=len(train_dataset.classes)).to(device)
history = train_with_optimization(model, train_loader, val_loader, config)
10. 优化结果的可视化

此部分代码负责可视化优化后的训练和验证过程中的准确率与损失值。通过图表展示模型在训练和验证集上的表现,帮助评估优化策略的有效性。代码还输出了最佳的训练和验证准确率,便于进一步分析模型的性能。

def plot_optimized_results(history):plt.style.use('seaborn-v0_8')plt.figure(figsize=(15, 5))# Plot Accuracyplt.subplot(1, 2, 1)plt.plot(history['train_acc'], label='Training', marker='o')plt.plot(history['val_acc'], label='Validation', marker='o')plt.title('Model Accuracy with Optimizations')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.legend()plt.grid(True)# Plot Lossplt.subplot(1, 2, 2)plt.plot(history['train_loss'], label='Training', marker='o')plt.plot(history['val_loss'], label='Validation', marker='o')plt.title('Model Loss with Optimizations')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.grid(True)plt.tight_layout()plt.savefig('optimized_results.png', dpi=300, bbox_inches='tight')plt.show()# Print best metricsbest_train_acc = max(history['train_acc'])best_val_acc = max(history['val_acc'])print(f"\nBest Training Accuracy: {best_train_acc:.2f}%")print(f"Best Validation Accuracy: {best_val_acc:.2f}%")# Plot results
plot_optimized_results(history)

在这里插入图片描述

11. 模型加载与图像预测

这段代码提供了一个从URL加载图像并用训练好的模型进行预测的流程。首先,加载已保存的模型,并通过预处理步骤对图像进行转换,然后进行推理并展示前5个预测结果。

import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
from io import BytesIO# Load the saved model
def load_model():# Check if model file existstry:# Load model checkpointcheckpoint = torch.load('optimized_model.pth')model = OptimizedCNN(num_classes=36)  # Same as trainingmodel.load_state_dict(checkpoint['model_state_dict'])model.eval()print("Model loaded successfully!")return modelexcept FileNotFoundError:print("Model file 'optimized_model.pth' not found!")return None# Prediction function
def predict_image(url, model):# Image preprocessingtransform = transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# Load image from URLresponse = requests.get(url)image = Image.open(BytesIO(response.content)).convert('RGB')# Transform imageinput_tensor = transform(image).unsqueeze(0)# Make predictionwith torch.no_grad():output = model(input_tensor)probabilities = torch.nn.functional.softmax(output[0], dim=0)# Get top 5 predictionstop_probs, top_indices = torch.topk(probabilities, 5)# Show resultsplt.figure(figsize=(12, 4))# Show imageplt.subplot(1, 2, 1)plt.imshow(image)plt.title('Input Image')plt.axis('off')# Show predictionsplt.subplot(1, 2, 2)classes = sorted(os.listdir("/home/mw/input/Fruit1112533/Fruits and Vegetables Image Recognition Dataset/train"))y_pos = range(5)plt.barh(y_pos, [prob.item() * 100 for prob in top_probs])plt.yticks(y_pos, [classes[idx] for idx in top_indices])plt.xlabel('Probability (%)')plt.title('Top 5 Predictions')plt.tight_layout()plt.show()# Print predictionsprint("\nPredictions:")print("-" * 30)for i in range(5):print(f"{classes[top_indices[i]]:20s}: {top_probs[i]*100:.2f}%")# Load model
model = load_model()# Now you can use it like this:
predict_image('https://pngimg.com/uploads/watermelon/watermelon_PNG2640.png', model)

在这里插入图片描述

注意

# 需要完整代码以及数据集请点击以下链接:
https://mbd.pub/o/bread/mbd-Z5yclpZu

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/64655.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Tomcat部署war包项目解决404问题

问题出在了Tomcat的版本上了,应该先去看这个项目使用的springboot版本,然后去仓库里找到对应Tomcat版本。 Maven Repository: org.springframework.boot spring-boot-starter-tomcat 因此我们应该选择Tomcat9版本。 当我把Tomcat11换成Tomcat9时&…

Redis篇--常见问题篇1--缓存穿透(缓存空值,布隆过滤器,接口限流)

1、概述 缓存穿透是指客户端请求的数据既不在Redis缓存中,也不在数据库中。换句话说,缓存和数据库中都不存在该数据,但客户端仍然发起了查询请求。这种情况下,缓存无法命中,请求会直接穿透到数据库,而数据…

前端使用 Konva 实现可视化设计器(20)- 性能优化、UI 美化

这一章主要分享一下使用 Konva 遇到的性能优化问题,并且介绍一下 UI 美化的思路。 至少有 2 位小伙伴积极反馈,发现本示例有明显的性能问题,一是内存溢出问题,二是卡顿的问题,在这里感谢大家的提醒。 请大家动动小手&a…

BlueLM:以2.6万亿token铸就7B参数超大规模语言模型

一、介绍 BlueLM 是由 vivo AI 全球研究院自主研发的大规模预训练语言模型,本次发布包含 7B 基础 (base) 模型和 7B 对话 (chat) 模型,同时我们开源了支持 32K 的长文本基础 (base) 模型和对话 (chat) 模型。 更大量的优质数据 :高质量语料…

C语言基础16(文件IO)

文章目录 构造类型枚举类型typedef 文件操作(文件IO)概述文件的操作文件的打开与关闭打开文件关闭文件文件打开与关闭案例 文件的顺序读写单字符读取多字符读取单字符写入多字符写入 综合案例:文件拷贝判别文件结束 数据块的读写(二进制)数据块的读取数据块的写入 文…

冯诺依曼架构与哈佛架构的对比与应用

冯诺依曼架构(Von Neumann Architecture),也称为 冯诺依曼模型,是由著名数学家和计算机科学家约翰冯诺依曼(John von Neumann)在1945年提出的。冯诺依曼架构为现代计算机奠定了基础,几乎所有现代…

3D造型软件solvespace在windows下的编译

3D造型软件solvespace在windows下的编译 在逛开源社区的时候发现了几款开源CAD建模软件,一直囿于没有合适的建模软件,虽然了解了很多的模拟分析软件,却不能使之成为整体的解决方案,从而无法产生价值。opencascad之流虽然可行&…

机器学习04-为什么Relu函数

机器学习0-为什么Relu函数 文章目录 机器学习0-为什么Relu函数 [toc]1-手搓神经网络步骤总结2-为什么要用Relu函数3-进行L1正则化修改后的代码解释 4-进行L2正则化解释注意事项 5-Relu激活函数多有夸张1-细数Relu函数的5宗罪2-Relu函数5宗罪详述 6-那为什么要用这个Relu函数7-文…

QScreen在Qt5.15与Qt6.8版本下的区别

简述 QScreen主要用于提供与屏幕相关的信息。它可以获取有关显示设备的分辨率、尺寸、DPI(每英寸点数)等信息。本文主要是介绍Qt5.15与Qt6环境下,QScreen的差异,以及如何判断高DPI设备。 属性说明 logicalDotsPerInch&#xff1…

[HNCTF 2022 Week1]你想学密码吗?

下载附件用记事本打开 把这些代码放在pytho中 # encode utf-8 # python3 # pycryptodemo 3.12.0import Crypto.PublicKey as pk from hashlib import md5 from functools import reducea sum([len(str(i)) for i in pk.__dict__]) funcs list(pk.__dict__.keys()) b reduc…

shell8

until循环(条件为假的时候一直循环和while相反) i0 until [ ! $i -lt 10 ] doecho $i((i)) done分析 初始化变量: i0:将变量i初始化为0。 条件判断 (until 循环): until [ ! $i -lt 10 ]:这里的逻辑有些复杂。它使用了until循环…

【游戏中orika完成一个Entity的复制及其Entity异步落地的实现】 1.ctrl+shift+a是飞书下的截图 2.落地实现

一、orika工具使用 1)工具类 package com.xinyue.game.utils;import ma.glasnost.orika.MapperFactory; import ma.glasnost.orika.impl.DefaultMapperFactory;/*** author 王广帅* since 2022/2/8 22:37*/ public class XyBeanCopyUtil {private static MapperFactory mappe…

【十进制整数转换为其他进制数——短除形式的贪心算法】

之前写过一篇用贪心算法计算十进制转换二进制的方法,详见:用贪心算法计算十进制数转二进制数(整数部分)_短除法求二进制-CSDN博客 经过一段时间的研究,本人又发现两个规律: 1、不仅仅十进制整数转二进制可…

【Harmony Next】多个图文配合解释DevEco Studio工程中,如何配置App相关内容,一次解决多个问题?

解决App配置相关问题列表 1、Harmony Next如何配置图标? 2、Harmony Next如何配置App名称? 3、Harmony Next如何配置版本号? 4、Harmony Next如何配置Bundle ID? 5、Harmony Next如何配置build号? 6、Harmony Next多语言配置在哪…

Mybatis分页插件的使用问题记录

项目中配置的分页插件依赖为 <dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper</artifactId><version>5.1.7</version></dependency>之前的项目代码编写分页的方式为&#xff0c;通过传入的条件…

【技术干货】移动SDK安全风险及应对策略

移动SDK&#xff08;软件开发工具包&#xff09;已经成为应用开发中不可或缺的一部分。通过SDK&#xff0c;开发者能够快速集成分析、广告调度、音视频处理、社交功能和用户身份验证等常见功能&#xff0c;而无需从零开始构建。这不仅能节省时间和资源&#xff0c;还能提高开发…

易语言OCR银行卡文字识别

一.引言 文字识别&#xff0c;也称为光学字符识别&#xff08;Optical Character Recognition, OCR&#xff09;&#xff0c;是一种将不同形式的文档&#xff08;如扫描的纸质文档、PDF文件或数字相机拍摄的图片&#xff09;中的文字转换成可编辑和可搜索的数据的技术。随着技…

新能源汽车充电需求攀升,智慧移动充电服务有哪些实际应用场景?

在新能源汽车行业迅猛发展的今天&#xff0c;智慧充电桩作为支持这一变革的关键基础设施&#xff0c;正在多个实际应用场景中发挥着重要作用。从公共停车场到高速公路服务区&#xff0c;从企业园区到住宅小区&#xff0c;智慧充电桩不仅提供了便捷的充电服务&#xff0c;还通过…

QT多媒体开发(一):概述

Qt Multimedia 模块为多媒体编程提供支持。多媒体编程实现的功能主要包括播放音频和视频文件&#xff0c;通过麦克风录制音频&#xff0c;通过摄像头拍照和录像等。 QT6 中多媒体模块相比QT5变化较大&#xff0c;所以用QT6编译 QT5写的多媒体 程序基本无法通过。 Qt 5 多媒体模…

人才画像系统如何支撑企业的人才战略落地

在当今竞争激烈的商业环境中&#xff0c;企业的人才战略对于其长期发展至关重要。为了有效实施人才战略&#xff0c;企业需要一套精准、高效的人才管理工具&#xff0c;而人才画像系统正是满足这一需求的关键解决方案。本文将探讨人才画像系统如何支撑企业的人才战略落地&#…