精准检测花生豆:基于EfficientNet的深度学习分类项目
在现代农业生产中,作物的质量检测和分类是确保产品质量的重要环节。针对花生豆的检测与分类需求,我们开发了一套基于深度学习的解决方案,利用EfficientNetB0模型实现高效、准确的花生豆分类。本博客将详细介绍该项目的背景、数据处理、模型架构、训练过程、评估方法及预测应用。
目录
- 项目背景
- 项目概述
- 数据处理
- 数据集结构
- 数据增强与规范化
- 模型架构
- 训练过程
- 训练脚本 (
train.py
)
- 训练脚本 (
- 模型评估
- 评估脚本 (
evaluate.py
)
- 评估脚本 (
- 预测与应用
- 预测脚本 (
predict.py
)
- 预测脚本 (
- 项目成果
- 结论与未来工作
项目背景
花生豆作为一种重要的经济作物,其品质直接影响到市场价值和消费者满意度。传统的人工检测方法不仅耗时耗力,而且易受主观因素影响,难以实现大规模、精准的分类。因此,开发一种高效、准确的自动化检测系统显得尤为重要。
项目概述
本项目旨在利用深度学习技术,构建一个能够自动检测和分类花生豆的系统。通过收集和处理大量花生豆图像数据,训练一个高性能的卷积神经网络模型,实现对不同类别花生豆的精准分类。项目主要包括以下几个部分:
- 数据处理:图像数据的加载、预处理与增强。
- 模型架构:基于EfficientNetB0的分类模型设计。
- 训练过程:模型的训练与优化,包括断点续训与学习率调度。
- 模型评估:在测试集上的性能评估。
- 预测应用:对新图像进行花生豆分类与标注。
数据处理
数据集结构
项目使用的数据集分为训练集、验证集和测试集,具体结构如下:
./data/dataset/
├── train/
│ ├── baiban/
│ ├── bandian/
│ ├── famei/
│ ├── faya/
│ ├── hongpi/
│ ├── qipao/
│ ├── youwu/
│ └── zhengchang/
├── validation/
│ ├── baiban/
│ ├── bandian/
│ ├── famei/
│ ├── faya/
│ ├── hongpi/
│ ├── qipao/
│ ├── youwu/
│ └── zhengchang/
└── test/├── baiban/├── bandian/├── famei/├── faya/├── hongpi/├── qipao/├── youwu/└── zhengchang/
每个子文件夹对应一种花生豆类别,包含相应的图像数据。
数据增强与规范化
为了提高模型的泛化能力,训练过程中对图像数据进行了多种数据增强操作,如随机裁剪、水平翻转、旋转和颜色抖动。同时,使用ImageNet的均值和标准差对图像进行了归一化处理,与预训练模型的输入要求保持一致。
# utils/dataLoader.pytrain_transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),transforms.RandomHorizontalFlip(),transforms.RandomRotation(15),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize(*stats)
])validation_transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize(*stats)
])test_transform = transforms.Compose([transforms.Resize((image_size, image_size)),transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize(*stats)
])
模型架构
本项目采用了EfficientNetB0作为基础模型。EfficientNet系列通过系统性地平衡网络的宽度、深度和分辨率,在模型性能和计算效率之间取得了优异的平衡。具体来说:
- 预训练权重:使用在ImageNet上预训练的权重,帮助模型在较小的数据集上快速收敛。
- 冻结特征提取部分:根据需要,可以选择冻结模型的特征提取层,仅训练最后的分类器,适用于数据量较小的情况。
- 分类器设计:在原有分类器前添加了Dropout层,减少过拟合风险。
# utils/model.pyclass EfficientNetB0(nn.Module):def __init__(self, num_classes, pretrained=True, freeze_features=False):super(EfficientNetB0, self).__init__()if pretrained:self.model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)else:self.model = models.efficientnet_b0(weights=None)if freeze_features:for param in self.model.features.parameters():param.requires_grad = Falsein_features = self.model.classifier[1].in_featuresself.model.classifier = nn.Sequential(nn.Dropout(p=0.4, inplace=True),nn.Linear(in_features, num_classes))def forward(self, x):return self.model(x)
训练过程
训练脚本 (train.py
)
训练脚本负责模型的训练与验证,包括数据加载、模型初始化、训练循环、学习率调度、模型保存和训练曲线绘制等功能。
关键功能包括:
- 训练与验证循环:每个epoch包括训练阶段和验证阶段,记录损失与准确率。
- 优化与调度:使用Adam优化器和
ReduceLROnPlateau
学习率调度器,根据验证损失动态调整学习率。 - 模型保存:保存验证集准确率最高的模型,并定期自动保存模型检查点。
- 断点续训:支持从保存的检查点继续训练,避免重复计算。
- 训练曲线绘制:训练结束后,生成并保存训练与验证的准确率和损失曲线。
# train.pyimport torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import osdef accuracy(predictions, labels):pred = torch.argmax(predictions, dim=1)correct = (pred == labels).sum().item()return correctdef train(net, start_epoch, epochs, train_loader, validation_loader, device, criterion, optimizer, scheduler, model_path, auto_save):# 初始化train_acc_list, validation_acc_list = [], []train_loss_list, validation_loss_list = [], []best_validation_acc = 0net = net.to(device)if start_epoch > 0:print(f"从 epoch {start_epoch} 开始训练。")for epoch in range(start_epoch, epochs):# 训练阶段net.train()train_correct, train_loss, total = 0, 0, 0with tqdm(train_loader, ncols=100, colour='green', desc=f"Train Epoch {epoch+1}/{epochs}") as pbar:for images, labels in pbar:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item() * images.size(0)train_correct += accuracy(outputs, labels)total += labels.size(0)pbar.set_postfix({'loss': f"{train_loss / total:.4f}", 'acc': f"{train_correct / total:.4f}"})train_acc = train_correct / totaltrain_loss = train_loss / totaltrain_acc_list.append(train_acc)train_loss_list.append(train_loss)# 验证阶段net.eval()validation_correct, validation_loss, total_validation = 0, 0, 0with torch.no_grad():with tqdm(validation_loader, ncols=100, colour='blue', desc=f"Validation Epoch {epoch+1}/{epochs}") as pbar:for images, labels in pbar:images, labels = images.to(device), labels.to(device)outputs = net(images)loss = criterion(outputs, labels)validation_loss += loss.item() * images.size(0)validation_correct += accuracy(outputs, labels)total_validation += labels.size(0)pbar.set_postfix({'loss': f"{validation_loss / total_validation:.4f}", 'acc': f"{validation_correct / total_validation:.4f}"})validation_acc = validation_correct / total_validationvalidation_loss = validation_loss / total_validationvalidation_acc_list.append(validation_acc)validation_loss_list.append(validation_loss)# 更新学习率scheduler.step(validation_loss)# 保存最佳模型if validation_acc > best_validation_acc:best_validation_acc = validation_acccheckpoint = {'epoch': epoch,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),'validation_acc': best_validation_acc}torch.save(checkpoint, model_path)print(f"保存最佳模型,验证准确率: {best_validation_acc:.4f}")# 自动保存模型if (epoch + 1) % auto_save == 0:save_path = model_path.replace('.pth', f'_epoch{epoch+1}.pth')checkpoint = {'epoch': epoch,'model_state_dict': net.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'scheduler_state_dict': scheduler.state_dict(),'validation_acc': best_validation_acc}torch.save(checkpoint, save_path)print(f"自动保存模型到 {save_path}")# 绘制训练曲线def plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs):plt.figure(figsize=(12, 5))plt.subplot(1, 2, 1)plt.plot(range(1, epochs+1), train_acc_list, 'bo-', label="训练准确率")plt.plot(range(1, epochs+1), validation_acc_list, 'ro-', label="验证准确率")plt.title("训练准确率 vs 验证准确率")plt.xlabel("轮次")plt.ylabel("准确率")plt.legend()plt.subplot(1, 2, 2)plt.plot(range(1, epochs+1), train_loss_list, 'bo-', label="训练损失")plt.plot(range(1, epochs+1), validation_loss_list, 'ro-', label="验证损失")plt.title("训练损失 vs 验证损失")plt.xlabel("轮次")plt.ylabel("损失")plt.legend()os.makedirs('logs', exist_ok=True)plt.savefig('logs/training_curve.png')plt.show()plot_training_curves(train_acc_list, validation_acc_list, train_loss_list, validation_loss_list, epochs)if __name__ == '__main__':batch_size = 64image_size = 224classes_num = 8num_epochs = 100auto_save = 10lr = 1e-4weight_decay = 1e-4device = 'cuda' if torch.cuda.is_available() else 'cpu'classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}train_loader, validation_loader, test_loader = load_data(batch_size, image_size, classify)net = EfficientNetB0(classes_num, pretrained=True, freeze_features=False)model_path = 'model_weights/EfficientNetB0.pth'os.makedirs('model_weights', exist_ok=True)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)# 检查点续训start_epoch = 0best_validation_acc = 0if os.path.exists(model_path):try:checkpoint = torch.load(model_path, map_location=device)required_keys = ['model_state_dict', 'optimizer_state_dict', 'scheduler_state_dict', 'epoch', 'validation_acc']if all(key in checkpoint for key in required_keys):net.load_state_dict(checkpoint['model_state_dict'])optimizer.load_state_dict(checkpoint['optimizer_state_dict'])scheduler.load_state_dict(checkpoint['scheduler_state_dict'])start_epoch = checkpoint['epoch'] + 1best_validation_acc = checkpoint['validation_acc']print(f"从 epoch {checkpoint['epoch']} 继续训练,最佳验证准确率: {best_validation_acc:.4f}")else:print(f"检查点文件缺少必要的键,开始从头训练。")except Exception as e:print(f"加载检查点时发生错误: {e}")print("开始从头训练。")print("训练开始")time_start = time.time()train(net, start_epoch, num_epochs, train_loader, validation_loader, device=device, criterion=criterion, optimizer=optimizer, scheduler=scheduler, model_path=model_path, auto_save=auto_save)time_end = time.time()seconds = time_end - time_startm, s = divmod(seconds, 60)h, m = divmod(m, 60)print("训练结束")print("本次训练时长为:%02d:%02d:%02d" % (h, m, s))
主要特点:
- 进度条可视化:使用
tqdm
库实时展示训练和验证进度。 - 断点续训:支持从上一次中断的epoch继续训练,确保训练过程的连续性。
- 自动保存:定期保存模型检查点,防止意外中断导致的训练损失。
- 训练曲线:生成并保存训练与验证的准确率和损失曲线,便于后续分析与调优。
模型评估
评估脚本 (evaluate.py
)
评估脚本用于在测试集上评估训练好的模型性能,计算准确率和损失,并将结果保存到文件中。
# evaluate.pyimport torch
import torch.nn as nn
from utils.dataLoader import load_data
from utils.model import EfficientNetB0
from tqdm import tqdm
import osdef accuracy(predictions, labels):pred = torch.argmax(predictions, dim=1)correct = (pred == labels).sum().item()return correctdef evaluate(net, test_loader, device, criterion, output_path):net.eval()test_correct, test_loss, total_test = 0, 0, 0with torch.no_grad():with tqdm(test_loader, ncols=100, colour='blue', desc=f"Evaluating on Test Set") as pbar:for images, labels in pbar:images, labels = images.to(device), labels.to(device)outputs = net(images)loss = criterion(outputs, labels)test_loss += loss.item() * images.size(0)test_correct += accuracy(outputs, labels)total_test += labels.size(0)pbar.set_postfix({'loss': f"{test_loss / total_test:.4f}", 'acc': f"{test_correct / total_test:.4f}"})test_acc = test_correct / total_testtest_loss = test_loss / total_testresult = f"测试集准确率: {test_acc:.4f}, 测试集损失: {test_loss:.4f}"print(result)# 保存结果到文件os.makedirs(os.path.dirname(output_path), exist_ok=True)with open(output_path, 'a') as f:f.write(result + '\n')if __name__ == '__main__':batch_size = 64image_size = 224classes_num = 8device = 'cuda' if torch.cuda.is_available() else 'cpu'classify = {'baiban': 0, 'bandian': 1, 'famei': 2, 'faya': 3, 'hongpi': 4, 'qipao': 5, 'youwu': 6, 'zhengchang': 7}_, _, test_loader = load_data(batch_size, image_size, classify)net = EfficientNetB0(classes_num, pretrained=False)model_path = 'model_weights/EfficientNetB0.pth'if not os.path.exists(model_path):print(f"模型权重文件 {model_path} 不存在,请先训练模型。")exit()net.load_state_dict(torch.load(model_path, map_location=device))net.to(device)criterion = nn.CrossEntropyLoss()evaluation_output_path = 'outputs/evaluation_results.txt'# 清空之前的评估结果if os.path.exists(evaluation_output_path):os.remove(evaluation_output_path)print("评估开始")evaluate(net, test_loader, device=device, criterion=criterion, output_path=evaluation_output_path)print("评估结束")
评估流程:
- 加载模型:从保存的权重文件中加载训练好的模型。
- 模型评估:在测试集上计算模型的准确率和损失。
- 结果保存:将评估结果保存到指定的输出文件中,便于后续查看与分析。
预测与应用
预测脚本 (predict.py
)
预测脚本用于对新图像进行花生豆分类,并在图像上标注分类结果和边框。
# predict.pyimport os
import cv2
import numpy as np
import torch
from PIL import Image
from utils.model import EfficientNetB0
from torchvision import transformsdef delet_contours(contours, delete_list):delta = 0for i in range(len(delete_list)):del contours[delete_list[i] - delta]delta += 1return contoursdef main():input_path = 'data/pic'output_dir = 'outputs/predicted_images'os.makedirs(output_dir, exist_ok=True)image_files = os.listdir(input_path)classify = {0: 'baiban', 1: 'bandian', 2: 'famei', 3: 'faya', 4: 'hongpi', 5: 'qipao', 6: 'youwu', 7: 'zhengchang'}# 与训练时相同的预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=(0.485, 0.456, 0.406), # ImageNet均值std=(0.229, 0.224, 0.225)) # ImageNet标准差])device = 'cuda' if torch.cuda.is_available() else 'cpu'net = EfficientNetB0(8, pretrained=False)model_path = 'model_weights/EfficientNetB0.pth'if not os.path.exists(model_path):print(f"模型权重文件 {model_path} 不存在,请先训练模型。")returnnet.load_state_dict(torch.load(model_path, map_location=device))net.to(device)net.eval()min_size = 30max_size = 400for img_name in image_files:img_path = os.path.join(input_path, img_name)img = cv2.imread(img_path)if img is None:print(f"无法读取图像: {img_path}")continuehsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) # 转换到HSV颜色空间# 根据HSV颜色范围进行掩膜操作(根据实际情况调整颜色范围)lower_blue = np.array([100, 100, 8])upper_blue = np.array([255, 255, 255])mask = cv2.inRange(hsv, lower_blue, upper_blue) # 创建掩膜result = cv2.bitwise_and(img, img, mask=mask) # 应用掩膜result = result.astype(np.uint8)# 转换为灰度图并二值化gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)_, binary_image = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)# 查找轮廓contours, _ = cv2.findContours(binary_image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)contours = list(contours)# 过滤轮廓delete_list = []for idx, contour in enumerate(contours):perimeter = cv2.arcLength(contour, True)if perimeter < min_size or perimeter > max_size:delete_list.append(idx)contours = delet_contours(contours, delete_list)# 对每个轮廓进行分类for contour in contours:x, y, w, h = cv2.boundingRect(contour)crop = img[y:y+h, x:x+w]if crop.size == 0:continuecrop_pil = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))crop_tensor = transform(crop_pil).unsqueeze(0).to(device)with torch.no_grad():output = net(crop_tensor)pred = torch.argmax(output, dim=1).item()label = classify[pred]# 标注图像cv2.putText(img, label, (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (36,255,12), 2)cv2.rectangle(img, (x, y), (x + w, y + h), (0, 0, 255), 2)# 保存结果图像到outputs/predicted_images/output_image_path = os.path.join(output_dir, f"{os.path.splitext(img_name)[0]}_predicted.jpg")cv2.imwrite(output_image_path, img)print(f"保存预测结果到 {output_image_path}")print("所有图像的预测和标注已完成并保存到 ./outputs/predicted_images/")if __name__ == '__main__':main()
预测流程:
- 图像预处理:将输入图像转换到HSV颜色空间,应用颜色掩膜提取花生豆区域。
- 轮廓检测与过滤:查找并过滤不符合大小要求的轮廓,确保只处理有效的花生豆区域。
- 分类与标注:对每个有效轮廓进行裁剪、预处理,并使用训练好的模型进行分类。在图像上标注分类结果和边框。
- 结果保存:将标注后的图像保存到指定的输出目录。
预测结果
项目成果
通过本项目,我们成功构建了一个能够高效、准确地检测和分类花生豆的深度学习模型。主要成果包括:
- 高准确率:模型在测试集上达到了令人满意的分类准确率。
- 自动化检测:实现了对新图像的自动检测与分类,大大提高了检测效率。
- 可视化结果:通过图像标注,直观展示了分类结果,便于用户理解和应用。
训练与验证的准确率和损失曲线示例
开源的程序及数据集
git clone https://gitee.com/songaoxiangsoar/peanut-bean-testing.git
结论与未来工作
本项目展示了基于深度学习的花生豆检测与分类的可行性与有效性。通过采用预训练的EfficientNetB0模型,并结合数据增强与优化策略,模型在花生豆分类任务中表现出色。
未来的工作方向包括:
- 模型优化:尝试更深更复杂的模型,如EfficientNetB7,以进一步提升分类性能。
- 数据扩展:收集更多多样化的花生豆图像,增强模型的泛化能力。
- 实时检测:优化模型推理速度,实现实时花生豆检测与分类。
- 部署应用:将模型集成到移动设备或嵌入式系统,便于现场检测与应用。
通过持续的优化与扩展,我们相信这一系统将在农业生产中发挥更大的价值,助力智能农业的发展。
感谢阅读本博客!如果您对本项目有任何疑问或建议,欢迎在下方留言交流。