一、引言:从算子级优化到计算图级优化
视频生成模型的推理优化是一个多层次、系统性的工程挑战。在模型推理的早期阶段,优化重点通常集中在算子层面,例如通过优化卷积、注意力等核心算子的计算效率来直接提升浮点运算性能。然而,随着单算子性能逐渐逼近硬件极限,计算图层面的优化便成为释放更大潜力的关键。计算图优化关注的是算子之间的调度、内存复用以及控制流开销,其核心在于提升整体执行图效率。一个高效的执行图能够最大限度地减少框架与硬件的交互开销,避免不必要的内存搬运,并使得更激进的算子融合与内存规划成为可能。
本文将聚焦于推理执行流程本身,探讨如何借助 torch.compile 对 Self-Forcing 的推理流程进行整图编译(full graph compilation),以系统性地降低 Python 解释与调度开销,并为后续更深层次的图级优化奠定基础。
二、Self-Forcing 推理特性
与整图编译的挑战
Self-Forcing 是一种将 Wan2.1 等全帧并行扩散模型改造为因果注意力架构的训练与推理范式。与传统双向扩散模型在推理阶段需要同时处理全部视频帧不同,Self-Forcing 采用逐块(block-wise)的自回归生成策略:每次生成一小段 latent(通常为 3 帧),并通过 KV Cache 复用历史上下文。
在这种设计下,注意力计算的复杂度由传统扩散模型的下降为,其中 B 表示单次生成的 block 大小。这一特性使得 Self-Forcing 在低延迟、流式视频生成等应用场景中具备显著优势。
从编译优化的角度看,Self-Forcing 的实现同时具备“适合编译”和“难以编译”的双重特性。一方面,其推理过程高度结构化,计算模式在每个 step 内基本固定;另一方面,原始实现中广泛存在以下问题:
1)依赖张量值的 Python 控制流;
2)通过 .item()、tolist() 等方式将张量退回 Host 端;
3)KV Cache 的动态索引与切片操作;
4)Python 层缓存与调试逻辑的混入。
上述因素都可能会在 torch.compile 过程中触发 Graph Break,使编译器只能生成多个碎片化的子图,从而难以获得实质性的性能收益。
三、整图编译策略与实现选择
在具体实现上,我们采用了一种渐进式的优化策略:首先对关键模块使用 torch.compile 进行局部封装,以评估其潜在收益;随后,在此基础上系统性地识别并消除 Graph Break,逐步推进至整图编译(Full Graph Compilation)。
我们最终选择在注意力模块(CausalWanAttentionBlock)的 forward 方法上使用如下配置:
其中,dynamic=True 允许编译器以符号形状(symbolic shapes)的形式表示部分运行期才能确定的维度,从而在不触发额外 graph break 或重新编译的前提下,支持一定范围内的输入形状变化。这一特性在进行前期优化的时候比较重要:尽管单次推理的视频分辨率和总帧数通常是固定的,但序列在多卡之间的切分方式、以及 KV Cache 在不同 step 中的有效长度与写入区间,都可能会引入运行期的形状差异,通过符号化这些维度,可以避免因形状变化而频繁触发重新编译,从而提升整体推理稳定性与性能。
fullgraph=True 则是实现整图优化的关键。该选项要求整个函数必须被编译为单一 FX 计算图,一旦遇到无法追踪的操作便直接报错并中止编译。虽然这一“严格模式”显著提高了工程改造的难度,但它能够在编译阶段完整暴露所有潜在的 Graph Break 点,避免隐式的子图切换开销,并为后续的算子融合和 CUDA Graph 捕获提供必要前提。
在实践中,fullgraph 模式的价值不仅体现在最终性能上,更体现在其对代码结构与数据流设计的“约束”作用。
四、Graph Break 的成因分析与消除方法
在本章节中,我们的代码示例基于 Self Forcing 的 官方实现(https://github.com/guandeh17/Self-Forcing),部分提及的序列并行代码基于内部工程实现。
控制流与标量提取
在 torch.compile 中导致图断开的最常见原因之一是 Python 端语义依赖于运行时张量值。在 Self-Forcing 的原始实现中,大量逻辑通过 .item() 将张量转换为 Python 标量,用于计算序列长度、帧索引以及 KV Cache 的读写位置,例如:
frame_seqlen = math.prod(grid_sizes[0][1:]).item()local_end_index = kv_cache["local_end_index"].item() + current_end
在 fullgraph=True 模式下,这类代码会直接触发如下错误:
Unsupported Tensor.item() call with capture_scalar_outputs=False尽管 PyTorch 提供了 capture_scalar_outputs=True 来支持此类用法,但该方案会引入额外的 Host-Device 同步和标量封装开销,并削弱编译器对数据流的静态分析能力。因此基于性能考虑,我们选择彻底消除 .item() 调用,而不是依赖该选项。
核心思路是:只要某个量可以在 GPU 上以张量形式计算,就应避免将其退回 CPU。例如,上述代码可以改写为:
frame_seqlen = torch.prod(grid_sizes[0][1:])current_start_frame = current_start // frame_seqlen
更加概括的说,我们在推理热路径上系统性地移除了所有张量到 Python 的转换操作,包括但不限于:
.item()
int(tensor)
.tolist()
其他类似的标量化操作
原因在于,从 torch.compile 的视角来看,Tensor 并不等价于具体数值,而是一种可被分析和优化的计算关系表示。一旦中间结果被转换为 Python 对象,相关数据依赖将很有可能脱离编译器的控制范围,导致计算图无法被完整捕获和优化。相反,保持计算逻辑完全以张量形式表达,有助于最大化编译器的优化空间,并确保推理过程中 CUDA 执行的高效性与并行性。
数据依赖与动态形状
另一类更为隐蔽的 Graph Break 源于数据依赖导致的动态形状推导失败。在 RoPE 的计算逻辑中,原始实现通过 tolist() 将 grid_sizes 张量转换为 Python 列表,并在循环中动态计算序列长度:
for f, h, w in grid_sizes.tolist():seq_len = f * h * wx_i = x[i, :seq_len]
这一实现同时引入了两个问题:一方面,tolist() 本身会触发 Graph Break,使相关计算逻辑脱离计算图;另一方面,由图外标量 f、h、w 参与计算得到的 seq_len 被用于张量切片,导致编译器无法为输出张量的形状建立有效的符号约束。最终,这种数据依赖关系会在编译阶段表现为守卫失败(data-dependent guard failure),报错如下:
Could not guard on data-dependent expression u0*u1*u2 < 0 (unhinted: u0*u1*u2 < 0). (Size-like symbols: none)但事实上,从推理优化的角度来看,这种实现方式本身是不必要的。对于 Self-Forcing 推理流程而言,grid_sizes 实际由视频帧数、latent 分辨率以及 patch 大小共同决定,而这些参数在推理服务启动时就已经确定,或仅存在有限几组离散取值。因此,在推理优化场景中,我们将 grid_sizes 视为配置期常量,并针对固定配置对计算图进行特化(specialization)。
尽管这种做法在形式上降低了实现的通用性,但在推理系统中,为固定分辨率和序列长度构建多组计算图是一种常见且行之有效的工程实践。其设计思想也与 CUDA Graph 实践中的分桶(bucketing)与填充(padding)机制高度一致。
KV Cache 的动态索引问题
KV Cache 的读写是推理代码中导致 Graph Break 的另一个高频来源。原始实现中,KV Cache 的访问依赖于运行时计算得到的索引:
# local_start_index 和 local_end_index 是 Tensor 类型# 读取 KV Cachex = kv_cache["k"][:, local_start_index:local_end_index]# 写入 KV Cachekv_cache["k"][:, local_start_index:local_end_index] = roped_key
由于当前主流版本的 torch.compile 实现尚不支持将张量作为切片边界,上述代码在 fullgraph 模式下会导致编译失败。
通过对因果注意力具体计算过程的分析可以发现,在未启用滑动窗口注意力的前提下,KV Cache 的访问范围实际上具有明确的闭式解:历史 KV 的读取起点始终为 0,而写入位置则可以由全局序列起始位置 current_start 与当前 block 的 token 数直接确定。在序列并行(SP)场景中,由于 attention 计算开始前已经完成了当前 block 所有序列的 all-gather,这一结论同样成立。
基于上述观察,我们将 KV Cache 中依赖运行期索引的动态切片逻辑重构为等价的静态索引访问,并在工程实现中通过自定义的 tilelang kernel 实现高效写入。该改造不仅彻底消除了由 KV Cache 读写引发的 Graph Break,也在实际推理中带来了较为明显的性能收益。
Host 调用与 Python 层缓存
除了张量相关的问题,Python 层的缓存与调试逻辑同样可能会破坏整图编译。例如,社区中常见的 RoPE 优化方案通常使用 LRU 缓存 cos/sin 值,但该缓存机制依赖 Python 字典的状态更新,推理过程中仍会引入 Host 端参与。
在实际工程中,我们将这类依赖运行期状态的动态缓存改写为预计算逻辑:在模型初始化阶段提前生成所需的 cos / sin 张量,并以连续 Tensor 的形式常驻于 GPU 显存中;在推理阶段仅通过张量索引完成访问,从而彻底消除 Python 侧的参与。
类似地,诸如 time.time() 等调试代码也会直接触发 Graph Break,应该在优化阶段彻底移除。
五、实验结果与总结
在完成 Graph Break 的系统性消除并启用整图编译后,我们在生成 5 秒、480P 视频 的推理任务上进行了性能评测,模型规模为 14B 参数。消融实验结果表明,仅通过 torch.compile 的整图优化,便可在端到端层面获得约 47.6% 的加速效果,将推理耗时从 8.86 秒 降低至 6.00 秒,且未观察到明显的精度退化。
综上所述,本文展示了在自回归视频生成推理场景下,基于 torch.compile 实现整图编译的一套工程实践经验。我们的经验表明,整图编译的核心价值并不仅在于“自动加速”,更在于其对数据依赖、控制流以及工程实现方式所施加的强约束。这种约束能够显式暴露系统中的隐性复杂度,为进一步的底层算子融合与系统级优化奠定基础。
-End-
作者丨storyicon、在喝可乐的派派
往期精彩指路
丨丨
丨丨
推荐站内搜索:最好用的开发软件、免费开源系统、渗透测试工具云盘下载、最新渗透测试资料、最新黑客工具下载……




还没有评论,来说两句吧...