本节课可视为机器学习系列课程的一个前期攻略,这节课主要对Machine Learning 的框架进行了简单的介绍;并以training data上的loss大小为切入点,介绍了几种常见的在模型训练的过程中容易出现的情况。
课程视频:
Youtube: https://www.youtube.com/watch?v=WeHM2xpYQpw
课程PPT:
https://view.officeapps.live.com/op/view.aspx?src=https%3A%2F%2Fspeech.ee.ntu.edu.tw%2F~hylee%2Fml%2Fml2021-course-data%2Foverfit-v6.pptx&wdOrigin=BROWSELINK
以下是本节课的课程笔记。
一、Framework of ML
机器学习的数据集总体上分为训练集(training data)和测试集(testing data)。其中训练集由feature x和ground truth y组成,模型在训练集上学习x和y之间的隐含关系,再在测试集上对模型的好坏进行验证。
模型在训练集上的training大致可以分为以下三个steps:
Step1:初步划定一个model set:y = f(x),其中模型 f 由系列参数 𝜽 确定,如果𝜽的值不同,我们则说模型不同。
Step2:划定好model set后就需要定义一个loss function 来对模型的好坏进行评估,通常,loss function反映的是模型的预测值和ground truth之间的差距,差距越小(loss值越小),则模型越好。
Step3:定义好loss function后,就开始对模型进行优化,找到让loss指最小的参数集合𝜽*,𝜽* 所对应的model f* 即为我们最终想要学习到的模型。
二、General Guide
在训练模型的过程中,我们往往会根据training data上的loss值来初步判断模型的好坏。
1.training data上loss过大
导致training data上loss值过大的原因主要有以下两个:
(1)model bias
即模型模型太简单(大海捞针,但针不在海里),通常的解决措施是重新设计模型使其具有更大的弹性,例如,在输入中增加更多的feature,或者使用deep learning以增加模型的弹性(more neurons, layers)。
(2)optimization
optimization做得不好,没有找到最优的function(大海捞针,针在海里,但就是没捞到)。例如我们通常使用gradient decent的optimization方法,但这种方法可能会卡在local minimum的地方,从而导致我们没有找到全局最小解。如果是optimization做的不好,我们需要使用更powerful的optimization方法,这在后面的学习中会有介绍。
Q:如何判断训练集上的loss大时由model bias还是optimization引起的?
(参考文献:Deep Residual Learning for Image Recognition)
主要是通过对不同的模型进行比较来判断(判断模型是否足够大)。当我们看到一个从来没有做过的问题,可以先跑一些比较浅的network,甚至一些不属于DL的方法,因为这些方法不太会有optimization失败的问题。如果在训练集上deeper network反而没有得到更小的loss,则可能是optimization出了问题。(注意:过拟合是deeper network在训练集上loss小,在测试集上loss大)
例如,在下图右部分,56-layer的loss值较之20-layer的反而更大,则很可能是opyimization出了问题。
2.training data上loss值较小
如果在training data上的loss值比较小,则可以看看模型在测试集上的表现了。如果测试值上loss值很小,那这正是我们期待的结果。如果很不幸,模型在测试集上loss较大,此时又可大致分为两种情况:
(1)overfitting
overfitting即模型过度地对训练数据进行了拟合,把一些非common feature当做common data学习到了。此事的solution主要有:
A. more training data,即增加更多的训练数据;
B. data augmentation,如果训练数据有限,则可以在原有数据的基础上,通过一些特殊处理,创造一些资料。
C. make your model simpler,常见的举措有:
- less parameters/ sharing parameters (让一些model共用参数)
- early stopping
- regularization
- dropout
(2)mismatch
mismatch则是由于训练资料和测试资料的分布不一致导致的,这个时候增加训练资料也没用。在HW11中会具体讲解这类情况。
三、如何保证选择的model是合理的
如果一个模型只是在训练集上强行将输入x和ground truth y相关联,而没有学习到一些实质性的东西,那么到了测试集上模型的表现将会是很差的。通常的解决措施是引入交叉验证。
1.Cross Validation
在训练时,我们从测试集中划出部分数据作为validation set来衡量loss,根据validation set上的得分情况去挑选最优的模型,再在测试集上对模型的好坏进行验证。
2.N-fold Cross Validation
如果训练数据较少,可采用N折交叉验证的方法。即将训练数据分为N等份,依次以第一份、第二份……第 i 份作为验证集(其余作为测试集),这样重复N次对模型进行训练、验证。在将这N次训练中各个模型在验证集上的N次得分的平均进行比较,选择loss最小的模型作为我们的最优模型,并用它在测试集上对模型进行评分。