TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用

TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

1.简介

TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。

主要特点:

  • 模型无关:适用于多种模型结构(主要面向Transfomer结构)
  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
  • 非侵入式:无需对教师与学生模型本身结构进行修改
  • 支持典型的NLP任务:文本分类、阅读理解、序列标注等

TextBrewer目前支持的知识蒸馏技术有:

  • 软标签与硬标签混合训练
  • 动态损失权重调整与蒸馏温度调整
  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
  • 任意构建中间层特征匹配方案
  • 多教师知识蒸馏

TextBrewer的主要功能与模块分为3块:

  1. Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
  3. Utilities:模型参数分析显示等辅助工具

用户需要准备:

  1. 已训练好的教师模型, 待蒸馏的学生模型
  2. 训练数据与必要的实验配置, 即可开始蒸馏

在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。

2.TextBrewer结构

2.1 安装要求

  • Python >= 3.6

  • PyTorch >= 1.1.0

  • TensorboardX or Tensorboard

  • NumPy

  • tqdm

  • Transformers >= 2.0 (可选, Transformer相关示例需要用到)

  • Apex == 0.1.0 (可选,用于混合精度训练)

  • 从PyPI自动下载安装包安装:

pip install textbrewer
  • 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer

2.2工作流程

  • Stage 1 : 蒸馏之前的准备工作:

    1. 训练教师模型
    2. 定义与初始化学生模型(随机初始化,或载入预训练权重)
    3. 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
  • Stage 2 : 使用TextBrewer蒸馏:

    1. 构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig),初始化distiller
    2. 定义adaptorcallback ,分别用于适配模型输入输出和训练过程中的回调
    3. 调用distillertrain方法开始蒸馏

2.3 以蒸馏BERT-base到3层BERT为例展示TextBrewer用法

在开始蒸馏之前准备:

  • 训练好的教师模型teacher_model (BERT-base),待训练学生模型student_model (3-layer BERT)
  • 数据集dataloader,优化器optimizer,学习率调节器类或者构造函数scheduler_class 和构造用的参数字典 scheduler_args

使用TextBrewer蒸馏:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):# model输出的第二、三个元素分别是logits和hidden statesreturn {'logits': model_outputs[1], 'hidden': model_outputs[2]}#蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(intermediate_matches=[    {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()#初始化distiller
distiller = GeneralDistiller(train_config=train_config, distill_config = distill_config,model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)#开始蒸馏
with distiller:distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

2.4蒸馏任务示例

  • Transformers 4示例

    • examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
    • examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
  • examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。

  • examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。

  • examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。

  • examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。

  • examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。

2.4.1蒸馏效果

我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。

  • 模型

  • 对于英文任务,教师模型为BERT-base-cased

  • 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-extElectra-base

我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。

  • 英文模型
Model#LayersHidden sizeFeed-forward size#ParamsRelative size
BERT-base-cased (教师)127683072108M100%
T6 (学生)6768307265M60%
T3 (学生)3768307244M41%
T3-small (学生)3384153617M16%
T4-Tiny (学生)4312120014M13%
T12-nano (学生)12256102417M16%
BiGRU (学生)-768-31M29%
  • 中文模型
Model#LayersHidden sizeFeed-forward size#ParamsRelative size
RoBERTa-wwm-ext (教师)127683072102M100%
Electra-base (教师)127683072102M100%
T3 (学生)3768307238M37%
T3-small (学生)3384153614M14%
T4-Tiny (学生)4312120011M11%
Electra-small (学生)12256102412M12%
  • T6的结构与DistilBERT[1], BERT6-PKD[2], BERT-of-Theseus[3] 相同。
  • T4-tiny的结构与 TinyBERT[4] 相同。
  • T3的结构与BERT3-PKD[2] 相同。

2.4.2 蒸馏配置

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
#其他参数为默认值

不同的模型用的matches我们采用了以下配置:

Modelmatches
BiGRUNone
T6L6_hidden_mse + L6_hidden_smmd
T3L3_hidden_mse + L3_hidden_smmd
T3-smallL3n_hidden_mse + L3_hidden_smmd
T4-TinyL4t_hidden_mse + L4_hidden_smmd
T12-nanosmall_hidden_mse + small_hidden_smmd
Electra-smallsmall_hidden_mse + small_hidden_smmd

各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。

2.4.3训练配置

蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。

2.4.4英文实验结果

在英文实验中,我们使用了如下三个典型数据集。

DatasetTask typeMetrics#Train#DevNote
MNLI文本分类m/mm Acc393K20K句对三分类任务
SQuAD 1.1阅读理解EM/F188K11K篇章片段抽取型阅读理解
CoNLL-2003序列标注F123K6K命名实体识别任务

我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。

Public results:

Model (public)MNLISQuADCoNLL-2003
DistilBERT (T6)81.6 / 81.178.1 / 86.2-
BERT6-PKD (T6)81.5 / 81.077.1 / 85.3-
BERT-of-Theseus (T6)82.4/ 82.1--
BERT3-PKD (T3)76.7 / 76.3--
TinyBERT (T4-tiny)82.8 / 82.972.7 / 82.1-

Our results:

Model (ours)MNLISQuADCoNLL-2003
BERT-base-cased (教师)83.7 / 84.081.5 / 88.691.1
BiGRU--85.3
T683.5 / 84.080.8 / 88.190.7
T381.8 / 82.776.4 / 84.987.5
T3-small81.3 / 81.772.3 / 81.478.6
T4-tiny82.0 / 82.675.2 / 84.089.1
T12-nano83.2 / 83.979.0 / 86.689.6

说明:

  1. 公开模型的名称后括号内是其等价的模型结构
  2. 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
  3. 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据

2.4.5中文实验结果

在中文实验中,我们使用了如下典型数据集。

DatasetTask typeMetrics#Train#DevNote
XNLI文本分类Acc393K2.5KMNLI的中文翻译版本,3分类任务
LCQMC文本分类Acc239K8.8K句对二分类任务,判断两个句子的语义是否相同
CMRC 2018阅读理解EM/F110K3.4K篇章片段抽取型阅读理解
DRCD阅读理解EM/F127K3.5K繁体中文篇章片段抽取型阅读理解
MSRA NER序列标注F145K3.4K (测试集)中文命名实体识别

实验结果如下表所示。

ModelXNLILCQMCCMRC 2018DRCD
RoBERTa-wwm-ext (教师)79.989.468.8 / 86.486.5 / 92.5
T378.489.066.4 / 84.278.2 / 86.4
T3-small76.088.158.0 / 79.375.8 / 84.8
T4-tiny76.288.461.8 / 81.877.3 / 86.1
ModelXNLILCQMCCMRC 2018DRCDMSRA NER
Electra-base (教师)77.889.865.6 / 84.786.9 / 92.395.14
Electra-small77.789.366.5 / 84.985.5 / 91.393.48

说明:

  1. 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
  2. CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
  3. Electra-base的教师模型训练设置参考自Chinese-ELECTRA
  4. Electra-small学生模型采用预训练权重初始化

3.核心概念

3.1Configurations

  • TrainingConfigDistillationConfig:训练和蒸馏相关的配置。

3.2Distillers

Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

  • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。
  • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用
  • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配
  • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。
  • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

3.3用户定义函数

蒸馏实验中,有两个组件需要由用户提供,分别是callbackadaptor :

3.3.1Callback

回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。

3.3.2Adaptor

将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。

更多细节可参见完整文档中的说明。

4.FAQ

Q: 学生模型该如何初始化?

A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。

Q: 如何设置蒸馏的训练参数以达到一个较好的效果?

A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考

Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?

A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?

A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/31807.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Spring Boot】拦截器与统一功能处理

博主简介:想进大厂的打工人博主主页:xyk:所属专栏: JavaEE进阶 上一篇文章我们讲解了Spring AOP是一个基于面向切面编程的框架,用于将某方面具体问题集中处理,通过代理对象来进行传递,但使用原生Spring AOP实现统一的…

平板选择什么电容笔比较好?ipad手写笔推荐品牌

在现在的生活上,有了iPad平板,一切都变得简单了许多,也让我们的学习以及工作都更加的便利。这其中,电容笔就起到了很大的作用,很多人都不知道,到底要买什么牌子的电容笔?哪些电容笔的性价比比较…

CentOS7 启动谷歌浏览器 java+Selenium+chrome+chromedriver

前言:自己想使用该技术实现自动化抓取音乐,目前在window上运行成功,需要在Linux Centos服务上跑,配置上出现了许多问题,特此记录。 参考文档:CentOS7 安装Seleniumchromechromedriverjava_远方丿的博客-CSD…

后端开发7.轮播图模块【mongdb开发】

概述 轮播图模块数据库采用mongdb开发 效果图 数据库设计 创建数据库 use sc; 添加数据 db.banner.insertMany([ {bannerId:"1",bannerName:"商城轮播图1",bannerUrl:"http://xx:8020/img/轮播图/shop1.png"}, {bannerId:"2"…

深入学习 Redis - 主从结构配置、流程、底层原理(全网最详细)

目录 前言 一、主从模式 1.1、概述 1.2、配置 redis 主从结构 1.2.1、复制配置文件,修改 1.2.2、配置主从结构 1.2.3、启动 redis 服务 1.2.4、查看复制状态 1.3、slaveof 命令 1.3.1、断开主从复制关系 1.3.2、切换主从复制关系 1.3.3、只读 1.3.4、网…

C语言——自定义类型详解[结构体][枚举][联合体]

自定义类型详解 前言:一、结构体1.1结构体的声明1.2结构体内存对齐1.3位段(位域) 二、枚举2.1枚举类型的定义2.2枚举类型的优点2.3枚举的使用 三、联合体3.1联合体类型的定义3.2联合体的特点3.3联合体大小的计算 前言: 我打算把结…

solr迁移到另一个solr中(docker单机)

背景介绍 solr数据迁移,或者版本升级,需要用到迁移,此处记录一下迁移方法以及过程中遇到的问题。我这边使用的是docker环境,非docker部署的应该也是一样的。 solr部署教程 准备工作 ● solrA 版本: 8.11.2 (已有so…

Kafka基本概念

文章目录 概要整体架构broker和集群ProducerConsumer和消费者组小结 概要 Kafka是最初由Linkedin公司开发,是一个分布式、分区的、多副本的、多生产者、多订阅者,基于 zookeeper协调的分布式日志系统(也可以当做MQ系统)&#xff…

最佳实践:Swagger 自动生成 Api 文档

目录 Tapir 介绍 为什么使用 Tapir 快速使用 Tapir 添加依赖 定义一个端点(Endpoint) 生成 Swagger ui 根据 yaml 生成 endpoint 自动生成 API 文档的好处不言而喻,它可以提供给你的团队或者外部协作者,方便 API 使用者准确地调用到你的 API。为了…

list的使用和模拟实现

目录 1.list的介绍及使用 1.1 list的介绍 1.2 list的使用 1.2.1 list的构造 1.2.2 list iterator的使用 1.2.3 list capacity 1.2.4 list element access 1.2.5 list modifiers 2.为什么使用迭代器? 3.list的模拟实现 3.1完整代码 3.2代码解析 4.list与…

YOLOv5-7.0实例分割+TensorRT部署

一:介绍 将YOLOv5结合分割任务并进行TensorRT部署,是一项既具有挑战性又令人兴奋的任务。分割(Segmentation)任务要求模型不仅能够检测出目标的存在,还要精确地理解目标的边界和轮廓,为每个像素分配相应的…

Spring Boot配置文件与日志文件

1. Spring Boot 配置文件 我们知道, 当我们创建一个Spring Boot项目之后, 就已经有了配置文件存在于目录结构中. 1. 配置文件作用 整个项目中所有重要的数据都是在配置文件中配置的,比如: 数据库的连接信息 (包含用户名和密码的设置) ;项目的启动端口;第三方系统的调…

KMP字符串 (简单清晰/Java)

Kmp算法 解决问题: 字符串匹配问题 怎么解决? 前缀表next[]数组 #分析 先看暴力做法: 两层for循环,一层遍历文本串,一层遍历模式串(子串)对应的每个字符进行匹配,匹配成功就 i &a…

如何将jar包部署到宝塔

尝试多种方式上传,但启动一直失败,这种方式亲测是好使的 项目内修改位置 在pom.xml文件中将mysql的scope改成provided,如果是固定的版本号会出现问题 之后就可以打包啦,直接点击maven中的package 找到打包文件的位置&#xff…

免费插件-illustrator-Ai插件-印刷功能-二维码生成

文章目录 1.介绍2.安装3.通过窗口>扩展>知了插件4.功能解释5.示例5.1.QR常用二维码5.2.PDF4175.3.EAN13 6.总结 1.介绍 本文介绍一款免费插件,加强illustrator使用人员工作效率,进行二维码生成。首先从下载网址下载这款插件 https://download.csd…

MySQL之深入InnoDB存储引擎——redo日志

文章目录 一、为什么需要redo日志二、redo日志的类型1)简单的redo日志类型2)复杂的redo日志类型 三、Mini-Transaction四、redo日志的写入过程五、redo日志文件1、刷盘时机2、redo日志文件组 六、log sequence number1、lsn的引入2、flushed_to_disk_lsn…

java 文件/文件夹复制,添加压缩zip

复制文件夹,并压缩成zip 需求:创建A文件夹,把B文件夹复制到A文件夹。然后把A文件夹压缩成zip包 public static void main(String[] args) throws Exception {try {String A "D:\\dev\\program";String B "D:\\program";// 创建临…

Vue 插槽 slot

solt 插槽需要分为 2.6.0 版本以上和 2.6.0版本以下。 2.6.0 版本以下的 slot 插槽在,2.x版本将继续支持,但是在 Vue 3 中已被废弃,且不会出现在官方文档中。 作用 插槽 prop 允许我们将插槽转换为可复用的模板,这些模板可以基于…

Qt应用开发(基础篇)——LCD数值类 QLCDNumber

一、前言 QLCDNumber类继承于QFrame,QFrame继承于QWidget,是Qt的一个基础小部件。 框架类QFrame介绍 QLCDNumber用来显示一个带有类似lcd数字的数字,适用于信号灯、跑步机、体温计、时钟、电表、水表、血压计等仪器类产品的数值显示。 QLCDNu…

【CSS】文本效果

文本溢出、整字换行、换行规则以及书写模式 代码&#xff1a; <style> p.test1 {white-space: nowrap; width: 200px; border: 1px solid #000000;overflow: hidden;text-overflow: clip; }p.test2 {white-space: nowrap; width: 200px; border: 1px solid #000000;ove…