【深度学习】序列生成模型(二):束搜索

文章目录

    • 序列生成
    • 束搜索
      • 理论基础
      • 算法步骤
      • python实现

序列生成

  在进行最大似然估计训练后的模型 p θ ( x ∣ x 1 : ( t − 1 ) ) p_\theta(x | \mathbf{x}_{1:(t-1)}) pθ(xx1:(t1)),我们可以使用该模型进行序列生成。生成的过程是按照时间顺序逐步生成序列样本。假设在第 t t t 步,我们已经生成了前 t − 1 t-1 t1 步的序列前缀 x 1 : ( t − 1 ) = x 1 , … , x t − 1 \mathbf{x}_{1:(t-1)} = x_1, \ldots, x_{t-1} x1:(t1)=x1,,xt1,我们希望在当前步生成下一个词 x t x_t xt。生成的过程可以用以下概率分布表示:

x t ∼ p θ ( x ∣ x 1 : ( t − 1 ) ) x_t \sim p_\theta(x | \mathbf{x}_{1:(t-1)}) xtpθ(xx1:(t1))

其中, x 1 : ( t − 1 ) \mathbf{x}_{1:(t-1)} x1:(t1) 是已经生成的前缀序列, x t x_t xt 是在给定前缀序列的条件下,由模型生成的当前时刻的词。

  这个过程可以迭代进行,直到生成完整的序列样本。在每一步,模型根据已经生成的前缀序列生成当前时刻的词,然后将当前时刻的词添加到前缀序列中,用于生成下一个时刻的词。

生成的序列样本可以用如下方式表示:

x ^ = x ^ 1 , x ^ 2 , … , x ^ T \mathbf{\hat{x}} = \hat{x}_1, \hat{x}_2, \ldots, \hat{x}_T x^=x^1,x^2,,x^T

其中, x ^ t \hat{x}_t x^t 是在第 t t t 步生成的词, x ^ \mathbf{\hat{x}} x^ 是完整的生成序列。这个过程是根据训练得到的模型对数据分布进行采样,从而生成新的符合训练数据分布的序列。

  自回归的方式可以生成一个无限长度的序列.为了避免这种情况,通常会设置一个特殊的符号⟨𝐸𝑂𝑆⟩(End-of-Sequence)来表示序列的结束.在训练时,每个序列样本的结尾都会加上结束符号 ⟨ EOS ⟩ \langle \text{EOS} \rangle EOS。训练模型时,这有助于模型学习何时停止生成。在测试时,一旦生成了结束符号 ⟨ EOS ⟩ \langle \text{EOS} \rangle EOS,模型就会中止生成过程。

束搜索

理论基础

  在每个时间步,自回归模型贪婪搜索选择当前条件概率分布中具有最高概率的词作为生成的词。具体而言,对于每个时间步 t t t,生成的词 x ^ t \hat{x}_t x^t是:

x ^ t = arg ⁡ max ⁡ x ∈ V p θ ( x ∣ x 1 : ( t − 1 ) ) \hat{x}_t = \arg\max_{x \in \mathcal{V}} p_\theta(x | \mathbf{x}_{1:(t-1)}) x^t=argxVmaxpθ(xx1:(t1))

其中, V \mathcal{V} V 是词表, x 1 : ( t − 1 ) = x ^ 1 , … , x ^ t − 1 \mathbf{x}_{1:(t-1)} = \hat{x}_1, \ldots, \hat{x}_{t-1} x1:(t1)=x^1,,x^t1 是前 t − 1 t-1 t1 步中已经生成的前缀序列。

  这种贪婪搜索策略是一种简单且直观的方法,但它有一个主要的缺点,即可能导致生成的序列不是全局最优的。由于在每个时间步都选择了局部最大概率的词,生成的序列并不保证是整个序列的全局最大概率。这种策略可能导致生成的序列缺乏一致性或流畅性。
  为了改善这种情况,束搜索(Beam Search)是一种常用的启发式方法,特别在序列生成任务中应用广泛。在束搜索中,每个时间步生成多个备选序列,而不仅仅是一个。这样可以在每个时间步维持一个集合,称为束(beam),其中包含多个备选序列。束的大小由超参数 K K K 决定,通常被称为束大小。
在这里插入图片描述
  在每个时间步,算法选择概率最高的 K K K 个序列作为备选,并将它们作为下一个时间步的输入。这样,算法在整个生成过程中维持了 K K K 条备选序列,允许更全面地探索可能的序列空间。
  束搜索有助于减少搜索空间,提高搜索的效率。然而,束大小 K K K 的选择是一个权衡,较小的 K K K 可能导致搜索空间不够广泛,而较大的 K K K 则会增加计算开销。因此,束大小的选择通常需要根据具体任务和性能需求进行调整。

算法步骤

  1. 初始化: 设置束大小 K K K,初始化一个束(beam)用于存储备选序列。初始时,束中包含一个空序列。

  2. 逐步生成: 对于每个时间步 t t t,执行以下步骤:

    a. 对于束中的每个备选序列,生成下一个词的备选集合。计算条件概率 p θ ( x t ∣ context ) p_\theta(x_t | \text{context}) pθ(xtcontext)

    b. 对于所有的备选序列和它们的备选词,计算在当前时间步的累积概率。

    c. 从所有的备选序列中选择累积概率最高的 K K K个序列作为新的束。

    d. 如果生成了结束符号或达到了最大生成长度,则停止生成。

  3. 输出: 选择束中最终累积概率最高的序列作为最终的生成结果。

python实现

def beam_search(model, initial_context, beam_size, max_length):# 初始化束,初始时包含一个空序列beam = [([], 1.0)]  # 初始序列和初始概率# 逐步生成for t in range(max_length):new_beam = []# 对于束中的每个备选序列for sequence, score in beam:# 生成备选词candidates = generate_candidates(model, sequence, initial_context)# 计算累积概率for candidate in candidates:new_sequence = sequence + [candidate]new_score = score * calculate_probability(model, new_sequence, initial_context)new_beam.append((new_sequence, new_score))# 选择累积概率最高的 K 个序列作为新的束beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:beam_size]# 判断是否生成了结束符号或达到最大生成长度if is_finished(beam):break# 选择最终累积概率最高的序列作为结果best_sequence = max(beam, key=lambda x: x[1])[0]return best_sequence

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/230824.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实现el-table操作列点击弹出echarts

代码&#xff1a; <el-table-column :width"90"><template #default"scope"><el-popover placement"left-end" width"550" trigger"click"><div><div style"font-size: 18px; margin-left…

IDEA报错处理

问题1 IDEA 新建 Maven 项目没有文件结构 pom 文件为空 将JDK换成1.8后解决。 网络说法&#xff1a;别用 java18&#xff0c;换成 java17 或者 java1.8 都可以&#xff0c;因为 java18 不是 LTS 版本&#xff0c;有着各种各样的问题。。

numpy-learn

创建数组 import numpy as np import pandas as pd import mathvalue float(nan)# 使用 math.isnan() if math.isnan(value):print("Value is NaN")# 使用 numpy.isnan() if np.isnan(value):print("Value is NaN")np.array([1, 2, 3, 4, 5]) np.linspac…

Hadoop和Spark的区别

Hadoop 表达能力有限。磁盘IO开销大&#xff0c;延迟度高。任务和任务之间的衔接涉及IO开销。前一个任务完成之前其他任务无法完成&#xff0c;难以胜任复杂、多阶段的计算任务。 Spark Spark模型是对Mapreduce模型的改进&#xff0c;可以说没有HDFS、Mapreduce就没有Spark。…

Python 词法分析

Python 程序由 解析器 读取&#xff0c;输入解析器的是 词法分析器 生成的 形符 流。本章介绍词法分析器怎样把文件拆成形符。 Python 将读取的程序文本转为 Unicode 代码点&#xff1b;编码声明用于指定源文件的编码&#xff0c;默认为 UTF-8&#xff0c;详见 PEP 3120。源文…

Wireshark插件开发

第一章&#xff1a;Wireshark基础及捕获技巧 1.1 Wireshark基础知识回顾 1.2 高级捕获技巧&#xff1a;过滤器和捕获选项 1.3 Wireshark与其他抓包工具的比较 第二章&#xff1a;网络协议分析 2.1 网络协议分析&#xff1a;TCP、UDP、ICMP等 2.2 高级协议分析&#xff1a;HTTP…

2023年全球运维大会(GOPS深圳站)-核心PPT资料下载

一、峰会简介 1、大会背景与概述 全球运维大会&#xff08;GOPS&#xff09;是运维领域最具影响力的国际盛会&#xff0c;每年都会汇聚世界各地的运维专家、企业领袖、技术爱好者&#xff0c;共同探讨运维技术的最新发展、最佳实践以及面临的挑战。2023年GOPS深圳站作为该系列…

2023建筑行业薪资趋势?如何提高建筑设计效率呢?

12月6日&#xff0c;国外著名建筑可视化网站CGarchitect公布了其2023年建筑可视化薪资调查结果&#xff0c;详细描述了行业内的薪资趋势。 调查表明&#xff0c;占比较高的是有16.04%的年收入低于10000美元&#xff08;约71000人民币&#xff09;&#xff0c;其次是11.75%的受…

【MyBatis-Plus】多数据源分页配置(低版本暂时就支持一种(可选),高版本多支持)

【转载】一、Mybatis Plus 3.4 版本之后分页插件的变化 1、地址 Mybatis Plus 3.4版本之后分页插件的变化 2、内容 1、MybatisPlusInterceptor 从 Mybatis Plus 3.4.0 版本开始&#xff0c;不再使用旧版本的 PaginationInterceptor&#xff0c;而是使用 MybatisPlusInterce…

【C++】封装:练习案例-点和圆的关系

练习案例&#xff1a;点和圆的关系 设计一个圆形类&#xff08;Circle&#xff09;&#xff0c;和一个点类&#xff08;Point&#xff09;&#xff0c;计算点和圆的关系。 思路&#xff1a; 1&#xff09;创建点类point.h和point.cpp 2&#xff09;创建圆类circle.h和circle…

20、WEB攻防——PHP特性缺陷对比函数CTF考点CMS审计实例

文章目录 一、PHP常用过滤函数&#xff1a;1.1 与1.2 md51.3 intval1.4 strpos1.5 in_array1.6 preg_match1.7 str_replace CTFshow演示三、参考资料 一、PHP常用过滤函数&#xff1a; 1.1 与 &#xff1a;弱类型对比&#xff08;不考虑数据类型&#xff09;&#xff0c;甚至…

Java中的final关键字和static关键字

这两个关键字编写代码时会经常用&#xff0c;正确的使用这些关键字&#xff0c;可以形成良好的编程习惯&#xff0c;保护好代码的封装性。 1、final 关键字 在Java中&#xff0c;利用关键字final指示常量&#xff0c;习惯上&#xff0c;常量名使用全大写。 关键字final表示这个…

计算机网络:自顶向下第八版学习指南笔记和课后实验--运输层

记录一些学习计算机网络:自顶向下的学习笔记和心得 Github地址&#xff0c;欢迎star ⭐️⭐️⭐️⭐️⭐️ 运输层 TCP&#xff1a; 传输控制协议 报文段 UDP&#xff1a; 用户数据包协议 数据报 将主机间交付扩展到进程间交付被称为运输层的多路复用与多路分解 将运输层…

【Java】【Stream流】分组

Java实际开发中使用流会提升代码的质量&#xff0c;所以这里继续分享使用流 玩分组 import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; public class StreamGroupingExample { public static void main(String[] args) { List<…

深入解析C语言数组与指针:嵌套循环遍历数组

在这篇博客中&#xff0c;我们将深入探讨C语言中数组和指针的关系&#xff0c;通过一个简单的嵌套循环遍历数组的例子展示了它们的使用。 代码示例 #include <stdio.h>int main() {int arr1[] {1, 2, 3, 4, 5};int arr2[] {2, 3, 4, 5, 6};int arr3[] {3, 4, 5, 6, …

RocketMq查看消息轨迹

查看消息轨迹 1.修改配置文件 broker的启动文件加上消息轨迹相关配置 ##if msg tracing is open,the flag will be true traceTopicEnabletrue2.启动broker 使用broker-a.properties配置文件后台启动Broker。 nohup mqbroker -c /usr/local/rocketmq/rocketmq-all-4.9.1-bin…

智能五子棋1

*一、项目需求* 五子棋是一种简单的黑白棋&#xff0c;历史悠久&#xff0c;起源于中国&#xff0c;后传入日本&#xff0c;在日本被称为“连珠”&#xff0c;是一种老少皆宜的益智游戏。 人工智能五子棋系统的目标用户是一切想致力于研究人机对弈算法理论的相关研究者和一切…

使用C语言设计并实现一个成绩管理系统

使用C语言设计并实现一个成绩管理系统&#xff0c;该系统用于教师管理一门课程的成绩。 系统功能&#xff1a;成绩录入、打印成绩单、修改成绩、统计分数段、统计平均分、统计不及格学生&#xff0c;相关要求&#xff1a; 1&#xff09; 系统要有主菜单界面&#xff0c;让教师…

关于《企业数字化平台》

大家好&#xff0c;开始我们《企业数字化平台》系列栏目&#xff0c;首先做一个简短的自我介绍&#xff0c;Duster是本人现在的笔名&#xff0c;曾用笔名尘埃&#xff0c;写了本书《生活新视界》&#xff0c;如果有缘的话&#xff0c;希望您读到他&#xff0c;如果您能读懂&…