【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型

前言

Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到,大大减少了目标在遮挡后,追踪失败的可能。

一、数据集简介

Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低清摄像头)拍摄到的 1501 个行人、32668 个检测到的行人矩形框。每个行人至少由2个摄像头捕获到,并且在一个摄像头中可能具有多张图像。训练集有 751 人,包含 12,936 张图像,平均每个人有 17.2 张训练数据;测试集有 750 人,包含 19,732 张图像,平均每个人有 26.3 张测试数据。3368 张查询图像的行人检测矩形框是人工绘制的,而 gallery 中的行人检测矩形框则是使用DPM检测器检测得到的。

该数据集提供的固定数量的训练集和测试集均可以在single-shot或multi-shot测试设置下使用。

目录结构

Market-1501-v15.09.15

  ├── bounding_box_test

       ├── 0000_c1s1_000151_01.jpg

       ├── 0000_c1s1_000376_03.jpg

       ├── 0000_c1s1_001051_02.jpg

  ├── bounding_box_train

       ├── 0002_c1s1_000451_03.jpg

       ├── 0002_c1s1_000551_01.jpg

       ├── 0002_c1s1_000801_01.jpg

  ├── gt_bbox

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c1s1_009376_00.jpg

       ├── 0001_c2s1_001976_00.jpg

  ├── gt_query

       ├── 0001_c1s1_001051_00_good.mat

       ├── 0001_c1s1_001051_00_junk.mat

  ├── query

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c2s1_000301_00.jpg

       ├── 0001_c3s1_000551_00.jpg

  └── readme.txt

目录介绍

(1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)

(2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像

(3) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

(4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)

(5) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box

命名规则

以 0001_c1s1_000151_01.jpg 为例

1) 0001 表示每个人的标签编号,从0001到1501;

2) c1 表示第一个摄像头(camera1),共有6个摄像头;

3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;

4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;

5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

二、跟踪模型介绍

特征提取的模型有很多,可以替换特征提取模型网络。

下述给出的是 deep_sort/deep/model.py 里面的模型代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):def __init__(self, c_in, c_out, is_downsample=False):super(BasicBlock, self).__init__()self.is_downsample = is_downsampleif is_downsample:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)else:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(c_out)self.relu = nn.ReLU(True)self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(c_out)if is_downsample:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),nn.BatchNorm2d(c_out))elif c_in != c_out:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),nn.BatchNorm2d(c_out))self.is_downsample = Truedef forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)if self.is_downsample:x = self.downsample(x)return F.relu(x.add(y), True)def make_layers(c_in, c_out, repeat_times, is_downsample=False):blocks = []for i in range(repeat_times):if i == 0:blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]else:blocks += [BasicBlock(c_out, c_out), ]return nn.Sequential(*blocks)class Net(nn.Module):def __init__(self, num_classes=751, reid=False):super(Net, self).__init__()# 3 128 64self.conv = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# nn.Conv2d(32,32,3,stride=1,padding=1),# nn.BatchNorm2d(32),# nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, padding=1),)# 32 64 32self.layer1 = make_layers(64, 64, 2, False)# 32 64 32self.layer2 = make_layers(64, 128, 2, True)# 64 32 16self.layer3 = make_layers(128, 256, 2, True)# 128 16 8self.layer4 = make_layers(256, 512, 2, True)# 256 8 4self.avgpool = nn.AvgPool2d((8, 4), 1)# 256 1 1self.reid = reidself.classifier = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(256, num_classes),)def forward(self, x):x = self.conv(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)# B x 128if self.reid:x = x.div(x.norm(p=2, dim=1, keepdim=True))return x# classifierx = self.classifier(x)return xif __name__ == '__main__':net = Net()x = torch.randn(4, 3, 128, 64)y = net(x)import ipdbipdb.set_trace()

三、数据集处理

splitDataset.py

用于存放训练的图片

# -*- coding:utf-8 -*-
# @author: 牧锦程
# @微信公众号: AI算法与电子竞赛
# @Email: m21z50c71@163.com
# @VX:fylaicaiimport os
from shutil import copyfile# You only need to change this line to your dataset download path
download_path = 'Market-1501-v15.09.15'if not os.path.isdir(download_path):print('please change the download_path')save_path = 'pytorch'
if not os.path.isdir(save_path):os.mkdir(save_path)# ------------------- query ----------------------
query_path = download_path + '/query'
query_save_path = save_path + '/query'
print("process: ", query_path)
if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ----------------- multi-query ------------------------
query_path = download_path + '/gt_bbox'
print("process: ", query_path)
# for dukemtmc-reid, we do not need multi-query
if os.path.isdir(query_path):query_save_path = save_path + '/multi-query'if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ------------------- gallery ----------------------
gallery_path = download_path + '/bounding_box_test'
gallery_save_path = save_path + '/gallery'
print("process: ", gallery_path)
if not os.path.isdir(gallery_save_path):os.mkdir(gallery_save_path)for root, dirs, files in os.walk(gallery_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = gallery_path + '/' + namedst_path = gallery_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ------------------ train ---------------------
train_path = download_path + '/bounding_box_train'
train_save_path = save_path + '/train'
val_save_path = save_path + '/test'
if not os.path.isdir(train_save_path):os.mkdir(train_save_path)os.mkdir(val_save_path)print("process: ", train_path)
for root, dirs, files in os.walk(train_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = train_path + '/' + namedst_path = train_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)# first image is used as val imagedst_path = val_save_path + '/' + ID[0]os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)

四、模型训练

修改数据集路径

修改 data-dir 参数为自己的数据集路径

修改数据增强

增加一个尺寸修改

修改类别数量

代码中通过dataloader来获取,因此可以不进行修改

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

修改保存模型名称

这里的修改是为例与原始的模型进行区分,可以做对比

训练结果

查看精度

先运行test.py,生成 features.pth

在运行 evaluate.py,得到如下精度:

五、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/25272.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业网页制作

随着互联网的普及,企业网站已成为企业展示自己形象、吸引潜在客户、开拓新市场的重要方式。而企业网页制作则是构建企业网站的基础工作,它的质量和效率对于企业网站的成败至关重要。 首先,企业网页制作需要根据企业的特点和需求进行规划。在网…

前端 移动端 手机调试 (超简单,超有效 !)

背景:webpack工具构建下的vue项目 1. 找出电脑的ipv4地址 2. 替换 host 3. 手机连接电脑热点或者同一个wifi 。浏览器打开链接即可。

Spring运维之业务层测试数据回滚以及设置测试的随机用例

业务层测试数据回滚 我们之前在写dao层 测试的时候 如果执行到这边的代码 会在数据库 里面留下数据 运行一次留一次数据 开发有开发数据库,运行有运行数据库 我们先连数据库 在pom文件里引入mysql的驱动和mybatis-plus的依赖 在数据层写接口 用mybatis-plus进…

浅谈JDBC

文章目录 一、什么是 JDBC?二、JDBC 操作流程三、JDBC代码例子 一、什么是 JDBC? JDBC是一种可用于执行SQL语句的JAVA API,是链接数据库和JAVA应用程序的纽带。JDBC一般需要进行3个步骤:与数据库建立一个链接、向数据库发送SQL语…

当AGI能够自我复制并传播到任意电脑时,会怎样?

当人工通用智能(AGI)具备自我复制能力,并能够传播到任何一台计算机作为宿主机时,这种情景可能带来一系列深远的影响和挑战。以下是一些潜在的影响和可能的结果: 1. 安全威胁与恶意利用 AGI的自我复制和传播能力可能被…

openh264 场景变化检测算法源码分析

文件位置 openh264/codec/processing/scenechangedetection/SceneChangeDetection.cppopenh264/codec/processing/scenechangedetection/SceneChangeDetection.h 代码流程 说明: 通过代码流程分析,当METHOD_SCENE_CHANGE_DETECTION_SCREEN场景类型为时…

Web前端工资调整:深入剖析与全面解读

Web前端工资调整:深入剖析与全面解读 在快速发展的互联网行业中,Web前端技术日新月异,而与之紧密相关的工资调整也成为了业内热议的话题。本文将从四个方面、五个方面、六个方面和七个方面,深入剖析Web前端工资调整的原因、趋势、…

Linux -- 了解 vim

目录 vim Linux 怎么编写代码? 了解 vim 的模式 什么是命令模式? 命令模式下 vim 的快捷键: 光标定位: 复制粘贴: 删除及撤销: 注释代码: 什么是底行模式? ​编辑 ​编辑…

Java:111-SpringMVC的底层原理(中篇)

这里续写上一章博客(110章博客): 现在我们来学习一下高级的技术,前面的mvc知识,我们基本可以在67章博客及其后面相关的博客可以学习到,现在开始学习精髓: Spring MVC 高级技术: …

输入与随机数

Java的两个类库 输入 如何实现键盘输入?我们需要了解到Scanner这个类,其作用于及键盘输入。 类库:java.util 如何使用?分为3步走: 导入包(一般idea会帮做) import java.util.Scanner;创建对…

Java学习日志26:Double.NEGATIVE_INFINITY与Double.MIN_VALUE的区别

Double.NEGATIVE_INFINITY 和 Double.MIN_VALUE 都是 Java 中 Double 类的静态成员,但它们表示的含义不同。 Double.NEGATIVE_INFINITY 表示双精度浮点数的负无穷大。它是一个特殊的数值,在数值计算中通常用于表示一些特殊情况,比如除以零的…

Large-Scale LiDAR Consistent Mapping usingHierarchical LiDAR Bundle Adjustment

1. 代码地址 GitHub - hku-mars/HBA: [RAL 2023] A globally consistent LiDAR map optimization module 2. 摘要 重建精确一致的大规模激光雷达点云地图对于机器人应用至关重要。现有的基于位姿图优化的解决方案,尽管它在时间方面是有效的,但不能直接…

ubuntu使用docker安装openwrt

系统:ubuntu24.04 架构:x86 1. 安装docker 1.1 离线安装 docker下载地址 根据系统版本,依次下载最新的三个关于docker的软件包 container.io(注意后缀版本顺序)docker-ce-clidocker-ce 然后再ubuntu系统中依次按顺…

【召回第一篇】召回方法综述

各个网站上找的各位大神的优秀回答,记录再此。 首先是石塔西大佬的回答:工业界推荐系统中有哪些召回策略? 万变不离其宗:用统一框架理解向量化召回前言常读我的文章的同学会注意到,我一直强调、推崇,不要…

多种策略提升线上 tensorflow 模型推理速度

前言 本文以最常见的模型 Bi-LSTM-CRF 为例,总结了在实际工作中能有效提升在 CPU/GPU 上的推理速度的若干方法,包括优化模型结构,优化超参数,使用 onnx 框架等。当然如果你有充足的 GPU ,结合以上方法提升推理速度的效…

真空衰变,真正的宇宙级灾难,它到底有多可怕?

真空衰变,真正的宇宙级灾难,它到底有多可怕? 真空衰变 真空衰变(Vacuum decay)是物理学家根据量子场论推测出的一种宇宙中可能会发生的现象,这种现象被称为真正的宇宙级灾难,它到底有多可怕呢…

前端 Vue 操作文件方法(导出下载、图片压缩、文件上传和转换)

一、前言 本文对前端 Vue 项目开发过程中,经常遇到要对文件做一些相关操作,比如:文件导出下载、文件上传、图片压缩、文件转换等一些处理方法进行归纳整理,方便后续查阅和复用。 二、具体内容 1、后端的文件导出接口,…

【报文数据流中的反压处理】

报文数据流中的反压处理 1 带存储体的反压1.1 原理图1.2 Demo 尤其是在NP芯片中,经常涉及到报文的数据流处理;为了防止数据丢失,和各模块的流水处理;因此需要到反压机制; 反压机制目前接触到的有两种:一是基…

【深度学习】目标检测,Faster-RCNN算法训练,使用mmdetection训练

文章目录 资料环境数据测试 资料 https://mmdetection.readthedocs.io/zh-cn/latest/user_guides/config.html 环境 Dockerfile ARG PYTORCH"1.9.0" ARG CUDA"11.1" ARG CUDNN"8"FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}…

1731. 每位经理的下属员工数量

1731. 每位经理的下属员工数量 题目链接:1731. 每位经理的下属员工数量 代码如下: # Write your MySQL query statement below select a.employee_id as employee_id,a.name as name,count(b.employee_id) as reports_count,round(avg(b.age),0) as av…