【连续学习之VCL算法】2017年论文:Variational continual learning

1 介绍

年份:2017

期刊: arXiv preprint

Nguyen C V, Li Y, Bui T D, et al. Variational continual learning[J]. arXiv preprint arXiv:1710.10628, 2017.

本文提出的算法是变分连续学习(Variational Continual Learning, VCL),它是一种基于变分推断的在线学习方法,结合了在线变分推断(VI)和蒙特卡洛VI的最新进展,用于训练深度判别模型和生成模型,以实现在连续学习设置中避免灾难性遗忘并适应新任务的能力。关键步骤包括使用变分推断来近似后验分布,并通过核心集(coreset)数据摘要方法增强模型的记忆能力。本文算法属于基于变分推断的算法,它通过在线更新模型参数的后验分布来实现连续学习,这可以归类为基于正则化的算法,因为它利用KL散度最小化来正则化模型参数,以平衡对新数据的适应性和对旧数据的保留。

2 创新点

  1. 变分连续学习框架(VCL)
    • 提出了一种新的连续学习框架,即变分连续学习(VCL),它结合了在线变分推断(VI)和蒙特卡洛VI,适用于复杂的连续学习环境。
  2. 深度模型的连续学习
    • 将VCL框架应用于深度判别模型和深度生成模型,展示了该框架在这些复杂神经网络模型中的有效性。
  3. 核心集(coreset)数据摘要
    • 引入了核心集的概念,这是一种小型的代表性数据集,用于保留先前任务的关键信息,帮助算法在新任务学习中避免遗忘旧任务。
  4. 自动和无参数的连续学习
    • VCL框架避免了传统方法中需要手动调整的超参数,实现了完全自动化的学习过程,且无需额外的验证集来调整参数。
  5. 实验结果的优越性
    • 在多个任务上的实验结果显示,VCL在避免灾难性遗忘方面优于现有的连续学习方法,且不需要调整任何超参数。
  6. 理论基础和扩展性
    • 基于贝叶斯推断的理论基础,VCL提供了一种原则性强、可扩展的解决方案,可以应用于多种不同的模型和学习场景。
  7. 适用于复杂任务演化
    • VCL能够处理任务随时间演变以及全新任务出现的情况,这对于现实世界中任务不断变化的场景具有重要意义。

3 算法

3.1 算法原理

  1. 贝叶斯推断框架
    • 贝叶斯推断提供了一个自然框架来处理连续学习问题。它通过保留模型参数的分布来表示参数的不确定性,这有助于在新数据到来时更新知识,同时保留旧知识。
  2. 在线变分推断(Online VI)
    • 在线VI是一种近似贝叶斯推断的方法,它通过迭代更新近似后验分布来处理新数据。VCL利用在线VI来递归地更新模型参数的后验分布。
  3. 变分连续学习(VCL)
    • VCL通过最小化KL散度(Kullback-Leibler divergence)来找到最佳近似后验分布。具体来说,对于每一步新数据的到来,VCL通过结合之前的后验分布和新数据的似然函数,然后通过变分推断找到新的近似后验分布。
  4. 核心集(Coreset)
    • 为了缓解连续学习中累积的近似误差,VCL引入了核心集的概念。核心集是从先前任务中提取的代表性数据点集合,用于在训练过程中刷新模型对旧任务的记忆。
  5. 递归更新
    • VCL递归地更新模型参数的近似后验分布。给定前一步的后验分布和新数据,VCL通过乘以似然函数并重新归一化来获得新的后验分布。
  6. 预测和参数更新
    • 在测试时,VCL使用最终的变分分布来进行预测。在训练时,VCL通过最大化变分下界(variational lower bound)来更新变分参数,这涉及到计算期望对数似然和KL散度。
  7. 蒙特卡洛方法
    • 为了处理期望对数似然的计算,VCL采用蒙特卡洛方法来近似这些期望值,这通常涉及到使用重参数化技巧(reparameterization trick)来计算梯度。

3.2 算法步骤

  1. 初始化:选择一个先验分布 p ( θ ) p(\theta) p(θ)并初始化变分近似 q 0 ( θ ) = p ( θ ) q_0(\theta) = p(\theta) q0(θ)=p(θ)
  2. 核心集初始化:初始化核心集 C 0 = ∅ C_0 = \emptyset C0=
  3. 对于每一个新任务 t = 1 , 2 , … , T t = 1, 2, \ldots, T t=1,2,,T执行以下步骤:a. 观察新数据集 D t D_t Dt。b. 更新核心集 C t C_t Ct,使用 C t − 1 C_{t-1} Ct1 D t D_t Dt来选择新的代表性数据点。c. 更新非核心集数据点的变分分布:

q ~ t ( θ ) = arg ⁡ min ⁡ q ∈ Q K L ( q ( θ ) ∥ q ~ t − 1 ( θ ) p ( D t ∪ C t − 1 ∖ C t ∣ θ ) Z ) \tilde{q}_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_{t-1}(\theta) p(D_t \cup C_{t-1} \setminus C_t | \theta)}{Z} \right) q~t(θ)=argqQminKL(q(θ)Zq~t1(θ)p(DtCt1Ctθ))

其中, Z Z Z是归一化常数。

d. 计算最终的变分分布(仅用于预测):

q t ( θ ) = arg ⁡ min ⁡ q ∈ Q K L ( q ( θ ) ∥ q ~ t ( θ ) p ( C t ∣ θ ) Z ) q_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_t(\theta) p(C_t | \theta)}{Z} \right) qt(θ)=argqQminKL(q(θ)Zq~t(θ)p(Ctθ))

e. 进行预测:在测试输入 x ∗ x^* x上,使用 q t ( θ ) q_t(\theta) qt(θ)来计算预测分布:

p ( y ∗ ∣ x ∗ , D 1 : t ) = ∫ q t ( θ ) p ( y ∗ ∣ θ , x ∗ ) d θ p(y^* | x^*, D_{1:t}) = \int q_t(\theta) p(y^* | \theta, x^*) d\theta p(yx,D1:t)=qt(θ)p(yθ,x)dθ

4 实验分析

图1展示了论文中测试的多头网络架构,包括判别模型(a)和生成模型(b),其中判别模型中低层网络参数θS在多个任务中共享,每个任务t有自己的“头部网络”θtH,映射到共同隐藏层的输出;生成模型中头部网络生成来自潜在变量z的中间层表示。

图6展示了在训练后各个任务生成器生成的图像,其中每列代表特定任务生成器的输出,每行显示所有训练任务生成器的结果,明显地,简单直接的在线学习方法遭受了灾难性遗忘,而其他方法(如VCL)成功地记住了之前的任务。实验结论是,与简单在线学习相比,VCL等方法在连续学习环境中能更好地保留对先前任务的记忆,避免了灾难性遗忘,展现出更好的长期记忆性能。

5 思考

(1)代码举例理解本文算法

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.nn.functional import softmax# 假设我们有一个简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 变分连续学习算法的实现
def variational_continual_learning(model, prior_mu, prior_sigma, tasks_num, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)for t in range(tasks_num):# 加载当前任务的数据datasets, labels = data_loader(t)# 遍历当前任务的数据进行训练for data, label in zip(datasets, labels):# 前向传播output = model(data)log_likelihood = softmax(output, dim=1).gather(1, label.unsqueeze(1)).squeeze(1).log()# 计算损失函数,包括负对数似然和KL散度loss = -log_likelihood + kl_divergence(model.fc2.weight, model.fc2.bias, prior_mu, prior_sigma)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()return modeldef kl_divergence(weights, biases, prior_mu, prior_sigma):# 计算权重和偏置的KL散度posterior_mu = weightsposterior_sigma = torch.nn.functional.softplus(biases) + 1e-6  # 防止sigma为0# KL散度计算公式kl_w = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 + (posterior_mu - prior_mu)**2 / posterior_sigma**2 - 1)kl_b = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 - 1)return kl_w.sum() + kl_b.sum()# 假设我们有一个数据加载器,用于加载连续的任务
def data_loader(task_id):# 这里只是一个示例,实际中需要根据task_id加载不同的数据# 返回当前任务的数据和标签pass# 初始化模型
input_size = 784  # 例如MNIST数据集
hidden_size = 100
output_size = 10  # 假设有10个类别
model = SimpleNN(input_size, hidden_size, output_size)# 设置先验分布的均值和标准差
prior_mu = torch.zeros(output_size)
prior_sigma = torch.ones(output_size)# 执行变分连续学习算法
tasks_num = 5  # 假设有5个连续的任务
trained_model = variational_continual_learning(model, prior_mu, prior_sigma, tasks_num)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/64204.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

多视图 (Multi-view) 与多模态 (Multi-modal)

多视图 (Multi-view) 与多模态 (Multi-modal) 是两种不同的数据处理方式,它们在机器学习和数据分析中有着重要的应用。尽管这两者有一些相似之处,但它们关注的角度和处理方法有所不同。 多视图 (Multi-view) 定义:多视图指的是同一数据对象…

【Transformer】深入浅出自注意力机制

写在前面:博主本人也是刚接触计算机视觉领域不久,本篇文章是为了记录自己的学习,大家一起学习,有问题欢迎大家指出。(博主本人的习惯是看文章看到不懂的有立马去看不懂的那块,所以博文可能内容比较杂&#…

HarmonyOS NEXT 实战之元服务:静态案例效果---教育培训服务

背景: 前几篇学习了元服务,后面几期就让我们开发简单的元服务吧,里面丰富的内容大家自己加,本期案例 仅供参考 先上本期效果图 ,里面图片自行替换 效果图1完整代码案例如下: import { authentication } …

互联网视频云平台EasyDSS无人机推流直播技术如何助力野生动植物保护工作?

在当今社会,随着科技的飞速发展,无人机技术已经广泛应用于各个领域,为我们的生活带来了诸多便利。而在动植物保护工作中,无人机的应用更是为这一领域注入了新的活力。EasyDSS,作为一款集视频处理、分发、存储于一体的综…

51c视觉~YOLO~合集8

我自己的原文哦~ https://blog.51cto.com/whaosoft/12897680 1、Yolo9 1.1、YOLOv9SAM实现动态目标检测和分割 主要介绍基于YOLOv9SAM实现动态目标检测和分割 背景介绍 在本文中,我们使用YOLOv9SAM在RF100 Construction-Safety-2 数据集上实现自定义对象检测模…

Docker Container 可观测性最佳实践

Docker Container 介绍 Docker Container( Docker 容器)是一种轻量级、可移植的、自给自足的软件运行环境,它在 Docker 引擎的宿主机上运行。容器在许多方面类似于虚拟机,但它们更轻量,因为它们不需要模拟整个操作系统…

气相色谱-质谱联用分析方法中的常用部件,分流平板更换

分流平板,是气相色谱-质谱联用分析方法中的一个常用部件,它可以实现气相色谱柱流与MS检测器流的分离和分流。常见的气质联用仪分流平板有很多种,如单层T型分流平板、双层T型分流平板、螺旋分流平板等等。 操作视频http://www.spcctech.com/v…

易基因: BS+ChIP-seq揭示DNA甲基化调控非编码RNA(VIM-AS1)抑制肿瘤侵袭性|Exp Mol Med

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 肝细胞癌(hepatocellular carcinoma,HCC)早期复发仍然是一个具有挑战性的领域,其中涉及的机制尚未完全被理解。尽管微血管侵犯&#xff08…

鸿蒙系统文件管理基础服务的设计背景和设计目标

有一定经验的开发者通常对文件管理相关的api应用或者底层逻辑都比较熟悉,但是关于文件管理服务的设计背景和设计目标可能了解得不那么清楚,本文旨在分享文件管理服务的设计背景及目标,方便广大开发者更好地理解鸿蒙系统文件管理服务。 1 鸿蒙…

Doris 数据库外部表-JDBC 外表,Oracle to Doris

简介 提供了 Doris 通过数据库访问的标准接口 (JDBC) 来访问外部表,外部表省去了繁琐的数据导入工作,让 Doris 可以具有了访问各式数据库的能力,并借助 Doris 本身的 OLAP 的能力来解决外部表的数据分析问题: 支持各种数据源接入…

分布式 IO 模块助力冲压机械臂产线实现智能控制

在当今制造业蓬勃发展的浪潮中,冲压机械臂产线的智能化控制已然成为提升生产效率、保障产品质量以及增强企业竞争力的关键所在。而分布式 IO 模块的应用,正如同为这条产线注入了一股强大的智能动力,开启了全新的高效生产篇章。 传统挑战 冲压…

深度学习中的并行策略概述:4 Tensor Parallelism

深度学习中的并行策略概述:4 Tensor Parallelism 使用 PyTorch 实现 Tensor Parallelism 。首先定义了一个简单的模型 SimpleModel,它包含两个全连接层。然后,本文使用 torch.distributed.device_mesh 初始化了一个设备网格,这代…

企业销售人员培训系统|Java|SSM|VUE| 前后端分离

【技术栈】 1⃣️:架构: B/S、MVC 2⃣️:系统环境:Windowsh/Mac 3⃣️:开发环境:IDEA、JDK1.8、Maven、Mysql5.7 4⃣️:技术栈:Java、Mysql、SSM、Mybatis-Plus、VUE、jquery,html 5⃣️数据库…

自然语言处理与知识图谱的融合与应用

目录 前言1. 知识图谱与自然语言处理的关系1.1 知识图谱的定义与特点1.2 自然语言处理的核心任务1.3 二者的互补性 2. NLP在知识图谱构建中的应用2.1 信息抽取2.1.1 实体识别2.1.2 关系抽取2.1.3 属性抽取 2.2 知识融合2.3 知识推理 3. NLP与知识图谱融合的实际应用3.1 智能问答…

CSS(三)盒子模型

目录 Content Padding Border Margin 盒子模型计算方式 使用 box-sizing 属性控制盒子模型的计算 所有的HTML元素都可以看作像下图这样一个矩形盒子: 这个模型包括了四个区域:content(内容区域)、padding(内边距…

基于NodeMCU的物联网窗帘控制系统设计

最终效果 基于NodeMCU的物联网窗帘控制系统设计 项目介绍 该项目是“物联网实验室监测控制系统设计(仿智能家居)”项目中的“家电控制设计”中的“窗帘控制”子项目,最前者还包括“物联网设计”、“环境监测设计”、“门禁系统设计计”和“小…

有没有免费提取音频的软件?音频编辑软件介绍!

出于工作和生活娱乐等原因,有时候我们需要把音频单独提取出来(比如歌曲伴奏、人声清唱等、乐器独奏等)。要提取音频必须借助音频处理软件,那么有没有免费提取音频的软件呢?下面我们将为大家介绍几款免费软件&#xff0…

【保姆式】python调用api通过机器人发送文件到飞书指定群聊

当前飞书webhook机器人还不支持发送文件类型的群消息,它目前仅支持文本,富文本,卡片等文字类型的数据。 我们可以申请创建一个机器人应用来实现群发送文件消息。 创建飞书应用 创建飞书应用、配置权限、添加机器人 来到飞书开发者后台 创建…

GitLab 服务变更提醒:中国大陆、澳门和香港用户停止提供服务(GitLab 服务停止)

目录 前言 一. 变更详情 1. 停止服务区域 2. 邮件通知 3. 新的服务提供商 4. 关键日期 5. 行动建议 二. 迁移指南 三. 注意事项 四. 相关推荐 前言 近期,许多位于中国大陆、澳门和香港的 GitLab 用户收到了一封来自 GitLab 官方的重要通知。根据这封邮件…

【Agent】AutoGen Studio2.0开源框架-UI层环境安装+详细操作教程(从0到1带跑通智能体AutoGen Studio)

💥 欢迎来到我的博客!很高兴能在这里与您相遇! 首页:GPT-千鑫 – 热爱AI、热爱Python的天选打工人,活到老学到老!!!导航 - 人工智能系列:包含 OpenAI API Key教程, 50个…