【域适应】基于深度域适应MMD损失的典型四分类任务实现

关于

MMD (maximum mean discrepancy)是用来衡量两组数据分布之间相似度的度量。一般地,如果两组数据分布相似,那么MMD 损失就相对较小,说明两组数据/特征处于相似的特征空间中。基于这个想法,对于源域和目标域数据,在使用深度学习进行特征提取中,使用MMD损失,可以让模型提取两个域的共有特征/空间,从而实现源域到目标域的迁移。

参考论文:https://arxiv.org/abs/1409.6041

工具

Python

 

方法实现

定义mmd函数
#!/usr/bin/env python
# encoding: utf-8import torch# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
#             = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):delta = f_of_X - f_of_Yloss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))return lossdef guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):n_samples = int(source.size()[0])+int(target.size()[0])total = torch.cat([source, target], dim=0)total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))L2_distance = ((total0-total1)**2).sum(2)if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]return sum(kernel_val)#/len(kernel_val)def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)loss = 0for i in range(batch_size):s1, s2 = i, (i+1)%batch_sizet1, t2 = s1+batch_size, s2+batch_sizeloss += kernels[s1, s2] + kernels[t1, t2]loss -= kernels[s1, t2] + kernels[s2, t1]return loss / float(batch_size)def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)XX = kernels[:batch_size, :batch_size]YY = kernels[batch_size:, batch_size:]XY = kernels[:batch_size, batch_size:]YX = kernels[batch_size:, :batch_size]loss = torch.mean(XX + YY - XY -YX)return loss
定义基于mmd特征对齐CNN模型
# encoding=utf-8import torch.nn as nn
import torch.nn.functional as Fclass Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 3)),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 3)),nn.ReLU(),nn.Dropout(0.4),nn.MaxPool2d(kernel_size=(1, 2), stride=2))self.fc1 = nn.Sequential(nn.Linear(in_features=64 * 98, out_features=100),nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=100, out_features=2))def forward(self, src, tar):x_src = self.conv1(src)x_tar = self.conv1(tar)x_src = self.conv2(x_src)x_tar = self.conv2(x_tar)#print(x_src.shape)x_src = x_src.reshape(-1, 64 * 98)x_tar = x_tar.reshape(-1, 64 * 98)x_src_mmd = self.fc1(x_src)x_tar_mmd = self.fc1(x_tar)#x_src = self.fc1(x_src)#x_tar = self.fc1(x_tar)#x_src_mmd = self.fc2(x_src)#x_tar_mmd = self.fc2(x_tar)y_src = self.fc2(x_src_mmd)return y_src, x_src_mmd, x_tar_mmd

代码获取

后台私信;

其他相关域适应问题和代码开发,欢迎沟通和交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/817146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

顶切,半顶切是什么意思?

齿轮加工及刀具中有一些特定名词或者叫法,不熟悉的小伙伴可能最开始会有一些困惑,这不,最近有小伙伴问了一个问题:顶切是说齿顶的倒角吗? 今天就给大家说说顶切和半顶切。 一、顶切 Topping 从字面上可以看到可以想到…

Swagger API 文档 | SpringCloudGateway 集成 SpringDoc

文章目录 工作原理方案 1:配置 swagger-ui.urls方案 2:通过路由定义动态配置具体案例第 1 步:导入代码第 2 步:配置 Swagger Url 列表第 3 步:启动程序第 4 步:查看注册中心第 5 步:访问网关 Swagger UI相关博文😎 本节目标: Spring Cloud Gateway 集成 SpringDoc,实…

MySQL的权限管理

MySQL的权限管理 在理解MySQL的权限管理之前,我们需要先了解其架构设计以及权限管理在该架构中的定位。 MySQL的架构设计 MySQL数据库系统采用了分层的架构设计,主要可以分为以下几个层级: 连接层:最外层,处理连接…

x264 8x8 水平预测汇编分析

C 语言代码 void s264_predict_8x8_h_c( pixel *src, pixel edge[36] ) { PREDICT_8x8_LOAD_LEFT #define ROW(y) MPIXEL_X4( srcy*FDEC_STRIDE0 ) \ MPIXEL_X4( srcy*FDEC_STRIDE4 ) PIXEL_SPLAT_X4( l##y ); ROW(0); ROW(1); ROW(2); ROW(3); ROW(4); ROW(5); ROW(6); ROW(7…

VUE3:自定义loading指令

一、效果描述 在一个div中使用该指令并绑定一个变量&#xff0c;通过修改变量实现Loading的显示与隐藏。 二、代码与使用方式 1.使用指令的vue组件 <template><div v-loading"showLoading"></div> </template> <script setup lang&q…

爬虫 selenium

爬虫 selenium 【一】介绍 【1】说明 Selenium是一款广泛应用于Web应用程序测试的自动化测试框架 它可以模拟用户再浏览器上的行为对Web应用进行自动化测试 主要作用&#xff1a; 浏览器控制&#xff1a;启动、切换、关闭不同浏览器元素定位于操作&#xff1a;通过CSS选择器…

vscode中运行js

vscode中运行js 目前vscode插件运行js都是基于node环境&#xff0c;vscode控制台打印有些数据不方便等缺点。 每次调试在浏览器中运行js&#xff0c;需要创建html模板、插入js。期望能够直接运行js可以打开浏览器运行js&#xff0c;在vscode插件市场找到一款插件可以做到。 插…

yolo系列(之一)

深度学习经典检测算法 two-stage (两阶段) : Faster-rcnn Mask-Rcnn系列 &#xff08;输入图像---》CNN特征---》预选框---》输出结果&#xff09; one-stage (单阶段): YOLO系列 &#xff08;输入图像---》CNN特征---》输出结果&#xff09; one-stage的特点&#xff1a;&…

深度学习学习日记4.15 (面向GPT学习)

精确学习时间&#xff08;09点35分开始&#xff09; 深度学习 torch.nntorch.utils.datanumpytorchvision中的模块有哪些os 模块PIL&#xff08;Python Imaging Library&#xff09;tqdmmatplotlibnn.ReLU inplace参数设为Truenn.relu 训练的迭代过程梯度清零loss指标计算为什…

2024 蓝桥打卡Day40

2021、2022年蓝桥杯真题练习 2021年蓝桥杯真题练习A ASCB 卡片C 查找D 货物摆放F 时间显示H 杨辉三角形 2022年蓝桥杯真题练习A 星期计算B 山C 字符统计D 最少刷题数E 求阶乘 2021年蓝桥杯真题练习 A ASC package THL_0412;public class A_2021 {public static void main(Str…

SQLite超详细的编译时选项(十六)

返回&#xff1a;SQLite—系列文章目录 上一篇&#xff1a;SQLite数据库文件格式&#xff08;十五&#xff09; 下一篇&#xff1a;SQLite 在Android安装与定制方案&#xff08;十七&#xff09; 1. 概述 对于大多数目的&#xff0c;SQLite可以使用默认的 编译选项。但是…

WinForms 零基础进阶教程:文件操作与 CSV 处理

文章目录 文件操作数据存储与文件操作文件存取的好处文件存取的方式文本文件的写入和读取文本文件的删除、复制和移动 目录的操作文件属性操作文件路径 对话框OpenFileDialog对话框SaveFileDialog对话框对话框中CheckPathExists属性的应用 CSV 文件读写与 DataGridView 进阶Dat…

Python基于Django的微博热搜、微博舆论可视化系统

博主介绍&#xff1a;✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3…

Redis限流插件

Redis限流插件: 1:搭建层级结构 同时对 redis.log 授权 chmod 777 redis.log2:确认 redis 版本 3:下载redis配置文件 redis.conf https://redis.io/docs/management/config/ 4:上传/redis/conf作为原始 redis.conf 5:在/redis_6390/conf下编辑redis.conf docker run -it \ --…

美国有哪些银行可以开?美国华美银行开户

很多客户在注册香港公司后&#xff0c;都会选择美国的银行开户&#xff0c;因为目前较好开&#xff0c;简单&#xff0c;快速&#xff0c;便宜&#xff0c;性价比较高的银行 。 华美银行现为全美以华裔为主要市场规模较大的商业银行&#xff0c;总部位于加州&#xff0c;专注于…

营销自动化的未来趋势

​营销自动化在2024年的趋势将继续沿着数字化、个性化和智能化的方向发展。 根据小编全网搜索相关资料信息整理得出&#xff0c;营销自动化有以下可能趋势&#xff1a; 1、深度智能化&#xff1a;随着人工智能技术的进一步发展&#xff0c;营销自动化将更加智能化。通过机器学…

51单片机上面的IIC协议

1、什么是IIC协议 2、模拟IIC协议 51单片机上面是没有与IIC协议相关的寄存器的&#xff08;没有相关的硬件&#xff09;&#xff0c;不像串口可以配置对应的寄存器达到目的&#xff08;比如修改波特率9600 or 115200&#xff09;&#xff0c;要配置IIC只能够根据用户手册里面的…

​面试经典150题——LRU 缓存

​ 1. 题目描述 2. 题目分析与解析 首先讲解一下LRU LRU 是“Least Recently Used”的缩写&#xff0c;LRU 算法的基本思想是跟踪最近最少使用的数据&#xff0c;并在缓存已满且需要存储新数据时优先驱逐该数据。 LRU 算法通常的工作原理的简化解释&#xff1a; 当访问或使…

第十五届蓝桥杯python组

第十五届蓝桥杯Python组的比赛涉及了多个编程问题&#xff0c;这些问题覆盖了不同的算法和数据结构知识点。以下是一些题目的描述和解题思路&#xff1a; A题“穿越时空之门”要求计算在二进制和四进制表示下&#xff0c;数字的各数位之和相等的勇者数量。解题的关键是将数字转…

vue.config.js跨域问题解决

讲解视频 问题背景 目标地址&#xff1a; 而当前项目启动是http&#xff0c;协议名不同&#xff0c;所以跨域了 解决步骤和解答 1. 新建vue.config.js文件 2. 添加如下代码&#xff1a; 一般目标路径target写 域名 就可以了 但其实&#xff0c;写路径也可以&#xff0c;…