问题排查#
本文档描述了 TensorRT-LLM 中一些常见问题及其解决方案,包括安装、模型构建、模型执行以及输入/输出大小的问题。
安装错误#
在 TensorRT-LLM 的编译和安装过程中,许多构建错误可以通过简单地删除构建树并再次重建来解决。
在大多数情况下,这些问题是由如下工作流程引起的:旧的编译 -> 一些代码更改(仓库更新或用户编写) -> 后来的编译。
解决方案:尝试使用 --clean
运行构建脚本,或者在再次运行构建脚本之前尝试运行 rm -r build cpp/build
。
调试单元测试#
这是一个在单元测试中打印 MLP 输出张量值的示例(完整示例)。
使用
register_network_output
API 将中间张量注册为网络输出。
class MLP(Module):
def __init__(self, ...):
super().__init__()
# Do not modify the definition in `__init__` method
self.fc = ...
self.proj = ...
def forward(self, hidden_states):
inter = self.fc(hidden_states)
inter = tensorrt_llm.functional.relu(inter)
# Here register the tensor `inter` as our debug output tensor
self.register_network_output('inter', inter)
output = self.proj(inter)
return output
将中间张量标记为网络输出。
for k, v in gm.named_network_outputs():
net._mark_output(v, k, dtype)
在运行时打印张量。
print(outputs.keys())
print(outputs['inter'])
调试 E2E 模型#
这是一个在 GPT 模型中打印 MLP 输出张量值的示例。
在
tensorrt_llm/models/gpt/model.py
中注册 MLP 输出张量。
hidden_states = residual + attention_output.data
residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
# Register as model output
# ------------------------------------------------------
self.register_network_output('mlp_output', hidden_states)
# ------------------------------------------------------
hidden_states = residual + hidden_states
构建模型的 TensorRT 引擎。
使用 trtllm-build
构建引擎时启用 --enable_debug_output
选项
cd examples/gpt
# Download hf gpt2 model
rm -rf gpt2 && git clone https://hugging-face.cn/gpt2-medium gpt2
pushd gpt2 && rm pytorch_model.bin model.safetensors && wget -q https://hugging-face.cn/gpt2-medium/resolve/main/pytorch_model.bin && popd
# Convert to TensorRT-LLM checkpoint
python3 convert_checkpoint.py \
--model_dir gpt2 \
--dtype float16 \
--output_dir gpt2/trt_ckpt/fp16/1-gpu
# Build TensorRT-LLM engines with --enable_debug_output
trtllm-build \
--checkpoint_dir gpt2/trt_ckpt/fp16/1-gpu \
--enable_debug_output \
--output_dir gpt2/trt_engines/fp16/1-gpu
打印中间输出张量。
在 tensorrt_llm/runtime/generation.py
中添加调试信息。
stream = torch.cuda.current_stream().cuda_stream
instance_idx = step % 2
if self.cuda_graph_mode and self.runtime.cuda_graph_instances[
instance_idx] is not None:
# launch cuda graph
CUASSERT(
cudart.cudaGraphLaunch(
self.runtime.cuda_graph_instances[instance_idx], stream))
ok = True
else:
ok = self.runtime._run(context, stream)
if not ok:
raise RuntimeError(f"Executing TRT engine failed step={step}!")
if self.debug_mode:
torch.cuda.synchronize()
# -------------------------------------------
if step == 0:
print(self.debug_buffer.keys())
print(f"Step: {step}")
print(self.debug_buffer['transformer.layers.6.mlp_output'])
# -------------------------------------------
使用
--debug_mode
和--use_py_session
运行../run.py
。
python3 ../run.py \
--engine_dir gpt2/trt_engines/fp16/1-gpu \
--tokenizer_dir gpt2 \
--max_output_len 8 \
--debug_mode \
--use_py_session
查看张量的值。
......
dict_keys(['context_lengths', 'cache_indirection', 'position_ids', 'logits', 'last_token_ids', 'input_ids', 'kv_cache_block_pointers', 'host_kv_cache_block_pointers', 'sequence_length', 'host_past_key_value_lengths', 'host_sink_token_length', 'host_request_types', 'host_max_attention_window_sizes', 'host_context_lengths', 'transformer.layers.0.mlp_output', 'transformer.layers.1.mlp_output', 'transformer.layers.2.mlp_output', 'transformer.layers.3.mlp_output', 'transformer.layers.4.mlp_output', 'transformer.layers.5.mlp_output', 'transformer.layers.6.mlp_output', 'transformer.layers.7.mlp_output', 'transformer.layers.8.mlp_output', 'transformer.layers.9.mlp_output', 'transformer.layers.10.mlp_output', 'transformer.layers.11.mlp_output', 'transformer.layers.12.mlp_output', 'transformer.layers.13.mlp_output', 'transformer.layers.14.mlp_output', 'transformer.layers.15.mlp_output', 'transformer.layers.16.mlp_output', 'transformer.layers.17.mlp_output', 'transformer.layers.18.mlp_output', 'transformer.layers.19.mlp_output', 'transformer.layers.20.mlp_output', 'transformer.layers.21.mlp_output', 'transformer.layers.22.mlp_output', 'transformer.layers.23.mlp_output'])
Step: 0
tensor([[ 0.0294, -0.0260, -0.0776, ..., -0.0560, -0.0235, 0.0273],
[-0.0071, 0.5879, 0.1993, ..., -1.0449, -0.6299, 0.5957],
[-0.8779, 0.1050, 0.7090, ..., 0.0910, 1.0713, -0.2939],
...,
[ 0.1212, -0.0903, -0.5918, ..., -0.1045, -0.3445, 0.1082],
[-1.0723, -0.0732, 0.6157, ..., 0.3452, 0.2998, 0.2649],
[-0.7134, 0.9692, -0.1141, ..., -0.0096, 0.9521, 0.1437]],
device='cuda:0', dtype=torch.float16)
Step: 1
tensor([[-0.2107, 0.5874, 0.8179, ..., 0.7900, -0.6890, 0.6064]],
device='cuda:0', dtype=torch.float16)
Step: 2
tensor([[ 0.4192, -0.0047, 1.3887, ..., -0.9028, -0.0682, -0.2820]],
device='cuda:0', dtype=torch.float16)
Step: 3
tensor([[-0.7949, -0.5073, -0.1721, ..., -0.5830, -0.1378, -0.0070]],
device='cuda:0', dtype=torch.float16)
Step: 4
tensor([[-0.0804, 0.1272, -0.6255, ..., -0.1072, -0.0523, 0.7144]],
device='cuda:0', dtype=torch.float16)
Step: 5
tensor([[-0.3328, -0.8828, 0.3442, ..., 0.8149, -0.0630, 1.2305]],
device='cuda:0', dtype=torch.float16)
Step: 6
tensor([[-0.2225, -0.2079, -0.1459, ..., -0.3555, -0.1672, 0.1135]],
device='cuda:0', dtype=torch.float16)
Step: 7
tensor([[ 0.1290, -0.1556, 0.3977, ..., -0.8218, -0.3291, -0.8672]],
device='cuda:0', dtype=torch.float16)
Input [Text 0]: "Born in north-east France, Soyer trained as a"
Output [Text 0 Beam 0]: " chef before moving to London in the early"
调试执行错误#
如果问题来自插件,请尝试设置环境变量 CUDA_LAUNCH_BLOCKING=1
,使内核与其返回状态同步启动并立即检查。
如果问题来自输入张量的运行时形状,请仔细检查引擎的输入张量的形状(每个秩的秩和长度)和位置(CPU / GPU)是否符合构建时设置。
例如,获得如下错误信息的可能原因是,我们在引擎构建和运行之间使用了不匹配的配置,包括代码更改(仓库更新或用户重写)、输入形状太大或太小等。
unexpected shape for input 'XXX' for model 'YYY'. Expected [-1,-1,-1], got [8,16]. NOTE: Setting a non-zero max_batch_size in the model config requires a batch dimension to be prepended to each input shape. If you want to specify the full shape including the batch dim in your input dims config, try setting max_batch_size to zero. See the model configuration docs for more info on max_batch_size.
[TensorRT-LLM][ERROR] Assertion failed: Tensor 'input_ids' has invalid shape (8192), expected (-1) (/code/tensorrt_llm/cpp/tensorrt_llm/runtime/tllmRuntime.cpp:149)
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 8192 but got size 1024 for tensor number 1 in the list.
通过设置环境变量 export TLLM_LOG_LEVEL=TRACE
,我们可以获得有关 TensorRT 引擎和运行时上下文的更多信息。
在第一次前向计算之前,所有输入/输出张量的形状及其对应的允许范围都在如下表格中提供
[TensorRT-LLM][TRACE] Information of engine input / output.
[TensorRT-LLM][TRACE] =====================================================================
[TensorRT-LLM][TRACE] Name |I/O|Location|DataType| Shape |
[TensorRT-LLM][TRACE] ---------------------------------------------------------------------
[TensorRT-LLM][TRACE] input_ids | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] position_ids | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] last_token_ids | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] kv_cache_block_offsets | I | GPU | INT32 |(1, -1, 2, -1)|
[TensorRT-LLM][TRACE] host_kv_cache_block_offsets | I | GPU | INT32 |(1, -1, 2, -1)|
[TensorRT-LLM][TRACE] host_kv_cache_pool_pointers | I | GPU | INT64 | (1, 2) |
[TensorRT-LLM][TRACE] host_kv_cache_pool_mapping | I | GPU | INT32 | (28) |
[TensorRT-LLM][TRACE] sequence_length | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] host_request_types | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] host_past_key_value_lengths | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] context_lengths | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] host_runtime_perf_knobs | I | GPU | INT64 | (16) |
[TensorRT-LLM][TRACE] host_context_lengths | I | GPU | INT32 | (-1) |
[TensorRT-LLM][TRACE] host_max_attention_window_sizes| I | GPU | INT32 | (28) |
[TensorRT-LLM][TRACE] host_sink_token_length | I | GPU | INT32 | (1) |
[TensorRT-LLM][TRACE] cache_indirection | I | GPU | INT32 | (-1, 1, -1) |
[TensorRT-LLM][TRACE] logits | O | GPU | FP32 | (-1, 65024) |
[TensorRT-LLM][TRACE] =====================================================================
[TensorRT-LLM][TRACE] Information of optimization profile.
[TensorRT-LLM][TRACE] Optimization Profile 0:
[TensorRT-LLM][TRACE] =============================================================================
[TensorRT-LLM][TRACE] Name | Min | Opt | Max |
[TensorRT-LLM][TRACE] -----------------------------------------------------------------------------
[TensorRT-LLM][TRACE] input_ids | (1) | (8) | (8192) |
[TensorRT-LLM][TRACE] position_ids | (1) | (8) | (8192) |
[TensorRT-LLM][TRACE] last_token_ids | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] kv_cache_block_offsets | (1, 1, 2, 1) |(1, 4, 2, 16) |(1, 8, 2, 32) |
[TensorRT-LLM][TRACE] host_kv_cache_block_offsets | (1, 1, 2, 1) |(1, 4, 2, 16) |(1, 8, 2, 32) |
[TensorRT-LLM][TRACE] host_kv_cache_pool_pointers | (1, 2) | (1, 2) | (1, 2) |
[TensorRT-LLM][TRACE] host_kv_cache_pool_mapping | (28) | (28) | (28) |
[TensorRT-LLM][TRACE] sequence_length | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] host_request_types | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] host_past_key_value_lengths | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] context_lengths | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] host_runtime_perf_knobs | (16) | (16) | (16) |
[TensorRT-LLM][TRACE] host_context_lengths | (1) | (4) | (8) |
[TensorRT-LLM][TRACE] host_max_attention_window_sizes| (28) | (28) | (28) |
[TensorRT-LLM][TRACE] host_sink_token_length | (1) | (1) | (1) |
[TensorRT-LLM][TRACE] cache_indirection | (1, 1, 1) | (4, 1, 1024) | (8, 1, 2048) |
[TensorRT-LLM][TRACE] logits | (1, 65024) | (4, 65024) | (8, 65024) |
[TensorRT-LLM][TRACE] =============================================================================
在每次前向计算之前,TRT 引擎的所有输入/输出张量的真实形状都由如下表格提供
[TensorRT-LLM][TRACE] Information of context input / output.
[TensorRT-LLM][TRACE] Using Optimization Profile: 0
[TensorRT-LLM][TRACE] =================================================
[TensorRT-LLM][TRACE] Name |I/O| Shape |
[TensorRT-LLM][TRACE] -------------------------------------------------
[TensorRT-LLM][TRACE] input_ids | I | (33) |
[TensorRT-LLM][TRACE] position_ids | I | (33) |
[TensorRT-LLM][TRACE] last_token_ids | I | (3) |
[TensorRT-LLM][TRACE] kv_cache_block_offsets | I |(1, 3, 2, 4)|
[TensorRT-LLM][TRACE] host_kv_cache_block_offsets | I |(1, 3, 2, 4)|
[TensorRT-LLM][TRACE] host_kv_cache_pool_pointers | I | (1, 2) |
[TensorRT-LLM][TRACE] host_kv_cache_pool_mapping | I | (28) |
[TensorRT-LLM][TRACE] sequence_length | I | (3) |
[TensorRT-LLM][TRACE] host_request_types | I | (3) |
[TensorRT-LLM][TRACE] host_past_key_value_lengths | I | (3) |
[TensorRT-LLM][TRACE] context_lengths | I | (3) |
[TensorRT-LLM][TRACE] host_runtime_perf_knobs | I | (16) |
[TensorRT-LLM][TRACE] host_context_progress | I | (1) |
[TensorRT-LLM][TRACE] host_context_lengths | I | (3) |
[TensorRT-LLM][TRACE] host_max_attention_window_sizes| I | (28) |
[TensorRT-LLM][TRACE] host_sink_token_length | I | (1) |
[TensorRT-LLM][TRACE] cache_indirection | I |(3, 2, 256) |
[TensorRT-LLM][TRACE] logits | O | (3, 65024) |
[TensorRT-LLM][TRACE] =================================================
提示#
建议向 docker 或 nvidia-docker 运行命令添加选项
–shm-size=1g –ulimit memlock=-1
。 否则,在运行多个 GPU 推理时,您可能会看到 NCCL 错误。 有关详细信息,请参阅 https://docs.nvda.net.cn/deeplearning/nccl/user-guide/docs/troubleshooting.html#errors。在构建模型时,可能会发生与内存相关的问题,例如
[09/23/2023-03:13:00] [TRT] [E] 9: GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types
[09/23/2023-03:13:00] [TRT] [E] 9: [pluginV2Builder.cpp::reportPluginError::24] Error Code 9: Internal Error (GPTLMHeadModel/layers/0/attention/qkv/PLUGIN_V2_Gemm_0: could not find any supported formats consistent with input/output data types)
一个可能的解决方案是通过减少最大批量大小、输入和输出长度来减少所需的内存量。 另一种选择是启用插件,例如:--gpt_attention_plugin
。
MPI + Slurm
TensorRT-LLM 是一个MPI 感知包,它使用 mpi4py
。 如果您在 Slurm 环境中运行脚本,您可能会遇到干扰
--------------------------------------------------------------------------
PMI2_Init failed to initialize. Return code: 14
--------------------------------------------------------------------------
--------------------------------------------------------------------------
The application appears to have been direct launched using "srun",
but OMPI was not built with SLURM's PMI support and therefore cannot
execute. There are several options for building PMI support under
SLURM, depending upon the SLURM version you are using:
version 16.05 or later: you can use SLURM's PMIx support. This
requires that you configure and build SLURM --with-pmix.
Versions earlier than 16.05: you must use either SLURM's PMI-1 or
PMI-2 support. SLURM builds PMI-1 by default, or you can manually
install PMI-2. You must then build Open MPI using --with-pmi pointing
to the SLURM PMI library location.
Please configure as appropriate and try again.
--------------------------------------------------------------------------
您可能会遇到其他问题,例如程序启动时挂起。
根据经验,如果您在 Slurm 节点上以交互方式运行 TensorRT-LLM,请在命令前加上 mpirun -n 1
,以在专用的 MPI 环境中运行 TensorRT-LLM,而不是由您的 Slurm 分配提供的 MPI 环境。
例如:mpirun -n 1 python3 examples/gpt/build.py ...
无论使用多少 GPU,都必须始终为 -n 1
。 如果您为 2 GPU 程序使用 -n 2
,它将无法工作。 此处的 mpirun
不是用于协调多个进程,而是用于在 SLURM 上调用正确的环境。 内部 MPI 实现处理生成额外的进程。