【域适应】深度域适应常用的距离度量函数实现

关于

深度域适应中,有一类方法是实现目标域和源域的特征对齐,特征对齐的衡量函数主要包括MMD,MK-MMD,A-distance,CORAL loss, Wasserstein distance等等。本文总结了常用的特征变换对齐的函数定义。

工具

Python

方法实现

MMD 多核函数定义
# Compute MMD (maximum mean discrepancy) using numpy and scikit-learn.import numpy as np
from sklearn import metricsdef mmd_linear(X, Y):"""MMD using linear kernel (i.e., k(x,y) = <x,y>)Note that this is not the original linear MMD, only the reformulated and faster version.The original version is:def mmd_linear(X, Y):XX = np.dot(X, X.T)YY = np.dot(Y, Y.T)XY = np.dot(X, Y.T)return XX.mean() + YY.mean() - 2 * XY.mean()Arguments:X {[n_sample1, dim]} -- [X matrix]Y {[n_sample2, dim]} -- [Y matrix]Returns:[scalar] -- [MMD value]"""delta = X.mean(0) - Y.mean(0)return delta.dot(delta.T)def mmd_rbf(X, Y, gamma=1.0):"""MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))Arguments:X {[n_sample1, dim]} -- [X matrix]Y {[n_sample2, dim]} -- [Y matrix]Keyword Arguments:gamma {float} -- [kernel parameter] (default: {1.0})Returns:[scalar] -- [MMD value]"""XX = metrics.pairwise.rbf_kernel(X, X, gamma)YY = metrics.pairwise.rbf_kernel(Y, Y, gamma)XY = metrics.pairwise.rbf_kernel(X, Y, gamma)return XX.mean() + YY.mean() - 2 * XY.mean()def mmd_poly(X, Y, degree=2, gamma=1, coef0=0):"""MMD using polynomial kernel (i.e., k(x,y) = (gamma <X, Y> + coef0)^degree)Arguments:X {[n_sample1, dim]} -- [X matrix]Y {[n_sample2, dim]} -- [Y matrix]Keyword Arguments:degree {int} -- [degree] (default: {2})gamma {int} -- [gamma] (default: {1})coef0 {int} -- [constant item] (default: {0})Returns:[scalar] -- [MMD value]"""XX = metrics.pairwise.polynomial_kernel(X, X, degree, gamma, coef0)YY = metrics.pairwise.polynomial_kernel(Y, Y, degree, gamma, coef0)XY = metrics.pairwise.polynomial_kernel(X, Y, degree, gamma, coef0)return XX.mean() + YY.mean() - 2 * XY.mean()if __name__ == '__main__':a = np.arange(1, 10).reshape(3, 3)b = [[7, 6, 5], [4, 3, 2], [1, 1, 8], [0, 2, 5]]b = np.array(b)print(a)print(b)print(mmd_linear(a, b))  # 6.0print(mmd_rbf(a, b))  # 0.5822print(mmd_poly(a, b))  # 2436.5
 A-distance 函数定义
# Compute A-distance using numpy and sklearn
# Reference: Analysis of representations in domain adaptation, NIPS-07.import numpy as np
from sklearn import svmdef proxy_a_distance(source_X, target_X, verbose=False):"""Compute the Proxy-A-Distance of a source/target representation"""nb_source = np.shape(source_X)[0]nb_target = np.shape(target_X)[0]if verbose:print('PAD on', (nb_source, nb_target), 'examples')C_list = np.logspace(-5, 4, 10)half_source, half_target = int(nb_source/2), int(nb_target/2)train_X = np.vstack((source_X[0:half_source, :], target_X[0:half_target, :]))train_Y = np.hstack((np.zeros(half_source, dtype=int), np.ones(half_target, dtype=int)))test_X = np.vstack((source_X[half_source:, :], target_X[half_target:, :]))test_Y = np.hstack((np.zeros(nb_source - half_source, dtype=int), np.ones(nb_target - half_target, dtype=int)))best_risk = 1.0for C in C_list:clf = svm.SVC(C=C, kernel='linear', verbose=False)clf.fit(train_X, train_Y)train_risk = np.mean(clf.predict(train_X) != train_Y)test_risk = np.mean(clf.predict(test_X) != test_Y)if verbose:print('[ PAD C = %f ] train risk: %f  test risk: %f' % (C, train_risk, test_risk))if test_risk > .5:test_risk = 1. - test_riskbest_risk = min(best_risk, test_risk)return 2 * (1. - 2 * best_risk)
 CORAL loss函数定义
# Compute CORAL loss using pytorch
# Reference: DCORAL: Correlation Alignment for Deep Domain Adaptation, ECCV-16.import torchdef CORAL_loss(source, target):d = source.data.shape[1]ns, nt = source.data.shape[0], target.data.shape[0]# source covariancexm = torch.mean(source, 0, keepdim=True) - sourcexc = xm.t() @ xm / (ns - 1)# target covariancexmt = torch.mean(target, 0, keepdim=True) - targetxct = xmt.t() @ xmt / (nt - 1)# frobenius norm between source and targetloss = torch.mul((xc - xct), (xc - xct))loss = torch.sum(loss) / (4*d*d)return loss# Another implementation:
# Two implementations are the same. Just different formulation format.
# def CORAL(source, target):
#     d = source.size(1)
#     ns, nt = source.size(0), target.size(0)
#     # source covariance
#     tmp_s = torch.ones((1, ns)).to(DEVICE) @ source
#     cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)#     # target covariance
#     tmp_t = torch.ones((1, nt)).to(DEVICE) @ target
#     ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)#     # frobenius norm
#     loss = torch.norm(cs - ct, p='fro').pow(2)
#     loss = loss / (4 * d * d)
#     return loss
Wasserstein loss函数定义
import math
import torch
import torch.linalg as linalgdef calculate_2_wasserstein_dist(X, Y):'''Calulates the two components of the 2-Wasserstein metric:The general formula is given by: d(P_X, P_Y) = min_{X, Y} E[|X-Y|^2]For multivariate gaussian distributed inputs z_X ~ MN(mu_X, cov_X) and z_Y ~ MN(mu_Y, cov_Y),this reduces to: d = |mu_X - mu_Y|^2 - Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))Fast method implemented according to following paper: https://arxiv.org/pdf/2009.14075.pdfInput shape: [b, n] (e.g. batch_size x num_features)Output shape: scalar'''if X.shape != Y.shape:raise ValueError("Expecting equal shapes for X and Y!")# the linear algebra ops will need some extra precision -> convert to doubleX, Y = X.transpose(0, 1).double(), Y.transpose(0, 1).double()  # [n, b]mu_X, mu_Y = torch.mean(X, dim=1, keepdim=True), torch.mean(Y, dim=1, keepdim=True)  # [n, 1]n, b = X.shapefact = 1.0 if b < 2 else 1.0 / (b - 1)# Cov. MatrixE_X = X - mu_XE_Y = Y - mu_Ycov_X = torch.matmul(E_X, E_X.t()) * fact  # [n, n]cov_Y = torch.matmul(E_Y, E_Y.t()) * fact# calculate Tr((cov_X * cov_Y)^(1/2)). with the method proposed in https://arxiv.org/pdf/2009.14075.pdf# The eigenvalues for M are real-valued.C_X = E_X * math.sqrt(fact)  # [n, n], "root" of covarianceC_Y = E_Y * math.sqrt(fact)M_l = torch.matmul(C_X.t(), C_Y)M_r = torch.matmul(C_Y.t(), C_X)M = torch.matmul(M_l, M_r)S = linalg.eigvals(M) + 1e-15  # add small constant to avoid infinite gradients from sqrt(0)sq_tr_cov = S.sqrt().abs().sum()# plug the sqrt_trace_component into Tr(cov_X + cov_Y - 2(cov_X * cov_Y)^(1/2))trace_term = torch.trace(cov_X + cov_Y) - 2.0 * sq_tr_cov  # scalar# |mu_X - mu_Y|^2diff = mu_X - mu_Y  # [n, 1]mean_term = torch.sum(torch.mul(diff, diff))  # scalar# put it together

代码获取

相关项目和问题,欢迎沟通交流。

参考文献

Yan H, Ding Y, Li P, Wang Q, Xu Y, Zuo W. Mind the class weight bias: Weighted maximum mean discrepancy for unsupervised domain adaptation. InProceedings of the IEEE conference on computer vision and pattern recognition 2017 (pp. 2272-2281).

Sun B, Saenko K. Deep coral: Correlation alignment for deep domain adaptation. InComputer Vision–ECCV 2016 Workshops: Amsterdam, The Netherlands, October 8-10 and 15-16, 2016, Proceedings, Part III 14 2016 (pp. 443-450). Springer International Publishing.

Shen J, Qu Y, Zhang W, Yu Y. Wasserstein distance guided representation learning for domain adaptation. InProceedings of the AAAI conference on artificial intelligence 2018 Apr 29 (Vol. 32, No. 1).

Ben-David S, Blitzer J, Crammer K, Pereira F. Analysis of representations for domain adaptation. Advances in neural information processing systems. 2006;19.

Kramer O, Kramer O. Scikit-learn. Machine learning for evolution strategies. 2016:45-53.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/808716.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

初始C++之缺省参数 函数重载 引用

初始C之缺省参数 函数重载 引用& 文章目录 初始C之缺省参数 函数重载 引用&一、缺省参数1.1 缺省参数的定义1.2 缺省参数的分类1.3 注意事项 二、 函数重载2.1 函数重载的定义2.2 参数个数不同2.3 参数类型不同2.4 类型顺序不同2.5 为什么C语言不支持函数重载 三、引用…

OpenHarmony南向开发案例:【智能保险柜】

样例简介 智能保险柜实时监测保险柜中振动传感器&#xff0c;当有振动产生时及时向用户发出警报。在连接网络后&#xff0c;配合数字管家应用&#xff0c;用户可以远程接收智能保险柜的报警信息。后续可扩展摄像头等设备&#xff0c;实现对危险及时报警&#xff0c;及时处理&a…

探究 ChatGPT 的心脏--Transformer(基础知识第一篇)

Transformer 是 ChatGPT 的核心部分&#xff0c;如果将 AI 看做一辆高速运转的汽车&#xff0c;那么 Transformer 就是最重要的引擎。它是谷歌于 2017 年发表的《Attention is All You Need》中提出的 Sequence-to-sequence 的模型&#xff0c;诞生之后便一统江湖&#xff0c;在…

项目存放在git上,在jenkins使用docker打包并推送到Ubuntu上运行

项目添加dockerfile 在需要打包的工程的根目录添加Dockerfile文件&#xff0c;文件内容&#xff1a; # 设置JAVA版本 FROM openjdk:8 # 指定存储卷&#xff0c;任何向/tmp写入的信息都不会记录到容器存储层 VOLUME /tmp# 拷贝运行JAR包 ARG JAR_FILE COPY ${JAR_FILE} app.jar…

蓝桥杯练习系统(算法训练)ALGO-958 P0704回文数和质数

资源限制 内存限制&#xff1a;256.0MB C/C时间限制&#xff1a;1.0s Java时间限制&#xff1a;3.0s Python时间限制&#xff1a;5.0s 一个数如果从左往右读和从右往左读数字是完全相同的&#xff0c;则称这个数为回文数&#xff0c;比如898,1221,15651都是回文数。编写…

内核驱动更新

1.声明我们是开源的 .c 文件末尾加上 2.在Kconfig里面修改设备&#xff0c;bool&#xff08;双态&#xff09;-----》tristate&#xff08;三态&#xff09; 3.进入menuconfig修改为M 4.编译内核 make modules 也许你会看到一个 .ko 文件 5.复制到根目录文件下 在板子…

4.11作业

服务器端 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QTcpServer> //服务器端类 #include<QMessageBox> //消息对话框 #include<QTcpSocket> //客户端类 #include<QList> //链表容器QT_BEGIN_NAMESPACE namespace Ui { cla…

Pycharm远程连接服务器配置详解

背景&#xff1a; 相信很多人都遇到了这种情况&#xff0c;日常的开发和程序的验证都需要在linux环境下验证&#xff0c;而我们都是使用本地windows来进行开发或者脚本的编写&#xff0c;然后再push到远程仓库&#xff0c;再到linux环境下pull下来代码验证&#xff0c;这样每次…

CorelDRAW21.2.4中文最新官方和谐版下载

CorelDRAW是一款由加拿大Corel公司出品的平面设计软件&#xff0c;也被称为CDR。它是一款功能强大的矢量图形制作和排版软件&#xff0c;主要面向绘图设计师和印刷输出人员。该软件提供了矢量插图、页面布局、图片编辑和设计工具&#xff0c;广泛应用于排版印刷、矢量图形编辑及…

HWOD:密码强度等级

一、知识点 回车键的ASCII码是10 如果使用EOF&#xff0c;有些用例不通过 二、题目 1、描述 密码按如下规则进行计分&#xff0c;并根据不同的得分为密码进行安全等级划分。 一、密码长度: 5 分: 小于等于4 个字符 10 分: 5 到7 字符 25 分: 大于等于8 个字符 二、字母: 0…

NotePad++ 快速生成SQL IN (‘’,‘’)

sql In(‘’&#xff0c;‘’)这种形式 第一步&#xff1a;AltC 鼠标放在第一行最左边 第二步 CtrlH $代表行末 第三步 去掉每行换行符 换行可能是"\n" 或者"\r"或者"\r\n" 结果&#xff1a;

容错组合导航

在初始值正确的情况下&#xff0c;惯性导航短期精度较高&#xff0c;但是其误差随着时间是累计的。如果要提高惯性导航的长期精度&#xff0c;就必须提高惯性器件的精度和初始读准精度&#xff0c;这必将大大提高成本。 如果将惯性导航与其他导航系统适当地组合起来&#xff0c…

Java泛型中 T 和 ? 傻傻分不清楚

1.定义&#xff1a; JDK5.0后&#xff0c;Java提供了泛型。 泛型是一种在编译时提供类型安全的方式&#xff0c;允许程序员在定义类、接口和方法时使用类型参数。这样&#xff0c;可以在不损失类型安全的情况下&#xff0c;创建可重用的代码。 泛型有两种主要的使用形式&#x…

linux学习:栈

目录 顺序栈 结构 初始化一个空顺序栈 压栈 出栈 例子 十进制转八进制 链式栈 管理结构体的定义 初始化 压栈 出栈 顺序栈 顺序栈的实现&#xff0c;主要就是定义一块连续的内存来存放这些栈元素&#xff0c;同时为了方便管理&#xff0c; 再定义一个整数变量来代表…

2024中国(宁波)国际宠物用品博览会

2024中国(宁波)国际宠物用品博览会 People&Pet Fair 2024 专注2B交易&#xff0c;关注人宠发展&#xff0c;它经济&#xff0c;势不可挡! 时间&#xff1a;2024年11月14-16日 地点&#xff1a;宁波国际会展中心 详询主办方陆先生 I38&#xff08;前三位&#xff09; …

水离子雾化壁炉与酒店大厅的氛围搭配

将水离子雾化壁炉与酒店大厅的氛围搭配是一个很好的主意&#xff0c;可以为大厅增添舒适、温馨的氛围&#xff0c;以下是一些建议&#xff1a; 迎宾区域&#xff1a;在酒店大厅的迎宾区域设置水离子雾化壁炉&#xff0c;作为客人抵达时的第一印象。壁炉的温馨效果可以让客人感到…

Java+BS +saas云HIS系统源码SpringBoot+itext + POI + ureport2数字化医院系统源码

JavaBS saas云HIS系统源码SpringBootitext POI ureport2数字化医院系统源码 医院云HIS系统是一种运用云计算、大数据、物联网等新兴信息技术的业务和技术平台。它按照现代医疗卫生管理要求&#xff0c;在特定区域内以数字化形式收集、存储、传递和处理医疗卫生行业的数据。通…

【应用】SpringBoot-自动配置原理

前言 本文简要介绍SpringBoot的自动配置原理。 本文讲述的SpringBoot版本为&#xff1a;3.1.2。 前置知识 在看原理介绍之前&#xff0c;需要知道Import注解的作用&#xff1a; 可以导入Configuration注解的配置类、声明Bean注解的bean方法&#xff1b;可以导入ImportSele…

异构超图嵌入的图分类 笔记

1 Title Heterogeneous Hypergraph Embedding for Graph Classification&#xff08;Xiangguo Sun , PictureHongzhi Yin , PictureBo Liu , PictureHongxu Chen , PictureJiuxin Cao , PictureYingxia Shao , PictureNguyen Quoc Viet Hung&#xff09;【WSDM 2021】 2 Co…

模拟移动端美团案例(react版)

文章目录 目录 概述 项目搭建 1.启动项目&#xff08;mock服务前端服务&#xff09; 2.使用Redux ToolTik(RTK)编写store(异步action) 3.组件触发action并渲染数据 一、渲染列表 ​编辑 二、tab切换类交互 三、添加购物车 四、统计区域功能实现 五、购物车列表功能实现 六、控制…