使用 Executor / cpp 运行时运行 gpt-2b + LoRA#
首先构建一个启用 LoRA 和 inflight-batching 的模型。
git-lfs clone https://hugging-face.cn/qychen/luotuo-lora-7b-0.1
git-lfs clone https://hugging-face.cn/kunishou/Japanese-Alpaca-LoRA-7b-v0
BASE_MODEL=llama-7b-hf
python examples/llama/convert_checkpoint.py --model_dir ${BASE_MODEL} \
--output_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \
--dtype float16
trtllm-build --checkpoint_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \
--output_dir /tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/ \
--remove_input_padding enable \
--gpt_attention_plugin float16 \
--context_fmha enable \
--paged_kv_cache enable \
--gemm_plugin float16 \
--lora_plugin float16 \
--max_batch_size 128 \
--max_input_len 512 \
--max_seq_len 562 \
--lora_dir Japanese-Alpaca-LoRA-7b-v0 \
--max_lora_rank 8 \
--lora_target_modules "attn_q" "attn_k" "attn_v"
要将 LoRA 传递到 cpp 运行时,必须将其转换为以下格式。以下脚本会将 Hugging Face LoRA 模型转换为正确的 NumPy 张量。
python3 tensorrt_llm/examples/hf_lora_convert.py -i Japanese-Alpaca-LoRA-7b-v0 -o Japanese-Alpaca-LoRA-7b-v0-weights --storage-type float16
python3 tensorrt_llm/examples/hf_lora_convert.py -i luotuo-lora-7b-0.1 -o luotuo-lora-7b-0.1-weights --storage-type float16
有关使用 Triton 的 Multi-LoRA 示例,请参阅 tensorrtllm_backend 文档。
LoRA 张量格式详细信息#
要使用 Executor
运行推理,Request
必须具有包含 task_id
、weights
和 config
参数的 LoraConfig
。
task_id
给定 LoRA 的唯一任务 ID。
要首次使用特定 LoRA 执行推理,必须提供 task_id
、weights
和 config
。LoRA 将被缓存,因此后续对同一任务的请求只需要 task_id
。如果缓存已满,则会逐出最旧的 LoRA,以便为新的 LoRA 腾出空间。如果 task_id
未缓存,则会返回错误。
weights
包含所有 LoRA 的权重。目前,这应包括所有 TP 和 PP 级别的权重。权重张量的形状为 [num_lora_modules_layers, D x Hi + Ho x D ]
。最后一个维度保存关联模块(例如,attn_qkv
)和模型层的输入/输出适配器权重。
每个输入/输出张量首先被展平,然后以上述格式连接在一起。第一个维度(大小为 num_lora_module_layers
)对于每个模块层都有一个条目(也就是说,attn_q layer1
有一个条目,而 attn_k layer1
有另一个条目)。
D=adapter_size (即 R 值), Hi=hidden_size_in, Ho=hidden_size_out。
config
是一个配置张量,用于标识 LoraWeights
的每个元素的 moduleId、layerId 和适配器大小。它的形状为 [num_lora_modules_layers, 3]
。最后一个维度保存 [module_id, layer_idx, adapter_size D (i.e. R value)]
。
此特性支持 https://arxiv.org/pdf/2106.09685.pdf 中描述的 LoRA。
LoRA 张量示例#
以下是一个具有 tp=1、pp=1、4 层和隐藏大小为 4 的模型的 LoraWeights
和 LoraConfig
张量示例。以下张量用于具有 q
和 k
适配器的 LoRA。
# loraConfig
[
[1, 0, 2]
[2, 0, 4]
[1, 1, 2]
[2, 1, 4]
[1, 2, 2] # Note that the final 2 layers only adapt `q`
[1, 3, 8]
]
# Note: The loraConfig tensor configures the loraWeights tensor.
# The contents of each row of loraWeights is specified be the corresponding row in loraConfig
# loraWeights
# Note: that 'in weights' and 'out weights' are 'A' and 'B' in the LoRA paper.
[
[ <2 x 4 in weights>, <4 x 2 out weights> <padding> ] # `q` adapter for layer 0
[ <4 x 4 in weights>, <4 x 4 out weights> <padding> ] # `k` adapter for layer 0
[ <2 x 4 in weights>, <4 x 2 out weights> <padding> ] # `q` adapter for layer 1
[ <4 x 4 in weights>, <4 x 4 out weights> <padding> ] # `k` adapter for layer 1
[ <2 x 4 in weights>, <4 x 2 out weights> <padding> ] # `q` adapter for layer 2
[ <8 x 4 in weights>, <4 x 8 out weights> ] # `q` adapter for layer 3. Note the final layer has a adapter size of 8
]
LoRA 模块 ID 映射#
模块名称(如 |
模块 ID |
描述 |
---|---|---|
attn_qkv |
0 |
组合的 qkv 适配器 |
attn_q |
1 |
q 适配器 |
attn_k |
2 |
k 适配器 |
attn_v |
3 |
v 适配器 |
attn_dense |
4 |
注意力中密集层的适配器 |
mlp_h_to_4h |
5 |
对于在注意力/ RMSNorm 之后的门控 mlp 层的 llama2 适配器:向上投影 |
mlp_4h_to_h |
6 |
对于在注意力/ RMSNorm 之后门控 mlp 层的 llama2 适配器:向下投影 |
mlp_gate |
7 |
对于在注意力/ RMSNorm 之后门控 mlp 层的 llama2 适配器:门 |
cross_attn_qkv |
8 |
交叉注意力的组合 qkv 适配器 |
cross_attn_q |
9 |
交叉注意力的 q 适配器 |
cross_attn_k |
10 |
交叉注意力的 k 适配器 |
cross_attn_v |
11 |
交叉注意力的 v 适配器 |
cross_attn_dense |
12 |
交叉注意力中密集层的适配器 |
moe_h_to_4h |
13 |
对于专家 mlp 层的 mixtral 适配器:向上投影 |
moe_4h_to_h |
14 |
对于专家 mlp 层的 mixtral 适配器:向下投影 |
moe_gate |
15 |
对于专家 mlp 层的 mixtral 适配器:门 |
moe_router |
16 |
对于专家路由器层的 mixtral 适配器 |
mlp_router |
17 |
对于共享专家门层的 qwen2-moe 适配器 |
mlp_gate_up |
18 |
在注意力/ RMSNorm 之后门控 mlp 层的适配器:门 + 向上投影 |
LoraCache 配置#
核心思想是我们将在 TRT-LLM 中拥有一个固定大小的 2 级 LoRA 缓存。更高级别的缓存位于主机上,而更低级别的缓存位于 GPU 上(不同于现有的 KV 缓存)。两者的大小都是用户可配置的。
CPU 缓存被配置为最大大小。GPU 缓存被配置为引擎加载后可用 GPU 内存的百分比。当请求进入时,LoRA 被存储在主机缓存中。
当请求被调度执行时,LoRA 被加载到 GPU 缓存中。
具有张量并行的 LoRA#
LoRA 的张量并行分区是特殊的。有两种情况:RowLinear
和 ColumnLinear
。假设我们有一个线性层,输入特征大小为 K
,输出特征大小为 N
。那么,权重的形状为 [K, N]
。
首先,考虑这个线性层是一个 ColumnLinear
层。当我们对权重进行划分时,我们通过 tp_size
按列分割权重。然后,会有 tp_size
个分割后的权重,这些权重的形状是 [K, N // tp_size]
。当我们在这样的 ColumnLinear
层上应用 LoRA 适配器时,原始的两个权重的形状是 [K, lora_rank]
和 [lora_rank, N]
。因此,我们只分割第二个权重,并获得 tp_size
个分割后的权重,其形状为 [lora_rank, N // tp_size]
。对于第一个权重,每个 GPU 保持相同的完整权重(形状为 [K, lora_rank]
)。
接下来,考虑这个线性层是一个 RowLinear
层。当我们对权重进行划分时,我们通过 tp_size
按行分割权重。然后,会有 tp_size
个分割后的权重,这些权重的形状是 [K // tp_size, N]
。当我们在这样的 RowLinear
层上应用 LoRA 适配器时,原始的两个权重的形状是 [K, lora_rank]
和 [lora_rank, N]
。因此,我们只分割第一个权重,并获得 tp_size
个分割后的权重,其形状为 [K // tp_size, lora_rank]
。对于第二个权重,每个 GPU 保持相同的完整权重(形状为 [lora_rank, N]
)。
DoRA#
TRTLLM 支持 https://arxiv.org/abs/2402.09353 中描述的 DoRA。要启用 DoRA,您必须将额外的 --dora_plugin enable
标志添加到 trtllm-build
命令中。
DoRA 缩放因子必须在提交到 TRTLLM 中的推理请求之前进行归一化。归一化需要基础模型权重。要归一化您的适配器,您可以使用 tensorrt_llm/examples/dora/normalize_weights.py
中提供的脚本。
使用 DoRA 时,LoraWeights
和 LoraConfig
的格式略有变化。LoraConfig
的形状变为 [module_id, layer_idx, adapter_size D (即 R 值), is_dora]
,其中 is_dora
是一个布尔标志,用于确定提供的适配器是否包含 DoRA 缩放因子。如果使用旧的配置形状,则假定适配器不具有 DoRA 缩放因子。LoraWeights
的形状变为 [num_lora_modules_layers, D x Hi + Ho x D + Ho]
,并且最后的 Ho
值是 DoRA 缩放向量。