图重写模块#

TensorRT-LLM 使用声明式方法来定义神经网络,并包含优化底层图的技术。 它提供了一个类似于 PyTorch 的 Module 的包装器。 当用户调用 forward 方法时,层被降低到 TensorRT 的 ILayers 并成为 INetworkDefinition 的一部分。 图重写 (GW) 模块可用于在 ILayer/INetworkDefinition 级别操作网络。

何时使用图重写?#

对于网络操作,TensorRT-LLM 中有两个选项

  1. 模块重写:此方法在触发 forward 方法(即,创建 TensorRT 图)之前修改 Module 实例的成员。 它在网络表示的最高级别上工作,并有助于修改操作序列(例如,修改用于 SmoothQuant 的 GEMM + 激活)。

  2. 图重写:图重写在触发 forward 方法后操作 TensorRT 的 INetworkDefinition。 它在更细粒度的 ILayer 级别上运行,并且可以改变跨多个 Module 实例的结构。 它通常用于层融合。

图重写 (GW) 非常适合在以下条件下使用

  1. 仅当 ILayer/INetworkDefinition 可用时。

  2. 当模块重写会导致嵌套控制流或分散的功能时。

图重写 API#

提供了几个核心 API 用于图重写

FLayerInfo 用于检索功能的高级信息#

对于位于 functional.py 中的所有层,一旦降低到 INetworkDefinition,原始输入信息就会丢失,尤其是对于 TensorRT 插件,它们在 Python 世界中是不透明的。 FLayerInfo 将它们的原始信息保存为高级签名,其中包含像 Tensor、Python 属性等输入。 有一个网络范围的单例称为 FLayerInfoMemo,用于将每个 ILayer 映射到其对应的 FLayerInfo

对于 FLayerInfo

  • FLayerInfo.replace_input_with:将一些输入张量替换为另一个张量。

  • FLayerInfo.replace_output_uses_with:将原始输出张量的使用重定向到一组新的张量。

对于 FLayerInfoMemo

  • FLayerInfoMemo.instance():获取单例实例。

  • FLayerInfoMemo.get:获取 ILayer 的相应 FLayerInfo

FLayerInfo 在 GW 期间与实际 ILayer 保持一致,使其可以安全使用。

模式和模式管理器#

有两种模式

  • PatternRewriter:用于定义重写模式,它实际上会改变网络。

    • match:匹配模式;如果匹配到某个层,则返回 true。

    • rewrite:操作层。

    • match_and_rewrite:结合了 matchrewrite,用于需要从 match 传递到 rewrite 的复杂状态。

  • PatternAnalyzer:用于定义分析模式,该模式从网络收集信息。

    • match:匹配模式。

    • analyze:对层列表执行分析。

有两个管理器用于管理多个 PatternRewriterPatternAnalyzer

  • RewritePatternManager:

    • add:添加具有其标签和优点的模式;该优点指定其特权。

    • get:按标签获取模式。

    • rewrite:将包含的重写模式应用于网络。

  • AnalysisPatternManager:

    • add:添加具有其标签和优点的模式;该优点指定其特权。

    • get:按标签获取模式。

    • analyze:将包含的分析模式应用于网络。

@record_signature 用于装饰需要 FLayerInfo 的功能#

@record_signature 装饰器用于记录功能的 FLayerInfo。 虽然 FLayerInfo 对于分析或重写某些功能时对 GW 至关重要,但它以“根据需要添加”的方式使用。 如果您要添加 GW 模式,请确保该功能需要 @record_signature 装饰器。

经典工作流程#

有用于定义 GW 模式的特定例程。 让我们从一个简单的示例开始:用减法层替换求和层,这也可以在 test_graph_rewriting.py 文件中找到。

class NaivePatternRewriter_ReplaceAddWithSub(PatternRewriter):

    def __init__(self):
        super().__init__('replace_add_with_sub',
                         root_layer={trt.LayerType.ELEMENTWISE},
                         separate_match_rewrite=True)

    def match(self, layer: Layer):
        # The rewriter will stop at the first matched layer, and then the Rewriter will enter the rewrite() to do the rewriting.
        return layer.as_layer().op == trt.ElementWiseOperation.SUM

    def rewrite(self, layer: Layer) -> None:
        # The layer here should be an Elementwise_SUM layer.
        with net_guard(layer.network):
            # There are several stages to replace some subgraph with another subgraph:

            # Stage 1: Get the input tensors and output tensors of the subgraph to replace.
            # - For Elementwise_SUM, there are two inputs and one output.
            a, b = layer.get_inputs(0, 1)
            o = layer.get_outputs(0)[0]

            # Stage 2: Create a new subgraph that takes the old one's inputs.
            # - Here we insert an Elementwise_SUB layer, and 'c' is the output.
            c = a - b

            # Stage 3: Redirect all the layers depending on the outputs of the old subgraph to the new subgraph's.
            # - After this, the SUM becomes dangling and will be pruned by TensorRT when building the engine.
            # - Note that there is no API in TensorRT python to remove a layer explicitly; `replace_all_uses_with` is the only way to "remove" a layer.
            o.replace_all_uses_with(c)

            # Stage 4: Mark all the layers in the old subgraph as removed.
            # - This helps the PatternRewriter to skip the removed layers.
            layer.mark_as_removed()

在此示例中,我们处理的是 ILayer 而不是插件,因此 FLayerInfo 是不必要的。 如图 rewrite 方法所示,有四个阶段在几乎所有重写模式中共享。

请注意,在 GW 中,我们从不直接重写层。 相反,我们分两个步骤进行:首先,创建另一个具有相同输入的层,并剥夺原始输出的所有用户,将它们重定向到新层的输出。 这样,旧层将悬空,并在引擎构建阶段由 TensorRT 自动修剪。 这是 TensorRT 的一个限制,因为 remove-layer-like API 在 Python 中不可用。

在第 2 阶段中,我们依赖于网络构建阶段常用的运算符和层。 理想情况下,您可以在 GW 期间将它们替换为任何网络结构。

对于 FLayerInfo 的用法,让我们重写 gpt_attention 以启用 remove-padding 功能。 gpt_attention 实际上是

一个 TensorRT 插件,因此我们需要 FLayerInfo 来保存原始的 Tensor 类型的输入,以帮助创建新的 gpt_attention 层。

class GPTAttentionPluginRemovePaddingRewritePass(PatternRewriter):

    def __init__(self):
        super().__init__('gpt_attention_plugin_remove_padding',
                         root_layer={trt.LayerType.PLUGIN_V2})

    def match_and_rewrite(self, layer: Layer) -> bool:
        if layer.as_layer().type != trt.LayerType.PLUGIN_V2 or \
                layer.as_layer().plugin.plugin_namespace != 'tensorrt_llm' or \
                layer.as_layer().plugin.plugin_type != 'GPTAttention':
            return False

        # Retrieve the FLayerInfo
        flayer = FLayerInfoMemo.instance().get(layer.name)
        assert flayer
        # Although the layer is a plugin, which is a black box, we get some high-level input information from the FLayerInfo.
        tensor_input: Tensor = flayer.get_input('qkv')
        if tensor_input.shape[0] == 1:  # Already in remove-padding mode
            return False

        # Some information could be passed in from external
        assert self.args is not None, "args should be passed in from RewritePatternManager.rewrite()"
        batch_size, in_len, hidden_size = self.args['batch_size'], self.args['in_len'], self.args['hidden_size']

        with net_guard(layer.network):
            new_inputs = flayer.clone_inputs()

            # Step 1: Create new inputs and replace the original arglist.
            input = Tensor(
                name='qkv',
                dtype=trt.float16,
                shape=(1, batch_size * in_len, hidden_size),
            )
            new_inputs['qkv'] = input

            # Step 2: Create a new plugin instance.
            new_outs = gpt_attention(**new_inputs)

            # Step 3: Deprive all the users of the old plugin instance.
            flayer.replace_outputs_uses_with(layer.network, new_outs)

            # Step 4: Remove the old plugin instance.
            layer.mark_as_removed()

        return True

这与第一个例子非常相似,重点在于 FLayerInfo 部分。通过下面的代码,我们可以获得该层的原始输入,从而能够修改与移除填充相关的输入,并创建一个新层来替换它。

flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
new_inputs = flayer.clone_inputs()

# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
    name='tensor',
    dtype=trt.float16,
    shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['tensor'] = input

# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)

有关真实示例,请参阅 graph_rewriting.py 中的 FuseAttentionWithBiasPass