基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

 第一步:准备数据

25种宝石数据,总共800张:

{ "0": "Alexandrite","1": "Almandine","2": "Benitoite","3": "Beryl Golden","4": "Carnelian", "5": "Cats Eye","6": "Danburite", "7": "Diamond","8": "Emerald","9": "Fluorite","10": "Garnet Red","11": "Hessonite","12": "Iolite","13": "Jade","14": "Kunzite","15": "Labradorite","16": "Malachite","17": "Onyx Black","18": "Pearl","19": "Quartz Beer","20": "Rhodochrosite","21": "Sapphire Blue","22": "Tanzanite","23": "Variscite","24": "Zircon"}

第二步:搭建模型

本文选择一个RegNet网络,其原理介绍如下:

该论文提出了一个新的网络设计范式,并不是专注于设计单个网络实例,而是设计了一个网络设计空间network design space。整个过程类似于经典的手工网络设计,但被提升到了设计空间的水平。使用本文的方法,作者探索了网络设计的结构方面,并得到了一个由简单、规则的网络构成了低维设计空间并称之为RegNet。RegNet设计空间提供了各个范围flop下简单、快速的网络。在类似的训练设置和flops下,RegNet的效果超过了EfficientNet同时在GPU上快了5倍

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import create_regnet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_regnet(model_name=args.model_name,num_classes=args.num_classes).to(device)# print(model)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("train {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=25)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\gemstone\archive_train")parser.add_argument('--model-name', default='RegNetY_400MF', help='create model name')# 预训练权重下载地址# 链接: https://pan.baidu.com/s/1XTo3walj9ai7ZhWz7jh-YA  密码: 8lmuparser.add_argument('--weights', type=str, default='regnety_400mf.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/18245.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字化农业新时代:图扑农林牧综合监控平台

利用图扑自研 HT for Web GIS 产品,结合遥感技术,构建可交互式的农林牧数据分析平台。该平台围绕地块总览、播种分析、牛只管理、设备查询四个维度,对地区的全貌、农场、村集体分布以及相应的环境进行多样化的可视化展示和进行数据支持&#…

爱岗敬业短视频:成都科成博通文化传媒公司

爱岗敬业短视频:传递正能量,塑造职场新风尚 在当今社会,短视频以其独特的传播方式和广泛的受众群体,成为了信息传播的重要渠道。在众多短视频内容中,以“爱岗敬业”为主题的短视频尤为引人注目,成都科成博…

FreeRtos进阶——队列的特殊用途

信号量与互斥量都一样,都是特殊的队列。但是只有互斥量实现了优先级继承机制。 信号量与互斥量与队列一样,在操作增加或者减少时,必须先关中断在进行操作! 信号量创建揭秘 图中信号量的创建过程,在代码中的体现本质就是…

现在股票交易佣金标准最低是万0.854,低佣金炒股开户方式和流程!

股票交易佣金的最低标准是万分之0.854; 证券公司股票交易佣金默认是万分之3; 无门槛的股票交易佣金是万分之1; 万分之0.854的佣金要求投资者资产达到一定规模,不同的证券公司规定不一样。 如果没有经过证券公司客户经理协商开…

【SQL学习进阶】从入门到高级应用(一)

文章目录 MySQL命令行基本命令数据库表的概述初始化测试数据熟悉测试数据 🌈你好呀!我是 山顶风景独好 💝欢迎来到我的博客,很高兴能够在这里和您见面! 💝希望您在这里可以感受到一份轻松愉快的氛围&#x…

C++牛客周赛43题目分享(3)小红平分糖果,小红的完全平方数,小苯的字符串变化,小红的子数组排列判断

目录 ​编辑 1.前言 2.四道题目 2.1小红平分糖果 2.1.1题目描述 2.1.2输入描述 2.1.3输出描述 2.1.4示例 2.1.5代码 2.2小红的完全平方数 2.1.1题目描述 2.1.2输入描述 2.1.3输出描述 2.1.4示例 2.1.5代码 2.3小苯的字符串变化 2.1.1题目描述 2.1.2输入描述 …

想自学编程,看编程书有些看不懂,下一步应该怎么办?

不管你从事什么工作,编程都有助于你的职业发展。学习编程将给你自己赋能。我喜欢尝试新想法,时刻都有希望启动的新项目。学会编程后,我就可以坐下来自己实现,而不需要依赖他人。 编程也会提升你在其他方面的技能。因为你熟练掌握…

Gitlab不允许使用ssh拉取代码的解决方案

一、起因 之前一直是用ssh进行代码拉取,后来公司搞网安行动,不允许ssh进行连接拉取代码了 因为我是用shell写了个小型的CI/CD,部署前端项目用于后端联调的,因此在自动部署时,不方便人机交互,所以需要自动填充账密。 …

护网2024-攻防对抗解决方案思路

一、护网行动简介 近年来,网络安全已被国家上升为国家安全的战略层面,网络安全同样也被视为维护企业业务持续性的关键。国家在网络安全治理方面不断出台法规与制度,并实施了一些大型项目和计划,如网络安全法、等级保护、网络安全…

【UE C++】 虚幻引擎C++开发需要掌握的C++和U++的基础知识有哪些?

目录 0 引言1 关键的 C 知识2 Unreal Engine 相关知识3 学习建议 🙋‍♂️ 作者:海码007📜 专栏:UE虚幻引擎专栏💥 标题:【UE C】 虚幻引擎C开发需要掌握的C和U的基础知识有哪些?❣️ 寄语&…

什么情况下JVM内存中的一个对象会被垃圾回收?

什么情况下JVM内存中的一个对象会被垃圾回收? 1、什么时候会触发垃圾回收?2、被哪些变量引用的对象是不能回收的?3、Java中对象不同的引用类型4、finalize()方法的作用1、什么时候会触发垃圾回收? 平时我们系统运行创建的对象都是优先分配在新生代里的,如图: 然后如果…

【Oracle】PL SQL 怎么重新编译无效的对象

1.打开PL SQL ,点击图中有红色的 2.点击齿轮按钮即可 from:【Oracle】PL SQL 怎么重新编译无效的对象_plsql编译无效对象的按钮在哪里-CSDN博客

最新php项目加密源码

压缩包里有多少个php就会被加密多少个PHP、php无需安装任何插件。源码全开源 如果上传的压缩包里有子文件夹(子文件夹里的php文件也会被加密),加密后的压缩包需要先修复一下,步骤:打开压缩包 》 工具 》 修复压缩文件…

AIGC 010-CLIP第一个文本和图像对齐的大模型!

AIGC 010-CLIP第一个文本和图像对齐的大模型! 文章目录 0 论文工作1 论文方法2 效果 0 论文工作 不客气的说CLIP和扩散模型的成功让计算式视觉领域几乎所有工作都重新做了一遍。 CLIP(对比语言-图像预训练)论文提出了一种新的对比学习方法&a…

28-ESP32-S3 lwIP 轻量级 TCP/IP 协议栈

ESP32-S3 lwIP 介绍 ESP32-S3 是一款集成了Wi-Fi 和蓝牙功能的微控制器。它的设计初衷是为了方便嵌入式系统的开发。不过你可能会好奇,ESP32-S3 怎么实现与外部网络的通信呢?这里就要提到一个开源的 TCP/IP 协议栈,它叫做lwIP(轻…

博客系统多模块开发

创建工程 创建父工程 删除src目录&#xff0c;在pom.xml添加依赖&#xff1a; <!--统一版本 字符编码--><properties><maven.compiler.source>8</maven.compiler.source><maven.compiler.target>8</maven.compiler.target><project.b…

使用 Flask 和 Vue.js 构建 Web 应用

文章目录 入门1. 设置 Flask 后端2. 设置 Vue.js 前端 将 Flask 与 Vue.js 集成1. 配置 Flask 来提供 Vue.js 文件2. 构建 Vue.js 组件3. 运行应用程序 结论 在现代 Web 开发中&#xff0c;创建动态和响应式的应用通常涉及将后端框架如 Flask 与前端库如 Vue.js 结合起来。这种…

职责链设计模式

职责链设计模式&#xff08;Chain of Responsibility Design Pattern&#xff09;是一种行为设计模式&#xff0c;使多个对象都有机会处理请求&#xff0c;从而避免请求的发送者和接收者之间的耦合。这些对象被链接成一条链&#xff0c;沿着这条链传递请求&#xff0c;直到有一…

2024年5月20日 (周一) 叶子游戏新闻

报告老板&#xff0c;现在就加班&#xff01;《职场浮生记》抢先体验版现已上线今天由LeiYun Games开发&#xff0c;2P Games发行的《职场浮生记》正式在Steam平台推出抢先体验版。玩家将跟随主角的步伐踏入一个最为真实的职场环境之中&#xff0c;在生活与工作之间找寻平衡&am…

数据库多表查询

多表查询&#xff1a; SELECT *FROM stu_table,class WHERE stu_table.c_idclass.c_id; 多表查询——内连接 查询两张表交集部分。 隐式内连接&#xff1a; #查询学生姓名&#xff0c;和班级名称&#xff0c;隐式调用 SELECT stu_table.s_name,class.c_name FROM stu_table…