【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型

前言

Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到,大大减少了目标在遮挡后,追踪失败的可能。

一、数据集简介

Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低清摄像头)拍摄到的 1501 个行人、32668 个检测到的行人矩形框。每个行人至少由2个摄像头捕获到,并且在一个摄像头中可能具有多张图像。训练集有 751 人,包含 12,936 张图像,平均每个人有 17.2 张训练数据;测试集有 750 人,包含 19,732 张图像,平均每个人有 26.3 张测试数据。3368 张查询图像的行人检测矩形框是人工绘制的,而 gallery 中的行人检测矩形框则是使用DPM检测器检测得到的。

该数据集提供的固定数量的训练集和测试集均可以在single-shot或multi-shot测试设置下使用。

目录结构

Market-1501-v15.09.15

  ├── bounding_box_test

       ├── 0000_c1s1_000151_01.jpg

       ├── 0000_c1s1_000376_03.jpg

       ├── 0000_c1s1_001051_02.jpg

  ├── bounding_box_train

       ├── 0002_c1s1_000451_03.jpg

       ├── 0002_c1s1_000551_01.jpg

       ├── 0002_c1s1_000801_01.jpg

  ├── gt_bbox

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c1s1_009376_00.jpg

       ├── 0001_c2s1_001976_00.jpg

  ├── gt_query

       ├── 0001_c1s1_001051_00_good.mat

       ├── 0001_c1s1_001051_00_junk.mat

  ├── query

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c2s1_000301_00.jpg

       ├── 0001_c3s1_000551_00.jpg

  └── readme.txt

目录介绍

(1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)

(2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像

(3) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

(4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)

(5) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box

命名规则

以 0001_c1s1_000151_01.jpg 为例

1) 0001 表示每个人的标签编号,从0001到1501;

2) c1 表示第一个摄像头(camera1),共有6个摄像头;

3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;

4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;

5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

二、跟踪模型介绍

特征提取的模型有很多,可以替换特征提取模型网络。

下述给出的是 deep_sort/deep/model.py 里面的模型代码

import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):def __init__(self, c_in, c_out, is_downsample=False):super(BasicBlock, self).__init__()self.is_downsample = is_downsampleif is_downsample:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False)else:self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(c_out)self.relu = nn.ReLU(True)self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,padding=1, bias=False)self.bn2 = nn.BatchNorm2d(c_out)if is_downsample:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),nn.BatchNorm2d(c_out))elif c_in != c_out:self.downsample = nn.Sequential(nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),nn.BatchNorm2d(c_out))self.is_downsample = Truedef forward(self, x):y = self.conv1(x)y = self.bn1(y)y = self.relu(y)y = self.conv2(y)y = self.bn2(y)if self.is_downsample:x = self.downsample(x)return F.relu(x.add(y), True)def make_layers(c_in, c_out, repeat_times, is_downsample=False):blocks = []for i in range(repeat_times):if i == 0:blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]else:blocks += [BasicBlock(c_out, c_out), ]return nn.Sequential(*blocks)class Net(nn.Module):def __init__(self, num_classes=751, reid=False):super(Net, self).__init__()# 3 128 64self.conv = nn.Sequential(nn.Conv2d(3, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),# nn.Conv2d(32,32,3,stride=1,padding=1),# nn.BatchNorm2d(32),# nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, padding=1),)# 32 64 32self.layer1 = make_layers(64, 64, 2, False)# 32 64 32self.layer2 = make_layers(64, 128, 2, True)# 64 32 16self.layer3 = make_layers(128, 256, 2, True)# 128 16 8self.layer4 = make_layers(256, 512, 2, True)# 256 8 4self.avgpool = nn.AvgPool2d((8, 4), 1)# 256 1 1self.reid = reidself.classifier = nn.Sequential(nn.Linear(512, 256),nn.BatchNorm1d(256),nn.ReLU(inplace=True),nn.Dropout(),nn.Linear(256, num_classes),)def forward(self, x):x = self.conv(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)# B x 128if self.reid:x = x.div(x.norm(p=2, dim=1, keepdim=True))return x# classifierx = self.classifier(x)return xif __name__ == '__main__':net = Net()x = torch.randn(4, 3, 128, 64)y = net(x)import ipdbipdb.set_trace()

三、数据集处理

splitDataset.py

用于存放训练的图片

# -*- coding:utf-8 -*-
# @author: 牧锦程
# @微信公众号: AI算法与电子竞赛
# @Email: m21z50c71@163.com
# @VX:fylaicaiimport os
from shutil import copyfile# You only need to change this line to your dataset download path
download_path = 'Market-1501-v15.09.15'if not os.path.isdir(download_path):print('please change the download_path')save_path = 'pytorch'
if not os.path.isdir(save_path):os.mkdir(save_path)# ------------------- query ----------------------
query_path = download_path + '/query'
query_save_path = save_path + '/query'
print("process: ", query_path)
if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ----------------- multi-query ------------------------
query_path = download_path + '/gt_bbox'
print("process: ", query_path)
# for dukemtmc-reid, we do not need multi-query
if os.path.isdir(query_path):query_save_path = save_path + '/multi-query'if not os.path.isdir(query_save_path):os.mkdir(query_save_path)for root, dirs, files in os.walk(query_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = query_path + '/' + namedst_path = query_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ------------------- gallery ----------------------
gallery_path = download_path + '/bounding_box_test'
gallery_save_path = save_path + '/gallery'
print("process: ", gallery_path)
if not os.path.isdir(gallery_save_path):os.mkdir(gallery_save_path)for root, dirs, files in os.walk(gallery_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = gallery_path + '/' + namedst_path = gallery_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)# ------------------ train ---------------------
train_path = download_path + '/bounding_box_train'
train_save_path = save_path + '/train'
val_save_path = save_path + '/test'
if not os.path.isdir(train_save_path):os.mkdir(train_save_path)os.mkdir(val_save_path)print("process: ", train_path)
for root, dirs, files in os.walk(train_path, topdown=True):for name in files:if not name[-3:] == 'jpg':continueID = name.split('_')src_path = train_path + '/' + namedst_path = train_save_path + '/' + ID[0]if not os.path.isdir(dst_path):os.mkdir(dst_path)# first image is used as val imagedst_path = val_save_path + '/' + ID[0]os.mkdir(dst_path)copyfile(src_path, dst_path + '/' + name)

四、模型训练

修改数据集路径

修改 data-dir 参数为自己的数据集路径

修改数据增强

增加一个尺寸修改

修改类别数量

代码中通过dataloader来获取,因此可以不进行修改

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

修改保存模型名称

这里的修改是为例与原始的模型进行区分,可以做对比

训练结果

查看精度

先运行test.py,生成 features.pth

在运行 evaluate.py,得到如下精度:

五、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/25272.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业网页制作

随着互联网的普及,企业网站已成为企业展示自己形象、吸引潜在客户、开拓新市场的重要方式。而企业网页制作则是构建企业网站的基础工作,它的质量和效率对于企业网站的成败至关重要。 首先,企业网页制作需要根据企业的特点和需求进行规划。在网…

前端 移动端 手机调试 (超简单,超有效 !)

背景:webpack工具构建下的vue项目 1. 找出电脑的ipv4地址 2. 替换 host 3. 手机连接电脑热点或者同一个wifi 。浏览器打开链接即可。

Spring运维之业务层测试数据回滚以及设置测试的随机用例

业务层测试数据回滚 我们之前在写dao层 测试的时候 如果执行到这边的代码 会在数据库 里面留下数据 运行一次留一次数据 开发有开发数据库,运行有运行数据库 我们先连数据库 在pom文件里引入mysql的驱动和mybatis-plus的依赖 在数据层写接口 用mybatis-plus进…

openh264 场景变化检测算法源码分析

文件位置 openh264/codec/processing/scenechangedetection/SceneChangeDetection.cppopenh264/codec/processing/scenechangedetection/SceneChangeDetection.h 代码流程 说明: 通过代码流程分析,当METHOD_SCENE_CHANGE_DETECTION_SCREEN场景类型为时…

Linux -- 了解 vim

目录 vim Linux 怎么编写代码? 了解 vim 的模式 什么是命令模式? 命令模式下 vim 的快捷键: 光标定位: 复制粘贴: 删除及撤销: 注释代码: 什么是底行模式? ​编辑 ​编辑…

Java:111-SpringMVC的底层原理(中篇)

这里续写上一章博客(110章博客): 现在我们来学习一下高级的技术,前面的mvc知识,我们基本可以在67章博客及其后面相关的博客可以学习到,现在开始学习精髓: Spring MVC 高级技术: …

Large-Scale LiDAR Consistent Mapping usingHierarchical LiDAR Bundle Adjustment

1. 代码地址 GitHub - hku-mars/HBA: [RAL 2023] A globally consistent LiDAR map optimization module 2. 摘要 重建精确一致的大规模激光雷达点云地图对于机器人应用至关重要。现有的基于位姿图优化的解决方案,尽管它在时间方面是有效的,但不能直接…

ubuntu使用docker安装openwrt

系统:ubuntu24.04 架构:x86 1. 安装docker 1.1 离线安装 docker下载地址 根据系统版本,依次下载最新的三个关于docker的软件包 container.io(注意后缀版本顺序)docker-ce-clidocker-ce 然后再ubuntu系统中依次按顺…

【召回第一篇】召回方法综述

各个网站上找的各位大神的优秀回答,记录再此。 首先是石塔西大佬的回答:工业界推荐系统中有哪些召回策略? 万变不离其宗:用统一框架理解向量化召回前言常读我的文章的同学会注意到,我一直强调、推崇,不要…

多种策略提升线上 tensorflow 模型推理速度

前言 本文以最常见的模型 Bi-LSTM-CRF 为例,总结了在实际工作中能有效提升在 CPU/GPU 上的推理速度的若干方法,包括优化模型结构,优化超参数,使用 onnx 框架等。当然如果你有充足的 GPU ,结合以上方法提升推理速度的效…

真空衰变,真正的宇宙级灾难,它到底有多可怕?

真空衰变,真正的宇宙级灾难,它到底有多可怕? 真空衰变 真空衰变(Vacuum decay)是物理学家根据量子场论推测出的一种宇宙中可能会发生的现象,这种现象被称为真正的宇宙级灾难,它到底有多可怕呢…

前端 Vue 操作文件方法(导出下载、图片压缩、文件上传和转换)

一、前言 本文对前端 Vue 项目开发过程中,经常遇到要对文件做一些相关操作,比如:文件导出下载、文件上传、图片压缩、文件转换等一些处理方法进行归纳整理,方便后续查阅和复用。 二、具体内容 1、后端的文件导出接口,…

【报文数据流中的反压处理】

报文数据流中的反压处理 1 带存储体的反压1.1 原理图1.2 Demo 尤其是在NP芯片中,经常涉及到报文的数据流处理;为了防止数据丢失,和各模块的流水处理;因此需要到反压机制; 反压机制目前接触到的有两种:一是基…

【深度学习】目标检测,Faster-RCNN算法训练,使用mmdetection训练

文章目录 资料环境数据测试 资料 https://mmdetection.readthedocs.io/zh-cn/latest/user_guides/config.html 环境 Dockerfile ARG PYTORCH"1.9.0" ARG CUDA"11.1" ARG CUDNN"8"FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}…

使用 Scapy 库编写 TCP 劫持攻击脚本

一、介绍 TCP劫持攻击(TCP Hijacking),也称为会话劫持,是一种攻击方式,攻击者在合法用户与服务器之间的通信过程中插入或劫持数据包,从而控制通信会话。通过TCP劫持,攻击者可以获取敏感信息、执…

mysql 更改数据存储目录

先停止 mysql :sudo systemctl start/stop mysql 新建新的目录, 比如 /mnt/data/systemdata/mysql/mysql_data sudo chown -R mysql:mysql /mnt/data/sysdata/mysql/mysql_data sudo chmod -R 750 /mnt/data/sysdata/mysql/mysql_data 更改mysql.cnf…

2024高考作文-ChatGPT完成答卷,邀请大家来打分

高考,愿你脑洞大开,知识点全都扎根脑海;考试时手感倍儿棒,答题如行云流水;成绩公布时,笑容如春风拂面,心情如阳光普照!高考加油,你一定行! 新课标I卷 试题内…

“深入探讨Java中的对象拷贝:浅拷贝与深拷贝的差异与应用“

前言:在Java编程中,深拷贝(Deep Copy)与浅拷贝(Shallow Copy)是两个非常重要的概念。它们涉及到对象在内存中的复制方式,对于理解对象的引用、内存管理以及数据安全都至关重要。 ✨✨✨这里是秋…

多粒度特征融合(细粒度图像分类)

多粒度特征融合(细粒度图像分类) 摘要Abstract1. 多粒度特征融合1.1 文献摘要1.2 研究背景1.3 创新点1.4 模型方法1.4.1 Swin-Transformer1.4.2 多粒度特征融合模块1.4.3 自注意力1.4.4 通道注意力1.4.5 图卷积网络1.4.6 基于Vision-Transformer的两阶段…

Rust 实战丨SSE(Server-Sent Events)

📌 SSE(Server-Sent Events)是一种允许服务器向客户端浏览器推送信息的技术。它是 HTML5 的一部分,专门用于建立一个单向的从服务器到客户端的通信连接。SSE的使用场景非常广泛,包括实时消息推送、实时通知更新等。 S…