Center Loss 和 ArcFace Loss 笔记

一、Center Loss

1. 定义

Center Loss 旨在最小化类内特征的离散程度,通过约束样本特征与其类别中心之间的距离,提高类内特征的聚合性。

2. 公式

对于样本 xi​ 和其类别yi​,Center Loss 的公式为:

  • xi​: 当前样本的特征向量(通常来自网络的最后一层)。
  • Cyi: 类别 yi​ 的特征中心。
  • m: 样本数量。

3. 作用

  • 减小类内样本的特征分布范围。
  • 提高分类模型对相似类别样本的区分能力。

4. 实现

import torch
import torch.nn as nnclass CenterLoss(nn.Module):def __init__(self, num_classes, feat_dim, weight=1.0):""":param num_classes: 类别数量:param feat_dim: 特征向量维度:param weight: 损失的权重"""super(CenterLoss, self).__init__()self.weight = weightself.centers = nn.Parameter(torch.randn(num_classes, feat_dim))  # 初始化类别中心def forward(self, features, labels):""":param features: 网络输出的特征向量 (batch_size, feat_dim):param labels: 样本对应的类别标签 (batch_size,)"""centers = self.centers[labels]  # 获取对应标签的中心loss = torch.sum((features - centers) ** 2, dim=1).mean()  # 欧几里得距离平方和return self.weight * loss

5. 结合 Cross-Entropy Loss

Center Loss 与交叉熵损失结合,联合优化网络:

center_loss = CenterLoss(num_classes=10, feat_dim=512)
cross_entropy_loss = nn.CrossEntropyLoss()# 训练时
features, logits = model(input_data)
loss_ce = cross_entropy_loss(logits, labels)
loss_center = center_loss(features, labels)total_loss = loss_ce + 0.1 * loss_center  # 合并损失

二、ArcFace Loss

1. 定义

ArcFace Loss 是基于角度的损失函数,用于增强特征的判别性。通过在角度空间引入额外的边际约束,强迫同类样本之间更加接近,而不同类样本之间更加远离。

2. 公式

ArcFace Loss 的公式为:

  • θ: 特征和分类权重之间的角度。
  • m: 边际(margin)。

最终损失使用交叉熵计算:

  • s: 缩放因子,用于平衡模型的学习难度。

3. 作用

  • 强化特征的角度判别能力,使得分类更加鲁棒。
  • 在人脸识别任务中,显著提高模型的性能。

4. 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass ArcFaceLoss(nn.Module):def __init__(self, in_features, out_features, s=30.0, m=0.50):""":param in_features: 特征向量维度:param out_features: 类别数量:param s: 缩放因子:param m: 边际约束"""super(ArcFaceLoss, self).__init__()self.s = sself.m = mself.weight = nn.Parameter(torch.randn(out_features, in_features))  # 分类权重def forward(self, embeddings, labels):# Normalize embeddings and weightembeddings = F.normalize(embeddings, p=2, dim=1)weight = F.normalize(self.weight, p=2, dim=1)# Cosine similaritycosine = F.linear(embeddings, weight)# Add marginphi = cosine - self.mone_hot = torch.zeros_like(cosine)one_hot.scatter_(1, labels.view(-1, 1), 1)cosine_with_margin = one_hot * phi + (1 - one_hot) * cosine# Scalelogits = self.s * cosine_with_marginloss = F.cross_entropy(logits, labels)return loss

解释:

        ArcFaceLoss在最后一层网络,输入是上一层的输出特征值x,初始化当前层的w权重。

cos(角度)=w×x/|w|×|x|,由于ArcLoss会对w和x进行归一化到和为1的概率值。所以|w|×|x|=1。则推导出cos(角度)=w×x,那么真实标签位置给角度+m则让角度变大了,cos值变小。w×x变小,输出的预测为真实标签的概率变低。让模型更难训练,那么在一遍又一遍的模型读取图片提取特征的过程中,会让模型逐渐的将真实标签位置的w×x值变大==cos(角度+m)变大,那么角度就会变的更小。只有角度更小的时候,cos余弦相似度才会大,从而让模型认为这个类别是真实的类别。

所以arcloss主要加入了一个m,增大角度,让模型更难训练,让模型把角度变的更小,从而让w的值调整的更加让类间距增大。

简而言之:加入m的值,让真实类和其他类相似度更高,让模型更难训练。迫使模型为了让真实和其他类相似度更低,而让w权重的值更合理。

三、对比分析

四、如何选择

  • 如果任务需要提升类内特征的聚合性(如样本分布紧密性),优先考虑 Center Loss
  • 如果任务需要增强类间特征的判别能力(如人脸识别),优先选择 ArcFace Loss
  • 可以同时使用两者,将特征聚合和判别性结合,提高模型的鲁棒性。

五、推荐学习资源

  1. ArcFace: Additive Angular Margin Loss for Deep Face Recognition (论文)
  2. Center Loss: A Discriminative Feature Learning Approach for Deep Face Recognition (论文)
  3. PyTorch 官方文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/67060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AI在软件工程教育中的应用与前景展望

引言 随着科技的快速发展,软件工程教育面临着前所未有的挑战与机遇。传统的教学模式逐渐无法满足快速变化的行业需求,学生们需要更多的实践经验和个性化的学习方式。而在这样的背景下,人工智能(AI)作为一项创新技术&a…

【微服务】面试 7、幂等性

幂等性概念及场景 概念:多次调用方法或接口不改变业务状态,重复调用结果与单次调用一致。例如在京东下单,多次点击提交订单只能成功一次。场景:包括用户重复点击、网络波动导致多次请求、mq 消息重复消费、代码中设置失败或超时重…

Redis 为什么要引入 Pipeline机制?

在 Redis 中有一种 Pipeline(管道)机制,其目的是提高数据传输效率和吞吐量。那么,Pipeline是如何工作的?它又是如何提高性能的?Pipeline有什么优缺点?我们该如何使用 Pipeline? 1、…

游戏引擎学习第78天

Blackboard: Position ! Collision “网格” 昨天想到的一个点,可能本来就应该想到,但有时反而不立即思考这些问题也能带来一些好处。节目是周期性的,每天不需要全程关注,通常只是在晚上思考,因此有时我们可能不能那么…

使用 C# 制作图像的特写窗口

许多网站都会显示一个特写窗口,其中显示放大的图像部分,以便您可以看到更多细节。您在主图像上移动鼠标,它会在单独的图片中显示特写。此示例执行的操作类似。(示例使用的一些数学运算非常棘手,因此您可能需要仔细查看…

Python学习(三)基础入门(数据类型、变量、条件判断、模式匹配、循环)

目录 一、第一个 Python 程序1.1 命令行模式、Python 交互模式1.2 Python的执行方式1.3 SyntaxError 语法错误1.4 输入和输出 二、Python 基础2.1 Python 语法2.2 数据类型1)Number 数字2)String 字符串3)List 列表4)Tuple 元组5&…

【MySQL】SQL菜鸟教程(一)

1.常见命令 1.1 总览 命令作用SELECT从数据库中提取数据UPDATE更新数据库中的数据DELETE从数据库中删除数据INSERT INTO向数据库中插入新数据CREATE DATABASE创建新数据库ALTER DATABASE修改数据库CREATE TABLE创建新表ALTER TABLE变更数据表DROP TABLE删除表CREATE INDEX创建…

力扣257(关于回溯算法)二叉树的所有路径

257. 二叉树的所有路径 一.问题描述 已解答 简单 相关标签 相关企业 给你一个二叉树的根节点 root ,按 任意顺序 ,返回所有从根节点到叶子节点的路径。 叶子节点 是指没有子节点的节点。 示例 1: 输入:root [1,2,3,null,5…

Redis有哪些常用应用场景?

大家好,我是锋哥。今天分享关于【Redis有哪些常用应用场景?】面试题。希望对大家有帮助; Redis有哪些常用应用场景? 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 Redis 是一个高性能的开源键值对(Key-Va…

【2024年华为OD机试】(A卷,100分)- 处理器问题(Java JS PythonC/C++)

一、问题描述 题目描述 某公司研发了一款高性能AI处理器。每台物理设备具备8颗AI处理器,编号分别为0、1、2、3、4、5、6、7。 编号0-3的处理器处于同一个链路中,编号4-7的处理器处于另外一个链路中,不通链路中的处理器不能通信。 如下图所…

设计模式-结构型-组合模式

1. 什么是组合模式? 组合模式(Composite Pattern) 是一种结构型设计模式,它允许将对象组合成树形结构来表示“部分-整体”的层次结构。组合模式使得客户端对单个对象和组合对象的使用具有一致性。换句话说,组合模式允…

HQChart使用教程30-K线图如何对接第3方数据44-DRAWPIE数据结构

HQChart使用教程30-K线图如何对接第3方数据44-DRAWPIE数据结构 效果图DRAWPIEHQChart代码地址后台数据对接说明示例数据数据结构说明效果图 DRAWPIE DRAWPIE是hqchart插件独有的绘制饼图函数,可以通过麦语法脚本来绘制一个简单的饼图数据。 饼图显示的位置固定在右上角。 下…

Proser:升级为简易的通讯调试助手软件

我本来打算将Proser定位为一个直观的协议编辑、发送端模拟软件,像下面这样。 但是按耐不住升级的心理,硬生生的把即时收发整合了进去,就像这样! 不过,目前针对即时收发还没有发送历史、批量发送等功能,…

PyTorch环境配置常见报错的解决办法

目标 小白在最基础的环境配置里一般都会出现许多问题。 这里把一些常见的问题分享出来。希望可以节省大家一些时间。 最终目标是可以在cmd虚拟环境里进入jupyter notebook,new的时候有对应的环境,并且可以跑通所有的import code。 第一步:…

【Linux系列】Curl 参数详解与实践应用

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Web基础-分层解耦

思考:什么是耦合?什么是内聚?软件设计原则是什么? 耦合:衡量软件中各个层 / 各个模块的依赖关联程度。 内聚:软件中各个功能模块内部的功能联系。 软件设计原则:高内聚低耦合。 那我们该如何实现…

算法题(33):长度最小的子数组

审题: 需要我们找到满足元素之和大于等于target的最小子数组的元素个数,并返回 思路: 核心:子数组共有n种起点,nums数组的每个元素都可以充当子数组的首元素,我们只需要先确定子数组的首元素,然后往后查找满…

网络数据通信基本流程

1.基本概念 网络通信就是发送数据、接收数据、处理数据的过程,发送数据时要读数据进行处理(封装),接收数据时也要对数据进行处理(分用), 1)封装 对数据进行加工处理,如…

科创驱动 | 华望系统科技荣膺西湖区年度前沿创新新锐企业

2025年1月3日,由中共西湖区党委、西湖区人民政府主办的“新年第一会”—西湖区科技创新大会在杭州隆重举行。大会现场揭晓了西湖区年度科技创新团队与项目,并发布了“2024西湖区科技十大事件”与“西湖区五大年度科技榜单”。杭州华望系统科技有限公司榜…

Java Web开发基础:HTML的深度解析与应用

文章目录 前言🌍一.B/S 软件开发架构简述🌍二.HTML 介绍❄️2.1 官方文档❄️2.2 网页的组成❄️2.3 HTML 是什么❄️2.4html基本结构 🌍三.HTML标签1.html 的标签/元素-说明2. html 标签注意事项和细节3.font 字体标签4.标题标签5.超链接标签…