数据蒸馏:Dataset Distillation by Matching Training Trajectories 论文翻译和理解

一、TL;DR

  1. 数据集蒸馏的任务是合成一个较小的数据集,使得在该合成数据集上训练的模型能够达到在完整数据集上训练的模型相同的测试准确率,号称优于coreset的选择方法
  2. 本文中,对于给定的网络,我们在蒸馏数据上对其进行几次迭代训练,预先计算并存储在真实数据集上训练的专家网络的训练轨迹,并根据合成训练参数与在真实数据上训练的参数之间的距离来优化蒸馏数据。
  3. 有一个问题哈,这种蒸馏方法强依赖GT,如果新增数据优化模型,没有GT可能还是只能使用coreset的方法来做

二、方法介绍

数据蒸馏的目标是从大型训练数据集中提取知识,将其浓缩到一个非常小的合成训练图像集合中(每个类别低至一张图像),以便在蒸馏数据上训练模型能够获得与在原始数据集上训练相似的测试性能,如下图所示:

与经典的数据压缩不同,数据集蒸馏旨在保留足够的任务相关的信息,以便在小的合成数据集上训练的模型能够泛化到未见过的测试数据,如图2所示。因此,蒸馏算法必须在大量压缩信息的同时,保留区分性特征

之前的方法的问题:

  1. 大多数先前的方法都集中在小型数据集(如MNIST和CIFAR)上,而在真实、更高分辨率的图像上却难以取得进展
  2. 一些方法考虑了端到端的训练,但往往需要巨大的计算和内存资源,并且存在近似松弛或训练不稳定性的问题
  3. 另外一些方法专注于短期行为,强制在蒸馏数据上进行单次训练步骤以匹配真实数据上的训练步骤,在评估中可能会累积误差。

在本工作中:

  1. 提出了一种新的数据集蒸馏方法,不仅在性能上超越了以前的工作,而且适用于大规模数据集,如图1所示。
  2. 本文方法试图直接模仿在真实数据集上训练的网络的长期训练动态;
  3. 我们将合成数据上训练的参数轨迹段与从在真实数据上训练的模型记录的专家轨迹段进行匹配,从而避免了短视(即,专注于单个步骤)或难以优化(即,建模完整轨迹)的问题
  4. 将真实数据集视为引导网络训练动态的黄金标准,我们可以认为诱导的网络参数序列是一个专家轨迹。如果我们的蒸馏数据集能够诱导网络的训练动态遵循这些专家轨迹那么合成训练的网络将在参数空间中接近于在真实数据上训练的模型,并实现类似的测试性能

在我们的方法中,我们的损失函数直接鼓励蒸馏数据集引导网络优化沿着类似的轨迹进行(图3)。

训练流程:

  1. 从头开始在真实数据集上训练一组模型,并记录它们的专家训练轨迹。
  2. 从随机选择的专家轨迹中随机选择一个时间步来初始化一个新模型,并在合成数据集上进行几次迭代训练。
  3. 我们根据这个合成训练的网络与专家轨迹的偏离程度来惩罚蒸馏数据,并通过训练迭代进行反向传播。本质上,我们将许多专家训练轨迹的知识转移到了蒸馏图像上。

实验结果:

  1. 轻松超越了现有的数据集蒸馏方法以及核心集选择方法,在标准数据集上表现优异,包括CIFAR-10、CIFAR-100和Tiny ImageNet。
  2. CIFAR-10上,我们使用每个类别一张图像时达到了46.3%,每个类别50张图像时达到了71.5%的准确率
  3. 首次能够从ImageNet中蒸馏出高128×128分辨率的图像

三、近期工作(直接翻译)

3.1 数据集蒸馏

数据集蒸馏最早由Wang等人[44]提出,他们提出将模型权重表示为蒸馏图像的函数,并使用基于梯度的超参数优化方法对其进行优化[23],这种方法在元学习研究中也得到了广泛应用[8, 27]。随后,通过学习软标签[2, 38]、通过梯度匹配放大学习信号[47]、采用数据增强[45]以及针对无限宽度核极限进行优化[25, 26],一些工作显著提高了结果。数据集蒸馏已经实现了多种应用,包括持续学习[44, 45, 47]、高效的神经架构搜索[45, 47]、联邦学习[11, 37, 50]以及针对图像、文本和医学影像数据的隐私保护机器学习[22, 37]。正如引言中提到的,我们的方法不依赖于单步行为匹配[45, 47]、成本高昂的完整优化轨迹展开[38, 44]或大规模神经切线核计算[25, 26]。相反,我们的方法通过从预训练的专家中转移知识来实现长期轨迹匹配。

与我们的工作同时进行的,Zhao和Bilen[46]的方法完全忽略了优化步骤,而是专注于合成数据和真实数据之间的分布匹配。尽管这种方法由于降低了内存需求而适用于更高分辨率的数据集(例如Tiny ImageNet),但在大多数情况下,其性能表现不如以往的工作(例如,与之前的作品[45, 47]相比)。相比之下,我们的方法在标准基准测试和更高分辨率数据集上同时降低了内存成本,同时超越了现有作品[45, 47]和同时进行的方法[46]。

还有一条相关的研究路线是学习一个生成模型来合成训练数据[24, 36]。然而,这些方法并没有生成一个小尺寸的数据集,因此不能直接与数据集蒸馏方法进行比较。

3.2 模仿学习

模仿学习试图通过观察一系列专家演示来学习一个良好的策略[29, 30, 31]。行为克隆训练学习策略以与专家演示相同的方式行动。一些更复杂的形式涉及使用专家的标记进行在线策略学习[33],而其他方法则完全避免任何标记,例如通过分布匹配[16]。这些方法(特别是行为克隆)已被证明在离线环境中效果良好[9, 12]。我们的方法可以被视为模仿通过在真实数据集上训练获得的一系列专家网络训练轨迹。因此,它可以被视为在优化轨迹上进行模仿学习。

3.3 核心集和实例选择

与数据集蒸馏类似,核心集[1, 4, 13, 34, 41]和实例选择[28]旨在选择整个训练数据集的一个子集,其中在这个小子集上进行训练能够获得良好的性能。这些方法中的大多数并不适用于现代深度学习,但基于双层优化的新公式在持续学习等应用中已经显示出有希望的结果[3]。与核心集相关,其他研究路线旨在了解哪些训练样本对现代机器学习是“有价值的”,包括测量单个样本的准确性[20]和计算误分类率[39]。事实上,数据集蒸馏是这类想法的推广,因为蒸馏数据不需要是真实的,也不需要来自训练集。

四、方法详细介绍

数据集蒸馏指的是策划一个小的、合成的训练集 Dsyn​,使得在该合成数据上训练的模型在真实测试集上的表现与在大型真实训练集 Dreal​ 上训练的模型相似。本文方法直接模仿真实数据训练的长期行为,将蒸馏数据上的多个训练步骤与真实数据上的更多步骤进行匹配。

3.1 专家轨迹

如何获取在真实数据集上训练的网络的专家轨迹?

方法的核心:

  1. 利用专家轨迹 τ∗ 来指导我们合成数据集的蒸馏。专家轨迹是指在完整的真实数据集上训练神经网络时获得的参数时间序列 {θt∗​}0T​。

如何生成这些专家轨迹?

  1. 我们简单地在真实数据集上训练大量网络,每个模型不同epoch组成一条expert trajectory。作者称这些参数序列为“expert trajectory”,因为它们代表了数据集蒸馏任务的理论上限:在完整的真实数据集上训练的网络的性能。
  2. 同样,我们定义学生参数 θ^t​ 为在训练步骤 t 时在合成图像上训练的网络参数。我们的目标是蒸馏一个数据集,使其诱导出与真实训练集诱导的轨迹(给定相同的起始点)相似的轨迹,从而使我们最终得到一个类似的模型

由于这些专家轨迹仅使用真实数据计算,因此我们可以在蒸馏之前预先计算它们。对于给定数据集的所有实验,我们都使用相同的预先计算的专家轨迹集合,这使得蒸馏和实验能够快速进行。

3.2 长期参数匹配

本文方法通过鼓励蒸馏数据集诱导与真实数据集相似的长期网络参数轨迹,从而使得在合成数据上训练的网络表现类似于在真实数据上训练的网络。

我们的蒸馏过程从构成我们expert trajectories中的参数序列 {θt∗​}0T​ 中学习。与以往工作不同,我们的方法直接鼓励我们合成数据集诱导的长期训练动态与在真实数据上训练的网络的动态相匹配。

在每个蒸馏步骤中,我们首先从我们的专家轨迹之一中随机时间步采样参数 θt∗​,并用这些参数初始化我们的学生参数 θ^t​:=θt∗​。在初始化我们的学生网络后,我们接着对合成数据的分类损失进行 N 次梯度下降更新,更新学生参数:

其中A是可微分增强操作,α是个可学习的学习率。然后计算更新后的学生参数和expert trajectory的模型参数的匹配损失,根据权重匹配损失更新我们的蒸馏图像,即更新后学生参数 θ^t+N​ 与已知未来的专家参数 θt+M∗​ 之间的归一化平方 L2 误差:

通过将反向传播通过学生网络的所有 N 次更新来最小化这个目标,更新我们蒸馏数据集的像素,以及我们的可训练学习率 α。可训练学习率 α 的优化起到了自动调整学生和专家更新次数(超参数 M 和 N)的作用。我们使用带有动量的随机梯度下降(SGD)来优化 Dsyn​ 和 α,以达到上述目标。整体如下所示:

3.3 内存限制

本文如何减少内存消耗?

原式是这样进行梯度更新的,由于Dataset太大,因此可以将一式转化为三式

我们可以为学生网络的每次更新(即算法 1 第 10 行的内循环)采样一个新的小批量 b,这样在计算最终权重匹配损失(方程 2)时,所有的蒸馏图像都将被看到。小批量 b 仍然包含来自不同类别的图像,但每个类别的图像数量要少得多。在这种情况下,我们的学生网络更新变为

这种分批方法允许我们在确保同一类别蒸馏图像之间存在一定程度的异质性的同时,蒸馏出一个更大的合成数据集。

五、实验

对于 CIFAR-10,这些蒸馏图像可以在图 4 中看到。CIFAR-100 的图像在补充材料中进行了可视化。

如表 1 所示,我们的方法在每种设置中都显著优于所有基线。事实上,在每个类别一张图像的设置中,我们在两个数据集上都将次优方法(DSA [45])的测试准确率几乎提高了一倍。

在表 2 中,我们还与最近的方法 KIP [25, 26] 进行了比较:

正如之前的方法 [44] 所指出的,我们还发现在合成数据集中允许更多图像时,收益会显著减少

  1. 例如,在 CIFAR-10 上,当我们将每个类别的图像数量从 1 增加到 10 时,分类准确率从 46.3% 提高到 65.3%,
  2. 但当我们将每个类别的蒸馏图像数量从 10 增加到 50 时,仅从 65.3% 提高到 71.5%。

如果我们查看图 4(顶部)中每个类别一张图像的可视化,我们会看到每个类别的非常抽象但仍然可以识别的表示。当我们只允许每个类别有一张合成图像时,优化被迫将尽可能多的类别区分信息压缩到一个样本中。当我们允许更多图像来分散类别的信息时,优化有自由度将类别的区分特征分散到多个样本中,从而产生我们在图 4(底部)中看到的多样化的一组结构化图像(例如,不同类型的汽车和马,具有不同的姿势)。

跨架构泛化。我们还在 CIFAR-10、每个类别一张图像的任务上评估了我们的合成数据在与用于蒸馏它的架构不同的架构上的表现。在表 3 中,我们展示了我们的基线 ConvNet 性能,并在 ResNet 、VGG 和 AlexNet 上进行了评估。

表明我们的方法对架构的变化具有鲁棒性

4.2 短期匹配与长期匹配

非常短期的匹配(N=1 且 M 较小)通常比长期匹配表现更差,当 N 和 M 都相对较大时,达到最佳性能

对于这两种方法,我们测试它们使用蒸馏数据从相同的初始参数训练网络到目标参数的接近程度。DSA 仅针对短期行为进行优化,因此在更长时间的训练过程中可能会累积误差。实际上,随着 Δt 变得更大,DSA 在更长距离上无法模仿真实数据的行为。相比之下,我们的方法针对长期匹配进行了优化,因此表现更好。

六、总结

在这项工作中,我们介绍了一种数据集蒸馏算法,通过直接优化合成数据来诱导与真实数据相似的网络训练动态。我们的方法与以往方法的主要区别在于,我们既不受限于短期单步匹配,也不受优化整个训练过程的不稳定性以及计算强度的影响。我们的方法在这两个方面都取得了平衡,并且在这两方面都显示出改进。与以往方法不同,我们的方法首次扩展到128×128的ImageNet图像

局限性。我们使用预先计算的轨迹,虽然节省了大量内存,但以增加磁盘存储和专家模型训练的计算成本为代价。训练和存储专家轨迹的计算开销相当高。例如,CIFAR专家每个epoch大约需要3秒(所有200个CIFAR专家总共需要8个GPU小时),而每个ImageNet(子集)专家每个epoch大约需要11秒(所有100个ImageNet专家总共需要15个GPU小时)。在存储方面,每个CIFAR专家大约占用60MB的存储空间,而每个ImageNet专家大约占用120MB。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74650.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【spring cloud Netflix】Ribbon组件

1.基本概念 SpringCloud Ribbon是基于Netflix Ribbon 实现的一套客户端负载均衡的工具。简单的说,Ribbon 是 Netflix 发布的开源项目,主要功能是提供客户端的软件负载均衡算法,将 Netflix 的中间层服务连接在一 起。Ribbon 的客户端组件提供…

P1036 [NOIP 2002 普及组] 选数(DFS)

题目描述 已知 n 个整数 x1​,x2​,⋯,xn​&#xff0c;以及 1 个整数 k&#xff08;k<n&#xff09;。从 n 个整数中任选 k 个整数相加&#xff0c;可分别得到一系列的和。例如当 n4&#xff0c;k3&#xff0c;4 个整数分别为 3,7,12,19 时&#xff0c;可得全部的组合与它…

在响应式网页的开发中使用固定布局、流式布局、弹性布局哪种更好

一、首先看下固定布局与流体布局的区别 &#xff08;一&#xff09;固定布局 固定布局的网页有一个固定宽度的容器&#xff0c;内部组件宽度可以是固定像素值或百分比。其容器元素不会移动&#xff0c;无论访客屏幕分辨率如何&#xff0c;看到的网页宽度都相同。现代网页设计…

二分查找与二叉树中序遍历——面试算法

目录 二分查找与分治 循环方式 递归方式 元素中有重复的二分查找 基于二分查找的拓展问题 山脉数组的顶峰索引——局部有序 旋转数字中的最小数字 找缺失数字 优化平方根 中序与搜索树 二叉搜索树中搜索特定值 验证二叉搜索树 有序数组转化为二叉搜索树 寻找两个…

字符串——面试考察高频算法题

目录 转换成小写字母 字符串转化为整数 反转相关的问题 反转字符串 k个一组反转 仅仅反转字母 反转字符串里的单词 验证回文串 判断是否互为字符重排 最长公共前缀 字符串压缩问题 转换成小写字母 给你一个字符串 s &#xff0c;将该字符串中的大写字母转换成相同的…

现代复古电影海报品牌徽标设计衬线英文字体安装包 Thick – Retro Vintage Cinematic Font

Thick 是一种大胆的复古字体&#xff0c;专为有影响力的标题和怀旧的视觉效果而设计。其厚实的字体、复古魅力和电影风格使其成为电影海报、产品标签、活动品牌和编辑设计的理想选择。无论您是在引导电影的黄金时代&#xff0c;还是在现代布局中注入复古活力&#xff0c;Thick …

[C++面试] new、delete相关面试点

一、入门 1、说说new与malloc的基本用途 int* p1 (int*)malloc(sizeof(int)); // C风格 int* p2 new int(10); // C风格&#xff0c;初始化为10 new 是 C 中的运算符&#xff0c;用于在堆上动态分配内存并调用对象的构造函数&#xff0c;会自动计算所需内存…

Unity URP管线与HDRP管线对比

1. 渲染架构与底层技术 URP 渲染路径&#xff1a; 前向渲染&#xff08;Forward&#xff09;&#xff1a;默认单Pass前向&#xff0c;支持少量实时光源&#xff08;通常4-8个逐物体&#xff09;。 延迟渲染&#xff08;Deferred&#xff09;&#xff1a;可选但功能简化&#…

JDK8卸载与安装教程(超详细)

JDK8卸载与安装教程&#xff08;超详细&#xff09; 最近学习一个项目&#xff0c;需要使用更高级的JDK&#xff0c;这里记录一下卸载旧版本与安装新版本JDK的过程。 JDK8卸载 以windows10操作系统为例&#xff0c;使用快捷键winR输入cmd&#xff0c;打开控制台窗口&#xf…

python爬虫:DrissionPage实战教程

如果本文章看不懂可以看看上一篇文章&#xff0c;加强自己的基础&#xff1a;爬虫自动化工具&#xff1a;DrissionPage-CSDN博客 案例解析&#xff1a; 前提&#xff1a;我们以ChromiumPage为主&#xff0c;写代码工具使用Pycharm&#xff08;python环境3.9-3.10&#xff09; …

07-01-自考数据结构(20331)- 排序-内部排序知识点

内部排序算法是数据结构核心内容,主要包括插入类(直接插入、希尔)、交换类(冒泡、快速)、选择类(简单选择、堆)、归并和基数五大类排序方法。 知识拓扑 知识点介绍 直接插入排序 定义:将每个待排序元素插入到已排序序列的适当位置 算法步骤: 从第二个元素开始遍历…

Go语言-初学者日记(八):构建、部署与 Docker 化

&#x1f9f1; 一、go build&#xff1a;最基础的构建方式 Go 的构建工具链是出了名的轻量、简洁&#xff0c;直接用 go build 就能把项目编译成二进制文件。 ✅ 构建当前项目 go build -o myapp-o myapp 指定输出文件名默认会构建当前目录下的 main.go 或 package main &a…

教程:如何使用 JSON 合并脚本

目录 1. 介绍 2. 使用方法 3. 注意事项 4. 示例 5.完整代码 1. 介绍 该脚本用于将多个 COCO 格式的 JSON 标注文件合并为一个 JSON 文件。COCO 格式常用于目标检测和图像分割任务&#xff0c;包含以下三个主要部分&#xff1a; "images"&#xff1a;图像信息&a…

Java学习总结-缓冲流性能分析

测试用例&#xff1a; 分别使用原始的字节流&#xff0c;以及字节缓冲流复制一个很大的视频。 测试步骤&#xff1a; 在这个分析性能需要一个记录时间的工具&#xff1a;这个是记录1970-1-1 00&#xff1a;00&#xff1a;00到现在的总毫秒值。 long start System.currentT…

流影---开源网络流量分析平台(五)(成果展示)

目录 前沿 攻击过程 前沿 前四章我们已经成功安装了流影的各个功能&#xff0c;那么接下来我们就看看这个开源工具的实力&#xff0c;本实验将进行多个攻击手段&#xff08;ip扫描&#xff0c;端口扫描&#xff0c;sql注入&#xff09;攻击靶机&#xff0c;来看看流影的态感效…

vs环境中编译osg以及osgQt

1、下载 OpenSceneGraph 获取源代码 您可以通过以下方式获取 OSG 源代码: 官网下载:https://github.com/openscenegraph/OpenSceneGraph/releases 使用 git 克隆: git clone https://github.com/openscenegraph/OpenSceneGraph.git 2、下载必要的第三方依赖库 依赖库 ht…

Unity:标签(tags)

为什么需要Tags&#xff1f; 在游戏开发中&#xff0c;游戏对象&#xff08;GameObject&#xff09;数量可能非常多&#xff0c;比如玩家、敌人、子弹等。开发者需要一种简单的方法来区分这些对象&#xff0c;并根据它们的类型执行不同的逻辑。 核心需求&#xff1a; 分类和管…

【C++11】lambda

lambda lambda表达式语法 lambda表达式本质是一个匿名函数对象&#xff0c;跟普通函数不同的是它可以定义在函数内部。lambda表达式语法使用层而言没有类型&#xff0c;所以一般是用auto或者模板参数定义的对象去接收lambda对象。 lambda表达式的格式&#xff1a;[capture-l…

fpga:分秒计时器

任务目标 分秒计数器核心功能&#xff1a;实现从00:00到59:59的循环计数&#xff0c;通过四个七段数码管显示分钟和秒。 复位功能&#xff1a;支持硬件复位&#xff0c;将计数器归零并显示00:00。 启动/暂停控制&#xff1a;通过按键控制计时的启动和暂停。 消抖处理&#…

《UNIX网络编程卷1:套接字联网API》第6章 IO复用:select和poll函数

《UNIX网络编程卷1&#xff1a;套接字联网API》第6章 I/O复用&#xff1a;select和poll函数 6.1 I/O复用的核心价值与适用场景 I/O复用是高并发网络编程的基石&#xff0c;允许单个进程/线程同时监控多个文件描述符&#xff08;套接字&#xff09;的状态变化&#xff0c;从而高…