用Transformers实现简单的大模型文本生成

根据输入的prompt,生成一段指定长度的文字。Llama跑起来太慢了,这里用GPT-2作为列子。

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torchtokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}')print('>>> Way 1: Use `model.generate()` to generate tokens with KV cache')
generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id)
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('generated_text:', generated_text)print('>>> Way 2: Use `for loop` to generate tokens with KV cache')
past_key_values = None
print('Prefill Stage..')
outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
print('Decoding/Generating Stage..')
for _ in range(max_gen_len - 1):outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)past_key_values = outputs.past_key_values  # if use_cache=False, past_key_values=Nonepred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)generated_ids.append(pred_token_idx.item())
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True)
print('generated_text:', prompt_text + generated_text)

这里提供了两种方法实现文本生成:

  • model.generate():给模型输入prompt,一次性得到所有输出的token,最方便的写法
  • for loop:这是StreamingLLM中给的代码例子,也揭示了自回归生成的原理。首先是prefill阶段,输入prompt,得到KV cache和生成的第一个token;然后是decoding/generating阶段,开始自回归生成token,每次生成的模型输入是当前新token和KV cache,每生成一个token都会自动更新KV cache

最终,可以看到两种方法生成的文本是一模一样的:

在这里插入图片描述

进一步探究自回归过程中维度的变化:

在这里插入图片描述

这个就是标准的自回归生成任务了,不管是GPT还是Llama,都是如此(至少PyTorch版本都是这样的,Flax版本的KV cache有点奇怪,用的lax.dynamic_update_slice(cached_key.value, key, indices),KV cache的维度并没有随着token的生成而增加…不太明白)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/11697.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

打造清洁宜居家园保护自然生态环境,基于YOLOv7【tiny/l/x】参数系列模型开发构建自然生态场景下违规违法垃圾倾倒检测识别系统

自然生态环境,作为我们人类赖以生存的家园,其健康与否直接关系到我们的生活质量。然而,近年来,一些不法分子为了个人私利,在河边、路边等公共区域肆意倾倒垃圾,严重破坏了环境的健康与平衡。这种行为不仅损…

计算机视觉的应用30-基于深度卷积神经网络CNN模型实现物体表面缺陷检测技术的项目

大家好,我是微学AI,今天给大家介绍一下计算机视觉的应用30-基于深度卷积神经网络CNN模型实现物体表面缺陷检测技术的项目主要包括:物体表面缺陷检测技术项目介绍,数据构造,模型介绍。 物体表面缺陷检测技术是工业自动化…

[附源码]剑灵三系可乐6.1_Win服务端_联网+单机搭建

本教程仅限学习使用,禁止商用,一切后果与本人无关,此声明具有法律效应!!!! 教程是本人亲自搭建成功的,绝对是完整可运行的,踩过的坑都给你们填上了。 如果你是小白也没…

YOLOv9-20240507周更说明|更新MobileNetv4等多种轻量化主干

专栏地址:目前售价售价69.9,改进点70 专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,助力高效涨点!!! 本周已更新说明: ### ⭐⭐更新时间:2024/5/12⭐⭐ 1. YOLOv9…

SQL Server “provider: Named Pipes Provider, error: 40 -无法打开到SQL Server的连接“错误处理

目录 错误提醒解决办法 错误提醒 连接SQL Server时显示如下错误: 解决办法 (1)首先,打开SQL Server Configuration Manager配置管理器 (2) 停止SQL Server服务 右键点击后,选择【停止】 (3) 启动TCP/IP &…

Co-Driver:基于 VLM 的自动驾驶助手,具有类人行为并能理解复杂的道路场景

24年5月来自俄罗斯莫斯科研究机构的论文“Co-driver: VLM-based Autonomous Driving Assistant with Human-like Behavior and Understanding for Complex Road Scenes”。 关于基于大语言模型的自动驾驶解决方案的最新研究,显示了规划和控制领域的前景。 然而&…

Bittensor怎么挖?手把手教你,使用bitget钱包

4月 Binance 上新 TheBittensorHub (TAO), 这个项目究竟做了什么可以令其在上大舞台前就已经在所有通证中排名前 30? 本文将深度解析。 该项目既不直接贡献数据,也不直接贡献算力。 而是通过区块链网络和激励机制,来对不同的算法进行调度和…

【HarmonyOS】综合应用-《校园通》

概念 本文结合之前的笔记文章知识点,做一个综合性的小应用。 创建一个ArkTS语言的鸿蒙项目,搭建首页面 其界面代码如下,该界面使用了垂直布局,相对布局,轮播布局,以及图片,文本等组件的综合运…

具身智能论文(一)

目录 1. PoSE: Suppressing Perceptual Noise in Embodied Agents for Enhanced Semantic Navigation2. Embodied Intelligence: Bionic Robot Controller Integrating Environment Perception, Autonomous Planning, and Motion Control3. Can an Embodied Agent Find Your “…

免费的国内版 GPT 推荐,5个国产ai工具

提起AI,大家第一个想到的就是GPT。 虽然它确实很厉害,但奈何于我们水土不服,使用门槛有些高。 不过随着GPT的爆火,现在AI智能工具已经遍布到各行各业了,随着时间的推移,国内的AI工具也已经“百花盛放”了…

Pencils Protocol 提供层次化的 Staking,品牌升级不断

Pencils Protocol 是一个 Scroll 生态中的一个综合应用平台,在全新的品牌升级后(原为 Penpad),其在原有的 LaunchPad 的基础上,进一步向收益聚合器、RWA 等板块进行全新的拓展。目前,Pencils Protocol 生态的整体功能板块包括 Lau…

人脸识别技术在访客管理中的应用

访客办理体系,能够使用于政府、戎行、企业、医院、写字楼等众多场所。在办理时,需求对来访人员身份进行精确认证,才能保证来访人员的进入对被访单位不被外来风险入侵。在核实身份时,比较好的方法就是选用人脸辨认技能,…

bat xcopy 解析

echo off set source_folder"C:\path\to\source" set destination_folder"C:\path\to\destination" set exclude_file"C:\path\to\excluded_folders.txt"REM 创建目标文件夹(如果不存在) mkdir %destination_folder% 2>…

JDK的串行收集器介绍与优化指南-01

JDK串行收集器概述 定义与背景 串行收集器(Serial Collector)是Java虚拟机(JVM)中的一种单线程垃圾收集器,它在垃圾收集过程中会暂停所有工作线程,直至收集完成。它适用于内存资源受限、对吞吐量要求不高…

【玄机平台】应急响应

前言: 感谢玄机平台靶机的提供,让我学到了不少东西 平台题解 : 第一章 应急响应-webshell查杀 1.黑客webshell里面的flag flag{xxxxx-xxxx-xxxx-xxxx-xxxx} ssh连接 下载/var/www/html源码(finsehll连直接下)压缩丢…

JavaWeb--15 tlias-web-management 黑马程序员 部门管理(修改部门信息)

tlias 1 需求分析和开发规范2 部门管理2.1 查询部门2.2 删除部门2.3 添加部门2.4 更新部门 1 需求分析和开发规范 需求说明–接口文档–思路分析–开发–测试–前后端联调 查看页面原型明确需求 根据页面原型和需求,进行表结构设计、编写接口文档(已提供) 阅读接口…

vue3专栏项目 -- 四、前后端结合(下)

一、async 和 await 1、使用async 和 await 改造异步请求 在接触后端API以后就遇到了越来越多的异步请求,现在我们就使用async 和 await 改造异步请求。 async function是把返回内容包裹成个Promise返回Promise await 它在async function里面才起作用&#xff0…

ini配置文件怎么存取False

1、ini文件介绍 INI文件(全称为Initialization File,初始化文件)是一种简单的文本文件格式,用于存储配置数据。它广泛应用于操作系统和各种应用程序中,用来保存设置、参数或初始化信息。INI文件的基本结构包括节&…

Office之Word应用(二)

一、页眉添加文件名称和页码 1、双击页眉,点击“页眉-空白(三栏)” 2、删掉第一处(鼠标放在上面就会选中,Enter即可),第二处输入文档名称,第三处插入页码。 注:插入页码时…

Jmeter 性能-阶梯负载最终请求数

1、设置阶梯加压线程组请求参数 说明: 每隔2秒钟,会在1秒内启动5个线程 每次线程加载之后都会运行2s然后开始下一次线程加载 最终会加载50个线程并持续运行30s 50个线程持续运行30s后,会每隔2秒钟停止5个线程,剩余的线程继续负…