文本分类场景下微调BERT

How to Fine-Tune BERT for Text Classification

论文《How to Fine-Tune BERT for Text Classification?》是2019年发表的一篇论文。这篇文章做了一些实验来分析了如何在文本分类场景下微调BERT,是后面网上讨论如何微调BERT时经常提到的论文。

结论与思路

先来看一下论文的实验结论:

  1. BERT模型上面的层对于文本分类任务更有用;
  2. 选取合适的逐层递减的学习率,Bert可以克服灾难性遗忘问题(catastrophic forgetting problem);
  3. 任务内(Within-task)和领域内(in-domain)继续预训练(further pre-training) 可以显著提高模型的性能;
  4. 在单任务微调之前先进行多任务微调(multi-task fine-tuning)对于单任务微调有帮助,但是其好处没有继续预训练大;
  5. BERT可以改进小数据量的任务。

接下来看论文是如何来微调BERT模型的,论文从如下三种方法中来找最合适的微调方法,因此将微调BERT的方式有三种如上图所示。

  • 微调策略(Fine-Tuning Strategies):在微调BERT时如何利用BERT,比如BERT的哪些层对目标任务更有效。如何选择优化算法和学习率?
  • 继续预训练(Further Pre-training):BERT是在通用领域的语料上训练的,通用领域的数据分布与目标领域很可能不同,所以很自然的想法是在目标领域语料上继续训练BERT。
  • 多任务微调(Multi-task Fine-Tuning):多任务学习已经表现出可以有效利用不同任务之间的共享知识,如果目标领域有多个不同任务,那么在这些任务上同时微调BERT可以带来好处吗?

论文实验设置

  • 使用的模型为base BERT模型:uncased BERT-base模型和Chineses BERT-base模型
  • 使用的数据集的统计信息如下图所示,一共有8个数据集。
    • 情感分析(Sentiment analysis):使用了二分类电影评论数据集IMDb,Yelp评论数据集的二分类和五分类版本。
    • 问题分类(Question classification):六分类版本的TREC数据集,Yahoo! Answers数据集。
    • 话题分类(Topic classification):AG’s News数据集,DBPedia数据集。从SogouCA和SogouCS新闻语料集构建了一个中文话题分类数据集:通过URL来决定话题类别,比如""http://sports.sohu.com"对应"sports"类别;一共选取了"sports",“house”,“business”,“entertainment”,“women”,"technology"共6个类别,每个类别的训练集样本为9000测试集为1000。
  • 数据预处理:遵循BERT论文中的词汇表和分词方式:30,000个token词汇表和用 ##来分割word的WordPiece embedding。数据集中文档长度的统计时基于word piece的。 对于BERT的继续训练,对英文数据集使用spaCy进行句子分割,对中文数据集使用“。”,“?”,“!”来进行句子分割。
  • 超参数
    • 继续预训练在1个TITAN Xp GPU上进行,batch size为32,最大训练长度为128,学习率时5e-5,训练步数为100,000,warm-up步数为10,000。
    • 微调在4个TITAN Xp GPU上进行,为确保显存被充分利用batch size为24,dropout概率为0.1。Adam优化器的 β 1 = 0.9 \beta_1=0.9 β1=0.9 β 2 = 0.999 \beta_2=0.999 β2=0.999。使用slanted triangular learning rates,基础学习率为2e-5,warm-up比例为0.1。根据经验将最大训练epoch设为4,将在验证集上效果最好的模型保存下来用于测试。

在这里插入图片描述

微调策略及实验

将BERT应用到目标任务时,需要考虑几个因素:

  • BERT的最大序列长度时512,所以在使用BERT时先要对长文本进行预处理。
  • BERT-base模型包括一个embedding层,12个encoder层,一个pooling层。在使用时需要选择对文本分类任务最有效的层。
  • 过拟合问题,如何选择合适的学习率防止BERT在目标任务上过拟合。

BERT模型里更低的层包含更通用的信息,所以论文作者考虑对不同的层使用不同的学习率。将BERT模型的参数 θ \theta θ表示成 { θ 1 , ⋯ , θ L } \{\theta^1, \cdots, \theta^L \} {θ1,,θL},其中 θ l \theta^l θl是BERT的第 l l l层的参数,则微调时每一层的参数更新可表示为如下:

θ t l = θ t − 1 l − η l ⋅ ∇ θ l J ( θ ) \theta^l_t = \theta^l_{t-1} - \eta^l \cdot \nabla_{\theta^l} J(\theta) θtl=θt1lηlθlJ(θ)

上式中 η l \eta^l ηl是BERT的第 l l l层的学习率。将基准学习率设置为 η L \eta^L ηL,并使用 η k − 1 = ξ ⋅ η k \eta^{k-1}=\xi \cdot \eta^k ηk1=ξηk表示各层学习率之间的关系; ξ \xi ξ是衰减因子,它小于等于1。当 ξ = 1 \xi = 1 ξ=1时,所有层的学习率都是一样的,也就相当于普通的SGD了。


BERT的最大序列长度时512,所以在使用BERT时先要对长文本进行预处理。考虑如下方法来处理长文本:

  • 裁剪方法(Truncation methods),因为一篇文章的主要信息通常在其开始和结束部位,所以使用了如下三种不同的方法的来裁剪文本。
    • head-only:保留文本前510个token(512-[CLS]-[SEP])
    • tail-only:保留文本最后510个token
    • head+tail:按经验选择前128个token以及最后382个token
  • 层次方法(Hierarchical methods): 设文本的长度为L,将文本划分为 k=L/510 个片段,将它们输入BERT得到k个文本片段的表征向量。每个片段的表征向量取的是最后一层的[CLS]token的隐状态向量。然后使用mean pooling,max pooling, self-attention来组合这些片段的表征向量。
    在IMDb和Sogou数据集上的实验表明 head+tail裁剪方法表现最好,所以在论文后面的实验中都使用这种方法来处理长文本。
    在这里插入图片描述

论文试验了使用BERT不同的层捕捉文本的特征,微调模型并记录模型的测试错误率如下图所示。BERT最后一层微调后的性能最好。
在这里插入图片描述

灾难性遗忘是指在迁移学习过程中,学习新知识时预训练的知识被消除掉了。论文作者使用不同的学习率来微调BERT,在IMDb上的错误率的学习曲线如下图所示。实验表明一个较低的学习率比如2e-5对于BERT克服灾难性遗忘是必要的。在比较激进的学习率如4e-4训练集难以收敛。
在这里插入图片描述

下图是不同的基准学习率和衰减因子在IMDb数据集上的表现,逐层递减的学习率比固定学习率在微调BERT时表现要好,一个合适的选择是 ξ = 0.95 \xi=0.95 ξ=0.95 l r = 2.0 e − 5 lr=2.0e-5 lr=2.0e5
在这里插入图片描述

继续预训练及实验

因为BERT模型是在通用领域的语料上训练的,对于特定领域的文本分类任务比如电影评论,其数据分布可能与BERT不一样。所以可以在领域相关的数据上继续预训练模型,论文进行了三种继续预训练的方法:

  • 任务内的继续预训练(Within-Task Further Pre-Training),在目标任务的训练数据上继续预训练BERT。
  • 领域内的继续预训练(In-Domain Further pre-training),训练数据是从目标任务相同领域来获取的。比如几个不同的情感分类任务,它们有类似的数据分布,在这些任务的组合训练数据上来继续预训练BERT。
  • 跨领域继续预训练(Cross-Domain Further pre-training),包括与目标任务相同领域以及其他领域的训练数据。

任务内的继续预训练:作者试验了不同的训练步数来继续预训练模型,再用之前得到的最好的微调策略来微调模型。如下图所示继续预训练有助于提高BERT的性能,再100K个训练步后得到最佳性能。
在这里插入图片描述

领域内与跨领域继续预训练:将7个英文数据集划分为3个领域:情感,话题,问题,这个划分不是严格正确的,所以作者也将每个数据集当作不同的领域进行了实验,结果如下图所示。

  • 领域内继续预训练总体而言比任务内继续预训练可以带来更好的效果。在句子级别的小数据集TREC上,任务内继续预训练有害于模型效果,而在Yah.A语料上的领域继续预训练后得到了更好的效果。
  • 跨领域继续预训练(下面图中的标记为"all"的行)总体而言没有带来明显的好处。因为BERT已经在通用领域训练过了。
  • IMDb和Yelp在情感领域内没有给互相带来性能提升。可能因为它们分别是关于电影和食物的,数据分布可能有明显差别。
    在这里插入图片描述

将微调后的模型与其他文本分类模型的比较如下图所示,BERT-Feat是指用BERT来进行特征提取之后,将特征作为biLSTM+self-attention的输入embedding。BERT-FiT是直接微调BERT得到的模型,BERT-ITPT-FiT是任务内继续预训练模型,BERT-IDPT-FiT是领域内继续预训练后微调的模型(对应于上图的’all sentiment’, ‘all question’,‘all topic’),BERT-CDPT-FiT对应跨领域继续预训练后微调的模型(对应于上图的"all"一行)

  • BERT-Feat 比除ULMFiT之外的模型效果都要好。
  • BERT-FiT只比BERT-Feat在数据集DBpedia上差一点点,其余数据集上效果都更好。
  • 三个继续预训练模型微调之后的效果都比BERT-Fit模型更好。
  • BERT-IDPT-FiT即领域内继续预训练再微调的效果是最好的。
    在这里插入图片描述

此外作者评估了BERT-FiT和BERT-ITPT-FiT在不同样本数量的训练集上微调训练的效果,在IMDb的训练数据里选了一个子集来微调模型,结果如下图,实验表明BERT在小数据集也可以带来显著效果提升,继续预训练BERT可以进一步提升效果。
在这里插入图片描述

多任务微调及实验

多任务学习可以从不同的监督学习任务共享知识,所有任务共享BERT层和embedding层,每个任务有自己的分类层。
论文在四个英文数据集(IMDb, Yelp P., AG, DBP)上进行多任务微调,先对四个任务一起微调训练,再使用一个更低的学习率在每个数据集上额外进行微调训练。实验结果如下图,结果表明多任务微调对结果有提升,但是跨领域继续微调模型的多任务微调在数据集Yelp P.和AG.上没有效果,作者推测跨领域继续微调和多任务学习微调可能是可互相替代的方法,因为跨领域继续微调模型已经学习到了丰富的领域相关的信息,多任务学习就不会提高文本分类子任务的泛化性了。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53570.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sqoop 数据迁移

Sqoop 数据迁移 一、Sqoop 概述二、Sqoop 优势三、Sqoop 的架构与工作机制四、Sqoop Import 流程五、Sqoop Export 流程六、Sqoop 安装部署6.1 下载解压6.2 修改 Sqoop 配置文件6.3 配置 Sqoop 环境变量6.4 添加 MySQL 驱动包6.5 测试运行 Sqoop6.5.1 查看Sqoop命令语法6.5.2 测…

【数学建模】2024数学建模国赛经验分享

文章目录 一、关于我二、我的数模历程三、经验总结: 一、关于我 我的CSDN主页:https://gxdxyl.blog.csdn.net/ 2020年7月(大二结束的暑假)开始在CSDN写作: 阿里云博客专家: 接触的领域挺多的&#xff…

【Linux】常用指令(中)(附带基础指令的详细讲解、Linux的一些附加知识)

文章目录 前言1. Linux基础常用指令1.1 通配符 "*"1.2 man指令(重要)1.2.1 man指令的语法 1.3 何为"指令"?(附带知识)1.4 echo指令1.5 cat指令1.6 Linux下一切皆文件!1.6.1 ">" 输出重定向1.6.2…

Qt篇——Qt使用C++获取Windows电脑上所有外接设备的名称、物理端口位置等信息

我之前有发过一篇文章《Qt篇——获取Windows系统上插入的串口设备的物理序号》,文章中主要获取的是插入的USB串口设备的物理序号;而本篇文章则进行拓展,可以获取所有外接设备的相关信息(比如USB摄像头、USB蓝牙、USB网卡、其它一些…

分布式技术概览

文章目录 分布式技术1. 分布式数据库(Distributed Databases)2. 分布式文件系统(Distributed File Systems)3. 分布式哈希表(Distributed Hash Tables, DHTs)4. 分布式缓存(Distributed Caching…

面试必问:Java 类加载过程

java 类加载过程主要分为加载、链接和初始化三个阶段,六个关键步骤:加载、验证、准备、解析、初始化。 加载阶段(Loading) 加载时类加载的第一个过程,在这个阶段,将完成以下三件事情: 1&#…

基于Springboot的鲜花销售网站的设计与实现

项目描述 这是一款基于Springboot的鲜花销售网站的系统 模块描述 鲜花销售系统 1、用户 登录 在线注册 浏览商品 鲜花搜索 订购商品 查询商品详情 水果分类查看 水果加购物车 下单结算 填写收货地址 2、管理员 登录 用户管理 商品管理 订单管理 账户管理 截图

项目经理应该学习pmp还是cspm?

职场竞争激烈,项目管理专业人才在各个行业中的作用越来越凸显出来。在23年之前,我国关于通用项目管理人才的培养更多依赖于国外的PMP认证,缺少自主的认证评价标准和体系。 为了弥补这一空缺,基于国内的项目管理发展需求&#xff…

西门子博途零基础学PLC必会的100个指令

#西门子##PLC##自动化##工业自动化##编程##电工##西门子PLC##工业##制造业##数字化##电气##工程师# 工控人加入PLC工业自动化精英社群 工控人加入PLC工业自动化精英社群

OpenMV——色块追踪

Python知识: 1.给Python的列表赋值: 定义一个元组时就是 元组a (1,2,…) 元组中可以只有一个元素,但是就必须要加一个 “ , ” 如 a (2,) 而列表的定义和元组类似,只是把()换成[]: #那么下面的colour_1 ~ 3属于元组&#xf…

(计算机网络)运输层

一.运输层的作用 运输层:负责将数据统一的交给网络层 实质:进程在通信 TCP(有反馈)UDP(无反馈) 二.复用和分用 三. TCP和UDP的特点和区别 进程号--不是固定的 端口号固定--mysql--3306 端口--通信的终点 …

苹果的“AI茅”之路只走了一半

今年苹果发布会最大的亮点,也许是和华为“撞档”,又或者是替腾讯“发布”新手游,但肯定不是iPhone 16。 9月10日,苹果秋季新品发布会与华为见非凡品牌盛典相继举行,iPhone 16系列也与HUAWEI Mate XT同日发布。 不过&…

传统CV算法——特征匹配算法

Brute-Force蛮力匹配 Brute-Force蛮力匹配是一种简单直接的模式识别方法,经常用于计算机视觉和数字图像处理领域中的特征匹配。该方法通过逐一比较目标图像中的所有特征点与源图像中的特征点来寻找最佳匹配。这种方法的主要步骤包括: 特征提取&#xff…

根据NVeloDocx Word模板引擎生成Word(三)

基于永久免费开放的《E6低代码开发平台》的Word模版引擎NVeloDocx,实现根据Word模版生成Word文件,前面2篇已经非常详细介绍了《主表单字段》,《子表记录循环输入到表格》。那这一篇我们就介绍插入单张图片、二维码,条形码等等&…

python-网页自动化(三)

如果遇到使用 ajax 加载的网页,页面元素可能不是同时加载出来的,这个时候尝试在 get 方法执行完 成时获取网页源代码可能并非浏览器完全加载完成的页面。所以,这种情况下需要设置延时等待一定时间,确保全部节点都加载出来。 那么&…

【Petri网导论学习笔记】Petri网导论入门学习(一)

Petri 网导论 如需学习转载请注明原作者并附本帖链接!!! 如需学习转载请注明原作者并附本帖链接!!! 如需学习转载请注明原作者并附本帖链接!!! 发现网上关于Petri网的学习…

【机器学习】从零开始理解深度学习——揭开神经网络的神秘面纱

1. 引言 随着技术的飞速发展,人工智能(AI)已从学术研究的实验室走向现实应用的舞台,成为推动现代社会变革的核心动力之一。而在这一进程中,深度学习(Deep Learning)因其在大规模数据处理和复杂问题求解中的卓越表现,迅速崛起为人工智能的最前沿技术。深度学习的核心是…

金智维K-RPA基本介绍

一、K-RPA基本组成 K-RPA软件机器人管理系统基于“RPAX”数字化技术打造,其核心系统由管理中心(Server)、设计器(Control)、机器人(Robot/Agent)三大子系统组成,各子系统协同工作,易于构建协同式环境。 管理中心(Server&#xff…

【Linux 运维知识】Linux 编译后的内核镜像大小

Linux 内核镜像的大小取决于多个因素,包括内核的版本、启用的功能、模块的数量以及特定的编译配置。 以下是常见情况下不同内核镜像的大小范围: 1. 标准内核镜像大小 压缩后的内核镜像 (vmlinuz): 压缩后的内核镜像文件,通常位于…

基于boost的共享内存通信demo

文章目录 前言一、共享内存管理二、图像算法服务中的IPC通信流程三、demo实验结果总结 前言 在一个系统比较复杂的时候,将模块独立成单独的进程有助于错误定位以及异常重启恢复,不至于某个模块发生崩溃导致整个系统崩溃。当通信数据量比较大时&#xff…