Transformers实战——多项选择

文章目录

  • 一、导入相关包
  • 二、加载数据集
  • 三、数据集预处理
  • 四、创建模型
  • 五、创建评估函数
  • 六、配置训练参数
  • 七、创建训练器
  • 八、模型训练
  • 九、模型预测

!pip install transformers datasets evaluate accelerate 

一、导入相关包

import evaluate
from datasets import DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer

二、加载数据集

# c3 = DatasetDict.load_from_disk("./c3/") 从本地加载
# c3 = load_from_disk("./c3/") 同上
c3 = load_dataset("clue",'c3')
c3
'''
DatasetDict({test: Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 1625})train: Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 11869})validation: Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 3816})
})
'''
c3["train"][:10]
'''
{'id': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],'context': [['男:你今天晚上有时间吗?我们一起去看电影吧?', '女:你喜欢恐怖片和爱情片,但是我喜欢喜剧片,科幻片一般。所以……'],['男:足球比赛是明天上午八点开始吧?', '女:因为天气不好,比赛改到后天下午三点了。'],['女:今天下午的讨论会开得怎么样?', '男:我觉得发言的人太少了。'],['男:我记得你以前很爱吃巧克力,最近怎么不吃了,是在减肥吗?', '女:是啊,我希望自己能瘦一点儿。'],['女:过几天刘明就要从英国回来了。我还真有点儿想他了,记得那年他是刚过完中秋节走的。','男:可不是嘛!自从我去日本留学,就再也没见过他,算一算都五年了。','女:从2000年我们在学校第一次见面到现在已经快十年了。我还真想看看刘明变成什么样了!','男:你还别说,刘明肯定跟英国绅士一样,也许还能带回来一个英国女朋友呢。'],['男:好久不见了,最近忙什么呢?','女:最近我们单位要搞一个现代艺术展览,正忙着准备呢。','男:你们不是出版公司吗?为什么搞艺术展览?','女:对啊,这次展览是我们出版的一套艺术丛书的重要宣传活动。'],['男:会议结束后,你记得把空调和灯都关了。', '女:好的,我知道了,明天见。'],['男:你出国读书的事定了吗?', '女:思前想后,还拿不定主意呢。'],['男:这件衣服我要了,在哪儿交钱?', '女:前边右拐就有一个收银台,可以交现金,也可以刷卡。'],['男:小李啊,你是我见过的最爱干净的学生。','女:谢谢教授夸奖。不过,您是怎么看出来的?','男:不管我叫你做什么,你总是推得干干净净。','女:教授,我……']],'question': ['女的最喜欢哪种电影?','根据对话,可以知道什么?','关于这次讨论会,我们可以知道什么?','女的为什么不吃巧克力了?','现在大概是哪一年?','女的的公司为什么要做现代艺术展览?','他们最可能是什么关系?','女的是什么意思?','他们最可能在什么地方?','教授认为小李怎么样?'],'choice': [['恐怖片', '爱情片', '喜剧片', '科幻片'],['今天天气不好', '比赛时间变了', '校长忘了时间'],['会是昨天开的', '男的没有参加', '讨论得不热烈', '参加的人很少'],['刷牙了', '要减肥', '口渴了', '吃饱了'],['2005年', '2010年', '2008年', '2009年'],['传播文化', '宣传新书', '推广现代艺术', '体现企业文化'],['同事', '司机和客人', '医生和病人'],['不想出国', '出国太难', '还在犹豫', '不想决定'],['医院', '迪厅', '商场', '饭馆'],['卫生习惯非常好', '做事的能力不够', '找借口拒绝做事', '记不住该做的事']],'answer': ['喜剧片','比赛时间变了','讨论得不热烈','要减肥','2010年','宣传新书','同事','还在犹豫','商场','找借口拒绝做事']}
'''
# dataset本质上是一个字典,删除test键
c3.pop("test") # 删除test数据集
'''
Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 1625
})
...# 因为是字典,下列操作也支持
'''
c3.keys()
c3.values()
c3.items()
'''
c3
'''
DatasetDict({train: Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 11869})validation: Dataset({features: ['id', 'context', 'question', 'choice', 'answer'],num_rows: 3816})
})
'''

三、数据集预处理

tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-macbert-base")
tokenizer
'''
BertTokenizerFast(name_or_path='hfl/chinese-macbert-base', vocab_size=21128, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
'''
def process_function(examples):# examples, dict, keys: ["context", "quesiton", "choice", "answer"]# 假设examples有1000个context = []question_choice = []labels = []for idx in range(len(examples["context"])):ctx = "\n".join(examples["context"][idx])question = examples["question"][idx]choices = examples["choice"][idx]for choice in choices:context.append(ctx)question_choice.append(question + " " + choice)# 不足四个选项,补全四个选项if len(choices) < 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question + " " + "不知道")labels.append(choices.index(examples["answer"][idx]))tokenized_examples = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")     # input_ids: 4000 * 256,tokenized_examples = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}     # 1000 * 4 *256tokenized_examples["labels"] = labelsreturn tokenized_examples
res = c3["train"].select(range(10)).map(process_function, batched=True)
res
'''
Dataset({features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 10
})
'''
import numpy as np
np.array(res["input_ids"]).shape
'''
(10, 4, 256)
'''
tokenized_c3 = c3.map(process_function, batched=True)
tokenized_c3
'''
DatasetDict({train: Dataset({features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 11869})validation: Dataset({features: ['id', 'context', 'question', 'choice', 'answer', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],num_rows: 3816})
})
'''

四、创建模型

model = AutoModelForMultipleChoice.from_pretrained("hfl/chinese-macbert-base")

五、创建评估函数

import numpy as np # 切记这里predictions是np数组
accuracy = evaluate.load("accuracy")def compute_metric(pred):predictions, labels = predpredictions = np.argmax(predictions, axis=-1)return accuracy.compute(predictions=predictions, references=labels)

六、配置训练参数

  • fp16=True用混合精度训练
    • 混合精度训练需要 GPU 支持,特别是 NVIDIA 的 Volta 和 Turing 架构以及更高版本的 GPU。如果您在没有这些硬件的环境中启用了混合精度训练,可能会遇到错误。
    • 好处:更少的显存、更快的训练速度
    • 坏处:损失精度
args = TrainingArguments(output_dir="./muliple_choice",per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=3,logging_steps=50,evaluation_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True,fp16=True # 用混合精度训练,可以加速训练
)

七、创建训练器

trainer = Trainer(model=model,args=args,train_dataset=tokenized_c3["train"],eval_dataset=tokenized_c3["validation"],compute_metrics=compute_metric
)

八、模型训练

trainer.train()

九、模型预测

  • 多项选择任务 pipeline并没有现成的封装,需要自己写推理
from typing import Any
import torchclass MultipleChoicePipeline:def __init__(self, model, tokenizer) -> None:self.model = modelself.tokenizer = tokenizerself.device = model.devicedef preprocess(self, context, quesiton, choices):cs, qcs = [], []for choice in choices:cs.append(context)qcs.append(quesiton + " " + choice)return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")def predict(self, inputs):# inputs,扩充一个batch维度inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}return self.model(**inputs).logitsdef postprocess(self, logits, choices):predition = torch.argmax(logits, dim=-1).cpu().item()return choices[predition]def __call__(self, context, question, choices) -> Any:inputs = self.preprocess(context, question, choices)logits = self.predict(inputs)result = self.postprocess(logits, choices)return result
  • 单条预测
pipe = MultipleChoicePipeline(model, tokenizer)
  • 注意:这里不限于选项的个数,训练的时候限制了 4 个,推理的时候可以任意个数
pipe("小明在北京上班", "小明在哪里上班?", ["北京", "上海", "河北", "海南", "河北", "海南"])
'''
北京
'''
pipe("小明在北京上班", "小明在哪里上班?", ["北京", "上海"])
'''
北京
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/157810.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用ScheduledExecutorService接口,Quartz框架等创建定时任务

【点我-这里送书】 本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(…

小程序常见操作

测试时访问本地http服务器调用报错 微信开发者工具&#xff08;右上角&#xff09;-> 详情->本地设置->不校验合法域名、web-view(业务域名)... -> 去除勾选使用npm包 1) 工程目录下创建package.jsonnpm init(手动完成设定) / npm init -y (默认设定) 2) 安装 np…

SAP权限相关的表及如何使用FM获取用户权限

一、SAP权限相关的表 AGR_1016 活动组参数文件名称 AGR_1016B 活动组参数文件名称 AGR_1250 活动组的权限数据(通过权限对象 查 角色) AGR_1251 活动组的权限数据 AGR_1252 …

AVR单片机在机器人视觉导航中的应用研究

AVR单片机在机器人视觉导航中的应用是一项前沿的研究领域&#xff0c;旨在实现机器人在未知环境中的自主导航和避障功能。本文将介绍AVR单片机在机器人视觉导航中的应用原理和实现步骤&#xff0c;并提供相应的代码示例。 1. 导航概述 机器人视觉导航是基于计算机视觉和控制理…

SpringBoot 整合 JdbcTemplate(配置多数据源)

数据持久化有几个常见的方案&#xff0c;有 Spring 自带的 JdbcTemplate 、有 MyBatis&#xff0c;还有 JPA&#xff0c;在这些方案中&#xff0c;最简单的就是 Spring 自带的 JdbcTemplate 了&#xff0c;这个东西虽然没有 MyBatis 那么方便&#xff0c;但是比起最开始的 Jdbc…

c语言回文数

以下是用C语言编写的回文数代码&#xff1a; #include <stdio.h>int main() { int num, reversedNum 0, remainder, originalNum; printf("请输入一个正整数&#xff1a;"); scanf("%d", &num); originalNum num; while (num …

SCAUoj实验11 链表操作

SCAU链表oj题目 文章目录 前言一、堂前习题1099 [填空题]链表的合并 二、堂上练习1098 [填空]链表结点的插入1104 [填空题]链表的倒序1101 [填空题]链表的排序 前言 刚开始学习链表可能会看得比较头晕&#xff0c;关键在于先理解链表的逻辑结构和物理结构&#xff0c;尤其是逻辑…

CMAK Kafka可视化管理工具

CMAK简介 为了简化开发者和服务工程师维护Kafka集群的工作,yahoo构建了一个叫做Kafka管理器的基于Web工具,叫做 CMAK(原名Kafka Manager)。 这个管理工具可以很容易地发现分布在集群中的哪些topic分布不均匀,或者是分区在整个集群分布不均匀的的情况。 它支持管理多个集…

文本分析:NLP 魔法!

一、说明 这是一个关于 NLP 和分类项目的博客。NLP 是自然语言处理&#xff0c;目前需求量很大。让我们了解如何利用 NLP。我们将通过编码来理解流程和概念。我将在本博客中介绍 BagOfWords 和 n-gram 以及朴素贝叶斯分类模型。这个博客的独特之处&#xff08;这使得它很长&…

2023年度中国开源研究报告

截止为2023年11月的中国开源项目数字报告&#xff0c;计算了中国的开源项目的活动指标进行排名&#xff0c;可以看到排名第一的是百度的飞桨PaddlePaddle&#xff0c;前50的排名中人工智能相关的开源项目&#xff0c;占比越来越高&#xff0c;其中使用的编程语言主要有&#xf…

数据在金融行业的应用有哪些

在当今的数字化时代&#xff0c;数据已经成为金融行业不可或缺的一部分。从风险管理、投资决策、客户关系管理到监管合规&#xff0c;数据在金融领域的各个方面都发挥着重要作用。 ​那么&#xff0c;大数据在金融行业有哪些应用呢&#xff1f; 一、数据在金融行业中的应用 1…

单元测试实战(五)普通类的测试

为鼓励单元测试&#xff0c;特分门别类示例各种组件的测试代码并进行解说&#xff0c;供开发人员参考。 本文中的测试均基于JUnit5。 单元测试实战&#xff08;一&#xff09;Controller 的测试 单元测试实战&#xff08;二&#xff09;Service 的测试 单元测试实战&am…

Pod详解

Pod详解 1 .Pod介绍 1.1 Pod结构 每个Pod中都可以包含一个或者多个容器&#xff0c;这些容器可以分为两类&#xff1a; 用户程序所在的容器&#xff0c;数量可多可少 Pause容器&#xff0c;这是每个Pod都会有的一个根容器&#xff0c;它的作用有两个&#xff1a; 可以以它为…

小米集团收入增长失速已久:穿越寒冬,雷军的路走对了吗?

撰稿|行星 来源|贝多财经 11月20日&#xff0c;小米集团&#xff08;HK:01810&#xff0c;下称“小米”&#xff09;发布了截至2023年9月30日的第三季度业绩公告。 财报显示&#xff0c;在智能手机出货量下行、平均售价下跌的背景下&#xff0c;小米逆势而上&#xff0c;实现…

创建用户报错:ORA-65096: 公用用户名或角色名无效

题主的Oracle版本是最新的Oracle 21 描述&#xff1a; 1、在命令行工具 给Oracle创建用户&#xff0c;create user c##用户名identifed by 密码&#xff0c;报错&#xff1a;【ORA-65096: 公用用户名或角色名无效】 2、在navicat创建用户&#xff0c;提示如下&#xff1a; 解…

Windows系统如何安装与使用TortoiseSVN客户端,并实现在公网访问本地SVN服务器

文章目录 前言1. TortoiseSVN 客户端下载安装2. 创建检出文件夹3. 创建与提交文件4. 公网访问测试 前言 TortoiseSVN是一个开源的版本控制系统&#xff0c;它与Apache Subversion&#xff08;SVN&#xff09;集成在一起&#xff0c;提供了一个用户友好的界面&#xff0c;方便用…

并行与分布式计算 第8章 并行计算模型

文章目录 并行与分布式计算 第8章 并行计算模型8.1 并行算法基础8.1.1 并行算法的定义8.1.2并行算法的分类8.1.3算法的复杂度 8.2 并行计算模型8.2.1 PRAM (SIMD-SM)模型8.2.3 BSP (MIMD-DM)模型8.2.4LogP&#xff08;MIMD-DM&#xff09;模型 并行与分布式计算 第8章 并行计算…

java疫情期间社区出入管理系统-计算机毕业设计源码21295

摘 要 信息化社会内需要与之针对性的信息获取途径&#xff0c;但是途径的扩展基本上为人们所努力的方向&#xff0c;由于站在的角度存在偏差&#xff0c;人们经常能够获得不同类型信息&#xff0c;这也是技术最为难以攻克的课题。针对疫情期间社区出入管理等问题&#xff0c;对…

【算法挨揍日记】day21——64. 最小路径和、174. 地下城游戏

64. 最小路径和 64. 最小路径和 题目描述&#xff1a; 给定一个包含非负整数的 m x n 网格 grid &#xff0c;请找出一条从左上角到右下角的路径&#xff0c;使得路径上的数字总和为最小。 说明&#xff1a;每次只能向下或者向右移动一步。 解题思路&#xff1a; 状态表示&…

硬技能之上的软技巧(二)

在硬技能的基础上&#xff0c;如何运用其他软技巧来提升个人能力和职业发展。在之前的讨论中&#xff0c;我们提到了硬技能和软技巧的基本概念&#xff0c;以及如何运用一些软技巧来提升个人能力和职业发展。 本篇文章将进一步探讨其他软技巧&#xff0c;包括批判性思维、自我…