使用 Executor / cpp 运行时运行 gpt-2b + LoRA#

首先构建一个启用 LoRA 和 inflight-batching 的模型。

git-lfs clone https://hugging-face.cn/qychen/luotuo-lora-7b-0.1
git-lfs clone https://hugging-face.cn/kunishou/Japanese-Alpaca-LoRA-7b-v0
BASE_MODEL=llama-7b-hf

python examples/llama/convert_checkpoint.py --model_dir ${BASE_MODEL} \
    --output_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \
    --dtype float16

trtllm-build --checkpoint_dir /tmp/llama_7b/trt_ckpt/fp16/1-gpu/ \
    --output_dir /tmp/llama_7b_with_lora_qkv/trt_engines/fp16/1-gpu/ \
    --remove_input_padding enable \
    --gpt_attention_plugin float16 \
    --context_fmha enable \
    --paged_kv_cache enable \
    --gemm_plugin float16 \
    --lora_plugin float16 \
    --max_batch_size 128 \
    --max_input_len 512 \
    --max_seq_len 562 \
    --lora_dir Japanese-Alpaca-LoRA-7b-v0 \
    --max_lora_rank 8 \
    --lora_target_modules "attn_q" "attn_k" "attn_v"

要将 LoRA 传递到 cpp 运行时,必须将其转换为以下格式。以下脚本会将 Hugging Face LoRA 模型转换为正确的 NumPy 张量。

python3 tensorrt_llm/examples/hf_lora_convert.py -i Japanese-Alpaca-LoRA-7b-v0 -o Japanese-Alpaca-LoRA-7b-v0-weights --storage-type float16
python3 tensorrt_llm/examples/hf_lora_convert.py -i luotuo-lora-7b-0.1 -o luotuo-lora-7b-0.1-weights --storage-type float16

有关使用 Triton 的 Multi-LoRA 示例,请参阅 tensorrtllm_backend 文档

LoRA 张量格式详细信息#

要使用 Executor 运行推理,Request 必须具有包含 task_idweightsconfig 参数的 LoraConfig

task_id 给定 LoRA 的唯一任务 ID。

要首次使用特定 LoRA 执行推理,必须提供 task_idweightsconfig。LoRA 将被缓存,因此后续对同一任务的请求只需要 task_id。如果缓存已满,则会逐出最旧的 LoRA,以便为新的 LoRA 腾出空间。如果 task_id 未缓存,则会返回错误。

weights 包含所有 LoRA 的权重。目前,这应包括所有 TP 和 PP 级别的权重。权重张量的形状为 [num_lora_modules_layers, D x Hi + Ho x D ]。最后一个维度保存关联模块(例如,attn_qkv)和模型层的输入/输出适配器权重。

每个输入/输出张量首先被展平,然后以上述格式连接在一起。第一个维度(大小为 num_lora_module_layers)对于每个模块层都有一个条目(也就是说,attn_q layer1 有一个条目,而 attn_k layer1 有另一个条目)。

D=adapter_size (即 R 值), Hi=hidden_size_in, Ho=hidden_size_out。

config 是一个配置张量,用于标识 LoraWeights 的每个元素的 moduleId、layerId 和适配器大小。它的形状为 [num_lora_modules_layers, 3]。最后一个维度保存 [module_id, layer_idx, adapter_size D (i.e. R value)]

此特性支持 https://arxiv.org/pdf/2106.09685.pdf 中描述的 LoRA。

LoRA 张量示例#

以下是一个具有 tp=1、pp=1、4 层和隐藏大小为 4 的模型的 LoraWeightsLoraConfig 张量示例。以下张量用于具有 qk 适配器的 LoRA。

# loraConfig
[
  [1, 0, 2]
  [2, 0, 4]
  [1, 1, 2]
  [2, 1, 4]
  [1, 2, 2]  # Note that the final 2 layers only adapt `q`
  [1, 3, 8]
]
# Note: The loraConfig tensor configures the loraWeights tensor.
#       The contents of each row of loraWeights is specified be the corresponding row in loraConfig

# loraWeights
# Note: that 'in weights' and 'out weights' are 'A' and 'B' in the LoRA paper.
[
  [ <2 x 4 in weights>, <4 x 2 out weights> <padding> ]  # `q` adapter for layer 0
  [ <4 x 4 in weights>, <4 x 4 out weights> <padding> ]  # `k` adapter for layer 0
  [ <2 x 4 in weights>, <4 x 2 out weights> <padding> ]  # `q` adapter for layer 1
  [ <4 x 4 in weights>, <4 x 4 out weights> <padding> ]  # `k` adapter for layer 1
  [ <2 x 4 in weights>, <4 x 2 out weights> <padding> ]  # `q` adapter for layer 2
  [ <8 x 4 in weights>, <4 x 8 out weights>           ]  # `q` adapter for layer 3. Note the final layer has a adapter size of 8
]

LoRA 模块 ID 映射#

模块名称(如 convert_checkpoint.py 脚本中所指定)

模块 ID

描述

attn_qkv

0

组合的 qkv 适配器

attn_q

1

q 适配器

attn_k

2

k 适配器

attn_v

3

v 适配器

attn_dense

4

注意力中密集层的适配器

mlp_h_to_4h

5

对于在注意力/ RMSNorm 之后的门控 mlp 层的 llama2 适配器:向上投影

mlp_4h_to_h

6

对于在注意力/ RMSNorm 之后门控 mlp 层的 llama2 适配器:向下投影

mlp_gate

7

对于在注意力/ RMSNorm 之后门控 mlp 层的 llama2 适配器:门

cross_attn_qkv

8

交叉注意力的组合 qkv 适配器

cross_attn_q

9

交叉注意力的 q 适配器

cross_attn_k

10

交叉注意力的 k 适配器

cross_attn_v

11

交叉注意力的 v 适配器

cross_attn_dense

12

交叉注意力中密集层的适配器

moe_h_to_4h

13

对于专家 mlp 层的 mixtral 适配器:向上投影

moe_4h_to_h

14

对于专家 mlp 层的 mixtral 适配器:向下投影

moe_gate

15

对于专家 mlp 层的 mixtral 适配器:门

moe_router

16

对于专家路由器层的 mixtral 适配器

mlp_router

17

对于共享专家门层的 qwen2-moe 适配器

mlp_gate_up

18

在注意力/ RMSNorm 之后门控 mlp 层的适配器:门 + 向上投影

LoraCache 配置#

核心思想是我们将在 TRT-LLM 中拥有一个固定大小的 2 级 LoRA 缓存。更高级别的缓存位于主机上,而更低级别的缓存位于 GPU 上(不同于现有的 KV 缓存)。两者的大小都是用户可配置的。

CPU 缓存被配置为最大大小。GPU 缓存被配置为引擎加载后可用 GPU 内存的百分比。当请求进入时,LoRA 被存储在主机缓存中。

当请求被调度执行时,LoRA 被加载到 GPU 缓存中。

具有张量并行的 LoRA#

LoRA 的张量并行分区是特殊的。有两种情况:RowLinearColumnLinear。假设我们有一个线性层,输入特征大小为 K,输出特征大小为 N。那么,权重的形状为 [K, N]

首先,考虑这个线性层是一个 ColumnLinear 层。当我们对权重进行划分时,我们通过 tp_size 按列分割权重。然后,会有 tp_size 个分割后的权重,这些权重的形状是 [K, N // tp_size]。当我们在这样的 ColumnLinear 层上应用 LoRA 适配器时,原始的两个权重的形状是 [K, lora_rank][lora_rank, N]。因此,我们只分割第二个权重,并获得 tp_size 个分割后的权重,其形状为 [lora_rank, N // tp_size]。对于第一个权重,每个 GPU 保持相同的完整权重(形状为 [K, lora_rank])。

接下来,考虑这个线性层是一个 RowLinear 层。当我们对权重进行划分时,我们通过 tp_size 按行分割权重。然后,会有 tp_size 个分割后的权重,这些权重的形状是 [K // tp_size, N]。当我们在这样的 RowLinear 层上应用 LoRA 适配器时,原始的两个权重的形状是 [K, lora_rank][lora_rank, N]。因此,我们只分割第一个权重,并获得 tp_size 个分割后的权重,其形状为 [K // tp_size, lora_rank]。对于第二个权重,每个 GPU 保持相同的完整权重(形状为 [lora_rank, N])。

DoRA#

TRTLLM 支持 https://arxiv.org/abs/2402.09353 中描述的 DoRA。要启用 DoRA,您必须将额外的 --dora_plugin enable 标志添加到 trtllm-build 命令中。

DoRA 缩放因子必须在提交到 TRTLLM 中的推理请求之前进行归一化。归一化需要基础模型权重。要归一化您的适配器,您可以使用 tensorrt_llm/examples/dora/normalize_weights.py 中提供的脚本。

使用 DoRA 时,LoraWeightsLoraConfig 的格式略有变化。LoraConfig 的形状变为 [module_id, layer_idx, adapter_size D (即 R 值), is_dora],其中 is_dora 是一个布尔标志,用于确定提供的适配器是否包含 DoRA 缩放因子。如果使用旧的配置形状,则假定适配器不具有 DoRA 缩放因子。LoraWeights 的形状变为 [num_lora_modules_layers, D x Hi + Ho x D + Ho],并且最后的 Ho 值是 DoRA 缩放向量。