基于Pytorch框架的深度学习EfficientNet神经网络香蕉水果成熟度识别分类系统源码

 第一步:准备数据

4种香蕉水果成熟度数据:overripe,ripe,rotten,unripe(过熟、熟、烂、未成熟),总共有13474张图片,每个文件夹单独放一种成熟度数据

第二步:搭建模型

本文选择一个EfficientNet网络,其原理介绍如下:

        为了弄清楚神经网络缩放之后的效果,谷歌团队系统地研究了改变不同维度对模型的影响,维度参数包括网络深度、宽度和图像分辨率。首先他们进行了栅格搜索(Grid Search)。这是一种穷举搜索方法,可以在固定资源的限定下,列出所有参数之间的关系,显示出改变某一种维度时,基线网络模型会受到什么样的影响。换句话说,如果只改变了宽度、深度或分辨率,模型的表现会发生什么变化。

        综合考虑所有情况之后,他们确定了每个维度最合适的调整系数,然后将它们一同应用到基线网络中,对每个维度都进行适当的缩放,并且确保其符合目标模型的大小和计算预算。

        简单来说,就是分别找到宽度、深度和分辨率的最佳系数,然后将它们组合起来一起放入原本的网络模型中,对每一个维度都有所调整。从整体的角度缩放模型。与传统方法相比,这种复合缩放法可以持续提高模型的准确性和效率。在现有模型 MobileNet 和 ResNet 上的测试结果显示,它分别提高了 1.4% 和 0.7% 的准确率。

         因为,为了进一步提高性能,谷歌 AI 团队还使用了 AutoML MNAS 框架进行神经架构搜索,优化准确性和效率。AutoML 是一种可以自动设计神经网络的技术,由谷歌团队在 2017 年提出,而且经过了多次优化更新。使用这种技术可以更简便地创造神经网络。由此产生的架构使用了移动倒置瓶颈卷积(MBConv),类似于 MobileNetV2 和 MnasNet 模型,但由于计算力(FLOPS)预算增加,MBConv 模型体积略大。随后他们多次缩放了基线网络,组成了一系列模型,统称为 EfficientNets。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import efficientnet_b0 as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = {"B0": 224,"B1": 240,"B2": 260,"B3": 300,"B4": 380,"B5": 456,"B6": 528,"B7": 600}num_model = "B0"data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(img_size[num_model]),transforms.CenterCrop(img_size[num_model]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后一个卷积层和全连接层外,其他权重全部冻结if ("features.top" not in name) and ("classifier" not in name):para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=4)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\classifier\classifier\train")# download model weights# 链接: https://pan.baidu.com/s/1ouX0UmjCsmSx3ZrqXbowjw  密码: 090iparser.add_argument('--weights', type=str, default='./efficientnetb0.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习CNN神经网络香蕉水果成熟度识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/16286.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ZEDmini使用完全指南

ZEDmini使用 ZED stereolabs 开箱测评 使用说明 ubuntu18.04nvidiacuda10 ubuntu18.04ZED SDK安装和使用 Ubuntu16.04安装NVIDIA显卡驱动 查看显卡信息 redwallredwall-G3-3500:~/catkin_ws$ lspci | grep VGA 00:02.0 VGA compatible controller: Intel Corporation Device …

sourcetree推送到git上面

官网:Sourcetree | Free Git GUI for Mac and Windows 下载到1次提交 下载后打开 点击跳过 下一步 名字邮箱 点击clone 把自己要上传的代码粘贴到里面去 返回点击远程->点击暂存所有 加载完毕后,输入提交内容提交 提交完成了 2次提交 把文件夹内的…

【java程序设计期末复习】chapter4 类和对象

类和对象 编程语言的几个发展阶段 (1)面向机器语言 计算机处理信息的早期语言是所谓的机器语言,使用机器语言进行程序设计需要面向机器来编写代码,即需要针对不同的机器编写诸如0101 1100这样的指令序列。 (2&#x…

【JavaScript】文件下载

文件下载的消息格式 服务器只要在响应头中加入 Content-Disposition: attachment; filename"kxx" 即可触发浏览器的下载功能其中: attachment 表示附件,浏览器看到此字段,触发下载行为(不同的浏览器下载行为有所区别&…

【二叉树】力扣OJ题

文章目录 前言1. 翻转二叉树1.1 题目1.2 解题思路1.3 代码实现1.4 时空复杂度 2. 对称二叉树2.1 题目2.2 解题思路2.3 代码实现2.4 时空复杂度 3. 平衡二叉树3.1 题目3.2 解题思路3.3 代码实现3.4 时空复杂度 结语 前言 本篇博客主要介绍二叉树的经典 OJ 题,题目主…

MyBatis详细教程!!(入门版)

目录 什么是MyBatis? MyBatis入门 1)创建工程 2)数据准备 3)配置数据库连接字符串 4)写持久层代码 5)生成测试类 MyBatis打印日志 传递参数 MyBatis的增、删、改 增(Insert&#xff0…

有什么普通人可以做的赚钱软件?盘点9个适合普通人长期做的软件

在这个互联网高速发展的时代,智能手机已经成为我们生活中不可分割的一部分。众多APP的涌现,使得许多朋友都在寻求通过手机赚钱的方法。 然而,面对市面上琳琅满目的网上赚钱APP,我们该如何挑选呢?别担心,今…

功率电感设计方法2:实例

文章目录 1:美磁的选项手册截图2:设计步骤2.1:设计需求2.2:选择磁芯材料2.3:选择磁芯2.4 查询 A L A_{L} AL​自感系数2.5 初算匝数2.6重新校准验算感量 3:后续 绕线因子4:日常壁纸分享 参考手册链接 1&…

普通人转行程序员,最大的困难是找不到就业方向

来百度APP畅享高清图片 大家好,这里是程序员晚枫,小破站也叫这个名。 我自己是法学院毕业后,通过2年的努力才转行程序员成功的。[吃瓜R] 我发现对于一个外行来说,找不到一个适合自己的方向,光靠努力在一个新的行业里…

使用Java 将字节数组转成16进制的形式

概述 在很多场景下,需要进行分析字节数据,但是我们存起来的字节数据一般都是二进制的,这时候就需要我们将其转成16进制的方式方便分析。比如在做音视频的时候,需要看下我们传输的视频h264数据中是否有对应的I帧或者B帧等数据&…

07、SpringBoot 源码分析 - SpringApplication启动流程七

SpringBoot 源码分析 - SpringApplication启动流程七 初始化基本流程SpringApplication的prepareContext准备上下文postProcessApplicationContext处理applyInitializers初始化器初始化load SpringApplication的refreshContext刷新上下文refreshServletWebServerApplicationCon…

8.什么是HOOK

程序编译的本质是,首先计算机它只能看得懂机器码也就是只能看得懂数字,机器码学起来很费劲然后就创造了编译器这个东西,编译器它懂机器语言所以它可以跟机器沟通,而我们人可以跟编译器沟通,人跟编译器的语言就是各种各…

[Vulnhub]Vulnix 通过NFS挂载+SSH公钥免密登录权限提升

端口扫描 Server IP AddressPorts Open192.168.8.103TCP:22/tcp, 25/tcp, 79/tcp, 110/tcp, 111/tcp, 143/tcp, 512/tcp, 513/tcp, 514/tcp, 993/tcp, 995/tcp, 2049/tcp, 37522/tcp, 42172/tcp, 43219/tcp, 47279/tcp, 54227/tcp $ nmap -p- 192.168.8.103 -sV -sC --min-ra…

MyBatis系统学习 - 使用Mybatis完成查询单条,多条数据,模糊查询,动态设置表名,获取自增主键

上篇博客我们围绕Mybatis链接数据库进行了相关概述,并对Mybatis的配置文件进行详细的描述,本篇博客也是建立在上篇博客之上进行的,在上面博客搭建的框架基础上,我们对MyBatis实现简单的增删改查操作进行重点概述,在MyB…

P459 包装类Wrapper

包装类的分类 1)针对八种基本数据类型相应的引用类型——包装类。 2)有了类的特点,就可以调用类中的方法。 Boolean包装类 Character包装类 其余六种Number类型的包装类 包装类和基本数据类型的相互转换 public class Integer01 {publi…

解决文件夹打开出错问题:原因、数据恢复与预防措施

在我们日常使用电脑或移动设备时,有时会遇到一个非常棘手的问题——文件夹打开出错。这种错误可能会让您无法访问重要的文件和数据,给工作和生活带来极大的不便。本文将带您深入了解文件夹打开出错的原因,并提供有效的数据恢复方案&#xff0…

【网络协议】应用层协议--HTTP

文章目录 一、HTTP是什么?二、HTTP协议工作过程三、HTTP协议1. fiddler2. Fiddler抓包的原理3. 代理服务器是什么?4. HTTP协议格式1.1 请求1.2 响应 四、认识HTTP的请求1.认识HTTP请求的方法2.认识请求头(header)3.认识URL3.1 URL是什么&…

SparkSQL入门

1、SparkSQL是什么? 结论:SparkSQL 是一个即支持 SQL 又支持命令式数据处理的工具 2、SparkSQL 的适用场景? 结论:SparkSQL 适用于处理结构化数据的场景,而Spark 的 RDD 主要用于处理 非结构化数据 和 半结构化数据 …

掌握ASPICE标准:汽车软件测试工程师的专业发展路径

掌握ASPICE标准:汽车软件测试工程师的专业发展路径 文:领测老贺 随着新能源汽车在中国的蓬勃发展,智能驾驶技术的兴起,汽车测试工程师的角色变得愈发关键。这一变革带来了前所未有的挑战和机遇,要求测试工程师不仅要具…

解决git克隆项目出现fatal无法访问git clone https://github.com/lvgl/lvgl.git

Windows 11系统 报错 $ git clone https://github.com/lvgl/lvgl.git Cloning into lvgl... fatal: unable to access https://github.com/lvgl/lvgl.git/: Failed to connect to github.com port 443 after 21141 ms: Couldnt connect to server 解决方法 git运行这两段代码…