比起两年前,NLG任务已经得到了非常有效的发展,transformers模块的使用广泛程度也达到前所未有的程度。在模型推理预测时,一个核心的语句就是model.generate(),本文就来详细介绍一下generate方法是如何运作的。在生成的过程中,包含了诸多生成策略,本文将以最常用的beam search为例,尽可能详细地展开介绍。
随着各种LLM的出现,transformers中与generate相关的代码发生了一些变化,主要区别在于:
generate的源码位置发生了改变;
generate方法中,采用一个generation_config参数来管理生成相关的各种配置,并优化了逻辑,使得逻辑更加清晰。
generate的代码位置
在之前版本的transformers中(transformers~=4.9),generate方法位于transformers.generation_utils.py,这个方法是GenerationMixin类的一个方法。
而在新版本的transformers中(transformers~=4.42),generate方法被转移到了transformers.generation.utils.py,仍然是GenerationMixin的一个类方法。
而对于一个hf形式的预训练模型,都是继承了PreTrainedModel类的,而顺着这个PreTrainedModel类,可以看到更上一级的继承逻辑,GenerationMixin就在其中:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
这就是为什么通过AutoModel.from_pretrained()实例化的一个model为什么可以直接调用generate方法去做推理。
GenerationMixin概览
这一部分作为一个速查表写在这里,不建议直接阅读,而是在读后面代码的过程中,返回来查看这部分内容。
GenerationMixin类所有方法概览如下:
方法名 | 作用 |
|---|---|
_validate_model_class | 检修该模型是否可以做生成,并抛出相应的异常 |
_validate_model_kwargs | 检查generation config中的参数是否与生成策略相匹配 |
_prepare_model_inputs | 为生成过程准备输入 |
_maybe_initialize_input_ids_for_generation | 当生成过程的inputs为空时,使用bos token做初始化 |
_prepare_attention_mask_for_generation | 为生成过程准备attention_mask |
_prepare_encoder_decoder_kwargs_for_generation | 为生成过程准备encoder相关的参数 |
_prepare_decoder_input_ids_for_generation | 为自回归模型额外处理input_ids |
_get_decoder_start_token_id | 获取decoder的开始位置的token id,这个id可能与bos不同 |
_get_logits_processor | 创建logits处理器 |
_get_stopping_criteria | 创建停止规则 |
_get_logits_warper | 创建logits warper |
_expand_inputs_for_generation | 根据num_beams对input_ids进行扩展 |
prepare_inputs_for_generation | 对模型的输入进行预处理 |
adjust_logits_during_generation | 在生成过程中对计算的logits进行调整 |
_update_model_kwargs_for_generation | 根据一个step的生成结果,更新生成参数 |
_reorder_cache | 根据step更新的beam_idx,对缓存的past_k_v进行重排 |
generate签名
在介绍流程之前先看一下generate方法的签名,在4.42.4版本中,其签名简化如下:
@torch.no_grad()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
相比之前的版本,这样写的直接优点就是,与原版的超长签名相比,减少了传入的参数,将诸如top_k, top_p, num_beams等参数全部都整合到了generation_config中,使得函数看起来更加简化,并且该参数可以直接从模型路径下的generation_config.json文件中读取,一定程度上为用户提供了便捷。
相应的缺点就是很多参数没有显性地暴露出来,在查看注释和自定义生成配置的时候就不是很方便了。
需要在GenerationConfig中查看可选的参数:
from transformers.generation.configuration_utils import GenerationConfig
help(GenerationConfig)
generate方法的参数含义与作用介绍如下:
参数名 | 类型 | 含义与作用 |
|---|---|---|
inputs | torch.Tensor | tokenize之后的序列id,模型将基于这个序列计算logits并进行生成。如果为空,则默认为bos token对应的id |
generation_config | GenerationConfig | 各种生成策略对应的参数,如果为空,将会从模型路径下的generation_config.json文件中读取,或从model config获取 |
logits_processor | LogitsProcessorList | 对模型计算出的logits进行进一步处理,例如对“复读机现象”相应的概率进行惩罚,以避免模型生成结果不断重复 |
stopping_criteria | StoppingCriteriaList | 对生成过程做停止控制的工具,例如达到一定长度时强行停止,达到一定生成时间时停止等 |
prefix_allowed_tokens_fn | [int, torch.Tensor], List[int] | beam search过程中,每个step允许生成的token id范围 |
synced_gpus | bool | 采用DeepSpeed ZeRO时使用 |
assistant_model | PreTrainedModel | 可用于加速生成的助理模型。助理模型必须具有完全相同的tokenizer。助理模型应该更小。 |
streamer | BaseStreamer | stream generate时使用(也就是一个字一个字的往外蹦的效果) |
negative_prompt_ids | torch.LongTensor | 某些处理器(如 CFG)所需的 negative prompt |
negative_prompt_attention_mask | torch.LongTensor | 负面提示 IDs 的注意力掩码 |
在这些输入中,logits_processor和stopping_criteria,将是用户手动干预生成过程的主要手段。
generate过程
在4.42版本的transformers代码中,generate过程的注释写的比较条理清晰,所以本文也沿用代码注释中的序号进行划分。
读取并更新generation config
这一部分的大概逻辑就是处理generation config为None的情况,以及检查是否存在与生成策略不一致的错误参数。
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
self._validate_model_kwargs(model_kwargs.copy())
self._validate_assistant(assistant_model)
其中_validate_model_class,_validate_model_kwargs, _validate_assistant方法都不是重点,这里不展开介绍。
补充没有传入的参数
这部分需要补充的参数包括logits_processor, stopping_criteria, 以及generation_config中的pad_token_id。前两项是设置为默认的空list;查看self.forward以及model_args有没有attention_mask的传入
# 2. Set generation parameters if not already defined
if synced_gpus is None:
if is_deepspeed_zero3_enabled() and dist.get_world_size() > 1:
synced_gpus = True
else:
synced_gpus = False
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
定义模型输入
# 3. Define model inputs
inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
inputs, generation_config.bos_token_id, model_kwargs
)
batch_size = inputs_tensor.shape[0]
device = inputs_tensor.device
self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
# decoder-only models must use left-padding for batched generation.
if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
# If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
# Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
if (
generation_config.pad_token_id is not None
and batch_size > 1
and len(inputs_tensor.shape) == 2
and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id) > 0
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
这里主要需要关注_prepare_model_inputs这个方法,这个方法的核心,一句话概括就是模型输入的序列input_ids,必须非空,如果空的话,就用bos_token去初始化。其余部分都是用来应对个别模型的特殊情况。并检查decoder-only的模型输入,检查input_ids中任何序列中的最后一个 id 是否为“pad_token_id”
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[torch.Tensor] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
# 这一步似乎是起到一个校准的作用,防止某些encoder-decoder模型的主模型和encoder的输入名称不一致
# 1. retrieve all kwargs that are non-None or non-model input related.
# some encoder-decoder models have different names for model and encoder
if (
self.config.is_encoder_decoder
and hasattr(self, "encoder")
and self.encoder.main_input_name != self.main_input_name
):
input_name = self.encoder.main_input_name
else:
input_name = self.main_input_name
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
# 确保inputs没有重复传入
# 2. check whether model_input_name is passed as kwarg
# if yes and `inputs` is None use kwarg inputs
inputs_kwarg = model_kwargs.pop(input_name, None)
if inputs_kwarg is not None and inputs is not None:
raise ValueError(
f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
f"Make sure to either pass {inputs} or {input_name}=..."
)
elif inputs_kwarg is not None:
inputs = inputs_kwarg
# 如果 input_name 是 input_ids 且 model_kwargs 中存在 inputs_embeds这一输入参数:
# 如果是decoder-only模型,如果支持 inputs_embeds,则将 input_ids 转移到 model_kwargs 中,
# 这样后续的一些自动化步骤(如创建 attention_mask)可以依赖实际的模型输入。需要把'input_ids'这一参数放在inputs_kwarg中传入
# 如果是encoder-decoder模型,input_ids与inputs_embeds只能传入其一
# 3. In the presence of `inputs_embeds` for text models:
# - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
# doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
# input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
# - encoder-decoder models should complain if the user attempts to pass `inputs_embeds` and `input_ids`, and
# pull the former to inputs. It will be used in place of `input_ids` to get the encoder hidden states.
if input_name == "input_ids" and "inputs_embeds" in model_kwargs:
if not self.config.is_encoder_decoder:
has_inputs_embeds_forwarding = "inputs_embeds" in set(
inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
)
if not has_inputs_embeds_forwarding:
raise ValueError(
f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
"doesn't have its forwarding implemented. See the GPT2 implementation for an example "
"(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
)
# In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
# the attention mask) can rely on the actual model input.
model_kwargs["input_ids"] = self._maybe_initialize_input_ids_for_generation(
inputs, bos_token_id, model_kwargs=model_kwargs
)
else:
if inputs is not None:
raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
# 如果最后还是没有input_ids, 采用bos创建input_ids,可以简化理解为:
# torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
这里稍微解释下为什么会存在这样的差异检查:
decoder-only模型与encoder-decoder模型的不同
decoder-only模型(如 GPT 系列):
- 这种模型只包含解码器部分,通常用于纯文本生成任务。
- 它们通常需要处理来自上一步生成的文本,因此主要使用
input_ids来进行自回归(autoregressive)生成。- 对于这类模型,如果用户直接提供
inputs_embeds而不是input_ids,这可能会改变生成的方式,所以需要确认模型是否实现了对应的处理逻辑。encoder-decoder模型(如 BART、T5):
- 这种模型包含一个编码器和一个解码器,通常用于翻译、摘要等任务。
- 编码器接收输入并生成隐藏状态(hidden states),然后解码器使用这些隐藏状态来生成输出。
- 如果用户同时提供
inputs_embeds和input_ids,这会导致模型无法确定该使用哪种输入,因此需要报错提示用户选择其一。
定义模型的其他参数
这一部分没有需要特别注意的地方,主要就是一些config设置,补齐模型的其他参数,如创建attention_mask,确保encoder-decoder模型能够返回’ModelOutput’类等等。
# 4. Define other model kwargs
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache
if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
# if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
inputs_tensor, model_kwargs, model_input_name, generation_config
)
对自回归模型准备input_ids
这一步与4.3的主要区别在于,针对AR模型额外进行了处理。如果是encoder-decoder模型,确保模型的解码器输入是正确格式的、如果是decoder-only的模型则直接采用4.3创建的input_tensor作为input_ids。
# 5. Prepare `input_ids` which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
batch_size=batch_size,
model_input_name=model_input_name,
model_kwargs=model_kwargs,
decoder_start_token_id=generation_config.decoder_start_token_id,
device=inputs_tensor.device,
)
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)
if streamer is not None:
streamer.put(input_ids.cpu())
另外, 这里还有个函数heal_tokens, 主要功能为模型生成新的序列, 扩展或者替换序列中的尾部token,以便更好地匹配可能的扩展或替代
- 遍历每个批次中的尾部 token ID,并尝试找到这个 token 的可能扩展(即可能的替代 token)
- 如果找到了扩展 token,为每个扩展 token 应用一个生成偏置值 (
sequence_bias),使得这些 token 在生成时更有可能被选择。 - 在生成时,会轻微地偏向原始 token 以避免过于激进的修复(例如 'http' -> 'https' 这样的替代可能是不期望的)。
- 使用
self.generate方法重新生成新的序列,替换掉原始序列中的尾部 token。 - 最终返回处理后的
input_ids,这些序列可能在尾部 token 被替换或扩展后得到了改进。
假设你有一个输入序列为input_ids = torch.tensor([[203, 204, 205]]),这些 IDs 对应的 tokens 是'hello world'。如果205是一个可以扩展的 token(例如,它可能表示'world'或者'worldwide'),那么这个函数会尝试查找和替换它。如果找到了更好的扩展(例如'worldwide'),函数将会生成一个新序列替换掉原来的序列。
最终生成的新序列可能是torch.tensor([[203, 204, 305]]),其中305是新的 token ID,对应'worldwide'。这样,原始输入序列就被“修复”成了更合适的输出。
准备最大长度
这一部分就是根据config中的相关配置,判断input_id的长度有没有超长, 被封装进了_prepare_generated_length和_validate_generated_length函数中, 后面都是cache相关的参数设置, 暂时不去考虑。
# 6. Prepare `max_length` depending on other stopping criteria.
input_ids_length = input_ids.shape[-1]
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
generation_config = self._prepare_generated_length(
generation_config=generation_config,
has_default_max_length=has_default_max_length,
has_default_min_length=has_default_min_length,
model_input_name=model_input_name,
inputs_tensor=inputs_tensor,
input_ids_length=input_ids_length,
)
use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation,
getattr(generation_config, "num_beams", 1) * batch_size,
generation_config.max_length,
)
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs["past_key_values"] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
确认生成模式
这里直接选择beam search分支了,其他模式不做展开介绍,下同。
beam search分为两种,beam_search以及进阶款的后者对应后续的生成方法为beam_sample。
如果do_sample为True, 会选择beam_sample
二者的区别主要在于,进阶款的beam_sample_gen_mode可以设置temperature、top_k、top_p等参数进一步控制生成,设置的方法在logits warper中介绍。对于基础款的beam_search,就没有创建logits warper这一环节。
# 7. determine generation mode
generation_mode = generation_config.get_generation_mode(assistant_model)
if streamer is not None and (generation_config.num_beams > 1):
raise ValueError(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)
创建logits processor
# 8. prepare distribution pre_processing samplers
prepared_logits_processor = self._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_length,
encoder_input_ids=inputs_tensor,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
device=inputs_tensor.device,
model_kwargs=model_kwargs,
negative_prompt_ids=negative_prompt_ids,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
这一个环节比较重要,因为涉及到了logits processor。这些processor是在生成的过程中,在每一个step,对计算出来的得分进行修正处理的。在transformers中,预设了若干processor,用户也可以定义自己的processor(需要继承抽象类transformers.generation.logit_process.LogitsProcessor),自己设计逻辑,来对生成的过程进行人工干预。
在beam search中,logits process的使用方法是:
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
其中,input_ids是当前step传给模型的序列token id对应Tensor(batch_size, sequence_length),next_token_scores是经过模型计算之后的分数(即在vocab上的概率分布)取log_softmax。
在这里简单介绍一下在transformers中预设的processor。限于篇幅,不贴出全部源码,只对其功能进行总结。
processor | 作用 | 参考连接 |
|---|---|---|
MinLengthLogitsProcessor | 通过将EOS的概率强行设置为0,保证生成结果的长度大于等于一个最小值 | / |
MinNewTokensLengthLogitsProcessor | 与上一个类似,但是prompt的部分不计入生成长度 | / |
RepetitionPenaltyLogitsProcessor | 防止“复读机”现象,给重复出现token添加惩罚,由预训练模型CTRL提出 | |
EncoderRepetitionPenaltyLogitsProcessor | 与上一个区别在于,生成的结果不能与encoder输入input id重复,而非与当前给定的全部input id | / |
NoRepeatNGramLogitsProcessor | 防止生成的文本中出现重复的n-gram(n个连续的词或字符),区别在于限制连续n个 | |
EncoderNoRepeatNGramLogitsProcessor | n-gram可以在encoder的input ids中重复,不可以在decoder重复 | |
NoBadWordsLogitsProcessor | 确保某些词永远不会被生成 | / |
PrefixConstrainedLogitsProcessor | 给定一个prefix_allow_func来限制符合哪些条件的token可以被生成 | |
HammingDiversityLogitsProcessor | 以Hamming距离为标准,确保生成的各个beam之前的区别足够大 | |
ForcedBOSTokenLogitsProcessor | 确保生成的第一个token是某个特定的token | / |
ForcedEOSTokenLogitsProcessor | 达到最大长度时,确保以某个特定的token作为结束 | / |
InfNanRemoveLogitsProcessor | 将计算出的得分中,nan替换为0,inf替换为可计算的最大值 | / |
SuppressTokensAtBeginLogitsProcessor | 在达到某个长度之后,将不再生成某些特定的词 | / |
SuppressTokensLogitsProcessor | 将某些特定词的概率设置为-inf,不生成这些词 | / |
ForceTokensLogitsProcessor | 建立一个映射表,把某个token强行映射成另一个token | / |
WhisperTimeStampLogitsProcessor | 强制模型生成时间戳(时间戳是一个特殊token,例如对话中,query=今天是周几?,answer=今天是[timestamp],这个[timestamp]后期会替换成对应的时间) | / |
创建停止规则
stopping_criteria与logits_processor是用户对生成过程进行干预的主要手段,相比logits_processor强行改变概率空间,stopping_criteria则是直接设定了终止生成的策略,理解起来也会相对容易一些。
# 9. prepare stopping criteria
prepared_stopping_criteria = self._get_stopping_criteria(
generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs
)
预设的criteria总结如下:
criteria | 作用 |
|---|---|
MaxLengthCriteria | 生成的序列达到设置的最大长度时,停止生成 |
MaxNewTokensCriteria | 生成的序列中,除去prompt的部分达到设置的最大长度时,停止生成 |
MaxTimeCriteria | 生成的耗时超过一定时间限制时,停止生成 |
如果是自定义criteria,应当继承抽象类transformers.generation.stopping_criteria.StoppingCriteria。
进入相应的分支
这里直接选择进入beam search的分支。如前文所述,如果要控制temperature等超参数,则应该进入is_beam_sample_gen_mode这个分支。
这些是生成模型中不同的生成方法,每种方法在文本生成的过程中有不同的策略:
- CONTRASTIVE_SEARCH:是一种平衡探索和利用的生成方法,适用于需要生成高质量且有一定多样性的文本任务。
- GREEDY_SEARCH: 每一步都选择最有可能的词,生成单一且通常连贯的文本,但缺乏多样性。
- SAMPLE: 根据概率分布随机选择下一个词,生成更具多样性的文本,但可能影响连贯性。
- ASSISTED_GENERATION: 结合人类输入或外部辅助信息来生成文本,通常用于对生成内容有特定需求的任务。
- BEAM_SEARCH: 在生成过程中维护多个候选序列(通常称为“束”),最终选择概率最高的完整序列。这种方法在质量和多样性之间取得平衡。
- BEAM_SAMPLE: 结合了束搜索和采样的策略,在保持候选序列的同时对每一步进行采样,以提高生成文本的多样性。
- CONSTRAINED_BEAM_SEARCH: 在束搜索的基础上增加了某些约束条件,确保生成的文本满足特定要求,如包含某些关键字或遵循特定的语法结构。
- GROUP_BEAM_SEARCH: 将候选序列分组进行束搜索,通常用于减少生成文本的重复性,并提高生成结果的多样性。这在需要生成多个不重复答案的场景中非常有用。
后面主要介绍Beam_search
创建logits warper
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. prepare logits warper
prepared_logits_warper = (
self._get_logits_warper(generation_config, device=input_ids.device)
if generation_config.do_sample
else None
)
logits warper的使用方法与logits processor一样,都是用来修改概率的输出。关于他们的区别,暂时没有找到很好的解释,可以理解为warper控制着temperature、topk等与生成策略相关的参数。并且是在logits processor处理之后再进行处理的。
普通的beam search不会涉及这一部分,只有选择sample模式的beam search时,才会使用到logits warper。
需要记住的是,它的输入与processor一样,都是当前的序列(token_ids)与之前计算出的得分(scores),返回的结果是处理之后的得分,形状是(batch_size, config.vocab_size)。
预设的warper包括:
warper | 作用(仅供参考) | 参考链接 |
|---|---|---|
TemperatureLogitsWarper | 对score整体除以temperature做缩放 | / |
TopPLogitsWarper | 概率小于top_p的得分置为0 | / |
TopKLogitsWarper | 只取topk的概率对应的词汇,其余的概率置为-inf | / |
TypicalLogitsWarper | typical decoding | |
EpsilonLogitsWarper | 将概率小于epsilon的token移除 | |
EtaLogitsWarper | eta-sampling | |
LogitNormalization | 在beam search进行的过程中做layernorm | / |
beam search
这一部分是beam search的核心流程,其具体的执行生成过程将在第5节中进行详细的介绍。
在这一部分中,首先创建了用于打分的BeamSearchScorer(具体作用将在第5节中进行介绍),然后根据num_beams对input_ids进行了扩展,最后执行beam search的核心方法beam_search,或beam sample对应的beam_sample方法。
# 12. prepare beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=generation_config.num_beams,
device=inputs_tensor.device,
length_penalty=generation_config.length_penalty,
do_early_stopping=generation_config.early_stopping,
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 13. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 14. run beam sample
result = self._beam_search(
input_ids,
beam_scorer,
logits_processor=prepared_logits_processor,
logits_warper=prepared_logits_warper,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
**model_kwargs,
)
Beam Search
简单介绍一下在文本生成任务中常用的解码策略Beam Search(集束搜索)。
生成式任务相比普通的分类、tagging等NLP任务会复杂不少。在生成的时候,模型的输出是一个时间步一个时间步依次获得的,而且前面时间步的结果还会影响后面时间步的结果。也就是说,每一个时间步,模型给出的都是基于历史生成结果的条件概率。为了生成完整的句子,需要一个称为解码(decode)的额外动作来融合模型多个时间步的输出,而且使得最终得到的序列的每一步条件概率连乘起来最大。
在文本生成任务中,每一个时间步可能的输出种类称为字典大小(vocabulary size,我们用V表示),进行T步随机的生成可能获得的结果总共有\(V^T\)种。拿中文文本生成来说,V 的值大约是5000-6000,即常用汉字的个数。在如此大的基数下,遍历整个生成空间是不现实的。
贪心搜索
每一个时间步都取出一个条件概率最大的输出,如图:

Beam Search原理
思路也很简单,就是稍微放宽一些考察的范围。在每一个时间步,不再只保留当前分数最高的1个输出,而是保留num_beams个。当num_beams=1时集束搜索就退化成了贪心搜索。

- 在第一个时间步,A和C是最优的两个,因此得到了两个结果[A],[C],其他三个就被抛弃了;
- 第二步会基于这两个结果继续进行生成,在A这个分支可以得到5个候选人,[AA],[AB],[AC],[AD],[AE],C也同理得到5个,此时会对这10个进行统一排名,再保留最优的两个,即图中的[AB]和[CE];
- 第三步同理,也会从新的10个候选人里再保留最好的两个,最后得到了[ABD],[CED]两个结果。 可以发现,beam search在每一步需要考察的候选人数量是贪心搜索的num_beams倍,因此是一种牺牲时间换性能的方法。
代码流程概览
为了帮助大家阅读代码,这里把这部分代码的整体逻辑进行一下梳理,如下图所示

总的来说,生成过程中不断重复调用模型的forward()计算出logits,以及调用BeamSearchScorer的process()来计算下一个位置每个token出现的得分,来生成下一个token及其概率分布,直到满足终止条件,结束生成。
BeamSearch Scorer
BeamSearchScorer是在生成过程进行状态维护的类,它的作用是用来更新Beam得分,以及判断生成过程是否结束等。在这一节中,简单了解一下这个类的构造,具体的使用方法会在本篇的第5节中,结合beam search的整个流程的推进,进行更加详细的介绍。
先简单解释一下其参数:
参数名 | 类型 | 含义 |
|---|---|---|
batch_size | int | 批量生成时一次处理多少条数据 |
num_beams | int | 每一条数据在生成时保留几个beam |
device | torch.device | cpu or cuda |
length_penalty | Optional[float] | 控制倾向于生成更长的句子还是更短的句子 |
do_early_stopping | Optional[Union[bool, str]] | 早停机制,是否生成达到num_beam后立即停止 |
num_beam_hyps_to_keep | int | 最终返回多少个beam |
num_beam_groups | int | 把所有的beam按照差异度分成多少组 |
max_length | int | 生成的最大长度 |
这个类除了构造方法之外,只有两个方法和一个属性:
@property
def is_done(self) -> bool:
def process(
self,
input_ids: torch.LongTensor,
next_scores: torch.FloatTensor,
next_tokens: torch.LongTensor,
next_indices: torch.LongTensor,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
def finalize(
self,
input_ids: torch.LongTensor,
final_beam_scores: torch.FloatTensor,
final_beam_tokens: torch.LongTensor,
final_beam_indices: torch.LongTensor,
max_length: int,
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
eos_token_id: Optional[Union[int, List[int], torch.Tensor]] = None,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
) -> Tuple[torch.LongTensor]:
- 其中
is_done用来记录是否batch中所有数据都已经生成结束; - process是生成的每一个step都需要执行的状态更新过程,属于生成中的主干部分;
- finalize是整个生成过程所有step都已经结束之后(出现EOS或达到stopping_criteria的终止条件),最终的后处理加工。
除此之外,这个类还有两个成员需要注意:
self.group_size是按照差异性对beam分组时,每一组的beam数量:
self.group_size = self.num_beams // self.num_beam_groups
- self._beam_hyps是一组容器,用来容纳得分最高的 \(n\) 个beam:
self._beam_hyps = [
BeamHypotheses(
num_beams=self.group_size,
length_penalty=self.length_penalty,
early_stopping=self.do_early_stopping,
max_length=max_length,
)
for _ in range(batch_size * self.num_beam_groups)
]
接下来在下节中,简单介绍一下这个BeamHypotheses类。
BeamHypotheses
BeamHypotheses,直接翻译过来就是“假说”,这个名称很容易引起迷惑,但其实把它看做是一个容器就好了,其容纳的内容就是 \(n\) 个得分最高的beam。batch中的每个样本,对应一个BeamHypotheses。
从构造方法可以看出,其自身除了一个self.beams用来容纳得分最高的beam之外,还有若干固有的属性:
class BeamHypotheses:
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty # 与BeamScorer的length_penalty是同一个东西,用来控制倾向于生成长序列还是短序列
self.early_stopping = early_stopping # 与BeamScorer的early_stopping是同一个,控制是否采用早停机制
self.max_length = max_length # 与BeamScorer的max_length是同一个,控制生成序列的最大长度
self.num_beams = num_beams # 与BeamScorer的num_beams是同一个,生成过程中保留多少个beam
self.beams = [] # 在生成过程中,用来容纳至多num_beams个beam
self.worst_score = 1e9 # 当前状态下最差一个beam的得分
if not isinstance(self.early_stopping, bool) and self.max_length is None:
raise ValueError(
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
" BeamScorer class instance at initialization time."
)
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
然后看一下BeamHypotheses的两个核心方法,add和is_done:
add方法用来将一个beam(对应的容器)添加到整个列表中:
def add(
self,
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
"""
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
# 如果超了设置的beam数量,则按照分数从小到大对beam进行排序
# 删除分数最小的对应的beam,然后把最小的分数更新
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
is_done方法用来判断是否所有beam都已经完成了生成:
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
# `True`: stop as soon as at least `num_beams` hypotheses are finished
if self.early_stopping is True:
return True
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
# when `length_penalty` is positive. See the discussion below for more details.
# <https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565>
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
else:
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
beam_search过程
beam_search与beam_sample两个模式的主流程,二者的区别不是很大,先来看beam_search。
def _beam_search(
self,
input_ids: torch.LongTensor,
beam_scorer: BeamScorer,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
logits_warper: Optional[LogitsProcessorList] = None,
**model_kwargs,
) -> Union[GenerateBeamOutput, torch.LongTensor]:
其中这些输入参数,多数在前面已经介绍过,这里需要注意的是BeamScorer,它是一个用来在生成过程中,对每一个step的概率得分进行计算,并且判断生成过程是否结束。
在这个方法中,有一个
while self._has_unfinished_sequences的循环,是其主体部分,也是beam search核心逻辑的体现。在这个while之前的部分基本都是些实例化初始化的内容,理解起来没有什么困难。唯一需要额外注意的,应该是beam score的初始化问题。
beam score初始化
beam score的初始化是一个比较细节的问题,并且是新版的代码对其进行了改进。
理论上,对于beam search的过程,需要维护一个beam score来记录生成过程中每个beam的得分即可,也就是维护一个(batch_size, num_beams)的tensor,然而在代码的实现中,却有这样一个细节:
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,))
即batch中的每一条,对应的要生成的所有beam中,只有第1个beam的得分初始化为0,其余beam全部都初始化为-inf。代码的注释也对这样的用意进行了解释:防止在生成过程中,所有的beam产生的结果都是一样的。
这里举一个例子来对此进行说明。
假如有这样的场景,有这样的一个句子作为开头:“我的家在”,需要模型生成接下来的内容。
那么在下一个step,需要根据现有的序列“我的家在”,来计算词表中所有词的得分。
假如beam_size为2,那么就会保留了得分最高的两个,此时我们期望得到的两个beam可能分别为:
beam 1:“我的家在东” beam 2:“我的家在北”
然后再一个step,这两个beam分别变成了:
beam 1:“我的家在东北” beam 2:“我的家在北京”
然而实际情况却并非如此,实际上,每个beam是一个容器类
BeamHypotheses, 在计算时,第一个beam的序列为“我的家在”时,第二个beam的序列也是“我的家在”, 这样一来,两个序列的tokens完全一致,对应的scores完全一致,后续生成的结果,也就会一直重复下去了。把第一个beam的分初始化为0,其余beam初始化为负无穷,就可以确保生成出来的只能在第一个beam对应的序列里, 就不会出现一直重复的情况了。
下面直接通过while循环来看beam search的主体逻辑。
准备输入
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
prepare_inputs_for_generation这个方法在GenerationMixin中没有定义,需要在具体的模型中定义,举一个最简单的例子,在InternVL中,该方法仅仅是在生成过程中根据当前的 input_ids 和之前生成的 past_key_values 动态调整模型的输入,以便高效生成后续的 token。同时,它还处理了输入的多种形式,如 inputs_embeds 和 input_ids,并动态生成 position_ids 以支持批量生成。
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
# Some generation methods already pass only the last input ID
if input_ids.shape[1] > past_length:
remove_prefix_length = past_length
else:
# Default to old behavior: keep only final ID
remove_prefix_length = input_ids.shape[1] - 1
input_ids = input_ids[:, remove_prefix_length:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1]:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update(
{
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
}
)
return model_inputs
前向forward
在新版transformers中, 还加入了对低内存消耗的支持,也就是判断sequential=True, 会将输入数据拆分成更小的子批次,然后在顺序执行模式,这种方式适用于内存受限的环境,比如在进行低内存的 beam_search 生成时,减少一次性处理的数据量以避免内存溢出。
有了输入之后,自然要将输入传输给模型进行计算,也就是网络的前向传播阶段,这里的self是调用自身,也就是GenerationMixin这个类,而我们在之前的分析中知道,其实这个类是被实际调用的模型所继承的,所以实际上这里是使用了生成模型的forward方法。
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
所以这个outputs,就是包含了loss,logits,以及可能包含attention与past_v_k等各种信息的计算结果。
以InternVL为例,在InternLM2ForCausalLM可以看到,它主要就是先经过了InternLM2Model的transformer网络,得到一个形状为(seq_len, bsz, hidden)的hidden_states,然后将其映射到词表上,就得到了在词表空间上的概率分布,形状为(bsz, seq_len, vocab),也就是常说的logits。多数LLM模型都是这样的一个套路。
output = CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
计算下一个step每个token的得分
在上一小节中,前向计算的结果,有很多项,其中在生成过程中,最关键的就是logits,它直接关系到下一个step生成的token是什么。
- logits的形状为
(bsz, seq_len, vocab),所以下面代码中,第一行取的outputs.logits[:, -1, :].clone(),也就是取最后一个位置的概率分布,即用来生成下一个step的token。 - 取
log_softmax(dim=-1)将logits变成vocab空间上的“概率”。 - 使用之前实例化的
logits_processor对计算出的概率进行进一步的处理 - 如果采用beam_sample,还需要用
logits_warper对计算出的概率进行进一步的处理 - 将processor处理之后的得分,与beam本身的得分相加算总分,即新的beam总分=原来的beam总分+即将生成的新token的分。这里可以回顾一下之前beam score初始化的细节,在while循环中的第一个step,只有第一个beam的分不是-inf,而之后的step中就不存在这个问题了。
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
next_token_scores_processed = logits_processor(input_ids, next_token_scores)
if do_sample:
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(
next_token_scores_processed
)
这里还有一个细节,就是在最后一步中,为什么next_token_scores_processed与beam_scores可以直接相加,我的理解是,计算最大概率的beam,其基本的概率公式应该是每一个step的概率相乘:
而由于在之前的得分计算中,已经取了对数,也就把原本乘性的问题变成了加性,二者自然可以直接相加了。
选择next token
在sample之前,有一个reshape的过程,将next_token_scores的形状从(batch_size * num_beams, vocab_size)变成了(batch_size, num_beams * vocab_size),也就是说,将num_beam展平在了vocab的维度上:
# reshape for beam search
vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)

经过了这样的reshape,就把batch中的每一条样本,其包含的所有beam,放在一起进行对比了。更具体一点来讲,就是在 [ 选择第1个beam的情况下,再选词表中第1个词, 选择第1个beam的情况下,再选词表中第2个词, …, 选择第2个beam的情况下,再选词表中第1个词, …, 选择第2个beam的情况下,再选词表中第6个词, ] 之中,选取概率最高的。由于这里介绍的是最基础的beam_search,所以还没涉及到top_k等超参数,这些部分在下文中会继续介绍。
在实际操作中,多采了一倍的token作为备选,以确保后续不会出问题。
接下来的torch.div的操作,是因为在topk之前,将beam展平在了vocab上,所以算出来的indices是在所有beam上的一个“绝对位置”,需要将它变成在每一个beam上的“相对位置”。
# Beam token selection: pick 1 + eos_token_id.shape[0] next tokens for each beam so we have at least 1
# non eos token per beam.
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
n_tokens_to_keep = max(2, 1 + n_eos_tokens) * num_beams
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
else:
next_token_scores, next_tokens = torch.topk(
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
更新beam状态
在这一步中,对beam的状态进行了更新,依赖BeamScorer的process方法。
# stateless
beam_outputs = beam_scorer.process(
input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=beam_indices,
decoder_prompt_len=decoder_prompt_len,
)
这里就涉及到了scorer的process部分,对此进行详细的说明:
注意自此开始,代码跳转到transformers.generation.beam_search.BeamSearchScorer.process
初始化相关参数:
cur_len:当前生成序列的长度。batch_size:计算批次大小,表示每个组的束搜索的数量。next_beam_scores、next_beam_tokens、next_beam_indices:这些变量用于存储下一个步骤中选择的token的得分、token值和对应的索引, 注意这三项既是process方法的输入,也是process最终的输出,作为下一次process的输入
并且,在上一节的代码中可以看到,beam_scores在传入给process之前,已经在dim=1上做了排序,也就是在vocab_size的维度。
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
raise ValueError(
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
f"size of {self.group_size} is expected by the beam scorer."
)
else:
raise ValueError(
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
f"{self.group_size} is expected by the beam scorer."
)
device = input_ids.device
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
接下来是process的主体部分,对每一个beam_hyp(即每一个进行生成中的束)进行遍历,过程的细节以注释的形式写在了代码里,这一部分的逻辑不算复杂,但是其中也涉及到了一些由分组运算而引发的细节问题:
for batch_idx, beam_hyp in enumerate(self._beam_hyps):
# 如果当前这一束已经被标记为完成了生成,则将三项输出结果进行padding
if self._done[batch_idx]:
if self.num_beams < len(beam_hyp):
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
if eos_token_id is None or pad_token_id is None:
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
# pad the batch
next_beam_scores[batch_idx, :] = 0
next_beam_tokens[batch_idx, :] = pad_token_id
next_beam_indices[batch_idx, :] = 0
continue
# next tokens for this sentence
# 如果当前这一束还没有完成,则计算这一束的下一个token
beam_idx = 0
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
):
# 由于在一开始输入到process时,next_tokens等,就是在vocab维度上排好序的
# 所以这里只需要按顺序添加即可
# 这里是将某个beam中的相对位置恢复为整个tensor中的绝对位置,注意看第5.6.5节中的图
batch_beam_idx = batch_idx * self.group_size + next_index
# add to generated hypotheses if end of sentence
# 最高得分是结束符的情况
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added
# 如果当前得分最高的是结束符则需要进行额外的一步判断
# 因为在计算得分的时候是将一组中所有beam放在一起计算的,所以即便是预测到了eos,
# 如果它不再前第num_beams个token范围内的话,那这个eos就不能算数
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams:
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
)
else:
# add next predicted token since it is not eos_token
# 如果不是eos token的话,则直接添加即可
next_beam_scores[batch_idx, beam_idx] = next_score
next_beam_tokens[batch_idx, beam_idx] = next_token
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
beam_idx += 1
# once the beam for next step is full, don't add more tokens to it.
if beam_idx == self.group_size:
break
if beam_idx < self.group_size:
raise ValueError(
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
)
# Check if we are done so that we can save a pad step if all(done)
# 更新beam的完成状态
cur_len += 1 # add up to the length which the next_scores is calculated on
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
return UserDict(
{
"next_beam_scores": next_beam_scores.view(-1),
"next_beam_tokens": next_beam_tokens.view(-1),
"next_beam_indices": next_beam_indices.view(-1),
}
)
以上就是BeamScorer的process过程,在计算出新的beam_scores等三项结果之后,还需要进行进一步的处理:
注意从这里开始,代码回到transformers.generation.utils.GenerationMixin._beam_search。
这一部分代码用来更新生成参数,保存past_key_values,以及判断是否满足停止条件。
beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
this_peer_finished = True
后处理finalize
当生成终止后,还需要进行一个统一的后处理流程,以选择最佳的序列作为最终结果返回。
代码位于transformers.generation.beam_search.BeamSearchScorer.finalize。
在这个环节中,首先需要把没有完成的beam对应的token和score添加到容器中。回顾process部分的代码,可以看到,只有当预测出eos token,并且满足一定条件时,token和score才会被添加到beam_hyp容器中,而根据beam search的整体逻辑,每个step的状态更新完成时,不管是否添加到了容器中,都需要对结束状态进行判断,而判断时,stopping_criteria就会发挥作用了。这也就会造成存在这样一种情况,还没有结束生成的beam,由于满足了stopping_criteria的中止条件,而提前中止,此时的token和score并没有被添加到beam_hyp中,所以需要这样一个后处理的动作,来确保最终得到的beam数量,等于预先设置的num_beams。
batch_size = len(self._beam_hyps) // self.num_beam_groups
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id = torch.tensor(eos_token_id)
# finalize all open beam hypotheses and add to generated hypotheses
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
if self._done[batch_group_idx]:
continue
# all open beam hypotheses are added to the beam hypothesis
# beam hypothesis class automatically keeps the best beams
for index_per_group in range(self.group_size):
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
然后根据score从高到低对所有的束进行排序,保留得分最高的num_beam_hyps_to_keep个束。
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses
for i in range(batch_size):
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
for j in range(self.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append hyp to lists
best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
最后,对保留的所有束进行padding,已经添加eos结束符:
# prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None
# shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item():
if pad_token_id is None:
raise ValueError("`pad_token_id` has to be defined")
decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len:
# inserting only the first eos_token_id
decoded[i, sent_lengths[i]] = eos_token_id[0]
return UserDict(
{
"sequences": decoded,
"sequence_scores": best_scores,
"beam_indices": indices,
}
)
以上就是beam search的完整流程了。在实际应用中,使用更多的方法一般是beam search的升级版,beam sample,在下节中,将简单介绍一下beam sample模式与一般beam search的主要区别。
Beam Sample
Beam sample与一般的beam search相比,主要区别体现在其需要根据GenerationConfig的配置,创建若干logits warper,对计算出的next_token_scores进行进一步的加工。
从代码中来看,beam_sample方法与beam_search方法相比,区别主要在于while self._has_unfinished_sequences的循环中,增加了logits_warper:
if do_sample:
next_token_scores_processed = logits_warper(input_ids, next_token_scores_processed)
这里就以其中一种logits wrapper为例进行介绍。
Temperature是生成过程中一项重要超参数,它控制着生成结果是否具有“创造性”。这个数值一般介于[ 0.1 , 1 ] [0.1, 1][0.1,1],该值越大,越倾向于生成概率不那么高的token,结果更具有“创造性”,等于1时,相当于原始的softmax得到的分布;而该值越小,则倾向于生成更加保守的结果,当接近于0时,则趋向于greedy search。
对应的wrapper实现如下:
class TemperatureLogitsWarper(LogitsWarper):
def __init__(self, temperature: float):
if not isinstance(temperature, float) or not (temperature > 0):
except_msg = (
f"`temperature` (={temperature}) has to be a strictly positive float, otherwise your next token "
"scores will be invalid."
)
if isinstance(temperature, float) and temperature == 0.0:
except_msg += " If you're looking for greedy decoding strategies, set `do_sample=False`."
raise ValueError(except_msg)
self.temperature = temperature
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
scores_processed = scores / self.temperature
return scores_processed
从中可以看到,它只是将原来的得分除以temperature的数值。结合logits_warper在整体流程中的位置(warper的调用位于softmax之后),可以看出这一计算并没有在当前step生效,而是在下一个step时才会生效,这也符合带temperature的softmax的公式。 原始的softmax:
增加temperature之后的softmax:
其他的warper也是类似的使用方法,是作用在softmax计算完当前step的得分之后。
beam_sample 相比beam_search还有一处不同处
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=n_tokens_to_keep)
next_token_scores = torch.gather(next_token_scores, -1, next_tokens)
next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)
next_tokens = torch.gather(next_tokens, -1, _indices)
else:
next_token_scores, next_tokens = torch.topk(
next_token_scores, n_tokens_to_keep, dim=1, largest=True, sorted=True
)
- Beam Sample (
do_sample=True): 基于概率分布进行采样,生成较为多样化的文本。通过softmax将next_token_scores转换为概率,再进行torch.multinomial采样。 - Beam Search (
do_sample=False): 直接选择分数最高的几个 token,生成更为确定性的输出,适用于希望生成高概率、较为确定的文本场景。
Reference
以beam search为例,详解transformers中generate方法(上)