【计算机视觉】siamfc论文复现实现目标追踪

什么是目标跟踪

使用视频序列第一帧的图像(包括bounding box的位置),来找出目标出现在后序帧位置的一种方法。

什么是孪生网络结构

孪生网络结构其思想是将一个训练样本(已知类别)和一个测试样本(未知类别)输入到两个CNN(这两个CNN往往是权值共享的)中,从而获得两个特征向量,然后通过计算这两个特征向量的的相似度,相似度越高表明其越可能是同一个类别。

在这里插入图片描述

给你一张我的正脸照(没有经过美颜处理的),你该如何在人群中找到我呢?一种最直观的方案就是:“谁长得最像就是谁”。但是对于计算机来说,如何衡量“长得像”,并不是个简单的问题。这就涉及一种基本的运算——互相关(cross-correlation)。互相关运算可以用来度量两个信号之间的相似性。互相关得到的响应图中每个像素的响应高低代表着每个位置相似度的高低。

在这里插入图片描述

在目标领域中,最早利用这种思想的是SiamFC,其网络结构如上图。图中的φ就是CNN编码器,上下两个分支使用的CNN不仅结构相同,参数也是完全共享的(说白了就是同一个网络,并不存在孪生兄弟那样的设定)。z和x分别是要跟踪的目标模版图像(尺寸为127x127)和新的一帧中的搜索范围(尺寸为255x255)。二者经过同样的编码器后得到各自的特征图,对二者进行互相关运算后则会同样得到一个响应图(尺寸为17x17),其每一个像素的值对应了x中与z等大的一个对应区域出现跟踪目标的概率。

互相关运算的步骤,像极了我们手里拿着一张目标的照片(模板图像),然后把这个照片按在需要寻找目标的图片上(搜索图像)进行移动,然后求重叠部分相似度,从而找到这个目标,只不过为了计算机计算的方便,使用AlexNet对图像数据进行了编码/特征提取

下面这个版本中有一些动图,还是会帮助理解的:https://github.com/rafellerc/Pytorch-SiamFC

SiamFC代码分析

我们对siamese的结构大致就讲完了,还有一些内容结合代码来讲,效果更好。

3.1 training

3.1.1图像预处理

小超up给出训练的框图如下。训练过程中,首先要获取训练数据集的所有视频序列(每个视频序列的所有帧),我采用的是GOT-10k数据集训练;获取数据集之后进行图像预处理,对每一个视频序列抽取两帧图像并作数据增强处理(包括裁剪、resize等过程),分别作为目标模板图像和搜索图像;把经过图像处理的所有图像对加载并以batch_size输入网络得到预测输出;建立标签和损失函数,损失函数的输入是预测输出,目标是标签;设置优化策略,梯度下降损失,最终得到网络模型。

在这里插入图片描述

先贴代码,再分析:

def train(data_dir, net_path=None,save_dir='pretrained'):#从文件中读取图像数据集seq_dataset = GOT10k(data_dir,subset='train',return_meta=False)#定义图像预处理方法transforms = SiamFCTransforms(  exemplar_sz=cfg.exemplar_sz, #127instance_sz=cfg.instance_sz, #255context=cfg.context) #0.5#从读取的数据集每个视频序列配对训练图像并进行预处理,裁剪等train_dataset = GOT10kDataset(seq_dataset,transforms)

data_dir是存放GOT-10k数据集的文件路径,GOT-10k一共有9335个训练视频序列,seq_dataset返回的是所有视频序列的图片路径列表seq_dirs及对应groundtruth列表anno_files及一些其他信息,如下:

img

接下来是定义好图像预处理方法,在GOT10kDataset方法中对每个视频序列配对两帧图像,并使用定义好的图像处理方法,接下来直接进入该方法分析代码,GOT10kDataset的代码如下:

class GOT10kDataset(Dataset): #继承了torch.utils.data的Dataset类def __init__(self, seqs, transforms=None,pairs_per_seq=1):def __getitem__(self, index): #通过_sample_pair方法得到索引返回item=(z,x,box_z,box_x),然后经过transforms处理def __len__(self): #返回9335*pairs_per_seq对def _sample_pair(self, indices): #随机挑选两个索引,这里取的间隔不超过T=100def _filter(self, img0, anno, vis_ratios=None): #通过该函数筛选符合条件的有效索引val_indices

这里最重要的方法就是__getitem__,该方法最终返回处理后的图像,在内部首先调用了_sample_pair方法,用于提取两帧有效图片(有效的定义是图片目标的面积和高宽等有约束条件)的索引,在得到这两帧图片和对应groundtruth之后通过定义好的transforms进行处理,transforms是SiamFCTransforms类的实例化对象,该类中主要继承了resize图片大小和各种裁剪方式等,如代码所示:

class SiamFCTransforms(object):def __init__(self, exemplar_sz=127, instance_sz=255, context=0.5):self.exemplar_sz = exemplar_szself.instance_sz = instance_szself.context = context#transforms_z/x是数据增强方法self.transforms_z = Compose([RandomStretch(),     #随机resize图片大小,变化再[1 1.05]之内CenterCrop(instance_sz - 8),  #中心裁剪 裁剪为255-8RandomCrop(instance_sz - 2 * 8),   #随机裁剪  255-8->255-8-8CenterCrop(exemplar_sz),   #中心裁剪 255-8-8->127ToTensor()])                        #图片的数据格式从numpy转换成torch张量形式self.transforms_x = Compose([RandomStretch(),                   #s随机resize图片CenterCrop(instance_sz - 8),      #中心裁剪 裁剪为255-8RandomCrop(instance_sz - 2 * 8),  #随机裁剪 255-8->255-8-8ToTensor()])                      #图片数据格式转化为torch张量def __call__(self, z, x, box_z, box_x): #z,x表示传进来的图像z = self._crop(z, box_z, self.instance_sz)       #对z(x类似)图像 1、box转换(l,t,w,h)->(y,x,h,w),并且数据格式转为float32,得到center[y,x],和target_sz[h,w]x = self._crop(x, box_x, self.instance_sz)       #2、得到size=((h+(h+w)/2)*(w+(h+2)/2))^0.5*255(instance_sz)/127z = self.transforms_z(z)                         #3、进入crop_and_resize:传入z作为图片img,center,size,outsize=255(instance_sz),随机选方式填充,均值填充x = self.transforms_x(x)                         #   以center为中心裁剪一块边长为size大小的正方形框(注意裁剪时的padd边框填充问题),再resize成out_size=255(instance_sz)return z, x

实例化对象后,直接从__call__开始运行代码,首先关注的应该是_crop函数,该函数将原始的两帧图片分别以目标为中心,裁剪一块包含上下文信息的patch,patch的边长定义如下:

在这里插入图片描述

式中,w、h分别表示目标的宽和高。下面具体讲里面的_crop函数:

    def _crop(self, img, box, out_size):# convert box to 0-indexed and center based [y, x, h, w]box = np.array([box[1] - 1 + (box[3] - 1) / 2,box[0] - 1 + (box[2] - 1) / 2,box[3], box[2]], dtype=np.float32)center, target_sz = box[:2], box[2:]context = self.context * np.sum(target_sz)size = np.sqrt(np.prod(target_sz + context))size *= out_size / self.exemplar_szavg_color = np.mean(img, axis=(0, 1), dtype=float)interp = np.random.choice([cv2.INTER_LINEAR,cv2.INTER_CUBIC,cv2.INTER_AREA,cv2.INTER_NEAREST,cv2.INTER_LANCZOS4])patch = ops.crop_and_resize(img, center, size, out_size,border_value=avg_color, interp=interp)return patch

因为GOT-10k里面对于目标的bbox是以ltwh(即left, top, weight, height)形式给出的,上述代码一开始就先把输入的box变成center based,坐标形式变为[y, x, h, w],结合下面这幅图就非常好理解在这里插入图片描述img

crop_and_resize:

def crop_and_resize(img, center, size, out_size,border_type=cv2.BORDER_CONSTANT,border_value=(0, 0, 0),interp=cv2.INTER_LINEAR):# convert box to corners (0-indexed)size = round(size)  # the size of square cropcorners = np.concatenate((np.round(center - (size - 1) / 2),np.round(center - (size - 1) / 2) + size))corners = np.round(corners).astype(int)# pad image if necessarypads = np.concatenate((-corners[:2], corners[2:] - img.shape[:2]))npad = max(0, int(pads.max()))if npad > 0:img = cv2.copyMakeBorder(img, npad, npad, npad, npad,border_type, value=border_value)# crop image patchcorners = (corners + npad).astype(int)patch = img[corners[0]:corners[2], corners[1]:corners[3]]# resize to out_sizepatch = cv2.resize(patch, (out_size, out_size),interpolation=interp)return patch

在裁剪过程中会出现越界的情况,需要对原始图像边缘填充,填充值固定为图像的RGB均值,填充大小根据图像边缘越界最大值作为填充值,具体实现过程由以下代码完成。

# padding操作#corners表示目标的[ymin,xmin,ymax,xmax]pads = np.concatenate((-corners[:2], corners[2:] - img.shape[:2]))npad = max(0, int(pads.max())) #得到上下左右4个越界值中最大的与0对比,<0代表无越界if npad > 0:img = cv2.copyMakeBorder(img, npad, npad, npad, npad,cv2.BORDER_CONSTANT, value=img_average)

实验结果:

img

3.1.2加载训练数据、标签及损失函数

图像预处理完成后,得到了用与训练的9335对图像,将图像加载批量加载输入网络得到输出结果作为损失函数的input,损失函数的target是制定好的labels。

#加载训练数据集loader_dataset = DataLoader( dataset = train_dataset,batch_size=cfg.batch_size,shuffle=True,num_workers=cfg.num_workers,pin_memory=True,drop_last=True, )#初始化训练网络cuda = torch.cuda.is_available()  #支持GPU为Truedevice = torch.device('cuda:0' if cuda else 'cpu')  #cuda设备号为0model = AlexNet(init_weight=True)corr = _corr()model = model.to(device)corr = corr.to(device)# 设置损失函数和标签logist_loss = BalancedLoss()labels = _create_labels(size=[cfg.batch_size, 1, cfg.response_sz - 2, cfg.response_sz - 2])labels = torch.from_numpy(labels).to(device).float()

本小节主要讲网络输出的labels和损失函数,接下来只是小超up个人的一些理解,代码与论文理论部分形式不一致,但效果一样。先上图,论文中labels以及损失函数如下图:

img

img

然而代码中的labels值却是1和0,损失函数使用的是二值交叉熵损失函数F.binary_cross_entropy_with_logits,如下图推导所示,解释了为什么代码实现部分真正使用的labels值是1和0,而理论部分使用的是1和-1。

img

利用下面代码的这个_creat_labels方法可以得到标签。

def _create_labels(size):def logistic_labels(x, y, r_pos):# x^2+y^2<4 的位置设为为1,其他为0dist = np.sqrt(x ** 2 + y ** 2)labels = np.where(dist <= r_pos,    #r_os=2np.ones_like(x),  #np.ones_like(x),用1填充xnp.zeros_like(x)) #np.zeros_like(x),用0填充xreturn labels#获取标签的参数n, c, h, w = size  # [8,1,15,15]x = np.arange(w) - (w - 1) / 2  #x=[-7 -6 ....0....6 7]y = np.arange(h) - (h - 1) / 2  #y=[-7 -6 ....0....6 7]x, y = np.meshgrid(x, y)       #建立标签r_pos = cfg.r_pos / cfg.total_stride  # 16/8labels = logistic_labels(x, y, r_pos)#重复batch_size个label,因为网络输出是batch_size张response maplabels = labels.reshape((1, 1, h, w))   #[1,1,15,15]labels = np.tile(labels, (n, c, 1, 1))  #将labels扩展[8,1,15,15]return labels

验证结果如下图,只截取了部分labels,得到的labels对应输入,大小都是[8,1,15,15]

if __name__ == '__main__':labels = _create_labels([8,1,15,15])  #返回的label.shape=(8,1,15,15)

其中关于np.tile、np.meshgrid、np.where函数的使用可以去看这篇博客,最后出来的一个batch下某一个通道下的label就是下面这样的

img

3.1.3 优化策略

这里主要说一下学习率lr,随着训练次数epoch增多而减小,具体值如下公式image-20240721105017162,式中,initial为初始学习率,gamma是定义的超参,epoch为训练次数。整个优化器及学习率调整实现代码如下:

#建立优化器,设置指数变化的学习率optimizer = optim.SGD(model.parameters(),lr=cfg.initial_lr,              #初始化的学习率,后续会不断更新weight_decay=cfg.weight_decay,  #λ=5e-4,正则化momentum=cfg.momentum)          #v(now)=dx∗lr+v(last)∗momemtumgamma = np.power(                   #np.power(a,b) 返回a^bcfg.ultimate_lr / cfg.initial_lr,1.0 / cfg.epoch_num)lr_scheduler = ExponentialLR(optimizer, gamma)  #指数形式衰减,lr=initial_lr*(gamma^epoch)=
3.1.4 模型的训练与保存

一切准备工作就绪后,就开始训练了。代码中设定epoch_num为50次,训练时密切加上model.train(),告诉网络处于训练状态,这样,网络运行时就会利用pytorch的自动求导机制求导;在测试时,改为model.eval(),关闭自动求导。模型训练的步骤如代码所示:

# loop over epochs
for epoch in range(self.cfg.epoch_num):# update lr at each epochself.lr_scheduler.step(epoch=epoch)# loop over dataloaderfor it, batch in enumerate(dataloader):loss = self.train_step(batch, backward=True)print('Epoch: {} [{}/{}] Loss: {:.5f}'.format(epoch + 1, it + 1, len(dataloader), loss))sys.stdout.flush()# save checkpointif not os.path.exists(save_dir):os.makedirs(save_dir)net_path = os.path.join(save_dir, 'siamfc_alexnet_e%d.pth' % (epoch + 1))torch.save(self.net.state_dict(), net_path)

至此此份repo的训练应该差不多结束了

参考文档

siameseFC论文和代码解析

SiamFC 学习(论文、总结与分析)

siamfc-pytorch代码讲解(一):backbone&head

siamfc-pytorch代码讲解(二):train&siamfc

SiamFC代码分析(architecture、training、test)

http://www.360doc.com/content/19/0801/10/32196507_852333196.shtml

视频推荐

目标跟踪零基础代码入门(一):SiamFC_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/48472.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码解读:Diffusion Models中的长宽桶技术(Aspect Ratio Bucketing)

Diffusion Models专栏文章汇总&#xff1a;入门与实战 前言&#xff1a;自从SDXL提出了长宽桶技术之后&#xff0c;彻底解决了不同长宽比的图像输入问题&#xff0c;现在已经成为训练扩散模型必选的方案。这篇博客从代码详细解读如何在模型训练的时候运用长宽桶技术(Aspect Rat…

【机器学习】-- SVM核函数(超详细解读)

支持向量机&#xff08;SVM&#xff09;中的核函数是支持向量机能够处理非线性问题并在高维空间中学习复杂决策边界的关键。核函数在SVM中扮演着将输入特征映射到更高维空间的角色&#xff0c;使得原始特征空间中的非线性问题在高维空间中变得线性可分。 一、SVM是什么&#x…

时间卷积网络(TCN):序列建模的强大工具(附Pytorch网络模型代码)

这里写目录标题 1. 引言2. TCN的核心特性2.1 序列建模任务描述2.2 因果卷积2.3 扩张卷积2.4 残差连接 3. TCN的网络结构4. TCN vs RNN5. TCN的应用TCN的实现 1. 引言 引用自&#xff1a;Bai S, Kolter J Z, Koltun V. An empirical evaluation of generic convolutional and re…

Linux系统之部署扫雷小游戏(三)

Linux系统之部署扫雷小游戏(三) 一、小游戏介绍1.1 小游戏简介1.2 项目预览二、本次实践介绍2.1 本地环境规划2.2 本次实践介绍三、检查本地环境3.1 检查系统版本3.2 检查系统内核版本3.3 检查软件源四、安装Apache24.1 安装Apache2软件4.2 启动apache2服务4.3 查看apache2服…

大厂生产解决方案:泳道隔离机制

更多大厂面试内容可见 -> http://11come.cn 大厂生产解决方案&#xff1a;泳道隔离机制 背景 在公司中&#xff0c;由于项目多、开发人员多&#xff0c;一般会有多套测试环境&#xff08;可以理解为多个服务器&#xff09;&#xff0c;同一套服务会在多套测试环境中都部署…

如何解决微服务下引起的 分布式事务问题

一、什么是分布式事务&#xff1f; 虽然叫分布式事务&#xff0c;但不是一定是分布式部署的服务之间才会产生分布式事务。不是在同一个服务或同一个数据库架构下&#xff0c;产生的事务&#xff0c;也就是分布式事务。 跨数据源的分布式事务 跨服务的分布式事务 二、解决方…

配置服务器

参考博客 1. https://blog.csdn.net/qq_31278903/article/details/83146031 2. https://blog.csdn.net/u014374826/article/details/134093409 3. https://blog.csdn.net/weixin_42728126/article/details/88887350 4. https://blog.csdn.net/Dreamhai/article/details/109…

javac详解 idea maven内部编译原理 自制编译器

起因 不知道大家在开发中&#xff0c;有没有过下面这些疑问。有的话&#xff0c;今天就一次解答清楚。 如何使用javac命令编译一个项目&#xff1f;java或者javac的一些参数到底有什么用&#xff1f;idea或者maven是如何编译java项目的&#xff1f;&#xff08;你可能猜测底层…

【一刷《剑指Offer》】面试题 47:不用加减乘除做加法

力扣对应题目链接&#xff1a;LCR 190. 加密运算 - 力扣&#xff08;LeetCode&#xff09; 牛客对应题目链接&#xff1a;不用加减乘除做加法_牛客题霸_牛客网 (nowcoder.com) 一、《剑指Offer》对应内容 二、分析题目 sumdataA⊕dataB 非进位和&#xff1a;异或运…

Unity UGUI 之 Graphic Raycaster

本文仅作学习笔记与交流&#xff0c;不作任何商业用途 本文包括但不限于unity官方手册&#xff0c;唐老狮&#xff0c;麦扣教程知识&#xff0c;引用会标记&#xff0c;如有不足还请斧正 首先手册连接如下&#xff1a; Unity - Manual: Graphic Raycaster 笔记来源于&#xff…

无人车技术浪潮真的挡不住了~

正文 无人驾驶汽车其实也不算是新鲜玩意了&#xff0c;早在十年前大家都开始纷纷投入研发&#xff0c;在那时就已经蠢蠢欲动&#xff0c;像目前大部分智驾系统和辅助驾驶系统都是无人驾驶系统的一个中间过度版本&#xff0c;就像手机进入智能机时代的中间版本。 然而前段时间突…

SpringBoot 介绍和使用(详细)

使用SpringBoot之前,我们需要了解Maven,并配置国内源(为什么要配置这些,下面会详细介绍),下面我们将创建一个SpringBoot项目"输出Hello World"介绍. 1.环境准备 ⾃检Idea版本: 社区版: 2021.1 -2022.1.4 专业版: ⽆要求 如果个⼈电脑安装的idea不在这个范围, 需要…

LeetCode 热题 HOT 100 (001/100)【宇宙最简单版】

【链表】 No. 0160 相交链表 【简单】&#x1f449;力扣对应题目指路 希望对你有帮助呀&#xff01;&#xff01;&#x1f49c;&#x1f49c; 如有更好理解的思路&#xff0c;欢迎大家留言补充 ~ 一起加油叭 &#x1f4a6; 欢迎关注、订阅专栏 【力扣详解】谢谢你的支持&#x…

搜维尔科技:【产品推荐】Euleria Health Riablo 运动功能训练与评估系统

Euleria Health Riablo 运动功能训练与评估系统 Riablo提供一种创新的康复解决方案&#xff0c;将康复和训练变得可激励、可衡量和可控制。Riablo通过激活本体感觉&#xff0c;并通过视听反馈促进神经肌肉的训练。 得益于其技术先进和易用性&#xff0c;Riablo是骨科、运动医…

jmeter部署

一、windows环境下部署 1、安装jdk并配置jdk的环境变量 (1) 安装jdk jdk下载完成后双击安装包&#xff1a;无限点击"下一步"直到完成&#xff0c;默认路径即可。 (2) jdk安装完成后配置jdk的环境变量 找到环境变量中的系统变量&#xff1a;此电脑 --> 右键属性 …

C语言:温度转换

1.题目&#xff1a;实现摄氏度&#xff08;Celsius&#xff09;和华氏度&#xff08;Fahrenheit&#xff09;之间的转换。 输入一个华氏温度&#xff0c;输出摄氏温度&#xff0c;结果保留两位小数。 2.思路&#xff1a;&#xff08;这是固定公式&#xff0c;其中 F 是华氏度&a…

【C语言】详解结构体(下)(位段)

文章目录 前言1. 位段的含义2. 位段的声明3. 位段的内存分配&#xff08;重点&#xff09;3.1 存储方向的问题3.2 剩余空间利用的问题 4. 位段的跨平台问题5. 位段的应用6. 总结 前言 相信大部分的读者在学校或者在自学时结构体的知识时&#xff0c;可能很少会听到甚至就根本没…

STM32实战篇:按键(外部输入信号)触发中断

功能要求 将两个按键分别与引脚PA0、PA1相连接&#xff0c;通过按键按下&#xff0c;能够触发中断响应程序&#xff08;不需明确功能&#xff09;。 代码流程如下&#xff1a; 实现代码 #include "stm32f10x.h" // Device headerint main() {//开…

JUC并发编程01-基础概念

概念 进程 进程可以视为程序的一个实例&#xff0c;进程就是用来加载指令、管理内存、管理I0 线程 一个进程内可以有多个线程&#xff0c;一个线程就是一个指令流。 在Java中&#xff0c;线程作为最小调度单位&#xff0c;进程作为资源分配的最小单位&#xff0c;可以说进程…

Mysql数据库第二次作业

(1)显示所有职工的基本信息。 mysql> select * from t_worker; (2)查询所有职工所属部门的部门号&#xff0c;不显示重复的部门号。 mysql> select distinct department_id from t_worker; (3)求出所有职工的人数。 mysql> select count(1) from t_worker; (4)列…