基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

 第一步:准备数据

25种宝石数据,总共800张:

{ "0": "Alexandrite","1": "Almandine","2": "Benitoite","3": "Beryl Golden","4": "Carnelian", "5": "Cats Eye","6": "Danburite", "7": "Diamond","8": "Emerald","9": "Fluorite","10": "Garnet Red","11": "Hessonite","12": "Iolite","13": "Jade","14": "Kunzite","15": "Labradorite","16": "Malachite","17": "Onyx Black","18": "Pearl","19": "Quartz Beer","20": "Rhodochrosite","21": "Sapphire Blue","22": "Tanzanite","23": "Variscite","24": "Zircon"}

第二步:搭建模型

本文选择一个RegNet网络,其原理介绍如下:

该论文提出了一个新的网络设计范式,并不是专注于设计单个网络实例,而是设计了一个网络设计空间network design space。整个过程类似于经典的手工网络设计,但被提升到了设计空间的水平。使用本文的方法,作者探索了网络设计的结构方面,并得到了一个由简单、规则的网络构成了低维设计空间并称之为RegNet。RegNet设计空间提供了各个范围flop下简单、快速的网络。在类似的训练设置和flops下,RegNet的效果超过了EfficientNet同时在GPU上快了5倍

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import create_regnet
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_regnet(model_name=args.model_name,num_classes=args.num_classes).to(device)# print(model)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("train {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=25)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\gemstone\archive_train")parser.add_argument('--model-name', default='RegNetY_400MF', help='create model name')# 预训练权重下载地址# 链接: https://pan.baidu.com/s/1XTo3walj9ai7ZhWz7jh-YA  密码: 8lmuparser.add_argument('--weights', type=str, default='regnety_400mf.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习RegNet神经网络二十五种宝石识别分类系统源码

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/18245.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字化农业新时代:图扑农林牧综合监控平台

利用图扑自研 HT for Web GIS 产品,结合遥感技术,构建可交互式的农林牧数据分析平台。该平台围绕地块总览、播种分析、牛只管理、设备查询四个维度,对地区的全貌、农场、村集体分布以及相应的环境进行多样化的可视化展示和进行数据支持&#…

爱岗敬业短视频:成都科成博通文化传媒公司

爱岗敬业短视频:传递正能量,塑造职场新风尚 在当今社会,短视频以其独特的传播方式和广泛的受众群体,成为了信息传播的重要渠道。在众多短视频内容中,以“爱岗敬业”为主题的短视频尤为引人注目,成都科成博…

js Ajax函数封装及使用

直接上代码 一、ajax函数封装 /*** ajax函数* param {Object} options 请求传入的对象参数*/ function ajax(options {}) {// 1. 参数校验// 校验请求地址必传,而只能是字符串类型if (!options.url || typeof (options.url) ! string) throw Error(url必传,只能是字符串);//…

每天发布1000个视频SOP之账号管理:抖音可用的上传管理账号的浏览器安装包23个

企业抖音运营: 全流程整套操作SOP, 每天发布1000个视频工作管理体系: (因为上传限制,分成3个压缩包资源上传) 这是其中的:SOP矩阵划管理登录抖音平台账号,上传管理运营账号&#…

FreeRtos进阶——队列的特殊用途

信号量与互斥量都一样,都是特殊的队列。但是只有互斥量实现了优先级继承机制。 信号量与互斥量与队列一样,在操作增加或者减少时,必须先关中断在进行操作! 信号量创建揭秘 图中信号量的创建过程,在代码中的体现本质就是…

设计模式 16 解释器模式 Interpreter Design Pattern

设计模式 16 解释器模式 Interpreter Design Pattern 1.定义 解释器模式 (Interpreter Design Pattern) 是一种行为型设计模式,它定义了一种语法表示,并提供了一种解释器来解释该语法表示的句子。 核心概念: 语法表示 (Grammar): 定义了…

如何使用 jQuery 库来删除 HTML 页面中指定的元素下的所有子元素,但是保留其中一个特定的子元素

如何使用 jQuery 库来删除 HTML 页面中指定的元素下的所有子元素&#xff0c;但是保留其中一个特定的子元素 示例如下&#xff1a; 假设我们有以下的 HTML 代码&#xff1a; <div id"container"><div id"header">Header</div><div…

现在股票交易佣金标准最低是万0.854,低佣金炒股开户方式和流程!

股票交易佣金的最低标准是万分之0.854&#xff1b; 证券公司股票交易佣金默认是万分之3&#xff1b; 无门槛的股票交易佣金是万分之1&#xff1b; 万分之0.854的佣金要求投资者资产达到一定规模&#xff0c;不同的证券公司规定不一样。 如果没有经过证券公司客户经理协商开…

【SQL学习进阶】从入门到高级应用(一)

文章目录 MySQL命令行基本命令数据库表的概述初始化测试数据熟悉测试数据 &#x1f308;你好呀&#xff01;我是 山顶风景独好 &#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01; &#x1f49d;希望您在这里可以感受到一份轻松愉快的氛围&#x…

SolidityFoundry 安全审计测试 tx.origin 漏洞

名称&#xff1a;tx.origin 漏洞 solidityproject/vulnerable-defi at master XuHugo/solidityproject GitHub 说明&#xff1a; tx.origin是Solidity中的一个全局变量&#xff1b;智能合约中使用该变量进行身份验证&#xff0c;会使合约容易受到网络钓鱼攻击。 msg.sende…

C++牛客周赛43题目分享(3)小红平分糖果,小红的完全平方数,小苯的字符串变化,小红的子数组排列判断

目录 ​编辑 1.前言 2.四道题目 2.1小红平分糖果 2.1.1题目描述 2.1.2输入描述 2.1.3输出描述 2.1.4示例 2.1.5代码 2.2小红的完全平方数 2.1.1题目描述 2.1.2输入描述 2.1.3输出描述 2.1.4示例 2.1.5代码 2.3小苯的字符串变化 2.1.1题目描述 2.1.2输入描述 …

Java 原子变量 一次通关

前言 Java中的原子变量是用于实现无锁的线程安全编程的一种机制。它们是java.util.concurrent.atomic包中的一部分&#xff0c;这个包提供了一系列原子类&#xff0c;用于执行原子操作。 主要类型 Java的原子包提供了多种原子类&#xff0c;包括&#xff1a; 基本类型&…

想自学编程,看编程书有些看不懂,下一步应该怎么办?

不管你从事什么工作&#xff0c;编程都有助于你的职业发展。学习编程将给你自己赋能。我喜欢尝试新想法&#xff0c;时刻都有希望启动的新项目。学会编程后&#xff0c;我就可以坐下来自己实现&#xff0c;而不需要依赖他人。 编程也会提升你在其他方面的技能。因为你熟练掌握…

对SpringBoot配置文件配置项加密原理

参考认识BeanFactoryPostProcessor和BeanDefinitionRegistryPostProcessor - 知乎 SpringBoot 之 Jasypt 实现yml配置文件加密_-djasypt.encryptor.password-CSDN博客 【springboot】jasypt加密_jasypt.encryptor.password-CSDN博客 实现&#xff1a; 导包&#xff1a; 使…

Gitlab不允许使用ssh拉取代码的解决方案

一、起因 之前一直是用ssh进行代码拉取&#xff0c;后来公司搞网安行动&#xff0c;不允许ssh进行连接拉取代码了 因为我是用shell写了个小型的CI/CD,部署前端项目用于后端联调的&#xff0c;因此在自动部署时&#xff0c;不方便人机交互&#xff0c;所以需要自动填充账密。 …

ZLMediaKit cmake 编译 要点

# 加载自定义模块 list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") project(ZLMediaKit LANGUAGES C CXX) # 使能 C11 set(CMAKE_CXX_STANDARD 11) ############################### SET(CMAKE_SYSTEM_NAME Linux) SET(CMAKE_SYSTEM_PROCESS…

护网2024-攻防对抗解决方案思路

一、护网行动简介 近年来&#xff0c;网络安全已被国家上升为国家安全的战略层面&#xff0c;网络安全同样也被视为维护企业业务持续性的关键。国家在网络安全治理方面不断出台法规与制度&#xff0c;并实施了一些大型项目和计划&#xff0c;如网络安全法、等级保护、网络安全…

【UE C++】 虚幻引擎C++开发需要掌握的C++和U++的基础知识有哪些?

目录 0 引言1 关键的 C 知识2 Unreal Engine 相关知识3 学习建议 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏&#xff1a;UE虚幻引擎专栏&#x1f4a5; 标题&#xff1a;【UE C】 虚幻引擎C开发需要掌握的C和U的基础知识有哪些&#xff1f;❣️ 寄语&…

【MySQL精通之路】安全(1)-安全指南

任何在连接到Internet的计算机上使用MySQL的人都应该阅读本节&#xff0c;以避免最常见的安全错误。 在讨论安全性时&#xff0c;有必要考虑充分保护整个服务器主机&#xff08;而不仅仅是MySQL服务器&#xff09;免受所有类型的适用攻击&#xff1a;窃听、更改、播放和拒绝服…

kafkastream

kafkastream的介绍&#xff1a; Kafka Streams是一个开源的流处理库&#xff0c;用于构建实时数据流应用程序和微服务。它是Apache Kafka项目的一部分&#xff0c;是一种基于事件驱动的流处理解决方案。 Kafka Streams提供了高级别的API&#xff0c;使开发人员能够以简单和声…