【Pytorch神经网络实战案例】34 使用GPT-2模型实现句子补全功能(手动加载)

 

1 GPT-2 模型结构

GPT-2的整体结构如下图,GPT-2是以Transformer为基础构建的,使用字节对编码的方法进行数据预处理,通过预测下一个词任务进行预训练的语言模型。

1.1 GPT-2 功能简介

GPT-2 就是一个语言模型,能够根据上文预测下一个单词,所以它就可以利用预训练已经学到的知识来生成文本,如生成新闻。也可以使用另一些数据进行微调,生成有特定格式或者主题的文本,如诗歌、戏剧。

2 手动加载GPT-2模型并实现语句与完整句子预测

使用GPT-2模型配套的PreTrainedTokenizer类,所需要加载的词表文件比BERT模型多了一个merges文件。

2.1 代码实现:手动加载GPT-2模型并实现下一个单词预测---GPT2_make.py(第1部分)

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel# 案例描述:Transformers库中的GPT-2模型,并用它实现下一词预测功能,即预测一个未完成句子的下一个可能出现的单词。
# 下一词预测任务是一个常见的任务,在Transformers库中有很多模型都可以实现该任务。也可以使用BERT模型来实现。选用GPT-2模型,主要在于介绍手动加载多词表文件的特殊方式。# 1.1 加载词表文件# 自动加载预训练模型(权重)
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 手动加载词表文件:gpt2-merges.txt gpt2-vocab.json。
# from_pretrained方法是支持从本地载入多个词表文件的,但对载入的词表文件名称有特殊的要求:该文件名称必须按照源码文件tokenization_gpt2.py的VOCAB_FILES_NAMES字典对象中定义的名字来命名。
# 故使用from_pretrained方法,必须对已经下载好的词表文件进行改名将/gpt2/gpt2-vocab.json和/gpt2/gpt2-merges.txt这两个文件,分别改名为“gpt2/vocab.json和/gpt2/merges.txt
tokenizer = GPT2Tokenizer.from_pretrained(r'./models/gpt2') # 自动加载改名后的文件# 编码输入
indexed_tokens = tokenizer.encode("Who is Li BiGor ? Li BiGor is a")
print("输入语句为:",tokenizer.decode(indexed_tokens))
tokens_tensor = torch.tensor([indexed_tokens])  # 将输入语句转换为张量# 自动加载预训练模型(权重)
# model = GPT2LMHeadModel.from_pretrained('gpt2')
# 手动加载:配置文件gpt2-config.json 与 权重文件pt2-pytorch_model.bin
model = GPT2LMHeadModel.from_pretrained('./models/gpt2/gpt2-pytorch_model.bin',config='./models/gpt2/gpt2-config.json')# 将模型设置为评估模式
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokens_tensor = tokens_tensor.to(DEVICE)
model.to(DEVICE)# 预测所有标记
with torch.no_grad():outputs = model(tokens_tensor)predictions = outputs[0]# 得到预测的下一词
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print("输出语句为:",predicted_text) # GPT-2模型没有为输入文本添加特殊词。
# 输出:Who is Li BiGor? Li BiGor is a Chinese

2.2 代码实现:手动加载GPT-2模型并实现完整句子预测---GPT2_make.py(第2部分)

# 案例描述:Transformers库中的GPT-2模型,通过循环生成下一词,实现将一句话补充完整。
# 1.2 生成一段完整的话 这里有BUg 暂不会改
stopids = tokenizer.convert_tokens_to_ids(["."])[0] # 定义结束符# 在循环调用模型预测功能时,使用了模型的past功能。该功能以使模型进入连续预测状态,即在前面预测结果的基础之上进行下一词预测,而不需要在每预测时,对所有句子进行重新处理。
# past功能是使用预训练模型时很常用的功能,在Transformers库中,凡是带有下一词预测功能的预训练模型(如GPT,XLNet,CTRL等)都有这个功能。
# 但并不是所有模型的past功能都是通过past参数进行设置的,有的模型虽然使用的参数名称是mems,但作用与pat参数一样。
past = None # 定义模型参数for i in range(100):    # 循环100次with torch.no_grad():output, past = model(tokens_tensor, past=past)  # 预测下一次token = torch.argmax(output[..., -1, :])indexed_tokens += [token.tolist()]  # 将预测结果收集if stopids == token.tolist():   # 当预测出句号时,终止预测。breaktokens_tensor = token.unsqueeze(0)  # 定义下一次预测的输入张量sequence = tokenizer.decode(indexed_tokens) # 进行字符串编码
print(sequence)

3 GPT2_make.py(汇总)

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel# 案例描述:Transformers库中的GPT-2模型,并用它实现下一词预测功能,即预测一个未完成句子的下一个可能出现的单词。
# 下一词预测任务是一个常见的任务,在Transformers库中有很多模型都可以实现该任务。也可以使用BERT模型来实现。选用GPT-2模型,主要在于介绍手动加载多词表文件的特殊方式。# 1.1 加载词表文件# 自动加载预训练模型(权重)
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 手动加载词表文件:gpt2-merges.txt gpt2-vocab.json。
# from_pretrained方法是支持从本地载入多个词表文件的,但对载入的词表文件名称有特殊的要求:该文件名称必须按照源码文件tokenization_gpt2.py的VOCAB_FILES_NAMES字典对象中定义的名字来命名。
# 故使用from_pretrained方法,必须对已经下载好的词表文件进行改名将/gpt2/gpt2-vocab.json和/gpt2/gpt2-merges.txt这两个文件,分别改名为“gpt2/vocab.json和/gpt2/merges.txt
tokenizer = GPT2Tokenizer.from_pretrained(r'./models/gpt2') # 自动加载改名后的文件# 编码输入
indexed_tokens = tokenizer.encode("Who is Li BiGor ? Li BiGor is a")
print("输入语句为:",tokenizer.decode(indexed_tokens))
tokens_tensor = torch.tensor([indexed_tokens])  # 将输入语句转换为张量# 自动加载预训练模型(权重)
# model = GPT2LMHeadModel.from_pretrained('gpt2')
# 手动加载:配置文件gpt2-config.json 与 权重文件pt2-pytorch_model.bin
model = GPT2LMHeadModel.from_pretrained('./models/gpt2/gpt2-pytorch_model.bin',config='./models/gpt2/gpt2-config.json')# 将模型设置为评估模式
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokens_tensor = tokens_tensor.to(DEVICE)
model.to(DEVICE)# 预测所有标记
with torch.no_grad():outputs = model(tokens_tensor)predictions = outputs[0]# 得到预测的下一词
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print("输出语句为:",predicted_text) # GPT-2模型没有为输入文本添加特殊词。
# 输出:Who is Li BiGor? Li BiGor is a Chinese# 案例描述:Transformers库中的GPT-2模型,通过循环生成下一词,实现将一句话补充完整。
# 1.2 生成一段完整的话 这里有BUg 暂不会改
stopids = tokenizer.convert_tokens_to_ids(["."])[0] # 定义结束符# 在循环调用模型预测功能时,使用了模型的past功能。该功能以使模型进入连续预测状态,即在前面预测结果的基础之上进行下一词预测,而不需要在每预测时,对所有句子进行重新处理。
# past功能是使用预训练模型时很常用的功能,在Transformers库中,凡是带有下一词预测功能的预训练模型(如GPT,XLNet,CTRL等)都有这个功能。
# 但并不是所有模型的past功能都是通过past参数进行设置的,有的模型虽然使用的参数名称是mems,但作用与pat参数一样。
past = None # 定义模型参数for i in range(100):    # 循环100次with torch.no_grad():output, past = model(tokens_tensor, past=past)  # 预测下一次token = torch.argmax(output[..., -1, :])indexed_tokens += [token.tolist()]  # 将预测结果收集if stopids == token.tolist():   # 当预测出句号时,终止预测。breaktokens_tensor = token.unsqueeze(0)  # 定义下一次预测的输入张量sequence = tokenizer.decode(indexed_tokens) # 进行字符串编码
print(sequence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469141.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电容式传感器位移性能试验报告_一文读懂什么是接近传感器?

点击上方蓝字 记得关注我们哦!接近传感器是一种非接触式传感器,当目标进入传感器的视野时,它会检测到物体(通常称为“目标”)的存在。取决于接近传感器的类型,传感器可以利用声音,光,红外辐射(IR)或电磁场来…

[dts]DTS实例分析

2. 通常会碰到的实际问题 到此,问题出现了: 1. 当写一个按键驱动,应该如何在*.dts或者*.dtsi中操作? 2. 当在串口driver中需要使用到某个pin脚作为普通输出IO,该如何操作? 3. 当在串口driver中需要使用某个muxpin脚作…

【解决】YOLOv6.1安装requirements.txt报错UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84

案例描述: 使用YOLOV5时,报错解决UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84 in position 285: illegal multibyte sequence 解决方案: 在C:\ProgramData\Anaconda3\lib\distutils\dist.py"文件搜索read 将parser.read(filenam…

bp 神经网络 优点 不足_深度学习之BP神经网络--Stata和R同步实现(附Stata数据和代码)

说明:本文原发于“计量经济圈”公众号,在此仅展示Stata的部分。R部分请移步至本人主页的“R语言与机器学习--经济学视角”专栏,或点击下方链接卡跳转。盲区行者:深度学习之BP神经网络--Stata和R同步实现(附R数据和代码…

高通平台中gpio简单操作和调试

做底层驱动免不了gpio打交道,所以对其操作和调试进行了一下简单的梳理 一、gpio的调试方法 在Linux下,通过sysfs,获取gpio状态,也可以操作gpio。 1、获取gpio状态 cd /sys/kernel/debug/ cat gpio 2、操作gpio(以gpi…

Win10系统下使用anaconda在虚拟环境下安装CUDA及CUDNN

前排预警:不要挂梯子!!!!!使用清华源就行不然报错!!!! 解决check_hostname requires server_hostname_orange_の_呜的博客-CSDN博客错误描述在GitHub下载代码文件后使用pip install -r requirement.txt下载依赖包时出…

jlink问题

现在淘宝上买到的JLINK都是盗版的,用着用着的时候就会遇到各种异常问题, 这里有一个方法来修改SN,修改SN后就会变得正常了,亲测有效 两种固件: V*_ID-自定义.* 是出厂设置,烧入后用jlink.exe连接上S/N会显示-1. 此时可以根据自己的…

将XML格式转化为YOLO需要的txt格式(代码)

1、XML的格式 <annotation><folder>cr</folder><filename>crazing_2.jpg</filename><source><database>NEU-DET</database></source><size><width>200</width><height>200</height><…

js 点击button切换颜色_ThingJS 和three.js开发示例对比,让开发早点下班回家!3D 可视化...

ThingJS 3D框架简化了开发工作&#xff0c;面向对象和模块化的特点使得网页代码更加易于管理和维护&#xff0c;并且提供近200个官方示例&#xff0c;直接获取API能力&#xff0c;不需要基于3D概念进行开发&#xff0c;适合3D商业项目快速生成&#xff01;距离您的业务仅一层之…

变量命名

列举一下我自己的一些写法 local_int_loop_count global_int_data_count local_bool_plug_insert_flag global_bool_ble_connect_flag函数命名 get_tick_number set_tick_number为了代码清晰易懂&#xff0c;通常变量名采用一些著名的命名规则&#xff0c;主要有Camel标记法&am…

VSCode使用技巧——Ctrl+鼠标滚轮键使字体进行缩放

点击VSCode左下角的齿轮&#xff0c;进入设置 进入Extensions——》JSON——》Edit in settings.json 在json当中添加如下&#xff1a; "editor.mouseWheelZoom": true,

python 交互式可视化库_Python 交互式可视化库

Python 交互式可视化库 所属分类&#xff1a;中间件编程 开发工具&#xff1a;Python 文件大小&#xff1a;12843KB 下载次数&#xff1a;1 上传日期&#xff1a;2018-12-06 18:40:56 上 传 者&#xff1a;孤独的老张 说明&#xff1a; 一个 Python 交互式可视化库&#xff0c;…

OpenCV各版本差异与演化,从1.x到4.0

最近因项目需要&#xff0c;得把OpenCV捡起来&#xff0c;登录OpenCV官网&#xff0c;竟然发现release了4.0.0-beata版本&#xff0c;所以借此机会&#xff0c;查阅资料&#xff0c;了解下OpenCV各版本的差异及其演化过程&#xff0c;形成了以下几点认识&#xff1a; 新版本的…

python题库刷题训练软件_Python基础练习100题 ( 11~ 20)

刷题继续 上一期和大家分享了前10道题&#xff0c;今天继续来刷11~20 Question 11: Write a program which accepts a sequence of comma separated 4 digit binary numbers as its input and then check whether they are divisible by 5 or not. The numbers that are divisi…

如何学习计算机图形学

http://blog.csdn.net/szchtx/article/details/6916675转载于:https://www.cnblogs.com/ArcherHuang/p/6574560.html

shell for循环

weiqifaubuntu:~/qcom$ for i in $(seq 1 1 10) > do > echo "hello World" > done hello World hello World hello World hello World hello World hello World hello World hello World hello World hello World weiqifaubuntu:~/qcom$ 输入for i in $(s…

西门子s7-200解密软件下载_西门子S7-200/300/400通讯方式汇总,超级全面

1西门子 200 plc 使用 MPI 协议与组态王进行通讯时需要哪些设置?1)在运行组态王的机器上需要安装西门子公司提供的 STEP7 Microwin 3.2 的编程软件&#xff0c;我们的驱动需要调用编程软件提供的 MPI 接口库函数;2)需要将 MPI 通讯卡 CP5611 卡安装在计算机的插槽中&#xff0…

如何监控NVIDIA GPU 的运行状态和使用情况

设备跟踪和管理正成为机器学习工程的中心焦点。这个任务的核心是在模型训练过程中跟踪和报告gpu的使用效率。 有效的GPU监控可以帮助我们配置一些非常重要的超参数&#xff0c;例如批大小&#xff0c;还可以有效的识别训练中的瓶颈&#xff0c;比如CPU活动(通常是预处理图像)占…

进程和线程的本质和区别

进程是什么&#xff1f; 程序并不能单独运行&#xff0c;只有将程序装载到内存中&#xff0c;系统为它分配资源才能运行&#xff0c;而这种执行的程序就称之为进程。程序和进程的区别就在于&#xff1a;程序是指令的集合&#xff0c;它是进程运行的静态描述文本&#xff1b;进程…

HBase学习笔记——概念及原理

1.什么是HBase HBase – Hadoop Database&#xff0c;是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统&#xff0c;利用HBase技术可在廉价PC Server上搭建起大规模结构化存储集群。HBase利用Hadoop HDFS作为其文件存储系统&#xff0c;利用Hadoop MapReduce来处理HBas…