k均值的损失函数_一种基于均值不等式的Listwise损失函数

1 前言

1.1 Learning to Rank 简介

Learning to Rank (LTR) , 也被叫做排序学习, 是搜索中的重要技术, 其目的是根据候选文档和查询语句的相关性对候选文档进行排序, 或者选取topk文档. 比如在搜索引擎中, 需要根据用户问题选取最相关的搜索结果展示到首页. 下图是搜索引擎的搜索结果

d15e17ca0bdd701179c92463a8f8a563.png

1.2 LTR算法分类

根据损失函数可把LTR分为三种: 1. Pointwise, 该类型算法将LTR任务作为回归任务来训练, 即尝试训练一个为文档和查询语句的打分器, 然后根据打分进行排序. 2. Pairwise, 该类型算法的损失函数考虑了两个候选文档, 学习目标是把相关性高的文档排在前面, triplet loss 就属于Pairwise, 它的损失函数是$$ loss = max(0, score_{neg}-score_{pos}+margin)$$, 可以看出该损失函数一次考虑两个候选文档. 3. Listwise, 该类型算法的损失函数会考虑多个候选文档, 这是本文的重点, 下面会详细介绍.

1.3 本文主要内容

本文主要介绍了本人在学习研究过程中发明的一种新的Listwise损失函数, 以及该损失函数的使用效果. 如果读者对LTR任务及其算法还不够熟悉, 建议先去学习LTR相关知识, 同时本人博客自然语言处理中的负样本挖掘 (分类与排序任务中如何选择负样本) 也和本文关系较大, 可以先进行阅读.

2 预备知识

2.1 数学符号定义

$q$代表用户搜索问题, 比如"如何成为宇航员", $D$代表候选文档集合,$d^+$代表和$q$相关的文档,$d^-$代表和$q$不相关的文档, $d^+_i$代表第$i$个和$q$相关的文档, LTR的目标就是根据$q$找到最相关的文档$d$

2.2 学习目标

本次学习目标是训练一个打分器 scorer, 它可以衡量q和d的相关性, scorer(q, d)就是相关性分数,分值越大越相关. 当前主流方法下, scorer一般选用深度神经网络模型.

2.3训练数据分类

损失函数不同, 构造训练数据的方法也会不同:

-Pointwise, 可以构造回归数据集, 相关的数据设为1, 不相关设为0.
-Pairwise, 可构造triplet类型的数据集, 形如($q,d^+, d^-$) -Listwise, 可构造这种类型的训练集: ($q,d^+1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-{n+m}$), 一个正例还是多个正例也会影响到损失函数的构造, 本文提出的损失函数是针对多正例多负例的情况.

3 基于均值不等式的Listwise损失函数

3.1 损失函数推导过程

在上一小结我们可以知道,训练集是如下形式 ($q,d^+1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-{n+m}$), 对于一个q, 有m个相关的文档和n个不相关的文档, 那么我们一共可以获取m+n个分值:$(score_1,score_2,...,score_n,...,score_{n+m})$, 我们希望打分器对相关文档打分趋近于正无穷, 对不相关文档打分趋近于负无穷.

对m+n个分值做一个softmax得到$p_1,p_2,...,p_n,...,p_{n+m}$, 此时$p_i$可以看作是第i个候选文档与q相关的概率, 显然我们希望$p_1,p_2,...,p_m$越大越好, $p_{n+1},...,p_{m+n}$越小越好, 即趋近于0. 因此我们暂时的优化目标是$sum_{i=1}^{n}{p_i} rightarrow 1$.

但是这个优化目标是不合理的, 假设$p_1=1$, 其他值全为0, 虽然满足了上面的要求, 但这并不是我们想要的. 因为我们不仅希望$sum_{i=1}^{n}{p_i} rightarrow 1$, 还希望相关候选文档的每一个p值都要足够大, 即我们希望m个候选文档都与q相关的概率是最大的, 所以我们真正的优化目标是: $$max(prod_{i=1}^{n}{p_i} ) , sum_{i=1}^{n}{p_i} = 1$$

当前情况下, 损失函数已经可以通过代码实现了, 但是我们还可以做一些化简工作, $prod_{i=1}^{n}{p_i}$是存在最大值的, 根据均值不等式可得: $$prod_{i=1}^{n}{p_i} leq (frac{sum_{i=1}^{n}{p_i}}{n})^n$$

对两边取对数: $$sum_{i=1}^{n}{log(p_i)} leq -nlog(n)$$

这样是不是感觉清爽多了, 然后我们把它转换成损失函数的形式: $$ loss = -nlog(n) - sum_{i=1}^{n}{log(p_i)}$$

所以我们的训练目标就是$min{(loss)}$

3.2 使用pytorch实现该损失函数

在获取到最终的损失函数后, 我们还需要用代码来实现, 实现代码如下:

# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''

该示例代码仅展现了batch_size为1的情况, 在batch_size大于1时, 每一条数据都有不同的m和n, 为了能一起送入模型计算分值, 需要灵活的使用mask. 本人在实际使用该损失函数时,一共使用了两种mask, 分别mask每条数据所有候选文档和每条数据的相关文档, 供大家参考使用.

3.3 效果评估和使用经验

由于评测数据使用的是内部数据, 代码和数据都无法公开, 因此只能对使用效果做简单总结: 1. 效果优于PointwisePairwise, 但差距不是特别大 2. 相比Pairwise收敛速度极快, 训练一轮基本就可以达到最佳效果

下面是个人使用经验: 1. 该损失函数比较占用显存, 实际的batch_size是batch_size*(m+n), 建议显存在12G以上 2. 负例数量越多,效果越好, 收敛也越快 3. 用pytorch实现log_softmax时, 不要自己实现, 直接使用torch中的log_softmax函数, 它的效率更高些. 4. 只有一个正例, 还可以考虑转为分类问题,使用交叉熵做优化, 效果同样较好

### 4 总结 该损失函数还是比较简单的, 只需要简单的数学知识就可以自行推导, 在实际使用中也取得了较好的效果, 希望也能够帮助到大家. 如果大家有更好的做法欢迎告诉我.

文章可以转载, 但请注明出处:

  • 本人简书社区主页
  • 本人博客园社区主页
  • 本人知乎主页
  • 本人Medium社区主页

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/433213.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HBase 集群搭建

文章目录 安装前准备兼容性官方网址 集群搭建搭建 Hadoop 集群搭建 Zookeeper 集群解压缩安装配置文件高可用配置分发 HBase 文件 服务的启停启动顺序停止顺序 验证进程查看 Web 端页面 安装前准备 兼容性 1)与 Zookeeper 的兼容性问题,越新越好&#…

在哪个Linux发行版上运行python,怎么在linux上运行python

Linux默认是已经安装好了Python程序目前来说,大多数的Linux发行版是安装了两个版本的Python程序一个是Python 2.x一个是Python 3.x一些系统自带的程序文件需要Python 2的支持,另外Python 3又是大势所趋所以,我们最好不要动系统的Python版本需…

职场上个人的核心技术_职场上,这3种人表面老实,实际却是个“高手”,要远离...

职场上,这3种人表面老实,实际却是个“高手”,要远离!在职场生活中,每一步都需要走好,因为你不慎走错了一步也就可能满盘皆输。而公司里面也有一种比较特殊的情况,也就是有这么3种类型的人&#…

linux编程参数列表,Linux编程 14 文件权限(用户列表passwd,用户控制shadow,useradd模板与useradd命令参数介绍)...

一. 概述linux安全系统的核心是用户账户。 创建用户时会分配用户ID(UID)。 UID是唯一的,但在登录系统时不是用UID,而是用登录名。在讲文件权限之之前,先了解下linux是怎样处理用户账户的。以及用户账户需要的文件和工具,这样处理文…

GitHub托管BootStrap资源汇总(持续更新中…)

Twitter BootStrap已经火过大江南北,对于无法依赖美工的程序员来说,这一成熟前卫的前端框架简直就一神器,轻轻松松地实现出专业的UI效果。GitHub上相关的的开源项目更是层出不穷,在此整理列举一些感觉不错的组件或增强实现&#x…

aix linux运维,运维老司机分享的八个AIX日常运维经验及案例

原文来自微信公众号:AIX专家俱乐部【经验分享】在AIX启动时,打开debug模式经常遇到aix无法启动,但又不知道pending在哪,因此打开启动过程的debug模式,对于诊断问题有很大的帮帮助。下面是打开debug的方法:打…

php 区块链算法_PoW/BFT等5种主流区块链共识算法的开源代码实现

共识算法是实现自主产权区块链的必不可少的关键环节,本文列出社区中相对成熟的区块链共识算法开源实现,包括BFT共识、Raft共识、Paxos共识、PoW共识等,可供希望开发自主产权区块链的团队参考学习。相关推荐:区块链开发系列教程1、…

[每日一题] 11gOCP 1z0-052 :2013-09-1 RMAN-- repair failure........................................A20...

转载请注明出处:http://blog.csdn.net/guoyjoe/article/details/10859315 正确答案:D 一、模拟上题的错误: 1、删除4号文件 [oraclemydb ~]$ cd /u01/app/oracle/oradata/ocm/ [oraclemydb ocm]$ rm -rf users01.dbf2、…

kafka集群 kubernetes_为什么 Kubernetes 如此受欢迎?

点击上方蓝色“火丁笔记”关注我们,设个星标,每天学习全栈知识在撰写本文时,Kubernetes 已有 6 年历史[1]了,在过去的两年中,它的流行度不断提高,一直是最受欢迎的平台之一[2]。今年,它成为最受…

android 动画 返回,Android“菜单图标变返回”动画

此例用到SVG动画,其中涉及三个XML文件,分别为:Vector矢量图,objectAnimator动画,以及一个animated-vector文件将前两个文件联合起来。1.在drawable文件夹下新建vector文件描述矢量图android:height"200dp"an…

全志A10 Bootload加载过程分析

A10的启动过程大概可分为5步:BootRom,SPL,Uboot,Kernel,RootFileSystem。本文只关注镜像的加载过程,分析RootRom->SPL->Uboot的启动流程。系统上电后,ARM处理器在复位时从地址0x000000开始…

android老 电池,为什么安卓手机不会因为电池的老化而降频呢?

前段时间,苹果手机的降频事件也是闹的沸沸扬扬,库克也为此进行了公开道歉,各位的吃瓜群众也是看的不亦乐乎,于是,也有不少的小伙伴会问:“为什么安卓手机不会因为电池的老化而降频?”今天&#…

android 5.0.1 libdvm.so,Android逆向进阶—— 脱壳的奥义(基ART模式下的dump)

本文作者:i春秋作家HAI_ZHU000 前言市面上的资料大多都是基于Dalvik模式的dump,所以这此准备搞一个ART模式下的dump。Dalvik模式是Android 4.4及其以下采用的模式,之后到了Android 5.0 之后就是ART模式,关于这两个模式的详细内容&…

android+3.0新加的动画,Android动画片

使用Android两年多了,工作中的动画也动能应付,自认为Android中的动画自己也能用个八九不离十,结果我在学习[Periscope点赞效果](http://www.jianshu.com/p/03fdcfd3ae9c)的时候发现动画的这些高级功能我从没用过、也没见过,静下来…

在线打开html文件,html是什么文件?html文件怎么打开?

html是什么?html即超文本标记语言,现在大多网页都是html的格式。而所谓的html文件是一种超文本文件,其中超文本可以是图片或音乐等非文字元素,使用很广泛。但是很多用户都不太明白html是什么文件?也不清楚html文件要如…

gsoap使用心得! (win32)

最近换了个工作环境,现在在大望路这边上班,呵,刚上班接到的任务就是熟悉gsoap!废话少说,现在开始gSoap学习!gSOAP是一个夸平台的,用于开发Web Service服务端和客户端的工具,在Window…

html怎么置顶导航栏,css怎么实现滚动页面导航栏固定在顶部

css怎么实现滚动页面导航栏固定在顶部(吸顶效果)功能:当网页向下滚动时,导航栏一直在固定在顶部一、css设置这里主要用到css中position中的relative与fixed;其中relative是生成相对定位的元素,相对于其正常位置进行定位。fixed是生…

numpy读取csv_Numpy——IO操作与数据处理

一、问题?大多数数据并不是我们自己构造的,存在文件当中。我们需要工具去获取,但是Numpy其实并不适合去读取处理数据,这里我们了解相关API,以及Numpy不方便的地方即可。二、Numpy读取genfromtxt(fname[, dtype, commen…

android 中radiogroup滑动切换,巧妙实现缺角radiogroup控制多个fragment切换和滑动

在android开发中,用一个radiogroup控制多个fragment切换是十分常见的需求。但是如果fragment是一个ListView,如何保证滑动的时候通过缺角可以看到下面的listview是一个难点。直接上图:(1)完美效果(2)较差效果另外,不妨假设缺角的高度是5dp&am…

荣耀智慧屏评测 鸿蒙OS加持,荣耀智慧屏评测:鸿蒙OS加持 面向未来的超智能电视...

原标题:荣耀智慧屏评测:鸿蒙OS加持 面向未来的超智能电视 来源:TechWeb.com.cn当华为选择在今年公布鸿蒙OS系统后,很多人都在期待它的庐山真面目。无论是以后非常时期不再受制于人,或是循序渐进将系统过渡给自家设备&a…