使用Bert模型优化Padding策略:加速文本分类训练

文章目录

    • 摘要
    • 介绍
    • 导包
    • 模型
    • dataset
    • BertCLS
    • 速度对比
    • 代码开源地址

摘要

本文探讨了如何通过优化Padding策略,提高基于Bert的文本分类模型的训练速度。我们比较了两种不同的Padding方式:一种是将输入数据统一填充到最大长度512,另一种是只将每个Batch中的数据填充到Batch中最长的样本长度。

通过实验结果证明,后者显著减少了训练时间,且不影响模型的性能。

介绍

在本节中,我们研究 🤗 Transformers 中标记器的功能。分析 batch 数据 的padding。
下述 tokenizer 代码,会把每一条数据,padding 到最大长度 512。

tokenized_inputs = tokenizer(item["text"],max_length=512,padding="max_length",truncation=True,
)

下述代码仅仅进行 tokenize 化,只进行截断,但暂时还 不进行padding。把 padding 操作留到 后续的 batch 数据处理上,这样只需要padding 到 batch 数据里最长的数据长度,而无需 padding 到512。由于减少了input_ids的数据长度,所以在一定程度上可以加快模型的训练和推理速度。

tokenized_inputs = tokenizer(item["text"],max_length=512,truncation=True,
)

导包

import os
import evaluate
import numpy as np
from datasets import load_dataset
from transformers import (Trainer,TrainingArguments,AutoModelForSequenceClassification,AutoTokenizer,DataCollatorWithPadding,
)
from dataclasses import dataclass

如果无法连接huggingface,利用c l a s h 的代理。

os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'

模型

使用 Bert 模型,进行实验。

model_name = "bert-base-chinese"bert = AutoModelForSequenceClassification.from_pretrained(model_name,trust_remote_code=True,num_labels=2,
)tokenizer = AutoTokenizer.from_pretrained(model_name)

dataset

使用 huggingface 平台的 中文二分类数据集。

ds = load_dataset("lansinuote/ChnSentiCorp")

使用 transformers 的 DataCollatorWithPadding。

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
def tokenize_func_pad(item):global tokenizertokenized_inputs = tokenizer(item["text"],max_length=512,truncation=True,padding="max_length")return tokenized_inputs
batch_dataset_pad = ds["train"].map(tokenize_func_pad, remove_columns=["text"]
)
batch_dataset_pad.set_format("torch")
batch_dataset_pad["input_ids"].shape

输出结果:

torch.Size([9600, 512])

batch_dataset_pad["input_ids"].shapetorch.Size([9600, 512]),每一条数据都padding 到512长度。这里面会包含大量的<pad>。

只需要在每一个 batch 数据中,把其中所有的样本 padding 到这一个batch最长的样本长度即可。无需把所有的样本都 padding 到 512。

def tokenize_func(item):global tokenizertokenized_inputs = tokenizer(item["text"],max_length=512,truncation=True,)return tokenized_inputs
batch_dataset = ds["train"].map(tokenize_func, remove_columns=["text"]
)data = [batch_dataset[i] for i in range(16)]
data_collator(data)["input_ids"].shape

输出结果:

torch.Size([16, 126])

上述结果展示了,把当前batch里的数据,padding到当前batch的最大长度126。

data_collator 可传递给 DataLoader 和 Trainer

BertCLS

下述是 Bert 封装好的一个工具类。方便读者使用,训练模型、评估、预测。

@dataclass
class BertCLS:def __init__(self,model,tokenizer,train_dataset=None,eval_dataset=None,output_dir="output",epoch=3,):self.model = modelself.tokenizer = tokenizerself.train_dataset = train_datasetself.eval_dataset = eval_datasetself.data_collator = DataCollatorWithPadding(tokenizer=tokenizer)self.args = self.get_args(output_dir, epoch)self.trainer = Trainer(model=self.model,args=self.args,train_dataset=self.train_dataset,eval_dataset=self.eval_dataset,data_collator=self.data_collator,# compute_metrics=compute_metrics,tokenizer=tokenizer,)def get_args(self, output_dir, epoch):if self.eval_dataset:args = TrainingArguments(output_dir=output_dir,evaluation_strategy="epoch",save_strategy="epoch",save_total_limit=3,learning_rate=2e-5,num_train_epochs=epoch,weight_decay=0.01,per_device_train_batch_size=32,per_device_eval_batch_size=16,# logging_steps=16,save_safetensors=True,overwrite_output_dir=True,load_best_model_at_end=True,)else:args = TrainingArguments(output_dir=output_dir,evaluation_strategy="no",save_strategy="epoch",save_total_limit=3,learning_rate=2e-5,num_train_epochs=epoch,weight_decay=0.01,per_device_train_batch_size=32,per_device_eval_batch_size=16,# logging_steps=16,save_safetensors=True,overwrite_output_dir=True,# load_best_model_at_end=True,)return argsdef set_args(self, args):"""从外部重新设置 TrainingArguments,args 更新后,trainer也进行更新"""self.args = argsself.trainer = Trainer(model=self.model,args=self.args,train_dataset=self.train_dataset,eval_dataset=self.eval_dataset,data_collator=self.data_collator,# compute_metrics=compute_metrics,tokenizer=self.tokenizer,)def train(self, epoch=None, over_write=False):if epoch:self.args.num_train_epochs = epochbest_model_path = os.path.join(self.args.output_dir, "best_model")if over_write or not os.path.exists(best_model_path):self.trainer.train()self.trainer.save_model(best_model_path)else:print(f"预训练权重 {best_model_path} 已存在,且over_write={over_write}。不启动模型训练!")def eval(self, eval_dataset):predictions = self.trainer.predict(eval_dataset)preds = np.argmax(predictions.predictions, axis=-1)metric = evaluate.load("glue", "mrpc")return metric.compute(predictions=preds, references=predictions.label_ids)def pred(self, pred_dataset):predictions = self.trainer.predict(pred_dataset)preds = np.argmax(predictions.predictions, axis=-1)return pred_dataset.add_column("pred", preds)

速度对比

只给 BertCLS 训练数据集,开始训练模型。如果要看模型训练的评估结果,输入评估数据集即可。

padding 到最大长度:

BertCLS(bert, tokenizer, batch_dataset_pad, epoch=1).train()
 [300/300 03:09, Epoch 1/1]

padding 到每个batch的最大长度:

BertCLS(bert, tokenizer, batch_dataset, epoch=1).train()
 [300/300 02:19, Epoch 1/1]

padding 到 512个长度,训练3分9秒结束。如果只padding到每个batch的最大长度,训练2分19秒结束。

代码开源地址

https://github.com/JieShenAI/csdn/blob/main/24/09/tokenizer_pad/pad_vs.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/52441.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UE4_后期处理五—饱和度调整、隔离、扭曲、重影

一、色彩饱和度调整&#xff1a; 原图 后期处理材质节点&#xff1a; 效果图&#xff1a; 可以根据参数saturation调整饱和还是去饱和。 当saturation为1时&#xff1a;去饱和度&#xff0c;如下图&#xff1a; 当saturation为0时&#xff1a;原始的一个状态&#xff0c;如下…

Debian安装Nodejs与npm

仅做记录&#xff0c;apt install nodejs 只会安装 Node.js 本身&#xff0c;而不会自动安装 npm sudo apt install -y nodejssudo apt install -y npm

展会通过智慧客流统计分析优化运营策略-讯鹏科技

在当今数字化高速发展的时代&#xff0c;展会行业也在积极探索利用智慧科技进行转型与升级。其中&#xff0c;智慧客流统计分析成为了展会优化运营策略的关键要素。 智慧客流统计分析首先为展会提供了精准的数据支撑。通过先进的传感器、摄像头等设备&#xff0c;能够实时、准确…

常用的jdk下载地址

jdk下载地址 安装方式可以看之前的博客&#xff1a; mac安装jdk oracle 版本&#xff1a;https://www.oracle.com/java/technologies/downloads/ Eclipse Temurin版本&#xff1a;https://adoptium.net/zh-CN/temurin/releases/ 阿里版本&#xff1a; github&#xff1a;htt…

PyCharm修改背景颜色、修改字体大小+Python常用快捷键+Python常见的运算符

文章目录 PyCharm软件的使用1. 修改背景颜色和字体大小1.1 修改背景颜色1.2 修改字体大小 2. 常用的快捷键3. 常见的运算符3.1 算术运算符3.2 赋值运算符3.3 比较运算符3.4 逻辑运算符 PyCharm软件的使用 1. 修改背景颜色和字体大小 1.1 修改背景颜色 1.2 修改字体大小 2. 常…

图文并茂带你理解Java的SPI机制

目录 一、Java的SPI机制1、什么是Java的SPI &#xff1f;2、JavaSPI 代码示例 (使用Maven项目演示)3、 JavaSPI 机制的核心-ServiceLoader4、实现自己的ServiceLoader5、Java中还有哪些SPI实现&#xff1f; 一、Java的SPI机制 1、什么是Java的SPI &#xff1f; SPI全称 Servi…

字符串API

字符串作为经常使用的数据类型&#xff0c;它们的API种类繁多&#xff0c;为了提升效率&#xff0c;不少API的底层实现可能会用更接近机器优化的代码实现。 不同语言的实现 用于内核或驱动调用的copy/compare API&#xff0c;大部分有机器架构汇编代码实现以加速处理&#xff…

中秋快到了,要给哪些国外客户送祝福(附贺卡模板)

马上就要中秋节了&#xff0c;在这里提前祝小伙伴们中秋节快乐&#xff0c;身体健康&#xff0c;阖家团圆&#xff0c;业绩越来越好&#xff0c;公司越来越好&#xff0c;一切都越来越好&#xff01; 中秋节是我们非常重要的几个传统节日之一了&#xff0c;除了我们自己庆祝之…

计算机网络练级第一级————认识网络

目录 网络搁哪&#xff1f; 网络的发展史&#xff08;了解&#xff09; 独立模式&#xff1a; 网络互联&#xff1a; 局域网时期&#xff1a; 广域网时期&#xff1a; 什么是协议 TCP/IP五层/四层模型 用官话来说&#xff1a; 我自己的话来说 第一层应用层&#xff1…

java中JTS对空间数据Geometry进行坐标系投影转换

java中JTS对空间数据Geometry进行坐标系投影转换 代码&#xff1a; /*** Description: 用JTS对Geometry空间数据进行坐标系投影转换** Param: [params]* Return: Geometry* Author yanghaoxing* Date 2024/9/10 14:54*/public Geometry getGeometryForlong(Geometry geometr…

Python+selenium自动化元素定位防踩坑(建议收藏)

踩坑一&#xff1a;StaleElementReferenceException selenium.common.exceptions.StaleElementReferenceException: Message: stale element reference: element is not attached to the page document 异常原因&#xff1a; 意思是&#xff0c;引用的元素已过期。原因是页面…

学习记录之C语言学习笔记1

1. 数据类型 基本数据类型&#xff1a;整型&#xff08;int&#xff09;、浮点型&#xff08;float&#xff09;、字符型&#xff08;char&#xff09;和双精度浮点型&#xff08;double&#xff09;。 派生数据类型&#xff1a;数组、结构体、联合体和枚举。 void类型…

soup.find(‘div‘)获取的数据长度为3,为什么1和3都是空的?

用beautifulSoup中的find&#xff08;‘div’&#xff09;可以获取一个div数据&#xff0c;为什么用len&#xff08;&#xff09;计数是显示长度为3&#xff1f; 实际在打印输出时&#xff0c;1和3又没有内容输出&#xff1f;用print&#xff08;div【0】&#xff09;和print&…

Java小白一文讲清Java中集合相关的知识点(七)

LinkedHashSet LinkedHashSet是HashSet的子类 LinkedHashSet底层是一个LinkedHashMap,底层维护了一个数组双向链表 而在之前讲的HashSet中的链表是单向的哈&#xff0c;注意区分&#xff01; LinkedHashSet根据元素的hashcode值来决定元素的存储位置&#xff0c;同时使用链表…

application/x-www-form-urlencoded与multipart/form-data与application/json的区别

前端数据传递至后台时&#xff0c;需要对其进行编码&#xff0c;其中&#xff0c;编码格式可分为四种&#xff1a;application/x-www-form-urlencoded&#xff0c;multipart/form-data&#xff0c;application/json&#xff0c;text/plain。 text/plain是纯文本数据&#xff0…

极限编程XP例题

答案&#xff1a;D 解析&#xff1a; 结对编程&#xff0c;一个人写代码&#xff0c;一个人看&#xff0c;由于是两个或两个以上的人负责&#xff0c;因此选项A 支持共同代码拥有和共同对系统负责是正确的 选项B 由于是一个人写一个人看&#xff0c;变相实现了代码审查 选项…

深入了解 GROW with SAP:它究竟是什么?

GROW with SAP 是一套综合全面的产品组合&#xff0c;包含一系列解决方案、加速采用服务、社区支持和学习资源&#xff0c;能够确保各种规模的企业成功采用 ERP 云软件。部署 GROW with SAP 后&#xff0c;企业可以采用 SAP S/4HANA Cloud Public Edition [ERP 公有云版]。在 S…

4 路由模式

路由模式 逻辑图 如果我们将生产环境的日志进行处理&#xff0c;而日志是分等级的&#xff0c;我们就按照 error waring info三个等级来讲解 一个消费者是处理【所有】&#xff08;info&#xff0c;error&#xff0c;warning&#xff09;的日志&#xff0c;用于做数据仓库&am…

Redis搭建集成

图示 正常来讲配置一主两从需要三台服务器,博主内存告急,就使用一台进行操作了,使用多台跟一台操作没有区别,只是多台不需要新建太多配置文件 一. 准备配置文件 如果你跟我一样是在一台服务器里面进行配置主从服务的,跟我一起操作即可 找到redis目录 在bin目录同位置创建一…

Linux驱动.之驱动开发思维,设备,驱动,总线分析思想,驱动的分类(字符设备,块设备,网络设备)

在stm32&#xff0c;裸机开发时&#xff0c;偏底层&#xff0c;跟寄存器打交道&#xff0c;有些MCU提供了库&#xff0c;库也还是操作寄存器的&#xff0c;通过配置寄存器&#xff0c; 配置各种工作模式&#xff0c;时钟&#xff0c;等等&#xff0c;交换数据等等。 Linux下驱…