GenerationMixin概述

类名简单说明
GenerateDecoderOnlyOutput继承自 ModelOutput,适用于非束搜索方法的解码器-only模型输出类。
GenerateEncoderDecoderOutput继承自 ModelOutput,适用于非束搜索方法的编码器-解码器模型输出类。
GenerateBeamDecoderOnlyOutput继承自 ModelOutput,适用于束搜索方法的解码器-only模型输出类。
GenerateBeamEncoderDecoderOutput继承自 ModelOutput,适用于束搜索方法的编码器-解码器模型输出类。
GreedySearchDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的别名。
ContrastiveSearchDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的别名。
SampleDecoderOnlyOutputGenerateDecoderOnlyOutput 相同,保留用于向后兼容的别名。
GreedySearchEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的别名。
ContrastiveSearchEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的别名。
SampleEncoderDecoderOutputGenerateEncoderDecoderOutput 相同,保留用于向后兼容的别名。
BeamSearchDecoderOnlyOutputGenerateBeamDecoderOnlyOutput 相同,保留用于向后兼容的别名。
BeamSampleDecoderOnlyOutputGenerateBeamDecoderOnlyOutput 相同,保留用于向后兼容的别名。
BeamSearchEncoderDecoderOutputGenerateBeamEncoderDecoderOutput 相同,保留用于向后兼容的别名。
BeamSampleEncoderDecoderOutputGenerateBeamEncoderDecoderOutput 相同,保留用于向后兼容的别名。
GreedySearchOutputGreedySearchEncoderDecoderOutputGreedySearchDecoderOnlyOutput 的联合类型。
SampleOutputSampleEncoderDecoderOutputSampleDecoderOnlyOutput 的联合类型。
BeamSearchOutputBeamSearchEncoderDecoderOutputBeamSearchDecoderOnlyOutput 的联合类型。
BeamSampleOutputBeamSampleEncoderDecoderOutputBeamSampleDecoderOnlyOutput 的联合类型。
ContrastiveSearchOutputContrastiveSearchEncoderDecoderOutputContrastiveSearchDecoderOnlyOutput 的联合类型。
GenerateNonBeamOutputGenerateDecoderOnlyOutputGenerateEncoderDecoderOutput 的联合类型。
GenerateBeamOutputGenerateBeamDecoderOnlyOutputGenerateBeamEncoderDecoderOutput 的联合类型。
GenerateOutputGenerateNonBeamOutputGenerateBeamOutput 的联合类型。
GenerationMixin包含自动回归文本生成所有功能的类,可作为 PreTrainedModel 的 mixin 使用。
  • 定义了多个数据类(@dataclass),这些类继承自 ModelOutput,用于表示生成模型在不同情况下的输出结果。
    Python:@dataclass装饰器

  • 定义了一些等价的类和类型简写(typing shortcuts),主要是为了兼容旧版本的代码,也方便在代码中进行类型提示。

重点解释以下三个类:

  1. GenerateDecoderOnlyOutput

  2. GenerateEncoderDecoderOutput

  3. GenerateNonBeamOutput


1. GenerateDecoderOnlyOutput

描述:

GenerateDecoderOnlyOutput 是一个数据类,用于表示 仅解码器模型(decoder-only models) 在使用 非束搜索方法(non-beam methods) 进行生成时的输出结果。

主要用途:

此类主要用于像 GPT-2、GPT-3 等仅包含解码器的模型,当它们使用贪婪搜索(Greedy Search)、随机采样(Sampling)、对比搜索(Contrastive Search)等非束搜索方法进行文本生成时,封装和返回生成的结果。

2. GenerateEncoderDecoderOutput

描述:

GenerateEncoderDecoderOutput 是一个数据类,用于表示 编码器-解码器模型(encoder-decoder models) 在使用 非束搜索方法(non-beam methods) 进行生成时的输出结果。

主要用途:

此类主要用于像 BART、T5 等包含编码器和解码器的模型,当它们使用贪婪搜索、随机采样、对比搜索等非束搜索方法进行文本生成时,封装和返回生成的结果。

GenerateDecoderOnlyOutputGenerateEncoderDecoderOutput

字段名GenerateDecoderOnlyOutputGenerateEncoderDecoderOutput
sequences必填
torch.LongTensor
形状:(batch_size, sequence_length)
生成的序列。
必填
torch.LongTensor
形状:(batch_size * num_return_sequences, sequence_length)
生成的序列。
scores可选
Optional[Tuple[torch.FloatTensor]]
output_scores=True 时返回。
处理后的预测分数(每步)。
可选
Optional[Tuple[torch.FloatTensor]]
同左。
logits可选
Optional[Tuple[torch.FloatTensor]]
output_logits=True 时返回。
未经处理的预测分数(每步)。
可选
Optional[Tuple[torch.FloatTensor]]
同左。
attentions可选
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_attentions=True 时返回。
解码器每层的注意力权重。
名称不同
GenerateEncoderDecoderOutput 中,该字段为 decoder_attentions
hidden_states可选
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_hidden_states=True 时返回。
解码器每层的隐藏状态。
名称不同
GenerateEncoderDecoderOutput 中,该字段为 decoder_hidden_states
past_key_values可选
Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
use_cache=True 时返回。
模型的缓存状态。
可选
Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
同左。
encoder_attentions无此字段可选
Optional[Tuple[torch.FloatTensor]]
output_attentions=True 时返回。
编码器每层的注意力权重。
encoder_hidden_states无此字段可选
Optional[Tuple[torch.FloatTensor]]
output_hidden_states=True 时返回。
编码器每层的隐藏状态。
decoder_attentions无此字段
对应于 attentions 字段。
可选
Optional[Tuple[Tuple[torch.FloatTensor]]]
解码器每层的注意力权重。
decoder_hidden_states无此字段
对应于 hidden_states 字段。
可选
Optional[Tuple[Tuple[torch.FloatTensor]]]
解码器每层的隐藏状态。
cross_attentions无此字段可选
Optional[Tuple[Tuple[torch.FloatTensor]]]
output_attentions=True 时返回。
解码器每层的跨注意力权重。

字段详解

共有字段
  • sequences

    • 描述:生成的序列。
    • GenerateDecoderOnlyOutput:形状为 (batch_size, sequence_length)
    • GenerateEncoderDecoderOutput:形状为 (batch_size * num_return_sequences, sequence_length)
  • scores

    • 描述:处理后的预测分数(即在 SoftMax 之前的 logits),每步生成一个。
    • 类型Optional[Tuple[torch.FloatTensor]]
    • 返回条件output_scores=True
  • logits

    • 描述:未经处理的预测分数(logits),每步生成一个。
    • 类型Optional[Tuple[torch.FloatTensor]]
    • 返回条件output_logits=True
  • past_key_values

    • 描述:模型的缓存状态,用于加速解码。
    • 类型Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]]
    • 返回条件use_cache=True
仅在 GenerateDecoderOnlyOutput
  • attentions

    • 描述:解码器的注意力权重。
    • 类型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回条件output_attentions=True
  • hidden_states

    • 描述:解码器的隐藏状态。
    • 类型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回条件output_hidden_states=True
仅在 GenerateEncoderDecoderOutput
  • encoder_attentions

    • 描述:编码器的注意力权重。
    • 类型Optional[Tuple[torch.FloatTensor]]
    • 返回条件output_attentions=True
  • encoder_hidden_states

    • 描述:编码器的隐藏状态。
    • 类型Optional[Tuple[torch.FloatTensor]]
    • 返回条件output_hidden_states=True
  • decoder_attentions

    • 描述:解码器的注意力权重(相当于 GenerateDecoderOnlyOutput 中的 attentions)。
    • 类型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回条件output_attentions=True
  • decoder_hidden_states

    • 描述:解码器的隐藏状态(相当于 GenerateDecoderOnlyOutput 中的 hidden_states)。
    • 类型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回条件output_hidden_states=True
  • cross_attentions

    • 描述:解码器的跨注意力权重(解码器与编码器之间的注意力)。
    • 类型Optional[Tuple[Tuple[torch.FloatTensor]]]
    • 返回条件output_attentions=True

3. GenerateNonBeamOutput

描述:

GenerateNonBeamOutput 是一个类型别名,用于表示在使用 非束搜索方法(non-beam methods) 进行生成时,模型的输出结果。

定义:
GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput]
含义:
  • 它可以是 GenerateDecoderOnlyOutput 类型,也可以是 GenerateEncoderDecoderOutput 类型。

  • 这个类型别名的存在,使得在处理非束搜索生成输出时,可以统一处理,不用区分模型是仅解码器模型还是编码器-解码器模型。


附加说明:

  • 非束搜索方法(Non-beam methods)

    指在生成文本时,不使用束搜索(Beam Search)算法的生成方法,例如贪婪搜索、随机采样、对比搜索等。这些方法通常速度更快,但可能生成的结果质量不如束搜索。

  • 缓存机制(Past Key Values)

    在生成长序列时,模型可以缓存之前计算的键和值,以避免重复计算,提高生成效率。缓存的内容和格式因模型而异。


GenerationMixin

以下是对 GenerationMixin 类中各个方法和属性的整理,包括分类和功能描述:

类别名称功能描述
静态方法_expand_inputs_for_generation扩展输入以用于生成,将张量从 [batch_size, ...] 扩展为 [batch_size * expand_size, ...]
实例方法prepare_inputs_for_generation准备生成所需的模型输入,包括计算注意力掩码或根据缓存裁剪输入等操作。
实例方法_prepare_model_inputs提取用于生成的模型特定输入。
实例方法_maybe_initialize_input_ids_for_generation在必要时初始化用于生成的 input_ids
实例方法_prepare_attention_mask_for_generation为生成准备注意力掩码。
实例方法_prepare_encoder_decoder_kwargs_for_generation在生成期间为编码器-解码器模型准备 kwargs
实例方法_prepare_decoder_input_ids_for_generation为编码器-解码器模型准备用于生成的 decoder_input_ids
实例方法_update_model_kwargs_for_generation更新下一步生成所需的 model_kwargs
实例方法_reorder_cache重新排序缓存,需要在子类中实现,以适应 beam search 等方法。
实例方法_get_candidate_generator返回在辅助生成中使用的候选生成器。
实例方法_get_logits_processor返回一个 LogitsProcessorList,其中包含所有用于修改分数的相关 LogitsProcessor 实例。
实例方法_get_stopping_criteria返回用于生成的 StoppingCriteriaList,包括各种停止条件。
实例方法_merge_criteria_processor_list合并默认和自定义的 criteria 或 processor 列表。
实例方法compute_transition_scores根据生成的得分计算序列的转移得分。
实例方法_validate_model_class验证模型类是否兼容生成操作。
实例方法_validate_assistant验证辅助模型(如果提供)是否兼容和正确配置。
实例方法_validate_model_kwargs验证用于生成的 model_kwargs 参数。
实例方法_validate_generated_length执行与生成长度相关的验证,确保参数设置正确。
实例方法_prepare_generated_length在生成配置中准备最大和最小长度,避免参数冲突。
实例方法_prepare_generation_config准备基础的生成配置,并应用来自 kwargs 的任何选项。
实例方法_get_initial_cache_position计算预填充阶段的 cache_position
实例方法_get_cache根据参数为生成设置缓存。
实例方法_supports_default_dynamic_cache返回模型是否支持将 DynamicCache 实例作为 past_key_values
实例方法_prepare_cache_for_generation为生成准备缓存,并将其写入 model_kwargs
实例方法_supports_logits_to_keep返回模型是否支持 logits_to_keep 参数,用于节省内存。
实例方法_prepare_special_tokens为生成准备特殊的 tokens,如 bos_token_ideos_token_id 等。
实例方法generate为具有语言模型头的模型生成 token 序列,是生成过程的主要入口方法。
实例方法_has_unfinished_sequences检查设备中是否仍然存在未完成的序列,用于确定是否继续生成循环。
实例方法heal_tokens生成 token 序列,其中每个序列的尾部 token 替换为适当的扩展,用于修复不完整的 token。
实例方法_dola_decoding使用 DoLa 解码生成序列,一种改进生成质量的解码策略。
实例方法_contrastive_search使用对比搜索生成序列,旨在改善生成文本的质量和多样性。
实例方法_sample使用多项式采样生成序列,可以实现随机性和多样性。
实例方法_temporary_reorder_cache临时函数,用于处理不同类型的缓存重新排序。
实例方法_beam_search使用 beam search 解码生成序列,支持高质量的序列生成。
实例方法_group_beam_search使用分组 beam search 解码生成序列,引入多样性。
实例方法_constrained_beam_search使用受限 beam search 解码生成序列,支持强制包含特定词语等约束。
实例方法_assisted_decoding使用辅助解码生成序列,利用辅助模型加速和改善生成过程。

GenerationMixin中最核心的方法是generate方法,其它方法都是generate方法的辅助方法:
GenerationMixin:generate

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/75760.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【备赛】蓝桥杯嵌入式实现led闪烁

原理 由于蓝桥杯的板子带有锁存器,并且与lcd屏幕有冲突,所以这个就成了考点。 主要就是用定时器来实现,同时也要兼顾lcd的冲突。 一、处理LCD函数 首先来解决与lcd屏幕冲突的问题,把我们所有用到的lcd函数改装一下。 以下是基…

C++ 并发性能优化实战:提升多线程应用的效率与稳定性

🧑 博主简介:CSDN博客专家、CSDN平台优质创作者,获得2024年博客之星荣誉证书,高级开发工程师,数学专业,拥有高级工程师证书;擅长C/C、C#等开发语言,熟悉Java常用开发技术&#xff0c…

Python----计算机视觉处理(Opencv:道路检测之车道线拟合)

完整版: Python----计算机视觉处理(Opencv:道路检测完整版:透视变换,提取车道线,车道线拟合,车道线显示) 一、获取左右车道线的原始位置 导入模块 import cv2 import numpy as np from matplot…

优选算法的妙思之流:分治——归并专题

专栏:算法的魔法世界 个人主页:手握风云 目录 一、归并排序 二、例题讲解 2.1. 排序数组 2.2. 交易逆序对的总数 2.3. 计算右侧小于当前元素的个数 2.4. 翻转对 一、归并排序 归并排序也是采用了分治的思想,将数组划分为多个长度为1的子…

C语言查漏补缺:基础篇

1.原理 C语言是一门编译型计算机语言,要编写C代码,C源代码文本文件本身无法直接执行,必须通过编译器翻译和链接器的链接,生成二进制的可执行文件,然后才能执行。这里的二进制的可执行文件就是我们最终要形成的可执行程…

TPS入门DAY02 服务器篇

1.创建空白插件 2.导入在线子系统以及在线steam子系统库 MultiplayerSessions.uplugin MultiplayerSessions.Build.cs 3.创建游戏实例以及初始化会话创建流程 创建会话需要的函数,委托,委托绑定的回调,在线子系统接口绑定某一个委托的控制其…

产品经理课程

原型工具 一、土耳其机器人 这个说法来源于 1770 年出现的一个骗局,一个叫沃尔夫冈冯肯佩伦(Wolfgang von Kempelen)的人为了取悦奥地利女皇玛丽娅特蕾莎(Maria Theresia),“制造”了一个会下国际象棋的机…

nginx中的limit_req 和 limit_conn

在 Nginx 中,limit_req 和 limit_conn 是两个用于限制客户端请求的指令,它们分别用于限制请求速率和并发连接数。 limit_req limit_req 用于限制请求速率,防止客户端发送过多请求影响服务器性能。它通过 limit_req_zone 指令定义一个共享内存…

基于winform的串口调试助手

目录 一、串口助手界面设计 1.1 串口配置 1.2 接收配置 1.3 发送配置 1.4 接收窗口和发送窗口 1.5 状态显示窗口 1.6 串口通讯控件 二、程序编写 2.1 端口号自动识别并显示在端口号下拉框 功能说明: 2.2 波特率下拉框显示 2.3 数据位下拉框显示 2.4 校…

Docker基础2

如需转载,标记出处 本次我们将下载一个 Docker 镜像,从镜像中启动容器 上一章,安装 Docker 时,获得两个主要组件: Docker 客户端 Docker 守护进程(有时称为“服务器”或“引擎”) 守护进程实…

Rocketmq2

一、生产者端防丢失 1. 发送方式选择 同步发送:使用 send() 方法,等待 Broker 确认响应(SendResult),确保消息已成功发送。异步发送:使用 sendAsync() 方法并设置回调函数,处理发送成功 / 失败…

RabbitMQ详解,RabbitMQ是什么?架构是怎样的?

目录 一,RabbitMQ是什么? 二,RabbitMQ架构 2.1 首先我们来看下RabbitMQ里面的心概念Queue是什么? 2.2 交换器Exchange 2.3 RabbitMQ是什么? 2.4 重点看下优先级队列是什么? 三,RabbitMQ集群 3.1 普通集群模式 3.2 镜像队列集群 一,RabbitMQ是什么? 假设我们程序…

【一步步开发AI运动APP】六、运动计时计数能调用

之前我们为您分享了【一步步开发AI运动小程序】开发系列博文,通过该系列博文,很多开发者开发出了很多精美的AI健身、线上运动赛事、AI学生体测、美体、康复锻炼等应用场景的AI运动小程序;为了帮助开发者继续深耕AI运动领域市场,今…

MySQL——DQL的多表查询

一、交叉连接 标准语法:select * from 表1 cross join 表2 where 表1.公共列 表2.公共列; 简单语法:select * from 表1 , 表2 where 表1.公共列 表2.公共列; 公共列:两张表具有相同含义的列,不是列名一样。 …

【Linux内核】如何更加优雅阅读Linux内核源码(vscode)

1. 前言 因为已经习惯在Ubuntu下进行嵌入式工作开发,但Linux源码在Source Insight下进行阅读,一直很苦恼Linux/Windows来回切换的开发方式,当前发现可以通过 vscode clangd(扩展组件) 方式进行更好的内核源码阅读。 2. 环境 操作系统&…

21.OpenCV获取图像轮廓信息

OpenCV获取图像轮廓信息 在计算机视觉领域,识别和分析图像中的对象形状是一项基本任务。OpenCV 库提供了一个强大的工具——轮廓检测(Contour Detection),它能够帮助我们精确地定位对象的边界。这篇博文将带你入门 OpenCV 的轮廓…

LETTERS(DFS)

【题目描述】 给出一个rowcolrowcol的大写字母矩阵,一开始的位置为左上角,你可以向上下左右四个方向移动,并且不能移向曾经经过的字母。问最多可以经过几个字母。 【输入】 第一行,输入字母矩阵行数RR和列数SS,1≤R,S≤…

Day2-2:前端项目uniapp壁纸实战

再在wallpaper新建一个目录components 在components下新建组件common-title 记得点击创建同名目录 在index加 <view class"select"><common-title></common-title></view> 图片换了下&#xff0c;原来的有点丑&#xff0c;图片可按自己喜欢…

其他 vector 操作详解(四十)

介绍 除去向 vector 添加元素&#xff08;如 push_back&#xff09;之外&#xff0c;vector 还提供了许多其他操作&#xff0c;这些操作大多与 string 的操作类似。通过掌握这些操作&#xff0c;我们可以方便地查询、修改和比较 vector 中的元素&#xff0c;从而构建灵活、高效…

【Leetcode 每日一题】368. 最大整除子集

问题背景 给你一个由 无重复 正整数组成的集合 n u m s nums nums&#xff0c;请你找出并返回其中最大的整除子集 a n s w e r answer answer&#xff0c;子集中每一元素对 ( a n s w e r [ i ] , a n s w e r [ j ] ) (answer[i], answer[j]) (answer[i],answer[j]) 都应当…