Reading

Data Packing

简介

在深度学习模型(尤其是 Transformer 架构)的训练中,输入数据的长度通常需要保持一致。如果直接输入大量短文本,就需要用大量无意义的占位符(Padding)来补齐长度,这会极大地浪费 GPU 的计算资源。为了最大化计算效率,目前基本主流的训练框架里都会加入数据打包(Data Packing)的逻辑,本文以lmms-engine中的操作为例,具体查看实际训练时对数据packing操作以及use_rmpad消除所有padding计算的逻辑

Packing

Dataset

如下Dataset代码所示,这段代码核心目标是:在不超过预设最大长度(packing_length)的前提下,尽可能多地将短样本塞进同一个批次(Batch)中。

这种做法带来了两个显著的好处:

  1. 提升计算效率:减少了 Padding Token 的数量,让 GPU 的每一次矩阵乘法都作用在真实有效的数据上。
  2. 稳定训练过程:每个 Batch 的有效 Token 数量更加一致,有助于梯度的稳定
if self.config.packing:
    # Reset index at the start of each iteration pass
    self.cur_idx = 0
    buffer = []
    buffer_length = 0
    packing_length = self.config.packing_length

    # Iterate through the dataset once per epoch
    while self.cur_idx < len(curr_data_list):
        try:
            data_dict = self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list)
        except Exception as e:
            traceback.print_exc()
            logger.error(f"Error getting one sample: {e}, skip this sample")
            self.cur_idx += 1
            continue
        input_ids = data_dict["input_ids"]
        data_length = input_ids.shape[0]
        self.cur_idx += 1

        # Drop overlong sample if filtering is enabled
        if data_length > packing_length and self.config.filter_overlong:
            continue

        # If current sample cannot fit into current buffer, yield the buffer first
        if buffer_length > 0 and buffer_length + data_length > packing_length:
            yield buffer
            buffer = []
            buffer_length = 0

        # If the sample is still longer than packing_length (and not filtered),
        # yield it as its own batch to avoid stalling
        if data_length > packing_length:
            yield [data_dict]
            continue

        # Append to buffer
        buffer.append(data_dict)
        buffer_length += data_length

    # Flush remaining buffer
    if len(buffer) > 0:
        yield buffer

总结数据生成器(Generator)具体步骤

代码通过一个 while 循环遍历数据集,具体的执行流程可以清晰地拆解为以下六个步骤:

  1. 初始化缓冲区
    在每次遍历(Epoch)开始前,初始化一个空列表 buffer 作为“购物车”,用 buffer_length 记录当前购物车内数据的总长度,并设定容量上限 packing_length
  2. 安全读取单个样本
    使用 try-except 结构调用 get_one_sample 获取当前索引的数据。如果某条数据损坏或读取失败,代码会捕获异常、打印错误日志,并优雅地跳过(continue),防止整场训练任务因为一条脏数据而崩溃。
  3. 超长样本过滤(按需拦截)
    获取当前样本的长度 data_length。如果该长度已经超过了 packing_length,并且系统配置中开启了 filter_overlong,则直接丢弃该样本,不让其参与后续的打包。
  4. 缓冲区溢出判断(核心打包机制)
    在将当前样本放入缓冲区之前,先进行“容量预判”:
    • 如果 buffer 中已经有数据了,且 现有长度 + 当前样本长度 > 容量上限,说明购物车装不下了。此时,代码会通过 yield buffer 将当前收集到的所有样本作为一个整体产出,随后清空购物车,重置长度。
  5. 处理单件超大样本(兜底策略)
    如果当前样本本身的长度就大于 packing_length(且在第 3 步中没有被过滤掉),为了防止它卡死后续正常的打包流程,代码会直接将其作为一个独立的列表 [data_dict] 产出(yield),然后直接进入下一次循环。
  6. 装入缓冲区与最终收尾
    如果样本顺利通过了上述所有检查,它就会被追加到 buffer 中,并累加 buffer_length。当整个数据集遍历完毕跳出 while 循环后,如果 buffer 中还有未产出的“尾货”(len(buffer) > 0),代码会执行最后一次 yield,确保没有任何数据被遗漏。

Collator

在 Collator 组合成batch的形式传入到模型的输入, 这里还是将数据padding, 如下面代码所示,这段代码定义了一个 VisionCollator 类,它的核心目标是作为 DataLoader 中的 collate_fn(数据整理函数)使用:将多个独立的、长度不一的样本字典,汇总并填充(Padding)成一个统一的、可以直接输入给模型进行前向传播的批次(Batch)张量字典。

在多模态大语言模型(VLM)的训练或推理中,由于每个样本的文本长度不同,无法直接堆叠成矩阵。VisionCollator 的作用就是像一个“打包员”,它不仅能动态补齐文本序列的长度(支持左侧或右侧填充),还能自动生成正确的注意力掩码(Attention Mask),并妥善拼接图像特征或其他全局配置参数,确保数据完美符合模型的输入要求

@ dataclass
class VisionCollator:
    processor: Processable

    def pad_sequence(self, input_ids, batch_first, padding_value):
        if self.processor.tokenizer.padding_side == "left":
            input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
        if self.processor.tokenizer.padding_side == "left":
            input_ids = torch.flip(input_ids, [1])
        return input_ids

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        if isinstance(instances[0], list):
            instances = [inst for instance in instances for inst in instance]
        inputs = collections.defaultdict(list)
        for instance in instances:
            for key, values in instance.items():
                inputs[key].append(values)

        batched_inputs = {}
        if "input_ids" in inputs.keys():
            input_ids = inputs.pop("input_ids")
            input_ids = self.pad_sequence(
                input_ids,
                batch_first=True,
                padding_value=self.processor.tokenizer.pad_token_id,
            )
            batched_inputs["input_ids"] = input_ids
        if "labels" in inputs.keys():
            labels = inputs.pop("labels")
            labels = self.pad_sequence(
                labels,
                batch_first=True,
                padding_value=-100,
            )
            batched_inputs["labels"] = labels

        if "attention_mask" in inputs.keys():
            inputs.pop("attention_mask")

        attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
        batched_inputs["attention_mask"] = attention_mask

        # for the other keys
        for key, values in inputs.items():
            # Handle scalar/boolean values ( use_audio_in_video)
            if isinstance(values[0], bool) or (
                isinstance(values[0], (int, float)) and not isinstance(values[0], torch.Tensor)
            ):
                batched_inputs[key] = values[0]
            else:
                batched_inputs[key] = torch.concatenate(values, dim=0)
        return batched_inputs

    @property
    def image_token_id(self):
        return self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)

代码的执行流程主要集中在 __call__ 方法中,配合辅助的 pad_sequence 函数,可以清晰地拆解为以下六个关键步骤:

步骤 1:展平嵌套列表(无缝衔接打包逻辑)

if isinstance(instances[0], list):
    instances = [inst for instance in instances for inst in instance]

当 DataLoader 接收到上一段代码 yield buffer 产出的数据时,instances 会是一个“列表的列表”。这里通过列表推导式将其展平为一维的样本列表,方便后续统一处理。

步骤 2:按特征键值聚合数据

inputs = collections.defaultdict(list)
for instance in instances:
    for key, values in instance.items():
        inputs[key].append(values)

遍历所有样本,将每个样本字典中相同的键(如 input_ids, labels, pixel_values 等)提取出来,分别放入对应的列表中。这一步完成了数据从“按样本组织”到“按特征组织”的转换。

步骤 3:序列填充机制(pad_sequence

def pad_sequence(self, input_ids, batch_first, padding_value):
    # ...

PyTorch 原生的 pad_sequence 只能在序列右侧填充。但很多大模型(如 LLaMA 系)在特定场景下需要左侧填充。如果配置为左填充,先将序列翻转(torch.flip),调用原生函数在右侧填充后,再翻转回来。

步骤 4:处理核心输入与标签
提取出 input_idslabels 后,调用上一步的填充函数:

  • input_ids:使用分词器标准的 pad_token_id 补齐,对齐文本输入长度。
  • labels:使用 -100 补齐。在 PyTorch 的交叉熵损失函数中,默认会忽略值为 -100 的标签,从而防止 Padding 部分参与 Loss 计算。

步骤 5:动态生成注意力掩码(Attention Mask)

attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()

为了确保绝对的准确性,代码直接丢弃了原始数据中可能存在的 attention_mask,而是通过判断填充后的 input_ids 中哪些位置不等于(ne Padding Token,动态生成了一个全新的 attention_mask(真实数据为 1,Padding 为 0)。

步骤 6:处理多模态特征与其他键值
对于剩下的键(例如图像张量或全局布尔开关):

  • 如果是全局标量或布尔值(如 use_audio_in_video),直接取第一个值作为整个 Batch 的配置。
  • 如果是张量(如图像特征),则在第 0 维度(Batch 维度)上将它们直接拼接(torch.concatenate)在一起。

rmpad

项目中,是以 monkey patch的形式(也就是打热补丁) 在运行时动态替换 Qwen3-VL 模型的核心前向传播(Forward)函数,从而注入极致的底层加速算子(通常是为了实现 Remove Padding 即 rmpad 优化)。在处理变长序列或经过“打包(Packing)”的数据时,传统的注意力机制(Attention)仍然需要对齐长度,这会带来冗余计算。rmpad(移除填充)是一种高级优化技术(常与 FlashAttention 的 varlen 变长序列处理结合),它能在底层 CUDA 算子层面直接跳过 Padding 部分的计算。这段代码就是在系统开启 use_rmpad 时,用这套高性能的自定义算子替换掉 PyTorch 原生的标准实现,以大幅提升吞吐量并降低显存占用。

rmpad 是如何工作的?

rmpad 的出现就是为了从根本上消灭这种浪费。它的核心思想是:打破传统的二维矩阵束缚,只计算真正有效的数据。

具体步骤如下:

  1. 展平与压缩(Flattening):在数据进入 Attention 层之前,底层算子会把原本形状为 [batch_size, max_seq_len] 的二维矩阵,直接“拍扁”成一个一维的长条序列 [total_valid_tokens]。所有的 Padding 都在这一步被物理移除了。
  2. 记录边界(cu_seqlen):既然数据变成了一维长条,GPU 怎么知道哪几个 Token 属于文本 A,哪几个属于文本 B 呢?系统会额外传递一个极小的索引数组(通常叫 cu_seqlen,即累加序列长度),用来标记每条数据的起止位置。
  3. 变长注意力计算(Varlen Attention):替换后的底层 CUDA 算子(比如 FlashAttention 的 varlen 接口)会读取这个一维长条和边界索引,只在每条数据的有效边界内进行 Attention 计算,绝不跨界,也绝不计算任何 Padding。

所以为了提高GPU 的利用率(MFU),一般是基于这三个操作来复合实现的

  1. 数据打包(Packing):在最上层,尽量把多个短文本拼成一个长文本,减少整体的 Padding 比例。
  2. 数据整理(Collator):在中间层,把实在拼不齐的尾部数据用 Padding 补齐,生成 Mask,保证数据能顺利变成 Tensor。
  3. 底层算子(rmpad):在最底层计算时,直接把 Collator 补上的那一点点 Padding 再次剔除,调用专属的变长算子(Varlen Ops)进行绝对纯净的有效计算。
if use_rmpad:
	from .qwen3_vl_ops import attn_forward as qwen3_ops_attn_forward
	from .qwen3_vl_ops import (
	    decoder_layer_forward as qwen3_ops_decoder_layer_forward,
	)
	from .qwen3_vl_ops import model_forward as qwen3_ops_model_forward
	from .qwen3_vl_ops import text_model_forward as qwen3_ops_text_model_forward

	modeling_qwen3_vl.Qwen3VLModel.forward = qwen3_ops_model_forward
	modeling_qwen3_vl.Qwen3VLTextModel.forward = qwen3_ops_text_model_forward
	modeling_qwen3_vl.Qwen3VLTextDecoderLayer.forward = qwen3_ops_decoder_layer_forward
	modeling_qwen3_vl.Qwen3VLTextAttention.forward = qwen3_ops_attn_forward

Qwen3VLModel.forward

这里就是 rmpad(移除填充)概念的具体演示,大模型在进入核心 Attention 计算前,是如何在底层将二维矩阵“拍扁”并精准剔除无用 Padding 的。

下面代码显式调用了 _unpad_input。它计算了非 padding 元素的索引 (indices) 和累积序列长度 (cu_seq_lens), 去掉了input中的padding token

if input_ids is not None:
    original_input_ids = input_ids
    input_ids, indices, cu_seq_lens, _ = _unpad_input(input_ids, attention_mask=attention_mask)
    batch_size, seq_length = original_input_ids.shape
elif inputs_embeds is not None:
    original_inputs_embeds = inputs_embeds
    inputs_embeds, indices, cu_seq_lens, _ = _unpad_input(inputs_embeds, attention_mask=attention_mask)
    batch_size, seq_length, _ = original_inputs_embeds.shape
def _unpad_input(input_ids, attention_mask):
    valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
    seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]

    unpad_seq_len = input_ids.shape[0]

    return input_ids, indices, cu_seqlens, max_seqlen_in_batch

下面来具体分析_unpad_input的步骤:

第一步:提取有效掩码与计算真实长度

valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)

先将 attention_mask 降维,找到所有值为 1(即真实 Token)的位置,得到 valid_mask。然后沿着序列维度求和,算出这个 Batch 中每一行真实的有效长度(例如 [1000, 10])。

第二步:定位所有有效 Token 的一维坐标

indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()

把二维的 valid_mask 彻底展平(Flatten),然后用 torch.nonzero 找出所有真实 Token 在这个一维长条中的绝对索引位置。这个 indices 既用于现在的“提取”,也用于未来的“还原”。

第三步:计算 FlashAttention 必备的 cu_seqlens

cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))

这是极其关键的一步!底层变长算子需要知道每条序列的起止位置。
假设 seqlens_in_batch[3, 2, 4]

  • cumsum(累加)后变成 [3, 5, 9]
  • F.pad 在最前面补个 0,最终得到 [0, 3, 5, 9]。 这就是大名鼎鼎的 cu_seqlens(Cumulative Sequence Lengths)。底层 CUDA 算子拿到它,就知道第 1 条序列在索引 0~3,第 2 条在 3~5,绝不会算串位。

第四步:物理剔除 Padding(张量重塑与索引提取)

input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]

这里使用了 einops.rearrange 将形状为 (batch, seq_len, dim) 的张量直接拍扁成 (batch * seq_len, dim)。随后,利用第二步算出的 indices只把有效的 Token 抽出来
至此,所有的 Padding 被彻底抛弃,数据变成了绝对紧凑的形态。

由于输入形状变了,位置编码的施加方式也必须改变。

计算出的 position_ids 需要根据 indices 进行重排和筛选,以匹配扁平化后的 input_ids

# 将 (batch, seq) 的 pos_ids 展平,并只取非 padding 部分
position_ids = (
        index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1)
    )

最终传递给 Language Model 的参数最后调用底层 LLM 时,参数列表也不同

outputs = self.language_model(
    ...,
    indices=indices,          # 关键:用于恢复原始形状或索引
    cu_seq_lens=cu_seq_lens,  # 关键:FlashAttention Varlen 需要知道每个句子的边界
    ...
)

Qwen3VLTextAttention.forward

这里其实就是显示的使用 flash_attn_varlen_func来做attention计算, 从而达到packing的最终目的。

attn_output = flash_attn_varlen_func(
    q=query_states,
    k=key_states,
    v=value_states,
    cu_seqlens_q=cu_seq_lens,
    cu_seqlens_k=cu_seq_lens,
    max_seqlen_q=max_seqlen,
    max_seqlen_k=max_seqlen,
    causal=True,
    window_size=window_size,
    softmax_scale=self.head_dim**-0.5,
    dropout_p=0.0,
)

简单来说,flash_attn_varlen_func 的核心任务是:在一堆已经混在一起、分不清谁是谁的 Token 长条中,利用“索引导航”精准地还原出原本的句子结构,并进行 Attention 计算。

下面通过一个具体的例子来详细拆解这个过程:

  1. 数据的变形:从“方阵”到“长条”
    在传统的 Attention(如 forward 函数)中,数据是规整的方阵,短句子后面补了 Padding(0):
    而在 Rmpad (Remove Padding) 模式下,我们在进入 flash_attn_varlen_func 之前(即 attn_forward 的前半部分),已经把所有 Padding 扔掉了,并把剩下的 Token 首尾相连拼成了一个超长的 1D 序列:
    • 传统 Batch (Batch=3):
      • 句子 A (长度 2): [A1, A2, 0, 0]
      • 句子 B (长度 4): [B1, B2, B3, B4]
      • 句子 C (长度 3): [C1, C2, C3, 0]
      • 形状: (3, 4, Hidden_Dim)。其中 0 是浪费的计算。
    • Varlen 输入 (q, k, v):
      • [A1, A2, B1, B2, B3, B4, C1, C2, C3]
      • 形状: (Total_Tokens, Num_Heads, Head_Dim)。这里 Total = 2+4+3 = 9。
      • 注意: 这里完全没有 Batch 维度了。
  2. cu_seqlens (Cumulative Sequence Lengths)

既然数据变成了一长条,FlashAttention 怎么知道 A2 后面是 B1(属于不同句子,不能做 Attention),而不是 A3 呢?

这就靠代码中的 cu_seqlens(累积序列长度)。这是一个 “切分点”索引数组

对于上面的例子,cu_seqlens 的值会是: [0, 2, 6, 9]

在代码中的体现:

flash_attn_varlen_func(
    q=query_states,           # 形状: (Total_Tokens, nheads, dim)
    cu_seqlens_q=cu_seq_lens, # 形状: (Batch_Size + 1,) -> [0, 2, 6, 9]
    max_seqlen_q=max_seqlen,  # 告诉内核最长的句子有多长 (这里是 4)
    ...
)

CUDA 内核拿到这个 cu_seqlens 后,就能在显存中快速定位每个句子的起始和结束位置,从而只在 A1A2 之间计算 Attention,绝不会跨越边界去和 B1 计算。

    • 0 -> 2: 第 1 个句子(A)的范围(索引 0 到 2,不含 2)。
    • 2 -> 6: 第 2 个句子(B)的范围(索引 2 到 6,即 2,3,4,5)。
    • 6 -> 9: 第 3 个句子(C)的范围。

Loss计算

在已经剔除 Padding 且被展平成一维长条的数据上,正确地执行自回归语言模型的“错位(Shift)”操作,确保模型只在单条序列内部预测下一个 Token,而绝对不会跨序列预测。

在标准的二维矩阵 [batch, seq_len] 中,我们通常只需要简单地切片 logits[:, :-1]labels[:, 1:] 就能让输入和标签错开一位。但是,当启用了 use_rmpad,所有序列首尾相连变成了一条线时,如果直接整体错位,序列 A 的最后一个 Token 就会去预测序列 B 的第一个 Token,这会导致模型学到完全错误的因果关系。这段代码就是为了精准规避这个灾难。

if use_rmpad:
    # We need to shift the tokens according to seq lens
    # Otherwise, the first labels of the next seq will be the last labels of the current seq
    shift_hidden_states = []
    shift_labels = []
    for i in range(len(seq_lens) - 1):
        cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
        cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
        cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
        cur_shift_labels = cur_labels[1:].contiguous()
        shift_hidden_states.append(cur_shift_hidden_states)
        shift_labels.append(cur_shift_labels)
    shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
    shift_labels = torch.cat(shift_labels, dim=0)

代码通过遍历累加的长度索引(seq_lens,在这里它的作用等同于上一问的 cu_seqlens),逐个序列进行安全错位。具体步骤如下:

  1. 准备容器与遍历边界
shift_hidden_states = []
shift_labels = []
for i in range(len(seq_lens) - 1):

首先初始化两个空列表,用来存放处理好的各个序列片段。接着,利用 seq_lens 数组进行循环。因为 seq_lens 记录的是累加边界(例如 [0, 3, 5, 9]),所以循环 len(seq_lens) - 1 次刚好能遍历完所有的独立序列。

  1. 精准切分单条序列
cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]

在循环内部,代码利用边界索引 seq_lens[i]seq_lens[i + 1],从首尾相连的一维长条 hidden_states(通常是模型最后一层的输出)和 labels 中,安全地把当前这条独立的序列“抠”出来

  1. 序列内部的安全错位(Shift)
cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
cur_shift_labels = cur_labels[1:].contiguous()

这是最核心的自回归逻辑:

    • hidden_states[:-1]:丢弃当前序列的最后一个状态(因为它没有下一个 Token 可以预测了)。
    • labels[1:]:丢弃当前序列的第一个标签(因为它是作为起始输入,不需要被预测)。 这样一错位,输入状态就和它需要预测的下一个标签完美对齐了。
  1. 重新拼接回一维长条
shift_hidden_states.append(cur_shift_hidden_states)
shift_labels.append(cur_shift_labels)
# 循环结束后
shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
shift_labels = torch.cat(shift_labels, dim=0)

将切片并错位好的小片段塞入列表,最后使用 torch.cat 沿着第 0 维度(序列维度)重新拼接起来。此时,得到的数据依然是紧凑的、没有 Padding 的一维张量,但内部的因果关系已经完全正确,可以直接送入交叉熵损失函数(CrossEntropyLoss)进行计算。

总结

这套优化方案可以清晰地划分为四个阶段:

  1. 阶段一:数据整理与对齐(VisionCollator)
    • 动作:将零散的样本收集起来,展平嵌套结构。
    • 核心:使用智能的 pad_sequence(支持左侧/右侧填充)将不同长度的文本和图像特征对齐成标准的二维张量 [batch_size, max_seq_len],并动态生成精准的 attention_mask
    • 目的:满足 PyTorch DataLoader 对标准张量格式的基本要求,为后续处理提供统一的入口。
  1. 阶段二:算子劫持与替换(Monkey Patching)
    • 动作:如果系统开启了 use_rmpad 性能开关,在模型初始化阶段,动态导入底层 C++/CUDA 编写的高性能变长算子(如 FlashAttention 的 varlen 版本)。
    • 核心:强行替换掉模型(如 Qwen3VLModel)原生的、低效的 Attention 和 Decoder Layer 的 forward 方法。
    • 目的:实现非侵入式的底层加速,让模型在不知不觉中执行最高效的计算逻辑。
  1. 阶段三:入场解包与物理去填(_unpad_input
    • 动作:在模型前向传播的最开始,拦截带有 Padding 的标准批次数据。
    • 核心:利用 attention_mask 找到所有真实的 Token,将二维矩阵彻底“拍扁”成一维长条 [total_valid_tokens, dim],物理丢弃所有 Padding。同时,计算出底层算子必需的累加边界索引(cu_seqlens)。
    • 目的:为替换后的高性能算子准备绝对紧凑、零浪费的输入数据。
  1. 阶段四:出场错位与防串味(Shift Logic)
    • 动作:在模型计算完毕、准备计算 Loss 之前,对一维长条数据进行自回归错位(Shift)。
    • 核心:利用之前保存的边界索引(seq_lens),将一维长条重新切分成独立的序列片段,在每个片段内部进行 [:-1] 和 [1:] 的错位操作,然后再重新拼接成一维长条。
    • 目的:完美规避序列展平后“上一条序列的结尾预测下一条序列的开头”的灾难性逻辑错误,确保 Loss 计算的绝对正确。