transformers训练(NLP)阅读理解(多项选择)

简介

在阅读理解任务中,有一种通过多项选择其中一个答案来训练机器的阅读理解。比如:给定一个或多个文档h,以及一个问题S和对应的多个答案候选,输出问题S的答案E,E是答案候选中的某一个选项。
这样的目的就是通过文档,问题,多个答案选其中一个,就是让机器更好文档中的含义,提高机器的阅读理解。
是不是感觉似陈相识,这不就是语文考试的必考题,阅读理解吗。。。。。。。。。。。。。。

给机器训练的数据集示例

如:

  • Context:老师把一个大玻璃瓶子带到学校,瓶子里装着满满的石头、玻璃碎片和沙子。之后,老师请学生把瓶子里的东西都倒出来,然后再装进去,先从沙子开始。每个学生都试了试,最后都发现没有足够的空间装所有的石头。老师指导学生重新装这个瓶子。这次,先从石头开始,最后再装沙子。石头装进去后,沙子就沉积在石头的周围,最后,所有东西都装进瓶子里了。老师说:“如果我们先从小的东西开始,把小东西装进去之后,大的石头就放不进去了。生活也是如此,如果你的生活先被不重要的事挤满了,那你就无法再装进更大、更重要的事了。”
  • Question:正确的装法是,先装?
  • Choices / Candidates:[“小东西”,“大东西”,“轻的东西”,“软的东西” ]
  • Answer:大东西

技术实现思路

多项选择任务,技术实现,这里难点涉及数据处理和训练与推理。其实就是将数据处理好,喂给模型进行训练与推理,让其理解文本。
这里采用格式是:

[CLS] 文本内容 [SEP] 问题 答案 [SEP]

这里涉及在多个候选答案中只取一个答案,让大模型理解文本。所以需要将文本、问题、答案,拆分为4条数据喂给大模型,在告知它正确答案,这样处理大模型才能读懂数据,也是它读取数据逻辑。

代码部分

# 导入包
import evaluate
import numpy as np
from datasets import DatasetDict
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer# 读取数据集
c3 = DatasetDict.load_from_disk("./c3/")# 打印训练的前条数据开下接口
examples=c3["train"][:5]print(examples)# 将数据处理为训练和测试
c3.pop("test")# 加载分词器,模型已下载到本地
tokenizer = AutoTokenizer.from_pretrained("chinese-macbert-large")# 对数据集进行处理,验证数据处理阶段
question_choice = []
labels = []
for idx in range(len(examples["context"])):ctx = "\n".join(examples["context"][idx])question = examples["question"][idx]choices = examples["choice"][idx]for choice in choices:context.append(ctx)question_choice.append(question + " " + choice)if len(choices) < 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question + " " + "不知道")print("========:", choices.index(examples["answer"][idx]))labels.append(choices.index(examples["answer"][idx]))
tokenized_examples = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")     # input_ids: 4000 * 256,
for k, v in tokenized_examples.items():print("k:", k, "v:", v)
tokenized_examples = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}     # 1000 * 4 *256
tokenized_examples["labels"] = labels# 处理数据函数
def process_function(examples):# examples, dict, keys: ["context", "quesiton", "choice", "answer"]# examples, 1000context = []question_choice = []labels = []for idx in range(len(examples["context"])):ctx = "\n".join(examples["context"][idx])question = examples["question"][idx]choices = examples["choice"][idx]for choice in choices:context.append(ctx)question_choice.append(question + " " + choice)if len(choices) < 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question + " " + "不知道")labels.append(choices.index(examples["answer"][idx]))# 使用分词器,对数据进行分词处理    tokenized_examples = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")     # input_ids: 4000 * 256,tokenized_examples = {k: [v[i: i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}     # 1000 * 4 *256tokenized_examples["labels"] = labelsreturn tokenized_examples# 处理数据
tokenized_c3 = c3.map(process_function, batched=True)# 加载模型
model = AutoModelForMultipleChoice.from_pretrained("chinese-macbert-large")# 创建评估函数
accuracy = evaluate.load("accuracy")def compute_metric(pred):predictions, labels = predpredictions = np.argmax(predictions, axis=-1)return accuracy.compute(predictions=predictions, references=labels)# 配置训练参数
args = TrainingArguments(output_dir="./muliple_choice",per_device_train_batch_size=16,per_device_eval_batch_size=16,num_train_epochs=1,logging_steps=50,eval_strategy="epoch",save_strategy="epoch",load_best_model_at_end=True,fp16=True
) # 创建训练器 
trainer = Trainer(model=model,args=args,tokenizer=tokenizer,train_dataset=tokenized_c3["train"],eval_dataset=tokenized_c3["validation"],compute_metrics=compute_metric
)  # 模型训练
trainer.train()# 模型预测
from typing import Any
import torchclass MultipleChoicePipeline:def __init__(self, model, tokenizer) -> None:self.model = modelself.tokenizer = tokenizerself.device = model.devicedef preprocess(self, context, quesiton, choices):cs, qcs = [], []for choice in choices:cs.append(context)qcs.append(quesiton + " " + choice)return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")def predict(self, inputs):inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}return self.model(**inputs).logitsdef postprocess(self, logits, choices):predition = torch.argmax(logits, dim=-1).cpu().item()return choices[predition]def __call__(self, context, question, choices) -> Any:inputs = self.preprocess(context, question, choices)logits = self.predict(inputs)result = self.postprocess(logits, choices)return resultpipe = MultipleChoicePipeline(model, tokenizer)
pipe("小明在北京上班", "小明在哪里上班?", ["北京", "上海", "河北", "海南", "河北", "海南"])

以上就是完整多项选择阅读理解的大模型训练代码

问题

1,这里的难点就是多项选择如何让大模型进行阅读理解训练思想,这里参考的就是语文里的阅读理解。
2,将数据处理成什么样子,大模型才能理解,才能去进行正确的训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/62310.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java 学习】面向程序的三大特性:封装、继承、多态

引言 1. 封装1.1 什么是封装呢&#xff1f;1.2 访问限定符1.3 使用封装 2. 继承2.1 为什么要有继承&#xff1f;2.2 继承的概念2.3 继承的语法2.4 访问父类成员2.4.1 子类中访问父类成员的变量2.4.2 访问父类的成员方法 2.5 super关键字2.6 子类的构造方法 3. 多态3.1 多态的概…

impala入门与实践

1.impala基本介绍 impala是cloudera提供的一款高效率的sql查询工具&#xff0c;提供实时的查询效果&#xff0c;官方测试性能比hive快10到100倍&#xff0c;其sql查询比sparkSQL还要更加快速&#xff0c;号称是当前大数据领域最快的查询sql工具。impala是参照谷歌的新三篇论文…

shell查看服务器的内存和CPU,实时使用情况

要查看服务器的内存和 CPU 实时使用情况&#xff0c;可以使用以下方法和命令&#xff1a; 1. 使用 top 运行 top 命令以显示实时的系统性能信息&#xff0c;包括 CPU 和内存使用情况。 top按 q 退出。输出内容包括&#xff1a; CPU 使用率&#xff1a;位于顶部&#xff0c;标…

java中链表的数据结构的理解

在 Java 中&#xff0c;链表是一种常见的数据结构&#xff0c;可以通过类的方式实现自定义链表。以下是关于 Java 中链表的数据结构和实现方式的详细介绍。 1. 自定义链表结构 Java 中链表通常由一个节点类 (ListNode) 和可能的链表操作类构成。 节点类 (ListNode) 这是链表…

结构方程模型(SEM)入门到精通:lavaan VS piecewiseSEM、全局估计/局域估计;潜变量分析、复合变量分析、贝叶斯SEM在生态学领域应用

目录 第一章 夯实基础 R/Rstudio简介及入门 第二章 结构方程模型&#xff08;SEM&#xff09;介绍 第三章 R语言SEM分析入门&#xff1a;lavaan VS piecewiseSEM 第四章 SEM全局估计&#xff08;lavaan&#xff09;在生态学领域高阶应用 第五章 SEM潜变量分析在生态学领域…

2.mybatis整体配置

文章目录 mybatis-config.xml介绍SqlSessionFactoryBuilderXMLConfigBuilderpropertiessetting类型别名&#xff08;typeAliases&#xff09;扫描插件(plugins)解析objectFactory(对象工厂)解析objectWrapperFactory解析reflectorFactorysettingsElement()方法环境配置&#xf…

把本地新项目初始化传到github

在本地项目根目录下初始化Git仓库 git init将项目文件添加到Git仓库,接下来&#xff0c;你需要将项目中的文件添加到Git仓库中。可以使用git add命令来添加文件或目录。如果你想要添加所有文件&#xff0c;可以使用.来表示当前目录中的所有文件&#xff1a; git add .提交项目…

软件测试丨Pytest 第三方插件与 Hook 函数

Pytest不仅是一个用于编写简单和复杂测试的框架&#xff0c;还有大量的第三方插件以及灵活的Hook函数供我们使用&#xff0c;这些功能大大增强了其在软件测试中的应用。通过使用Pytest&#xff0c;测试开发变得简便、安全、高效&#xff0c;同时也能帮助我们更快地修复Bug&…

小米PC电脑手机互联互通,小米妙享,小米电脑管家,老款小米笔记本怎么使用,其他品牌笔记本怎么使用,一分钟教会你

说在前面 之前我们体验过妙享中心&#xff0c;里面就有互联互通的全部能力&#xff0c;现在有了小米电脑管家&#xff0c;老款的笔记本竟然用不了&#xff0c;也可以理解&#xff0c;毕竟老款笔记本做系统研发的时候没有预留适配的文件补丁&#xff0c;至于其他品牌的winPC小米…

python爬虫案例——猫眼电影数据抓取之字体解密,多套字体文件解密方法(20)

文章目录 1、任务目标2、网站分析3、代码编写1、任务目标 目标网站:猫眼电影(https://www.maoyan.com/films?showType=2) 要求:抓取该网站下,所有即将上映电影的预约人数,保证能够获取到实时更新的内容;如下: 2、网站分析 进入目标网站,打开开发者模式,经过分析,我…

一分钟食用前端测试框架Jest

安装 其实食用Jest是很简单的,我们只需要安装Jest即可 npm install --save-dev jestyarn add --dev jestpnpm add --save-dev jest ESmodule 本身来说,Jest是不支持Esmodule的,他支持CommonJS,我们需要Babel改一下 npm i --save-dev babel-jest babel/core babel/preset-env …

MySQL中的ROW_NUMBER窗口函数简单了解下

ROW_NUMBER() 是 MySQL8引入的窗口函数之一&#xff0c;它为查询结果集中的每一行分配一个唯一的顺序号&#xff08;行号&#xff09;。这个顺序号是基于窗口函数的 ORDER BY 子句进行排序的&#xff0c;可以根据指定的排序顺序生成连续的整数值。 ROW_NUMBER() 在分页、去重、…

从 App Search 到 Elasticsearch — 挖掘搜索的未来

作者&#xff1a;来自 Elastic Nick Chow App Search 将在 9.0 版本中停用&#xff0c;但 Elasticsearch 拥有你构建强大的 AI 搜索体验所需的一切。以下是你需要了解的内容。 生成式人工智能的最新进展正在改变用户行为&#xff0c;激励开发人员创造更具活力、更直观、更引人入…

CTF之密码学(费纳姆密码)

一、作为二进制替换密码的费纳姆密码 定义&#xff1a;费纳姆密码是一种由二进制产生的替换密码&#xff0c;也被称为弗纳姆密码&#xff08;Vernam cipher&#xff09;。它采用二进制表示法&#xff0c;将明文转化为二进制数字&#xff0c;并通过与密钥进行模2加法运算来产生密…

若依框架部署在网站一个子目录下(/admin)问题(

部署在子目录下首先修改vue.config.js文件&#xff1a; 问题一&#xff1a;登陆之后跳转到了404页面问题&#xff0c;解决办法如下&#xff1a; src/router/index.js 把404页面直接变成了首页&#xff08;大佬有啥优雅的解决办法求告知&#xff09; 问题二&#xff1a;退出登录…

【贪心算法第六弹——334.递增的三元子序列(easy)】

目录 1.题目解析 题目来源 测试用例 2.算法原理 3.实战代码 代码解析 本题属于最长递增子序列的简化版本&#xff0c;只需要判断能不能组成三位的递增子序列即可&#xff0c;建议先去看博主的另一篇博客可以更好的理解本篇博客&#xff1a;300.最长递增子序列 1.题目解析…

c++(斗罗大陆)

这次&#xff0c;作者编了斗罗大陆的武魂、魂力等级&#xff0c;目前只写到了11级 #include<iostream> #include<conio.h> #include<windows.h> #include<stdlib.h> #include<stdio.h> #include<time.h> #include<strin…

Perl编程语言简介

文章目录 前言基础语法一. 变量1. 标量2. 数组3. 哈希 二. 控制结构1. 条件语句2. 循环语句 三. 自定义函数与作用域四. 文件读写五. 正则表达式1. 模式匹配2. 替换 六. 模块和包七. 面向对象的编程 总结 前言 Perl&#xff08;Practical Extraction and Report Language&…

《TCP/IP网络编程》学习笔记 | Chapter 16:关于 I/O 流分离的其他内容

《TCP/IP网络编程》学习笔记 | Chapter 16&#xff1a;关于 I/O 流分离的其他内容 《TCP/IP网络编程》学习笔记 | Chapter 16&#xff1a;关于 I/O 流分离的其他内容分离 I/O 流2 次 I/O 流分离分离「流」的好处「流」分离带来的 EOF 问题 文件描述符的的复制和半关闭终止「流」…

LeetCode数组题

参考链接 代码随想录 讲解视频链接 数组题 1、(两数之和)给定一个整数数组 nums 和一个整数目标值 target&#xff0c;请你在该数组中找出 和为目标值 target 的那 两个 整数&#xff0c;并返回它们的数组下标。你可以假设每种输入只会对应一个答案&#xff0c;并且你不能使用…