【多标签分类问题的样本挖掘】Pytorch中的TripletMarginLoss的样本挖掘

多数度量学习的代码都需要进行挖掘,样本挖掘过程就是把一个Batch中的所有样本,根据标签来划分成正样本和负样本
这里我们只讨论多标签分类问题,标签是onehot编码,如果是单标签分类任务可以去看pytorch_metric_learning这个库有实现好的挖掘方法
比如输入样本为[Batch,Embedding],对应的标签是[Batch,Class]
对这些样本进行挖掘后得到以下三部分:

  1. Anchor :锚点样本,其实就是和输入的Batch一模一样,
  2. Positive Sample : 挖掘的正正样本
  3. Negtive Sample : 挖掘的负样本
import torch
import torch.nn as nn 
import torchvision# 损失函数
class HibCriterion(nn.Module):def __init__(self):super().__init__()def forward(self, z_samples, alpha, beta, indices_tuple):n_samples = z_samples.shape[1]if len(indices_tuple) == 3:a, p, n = indices_tupleap = an = aelif len(indices_tuple) == 4:ap, p, an, n = indices_tuplealpha = torch.nn.functional.softplus(alpha)loss = 0for i in range(n_samples):z_i = z_samples[:, i, :]for j in range(n_samples):z_j = z_samples[:, j, :]prob_pos = torch.sigmoid(- alpha * torch.sum((z_i[ap] - z_j[p])**2, dim=1) + beta) + 1e-6prob_neg = torch.sigmoid(- alpha * torch.sum((z_i[an] - z_j[n])**2, dim=1) + beta) + 1e-6# maximize the probability of positive pairs and minimize the probability of negative pairsloss += -torch.log(prob_pos) - torch.log(1 - prob_neg)loss = loss / (n_samples ** 2)return loss.mean()def get_matches_and_diffs(labels):matches = (labels.float() @ labels.float().T).byte()diffs = matches ^ 1 # 异或运算得到负标签的矩阵return matches, diffsdef get_all_triplets_indices_vectorized_method(all_matches, all_diffs):"""Args:all_matches (torch.Tensor): 相同标签all_diffs (torch.Tensor): 不相同标签Processing : all_matches.unsqueeze(2) -> [Batch,Batch,1]all_diffs.unsqeeeze(1) -> [Batch,1,Batch] Returns:torch.Tensor: _description_"""triplets = all_matches.unsqueeze(2) * all_diffs.unsqueeze(1)return torch.where(triplets)class TripletMinner(nn.Module):def __init__(self, *args, **kwargs) -> None:super().__init__(*args, **kwargs)self.sim_mat = get_matches_and_diffsself.selctor = get_all_triplets_indices_vectorized_methoddef forward(self,labels):a , b = self.sim_mat(labels)c = self.selctor(a,b)return c

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/15721.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习Uni-app开发小程序Day18

昨天学习了使用轮播显示图片和文字,轮播方式纵向和横向。今天使用扩展组件和scroll-view显示图片,使用scroll-view的grid方式、插槽slot、自定义组件、磨砂背景定位布局做专题组件 这就是需要做成的效果,下面将一步一步的完成。 首先&#x…

如何高效创建与配置工程环境:零基础入门

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、工程环境的搭建与准备 二、配置虚拟环境与选择解释器 三、编写代码与自动添加多行注释 …

git describe --tags报错 fatal: No names found, cannot describe anything.

文章目录 git describe --tags报错 fatal: No names found, cannot describe anything. git describe --tags报错 fatal: No names found, cannot describe anything. 问题描述: git describe --tags fatal: No names found, cannot describe anything.原因分析&a…

SpringMVC笔记

一、SpringMVC 简介 1.1 什么是 MVC MVC 是一种软件架构的思想,将软件按照模型、视图、控制器来划分 1.M:Model 模型层,指工程中的 JavaBean ,作用是处理数据 JavaBean 分为两类 实体类Bean:专门存储业务数据的…

C++vector的简单模拟实现

文章目录 目录 文章目录 前言 一、vector使用时的注意事项 1.typedef的类型 2.vector不是string 3.vector 4.算法sort 二、vector的实现 1.通过源码进行猜测vector的结构 2.初步vector的构建 2.1 成员变量 2.2成员函数 2.2.1尾插和扩容 2.2.2operator[] 2.2.3 迭代器 2…

云存储与云计算详解

1. 云存储与云计算概述 1.1 云存储 云存储(Cloud Storage)是指通过互联网将数据存储在远程服务器上,用户可以随时随地访问和管理这些数据。云存储的优点包括高可扩展性、灵活性和成本效益。 1.2 云计算 云计算(Cloud Computin…

前端 控制台提示invalid date

如果你遇到了 "Invalid Date" 的错误,这通常意味着传递给 Date 构造函数的字符串或数值无法被解析为一个有效的日期。对于时间戳来说,确保它是一个有效的数字(表示自1970年1月1日00:00:00 UTC以来的毫秒数)。 以下是一…

Java如何设计一个功能

流程说明:实现一组功能的步骤 1,充分了解需求,包括所有的细节,需要知道要做一个什么样的功能。 2,设计实体/表 正向工程:设计实体、映射文件 --> 建表 反向工程:设计表 --> 映射文件、实体 设计实体类型分析步骤: 1)功能模块有几个实体…

【Apache Doris】BE宕机问题排查指南

【Apache Doris】BE宕机问题排查指南 背景BE宕机分类如何判断是BE进程是Crash还是OOMBE Crash 后如何排查BE OOM 后如何分析Cache 没及时释放导致BE OOM(2.0.3-rc04) 关于社区 作者|李渊渊 背景 在实际线上生产环境中,大家可能遇…

校园网拨号上网环境下多开虚拟机,实现宿主机与虚拟机互通,并访问外部网络

校园网某些登录客户端只允许同一时间一台设备登录,因此必须使用NAT模式共享宿主机的真实IP,相当于访问外网时只使用宿主机IP,此方式通过虚拟网卡与物理网卡之间的数据转发实现访问外网及互通 经验证,将centos的物理地址与主机物理…

有什么好用的语音翻译软件推荐?亲测实用的语音翻译工具来了

嘿,大家好!你们有没有想过,现在世界这么“小”,我们跟不同国家的人打交道的机会越来越多了。 但是呢,语言不通真是个大问题。别担心,现在有个超棒的解决方案——语音翻译技术!这玩意儿能实时把…

Spring Cloud学习笔记(Nacos):配置中心基础和代码样例

这是本人学习的总结,主要学习资料如下 - 马士兵教育 1、Overview2、样例2.1、Dependency2.2、配置文件的定位2.3、bootstrap.yml2.4、配置中心新增配置2.5、验证 1、Overview 配置中心用于管理配置项和配置文件,比如平时写的application.yml就是配置文件…

Python 遍历字典的方法,你都掌握了吗

Python中的字典是一种非常灵活的数据结构,它允许通过键来存储和访问值。在处理字典时,经常需要遍历字典中的元素,以下是几种常见的遍历字典的方法。 1. 使用 for 循环直接遍历字典的键 字典的键是唯一的,可以直接通过 for 循环来…

【Spring Security + OAuth2】OAuth2

Spring Security OAuth2 第一章 Spring Security 快速入门 第二章 Spring Security 自定义配置 第三章 Spring Security 前后端分离配置 第四章 Spring Security 身份认证 第五章 Spring Security 授权 第六章 OAuth2 文章目录 Spring Security OAuth21、OAuth2简介1.1、OAu…

call、apply和bind

call、apply和bind都是JavaScript中函数对象的方法,用于改变函数的this值。 call:call方法接收一个对象和一系列参数,并立即调用函数,将this值设置为提供的对象。例如: function greet(greeting, punctuation) {cons…

Linux驱动开发笔记(二) 基于字符设备驱动的I/O操作

文章目录 前言一、设备驱动的作用与本质1. 驱动的作用2. 有无操作系统的区别 二、内存管理单元MMU三、相关函数1. ioremap( )2. iounmap( )3. class_create( )4. class_destroy( ) 四、GPIO的基本知识1. GPIO的寄存器进行读写操作流程2. 引脚复用2. 定义GPIO寄存器物理地址 五、…

【2024最新华为OD-C卷试题汇总】传递悄悄话的最长时间(100分) - 三语言AC题解(Python/Java/Cpp)

🍭 大家好这里是清隆学长 ,一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C卷的三语言AC题解 💻 ACM银牌🥈| 多次AK大厂笔试 | 编程一对一辅导 👏 感谢大家的订阅➕ 和 喜欢💗 文章目录 前…

东哥一句兄弟,你还当真了?

关注卢松松,会经常给你分享一些我的经验和观点。 你还真把自己当刘强东兄弟了?谁跟你是兄弟了?你在国外的房子又不给我住,你出去旅游也不带上我!都成人年了,东哥一句客套话,别当真! 今天,东哥在高管会上直言&…

mysql内存结构

一:逻辑存储结构:表空间->段->区->页->行、 表空间:一个mysql实例对应多个表空间,用于存储记录,索引等数据。 段:分为数据段,索引段,回滚段。innoDB是索引组织表&…