基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)

目录

一、项目界面

二、代码实现

1、数据集结构

2、设置需要模型的训练参数和指定数据集路径

3、网络代码

4、训练代码

5、评估代码

6、结果显示

三、项目代码


一、项目界面

二、代码实现

1、数据集结构

每一个文件夹对应一个类别的数据

2、设置需要模型的训练参数和指定数据集路径

# 数据名字标签
label_names = {0:"daisy",1:"dandelion",2:"rose",3:"sunflower",4:"tulip",}# 类别数量,根据label_names标签名自动得出
num_classes = len(label_names)# 重采样大小。如果无则填None
re_size = (28,28)# 训练集地址,默认即可
train_path = r"./data/train"
# 验证集地址,默认即可
val_path = r"./data/val"
# 测试集地址,默认即可
test_path = r"./data/test"# 图像后缀
img_ = "jpg"# 批量大小
batch_size = 64# 结果保存地址
save_results = r"./results"# 学习率
lr = 0.001# 迭代次数
epochs = 20# ----------划分数据集参数-----------
# 确定将数据集划分为训练集,验证集,测试集的比例
train_pct = 0.5
valid_pct = 0.1
test_pct = 0.4# 确定原图像数据集路径。默认即可
dataset_dir = r"./data/data"  # 原始数据集路径
# 确定数据集划分后保存的路径
split_dir = r"./data"         # 划分后保存路径
3、网络代码

该网络基于残差模型修改

import torch
import torch.nn as nn
import torchvision.models as modelsclass resnet18(nn.Module):def __init__(self, num_classes=5, pretrained=False):super(resnet18, self).__init__()# 加载ResNet-18模型self.model = models.resnet18(pretrained=pretrained)# print(self.model)# 更改全连接层以输出自定义类别数量self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)def forward(self, x):return self.model(x)if __name__ == '__main__':# 示例用法num_classes = 10model = resnet18(num_classes=num_classes)# 打印模型以确认更改print(model)
4、训练代码
import os
import torch
import torch.nn as nn
from models.resnet18 import resnet18
from utils.utils import train_and_val,plot_acc,plot_loss,plot_lr,MyDataset
import numpy as np
from torch.utils.data import DataLoader
import glob
import pandas as pd
import configdef main(epochs,model):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")if not os.path.exists(config.save_results):os.makedirs(config.save_results)# ----------------------------模型加载-------------------------model = model.to(device)loss_function = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=5,gamma=0.9)  # 每经过5个epoch,学习率乘以0.9# ------------------------------------------------------------# ---------------------------加载数据--------------------------im_train_list = glob.glob(config.train_path + "/*/*." + config.img_)im_val_list = glob.glob(config.val_path + "/*/*." + config.img_)train_dataset = MyDataset(im_train_list, config.label_names)val_dataset = MyDataset(im_val_list, config.label_names)train_loader = DataLoader(train_dataset,batch_size=config.batch_size,shuffle=True)val_loader = DataLoader(val_dataset,batch_size=config.batch_size,shuffle=False)print("num of train", len(train_dataset))print("num of val", len(val_loader))# ------------------------------------------------------------# ---------------------------网络训练--------------------------history = train_and_val(epochs, model, train_loader,val_loader,loss_function, optimizer,scheduler,config.save_results,device)df = pd.DataFrame(history) # 转换为DataFramedf.to_excel(os.path.join(config.save_results,'history.xlsx'), index=False) # 保存为 Excel 文件plot_loss(np.arange(0,epochs),config.save_results, history)plot_acc(np.arange(0,epochs),config.save_results, history)plot_lr(np.arange(0,epochs),config.save_results, history)if __name__ == '__main__':model = resnet18(num_classes=config.num_classes)main(config.epochs,model)
5、评估代码
from sklearn.metrics import classification_report
import torch
import os
import torch.nn as nn
from tqdm import tqdm
import pandas as pd
from models.resnet18 import resnet18
import matplotlib.pyplot as plt
from utils.utils import MyDataset,reports
from torch.utils.data import DataLoader
import seaborn as sns
import glob
import configdef main(model):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ----------------------------模型加载-------------------------model = model.to(device)checkpoint = torch.load(os.path.join(config.save_results,"best.pth"))model.load_state_dict(checkpoint, strict=True)model.eval()# ------------------------------------------------------------# ---------------------------加载数据--------------------------im_test_list = glob.glob(config.test_path + "/*/*." + config.img_)test_dataset = MyDataset(im_test_list, config.label_names)test_loader = DataLoader(test_dataset,batch_size=config.batch_size,shuffle=False)print("num of test", len(test_loader))# ------------------------------------------------------------act = nn.Softmax(dim=-1)y_true, y_pred = [], []with torch.no_grad():with tqdm(total=len(test_loader)) as pbar:for images, labels in test_loader:outputs = act(model(images.to(device)))_, predicted = torch.max(outputs, 1)predicted = predicted.cpu()y_pred.extend(predicted.numpy())y_true.extend(labels.cpu().numpy())pbar.update(1)oa,aa,kappa,cls,cm = reports(y_true, y_pred)cr = classification_report(y_true, y_pred, target_names=config.label_names.values(), output_dict=True)df = pd.DataFrame(cr).transpose()df.to_csv(os.path.join(config.save_results,"classification_report.csv"), index=True)print("Accuracy is :", oa)with open(os.path.join(config.save_results,"results.txt"), "a") as file:file.write('OA:{:.4f} AA:{:.4f} kappa:{:.4f}\ncls:{}\n混淆矩阵:\n{}\n'.format(oa, aa, kappa,cls,cm))plt.figure(figsize=(10, 7))sns.heatmap(cm, annot=True, xticklabels=config.label_names.values(), yticklabels=config.label_names.values(), cmap='Blues', fmt="d")plt.xlabel('Predicted')plt.ylabel('True')plt.savefig(os.path.join(config.save_results,'test_confusion_matrix.png'))plt.clf()if __name__ == '__main__':model = resnet18()main(model)
6、结果显示

上述仅仅是简单演示,结果没有参考意义。

三、项目代码

本项目的代码通过以下链接下载:基于深度学习的图像分类或识别系统(含全套项目+PyQt5界面)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/53247.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

09.20 C++对C的扩充以及C++中的封装、SeqList

SeqList.h #ifndef SEQLIST_H #define SEQLIST_H#include <iostream> #include<memory.h> #include<stdlib.h> #include<string.h>using namespace std;//typedef int datatype; //类型重命名 using datatype int;//封装一个顺序表 class Seq…

app抓包 chrome://inspect/#devices

一、前言&#xff1a; 1.首先不支持flutter框架&#xff0c;可支持ionic、taro 2.初次需要翻墙 3.app为debug包&#xff0c;非release 二、具体步骤 1.谷歌浏览器地址&#xff1a;chrome://inspect/#devices qq浏览器地址&#xff1a;qqbrowser://inspect/#devi…

新媒体运营

一、新媒体运营的概念 1.新媒体 2.新媒体运营的五大方向 用户运营 产品运营 。。。 二、新媒体的岗位职责及要求 三、新媒体平台

快速开发与维护:探索 AndroidAnnotations

在移动应用开发的世界中&#xff0c;效率和可维护性是两个至关重要的要素。随着应用功能的不断增长和用户需求的不断变化&#xff0c;开发者们一直在寻找能够提高生产力的工具和框架。今天&#xff0c;我们将深入探讨一个能够帮助开发者实现快速开发和易于维护的框架——Androi…

dgl库安装

此篇文章继续上一篇pytorch已经安装成功的情况 &#xff08;python3.9&#xff0c;pytorch2.2.2&#xff0c;cuda11.8&#xff09; 上一篇pytorch安装教学链接 选择与之匹配的版本 输入下方代码进行测试 import dgl.data dataset dgl.data.CoraGraphDataset() print(‘Numb…

使用宝塔部署项目在win上

项目部署 注意&#xff1a; 前后端部署项目&#xff0c;需要两个域名&#xff08;二级域名&#xff0c;就是主域名结尾的域名&#xff0c;需要在主域名下添加就可以了&#xff09;&#xff0c;前端一个&#xff0c;后端一个 思路&#xff1a;访问域名就会浏览器会加载前端的代…

【Redis入门到精通二】Redis核心数据类型(String,Hash)详解

目录 Redis数据类型 1.String类型 &#xff08;1&#xff09;常见命令 &#xff08;2&#xff09;内部编码 2.Hash类型 &#xff08;1&#xff09;常见命令 &#xff08;2&#xff09;内部编码 Redis数据类型 查阅Redis官方文档可知&#xff0c;Redis提供给用户的核心数据…

【HTML5】html5开篇基础(1)

1.❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; Hello, Hello~ 亲爱的朋友们&#x1f44b;&#x1f44b;&#xff0c;这里是E绵绵呀✍️✍️。 如果你喜欢这篇文章&#xff0c;请别吝啬你的点赞❤️❤️和收藏&#x1f4d6;&#x1f4d6;。如果你对我的…

代码随想录冲冲冲 Day51 图论Part3

101. 孤岛的总面积 dfs 首先dfs的作用就是在遇到陆地的时候找到所有的周围陆地 对于这道题的dfs 会把所有的链接边缘的陆地变成海洋 这样在全部调整之后 剩下的就是孤岛了 这道题中的dfs的结束条件就是遇到海洋时 遇到每一个陆地就会把面积1&#xff0c;在每一次重新找到…

(Java企业 / 公司项目)点赞业务系统设计-批量查询点赞状态(二)

接着上一篇文章来搞,批量查询点赞状态。这个接口提供给其他的微服务调用所以这里会用到FeignClient 直接上接口 1. 接口信息 这里是查询多个业务的点赞状态,因此请求参数自然是业务id的集合。由于是查询当前用戶的点赞状态,因此无需传递用戶信息。当前用户指的是登录用户 …

ELF文件结构

ELF文件格式的最前部是 ELF文件头&#xff08;ELF Header&#xff09; &#xff0c;包含整个文件的基本属性。然后是各个节&#xff0c;ELF文件中与节有关的结构是 “节表&#xff08;Section Header Table&#xff09;”&#xff0c;节表描述ELF文件包含的所有节的信息。 文件…

layui时间选择器选择周 日月季度年

<!-- layui框架样式 --><link type"text/css" href"/static/plugins/layui/css/layui.css" rel"stylesheet" /><!-- layui框架js --><script type"text/javascript" src"/static/plugins/layui/layui.js&qu…

MYSQL面试知识点手册

第一部分&#xff1a;MySQL 基础知识 1.1 MySQL 简介 MySQL 是世界上最流行的开源关系型数据库管理系统之一&#xff0c;它以性能卓越、稳定可靠和易用性而闻名。MySQL 主要应用在 Web 开发、大型互联网公司、企业级应用等场景&#xff0c;且广泛用于构建高并发、高可用的数据…

Qt_多元素控件

目录 1、认识多元素控件 2、QListWidget 2.1 使用QListWidget 3、QTableWidget 3.1 使用QListWidget 4、QTreeWidget 4.1 使用QTreeWidget 5、QGroupBox 5.1 使用QGroupBox 6、QTabWidget 6.1 使用QTabWidget 结语 前言&#xff1a; 在Qt中&#xff0c;控件之间…

GAMES104:15 游戏引擎的玩法系统基础-学习笔记

文章目录 0&#xff0c;游戏性课程框架一&#xff0c;事件机制1.1 事件的定义1.2 callback的注册1.3 事件的分发系统 二&#xff0c;游戏逻辑与脚本系统2.1 特点和常见脚本语言2.2 脚本语言的GO管理2.3 脚本语言的架构2.4 可视化脚本 三&#xff0c;Gameplay 开发中的3C &#…

Zookeeper安装使用教程

# 安装 官网下载安装包 #配置文件 端口默认8080&#xff0c;可能需要更改一下 #启动 cd /Users/lisongsong/software/apache-zookeeper-3.7.2-bin/bin ./zkServer.sh start #查看运行状态 ./zkServer.sh status #停止 ./zkServer.sh stop #启动客户端 ./zkCli.sh ls /

深度学习之图像数据集增强(Data Augmentation)

文章目录 一、 数据增强概述二、python实现传统数据增强参考文献 一、 数据增强概述 数据增强&#xff08;Data Augmentation&#xff09;是一种技术&#xff0c;通过对现有数据进行各种变换和处理来生成新的训练样本&#xff0c;从而增加数据集的多样性和数量。这些变换可以是…

vue part 11

vuex的模块化与namespace 115_尚硅谷Vue技术_vuex模块化namespace_1_哔哩哔哩_bilibili 116_尚硅谷Vue技术_vuex模块化namespace_2_哔哩哔哩_bilibili vue-router路由 很常见的很重要的应用&#xff1a;Ajax请求&#xff0c;将响应的数据替换掉原先的代码从而实现不跳转页面…

网站SEO,该如何规范目标网站URL配置!

随着互联网技术的飞速发展&#xff0c;搜索引擎优化&#xff08;SEO&#xff09;在网站建设和运营中的重要性日益凸显。优化目标网站的URL配置&#xff0c;作为SEO策略中的关键环节&#xff0c;对于提升网站在搜索引擎中的排名和曝光度具有至关重要的作用。大连蝙蝠侠科技将从U…

滚珠花键与滚珠丝杆的区别与应用

在机械工业中&#xff0c;经常使用滚珠花键这种传动元件&#xff0c;人们经常拿它与滚珠丝杆相比较&#xff0c;甚至与之混淆。事实上&#xff0c;它们是不同的&#xff0c;滚珠花键和滚珠丝杆在机械传动领域中各有其独特的作用和特点。那么&#xff0c;两者之间的区别是什么呢…