基于yolov5和desnet的猫咪识别模型

前言

前段时间给学校的猫咪小程序搭建了识猫模型,可以通过猫咪的照片辨别出是那只猫猫,这里分享下具体的方案,先看效果图:

源代码在文末

模型训练

在训练服务器(或你的个人PC)上拉取本仓库代码。

图片数据准备


进入`data`目录,执行`npm install`安装依赖。(需要 Node.js 环境,不确定老版本 Node.js 兼容性,建议使用最新版本。)


复制`config.demo.ts`文件并改名为`config.ts`,填写Laf云环境的`LAF_APPID`;


执行`npm start`,脚本将根据小程序数据库记录拉取小程序云存储中的图片。

如果不打算从laf拉取数据,也可以自己制作数据集,只要保证文件格式如下就可以

catface文件下面的data文件中的photos中有若干个文件夹,每个文件夹名称为id,文件夹下为图片。

环境搭建


返回仓库根目录,执行`python -m pip install -r requirements.txt`安装依赖。(需要Python>=3.8。不建议使用特别新版本的 Python,可能有兼容性问题。)


如果是linux系统,可以直接执行`bash prepare_yolov5.sh`拉取YOLOv5目标检测模型所需的代码,然后下载并预处理模型数据。如果是windows系统可以自己手动从gihub上拉取yolov5的模型。


执行`python3 data_preprocess.py`,脚本将使用YOLOv5从`data/photos`的图片中识别出猫猫并截取到`data/crop_photos`目录。

开始训练

执行`python3 main.py`,使用默认参数训练一个识别猫猫图片的模型。(你可以通过`python3 main.py --help`查看帮助来自定义一些训练参数。)程序运行结束时,你应当看到目录的export文件夹下存在`cat.onnx`和`cat.json`两个文件。(训练数据使用TensorBoard记录在`lightning_logs`文件夹下。若要查看准确率等信息,请自行运行TensorBoard。)


执行`python3 main.py --data data/photos --size 224 --name fallback`,使用修改后的参数训练一个在YOLOv5无法找到猫猫时使用的全图识别模型。程序运行结束时,你应当看到目录的export文件夹下存在`fallback.onnx`和`fallback.json`两个文件。

这里介绍下模型类的代码,我们定义了学习率,网络指定为densenet21

import torch
import torch.nn as nn
from torchvision import models
import torch.optim as optim
from pytorch_lightning import LightningModule
import torchmetrics
from typing import Tupleclass CatFaceModule(LightningModule):def __init__(self, num_classes: int, lr: float):super(CatFaceModule, self).__init__()self.save_hyperparameters()self.net = models.densenet121(num_classes=num_classes)self.loss_func = nn.CrossEntropyLoss()def forward(self, x: torch.Tensor) -> torch.Tensor:return self.net(x)def training_step(self, batch: Tuple[torch.Tensor, torch.LongTensor], batch_idx: int) -> torch.Tensor:loss, acc = self.do_step(batch)self.log('train/loss', loss, on_step=True, on_epoch=True)self.log('train/acc', acc, on_step=True, on_epoch=True)return lossdef validation_step(self, batch, batch_idx: int):loss, acc = self.do_step(batch)self.log('val/loss', loss, on_step=False, on_epoch=True)self.log('val/acc', acc, on_step=False, on_epoch=True)def do_step(self, batch: Tuple[torch.Tensor, torch.LongTensor]) -> Tuple[torch.Tensor, torch.Tensor]:# shape: x (B, C, H, W), y (B), w (B)x, y = batch# shape: out (B, num_classes)out = self.net(x)loss = self.loss_func(out, y)with torch.no_grad():# 每个类别分别计算准确率,以平衡地综合考虑每只猫的准确率accuracy_per_class = torchmetrics.functional.accuracy(out, y, task="multiclass", num_classes=self.hparams['num_classes'], average=None)# 去掉batch中没有出现的类别,这些位置为nannan_mask = accuracy_per_class.isnan()accuracy_per_class = accuracy_per_class.masked_fill(nan_mask, 0)# 剩下的位置取均值acc = accuracy_per_class.sum() / (~nan_mask).sum()return loss, accdef configure_optimizers(self) -> optim.Optimizer:return optim.Adam(self.parameters(), lr=self.hparams['lr'])

在模型训练完毕后可以运行我编写的modelTest,在这个文件中替换图片为自己的图片,观察输出是否正常,正常输出是这样的:

在这个输出中,通过yolo检测了图片中是否含有猫咪,通过densenet对图片所属于的类进行概率计算,概率和id按照概率从大到小排序返回。

接口实现

我们训练了两个densenet模型,一个是全图像的输入为228的模型a,一个是输入图像为128的模型b,当请求打到服务器时,应用程序会先通过yolo检测是否有猫,有的话就截取猫咪图像,使用模型b;否则不截取,使用模型a。

以下是代码:

from typing import Any
from werkzeug.datastructures import FileStorageimport torch
from PIL import Image
import numpy as np
import onnxruntime
from flask import Flask, request
from dotenv import load_dotenv
import os
import json
import time
from base64 import b64encode
from hashlib import sha256load_dotenv("./env", override=True)HOST_NAME = os.environ['HOST_NAME']
PORT = int(os.environ['PORT'])SECRET_KEY = os.environ['SECRET_KEY']
TOLERANT_TIME_ERROR = int(os.environ['TOLERANT_TIME_ERROR']) # 可以容忍的时间戳误差(s)IMG_SIZE = int(os.environ['IMG_SIZE'])
FALLBACK_IMG_SIZE = int(os.environ['FALLBACK_IMG_SIZE'])CAT_BOX_MAX_RET_NUM = int(os.environ['CAT_BOX_MAX_RET_NUM']) # 最多可以返回的猫猫框个数
RECOGNIZE_MAX_RET_NUM = int(os.environ['RECOGNIZE_MAX_RET_NUM']) # 最多可以返回的猫猫识别结果个数print("==> loading models...")
assert os.path.isdir("export"), "*** export directory not found! you should export the training checkpoint to ONNX model."crop_model = torch.hub.load('yolov5', 'custom', 'yolov5/yolov5m.onnx', source='local')with open("export/cat.json", "r") as fp:cat_ids = json.load(fp)
cat_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])with open("export/cat.json", "r") as fp:fallback_ids = json.load(fp)
fallback_model = onnxruntime.InferenceSession("export/cat.onnx", providers=["CPUExecutionProvider"])print("==> models are loaded.")app = Flask(__name__)
# 限制post大小为10MB
app.config['MAX_CONTENT_LENGTH'] = 10 * 1024 * 1024def wrap_ok_return_value(data: Any) -> str:return json.dumps({'ok': True,'message': 'OK','data': data})def wrap_error_return_value(message: str) -> str:return json.dumps({'ok': False,'message': message,'data': None})def check_signature(photo: FileStorage, timestamp: int, signature: str) -> bool:if abs(timestamp - time.time()) > TOLERANT_TIME_ERROR:return FalsephotoBase64 = b64encode(photo.read()).decode()photo.seek(0) # 重置读取位置,避免影响后续操作signatureData = (photoBase64 + str(timestamp) + SECRET_KEY).encode()return signature == sha256(signatureData).hexdigest()@app.route("/recognizeCatPhoto", methods=["POST"])
@app.route("/recognizeCatPhoto/", methods=["POST"])
def recognize_cat_photo():try:photo = request.files['photo']timestamp = int(request.form['timestamp'])signature = request.form['signature']if not check_signature(photo, timestamp=timestamp, signature=signature):return wrap_error_return_value("fail signature check.")src_img = Image.open(photo).convert("RGB")# 使用 YOLOv5 进行目标检测,结果为[{xmin, ymin, xmax, ymax, confidence, class, name}]格式results = crop_model(src_img).pandas().xyxy[0].to_dict('records')# 过滤非cat目标cat_results = list(filter(lambda target: target['name'] == 'cat', results))if len(cat_results) >= 1:cat_idx = int(request.form['catIdx']) if 'catIdx' in request.form and int(request.form['catIdx']) < len(cat_results) else 0# 裁剪出(指定的)catcat_result = cat_results[cat_idx]crop_box = cat_result['xmin'], cat_result['ymin'], cat_result['xmax'], cat_result['ymax']# 裁剪后直接resize到正方形src_img = src_img.crop(crop_box).resize((IMG_SIZE, IMG_SIZE))# 输入到cat模型img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255scores = cat_model.run([node.name for node in cat_model.get_outputs()], {cat_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()# 按概率排序cat_id_with_score = sorted([dict(catID=cat_ids[i], score=scores[i]) for i in range(len(cat_ids))], key=lambda item: item['score'], reverse=True)else:# 没有检测到cat# 整张图片直接resize到正方形src_img = src_img.resize((FALLBACK_IMG_SIZE, FALLBACK_IMG_SIZE))img = np.array(src_img, dtype=np.float32).transpose((2, 0, 1)) / 255scores = fallback_model.run([node.name for node in fallback_model.get_outputs()], {fallback_model.get_inputs()[0].name: img[np.newaxis, :]})[0][0].tolist()# 按概率排序cat_id_with_score = sorted([dict(catID=fallback_ids[i], score=scores[i]) for i in range(len(fallback_ids))], key=lambda item: item['score'], reverse=True)return wrap_ok_return_value({'catBoxes': [{'xmin': item['xmin'],'ymin': item['ymin'],'xmax': item['xmax'],'ymax': item['ymax']} for item in cat_results][:CAT_BOX_MAX_RET_NUM],'recognizeResults': cat_id_with_score[:RECOGNIZE_MAX_RET_NUM]})except BaseException as err:return wrap_error_return_value(str(err))if __name__ == "__main__":app.run(host=HOST_NAME, port=PORT, debug=False)

我们可以在本地运行,如果想测试的小伙伴可以把接口中密钥校验的代码删除,然后直接发送post请求即可。

源码链接

cat-face: 猫脸识别程序,使用yolov5和densenet分类

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/15834.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

10款免费黑科技软件,强烈推荐!

1.AI视频生成——巨日禄 网页版https://aitools.jurilu.com/ "巨日禄 "是一款功能强大的文本视频生成器&#xff0c;可以快速将文本内容转换成极具吸引力的视频。操作简单&#xff0c;用户只需输入文字&#xff0c;选择喜欢的样式和模板&#xff0c; “巨日禄”就会…

Day39贪心算法part06

LC738单调递增的数字&#xff08;未掌握&#xff09; 思路分析&#xff1a;一旦出现strNum[i - 1] > strNum[i]的情况&#xff08;非单调递增&#xff09;&#xff0c;首先想让strNum[i - 1]–&#xff0c;然后strNum[i]给为9字符串是不可变的&#xff0c;不可以使用s.char…

树莓派学习笔记——树莓派的三种GPIO编码方式

1、板载编码&#xff08;Board pin numbering&#xff09;: 板载编码是树莓派上的一种GPIO引脚编号方式&#xff0c;它指的是按照引脚在树莓派主板上的物理位置来编号。这种方式对于初学者来说可能比较直观&#xff0c;因为它允许你直接根据引脚在板上的位置来编程。 2、BCM编…

Linux gurb2简介

文章目录 前言一、GRUB 2简介二、GRUB 2相关文件/文件夹2.1 /etc/default/grub文件2.2 /etc/grub.d/文件夹2.3 /boot/grub/grub.cfg文件 三、grubx64.efi参考资料 前言 简单来说&#xff0c;引导加载程序&#xff08;boot loader&#xff09;是计算机启动时运行的第一个软件程…

文章解读与仿真程序复现思路——电力系统保护与控制EI\CSCD\北大核心《计及温控厌氧发酵和阶梯碳交易的农村综合能源低碳经济调度》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

网络域名是什么意思

网络域名&#xff0c;顾名思义&#xff0c;就是网络上的名字&#xff0c;类似于现实中的地址或姓名一样&#xff0c;用来标识网络上的一个或一组计算机或服务器的位置&#xff0c;以及它们的相应服务资源。网络域名是互联网上最基础的基础设施之一&#xff0c;是网络通信的“标…

【mysql】更新操作是如何执行的

现有一张表&#xff0c;建表语句如下&#xff1a; mysql> create table T(ID int primary key, c int);如果要将 ID2 这一行的a字段值加 1&#xff0c;SQL语句会这么写&#xff1a; mysql> update T set c c 1 where ID 2;上面这条sql执行时&#xff0c;分析器会通过词…

Nacos 微服务管理

Nacos 本教程将为您提供Nacos的基本介绍&#xff0c;并带您完成Nacos的安装、服务注册与发现、配置管理等功能。在这个过程中&#xff0c;您将学到如何使用Nacos进行微服务管理。下方是官方文档&#xff1a; Nacos官方文档 1. Nacos 简介 Nacos&#xff08;Naming and Confi…

操作符详解(上)(新手向)

操作符详解&#xff08;上&#xff09; 一&#xff0c;算术操作符&#xff08;双目操作符&#xff09;1:‘’,‘-’,‘*’2&#xff1a;‘/’&#xff0c;‘%’ 一&#xff0c;单目操作符1:‘’,‘-’2&#xff1a;‘!’3&#xff1a;‘&’4&#xff1a;‘*’5&#xff1a;…

linux 排查java内存溢出(持续更新中)

场景 tone.jar 启动后内存溢出,假设pid 为48044 排查 1.确定java程序的pid(进程id) ps 或 jps 都可以 ps -ef | grep tone jps -l 2.查看堆栈信息 jmap -heap 48044 3.查看对象的实例数量显示前30 jmap -histo:live 48044 | head -n 30 4.查看线程状态 jstack 48044

Spring 事件监听

参考&#xff1a;Spring事件监听流程分析【源码浅析】_private void processbean(final string beanname, fi-CSDN博客 一、简介 Spring早期通过实现ApplicationListener接口定义监听事件&#xff0c;Spring 4.2开始通过EventListener注解实现监听事件 FunctionalInterface p…

Rustdesk客户端源码编译

1.安装VCPKG windows平台vcpkg安装-CSDN博客 2.使用VCPKG安装: windows平台vcpkg安装-CSDN博客 配置VCPKG_ROOT环境变量: 安装静态库: ./vcpkg install libvpx:x64-windows-static libyuv:x64-windows-static opus:x64-windows-static aom:x64-windows-static 静态库安装成…

【C语言深度解剖】(15):动态内存管理和柔性数组

&#x1f921;博客主页&#xff1a;醉竺 &#x1f970;本文专栏&#xff1a;《C语言深度解剖》 &#x1f63b;欢迎关注&#xff1a;感谢大家的点赞评论关注&#xff0c;祝您学有所成&#xff01; ✨✨&#x1f49c;&#x1f49b;想要学习更多C语言深度解剖点击专栏链接查看&…

I.MX6ULL的官方 SDK 移植实验

系列文章目录 I.MX6ULL的官方 SDK 移植实验 I.MX6ULL的官方 SDK 移植实验 系列文章目录一、前言二、I.MX6ULL 官方 SDK 包简介三、硬件原理图四、试验程序编写4.1 SDK 文件移植4.2 创建 cc.h 文件4.3 编写实验代码 五、编译下载验证5.1编写 Makefile 和链接脚本5.2编译下载 一、…

列表元素添加的艺术:从单一到批量

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、向列表中添加单一元素 1. append方法 2. insert方法 三、向列表中添加批量…

MySQL 存储过程(实验报告)

一、实验名称&#xff1a; 存储过程 二、实验日期&#xff1a; 2024 年5 月 25 日 三、实验目的&#xff1a; 掌握MySQL存储过程的创建及调用&#xff1b; 四、实验用的仪器和材料&#xff1a; 硬件&#xff1a;PC电脑一台&#xff1b; 配置&#xff1a;内存&#xff0…

Android 配置本地解决下载 Gradle 慢的问题

步骤1 打开项目下 gradle/wrapper/gradle-wrapper.properties 文件。 步骤2 文件内容如下。 #Sat May 25 16:24:00 CST 2024 distributionBaseGRADLE_USER_HOME distributionPathwrapper/dists distributionUrlhttps\://services.gradle.org/distributions/gradle-8.7-bin…

【Docker学习】深入研究命令docker exec

使用docker的过程中&#xff0c;我们会有多重情况需要访问容器。比如希望直接进入MySql容器执行命令&#xff0c;或是希望查看容器环境&#xff0c;进行某些操作或访问。这时就会用到这个命令&#xff1a;docker exec。 命令&#xff1a; docker container exec 描述&#x…

Jmeter预习第1天

Jmeter参数化&#xff08;重点&#xff09; 本质&#xff1a;使用参数的方式来替代脚本中的固定为测试数据 实现方式&#xff1a; 定义变量&#xff08;最基础&#xff09; 文件定义的方式&#xff08;所有测试数据都是固定的情况下[死数据]&#xff0c;eg:注册登录&#xff0…

Linux -- 进程间通信的五种方式

IPC&#xff08;InterProcess Communication&#xff09;的方式通常有管道&#xff08;包括无名管道和命名管道&#xff09;、消息队列、信号量、共享存储、Socket、Streams等。其中Socket和Stream支持不同主机上的两个进程IPC。 管道&#xff08;Pipes&#xff09;&#xff1a…