【域适应】基于深度域适应MMD损失的典型四分类任务实现

关于

MMD (maximum mean discrepancy)是用来衡量两组数据分布之间相似度的度量。一般地,如果两组数据分布相似,那么MMD 损失就相对较小,说明两组数据/特征处于相似的特征空间中。基于这个想法,对于源域和目标域数据,在使用深度学习进行特征提取中,使用MMD损失,可以让模型提取两个域的共有特征/空间,从而实现源域到目标域的迁移。

参考论文:https://arxiv.org/abs/1409.6041

工具

Python

 

方法实现

定义mmd函数
#!/usr/bin/env python
# encoding: utf-8import torch# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
#             = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
def mmd_linear(f_of_X, f_of_Y):delta = f_of_X - f_of_Yloss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))return lossdef guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):n_samples = int(source.size()[0])+int(target.size()[0])total = torch.cat([source, target], dim=0)total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))L2_distance = ((total0-total1)**2).sum(2)if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]return sum(kernel_val)#/len(kernel_val)def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)loss = 0for i in range(batch_size):s1, s2 = i, (i+1)%batch_sizet1, t2 = s1+batch_size, s2+batch_sizeloss += kernels[s1, s2] + kernels[t1, t2]loss -= kernels[s1, t2] + kernels[s2, t1]return loss / float(batch_size)def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):batch_size = int(source.size()[0])kernels = guassian_kernel(source, target,kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)XX = kernels[:batch_size, :batch_size]YY = kernels[batch_size:, batch_size:]XY = kernels[:batch_size, batch_size:]YX = kernels[batch_size:, :batch_size]loss = torch.mean(XX + YY - XY -YX)return loss
定义基于mmd特征对齐CNN模型
# encoding=utf-8import torch.nn as nn
import torch.nn.functional as Fclass Network(nn.Module):def __init__(self):super(Network, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(1, 3)),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(1, 3)),nn.ReLU(),nn.Dropout(0.4),nn.MaxPool2d(kernel_size=(1, 2), stride=2))self.fc1 = nn.Sequential(nn.Linear(in_features=64 * 98, out_features=100),nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(in_features=100, out_features=2))def forward(self, src, tar):x_src = self.conv1(src)x_tar = self.conv1(tar)x_src = self.conv2(x_src)x_tar = self.conv2(x_tar)#print(x_src.shape)x_src = x_src.reshape(-1, 64 * 98)x_tar = x_tar.reshape(-1, 64 * 98)x_src_mmd = self.fc1(x_src)x_tar_mmd = self.fc1(x_tar)#x_src = self.fc1(x_src)#x_tar = self.fc1(x_tar)#x_src_mmd = self.fc2(x_src)#x_tar_mmd = self.fc2(x_tar)y_src = self.fc2(x_src_mmd)return y_src, x_src_mmd, x_tar_mmd

代码获取

后台私信;

其他相关域适应问题和代码开发,欢迎沟通和交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/817146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

顶切,半顶切是什么意思?

齿轮加工及刀具中有一些特定名词或者叫法,不熟悉的小伙伴可能最开始会有一些困惑,这不,最近有小伙伴问了一个问题:顶切是说齿顶的倒角吗? 今天就给大家说说顶切和半顶切。 一、顶切 Topping 从字面上可以看到可以想到…

MySQL的权限管理

MySQL的权限管理 在理解MySQL的权限管理之前,我们需要先了解其架构设计以及权限管理在该架构中的定位。 MySQL的架构设计 MySQL数据库系统采用了分层的架构设计,主要可以分为以下几个层级: 连接层:最外层,处理连接…

爬虫 selenium

爬虫 selenium 【一】介绍 【1】说明 Selenium是一款广泛应用于Web应用程序测试的自动化测试框架 它可以模拟用户再浏览器上的行为对Web应用进行自动化测试 主要作用: 浏览器控制:启动、切换、关闭不同浏览器元素定位于操作:通过CSS选择器…

vscode中运行js

vscode中运行js 目前vscode插件运行js都是基于node环境,vscode控制台打印有些数据不方便等缺点。 每次调试在浏览器中运行js,需要创建html模板、插入js。期望能够直接运行js可以打开浏览器运行js,在vscode插件市场找到一款插件可以做到。 插…

yolo系列(之一)

深度学习经典检测算法 two-stage (两阶段) : Faster-rcnn Mask-Rcnn系列 (输入图像---》CNN特征---》预选框---》输出结果) one-stage (单阶段): YOLO系列 (输入图像---》CNN特征---》输出结果) one-stage的特点:&…

深度学习学习日记4.15 (面向GPT学习)

精确学习时间(09点35分开始) 深度学习 torch.nntorch.utils.datanumpytorchvision中的模块有哪些os 模块PIL(Python Imaging Library)tqdmmatplotlibnn.ReLU inplace参数设为Truenn.relu 训练的迭代过程梯度清零loss指标计算为什…

SQLite超详细的编译时选项(十六)

返回:SQLite—系列文章目录 上一篇:SQLite数据库文件格式(十五) 下一篇:SQLite 在Android安装与定制方案(十七) 1. 概述 对于大多数目的,SQLite可以使用默认的 编译选项。但是…

WinForms 零基础进阶教程:文件操作与 CSV 处理

文章目录 文件操作数据存储与文件操作文件存取的好处文件存取的方式文本文件的写入和读取文本文件的删除、复制和移动 目录的操作文件属性操作文件路径 对话框OpenFileDialog对话框SaveFileDialog对话框对话框中CheckPathExists属性的应用 CSV 文件读写与 DataGridView 进阶Dat…

Python基于Django的微博热搜、微博舆论可视化系统

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

Redis限流插件

Redis限流插件: 1:搭建层级结构 同时对 redis.log 授权 chmod 777 redis.log2:确认 redis 版本 3:下载redis配置文件 redis.conf https://redis.io/docs/management/config/ 4:上传/redis/conf作为原始 redis.conf 5:在/redis_6390/conf下编辑redis.conf docker run -it \ --…

51单片机上面的IIC协议

1、什么是IIC协议 2、模拟IIC协议 51单片机上面是没有与IIC协议相关的寄存器的(没有相关的硬件),不像串口可以配置对应的寄存器达到目的(比如修改波特率9600 or 115200),要配置IIC只能够根据用户手册里面的…

​面试经典150题——LRU 缓存

​ 1. 题目描述 2. 题目分析与解析 首先讲解一下LRU LRU 是“Least Recently Used”的缩写,LRU 算法的基本思想是跟踪最近最少使用的数据,并在缓存已满且需要存储新数据时优先驱逐该数据。 LRU 算法通常的工作原理的简化解释: 当访问或使…

vue.config.js跨域问题解决

讲解视频 问题背景 目标地址: 而当前项目启动是http,协议名不同,所以跨域了 解决步骤和解答 1. 新建vue.config.js文件 2. 添加如下代码: 一般目标路径target写 域名 就可以了 但其实,写路径也可以,…

查看 Linux 接入的 USB 设备速率是 USB2 还是 USB3

查看接入 usb 设备的速率 使用以下命令查看接入的 USB 设备速率(每一行最后的 xxM 字样)。插入设备前查看一次,插入设备后查看一次,对比即可定位到刚插入的设备是哪一条。 lsusb -t命令输出如下图 对照 USB 速率表 对照 USB 速…

EasyRecovery数据恢复软件2024试用版下载安装包

EasyRecovery支持的文件格式非常广泛,几乎涵盖了用户日常所需的所有文件类型。具体来说,它支持恢复的办公文档类型包括Microsoft Word、Excel、PPT、MS office、Adobe PDF、Access等。此外,对于音频文件,EasyRecovery同样支持丰富…

初识three.js创建第一个threejs3D页面

说到3D&#xff0c;想必大家都能想到three.js&#xff0c;它是由WebGL封装出来的&#xff0c;接下来&#xff0c;我手把手教大家创建一个简单的3D页面 话尽在代码中&#xff0c;哈哈 大家可以复制代码玩一下 <!DOCTYPE html> <html lang"en"><head&…

【嵌入式 - 输出驱动电路Open Drain (开漏)和Push-Pull (推挽)】

定义 Open drain 和 push-pull 是两种常见的输出驱动电路。它们在数字电子电路中用于控制信号的输出。让我逐一解释它们&#xff1a; 1. Open Drain (开漏): Open drain 输出端通常连接到地 (GND) 或者一个高电阻 (pull-up) 上。当输出信号为逻辑高电平时&#xff0c;输出端…

【位运算】Leetcode 丢失的数字

题目解析 268. 丢失的数字 本题的意思就是数组的长度为n&#xff0c;在[0,n]区间中寻找缺失的一个数字 算法讲解 直观思路&#xff1a;排序 Hash&#xff0c;顺序查找缺失的数字 优化&#xff1a;使用异或&#xff0c;首先将[0,n]之间所有数字异或在一起&#xff0c;然后将…

链表创建的陷阱与细节

链表是线性表的一种&#xff0c;它在逻辑结构上是连续的&#xff0c;在物理结构上是非连续的。 也就是说链表在物理空间上是独立的&#xff0c;可能是东一块西一块的。如下顺序表和链表在内存空间上的对比&#xff1a; 而链表的每一块空间是如何产生联系实现在逻辑结构上是连续…

移动应用安全合规动态:网信办、金管局发文强调数据安全;3月个人信息违规抽查结果出炉!(第五期)

一、监管部门动向&#xff1a;国家互联网信息办公室公布《促进和规范数据跨境流动规定》; 工信部发布《关于网络安全保险典型服务方案目录的公示》 二、安全新闻&#xff1a;恶意软件警报&#xff01;黑客利用软件即服务攻击印度安卓用户&#xff1b;Cerberus银行恶意软件的虚…