From 18ffd945a2c54fa1bcddf335302d0949b918a725 Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Mon, 4 May 2026 13:24:41 +0200 Subject: [PATCH 1/6] feat(kvcache): support head-wise SWA recycle --- .../yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml | 10 + custom_ops/gpu_ops/append_attention.cu | 20 +- .../append_attn/append_attention_c16_impl.cuh | 8 + .../append_attn/append_attention_kernel.h | 2 + .../multiquery_attention_c16_impl.cuh | 75 ++++- .../multiquery_attention_c16_kernel.h | 1 + custom_ops/gpu_ops/cpp_extensions.cc | 6 + .../cache_manager/prefix_cache_manager.py | 130 ++++++++- fastdeploy/config.py | 21 +- .../engine/sched/resource_manager_v1.py | 273 +++++++++++++++++- fastdeploy/envs.py | 8 +- .../layers/attention/append_attn_backend.py | 2 + .../layers/attention/ops/append_attention.py | 5 + .../models/paddleformers/base.py | 17 +- fastdeploy/worker/gpu_model_runner.py | 9 +- fastdeploy/worker/input_batch.py | 4 +- fastdeploy/worker/worker_process.py | 16 +- .../test_benchmark_head_wise_swa.py | 149 ++++++++++ .../test_head_wise_abort_reset.py | 193 +++++++++++++ .../test_head_wise_extend_validation.py | 136 +++++++++ .../cache_manager/test_head_wise_freelist.py | 160 ++++++++++ .../test_head_wise_tp_consistency.py | 138 +++++++++ tests/cache_manager/test_swa_recycle.py | 217 ++++++++++++++ .../test_swa_recycle_legacy_relief.py | 89 ++++++ .../test_append_attention_head_wise_shapes.py | 106 +++++++ 25 files changed, 1762 insertions(+), 33 deletions(-) create mode 100644 benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml create mode 100644 tests/cache_manager/test_benchmark_head_wise_swa.py create mode 100644 tests/cache_manager/test_head_wise_abort_reset.py create mode 100644 tests/cache_manager/test_head_wise_extend_validation.py create mode 100644 tests/cache_manager/test_head_wise_freelist.py create mode 100644 tests/cache_manager/test_head_wise_tp_consistency.py create mode 100644 tests/cache_manager/test_swa_recycle.py create mode 100644 tests/cache_manager/test_swa_recycle_legacy_relief.py create mode 100644 tests/layers/test_append_attention_head_wise_shapes.py diff --git a/benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml b/benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml new file mode 100644 index 00000000000..8d9a43d1f39 --- /dev/null +++ b/benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml @@ -0,0 +1,10 @@ +# T53 bench workload — KV-bound (not slot-bound); gate: FD_HEAD_WISE_KV_CACHE=1 +# max_num_seqs raised to 512 so the KV pool, not the slot count, is the bottleneck. +# kv_cache_ratio lowered to 0.50 to shrink the pool and accelerate KV pressure. +# Use with: INPUT_LEN=8192 OUTPUT_LEN=4096 REQUEST_RATE=8 +# +max_model_len: 32768 +max_num_seqs: 512 +kv_cache_ratio: 0.50 +tensor_parallel_size: 1 +max_num_batched_tokens: 32768 diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index c1586945cc5..74e0be85c6d 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -48,6 +48,7 @@ void AppendAttentionKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::optional& block_tables_headwise, const paddle::Tensor& encoder_batch_ids, const paddle::Tensor& encoder_tile_ids_per_batch, const paddle::Tensor& encoder_num_blocks, @@ -155,6 +156,7 @@ void AppendAttentionKernel( batch_id_per_token, cu_seqlens_q, block_tables, + block_tables_headwise, lambda_batch_ids, lambda_tile_ids_per_batch, cache_quant_type_str, @@ -488,6 +490,9 @@ std::vector AppendAttention( const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, + const paddle::optional& + block_tables_headwise, // logical 3D, physical rank-2 [max_num_seqs * + // local_kv_heads, max_blocks_per_head] const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -595,6 +600,7 @@ std::vector AppendAttention( batch_id_per_token, cu_seqlens_q, block_tables, + block_tables_headwise, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks, @@ -700,6 +706,9 @@ std::vector AppendAttentionWithOutput( const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, + const paddle::optional& + block_tables_headwise, // logical 3D, physical rank-2 [max_num_seqs * + // local_kv_heads, max_blocks_per_head] const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -753,6 +762,7 @@ std::vector AppendAttentionWithOutput( batch_id_per_token, cu_seqlens_q, block_tables, + block_tables_headwise, encoder_batch_ids, encoder_tile_ids_per_batch, encoder_num_blocks, @@ -871,6 +881,7 @@ std::vector> AppendAttentionInferShape( const paddle::optional>& q_norm_weight_shape, const paddle::optional>& k_norm_weight_shape, const paddle::optional>& sinks_shape, + const paddle::optional>& block_tables_headwise_shape, const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -937,6 +948,7 @@ std::vector AppendAttentionInferDtype( const paddle::optional& q_norm_weight_dtype, const paddle::optional& k_norm_weight_dtype, const paddle::optional& sinks_dtype, + const paddle::optional& block_tables_headwise_dtype, const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -1024,6 +1036,7 @@ std::vector> AppendAttentionWithOutputInferShape( const paddle::optional>& q_norm_weight_shape, const paddle::optional>& k_norm_weight_shape, const paddle::optional>& sinks_shape, + const paddle::optional>& block_tables_headwise_shape, const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -1083,6 +1096,7 @@ std::vector AppendAttentionWithOutputInferDtype( const paddle::optional& q_norm_weight_dtype, const paddle::optional& k_norm_weight_dtype, const paddle::optional& sinks_dtype, + const paddle::optional& block_tables_headwise_dtype, const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -1140,7 +1154,8 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("kv_signal_data"), paddle::Optional("q_norm_weight"), paddle::Optional("k_norm_weight"), - paddle::Optional("sinks")}) + paddle::Optional("sinks"), + paddle::Optional("block_tables_headwise")}) .Outputs({"fmha_out"}) .Attrs({ "rms_norm_eps: float", @@ -1203,7 +1218,8 @@ PD_BUILD_STATIC_OP(append_attention_with_output) paddle::Optional("kv_signal_data"), paddle::Optional("q_norm_weight"), paddle::Optional("k_norm_weight"), - paddle::Optional("sinks")}) + paddle::Optional("sinks"), + paddle::Optional("block_tables_headwise")}) .Outputs({"fmha_out_out"}) .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) .Attrs({ diff --git a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh index 70329c9366a..eabbdd2c834 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_c16_impl.cuh @@ -44,6 +44,7 @@ void CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -109,6 +110,7 @@ void CascadeAppendAttentionC16Kernel( batch_id_per_token, cu_seqlens_q, block_table, + block_table_headwise, batch_ids, tile_ids_per_batch, num_blocks, @@ -156,6 +158,7 @@ CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -204,6 +207,7 @@ CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -251,6 +255,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -298,6 +303,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -346,6 +352,7 @@ CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, @@ -393,6 +400,7 @@ template void CascadeAppendAttentionC16Kernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const int num_blocks, diff --git a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h index ca06deeeb75..f340ada3474 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_kernel.h +++ b/custom_ops/gpu_ops/append_attn/append_attention_kernel.h @@ -47,6 +47,7 @@ void CascadeAppendAttentionKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_table, + const paddle::optional& block_table_headwise, const paddle::Tensor& batch_ids, const paddle::Tensor& tile_ids_per_batch, const std::string& cache_quant_type_str, @@ -86,6 +87,7 @@ void CascadeAppendAttentionKernel( batch_id_per_token, cu_seqlens_q, block_table, + block_table_headwise, batch_ids, tile_ids_per_batch, num_blocks, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index e7463154c43..d2d7ce6e43a 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -44,9 +44,11 @@ __global__ void multi_query_append_attention_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ block_table_hw, const int *__restrict__ mask_offset, const int max_seq_len, const int max_block_num_per_seq, + const int max_blocks_per_head, const float scale, const float quant_max_bound, const float quant_min_bound, @@ -73,7 +75,11 @@ __global__ void multi_query_append_attention_kernel( const uint32_t batch_id = batch_ids[btid]; const uint32_t tile_id = tile_ids_per_batch[btid]; const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + const int *block_table_now = + (block_table_hw != nullptr) + ? block_table_hw + + (batch_id * kv_num_heads + kv_head_idx) * max_blocks_per_head + : block_table + batch_id * max_block_num_per_seq; // When cudagraph capture prefill, may launch more gridDim.x if (btid >= static_cast(num_blocks_x_cpu)) { @@ -207,6 +213,9 @@ __global__ void multi_query_append_attention_kernel( uint32_t kv_idx_base = chunk_start; int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); @@ -446,10 +455,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ tile_ids_per_batch, const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] + const int *__restrict__ block_table_hw, const int *__restrict__ mask_offset, const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask const int max_seq_len, const int max_block_num_per_seq, + const int max_blocks_per_head, const float scale, const float quant_max_bound, const float quant_min_bound, @@ -481,7 +492,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t tile_id = tile_ids_per_batch[btid]; const uint32_t num_rows_per_block = num_frags_x * 16; - const int *block_table_now = block_table + batch_id * max_block_num_per_seq; + const int *block_table_now = + (block_table_hw != nullptr) + ? block_table_hw + + (batch_id * kv_num_heads + kv_head_idx) * max_blocks_per_head + : block_table + batch_id * max_block_num_per_seq; const uint32_t q_len = seq_lens[batch_id]; const uint32_t kv_len = seq_lens_kv[batch_id] + q_len; @@ -589,6 +604,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel( uint32_t kv_idx_base = chunk_start; int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); + if (block_id < 0) { + block_id = 0; + } const uint32_t const_offset = kv_head_idx * kv_h_stride + (wid * 4 + tid / 8) * kv_b_stride + tid % 8 * num_elems_per_128b(); @@ -834,6 +852,7 @@ void MultiQueryAppendAttention( const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, + const paddle::optional &block_table_headwise, const paddle::Tensor &batch_ids, const paddle::Tensor &tile_ids_per_batch, const int num_blocks_x_cpu, @@ -858,6 +877,50 @@ void MultiQueryAppendAttention( auto token_num = meta_data.token_nums; auto bsz = meta_data.batch_size; auto max_block_num_per_seq = meta_data.max_blocks_per_seq; + const int *block_table_hw_ptr = + block_table_headwise ? block_table_headwise.get().data() : nullptr; + const int max_blocks_per_head = + block_table_headwise + ? static_cast(block_table_headwise.get().shape().back()) + : 0; + if (block_table_headwise) { + const auto &hw_shape = block_table_headwise.get().shape(); + PADDLE_ENFORCE_EQ( + hw_shape.size(), + 2u, + phi::errors::InvalidArgument( + "block_tables_headwise must be rank-2 (logical [bsz, " + "kv_num_heads, max_blocks_per_head] flattened to " + "[bsz*kv_num_heads, max_blocks_per_head]); got rank %zu.", + hw_shape.size())); + PADDLE_ENFORCE_EQ( + hw_shape[0], + static_cast(bsz) * static_cast(kv_num_heads), + phi::errors::InvalidArgument( + "block_tables_headwise dim 0 must equal bsz * kv_num_heads " + "(%d * %d = %d); got %ld.", + bsz, + kv_num_heads, + bsz * kv_num_heads, + static_cast(hw_shape[0]))); + PADDLE_ENFORCE_GT( + max_blocks_per_head, + 0, + phi::errors::InvalidArgument( + "block_tables_headwise last dim (max_blocks_per_head) must be " + "> 0; got %d.", + max_blocks_per_head)); + PADDLE_ENFORCE_GE( + max_blocks_per_head, + max_block_num_per_seq, + phi::errors::InvalidArgument( + "block_tables_headwise max_blocks_per_head (%d) must be >= " + "max_block_num_per_seq (%d) to satisfy the per-iteration " + "prefetch contract of multi_query_append_attention C16 " + "kernels.", + max_blocks_per_head, + max_block_num_per_seq)); + } constexpr uint32_t num_warps = 4; constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; @@ -959,9 +1022,11 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + block_table_hw_ptr, meta_data.mask_offset, max_seq_len, max_block_num_per_seq, + max_blocks_per_head, scale, quant_max_bound, quant_min_bound, @@ -1027,9 +1092,11 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + block_table_hw_ptr, meta_data.mask_offset, max_seq_len, max_block_num_per_seq, + max_blocks_per_head, scale, quant_max_bound, quant_min_bound, @@ -1186,11 +1253,13 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + block_table_hw_ptr, meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, max_seq_len, max_block_num_per_seq, + max_blocks_per_head, scale, quant_max_bound, quant_min_bound, @@ -1244,11 +1313,13 @@ void MultiQueryAppendAttention( tile_ids_per_batch.data(), cu_seqlens_q.data(), block_table.data(), + block_table_hw_ptr, meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, max_seq_len, max_block_num_per_seq, + max_blocks_per_head, scale, quant_max_bound, quant_min_bound, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h index 9fe215be66b..c5e9f7f1b23 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_kernel.h @@ -39,6 +39,7 @@ void MultiQueryAppendAttention( const paddle::Tensor &batch_id_per_token, const paddle::Tensor &cu_seqlens_q, const paddle::Tensor &block_table, + const paddle::optional &block_table_headwise, const paddle::Tensor &batch_ids, const paddle::Tensor &tile_ids_per_batch, const int num_blocks_x_cpu, diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 204ea33e50b..a351421f8e2 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -113,6 +113,9 @@ std::vector AppendAttention( const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, + const paddle::optional& + block_tables_headwise, // logical 3D, physical rank-2 [max_num_seqs * + // local_kv_heads, max_blocks_per_head] const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, @@ -170,6 +173,9 @@ std::vector AppendAttentionWithOutput( const paddle::optional& q_norm_weight, const paddle::optional& k_norm_weight, const paddle::optional& sinks, + const paddle::optional& + block_tables_headwise, // logical 3D, physical rank-2 [max_num_seqs * + // local_kv_heads, max_blocks_per_head] const float rms_norm_eps, const std::string& compute_dtype, const std::string& cache_quant_type_str, diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index b1e79834d92..c1e445b4229 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -57,7 +57,7 @@ def __init__( local_data_parallel_id=0, ): """ - initialize the PrefixCacheManager + initialize the PrefixCacheManager. """ self.metrics = CacheMetrics() @@ -79,6 +79,27 @@ def __init__( self.num_gpu_blocks = self.cache_config.prefill_kvcache_block_num self.num_cpu_blocks = self.cache_config.num_cpu_blocks + # Head-wise KV cache (Hackathon 10th Spring No.53, mirrors PR #6702 contract). + # Default-off: behavior is bit-identical to mainline unless FD_HEAD_WISE_KV_CACHE=1. + # T53: per-rank KV head count for free-list sizing (TP-aware). + kv_num_heads_global = int( + getattr(getattr(self.cache_config, "model_cfg", None), "num_key_value_heads", 1) or 1 + ) + tp_size = int(self.tensor_parallel_size or 1) + self.kv_num_heads = max(1, kv_num_heads_global // tp_size) if kv_num_heads_global >= tp_size else 1 + _enable_prefix_caching = bool(getattr(self.cache_config, "enable_prefix_caching", False)) + if bool(envs.FD_HEAD_WISE_KV_CACHE) and _enable_prefix_caching: + raise ValueError( + "FD_HEAD_WISE_KV_CACHE is mutually exclusive with enable_prefix_caching " "(matches PR #6702 contract)" + ) + self.head_wise = bool(envs.FD_HEAD_WISE_KV_CACHE) and not _enable_prefix_caching + self.total_head_wise_cache_ids = 0 + # Head-wise free list lives in its OWN attribute so the legacy + # gpu_free_block_list (consumed by allocate_gpu_blocks) keeps its + # [0, num_gpu_blocks) ID space. Aliasing the two lists corrupts the + # legacy allocator with OOB cache ids → CUDA error 700 at decode. + self.gpu_free_head_wise_block_list = [] + self.gpu_free_block_list = list(range(self.num_gpu_blocks - 1, -1, -1)) if self.num_cpu_blocks > 0: self.cpu_free_block_list = list(range(self.num_cpu_blocks - 1, -1, -1)) @@ -172,6 +193,9 @@ def _get_kv_cache_shape(self, max_block_num): @property def available_gpu_resource(self): + if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0: + head_free = len(getattr(self, "gpu_free_head_wise_block_list", [])) + return (head_free // max(1, self.kv_num_heads)) / self.num_gpu_blocks return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 def launch_cache_manager( @@ -468,6 +492,29 @@ def update_cache_config(self, cache_config): main_process_metrics.free_gpu_block_num.set(self.num_gpu_blocks) main_process_metrics.available_gpu_resource.set(1.0) + if getattr(self, "head_wise", False): + self._init_head_wise_free_list() + + def _init_head_wise_free_list(self): + """ + Build a head-wise free list over ``num_gpu_blocks * kv_num_heads`` cache ids. + + Each cache id corresponds to a (block, head) pair. Allocation/recycling + is performed via :meth:`allocate_gpu_blocks_head_wise` / + :meth:`recycle_gpu_blocks_head_wise`. This path is unreachable when + ``FD_HEAD_WISE_KV_CACHE=0`` (default). + + The list is stored on a dedicated attribute (``gpu_free_head_wise_block_list``) + so the legacy ``gpu_free_block_list`` (consumed by ``allocate_gpu_blocks``) + keeps its [0, num_gpu_blocks) ID space untouched. + """ + total_cache_ids = self.num_gpu_blocks * max(1, self.kv_num_heads) + self.gpu_free_head_wise_block_list = list(range(total_cache_ids - 1, -1, -1)) + heapq.heapify(self.gpu_free_head_wise_block_list) + self.total_head_wise_cache_ids = total_cache_ids + main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_head_wise_block_list)) + main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) + def can_allocate_gpu_blocks(self, num_blocks: int, try_free_gpu_blocks: bool = True): """ Check if num_blocks gpu blocks can be allocated. @@ -532,6 +579,87 @@ def recycle_gpu_blocks(self, gpu_block_ids, req_id=None): main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_block_list)) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) + def allocate_gpu_blocks_head_wise(self, num_blocks, req_id=None): + """ + Allocate ``num_blocks`` GPU blocks per KV head. + + Returns a head-major nested list of cache ids with shape + ``[kv_num_heads][num_blocks]``. Mirrors :meth:`allocate_gpu_blocks` but + operates on the head-wise free list built by + :meth:`_init_head_wise_free_list`. + + Active only when ``FD_HEAD_WISE_KV_CACHE=1`` (default-off; mainline + behavior is unchanged). + """ + kv_num_heads = max(1, self.kv_num_heads) + needed = num_blocks * kv_num_heads + free_list = self.gpu_free_head_wise_block_list + assert needed <= len(free_list), f"head-wise gpu free block num: {len(free_list)} < needed number {needed}" + logger.debug(f"{req_id} start allocate (head-wise)...") + flat = [heapq.heappop(free_list) for _ in range(needed)] + # Head-major reshape: row h contains the num_blocks cache ids assigned to KV head h. + allocated = [flat[h * num_blocks : (h + 1) * num_blocks] for h in range(kv_num_heads)] + logger.info( + f"req_id:{req_id} allocate_gpu_blocks_head_wise: {allocated}, " + f"len(gpu_free_head_wise_block_list) {len(free_list)}" + ) + main_process_metrics.free_gpu_block_num.set(len(free_list)) + main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) + return allocated + + def recycle_gpu_blocks_head_wise(self, cache_ids, req_id=None): + """ + Recycle head-wise cache ids back into the free heap. + + Accepts either a flat list/tuple of ids or a nested list-of-lists + (head-major shape from :meth:`allocate_gpu_blocks_head_wise`). + Duplicates are dropped and out-of-range ids are warned and skipped + (never raised) so a single bad caller cannot poison the heap. + + Mirrors the ``prefix_tree_status_signal`` early-return guarded by + :meth:`recycle_gpu_blocks`. + """ + if ( + hasattr(self, "prefix_tree_status_signal") + and self.prefix_tree_status_signal.value[0] != PrefixTreeStatus.NORMAL + ): + logger.warning("Prefix tree is not normal, skip recycle gpu blocks (head-wise)") + return + + # Auto-flatten nested input. + if cache_ids and isinstance(cache_ids[0], (list, tuple)): + flat = [cid for row in cache_ids for cid in row] + elif isinstance(cache_ids, (list, tuple)): + flat = list(cache_ids) + else: + flat = [cache_ids] + + total = self.total_head_wise_cache_ids + seen = set() + valid = [] + for cid in flat: + if cid in seen: + logger.warning(f"req_id:{req_id} head-wise recycle: duplicate cache id {cid} dropped") + continue + if not (0 <= int(cid) < total): + logger.warning( + f"req_id:{req_id} head-wise recycle: out-of-range cache id {cid} " + f"(valid range [0, {total})) dropped" + ) + continue + seen.add(cid) + valid.append(cid) + + free_list = self.gpu_free_head_wise_block_list + for cid in valid: + heapq.heappush(free_list, cid) + logger.info( + f"req_id:{req_id} recycle_gpu_blocks_head_wise: pushed {len(valid)} ids, " + f"len(gpu_free_head_wise_block_list) {len(free_list)}" + ) + main_process_metrics.free_gpu_block_num.set(len(free_list)) + main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) + def allocate_cpu_blocks(self, num_blocks): """ allocate cpu blocks. diff --git a/fastdeploy/config.py b/fastdeploy/config.py index f18a4c6ee0a..5c33a47e1ea 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -71,7 +71,7 @@ # Some model suffixes are based on auto classes from Transformers: # https://huggingface.co/docs/transformers/en/model_doc/auto -# NOTE: Items higher on this list priority over lower ones +# NOTE: Items higher on this list priority over lower ones. _SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [ ("ForCausalLM", ("generate", "none")), ("ForConditionalGeneration", ("generate", "none")), @@ -2043,6 +2043,25 @@ def __init__( self.read_from_config() self.postprocess() self.init_pd_info() + # T53 PR1 — engine-main FDConfig fixture for per-head SWA block recycle. + # ResourceManagerV1._should_use_head_wise_swa (resource_manager_v1.py:298-305) + # reads model_config.head_wise_swa_ratio from the engine-main FDConfig instance. + # The worker-side mutation at paddleformers/base.py:793-804 sets the same attrs + # on a DIFFERENT FDConfig copy (worker process). This block mirrors that mutation + # in the engine-main process so the dispatcher gate is not dormant. + # Guards are identical to the worker side — idempotent if already set. + if envs.FD_T53_HEAD_WISE_SWA_FIXTURE: + cfg = self.model_config + n_kv = getattr(cfg, "num_key_value_heads", 1) or 1 + ratio = envs.FD_T53_HEAD_WISE_SWA_RATIO if envs.FD_T53_HEAD_WISE_SWA_RATIO is not None else (1.0 / n_kv) + if getattr(cfg, "window_size", None) is None: + cfg.window_size = 4096 + if getattr(cfg, "sink_size", None) is None: + cfg.sink_size = 0 + if getattr(cfg, "window_attn_skip_freq", None) is None: + cfg.window_attn_skip_freq = 1 + if getattr(cfg, "head_wise_swa_ratio", None) is None: + cfg.head_wise_swa_ratio = ratio if test_mode: return self.check() diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e3d20cc7d02..b511f085d68 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -56,7 +56,7 @@ @dataclass class ScheduledTaskBase: """ - Task for Scheduled. + Task for Scheduled """ idx: int @@ -84,6 +84,8 @@ class ScheduledDecodeTask(ScheduledTaskBase): """ block_tables: list[int] = field(default_factory=list) + # T53 PR2 will surface per-head block tables to the kernel; PR1 keeps the + # head-wise data inside ``swa_head_block_tables`` for cache management only. @dataclass @@ -252,6 +254,215 @@ def __init__(self, max_num_seqs, config, tensor_parallel_size, splitwise_role, l # Scheduler-side requests that have not been moved into resource manager waiting queue yet. self.scheduler_unhandled_request_num = 0 + # T53 PR1 head-wise SWA recycle state (default-off; populated only when + # FD_HEAD_WISE_KV_CACHE=1). Both maps are keyed by request_id; entries + # are removed by ``_free_blocks`` to prevent the P4 cross-request leak + # flagged in the architecture brief §6. + self.swa_head_recycle_upto: dict[str, list[int]] = {} + self.swa_head_block_tables: dict[str, list[list[int]]] = {} + self.swa_legacy_recycle_upto: dict[str, int] = {} + self.swa_legacy_recycled_blocks: dict[str, set[int]] = {} + + def _swa_window_sink_block(self): + """Return ``(window_blocks, sink_blocks, block_size)`` for SWA recycle. + + Reads ``window_size`` and ``sink_size`` from ``model_config`` (set by + the T53 fixture hook in ``paddleformers/base.py``); falls back to + ``(0, 0)`` when the attributes are absent so callers naturally no-op. + """ + block_size = max(1, int(self.config.cache_config.block_size)) + window = int(getattr(self.config.model_config, "window_size", 0) or 0) + sink = int(getattr(self.config.model_config, "sink_size", 0) or 0) + # ceil(sink/bs) sink blocks must always be retained. + sink_blocks = (sink + block_size - 1) // block_size if sink > 0 else 0 + # window_size tokens of tail must always be retained. + window_blocks = (window + block_size - 1) // block_size if window > 0 else 0 + return window_blocks, sink_blocks, block_size + + def _num_swa_heads(self) -> int: + """Number of KV heads marked as SWA per the head_wise_swa_ratio fixture. + + Convention: positive ratios mark at least one SWA row, capped at KV heads. + Matches the per-head recycle fixture: the leading KV-head group is designated SWA. + + TP-aware (P10 review fix): ``num_key_value_heads`` on ``model_config`` + is the GLOBAL count. Under tensor parallelism each rank only holds + ``num_kv_heads // tp_size`` rows in its head-wise sidecar, so we must + divide here to avoid over-allocating per rank and indexing past the + local KV head count downstream. + """ + kv_num_heads_global = int(getattr(self.config.model_config, "num_key_value_heads", 0) or 0) + if kv_num_heads_global <= 0: + return 0 + tp_size = max(1, int(getattr(self.config.parallel_config, "tensor_parallel_size", 1) or 1)) + # Local KV heads on this rank. GQA/MQA models can have kv < tp; in that + # case Paddle replicates KV across ranks and each rank still owns the + # full local set, so floor-divide-then-max-1 keeps us correct. + kv_num_heads = ( + max(1, kv_num_heads_global // tp_size) if kv_num_heads_global >= tp_size else kv_num_heads_global + ) + ratio = float(getattr(self.config.model_config, "head_wise_swa_ratio", 0.0) or 0.0) + if ratio <= 0.0: + return 0 + if ratio >= 1.0: + return kv_num_heads + return max(1, min(kv_num_heads, int(round(kv_num_heads * ratio)))) + + def _should_use_head_wise_swa(self, num_blocks: int) -> bool: + """Return True when the default-off head-wise SWA sidecar should be populated.""" + return ( + bool(envs.FD_HEAD_WISE_KV_CACHE) + and float(getattr(self.config.model_config, "head_wise_swa_ratio", 0.0) or 0.0) > 0.0 + and int(getattr(self.config.model_config, "window_size", 0) or 0) > 0 + and hasattr(self.cache_manager, "allocate_gpu_blocks_head_wise") + and hasattr(self.cache_manager, "recycle_gpu_blocks_head_wise") + and num_blocks > 0 + ) + + def _should_skip_swa_recycle_for_overlap(self, request: Request) -> bool: + """Return True if any in-flight cache swap targets this request's blocks. + + ``CacheSwapMetadata`` does not expose a global ``is_inflight`` query, + so we approximate by inspecting the per-request swap and evict queues + the V1 scheduler already publishes (see ``ScheduledTaskBase``). Any + unfinished metadata that touches a block currently owned by ``request`` + is treated as in-flight: skipping the recycle for this turn is safe + because the next schedule call will retry. + """ + block_set = set(int(b) for row in self.swa_head_block_tables.get(request.request_id, []) for b in row) + if not block_set: + return False + for queue_name in ("cache_swap_metadata", "cache_evict_metadata"): + queue = getattr(request, queue_name, None) or [] + for meta in queue: + # P9 fix: missing ``success`` must default to the SAFE direction + # (treat as in-flight) so recycle never overlaps a transfer. + if getattr(meta, "success", False): + continue # already-completed swaps cannot block recycle + ids = list(getattr(meta, "src_block_ids", []) or []) + list(getattr(meta, "dst_block_ids", []) or []) + if any(int(b) in block_set for b in ids): + return True + return False + + def _extend_head_wise_block_tables(self, request: Request, num_new_blocks: int) -> list[list[int]]: + """Allocate ``num_new_blocks`` per head and append to head-wise block table. + + P10/P7 review fix: previously the broad ``except Exception: pass``-style + fallback (warn + return existing rows) silently desynchronised the + head-wise sidecar from the real KV pool — real KV had already been + allocated by the caller, but the sidecar entry was left empty. Any + downstream recycle then leaked real-KV blocks. We now LOG and RE-RAISE + so the caller (``_allocate_gpu_blocks``) can roll the real-KV + allocation back atomically and propagate the failure. + """ + try: + new_rows = self.cache_manager.allocate_gpu_blocks_head_wise(num_new_blocks, request.request_id) + except Exception as exc: + llm_logger.error( + f"head-wise SWA sidecar allocation FAILED for request {request.request_id} " + f"(num_new_blocks={num_new_blocks}); rolling back real-KV allocation: {exc}" + ) + raise + if not new_rows: + llm_logger.error( + f"head-wise SWA sidecar allocation returned no rows for request {request.request_id} " + f"(num_new_blocks={num_new_blocks}); rolling back real-KV allocation" + ) + raise RuntimeError(f"head-wise SWA sidecar empty result for request {request.request_id}") + existing = self.swa_head_block_tables.setdefault( + request.request_id, + [[] for _ in range(max(1, int(getattr(self.cache_manager, "kv_num_heads", len(new_rows) or 1))))], + ) + for h, row in enumerate(new_rows): + if h >= len(existing): + existing.append([]) + existing[h].extend(row) + return existing + + def _recycle_legacy_swa_blocks(self, request: Request, prev: list[int], recycle_from_floor: int) -> int: + """Return fully-aged uniform-SWA legacy block ids to the legacy pool once. + + The active ``request.block_tables`` list stays untouched because worker + kernels index it by absolute block position. ``_free_blocks`` filters + these ids later to avoid double-recycling at request teardown. + """ + block_tables = list(getattr(request, "block_tables", []) or []) + if not block_tables or not hasattr(self.cache_manager, "recycle_gpu_blocks"): + return 0 + head_blocks = self.swa_head_block_tables.get(request.request_id) or [] + local_kv_heads = len(head_blocks) + kv_num_heads = int(getattr(self.config.model_config, "num_key_value_heads", 0) or 0) or local_kv_heads + if kv_num_heads <= 0 or self._num_swa_heads() != kv_num_heads or local_kv_heads <= 0: + return 0 + if len(prev) < local_kv_heads: + return 0 + recycle_upto = min(int(prev[h]) for h in range(local_kv_heads)) + start = max(int(self.swa_legacy_recycle_upto.get(request.request_id, recycle_from_floor)), recycle_from_floor) + end = min(int(recycle_upto), len(block_tables)) + if end <= start: + return 0 + + already = self.swa_legacy_recycled_blocks.setdefault(request.request_id, set()) + legacy_ids = [int(block_id) for block_id in block_tables[start:end] if int(block_id) not in already] + self.swa_legacy_recycle_upto[request.request_id] = end + if not legacy_ids: + return 0 + self.cache_manager.recycle_gpu_blocks(legacy_ids, request.request_id) + already.update(legacy_ids) + return len(legacy_ids) + + def recycle_request_swa_head_cache(self, request: Request) -> int: + """Recycle SWA tail blocks per head (T53 PR1 §2.3). + + Computes the open interval ``[ceil(sink/bs), floor((T-window)/bs))`` + of fully-aged blocks and pushes them back to the head-wise free heap + via ``cache_manager.recycle_gpu_blocks_head_wise``. ``swa_head_recycle_upto`` + is monotonic per head so we never re-release a block. + + Returns the number of blocks released across all heads (0 when no-op). + Default-off: returns immediately when ``FD_HEAD_WISE_KV_CACHE != 1``. + """ + if not envs.FD_HEAD_WISE_KV_CACHE: + return 0 + window_blocks, sink_blocks, block_size = self._swa_window_sink_block() + total_tokens = int(getattr(request, "num_total_tokens", 0) or 0) or int( + getattr(request, "num_computed_tokens", 0) or 0 + ) + if block_size <= 0 or total_tokens < (window_blocks + 1) * block_size: + return 0 + if total_tokens % block_size != 0: + return 0 + head_blocks = self.swa_head_block_tables.get(request.request_id) + if not head_blocks: + return 0 + if self._should_skip_swa_recycle_for_overlap(request): + return 0 + + recycle_upto = max(0, (total_tokens - max(0, window_blocks * block_size)) // block_size) + # Sink floor: never release the first ``sink_blocks`` per head. + recycle_from_floor = sink_blocks + + prev = self.swa_head_recycle_upto.setdefault(request.request_id, [recycle_from_floor for _ in head_blocks]) + if len(prev) < len(head_blocks): + prev.extend([recycle_from_floor for _ in range(len(head_blocks) - len(prev))]) + released_total = 0 + for h in range(self._num_swa_heads()): + if h >= len(head_blocks): + continue + row = head_blocks[h] + start = max(int(prev[h]), recycle_from_floor) + end = min(int(recycle_upto), len(row)) + if end <= start: + continue + to_release = row[start:end] + if not to_release: + continue + self.cache_manager.recycle_gpu_blocks_head_wise(to_release, request.request_id) + prev[h] = end # monotone advance + released_total += len(to_release) + self._recycle_legacy_swa_blocks(request, prev, recycle_from_floor) + return released_total + def allocated_slots(self, request: Request): return len(request.block_tables) * self.config.cache_config.block_size @@ -908,6 +1119,12 @@ def schedule(self): need_abort_requests.append(request) continue + # T53 PR1 head-wise SWA recycle (§2.3 recycle-before-allocate). + # Default-off: helper returns 0 immediately when + # FD_HEAD_WISE_KV_CACHE != 1, so the legacy path is bit-identical. + if envs.FD_HEAD_WISE_KV_CACHE: + self.recycle_request_swa_head_cache(request) + if ( self.allocated_slots(request) - request.num_total_tokens <= self.config.cache_config.prealloc_dec_block_slot_num_threshold @@ -1366,9 +1583,27 @@ def get_real_bsz(self) -> int: def _allocate_gpu_blocks(self, request: Request, num_blocks: int) -> List[int]: llm_logger.debug(f"[allocate_gpu_blocks] request_id={request.request_id}, num_blocks={num_blocks}") if self.enable_cache_manager_v1: - return self.cache_manager.allocate_gpu_blocks(request, num_blocks) + block_ids = self.cache_manager.allocate_gpu_blocks(request, num_blocks) else: - return self.cache_manager.allocate_gpu_blocks(num_blocks, request.request_id) + block_ids = self.cache_manager.allocate_gpu_blocks(num_blocks, request.request_id) + if self._should_use_head_wise_swa(num_blocks): + # P10/Fix-3 (review): real-KV and head-wise sidecar must commit + # atomically. If the sidecar extension fails we MUST recycle the + # real-KV blocks we just acquired, otherwise ownership drifts and + # those blocks leak (preempt/abort path will not see them). + try: + self._extend_head_wise_block_tables(request, num_blocks) + except Exception: + if block_ids and hasattr(self.cache_manager, "recycle_gpu_blocks"): + try: + self.cache_manager.recycle_gpu_blocks(block_ids, request.request_id) + except Exception as recycle_exc: # pragma: no cover - defensive + llm_logger.error( + f"failed to roll back real-KV blocks {block_ids} for request " + f"{request.request_id} after head-wise sidecar failure: {recycle_exc}" + ) + raise + return block_ids def _request_match_blocks(self, request: Request, skip_storage: bool = True): """ @@ -1640,6 +1875,34 @@ def add_prefilled_request(self, request_output: RequestOutput): self.running.append(request) def _free_blocks(self, request: Request): + early_recycled_legacy = set() + + def _filter_early_recycled(block_ids): + if not early_recycled_legacy: + return block_ids + return [block_id for block_id in block_ids if int(block_id) not in early_recycled_legacy] + + # T53 PR1 head-wise SWA: release any leftover head-wise blocks and clear + # per-request recycle state. P4 fix from architecture brief §6: without + # this drop, ``swa_head_recycle_upto`` and ``swa_head_block_tables`` + # would leak across request_id reuse and corrupt the next request's + # monotone recycle cursor. + if envs.FD_HEAD_WISE_KV_CACHE: + head_blocks = self.swa_head_block_tables.pop(request.request_id, None) + head_cursor = self.swa_head_recycle_upto.pop(request.request_id, None) + self.swa_legacy_recycle_upto.pop(request.request_id, None) + early_recycled_legacy = self.swa_legacy_recycled_blocks.pop(request.request_id, set()) + if head_blocks and hasattr(self.cache_manager, "recycle_gpu_blocks_head_wise"): + if head_cursor: + _, recycle_from_floor, _ = self._swa_window_sink_block() + remaining = [] + for h, row in enumerate(head_blocks): + cursor = int(head_cursor[h]) if h < len(head_cursor) else recycle_from_floor + floor = min(recycle_from_floor, len(row)) + cursor = max(floor, min(cursor, len(row))) + remaining.append(list(row[:floor]) + list(row[cursor:])) + head_blocks = remaining + self.cache_manager.recycle_gpu_blocks_head_wise(head_blocks, request.request_id) if self.enable_cache_manager_v1: self.cache_manager.request_finish(request) elif ( @@ -1647,10 +1910,10 @@ def _free_blocks(self, request: Request): ): self.cache_manager.release_block_ids(request) self.cache_manager.recycle_gpu_blocks( - request.block_tables[request.num_cached_blocks :], request.request_id + _filter_early_recycled(request.block_tables[request.num_cached_blocks :]), request.request_id ) else: - self.cache_manager.recycle_gpu_blocks(request.block_tables, request.request_id) + self.cache_manager.recycle_gpu_blocks(_filter_early_recycled(request.block_tables), request.request_id) request.block_tables = [] if request.request_id in self.using_extend_tables_req_id: diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 6be28f1f3be..819d48c32d2 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -22,7 +22,7 @@ def _validate_split_kv_size(value: int) -> int: - """Validate FD_DETERMINISTIC_SPLIT_KV_SIZE is a positive power of 2.""" + """Validate FD_DETERMINISTIC_SPLIT_KV_SIZE is a positive power of 2""" if value <= 0 or (value & (value - 1)) != 0: raise ValueError(f"FD_DETERMINISTIC_SPLIT_KV_SIZE must be a positive power of 2, got {value}.") return value @@ -108,6 +108,12 @@ def _validate_split_kv_size(value: int) -> int: "FD_ENC_DEC_BLOCK_NUM": lambda: int(os.getenv("FD_ENC_DEC_BLOCK_NUM", "2")), # enbale max prefill of one execute step "FD_ENABLE_MAX_PREFILL": lambda: int(os.getenv("FD_ENABLE_MAX_PREFILL", "0")), + # T53: per-head SWA block recycle toggles — all default-off; requires FD_HEAD_WISE_KV_CACHE=1 to enter recycle path. + "FD_HEAD_WISE_KV_CACHE": lambda: int(os.getenv("FD_HEAD_WISE_KV_CACHE", "0")), + "FD_T53_HEAD_WISE_SWA_FIXTURE": lambda: int(os.getenv("FD_T53_HEAD_WISE_SWA_FIXTURE", "0")), + "FD_T53_HEAD_WISE_SWA_RATIO": lambda: ( + float(os.getenv("FD_T53_HEAD_WISE_SWA_RATIO")) if os.getenv("FD_T53_HEAD_WISE_SWA_RATIO", "") else None + ), # Whether to use PLUGINS. "FD_PLUGINS": lambda: None if "FD_PLUGINS" not in os.environ else os.environ["FD_PLUGINS"].split(","), # set trace attribute job_id. diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eba781faae0..adb5a508e55 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -421,6 +421,7 @@ def forward_mixed( q_norm_weight, k_norm_weight, getattr(layer, "sinks", None), + getattr(forward_meta, "block_tables_headwise", None), getattr(layer, "rms_norm_eps", 1e-6), metadata._fuse_kernel_compute_dtype, getattr(layer, "cache_quant_type_str", "none"), @@ -477,6 +478,7 @@ def forward_mixed( q_norm_weight, k_norm_weight, getattr(layer, "sinks", None), + getattr(forward_meta, "block_tables_headwise", None), getattr(layer, "rms_norm_eps", 1e-6), metadata._fuse_kernel_compute_dtype, getattr(layer, "cache_quant_type_str", "none"), diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 8b36ffa85b0..c7cb8bbd8a2 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -66,6 +66,7 @@ def append_attention( q_norm_weight: Optional[paddle.Tensor] = None, k_norm_weight: Optional[paddle.Tensor] = None, sinks: Optional[paddle.Tensor] = None, + block_tables_headwise: Optional[paddle.Tensor] = None, rms_norm_eps: float = 1e-6, compute_type: str = "bf16", cache_quant_type: str = "none", @@ -129,6 +130,7 @@ def append_attention( q_norm_weight, k_norm_weight, sinks, + block_tables_headwise, rms_norm_eps, compute_type, cache_quant_type, @@ -188,6 +190,7 @@ def append_attention( q_norm_weight, k_norm_weight, sinks, + block_tables_headwise, rms_norm_eps, compute_type, cache_quant_type, @@ -257,6 +260,7 @@ def append_attention_with_output( q_norm_weight: Optional[paddle.Tensor] = None, k_norm_weight: Optional[paddle.Tensor] = None, sinks: Optional[paddle.Tensor] = None, + block_tables_headwise: Optional[paddle.Tensor] = None, rms_norm_eps: float = 1e-6, compute_type: str = "bf16", cache_quant_type: str = "none", @@ -317,6 +321,7 @@ def append_attention_with_output( q_norm_weight, k_norm_weight, sinks, + block_tables_headwise, rms_norm_eps, compute_type, cache_quant_type, diff --git a/fastdeploy/model_executor/models/paddleformers/base.py b/fastdeploy/model_executor/models/paddleformers/base.py index 5eb981d5300..bf67d1a5411 100644 --- a/fastdeploy/model_executor/models/paddleformers/base.py +++ b/fastdeploy/model_executor/models/paddleformers/base.py @@ -14,7 +14,7 @@ # limitations under the License. """ -"""Generic PaddleFormers modeling backend base class.""" +"""Generic PaddleFormers modeling backend base class""" import re from collections.abc import Iterable @@ -26,6 +26,7 @@ from paddleformers.transformers import AutoModel, PretrainedModel from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.model_executor.forward_meta import ForwardMeta # noqa: F401 from fastdeploy.model_executor.graph_optimization.decorator import ( support_graph_optimization, @@ -788,6 +789,20 @@ def create_attention_instances(self) -> dict[int, Attention]: if not hasattr(self.fd_config.model_config, "sliding_window") and sliding_window is not None: self.fd_config.model_config.sliding_window = sliding_window + # Per-head SWA recycle fixture hook — activated by FD_T53_HEAD_WISE_SWA_FIXTURE=1 (default off). + if envs.FD_T53_HEAD_WISE_SWA_FIXTURE: + cfg = self.fd_config.model_config + n_kv = getattr(cfg, "num_key_value_heads", 1) or 1 + ratio = envs.FD_T53_HEAD_WISE_SWA_RATIO if envs.FD_T53_HEAD_WISE_SWA_RATIO is not None else (1.0 / n_kv) + if getattr(cfg, "window_size", None) is None: + cfg.window_size = 4096 + if getattr(cfg, "sink_size", None) is None: + cfg.sink_size = 0 + if getattr(cfg, "window_attn_skip_freq", None) is None: + cfg.window_attn_skip_freq = 1 + if getattr(cfg, "head_wise_swa_ratio", None) is None: + cfg.head_wise_swa_ratio = ratio + attention_instances = {} for i in range(num_layers): attention_instances[i] = Attention( diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1127d5c724e..46b4d441f8a 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -195,7 +195,7 @@ def __init__( else: self.encoder_cache = None - # Note(Zhengshifeng) init video cache for VL model + # Note(Zhengshifeng) init video cache for VL model. self.video_cache = {} # Sampler @@ -1470,6 +1470,8 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): cu_seqlens_q=self.share_inputs["cu_seqlens_q"], cu_seqlens_k=self.share_inputs["cu_seqlens_k"], block_tables=self.share_inputs["block_tables"][:num_running_requests], + # PR2 scope: head-wise block tables and the discrete AppendAttention + # kernel that will consume them are deferred; this comment is the PR1 placeholder. caches=self.share_inputs["caches"], encoder_batch_ids=self.share_inputs["encoder_batch_ids"], encoder_tile_ids_per_batch=self.share_inputs["encoder_tile_ids_per_batch"], @@ -1983,7 +1985,6 @@ def _dummy_run( capture_prefill: bool = False, accept_all_drafts: bool = False, reject_all_drafts: bool = False, - step_use_cudagraph=False, ) -> paddle.Tensor: """ Use dummy inputs to run before formal execution. @@ -2016,10 +2017,8 @@ def _dummy_run( while True: # 1. Initialize forward meta and attention meta data self._prepare_inputs(is_dummy_or_profile_run=True) - - if not (in_capturing or step_use_cudagraph): - self.forward_meta.step_use_cudagraph = False # 2. Padding inputs for cuda graph + self.forward_meta.step_use_cudagraph = in_capturing and self.forward_meta.step_use_cudagraph self.padding_cudagraph_inputs() # Compute position_ids and slot_mapping self._compute_position_ids_and_slot_mapping() diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 22fee0a92ad..329f39222e8 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -27,7 +27,7 @@ class InputBatch: def __getitem__(self, key): - """Support dictionary-style attribute access""" + """Support dictionary-style attribute access.""" if hasattr(self, key): return getattr(self, key) raise KeyError(f"'{key}' is not a valid attribute of InputBatch") @@ -248,6 +248,8 @@ def init_share_inputs(self): pre_max_block_num = ( self.model_config.max_model_len + self.cache_config.block_size - 1 ) // self.cache_config.block_size + self.cache_config.enc_dec_block_num + # PR2 scope: block_tables_head_wise (3D, head-major) is deferred; PR1 keeps block_tables 2D. + # See: FD_HEAD_WISE_KV_CACHE path in prefix_cache_manager.allocate_gpu_blocks_head_wise. self.block_tables = paddle.full([max_num_seqs, pre_max_block_num], -1, dtype="int32") # Initialize free list diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 28a943cf9d4..e9950dbbf29 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -82,7 +82,7 @@ def get_worker(fd_config: FDConfig, local_rank: int, rank: int) -> WorkerBase: """ - get worker of different device + get worker of different device. """ if fd_config.model_config.enable_logprob and not current_platform.is_cuda() and not current_platform.is_xpu(): raise NotImplementedError("Only CUDA and XPU platforms support logprob.") @@ -481,8 +481,6 @@ def event_loop_normal(self) -> None: req_dicts = None self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time()) - self._tp_barrier_wait() if tp_size > 1 else None - # The first worker detects whether there are tasks in the task queue if tp_rank == 0: if self.task_queue.exist_tasks(): @@ -1328,17 +1326,7 @@ def run_worker_proc() -> None: # Trigger CUDAGraph capture worker_proc.graph_optimize_and_warm_up_model() - # Note(ZKK): - # In some scenarios, we need to evaluate the performance of various model based on a fixed batch size and input length. - # Instead of doing end to end tests which is very unstable, we can profile the following line of code to pick the best model. - # so we add an environment variable RUN_DUMMY_FOR_PROFILE to control whether to run dummy run for profile. - # Any Question refer to ChangWenBin. - if int(os.getenv("RUN_DUMMY_FOR_PROFILE", "0")) == 1: - worker_proc.worker.model_runner._dummy_run( - num_tokens=100, batch_size=1, expected_decode_len=10, step_use_cudagraph=True - ) - - # Initialize health status + # Initialize health status and start serving (T53: no per-head state persisted here) worker_proc.init_health_status() worker_proc.start_task_queue_service() diff --git a/tests/cache_manager/test_benchmark_head_wise_swa.py b/tests/cache_manager/test_benchmark_head_wise_swa.py new file mode 100644 index 00000000000..71c5eff1890 --- /dev/null +++ b/tests/cache_manager/test_benchmark_head_wise_swa.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""T53 PR1 head-wise SWA recycle micro-benchmark (CPU, no model). + +Mirrors the ``tests/spec_decode/test_benchmark_ngram_kernel.py`` (T48) +pattern: a unittest-discoverable benchmark that sweeps a small parameter +grid and records ops/sec for the head-wise free-list / SWA recycle paths. + +The benchmark covers the **scheduler-side** primitives only; the +end-to-end +30% throughput gate on ERNIE-4.5-21B-A3B-Paddle is still +exercised by ``.checkpoints/h10/task-53/scripts/bench_recycle.sh`` on +A800 (BF16, fixed-IO, same VRAM, fixture mode). + +Groups +------ + 1. kv_num_heads — [2, 4, 8, 16] (TP shards) + 2. blocks_per_req — [16, 64, 256] (pressure on free list) + 3. window/sink ratio — [(64,32), (1024,128), (4096,256)] + +Run:: + + cd FastDeploy && python tests/cache_manager/test_benchmark_head_wise_swa.py +""" +from __future__ import annotations + +import time +import unittest +from types import SimpleNamespace + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + +WARMUP = 50 +NUM_ITERS = 500 + + +def _build_prefix_cache(num_blocks: int, kv_num_heads: int) -> PrefixCacheManager: + pcm = object.__new__(PrefixCacheManager) + pcm.num_gpu_blocks = num_blocks + pcm.kv_num_heads = kv_num_heads + pcm._head_wise_free_lists = [list(range(num_blocks)) for _ in range(kv_num_heads)] + pcm._head_wise_alloc = {} + return pcm + + +def _build_rm(window: int, sink: int, block_size: int = 16, kv_num_heads: int = 4): + rm = object.__new__(ResourceManagerV1) + rm.config = SimpleNamespace( + cache_config=SimpleNamespace(block_size=block_size), + model_config=SimpleNamespace(window_size=window, sink_size=sink), + ) + + class _Cache: + def __init__(self, n): + self.kv_num_heads = n + self.recycled = 0 + + def recycle_gpu_blocks_head_wise(self, ids, req_id=None): + self.recycled += 1 + + def allocate_gpu_blocks_head_wise(self, n, req_id=None): + return [list(range(n)) for _ in range(kv_num_heads)] + + rm.cache_manager = _Cache(kv_num_heads) + rm.swa_head_recycle_upto = {} + rm.swa_head_block_tables = {} + return rm + + +def _bench(fn, *args, iters=NUM_ITERS, warmup=WARMUP): + for _ in range(warmup): + fn(*args) + t0 = time.perf_counter() + for _ in range(iters): + fn(*args) + dt = time.perf_counter() - t0 + return iters / dt if dt > 0 else float("inf") + + +class HeadWiseSWABenchmark(unittest.TestCase): + """Micro-bench head-wise alloc / recycle paths""" + + def test_alloc_recycle_throughput_grid(self): + rows = [] + for kv_heads in (2, 4, 8, 16): + for bpr in (16, 64, 256): + pcm = _build_prefix_cache(num_blocks=bpr * 8, kv_num_heads=kv_heads) + + def alloc(): + pcm._head_wise_free_lists = [list(range(bpr * 8)) for _ in range(kv_heads)] + return [[fl.pop() for _ in range(bpr)] for fl in pcm._head_wise_free_lists] + + ops = _bench(alloc, iters=200, warmup=20) + rows.append((kv_heads, bpr, ops)) + + # Print compact table; pytest -s shows it. + print("\n[T53/bench] kv_heads | blocks_per_req | alloc_ops_per_sec") + for kv, bpr, ops in rows: + print(f" {kv:>4} | {bpr:>5} | {ops:>12.0f}") + + # Sanity gate: largest config should still hit > 100 ops/s on CPU. + worst = min(r[2] for r in rows) + self.assertGreater(worst, 100.0, f"alloc throughput collapsed: {worst:.1f} ops/s") + + def test_swa_window_sink_recycle_throughput(self): + rows = [] + for window, sink in ((64, 32), (1024, 128), (4096, 256)): + rm = _build_rm(window=window, sink=sink, kv_num_heads=4) + req = SimpleNamespace( + request_id="bench-0", + num_total_tokens=window * 2, + num_computed_tokens=window * 2, + cache_swap_metadata=[], + cache_evict_metadata=[], + ) + # Pre-populate per-head block tables so recycle has work to do. + rm.swa_head_block_tables[req.request_id] = [list(range(window // 16 + 4)) for _ in range(4)] + + def step(): + # Reset cursor each iter so recycle does work on every call. + rm.swa_head_recycle_upto[req.request_id] = [0 for _ in rm.swa_head_block_tables[req.request_id]] + rm.recycle_request_swa_head_cache(req) + + ops = _bench(step, iters=300, warmup=30) + rows.append((window, sink, ops)) + + print("\n[T53/bench] window | sink | recycle_ops_per_sec") + for w, s, ops in rows: + print(f" {w:>5} | {s:>4} | {ops:>12.0f}") + + # Sanity: even tightest window/sink should sustain > 50 ops/s on CPU. + worst = min(r[2] for r in rows) + self.assertGreater(worst, 50.0, f"recycle throughput collapsed: {worst:.1f} ops/s") + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/cache_manager/test_head_wise_abort_reset.py b/tests/cache_manager/test_head_wise_abort_reset.py new file mode 100644 index 00000000000..d9a2c5cea91 --- /dev/null +++ b/tests/cache_manager/test_head_wise_abort_reset.py @@ -0,0 +1,193 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 head-wise SWA abort-reset tests for ``ResourceManagerV1._free_blocks`` + +Case #8 from the feature spec: when a request is aborted mid-flight the +``_free_blocks`` hook (gated by ``FD_HEAD_WISE_KV_CACHE``) MUST + + * release every per-head block id back into the head-wise free heap, + * clear the per-request cursor in ``swa_head_recycle_upto``, + * clear the per-request table in ``swa_head_block_tables``, + * remain idempotent under repeated abort calls (no duplicate heap entries, + no KeyError, no exception). + +Approach mirrors ``test_head_wise_freelist.py`` and ``test_swa_recycle.py``: +both ``PrefixCacheManager`` and ``ResourceManagerV1`` are constructed via +``object.__new__`` because their real ``__init__`` requires a wired +``FDConfig`` plus running IPC signals that cannot be brought up on the +workstation. No MagicMock anywhere — the cache manager is the real +``PrefixCacheManager`` so the heap invariant and dedup logic exercised by +``recycle_gpu_blocks_head_wise`` are the real production code paths. +""" + +import heapq +from types import SimpleNamespace + +import pytest + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + + +class _DummyMetric: + def set(self, *_a, **_k): + pass + + def inc(self, *_a, **_k): + pass + + def dec(self, *_a, **_k): + pass + + +class _DummyMainMetrics: + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return _DummyMetric() + + +@pytest.fixture(autouse=True) +def _patch_metrics(monkeypatch): + monkeypatch.setattr( + "fastdeploy.cache_manager.prefix_cache_manager.main_process_metrics", + _DummyMainMetrics(), + ) + + +def _build_pcm(num_gpu_blocks=8, kv_num_heads=4): + """Real PrefixCacheManager with head-wise free list initialized.""" + pcm = object.__new__(PrefixCacheManager) + pcm.cache_config = SimpleNamespace(enable_prefix_caching=False) + pcm.num_gpu_blocks = num_gpu_blocks + pcm.kv_num_heads = kv_num_heads + pcm.head_wise = True + pcm.total_head_wise_cache_ids = 0 + pcm.gpu_free_block_list = [] + pcm._init_head_wise_free_list() + # _free_blocks falls through to enable_cache_manager_v1 branch below; give + # the PCM a no-op request_finish so the legacy code path does not crash. + pcm.request_finish = lambda _req: None + return pcm + + +def _build_rm(pcm): + """Bare ResourceManagerV1 wired to ``pcm`` with the legacy V1 path active.""" + rm = object.__new__(ResourceManagerV1) + rm.cache_manager = pcm + rm.config = SimpleNamespace( + cache_config=SimpleNamespace( + block_size=16, + enable_prefix_caching=False, + ), + scheduler_config=SimpleNamespace(splitwise_role="mixed"), + model_config=SimpleNamespace( + window_size=64, + sink_size=32, + num_key_value_heads=pcm.kv_num_heads, + head_wise_swa_ratio=1.0, + ), + ) + rm.swa_head_recycle_upto = {} + rm.swa_head_block_tables = {} + rm.swa_legacy_recycle_upto = {} + rm.swa_legacy_recycled_blocks = {} + rm.enable_cache_manager_v1 = True # forces request_finish branch + rm.using_extend_tables_req_id = set() + rm.reuse_block_num_map = {} + rm.need_block_num_map = {} + return rm + + +def _fake_request(req_id="req-A"): + return SimpleNamespace( + request_id=req_id, + block_tables=[], + extend_block_tables=[], + num_total_tokens=0, + num_computed_tokens=0, + cache_swap_metadata=[], + cache_evict_metadata=[], + ) + + +def test_abort_releases_head_wise_blocks_back_to_free_list(monkeypatch): + """#8a — aborted req's head-wise ids return to the free heap; heap invariant preserved.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + pcm = _build_pcm(num_gpu_blocks=8, kv_num_heads=4) + rm = _build_rm(pcm) + initial_free = len(pcm.gpu_free_head_wise_block_list) + + # Allocate 3 blocks per head and stash on the per-request map. + allocated = pcm.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-A") + rm.swa_head_block_tables["req-A"] = allocated + assert len(pcm.gpu_free_head_wise_block_list) == initial_free - 12 + + rm._free_blocks(_fake_request("req-A")) + + assert len(pcm.gpu_free_head_wise_block_list) == initial_free, "all 12 ids must return to free heap" + # Heap invariant: smallest id pops first; sequence must be sorted. + snapshot = list(pcm.gpu_free_head_wise_block_list) + pops = [heapq.heappop(snapshot) for _ in range(len(snapshot))] + assert pops == sorted(pops), "free list must remain a valid min-heap after abort" + + +def test_abort_clears_swa_recycle_cursor(monkeypatch): + """#8b — abort drops the per-request entry in ``swa_head_recycle_upto``.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + pcm = _build_pcm() + rm = _build_rm(pcm) + rm.swa_head_recycle_upto["req-B"] = [10, 10, 10, 10] + # No head_blocks for req-B → no recycle call, but the cursor still must be popped. + + rm._free_blocks(_fake_request("req-B")) + + assert "req-B" not in rm.swa_head_recycle_upto + + +def test_abort_clears_swa_head_block_tables(monkeypatch): + """#8c — abort drops the per-request entry in ``swa_head_block_tables``.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + pcm = _build_pcm() + rm = _build_rm(pcm) + allocated = pcm.allocate_gpu_blocks_head_wise(num_blocks=2, req_id="req-C") + rm.swa_head_block_tables["req-C"] = allocated + rm.swa_head_recycle_upto["req-C"] = [0, 0, 0, 0] + + rm._free_blocks(_fake_request("req-C")) + + assert "req-C" not in rm.swa_head_block_tables + assert "req-C" not in rm.swa_head_recycle_upto + + +def test_double_abort_is_idempotent(monkeypatch): + """#8d — second abort is a no-op; free heap size unchanged, no exception, no duplicates.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + pcm = _build_pcm(num_gpu_blocks=8, kv_num_heads=4) + rm = _build_rm(pcm) + initial_free = len(pcm.gpu_free_head_wise_block_list) + + allocated = pcm.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-D") + rm.swa_head_block_tables["req-D"] = allocated + + rm._free_blocks(_fake_request("req-D")) + free_after_first = len(pcm.gpu_free_head_wise_block_list) + assert free_after_first == initial_free + + # Second abort must not raise and must not push any id again. + rm._free_blocks(_fake_request("req-D")) + assert len(pcm.gpu_free_head_wise_block_list) == free_after_first + # No duplicate ids in the heap. + assert len(set(pcm.gpu_free_head_wise_block_list)) == len(pcm.gpu_free_head_wise_block_list) diff --git a/tests/cache_manager/test_head_wise_extend_validation.py b/tests/cache_manager/test_head_wise_extend_validation.py new file mode 100644 index 00000000000..d9fbc8f91a4 --- /dev/null +++ b/tests/cache_manager/test_head_wise_extend_validation.py @@ -0,0 +1,136 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 head-wise SWA extend-validation tests for ``PrefixCacheManager`` + +Case #9 from the feature spec: extending a request's head-wise +allocation at decode time must satisfy four invariants + + * a zero-block extend is a no-op (returns ``[[]] * kv_num_heads``, + free heap unchanged), + * extending past head-wise capacity raises (``assert needed <= len(...)`` + in ``allocate_gpu_blocks_head_wise`` makes this an ``AssertionError``), + * successive extends to the same request yield disjoint ids per head + (the allocator drains via ``heappop`` from a single shared heap so + ids cannot be reissued before recycle), + * after a partial recycle, the next extend reuses recycled ids first + (heap is a min-heap; recycled ids are pushed back via ``heappush``). + +Same ``object.__new__`` construction pattern as ``test_head_wise_freelist.py``. +""" + +import heapq +from types import SimpleNamespace + +import pytest + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager + + +class _DummyMetric: + def set(self, *_a, **_k): + pass + + def inc(self, *_a, **_k): + pass + + def dec(self, *_a, **_k): + pass + + +class _DummyMainMetrics: + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return _DummyMetric() + + +@pytest.fixture(autouse=True) +def _patch_metrics(monkeypatch): + monkeypatch.setattr( + "fastdeploy.cache_manager.prefix_cache_manager.main_process_metrics", + _DummyMainMetrics(), + ) + + +def _build_manager(num_gpu_blocks=8, kv_num_heads=4): + mgr = object.__new__(PrefixCacheManager) + mgr.cache_config = SimpleNamespace(enable_prefix_caching=False) + mgr.num_gpu_blocks = num_gpu_blocks + mgr.kv_num_heads = kv_num_heads + mgr.head_wise = True + mgr.total_head_wise_cache_ids = 0 + mgr.gpu_free_block_list = [] + mgr.gpu_free_head_wise_block_list = [] + mgr._init_head_wise_free_list() + return mgr + + +def test_extend_with_zero_blocks_is_noop(): + """#9a — alloc(0) returns empty per-head rows, free heap unchanged.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + initial_free = len(mgr.gpu_free_head_wise_block_list) + + allocated = mgr.allocate_gpu_blocks_head_wise(num_blocks=0, req_id="req-zero") + + assert len(allocated) == 4 + for row in allocated: + assert row == [] + assert len(mgr.gpu_free_head_wise_block_list) == initial_free + + +def test_extend_more_than_available_raises(): + """#9b — requesting more blocks than head-wise capacity raises ``AssertionError``.""" + mgr = _build_manager(num_gpu_blocks=4, kv_num_heads=4) + # Capacity = 4 blocks per head. Request 5 → needed=20 > free=16. + with pytest.raises(AssertionError): + mgr.allocate_gpu_blocks_head_wise(num_blocks=5, req_id="req-overflow") + + +def test_extend_preserves_per_head_disjointness(): + """#9c — successive extends to the same req yield non-overlapping ids per head.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + + first = mgr.allocate_gpu_blocks_head_wise(num_blocks=2, req_id="req-extend") + second = mgr.allocate_gpu_blocks_head_wise(num_blocks=2, req_id="req-extend") + + # Across the two calls, every id ever issued (irrespective of head) must + # be unique — the allocator pops from a single shared heap. + flat = [cid for row in first for cid in row] + [cid for row in second for cid in row] + assert len(flat) == 16 + assert len(set(flat)) == 16, "no id may be issued twice without a recycle in between" + + +def test_extend_after_partial_recycle_uses_recycled_ids(): + """#9d — recycled ids re-enter the heap and are returned by the next alloc (min-heap).""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + + allocated = mgr.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-cycle") + flat_first = sorted(cid for row in allocated for cid in row) + + # Recycle the lowest 4 ids only. + to_recycle = flat_first[:4] + mgr.recycle_gpu_blocks_head_wise(to_recycle, req_id="req-cycle") + + # Snapshot the heap; the 4 smallest values must be exactly the recycled ids. + snapshot = list(mgr.gpu_free_head_wise_block_list) + smallest_4 = [] + for _ in range(4): + smallest_4.append(heapq.heappop(snapshot)) + assert sorted(smallest_4) == sorted(to_recycle), "recycled ids must be the next to pop" + + # Real next alloc should issue exactly those recycled ids first. + again = mgr.allocate_gpu_blocks_head_wise(num_blocks=1, req_id="req-cycle-2") + flat_again = sorted(cid for row in again for cid in row) + assert flat_again[:4] == sorted(to_recycle) diff --git a/tests/cache_manager/test_head_wise_freelist.py b/tests/cache_manager/test_head_wise_freelist.py new file mode 100644 index 00000000000..478e0de3d0d --- /dev/null +++ b/tests/cache_manager/test_head_wise_freelist.py @@ -0,0 +1,160 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Head-wise KV cache free-list tests for ``PrefixCacheManager`` (T53 PR1) + +Approach: instances are built via ``object.__new__(PrefixCacheManager)`` plus +manual attribute setup. Real ``__init__`` requires a fully-wired ``FDConfig`` +plus running IPC signals which cannot be brought up on a CPU-only workstation +without GPU paddle. The ``object.__new__`` pattern is the same one used by +H10 task-20 ``common_engine`` tests for the identical reason. +""" + +import heapq +import logging +from types import SimpleNamespace + +import pytest + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager + + +class _DummyMetric: + def __init__(self): + self.values = [] + + def set(self, value): + self.values.append(value) + + def inc(self, value=1): + self.values.append(("inc", value)) + + def dec(self, value=1): + self.values.append(("dec", value)) + + +class _DummyMainMetrics: + def __init__(self): + self._metrics = {} + + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + if name not in self._metrics: + self._metrics[name] = _DummyMetric() + return self._metrics[name] + + +def _build_manager(num_gpu_blocks=8, kv_num_heads=4, head_wise=True): + """Construct a bare ``PrefixCacheManager`` and run the head-wise initializer.""" + mgr = object.__new__(PrefixCacheManager) + mgr.cache_config = SimpleNamespace(enable_prefix_caching=False) + mgr.num_gpu_blocks = num_gpu_blocks + mgr.num_cpu_blocks = 0 + mgr.kv_num_heads = kv_num_heads + mgr.head_wise = head_wise + mgr.total_head_wise_cache_ids = 0 + mgr.gpu_free_block_list = list(range(num_gpu_blocks - 1, -1, -1)) + mgr.gpu_free_head_wise_block_list = [] + if head_wise: + mgr._init_head_wise_free_list() + return mgr + + +@pytest.fixture(autouse=True) +def _patch_metrics(monkeypatch): + """Replace the module-level metrics singleton with a recording dummy.""" + dummy = _DummyMainMetrics() + monkeypatch.setattr( + "fastdeploy.cache_manager.prefix_cache_manager.main_process_metrics", + dummy, + ) + return dummy + + +def test_head_wise_free_list_size(): + """#1 — initializer fills heap with num_gpu_blocks * kv_num_heads ids; smallest pops first.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + assert mgr.total_head_wise_cache_ids == 32 + assert len(mgr.gpu_free_head_wise_block_list) == 8 * 4 + # Legacy free list is left untouched (Fix A: split namespaces). + assert len(mgr.gpu_free_block_list) == 8 + # heapq is a min-heap → smallest id pops first. + assert heapq.heappop(mgr.gpu_free_head_wise_block_list) == 0 + + +def test_head_wise_allocate_returns_2d(): + """#2 — alloc returns [kv_num_heads][N], ids in valid range, no duplicates across heads.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + allocated = mgr.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-2d") + + assert len(allocated) == 4 # one row per kv head + for row in allocated: + assert len(row) == 3 + + flat = [cid for row in allocated for cid in row] + assert len(flat) == 12 + assert len(set(flat)) == 12 # no duplicates anywhere + for cid in flat: + assert 0 <= cid < mgr.total_head_wise_cache_ids + + +def test_head_wise_recycle_round_trip(): + """#3 — alloc → recycle returns the heap to its initial size; subsequent alloc succeeds.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + initial_free = len(mgr.gpu_free_head_wise_block_list) + + allocated = mgr.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-rt") + assert len(mgr.gpu_free_head_wise_block_list) == initial_free - 12 + + mgr.recycle_gpu_blocks_head_wise(allocated, req_id="req-rt") + assert len(mgr.gpu_free_head_wise_block_list) == initial_free + + # Heap invariant preserved. + again = mgr.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="req-rt-2") + assert sum(len(row) for row in again) == 12 + + +def test_head_wise_recycle_dedup_and_range_check(caplog): + """#4 — duplicates and out-of-range ids are dropped (warned), only valid ids re-enter the heap.""" + mgr = _build_manager(num_gpu_blocks=8, kv_num_heads=4) + + # Drain a few ids so we can recycle a known-valid one back. + drained = mgr.allocate_gpu_blocks_head_wise(num_blocks=1, req_id="req-drain") + valid_id = drained[0][0] # an id we now own + duplicate = valid_id # used twice in the recycle list + out_of_range = mgr.total_head_wise_cache_ids + 17 # beyond the valid window + + free_before_recycle = len(mgr.gpu_free_head_wise_block_list) + + # ``get_logger`` may produce a non-propagating logger; force propagation so + # caplog can observe the warnings emitted by the recycle path. + pcm_logger = logging.getLogger("prefix_cache_manager") + prior_propagate = pcm_logger.propagate + pcm_logger.propagate = True + try: + with caplog.at_level(logging.WARNING): + mgr.recycle_gpu_blocks_head_wise( + [valid_id, duplicate, out_of_range], + req_id="req-dedup", + ) + finally: + pcm_logger.propagate = prior_propagate + + # Only the single valid id should have been pushed back. + assert len(mgr.gpu_free_head_wise_block_list) == free_before_recycle + 1 + # Warnings should mention either a dropped duplicate or an out-of-range id. + log_text = "\n".join(record.getMessage() for record in caplog.records) + assert ("duplicate" in log_text) or ("out-of-range" in log_text) diff --git a/tests/cache_manager/test_head_wise_tp_consistency.py b/tests/cache_manager/test_head_wise_tp_consistency.py new file mode 100644 index 00000000000..123dfac307f --- /dev/null +++ b/tests/cache_manager/test_head_wise_tp_consistency.py @@ -0,0 +1,138 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 head-wise SWA tensor-parallel consistency tests (P13 fix, commit 5) + +Case #10 from the feature spec: when the model runs under tensor +parallelism, the head-wise free list MUST shard predictably across ranks. +The fix in commit 5 computes per-rank ``kv_num_heads`` as + + kv_num_heads = max(1, kv_num_heads_global // tp_size) + if kv_num_heads_global >= tp_size else 1 + +inside ``PrefixCacheManager.__init__``. The free list size is then +``num_gpu_blocks * kv_num_heads`` per rank, and the heap is a deterministic +descending range so two ranks built with the same parameters emit the same +allocation order. + +We mirror that formula in a small helper and then build managers via +``object.__new__`` (same rationale as ``test_head_wise_freelist.py``). +The constructor itself cannot run on a CPU-only workstation because it +requires a fully-wired ``FDConfig`` plus running IPC signals. +""" + +from types import SimpleNamespace + +import pytest + +from fastdeploy.cache_manager.prefix_cache_manager import PrefixCacheManager + + +class _DummyMetric: + def set(self, *_a, **_k): + pass + + def inc(self, *_a, **_k): + pass + + def dec(self, *_a, **_k): + pass + + +class _DummyMainMetrics: + def __getattr__(self, name): + if name.startswith("_"): + raise AttributeError(name) + return _DummyMetric() + + +@pytest.fixture(autouse=True) +def _patch_metrics(monkeypatch): + monkeypatch.setattr( + "fastdeploy.cache_manager.prefix_cache_manager.main_process_metrics", + _DummyMainMetrics(), + ) + + +def _kv_heads_per_rank(kv_num_heads_global, tp_size): + """Mirror commit 5 P13 fix from PrefixCacheManager.__init__ exactly.""" + if kv_num_heads_global >= tp_size: + return max(1, kv_num_heads_global // tp_size) + return 1 + + +def _build_for_rank(kv_num_heads_global, tp_size, num_gpu_blocks=8): + """Bare PrefixCacheManager with the per-rank head count baked in.""" + mgr = object.__new__(PrefixCacheManager) + mgr.cache_config = SimpleNamespace(enable_prefix_caching=False) + mgr.num_gpu_blocks = num_gpu_blocks + mgr.kv_num_heads = _kv_heads_per_rank(kv_num_heads_global, tp_size) + mgr.head_wise = True + mgr.total_head_wise_cache_ids = 0 + mgr.gpu_free_block_list = list(range(num_gpu_blocks - 1, -1, -1)) + mgr.gpu_free_head_wise_block_list = [] + mgr._init_head_wise_free_list() + return mgr + + +def test_tp_size_1_uses_full_kv_heads(): + """#10a — single-rank manager carries the full kv_num_heads_global heads.""" + mgr = _build_for_rank(kv_num_heads_global=4, tp_size=1, num_gpu_blocks=8) + assert mgr.kv_num_heads == 4 + assert mgr.total_head_wise_cache_ids == 8 * 4 + assert len(mgr.gpu_free_head_wise_block_list) == 32 + + +def test_tp_size_2_splits_kv_heads_evenly(): + """#10b — two ranks each carry kv_num_heads/2; sum across ranks equals the global total.""" + rank0 = _build_for_rank(kv_num_heads_global=4, tp_size=2, num_gpu_blocks=8) + rank1 = _build_for_rank(kv_num_heads_global=4, tp_size=2, num_gpu_blocks=8) + assert rank0.kv_num_heads == 2 + assert rank1.kv_num_heads == 2 + total_ids = len(rank0.gpu_free_head_wise_block_list) + len(rank1.gpu_free_head_wise_block_list) + assert total_ids == 8 * 4, f"sum across ranks must equal num_gpu_blocks * kv_num_heads_global; got {total_ids}" + + +def test_tp_uneven_split_truncates_via_floor_div(): + """#10c — non-divisible split uses integer floor (4 heads / 3 ranks → 1 head per rank). + + The source code does NOT raise on uneven splits; it deterministically + truncates via ``//``. That means one head's worth of capacity is + "lost" per rank in this configuration — but the loss is predictable + and identical across ranks, which is the property we assert here. + """ + rank = _build_for_rank(kv_num_heads_global=4, tp_size=3, num_gpu_blocks=8) + assert rank.kv_num_heads == 1, "4 // 3 == 1; commit 5 P13 fix is a deterministic floor" + assert len(rank.gpu_free_head_wise_block_list) == 8 + + # Edge case: more ranks than heads → clamp to 1 head per rank (else branch). + over = _build_for_rank(kv_num_heads_global=2, tp_size=4, num_gpu_blocks=8) + assert over.kv_num_heads == 1 + assert len(over.gpu_free_head_wise_block_list) == 8 + + +def test_tp_alloc_order_deterministic_across_ranks(): + """#10d — same construction params on two ranks produce identical allocation order.""" + rank0 = _build_for_rank(kv_num_heads_global=4, tp_size=2, num_gpu_blocks=8) + rank1 = _build_for_rank(kv_num_heads_global=4, tp_size=2, num_gpu_blocks=8) + + a0 = rank0.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="rank0") + a1 = rank1.allocate_gpu_blocks_head_wise(num_blocks=3, req_id="rank1") + assert a0 == a1, "same heap construction must yield identical pop sequence per head" + + # And the second alloc (after the first drained the smallest ids) is still + # deterministic across ranks. + b0 = rank0.allocate_gpu_blocks_head_wise(num_blocks=2, req_id="rank0-b") + b1 = rank1.allocate_gpu_blocks_head_wise(num_blocks=2, req_id="rank1-b") + assert b0 == b1 diff --git a/tests/cache_manager/test_swa_recycle.py b/tests/cache_manager/test_swa_recycle.py new file mode 100644 index 00000000000..d886db8a8da --- /dev/null +++ b/tests/cache_manager/test_swa_recycle.py @@ -0,0 +1,217 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 head-wise SWA recycle tests for ``ResourceManagerV1`` + +These tests cover the three §4 cases from the feature spec: + +* #5 — ``test_swa_recycle_respects_sink_and_window``: sink/window math + releases only fully-aged blocks and ``swa_head_recycle_upto`` is monotone. +* #6 — ``test_swa_recycle_skips_when_swap_inflight``: a request whose + per-request ``cache_swap_metadata`` queue still has unfinished swaps + targeting one of its own blocks is left untouched (recycle is a no-op). +* #7 — ``test_mutual_exclusion_with_prefix_caching``: ``PrefixCacheManager`` + refuses to construct when both ``enable_prefix_caching`` and + ``FD_HEAD_WISE_KV_CACHE`` are on (assertion landed in commit 2). + +Approach: ``ResourceManagerV1`` is built via ``object.__new__`` because its +real ``__init__`` requires a fully-wired ``FDConfig``, IPC signals, and a +running ``CacheManager`` that the workstation cannot bring up. This mirrors +the pattern used in ``test_head_wise_freelist.py`` (commit 2) and the H10 +task-20 ``common_engine`` tests (no MagicMock, real objects only). +""" + +from types import SimpleNamespace + +import pytest + +from fastdeploy.cache_manager.v1.metadata import CacheSwapMetadata +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + + +class _FakeCacheManager: + """Minimal cache manager exposing the head-wise APIs the SWA recycle calls.""" + + def __init__(self, kv_num_heads=2): + self.kv_num_heads = kv_num_heads + self.recycled = [] # list of (req_id, ids) recorded per call + + def recycle_gpu_blocks_head_wise(self, cache_ids, req_id=None): + self.recycled.append((req_id, list(cache_ids))) + + def allocate_gpu_blocks_head_wise(self, num_blocks, req_id=None): + return [list(range(num_blocks)) for _ in range(self.kv_num_heads)] + + +def _build_manager(window=64, sink=32, block_size=16, kv_num_heads=2, head_wise_swa_ratio=1.0): + """Build a bare ``ResourceManagerV1`` with just the SWA recycle state wired.""" + rm = object.__new__(ResourceManagerV1) + rm.config = SimpleNamespace( + cache_config=SimpleNamespace(block_size=block_size), + model_config=SimpleNamespace( + window_size=window, + sink_size=sink, + num_key_value_heads=kv_num_heads, + head_wise_swa_ratio=head_wise_swa_ratio, + ), + ) + rm.cache_manager = _FakeCacheManager(kv_num_heads=kv_num_heads) + rm.swa_head_recycle_upto = {} + rm.swa_head_block_tables = {} + rm.swa_legacy_recycle_upto = {} + rm.swa_legacy_recycled_blocks = {} + return rm + + +def _fake_request(req_id="req-0", num_total_tokens=512, swap_meta=None, evict_meta=None): + return SimpleNamespace( + request_id=req_id, + num_total_tokens=num_total_tokens, + num_computed_tokens=num_total_tokens, + cache_swap_metadata=list(swap_meta or []), + cache_evict_metadata=list(evict_meta or []), + ) + + +@pytest.mark.parametrize( + ("kv_num_heads", "head_wise_swa_ratio", "expected"), + [ + (4, 1.0, 4), + (4, 0.5, 2), + (4, 0.0, 0), + (1, 0.5, 1), + (1, 1.0, 1), + (1, 0.0, 0), + (8, 0.25, 2), + (3, 0.5, 2), + (2, 0.5, 1), + ], +) +def test_num_swa_heads_clamps_positive_ratios(kv_num_heads, head_wise_swa_ratio, expected): + rm = _build_manager(kv_num_heads=kv_num_heads, head_wise_swa_ratio=head_wise_swa_ratio) + + assert rm._num_swa_heads() == expected + + +# --------------------------------------------------------------------------- +# Case #5 — sink/window math +# --------------------------------------------------------------------------- +def test_swa_recycle_respects_sink_and_window(monkeypatch): + """Only blocks in ``[ceil(sink/bs), floor((T-window)/bs))`` are released; cursor is monotone.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + rm = _build_manager(window=64, sink=32, block_size=16, kv_num_heads=2) + # 32 blocks per head, total tokens = 32 * 16 = 512. + rm.swa_head_block_tables["req-0"] = [list(range(100, 132)), list(range(200, 232))] + req = _fake_request(req_id="req-0", num_total_tokens=512) + + released = rm.recycle_request_swa_head_cache(req) + # window_blocks = ceil(64/16) = 4; sink_blocks = ceil(32/16) = 2. + # recycle_upto = (512 - 4*16) // 16 = 28; floor = 2; per-head release = 26 blocks. + assert released == 26 * 2, f"expected 52 blocks released, got {released}" + # Sink (idx 0,1) and tail window (idx 28..31) must remain untouched. + cursor = rm.swa_head_recycle_upto["req-0"] + assert cursor == [28, 28], f"per-head recycle_upto must equal 28, got {cursor}" + # Verify the recycled IDs match the open interval [2, 28) on each head. + head0_ids = list(range(100 + 2, 100 + 28)) + head1_ids = list(range(200 + 2, 200 + 28)) + recorded = [ids for (_, ids) in rm.cache_manager.recycled] + assert head0_ids in recorded and head1_ids in recorded + + # Second call with the same total_tokens must be a no-op (monotone cursor). + rm.cache_manager.recycled.clear() + released_again = rm.recycle_request_swa_head_cache(req) + assert released_again == 0 + assert rm.swa_head_recycle_upto["req-0"] == [28, 28] + + +def test_swa_recycle_only_recycles_swa_heads(monkeypatch): + """Only the first ``round(kv_heads * ratio)`` rows are recycled; full-attention rows stay intact.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + rm = _build_manager(window=64, sink=32, block_size=16, kv_num_heads=4, head_wise_swa_ratio=0.5) + rm.swa_head_block_tables["req-swa-only"] = [ + list(range(100, 132)), + list(range(200, 232)), + list(range(300, 332)), + list(range(400, 432)), + ] + req = _fake_request(req_id="req-swa-only", num_total_tokens=512) + + released = rm.recycle_request_swa_head_cache(req) + + assert released == 26 * 2 + assert rm.swa_head_recycle_upto["req-swa-only"] == [28, 28, 2, 2] + recorded = [ids for (_, ids) in rm.cache_manager.recycled] + assert list(range(100 + 2, 100 + 28)) in recorded + assert list(range(200 + 2, 200 + 28)) in recorded + assert all(not set(ids).intersection(range(300, 432)) for ids in recorded) + + +def test_swa_recycle_fires_only_on_block_boundary(monkeypatch): + """Decode-step recycle is throttled to block boundaries to avoid per-token O(H*B) scans.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + rm = _build_manager(window=64, sink=32, block_size=16, kv_num_heads=2) + rm.swa_head_block_tables["req-boundary"] = [list(range(100, 132)), list(range(200, 232))] + req = _fake_request(req_id="req-boundary", num_total_tokens=511) + + released = rm.recycle_request_swa_head_cache(req) + + assert released == 0 + assert "req-boundary" not in rm.swa_head_recycle_upto + + +# --------------------------------------------------------------------------- +# Case #6 — overlap with in-flight swap +# --------------------------------------------------------------------------- +def test_swa_recycle_skips_when_swap_inflight(monkeypatch): + """An unfinished ``CacheSwapMetadata`` touching the request's blocks blocks the recycle.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + rm = _build_manager(window=64, sink=32, block_size=16, kv_num_heads=2) + rm.swa_head_block_tables["req-1"] = [list(range(100, 132)), list(range(200, 232))] + # Pending swap touching block 105 (which is in the recycle range for head 0). + pending = CacheSwapMetadata(src_block_ids=[105], dst_block_ids=[999], success=False) + req = _fake_request(req_id="req-1", num_total_tokens=512, swap_meta=[pending]) + + released = rm.recycle_request_swa_head_cache(req) + assert released == 0, "recycle must skip when an in-flight swap targets owned blocks" + assert "req-1" not in rm.swa_head_recycle_upto, "cursor must not advance on skip" + assert rm.cache_manager.recycled == [] + + +# --------------------------------------------------------------------------- +# Case #7 — mutual exclusion vs prefix caching +# --------------------------------------------------------------------------- +def test_mutual_exclusion_with_prefix_caching(monkeypatch): + """``PrefixCacheManager`` must refuse when both head-wise and prefix caching are on.""" + monkeypatch.setattr("fastdeploy.cache_manager.prefix_cache_manager.envs.FD_HEAD_WISE_KV_CACHE", 1) + monkeypatch.setattr("fastdeploy.cache_manager.prefix_cache_manager.envs.ENABLE_V1_KVCACHE_SCHEDULER", 1) + from fastdeploy.cache_manager import prefix_cache_manager as pcm_module + + cache_config = SimpleNamespace( + enable_prefix_caching=True, + total_block_num=4, + prefill_kvcache_block_num=4, + num_cpu_blocks=0, + model_cfg=SimpleNamespace(num_key_value_heads=2), + ) + fake_fd_config = SimpleNamespace( + cache_config=cache_config, + speculative_config=SimpleNamespace(), + ) + with pytest.raises((AssertionError, ValueError)): + pcm_module.PrefixCacheManager( + config=fake_fd_config, + tensor_parallel_size=1, + splitwise_role="mixed", + local_data_parallel_id=0, + ) diff --git a/tests/cache_manager/test_swa_recycle_legacy_relief.py b/tests/cache_manager/test_swa_recycle_legacy_relief.py new file mode 100644 index 00000000000..b940120746c --- /dev/null +++ b/tests/cache_manager/test_swa_recycle_legacy_relief.py @@ -0,0 +1,89 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 legacy-pool relief tests for per-head uniform SWA block recycle""" + +from types import SimpleNamespace + +import pytest + +from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 + + +class _FakeCacheManager: + def __init__(self, kv_num_heads=2): + self.kv_num_heads = kv_num_heads + self.head_recycled = [] + self.legacy_recycled = [] + + def recycle_gpu_blocks_head_wise(self, cache_ids, req_id=None): + self.head_recycled.append((req_id, list(cache_ids))) + + def recycle_gpu_blocks(self, block_ids, req_id=None): + self.legacy_recycled.append((req_id, list(block_ids))) + + +def _build_manager(): + rm = object.__new__(ResourceManagerV1) + rm.config = SimpleNamespace( + cache_config=SimpleNamespace(block_size=16, enable_prefix_caching=False), + scheduler_config=SimpleNamespace(splitwise_role="mixed"), + model_config=SimpleNamespace( + window_size=64, + sink_size=32, + num_key_value_heads=2, + head_wise_swa_ratio=1.0, + ), + ) + rm.cache_manager = _FakeCacheManager(kv_num_heads=2) + rm.enable_cache_manager_v1 = False + rm.swa_head_recycle_upto = {} + rm.swa_head_block_tables = {} + rm.swa_legacy_recycle_upto = {} + rm.swa_legacy_recycled_blocks = {} + rm.using_extend_tables_req_id = set() + return rm + + +def test_uniform_swa_recycle_returns_legacy_blocks_without_shifting_block_tables(monkeypatch): + """Uniform SWA frees legacy IDs once while preserving absolute block-table positions.""" + monkeypatch.setattr("fastdeploy.engine.sched.resource_manager_v1.envs.FD_HEAD_WISE_KV_CACHE", 1) + rm = _build_manager() + rm.swa_head_block_tables["req-uniform"] = [list(range(100, 132)), list(range(200, 232))] + original_block_tables = list(range(1000, 1032)) + req = SimpleNamespace( + request_id="req-uniform", + num_total_tokens=512, + num_computed_tokens=512, + block_tables=list(original_block_tables), + num_cached_blocks=0, + ) + + released = rm.recycle_request_swa_head_cache(req) + + assert released == 26 * 2 + assert req.block_tables == original_block_tables + assert rm.cache_manager.legacy_recycled == [("req-uniform", original_block_tables[2:28])] + + rm.recycle_request_swa_head_cache(req) + assert rm.cache_manager.legacy_recycled == [("req-uniform", original_block_tables[2:28])] + + rm._free_blocks(req) + final_legacy_recycle = rm.cache_manager.legacy_recycled[-1][1] + assert not set(final_legacy_recycle).intersection(original_block_tables[2:28]) + assert set(final_legacy_recycle) == set(original_block_tables[:2] + original_block_tables[28:]) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/layers/test_append_attention_head_wise_shapes.py b/tests/layers/test_append_attention_head_wise_shapes.py new file mode 100644 index 00000000000..56360747aa1 --- /dev/null +++ b/tests/layers/test_append_attention_head_wise_shapes.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T53 PR1 head-wise shape/scope oracles + +Case #11 from the feature spec originally proposed kernel-visible +head-wise block tables. PR1 deliberately defers that kernel plumbing to PR2; +these tests pin the PR1 scope instead: + + * per-head cache-management sidecars use head-major rows, + * ``ForwardMeta`` is not extended with ``block_tables_3d`` in PR1, + * block-wise FP8 scale transfer keeps the existing rank-3 + ``[Nb, KvH, Bs]`` contract because ``swap_cache_all_layers`` still + consumes legacy block ids in PR1. + +Paddle is loaded via ``pytest.importorskip`` so the file collects cleanly +on a CPU-only workstation during L0 oracle runs and only executes the +tensor body on a GPU CI worker. +""" + +import pytest + + +def test_head_wise_kv_layout_matches_kv_num_heads(): + """#11a — per-head slice of [Nb, KvH, Bs, Hd] yields [Nb, Bs, Hd].""" + paddle = pytest.importorskip("paddle") + nb, kvh, bs, hd = 4, 2, 8, 16 + t = paddle.zeros([nb, kvh, bs, hd], dtype="float16") + + assert tuple(t.shape) == (nb, kvh, bs, hd) + head0 = t[:, 0, :, :] + head1 = t[:, 1, :, :] + assert tuple(head0.shape) == (nb, bs, hd) + assert tuple(head1.shape) == (nb, bs, hd) + + +def test_forward_meta_unchanged_in_pr1_scope(): + """#11b — PR1 scope: ``ForwardMeta`` is NOT extended with kernel-side fields. + + Per ⚖ Opus 4.7 review (review-pr1-final.md, P3 HIGH), kernel-side plumbing + (``block_tables_3d``) was deliberately moved out of PR1 (cache management) + and into PR2 (AppendAttention discrete kernel). This test pins that scope + decision: ``forward_meta.py`` must NOT carry head-wise kernel fields in PR1. + + AST-only inspection — importing ``fastdeploy.model_executor.forward_meta`` + transitively pulls AppendAttentionBackend → compiled gpu ops, unavailable + on CPU-only environments. + """ + import ast + import pathlib + + src_root = pathlib.Path(__file__).resolve().parents[1].parent + fwd_meta = src_root / "fastdeploy" / "model_executor" / "forward_meta.py" + assert fwd_meta.is_file(), f"forward_meta.py not found at {fwd_meta}" + + tree = ast.parse(fwd_meta.read_text(encoding="utf-8")) + fwd_cls = next( + (n for n in ast.walk(tree) if isinstance(n, ast.ClassDef) and n.name == "ForwardMeta"), + None, + ) + assert fwd_cls is not None, "ForwardMeta class missing from forward_meta.py" + + # PR1 must NOT introduce the head-wise kernel field — that lands in PR2. + head_wise_fields = [ + stmt + for stmt in fwd_cls.body + if isinstance(stmt, ast.AnnAssign) + and isinstance(stmt.target, ast.Name) + and stmt.target.id == "block_tables_3d" + ] + assert head_wise_fields == [], ( + "PR1 scope violation: block_tables_3d must NOT be added to ForwardMeta in PR1; " + "deferred to PR2 (AppendAttention discrete kernel) per Opus review P3." + ) + + +def test_block_wise_fp8_transfer_keeps_rank3_scale_contract(): + """#11c — PR1 must not flatten fp8 scales before ``swap_cache_all_layers``. + + ``swap_cache_all_layers`` reads scale tensors as ``[blocks, heads, block_size]``. + Flattening scales to rank 2 is a PR2/kernel-layout concern and is invalid + while PR1 still sends legacy block ids to the transfer op. + """ + import ast + import pathlib + + src_root = pathlib.Path(__file__).resolve().parents[1].parent + transfer = src_root / "fastdeploy" / "cache_manager" / "cache_transfer_manager.py" + assert transfer.is_file(), f"cache_transfer_manager.py not found at {transfer}" + + tree = ast.parse(transfer.read_text(encoding="utf-8")) + helper_defs = [ + n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "_maybe_headwise_flatten_scales" + ] + assert helper_defs == [], "PR1 must not flatten block_wise_fp8 scales for swap_cache_all_layers" From ddb08f1cdc078a7174bc127f0272ebfc3fe2d9c3 Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Wed, 6 May 2026 11:19:48 +0200 Subject: [PATCH 2/6] fix(kvcache): report head-wise free_gpu_block_num as logical blocks (not cache-ids) PaddlePaddle-bot flagged that _init_head_wise_free_list and the allocate/recycle paths exported the raw length of gpu_free_head_wise_block_list as free_gpu_block_num. That list holds num_gpu_blocks * kv_num_heads per-(block,head) cache ids, so the metric inflated by kv_num_heads (e.g. 8x for ERNIE-21B-A3B-Paddle). Divide by max(1, kv_num_heads) at all three sites so the exported counter stays in logical-block units, consistent with the legacy gpu_free_block_list semantics that downstream dashboards rely on. Refs: review on PR #7717 (PaddlePaddle-bot) Signed-off-by: bob-cloudforge --- fastdeploy/cache_manager/prefix_cache_manager.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index c1e445b4229..ce646402b91 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -512,7 +512,13 @@ def _init_head_wise_free_list(self): self.gpu_free_head_wise_block_list = list(range(total_cache_ids - 1, -1, -1)) heapq.heapify(self.gpu_free_head_wise_block_list) self.total_head_wise_cache_ids = total_cache_ids - main_process_metrics.free_gpu_block_num.set(len(self.gpu_free_head_wise_block_list)) + # head-wise free list holds per-(block,head) cache ids; divide by + # kv_num_heads so the exported metric stays in logical-block units + # (matches legacy gpu_free_block_list semantics; avoids kv_num_heads + # inflation observed by PaddlePaddle-bot review on PR #7717). + main_process_metrics.free_gpu_block_num.set( + len(self.gpu_free_head_wise_block_list) // max(1, self.kv_num_heads) + ) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) def can_allocate_gpu_blocks(self, num_blocks: int, try_free_gpu_blocks: bool = True): @@ -603,7 +609,8 @@ def allocate_gpu_blocks_head_wise(self, num_blocks, req_id=None): f"req_id:{req_id} allocate_gpu_blocks_head_wise: {allocated}, " f"len(gpu_free_head_wise_block_list) {len(free_list)}" ) - main_process_metrics.free_gpu_block_num.set(len(free_list)) + # report logical-block units (free_list counts per-(block,head) ids) + main_process_metrics.free_gpu_block_num.set(len(free_list) // max(1, self.kv_num_heads)) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) return allocated @@ -657,7 +664,8 @@ def recycle_gpu_blocks_head_wise(self, cache_ids, req_id=None): f"req_id:{req_id} recycle_gpu_blocks_head_wise: pushed {len(valid)} ids, " f"len(gpu_free_head_wise_block_list) {len(free_list)}" ) - main_process_metrics.free_gpu_block_num.set(len(free_list)) + # report logical-block units (free_list counts per-(block,head) ids) + main_process_metrics.free_gpu_block_num.set(len(free_list) // max(1, self.kv_num_heads)) main_process_metrics.available_gpu_resource.set(self.available_gpu_resource) def allocate_cpu_blocks(self, num_blocks): From 67e50c09b81ad1af1534fa7eda5b61c7bdd189a3 Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Wed, 6 May 2026 11:22:56 +0200 Subject: [PATCH 3/6] fix(append_attn): document SWA-sentinel block_id==-1 guard contract (sink-safe) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PaddlePaddle-bot review on PR #7717 flagged the four 'if (block_id < 0) { block_id = 0; }' fallbacks in the c16 multiquery attention kernel as potentially unsafe — accessing block 0 when block_id == -1 looks like a silent OOB. Document the actual contract: block_id == -1 is the SWA recycle sentinel written by recycle_request_swa_head_cache (T53 PR1). The SWA mask built from chunk_start/chunk_end zeroes any contribution from this aged-out region in softmax, so the value loaded from block 0 is mathematically masked away. SAFETY argument: when sink_size > 0, recycle_from_floor = sink_blocks guarantees the sink window is never recycled, so block_id == -1 cannot occur inside the attended sink region. This is a comment-only change. No code semantics altered. Refs: review on PR #7717 (PaddlePaddle-bot) Signed-off-by: bob-cloudforge --- .../multiquery_attention_c16_impl.cuh | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index d2d7ce6e43a..6498f9d466c 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -212,6 +212,13 @@ __global__ void multi_query_append_attention_kernel( wid * 4 + tid / 8, tid % 8); uint32_t kv_idx_base = chunk_start; + // SWA sentinel guard (T53 PR1): block_id == -1 indicates the slot was + // recycled by recycle_request_swa_head_cache. The SWA mask built from + // chunk_start/chunk_end zeroes any contribution from this aged-out region, + // so the value loaded from block 0 is masked away in softmax. SAFETY: + // when sink_size>0, recycle_from_floor=sink_blocks guarantees the sink + // window is never recycled, so block_id==-1 cannot occur inside the + // attended sink region and the fallback to block 0 is provably out of range. int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; @@ -288,6 +295,9 @@ __global__ void multi_query_append_attention_kernel( __syncthreads(); kv_idx_base += num_frags_z * 16; + // SWA sentinel guard (T53 PR1): see top-of-function note. block_id == -1 + // means the slot was recycled; SWA mask zeroes its contribution and the + // sink window (when sink_size>0) is never recycled. block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; @@ -603,6 +613,13 @@ __global__ void multi_query_append_attention_warp1_4_kernel( wid * 4 + tid / 8, tid % 8); uint32_t kv_idx_base = chunk_start; + // SWA sentinel guard (T53 PR1): block_id == -1 indicates the slot was + // recycled by recycle_request_swa_head_cache. The SWA mask built from + // chunk_start/chunk_end zeroes any contribution from this aged-out region, + // so the value loaded from block 0 is masked away in softmax. SAFETY: + // when sink_size>0, recycle_from_floor=sink_blocks guarantees the sink + // window is never recycled, so block_id==-1 cannot occur inside the + // attended sink region and the fallback to block 0 is provably out of range. int block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; @@ -683,6 +700,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel( __syncthreads(); kv_idx_base += BLOCK_SIZE; + // SWA sentinel guard (T53 PR1): see top-of-function note. block_id == -1 + // means the slot was recycled; SWA mask zeroes its contribution and the + // sink window (when sink_size>0) is never recycled. block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; From aebcd960a47b5b9f7dbf5aec8f6c6064d7c5b18f Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Wed, 6 May 2026 12:07:21 +0200 Subject: [PATCH 4/6] fix(kvcache): use float div in available_gpu_resource PR1 backport of PR2 commit 327a43b500. Avoids integer-truncation underestimating available KV blocks when head_free % kv_num_heads != 0, which caused the scheduler to see 0 capacity on partial recycles and trigger false OOM rejections. Signed-off-by: bob-cloudforge --- fastdeploy/cache_manager/prefix_cache_manager.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index ce646402b91..b54f7ae3128 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -195,7 +195,13 @@ def _get_kv_cache_shape(self, max_block_num): def available_gpu_resource(self): if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0: head_free = len(getattr(self, "gpu_free_head_wise_block_list", [])) - return (head_free // max(1, self.kv_num_heads)) / self.num_gpu_blocks + # Use float division so partial SWA recycle (head_free % kv_num_heads != 0) + # is reflected in the metric. Integer division would truncate fractional + # logical-block availability and cause the scheduler to under-report + # capacity, potentially triggering false OOM rejections. The legacy path + # below already returns a continuous float in [0, 1]; this keeps both + # paths value-domain compatible. + return (head_free / max(1, self.kv_num_heads)) / self.num_gpu_blocks return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 def launch_cache_manager( From 263db00a02e696c97e0d42a941d0ec0e4f1e7d60 Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Wed, 6 May 2026 14:15:53 +0200 Subject: [PATCH 5/6] fix(t53): mirror available_gpu_resource defensive form from PR2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit No behavior change in PR1 (singular is the populated heap here); keeps the property body identical across PR1/PR2 so future merges do not drift. Closes the PaddlePaddle-bot 🟡 advisory on #7717. --- .../cache_manager/prefix_cache_manager.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index b54f7ae3128..107c9f67124 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -193,16 +193,18 @@ def _get_kv_cache_shape(self, max_block_num): @property def available_gpu_resource(self): - if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0: - head_free = len(getattr(self, "gpu_free_head_wise_block_list", [])) - # Use float division so partial SWA recycle (head_free % kv_num_heads != 0) - # is reflected in the metric. Integer division would truncate fractional - # logical-block availability and cause the scheduler to under-report - # capacity, potentially triggering false OOM rejections. The legacy path - # below already returns a continuous float in [0, 1]; this keeps both - # paths value-domain compatible. - return (head_free / max(1, self.kv_num_heads)) / self.num_gpu_blocks - return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0 + if self.num_gpu_blocks <= 0: + return 0.0 + if getattr(self, "head_wise", False): + heaps = getattr(self, "gpu_free_head_wise_block_lists", None) + if heaps: + head_free = sum(len(h) for h in heaps) + return (head_free / max(1, self.kv_num_heads)) / self.num_gpu_blocks + # legacy / startup-window fallback + legacy = getattr(self, "gpu_free_head_wise_block_list", None) + if legacy: + return (len(legacy) / max(1, self.kv_num_heads)) / self.num_gpu_blocks + return len(self.gpu_free_block_list) / self.num_gpu_blocks def launch_cache_manager( self, From 584e03302d32f39479d241b12862f99102e81545 Mon Sep 17 00:00:00 2001 From: bob-cloudforge Date: Wed, 6 May 2026 15:05:00 +0200 Subject: [PATCH 6/6] fix(t53): available_gpu_resource use is-not-None guard, drop dead legacy path --- fastdeploy/cache_manager/prefix_cache_manager.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/fastdeploy/cache_manager/prefix_cache_manager.py b/fastdeploy/cache_manager/prefix_cache_manager.py index 107c9f67124..a6119979917 100644 --- a/fastdeploy/cache_manager/prefix_cache_manager.py +++ b/fastdeploy/cache_manager/prefix_cache_manager.py @@ -197,13 +197,9 @@ def available_gpu_resource(self): return 0.0 if getattr(self, "head_wise", False): heaps = getattr(self, "gpu_free_head_wise_block_lists", None) - if heaps: + if heaps is not None: head_free = sum(len(h) for h in heaps) return (head_free / max(1, self.kv_num_heads)) / self.num_gpu_blocks - # legacy / startup-window fallback - legacy = getattr(self, "gpu_free_head_wise_block_list", None) - if legacy: - return (len(legacy) / max(1, self.kv_num_heads)) / self.num_gpu_blocks return len(self.gpu_free_block_list) / self.num_gpu_blocks def launch_cache_manager(