简介
在深度学习模型(尤其是 Transformer 架构)的训练中,输入数据的长度通常需要保持一致。如果直接输入大量短文本,就需要用大量无意义的占位符(Padding)来补齐长度,这会极大地浪费 GPU 的计算资源。为了最大化计算效率,目前基本主流的训练框架里都会加入数据打包(Data Packing)的逻辑,本文以lmms-engine中的操作为例,具体查看实际训练时对数据packing操作以及use_rmpad消除所有padding计算的逻辑
Packing
Dataset
如下Dataset代码所示,这段代码核心目标是:在不超过预设最大长度(packing_length)的前提下,尽可能多地将短样本塞进同一个批次(Batch)中。
这种做法带来了两个显著的好处:
- 提升计算效率:减少了 Padding Token 的数量,让 GPU 的每一次矩阵乘法都作用在真实有效的数据上。
- 稳定训练过程:每个 Batch 的有效 Token 数量更加一致,有助于梯度的稳定
if self.config.packing:
# Reset index at the start of each iteration pass
self.cur_idx = 0
buffer = []
buffer_length = 0
packing_length = self.config.packing_length
# Iterate through the dataset once per epoch
while self.cur_idx < len(curr_data_list):
try:
data_dict = self.get_one_sample(self.cur_idx, curr_data_folder[self.cur_idx], curr_data_list)
except Exception as e:
traceback.print_exc()
logger.error(f"Error getting one sample: {e}, skip this sample")
self.cur_idx += 1
continue
input_ids = data_dict["input_ids"]
data_length = input_ids.shape[0]
self.cur_idx += 1
# Drop overlong sample if filtering is enabled
if data_length > packing_length and self.config.filter_overlong:
continue
# If current sample cannot fit into current buffer, yield the buffer first
if buffer_length > 0 and buffer_length + data_length > packing_length:
yield buffer
buffer = []
buffer_length = 0
# If the sample is still longer than packing_length (and not filtered),
# yield it as its own batch to avoid stalling
if data_length > packing_length:
yield [data_dict]
continue
# Append to buffer
buffer.append(data_dict)
buffer_length += data_length
# Flush remaining buffer
if len(buffer) > 0:
yield buffer
总结数据生成器(Generator)具体步骤
代码通过一个 while 循环遍历数据集,具体的执行流程可以清晰地拆解为以下六个步骤:
- 初始化缓冲区
在每次遍历(Epoch)开始前,初始化一个空列表buffer作为“购物车”,用buffer_length记录当前购物车内数据的总长度,并设定容量上限packing_length。 - 安全读取单个样本
使用try-except结构调用get_one_sample获取当前索引的数据。如果某条数据损坏或读取失败,代码会捕获异常、打印错误日志,并优雅地跳过(continue),防止整场训练任务因为一条脏数据而崩溃。 - 超长样本过滤(按需拦截)
获取当前样本的长度data_length。如果该长度已经超过了packing_length,并且系统配置中开启了filter_overlong,则直接丢弃该样本,不让其参与后续的打包。 - 缓冲区溢出判断(核心打包机制)
在将当前样本放入缓冲区之前,先进行“容量预判”:- 如果
buffer中已经有数据了,且 现有长度 + 当前样本长度 > 容量上限,说明购物车装不下了。此时,代码会通过yield buffer将当前收集到的所有样本作为一个整体产出,随后清空购物车,重置长度。
- 如果
- 处理单件超大样本(兜底策略)
如果当前样本本身的长度就大于packing_length(且在第 3 步中没有被过滤掉),为了防止它卡死后续正常的打包流程,代码会直接将其作为一个独立的列表[data_dict]产出(yield),然后直接进入下一次循环。 - 装入缓冲区与最终收尾
如果样本顺利通过了上述所有检查,它就会被追加到buffer中,并累加buffer_length。当整个数据集遍历完毕跳出while循环后,如果buffer中还有未产出的“尾货”(len(buffer) > 0),代码会执行最后一次yield,确保没有任何数据被遗漏。
Collator
在 Collator 组合成batch的形式传入到模型的输入, 这里还是将数据padding, 如下面代码所示,这段代码定义了一个 VisionCollator 类,它的核心目标是作为 DataLoader 中的 collate_fn(数据整理函数)使用:将多个独立的、长度不一的样本字典,汇总并填充(Padding)成一个统一的、可以直接输入给模型进行前向传播的批次(Batch)张量字典。
在多模态大语言模型(VLM)的训练或推理中,由于每个样本的文本长度不同,无法直接堆叠成矩阵。VisionCollator 的作用就是像一个“打包员”,它不仅能动态补齐文本序列的长度(支持左侧或右侧填充),还能自动生成正确的注意力掩码(Attention Mask),并妥善拼接图像特征或其他全局配置参数,确保数据完美符合模型的输入要求
@ dataclass
class VisionCollator:
processor: Processable
def pad_sequence(self, input_ids, batch_first, padding_value):
if self.processor.tokenizer.padding_side == "left":
input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
if self.processor.tokenizer.padding_side == "left":
input_ids = torch.flip(input_ids, [1])
return input_ids
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
if isinstance(instances[0], list):
instances = [inst for instance in instances for inst in instance]
inputs = collections.defaultdict(list)
for instance in instances:
for key, values in instance.items():
inputs[key].append(values)
batched_inputs = {}
if "input_ids" in inputs.keys():
input_ids = inputs.pop("input_ids")
input_ids = self.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.processor.tokenizer.pad_token_id,
)
batched_inputs["input_ids"] = input_ids
if "labels" in inputs.keys():
labels = inputs.pop("labels")
labels = self.pad_sequence(
labels,
batch_first=True,
padding_value=-100,
)
batched_inputs["labels"] = labels
if "attention_mask" in inputs.keys():
inputs.pop("attention_mask")
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
batched_inputs["attention_mask"] = attention_mask
# for the other keys
for key, values in inputs.items():
# Handle scalar/boolean values ( use_audio_in_video)
if isinstance(values[0], bool) or (
isinstance(values[0], (int, float)) and not isinstance(values[0], torch.Tensor)
):
batched_inputs[key] = values[0]
else:
batched_inputs[key] = torch.concatenate(values, dim=0)
return batched_inputs
@property
def image_token_id(self):
return self.processor.tokenizer.convert_tokens_to_ids(self.processor.image_token)
代码的执行流程主要集中在 __call__ 方法中,配合辅助的 pad_sequence 函数,可以清晰地拆解为以下六个关键步骤:
步骤 1:展平嵌套列表(无缝衔接打包逻辑)
if isinstance(instances[0], list):
instances = [inst for instance in instances for inst in instance]
当 DataLoader 接收到上一段代码 yield buffer 产出的数据时,instances 会是一个“列表的列表”。这里通过列表推导式将其展平为一维的样本列表,方便后续统一处理。
步骤 2:按特征键值聚合数据
inputs = collections.defaultdict(list)
for instance in instances:
for key, values in instance.items():
inputs[key].append(values)
遍历所有样本,将每个样本字典中相同的键(如 input_ids, labels, pixel_values 等)提取出来,分别放入对应的列表中。这一步完成了数据从“按样本组织”到“按特征组织”的转换。
步骤 3:序列填充机制(pad_sequence)
def pad_sequence(self, input_ids, batch_first, padding_value):
# ...
PyTorch 原生的 pad_sequence 只能在序列右侧填充。但很多大模型(如 LLaMA 系)在特定场景下需要左侧填充。如果配置为左填充,先将序列翻转(torch.flip),调用原生函数在右侧填充后,再翻转回来。
步骤 4:处理核心输入与标签
提取出 input_ids 和 labels 后,调用上一步的填充函数:
input_ids:使用分词器标准的pad_token_id补齐,对齐文本输入长度。labels:使用-100补齐。在 PyTorch 的交叉熵损失函数中,默认会忽略值为-100的标签,从而防止 Padding 部分参与 Loss 计算。
步骤 5:动态生成注意力掩码(Attention Mask)
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id).long()
为了确保绝对的准确性,代码直接丢弃了原始数据中可能存在的 attention_mask,而是通过判断填充后的 input_ids 中哪些位置不等于(ne) Padding Token,动态生成了一个全新的 attention_mask(真实数据为 1,Padding 为 0)。
步骤 6:处理多模态特征与其他键值
对于剩下的键(例如图像张量或全局布尔开关):
- 如果是全局标量或布尔值(如
use_audio_in_video),直接取第一个值作为整个 Batch 的配置。 - 如果是张量(如图像特征),则在第 0 维度(Batch 维度)上将它们直接拼接(
torch.concatenate)在一起。
rmpad
项目中,是以 monkey patch的形式(也就是打热补丁) 在运行时动态替换 Qwen3-VL 模型的核心前向传播(Forward)函数,从而注入极致的底层加速算子(通常是为了实现 Remove Padding 即 rmpad 优化)。在处理变长序列或经过“打包(Packing)”的数据时,传统的注意力机制(Attention)仍然需要对齐长度,这会带来冗余计算。rmpad(移除填充)是一种高级优化技术(常与 FlashAttention 的 varlen 变长序列处理结合),它能在底层 CUDA 算子层面直接跳过 Padding 部分的计算。这段代码就是在系统开启 use_rmpad 时,用这套高性能的自定义算子替换掉 PyTorch 原生的标准实现,以大幅提升吞吐量并降低显存占用。
rmpad 是如何工作的?
rmpad 的出现就是为了从根本上消灭这种浪费。它的核心思想是:打破传统的二维矩阵束缚,只计算真正有效的数据。
具体步骤如下:
- 展平与压缩(Flattening):在数据进入 Attention 层之前,底层算子会把原本形状为
[batch_size, max_seq_len]的二维矩阵,直接“拍扁”成一个一维的长条序列[total_valid_tokens]。所有的 Padding 都在这一步被物理移除了。 - 记录边界(cu_seqlen):既然数据变成了一维长条,GPU 怎么知道哪几个 Token 属于文本 A,哪几个属于文本 B 呢?系统会额外传递一个极小的索引数组(通常叫
cu_seqlen,即累加序列长度),用来标记每条数据的起止位置。 - 变长注意力计算(Varlen Attention):替换后的底层 CUDA 算子(比如 FlashAttention 的
varlen接口)会读取这个一维长条和边界索引,只在每条数据的有效边界内进行 Attention 计算,绝不跨界,也绝不计算任何 Padding。
所以为了提高GPU 的利用率(MFU),一般是基于这三个操作来复合实现的
- 数据打包(Packing):在最上层,尽量把多个短文本拼成一个长文本,减少整体的 Padding 比例。
- 数据整理(Collator):在中间层,把实在拼不齐的尾部数据用 Padding 补齐,生成 Mask,保证数据能顺利变成 Tensor。
- 底层算子(rmpad):在最底层计算时,直接把 Collator 补上的那一点点 Padding 再次剔除,调用专属的变长算子(Varlen Ops)进行绝对纯净的有效计算。
if use_rmpad:
from .qwen3_vl_ops import attn_forward as qwen3_ops_attn_forward
from .qwen3_vl_ops import (
decoder_layer_forward as qwen3_ops_decoder_layer_forward,
)
from .qwen3_vl_ops import model_forward as qwen3_ops_model_forward
from .qwen3_vl_ops import text_model_forward as qwen3_ops_text_model_forward
modeling_qwen3_vl.Qwen3VLModel.forward = qwen3_ops_model_forward
modeling_qwen3_vl.Qwen3VLTextModel.forward = qwen3_ops_text_model_forward
modeling_qwen3_vl.Qwen3VLTextDecoderLayer.forward = qwen3_ops_decoder_layer_forward
modeling_qwen3_vl.Qwen3VLTextAttention.forward = qwen3_ops_attn_forward
Qwen3VLModel.forward
这里就是 rmpad(移除填充)概念的具体演示,大模型在进入核心 Attention 计算前,是如何在底层将二维矩阵“拍扁”并精准剔除无用 Padding 的。
下面代码显式调用了 _unpad_input。它计算了非 padding 元素的索引 (indices) 和累积序列长度 (cu_seq_lens), 去掉了input中的padding token
if input_ids is not None:
original_input_ids = input_ids
input_ids, indices, cu_seq_lens, _ = _unpad_input(input_ids, attention_mask=attention_mask)
batch_size, seq_length = original_input_ids.shape
elif inputs_embeds is not None:
original_inputs_embeds = inputs_embeds
inputs_embeds, indices, cu_seq_lens, _ = _unpad_input(inputs_embeds, attention_mask=attention_mask)
batch_size, seq_length, _ = original_inputs_embeds.shape
def _unpad_input(input_ids, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]
unpad_seq_len = input_ids.shape[0]
return input_ids, indices, cu_seqlens, max_seqlen_in_batch
下面来具体分析_unpad_input的步骤:
第一步:提取有效掩码与计算真实长度
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(1)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
先将 attention_mask 降维,找到所有值为 1(即真实 Token)的位置,得到 valid_mask。然后沿着序列维度求和,算出这个 Batch 中每一行真实的有效长度(例如 [1000, 10])。
第二步:定位所有有效 Token 的一维坐标
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
把二维的 valid_mask 彻底展平(Flatten),然后用 torch.nonzero 找出所有真实 Token 在这个一维长条中的绝对索引位置。这个 indices 既用于现在的“提取”,也用于未来的“还原”。
第三步:计算 FlashAttention 必备的 cu_seqlens
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
这是极其关键的一步!底层变长算子需要知道每条序列的起止位置。
假设 seqlens_in_batch 是 [3, 2, 4]:
cumsum(累加)后变成[3, 5, 9]。F.pad在最前面补个 0,最终得到[0, 3, 5, 9]。 这就是大名鼎鼎的cu_seqlens(Cumulative Sequence Lengths)。底层 CUDA 算子拿到它,就知道第 1 条序列在索引 0~3,第 2 条在 3~5,绝不会算串位。
第四步:物理剔除 Padding(张量重塑与索引提取)
input_ids = rearrange(input_ids, "b s ... -> (b s) ...")[indices]
这里使用了 einops.rearrange 将形状为 (batch, seq_len, dim) 的张量直接拍扁成 (batch * seq_len, dim)。随后,利用第二步算出的 indices,只把有效的 Token 抽出来。
至此,所有的 Padding 被彻底抛弃,数据变成了绝对紧凑的形态。
由于输入形状变了,位置编码的施加方式也必须改变。
计算出的 position_ids 需要根据 indices 进行重排和筛选,以匹配扁平化后的 input_ids。
# 将 (batch, seq) 的 pos_ids 展平,并只取非 padding 部分
position_ids = (
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices).transpose(0, 1).unsqueeze(1)
)
最终传递给 Language Model 的参数最后调用底层 LLM 时,参数列表也不同
outputs = self.language_model(
...,
indices=indices, # 关键:用于恢复原始形状或索引
cu_seq_lens=cu_seq_lens, # 关键:FlashAttention Varlen 需要知道每个句子的边界
...
)
Qwen3VLTextAttention.forward
这里其实就是显示的使用 flash_attn_varlen_func来做attention计算, 从而达到packing的最终目的。
attn_output = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=cu_seq_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
window_size=window_size,
softmax_scale=self.head_dim**-0.5,
dropout_p=0.0,
)
简单来说,flash_attn_varlen_func 的核心任务是:在一堆已经混在一起、分不清谁是谁的 Token 长条中,利用“索引导航”精准地还原出原本的句子结构,并进行 Attention 计算。
下面通过一个具体的例子来详细拆解这个过程:
- 数据的变形:从“方阵”到“长条”
在传统的 Attention(如forward函数)中,数据是规整的方阵,短句子后面补了 Padding(0):
而在 Rmpad (Remove Padding) 模式下,我们在进入flash_attn_varlen_func之前(即attn_forward的前半部分),已经把所有 Padding 扔掉了,并把剩下的 Token 首尾相连拼成了一个超长的 1D 序列:- 传统 Batch (Batch=3):
- 句子 A (长度 2):
[A1, A2, 0, 0] - 句子 B (长度 4):
[B1, B2, B3, B4] - 句子 C (长度 3):
[C1, C2, C3, 0] - 形状:
(3, 4, Hidden_Dim)。其中0是浪费的计算。
- 句子 A (长度 2):
- Varlen 输入 (q, k, v):
[A1, A2, B1, B2, B3, B4, C1, C2, C3]- 形状:
(Total_Tokens, Num_Heads, Head_Dim)。这里 Total = 2+4+3 = 9。 - 注意: 这里完全没有 Batch 维度了。
- 传统 Batch (Batch=3):
cu_seqlens(Cumulative Sequence Lengths)
既然数据变成了一长条,FlashAttention 怎么知道 A2 后面是 B1(属于不同句子,不能做 Attention),而不是 A3 呢?
这就靠代码中的 cu_seqlens(累积序列长度)。这是一个 “切分点”索引数组。
对于上面的例子,cu_seqlens 的值会是: [0, 2, 6, 9]
在代码中的体现:
flash_attn_varlen_func(
q=query_states, # 形状: (Total_Tokens, nheads, dim)
cu_seqlens_q=cu_seq_lens, # 形状: (Batch_Size + 1,) -> [0, 2, 6, 9]
max_seqlen_q=max_seqlen, # 告诉内核最长的句子有多长 (这里是 4)
...
)
CUDA 内核拿到这个 cu_seqlens 后,就能在显存中快速定位每个句子的起始和结束位置,从而只在 A1 和 A2 之间计算 Attention,绝不会跨越边界去和 B1 计算。
- 0 -> 2: 第 1 个句子(A)的范围(索引 0 到 2,不含 2)。
- 2 -> 6: 第 2 个句子(B)的范围(索引 2 到 6,即 2,3,4,5)。
- 6 -> 9: 第 3 个句子(C)的范围。
Loss计算
在已经剔除 Padding 且被展平成一维长条的数据上,正确地执行自回归语言模型的“错位(Shift)”操作,确保模型只在单条序列内部预测下一个 Token,而绝对不会跨序列预测。
在标准的二维矩阵 [batch, seq_len] 中,我们通常只需要简单地切片 logits[:, :-1] 和 labels[:, 1:] 就能让输入和标签错开一位。但是,当启用了 use_rmpad,所有序列首尾相连变成了一条线时,如果直接整体错位,序列 A 的最后一个 Token 就会去预测序列 B 的第一个 Token,这会导致模型学到完全错误的因果关系。这段代码就是为了精准规避这个灾难。
if use_rmpad:
# We need to shift the tokens according to seq lens
# Otherwise, the first labels of the next seq will be the last labels of the current seq
shift_hidden_states = []
shift_labels = []
for i in range(len(seq_lens) - 1):
cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
cur_shift_labels = cur_labels[1:].contiguous()
shift_hidden_states.append(cur_shift_hidden_states)
shift_labels.append(cur_shift_labels)
shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
shift_labels = torch.cat(shift_labels, dim=0)
代码通过遍历累加的长度索引(seq_lens,在这里它的作用等同于上一问的 cu_seqlens),逐个序列进行安全错位。具体步骤如下:
- 准备容器与遍历边界
shift_hidden_states = []
shift_labels = []
for i in range(len(seq_lens) - 1):
首先初始化两个空列表,用来存放处理好的各个序列片段。接着,利用 seq_lens 数组进行循环。因为 seq_lens 记录的是累加边界(例如 [0, 3, 5, 9]),所以循环 len(seq_lens) - 1 次刚好能遍历完所有的独立序列。
- 精准切分单条序列
cur_hidden_states = hidden_states[seq_lens[i] : seq_lens[i + 1], :]
cur_labels = labels[seq_lens[i] : seq_lens[i + 1]]
在循环内部,代码利用边界索引 seq_lens[i] 到 seq_lens[i + 1],从首尾相连的一维长条 hidden_states(通常是模型最后一层的输出)和 labels 中,安全地把当前这条独立的序列“抠”出来。
- 序列内部的安全错位(Shift)
cur_shift_hidden_states = cur_hidden_states[:-1, :].contiguous()
cur_shift_labels = cur_labels[1:].contiguous()
这是最核心的自回归逻辑:
hidden_states[:-1]:丢弃当前序列的最后一个状态(因为它没有下一个 Token 可以预测了)。labels[1:]:丢弃当前序列的第一个标签(因为它是作为起始输入,不需要被预测)。 这样一错位,输入状态就和它需要预测的下一个标签完美对齐了。
- 重新拼接回一维长条
shift_hidden_states.append(cur_shift_hidden_states)
shift_labels.append(cur_shift_labels)
# 循环结束后
shift_hidden_states = torch.cat(shift_hidden_states, dim=0)
shift_labels = torch.cat(shift_labels, dim=0)
将切片并错位好的小片段塞入列表,最后使用 torch.cat 沿着第 0 维度(序列维度)重新拼接起来。此时,得到的数据依然是紧凑的、没有 Padding 的一维张量,但内部的因果关系已经完全正确,可以直接送入交叉熵损失函数(CrossEntropyLoss)进行计算。
总结
这套优化方案可以清晰地划分为四个阶段:
- 阶段一:数据整理与对齐(VisionCollator)
- 动作:将零散的样本收集起来,展平嵌套结构。
- 核心:使用智能的
pad_sequence(支持左侧/右侧填充)将不同长度的文本和图像特征对齐成标准的二维张量[batch_size, max_seq_len],并动态生成精准的attention_mask。 - 目的:满足 PyTorch
DataLoader对标准张量格式的基本要求,为后续处理提供统一的入口。
- 阶段二:算子劫持与替换(Monkey Patching)
- 动作:如果系统开启了
use_rmpad性能开关,在模型初始化阶段,动态导入底层 C++/CUDA 编写的高性能变长算子(如 FlashAttention 的 varlen 版本)。 - 核心:强行替换掉模型(如
Qwen3VLModel)原生的、低效的 Attention 和 Decoder Layer 的forward方法。 - 目的:实现非侵入式的底层加速,让模型在不知不觉中执行最高效的计算逻辑。
- 动作:如果系统开启了
- 阶段三:入场解包与物理去填(
_unpad_input)
- 动作:在模型前向传播的最开始,拦截带有 Padding 的标准批次数据。
- 核心:利用
attention_mask找到所有真实的 Token,将二维矩阵彻底“拍扁”成一维长条[total_valid_tokens, dim],物理丢弃所有 Padding。同时,计算出底层算子必需的累加边界索引(cu_seqlens)。 - 目的:为替换后的高性能算子准备绝对紧凑、零浪费的输入数据。
- 阶段四:出场错位与防串味(Shift Logic)
- 动作:在模型计算完毕、准备计算 Loss 之前,对一维长条数据进行自回归错位(Shift)。
- 核心:利用之前保存的边界索引(
seq_lens),将一维长条重新切分成独立的序列片段,在每个片段内部进行[:-1]和[1:]的错位操作,然后再重新拼接成一维长条。 - 目的:完美规避序列展平后“上一条序列的结尾预测下一条序列的开头”的灾难性逻辑错误,确保 Loss 计算的绝对正确。