图重写模块#
TensorRT-LLM 使用声明式方法来定义神经网络,并包含优化底层图的技术。 它提供了一个类似于 PyTorch 的 Module 的包装器。 当用户调用 forward
方法时,层被降低到 TensorRT 的 ILayer
s 并成为 INetworkDefinition
的一部分。 图重写 (GW) 模块可用于在 ILayer
/INetworkDefinition
级别操作网络。
何时使用图重写?#
对于网络操作,TensorRT-LLM 中有两个选项
模块重写:此方法在触发
forward
方法(即,创建 TensorRT 图)之前修改Module
实例的成员。 它在网络表示的最高级别上工作,并有助于修改操作序列(例如,修改用于 SmoothQuant 的 GEMM + 激活)。图重写:图重写在触发
forward
方法后操作 TensorRT 的INetworkDefinition
。 它在更细粒度的ILayer
级别上运行,并且可以改变跨多个 Module 实例的结构。 它通常用于层融合。
图重写 (GW) 非常适合在以下条件下使用
仅当
ILayer
/INetworkDefinition
可用时。当模块重写会导致嵌套控制流或分散的功能时。
图重写 API#
提供了几个核心 API 用于图重写
FLayerInfo 用于检索功能的高级信息#
对于位于 functional.py
中的所有层,一旦降低到 INetworkDefinition
,原始输入信息就会丢失,尤其是对于 TensorRT 插件,它们在 Python 世界中是不透明的。 FLayerInfo
将它们的原始信息保存为高级签名,其中包含像 Tensor
、Python 属性等输入。 有一个网络范围的单例称为 FLayerInfoMemo
,用于将每个 ILayer
映射到其对应的 FLayerInfo
。
对于 FLayerInfo
FLayerInfo.replace_input_with
:将一些输入张量替换为另一个张量。FLayerInfo.replace_output_uses_with
:将原始输出张量的使用重定向到一组新的张量。
对于 FLayerInfoMemo
FLayerInfoMemo.instance()
:获取单例实例。FLayerInfoMemo.get
:获取ILayer
的相应FLayerInfo
。
FLayerInfo
在 GW 期间与实际 ILayer
保持一致,使其可以安全使用。
模式和模式管理器#
有两种模式
PatternRewriter
:用于定义重写模式,它实际上会改变网络。match
:匹配模式;如果匹配到某个层,则返回 true。rewrite
:操作层。match_and_rewrite
:结合了match
和rewrite
,用于需要从match
传递到rewrite
的复杂状态。
PatternAnalyzer
:用于定义分析模式,该模式从网络收集信息。match
:匹配模式。analyze
:对层列表执行分析。
有两个管理器用于管理多个 PatternRewriter
或 PatternAnalyzer
RewritePatternManager
:add
:添加具有其标签和优点的模式;该优点指定其特权。get
:按标签获取模式。rewrite
:将包含的重写模式应用于网络。
AnalysisPatternManager
:add
:添加具有其标签和优点的模式;该优点指定其特权。get
:按标签获取模式。analyze
:将包含的分析模式应用于网络。
@record_signature 用于装饰需要 FLayerInfo 的功能#
@record_signature
装饰器用于记录功能的 FLayerInfo
。 虽然 FLayerInfo 对于分析或重写某些功能时对 GW 至关重要,但它以“根据需要添加”的方式使用。 如果您要添加 GW 模式,请确保该功能需要 @record_signature
装饰器。
经典工作流程#
有用于定义 GW 模式的特定例程。 让我们从一个简单的示例开始:用减法层替换求和层,这也可以在 test_graph_rewriting.py
文件中找到。
class NaivePatternRewriter_ReplaceAddWithSub(PatternRewriter):
def __init__(self):
super().__init__('replace_add_with_sub',
root_layer={trt.LayerType.ELEMENTWISE},
separate_match_rewrite=True)
def match(self, layer: Layer):
# The rewriter will stop at the first matched layer, and then the Rewriter will enter the rewrite() to do the rewriting.
return layer.as_layer().op == trt.ElementWiseOperation.SUM
def rewrite(self, layer: Layer) -> None:
# The layer here should be an Elementwise_SUM layer.
with net_guard(layer.network):
# There are several stages to replace some subgraph with another subgraph:
# Stage 1: Get the input tensors and output tensors of the subgraph to replace.
# - For Elementwise_SUM, there are two inputs and one output.
a, b = layer.get_inputs(0, 1)
o = layer.get_outputs(0)[0]
# Stage 2: Create a new subgraph that takes the old one's inputs.
# - Here we insert an Elementwise_SUB layer, and 'c' is the output.
c = a - b
# Stage 3: Redirect all the layers depending on the outputs of the old subgraph to the new subgraph's.
# - After this, the SUM becomes dangling and will be pruned by TensorRT when building the engine.
# - Note that there is no API in TensorRT python to remove a layer explicitly; `replace_all_uses_with` is the only way to "remove" a layer.
o.replace_all_uses_with(c)
# Stage 4: Mark all the layers in the old subgraph as removed.
# - This helps the PatternRewriter to skip the removed layers.
layer.mark_as_removed()
在此示例中,我们处理的是 ILayer
而不是插件,因此 FLayerInfo
是不必要的。 如图 rewrite
方法所示,有四个阶段在几乎所有重写模式中共享。
请注意,在 GW 中,我们从不直接重写层。 相反,我们分两个步骤进行:首先,创建另一个具有相同输入的层,并剥夺原始输出的所有用户,将它们重定向到新层的输出。 这样,旧层将悬空,并在引擎构建阶段由 TensorRT 自动修剪。 这是 TensorRT 的一个限制,因为 remove-layer-like API 在 Python 中不可用。
在第 2 阶段中,我们依赖于网络构建阶段常用的运算符和层。 理想情况下,您可以在 GW 期间将它们替换为任何网络结构。
对于 FLayerInfo
的用法,让我们重写 gpt_attention
以启用 remove-padding
功能。 gpt_attention
实际上是
一个 TensorRT 插件,因此我们需要 FLayerInfo
来保存原始的 Tensor 类型的输入,以帮助创建新的 gpt_attention
层。
class GPTAttentionPluginRemovePaddingRewritePass(PatternRewriter):
def __init__(self):
super().__init__('gpt_attention_plugin_remove_padding',
root_layer={trt.LayerType.PLUGIN_V2})
def match_and_rewrite(self, layer: Layer) -> bool:
if layer.as_layer().type != trt.LayerType.PLUGIN_V2 or \
layer.as_layer().plugin.plugin_namespace != 'tensorrt_llm' or \
layer.as_layer().plugin.plugin_type != 'GPTAttention':
return False
# Retrieve the FLayerInfo
flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
# Although the layer is a plugin, which is a black box, we get some high-level input information from the FLayerInfo.
tensor_input: Tensor = flayer.get_input('qkv')
if tensor_input.shape[0] == 1: # Already in remove-padding mode
return False
# Some information could be passed in from external
assert self.args is not None, "args should be passed in from RewritePatternManager.rewrite()"
batch_size, in_len, hidden_size = self.args['batch_size'], self.args['in_len'], self.args['hidden_size']
with net_guard(layer.network):
new_inputs = flayer.clone_inputs()
# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
name='qkv',
dtype=trt.float16,
shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['qkv'] = input
# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)
# Step 3: Deprive all the users of the old plugin instance.
flayer.replace_outputs_uses_with(layer.network, new_outs)
# Step 4: Remove the old plugin instance.
layer.mark_as_removed()
return True
这与第一个例子非常相似,重点在于 FLayerInfo
部分。通过下面的代码,我们可以获得该层的原始输入,从而能够修改与移除填充相关的输入,并创建一个新层来替换它。
flayer = FLayerInfoMemo.instance().get(layer.name)
assert flayer
new_inputs = flayer.clone_inputs()
# Step 1: Create new inputs and replace the original arglist.
input = Tensor(
name='tensor',
dtype=trt.float16,
shape=(1, batch_size * in_len, hidden_size),
)
new_inputs['tensor'] = input
# Step 2: Create a new plugin instance.
new_outs = gpt_attention(**new_inputs)
有关真实示例,请参阅 graph_rewriting.py
中的 FuseAttentionWithBiasPass
。