KV Cache 管理器#

在基于 Transformer 的模型中,KV(Key-Value)Cache 是一种用于优化解码效率的机制,尤其是在自回归生成任务中。由于 KV Cache 需要内存来存储,它也是一个重要的资源。在 TensorRT-LLM 中,KV Cache 由 KVCacheManager 管理。

KV Cache 管理器介绍#

KVCacheManager 是一种资源管理器,继承自 BaseResourceManager。因此,它实现了 BaseResourceManager 声明的接口。

注意:随着项目的演进,这些接口可能会发生变化。

接口#

来自 BaseResourceManager 的接口包括

  • prepare_resources:在 PyExecutor 中,每个步骤的模型前向传播之前都会为当前批次调用。在 KVCacheManager 中,这涉及分配 KV Cache 内存。此分配因请求类型而异。对于首次进入上下文阶段的请求,需要为整个上下文分配 KV Cache。对于已处于生成阶段的请求,会为即将到来的步骤分配 KV Cache。如果 KV Cache 以块组织,并且块内有可用空间,实际分配可能不会发生。

  • update_resources:在每个步骤结束时为当前批次调用,以更新已分配的资源。对于 KV Cache,更新可能不是必需的,因此此函数目前不执行任何操作。如果在 Python 中支持 KV Cache 重用,例如 KV Cache Radix Tree 管理等更新会在此处发生。

  • free_resources:在请求完成时调用,以释放为该请求分配的资源。对于 KV Cache,如果未启用重用,应回收该请求使用的 KV Cache 内存。在 C++ 绑定实现中,这可能涉及调用绑定的 remove_sequence 方法来释放与该请求相关的 KV Cache 内存。

还有两个为 CapacityScheduler 设计的接口

  • get_max_resource_count:查询可用资源的最大数量。对于 KVCacheManager,这通常是 KV Cache 块的最大数量。

  • get_needed_resource_to_completion:计算单个请求完成所需的资源。CapacityScheduler 使用此接口汇总所需的总资源,并确定是否可以容纳新请求。

除了 BaseResourceManager 接口外,KVCacheManager 还具有与正在使用的 ModelEngine 相关的接口。对于 PyTorchModelEngine,常用接口包括

  • get_batch_cache_indices:接受一个 LlmRequest 列表,并返回一个 Dict[List[int]],指示每个请求的块 ID。

  • get_buffers:返回给定层的 KV Cache 池的缓冲区,供注意力后端使用。形状可能是 [num_blocks, 2, num_tokens_per_block, num_kv_heads, head_dim]。

  • get_num_free_blocks:返回可用于分配的空闲块数量。

还有用于预热 PyTorchModelEngine 的接口,尤其是在使用 CUDA 图时

  • add_padding_request:向 KV Cache 添加一个上下文长度为 1 的序列作为预热请求。如果在概念验证中未使用 CUDA 图,则此操作是可选的。

定制 KV Cache 管理器#

要定制 KVCacheManager,请实现所有必要的接口。然后,将其集成到 PyExecutor 中。对于 PyTorch 后端,相关代码位于 pytorch_model_registry.py 中。在 create_pytorch_model_based_executor 函数中,KVCacheManager 的实例化方式如下

    kv_cache_manager = KVCacheManager(
        executor_config.kv_cache_config,
        tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
        num_layers=model_engine.model.config.num_hidden_layers,
        num_kv_heads=model_engine.model.config.num_key_value_heads,
        head_dim=head_dim,
        tokens_per_block=tokens_per_block,
        max_seq_len=max_seq_len,
        max_batch_size=max_num_requests,
        mapping=mapping,
        dtype=kv_cache_dtype,
    )

对于本地测试或概念验证,请更新这些行以使用您的实现。然后,进行测试以确保 PyExecutor 使用您定制的 KVCacheManager 运行。