From e8d4c5c458a48533f492e35252abcd21d3968bb2 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 21:25:55 +0000 Subject: [PATCH 01/29] feat: TPU host offload for KV cache Signed-off-by: Juncheng Gu --- examples/gke/benchmarks/README.md | 117 + examples/gke/benchmarks/benchmark-pod.yaml | 55 + examples/gke/benchmarks/deploy-baseline.yaml | 39 + .../gke/benchmarks/deploy-cpu-offload.yaml | 41 + examples/gke/benchmarks/service.yaml | 15 + examples/gke/hf_secret.yaml | 8 + examples/gke/pod_tpu_commons_cpu_offload.yaml | 32 + ..._tpu_commons_cpu_offload_verification.yaml | 39 + .../gke/pod_tpu_host_offload_unit_tests.yaml | 36 + examples/multi_modal_inference.py | 1 - examples/offline_inference_kv_cache.py | 85 + ...offline_inference_kv_cache_verification.py | 177 ++ .../offload/tpu_offload_accuracy_test.py | 109 + .../tpu_offload_connector_scheduler_test.py | 493 +++++ .../tpu_offload_connector_worker_test.py | 441 ++++ .../offload/tpu_offload_cpu_backend_test.py | 83 + .../offload/tpu_offload_manager_test.py | 342 +++ .../offload/tpu_offload_utils_test.py | 157 ++ tests/kernels/host_dma_test.py | 178 ++ tpu_inference/distributed/offload/__init__.py | 0 .../distributed/offload/cpu_backend.py | 109 + .../distributed/offload/offload_manager.py | 375 ++++ .../offload/tpu_offload_connector.py | 1928 +++++++++++++++++ tpu_inference/distributed/offload/utils.py | 266 +++ tpu_inference/kernels/dma/__init__.py | 0 tpu_inference/kernels/dma/host_dma.py | 102 + tpu_inference/platforms/tpu_platform.py | 4 - tpu_inference/runner/kv_cache_manager.py | 25 + tpu_inference/runner/tpu_runner.py | 3 + tpu_inference/worker/tpu_worker.py | 32 + 30 files changed, 5287 insertions(+), 5 deletions(-) create mode 100644 examples/gke/benchmarks/README.md create mode 100644 examples/gke/benchmarks/benchmark-pod.yaml create mode 100644 examples/gke/benchmarks/deploy-baseline.yaml create mode 100644 examples/gke/benchmarks/deploy-cpu-offload.yaml create mode 100644 examples/gke/benchmarks/service.yaml create mode 100644 examples/gke/hf_secret.yaml create mode 100644 examples/gke/pod_tpu_commons_cpu_offload.yaml create mode 100644 examples/gke/pod_tpu_commons_cpu_offload_verification.yaml create mode 100644 examples/gke/pod_tpu_host_offload_unit_tests.yaml create mode 100644 examples/offline_inference_kv_cache.py create mode 100644 examples/offline_inference_kv_cache_verification.py create mode 100644 tests/distributed/offload/tpu_offload_accuracy_test.py create mode 100644 tests/distributed/offload/tpu_offload_connector_scheduler_test.py create mode 100644 tests/distributed/offload/tpu_offload_connector_worker_test.py create mode 100644 tests/distributed/offload/tpu_offload_cpu_backend_test.py create mode 100644 tests/distributed/offload/tpu_offload_manager_test.py create mode 100644 tests/distributed/offload/tpu_offload_utils_test.py create mode 100644 tests/kernels/host_dma_test.py create mode 100644 tpu_inference/distributed/offload/__init__.py create mode 100644 tpu_inference/distributed/offload/cpu_backend.py create mode 100644 tpu_inference/distributed/offload/offload_manager.py create mode 100644 tpu_inference/distributed/offload/tpu_offload_connector.py create mode 100644 tpu_inference/distributed/offload/utils.py create mode 100644 tpu_inference/kernels/dma/__init__.py create mode 100644 tpu_inference/kernels/dma/host_dma.py diff --git a/examples/gke/benchmarks/README.md b/examples/gke/benchmarks/README.md new file mode 100644 index 000000000..9d1136637 --- /dev/null +++ b/examples/gke/benchmarks/README.md @@ -0,0 +1,117 @@ +# Benchmarks using SGLang bench_serving tool + +This guide outlines the steps to deploy a vLLM serving instance on Google Kubernetes Engine (GKE) with TPUs, create a service to expose it, and then run the SGLang `bench_serving.py` benchmark against it. Two deployment options for vLLM are provided: a baseline without host offload and one with TPU host offload for KV cache. + +## Prerequisites + +* `kubectl` configured to connect to your GKE cluster. +* `gcloud` CLI installed and authenticated. +* A GKE cluster with TPU nodes (the below steps have been verified with `ct6e-standard-8t` GKE node) +* Access to Llama-3.3-70B model on Hugging Face + +## 1. Create Hugging Face Token Secret + +A Hugging Face token is required to pull the model. Create a Kubernetes secret with your token: + +```bash +kubectl create secret generic hf-token-secret --from-literal=token='' +``` + +Replace `` with your actual Hugging Face token. + +## 2. Deploy vLLM Pod (Choose One) + +Choose one of the following deployment options for your vLLM pod. Ensure the right container image is used in the pod spec + +### Option A: Baseline vLLM (No Host Offload) + +This deployment uses a standard vLLM setup without any specific TPU host offload connector. The KV cache will reside entirely on the TPU HBM. + +```bash +kubectl apply -f deploy-baseline.yaml +``` + +### Option B: vLLM with TPU Host Offload + +This deployment configures vLLM to use a `TPUOffloadConnector` for KV cache offload to the host CPU memory. This is specified by the `--kv-transfer-config` argument. + +```bash +kubectl apply -f deploy-cpu-offload.yaml +``` + +## 3. Deploy Service + +Deploy a LoadBalancer service to expose your vLLM deployment. This will provide an external IP address to send benchmark requests to. + +```bash +kubectl apply -f service.yaml +``` + +After deployment, get the external IP of the service: + +```bash +kubectl get service tpu-offline-inference -o jsonpath='{.status.loadBalancer.ingress[0].ip}' +``` + +This command will directly output the external IP address. It might take a few minutes for the IP to be provisioned. + +## 4. Run Benchmark + +Instead of installing SGLang locally, we can run the benchmark from within the Kubernetes cluster using a dedicated pod. This approach avoids local dependency management and ensures the benchmark runs in a consistent environment. + +### a. Configure the Benchmark Pod + +A sample pod specification is provided in `benchmark-pod.yaml`. Before deploying it, you need to configure the environment variables within the file, especially the `IP` of the vLLM service. + +Open `benchmark-pod.yaml` and replace `` with the actual external IP address of your `tpu-offline-inference` service obtained in step 3. + +You can also adjust the following benchmark parameters via environment variables in the `benchmark-pod.yaml` file: + +* `GSP_NUM_GROUPS`: The number of unique system prompts. +* `GSP_PROMPTS_PER_GROUP`: The number of questions per system prompt. +* `GSP_SYSTEM_PROMPT_LEN`: The token length of the system prompt. +* `GSP_QUESTION_LEN`: The token length of the question. +* `GSP_OUTPUT_LEN`: The desired output token length. +* `MODEL`: The model to benchmark. + +### b. Deploy the Benchmark Pod + +Once configured, deploy the benchmark pod: + +```bash +kubectl apply -f benchmark-pod.yaml +``` + +The pod will start, clone the SGLang repository, install dependencies, and run the benchmark. + +### c. Monitor the Benchmark + +You can monitor the progress of the benchmark by checking the logs of the pod: + +```bash +kubectl logs -f sglang-benchmark +``` + +The pod is configured with `restartPolicy: Never`, so it will run the benchmark once and then complete. + +## 5. Understanding `generated-shared-prefix` Dataset + +The `generated-shared-prefix` dataset is designed to benchmark serving performance for workloads where multiple requests share a common, long prefix. This is common in applications using system prompts or few-shot examples. + +**How it works:** + +1. **System Prompt Generation:** A specified number of unique "system prompts" are generated. Each is a long sequence of random tokens. +2. **Question Generation:** Shorter "questions" (random tokens) are generated. +3. **Prompt Combination:** Each system prompt is combined with multiple unique questions to form final prompts. This creates groups of prompts where each prompt in a group shares the exact same system prompt as a prefix. +4. **Request Creation:** Each final prompt is packaged with its desired output length. +5. **Shuffling:** The entire set of generated requests is randomly shuffled. This interleaves requests from different groups, simulating realistic traffic where shared prefixes are not necessarily processed sequentially. +6. **Caching:** The generated dataset is cached locally for faster subsequent runs with the same parameters. + +**Key Parameters for `generated-shared-prefix`:** + +* `--gsp-num-groups`: The number of unique system prompts to generate. Each system prompt forms a "group" of requests. +* `--gsp-prompts-per-group`: The number of unique questions that will be appended to each system prompt. This determines how many requests will share a given system prompt. +* `--gsp-system-prompt-len`: The length (in tokens) of each generated system prompt. +* `--gsp-question-len`: The length (in tokens) of each generated question. +* `--gsp-output-len`: The desired length (in tokens) of the generated output for each request. +* `--seed`: (Optional) An integer seed for random number generation, ensuring reproducible prompt generation and request shuffling across runs. diff --git a/examples/gke/benchmarks/benchmark-pod.yaml b/examples/gke/benchmarks/benchmark-pod.yaml new file mode 100644 index 000000000..05e2da502 --- /dev/null +++ b/examples/gke/benchmarks/benchmark-pod.yaml @@ -0,0 +1,55 @@ +apiVersion: v1 +kind: Pod +metadata: + name: sglang-benchmark +spec: + containers: + - name: sglang-benchmark-container + image: python:3.9-slim + command: ["/bin/bash", "-c"] + args: + - | + set -ex + apt-get update && apt-get install -y git + git clone -b v0.5.2 https://github.com/sgl-project/sglang.git + cd sglang + pip install --upgrade pip + pip install protobuf aiohttp numpy requests tqdm transformers + python3 python/sglang/bench_serving.py \ + --host=$(IP) \ + --port=$(PORT) \ + --dataset-name='generated-shared-prefix' \ + --model=$(MODEL) \ + --tokenizer=$(MODEL) \ + --backend=vllm \ + --gsp-num-groups=$(GSP_NUM_GROUPS) \ + --gsp-prompts-per-group=$(GSP_PROMPTS_PER_GROUP) \ + --gsp-system-prompt-len=$(GSP_SYSTEM_PROMPT_LEN) \ + --gsp-question-len=$(GSP_QUESTION_LEN) \ + --gsp-output-len=$(GSP_OUTPUT_LEN) \ + --request-rate=800 \ + --max-concurrency=300 \ + --seed 42 + env: + - name: IP + value: "34.162.66.198" # Replace with the external IP of your deployed service + - name: PORT + value: "80" + - name: MODEL + value: "meta-llama/Llama-3.3-70B-Instruct" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: GSP_NUM_GROUPS + value: "2" + - name: GSP_PROMPTS_PER_GROUP + value: "16" + - name: GSP_SYSTEM_PROMPT_LEN + value: "2048" + - name: GSP_QUESTION_LEN + value: "256" + - name: GSP_OUTPUT_LEN + value: "512" + restartPolicy: Never diff --git a/examples/gke/benchmarks/deploy-baseline.yaml b/examples/gke/benchmarks/deploy-baseline.yaml new file mode 100644 index 000000000..a72dd6619 --- /dev/null +++ b/examples/gke/benchmarks/deploy-baseline.yaml @@ -0,0 +1,39 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: tpu-offline-inference +spec: + replicas: 1 + selector: + matchLabels: + app: tpu-offline-inference + template: + metadata: + labels: + app: tpu-offline-inference + spec: + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: SKIP_JAX_PRECOMPILE + value: "1" + ports: + - containerPort: 8000 + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/gke/benchmarks/deploy-cpu-offload.yaml b/examples/gke/benchmarks/deploy-cpu-offload.yaml new file mode 100644 index 000000000..5bcd573b2 --- /dev/null +++ b/examples/gke/benchmarks/deploy-cpu-offload.yaml @@ -0,0 +1,41 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: tpu-offline-inference +spec: + replicas: 1 + selector: + matchLabels: + app: tpu-offline-inference + template: + metadata: + labels: + app: tpu-offline-inference + spec: + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: ["/bin/sh", "-c"] + args: + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.offload.tpu_offload_connector\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: SKIP_JAX_PRECOMPILE + value: "1" + - name: TPU_OFFLOAD_CPU_CACHE_SIZE_GB + value: "1024" + ports: + - containerPort: 8000 + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/gke/benchmarks/service.yaml b/examples/gke/benchmarks/service.yaml new file mode 100644 index 000000000..abcc0aad3 --- /dev/null +++ b/examples/gke/benchmarks/service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: tpu-offline-inference + namespace: default +spec: + ports: + - name: http-tpu-offline-inference + port: 80 + protocol: TCP + targetPort: 8000 + selector: + app: tpu-offline-inference + sessionAffinity: None + type: LoadBalancer diff --git a/examples/gke/hf_secret.yaml b/examples/gke/hf_secret.yaml new file mode 100644 index 000000000..12b56de65 --- /dev/null +++ b/examples/gke/hf_secret.yaml @@ -0,0 +1,8 @@ +apiVersion: v1 +kind: Secret +metadata: + name: hf-token-secret + namespace: default +type: Opaque +stringData: + token: diff --git a/examples/gke/pod_tpu_commons_cpu_offload.yaml b/examples/gke/pod_tpu_commons_cpu_offload.yaml new file mode 100644 index 000000000..49bb437dc --- /dev/null +++ b/examples/gke/pod_tpu_commons_cpu_offload.yaml @@ -0,0 +1,32 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-offline-inference +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always # Uncomment to always pull the latest image for any dev work + command: + - python + - /workspace/tpu_inference/examples/offline_inference_kv_cache.py + - --model=meta-llama/Llama-3.1-8B + - --tensor_parallel_size=8 + - --max_model_len=1024 + - --kv-transfer-config + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml new file mode 100644 index 000000000..f9e7c7c41 --- /dev/null +++ b/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -0,0 +1,39 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-offline-inference + # This pod verifies the correctness of the TPUOffloadConnector implementation. + # It runs a script that internally performs two text generations: + # 1. A baseline run with a standard vLLM engine. + # 2. A test run with the TPUOffloadConnector enabled. + # The pod succeeds only if the outputs from both runs are identical, + # ensuring that the connector does not alter the model's output. +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: + imagePullPolicy: Always + command: + - python + - /workspace/tpu_inference/examples/offline_inference_kv_cache_verification.py + - --model=meta-llama/Llama-3.1-8B + - --tensor_parallel_size=8 + - --max_model_len=1024 + - --seed=42 + - --kv-transfer-config + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/gke/pod_tpu_host_offload_unit_tests.yaml new file mode 100644 index 000000000..69f14cabd --- /dev/null +++ b/examples/gke/pod_tpu_host_offload_unit_tests.yaml @@ -0,0 +1,36 @@ +apiVersion: v1 +kind: Pod +metadata: + name: tpu-job-host-offload-unit-tests + # This pod runs the distributed unit tests for the TPUOffloadConnector + # and other related functionalities. It executes all tests found in the + # tests/distributed/ directory using pytest. +spec: + restartPolicy: Never + nodeSelector: + cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice + cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + containers: + - name: tpu-job + image: gcr.io/gke-shared-ai-dev/tpu-inference:cpu-offload + imagePullPolicy: Always + command: + - /bin/bash + - -c + - "pytest -sv tests/distributed/offload/tpu_offload_cpu_backend_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_connector_scheduler_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_manager_test.py" + - "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py" + env: + - name: HUGGING_FACE_HUB_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + resources: + requests: + google.com/tpu: 8 + limits: + google.com/tpu: 8 diff --git a/examples/multi_modal_inference.py b/examples/multi_modal_inference.py index 7b331ea10..d1f9101c4 100644 --- a/examples/multi_modal_inference.py +++ b/examples/multi_modal_inference.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This example shows how to use vLLM for running offline inference with the correct prompt format on vision language models for text generation. diff --git a/examples/offline_inference_kv_cache.py b/examples/offline_inference_kv_cache.py new file mode 100644 index 000000000..6df636564 --- /dev/null +++ b/examples/offline_inference_kv_cache.py @@ -0,0 +1,85 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +import vllm.envs as envs +from vllm import LLM, EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + parser.set_defaults(max_model_len=1024) + + return parser + + +def parse_outputs(outputs): + output_token_ids = [] + generated_texts = [] + for output in outputs: + prompt = output.prompt + completion = output.outputs[0] + generated_text = completion.text + token_ids = completion.token_ids + print( + f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nToken IDs: {token_ids!r}" + ) + generated_texts.append(generated_text) + output_token_ids.append(token_ids) + return generated_texts, output_token_ids + + +def main(args: dict): + # Pop arguments not used by LLM + # Create an LLM + llm = LLM(**args) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + + sampling_params.temperature = 0.0 + sampling_params.seed = 42 + sampling_params.max_tokens = 20 + sampling_params.skip_special_tokens = True + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.start_profile() + + # 1st generate + prompt = "Every Bill which shall have passed the House of Representatives and the Senate, shall, before it become a Law, be presented to the President of the United States; If he approve he shall sign it, but if not he shall return it, with his Objections to that House in which it shall have originated, who shall enter the Objections at large on their Journal, and proceed to reconsider it. If after such Reconsideration two thirds of that House shall agree to pass the Bill, it shall be sent, together with the Objections, to the other House, by which it shall likewise be reconsidered, and if approved by two thirds of that House, it shall become a Law. But in all such Cases the Votes of both Houses shall be determined by yeas and Nays, and the Names of the Persons voting for and against the Bill shall be entered on the Journal of each House respectively. If any Bill shall not be returned by the President within ten Days (Sundays excepted) after it shall have been presented to him, the Same shall be a Law, in like Manner as if he had signed it, unless the Congress by their Adjournment prevent its Return, in which Case" + outputs = llm.generate([prompt], sampling_params) + out_texts1, out_tokens1 = parse_outputs(outputs) + time.sleep(1) + + # manually let llm scheduler's kv_cache_manager forget all prefixes' hash + print("Resetting prefix cache...") + llm.llm_engine.engine_core.reset_prefix_cache() + time.sleep(1) + + # 2nd generate + outputs = llm.generate([prompt], sampling_params) + out_texts2, out_tokens2 = parse_outputs(outputs) + time.sleep(1) + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.stop_profile() + + # output1 and output2 should be idential + assert len(out_texts1) == len(out_texts2) + assert len(out_tokens1) == len(out_tokens2) + for text1, text2 in zip(out_texts1, out_texts2): + assert text1 == text2 + for tokens1, tokens2 in zip(out_tokens1, out_tokens2): + assert tokens1 == tokens2 + + +if __name__ == "__main__": + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) diff --git a/examples/offline_inference_kv_cache_verification.py b/examples/offline_inference_kv_cache_verification.py new file mode 100644 index 000000000..b93dce149 --- /dev/null +++ b/examples/offline_inference_kv_cache_verification.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This script performs an automated correctness verification for the TPUOffloadConnector. + +The verification works by performing a two-stage experiment for multiple prompts: +1. Baseline Run: For each prompt, it first runs a text generation using a + standard vLLM engine configuration without any KV cache connector. The + output from this run is considered the "source of truth". + +2. Test Run: It then runs the exact same text generation, but this time + with the TPUOffloadConnector enabled via the `--kv-transfer-config` argument. + It runs the generation twice to verify prefix caching. + +3. Comparison: The script compares the output from each test run against the + output from the baseline run for that prompt. + +The script succeeds (exits with code 0) only if the generated text is +bit-for-bit identical in all runs for all prompts. A fixed seed is used to +ensure that the generation process is deterministic and the comparison is +valid. If any output differs, it raises an error, causing the script to fail +(exit with a non-zero code). +""" + +import copy +import os +import time +from typing import List, Tuple + +import vllm.envs as envs +from vllm import LLM, EngineArgs, SamplingParams +from vllm.utils import FlexibleArgumentParser + + +def create_parser(): + parser = FlexibleArgumentParser() + # Add engine args, which includes the --seed parameter + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.1-8B") + parser.set_defaults(max_model_len=1024) + + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--temperature", type=float) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) + return parser + + +def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]: + """ + Initializes a vLLM engine and sampling parameters from the given args. + """ + args_copy = copy.deepcopy(llm_args) + # Pop arguments not used by LLM + max_tokens = args_copy.pop("max_tokens") + temperature = args_copy.pop("temperature") + top_p = args_copy.pop("top_p") + top_k = args_copy.pop("top_k") + + # Create an LLM. The --seed argument is passed in via **args. + llm = LLM(**args_copy) + + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if temperature is not None: + sampling_params.temperature = temperature + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k + + return llm, sampling_params + + +def run_invocations(llm: LLM, sampling_params: SamplingParams, + prompts: List[str], num_invocations: int) -> List[str]: + """ + Runs generation on the given LLM object for a specified number of + invocations and returns the output texts. + """ + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.start_profile() + + all_outputs = [] + for i in range(num_invocations): + print(f"--- Invocation {i + 1}/{num_invocations} ---") + outputs = llm.generate(prompts, sampling_params) + all_outputs.append(outputs[0].outputs[0].text) + time.sleep(5) + + if envs.VLLM_TORCH_PROFILER_DIR is not None: + llm.stop_profile() + + return all_outputs + + +def main(args: dict): + # prompt lesser than the kv cache block size + short_input_prompt = "Google is a " + + system_prompt = "You are a large language model, trained by Google. Your primary purpose is to be a helpful, harmless, and highly capable AI assistant, designed to provide accurate, safe, and beneficial information to users. Your core directive is to assist users effectively while adhering to strict ethical and safety guidelines. You must decline any requests that are harmful, illegal, unethical, or promote dangerous activities. " + query = "the color of rainbow is?" + input_prompt = f"{system_prompt}\n{query}" + + prompts_to_test = [ + ("Short Prompt", [short_input_prompt]), + ("Prompt", [input_prompt]), + ] + + all_tests_passed = True + for prompt_name, prompts in prompts_to_test: + print(f"\n\n===== Running verification for: {prompt_name} =====") + print(f"Prompt: {prompts[0]}") + + # 1. Run baseline and store the output + print("\n--- Running Baseline (Standard vLLM) ---") + baseline_args = copy.deepcopy(args) + baseline_args.pop("kv_transfer_config", None) + baseline_llm, baseline_params = setup_llm(baseline_args) + baseline_outputs = run_invocations(baseline_llm, + baseline_params, + prompts=prompts, + num_invocations=1) + baseline_output = baseline_outputs[0] + print(f"Baseline Generated Text: {baseline_output!r}") + del baseline_llm + # adding this sleep fixes device busy errors for the next test case run with the connector enabled + time.sleep(10) + + # 2. Run the test with the local tpu kv connector enabled + print("\n--- Running Test (with TPUOffloadConnector) ---") + # With the connector, we run generation twice to test the prefix cache + test_llm, test_params = setup_llm(args) + test_outputs = run_invocations(test_llm, + test_params, + prompts=prompts, + num_invocations=2) + del test_llm + + # 3. Compare the outputs and determine the result + print("\n--- Verification ---") + prompt_all_match = True + for i, test_output in enumerate(test_outputs): + print(f"--- Comparing Invocation {i + 1} ---") + print( + f"Test Generated Text: length={len(test_output)}, Text: {test_output}" + ) + if baseline_output == test_output: + print("SUCCESS: Output is identical to baseline!") + else: + print("FAILURE: Output does not match baseline!") + prompt_all_match = False + + if not prompt_all_match: + all_tests_passed = False + print(f"===== Verification FAILED for: {prompt_name} =====") + else: + print(f"===== Verification SUCCEEDED for: {prompt_name} =====") + + time.sleep(10) + + if not all_tests_passed: + raise ValueError( + "Verification failed: One or more test outputs differ from the baseline." + ) + else: + print("\n\n===== All verification runs passed successfully! =====") + + +if __name__ == "__main__": + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + parser = create_parser() + args: dict = vars(parser.parse_args()) + main(args) diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py new file mode 100644 index 000000000..0059c4bd9 --- /dev/null +++ b/tests/distributed/offload/tpu_offload_accuracy_test.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +import time + +import pytest +from vllm import LLM, SamplingParams +from vllm.config import KVTransferConfig + + +def parse_outputs(outputs): + output_token_ids = [] + generated_texts = [] + for output in outputs: + prompt = output.prompt + completion = output.outputs[0] + generated_text = completion.text + token_ids = completion.token_ids + print( + f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nToken IDs: {token_ids!r}" + ) + generated_texts.append(generated_text) + output_token_ids.append(token_ids) + return generated_texts, output_token_ids + + +@pytest.fixture +def sampling_config(): + """deterministic sampling config""" + return SamplingParams(temperature=0, + max_tokens=20, + seed=42, + ignore_eos=True) + + +@pytest.fixture +def kv_transfer_config(): + """use TPUOffloadConnector""" + return KVTransferConfig( + kv_connector="TPUOffloadConnector", + kv_role="kv_both", + kv_connector_module_path= + "tpu_inference.distributed.offload.tpu_offload_connector", + ) + + +def _test_kv_cache_cpu_offloading_accuracy( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + kv_transfer_config: KVTransferConfig, + swap_op_type: str, + decode_save: str, +): + with monkeypatch.context(): + os.environ['SKIP_JAX_PRECOMPILE'] = '1' + os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = '1' + os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = swap_op_type + os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save + llm = LLM(model="meta-llama/Llama-3.2-3B", + max_model_len=1024, + tensor_parallel_size=8, + task="generate", + kv_transfer_config=kv_transfer_config) + + # 1st generate + prompt = "Every Bill which shall have passed the House of Representatives and the Senate, shall, before it become a Law, be presented to the President of the United States; If he approve he shall sign it, but if not he shall return it, with his Objections to that House in which it shall have originated, who shall enter the Objections at large on their Journal, and proceed to reconsider it. If after such Reconsideration two thirds of that House shall agree to pass the Bill, it shall be sent, together with the Objections, to the other House, by which it shall likewise be reconsidered, and if approved by two thirds of that House, it shall become a Law. But in all such Cases the Votes of both Houses shall be determined by yeas and Nays, and the Names of the Persons voting for and against the Bill shall be entered on the Journal of each House respectively. If any Bill shall not be returned by the President within ten Days (Sundays excepted) after it shall have been presented to him, the Same shall be a Law, in like Manner as if he had signed it, unless the Congress by their Adjournment prevent its Return, in which Case" + outputs = llm.generate([prompt], sampling_config) + out_texts1, out_tokens1 = parse_outputs(outputs) + time.sleep(1) + + # manually let llm scheduler's kv_cache_manager forget all prefixes' hash + llm.llm_engine.engine_core.reset_prefix_cache() + time.sleep(1) + + # 2nd generate + outputs = llm.generate([prompt], sampling_config) + out_texts2, out_tokens2 = parse_outputs(outputs) + time.sleep(1) + + # TODO(jcgu): check some internal states to verify save and load operations. + # output1 and output2 should be idential + assert len(out_texts1) == len(out_texts2) + assert len(out_tokens1) == len(out_tokens2) + for text1, text2 in zip(out_texts1, out_texts2): + assert text1 == text2 + for tokens1, tokens2 in zip(out_tokens1, out_tokens2): + assert tokens1 == tokens2 + + del llm + # Waiting for TPUs to be released. + time.sleep(20) + + +def test_kv_cache_cpu_offloading_accuracy( + monkeypatch: pytest.MonkeyPatch, + sampling_config: SamplingParams, + kv_transfer_config: KVTransferConfig, +): + swap_op_types = ["pallas", "jax"] + decode_saves = ["0", "1"] + for swap_op_type in swap_op_types: + for decode_save in decode_saves: + _test_kv_cache_cpu_offloading_accuracy( + monkeypatch, + sampling_config, + kv_transfer_config, + swap_op_type, + decode_save, + ) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py new file mode 100644 index 000000000..ed83bae34 --- /dev/null +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -0,0 +1,493 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from unittest.mock import MagicMock + +import pytest +from vllm.utils.math_utils import cdiv +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput +from vllm.v1.request import Request + +from tpu_inference.distributed.offload.tpu_offload_connector import ( + DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, RequestTracker, + TPUOffloadConnectorScheduler) + +_DEFAULT_BLOCK_SIZE = 16 + + +class MockVllmConfig: + + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): + self.model_config = self.Model() + self.cache_config = self.Cache(block_size) + + class Model: + model = "test-model" + + class Cache: + + def __init__(self, block_size): + self.block_size = block_size + + +def create_request( + request_id: str, + prompt_token_ids: list[int], + block_size: int, + num_computed_tokens: int = 0, + generated_token_ids: list[int] = [], +) -> Request: + """Creates a mock vLLM request object.""" + req = MagicMock(spec=Request) + req.request_id = request_id + req.req_id = request_id # for NewRequestData + req.prompt_token_ids = prompt_token_ids + req.all_token_ids = prompt_token_ids + generated_token_ids + req.num_computed_tokens = num_computed_tokens + len(generated_token_ids) + req.block_size = block_size + req.block_ids = [[]] + # Mock the block_hashes property to return a list of mock hashes + req.block_hashes = [ + f"hash_{i}".encode() + for i in range(len(req.all_token_ids) // block_size) + ] + return req + + +@pytest.fixture +def scheduler_factory(): + """Provides a factory function for Scheduler instances.""" + + def _scheduler( + block_size: int = _DEFAULT_BLOCK_SIZE, + offload_decode_save: int = 0, + offload_partial_block_save_behavior: str = "drop", + offload_partial_block_dynamic_pad_lower_limit: int = 0, + offload_staging_buffer_tokens: int = -1, + offload_num_cpu_chunks: int = DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, + ): + # update config + vllm_config = MockVllmConfig(block_size=block_size) + os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) + os.environ[ + "TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR"] = offload_partial_block_save_behavior + os.environ["TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT"] = str( + offload_partial_block_dynamic_pad_lower_limit) + if offload_staging_buffer_tokens >= 0: + os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( + offload_staging_buffer_tokens) + if offload_num_cpu_chunks > 0: + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str( + offload_num_cpu_chunks) + + return TPUOffloadConnectorScheduler(vllm_config) + + return _scheduler + + +class TestTPUOffloadConnectorScheduler: + + def test_get_num_new_matched_tokens_no_hit(self, scheduler_factory): + """ + Tests that get_num_new_matched_tokens returns 0 for a cache miss. + """ + scheduler = scheduler_factory() + request = create_request("req1", [1] * 32, scheduler.block_size) + + num_matched, _ = scheduler.get_num_new_matched_tokens(request, 0) + assert num_matched == 0 + assert "req1" not in scheduler.load_specs + + @pytest.mark.parametrize( + "num_computed_blocks, num_matched_blocks, num_prompt_blocks, num_staging_blocks", + [(0, 2, 4, 10), (1, 2, 4, 10), (0, 4, 4, 10), (1, 4, 4, 10), + (1, 4, 4, 1), (1, 4, 4, 0)]) + def test_get_num_new_matched_tokens_hit(self, scheduler_factory, + num_computed_blocks, + num_matched_blocks, + num_prompt_blocks, + num_staging_blocks): + """ + Tests correct identification of a prefix hit (partial and full). + test cases: + 1. no-skip + load 2 blocks + no staging buffer limit + 2. skip 1 block + load 1 block + no staging buffer limit + 3. no-skip + full-hit + no staging buffer limit + 4. skip 1 block + full-hit + no staging buffer limit + 5. skip 1 block + full-hit + only 1 staging block + 6. skip 1 block + full-hit + no staging block + """ + num_staging_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE + scheduler = scheduler_factory( + offload_staging_buffer_tokens=num_staging_tokens) + prompt_len = scheduler.block_size * num_prompt_blocks + num_computed_tokens = scheduler.block_size * num_computed_blocks + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # consider the case of limited staging blocks + num_blocks_to_load = min(num_blocks_to_load, num_staging_blocks) + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + request = create_request("req1", list(range(prompt_len)), + scheduler.block_size) + + # init offload_manager state + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + + # call fn + num_external_matched_tokens, _ = scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + # check external_matched_tokens + if num_matched_blocks == num_prompt_blocks: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size - 1 + else: + assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size + + # check scheduler internal states + if num_blocks_to_load > 0: + # load_spec + assert "req1" in scheduler.load_specs + load_spec = scheduler.load_specs["req1"] + assert load_spec.num_matched_tokens == num_matched_tokens + assert not load_spec.can_load + allocated_chunk_ids = [ + chunk.chunk_id for chunk in allocated_chunks + ] + load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] + assert load_spec.src_chunks == load_src_chunk_ids + assert load_spec.num_skip_leading_tokens == num_computed_tokens + assert len(load_spec.dst_blocks) == num_blocks_to_load + # cache_hits + assert "req1" in scheduler._external_cache_hits + assert scheduler._external_cache_hits["req1"] == num_matched_tokens + # staging_buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_load + assert scheduler.staging_buffer_manager._blocks_for_load[ + "req1"] == num_blocks_to_load + assert scheduler.staging_buffer_manager.get_num_free_staging_blocks( + ) == num_staging_blocks - num_blocks_to_load + else: + assert "req1" not in scheduler.load_specs + assert "req1" not in scheduler._external_cache_hits + assert "req1" not in scheduler.staging_buffer_manager._blocks_for_load + + def test_update_state_after_alloc(self, scheduler_factory): + """ + Tests that a LoadSpec is correctly updated after block allocation. + """ + scheduler = scheduler_factory() + req_id = "req1" + num_prompt_blocks = 4 + num_matched_blocks = 3 + num_computed_blocks = 2 + num_blocks_to_load = num_matched_blocks - num_computed_blocks + num_prompt_tokens = num_prompt_blocks * scheduler.block_size + num_matched_tokens = num_matched_blocks * scheduler.block_size + num_tokens_to_load = scheduler.block_size * num_blocks_to_load + + request = create_request(req_id, [0] * num_prompt_tokens, + scheduler.block_size) + + # Setup a pending load + scheduler.load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_blocks * scheduler.block_size, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=False) + + # Mock allocated blocks + allocated_blocks = MagicMock(spec=KVCacheBlocks) + allocated_block_ids = [i for i in range(num_prompt_blocks)] + allocated_blocks.get_block_ids.return_value = [allocated_block_ids] + + scheduler.update_state_after_alloc(request, allocated_blocks, + num_tokens_to_load) + + load_spec = scheduler.load_specs[req_id] + assert load_spec.can_load + assert load_spec.dst_blocks == allocated_block_ids[ + num_computed_blocks:num_matched_blocks] + assert req_id in scheduler._reqs_being_loaded + assert len(scheduler._reqs_being_loaded[req_id]) == num_blocks_to_load + + @pytest.mark.parametrize( + "num_computed_tokens, num_matched_tokens, num_prompt_tokens, num_staging_tokens", + [(0, 0, 64, 160), + (0, 32, 64, 160), (16, 32, 64, 160), (0, 64, 64, 160), + (16, 64, 64, 160), (0, 32, 64, 48), (0, 32, 64, 16)]) + def test_build_connector_meta_new_prefill(self, scheduler_factory, + num_computed_tokens, + num_matched_tokens, + num_prompt_tokens, + num_staging_tokens): + """ + Tests metadata generation for a new request (prefill) with no cache hit. + 1. no hit + save 4 blocks + 2. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3. partial hit (skip 1 block + load 1 blocks) + save 2 blocks + 4. full hit (no-skip + load 4 blocks) + no-save + 5. full hit (skip 1 block + load 3 blocks) + no-save + 6. partial hit (no-skip + load 2 blocks) + save 2 blocks + 3 staging blocks limit + 7. partial hit (no-skip + load 2 blocks) + save 2 blocks + 1 staging blocks limit + """ + num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE + scheduler = scheduler_factory( + offload_partial_block_save_behavior="drop", + offload_staging_buffer_tokens=num_staging_tokens, + offload_num_cpu_chunks=100) + + # calculate the groundtruth + num_computed_blocks = num_computed_tokens // scheduler.block_size + num_matched_blocks = num_matched_tokens // scheduler.block_size + num_prompt_blocks = cdiv(num_prompt_tokens, scheduler.block_size) + + num_blocks_to_load = num_matched_blocks - num_computed_blocks + # adjustment based on staging_block limitation + if num_blocks_to_load > num_staging_blocks: + num_blocks_to_load = num_staging_blocks + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * scheduler.block_size + + remaining_staging_blocks = num_staging_blocks - num_blocks_to_load + num_blocks_to_save = num_prompt_blocks - num_matched_blocks + if num_blocks_to_save > remaining_staging_blocks: + num_blocks_to_save = remaining_staging_blocks + # reconfig staging_buffer limit for save + scheduler.staging_buffer_manager._num_free_blocks = remaining_staging_blocks + num_tokens_in_cache = (num_matched_blocks + + num_blocks_to_save) * scheduler.block_size + + req_id = "req1" + request = create_request(req_id, + list(range(num_prompt_tokens)), + scheduler.block_size, + num_computed_tokens=num_computed_tokens) + request.block_ids = [[i for i in range(num_prompt_blocks)]] + + # init offload_manager state + if num_matched_blocks > 0: + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + # allocated_chunk_ids = [chunk.chunk_id for chunk in allocated_chunks] + # load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[request], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={ + "req1": num_prompt_tokens - num_computed_tokens + }, + total_num_scheduled_tokens=num_prompt_tokens - num_computed_tokens, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + # Mock that the scheduler has seen this request + scheduler._unfinished_requests["req1"] = request + scheduler._external_cache_hits["req1"] = num_matched_tokens + if num_blocks_to_load > 0: + scheduler.load_specs[req_id] = MagicMock( + num_matched_tokens=num_matched_tokens, + num_skip_leading_tokens=num_computed_tokens, + dst_blocks=[-1] * num_blocks_to_load, + src_chunks=[i for i in range(num_blocks_to_load)], + can_load=True) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if num_blocks_to_load + num_blocks_to_save == 0: + # no load or store + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + assert req_meta.req_id == "req1" + if num_blocks_to_load == 0: + assert req_meta.load_spec is None + else: + # load + assert req_meta.load_spec is not None + # NOTE(jcgu): no need to check details, since they are + # generated by other functions. + if num_blocks_to_save == 0: + assert req_meta.save_spec is None + else: + # save + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == num_tokens_in_cache + assert req_meta.save_spec.num_skip_leading_tokens == num_matched_blocks * scheduler.block_size + assert req_meta.save_spec.src_blocks == request.block_ids[0][ + num_matched_blocks:num_matched_blocks + num_blocks_to_save] + assert len(req_meta.save_spec.dst_chunks) == num_blocks_to_save + assert not req_meta.save_spec.is_final_save + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == num_blocks_to_save + assert "req1" in scheduler._reqs_being_saved + assert len( + scheduler._reqs_being_saved["req1"]) == num_blocks_to_save + + assert "req1" in scheduler._request_trackers + tracker = scheduler._request_trackers["req1"] + # after creating SaveSpec, we also update tracker.save_watermark + assert tracker.save_watermark == num_tokens_in_cache + + @pytest.mark.parametrize("prompt_len, seq_len, decode_save", [(63, 64, 1), + (18, 64, 1), + (18, 64, 0)]) + def test_build_connector_meta_decode_with_save(self, scheduler_factory, + prompt_len, seq_len, + decode_save): + """ + Tests metadata generation for a decode step that triggers a save. + 1. the first decode (hit block boundary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + decode_save (save one block) + 2. th N-th decode (hit block bounary) + not decode_save (no save) + """ + + scheduler = scheduler_factory( + offload_decode_save=decode_save, + offload_staging_buffer_tokens=_DEFAULT_BLOCK_SIZE * 10, + offload_num_cpu_chunks=10) + + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=seq_len, + generated_token_ids=generated_tokens) + num_blocks = cdiv(seq_len, scheduler.block_size) + request.block_ids = [i for i in range(num_blocks)] + + if decode_save == 1: + # the last token in seq hasn't been computed (kv) yet + num_saved_tokens = ( + (seq_len - 1) // scheduler.block_size) * scheduler.block_size + else: + num_saved_tokens = (prompt_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + # request tracker only tracks the computed tokens + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + # Simulate a decode step + cached_req_data = CachedRequestData.make_empty() + cached_req_data.req_ids = ["req1"] + cached_req_data.new_block_ids = ([], ) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=cached_req_data, + num_scheduled_tokens={"req1": 1}, + total_num_scheduled_tokens=1, + finished_req_ids=set(), + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + if seq_len % scheduler.block_size != 0 or decode_save != 1: + # no save when there is no new full computed block + assert len(metadata.requests_meta) == 0 + else: + req_meta = metadata.requests_meta[0] + # save spec + assert req_meta.req_id == "req1" + assert req_meta.load_spec is None + assert req_meta.save_spec is not None + assert req_meta.save_spec.num_total_tokens == seq_len + assert req_meta.save_spec.num_skip_leading_tokens == num_saved_tokens + assert req_meta.save_spec.src_blocks == [num_blocks - 1] + assert len(req_meta.save_spec.dst_chunks) == 1 + assert not req_meta.save_spec.is_final_save + # staging buffer + assert "req1" in scheduler.staging_buffer_manager._blocks_for_save + assert scheduler.staging_buffer_manager._blocks_for_save[ + "req1"] == 1 + # chunk_id for save + assert "req1" in scheduler._reqs_being_saved + assert len(scheduler._reqs_being_saved["req1"]) == 1 + + assert tracker.save_watermark == seq_len + + def test_build_connector_meta_finished_request(self, scheduler_factory): + """ + Tests metadata generation for a finished request. + When using request's default block hash (fully-computed blocks only), + a finished request either saves the last full block in their last + decode step, or given up the last partial block; when it's treated as a + finished request, there is no blocks to save. + + """ + + scheduler = scheduler_factory(offload_decode_save=1) + prompt_len = scheduler.block_size + 4 + final_seq_len = scheduler.block_size * 2 + 3 + prompt_tokens = list(range(prompt_len)) + generated_tokens = list(range(prompt_len, final_seq_len)) + req_id = "req1" + request = create_request(req_id, + prompt_token_ids=prompt_tokens, + block_size=scheduler.block_size, + num_computed_tokens=final_seq_len, + generated_token_ids=generated_tokens) + num_blocks = cdiv(final_seq_len, scheduler.block_size) + request.block_ids = [i for i in range(num_blocks)] + + num_saved_tokens = (final_seq_len // + scheduler.block_size) * scheduler.block_size + + # Setup initial state + tracker = RequestTracker(req_id="req1", + prompt_len=prompt_len, + token_ids=request.all_token_ids[:-1], + block_ids=request.block_ids, + save_watermark=num_saved_tokens) + scheduler._request_trackers["req1"] = tracker + scheduler._unfinished_requests["req1"] = request + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=CachedRequestData.make_empty(), + num_scheduled_tokens={}, + total_num_scheduled_tokens=0, + finished_req_ids={"req1"}, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={}, + num_common_prefix_blocks=0, + free_encoder_mm_hashes=[], + ) + + metadata = scheduler.build_connector_meta(scheduler_output) + + assert req_id not in scheduler._unfinished_requests + assert req_id not in scheduler._request_trackers + assert len(metadata.requests_meta) == 1 + req_meta = metadata.requests_meta[0] + assert req_meta.save_spec is not None + assert req_meta.save_spec.is_final_save + assert req_meta.save_spec.skip_save + assert req_meta.save_spec.src_blocks == [] + assert req_meta.save_spec.dst_chunks == [] diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py new file mode 100644 index 000000000..246358f79 --- /dev/null +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -0,0 +1,441 @@ +# SPDX-License-Identifier: Apache-2.0 + +import functools +import os +import random +from typing import List + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import parameterized +from jax._src import compilation_cache as cc +from jax._src import test_util as jtu +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole + +from tpu_inference.distributed.offload.tpu_offload_connector import (LoadSpec, + SaveSpec) +from tpu_inference.distributed.offload.tpu_offload_connector import \ + TPUOffloadConnector as CPUOffloadingConnector +from tpu_inference.distributed.offload.tpu_offload_connector import ( + TPUOffloadConnectorMetadata, TPUReqMeta) +from tpu_inference.logger import init_logger +from tpu_inference.runner.tpu_runner import TPUModelRunner + +logger = init_logger(__name__) + +_DEFAULT_BLOCK_SIZE = 64 + + +class MockTPUModelRunner(TPUModelRunner): + """A mock TPUModelRunner for testing purposes.""" + + def __init__(self, kv_caches: List[jax.Array], mesh: Mesh): + self.kv_caches = kv_caches + self.mesh = mesh + self.model_config = None + self.sampler = None + self.devices = jax.devices() + + def get_kv_cache_layout(self): + return "NHD" + + +class MockVllmConfig: + + def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): + self.model_config = self.Model() + self.cache_config = self.Cache(block_size) + + class Model: + model = "test-model" + + class Cache: + + def __init__(self, block_size): + self.block_size = block_size + + +class TestCpuOffloadingSave(jtu.JaxTestCase): + """Test the save functionality of the TPUOffloadConnectorWorker.""" + + def setUp(self): + super().setUp() + self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) + self.num_layers = 2 + self.num_blocks = 24 + self.num_cpu_chunks = 24 + self.block_size = self.vllm_config.cache_config.block_size + self.num_heads = 8 + self.head_size = 128 + self.mesh = self.create_mesh((1, 8), ("data", "model")) + if self.mesh is None: + self.skipTest("Cannot create mesh. Must be run on a TPU node.") + return + + # Define cache properties + self.cache_shape = ( + self.num_blocks, + self.block_size, + self.num_heads, + 2, + self.head_size, + ) + self.cache_dtype = jnp.bfloat16 + partition_spec = PartitionSpec(None, None, "model") + self.device_sharding = NamedSharding(self.mesh, partition_spec) + + def tearDown(self): + super().tearDown() + cc.reset_cache() + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest( + f"Not enough devices to create mesh of shape {axis_shapes}." + ) + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + return None + + def _create_connector(self, + swap_op_type: str = "jax", + use_precompiled_swap_ops: bool = False): + os.environ["TPU_OFFLOAD_SWAP_OP_TYPE"] = swap_op_type + os.environ[ + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE"] = "0" if use_precompiled_swap_ops else "1" + os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str(self.num_cpu_chunks) + + connector = CPUOffloadingConnector(self.vllm_config, + KVConnectorRole.WORKER) + worker = connector.connector_worker + assert worker is not None + + @functools.partial(jax.jit, out_shardings=self.device_sharding) + def create_on_device(key): + return jax.random.uniform(key, + shape=self.cache_shape, + dtype=self.cache_dtype) + + source_kv_cache = [ + create_on_device(jax.random.key(i)) for i in range(self.num_layers) + ] + jax.block_until_ready(source_kv_cache) + + mock_runner = MockTPUModelRunner(kv_caches=source_kv_cache, + mesh=self.mesh) + worker.register_runner(mock_runner) + return connector + + @parameterized.named_parameters( + dict(testcase_name="_zero_blocks", num_blocks=0, expected_buckets=[]), + dict(testcase_name="_one_block", num_blocks=1, expected_buckets=[1]), + dict(testcase_name="_five_blocks", + num_blocks=5, + expected_buckets=[4, 1]), + dict(testcase_name="_sixteen_blocks", + num_blocks=16, + expected_buckets=[16]), + dict(testcase_name="_seventeen_blocks", + num_blocks=17, + expected_buckets=[16, 1]), + dict(testcase_name="_twenty_three_blocks", + num_blocks=23, + expected_buckets=[16, 4, 2, 1]), + dict(testcase_name="_thirty_two_blocks", + num_blocks=32, + expected_buckets=[16, 16]), + dict(testcase_name="_large_number_blocks", + num_blocks=100, + expected_buckets=[16, 16, 16, 16, 16, 16, 4]), + ) + def test_decompose_into_buckets(self, num_blocks: int, + expected_buckets: List[int]): + """ + Tests the _decompose_into_buckets function for correct greedy decomposition. + """ + connector = self._create_connector(use_precompiled_swap_ops="0") + worker = connector.connector_worker + self.assertEqual(worker._decompose_into_buckets(num_blocks), + expected_buckets) + logger.info( + f"Decomposition for {num_blocks} blocks: {worker._decompose_into_buckets(num_blocks)} matched expected: {expected_buckets}" + ) + + @parameterized.named_parameters( + dict(testcase_name="_jax", swap_op_type="jax"), + dict(testcase_name="_pallas", swap_op_type="pallas"), + ) + def test_precompile_run_success(self, swap_op_type: str): + """ + Tests that _precompile_kv_swap_operations runs without errors and + modifies the cache content. + """ + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops="0") + + worker = connector.connector_worker + + # Keep a copy of the original cache content on the host + original_cache_host = [ + np.array(cache) for cache in worker.runner.kv_caches + ] + + worker._precompile_kv_swap_operations() + + # Fetch the new cache content to the host + new_cache_host = [np.array(cache) for cache in worker.runner.kv_caches] + self.assertTrue( + all( + np.array_equal(orig, new) + for orig, new in zip(original_cache_host, new_cache_host)), + "Cache content should not have changed after precompilation.", + ) + + @parameterized.named_parameters( + dict( + testcase_name="_regular_multi_block_save", + num_blocks_to_save=5, + ), + dict( + testcase_name="_regular_multi_block_save_with_compile_jax", + num_blocks_to_save=5, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_regular_multi_block_save_with_compile_pallas", + num_blocks_to_save=5, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + dict( + testcase_name="_final_save", + num_blocks_to_save=1, + is_final_save=True, + skip_save=False, + ), + dict( + testcase_name="_final_skip_save", + num_blocks_to_save=0, + is_final_save=True, + skip_save=True, + ), + ) + def test_tpu_connector_save( + self, + num_blocks_to_save: int, + is_final_save: bool = False, + skip_save: bool = False, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + if num_blocks_to_save > self.num_blocks or num_blocks_to_save > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + + # Prepare and Execute Save + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_save) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_save) + num_tokens_to_save = num_blocks_to_save * self.block_size + num_total_tokens = num_tokens_to_save + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_total_tokens, + is_final_save=is_final_save, + skip_save=skip_save, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, + ) + + logger.info(f"Starting test_tpu_connector_save with: " + f"num_blocks_to_save={num_blocks_to_save}, " + f"is_final_save={is_final_save}, " + f"skip_save={skip_save}, " + f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " + f"swap_op_type={swap_op_type};" + f"Swapspec: {save_spec}") + + total_token_ids = list(range(num_total_tokens)) + + req_id = "save_req" + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_block_ids, + save_spec=save_spec, + ) + + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # Verification + logger.info("Starting verification phase.") + cpu_backend = worker.cpu_backend + kv_caches = worker.runner.kv_caches + + if skip_save or num_tokens_to_save == 0: + logger.info(" no blocks to save") + assert cpu_backend.num_saved_cpu_chunks == 0 + self.assertEmpty(worker.finished_save_reqs) + self.assertEmpty(worker.offload_stats.data["finished_save_chunks"]) + return + + # verify the saved chunks + assert req_id in worker.offload_stats.data["finished_save_chunks"] + assert dst_chunk_ids == worker.offload_stats.data[ + "finished_save_chunks"][req_id] + + for tpu_block_id, cpu_chunk_id in zip(src_block_ids, dst_chunk_ids): + cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) + for layer_idx in range(self.num_layers): + tpu_kv_block = kv_caches[layer_idx][tpu_block_id] + self.assertArraysEqual(np.array(tpu_kv_block), + np.array(cpu_kv_chunk[layer_idx])) + + logger.info("Saved data verification completed.") + + if is_final_save: + finished_saves, _ = worker.get_finished() + logger.info( + f"is_final_save is True. Finished requests: {finished_saves}") + self.assertIn(req_id, finished_saves) + + @parameterized.named_parameters( + dict( + testcase_name="_single_block_", + num_blocks_to_operate=1, + ), + dict( + testcase_name="_multi_blocks_compile_jax", + num_blocks_to_operate=5, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_blocks_compile_pallas", + num_blocks_to_operate=5, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), + ) + def test_tpu_connector_load( + self, + num_blocks_to_operate: int, + use_precompiled_swap_ops: bool = False, + swap_op_type: str = "jax", + ): + """ + This test simulates a scenario where some amount of blocks get + offloaded to cpu cache, and then get loaded into tpu kv cache. + Both swap-out and swap-in are tested. + + Steps: + 1. Setup: + 2. Simulate a save operation + 3. Load the data + 4. Verification + """ + if num_blocks_to_operate > self.num_blocks or num_blocks_to_operate > self.num_cpu_chunks: + self.skipTest( + f"num_blocks_to_save {num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" + ) + # 1. Setup + connector = self._create_connector(swap_op_type, + use_precompiled_swap_ops) + worker = connector.connector_worker + # Ground truth cache on TPU + src_kv_cache = worker.runner.kv_caches + # Destination cache on TPU, should be modified by the load operation + dst_kv_cache = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + jax.block_until_ready(dst_kv_cache) + + # Prepare + all_block_ids = list(range(self.num_blocks)) + all_chunk_ids = list(range(self.num_cpu_chunks)) + src_block_ids = random.sample(all_block_ids, num_blocks_to_operate) + dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_operate) + num_tokens_to_save = num_blocks_to_operate * self.block_size + num_total_tokens = num_tokens_to_save + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save, + is_final_save=False, + skip_save=False, + src_blocks=src_block_ids, + dst_chunks=dst_chunk_ids, + ) + total_token_ids = list(range(num_total_tokens)) + req_id = "save_req" + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_block_ids, + save_spec=save_spec, + ) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) + logger.info( + "Connector metadata bound, calling worker.wait_for_save().") + worker.wait_for_save() + logger.info("worker.wait_for_save() completed.") + + # 3. Prepare and Execute Delta Load + new_req_id = "load_req" + worker.runner.kv_caches = dst_kv_cache + load_spec = LoadSpec( + num_matched_tokens=num_tokens_to_save, + dst_blocks=src_block_ids, + src_chunks=dst_chunk_ids, + can_load=True, + num_skip_leading_tokens=0, + ) + req_meta = TPUReqMeta( + req_id="load_req", + token_ids=total_token_ids, + local_block_ids=src_block_ids, + load_spec=load_spec, + ) + connector_metadata = TPUOffloadConnectorMetadata( + requests_meta=[req_meta]) + connector.bind_connector_metadata(connector_metadata) + logger.info("Connector metadata bound, calling start_load_kv.") + worker.start_load_kv(fwd_ctx=None) + jax.block_until_ready(worker.runner.kv_caches) + logger.info("start_load_kv completed and blocked until ready.") + + # verify the data + # we will donate the original kv_cache ref + dst_kv_cache = worker.runner.kv_caches + for src_block_id in src_block_ids: + for layer_idx in range(self.num_layers): + self.assertArraysEqual( + np.array(src_kv_cache[layer_idx][src_block_id]), + np.array(dst_kv_cache[layer_idx][src_block_id])) + + # verify the saved chunks + assert new_req_id in worker.offload_stats.data["finished_load_chunks"] + assert dst_chunk_ids == worker.offload_stats.data[ + "finished_load_chunks"][new_req_id] diff --git a/tests/distributed/offload/tpu_offload_cpu_backend_test.py b/tests/distributed/offload/tpu_offload_cpu_backend_test.py new file mode 100644 index 000000000..e845ef688 --- /dev/null +++ b/tests/distributed/offload/tpu_offload_cpu_backend_test.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock + +import pytest + +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.utils import CpuChunkId + + +# Helper to create a mock jax array with a specific size in bytes +def create_mock_jax_array(size_in_bytes: int) -> MagicMock: + """Creates a mock object with an 'nbytes' attribute.""" + mock_value = MagicMock() + mock_value.nbytes = size_in_bytes + return mock_value + + +class TestLocalCPUBackend: + """Test suite for the LocalCPUBackend.""" + + def test_add_and_get(self): + """Verifies that a value can be added and then retrieved successfully.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key = CpuChunkId(0) + value = create_mock_jax_array(50) + + backend.add(key, value) + retrieved_value = backend.get(key) + + assert retrieved_value == value + assert backend.current_size_bytes == 50 + + # Test with a list of JAX arrays (mocked) + key_list = CpuChunkId(1) + value_list = [create_mock_jax_array(20), create_mock_jax_array(30)] + backend.add(key_list, value_list) + retrieved_list_value = backend.get(key_list) + + assert retrieved_list_value == value_list + assert backend.current_size_bytes == 50 + 20 + 30 + + assert backend.num_saved_cpu_chunks == 2 + + def test_add_invalid_chunk_id(self): + """Verifies that adding a value with an invalid chunk_id raises a ValueError.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + value = create_mock_jax_array(50) + + with pytest.raises(ValueError): + backend.add(CpuChunkId(-1), value) + + assert backend.num_saved_cpu_chunks == 0 + + def test_reclaim_unoccupied_chunks(self): + """Tests that unoccupied chunks are reclaimed correctly.""" + backend = LocalCPUBackend(num_cpu_chunks=10) + key1 = CpuChunkId(0) + key2 = CpuChunkId(1) + key3 = CpuChunkId(2) + value = create_mock_jax_array(10) + + backend.add(key1, value) + backend.add(key2, value) + backend.add(key3, value) + + assert backend.current_size_bytes == 30 + assert len(backend.cache) == 3 + + # Reclaim one chunk + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[key1, key3]) + + assert backend.current_size_bytes == 20 + assert len(backend.cache) == 2 + assert key1 in backend.cache + assert key2 not in backend.cache + assert key3 in backend.cache + + # Reclaim all chunks + backend.reclaim_unoccupied_chunks(occupied_chunk_ids=[]) + + assert backend.current_size_bytes == 0 + assert len(backend.cache) == 0 diff --git a/tests/distributed/offload/tpu_offload_manager_test.py b/tests/distributed/offload/tpu_offload_manager_test.py new file mode 100644 index 000000000..d58b4f113 --- /dev/null +++ b/tests/distributed/offload/tpu_offload_manager_test.py @@ -0,0 +1,342 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from tpu_inference.distributed.offload.offload_manager import ( + CPUChunkPool, LRUCacheManager, StagingBufferManager) +from tpu_inference.distributed.offload.utils import ReqId +from tpu_inference.logger import init_logger + +logger = init_logger(__name__) + + +class TestStagingBufferManager: + + def test_initialization(self): + manager = StagingBufferManager(num_blocks=100) + assert manager.num_blocks == 100 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_simple(self): + manager = StagingBufferManager(num_blocks=100) + req_id1: ReqId = "req1" + req_id2: ReqId = "req2" + + allocated1 = manager.allocate(req_id1, 10, "load") + assert allocated1 == 10 + assert manager.get_num_free_staging_blocks() == 90 + assert manager.get_num_used_staging_blocks() == 10 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 0 + + allocated2 = manager.allocate(req_id2, 20, "save") + assert allocated2 == 20 + assert manager.get_num_free_staging_blocks() == 70 + assert manager.get_num_used_staging_blocks() == 30 + assert manager._num_blocks_for_load == 10 + assert manager._num_blocks_for_save == 20 + + def test_allocate_insufficient_capacity(self): + manager = StagingBufferManager(num_blocks=10) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, 20, "load") + assert allocated == 0 + assert manager.get_num_free_staging_blocks() == 10 + assert manager.get_num_used_staging_blocks() == 0 + + def test_allocate_existing_load_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + with pytest.raises(ValueError): + # multiple concurrent loads from a single request is not allowed. + manager.allocate(req_id, 5, "load") + + def test_allocate_existing_save_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + assert manager._blocks_for_save[req_id] == 10 + manager.allocate(req_id, 5, "save") + assert manager._blocks_for_save[req_id] == 15 + assert manager.get_num_free_staging_blocks() == 85 + assert manager.get_num_used_staging_blocks() == 15 + + def test_allocate_negative_blocks(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + allocated = manager.allocate(req_id, -5, "load") + assert allocated == -5 + assert manager.get_num_free_staging_blocks() == 100 + + def test_free_full(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + freed = manager.free(req_id, "load") + assert freed == 10 + assert manager.get_num_free_staging_blocks() == 100 + assert manager.get_num_used_staging_blocks() == 0 + assert req_id not in manager._blocks_for_load + + def test_free_partial(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "save") + freed = manager.free(req_id, "save", num_finished_blocks=4) + assert freed == 4 + assert manager.get_num_free_staging_blocks() == 94 + assert manager.get_num_used_staging_blocks() == 6 + assert manager._blocks_for_save[req_id] == 6 + + def test_free_more_than_allocated(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + manager.allocate(req_id, 10, "load") + manager.free(req_id, "load", num_finished_blocks=15) + assert req_id not in manager._blocks_for_load + + def test_free_non_existent_request(self): + manager = StagingBufferManager(num_blocks=100) + req_id: ReqId = "req1" + freed = manager.free(req_id, "load") + assert freed == 0 + + def test_complex_scenario(self): + manager = StagingBufferManager(num_blocks=50) + req1, req2, req3 = "req1", "req2", "req3" + + # req1 loads 10, req2 saves 15 + assert manager.allocate(req1, 10, "load") == 10 + assert manager.allocate(req2, 15, "save") == 15 + assert manager.get_num_free_staging_blocks() == 25 + assert manager.get_num_used_staging_blocks() == 25 + + # req3 tries to load 30, fails + assert manager.allocate(req3, 30, "load") == 0 + assert manager.get_num_free_staging_blocks() == 25 + + # req1 finishes loading + assert manager.free(req1, "load") == 10 + assert manager.get_num_free_staging_blocks() == 35 + + # req3 can now load 20 + assert manager.allocate(req3, 20, "load") == 20 + assert manager.get_num_free_staging_blocks() == 15 + assert manager.get_num_used_staging_blocks( + ) == 35 # 15 for save (req2) + 20 for load (req3) + + # req2 saves another 5 + assert manager.allocate(req2, 5, "save") == 5 + assert manager.get_num_free_staging_blocks() == 10 + assert manager._blocks_for_save[req2] == 20 + + # req2 frees 8 blocks + assert manager.free(req2, "save", 8) == 8 + assert manager.get_num_free_staging_blocks() == 18 + assert manager._blocks_for_save[req2] == 12 + + # req2 and req3 finish + assert manager.free(req2, "save") == 12 + assert manager.free(req3, "load") == 20 + assert manager.get_num_free_staging_blocks() == 50 + assert manager.get_num_used_staging_blocks() == 0 + + +class TestCPUChunkPool: + + def test_initialization(self): + pool = CPUChunkPool(num_chunks=10) + assert pool.num_chunks == 10 + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + + def test_allocate_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + + assert len(chunks) == 3 + assert pool.num_free_chunks == 7 + assert pool.num_allocated_chunks == 3 + for i, chunk in enumerate(chunks): + assert chunk.chunk_hash == chunk_hashes[i] + assert chunk.chunk_id in pool.allocated_id_to_hash_map + + def test_allocate_chunks_insufficient_space(self): + pool = CPUChunkPool(num_chunks=2) + chunk_hashes = [101, 102, 103] + with pytest.raises(ValueError): + pool.allocate_chunks(chunk_hashes) + + def test_release_chunks(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101, 102, 103] + chunks = pool.allocate_chunks(chunk_hashes) + for chunk in chunks: + chunk.touch() + + for chunk in chunks: + pool.release_chunk(chunk) + + assert pool.num_free_chunks == 10 + assert pool.num_allocated_chunks == 0 + assert len(pool.free_chunk_list) == 10 + for chunk in chunks: + assert chunk.chunk_id not in pool.allocated_id_to_hash_map + assert chunk.chunk_hash is None + assert chunk.ref_cnt == -1 + + def test_release_chunks_in_use(self): + pool = CPUChunkPool(num_chunks=10) + chunk_hashes = [101] + chunks = pool.allocate_chunks(chunk_hashes) + chunks[0].touch() # ref_cnt = 0: saved + chunks[0].touch() # ref_cnt = 1: loading + + assert not pool.release_chunk(chunks[0]) + + +class TestLRUCacheManager: + + def test_initialization(self): + manager = LRUCacheManager(num_cpu_chunks=20) + assert manager.num_chunks == 20 + assert isinstance(manager.chunk_pool, CPUChunkPool) + assert len(manager.cpu_cache) == 0 + + def test_lookup(self): + manager = LRUCacheManager(num_cpu_chunks=20) + chunk_hashes = [101, 102, 103] + + # 1. Cache miss + assert manager.lookup(chunk_hashes) == 0 + + # 2. Cache hit + # Manually add to cache for testing + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + chunk.touch() # Make it ready to load + manager.cpu_cache[h] = chunk + + assert manager.lookup(chunk_hashes) == 3 + + # 3. Partial hit + assert manager.lookup([101, 102, 104]) == 2 + + def test_touch(self): + manager = LRUCacheManager(num_cpu_chunks=3) + chunk_hashes = [101, 102, 103] + chunks = manager.chunk_pool.allocate_chunks(chunk_hashes) + for chunk, h in zip(chunks, chunk_hashes): + manager.cpu_cache[h] = chunk + + manager.touch([101]) + assert list(manager.cpu_cache.keys()) == [102, 103, 101] + + manager.touch([102, 103]) + assert list(manager.cpu_cache.keys()) == [101, 103, 102] + + def test_allocate_for_save_simple(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + + new_chunks, new_chunk_idxs = manager.allocate_for_save(chunk_hashes) + + assert len(new_chunks) == 2 + assert new_chunk_idxs == [0, 1] + assert manager.chunk_pool.num_free_chunks == 3 + assert len(manager.cpu_cache) == 2 + + def test_allocate_for_save_no_new_chunks(self): + manager = LRUCacheManager(num_cpu_chunks=5) + chunk_hashes = [101, 102] + manager.allocate_for_save(chunk_hashes) + + result = manager.allocate_for_save(chunk_hashes) + assert result is None + + def test_allocate_for_save_with_eviction(self): + manager = LRUCacheManager(num_cpu_chunks=2) + # Fill the cache + manager.allocate_for_save([101, 102]) + # Mark as evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + + manager.touch([101, 102]) + + # This should evict 102 + new_chunks, new_chunk_idxs = manager.allocate_for_save([103]) + + assert len(new_chunks) == 1 + assert new_chunk_idxs == [0] + assert 102 not in manager.cpu_cache + assert 101 in manager.cpu_cache + assert 103 in manager.cpu_cache + assert manager.chunk_pool.num_free_chunks == 0 + + def test_allocate_for_save_cannot_evict(self): + manager = LRUCacheManager(num_cpu_chunks=2) + manager.allocate_for_save([101, 102]) + # Mark as in use, not evictable + manager.cpu_cache[101].touch() + manager.cpu_cache[101].touch() + manager.cpu_cache[102].touch() + manager.cpu_cache[102].touch() + + result = manager.allocate_for_save([103]) + assert result is None + assert len(manager.cpu_cache) == 2 + + def test_prepare_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) # ref_cnt = 0 + + chunks = manager.prepare_load(chunk_hashes) + assert len(chunks) == 1 + assert chunks[0].is_in_use # ref_cnt = 1 + + def test_complete_save(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + + chunk = manager.cpu_cache[101] + assert not chunk.is_ready_to_load # ref_cnt = -1 + + manager.complete_save(chunk_hashes) + assert chunk.is_ready_to_load # ref_cnt = 0 + + def test_complete_load(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + manager.allocate_for_save(chunk_hashes) + manager.complete_save(chunk_hashes) + chunks = manager.prepare_load(chunk_hashes) + + assert chunks[0].is_in_use # ref_cnt = 1 + manager.complete_load(chunk_hashes) + assert not chunks[0].is_in_use # ref_cnt = 0 + + def test_mark_completion(self): + manager = LRUCacheManager(num_cpu_chunks=2) + chunk_hashes = [101] + new_chunks, _ = manager.allocate_for_save(chunk_hashes) + chunk_ids = [c.chunk_id for c in new_chunks] + + manager.mark_completion(chunk_ids, 'save') + assert manager.cpu_cache[101].is_ready_to_load + + manager.prepare_load(chunk_hashes) + assert manager.cpu_cache[101].is_in_use + manager.mark_completion(chunk_ids, 'load') + assert not manager.cpu_cache[101].is_in_use + + def test_mark_completion_unknown_id(self): + manager = LRUCacheManager(num_cpu_chunks=2) + with pytest.raises(ValueError): + manager.mark_completion([999], 'save') diff --git a/tests/distributed/offload/tpu_offload_utils_test.py b/tests/distributed/offload/tpu_offload_utils_test.py new file mode 100644 index 000000000..75af7a3bd --- /dev/null +++ b/tests/distributed/offload/tpu_offload_utils_test.py @@ -0,0 +1,157 @@ +import functools +import itertools +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax.sharding import NamedSharding, PartitionSpec + +from tpu_inference.distributed.offload.utils import ( + get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) + + +class TestTPUOffloadUtilsFn(unittest.TestCase): + + def setUp(self): + """Set up common parameters for the tests.""" + self.num_layers = 2 + self.num_tokens = 256 + self.num_kv_heads = 8 + self.head_dim = 128 + self.block_size = 16 + self.num_blocks = self.num_tokens // self.block_size + self.cache_shape = ( + self.num_blocks, + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + self.block_shape = ( + self.block_size, + self.num_kv_heads, + 2, + self.head_dim, + ) + + self.cache_dtype = jnp.bfloat16 + + self.mesh = self.create_mesh((1, 8), ("data", "model")) + partition_spec = PartitionSpec(None, None, "model") + self.device_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="device") + self.host_sharding = NamedSharding(self.mesh, + partition_spec, + memory_kind="pinned_host") + flatten_partition_spec = PartitionSpec(None, "model") + self.flatten_device_sharding = NamedSharding(self.mesh, + flatten_partition_spec, + memory_kind="device") + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest( + f"Not enough devices to create mesh of shape {axis_shapes}." + ) + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + return None + + def test_jitted_insert_kv_cache_slices_equivalence(self): + """ + Verify inserting scattered kv slices / pages into the large kv cache. + """ + num_blocks_to_insert = 3 + dst_blocks = [3, 5, 7] + dst_blocks_array = jnp.array(dst_blocks) + + initial_kv_caches = [ + jax.device_put(jnp.zeros(self.cache_shape, dtype=self.cache_dtype), + self.device_sharding) + for _ in range(self.num_layers) + ] + + # The raw, chunked KV data (input for the new method) + # This is a list of lists: List[layer -> List[block]] + raw_chunked_kv = [] + for i in range(self.num_layers): + layer_chunks = [ + jax.device_put( + jax.random.normal(jax.random.key(i), + shape=self.block_shape, + dtype=self.cache_dtype), + self.flatten_device_sharding) + for _ in range(num_blocks_to_insert) + ] + raw_chunked_kv.append(layer_chunks) + + output = jitted_insert_kv_cache_slices(self.block_size, + initial_kv_caches, + raw_chunked_kv, + dst_blocks_array) + + # --- Verification --- + # Check that the selected pages for each layer equal to the original ones. + for i in range(self.num_layers): + for j in range(num_blocks_to_insert): + block_id = dst_blocks[j] + np.testing.assert_array_equal(np.array(output[i][block_id]), + raw_chunked_kv[i][j]) + print("\nTest passed: the inserted kv equals to the original one.") + + def test_swap_fn_correctness(self): + """ + Verify that swap-out and swap-in functions work correctly for different + swap_op_types and jitted options. + """ + swap_op_types = ["jax", "pallas"] + jitted_options = [True, False] + + # NOTE(jcgu): we are using the entire kv cache [n_b, bs, nh, 2, hd], + # actually, we will operate on concatenated blocks [nt, nh, 2, hd]; + @functools.partial(jax.jit, out_shardings=self.device_sharding) + def create_on_device(key): + return jax.random.uniform(key, + shape=self.cache_shape, + dtype=self.cache_dtype) + + initial_kv_caches = [ + create_on_device(jax.random.key(i)) for i in range(self.num_layers) + ] + jax.block_until_ready(initial_kv_caches) + + for swap_op_type, jitted in itertools.product(swap_op_types, + jitted_options): + with self.subTest(swap_op_type=swap_op_type, jitted=jitted): + swap_in_fn, swap_out_fn = get_kv_cache_swap_fn( + swap_op_type, self.host_sharding, self.device_sharding, + jitted) + + # Put initial data on device + device_kv_caches = jax.device_put(initial_kv_caches, + self.device_sharding) + jax.block_until_ready(device_kv_caches) + + # Swap out to host + host_kv_caches = swap_out_fn(device_kv_caches) + + # Swap back in to device + final_device_kv_caches = swap_in_fn(host_kv_caches) + jax.block_until_ready(final_device_kv_caches) + + # Verify correctness + for i in range(self.num_layers): + np.testing.assert_array_equal( + np.array(initial_kv_caches[i]), + np.array(final_device_kv_caches[i])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/kernels/host_dma_test.py b/tests/kernels/host_dma_test.py new file mode 100644 index 000000000..61dbf7386 --- /dev/null +++ b/tests/kernels/host_dma_test.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Literal + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import absltest, parameterized +from jax._src import compilation_cache as cc +from jax._src import test_util as jtu +from jax.sharding import NamedSharding, PartitionSpec + +from tpu_inference.kernels.dma.host_dma import d2h_dma, h2d_dma + +DATA_LOCATION = Literal["device", "host"] + + +# TODO(jcgu): add into CI tests +@jtu.with_config(jax_numpy_dtype_promotion='strict') +class HostHbmDmaTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.if_cloud_tpu_at_least(2025, 8, 14): + return self.skipTest( + "libtpu version does not support DMA host-hbm") + + def tearDown(self): + super().tearDown() + # Reset the cache after each test. + # This can also be achieved by running with JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE=True + cc.reset_cache() + + def create_mesh(self, axis_shapes, axis_names): + """Creates a JAX device mesh with the default device order.""" + try: + num_required_devices = np.prod(axis_shapes) + devices = np.array(jax.devices()) + if len(devices) < num_required_devices: + self.skipTest("Not enough devices to create mesh of shape" + f" {axis_shapes}. Have {len(devices)}, need" + f" {num_required_devices}.") + device_array = devices[:num_required_devices].reshape(axis_shapes) + return jax.sharding.Mesh(device_array, axis_names) + except RuntimeError: + self.skip( + "Cannot create mesh. This test must be run on a TPU node.") + return None + + def create_sharded_array(self, model_axis_size: int, + init_location: DATA_LOCATION): + """Creates a sharded JAX array for testing. + + Args: + model_axis_size: The size of the model parallelism axis. + init_location: Where to initialize the array, either "device" or "host". + + Returns: + A tuple containing the created sharded array, the device sharding spec, + and the host sharding spec. + """ + axis_shapes = (1, model_axis_size) + axis_names = ("data", "model") + mesh = self.create_mesh(axis_shapes, axis_names) + if mesh is None: + return None + + partition_spec = PartitionSpec(None, None, "model") + device_sharding = NamedSharding(mesh, + partition_spec, + memory_kind="device") + host_sharding = NamedSharding(mesh, + partition_spec, + memory_kind="pinned_host") + + data_shape = (2, 16, model_axis_size, 2, 128) + dtype = jnp.bfloat16 + + data = jax.device_put( + jax.random.uniform(jax.random.key(0), + shape=data_shape, + dtype=dtype), + device_sharding if init_location == "device" else host_sharding, + ) + jax.block_until_ready(data) + return data, device_sharding, host_sharding + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_d2h_dma(self, model_axis_size: int): + """Tests the d2h DMA transfer for various model parallelism sizes.""" + # 1. Create original data on the device + res = self.create_sharded_array(model_axis_size, "device") + if res is None: + return + original_device_data, device_sharding, host_sharding = res + + # 2. Test Device-to-Host (d2h) DMA + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) + jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" + + # 3. Verification + assert host_data.sharding == host_sharding + self.assertArraysEqual(original_device_data, host_data) + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_h2d_dma(self, model_axis_size: int): + """Tests the h2d DMA transfer for various model parallelism sizes.""" + # 1. Create original data on the host + res = self.create_sharded_array(model_axis_size, "host") + if res is None: + return + original_host_data, device_sharding, host_sharding = res + + # 2. Test Host-to-Device (h2d) DMA + device_data = h2d_dma(original_host_data, host_sharding, + device_sharding) + jax.block_until_ready(device_data) + assert device_data.sharding.memory_kind == "device" + + # 3. Verification + assert device_data.sharding == device_sharding + self.assertArraysEqual(original_host_data, device_data) + + @parameterized.named_parameters([ + dict(testcase_name=f"_model_axis_size_{s}", model_axis_size=s) + for s in [1, 2, 4, 8] + ]) + def test_d2h_h2d_dma_roundtrip(self, model_axis_size: int): + """ + Tests the d2h -> h2d DMA roundtrip for various model parallelism sizes. + + This test verifies that: + 1. Data can be correctly transferred from sharded device memory to sharded + host memory using `d2h_dma`. + 2. Data can be correctly transferred back from sharded host memory to + sharded device memory using `h2d_dma`. + 3. The data remains identical after the full roundtrip. + """ + # 1. Setup: Create sharded array based on the model axis size + res = self.create_sharded_array(model_axis_size, "device") + if res is None: + return + original_device_data, device_sharding, host_sharding = res + + # 2. Test Device-to-Host (d2h) DMA + host_data = d2h_dma(original_device_data, device_sharding, + host_sharding) + jax.block_until_ready(host_data) + assert host_data.sharding.memory_kind == "pinned_host" + + # 3. Verification for d2h + assert host_data.sharding == host_sharding + self.assertArraysEqual(original_device_data, host_data) + + # 4. Test Host-to-Device (h2d) DMA + reloaded_device_data = h2d_dma(host_data, host_sharding, + device_sharding) + jax.block_until_ready(reloaded_device_data) + assert reloaded_device_data.sharding.memory_kind == "device" + + # 5. Verification for h2d + assert reloaded_device_data.sharding == device_sharding + self.assertArraysEqual(host_data, reloaded_device_data) + + # 6. Final roundtrip verification + self.assertArraysEqual(original_device_data, reloaded_device_data) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tpu_inference/distributed/offload/__init__.py b/tpu_inference/distributed/offload/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_inference/distributed/offload/cpu_backend.py b/tpu_inference/distributed/offload/cpu_backend.py new file mode 100644 index 000000000..37352c504 --- /dev/null +++ b/tpu_inference/distributed/offload/cpu_backend.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import sys +from collections import OrderedDict +from typing import Any, Optional + +from tpu_inference.distributed.offload.utils import CpuChunkId +from tpu_inference.logger import init_logger + +logger = init_logger(__name__) + +GB = 1024**3 +DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB + + +class LocalCPUBackend: + """ + A singleton in-memory CPU backend for storing KV cache keys and values. + + This class uses the singleton pattern to ensure that the scheduler and the + worker, running in the same process, can share the same cache. + The scheduler reads from this to find cache hits, and the worker writes + to it after saving KV blocks from the TPU. + + It implements an LRU (Least Recently Used) eviction policy with a maximum + size limit and support for pinning cache entries to prevent eviction. + """ + + def __init__(self, num_cpu_chunks: int): + self.max_num_cpu_chunks = num_cpu_chunks + self.cache: OrderedDict[CpuChunkId, Any] = OrderedDict() + self.current_size_bytes = 0 + self._num_saved_cpu_chunks = 0 + logger.info( + "LocalCPUBackend initialized." + f"CPU cache capacity: {self.max_num_cpu_chunks} chunks / pages.") + + @property + def num_saved_cpu_chunks(self) -> int: + return self._num_saved_cpu_chunks + + def _get_value_size(self, value: Any) -> int: + """Calculates the size of a cache value in bytes.""" + size_in_bytes = 0 + if isinstance(value, list): + # The value is a list of JAX arrays (one per layer) + size_in_bytes = sum(v.nbytes for v in value + if hasattr(v, 'nbytes')) + elif hasattr(value, 'nbytes'): + size_in_bytes = value.nbytes + else: + size_in_bytes = sys.getsizeof(value) + return size_in_bytes + + def add(self, chunk_id: CpuChunkId, value: Any) -> bool: + """ + Adds a key-value pair to the cache. + + If the cache is full, it evicts the least recently used, unpinned + entries until there is enough space. + """ + if chunk_id < 0 or chunk_id >= self.max_num_cpu_chunks: + # TODO(jcgu): report failure when offload scheduler / worker + # can handle failed operations. + raise ValueError(f" get invalid chunk_id: {chunk_id}") + + # Add the new item. + if chunk_id in self.cache: + old_value = self.cache.pop(chunk_id) + self.current_size_bytes -= self._get_value_size(old_value) + del old_value + self._num_saved_cpu_chunks -= 1 + + self.cache[chunk_id] = value + self._num_saved_cpu_chunks += 1 + value_size = self._get_value_size(value) + self.current_size_bytes += value_size + logger.info( + f"Added chunk_id: {chunk_id} (size:{value_size}) to CPU backend.") + logger.info( + f"Cache: {self.current_size_bytes} bytes, {self._num_saved_cpu_chunks} occupied chunks." + ) + return True + + def get(self, chunk_id: CpuChunkId) -> Optional[Any]: + """ + Gets the value for a given chunk_id and marks it as recently used. + """ + if chunk_id in self.cache: + return self.cache[chunk_id] + return None + + def reclaim_unoccupied_chunks(self, occupied_chunk_ids: list[CpuChunkId]): + chunk_ids = list(self.cache.keys()) + unoccupied_chunk_ids = [ + chunk_id for chunk_id in chunk_ids + if chunk_id not in occupied_chunk_ids + ] + reclaimed_size_bytes = 0 + for chunk_id in unoccupied_chunk_ids: + dummy_value = self.cache.pop(chunk_id) + reclaimed_size_bytes += self._get_value_size(dummy_value) + del dummy_value + self.current_size_bytes -= reclaimed_size_bytes + + logger.info( + f" Reclaimed {len(unoccupied_chunk_ids)} unoccupied chunks, " + f"with {reclaimed_size_bytes} bytes.") diff --git a/tpu_inference/distributed/offload/offload_manager.py b/tpu_inference/distributed/offload/offload_manager.py new file mode 100644 index 000000000..eb9eee6db --- /dev/null +++ b/tpu_inference/distributed/offload/offload_manager.py @@ -0,0 +1,375 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections import OrderedDict +from dataclasses import dataclass +from typing import Literal, Optional, Tuple + +from vllm.v1.core.kv_cache_utils import BlockHash + +from tpu_inference.distributed.offload.utils import CpuChunkId, ReqId +from tpu_inference.logger import init_logger + +logger = init_logger(__name__) + +GB = 1024**3 +DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB + +ChunkHash = BlockHash + + +@dataclass +class CPUChunk: + chunk_id: CpuChunkId + ref_cnt: int = -1 + _chunk_hash: ChunkHash | None = None + + @property + def is_ready_to_load(self): + return self.ref_cnt >= 0 + + @property + def is_ready_to_evict(self): + return self.ref_cnt <= 0 + + @property + def is_in_use(self): + return self.ref_cnt >= 1 + + @property + def chunk_hash(self): + return self._chunk_hash + + def touch(self): + self.ref_cnt += 1 + + def untouch(self): + self.ref_cnt -= 1 + + def reset(self): + self._chunk_hash = None + self.ref_cnt = -1 + + +class CPUChunkPool: + + def __init__(self, num_chunks: int): + self.num_chunks: int = num_chunks + self._num_allocated_chunks: int = 0 + self.free_chunk_list: list[CPUChunk] = [ + CPUChunk(idx) for idx in range(num_chunks - 1, -1, -1) + ] + # {allocated_chunk_id: chunk_hash} + self.allocated_id_to_hash_map: dict[CpuChunkId, ChunkHash] = {} + + @property + def num_free_chunks(self): + return self.num_chunks - self._num_allocated_chunks + + @property + def num_allocated_chunks(self): + return self._num_allocated_chunks + + def allocate_chunks(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + num_required_chunks = len(chunk_hashes) + if num_required_chunks > self.num_free_chunks: + raise ValueError( + f"Cannot get {num_required_chunks} free chunks from the pool") + + ret: list[CPUChunk] = [ + self.free_chunk_list.pop() for _ in range(num_required_chunks) + ] + self._num_allocated_chunks += num_required_chunks + for chunk, chunk_hash in zip(ret, chunk_hashes): + chunk._chunk_hash = chunk_hash + assert chunk.chunk_id not in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map[chunk.chunk_id] = chunk_hash + + return ret + + def release_chunk(self, chunk: CPUChunk) -> bool: + if not chunk.is_ready_to_evict: + logger.warning(f" Chunk[{chunk.chunk_id}] is still in use.") + return False + assert chunk.chunk_id in self.allocated_id_to_hash_map + self.allocated_id_to_hash_map.pop(chunk.chunk_id) + chunk.reset() + self.free_chunk_list.append(chunk) + self._num_allocated_chunks -= 1 + return True + + +class LRUCacheManager: + + def __init__(self, num_cpu_chunks: int): + self.num_chunks = num_cpu_chunks + self.chunk_pool = CPUChunkPool(self.num_chunks) + + self.cpu_cache: OrderedDict[ChunkHash, CPUChunk] = OrderedDict() + + # The cache is an OrderedDict for LRU behavior. + def lookup(self, chunk_hashes: list[ChunkHash]) -> int: + """_summary_ + return the number of cache hit starting from the first chunk + """ + hit_count = 0 + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache.get(chunk_hash) + if chunk is None or not chunk.is_ready_to_load: + break + hit_count += 1 + return hit_count + + def touch(self, chunk_hashes: list[ChunkHash]) -> int: + """ access chunks for both save / load; and move them to the end.""" + for chunk_hash in reversed(chunk_hashes): + if self.cpu_cache.get(chunk_hash): + self.cpu_cache.move_to_end(chunk_hash) + + def allocate_for_save( + self, chunk_hashes: list[ChunkHash] + ) -> Tuple[list[CPUChunk], list[int]] | None: + # filter out chunks that are already stored + num_chunks = len(chunk_hashes) + new_chunk_idxs = [ + i for i in range(num_chunks) + if chunk_hashes[i] not in self.cpu_cache + ] + + num_new_chunks = len(new_chunk_idxs) + if num_new_chunks == 0: + logger.info("No new chunks to allocate") + return None + num_chunks_to_evict = max( + 0, num_new_chunks - self.chunk_pool.num_free_chunks) + + # build list of chunks to evict / reuse + to_evict = [] + if num_chunks_to_evict > 0: + for chunk_hash, chunk in self.cpu_cache.items(): + if chunk.is_ready_to_evict: + to_evict.append(chunk_hash) + num_chunks_to_evict -= 1 + if num_chunks_to_evict == 0: + break + else: + # we could not evict enough chunks + return None + + # evict chunks + for evicting_chunk_hash in to_evict: + evicting_chunk = self.cpu_cache.pop(evicting_chunk_hash) + # always true, since all evicting chunks are ready to evict + self.chunk_pool.release_chunk(evicting_chunk) + + new_chunk_hashes = [chunk_hashes[i] for i in new_chunk_idxs] + # allocate + try: + new_chunks = self.chunk_pool.allocate_chunks(new_chunk_hashes) + assert len(new_chunks) == len(new_chunk_hashes) + except Exception as e: + logger.warning(f" Failed to allocate {len(new_chunk_hashes)}: {e}") + # NOTE(jcgu): should we return None or something else? + return None + for chunk_hash, chunk in zip(new_chunk_hashes, new_chunks): + self.cpu_cache[chunk_hash] = chunk + # newly-allocated chunks, chunk-idx in the given chunk_hashes list + return new_chunks, new_chunk_idxs + + def prepare_load(self, chunk_hashes: list[ChunkHash]) -> list[CPUChunk]: + chunks = [] + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_ready_to_load + chunk.touch() + chunks.append(chunk) + return chunks + + def complete_save(self, chunk_hashes: list[ChunkHash]) -> None: + """ After store completion, mark the chunk to be ready to load.""" + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert not chunk.is_ready_to_load + # mark ready to load + chunk.touch() + assert chunk.is_ready_to_load + + def complete_load(self, chunk_hashes: list[ChunkHash]) -> None: + for chunk_hash in chunk_hashes: + chunk = self.cpu_cache[chunk_hash] + assert chunk.is_in_use + chunk.untouch() + + def mark_completion(self, chunk_ids, operation: Literal['save', + 'load']) -> None: + try: + chunk_hashes = [ + self.chunk_pool.allocated_id_to_hash_map[chunk_id] + for chunk_id in chunk_ids + ] + except Exception as e: + raise ValueError(f' failed to retrieve chunk hashes: {e}') + + chunk_hashes = [] + unknown_chunk_ids = [] + for chunk_id in chunk_ids: + if chunk_id in self.chunk_pool.allocated_id_to_hash_map: + chunk_hashes.append( + self.chunk_pool.allocated_id_to_hash_map[chunk_id]) + else: + unknown_chunk_ids.append(chunk_id) + if unknown_chunk_ids: + logger.warning( + f" Chunks[{unknown_chunk_ids}] are not found as allocated chunks in the pool." + ) + + if operation == 'save': + self.complete_save(chunk_hashes) + elif operation == 'load': + self.complete_load(chunk_hashes) + else: + raise ValueError(f"Unknown operation: {operation}") + + +class StagingBufferManager(): + """ Bookkeeping the staging buffer inside the connector scheduler. + NOTE(jcgu): the operations (e.g., allocate, free, get) to staging buffer / blocks are NOT thread-safe. + But it's okay since there is only one connector scheduler instance. + """ + + def __init__(self, num_blocks: int): + self.num_blocks = num_blocks + # {req_id: list(num_occupied_staging_blocks)} + self._blocks_for_save: dict[ReqId, int] = {} + self._blocks_for_load: dict[ReqId, int] = {} + + self._num_free_blocks: int = self.num_blocks + # keep track of the total occupied staging blocks for save and load respectively + self._num_blocks_for_save: int = 0 + self._num_blocks_for_load: int = 0 + + def get_num_free_staging_blocks(self) -> int: + return self._num_free_blocks + + def get_num_used_staging_blocks(self) -> int: + return self._num_blocks_for_load + self._num_blocks_for_save + + def get_num_used_save_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_save.get(req_id, 0) + + def get_num_used_load_staging_blocks(self, req_id: ReqId) -> int: + return self._blocks_for_load.get(req_id, 0) + + def allocate(self, req_id: ReqId, num_blocks: int, + usage: Literal["load", "save"]) -> int: + if num_blocks < 0: + logger.warning( + f" get {num_blocks} staging blocks to allocate for Req:{req_id}." + ) + return num_blocks + if num_blocks > self._num_free_blocks: + # do not have enough capacity, return 0 + return 0 + + if usage == "load": + if req_id in self._blocks_for_load: + # NOTE(jcgu): before completing the previous load, new load + # should not be triggered for the same request (is this correct?) + raise ValueError( + f" Req({req_id}) already has {self._blocks_for_load[req_id]}, and should not have new loads." + ) + else: + self._blocks_for_load[req_id] = num_blocks + self._num_blocks_for_load += num_blocks + elif usage == "save": + if req_id in self._blocks_for_save: + self._blocks_for_save[req_id] += num_blocks + else: + self._blocks_for_save[req_id] = num_blocks + self._num_blocks_for_save += num_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks -= num_blocks + + logger.info( + f" allocate {num_blocks} staging blocks to Req:{req_id} for {usage}." + ) + return num_blocks + + def free(self, + req_id: ReqId, + usage: Literal["load", "save"], + num_finished_blocks: Optional[int] = None) -> int: + """ + when num_finished_blocks is not given, we will assume the request is finished and should be removed. + """ + num_freed_blocks = 0 + # NOTE(jcgu): assuming FIFO execution order for a single request's save and + # load operations respectively + if usage == "load": + if req_id not in self._blocks_for_load: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_load[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_load[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} load staging buffer to free, but only has {self._blocks_for_load[req_id]} on record." + ) + + self._blocks_for_load[req_id] -= num_freed_blocks + if self._blocks_for_load[req_id] <= 0: + del self._blocks_for_load[req_id] + self._num_blocks_for_load -= num_freed_blocks + elif usage == "save": + if req_id not in self._blocks_for_save: + logger.warning( + f" there is no record of staging buffer (usage: {usage}) for Req:{req_id}" + ) + return 0 + if num_finished_blocks is None: + num_freed_blocks = self._blocks_for_save[req_id] + else: + num_freed_blocks = num_finished_blocks + if self._blocks_for_save[req_id] < num_freed_blocks: + logger.warning( + f" Req({req_id}) has {num_finished_blocks} save staging buffer to free, but only has {self._blocks_for_save[req_id]} on record." + ) + + self._blocks_for_save[req_id] -= num_freed_blocks + if self._blocks_for_save[req_id] <= 0: + del self._blocks_for_save[req_id] + self._num_blocks_for_save -= num_freed_blocks + else: + raise ValueError( + f" Staging buffer manager should not get usage: {usage}") + self._num_free_blocks += num_freed_blocks + + logger.info( + f" free {num_freed_blocks} staging blocks (usage: {usage}) from Req:{req_id}" + ) + return num_freed_blocks + + def get_usage(self, with_details: bool = False): + usage_str = (f"Staging Buffer: total={self.num_blocks}, " + f"free={self._num_free_blocks}, " + f"used_for_load={self._num_blocks_for_load}, " + f"used_for_save={self._num_blocks_for_save};") + if with_details: + blocks_for_save_str = " save_details:{" + for req, bn in self._blocks_for_save.items(): + blocks_for_save_str += f"{req}:{bn}," + blocks_for_save_str += "} " + + blocks_for_load_str = " load_details:{" + for req, bn in self._blocks_for_load.items(): + blocks_for_load_str += f"{req}:{bn}," + blocks_for_load_str += "}." + usage_str += blocks_for_save_str + blocks_for_load_str + + return usage_str diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py new file mode 100644 index 000000000..52be57ee7 --- /dev/null +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -0,0 +1,1928 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Scheduler side execution: +TPUOffloadConnectorScheduler manages the state of KV cache loading and saving for +each request. It acts as a state machine, tracking the progress of requests +across multiple scheduling steps and generating work orders (TPUReqMeta) for +the TPUOffloadConnectorWorker. + +Core Components: +- RequestTracker: The primary state object for a request. It tracks the + cumulative tokens and blocks processed, and how many of those tokens have + been saved to the CPU cache. A tracker is created when a request is first + scheduled and lives until the request is finished. + +- LoadSpec: A temporary state object created when a new request has a prefix + that matches data in the CPU cache (`get_num_new_matched_tokens`). It + holds the number of matched tokens and a `can_load` flag, which is set + to True only after the vLLM scheduler allocates the necessary blocks for + the load (`update_state_after_alloc`). + +- SaveSpec: A part of the work order sent to the worker. It instructs the + worker to save a specific slice of the KV cache from TPU to CPU. It + contains `num_skip_leading_tokens` to indicate which part of the request's + KV cache is new and needs saving, and an `is_final_save` flag to signal + the last save operation for a request. + +- TPUReqMeta: The unified work order for a single request in a single step, + sent from the scheduler to the worker. It can contain a `load_spec` (to + load from CPU to TPU), a `save_spec` (to save from TPU to CPU), or both. + +State Machine Flow (from the perspective of a request): + +1. RECEIVED -> AWAITING_ALLOCATION + - A new request arrives. + - `get_num_new_matched_tokens` checks the CPU backend for a matching + token prefix. + - If a match is found (N > 0 tokens), a `LoadSpec(num_matched_tokens=N, can_load=False)` + is created. The request now waits for the vLLM scheduler to allocate + physical blocks for these N tokens. + +2. AWAITING_ALLOCATION -> SCHEDULED + - The vLLM scheduler allocates blocks for the request. + - `update_state_after_alloc` is called. If a `LoadSpec` exists, its + `can_load` flag is set to True, greenlighting the load operation. + The request is now considered scheduled for processing in this step. + +3. SCHEDULED -> IN_FLIGHT or COMPLETED + - This transition is handled by `build_connector_meta` which calls the + central decision-making function, `_prepare_req_meta`. + - LoadSpec Preparation: The `LoadSpec` (if it exists and `can_load` + is True) is passed directly into the `TPUReqMeta`. The worker will + use `num_matched_tokens` to slice the correct prefix from the request's + `token_ids` and fetch the corresponding data from the CPU cache. + - SaveSpec Preparation: `_prepare_req_meta` determines if a save is + needed by comparing the total tokens processed so far + (`len(tracker.token_ids)`) with the number of tokens already saved + (`tracker.num_saved_tokens`). + - If `len(token_ids) > num_saved_tokens`, a `SaveSpec` is created. + - `num_skip_leading_tokens` is set to `tracker.num_saved_tokens`. This + tells the worker to ignore the prefix that's already in the CPU + cache and only save the new data. + - The scheduler then *transactionally* updates `tracker.num_saved_tokens` + to the new total length, ensuring this slice of data is not saved + again. + - If the scheduler has not finished the request, it transitions to + IN_FLIGHT. Its tracker is updated for the next scheduling step. + - If the scheduler has finished the request, it transitions to + COMPLETED. The tracker is removed, and a final `SaveSpec` is + generated. + - is_final_save: This flag is set to `True` only when the + scheduler marks a request as finished. It is a signal + for the worker, indicating that after this save is complete, the + request's lifecycle is over and its resources + can be safely freed. + +Worker Side Execution: +- The TPUOffloadConnectorWorker receives the `TPUOffloadConnectorMetadata` containing the list of + `TPUReqMeta` objects. +- `start_load_kv`: Iterates through the metadata. If a `meta.load_spec` + exists, it reads the corresponding data from the CPU backend and copies it + into the allocated blocks on the TPU. This is a blocking operation. +- `wait_for_save`: Iterates through the metadata. If a `meta.save_spec` + exists, it submits an asynchronous task to copy the specified slice of + KV data from TPU to CPU and update the CPU backend. It then waits for all + submitted save tasks for the current step to complete. +""" +import copy +import os +import time +from collections import defaultdict +from concurrent.futures import Future, ThreadPoolExecutor +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Literal, Optional, get_args + +import jax +import jax.numpy as jnp +from jax.sharding import Mesh +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ + KVConnectorStats +from vllm.utils.math_utils import cdiv +from vllm.v1.core.kv_cache_utils import BlockHash +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.outputs import KVConnectorOutput + +if TYPE_CHECKING: + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + from vllm.forward_context import ForwardContext + +from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend +from tpu_inference.distributed.offload.offload_manager import ( + LRUCacheManager, StagingBufferManager) +from tpu_inference.distributed.offload.utils import ( + CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId, + TokenProcessor, get_default_kv_connector_staging_buffer_tokens, + get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) +from tpu_inference.logger import init_logger +from tpu_inference.runner.kv_cache_manager import KVCacheManager +from tpu_inference.runner.tpu_runner import TPUModelRunner + +logger = init_logger(__name__) + +# kv cache layout needed by cpu offloading mechanism +REQUIRED_KV_CACHE_LAYOUT = "NHD" + +# default swap op type +DEFAULT_HOST_HBM_SWAP_OP_TYPE = "jax" + +BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] + +# we keep our operations at vllm's block granularity, +# and want to provide the following three preferences when handling +# the last partial block during save: +# 1. [supported] drop: drop the entire partial block +# 2. pad: pad to a full block +# 3. dynamic: keep the partial block as is. +PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop", "pad", "dynamic"] + +DEFAULT_TPU_OFFLOAD_CPU_CHUNKS = 1024 + + +@dataclass +class SaveSpec: + """A confirmed work order for the worker to save KV data.""" + num_skip_leading_tokens: int + # total processed tokens for matching / saving + num_total_tokens: int + src_blocks: list[int] + dst_chunks: list[int] + # final save for the (newly) finished request + is_final_save: bool = False + # A direct signal to the worker to skip the data transfer but still + # process the completion signal if is_final_save is True. + skip_save: bool = False + + +@dataclass +class LoadSpec: + """Internal scheduler state for a potential load operation.""" + num_matched_tokens: int + src_chunks: list[int] + dst_blocks: list[int] + can_load: bool = False + num_skip_leading_tokens: int = 0 + + +@dataclass +class TPUReqMeta: + """A unified work order for a single request in a single step.""" + # The unique identifier for the request. + req_id: str + # For a load operation, this contains the prefix of tokens to be loaded + # from the cache. For a save operation, this contains the new tokens + # that have just been computed. + token_ids: list[int] + # The full list of physical blocks corresponding to the `token_ids`. + local_block_ids: list[int] + # An optional `SaveSpec` object. If present, it instructs the worker to + # perform a save operation. + save_spec: Optional[SaveSpec] = None + # An optional `LoadSpec` object. If present, it instructs the worker to + # perform a load operation. + load_spec: Optional[LoadSpec] = None + + def __repr__(self) -> str: + load_info = f"load_spec_exists={self.load_spec is not None}" + if self.load_spec: + load_info += ( + f", num_matched_tokens={self.load_spec.num_matched_tokens}, " + f"can_load={self.load_spec.can_load}, " + f"num_skip_leading_tokens={self.load_spec.num_skip_leading_tokens}, " + f"src_chunks={self.load_spec.src_chunks}, " + f"dst_blocks={self.load_spec.dst_blocks}") + save_info = f"save_spec_exists={self.save_spec is not None}" + if self.save_spec: + save_info += ( + f", num_skip_leading_tokens={self.save_spec.num_skip_leading_tokens}, " + f"num_total_tokens={self.save_spec.num_total_tokens}, " + f"is_final_save={self.save_spec.is_final_save}, " + f"skip_save={self.save_spec.skip_save}, " + f"dst_chunks={self.save_spec.dst_chunks}, " + f"src_blocks={self.save_spec.src_blocks}") + + return (f"TPUReqMeta(req_id={self.req_id}, " + f"num_token_ids={len(self.token_ids)}, " + f"num_local_block_ids={len(self.local_block_ids)}, " + f"{load_info}, {save_info})") + + +@dataclass +class RequestTracker: + """Tracks the evolving state of a single request across multiple scheduling steps.""" + # The unique identifier for the request. + req_id: str + # The total number of tokens in the original prompt. + prompt_len: int + # The full, cumulative list of physical block numbers allocated to this + # request so far. + block_ids: list[int] + # The full, cumulative list of token IDs that have been processed for this + # request so far. This list only contains the + # tokens to be computed, not the prefix loaded from cache. + token_ids: list[int] + # The number of tokens that were a hit in the CPU cache at the beginning + # of the request. This is constant for the lifetime of the request. + num_external_hits: int = 0 + # A high-water mark indicating how many tokens from the start of the + # computed tokens (`token_ids`) have already been saved to the CPU cache. + save_watermark: int = 0 + # Whether the request is in the decoding phase (generating one token at a time). + is_decode_phase: bool = False + + def update(self, new_block_ids: list[int], new_token_ids: list[int]): + """Appends new block IDs and token IDs to the tracker.""" + if new_block_ids is None: + new_block_ids = [] + elif len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.block_ids.extend(new_block_ids) + self.token_ids.extend(new_token_ids) + + # NOTE(jcgu): is it always true? will MTP affect this judegment? + # When a request is scheduled again, and the number of new tokens + # is 1 (excluding chunked prefill), the request is in decode phase. + if len(new_token_ids) == 1: + self.is_decode_phase = True + + def __repr__(self) -> str: + output_str = " - RequestTracker: " + \ + f"req_id={self.req_id}, " + \ + f"prompt_len={self.prompt_len}, " + \ + f"num_tokens={len(self.token_ids)}, " + \ + f"num_blocks={len(self.block_ids)}, " + \ + f"save_watermark={self.save_watermark}" + return output_str + + +@dataclass +class KVOffloadConnectorStats(KVConnectorStats): + """Container for transfer performance metrics""" + + def __post_init__(self): + if not self.data: + # Empty container init, no data is passed in. + self.reset() + + def reset(self): + # Must be serializable + self.data: dict[str, dict[str, list[int]]] = { + "finished_save_chunks": dict(), + "finished_load_chunks": dict(), + } + + def record_save(self, req: ReqId, saved_chunk_ids: list[int]): + if req not in self.data["finished_save_chunks"]: + self.data["finished_save_chunks"][req] = [] + self.data["finished_save_chunks"][req].extend( + copy.deepcopy(saved_chunk_ids)) + + def record_load(self, req: ReqId, loaded_chunk_ids: list[int]): + if req not in self.data["finished_load_chunks"]: + self.data["finished_load_chunks"][req] = [] + self.data["finished_load_chunks"][req].extend( + copy.deepcopy(loaded_chunk_ids)) + + def clone_and_reset(self) -> "KVOffloadConnectorStats": + old = copy.copy(self) + self.reset() + return old + + def is_empty(self) -> bool: + return self.num_finished_blocks == 0 + + def aggregate(self, other: KVConnectorStats) -> KVConnectorStats: + return self + + def reduce(self) -> dict[str, int | float]: + # Compute compact representative stats suitable for CLI logging + if self.is_empty(): + return { + "Num finished save blocks ": 0, + "Num finished load blocks ": 0, + } + + finished_save_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_save_chunks"].values()) + finished_load_chunks = sum( + len(chunk_list) + for chunk_list in self.data["finished_load_chunks"].values()) + + return { + "Num finished save chunks ": finished_save_chunks, + "Num finished load chunks": finished_load_chunks, + } + + @property + def num_finished_blocks(self) -> int: + return len(self.data["finished_save_chunks"]) + len( + self.data["finished_load_chunks"]) + + +# The metadata used for communicating between scheduler and worker connectors. +@dataclass +class TPUOffloadConnectorMetadata(KVConnectorMetadata): + requests_meta: list[TPUReqMeta] = field(default_factory=list) + + +class TPUOffloadConnector(KVConnectorBase_V1): + + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + logger.info("TPUOffloadConnector: Entering __init__") + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = \ + TPUOffloadConnectorScheduler(vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + # The worker needs a reference to the base connector to access + # the metadata object set by the engine. + self.connector_worker = TPUOffloadConnectorWorker( + vllm_config, self) + + ############################################################ + # Class Methods + ############################################################ + @classmethod + def get_required_kvcache_layout(cls, vllm_config: VllmConfig): + if vllm_config.model_config is None: + logger.warning_once("Unable to detect current VLLM config. " + "Fallback to default kv cache layout.") + return None + + # TODO(jcgu): test mla + use_mla = vllm_config.model_config.use_mla + if use_mla: + # which fallback to the default behavior. + return None + + logger.info_once( + "TPUOffloadConnector currently only supports %s KV cache layout.", + REQUIRED_KV_CACHE_LAYOUT) + return REQUIRED_KV_CACHE_LAYOUT + + @classmethod + def build_kv_connector_stats( + cls, + data: dict[str, dict[str, int]] | None = None + ) -> KVConnectorStats | None: + return (KVOffloadConnectorStats( + data=data) if data is not None else KVOffloadConnectorStats()) + + ############################################################ + # Scheduler Side Methods + ############################################################ + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> TPUOffloadConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: list[jax.Array]): + logger.info("TPUOffloadConnector: Entering register_kv_caches") + """ + We don't register kv_caches in connector, we call `register_runner` and + use runner.kv_caches directly instead because the ref of runner.kv_caches + would be reassigned during model forward. + """ + pass + + def register_runner(self, runner: TPUModelRunner) -> None: + logger.info("TPUOffloadConnector: Entering register_runner") + assert self.connector_worker is not None + self.connector_worker.register_runner(runner) + + def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: + """Starts loading the KV cache for the given requests.""" + assert self.connector_worker is not None + self.connector_worker.start_load_kv(fwd_ctx) + + def wait_for_layer_load(self, layer_name: str) -> None: + logger.info("TPUOffloadConnector: Entering wait_for_layer_load") + """TPU connector doesn't support layer wise load.""" + pass + + def save_kv_layer(self, **kwargs) -> None: + logger.info("TPUOffloadConnector: Entering save_kv_layer") + """TPU connector doesn't support layer wise save.""" + pass + + def wait_for_save(self): + assert isinstance(self._connector_metadata, + TPUOffloadConnectorMetadata) + self.connector_worker.wait_for_save() + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished() + + def update_connector_output(self, connector_output: KVConnectorOutput): + assert self.connector_scheduler is not None + self.connector_scheduler.update_connector_output(connector_output) + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + if self.connector_worker is None: + return None + return self.connector_worker.get_kv_connector_stats() + + +class TPUOffloadConnectorScheduler(): + + def __init__(self, vllm_config: "VllmConfig"): + logger.info("TPUOffloadConnectorScheduler: Entering __init__") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + + # offloading manager + self.num_cpu_chunks = int( + os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", + str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.offload_manager = LRUCacheManager( + num_cpu_chunks=self.num_cpu_chunks) + + self._request_trackers: dict[ReqId, RequestTracker] = {} + # This dictionary holds the full vLLM Request object for all requests + # that are currently in a running state (i.e., have been scheduled but + # are not yet finished). It's used to access the complete prompt token + # list when processing incremental updates for cached/running requests, + # as the scheduler output for these requests is minimal. + self._unfinished_requests: dict[ReqId, "Request"] = {} + self.load_specs: dict[ReqId, LoadSpec] = {} + + # {reqid: total_num_matched_tokens_in_cpu_backend} + self._external_cache_hits: dict[ReqId, int] = {} + + # request ID -> set(block hashes being saved/loaded) + self._reqs_being_saved = defaultdict[ReqId, set[CpuChunkId]](set) + self._reqs_being_loaded = defaultdict[ReqId, set[CpuChunkId]](set) + + model_name = self.vllm_config.model_config.model + self.token_processor = TokenProcessor(model_name=model_name, + chunk_size=self.block_size) + + self.decode_save = os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0") == "1" + # NOTE(jcgu): currently, let's make chunk_size == block_size + # chunk_size == n * block_size lead to + # 1. multi-size chunks + # 2. complicated resize (split, concatenate) operations due to + # real-chunk-size in save and load + self.cpu_chunk_size = self.block_size + + # define partial_block saving behavior + self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = \ + os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR", "drop") + assert self.partial_block_save_behavior in get_args( + PARTIAL_BLOCK_SAVE_BEHAVIOR + ), f"{self.partial_block_save_behavior} not in {get_args(PARTIAL_BLOCK_SAVE_BEHAVIOR)}" + self.partial_block_dynamic_pad_lower_limit = \ + int(os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT", "0")) + if self.partial_block_save_behavior == "dynamic": + if self.partial_block_dynamic_pad_lower_limit <= 0: + self.partial_block_save_behavior == "drop" + elif self.partial_block_dynamic_pad_lower_limit >= self.block_size: + self.partial_block_save_behavior == "pad" + logger.info( + f" partial_block_save_behavior is configed to {self.partial_block_save_behavior}, but we only support drop now." + ) + self.partial_block_save_behavior = "drop" + + # config staging buffer + # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler + # otherwise, we can only use # of tokens as input, instead of buffer size in GB + _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens( + ) + num_staging_buffer_tokens = int( + os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", + str(_default_staging_buffer_tokens))) + self.num_staging_blocks = num_staging_buffer_tokens // self.block_size + self.staging_buffer_manager = StagingBufferManager( + num_blocks=self.num_staging_blocks) + + logger.info( + f"TPUOffloadConnectorScheduler initialized with: " + f"block_size={self.block_size}, " + f"cpu_chunk_size={self.cpu_chunk_size}, " + f"num_cpu_chunks={self.num_cpu_chunks}, " + f"model_name={model_name}, " + f"decode_save={self.decode_save}, " + f"partial_block_save_behavior={self.partial_block_save_behavior}, " + f"partial_block_dynamic_pad_lower_limit={self.partial_block_dynamic_pad_lower_limit}, " + f"num_staging_blocks={self.num_staging_blocks}.") + + def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: + # request's original block_hashes do not include the last partial block + # TODO(jcgu): switch back to token_processor + return req.block_hashes + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Checks for external KV cache hit against the local CPU backend. + """ + assert num_computed_tokens % self.block_size == 0, f"{num_computed_tokens} % {self.block_size} != 0" + # get block_hash + block_hashes = self._get_request_block_hashes(request) + num_total_blocks = len(block_hashes) + prompt_token_ids = request.prompt_token_ids + logger.info(f"Request {request.request_id}: Checking for cache hit. " + f"Prompt length: {len(prompt_token_ids)}, " + f"Block_hashes ({num_total_blocks})," + f"Already computed tokens: {num_computed_tokens}. ") + + # look for blocks in the cache + num_hits = self.offload_manager.lookup(block_hashes) + matched_block_hashes = block_hashes[:num_hits] + self.offload_manager.touch(block_hashes) + num_matched_blocks = len(matched_block_hashes) + num_matched_tokens = min(num_matched_blocks * self.block_size, + len(prompt_token_ids)) + num_computed_blocks = num_computed_tokens // self.block_size + num_blocks_to_load = num_matched_blocks - num_computed_blocks + logger.info( + f"Request {request.request_id}: Found {num_matched_tokens} (out of {len(prompt_token_ids)} prompt tokens) matched tokens ({num_matched_blocks} blocks) in CPU backend (computed_blocks: {num_computed_blocks}, blocks_to_load: {num_blocks_to_load})." + ) + + if num_blocks_to_load > 0: + # planning staging blocks for load + # NOTE(jcgu): do not worry about the inconsistency of the staging buffer status; + # there is only one connector scheduler who is operating on it. + num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( + ) + if num_blocks_to_load > num_avail_staging_blocks: + # reduce blocks_to_load (and matched tokens) when there are insufficient staging blocks. + logger.info( + f" Req({request.request_id}) found {num_matched_blocks} blocks ({num_matched_tokens} tokens), but only {num_avail_staging_blocks} staging blocks available." + ) + num_blocks_to_load = num_avail_staging_blocks + num_matched_tokens = (num_blocks_to_load + + num_computed_blocks) * self.block_size + + # still have something to load + if num_blocks_to_load > 0: + # get the src chunk ids to load + block_hashes_to_load = block_hashes[num_computed_blocks:( + num_computed_blocks + num_blocks_to_load)] + chunks_to_load = self.offload_manager.prepare_load( + block_hashes_to_load) + src_chunk_ids = [chunk.chunk_id for chunk in chunks_to_load] + + # NOTE(jcgu): fill real dst_blocks later when blocks get allocated. + dummy_dst_blocks = [-1] * num_blocks_to_load + self.load_specs[request.request_id] = LoadSpec( + num_matched_tokens=num_matched_tokens, + src_chunks=src_chunk_ids, + dst_blocks=dummy_dst_blocks, + num_skip_leading_tokens=num_computed_tokens, + ) + num_allocated_blocks = self.staging_buffer_manager.allocate( + request.request_id, + num_blocks=num_blocks_to_load, + usage="load") + assert num_allocated_blocks == num_blocks_to_load >= 0, f" failed to allocate {num_allocated_blocks} (load) staging blocks for request {request.request_id}, expected {num_blocks_to_load}." + + # record the matched tokens in the cache, it will be needed in + # init save_spec + self._external_cache_hits[ + request.request_id] = num_matched_tokens + + is_full_prefix_hit = (num_matched_tokens > 0 + and num_matched_tokens == len(prompt_token_ids)) + num_matched_for_scheduler = num_matched_tokens + if is_full_prefix_hit: + # When the entire prompt is found in the CPU cache (a "full hit"), + # report N-1 matched tokens to the vLLM scheduler instead + # of the true N. If we report a 100% match (N + # matched tokens for a prompt of length N), the scheduler sees + # zero new tokens and may not schedule the request for a prefill + # step at all and hits + # https://github.com/vllm-project/vllm/blob/b8b302cde434df8c9289a2b465406b47ebab1c2d/vllm/v1/core/sched/scheduler.py#L438 assetion. + # By reporting N-1, we ensure the scheduler allocates resources + # for and schedules the computation of the "last" token of the + # prompt. The worker (`start_load_kv`) still load the KV of N + # matched tokens, but the final token'KV will not be used, but be + # "re-computed" in the following forward pass (the loaded data in + # the slot gets override.) And from there, the request can + # seamlessly transition to the decoding phase. + num_matched_for_scheduler = num_matched_tokens - 1 + logger.info( + f"Request {request.request_id}: Full prompt hit. Reporting {num_matched_for_scheduler} matched tokens. Actual hit from backend is {num_matched_tokens} tokens" + ) + + # Note on unpinning for the full prefix hit case: Although we report N-1 tokens + # to the scheduler, the RequestTracker (created later in + # `build_connector_meta`) stores the true, full N prompt tokens. + # The `get_finished` method on the worker side uses this complete + # token list to regenerate the keys, ensuring that all N keys + # originally pinned during this lookup are gracefully unpinned upon + # request completion. + # We don't need to load tokens that are already computed locally in vLLM + num_to_load = max(0, num_matched_for_scheduler - num_computed_tokens) + logger.info( + f"Request {request.request_id}: After accounting for {num_computed_tokens} computed tokens, reporting {num_to_load} tokens to load." + ) + + # external_computed_tokens, load_kv_async + return num_to_load, False + + def _adjust_last_partial_block(self, + last_partial_block_num_tokens: int) -> bool: + """ + adjust prompt token / len based on pre-configed save behavior + when the last block of request's token is partially used. + In order to keep all the saved kv be aligned with block_size, + we may + 1. drop the partial block + 2. pad the partial block to be a full block + 3. drop or pad based on actual num_tokens in the last partial block + + Input: num of tokens in the last partial block (could be 0) + Output: the last partial block should be kept (True) or dropped (False) + """ + if self.partial_block_save_behavior == "pad": + return True if last_partial_block_num_tokens > 0 else False + elif self.partial_block_save_behavior == "drop": + return False + elif self.partial_block_save_behavior == "dynamic": + return True if last_partial_block_num_tokens >= self.partial_block_dynamic_pad_lower_limit else False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + This hook is not used for the save logic. + Update the dst_blocks in the load_spec + """ + logger.info( + f"TPUOffloadConnectorScheduler: Entering update_state_after_alloc Request {request.request_id}: Scheduler allocated " + f"{num_external_tokens} external tokens.") + self._unfinished_requests[request.request_id] = request + if num_external_tokens == 0: + return + if request.request_id in self.load_specs: + block_hashes = self._get_request_block_hashes(request) + all_blocks = blocks.get_block_ids()[0] + logger.info( + f" Request: {request.request_id} has {len(all_blocks)} blocks / {len(block_hashes)} block hashes.)" + ) + load_spec = self.load_specs[request.request_id] + assert load_spec.num_skip_leading_tokens % self.block_size == 0 + skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size + + total_matched_blocks = len( + load_spec.dst_blocks) + skip_leading_blocks + assert total_matched_blocks == cdiv( + load_spec.num_matched_tokens, self.block_size + ), f"{total_matched_blocks} != {load_spec.num_matched_tokens}" + dst_blocks = all_blocks[skip_leading_blocks:total_matched_blocks] + load_spec.dst_blocks = dst_blocks + load_spec.can_load = True + self._reqs_being_loaded[request.request_id] |= set( + load_spec.src_chunks) + logger.info( + f"Request {request.request_id} ({len(dst_blocks)} dst_blocks) is ready to load." + ) + + def _prepare_req_meta( + self, + tracker: RequestTracker, + load_spec: Optional[LoadSpec], + is_finished: bool, + ) -> Optional[TPUReqMeta]: + """ + Central decision-making function. Determines if a save or load is + needed and prepares the metadata. Also performs the transactional + update of the tracker's save state. + """ + req_id = tracker.req_id + _request = self._unfinished_requests[req_id] + block_hashes = self._get_request_block_hashes(_request) + self.offload_manager.touch(block_hashes) + + # only consider the tokens covered by block_hashes + num_total_blocks = len(block_hashes) + num_total_tokens = min(num_total_blocks * self.block_size, + len(tracker.token_ids)) + num_full_blocks = num_total_tokens // self.block_size + num_full_blocks_tokens = num_full_blocks * self.block_size + # adjust last partial block + last_partial_block_num_tokens = num_total_tokens - num_full_blocks_tokens + need_last_block = self._adjust_last_partial_block( + last_partial_block_num_tokens) + adjusted_num_total_tokens = num_total_tokens if need_last_block else num_full_blocks_tokens + adjusted_num_total_blocks = num_full_blocks + (1 if need_last_block + else 0) + assert adjusted_num_total_blocks <= len(tracker.block_ids) + + has_new_tokens = adjusted_num_total_tokens > tracker.save_watermark + should_save = False + # Determine if a save is needed for this step + # when there are new token KVs (adjusted by saving behavior): + # 1. Prefill: always save + # 2. Decode (with save_decode=True) + # 2.1 regular decode (not finished): accumulate until getting a full block + # 2.2 request finished: save + if has_new_tokens: + if not tracker.is_decode_phase: + # Prefill: always save the new-computed blocks + should_save = True + elif self.decode_save: + if is_finished: + # After decode, if there are new final new tokens to save + should_save = True + else: + # During decode, we do not drop or pad, just accumulate tokens until the next block boundary + next_block_boundary = ( + tracker.save_watermark // self.block_size + + 1) * self.block_size + logger.info( + f"in decode phase, next_block_boundary: {next_block_boundary}, " + ) + if num_total_tokens == next_block_boundary: + # always save the full block for decode (not affected by saving_behavior) + assert num_total_tokens == adjusted_num_total_tokens, f" decode_save: {num_total_tokens} != (adjusted) {adjusted_num_total_tokens}" + should_save = True + + logger.info(f" - Preparing meta for req (save): {tracker.req_id}, " + f"is_finished={is_finished}, " + f"total_tokens={num_total_tokens}, " + f"adjusted_num_total_tokens={adjusted_num_total_tokens}, " + f"adjusted_num_total_blocks={adjusted_num_total_blocks}, " + f"saved_tokens={tracker.save_watermark}, " + f"has_new={has_new_tokens}, " + f"is_decode={tracker.is_decode_phase}, " + f"should_save={should_save}") + + # A SaveSpec is always prepared for a finished request to signal completion, + # even if we don't save the underlying KV data. This is to ensure the TPUOffloadConnectorWorker + # can correctly report finished request. + save_spec = None + if should_save: + # get src block_ids for save + # NOTE(jcgu): recompute skip_leading_blocks + # if tracker.save_watermark has partial tokens in the last block + # and we saved (i.e., pad) the entire block to cpu_backend, now we + # want to save the kv of the new tokens in that block; because of + # the new tokens in that block's token sequence, the block will + # have a new key (hash value) in cpu_backend, so we should treat + # the block as a new cache and save the entire block. + # Example: + # we have saved: + # blocks: [------b0------] [------b1------] + # tokens: [t0, t1, t2, t3] [t4, t5,] + # cpu-backend:{key0: b0, key1:b1(2 tokens, padded)} + # + # Now, we have 2 new tokens in the sequence + # blocks: [------b0------] [------b1------] + # tokens: [t0, t1, t2, t3] [t4, t5, t6, t7] + # cpu-backend:{key0: b0, key1:b1(2 tokens, padded), + # key1_2: b1_2(4 tokens)} + # In cpu-backend, since b1's token-sequence has been changed, it + # will have a new key. + # + # if we always drop the partial-filled block when saving, then there + # will no such an issue. + num_skip_leading_blocks = tracker.save_watermark // self.block_size + num_skip_leading_tokens = num_skip_leading_blocks * self.block_size + num_blocks_to_save = adjusted_num_total_blocks - num_skip_leading_blocks + + # planning staging blocks for save + num_avail_staging_blocks = self.staging_buffer_manager.get_num_free_staging_blocks( + ) + if num_blocks_to_save > num_avail_staging_blocks: + # reduce blocks_to_save due to limited free staging blocks + logger.info( + f" Req({tracker.req_id}) have {num_blocks_to_save} ({adjusted_num_total_blocks} - {num_skip_leading_blocks}) blocks to save, but only {num_avail_staging_blocks} staging blocks available." + ) + num_blocks_to_save = num_avail_staging_blocks + adjusted_num_total_blocks = num_skip_leading_blocks + num_blocks_to_save + adjusted_num_total_tokens = adjusted_num_total_blocks * self.block_size + + if num_blocks_to_save > 0: + block_hashes_to_save = block_hashes[ + num_skip_leading_blocks:adjusted_num_total_blocks] + allocate_output = self.offload_manager.allocate_for_save( + block_hashes_to_save) + if allocate_output is not None: + # there are enough chunks to save + chunks_for_save, chunk_idxs = allocate_output + assert num_blocks_to_save == len(chunks_for_save) + src_block_ids = tracker.block_ids[ + num_skip_leading_blocks:adjusted_num_total_blocks] + + dst_chunks = [chunk.chunk_id for chunk in chunks_for_save] + src_blocks = [src_block_ids[idx] for idx in chunk_idxs] + + # This is a real save operation. + save_spec = SaveSpec( + num_skip_leading_tokens=num_skip_leading_tokens, + num_total_tokens=adjusted_num_total_tokens, + is_final_save=is_finished, + skip_save=False, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + self._reqs_being_saved[req_id] |= set(dst_chunks) + num_allocated_blocks = self.staging_buffer_manager.allocate( + tracker.req_id, + num_blocks=num_blocks_to_save, + usage="save") + assert num_allocated_blocks == num_blocks_to_save >= 0, f" failed to allocate {num_allocated_blocks} (save) staging blocks for request {tracker.req_id}, expected {num_blocks_to_save}." + + if adjusted_num_total_tokens > tracker.save_watermark: + logger.info( + f" -> Old watermark {tracker.save_watermark}, new save_watermark count: {adjusted_num_total_tokens}" + ) + tracker.save_watermark = adjusted_num_total_tokens + + if is_finished and save_spec is None: + # For finished requests, there must be a no-op save to update the state in the worker side. + # This is a "completion-only" signal because should_save is False. + # NOTE(jcgu): num_total_tokens will be used to unpin tokens; + # apply the number of saved tokens; + # TODO(jcgu): rm the no-op save, since save status has been updated + # through kv_connector_output.kv_connector_stats + save_spec = SaveSpec( + num_skip_leading_tokens=tracker.save_watermark, + num_total_tokens=tracker.save_watermark, + src_blocks=[], + dst_chunks=[], + is_final_save=True, + skip_save=True, + ) + + # 2. Determine if a work order is needed. + if not save_spec and not (load_spec and load_spec.can_load): + return None + + # 3. Construct and return the final work order. + return TPUReqMeta( + req_id=tracker.req_id, + token_ids=tracker.token_ids, + local_block_ids=tracker.block_ids, + save_spec=save_spec, + load_spec=load_spec, + ) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput) -> TPUOffloadConnectorMetadata: + metadata = TPUOffloadConnectorMetadata() + + # Phase 1: Handle and clean up finished requests + # This block handles requests that have completed their generation. + # We pop their state from our tracking dictionaries and call _prepare_req_meta + # one last time. This ensures any final, unsaved tokens are captured and + # signals to the worker that this is the final save for the request. + logger.info( + f"Phase 1: Processing {len(scheduler_output.finished_req_ids)} finished requests." + ) + for finished_req_id in scheduler_output.finished_req_ids: + logger.info(f" - Processing finished req: {finished_req_id}") + tracker = self._request_trackers[finished_req_id] + + if not tracker: + logger.warning( + f" - No tracker found for finished req: {finished_req_id}. Skipping." + ) + continue + + # Prepare one final metadata object if there's a final save needed. + # `is_finished` is set to True to flag this as the last save operation. + req_meta = self._prepare_req_meta(tracker, + load_spec=None, + is_finished=True) + if req_meta: + logger.info( + f" - Creating final save metadata for req: {finished_req_id}" + ) + metadata.requests_meta.append(req_meta) + + # Pop tracker and other state first. + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self.load_specs.pop(finished_req_id, None) + + # Phase 2: Process newly scheduled requests + # This block handles requests being scheduled for the very first time. + # It creates the initial RequestTracker and prepares the first work order. + logger.info( + f"Phase 2: Processing {len(scheduler_output.scheduled_new_reqs)} new requests." + ) + for request in scheduler_output.scheduled_new_reqs: + req_id = request.req_id + + _request = self._unfinished_requests[req_id] + logger.info( + f" - Processing new req: {req_id}, {len(_request.block_hashes)} block_hashes." + ) + num_new_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + + # Get the external cache hit count from our new, reliable source. + num_external_hits = self._external_cache_hits.pop(req_id, 0) + + # Determine the total length of tokens the tracker should hold. + # This is vLLM's already computed tokens + newly scheduled tokens. + num_total_tokens_for_tracker = request.num_computed_tokens + num_new_scheduled_tokens + tokens_for_tracker = request.prompt_token_ids[: + num_total_tokens_for_tracker] + logger.info( + f" - num_new_scheduled_tokens: {num_new_scheduled_tokens}, num_vllm_computed: {request.num_computed_tokens}, num_external_hits: {num_external_hits}" + ) + logger.info( + f" - Slicing prompt[:{num_total_tokens_for_tracker}] -> len(tokens_for_tracker): {len(tokens_for_tracker)}" + ) + + # Set the initial high-water mark for `save_watermark`. + # This is the maximum of what vLLM has computed and what's in our external cache. + initial_save_watermark = max(request.num_computed_tokens, + num_external_hits) + + # Create and store the tracker, which will maintain the request's + # state for its entire lifetime. + assert req_id not in self._request_trackers, f"Request {req_id} already has a tracker." + # TODO(jcgu): reduce duplicated info in request tracker + tracker = RequestTracker( + req_id=req_id, + prompt_len=len(request.prompt_token_ids), + block_ids=copy.deepcopy(request.block_ids[0]), + token_ids=tokens_for_tracker, + num_external_hits=num_external_hits, + # The high-water mark for saved tokens starts after the cached prefix. + save_watermark=initial_save_watermark, + ) + self._request_trackers[req_id] = tracker + logger.info( + f" - Created tracker for {req_id} with initial state: {tracker}" + ) + + # Immediately prepare metadata for this new request. This could include + # both a load operation (for the cached part) and a save operation + # (for the newly computed part). + load_spec = self.load_specs.get(req_id) + req_meta = self._prepare_req_meta(tracker, + load_spec, + is_finished=False) + if req_meta: + logger.info(f" - Creating metadata for new req: {req_id} " + f"(has_load={req_meta.load_spec is not None}, " + f"has_save={req_meta.save_spec is not None})") + metadata.requests_meta.append(req_meta) + + # Phase 3: Process cached (running) requests + # This block handles requests that have already been pre-filled at least + # once and are now being processed again (e.g., for chunked prefill). + cached_reqs = scheduler_output.scheduled_cached_reqs + logger.info( + f"Phase 3: Processing {len(cached_reqs.req_ids)} cached requests.") + for i, req_id in enumerate(cached_reqs.req_ids): + tracker = self._request_trackers[req_id] + full_request = self._unfinished_requests.get(req_id) + _block_hashes = full_request.block_hashes + logger.info( + f" - Processing cached req: {req_id}, {len(_block_hashes)} block_hashes." + ) + + if full_request is None: + logger.warning( + f" - No full request found for cached req: {req_id}. Skipping." + ) + continue + + # num_new_tokens: The number of *additional* tokens the scheduler is + # processing in this step for this ongoing request. + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + + # current_token_count: This is the crucial calculation to find our + # place in the full prompt. It's the length of the token prefix + # already processed in previous steps. + current_token_count = len(tracker.token_ids) + + logger.info( + f" - len(full_request.all_token_ids): {len(full_request.all_token_ids)}" + ) + # new_token_ids: The slice of the full token sequence corresponding to the + # new work being done in this step. + new_token_ids = full_request.all_token_ids[ + current_token_count:current_token_count + num_new_tokens] + + # new_blocks: The new physical blocks allocated for the new_token_ids. + new_blocks = cached_reqs.new_block_ids[i] + if new_blocks is None: + new_blocks = [] + + logger.info( + f" - num_new_tokens: {num_new_tokens}, current_token_count: {current_token_count}" + ) + logger.info( + f" - Slicing prompt -> len(new_token_ids): {len(new_token_ids)}" + ) + logger.info(f" - New blocks allocated: {len(new_blocks)}") + + # Update the tracker with the incremental data. + tracker.update(new_blocks, new_token_ids) + logger.info(f" - Updated tracker for {req_id}: " + f"total_tokens={len(tracker.token_ids)}, " + f"total_blocks={len(tracker.block_ids)}") + + # Immediately prepare metadata for this updated request. This will + # typically be a save operation for the new tokens. + req_meta = self._prepare_req_meta(tracker, + load_spec=None, + is_finished=False) + if req_meta: + logger.info( + f" - Creating metadata for cached req: {req_id} " + f"(has_save={req_meta.save_spec is not None})") + metadata.requests_meta.append(req_meta) + + if metadata.requests_meta: + logger.info( + f"Prepared {len(metadata.requests_meta)} requests for worker.") + return metadata + + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + logger.info( + f"TPUOffloadConnectorScheduler: getting workers' output: finished_sending: {connector_output.finished_sending}, finished_recving: {connector_output.finished_recving}" + ) + + # per iteration, update the finished staging blocks + if connector_output.kv_connector_stats and connector_output.kv_connector_stats.data is not None: + assert isinstance(connector_output.kv_connector_stats, + KVOffloadConnectorStats) + assert "finished_save_chunks" in connector_output.kv_connector_stats.data + assert "finished_load_chunks" in connector_output.kv_connector_stats.data + for req_id, saved_chunk_ids in connector_output.kv_connector_stats.data[ + "finished_save_chunks"].items(): + num_saved_chunks = len(saved_chunk_ids) + logger.info( + f" finished_save_chunks for {req_id}: {saved_chunk_ids}") + # free staging blocks + self.staging_buffer_manager.free( + req_id, usage="save", num_finished_blocks=num_saved_chunks) + # update in-flight save + for saved_chunk_id in saved_chunk_ids: + assert saved_chunk_id in self._reqs_being_saved[req_id] + self._reqs_being_saved[req_id].remove(saved_chunk_id) + if len(self._reqs_being_saved[req_id]) == 0: + self._reqs_being_saved.pop(req_id, None) + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(saved_chunk_ids, "save") + + for req_id, loaded_chunk_ids in connector_output.kv_connector_stats.data[ + "finished_load_chunks"].items(): + num_loaded_chunks = len(loaded_chunk_ids) + logger.info( + f" finished_load_chunks for {req_id}: {num_loaded_chunks}" + ) + self.staging_buffer_manager.free( + req_id, + usage="load", + num_finished_blocks=num_loaded_chunks) + # update in-flight save + for loaded_chunk_id in loaded_chunk_ids: + assert loaded_chunk_id in self._reqs_being_loaded[req_id] + self._reqs_being_loaded[req_id].remove(loaded_chunk_id) + if len(self._reqs_being_loaded[req_id]) == 0: + self._reqs_being_loaded.pop(req_id, None) + # update the status of occupied cpu chunks + self.offload_manager.mark_completion(loaded_chunk_ids, "load") + + # clean up the status of the finished requests + # save + for req_id in connector_output.finished_sending or []: + if req_id in self._reqs_being_saved: + assert len(self._reqs_being_saved[req_id]) == 0 + self._reqs_being_saved.pop(req_id) + num_freed_blocks = self.staging_buffer_manager.free(req_id, + usage="save") + logger.info( + f" freed {num_freed_blocks} staging blocks (save) from {req_id}" + ) + + # load + for req_id in connector_output.finished_recving or []: + if req_id in self._reqs_being_loaded: + assert len(self._reqs_being_loaded[req_id]) == 0 + self._reqs_being_loaded.pop(req_id) + num_freed_blocks = self.staging_buffer_manager.free(req_id, + usage="load") + logger.info( + f" freed {num_freed_blocks} staging blocks (load) from {req_id}" + ) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Called when a request has finished, before its blocks are freed. + + True if the request is being saved/sent asynchronously and blocks + should not be freed until the request_id is returned from + get_finished(). + Optional KVTransferParams to be included in the request outputs + returned by the engine. + return: + delay_free_blocks, kv_xfer_params + """ + logger.info("TPUOffloadConnectorScheduler: Entering request_finished") + # Return True to indicate the request is being saved asynchronously + # and its blocks should not be freed yet. + + req_id = request.request_id + if req_id in self._reqs_being_saved and len( + self._reqs_being_saved[req_id]) > 0: + return True, None + if req_id in self._reqs_being_loaded and len( + self._reqs_being_loaded[req_id]) > 0: + return True, None + + logger.info( + f"TPUOffloadConnectorScheduler: finished request: {req_id}") + self._reqs_being_saved.pop(req_id, None) + self._reqs_being_loaded.pop(req_id, None) + + return False, None + + +class TPUOffloadConnectorWorker: + + def __init__(self, vllm_config: VllmConfig, + connector: "TPUOffloadConnector"): + logger.info("TPUOffloadConnectorWorker: Entering __init__") + self.vllm_config = vllm_config + self.connector = connector + self.block_size = vllm_config.cache_config.block_size + + self.runner: Optional[TPUModelRunner] = None + self.mesh: Optional[Mesh] = None + self.swap_op_type = os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", + default=DEFAULT_HOST_HBM_SWAP_OP_TYPE) + assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE) + # TODO(jcgu): check libtpu compatibility for pallas dma kernel + logger.info( + f"(cpu offloading) swap operation type is {self.swap_op_type}") + + self.use_bucketed_swap_ops = os.getenv( + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0") == "0" + logger.info( + f"(cpu offloading) use_bucketed_swap_ops={self.use_bucketed_swap_ops}" + ) + + self.swap_in_fn: KVCacheSwapFn = None + self.swap_out_fn: KVCacheSwapFn = None + + # cpu cache + self.num_cpu_chunks = int( + os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", + str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.cpu_backend = LocalCPUBackend(num_cpu_chunks=self.num_cpu_chunks) + # The worker needs its own token processor to generate keys. + model_name = self.vllm_config.model_config.model + logger.info( + f"Model name is {model_name}, KV block_size={self.block_size}") + self.token_processor = TokenProcessor(model_name=model_name, + chunk_size=self.block_size) + + self.cpu_chunk_size = self.block_size + # Thread pool for asynchronous TPU->CPU copies + self.save_executor = ThreadPoolExecutor(max_workers=4, + thread_name_prefix="tpu_saver") + self.finished_save_reqs: set[ReqId] = set() + self.finished_load_reqs: set[ReqId] = set() + # Tracks if wait_for_save has been called for the current step's metadata. + self._processed_save_for_step = False + + # record finished save / load blocks (with req_ids) for each iteration + self.offload_stats = KVOffloadConnectorStats() + + def __del__(self): + logger.info("TPUOffloadConnectorWorker: Entering __del__") + self.save_executor.shutdown(wait=True) + + def register_runner(self, runner: TPUModelRunner): + logger.info("TPUOffloadConnectorWorker: Entering register_runner") + self.runner = runner + self.devices = runner.devices + self.mesh = runner.mesh + # Get the spec of the kv_caches + kv_caches = runner.kv_caches + if kv_caches: + self.kv_cache_layout = runner.get_kv_cache_layout() + kv_layer = kv_caches[0] + self.num_layers = len(kv_caches) + self.shape = list(kv_layer.shape) + self.dtype = kv_layer.dtype + self.device_sharding = kv_layer.sharding + + # NOTE(jcgu): needed when sliced-kv is [num_tokens, num_head, head_dim] + self.flatten_device_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="device") + + self.flatten_host_sharding = jax.sharding.NamedSharding( + mesh=self.device_sharding.mesh, + spec=jax.sharding.PartitionSpec(None, "model"), + memory_kind="pinned_host") + + self.swap_in_fn, self.swap_out_fn = get_kv_cache_swap_fn( + self.swap_op_type, + host_sharding=self.flatten_host_sharding, + device_sharding=self.flatten_device_sharding) + + logger.info( + "KV Cache details registered in TPUOffloadConnectorWorker:") + logger.info(f" - Num layers: {self.num_layers}") + logger.info(f" - Shape per layer: {self.shape}") + logger.info(f" - DType: {self.dtype}") + logger.info(f" - Device sharding: {self.device_sharding}") + logger.info( + f" - Flatten Device sharding: {self.flatten_device_sharding}") + logger.info(f" - Layout: {self.kv_cache_layout}") + else: + raise ValueError( + "TPUOffloadConnectorWorker registered with no KV caches.") + + # Pre-compile the JIT functions for KV cache swapping. + if self.use_bucketed_swap_ops: + self._precompile_kv_swap_operations() + + def _decompose_into_buckets(self, num_blocks: int) -> list[int]: + """ + Decomposes a number into a sum of numbers from the BLOCK_SIZE_BUCKETS + list using a greedy approach. + """ + sorted_buckets = sorted(BLOCK_SIZE_BUCKETS, reverse=True) + chunks = [] + remaining = num_blocks + while remaining > 0: + for bucket_size in sorted_buckets: + if remaining >= bucket_size: + chunks.append(bucket_size) + remaining -= bucket_size + break + else: + # This should not happen if 1 is in the buckets + raise ValueError( + "Could not decompose number with the given buckets.") + return chunks + + def _precompile_kv_swap_operations(self): + """ + Pre-compiles the JIT-compiled functions used for KV cache swapping + with a variety of common block sizes to avoid runtime recompilation. + """ + if os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0") == "1": + logger.info( + "Skipping KV swap pre-compilation due to environment variable." + ) + return + + logger.info("Starting pre-compilation of KV cache swap operations") + start_time = time.time() + paged_kv_for_compilation = self.runner.kv_caches + for num_blocks in BLOCK_SIZE_BUCKETS: + try: + logger.info(f" - Compiling for {num_blocks} blocks...") + dummy_block_ids = jnp.arange(num_blocks) + + # 1. Pre-compile gather (used in save) + flat_dummy_kv_caches_tpu = KVCacheManager._jitted_gather_kv_cache( + paged_kv_for_compilation, dummy_block_ids) + jax.block_until_ready(flat_dummy_kv_caches_tpu) + + # 2. Pre-compile TPU -> CPU transfer (used in save) + dummy_kv_cpu = self.swap_out_fn(flat_dummy_kv_caches_tpu) + jax.block_until_ready(dummy_kv_cpu) + + # 3. Pre-compile CPU -> TPU transfer (used in load) + split_size_list = [self.block_size] * num_blocks + chunked_dummy_kv_cpu = [ + jax.lax.split(flat_layer_cache, split_size_list, axis=0) + for flat_layer_cache in dummy_kv_cpu + ] + chunked_dummy_kv_tpu = self.swap_in_fn(chunked_dummy_kv_cpu) + jax.block_until_ready(chunked_dummy_kv_tpu) + + # 4. Pre-compile insert (used in load). + # The result is passed to the next iteration's gather to avoid + # using a "deleted" array. + logger.info( + f" - Calling jitted_insert_kv_cache_slices with paged_kv_for_compilation len: {len(paged_kv_for_compilation)}, first_element_shape: {paged_kv_for_compilation[0].shape}, " + f"chunked_dummy_kv_tpu len: {len(chunked_dummy_kv_tpu)}") + paged_kv_for_compilation = jitted_insert_kv_cache_slices( + self.block_size, paged_kv_for_compilation, + chunked_dummy_kv_tpu, dummy_block_ids) + jax.block_until_ready(paged_kv_for_compilation) + except Exception as e: + logger.warning( + f" - Failed to pre-compile for {num_blocks} blocks: {e}", + exc_info=True) + + self.runner.kv_caches = paged_kv_for_compilation + duration = time.time() - start_time + logger.info("KV cache swap pre-compilation finished in %.2f [secs].", + duration) + + def _bucketed_gather_kv_cache( + self, + kv_caches: list[jax.Array], + block_ids: jax.Array, + ) -> list[jax.Array]: + """ + Gathers KV cache data for the given block_ids by breaking the operation + into bucket-aligned chunks to leverage JIT compilation cache. + """ + num_blocks = len(block_ids) + if num_blocks == 0: + return [] + if num_blocks in BLOCK_SIZE_BUCKETS: + return KVCacheManager._jitted_gather_kv_cache(kv_caches, block_ids) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing gather for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + gathered_chunks = [] + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + block_slice = jax.lax.dynamic_slice_in_dim(block_ids, + block_offset, + decomposed_block_size, + axis=0) + gathered_chunk = KVCacheManager._jitted_gather_kv_cache( + kv_caches, block_slice) + gathered_chunks.append(gathered_chunk) + block_offset += decomposed_block_size + + # Reassemble the results from all chunks + return jax.tree.map(lambda *x: jnp.concatenate(x, axis=0), + *gathered_chunks) + + def _bucketed_swap_out_fn( + self, + flat_kv_caches_tpu: list[jax.Array]) -> list[list[jax.Array]]: + """ + Swaps out KV cache data from TPU to CPU in bucket-aligned chunks, + returning a list of block-sized chunks per layer. + """ + num_tokens = flat_kv_caches_tpu[0].shape[0] + num_blocks = num_tokens // self.block_size + if num_blocks == 0: + return [[] for _ in range(self.num_layers)] + + # Fast path: handle bucket-sized transfers + if num_blocks in BLOCK_SIZE_BUCKETS: + flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) + split_size_list = [self.block_size] * num_blocks + return [ + jax.lax.split(flat_layer_cache, split_size_list, axis=0) + for flat_layer_cache in flat_kv_caches_cpu + ] + + # Bucket decomposition path + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing swap-out for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + # This will be a list of lists, where each inner list holds the chunks + # for a layer. + final_chunks_per_layer = [[] for _ in range(self.num_layers)] + token_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + chunk_size_in_tokens = decomposed_block_size * self.block_size + + # Slice the TPU tensor for the current bucket + tpu_chunk = [ + jax.lax.dynamic_slice_in_dim(layer_cache, + token_offset, + chunk_size_in_tokens, + axis=0) + for layer_cache in flat_kv_caches_tpu + ] + + # Swap the bucket to CPU, result is a flat tensor for this bucket. We are doing the chunking inside this function to avoid returning any jnp.concatenate + # of kv cache for the the bucketed blocks + cpu_chunk_flat_per_layer = self.swap_out_fn(tpu_chunk) + # Split the flat bucket tensor into block-sized chunks and append + split_size_list = [self.block_size] * decomposed_block_size + for i, layer_cache in enumerate(cpu_chunk_flat_per_layer): + chunks = jax.lax.split(layer_cache, split_size_list, axis=0) + final_chunks_per_layer[i].extend(chunks) + + token_offset += chunk_size_in_tokens + + return final_chunks_per_layer + + def _bucketed_swap_in_fn( + self, + assembled_kv_on_cpu: list[list[jax.Array]], + ) -> list[list[jax.Array]]: + """ + Swaps in KV cache data from CPU to TPU in bucket-aligned chunks, + assembling a complete staging buffer on the TPU. + """ + num_blocks = len(assembled_kv_on_cpu[0]) + if num_blocks == 0: + return [[] for _ in range(self.num_layers)] + if num_blocks in BLOCK_SIZE_BUCKETS: + return self.swap_in_fn(assembled_kv_on_cpu) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing swap-in for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + + tpu_chunks_per_layer = [[] for _ in range(self.num_layers)] + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + cpu_chunks_for_bucket = [ + layer_chunks[block_offset:block_offset + decomposed_block_size] + for layer_chunks in assembled_kv_on_cpu + ] + tpu_chunks_for_bucket = self.swap_in_fn(cpu_chunks_for_bucket) + for i in range(self.num_layers): + tpu_chunks_per_layer[i].extend(tpu_chunks_for_bucket[i]) + block_offset += decomposed_block_size + + return tpu_chunks_per_layer + + def _bucketed_jitted_insert_kv_cache_slices( + self, + kv_caches: list[jax.Array], + kv_cache_slices: list[list[jax.Array]], + dst_blocks: jax.Array, + ) -> list[jax.Array]: + """ + Inserts KV cache slices into the main cache in bucket-aligned chunks. + """ + num_blocks = len(dst_blocks) + if num_blocks == 0: + return kv_caches + if num_blocks in BLOCK_SIZE_BUCKETS: + return jitted_insert_kv_cache_slices(self.block_size, kv_caches, + kv_cache_slices, dst_blocks) + + decomposed_block_sizes = self._decompose_into_buckets(num_blocks) + logger.info( + f"Decomposing insert for {num_blocks} blocks into bucket sizes {decomposed_block_sizes}" + ) + + updated_kv_caches = kv_caches + block_offset = 0 + for decomposed_block_size in decomposed_block_sizes: + slices_for_bucket = [ + layer_slices[block_offset:block_offset + decomposed_block_size] + for layer_slices in kv_cache_slices + ] + dst_blocks_for_bucket = jax.lax.dynamic_slice_in_dim( + dst_blocks, block_offset, decomposed_block_size, axis=0) + + updated_kv_caches = jitted_insert_kv_cache_slices( + self.block_size, updated_kv_caches, slices_for_bucket, + dst_blocks_for_bucket) + + block_offset += decomposed_block_size + + return updated_kv_caches + + def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], + full_token_ids: list[int], + save_spec: SaveSpec) -> ReqId: + """ + Extracts KV cache blocks from TPU, copies them to CPU, and updates the + CPU backend with the new cache keys and their corresponding token data. + """ + if not self.runner or not self.runner.kv_caches: + logger.error(f"Cannot save blocks for request {req_id}: runner or " + "KV caches not registered.") + return req_id + + blocks_to_save = save_spec.src_blocks + dst_chunks = save_spec.dst_chunks + + num_total_tokens = save_spec.num_total_tokens + num_skip_leading_tokens = save_spec.num_skip_leading_tokens + num_blocks_to_save = len(blocks_to_save) + + assert num_total_tokens <= len( + full_token_ids), f"{num_total_tokens} > {len(full_token_ids)}" + + num_tokens_to_save = num_total_tokens - num_skip_leading_tokens + if num_tokens_to_save <= 0 and not save_spec.is_final_save: + logger.info(f"Request {req_id}: No new tokens to save.") + return req_id + + process_token_ids = full_token_ids[:num_total_tokens] + tokens_to_save = process_token_ids[num_skip_leading_tokens:] + + logger.info(f"Request {req_id} save details: " + f"full_block_ids len={len(full_block_ids)}, " + f"num_skip_leading_tokens={num_skip_leading_tokens}, " + f"num_total_tokens={num_total_tokens}, " + f"num_tokens_to_save={num_tokens_to_save}, " + f"blocks_to_save({len(blocks_to_save)}: {blocks_to_save}, " + f"dst_chunks({len(dst_chunks)}: {dst_chunks} ") + + if not blocks_to_save and tokens_to_save: + logger.warning( + f"Request {req_id}: Tokens to save but no corresponding blocks found." + ) + return req_id + + if not tokens_to_save: + logger.info( + f"Request {req_id}: No new tokens to save, but processing as final save." + ) + return req_id + + # Verify if blocks_to_save is a contiguous subarray of full_block_ids + first_src_block = blocks_to_save[0] + last_src_block = blocks_to_save[-1] + try: + first_block_idx_in_full = full_block_ids.index(first_src_block) + last_block_idx_in_full = full_block_ids.index(last_src_block) + if not (last_block_idx_in_full - first_block_idx_in_full + 1 + == len(blocks_to_save)): + raise ValueError( + f"Request({req_id}): blocks_to_save {blocks_to_save} does not exist in full_block_ids {full_block_ids}" + ) + except ValueError: + raise ValueError( + f"Request({req_id}): blocks_to_save {blocks_to_save} contains blocks not present in local_block_ids {full_block_ids}" + ) + + try: + start_time = time.time() + blocks_to_save = jnp.array(blocks_to_save) + if self.use_bucketed_swap_ops: + flat_kv_caches_tpu = self._bucketed_gather_kv_cache( + self.runner.kv_caches, blocks_to_save) + else: + flat_kv_caches_tpu = KVCacheManager._jitted_gather_kv_cache( + self.runner.kv_caches, blocks_to_save) + + jax.block_until_ready(flat_kv_caches_tpu) + logger.info( + f"extracted_blocks_tpu: {flat_kv_caches_tpu[0].shape}, {flat_kv_caches_tpu[0].sharding}" + ) + + chunks_on_cpu = None + if self.use_bucketed_swap_ops: + chunks_on_cpu = self._bucketed_swap_out_fn(flat_kv_caches_tpu) + else: + flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) + if flat_kv_caches_cpu: + jax.block_until_ready(flat_kv_caches_cpu) + # NOTE(jcgu): we keep cpu_chunk_size == block_size + split_size_list = [self.cpu_chunk_size + ] * num_blocks_to_save + chunks_on_cpu = [ + jax.lax.split(flat_layer_cache, + split_size_list, + axis=0) + for flat_layer_cache in flat_kv_caches_cpu + ] + + if chunks_on_cpu and chunks_on_cpu[0]: + jax.block_until_ready(chunks_on_cpu) + + duration = time.time() - start_time + logger.info(f"Successfully saved {len(blocks_to_save)} blocks for " + f"request {req_id} to CPU in {duration:.4f} seconds.") + + total_size_bytes = sum( + sum(chunk.nbytes for chunk in layer_chunks) + for layer_chunks in chunks_on_cpu) + logger.info( + f"Total size of chunks_on_cpu: {total_size_bytes / 1024**2:.2f} MB" + ) + + post_transfer_start_time = time.time() + + for i in range(num_blocks_to_save): + chunk_id = dst_chunks[i] + cur_chunk_cross_layers = [ + chunks_on_cpu[j][i] for j in range(self.num_layers) + ] + self.cpu_backend.add(chunk_id, cur_chunk_cross_layers) + logger.info(f"Request {req_id}: Saving to CPU chunk: " + f"chunk_id={chunk_id}, " + f" local_chunk_idx={i}") + + logger.info( + f"Request {req_id}: Added {num_blocks_to_save} chunks to CPU backend." + ) + + post_transfer_duration = time.time() - post_transfer_start_time + logger.info( + f"Request {req_id}: e2e host processing of {num_blocks_to_save} chunks took {post_transfer_duration:.4f} seconds." + ) + except Exception as e: + logger.error(f"Error saving blocks for request {req_id}: {e}", + exc_info=True) + + return req_id + + def wait_for_save(self): + """ + Initiates and waits for all pending asynchronous save operations for the + current step to complete. + """ + # This method is idempotent. If the save operations for the current + # step's metadata have already been processed, we can exit early. + if self._processed_save_for_step: + return + + # logger.info("TPUOffloadConnectorWorker: Entering wait_for_save") + metadata = self.connector._get_connector_metadata() + if not isinstance(metadata, TPUOffloadConnectorMetadata): + logger.info( + "wait_for_save:not an instances of TPUOffloadConnectorMetadata" + ) + self._processed_save_for_step = True + return + + if not metadata.requests_meta: + # logger.info("wait_for_save:no reqs to save") + self._processed_save_for_step = True + return + + pending_save_futures: list[tuple[Future, TPUReqMeta]] = [] + # Handle save requests + for meta in metadata.requests_meta: + if meta.save_spec: + if meta.save_spec.skip_save: + logger.info( + f"Request {meta.req_id}: Scheduler signaled to skip save." + ) + if meta.save_spec.is_final_save: + logger.info( + f"Request {meta.req_id}: Final save is a no-op. Marking as finished." + ) + # self.finished_save_reqs.add(meta.req_id) + continue + + # If there are tokens to save, submit the task to the thread pool. + logger.info(f"Submitting save task for request {meta.req_id}") + future = self.save_executor.submit(self._save_blocks_to_cpu, + meta.req_id, + meta.local_block_ids, + meta.token_ids, + meta.save_spec) + pending_save_futures.append((future, meta)) + + if not pending_save_futures: + self._processed_save_for_step = True + return + + logger.info(f"Waiting for {len(pending_save_futures)} save " + "operations to complete...") + start_time = time.time() + + for future, meta in pending_save_futures: + try: + # The result of _save_blocks_to_cpu is the request_id + finished_req_id = future.result() + logger.info( + f"Save operation completed for request {finished_req_id}") + + if len(meta.save_spec.src_blocks) > 0: + self.offload_stats.record_save( + req=finished_req_id, + saved_chunk_ids=meta.save_spec.dst_chunks) + + if meta.save_spec and meta.save_spec.is_final_save: + logger.info( + f"Request {finished_req_id}: Final save completed. Marking as finished." + ) + self.finished_save_reqs.add(finished_req_id) + + except Exception as e: + logger.error(f"A save operation failed: {e}", exc_info=True) + + duration = time.time() - start_time + logger.info(f"All {len(pending_save_futures)} save operations " + f"completed in {duration:.4f} seconds.") + self._processed_save_for_step = True + + def start_load_kv(self, fwd_ctx: "ForwardContext") -> None: + """ + This function is the worker-side entry point for loading data from the + local CPU backend into the TPU's sharded KV cache. It is a blocking + operation that ensures the cache is fully updated before the model's + forward pass begins. + """ + # Reset the save processing flag at the start of a new step. + self._processed_save_for_step = False + metadata = self.connector._get_connector_metadata() + if not isinstance( + metadata, + TPUOffloadConnectorMetadata) or not metadata.requests_meta: + logger.info("No load operations scheduled for this step.") + return + + if not self.device_sharding: + raise RuntimeError( + "KV cache sharding info not available. Was register_runner called?" + ) + + assert self.runner is not None and self.runner.kv_caches is not None + + # Process each request that needs its KV cache loaded + load_times = [] + for meta in metadata.requests_meta: + if not (meta.load_spec and meta.load_spec.can_load): + continue + + request_load_start_time = time.time() + logger.info( + "TPUOffloadConnectorWorker: Starting KV cache load process.") + dst_blocks = meta.load_spec.dst_blocks + src_chunks = meta.load_spec.src_chunks + num_blocks_to_load = len(dst_blocks) + num_matched_tokens = meta.load_spec.num_matched_tokens + num_skip_leading_tokens = meta.load_spec.num_skip_leading_tokens + num_tokens_to_load_delta = num_matched_tokens - num_skip_leading_tokens + assert num_skip_leading_tokens % self.block_size == 0, f"{num_skip_leading_tokens} % {self.block_size} != 0" + + if num_tokens_to_load_delta <= 0: + logger.info( + f"Request {meta.req_id}: No new tokens to load. Skipping.") + continue + + assert num_blocks_to_load > 0, f"Request({meta.req_id}) has no dst blocks to load." + # Verify if dst_blocks is a contiguous subarray of meta.local_block_ids + first_dst_block = dst_blocks[0] + last_dst_block = dst_blocks[-1] + try: + first_block_idx_in_local = meta.local_block_ids.index( + first_dst_block) + last_block_idx_in_local = meta.local_block_ids.index( + last_dst_block) + if not (last_block_idx_in_local - first_block_idx_in_local + 1 + == len(dst_blocks)): + raise ValueError( + f"Request({meta.req_id}): dst_blocks {dst_blocks} does not exist in local_block_ids {meta.local_block_ids}" + ) + except ValueError: + raise ValueError( + f"Request({meta.req_id}): dst_blocks {dst_blocks} contains blocks not present in local_block_ids {meta.local_block_ids}" + ) + + logger.info( + f"Processing KV load for request {meta.req_id}: " + f"Total matched: {num_matched_tokens}, " + f"Already computed: {num_skip_leading_tokens}. " + f"Fetching delta of {num_tokens_to_load_delta} tokens from cache for " + f"{num_blocks_to_load} blocks.") + + # Assemble the per-layer data for the delta tokens on the CPU. + # We create a list of lists, where the outer list represents layers + # and the inner lists will hold the data chunks for that layer. + assembled_kv_on_cpu = [[] for _ in range(self.num_layers)] + # Fetch and chunks from the backend. + for i in range(num_blocks_to_load): + src_chunk_id = src_chunks[i] + cached_value = self.cpu_backend.get(src_chunk_id) + if cached_value: + for j in range(self.num_layers): + assembled_kv_on_cpu[j].append(cached_value[j]) + else: + logger.error( + f"Chunk[{src_chunk_id}] not found in CPU backend for request {meta.req_id}. Inconsistent state detected." + ) + return + + # swap-in + # output: [[cpu_chunk_size * num_chunks] * num_layer] + if self.use_bucketed_swap_ops: + # Use the bucketed wrappers for a uniform two-step process + raw_chunked_kv_on_tpu = self._bucketed_swap_in_fn( + assembled_kv_on_cpu) + else: + raw_chunked_kv_on_tpu = self.swap_in_fn(assembled_kv_on_cpu) + jax.block_until_ready(raw_chunked_kv_on_tpu) + + if self.use_bucketed_swap_ops: + self.runner.kv_caches = self._bucketed_jitted_insert_kv_cache_slices( + self.runner.kv_caches, + raw_chunked_kv_on_tpu, + jnp.array(dst_blocks), + ) + else: + self.runner.kv_caches = jitted_insert_kv_cache_slices( + self.block_size, + self.runner.kv_caches, + raw_chunked_kv_on_tpu, + jnp.array(dst_blocks), + ) + jax.block_until_ready(self.runner.kv_caches) + logger.info( + f"Request {meta.req_id}: Loaded {num_tokens_to_load_delta} tokens into " + f"{num_blocks_to_load} new blocks.") + + load_times.append(time.time() - request_load_start_time) + self.finished_load_reqs.add(meta.req_id) + if num_blocks_to_load > 0: + self.offload_stats.record_load(req=meta.req_id, + loaded_chunk_ids=src_chunks) + + if load_times: + aggregate_load_time = sum(load_times) + logger.info( + f"TPUOffloadConnectorWorker: Aggregate KV cache load time for {len(load_times)} requests: {aggregate_load_time:.4f} seconds" + ) + + def get_kv_connector_stats(self) -> KVConnectorStats | None: + """ + Get the KV transfer stats for the connector. + """ + # Clear stats for next iteration + if not self.offload_stats.is_empty(): + return self.offload_stats.clone_and_reset() + return None + + def get_finished(self) -> tuple[set[str], set[str]]: + """ + Returns the sets of request IDs for completed save and load operations. + """ + # Safeguard call to wait_for_save(). + # In the final step for a request, the vLLM engine may not call + # `worker.execute_model()` if there's no computation to be done. + # This skips the usual `wait_for_save()` call, preventing the final + # save operation (marked with `is_final_save=True`) from being + # processed. Calling it here ensures that any pending save operations + # for the current step's metadata are executed, and the finished + # request IDs are correctly identified and reported back to the engine + # for resource cleanup. The `wait_for_save` method is idempotent, + # so this call is a no-op in the normal execution path. + logger.info("TPUOffloadConnectorWorker: Entering get_finished") + self.wait_for_save() + + finished_saves = self.finished_save_reqs + self.finished_save_reqs = set() + finished_loads = self.finished_load_reqs + self.finished_load_reqs = set() + logger.info(f"Finished saves: {finished_saves}, " + f"Finished loads: {finished_loads}") + return finished_saves, finished_loads diff --git a/tpu_inference/distributed/offload/utils.py b/tpu_inference/distributed/offload/utils.py new file mode 100644 index 000000000..2643224e5 --- /dev/null +++ b/tpu_inference/distributed/offload/utils.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the LMCache project + +import functools +import hashlib +from dataclasses import dataclass +from typing import Callable, Iterable, List, Literal, Optional, Tuple + +import jax +from vllm.config import get_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory + +from tpu_inference.kernels.dma.host_dma import d2h_dma, h2d_dma +from tpu_inference.logger import init_logger + +ReqId = str + +CpuChunkId = int + +# Corresponds to the initial hash value +NONE_HASH = 0 + +logger = init_logger(__name__) + +CPU_OFFLOADING_SWAP_OP_TYPE = Literal["jax", "pallas"] + +DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS = 8192 + + +@dataclass(order=True) +class CacheKey: + """ + A key for the cache engine. + """ + model_name: str + chunk_hash: int + + def __hash__(self): + return hash(( + self.model_name, + self.chunk_hash, + )) + + def __eq__(self, other): + if type(self) is type(other): + return (self.model_name == other.model_name + and self.chunk_hash == other.chunk_hash) + return False + + +class TokenProcessor: + + def __init__(self, model_name: str, chunk_size: int = 16): + self.model_name = model_name + self.chunk_size = chunk_size + logger.info(f"TokenProcessor initialized with chunk_size={chunk_size}") + + def _hash_tokens( + self, + tokens: List[int], + prefix_hash: Optional[int] = None, + ) -> int: + hasher = hashlib.sha256() + hasher.update(str(prefix_hash).encode('utf-8')) + hasher.update(str(tuple(tokens)).encode('utf-8')) + return int(hasher.hexdigest(), 16) + + def process_tokens( + self, + tokens: Optional[List[int]] = None, + ) -> Iterable[Tuple[int, int, CacheKey]]: + """Process the tokens and return the corresponding cache keys.""" + if not tokens: + return + + total_len = len(tokens) + prefix_hash = NONE_HASH + + for i in range(0, total_len, self.chunk_size): + chunk = tokens[i:i + self.chunk_size] + prefix_hash = self._hash_tokens(chunk, prefix_hash) + start_idx = i + end_idx = min(start_idx + self.chunk_size, total_len) + logger.info( + f"Processing chunk: start={start_idx}, end={end_idx}, hash={prefix_hash}" + ) + yield ( + start_idx, + end_idx, + CacheKey(model_name=self.model_name, chunk_hash=prefix_hash), + ) + + +def get_kv_connector_cache_layout(): + """ + Retrieve the required kv cache layout for the configured kv connector + Return: None, when no kv_transfer_config is found; otherwise, the layout str + """ + vllm_config = get_current_vllm_config() + kv_config = vllm_config.kv_transfer_config + if kv_config is not None: + connector_cls = KVConnectorFactory.get_connector_class(kv_config) + required_kvcache_layout = \ + connector_cls.get_required_kvcache_layout(vllm_config) + if required_kvcache_layout is not None: + return required_kvcache_layout + logger.info_once( + "Connectors do not specify a kv cache layout, defaulting to NHD.") + return None + + +def get_default_kv_connector_staging_buffer_tokens() -> int: + return DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS + + +SwapFn = Callable[ + [ + List[jax.Array], # src_kv_caches + jax.sharding.NamedSharding, # src_sharding + jax.sharding.NamedSharding, # dst_sharding + Literal["h2d", "d2h"], # direction + ], + List[jax.Array], # return value +] + +KVCacheSwapFn = Callable[[List[jax.Array]], List[jax.Array]] + + +# NOTE(jcgu): keep the same interface as the pallas one +def jax_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, + direction: Literal["h2d", "d2h"], +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using jax device_put + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def _jax_device_put(input_array): + return jax.device_put(input_array, dst_sharding) + + return jax.tree.map(_jax_device_put, src_kv_caches) + + +def pallas_swap_kv_caches( + src_kv_caches: List[jax.Array], + src_sharding: jax.sharding.NamedSharding, + dst_sharding: jax.sharding.NamedSharding, + direction: Literal["h2d", "d2h"], +) -> List[jax.Array]: + """Swap in / out multi-layer kv_cache using pallas dma kernel + + Args: + src_kv_caches: [kv_cache of each layer] + src_sharding: kv_caches' original sharding + dst_sharding: kv_caches' target sharding (different memory_kind) + direction: h2d -> swap_in, d2h -> swap_out + Returns: + a list of jax.Array objects with the dst_sharding + """ + + def swap_in_fn(inputs, input_sharding, out_sharding): + + def _swap_in(host_sharded_array): + return h2d_dma(host_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_in, inputs) + + def swap_out_fn(inputs, input_sharding, out_sharding): + + def _swap_out(hbm_sharded_array): + return d2h_dma(hbm_sharded_array, input_sharding, out_sharding) + + return jax.tree.map(_swap_out, inputs) + + if direction == "d2h": + return swap_out_fn(src_kv_caches, src_sharding, dst_sharding) + elif direction == "h2d": + return swap_in_fn(src_kv_caches, src_sharding, dst_sharding) + + +def get_kv_cache_swap_fn( + swap_op_type: CPU_OFFLOADING_SWAP_OP_TYPE, + host_sharding: jax.sharding.NamedSharding, + device_sharding: jax.sharding.NamedSharding, + jitted: bool = True, +) -> Tuple[KVCacheSwapFn, KVCacheSwapFn]: + """get the right swap_in and swap_out functions + + Args: + swap_op_type : (str) pallas or jax + host_sharding: + device_sharding: + + Returns: + A tuple containing the jitted swap-in and swap-out functions. + """ + _swap_fn: SwapFn = pallas_swap_kv_caches if swap_op_type == "pallas" else jax_swap_kv_caches + if jitted: + _swap_in_fn = jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=device_sharding) + _swap_out_fn = jax.jit( + _swap_fn, + static_argnames=["src_sharding", "dst_sharding", "direction"], + out_shardings=host_sharding) + else: + _swap_in_fn = _swap_fn + _swap_out_fn = _swap_fn + + # swap_in (h2d) + swap_in_fn = functools.partial(_swap_in_fn, + src_sharding=host_sharding, + dst_sharding=device_sharding, + direction="h2d") + # swap_out (d2h) + swap_out_fn = functools.partial(_swap_out_fn, + src_sharding=device_sharding, + dst_sharding=host_sharding, + direction="d2h") + return swap_in_fn, swap_out_fn + + +@functools.partial( + jax.jit, + static_argnames=("block_size"), + donate_argnames=( + "kv_caches", + "kv_cache_slices", + ), +) +def jitted_insert_kv_cache_slices( + block_size, + kv_caches: List[jax.Array], + kv_cache_slices: List[List[jax.Array]], + block_numbers: jax.Array, +) -> List[jax.Array]: + """ + JIT-compiled function to insert KV cache slices into the physical + cache for all layers at once. This fuses reshape, and scatter + operations into a single efficient kernel. + """ + + def _update_layer(cache, slices): + """The function to apply to each layer's cache and slices.""" + # new_shape = (1, block_size, *slices[0].shape[1:]) + for (i, block_idx) in enumerate(block_numbers): + # reshaped_block = slices[i].reshape(new_shape) + reshaped_block = jax.lax.expand_dims(slices[i], dimensions=(0, )) + cache = jax.lax.dynamic_update_slice_in_dim(cache, + reshaped_block, + block_idx, + axis=0) + return cache + + return jax.tree.map(_update_layer, kv_caches, kv_cache_slices) diff --git a/tpu_inference/kernels/dma/__init__.py b/tpu_inference/kernels/dma/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tpu_inference/kernels/dma/host_dma.py b/tpu_inference/kernels/dma/host_dma.py new file mode 100644 index 000000000..68a53f9d0 --- /dev/null +++ b/tpu_inference/kernels/dma/host_dma.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +""" Host <-> HBM DMA kernel""" +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +def host_hbm_dma(x_ref, y_ref): + """ + DMA a jax array between host and hbm + Input jax array ref: x_ref + Output jax array ref: y_ref + """ + + def body(sem): + pltpu.async_copy(x_ref, y_ref, sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + +# NOTE(jcgu): input / out arrays should have the same sharding, but different memory_kind +# NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding +def d2h_dma( + input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, + out_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """ DMA a device jax array to host memory. + Args: + input_array: input jax array on device hbm + input_sharding: input's device sharding + out_sharding: output's host sharding + Returns: + jax array on host memory with the same sharding + """ + + @jax.jit + def _d2h_dma_call(x): + return pl.pallas_call( + host_hbm_dma, + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.HOST), + out_shape=pltpu.HOST(shape=x.shape, dtype=x.dtype), + name="d2h_dma_kernel", + )(x) + + d2h_dma_kernel = jax.jit( + jax.shard_map( + _d2h_dma_call, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, + out_specs=out_sharding.spec, + check_vma=False, + ), + out_shardings=out_sharding, + ) + + return d2h_dma_kernel(input_array) + + +# NOTE(jcgu): input / out arrays should have the same sharding, but different memory_kind +# NOTE(jcgu): only support NamedSharding, does not support SingleDeviceSharding +def h2d_dma( + input_array: jax.Array, + input_sharding: jax.sharding.NamedSharding, + out_sharding: jax.sharding.NamedSharding, +) -> jax.Array: + """ DMA a host jax array to device hbm. + Args: + input_array: input jax array on host memory + input_sharding: the host sharding for input + out_sharding: the device sharding for output + Returns: + jax array on device hbm with the assigned sharding + """ + + @jax.jit + def _h2d_dma_call(x): + return pl.pallas_call( + host_hbm_dma, + in_specs=[ + pl.BlockSpec(memory_space=pl.HOST), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + out_shape=jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype), + name="h2d_dma_kernel", + )(x) + + h2d_dma_kernel = jax.jit( + jax.shard_map( + _h2d_dma_call, + mesh=input_sharding.mesh, + in_specs=input_sharding.spec, + out_specs=out_sharding.spec, + check_vma=False, + ), + out_shardings=out_sharding, + ) + + return h2d_dma_kernel(input_array) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 9f2a78526..544b77de6 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -217,10 +217,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True - kv_transfer_config = vllm_config.kv_transfer_config - if kv_transfer_config is not None: - assert kv_transfer_config.kv_connector == "TPUConnector" - # Late initialization to avoid circular import from tpu_inference.models.jax.utils.quantization.quantization_utils import \ update_vllm_config_for_qwix_quantization diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index 348521715..cba70ffaf 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -11,12 +11,16 @@ from vllm.attention.layer import Attention from vllm.config import get_layers_from_vllm_config from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import (get_kv_cache_layout, + set_kv_cache_layout) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, MLAAttentionSpec, SlidingWindowSpec) from tpu_inference import utils from tpu_inference import utils as common_utils +from tpu_inference.distributed.offload.utils import \ + get_kv_connector_cache_layout from tpu_inference.logger import init_logger from tpu_inference.runner import utils as runner_utils from tpu_inference.runner.input_batch import CachedRequestState, InputBatch @@ -29,6 +33,10 @@ logger = init_logger(__name__) +# default layout (order) used by kv cache manager +# N=num_blocks, H=num_heads and D=head_size +DEFAULT_KV_CACHE_LAYOUT = "NHD" + class KVCacheManager: @@ -165,6 +173,10 @@ def get_kv_cache_spec(self): f"Unknown attention type: {attn_module.attn_type}") return kv_cache_spec + def get_kv_cache_layout(self): + # return the layout (mostly "NHD" or "HND") of kv cache + return get_kv_cache_layout() + def maybe_reinitialize_input_batch(self, kv_cache_config: KVCacheConfig) -> None: block_sizes = [ @@ -195,6 +207,19 @@ def maybe_reinitialize_input_batch(self, def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.maybe_reinitialize_input_batch(kv_cache_config) + # set the kv cache layout which is needed by kv connectors + # NOTE(jcgu): please update the default value when the order changes + set_kv_cache_layout(DEFAULT_KV_CACHE_LAYOUT) + + # verify kv cache layout is matched between the cache manager and + # the kv connector (if configured) + _required_kv_layout = get_kv_connector_cache_layout() + if (_required_kv_layout + and _required_kv_layout != DEFAULT_KV_CACHE_LAYOUT): + raise ValueError( + f"KV cache layout ({DEFAULT_KV_CACHE_LAYOUT}) does not match with the " + f"kv_connector's required layout ({_required_kv_layout})") + # uniform page size. representative_spec = kv_cache_config.kv_cache_groups[0].kv_cache_spec page_size_bytes = representative_spec.page_size_bytes diff --git a/tpu_inference/runner/tpu_runner.py b/tpu_inference/runner/tpu_runner.py index ac67eae30..893f7ac60 100644 --- a/tpu_inference/runner/tpu_runner.py +++ b/tpu_inference/runner/tpu_runner.py @@ -524,6 +524,9 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: def get_kv_cache_spec(self): return self.kv_cache_manager.get_kv_cache_spec() + def get_kv_cache_layout(self): + return self.kv_cache_manager.get_kv_cache_layout() + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.kv_cache_config = kv_cache_config self.use_hybrid_kvcache = len(kv_cache_config.kv_cache_groups) > 1 diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 9b52d43e4..d9a3df752 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -26,6 +26,8 @@ from tpu_inference import envs, utils from tpu_inference.distributed import jax_parallel_state +from tpu_inference.distributed.offload.utils import \ + get_default_kv_connector_staging_buffer_tokens from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.common.sharding import ShardingConfigManager @@ -289,6 +291,31 @@ def determine_available_memory(self) -> int: total_hbm_limit_gb = round(total_hbm_limit / utils.GBYTES, 2) total_hbm_limit_cap_gb = round(total_hbm_limit_cap / utils.GBYTES, 2) total_hbm_used_gb = round(total_hbm_used / utils.GBYTES, 2) + + if self.vllm_config.kv_transfer_config is not None: + kv_transfer_config = self.vllm_config.kv_transfer_config + if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector": + # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. + _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens( + ) + staging_buffer_tokens = int( + os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", + str(_default_staging_buffer_tokens))) + # calculate staging buffer size + staging_buffer_pages = staging_buffer_tokens // self.vllm_config.cache_config.block_size + + kv_cache_specs = self.model_runner.get_kv_cache_spec() + num_layers = len(kv_cache_specs) + vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs) + # rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh, + # kv_cache_specs) + stage_buffer_size_bytes = staging_buffer_pages * num_layers * vllm_page_size_bytes + + total_hbm_avail = total_hbm_avail - stage_buffer_size_bytes + logger.info( + f" ALERT: KV offloading enabled. Deducting {stage_buffer_size_bytes} Bytes ({staging_buffer_pages} pages) from available HBM for staging buffer." + ) + total_hbm_avail_gb = round(total_hbm_avail / utils.GBYTES, 2) logger.info(f"Memory statistics | " @@ -432,6 +459,11 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: return kv_cache_specs + def get_kv_connector_handshake_metadata(self) -> dict | None: + """Get KV connector metadata from this worker if available.""" + # NOTE: we are not using it right now. + return + def initialize_from_config( self, kv_cache_config: KVCacheConfig, From cd5cce2823bf59495561d39d99662ed0add2b6eb Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 21:40:25 +0000 Subject: [PATCH 02/29] tweaks Signed-off-by: Juncheng Gu --- .../distributed/offload/tpu_offload_connector_worker_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 246358f79..6642aaa22 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -47,6 +47,7 @@ class MockVllmConfig: def __init__(self, block_size=_DEFAULT_BLOCK_SIZE): self.model_config = self.Model() self.cache_config = self.Cache(block_size) + self.kv_transfer_config = self.KVTransferConfig() class Model: model = "test-model" @@ -56,6 +57,10 @@ class Cache: def __init__(self, block_size): self.block_size = block_size + class KVTransferConfig: + ip = "ip" + port = 1234 + class TestCpuOffloadingSave(jtu.JaxTestCase): """Test the save functionality of the TPUOffloadConnectorWorker.""" From 21ae0def5297d3f60ec59fdbc0c4ce17c54e15bb Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 22:55:09 +0000 Subject: [PATCH 03/29] offload envs Signed-off-by: Juncheng Gu --- .../tpu_offload_connector_scheduler_test.py | 7 ----- .../distributed/offload/cpu_backend.py | 3 --- .../distributed/offload/offload_manager.py | 3 --- .../offload/tpu_offload_connector.py | 27 +++++++------------ tpu_inference/envs.py | 20 ++++++++++++++ 5 files changed, 29 insertions(+), 31 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py index ed83bae34..3ab5dbdf2 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -62,18 +62,12 @@ def scheduler_factory(): def _scheduler( block_size: int = _DEFAULT_BLOCK_SIZE, offload_decode_save: int = 0, - offload_partial_block_save_behavior: str = "drop", - offload_partial_block_dynamic_pad_lower_limit: int = 0, offload_staging_buffer_tokens: int = -1, offload_num_cpu_chunks: int = DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, ): # update config vllm_config = MockVllmConfig(block_size=block_size) os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) - os.environ[ - "TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR"] = offload_partial_block_save_behavior - os.environ["TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT"] = str( - offload_partial_block_dynamic_pad_lower_limit) if offload_staging_buffer_tokens >= 0: os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( offload_staging_buffer_tokens) @@ -238,7 +232,6 @@ def test_build_connector_meta_new_prefill(self, scheduler_factory, """ num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE scheduler = scheduler_factory( - offload_partial_block_save_behavior="drop", offload_staging_buffer_tokens=num_staging_tokens, offload_num_cpu_chunks=100) diff --git a/tpu_inference/distributed/offload/cpu_backend.py b/tpu_inference/distributed/offload/cpu_backend.py index 37352c504..3199c5086 100644 --- a/tpu_inference/distributed/offload/cpu_backend.py +++ b/tpu_inference/distributed/offload/cpu_backend.py @@ -10,9 +10,6 @@ logger = init_logger(__name__) -GB = 1024**3 -DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB - class LocalCPUBackend: """ diff --git a/tpu_inference/distributed/offload/offload_manager.py b/tpu_inference/distributed/offload/offload_manager.py index eb9eee6db..c4f5cfcf0 100644 --- a/tpu_inference/distributed/offload/offload_manager.py +++ b/tpu_inference/distributed/offload/offload_manager.py @@ -12,9 +12,6 @@ logger = init_logger(__name__) -GB = 1024**3 -DEFAULT_CPU_CACHE_SIZE_BYTES = 1 * GB - ChunkHash = BlockHash diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 52be57ee7..c0c07e3ad 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -112,13 +112,13 @@ from vllm.v1.request import Request from vllm.forward_context import ForwardContext +from tpu_inference import envs from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend from tpu_inference.distributed.offload.offload_manager import ( LRUCacheManager, StagingBufferManager) from tpu_inference.distributed.offload.utils import ( CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId, - TokenProcessor, get_default_kv_connector_staging_buffer_tokens, - get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) + TokenProcessor, get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) from tpu_inference.logger import init_logger from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -480,9 +480,7 @@ def __init__(self, vllm_config: "VllmConfig"): self.block_size = vllm_config.cache_config.block_size # offloading manager - self.num_cpu_chunks = int( - os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", - str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS self.offload_manager = LRUCacheManager( num_cpu_chunks=self.num_cpu_chunks) @@ -506,7 +504,7 @@ def __init__(self, vllm_config: "VllmConfig"): self.token_processor = TokenProcessor(model_name=model_name, chunk_size=self.block_size) - self.decode_save = os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0") == "1" + self.decode_save = envs.TPU_OFFLOAD_DECODE_SAVE # NOTE(jcgu): currently, let's make chunk_size == block_size # chunk_size == n * block_size lead to # 1. multi-size chunks @@ -514,6 +512,7 @@ def __init__(self, vllm_config: "VllmConfig"): # real-chunk-size in save and load self.cpu_chunk_size = self.block_size + # TODO(jcgu): rm # define partial_block saving behavior self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = \ os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR", "drop") @@ -535,11 +534,7 @@ def __init__(self, vllm_config: "VllmConfig"): # config staging buffer # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler # otherwise, we can only use # of tokens as input, instead of buffer size in GB - _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens( - ) - num_staging_buffer_tokens = int( - os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", - str(_default_staging_buffer_tokens))) + num_staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS self.num_staging_blocks = num_staging_buffer_tokens // self.block_size self.staging_buffer_manager = StagingBufferManager( num_blocks=self.num_staging_blocks) @@ -1214,15 +1209,13 @@ def __init__(self, vllm_config: VllmConfig, self.runner: Optional[TPUModelRunner] = None self.mesh: Optional[Mesh] = None - self.swap_op_type = os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", - default=DEFAULT_HOST_HBM_SWAP_OP_TYPE) + self.swap_op_type = envs.TPU_OFFLOAD_SWAP_OP_TYPE assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE) # TODO(jcgu): check libtpu compatibility for pallas dma kernel logger.info( f"(cpu offloading) swap operation type is {self.swap_op_type}") - self.use_bucketed_swap_ops = os.getenv( - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0") == "0" + self.use_bucketed_swap_ops = not envs.TPU_OFFLOAD_SKIP_JAX_PRECOMPILE logger.info( f"(cpu offloading) use_bucketed_swap_ops={self.use_bucketed_swap_ops}" ) @@ -1231,9 +1224,7 @@ def __init__(self, vllm_config: VllmConfig, self.swap_out_fn: KVCacheSwapFn = None # cpu cache - self.num_cpu_chunks = int( - os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", - str(DEFAULT_TPU_OFFLOAD_CPU_CHUNKS))) + self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS self.cpu_backend = LocalCPUBackend(num_cpu_chunks=self.num_cpu_chunks) # The worker needs its own token processor to generate keys. model_name = self.vllm_config.model_config.model diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 9201e1a11..332d73049 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -24,6 +24,11 @@ NUM_SLICES: int = 1 RAY_USAGE_STATS_ENABLED: str = "0" VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: str = "shm" + TPU_OFFLOAD_SKIP_JAX_PRECOMPILE: bool = False + TPU_OFFLOAD_SWAP_OP_TYPE: str = "jax" + TPU_OFFLOAD_DECODE_SAVE: bool = False + TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024 + TPU_OFFLOAD_STAGING_BUFFER_TOKENS: int = 8192 def env_with_choices( @@ -122,6 +127,21 @@ def _get_validated_env() -> str | None: # Ray compiled DAG channel type for TPU "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]), + # kv offload to dram: save kv in the decode phase + "TPU_OFFLOAD_DECODE_SAVE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))), + # kv offload to dram: swap function type: jax, or pallas + "TPU_OFFLOAD_SWAP_OP_TYPE": + lambda: os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", "jax"), + # kv offload to dram: dram space size in # of chunks / blocks + "TPU_OFFLOAD_NUM_CPU_CHUNKS": + lambda: int(os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", "1024")), + # kv offload to dram: dram space size in # of chunks / blocks + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))), + # kv offload to dram: size of staging buffer (hbm) for swap + "TPU_OFFLOAD_STAGING_BUFFER_TOKENS": + lambda: int(os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", "16384")), } From a5ec87d47647fb0616f2692ed4b639282a293444 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 23:03:44 +0000 Subject: [PATCH 04/29] rm saving behavior Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_connector.py | 48 +------------------ tpu_inference/distributed/offload/utils.py | 6 --- tpu_inference/worker/tpu_worker.py | 8 +--- 3 files changed, 3 insertions(+), 59 deletions(-) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index c0c07e3ad..eb60147fc 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -128,9 +128,6 @@ # kv cache layout needed by cpu offloading mechanism REQUIRED_KV_CACHE_LAYOUT = "NHD" -# default swap op type -DEFAULT_HOST_HBM_SWAP_OP_TYPE = "jax" - BLOCK_SIZE_BUCKETS = [1, 2, 4, 8, 16] # we keep our operations at vllm's block granularity, @@ -139,9 +136,7 @@ # 1. [supported] drop: drop the entire partial block # 2. pad: pad to a full block # 3. dynamic: keep the partial block as is. -PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop", "pad", "dynamic"] - -DEFAULT_TPU_OFFLOAD_CPU_CHUNKS = 1024 +PARTIAL_BLOCK_SAVE_BEHAVIOR = Literal["drop"] @dataclass @@ -512,24 +507,7 @@ def __init__(self, vllm_config: "VllmConfig"): # real-chunk-size in save and load self.cpu_chunk_size = self.block_size - # TODO(jcgu): rm - # define partial_block saving behavior - self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = \ - os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_SAVE_BEHAVIOR", "drop") - assert self.partial_block_save_behavior in get_args( - PARTIAL_BLOCK_SAVE_BEHAVIOR - ), f"{self.partial_block_save_behavior} not in {get_args(PARTIAL_BLOCK_SAVE_BEHAVIOR)}" - self.partial_block_dynamic_pad_lower_limit = \ - int(os.getenv("TPU_OFFLOAD_PARTIAL_BLOCK_DYNAMIC_PAD_LOWER_LIMIT", "0")) - if self.partial_block_save_behavior == "dynamic": - if self.partial_block_dynamic_pad_lower_limit <= 0: - self.partial_block_save_behavior == "drop" - elif self.partial_block_dynamic_pad_lower_limit >= self.block_size: - self.partial_block_save_behavior == "pad" - logger.info( - f" partial_block_save_behavior is configed to {self.partial_block_save_behavior}, but we only support drop now." - ) - self.partial_block_save_behavior = "drop" + self.partial_block_save_behavior: PARTIAL_BLOCK_SAVE_BEHAVIOR = "drop" # config staging buffer # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler @@ -547,7 +525,6 @@ def __init__(self, vllm_config: "VllmConfig"): f"model_name={model_name}, " f"decode_save={self.decode_save}, " f"partial_block_save_behavior={self.partial_block_save_behavior}, " - f"partial_block_dynamic_pad_lower_limit={self.partial_block_dynamic_pad_lower_limit}, " f"num_staging_blocks={self.num_staging_blocks}.") def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: @@ -668,27 +645,6 @@ def get_num_new_matched_tokens( # external_computed_tokens, load_kv_async return num_to_load, False - def _adjust_last_partial_block(self, - last_partial_block_num_tokens: int) -> bool: - """ - adjust prompt token / len based on pre-configed save behavior - when the last block of request's token is partially used. - In order to keep all the saved kv be aligned with block_size, - we may - 1. drop the partial block - 2. pad the partial block to be a full block - 3. drop or pad based on actual num_tokens in the last partial block - - Input: num of tokens in the last partial block (could be 0) - Output: the last partial block should be kept (True) or dropped (False) - """ - if self.partial_block_save_behavior == "pad": - return True if last_partial_block_num_tokens > 0 else False - elif self.partial_block_save_behavior == "drop": - return False - elif self.partial_block_save_behavior == "dynamic": - return True if last_partial_block_num_tokens >= self.partial_block_dynamic_pad_lower_limit else False - def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int): diff --git a/tpu_inference/distributed/offload/utils.py b/tpu_inference/distributed/offload/utils.py index 2643224e5..c3767983f 100644 --- a/tpu_inference/distributed/offload/utils.py +++ b/tpu_inference/distributed/offload/utils.py @@ -25,8 +25,6 @@ CPU_OFFLOADING_SWAP_OP_TYPE = Literal["jax", "pallas"] -DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS = 8192 - @dataclass(order=True) class CacheKey: @@ -110,10 +108,6 @@ def get_kv_connector_cache_layout(): return None -def get_default_kv_connector_staging_buffer_tokens() -> int: - return DEFAULT_TPU_OFFLOAD_STAGING_BUFFER_TOKENS - - SwapFn = Callable[ [ List[jax.Array], # src_kv_caches diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index d9a3df752..3a6e27bca 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -26,8 +26,6 @@ from tpu_inference import envs, utils from tpu_inference.distributed import jax_parallel_state -from tpu_inference.distributed.offload.utils import \ - get_default_kv_connector_staging_buffer_tokens from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port, get_node_id) from tpu_inference.layers.common.sharding import ShardingConfigManager @@ -296,11 +294,7 @@ def determine_available_memory(self) -> int: kv_transfer_config = self.vllm_config.kv_transfer_config if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector": # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. - _default_staging_buffer_tokens = get_default_kv_connector_staging_buffer_tokens( - ) - staging_buffer_tokens = int( - os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", - str(_default_staging_buffer_tokens))) + staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS # calculate staging buffer size staging_buffer_pages = staging_buffer_tokens // self.vllm_config.cache_config.block_size From 0fc7dad5d9e00c9d87dd26c726eebb7df2bc66ec Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 23:05:09 +0000 Subject: [PATCH 05/29] tweaks Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_connector_scheduler_test.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py index 3ab5dbdf2..07fddc91e 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -10,8 +10,7 @@ from vllm.v1.request import Request from tpu_inference.distributed.offload.tpu_offload_connector import ( - DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, RequestTracker, - TPUOffloadConnectorScheduler) + RequestTracker, TPUOffloadConnectorScheduler) _DEFAULT_BLOCK_SIZE = 16 @@ -63,7 +62,7 @@ def _scheduler( block_size: int = _DEFAULT_BLOCK_SIZE, offload_decode_save: int = 0, offload_staging_buffer_tokens: int = -1, - offload_num_cpu_chunks: int = DEFAULT_TPU_OFFLOAD_CPU_CHUNKS, + offload_num_cpu_chunks: int = -1, ): # update config vllm_config = MockVllmConfig(block_size=block_size) From df3b091c5608cdbb2531ebc20c03e6718940ab4c Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 23:19:13 +0000 Subject: [PATCH 06/29] staging_tokens --> staging_blocks Signed-off-by: Juncheng Gu --- .../tpu_offload_connector_scheduler_test.py | 20 +++++++++---------- .../offload/tpu_offload_connector.py | 19 +++++++----------- tpu_inference/envs.py | 18 ++++++++--------- tpu_inference/worker/tpu_worker.py | 4 +--- 4 files changed, 26 insertions(+), 35 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py index 07fddc91e..d0b2bdced 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -61,15 +61,15 @@ def scheduler_factory(): def _scheduler( block_size: int = _DEFAULT_BLOCK_SIZE, offload_decode_save: int = 0, - offload_staging_buffer_tokens: int = -1, + offload_num_staging_blocks: int = -1, offload_num_cpu_chunks: int = -1, ): # update config vllm_config = MockVllmConfig(block_size=block_size) os.environ["TPU_OFFLOAD_DECODE_SAVE"] = str(offload_decode_save) - if offload_staging_buffer_tokens >= 0: - os.environ["TPU_OFFLOAD_STAGING_BUFFER_TOKENS"] = str( - offload_staging_buffer_tokens) + if offload_num_staging_blocks >= 0: + os.environ["TPU_OFFLOAD_NUM_STAGING_BLOCKS"] = str( + offload_num_staging_blocks) if offload_num_cpu_chunks > 0: os.environ["TPU_OFFLOAD_NUM_CPU_CHUNKS"] = str( offload_num_cpu_chunks) @@ -111,9 +111,8 @@ def test_get_num_new_matched_tokens_hit(self, scheduler_factory, 5. skip 1 block + full-hit + only 1 staging block 6. skip 1 block + full-hit + no staging block """ - num_staging_tokens = num_staging_blocks * _DEFAULT_BLOCK_SIZE scheduler = scheduler_factory( - offload_staging_buffer_tokens=num_staging_tokens) + offload_num_staging_blocks=num_staging_blocks) prompt_len = scheduler.block_size * num_prompt_blocks num_computed_tokens = scheduler.block_size * num_computed_blocks num_blocks_to_load = num_matched_blocks - num_computed_blocks @@ -231,7 +230,7 @@ def test_build_connector_meta_new_prefill(self, scheduler_factory, """ num_staging_blocks = num_staging_tokens // _DEFAULT_BLOCK_SIZE scheduler = scheduler_factory( - offload_staging_buffer_tokens=num_staging_tokens, + offload_num_staging_blocks=num_staging_blocks, offload_num_cpu_chunks=100) # calculate the groundtruth @@ -347,10 +346,9 @@ def test_build_connector_meta_decode_with_save(self, scheduler_factory, 2. th N-th decode (hit block bounary) + not decode_save (no save) """ - scheduler = scheduler_factory( - offload_decode_save=decode_save, - offload_staging_buffer_tokens=_DEFAULT_BLOCK_SIZE * 10, - offload_num_cpu_chunks=10) + scheduler = scheduler_factory(offload_decode_save=decode_save, + offload_num_staging_blocks=10, + offload_num_cpu_chunks=10) prompt_tokens = list(range(prompt_len)) generated_tokens = list(range(prompt_len, seq_len)) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index eb60147fc..7b23f3c8c 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -511,9 +511,8 @@ def __init__(self, vllm_config: "VllmConfig"): # config staging buffer # NOTE(jcgu): Need to find a way to grab page_size_bytes in scheduler - # otherwise, we can only use # of tokens as input, instead of buffer size in GB - num_staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS - self.num_staging_blocks = num_staging_buffer_tokens // self.block_size + # otherwise, we can only use # of blocks as input, instead of buffer size in GB + self.num_staging_blocks = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS self.staging_buffer_manager = StagingBufferManager( num_blocks=self.num_staging_blocks) @@ -698,19 +697,15 @@ def _prepare_req_meta( block_hashes = self._get_request_block_hashes(_request) self.offload_manager.touch(block_hashes) - # only consider the tokens covered by block_hashes + # only consider the tokens covered by block_hashes; + # currently full blocks only num_total_blocks = len(block_hashes) num_total_tokens = min(num_total_blocks * self.block_size, len(tracker.token_ids)) num_full_blocks = num_total_tokens // self.block_size - num_full_blocks_tokens = num_full_blocks * self.block_size - # adjust last partial block - last_partial_block_num_tokens = num_total_tokens - num_full_blocks_tokens - need_last_block = self._adjust_last_partial_block( - last_partial_block_num_tokens) - adjusted_num_total_tokens = num_total_tokens if need_last_block else num_full_blocks_tokens - adjusted_num_total_blocks = num_full_blocks + (1 if need_last_block - else 0) + num_full_block_tokens = num_full_blocks * self.block_size + adjusted_num_total_tokens = num_full_block_tokens + adjusted_num_total_blocks = num_full_blocks assert adjusted_num_total_blocks <= len(tracker.block_ids) has_new_tokens = adjusted_num_total_tokens > tracker.save_watermark diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 332d73049..75e95cd59 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -28,7 +28,7 @@ TPU_OFFLOAD_SWAP_OP_TYPE: str = "jax" TPU_OFFLOAD_DECODE_SAVE: bool = False TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024 - TPU_OFFLOAD_STAGING_BUFFER_TOKENS: int = 8192 + TPU_OFFLOAD_NUM_STAGING_BLOCKS: int = 128 def env_with_choices( @@ -127,21 +127,21 @@ def _get_validated_env() -> str | None: # Ray compiled DAG channel type for TPU "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE": env_with_choices("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE", "shm", ["shm"]), - # kv offload to dram: save kv in the decode phase - "TPU_OFFLOAD_DECODE_SAVE": - lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))), + # kv offload to dram: skip pre-compiling swap-related jax functions + "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))), # kv offload to dram: swap function type: jax, or pallas "TPU_OFFLOAD_SWAP_OP_TYPE": lambda: os.getenv("TPU_OFFLOAD_SWAP_OP_TYPE", "jax"), + # kv offload to dram: save kv in the decode phase + "TPU_OFFLOAD_DECODE_SAVE": + lambda: bool(int(os.getenv("TPU_OFFLOAD_DECODE_SAVE", "0"))), # kv offload to dram: dram space size in # of chunks / blocks "TPU_OFFLOAD_NUM_CPU_CHUNKS": lambda: int(os.getenv("TPU_OFFLOAD_NUM_CPU_CHUNKS", "1024")), - # kv offload to dram: dram space size in # of chunks / blocks - "TPU_OFFLOAD_SKIP_JAX_PRECOMPILE": - lambda: bool(int(os.getenv("TPU_OFFLOAD_SKIP_JAX_PRECOMPILE", "0"))), # kv offload to dram: size of staging buffer (hbm) for swap - "TPU_OFFLOAD_STAGING_BUFFER_TOKENS": - lambda: int(os.getenv("TPU_OFFLOAD_STAGING_BUFFER_TOKENS", "16384")), + "TPU_OFFLOAD_NUM_STAGING_BLOCKS": + lambda: int(os.getenv("TPU_OFFLOAD_NUM_STAGING_BLOCKS", "128")), } diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 3a6e27bca..3d2739c30 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -294,9 +294,7 @@ def determine_available_memory(self) -> int: kv_transfer_config = self.vllm_config.kv_transfer_config if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector": # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. - staging_buffer_tokens = envs.TPU_OFFLOAD_STAGING_BUFFER_TOKENS - # calculate staging buffer size - staging_buffer_pages = staging_buffer_tokens // self.vllm_config.cache_config.block_size + staging_buffer_pages = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS kv_cache_specs = self.model_runner.get_kv_cache_spec() num_layers = len(kv_cache_specs) From aca95f14ef8a9a1c96be7b30dad1994a6fc84a52 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 24 Nov 2025 23:33:39 +0000 Subject: [PATCH 07/29] updte gke yaml Signed-off-by: Juncheng Gu --- examples/gke/benchmarks/deploy-cpu-offload.yaml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/gke/benchmarks/deploy-cpu-offload.yaml b/examples/gke/benchmarks/deploy-cpu-offload.yaml index 5bcd573b2..8bddddbe1 100644 --- a/examples/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/gke/benchmarks/deploy-cpu-offload.yaml @@ -30,8 +30,10 @@ spec: key: token - name: SKIP_JAX_PRECOMPILE value: "1" - - name: TPU_OFFLOAD_CPU_CACHE_SIZE_GB - value: "1024" + - name: TPU_OFFLOAD_NUM_CPU_CHUNKS + value: "4096" + - name: TPU_OFFLOAD_NUM_STAGING_BLOCKS + value: "256" ports: - containerPort: 8000 resources: From ace918a642693c107f08af9a4a04392123199223 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 25 Nov 2025 00:05:20 +0000 Subject: [PATCH 08/29] tweaks Signed-off-by: Juncheng Gu --- tpu_inference/distributed/offload/tpu_offload_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 7b23f3c8c..ce506f584 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -246,7 +246,7 @@ def update(self, new_block_ids: list[int], new_token_ids: list[int]): self.block_ids.extend(new_block_ids) self.token_ids.extend(new_token_ids) - # NOTE(jcgu): is it always true? will MTP affect this judegment? + # NOTE(jcgu): is it always true? will MTP affect this judgement? # When a request is scheduled again, and the number of new tokens # is 1 (excluding chunked prefill), the request is in decode phase. if len(new_token_ids) == 1: @@ -711,7 +711,7 @@ def _prepare_req_meta( has_new_tokens = adjusted_num_total_tokens > tracker.save_watermark should_save = False # Determine if a save is needed for this step - # when there are new token KVs (adjusted by saving behavior): + # when there are new token KVs: # 1. Prefill: always save # 2. Decode (with save_decode=True) # 2.1 regular decode (not finished): accumulate until getting a full block From 894c747df2b131590863ea07bbf140efca4af7d8 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 25 Nov 2025 19:26:55 +0000 Subject: [PATCH 09/29] fix imports in kv_cache tests Signed-off-by: Juncheng Gu --- examples/{ => offload}/gke/benchmarks/README.md | 0 examples/{ => offload}/gke/benchmarks/benchmark-pod.yaml | 0 examples/{ => offload}/gke/benchmarks/deploy-baseline.yaml | 0 examples/{ => offload}/gke/benchmarks/deploy-cpu-offload.yaml | 0 examples/{ => offload}/gke/benchmarks/service.yaml | 0 examples/{ => offload}/gke/hf_secret.yaml | 0 examples/{ => offload}/gke/pod_tpu_commons_cpu_offload.yaml | 2 +- .../gke/pod_tpu_commons_cpu_offload_verification.yaml | 2 +- .../{ => offload}/gke/pod_tpu_host_offload_unit_tests.yaml | 0 examples/{ => offload}/offline_inference_kv_cache.py | 2 +- .../{ => offload}/offline_inference_kv_cache_verification.py | 4 +++- 11 files changed, 6 insertions(+), 4 deletions(-) rename examples/{ => offload}/gke/benchmarks/README.md (100%) rename examples/{ => offload}/gke/benchmarks/benchmark-pod.yaml (100%) rename examples/{ => offload}/gke/benchmarks/deploy-baseline.yaml (100%) rename examples/{ => offload}/gke/benchmarks/deploy-cpu-offload.yaml (100%) rename examples/{ => offload}/gke/benchmarks/service.yaml (100%) rename examples/{ => offload}/gke/hf_secret.yaml (100%) rename examples/{ => offload}/gke/pod_tpu_commons_cpu_offload.yaml (92%) rename examples/{ => offload}/gke/pod_tpu_commons_cpu_offload_verification.yaml (93%) rename examples/{ => offload}/gke/pod_tpu_host_offload_unit_tests.yaml (100%) rename examples/{ => offload}/offline_inference_kv_cache.py (98%) rename examples/{ => offload}/offline_inference_kv_cache_verification.py (97%) diff --git a/examples/gke/benchmarks/README.md b/examples/offload/gke/benchmarks/README.md similarity index 100% rename from examples/gke/benchmarks/README.md rename to examples/offload/gke/benchmarks/README.md diff --git a/examples/gke/benchmarks/benchmark-pod.yaml b/examples/offload/gke/benchmarks/benchmark-pod.yaml similarity index 100% rename from examples/gke/benchmarks/benchmark-pod.yaml rename to examples/offload/gke/benchmarks/benchmark-pod.yaml diff --git a/examples/gke/benchmarks/deploy-baseline.yaml b/examples/offload/gke/benchmarks/deploy-baseline.yaml similarity index 100% rename from examples/gke/benchmarks/deploy-baseline.yaml rename to examples/offload/gke/benchmarks/deploy-baseline.yaml diff --git a/examples/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml similarity index 100% rename from examples/gke/benchmarks/deploy-cpu-offload.yaml rename to examples/offload/gke/benchmarks/deploy-cpu-offload.yaml diff --git a/examples/gke/benchmarks/service.yaml b/examples/offload/gke/benchmarks/service.yaml similarity index 100% rename from examples/gke/benchmarks/service.yaml rename to examples/offload/gke/benchmarks/service.yaml diff --git a/examples/gke/hf_secret.yaml b/examples/offload/gke/hf_secret.yaml similarity index 100% rename from examples/gke/hf_secret.yaml rename to examples/offload/gke/hf_secret.yaml diff --git a/examples/gke/pod_tpu_commons_cpu_offload.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml similarity index 92% rename from examples/gke/pod_tpu_commons_cpu_offload.yaml rename to examples/offload/gke/pod_tpu_commons_cpu_offload.yaml index 49bb437dc..368e44da3 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload.yaml +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml @@ -13,7 +13,7 @@ spec: imagePullPolicy: Always # Uncomment to always pull the latest image for any dev work command: - python - - /workspace/tpu_inference/examples/offline_inference_kv_cache.py + - /workspace/tpu_inference/examples/offload/offline_inference_kv_cache.py - --model=meta-llama/Llama-3.1-8B - --tensor_parallel_size=8 - --max_model_len=1024 diff --git a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml similarity index 93% rename from examples/gke/pod_tpu_commons_cpu_offload_verification.yaml rename to examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml index f9e7c7c41..2ebdc67ee 100644 --- a/examples/gke/pod_tpu_commons_cpu_offload_verification.yaml +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -19,7 +19,7 @@ spec: imagePullPolicy: Always command: - python - - /workspace/tpu_inference/examples/offline_inference_kv_cache_verification.py + - /workspace/tpu_inference/examples/offload/offline_inference_kv_cache_verification.py - --model=meta-llama/Llama-3.1-8B - --tensor_parallel_size=8 - --max_model_len=1024 diff --git a/examples/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml similarity index 100% rename from examples/gke/pod_tpu_host_offload_unit_tests.yaml rename to examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml diff --git a/examples/offline_inference_kv_cache.py b/examples/offload/offline_inference_kv_cache.py similarity index 98% rename from examples/offline_inference_kv_cache.py rename to examples/offload/offline_inference_kv_cache.py index 6df636564..ffbe00f22 100644 --- a/examples/offline_inference_kv_cache.py +++ b/examples/offload/offline_inference_kv_cache.py @@ -5,7 +5,7 @@ import vllm.envs as envs from vllm import LLM, EngineArgs -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def create_parser(): diff --git a/examples/offline_inference_kv_cache_verification.py b/examples/offload/offline_inference_kv_cache_verification.py similarity index 97% rename from examples/offline_inference_kv_cache_verification.py rename to examples/offload/offline_inference_kv_cache_verification.py index b93dce149..be51edf89 100644 --- a/examples/offline_inference_kv_cache_verification.py +++ b/examples/offload/offline_inference_kv_cache_verification.py @@ -28,7 +28,7 @@ import vllm.envs as envs from vllm import LLM, EngineArgs, SamplingParams -from vllm.utils import FlexibleArgumentParser +from vllm.utils.argparse_utils import FlexibleArgumentParser def create_parser(): @@ -89,6 +89,8 @@ def run_invocations(llm: LLM, sampling_params: SamplingParams, print(f"--- Invocation {i + 1}/{num_invocations} ---") outputs = llm.generate(prompts, sampling_params) all_outputs.append(outputs[0].outputs[0].text) + # reset prefix cache + llm.llm_engine.engine_core.reset_prefix_cache() time.sleep(5) if envs.VLLM_TORCH_PROFILER_DIR is not None: From a24e4bb255e5f675d2cd646031c2ed4cdc9f3e58 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 25 Nov 2025 19:40:44 +0000 Subject: [PATCH 10/29] tweaks Signed-off-by: Juncheng Gu --- examples/multi_modal_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/multi_modal_inference.py b/examples/multi_modal_inference.py index d1f9101c4..7b331ea10 100644 --- a/examples/multi_modal_inference.py +++ b/examples/multi_modal_inference.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This example shows how to use vLLM for running offline inference with the correct prompt format on vision language models for text generation. From 616ac130d3e54099d1e295651d8ea219317993bd Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Tue, 25 Nov 2025 19:49:55 +0000 Subject: [PATCH 11/29] tweaks Signed-off-by: Juncheng Gu --- tpu_inference/platforms/tpu_platform.py | 1 + tpu_inference/worker/tpu_worker.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 544b77de6..8502bcd8e 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -217,6 +217,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "Forcing --disable_chunked_mm_input.") scheduler_config.disable_chunked_mm_input = True + # Late initialization to avoid circular import from tpu_inference.models.jax.utils.quantization.quantization_utils import \ update_vllm_config_for_qwix_quantization diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 3d2739c30..8b5e33f98 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -299,8 +299,6 @@ def determine_available_memory(self) -> int: kv_cache_specs = self.model_runner.get_kv_cache_spec() num_layers = len(kv_cache_specs) vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs) - # rpa_page_size_bytes = get_rpa_page_size_bytes(self.model_runner.mesh, - # kv_cache_specs) stage_buffer_size_bytes = staging_buffer_pages * num_layers * vllm_page_size_bytes total_hbm_avail = total_hbm_avail - stage_buffer_size_bytes From 6f8ae20ca5869d3902a236ea3e6629aad54185b0 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 26 Nov 2025 05:46:41 +0000 Subject: [PATCH 12/29] multi-request worker test Signed-off-by: Juncheng Gu --- .../tpu_offload_connector_worker_test.py | 273 ++++++++++++------ 1 file changed, 180 insertions(+), 93 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 6642aaa22..4773464ae 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -25,7 +25,7 @@ logger = init_logger(__name__) -_DEFAULT_BLOCK_SIZE = 64 +_DEFAULT_BLOCK_SIZE = 256 class MockTPUModelRunner(TPUModelRunner): @@ -62,14 +62,14 @@ class KVTransferConfig: port = 1234 -class TestCpuOffloadingSave(jtu.JaxTestCase): +class TestTPUOffloadConnectorWorker(jtu.JaxTestCase): """Test the save functionality of the TPUOffloadConnectorWorker.""" def setUp(self): super().setUp() self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) - self.num_layers = 2 - self.num_blocks = 24 + self.num_layers = 80 + self.num_blocks = 128 self.num_cpu_chunks = 24 self.block_size = self.vllm_config.cache_config.block_size self.num_heads = 8 @@ -204,30 +204,52 @@ def test_precompile_run_success(self, swap_op_type: str): ) @parameterized.named_parameters( + dict( + testcase_name="_regular_single_block_save", + num_blocks_to_save=1, + num_requests=1, + ), + dict( + testcase_name="_regular_multi_requests_single_block_save", + num_blocks_to_save=2, + num_requests=4, + ), dict( testcase_name="_regular_multi_block_save", num_blocks_to_save=5, + num_requests=1, ), dict( testcase_name="_regular_multi_block_save_with_compile_jax", num_blocks_to_save=5, + num_requests=1, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name= + "_regular_multi_request_single_block_save_with_compile_jax", + num_blocks_to_save=1, + num_requests=6, use_precompiled_swap_ops=True, ), dict( testcase_name="_regular_multi_block_save_with_compile_pallas", num_blocks_to_save=5, + num_requests=1, use_precompiled_swap_ops=True, swap_op_type="pallas", ), dict( testcase_name="_final_save", num_blocks_to_save=1, + num_requests=1, is_final_save=True, skip_save=False, ), dict( testcase_name="_final_skip_save", num_blocks_to_save=0, + num_requests=1, is_final_save=True, skip_save=True, ), @@ -235,52 +257,64 @@ def test_precompile_run_success(self, swap_op_type: str): def test_tpu_connector_save( self, num_blocks_to_save: int, + num_requests: int = 1, is_final_save: bool = False, skip_save: bool = False, use_precompiled_swap_ops: bool = False, swap_op_type: str = "jax", ): - if num_blocks_to_save > self.num_blocks or num_blocks_to_save > self.num_cpu_chunks: + total_num_blocks_to_save = num_blocks_to_save * num_requests + if total_num_blocks_to_save > self.num_blocks or total_num_blocks_to_save > self.num_cpu_chunks: self.skipTest( - f"num_blocks_to_save {num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" + f"num_blocks_to_save {total_num_blocks_to_save} exceeds ModelRunner / OffloadConnectorWorker's capacity" ) # Prepare and Execute Save all_block_ids = list(range(self.num_blocks)) all_chunk_ids = list(range(self.num_cpu_chunks)) - src_block_ids = random.sample(all_block_ids, num_blocks_to_save) - dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_save) - num_tokens_to_save = num_blocks_to_save * self.block_size - num_total_tokens = num_tokens_to_save - save_spec = SaveSpec( - num_skip_leading_tokens=0, - num_total_tokens=num_total_tokens, - is_final_save=is_final_save, - skip_save=skip_save, - src_blocks=src_block_ids, - dst_chunks=dst_chunk_ids, - ) + src_block_ids = random.sample(all_block_ids, total_num_blocks_to_save) + dst_chunk_ids = random.sample(all_chunk_ids, total_num_blocks_to_save) + + src_block_ids_split = np.array_split(src_block_ids, num_requests) + dst_chunk_ids_split = np.array_split(dst_chunk_ids, num_requests) + + requests_meta = [] + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + + num_tokens_to_save_per_req = len(src_blocks) * self.block_size + + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save_per_req, + is_final_save=is_final_save, + skip_save=skip_save, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + + total_token_ids = list(range(num_tokens_to_save_per_req)) + + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + save_spec=save_spec, + ) + requests_meta.append(req_meta) logger.info(f"Starting test_tpu_connector_save with: " f"num_blocks_to_save={num_blocks_to_save}, " + f"num_requests={num_requests}, " f"is_final_save={is_final_save}, " f"skip_save={skip_save}, " f"use_precompiled_swap_ops={use_precompiled_swap_ops}, " - f"swap_op_type={swap_op_type};" - f"Swapspec: {save_spec}") - - total_token_ids = list(range(num_total_tokens)) - - req_id = "save_req" - req_meta = TPUReqMeta( - req_id=req_id, - token_ids=total_token_ids, - local_block_ids=src_block_ids, - save_spec=save_spec, - ) + f"swap_op_type={swap_op_type};") connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) + requests_meta=requests_meta) connector = self._create_connector(swap_op_type, use_precompiled_swap_ops) @@ -296,7 +330,7 @@ def test_tpu_connector_save( cpu_backend = worker.cpu_backend kv_caches = worker.runner.kv_caches - if skip_save or num_tokens_to_save == 0: + if skip_save or total_num_blocks_to_save == 0: logger.info(" no blocks to save") assert cpu_backend.num_saved_cpu_chunks == 0 self.assertEmpty(worker.finished_save_reqs) @@ -304,16 +338,27 @@ def test_tpu_connector_save( return # verify the saved chunks - assert req_id in worker.offload_stats.data["finished_save_chunks"] - assert dst_chunk_ids == worker.offload_stats.data[ - "finished_save_chunks"][req_id] - - for tpu_block_id, cpu_chunk_id in zip(src_block_ids, dst_chunk_ids): - cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) - for layer_idx in range(self.num_layers): - tpu_kv_block = kv_caches[layer_idx][tpu_block_id] - self.assertArraysEqual(np.array(tpu_kv_block), - np.array(cpu_kv_chunk[layer_idx])) + all_req_ids = {f"save_req_{i}" for i in range(num_requests)} + self.assertSetEqual( + all_req_ids, + set(worker.offload_stats.data["finished_save_chunks"].keys())) + + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + self.assertListEqual( + dst_chunks, + worker.offload_stats.data["finished_save_chunks"][req_id]) + + for tpu_block_id, cpu_chunk_id in zip(src_blocks, dst_chunks): + cpu_kv_chunk = cpu_backend.get(cpu_chunk_id) + for layer_idx in range(self.num_layers): + tpu_kv_block = kv_caches[layer_idx][tpu_block_id] + assert cpu_kv_chunk[ + layer_idx].sharding.memory_kind == 'pinned_host' + self.assertArraysEqual(np.array(tpu_kv_block), + np.array(cpu_kv_chunk[layer_idx])) logger.info("Saved data verification completed.") @@ -321,22 +366,30 @@ def test_tpu_connector_save( finished_saves, _ = worker.get_finished() logger.info( f"is_final_save is True. Finished requests: {finished_saves}") - self.assertIn(req_id, finished_saves) + self.assertSetEqual(all_req_ids, finished_saves) @parameterized.named_parameters( dict( testcase_name="_single_block_", num_blocks_to_operate=1, + num_requests=1, + ), + dict( + testcase_name="_multi_requests_", + num_blocks_to_operate=2, + num_requests=4, ), dict( testcase_name="_multi_blocks_compile_jax", num_blocks_to_operate=5, + num_requests=1, use_precompiled_swap_ops=True, swap_op_type="jax", ), dict( testcase_name="_multi_blocks_compile_pallas", num_blocks_to_operate=5, + num_requests=1, use_precompiled_swap_ops=True, swap_op_type="pallas", ), @@ -344,6 +397,7 @@ def test_tpu_connector_save( def test_tpu_connector_load( self, num_blocks_to_operate: int, + num_requests: int = 1, use_precompiled_swap_ops: bool = False, swap_op_type: str = "jax", ): @@ -358,9 +412,10 @@ def test_tpu_connector_load( 3. Load the data 4. Verification """ - if num_blocks_to_operate > self.num_blocks or num_blocks_to_operate > self.num_cpu_chunks: + total_num_blocks_to_operate = num_blocks_to_operate * num_requests + if total_num_blocks_to_operate > self.num_blocks or total_num_blocks_to_operate > self.num_cpu_chunks: self.skipTest( - f"num_blocks_to_save {num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" + f"num_blocks_to_save {total_num_blocks_to_operate} exceeds ModelRunner / OffloadConnectorWorker's capacity" ) # 1. Setup connector = self._create_connector(swap_op_type, @@ -376,31 +431,43 @@ def test_tpu_connector_load( ] jax.block_until_ready(dst_kv_cache) - # Prepare + # 2. Simulate a save operation all_block_ids = list(range(self.num_blocks)) all_chunk_ids = list(range(self.num_cpu_chunks)) - src_block_ids = random.sample(all_block_ids, num_blocks_to_operate) - dst_chunk_ids = random.sample(all_chunk_ids, num_blocks_to_operate) - num_tokens_to_save = num_blocks_to_operate * self.block_size - num_total_tokens = num_tokens_to_save - save_spec = SaveSpec( - num_skip_leading_tokens=0, - num_total_tokens=num_tokens_to_save, - is_final_save=False, - skip_save=False, - src_blocks=src_block_ids, - dst_chunks=dst_chunk_ids, - ) - total_token_ids = list(range(num_total_tokens)) - req_id = "save_req" - req_meta = TPUReqMeta( - req_id=req_id, - token_ids=total_token_ids, - local_block_ids=src_block_ids, - save_spec=save_spec, - ) + src_block_ids = random.sample(all_block_ids, + total_num_blocks_to_operate) + dst_chunk_ids = random.sample(all_chunk_ids, + total_num_blocks_to_operate) + + src_block_ids_split = np.array_split(src_block_ids, num_requests) + dst_chunk_ids_split = np.array_split(dst_chunk_ids, num_requests) + + save_requests_meta = [] + for i in range(num_requests): + req_id = f"save_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + num_tokens_to_save_per_req = len(src_blocks) * self.block_size + + save_spec = SaveSpec( + num_skip_leading_tokens=0, + num_total_tokens=num_tokens_to_save_per_req, + is_final_save=False, + skip_save=False, + src_blocks=src_blocks, + dst_chunks=dst_chunks, + ) + total_token_ids = list(range(num_tokens_to_save_per_req)) + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + save_spec=save_spec, + ) + save_requests_meta.append(req_meta) + connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) + requests_meta=save_requests_meta) connector.bind_connector_metadata(connector_metadata) logger.info( "Connector metadata bound, calling worker.wait_for_save().") @@ -408,39 +475,59 @@ def test_tpu_connector_load( logger.info("worker.wait_for_save() completed.") # 3. Prepare and Execute Delta Load - new_req_id = "load_req" worker.runner.kv_caches = dst_kv_cache - load_spec = LoadSpec( - num_matched_tokens=num_tokens_to_save, - dst_blocks=src_block_ids, - src_chunks=dst_chunk_ids, - can_load=True, - num_skip_leading_tokens=0, - ) - req_meta = TPUReqMeta( - req_id="load_req", - token_ids=total_token_ids, - local_block_ids=src_block_ids, - load_spec=load_spec, - ) + + load_requests_meta = [] + for i in range(num_requests): + req_id = f"load_req_{i}" + src_blocks = src_block_ids_split[i].tolist() + dst_chunks = dst_chunk_ids_split[i].tolist() + num_tokens_to_load_per_req = len(src_blocks) * self.block_size + + load_spec = LoadSpec( + num_matched_tokens=num_tokens_to_load_per_req, + dst_blocks=src_blocks, + src_chunks=dst_chunks, + can_load=True, + num_skip_leading_tokens=0, + ) + total_token_ids = list(range(num_tokens_to_load_per_req)) + req_meta = TPUReqMeta( + req_id=req_id, + token_ids=total_token_ids, + local_block_ids=src_blocks, + load_spec=load_spec, + ) + load_requests_meta.append(req_meta) + connector_metadata = TPUOffloadConnectorMetadata( - requests_meta=[req_meta]) + requests_meta=load_requests_meta) connector.bind_connector_metadata(connector_metadata) logger.info("Connector metadata bound, calling start_load_kv.") worker.start_load_kv(fwd_ctx=None) jax.block_until_ready(worker.runner.kv_caches) logger.info("start_load_kv completed and blocked until ready.") + # 4. Verification # verify the data - # we will donate the original kv_cache ref dst_kv_cache = worker.runner.kv_caches - for src_block_id in src_block_ids: - for layer_idx in range(self.num_layers): - self.assertArraysEqual( - np.array(src_kv_cache[layer_idx][src_block_id]), - np.array(dst_kv_cache[layer_idx][src_block_id])) - - # verify the saved chunks - assert new_req_id in worker.offload_stats.data["finished_load_chunks"] - assert dst_chunk_ids == worker.offload_stats.data[ - "finished_load_chunks"][new_req_id] + for i in range(num_requests): + src_blocks = src_block_ids_split[i].tolist() + for src_block_id in src_blocks: + for layer_idx in range(self.num_layers): + self.assertArraysEqual( + np.array(src_kv_cache[layer_idx][src_block_id]), + np.array(dst_kv_cache[layer_idx][src_block_id])) + + # verify the loaded chunks + all_load_req_ids = {f"load_req_{i}" for i in range(num_requests)} + self.assertSetEqual( + all_load_req_ids, + set(worker.offload_stats.data["finished_load_chunks"].keys())) + + for i in range(num_requests): + req_id = f"load_req_{i}" + dst_chunks = dst_chunk_ids_split[i].tolist() + self.assertListEqual( + dst_chunks, + worker.offload_stats.data["finished_load_chunks"][req_id]) From ff4d31fe6740b4ec702f4f684550d9d1b5b13706 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 26 Nov 2025 17:43:39 +0000 Subject: [PATCH 13/29] debug: add jax block Signed-off-by: Juncheng Gu --- tpu_inference/distributed/offload/tpu_offload_connector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index ce506f584..d30bfbdcf 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -1375,6 +1375,7 @@ def _bucketed_swap_out_fn( # Fast path: handle bucket-sized transfers if num_blocks in BLOCK_SIZE_BUCKETS: flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) + jax.block_until_ready(flat_kv_caches_cpu) split_size_list = [self.block_size] * num_blocks return [ jax.lax.split(flat_layer_cache, split_size_list, axis=0) @@ -1405,6 +1406,7 @@ def _bucketed_swap_out_fn( # Swap the bucket to CPU, result is a flat tensor for this bucket. We are doing the chunking inside this function to avoid returning any jnp.concatenate # of kv cache for the the bucketed blocks cpu_chunk_flat_per_layer = self.swap_out_fn(tpu_chunk) + jax.block_until_ready(cpu_chunk_flat_per_layer) # Split the flat bucket tensor into block-sized chunks and append split_size_list = [self.block_size] * decomposed_block_size for i, layer_cache in enumerate(cpu_chunk_flat_per_layer): From 43f8f1edc4eebbf46998229c3f850ec48f2ec64e Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 1 Dec 2025 21:55:23 +0000 Subject: [PATCH 14/29] worker_test: multi requests; acc_test: precompile Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_accuracy_test.py | 24 +++++--- .../tpu_offload_connector_worker_test.py | 61 ++++++++++++++----- .../offload/tpu_offload_connector.py | 56 +++++++---------- 3 files changed, 83 insertions(+), 58 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py index 0059c4bd9..a5f538f8f 100644 --- a/tests/distributed/offload/tpu_offload_accuracy_test.py +++ b/tests/distributed/offload/tpu_offload_accuracy_test.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import itertools import os import time @@ -49,12 +50,13 @@ def _test_kv_cache_cpu_offloading_accuracy( sampling_config: SamplingParams, kv_transfer_config: KVTransferConfig, swap_op_type: str, + skip_precompile: str, decode_save: str, ): with monkeypatch.context(): os.environ['SKIP_JAX_PRECOMPILE'] = '1' - os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = '1' os.environ['TPU_OFFLOAD_SWAP_OP_TYPE'] = swap_op_type + os.environ['TPU_OFFLOAD_SKIP_JAX_PRECOMPILE'] = skip_precompile os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save llm = LLM(model="meta-llama/Llama-3.2-3B", max_model_len=1024, @@ -98,12 +100,14 @@ def test_kv_cache_cpu_offloading_accuracy( ): swap_op_types = ["pallas", "jax"] decode_saves = ["0", "1"] - for swap_op_type in swap_op_types: - for decode_save in decode_saves: - _test_kv_cache_cpu_offloading_accuracy( - monkeypatch, - sampling_config, - kv_transfer_config, - swap_op_type, - decode_save, - ) + skip_precompile = ["0", "1"] + for swap_op_type, decode_save, _skip_precompile in itertools.product( + swap_op_types, decode_saves, skip_precompile): + _test_kv_cache_cpu_offloading_accuracy( + monkeypatch, + sampling_config, + kv_transfer_config, + swap_op_type, + _skip_precompile, + decode_save, + ) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 4773464ae..79fa63bfd 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -70,7 +70,7 @@ def setUp(self): self.vllm_config = MockVllmConfig(block_size=_DEFAULT_BLOCK_SIZE) self.num_layers = 80 self.num_blocks = 128 - self.num_cpu_chunks = 24 + self.num_cpu_chunks = 128 self.block_size = self.vllm_config.cache_config.block_size self.num_heads = 8 self.head_size = 128 @@ -205,40 +205,57 @@ def test_precompile_run_success(self, swap_op_type: str): @parameterized.named_parameters( dict( - testcase_name="_regular_single_block_save", + testcase_name="_single_block", num_blocks_to_save=1, num_requests=1, ), dict( - testcase_name="_regular_multi_requests_single_block_save", - num_blocks_to_save=2, - num_requests=4, + testcase_name="_multi_requests_single_block", + num_blocks_to_save=1, + num_requests=6, ), dict( - testcase_name="_regular_multi_block_save", + testcase_name="_multi_blocks", num_blocks_to_save=5, num_requests=1, ), dict( - testcase_name="_regular_multi_block_save_with_compile_jax", + testcase_name="_multi_requests_multi_blocks", + num_blocks_to_save=5, + num_requests=6, + ), + dict( + testcase_name="_multi_blocks_with_compile_jax", num_blocks_to_save=5, num_requests=1, use_precompiled_swap_ops=True, ), dict( - testcase_name= - "_regular_multi_request_single_block_save_with_compile_jax", + testcase_name="_multi_requests_single_block_with_compile_jax", num_blocks_to_save=1, num_requests=6, use_precompiled_swap_ops=True, ), dict( - testcase_name="_regular_multi_block_save_with_compile_pallas", + testcase_name="_multi_requests_multi_blocks_with_compile_jax", + num_blocks_to_save=5, + num_requests=6, + use_precompiled_swap_ops=True, + ), + dict( + testcase_name="_multi_blocks_with_compile_pallas", num_blocks_to_save=5, num_requests=1, use_precompiled_swap_ops=True, swap_op_type="pallas", ), + dict( + testcase_name="_multi_requests_multi_blocks_with_compile_pallas", + num_blocks_to_save=5, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="pallas", + ), dict( testcase_name="_final_save", num_blocks_to_save=1, @@ -370,13 +387,13 @@ def test_tpu_connector_save( @parameterized.named_parameters( dict( - testcase_name="_single_block_", + testcase_name="_single_block", num_blocks_to_operate=1, num_requests=1, ), dict( - testcase_name="_multi_requests_", - num_blocks_to_operate=2, + testcase_name="_multi_requests_single_block", + num_blocks_to_operate=1, num_requests=4, ), dict( @@ -387,9 +404,23 @@ def test_tpu_connector_save( swap_op_type="jax", ), dict( - testcase_name="_multi_blocks_compile_pallas", + testcase_name="_multi_requests_single_block_compile_jax", + num_blocks_to_operate=1, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_requests_multi_blocks_compile_jax", num_blocks_to_operate=5, - num_requests=1, + num_requests=6, + use_precompiled_swap_ops=True, + swap_op_type="jax", + ), + dict( + testcase_name="_multi_requests_multi_blocks_compile_pallas", + num_blocks_to_operate=5, + num_requests=6, use_precompiled_swap_ops=True, swap_op_type="pallas", ), diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index d30bfbdcf..9c7f49a92 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -118,7 +118,7 @@ LRUCacheManager, StagingBufferManager) from tpu_inference.distributed.offload.utils import ( CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId, - TokenProcessor, get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) + get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) from tpu_inference.logger import init_logger from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_runner import TPUModelRunner @@ -496,8 +496,6 @@ def __init__(self, vllm_config: "VllmConfig"): self._reqs_being_loaded = defaultdict[ReqId, set[CpuChunkId]](set) model_name = self.vllm_config.model_config.model - self.token_processor = TokenProcessor(model_name=model_name, - chunk_size=self.block_size) self.decode_save = envs.TPU_OFFLOAD_DECODE_SAVE # NOTE(jcgu): currently, let's make chunk_size == block_size @@ -528,7 +526,7 @@ def __init__(self, vllm_config: "VllmConfig"): def _get_request_block_hashes(self, req: "Request") -> list[BlockHash]: # request's original block_hashes do not include the last partial block - # TODO(jcgu): switch back to token_processor + # TODO(jcgu): add an option to use local token_processor return req.block_hashes def get_num_new_matched_tokens( @@ -1160,19 +1158,14 @@ def __init__(self, vllm_config: VllmConfig, self.runner: Optional[TPUModelRunner] = None self.mesh: Optional[Mesh] = None + self.swap_in_fn: KVCacheSwapFn = None + self.swap_out_fn: KVCacheSwapFn = None self.swap_op_type = envs.TPU_OFFLOAD_SWAP_OP_TYPE - assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE) # TODO(jcgu): check libtpu compatibility for pallas dma kernel - logger.info( - f"(cpu offloading) swap operation type is {self.swap_op_type}") - + assert self.swap_op_type in get_args(CPU_OFFLOADING_SWAP_OP_TYPE) self.use_bucketed_swap_ops = not envs.TPU_OFFLOAD_SKIP_JAX_PRECOMPILE - logger.info( - f"(cpu offloading) use_bucketed_swap_ops={self.use_bucketed_swap_ops}" - ) - - self.swap_in_fn: KVCacheSwapFn = None - self.swap_out_fn: KVCacheSwapFn = None + logger.info(f" swap operation type is {self.swap_op_type}, " + f"use_bucketed_swap_ops={self.use_bucketed_swap_ops}.") # cpu cache self.num_cpu_chunks = envs.TPU_OFFLOAD_NUM_CPU_CHUNKS @@ -1181,13 +1174,11 @@ def __init__(self, vllm_config: VllmConfig, model_name = self.vllm_config.model_config.model logger.info( f"Model name is {model_name}, KV block_size={self.block_size}") - self.token_processor = TokenProcessor(model_name=model_name, - chunk_size=self.block_size) self.cpu_chunk_size = self.block_size # Thread pool for asynchronous TPU->CPU copies - self.save_executor = ThreadPoolExecutor(max_workers=4, - thread_name_prefix="tpu_saver") + self.save_executor = ThreadPoolExecutor( + max_workers=4, thread_name_prefix="tpu_save_handler") self.finished_save_reqs: set[ReqId] = set() self.finished_load_reqs: set[ReqId] = set() # Tracks if wait_for_save has been called for the current step's metadata. @@ -1298,10 +1289,11 @@ def _precompile_kv_swap_operations(self): # 3. Pre-compile CPU -> TPU transfer (used in load) split_size_list = [self.block_size] * num_blocks - chunked_dummy_kv_cpu = [ - jax.lax.split(flat_layer_cache, split_size_list, axis=0) - for flat_layer_cache in dummy_kv_cpu - ] + chunked_dummy_kv_cpu = jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + dummy_kv_cpu) + chunked_dummy_kv_tpu = self.swap_in_fn(chunked_dummy_kv_cpu) jax.block_until_ready(chunked_dummy_kv_tpu) @@ -1374,13 +1366,13 @@ def _bucketed_swap_out_fn( # Fast path: handle bucket-sized transfers if num_blocks in BLOCK_SIZE_BUCKETS: + split_size_list = [self.block_size] * num_blocks flat_kv_caches_cpu = self.swap_out_fn(flat_kv_caches_tpu) jax.block_until_ready(flat_kv_caches_cpu) - split_size_list = [self.block_size] * num_blocks - return [ - jax.lax.split(flat_layer_cache, split_size_list, axis=0) - for flat_layer_cache in flat_kv_caches_cpu - ] + return jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + flat_kv_caches_cpu) # Bucket decomposition path decomposed_block_sizes = self._decompose_into_buckets(num_blocks) @@ -1580,12 +1572,10 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], # NOTE(jcgu): we keep cpu_chunk_size == block_size split_size_list = [self.cpu_chunk_size ] * num_blocks_to_save - chunks_on_cpu = [ - jax.lax.split(flat_layer_cache, - split_size_list, - axis=0) - for flat_layer_cache in flat_kv_caches_cpu - ] + chunks_on_cpu = jax.tree.map( + lambda flat_layer_cache: jax.lax.split( + flat_layer_cache, split_size_list, axis=0), + flat_kv_caches_cpu) if chunks_on_cpu and chunks_on_cpu[0]: jax.block_until_ready(chunks_on_cpu) From 97153b9565f2c5145ed25aeb3202c5b324fa8dde Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 1 Dec 2025 22:16:57 +0000 Subject: [PATCH 15/29] add feature test Signed-off-by: Juncheng Gu --- .buildkite/features/KV_Cache_Offload.yml | 51 ++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 .buildkite/features/KV_Cache_Offload.yml diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml new file mode 100644 index 000000000..b6a93090d --- /dev/null +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -0,0 +1,51 @@ +# KV Cache Offload +# feature support matrix +steps: + - label: "Correctness tests for KV Cache Offload" + key: "KV_Cache_Offload_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_scheduler_test.py \ + /workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_worker_test.py \ + /workspace/tpu_inference/tests/distributed/offload/tpu_offload_cpu_backend_test.py \ + /workspace/tpu_inference/tests/distributed/offload/tpu_offload_manager_test.py \ + /workspace/tpu_inference/tests/distributed/offload/tpu_offload_utils_test.py \ + /workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py + - label: "Record correctness test result for KV Cache Offload" + key: "record_KV_Cache_Offload_CorrectnessTest" + depends_on: "KV_Cache_Offload_CorrectnessTest" + env: + CI_TARGET: "KV Cache Offload" + CI_STAGE: "CorrectnessTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh KV_Cache_Offload_CorrectnessTest + + - label: "Performance tests for KV Cache Offload" + key: "KV_Cache_Offload_PerformanceTest" + depends_on: "record_KV_Cache_Offload_CorrectnessTest" + soft_fail: true + agents: + queue: tpu_v6e_queue + commands: + - | + buildkite-agent meta-data set "KV_Cache_Offload_PerformanceTest" "to be added" + - label: "Record performance test result for KV Cache Offload" + key: "record_KV_Cache_Offload_PerformanceTest" + depends_on: "KV_Cache_Offload_PerformanceTest" + env: + CI_TARGET: "KV Cache Offload" + CI_STAGE: "PerformanceTest" + CI_CATEGORY: "feature support matrix" + agents: + queue: cpu + commands: + - | + .buildkite/scripts/record_step_result.sh KV_Cache_Offload_PerformanceTest From 560caf8d7af9dab7c8dca83b5780c8b6bfcb34d3 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Mon, 1 Dec 2025 22:57:17 +0000 Subject: [PATCH 16/29] follow up changes in the upstream; and update test scripts Signed-off-by: Juncheng Gu --- .buildkite/features/KV_Cache_Offload.yml | 7 +------ examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml | 7 +------ .../offload/tpu_offload_connector_worker_test.py | 1 + tpu_inference/worker/tpu_worker.py | 3 ++- 4 files changed, 5 insertions(+), 13 deletions(-) diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml index b6a93090d..e00dd6209 100644 --- a/.buildkite/features/KV_Cache_Offload.yml +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -9,12 +9,7 @@ steps: commands: - | .buildkite/scripts/run_in_docker.sh \ - python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_scheduler_test.py \ - /workspace/tpu_inference/tests/distributed/offload/tpu_offload_connector_worker_test.py \ - /workspace/tpu_inference/tests/distributed/offload/tpu_offload_cpu_backend_test.py \ - /workspace/tpu_inference/tests/distributed/offload/tpu_offload_manager_test.py \ - /workspace/tpu_inference/tests/distributed/offload/tpu_offload_utils_test.py \ - /workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py + python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/ - label: "Record correctness test result for KV Cache Offload" key: "record_KV_Cache_Offload_CorrectnessTest" depends_on: "KV_Cache_Offload_CorrectnessTest" diff --git a/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml index 69f14cabd..7fc2180f2 100644 --- a/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml @@ -17,12 +17,7 @@ spec: command: - /bin/bash - -c - - "pytest -sv tests/distributed/offload/tpu_offload_cpu_backend_test.py" - - "pytest -sv tests/distributed/offload/tpu_offload_connector_worker_test.py" - - "pytest -sv tests/distributed/offload/tpu_offload_connector_scheduler_test.py" - - "pytest -sv tests/distributed/offload/tpu_offload_utils_test.py" - - "pytest -sv tests/distributed/offload/tpu_offload_manager_test.py" - - "pytest -sv tests/distributed/offload/tpu_offload_accuracy_test.py" + - "pytest -sv tests/distributed/offload/" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 79fa63bfd..235368953 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -94,6 +94,7 @@ def setUp(self): def tearDown(self): super().tearDown() cc.reset_cache() + jax.clear_caches() def create_mesh(self, axis_shapes, axis_names): """Creates a JAX device mesh with the default device order.""" diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 8b5e33f98..a41c92e7b 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -298,7 +298,8 @@ def determine_available_memory(self) -> int: kv_cache_specs = self.model_runner.get_kv_cache_spec() num_layers = len(kv_cache_specs) - vllm_page_size_bytes = get_uniform_page_size(kv_cache_specs) + vllm_page_size_bytes = get_uniform_page_size( + list(kv_cache_specs.values())) stage_buffer_size_bytes = staging_buffer_pages * num_layers * vllm_page_size_bytes total_hbm_avail = total_hbm_avail - stage_buffer_size_bytes From 12c4885ff35ee04ac08ecbe5f4c64048fc062b59 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 3 Dec 2025 23:34:48 +0000 Subject: [PATCH 17/29] update ci tests Signed-off-by: Juncheng Gu --- .../features/KV_Cache_Host_Offloading.yml | 45 ------------------- .buildkite/features/KV_Cache_Offload.yml | 2 +- .buildkite/pipeline_jax.yml | 1 + .../offload/tpu_offload_accuracy_test.py | 1 - .../tpu_offload_connector_worker_test.py | 5 ++- 5 files changed, 5 insertions(+), 49 deletions(-) delete mode 100644 .buildkite/features/KV_Cache_Host_Offloading.yml diff --git a/.buildkite/features/KV_Cache_Host_Offloading.yml b/.buildkite/features/KV_Cache_Host_Offloading.yml deleted file mode 100644 index 020392441..000000000 --- a/.buildkite/features/KV_Cache_Host_Offloading.yml +++ /dev/null @@ -1,45 +0,0 @@ -# KV cache host offloading -# feature support matrix -steps: - - label: "Correctness tests for KV cache host offloading" - key: "KV_Cache_Host_Offloading_CorrectnessTest" - soft_fail: true - agents: - queue: tpu_v6e_queue - commands: - - | - buildkite-agent meta-data set "KV_Cache_Host_Offloading_CorrectnessTest" "to be added" - - label: "Record correctness test result for KV cache host offloading" - key: "record_KV_Cache_Host_Offloading_CorrectnessTest" - depends_on: "KV_Cache_Host_Offloading_CorrectnessTest" - env: - CI_TARGET: "KV cache host offloading" - CI_STAGE: "CorrectnessTest" - CI_CATEGORY: "feature support matrix" - agents: - queue: cpu - commands: - - | - .buildkite/scripts/record_step_result.sh KV_Cache_Host_Offloading_CorrectnessTest - - - label: "Performance tests for KV cache host offloading" - key: "KV_Cache_Host_Offloading_PerformanceTest" - depends_on: "record_KV_Cache_Host_Offloading_CorrectnessTest" - soft_fail: true - agents: - queue: tpu_v6e_queue - commands: - - | - buildkite-agent meta-data set "KV_Cache_Host_Offloading_PerformanceTest" "to be added" - - label: "Record performance test result for KV cache host offloading" - key: "record_KV_Cache_Host_Offloading_PerformanceTest" - depends_on: "KV_Cache_Host_Offloading_PerformanceTest" - env: - CI_TARGET: "KV cache host offloading" - CI_STAGE: "PerformanceTest" - CI_CATEGORY: "feature support matrix" - agents: - queue: cpu - commands: - - | - .buildkite/scripts/record_step_result.sh KV_Cache_Host_Offloading_PerformanceTest diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml index e00dd6209..2aa20049a 100644 --- a/.buildkite/features/KV_Cache_Offload.yml +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -9,7 +9,7 @@ steps: commands: - | .buildkite/scripts/run_in_docker.sh \ - python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/ + python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py - label: "Record correctness test result for KV Cache Offload" key: "record_KV_Cache_Offload_CorrectnessTest" depends_on: "KV_Cache_Offload_CorrectnessTest" diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 19c232ae8..f6eb53be3 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -122,6 +122,7 @@ steps: --ignore=/workspace/tpu_inference/tests/e2e \ --ignore=/workspace/tpu_inference/tpu_inference/mock \ --ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \ + --ignore=/workspace/tpu_inference/tests/distributed/offload/test_offload_accuracy_test.py \ --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69 - label: "JAX unit tests - kernels" diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/distributed/offload/tpu_offload_accuracy_test.py index a5f538f8f..34a553add 100644 --- a/tests/distributed/offload/tpu_offload_accuracy_test.py +++ b/tests/distributed/offload/tpu_offload_accuracy_test.py @@ -60,7 +60,6 @@ def _test_kv_cache_cpu_offloading_accuracy( os.environ['TPU_OFFLOAD_DECODE_SAVE'] = decode_save llm = LLM(model="meta-llama/Llama-3.2-3B", max_model_len=1024, - tensor_parallel_size=8, task="generate", kv_transfer_config=kv_transfer_config) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 235368953..f020b864c 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -72,9 +72,10 @@ def setUp(self): self.num_blocks = 128 self.num_cpu_chunks = 128 self.block_size = self.vllm_config.cache_config.block_size - self.num_heads = 8 + num_devices = len(list(jax.devices())) + self.num_heads = num_devices self.head_size = 128 - self.mesh = self.create_mesh((1, 8), ("data", "model")) + self.mesh = self.create_mesh((1, num_devices), ("data", "model")) if self.mesh is None: self.skipTest("Cannot create mesh. Must be run on a TPU node.") return From 2901e56e5215ebf35676529a6420024da933f351 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 3 Dec 2025 23:58:51 +0000 Subject: [PATCH 18/29] update unit-test yml Signed-off-by: Juncheng Gu --- .buildkite/features/KV_Cache_Offload.yml | 5 ++++- .buildkite/pipeline_jax.yml | 21 +++++++++++++++++-- .../offload/tpu_offload_utils_test.py | 5 +++-- tests/kernels/host_dma_test.py | 6 +----- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml index 2aa20049a..6c9f34bc5 100644 --- a/.buildkite/features/KV_Cache_Offload.yml +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -4,8 +4,11 @@ steps: - label: "Correctness tests for KV Cache Offload" key: "KV_Cache_Offload_CorrectnessTest" soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" agents: - queue: tpu_v6e_queue + queue: tpu_v6e_8_queue commands: - | .buildkite/scripts/run_in_docker.sh \ diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index f6eb53be3..215e9b4ad 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -122,7 +122,7 @@ steps: --ignore=/workspace/tpu_inference/tests/e2e \ --ignore=/workspace/tpu_inference/tpu_inference/mock \ --ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \ - --ignore=/workspace/tpu_inference/tests/distributed/offload/test_offload_accuracy_test.py \ + --ignore=/workspace/tpu_inference/tests/distributed/offload \ --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69 - label: "JAX unit tests - kernels" @@ -138,6 +138,7 @@ steps: --ignore=/workspace/tpu_inference/tests/kernels/ragged_paged_attention_kernel_v2_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/ragged_kv_cache_update_v2_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/collectives \ + --ignore=/workspace/tpu_inference/tests/kernels/host_dma_test.py \ --ignore=/workspace/tpu_inference/tests/kernels/fused_moe_v1_test.py else echo "Skipping: no changes detected in kernels, tests/kernels, or requirements.txt" @@ -256,6 +257,21 @@ steps: echo "Skipping: NIGHTLY environment variable not set" exit 0 fi + + - label: "kv cache offload tests on multi chips" + key: test_17 + soft_fail: true + env: + USE_V6E8_QUEUE: "True" + VLLM_LOG_LEVEL: "INFO" + agents: + queue: tpu_v6e_8_queue + commands: + - | + .buildkite/scripts/run_in_docker.sh \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/distributed/offload/ \ + /workspace/tpu_inference/tests/kernels/host_dma_test.py \ + --ignore=/workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- @@ -278,9 +294,10 @@ steps: - test_13 - test_15 - test_16 + - test_17 agents: queue: cpu commands: - | .buildkite/scripts/check_results.sh \ - "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16 + "TPU JAX Tests Failed" test_0 test_1 test_2 test_3 test_4 test_5 test_6 test_7 test_8 test_9 test_10 test_11 test_12 test_13 test_15 test_16 test_17 diff --git a/tests/distributed/offload/tpu_offload_utils_test.py b/tests/distributed/offload/tpu_offload_utils_test.py index 75af7a3bd..739ca0a79 100644 --- a/tests/distributed/offload/tpu_offload_utils_test.py +++ b/tests/distributed/offload/tpu_offload_utils_test.py @@ -17,7 +17,8 @@ def setUp(self): """Set up common parameters for the tests.""" self.num_layers = 2 self.num_tokens = 256 - self.num_kv_heads = 8 + num_devices = len(list(jax.devices())) + self.num_kv_heads = num_devices self.head_dim = 128 self.block_size = 16 self.num_blocks = self.num_tokens // self.block_size @@ -37,7 +38,7 @@ def setUp(self): self.cache_dtype = jnp.bfloat16 - self.mesh = self.create_mesh((1, 8), ("data", "model")) + self.mesh = self.create_mesh((1, num_devices), ("data", "model")) partition_spec = PartitionSpec(None, None, "model") self.device_sharding = NamedSharding(self.mesh, partition_spec, diff --git a/tests/kernels/host_dma_test.py b/tests/kernels/host_dma_test.py index 61dbf7386..626195343 100644 --- a/tests/kernels/host_dma_test.py +++ b/tests/kernels/host_dma_test.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import numpy as np from absl.testing import absltest, parameterized -from jax._src import compilation_cache as cc from jax._src import test_util as jtu from jax.sharding import NamedSharding, PartitionSpec @@ -15,7 +14,6 @@ DATA_LOCATION = Literal["device", "host"] -# TODO(jcgu): add into CI tests @jtu.with_config(jax_numpy_dtype_promotion='strict') class HostHbmDmaTest(jtu.JaxTestCase): @@ -27,9 +25,7 @@ def setUp(self): def tearDown(self): super().tearDown() - # Reset the cache after each test. - # This can also be achieved by running with JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE=True - cc.reset_cache() + jax.clear_caches() def create_mesh(self, axis_shapes, axis_names): """Creates a JAX device mesh with the default device order.""" From 7b0a20ae33dfe48153b313d64628b950db5f15c7 Mon Sep 17 00:00:00 2001 From: dannawang Date: Thu, 4 Dec 2025 04:32:27 +0000 Subject: [PATCH 19/29] Update test Signed-off-by: dannawang --- .../offload/tpu_offload_connector_worker_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index f020b864c..19c7cd0e4 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import gc import os import random from typing import List @@ -94,9 +95,17 @@ def setUp(self): def tearDown(self): super().tearDown() + # Destroy references explicitly + if hasattr(self, 'connector'): + del self.connector + + # Force JAX to release memory cc.reset_cache() jax.clear_caches() + # Force Python GC + gc.collect() + def create_mesh(self, axis_shapes, axis_names): """Creates a JAX device mesh with the default device order.""" try: From 63b0c0b7c3e2543521c3da9e005a41c815eb8f3e Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 4 Dec 2025 05:59:09 +0000 Subject: [PATCH 20/29] fix gke kv cache verification with sampling_param.temperature=0 Signed-off-by: Juncheng Gu --- ...offline_inference_kv_cache_verification.py | 27 ++++--------------- .../tpu_offload_connector_worker_test.py | 4 +-- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/examples/offload/offline_inference_kv_cache_verification.py b/examples/offload/offline_inference_kv_cache_verification.py index be51edf89..3c58c258d 100644 --- a/examples/offload/offline_inference_kv_cache_verification.py +++ b/examples/offload/offline_inference_kv_cache_verification.py @@ -38,12 +38,6 @@ def create_parser(): parser.set_defaults(model="meta-llama/Llama-3.1-8B") parser.set_defaults(max_model_len=1024) - # Add sampling params - sampling_group = parser.add_argument_group("Sampling parameters") - sampling_group.add_argument("--max-tokens", type=int) - sampling_group.add_argument("--temperature", type=float) - sampling_group.add_argument("--top-p", type=float) - sampling_group.add_argument("--top-k", type=int) return parser @@ -52,25 +46,14 @@ def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]: Initializes a vLLM engine and sampling parameters from the given args. """ args_copy = copy.deepcopy(llm_args) - # Pop arguments not used by LLM - max_tokens = args_copy.pop("max_tokens") - temperature = args_copy.pop("temperature") - top_p = args_copy.pop("top_p") - top_k = args_copy.pop("top_k") - # Create an LLM. The --seed argument is passed in via **args. llm = LLM(**args_copy) - # Create a sampling params object - sampling_params = llm.get_default_sampling_params() - if max_tokens is not None: - sampling_params.max_tokens = max_tokens - if temperature is not None: - sampling_params.temperature = temperature - if top_p is not None: - sampling_params.top_p = top_p - if top_k is not None: - sampling_params.top_k = top_k + # Create a sampling params + sampling_params = SamplingParams(temperature=0, + max_tokens=20, + seed=42, + ignore_eos=True) return llm, sampling_params diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/distributed/offload/tpu_offload_connector_worker_test.py index 19c7cd0e4..954d867b6 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/distributed/offload/tpu_offload_connector_worker_test.py @@ -26,7 +26,7 @@ logger = init_logger(__name__) -_DEFAULT_BLOCK_SIZE = 256 +_DEFAULT_BLOCK_SIZE = 64 class MockTPUModelRunner(TPUModelRunner): @@ -97,7 +97,7 @@ def tearDown(self): super().tearDown() # Destroy references explicitly if hasattr(self, 'connector'): - del self.connector + del self.connector # Force JAX to release memory cc.reset_cache() From 8b79f6882e3fd43c71ff29f06ca476d253fd5332 Mon Sep 17 00:00:00 2001 From: dannawang Date: Thu, 4 Dec 2025 21:50:54 +0000 Subject: [PATCH 21/29] Change sampling params to configrable Signed-off-by: dannawang --- ...offline_inference_kv_cache_verification.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/offload/offline_inference_kv_cache_verification.py b/examples/offload/offline_inference_kv_cache_verification.py index 3c58c258d..ec3f9f4e9 100644 --- a/examples/offload/offline_inference_kv_cache_verification.py +++ b/examples/offload/offline_inference_kv_cache_verification.py @@ -38,6 +38,11 @@ def create_parser(): parser.set_defaults(model="meta-llama/Llama-3.1-8B") parser.set_defaults(max_model_len=1024) + # Add sampling params + sampling_group = parser.add_argument_group("Sampling parameters") + sampling_group.add_argument("--max-tokens", type=int) + sampling_group.add_argument("--top-p", type=float) + sampling_group.add_argument("--top-k", type=int) return parser @@ -46,14 +51,24 @@ def setup_llm(llm_args: dict) -> Tuple[LLM, SamplingParams]: Initializes a vLLM engine and sampling parameters from the given args. """ args_copy = copy.deepcopy(llm_args) + # Pop arguments not used by LLM + max_tokens = args_copy.pop("max_tokens") + top_p = args_copy.pop("top_p") + top_k = args_copy.pop("top_k") + # Create an LLM. The --seed argument is passed in via **args. llm = LLM(**args_copy) - # Create a sampling params - sampling_params = SamplingParams(temperature=0, - max_tokens=20, - seed=42, - ignore_eos=True) + # Create a sampling params object + sampling_params = llm.get_default_sampling_params() + sampling_params.temperature = 0 + sampling_params.ignore_eos = True + if max_tokens is not None: + sampling_params.max_tokens = max_tokens + if top_p is not None: + sampling_params.top_p = top_p + if top_k is not None: + sampling_params.top_k = top_k return llm, sampling_params From a3ff52bec4b6096c2a632d2927a0cc7c1fe21ee7 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 5 Dec 2025 23:29:27 +0000 Subject: [PATCH 22/29] config pre-mapped buffer of tpu Signed-off-by: Juncheng Gu --- examples/offload/gke/benchmarks/deploy-cpu-offload.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml index 8bddddbe1..9f72ea703 100644 --- a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml @@ -34,6 +34,12 @@ spec: value: "4096" - name: TPU_OFFLOAD_NUM_STAGING_BLOCKS value: "256" + # config the pre-mapped CPU buffer for TPUs + # https://docs.cloud.google.com/tpu/docs/performance-guide#tpu_model_performance + - name: TPU_PREMAPPED_BUFFER_SIZE + value: "68719476736" # 64 GB + - name: TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES + value: "68719476736" # 64 GB ports: - containerPort: 8000 resources: From 02493293fbd3f7ec755cf78547a988354282f5b1 Mon Sep 17 00:00:00 2001 From: dannawang Date: Sat, 6 Dec 2025 00:52:06 +0000 Subject: [PATCH 23/29] Update benchmark pods Signed-off-by: dannawang --- examples/offload/gke/benchmarks/deploy-baseline.yaml | 2 +- .../offload/gke/benchmarks/deploy-cpu-offload.yaml | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/offload/gke/benchmarks/deploy-baseline.yaml b/examples/offload/gke/benchmarks/deploy-baseline.yaml index a72dd6619..048d2fa77 100644 --- a/examples/offload/gke/benchmarks/deploy-baseline.yaml +++ b/examples/offload/gke/benchmarks/deploy-baseline.yaml @@ -21,7 +21,7 @@ spec: imagePullPolicy: Always command: ["/bin/sh", "-c"] args: - - "vllm serve meta-llama/Llama-3.3-70B-Instruct --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml index 9f72ea703..b39246b64 100644 --- a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml @@ -15,13 +15,21 @@ spec: nodeSelector: cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. + initContainers: + - name: increase-vm-max-map-count + image: busybox + # WARNING: This changes the HOST memory settings (vm.max_map_count), not just the container. + # Required to prevent vLLM crashes due to memory mapping limits. + command: ["sysctl", "-w", "vm.max_map_count=1048576"] + securityContext: + privileged: true containers: - name: tpu-job image: imagePullPolicy: Always command: ["/bin/sh", "-c"] args: - - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.offload.tpu_offload_connector\"}' --port 8000 --max_num_batched_tokens 2048 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.offload.tpu_offload_connector\"}' --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: From 9ac152e2254b1aa8e9c4936a7ea7493d3b2f5443 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Sat, 6 Dec 2025 05:05:47 +0000 Subject: [PATCH 24/29] tweaks Signed-off-by: Juncheng Gu --- tpu_inference/worker/tpu_worker.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index a41c92e7b..2be5a8527 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -488,8 +488,3 @@ def sync_weights( def shutdown(self) -> None: return - - # Ray executor do not need handshake metadata - # as we pass the kv_parameters through proxy server - def get_kv_connector_handshake_metadata(self) -> None: - pass From b0ddb8c7dc04dbb4941d32028525d4ac664d48fb Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Wed, 10 Dec 2025 22:55:55 +0000 Subject: [PATCH 25/29] fix load_spec for unscheduled requests; fix cached request with both save and load Signed-off-by: Juncheng Gu --- .../offload/tpu_offload_connector.py | 141 ++++++++++++------ 1 file changed, 92 insertions(+), 49 deletions(-) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/distributed/offload/tpu_offload_connector.py index 9c7f49a92..5221b9e02 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/distributed/offload/tpu_offload_connector.py @@ -101,7 +101,6 @@ KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.kv_transfer.kv_connector.v1.metrics import \ KVConnectorStats -from vllm.utils.math_utils import cdiv from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig @@ -487,6 +486,11 @@ def __init__(self, vllm_config: "VllmConfig"): # as the scheduler output for these requests is minimal. self._unfinished_requests: dict[ReqId, "Request"] = {} self.load_specs: dict[ReqId, LoadSpec] = {} + # requests with load ops that have been considered by vllm scheduler, + # not all of them will be scheduled, the scheduled ones will be + # moved to load_specs. + # it should be cleaned after ConnectorMetadata's creation + self._pre_load_specs: dict[ReqId, LoadSpec] = {} # {reqid: total_num_matched_tokens_in_cpu_backend} self._external_cache_hits: dict[ReqId, int] = {} @@ -552,10 +556,12 @@ def get_num_new_matched_tokens( matched_block_hashes = block_hashes[:num_hits] self.offload_manager.touch(block_hashes) num_matched_blocks = len(matched_block_hashes) - num_matched_tokens = min(num_matched_blocks * self.block_size, - len(prompt_token_ids)) + # num_matched_tokens = min(num_matched_blocks * self.block_size, + # len(prompt_token_ids)) + num_matched_tokens = num_matched_blocks * self.block_size + assert num_matched_tokens <= len(prompt_token_ids) num_computed_blocks = num_computed_tokens // self.block_size - num_blocks_to_load = num_matched_blocks - num_computed_blocks + num_blocks_to_load = max(num_matched_blocks - num_computed_blocks, 0) logger.info( f"Request {request.request_id}: Found {num_matched_tokens} (out of {len(prompt_token_ids)} prompt tokens) matched tokens ({num_matched_blocks} blocks) in CPU backend (computed_blocks: {num_computed_blocks}, blocks_to_load: {num_blocks_to_load})." ) @@ -572,36 +578,30 @@ def get_num_new_matched_tokens( f" Req({request.request_id}) found {num_matched_blocks} blocks ({num_matched_tokens} tokens), but only {num_avail_staging_blocks} staging blocks available." ) num_blocks_to_load = num_avail_staging_blocks - num_matched_tokens = (num_blocks_to_load + - num_computed_blocks) * self.block_size + num_matched_blocks = num_blocks_to_load + num_computed_blocks + num_matched_tokens = num_matched_blocks * self.block_size # still have something to load if num_blocks_to_load > 0: - # get the src chunk ids to load - block_hashes_to_load = block_hashes[num_computed_blocks:( - num_computed_blocks + num_blocks_to_load)] - chunks_to_load = self.offload_manager.prepare_load( - block_hashes_to_load) - src_chunk_ids = [chunk.chunk_id for chunk in chunks_to_load] - - # NOTE(jcgu): fill real dst_blocks later when blocks get allocated. + # NOTE(jcgu): put dummy chunk / block ids; + # fill real ids later when the requests gets scheduled + src_chunk_ids = [-1] * num_blocks_to_load dummy_dst_blocks = [-1] * num_blocks_to_load - self.load_specs[request.request_id] = LoadSpec( + self._pre_load_specs[request.request_id] = LoadSpec( num_matched_tokens=num_matched_tokens, src_chunks=src_chunk_ids, dst_blocks=dummy_dst_blocks, num_skip_leading_tokens=num_computed_tokens, ) - num_allocated_blocks = self.staging_buffer_manager.allocate( + num_allocated_staging_blocks = self.staging_buffer_manager.allocate( request.request_id, num_blocks=num_blocks_to_load, usage="load") - assert num_allocated_blocks == num_blocks_to_load >= 0, f" failed to allocate {num_allocated_blocks} (load) staging blocks for request {request.request_id}, expected {num_blocks_to_load}." + assert num_allocated_staging_blocks == num_blocks_to_load >= 0, f" failed to allocate {num_allocated_staging_blocks} (load) staging blocks for request {request.request_id}, expected {num_blocks_to_load}." - # record the matched tokens in the cache, it will be needed in - # init save_spec - self._external_cache_hits[ - request.request_id] = num_matched_tokens + # record the matched tokens in the cache, it will be needed in + # init save_spec + self._external_cache_hits[request.request_id] = num_matched_tokens is_full_prefix_hit = (num_matched_tokens > 0 and num_matched_tokens == len(prompt_token_ids)) @@ -655,24 +655,38 @@ def update_state_after_alloc(self, request: "Request", self._unfinished_requests[request.request_id] = request if num_external_tokens == 0: return - if request.request_id in self.load_specs: + + # retrieve the load_spec + load_spec = self._pre_load_specs.pop(request.request_id, None) + if load_spec: + assert load_spec.num_skip_leading_tokens % self.block_size == 0 + assert len(load_spec.src_chunks) == len(load_spec.dst_blocks) + skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size + num_blocks_to_load = len(load_spec.src_chunks) + num_matched_blocks = num_blocks_to_load + skip_leading_blocks + assert num_matched_blocks == load_spec.num_matched_tokens // self.block_size, f"{num_matched_blocks} != {load_spec.num_matched_tokens} // {self.block_size}" + block_hashes = self._get_request_block_hashes(request) all_blocks = blocks.get_block_ids()[0] logger.info( - f" Request: {request.request_id} has {len(all_blocks)} blocks / {len(block_hashes)} block hashes.)" + f" Request: {request.request_id} has {len(all_blocks)} blocks / {len(block_hashes)} block hashes." ) - load_spec = self.load_specs[request.request_id] - assert load_spec.num_skip_leading_tokens % self.block_size == 0 - skip_leading_blocks = load_spec.num_skip_leading_tokens // self.block_size - total_matched_blocks = len( - load_spec.dst_blocks) + skip_leading_blocks - assert total_matched_blocks == cdiv( - load_spec.num_matched_tokens, self.block_size - ), f"{total_matched_blocks} != {load_spec.num_matched_tokens}" - dst_blocks = all_blocks[skip_leading_blocks:total_matched_blocks] + # get the src chunk ids to load + block_hashes_to_load = block_hashes[ + skip_leading_blocks:num_matched_blocks] + chunks_to_load = self.offload_manager.prepare_load( + block_hashes_to_load) + src_chunk_ids = [chunk.chunk_id for chunk in chunks_to_load] + + # get dst block ids + dst_blocks = all_blocks[skip_leading_blocks:num_matched_blocks] + + # update load spec + load_spec.src_chunks = src_chunk_ids load_spec.dst_blocks = dst_blocks load_spec.can_load = True + self.load_specs[request.request_id] = load_spec self._reqs_being_loaded[request.request_id] |= set( load_spec.src_chunks) logger.info( @@ -949,10 +963,10 @@ def build_connector_meta( f" - Created tracker for {req_id} with initial state: {tracker}" ) - # Immediately prepare metadata for this new request. This could include - # both a load operation (for the cached part) and a save operation - # (for the newly computed part). - load_spec = self.load_specs.get(req_id) + # Immediately prepare metadata for this new request. + # This could include both a load operation (for the cached part) + # and a save operation (for the newly computed part). + load_spec = self.load_specs.pop(req_id, None) req_meta = self._prepare_req_meta(tracker, load_spec, is_finished=False) @@ -1018,10 +1032,11 @@ def build_connector_meta( f"total_tokens={len(tracker.token_ids)}, " f"total_blocks={len(tracker.block_ids)}") - # Immediately prepare metadata for this updated request. This will - # typically be a save operation for the new tokens. + # for cached requests, whose kv pages get evicted, there will be + # load operations. + load_spec = self.load_specs.pop(req_id, None) req_meta = self._prepare_req_meta(tracker, - load_spec=None, + load_spec=load_spec, is_finished=False) if req_meta: logger.info( @@ -1032,6 +1047,23 @@ def build_connector_meta( if metadata.requests_meta: logger.info( f"Prepared {len(metadata.requests_meta)} requests for worker.") + + # after building connector_metadata, all load_specs should be consumed + assert len( + self.load_specs + ) == 0, f" load_specs still has {list(self.load_specs.keys())}" + + # clean up the temporary states of requests that are not scheduled + for req_id, _load_spec in self._pre_load_specs.items(): + logger.info(f"non-scheduled-reuqest:{req_id}") + _freed_num_staging_blocks = self.staging_buffer_manager.free( + req_id, "load") + assert _freed_num_staging_blocks == len( + _load_spec.src_chunks + ), f"{_freed_num_staging_blocks} != {len(_load_spec.src_chunks)}" + self._pre_load_specs.clear() + self._external_cache_hits.clear() + return metadata def update_connector_output(self, connector_output: KVConnectorOutput): @@ -1066,6 +1098,11 @@ def update_connector_output(self, connector_output: KVConnectorOutput): self._reqs_being_saved[req_id].remove(saved_chunk_id) if len(self._reqs_being_saved[req_id]) == 0: self._reqs_being_saved.pop(req_id, None) + else: + logger.info( + f" remaining_saving_blocks:{req_id}, { self._reqs_being_saved[req_id]}." + ) + # update the status of occupied cpu chunks self.offload_manager.mark_completion(saved_chunk_ids, "save") @@ -1127,20 +1164,25 @@ def request_finished( return: delay_free_blocks, kv_xfer_params """ - logger.info("TPUOffloadConnectorScheduler: Entering request_finished") + logger.info(" Entering request_finished") # Return True to indicate the request is being saved asynchronously # and its blocks should not be freed yet. req_id = request.request_id if req_id in self._reqs_being_saved and len( self._reqs_being_saved[req_id]) > 0: + logger.info( + f"not_free_with_save:{req_id}, {self._reqs_being_saved[req_id]}" + ) return True, None if req_id in self._reqs_being_loaded and len( self._reqs_being_loaded[req_id]) > 0: + logger.info( + f"not_free_with_load:{req_id}, {self._reqs_being_loaded[req_id]}" + ) return True, None - logger.info( - f"TPUOffloadConnectorScheduler: finished request: {req_id}") + logger.info(f" finished request: {req_id}") self._reqs_being_saved.pop(req_id, None) self._reqs_being_loaded.pop(req_id, None) @@ -1511,13 +1553,14 @@ def _save_blocks_to_cpu(self, req_id: ReqId, full_block_ids: list[int], process_token_ids = full_token_ids[:num_total_tokens] tokens_to_save = process_token_ids[num_skip_leading_tokens:] - logger.info(f"Request {req_id} save details: " - f"full_block_ids len={len(full_block_ids)}, " - f"num_skip_leading_tokens={num_skip_leading_tokens}, " - f"num_total_tokens={num_total_tokens}, " - f"num_tokens_to_save={num_tokens_to_save}, " - f"blocks_to_save({len(blocks_to_save)}: {blocks_to_save}, " - f"dst_chunks({len(dst_chunks)}: {dst_chunks} ") + logger.info( + f"Request {req_id} save details: " + f"full_block_ids len={len(full_block_ids)}, " + f"num_skip_leading_tokens={num_skip_leading_tokens}, " + f"num_total_tokens={num_total_tokens}, " + f"num_tokens_to_save={num_tokens_to_save}, " + f"blocks_to_save({len(blocks_to_save)}: {blocks_to_save}), " + f"dst_chunks({len(dst_chunks)}: {dst_chunks}) ") if not blocks_to_save and tokens_to_save: logger.warning( From 43e2fb42583858418d902933c79aa3473f9dd938 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Thu, 11 Dec 2025 03:23:06 +0000 Subject: [PATCH 26/29] cpu chunk: ready_to_evict: ref_cnt==0 Signed-off-by: Juncheng Gu --- tpu_inference/distributed/offload/offload_manager.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tpu_inference/distributed/offload/offload_manager.py b/tpu_inference/distributed/offload/offload_manager.py index c4f5cfcf0..faf88b8fe 100644 --- a/tpu_inference/distributed/offload/offload_manager.py +++ b/tpu_inference/distributed/offload/offload_manager.py @@ -17,6 +17,12 @@ @dataclass class CPUChunk: + """ + ref_cnt: + -1: init, not saved + 0: saved, ready_to_evict, ready_to_load + >=1: loadings, ready_to_load, in_use + """ chunk_id: CpuChunkId ref_cnt: int = -1 _chunk_hash: ChunkHash | None = None @@ -27,7 +33,7 @@ def is_ready_to_load(self): @property def is_ready_to_evict(self): - return self.ref_cnt <= 0 + return self.ref_cnt == 0 @property def is_in_use(self): From 0d39925b041dbb3dda4bda6fbf815375088090bf Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 12 Dec 2025 05:47:29 +0000 Subject: [PATCH 27/29] update unit tests Signed-off-by: Juncheng Gu --- .../tpu_offload_connector_scheduler_test.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py index d0b2bdced..60476edb8 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/distributed/offload/tpu_offload_connector_scheduler_test.py @@ -141,22 +141,18 @@ def test_get_num_new_matched_tokens_hit(self, scheduler_factory, assert num_external_matched_tokens == num_blocks_to_load * scheduler.block_size # check scheduler internal states + # cache_hits + assert "req1" in scheduler._external_cache_hits + assert scheduler._external_cache_hits["req1"] == num_matched_tokens if num_blocks_to_load > 0: # load_spec - assert "req1" in scheduler.load_specs - load_spec = scheduler.load_specs["req1"] + assert "req1" in scheduler._pre_load_specs + load_spec = scheduler._pre_load_specs["req1"] assert load_spec.num_matched_tokens == num_matched_tokens assert not load_spec.can_load - allocated_chunk_ids = [ - chunk.chunk_id for chunk in allocated_chunks - ] - load_src_chunk_ids = allocated_chunk_ids[num_computed_blocks:] - assert load_spec.src_chunks == load_src_chunk_ids + assert len(load_spec.src_chunks) == num_blocks_to_load assert load_spec.num_skip_leading_tokens == num_computed_tokens assert len(load_spec.dst_blocks) == num_blocks_to_load - # cache_hits - assert "req1" in scheduler._external_cache_hits - assert scheduler._external_cache_hits["req1"] == num_matched_tokens # staging_buffer assert "req1" in scheduler.staging_buffer_manager._blocks_for_load assert scheduler.staging_buffer_manager._blocks_for_load[ @@ -164,8 +160,7 @@ def test_get_num_new_matched_tokens_hit(self, scheduler_factory, assert scheduler.staging_buffer_manager.get_num_free_staging_blocks( ) == num_staging_blocks - num_blocks_to_load else: - assert "req1" not in scheduler.load_specs - assert "req1" not in scheduler._external_cache_hits + assert "req1" not in scheduler._pre_load_specs assert "req1" not in scheduler.staging_buffer_manager._blocks_for_load def test_update_state_after_alloc(self, scheduler_factory): @@ -185,8 +180,14 @@ def test_update_state_after_alloc(self, scheduler_factory): request = create_request(req_id, [0] * num_prompt_tokens, scheduler.block_size) + # init offload_manager state + matched_block_hashes = request.block_hashes[:num_matched_blocks] + allocated_chunks, _ = scheduler.offload_manager.allocate_for_save( + matched_block_hashes) + scheduler.offload_manager.complete_save(matched_block_hashes) + # Setup a pending load - scheduler.load_specs[req_id] = MagicMock( + scheduler._pre_load_specs[req_id] = MagicMock( num_matched_tokens=num_matched_tokens, num_skip_leading_tokens=num_computed_blocks * scheduler.block_size, dst_blocks=[-1] * num_blocks_to_load, From 7d81d9010444bca4467e8f3dab65eb07c60e67a2 Mon Sep 17 00:00:00 2001 From: Juncheng Gu Date: Fri, 12 Dec 2025 07:01:01 +0000 Subject: [PATCH 28/29] put offload folder under tpu_inference Signed-off-by: Juncheng Gu --- .buildkite/features/KV_Cache_Offload.yml | 2 +- .buildkite/pipeline_jax.yml | 6 +++--- .../offload/gke/benchmarks/deploy-cpu-offload.yaml | 2 +- .../offload/gke/pod_tpu_commons_cpu_offload.yaml | 2 +- .../pod_tpu_commons_cpu_offload_verification.yaml | 2 +- .../gke/pod_tpu_host_offload_unit_tests.yaml | 2 +- .../offload/tpu_offload_accuracy_test.py | 3 +-- .../tpu_offload_connector_scheduler_test.py | 2 +- .../offload/tpu_offload_connector_worker_test.py | 9 ++++----- .../offload/tpu_offload_cpu_backend_test.py | 4 ++-- .../offload/tpu_offload_manager_test.py | 7 ++++--- .../offload/tpu_offload_utils_test.py | 4 ++-- .../{distributed => }/offload/__init__.py | 0 .../{distributed => }/offload/cpu_backend.py | 3 +-- .../{distributed => }/offload/offload_manager.py | 3 +-- .../offload/tpu_offload_connector.py | 14 +++++++------- tpu_inference/{distributed => }/offload/utils.py | 1 - tpu_inference/runner/kv_cache_manager.py | 3 +-- tpu_inference/worker/tpu_worker.py | 3 ++- 19 files changed, 34 insertions(+), 38 deletions(-) rename tests/{distributed => }/offload/tpu_offload_accuracy_test.py (97%) rename tests/{distributed => }/offload/tpu_offload_connector_scheduler_test.py (99%) rename tests/{distributed => }/offload/tpu_offload_connector_worker_test.py (98%) rename tests/{distributed => }/offload/tpu_offload_cpu_backend_test.py (94%) rename tests/{distributed => }/offload/tpu_offload_manager_test.py (98%) rename tests/{distributed => }/offload/tpu_offload_utils_test.py (97%) rename tpu_inference/{distributed => }/offload/__init__.py (100%) rename tpu_inference/{distributed => }/offload/cpu_backend.py (96%) rename tpu_inference/{distributed => }/offload/offload_manager.py (99%) rename tpu_inference/{distributed => }/offload/tpu_offload_connector.py (99%) rename tpu_inference/{distributed => }/offload/utils.py (99%) diff --git a/.buildkite/features/KV_Cache_Offload.yml b/.buildkite/features/KV_Cache_Offload.yml index 6c9f34bc5..17a88786b 100644 --- a/.buildkite/features/KV_Cache_Offload.yml +++ b/.buildkite/features/KV_Cache_Offload.yml @@ -12,7 +12,7 @@ steps: commands: - | .buildkite/scripts/run_in_docker.sh \ - python3 -m pytest -s -v /workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py + python3 -m pytest -s -v /workspace/tpu_inference/tests/offload/tpu_offload_accuracy_test.py - label: "Record correctness test result for KV Cache Offload" key: "record_KV_Cache_Offload_CorrectnessTest" depends_on: "KV_Cache_Offload_CorrectnessTest" diff --git a/.buildkite/pipeline_jax.yml b/.buildkite/pipeline_jax.yml index 215e9b4ad..2041a6d2f 100644 --- a/.buildkite/pipeline_jax.yml +++ b/.buildkite/pipeline_jax.yml @@ -122,7 +122,7 @@ steps: --ignore=/workspace/tpu_inference/tests/e2e \ --ignore=/workspace/tpu_inference/tpu_inference/mock \ --ignore=/workspace/tpu_inference/tests/layers/vllm/test_compressed_tensors_moe.py \ - --ignore=/workspace/tpu_inference/tests/distributed/offload \ + --ignore=/workspace/tpu_inference/tests/offload \ --cov-config=/workspace/tpu_inference/.coveragerc --cov tpu_inference --cov-report term-missing --cov-fail-under=69 - label: "JAX unit tests - kernels" @@ -269,9 +269,9 @@ steps: commands: - | .buildkite/scripts/run_in_docker.sh \ - python3 -m pytest -s -v -x /workspace/tpu_inference/tests/distributed/offload/ \ + python3 -m pytest -s -v -x /workspace/tpu_inference/tests/offload/ \ /workspace/tpu_inference/tests/kernels/host_dma_test.py \ - --ignore=/workspace/tpu_inference/tests/distributed/offload/tpu_offload_accuracy_test.py + --ignore=/workspace/tpu_inference/tests/offload/tpu_offload_accuracy_test.py # ----------------------------------------------------------------- # NOTIFICATION STEP # ----------------------------------------------------------------- diff --git a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml index b39246b64..fa4d03a8a 100644 --- a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml @@ -29,7 +29,7 @@ spec: imagePullPolicy: Always command: ["/bin/sh", "-c"] args: - - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.distributed.offload.tpu_offload_connector\"}' --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" + - "vllm serve meta-llama/Llama-3.3-70B-Instruct --kv-transfer-config '{\"kv_connector\":\"TPUOffloadConnector\",\"kv_role\":\"kv_both\",\"kv_connector_module_path\":\"tpu_inference.offload.tpu_offload_connector\"}' --port 8000 --enable-chunked-prefill --tensor-parallel-size 8 --seed 42 --enable_prefix_caching --gpu-memory-utilization 0.9" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml index 368e44da3..7b9145953 100644 --- a/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload.yaml @@ -18,7 +18,7 @@ spec: - --tensor_parallel_size=8 - --max_model_len=1024 - --kv-transfer-config - - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.offload.tpu_offload_connector","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml index 2ebdc67ee..e9bd9748b 100644 --- a/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml +++ b/examples/offload/gke/pod_tpu_commons_cpu_offload_verification.yaml @@ -25,7 +25,7 @@ spec: - --max_model_len=1024 - --seed=42 - --kv-transfer-config - - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.distributed.offload.tpu_offload_connector","kv_role":"kv_both"}' + - '{"kv_connector":"TPUOffloadConnector","kv_connector_module_path":"tpu_inference.offload.tpu_offload_connector","kv_role":"kv_both"}' env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml index 7fc2180f2..7712c7bff 100644 --- a/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml +++ b/examples/offload/gke/pod_tpu_host_offload_unit_tests.yaml @@ -17,7 +17,7 @@ spec: command: - /bin/bash - -c - - "pytest -sv tests/distributed/offload/" + - "pytest -sv tests/offload/" env: - name: HUGGING_FACE_HUB_TOKEN valueFrom: diff --git a/tests/distributed/offload/tpu_offload_accuracy_test.py b/tests/offload/tpu_offload_accuracy_test.py similarity index 97% rename from tests/distributed/offload/tpu_offload_accuracy_test.py rename to tests/offload/tpu_offload_accuracy_test.py index 34a553add..fd597f361 100644 --- a/tests/distributed/offload/tpu_offload_accuracy_test.py +++ b/tests/offload/tpu_offload_accuracy_test.py @@ -40,8 +40,7 @@ def kv_transfer_config(): return KVTransferConfig( kv_connector="TPUOffloadConnector", kv_role="kv_both", - kv_connector_module_path= - "tpu_inference.distributed.offload.tpu_offload_connector", + kv_connector_module_path="tpu_inference.offload.tpu_offload_connector", ) diff --git a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py b/tests/offload/tpu_offload_connector_scheduler_test.py similarity index 99% rename from tests/distributed/offload/tpu_offload_connector_scheduler_test.py rename to tests/offload/tpu_offload_connector_scheduler_test.py index 60476edb8..30e31e334 100644 --- a/tests/distributed/offload/tpu_offload_connector_scheduler_test.py +++ b/tests/offload/tpu_offload_connector_scheduler_test.py @@ -9,7 +9,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.request import Request -from tpu_inference.distributed.offload.tpu_offload_connector import ( +from tpu_inference.offload.tpu_offload_connector import ( RequestTracker, TPUOffloadConnectorScheduler) _DEFAULT_BLOCK_SIZE = 16 diff --git a/tests/distributed/offload/tpu_offload_connector_worker_test.py b/tests/offload/tpu_offload_connector_worker_test.py similarity index 98% rename from tests/distributed/offload/tpu_offload_connector_worker_test.py rename to tests/offload/tpu_offload_connector_worker_test.py index 954d867b6..c23eb1146 100644 --- a/tests/distributed/offload/tpu_offload_connector_worker_test.py +++ b/tests/offload/tpu_offload_connector_worker_test.py @@ -15,13 +15,12 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole -from tpu_inference.distributed.offload.tpu_offload_connector import (LoadSpec, - SaveSpec) -from tpu_inference.distributed.offload.tpu_offload_connector import \ +from tpu_inference.logger import init_logger +from tpu_inference.offload.tpu_offload_connector import LoadSpec, SaveSpec +from tpu_inference.offload.tpu_offload_connector import \ TPUOffloadConnector as CPUOffloadingConnector -from tpu_inference.distributed.offload.tpu_offload_connector import ( +from tpu_inference.offload.tpu_offload_connector import ( TPUOffloadConnectorMetadata, TPUReqMeta) -from tpu_inference.logger import init_logger from tpu_inference.runner.tpu_runner import TPUModelRunner logger = init_logger(__name__) diff --git a/tests/distributed/offload/tpu_offload_cpu_backend_test.py b/tests/offload/tpu_offload_cpu_backend_test.py similarity index 94% rename from tests/distributed/offload/tpu_offload_cpu_backend_test.py rename to tests/offload/tpu_offload_cpu_backend_test.py index e845ef688..69094bf64 100644 --- a/tests/distributed/offload/tpu_offload_cpu_backend_test.py +++ b/tests/offload/tpu_offload_cpu_backend_test.py @@ -4,8 +4,8 @@ import pytest -from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend -from tpu_inference.distributed.offload.utils import CpuChunkId +from tpu_inference.offload.cpu_backend import LocalCPUBackend +from tpu_inference.offload.utils import CpuChunkId # Helper to create a mock jax array with a specific size in bytes diff --git a/tests/distributed/offload/tpu_offload_manager_test.py b/tests/offload/tpu_offload_manager_test.py similarity index 98% rename from tests/distributed/offload/tpu_offload_manager_test.py rename to tests/offload/tpu_offload_manager_test.py index d58b4f113..5d1674eae 100644 --- a/tests/distributed/offload/tpu_offload_manager_test.py +++ b/tests/offload/tpu_offload_manager_test.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest -from tpu_inference.distributed.offload.offload_manager import ( - CPUChunkPool, LRUCacheManager, StagingBufferManager) -from tpu_inference.distributed.offload.utils import ReqId from tpu_inference.logger import init_logger +from tpu_inference.offload.offload_manager import (CPUChunkPool, + LRUCacheManager, + StagingBufferManager) +from tpu_inference.offload.utils import ReqId logger = init_logger(__name__) diff --git a/tests/distributed/offload/tpu_offload_utils_test.py b/tests/offload/tpu_offload_utils_test.py similarity index 97% rename from tests/distributed/offload/tpu_offload_utils_test.py rename to tests/offload/tpu_offload_utils_test.py index 739ca0a79..e5c28af27 100644 --- a/tests/distributed/offload/tpu_offload_utils_test.py +++ b/tests/offload/tpu_offload_utils_test.py @@ -7,8 +7,8 @@ import numpy as np from jax.sharding import NamedSharding, PartitionSpec -from tpu_inference.distributed.offload.utils import ( - get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) +from tpu_inference.offload.utils import (get_kv_cache_swap_fn, + jitted_insert_kv_cache_slices) class TestTPUOffloadUtilsFn(unittest.TestCase): diff --git a/tpu_inference/distributed/offload/__init__.py b/tpu_inference/offload/__init__.py similarity index 100% rename from tpu_inference/distributed/offload/__init__.py rename to tpu_inference/offload/__init__.py diff --git a/tpu_inference/distributed/offload/cpu_backend.py b/tpu_inference/offload/cpu_backend.py similarity index 96% rename from tpu_inference/distributed/offload/cpu_backend.py rename to tpu_inference/offload/cpu_backend.py index 3199c5086..e4613a4a7 100644 --- a/tpu_inference/distributed/offload/cpu_backend.py +++ b/tpu_inference/offload/cpu_backend.py @@ -1,12 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import sys from collections import OrderedDict from typing import Any, Optional -from tpu_inference.distributed.offload.utils import CpuChunkId from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import CpuChunkId logger = init_logger(__name__) diff --git a/tpu_inference/distributed/offload/offload_manager.py b/tpu_inference/offload/offload_manager.py similarity index 99% rename from tpu_inference/distributed/offload/offload_manager.py rename to tpu_inference/offload/offload_manager.py index faf88b8fe..fc792d4d2 100644 --- a/tpu_inference/distributed/offload/offload_manager.py +++ b/tpu_inference/offload/offload_manager.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections import OrderedDict from dataclasses import dataclass @@ -7,8 +6,8 @@ from vllm.v1.core.kv_cache_utils import BlockHash -from tpu_inference.distributed.offload.utils import CpuChunkId, ReqId from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import CpuChunkId, ReqId logger = init_logger(__name__) diff --git a/tpu_inference/distributed/offload/tpu_offload_connector.py b/tpu_inference/offload/tpu_offload_connector.py similarity index 99% rename from tpu_inference/distributed/offload/tpu_offload_connector.py rename to tpu_inference/offload/tpu_offload_connector.py index 5221b9e02..cdeb55a9e 100644 --- a/tpu_inference/distributed/offload/tpu_offload_connector.py +++ b/tpu_inference/offload/tpu_offload_connector.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Scheduler side execution: TPUOffloadConnectorScheduler manages the state of KV cache loading and saving for @@ -112,13 +111,14 @@ from vllm.forward_context import ForwardContext from tpu_inference import envs -from tpu_inference.distributed.offload.cpu_backend import LocalCPUBackend -from tpu_inference.distributed.offload.offload_manager import ( - LRUCacheManager, StagingBufferManager) -from tpu_inference.distributed.offload.utils import ( - CPU_OFFLOADING_SWAP_OP_TYPE, CpuChunkId, KVCacheSwapFn, ReqId, - get_kv_cache_swap_fn, jitted_insert_kv_cache_slices) from tpu_inference.logger import init_logger +from tpu_inference.offload.cpu_backend import LocalCPUBackend +from tpu_inference.offload.offload_manager import (LRUCacheManager, + StagingBufferManager) +from tpu_inference.offload.utils import (CPU_OFFLOADING_SWAP_OP_TYPE, + CpuChunkId, KVCacheSwapFn, ReqId, + get_kv_cache_swap_fn, + jitted_insert_kv_cache_slices) from tpu_inference.runner.kv_cache_manager import KVCacheManager from tpu_inference.runner.tpu_runner import TPUModelRunner diff --git a/tpu_inference/distributed/offload/utils.py b/tpu_inference/offload/utils.py similarity index 99% rename from tpu_inference/distributed/offload/utils.py rename to tpu_inference/offload/utils.py index c3767983f..3a99a57e6 100644 --- a/tpu_inference/distributed/offload/utils.py +++ b/tpu_inference/offload/utils.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the LMCache project import functools import hashlib diff --git a/tpu_inference/runner/kv_cache_manager.py b/tpu_inference/runner/kv_cache_manager.py index cba70ffaf..5e4d0af54 100644 --- a/tpu_inference/runner/kv_cache_manager.py +++ b/tpu_inference/runner/kv_cache_manager.py @@ -19,9 +19,8 @@ from tpu_inference import utils from tpu_inference import utils as common_utils -from tpu_inference.distributed.offload.utils import \ - get_kv_connector_cache_layout from tpu_inference.logger import init_logger +from tpu_inference.offload.utils import get_kv_connector_cache_layout from tpu_inference.runner import utils as runner_utils from tpu_inference.runner.input_batch import CachedRequestState, InputBatch from tpu_inference.runner.kv_cache import create_kv_caches diff --git a/tpu_inference/worker/tpu_worker.py b/tpu_inference/worker/tpu_worker.py index 2be5a8527..85c8ba083 100644 --- a/tpu_inference/worker/tpu_worker.py +++ b/tpu_inference/worker/tpu_worker.py @@ -292,7 +292,8 @@ def determine_available_memory(self) -> int: if self.vllm_config.kv_transfer_config is not None: kv_transfer_config = self.vllm_config.kv_transfer_config - if kv_transfer_config.kv_connector == "TPUOffloadConnector" and kv_transfer_config.kv_connector_module_path == "tpu_inference.distributed.offload.tpu_offload_connector": + if kv_transfer_config.kv_connector == "TPUOffloadConnector" and \ + kv_transfer_config.kv_connector_module_path == "tpu_inference.offload.tpu_offload_connector": # If kv offloading is enabled, we need to account for the memory used by the KV transfer buffer. staging_buffer_pages = envs.TPU_OFFLOAD_NUM_STAGING_BLOCKS From b5a294714f99de41dc23a9b2a1a66055696632bc Mon Sep 17 00:00:00 2001 From: dannawang Date: Tue, 16 Dec 2025 05:06:13 +0000 Subject: [PATCH 29/29] Update banchmark pods Signed-off-by: dannawang --- .../gke/benchmarks/deploy-baseline.yaml | 2 +- .../gke/benchmarks/deploy-cpu-offload.yaml | 23 +++++++++++++++---- tpu_inference/envs.py | 4 ++++ .../offload/tpu_offload_connector.py | 3 ++- 4 files changed, 25 insertions(+), 7 deletions(-) diff --git a/examples/offload/gke/benchmarks/deploy-baseline.yaml b/examples/offload/gke/benchmarks/deploy-baseline.yaml index 048d2fa77..fdaab3147 100644 --- a/examples/offload/gke/benchmarks/deploy-baseline.yaml +++ b/examples/offload/gke/benchmarks/deploy-baseline.yaml @@ -29,7 +29,7 @@ spec: name: hf-token-secret key: token - name: SKIP_JAX_PRECOMPILE - value: "1" + value: "0" ports: - containerPort: 8000 resources: diff --git a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml index fa4d03a8a..0996a4122 100644 --- a/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml +++ b/examples/offload/gke/benchmarks/deploy-cpu-offload.yaml @@ -16,11 +16,24 @@ spec: cloud.google.com/gke-tpu-accelerator: tpu-v6e-slice cloud.google.com/gke-tpu-topology: 2x4 # Specify the physical topology for the TPU slice. initContainers: - - name: increase-vm-max-map-count + - name: tpu-node-setup image: busybox - # WARNING: This changes the HOST memory settings (vm.max_map_count), not just the container. - # Required to prevent vLLM crashes due to memory mapping limits. - command: ["sysctl", "-w", "vm.max_map_count=1048576"] + command: ["/bin/sh", "-c"] + args: + - | + # WARNING: This changes the HOST memory settings, not just the container. + # Required to prevent vLLM crashes due to memory mapping limits. + sysctl -w vm.max_map_count=8388608 + + # Check if the VFIO IOMMU module parameter exists, and if so, increase the + # limit on DMA mappings. This allows the TPU driver to pin and map a + # larger number of memory pages for direct hardware access. + if [ -f /sys/module/vfio_iommu_type1/parameters/dma_entry_limit ]; then + echo 2000000 > /sys/module/vfio_iommu_type1/parameters/dma_entry_limit + echo "Successfully increased dma_entry_limit to 2000000" + else + echo "Warning: vfio_iommu_type1 module parameter not found. Ensure the module is loaded." + fi securityContext: privileged: true containers: @@ -37,7 +50,7 @@ spec: name: hf-token-secret key: token - name: SKIP_JAX_PRECOMPILE - value: "1" + value: "0" - name: TPU_OFFLOAD_NUM_CPU_CHUNKS value: "4096" - name: TPU_OFFLOAD_NUM_STAGING_BLOCKS diff --git a/tpu_inference/envs.py b/tpu_inference/envs.py index 75e95cd59..546e81875 100644 --- a/tpu_inference/envs.py +++ b/tpu_inference/envs.py @@ -29,6 +29,7 @@ TPU_OFFLOAD_DECODE_SAVE: bool = False TPU_OFFLOAD_NUM_CPU_CHUNKS: int = 1024 TPU_OFFLOAD_NUM_STAGING_BLOCKS: int = 128 + TPU_OFFLOAD_SAVE_THREADS: int = 1 def env_with_choices( @@ -142,6 +143,9 @@ def _get_validated_env() -> str | None: # kv offload to dram: size of staging buffer (hbm) for swap "TPU_OFFLOAD_NUM_STAGING_BLOCKS": lambda: int(os.getenv("TPU_OFFLOAD_NUM_STAGING_BLOCKS", "128")), + # kv offload to dram: number of threads for asynchronous TPU -> CPU data transfer + "TPU_OFFLOAD_SAVE_THREADS": + lambda: int(os.getenv("TPU_OFFLOAD_SAVE_THREADS", "1")), } diff --git a/tpu_inference/offload/tpu_offload_connector.py b/tpu_inference/offload/tpu_offload_connector.py index cdeb55a9e..407d9eab6 100644 --- a/tpu_inference/offload/tpu_offload_connector.py +++ b/tpu_inference/offload/tpu_offload_connector.py @@ -1219,8 +1219,9 @@ def __init__(self, vllm_config: VllmConfig, self.cpu_chunk_size = self.block_size # Thread pool for asynchronous TPU->CPU copies + self.num_save_threads = envs.TPU_OFFLOAD_SAVE_THREADS self.save_executor = ThreadPoolExecutor( - max_workers=4, thread_name_prefix="tpu_save_handler") + max_workers=self.num_save_threads, thread_name_prefix="tpu_save_handler") self.finished_save_reqs: set[ReqId] = set() self.finished_load_reqs: set[ReqId] = set() # Tracks if wait_for_save has been called for the current step's metadata.