在 PyTorch 后端添加新模型#
目录#
介绍#
本指南提供了在 PyTorch 后端添加新模型的分步过程。
前提条件#
在开始之前,请确保您已具备以下条件:
已成功安装 TensorRT-LLM。请按照这些说明进行操作。
分步指南#
模型配置#
假设您想支持一个名为 MyModel
的新模型。如果该模型已在 HuggingFace 的 transformers 中支持,您应该引入 PyTorch 模型代码并复用 HuggingFace 的配置类。例如,我们的 tensorrt_llm/_torch/models/modeling_llama.py
改编自 HuggingFace 的 modeling_llama.py;在模型代码中,我们复用了配置类。
from transformers import LlamaConfig
如果该模型未在 HuggingFace 的 transformers 中注册,您需要参照 HuggingFace 的 configuration_llama.py 在您的 configuration_mymodel.py
中定义配置类。
from transformers.configuration_utils import PretrainedConfig
class MyConfig(PretrainedConfig):
def __init__(self, ...):
...
模型定义#
移除任何不必要的代码(例如,训练特定代码),然后重写一些 PyTorch 模块。对于典型的 Transformer 解码器模型,您需要像这样实现您的 modeling_mymodel.py
from typing import Optional
import torch
from torch import nn
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import DecoderModel, DecoderModelForCausalLM
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from configuration_mymodel import MyConfig
class MyAttention(Attention):
def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: Optional[int] = None):
# Use model_config to initialize the Attention module
super().__init__(...)
class MyDecoderLayer(DecoderLayer):
def __init__(self, model_config: ModelConfig[MyConfig], layer_idx: int):
super().__init__()
# Use model_config to initialize the submodules
self.input_layernorm = ...
self.self_attn = MyAttention(model_config, layer_idx)
self.post_attention_layernorm = ...
self.mlp = ...
def forward(self, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs):
# Define the forward computation of a single decoder layer
...
class MyModel(DecoderModel):
def __init__(self, model_config: ModelConfig[MyConfig]):
super().__init__(model_config)
# Use model_config to initialize the submodules
self.embed_tokens = ...
self.layers = nn.ModuleList([
MyDecoderLayer(model_config, layer_idx) for layer_idx in range(model_config.pretrained_config.num_hidden_layers)
])
def forward(self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None):
# Define the forward computation of the model
...
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def __init__(self, model_config: ModelConfig[MyConfig]):
super().__init__(MyModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
请注意,MyAttention
继承自我们的 Attention
模块(位于 tensorrt_llm/_torch/modules/attention.py
中),以便注意力计算与我们的 PyTorch 运行时兼容。与此相关,模块输入也应进行调整。
The
attn_metadata
存储了来自批量输入和 KV 缓存的元数据,供注意力后端使用。它由运行时创建并传递,模型开发者需要确保attn_metadata
正确传递给了注意力模块。输入张量(即
input_ids
、position_ids
、hidden_states
)处于 packed 模式。第一个维度对应于批次中的令牌数量。
此外,MyDecoderLayer
、MyModel
和 MyModelForCausalLM
分别是 DecoderLayer
、DecoderModel
和 DecoderModelForCausalLM
的子类。基类定义了接口并提供了定义模型层、加载权重等的通用框架。
您可以选择用我们的实现替换原生 PyTorch 模块,以启用功能或获得更高性能:
Linear
(位于tensorrt_llm/_torch/modules/linear.py
中):启用张量并行和量化。Embedding
(位于tensorrt_llm/_torch/modules/embedding.py
中):为嵌入启用张量并行。RotaryEmbedding
(位于tensorrt_llm/_torch/modules/rotary_embedding.py
中):启用高性能旋转嵌入。RMSNorm
(位于tensorrt_llm/_torch/modules/rms_norm.py
中):启用高性能 RMS 范数。
要获取具体参考,请查看 tensorrt_llm/_torch/models/modeling_llama.py
。
权重加载#
基类 DecoderModelForCausalLM
提供了一个 load_weights
方法,用于从检查点文件加载权重并将其分配给模型中对应的层。但是,如果默认方法不适用于 MyModelForCausalLM
,您需要实现自己的 load_weights
方法。
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def load_weights(self, weights: dict):
# Define the weight loading logic
...
例如,Huggingface 的 LLaMA 模型使用三个线性层进行 Q/K/V 投影,导致检查点中有三个权重张量。
>>> weights
{
...,
"model.layers.0.self_attn.q_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.k_proj.weight": torch.Tensor([hidden_size, hidden_size]),
"model.layers.0.self_attn.v_proj.weight": torch.Tensor([hidden_size, hidden_size]),
...,
}
然而,我们的 LLaMA 模型将这三个层融合为一个线性层。
>>> llama.model.layers[0].self_attn.qkv_proj.weight.data
torch.Tensor([hidden_size * 3, hidden_size])
因此,load_weights
需要从原始检查点中收集这三个权重张量,将它们连接起来,并分配给融合后的线性层。考虑到张量并行和量化,这个过程会更复杂。我们建议在实现您模型级的 load_weights
方法时,调用预定义的模块级 load_weights
方法(例如 Linear
和 Embedding
)。
总而言之,load_weights
应该处理 MyModelForCausalLM
与从检查点加载的权重之间的任何差异,以便 MyModelForCausalLM
可以执行与原始模型等效的前向计算。
模型注册#
新模型需要注册,以便 PyTorch 运行时能够识别它。注册只需为 MyModelForCausalLM
添加 register_auto_model
装饰器即可。
from tensorrt_llm._torch.models.modeling_utils import register_auto_model
@register_auto_model("MyModelForCausalLM")
class MyModelForCausalLM(DecoderModelForCausalLM[MyModel, MyConfig]):
def __init__(self, model_config: ModelConfig[MyConfig]):
...
核心模型#
要将新模型添加到核心模型中,应将 modeling_mymodel.py
(以及可能的 configuration_mymodel.py
)放在 tensorrt_llm/_torch/models
目录下。然后,您需要在 tensorrt_llm/_torch/models/__init__.py
中导入模型代码。
from .modeling_mymodel import MyModelForCausalLM
__all__ = [
...,
"MyModelForCausalLM",
]
树外模型#
或者,您可以将新模型注册为树外模型,这样您就可以使用新模型而无需修改 TensorRT-LLM 代码库。为此,将 modeling_mymodel.py
(以及可能的 configuration_mymodel.py
)放在您的工作目录下,并在您的脚本中导入模型代码。
from tensorrt_llm._torch import LLM
import modeling_mymodel
def main():
llm = LLM(...)
if __name__ == '__main__':
main()
我们在 examples/pytorch/out_of_tree_example
中提供了一个树外模型的示例。该模型在 modeling_opt.py
中实现,您可以通过以下方式运行该示例:
python examples/pytorch/out_of_tree_example/main.py