基于transformers框架实践Bert系列6-完形填空

本系列用于Bert模型实践实际场景,分别包括分类器、命名实体识别、选择题、文本摘要等等。(关于Bert的结构和详细这里就不做讲解,但了解Bert的基本结构是做实践的基础,因此看本系列之前,最好了解一下transformers和Bert等)
本篇主要讲解完形填空应用场景。本系列代码和数据集都上传到GitHub上:https://github.com/forever1986/bert_task

1 环境说明

1)本次实践的框架采用torch-2.1+transformer-4.37
2)另外还采用或依赖其它一些库,如:evaluate、pandas、datasets、accelerate等

2 前期准备

Bert模型是一个只包含transformer的encoder部分,并采用双向上下文和预测下一句训练而成的预训练模型。可以基于该模型做很多下游任务。

2.1 了解Bert的输入输出

Bert的输入:input_ids(使用tokenizer将句子向量化),attention_mask,token_type_ids(句子序号)、labels(结果)
Bert的输出:
last_hidden_state:最后一层encoder的输出;大小是(batch_size, sequence_length, hidden_size)(注意:这是关键输出,本次任务就需要获取该值,可以取出那个被mask掉的token,获取其前几个,取score最高的(当然也可以使用top_k或者top_p方式获取一定随机性)
pooler_output:这是序列的第一个token(classification token)的最后一层的隐藏状态,输出的大小是(batch_size, hidden_size),它是由线性层和Tanh激活函数进一步处理的。(通常用于句子分类,至于是使用这个表示,还是使用整个输入序列的隐藏状态序列的平均化或池化,视情况而定)。
hidden_states: 这是输出的一个可选项,如果输出,需要指定config.output_hidden_states=True,它也是一个元组,它的第一个元素是embedding,其余元素是各层的输出,每个元素的形状是(batch_size, sequence_length, hidden_size)
attentions:这是输出的一个可选项,如果输出,需要指定config.output_attentions=True,它也是一个元组,它的元素是每一层的注意力权重,用于计算self-attention heads的加权平均值。

2.2 数据集与模型

1)数据集来自:ChnSentiCorp(该数据集本身是做情感分类,但是我们只需要取其text部分即可)
2)模型权重使用:bert-base-chinese

2.3 任务说明

完形填空其实就是在一段文字中mask掉几个字,让模型能够自动填充字。这里本身就是bert模型做预训练是所做的事情之一,因此就是让数据给模型做训练的过程。

2.4 实现关键

1)数据集结构是一个带有text和label两列的数据,我们只需要获取到text部分即可。
在这里插入图片描述
2)随机mask掉部分数据,这个本身也是bert的训练过程,因此在transforms框架中DataCollatorForLanguageModeling已经实现了,你也可以自己实现随机mask掉你的数据进行训练

3 关键代码

3.1 数据集处理

数据集不需要做过多处理,只需要将text部分进行tokenizer,并制定max_length和truncation即可

def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datas
new_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)

3.2 模型加载

model = BertForMaskedLM.from_pretrained(model_path)

注意:这里使用的是transformers中的BertForMaskedLM,该类对bert模型进行封装。如果我们不使用该类,需要自己定义一个model,继承bert,增加分类线性层。另外使用AutoModelForMaskedLM也可以,其实AutoModel最终返回的也是BertForMaskedLM,它是根据你config中的model_type去匹配的。
这里列一下BertForMaskedLM的关键源代码说明一下transformers帮我们做了哪些关键事情

# 在__init__方法中增加增加了BertOnlyMLMHead,BertOnlyMLMHead其实就是一个二层神经网络,一层是BertPredictionHeadTransform(包括linear+geluAct+ln),一层是decoder(hidden_size*vocab_size大小的linear)。
self.bert = BertModel(config, add_pooling_layer=False)
self.cls = BertOnlyMLMHead(config)
# 将输出结果outputs取第一个返回值,也就是last_hidden_state
sequence_output = outputs[0]
# 将last_hidden_state输入到cls层中,获得最终结果(预测的score和词)
prediction_scores = self.cls(sequence_output)

3.3 自动并随机mask数据

关键代码在于DataCollatorForLanguageModeling,该类会实现自动mask。参考torch_mask_tokens方法。

trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)

4 整体代码

"""
基于BERT做完形填空
1)数据集来自:ChnSentiCorp
2)模型权重使用:bert-base-chinese
"""
# step 1 引入数据库
from datasets import DatasetDict
from transformers import TrainingArguments, Trainer, BertTokenizerFast, BertForMaskedLM, DataCollatorForLanguageModeling, pipelinemodel_path = "./model/tiansz/bert-base-chinese"
data_path = "./data/ChnSentiCorp"# step 2 数据集处理
datasets = DatasetDict.load_from_disk(data_path)
tokenizer = BertTokenizerFast.from_pretrained(model_path)def process_function(datas):tokenized_datas = tokenizer(datas["text"], max_length=256, truncation=True)return tokenized_datasnew_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)# step 3 加载模型
model = BertForMaskedLM.from_pretrained(model_path)# step 4 创建TrainingArguments
# 原先train是9600条数据,batch_size=32,因此每个epoch的step=300
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹per_device_train_batch_size=32,  # 训练时的batch_sizenum_train_epochs=1,              # 训练轮数logging_steps=30,                # log 打印的频率)# step 5 创建Trainer
trainer = Trainer(model=model,args=train_args,train_dataset=new_datasets["train"],# 自动MASK关键所在,通过DataCollatorForLanguageModeling实现自动MASK数据data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15),)# Step 6 模型训练
trainer.train()# step 7 模型评估
pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer, device=0)
str = datasets["test"][3]["text"]
str = str.replace("方便","[MASK][MASK]")
results = pipe(str)
# results[0][0]["token_str"]
print(results[0][0]["token_str"]+results[1][0]["token_str"])

5 运行效果

在这里插入图片描述

注:本文参考来自大神:https://github.com/zyds/transformers-code

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/13512.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python筑基之旅-MySQL数据库(一)

目录 一、MySQL数据库 1、简介 2、优点 2-1、开源和免费 2-2、高性能 2-3、可扩展性 2-4、易用性 2-5、灵活性 2-6、安全性和稳定性 2-7、丰富的功能 2-8、结合其他工具和服务 2-9、良好的兼容性和移植性 3、缺点 3-1、对大数据的支持有限 3-2、缺乏全文…

微服务如何做好监控

大家好,我是苍何。 在脉脉上看到这条帖子,说阿里 P8 因为上面 P9 斗争失败走人,以超龄 35 被裁,Boss 上找工作半年,到现在还处于失业中。 看了下沟通记录, 沟通了 1000 多次,但没有一个邀请投递…

uniapp中使用 iconfont字体

下载 iconfont 字体文件 打开 iconfont.css 文件,修改一下 把文件 复制到 static/iconfont/… 目录下 在App.vue中引入iconfont 5. 使用iconfont 使用 iconfont 有两种方式, 一种是 class 方式, 一种是使用 unicode 的方式 5.1 使用 class 的…

【Mac】Dreamweaver 2021 for mac v21.3 Rid中文版安装教程

软件介绍 Dreamweaver是Adobe公司开发的一款专业网页设计与前端开发软件。它集成了所见即所得(WYSIWYG)编辑器和代码编辑器,可以帮助开发者快速创建和编辑网页。Dreamweaver提供了丰富的功能和工具,包括代码提示、语法高亮、代码…

教你一分钟搭建适合IT人员的在线开发工具箱

文章目录 1. 使用Docker本地部署it-tools2. 本地访问it-tools3. 安装cpolar内网穿透4. 固定it-tools公网地址 本篇文章将介绍如何在Windows上使用Docker本地部署IT- Tools,并且同样可以结合cpolar实现公网访问。 在前一篇文章中我们讲解了如何在Linux中使用Docker搭…

Anaconda Jupyter 报错及解决方法记录

一、AttributeError: module lib has no attribute X509_V_FLAG_CB_ISSUER_CHECK 背景:Anaconda更新版本后,运行import oss2时报错 ~/anaconda3/lib/python3.8/site-packages/OpenSSL/crypto.py in X509StoreFlags() 1535 NOTIFY_POLICY _lib…

【Java基础】集合(1) —— Collection

存储不同类型的对象: Object[] arrnew object[5];数组的长度是固定的, 添加或删除数据比较耗时 集合: Object[] toArray可以存储不同类型的对象随着存储的对象的增加,会自动的扩容集合提供了非常丰富的方法,便于操纵集合相当于容器,可以存储多…

冯喜运:5.16黄金是否突破阻力?黄金原油趋势分析

【黄金消息面分析】:周四(5月16日)亚市盘中,现货黄金延续昨日升势,金价目前最高触及2397.44美元/盎司,为4月19日以来新高。FXStreet首席分析师Valeria Bednarik撰文,对黄金技术前景进行分析。Bednarik指出,…

「51媒体」北京财经媒体有哪些?媒体邀约宣传

传媒如春雨,润物细无声,大家好,我是51媒体网胡老师。 北京作为中国的首都,拥有众多的财经媒体,这些媒体在财经新闻报道、经济分析、市场研究等方面发挥着重要作用。根据搜索结果,以下是一些北京地区的财经…

CV每日论文--2024.5.15

1、Can Better Text Semantics in Prompt Tuning Improve VLM Generalization? 中文标题:更好的文本语义在提示微调中能否提高视觉语言模型的泛化能力? 简介:这篇论文介绍了一种新的可学习提示调整方法,该方法超越了仅对视觉语言模型进行微调的传统方…

Lazyboy品牌发布会“球幕气膜”

Lazyboy品牌发布会“球幕气膜”为品牌活动提供了一个独特、现代化、环保的展示空间。这座球幕气膜不仅为发布会提供了一个视觉震撼的场地,也为与会嘉宾带来了全新的体验。作为轻空间(江苏)膜科技有限公司(以下简称“轻空间”&…

使用Docker在阿里云ECS上部署Gitlab,提供代码托管、CICD 和 docker镜像服务

文章目录 使用Docker在阿里云ECS上部署Gitlab1.购买一个数据,挂载到/data用于存储gitlab相关数据2. 部署docker引擎3. 调整ssh的默认端口,将22端口留给gitlab4. 部署gitlab5. 进入docker容器获取gitlab的默认密码6. 登录gitlab,完成gitlab-ru…

linux ndk编译搭建测试

一、ndk下载 NDK 下载 | Android NDK | Android Developers 二、ndk环境变量配置 ndk解压: unzip android-ndk-r26d-linux.zip 环境变量配置: export NDK_HOME/rd/own/test/android-ndk-r26d/ export PATH$PATH:$NDK_HOME 三、编译测试验证 …

LeetCode-2589. 完成所有任务的最少时间【栈 贪心 数组 二分查找 排序】

LeetCode-2589. 完成所有任务的最少时间【栈 贪心 数组 二分查找 排序】 题目描述:解题思路一:贪心暴力解题思路二:栈二分查找解题思路三:简化版 题目描述: 你有一台电脑,它可以 同时 运行无数个任务。给你…

mac安装两个版本谷歌浏览器;在mac运行不同版本的chrome浏览器

场景 正常情况下,mac上只能安装一个版本的chrome浏览器,即使你安装了两个版本的,打开老旧版本时候也会自动切换成最新版的浏览器 故本文主要解决如何下载和在mac运行不同版本的chrome浏览器 文章目录 场景一、下载1.mac本身就有一个最新版ch…

Java语言saas模式云HIS系统源码 前端Angular+后台SpringBoot云HIS系统源码 HIS系统适合哪些类型的医院?

Java语言saas模式云HIS系统源码 前端Angular后台SpringBoot云HIS系统源码 HIS系统适合哪些类型的医院? 云HIS系统(医院信息系统)是对医院及其所属各部门的人、财、物进行综合管理,对在医疗活动各阶段产生的数据进行采集、储存、处…

CCF20181201——小明上学

CCF20181201——小明上学 代码如下&#xff1a; #include<bits/stdc.h> using namespace std; int main() {int r,y,g,n,k[101],t[101],sum0;cin>>r>>y>>g;cin>>n; for(int i0;i<n;i){cin>>k[i]>>t[i];if(k[i]0||k[i]1)sumt[i];…

ITSM的服务台如何让工作更流畅

在现代企业的信息技术管理框架内&#xff0c;IT服务管理&#xff08;IT Service Management, ITSM&#xff09;体系扮演着至关重要的角色&#xff0c;而其中的服务台则是这一复杂体系的心脏地带。服务台不仅仅是解答技术疑问的一线窗口&#xff0c;更是企业IT运维效率与用户满意…

FENDI CLUB啤酒,为何女生喜欢?

精酿啤酒已经成了女生喜欢的饮品&#xff0c;在日剧《无法成为野兽的我们》里&#xff0c;主人公小晶永远保持标准笑容&#xff0c;完美完成所有的工作。只有一个人的时候&#xff0c;她才会放下习惯性的微笑&#xff0c;显露自己的疲惫。小晶缓解疲惫&#xff0c;就是下班后去…

尽微好物:从0到10亿+的抖音电商的TOP1“联盟团长”,如何使用NineData实现上云下云

杭州尽微供应链是抖⾳平台⽉均带货10E的TOP1“联盟团⻓”&#xff0c;是字节跳动⼀级代理商&#xff0c;巨量千川指定服务商&#xff0c;拥有商品库9万&#xff0c;是⾏业领先的电商供应链平台&#xff0c;达⼈陪跑机构。 杭州尽微供应链以天猫、京东抖音电商业务为依托&#x…