Domain Adaptation(李宏毅)机器学习 2023 Spring HW11 (Boss Baseline)

1. 领域适配简介

领域适配是一种迁移学习方法,适用于源领域和目标领域数据分布不同但学习任务相同的情况。具体而言,我们在源领域(通常有大量标注数据)训练一个模型,并希望将其应用于目标领域(通常只有少量或没有标注数据)。然而,由于这两个领域的数据分布不同,模型在目标领域上的性能可能会显著下降。领域适配技术的目标是通过对模型进行适配,缩小源领域与目标领域之间的差距,从而提升模型在目标领域的表现。

Domain Shift (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

以数字识别为例,如果我们的源数据是灰度图像,并且在这些数据上训练模型,我们可以预期模型会取得相当不错的效果。然而,如果我们将这个在灰度图像上训练的模型用于分类彩色图像,模型的表现可能会较差。这是因为这两个数据集之间存在领域转移。

领域适配方法可以根据目标领域中标签的可用性进行分类:

  1. 有监督领域适配:源领域和目标领域都有标注数据。这种情况较为少见,因为领域适配的主要动机是目标领域标签的稀缺性。

  2. 无监督领域适配:源领域有标注数据,而目标领域没有标注数据。这是最常见且最具挑战性的情况。

  3. 半监督领域适配:源领域有标注数据,目标领域则只有少量标注数据。

Different Domain Adaptation Scenarios (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们的博客和作业主要关注目标领域缺乏标注数据的场景。

解决这个问题的基本概念如下:我们旨在找到一个特征提取器,它能够接收输入数据并输出特征空间。这个特征提取器应该能够滤除领域特定的变化,同时保留不同领域之间共享的特征。例如,在以下的示例中,特征提取器应该能够忽略图像的颜色,对于相同的数字,不论其颜色如何,都能生成具有相同分布的特征。

Basic Idea of Domain Adaptation (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

研究人员提出了许多方法,其中对抗学习方法是最常见且最有效的技术之一。

Domain Adversarial Training - 1 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

我们将一个标准网络分为两部分:特征提取器和标签预测器。在训练过程中,我们以标准的有监督方式在源领域数据上训练整个网络。对于目标领域数据,我们只使用特征提取器提取特征,并采用技术手段将目标领域的特征与源领域的特征对齐。

Domain Adversarial Training - 2 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

具体来说,我们设计了一个新的领域分类器,它是一个二分类器,输入特征向量并判断输入数据是来自源领域还是目标领域。另一方面,特征生成器的设计目的是“欺骗”领域分类器,使其无法正确区分来源领域。

Domain Adversarial Training - 3 (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

如果我们仔细思考上述方法,我们可以直观地理解,尽管对抗训练可以使源领域和目标领域的整体分布更加相似,如下图左侧所示,但这种分布可能并不适合或不适用于机器学习任务。理想情况下,我们期望获得右侧图像所示的分布。

Limitation of Domain Adversarial Training (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2021-course-data/da_v6.pdf)

当然,已有大量论文提出了针对这一问题的解决方法。为了在这次作业中通过strong 和 boss baseline,我们需要深入相关文献,并采用合适的方法。在作业中,我将介绍更多相关的论文和技术。

2. Homework Results and Analysis

作业 11 聚焦于领域适配。给定真实图像(带标签)和涂鸦(无标签),任务是利用领域适配技术训练一个网络,能够准确预测绘制图像的标签。

task description (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

数据集设置:

  • 标签:10个类别(编号从0到9),如以下图片所示。

  • 训练集:5000张 (32, 32) RGB 真实图像(带标签)。

  • 测试集:100000张 (28, 28) 灰度绘制图像。

source and target data (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

baseline 的门槛 在 Kaggle 上的数值为:

Baseline

Public

Private

Simple

Score >= 0.44280

Score >= 0.44012

Medium

Score >= 0.65994

Score >= 0.65928

Strong

Score >= 0.75342

Score >= 0.75518

Boss

Score >= 0.81072

Score >= 0.80794

像往常一样,助教会提供关于如何超越各种基准模型的指导。

Hints for Simple, Medium and Strong Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

Hints for Boss Baseline (source: https://speech.ee.ntu.edu.tw/~hylee/ml/ml2023-course-data/HW11.pdf)

2.1 Simple Baseline

使用助教提供的默认代码足以通过 simple baseline。

2.2 Medium Baseline

通过增加训练轮数并调整超参数 lambda,可以通过 medium baseline。

num_epochs = 800
# train 800 epochswith Progress(TextColumn("[progress.description]{task.description}"),BarColumn(),TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),TimeRemainingColumn(),TimeElapsedColumn()) as progress:epoch_tqdm = progress.add_task(description="epoch progress", total=num_epochs)for epoch in range(num_epochs):train_D_loss, train_F_loss, train_acc = train_epoch(source_dataloader, target_dataloader, progress, lamb=0.6)progress.advance(epoch_tqdm, advance=1)if epoch == 10:torch.save(feature_extractor.state_dict(), f'extractor_model_early.bin')torch.save(label_predictor.state_dict(), f'predictor_model_early.bin')elif epoch == 100:torch.save(feature_extractor.state_dict(), f'extractor_model_mid.bin')torch.save(label_predictor.state_dict(), f'predictor_model_mid.bin')torch.save(feature_extractor.state_dict(), f'extractor_model.bin')torch.save(label_predictor.state_dict(), f'predictor_model.bin')print('epoch {:>3d}: train D loss: {:6.4f}, train F loss: {:6.4f}, acc {:6.4f}'.format(epoch, train_D_loss, train_F_loss, train_acc))

2.3 Strong Baseline

助教建议了几篇论文来提升性能并通过strong baseline。其中,我发现以下这篇论文特别有趣:《Minimum Class Confusion for Versatile Domain Adaptation》(Jin, Ying, et al.)(链接)。

他们“提出了一种新颖的损失函数:Minimum Class Confusion(MCC)。它可以被描述为一种新颖且多功能的领域适配方法,无需显式进行领域对齐,且具有较快的收敛速度。此外,它还可以作为一种通用正则化器,与现有的领域适配方法正交且互补,从而进一步加速和改善这些已有的竞争性方法。”(Jin, Ying, et al.,p. 3)

The schematic of the Minimum Class Confusion (MCC) loss function (source: https://arxiv.org/abs/1912.03699)

MCC 的计算过程如下:

给定以下变量:

  • \mathbf{f}_t:网络输出的目标领域数据的logits(即网络分类器的输出)。

  • T :一个温度参数,用于缩放logits,使其更加平滑并增大类别分布之间的差异。

  • \mathbf{p}_t:目标领域经温度平滑后的预测结果,表示通过softmax得到的概率分布。

  • H(\cdot):熵函数,用于衡量每个样本的预测不确定性。

MCC步骤1:目标领域logits的温度缩放:

目标领域的logits ​ \mathbf{f}_t 通过温度进行缩放,以平滑分类概率:

\\ \mathbf{f}_t' = \frac{\mathbf{f}_t}{T} \\

其中, T > 1 用于拉伸预测的概率分布,防止模型过于自信。

MCC步骤2:计算Softmax输出:

将经过温度缩放的logits通过softmax函数得到目标领域预测的概率分布 \mathbf{p}_t ​:

\mathbf{p}_t = \text{Softmax}(\mathbf{f}_t') \\

此处, \mathbf{p}_t ​是一个 N \times C 的矩阵,其中 N 是目标领域样本的数量,C 是分类的类别数。

MCC步骤3:计算样本熵权重:

每个样本的熵 H(\mathbf{p}_t) 使用以下公式计算:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/66969.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL从入门到实战-1

目录 学前须知 sqlzoo数据介绍 world nobel covid ge game、goal、eteam teacher、dept movie、casting、actor 基础语句 select&from 基础查询select单列&多列&所有列&别名应用 例题一 例题二 例题三 select使用distinct去重 例题四 例题五…

Python在Excel工作表中创建数据透视表

在数据处理和分析工作中,Excel作为一个广泛使用的工具,提供了强大的功能来管理和解析数据。当面对大量复杂的数据集时,为了更高效地总结、分析和展示数据,创建数据透视表成为一种不可或缺的方法。通过使用Python这样的编程语言与E…

springboot整合h2

在 Spring Boot 中整合 H2 数据库非常简单。H2 是一个轻量级的嵌入式数据库&#xff0c;非常适合开发和测试环境。以下是整合 H2 数据库的步骤&#xff1a; 1. 添加依赖 首先&#xff0c;在你的 pom.xml 文件中添加 H2 数据库的依赖&#xff1a; <dependency><grou…

Web前端界面开发

前沿&#xff1a;介绍自适应和响应式布局 自适应布局&#xff1a;-----针对页面1个像素的变换而变化 就是我们上一个练习的效果 我们的页面效果&#xff0c;随着我们的屏幕大小而发生适配的效果&#xff08;类似等比例&#xff09; 如&#xff1a;rem适配 和 vw/vh适配 …

【01】AE特效开发制作特技-Adobe After Effects-AE特效制作快速入门-制作飞机,子弹,爆炸特效以及导出png序列图-优雅草央千澈

【01】AE特效开发制作特技-Adobe After Effects-AE特效制作快速入门-制作飞机&#xff0c;子弹&#xff0c;爆炸特效以及导出png序列图-优雅草央千澈 开发背景 优雅草央千澈所有的合集&#xff0c;系列文章可能是不太适合完全初学者的&#xff0c;因为课程不会非常细致的系统…

java项目之在线文档管理系统源码(springboot+mysql+vue+文档)

大家好我是风歌&#xff0c;曾担任某大厂java架构师&#xff0c;如今专注java毕设领域。今天要和大家聊的是一款基于springboot的在线文档管理系统。项目源码以及部署相关请联系风歌&#xff0c;文末附上联系信息 。 项目简介&#xff1a; 在线文档管理系统的主要使用者分为管…

可靠的人形探测,未完待续(III)

一不小心&#xff0c;此去经年啊。问大家新年快乐&#xff01; 那&#xff0c;最近在研究毫米波雷达模块嘛&#xff0c;期望用在后续的产品中&#xff0c;正好看到瑞萨的活动送板子&#xff0c;手一下没忍住。 拿了板子就得干活咯&#xff0c;我一路火花带闪电&#xff0c;开整…

【灵码助力安全3】——利用通义灵码辅助智能合约漏洞检测的尝试

前言 随着区块链技术的快速发展&#xff0c;智能合约作为去中心化应用&#xff08;DApps&#xff09;的核心组件&#xff0c;其重要性日益凸显。然而&#xff0c;智能合约的安全问题一直是制约区块链技术广泛应用的关键因素之一。由于智能合约代码一旦部署就难以更改&#xf…

腾讯云下架印度云服务器节点,印度云服务器租用何去何从

近日&#xff0c;腾讯云下架印度云服务器节点的消息引起了业界的广泛关注。这一变动让许多依赖印度云服务器的用户开始担忧&#xff0c;印度云服务器租用的未来究竟在何方&#xff1f; 从印度市场本身来看&#xff0c;其云服务市场的潜力不容小觑。据 IDC 报告&#xff0c;到 2…

【RTSP】使用webrtc播放rtsp视频流

一、简介 rtsp流一般是监控、摄像机的实时视频流,现在的主流浏览器是不支持播放rtsp流文件的,所以需要借助其他方案来播放实时视频,下面介绍下我采用的webrtc方案,实测可行。 二、webrtc-streamer是什么? webrtc-streamer是一个使用简单机制通过 WebRTC 流式传输视频捕获…

多并发发短信处理(头条项目-07)

1 pipeline操作 Redis数据库 Redis 的 C/S 架构&#xff1a; 基于客户端-服务端模型以及请求/响应协议的 TCP服务。客户端向服务端发送⼀个查询请求&#xff0c;并监听Socket返回。通常是以 阻塞模式&#xff0c;等待服务端响应。服务端处理命令&#xff0c;并将结果返回给客…

【网络协议】动态路由协议

前言 本文将概述动态路由协议&#xff0c;定义其概念&#xff0c;并了解其与静态路由的区别。同时将讨论动态路由协议相较于静态路由的优势&#xff0c;学习动态路由协议的不同类别以及无类别&#xff08;classless&#xff09;和有类别&#xff08;classful&#xff09;的特性…

c#集成npoi根据excel模板导出excel

NuGet中安装npoi 创建excel模板&#xff0c;替换其中的内容生成新的excel文件。 例子中主要写了这四种情况&#xff1a; 1、替换单个单元格内容&#xff1b; 2、替换横向多个单元格&#xff1b; 3、替换表格&#xff1b; 4、单元格中插入图片&#xff1b; using System.IO; …

人工智能知识分享第十天-机器学习_聚类算法

聚类算法 1 聚类算法简介 1.1 聚类算法介绍 一种典型的无监督学习算法&#xff0c;主要用于将相似的样本自动归到一个类别中。 目的是将数据集中的对象分成多个簇&#xff08;Cluster&#xff09;&#xff0c;使得同一簇内的对象相似度较高&#xff0c;而不同簇之间的对象相…

B树及其Java实现详解

文章目录 B树及其Java实现详解一、引言二、B树的结构与性质1、节点结构2、性质 三、B树的操作1、插入操作1.1、插入过程 2、删除操作2.1、删除过程 3、搜索操作 四、B树的Java实现1、节点类实现2、B树类实现 五、使用示例六、总结 B树及其Java实现详解 一、引言 B树是一种多路…

本地缓存:Guava Cache

这里写目录标题 一、范例二、应用场景三、加载1、CacheLoader2、Callable3、显式插入 四、过期策略1、基于容量的过期策略2、基于时间的过期策略3、基于引用的过期策略 五、显示清除六、移除监听器六、清理什么时候发生七、刷新八、支持更新锁定能力 一、范例 LoadingCache<…

【高录用 | 快见刊 | 快检索】第十届社会科学与经济发展国际学术会议 (ICSSED 2025)

第十届社会科学与经济发展国际学术会议(ICSSED 2025)定于2025年2月28日-3月2日在中国上海隆重举行。会议主要围绕社会科学与经济发展等研究领域展开讨论。会议旨在为从事社会科学与经济发展研究的专家学者提供一个共享科研成果和前沿技术&#xff0c;了解学术发展趋势&#xff…

[ComfyUI]接入Google的Whisk,巨物融合玩法介绍

一、介紹​ 前段时间&#xff0c;谷歌推出了一个图像生成工具whisk&#xff0c;有一个很好玩的图片融合玩法&#xff0c;分别提供三张图片,就可以任何组合来生成图片。​ ​ 最近我发现有人开发了对应的ComfyUI插件&#xff0c;对whisk做了支持&#xff0c;就来体验了下&#…

模式识别与机器学习

文章目录 考试题型零、简介1.自学内容(1)机器学习(2)机器学习和统计学中常见的流程(3)导数 vs 梯度(4)KL散度(5)凸优化问题 2.基本概念3.典型的机器学习系统4.前沿研究方向举例 一、逻辑回归1.线性回归2.逻辑回归3.随堂练习 二、贝叶斯学习基础1.贝叶斯公式2.贝叶斯决策3.分类器…

nginx负载均衡-基于端口的负载均衡(一)

注意&#xff1a; (1) 做负载均衡技术至少需要三台服务器&#xff1a;一台独立的负载均衡器&#xff0c;两台web服务器做集群 一、nginx分别代理后端web1 和 web2的三台虚拟主机 1、web1&#xff08;nginx-10.0.0.7&#xff09;配置基于端口的虚拟主机 [rootOldboy extra]# …