1 介绍
年份:2017
期刊: arXiv preprint
Nguyen C V, Li Y, Bui T D, et al. Variational continual learning[J]. arXiv preprint arXiv:1710.10628, 2017.
本文提出的算法是变分连续学习(Variational Continual Learning, VCL),它是一种基于变分推断的在线学习方法,结合了在线变分推断(VI)和蒙特卡洛VI的最新进展,用于训练深度判别模型和生成模型,以实现在连续学习设置中避免灾难性遗忘并适应新任务的能力。关键步骤包括使用变分推断来近似后验分布,并通过核心集(coreset)数据摘要方法增强模型的记忆能力。本文算法属于基于变分推断的算法,它通过在线更新模型参数的后验分布来实现连续学习,这可以归类为基于正则化的算法,因为它利用KL散度最小化来正则化模型参数,以平衡对新数据的适应性和对旧数据的保留。
2 创新点
- 变分连续学习框架(VCL):
- 提出了一种新的连续学习框架,即变分连续学习(VCL),它结合了在线变分推断(VI)和蒙特卡洛VI,适用于复杂的连续学习环境。
- 深度模型的连续学习:
- 将VCL框架应用于深度判别模型和深度生成模型,展示了该框架在这些复杂神经网络模型中的有效性。
- 核心集(coreset)数据摘要:
- 引入了核心集的概念,这是一种小型的代表性数据集,用于保留先前任务的关键信息,帮助算法在新任务学习中避免遗忘旧任务。
- 自动和无参数的连续学习:
- VCL框架避免了传统方法中需要手动调整的超参数,实现了完全自动化的学习过程,且无需额外的验证集来调整参数。
- 实验结果的优越性:
- 在多个任务上的实验结果显示,VCL在避免灾难性遗忘方面优于现有的连续学习方法,且不需要调整任何超参数。
- 理论基础和扩展性:
- 基于贝叶斯推断的理论基础,VCL提供了一种原则性强、可扩展的解决方案,可以应用于多种不同的模型和学习场景。
- 适用于复杂任务演化:
- VCL能够处理任务随时间演变以及全新任务出现的情况,这对于现实世界中任务不断变化的场景具有重要意义。
3 算法
3.1 算法原理
- 贝叶斯推断框架:
- 贝叶斯推断提供了一个自然框架来处理连续学习问题。它通过保留模型参数的分布来表示参数的不确定性,这有助于在新数据到来时更新知识,同时保留旧知识。
- 在线变分推断(Online VI):
- 在线VI是一种近似贝叶斯推断的方法,它通过迭代更新近似后验分布来处理新数据。VCL利用在线VI来递归地更新模型参数的后验分布。
- 变分连续学习(VCL):
- VCL通过最小化KL散度(Kullback-Leibler divergence)来找到最佳近似后验分布。具体来说,对于每一步新数据的到来,VCL通过结合之前的后验分布和新数据的似然函数,然后通过变分推断找到新的近似后验分布。
- 核心集(Coreset):
- 为了缓解连续学习中累积的近似误差,VCL引入了核心集的概念。核心集是从先前任务中提取的代表性数据点集合,用于在训练过程中刷新模型对旧任务的记忆。
- 递归更新:
- VCL递归地更新模型参数的近似后验分布。给定前一步的后验分布和新数据,VCL通过乘以似然函数并重新归一化来获得新的后验分布。
- 预测和参数更新:
- 在测试时,VCL使用最终的变分分布来进行预测。在训练时,VCL通过最大化变分下界(variational lower bound)来更新变分参数,这涉及到计算期望对数似然和KL散度。
- 蒙特卡洛方法:
- 为了处理期望对数似然的计算,VCL采用蒙特卡洛方法来近似这些期望值,这通常涉及到使用重参数化技巧(reparameterization trick)来计算梯度。
3.2 算法步骤
- 初始化:选择一个先验分布 p ( θ ) p(\theta) p(θ)并初始化变分近似 q 0 ( θ ) = p ( θ ) q_0(\theta) = p(\theta) q0(θ)=p(θ)。
- 核心集初始化:初始化核心集 C 0 = ∅ C_0 = \emptyset C0=∅。
- 对于每一个新任务 t = 1 , 2 , … , T t = 1, 2, \ldots, T t=1,2,…,T执行以下步骤:a. 观察新数据集 D t D_t Dt。b. 更新核心集 C t C_t Ct,使用 C t − 1 C_{t-1} Ct−1和 D t D_t Dt来选择新的代表性数据点。c. 更新非核心集数据点的变分分布:
q ~ t ( θ ) = arg min q ∈ Q K L ( q ( θ ) ∥ q ~ t − 1 ( θ ) p ( D t ∪ C t − 1 ∖ C t ∣ θ ) Z ) \tilde{q}_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_{t-1}(\theta) p(D_t \cup C_{t-1} \setminus C_t | \theta)}{Z} \right) q~t(θ)=argq∈QminKL(q(θ)∥Zq~t−1(θ)p(Dt∪Ct−1∖Ct∣θ))
其中, Z Z Z是归一化常数。
d. 计算最终的变分分布(仅用于预测):
q t ( θ ) = arg min q ∈ Q K L ( q ( θ ) ∥ q ~ t ( θ ) p ( C t ∣ θ ) Z ) q_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_t(\theta) p(C_t | \theta)}{Z} \right) qt(θ)=argq∈QminKL(q(θ)∥Zq~t(θ)p(Ct∣θ))
e. 进行预测:在测试输入 x ∗ x^* x∗上,使用 q t ( θ ) q_t(\theta) qt(θ)来计算预测分布:
p ( y ∗ ∣ x ∗ , D 1 : t ) = ∫ q t ( θ ) p ( y ∗ ∣ θ , x ∗ ) d θ p(y^* | x^*, D_{1:t}) = \int q_t(\theta) p(y^* | \theta, x^*) d\theta p(y∗∣x∗,D1:t)=∫qt(θ)p(y∗∣θ,x∗)dθ
4 实验分析
图1展示了论文中测试的多头网络架构,包括判别模型(a)和生成模型(b),其中判别模型中低层网络参数θS在多个任务中共享,每个任务t有自己的“头部网络”θtH,映射到共同隐藏层的输出;生成模型中头部网络生成来自潜在变量z的中间层表示。
图6展示了在训练后各个任务生成器生成的图像,其中每列代表特定任务生成器的输出,每行显示所有训练任务生成器的结果,明显地,简单直接的在线学习方法遭受了灾难性遗忘,而其他方法(如VCL)成功地记住了之前的任务。实验结论是,与简单在线学习相比,VCL等方法在连续学习环境中能更好地保留对先前任务的记忆,避免了灾难性遗忘,展现出更好的长期记忆性能。
5 思考
(1)代码举例理解本文算法
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.nn.functional import softmax# 假设我们有一个简单的神经网络模型
class SimpleNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size)self.fc2 = nn.Linear(hidden_size, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 变分连续学习算法的实现
def variational_continual_learning(model, prior_mu, prior_sigma, tasks_num, lr=0.001):optimizer = optim.Adam(model.parameters(), lr=lr)for t in range(tasks_num):# 加载当前任务的数据datasets, labels = data_loader(t)# 遍历当前任务的数据进行训练for data, label in zip(datasets, labels):# 前向传播output = model(data)log_likelihood = softmax(output, dim=1).gather(1, label.unsqueeze(1)).squeeze(1).log()# 计算损失函数,包括负对数似然和KL散度loss = -log_likelihood + kl_divergence(model.fc2.weight, model.fc2.bias, prior_mu, prior_sigma)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()return modeldef kl_divergence(weights, biases, prior_mu, prior_sigma):# 计算权重和偏置的KL散度posterior_mu = weightsposterior_sigma = torch.nn.functional.softplus(biases) + 1e-6 # 防止sigma为0# KL散度计算公式kl_w = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 + (posterior_mu - prior_mu)**2 / posterior_sigma**2 - 1)kl_b = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 - 1)return kl_w.sum() + kl_b.sum()# 假设我们有一个数据加载器,用于加载连续的任务
def data_loader(task_id):# 这里只是一个示例,实际中需要根据task_id加载不同的数据# 返回当前任务的数据和标签pass# 初始化模型
input_size = 784 # 例如MNIST数据集
hidden_size = 100
output_size = 10 # 假设有10个类别
model = SimpleNN(input_size, hidden_size, output_size)# 设置先验分布的均值和标准差
prior_mu = torch.zeros(output_size)
prior_sigma = torch.ones(output_size)# 执行变分连续学习算法
tasks_num = 5 # 假设有5个连续的任务
trained_model = variational_continual_learning(model, prior_mu, prior_sigma, tasks_num)