How to Fine-Tune BERT for Text Classification
论文《How to Fine-Tune BERT for Text Classification?》是2019年发表的一篇论文。这篇文章做了一些实验来分析了如何在文本分类场景下微调BERT,是后面网上讨论如何微调BERT时经常提到的论文。
结论与思路
先来看一下论文的实验结论:
- BERT模型上面的层对于文本分类任务更有用;
- 选取合适的逐层递减的学习率,Bert可以克服灾难性遗忘问题(catastrophic forgetting problem);
- 任务内(Within-task)和领域内(in-domain)继续预训练(further pre-training) 可以显著提高模型的性能;
- 在单任务微调之前先进行多任务微调(multi-task fine-tuning)对于单任务微调有帮助,但是其好处没有继续预训练大;
- BERT可以改进小数据量的任务。
接下来看论文是如何来微调BERT模型的,论文从如下三种方法中来找最合适的微调方法,因此将微调BERT的方式有三种如上图所示。
- 微调策略(Fine-Tuning Strategies):在微调BERT时如何利用BERT,比如BERT的哪些层对目标任务更有效。如何选择优化算法和学习率?
- 继续预训练(Further Pre-training):BERT是在通用领域的语料上训练的,通用领域的数据分布与目标领域很可能不同,所以很自然的想法是在目标领域语料上继续训练BERT。
- 多任务微调(Multi-task Fine-Tuning):多任务学习已经表现出可以有效利用不同任务之间的共享知识,如果目标领域有多个不同任务,那么在这些任务上同时微调BERT可以带来好处吗?
论文实验设置
- 使用的模型为base BERT模型:uncased BERT-base模型和Chineses BERT-base模型
- 使用的数据集的统计信息如下图所示,一共有8个数据集。
- 情感分析(Sentiment analysis):使用了二分类电影评论数据集IMDb,Yelp评论数据集的二分类和五分类版本。
- 问题分类(Question classification):六分类版本的TREC数据集,Yahoo! Answers数据集。
- 话题分类(Topic classification):AG’s News数据集,DBPedia数据集。从SogouCA和SogouCS新闻语料集构建了一个中文话题分类数据集:通过URL来决定话题类别,比如""http://sports.sohu.com"对应"sports"类别;一共选取了"sports",“house”,“business”,“entertainment”,“women”,"technology"共6个类别,每个类别的训练集样本为9000测试集为1000。
- 数据预处理:遵循BERT论文中的词汇表和分词方式:30,000个token词汇表和用 ##来分割word的WordPiece embedding。数据集中文档长度的统计时基于word piece的。 对于BERT的继续训练,对英文数据集使用spaCy进行句子分割,对中文数据集使用“。”,“?”,“!”来进行句子分割。
- 超参数
- 继续预训练在1个TITAN Xp GPU上进行,batch size为32,最大训练长度为128,学习率时5e-5,训练步数为100,000,warm-up步数为10,000。
- 微调在4个TITAN Xp GPU上进行,为确保显存被充分利用batch size为24,dropout概率为0.1。Adam优化器的 β 1 = 0.9 \beta_1=0.9 β1=0.9且 β 2 = 0.999 \beta_2=0.999 β2=0.999。使用slanted triangular learning rates,基础学习率为2e-5,warm-up比例为0.1。根据经验将最大训练epoch设为4,将在验证集上效果最好的模型保存下来用于测试。
微调策略及实验
将BERT应用到目标任务时,需要考虑几个因素:
- BERT的最大序列长度时512,所以在使用BERT时先要对长文本进行预处理。
- BERT-base模型包括一个embedding层,12个encoder层,一个pooling层。在使用时需要选择对文本分类任务最有效的层。
- 过拟合问题,如何选择合适的学习率防止BERT在目标任务上过拟合。
BERT模型里更低的层包含更通用的信息,所以论文作者考虑对不同的层使用不同的学习率。将BERT模型的参数 θ \theta θ表示成 { θ 1 , ⋯ , θ L } \{\theta^1, \cdots, \theta^L \} {θ1,⋯,θL},其中 θ l \theta^l θl是BERT的第 l l l层的参数,则微调时每一层的参数更新可表示为如下:
θ t l = θ t − 1 l − η l ⋅ ∇ θ l J ( θ ) \theta^l_t = \theta^l_{t-1} - \eta^l \cdot \nabla_{\theta^l} J(\theta) θtl=θt−1l−ηl⋅∇θlJ(θ)
上式中 η l \eta^l ηl是BERT的第 l l l层的学习率。将基准学习率设置为 η L \eta^L ηL,并使用 η k − 1 = ξ ⋅ η k \eta^{k-1}=\xi \cdot \eta^k ηk−1=ξ⋅ηk表示各层学习率之间的关系; ξ \xi ξ是衰减因子,它小于等于1。当 ξ = 1 \xi = 1 ξ=1时,所有层的学习率都是一样的,也就相当于普通的SGD了。
BERT的最大序列长度时512,所以在使用BERT时先要对长文本进行预处理。考虑如下方法来处理长文本:
- 裁剪方法(Truncation methods),因为一篇文章的主要信息通常在其开始和结束部位,所以使用了如下三种不同的方法的来裁剪文本。
- head-only:保留文本前510个token(512-[CLS]-[SEP])
- tail-only:保留文本最后510个token
- head+tail:按经验选择前128个token以及最后382个token
- 层次方法(Hierarchical methods): 设文本的长度为L,将文本划分为 k=L/510 个片段,将它们输入BERT得到k个文本片段的表征向量。每个片段的表征向量取的是最后一层的[CLS]token的隐状态向量。然后使用mean pooling,max pooling, self-attention来组合这些片段的表征向量。
在IMDb和Sogou数据集上的实验表明 head+tail裁剪方法表现最好,所以在论文后面的实验中都使用这种方法来处理长文本。
论文试验了使用BERT不同的层捕捉文本的特征,微调模型并记录模型的测试错误率如下图所示。BERT最后一层微调后的性能最好。
灾难性遗忘是指在迁移学习过程中,学习新知识时预训练的知识被消除掉了。论文作者使用不同的学习率来微调BERT,在IMDb上的错误率的学习曲线如下图所示。实验表明一个较低的学习率比如2e-5对于BERT克服灾难性遗忘是必要的。在比较激进的学习率如4e-4训练集难以收敛。
下图是不同的基准学习率和衰减因子在IMDb数据集上的表现,逐层递减的学习率比固定学习率在微调BERT时表现要好,一个合适的选择是 ξ = 0.95 \xi=0.95 ξ=0.95和 l r = 2.0 e − 5 lr=2.0e-5 lr=2.0e−5。
继续预训练及实验
因为BERT模型是在通用领域的语料上训练的,对于特定领域的文本分类任务比如电影评论,其数据分布可能与BERT不一样。所以可以在领域相关的数据上继续预训练模型,论文进行了三种继续预训练的方法:
- 任务内的继续预训练(Within-Task Further Pre-Training),在目标任务的训练数据上继续预训练BERT。
- 领域内的继续预训练(In-Domain Further pre-training),训练数据是从目标任务相同领域来获取的。比如几个不同的情感分类任务,它们有类似的数据分布,在这些任务的组合训练数据上来继续预训练BERT。
- 跨领域继续预训练(Cross-Domain Further pre-training),包括与目标任务相同领域以及其他领域的训练数据。
任务内的继续预训练:作者试验了不同的训练步数来继续预训练模型,再用之前得到的最好的微调策略来微调模型。如下图所示继续预训练有助于提高BERT的性能,再100K个训练步后得到最佳性能。
领域内与跨领域继续预训练:将7个英文数据集划分为3个领域:情感,话题,问题,这个划分不是严格正确的,所以作者也将每个数据集当作不同的领域进行了实验,结果如下图所示。
- 领域内继续预训练总体而言比任务内继续预训练可以带来更好的效果。在句子级别的小数据集TREC上,任务内继续预训练有害于模型效果,而在Yah.A语料上的领域继续预训练后得到了更好的效果。
- 跨领域继续预训练(下面图中的标记为"all"的行)总体而言没有带来明显的好处。因为BERT已经在通用领域训练过了。
- IMDb和Yelp在情感领域内没有给互相带来性能提升。可能因为它们分别是关于电影和食物的,数据分布可能有明显差别。
将微调后的模型与其他文本分类模型的比较如下图所示,BERT-Feat是指用BERT来进行特征提取之后,将特征作为biLSTM+self-attention的输入embedding。BERT-FiT是直接微调BERT得到的模型,BERT-ITPT-FiT是任务内继续预训练模型,BERT-IDPT-FiT是领域内继续预训练后微调的模型(对应于上图的’all sentiment’, ‘all question’,‘all topic’),BERT-CDPT-FiT对应跨领域继续预训练后微调的模型(对应于上图的"all"一行)
- BERT-Feat 比除ULMFiT之外的模型效果都要好。
- BERT-FiT只比BERT-Feat在数据集DBpedia上差一点点,其余数据集上效果都更好。
- 三个继续预训练模型微调之后的效果都比BERT-Fit模型更好。
- BERT-IDPT-FiT即领域内继续预训练再微调的效果是最好的。
此外作者评估了BERT-FiT和BERT-ITPT-FiT在不同样本数量的训练集上微调训练的效果,在IMDb的训练数据里选了一个子集来微调模型,结果如下图,实验表明BERT在小数据集也可以带来显著效果提升,继续预训练BERT可以进一步提升效果。
多任务微调及实验
多任务学习可以从不同的监督学习任务共享知识,所有任务共享BERT层和embedding层,每个任务有自己的分类层。
论文在四个英文数据集(IMDb, Yelp P., AG, DBP)上进行多任务微调,先对四个任务一起微调训练,再使用一个更低的学习率在每个数据集上额外进行微调训练。实验结果如下图,结果表明多任务微调对结果有提升,但是跨领域继续微调模型的多任务微调在数据集Yelp P.和AG.上没有效果,作者推测跨领域继续微调和多任务学习微调可能是可互相替代的方法,因为跨领域继续微调模型已经学习到了丰富的领域相关的信息,多任务学习就不会提高文本分类子任务的泛化性了。