以beam search为例,详解transformers中generate方法(上)
- 1. generate的代码位置
- 2. GenerationMixin概览
- 3. generate签名
- 4. generate过程
- 4.1 读取并更新generation config
- 4.2 补充没有传入的参数
- 4.3 定义模型输入
- 4.4 定义模型的其他参数
- 4.5 对自回归模型准备input_ids
- 4.6 准备最大长度
- 4.7 确认生成模式
- 4.8 创建logits处理器
- 4.9 创建停止规则
- 4.10 进入相应的分支
- 4.11 创建logits warper
- 4.12 beam search
比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate()
,本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,在本人能力范围内,尽可能详细地展开介绍。
考虑到篇幅可能会比较长,本文将分为上下两篇,上篇主要介绍generate方法的整体结构,下篇将对beam search部分的代码进行进一步的介绍。
随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:
-
- generate的源码位置发生了改变;
-
- generate方法中,采用一个generation_config参数来管理生成相关的各种配置,并优化了逻辑,使得逻辑更加清晰。
1. generate的代码位置
在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py
,这个方法是GenerationMixin
类的一个方法。
而在新版本的transformers中(transformers~=4.28),generate方法被转移到了transformers.generation.utils.py
,仍然是GenerationMixin
的一个类方法。
而对于一个hf形式的预训练模型,都是继承了PreTrainedModel
类的,而顺着这个PreTrainedModel
类,可以看到更上一级的继承逻辑,GenerationMixin
就在其中:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
这就是为什么通过AutoModel.from_pretrained()
实例化的一个model为什么可以直接调用generate
方法去做推理。
2. GenerationMixin概览
这一部分作为一个速查表写在这里,不建议直接阅读,而是在读第4节代码的过程中,返回来查看这部分内容。
GenerationMixin
类所有方法概览如下:
方法名 | 作用 | 在本文中出现的位置 |
---|---|---|
_validate_model_class | 检修该模型是否可以做生成,并抛出相应的异常 | 4.1 |
_validate_model_kwargs | 检查generation config中的参数是否与生成策略相匹配 | 4.1 |
_prepare_model_inputs | 为生成过程准备输入 | 4.3 |
_maybe_initialize_input_ids_for_generation | 当生成过程的inputs为空时,使用bos token做初始化 | 4.3 |
_prepare_attention_mask_for_generation | 为生成过程准备attention_mask | 4.4 |
_prepare_encoder_decoder_kwargs_for_generation | 为生成过程准备encoder相关的参数 | 4.4 |
_prepare_decoder_input_ids_for_generation | 为自回归模型额外处理input_ids | 4.5 |
_get_decoder_start_token_id | 获取decoder的开始位置的token id,这个id可能与bos不同 | 4.5 |
_get_logits_processor | 创建logits处理器 | 4.8 |
_get_stopping_criteria | 创建停止规则 | 4.9 |
_get_logits_warper | 创建logits warper | 4.11 |
_expand_inputs_for_generation | 根据num_beams对input_ids进行扩展 | 4.12 |
prepare_inputs_for_generation | 对模型的输入进行预处理 | 下篇3.1 |
adjust_logits_during_generation | 在生成过程中对计算的logits进行调整 | 下篇3.1 |
_update_model_kwargs_for_generation | 根据一个step的生成结果,更新生成参数 | 下篇5.6 |
_reorder_cache | 根据step更新的beam_idx,对缓存的past_k_v进行重排 | 下篇5.6 |
3. generate签名
在介绍流程之前先看一下generate方法的签名,在4.28版本中,其签名简化如下:
@torch.no_grad()def generate(self,inputs: Optional[torch.Tensor] = None,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,synced_gpus: Optional[bool] = None,streamer: Optional["BaseStreamer"] = None,**kwargs,) -> Union[GenerateOutput, torch.LongTensor]:
相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k
, top_p
, num_beams
等参数全部都整合到了generation_config
中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷。
相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。
需要在GenerationConfig
中查看可选的参数:
from transformers.generation.configuration_utils import GenerationConfighelp(GenerationConfig)
(GenerationConfig
中各类生成策略对应的参数各有不同,这里不展开介绍,在本文的下篇中,将对beam search策略下的参数进行简介。)
generate方法的参数含义与作用介绍如下:
参数名 | 类型 | 含义与作用 |
---|---|---|
inputs | torch.Tensor | tokenize之后的序列id,模型将基于这个序列计算logits并进行生成。如果为空,则默认为bos token对应的id |
generation_config | GenerationConfig | 各种生成策略对应的参数,如果为空,将会从模型路径下的generation_config.json文件中读取,或从model config获取 |
logits_processor | LogitsProcessorList | 对模型计算出的logits进行进一步处理,例如对“复读机现象”相应的概率进行惩罚,以避免模型生成结果不断重复 |
stopping_criteria | StoppingCriteriaList | 对生成过程做停止控制的工具,例如达到一定长度时强行停止,达到一定生成时间时停止等 |
prefix_allowed_tokens_fn | [int, torch.Tensor], List[int] | beam search过程中,每个step允许生成的token id范围 |
synced_gpus | bool | 采用DeepSpeed ZeRO时使用 |
streamer | BaseStreamer | stream generate时使用(也就是一个字一个字的往外蹦的效果) |
在这些输入中,logits_processor和stopping_criteria,将是用户手动干预生成过程的主要手段。
4. generate过程
在4.28版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。
4.1 读取并更新generation config
这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` callself._validate_model_class()# priority: `generation_config` argument > `model.generation_config` (the default generation config)if generation_config is None:# legacy: users may modify the model configuration to control generation -- update the generation config# model attribute accordingly, if it was created from the model configif self.generation_config._from_model_config:new_generation_config = GenerationConfig.from_model_config(self.config)if new_generation_config != self.generation_config:warnings.warn("You have modified the pretrained model configuration to control generation. This is a"" deprecated strategy to control generation and will be removed soon, in a future version."" Please use a generation configuration file (see"" https://huggingface.co/docs/transformers/main_classes/text_generation)")self.generation_config = new_generation_configgeneration_config = self.generation_configgeneration_config = copy.deepcopy(generation_config)model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargsgeneration_config.validate()self._validate_model_kwargs(model_kwargs.copy())
其中_validate_model_class
和_validate_model_kwargs
两个方法都不是重点,这里不展开介绍。
4.2 补充没有传入的参数
这部分需要补充的参数包括logits_processor
, stopping_criteria
, 以及generation_config
中的pad_token_id
。前两项是设置为默认的空list;pad_token_id没有给定,而eos给定的话,用eos来做padding。
# 2. Set generation parameters if not already definedlogits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:if model_kwargs.get("attention_mask", None) is None:logger.warning("The attention mask and the pad token id were not set. As a consequence, you may observe ""unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.")eos_token_id = generation_config.eos_token_idif isinstance(eos_token_id, list):eos_token_id = eos_token_id[0]logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")generation_config.pad_token_id = eos_token_id
4.3 定义模型输入
# 3. Define model inputs# inputs_tensor has to be defined# model_input_name is defined if model-specific keyword input is passed# otherwise model_input_name is None# all model-specific keyword inputs are removed from `model_kwargs`inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)batch_size = inputs_tensor.shape[0]
这里主要需要关注_prepare_model_inputs
这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况:
def _prepare_model_inputs(self,inputs: Optional[torch.Tensor] = None,bos_token_id: Optional[int] = None,model_kwargs: Optional[Dict[str, torch.Tensor]] = None,) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:"""This function extracts the model-specific `inputs` for generation."""# 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致# 1. retrieve all kwargs that are non-None or non-model input related.# some encoder-decoder models have different names for model and encoderif (self.config.is_encoder_decoderand hasattr(self, "encoder")and self.encoder.main_input_name != self.main_input_name):input_name = self.encoder.main_input_nameelse:input_name = self.main_input_namemodel_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}# 确保inputs没有重复传入# 2. check whether model_input_name is passed as kwarg# if yes and `inputs` is None use kwarg inputsinputs_kwarg = model_kwargs.pop(input_name, None)if inputs_kwarg is not None and inputs is not None:raise ValueError(f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed."f"Make sure to either pass {inputs} or {input_name}=...")elif inputs_kwarg is not None:inputs = inputs_kwarg# 对于inputs_embeds这一输入参数:# 如果是decoder-only模型,需要把'input_ids'这一参数放在inputs_kwarg中传入# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一# 3. In the presence of `inputs_embeds` for text models:# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.if input_name == "input_ids" and "inputs_embeds" in model_kwargs:if not self.config.is_encoder_decoder:has_inputs_embeds_forwarding = "inputs_embeds" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())if not has_inputs_embeds_forwarding:raise ValueError(f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} ""doesn't have its forwarding implemented. See the GPT2 implementation for an example ""(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!")# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of# the attention mask) can rely on the actual model input.model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs=model_kwargs)else:if inputs is not None:raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"# 4. if `inputs` is still None, try to create `input_ids` from BOS token# 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:# torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_idinputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)return inputs, input_name, model_kwargs
4.4 定义模型的其他参数
这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。
# 4. Define other model kwargsmodel_kwargs["output_attentions"] = generation_config.output_attentionsmodel_kwargs["output_hidden_states"] = generation_config.output_hidden_statesmodel_kwargs["use_cache"] = generation_config.use_cacheaccepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())requires_attention_mask = "encoder_outputs" not in model_kwargsif model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id)# decoder-only models should use left-padding for generationif not self.config.is_encoder_decoder:if (generation_config.pad_token_id is not Noneand torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0):logger.warning("A decoder-only architecture is being used, but right-padding was detected! For correct ""generation results, please set `padding_side='left'` when initializing the tokenizer.")if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:# if model is encoder decoder encoder_outputs are created# and added to `model_kwargs`model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, model_input_name)
4.5 对自回归模型准备input_ids
这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,则直接采用4.3创建的input_tensor作为input_ids。
# 5. Prepare `input_ids` which will be used for auto-regressive generationif self.config.is_encoder_decoder:# 这里主要是针对decoder的开始位置id与bos id不同的情况# 在这种情况下,decoder-only模型应当以配置中规定的decoder start id开始进行生成# 此处可简单理解为:torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_idinput_ids = self._prepare_decoder_input_ids_for_generation(batch_size,decoder_start_token_id=generation_config.decoder_start_token_id,bos_token_id=generation_config.bos_token_id,model_kwargs=model_kwargs,device=inputs_tensor.device,)# conditional generation for multi-modal models.if "input_ids" in model_kwargs and model_input_name == "pixel_values":input_ids = torch.cat([input_ids, model_kwargs.pop("input_ids")], dim=-1)else:input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
4.6 准备最大长度
这一部分就是根据config中的相关配置,判断input_id的长度有没有超长。
# 6. Prepare `max_length` depending on other stopping criteria.input_ids_seq_length = input_ids.shape[-1]has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not Noneif has_default_max_length and generation_config.max_new_tokens is None:warnings.warn(f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. ""This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"" recommend using `max_new_tokens` to control the maximum length of the generation.",UserWarning,)elif generation_config.max_new_tokens is not None:generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_lengthif not has_default_max_length:logger.warn(f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. ""Please refer to the documentation for more information. ""(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",UserWarning,)if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:raise ValueError(f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"f" the maximum length ({generation_config.max_length})")if input_ids_seq_length >= generation_config.max_length:input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"logger.warning(f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"" increasing `max_new_tokens`.")
4.7 确认生成模式
这里直接选择beam search分支了,其他模式不做展开介绍,下同。
beam search分为两种,基础款的beam_gen_mode
,以及进阶款的beam_sample_gen_mode
,其中,前者对应后续的生成方法为beam_search
,后者对应后续的生成方法为beam_sample
。
二者的区别主要在于,进阶款的beam_sample_gen_mode
可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在4.11节:logits warper中介绍。对于基础款的beam_gen_mode
,就没有创建logits warper这一环节。
# 7. determine generation modeis_beam_sample_gen_mode = ((generation_config.num_beams > 1)and (generation_config.num_beam_groups == 1)and generation_config.do_sample is Trueand not is_constraint_gen_modeand not is_contrastive_search_gen_mode)
4.8 创建logits处理器
# 8. prepare distribution pre_processing samplerslogits_processor = self._get_logits_processor(generation_config=generation_config,input_ids_seq_length=input_ids_seq_length,encoder_input_ids=inputs_tensor,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,logits_processor=logits_processor,)
这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers
中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。
在beam search中,logits process的使用方法是:
# 在def beam_sample中使用
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。
在这里简单介绍一下在transformers
中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。
processor | 作用 | 参考连接 |
---|---|---|
MinLengthLogitsProcessor | 通过将EOS的概率强行设置为0,保证生成结果的长度大于等于一个最小值 | / |
MinNewTokensLengthLogitsProcessor | 与上一个类似,但是prompt的部分不计入生成长度 | / |
RepetitionPenaltyLogitsProcessor | 防止“复读机”现象,给重复出现token添加惩罚,由预训练模型CTRL提出 | arxiv |
EncoderRepetitionPenaltyLogitsProcessor | 与上一个区别在于,生成的结果不能与encoder输入input id重复,而非与当前给定的全部input id | / |
NoRepeatNGramLogitsProcessor | 防止生成的文本中出现重复的n-gram(n个连续的词或字符),区别在于限制连续n个 | github |
EncoderNoRepeatNGramLogitsProcessor | n-gram可以在encoder的input ids中重复,不可以在decoder重复 | github |
NoBadWordsLogitsProcessor | 确保某些词永远不会被生成 | / |
PrefixConstrainedLogitsProcessor | 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 | arxiv |
HammingDiversityLogitsProcessor | 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 | arxiv |
ForcedBOSTokenLogitsProcessor | 确保生成的第一个token是某个特定的token | / |
ForcedEOSTokenLogitsProcessor | 达到最大长度时,确保以某个特定的token作为结束 | / |
InfNanRemoveLogitsProcessor | 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 | / |
SuppressTokensAtBeginLogitsProcessor | 在达到某个长度之后,将不再生成某些特定的词 | / |
SuppressTokensLogitsProcessor | 将某些特定词的概率设置为-inf,不生成这些词 | / |
ForceTokensLogitsProcessor | 建立一个映射表,把某个token强行映射成另一个token | / |
WhisperTimeStampLogitsProcessor | 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) | / |
4.9 创建停止规则
stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。
# 9. prepare stopping criteriastopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria)
预设的criteria总结如下:
criteria | 作用 |
---|---|
MaxLengthCriteria | 生成的序列达到设置的最大长度时,停止生成 |
MaxNewTokensCriteria | 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成 |
MaxTimeCriteria | 生成的耗时超过一定时间限制时,停止生成 |
如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria
。
4.10 进入相应的分支
这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。
4.11 创建logits warper
# 11. prepare logits warperlogits_warper = self._get_logits_warper(generation_config)
logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。
普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。
需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)
。
预设的warper包括:
warper | 作用(仅供参考) | 参考链接 |
---|---|---|
TemperatureLogitsWarper | 对score整体除以temperature做缩放 | / |
TopPLogitsWarper | 概率小于topp的得分置为0 | / |
TopKLogitsWarper | 只取topk的概率对应的词汇,其余的概率置为-inf | / |
TypicalLogitsWarper | typical decoding | arxiv |
EpsilonLogitsWarper | 将概率小于epsilon的token移除 | arxiv |
EtaLogitsWarper | eta-sampling | arxiv |
LogitNormalization | 在beam search进行的过程中做layernorm | / |
4.12 beam search
这一部分是beam search的核心流程,限于篇幅,其具体的执行生成过程将在本文的下篇中进行详细的介绍。
在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在下篇中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search
,或beam sample对应的beam_sample
方法。
beam_scorer = BeamSearchScorer(batch_size=batch_size,num_beams=generation_config.num_beams,device=inputs_tensor.device,length_penalty=generation_config.length_penalty,do_early_stopping=generation_config.early_stopping,num_beam_hyps_to_keep=generation_config.num_return_sequences,max_length=generation_config.max_length,)# 12. interleave input_ids with `num_beams` additional sequences per batchinput_ids, model_kwargs = self._expand_inputs_for_generation(input_ids=input_ids,expand_size=generation_config.num_beams,is_encoder_decoder=self.config.is_encoder_decoder,**model_kwargs,)# 13. run beam searchreturn self.beam_search(input_ids,beam_scorer,logits_processor=logits_processor,stopping_criteria=stopping_criteria,pad_token_id=generation_config.pad_token_id,eos_token_id=generation_config.eos_token_id,output_scores=generation_config.output_scores,return_dict_in_generate=generation_config.return_dict_in_generate,synced_gpus=synced_gpus,**model_kwargs,)
在本文的下篇中,将介绍beam search的基本原理,transformers模块对于beam search的实现方法中,主要涉及的几个工具组件,beam search的生成与更新过程,以及beam sample对beam search的改进实现,感兴趣的朋友可以继续阅读。