基于transformers框架实践Bert系列3-单选题

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解单选题应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

目录

  • 1 环境说明
  • 2 前期准备
    • 2.1 了解Bert的输入输出
    • 2.2 数据集与模型
    • 2.3 任务说明
    • 2.4 实现关键
  • 3 关键代码
    • 3.1 数据集处理
    • 3.2 模型加载
    • 3.3 评估函数
  • 4 整体代码
  • 5 运行效果

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。(注意:这是关键输出,本次选择任务就需要获取该值,并进行一次线性层处理
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:c3
2)模型权重使用:bert-base-chinese

2.3 任务说明

1)单选题其实就是根据给定的文本内容、题目、选项,最终返回一个得分最高的选项。可以当做一个文本匹配,也可以当做一个分类问题,这里使用分类方式来解决
2)将内容、题目、选项(假设是4个选项)都组装成4条句子,如下:
在这里插入图片描述
然后通过每条句子做一个classifier分类,得到的结果在通过softmax取出最佳的结果。

2.4 实现关键

1)每道题都将其组装成4个句子,如果不够4个选项,补充到4个选项,可以使用“不知道”选项补充
2)将所有句子都放入bert选项,增加一个线性层,将其结果映射为41,将通过转换为14矩阵,通过softmax得到最佳答案

3 关键代码

3.1 数据集处理

# 1. 将每条问题的contex、question分别和4个选项组成4条句子,也就是4倍的datas数量*max_length
# 2. 将输入内容变成 datas数量*4*max_length
def process_function(datas):context = []question_choice = []labels = []for idx in range(len(datas["context"])):ctx = "\n".join(datas["context"][idx])question = datas["question"][idx]choices = datas["choice"][idx]for choice in choices:context.append(ctx)question_choice.append(question + " " + choice)if len(choices) < 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question + " " + "不知道")labels.append(choices.index(datas["answer"][idx]))tokenized_datas = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")# 将原先的4个选项都是自身组成一条句子,token之后是一个2维的向量:4倍的datas数量*max_length,这里需要将其变成3维:datas数量*4*max_lengthtokenized_datas = {key: [data[i: i + 4] for i in range(0, len(data), 4)] for key, data in tokenized_datas.items()}tokenized_datas["labels"] = labelsreturn tokenized_datas

3.2 模型加载

model = BertForMultipleChoice.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForMultipleChoice,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForMultipleChoice也可以,其实AutoModel最终返回的也是BertForMultipleChoice,它是根据你config中的model_type去匹配的。
这里列一下BertForMultipleChoice的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了线性层,其映射为1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, 1)
# 1. 获取到num_choices的数量,因为数据处理后,input_ids的shape是datas数量*4*max_length
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
# 2.将input_ids、attention_mask、token_type_ids、position_ids 都转换为2维,因为bert接受的是2为向量,因此又变成4倍的datas数量*max_length
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
inputs_embeds = (inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))if inputs_embeds is not Noneelse None
)outputs = self.bert(input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,
)
# 3. 获取pooler_output,这时它的shape是4倍的datas数量*hidden_size
pooled_output = outputs[1]pooled_output = self.dropout(pooled_output)
# 4. 通过一个线性层,将结果变成(4倍的datas数量*1)的矩阵
logits = self.classifier(pooled_output)
# 5. 通过view,将其变成(datas数量*4)
reshaped_logits = logits.view(-1, num_choices)

3.3 评估函数

这里采用evaluate库加载accuracy准确度计算方式来做评估,本次实验将accuracy的计算py文件下载下来,因此也是本地加载

# 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
accuracy = evaluate.load("./evaluate/metric_accuracy.py")def evaluate_function(prepredictions):predictions, labels = prepredictionspredictions = numpy.argmax(predictions, axis=-1)return accuracy.compute(predictions=predictions, references=labels)

4 整体代码

"""
基于BERT做单选题
1)数据集来自:c3
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
import numpy
import torch
import evaluate
from typing import Any
from datasets import DatasetDict
from transformers import BertForMultipleChoice, TrainingArguments, Trainer, BertTokenizerFastmodel_path = "./model/tiansz/bert-base-chinese"
data_path = "data/c3"# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
# test数据集没有答案answer,因此去除,也不做模型评估
datasets.pop("test")
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):context = []question_choice = []labels = []for idx in range(len(datas["context"])):ctx = "\n".join(datas["context"][idx])question = datas["question"][idx]choices = datas["choice"][idx]for choice in choices:context.append(ctx)question_choice.append(question + " " + choice)if len(choices) < 4:for _ in range(4 - len(choices)):context.append(ctx)question_choice.append(question + " " + "不知道")labels.append(choices.index(datas["answer"][idx]))tokenized_datas = tokenizer(context, question_choice, truncation="only_first", max_length=256, padding="max_length")# 将原先的4个选项都是自身组成一条句子,token之后是一个2维的向量:4倍的datas数量*max_length,这里需要将其变成3维:datas数量*4*max_lengthtokenized_datas = {key: [data[i: i + 4] for i in range(0, len(data), 4)] for key, data in tokenized_datas.items()}tokenized_datas["labels"] = labelsreturn tokenized_datasnew_datasets = datasets.map(process_function, batched=True)# step 3 加载模型
model = BertForMultipleChoice.from_pretrained(model_path)# step 4 评估函数:此处的评估函数可以从https://github.com/huggingface/evaluate下载到本地
accuracy = evaluate.load("./evaluate/metric_accuracy.py")def evaluate_function(prepredictions):predictions, labels = prepredictionspredictions = numpy.argmax(predictions, axis=-1)return accuracy.compute(predictions=predictions, references=labels)# step 5 创建TrainingArguments
# train是11869条数据,batch_size=16,因此每个epoch的step=742,总step=2226
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹per_device_train_batch_size=16,  # 训练时的batch_sizeper_device_eval_batch_size=16,    # 验证时的batch_sizenum_train_epochs=3,              # 训练轮数logging_steps=100,                # log 打印的频率evaluation_strategy="epoch",     # 评估策略save_strategy="epoch",           # 保存策略save_total_limit=3,              # 最大保存数load_best_model_at_end=True      # 训练完成后加载最优模型)# step 6 创建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],eval_dataset=new_datasets["validation"],compute_metrics=evaluate_function,)# step 7 训练
trainer.train()# step 8 模型预测
class MultipleChoicePipeline:def __init__(self, model, tokenizer) -> None:self.model = modelself.tokenizer = tokenizerself.device = model.devicedef preprocess(self, context, quesiton, choices):cs, qcs = [], []for choice in choices:cs.append(context)qcs.append(quesiton + " " + choice)return tokenizer(cs, qcs, truncation="only_first", max_length=256, return_tensors="pt")def predict(self, inputs):inputs = {k: v.unsqueeze(0).to(self.device) for k, v in inputs.items()}return self.model(**inputs).logitsdef postprocess(self, logits, choices):predition = torch.argmax(logits, dim=-1).cpu().item()return choices[predition]def __call__(self, context, question, choices) -> Any:inputs = self.preprocess(context, question, choices)logits = self.predict(inputs)result = self.postprocess(logits, choices)return resultpipe = MultipleChoicePipeline(model, tokenizer)
res = pipe("男:还是古典音乐好听,我就受不了摇滚乐,实在太吵了。女:那是因为你没听过现场,流行乐和摇滚乐才有感觉呢。","男的喜欢什么样的音乐?", ["古典音乐", "摇滚音乐", "流行音乐", "乡村音乐"])
print(res)

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/14815.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaEE】加法计算器与用户登录实战演练

目录 综合练习加法计算器1. 准备工作2. 约定前后端交互接口3. 服务器代码 用户登录1. 准备工作2. 约定前后端交互接口3. 服务器代码4. 调整前端页面代码 综合练习 理解前后端交互过程接⼝传参, 数据返回, 以及⻚⾯展⽰ 加法计算器 需求: 输⼊两个整数, 点击"点击相加&q…

56. UE5 RPG 给敌人添加AI实现跟随玩家

在这一篇里&#xff0c;我们要实现一下敌人的AI&#xff0c;敌人也需要一系列的行为&#xff0c;比如朝向英雄攻击&#xff0c;移动&#xff0c;在满足条件时施放技能。这些敌人的行为可以通过使用UE的内置的AI系统去实现。 在UE里&#xff0c;只要是基于Character类创建的蓝图…

安卓绕过限制直接使用Android/data无需授权,支持安卓14(部分)

大家都知道&#xff0c;安卓每次更新都会给权限划分的更细、收的更紧。   早在安卓11的时候还可以直接通过授权Android/data来实现操作其他软件的目录&#xff0c;没有之前安卓11授权的图了&#xff0c;反正都长一个样&#xff0c;就直接贴新图了。   后面到了安卓12~13的…

信息系统项目管理师0128:输出(8项目整合管理—8.6管理项目知识—8.6.3输出)

点击查看专栏目录 文章目录 8.6.3 输出 8.6.3 输出 经验教训登记册 经验教训登记册可以包含执行情况的类别和详细的描述&#xff0c;还可包括与执行情况相关的影响、建议和行动方案。经验教训登记册可以记录遇到的挑战、问题、意识到的风险和机会以及其他适用的内容。经验教训…

Debezium+Kafka:Oracle 11g 数据实时同步至 DolphinDB 解决方案

随着越来越多用户使用 DolphinDB&#xff0c;各式各样的应用场景对 DolphinDB 的数据接入提出了不同的要求。部分用户需要将 Oracle 11g 的数据实时同步到 DolphinDB 中来&#xff0c;以满足在 DolphinDB 中实时使用数据的需求。本篇教程将介绍使用 Debezium 来实时捕获和发布 …

npm介绍、常用命令详解以及什么是全局目录

目录 npm介绍、常用命令详解以及什么是全局目录一、介绍npm的主要功能npm仓库npm的配置npm的版本控制 二、命令1. npm init: 初始化一个新的Node.js项目&#xff0c;创建package.json文件。package.json是一个描述项目信息和依赖关系的文件。2. npm install <package_name&g…

LeetCode算法题:42. 接雨水(Java)

题目描述 给定 n 个非负整数表示每个宽度为 1 的柱子的高度图&#xff0c;计算按此排列的柱子&#xff0c;下雨之后能接多少雨水。 示例 1&#xff1a; 输入&#xff1a;height [0,1,0,2,1,0,1,3,2,1,2,1] 输出&#xff1a;6 解释&#xff1a;上面是由数组 [0,1,0,2,1,0,1,3…

基于vue3速学angular

因为工作原因&#xff0c;需要接手新的项目&#xff0c;新的项目是angular框架的&#xff0c;自学下和vue3的区别&#xff0c;写篇博客记录下&#xff1a; 参考&#xff1a;https://zhuanlan.zhihu.com/p/546843290?utm_id0 1.结构上&#xff1a; vue3:一个vue文件&#xff…

python:pycharm虚拟解释器报错环境位置目录为空

目录 解释器分控制台解释器 和 pycharm解释器 控制台解释器切换&#xff1a; pycharm解释器 解释器分控制台解释器 和 pycharm解释器 控制台解释器切换&#xff1a; 切换到解释器下 激活解释器 查看解释器 where python 激活成功 这时在控制台使用python xxx.py 可以…

​​​【收录 Hello 算法】10.1 二分查找

目录 10.1 二分查找 10.1.1 区间表示方法 10.1.2 优点与局限性 10.1 二分查找 二分查找&#xff08;binary search&#xff09;是一种基于分治策略的高效搜索算法。它利用数据的有序性&#xff0c;每轮缩小一半搜索范围&#xff0c;直至找到目标元素或搜索区间为空…

​​​【收录 Hello 算法】第 10 章 搜索

目录 第 10 章 搜索 本章内容 第 10 章 搜索 搜索是一场未知的冒险&#xff0c;我们或许需要走遍神秘空间的每个角落&#xff0c;又或许可以快速锁定目标。 在这场寻觅之旅中&#xff0c;每一次探索都可能得到一个未曾料想的答案。 本章内容 10.1 二分查找10.2 二…

恶劣天候激光雷达点云模拟方法论文整理

恶劣天候点云模拟方法论文整理 模拟雨天点云&#xff1a;【AAAI2024】模拟雪天点云&#xff1a;【CVPR 2022 oral】模拟雾天点云&#xff1a;【ICCV2021】模拟点云恶劣天候的散射现象&#xff1a;【Arxiv 2021】模拟积水地面的水花飞溅点云&#xff1a;【RAL2022】 模拟雨天点云…

蓝桥杯Web开发【模拟题三】15届

1.创意广告牌 在"绮幻山谷"的历史和"梦幻海湾"的繁华交汇之处&#xff0c;一块创意广告牌傲然矗立。它以木质纹理的背景勾勒出古朴氛围&#xff0c;上方倾斜的牌子写着"绮幻山谷的风吹到了梦幻海湾"&#xff0c;瞬间串联了过去与现在&#xff0…

EPIC免费领取《骑士精神2》 IGN9分神作骑士精神2限时免费领

EPIC免费领取《骑士精神2》 IGN9分神作骑士精神2限时免费领 最近Epic一直为玩家们送出各种游戏&#xff0c;从《龙腾世纪审判》到《模拟农场22》&#xff0c;而就在今天&#xff0c;epic又为玩家们送出了IGN评分9分高分的骑士精神2.这款游戏&#xff0c;该游戏是一款由Tripwir…

vue连接mqtt实现收发消息组件超级详细

基本概念&#xff1a; MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种基于发布/订阅模式的轻量级消息传输协议&#xff0c;专为低带宽、高延迟或不稳定的网络环境设计。以下是MQTT实现收发消息的基本原理&#xff1a; 客户端-服务器模型&#xff1a…

数据量较小的表是否有必要添加索引问题分析

目录 前言一、分析前准备1.1、准备测试表和数据1.2、插入测试数据1.3、测试环境说明 二、具体业务分析2.1、单次查询耗时分析2.2、无索引并发查询服务器CPU占用率分析2.3、添加索引并发查询服务器CPU占用率分析 三、总结 前言 在一次节日活动我们系统访问量到达了平时的两倍&am…

【小沐学GIS】GDAL库安装和使用(C++、Python)

文章目录 1、简介2、下载和编译&#xff08;C&#xff09;2.1 二进制构建2.1.1 Conda2.1.2 Vcpkg 2.2 源代码构建2.2.1 nmake.opt方式构建2.2.2 generate_vcxproj.bat方式构建 2.3 命令行测试2.3.1 获取S57海图数据 2.4 代码测试2.4.1 读取tiff信息 3、Python3.1 安装3.2 测试3…

零基础入门篇④ 初识Python(注释、编码规范、关键字...)

Python从入门到精通系列专栏面向零基础以及需要进阶的读者倾心打造,9.9元订阅即可享受付费专栏权益,一个专栏带你吃透Python,专栏分为零基础入门篇、模块篇、网络爬虫篇、Web开发篇、办公自动化篇、数据分析篇…学习不断,持续更新,火热订阅中🔥专栏订阅地址 👉Python从…

C语言 | Leetcode C语言题解之第110题平衡二叉树

题目&#xff1a; 题解&#xff1a; int height(struct TreeNode* root) {if (root NULL) {return 0;}int leftHeight height(root->left);int rightHeight height(root->right);if (leftHeight -1 || rightHeight -1 || fabs(leftHeight - rightHeight) > 1) {…

Android硬件渲染环境初始化

Android硬件渲染环境初始化 一.硬件加速渲染的开启1.ThreadedRenderer的初始化2.RenderProxy的创建 二.RenderProxy中组件的初始化1.RenderThread的创建2.CanvasContext的创建3.DrawFrameTask的初始化 三.RenderThread的启动1.RenderThread中组件的初始化2.RenderThread中任务的…