基于transformers框架实践Bert系列6-完形填空

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解完形填空应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)(注意:这是关键输出,本次任务就需要获取该值,可以取出那个被mask掉的token,获取其前几个,取score最高的(当然也可以使用top_k或者top_p方式获取一定随机性)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:ChnSentiCorp(该数据集本身是做情感分类,但是我们只需要取其text部分即可)
2)模型权重使用:bert-base-chinese

2.3 任务说明

完形填空其实就是在一段文字中mask掉几个字,让模型能够自动填充字。这里本身就是bert模型做预训练是所做的事情之一,因此就是让数据给模型做训练的过程。

2.4 实现关键

1)数据集结构是一个带有text和label两列的数据,我们只需要获取到text部分即可。
在这里插入图片描述
2)随机mask掉部分数据,这个本身也是bert的训练过程,因此在transforms框架中DataCollatorForLanguageModeling已经实现了,你也可以自己实现随机mask掉你的数据进行训练

3 关键代码

3.1 数据集处理

数据集不需要做过多处理,只需要将text部分进行tokenizer,并制定max_length和truncation即可

def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datas
new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加载

model = BertForMaskedLM.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForMaskedLM,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForMaskedLM也可以,其实AutoModel最终返回的也是BertForMaskedLM,它是根据你config中的model_type去匹配的。
这里列一下BertForMaskedLM的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了BertOnlyMLMHead,BertOnlyMLMHead其实就是一个二层神经网络,一层是BertPredictionHeadTransform(包括linear+geluAct+ln),一层是decoder(hidden_size*vocab_size大小的linear)。
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# 将输出结果outputs取第一个返回值,也就是last_hidden_state
sequence_output = outputs[0]
# 将last_hidden_state输入到cls层中,获得最终结果(预测的score和词)
prediction_scores = self.cls(sequence_output)

3.3 自动并随机mask数据

关键代码在于DataCollatorForLanguageModeling,该类会实现自动mask。参考torch_mask_tokens方法。

trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)

4 整体代码

"""
基于BERT做完形填空
1)数据集来自:ChnSentiCorp
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
from datasets import DatasetDict
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling, pipelinemodel_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/ChnSentiCorp"# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datasnew_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)# step 3 加载模型
model = BertForMaskedLM.from_pretrained(model_path)# step 4 创建TrainingArguments
# 原先train是9600条数据,batch_size=32,因此每个epoch的step=300
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizenum_train_epochs=1,              # 训练轮数logging_steps=30,                # log 打印的频率)# step 5 创建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],# 自动MASK关键所在,通过DataCollatorForLanguageModeling实现自动MASK数据data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)# Step 6 模型训练
trainer.train()# step 7 模型评估
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
str = datasets["test"][3]["text"]
str = str.replace("方便","[MASK][MASK]")
results = pipe(str)
# results[0][0]["token_str"]
print(results[0][0]["token_str"]+results[1][0]["token_str"])

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/13512.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

自己动手写docker——Namespace

Linux Namespace linux Namespace用于隔离一系列的系统资源,例如pid,userid,netword等,借助于Linux Namespace,可以实现容器的基本隔离。 Namespce介绍 Namespace类型系统调用参数作用Mount NamespaceCLONE_NEWNS隔离…

Python筑基之旅-MySQL数据库(一)

目录 一、MySQL数据库 1、简介 2、优点 2-1、开源和免费 2-2、高性能 2-3、可扩展性 2-4、易用性 2-5、灵活性 2-6、安全性和稳定性 2-7、丰富的功能 2-8、结合其他工具和服务 2-9、良好的兼容性和移植性 3、缺点 3-1、对大数据的支持有限 3-2、缺乏全文…

微服务如何做好监控

大家好,我是苍何。 在脉脉上看到这条帖子,说阿里 P8 因为上面 P9 斗争失败走人,以超龄 35 被裁,Boss 上找工作半年,到现在还处于失业中。 看了下沟通记录, 沟通了 1000 多次,但没有一个邀请投递…

uniapp中使用 iconfont字体

下载 iconfont 字体文件 打开 iconfont.css 文件,修改一下 把文件 复制到 static/iconfont/… 目录下 在App.vue中引入iconfont 5. 使用iconfont 使用 iconfont 有两种方式, 一种是 class 方式, 一种是使用 unicode 的方式 5.1 使用 class 的…

【Mac】Dreamweaver 2021 for mac v21.3 Rid中文版安装教程

软件介绍 Dreamweaver是Adobe公司开发的一款专业网页设计与前端开发软件。它集成了所见即所得(WYSIWYG)编辑器和代码编辑器,可以帮助开发者快速创建和编辑网页。Dreamweaver提供了丰富的功能和工具,包括代码提示、语法高亮、代码…

51单片机学习(1)2-1点亮一个LED

#include <REGX52.H> void() { p20xFE;//1111 1110 while(1) { //让程序停了下来了。 } }

教你一分钟搭建适合IT人员的在线开发工具箱

文章目录 1. 使用Docker本地部署it-tools2. 本地访问it-tools3. 安装cpolar内网穿透4. 固定it-tools公网地址 本篇文章将介绍如何在Windows上使用Docker本地部署IT- Tools&#xff0c;并且同样可以结合cpolar实现公网访问。 在前一篇文章中我们讲解了如何在Linux中使用Docker搭…

Anaconda Jupyter 报错及解决方法记录

一、AttributeError: module lib has no attribute X509_V_FLAG_CB_ISSUER_CHECK 背景&#xff1a;Anaconda更新版本后&#xff0c;运行import oss2时报错 ~/anaconda3/lib/python3.8/site-packages/OpenSSL/crypto.py in X509StoreFlags() 1535 NOTIFY_POLICY _lib…

【Java基础】集合(1) —— Collection

存储不同类型的对象: Object[] arrnew object[5];数组的长度是固定的, 添加或删除数据比较耗时 集合: Object[] toArray可以存储不同类型的对象随着存储的对象的增加&#xff0c;会自动的扩容集合提供了非常丰富的方法&#xff0c;便于操纵集合相当于容器&#xff0c;可以存储多…

探索Git之旅:仓库代码版本控制艺术

探索Git之旅&#xff1a;仓库代码版本控制艺术 引言Git基础与核心概念什么是版本控制&#xff1f;Git的工作流程分布式特性 Git实战操作指南安装与配置克隆仓库日常操作分支管理解决冲突 高级技巧与最佳实践Git FlowGit钩子Git别名 安全与性能考量结语与引发讨论 引言 在软件开…

冯喜运:5.16黄金是否突破阻力?黄金原油趋势分析

【黄金消息面分析】&#xff1a;周四(5月16日)亚市盘中&#xff0c;现货黄金延续昨日升势&#xff0c;金价目前最高触及2397.44美元/盎司&#xff0c;为4月19日以来新高。FXStreet首席分析师Valeria Bednarik撰文&#xff0c;对黄金技术前景进行分析。Bednarik指出&#xff0c;…

「51媒体」北京财经媒体有哪些?媒体邀约宣传

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 北京作为中国的首都&#xff0c;拥有众多的财经媒体&#xff0c;这些媒体在财经新闻报道、经济分析、市场研究等方面发挥着重要作用。根据搜索结果&#xff0c;以下是一些北京地区的财经…

富格林:曝光虚假套路规避亏损

富格林指出&#xff0c;在现货黄金市场中&#xff0c;交易时间很充足投资机会也多的是&#xff0c;但为什么还是有人亏损甚至爆仓呢&#xff1f;其实导致这种情况&#xff0c;是因为有一些投资者不知道其中的虚假套路&#xff0c;很容易就一头栽进去了。要规避虚假套路带来的亏…

CV每日论文--2024.5.15

1、Can Better Text Semantics in Prompt Tuning Improve VLM Generalization? 中文标题&#xff1a;更好的文本语义在提示微调中能否提高视觉语言模型的泛化能力? 简介&#xff1a;这篇论文介绍了一种新的可学习提示调整方法,该方法超越了仅对视觉语言模型进行微调的传统方…

Lazyboy品牌发布会“球幕气膜”

Lazyboy品牌发布会“球幕气膜”为品牌活动提供了一个独特、现代化、环保的展示空间。这座球幕气膜不仅为发布会提供了一个视觉震撼的场地&#xff0c;也为与会嘉宾带来了全新的体验。作为轻空间&#xff08;江苏&#xff09;膜科技有限公司&#xff08;以下简称“轻空间”&…

使用Docker在阿里云ECS上部署Gitlab,提供代码托管、CICD 和 docker镜像服务

文章目录 使用Docker在阿里云ECS上部署Gitlab1.购买一个数据&#xff0c;挂载到/data用于存储gitlab相关数据2. 部署docker引擎3. 调整ssh的默认端口&#xff0c;将22端口留给gitlab4. 部署gitlab5. 进入docker容器获取gitlab的默认密码6. 登录gitlab&#xff0c;完成gitlab-ru…

linux ndk编译搭建测试

一、ndk下载 NDK 下载 | Android NDK | Android Developers 二、ndk环境变量配置 ndk解压&#xff1a; unzip android-ndk-r26d-linux.zip 环境变量配置&#xff1a; export NDK_HOME/rd/own/test/android-ndk-r26d/ export PATH$PATH:$NDK_HOME 三、编译测试验证 …

虚函数应用和原理

虚函数的表现形式 用子类初始化父类指针, 调用虚函数时, 仍然调用的是子类的虚函数 测试代码如下 #include <iostream> #include <string.h>using namespace std;class A { public:void test() { cout << a << endl; };virtual void test2 (){ cout …

LeetCode-2589. 完成所有任务的最少时间【栈 贪心 数组 二分查找 排序】

LeetCode-2589. 完成所有任务的最少时间【栈 贪心 数组 二分查找 排序】 题目描述&#xff1a;解题思路一&#xff1a;贪心暴力解题思路二&#xff1a;栈二分查找解题思路三&#xff1a;简化版 题目描述&#xff1a; 你有一台电脑&#xff0c;它可以 同时 运行无数个任务。给你…

解锁电商数据之门:京东商品详情API接口的深度解析与应用指南

一、京东商品详情API简介 京东商品详情API是京东开放平台提供的一项服务&#xff0c;允许第三方应用通过调用接口获取京东商城中商品的详细信息。这些信息包括但不限于商品名称、价格、库存、详情描述、用户评价等。 二、功能特点 数据全面&#xff1a;提供商品的全方位数据…