NLP transformers - 文本分类

在这里插入图片描述

Text classification

文章目录

  • Text classification
    • 加载 IMDb 数据集
    • Preprocess 预处理
    • Evaluate
    • Train
    • Inference


本文翻译自:Text classification
https://huggingface.co/docs/transformers/tasks/sequence_classification
notebook : https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/sequence_classification.ipynb


文本分类是一种常见的 NLP 任务,它为文本分配标签或类别。一些大公司在生产中运行文本分类,以实现广泛的实际应用。最流行的文本分类形式之一是 情感分析,它为文本序列分配 🙂 积极、🙁 消极或 😐 中性等标签。

本指南将向您展示:

  1. 在IMDb数据集上微调DistilBERT,以确定电影评论是正面还是负面。
  2. 使用您的微调模型进行推理。

本教程中演示的任务由以下模型架构支持:

ALBERT, BART, BERT, BigBird, BigBird-Pegasus, BioGpt, BLOOM, CamemBERT, CANINE, CodeLlama, ConvBERT, CTRL, Data2VecText, DeBERTa, DeBERTa-v2, DistilBERT, ELECTRA, ERNIE, ErnieM, ESM, Falcon, FlauBERT, FNet, Funnel Transformer, Gemma, GPT-Sw3, OpenAI GPT-2, GPTBigCode, GPT Neo, GPT NeoX, GPT-J, I-BERT, Jamba, LayoutLM, LayoutLMv2, LayoutLMv3, LED, LiLT, LLaMA, Longformer, LUKE, MarkupLM, mBART, MEGA, Megatron-BERT, Mistral, Mixtral, MobileBERT, MPNet, MPT, MRA, MT5, MVP, Nezha, Nyströmformer, OpenLlama, OpenAI GPT, OPT, Perceiver, Persimmon, Phi, PLBart, QDQBert, Qwen2, Qwen2MoE, Reformer, RemBERT, RoBERTa, RoBERTa-PreLayerNorm, RoCBert, RoFormer, SqueezeBERT, StableLm, Starcoder2, T5, TAPAS, Transformer-XL, UMT5, XLM, XLM-RoBERTa, XLM-RoBERTa-XL, XLNet, X-MOD, YOSO


在开始之前,请确保已安装所有必需的库:

pip install transformers datasets evaluate accelerate

我们鼓励您登录 Hugging Face 帐户,以便您可以上传模型并与社区分享。出现提示时,输入您的令牌进行登录:

from huggingface_hub import notebook_loginnotebook_login()

加载 IMDb 数据集

首先从 🤗 数据集库加载 IMDb 数据集:

from  datasets import load_datasetimdb = load_dataset("imdb")

然后看一个数据样例:

IMDB[ “测试” ][ 0 ]
{"label" : 0 ,"text" : "我喜欢科幻小说,并且愿意忍受很多。... 一切又来了。” ,
}

该数据集中有两个字段:

  • text: 影评文字。
  • label: 0:表示负面评论或1正面评论的值。

Preprocess 预处理

下一步是加载 DistilBERT 分词器来预处理该text字段:

from transformers import AutoTokenizertokenizer = AutoTokenizer.from _pretrained( "distilbert/distilbert-base-uncased" )

创建一个预处理函数来对text序列进行标记和截断,使其长度不超过 DistilBERT 的最大输入长度:

def  preprocess_function ( Examples ):return tokenizer(examples[ "text" ], truncation= True )

要将预处理函数应用于整个数据集,请使用 🤗 数据集 map 函数。
您可以map通过设置 batched=True 一次处理数据集的多个元素来加快速度:

tokenized_imdb = imdb.map(preprocess_function, batched=True)

现在使用 DataCollatorWithPadding 创建一批示例。在整理过程中 动态地将句子填充 到批次中的最长长度,比将整个数据集填充到最大长度更有效。

from transformers import DataCollatorWithPaddingdata_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Evaluate

在训练期间包含指标通常有助于评估模型的性能。您可以使用 🤗 Evaluate库快速加载评估方法。对于此任务,加载准确性指标(请参阅 🤗 评估快速浏览以了解有关如何加载和计算指标的更多信息):

import evaluateaccuracy = evaluate.load("accuracy")

然后创建一个传递预测和标签的函数来compute计算准确性:

import numpy as npdef compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)return accuracy.compute(predictions=predictions, references=labels) 

您的compute_metrics函数现在已准备就绪,您将在设置训练时返回该函数。


Train

在开始训练模型之前,请使用id2labellabel2id ,创建预期 id 到其标签的映射:

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

如果您不熟悉使用 Trainer 微调模型,
请查看基本教程:<(https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer>

您现在就可以开始训练您的模型了!使用 AutoModelForSequenceClassification 加载 DistilBERT以及预期标签的数量和标签映射:

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainermodel = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

此时,只剩下三步:

  1. 在TrainingArguments中定义训练超参数。
    唯一必需的参数是output_dir指定保存模型的位置。您可以通过设置将此模型推送到 Hub push_to_hub=True(您需要登录 Hugging Face 才能上传模型)。
    在每个 epoch 结束时,Trainer 将评估准确性并保存训练检查点。
  2. 将训练参数以及模型、数据集、分词器、数据整理器和compute_metrics函数传递给Trainer 。
  3. 调用 train() 来微调您的模型。
training_args = TrainingArguments(output_dir="my_awesome_model",learning_rate=2e-5,per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=2,weight_decay=0.01,evaluation_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True,push_to_hub=True,
)trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_imdb["train"],eval_dataset=tokenized_imdb["test"],tokenizer=tokenizer,data_collator=data_collator,compute_metrics=compute_metrics,
)trainer.train()

当您传递 token 给Trainer时, 它默认应用动态填充tokenizer。在这种情况下,您不需要显式指定数据整理器。

训练完成后,使用 push_to_hub()方法将您的模型共享到 Hub,以便每个人都可以使用您的模型:

trainer.push_to_hub()

有关如何微调文本分类模型的更深入示例,请查看相应的 PyTorch 笔记本 或 TensorFlow 笔记本。


Inference

太好了,现在您已经微调了模型,您可以使用它进行推理!

获取一些您想要进行推理的文本:

text = “这是一部杰作。并不完全忠实于原著,但从头到尾都令人着迷。可能是三本书中我最喜欢的。”

尝试微调模型进行推理的最简单方法是在 pipeline() 中使用它。使用您的模型实例化pipeline情感分析,并将文本传递给它:

from transformers import pipelineclassifier = pipeline("sentiment-analysis", model="stevhliu/my_awesome_model")
classifier(text)

如果您愿意,您还可以手动复制 pipeline 的结果:


对文本进行分词并返回 PyTorch 张量:

from transformers import AutoTokenizertokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_model")
inputs = tokenizer(text, return_tensors="pt")

将您的输入传递给模型并返回logits

from transformers import AutoModelForSequenceClassificationmodel = AutoModelForSequenceClassification.from_pretrained("stevhliu/my_awesome_model")with torch.no_grad():logits = model(**inputs).logits

获取概率最高的类,并使用模型的id2label映射将其转换为文本标签:

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
# -> 'POSITIVE'

2024-04-28(日)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/4741.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FPGA 以太网概念简单学习

1 MAC和PHY 从硬件的角度来说&#xff0c;以太网接口电路主要由 MAC &#xff08; Media Access Control &#xff09;控制器和物理层接口 PHY&#xff08;Physical Layer &#xff0c; PHY &#xff09;两大部分构成。 MAC 指媒体访问控制子层协议&#xff0c;它和 PHY 接…

创建获利段

事务代码&#xff1a;KE21N BAPI&#xff1a;BAPI_COPAACTUALS_POSTCOSTDATA 前台操作&#xff1a; 表是业务配置的 配置路径&#xff1a; 代码&#xff1a;BAPI不返回生成的凭证号和获利段&#xff0c;需要通过增强或者读表获取 ls_copa_data-record_id 000001.ls_co…

Agent AI智能体在未来,一定与你我密不可分

随着Agent AI智能体的逐渐成熟&#xff0c;人工智能应用的不断深入与拓展&#xff0c;相信在不久的将来&#xff0c;他与你我的生活一定是密不可分的。 目录 ​编辑 1 Agent AI智能体是什么&#xff1f; 2 Agent AI在语言处理方面的能力 2.1 情感分析示例 2.2 文本分类任…

Spring - 5 ( 8000 字 Spring 入门级教程 )

一&#xff1a;Spring IoC&DI 1.1 方法注解 Bean 类注解是添加到某个类上的&#xff0c; 但是存在两个问题: 使用外部包里的类, 没办法添加类注解⼀个类, 需要多个对象, ⽐如多个数据源 这种场景, 我们就需要使用方法注解 Bean 我们先来看方法注解如何使用: public c…

Unity 踩坑记录 Rigidbody 刚体重力失效

playerSetting > physics > Gravity > 设置 Y 的值为负数

前端canvas项目实战——在线图文编辑器(九):逻辑画布

目录 前言一、 效果展示二、 实现步骤1. 调整布局&#xff0c;最大化利用屏幕空间2. 添加逻辑画布3. 添加遮罩4. 居中显示逻辑画布5. 一个容易被忽视的bug点 三、Show u the code后记 前言 上一篇博文中&#xff0c;我们实现了一组通用的功能按钮&#xff1a;复制、删除、锁定…

FreeRTOS之列表

1.FreeRTOS的列表和列表项十分重要。列表类相当于链表&#xff0c;列表项则相当于链表中的节点。列表项的地址是非连续的&#xff0c;列表项的数量可随时修改。在OS中的任务状态和数量会发生改变&#xff0c;因此使用列表可以很好的满足需求。 列表和列表项的相关定义与操作函…

电商独立站||跨境电商独立站网站搭建|功能系统搭建||API接口接入

搭建多语言跨境电商独立站系统 前台主要功能模块 短信接口 第三方登陆 支付方式 会员中心 代购订单列表 - new 会员签到 -1000(1) new 支付密码 ---1000 国内流程 -----5000 new 订单运单多退少补 -1000 未付款运单取消功能 - 修改运单运输方式 -----1000 年费会员 -----3000 …

大型零售企业,适合什么样的企业邮箱大文件解决方案?

大型零售企业通常指的是在全球或特定地区内具有显著市场影响力和知名度的零售商。这些企业不仅在零售业务收入上达到了惊人的规模&#xff0c;而且在全球范围内拥有广泛的销售网络和实体店铺。它们在快速变化的零售行业中持续创新&#xff0c;通过实体店、电商平台等多种渠道吸…

C#队列(Queue)的基本使用

概述 在编程中&#xff0c;队列&#xff08;Queue&#xff09;是一种常见的数据结构&#xff0c;它遵循FIFO&#xff08;先进先出&#xff09;的原则。在C#中&#xff0c;.NET Framework提供了Queue<T>类&#xff0c;它位于System.Collections.Generic命名空间下&#x…

【深度学习实战(26)】标签处理之语义分割标签转换,数据集划分

一、标签转换 我们在使用labeme标签工具&#xff0c;标注完数据后会获得json文件。在标注结束过后&#xff0c;我们需要通过标签转换操作&#xff0c;生成jpg格式原始图片和png格式mask标签图。 1.1 使用img_b64_to_arr将json标签中二进制图像数据变成numpy格式数据&#xf…

selenium在Pycharm中结合python的基本使用、交互、无界面访问

下载 下载与浏览器匹配的浏览器驱动文件&#xff0c;这里一定注意的是&#xff0c;要选择和浏览器版本号相同的驱动程序&#xff0c;否则后面会有很多问题。 &#xff08;1&#xff09;浏览器&#xff08;以google为例&#xff09;版本号的查询&#xff1a; 我这里的版本号是1…

java实现模板填充word,word转pdf,pdf转图片

Java实现Word转PDF及PDF转图片 在日常开发中&#xff0c;我们经常需要将文件操作&#xff0c;比如&#xff1a; 根据模板填充wordword文档中插入图片Word文档转换为PDF格式将PDF文件转换为图片。 这些转换可以帮助我们在不同的场景下展示或处理文档内容。下面&#xff0c;我将…

Leetcode—1256. 加密数字【中等】Plus(bitset、find_first_not_of、erase)

2024每日刷题&#xff08;120&#xff09; Leetcode—1256. 加密数字 实现代码 class Solution { public:string encode(int num) {string ans;num 1;while(num ! 0) {ans to_string(num & 1);num num >> 1;}if(ans.empty()) {return "";} else {stri…

coreldraw2024精简版绿色版安装包免费下载

CorelDRAW 2024是一款矢量图形设计软件&#xff0c;于2024年3月5日正式在全球范围内发布。这款软件在多个方面进行了更新和改进&#xff0c;为用户提供了更多高效、灵活和便捷的设计工具。 首先&#xff0c;CorelDRAW 2024新增了绘画笔刷功能&#xff0c;这些笔刷不仅模拟了传…

Ubuntu20.04 [Ros Noetic]版本——在catkin_make编译时出现报错的解决方案

今天在新的笔记本电脑上进行catkin_make的编译过程中遇到了报错&#xff0c;这个报错在之前也遇到过&#xff0c;但是&#xff0c;我却忘了怎么解决。很是头痛&#xff01; 经过多篇博客的查询&#xff0c;特此解决了这个编译报错的问题&#xff0c;于此特地记录&#xff01;&…

[论文笔记]SEARCHING FOR ACTIVATION FUNCTIONS

引言 今天带来另一篇激活函数论文SEARCHING FOR ACTIVATION FUNCTIONS的笔记。 作者利用自动搜索技术来发现新的激活函数。通过结合详尽的搜索和基于强化学习的搜索&#xff0c;通过实验发现最佳的激活函数 f ( x ) x ⋅ sigmoid ( β x ) f(x) x \cdot \text{sigmoid}(βx…

瓦片编辑器成功移植到小熊猫C++ 2.25.1版本,解决_findnext移植问题

移植之后出现绿色屏幕闪退 查了版本回滚直到不闪退&#xff0c;发现是在读取自定义文件上出问题 然后在找读取自定义文件函数&#xff0c;发现是读取图片部分出问题 然后就卡住了 调试半天&#xff0c;不是数据溢出&#xff0c;于是就看 函数_findnext,网上搜 ———_findn…

4.Docker本地镜像发布至阿里云仓库、私有仓库、DockerHub

文章目录 0、镜像的生成方法1、本地镜像发布到阿里云仓库2、本地镜像发布到私有仓库3、本地镜像发布到Docker Hub仓库 Docker仓库是集中存放镜像的地方&#xff0c;分为公共仓库和私有仓库。 注册服务器是存放仓库的具体服务器&#xff0c;一个注册服务器上可以有多个仓库&…

项目开发规范

Restful REST&#xff0c;表述性状态转换&#xff0c;他是一种软件架构风格 使用URL定位资源&#xff0c;HTTP动词描述操作 根据发出请求类型来区分操作 GET&#xff1a; 查询id为1的用户POST&#xff1a;新增用户PUT&#xff1a;修改用户DELETE&#xff1a;删除id为1的用户 …