【Pytorch神经网络实战案例】33 使用BERT模型实现完形填空任务

1 案例描述

案例:加载Transformers库中的BERT模型,并用它实现完形填空任务,即预测一个句子中缺失的单词。

2 代码实现:使用BERT模型实现完形填空任务

2.1 代码实现:载入词表,并对输入的文本进行分词转化---BERT_MASK.py(第1部分)

import torch
from transformers import BertTokenizer, BertForMaskedLM# 1.1 载入词表,并对输入的文本进行分词转化
# 加载预训练模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 输入文本,BERT模型需要特殊词来标定句子:
# [CLS]:标记一个段落的开始。一个段落可以有一个或多个句子,但是只能有一个[CLS]。[CLS]在BERT模型中还会被用作分类任务的输出特征。
# [SEP]:标记一个句子的结束。在一个段落中,可以有多个[SEP]。
text = "[CLS] Who is Li BiGor ? [SEP] Li BiGor is a programmer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# 使用词表对输入文本进行转换。与中文分词有点类似。由于词表中不可能覆盖所有的单词,因此当输入文本中的单词不存在时,系统会使用带有通配符的单间(以“#”开头的单词)将其拆开。
print("词表转化结果:",tokenized_text)
# 词表转化结果:['[CLS]','who','is','li','big','##or','?','[SEP]','li','big','##or','is','a','programmer','[SEP]']

2.2 代码实现:遮蔽单词,并将其转换为索引值---BERT_MASK.py(第2部分)

# 1.2 遮蔽单词,并将其转换为索引值,使用标记字符[MAS]代替输入文本中索引值为8的单词,对“Li”进行遮蔽,并将整个句子中的单词转换为词表中的索引值。
masked_index = 8  # 掩码一个标记,再使用'BertForMaskedLM'预测回来
tokenized_text[masked_index] = '[MASK]' # 标记字符[MASK],是BERT模型中的特殊标识符。在BERT模型的训练过程中,会对输入文本的随机位置用[MASK]字符进行替换,并训练模型预测出[MASK]字符对应的值。
print("句子中的索引:",tokenized_text)
# 句子中的索引:['[CLS]','who','is','li','big','##or','?','[SEP]','[MASK]','big','##or','is','a','programmer','[SEP]']
# 将标记转换为词汇表索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
print("句子中的向量:",tokens_tensor)
# 句子中的向量:tensor([[101,2040,2003,5622,2502,2953,1029,102,103,2502,2953,2003,1037,20273,102]])

2.3 代码实现:加载预训练模型,并对遮蔽单词进行预测---BERT_MASK.py(第3部分)

# 1.3 加载预训练模型,并对遮蔽单词进行预测
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# 加载预训练模型 (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased') # 用BertForMaskedLM类加载模型,该类可以对句子中的标记字符[MASK]进行预测。
model.eval()
model.to(device)
# 段标记索引:定义输入的BertForMaskedLM类句子指示参数,用于指示输入文本中的单词是属于第一句还是属于第二句。属于第一句的单词用0来表示(一共8个),属于第二句的单词用1来表示(一共7个)。
segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
segments_tensors = torch.tensor([segments_ids]).to(device)tokens_tensor = tokens_tensor.to(device)# 预测所有的tokens
with torch.no_grad():# 将文本和句子指示参数输入模型进行预测。# 输出结果是一个形状为[1,15,30522]的张量。其中,1代表批次个数,15代表输入句子中的15个单词,30522是词表中单词的个数。# 模型的结果表示词表中每个单词在句子中可能出现的概率。outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0]  # [1, 15, 30522]
# 预测结果:从输出结果中取出[MASK]字符对应的预测索引值。
predicted_index = torch.argmax(predictions[0, masked_index]).item()
# 将预测索引值转换为单词。
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print('预测词为:', predicted_token)
# 预测词为: li

3 代码总览---BERT_MASK.py

import torch
from transformers import BertTokenizer, BertForMaskedLM# 1.1 载入词表,并对输入的文本进行分词转化
# 加载预训练模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')# 输入文本,BERT模型需要特殊词来标定句子:
# [CLS]:标记一个段落的开始。一个段落可以有一个或多个句子,但是只能有一个[CLS]。[CLS]在BERT模型中还会被用作分类任务的输出特征。
# [SEP]:标记一个句子的结束。在一个段落中,可以有多个[SEP]。
text = "[CLS] Who is Li BiGor ? [SEP] Li BiGor is a programmer [SEP]"
tokenized_text = tokenizer.tokenize(text)
# 使用词表对输入文本进行转换。与中文分词有点类似。由于词表中不可能覆盖所有的单词,因此当输入文本中的单词不存在时,系统会使用带有通配符的单间(以“#”开头的单词)将其拆开。
print("词表转化结果:",tokenized_text)
# 词表转化结果:['[CLS]','who','is','li','big','##or','?','[SEP]','li','big','##or','is','a','programmer','[SEP]']# 1.2 遮蔽单词,并将其转换为索引值,使用标记字符[MAS]代替输入文本中索引值为8的单词,对“Li”进行遮蔽,并将整个句子中的单词转换为词表中的索引值。
masked_index = 8  # 掩码一个标记,再使用'BertForMaskedLM'预测回来
tokenized_text[masked_index] = '[MASK]' # 标记字符[MASK],是BERT模型中的特殊标识符。在BERT模型的训练过程中,会对输入文本的随机位置用[MASK]字符进行替换,并训练模型预测出[MASK]字符对应的值。
print("句子中的索引:",tokenized_text)
# 句子中的索引:['[CLS]','who','is','li','big','##or','?','[SEP]','[MASK]','big','##or','is','a','programmer','[SEP]']
# 将标记转换为词汇表索引
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# 将输入转换为PyTorch张量
tokens_tensor = torch.tensor([indexed_tokens])
print("句子中的向量:",tokens_tensor)
# 句子中的向量:tensor([[101,2040,2003,5622,2502,2953,1029,102,103,2502,2953,2003,1037,20273,102]])# 1.3 加载预训练模型,并对遮蔽单词进行预测
# 指定设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# 加载预训练模型 (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased') # 用BertForMaskedLM类加载模型,该类可以对句子中的标记字符[MASK]进行预测。
model.eval()
model.to(device)
# 段标记索引:定义输入的BertForMaskedLM类句子指示参数,用于指示输入文本中的单词是属于第一句还是属于第二句。属于第一句的单词用0来表示(一共8个),属于第二句的单词用1来表示(一共7个)。
segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
segments_tensors = torch.tensor([segments_ids]).to(device)tokens_tensor = tokens_tensor.to(device)# 预测所有的tokens
with torch.no_grad():# 将文本和句子指示参数输入模型进行预测。# 输出结果是一个形状为[1,15,30522]的张量。其中,1代表批次个数,15代表输入句子中的15个单词,30522是词表中单词的个数。# 模型的结果表示词表中每个单词在句子中可能出现的概率。outputs = model(tokens_tensor, token_type_ids=segments_tensors)
predictions = outputs[0]  # [1, 15, 30522]
# 预测结果:从输出结果中取出[MASK]字符对应的预测索引值。
predicted_index = torch.argmax(predictions[0, masked_index]).item()
# 将预测索引值转换为单词。
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print('预测词为:', predicted_token)
# 预测词为: li

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469146.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

container_of宏

1.container_of宏 1> Container_of在Linux内核中是一个常用的宏,用于从包含在某个结构中的指针获得结构本身的指针,通俗地讲就是通过结构体变量中某个成员的首地址进而获得整个结构体变量的首地址。 2>接口: container_of(ptr, type, …

c++ string 删除字符_字符串操作的全面总结

来自公众号:C语言与cpp编程字符串操作看似简单,其实非常重要,不注意的话,经常出现代码运行结果和自己想要的不一致,甚至崩溃。本文总结了一些构建string对象方法、修改string对象的方法、string类型的操作函数、string…

【Pytorch神经网络理论篇】 40 Transformers中的词表工具Tokenizer

同学你好!本文章于2021年末编写,获得广泛的好评! 故在2022年末对本系列进行填充与更新,欢迎大家订阅最新的专栏,获取基于Pytorch1.10版本的理论代码(2023版)实现, Pytorch深度学习理论篇(2023版)目录地址…

warning: function declaration isn’t a prototype(函数声明不是原型)的解决办法

linux驱动中定义一个无参的函数int probe_num(){....}警告:函数声明不是一个原型 [-Wstrict-prototypes]应对方法:改成int probe_num( void){....}警告消失 http://blog.csdn.net/dumgeewang/article/details/7410477

【Pytorch神经网络实战案例】34 使用GPT-2模型实现句子补全功能(手动加载)

1 GPT-2 模型结构 GPT-2的整体结构如下图,GPT-2是以Transformer为基础构建的,使用字节对编码的方法进行数据预处理,通过预测下一个词任务进行预训练的语言模型。 1.1 GPT-2 功能简介 GPT-2 就是一个语言模型,能够根据上文预测下…

电容式传感器位移性能试验报告_一文读懂什么是接近传感器?

点击上方蓝字 记得关注我们哦!接近传感器是一种非接触式传感器,当目标进入传感器的视野时,它会检测到物体(通常称为“目标”)的存在。取决于接近传感器的类型,传感器可以利用声音,光,红外辐射(IR)或电磁场来…

[dts]DTS实例分析

2. 通常会碰到的实际问题 到此,问题出现了: 1. 当写一个按键驱动,应该如何在*.dts或者*.dtsi中操作? 2. 当在串口driver中需要使用到某个pin脚作为普通输出IO,该如何操作? 3. 当在串口driver中需要使用某个muxpin脚作…

【解决】YOLOv6.1安装requirements.txt报错UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84

案例描述: 使用YOLOV5时,报错解决UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84 in position 285: illegal multibyte sequence 解决方案: 在C:\ProgramData\Anaconda3\lib\distutils\dist.py"文件搜索read 将parser.read(filenam…

bp 神经网络 优点 不足_深度学习之BP神经网络--Stata和R同步实现(附Stata数据和代码)

说明:本文原发于“计量经济圈”公众号,在此仅展示Stata的部分。R部分请移步至本人主页的“R语言与机器学习--经济学视角”专栏,或点击下方链接卡跳转。盲区行者:深度学习之BP神经网络--Stata和R同步实现(附R数据和代码…

高通平台中gpio简单操作和调试

做底层驱动免不了gpio打交道,所以对其操作和调试进行了一下简单的梳理 一、gpio的调试方法 在Linux下,通过sysfs,获取gpio状态,也可以操作gpio。 1、获取gpio状态 cd /sys/kernel/debug/ cat gpio 2、操作gpio(以gpi…

Win10系统下使用anaconda在虚拟环境下安装CUDA及CUDNN

前排预警:不要挂梯子!!!!!使用清华源就行不然报错!!!! 解决check_hostname requires server_hostname_orange_の_呜的博客-CSDN博客错误描述在GitHub下载代码文件后使用pip install -r requirement.txt下载依赖包时出…

jlink问题

现在淘宝上买到的JLINK都是盗版的,用着用着的时候就会遇到各种异常问题, 这里有一个方法来修改SN,修改SN后就会变得正常了,亲测有效 两种固件: V*_ID-自定义.* 是出厂设置,烧入后用jlink.exe连接上S/N会显示-1. 此时可以根据自己的…

将XML格式转化为YOLO需要的txt格式(代码)

1、XML的格式 <annotation><folder>cr</folder><filename>crazing_2.jpg</filename><source><database>NEU-DET</database></source><size><width>200</width><height>200</height><…

js 点击button切换颜色_ThingJS 和three.js开发示例对比,让开发早点下班回家!3D 可视化...

ThingJS 3D框架简化了开发工作&#xff0c;面向对象和模块化的特点使得网页代码更加易于管理和维护&#xff0c;并且提供近200个官方示例&#xff0c;直接获取API能力&#xff0c;不需要基于3D概念进行开发&#xff0c;适合3D商业项目快速生成&#xff01;距离您的业务仅一层之…

变量命名

列举一下我自己的一些写法 local_int_loop_count global_int_data_count local_bool_plug_insert_flag global_bool_ble_connect_flag函数命名 get_tick_number set_tick_number为了代码清晰易懂&#xff0c;通常变量名采用一些著名的命名规则&#xff0c;主要有Camel标记法&am…

VSCode使用技巧——Ctrl+鼠标滚轮键使字体进行缩放

点击VSCode左下角的齿轮&#xff0c;进入设置 进入Extensions——》JSON——》Edit in settings.json 在json当中添加如下&#xff1a; "editor.mouseWheelZoom": true,

python 交互式可视化库_Python 交互式可视化库

Python 交互式可视化库 所属分类&#xff1a;中间件编程 开发工具&#xff1a;Python 文件大小&#xff1a;12843KB 下载次数&#xff1a;1 上传日期&#xff1a;2018-12-06 18:40:56 上 传 者&#xff1a;孤独的老张 说明&#xff1a; 一个 Python 交互式可视化库&#xff0c;…

OpenCV各版本差异与演化,从1.x到4.0

最近因项目需要&#xff0c;得把OpenCV捡起来&#xff0c;登录OpenCV官网&#xff0c;竟然发现release了4.0.0-beata版本&#xff0c;所以借此机会&#xff0c;查阅资料&#xff0c;了解下OpenCV各版本的差异及其演化过程&#xff0c;形成了以下几点认识&#xff1a; 新版本的…

python题库刷题训练软件_Python基础练习100题 ( 11~ 20)

刷题继续 上一期和大家分享了前10道题&#xff0c;今天继续来刷11~20 Question 11: Write a program which accepts a sequence of comma separated 4 digit binary numbers as its input and then check whether they are divisible by 5 or not. The numbers that are divisi…

如何学习计算机图形学

http://blog.csdn.net/szchtx/article/details/6916675转载于:https://www.cnblogs.com/ArcherHuang/p/6574560.html