NLP(六十六)使用HuggingFace中的Trainer进行BERT模型微调

  以往,我们在使用HuggingFace在训练BERT模型时,代码写得比较复杂,涉及到数据处理、token编码、模型编码、模型训练等步骤,从事NLP领域的人都有这种切身感受。事实上,HugggingFace中提供了datasets模块(数据处理)和Trainer函数,使得我们的模型训练较为方便。关于datasets模块,可参考文章NLP(六十二)HuggingFace中的Datasets使用。
  本文将会介绍如何使用HuggingFace中的Trainer对BERT模型微调。

Trainer

  Trainer是HuggingFace中的模型训练函数,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer 。
  Trainer的传入参数如下:

model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] = None
args: TrainingArguments = None
data_collator: typing.Optional[DataCollator] = None
train_dataset: typing.Optional[torch.utils.data.dataset.Dataset] = None
eval_dataset: typing.Union[torch.utils.data.dataset.Dataset, typing.Dict[str, torch.utils.data.dataset.Dataset], NoneType] = None
tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None
model_init: typing.Union[typing.Callable[[], transformers.modeling_utils.PreTrainedModel], NoneType] = None
compute_metrics: typing.Union[typing.Callable[[transformers.trainer_utils.EvalPrediction], typing.Dict], NoneType] = None
callbacks: typing.Optional[typing.List[transformers.trainer_callback.TrainerCallback]] = None
optimizers: typing.Tuple[torch.optim.optimizer.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None)
preprocess_logits_for_metrics: typing.Union[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor], NoneType] = None )

参数解释:

  • model为预训练模型
  • args为TrainingArguments(训练参数)类
  • data_collator会将数据集中的元素组成一个batch,默认使用default_data_collator(),如果tokenizer没有提供,则使用DataCollatorWithPadding
  • train_dataset, eval_dataset为训练集,验证集
  • tokenizer为模型训练使用的tokenizer
  • model_init为模型初始化
  • compute_metrics为验证集的评估指标计算函数
  • callbacks为训练过程中的callback列表
  • optimizers为模型训练中的优化器
  • preprocess_logits_for_metrics为模型评估阶段前对logits的预处理

  TrainingArguments为训练参数类,其网址为:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments,传入参数非常多(transformers版本4.32.1中有98个参数!),我们在这里只介绍几个常见的:

output_dir: stroverwrite_output_dir: bool = False
evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no'
per_gpu_train_batch_size: typing.Optional[int] = None
per_gpu_eval_batch_size: typing.Optional[int] = None
learning_rate: float = 5e-05
num_train_epochs: float = 3.0
logging_dir: typing.Optional[str] = None
logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'
save_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps'save_steps: float = 500
report_to: typing.Optional[typing.List[str]] = None

参数解释:

  • output_dir为模型输出目录
  • evaluation_strategy为模型评估策略
  1. “no": 不做模型评估
  2. “steps”: 按训练步数(steps)进行评估,需指定步数
  3. “epoch”: 每个epoch训练完后进行评估
  • per_gpu_train_batch_size, per_gpu_eval_batch_size为每个GPU上训练集和测试集的batch size,也有CPU上的对应参数
  • learning_rate为学习率
  • logging_dir为日志输出目录
  • logging_strategy为日志输出策略,同样有no, steps, epoch三种,意义同上
  • save_strategy为模型保存策略,同样有no, steps, epoch三种,意义同上
  • report_to为模型训练、评估中的重要指标(如loss, accurace)输出之处,可选择azure_ml, clearml, codecarbon, comet_ml, dagshub, flyte, mlflow, neptune, tensorboard, wandb,使用all会输出到所有的地方,使用no则不会输出。

  下面我们使用Trainer进行BERT模型微调,给出英语、中文数据集上文本分类的示例代码。

BERT微调

  使用datasets模块导入imdb数据集(英语影评数据集,常用于文本分类),加载预训练模型bert-base-cased的tokenizer。

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
raw_datasets = datasets.load_dataset('imdb')

  查看数据集,有train(训练集)、test(测试集)、unsupervised(非监督)三部分,我们这里使用训练集和测试集,各自有25000个样本。

raw_datasets
DatasetDict({train: Dataset({features: ['text', 'label'],num_rows: 25000})test: Dataset({features: ['text', 'label'],num_rows: 25000})unsupervised: Dataset({features: ['text', 'label'],num_rows: 50000})
})

  创建数据tokenize函数,对文本进行tokenize,最大长度设置为300,同时使用data_collector为DataCollatorWithPadding。

def tokenize_function(sample):return tokenizer(sample['text'], max_length=300, truncation=True)
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

  加载分类模型,输出类别为2.

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

  设置compute_metrics函数,在评估过程中输出accuracy, f1, precision, recall四个指标。设置训练参数TrainingArguments类,设置Trainer。

from transformers import Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_supportdef compute_metrics(pred):labels = pred.label_idspreds = pred.predictions.argmax(-1)precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')acc = accuracy_score(labels, preds)return {'accuracy': acc,'f1': f1,'precision': precision,'recall': recall}training_args = TrainingArguments(output_dir='imdb_test_trainer', # 指定输出文件夹,没有会自动创建evaluation_strategy="epoch",per_device_train_batch_size=32,per_device_eval_batch_size=32,learning_rate=5e-5,num_train_epochs=3,warmup_ratio=0.2,logging_dir='./imdb_train_logs',logging_strategy="epoch",save_strategy="epoch",report_to="tensorboard") trainer = Trainer(model,training_args,train_dataset=tokenized_datasets["train"],eval_dataset=tokenized_datasets["test"],data_collator=data_collator,  # 在定义了tokenizer之后,其实这里的data_collator就不用再写了,会自动根据tokenizer创建tokenizer=tokenizer,compute_metrics=compute_metrics
)

  开启模型训练。

trainer.train()
EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.3643000.2232230.9106000.9105090.9122760.910600
20.1648000.2044200.9239600.9239410.9243750.923960
30.0710000.2413500.9255200.9255100.9257590.925520
TrainOutput(global_step=588, training_loss=0.20003824169132986, metrics={'train_runtime': 1539.8692, 'train_samples_per_second': 48.705, 'train_steps_per_second': 0.382, 'total_flos': 1.156249755e+16, 'train_loss': 0.20003824169132986, 'epoch': 3.0})

  以上为英语数据集的文本分类模型微调。
  中文数据集使用sougou-mini数据集(训练集4000个样本,测试集495个样本,共5个输出类别),预训练模型采用bert-base-chinese。代码基本与英语数据集差不多,只要修改 预训练模型,数据集加载 和 最大长度为128,输出类别。以下是不同的代码之处:

import numpy as np
from transformers import AutoTokenizer, DataCollatorWithPadding
import datasetscheckpoint = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)data_files = {"train": "./data/sougou/train.csv", "test": "./data/sougou/test.csv"}
raw_datasets = datasets.load_dataset("csv", data_files=data_files, delimiter=",")
...
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=5)
...

输出结果如下:

EpochTraining LossValidation LossAccuracyF1PrecisionRecall
10.8492000.1151890.9696970.9694490.9700730.969697
20.1069000.0939870.9737370.9737700.9753720.973737
30.0478000.0788610.9737370.9737400.9741170.973737

模型评估

  在上述模型评估过程中,已经有了模型评估的各项指标。
  本文也给出单独做模型评估的代码,方便后续对模型做量化时(后续介绍BERT模型的动态量化)获取量化前后模型推理的各项指标。
  中文数据集文本分类模型评估代码如下:

import torch
from transformers import AutoModelForSequenceClassificationMAX_LENGTH = 128
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = f"./sougou_test_trainer_{MAX_LENGTH}/checkpoint-96"
model = AutoModelForSequenceClassification.from_pretrained(checkpoint).to(device)from transformers import AutoTokenizer, DataCollatorWithPaddingtokenizer = AutoTokenizer.from_pretrained(checkpoint)import pandas as pdtest_df = pd.read_csv("./data/sougou/test.csv")
test_df.head()
textlabel
0届数比赛时间比赛地点参加国家和地区冠军亚军决赛成绩第一届1956-1957英国11美国丹麦6...0
1商品属性材质软橡胶带加浮雕工艺+合金彩色队徽吊牌规格162mm数量这一系列产品不限量发行图案...0
2今天下午,沈阳金德和长春亚泰队将在五里河相遇。在这两支球队中沈阳籍球员居多,因此这场比赛实际...0
3本报讯中国足协准备好了与特鲁西埃谈判的合同文本,也在北京给他预订好了房间,但特鲁西埃爽约了!...0
4网友点击发表评论祝贺中国队夺得五连冠搜狐体育讯北京时间5月6日,2006年尤伯杯羽毛球赛在日...0
import numpy as np
import times_time = time.time()
true_labels, pred_labels = [], [] 
for i, row in test_df.iterrows():row_s_time = time.time()true_labels.append(row["label"])encoded_text = tokenizer(row['text'], max_length=MAX_LENGTH, truncation=True, padding=True, return_tensors='pt').to(device)# print(encoded_text)logits = model(**encoded_text)label_id = np.argmax(logits[0].detach().cpu().numpy(), axis=1)[0]pred_labels.append(label_id)if i % 100 == 0:print(i, (time.time() - row_s_time)*1000, label_id)print("avg time: ", (time.time() - s_time) * 1000 / test_df.shape[0])
0 229.3872833251953 0
100 362.0314598083496 1
200 311.16747856140137 2
300 324.13792610168457 3
400 406.9099426269531 4
avg time:  352.44047810332944
true_labels[:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
pred_labels[:10]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
from sklearn.metrics import classification_reportprint(classification_report(true_labels, pred_labels, digits=4))
              precision    recall  f1-score   support0     0.9900    1.0000    0.9950        991     0.9691    0.9495    0.9592        992     0.9900    1.0000    0.9950        993     0.9320    0.9697    0.9505        994     0.9895    0.9495    0.9691        99accuracy                         0.9737       495macro avg     0.9741    0.9737    0.9737       495
weighted avg     0.9741    0.9737    0.9737       495

总结

  本文介绍了如何使用HuggingFace中的Trainer对BERT模型微调。可以看到,使用Trainer进行模型微调,代码较为简洁,且支持功能丰富,是理想的模型训练方式。
  本文项目代码已开源至Github,网址为:https://github.com/percent4/PyTorch_Learning/tree/master/huggingface_learning 。
  本人已开通个人博客网站,网址为:https://percent4.github.io/ ,欢迎大家访问~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/64254.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vuex详细用法

Vuex是一个专门为Vue.js应用程序开发的状态管理模式。它可以帮助我们在Vue组件之间共享和管理数据,以及实现更好的代码组织和调试。 在Vue.js中,组件之间的数据通信可以通过props和事件来实现。然而,随着应用程序规模的增长,组件…

Modbus通信协议

Modbus通信协议 一、概述 Modbus通信协议是一种工业现场总线协议标准,常用的Modbus协议有以下三种类型:Modbus TCP、Modbus RTU、Modbus ASCll。 Modbus通信协议解决了通过串行线路在电子设备之间发送信息的问题。该协议在遵循该协议的体系结构中实现主…

本地虚机Jumpserver使用域名访问报错 使用IP+端口没有错误

背景: 我在本地Windows VMware 15的环境中部署了CentOS7.5,下载jumpserver-offline-installer-v2.28.1-amd64-138.tar.gz并安装部署。 需求: 1、能使用http:ip访问堡垒机。达成; 2、能使用http:域名访问堡垒机。达成&#xff…

基于大语言模型知识问答应用落地实践 – 知识库构建(下)

上篇介绍了构建知识库的大体流程和一些优化经验细节,但并没有结合一个具体的场景给出更细节的实战经验以及相关的一些 benchmark 等,所以本文将会切入到一个具体场景进行讨论。 目标场景:对于 PubMed 医疗学术数据中的 1w 篇文章进行知识库构…

c++(8.29)auto关键字,lambda表达式,数据类型转换,标准模板库,list,文件操作+Xmind

作业: 封装一个学生的类,定义一个学生这样类的vector容器, 里面存放学生对象(至少3个) 再把该容器中的对象,保存到文件中。 再把这些学生从文件中读取出来,放入另一个容器中并且遍历输出该容器里的学生。…

docker 是什么

目录 docker是一个软件 Docker 是一种运行于 Linux 和 Windows 上的软件,用于创建、管理和编排容器。 为什么要使用 Docker? 1、 更快速的交付和部署 2、 更高效的虚拟化 3、 更轻松的迁移和扩展 4 、更简单的管理 docker是一个软件,是…

CSS transition 过渡

1 前言 CSS过渡(transition)可以在一个元素切换到另一种状态时为其定义平滑的过渡效果。 例如,用户鼠标悬停在按钮上时,按钮颜色平滑的从一个颜色过渡到另一个颜色。 .btn:hover{background-color: red;color: black; }默认悬停效果 添加过渡效果 .b…

网工内推 | 上市公司,IT工程师、服务器工程师,IP以上优先

01 烟台睿创微纳技术股份有限公司 招聘岗位:IT工程师 职责描述: 1、负责网络及安全架构的规划、设计、性能优化; 2、负责网络设备的安装、配置、管理、排错、维护,提供网络设备维护方案; 3、负责防火墙、上网行为管理…

Python数据分析:从导入数据到生成报告的全面指南

随着数据科学和人工智能的迅速发展,Python 已经成为了最受欢迎的数据分析语言之一。Python 具有简单易学、灵活性强、可扩展性高等优点,使其在数据分析领域具有广泛的应用。本文将介绍 Python 数据分析的基本步骤,帮助你了解如何使用 Python …

滑动窗口和双指针

滑动窗口和双指针 一、循环不变量1.1 定义1.2 总结 二、使用循环不变量写对代码2.1 注意2.2 总结 三、滑动窗口3.1 固定长度的滑动窗口(同向交替移动的两个变量)3.2 不定长度的滑动窗口3.2.1 定义3.2.2 总结 3.3 计数问题3.3.1 标准3.3.2 总结 3.4 使用数…

react app教程

react app教程 环境准备 下载node 下载npx npm install npx创建app npx create-react-app automedia cd automedia npm start构建发布版本 npm run build安装调试工具 # .vscode/launch.json {// 使用 IntelliSense 了解相关属性。 // 悬停以查看现有属性的描述。// 欲了…

前端基础4——jQuery

文章目录 一、基本了解1.1 导入jQuery库1.2 基本语法1.3 选择器 二、操作HTML2.1 隐藏和显示元素2.2 获取与设置内容2.3 获取、设置和删除属性2.4 添加元素2.5 删除元素2.6 设置CSS样式 三、jQuery Ajax3.1 基本语法3.2 回调函数3.3 常用HTTP方法3.4 案例一3.4.1 准备工作3.4.2…

Java,Linux,Mysql小白入门

Java入门 java后端__阿伟_的博客-CSDN博客 Linux与Git入门 Linux与Git入门教程__阿伟_的博客-CSDN博客 Mysql入门 Linux与Git入门教程__阿伟_的博客-CSDN博客

go语言配置

1、Go语言的环境变量 与Java等编程语言一样,安装Go语言开发环境需要设置全局的操作系统环境变量(除非是用包管理工具直接安装) 主要的系统级别的环境变量有两个: (1)GOROOT:表示Go语言环境在计算机上的安…

springboot docker

在Spring Boot中使用Docker可以帮助你将应用程序与其依赖的容器化,并简化部署和管理过程。 当你在Spring Boot中使用Docker时,你的代码不需要特殊的更改。你可以按照通常的方式编写Spring Boot应用程序。 java示例代码,展示了如何编写一个基…

lv3 嵌入式开发-3 linux shell命令(权限、输入输出)

1 Shell概述 随着各式Linux系统的图形化程度的不断提高,用户在桌面环境下,通过点击、拖拽等操作就可以完成大部分的工作。 然而,许多Ubuntu Linux功能使用shell命令来实现,要比使用图形界面交互,完成的更快、更直接。…

AcWing 785:快速排序 ← vector

【题目来源】https://www.acwing.com/problem/content/787/【题目描述】 给定你一个长度为 n 的整数数列。 请你使用快速排序对这个数列按照从小到大进行排序。 并将排好序的数列按顺序输出。【输入格式】 输入共两行,第一行包含整数 n。 第二行包含 n 个整数&#…

MOS场效应管

导体三极管中参与导电的有两种极性的载流子,所以也称为双极型三极管。本文将介绍另一种三极管,这种三极管只有一种载流子参与导电,所以也称为单极型三极管,因为这种管子是利用电场效应控制电流的,所以也叫场效应三极管…

JVM 垃圾收集器

重点:CMS,G1,ZGC 主要垃圾收集器如下,图中标出了它们的工作区域、垃圾收集算法,以及配合关系。 Serial 收集器 Serial 收集器是最基础、历史最悠久的收集器。 如同它的名字(串行)&#xff0c…

javaee spring 用注解的方式实现ioc

spring 用注解的方式实现ioc spring核心依赖 <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"…