如何在我们的模型中使用Beam search

        在上一篇文章中我们具体探讨了Beam search的思想以及Beam search的大致工作流程。根据对Beam search的大致流程我们已经清楚了,在这我们来具体实现一下Beam search并应用在我们的seq2seq任务中。

1. python中的堆(heapq)

        堆是一种特殊的树形数据结构。堆分为大根堆和小根堆两种类型,其中:

  • 小根堆: 父节点的值小于或等于其子节点的值。
  • 大根堆: 父节点的值大于或等于其子节点的值。

堆的应用场景主要是以下两个:

        1. 堆排序,完成升序或降序排列;

        2. 优先级队列,其中元素按照优先级顺序排列,优先级越低越先出队。在每次插入元素时,堆会自动调整以确保最高(或最低)优先级的元素位于堆的根部。

2. Beam search的实现

2.1 Beam search的流程

我们通过构建堆来实现Beam search,主要流程:

        1. 构造 <SOS> 做为第一次输入信息保存在堆中;

        2. 取出堆中的数据,开始forward操作,获取当前时间步的输出output、hidden;

        3. 从output中选择top k个数据输出,做为下一个时间步的输入(其中Beam width = k);

        4. 把下一个时间步需要的输入数据保存在一个新的堆中;

        5. 获取新的堆中概率最大的数据,判断数据是否为 <EOS> 或者序列是否达到输出最大长度,如果符合则停止输出,若不符合则继续循环2~5。

2.2 构建beam

class Beam:def __init__(self):self.heap = list()self.beam_width = 3def add(self, probability, complete, seq, decoder_input, decoder_hidden):"""入队:param probability: 概率乘积:param complete: 句子是否输出完成:param seq: 句子 包含token的list:param decoder_input: 下一个时间步进行解码的输入:param decoder_hidden: 下一个时间步进行解码的hidden:return: """heapq.heappush(self.heap, [probability, complete, seq, decoder_input, decoder_hidden])# 如果数据的个数大于beam_width则弹出if len(self.heap) > self.beam_width:# heappop会根据优先级从小到大弹出,所以优先级最大的beam_widt会被保存在堆中# 当两个元素的probability的优先级相同时,则根据complete优先级弹出heapq.heappop(self.heap)def __iter__(self):return iter(self.heap)

现在我们完成了保存数据的数据结构。

使用Beam search进行评估

在decoder中我们先定义一个函数处理序列

    def _prepar_seq(self, seq):"""去除seq中的<SOS>和<EOS>的token"""if seq[0].item() == ws.SOS:seq = seq[1:]if seq[-1].item() == ws.EOS:seq = seq[:-1]seq = [i.item() for i in seq]return seq

接下来在decoder中使用beam search

    def beam_search(self, encoder_outputs, encoder_hidden):"""使用堆来完成beam search:param encoder_outputs: [batch_size, seq_len, encoder_hidden_size]:param encoder_hidden: [1, batch_size, encoder_hidden_size]"""batch_size = encoder_hidden.size(1)# 1. 构造第一次需要的输入数据,保存在堆中decoder_input = torch.LongTensor([[ws.SOS]*batch_size]).to(device)  # [batch_size, 1]# 要输入的hiddendecoder_hidden = encoder_hiddenprev_beam = Beam()prev_beam.add(1, False, [decoder_input], decoder_input, decoder_hidden)while True:cur_beam = Beam()# 2. 取出堆中的数据,进行forward_step操作,获得当前时间步的output, hiddenfor _probability, _complete, _seq, _decoder_input, _decoder_hidden in prev_beam:# 判断前一次的 _complete是否为True,如果是则不需要forward# 有可能为True,但是概率并不是最大if _complete == True:cur_beam.add(_probability, _complete, _seq, _decoder_input, _decoder_hidden)else:# 需要进行forward操作decoder_output_t, decoder_hidden = self.forward_step(_decoder_input, _decoder_hidden, encoder_outputs)# 3. 从output中选择最大的beam width个输出,作为下一次的inputvalue, index = torch.topk(decoder_output_t, config.beam_width)  # [batch_size, beam_width]for m, n in zip(value[0], index[0]):decoder_input = torch.LongTensor([[n]]).to(config.device)seq = _seq + [n]  # 更新句子序列probability = _probability * m  # 更新概率乘积if n.item() == config.chatbot_ws_by_word_target.SOS:complete = Trueelse:complete = False# 4. 把下个时间步需要的输入等数据保存在一个新的堆中cur_beam.add(probability, complete, seq, decoder_input, decoder_hidden)# 5. 获取新的堆中的优先级最高(概率最大)的数据,判断数据是否以EOS结尾或者是达到最大长度# 若是则停止迭代# 若不是则继续best_prob, best_complete, best_seq, _, _ = max(cur_beam)if best_complete == True or len(best_seq) - 1 == config.chatbot_target_max_seq_len + 1:return self._perpar_seq(best_seq)else:prev_beam = cur_beam

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/682261.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

速盾:2024年cdn在5g时代重要吗

在2024年&#xff0c;随着5G技术的普及与应用&#xff0c;内容分发网络&#xff08;Content Delivery Network&#xff0c;CDN&#xff09;在数字化时代中的重要性将进一步巩固和扩大。CDN是一种用于快速、高效地分发网络内容的基础设施&#xff0c;它通过将内容部署在全球各地…

幻兽帕鲁Palworld服务器设置参数(汉化)

创建幻兽帕鲁服务器配置参数说明&#xff0c;Palworld服务器配置参数与解释&#xff0c;阿腾云atengyun.com分享&#xff1a; 自建幻兽帕鲁服务器教程&#xff1a; 阿里云教程 https://t.aliyun.com/U/bLynLC腾讯云教程 https://curl.qcloud.com/oRMoSucP 幻兽帕鲁服务器 幻…

with 用法

with 已弃用: 不再推荐使用该特性。虽然一些浏览器仍然支持它&#xff0c;但也许已从相关的 web 标准中移除&#xff0c;也许正准备移除或出于兼容性而保留。请尽量不要使用该特性&#xff0c;并更新现有的代码&#xff1b;参见本页面底部的兼容性表格以指导你作出决定。请注意…

寒假学习记录16:Express框架(Node)

后续会补充 1.引入express 1.先下载express框架 创建一个package.json格式的文件&#xff0c;里面写入 {"dependencies": {"express": "~4.16.1" //express版本号} } 然后打开终端输入 npm i 2.引入express模块 const express require(&quo…

如何使用idea连通服务器上的Redis(详细版本)

这里我使用的是阿里云的服务器 打开阿里云的安全组&#xff0c;设置端口为6379 在redis.conf文件中&#xff0c;注释bind 127.0.0.1 将protected-mode设置为no&#xff0c;即关闭保护模式 更改服务器中的防火墙&#xff0c;放行6379端口 # 放行端口 firewall-cmd --zo…

【python】元组

是python中内置的不可变序列 在python中使用()定义元组&#xff0c;元素与元素之间使用英文的逗号分隔 元组中只有一个元素的时候&#xff0c;逗号也不能省略 y(10,) print(y,type(y))元组的创建方式 使用()直接创建元组 元组名(elem1,elem2,...,elemN)使用内置函数tuple()创…

Nacos 的配置管理和配置热更新

一、配置管理的必要性 1. 存在问题 微服务重复配置过多维护成本高&#xff1a;将各个微服务的配置都写到配置管理服务中&#xff0c;单个微服务不去编写配置&#xff0c;而是到配置管理服务中读取配置&#xff0c;实现配置共享&#xff0c;便于修改和维护 业务配置经常变动&a…

【AI视野·今日CV 计算机视觉论文速览 第299期】Mon, 29 Jan 2024

AI视野今日CS.CV 计算机视觉论文速览 Mon, 29 Jan 2024 Totally 55 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Computer Vision Papers Annotated Hands for Generative Models Authors Yue Yang, Atith N Gandhi, Greg TurkGAN 和扩散模型等生成模型已经展示了…

C++:priority_queue模拟实现

C&#xff1a;priority_queue模拟实现 什么是priority_queue模拟实现向上调整算法向下调整算法插入与删除 仿函数 什么是priority_queue priority_queue称为优先级队列。优先级队列是一种特殊的队列&#xff0c;其中每个元素都有一个相关的优先级。元素的优先级决定了它们在队…

【FTP讲解】

FTP讲解 1. 介绍2. 工作原理3. 传输模式4. 安全5. 设置FTP服务器6. FTP命令 1. 介绍 FTP&#xff08;File Transfer Protocol&#xff09;是“文件传输协议”的英文缩写&#xff0c;它是用于在网络上进行数据传输的一种协议。FTP是因特网上使用最广泛的协议之一&#xff0c;它…

Python数学建模之回归分析

1.基本概念及应用场景 回归分析是一种预测性的建模技术&#xff0c;数学建模中常用回归分析技术寻找存在相关关系的变量间的数学表达式&#xff0c;并进行统计推断。例如&#xff0c;司机的鲁莽驾驶与交通事故的数量之间的关系就可以用回归分析研究。回归分析根据变量的…

论文阅读:GamutMLP A Lightweight MLP for Color Loss Recovery

这篇文章是关于色彩恢复的一项工作&#xff0c;发表在 CVPR2023&#xff0c;其中之一的作者是 Michael S. Brown&#xff0c;这个老师是加拿大 York 大学的&#xff0c;也是 ISP 领域的大牛&#xff0c;现在好像也在三星研究院担任兼职&#xff0c;这个老师做了很多这种类似的工…

系统架构25 - 软件架构设计(4)

软件架构复用 软件产品线定义分类原因复用对象及形式基本过程 软件产品线 软件产品线是指一组软件密集型系统&#xff0c;它们共享一个公共的、可管理的特性集&#xff0c;满足某个特定市场或任务的具体需要&#xff0c;是以规定的方式用公共的核心资产集成开发出来的。即围绕…

九、OpenCV自带colormap

项目功能实现&#xff1a;每隔1500ms轮流自动播放不同风格图像显示&#xff0c;按下Esc键退出 按照之前的博文结构来&#xff0c;这里就不在赘述了 一、头文件 colormap.h #pragma once #include<opencv2/opencv.hpp> using namespace cv;class ColorMap { public:vo…

Mybatis开发辅助神器p6spy

Mybatis什么都好&#xff0c;就是不能打印完整的SQL语句&#xff0c;虽然可以根据数据来判断一二&#xff0c;但始终不能直观的看到实际语句。这对我们想用完整语句去数据库里执行&#xff0c;带来了不便。 怎么说呢不管用其他什么方式来实现完整语句&#xff0c;都始终不是Myb…

C++ 11新特性之并发

概述 随着计算机硬件的发展&#xff0c;多核处理器已经成为主流&#xff0c;对程序并发执行能力的需求日益增长。C 11标准引入了一套全面且强大的并发编程支持库&#xff0c;为开发者提供了一个安全、高效地利用多核CPU资源进行并行计算的新框架&#xff0c;极大地简化了多线程…

C#面:Static Nested Class 和 Inner Class 有什么不同

这是两种不同的类嵌套方式。 Static Nested Class &#xff1a; 是一个静态嵌套类&#xff0c;它是在外部类中定义的一个静态类。它可以访问外部类的静态成员和方法&#xff0c;但不能直接访问外部类的非静态成员和方法。静态嵌套类可以独立于外部类实例化&#xff0c;即可以…

《Linux 简易速速上手小册》第6章: 磁盘管理与文件系统(2024 最新版)

文章目录 6.1 磁盘分区与格式化6.1.1 重点基础知识6.1.2 重点案例&#xff1a;为新硬盘配置分区和文件系统6.1.3 拓展案例 1&#xff1a;创建交换分区6.1.4 拓展案例 2&#xff1a;使用 LVM 管理分区 6.2 挂载与卸载文件系统6.2.1 重点基础知识6.2.2 重点案例&#xff1a;挂载新…

近十年金融资产收益率

通过掌握大类资产的历年收益率数据&#xff0c;做基于数据的投资&#xff0c;提高胜率和收益率。 下面是同花顺梳理的2014至2023大类金融资产收益率&#xff1a; 基于这个数据&#xff0c;我们再统计两项指标&#xff1a; 1. 每种资产在近十年的投资胜率&#xff08;收益率为…

牛客2024年情人节比赛 娱乐报告

前言 挺欢乐的比赛&#xff0c;有趣 欢迎关注 珂朵莉 牛客周赛专栏 珂朵莉 牛客小白月赛专栏 A. 第二杯半价 思路: 模拟 分奇偶进行讨论 t int(input())for _ in range(t):n, x list(map(int, input().split()))if n % 2 1:print (n//2 * (x (x 1) // 2) x)else:pr…