学习视频:火炉课堂 | 元学习(meta-learning)到底是什么鬼?_哔哩哔哩_bilibili
一、从传统机器学习到元学习
我们传统的机器学习,是手工设计一个模型,然后将训练数据投进模型中进行训练,得到一个最优的模型参数,再将测试的数据放进这个最优的模型参数中,得出预测的结果。如下图所示,我们想得到的是一个最优的模型参数F。
但是这个初始化的模型,是经过手工设计的,比如说CNN里面卷积有多少层或者使用什么样的激活函数,这些想要得到最优的状态,就需要一遍一遍的做消融实验来得到一个最优值,那这些能不能也使它自主学习呢?这就是元学习想要达到的效果。通俗来讲,元学习就是希望能够自动学习应该使用什么样的算法。
二、损失优化
深度学习的损失计算和优化是在训练阶段。比如分类网络,在训练的时候,经过模型之后会得到一个预测类别,这个时候与标签对比,判断分类是否正确,如果不正确的话,惩罚的力度就更大,强迫模型去往正确的方向优化。但是元学习首先学习的是算法,那要评判一个算法好不好,就需要看这个算法在验证集上的效果,所以损失是在验证过程中计算,相当于深度学习外面又套了一层训练。整体来说是因为:元学习要得到一个模型,就必须要有训练数据,那么这个时候要想评估这个模型好不好,是不能用训练数据来做评估的,因此只能换一批数据来做评估,这也就是为什么要在验证数据上计算损失。
三、数据集构成
传统机器学习的数据集由两部分构成:训练集和测试集,并且要求训练集和测试集里面的样本类别是保持一致的。
元学习的数据集也包含训练集和测试集两部分,但是要求训练集和测试集里面的类别不同,不存在交叉情况,也就是说测试集里面的数据对于训练集来说都是新的数据。这里我的理解是,元学习是为了学习一个算法适用于多个任务的算法,这个时候就需要算法的泛化能力很强,所以需要新的任务来评估效果。元数据的数据集构成如下:
像前面损失优化部分所说的,首先,一个初始给定的算法,需要一部分训练数据来训练一个模型参数,然后需要一部分验证数据使用这个模型参数对这个算法做评估,继而对算法进行优化,这个时候肯定是需要训练数据和验证数据类别是相同的。但是相对于优化算法这一层来说,这部分还是在训练的过程,我们需要在验证数据上来评估这个算法的效果怎么样,然后对算法进行优化。所以这个时候,相对于优化算法这一层的训练过程,就需要两批数据,一批用来训练一个模型参数,一批用来评估,这两批数据就叫做支撑数据(support set) 和 查询数据(query set)。整体的训练过程如下:
train_supprot是为了先训练一个模型出来,然后使用train_query来做评估,从而更新算法的相关参数,在train里面训练好了一个算法之后,再用test_supprot来训练一个适用于当前任务的模型参数(即针对test里面的数据任务训练一个适用于test数据的模型),test_supprot训练完了之后,再使用test_query里面的数据来进行评估。整体来说,四个部分的数据,train_supprt、train_query、test_supprot里面的数据都是参与训练的,唯独只有test_query里面的数据是做最后的评估,不参与训练。
四、训练部分
元学习的训练成为元训练(mate-training),分为外层训练(训练算法)和内层训练(训练当前算法的模型参数)。元学习的测试,称之为元测试(meta-testing)。
1、元训练
这里ω指的是算法的相关参数,θ指的是当前算法的模型参数。外层训练是为了训练算法参数ω,内层训练是为了训练模型参数θ。主要的步骤是:首先给定一个算法ω,然后使用train-support的数据训练出模型参数θ,拿着这个模型参数θ在train-query数据上做评估,此时经过内层的训练,对于此时的ω肯定是模型θ最优的状态,如果此时在train-query上的评估结果不好,说明此时的ω算法不合适,就需要对ω进行更新。
2、元测试
外层训练结束之后,学习到的ω可以看作是最优的算法,那么此时,我们想要在测试任务中得到一个最优的模型参数,就需要在测试集的支撑集(test-support)数据上对模型参数θ进行更新训练,从而得到最优算法下的最优模型参数θ。
五、元知识
前面一直在说在train_query评估之后,如果效果不好,就更新算法,那么是更新算法的什么东西呢,又有哪些东西是可以更新的呢?这些东西就成为元知识。
如:超参数、初始化参数、特征空间嵌入,模型结构,损失函数等等
六、元学习方法的分类
这部分直接简单的归纳一下,细节可以看原视频。
1、基于优化(optimization-base method)
可优化的部分如下:
方法举例:MAML
2、基于模型(model-based method)
可以直接生成一个模型,不需要先训练一个算法再训练模型
方法举例:
Memory-augmented neural network
Meta networks
CNAPs
3、基于度量(metric-based method)
先提取每类样本的特征,然后拿测试样本的特征与每类样本计算距离,并将其预测为距离最近的那一类,这种方法分类器是没有模型的,但是特征提取的过程可以用cnn来做。
方法举例:
仅简单的记录一下元学习的相关知识,若有不对,欢迎指正!