Skip to content

【Hackathon 10th Spring No.53】[Feature][Kernel] Optimize AppendAttention for discrete head-wise block_idx [cf]#7718

Open
bob-cloudforge wants to merge 9 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-053-pr2-discrete-block-idx-v4
Open

【Hackathon 10th Spring No.53】[Feature][Kernel] Optimize AppendAttention for discrete head-wise block_idx [cf]#7718
bob-cloudforge wants to merge 9 commits intoPaddlePaddle:developfrom
CloudForge-Solutions:task/h10-053-pr2-discrete-block-idx-v4

Conversation

@bob-cloudforge
Copy link
Copy Markdown

@bob-cloudforge bob-cloudforge commented May 4, 2026

PR2 Body — 【Hackathon 10th Spring No.53】[Feature][Kernel] Optimize AppendAttention for discrete head-wise block_idx [cf]

Companion PR stacked on PR1 (#7717). Until PR1 lands on develop, this branch carries the PR1 producer commits as its base, so the GitHub diff against develop is the stacked PR1 + PR2 surface (31 files, +2360/-49). The PR2-only delta is summarized below.


Motivation

Hackathon 10th Spring Task No.53 PR2 of 2. Spec: https://github.com/PaddlePaddle/community/blob/master/hackathon/hackathon_10th/【Hackathon_10th】开源贡献个人挑战赛春节特别季—任务合集.md#no53.

When SWA and full-attention heads coexist in one layer, the current AppendAttention path walks the same uniform block_tables row for every KV head. The discrete block_tables_headwise layout (rank-2 logical [batch, kv_head, block], physical [batch * local_kv_heads, max_blocks_per_head]) lets SWA-head CTAs walk a shorter / sparser row while full heads preserve the existing full-context row. That reduces unnecessary block-id loads and K/V page reads under the required recycle OFF benchmark.

The ABI is additive: callers without block_tables_headwise use the legacy path unchanged; callers with the head-wise table take the new kernel-visible fast path.

Modifications

Total stacked diff: 31 files, +2360/-49, grouped below. The PR2-only delta block lists what this PR adds on top of PR1.

Stacked surface (PR1 producer + PR2 kernel + shared tests)

Area Files +/− Purpose
Kernel (custom_ops/gpu_ops/) 7 +136 / −19 append_attention.cu, append_attn/{append_attention_c16_impl.cuh, append_attention_kernel.h, multiquery_attention_c16_impl.cuh, multiquery_attention_c16_kernel.h, template_config.json}, cpp_extensions.cc
Runtime (fastdeploy/) 13 +849 / −30 cache_manager/prefix_cache_manager.py, engine/sched/resource_manager_v1.py, worker/{gpu_model_runner.py, input_batch.py, worker_process.py}, model_executor/{forward_meta.py, layers/attention/append_attn_backend.py, layers/attention/ops/append_attention.py, models/paddleformers/base.py}, engine/request.py, spec_decode/mtp.py, config.py, envs.py
Tests (tests/) 9 +1360 / 0 tests/cache_manager/test_{per_head_heaps, head_wise_freelist, head_wise_extend_validation, head_wise_abort_reset, head_wise_tp_consistency, swa_recycle, swa_recycle_legacy_relief, benchmark_head_wise_swa}.py, tests/layers/test_append_attention_head_wise_shapes.py
Bench / config 2 +15 / 0 benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml, .gitignore

PR2-only delta (changes added on top of PR1 #7717)

File Change
custom_ops/gpu_ops/append_attention.cu Thread block_tables_headwise through AppendAttentionKernel, AppendAttention, and AppendAttentionWithOutput; add PD_CHECK(.dtype() == INT32) dtype guards on every Python-supplied .data<int>() read (set_max_lengths, encoder_num_blocks, kv_num_blocks, decoder_num_blocks, mask_offset); make block_tables_headwise keyword-only on the Python op; add sink_size / head_wise_full_hidden parameters; thread sink_size into append_attention_with_output_gpu() (was hardcoded 0).
custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh c16 kernel point-of-use: replace uniform block_tables row walk with per-head row selection from block_tables_headwise when present; preserve existing block_id < 0 → 0 clamp at the load site (-1 sentinel = evicted SWA slot, mask zeroes contribution). c8/c4 variants deferred to PR3.
custom_ops/gpu_ops/append_attn/{append_attention_c16_impl.cuh, append_attention_kernel.h, multiquery_attention_c16_kernel.h, template_config.json}, custom_ops/gpu_ops/cpp_extensions.cc Thread the optional block_tables_headwise tensor through kernel headers, template config, and the PHI op signature.
fastdeploy/model_executor/layers/attention/append_attn_backend.py Add _get_block_tables_headwise(forward_meta) helper (per-call read of forward_meta, then forward_meta.cache_manager, else None); thread the tensor as a kwarg into both append_attention() and append_attention_with_output() call sites; pass sink_size and head_wise_full_hidden to the with-output path.
fastdeploy/model_executor/layers/attention/ops/append_attention.py Make block_tables_headwise keyword-only on both ops; guard head_wise_full_hidden > 0 in the use_output=True path with assert head_wise_full_hidden == 0 (dual-call merge stays in append_attention() only; with-output path deferred to PR3).
fastdeploy/engine/sched/resource_manager_v1.py Add assert (kv_num_heads_global < tp_size) or (kv_num_heads_global % tp_size == 0) GQA divisibility guard before kv_num_heads_global // tp_size.
tests/layers/test_append_attention_head_wise_shapes.py Shape-level smoke test for the kernel-visible head-wise contract (additive on top of PR1's allocator tests).

The c16 kernel is the only flavor consumed in PR2. c8 / c4 / write-path mirrors and the graph-blacklist update are intentionally deferred to PR3. Safety in PR2 = legacy uniform block_tables walk + existing block_id<0 fallback + SWA mask zero-contribution.

Clean-room note: PR2 uses public PR #6702 only as behavior/reference context. No Co-authored-by trailer; prose acknowledgement only.

Usage or Command

No user-facing API change. The optimized path is active when PR1 provides block_tables_headwise and head-wise SWA is enabled:

export FD_HEAD_WISE_KV_CACHE=1
export FD_T53_HEAD_WISE_SWA_RATIO=0.5    # leading half of KV heads designated SWA

Spec acceptance must be measured with timely SWA recycle OFF, comparing 1D uniform block_idx against 2D discrete block_idx.

Accuracy Tests

Spec PR2 acceptance — recycle OFF; H/B card; 1D uniform vs 2D discrete; both TTFT and TBT improve ≥5%:

block_idx mode Hardware TTFT (ms) TBT (ms) Δ TTFT Δ TBT
1D (uniform) H100 / H20 / B200 TBD TBD baseline baseline
2D (discrete, optimized) same TBD TBD +TBD% ≥5 ✓ +TBD% ≥5 ✓

Benchmark: FastDeploy/benchmarks/serving/benchmark_serving.py with benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml.

Hardware request to reviewers (cc @luotao1): PR2 acceptance requires H/B card per spec. A800 numbers (when present) are preview-only and labelled as such; FULL bench run is one-time pre-merge.

Correctness gates before push:

  • block_tables_headwise=None legacy path unchanged.
  • use_output=True and use_output=False both consume the same head-wise table contract.
  • 1D vs 2D numeric parity for FP16/BF16/cache-quant variants; -1 sentinel rows skip before K/V pointer derivation.
  • GSM8K parity within ±0.1 pp.
  • All 9 head-wise tests under tests/cache_manager/ and tests/layers/ green locally.

CI run: https://github.com/PaddlePaddle/FastDeploy/pull/7718/checks

Depends on: #7717 (ResourceManagerV1 head-wise SWA recycle, producer for block_tables_headwise).

Checklist

Adds rank-2 block_tables_headwise plumbing for c16 multi-query attention path.

Updates template_config.json so the codegen produces explicit instantiations matching the new impl signature (added optional block_table_headwise param).
@paddle-bot
Copy link
Copy Markdown

paddle-bot Bot commented May 4, 2026

Thanks for your contribution!

@paddle-bot paddle-bot Bot added the contributor External developers label May 4, 2026
@CLAassistant
Copy link
Copy Markdown

CLAassistant commented May 4, 2026

CLA assistant check
All committers have signed the CLA.

PaddlePaddle-bot

This comment was marked as outdated.

@PaddlePaddle-bot
Copy link
Copy Markdown

PaddlePaddle-bot commented May 4, 2026

🤖 Paddle-CI-Agent | ci_status_monitor | 2026-05-06 16:43:29

CI报告基于以下代码生成(30分钟更新一次):


1 任务总览

所有 required 任务均已通过(当前无 required 任务配置),有 1 个 optional 任务失败(不阻塞合并)。

总执行(rerun次数) 总任务 ✅ 通过 ❌ 失败 ⏳ 运行中 ⏸️ 等待中 跳过
2(0) 2 1 1 0 0 0

⚠️ 注意:以下 7 个 Workflow 处于 action_required 状态(等待审批后才会执行):CI_XPU、ILUVATAR-CI、PR Build and Test、Approval、CI_HPU、Check PR Template、Codestyle-Check。这些 Workflow 需人工审批触发。


2 任务状态汇总

2.1 Required任务 : 0/0 通过

当前未配置必选任务(或 GitHub Branch Protection Rules 未设置)。

2.2 可选任务 — 1/2 通过

可选任务不阻塞合并,失败仅供参考。

状态 任务 耗时 日志 重跑
Trigger Jenkins for PR 11m7s Job -
其余 1 个可选任务通过 - - -

3 失败详情(仅 required)

无 required 失败任务。

@bob-cloudforge bob-cloudforge changed the title feat(append_attn): head-wise SWA recycle + discrete-block-idx ABI 【Hackathon 10th Spring No.53】[Feature][Kernel] Optimize AppendAttention for discrete head-wise block_idx [cf] May 4, 2026
- gpu_model_runner: _maybe_slice_block_tables_headwise now is_dummy_or_profile_run-aware so captured CUDA graph records non-null sidecar; identity-stride dummy seeding aligned with kernel shape assert (dim0 == bsz * kv_num_heads).
- input_batch: InputBatch.swap_states + ProposerInputBatch.swap_states clone-then-copy swap block_tables_headwise[i*kv_local:(i+1)*kv_local] row groups so head-wise rows follow slot moves on both target and proposer paths.
- gpu_model_runner._process_reorder: in-place clear forward_batch_reqs_list before repopulating from share_inputs.index_to_batch_id; prevents stale tail entries from leaking into logprob-settings consumers (Option A: post-hoc rebuild).
- gpu_model_runner: docstring corrected to match C16 kernel sentinel handling (multiquery_attention_c16_impl.cuh L215-223 / L605-613); -1 sentinel reads block 0 as harmless placeholder, SWA mask zeroes the contribution. No fallback to flat block_tables.
- benchmarks/yaml: add eb45-21b-a3b-32k-bf16-kv50-512s.yaml for PR2 bench geometry.

Refs: T53 PR2 PaddlePaddle#7718.
self.input_batch is not constructed yet during _dummy_prefill_inputs
and CUDA-graph capture, so reading self.input_batch.kv_num_heads_local
crashed the worker before the bench server could start. Use
self.model_config.kv_num_heads (set in init_share_inputs before warmup)
which has the same TP-aware value.
The PR1 head-wise allocator (PaddlePaddle#7717) emits flat global block IDs in
[0, num_gpu_blocks * kv_num_heads) from a single shared min-heap, but
the PR2 discrete kernel (PaddlePaddle#7718) ABI L1 expects per-head local IDs in
{-1} ∪ [0, num_gpu_blocks). This causes cudaIllegalAddress on any
request whose allocated IDs cross the num_gpu_blocks boundary
(i.e. immediately on head index ≥ ceil(num_gpu_blocks / num_blocks)).

This commit normalizes IDs at the backend boundary in append_attn_backend.py
using `local = flat % num_gpu_blocks` (sentinel -1 preserved), with a
fail-fast assert to catch any residual OOB. The hotfix is bench-only;
the canonical fix (per-head independent allocator pools) is deferred to
PR1 v5 (RFC-PR1-reanchored.md §3).

Also adds FD_T53_HEAD_WISE_SWA_RATIO ∈ [0.0, 1.0] validator.

Refs: .checkpoints/h10/task-53/design/PR2-HOTFIX-SPEC.md (Option B, OPUS-GATE PASS)
     .checkpoints/h10/task-53/design/CONTRACT-ORACLE.md (I2, I7)
     .checkpoints/h10/task-53/design/RFC-PR2-reanchored.md (ABI L1)

Files: 2 changed (1 backend hotfix, 1 envs validator)
…mixed

Boolean fancy indexing and .item() CPU sync inside forward_mixed
crash CUDA graph capture (cudaError 900 cudaErrorStreamCaptureUnsupported).
The paddle.where normalization is graph-safe (static-shape elementwise ops).
Assert was debug-only; normalization alone is the actual OOB fix.
- prefix_cache_manager: replace shared flat heap with kv_num_heads
  independent heaps; allocate/recycle now per-head with rank-2
  [kv_num_heads][N] nested-list contract per RFC-PR2 §3
- gpu_model_runner: warmup base = idx * fill_blocks (not cross-head
  flat); rank-2 buffer shape preserved per kernel ABI
- append_attn_backend: revert flat % num_gpu_blocks HOTFIX (silent
  aliasing); replace with FD_T53_DEBUG_BLOCK_TABLES gated assert
- tests: 4 per-head value-space invariants, no MagicMock
- .gitignore: ignore runs/ bench output dir

Closes T53-PR2-OOB-blocker (kernel ABI now matches producer).
….data<int>()

Adds dtype guards before .data<int>() reads of:
- set_max_lengths, encoder_num_blocks, kv_num_blocks, decoder_num_blocks
  (in AppendAttentionKernel, lines 100-105/186/187/285)
- mask_offset.get() (in AppendAttention L599 and AppendAttentionWithOutput L763)

Catches accidental INT64/FP dtype before UB. Matches existing PD_CHECK style
from set_flags.cu / set_mask_value.cu.
…p_size

Guards against silent under-allocation when kv_num_heads_global is not a
multiple of tp_size (and >= tp_size). The kv<tp replication path is
explicitly excluded from the assert, preserving existing GQA/MQA behavior.
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-05-06 16:56:58

📋 Review 摘要

PR 概述:为 AppendAttention 内核新增 block_tables_headwise 可选参数,实现 SWA/全注意力混合场景下的逐头离散 block 索引分配框架(PR2/3 最小可行切片)

变更范围custom_ops/append_attn/(CUDA kernel)、cache_manager/engine/sched/model_executor/layers/attention/config.py

影响面 Tag[OP] [KVCache] [Scheduler] [FDConfig]

📝 PR 规范检查

标题包含两个 Tag([Feature] 和非官方 [Kernel]),且附有非标准前缀 【Hackathon 10th Spring No.53】 和后缀 [cf]## Checklist 使用了自定义条目,未遵循 §D2 标准模板。

标题建议(可直接复制):

  • [Feature] Optimize AppendAttention for discrete head-wise block_idx

PR 描述建议(可直接复制,必须复刻 checklist §D2 模板的完整结构):

## Motivation
当 SWA(Sliding Window Attention)与全注意力 head 在同一 Layer 中共存时,现有 AppendAttention 路径对所有 KV head 使用同一 uniform `block_tables` 行,SWA head 会读取无效的全上下文 block。新增的 `block_tables_headwise` 二维布局(`[bsz*kv_num_heads, max_blocks_per_head]`)使 SWA-head CTA 仅遍历其有效短行,full-attention head 保留原有全上下文行,在 recycle OFF 基准测试下减少无效 block-id 加载与 K/V page 读取,预期 TTFT/TBT 改善 ≥5%。本 PR 为最小可行切片(PR2/3),保持 ABI 增量兼容。

## Modifications

| 文件 | 变更说明 |
|---|---|
| `custom_ops/gpu_ops/append_attention.cu` | `AppendAttention`/`AppendAttentionWithOutput` Op 新增可选参数 `block_tables_headwise``PD_BUILD_STATIC_OP` 注册同步更新;补充 `mask_offset` 等 INT32 dtype 守卫 |
| `custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh` | `multi_query_append_attention_kernel``_warp1_4_kernel` 新增 `block_table_hw`/`max_blocks_per_head` 参数,per-head 行选择逻辑;`block_id<0` 分支改为内联三目表达式并注明 `-1` 哨兵语义 |
| `custom_ops/gpu_ops/append_attn/template_config.json` | 更新 `function_signature` 以包含新参数 |
| `custom_ops/gpu_ops/cpp_extensions.cc` | Python 绑定同步新增参数 |
| `fastdeploy/model_executor/layers/attention/ops/append_attention.py` | `append_attention()` 新增 `block_tables_headwise` 关键字参数与 `sink_size`/`head_wise_full_hidden`;SWA+全注意力双调用合并逻辑;`append_attention_with_output()``head_wise_full_hidden>0` 路径抛出 `NotImplementedError`(T53 PR3 范围) |
| `fastdeploy/model_executor/layers/attention/append_attn_backend.py` | 新增 `_get_block_tables_headwise()` 辅助方法;将 headwise 张量作为 kwarg 传入两条调用路径 |
| `fastdeploy/cache_manager/prefix_cache_manager.py` | 新增逐头独立 min-heap 空闲列表(`gpu_free_head_wise_block_lists`);`allocate_gpu_blocks_head_wise`/`recycle_gpu_blocks_head_wise` 方法;`available_gpu_resource` 属性适配 head-wise 路径 |
| `fastdeploy/engine/sched/resource_manager_v1.py` | 新增 `_swa_window_sink_block``_num_swa_heads``_head_wise_swa_active``_should_use_head_wise_swa` 等辅助方法;per-head 块表分配/扩展/回收框架 |
| `fastdeploy/config.py` | FDConfig.__init__ 末尾注入 T53 engine-main 进程 fixture(`FD_T53_HEAD_WISE_SWA_FIXTURE` 门控) |
| `fastdeploy/engine/request.py` | `Request` 新增 `head_block_tables` 字段 |
| `fastdeploy/envs.py` | 新增 `FD_HEAD_WISE_KV_CACHE``FD_T53_HEAD_WISE_SWA_FIXTURE``FD_T53_HEAD_WISE_SWA_RATIO` 环境变量 |
| `tests/cache_manager/``tests/layers/` | 新增 9 个单元测试覆盖 head-wise 分配/回收/shape 验证 |

## Usage or Command

```bash
export FD_HEAD_WISE_KV_CACHE=1
# 性能基准(验收:recycle OFF;1D uniform vs 2D discrete;TTFT/TBT 改善 ≥5%)
python benchmarks/serving/benchmark_serving.py \
    --config benchmarks/yaml/eb45-21b-a3b-32k-bf16-kv50-512s.yaml
```

## Accuracy Tests

| `block_idx` 模式 | 硬件 | TTFT (ms) | TBT (ms) | Δ TTFT | Δ TBT |
|---|---|---|---|---|---|
| 1D (uniform) | H100 / H20 / B200 | TBD | TBD | baseline | baseline |
| 2D (discrete, optimized) | same | TBD | TBD | **+TBD%** ≥5 ✓ | **+TBD%** ≥5 ✓ |

GSM8K parity 在 ±0.1 pp 范围内(待补充数据)。

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

问题

级别 文件 概述
🔴 Bug fastdeploy/cache_manager/prefix_cache_manager.py:197 available_gpu_resource 在 head-wise 模式下读取已清空的单数属性,恒返回 0.0,调度器误判 GPU 资源耗尽
🟡 建议 fastdeploy/engine/sched/resource_manager_v1.py:296 assert 用于运行时 GQA/MQA 约束校验,Python -O 下会被跳过
🟡 建议 fastdeploy/engine/sched/resource_manager_v1.py:315 _should_use_head_wise_swa 缺少 FD_HEAD_WISE_KV_CACHE 门控,原生带 window_size 的模型可能误触发 head-wise 路径导致 AttributeError

总体评价

整体架构思路清晰,per-head 独立 min-heap 方案正确解决了 block-id 混叠问题,ABI 增量兼容设计合理。available_gpu_resource 恒零是阻塞性 Bug,会导致 head-wise 模式下所有请求被拒绝,需优先修复;assert 替换和 _should_use_head_wise_swa 门控修复成本低,建议一并处理后再进行功能测试。

@property
def available_gpu_resource(self):
if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0:
head_free = len(getattr(self, "gpu_free_head_wise_block_list", []))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug head_free 读取的是 gpu_free_head_wise_block_list单数,在 _init_head_wise_free_list 中被显式清空为 []),而非实际存储空闲块的 gpu_free_head_wise_block_lists复数,per-head 独立堆)。

head_wise=True 时,head_free 恒为 0,导致 available_gpu_resource 恒返回 0.0,调度器将误判 GPU 资源全部耗尽,所有请求被拒绝或系统 stall。

建议修复:

@property
def available_gpu_resource(self):
    if getattr(self, "head_wise", False) and self.num_gpu_blocks > 0:
        lists = getattr(self, "gpu_free_head_wise_block_lists", None)
        if lists:
            total_free = sum(len(h) for h in lists)
            head_free = total_free // max(1, self.kv_num_heads)
            return head_free / self.num_gpu_blocks
    return len(self.gpu_free_block_list) / self.num_gpu_blocks if self.num_gpu_blocks > 0 else 0.0

"""
kv_num_heads_global = int(getattr(self.config.model_config, "num_key_value_heads", 0) or 0)
if kv_num_heads_global <= 0:
return 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 assert 在 Python -O(优化模式)下会被完全跳过,无法保护运行时约束。

建议改为显式抛出:

if not ((kv_num_heads_global < tp_size) or (kv_num_heads_global % tp_size == 0)):
    raise ValueError(
        f"GQA/MQA constraint violated: kv_num_heads={kv_num_heads_global} "
        f"not divisible by tp_size={tp_size}"
    )

if ratio <= 0.0:
return 0
if ratio >= 1.0:
return kv_num_heads
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 _should_use_head_wise_swa 仅检查 window_size > 0hasattr,但 allocate_gpu_blocks_head_wisePrefixCacheManager 的实例方法(hasattr 恒为 True)。若模型原生配置了 window_size > 0 但未设置 FD_HEAD_WISE_KV_CACHE=1_init_head_wise_free_list 不会被调用,gpu_free_head_wise_block_lists 不存在,调用 allocate_gpu_blocks_head_wise 会抛出 AttributeError

建议补充环境变量门控:

def _should_use_head_wise_swa(self, num_blocks: int) -> bool:
    return (
        bool(envs.FD_HEAD_WISE_KV_CACHE)
        and int(getattr(self.config.model_config, "window_size", 0) or 0) > 0
        and hasattr(self.cache_manager, "allocate_gpu_blocks_head_wise")
        and hasattr(self.cache_manager, "recycle_gpu_blocks_head_wise")
        and num_blocks > 0
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants