CVPR2022人脸识别Partial FC论文及代码学习笔记

论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/An_Killing_Two_Birds_With_One_Stone_Efficient_and_Robust_Training_CVPR_2022_paper.pdf

代码链接:insightface/recognition/arcface_torch at master · deepinsight/insightface · GitHub

背景

使用基于百万规模的数据集和基于margin的softmax损失函数来学习区分性的embeddings是当前人脸识别的SOTA方法。然而,全连接层的内存和计算成本随着训练集中ID数量的增加而线性增加。此外,大规模训练数据存在类间冲突(同一个人被分成不同ID)和长尾分布的问题。

传统FC

将传统的FC层应用在大规模的数据集上时,存在以下缺陷:

1、gradient confusion under interclass conflict

WebFace42M里有很多不同类别对之间的余弦相似度大于0.4,这表明类间冲突仍然存在于这些清洗过的数据集中。直接优化的话会导致gradient confusion(同一个人的特征非常相似却要掰成两个ID)

2、centers of tail classes undergo too many passive updates

每个iteration都优化图片数量很少的id,可能会导致负优化

3、the storage and calculation of the FC layer can easily exceed current GPU capabilities

PartialFC

在训练期间仍然维护所有类别中心,但只随机采样一小部分负类别中心来计算基于margin的softmax损失,而不是在每次迭代中使用所有负类别中心。更具体地说,首先从每个GPU收集embeddings和标签,然后将组合的特征和标签分布到所有GPU。为了平衡每个GPU的内存使用和计算成本,为每个GPU设置了一个内存缓冲区(下面代码中的perm)。内存缓冲区的大小由类别总数和负类别中心的采样率决定。在每个GPU上,首先通过标签选择正类中心并放入缓冲区,然后随机选择一小部分负类中心(负类中心的数量为self.sample_rate * self.num_local)填充缓冲区的其余部分,

def sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]

随后,使用选出的样本中心去与特征相乘并计算基于margin的softmax损失。

PFC在DDP框架下的流程图如下图所示,

整体代码如下,

class PartialFC_V2(torch.nn.Module):"""https://arxiv.org/abs/2203.15565A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).When sample rate less than 1, in each iteration, positive class centers and a random subset ofnegative class centers are selected to compute the margin-based softmax loss, all classcenters are still maintained throughout the whole training process, but only a subset isselected and updated in each iteration... note::When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).Example:-------->>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)>>> for img, labels in data_loader:>>>     embeddings = net(img)>>>     loss = module_pfc(embeddings, labels)>>>     loss.backward()>>>     optimizer.step()"""_version = 2def __init__(self,margin_loss: Callable,embedding_size: int,num_classes: int,sample_rate: float = 1.0,fp16: bool = False,):"""Paramenters:-----------embedding_size: intThe dimension of embedding, requirednum_classes: intTotal number of classes, requiredsample_rate: floatThe rate of negative centers participating in the calculation, default is 1.0."""super(PartialFC_V2, self).__init__()assert (distributed.is_initialized()), "must initialize distributed before create this"self.rank = distributed.get_rank()self.world_size = distributed.get_world_size()self.dist_cross_entropy = DistCrossEntropy()self.embedding_size = embedding_sizeself.sample_rate: float = sample_rateself.fp16 = fp16self.num_local: int = num_classes // self.world_size + int(self.rank < num_classes % self.world_size)self.class_start: int = num_classes // self.world_size * self.rank + min(self.rank, num_classes % self.world_size)self.num_sample: int = int(self.sample_rate * self.num_local)self.last_batch_size: int = 0self.is_updated: bool = Trueself.init_weight_update: bool = Trueself.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))# margin_lossif isinstance(margin_loss, Callable):self.margin_softmax = margin_losselse:raisedef sample(self, labels, index_positive):"""This functions will change the value of labelsParameters:-----------labels: torch.Tensorpassindex_positive: torch.Tensorpassoptimizer: torch.optim.Optimizerpass"""with torch.no_grad():positive = torch.unique(labels[index_positive], sorted=True).cuda()if self.num_sample - positive.size(0) >= 0:perm = torch.rand(size=[self.num_local]).cuda()perm[positive] = 2.0index = torch.topk(perm, k=self.num_sample)[1].cuda()index = index.sort()[0].cuda()else:index = positiveself.weight_index = indexlabels[index_positive] = torch.searchsorted(index, labels[index_positive])return self.weight[self.weight_index]def forward(self,local_embeddings: torch.Tensor,local_labels: torch.Tensor,):"""Parameters:----------local_embeddings: torch.Tensorfeature embeddings on each GPU(Rank).local_labels: torch.Tensorlabels on each GPU(Rank).Returns:-------loss: torch.Tensorpass"""local_labels.squeeze_()local_labels = local_labels.long()batch_size = local_embeddings.size(0)if self.last_batch_size == 0:self.last_batch_size = batch_sizeassert self.last_batch_size == batch_size, (f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")_gather_embeddings = [torch.zeros((batch_size, self.embedding_size)).cuda()for _ in range(self.world_size)]_gather_labels = [torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)]_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)distributed.all_gather(_gather_labels, local_labels)embeddings = torch.cat(_list_embeddings)labels = torch.cat(_gather_labels)## 选出落在本进程对应的类别范围内的数据labels = labels.view(-1, 1)index_positive = (self.class_start <= labels) & (labels < self.class_start + self.num_local)## 标签不在本类别段的, 将其类别标签设为-1labels[~index_positive] = -1## 将类别ID平移到原点(因为不同进程都会初始化对应的self.weight, 若不平移回去, 则label与self.weight中的index会对应不上)labels[index_positive] -= self.class_startif self.sample_rate < 1:weight = self.sample(labels, index_positive)else:weight = self.weightwith torch.cuda.amp.autocast(self.fp16):norm_embeddings = normalize(embeddings)norm_weight_activated = normalize(weight)logits = linear(norm_embeddings, norm_weight_activated)if self.fp16:logits = logits.float()logits = logits.clamp(-1, 1)logits = self.margin_softmax(logits, labels)loss = self.dist_cross_entropy(logits, labels)return loss

实验结果

将PFC替换掉传统FC后,模型在WebFace(包括4m、12m、42m)上的性能会有所提升,

 消融实验的结果如下,

与SOTA方法的性能对比如下, 

结论与讨论

结论

作者提出了一种用于在大规模数据集上训练人脸识别模型的方法——Partial FC (PFC)。在PFC的每次迭代中,仅选择一小部分类别中心来计算基于边际的softmax损失,这样可以显著减少类间冲突的概率、尾类中心的被动更新频率以及计算需求。通过广泛的实验,作者验证了所提出的PFC的有效性、鲁棒性和高效性。

局限性

尽管在WebFace上训练的PFC模型在高质量测试集上取得了不错的结果,但在人脸分辨率较低或低光照条件下拍摄的人脸上,PFC模型的表现可能较差。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/12201.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于lidar的多目标跟踪

文章目录 基本流程编译过程注意事项基本流程 基于雷达点云的目标追踪主要包括以下几个步骤: 点云预处理: 滤除噪点和无效点(如NaN值)进行平面分割,提取地面点云对剩余的点云进行聚类,得到可能的目标点云目标检测 对聚类后的点云进行分析,判断是否为有效目标可以利用目标的尺寸…

怎么转换音频?看这3款音频转换器

随着数字媒体的发展&#xff0c;音频文件在我们的日常生活中占据了越来越重要的地位。有时候在不同的应用场景里&#xff0c;无论是音乐、语音还是其他类型的音频内容&#xff0c;我们都需要对其进行转换以满足不同的需求。 本文将为您介绍3款常用的音频转换器&#xff0c;帮助…

Springboot+mybatis-plus+dynamic-datasource+继承DynamicRoutingDataSource切换数据源

Springbootmybatis-plusdynamic-datasource继承DynamicRoutingDataSource切换数据源 背景 最近公司要求支持saas&#xff0c;实现动态切换库的操作&#xff0c;默认会加载主租户的数据源&#xff0c;其他租户数据源在使用过程中自动创建加入。 解决问题 1.通过请求中设置租…

数据可视化训练第7天(json文件读取国家人口数据,找出前10和后10)

数据 https://restcountries.com/v3.1/all&#xff1b;建议下载下来&#xff0c;并不是很大 import numpy as np import matplotlib.pyplot as plt import requests import json #由于访问url过于慢&#xff1b;将数据下载到本地是json数据 #urlhttps://restcountries.com/v3…

MATLAB蚁群算法求解带时间窗的旅行商TSPTW问题代码实例

MATLAB蚁群算法求解带时间窗的旅行商TSPTW问题代码实例 蚁群算法编程求解TSPTW问题实例&#xff1a; 在经纬度范围为(121, 43)到(123, 45)的矩形区域内&#xff0c;散布着1个商家&#xff08;编号1&#xff09;和25个顾客点&#xff08;编号为226&#xff09;&#xff0c;各个…

web入门练手案例(二)

下面是一下web入门案例和实现的代码&#xff0c;带有部分注释&#xff0c;倘若代码中有任何问题或疑问&#xff0c;欢迎留言交流~ 数字变色Logo 案例描述 “Logo”是“商标”的英文说法&#xff0c;是企业最基本的视觉识别形象&#xff0c;通过商标的推广可以让消费者了解企…

C语言(指针)2

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸各位能阅读我的文章&#xff0c;诚请评论指点&#xff0c;关注收藏&#xff0c;欢迎欢迎~~ &#x1f4a5;个人主页&#xff1a;小羊在奋斗 &#x1f4a5;所属专栏&#xff1a;C语言 本系列文章为个人学习笔记&#x…

听说SOLIDWORKS科研版可以节约研发成本?

近几年来&#xff0c;政府越来越重视科研带动产业&#xff0c;绩效优良的产业技术研究院对于国家和地区的学术成果转化、技术创新、产业发展等具有不可忽视的促进和带动作用。研究院会承担众多新产业的基础研究工作&#xff0c;而常规的基础研究需要长期的积累&#xff0c;每个…

JAVA毕业设计141—基于Java+Springboot+Vue的物业管理系统(源代码+数据库)

毕设所有选题&#xff1a; https://blog.csdn.net/2303_76227485/article/details/131104075 基于JavaSpringbootVue的物业管理系统(源代码数据库)141 一、系统介绍 本项目前后端分离&#xff0c;分为管理员、员工、用户三种角色(角色权限可自行分配) 1、用户&#xff1a; …

高清模拟视频采集卡CVBS四合一信号采集设备解析

介绍一款新产品——LCC261高清视频采集与编解码一体化采集卡。这款高品质的产品拥有卓越的性能表现和丰富多样的功能特性&#xff0c;能够满足广大用户对于高清视频采集、处理以及传输的需求。 首先&#xff0c;让我们来了解一下LCC261的基本信息。它是一款基于灵卡技术研发的高…

LeetCode2095删除链表的中间节点

题目描述 给你一个链表的头节点 head 。删除 链表的 中间节点 &#xff0c;并返回修改后的链表的头节点 head 。长度为 n 链表的中间节点是从头数起第 ⌊n / 2⌋ 个节点&#xff08;下标从 0 开始&#xff09;&#xff0c;其中 ⌊x⌋ 表示小于或等于 x 的最大整数。对于 n 1、…

深入探索Android签名机制:从v1到v3的演进之旅

引言 在Android开发的世界中&#xff0c;APK的签名机制是确保应用安全性的关键环节。随着技术的不断进步&#xff0c;Android签名机制也经历了从v1到v3的演进。本文将带你深入了解Android签名机制的演变过程&#xff0c;揭示每个版本背后的技术细节&#xff0c;并探讨它们对开…

浅谈下MYSQL表设计的几条规则

作为后端开发人员&#xff0c;避免不了和数据库打交道&#xff0c;可是我们怎么能够设计出高效&#xff0c;可维护&#xff0c;可扩展的数据库设计呢&#xff0c;在这里我总结了几个点&#xff0c;供大家参考。 在写之前&#xff0c;可能需要重复下数据库设计的范式原则&#…

水雨情监测系统—实时监测水位信息

TH-SW3水雨情监测系统是一种专门用于实时监测和收集水文气象数据的自动化系统。它能够实时获取区域内降雨和水情数据&#xff0c;并将其存储到数据库中进行分析处理&#xff0c;从而为防汛指挥人员提供及时准确的信息服务。 水雨情监测系统的主要功能包括实时监测水位、流速、流…

C++类与对象基础探秘系列(二)

目录 类的6个默认成员函数 构造函数 构造函数的概念 构造函数的特性 析构函数 析构函数的概念 析构函数的特性 拷贝构造函数 拷贝构造函数的概念 拷贝构造函数的特性 赋值运算符重载 运算符重载 赋值运算符重载 const成员 const修饰类的成员函数 取地址及const取地址操作…

MySQL文档_下载

可能需要&#xff1a;MySQL下载–》更新版本–》迁移数据库到MySQL 以下都不重要【只要确定好需要安装版本&#xff0c;找到对应的版本下载&#xff0c;安装&#xff0c;设置即可】 下载、安装&#xff1a; Determine whether MySQL runs and is supported on your platform…

Debian12安装后更换为国内镜像源,切换root用户,解决用户名不在sudoers文件中此事将被报告

选择Debian作为编程开发最佳Linux的理由&#xff1a; Debian是面向程序员的最古老&#xff0c;最出色的Linux发行版之一。Debian提供了具有.deb软件包管理兼容性的超稳定发行版。Debian为程序员提供了许多最新功能。因此&#xff0c;它具有一个特殊的编程空间。Debian是开发人员…

弥合孤岛:克服构建 DevOps 文化的挑战

持续变革正在发生软件开发行业。DevOps 因其对自动化、协作和持续改进的关注而成为优化软件交付并弥合开发和运营团队之间鸿沟的重要方法。然而&#xff0c;过渡到真正的 DevOps 文化并非没有挑战。本文探讨了您在追求 DevOps 时可能面临的障碍并提供了解决方案。 01 了解 Dev…

数据结构 顺序表1

1. 何为顺序表&#xff1a; 顺序表是一种线性数据结构&#xff0c;是由一组地址连续的存储单元依次存储数据元素的结构&#xff0c;通常采用数组来实现。顺序表的特点是可以随机存取其中的任何一个元素&#xff0c;并且支持在任意位置上进行插入和删除操作。在顺序表中&#xf…

算法-卡尔曼滤波之基本数学的概念

1.均值 定义&#xff1a;均值是一组数据中所有数值的总和除以数据的数量。均值是数据的中心趋势的一种度量&#xff0c;通常用符号 xˉ 表示。 &#xff1a;对于包含 n 个数据的数据集 {&#x1d465;1,&#x1d465;2,...,&#x1d465;&#x1d45b;}&#xff0c;均值 xˉ 计…