【Pytorch神经网络实战案例】34 使用GPT-2模型实现句子补全功能(手动加载)

 

1 GPT-2 模型结构

GPT-2的整体结构如下图,GPT-2是以Transformer为基础构建的,使用字节对编码的方法进行数据预处理,通过预测下一个词任务进行预训练的语言模型。

1.1 GPT-2 功能简介

GPT-2 就是一个语言模型,能够根据上文预测下一个单词,所以它就可以利用预训练已经学到的知识来生成文本,如生成新闻。也可以使用另一些数据进行微调,生成有特定格式或者主题的文本,如诗歌、戏剧。

2 手动加载GPT-2模型并实现语句与完整句子预测

使用GPT-2模型配套的PreTrainedTokenizer类,所需要加载的词表文件比BERT模型多了一个merges文件。

2.1 代码实现:手动加载GPT-2模型并实现下一个单词预测---GPT2_make.py(第1部分)

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel# 案例描述:Transformers库中的GPT-2模型,并用它实现下一词预测功能,即预测一个未完成句子的下一个可能出现的单词。
# 下一词预测任务是一个常见的任务,在Transformers库中有很多模型都可以实现该任务。也可以使用BERT模型来实现。选用GPT-2模型,主要在于介绍手动加载多词表文件的特殊方式。# 1.1 加载词表文件# 自动加载预训练模型(权重)
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 手动加载词表文件:gpt2-merges.txt gpt2-vocab.json。
# from_pretrained方法是支持从本地载入多个词表文件的,但对载入的词表文件名称有特殊的要求:该文件名称必须按照源码文件tokenization_gpt2.py的VOCAB_FILES_NAMES字典对象中定义的名字来命名。
# 故使用from_pretrained方法,必须对已经下载好的词表文件进行改名将/gpt2/gpt2-vocab.json和/gpt2/gpt2-merges.txt这两个文件,分别改名为“gpt2/vocab.json和/gpt2/merges.txt
tokenizer = GPT2Tokenizer.from_pretrained(r'./models/gpt2') # 自动加载改名后的文件# 编码输入
indexed_tokens = tokenizer.encode("Who is Li BiGor ? Li BiGor is a")
print("输入语句为:",tokenizer.decode(indexed_tokens))
tokens_tensor = torch.tensor([indexed_tokens])  # 将输入语句转换为张量# 自动加载预训练模型(权重)
# model = GPT2LMHeadModel.from_pretrained('gpt2')
# 手动加载:配置文件gpt2-config.json 与 权重文件pt2-pytorch_model.bin
model = GPT2LMHeadModel.from_pretrained('./models/gpt2/gpt2-pytorch_model.bin',config='./models/gpt2/gpt2-config.json')# 将模型设置为评估模式
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokens_tensor = tokens_tensor.to(DEVICE)
model.to(DEVICE)# 预测所有标记
with torch.no_grad():outputs = model(tokens_tensor)predictions = outputs[0]# 得到预测的下一词
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print("输出语句为:",predicted_text) # GPT-2模型没有为输入文本添加特殊词。
# 输出:Who is Li BiGor? Li BiGor is a Chinese

2.2 代码实现:手动加载GPT-2模型并实现完整句子预测---GPT2_make.py(第2部分)

# 案例描述:Transformers库中的GPT-2模型,通过循环生成下一词,实现将一句话补充完整。
# 1.2 生成一段完整的话 这里有BUg 暂不会改
stopids = tokenizer.convert_tokens_to_ids(["."])[0] # 定义结束符# 在循环调用模型预测功能时,使用了模型的past功能。该功能以使模型进入连续预测状态,即在前面预测结果的基础之上进行下一词预测,而不需要在每预测时,对所有句子进行重新处理。
# past功能是使用预训练模型时很常用的功能,在Transformers库中,凡是带有下一词预测功能的预训练模型(如GPT,XLNet,CTRL等)都有这个功能。
# 但并不是所有模型的past功能都是通过past参数进行设置的,有的模型虽然使用的参数名称是mems,但作用与pat参数一样。
past = None # 定义模型参数for i in range(100):    # 循环100次with torch.no_grad():output, past = model(tokens_tensor, past=past)  # 预测下一次token = torch.argmax(output[..., -1, :])indexed_tokens += [token.tolist()]  # 将预测结果收集if stopids == token.tolist():   # 当预测出句号时,终止预测。breaktokens_tensor = token.unsqueeze(0)  # 定义下一次预测的输入张量sequence = tokenizer.decode(indexed_tokens) # 进行字符串编码
print(sequence)

3 GPT2_make.py(汇总)

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel# 案例描述:Transformers库中的GPT-2模型,并用它实现下一词预测功能,即预测一个未完成句子的下一个可能出现的单词。
# 下一词预测任务是一个常见的任务,在Transformers库中有很多模型都可以实现该任务。也可以使用BERT模型来实现。选用GPT-2模型,主要在于介绍手动加载多词表文件的特殊方式。# 1.1 加载词表文件# 自动加载预训练模型(权重)
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# 手动加载词表文件:gpt2-merges.txt gpt2-vocab.json。
# from_pretrained方法是支持从本地载入多个词表文件的,但对载入的词表文件名称有特殊的要求:该文件名称必须按照源码文件tokenization_gpt2.py的VOCAB_FILES_NAMES字典对象中定义的名字来命名。
# 故使用from_pretrained方法,必须对已经下载好的词表文件进行改名将/gpt2/gpt2-vocab.json和/gpt2/gpt2-merges.txt这两个文件,分别改名为“gpt2/vocab.json和/gpt2/merges.txt
tokenizer = GPT2Tokenizer.from_pretrained(r'./models/gpt2') # 自动加载改名后的文件# 编码输入
indexed_tokens = tokenizer.encode("Who is Li BiGor ? Li BiGor is a")
print("输入语句为:",tokenizer.decode(indexed_tokens))
tokens_tensor = torch.tensor([indexed_tokens])  # 将输入语句转换为张量# 自动加载预训练模型(权重)
# model = GPT2LMHeadModel.from_pretrained('gpt2')
# 手动加载:配置文件gpt2-config.json 与 权重文件pt2-pytorch_model.bin
model = GPT2LMHeadModel.from_pretrained('./models/gpt2/gpt2-pytorch_model.bin',config='./models/gpt2/gpt2-config.json')# 将模型设置为评估模式
model.eval()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokens_tensor = tokens_tensor.to(DEVICE)
model.to(DEVICE)# 预测所有标记
with torch.no_grad():outputs = model(tokens_tensor)predictions = outputs[0]# 得到预测的下一词
predicted_index = torch.argmax(predictions[0, -1, :]).item()
predicted_text = tokenizer.decode(indexed_tokens + [predicted_index])
print("输出语句为:",predicted_text) # GPT-2模型没有为输入文本添加特殊词。
# 输出:Who is Li BiGor? Li BiGor is a Chinese# 案例描述:Transformers库中的GPT-2模型,通过循环生成下一词,实现将一句话补充完整。
# 1.2 生成一段完整的话 这里有BUg 暂不会改
stopids = tokenizer.convert_tokens_to_ids(["."])[0] # 定义结束符# 在循环调用模型预测功能时,使用了模型的past功能。该功能以使模型进入连续预测状态,即在前面预测结果的基础之上进行下一词预测,而不需要在每预测时,对所有句子进行重新处理。
# past功能是使用预训练模型时很常用的功能,在Transformers库中,凡是带有下一词预测功能的预训练模型(如GPT,XLNet,CTRL等)都有这个功能。
# 但并不是所有模型的past功能都是通过past参数进行设置的,有的模型虽然使用的参数名称是mems,但作用与pat参数一样。
past = None # 定义模型参数for i in range(100):    # 循环100次with torch.no_grad():output, past = model(tokens_tensor, past=past)  # 预测下一次token = torch.argmax(output[..., -1, :])indexed_tokens += [token.tolist()]  # 将预测结果收集if stopids == token.tolist():   # 当预测出句号时,终止预测。breaktokens_tensor = token.unsqueeze(0)  # 定义下一次预测的输入张量sequence = tokenizer.decode(indexed_tokens) # 进行字符串编码
print(sequence)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469141.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电容式传感器位移性能试验报告_一文读懂什么是接近传感器?

点击上方蓝字 记得关注我们哦!接近传感器是一种非接触式传感器,当目标进入传感器的视野时,它会检测到物体(通常称为“目标”)的存在。取决于接近传感器的类型,传感器可以利用声音,光,红外辐射(IR)或电磁场来…

[dts]DTS实例分析

2. 通常会碰到的实际问题 到此,问题出现了: 1. 当写一个按键驱动,应该如何在*.dts或者*.dtsi中操作? 2. 当在串口driver中需要使用到某个pin脚作为普通输出IO,该如何操作? 3. 当在串口driver中需要使用某个muxpin脚作…

【解决】YOLOv6.1安装requirements.txt报错UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84

案例描述: 使用YOLOV5时,报错解决UnicodeDecodeError: ‘gbk‘ codec can‘t decode byte 0x84 in position 285: illegal multibyte sequence 解决方案: 在C:\ProgramData\Anaconda3\lib\distutils\dist.py"文件搜索read 将parser.read(filenam…

bp 神经网络 优点 不足_深度学习之BP神经网络--Stata和R同步实现(附Stata数据和代码)

说明:本文原发于“计量经济圈”公众号,在此仅展示Stata的部分。R部分请移步至本人主页的“R语言与机器学习--经济学视角”专栏,或点击下方链接卡跳转。盲区行者:深度学习之BP神经网络--Stata和R同步实现(附R数据和代码…

Win10系统下使用anaconda在虚拟环境下安装CUDA及CUDNN

前排预警:不要挂梯子!!!!!使用清华源就行不然报错!!!! 解决check_hostname requires server_hostname_orange_の_呜的博客-CSDN博客错误描述在GitHub下载代码文件后使用pip install -r requirement.txt下载依赖包时出…

将XML格式转化为YOLO需要的txt格式(代码)

1、XML的格式 <annotation><folder>cr</folder><filename>crazing_2.jpg</filename><source><database>NEU-DET</database></source><size><width>200</width><height>200</height><…

js 点击button切换颜色_ThingJS 和three.js开发示例对比,让开发早点下班回家!3D 可视化...

ThingJS 3D框架简化了开发工作&#xff0c;面向对象和模块化的特点使得网页代码更加易于管理和维护&#xff0c;并且提供近200个官方示例&#xff0c;直接获取API能力&#xff0c;不需要基于3D概念进行开发&#xff0c;适合3D商业项目快速生成&#xff01;距离您的业务仅一层之…

VSCode使用技巧——Ctrl+鼠标滚轮键使字体进行缩放

点击VSCode左下角的齿轮&#xff0c;进入设置 进入Extensions——》JSON——》Edit in settings.json 在json当中添加如下&#xff1a; "editor.mouseWheelZoom": true,

OpenCV各版本差异与演化,从1.x到4.0

最近因项目需要&#xff0c;得把OpenCV捡起来&#xff0c;登录OpenCV官网&#xff0c;竟然发现release了4.0.0-beata版本&#xff0c;所以借此机会&#xff0c;查阅资料&#xff0c;了解下OpenCV各版本的差异及其演化过程&#xff0c;形成了以下几点认识&#xff1a; 新版本的…

西门子s7-200解密软件下载_西门子S7-200/300/400通讯方式汇总,超级全面

1西门子 200 plc 使用 MPI 协议与组态王进行通讯时需要哪些设置?1)在运行组态王的机器上需要安装西门子公司提供的 STEP7 Microwin 3.2 的编程软件&#xff0c;我们的驱动需要调用编程软件提供的 MPI 接口库函数;2)需要将 MPI 通讯卡 CP5611 卡安装在计算机的插槽中&#xff0…

如何监控NVIDIA GPU 的运行状态和使用情况

设备跟踪和管理正成为机器学习工程的中心焦点。这个任务的核心是在模型训练过程中跟踪和报告gpu的使用效率。 有效的GPU监控可以帮助我们配置一些非常重要的超参数&#xff0c;例如批大小&#xff0c;还可以有效的识别训练中的瓶颈&#xff0c;比如CPU活动(通常是预处理图像)占…

HBase学习笔记——概念及原理

1.什么是HBase HBase – Hadoop Database&#xff0c;是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统&#xff0c;利用HBase技术可在廉价PC Server上搭建起大规模结构化存储集群。HBase利用Hadoop HDFS作为其文件存储系统&#xff0c;利用Hadoop MapReduce来处理HBas…

.bat是什么语言_简单说说当我们打开网页时,浏览器到底做了什么?

前言&#xff1a;为什么我们需要掌握浏览器的原理作为一名前端研发&#xff0c;平日里打交道最多的&#xff0c;就是各式各样的客户端。不论你是针对pc端还是移动端&#xff0c;甚至是专门在微信端做前端研发&#xff0c;都需要跟一样东西接触——浏览器。不知道你有没有留意过…

花书《深度学习》代码实现:02 概率部分:概率密度函数+期望+常见概率分布代码实现

1 概率 1.1 概率与随机变量 频率学派概率 (Frequentist Probability)&#xff1a;认为概率和事件发⽣的频率相关。贝叶斯学派概率 (Bayesian Probability)&#xff1a;认为概率是对某件事发⽣的确定程度&#xff0c;可以理解成是确信的程度。随机变量 (Random Variable)&…

内存泄露Lowmemorykiller分析

01 前言 最近疫苗事情非常火热,这件事情让我对刘强东有点刮目相看,我们需要更多的人关注曝光此类问题 02 正文 Android Kernel 会定时执行一次检查,杀死一些进程,释放掉内存。Low memory killer 是定时进行检查。Low memory killer 主要是通过进程的oom_adj 来判定进程的…

TabError: inconsistent use of tabs and spaces in indentation

本文使用PyCharm的格式化代码功能解决TabError: inconsistent use of tabs and spaces in indentation。 1、提出问题&#xff1a; 当把代码从别处复制进来PyCharm&#xff0c;然后运行报错&#xff1a;TabError: inconsistent use of tabs and spaces in indentation 2、 分…

python 默认参数_有趣的 Python 特性 3 | 当心默认可变参数这个大猪蹄子。

本文字数&#xff1a;1575 字阅读本文大概需要&#xff1a;4 分钟写在之前Python 提供了很多让使用者觉得舒服至极的功能特性&#xff0c;但是随着不断的深入学习和使用 Python&#xff0c;我发现其中存在着许多玄学的输出与之前预想的结果大相径庭&#xff0c;这个对于初学者来…

AI-无损检测方向速读:基于深度学习的表面缺陷检测方法综述

1 表面缺陷检测的概念 表面缺陷检测是机器视觉领域中非常重要的一项研究内容, 也称为 AOI (Automated optical inspection) 或 ASI (Automated surface inspection)&#xff0c;它是利用机器视觉设备获取图像来判断采集图像中是否存在缺陷的技术。 1.1 传统检测的缺陷(非CNN)…

【完美解决】RuntimeError: one of the variables needed for gradient computation has been modified by an inp

正文在后面&#xff0c;往下拉即可~~~~~~~~~~~~ 欢迎各位深度学习的小伙伴订阅的我的专栏 Pytorch深度学习理论篇实战篇(2023版)专栏地址&#xff1a; &#x1f49b;Pytorch深度学习理论篇(2023版)https://blog.csdn.net/qq_39237205/category_12077968.html &#x1f49a;Pyt…

python正则表达式入门_Python中的正则表达式教程

本文http://www.cnblogs.com/huxi/archive/2010/07/04/1771073.html 正则表达式经常被用到&#xff0c;而自己总是记不全&#xff0c;转载一份完整的以备不时之需。 1. 正则表达式基础 1.1. 简单介绍 正则表达式并不是Python的一部分。正则表达式是用于处理字符串的强大工具&am…