基于yolo的视频检测分析
分步骤实现视频处理、目标检测与追踪、动作分析、计数逻辑,然后整合成API
完整的解决方案,包含视频分析逻辑和API封装,使用Python、YOLOv8和FastAPI实现代码如下:
python代码实现
import os
import uuid
import json
from datetime import datetime
from fastapi import FastAPI, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import cv2
import numpy as np
from ultralytics import YOLO
from collections import defaultdict# 初始化FastAPI应用
app = FastAPI(title="Smart Retail Cabinet API")# 配置参数
CONFIG = {"upload_folder": "uploads","result_folder": "results","model_path": "yolov8n_custom.pt", # 自定义训练的商品检测模型"shelf_roi": [(100, 50), (600, 400)], # 货架区域坐标"hand_roi": [(0, 400), (800, 600)], # 手部操作区域坐标"frame_skip": 5, # 跳帧处理间隔"confidence_threshold": 0.7
}# 数据库模拟(实际应使用数据库)
processing_queue = {}
results_db = {}# 定义数据模型
class ProcessingRequest(BaseModel):task_id: strstatus: strresult: dict = None# 工具函数
def create_folders():os.makedirs(CONFIG['upload_folder'], exist_ok=True)os.makedirs(CONFIG['result_folder'], exist_ok=True)def save_upload_file(file: UploadFile) -> str:task_id = str(uuid.uuid4())file_path = os.path.join(CONFIG['upload_folder'], f"{task_id}.mp4")with open(file_path, "wb") as buffer:buffer.write(file.file.read())return task_id, file_path# 核心分析逻辑
class CabinetAnalyzer:def __init__(self):self.model = YOLO(CONFIG['model_path'])self.track_history = defaultdict(list)self.inventory = defaultdict(int)self.current_in_shelf = set()def is_in_region(self, box, region):x1, y1, x2, y2 = boxrx1, ry1 = region[0]rx2, ry2 = region[1]return (rx1 < (x1+x2)/2 < rx2) and (ry1 < (y1+y2)/2 < ry2)def analyze_video(self, video_path: str):cap = cv2.VideoCapture(video_path)frame_count = 0event_log = []while cap.isOpened():success, frame = cap.read()if not success:breakframe_count += 1if frame_count % CONFIG['frame_skip'] != 0:continue# 运行目标检测与追踪results = self.model.track(frame, persist=True,conf=CONFIG['confidence_threshold'],verbose=False)if results[0].boxes.id is None:continue# 解析检测结果boxes = results[0].boxes.xyxy.cpu().numpy()track_ids = results[0].boxes.id.int().cpu().tolist()class_ids = results[0].boxes.cls.int().tolist()confidences = results[0].boxes.conf.float().tolist()current_frame_objects = set()for box, track_id, cls_id, conf in zip(boxes, track_ids, class_ids, confidences):# 判断物体位置in_shelf = self.is_in_region(box, CONFIG['shelf_roi'])in_hand = self.is_in_region(box, CONFIG['hand_roi'])# 更新追踪历史self.track_history[track_id].append({"frame": frame_count,"position": box.tolist(),"in_shelf": in_shelf,"in_hand": in_hand,"class_id": cls_id})current_frame_objects.add(track_id)# 检测放入/取出事件if len(self.track_history[track_id]) > 10:first_state = self.track_history[track_id][-10]current_state = self.track_history[track_id][-1]if first_state["in_shelf"] and current_state["in_hand"]:event_log.append({"type": "TAKE","track_id": track_id,"class_id": cls_id,"frame": frame_count})self.inventory[cls_id] += 1elif first_state["in_hand"] and current_state["in_shelf"]:event_log.append({"type": "RETURN","track_id": track_id,"class_id": cls_id,"frame": frame_count})self.inventory[cls_id] -= 1# 清理旧的追踪记录for track_id in list(self.track_history.keys()):if track_id not in current_frame_objects:del self.track_history[track_id]# 更新货架状态self.current_in_shelf = current_frame_objectscap.release()return {"inventory_changes": dict(self.inventory),"event_log": event_log,"total_frames": frame_count}# API端点
@app.post("/analyze")
async def create_analysis_task(file: UploadFile, background_tasks: BackgroundTasks):create_folders()task_id, video_path = save_upload_file(file)processing_queue[task_id] = {"status": "queued","start_time": datetime.now().isoformat()}background_tasks.add_task(run_analysis, task_id, video_path)return JSONResponse({"task_id": task_id,"status_url": f"/status/{task_id}"})@app.get("/status/{task_id}")
def get_task_status(task_id: str):if task_id not in processing_queue and task_id not in results_db:return JSONResponse({"error": "Task not found"}, status_code=404)if task_id in results_db:return results_db[task_id]return processing_queue[task_id]async def run_analysis(task_id: str, video_path: str):try:processing_queue[task_id]["status"] = "processing"analyzer = CabinetAnalyzer()result = analyzer.analyze_video(video_path)# 转换为商品名称class_names = analyzer.model.namesfinal_result = {"items": [{"product": class_names[cls_id],"quantity": quantity} for cls_id, quantity in result["inventory_changes"].items() if quantity > 0],"analysis_time": datetime.now().isoformat(),"metadata": {"duration_frames": result["total_frames"],"events_count": len(result["event_log"])}}# 保存结果result_file = os.path.join(CONFIG['result_folder'], f"{task_id}.json")with open(result_file, "w") as f:json.dump(final_result, f)results_db[task_id] = {"status": "completed","result": final_result,"result_url": f"/results/{task_id}"}except Exception as e:results_db[task_id] = {"status": "failed","error": str(e)}finally:if task_id in processing_queue:del processing_queue[task_id]@app.get("/results/{task_id}")
def get_analysis_result(task_id: str):if task_id not in results_db:return JSONResponse({"error": "Result not found"}, status_code=404)result_file = os.path.join(CONFIG['result_folder'], f"{task_id}.json")with open(result_file) as f:result_data = json.load(f)return JSONResponse(result_data)if __name__ == "__main__":import uvicornuvicorn.run(app, host="0.0.0.0", port=8000)
系统架构说明:
- API接口层:
/analyze:接收视频上传并启动后台分析任务
/status/{task_id}:查询任务状态
/results/{task_id}:获取分析结果
- 视频分析核心模块:
使用YOLOv8进行实时目标检测与追踪
基于空间位置的状态机判断取放行为
多帧验证机制确保事件准确性
- 数据流处理:
sequenceDiagramparticipant Clientparticipant APIparticipant Analyzerparticipant StorageClient->>API: POST /analyze (video.mp4)API->>Storage: 保存视频文件API->>Analyzer: 启动后台任务Analyzer->>Analyzer: 视频帧处理Analyzer->>Storage: 保存分析结果Client->>API: GET /status/123API->>Client: 返回处理状态Client->>API: GET /results/123API->>Client: 返回JSON结果
使用方法:
- 启动服务:
uvicorn main:app --reload
- 调用API:
# 上传视频
curl -X POST -F "file=@test.mp4" http://localhost:8000/analyze# 响应示例
{"task_id": "550e8400-e29b-41d4-a716-446655440000","status_url": "/status/550e8400-e29b-41d4-a716-446655440000"
}# 查询结果
curl http://localhost:8000/results/550e8400-e29b-41d4-a716-446655440000# 响应示例
{"items": [{"product": "coca_cola", "quantity": 2},{"product": "pringles", "quantity": 1}],"analysis_time": "2024-01-20T14:30:00","metadata": {"duration_frames": 300,"events_count": 3}
}
优化建议:
- 模型优化:
使用TensorRT加速推理
量化模型到INT8精度
针对具体商品优化YOLO模型
- 系统扩展:
使用Redis存储任务队列
添加Celery分布式任务队列
实现视频分块处理
- 安全增强:
添加JWT身份验证
限制文件上传类型和大小
实现请求频率限制
- 功能扩展:
添加多摄像头支持
实现实时视频流分析
增加库存同步接口
该方案实现了完整的取货分析流程,并通过API提供标准化接口,适合集成到各类零售管理系统。实际部署时需要根据具体硬件配置调整视频处理参数(如frame_skip值)。