From 22d58b7cf7b7a9a7e3dd424e03257e5f21a70341 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Wed, 13 May 2026 10:49:42 +0530 Subject: [PATCH 01/21] Contrib: add Qwen3.6-27B vLLM APC baseline --- contrib/models/Qwen3.6-27B/README.md | 331 ++ .../scripts/openai_compat_server.py | 383 ++ contrib/models/Qwen3.6-27B/src/__init__.py | 41 + .../models/Qwen3.6-27B/src/modeling_qwen35.py | 3245 +++++++++++++++++ .../Qwen3.6-27B/src/modeling_qwen35_vision.py | 819 +++++ .../Qwen3.6-27B/src/modeling_qwen35_vl.py | 662 ++++ .../Qwen3.6-27B/src/nki_kernels/__init__.py | 10 + .../src/nki_kernels/nki_deltanet.py | 334 ++ .../src/nki_kernels/nki_deltanet_chunked.py | 546 +++ .../src/nki_kernels/nki_deltanet_fused.py | 595 +++ contrib/models/Qwen3.6-27B/test/__init__.py | 0 .../Qwen3.6-27B/test/integration/__init__.py | 0 .../integration/qwen36_27b_compile_fp8.py | 288 ++ .../test/integration/test_model.py | 605 +++ .../models/Qwen3.6-27B/test/unit/__init__.py | 0 .../Qwen3.6-27B/test/unit/test_config.py | 201 + .../test/unit/test_deltanet_decay.py | 68 + .../test/unit/test_hybrid_cache_manager.py | 314 ++ .../test/unit/test_weight_conversion.py | 436 +++ contrib/models/Qwen3.6-27B/vllm/README.md | 262 ++ .../Qwen3.6-27B/vllm/hf_qwen35_config.py | 68 + .../Qwen3.6-27B/vllm/install_qwen36_vllm.sh | 61 + .../Qwen3.6-27B/vllm/patch_nxdi_registry.py | 71 + .../Qwen3.6-27B/vllm/qwen36_chat_proxy.py | 182 + .../Qwen3.6-27B/vllm/run_offline_inference.py | 159 + .../models/Qwen3.6-27B/vllm/serve_qwen36.py | 21 + .../models/Qwen3.6-27B/vllm/sitecustomize.py | 9 + .../Qwen3.6-27B/vllm/start_vllm_server.sh | 147 + 28 files changed, 9858 insertions(+) create mode 100644 contrib/models/Qwen3.6-27B/README.md create mode 100644 contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py create mode 100644 contrib/models/Qwen3.6-27B/src/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py create mode 100644 contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py create mode 100644 contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py create mode 100644 contrib/models/Qwen3.6-27B/test/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py create mode 100644 contrib/models/Qwen3.6-27B/test/integration/test_model.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/__init__.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_config.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py create mode 100644 contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/README.md create mode 100644 contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py create mode 100755 contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh create mode 100644 contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py create mode 100644 contrib/models/Qwen3.6-27B/vllm/sitecustomize.py create mode 100755 contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md new file mode 100644 index 00000000..cdba94ba --- /dev/null +++ b/contrib/models/Qwen3.6-27B/README.md @@ -0,0 +1,331 @@ +# Contrib Model: Qwen3.6-27B + +NeuronX Distributed Inference implementation of Qwen3.6-27B, a 27B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture. + +## Relationship to Qwen3.5-27B + +Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib reuses the same NxDI implementation as [Qwen3.5-27B](../Qwen3.5-27B/) (PR #128). Any code updates to Qwen3.5-27B should be propagated to this contrib and vice versa. + +### Config differences from Qwen3.5-27B + +| Field | Value | Impact | +|-------|-------|--------| +| `output_gate_type` | `"swish"` | **Ignored** -- not used by HF transformers or NxDI (gate uses sigmoid) | +| `language_model_only` | `false` | Informational, not used by model code | +| `bos_token_id` | `248044` | New but not architecture-relevant | +| `pad_token_id` | `null` | New at text_config level (already handled) | +| `partial_rotary_factor` | `0.25` | Already in rope_parameters, redundant copy | +| `transformers_version` | `4.57.1` | Updated from `4.57.0.dev0` | + +No architecture changes are required relative to the Qwen3.5-27B hybrid +implementation. This contrib packages the NxDI Qwen3.6-27B model code, +DeltaNet NKI kernels, FP8/vLLM serving helpers, and validation coverage for the +Qwen3.6 weights. + +## Model Family + +| Model | HuggingFace ID | Params | Instance | +|-------|----------------|--------|----------| +| **Qwen3.6-27B** | `Qwen/Qwen3.6-27B` | 27B | trn2.3xlarge (TP=4) | + +**License:** Apache 2.0 + +## Architecture Details + +| Feature | Value | +|---------|-------| +| Layers | 64 (48 DeltaNet + 16 GQA) | +| Layer Pattern | [3 DeltaNet + 1 GQA] x 16 | +| Hidden Size | 5120 | +| GQA Attention | 24 heads, 4 KV heads, head_dim=256 | +| DeltaNet Attention | 48 value heads, 16 key heads, k_dim=v_dim=128 | +| Dense MLP | SwiGLU (gate_proj + up_proj: 5120 -> 17408, down_proj: 17408 -> 5120) | +| Position Encoding | Partial RoPE (25% of head_dim = 64 dims), mRoPE for VL | +| Vocabulary | 248,320 | +| Normalization | RMSNorm with +1 weight convention | +| Activation | SiLU gated MLP | + +### Unique Architecture Features + +- **Hybrid DeltaNet + GQA:** 48 of 64 layers use Gated DeltaNet (linear recurrent attention), 16 layers use standard GQA with KV cache. The pattern repeats every 4 layers: 3 DeltaNet + 1 GQA. +- **DeltaNet Linear Attention:** Uses the delta rule for recurrent state updates with gated decay. Per-step: `state *= exp(g); delta = (v - state^T @ k) * beta; state += outer(k, delta); output = state^T @ q`. Runs as a chunked algorithm for context encoding, per-token recurrence for token generation. +- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused kernel uses a Neumann series for intra-chunk correction with state persistence in SBUF across chunks. +- **GQA Output Gate:** Attention layers use a sigmoid output gate. `q_proj` is 2x sized and interleaved: `[head0_query | head0_gate | head1_query | ...]`. The gate is split during weight conversion and applied after attention. +- **Partial RoPE:** Only 25% of head_dim (64 of 256 dimensions) receives rotary embeddings. The remaining 192 dimensions are identity (no rotation). +- **+1 RMSNorm Convention:** HF weights use `output = norm(x) * (1 + weight)` where weight is initialized to zeros. Converted to standard `output = norm(x) * weight` during loading by adding 1.0 to all RMSNorm weights (except DeltaNet internal norms, which use standard convention). +- **Vision-Language Support:** Optional ViT encoder runs on CPU (HBM fully consumed by 27B text decoder). Vision embeddings are injected via a scatter mask at traced input positions. + +## Test Results + +### Unit Tests (CPU) + +| Test Module | Tests | Status | +|-------------|-------|--------| +| test_config.py | 26 | 26/26 PASS | +| test_weight_conversion.py | 16 | 16/16 PASS | +| **Total** | **42** | **42/42 PASS** | + +Unit tests are architecture-level and do not depend on weights. Identical results to Qwen3.5-27B. + +### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4, SDK 2.29) + +7/7 text-only quality tests passed with `enable_thinking=False`: + +| Test | Expected | Result | +|------|----------|--------| +| Speed of light | 299,792,458 m/s | PASS | +| 17 * 23 | 391 | PASS | +| 60mph * 2.5h | 150 miles | PASS | +| is_prime function | Correct Python | PASS | +| French translation | Bonjour, comment allez-vous ? | PASS | +| Capital of Japan | Tokyo | PASS | +| sqrt(144) | 12 | PASS | + +## Performance Benchmarks + +### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, SDK 2.29, BF16) + +**TTFT (Time To First Token)** + +| Input Length | P50 (ms) | P95 (ms) | +|-------------|----------|----------| +| 16 tokens | 305.3 | 305.6 | +| 64 tokens | 305.4 | 305.9 | +| 128 tokens | 306.6 | 306.8 | +| 256 tokens | 306.2 | 306.3 | + +**TPOT / Throughput** + +| Output Length | TPOT P50 (ms) | tok/s P50 | E2E P50 (ms) | +|--------------|---------------|-----------|---------------| +| 16 | 54.3 | 18.4 | 1,121 | +| 32 | 54.4 | 18.4 | 1,993 | +| 64 | 54.2 | 18.5 | 3,720 | +| 128 | 54.2 | 18.5 | 4,912 | + +### Comparison with Qwen3.5-27B + +| Metric | Qwen3.5-27B | Qwen3.6-27B | Delta | +|--------|------------|------------|-------| +| TPOT P50 | 53 ms | 54.2 ms | +2.3% | +| Throughput | 18.9 tok/s | 18.5 tok/s | -2.1% | +| TTFT (128 tok) | 576 ms | 306.6 ms | -47% * | + +\* TTFT improvement is due to compilation config differences (256-token bucket vs 128-token bucket), not model differences. Architectural performance is equivalent. + +### Long-Context vLLM Baseline + +A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) +with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. + +| Metric | Result | +|--------|--------| +| Max model length | 131,072 tokens | +| Context encoding bucket | 512 | +| Prefill throughput | 404-428 tok/s from 512 through 64K prompt tokens | +| Decode throughput | 26.3-26.6 tok/s | +| 64K quality | needle retrieval prompts returned all expected codes | +| State reset | repeated short-after-long validation passed after 32K and 64K requests | +| Peak Neuron device memory | ~53.25 GB decimal during the 64K eval | + +Native vLLM prefix caching/APC was also validated with exact greedy output +matches: + +| APC Scenario | Cold | Warm | Speedup | Result | +|--------------|------|------|---------|--------| +| Server exact-repeat, ~10.8K prompt tokens | 26.68s | 1.67s | 16.0x | exact text match | +| Offline exact-repeat | 26.19s | 2.38s | 11.0x | exact token-ID match | +| Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | +| Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | + +### Key Observations + +- **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. +- **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers. +- **vLLM APC is high leverage:** Repeated-prefix requests avoid replaying long chunked prefill and are the largest observed latency win for chat/RAG-style workloads. +- **Performance equivalent to Qwen3.5-27B:** The BF16 TPOT difference is within measurement noise. Expected since architectures are identical. + +## Usage + +### Text-Only (trn2.3xlarge, TP=4) + +```python +import json +import torch +from transformers import AutoTokenizer, GenerationConfig +from neuronx_distributed_inference.models.config import NeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + +model_path = "/path/to/Qwen3.6-27B" +compiled_path = "/scratch/qwen36_traced/" + +neuron_config = NeuronConfig( + tp_degree=4, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + logical_nc_config=2, + enable_bucketing=False, + flash_decoding_enabled=False, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + save_sharded_checkpoint=True, +) + +# Read config.json directly (model_type 'qwen3_5' may not be +# registered in all transformers versions) +import os +with open(os.path.join(model_path, "config.json")) as f: + hf_config = json.load(f) +text_config = hf_config.get("text_config", hf_config) +config_dict = dict(text_config) +config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) +config_dict.setdefault("tie_word_embeddings", False) + +config = Qwen35InferenceConfig( + neuron_config=neuron_config, + **config_dict, +) + +model = NeuronQwen35ForCausalLM(model_path, config) +model.compile(compiled_path) + +# Reload from compiled artifacts +model = NeuronQwen35ForCausalLM(compiled_path) +model.load(compiled_path) + +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right") +gen_config = GenerationConfig( + do_sample=True, top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +inputs = tokenizer("The capital of France is", return_tensors="pt") +gen_model = HuggingFaceGenerationAdapter(model) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, +) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) +``` + +### Vision-Language (trn2.3xlarge, TP=4) + +The VL pipeline uses the text decoder on Neuron and the vision encoder on CPU: + +```python +from src.modeling_qwen35_vl import NeuronQwen35VLForCausalLM, Qwen35VLInferenceConfig + +vl_model = NeuronQwen35VLForCausalLM( + model_path="/path/to/Qwen3.6-27B", + config=vl_config, +) +vl_model.compile(compiled_path) +vl_model.load(compiled_path) + +# See test/integration/test_model.py for full VL usage example +``` + +### DeltaNet Kernel Selection + +The DeltaNet forward path can be controlled via environment variables: + +| Env Var | Forward Path | Use Case | +|---------|-------------|----------| +| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance (default for SDK 2.29) | +| `USE_NKI_CHUNKED=1` | Per-chunk NKI kernel | Legacy, superseded by fused | +| `USE_NKI=1` | Per-token NKI kernel | TKG (always used for token generation) | +| `DELTANET_SEQUENTIAL=1` | Sequential PyTorch | Debugging/reference | +| *(none)* | PyTorch chunked | Default fallback for CTE | + +## Caveats + +1. **BF16 HBM pressure at TP=4:** The pure BF16 model consumes nearly all HBM on trn2.3xlarge. Use the FP8/vLLM path for the validated 128K artifact, or a larger instance for additional batching/headroom. + +2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed -- runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). + +3. **No mini model test:** Unlike DeepSeek-V3, a mini model cannot be provided because DeltaNet layers require NKI kernels that only execute on Neuron devices. Integration tests require a trn2 instance with the full 27B weights. + +4. **Vision encoder runs on CPU:** The ViT cannot be placed on Neuron because HBM is fully consumed by the text decoder. This adds ~918ms latency per image. Future optimization: quantize text decoder to free HBM, or use larger instance. + +5. **Compilation time:** The short-context BF16 path compiles in roughly 13 minutes. The validated 128K FP8/vLLM artifact takes longer because it includes long-context cache shapes and presharded checkpoints. + +6. **+1 RMSNorm convention:** Qwen3.5/3.6 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. + +7. **DeltaNet numerical stability:** DeltaNet kernels rely on normalized Q/K inputs and bounded decay handling. The chunked path includes regression coverage for decay handling; changes to the fused kernel should be validated against the CPU reference and long-context stress prompts. + +8. **Shared codebase with Qwen3.5-27B:** This contrib uses the same `Qwen35*` class names and `modeling_qwen35*.py` filenames as the [Qwen3.5-27B contrib](../Qwen3.5-27B/). This is intentional -- both models share the `qwen3_5` model_type. The code is identical; only the HuggingFace model ID and weights differ. + +## Maximum Sequence Length + +| seq_len | Path | Status | Notes | +|---------|------|--------|-------| +| 128 | BF16 NxDI | **PASS** | BF16 baseline/quality checks | +| 256 | BF16 NxDI | **PASS** | BF16 benchmark bucket | +| 512 | BF16 NxDI | **PASS** | 4 DeltaNet chunks | +| 65,536 | FP8/vLLM | **PASS** | chunked prefill, quality, and state-reset validation | +| 131,072 | FP8/vLLM | **PASS** | compiled and served with 512-token CTE bucket | + +For production long-context serving on trn2.3xlarge, use the FP8/vLLM artifact +and 512-token context encoding bucket. Larger instances are recommended for +larger batches or additional serving headroom. + +## Compatibility Matrix + +| Instance | TP | LNC | Status | Notes | +|----------|-----|-----|--------|-------| +| trn2.3xlarge | 4 | 2 | **PASS** | BF16 short-context and FP8 128K vLLM/APC validated | +| trn2.12xlarge | 16 | 2 | Expected PASS | Untested, recommended for batching/headroom | + +### SDK Configuration + +| Component | Version | +|-----------|---------| +| NxDI | 0.9.17334 | +| neuronx-cc | 2.24.5133 | +| torch | 2.9.1 | +| transformers | 4.57.6 | +| NKI | 0.3.0 | +| NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | + +## Testing + +### Unit Tests (CPU only, no device needed) + +```bash +cd contrib/models/Qwen3.6-27B/ +# On DLAMI: source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +pytest test/unit/ -v +``` + +Tests: config parsing (26), weight conversion (16) = **42 tests**. + +### Integration Tests (needs trn2.3xlarge with 4 NeuronCores) + +```bash +cd contrib/models/Qwen3.6-27B/ + +QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \ +QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \ +pytest test/integration/test_model.py --capture=tee-sys +``` + +Tests: model loads, generates, coherence, top-token valid, capital test, TTFT, throughput, multi-prompt = **8 tests**. + +Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses the `qwen3_5` model_type internally. + +## Example Checkpoints + +- `Qwen/Qwen3.6-27B` (BF16, ~52 GB) + +## Maintainer + +AWS Neuron + +**Last Updated:** 2026-04-23 diff --git a/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py new file mode 100644 index 00000000..fe7f45d5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/openai_compat_server.py @@ -0,0 +1,383 @@ +#!/usr/bin/env python3 +"""Minimal OpenAI-compatible HTTP server for the Qwen3.6-27B NxDI artifact. + +This intentionally avoids uvicorn/fastapi runtime dependencies so it can run in +the stock Neuron inference venv. It supports non-streaming: + - GET /health + - GET /v1/models + - POST /v1/completions + - POST /v1/chat/completions +""" + +import argparse +import json +import sys +import threading +import time +import traceback +import uuid +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Dict, List + +import torch + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: Dict[str, Any]): + body = json.dumps(payload, ensure_ascii=False).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.send_header("Access-Control-Allow-Origin", "*") + handler.send_header("Access-Control-Allow-Headers", "authorization,content-type") + handler.send_header("Access-Control-Allow-Methods", "GET,POST,OPTIONS") + handler.end_headers() + handler.wfile.write(body) + + +def _error(handler: BaseHTTPRequestHandler, status: int, message: str): + _json_response( + handler, + status, + {"error": {"message": message, "type": "server_error", "code": status}}, + ) + + +def _first_text_prompt(prompt: Any) -> str: + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list) and prompt: + return str(prompt[0]) + return str(prompt) + + +def _token_scalar(tokens: Any) -> int: + if hasattr(tokens, "detach"): + tokens = tokens.detach().cpu() + if tokens.ndim == 0: + return int(tokens.item()) + return int(tokens.reshape(-1)[0].item()) + + +class QwenOpenAIServer: + def __init__(self, args: argparse.Namespace): + self.args = args + self.model_id = args.model_id + self.lock = threading.Lock() + self._load_model() + + def _load_model(self): + if self.args.contrib_root not in sys.path: + sys.path.insert(0, self.args.contrib_root) + + from transformers import AutoTokenizer, GenerationConfig + from neuronx_distributed_inference.modules.generation.sampling import ( + prepare_sampling_params, + ) + from src.modeling_qwen35 import NeuronQwen35ForCausalLM + + print("Loading tokenizer from", self.args.model_path, flush=True) + self.tokenizer = AutoTokenizer.from_pretrained( + self.args.model_path, + padding_side="right", + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + print("Loading NxDI artifact from", self.args.compiled_path, flush=True) + t0 = time.perf_counter() + self.model = NeuronQwen35ForCausalLM(self.args.compiled_path) + self.model.load(self.args.compiled_path) + self.model.reset() + self.prepare_sampling_params = prepare_sampling_params + self.GenerationConfig = GenerationConfig + print(f"Model loaded in {time.perf_counter() - t0:.2f}s", flush=True) + + def _chat_prompt(self, messages: List[Dict[str, Any]], enable_thinking: bool = False) -> str: + try: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=enable_thinking, + ) + except TypeError: + return self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + lines = [] + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", "") + lines.append(f"{role}: {content}") + lines.append("assistant:") + return "\n".join(lines) + + def _generate(self, prompt: str, body: Dict[str, Any]) -> Dict[str, Any]: + max_tokens = int(body.get("max_tokens", body.get("max_completion_tokens", 128)) or 128) + if max_tokens <= 0: + raise ValueError("max_tokens must be positive") + if max_tokens > self.args.max_new_tokens_limit: + raise ValueError( + f"max_tokens={max_tokens} exceeds server limit {self.args.max_new_tokens_limit}" + ) + + input_ids = torch.tensor( + [self.tokenizer(prompt, add_special_tokens=False).input_ids], + dtype=torch.long, + ) + prompt_tokens = int(input_ids.shape[1]) + if prompt_tokens <= 0: + raise ValueError("prompt must contain at least one token") + if prompt_tokens + max_tokens > self.args.seq_len: + raise ValueError( + f"prompt_tokens + max_tokens = {prompt_tokens + max_tokens} exceeds " + f"seq_len={self.args.seq_len}" + ) + + temperature = float(body.get("temperature", 0.0) or 0.0) + top_p = float(body.get("top_p", 1.0) or 1.0) + top_k = int(body.get("top_k", 1) or 1) + # NxDI's traced on-device sampler for this artifact uses do_sample=True. + # OpenAI temperature=0 means greedy, but passing literal 0 into that + # sampler divides logits by zero. top_k=1 with temperature=1 is the + # deterministic greedy path used by the validated HF adapter tests. + sampler_temperature = temperature + if temperature <= 0.0: + sampler_temperature = 1.0 + top_p = 1.0 + top_k = 1 + sampling_params = self.prepare_sampling_params( + batch_size=1, + top_k=[top_k], + top_p=[top_p], + temperature=[sampler_temperature], + ) + seq_ids = torch.tensor([0], dtype=torch.int32) + + with self.lock: + if hasattr(self.model, "reset"): + self.model.reset() + t0 = time.perf_counter() + first_token = None + for start in range(0, prompt_tokens, self.args.chunk_size): + end = min(start + self.args.chunk_size, prompt_tokens) + valid = end - start + chunk_ids = input_ids[:, start:end] + attention_mask = torch.ones((1, valid), dtype=torch.long) + position_ids = torch.arange( + start, + end, + dtype=torch.long, + ).unsqueeze(0) + + with torch.no_grad(): + out = self.model( + input_ids=chunk_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + first_token = _token_scalar(out.tokens) + + if first_token is None: + raise RuntimeError("prefill produced no token") + + new_ids = [] + current_token = first_token + vocab_size = len(self.tokenizer) + raw_eos_id = self.tokenizer.eos_token_id + eos_ids = ( + set(raw_eos_id) + if isinstance(raw_eos_id, (list, tuple, set)) + else {raw_eos_id} + ) + decode_ids = torch.empty((1, 1), dtype=torch.int32) + decode_position_ids = torch.empty((1, 1), dtype=torch.int32) + decode_attention_mask = torch.ones( + (1, prompt_tokens + max_tokens), + dtype=torch.int32, + ) + finish_reason = "length" + with torch.no_grad(): + for step in range(max_tokens): + if current_token in eos_ids: + finish_reason = "stop" + break + if current_token < 0 or current_token >= vocab_size: + raise RuntimeError(f"model generated invalid token id: {current_token}") + new_ids.append(current_token) + if step == max_tokens - 1: + break + + pos_value = prompt_tokens + step + decode_ids[0, 0] = current_token + decode_position_ids[0, 0] = pos_value + active_attention_mask = decode_attention_mask[:, : pos_value + 1] + out = self.model( + input_ids=decode_ids, + attention_mask=active_attention_mask, + position_ids=decode_position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + return_dict=True, + ) + current_token = _token_scalar(out.tokens) + elapsed = time.perf_counter() - t0 + + invalid = [tok for tok in new_ids if tok < 0 or tok >= vocab_size] + if invalid: + raise RuntimeError(f"model generated invalid token ids: {invalid[:8]}") + + text = self.tokenizer.decode(new_ids, skip_special_tokens=True) + for stop in body.get("stop") or []: + if isinstance(stop, str) and stop in text: + text = text.split(stop, 1)[0] + + return { + "text": text, + "prompt_tokens": prompt_tokens, + "completion_tokens": len(new_ids), + "elapsed": elapsed, + "tokens": new_ids, + "finish_reason": finish_reason, + } + + +def make_handler(server_state: QwenOpenAIServer): + class Handler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def log_message(self, fmt, *args): + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def do_OPTIONS(self): + _json_response(self, 200, {}) + + def do_GET(self): + if self.path == "/health": + _json_response(self, 200, {"status": "ok", "model": server_state.model_id}) + elif self.path == "/v1/models": + _json_response( + self, + 200, + { + "object": "list", + "data": [ + { + "id": server_state.model_id, + "object": "model", + "created": int(time.time()), + "owned_by": "local", + } + ], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + + def do_POST(self): + try: + length = int(self.headers.get("content-length", "0")) + body = json.loads(self.rfile.read(length).decode("utf-8") or "{}") + if body.get("stream"): + raise ValueError("stream=true is not supported by this minimal server yet") + + if self.path == "/v1/completions": + result = server_state._generate(_first_text_prompt(body.get("prompt", "")), body) + _json_response( + self, + 200, + { + "id": f"cmpl-{uuid.uuid4().hex}", + "object": "text_completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "text": result["text"], + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + elif self.path == "/v1/chat/completions": + messages = body.get("messages") or [] + if not isinstance(messages, list): + raise ValueError("messages must be a list") + result = server_state._generate( + server_state._chat_prompt( + messages, + enable_thinking=bool(body.get("enable_thinking", False)), + ), + body, + ) + _json_response( + self, + 200, + { + "id": f"chatcmpl-{uuid.uuid4().hex}", + "object": "chat.completion", + "created": int(time.time()), + "model": server_state.model_id, + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": result["text"], + }, + "finish_reason": result["finish_reason"], + } + ], + "usage": { + "prompt_tokens": result["prompt_tokens"], + "completion_tokens": result["completion_tokens"], + "total_tokens": result["prompt_tokens"] + + result["completion_tokens"], + }, + "x_latency_seconds": result["elapsed"], + }, + ) + else: + _error(self, 404, f"unknown route: {self.path}") + except Exception as exc: + traceback.print_exc() + _error(self, 500, str(exc)) + + return Handler + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model-id", default="qwen3.6-27b-neuron") + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--chunk-size", type=int, default=512) + parser.add_argument("--max-new-tokens-limit", type=int, default=512) + args = parser.parse_args() + + state = QwenOpenAIServer(args) + httpd = ThreadingHTTPServer((args.host, args.port), make_handler(state)) + print(f"Serving {args.model_id} on http://{args.host}:{args.port}", flush=True) + httpd.serve_forever() + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/src/__init__.py b/contrib/models/Qwen3.6-27B/src/__init__.py new file mode 100644 index 00000000..7e79aa03 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/__init__.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_qwen35 import ( + NeuronGatedDeltaNet, + NeuronQwen35Attention, + NeuronQwen35DecoderLayer, + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35DecoderModelInstance, + Qwen35InferenceConfig, + Qwen35MLP, + Qwen35ModelWrapper, +) +from src.modeling_qwen35_vision import ( + NeuronQwen35VisionForImageEncoding, + NeuronQwen35VisionModel, +) +from src.modeling_qwen35_vl import ( + NeuronQwen35VLForCausalLM, + Qwen35VLInferenceConfig, +) + +__all__ = [ + # Text decoder + "NeuronGatedDeltaNet", + "NeuronQwen35Attention", + "NeuronQwen35DecoderLayer", + "NeuronQwen35ForCausalLM", + "NeuronQwen35Model", + "Qwen35DecoderModelInstance", + "Qwen35InferenceConfig", + "Qwen35MLP", + "Qwen35ModelWrapper", + # Vision encoder + "NeuronQwen35VisionForImageEncoding", + "NeuronQwen35VisionModel", + # Vision-language + "NeuronQwen35VLForCausalLM", + "Qwen35VLInferenceConfig", +] diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py new file mode 100644 index 00000000..ed3c3f2c --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35.py @@ -0,0 +1,3245 @@ +""" +NxDI contrib: Qwen3.5-27B / Qwen3.6-27B (qwen3_5 -- dense model) + +Supports both Qwen3.5-27B and Qwen3.6-27B. These models share identical +architecture (qwen3_5 model_type). Qwen3.6-27B is a post-training update +with improved agentic coding and thinking preservation -- no architecture +changes, only weight differences. + +Hybrid DeltaNet + Standard Attention + Dense MLP architecture. +Adapted from Qwen3.5-35B-A3B (MoE) -- MoE removed, dense MLP added. + +48 of 64 layers use Gated DeltaNet (linear recurrent attention) +16 of 64 layers use standard GQA with KV cache + output gate +All 64 layers use a dense SwiGLU MLP (intermediate_size=17408) + +Architecture details: +- DeltaNet layers: separate in_proj_{qkv, z, a, b}, causal conv1d on QKV, gated delta rule +- Attention layers: q_proj doubled (Q + gate), partial RoPE (25% of head_dim), sigmoid output gate +- Dense MLP: standard SwiGLU (gate_proj, up_proj, down_proj) -- no MoE, no router, no experts +- KV cache: NxDI KVCacheManager for attention layers; DeltaNet layers store recurrent+conv + state as nn.Parameter buffers and return dummy KV tuples + +Config compatibility notes: +- Qwen3.6-27B adds output_gate_type="swish" to text_config. This field is + unused by both HF transformers and this NxDI code (gate uses sigmoid, as + confirmed across transformers v4.57.6, v5.6.0, and GitHub main). Safe to ignore. +""" + +import gc +import math +import logging +import os +import sys +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm + +try: + from neuronxcc.nki._private_kernels.attention import attention_isa_kernel +except ImportError: + from neuronxcc.nki.kernels.attention import attention_isa_kernel + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) +from neuronx_distributed.utils import cpu_mode + +try: + from nki import jit as nki_jit # NKI 0.3.0+ (SDK 2.29) +except ImportError: + from torch_neuronx.xla_impl.ops import nki_jit # NKI 0.2.x (SDK 2.28) +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from src.nki_kernels.nki_deltanet import deltanet_recurrent_fwd as _deltanet_nki_kernel +from src.nki_kernels.nki_deltanet import ( + deltanet_recurrent_fwd_state as _deltanet_nki_kernel_state, +) +from src.nki_kernels.nki_deltanet_chunked import ( + deltanet_chunk_step as _deltanet_nki_chunk_step, +) +from src.nki_kernels.nki_deltanet_fused import ( + deltanet_fused_chunked_fwd as _deltanet_fused_kernel, +) +from src.nki_kernels.nki_deltanet_fused import ( + _make_lower_mask, + _make_lower_mask_diag, + _make_identity, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, + DecoderModelInstance, + ModelWrapper, +) +from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, +) +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.kvcache.kv_cache_manager import KVCacheManager +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) + +logger = logging.getLogger(__name__) + +try: + _flash_fwd_call = nki_jit()(attention_isa_kernel) +except TypeError: + from torch_neuronx.xla_impl.ops import nki_jit as _torch_xla_nki_jit + + _flash_fwd_call = _torch_xla_nki_jit()(attention_isa_kernel) + +# Option B: Direct nkilib flash attention for head_dim > 128 +USE_NKILIB_KERNEL = os.environ.get("USE_NKILIB_KERNEL", "0") == "1" + +_nkilib_flash_attn = None +if USE_NKILIB_KERNEL: + try: + import neuronxcc.nki as _nki + from neuronx_distributed_inference.modules.attention.attention_base import ( + peel_decorations as _peel_decorations, + get_platform_target as _get_platform_target, + ) + from neuronxcc.nki.compiler import ( + skip_middle_end_transformations as _skip_middle_end, + enable_stack_allocator as _enable_stack_allocator, + ) + + import importlib + + _fork_path = "/home/ubuntu/nki-library-fork/nkilib_src" + if os.path.isdir(_fork_path) and _fork_path not in sys.path: + sys.path.insert(0, _fork_path) + _to_remove = [k for k in sys.modules if k.startswith("nkilib")] + for k in _to_remove: + del sys.modules[k] + import nki.language as _stub_nl + import neuronxcc.nki.language as _real_nl + + for _attr in [ + "NKIObject", + "float8_e4m3fn", + "float8_e4m3fn_x4", + "float8_e5m2_x4", + "float4_e2m1fn_x4", + ]: + if not hasattr(_real_nl, _attr) and hasattr(_stub_nl, _attr): + setattr(_real_nl, _attr, getattr(_stub_nl, _attr)) + from nkilib.core.attention.attention_cte import ( + attention_cte as _attention_cte_raw, + _MAX_HEAD_DIM, + ) + + assert _MAX_HEAD_DIM == 256, ( + f"nkilib fork has _MAX_HEAD_DIM={_MAX_HEAD_DIM}, expected 256. " + f"System nkilib may have been loaded instead of fork." + ) + logger.info( + f"Loaded nkilib attention_cte from fork (_MAX_HEAD_DIM={_MAX_HEAD_DIM})" + ) + + _raw_fn = _peel_decorations(_attention_cte_raw) + _platform = _get_platform_target() + _nkilib_flash_attn = _nki.jit( + _raw_fn, + mode="torchxla", + platform_target=_platform, + show_compiler_tb=True, + debug_kernel=True, + ) + _nkilib_flash_attn = _skip_middle_end(_nkilib_flash_attn) + _nkilib_flash_attn = _enable_stack_allocator( + _nkilib_flash_attn, log_level=logging.INFO + ) + logger.info("Option B: nkilib flash attention loaded for head_dim > 128") + except Exception as e: + logger.warning(f"Option B: Failed to load nkilib flash attention: {e}") + import traceback as _tb + + _tb.print_exc() + _nkilib_flash_attn = None + +# Option A: Detect if patch_attn_kernel was imported +NKILIB_PATCH_ACTIVE = False +try: + from importlib import import_module as _import_module + + _attn_mod = _import_module("neuronxcc.nki._pre_prod_kernels.attn_fwd") + if hasattr(_attn_mod, "_original_attention_nki_kernel_adapter"): + NKILIB_PATCH_ACTIVE = True + logger.info("Option A detected: _pre_prod_kernels patched with nkilib kernel") +except Exception: + pass + + +# ============================================================ +# Newton-Raphson Refined RMSNorm +# ============================================================ +USE_NEWTON_RMSNORM = os.environ.get("USE_NEWTON_RMSNORM") == "1" +USE_PYTHON_RMSNORM = os.environ.get("USE_PYTHON_RMSNORM") == "1" + + +class NewtonRMSNorm(nn.Module): + """RMSNorm with Newton-Raphson refined rsqrt for improved numerical accuracy.""" + + def __init__(self, hidden_size=None, eps=1e-6): + super().__init__() + self.weight = None + if hidden_size is not None: + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.hidden_size = hidden_size + self.variance_epsilon = eps + + def forward(self, hidden_states): + original_dtype = hidden_states.dtype + x = hidden_states.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + y = torch.rsqrt(variance + self.variance_epsilon) + y = y * (3.0 - (variance + self.variance_epsilon) * y * y) * 0.5 + result = x * y + if self.weight is not None: + result = result * self.weight.float() + return result.to(original_dtype) + + +def get_rmsnorm_cls(): + if cpu_mode() or USE_PYTHON_RMSNORM: + return Qwen3MoeRMSNorm + return NewtonRMSNorm if USE_NEWTON_RMSNORM else CustomRMSNorm + + +def l2norm(x, dim=-1, eps=1e-6): + return F.normalize(x, p=2, dim=dim, eps=eps) + + +# ============================================================ +# Gated DeltaNet Module (Linear Recurrent Attention) +# ============================================================ + + +class NeuronGatedDeltaNet(nn.Module): + """ + Gated DeltaNet linear attention for Neuron. + + Replaces standard attention for 48 of 64 layers in Qwen3.5/3.6-27B. + Uses a chunk-based linear recurrence instead of KV cache. + + HF weight layout (27B dense -- scaled dimensions): + - in_proj_qkv.weight: (key_dim*2 + value_dim, hidden_size) = (10240, 5120) + - in_proj_z.weight: (value_dim, hidden_size) = (6144, 5120) + - in_proj_a.weight: (num_v_heads, hidden_size) = (48, 5120) + - in_proj_b.weight: (num_v_heads, hidden_size) = (48, 5120) + - conv1d.weight: (conv_dim, 1, conv_kernel_size) = (10240, 1, 4) + - A_log: (num_v_heads,) = (48,) + - dt_bias: (num_v_heads,) = (48,) + - norm.weight: (head_v_dim,) = (128,) + - out_proj.weight: (hidden_size, value_dim) = (5120, 6144) + """ + + def __init__(self, config, layer_idx: int): + super().__init__() + tc = config + + self.hidden_size = tc.hidden_size # 5120 + self.tp_degree = tc.neuron_config.tp_degree + self.global_num_v_heads = tc.linear_num_value_heads # 48 + self.global_num_k_heads = tc.linear_num_key_heads # 16 + self.head_k_dim = tc.linear_key_head_dim # 128 + self.head_v_dim = tc.linear_value_head_dim # 128 + if self.global_num_v_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={self.global_num_v_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + if self.global_num_k_heads % self.tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={self.global_num_k_heads} must be divisible " + f"by tp_degree={self.tp_degree}" + ) + self.num_v_heads = self.global_num_v_heads // self.tp_degree + self.num_k_heads = self.global_num_k_heads // self.tp_degree + self.global_key_dim = self.head_k_dim * self.global_num_k_heads # 2048 + self.global_value_dim = self.head_v_dim * self.global_num_v_heads # 6144 + self.key_dim = self.head_k_dim * self.num_k_heads # 512 at TP=4 + self.value_dim = self.head_v_dim * self.num_v_heads # 1536 at TP=4 + self.conv_kernel_size = tc.linear_conv_kernel_dim # 4 + self.layer_idx = layer_idx + self.rms_norm_eps = tc.rms_norm_eps + self.use_hybrid_cache_manager = getattr(tc, "use_hybrid_cache_manager", False) + self.use_qwen_hybrid_chunked_prefill = getattr( + tc, "use_qwen_hybrid_chunked_prefill", False + ) + self.use_qwen_hybrid_chunked_prefill_nki = getattr( + tc, "use_qwen_hybrid_chunked_prefill_nki", False + ) + + # KV cache dummy shape info + self.head_dim = tc.head_dim # 256 + tp_degree = tc.neuron_config.tp_degree + raw_kv_heads = tc.num_key_value_heads + if raw_kv_heads < tp_degree: + replicated_kv_heads = tp_degree + else: + replicated_kv_heads = raw_kv_heads + self.kv_heads_per_rank = replicated_kv_heads // tp_degree + + # Conv1d on concatenated QKV (NOT Z). Store the depthwise kernel in a + # ColumnParallelLinear parameter container so NxD's checkpoint sharder + # can split it by output channel. Forward still uses it as Conv1d + # weight after unsqueezing the singleton input-channel dimension. + self.global_conv_dim = self.global_key_dim * 2 + self.global_value_dim # 10240 + self.conv_dim = self.key_dim * 2 + self.value_dim # 2560 at TP=4 + self.conv1d_weight = ColumnParallelLinear( + self.conv_kernel_size, + self.global_conv_dim, + bias=False, + gather_output=False, + ) + + # Input/output projections are the large DeltaNet tensors. Shard them + # with tensor parallelism; convert_qwen35_hf_to_neuron_state_dict() + # reorders in_proj_qkv into per-rank [Q_local | K_local | V_local] + # blocks before NxD slices the output dimension. + self.in_proj_qkv = ColumnParallelLinear( + self.hidden_size, + self.global_key_dim * 2 + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_z = ColumnParallelLinear( + self.hidden_size, + self.global_value_dim, + bias=False, + gather_output=False, + ) + self.in_proj_b = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.in_proj_a = ColumnParallelLinear( + self.hidden_size, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Same parameter-container pattern for per-value-head decay vectors. + # These are used as vectors in forward but sharded by output dim during + # checkpoint conversion/loading. + self.dt_bias_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + self.A_log_weight = ColumnParallelLinear( + 1, + self.global_num_v_heads, + bias=False, + gather_output=False, + ) + + # Output norm and projection + self.norm = Qwen3MoeRMSNorm(self.head_v_dim, eps=self.rms_norm_eps) + self.out_proj = RowParallelLinear( + self.global_value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) + + # State buffers for CTE -> TKG carry-over + alloc_batch_size = getattr(config.neuron_config, "max_batch_size", 1) + self._phase_batch_size = getattr(config.neuron_config, "batch_size", 1) + self.recurrent_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + self.conv_state_buffer = nn.Parameter( + torch.zeros( + alloc_batch_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=config.neuron_config.torch_dtype, + ), + requires_grad=False, + ) + + def _conv1d_weight(self): + return self.conv1d_weight.weight.unsqueeze(1) + + def _dt_bias(self): + return self.dt_bias_weight.weight.squeeze(1) + + def _A_log(self): + return self.A_log_weight.weight.squeeze(1) + + def _recurrent_step(self, query, key, value, g, beta, recurrent_state): + """Single-step recurrent update for token generation.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + q_t = query[:, :, 0] + k_t = key[:, :, 0] + v_t = value[:, :, 0] + g_t = g[:, :, 0].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, 0].unsqueeze(-1) + + new_state = recurrent_state * g_t + kv_mem = (new_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + new_state = new_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + output = (new_state * q_t.unsqueeze(-1)).sum(dim=-2) + + return output.unsqueeze(2), new_state + + def _nki_recurrent_forward(self, query, key, value, g, beta): + """Full-sequence recurrent forward using NKI kernel for context encoding.""" + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + BH = B * H + query_flat = query.reshape(BH, S, k_dim).contiguous() + key_flat = key.reshape(BH, S, k_dim).contiguous() + value_flat = value.reshape(BH, S, v_dim).contiguous() + + g_flat = g.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + beta_flat = beta.reshape(BH, S).unsqueeze(-1).expand(-1, -1, v_dim).contiguous() + + outputs = [] + states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_nki_kernel_state( + query_flat[bh], + key_flat[bh], + value_flat[bh], + g_flat[bh], + beta_flat[bh], + ) + outputs.append(out_bh) + states.append(state_bh) + + output = torch.stack(outputs, dim=0) + output = output.reshape(B, H, S, v_dim) + + final_state = torch.stack(states, dim=0) + final_state = final_state.reshape(B, H, k_dim, v_dim) + + return output, final_state + + def _nki_chunked_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunked NKI kernel forward for context encoding (prefill).""" + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + num_chunks = total_seq_len // chunk_size + g_reshaped = g.reshape(B, H, num_chunks, chunk_size) + g_cs = g_reshaped.cumsum(dim=-1) + g_last_per_chunk = g_cs[:, :, :, -1:] + g_last_expanded = g_last_per_chunk.expand(-1, -1, -1, chunk_size) + + query_chunks = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key_chunks = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value_chunks = value.reshape(B, H, num_chunks, chunk_size, v_dim) + + beta_chunks = ( + beta.reshape(B, H, num_chunks, chunk_size) + .unsqueeze(-1) + .expand(-1, -1, -1, -1, v_dim) + ) + gc_chunks = g_cs.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + gl_chunks = g_last_expanded.unsqueeze(-1).expand(-1, -1, -1, -1, v_dim) + + BH = B * H + query_chunks = query_chunks.reshape( + BH, num_chunks, chunk_size, k_dim + ).contiguous() + key_chunks = key_chunks.reshape(BH, num_chunks, chunk_size, k_dim).contiguous() + value_chunks = value_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + beta_chunks = beta_chunks.reshape( + BH, num_chunks, chunk_size, v_dim + ).contiguous() + gc_chunks = gc_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + gl_chunks = gl_chunks.reshape(BH, num_chunks, chunk_size, v_dim).contiguous() + + device = query.device + lower_mask = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=-1, + ) + identity_mat = torch.eye(chunk_size, dtype=torch.float32, device=device) + lower_mask_diag = torch.tril( + torch.ones(chunk_size, chunk_size, dtype=torch.float32, device=device), + diagonal=0, + ) + + initial_state_flat = None + if initial_state is not None: + initial_state_flat = initial_state.reshape(BH, k_dim, v_dim).float().contiguous() + + all_outputs = [] + all_states = [] + for bh in range(BH): + if initial_state_flat is None: + state = torch.zeros(k_dim, v_dim, dtype=torch.float32, device=device) + else: + state = initial_state_flat[bh] + + head_chunks = [] + for c_idx in range(num_chunks): + q_chunk = query_chunks[bh, c_idx].contiguous() + k_chunk = key_chunks[bh, c_idx].contiguous() + v_chunk = value_chunks[bh, c_idx].contiguous() + beta_chunk = beta_chunks[bh, c_idx].contiguous() + gc_chunk = gc_chunks[bh, c_idx].contiguous() + gl_chunk = gl_chunks[bh, c_idx].contiguous() + + out_chunk, state = _deltanet_nki_chunk_step( + q_chunk, + k_chunk, + v_chunk, + beta_chunk, + gc_chunk, + gl_chunk, + state, + lower_mask, + identity_mat, + lower_mask_diag, + ) + head_chunks.append(out_chunk) + + head_output = torch.cat(head_chunks, dim=0) + all_outputs.append(head_output) + all_states.append(state) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _fused_chunked_forward( + self, query, key, value, g, beta, output_final_state=False + ): + """Fused single-kernel chunked forward for CTE — SSD-style. + + Processes all chunks in a single NKI kernel call per (B,H) pair. + State persists in SBUF across chunks (no HBM round-trips). + Cumsum of g computed in-kernel via tensor_tensor_scan. + + This is the optimized version of _nki_chunked_forward with: + 1. Single kernel call per (B,H) instead of B*H*num_chunks + 2. State in SBUF across all chunks (biggest perf win) + 3. In-kernel cumsum (avoids PyTorch cumsum overhead) + 4. tensor_scalar for broadcasts (no explicit loops) + """ + chunk_size = 128 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + # Pad sequence to multiple of chunk_size + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + # Pass raw per-token log-decay. The fused NKI kernel forms decay as + # exp(cumsum(g)_i - cumsum(g)_j), so no pre-kernel clamp is needed. + + BH = B * H + # Flatten to (BH, S, dim) for per-(b,h) kernel calls + query_flat = query.reshape(BH, total_seq_len, k_dim).contiguous() + key_flat = key.reshape(BH, total_seq_len, k_dim).contiguous() + value_flat = value.reshape(BH, total_seq_len, v_dim).contiguous() + + # g and beta: (BH, S) -> (BH, S, 1) for the kernel's (S, 1) input layout + g_flat = g.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + beta_flat = beta.reshape(BH, total_seq_len).unsqueeze(-1).contiguous() + + # Create constant mask tensors (shared across all B*H calls) + device = query.device + lower_mask = torch.tensor( + _make_lower_mask(), dtype=torch.float32, device=device + ) + identity_mat = torch.tensor( + _make_identity(), dtype=torch.float32, device=device + ) + lower_mask_diag = torch.tensor( + _make_lower_mask_diag(), dtype=torch.float32, device=device + ) + + all_outputs = [] + all_states = [] + for bh in range(BH): + out_bh, state_bh = _deltanet_fused_kernel( + query_flat[bh], # (S, 128) + key_flat[bh], # (S, 128) + value_flat[bh], # (S, 128) + g_flat[bh], # (S, 1) — RAW g, not cumsum + beta_flat[bh], # (S, 1) — sigmoid(b) + lower_mask, # (128, 128) + identity_mat, # (128, 128) + lower_mask_diag, # (128, 128) + ) + all_outputs.append(out_bh) + all_states.append(state_bh) + + output = torch.stack(all_outputs, dim=0) + output = output.reshape(B, H, total_seq_len, v_dim) + output = output[:, :, :S] + + if output_final_state: + final_state = torch.stack(all_states, dim=0) + last_recurrent_state = final_state.reshape(B, H, k_dim, v_dim) + else: + last_recurrent_state = None + + return output, last_recurrent_state + + def _sequential_forward(self, query, key, value, g, beta, output_final_state=False): + """Sequential full-sequence gated delta rule for CTE. + + Uses the same per-step recurrence as _recurrent_step but loops over the + full sequence. Avoids the slice-assignment loop in _chunk_forward that + may compile incorrectly on Neuron/XLA. + """ + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + state = query.new_zeros(B, H, k_dim, v_dim) + all_outputs = [] + for t in range(S): + q_t = query[:, :, t] # (B, H, K) + k_t = key[:, :, t] # (B, H, K) + v_t = value[:, :, t] # (B, H, V) + beta_t = beta[:, :, t].unsqueeze(-1) # (B, H, 1) + g_t = g[:, :, t].exp().unsqueeze(-1).unsqueeze(-1) # (B, H, 1, 1) + + # Gated delta rule + state = state * g_t + kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + delta = (v_t - kv_mem) * beta_t # (B, H, V) + state = state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) # (B, H, K, V) + + o_t = (state * q_t.unsqueeze(-1)).sum(dim=-2) # (B, H, V) + all_outputs.append(o_t.unsqueeze(2)) + + output = torch.cat(all_outputs, dim=2) # (B, H, S, V) + final_state = state if output_final_state else None + return output, final_state + + def _chunk_forward( + self, query, key, value, g, beta, output_final_state=False, initial_state=None + ): + """Chunk-based forward for context encoding (prefill).""" + chunk_size = 64 + + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + B, H, S, k_dim = query.shape + v_dim = value.shape[-1] + scale = 1.0 / (k_dim**0.5) + query = query * scale + + pad_size = (chunk_size - S % chunk_size) % chunk_size + if pad_size > 0: + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = S + pad_size + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + num_chunks = total_seq_len // chunk_size + query = query.reshape(B, H, num_chunks, chunk_size, k_dim) + key = key.reshape(B, H, num_chunks, chunk_size, k_dim) + value = value.reshape(B, H, num_chunks, chunk_size, v_dim) + k_beta = k_beta.reshape(B, H, num_chunks, chunk_size, k_dim) + v_beta = v_beta.reshape(B, H, num_chunks, chunk_size, v_dim) + g = g.reshape(B, H, num_chunks, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + if initial_state is None: + last_recurrent_state = torch.zeros( + B, H, k_dim, v_dim, dtype=query.dtype, device=query.device + ) + else: + last_recurrent_state = initial_state.to(dtype=query.dtype) + core_attn_out = torch.zeros_like(value) + mask2 = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + for i in range(num_chunks): + q_i = query[:, :, i] + k_i = key[:, :, i] + v_i = value[:, :, i] + + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask2, 0 + ) + + v_prime = k_cumdecay[:, :, i] @ last_recurrent_state + v_new = v_i - v_prime + + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn_i @ v_new + + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + ( + k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None] + ).transpose(-1, -2) + @ v_new + ) + + core_attn_out = core_attn_out.reshape(B, H, -1, v_dim) + core_attn_out = core_attn_out[:, :, :S] + + if not output_final_state: + last_recurrent_state = None + + return core_attn_out, last_recurrent_state + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + **kwargs, + ): + """Forward pass compatible with NxDI decoder layer interface.""" + batch_size, seq_len, _ = hidden_states.shape + + seq_ids = kwargs.get("seq_ids", None) + qwen_chunked_prefill_active = ( + self.use_qwen_hybrid_chunked_prefill + and past_key_value is not None + and seq_len > 1 + ) + is_decode = past_key_value is not None and not qwen_chunked_prefill_active + + # Padding mask for DeltaNet: [B, S, 1] with 1.0 for real tokens, 0.0 for padding. + # Passed from get_model_output where it's computed from input_ids != pad_token_id. + # Embeddings are already zeroed for padding tokens; this mask additionally + # zeros the decay gate so the recurrent state is preserved unchanged + # through padding positions (no spurious decay). + valid_mask_1d = kwargs.get("deltanet_padding_mask", None) # [B, S, 1] or None + hybrid_cache_active = self.use_hybrid_cache_manager + recurrent_state_cache = None + conv_state_cache = None + if hybrid_cache_active and past_key_value is not None: + recurrent_state_cache, conv_state_cache = past_key_value + + # Project inputs + deltanet_fp32 = os.environ.get("DELTANET_FP32") == "1" + if deltanet_fp32 and isinstance(self.in_proj_qkv, nn.Linear): + hs_f32 = hidden_states.float() + qkv = F.linear(hs_f32, self.in_proj_qkv.weight.float()).to( + hidden_states.dtype + ) + z = F.linear(hs_f32, self.in_proj_z.weight.float()).to(hidden_states.dtype) + b = F.linear(hs_f32, self.in_proj_b.weight.float()).to(hidden_states.dtype) + a = F.linear(hs_f32, self.in_proj_a.weight.float()).to(hidden_states.dtype) + else: + qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + # Split QKV + query = qkv[..., : self.key_dim] + key = qkv[..., self.key_dim : self.key_dim * 2] + value = qkv[..., self.key_dim * 2 :] + + # Causal Conv1d on QKV + mixed = torch.cat([query, key, value], dim=-1) + mixed = mixed.transpose(1, 2) + + if is_decode: + if conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + elif seq_ids is not None: + conv_state = torch.index_select(self.conv_state_buffer, 0, seq_ids) + else: + conv_state = self.conv_state_buffer[:batch_size] + conv_input = torch.cat([conv_state, mixed], dim=-1) + + w = self._conv1d_weight().squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(4): + conv_out = ( + conv_out + + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[:, :, k : k + 1] + ) + mixed_post_conv = F.silu(conv_out) + + new_conv_state = torch.cat([conv_state[:, :, 1:], mixed], dim=-1) + alloc_bs = self.conv_state_buffer.shape[0] + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + self.conv_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + if qwen_chunked_prefill_active and conv_state_cache is not None: + conv_state = conv_state_cache[:batch_size] + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=conv_state.dtype, device=conv_state.device + ) + conv_state = conv_state * (1.0 - reset_mask[:, None, :]) + conv_input = torch.cat([conv_state, mixed], dim=-1) + w = self._conv1d_weight().squeeze(1) + conv_out = torch.zeros_like(mixed) + for k in range(self.conv_kernel_size): + conv_out = conv_out + w[:, k].unsqueeze(0).unsqueeze(-1) * conv_input[ + :, :, k : k + seq_len + ] + mixed_post_conv = F.silu(conv_out) + if valid_mask_1d is not None: + state_len = self.conv_kernel_size - 1 + num_valid = valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + idx_base = (state_len + num_valid - state_len).clamp(min=0) + offsets = torch.arange(state_len, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(conv_input, 2, gather_idx) + else: + new_conv_state = conv_input[:, :, -self.conv_kernel_size + 1 :].contiguous() + else: + mixed_post_conv = F.silu( + F.conv1d( + mixed, + self._conv1d_weight(), + bias=None, + padding=self.conv_kernel_size - 1, + groups=self.conv_dim, + )[:, :, :seq_len] + ) + + if valid_mask_1d is not None: + # valid_mask_1d is [B, S, 1]; count valid tokens per batch + num_valid = ( + valid_mask_1d.squeeze(-1).sum(dim=-1, keepdim=True).long() + ) # [B, 1] + idx_base = num_valid - 3 + idx_base = idx_base.clamp(min=0) + offsets = torch.arange(3, device=mixed.device).unsqueeze(0) + gather_idx = idx_base + offsets # [B, 3] + gather_idx = gather_idx.unsqueeze(1).expand(-1, self.conv_dim, -1) + new_conv_state = torch.gather(mixed, 2, gather_idx) + else: + new_conv_state = mixed[:, :, -3:].contiguous() + + alloc_bs = self.conv_state_buffer.shape[0] + if hybrid_cache_active: + new_conv_state = new_conv_state.to(self.conv_state_buffer.dtype) + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 = direct replacement + new_conv_state = ( + new_conv_state.to(self.conv_state_buffer.dtype) + + self.conv_state_buffer * 0 + ) + elif batch_size < alloc_bs: + pad_size = alloc_bs - batch_size + new_conv_state = torch.cat( + [ + new_conv_state, + torch.zeros( + pad_size, + self.conv_dim, + self.conv_kernel_size - 1, + dtype=new_conv_state.dtype, + device=new_conv_state.device, + ), + ], + dim=0, + ) + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + else: + new_conv_state = new_conv_state + self.conv_state_buffer * 0 + + mixed_post_conv = mixed_post_conv.transpose(1, 2) + + # Zero out conv1d output for padding positions. + # Conv1d with kernel_size=4 leaks real token info into the first + # few padding positions. Zeroing here ensures Q, K, V are exactly + # zero for all padding positions so the recurrence is unaffected. + if valid_mask_1d is not None: + mixed_post_conv = ( + mixed_post_conv * valid_mask_1d + ) # [B, S, conv_dim] * [B, S, 1] + + query = mixed_post_conv[..., : self.key_dim] + key = mixed_post_conv[..., self.key_dim : self.key_dim * 2] + value = mixed_post_conv[..., self.key_dim * 2 :] + + # Reshape to heads + query = query.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + key = key.reshape(batch_size, seq_len, self.num_k_heads, self.head_k_dim) + value = value.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + + # Compute gating + beta = b.sigmoid() + g = -self._A_log().float().exp() * F.softplus(a.float() + self._dt_bias()) + + if valid_mask_1d is not None: + # Zero g for padding → alpha=exp(0)=1 → state preserved through padding + # Zero beta for padding → no state update from padding tokens + mask_2d = valid_mask_1d.squeeze(-1).float() # [B, S] + g = g * mask_2d.unsqueeze(-1) + beta = beta * mask_2d.unsqueeze(-1) + + # Expand K heads to match V heads (16 -> 48) using expand+reshape + if self.num_v_heads // self.num_k_heads > 1: + rep = self.num_v_heads // self.num_k_heads # 3 + query = ( + query.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + key = ( + key.unsqueeze(3) + .expand(-1, -1, -1, rep, -1) + .reshape(batch_size, seq_len, self.num_v_heads, self.head_k_dim) + ) + + # Transpose to (B, H, S, dim) + query = query.transpose(1, 2).contiguous().float() + key = key.transpose(1, 2).contiguous().float() + value = value.transpose(1, 2).contiguous().float() + g = g.transpose(1, 2).contiguous().float() + beta = beta.transpose(1, 2).contiguous().float() + + if is_decode: + # TKG: single-step recurrent update + if recurrent_state_cache is not None: + recurrent_state = recurrent_state_cache[:batch_size].float() + elif seq_ids is not None: + recurrent_state = torch.index_select( + self.recurrent_state_buffer, 0, seq_ids + ).float() + else: + recurrent_state = self.recurrent_state_buffer[:batch_size].float() + + output, new_state = self._recurrent_step( + query, key, value, g, beta, recurrent_state + ) + new_state_bf16 = new_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if hybrid_cache_active: + new_rec_state = new_state_bf16 + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + new_state_bf16, + self.recurrent_state_buffer[batch_size:] * 0, + ], + dim=0, + ) + else: + new_rec_state = new_state_bf16 + self.recurrent_state_buffer * 0 + else: + # CTE: fused NKI kernel by default (PyTorch _chunk_forward can hit + # neuronx-cc codegen ICE NCC_INLA001 with these DeltaNet dimensions). + # Override with env vars for debugging/benchmarking. + use_nki_fused = os.environ.get("USE_NKI_FUSED", "1") != "0" + use_nki_chunked = os.environ.get("USE_NKI_CHUNKED") == "1" + use_nki = os.environ.get("USE_NKI") == "1" + use_sequential = os.environ.get("DELTANET_SEQUENTIAL") == "1" + use_pytorch_chunk = os.environ.get("USE_PYTORCH_CHUNK") == "1" + + if qwen_chunked_prefill_active and recurrent_state_cache is not None: + initial_state = recurrent_state_cache[:batch_size].float() + if position_ids is not None: + reset_mask = (position_ids[:, :1].long() == 0).to( + dtype=initial_state.dtype, device=initial_state.device + ) + initial_state = initial_state * (1.0 - reset_mask[:, :, None, None]) + if self.use_qwen_hybrid_chunked_prefill_nki: + output, final_state = self._nki_chunked_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + else: + output, final_state = self._chunk_forward( + query, + key, + value, + g, + beta, + output_final_state=True, + initial_state=initial_state, + ) + elif use_pytorch_chunk: + output, final_state = self._chunk_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_chunked: + output, final_state = self._nki_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki: + output, final_state = self._nki_recurrent_forward( + query, key, value, g, beta + ) + elif use_sequential: + output, final_state = self._sequential_forward( + query, key, value, g, beta, output_final_state=True + ) + elif use_nki_fused: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + else: + output, final_state = self._fused_chunked_forward( + query, key, value, g, beta, output_final_state=True + ) + + if final_state is not None: + final_state_bf16 = final_state.to(self.recurrent_state_buffer.dtype) + alloc_bs = self.recurrent_state_buffer.shape[0] + if hybrid_cache_active: + new_rec_state = final_state_bf16 + elif seq_ids is not None: + # BS=1 optimization: scatter to index 0 of size-1 buffer = direct replacement + # Add buffer dependency for input_output_alias + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + elif batch_size < alloc_bs: + new_rec_state = torch.cat( + [ + final_state_bf16, + torch.zeros( + alloc_bs - batch_size, + self.num_v_heads, + self.head_k_dim, + self.head_v_dim, + dtype=final_state_bf16.dtype, + device=final_state_bf16.device, + ), + ], + dim=0, + ) + new_rec_state = new_rec_state + self.recurrent_state_buffer * 0 + else: + new_rec_state = final_state_bf16 + self.recurrent_state_buffer * 0 + else: + new_rec_state = self.recurrent_state_buffer * 1 + + # Output: norm, gate, project + output = output.to(hidden_states.dtype) + output = output.transpose(1, 2).contiguous() + output = output.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = self.norm(output) + z_gate = z.reshape(batch_size, seq_len, self.num_v_heads, self.head_v_dim) + output = output * F.silu(z_gate) + output = output.reshape(batch_size, seq_len, self.value_dim) + output = self.out_proj(output) + + if hybrid_cache_active: + return output, (new_rec_state, new_conv_state), new_rec_state, new_conv_state + + # Return dummy KV for KVCacheManager + dummy_k = torch.zeros( + batch_size, + self.kv_heads_per_rank, + seq_len, + self.head_dim, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + dummy_v = torch.zeros_like(dummy_k) + + return output, (dummy_k, dummy_v), new_rec_state, new_conv_state + + +# ============================================================ +# InferenceConfig (Dense -- no MoE) +# ============================================================ + + +class Qwen35InferenceConfig(InferenceConfig): + """Config for Qwen3.5/3.6-27B (dense) with hybrid DeltaNet + Attention.""" + + def __init__(self, *args, **kwargs): + # Set defaults BEFORE super().__init__() because it calls validate_config() + # which checks get_required_attributes(). These can be overridden by + # kwargs or load_config. + + # Layer types for hybrid dispatch: [3 DeltaNet + 1 GQA] repeated. + if "layer_types" not in kwargs and not any( + hasattr(a, "layer_types") for a in args if hasattr(a, "__dict__") + ): + num_layers = kwargs.get("num_hidden_layers", 64) + if num_layers % 4 != 0: + raise ValueError( + f"Qwen3.5 hybrid layer count must be divisible by 4, got {num_layers}" + ) + layer_types = [] + for _ in range(num_layers // 4): + layer_types.extend( + [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + ) + kwargs.setdefault("layer_types", layer_types) + + # DeltaNet-specific config defaults + kwargs.setdefault("linear_num_value_heads", 48) + kwargs.setdefault("linear_num_key_heads", 16) + kwargs.setdefault("linear_key_head_dim", 128) + kwargs.setdefault("linear_value_head_dim", 128) + kwargs.setdefault("linear_conv_kernel_dim", 4) + kwargs.setdefault("use_hybrid_cache_manager", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill", False) + kwargs.setdefault("use_qwen_hybrid_chunked_prefill_nki", False) + + super().__init__(*args, **kwargs) + + # Attention output gate + self.attn_output_gate = getattr(self, "attn_output_gate", True) + + # Partial RoPE + self.partial_rotary_factor = getattr(self, "partial_rotary_factor", 0.25) + self.rope_dim = int(self.head_dim * self.partial_rotary_factor) # 64 + + # mRoPE (multimodal RoPE) for VL support + rope_params = getattr(self, "rope_parameters", {}) or {} + self.mrope_section = rope_params.get("mrope_section", [11, 11, 10]) + self.mrope_interleaved = rope_params.get("mrope_interleaved", True) + + # Standard HF config attributes expected by NxDI + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "intermediate_size", + "max_position_embeddings", + "num_attention_heads", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_theta", + "vocab_size", + # DeltaNet-specific + "linear_num_value_heads", + "linear_num_key_heads", + "linear_key_head_dim", + "linear_value_head_dim", + "linear_conv_kernel_dim", + "layer_types", + ] + + @classmethod + def get_neuron_config_cls(cls): + return NeuronConfig + + +# ============================================================ +# Attention (standard GQA for 16 of 64 layers) +# With output gate: q_proj is 2x sized, split into (query, gate) +# With partial RoPE: only first rope_dim dimensions get rotary +# ============================================================ + + +class Qwen35MRoPEEmbedding(nn.Module): + """Multimodal Rotary Position Embedding (mRoPE) for Qwen3.5. + + Handles 3D position information (temporal, height, width) for VL models. + Position IDs have shape (3, batch_size, seq_len) for T/H/W dimensions. + For text-only (2D position_ids), broadcasts to 3D with identical positions. + """ + + def __init__(self, config): + super().__init__() + self.head_dim = config.head_dim # 256 + self.rope_dim = config.rope_dim # 64 + self.mrope_section = config.mrope_section # [11, 11, 10] + self.mrope_interleaved = getattr(config, "mrope_interleaved", True) + self.rope_theta = config.rope_theta + + # Validate mrope_section sums to rope_dim // 2 = 32 + assert sum(self.mrope_section) == self.rope_dim // 2, ( + f"mrope_section {self.mrope_section} sums to {sum(self.mrope_section)}, " + f"expected {self.rope_dim // 2}" + ) + + def forward(self, x, position_ids_3d): + """Compute cos/sin from 3D position IDs. + + Args: + x: hidden_states (for device/dtype inference) + position_ids_3d: (3, batch_size, seq_len) -- T, H, W positions + + Returns: + cos: (batch_size, seq_len, rope_dim) + sin: (batch_size, seq_len, rope_dim) + """ + device = x.device + dtype = torch.float32 + + if position_ids_3d.ndim == 2: + position_ids_3d = position_ids_3d[None, ...].expand( + 3, position_ids_3d.shape[0], -1 + ) + + inv_freq = 1.0 / ( + self.rope_theta + ** ( + torch.arange(0, self.rope_dim, 2, dtype=dtype, device=device) + / self.rope_dim + ) + ) + inv_freq = inv_freq[None, None, :, None].expand( + 3, position_ids_3d.shape[1], -1, 1 + ) + positions = position_ids_3d[:, :, None, :].float() + freqs = (inv_freq.float() @ positions).transpose(2, 3) + + # Match HF Qwen3.6 mRoPE layout exactly: start from the temporal + # frequencies, then splice H/W frequencies into interleaved positions. + freqs_t = freqs[0] + if self.mrope_interleaved: + for dim, offset in enumerate((1, 2), start=1): + length = self.mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + + emb = torch.cat((freqs_t, freqs_t), dim=-1) + cos = emb.cos().to(dtype=x.dtype) + sin = emb.sin().to(dtype=x.dtype) + + return cos, sin + + +class NeuronQwen35Attention(NeuronAttentionBase): + """Standard GQA attention for Qwen3.5 with output gate and partial RoPE. + + 24 Q heads, 4 KV heads (6:1 GQA), head_dim=256 for 27B dense. + q_proj is doubled (query + gate), split at load time. + Only first rope_dim=64 of head_dim=256 gets rotary encoding. + + Uses NeuronAttentionBase infrastructure for QKV projection, KV cache, + RoPE, and attention computation. Overrides forward() to insert the + sigmoid output gate between attention output and o_proj. + """ + + def __init__(self, config): + # Partial RoPE: create mRoPE embedding with rope_dim (64) + self.rope_dim = config.rope_dim # 64 = head_dim * partial_rotary_factor + + # Create QK norm modules (will be passed to base class) + rms_norm_eps = config.rms_norm_eps + q_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + k_ln = get_rmsnorm_cls()(config.head_dim, rms_norm_eps) + + # Partial RoPE: use standard RotaryEmbedding. + # For VL with 3D mRoPE positions, cos/sin are pre-computed externally in + # get_model_output() using Qwen35MRoPEEmbedding and passed as cos_cache/sin_cache. + rotary_emb = RotaryEmbedding( + self.rope_dim, # Only 64 dims get rotary embedding + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=rms_norm_eps, + use_qk_norm=False, + q_layernorm=q_ln, + k_layernorm=k_ln, + ) + + # Separate mRoPE module for VL 3D position_ids + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + # Output gate projection: hidden_size -> num_heads * head_dim + # Populated from the second half of q_proj during state dict conversion. + self.output_gate_proj = ColumnParallelLinear( + config.hidden_size, + config.num_attention_heads * config.head_dim, + bias=False, + gather_output=False, + ) + + def apply_rotary_embedding( + self, Q, K, V, position_ids, cos_cache, sin_cache, use_polar_compatible_rope + ): + """Partial RoPE: only apply rotary embedding to first rope_dim dimensions. + + Q shape: (B, H, S, head_dim) where head_dim=256 + cos/sin shape: (B, S, rope_dim) where rope_dim=64 (from RotaryEmbedding(dim=64)) + + Split Q/K along last dim into: + q_rope (first 64 dims) -- apply RoPE + q_pass (remaining 192 dims) -- pass through unchanged + """ + from neuronx_distributed_inference.modules.attention.utils import ( + apply_rotary_pos_emb, + ) + + if self.rotary_emb is not None: + if cos_cache is None or sin_cache is None: + cos_cache, sin_cache = self.rotary_emb(V, position_ids) + + # Split into rope and pass-through portions + Q_orig_dtype = Q.dtype + q_rope = Q[..., : self.rope_dim] # (B, H, S, 64) + q_pass = Q[..., self.rope_dim :] # (B, H, S, 192) + k_rope = K[..., : self.rope_dim] + k_pass = K[..., self.rope_dim :] + + # Apply RoPE only to the rope portion + q_rope, k_rope = apply_rotary_pos_emb(q_rope, k_rope, cos_cache, sin_cache) + + # Concatenate back (ensure bf16 is maintained) + Q = torch.cat([q_rope, q_pass], dim=-1).to(Q_orig_dtype) + K = torch.cat([k_rope, k_pass], dim=-1).to(Q_orig_dtype) + + return Q, K, cos_cache, sin_cache + + def perform_prefill(self, Q, K, V, q_len, bsz, attention_mask=None): + """Prefill path with NKI flash attention for head_dim=256.""" + head_dim = Q.shape[-1] + + # Option B: nkilib flash attention for head_dim > 128 + if _nkilib_flash_attn is not None: + q_contig = Q.contiguous() + k_contig = K.contiguous() + v_contig = V.contiguous() + scale = 1.0 / math.sqrt(head_dim) + result = _nkilib_flash_attn( + q_contig, k_contig, v_contig, scale=scale, use_causal_mask=True + ) + return result, None + + # Option A: kernel patched globally + if NKILIB_PATCH_ACTIVE: + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + # Fallback: softmax path (use 3D tensors to avoid compiler ICE with 4D patterns) + if head_dim > 128: + # GQA: expand K/V heads to match Q heads + num_q_heads = Q.shape[1] + num_kv_heads = K.shape[1] + if num_q_heads != num_kv_heads: + kv_rep = num_q_heads // num_kv_heads + K = ( + K.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + V = ( + V.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(bsz, num_q_heads, q_len, head_dim) + ) + # Reshape to 3D (B*H, S, d) to avoid neuronx-cc codegen ICE with 4D + # attention weight tensors (NCC_INLA001: Expected 2D tensor but got 4D AP) + Q_3d = Q.reshape(bsz * num_q_heads, q_len, head_dim) + K_3d = K.reshape(bsz * num_q_heads, q_len, head_dim) + V_3d = V.reshape(bsz * num_q_heads, q_len, head_dim) + attn_weights = torch.bmm(Q_3d, K_3d.transpose(-1, -2)) / math.sqrt(head_dim) + # Build causal mask for 3D: (1, S, S) broadcast over B*H + causal_mask = torch.triu( + torch.full( + (q_len, q_len), + -65504.0, + dtype=attn_weights.dtype, + device=attn_weights.device, + ), + diagonal=1, + ).unsqueeze(0) + attn_weights = attn_weights + causal_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + Q.dtype + ) + attn_output = torch.bmm(attn_weights, V_3d) + # Reshape back to 4D (B, H, S, d) + return attn_output.reshape(bsz, num_q_heads, q_len, head_dim), None + + return _flash_fwd_call(Q, K, V, use_causal_mask=True), None + + def perform_qwen_chunked_prefill(self, Q, K, V, past_key_value, position_ids): + """Exact chunked CTE over the full decode cache. + + The current chunk K/V tensors are scattered into the full cache at + absolute position_ids, then attention for this chunk is computed over + all cache positions up to the chunk end. This keeps full-attention + layers correct when model-local chunked prefill feeds context in + multiple CTE-bucket calls. + """ + k_cache, v_cache = past_key_value + B, q_heads, q_len, head_dim = Q.shape + kv_heads = K.shape[1] + cache_len = k_cache.shape[2] + + pos = position_ids.long() + k_index = pos[:, None, :, None].expand(B, kv_heads, q_len, head_dim) + k_cache = torch.scatter(k_cache, dim=2, index=k_index, src=K.to(k_cache.dtype)) + v_cache = torch.scatter(v_cache, dim=2, index=k_index, src=V.to(v_cache.dtype)) + + if q_heads != kv_heads: + kv_rep = q_heads // kv_heads + K_full = ( + k_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + V_full = ( + v_cache.unsqueeze(2) + .expand(-1, -1, kv_rep, -1, -1) + .reshape(B, q_heads, cache_len, head_dim) + ) + else: + K_full = k_cache + V_full = v_cache + + attn_weights = torch.matmul(Q, K_full.transpose(-1, -2)) / math.sqrt(head_dim) + cache_positions = torch.arange(cache_len, device=position_ids.device).view(1, 1, 1, -1) + causal_mask = cache_positions <= pos[:, None, :, None] + attn_weights = attn_weights.masked_fill(~causal_mask, -65504.0) + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(Q.dtype) + return torch.matmul(attn_weights, V_full) + + def forward( + self, + hidden_states, + attention_mask=None, + position_ids=None, + past_key_value=None, + cos_cache=None, + sin_cache=None, + rmsnorm=None, + adapter_ids=None, + active_mask=None, + **kwargs, + ): + """Forward with output gate applied BEFORE o_proj. + + Override NeuronAttentionBase.forward() to insert the sigmoid gate + between the attention output and o_proj, matching the HF reference: + gate = sigmoid(gate_proj(pre_attn_hidden)) + attn_output = attn_output * gate + attn_output = o_proj(attn_output) + """ + bsz, q_len, _ = hidden_states.shape + + # Use standard 2D position_ids for prep_qkv_tensors. + rope_pos_ids = position_ids + + # Compute gate from input hidden states (before QKV projection) + gate = self.output_gate_proj(hidden_states) # (B, S, num_heads * head_dim) + + # Standard QKV prep (projections, QK norm, RoPE) + Q, K, V, cos_cache, sin_cache, _residual = self.prep_qkv_tensors( + rope_pos_ids, + hidden_states, + past_key_value, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rmsnorm=rmsnorm, + ) + + qwen_chunked_prefill_active = ( + past_key_value is not None + and q_len > 1 + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + + if past_key_value is None: + # Context encoding (prefill) + attn_output, _flash_strategy = self.perform_prefill( + Q, K, V, q_len, bsz, attention_mask + ) + elif qwen_chunked_prefill_active: + attn_output = self.perform_qwen_chunked_prefill( + Q, K, V, past_key_value, position_ids + ) + else: + # Token generation (decode) + tkg_mask = attention_mask + if tkg_mask is not None and tkg_mask.ndim == 2: + tkg_mask = tkg_mask.unsqueeze(1).unsqueeze(2) # (B, S) -> (B, 1, 1, S) + attn_output = self.compute_for_token_gen( + Q, K, V, position_ids, past_key_value, tkg_mask, active_mask + ) + + # attn_output is (B, H, S, head_dim) -- transpose to (B, S, H*head_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) + + # Apply sigmoid output gate BEFORE o_proj (matching HF reference) + attn_output = attn_output * torch.sigmoid(gate) + + # Apply o_proj + attn_output = self.get_o_proj()(attn_output, adapter_ids=adapter_ids) + + # Ensure K, V are in model dtype (bf16) for KV cache update + # (prevents mixed-precision dynamic-update-slice in neuronx-cc) + K = K.to(self.torch_dtype) + V = V.to(self.torch_dtype) + past_key_value = (K, V) + return attn_output, past_key_value, cos_cache, sin_cache + + +# ============================================================ +# Dense MLP (replaces MoE) +# ============================================================ + + +class Qwen35MLP(nn.Module): + """Dense SwiGLU MLP for Qwen3.5/3.6-27B. + + gate_proj: hidden_size -> intermediate_size (5120 -> 17408) + up_proj: hidden_size -> intermediate_size (5120 -> 17408) + down_proj: intermediate_size -> hidden_size (17408 -> 5120) + + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + """ + + def __init__(self, config): + super().__init__() + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=False, + gather_output=False, + ) + self.down_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + ) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + hidden_states = F.silu(gate) * up + hidden_states = self.down_proj(hidden_states) + return hidden_states + + +# ============================================================ +# Decoder Layer (hybrid dispatch -- DeltaNet or GQA + Dense MLP) +# ============================================================ + + +class NeuronQwen35DecoderLayer(nn.Module): + """Hybrid decoder layer: dispatches to DeltaNet or standard attention. + Uses dense MLP for all layers (no MoE). + """ + + def __init__(self, config: Qwen35InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_type = config.layer_types[layer_idx] + self.layer_idx = layer_idx + self.config = config + + # Attention (DeltaNet or standard GQA) + if self.layer_type == "linear_attention": + self.linear_attn = NeuronGatedDeltaNet(config, layer_idx) + else: + self.self_attn = NeuronQwen35Attention(config=config) + + # Dense MLP (all layers) + self.mlp = Qwen35MLP(config) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + past_key_value=None, + padding_mask=None, + cos_cache=None, + sin_cache=None, + **kwargs, + ): + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + hidden_states = self.input_layernorm(hidden_states) + + if self.layer_type == "linear_attention": + # DeltaNet path + attn_out, dummy_kv, new_rec_state, new_conv_state = self.linear_attn( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + **kwargs, + ) + hidden_states = residual + attn_out + present_key_value = dummy_kv + deltanet_states = ( + None + if getattr(self.config, "use_hybrid_cache_manager", False) + else (new_rec_state, new_conv_state) + ) + else: + deltanet_states = None + # Standard attention path + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + cos_cache=cos_cache, + sin_cache=sin_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Dense MLP FFN + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = ( + hidden_states, + present_key_value, + cos_cache, + sin_cache, + None, + deltanet_states, + ) + return outputs + + +# ============================================================ +# Hybrid Cache Manager (opt-in) +# ============================================================ + + +class HybridDeltaNetCacheManager(KVCacheManager): + """Layer-type-aware cache manager for Qwen3.5/Qwen3.6 hybrid dense models.""" + + def __init__(self, config: Qwen35InferenceConfig, num_kv_head, **kwargs): + self.layer_types = list(config.layer_types) + self._validate_hybrid_config(config) + super().__init__(config, num_kv_head=num_kv_head, **kwargs) + + dtype = ( + config.neuron_config.attention_dtype + if config.neuron_config.attention_dtype is not None + else config.neuron_config.torch_dtype + ) + cache_dtype = getattr(self, "cache_dtype", dtype) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + tp_degree = config.neuron_config.tp_degree + if config.linear_num_value_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={config.linear_num_value_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + if config.linear_num_key_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={config.linear_num_key_heads} must be divisible " + f"by tp_degree={tp_degree}" + ) + local_num_value_heads = config.linear_num_value_heads // tp_degree + local_num_key_heads = config.linear_num_key_heads // tp_degree + recurrent_shape = [ + max_batch_size, + local_num_value_heads, + config.linear_key_head_dim, + config.linear_value_head_dim, + ] + conv_dim = ( + 2 * local_num_key_heads * config.linear_key_head_dim + + local_num_value_heads * config.linear_value_head_dim + ) + conv_shape = [ + max_batch_size, + conv_dim, + config.linear_conv_kernel_dim - 1, + ] + + params = [] + for layer_idx, layer_type in enumerate(self.layer_types): + if layer_type == "linear_attention": + params.append( + nn.Parameter(torch.zeros(recurrent_shape, dtype=dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(conv_shape, dtype=dtype), requires_grad=False) + ) + else: + k_shape = self.k_shapes[layer_idx] if hasattr(self, "k_shapes") else self.k_shape + v_shape = self.v_shapes[layer_idx] if hasattr(self, "v_shapes") else self.v_shape + params.append( + nn.Parameter(torch.zeros(k_shape, dtype=cache_dtype), requires_grad=False) + ) + params.append( + nn.Parameter(torch.zeros(v_shape, dtype=cache_dtype), requires_grad=False) + ) + + self.past_key_values = nn.ParameterList(params) + + @staticmethod + def _validate_hybrid_config(config: Qwen35InferenceConfig): + nc = config.neuron_config + unsupported = [] + if nc.is_block_kv_layout: + unsupported.append("block KV layout") + if getattr(nc, "kv_quant_config", None) is not None or getattr(nc, "kv_cache_quant", False): + unsupported.append("KV cache quantization") + if nc.enable_fused_speculation or nc.speculation_length > 0 or nc.is_medusa: + unsupported.append("speculative decoding") + if getattr(nc, "enable_eagle_speculation", False) or getattr(nc, "is_eagle_draft", False): + unsupported.append("EAGLE speculation") + if nc.flash_decoding_enabled: + unsupported.append("flash decoding") + if nc.attention_dp_degree > 1: + unsupported.append("attention data parallelism") + if nc.kv_cache_tiling: + unsupported.append("KV cache tiling") + if nc.padding_side != "right": + unsupported.append("left padding") + if nc.is_continuous_batching: + unsupported.append("continuous batching") + if unsupported: + raise ValueError( + "HybridDeltaNetCacheManager v1 does not support: " + + ", ".join(unsupported) + ) + + def _is_deltanet_layer(self, idx: int) -> bool: + return self.layer_types[idx] == "linear_attention" + + def get_seq_length(self, past_key_values=None): + for idx, layer_type in enumerate(self.layer_types): + if layer_type != "linear_attention": + if past_key_values is None: + _, v_cache = self._fetch_cache(idx) + elif len(past_key_values) == len(self.past_key_values): + v_cache = past_key_values[2 * idx + 1] + else: + v_cache = past_key_values[idx][1] + return v_cache.shape[2] + return 0 + + def get_deltanet_state_by_layer_id(self, idx, kvcache_buffer=None, seq_ids=None): + recurrent_state, conv_state = self._fetch_cache(idx, kvcache_buffer) + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_state = torch.index_select(recurrent_state, dim=0, index=cache_idx) + conv_state = torch.index_select(conv_state, dim=0, index=cache_idx) + elif self.kv_cache_padding_size > 0: + recurrent_state = recurrent_state[: -self.kv_cache_padding_size] + conv_state = conv_state[: -self.kv_cache_padding_size] + return recurrent_state, conv_state + + def get_cache( + self, + seq_len: int, + skip_slice=False, + kvcache_buffer=None, + seq_ids=None, + windowed_context_encoding_window_idx=-1, + **kwargs, + ): + past_key_values = [] + for idx in range(len(self.past_key_values) // 2): + if self._is_deltanet_layer(idx): + past_key_values.append( + list(self.get_deltanet_state_by_layer_id(idx, kvcache_buffer, seq_ids)) + ) + else: + past_key_values.append( + list( + self.get_kv_by_layer_id( + idx=idx, + skip_slice=skip_slice, + seq_len=seq_len, + kvcache_buffer=kvcache_buffer, + seq_ids=seq_ids, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + ) + ) + return past_key_values + + def update_cache( + self, + is_for_context_encoding: bool, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + new_key_values: List[torch.Tensor], + seq_len: int, + scatter_index=None, + kv_active_mask=None, + kvcache_buffer=None, + windowed_context_encoding_window_idx: int = -1, + **kwargs, + ): + updated_cache = [] + for idx, kv_per_layer in enumerate(new_key_values): + if self._is_deltanet_layer(idx): + recurrent_state, conv_state = self.update_deltanet_state_by_layer_id( + idx=idx, + seq_ids=seq_ids, + state_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + ) + elif kwargs.get("qwen_chunked_prefill_update", False): + recurrent_state, conv_state = self.update_qwen_chunked_kv_by_layer_id( + idx=idx, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + kvcache_buffer=kvcache_buffer, + valid_mask=kwargs.get("qwen_chunked_valid_mask", None), + ) + else: + recurrent_state, conv_state = self.update_kv_by_layer_id( + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + kv_per_layer=kv_per_layer, + seq_len=seq_len, + scatter_index=scatter_index, + kv_active_mask=kv_active_mask, + kvcache_buffer=kvcache_buffer, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + updated_cache.append(recurrent_state) + updated_cache.append(conv_state) + return updated_cache + + def update_qwen_chunked_kv_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + valid_mask=None, + ): + latest_k, latest_v = kv_per_layer + k_cache, v_cache = self._fetch_cache(idx, kvcache_buffer) + latest_k = latest_k.to(k_cache.dtype) + latest_v = latest_v.to(v_cache.dtype) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + selected_k = torch.index_select(k_cache, dim=0, index=cache_idx) + selected_v = torch.index_select(v_cache, dim=0, index=cache_idx) + else: + cache_idx = None + selected_k = k_cache[: latest_k.shape[0]] + selected_v = v_cache[: latest_v.shape[0]] + + pos = position_ids.long() + k_index = pos[:, None, :, None].expand_as(latest_k) + v_index = pos[:, None, :, None].expand_as(latest_v) + + if valid_mask is not None: + valid = valid_mask.to(torch.bool)[:, None, :, None] + old_k = torch.gather(selected_k, dim=2, index=k_index) + old_v = torch.gather(selected_v, dim=2, index=v_index) + latest_k = torch.where(valid, latest_k, old_k) + latest_v = torch.where(valid, latest_v, old_v) + + updated_k = torch.scatter(selected_k, dim=2, index=k_index, src=latest_k) + updated_v = torch.scatter(selected_v, dim=2, index=v_index, src=latest_v) + + if cache_idx is not None: + k_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_k) + v_row_index = cache_idx.view(-1, 1, 1, 1).expand_as(updated_v) + k_cache = torch.scatter(k_cache, dim=0, index=k_row_index, src=updated_k) + v_cache = torch.scatter(v_cache, dim=0, index=v_row_index, src=updated_v) + return k_cache, v_cache + + if updated_k.shape[0] == k_cache.shape[0]: + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + pad_rows = k_cache.shape[0] - updated_k.shape[0] + if pad_rows > 0: + updated_k = torch.cat([updated_k, k_cache[updated_k.shape[0] :] * 0], dim=0) + updated_v = torch.cat([updated_v, v_cache[updated_v.shape[0] :] * 0], dim=0) + return updated_k + k_cache * 0, updated_v + v_cache * 0 + + def update_deltanet_state_by_layer_id( + self, + idx: int, + seq_ids: torch.Tensor, + state_per_layer: Tuple[torch.Tensor, torch.Tensor], + kvcache_buffer=None, + ): + latest_recurrent, latest_conv = state_per_layer + recurrent_cache, conv_cache = self._fetch_cache(idx, kvcache_buffer) + latest_recurrent = latest_recurrent.to(recurrent_cache.dtype) + latest_conv = latest_conv.to(conv_cache.dtype) + + if latest_recurrent.shape[0] == recurrent_cache.shape[0] and seq_ids is None: + return ( + latest_recurrent + recurrent_cache * 0, + latest_conv + conv_cache * 0, + ) + + if seq_ids is not None: + cache_idx = self.get_cache_update_index_for_seq_ids(seq_ids) + recurrent_index = cache_idx.view(-1, 1, 1, 1).expand_as(latest_recurrent) + conv_index = cache_idx.view(-1, 1, 1).expand_as(latest_conv) + recurrent_cache = torch.scatter( + input=recurrent_cache, + dim=0, + index=recurrent_index, + src=latest_recurrent, + ) + conv_cache = torch.scatter( + input=conv_cache, + dim=0, + index=conv_index, + src=latest_conv, + ) + return recurrent_cache, conv_cache + + pad_size = recurrent_cache.shape[0] - latest_recurrent.shape[0] + if pad_size > 0: + latest_recurrent = torch.cat( + [latest_recurrent, recurrent_cache[latest_recurrent.shape[0] :] * 0], + dim=0, + ) + latest_conv = torch.cat( + [latest_conv, conv_cache[latest_conv.shape[0] :] * 0], + dim=0, + ) + return latest_recurrent + recurrent_cache * 0, latest_conv + conv_cache * 0 + + +# ============================================================ +# Model +# ============================================================ + + +class NeuronQwen35Model(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen35InferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen35InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronQwen35DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + # mRoPE embedding for VL + self.mrope_emb = Qwen35MRoPEEmbedding(config) + + def init_inference_optimization(self, config: Qwen35InferenceConfig): + super().init_inference_optimization(config) + if getattr(config, "use_hybrid_cache_manager", False): + self.kv_mgr = HybridDeltaNetCacheManager( + config, + num_kv_head=self.num_key_value_heads, + global_rank=self.rank_util, + attention_chunk_size=self.attention_chunk_size, + sliding_window=self.sliding_window, + windowed_context_encoding_size=self.windowed_context_encoding_size, + layer_to_cache_size_mapping=self.layer_to_cache_size_mapping, + ) + + @property + def _deltanet_state_params(self): + """Return DeltaNet state nn.Parameters in alias order.""" + params = [] + for layer in self.layers: + if hasattr(layer, "linear_attn"): + params.append(layer.linear_attn.recurrent_state_buffer) + params.append(layer.linear_attn.conv_state_buffer) + return params + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Scatter vision embeddings into text input embeddings at image token positions.""" + _, max_positions, embedding_dim = inputs_embeds.shape + h_new = inputs_embeds.clone() + vision_flat = vision_embeddings.view(-1, embedding_dim) + positions_flat = vision_mask.view(-1) + h_new.view(-1, embedding_dim).index_put_( + (positions_flat,), vision_flat, accumulate=False + ) + return h_new + + def get_model_output( + self, + input_ids=None, + seq_ids=None, + attention_mask=None, + position_ids=None, + past_key_values=None, + active_mask=None, + inputs_embeds=None, + prev_hidden=None, + adapter_ids=None, + rotary_position_ids=None, + update_cache=False, + is_for_context_encoding=False, + vision_embeddings=None, + vision_mask=None, + local_attn_mask=None, + windowed_context_encoding_window_idx=-1, + padding_mask=None, + **kwargs, + ): + """Override to collect DeltaNet state tensors from decoder layers.""" + batch_size, seq_length = input_ids.shape[:2] + if self.config.neuron_config.layer_boundary_markers: + input_ids = ModuleMarkerStartWrapper()(input_ids) + + past_key_values_length = 0 + if past_key_values is not None: + if hasattr(self.kv_mgr, "get_seq_length"): + past_key_values_length = self.kv_mgr.get_seq_length(past_key_values) + else: + past_key_values_length = past_key_values[0][1].shape[2] + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # CRITICAL: Zero out embeddings for padding tokens so DeltaNet recurrence + # is not polluted. DeltaNet has no attention mask -- it processes all + # sequence positions through a linear recurrence. Padding tokens have + # real embedding vectors which corrupt the recurrence state. + # The mask is [B, S, 1] float with 1.0 for real tokens, 0.0 for padding. + if ( + is_for_context_encoding + and attention_mask is not None + and attention_mask.ndim == 2 + ): + deltanet_padding_mask = attention_mask.unsqueeze(-1).to( + inputs_embeds.dtype + ) + else: + deltanet_padding_mask = ( + (input_ids != self.padding_idx).unsqueeze(-1).to(inputs_embeds.dtype) + ) + if is_for_context_encoding: + inputs_embeds = inputs_embeds * deltanet_padding_mask + + # Vision embedding injection. Text-only calls still pass dummy vision + # tensors to keep the traced input signature stable; those tensors have + # one dummy entry per text token and must not overwrite text embeddings. + if (vision_embeddings is not None) and (vision_mask is not None): + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to( + self.config.neuron_config.torch_dtype + ) + has_real_vision_inputs = ( + vision_embeddings.ndim == 3 + and vision_mask.ndim == 3 + and vision_embeddings.shape[1] != seq_length + ) + if is_for_context_encoding and has_real_vision_inputs: + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + elif is_for_context_encoding and vision_embeddings.numel() > 0: + inputs_embeds = inputs_embeds + vision_embeddings.sum() * 0 + inputs_embeds = ( + inputs_embeds + vision_mask.sum().to(inputs_embeds.dtype) * 0 + ) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + hidden_states = inputs_embeds + + # Get KV cache for TKG and for model-local chunked CTE. + use_qwen_chunked_prefill = ( + is_for_context_encoding + and getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + ) + cache_size = ( + self.config.neuron_config.seq_len + if use_qwen_chunked_prefill + else self.n_positions + ) + if (not is_for_context_encoding) or use_qwen_chunked_prefill: + if self.kv_mgr is not None: + past_key_values = self.kv_mgr.get_cache( + seq_ids=seq_ids, + seq_len=cache_size, + is_for_context_encoding=is_for_context_encoding, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + **kwargs, + ) + + # Decoder layers + next_decoder_cache = () + deltanet_state_tensors = [] + cos_cache = None + sin_cache = None + + # Convert 2D attention_mask to 4D causal mask for CTE + if ( + attention_mask is not None + and attention_mask.ndim == 2 + and is_for_context_encoding + ): + causal = torch.ones( + (seq_length, seq_length), + dtype=torch.bool, + device=attention_mask.device, + ).tril() + padding_4d = attention_mask[:, None, None, :].to(torch.bool) + attention_mask = (causal[None, None, :, :] & padding_4d).to( + attention_mask.dtype + ) + + # Pre-compute mRoPE cos/sin + if rotary_position_ids is not None and rotary_position_ids.ndim == 3: + cos_cache, sin_cache = self.mrope_emb(inputs_embeds, rotary_position_ids) + + for idx, decoder_layer in enumerate(self.layers): + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) + + layer_outputs = decoder_layer( + hidden_states, + seq_ids=seq_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + active_mask=active_mask, + adapter_ids=adapter_ids, + cos_cache=cos_cache, + sin_cache=sin_cache, + rotary_position_ids=rotary_position_ids, + kv_mgr=self.kv_mgr, + get_kv_per_layer=False, + update_kv_per_layer=False, + idx=idx, + is_for_context_encoding=is_for_context_encoding, + seq_len=cache_size, + residual=None, + local_mask=local_attn_mask, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + padding_mask=padding_mask, + deltanet_padding_mask=deltanet_padding_mask, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + hidden_states = layer_outputs[0] + kv = layer_outputs[1] + next_decoder_cache += (kv,) + cos_cache, sin_cache = layer_outputs[2:4] + + # Collect DeltaNet state tensors + deltanet_states = layer_outputs[5] if len(layer_outputs) > 5 else None + if deltanet_states is not None: + deltanet_state_tensors.append(deltanet_states[0]) + deltanet_state_tensors.append(deltanet_states[1]) + + # Update KV cache + if update_cache: + next_decoder_cache = self.kv_mgr.update_cache( + is_for_context_encoding=is_for_context_encoding, + seq_ids=seq_ids, + position_ids=position_ids, + new_key_values=next_decoder_cache, + seq_len=cache_size, + windowed_context_encoding_window_idx=windowed_context_encoding_window_idx, + qwen_chunked_prefill_update=use_qwen_chunked_prefill, + qwen_chunked_valid_mask=deltanet_padding_mask.squeeze(-1) + if use_qwen_chunked_prefill + else None, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + + self._deltanet_updated_states = deltanet_state_tensors + + return (hidden_states, next_decoder_cache) + + def forward( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden=None, + adapter_ids=None, + accepted_indices=None, + current_length=None, + medusa_mask=None, + scatter_index=None, + slot_mapping=None, + active_block_table=None, + num_queries=None, + computed_context_lens=None, + tile_q_indices=None, + tile_block_tables=None, + tile_masks=None, + inputs_embeds=None, + kv_cache=None, + active_mask=None, + rotary_position_id=None, + vision_embeddings=None, + vision_mask=None, + ): + """Override base forward to append DeltaNet state tensors to output.""" + prev_hidden = self.set_none_if_empty(prev_hidden) + adapter_ids = self.set_none_if_empty(adapter_ids) + accepted_indices = self.set_none_if_empty(accepted_indices) + current_length = self.set_none_if_empty(current_length) + medusa_mask = self.set_none_if_empty(medusa_mask) + scatter_index = self.set_none_if_empty(scatter_index) + slot_mapping = self.set_none_if_empty(slot_mapping) + active_block_table = self.set_none_if_empty(active_block_table) + num_queries = self.set_none_if_empty(num_queries) + computed_context_lens = self.set_none_if_empty(computed_context_lens) + tile_q_indices = self.set_none_if_empty(tile_q_indices) + tile_block_tables = self.set_none_if_empty(tile_block_tables) + tile_masks = self.set_none_if_empty(tile_masks) + inputs_embeds = self.set_none_if_empty(inputs_embeds) + kv_cache = self.set_none_if_empty(kv_cache) + active_mask = self.set_none_if_empty(active_mask) + rotary_position_id = self.set_none_if_empty(rotary_position_id) + vision_embeddings = self.set_none_if_empty(vision_embeddings) + vision_mask = self.set_none_if_empty(vision_mask) + + is_for_context_encoding = position_ids.shape[-1] != 1 and not ( + hasattr(self.neuron_config, "speculation_length") + and position_ids.shape[-1] == self.neuron_config.speculation_length + ) + + seq_ids = seq_ids.to(torch.int32) + attn_mask = attention_mask + + hidden_states, updated_kv_cache = self.get_model_output( + input_ids=input_ids, + seq_ids=seq_ids, + attention_mask=attn_mask, + position_ids=position_ids, + active_mask=active_mask, + inputs_embeds=inputs_embeds, + adapter_ids=adapter_ids, + rotary_position_ids=rotary_position_id, + update_cache=True, + is_for_context_encoding=is_for_context_encoding, + padding_mask=None, + active_block_table=active_block_table, + scatter_index=slot_mapping + if getattr(self, "is_block_kv_layout", False) + else scatter_index, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + ) + + batch_size = input_ids.shape[0] + if not getattr(self, "sliced_hidden", False): + if not is_for_context_encoding: + pass + else: + if getattr(self.config, "use_qwen_hybrid_chunked_prefill", False): + if attention_mask is not None and attention_mask.ndim == 2: + index = ( + attention_mask.to(torch.long).sum(dim=1, keepdim=True) + - 1 + ).clamp(min=0) + else: + index = ( + (input_ids != self.padding_idx) + .sum(dim=1, keepdim=True) + .long() + - 1 + ).clamp(min=0) + else: + index = torch.max(position_ids, dim=1, keepdim=True).indices + index = index.unsqueeze(1).expand(batch_size, 1, self.hidden_size) + hidden_states = torch.gather(hidden_states, dim=1, index=index) + + logits = self.lm_head(hidden_states) + logits = logits.float() + + if hasattr(self.lm_head, "pad_size"): + if self.lm_head.gather_output: + rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) + world_size = 1 + else: + from neuronx_distributed.parallel_layers import parallel_state + + rank_id = self.rank_util.get_rank() + world_size = torch.distributed.get_world_size( + group=self.lm_head.tensor_parallel_group + ) + from neuronx_distributed_inference.models.model_base import ( + mask_padded_logits, + ) + + logits = mask_padded_logits( + logits, rank_id, world_size, pad_size=self.lm_head.pad_size + ) + + if self.on_device_sampling: + res = self._sample_on_device( + logits, sampling_params, False, is_for_context_encoding + ) + else: + res = logits + + outputs = [res] + if self.neuron_config.output_logits: + outputs += [logits] + outputs += updated_kv_cache + + # Append DeltaNet state tensors (for input_output_aliases) + if ( + not getattr(self.config, "use_hybrid_cache_manager", False) + and hasattr(self, "_deltanet_updated_states") + ): + outputs += self._deltanet_updated_states + + return outputs + + +# ============================================================ +# State Dict Converter (Dense -- no MoE weight handling) +# ============================================================ + + +def convert_qwen35_hf_to_neuron_state_dict(neuron_state_dict, config): + """Convert HF Qwen3.5/3.6-27B weights to NxDI format. + + Weight mappings per layer type: + + DeltaNet layers (linear_attention): + HF: layers.X.linear_attn.{in_proj_qkv, in_proj_z, in_proj_a, in_proj_b, + conv1d, A_log, dt_bias, norm, out_proj} + NxDI: projections keep names; conv1d/A_log/dt_bias are remapped into + ColumnParallelLinear parameter containers so NxD can shard them. + + Full attention layers: + HF: layers.X.self_attn.q_proj.weight: (12288, 5120) -- doubled for gate + NxDI: layers.X.self_attn.Wqkv.weight (fused Q+K+V, gate separated) + layers.X.self_attn.output_gate_proj.weight (gate part) + HF: layers.X.self_attn.{k_proj, v_proj, o_proj, q_norm, k_norm} + NxDI: layers.X.self_attn.{..., q_layernorm, k_layernorm} + + Dense MLP (all layers): + HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight (same names) + """ + # Add rank_util + neuron_state_dict["rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + def _reorder_deltanet_qkv_for_tp(qkv_weight: torch.Tensor) -> torch.Tensor: + """Pack [Q_all | K_all | V_all] into per-rank Q/K/V blocks. + + ColumnParallelLinear slices the first dimension contiguously. DeltaNet + needs each rank to receive its local query, key, and value heads + together, so the full HF tensor is repacked as: + [rank0 Q | rank0 K | rank0 V | rank1 Q | rank1 K | rank1 V | ...]. + """ + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + if num_k_heads % tp_degree != 0: + raise ValueError( + f"linear_num_key_heads={num_k_heads} must be divisible by tp_degree={tp_degree}" + ) + if num_v_heads % tp_degree != 0: + raise ValueError( + f"linear_num_value_heads={num_v_heads} must be divisible by tp_degree={tp_degree}" + ) + + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_weight = qkv_weight[:key_dim].reshape(num_k_heads, head_k_dim, -1) + k_weight = qkv_weight[key_dim : 2 * key_dim].reshape(num_k_heads, head_k_dim, -1) + v_weight = qkv_weight[2 * key_dim : 2 * key_dim + value_dim].reshape( + num_v_heads, head_v_dim, -1 + ) + local_k_heads = num_k_heads // tp_degree + local_v_heads = num_v_heads // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append( + q_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + k_weight[ + rank * local_k_heads : (rank + 1) * local_k_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + blocks.append( + v_weight[ + rank * local_v_heads : (rank + 1) * local_v_heads + ].reshape(-1, qkv_weight.shape[1]) + ) + return torch.cat(blocks, dim=0).contiguous() + + def _reorder_deltanet_qkv_channels_for_tp(channel_tensor: torch.Tensor) -> torch.Tensor: + """Repack a first-dimension Q/K/V channel tensor into TP rank blocks.""" + tp_degree = config.neuron_config.tp_degree + num_k_heads = config.linear_num_key_heads + num_v_heads = config.linear_num_value_heads + head_k_dim = config.linear_key_head_dim + head_v_dim = config.linear_value_head_dim + key_dim = num_k_heads * head_k_dim + value_dim = num_v_heads * head_v_dim + q_tensor = channel_tensor[:key_dim] + k_tensor = channel_tensor[key_dim : 2 * key_dim] + v_tensor = channel_tensor[2 * key_dim : 2 * key_dim + value_dim] + local_key_dim = key_dim // tp_degree + local_value_dim = value_dim // tp_degree + blocks = [] + for rank in range(tp_degree): + blocks.append(q_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append(k_tensor[rank * local_key_dim : (rank + 1) * local_key_dim]) + blocks.append( + v_tensor[rank * local_value_dim : (rank + 1) * local_value_dim] + ) + return torch.cat(blocks, dim=0).contiguous() + + # CRITICAL: Convert (1+weight) RMSNorm weights to standard RMSNorm weights. + # Qwen3.5 uses RMSNorm with `output = norm(x) * (1 + weight)` where weight + # is initialized to zeros. Standard NxDI RMSNorm uses `output = norm(x) * weight` + # where weight is initialized to ones. To convert: new_weight = old_weight + 1.0 + norm_keys_to_convert = [] + for l in range(config.num_hidden_layers): + norm_keys_to_convert.append(f"layers.{l}.input_layernorm.weight") + norm_keys_to_convert.append(f"layers.{l}.post_attention_layernorm.weight") + if config.layer_types[l] == "full_attention": + norm_keys_to_convert.append(f"layers.{l}.self_attn.q_norm.weight") + norm_keys_to_convert.append(f"layers.{l}.self_attn.k_norm.weight") + norm_keys_to_convert.append("norm.weight") + + for nk in norm_keys_to_convert: + if nk in neuron_state_dict: + old_val = neuron_state_dict[nk] + neuron_state_dict[nk] = old_val.float() + 1.0 + if "layers.0." in nk or nk == "norm.weight": + logger.debug( + f"[NORM FIX] {nk}: mean {old_val.float().mean():.4f} -> {neuron_state_dict[nk].mean():.4f}" + ) + else: + if "layers.0." in nk or nk == "norm.weight": + logger.warning(f"[NORM FIX] key not found: {nk}") + + for l in range(config.num_hidden_layers): + layer_type = config.layer_types[l] + + # === DeltaNet layers === + if layer_type == "linear_attention": + qkv_key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + if qkv_key in neuron_state_dict and config.neuron_config.tp_degree > 1: + neuron_state_dict[qkv_key] = _reorder_deltanet_qkv_for_tp( + neuron_state_dict[qkv_key] + ) + + conv_key = f"layers.{l}.linear_attn.conv1d.weight" + conv_weight_key = f"layers.{l}.linear_attn.conv1d_weight.weight" + if conv_key in neuron_state_dict: + conv_weight = neuron_state_dict.pop(conv_key) + if config.neuron_config.tp_degree > 1: + conv_weight = _reorder_deltanet_qkv_channels_for_tp(conv_weight) + neuron_state_dict[conv_weight_key] = conv_weight.squeeze(1).contiguous() + + for vector_name in ("A_log", "dt_bias"): + vector_key = f"layers.{l}.linear_attn.{vector_name}" + vector_weight_key = f"layers.{l}.linear_attn.{vector_name}_weight.weight" + if vector_key in neuron_state_dict: + neuron_state_dict[vector_weight_key] = ( + neuron_state_dict.pop(vector_key).reshape(-1, 1).contiguous() + ) + + # === Attention layers === + if layer_type == "full_attention": + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, + config.neuron_config.tp_degree, + dtype=torch.int32, + ) + + # QK norms: q_norm -> q_layernorm, k_norm -> k_layernorm + q_norm_key = f"layers.{l}.self_attn.q_norm.weight" + k_norm_key = f"layers.{l}.self_attn.k_norm.weight" + if q_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.q_layernorm.weight"] = ( + neuron_state_dict.pop(q_norm_key).detach().clone() + ) + if k_norm_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.k_layernorm.weight"] = ( + neuron_state_dict.pop(k_norm_key).detach().clone() + ) + + # q_proj is doubled: (12288, 5120) = (num_heads * head_dim * 2, hidden) + # INTERLEAVED: [head0_query(256) | head0_gate(256) | head1_query(256) | ...] + q_proj_key = f"layers.{l}.self_attn.q_proj.weight" + if q_proj_key in neuron_state_dict: + q_proj_w = neuron_state_dict.pop(q_proj_key) + num_heads = config.num_attention_heads # 24 + head_dim = config.head_dim # 256 + q_proj_w = q_proj_w.reshape(num_heads, head_dim * 2, config.hidden_size) + query_w = q_proj_w[:, :head_dim, :] # (24, 256, 5120) + gate_w = q_proj_w[:, head_dim:, :] # (24, 256, 5120) + query_w = query_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + gate_w = gate_w.reshape( + num_heads * head_dim, config.hidden_size + ) # (6144, 5120) + + neuron_state_dict[q_proj_key] = query_w + neuron_state_dict[f"layers.{l}.self_attn.output_gate_proj.weight"] = ( + gate_w + ) + + # Fuse QKV + if config.neuron_config.fused_qkv: + q_key = f"layers.{l}.self_attn.q_proj.weight" + k_key = f"layers.{l}.self_attn.k_proj.weight" + v_key = f"layers.{l}.self_attn.v_proj.weight" + if q_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.self_attn.Wqkv.weight"] = torch.cat( + [ + neuron_state_dict[q_key], + neuron_state_dict[k_key], + neuron_state_dict[v_key], + ] + ) + del neuron_state_dict[q_key] + del neuron_state_dict[k_key] + del neuron_state_dict[v_key] + + # Dense MLP: no weight conversion needed -- HF and NxDI use same names + # HF: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + # NxDI: layers.X.mlp.{gate_proj, up_proj, down_proj}.weight + + gc.collect() + + return neuron_state_dict + + +# ============================================================ +# Custom ModelWrapper and DecoderModelInstance for DeltaNet state aliasing +# ============================================================ + + +class Qwen35DecoderModelInstance(DecoderModelInstance): + """Custom DecoderModelInstance that adds DeltaNet state buffers to input_output_aliases.""" + + def get(self, bucket_rank, **kwargs): + """Override to add DeltaNet state aliases after KV cache aliases.""" + module, input_output_aliases = super().get(bucket_rank, **kwargs) + + num_output_from_trace = 1 if not self.neuron_config.output_logits else 2 + + if module.kv_mgr is not None: + num_kv = len(module.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start_idx = num_output_from_trace + num_kv + + if ( + not getattr(module.config, "use_hybrid_cache_manager", False) + and hasattr(module, "_deltanet_state_params") + ): + for i, param in enumerate(module._deltanet_state_params): + input_output_aliases[param] = state_start_idx + i + + return module, input_output_aliases + + +class Qwen35ModelWrapper(ModelWrapper): + """Custom ModelWrapper for VL support with mRoPE and vision inputs.""" + + def get_model_instance(self): + return Qwen35DecoderModelInstance( + model_cls=self.model_cls, + config=self.config, + **self.model_init_kwargs, + ) + + def input_generator(self): + """Generate inputs including mrope_position_ids, vision_embeddings, and vision_mask.""" + base_inputs = super().input_generator() + extended_inputs = [] + + for bucket_inputs in base_inputs: + input_ids = bucket_inputs[0] + batch_size = input_ids.shape[0] + n_active_tokens = input_ids.shape[1] + + is_cte = n_active_tokens > 1 + + if is_cte: + mrope_position_ids = ( + torch.arange(0, n_active_tokens, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + vision_embeddings = torch.zeros( + (batch_size, n_active_tokens, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, n_active_tokens, 1), + fill_value=n_active_tokens - 1, + dtype=torch.int32, + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + vision_embeddings = torch.zeros( + (0,), dtype=self.config.neuron_config.torch_dtype + ) + vision_mask = torch.zeros((0,), dtype=torch.int32) + + padded = list(bucket_inputs) + while len(padded) < 21: + padded.append(torch.zeros((0,), dtype=torch.int32)) + padded.append(mrope_position_ids) # position 21 + padded.append(vision_embeddings) # position 22 + padded.append(vision_mask) # position 23 + + extended_inputs.append(tuple(padded)) + + return extended_inputs + + def pad_inputs(self, *args, pad_type="first_fit"): + """Override to pad mrope_position_ids and vision inputs to bucket size.""" + orig_mrope = args[21] if len(args) >= 22 else None + orig_vis_emb = args[22] if len(args) >= 23 else None + orig_vis_mask = args[23] if len(args) >= 24 else None + + padded_args = super().pad_inputs(*args, pad_type=pad_type) + + if len(padded_args) >= 24 and orig_mrope is not None: + padded_seq_len = padded_args[0].shape[1] + batch_size = padded_args[0].shape[0] + is_cte = padded_seq_len > 1 + + if is_cte: + current_mrope = orig_mrope + current_vis_emb = orig_vis_emb + current_vis_mask = orig_vis_mask + + if ( + current_mrope.ndim == 3 + and current_mrope.shape[-1] != padded_seq_len + ): + orig_len = current_mrope.shape[-1] + pad_size = padded_seq_len - orig_len + last_pos = current_mrope[:, :, -1:] + pad_offsets = torch.arange( + 1, pad_size + 1, dtype=current_mrope.dtype + ) + pad_offsets = ( + pad_offsets.unsqueeze(0).unsqueeze(0).expand(3, batch_size, -1) + ) + mrope_pad = last_pos + pad_offsets + mrope_position_ids = torch.cat([current_mrope, mrope_pad], dim=-1) + elif current_mrope.ndim == 3: + mrope_position_ids = current_mrope + else: + mrope_position_ids = ( + torch.arange(0, padded_seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + + if ( + current_vis_emb is not None + and current_vis_emb.ndim == 3 + and current_vis_emb.shape[1] < padded_seq_len + ): + pad_emb = torch.zeros( + ( + batch_size, + padded_seq_len - current_vis_emb.shape[1], + current_vis_emb.shape[2], + ), + dtype=current_vis_emb.dtype, + ) + vision_embeddings = torch.cat([current_vis_emb, pad_emb], dim=1) + elif current_vis_emb is not None and current_vis_emb.ndim == 3: + vision_embeddings = current_vis_emb[:, :padded_seq_len] + else: + vision_embeddings = torch.zeros( + (batch_size, padded_seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + + if ( + current_vis_mask is not None + and current_vis_mask.ndim == 3 + and current_vis_mask.shape[1] < padded_seq_len + ): + pad_mask = torch.full( + (batch_size, padded_seq_len - current_vis_mask.shape[1], 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + vision_mask = torch.cat([current_vis_mask, pad_mask], dim=1) + elif current_vis_mask is not None and current_vis_mask.ndim == 3: + vision_mask = current_vis_mask[:, :padded_seq_len] + else: + vision_mask = torch.full( + (batch_size, padded_seq_len, 1), + fill_value=padded_seq_len - 1, + dtype=torch.int32, + ) + + padded_args = ( + *padded_args[:21], + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + + padded_args = list(padded_args) + padded_args[23] = padded_args[23].clamp(max=padded_seq_len - 1) + padded_args = tuple(padded_args) + + return padded_args + + +# ============================================================ +# Top-Level Model +# ============================================================ + + +class NeuronQwen35ForCausalLM(NeuronBaseForCausalLM): + _model_cls = NeuronQwen35Model + + def get_model_wrapper_cls(self): + """Return custom ModelWrapper with DeltaNet state aliasing.""" + return Qwen35ModelWrapper + + @staticmethod + def load_hf_model(model_path, **kwargs): + """Load HF model weights. + + The model is a VL model (Qwen3_5ForConditionalGeneration) but we + only need the text backbone. + """ + from transformers import AutoModelForCausalLM + + kwargs.setdefault("trust_remote_code", True) + return AutoModelForCausalLM.from_pretrained(model_path, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen35InferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, config): + """Strip VL wrapper prefix and convert to NxDI format.""" + new_sd = {} + for k, v in state_dict.items(): + if k.startswith("language_model."): + new_k = k.replace("language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.language_model."): + new_k = k.replace("model.language_model.", "", 1) + new_sd[new_k] = v + elif k.startswith("model.visual") or k.startswith("visual"): + continue # Skip vision encoder + elif k.startswith("model."): + new_sd[k.replace("model.", "", 1)] = v + elif k.startswith("mtp."): + continue # Skip MTP + elif k.startswith("lm_head."): + new_sd[k] = v + else: + new_sd[k] = v + + return convert_qwen35_hf_to_neuron_state_dict(new_sd, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def _copy_past_key_values(self, outputs): + """Override to also copy DeltaNet state buffers on CPU.""" + super()._copy_past_key_values(outputs) + if getattr(self.config, "use_hybrid_cache_manager", False): + return + + num_output_from_trace = 1 + if ( + self.neuron_config.output_logits + and self.neuron_config.on_device_sampling_config + ): + num_output_from_trace = 2 + + if ( + hasattr(self, "token_generation_model") + and self.token_generation_model is not None + ): + tkg_model = self.token_generation_model.model + cte_model = self.context_encoding_model.model + else: + return + + if tkg_model.kv_mgr is not None: + num_kv = len(tkg_model.kv_mgr.past_key_values) + else: + num_kv = 0 + + state_start = num_output_from_trace + num_kv + + tkg_params = getattr(tkg_model, "_deltanet_state_params", []) + cte_params = getattr(cte_model, "_deltanet_state_params", []) + + if len(tkg_params) > 0 and state_start + len(tkg_params) <= len(outputs): + for i, (tkg_param, cte_param) in enumerate(zip(tkg_params, cte_params)): + new_state = outputs[state_start + i] + tkg_param.data = new_state + cte_param.data = new_state + + def get_required_kwargs(self): + """Return extra kwargs for HF generation loop.""" + return ["llava_args"] + + def _get_model_outputs( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + medusa_args, + llava_args, + slot_mapping=None, + block_table=None, + full_context_lens=None, + computed_context_lens=None, + tf_args=None, + ): + """Override to pass all 24 positional args explicitly.""" + is_prefill = self._is_prefill(position_ids) or ( + getattr(self.config, "use_qwen_hybrid_chunked_prefill", False) + and input_ids.shape[-1] > 1 + ) + + seq_len = input_ids.shape[1] + batch_size = input_ids.shape[0] + + if llava_args and len(llava_args) >= 2: + vision_embeddings = llava_args[0] + vision_mask = llava_args[1] + if len(llava_args) >= 3: + mrope_position_ids = llava_args[2] + else: + mrope_position_ids = None + elif is_prefill: + vision_embeddings = torch.zeros( + (batch_size, seq_len, self.config.hidden_size), + dtype=self.config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + (batch_size, seq_len, 1), + fill_value=seq_len - 1, + dtype=torch.int32, + ) + mrope_position_ids = None + else: + vision_embeddings = torch.zeros((0,), dtype=torch.float32) + vision_mask = torch.zeros((0,), dtype=torch.int32) + mrope_position_ids = None + + if is_prefill: + if mrope_position_ids is None: + mrope_position_ids = ( + torch.arange(0, seq_len, dtype=torch.int32) + .unsqueeze(0) + .unsqueeze(0) + .expand(3, batch_size, -1) + .contiguous() + ) + else: + mrope_position_ids = torch.zeros((0,), dtype=torch.int32) + + empties = [torch.empty(0) for _ in range(14)] + + if is_prefill: + ctx_bs = self.context_encoding_model.neuron_config.batch_size + output_logits = [] + + for cb in range(0, batch_size, ctx_bs): + cb_end = min(cb + ctx_bs, batch_size) + actual_chunk = cb_end - cb + + chunk_input_ids = input_ids[cb:cb_end] + chunk_attn_mask = attention_mask[cb:cb_end] + chunk_pos_ids = position_ids[cb:cb_end] + chunk_seq_ids = seq_ids[cb:cb_end] + chunk_sampling = sampling_params[cb:cb_end] + chunk_prev_hidden = ( + prev_hidden[cb:cb_end] + if prev_hidden is not None + and hasattr(prev_hidden, "ndim") + and prev_hidden.ndim > 0 + and prev_hidden.shape[0] > 0 + else prev_hidden + ) + chunk_adapter_ids = ( + adapter_ids[cb:cb_end] + if adapter_ids is not None + and hasattr(adapter_ids, "ndim") + and adapter_ids.ndim > 0 + and adapter_ids.shape[0] > 0 + else adapter_ids + ) + + if mrope_position_ids.ndim == 3: + chunk_mrope = mrope_position_ids[:, cb:cb_end, :] + else: + chunk_mrope = mrope_position_ids + + if vision_embeddings.ndim == 3: + chunk_vis_emb = vision_embeddings[cb:cb_end] + chunk_vis_mask = vision_mask[cb:cb_end] + else: + chunk_vis_emb = vision_embeddings + chunk_vis_mask = vision_mask + + if actual_chunk < ctx_bs: + pad_n = ctx_bs - actual_chunk + chunk_input_ids = torch.cat( + [chunk_input_ids, chunk_input_ids[:1].expand(pad_n, -1)], dim=0 + ) + chunk_attn_mask = torch.cat( + [chunk_attn_mask, chunk_attn_mask[:1].expand(pad_n, -1)], dim=0 + ) + chunk_pos_ids = torch.cat( + [chunk_pos_ids, chunk_pos_ids[:1].expand(pad_n, -1)], dim=0 + ) + pad_seq = torch.arange( + batch_size, batch_size + pad_n, dtype=chunk_seq_ids.dtype + ) + chunk_seq_ids = torch.cat([chunk_seq_ids, pad_seq], dim=0) + chunk_sampling = torch.cat( + [chunk_sampling, chunk_sampling[:1].expand(pad_n, -1)], dim=0 + ) + if ( + chunk_prev_hidden is not None + and hasattr(chunk_prev_hidden, "ndim") + and chunk_prev_hidden.ndim > 0 + and chunk_prev_hidden.shape[0] > 0 + ): + chunk_prev_hidden = torch.cat( + [ + chunk_prev_hidden, + chunk_prev_hidden[:1].expand(pad_n, -1), + ], + dim=0, + ) + if ( + chunk_adapter_ids is not None + and hasattr(chunk_adapter_ids, "ndim") + and chunk_adapter_ids.ndim > 0 + and chunk_adapter_ids.shape[0] > 0 + ): + chunk_adapter_ids = torch.cat( + [ + chunk_adapter_ids, + chunk_adapter_ids[:1].expand(pad_n, -1), + ], + dim=0, + ) + if chunk_mrope.ndim == 3: + chunk_mrope = torch.cat( + [chunk_mrope, chunk_mrope[:, :1, :].expand(-1, pad_n, -1)], + dim=1, + ) + if chunk_vis_emb.ndim == 3: + chunk_vis_emb = torch.cat( + [ + chunk_vis_emb, + torch.zeros( + (pad_n,) + chunk_vis_emb.shape[1:], + dtype=chunk_vis_emb.dtype, + ), + ], + dim=0, + ) + chunk_vis_mask = torch.cat( + [ + chunk_vis_mask, + torch.full( + (pad_n,) + chunk_vis_mask.shape[1:], + fill_value=seq_len - 1, + dtype=chunk_vis_mask.dtype, + ), + ], + dim=0, + ) + + chunk_out = self.context_encoding_model( + chunk_input_ids, + chunk_attn_mask, + chunk_pos_ids, + chunk_seq_ids, + chunk_sampling, + chunk_prev_hidden, + chunk_adapter_ids, + *empties, + chunk_mrope, + chunk_vis_emb, + chunk_vis_mask, + ) + if actual_chunk < ctx_bs: + chunk_out = chunk_out[:actual_chunk] + output_logits.append(chunk_out) + + outputs = ( + torch.cat(output_logits, dim=0) + if len(output_logits) > 1 + else output_logits[0] + ) + self.kv_cache_populated = True + is_run_on_neuron = self.context_encoding_model.is_neuron() + else: + outputs = self.token_generation_model( + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + prev_hidden, + adapter_ids, + *empties, + mrope_position_ids, + vision_embeddings, + vision_mask, + ) + is_run_on_neuron = self.token_generation_model.is_neuron() + + return outputs, is_run_on_neuron + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + optimization_level = "-O1" + else: + optimization_level = "-O1" + + compiler_args = ( + "--enable-saturate-infinity " + "--enable-mixed-precision-accumulation " + f"--model-type transformer {optimization_level} " + "--auto-cast=none " + ) + return compiler_args diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py new file mode 100644 index 00000000..761d7e95 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vision.py @@ -0,0 +1,819 @@ +""" +Qwen3.5-27B / Qwen3.6-27B (Dense) Vision Encoder for NeuronX Distributed Inference. + +Ports the Qwen3.5/3.6 ViT encoder to run on Neuron. The vision encoder +architecture is identical across Qwen3.5-27B and Qwen3.6-27B (same patch +embed, same rotary, same merger) -- only out_hidden_size changes vs the MoE +variant (5120 vs 2048, read from config). + +The vision encoder runs as a separate compiled model from the text decoder, +compiled and loaded via NeuronBaseForImageToText. +""" + +import logging +import math +import os +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# CRITICAL: Use finite negative value instead of -inf for Neuron attention masks. +# The Neuron compiler's bfloat16 handling of -inf produces NaN that bleeds from +# padding positions into ALL positions through the transformer layers. +# -65504.0 is large enough for softmax masking but avoids NaN overflow. +_MASK_NEG_INF = -65504.0 + +logger = logging.getLogger(__name__) + +# -- NxDI imports (available on Neuron instances) -- +try: + from neuronx_distributed_inference.models.application_base import ( + NeuronApplicationBase, + ) + from neuronx_distributed_inference.models.model_wrapper import ModelWrapper + from neuronx_distributed_inference.modules.attention.attention_base import ( + NeuronAttentionBase, + ) + from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding + from neuronx_distributed.parallel_layers import layers as nxd_layers +except ImportError: + logger.warning( + "NxDI imports unavailable -- vision module can only be used on Neuron instances" + ) + +# -- HuggingFace imports for patch embed (runs on CPU) -- +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeVisionPatchEmbed, + Qwen3_5MoeVisionPatchMerger, + Qwen3_5MoeVisionRotaryEmbedding, + ) +except ImportError: + try: + # transformers 4.57+ uses Qwen3VL* class names + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen3VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen3VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + try: + # Older transformers uses Qwen2VL* class names + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLVisionPatchEmbed as Qwen3_5MoeVisionPatchEmbed, + Qwen2VLVisionPatchMerger as Qwen3_5MoeVisionPatchMerger, + Qwen2VLVisionRotaryEmbedding as Qwen3_5MoeVisionRotaryEmbedding, + ) + except ImportError: + Qwen3_5MoeVisionPatchEmbed = None + Qwen3_5MoeVisionPatchMerger = None + Qwen3_5MoeVisionRotaryEmbedding = None + + +def apply_rotary_pos_emb_vision(q, k, cos, sin): + """Apply rotary position embeddings to vision Q and K tensors. + + Uses rotate_half style (matching HF reference): + q_embed = (q * cos) + (rotate_half(q) * sin) + + Args: + q: (seq_len, num_heads, head_dim) + k: (seq_len, num_heads, head_dim) + cos: (seq_len, head_dim) + sin: (seq_len, head_dim) + """ + cos = cos.unsqueeze(-2) # (seq_len, 1, head_dim) + sin = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class NeuronQwen35VisionAttention(nn.Module): + """Vision attention for Qwen3.5 MoE. + + Uses fused QKV linear (no bias in Neuron port for efficiency). + Non-causal attention with block-diagonal mask for variable-length images. + """ + + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.num_heads = config.num_heads + self.head_dim = self.hidden_size // self.num_heads + self.scaling = self.head_dim**-0.5 + + # Fused QKV: (hidden_size -> 3 * hidden_size) with bias + self.qkv = nxd_layers.ColumnParallelLinear( + self.hidden_size, + 3 * self.hidden_size, + bias=True, + gather_output=True, + ) + self.proj = nxd_layers.RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + input_is_parallel=False, + ) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple + """ + seq_len = hidden_states.shape[0] + + # QKV projection + qkv = self.qkv(hidden_states) # (seq_len, 3 * hidden_size) + qkv = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim) + qkv = qkv.permute(1, 0, 2, 3) # (3, seq_len, num_heads, head_dim) + q, k, v = qkv.unbind(0) # each (seq_len, num_heads, head_dim) + + # Apply rotary embeddings + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb_vision(q, k, cos, sin) + + # Reshape for batched attention: (1, num_heads, seq_len, head_dim) + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + # Scaled dot-product attention + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * self.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + # Reshape back: (seq_len, hidden_size) + attn_output = attn_output.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + + # Output projection + attn_output = self.proj(attn_output) + return attn_output + + +class NeuronQwen35VisionMLP(nn.Module): + """Vision MLP with GELU activation.""" + + def __init__(self, config): + super().__init__() + self.linear_fc1 = nxd_layers.ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + gather_output=True, + ) + self.linear_fc2 = nxd_layers.RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + input_is_parallel=False, + ) + self.act_fn = nn.GELU() + + def forward(self, hidden_states): + return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_states))) + + +class NeuronQwen35VisionBlock(nn.Module): + """Single vision transformer block: LayerNorm + Attention + LayerNorm + MLP.""" + + def __init__(self, config): + super().__init__() + self.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + self.attn = NeuronQwen35VisionAttention(config) + self.mlp = NeuronQwen35VisionMLP(config) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states + + +class NeuronQwen35VisionModel(nn.Module): + """Qwen3.5 MoE Vision Encoder for Neuron. + + This is the nn.Module that gets compiled and traced onto Neuron. + Patch embedding, positional embedding, and rotary embedding are computed + on CPU in the ModelWrapper and passed as inputs. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [NeuronQwen35VisionBlock(config) for _ in range(config.depth)] + ) + # Merger: spatial_merge_size^2 * hidden_size -> out_hidden_size + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + """ + Args: + hidden_states: (seq_len, hidden_size) -- after patch_embed + pos_embed + attention_mask: (1, 1, seq_len, seq_len) block-diagonal mask + position_embeddings: (cos, sin) tuple for rotary + + Returns: + vision_embeddings: (merged_seq_len, out_hidden_size) + """ + for block in self.blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + # Apply merger: norm -> spatial merge -> fc1 -> gelu -> fc2 + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + + return hidden_states + + +class CPUVisionModel(nn.Module): + """CPU-only vision encoder (pure PyTorch, no Neuron dependencies). + + Used when HBM is insufficient to load the vision encoder on Neuron + alongside the text decoder (e.g., 27B dense model on trn2.3xlarge). + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.blocks = nn.ModuleList( + [self._make_block(config) for _ in range(config.depth)] + ) + self.merger_norm = nn.LayerNorm(config.hidden_size, eps=1e-6) + merger_hidden = config.hidden_size * (config.spatial_merge_size**2) + self.merger_fc1 = nn.Linear(merger_hidden, merger_hidden) + self.merger_act = nn.GELU() + self.merger_fc2 = nn.Linear(merger_hidden, config.out_hidden_size) + + @staticmethod + def _make_block(config): + """Build a single vision block with standard nn.Linear (no TP).""" + block = nn.Module() + block.norm1 = nn.LayerNorm(config.hidden_size, eps=1e-6) + block.norm2 = nn.LayerNorm(config.hidden_size, eps=1e-6) + + # Attention + attn = nn.Module() + attn.hidden_size = config.hidden_size + attn.num_heads = config.num_heads + attn.head_dim = config.hidden_size // config.num_heads + attn.scaling = attn.head_dim**-0.5 + attn.qkv = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True) + attn.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) + block.attn = attn + + # MLP + mlp = nn.Module() + mlp.linear_fc1 = nn.Linear( + config.hidden_size, config.intermediate_size, bias=True + ) + mlp.linear_fc2 = nn.Linear( + config.intermediate_size, config.hidden_size, bias=True + ) + mlp.act_fn = nn.GELU() + block.mlp = mlp + + return block + + def _forward_attention(self, attn, hidden_states, attention_mask, cos, sin): + seq_len = hidden_states.shape[0] + qkv = attn.qkv(hidden_states).reshape(seq_len, 3, attn.num_heads, attn.head_dim) + qkv = qkv.permute(1, 0, 2, 3) + q, k, v = qkv.unbind(0) + + if cos is not None and sin is not None: + cos_u = cos.unsqueeze(-2) + sin_u = sin.unsqueeze(-2) + + def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + q = (q * cos_u) + (rotate_half(q) * sin_u) + k = (k * cos_u) + (rotate_half(k) * sin_u) + + q = q.transpose(0, 1).unsqueeze(0) + k = k.transpose(0, 1).unsqueeze(0) + v = v.transpose(0, 1).unsqueeze(0) + + attn_weights = torch.matmul(q, k.transpose(-1, -2)) * attn.scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + out = torch.matmul(attn_weights, v) + out = out.squeeze(0).transpose(0, 1).reshape(seq_len, -1) + return attn.proj(out) + + def forward(self, hidden_states, attention_mask, cos, sin): + for block in self.blocks: + hidden_states = hidden_states + self._forward_attention( + block.attn, block.norm1(hidden_states), attention_mask, cos, sin + ) + hidden_states = hidden_states + block.mlp.linear_fc2( + block.mlp.act_fn(block.mlp.linear_fc1(block.norm2(hidden_states))) + ) + + hidden_states = self.merger_norm(hidden_states) + merge_size = self.config.spatial_merge_size + merged_hidden = self.config.hidden_size * (merge_size**2) + hidden_states = hidden_states.view(-1, merged_hidden) + hidden_states = self.merger_fc2(self.merger_act(self.merger_fc1(hidden_states))) + return hidden_states + + +class NeuronQwen35VisionModelWrapper(ModelWrapper): + """Wraps the vision encoder for NxDI tracing. + + Handles CPU-side operations that cannot be traced: + - Patch embedding (Conv3d) + - Positional embedding (Embedding + bilinear interpolation) + - Rotary position embedding computation + - Vision attention mask construction (block-diagonal) + - Sequence length bucketing and padding/unpadding + + Supports three modes: + 1. NxDI traced model (parallel layers) -- standard NxDI compilation + 2. Pre-compiled standalone model -- loaded from torch_neuronx.trace() output + 3. CPU-only model -- for when HBM is full (e.g., 27B dense on trn2.3xlarge) + """ + + def __init__(self, config, model_cls=None, **kwargs): + if model_cls is not None: + super().__init__(config, model_cls, **kwargs) + else: + # Standalone mode: no NxDI model_cls + nn.Module.__init__(self) + self.vision_config = config + self._compiled_model = None # Set by load_compiled() -- single bucket + self._compiled_buckets = None # Set by load_compiled() -- multi-bucket dict + self._cpu_model = None # Set by load_cpu_model() + + # These HF modules run on CPU, outside the traced graph + if Qwen3_5MoeVisionPatchEmbed is not None: + self.patch_embed = Qwen3_5MoeVisionPatchEmbed(config) + self.pos_embed = nn.Embedding( + config.num_position_embeddings, config.hidden_size + ) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3_5MoeVisionRotaryEmbedding(head_dim // 2) + else: + logger.warning("HF Qwen3.5 MoE vision classes not available") + + self.vision_seq_len_buckets = kwargs.get( + "vision_seq_len_buckets", [1024, 4096, 16384] + ) + + def load_compiled(self, compiled_model_path): + """Load pre-compiled standalone vision encoder(s). + + Supports two modes: + 1. Single .pt file: Legacy mode, loads one compiled model for one bucket size. + 2. Directory with multiple .pt files: Multi-bucket mode. Files must be named + 'vision_encoder_{bucket_size}.pt' (e.g., 'vision_encoder_256.pt'). + Falls back to single 'vision_encoder.pt' in the directory. + + Args: + compiled_model_path: Path to a .pt file or directory containing bucket .pt files. + """ + import glob as glob_module + + logger.info(f"Loading pre-compiled vision encoder from {compiled_model_path}") + + if os.path.isfile(compiled_model_path): + # Single file mode (legacy) + self._compiled_model = torch.jit.load(compiled_model_path) + self._compiled_buckets = None + logger.info("Vision encoder loaded successfully (single bucket)") + elif os.path.isdir(compiled_model_path): + # Directory mode: look for bucket-specific files + bucket_files = sorted( + glob_module.glob( + os.path.join(compiled_model_path, "vision_encoder_*.pt") + ) + ) + if bucket_files: + self._compiled_buckets = {} + for bf in bucket_files: + # Extract bucket size from filename: vision_encoder_256.pt -> 256 + basename = os.path.basename(bf) + try: + bucket_size = int( + basename.replace("vision_encoder_", "").replace(".pt", "") + ) + self._compiled_buckets[bucket_size] = torch.jit.load(bf) + logger.info(f" Loaded vision bucket {bucket_size} from {bf}") + except ValueError: + logger.warning(f" Skipping unrecognized file: {bf}") + self._compiled_model = None + # Update vision_seq_len_buckets to match compiled buckets + self.vision_seq_len_buckets = sorted(self._compiled_buckets.keys()) + logger.info( + f"Vision encoder loaded with {len(self._compiled_buckets)} buckets: " + f"{self.vision_seq_len_buckets}" + ) + else: + # Fall back to single vision_encoder.pt in directory + single_path = os.path.join(compiled_model_path, "vision_encoder.pt") + if os.path.exists(single_path): + self._compiled_model = torch.jit.load(single_path) + self._compiled_buckets = None + logger.info( + "Vision encoder loaded successfully (single file in dir)" + ) + else: + raise FileNotFoundError( + f"No vision encoder files found in {compiled_model_path}" + ) + else: + raise FileNotFoundError( + f"Vision encoder path not found: {compiled_model_path}" + ) + + def load_vision_weights_from_hf(self, model_path): + """Load patch_embed and pos_embed weights from HF safetensors. + + Args: + model_path: Path to HF model directory + """ + from pathlib import Path + from safetensors import safe_open + + st_files = sorted( + p + for p in Path(model_path).glob("*.safetensors") + if p.suffix == ".safetensors" + ) + loaded = 0 + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key == "model.visual.patch_embed.proj.weight": + self.patch_embed.proj.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.patch_embed.proj.bias": + self.patch_embed.proj.bias.data.copy_(f.get_tensor(key)) + loaded += 1 + elif key == "model.visual.pos_embed.weight": + self.pos_embed.weight.data.copy_(f.get_tensor(key)) + loaded += 1 + logger.info(f"Loaded {loaded} CPU-side vision weight tensors from HF") + + def load_cpu_model(self, model_path): + """Load a CPU-only vision encoder from HF safetensors. + + Use this when HBM is insufficient for the Neuron-compiled vision encoder + (e.g., 27B dense model fills trn2.3xlarge HBM). + + Args: + model_path: Path to HF model directory with safetensors + """ + from pathlib import Path + from safetensors import safe_open + + config = self.vision_config + cpu_model = CPUVisionModel(config) + + # Build key mapping from HF safetensors to CPU model + key_map = {} + for i in range(config.depth): + hf_pre = f"model.visual.blocks.{i}" + loc_pre = f"blocks.{i}" + for suffix in [ + "attn.qkv.weight", + "attn.qkv.bias", + "attn.proj.weight", + "attn.proj.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + "mlp.linear_fc2.weight", + "mlp.linear_fc2.bias", + "norm1.weight", + "norm1.bias", + "norm2.weight", + "norm2.bias", + ]: + key_map[f"{hf_pre}.{suffix}"] = f"{loc_pre}.{suffix}" + + key_map["model.visual.merger.norm.weight"] = "merger_norm.weight" + key_map["model.visual.merger.norm.bias"] = "merger_norm.bias" + key_map["model.visual.merger.linear_fc1.weight"] = "merger_fc1.weight" + key_map["model.visual.merger.linear_fc1.bias"] = "merger_fc1.bias" + key_map["model.visual.merger.linear_fc2.weight"] = "merger_fc2.weight" + key_map["model.visual.merger.linear_fc2.bias"] = "merger_fc2.bias" + + st_files = sorted(Path(model_path).glob("model*.safetensors")) + loaded = 0 + state_dict = cpu_model.state_dict() + + for sf_path in st_files: + with safe_open(str(sf_path), framework="pt") as f: + for key in f.keys(): + if key in key_map: + local_key = key_map[key] + if local_key in state_dict: + state_dict[local_key].copy_(f.get_tensor(key)) + loaded += 1 + + cpu_model.load_state_dict(state_dict) + cpu_model = cpu_model.to(torch.bfloat16).eval() + self._cpu_model = cpu_model + logger.info( + f"Loaded CPU vision encoder: {loaded} weights, " + f"{sum(p.numel() for p in cpu_model.parameters()) / 1e6:.1f}M params" + ) + + def _get_vision_bucket(self, seq_len): + """Find the smallest bucket that fits the sequence length.""" + for bucket in sorted(self.vision_seq_len_buckets): + if seq_len <= bucket: + return bucket + return self.vision_seq_len_buckets[-1] + + def rot_pos_emb(self, grid_thw): + """Compute rotary positional embeddings for vision tokens. + + Returns: (total_tokens, head_dim) tensor of rotary frequencies. + """ + merge_size = self.vision_config.spatial_merge_size + grid_thw_list = grid_thw.tolist() + + max_hw = max(max(h, w) for _, h, w in grid_thw_list) + freq_table = self.rotary_pos_emb(max_hw) + device = freq_table.device + + total_tokens = sum(t * h * w for t, h, w in grid_thw_list) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw_list: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + """Bilinear interpolation of positional embeddings for variable resolution.""" + grid_thw_list = grid_thw.tolist() + grid_ts = [row[0] for row in grid_thw_list] + grid_hs = [row[1] for row in grid_thw_list] + grid_ws = [row[2] for row in grid_thw_list] + device = self.pos_embed.weight.device + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in grid_thw_list: + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=self.pos_embed.weight.dtype, device=device + ) + pos_embeds = self.pos_embed(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + merge_size = self.vision_config.spatial_merge_size + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + + return torch.cat(patch_pos_embeds_permute) + + def _build_vision_attention_mask(self, grid_thw, seq_len, dtype): + """Build block-diagonal attention mask for variable-length images. + + Each image gets its own attention block (no cross-image attention). + """ + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] + ).cumsum(dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + + # Build block-diagonal mask + mask = torch.full((seq_len, seq_len), _MASK_NEG_INF, dtype=dtype) + for i in range(len(cu_seqlens) - 1): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + mask[start:end, start:end] = 0.0 + + return mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len) + + def forward(self, pixel_values, image_grid_thw): + """Run vision encoding (CPU preprocessing + Neuron traced model). + + Args: + pixel_values: Raw pixel values from HF processor + image_grid_thw: (num_images, 3) -- temporal, height, width in patches + + Returns: + vision_embeddings: (total_merged_tokens, out_hidden_size) + """ + # 1. Patch embedding (CPU, Conv3d) + hidden_states = self.patch_embed(pixel_values) + + # 2. Positional embedding (CPU, bilinear interpolation) + pos_embeds = self.fast_pos_embed_interpolate(image_grid_thw) + hidden_states = hidden_states + pos_embeds + + # 3. Rotary position embeddings (CPU) + rotary_pos_emb = self.rot_pos_emb(image_grid_thw) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + + # 4. Vision attention mask (block-diagonal) + seq_len = hidden_states.shape[0] + attention_mask = self._build_vision_attention_mask( + image_grid_thw, seq_len, hidden_states.dtype + ) + + # 5. Bucket and pad for Neuron compilation + bucket_len = self._get_vision_bucket(seq_len) + cos, sin = position_embeddings + if seq_len < bucket_len: + pad_len = bucket_len - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + cos = F.pad(cos, (0, 0, 0, pad_len)) + sin = F.pad(sin, (0, 0, 0, pad_len)) + # Extend mask with _MASK_NEG_INF for padded positions (NOT -inf, which causes NaN on Neuron) + mask = torch.full( + (1, 1, bucket_len, bucket_len), _MASK_NEG_INF, dtype=hidden_states.dtype + ) + mask[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask + + # 6. Run vision model (Neuron compiled or CPU fallback) + if self._compiled_buckets is not None: + # Multi-bucket mode: select the compiled model for this bucket + if bucket_len not in self._compiled_buckets: + raise RuntimeError( + f"No compiled vision encoder for bucket size {bucket_len}. " + f"Available buckets: {sorted(self._compiled_buckets.keys())}. " + f"Input seq_len={seq_len} requires bucket {bucket_len}." + ) + compiled_model = self._compiled_buckets[bucket_len] + vision_output = compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._compiled_model is not None: + # Single compiled model (legacy) + vision_output = self._compiled_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + elif self._cpu_model is not None: + # CPU-only mode: run vision encoder on CPU (no bucketing/padding needed + # but we pad anyway for consistency with the same merger math) + with torch.no_grad(): + vision_output = self._cpu_model( + hidden_states.to(torch.bfloat16), + attention_mask.to(torch.bfloat16), + cos.to(torch.bfloat16), + sin.to(torch.bfloat16), + ) + else: + # NxDI traced model: takes (hidden_states, attention_mask, position_embeddings) + vision_output = self.model(hidden_states, attention_mask, (cos, sin)) + + # 7. Unpad: only keep valid merged tokens + merge_area = self.vision_config.spatial_merge_size**2 + total_merged_tokens = sum( + t + * (h // self.vision_config.spatial_merge_size) + * (w // self.vision_config.spatial_merge_size) + for t, h, w in image_grid_thw.tolist() + ) + vision_output = vision_output[:total_merged_tokens] + + return vision_output + + +class NeuronQwen35VisionForImageEncoding(NeuronApplicationBase): + """Standalone application class for vision encoding (for testing).""" + + model_cls = NeuronQwen35VisionModel + model_wrapper_cls = NeuronQwen35VisionModelWrapper + + @staticmethod + def prepare_input_args(image_path, processor): + """Prepare vision inputs from an image path. + + Args: + image_path: Path to image file + processor: HF AutoProcessor + + Returns: + pixel_values, image_grid_thw + """ + from PIL import Image + + image = Image.open(image_path).convert("RGB") + inputs = processor(images=image, return_tensors="pt") + return inputs["pixel_values"], inputs["image_grid_thw"] diff --git a/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py new file mode 100644 index 00000000..e3afbb1b --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/modeling_qwen35_vl.py @@ -0,0 +1,662 @@ +""" +Qwen3.5-27B / Qwen3.6-27B Vision-Language Model Orchestrator for NeuronX Distributed Inference. + +This is the top-level VL model that wires together: +- The vision encoder (modeling_qwen35_vision.py) +- The text decoder (modeling_qwen35.py, dense model with vision injection) + +It handles: +- Multimodal RoPE (mRoPE) with interleaved layout +- Vision embedding injection via scatter_by_index_put +- Separate compilation and loading of vision and text models +- The CTE+TKG generation loop with vision inputs + +Architecture follows the NxDI NeuronBaseForImageToText pattern established +by Qwen3-VL in SDK 2.28, adapted for Qwen3.5/3.6 dense model's unique features: +- No deepstack (Qwen3.5/3.6 does not use intermediate vision feature injection) +- DeltaNet linear attention layers in the text decoder +- Dense SwiGLU MLP layers in the text decoder +- Interleaved mRoPE (THWTHW... layout) instead of Qwen3-VL's section-based layout +""" + +import logging +import os +from typing import Optional + +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +# NxDI imports +try: + from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, + ) + from neuronx_distributed_inference.models.config import NeuronConfig + + HAS_NXDI_VL = True +except ImportError: + HAS_NXDI_VL = False + logger.warning("NxDI VL base classes not available -- VL model requires SDK 2.28+") + +# Local imports +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from src.modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) +except ImportError: + from modeling_qwen35 import ( + NeuronQwen35ForCausalLM, + NeuronQwen35Model, + Qwen35InferenceConfig, + Qwen35ModelWrapper, + ) + from modeling_qwen35_vision import ( + NeuronQwen35VisionModel, + NeuronQwen35VisionModelWrapper, + ) + + +def get_rope_index( + input_ids, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + spatial_merge_size=2, +): + """Compute 3D multimodal RoPE position IDs for Qwen3.5. + + Returns position_ids of shape (3, batch_size, seq_len) where: + - Axis 0: temporal position + - Axis 1: height position + - Axis 2: width position + + For text tokens, all 3 axes have the same sequential position. + For vision tokens, each axis encodes the spatial/temporal grid position. + + Also returns rope_deltas for use during TKG decoding. + + Adapted from HuggingFace Qwen3_5Model.get_rope_index(). + """ + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + image_grid_thw_list = ( + image_grid_thw.tolist() if image_grid_thw is not None else None + ) + video_grid_thw_list = ( + video_grid_thw.tolist() if video_grid_thw is not None else None + ) + + mrope_position_deltas = [] + total_input_ids = input_ids + + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.zeros( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, ids in enumerate(total_input_ids): + ids = ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + if len(vision_start_indices) > 0: + vision_tokens = ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = ids.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = image_grid_thw_list[image_index] + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = video_grid_thw_list[video_index] + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t = t + llm_grid_h = h // spatial_merge_size + llm_grid_w = w // spatial_merge_size + + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + t_index = ( + torch.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + + +class Qwen35VLInferenceConfig: + """Configuration for the full VL model (text + vision). + + Wraps the existing Qwen35InferenceConfig for text and adds + vision-specific settings. + """ + + def __init__( + self, + text_config, + vision_config, + image_token_id=248056, + video_token_id=248057, + vision_start_token_id=248053, + vision_end_token_id=248054, + spatial_merge_size=2, + vision_seq_len_buckets=None, + **kwargs, + ): + """ + Args: + text_config: Qwen35InferenceConfig instance for the text decoder + vision_config: dict with vision encoder hyperparams (depth, hidden_size, etc.) + image_token_id: Token ID for image placeholder tokens + video_token_id: Token ID for video placeholder tokens + vision_start_token_id: Token ID for <|vision_start|> + vision_end_token_id: Token ID for <|vision_end|> + spatial_merge_size: How many patches are merged (2 = 2x2 = 4 patches merged) + vision_seq_len_buckets: List of vision sequence length buckets for compilation + """ + self.text_config = text_config + self.vision_config = vision_config + self.image_token_id = image_token_id + self.video_token_id = video_token_id + self.vision_start_token_id = vision_start_token_id + self.vision_end_token_id = vision_end_token_id + self.spatial_merge_size = spatial_merge_size + self.vision_seq_len_buckets = vision_seq_len_buckets or [1024, 4096, 16384] + + +class NeuronQwen35VLForCausalLM: + """Top-level VL model for Qwen3.5/3.6-27B on Neuron. + + This class manages: + - Separate compilation/loading of vision encoder and text decoder + - CPU-side mRoPE computation + - Vision embedding injection into text decoder + - The CTE+TKG generation loop + + Note: This is NOT an NeuronBaseForImageToText subclass because the + text decoder (NeuronQwen35ForCausalLM) has extensive custom overrides + (DeltaNet state management, custom forward, custom ModelWrapper) that + don't fit the base class pattern. Instead, this class composes the two + models and handles the VL orchestration directly. + """ + + def __init__(self, model_path, text_config, vision_config=None, processor=None): + """ + Args: + model_path: Path to HF model directory + text_config: Qwen35InferenceConfig for text decoder + vision_config: Qwen35VLInferenceConfig (or None for text-only) + processor: HF AutoProcessor for image preprocessing + """ + self.model_path = model_path + self.text_config = text_config + self.vl_config = vision_config + self.processor = processor + + # Text decoder (existing implementation) + self.text_model = NeuronQwen35ForCausalLM( + model_path=model_path, config=text_config + ) + + # Vision encoder (lazy init -- only built if vl_config provided) + self.vision_model_wrapper = None + if vision_config is not None: + self._init_vision_model(vision_config) + + # mRoPE state + self.rope_deltas = None + + def _init_vision_model(self, vl_config): + """Initialize the vision encoder wrapper.""" + from types import SimpleNamespace + + vision_cfg = SimpleNamespace(**vl_config.vision_config) + self.vision_model_wrapper = NeuronQwen35VisionModelWrapper( + config=vision_cfg, + model_cls=None, # Standalone mode (no NxDI parallel layers) + vision_seq_len_buckets=vl_config.vision_seq_len_buckets, + ) + self._vl_config = vl_config + + def compile(self, compiled_model_path): + """Compile both text and vision models. + + For the vision encoder, use compile_vision_encoder.py separately + (standalone torch_neuronx.trace compilation). Then use load() to + load the pre-compiled vision encoder. + """ + # Compile text decoder + text_path = os.path.join(compiled_model_path, "text_model") + os.makedirs(text_path, exist_ok=True) + self.text_model.compile(text_path) + + # Vision encoder is compiled separately via compile_vision_encoder.py + if self.vision_model_wrapper is not None: + logger.info( + "Vision encoder must be compiled separately using " + "compile_vision_encoder.py. Use load() to load the " + "pre-compiled vision encoder." + ) + + def load(self, compiled_model_path, vision_compiled_path=None): + """Load both compiled models. + + Args: + compiled_model_path: Path to compiled text model (or parent dir) + vision_compiled_path: Path to compiled vision encoder .pt file. + If None, looks for 'vision_encoder.pt' in compiled_model_path. + """ + text_path = os.path.join(compiled_model_path, "text_model") + if os.path.exists(text_path): + self.text_model.load(text_path) + else: + # Backward compatibility: text model compiled at root + self.text_model.load(compiled_model_path) + + # Load vision encoder + if self.vision_model_wrapper is not None: + if vision_compiled_path is None: + vision_compiled_path = os.path.join( + compiled_model_path, "vision_encoder.pt" + ) + if os.path.exists(vision_compiled_path): + self.vision_model_wrapper.load_compiled(vision_compiled_path) + # Also load CPU-side weights (patch_embed, pos_embed) + self.vision_model_wrapper.load_vision_weights_from_hf(self.model_path) + logger.info("Vision encoder loaded from pre-compiled model") + else: + logger.warning( + f"No compiled vision encoder found at {vision_compiled_path}. " + "Vision encoding will not be available." + ) + + # Qwen3.5 stop token IDs (loaded from config/tokenizer) + _DEFAULT_EOS_TOKEN_IDS = { + 248044, # <|endoftext|> -- text config eos_token_id + 248046, # <|im_end|> -- tokenizer eos_token / end of assistant turn + } + + def generate( + self, + input_ids, + attention_mask=None, + pixel_values=None, + image_grid_thw=None, + video_grid_thw=None, + max_new_tokens=32, + temperature=0.0, + top_p=1.0, + top_k=0, + eos_token_ids=None, + **kwargs, + ): + """Generate text from text and/or vision inputs. + + Args: + input_ids: (batch_size, seq_len) token IDs + attention_mask: (batch_size, seq_len) attention mask + pixel_values: Vision pixel values from HF processor (or None for text-only) + image_grid_thw: (num_images, 3) grid dimensions + video_grid_thw: (num_videos, 3) grid dimensions + max_new_tokens: Maximum new tokens to generate + temperature: Sampling temperature (0.0 = greedy/argmax) + top_p: Nucleus sampling threshold (1.0 = disabled) + top_k: Top-k sampling (0 = disabled) + eos_token_ids: Set of token IDs to stop generation on + (default: {248044, 248046}) + + Returns: + generated_ids: (batch_size, seq_len + max_new_tokens) token IDs + """ + if eos_token_ids is None: + eos_token_ids = self._DEFAULT_EOS_TOKEN_IDS + + # Reset text model state for a fresh generation. + # This ensures CTE runs (not TKG) even if a prior generate() was called. + # DeltaNet recurrent states don't need explicit zeroing because the CTE + # NKI kernel always starts from zero state. + self.text_model.reset() + + has_vision = pixel_values is not None and pixel_values.numel() > 0 + + # Step 1: Compute 3D mRoPE position IDs + if has_vision and self._vl_config is not None: + position_ids, self.rope_deltas = get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + image_token_id=self._vl_config.image_token_id, + video_token_id=self._vl_config.video_token_id, + vision_start_token_id=self._vl_config.vision_start_token_id, + spatial_merge_size=self._vl_config.spatial_merge_size, + ) + else: + # Text-only: use standard sequential position IDs + seq_len = input_ids.shape[1] + position_ids = torch.arange(seq_len).unsqueeze(0) + self.rope_deltas = None + + # Step 2: Run vision encoder and prepare injection args + llava_args = [] + batch_size = input_ids.shape[0] + if has_vision and self.vision_model_wrapper is not None: + # The vision encoder processes both image and video frames identically + # (they share the same ViT architecture). The HF processor outputs a + # single pixel_values tensor for images, and video frames are treated + # as multiple images with temporal grid > 1. + vision_embeddings = self.vision_model_wrapper(pixel_values, image_grid_thw) + # vision_embeddings: (total_merged_tokens, out_hidden_size) + + # Build vision_mask: boolean mask of ALL vision token positions + # (both image_token_id and video_token_id placeholders) + image_token_id = self._vl_config.image_token_id + video_token_id = self._vl_config.video_token_id + vision_bool_mask = (input_ids == image_token_id) | ( + input_ids == video_token_id + ) # (BS, seq_len) + + # For batch_size=1 (primary path): extract positions from batch element 0. + # For batch_size>1: each element may have different image token positions; + # we'd need per-element scatter. Currently only batch_size=1 is supported + # for VL (the compiled model uses batch_size=1 for CTE). + if batch_size > 1: + logger.warning( + "VL generation with batch_size > 1 is not fully supported. " + "Using batch element 0 for vision scatter positions." + ) + + positions = ( + vision_bool_mask[0].nonzero(as_tuple=False).squeeze(-1) + ) # (n_vision_tokens,) + + # Reshape vision_embeddings to (1, n_vision_tokens, hidden_size) + n_vis = positions.shape[0] + hidden_size = vision_embeddings.shape[-1] + vis_emb = vision_embeddings[:n_vis].unsqueeze(0) # (1, n_vis, hidden) + + # Pad to match input sequence length for compiled graph compatibility + seq_len = input_ids.shape[1] + pad_limit = seq_len # Must match the bucket size + + # Pad vision_embeddings to (1, pad_limit, hidden_size) + if n_vis < pad_limit: + pad_emb = torch.zeros( + (1, pad_limit - n_vis, hidden_size), + dtype=vis_emb.dtype, + ) + vis_emb_padded = torch.cat([vis_emb, pad_emb], dim=1) + else: + vis_emb_padded = vis_emb[:, :pad_limit] + + # Pad positions to (1, pad_limit, 1) with a SAFE fill value. + # CRITICAL: fill_value must be a valid index (within [0, pad_limit-1]). + # Using pad_limit-1 targets the last position (always a padding slot) + # so index_put_ scatters zero embeddings there harmlessly. + # NOTE: Do NOT use large sentinel values (e.g., 2**30) as they cause + # DGE out-of-bounds crashes in the Neuron runtime. + positions_padded = torch.full( + (1, pad_limit, 1), + fill_value=pad_limit - 1, + dtype=torch.int32, + ) + positions_padded[0, :n_vis, 0] = positions[:pad_limit].to(torch.int32) + + llava_args = [vis_emb_padded, positions_padded] + + # Append 3D mRoPE position IDs for the text model. + # position_ids shape: (3, batch_size, seq_len) from get_rope_index. + # _get_model_outputs receives this at slot 21 and pre-computes + # mRoPE cos/sin in get_model_output() for all decoder layers. + if position_ids.ndim == 3: + mrope_pos = position_ids[:, :, :seq_len].to(torch.int32).contiguous() + llava_args.append(mrope_pos) + else: + vision_embeddings = None + + # Step 3: Context encoding (prefill) + generated_ids = input_ids.clone() + + # CRITICAL: Always pass an explicit attention_mask for CTE. + # The base class _infer_attention_mask() assumes sequential position_ids + # (position_ids[i] >= i). When position_ids come from mRoPE temporal + # axis (non-sequential, e.g., all vision tokens share position 4), + # the inferred mask incorrectly masks out most of the sequence. + # Fix: provide a real all-ones mask for the actual token positions. + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + # For slot 2 (position_ids): use SEQUENTIAL positions regardless of mRoPE. + # Slot 2 is only used for: (1) logit position selection via torch.max(), + # (2) attention mask inference (which we bypass with explicit mask above). + # The actual RoPE computation uses slot 21 (rotary_position_ids) from + # _get_model_outputs, NOT slot 2. Using sequential slot 2 ensures + # correct logit selection and avoids any position_ids-related issues. + seq_len = input_ids.shape[1] + cte_position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) + + with torch.no_grad(): + output = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=cte_position_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + llava_args=llava_args, + ) + + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Check EOS after first token + if next_token.item() in eos_token_ids: + return generated_ids + + # Step 4: Token generation (TKG) loop + for _ in range(max_new_tokens - 1): + pos_ids = torch.tensor([[generated_ids.shape[1] - 1]]) + if self.rope_deltas is not None: + pos_ids = pos_ids + self.rope_deltas + + last_token = generated_ids[:, -1:] + with torch.no_grad(): + output = self.text_model( + input_ids=last_token, + position_ids=pos_ids, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ) + logits = output[0] if isinstance(output, tuple) else output.logits + next_token = self._sample_token(logits[:, -1, :], temperature, top_p, top_k) + generated_ids = torch.cat([generated_ids, next_token.unsqueeze(-1)], dim=-1) + + # Stop on EOS + if next_token.item() in eos_token_ids: + break + + return generated_ids + + @staticmethod + def _sample_token(logits, temperature=0.0, top_p=1.0, top_k=0): + """Sample a token from logits with optional temperature/top-p/top-k. + + Args: + logits: (batch_size, vocab_size) unnormalized logits + temperature: Sampling temperature. 0.0 = greedy (argmax). + top_p: Nucleus sampling threshold. 1.0 = disabled. + top_k: Top-k filtering. 0 = disabled. + + Returns: + token_id: (batch_size,) sampled token IDs + """ + if temperature <= 0.0: + return torch.argmax(logits, dim=-1) + + # Apply temperature + logits = logits / temperature + + # Top-k filtering + if top_k > 0: + top_k = min(top_k, logits.shape[-1]) + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = float("-inf") + + # Top-p (nucleus) filtering + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + torch.softmax(sorted_logits, dim=-1), dim=-1 + ) + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift right so the first token above threshold is kept + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = False + # Scatter back to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + -1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = float("-inf") + + # Sample from the filtered distribution + probs = torch.softmax(logits, dim=-1) + return torch.multinomial(probs, num_samples=1).squeeze(-1) + + @staticmethod + def prepare_input_args(text_prompt, image_path, processor, role="user"): + """Prepare inputs for vision+text generation. + + Args: + text_prompt: Text prompt string + image_path: Path to image file (or None for text-only) + processor: HF AutoProcessor + role: Message role (default "user") + + Returns: + input_ids, attention_mask, vision_inputs dict + """ + content = [] + if image_path is not None: + import base64 + from pathlib import Path + + image_data = Path(image_path).read_bytes() + b64 = base64.b64encode(image_data).decode("utf-8") + content.append( + { + "type": "image", + "url": f"data:image/jpeg;base64,{b64}", + } + ) + content.append({"type": "text", "text": text_prompt}) + + messages = [{"role": role, "content": content}] + inputs = processor.apply_chat_template( + messages, + tokenize=True, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, + ) + + input_ids = inputs["input_ids"] + attention_mask = inputs.get("attention_mask", torch.ones_like(input_ids)) + + vision_inputs = {} + if "pixel_values" in inputs: + vision_inputs["pixel_values"] = inputs["pixel_values"] + if "image_grid_thw" in inputs: + vision_inputs["image_grid_thw"] = inputs["image_grid_thw"] + if "video_grid_thw" in inputs: + vision_inputs["video_grid_thw"] = inputs["video_grid_thw"] + + return input_ids, attention_mask, vision_inputs diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py new file mode 100644 index 00000000..7e78cdb9 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/__init__.py @@ -0,0 +1,10 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Custom NKI kernels for Qwen3.5-27B / Qwen3.6-27B DeltaNet layers. + +Contains three kernel implementations: +- nki_deltanet: Per-token recurrent kernel (used for token generation) +- nki_deltanet_chunked: Per-chunk kernel (legacy, superseded by fused) +- nki_deltanet_fused: Fused single-kernel chunked forward (used for context encoding) +""" diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py new file mode 100644 index 00000000..a9994d54 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet.py @@ -0,0 +1,334 @@ +"""NKI kernels for DeltaNet gated delta rule recurrent forward. + +NKI v3 (SDK 2.29, NKI 0.3.0). Processes a SINGLE (batch, head) pair per kernel call. +The caller loops over (B, H) in PyTorch and calls this kernel for each pair. + +Input layout: All inputs are 2D contiguous tensors (S, 128). +Each call processes one (batch, head) element's full sequence. + +k_dim = v_dim = 128, which matches SBUF tile partition dimension exactly. +g and beta are scalars per token, expanded to (S, 128) by the caller. + +Two kernel variants: + deltanet_recurrent_fwd -- returns output only (original) + deltanet_recurrent_fwd_state -- returns (output, final_state) for CTE->TKG carry-over +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +# Partition dimension max (NeuronCore SBUF tile width) +P_MAX = 128 + +# Shuffle mask: broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_recurrent_fwd( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +) -> nl.ndarray: + """NKI kernel for DeltaNet recurrent forward -- single (batch, head). + + Iterates over sequence tokens with sequential_range. + State matrix (128 x 128) lives in SBUF. + + Args: + query: (S, 128) float32 + key: (S, 128) float32 + value: (S, 128) float32 + g_in: (S, 128) float32 + beta_in: (S, 128) float32 + + Returns: + output: (S, 128) float32 + """ + seq_len, dim = query.shape + + # Output tensor in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + # 1) Transpose delta (128,1) -> (1,128) in PSUM + # 2) Copy PSUM (1,128) -> SBUF (128,128) -- partition broadcast + # 3) Multiply by k_t (128,1) which broadcasts across free dim + # This avoids the nc_matmul P=1 outer product (wastes 127/128 TE lanes). + + # Transpose delta to get values along free dimension + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + # Each partition row gets the same delta values + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + # Element-wise multiply: outer[i,j] = delta_broadcast[i,j] * k_t[i,0] + # tensor_scalar broadcasts (P,1) k_t across all F columns + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + # Accumulate into state + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + return output + + +@nki.jit +def deltanet_recurrent_fwd_state( + query: nl.ndarray, # (S, 128) float32 + key: nl.ndarray, # (S, 128) float32 + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 128) float32, log-decay broadcast to 128 + beta_in: nl.ndarray, # (S, 128) float32, write gate broadcast to 128 +): + """NKI kernel for DeltaNet recurrent forward with final state output. + + Same recurrence as deltanet_recurrent_fwd, but ALSO writes the final + recurrent state (128, 128) to an output HBM buffer. This enables + CTE -> TKG state carry-over. + + Returns: + output: (S, 128) float32 -- per-token output + final_state: (128, 128) float32 -- recurrent state after last token + """ + seq_len, dim = query.shape + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Stride: for 2D (S, D), dim0 stride = D=128, dim1 stride = 1 + seq_stride = dim + + # Initialize recurrent state in SBUF: (128, 128) + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # Sequential loop over tokens (state-dependent) + for t in nl.sequential_range(seq_len): + tok_offset = t * seq_stride + + # ---- Load inputs for token t ---- + q_t = nl.ndarray((P_MAX, 1), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_t, + src=query.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + k_t = nl.ndarray((P_MAX, 1), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_t, + src=key.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + v_t = nl.ndarray((P_MAX, 1), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_t, + src=value.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + g_t = nl.ndarray((P_MAX, 1), dtype=g_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_t, + src=g_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + beta_t = nl.ndarray((P_MAX, 1), dtype=beta_in.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_t, + src=beta_in.ap(pattern=[[1, P_MAX]], offset=tok_offset), + ) + + # ---- Step 1: Decay state -- state = state * exp(g_t) ---- + exp_g = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation(dst=exp_g, op=nl.exp, data=g_t, bias=None, scale=1.0) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_g, + engine=nisa.vector_engine, + ) + nisa.tensor_copy(dst=state, src=state_decayed) + + # ---- Step 2: Read memory -- kv_mem = state^T @ k_t ---- + kv_mem_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_mem_psum, stationary=state, moving=k_t) + kv_mem = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_mem, src=kv_mem_psum) + + # ---- Step 3: delta = (v_t - kv_mem) * beta_t ---- + v_sub = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_sub, data1=v_t, data2=kv_mem, op=nl.subtract) + + delta = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=delta, + data=v_sub, + op0=nl.multiply, + operand0=beta_t, + engine=nisa.vector_engine, + ) + + # ---- Step 4: state += outer(k_t, delta) ---- + # Broadcast multiply: outer[i,j] = k_t[i] * delta[j] + delta_row_psum = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=delta_row_psum, data=delta) + + # Copy PSUM (1, 128) -> SBUF (1, 128) first (NKI 0.3.0 requires matching P dims) + delta_row_sb = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=delta_row_sb, src=delta_row_psum) + + # Broadcast (1, 128) SBUF -> (128, 128) SBUF via nc_stream_shuffle + delta_broadcast = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=delta_row_sb[0:1, 0:P_MAX], + dst=delta_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + outer_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=outer_prod, + data=delta_broadcast, + op0=nl.multiply, + operand0=k_t, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state, data2=outer_prod, op=nl.add) + nisa.tensor_copy(dst=state, src=state_new) + + # ---- Step 5: o_t = state^T @ q_t ---- + o_t_psum = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=o_t_psum, stationary=state, moving=q_t) + o_t = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=o_t, src=o_t_psum) + + # ---- Store output for token t ---- + nisa.dma_copy( + dst=output.ap(pattern=[[1, dim]], offset=tok_offset), + src=o_t, + ) + + # ---- Write final state to HBM ---- + # state is (128, 128) in SBUF, copy to final_state in HBM + # Use dma_copy with full tile: P_MAX rows, dim cols + nisa.dma_copy(dst=final_state, src=state) + + return output, final_state diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py new file mode 100644 index 00000000..f834c969 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_chunked.py @@ -0,0 +1,546 @@ +"""NKI per-chunk DeltaNet kernel for CTE (context encoding / prefill). + +Single-chunk kernel: processes one chunk (128 tokens) with a stable +triangular solve for intra-chunk correction. The caller loops over chunks in +PyTorch, passing state between calls. + +Each kernel call: + - Takes one chunk of data: q, k, v, beta, g_cumsum, g_last (all 128x128) + - Takes recurrent state_in (128x128) + - Returns chunk output (128x128) and state_out (128x128) + +No sequence-indexed DMA inside the kernel -- all inputs/outputs are full tiles. +This avoids the DMA OOB issue seen with nl.sequential_range + slice indexing +in the NxDI model compilation context. + +NKI v3 (SDK 2.29, NKI 0.3.0). Uses nki.* namespace. +""" + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group. +_BROADCAST_MASK = [0] * 32 + + +@nki.jit +def deltanet_chunk_step( + query, # (128, 128) float32 -- one chunk, l2-normed+scaled + key, # (128, 128) float32 -- one chunk, l2-normed + value, # (128, 128) float32 -- one chunk + beta_broadcast, # (128, 128) float32 -- write gate broadcast to 128 + g_cumsum, # (128, 128) float32 -- cumsum of g within chunk, broadcast + g_last, # (128, 128) float32 -- g_cumsum[-1], constant in chunk, broadcast + state_in, # (128, 128) float32 -- recurrent state from previous chunk + lower_mask, # (128, 128) float32 -- strict lower triangular + identity, # (128, 128) float32 -- identity matrix + lower_mask_diag, # (128, 128) float32 -- lower tri with diagonal +): + """Process one chunk of DeltaNet. + + Returns: + output: (128, 128) float32 -- chunk output + state_out: (128, 128) float32 -- updated recurrent state + """ + C, dim = query.shape # C = 128, dim = 128 + + # Output tensors in HBM + output = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.shared_hbm) + state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # Load all inputs into SBUF + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=q_c, src=query) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=k_c, src=key) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=v_c, src=value) + + beta_c = nl.ndarray((P_MAX, dim), dtype=beta_broadcast.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=beta_c, src=beta_broadcast) + + gc_c = nl.ndarray((P_MAX, dim), dtype=g_cumsum.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gc_c, src=g_cumsum) + + gl_c = nl.ndarray((P_MAX, dim), dtype=g_last.dtype, buffer=nl.sbuf) + nisa.dma_copy(dst=gl_c, src=g_last) + + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=state, src=state_in) + + # Load masks + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=k_beta, data1=k_c, data2=beta_c, op=nl.multiply) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_beta, data1=v_c, data2=beta_c, op=nl.multiply) + + # ============================================================ + # Stable decay factors from cumulative log-decay + # + # The caller passes g_cumsum and g_last broadcast to (128, 128). Extract + # one column and build pairwise decays as exp(gc[i] - gc[j]) so no + # individual exp(-gc[j]) term can overflow. + # ============================================================ + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_p[0:P_MAX, 0:1], src=gc_c[0:P_MAX, 0:1]) + + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gl_p[0:P_MAX, 0:1], src=gl_c[0:P_MAX, 0:1]) + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy(dst=gc_padded[0:P_MAX, 0:1], src=gc_p[0:P_MAX, 0:1]) + + gc_row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_row_psum, data=gc_padded) + + gc_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=gc_row[0:1, 0:P_MAX], src=gc_row_psum[0:1, 0:P_MAX]) + + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # QK = k_beta @ k^T -- contract over features + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kb_T_psum, stationary=k_beta, moving=eye) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=k_T_psum, stationary=k_c, moving=eye) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Stable triangular solve: N = inv(I - A_mat) + # + # A_mat is strictly lower triangular. Solve two 64x64 diagonal blocks + # row-by-row, then merge the lower-left block. This is equivalent to the + # nilpotent Neumann series but avoids repeated squaring of A. + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=P_acc, value=0.0) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=A_mat) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + col_mask_left_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=col_mask_left_row, value=0.0) + nisa.memset(dst=col_mask_left_row[0:1, 0:64], value=1.0) + col_mask_left = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=col_mask_left_row[0:1, 0:P_MAX], + dst=col_mask_left[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + col_mask_right_row = nl.ndarray((1, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=col_mask_right_row, value=0.0) + nisa.memset(dst=col_mask_right_row[0:1, 64:P_MAX], value=1.0) + col_mask_right = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=col_mask_right_row[0:1, 0:P_MAX], + dst=col_mask_right[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + block_row_mask_bottom = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=block_row_mask_bottom[0:P_MAX, 0:1], + src=Lmask_d[0:P_MAX, 64:65], + ) + + for solve_i in nl.static_range(64): + row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=P_acc) + row_prod = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=row_with_eye, data1=row_prod, data2=eye, op=nl.add) + + row_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_col_masked, + data1=row_with_eye, + data2=col_mask_left, + op=nl.multiply, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, solve_i : solve_i + 1], + ) + row_update = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_col_masked, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + P_next = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_next, data1=P_acc, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=P_acc, src=P_next) + + for solve_i in nl.static_range(64): + row_idx = 64 + solve_i + + row_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=P_acc) + row_prod = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=row_with_eye, data1=row_prod, data2=eye, op=nl.add) + + row_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_col_masked, + data1=row_with_eye, + data2=col_mask_right, + op=nl.multiply, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, row_idx : row_idx + 1], + ) + row_update = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_col_masked, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + P_next = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_next, data1=P_acc, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=P_acc, src=P_next) + + N_diag_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_diag_T_psum, data=P_acc) + N_diag_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_diag_T, src=N_diag_T_psum) + + tmp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=tmp_psum, stationary=N_diag_T, moving=A_mat) + tmp = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=tmp, src=tmp_psum) + + tmp_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=tmp_T_psum, data=tmp) + tmp_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=tmp_T, src=tmp_T_psum) + + N21_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=N21_psum, stationary=tmp_T, moving=P_acc) + N21 = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N21, src=N21_psum) + + N21_col_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=N21_col_masked, + data1=N21, + data2=col_mask_left, + op=nl.multiply, + ) + N21_block = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=N21_block, + data=N21_col_masked, + op0=nl.multiply, + operand0=block_row_mask_bottom, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=P_acc, data1=P_acc, data2=N21_block, op=nl.add) + + # ============================================================ + # Apply N: value_corr = N @ v_beta, k_cumdecay = N @ (k_beta * exp_gc) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=q_T_psum, stationary=q_c, moving=eye) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=attn_intra, data1=qk_raw, data2=decay_diag, op=nl.multiply) + + # ============================================================ + # v_prime = k_cumdecay @ state + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_T_psum, stationary=k_cumdecay, moving=eye) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(g_cumsum)) @ state + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qe_T_psum, stationary=q_exp, moving=eye) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_T_psum, stationary=attn_intra, moving=eye) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + nisa.dma_copy(dst=output, src=chunk_out) + + # ============================================================ + # State update: state_new = state * exp(g_last) + # + (k * exp(g_last - gc))^T @ v_new + # ============================================================ + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p, + op=nl.exp, + data=gl_minus_gc_p, + bias=None, + scale=1.0, + ) + + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + + state_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=state_new, data1=state_decayed, data2=kv_outer, op=nl.add) + + nisa.dma_copy(dst=state_out, src=state_new) + + return output, state_out diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py new file mode 100644 index 00000000..6008ae5a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py @@ -0,0 +1,595 @@ +"""Fused single-kernel DeltaNet chunked forward for CTE (context encoding). + +SSD-style architecture: processes ALL chunks for one (batch, head) pair in +a single NKI kernel call. State (128x128) persists in SBUF across chunks — +no HBM round-trips for inter-chunk state propagation. + +Key optimizations over nki_deltanet_chunked.py: + 1. Single kernel call per (B,H) instead of B*H*num_chunks calls + 2. State in SBUF across all chunks (no HBM state read/write per chunk) + 3. In-kernel cumsum via tensor_tensor_scan (no PyTorch cumsum) + 4. Masks and constants loaded once, reused across chunks + 5. Uses tensor_scalar for partition-broadcast (no explicit broadcast loops) + 6. nc_transpose (Vector Engine) for all 128x128 transposes instead of + nc_matmul(moving=eye) (Tensor Engine) — frees TE for actual math + +NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. +Chunk size = 128 = P_MAX (one tile per chunk). + +Mathematical framework (same as nki_deltanet_chunked.py): + Per-chunk Neumann-series power-doubling for intra-chunk correction: + A = -QK_decay * lower_mask + N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] + value_corr = N @ v_beta + k_cumdecay = N @ (k_beta * exp(gc)) + + Inter-chunk state propagation: + v_prime = k_cumdecay @ state + v_new = value_corr - v_prime + attn_inter = (q * exp(gc)) @ state + attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + output = attn_inter + attn_intra @ v_new + state = exp(g_last) * (state + k_raw_decay^T @ v_new) +""" + +import numpy as np + +import nki +import nki.isa as nisa +import nki.language as nl + +P_MAX = 128 # Partition dim = chunk_size = k_dim = v_dim +CHUNK_SIZE = 128 + +# Broadcast partition 0 to all partitions in a 32-wide group +_BROADCAST_MASK = [0] * 32 + + +def _make_lower_mask(): + """Strict lower triangular (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=-1) + + +def _make_lower_mask_diag(): + """Lower triangular with diagonal (128x128) as numpy constant.""" + return np.tril(np.ones((CHUNK_SIZE, CHUNK_SIZE), dtype=np.float32), k=0) + + +def _make_identity(): + """Identity matrix (128x128) as numpy constant.""" + return np.eye(CHUNK_SIZE, dtype=np.float32) + + +@nki.jit +def deltanet_fused_chunked_fwd( + query: nl.ndarray, # (S, 128) float32 — l2-normed and scaled + key: nl.ndarray, # (S, 128) float32 — l2-normed + value: nl.ndarray, # (S, 128) float32 + g_in: nl.ndarray, # (S, 1) float32 — per-token log-decay (NOT cumsum) + beta_in: nl.ndarray, # (S, 1) float32 — per-token write gate + lower_mask: nl.ndarray, # (128, 128) float32 — strict lower tri + identity: nl.ndarray, # (128, 128) float32 — identity + lower_mask_diag: nl.ndarray, # (128, 128) float32 — lower tri with diag +): + """Fused chunked DeltaNet forward — single kernel call per (batch, head). + + Processes all chunks sequentially within the kernel, keeping the recurrent + state (128x128) in SBUF across chunks. Returns per-token output and + final state. + + Input requirements: + - S must be divisible by 128 (pad before calling) + - query must be l2-normed and scaled by 1/sqrt(k_dim) + - key must be l2-normed + - g_in is RAW log-decay (cumsum computed in-kernel via tensor_tensor_scan) + - beta_in is sigmoid(b) (write gate) + + Returns: + output: (S, 128) float32 + final_state: (128, 128) float32 + """ + seq_len = query.shape[0] + dim = query.shape[1] # 128 + num_chunks = seq_len // CHUNK_SIZE + + # Output tensors in HBM + output = nl.ndarray((seq_len, dim), dtype=query.dtype, buffer=nl.shared_hbm) + final_state_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.shared_hbm) + + # ================================================================ + # Load constant masks into SBUF once (reused across all chunks) + # ================================================================ + eye = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=eye, src=identity) + + Lmask = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask, src=lower_mask) + + Lmask_d = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy(dst=Lmask_d, src=lower_mask_diag) + + # Ones vector for cumsum scan: (1, CHUNK_SIZE) + ones_1xC = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=ones_1xC, value=1.0) + + # Zero initial for cumsum scan + zero_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=zero_11, value=0.0) + + # ================================================================ + # Initialize recurrent state in SBUF — persists across ALL chunks + # ================================================================ + state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=state, value=0.0) + + # ================================================================ + # Sequential chunk processing + # ================================================================ + for i_chunk in nl.sequential_range(num_chunks): + chunk_start = i_chunk * CHUNK_SIZE + + # ---- Load chunk data from HBM ---- + q_c = nl.ndarray((P_MAX, dim), dtype=query.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=q_c, + src=query[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + k_c = nl.ndarray((P_MAX, dim), dtype=key.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=k_c, + src=key[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + v_c = nl.ndarray((P_MAX, dim), dtype=value.dtype, buffer=nl.sbuf) + nisa.dma_copy( + dst=v_c, + src=value[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + ) + + # g: (CHUNK_SIZE, 1) — raw log-decay per token + g_chunk_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=g_chunk_p[0:CHUNK_SIZE, 0:1], + src=g_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # beta: (CHUNK_SIZE, 1) — write gate scalar per token + beta_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.dma_copy( + dst=beta_p[0:CHUNK_SIZE, 0:1], + src=beta_in[chunk_start : chunk_start + CHUNK_SIZE, 0:1], + ) + + # ---- In-kernel cumsum of g via tensor_tensor_scan ---- + # Need g as (1, CHUNK_SIZE) for scan along free dim. + # Transpose: (CHUNK_SIZE, 1) -> (1, CHUNK_SIZE) via nc_transpose + g_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=g_padded, value=0.0) + nisa.tensor_copy( + dst=g_padded[0:CHUNK_SIZE, 0:1], + src=g_chunk_p[0:CHUNK_SIZE, 0:1], + ) + + g_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=g_tp_psum, data=g_padded) + + g_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=g_row[0:1, 0:CHUNK_SIZE], + src=g_tp_psum[0:1, 0:CHUNK_SIZE], + ) + + # cumsum: gc_row[t] = 1.0 * gc_row[t-1] + g_row[t] + gc_row = nl.ndarray((1, CHUNK_SIZE), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor_scan( + dst=gc_row[0:1, 0:CHUNK_SIZE], + data0=ones_1xC[0:1, 0:CHUNK_SIZE], + data1=g_row[0:1, 0:CHUNK_SIZE], + initial=zero_11[0:1, 0:1], + op0=nl.multiply, + op1=nl.add, + ) + + # Transpose gc back to (CHUNK_SIZE, 1) partition layout + gc_padded = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=gc_padded, value=0.0) + nisa.tensor_copy( + dst=gc_padded[0:1, 0:CHUNK_SIZE], + src=gc_row[0:1, 0:CHUNK_SIZE], + ) + + gc_tp_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=gc_tp_psum, data=gc_padded) + + # gc_p: (P_MAX, 1) — cumulative sum of g per token in this chunk + gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gc_p[0:CHUNK_SIZE, 0:1], + src=gc_tp_psum[0:CHUNK_SIZE, 0:1], + ) + + # g_last = gc[-1] (scalar) — needed for state decay + gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=gl_11[0:1, 0:1], + src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], + ) + + # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast + # across the free dimension without explicit (P_MAX, dim) copies. + + exp_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_gc_p, + data=gc_p, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_neg_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=neg_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, + ) + + # exp(g_last): scalar, then broadcast to (P_MAX, 1) + exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_11, + op=nl.exp, + data=gl_11, + bias=None, + scale=1.0, + ) + + exp_gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=exp_gl_11[0:1, 0:1], + dst=exp_gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) + + # ============================================================ + # k_beta = K * beta, v_beta = V * beta + # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim + # ============================================================ + k_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_beta, + data=k_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + v_beta = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=v_beta, + data=v_c, + op0=nl.multiply, + operand0=beta_p, + engine=nisa.vector_engine, + ) + + # ============================================================ + # Phase 1: Build A matrix (intra-chunk correction) + # Transpose K and K_beta for matmul + # ============================================================ + kb_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kb_T_psum, data=k_beta) + kb_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kb_T, src=kb_T_psum) + + k_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=k_T_psum, data=k_c) + k_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_T, src=k_T_psum) + + # QK = k_beta^T @ k (contract over features) + QK_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=QK_psum, stationary=kb_T, moving=k_T) + QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK, src=QK_psum) + + # ============================================================ + # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) + # + # Apply the strict causal mask before the split exp(gc) / exp(-gc) + # scaling. Upper-triangular entries are mathematically unused, but + # scaling them first can create very large finite values that poison + # later matmuls before the mask is applied. + # ============================================================ + QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) + + # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) + # Then transpose, column scale, transpose back. + # Uses tensor_scalar with (P_MAX,1) operand for row scaling. + QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_row, + data=QK_masked, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose to scale columns (now rows in transposed view) + QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) + QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) + + QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=QK_r_T_col, + data=QK_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose back + QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + + # A = -QK_decay * lower_mask + neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=neg_QK_decay, + data=QK_decay, + op0=nl.multiply, + operand0=-1.0, + engine=nisa.vector_engine, + ) + A_mat = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) + + # ============================================================ + # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) + # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) + # ============================================================ + P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) + + A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_pow, src=A_mat) + + for _round in nl.sequential_range(6): + # A_pow = A_pow^2: transpose A_pow, then matmul + Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) + Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) + + Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) + nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + + # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul + IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) + + IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=IpA_T_psum, data=IpA) + IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) + + Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) + nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + + # ============================================================ + # Apply N: value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # ============================================================ + N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=N_T_psum, data=P_acc) + N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=N_T, src=N_T_psum) + + vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) + value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=value_corr, src=vc_psum) + + # k_beta * exp(gc): row-scaled + kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=kb_exp_gc, + data=k_beta, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) + k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + + # ============================================================ + # Phase 2: Inter-chunk state propagation + # attn_intra = (q @ k^T) * decay_mask * lower_mask_diag + # ============================================================ + q_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=q_T_psum, data=q_c) + q_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=q_T, src=q_T_psum) + + qk_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=qk_psum, stationary=q_T, moving=k_T) + qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_raw, src=qk_psum) + + # Mask before split scaling for the same reason as the A matrix above: + # upper-triangular decay factors are unused and can be numerically huge. + qk_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=qk_masked, data1=qk_raw, data2=Lmask_d, op=nl.multiply) + + # Row-scale by exp(gc) + qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_row, + data=qk_masked, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + # Transpose, column-scale by exp(-gc), transpose back + qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) + qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) + + qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=qk_r_T_col, + data=qk_r_T, + op0=nl.multiply, + operand0=exp_neg_gc_p, + engine=nisa.vector_engine, + ) + + qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + + attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply + ) + + # ============================================================ + # v_prime = k_cumdecay @ state (state is in SBUF!) + # ============================================================ + kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) + kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) + + vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) + v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=v_prime, src=vp_psum) + + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) + + # ============================================================ + # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) + # ============================================================ + q_exp = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=q_exp, + data=q_c, + op0=nl.multiply, + operand0=exp_gc_p, + engine=nisa.vector_engine, + ) + + qe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=qe_T_psum, data=q_exp) + qe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=qe_T, src=qe_T_psum) + + ai_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=ai_psum, stationary=qe_T, moving=state) + attn_inter = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=attn_inter, src=ai_psum) + + # ============================================================ + # attn_intra @ v_new + # ============================================================ + ai_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=ai_T_psum, data=attn_intra) + ai_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=ai_T, src=ai_T_psum) + + intra_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=intra_psum, stationary=ai_T, moving=v_new) + intra_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=intra_out, src=intra_psum) + + # ============================================================ + # chunk_output = attn_inter + intra_out + # ============================================================ + chunk_out = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=chunk_out, data1=attn_inter, data2=intra_out, op=nl.add) + + # Store output chunk to HBM + nisa.dma_copy( + dst=output[chunk_start : chunk_start + CHUNK_SIZE, 0:dim], + src=chunk_out, + ) + + # ============================================================ + # State update: state = exp(g_last) * (state + k_raw_decay^T @ v_new) + # state is updated IN-PLACE in SBUF — no HBM round-trip! + # ============================================================ + + # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. + # Compute the equivalent stable form k * exp(g_last - gc), so the + # factor is always <= 1 for valid causal positions. + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=exp_gl_minus_gc_p, + data1=exp_gl_p, + data2=exp_neg_gc_p, + op=nl.multiply, + ) + + # k_raw_decay = k * exp(g_last - gc) + k_raw_decay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=k_raw_decay, + data=k_c, + op0=nl.multiply, + operand0=exp_gl_minus_gc_p, + engine=nisa.vector_engine, + ) + + # k_raw_decay^T @ v_new → (dim, dim) outer product sum + # nc_matmul: result[M,N] = sum_K stationary[K,M] * moving[K,N] + # stationary=k_raw_decay (P_MAX, dim), moving=v_new (P_MAX, dim) + # Result: sum over tokens -> (dim, dim) + kv_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kv_psum, stationary=k_raw_decay, moving=v_new) + kv_outer = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kv_outer, src=kv_psum) + + # state = state * exp(g_last) + kv_outer + # tensor_scalar broadcasts exp_gl_p (P_MAX, 1) across free dim. + state_decayed = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=state_decayed, + data=state, + op0=nl.multiply, + operand0=exp_gl_p, + engine=nisa.vector_engine, + ) + nisa.tensor_tensor(dst=state, data1=state_decayed, data2=kv_outer, op=nl.add) + + # ---- Write final state to HBM ---- + nisa.dma_copy(dst=final_state_out, src=state) + + return output, final_state_out diff --git a/contrib/models/Qwen3.6-27B/test/__init__.py b/contrib/models/Qwen3.6-27B/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/__init__.py b/contrib/models/Qwen3.6-27B/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py new file mode 100644 index 00000000..1c06f89f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/qwen36_27b_compile_fp8.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Compile Qwen3.6-27B 64K with a scoped FP8 weight-quantization ablation. + +This script intentionally starts from the validated 64K hybrid/chunked-prefill +baseline and changes only weight quantization. The first supported mode is +``mlp_only``: MLP linear weights are converted to FP8 while attention, DeltaNet, +normalization, embeddings, lm_head, KV cache, and recurrent state remain BF16. +""" + +from __future__ import annotations + +import argparse +import gc +import json +import os +import sys +from pathlib import Path + +import torch + + +def _repo_root(path: str | None) -> Path: + if path: + return Path(path).expanduser().resolve() + return Path(__file__).resolve().parents[5] + + +def _load_text_config(model_path: Path) -> dict: + with (model_path / "config.json").open() as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + return config_dict + + +def _mlp_only_modules_to_not_convert(num_layers: int) -> list[str]: + """Exclude numerically sensitive or unsupported modules from FP8 conversion.""" + modules = [ + "embed_tokens", + "model.embed_tokens", + "lm_head", + "norm", + "model.norm", + "rotary_emb", + "model.rotary_emb", + ] + for layer_idx in range(num_layers): + for prefix in ("layers", "model.layers"): + modules.extend( + [ + f"{prefix}.{layer_idx}.self_attn", + f"{prefix}.{layer_idx}.linear_attn", + f"{prefix}.{layer_idx}.input_layernorm", + f"{prefix}.{layer_idx}.post_attention_layernorm", + ] + ) + return modules + + +def _quantized_checkpoint_ready(path: Path) -> bool: + if path.is_file(): + return True + if path.is_dir(): + return any(path.iterdir()) + return False + + +def _is_mlp_weight(name: str) -> bool: + parts = name.split(".") + return ( + len(parts) >= 4 + and parts[-3] == "mlp" + and parts[-2] in {"gate_proj", "up_proj", "down_proj"} + and parts[-1] == "weight" + ) + + +def _scale_name(weight_name: str) -> str: + return weight_name[: -len(".weight")] + ".weight_scale" + + +def _clear_quantized_checkpoint_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=True) + for child in path.iterdir(): + if child.name.endswith(".safetensors") or child.name.endswith(".json"): + child.unlink() + + +def _save_mlp_only_fp8_state_dict(model_path: Path, output_path: Path) -> None: + """Create a sharded FP8 checkpoint directly from HF safetensors. + + Loading the HF architecture requires a newer Transformers than the Neuron + venv uses internally. For this MLP-only ablation, we do not need model + execution: the checkpoint transform is a direct tensor rewrite. + """ + from safetensors.torch import load_file, save_file # noqa: WPS433 + from neuronx_distributed.quantization.quantization_utils import ( # noqa: WPS433 + quantize_fp8_per_channel, + ) + + index_path = model_path / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open() as f: + source_index = json.load(f) + source_weight_map = source_index["weight_map"] + filenames = sorted(set(source_weight_map.values())) + elif (model_path / "model.safetensors").exists(): + source_weight_map = None + filenames = ["model.safetensors"] + else: + raise FileNotFoundError(f"No safetensors checkpoint found in {model_path}") + + _clear_quantized_checkpoint_dir(output_path) + output_weight_map: dict[str, str] = {} + total_size = 0 + quantized_count = 0 + + for filename in filenames: + shard = load_file(str(model_path / filename)) + output_shard = {} + for name, tensor in shard.items(): + if _is_mlp_weight(name): + weight, scale = quantize_fp8_per_channel( + tensor, + torch.float8_e4m3fn, + channel_axis=0, + ) + output_shard[name] = weight + output_shard[_scale_name(name)] = scale + output_weight_map[_scale_name(name)] = filename + total_size += weight.numel() * weight.element_size() + total_size += scale.numel() * scale.element_size() + quantized_count += 1 + else: + output_shard[name] = tensor + total_size += tensor.numel() * tensor.element_size() + output_weight_map[name] = filename + + save_file(output_shard, str(output_path / filename), metadata={"format": "pt"}) + del shard + del output_shard + gc.collect() + + if source_weight_map is not None: + with (output_path / "model.safetensors.index.json").open("w") as f: + json.dump( + { + "metadata": {"total_size": total_size}, + "weight_map": output_weight_map, + }, + f, + indent=2, + sort_keys=True, + ) + + print("MANUAL_FP8_MLP_WEIGHT_COUNT", quantized_count, flush=True) + + +def _build_config(args: argparse.Namespace): + from neuronx_distributed_inference.models.config import ( # noqa: WPS433 + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + config_dict = _load_text_config(model_path) + num_layers = int(config_dict["num_hidden_layers"]) + modules_to_not_convert = _mlp_only_modules_to_not_convert(num_layers) + + neuron_config = NeuronConfig( + tp_degree=args.tp_degree, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=args.seq_len, + max_context_length=args.cte_bucket, + max_length=args.seq_len, + context_encoding_buckets=[args.cte_bucket], + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig( + do_sample=False, + top_k=1, + top_p=1.0, + temperature=1.0, + ), + enable_bucketing=False, + logical_nc_config=args.logical_nc_config, + save_sharded_checkpoint=True, + quantized=True, + quantized_checkpoints_path=str( + Path(args.quantized_checkpoints_path).expanduser().resolve() + ), + quantization_type="per_channel_symmetric", + quantization_dtype="f8e4m3", + modules_to_not_convert=modules_to_not_convert, + kv_cache_quant=False, + quantized_mlp_kernel_enabled=False, + activation_quantization_type=None, + ) + + config_dict.setdefault("use_hybrid_cache_manager", True) + config_dict.setdefault("use_qwen_hybrid_chunked_prefill", True) + config_dict.setdefault("use_qwen_hybrid_chunked_prefill_nki", True) + + inf_config = Qwen35InferenceConfig(neuron_config=neuron_config, **config_dict) + return inf_config, modules_to_not_convert + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-path", required=True) + parser.add_argument("--quantized-checkpoints-path", required=True) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--tp-degree", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--force-quantize", action="store_true") + parser.add_argument("--quantize-only", action="store_true") + parser.add_argument("--load-after-compile", action="store_true") + args = parser.parse_args() + + repo = _repo_root(args.repo_root) + contrib_model_dir = repo / "contrib" / "models" / "Qwen3.6-27B" + sys.path.insert(0, str(repo)) + sys.path.insert(0, str(contrib_model_dir)) + + from src.modeling_qwen35 import NeuronQwen35ForCausalLM # noqa: WPS433 + + model_path = Path(args.model_path).expanduser().resolve() + compiled_path = Path(args.compiled_path).expanduser().resolve() + quantized_path = Path(args.quantized_checkpoints_path).expanduser().resolve() + + inf_config, modules_to_not_convert = _build_config(args) + + print("FP8_MODE mlp_only", flush=True) + print("MODEL_PATH", str(model_path), flush=True) + print("COMPILED_PATH", str(compiled_path), flush=True) + print("QUANTIZED_CHECKPOINTS_PATH", str(quantized_path), flush=True) + print("MODULES_TO_NOT_CONVERT_COUNT", len(modules_to_not_convert), flush=True) + print( + "CONTEXT_TRACE_SHAPE", + json.dumps( + { + "seq_len": args.seq_len, + "max_context_length": args.cte_bucket, + "context_encoding_buckets": [args.cte_bucket], + }, + sort_keys=True, + ), + flush=True, + ) + + if args.force_quantize or not _quantized_checkpoint_ready(quantized_path): + print("QUANTIZE_START manual_mlp_only", flush=True) + _save_mlp_only_fp8_state_dict(model_path, quantized_path) + print("QUANTIZE_DONE", flush=True) + else: + print("QUANTIZE_SKIP existing checkpoint found", flush=True) + + if args.quantize_only: + return 0 + + print("COMPILE_START", flush=True) + model = NeuronQwen35ForCausalLM(str(model_path), inf_config) + model.compile(str(compiled_path)) + del model + gc.collect() + print("COMPILE_DONE", flush=True) + + if args.load_after_compile: + model = NeuronQwen35ForCausalLM(str(compiled_path)) + model.load(str(compiled_path)) + print("LOAD_AFTER_COMPILE_OK", flush=True) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/test/integration/test_model.py b/contrib/models/Qwen3.6-27B/test/integration/test_model.py new file mode 100644 index 00000000..b1128c12 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/integration/test_model.py @@ -0,0 +1,605 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for Qwen3.6-27B on Neuron. + +Tests compilation, loading, inference accuracy, and performance using +the full 27B model with pre-downloaded HuggingFace weights on a trn2 instance. + +Qwen3.6-27B shares identical architecture with Qwen3.5-27B (qwen3_5 model_type). +These tests use the same Qwen35* classes and QWEN35_* env vars because the +underlying code is shared. + +Note: A mini model option is not provided because DeltaNet layers require NKI +kernels that only execute on Neuron devices, and the hybrid DeltaNet + GQA +architecture needs at least TP=4 for the full model to fit in HBM. + +Environment variables: + QWEN35_MODEL_PATH Path to HF model weights (required) + QWEN35_COMPILED_PATH Path to compiled artifacts (default: /tmp/qwen35_27b_traced) + QWEN35_TP_DEGREE Tensor parallelism degree (default: 4) + QWEN35_SEQ_LEN Max sequence length (default: 128) + TTFT_THRESHOLD_MS Max TTFT in ms (default: 5000) + THROUGHPUT_THRESHOLD Min throughput in tok/s (default: 5.0) + +Prerequisites: + - trn2.3xlarge or larger with TP >= 4 NeuronCores available + - NXDI installed (neuronx_distributed_inference) + - HuggingFace weights downloaded to QWEN35_MODEL_PATH + - SDK 2.29+ (NKI 0.3.0 required for DeltaNet kernels) + +Usage: + # Full model (trn2.3xlarge, TP=4): + QWEN35_MODEL_PATH=/mnt/models/Qwen3.6-27B \\ + QWEN35_COMPILED_PATH=/mnt/models/qwen36_traced \\ + pytest test/integration/test_model.py --capture=tee-sys +""" + +import gc +import json +import os +import shutil +import subprocess +import sys +import time + +import pytest +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration from environment ────────────────────────────────────── + +MODEL_PATH = os.environ.get("QWEN35_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("QWEN35_COMPILED_PATH", "/tmp/qwen35_27b_traced") +TP_DEGREE = int(os.environ.get("QWEN35_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("QWEN35_SEQ_LEN", "128")) +TTFT_THRESHOLD_MS = float(os.environ.get("TTFT_THRESHOLD_MS", "5000")) +THROUGHPUT_THRESHOLD = float(os.environ.get("THROUGHPUT_THRESHOLD", "5.0")) +USE_HYBRID_CACHE = os.environ.get("QWEN35_USE_HYBRID_CACHE", "0") == "1" +RECORD_HBM = os.environ.get("QWEN35_RECORD_HBM", "0") == "1" + +requires_model_path = pytest.mark.skipif( + not MODEL_PATH, + reason=( + "QWEN35_MODEL_PATH not set. Integration tests require the full 27B model " + "weights. Set QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B to run these tests." + ), +) +requires_hbm_recording = pytest.mark.skipif( + not RECORD_HBM, + reason=( + "QWEN35_RECORD_HBM=1 not set. This optional test records Neuron HBM " + "usage for dummy-KV vs hybrid-cache comparisons." + ), +) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def model_path(): + """Return path to model weights.""" + return MODEL_PATH + + +@pytest.fixture(scope="module") +def compiled_model(model_path): + """Compile and load the model on Neuron.""" + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + # Read config.json directly (model_type 'qwen3_5' may not be in + # AutoConfig registry for all transformers versions) + with open(os.path.join(model_path, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + + inf_config = Qwen35InferenceConfig( + neuron_config=neuron_config, + use_hybrid_cache_manager=USE_HYBRID_CACHE, + **config_dict, + ) + + # Compile if no existing artifacts + compiled_path = COMPILED_PATH + neff_path = os.path.join(compiled_path, "model.pt") + if not os.path.exists(neff_path): + print(f"Compiling to {compiled_path}...") + model = NeuronQwen35ForCausalLM(model_path, inf_config) + model.compile(compiled_path) + del model + gc.collect() + + # Load + print(f"Loading from {compiled_path}...") + model = NeuronQwen35ForCausalLM(compiled_path) + model.load(compiled_path) + return model + + +@pytest.fixture(scope="module") +def tokenizer(model_path): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(model_path, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def generation_config(tokenizer): + """Create generation config.""" + from transformers import GenerationConfig + + return GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + +def _generate(model, tokenizer, generation_config, prompt, max_new_tokens=20): + """Generate text using the NXDI model.""" + import transformers + + from neuronx_distributed_inference.utils.hf_adapter import ( + HuggingFaceGenerationAdapter, + ) + + inputs = tokenizer(prompt, padding=True, return_tensors="pt") + gen_model = HuggingFaceGenerationAdapter(model) + gen_model.generation_config.transformers_version = transformers.__version__ + generation_config.transformers_version = transformers.__version__ + outputs = gen_model.generate( + inputs.input_ids, + generation_config=generation_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + return outputs[0].tolist(), tokenizer.decode(outputs[0], skip_special_tokens=True) + + +def _is_repetitive(text, max_repeat=5): + """Check for excessive word repetition.""" + words = text.split() + if len(words) < max_repeat: + return False + for i in range(len(words) - max_repeat + 1): + if len(set(words[i : i + max_repeat])) == 1: + return True + return False + + +def _parse_peak_neuron_memory(stdout): + peak_device = 0 + peak_tensors = 0 + samples = 0 + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + report = json.loads(line) + except json.JSONDecodeError: + continue + for runtime in report.get("neuron_runtime_data", []): + memory_used = runtime.get("report", {}).get("memory_used", {}) + used = memory_used.get("neuron_runtime_used_bytes", {}) + peak_device = max(peak_device, int(used.get("neuron_device", 0) or 0)) + nc_usage = ( + used.get("usage_breakdown", {}).get("neuroncore_memory_usage", {}) + ) + tensor_bytes = sum( + int(core.get("tensors", 0) or 0) for core in nc_usage.values() + ) + peak_tensors = max(peak_tensors, tensor_bytes) + samples += 1 + return peak_device, peak_tensors, samples + + +def _capture_neuron_hbm(tmp_path, fn): + if shutil.which("neuron-monitor") is None: + pytest.skip("neuron-monitor is not available") + + monitor_config = { + "period": "0.5s", + "neuron_runtimes": [ + { + "tag_filter": ".*", + "metrics": [{"type": "memory_used", "period": "0.5s"}], + } + ], + } + config_path = tmp_path / "neuron-monitor.json" + config_path.write_text(json.dumps(monitor_config)) + + proc = subprocess.Popen( + ["neuron-monitor", "--config-file", str(config_path)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + try: + time.sleep(1.0) + result = fn() + time.sleep(1.0) + finally: + proc.terminate() + try: + stdout, stderr = proc.communicate(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + + peak_device, peak_tensors, samples = _parse_peak_neuron_memory(stdout) + assert samples > 0, f"neuron-monitor produced no runtime samples: {stderr}" + assert peak_device > 0, "Expected non-zero Neuron device HBM usage" + return result, peak_device, peak_tensors, samples + + +# ── Smoke Tests ───────────────────────────────────────────────────────── + + +@requires_model_path +def test_model_loads(compiled_model): + """Model compiles and loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "neuron_config") + print(" Model loaded successfully") + + +@requires_model_path +def test_model_generates(compiled_model, tokenizer, generation_config): + """Model generates at least 5 tokens.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello, I am a language model", + max_new_tokens=20, + ) + input_len = len(tokenizer.encode("Hello, I am a language model")) + new_tokens = len(tokens) - input_len + assert new_tokens >= 5, f"Expected >= 5 new tokens, got {new_tokens}" + print(f" Generated {new_tokens} tokens: {text[:100]}...") + + +# ── Accuracy Tests ────────────────────────────────────────────────────── + + +@requires_model_path +def test_output_coherence(compiled_model, tokenizer, generation_config): + """Output should contain multiple words and not be excessively repetitive.""" + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + words = generated.split() + assert len(words) >= 3, f"Expected >= 3 words, got {len(words)}: '{generated}'" + assert not _is_repetitive(generated), ( + f"Output is excessively repetitive: '{generated}'" + ) + print(f" Output coherent ({len(words)} words): {generated[:80]}...") + + +@requires_model_path +def test_top_token_valid(compiled_model, tokenizer, generation_config): + """First generated token should be a valid decodable token.""" + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + "Hello!", + max_new_tokens=1, + ) + input_len = len(tokenizer.encode("Hello!")) + first_new = tokens[input_len] + assert 0 <= first_new < len(tokenizer), ( + f"Token {first_new} out of vocab range" + ) + decoded = tokenizer.decode([first_new]) + assert len(decoded) > 0, f"Token {first_new} decoded to empty string" + print(f" First token: {first_new} -> '{decoded}'") + + +@requires_model_path +def test_olympics_prompt_no_invalid_tokens( + compiled_model, tokenizer, generation_config +): + """Regression test for NaN logits producing the int32-min token id.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=32, + ) + input_len = len(tokenizer.encode(prompt)) + generated = tokens[input_len:] + invalid = [token for token in generated if token < 0 or token >= len(tokenizer)] + + assert len(generated) >= 5, f"Expected >= 5 generated tokens, got {generated}" + assert not invalid, f"Generated invalid token ids: {invalid}" + + +@requires_model_path +def test_capital_of_france(compiled_model, tokenizer, generation_config): + """'The capital of France is' should produce 'Paris' in the response.""" + tokens, text = _generate( + compiled_model, + tokenizer, + generation_config, + "The capital of France is", + max_new_tokens=30, + ) + generated = text[len("The capital of France is") :].strip() + assert "paris" in generated.lower(), ( + f"Expected 'Paris' in output, got: '{generated}'" + ) + print(f" Capital of France: {generated}") + + +# ── Performance Tests ─────────────────────────────────────────────────── + + +@requires_model_path +def test_performance_ttft(compiled_model, tokenizer, generation_config): + """Time to first token should be within threshold.""" + prompt = "Hello, I am a language model" + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1) + + # Measure + times = [] + for _ in range(3): + t0 = time.perf_counter() + _generate( + compiled_model, tokenizer, generation_config, prompt, max_new_tokens=1 + ) + times.append((time.perf_counter() - t0) * 1000) + + avg_ms = sum(times) / len(times) + print(f" TTFT: {avg_ms:.1f} ms (threshold: {TTFT_THRESHOLD_MS} ms)") + assert avg_ms < TTFT_THRESHOLD_MS, ( + f"TTFT {avg_ms:.1f}ms > threshold {TTFT_THRESHOLD_MS}ms" + ) + + +@requires_model_path +def test_performance_throughput(compiled_model, tokenizer, generation_config): + """Throughput should meet minimum threshold.""" + prompt = "Once upon a time" + num_new_tokens = 20 + + # Warmup + _generate(compiled_model, tokenizer, generation_config, prompt, max_new_tokens=5) + + # Measure + t0 = time.perf_counter() + tokens, _ = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=num_new_tokens, + ) + elapsed = time.perf_counter() - t0 + + input_len = len(tokenizer.encode(prompt)) + actual_new = len(tokens) - input_len + throughput = actual_new / elapsed if elapsed > 0 else 0 + + print( + f" Throughput: {throughput:.1f} tok/s ({actual_new} tokens in {elapsed:.2f}s)" + ) + print(f" Threshold: {THROUGHPUT_THRESHOLD} tok/s") + assert throughput > THROUGHPUT_THRESHOLD, ( + f"Throughput {throughput:.1f} tok/s < threshold {THROUGHPUT_THRESHOLD}" + ) + + +@requires_model_path +@requires_hbm_recording +def test_hybrid_cache_hbm_snapshot(compiled_model, tokenizer, generation_config, tmp_path): + """Record peak Neuron HBM for dummy-KV vs hybrid-cache comparison runs.""" + prompt = "Give me a summary of the 2020 Olympics in 100 tokens." + max_new_tokens = int(os.environ.get("QWEN35_HBM_NEW_TOKENS", "32")) + + (_, text), peak_device, peak_tensors, samples = _capture_neuron_hbm( + tmp_path, + lambda: _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=max_new_tokens, + ), + ) + + mode = "hybrid" if USE_HYBRID_CACHE else "dummy_kv" + print( + " HBM " + f"mode={mode} peak_device_bytes={peak_device} " + f"peak_tensor_bytes={peak_tensors} samples={samples}" + ) + assert len(text) > len(prompt) + + +# ── Multi-Prompt Quality Test ────────────────────────────────────────── + + +@requires_model_path +def test_multi_prompt_generation(compiled_model, tokenizer, generation_config): + """Multiple prompts should produce coherent outputs.""" + prompts = [ + "The capital of France is", + "def fibonacci(n):", + "The largest ocean on Earth is", + "To make a chocolate cake, you need", + ] + + for prompt in prompts: + _, text = _generate( + compiled_model, + tokenizer, + generation_config, + prompt, + max_new_tokens=30, + ) + generated = text[len(prompt) :].strip() + words = generated.split() + assert len(words) >= 2, ( + f"Prompt '{prompt}' generated too few words: '{generated}'" + ) + assert not _is_repetitive(generated), ( + f"Prompt '{prompt}' produced repetitive output: '{generated}'" + ) + print(f" '{prompt[:30]}...' -> {generated[:60]}...") + + +# ── Standalone runner ─────────────────────────────────────────────────── + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3.6-27B Integration Tests") + print("=" * 60) + + if not MODEL_PATH: + print("\nQWEN35_MODEL_PATH not set. Provide the model path to run tests:") + print(" QWEN35_MODEL_PATH=/path/to/Qwen3.6-27B \\") + print(" QWEN35_COMPILED_PATH=/mnt/models/qwen35_traced \\") + print(" python -m pytest test/integration/test_model.py --capture=tee-sys") + sys.exit(0) + + # Setup + from transformers import AutoTokenizer, GenerationConfig as GenConfig + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side="right") + if tok.pad_token is None: + tok.pad_token = tok.eos_token + gen_cfg = GenConfig( + do_sample=True, + top_k=1, + pad_token_id=tok.pad_token_id, + eos_token_id=tok.eos_token_id, + ) + + # Build model + import json + + from neuronx_distributed_inference.models.config import ( + NeuronConfig, + OnDeviceSamplingConfig, + ) + from src.modeling_qwen35 import Qwen35InferenceConfig, NeuronQwen35ForCausalLM + + nc = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + ctx_batch_size=1, + tkg_batch_size=1, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + save_sharded_checkpoint=True, + ) + + with open(os.path.join(MODEL_PATH, "config.json")) as f: + full_config = json.load(f) + text_config = full_config.get("text_config", full_config) + config_dict = dict(text_config) + config_dict["pad_token_id"] = text_config.get("eos_token_id", 248044) + if "rope_parameters" in text_config: + config_dict["rope_theta"] = text_config["rope_parameters"].get( + "rope_theta", 10000000 + ) + config_dict.setdefault("tie_word_embeddings", False) + ic = Qwen35InferenceConfig(neuron_config=nc, **config_dict) + + cp = COMPILED_PATH + if not os.path.exists(os.path.join(cp, "model.pt")): + print(f"Compiling to {cp}...") + m = NeuronQwen35ForCausalLM(MODEL_PATH, ic) + m.compile(cp) + del m + gc.collect() + + print(f"Loading from {cp}...") + model = NeuronQwen35ForCausalLM(cp) + model.load(cp) + + tests = [ + ("model_loads", lambda: test_model_loads(model)), + ("model_generates", lambda: test_model_generates(model, tok, gen_cfg)), + ("output_coherence", lambda: test_output_coherence(model, tok, gen_cfg)), + ("top_token_valid", lambda: test_top_token_valid(model, tok, gen_cfg)), + ("capital_of_france", lambda: test_capital_of_france(model, tok, gen_cfg)), + ("performance_ttft", lambda: test_performance_ttft(model, tok, gen_cfg)), + ( + "performance_throughput", + lambda: test_performance_throughput(model, tok, gen_cfg), + ), + ( + "multi_prompt_generation", + lambda: test_multi_prompt_generation(model, tok, gen_cfg), + ), + ] + + passed = 0 + for name, fn in tests: + print(f"\n--- {name} ---") + try: + fn() + print(f" PASS") + passed += 1 + except Exception as e: + print(f" FAIL: {e}") + + print(f"\n{'=' * 60}") + print(f"Results: {passed}/{len(tests)} passed") + print(f"{'=' * 60}") diff --git a/contrib/models/Qwen3.6-27B/test/unit/__init__.py b/contrib/models/Qwen3.6-27B/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_config.py b/contrib/models/Qwen3.6-27B/test/unit/test_config.py new file mode 100644 index 00000000..571ad522 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_config.py @@ -0,0 +1,201 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B inference configuration. + +CPU-only tests that validate config parsing, layer type setup, +DeltaNet parameter defaults, RoPE configuration, and weight conversion logic. +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest +from unittest.mock import MagicMock + +import torch + +# Ensure the contrib root (Qwen3.6-27B/) is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_config(**overrides): + """Create a Qwen35InferenceConfig with reasonable defaults.""" + neuron_config = NeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + # DeltaNet-specific + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + ) + defaults.update(overrides) + config = Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_hidden_size(self): + config = _make_config() + self.assertEqual(config.hidden_size, 5120) + + def test_num_hidden_layers(self): + config = _make_config() + self.assertEqual(config.num_hidden_layers, 64) + + def test_num_attention_heads(self): + config = _make_config() + self.assertEqual(config.num_attention_heads, 24) + + def test_num_key_value_heads(self): + config = _make_config() + self.assertEqual(config.num_key_value_heads, 4) + + def test_head_dim(self): + config = _make_config() + self.assertEqual(config.head_dim, 256) + + def test_intermediate_size(self): + config = _make_config() + self.assertEqual(config.intermediate_size, 17408) + + def test_vocab_size(self): + config = _make_config() + self.assertEqual(config.vocab_size, 248320) + + def test_hidden_act(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestLayerTypes(unittest.TestCase): + """Test hybrid layer type assignment (3 DeltaNet + 1 GQA) x 16.""" + + def test_layer_types_length(self): + config = _make_config() + self.assertEqual(len(config.layer_types), 64) + + def test_layer_types_pattern(self): + """Every 4th layer (3, 7, 11, ...) should be full_attention.""" + config = _make_config() + for i in range(64): + expected = "full_attention" if i % 4 == 3 else "linear_attention" + self.assertEqual(config.layer_types[i], expected, f"Layer {i} mismatch") + + def test_deltanet_layer_count(self): + config = _make_config() + dn_count = sum(1 for t in config.layer_types if t == "linear_attention") + self.assertEqual(dn_count, 48) + + def test_gqa_layer_count(self): + config = _make_config() + gqa_count = sum(1 for t in config.layer_types if t == "full_attention") + self.assertEqual(gqa_count, 16) + + +class TestDeltaNetConfig(unittest.TestCase): + """Test DeltaNet-specific configuration defaults.""" + + def test_linear_num_value_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_value_heads, 48) + + def test_linear_num_key_heads(self): + config = _make_config() + self.assertEqual(config.linear_num_key_heads, 16) + + def test_linear_key_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_key_head_dim, 128) + + def test_linear_value_head_dim(self): + config = _make_config() + self.assertEqual(config.linear_value_head_dim, 128) + + def test_linear_conv_kernel_dim(self): + config = _make_config() + self.assertEqual(config.linear_conv_kernel_dim, 4) + + +class TestRoPEConfig(unittest.TestCase): + """Test partial RoPE configuration.""" + + def test_partial_rotary_factor(self): + config = _make_config() + self.assertAlmostEqual(config.partial_rotary_factor, 0.25) + + def test_rope_dim(self): + """rope_dim = head_dim * partial_rotary_factor = 256 * 0.25 = 64.""" + config = _make_config() + self.assertEqual(config.rope_dim, 64) + + def test_attn_output_gate(self): + config = _make_config() + self.assertTrue(config.attn_output_gate) + + def test_mrope_section(self): + config = _make_config() + self.assertEqual(config.mrope_section, [11, 11, 10]) + + def test_mrope_interleaved(self): + config = _make_config() + self.assertTrue(config.mrope_interleaved) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_neuron_config_cls(self): + """Qwen3.5/3.6-27B is dense -- uses NeuronConfig, NOT MoENeuronConfig.""" + self.assertEqual( + Qwen35InferenceConfig.get_neuron_config_cls(), + NeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("hidden_size", required) + self.assertIn("num_hidden_layers", required) + self.assertIn("linear_num_value_heads", required) + self.assertIn("linear_key_head_dim", required) + self.assertIn("layer_types", required) + + def test_output_attentions_default(self): + config = _make_config() + self.assertFalse(config.output_attentions) + + def test_output_hidden_states_default(self): + config = _make_config() + self.assertFalse(config.output_hidden_states) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py new file mode 100644 index 00000000..416a431a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py @@ -0,0 +1,68 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for fused DeltaNet log-decay bounding.""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + FUSED_DELTANET_DECAY_MAX, + FUSED_DELTANET_DECAY_MIN, + _bound_fused_deltanet_log_decay, +) + + +def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): + num_chunks = total_seq_len // chunk_size + return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) + + +class TestFusedDeltaNetDecayBounding(unittest.TestCase): + def test_preserves_non_extreme_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -0.125, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + + torch.testing.assert_close(bounded, g) + + def test_bounds_per_chunk_cumulative_decay(self): + batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 + g = torch.full( + (batch_size, num_heads, total_seq_len), + -10.0, + dtype=torch.float32, + ) + + bounded = _bound_fused_deltanet_log_decay( + g, batch_size, num_heads, total_seq_len, chunk_size + ) + bounded_cumsum = _chunked_cumsum( + bounded, batch_size, num_heads, total_seq_len, chunk_size + ) + expected_cumsum = _chunked_cumsum( + g, batch_size, num_heads, total_seq_len, chunk_size + ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) + + torch.testing.assert_close(bounded_cumsum, expected_cumsum) + self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) + self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) + self.assertTrue(torch.isfinite(bounded).all()) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py new file mode 100644 index 00000000..fa887ca2 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_hybrid_cache_manager.py @@ -0,0 +1,314 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys +import unittest +from math import prod +from unittest.mock import patch + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from neuronx_distributed_inference.models.config import NeuronConfig +from src.modeling_qwen35 import HybridDeltaNetCacheManager, Qwen35InferenceConfig + + +def _make_config(**overrides): + neuron_overrides = overrides.pop("neuron_overrides", {}) + neuron_kwargs = dict( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + max_batch_size=2, + kv_cache_batch_size=2, + seq_len=16, + torch_dtype=torch.bfloat16, + ) + neuron_kwargs.update(neuron_overrides) + neuron_config = NeuronConfig(**neuron_kwargs) + defaults = dict( + hidden_size=5120, + num_hidden_layers=64, + num_attention_heads=24, + num_key_value_heads=4, + head_dim=256, + intermediate_size=17408, + vocab_size=248320, + rms_norm_eps=1e-6, + max_position_embeddings=131072, + rope_theta=10000, + hidden_act="silu", + tie_word_embeddings=False, + linear_num_value_heads=48, + linear_num_key_heads=16, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_conv_kernel_dim=4, + use_hybrid_cache_manager=True, + ) + defaults.update(overrides) + return Qwen35InferenceConfig(neuron_config=neuron_config, **defaults) + + +def _numel(shape): + return prod(int(dim) for dim in shape) + + +def _managed_cache_numel(mgr): + return sum(param.numel() for param in mgr.past_key_values) + + +def _deltanet_state_numel(config, max_batch_size): + recurrent = ( + max_batch_size + * config.linear_num_value_heads + * config.linear_key_head_dim + * config.linear_value_head_dim + ) + conv_dim = ( + 2 * config.linear_num_key_heads * config.linear_key_head_dim + + config.linear_num_value_heads * config.linear_value_head_dim + ) + conv = max_batch_size * conv_dim * (config.linear_conv_kernel_dim - 1) + return recurrent + conv + + +class TestHybridDeltaNetCacheManager(unittest.TestCase): + def test_allocates_per_layer_cache_shapes(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + self.assertEqual(len(mgr.past_key_values), config.num_hidden_layers * 2) + self.assertEqual(list(mgr.past_key_values[0].shape), [2, 48, 128, 128]) + self.assertEqual(list(mgr.past_key_values[1].shape), [2, 10240, 3]) + self.assertEqual(mgr.layer_types[3], "full_attention") + self.assertEqual(mgr.past_key_values[6].dim(), 4) + self.assertEqual(mgr.past_key_values[7].shape[2], 16) + + def test_get_cache_slices_only_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + cache = mgr.get_cache(seq_len=4, seq_ids=torch.tensor([1])) + recurrent_state, conv_state = cache[0] + full_k, full_v = cache[3] + + self.assertEqual(list(recurrent_state.shape), [1, 48, 128, 128]) + self.assertEqual(list(conv_state.shape), [1, 10240, 3]) + self.assertEqual(full_k.shape[0], 2) + self.assertEqual(full_v.shape[0], 2) + self.assertEqual(full_k.shape[2], 4) + self.assertEqual(full_v.shape[2], 4) + + def test_get_seq_length_uses_first_full_attention_layer(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + nested_cache = mgr.get_cache(seq_len=5, seq_ids=torch.tensor([0])) + flat_cache = [tensor for layer_cache in nested_cache for tensor in layer_cache] + + self.assertEqual(nested_cache[0][1].shape[2], 3) + self.assertEqual(mgr.get_seq_length(nested_cache), 5) + self.assertEqual(mgr.get_seq_length(flat_cache), 5) + + def test_get_cache_selects_deltanet_state_rows_by_seq_ids(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + with torch.no_grad(): + mgr.past_key_values[0][0, ...].fill_(7) + mgr.past_key_values[0][1, ...].fill_(13) + mgr.past_key_values[1][0, ...].fill_(17) + mgr.past_key_values[1][1, ...].fill_(19) + + recurrent_state, conv_state = mgr.get_cache( + seq_len=4, + seq_ids=torch.tensor([1, 0]), + )[0] + + self.assertTrue(torch.all(recurrent_state[0] == 13)) + self.assertTrue(torch.all(recurrent_state[1] == 7)) + self.assertTrue(torch.all(conv_state[0] == 19)) + self.assertTrue(torch.all(conv_state[1] == 17)) + + def test_deltanet_update_scatters_by_seq_id(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 10240, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_conv[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 1)) + self.assertTrue(torch.all(updated_conv[1] == 1)) + + def test_deltanet_full_batch_update_replaces_state_cache(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((2, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((2, 10240, 3), dtype=torch.bfloat16) + recurrent[0].fill_(3) + recurrent[1].fill_(5) + conv[0].fill_(11) + conv[1].fill_(13) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([0, 1]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 3)) + self.assertTrue(torch.all(updated_recurrent[1] == 5)) + self.assertTrue(torch.all(updated_conv[0] == 11)) + self.assertTrue(torch.all(updated_conv[1] == 13)) + + def test_deltanet_update_maps_out_of_range_seq_id_to_padding_row(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + recurrent = torch.ones((1, 48, 128, 128), dtype=torch.bfloat16) + conv = torch.ones((1, 10240, 3), dtype=torch.bfloat16) + + updated_recurrent, updated_conv = mgr.update_deltanet_state_by_layer_id( + idx=0, + seq_ids=torch.tensor([99]), + state_per_layer=(recurrent, conv), + ) + + self.assertTrue(torch.all(updated_recurrent[0] == 0)) + self.assertTrue(torch.all(updated_recurrent[1] == 0)) + self.assertTrue(torch.all(updated_recurrent[2] == 1)) + self.assertTrue(torch.all(updated_conv[2] == 1)) + + def test_deltanet_state_shapes_do_not_scale_with_sequence_length(self): + short_config = _make_config(neuron_overrides={"seq_len": 128}) + long_config = _make_config(neuron_overrides={"seq_len": 2048}) + short_mgr = HybridDeltaNetCacheManager( + short_config, num_kv_head=short_config.num_key_value_heads + ) + long_mgr = HybridDeltaNetCacheManager( + long_config, num_kv_head=long_config.num_key_value_heads + ) + + self.assertEqual(short_mgr.past_key_values[0].shape, long_mgr.past_key_values[0].shape) + self.assertEqual(short_mgr.past_key_values[1].shape, long_mgr.past_key_values[1].shape) + self.assertLess(short_mgr.past_key_values[7].shape[2], long_mgr.past_key_values[7].shape[2]) + + def test_get_cache_trims_padding_row_without_seq_ids(self): + config = _make_config(neuron_overrides={"kv_cache_padding_size": 1}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + recurrent_state, conv_state = mgr.get_cache(seq_len=4)[0] + + self.assertEqual(list(recurrent_state.shape), [2, 48, 128, 128]) + self.assertEqual(list(conv_state.shape), [2, 10240, 3]) + + def test_update_cache_dispatches_deltanet_and_full_attention_layers(self): + config = _make_config() + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + new_key_values = [] + for idx in range(4): + first = mgr.past_key_values[2 * idx] + second = mgr.past_key_values[2 * idx + 1] + new_key_values.append( + ( + torch.full_like(first, fill_value=idx + 1), + torch.full_like(second, fill_value=idx + 11), + ) + ) + + position_ids = torch.arange(16, dtype=torch.long).unsqueeze(0).expand(2, -1) + full_k_update = torch.full_like(mgr.past_key_values[6], fill_value=4) + full_v_update = torch.full_like(mgr.past_key_values[7], fill_value=14) + with patch.object( + mgr, "update_kv_by_layer_id", return_value=(full_k_update, full_v_update) + ) as update_kv: + updated = mgr.update_cache( + is_for_context_encoding=True, + seq_ids=torch.tensor([0, 1], dtype=torch.int32), + position_ids=position_ids, + new_key_values=new_key_values, + seq_len=16, + ) + + self.assertEqual(update_kv.call_count, 1) + self.assertEqual(update_kv.call_args.kwargs["idx"], 3) + self.assertTrue(torch.all(updated[0] == 1)) + self.assertTrue(torch.all(updated[1] == 11)) + self.assertTrue(torch.all(updated[6] == 4)) + self.assertTrue(torch.all(updated[7] == 14)) + + def test_managed_cache_removes_dummy_kv_for_deltanet_layers(self): + config = _make_config(neuron_overrides={"seq_len": 1024}) + mgr = HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + max_batch_size = ( + config.neuron_config.kv_cache_batch_size + + config.neuron_config.kv_cache_padding_size + ) + full_kv_per_layer = _numel(mgr.k_shape) + _numel(mgr.v_shape) + deltanet_layers = config.layer_types.count("linear_attention") + legacy_total_numel = ( + full_kv_per_layer * config.num_hidden_layers + + _deltanet_state_numel(config, max_batch_size) * deltanet_layers + ) + expected_savings = full_kv_per_layer * deltanet_layers + + self.assertEqual( + legacy_total_numel - _managed_cache_numel(mgr), + expected_savings, + ) + self.assertLess(_managed_cache_numel(mgr), legacy_total_numel) + + def test_rejects_unsupported_hybrid_modes(self): + unsupported_cases = [ + ({"padding_side": "left"}, "left padding"), + ({"flash_decoding_enabled": True}, "flash decoding"), + ] + + for neuron_overrides, expected_error in unsupported_cases: + with self.subTest(expected_error=expected_error): + config = _make_config(neuron_overrides=neuron_overrides) + with self.assertRaisesRegex(ValueError, expected_error): + HybridDeltaNetCacheManager( + config, num_kv_head=config.num_key_value_heads + ) + + config = _make_config() + config.neuron_config.kv_cache_quant = True + with self.assertRaisesRegex(ValueError, "KV cache quantization"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config( + neuron_overrides={ + "attention_dp_degree": 2, + "batch_size": 2, + "ctx_batch_size": 2, + "tkg_batch_size": 2, + "max_batch_size": 2, + "kv_cache_batch_size": 2, + "is_continuous_batching": True, + } + ) + with self.assertRaisesRegex(ValueError, "attention data parallelism"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + config = _make_config() + config.neuron_config.kv_cache_tiling = True + with self.assertRaisesRegex(ValueError, "KV cache tiling"): + HybridDeltaNetCacheManager(config, num_kv_head=config.num_key_value_heads) + + def test_legacy_config_default_is_disabled(self): + config = _make_config(use_hybrid_cache_manager=False) + self.assertFalse(config.use_hybrid_cache_manager) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..252da3f4 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/test/unit/test_weight_conversion.py @@ -0,0 +1,436 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for Qwen3.5/3.6-27B HF-to-NxDI weight conversion. + +CPU-only tests that validate: +- RMSNorm (+1 convention) weight conversion +- GQA q_proj interleaved split (query + gate) +- QK norm key renaming (q_norm -> q_layernorm, k_norm -> k_layernorm) +- Fused QKV concatenation +- DeltaNet layer weights pass through unchanged +- VL wrapper prefix stripping +- rank_util injection + +These tests are architecture-level and apply to both Qwen3.5-27B and Qwen3.6-27B. +""" + +import os +import sys +import unittest + +import torch + +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +from src.modeling_qwen35 import ( + Qwen35InferenceConfig, + NeuronQwen35ForCausalLM, + convert_qwen35_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import NeuronConfig + + +def _make_mini_config(num_layers=4, tp_degree=2, fused_qkv=True): + """Create a small Qwen35InferenceConfig for testing.""" + neuron_config = NeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + fused_qkv=fused_qkv, + ) + config = Qwen35InferenceConfig( + neuron_config=neuron_config, + hidden_size=256, + num_hidden_layers=num_layers, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=64, + intermediate_size=512, + vocab_size=1000, + rms_norm_eps=1e-6, + max_position_embeddings=4096, + rope_theta=10000, + hidden_act="silu", + linear_num_value_heads=8, + linear_num_key_heads=4, + linear_key_head_dim=32, + linear_value_head_dim=32, + linear_conv_kernel_dim=4, + ) + return config + + +def _make_mini_state_dict(config): + """Create a minimal HF-style state dict for conversion testing.""" + sd = {} + H = config.hidden_size # 256 + I = config.intermediate_size # 512 + V = config.vocab_size # 1000 + num_heads = config.num_attention_heads # 4 + num_kv = config.num_key_value_heads # 2 + head_dim = config.head_dim # 64 + + sd["embed_tokens.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["lm_head.weight"] = torch.randn(V, H, dtype=torch.bfloat16) * 0.02 + sd["norm.weight"] = torch.zeros(H, dtype=torch.bfloat16) # +1 convention: zeros + + for l in range(config.num_hidden_layers): + sd[f"layers.{l}.input_layernorm.weight"] = torch.zeros(H, dtype=torch.bfloat16) + sd[f"layers.{l}.post_attention_layernorm.weight"] = torch.zeros( + H, dtype=torch.bfloat16 + ) + + # Dense MLP (all layers) + sd[f"layers.{l}.mlp.gate_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.up_proj.weight"] = ( + torch.randn(I, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.mlp.down_proj.weight"] = ( + torch.randn(H, I, dtype=torch.bfloat16) * 0.02 + ) + + if config.layer_types[l] == "full_attention": + # GQA layer: q_proj is interleaved [head0_q | head0_gate | head1_q | ...] + q_proj = ( + torch.randn(num_heads * head_dim * 2, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_proj.weight"] = q_proj + sd[f"layers.{l}.self_attn.k_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.v_proj.weight"] = ( + torch.randn(num_kv * head_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.o_proj.weight"] = ( + torch.randn(H, num_heads * head_dim, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.self_attn.q_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.self_attn.k_norm.weight"] = torch.zeros( + head_dim, dtype=torch.bfloat16 + ) + else: + # DeltaNet layer: minimal required weights + key_dim = config.linear_num_key_heads * config.linear_key_head_dim # 128 + value_dim = ( + config.linear_num_value_heads * config.linear_value_head_dim + ) # 256 + conv_dim = key_dim * 2 + value_dim # 512 + sd[f"layers.{l}.linear_attn.in_proj_qkv.weight"] = ( + torch.randn(conv_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_z.weight"] = ( + torch.randn(value_dim, H, dtype=torch.bfloat16) * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_a.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.in_proj_b.weight"] = ( + torch.randn(config.linear_num_value_heads, H, dtype=torch.bfloat16) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.conv1d.weight"] = ( + torch.randn( + conv_dim, 1, config.linear_conv_kernel_dim, dtype=torch.bfloat16 + ) + * 0.02 + ) + sd[f"layers.{l}.linear_attn.A_log"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.dt_bias"] = torch.randn( + config.linear_num_value_heads, dtype=torch.bfloat16 + ) + sd[f"layers.{l}.linear_attn.norm.weight"] = ( + torch.randn(value_dim, dtype=torch.bfloat16) * 0.5 + ) + sd[f"layers.{l}.linear_attn.out_proj.weight"] = ( + torch.randn(H, value_dim, dtype=torch.bfloat16) * 0.02 + ) + + return sd + + +class TestNormConversion(unittest.TestCase): + """Test (+1 convention) RMSNorm weight conversion.""" + + def test_norm_weight_adds_one(self): + """Weights initialized to zero should become 1.0 after conversion.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + # norm.weight was zeros -> should now be ones + torch.testing.assert_close( + result["norm.weight"], + torch.ones_like(result["norm.weight"]), + ) + + def test_input_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.input_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} input_layernorm not converted", + ) + + def test_post_attn_layernorm_adds_one(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + w = result[f"layers.{l}.post_attention_layernorm.weight"] + self.assertTrue( + torch.allclose(w, torch.ones_like(w)), + f"Layer {l} post_attention_layernorm not converted", + ) + + def test_qk_norm_adds_one(self): + """Q/K norms on GQA layers should also get +1 applied.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + q_w = result[f"layers.{l}.self_attn.q_layernorm.weight"] + k_w = result[f"layers.{l}.self_attn.k_layernorm.weight"] + self.assertTrue( + torch.allclose(q_w, torch.ones_like(q_w)), + f"Layer {l} q_layernorm not converted", + ) + self.assertTrue( + torch.allclose(k_w, torch.ones_like(k_w)), + f"Layer {l} k_layernorm not converted", + ) + + +class TestQProjSplit(unittest.TestCase): + """Test q_proj interleaved split into query + gate.""" + + def test_q_proj_split_shapes(self): + """q_proj (num_heads * head_dim * 2, H) -> separate query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + # After split: q_proj should be (num_heads * head_dim, H) = (256, 256) + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + expected_shape = ( + config.num_attention_heads * config.head_dim, + config.hidden_size, + ) + self.assertEqual( + q_w.shape, expected_shape, f"Layer {l} q_proj shape wrong" + ) + self.assertEqual( + gate_w.shape, expected_shape, f"Layer {l} gate shape wrong" + ) + + def test_q_proj_deinterleave_correct(self): + """Verify the interleaved split correctly separates query and gate.""" + config = _make_mini_config(fused_qkv=False) + sd = _make_mini_state_dict(config) + + # Create a known pattern: head0 query is 1s, head0 gate is 2s, etc. + l = 3 # First full_attention layer (layer 3) + num_heads = config.num_attention_heads + head_dim = config.head_dim + H = config.hidden_size + + interleaved = torch.zeros(num_heads * head_dim * 2, H, dtype=torch.bfloat16) + for h in range(num_heads): + interleaved[h * head_dim * 2 : h * head_dim * 2 + head_dim, :] = float( + h + 1 + ) # query + interleaved[h * head_dim * 2 + head_dim : (h + 1) * head_dim * 2, :] = ( + float(h + 100) + ) # gate + + sd[f"layers.{l}.self_attn.q_proj.weight"] = interleaved + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + q_w = result[f"layers.{l}.self_attn.q_proj.weight"] + gate_w = result[f"layers.{l}.self_attn.output_gate_proj.weight"] + + for h in range(num_heads): + q_head = q_w[h * head_dim : (h + 1) * head_dim, :] + gate_head = gate_w[h * head_dim : (h + 1) * head_dim, :] + self.assertTrue( + torch.all(q_head == float(h + 1)), f"Head {h} query values wrong" + ) + self.assertTrue( + torch.all(gate_head == float(h + 100)), f"Head {h} gate values wrong" + ) + + +class TestQKNormRename(unittest.TestCase): + """Test q_norm -> q_layernorm and k_norm -> k_layernorm renaming.""" + + def test_old_keys_removed(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_norm.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_norm.weight", result) + + def test_new_keys_present(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertIn(f"layers.{l}.self_attn.q_layernorm.weight", result) + self.assertIn(f"layers.{l}.self_attn.k_layernorm.weight", result) + + +class TestFusedQKV(unittest.TestCase): + """Test fused QKV concatenation for attention layers.""" + + def test_fused_qkv_shape(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + fused_key = f"layers.{l}.self_attn.Wqkv.weight" + self.assertIn(fused_key, result, f"Layer {l} missing Wqkv") + + q_dim = config.num_attention_heads * config.head_dim + k_dim = config.num_key_value_heads * config.head_dim + v_dim = config.num_key_value_heads * config.head_dim + expected_rows = q_dim + k_dim + v_dim + self.assertEqual(result[fused_key].shape[0], expected_rows) + + def test_fused_qkv_removes_individual_keys(self): + config = _make_mini_config(fused_qkv=True) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + self.assertNotIn(f"layers.{l}.self_attn.q_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.k_proj.weight", result) + self.assertNotIn(f"layers.{l}.self_attn.v_proj.weight", result) + + +class TestDeltaNetPassthrough(unittest.TestCase): + """Test that DeltaNet layer weights pass through conversion unchanged.""" + + def test_deltanet_weights_unchanged(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Record original DeltaNet weights + originals = {} + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + key = f"layers.{l}.linear_attn.in_proj_qkv.weight" + originals[key] = sd[key].clone() + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for key, orig in originals.items(): + self.assertIn(key, result, f"Missing: {key}") + torch.testing.assert_close( + result[key], orig, msg=f"DeltaNet weight changed: {key}" + ) + + def test_deltanet_norm_not_converted(self): + """DeltaNet layers use standard RMSNorm (NOT +1 convention). + The norm weight should NOT be changed.""" + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Set DeltaNet norm to a known non-zero value + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + sd[f"layers.{l}.linear_attn.norm.weight"] = torch.full( + (config.linear_num_value_heads * config.linear_value_head_dim,), + 0.87, + dtype=torch.bfloat16, + ) + + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "linear_attention": + w = result[f"layers.{l}.linear_attn.norm.weight"] + # Should still be ~0.87, NOT 1.87 + self.assertTrue( + torch.allclose(w, torch.full_like(w, 0.87), atol=0.01), + f"Layer {l} DeltaNet norm was incorrectly modified", + ) + + +class TestRankUtil(unittest.TestCase): + """Test rank_util tensor injection.""" + + def test_rank_util_present(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + self.assertIn("rank_util.rank", result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result["rank_util.rank"], expected) + + def test_gqa_layer_rank_util(self): + config = _make_mini_config(tp_degree=4) + sd = _make_mini_state_dict(config) + result = convert_qwen35_hf_to_neuron_state_dict(sd, config) + for l in range(config.num_hidden_layers): + if config.layer_types[l] == "full_attention": + key = f"layers.{l}.self_attn.rank_util.rank" + self.assertIn(key, result) + expected = torch.arange(0, 4, dtype=torch.int32) + torch.testing.assert_close(result[key], expected) + + +class TestVLPrefixStripping(unittest.TestCase): + """Test VL wrapper prefix stripping in convert_hf_to_neuron_state_dict.""" + + def test_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + # Wrap with VL prefix + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"language_model.{k}"] = v + vl_sd["visual.encoder.weight"] = torch.zeros(10) # should be skipped + vl_sd["mtp.something"] = torch.zeros(5) # should be skipped + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertNotIn("visual.encoder.weight", result) + self.assertNotIn("mtp.something", result) + self.assertIn("norm.weight", result) + + def test_model_language_model_prefix_stripped(self): + config = _make_mini_config() + sd = _make_mini_state_dict(config) + + vl_sd = {} + for k, v in sd.items(): + vl_sd[f"model.language_model.{k}"] = v + + result = NeuronQwen35ForCausalLM.convert_hf_to_neuron_state_dict(vl_sd, config) + self.assertIn("norm.weight", result) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/Qwen3.6-27B/vllm/README.md b/contrib/models/Qwen3.6-27B/vllm/README.md new file mode 100644 index 00000000..54d6f904 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/README.md @@ -0,0 +1,262 @@ +# Qwen3.6-27B vLLM on Neuron + +This folder contains the first-pass vLLM integration helpers for the +Qwen3.6-27B contrib model. + +The current goal is **vLLM serving through the Neuron/NxDI plugin** for the +validated Qwen3.6 artifact, including long prompts through vLLM's native +chunked-prefill scheduler. + +## Which vLLM Neuron Package? + +Use the vLLM-on-Neuron environment that matches the installed Neuron SDK first. +For SDK 2.29, the AWS Neuron guide lists the NxDI/vLLM plugin stack as +`vLLM 0.16.0` with plugin version `0.5.0`. The +`vllm-project/vllm-neuron` repository is useful source/reference material, but +its README currently describes a beta plugin path tied to older `vLLM 0.11.0` +and SDK 2.26.1. Do not downgrade the working SDK 2.29 environment just to use +that repository. + +On a DLAMI, prefer the preinstalled vLLM/Neuron environment when available. If +the instance does not have one, install the Neuron-compatible vLLM plugin/fork +using the current AWS guide, then run the contrib registry patch below. + +## What Works First + +- Register the contrib `qwen3_5` text model with the NxDI model registry inside + the vLLM environment. +- Start vLLM with `VLLM_PLUGINS=neuron`. +- Load a small-context model or a precompiled artifact with + `NEURON_COMPILED_ARTIFACTS`. +- Run a short OpenAI-compatible smoke prompt. + +## Chunked Prefill Note + +The Neuron plugin disables vLLM chunked prefill by default and installs a custom +continuous-batching scheduler. For this Qwen3.6 artifact we need vLLM's native +chunked-prefill scheduler so prompts longer than the 512-token context graph are +fed to the precompiled model in 512-token chunks. The launcher sets +`DISABLE_NEURON_CUSTOM_SCHEDULER=1` when `--enable-vllm-chunked-prefill` is +passed. It also launches with `--generation-config vllm` so model +`generation_config.json` does not silently override deterministic sampling +defaults. + +## Install The Contrib Registry Patch + +Activate the vLLM/Neuron environment on the instance, then run: + +```bash +cd /home/ubuntu/inferentia-gdn +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh +``` + +If your vLLM environment is not in a standard location: + +```bash +contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference +``` + +The installer only patches the active environment. It does not modify core repo +files. + +## Start vLLM + +Small-context compile/load path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --max-model-len 512 \ + --port 8000 +``` + +Precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --port 8000 +``` + +Long-prompt precompiled artifact path: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8000 +``` + +Native vLLM prefix-cache experiment: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --enable-prefix-caching \ + --mamba-cache-mode align \ + --port 8000 +``` + +Treat this as an experiment, not a production mode, until validation passes. +Standard vLLM APC reuses attention KV blocks; Qwen3.6 also needs DeltaNet +recurrent state and conv state at block boundaries. If native APC does not +produce exact greedy matches and a clear warm-hit speedup, the next step is a +hybrid APC path that caches those GDN states alongside attention KV. + +Production chat proxy: + +```bash +contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --port 8001 +``` + +Then expose the guarded OpenAI-compatible endpoint on port 8000: + +```bash +python contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py \ + --backend-url http://127.0.0.1:8001 \ + --port 8000 +``` + +The proxy forces `chat_template_kwargs={"enable_thinking": false}` for +`/v1/chat/completions` by default. It rejects raw `/v1/completions` because raw +prompts bypass the Qwen chat template and can pollute the hybrid model state. +It also hoists `system` and `developer` messages to a single leading `system` +message because the Qwen chat template rejects system messages that appear later +in the conversation. Use `--allow-thinking` or `--allow-completions` only for +explicit debugging. + +Offline long-prompt smoke: + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --chat \ + --prompt "$(python - <<'PY' +print('Summarize this document in one paragraph. ' + 'Neuron inference ' * 700) +PY +)" +``` + +Offline token-exact prefix-cache validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode align +``` + +Offline partial-prefix validation: + +```bash +python validation_scripts/qwen36_vllm_prefix_cache_partial_offline.py \ + --repo-root /home/ubuntu/inferentia-gdn \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --block-size 256 \ + --enable-vllm-chunked-prefill \ + --mamba-cache-mode align +``` + +Server-side prefix-cache validation through the guarded proxy: + +```bash +python validation_scripts/qwen36_prefix_cache_validation.py \ + --base-url http://127.0.0.1:8000 \ + --model qwen3.6-27b-neuron-128k-fp8-mlp +``` + +The acceptance gate is strict: repeated greedy calls must produce identical +output, and warm-hit latency should be materially lower than cold-fill latency. +For hybrid Qwen3.6, prefix-cache validation is not complete until the GDN +recurrent/conv state behavior is proven, not just attention KV cache hits. + +Native APC validation run on Trn2 with the FP8 128K artifact: + +- server exact-repeat, `~10.8K` prompt tokens: `26.68s` cold to `1.67s` warm, + `16.0x` speedup, exact greedy text match; +- offline exact-repeat, token IDs exposed: `26.19s` cold to `2.38s` warm, + `11.0x` speedup, exact greedy token-ID match; +- offline partial-prefix reuse, token IDs exposed: `25.52s` no-cache target to + `1.70s` APC target after a different shared-prefix warmup request, `15.0x` + speedup, exact greedy token-ID match. +- server hardening, exact repeat: `25.38s` cold to `1.55s` warm, `16.35x` + speedup, exact text match; +- server hardening, cross-prefix reuse after unrelated prefix: `25.17s` cold to + `1.36s` warm, exact text match; +- shared-prefix concurrency at 1/2/4 requests returned all requested markers + exactly; the artifact still queues because it is compiled for `max_num_seqs=1`. + +Validation run on Trn2 with the FP8 128K artifact: + +- state-reset artifact: `/opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1`; +- OpenAI-compatible `/v1/chat/completions` behind the proxy passes focused + quality checks without callers passing `chat_template_kwargs`; +- repeated short-after-long validation passes after 32K and 64K requests, + confirming DeltaNet recurrent/conv state is reset for new requests; +- 32K and 64K needle retrieval prompts return all expected codes; +- measured prefill is `404-428 tok/s` from 512 through 64K prompt tokens; +- measured decode is `26.3-26.6 tok/s`; +- peak Neuron device memory is about `53.25 GB` decimal for the 64K eval. + +Raw `/v1/completions` prompts are not chat-templated and can pollute the hybrid +state if sent directly to the backend. Keep the backend private and expose the +proxy on the public port for production calls. + +## Offline Smoke + +```bash +python contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py \ + --model-path /opt/dlami/nvme/models/Qwen3.6-27B \ + --compiled-artifacts /opt/dlami/nvme/qwen_artifacts/qwen36_27b_128k_fp8_mlp_only_vllm_statereset_run1 \ + --max-model-len 131072 \ + --seq-len 131072 \ + --cte-bucket 512 \ + --chat \ + --prompt "What is 17 * 23? Answer with the number only." +``` + +## Next Milestone + +Validate native vLLM prefix caching with the token-exact offline harness. If it +does not pass, implement hybrid APC by saving/restoring DeltaNet recurrent and +conv state at block boundaries. diff --git a/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py new file mode 100644 index 00000000..f764048a --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/hf_qwen35_config.py @@ -0,0 +1,68 @@ +"""Minimal Hugging Face config registration for Qwen3.5/Qwen3.6 vLLM smoke. + +The Neuron vLLM environment can lag upstream Transformers. vLLM validates the +HF config before the NxDI model registry gets a chance to instantiate the +contrib model, so register a permissive config class for the new model_type. +""" + +from __future__ import annotations + +from transformers import AutoConfig, PretrainedConfig + + +class Qwen35TextConfig(PretrainedConfig): + model_type = "qwen3_5_text" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + +class Qwen35Config(PretrainedConfig): + model_type = "qwen3_5" + sub_configs = {"text_config": Qwen35TextConfig} + + def __init__(self, text_config=None, **kwargs): + if isinstance(text_config, dict): + text_config = Qwen35TextConfig(**text_config) + self.text_config = text_config + if text_config is not None: + for name, value in text_config.to_dict().items(): + if name not in {"architectures", "model_type"}: + kwargs.setdefault(name, value) + rope_parameters = getattr(text_config, "rope_parameters", None) + if isinstance(rope_parameters, dict): + kwargs.setdefault("rope_theta", rope_parameters.get("rope_theta")) + super().__init__(**kwargs) + + +def _is_registered(model_type: str) -> bool: + try: + AutoConfig.for_model(model_type) + except ValueError: + return False + return True + + +def register_qwen35_hf_config() -> None: + if not _is_registered(Qwen35TextConfig.model_type): + AutoConfig.register(Qwen35TextConfig.model_type, Qwen35TextConfig) + if not _is_registered(Qwen35Config.model_type): + AutoConfig.register(Qwen35Config.model_type, Qwen35Config) + + +def register_qwen35_vllm_architecture() -> None: + try: + from vllm.model_executor.models import ModelRegistry + except Exception: + return + + supported_archs = ModelRegistry.get_supported_archs() + qwen3_impl = "vllm.model_executor.models.qwen3:Qwen3ForCausalLM" + for arch in ("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM"): + if arch not in supported_archs: + ModelRegistry.register_model(arch, qwen3_impl) + + +def register_qwen35_config() -> None: + register_qwen35_hf_config() + register_qwen35_vllm_architecture() diff --git a/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh new file mode 100755 index 00000000..f21536eb --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" + +if [[ $# -gt 0 ]]; then + VENV="$1" +else + VENV="" + for candidate in \ + /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_13 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_12 \ + /opt/aws_neuronx_venv_pytorch_inference_vllm_0_11 + do + if [[ -x "${candidate}/bin/python" ]]; then + VENV="${candidate}" + break + fi + done +fi + +if [[ -z "${VENV}" || ! -x "${VENV}/bin/python" ]]; then + echo "ERROR: Could not find a vLLM/Neuron Python environment." >&2 + echo "Usage: $0 /path/to/venv" >&2 + exit 1 +fi + +PYTHON="${VENV}/bin/python" +export PATH="${VENV}/bin:${PATH}" +export PYTHONPATH="${CONTRIB_ROOT}:${PYTHONPATH:-}" + +echo "vLLM/Neuron env : ${VENV}" +echo "Contrib root : ${CONTRIB_ROOT}" + +"${PYTHON}" "${SCRIPT_DIR}/patch_nxdi_registry.py" --contrib-root "${CONTRIB_ROOT}" + +"${PYTHON}" - <<'PY' +import importlib.util +from neuronx_distributed_inference.utils.constants import MODEL_TYPES + +if importlib.util.find_spec("vllm") is None: + raise RuntimeError("vLLM is not installed in this environment") + +if importlib.util.find_spec("vllm_neuron") is None: + print( + "WARNING: vllm_neuron package was not found. If this environment uses " + "an AWS vLLM fork with built-in Neuron support this may be fine; " + "otherwise install the Neuron vLLM plugin that matches this SDK.", + ) + +for key in ("qwen3_5", "qwen3_5_text"): + assert key in MODEL_TYPES, f"{key} missing from MODEL_TYPES" + assert "causal-lm" in MODEL_TYPES[key], f"{key}/causal-lm missing" +print("Qwen3.6 vLLM registry verification OK") +PY + +echo "Installation complete." +echo "Remember to set PYTHONPATH=${CONTRIB_ROOT} when starting vLLM." diff --git a/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py new file mode 100644 index 00000000..91fe41c5 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/patch_nxdi_registry.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +"""Register Qwen3.6 contrib model in the installed NxDI registry. + +This patches the active Python environment, not the repository checkout. The +runtime still needs PYTHONPATH to include contrib/models/Qwen3.6-27B so that +`src.modeling_qwen35` can be imported by the vLLM process. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + + +MARKER_BEGIN = "# QWEN36_CONTRIB_VLLM_REGISTER_BEGIN" +MARKER_END = "# QWEN36_CONTRIB_VLLM_REGISTER_END" + +REGISTRATION_BLOCK = f""" + +{MARKER_BEGIN} +# Registered by contrib/models/Qwen3.6-27B/vllm/install_qwen36_vllm.sh. +# Requires PYTHONPATH to include the Qwen3.6-27B contrib directory at runtime. +try: + from src.modeling_qwen35 import ( + NeuronQwen35ForCausalLM as _Qwen36ContribForCausalLM, + ) +except Exception: + _Qwen36ContribForCausalLM = None + +if _Qwen36ContribForCausalLM is not None: + MODEL_TYPES.setdefault("qwen3_5", {{}})["causal-lm"] = _Qwen36ContribForCausalLM + MODEL_TYPES.setdefault("qwen3_5_text", {{}})["causal-lm"] = _Qwen36ContribForCausalLM +{MARKER_END} +""" + + +def _constants_path() -> Path: + import neuronx_distributed_inference.utils.constants as constants # noqa: WPS433 + + return Path(constants.__file__).resolve() + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--contrib-root", required=True) + parser.add_argument("--dry-run", action="store_true") + args = parser.parse_args() + + contrib_root = Path(args.contrib_root).expanduser().resolve() + if not (contrib_root / "src" / "modeling_qwen35.py").exists(): + raise FileNotFoundError(f"Qwen3.6 contrib root looks invalid: {contrib_root}") + + path = _constants_path() + text = path.read_text() + if MARKER_BEGIN in text: + print(f"Registry already patched: {path}") + return 0 + + patched = text.rstrip() + REGISTRATION_BLOCK + "\n" + print(f"Patch target: {path}") + if args.dry_run: + print("Dry run; no files written") + return 0 + + path.write_text(patched) + print("Patched NxDI MODEL_TYPES with qwen3_5 and qwen3_5_text") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py new file mode 100644 index 00000000..d8bd0bda --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/qwen36_chat_proxy.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +"""Small OpenAI-compatible guard proxy for Qwen3.6 vLLM serving. + +The upstream Qwen3.6 chat template defaults to thinking mode. For this Neuron +artifact the production-safe chat path is non-thinking mode, so this proxy +injects ``chat_template_kwargs={"enable_thinking": false}`` for chat requests. +It also blocks raw completions by default because they are not chat-templated. +""" + +from __future__ import annotations + +import argparse +import json +import os +import urllib.error +import urllib.request +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any + + +def _json_response(handler: BaseHTTPRequestHandler, status: int, payload: dict[str, Any]): + body = json.dumps(payload).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.end_headers() + handler.wfile.write(body) + + +def _message_text(content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + elif isinstance(item, str): + parts.append(item) + return "\n".join(parts) + return str(content) + + +def _normalize_messages_for_qwen(messages: Any) -> Any: + """Make common OpenAI message layouts acceptable to the Qwen chat template.""" + if not isinstance(messages, list): + return messages + + system_parts: list[str] = [] + normal_messages: list[Any] = [] + for message in messages: + if not isinstance(message, dict): + normal_messages.append(message) + continue + + role = message.get("role") + if role in {"system", "developer"}: + system_parts.append(_message_text(message.get("content", ""))) + else: + normal_messages.append(message) + + if not system_parts: + return messages + + system_message = { + "role": "system", + "content": "\n\n".join(part for part in system_parts if part), + } + return [system_message, *normal_messages] + + +class Qwen36ProxyHandler(BaseHTTPRequestHandler): + backend_url: str = "http://127.0.0.1:8001" + force_disable_thinking: bool = True + allow_completions: bool = False + + def log_message(self, fmt: str, *args): # noqa: D401 + print(f"{self.address_string()} - {fmt % args}", flush=True) + + def _forward(self, method: str, body: bytes | None = None): + headers = { + key: value + for key, value in self.headers.items() + if key.lower() not in {"host", "content-length", "connection"} + } + url = self.backend_url.rstrip("/") + self.path + req = urllib.request.Request(url, data=body, headers=headers, method=method) + try: + with urllib.request.urlopen(req, timeout=None) as resp: + response_body = resp.read() + self.send_response(resp.status) + for key, value in resp.headers.items(): + if key.lower() in {"transfer-encoding", "connection"}: + continue + self.send_header(key, value) + self.end_headers() + self.wfile.write(response_body) + except urllib.error.HTTPError as exc: + error_body = exc.read() + self.send_response(exc.code) + for key, value in exc.headers.items(): + if key.lower() in {"transfer-encoding", "connection"}: + continue + self.send_header(key, value) + self.end_headers() + self.wfile.write(error_body) + + def do_GET(self): # noqa: N802 + self._forward("GET") + + def do_POST(self): # noqa: N802 + length = int(self.headers.get("Content-Length", "0") or "0") + raw_body = self.rfile.read(length) if length else b"" + + if self.path == "/v1/completions" and not self.allow_completions: + _json_response( + self, + 400, + { + "error": { + "message": ( + "Raw /v1/completions is disabled for Qwen3.6. " + "Use /v1/chat/completions so the Qwen chat template " + "and non-thinking mode are applied." + ), + "type": "invalid_request_error", + "code": "qwen36_chat_required", + } + }, + ) + return + + if self.path == "/v1/chat/completions" and raw_body: + try: + payload = json.loads(raw_body) + except json.JSONDecodeError: + self._forward("POST", raw_body) + return + + template_kwargs = payload.get("chat_template_kwargs") + if not isinstance(template_kwargs, dict): + template_kwargs = {} + if self.force_disable_thinking: + template_kwargs["enable_thinking"] = False + else: + template_kwargs.setdefault("enable_thinking", False) + payload["chat_template_kwargs"] = template_kwargs + payload["messages"] = _normalize_messages_for_qwen(payload.get("messages")) + raw_body = json.dumps(payload).encode("utf-8") + + self._forward("POST", raw_body) + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--host", default="0.0.0.0") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--backend-url", default=os.getenv("QWEN36_BACKEND_URL", "http://127.0.0.1:8001")) + parser.add_argument("--allow-completions", action="store_true") + parser.add_argument("--allow-thinking", action="store_true") + args = parser.parse_args() + + Qwen36ProxyHandler.backend_url = args.backend_url + Qwen36ProxyHandler.allow_completions = args.allow_completions + Qwen36ProxyHandler.force_disable_thinking = not args.allow_thinking + + server = ThreadingHTTPServer((args.host, args.port), Qwen36ProxyHandler) + print( + "Qwen3.6 proxy listening on " + f"{args.host}:{args.port}, backend={args.backend_url}, " + f"allow_completions={args.allow_completions}, " + f"force_disable_thinking={not args.allow_thinking}", + flush=True, + ) + server.serve_forever() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py new file mode 100644 index 00000000..8c0eb06f --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/run_offline_inference.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Offline vLLM smoke runner for Qwen3.6-27B on Neuron.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path + + +def _contrib_root(repo_root: str | None) -> Path: + if repo_root: + return Path(repo_root).expanduser().resolve() / "contrib" / "models" / "Qwen3.6-27B" + return Path(__file__).resolve().parents[1] + + +def _override_config(args: argparse.Namespace) -> dict: + neuron_config = { + "tp_degree": args.tensor_parallel_size, + "batch_size": args.max_num_seqs, + "ctx_batch_size": 1, + "tkg_batch_size": args.max_num_seqs, + "seq_len": args.seq_len, + "max_length": args.seq_len, + "max_context_length": args.cte_bucket, + "context_encoding_buckets": [args.cte_bucket], + "token_generation_buckets": [args.seq_len], + "enable_bucketing": False, + "logical_nc_config": args.logical_nc_config, + "torch_dtype": "bfloat16", + "save_sharded_checkpoint": True, + } + if args.enable_vllm_chunked_prefill: + neuron_config.update( + { + "is_block_kv_layout": True, + "chunked_prefill_config": { + "max_num_seqs": args.max_num_seqs, + "tkg_model_enabled": True, + "kernel_q_tile_size": 128, + "kernel_kv_tile_size": 1024, + }, + } + ) + return { + "max_prompt_length": args.cte_bucket, + "override_neuron_config": neuron_config, + } + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--repo-root", default=None) + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-artifacts", default=None) + parser.add_argument("--prompt", default="What is 17 * 23? Answer with the number only.") + parser.add_argument("--chat", action="store_true") + parser.add_argument("--enable-vllm-chunked-prefill", action="store_true") + parser.add_argument("--enable-prefix-caching", action="store_true") + parser.add_argument("--mamba-cache-mode", default=None) + parser.add_argument("--mamba-cache-dtype", default=None) + parser.add_argument("--mamba-ssm-cache-dtype", default=None) + parser.add_argument("--max-tokens", type=int, default=64) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--tensor-parallel-size", type=int, default=4) + parser.add_argument("--logical-nc-config", type=int, default=2) + parser.add_argument("--max-num-seqs", type=int, default=1) + parser.add_argument("--max-model-len", type=int, default=512) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--cte-bucket", type=int, default=512) + parser.add_argument("--block-size", type=int, default=256) + args = parser.parse_args() + + contrib_root = _contrib_root(args.repo_root) + script_dir = Path(__file__).resolve().parent + sys.path.insert(0, str(script_dir)) + sys.path.insert(0, str(contrib_root)) + os.environ["PYTHONPATH"] = ( + f"{script_dir}:{contrib_root}:{os.environ.get('PYTHONPATH', '')}" + ) + os.environ.setdefault("VLLM_NEURON_FRAMEWORK", "neuronx-distributed-inference") + os.environ.setdefault("VLLM_PLUGINS", "neuron") + if args.enable_vllm_chunked_prefill: + os.environ["DISABLE_NEURON_CUSTOM_SCHEDULER"] = "1" + if args.compiled_artifacts: + os.environ["NEURON_COMPILED_ARTIFACTS"] = str( + Path(args.compiled_artifacts).expanduser().resolve() + ) + + from hf_qwen35_config import register_qwen35_config # noqa: WPS433 + + register_qwen35_config() + + from vllm import LLM, SamplingParams # noqa: WPS433 + + prompt = args.prompt + if args.chat: + from transformers import AutoTokenizer # noqa: WPS433 + + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, + trust_remote_code=True, + ) + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": args.prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + + additional_config = _override_config(args) + print("VLLM_QWEN36_CONFIG", json.dumps(additional_config, sort_keys=True), flush=True) + + llm_kwargs = { + "model": str(Path(args.model_path).expanduser().resolve()), + "trust_remote_code": True, + "dtype": "bfloat16", + "tensor_parallel_size": args.tensor_parallel_size, + "max_num_seqs": args.max_num_seqs, + "max_model_len": args.max_model_len, + "enable_prefix_caching": args.enable_prefix_caching, + "enable_chunked_prefill": args.enable_vllm_chunked_prefill, + "additional_config": additional_config, + } + if args.mamba_cache_mode is not None: + llm_kwargs["mamba_cache_mode"] = args.mamba_cache_mode + if args.mamba_cache_dtype is not None: + llm_kwargs["mamba_cache_dtype"] = args.mamba_cache_dtype + if args.mamba_ssm_cache_dtype is not None: + llm_kwargs["mamba_ssm_cache_dtype"] = args.mamba_ssm_cache_dtype + if args.enable_vllm_chunked_prefill: + llm_kwargs["max_num_batched_tokens"] = args.cte_bucket + llm_kwargs["block_size"] = args.block_size + llm = LLM(**llm_kwargs) + + sampling = SamplingParams( + temperature=args.temperature, + top_k=args.top_k, + max_tokens=args.max_tokens, + ) + start = time.perf_counter() + outputs = llm.generate([prompt], sampling) + elapsed = time.perf_counter() - start + text = outputs[0].outputs[0].text + token_ids = outputs[0].outputs[0].token_ids + + print("PROMPT", prompt) + print("OUTPUT", text) + print("TOKENS", list(token_ids)) + print("ELAPSED_SECONDS", f"{elapsed:.3f}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py new file mode 100644 index 00000000..85e12ef6 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/serve_qwen36.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 +"""vLLM CLI wrapper that registers Qwen3.6 aliases before validation.""" + +from __future__ import annotations + +import sys + +from hf_qwen35_config import register_qwen35_config + + +def main() -> int: + register_qwen35_config() + + from vllm.entrypoints.cli.main import main as vllm_main + + sys.argv = ["vllm", "serve", *sys.argv[1:]] + return int(vllm_main() or 0) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py new file mode 100644 index 00000000..dcec3056 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/sitecustomize.py @@ -0,0 +1,9 @@ +"""Auto-register Qwen3.5/Qwen3.6 HF config when this folder is on PYTHONPATH. + +Do not import vLLM here. Neuron helper commands such as libneuronpjrt-path run +inside Python subprocesses and expect clean stdout. +""" + +from hf_qwen35_config import register_qwen35_hf_config + +register_qwen35_hf_config() diff --git a/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh new file mode 100755 index 00000000..46342690 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/vllm/start_vllm_server.sh @@ -0,0 +1,147 @@ +#!/usr/bin/env bash +set -euo pipefail + +MODEL_PATH="" +COMPILED_ARTIFACTS="" +MAX_MODEL_LEN="512" +SEQ_LEN="512" +CTE_BUCKET="512" +TP_DEGREE="4" +LNC="2" +MAX_NUM_SEQS="1" +PORT="8000" +HOST="0.0.0.0" +ENABLE_CHUNKED_PREFILL="0" +ENABLE_PREFIX_CACHING="0" +MAMBA_CACHE_MODE="" +MAMBA_CACHE_DTYPE="" +MAMBA_SSM_CACHE_DTYPE="" +BLOCK_SIZE="" + +while [[ $# -gt 0 ]]; do + case "$1" in + --model-path) MODEL_PATH="$2"; shift 2 ;; + --compiled-artifacts) COMPILED_ARTIFACTS="$2"; shift 2 ;; + --max-model-len) MAX_MODEL_LEN="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --cte-bucket) CTE_BUCKET="$2"; shift 2 ;; + --tensor-parallel-size) TP_DEGREE="$2"; shift 2 ;; + --logical-nc-config) LNC="$2"; shift 2 ;; + --max-num-seqs) MAX_NUM_SEQS="$2"; shift 2 ;; + --enable-vllm-chunked-prefill) ENABLE_CHUNKED_PREFILL="1"; shift ;; + --enable-prefix-caching) ENABLE_PREFIX_CACHING="1"; shift ;; + --disable-prefix-caching|--no-enable-prefix-caching) ENABLE_PREFIX_CACHING="0"; shift ;; + --mamba-cache-mode) MAMBA_CACHE_MODE="$2"; shift 2 ;; + --mamba-cache-dtype) MAMBA_CACHE_DTYPE="$2"; shift 2 ;; + --mamba-ssm-cache-dtype) MAMBA_SSM_CACHE_DTYPE="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --host) HOST="$2"; shift 2 ;; + --port) PORT="$2"; shift 2 ;; + *) echo "Unknown argument: $1" >&2; exit 2 ;; + esac +done + +if [[ -z "${MODEL_PATH}" ]]; then + echo "ERROR: --model-path is required" >&2 + exit 2 +fi + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +CONTRIB_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +export PYTHONPATH="${SCRIPT_DIR}:${CONTRIB_ROOT}:${PYTHONPATH:-}" +export VLLM_NEURON_FRAMEWORK="neuronx-distributed-inference" +export VLLM_PLUGINS="${VLLM_PLUGINS:-neuron}" + +if [[ -n "${COMPILED_ARTIFACTS}" ]]; then + export NEURON_COMPILED_ARTIFACTS="${COMPILED_ARTIFACTS}" +fi +if [[ -z "${BLOCK_SIZE}" ]]; then + BLOCK_SIZE="256" +fi +if [[ "${ENABLE_CHUNKED_PREFILL}" == "1" ]]; then + export DISABLE_NEURON_CUSTOM_SCHEDULER="1" +fi + +ADDITIONAL_CONFIG="$( + python3 - < Date: Wed, 13 May 2026 10:59:40 +0530 Subject: [PATCH 02/21] Docs: align Qwen3.6 README with contrib guidelines --- contrib/models/Qwen3.6-27B/README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index cdba94ba..5e244702 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -26,7 +26,7 @@ Qwen3.6 weights. | Model | HuggingFace ID | Params | Instance | |-------|----------------|--------|----------| -| **Qwen3.6-27B** | `Qwen/Qwen3.6-27B` | 27B | trn2.3xlarge (TP=4) | +| **Qwen3.6-27B** | [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) | 27B | trn2.3xlarge (TP=4) | **License:** Apache 2.0 @@ -63,9 +63,11 @@ Qwen3.6 weights. |-------------|-------|--------| | test_config.py | 26 | 26/26 PASS | | test_weight_conversion.py | 16 | 16/16 PASS | -| **Total** | **42** | **42/42 PASS** | +| test_hybrid_cache_manager.py | 13 | 13/13 PASS | +| test_deltanet_decay.py | 2 | 2/2 PASS | +| **Total** | **57** | **57/57 PASS** | -Unit tests are architecture-level and do not depend on weights. Identical results to Qwen3.5-27B. +Unit tests are architecture-level and do not depend on weights. Coverage includes config parsing, weight conversion, hybrid cache allocation/update behavior, and DeltaNet decay handling. ### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4, SDK 2.29) @@ -304,7 +306,7 @@ cd contrib/models/Qwen3.6-27B/ pytest test/unit/ -v ``` -Tests: config parsing (26), weight conversion (16) = **42 tests**. +Tests: config parsing (26), weight conversion (16), hybrid cache manager (13), and DeltaNet decay handling (2) = **57 tests**. ### Integration Tests (needs trn2.3xlarge with 4 NeuronCores) @@ -322,7 +324,7 @@ Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses th ## Example Checkpoints -- `Qwen/Qwen3.6-27B` (BF16, ~52 GB) +- [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) (BF16, ~52 GB) ## Maintainer From 6d6ae625428ee52b0c1df8eafd1e2f1ed24af746 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Wed, 13 May 2026 12:28:50 +0530 Subject: [PATCH 03/21] Docs: add Qwen3.6 TTFT and TPOT benchmarks --- contrib/models/Qwen3.6-27B/README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 5e244702..a6327f07 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -130,6 +130,17 @@ with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. | State reset | repeated short-after-long validation passed after 32K and 64K requests | | Peak Neuron device memory | ~53.25 GB decimal during the 64K eval | +TTFT/TPOT details for the same 128K FP8/vLLM artifact: + +| Metric | Result | Notes | +|--------|--------|-------| +| Decode TPOT | ~37.6-38.0 ms/token | Derived from 26.3-26.6 tok/s decode | +| Cold 512-token TTFT | ~1.2-1.3s | Derived from measured prefill throughput plus one decode step | +| Cold 32K-token TTFT | ~76.6-81.1s | Derived from measured prefill throughput plus one decode step | +| Cold 64K-token TTFT | ~153-162s | Derived from measured prefill throughput plus one decode step | +| Warm APC latency, ~10.8K prompt | 1.36-2.38s | Exact-repeat, partial-prefix, and cross-prefix validation runs | +| Cold APC baseline, ~10.8K prompt | 25.17-26.68s | Same prompts with prefix cache disabled or cold | + Native vLLM prefix caching/APC was also validated with exact greedy output matches: From 8760bf58958b2194d698492450804be9bb662590 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Mon, 18 May 2026 20:58:17 +0530 Subject: [PATCH 04/21] Docs: summarize Hybrid APC follow-up status --- contrib/models/Qwen3.6-27B/README.md | 52 ++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index a6327f07..60f40db9 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -151,6 +151,58 @@ matches: | Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | | Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | +### Hybrid APC Follow-up Status + +Follow-up work on the `experimental` branch extended the baseline vLLM/APC +path toward Qwen3.6 Hybrid APC, where attention KV prefix reuse is only correct +when the matching DeltaNet recurrent/conv checkpoint is also available. + +What has been implemented and proven in that branch: + +- Scheduler-side safety gating prevents vLLM from reading an attention prefix + unless a matching GDN checkpoint is registered. +- Qwen request prep consumes scheduler-authorized, request-scoped restore keys + instead of relying on prefix length alone. +- The CTE restore path handles suffix-only execution over a restored prefix: + suffix tokens, slot mapping, `computed_context_lens`, `num_queries`, and GDN + restore metadata are kept aligned. +- BF16 single-request backed-prefix validation passes with cold/warm exactness + on the 2K checkpoint-boundary case. The proven shape restores a 256-token GDN + checkpoint, executes a 16-token suffix, and matches cold output. +- The safety fallback also passes: if attention KV has a prefix hit but no GDN + checkpoint exists, prefix reads are disabled and the request recomputes cold. + +Current blocker: + +- True generated-token batch-2 validation needs both `tkg_batch_size=2` and + `ctx_batch_size=2`. +- A batch-2 artifact with `tkg_batch_size=2` but `ctx_batch_size=1` failed in + vLLM-Neuron host-logits sampling because two prefills were packed into one + CTE row, then logits were reordered for two live request ids. +- Single-bucket `ctx_batch_size=2` / `tkg_batch_size=2` BF16 artifacts for CTE + bucket 256 and CTE bucket 512 compiled successfully. +- The combined multi-bucket artifact (`cte_buckets=256,512`, + `prefix_buckets=256,512`) started compiling and the TKG priority HLO passed, + but the smaller Trainium instance became SSH-unresponsive during all-HLO CTE + compilation. This appears to be a Neuron/NxDI compile-capacity or compile + orchestration issue, not a model-correctness failure. + +Expected outcome after the batch-2 artifact or an equivalent prefill-only +proof is available: + +- Batched Hybrid APC can preserve the same correctness rule as the + single-request path: + `usable_prefix_hit = attention_KV_prefix_hit AND matching_GDN_checkpoint_hit`. +- Warm repeated-prefix and partial-prefix requests should avoid replaying the + shared cold prefill while restoring the required GDN state. +- This is the path expected to turn the current exact single-request APC proof + into a measured cold-prefill performance win for batched serving. + +The fused CTE kernel and FP8 path are not the current correctness blockers. +The BF16 per-chunk CTE path is the reference path for Hybrid APC validation: +the fused BF16 CTE artifact has shown NaNs around token 105-106, and FP8 should +be revisited after the BF16 batch-2 serving contract is proven. + ### Key Observations - **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. From ac7df718cef6beed3bc1bf2fffa2f6a64e57b6c2 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Mon, 18 May 2026 22:11:08 +0530 Subject: [PATCH 05/21] Docs: correct Trn2 compatibility matrix --- contrib/models/Qwen3.6-27B/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 60f40db9..650a7012 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -346,7 +346,8 @@ larger batches or additional serving headroom. | Instance | TP | LNC | Status | Notes | |----------|-----|-----|--------|-------| | trn2.3xlarge | 4 | 2 | **PASS** | BF16 short-context and FP8 128K vLLM/APC validated | -| trn2.12xlarge | 16 | 2 | Expected PASS | Untested, recommended for batching/headroom | +| trn2.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; use the same TP=4 artifact shape when compiling for trn2.3xlarge deployment | +| trn2u.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; same portability caveat as trn2.48xlarge | ### SDK Configuration From 11c550ce625a5243a016120edd26683d4d87be40 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 10:22:49 +0530 Subject: [PATCH 06/21] Stabilize Qwen fused DeltaNet decay (cherry picked from commit 8af0219d62a258556ec106eef8c008edbcb3f285) --- .../src/nki_kernels/nki_deltanet_fused.py | 215 ++++++++++-------- 1 file changed, 115 insertions(+), 100 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py index 6008ae5a..f6b6e0ee 100644 --- a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py @@ -16,8 +16,9 @@ NKI 0.3.0 (SDK 2.29). k_dim = v_dim = 128 = P_MAX exactly. Chunk size = 128 = P_MAX (one tile per chunk). -Mathematical framework (same as nki_deltanet_chunked.py): +Mathematical framework: Per-chunk Neumann-series power-doubling for intra-chunk correction: + QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j A = -QK_decay * lower_mask N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] value_corr = N @ v_beta @@ -216,7 +217,7 @@ def deltanet_fused_chunked_fwd( src=gc_row[0:1, CHUNK_SIZE - 1 : CHUNK_SIZE], ) - # ---- Compute exp(gc), exp(-gc), exp(g_last) as (P_MAX, 1) scalars ---- + # ---- Compute exp(gc) and exp(g_last) as (P_MAX, 1) scalars ---- # These (P_MAX, 1) tensors are used with tensor_scalar to broadcast # across the free dimension without explicit (P_MAX, dim) copies. @@ -229,22 +230,15 @@ def deltanet_fused_chunked_fwd( scale=1.0, ) - neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=neg_gc_p, - data=gc_p, - op0=nl.multiply, - operand0=-1.0, - engine=nisa.vector_engine, - ) - exp_neg_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) - nisa.activation( - dst=exp_neg_gc_p[0:P_MAX, 0:1], - op=nl.exp, - data=neg_gc_p[0:P_MAX, 0:1], - bias=None, - scale=1.0, - ) + # g_last: scalar, then broadcast to (P_MAX, 1) for direct + # exp(g_last - gc) in the state update. + gl_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gl_11[0:1, 0:1], + dst=gl_p[i_shuf * 32 : i_shuf * 32 + 32, 0:1], + shuffle_mask=_BROADCAST_MASK, + ) # exp(g_last): scalar, then broadcast to (P_MAX, 1) exp_gl_11 = nl.ndarray((1, 1), dtype=nl.float32, buffer=nl.sbuf) @@ -264,6 +258,88 @@ def deltanet_fused_chunked_fwd( shuffle_mask=_BROADCAST_MASK, ) + # ============================================================ + # Stable pairwise decay factors from cumulative log-decay. + # + # The original fused path used split scaling: + # exp(gc[i]) * exp(-gc[j]) + # That can materialize huge unused intermediates. Build the same + # causal decay matrices as the per-chunk kernel using exp(gc[i]-gc[j]) + # and mask after the exp so upper-triangular values cannot leak into + # later matmuls. + # ============================================================ + gc_row_broadcast = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + for i_shuf in nl.static_range(P_MAX // 32): + nisa.nc_stream_shuffle( + src=gc_row[0:1, 0:P_MAX], + dst=gc_row_broadcast[i_shuf * 32 : i_shuf * 32 + 32, 0:P_MAX], + shuffle_mask=_BROADCAST_MASK, + ) + + gc_col_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_strict, + data=Lmask, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_strict, data1=gc_row_broadcast, data2=Lmask, op=nl.multiply + ) + g_diff_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_strict, + data1=gc_col_strict, + data2=gc_row_strict, + op=nl.subtract, + ) + decay_strict_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_strict_raw, + op=nl.exp, + data=g_diff_strict, + bias=None, + scale=1.0, + ) + decay_strict = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_strict, data1=decay_strict_raw, data2=Lmask, op=nl.multiply + ) + + gc_col_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=gc_col_diag, + data=Lmask_d, + op0=nl.multiply, + operand0=gc_p, + engine=nisa.vector_engine, + ) + gc_row_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=gc_row_diag, data1=gc_row_broadcast, data2=Lmask_d, op=nl.multiply + ) + g_diff_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=g_diff_diag, + data1=gc_col_diag, + data2=gc_row_diag, + op=nl.subtract, + ) + decay_diag_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=decay_diag_raw, + op=nl.exp, + data=g_diff_diag, + bias=None, + scale=1.0, + ) + decay_diag = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=decay_diag, data1=decay_diag_raw, data2=Lmask_d, op=nl.multiply + ) + # ============================================================ # k_beta = K * beta, v_beta = V * beta # tensor_scalar broadcasts beta_p (P_MAX, 1) across free dim @@ -306,49 +382,9 @@ def deltanet_fused_chunked_fwd( QK = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=QK, src=QK_psum) - # ============================================================ - # Decay mask: QK_decay[i,j] = QK[i,j] * exp(gc[i]) * exp(-gc[j]) - # - # Apply the strict causal mask before the split exp(gc) / exp(-gc) - # scaling. Upper-triangular entries are mathematically unused, but - # scaling them first can create very large finite values that poison - # later matmuls before the mask is applied. - # ============================================================ - QK_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=QK_masked, data1=QK, data2=Lmask, op=nl.multiply) - - # Row scaling: QK_row[i,:] = QK[i,:] * exp(gc[i]) - # Then transpose, column scale, transpose back. - # Uses tensor_scalar with (P_MAX,1) operand for row scaling. - QK_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_row, - data=QK_masked, - op0=nl.multiply, - operand0=exp_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose to scale columns (now rows in transposed view) - QK_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_r_T_psum, data=QK_row) - QK_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_r_T, src=QK_r_T_psum) - - QK_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=QK_r_T_col, - data=QK_r_T, - op0=nl.multiply, - operand0=exp_neg_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose back - QK_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=QK_d_psum, data=QK_r_T_col) + # QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j. QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=QK_decay, src=QK_d_psum) + nisa.tensor_tensor(dst=QK_decay, data1=QK, data2=decay_strict, op=nl.multiply) # A = -QK_decay * lower_mask neg_QK_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) @@ -382,6 +418,7 @@ def deltanet_fused_chunked_fwd( Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) + nisa.tensor_tensor(dst=A_pow, data1=A_pow, data2=Lmask, op=nl.multiply) # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) @@ -395,6 +432,7 @@ def deltanet_fused_chunked_fwd( Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) nisa.tensor_copy(dst=P_acc, src=Pacc_psum) + nisa.tensor_tensor(dst=P_acc, data1=P_acc, data2=Lmask_d, op=nl.multiply) # ============================================================ # Apply N: value_corr = N @ v_beta @@ -439,40 +477,9 @@ def deltanet_fused_chunked_fwd( qk_raw = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_copy(dst=qk_raw, src=qk_psum) - # Mask before split scaling for the same reason as the A matrix above: - # upper-triangular decay factors are unused and can be numerically huge. - qk_masked = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=qk_masked, data1=qk_raw, data2=Lmask_d, op=nl.multiply) - - # Row-scale by exp(gc) - qk_row = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=qk_row, - data=qk_masked, - op0=nl.multiply, - operand0=exp_gc_p, - engine=nisa.vector_engine, - ) - - # Transpose, column-scale by exp(-gc), transpose back - qk_r_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=qk_r_T_psum, data=qk_row) - qk_r_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=qk_r_T, src=qk_r_T_psum) - - qk_r_T_col = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_scalar( - dst=qk_r_T_col, - data=qk_r_T, - op0=nl.multiply, - operand0=exp_neg_gc_p, - engine=nisa.vector_engine, - ) - - qk_d_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=qk_d_psum, data=qk_r_T_col) + # qk_decay[i,j] = (q @ k^T)[i,j] * exp(gc[i] - gc[j]) for i >= j. qk_decay = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=qk_decay, src=qk_d_psum) + nisa.tensor_tensor(dst=qk_decay, data1=qk_raw, data2=decay_diag, op=nl.multiply) attn_intra = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_tensor( @@ -548,14 +555,22 @@ def deltanet_fused_chunked_fwd( # ============================================================ # k_raw_decay contributes as exp(g_last) * (k * exp(-gc))^T @ v_new. - # Compute the equivalent stable form k * exp(g_last - gc), so the - # factor is always <= 1 for valid causal positions. - exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + # Compute the equivalent stable form k * exp(g_last - gc) directly so + # no exp(-gc) intermediate can overflow. + gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_tensor( - dst=exp_gl_minus_gc_p, - data1=exp_gl_p, - data2=exp_neg_gc_p, - op=nl.multiply, + dst=gl_minus_gc_p, + data1=gl_p, + data2=gc_p, + op=nl.subtract, + ) + exp_gl_minus_gc_p = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.activation( + dst=exp_gl_minus_gc_p[0:P_MAX, 0:1], + op=nl.exp, + data=gl_minus_gc_p[0:P_MAX, 0:1], + bias=None, + scale=1.0, ) # k_raw_decay = k * exp(g_last - gc) From 0a68b82c32f0b409c76ecaa81ee953ea4f048249 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 11:46:39 +0530 Subject: [PATCH 07/21] Add isolated fused DeltaNet NKI validation (cherry picked from commit fd30e32a6b149f68eb433ce2308ff93b2aaa7617) --- .../scripts/validate_deltanet_fused_nki.py | 260 ++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py new file mode 100644 index 00000000..d0b2874d --- /dev/null +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py @@ -0,0 +1,260 @@ +#!/usr/bin/env python3 +"""Validate and optionally inspect/profile the fused Qwen DeltaNet NKI kernel. + +The CPU reference stays off the XLA device so the generated NEFFs are from the +NKI kernel under test, not from reference PyTorch ops. +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import sys +from pathlib import Path +from typing import Any + + +P_MAX = 128 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Validate/profile deltanet_fused_chunked_fwd against CPU math." + ) + parser.add_argument("--seed", type=int, default=1234) + parser.add_argument("--seq-len", type=int, default=256) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--target", default="trn2") + parser.add_argument("--lnc", type=int, default=1) + parser.add_argument("--visible-cores", default="0") + parser.add_argument("--inspect", action="store_true") + parser.add_argument("--dge", action="store_true") + parser.add_argument( + "--inspect-dir", + default="/mnt/trainium_artifacts/profiles/deltanet_fused_isolated", + ) + parser.add_argument("--atol", type=float, default=3.0e-2) + parser.add_argument("--rtol", type=float, default=3.0e-2) + parser.add_argument("--value-scale", type=float, default=0.05) + parser.add_argument("--state-scale", type=float, default=0.01) + parser.add_argument("--gate-scale", type=float, default=0.01) + parser.add_argument("--fail-on-mismatch", action="store_true") + return parser.parse_args() + + +def configure_environment(args: argparse.Namespace) -> Path: + if args.seq_len <= 0 or args.seq_len % P_MAX != 0: + raise ValueError("--seq-len must be a positive multiple of 128") + os.environ.setdefault("NEURON_PLATFORM_TARGET_OVERRIDE", args.target) + os.environ.setdefault("NEURON_CC_FLAGS", f"--target {args.target} --lnc {args.lnc}") + os.environ.setdefault("NEURON_RT_VISIBLE_CORES", args.visible_cores) + + inspect_dir = Path(args.inspect_dir).expanduser().resolve() + if args.inspect: + inspect_dir.mkdir(parents=True, exist_ok=True) + os.environ["NEURON_RT_INSPECT_ENABLE"] = "1" + os.environ["NEURON_RT_INSPECT_DEVICE_PROFILE"] = "1" + os.environ["NEURON_RT_INSPECT_SYSTEM_PROFILE"] = "0" + os.environ["NEURON_RT_INSPECT_OUTPUT_DIR"] = str(inspect_dir) + os.environ["XLA_IR_DEBUG"] = "1" + os.environ["XLA_HLO_DEBUG"] = "1" + os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" + if args.dge: + os.environ["NEURON_RT_ENABLE_DGE_NOTIFICATIONS"] = "1" + return inspect_dir + + +def add_qwen_to_path() -> None: + script_path = Path(__file__).resolve() + qwen_root = script_path.parents[1] + sys.path.insert(0, str(qwen_root)) + + +def make_inputs(torch: Any, args: argparse.Namespace) -> dict[str, Any]: + generator = torch.Generator(device="cpu") + generator.manual_seed(args.seed) + + def randn(shape: tuple[int, ...], scale: float) -> Any: + return torch.randn(shape, generator=generator, dtype=torch.float32) * scale + + query = randn((args.seq_len, P_MAX), args.value_scale) + key = randn((args.seq_len, P_MAX), args.value_scale) + value = randn((args.seq_len, P_MAX), args.value_scale) + state_in = randn((P_MAX, P_MAX), args.state_scale) + + query = torch.nn.functional.normalize(query, p=2, dim=-1) / math.sqrt(P_MAX) + key = torch.nn.functional.normalize(key, p=2, dim=-1) + + beta = torch.sigmoid(randn((args.seq_len, 1), 1.0)) + g_raw = -torch.nn.functional.softplus(randn((args.seq_len, 1), 1.0)) + g_raw = g_raw * args.gate_scale + + lower_mask = torch.tril(torch.ones((P_MAX, P_MAX), dtype=torch.float32), diagonal=-1) + lower_mask_diag = torch.tril(torch.ones((P_MAX, P_MAX), dtype=torch.float32)) + identity = torch.eye(P_MAX, dtype=torch.float32) + + return { + "query": query.contiguous(), + "key": key.contiguous(), + "value": value.contiguous(), + "g_raw": g_raw.contiguous(), + "beta": beta.contiguous(), + "state_in": state_in.contiguous(), + "lower_mask": lower_mask.contiguous(), + "identity": identity.contiguous(), + "lower_mask_diag": lower_mask_diag.contiguous(), + } + + +def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: + lower = inputs["lower_mask"] + lower_diag = inputs["lower_mask_diag"] + eye = inputs["identity"] + state = inputs["state_in"].clone() + outputs = [] + + for start in range(0, inputs["query"].shape[0], P_MAX): + end = start + P_MAX + q = inputs["query"][start:end] + k = inputs["key"][start:end] + v = inputs["value"][start:end] + g = inputs["g_raw"][start:end] + beta = inputs["beta"][start:end] + + gc = torch.cumsum(g, dim=0) + gl = gc[-1:] + k_beta = k * beta + v_beta = v * beta + + decay = torch.exp(gc - gc.T) + decay_strict = decay * lower + decay_diag = decay * lower_diag + + qk_beta = k_beta @ k.T + a_mat = -(qk_beta * decay_strict) * lower + + # Mirror the fused kernel: Neumann power-doubling, not triangular solve. + p_acc = eye + a_mat + a_pow = a_mat.clone() + for _ in range(6): + a_pow = (a_pow @ a_pow) * lower + p_acc = ((eye + a_pow) @ p_acc) * lower_diag + + exp_gc = torch.exp(gc) + value_corr = p_acc @ v_beta + k_cumdecay = p_acc @ (k_beta * exp_gc) + attn_intra = (q @ k.T) * decay_diag + + v_new = value_corr - (k_cumdecay @ state) + chunk_out = ((q * exp_gc) @ state) + (attn_intra @ v_new) + outputs.append(chunk_out) + + k_raw_decay = k * torch.exp(gl - gc) + state = (state * torch.exp(gl)) + (k_raw_decay.T @ v_new) + + return torch.cat(outputs, dim=0).contiguous(), state.contiguous() + + +def tensor_metrics(torch: Any, actual: Any, expected: Any) -> dict[str, float | bool]: + diff = actual - expected + expected_norm = torch.linalg.vector_norm(expected).item() + diff_norm = torch.linalg.vector_norm(diff).item() + actual_flat = actual.reshape(-1).to(torch.float64) + expected_flat = expected.reshape(-1).to(torch.float64) + denom = torch.linalg.vector_norm(actual_flat) * torch.linalg.vector_norm(expected_flat) + cosine = ( + float(torch.dot(actual_flat, expected_flat) / denom) + if denom.item() != 0.0 + else float("nan") + ) + return { + "finite": bool(torch.isfinite(actual).all().item()), + "max_abs": float(torch.max(torch.abs(diff)).item()), + "mean_abs": float(torch.mean(torch.abs(diff)).item()), + "diff_norm": float(diff_norm), + "expected_norm": float(expected_norm), + "relative_norm": float(diff_norm / max(expected_norm, 1.0e-12)), + "cosine": cosine, + } + + +def main() -> int: + args = parse_args() + inspect_dir = configure_environment(args) + add_qwen_to_path() + + import torch + import torch_xla.core.xla_model as xm + + from src.nki_kernels.nki_deltanet_fused import deltanet_fused_chunked_fwd + + inputs = make_inputs(torch, args) + ref_out, ref_state = reference_math(torch, inputs) + + device = xm.xla_device() + xla_inputs = {name: tensor.to(device=device) for name, tensor in inputs.items()} + + out_cpu = state_cpu = None + for _ in range(args.runs): + out_dev, state_dev = deltanet_fused_chunked_fwd( + xla_inputs["query"], + xla_inputs["key"], + xla_inputs["value"], + xla_inputs["g_raw"], + xla_inputs["beta"], + xla_inputs["state_in"], + xla_inputs["lower_mask"], + xla_inputs["identity"], + xla_inputs["lower_mask_diag"], + ) + xm.mark_step() + out_cpu = out_dev.detach().cpu().float() + state_cpu = state_dev.detach().cpu().float() + + assert out_cpu is not None + assert state_cpu is not None + + output_close = torch.allclose(out_cpu, ref_out, atol=args.atol, rtol=args.rtol) + state_close = torch.allclose(state_cpu, ref_state, atol=args.atol, rtol=args.rtol) + output_finite = bool(torch.isfinite(out_cpu).all().item()) + state_finite = bool(torch.isfinite(state_cpu).all().item()) + passed = bool(output_close and state_close and output_finite and state_finite) + + result = { + "passed": passed, + "seed": args.seed, + "seq_len": args.seq_len, + "runs": args.runs, + "atol": args.atol, + "rtol": args.rtol, + "inspect": args.inspect, + "dge": args.dge, + "output_finite": output_finite, + "state_finite": state_finite, + "inspect_dir": str(inspect_dir), + "environment": { + key: os.environ.get(key) + for key in ( + "NEURON_CC_FLAGS", + "NEURON_PLATFORM_TARGET_OVERRIDE", + "NEURON_RT_VISIBLE_CORES", + "NEURON_RT_INSPECT_ENABLE", + "NEURON_RT_ENABLE_DGE_NOTIFICATIONS", + ) + }, + "nki_vs_reference": { + "output": tensor_metrics(torch, out_cpu, ref_out), + "state": tensor_metrics(torch, state_cpu, ref_state), + }, + } + print(json.dumps(result, indent=2, sort_keys=True)) + + if args.fail_on_mismatch and not passed: + return 2 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From d84b3c798fff413d7bf7afaacceb82a368701254 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 11:49:20 +0530 Subject: [PATCH 08/21] Load fused NKI kernel directly in validator (cherry picked from commit 399538964b64340a2ce2ebf0813f986a746496ab) --- .../scripts/validate_deltanet_fused_nki.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py index d0b2874d..19926dcc 100644 --- a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py @@ -8,6 +8,7 @@ from __future__ import annotations import argparse +import importlib.util import json import math import os @@ -72,6 +73,23 @@ def add_qwen_to_path() -> None: sys.path.insert(0, str(qwen_root)) +def load_fused_kernel(): + kernel_path = ( + Path(__file__).resolve().parents[1] + / "src" + / "nki_kernels" + / "nki_deltanet_fused.py" + ) + spec = importlib.util.spec_from_file_location( + "qwen36_nki_deltanet_fused_under_test", + kernel_path, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module.deltanet_fused_chunked_fwd + + def make_inputs(torch: Any, args: argparse.Namespace) -> dict[str, Any]: generator = torch.Generator(device="cpu") generator.manual_seed(args.seed) @@ -188,7 +206,7 @@ def main() -> int: import torch import torch_xla.core.xla_model as xm - from src.nki_kernels.nki_deltanet_fused import deltanet_fused_chunked_fwd + deltanet_fused_chunked_fwd = load_fused_kernel() inputs = make_inputs(torch, args) ref_out, ref_state = reference_math(torch, inputs) From 3e56e5814d89d1fedfa3536adde0bdd05ef4c277 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 18:24:22 +0530 Subject: [PATCH 09/21] Fix fused DeltaNet solve stability (cherry picked from commit ae2613d0fb3be7892ddea136a7e1507a21b75948) --- .../scripts/validate_deltanet_fused_nki.py | 24 +-- .../src/nki_kernels/nki_deltanet_fused.py | 151 +++++++++--------- .../test/unit/test_deltanet_decay.py | 96 +++++------ 3 files changed, 136 insertions(+), 135 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py index 19926dcc..505a8fb3 100644 --- a/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py +++ b/contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py @@ -126,6 +126,13 @@ def randn(shape: tuple[int, ...], scale: float) -> Any: } +def stable_causal_decay(torch: Any, gc: Any, mask: Any) -> Any: + """Compute exp(gc[i] - gc[j]) only where the causal mask is active.""" + diff = gc - gc.T + masked_diff = torch.where(mask.bool(), diff, torch.zeros_like(diff)) + return torch.exp(masked_diff) * mask + + def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: lower = inputs["lower_mask"] lower_diag = inputs["lower_mask_diag"] @@ -146,26 +153,19 @@ def reference_math(torch: Any, inputs: dict[str, Any]) -> tuple[Any, Any]: k_beta = k * beta v_beta = v * beta - decay = torch.exp(gc - gc.T) - decay_strict = decay * lower - decay_diag = decay * lower_diag + decay_strict = stable_causal_decay(torch, gc, lower) + decay_diag = stable_causal_decay(torch, gc, lower_diag) qk_beta = k_beta @ k.T a_mat = -(qk_beta * decay_strict) * lower - # Mirror the fused kernel: Neumann power-doubling, not triangular solve. - p_acc = eye + a_mat - a_pow = a_mat.clone() - for _ in range(6): - a_pow = (a_pow @ a_pow) * lower - p_acc = ((eye + a_pow) @ p_acc) * lower_diag + lhs = eye - a_mat exp_gc = torch.exp(gc) - value_corr = p_acc @ v_beta - k_cumdecay = p_acc @ (k_beta * exp_gc) + solve_rhs = v_beta - ((k_beta * exp_gc) @ state) + v_new = torch.linalg.solve_triangular(lhs, solve_rhs, upper=False) attn_intra = (q @ k.T) * decay_diag - v_new = value_corr - (k_cumdecay @ state) chunk_out = ((q * exp_gc) @ state) + (attn_intra @ v_new) outputs.append(chunk_out) diff --git a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py index f6b6e0ee..d3f1638b 100644 --- a/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py +++ b/contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py @@ -17,16 +17,12 @@ Chunk size = 128 = P_MAX (one tile per chunk). Mathematical framework: - Per-chunk Neumann-series power-doubling for intra-chunk correction: + Per-chunk direct triangular solve for intra-chunk correction: QK_decay[i,j] = QK[i,j] * exp(gc[i] - gc[j]) for i > j A = -QK_decay * lower_mask - N = (I+A)(I+A^2)(I+A^4)...(I+A^64) [6 rounds] - value_corr = N @ v_beta - k_cumdecay = N @ (k_beta * exp(gc)) + v_new = solve((I - A), v_beta - (k_beta * exp(gc)) @ state) Inter-chunk state propagation: - v_prime = k_cumdecay @ state - v_new = value_corr - v_prime attn_inter = (q * exp(gc)) @ state attn_intra = (q @ k^T) * decay_mask * lower_mask_diag output = attn_inter + attn_intra @ v_new @@ -399,56 +395,20 @@ def deltanet_fused_chunked_fwd( nisa.tensor_tensor(dst=A_mat, data1=neg_QK_decay, data2=Lmask, op=nl.multiply) # ============================================================ - # Neumann power-doubling: N = (I+A)(I+A^2)...(I+A^{64}) - # 6 rounds → resolves rank up to 2^6 = 64 (sufficient for chunk=128) - # ============================================================ - P_acc = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=P_acc, data1=eye, data2=A_mat, op=nl.add) - - A_pow = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=A_pow, src=A_mat) - - for _round in nl.sequential_range(6): - # A_pow = A_pow^2: transpose A_pow, then matmul - Ap_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=Ap_T_psum, data=A_pow) - Ap_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=Ap_T, src=Ap_T_psum) - - Ap_sq_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_matmul(dst=Ap_sq_psum, stationary=Ap_T, moving=A_pow) - nisa.tensor_copy(dst=A_pow, src=Ap_sq_psum) - nisa.tensor_tensor(dst=A_pow, data1=A_pow, data2=Lmask, op=nl.multiply) - - # P_acc = (I + A_pow) @ P_acc: transpose IpA, then matmul - IpA = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=IpA, data1=eye, data2=A_pow, op=nl.add) - - IpA_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=IpA_T_psum, data=IpA) - IpA_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=IpA_T, src=IpA_T_psum) - - Pacc_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_matmul(dst=Pacc_psum, stationary=IpA_T, moving=P_acc) - nisa.tensor_copy(dst=P_acc, src=Pacc_psum) - nisa.tensor_tensor(dst=P_acc, data1=P_acc, data2=Lmask_d, op=nl.multiply) - - # ============================================================ - # Apply N: value_corr = N @ v_beta - # k_cumdecay = N @ (k_beta * exp(gc)) + # Build the single RHS needed for v_new. + # + # Materializing N = inv(I - A) would compute: + # value_corr = N @ v_beta + # k_cumdecay = N @ (k_beta * exp(gc)) + # v_new = value_corr - k_cumdecay @ state + # + # By associativity: + # v_new = N @ (v_beta - (k_beta * exp(gc)) @ state) + # + # Solve this RHS directly. This is equivalent to the nilpotent + # Neumann series, but avoids repeated matrix squaring, which is + # numerically unstable for realistic Qwen decay gates. # ============================================================ - N_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=N_T_psum, data=P_acc) - N_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=N_T, src=N_T_psum) - - vc_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) - nisa.nc_matmul(dst=vc_psum, stationary=N_T, moving=v_beta) - value_corr = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=value_corr, src=vc_psum) - - # k_beta * exp(gc): row-scaled kb_exp_gc = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) nisa.tensor_scalar( dst=kb_exp_gc, @@ -458,10 +418,67 @@ def deltanet_fused_chunked_fwd( engine=nisa.vector_engine, ) - kcd_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) - nisa.nc_matmul(dst=kcd_psum, stationary=N_T, moving=kb_exp_gc) - k_cumdecay = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=k_cumdecay, src=kcd_psum) + kbe_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=kbe_T_psum, data=kb_exp_gc) + kbe_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_T, src=kbe_T_psum) + + kbe_state_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=kbe_state_psum, stationary=kbe_T, moving=state) + kbe_state = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=kbe_state, src=kbe_state_psum) + + solve_rhs = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=solve_rhs, data1=v_beta, data2=kbe_state, op=nl.subtract) + + # ============================================================ + # Direct forward substitution for: + # v_new = solve((I - A_mat), solve_rhs) + # + # A_mat is strictly lower triangular, so row i only depends on rows + # < i. The full-matmul plus row-select form keeps the shape static + # and compiler-safe while updating exactly one solved row per step. + # ============================================================ + v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.memset(dst=v_new, value=0.0) + + A_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) + nisa.nc_transpose(dst=A_T_psum, data=A_mat) + A_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=A_T, src=A_T_psum) + + for solve_i in nl.static_range(P_MAX): + row_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) + nisa.nc_matmul(dst=row_psum, stationary=A_T, moving=v_new) + row_prod = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy(dst=row_prod, src=row_psum) + + row_with_rhs = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor( + dst=row_with_rhs, + data1=row_prod, + data2=solve_rhs, + op=nl.add, + ) + + row_mask = nl.ndarray((P_MAX, 1), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_copy( + dst=row_mask[0:P_MAX, 0:1], + src=eye[0:P_MAX, solve_i : solve_i + 1], + ) + + row_update = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_scalar( + dst=row_update, + data=row_with_rhs, + op0=nl.multiply, + operand0=row_mask, + engine=nisa.vector_engine, + ) + + v_next = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) + nisa.tensor_tensor(dst=v_next, data1=v_new, data2=row_update, op=nl.add) + nisa.tensor_copy(dst=v_new, src=v_next) # ============================================================ # Phase 2: Inter-chunk state propagation @@ -486,22 +503,6 @@ def deltanet_fused_chunked_fwd( dst=attn_intra, data1=qk_decay, data2=Lmask_d, op=nl.multiply ) - # ============================================================ - # v_prime = k_cumdecay @ state (state is in SBUF!) - # ============================================================ - kcd_T_psum = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.psum) - nisa.nc_transpose(dst=kcd_T_psum, data=k_cumdecay) - kcd_T = nl.ndarray((P_MAX, P_MAX), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=kcd_T, src=kcd_T_psum) - - vp_psum = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.psum) - nisa.nc_matmul(dst=vp_psum, stationary=kcd_T, moving=state) - v_prime = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_copy(dst=v_prime, src=vp_psum) - - v_new = nl.ndarray((P_MAX, dim), dtype=nl.float32, buffer=nl.sbuf) - nisa.tensor_tensor(dst=v_new, data1=value_corr, data2=v_prime, op=nl.subtract) - # ============================================================ # attn_inter = (q * exp(gc)) @ state (state is in SBUF!) # ============================================================ diff --git a/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py index 416a431a..80927b6f 100644 --- a/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py +++ b/contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py @@ -1,67 +1,67 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for fused DeltaNet log-decay bounding.""" +"""CPU-only regressions for the fused DeltaNet decay reference math.""" +import importlib.util import os -import sys +import types import unittest import torch -_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) -if _CONTRIB_ROOT not in sys.path: - sys.path.insert(0, _CONTRIB_ROOT) -from src.modeling_qwen35 import ( - FUSED_DELTANET_DECAY_MAX, - FUSED_DELTANET_DECAY_MIN, - _bound_fused_deltanet_log_decay, +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +_VALIDATOR_PATH = os.path.join( + _CONTRIB_ROOT, + "scripts", + "validate_deltanet_fused_nki.py", ) -def _chunked_cumsum(g, batch_size, num_heads, total_seq_len, chunk_size): - num_chunks = total_seq_len // chunk_size - return g.reshape(batch_size, num_heads, num_chunks, chunk_size).cumsum(dim=-1) - - -class TestFusedDeltaNetDecayBounding(unittest.TestCase): - def test_preserves_non_extreme_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -0.125, - dtype=torch.float32, - ) - - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size +def _load_validator(): + spec = importlib.util.spec_from_file_location( + "qwen36_validate_deltanet_fused_nki", + _VALIDATOR_PATH, + ) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + spec.loader.exec_module(module) + return module + + +class TestFusedDeltaNetDecayMath(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.validator = _load_validator() + + def test_stable_causal_decay_masks_before_exp(self): + gc = torch.linspace(0.0, -300.0, 128, dtype=torch.float32).reshape(128, 1) + lower = torch.tril(torch.ones((128, 128), dtype=torch.float32), diagonal=-1) + lower_diag = torch.tril(torch.ones((128, 128), dtype=torch.float32)) + + strict_decay = self.validator.stable_causal_decay(torch, gc, lower) + diag_decay = self.validator.stable_causal_decay(torch, gc, lower_diag) + + self.assertTrue(torch.isfinite(strict_decay).all()) + self.assertTrue(torch.isfinite(diag_decay).all()) + self.assertTrue(torch.equal(strict_decay.triu(), torch.zeros_like(strict_decay.triu()))) + torch.testing.assert_close(torch.diagonal(diag_decay), torch.ones(128)) + + def test_reference_math_is_finite_for_realistic_gate_scale(self): + args = types.SimpleNamespace( + seed=1234, + seq_len=256, + value_scale=0.05, + state_scale=0.01, + gate_scale=1.0, ) - torch.testing.assert_close(bounded, g) - - def test_bounds_per_chunk_cumulative_decay(self): - batch_size, num_heads, total_seq_len, chunk_size = 2, 3, 16, 8 - g = torch.full( - (batch_size, num_heads, total_seq_len), - -10.0, - dtype=torch.float32, - ) + inputs = self.validator.make_inputs(torch, args) + output, state = self.validator.reference_math(torch, inputs) - bounded = _bound_fused_deltanet_log_decay( - g, batch_size, num_heads, total_seq_len, chunk_size - ) - bounded_cumsum = _chunked_cumsum( - bounded, batch_size, num_heads, total_seq_len, chunk_size - ) - expected_cumsum = _chunked_cumsum( - g, batch_size, num_heads, total_seq_len, chunk_size - ).clamp(min=FUSED_DELTANET_DECAY_MIN, max=FUSED_DELTANET_DECAY_MAX) - - torch.testing.assert_close(bounded_cumsum, expected_cumsum) - self.assertGreaterEqual(float(bounded_cumsum.min()), FUSED_DELTANET_DECAY_MIN) - self.assertLessEqual(float(bounded_cumsum.max()), FUSED_DELTANET_DECAY_MAX) - self.assertTrue(torch.isfinite(bounded).all()) + self.assertTrue(torch.isfinite(output).all()) + self.assertTrue(torch.isfinite(state).all()) if __name__ == "__main__": From e938d71dd6ac64128445961e070bc3a189246678 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 19:38:47 +0530 Subject: [PATCH 10/21] Record Qwen fused direct-solve validation --- .../README.md | 61 +++++++++++++ .../context_sweep_partial_20260522T1348Z.json | 69 +++++++++++++++ ...ctsolve_chat_coherence_20260522T1332Z.json | 56 ++++++++++++ ...rectsolve_decode_bench_20260522T1348Z.json | 85 +++++++++++++++++++ ...rectsolve_perf_capture_20260522T1348Z.json | 71 ++++++++++++++++ 5 files changed, 342 insertions(+) create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/README.md create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/context_sweep_partial_20260522T1348Z.json create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_chat_coherence_20260522T1332Z.json create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_decode_bench_20260522T1348Z.json create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_perf_capture_20260522T1348Z.json diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md new file mode 100644 index 00000000..4e213f8e --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -0,0 +1,61 @@ +# Qwen3.6 Fused Direct-Solve Validation + +Validation target: + +- Branch commit: `ae2613d` source, replayed onto PR 164 clean branch +- Artifact: `qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z` +- Runtime host: `trn2.3xlarge` +- Runtime path: offline vLLM/NxDI, on-device greedy sampling + +## Summary + +The fused DeltaNet CTE path now uses a direct triangular RHS solve instead of Neumann power-doubling. This fixes the fused-kernel instability observed with realistic Qwen gates while keeping the fused CTE path available for validation. + +## Coherence + +`qwen36_directsolve_chat_coherence_20260522T1332Z.json` + +- Overall pass: `true` +- Chat template used `enable_thinking=false` +- Fact, code, and prefix-cache prompts produced non-repetitive real text +- Smoke decode throughput: about `20.5 tok/s` + +## Decode + +`qwen36_directsolve_decode_bench_20260522T1348Z.json` + +- Average decode throughput: `21.63 tok/s` +- TPOT: `46.2 ms/token` +- 128-token decode average latency: `5.92 s` +- Artifact uses on-device greedy sampling with `output_logits=false` + +## Cold And Warm Prefill + +`context_sweep_partial_20260522T1348Z.json` + +| Prompt tokens | Cold TTFT | Cold prefill | Warm TTFT | Warm prefill | +| ---: | ---: | ---: | ---: | ---: | +| 512 | 1.31 s | 390 tok/s | 0.42 s | 1.2k tok/s | +| 4096 | 7.03 s | 582 tok/s | 0.42 s | 9.8k tok/s | +| 8192 | 13.61 s | 602 tok/s | 0.43 s | 18.9k tok/s | +| 16384 | 27.84 s | 589 tok/s | 0.45 s | 36.3k tok/s | + +The 32K row did not complete because this artifact was compiled with `prefix_buckets` only through `16384`: + +```text +Prefix len 16640 exceeds largest bucket 16384 for context_encoding_model +``` + +That is an artifact bucket coverage limitation, not a direct-solve correctness failure. + +## Memory + +`qwen36_directsolve_perf_capture_20260522T1348Z.json` + +- Neuron HBM peak sum: `60.1 GiB` +- Host process RSS peak: `46.3 GiB` +- Main logical cores peaked around `14.57 GiB` each on cores `0`, `2`, `4`, and `6` + +## Follow-Up + +For long-context validation, recompile the same branch with prefix buckets beyond `16384`, ideally through the intended 64K/128K validation range. diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/context_sweep_partial_20260522T1348Z.json b/profile_artifacts/qwen36_fused_directsolve_20260522/context_sweep_partial_20260522T1348Z.json new file mode 100644 index 00000000..776f2d85 --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/context_sweep_partial_20260522T1348Z.json @@ -0,0 +1,69 @@ +{ + "artifact": "/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z", + "artifact_config": { + "context_encoding_buckets": [ + 256, + 512 + ], + "pa_num_blocks": 512, + "prefix_buckets": [ + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384 + ], + "seq_len": 131072, + "tkg_batch_size": 1, + "token_generation_buckets": [ + 8192, + 32768, + 131072 + ] + }, + "failed_at": { + "error": "Prefix len 16640 exceeds largest bucket 16384 for context_encoding_model", + "target_prompt_tokens": 32768 + }, + "max_tokens": 1, + "rows": [ + { + "cold_effective_prompt_tokens_per_second": 390.3766293031504, + "cold_ttft_seconds": 1.311553923998872, + "repeat_exact": true, + "target_prompt_tokens": 512, + "token_range_passed": true, + "warm_effective_prompt_tokens_per_second": 1214.4647409237177, + "warm_ttft_seconds": 0.4215849030006211 + }, + { + "cold_effective_prompt_tokens_per_second": 582.2723653461799, + "cold_ttft_seconds": 7.034508665999965, + "repeat_exact": true, + "target_prompt_tokens": 4096, + "token_range_passed": true, + "warm_effective_prompt_tokens_per_second": 9832.557853165847, + "warm_ttft_seconds": 0.4165752249991783 + }, + { + "cold_effective_prompt_tokens_per_second": 602.0664367066421, + "cold_ttft_seconds": 13.606471811999654, + "repeat_exact": true, + "target_prompt_tokens": 8192, + "token_range_passed": true, + "warm_effective_prompt_tokens_per_second": 18907.071820446214, + "warm_ttft_seconds": 0.4332770340006391 + }, + { + "cold_effective_prompt_tokens_per_second": 588.6039525245649, + "cold_ttft_seconds": 27.835355045999677, + "repeat_exact": true, + "target_prompt_tokens": 16384, + "token_range_passed": true, + "warm_effective_prompt_tokens_per_second": 36293.66416713788, + "warm_ttft_seconds": 0.4514286550001998 + } + ] +} diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_chat_coherence_20260522T1332Z.json b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_chat_coherence_20260522T1332Z.json new file mode 100644 index 00000000..7123f598 --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_chat_coherence_20260522T1332Z.json @@ -0,0 +1,56 @@ +{ + "artifact": "/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z", + "branch": "qwen-fused-neumann-stable-decay", + "commit": "ae2613d", + "pass": true, + "rows": [ + { + "checks": { + "code_has_def_or_loop": true, + "fact_mentions_paris": true, + "has_text": true, + "not_all_bang": true, + "not_repetitive": true, + "prefix_mentions_cache": true + }, + "completion_tokens": 64, + "elapsed_s": 3.126171286000499, + "label": "fact", + "pass": true, + "text": " Paris. \ud83c\uddeb\ud83c\uddf7\ud83d\uddfc\ud83c\udf77\ud83e\udd56\ud83c\udfa8\ud83c\udfdb\ufe0f\ud83c\udf0d\ud83d\udc51\ud83d\udcdc\ud83c\udfad\ud83c\udf7d\ufe0f\ud83d\udeb2\ud83c\udf33\ud83c\udff0\ud83c\udfb6\ud83d\udcda\ud83c\udfaa\ud83c\udfa8\ud83c\udfad", + "tok_s": 20.47232673609484 + }, + { + "checks": { + "code_has_def_or_loop": true, + "fact_mentions_paris": true, + "has_text": true, + "not_all_bang": true, + "not_repetitive": true, + "prefix_mentions_cache": true + }, + "completion_tokens": 64, + "elapsed_s": 3.1215076250000493, + "label": "code", + "pass": true, + "text": "def fib(n):\n if n <= 0:\n return 0\n elif n == 1:\n return 1\n \n a, b = 0, 1\n for _ in range(2, n + 1):\n a, b = b, a", + "tok_s": 20.502913235715383 + }, + { + "checks": { + "code_has_def_or_loop": true, + "fact_mentions_paris": true, + "has_text": true, + "not_all_bang": true, + "not_repetitive": true, + "prefix_mentions_cache": true + }, + "completion_tokens": 64, + "elapsed_s": 3.123161378000077, + "label": "prefix_cache", + "pass": true, + "text": "* **Avoids Redundant Computation: Prefix caching stores the hidden states (activations) of previously processed tokens, allowing the model to skip re-computing the attention and feed-forward layers for identical prompt prefixes.\n* **Reduces Memory and Latency: By reusing cached results, the", + "tok_s": 20.492056686799366 + } + ] +} diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_decode_bench_20260522T1348Z.json b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_decode_bench_20260522T1348Z.json new file mode 100644 index 00000000..1e984319 --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_decode_bench_20260522T1348Z.json @@ -0,0 +1,85 @@ +{ + "artifact": "/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z", + "artifact_neuron_config": { + "context_encoding_buckets": [ + 256, + 512 + ], + "ctx_batch_size": 1, + "max_context_length": 131072, + "max_length": 131072, + "on_device_sampling_config": { + "deterministic": false, + "do_sample": false, + "dynamic": false, + "global_topk": 256, + "on_device_sampling_config": true, + "sampling_dp_degree": 1, + "temperature": 1.0, + "top_k": 1, + "top_k_kernel_enabled": true, + "top_p": 1.0 + }, + "output_logits": false, + "pa_block_size": 256, + "pa_num_blocks": 512, + "prefix_buckets": [ + 256, + 512, + 1024, + 2048, + 4096, + 8192, + 16384 + ], + "seq_len": 131072, + "tkg_batch_size": 1, + "token_generation_buckets": [ + 8192, + 32768, + 131072 + ] + }, + "async_mode": true, + "avg_elapsed_s": 5.918099790499582, + "avg_tok_s": 21.628582366108347, + "cte_buckets": [ + 256, + 512 + ], + "max_model_len": 8192, + "max_num_seqs": 1, + "max_tokens": 128, + "model_path": "/home/ubuntu/models/Qwen3.6-27B", + "pa_num_blocks": 512, + "prompt": "Explain prefix caching in two concise paragraphs.", + "runs": [ + { + "completion_tokens": 128, + "elapsed_s": 5.923555245999523, + "run": 1, + "text": "\n\nPrefix caching is an optimization technique used in large language model inference that significantly reduces computational overhead by reusing previously computed key-value (K/V) states. When a user's prompt shares a common prefix with a previous request\u2014such as a standard system instruction or a shared codebase\u2014the model can bypass the redundant attention calculations for those identical tokens. Instead of recomputing the K/V pairs for the shared context, the system simply loads the cached results from memory, allowing the model to focus its processing power exclusively on the new, unique tokens.\n\nThis approach dramatically accelerates response times and improves throughput, particularly in scenarios involving repetitive prompts or", + "tok_s": 21.608644586617956 + }, + { + "completion_tokens": 128, + "elapsed_s": 5.912644334999641, + "run": 2, + "text": "\n\nPrefix caching is an optimization technique used in large language model inference that significantly reduces computational overhead by reusing previously computed key-value (K/V) states. When a user's prompt shares a common prefix with a previous request\u2014such as a standard system instruction or a shared codebase\u2014the model can bypass the redundant attention calculations for those identical tokens. Instead of recomputing the K/V pairs for the shared context, the system simply loads the cached results from memory, allowing the model to focus its processing power exclusively on the new, unique tokens.\n\nThis approach dramatically accelerates response times and improves throughput, particularly in scenarios involving repetitive prompts or", + "tok_s": 21.64852014559874 + } + ], + "seq_len": 131072, + "token_generation_batches": null, + "token_generation_buckets": [ + 8192, + 32768, + 131072 + ], + "warmup": { + "completion_tokens": 8, + "elapsed_s": 1.13700048300052, + "text": "\n\nPrefix caching is an optimization technique used", + "tok_s": 7.036056817573349 + } +} diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_perf_capture_20260522T1348Z.json b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_perf_capture_20260522T1348Z.json new file mode 100644 index 00000000..1fa6af85 --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_perf_capture_20260522T1348Z.json @@ -0,0 +1,71 @@ +{ + "artifact": "/mnt/trainium_artifacts/qwen_artifacts/qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z", + "context_json": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_context_sweep_20260522T1348Z.json", + "decode_json": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_decode_bench_20260522T1348Z.json", + "phases": [ + { + "log_path": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_context_sweep_20260522T1348Z.log", + "memory": { + "max_device_mem_peak_sum_bytes": 64543326208, + "max_device_mem_present_sum_bytes": 1302872064, + "max_host_process_rss_kb": 48594600, + "max_peak_per_core_bytes": { + "neuron_core0": 15642320896, + "neuron_core1": 493510656, + "neuron_core2": 15642320896, + "neuron_core3": 493510656, + "neuron_core4": 15642320896, + "neuron_core5": 493510656, + "neuron_core6": 15642320896, + "neuron_core7": 493510656 + }, + "max_present_per_core_bytes": { + "neuron_core0": 67239936, + "neuron_core1": 32448512, + "neuron_core2": 57049088, + "neuron_core3": 56238080, + "neuron_core4": 635699200, + "neuron_core5": 8896512, + "neuron_core6": 635699200, + "neuron_core7": 8896512 + }, + "samples": 256 + }, + "memory_log_path": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_context_sweep_memory_20260522T1348Z.jsonl", + "name": "context_sweep", + "returncode": 1 + }, + { + "log_path": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_decode_bench_20260522T1348Z.log", + "memory": { + "max_device_mem_peak_sum_bytes": 64543326208, + "max_device_mem_present_sum_bytes": 1360543744, + "max_host_process_rss_kb": 48343516, + "max_peak_per_core_bytes": { + "neuron_core0": 15642320896, + "neuron_core1": 493510656, + "neuron_core2": 15642320896, + "neuron_core3": 493510656, + "neuron_core4": 15642320896, + "neuron_core5": 493510656, + "neuron_core6": 15642320896, + "neuron_core7": 493510656 + }, + "max_present_per_core_bytes": { + "neuron_core0": 67239936, + "neuron_core1": 25825280, + "neuron_core2": 44564480, + "neuron_core3": 32448512, + "neuron_core4": 635699200, + "neuron_core5": 32448512, + "neuron_core6": 635699200, + "neuron_core7": 8896512 + }, + "samples": 119 + }, + "memory_log_path": "/home/ubuntu/validation_logs/fp8_128k/qwen36_directsolve_decode_bench_memory_20260522T1348Z.jsonl", + "name": "decode_bench", + "returncode": 0 + } + ] +} From 4c44489783c9f04c7be98574f196aabec825d24f Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 19:42:24 +0530 Subject: [PATCH 11/21] Document Qwen fused direct-solve PR delta --- .../PR_BODY.md | 37 +++++++++++++++++++ .../README.md | 15 ++++++++ 2 files changed, 52 insertions(+) create mode 100644 profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md new file mode 100644 index 00000000..0bd34fae --- /dev/null +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -0,0 +1,37 @@ +# Qwen3.6 Fused DeltaNet Direct-Solve Follow-Up + +This is a clean branch on top of PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It is meant to extend the existing vLLM APC PR without bringing in the full experimental branch stack. + +## What Changed On Top Of PR 164 + +- Stabilized the Qwen fused DeltaNet CTE kernel. +- Added an isolated fused NKI validation script. +- Made the validator load the fused kernel directly. +- Replaced the fused kernel's Neumann power-doubling solve with a direct triangular RHS solve. +- Updated CPU DeltaNet decay regression coverage for realistic gate scales. +- Added compact validation artifacts for coherence, decode, prefill, and memory. + +## Why + +The previous fused path could produce unstable or incoherent outputs with realistic Qwen gate values. The Neumann power-doubling solve is mathematically convenient, but it forms repeated full-matrix powers and is numerically fragile for this recurrence. The direct triangular RHS solve computes the causal recurrence without those large intermediate powers and matches the stable chunked-kernel approach. + +## Validation + +Local checks: + +- `python3 -m py_compile contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py` +- `python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py -q` +- Result: `2 passed` + +Trn2 artifact validation: + +- Coherence pass: `true` +- Decode throughput: `21.63 tok/s` +- TPOT: `46.2 ms/token` +- Cold prefill: about `590 tok/s` from 4K through 16K +- Warm prefill: up to `36.3k tok/s` at 16K with APC reuse +- HBM peak sum: `60.1 GiB` + +Known limitation: + +- The compiled artifact used for this validation has `prefix_buckets` through `16384`; the 32K sweep failed with `Prefix len 16640 exceeds largest bucket 16384`. A long-context follow-up compile needs larger prefix buckets. diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index 4e213f8e..b18f188e 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -11,6 +11,21 @@ Validation target: The fused DeltaNet CTE path now uses a direct triangular RHS solve instead of Neumann power-doubling. This fixes the fused-kernel instability observed with realistic Qwen gates while keeping the fused CTE path available for validation. +## Delta From PR 164 + +This branch is intentionally based on the current vLLM APC PR head, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It does not include the full experimental branch history. + +The clean branch adds only the fused DeltaNet follow-up work on top of PR 164: + +- Stabilizes the Qwen fused DeltaNet CTE kernel implementation in `nki_deltanet_fused.py`. +- Adds an isolated fused NKI validator for realistic Qwen-style gate/decay coverage. +- Loads the fused kernel directly in the validator so it can run outside package import edge cases. +- Replaces the fused kernel's Neumann power-doubling solve with the same direct triangular RHS solve strategy used by the stable chunked path. +- Updates the CPU DeltaNet decay regression test to cover realistic gate scales and direct-solve behavior. +- Records the direct-solve artifact validation results in this directory. + +The already-open vLLM APC PR remains the base contribution for the Qwen3.6 model, vLLM APC integration, docs, and baseline benchmark material. This branch is the proposed add-on that makes the fused DeltaNet path coherent and measurable. + ## Coherence `qwen36_directsolve_chat_coherence_20260522T1332Z.json` From b609c8708b6f085745ae3d31866d853784f1f51d Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 19:44:12 +0530 Subject: [PATCH 12/21] Clarify Qwen fused direct-solve branch lineage --- .../PR_BODY.md | 27 ++++++++++++++++- .../README.md | 29 ++++++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 0bd34fae..4a5ede72 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -1,9 +1,32 @@ # Qwen3.6 Fused DeltaNet Direct-Solve Follow-Up -This is a clean branch on top of PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It is meant to extend the existing vLLM APC PR without bringing in the full experimental branch stack. +This is a clean extraction on top of PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It is meant to show the fused DeltaNet direct-solve follow-up without bringing in the full experimental branch stack. + +## Branch Lineage + +The actual development history was: + +```text +PR 164 / vLLM APC baseline + -> experimental + -> qwen-fused-neumann-stable-decay +``` + +The `experimental` branch added substantial runtime and validation work after PR 164: + +- Hybrid APC checkpoint cache, lifecycle, restore/commit masks, and strict metadata contracts. +- vLLM/NxDI scheduler bridge changes for cached chunked prefill, backed prefix reads, request-id propagation, and suffix continuation handling. +- Qwen chunked prefill fixes for CTE bucket alignment, prefix-cache slot mapping, GDN checkpoint commits, and chunk-boundary handling. +- FP8 128K artifact configuration guards, validation max-prompt alignment, and artifact audit checks. +- OpenAI/vLLM validation harnesses for exactness, context sweeps, TTFT/TPOT, decode benchmarking, memory capture, and API compatibility. +- Decode-path and sampling fixes, including on-device sampling/logits-path validation and chat-template thinking controls. + +The final fused branch added the direct-solve fused DeltaNet fix on top of that experimental runtime stack. ## What Changed On Top Of PR 164 +This clean branch extracts only the fused DeltaNet follow-up commits: + - Stabilized the Qwen fused DeltaNet CTE kernel. - Added an isolated fused NKI validation script. - Made the validator load the fused kernel directly. @@ -11,6 +34,8 @@ This is a clean branch on top of PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac - Updated CPU DeltaNet decay regression coverage for realistic gate scales. - Added compact validation artifacts for coherence, decode, prefill, and memory. +The artifact results below were produced from `qwen-fused-neumann-stable-decay`, so they validate the direct-solve fused kernel inside the full `experimental` lineage. They should not be read as proof that PR 164 plus only these extracted commits reproduces every Hybrid APC runtime fix from `experimental`. + ## Why The previous fused path could produce unstable or incoherent outputs with realistic Qwen gate values. The Neumann power-doubling solve is mathematically convenient, but it forms repeated full-matrix powers and is numerically fragile for this recurrence. The direct triangular RHS solve computes the causal recurrence without those large intermediate powers and matches the stable chunked-kernel approach. diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index b18f188e..d9f3708b 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -11,11 +11,32 @@ Validation target: The fused DeltaNet CTE path now uses a direct triangular RHS solve instead of Neumann power-doubling. This fixes the fused-kernel instability observed with realistic Qwen gates while keeping the fused CTE path available for validation. -## Delta From PR 164 +## Lineage From PR 164 -This branch is intentionally based on the current vLLM APC PR head, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It does not include the full experimental branch history. +The development lineage was: -The clean branch adds only the fused DeltaNet follow-up work on top of PR 164: +```text +PR 164 / vLLM APC baseline + -> experimental + -> qwen-fused-neumann-stable-decay +``` + +The `experimental` branch accumulated the runtime and validation work needed to make Qwen3.6 Hybrid APC usable beyond the original PR 164 baseline: + +- Hybrid APC checkpoint cache, lifecycle, restore/commit masks, and strict metadata contracts. +- vLLM/NxDI scheduler bridge changes for cached chunked prefill, backed prefix reads, request-id propagation, and suffix continuation handling. +- Qwen chunked prefill fixes for CTE bucket alignment, prefix-cache slot mapping, GDN checkpoint commits, and chunk-boundary handling. +- FP8 128K artifact configuration guards, validation max-prompt alignment, and artifact audit checks. +- OpenAI/vLLM validation harnesses for exactness, context sweeps, TTFT/TPOT, decode benchmarking, memory capture, and API compatibility. +- Decode-path and sampling fixes, including on-device sampling/logits-path validation and chat-template thinking controls. + +The final fused branch then adds the direct-solve fused DeltaNet follow-up on top of `experimental`. + +## Clean PR Extraction + +This clean branch is based on the current PR 164 head, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`, and intentionally does not include the full experimental branch history. + +It extracts only the fused DeltaNet follow-up work: - Stabilizes the Qwen fused DeltaNet CTE kernel implementation in `nki_deltanet_fused.py`. - Adds an isolated fused NKI validator for realistic Qwen-style gate/decay coverage. @@ -24,7 +45,7 @@ The clean branch adds only the fused DeltaNet follow-up work on top of PR 164: - Updates the CPU DeltaNet decay regression test to cover realistic gate scales and direct-solve behavior. - Records the direct-solve artifact validation results in this directory. -The already-open vLLM APC PR remains the base contribution for the Qwen3.6 model, vLLM APC integration, docs, and baseline benchmark material. This branch is the proposed add-on that makes the fused DeltaNet path coherent and measurable. +The validation artifact referenced below was compiled from `qwen-fused-neumann-stable-decay`, so these results reflect the final fused branch running on top of the `experimental` runtime stack. If this clean extraction is used to extend PR 164 directly, reviewers should treat the artifact results as validation of the fused direct-solve change in the full experimental lineage, not proof that PR 164 plus these extracted commits alone reproduces every Hybrid APC runtime behavior. ## Coherence From f895246469cb0fb97d965060c141e6f9e6a7f964 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 19:51:23 +0530 Subject: [PATCH 13/21] Summarize Qwen experimental branch changes --- .../PR_BODY.md | 14 ++++++++++++++ .../qwen36_fused_directsolve_20260522/README.md | 17 +++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 4a5ede72..8a07be52 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -23,6 +23,20 @@ The `experimental` branch added substantial runtime and validation work after PR The final fused branch added the direct-solve fused DeltaNet fix on top of that experimental runtime stack. +## Major Changes From PR 164 To The Tested Branch + +The full tested branch differs from PR 164 by roughly 105 source/result files. The important changes are: + +- **Hybrid APC runtime:** checkpoint cache, restore/commit masks, backed prefix reads, checkpoint-slot lifecycle, and metadata validation. +- **vLLM scheduler bridge:** request-id propagation, cached chunked-prefill continuations, active suffix accounting, no-prefix fallback handling, and backed-prefix authorization. +- **Qwen model execution:** Hybrid APC chunked prefill, GDN checkpoint commit/restore, text-only CTE inputs, compact CTE masks, prefix/suffix boundary handling, and decode-path safety. +- **NxDI prefix-cache plumbing:** vectorized APC args, prefix-cache bucket selection, padded-row safety, cached decode rows, and async checkpoint lifecycle. +- **DeltaNet NKI kernels:** chunked/fused validation paths, DeltaNet backend compile controls, masked Neumann experiments, and the final fused direct triangular RHS solve. +- **FP8/artifact compile path:** Qwen FP8 compile config coverage, artifact config audits, 128K validation alignment, `pa_num_blocks` checks, and larger TKG bucket support. +- **Serving/API compatibility:** OpenAI-compatible proxy/server behavior, chat-template `enable_thinking=false`, stop-sequence handling, and startup/offline helpers. +- **Validation harnesses:** exactness validation, OpenAI chat APC validation, boundary APC probes, context sweeps, offline decode benchmark, BF16 sweep, artifact audit, and memory/perf capture. +- **Tests and results:** added Hybrid APC, scheduler, model-alias, compile-config, artifact-audit, sampling, async, prefix-cache, and DeltaNet tests plus recorded performance/memory artifacts. + ## What Changed On Top Of PR 164 This clean branch extracts only the fused DeltaNet follow-up commits: diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index d9f3708b..bb811ab5 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -32,6 +32,23 @@ The `experimental` branch accumulated the runtime and validation work needed to The final fused branch then adds the direct-solve fused DeltaNet follow-up on top of `experimental`. +## Major Changes From PR 164 To The Tested Branch + +The tested branch differs from the vLLM APC PR by roughly 105 source/result files. The important changes are: + +- **Hybrid APC runtime:** added `hybrid_apc.py`, Hybrid APC request records, backed prefix reads, restore/commit masks, checkpoint-slot lifecycle handling, and stricter metadata validation. +- **vLLM scheduler bridge:** added `qwen36_hybrid_apc_scheduler_patch.py` and patched request-id propagation, cached chunked-prefill continuations, active suffix accounting, no-prefix fallback handling, and backed-prefix authorization. +- **Qwen model execution:** extended `modeling_qwen35.py` for Hybrid APC chunked prefill, GDN checkpoint commit/restore, text-only CTE inputs, compact CTE masks, prefix/suffix boundary handling, and decode-path safety. +- **NxDI prefix-cache plumbing:** updated `model_base.py`, `model_wrapper.py`, `async_execution.py`, and KV-cache helpers for vectorized APC args, prefix-cache bucket selection, padded-row safety, cached decode rows, and async checkpoint lifecycle. +- **DeltaNet NKI kernels:** added chunked and fused validation paths, DeltaNet backend compile controls, masked Neumann experiments, and the final fused direct triangular RHS solve. +- **FP8/artifact compile path:** expanded Qwen FP8 compile config coverage, artifact config audits, 128K/FP8 validation alignment, `pa_num_blocks` checks, and larger TKG bucket support. +- **Serving/API compatibility:** updated the OpenAI-compatible proxy/server behavior, chat-template `enable_thinking=false` handling, stop-sequence handling, and offline/server startup helpers. +- **Validation harnesses:** added exactness validation, OpenAI chat APC validation, boundary APC probes, context sweeps, offline decode benchmark, BF16 length sweep, artifact config audit, and memory/perf capture flows. +- **Tests:** added focused unit coverage for Hybrid APC manager/cache behavior, scheduler patching, model aliases, compile config, artifact config audit, sampling, async execution, prefix-cache bucket selection, and fused DeltaNet decay. +- **Result artifacts:** recorded 4K Hybrid APC TTFT/TPOT/memory results, 128K FP8 exactness/HBM estimates, decode fast-path probes, pfx128k context sweeps, and the fused direct-solve results in this directory. + +This is why the clean branch separates the **result presentation** from the full experimental runtime stack: reviewers can inspect the fused direct-solve result directly, while the large runtime lineage remains explicit. + ## Clean PR Extraction This clean branch is based on the current PR 164 head, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`, and intentionally does not include the full experimental branch history. From 6d758a26c3c96ae95f7022a42bb3dd03983b37ee Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 19:54:04 +0530 Subject: [PATCH 14/21] Expand Qwen fused direct-solve PR draft --- .../PR_BODY.md | 257 +++++++++++++++--- 1 file changed, 213 insertions(+), 44 deletions(-) diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 8a07be52..65eefc93 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -1,6 +1,10 @@ -# Qwen3.6 Fused DeltaNet Direct-Solve Follow-Up +# Qwen3.6: Fused DeltaNet Direct-Solve Follow-Up -This is a clean extraction on top of PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`. It is meant to show the fused DeltaNet direct-solve follow-up without bringing in the full experimental branch stack. +## Summary + +This branch is a reviewer-friendly presentation of the fused DeltaNet direct-solve result for Qwen3.6. It is intentionally based on PR 164, `contrib/qwen36-27b-vllm-apc-pr` at `ac7df71`, so reviewers can see the direct-solve change and validation artifacts without also reviewing the full experimental branch stack. + +The important result is that the fused DeltaNet CTE path is now coherent with realistic Qwen gate values when the Neumann power-doubling solve is replaced by a direct triangular RHS solve. ## Branch Lineage @@ -12,65 +16,230 @@ PR 164 / vLLM APC baseline -> qwen-fused-neumann-stable-decay ``` -The `experimental` branch added substantial runtime and validation work after PR 164: +PR 164 is the original Qwen3.6 vLLM APC baseline. After that, the `experimental` branch accumulated the runtime and validation work needed to make Hybrid APC usable and measurable. The final `qwen-fused-neumann-stable-decay` branch was created from `experimental` and added the fused DeltaNet stability work. + +This clean branch extracts the direct-solve fused DeltaNet work and its result artifacts onto PR 164 for review. It does not include the entire `experimental` branch history. + +## Why This Exists + +The original fused DeltaNet path used a Neumann power-doubling solve for the recurrence. That approach is mathematically convenient, but it is fragile for realistic Qwen gate scales because it repeatedly forms full matrix powers and can amplify numerical error. In practice, the fused path could produce unstable or incoherent tokens. -- Hybrid APC checkpoint cache, lifecycle, restore/commit masks, and strict metadata contracts. -- vLLM/NxDI scheduler bridge changes for cached chunked prefill, backed prefix reads, request-id propagation, and suffix continuation handling. -- Qwen chunked prefill fixes for CTE bucket alignment, prefix-cache slot mapping, GDN checkpoint commits, and chunk-boundary handling. -- FP8 128K artifact configuration guards, validation max-prompt alignment, and artifact audit checks. -- OpenAI/vLLM validation harnesses for exactness, context sweeps, TTFT/TPOT, decode benchmarking, memory capture, and API compatibility. -- Decode-path and sampling fixes, including on-device sampling/logits-path validation and chat-template thinking controls. +The chunked DeltaNet path already used a more stable direct triangular solve. This branch ports that idea to the fused path: compute the causal recurrence through a direct triangular RHS solve instead of Neumann power-doubling. -The final fused branch added the direct-solve fused DeltaNet fix on top of that experimental runtime stack. +The goal is not to claim the fused path is now the final production baseline by itself. The goal is to make the fused-kernel stability fix reviewable and to preserve the validation evidence that it produces coherent output inside the full experimental runtime lineage. ## Major Changes From PR 164 To The Tested Branch -The full tested branch differs from PR 164 by roughly 105 source/result files. The important changes are: +The full tested branch differs from PR 164 by roughly 105 source/result files. The main work streams from `experimental` were: + +### Hybrid APC Runtime + +- Added Hybrid APC request records and cache metadata. +- Added checkpoint-slot lifecycle management. +- Added restore and commit masks for GDN/recurrent state reuse. +- Added backed prefix reads and stricter unbacked-read guards. +- Added explicit metadata contracts so runtime decisions are scheduler-authorized instead of inferred locally. + +### vLLM / NxDI Scheduler Bridge + +- Added Qwen-specific vLLM scheduler patching for Hybrid APC. +- Propagated request IDs into the Neuron model runner path. +- Recognized cached chunked-prefill continuations. +- Tracked active scheduled suffix lengths. +- Added no-prefix fallback handling. +- Authorized backed prefix continuations through scheduler metadata. + +### Qwen Model Execution + +- Extended Qwen chunked prefill for Hybrid APC. +- Added GDN checkpoint commit/restore handling. +- Added text-only CTE input handling. +- Added compact CTE masks. +- Fixed prefix/suffix boundary handling. +- Guarded decode rows from unnecessary APC restore handling. +- Added chat-template thinking controls for validation and serving. + +### NxDI Prefix-Cache Plumbing + +- Updated prefix-cache model wrapper paths for vectorized APC args. +- Fixed prefix-cache bucket selection and padded-row safety. +- Added cached decode row handling. +- Added async checkpoint lifecycle cleanup. +- Added unit tests around bucket selection, async execution, and Hybrid APC prefix cache behavior. + +### DeltaNet NKI Kernels + +- Added DeltaNet backend compile controls. +- Added chunked and fused DeltaNet validator paths. +- Tested masked Neumann variants. +- Stabilized the fused DeltaNet kernel. +- Replaced the fused Neumann power-doubling solve with the direct triangular RHS solve. + +### FP8 / Artifact Compile Path + +- Added FP8 MLP-only compile configuration coverage. +- Added artifact config audit guardrails. +- Aligned validation max prompt length with compiled artifacts. +- Added `pa_num_blocks` and bucket sanity checks. +- Added 128K/TKG bucket validation support. + +### Serving And API Compatibility + +- Updated OpenAI-compatible proxy/server behavior. +- Normalized chat-template `enable_thinking=false`. +- Fixed stop-sequence handling. +- Added server startup/offline inference helpers. +- Added OpenAI API probe scripts and results. -- **Hybrid APC runtime:** checkpoint cache, restore/commit masks, backed prefix reads, checkpoint-slot lifecycle, and metadata validation. -- **vLLM scheduler bridge:** request-id propagation, cached chunked-prefill continuations, active suffix accounting, no-prefix fallback handling, and backed-prefix authorization. -- **Qwen model execution:** Hybrid APC chunked prefill, GDN checkpoint commit/restore, text-only CTE inputs, compact CTE masks, prefix/suffix boundary handling, and decode-path safety. -- **NxDI prefix-cache plumbing:** vectorized APC args, prefix-cache bucket selection, padded-row safety, cached decode rows, and async checkpoint lifecycle. -- **DeltaNet NKI kernels:** chunked/fused validation paths, DeltaNet backend compile controls, masked Neumann experiments, and the final fused direct triangular RHS solve. -- **FP8/artifact compile path:** Qwen FP8 compile config coverage, artifact config audits, 128K validation alignment, `pa_num_blocks` checks, and larger TKG bucket support. -- **Serving/API compatibility:** OpenAI-compatible proxy/server behavior, chat-template `enable_thinking=false`, stop-sequence handling, and startup/offline helpers. -- **Validation harnesses:** exactness validation, OpenAI chat APC validation, boundary APC probes, context sweeps, offline decode benchmark, BF16 sweep, artifact audit, and memory/perf capture. -- **Tests and results:** added Hybrid APC, scheduler, model-alias, compile-config, artifact-audit, sampling, async, prefix-cache, and DeltaNet tests plus recorded performance/memory artifacts. +### Validation And Results -## What Changed On Top Of PR 164 +- Added exactness validation. +- Added OpenAI chat APC validation. +- Added boundary APC probes. +- Added context sweeps. +- Added offline decode benchmarking. +- Added memory/perf capture. +- Recorded 4K/128K FP8 Hybrid APC results, decode results, and fused direct-solve results. -This clean branch extracts only the fused DeltaNet follow-up commits: +## What This Clean Branch Adds -- Stabilized the Qwen fused DeltaNet CTE kernel. -- Added an isolated fused NKI validation script. -- Made the validator load the fused kernel directly. -- Replaced the fused kernel's Neumann power-doubling solve with a direct triangular RHS solve. -- Updated CPU DeltaNet decay regression coverage for realistic gate scales. -- Added compact validation artifacts for coherence, decode, prefill, and memory. +This branch extracts only the fused DeltaNet follow-up commits: -The artifact results below were produced from `qwen-fused-neumann-stable-decay`, so they validate the direct-solve fused kernel inside the full `experimental` lineage. They should not be read as proof that PR 164 plus only these extracted commits reproduces every Hybrid APC runtime fix from `experimental`. +- `Stabilize Qwen fused DeltaNet decay` +- `Add isolated fused DeltaNet NKI validation` +- `Load fused NKI kernel directly in validator` +- `Fix fused DeltaNet solve stability` +- Validation/result documentation commits -## Why +Concretely, it changes: -The previous fused path could produce unstable or incoherent outputs with realistic Qwen gate values. The Neumann power-doubling solve is mathematically convenient, but it forms repeated full-matrix powers and is numerically fragile for this recurrence. The direct triangular RHS solve computes the causal recurrence without those large intermediate powers and matches the stable chunked-kernel approach. +- `contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py` +- `contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py` +- `contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py` +- `profile_artifacts/qwen36_fused_directsolve_20260522/*` -## Validation +## Implementation Details -Local checks: +The fused kernel previously used Neumann power-doubling to solve the recurrent DeltaNet update. The direct-solve version computes the lower-triangular causal recurrence explicitly in the RHS solve path. This avoids the repeated full-matrix power operations that were unstable under realistic gate values. -- `python3 -m py_compile contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py` -- `python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py -q` -- Result: `2 passed` +The validator was also made standalone enough to load the fused NKI kernel directly. This matters because review and debug runs should not depend on package import side effects. + +The CPU regression test was updated to cover realistic decay/gate scales and to catch the class of instability that showed up in the fused branch. + +## Validation Results + +Validation artifacts are stored under: + +```text +profile_artifacts/qwen36_fused_directsolve_20260522/ +``` + +### Local Checks + +```bash +python3 -m py_compile \ + contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py \ + contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py \ + contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py + +python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py -q +``` + +Result: + +```text +2 passed +``` + +### Artifact Under Test + +```text +qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z +``` -Trn2 artifact validation: +Runtime: -- Coherence pass: `true` -- Decode throughput: `21.63 tok/s` +- Instance: `trn2.3xlarge` +- Runtime path: offline vLLM/NxDI +- Sampling: on-device greedy +- `output_logits=false` +- TKG buckets: `[8192, 32768, 131072]` +- Prefix buckets in this artifact: up to `16384` + +### Coherence + +File: + +```text +qwen36_directsolve_chat_coherence_20260522T1332Z.json +``` + +Result: + +- Overall pass: `true` +- Chat template: `enable_thinking=false` +- Fact, code, and prefix-cache prompts produced real non-repetitive output. +- Smoke decode throughput: about `20.5 tok/s` + +### Decode + +File: + +```text +qwen36_directsolve_decode_bench_20260522T1348Z.json +``` + +Result: + +- Average decode throughput: `21.63 tok/s` - TPOT: `46.2 ms/token` -- Cold prefill: about `590 tok/s` from 4K through 16K -- Warm prefill: up to `36.3k tok/s` at 16K with APC reuse -- HBM peak sum: `60.1 GiB` +- 128-token decode latency: `5.92 s` + +### Cold / Warm Prefill + +File: + +```text +context_sweep_partial_20260522T1348Z.json +``` + +| Prompt tokens | Cold TTFT | Cold prefill | Warm TTFT | Warm prefill | +| ---: | ---: | ---: | ---: | ---: | +| 512 | 1.31 s | 390 tok/s | 0.42 s | 1.2k tok/s | +| 4096 | 7.03 s | 582 tok/s | 0.42 s | 9.8k tok/s | +| 8192 | 13.61 s | 602 tok/s | 0.43 s | 18.9k tok/s | +| 16384 | 27.84 s | 589 tok/s | 0.45 s | 36.3k tok/s | + +### Memory + +File: + +```text +qwen36_directsolve_perf_capture_20260522T1348Z.json +``` + +Result: + +- Neuron HBM peak sum: `60.1 GiB` +- Host process RSS peak: `46.3 GiB` +- Main logical cores peaked around `14.57 GiB` each on cores `0`, `2`, `4`, and `6` + +## Known Limitations + +The compiled artifact used for this validation has `prefix_buckets` only through `16384`. The 32K row failed with: + +```text +Prefix len 16640 exceeds largest bucket 16384 for context_encoding_model +``` + +That is an artifact bucket coverage limitation, not a direct-solve correctness failure. A long-context follow-up compile should include prefix buckets beyond 16K, ideally through the target 64K/128K range. + +The artifact results were produced from `qwen-fused-neumann-stable-decay`, which includes the full `experimental` runtime lineage. This clean branch shows the fused direct-solve extraction and result evidence, but it should not be read as proof that PR 164 plus only these extracted commits reproduces every Hybrid APC runtime behavior from `experimental`. + +## What Is Intentionally Not Included + +This clean branch does not include the full 80+ commit `experimental` stack. It also does not include large raw logs, obsolete investigation branches, or temporary scripts. Those were useful during development but would make this review branch hard to inspect. -Known limitation: +## Recommended Next Step -- The compiled artifact used for this validation has `prefix_buckets` through `16384`; the 32K sweep failed with `Prefix len 16640 exceeds largest bucket 16384`. A long-context follow-up compile needs larger prefix buckets. +Use this branch as the reviewer-facing result branch for the fused direct-solve change. If reviewers require source-level reproducibility for the full artifact behavior, stack or merge a curated `experimental` runtime-stabilization branch beneath this direct-solve work. From 935dab988897545e4b920a5c9fc98e2355d2a0a2 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 20:01:29 +0530 Subject: [PATCH 15/21] Document Qwen fused direct-solve validation --- contrib/models/Qwen3.6-27B/README.md | 41 +++++- .../fused_directsolve_validation_20260522.md | 125 ++++++++++++++++++ 2 files changed, 159 insertions(+), 7 deletions(-) create mode 100644 contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 650a7012..030d80ee 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -49,7 +49,7 @@ Qwen3.6 weights. - **Hybrid DeltaNet + GQA:** 48 of 64 layers use Gated DeltaNet (linear recurrent attention), 16 layers use standard GQA with KV cache. The pattern repeats every 4 layers: 3 DeltaNet + 1 GQA. - **DeltaNet Linear Attention:** Uses the delta rule for recurrent state updates with gated decay. Per-step: `state *= exp(g); delta = (v - state^T @ k) * beta; state += outer(k, delta); output = state^T @ q`. Runs as a chunked algorithm for context encoding, per-token recurrence for token generation. -- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused kernel uses a Neumann series for intra-chunk correction with state persistence in SBUF across chunks. +- **Custom NKI Kernels:** Three NKI kernels implement the DeltaNet forward pass on Neuron: a per-token recurrent kernel (TKG), a per-chunk kernel (legacy), and a fused single-kernel chunked forward (CTE). The fused CTE kernel uses the same direct lower-triangular intra-chunk solve strategy as the stable chunked path, with state persistence in SBUF across chunks. Earlier Neumann power-doubling experiments were not stable enough for Qwen3.6 gate scales. - **GQA Output Gate:** Attention layers use a sigmoid output gate. `q_proj` is 2x sized and interleaved: `[head0_query | head0_gate | head1_query | ...]`. The gate is split during weight conversion and applied after attention. - **Partial RoPE:** Only 25% of head_dim (64 of 256 dimensions) receives rotary embeddings. The remaining 192 dimensions are identity (no rotation). - **+1 RMSNorm Convention:** HF weights use `output = norm(x) * (1 + weight)` where weight is initialized to zeros. Converted to standard `output = norm(x) * weight` during loading by adding 1.0 to all RMSNorm weights (except DeltaNet internal norms, which use standard convention). @@ -151,6 +151,27 @@ matches: | Offline partial-prefix reuse | 25.52s | 1.70s | 15.0x | exact token-ID match | | Server cross-prefix reuse | 25.17s | 1.36s | 18.5x | exact text match | +### Fused Direct-Solve Validation + +The fused DeltaNet CTE kernel was revalidated with a direct triangular solve +instead of the earlier Neumann power-doubling correction. Detailed results and +raw result-file references are recorded in +[`docs/fused_directsolve_validation_20260522.md`](docs/fused_directsolve_validation_20260522.md). + +| Metric | Result | Notes | +|--------|--------|-------| +| Coherence | PASS | Fact, code, and prefix-cache prompts produced real non-repetitive text with `enable_thinking=false` | +| Decode throughput | 21.63 tok/s | Offline vLLM/NxDI path, on-device greedy sampling, 128-token decode | +| Decode TPOT | 46.2 ms/token | Same run as decode throughput | +| Cold 512-token TTFT | 1.31s | 390 tok/s cold prefill | +| Cold 16K-token TTFT | 27.84s | 589 tok/s cold prefill | +| Warm 16K-token TTFT | 0.45s | 36.3K tok/s effective warm prefill | +| Peak Neuron HBM | 60.1 GiB | Sum across logical cores on trn2.3xlarge | + +The validated direct-solve artifact used prefix buckets through 16K. A 32K +prompt exceeded that artifact's largest prefix bucket, so longer-context fused +validation requires recompiling the same code with larger prefix buckets. + ### Hybrid APC Follow-up Status Follow-up work on the `experimental` branch extended the baseline vLLM/APC @@ -198,10 +219,11 @@ proof is available: - This is the path expected to turn the current exact single-request APC proof into a measured cold-prefill performance win for batched serving. -The fused CTE kernel and FP8 path are not the current correctness blockers. -The BF16 per-chunk CTE path is the reference path for Hybrid APC validation: -the fused BF16 CTE artifact has shown NaNs around token 105-106, and FP8 should -be revisited after the BF16 batch-2 serving contract is proven. +The fused CTE kernel is no longer blocked on the earlier Neumann-series NaN +failure: the direct-solve fused path passed the coherence and performance +validation summarized above. The remaining Hybrid APC follow-up is serving +contract coverage, especially generated-token batch-2 validation with matching +`ctx_batch_size=2` and `tkg_batch_size=2` artifacts. ### Key Observations @@ -323,7 +345,7 @@ The DeltaNet forward path can be controlled via environment variables: 6. **+1 RMSNorm convention:** Qwen3.5/3.6 uses `output = norm(x) * (1 + weight)` for most RMSNorm layers, but DeltaNet internal norms use standard `output = norm(x) * weight`. The weight conversion handles this automatically, but custom weight loading must be aware of both conventions. -7. **DeltaNet numerical stability:** DeltaNet kernels rely on normalized Q/K inputs and bounded decay handling. The chunked path includes regression coverage for decay handling; changes to the fused kernel should be validated against the CPU reference and long-context stress prompts. +7. **DeltaNet numerical stability:** DeltaNet kernels rely on normalized Q/K inputs and bounded decay handling. The chunked and fused CTE paths now use direct triangular solves for intra-chunk correction; changes to either path should be validated against the CPU reference, the fused NKI validator, and long-context stress prompts. 8. **Shared codebase with Qwen3.5-27B:** This contrib uses the same `Qwen35*` class names and `modeling_qwen35*.py` filenames as the [Qwen3.5-27B contrib](../Qwen3.5-27B/). This is intentional -- both models share the `qwen3_5` model_type. The code is identical; only the HuggingFace model ID and weights differ. @@ -341,6 +363,11 @@ For production long-context serving on trn2.3xlarge, use the FP8/vLLM artifact and 512-token context encoding bucket. Larger instances are recommended for larger batches or additional serving headroom. +The fused direct-solve artifact summarized in this README was compiled with +prefix buckets through 16K. It is suitable for fused-kernel correctness and +short-to-mid-context performance validation, but not for proving 64K/128K warm +APC behavior without recompilation. + ## Compatibility Matrix | Instance | TP | LNC | Status | Notes | @@ -394,4 +421,4 @@ Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses th AWS Neuron -**Last Updated:** 2026-04-23 +**Last Updated:** 2026-05-22 diff --git a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md new file mode 100644 index 00000000..cd6aee66 --- /dev/null +++ b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md @@ -0,0 +1,125 @@ +# Qwen3.6 Fused Direct-Solve Validation + +This note records the reviewer-facing validation for the fused DeltaNet CTE +direct-solve branch. The raw JSON captures are kept in +`profile_artifacts/qwen36_fused_directsolve_20260522/`. + +## Artifact + +| Field | Value | +|-------|-------| +| Source branch | `qwen-fused-neumann-stable-decay` | +| Source commit | `ae2613d` | +| Clean PR branch | `qwen36-pr164-directsolve-baseline` | +| Artifact | `qwen36_27b_128k_fp8_mlp_edgebf16_hybrid_apc_nki_fusedstable_directsolve_retry_b256_cte256_512_pfx16k_slots64_tkg8192_32768_131072_async_20260522T130050Z` | +| Runtime | trn2.3xlarge, TP=4, LNC=2 | +| Serving path | Offline vLLM/NxDI, on-device greedy sampling | + +The clean PR branch is based on PR 164 and extracts the fused DeltaNet +direct-solve change. The measured artifact was compiled from the full +experimental runtime lineage: + +```text +PR 164 / vLLM APC baseline + -> experimental + -> qwen-fused-neumann-stable-decay +``` + +That lineage matters because the artifact also uses the Hybrid APC runtime, +FP8 128K compile settings, and vLLM/NxDI serving fixes developed after the +original PR 164 baseline. + +## What Changed + +The fused DeltaNet CTE kernel previously used Neumann power-doubling for the +intra-chunk correction. That approximation was fast, but it was not stable for +realistic Qwen3.6 gate/decay values and produced invalid output in fused BF16 +experiments. + +The direct-solve version replaces that correction with the lower-triangular +solve strategy used by the stable chunked CTE path. This keeps the fused +single-kernel structure while avoiding the observed Neumann instability. + +## Coherence + +Raw file: +`../../../../profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_chat_coherence_20260522T1332Z.json` + +| Check | Result | +|-------|--------| +| Overall pass | PASS | +| Chat template | `enable_thinking=false` | +| Fact prompt | Coherent | +| Code prompt | Coherent | +| Prefix-cache prompt | Coherent | +| Smoke decode | ~20.5 tok/s | + +The sampled outputs were non-repetitive, real text. This addresses the earlier +failure mode where the fused path could emit invalid repeated tokens. + +## Decode + +Raw file: +`../../../../profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_decode_bench_20260522T1348Z.json` + +| Metric | Result | +|--------|--------| +| Average decode throughput | 21.63 tok/s | +| TPOT | 46.2 ms/token | +| 128-token decode average latency | 5.92s | +| Sampling | On-device greedy, `output_logits=false` | + +## Cold And Warm Prefill + +Raw file: +`../../../../profile_artifacts/qwen36_fused_directsolve_20260522/context_sweep_partial_20260522T1348Z.json` + +| Prompt tokens | Cold TTFT | Cold prefill | Warm TTFT | Warm prefill | +|--------------:|----------:|-------------:|----------:|-------------:| +| 512 | 1.31s | 390 tok/s | 0.42s | 1.2K tok/s | +| 4,096 | 7.03s | 582 tok/s | 0.42s | 9.8K tok/s | +| 8,192 | 13.61s | 602 tok/s | 0.43s | 18.9K tok/s | +| 16,384 | 27.84s | 589 tok/s | 0.45s | 36.3K tok/s | + +The 32K row did not complete because this artifact was compiled with +`prefix_buckets` only through 16K: + +```text +Prefix len 16640 exceeds largest bucket 16384 for context_encoding_model +``` + +That is a bucket-coverage limitation of this artifact, not a fused direct-solve +correctness failure. Longer-context warm APC validation requires recompiling +with larger prefix buckets. + +## Memory + +Raw file: +`../../../../profile_artifacts/qwen36_fused_directsolve_20260522/qwen36_directsolve_perf_capture_20260522T1348Z.json` + +| Metric | Result | +|--------|--------| +| Neuron HBM peak sum | 60.1 GiB | +| Host process RSS peak | 46.3 GiB | +| Main logical-core peaks | ~14.57 GiB on cores 0, 2, 4, and 6 | + +## Validation Commands + +Local CPU checks used for this branch: + +```bash +python3 -m py_compile \ + contrib/models/Qwen3.6-27B/src/nki_kernels/nki_deltanet_fused.py \ + contrib/models/Qwen3.6-27B/scripts/validate_deltanet_fused_nki.py \ + contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py + +python3 -m pytest contrib/models/Qwen3.6-27B/test/unit/test_deltanet_decay.py -q +``` + +Result: + +```text +2 passed +``` + +Neuron validation used the compiled artifact above on trn2.3xlarge. From 724f9f2244e3f237df6991fa7317e8264fa90b61 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 20:05:54 +0530 Subject: [PATCH 16/21] Clarify Qwen3.6-only validation scope --- contrib/models/Qwen3.6-27B/README.md | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 030d80ee..a2cbf444 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -6,6 +6,10 @@ NeuronX Distributed Inference implementation of Qwen3.6-27B, a 27B parameter den Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib reuses the same NxDI implementation as [Qwen3.5-27B](../Qwen3.5-27B/) (PR #128). Any code updates to Qwen3.5-27B should be propagated to this contrib and vice versa. +This README reports Qwen3.6-27B validation only. Qwen3.5-27B is referenced for +architecture and code lineage; it was not re-benchmarked as part of this +Qwen3.6 contrib validation. + ### Config differences from Qwen3.5-27B | Field | Value | Impact | @@ -105,16 +109,6 @@ Unit tests are architecture-level and do not depend on weights. Coverage include | 64 | 54.2 | 18.5 | 3,720 | | 128 | 54.2 | 18.5 | 4,912 | -### Comparison with Qwen3.5-27B - -| Metric | Qwen3.5-27B | Qwen3.6-27B | Delta | -|--------|------------|------------|-------| -| TPOT P50 | 53 ms | 54.2 ms | +2.3% | -| Throughput | 18.9 tok/s | 18.5 tok/s | -2.1% | -| TTFT (128 tok) | 576 ms | 306.6 ms | -47% * | - -\* TTFT improvement is due to compilation config differences (256-token bucket vs 128-token bucket), not model differences. Architectural performance is equivalent. - ### Long-Context vLLM Baseline A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) @@ -230,7 +224,7 @@ contract coverage, especially generated-token batch-2 validation with matching - **BF16 TP=4 is HBM-limited:** The pure BF16 path is limited to short contexts on trn2.3xlarge. The validated 128K baseline uses MLP-only FP8 weights plus the hybrid cache manager. - **DeltaNet enables efficient TKG:** Token generation uses O(1) per-token recurrence instead of O(n) KV cache attention for 48/64 layers. - **vLLM APC is high leverage:** Repeated-prefix requests avoid replaying long chunked prefill and are the largest observed latency win for chat/RAG-style workloads. -- **Performance equivalent to Qwen3.5-27B:** The BF16 TPOT difference is within measurement noise. Expected since architectures are identical. +- **Qwen3.6-only measurements:** The benchmark tables above are Qwen3.6 results. Qwen3.5 is referenced only because the two contrib models share the same architecture and implementation lineage. ## Usage From 79a2d03a36ae4d13e94148330ecde7e872047027 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 20:15:35 +0530 Subject: [PATCH 17/21] Verify Qwen validation environment versions --- contrib/models/Qwen3.6-27B/README.md | 27 ++++++++++++------- .../fused_directsolve_validation_20260522.md | 23 ++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index a2cbf444..bf648025 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -73,7 +73,7 @@ Qwen3.6 weights. Unit tests are architecture-level and do not depend on weights. Coverage includes config parsing, weight conversion, hybrid cache allocation/update behavior, and DeltaNet decay handling. -### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4, SDK 2.29) +### Quality Validation (Qwen3.6-27B, trn2.3xlarge, TP=4) 7/7 text-only quality tests passed with `enable_thinking=False`: @@ -89,7 +89,7 @@ Unit tests are architecture-level and do not depend on weights. Coverage include ## Performance Benchmarks -### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, SDK 2.29, BF16) +### Qwen3.6-27B on trn2.3xlarge (TP=4, LNC=2, BF16) **TTFT (Time To First Token)** @@ -111,7 +111,7 @@ Unit tests are architecture-level and do not depend on weights. Coverage include ### Long-Context vLLM Baseline -A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2, SDK 2.29) +A 128K FP8-MLP artifact was validated on trn2.3xlarge (TP=4, LNC=2) with the vLLM Neuron plugin, Qwen chunked prefill, and native vLLM APC enabled. | Metric | Result | @@ -319,7 +319,7 @@ The DeltaNet forward path can be controlled via environment variables: | Env Var | Forward Path | Use Case | |---------|-------------|----------| -| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance (default for SDK 2.29) | +| `USE_NKI_FUSED=1` | Fused chunked NKI kernel | Best CTE performance in the validated build | | `USE_NKI_CHUNKED=1` | Per-chunk NKI kernel | Legacy, superseded by fused | | `USE_NKI=1` | Per-token NKI kernel | TKG (always used for token generation) | | `DELTANET_SEQUENTIAL=1` | Sequential PyTorch | Debugging/reference | @@ -329,7 +329,7 @@ The DeltaNet forward path can be controlled via environment variables: 1. **BF16 HBM pressure at TP=4:** The pure BF16 model consumes nearly all HBM on trn2.3xlarge. Use the FP8/vLLM path for the validated 128K artifact, or a larger instance for additional batching/headroom. -2. **SDK 2.29+ required:** The NKI DeltaNet kernels require NKI 0.3.0 (SDK 2.29). No library modifications needed -- runs on stock SDK 2.29 DLAMI (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). +2. **NKI 0.3.0+ required:** The NKI DeltaNet kernels were validated with the `nki` package version shown in the package table below. No library modifications were needed in the validated NxDI venv (`/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/`). 3. **No mini model test:** Unlike DeepSeek-V3, a mini model cannot be provided because DeltaNet layers require NKI kernels that only execute on Neuron devices. Integration tests require a trn2 instance with the full 27B weights. @@ -370,15 +370,24 @@ APC behavior without recompilation. | trn2.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; use the same TP=4 artifact shape when compiling for trn2.3xlarge deployment | | trn2u.48xlarge | 4 | 2 | Expected PASS | Untested for this contrib; same portability caveat as trn2.48xlarge | -### SDK Configuration +### Validation Host Package Versions + +These versions were checked on a running trn2.3xlarge validation host on +2026-05-22. Treat them as the observed validation environment, not as a generic +SDK release label. | Component | Version | |-----------|---------| -| NxDI | 0.9.17334 | -| neuronx-cc | 2.24.5133 | +| neuronx-distributed-inference | 0.9.17334+ced6ae4e | +| neuronx-distributed | 0.18.27753+1cafd54f | +| neuronx-cc | 2.24.8799.0+6f62ff7c | +| nki | 0.3.0+23928721754.g18aa1271 | | torch | 2.9.1 | +| torch-neuronx | 2.9.0.2.13.26312+8e870898 | +| torch-xla | 2.9.0 | | transformers | 4.57.6 | -| NKI | 0.3.0 | +| aws-neuronx-runtime-lib | 2.31.24.0-0b044f4ce | +| aws-neuronx-tools | 2.29.22.0-b486b0ade | | NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | ## Testing diff --git a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md index cd6aee66..04f2af44 100644 --- a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md +++ b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md @@ -15,6 +15,10 @@ direct-solve branch. The raw JSON captures are kept in | Runtime | trn2.3xlarge, TP=4, LNC=2 | | Serving path | Offline vLLM/NxDI, on-device greedy sampling | +The Trn2 validation host was rechecked on 2026-05-22 before publishing these +notes: it was reachable, reported instance type `trn2.3xlarge`, logical Neuron +core config `2`, and contained the artifact directory at 36G. + The clean PR branch is based on PR 164 and extracts the fused DeltaNet direct-solve change. The measured artifact was compiled from the full experimental runtime lineage: @@ -103,6 +107,25 @@ Raw file: | Host process RSS peak | 46.3 GiB | | Main logical-core peaks | ~14.57 GiB on cores 0, 2, 4, and 6 | +## Observed Package Versions + +These versions were read from the running validation host and the active NxDI +venv on 2026-05-22. + +| Component | Version | +|-----------|---------| +| neuronx-distributed-inference | 0.9.17334+ced6ae4e | +| neuronx-distributed | 0.18.27753+1cafd54f | +| neuronx-cc | 2.24.8799.0+6f62ff7c | +| nki | 0.3.0+23928721754.g18aa1271 | +| torch | 2.9.1 | +| torch-neuronx | 2.9.0.2.13.26312+8e870898 | +| torch-xla | 2.9.0 | +| transformers | 4.57.6 | +| aws-neuronx-runtime-lib | 2.31.24.0-0b044f4ce | +| aws-neuronx-tools | 2.29.22.0-b486b0ade | +| NXDI venv | `/opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/` | + ## Validation Commands Local CPU checks used for this branch: From 3698a4cb9ae64cd386a594210bf6e9fb5a000c3d Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 20:18:32 +0530 Subject: [PATCH 18/21] Correct Qwen3.6 lineage to PR 140 --- contrib/models/Qwen3.6-27B/README.md | 4 ++-- .../docs/fused_directsolve_validation_20260522.md | 3 ++- .../qwen36_fused_directsolve_20260522/PR_BODY.md | 9 +++++++-- .../qwen36_fused_directsolve_20260522/README.md | 6 ++++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index bf648025..995dcf46 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -2,9 +2,9 @@ NeuronX Distributed Inference implementation of Qwen3.6-27B, a 27B parameter dense model from Alibaba Cloud with a hybrid DeltaNet + GQA attention architecture. -## Relationship to Qwen3.5-27B +## Relationship to PR #140 and Qwen3.5-27B -Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib reuses the same NxDI implementation as [Qwen3.5-27B](../Qwen3.5-27B/) (PR #128). Any code updates to Qwen3.5-27B should be propagated to this contrib and vice versa. +Qwen3.6-27B is a **post-training update** of Qwen3.5-27B with improved agentic coding and thinking preservation. The models share **identical architecture** (`qwen3_5` model_type, `Qwen3_5ForConditionalGeneration`) -- only weights differ. This contrib builds on Jim Burtoft's Qwen3.6-27B contrib work in PR #140 and the shared Qwen3.5/Qwen3.6 hybrid architecture pattern. This README reports Qwen3.6-27B validation only. Qwen3.5-27B is referenced for architecture and code lineage; it was not re-benchmarked as part of this diff --git a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md index 04f2af44..06902ccd 100644 --- a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md +++ b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md @@ -24,11 +24,12 @@ direct-solve change. The measured artifact was compiled from the full experimental runtime lineage: ```text -PR 164 / vLLM APC baseline +PR 164 / vLLM APC baseline, building on PR #140 -> experimental -> qwen-fused-neumann-stable-decay ``` +PR 164 itself builds on Jim Burtoft's Qwen3.6-27B contrib work in PR #140. That lineage matters because the artifact also uses the Hybrid APC runtime, FP8 128K compile settings, and vLLM/NxDI serving fixes developed after the original PR 164 baseline. diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 65eefc93..97b13147 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -11,12 +11,17 @@ The important result is that the fused DeltaNet CTE path is now coherent with re The actual development history was: ```text -PR 164 / vLLM APC baseline +PR 164 / vLLM APC baseline, building on PR #140 -> experimental -> qwen-fused-neumann-stable-decay ``` -PR 164 is the original Qwen3.6 vLLM APC baseline. After that, the `experimental` branch accumulated the runtime and validation work needed to make Hybrid APC usable and measurable. The final `qwen-fused-neumann-stable-decay` branch was created from `experimental` and added the fused DeltaNet stability work. +PR 164 is the original Qwen3.6 vLLM APC baseline for this branch, and PR 164 +itself builds on Jim Burtoft's Qwen3.6-27B contrib work in PR #140. After that, +the `experimental` branch accumulated the runtime and validation work needed to +make Hybrid APC usable and measurable. The final +`qwen-fused-neumann-stable-decay` branch was created from `experimental` and +added the fused DeltaNet stability work. This clean branch extracts the direct-solve fused DeltaNet work and its result artifacts onto PR 164 for review. It does not include the entire `experimental` branch history. diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index bb811ab5..49e1d793 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -16,12 +16,14 @@ The fused DeltaNet CTE path now uses a direct triangular RHS solve instead of Ne The development lineage was: ```text -PR 164 / vLLM APC baseline +PR 164 / vLLM APC baseline, building on PR #140 -> experimental -> qwen-fused-neumann-stable-decay ``` -The `experimental` branch accumulated the runtime and validation work needed to make Qwen3.6 Hybrid APC usable beyond the original PR 164 baseline: +PR 164 itself builds on Jim Burtoft's Qwen3.6-27B contrib work in PR #140. The +`experimental` branch accumulated the runtime and validation work needed to make +Qwen3.6 Hybrid APC usable beyond the original PR 164 baseline: - Hybrid APC checkpoint cache, lifecycle, restore/commit masks, and strict metadata contracts. - vLLM/NxDI scheduler bridge changes for cached chunked prefill, backed prefix reads, request-id propagation, and suffix continuation handling. From df6d936e95a35d6ab4d58858cf2388034f9ae7e7 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 20:22:21 +0530 Subject: [PATCH 19/21] Document Qwen HBM comparison caveat --- contrib/models/Qwen3.6-27B/README.md | 9 +++++++++ .../docs/fused_directsolve_validation_20260522.md | 13 +++++++++++++ .../qwen36_fused_directsolve_20260522/PR_BODY.md | 8 ++++++++ .../qwen36_fused_directsolve_20260522/README.md | 12 ++++++++++++ 4 files changed, 42 insertions(+) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 995dcf46..574b82f0 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -166,6 +166,15 @@ The validated direct-solve artifact used prefix buckets through 16K. A 32K prompt exceeded that artifact's largest prefix bucket, so longer-context fused validation requires recompiling the same code with larger prefix buckets. +Memory note: the `60.1 GiB` direct-solve number is a Neuron high-water peak +from the Hybrid APC artifact, not a prompt-length-only 16K allocation. The +artifact was compiled with `pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, +and token-generation buckets `[8192, 32768, 131072]`. The PR 164 vLLM/APC +baseline README reports `~53.25 GB` decimal during a 64K eval, but this branch +does not include that run's raw memory log or artifact config. Treat the higher +direct-solve HBM as a real observation that needs like-for-like A/B validation, +not as proof that the direct triangular solve itself increases memory. + ### Hybrid APC Follow-up Status Follow-up work on the `experimental` branch extended the baseline vLLM/APC diff --git a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md index 06902ccd..1fae0425 100644 --- a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md +++ b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md @@ -108,6 +108,19 @@ Raw file: | Host process RSS peak | 46.3 GiB | | Main logical-core peaks | ~14.57 GiB on cores 0, 2, 4, and 6 | +The HBM number above is a high-water peak from the Hybrid APC artifact. It was +identical for the context-sweep and decode-benchmark phases in the raw capture, +so it should be read primarily as artifact/runtime static allocation rather than +as memory consumed by a 16K prompt alone. The artifact config includes +`pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, and token-generation buckets +`[8192, 32768, 131072]`. + +PR 164 reports `~53.25 GB` decimal during a 64K vLLM/APC eval, but this clean +branch does not include that run's raw memory log or artifact config. A strict +memory regression claim needs a like-for-like rerun of the PR 164 artifact and +the direct-solve artifact with the same memory capture script and comparable +cache/bucket settings. + ## Observed Package Versions These versions were read from the running validation host and the active NxDI diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 97b13147..aa0dafc3 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -229,6 +229,14 @@ Result: - Host process RSS peak: `46.3 GiB` - Main logical cores peaked around `14.57 GiB` each on cores `0`, `2`, `4`, and `6` +Memory caveat: this is a Neuron high-water peak from the Hybrid APC artifact, +not a prompt-length-only 16K allocation. The artifact was compiled with +`pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, and TKG buckets +`[8192, 32768, 131072]`. PR 164 reports `~53.25 GB` decimal during a 64K +vLLM/APC eval, but this clean branch does not include that run's raw memory log +or artifact config. A strict memory regression claim needs like-for-like A/B +measurement with the same capture script and comparable cache/bucket settings. + ## Known Limitations The compiled artifact used for this validation has `prefix_buckets` only through `16384`. The 32K row failed with: diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index 49e1d793..45dc4b21 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -111,6 +111,18 @@ That is an artifact bucket coverage limitation, not a direct-solve correctness f - Host process RSS peak: `46.3 GiB` - Main logical cores peaked around `14.57 GiB` each on cores `0`, `2`, `4`, and `6` +Memory comparison caveat: + +- The `60.1 GiB` value is a Neuron high-water peak and appears in both the + context-sweep and decode-benchmark captures. +- The direct-solve artifact was compiled with `pa_num_blocks=512`, + `max_gdn_checkpoint_slots=64`, and TKG buckets `[8192, 32768, 131072]`. +- PR 164 reports `~53.25 GB` decimal during a 64K vLLM/APC eval, but this + directory does not include that run's raw memory log or artifact config. +- Treat the higher direct-solve HBM as a real observation requiring like-for-like + A/B validation, not as proof that the direct triangular solve itself increases + memory. + ## Follow-Up For long-context validation, recompile the same branch with prefix buckets beyond `16384`, ideally through the intended 64K/128K validation range. From 24babfbe614440b5ed0aaf036327bc19d91ec548 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 22:38:20 +0530 Subject: [PATCH 20/21] Document Qwen fused validation follow-ups --- contrib/models/Qwen3.6-27B/README.md | 23 ++++++++-------- .../fused_directsolve_validation_20260522.md | 27 ++++++++++--------- .../PR_BODY.md | 14 +++++----- .../README.md | 13 +++++---- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/contrib/models/Qwen3.6-27B/README.md b/contrib/models/Qwen3.6-27B/README.md index 574b82f0..3b8b42a9 100644 --- a/contrib/models/Qwen3.6-27B/README.md +++ b/contrib/models/Qwen3.6-27B/README.md @@ -166,14 +166,17 @@ The validated direct-solve artifact used prefix buckets through 16K. A 32K prompt exceeded that artifact's largest prefix bucket, so longer-context fused validation requires recompiling the same code with larger prefix buckets. -Memory note: the `60.1 GiB` direct-solve number is a Neuron high-water peak -from the Hybrid APC artifact, not a prompt-length-only 16K allocation. The -artifact was compiled with `pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, -and token-generation buckets `[8192, 32768, 131072]`. The PR 164 vLLM/APC -baseline README reports `~53.25 GB` decimal during a 64K eval, but this branch -does not include that run's raw memory log or artifact config. Treat the higher -direct-solve HBM as a real observation that needs like-for-like A/B validation, -not as proof that the direct triangular solve itself increases memory. +Memory note: the `60.1 GiB` direct-solve number is `64.54 GB` decimal and is a +Neuron high-water peak from the Hybrid APC artifact, not a prompt-length-only +16K allocation. The artifact was compiled with `pa_num_blocks=512`, +`max_gdn_checkpoint_slots=64`, and token-generation buckets +`[8192, 32768, 131072]`. The 64-slot GDN checkpoint bank alone is expected to +reserve about `9.85 GB` decimal across TP=4 ranks +(`38.49 MB/checkpoint/rank * 64 * 4`), which explains most of the gap from the +PR 164 vLLM/APC README's `~53.25 GB` decimal 64K eval number. Treat the higher +direct-solve HBM as a Hybrid APC artifact/config observation requiring +like-for-like A/B validation, not as proof that the direct triangular solve +itself increases memory. ### Hybrid APC Follow-up Status @@ -429,8 +432,4 @@ Note: The env var is `QWEN35_MODEL_PATH` (not `QWEN36`) because the code uses th - [`Qwen/Qwen3.6-27B`](https://huggingface.co/Qwen/Qwen3.6-27B) (BF16, ~52 GB) -## Maintainer - -AWS Neuron - **Last Updated:** 2026-05-22 diff --git a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md index 1fae0425..1678aabd 100644 --- a/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md +++ b/contrib/models/Qwen3.6-27B/docs/fused_directsolve_validation_20260522.md @@ -108,18 +108,21 @@ Raw file: | Host process RSS peak | 46.3 GiB | | Main logical-core peaks | ~14.57 GiB on cores 0, 2, 4, and 6 | -The HBM number above is a high-water peak from the Hybrid APC artifact. It was -identical for the context-sweep and decode-benchmark phases in the raw capture, -so it should be read primarily as artifact/runtime static allocation rather than -as memory consumed by a 16K prompt alone. The artifact config includes -`pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, and token-generation buckets -`[8192, 32768, 131072]`. - -PR 164 reports `~53.25 GB` decimal during a 64K vLLM/APC eval, but this clean -branch does not include that run's raw memory log or artifact config. A strict -memory regression claim needs a like-for-like rerun of the PR 164 artifact and -the direct-solve artifact with the same memory capture script and comparable -cache/bucket settings. +The HBM number above is a high-water peak from the Hybrid APC artifact. In +decimal units it is `64.54 GB`, so the comparison against PR 164's `~53.25 GB` +decimal 64K vLLM/APC eval is an `~11.29 GB` delta. It was identical for the +context-sweep and decode-benchmark phases in the raw capture, so it should be +read primarily as artifact/runtime static allocation rather than as memory +consumed by a 16K prompt alone. + +The artifact config includes `pa_num_blocks=512`, +`max_gdn_checkpoint_slots=64`, and token-generation buckets +`[8192, 32768, 131072]`. The 64-slot GDN checkpoint bank is expected to reserve +about `38.49 MB` per checkpoint per TP rank, or `9.85 GB` decimal across +TP=4. That accounts for most of the observed delta; the remaining difference +can plausibly come from bucket/runtime/capture differences, but needs a +like-for-like rerun of the PR 164 artifact and the direct-solve artifact with +the same memory capture script and comparable cache/bucket settings. ## Observed Package Versions diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index aa0dafc3..1c59bb06 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -230,12 +230,14 @@ Result: - Main logical cores peaked around `14.57 GiB` each on cores `0`, `2`, `4`, and `6` Memory caveat: this is a Neuron high-water peak from the Hybrid APC artifact, -not a prompt-length-only 16K allocation. The artifact was compiled with -`pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, and TKG buckets -`[8192, 32768, 131072]`. PR 164 reports `~53.25 GB` decimal during a 64K -vLLM/APC eval, but this clean branch does not include that run's raw memory log -or artifact config. A strict memory regression claim needs like-for-like A/B -measurement with the same capture script and comparable cache/bucket settings. +not a prompt-length-only 16K allocation. In decimal units the direct-solve peak +is `64.54 GB`, versus the PR 164 vLLM/APC README's `~53.25 GB` decimal 64K +eval number. The artifact was compiled with `pa_num_blocks=512`, +`max_gdn_checkpoint_slots=64`, and TKG buckets `[8192, 32768, 131072]`. The +64-slot GDN checkpoint bank is expected to reserve about `9.85 GB` decimal +across TP=4, explaining most of the delta. A strict memory regression claim +still needs like-for-like A/B measurement with the same capture script and +comparable cache/bucket settings. ## Known Limitations diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md index 45dc4b21..2d684d08 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/README.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/README.md @@ -114,14 +114,17 @@ That is an artifact bucket coverage limitation, not a direct-solve correctness f Memory comparison caveat: - The `60.1 GiB` value is a Neuron high-water peak and appears in both the - context-sweep and decode-benchmark captures. + context-sweep and decode-benchmark captures. In decimal units it is + `64.54 GB`. - The direct-solve artifact was compiled with `pa_num_blocks=512`, `max_gdn_checkpoint_slots=64`, and TKG buckets `[8192, 32768, 131072]`. -- PR 164 reports `~53.25 GB` decimal during a 64K vLLM/APC eval, but this - directory does not include that run's raw memory log or artifact config. +- The 64-slot GDN checkpoint bank is expected to reserve about `38.49 MB` per + checkpoint per TP rank, or `9.85 GB` decimal across TP=4, which explains most + of the gap to the PR 164 vLLM/APC README's `~53.25 GB` decimal 64K eval + number. - Treat the higher direct-solve HBM as a real observation requiring like-for-like - A/B validation, not as proof that the direct triangular solve itself increases - memory. + A/B validation of the artifact/cache configuration, not as proof that the + direct triangular solve itself increases memory. ## Follow-Up From 882e9da03ae7684fa9994f5b7612e37439c50099 Mon Sep 17 00:00:00 2001 From: Deepankar Singh Date: Fri, 22 May 2026 22:47:30 +0530 Subject: [PATCH 21/21] Remove recommended next step from PR body --- .../qwen36_fused_directsolve_20260522/PR_BODY.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md index 1c59bb06..7410e2eb 100644 --- a/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md +++ b/profile_artifacts/qwen36_fused_directsolve_20260522/PR_BODY.md @@ -254,7 +254,3 @@ The artifact results were produced from `qwen-fused-neumann-stable-decay`, which ## What Is Intentionally Not Included This clean branch does not include the full 80+ commit `experimental` stack. It also does not include large raw logs, obsolete investigation branches, or temporary scripts. Those were useful during development but would make this review branch hard to inspect. - -## Recommended Next Step - -Use this branch as the reviewer-facing result branch for the fused direct-solve change. If reviewers require source-level reproducibility for the full artifact behavior, stack or merge a curated `experimental` runtime-stabilization branch beneath this direct-solve work.