基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

 第一步:准备数据

5种中草药数据:self.class_indict = ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"]

,总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个EfficientNetV2网络,其原理介绍如下:

        该网络主要使用训练感知神经结构搜索缩放的组合;在EfficientNetV1的基础上,引入了Fused-MBConv到搜索空间中;引入渐进式学习策略自适应正则强度调整机制使得训练更快;进一步关注模型的推理速度训练速度

与EfficientV1相比,主要有以下不同:

  1. V2中除了使用MBConv模块外,还使用了Fused-MBConv模块
  2. V2中会使用较小的expansion ratio,在V1中基本都是6。这样的好处是能够减少内存访问开销
  3. V2中更偏向使用更小的kernel_size(3 x 3),在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的,所以需要堆叠更多的层结构以增加感受野
  4. 移除了V1中最优一个步距为1的stage

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = {"s": [300, 384],  # train_size, val_size"m": [384, 480],"l": [384, 480]}num_model = "s"data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),transforms.CenterCrop(img_size[num_model][1]),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\ChineseMedicine")# download model weights# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ  密码: 5gu1parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/844962.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

split拆分文件

在Linux系统中,split命令是一个非常实用的工具,用于将大文件拆分成多个小文件。以下是一些基本的使用方法: 基本用法: 使用split命令可以按照指定的行数来拆分文件。例如,将一个文件每1000行拆分成一个新文件&#xff…

Go 语言的基本构成、要素与编写规范

Go 语言,作为由 Google 开发的现代编程语言,以其简洁、高效和并发编程能力而著称。在构建高性能分布式系统和现代软件开发中,Go 语言正日益受到欢迎。本篇文章将详细探讨 Go 语言程序结构的各个要素,包括函数定义、注释规范、数据…

C语言练习题之——从简单到烧脑(11)(每日两道)

题目1:有两个矩阵a[3][2],b[2][2],元素值由键盘输入&#xff0c;计算a与b的矩阵之和&#xff08;两个矩阵循环中相加&#xff0c;结尾求和&#xff09; #include<stdio.h> int main() {int arr[3][2], brr[2][2],i,j,sum10,sum20;for (i 0; i < 3; i){for (j 0; j …

Docker搭建FRP内网穿透服务器

使用Docker搭建一个frp内网穿透 在现代网络环境中&#xff0c;由于防火墙和NAT等原因&#xff0c;内网设备无法直接被外网访问。FRP (Fast Reverse Proxy) 是一款非常流行的内网穿透工具&#xff0c;它能够帮助我们将内网服务暴露给外网。本文将介绍如何在Linux服务器上使用Do…

压测工具Jmeter的使用

一、安装 下载地址&#xff1a; 国外地址&#xff1a;jmeter.apache.org&#xff08;下载会很慢&#xff0c;建议使用国内地址&#xff09; 国内地址&#xff1a;apache-jmeter-binaries安装包下载_开源镜像站-阿里云 下载好进入bin文件下&#xff0c;双击jmeter.bat 打开…

哈希传递(PTH)

使用Mimikatz进行PTH Pass The Hash 哈希传递攻击简称 PTH&#xff0c;该方法通过找到与账户相关的密码散列&#xff08;NTLLHash&#xff09;来进 行攻击。由于在Windows系统中&#xff0c;通常会使用NTLM Hash对访问资源的用户进行身份认证&#xff0c;所以该攻 击可以在不需…

WPF WebBrowser控件解析 HTML

WPF WebBrowser控件解析 HTML Window里面的AllowsTransparency属性不要加<WebBrowser x:Name"webBrowser" />public void InitWeb() {string htmlString "<html><head><title>this is a test</title><script type text/jav…

算法学习笔记(7.2)-贪心算法(最大容量问题)

目录 ##问题描述 ##问题示例 ##释 ##贪心策略的确定 ##代码示例 ##正确性验证 ##问题描述 输入一个数组 ℎ&#x1d461; &#xff0c;其中的每个元素代表一个垂直隔板的高度。数组中的任意两个隔板&#xff0c;以及它们之间的空间可以组成一个容器。 容器的容量等于高度和宽…

泛型aaaaa

1、泛型的概述&#xff1a; 1.1 泛型的由来 根据《Java编程思想》中的描述&#xff0c;泛型出现的动机&#xff1a; 有很多原因促成了泛型的出现&#xff0c;而最引人注意的一个原因&#xff0c;就是为了创建容器类。泛型的思想很早就存在&#xff0c;如C中的模板&#xff0…

6个PPT素材模板网站,免费!

免费PPT素材模板下载&#xff0c;就上这6个网站&#xff0c;建议收藏&#xff01; 1、菜鸟图库 ppt模板免费下载|ppt背景图片 - 菜鸟图库 菜鸟图库是一个设计、办公、媒体等素材非常齐全的网站&#xff0c;站内有几百万的素材&#xff0c;其中PPT模板就有几十万个&#xff0c;…

[stm32]——定时器与PWM的LED控制

目录 一、stm32定时器 1、定时器简介 2、定时器分类 3、通用定时器介绍 二、PWM相关介绍 1、工作原理 2、PWM的一般步骤 三、定时器控制LED亮灭 1、工程创建 2、代码编写 3、实现效果 四、采用PWM模式&#xff0c;实现呼吸灯效果 1、工程创建 2、代码编写 3、实现效果 一、stm3…

STM32 IIC协议

本文代码使用 HAL 库。 文章目录 前言一、什么是IIC协议二、IIC信号三、IIC协议的通讯时序1. 写操作2. 读操作 四、上拉电阻作用总结 前言 从这篇文章开始为大家介绍一些通信协议&#xff0c;包括 UART&#xff0c;SPI&#xff0c;IIC等。 UART串口通讯协议 SPI通信协议 一、…

B端系统:角色与权限界面设计,一文读懂。

一、什么是角色与权限系统 角色与权限系统是一种用于管理和控制用户在系统中的访问和操作权限的机制。它通过将用户分配到不同的角色&#xff0c;并为每个角色分配相应的权限&#xff0c;来实现对系统资源的权限控制和管理。 在角色与权限系统中&#xff0c;通常会定义多个角色…

postgressql——四种进程间锁(4)

进程间锁 在PostgreSQL里有四种类型的进程间锁: Spinlocks:自旋锁,其保护的对象一般是数据库内部的一些数据结构,是一种轻量级的锁。 LWLocks:轻量锁,也是主要用于保护数据库内部的一些数据结构,支持独占和共享两种模式。 Regular locks:又叫heavyweight locks,也就是…

【深度揭秘GPT-4o】:全面解析新一代AI技术的突破与优势

目录 ​编辑 1.版本对比&#xff1a;从GPT-3到GPT-4&#xff0c;再到GPT-4o的飞跃 1.1 模型规模的扩展 1.2 训练数据的更新 1.3 算法优化与效率提升 1.4 案例分析 2.技术能力&#xff1a;GPT-4o的核心优势 2.1 卓越的自然语言理解 2.1.1 上下文理解能力 2.1.2 语义分…

el-table中的信息数据过长 :show-overflow-tooltip=‘true‘**

可以在 el-table-column中添加 :show-overflow-tooltip‘true’

Kotlin 2.0 重磅发布! 性能提升!新功能上线!开发者必看!

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

【Java面试】六、Spring框架相关

文章目录 1、单例Bean不是线程安全的2、AOP3、Spring中事务的实现4、Spring事务失效的场景4.1 情况一&#xff1a;异常被捕获4.2 情况二&#xff1a;抛出检查异常4.3 注解加在非public方法上 5、Bean的生命周期6、Bean的循环引用7、Bean循环引用的解决&#xff1a;Spring三级缓…

软考随记(二)

I/O系统的5种不同的工作方式&#xff1a; 程序控制方式&#xff1a; 无条件查询&#xff1a;I/O端口总是准备好接受主机的输出数据&#xff0c;或是总是准备好向主机输入数据&#xff0c;而CPU在需要时随时直接利用I/O指令访问相应的I/O端口&#xff0c;实现与外设的数据交换 …

python-求点积

【问题描述】&#xff1a;给出两个数组&#xff0c;并求它们的点积。 【问题描述】&#xff1a;输入A[1,1,1],B[2,2,2]&#xff0c;输出6,即1*21*21*26。输入A[3,2],B[2,3,3],输出-1&#xff0c;没有点积。 完整代码如下&#xff1a; alist(map(int,input().split())) blist(…