diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/BENCHMARK_OMNI2_TTFB.md b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/BENCHMARK_OMNI2_TTFB.md new file mode 100644 index 00000000..da97c72a --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/BENCHMARK_OMNI2_TTFB.md @@ -0,0 +1,453 @@ +# Qwen3-Omni TTFB / RTF benchmark on omni2 audio-in conversations + +End-to-end benchmark on 100 real multi-turn conversations with audio user +inputs (source: `/home/ubuntu/omni2`). Each conversation has a system prompt, +2–4 prior text turns, and a final user turn that is a `.wav` audio file. The +pipeline must produce a spoken assistant reply. + +This doc tracks the progressive optimizations that moved **TTFB from 2727 ms +→ 2000 ms** (−27 %) and the talker success rate from 0 % → 88 % (no max‐token +truncation). + +## Setup + +- Trn2.48xlarge, 8 Neuron cores pinned via `NEURON_RT_VISIBLE_CORES=0-7` +- TP=8 for every submodel +- Dataset: `/home/ubuntu/omni2/merged_conversations_with_audio_x10_with_system.json` + (system prompt ~800 tokens, prompt lengths 1164–1494 tokens; all land in + the 2048 bucket) +- 100 conversations; audio files in `/home/ubuntu/omni2/speech_wav_16k/` + +```bash +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + +# Thinker-only benchmark +NEURON_RT_VISIBLE_CORES=0-7 python test_thinker_ttft_bench.py --num 100 + +# Full streaming TTFB / RTF benchmark (best configuration) +NEURON_RT_VISIBLE_CORES=0-7 CHUNK_SIZE=25 LEFT_CTX=5 \ + python test_ttfb_rtf_bench.py --num 100 \ + --max-thinker 200 --max-talker 500 \ + --neuron-c2w +``` + +## Thinker-only TTFT & throughput (`test_thinker_ttft_bench.py`) + +The `tensor_capture_hook` fires once per thinker forward (prefill + each +decode step), so we use it as a per-token timing tap. TTFT = time from +`adapter.generate()` start to the first hook fire (= end of prefill). + +| metric | mean | p50 | p90 | p95 | +|---|---:|---:|---:|---:| +| TTFT (prefill) | **668 ms** | 667 | 672 | 678 | +| decode ITL | **10.2 ms** | 10.1 | 10.3 | 10.3 | +| tokens/s (overall) | **48.3** | 47.1 | 54.9 | 69.1 | +| RTF vs audio input | **0.37** | 0.33 | 0.65 | 0.79 | +| prompt tokens | 1294 | 1278 | 1400 | 1419 | + +All 100 samples succeeded. 100 conversations ran in 136 s wall time. + +## Full streaming TTFB / RTF (`test_ttfb_rtf_bench.py`) + +Streaming pipeline: thinker (Neuron) → talker (Neuron) → UCP (Neuron) → +code2wav. `code2wav` fires inline every `CHUNK_SIZE` codec tokens. TTFB = +request arrival → first audio chunk delivered to the host. + +### TTFB progression across configurations + +| configuration | TTFB mean | TTFB p50 | TTFB p90 | TTFB p95 | hit-max / 100 | +|---|---:|---:|---:|---:|---:| +| 1. baseline streaming (broken talker) | 2727 | 2666 | 3113 | 3564 | **100** | +| 2. + TensorRegistry fix + norm capture + HF sampling | 2763 | 2698 | 3140 | 4128 | 15 | +| 3. + CHUNK_SIZE=25 / LEFT_CTX=5 | 2276 | 2193 | 2670 | 3581 | 14 | +| 4. + Neuron `code2wav` | 2000 | 1915 | 2389 | 3316 | 12 | +| 5. + thinker↔talker pipelining | **1759** | **1778** | **1811** | **1822** | **12** | + +All milliseconds. "hit-max" counts samples where the talker reached +`max_new_tokens=500` instead of naturally emitting `codec_eos_token_id` — +smaller is better. + +The **biggest tail-latency win** from step 5: p95 dropped from 3316 → 1822 ms +(−45 %). With pipelining, TTFB no longer scales with thinker output length — +the user gets first audio in a near-constant window regardless of whether the +thinker reply is 50 or 200 tokens. + +### TTFB breakdown (step 4 — fully serial pipeline) + +| stage | mean | p50 | p90 | note | +|---|---:|---:|---:|---| +| thinker full generate (Neuron) | 1346 | 1263 | 1485 | prefill 668 + ~68 × 10 ms decode | +| build talker inputs + 25 talker steps + UCP | 532 | 543 | 553 | 25 × ~21 ms decode | +| first `code2wav` chunk | **122** | **122** | **122** | Neuron NEFF, T=30 bucket | +| **TTFB total** | **2000** | **1915** | **2389** | | + +### TTFB breakdown (step 5 — pipelined) + +In the pipelined run thinker and talker overlap, so the breakdown is no +longer a sum of stages. The dominant component is "wait for first 4 thinker +tokens", measured as `build_talker_blocked_ms`: + +| stage | mean | p50 | note | +|---|---:|---:|---| +| wait for first 4 thinker tokens | 765 | 762 | thinker prefill (~668 ms) + a few decode steps + bg-thread overhead | +| talker prefill + 25 decode steps + first `code2wav` chunk | ~990 | ~1015 | running concurrently with the rest of thinker decode; sometimes blocks on get_trailing_slice waiting for the next thinker token | +| **TTFB total** | **1759** | **1778** | | + +#### Why "thinker" is 1346 ms, not the 668 ms TTFT number + +The `test_thinker_ttft_bench.py` table reports **TTFT = 668 ms** (time to first +token = prefill) and **ITL = 10 ms** (per decode step). In the streaming TTFB +pipeline, though, the talker cannot start until the thinker has generated the +**entire** assistant reply — HF's `_build_talker_inputs` needs the full token +sequence and the full layer-23 hidden tensor to assemble the talker's prompt. + +So the "thinker" row in the TTFB breakdown covers **prefill + all decode +steps** of the thinker, not just TTFT. With mean 68 new tokens: + +``` +thinker_ms ≈ prefill + new_tokens × ITL + ≈ 668 + 68 × 10 + ≈ 1348 ms (measured: 1346 ms) +``` + +Concretely, this is the serial pipeline the bench runs today: + +``` +t=0 request arrives +t=668 thinker prefill done (first thinker token available — TTFT) +t=1346 thinker done (68 decode steps @ 10 ms, all tokens + hiddens ready) +t=~1400 talker prefill done (~50 ms build + 1 prefill forward on talker) +t=1878 25 talker decode steps done (25 × ~21 ms, each pairs with one UCP call) +t=2000 first code2wav chunk returned → first audio delivered +``` + +The 1346 ms thinker cost is what dominates TTFB now, and it's mostly **not** +prefill (bucket 2048) — it's the 68 serial decode steps after prefill. + +#### What the ceiling looks like if we pipeline + +If we were willing to change architecture and let the talker consume thinker +tokens as they stream out (instead of waiting for the full thinker sequence), +TTFB could in principle drop to roughly the thinker-prefill time + a short +warmup before the talker finds enough context to emit codec tokens: + +``` +t=0 request arrives +t=~668 thinker prefill done, thinker begins streaming tokens +t=~668+K*10 K additional thinker tokens buffered so talker has enough context + to start (K is small — tens of tokens) + (in parallel: talker prefill + first few decode steps) +t=TTFB first 25 codec tokens produced, first c2w chunk returned +``` + +Ballpark with K ≈ 30: `668 + 30 × 10 + talker_prefill + 25 × 21 + 122 ≈ +1400 ms` — a ~600 ms reduction from the current 2000 ms. This requires: + +1. Making `_build_talker_inputs` incremental so it can extend the talker + context one thinker token at a time (today it's a single batched + assemble). +2. Running thinker decode and talker decode concurrently — either two Python + threads with separate Neuron queues, or a host-side coroutine that + alternates `thinker_step()` / `talker_step()` calls. +3. Deciding when K is large enough to start the talker (static threshold + sufficient; adaptive would be nicer). + +The thinker and talker run on disjoint NEFFs, so there's no device-level +conflict; the work is "just" in the orchestration. + +### Full-run stats at best configuration + +| metric | mean | p50 | p90 | p95 | +|---|---:|---:|---:|---:| +| TTFB | 2000 ms | 1915 | 2389 | 3316 | +| thinker | 1345 ms | 1262 | 1485 | 2656 | +| total (end-to-end) | 5648 ms | 4804 | 11428 | 11701 | +| RTF (total / wav) | 0.61 | 0.39 | 0.89 | 1.19 | +| input audio | 5.18 s | 4.27 | 9.26 | 9.86 | +| output wav | 16.06 s | 13.08 | 39.50 | 39.50 | +| thinker tokens | 68 | 59 | 81 | 151 | + +100/100 succeeded. 88/100 talker runs ended at `codec_eos`; 12 hit +`max_new_tokens=500`. + +--- + +## Fixes made + +### 1. `TensorRegistry.clear()` wiped `modules_to_capture` across buckets + +**Symptom.** With `tensor_capture_config={"layers.23"}` configured, capture +worked perfectly at the first bucket (256) and returned the empty fallback +`torch.zeros(1, dtype=bfloat16)` at every larger bucket (512 / 1024 / 2048 / +4096). All our omni2 prompts land in the 2048 bucket, so capture produced +`(1,)` and `_assemble_hidden` crashed on `captured[0][:, :prompt_len, :]`. + +**Root cause.** `NeuronBaseModel._get_captured_tensors` is called once per HLO +trace (once per bucket) and ends with `registry.clear()`. Upstream +`TensorRegistry.clear` in +`neuronx_distributed/utils/tensor_capture/registry.py` replaces `model_info` +with a fresh `CapturedModelInfo([], 10, False)`, **erasing the configured +`modules_to_capture`**. Forward hooks installed by `enable_tensor_capture` keep +firing, but `register_tensor` no longer finds the module name in +`modules_to_capture` and falls through to the "manual" branch. Only the first +bucket gets a real capture; every subsequent bucket's NEFF bakes in a zero +fallback. + +**Fix** (in `src/_upstream_compat.py::_patch_tensor_registry_clear`). +Monkey-patch `configure()` to stash the last non-empty module list and +`clear()` to restore it instead of wiping. Five lines of glue, applied at +import time. + +After the fix, verified all five buckets emit real captures with the expected +shape `(1, bucket_size, 2048)`. + +### 2. Talker shim fabricated hidden, blocking `codec_eos` + +**Symptom.** In the baseline streaming run (v1 above), **100/100 samples hit +`max_new_tokens`**. The talker never emitted `codec_eos_token_id = 2150`. We +confirmed via argmax logging that the decoder locked into repetitive loops +like `[318, 318, 318, ...]`. + +**Root cause.** HF's talker generate loop reads the per-step hidden from +`output.hidden_states[-1]` and feeds it to `code_predictor` as `past_hidden`. +Our `NeuronTalkerShim` (`test_audio_out_full_neuron.py`) returned a +**fabricated** hidden built by re-embedding the argmax'd codec token: + +```python +tok = logits_last.argmax(dim=-1) +fake_hidden = hf_model.talker.get_input_embeddings()(tok).to(torch.bfloat16) +return BaseModelOutputWithPast(hidden_states=(fake_hidden,), ...) +``` + +That stand-in drifted far enough from the talker's real pre-lm_head hidden +that greedy decoding couldn't reach `codec_eos` at all. + +**Fix.** Recompile the talker with `TensorCaptureConfig(modules_to_capture=["norm"])` +so the NEFF emits the real post-RMSNorm hidden `[B, S, 1024]` as an extra +output. New compile script: `compile_talker.py`. New artifact: +`/tmp/qwen3_omni_compiled/talker_tp8_capnorm/`. Compile time: ~11 min. + +Then update `make_neuron_talker_shim` in `test_audio_out_full_neuron.py` to +parse `out[2]` from the NEFF output (logits, gathered_logits, **captured +norm**) and pass it as `hidden_states=(real_hidden,)` instead of `fake_hidden`. + +### 3. Talker `generate()` call missed HF's reference settings + +After fix 2, talker could reach `codec_eos` in principle, but 85 % of runs +still hit max with greedy decoding because the argmax trajectory occasionally +locks into loops (we saw `[318, 318, 318, ...]` looping at the end). + +HF's reference `Qwen3OmniMoeForConditionalGeneration.generate` uses: + +```python +suppress_tokens = [i for i in range(vocab - 1024, vocab) if i != codec_eos] +talker.generate(do_sample=True, top_k=50, top_p=0.8, temperature=0.9, + repetition_penalty=1.1, suppress_tokens=suppress_tokens, ...) +``` + +`suppress_tokens` masks out the 1 024 non-codec ids (text-token range left +over in the talker's shared vocab). Sampling + repetition penalty breaks the +`[318, 318, ...]` loops. + +**Fix.** `test_ttfb_rtf_bench.py` now passes the full HF-matching talker +config. Hit-max dropped from 100/100 → 15/100. + +### 4. `CHUNK_SIZE=25` (from 50) + +Streaming fires code2wav after every `CHUNK_SIZE` codec tokens. The baseline +CHUNK_SIZE=50 meant the user waited for 50 talker steps (~950 ms) plus one +big c2w call (~540 ms on CPU) before hearing the first audio. Halving to 25 +cuts both. + +**Fix.** `CHUNK_SIZE` and `LEFT_CTX` are now env-var-controlled in +`test_audio_streaming.py`. TTFB dropped 487 ms (from 2763 → 2276 ms). + +```bash +CHUNK_SIZE=25 LEFT_CTX=5 python test_ttfb_rtf_bench.py ... +``` + +Trade-off: the small `LEFT_CTX` re-compute at each chunk boundary adds a few +hundred ms to the total wall time. Net win for TTFB, net neutral for total. + +### 5. Code2Wav on Neuron + +With the previous fixes, TTFB breakdown was thinker 1345 ms + talker 540 ms ++ **first_c2w 387 ms on CPU**. That last CPU step was the largest +remaining non-Neuron cost in the critical path. + +**Compile.** `compile_code2wav.py` traces `Qwen3OmniMoeCode2Wav.forward` +(8-layer sliding-window transformer → upsample conv chain → BigVGAN decoder) +with `torch_neuronx.trace` at fixed input lengths. Bucket set `{30, 50, 128}` +covers the streaming chunk (CHUNK_SIZE=25 + LEFT_CTX=5 = 30) and the residual +tail at finalize. Single-core, fp32 (`--auto-cast=none`), no TP. Compile time: +~2.5 min per bucket. + +**Runtime shim.** `code2wav_neuron.py::NeuronCode2WavShim` replaces +`hf_model.code2wav`. At call time it picks the smallest bucket ≥ T, zero-pads +the codec-token tensor up to the bucket, runs the Neuron NEFF, and trims the +output back to `T * total_upsample` samples. `chunked_decode` is forwarded +through the same shim for symmetry. + +**Verified** bit-exact against CPU: `max_abs_diff = 0.00000`, +`cosine_similarity = 1.0000`. + +**Result.** First-chunk c2w: **387 ms → 122 ms** (3.2× faster). TTFB: 2276 → +**2000 ms**. + +Enable with the new flag on the bench: + +```bash +python test_ttfb_rtf_bench.py --num 100 --neuron-c2w +``` + +### 6. Thinker ↔ talker pipelining + +With everything on Neuron, the talker still waited for the **complete** +thinker output before starting. That's because HF's `_build_talker_inputs` +takes the full token sequence + full layer-23 hidden, then `talker.generate` +is called once with a pre-built `trailing_text_hidden` tensor. + +But HF's talker `prepare_inputs_for_generation` only reads +`trailing_text_hidden[:, generation_step]` at decode step `k`, which +corresponds to the `(k+4)`-th thinker assistant token. So the talker can in +principle start as soon as **4 thinker tokens** are available, and consume +the rest one-by-one as they stream in. + +**Implementation (`test_ttfb_pipelined_bench.py`).** + +1. **Background thread runs the thinker.** A custom + `StoppingCriteria` is installed (NxDI's `_sample` ignores HF's + `streamer` arg, but it does call `stopping_criteria(input_ids, ...)` on + every decode step — perfect tap point) that pushes each newly-appended + token into a `PipelineState` condition-variable buffer. The + `tensor_capture_hook` for layer-23 captures the hidden tensor in the + same callback path. + +2. **Main thread builds the talker prefill incrementally.** + `StreamingTalkerInputs.build_prefill()` blocks until (a) the prefill's + layer-23 hidden is captured (one Neuron forward), and (b) the first 4 + assistant tokens are in the buffer. Then it assembles only the prefill + slice that HF's `_get_talker_assistant_parts` would build — using + `assistant_hidden[:, :4]` plus the codec specials. + +3. **Talker decode reads the trailing buffer on demand.** + `Qwen3OmniMoeTalkerForConditionalGeneration.prepare_inputs_for_generation` + is wrapped (layered on top of the existing streaming-c2w wrapper) so + that, for each decode step `k`, it overwrites `kwargs["trailing_text_hidden"]` + with a tensor whose row `k` is `text_projection(embed(thinker_tokens[k+4]))`, + pulled from the streaming buffer. If the (k+4)-th token isn't out yet, + the call blocks on the condition variable until it arrives. Past the end + of the thinker output, the slice falls back to `tts_eos_embed`. + +**Result.** TTFB: 2000 → **1759 ms** (mean), and crucially p95: 3316 → +**1822 ms** (−45 %). The big tail-latency win is because TTFB no longer +scales with thinker output length — the user gets first audio within +~1800 ms whether the assistant reply is 50 or 200 tokens long. + +The mean improvement is more modest (~12 %) than the naive "subtract all +thinker decode" estimate (~600 ms) for two reasons: + +- **Neuron device queue serializes.** Both thinker and talker NEFFs are + compiled at TP=8 and run on the same 8 cores. They interleave on the + Neuron driver instead of running truly in parallel. The talker can start + earlier, but its forwards still queue behind in-flight thinker forwards. +- **Per-step CPU overhead.** Each talker decode step now does an extra + `text_projection` on a single token (~3 ms) plus condition-variable + signal/wait (~1 ms). Across 25 steps that's ~100 ms of extra serial work. + +A real ~600 ms win would require running the thinker and talker on +**different** Neuron core groups (e.g. cores 0-7 and 8-15 on a trn2 +instance) so that they can dispatch in parallel. That's a separate +compile-and-deploy change. + +Enable with the new bench script: + +```bash +python test_ttfb_pipelined_bench.py --num 100 --neuron-c2w +``` + +--- + +## New files + +| Path | Purpose | +|---|---| +| `test_thinker_ttft_bench.py` | Thinker-only TTFT / ITL / throughput on 100 convs | +| `test_ttfb_rtf_bench.py` | Full streaming TTFB / RTF on 100 convs (serial). `--neuron-c2w` for Neuron-backed code2wav. | +| `test_ttfb_pipelined_bench.py` | Full streaming TTFB / RTF on 100 convs with thinker↔talker pipelining (background thread + on-demand `trailing_text_hidden`). | +| `compile_talker.py` | Compile talker with `TensorCaptureConfig(["norm"])` | +| `compile_code2wav.py` | Compile code2wav at fixed T buckets | +| `code2wav_neuron.py` | Runtime shim that routes code2wav through the compiled NEFFs | +| `src/_upstream_compat.py` | Added `_patch_tensor_registry_clear` | + +Compiled artifacts: + +| Path | Contents | +|---|---| +| `/tmp/qwen3_omni_compiled/talker_tp8_capnorm/` | Talker with norm capture (replaces `talker_tp8/` for the audio pipeline) | +| `/tmp/qwen3_omni_compiled/code2wav_buckets/model_T{30,50,128}.pt` | Per-bucket code2wav NEFFs | + +## Modified files + +| Path | Change | +|---|---| +| `src/_upstream_compat.py` | Patch `TensorRegistry.configure` / `clear` to preserve `modules_to_capture` across bucket traces | +| `test_audio_out_full_neuron.py` | Point `TALKER_COMPILED` at `talker_tp8_capnorm`; shim now reads `out[2]` (captured norm) as the real hidden | +| `test_audio_streaming.py` | `CHUNK_SIZE` / `LEFT_CTX` read from env vars | + +## Remaining TTFB cost (after step 5 — pipelined) + +At the pipelined config (TTFB mean 1759 ms / p50 1778 ms): + +- **wait for first 4 thinker tokens** — ~765 ms + - thinker prefill 668 + 4 × ITL (40) + bg-thread overhead (~60) + - this is the floor the talker can't start before +- **talker prefill + 25 decode steps (overlapped with thinker decode)** — ~870 ms + - 25 × ~21 ms talker, plus startup, plus a few cv-blocks waiting for the next thinker token +- **first code2wav chunk** — 122 ms +- **TTFB total** — 1759 ms + +Everything is on Neuron and overlapped where possible. + +### Options to go below 1759 ms + +1. **Run thinker and talker on disjoint Neuron core groups** (~250-500 ms + headroom). Today both NEFFs are TP=8 and share cores 0-7, so the Neuron + driver serializes their forwards. On a trn2 instance with 16 cores, we + could compile a second copy of the talker on cores 8-15 and dispatch in + true parallel. The talker's 25 × 21 ms then overlaps the thinker decode + end-to-end; TTFB would drop toward `max(thinker_full, prefill + + 4·ITL + 25·talker_ITL + c2w)` ≈ 1300-1400 ms. + +2. **Shorter thinker replies (task-dependent).** The 68-token mean is set by + the dataset's average assistant reply length. Prompting the thinker to be + more concise (or truncating via a stop sequence) cuts decode time + linearly. With pipelining, this matters less than before — the talker + hides most of the decode — but on samples where the talker happens to + catch up to the thinker's output, fewer thinker tokens still helps. + +3. **Thinker speculative decoding.** NxDI supports EAGLE-style speculation. + A draft model that proposes 2-3 thinker tokens per target step pushes the + effective ITL toward 4-5 ms, narrowing both the "wait for first 4 tokens" + window and the per-step talker stall window. + +4. **Thinker prefill bucketing.** All 100 prompts land in bucket 2048 + because the system prompt is ~800 tokens. Splitting the system prompt + into a separately-cached prefix and running bucket-512 on the delta could + shave the 668 ms prefill to ~250 ms — directly reducing the + "wait for first 4 tokens" floor. + +5. **Talker + UCP fusion.** The 25 × ~21 ms today is one talker forward + one + UCP forward, both via separate NEFFs. Merging them into a single traced + op per step would save ~3 ms/step on cross-NEFF dispatch, ~75 ms total + at CHUNK_SIZE=25. + +6. **Smaller CHUNK_SIZE.** 25 is already aggressive. Going to 15 would save + ~200 ms on the talker portion but increases left-context recompute + overhead in code2wav. Worth measuring if a lower floor is needed. + +7. **Pure-C++ orchestration / GIL elimination.** ~100 ms of the pipelined + TTFB is Python overhead from the bg-thread / cv-block / per-step + `text_projection` glue. A C++ inference server that drives both Neuron + models via the runtime API directly (no Python in the hot loop) would + shed that. diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/README.md b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/README.md new file mode 100644 index 00000000..73a06eb1 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/README.md @@ -0,0 +1,407 @@ +# Qwen3-Omni-30B-A3B-Instruct on AWS Neuron + +End-to-end inference of Qwen3-Omni-30B-A3B-Instruct on Trainium/Inferentia2 +via `neuronx_distributed_inference` (NxDI). Covers both ASR (speech→text) and +speech output (text→speech) pipelines. + +All five neural network modules run on Neuron with TP=8: + +| Module | Parameters | Role | +|---|---|---| +| Thinker MoE text decoder | 48 layers, 128 experts | generates text tokens from multimodal input | +| Vision encoder (Qwen3-VL ViT) | 27 layers | image → token embeddings | +| Audio encoder | 32-layer transformer | mel → audio token embeddings | +| Talker MoE | 20 layers, 128 experts | text + hidden → codec tokens | +| Unified Code Predictor | 5-layer dense GQA, 15-step unroll | expands each codec token to 15 residual codes | + +Code2Wav (codec→waveform) stays on CPU (~1s) — small enough that Neuron +offload overhead would negate the win. + +--- + +## End-to-end results + +### Audio output (text prompt → wav) + +Prompt: *"Please say hello and tell me about Neuron chips briefly."* +Output: 7.9 s @ 24 kHz, correct spoken answer. + +| Stage | Location | Time | +|---|---|---| +| Thinker text generation (80 tokens) | Neuron | 1.3 s | +| Layer-24 hidden assemble (from Neuron capture) | Neuron | 0.0 s | +| Talker decode loop (100 codec steps) | Neuron | 0.4 s | +| Unified Code Predictor (99 × 15 steps) | Neuron | 1.1 s | +| Code2Wav | CPU | 0.9 s | +| **Total** | | **~3.8 s** | + +Real-time factor **0.48x** (generates 2× faster than playback). + +**Progression** vs pure-CPU HF baseline: +| Version | Total | Note | +|---|---|---| +| CPU HF baseline | 91 s | reference | +| Thinker on Neuron only | 383 s | HF re-runs thinker on CPU inside `model.generate` | +| + skip HF thinker re-run | 107 s | replay thinker on CPU once for hidden states | +| + Neuron talker | 62 s | CPU talker (15 s) → Neuron (0.5 s) | +| + thinker hidden capture | 62 s | 45 s CPU re-forward → 0 s (directly from Neuron) | +| **+ Unified Code Predictor** | **3.8 s** | 59 s CPU CP (99 × 15 calls) → 1.1 s Neuron (99 calls) | + +### ASR (LibriSpeech test-clean, 100 samples) + +| Metric | Value | +|---|---| +| Samples | 100 | +| Total audio | 670 s | +| Total wall time | 66 s | +| Avg WER | 18.5 % (mostly casing/punctuation) | +| RTF | 0.12x (~8× real-time) | +| Per-clip latency (≤10 s clips) | 0.5–0.7 s | + +### Audio benchmarks (`test_audio_bench.py`) + +Four benchmarks run against the full Neuron pipeline. All generated wav files +are saved to `/tmp/qwen3_omni_bench/`. + +**1. Multi-length TTS** — scales gracefully; RTF actually improves as output +gets longer because the fixed thinker cost amortizes over more talker tokens. + +| Tag | Thinker tokens | Codec tokens | Wav | Total | RTF | +|---|---:|---:|---:|---:|---:| +| short (`"Say hi."`) | 10 | 80 | 6.3 s | 2.76 s | 0.44x | +| medium | 115 | 150 | 11.9 s | 4.91 s | 0.41x | +| long | 128 | 250 | 19.9 s | 6.98 s | 0.35x | +| xlong | 128 | 400 | 31.9 s | 10.24 s | 0.32x | + +**2. Multi-speaker TTS** — all three speakers (`chelsie`, `ethan`, `aiden`) +produce audio of identical length and near-identical latency, confirming the +speaker-ID plumbing works correctly. + +| Speaker | Wav | Total | RTF | +|---|---:|---:|---:| +| chelsie | 11.9 s | 3.42 s | 0.29x | +| ethan | 11.9 s | 3.31 s | 0.28x | +| aiden | 11.9 s | 3.30 s | 0.28x | + +**3. Audio-in → audio-out** — LibriSpeech clip as input, model repeats it +back as spoken audio. Full multimodal path (audio encoder → thinker → talker +→ code2wav). + +| Stage | Time | +|---|---:| +| Thinker (incl. audio encoder) | 0.7 s | +| Build talker input | negligible | +| Talker generate | 2.9 s | +| Code2Wav | 1.2 s | +| **Total** | **4.8 s** | + +- Input: 3.5 s speech, reference "CONCORD RETURNED TO ITS PLACE AMIDST THE TENTS" +- Model heard → repeated as "Concorde returned to its place amidst the tents." +- Output wav: 15.9 s (model adds a bit of extra speech / phrasing) + +**4. Long TTS (up to 512 codec tokens)** — stress the code predictor / +code2wav chain in sustained mode. Latency scales linearly with codec length; +per-step cost stable ~11 ms on the unified code predictor throughout. + +| Budget | Codec | Wav | Total | RTF | UCP time | UCP calls | UCP/call | +|---:|---:|---:|---:|---:|---:|---:|---:| +| 256 | 256 | 20.4 s | 6.92 s | 0.34x | 2.3 s | 255 | 9.0 ms | +| 400 | 400 | 31.9 s | 9.44 s | 0.30x | 3.7 s | 399 | 9.3 ms | +| 512 | 512 | 40.8 s | 11.68 s | 0.29x | 4.7 s | 511 | 9.2 ms | + +No drift in per-step UCP latency between 256 and 512 tokens. + +### TTFT / ITL (`test_ttft.py`) + +Per-stage time-to-first-token and inter-token latency, measured with a +`LogitsProcessor` that records `perf_counter()` on every token. Prompt +lengths 22, 28, and 57 tokens; talker budgets 80, 150, 250 codec tokens. + +| Stage | Metric | Value | +|---|---|---:| +| **Thinker** (48-layer MoE, TP=8) | TTFT (prefill) | 344–354 ms | +| | ITL mean | 12.1–13.1 ms | +| | ITL p50 / p95 | 12 / 22 ms | +| **Talker** (20-layer MoE, TP=8) | TTFT (prefill) | 24–30 ms | +| | ITL mean | 13.9–16.6 ms | +| | ITL p50 / p95 | 14 / 14 ms | +| **Code2Wav** (CPU, batch) | latency | 0.9–1.5 s per utterance | + +- Thinker TTFT ≈ 350 ms is dominated by the MoE prefill over a ~30-token + prompt (~12 ms/layer × 48 / 2 ≈ 290 ms of compute, plus bucket pad and + input move overhead). +- Thinker ITL ≈ 12 ms/token → **~80 tokens/s** text generation. +- Talker ITL ≈ 14 ms/token. Per talker step the pipeline also fires one + unified-CP NEFF call (~9 ms), so ~14 ms/step is consistent with + `talker (~5 ms) + UCP (~9 ms)`. +- Talker TTFT (~25 ms) is much lower than thinker TTFT because talker + prefill runs over a very short sequence (just the speaker token + special + prefix) — it's essentially a 14-token prefill through a 20-layer MoE. + +**End-to-end TTFB** (Time-To-First-Byte, prompt arrival → first wav sample +available on host): + +| Prompt length | Thinker tokens | Codec tokens | Wav length | Full TTFB | +|---:|---:|---:|---:|---:| +| 22 | 10 | 80 | 6.3 s | **2.73 s** | +| 28 | 94 | 150 | 11.9 s | **4.66 s** | +| 57 | 150 | 250 | 19.9 s | **7.13 s** | + +### Streaming code2wav (`test_audio_streaming.py`) + +Instead of waiting for the entire talker output and then running one large +`code2wav.chunked_decode`, we fire `code2wav` inline on 50-codec-token +chunks (~4 s audio each) as soon as they accumulate. Chunks are emitted +sequentially within the same thread — the talker pauses ~550 ms every 50 +codec tokens while the CPU `code2wav` decodes that chunk. This sacrifices a +little total wall time (per-chunk overhead adds up) in exchange for a +dramatically lower time-to-first-audio. + +Patch point: `Qwen3OmniMoeTalkerForConditionalGeneration.prepare_inputs_for_generation` +is wrapped at the class level; it intercepts `residual_codes` right after +HF builds it, appends to a shared list, and fires a chunk whenever the +list grows by 50. + +**First-audio latency comparison (batch vs streaming, same prompts)**: + +| Scenario | Wav length | Batch TTFB | Streaming TTFB | Improvement | +|---|---:|---:|---:|---:| +| short | 6.3 s | 2.73 s | **2.02 s** | −26 % | +| medium | 11.9 s | 4.66 s | **2.99 s** | −36 % | +| long | 23.8 s | 7.13 s | **3.41 s** | **−52 %** | +| xlong | 40.6 s | ~11.7 s (extrap.) | **4.00 s** | **−66 %** | + +Per-chunk code2wav cost stays steady (~550 ms per 50-token chunk). Total +wall time is slightly higher than batch mode (e.g. long: 9.66 s vs 7.13 s) +because small-chunk overhead (left-context re-compute, per-call Python +dispatch) dominates; but because user-perceived latency is TTFB, not total, +streaming is still a clear win for interactive use. + +Config knobs: `CHUNK_SIZE=50`, `LEFT_CTX=10` in `test_audio_streaming.py`. + +### Conversational audio-in benchmark (omni2, 100 convs) + +See [`BENCHMARK_OMNI2_TTFB.md`](BENCHMARK_OMNI2_TTFB.md) for a detailed TTFB +/ RTF benchmark on 100 real multi-turn audio-in conversations (prompts +1164–1494 tokens). Covers the progressive optimizations that took TTFB from +**2727 ms → 1759 ms** (−35 %, p95 from 3564 → 1822 ms / −49 %) and the +talker max-token truncation rate from 100 % → 12 %: + +1. Patched `TensorRegistry.clear()` so `layers.23` capture survives across + all bucket traces (prompts ≥ 512 tokens previously hit a zero fallback). +2. Recompiled the talker with `TensorCaptureConfig(["norm"])` and wired the + shim to use the real post-RMSNorm hidden — greedy decoding now reaches + `codec_eos` instead of looping on `[318, 318, …]`. +3. Switched talker `generate()` to HF's reference settings + (`do_sample=True, top_k=50, top_p=0.8, temperature=0.9, + repetition_penalty=1.1, suppress_tokens=`). +4. `CHUNK_SIZE=25` / `LEFT_CTX=5` (was 50 / 10): TTFB −487 ms. +5. Ported `code2wav` to Neuron (bit-exact vs CPU): first-chunk c2w 387 ms → + 122 ms. +6. Pipelined thinker ↔ talker — talker starts as soon as 4 thinker tokens + are buffered and reads `trailing_text_hidden[k]` on demand. Mean TTFB + 2000 → 1759 ms; p95 cut nearly in half (3316 → 1822 ms). + +Best configuration: `NEURON_RT_VISIBLE_CORES=0-7 CHUNK_SIZE=25 LEFT_CTX=5 +python test_ttfb_pipelined_bench.py --num 100 --neuron-c2w`. + +--- + +## Repository layout + +``` +contrib/models/Qwen3-Omni-30B-A3B-Instruct/ +├── README.md (this file) +└── src/ + ├── modeling_qwen3_omni.py top-level thinker + vision config / weight conversion + ├── modeling_qwen3_omni_text.py thinker MoE text decoder (48 layers, reuses Qwen3-VL attention) + ├── modeling_qwen3_omni_audio.py audio encoder (32-layer transformer on Neuron, Conv2d frontend / post-proc on CPU) + ├── modeling_qwen3_omni_talker.py talker MoE body (20 layers) + TalkerInferenceConfig + HF→Neuron weight conversion + ├── modeling_qwen3_omni_code_predictor.py per-call (debug) and unified (production) code predictor + ├── _upstream_compat.py runtime patches to NxDI's HuggingFaceGenerationAdapter and qwen3_vl vision loader + └── _model_path.py model-path helper +``` + +### Test scripts (in `/home/ubuntu/`) + +| File | Purpose | +|---|---| +| `test_asr_qwen3_omni.py` | ASR benchmark (LibriSpeech); builds and loads thinker + vision + audio encoder. Also exposes `build_and_load_model` reused by other tests. | +| `test_audio_out_cpu.py` | Pure-CPU HF reference for audio output. | +| `test_audio_out_neuron.py` | Phase 1 mixed: thinker on Neuron, talker+code2wav on CPU. | +| `test_audio_out_full_neuron.py` | Full Neuron pipeline (thinker + talker + unified CP + code2wav-CPU). | +| `test_audio_bench.py` | Four-benchmark audio suite (multi-length, multi-speaker, audio-in→audio-out, long TTS). Dumps JSON + wavs. | +| `test_ttft.py` | TTFT / ITL micro-benchmark — records per-token timestamps for thinker + talker and computes end-to-end TTFB. | +| `test_audio_streaming.py` | Streaming code2wav — emits 50-codec-token audio chunks inline for low TTFB. `CHUNK_SIZE` / `LEFT_CTX` overridable via env. | + +### Benchmark scripts (in the contrib dir) + +| File | Purpose | +|---|---| +| `test_thinker_ttft_bench.py` | Thinker-only TTFT / ITL / throughput on the omni2 100-conv dataset. | +| `test_ttfb_rtf_bench.py` | Full streaming TTFB / RTF on the omni2 100-conv dataset (serial thinker→talker). `--neuron-c2w` routes code2wav through Neuron. | +| `test_ttfb_pipelined_bench.py` | Same dataset, but thinker and talker overlap: thinker streams tokens to a bg thread, talker reads `trailing_text_hidden[k]` on demand. Lowest TTFB and tightest tail latency. | +| `compile_talker.py` | Recompile the talker with `TensorCaptureConfig(["norm"])` so the NEFF exposes the real post-RMSNorm hidden (needed by the `code_predictor` path). Output: `talker_tp8_capnorm/`. | +| `compile_code2wav.py` | Compile the vocoder at fixed input buckets (default `{30, 50, 128}`). Output: `code2wav_buckets/`. | +| `code2wav_neuron.py` | Runtime shim that replaces `hf_model.code2wav` with a bucket-dispatching Neuron wrapper. | + +--- + +## Setup + +```bash +source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate +``` + +Place HF model at `/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct`. + +Expected compiled artifacts in `/tmp/qwen3_omni_compiled/`: + +| Directory | Compiles | Size | +|---|---|---| +| `multimodal_tp8_cap23/` | Thinker (48L MoE) + vision, with layer-24 hidden capture | ~3 GB | +| `audio_encoder_tp8/` | 32-layer audio transformer | ~200 MB | +| `talker_tp8/` | Talker (20L MoE), no capture | ~1.2 GB | +| `talker_tp8_capnorm/` | Talker (20L MoE) with `norm` capture (needed by the audio-output pipeline; see `BENCHMARK_OMNI2_TTFB.md`) | ~1.2 GB | +| `code_predictor_unified_tp8/` | Unified CP (5L dense, 15-step unrolled) | ~150 MB | +| `code2wav_buckets/` | Optional Neuron code2wav, one NEFF per bucket size | ~500 MB total | + +--- + +## How to run + +### Audio output (text → speech) + +```bash +cd /home/ubuntu +NEURON_RT_VISIBLE_CORES=0-7 python test_audio_out_full_neuron.py \ + --prompt "Please say hello and tell me about Neuron chips briefly." \ + --out /tmp/out.wav +``` + +First run compiles all components (~45 min total). Subsequent runs reuse +the compiled artifacts and take ~60 s to load + 4 s to infer. + +### ASR + +```bash +NEURON_RT_VISIBLE_CORES=0-7 python test_asr_qwen3_omni.py --num-samples 100 +``` + +Opt-in flag `QWEN3_OMNI_CAPTURE_LAYER_HIDDEN=23` enables hidden-state capture +during thinker inference (required for the audio-output pipeline; compiled +separately from the ASR-only artifact to avoid the extra trace output). + +--- + +## Key design decisions + +### 1. Thinker weight sharing + +`NeuronQwen3OmniForCausalLM.checkpoint_loader_fn` loads HF safetensors once, +partitions by owning model (text vs vision), and returns a fresh shallow-copy +dict per builder call. Tensors are shared between text and vision +partitions. Without this, sharding text followed by vision peaked CPU RAM at +200 GB (and the second builder encountered missing vision keys because NxDI's +`preprocess_checkpoint` destructively deletes keys not owned by the current +model). With the fix, peak RAM stays around 20 GB. + +### 2. MRoPE + 3D position_ids + +The thinker uses interleaved MRoPE with mrope_section `[24, 20, 20]`. +`rotary_position_ids` with shape `[3, B, S]` is passed as input-generator +arg 21 (see `NeuronQwen3OmniTextModelWrapper._ROTARY_POSITION_IDS_INDEX`). +The upstream `pad_inputs` is patched (`test_asr_qwen3_omni.py:_patched_pad_inputs`) +to preserve this and the trailing `deepstack_vision_embeds` slot. + +### 3. Audio encoder — 20 heads don't divide TP=8 + +Padded num_heads 20 → 24 (next multiple of 8) and zero-fill the Q/K/V/out_proj +weight rows for the added heads (`NeuronAudioAttention.__init__` and +`convert_hf_to_neuron_state_dict`). Zero-padded heads produce zero output; +the zero column in `out_proj` ensures they don't contaminate the residual +stream. + +### 4. Audio encoder attention window + +`n_window=50` gives tiny 13-token attention blocks in the basic `cu_seqlens` +path, which corrupts long audio (≥15 s). Fixed by using +`_compute_inference_cu_seqlens` with `n_window_infer=800`, matching HF. + +### 5. Scatter bug at bucket boundaries + +HF's scatter path uses `fill_value = pad_limit - 1` which means when +`input_ids.shape[1] == bucket_size` exactly, the fill positions land on the +last *real* prompt token. Audio embeddings' padding zeros then clobber that +token (observed as "55." garbage on 16.8 s audio). Fixed by appending a pad +token when prompt length is a bucket boundary +(`modeling_qwen3_omni.py:forward`). + +### 6. Talker: shared_expert differs from routed experts + +Qwen3-Omni talker has `moe_intermediate_size=384` (routed experts) but +`shared_expert_intermediate_size=768` and a sigmoid-gated shared path. NxDI's +`initialize_moe_module` ties shared-expert size to `config.intermediate_size`, +so we build the shared expert as a separate `SharedExpertSwiGLU` module with +its own intermediate size and apply it alongside the routed MoE inside +`NeuronTalkerDecoderLayer`. + +### 7. Talker: 2 KV heads don't divide TP=8 + +Replicated the 2 KV heads into 8 (one per rank) during weight conversion +(`convert_talker_hf_to_neuron`, the `kv_pad` logic). Each replicated head +computes the same attention, so this is bit-exact up to bf16 noise. + +### 8. Talker ↔ HF generate integration + +The HF talker pipeline computes, for every decode step, a sum of +`last_id_hidden + code_predictor mid hiddens + trailing text hidden` on CPU, +then feeds the resulting `inputs_embeds` to `talker.model.forward`. We +install a shim (`NeuronTalkerShim`) that replaces `talker.model` and +routes the already-summed `inputs_embeds` through the Neuron NEFF via the +`vision_embeddings` input slot — the same pattern Qwen2.5-Omni uses. +`codec_head` is swapped to `nn.Identity` since the Neuron NEFF already +applies its internal `lm_head`. + +### 9. Thinker layer-24 hidden capture + +HF's talker inputs need per-token hidden states at `accept_hidden_layer=24` +(the 24th post-layer hidden, i.e., output of the 0-indexed layer 23). These +are extracted from the Neuron thinker for free via +`TensorCaptureConfig(modules_to_capture=["layers.23"])` plus +`output_logits=True`, instead of replaying the 30 B model on CPU (~45 s). +The capture hook is passed to `adapter.generate` as `tensor_capture_hook`. + +### 10. Unified Code Predictor + +HF's code predictor generates 15 residual codes per talker decode step with +15 sequential forward passes. Per-call Neuron overhead (~10 ms) × +(15 calls × 99 talker steps) would be ≈15 s — **slower** than HF's CPU +baseline (~60 s over same workload). Instead, the entire 15-step +argmax-loop is unrolled into a single NEFF +(`UnifiedNeuronCodePredictor.forward`) that completes in ~11 ms per talker +step, for **54× speedup** over HF CPU. + +The unrolled trace uses a fixed 16-position buffer (2 prefill + 14 decode) +and re-runs full attention each round (no KV cache). The 15 codec embedding +tables and 15 LM heads are stacked into single tensors and indexed inside +the trace. + +--- + +## Known limitations + +- **bf16 numerical drift** — occasionally one out of 15 residual codes + diverges by one unit (e.g., step 13 may pick code 293 vs golden 1025 when + the top-2 logits are separated by <0.002). Audio quality is unaffected. +- **Code2Wav on Neuron (opt-in)** — `compile_code2wav.py` traces the + vocoder at a handful of fixed input lengths; `code2wav_neuron.py` + dispatches by chunk size at runtime. Bit-exact vs CPU and ~3× faster on + the per-chunk streaming call. Enable via `--neuron-c2w` in + `test_ttfb_rtf_bench.py`. Compile `/tmp/qwen3_omni_compiled/code2wav_buckets/` + once with the expected chunk sizes (defaults cover `CHUNK_SIZE=25`). +- **Talker compilation time** — ~10 min for the 20-layer MoE. +- **CPU HF model required at inference time** — for the `_get_talker_user_parts` + / `_get_talker_assistant_parts` helpers and `text_projection`/`hidden_projection` + projections. ~60 GB CPU RAM. A future optimization would lift those + projections onto Neuron too. diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/code2wav_neuron.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/code2wav_neuron.py new file mode 100644 index 00000000..cab24d8f --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/code2wav_neuron.py @@ -0,0 +1,104 @@ +"""Runtime shim: replace ``hf_model.code2wav`` CPU calls with Neuron NEFFs. + +Buckets are chosen at install time; at call time we pick the smallest bucket +>= T and pad the codec-tokens tensor up to it. The output is trimmed back to +``T * total_upsample`` samples to match CPU behavior. + +Install once per process via ``install_neuron_code2wav(hf_model)``. +""" +import os +from pathlib import Path +from typing import List, Optional + +import torch + +DEFAULT_BUCKETS_DIR = Path("/tmp/qwen3_omni_compiled/code2wav_buckets") + + +class NeuronCode2WavShim(torch.nn.Module): + """Holds one compiled NEFF per bucket size; dispatches on T at call time.""" + + def __init__(self, hf_c2w, buckets_dir: Path, buckets: Optional[List[int]] = None): + super().__init__() + # We want to keep ``config`` and ``total_upsample`` from the original so + # callers that read those still work (``chunked_decode`` uses + # ``self.total_upsample``). + self.hf_c2w = hf_c2w + self.config = hf_c2w.config + self.total_upsample = hf_c2w.total_upsample + + found = {} + for f in sorted(buckets_dir.glob("model_T*.pt")): + # Parse T from filename "model_T{int}.pt" + T = int(f.stem.split("_T")[-1]) + if buckets is None or T in buckets: + found[T] = f + if not found: + raise RuntimeError(f"No code2wav NEFFs found in {buckets_dir}") + + self._neffs = {} + for T in sorted(found): + print(f" [code2wav shim] loading T={T} from {found[T]}") + self._neffs[T] = torch.jit.load(str(found[T])) + self._bucket_sizes = sorted(self._neffs.keys()) + self._max_bucket = self._bucket_sizes[-1] + + def _pick_bucket(self, T: int) -> int: + for b in self._bucket_sizes: + if b >= T: + return b + # T exceeds the largest bucket — fall back to CPU. + return -1 + + def forward(self, codes: torch.Tensor) -> torch.Tensor: + B, Q, T = codes.shape + bucket = self._pick_bucket(T) + if bucket == -1: + # No NEFF big enough: use CPU + return self.hf_c2w(codes) + + if T == bucket: + padded_codes = codes + else: + # Right-pad with zeros (valid codec ids live in [0, codebook_size=2048)) + pad_amount = bucket - T + pad = torch.zeros((B, Q, pad_amount), dtype=codes.dtype, device=codes.device) + padded_codes = torch.cat([codes, pad], dim=-1) + + neuron = self._neffs[bucket] + wav = neuron(padded_codes) + # Output shape is (B, 1, bucket * total_upsample). Trim to real length. + real_samples = T * self.total_upsample + wav = wav[..., :real_samples] + return wav + + # chunked_decode is inherited behavior on hf_c2w but our forward shim gets + # called with codes — we re-implement here for symmetry and to avoid HF + # accidentally calling the CPU forward. + def chunked_decode(self, codes: torch.Tensor, chunk_size: int = 300, + left_context_size: int = 25) -> torch.Tensor: + wavs = [] + start_index = 0 + while start_index < codes.shape[-1]: + end_index = min(start_index + chunk_size, codes.shape[-1]) + context_size = left_context_size if start_index - left_context_size > 0 else start_index + codes_chunk = codes[..., start_index - context_size: end_index] + wav_chunk = self.forward(codes_chunk) + wavs.append(wav_chunk[..., context_size * self.total_upsample:]) + start_index = end_index + return torch.cat(wavs, dim=-1) + + +def install_neuron_code2wav( + hf_model, + buckets_dir: Path = DEFAULT_BUCKETS_DIR, + buckets: Optional[List[int]] = None, +) -> NeuronCode2WavShim: + """Replace ``hf_model.code2wav`` with a Neuron-backed shim. + + Returns the shim (holding the original HF code2wav on ``.hf_c2w`` in case + callers want to fall back). + """ + shim = NeuronCode2WavShim(hf_model.code2wav, buckets_dir, buckets=buckets) + hf_model.code2wav = shim + return shim diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_audio.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_audio.py new file mode 100644 index 00000000..8ffe5847 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_audio.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +"""Compile Qwen3-Omni audio encoder transformer to a single Neuron core. + +Conv2d frontend stays on CPU. Transformer layers + postprocessor are traced +per bucket size via torch_neuronx.trace (no TP). + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=16 python compile_audio.py +""" +import json +import logging +import sys +import time +from pathlib import Path + +import torch + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +logging.basicConfig(level=logging.INFO) + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +COMPILED_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-audio" + +from modeling_qwen3_omni_audio import Qwen3OmniAudioEncoder + +config_path = Path(MODEL_PATH) / "config.json" +with open(config_path) as f: + full_config = json.load(f) + +audio_config = full_config.get("thinker_config", {}).get("audio_config", {}) +print(f"Audio config: d_model={audio_config.get('d_model')}, " + f"layers={audio_config.get('encoder_layers', audio_config.get('num_hidden_layers'))}, " + f"heads={audio_config.get('encoder_attention_heads')}") + +print("Loading audio encoder weights...") +t0 = time.perf_counter() +encoder = Qwen3OmniAudioEncoder.from_pretrained(MODEL_PATH, audio_config) +print(f"Weights loaded in {time.perf_counter() - t0:.1f}s") + +print(f"Compiling audio encoder to Neuron (buckets: {encoder.__class__.__name__})...") +t0 = time.perf_counter() +encoder.compile_neuron(COMPILED_PATH) +elapsed = time.perf_counter() - t0 +print(f"Audio encoder compilation complete in {elapsed:.1f}s") +print(f"Compiled audio encoder saved to: {COMPILED_PATH}") diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_code2wav.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_code2wav.py new file mode 100644 index 00000000..84998e58 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_code2wav.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +"""Compile code2wav (vocoder) on Neuron. + +code2wav is the ConvNeXt + upsample + BigVGAN stack that maps 16-channel codec +tokens to 24 kHz audio. It ran on CPU in the streaming bench, spending ~390 ms +on the first chunk and blocking TTFB. + +The model is a fixed-size graph given a fixed input length T (in codec tokens). +We trace one NEFF per bucket and dispatch at runtime by rounding T up to the +next bucket. Compile via ``torch_neuronx.trace`` (not the SPMD ModelBuilder) — +single-core, fp32 weights, no tensor parallelism. + +Output: /tmp/qwen3_omni_compiled/code2wav_buckets/model_T{T}.pt for each T. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-7 python compile_code2wav.py +""" +import os +os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7") +os.environ["TRANSFORMERS_VERBOSITY"] = "error" + +import argparse +import time +from pathlib import Path + +import torch +import torch_neuronx +from transformers import Qwen3OmniMoeForConditionalGeneration + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +# Bucket set tuned for the streaming bench: +# * streaming chunk: CHUNK_SIZE + LEFT_CTX = 25 + 5 = 30 +# * finalize chunk (tail): LEFT_CTX + 0..CHUNK_SIZE-1 ≤ 30 +# * non-streaming `chunked_decode` default: chunk_size=300 + left_context=25 +# We cover the streaming sizes and a large bucket for safety. +DEFAULT_BUCKETS = [30, 50, 128, 300, 512] + + +class Code2WavWrapper(torch.nn.Module): + """Wraps ``Qwen3OmniMoeCode2Wav.forward`` so it is trace-friendly. + + The original forward does a shape check that raises a Python error if + codes.shape[1] != num_quantizers. We keep that check out of the trace + (it's a static invariant) and only expose the compute. + """ + + def __init__(self, c2w): + super().__init__() + self.c2w = c2w + + def forward(self, codes): + # codes: [1, num_quantizers=16, T], long + c2w = self.c2w + hidden = c2w.code_embedding(codes + c2w.code_offset).mean(1) + hidden = c2w.pre_transformer(inputs_embeds=hidden).last_hidden_state + hidden = hidden.permute(0, 2, 1) + for blocks in c2w.upsample: + for block in blocks: + hidden = block(hidden) + wav = hidden + for block in c2w.decoder: + wav = block(wav) + return wav.clamp(min=-1, max=1) + + +def compile_one(c2w_wrapper, T, out_path): + example = torch.randint(0, 2048, (1, 16, T), dtype=torch.long) + print(f" tracing T={T} ...") + t0 = time.time() + traced = torch_neuronx.trace( + c2w_wrapper, + example, + compiler_workdir=f"/tmp/c2w_workdir_T{T}", + # fp32 for correctness; c2w is fairly small so cost is modest. + compiler_args="--auto-cast=none", + ) + traced.save(str(out_path)) + # Quick sanity: run once + out = traced(example) + print(f" done in {time.time()-t0:.0f}s, out shape={tuple(out.shape)}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", default="/tmp/qwen3_omni_compiled/code2wav_buckets") + parser.add_argument("--buckets", nargs="*", type=int, default=DEFAULT_BUCKETS) + args = parser.parse_args() + + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading HF model (we only need .code2wav) ...") + t0 = time.time() + hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained( + MODEL_PATH, dtype=torch.float32, low_cpu_mem_usage=True, device_map="cpu", + ) + hf_model.eval() + print(f" loaded in {time.time()-t0:.0f}s") + + wrapper = Code2WavWrapper(hf_model.code2wav).eval() + + for T in args.buckets: + out_path = out_dir / f"model_T{T}.pt" + if out_path.exists(): + print(f"T={T}: already compiled at {out_path}, skipping") + continue + print(f"T={T}: compiling to {out_path}") + compile_one(wrapper, T, out_path) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_multimodal.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_multimodal.py new file mode 100644 index 00000000..6213f2ea --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_multimodal.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +"""Compile Qwen3-Omni multimodal model (text MoE + vision encoder) for Neuron. + +Both text and vision models use TP=16 with LNC=2, running on 32 physical cores. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-31 python compile_multimodal.py +""" +import sys +import time +from pathlib import Path + +import torch + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from modeling_qwen3_omni import ( + NeuronQwen3OmniForCausalLM, + Qwen3OmniInferenceConfig, + load_qwen3_omni_multimodal_config, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +COMPILED_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-multimodal" +TP_DEGREE = 16 + +text_neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=4096, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": 1, "do_sample": False}, + blockwise_matmul_config={"use_torch_block_wise": True}, +) + +vision_neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, +) + +config = Qwen3OmniInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_qwen3_omni_multimodal_config(MODEL_PATH), +) + +model = NeuronQwen3OmniForCausalLM(MODEL_PATH, config) + +print(f"Compiling multimodal model with TP={TP_DEGREE} ...") +t0 = time.perf_counter() +model.compile(COMPILED_PATH) +elapsed = time.perf_counter() - t0 +print(f"Compilation complete in {elapsed:.1f}s") +print(f"Compiled model saved to: {COMPILED_PATH}") diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_talker.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_talker.py new file mode 100644 index 00000000..96bc78e9 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_talker.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +"""Compile the Qwen3-Omni talker MoE on Neuron with ``norm`` tensor capture. + +Why tensor_capture_config: HF talker's per-step generate loop reads the +transformer's final-layer hidden (pre-lm_head) via ``output.hidden_states[-1]`` +and feeds it to ``code_predictor`` to produce 15 residual codes. Our previous +compile had ``tensor_capture_config=None`` and the runtime shim fabricated the +hidden by re-embedding argmax'd tokens. That lossy stand-in drifts and never +lets the talker emit ``codec_eos_token_id`` — every bench sample maxed out at +``max_new_tokens``. Adding a capture on ``norm`` (the RMSNorm applied right +before lm_head) gives us the real hidden through the NEFF at negligible cost. + +Output path: ``/tmp/qwen3_omni_compiled/talker_tp8_capnorm``. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-7 python compile_talker.py +""" +import os +os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7") +os.environ["TRANSFORMERS_VERBOSITY"] = "error" + +import sys +from pathlib import Path +_HERE = Path(__file__).resolve().parent +_SRC = _HERE / "src" +if str(_SRC) not in sys.path: + sys.path.insert(0, str(_SRC)) + +import _upstream_compat # noqa: F401 — installs TensorRegistry.clear fix + +import argparse +import json +import time + +import torch + +from neuronx_distributed_inference.models.config import ( + MoENeuronConfig, TensorCaptureConfig, +) +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from modeling_qwen3_omni_talker import ( + NeuronTalkerForCausalLM, TalkerInferenceConfig, +) + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +# Buckets mirror the existing talker_tp8 compile. +TALKER_BUCKETS = [64, 128, 256, 512, 1024, 2048, 4096] +TP_DEGREE = 8 + + +def _talker_load_config(): + """Return a load_config hook that populates the config from + talker.text_config (the 20-layer MoE inside Qwen3-Omni). + + We reuse ``load_pretrained_config`` by handing it the nested + ``talker.text_config`` (which is itself a ``PretrainedConfig``). + """ + from transformers import AutoConfig + full_cfg = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True) + return load_pretrained_config(hf_config=full_cfg.talker_config.text_config) + + +def build_config(): + neuron_config = MoENeuronConfig( + batch_size=1, + seq_len=4096, + max_context_length=4096, + ctx_batch_size=1, + tp_degree=TP_DEGREE, + torch_dtype=torch.bfloat16, + fused_qkv=False, + sequence_parallel_enabled=False, + flash_decoding_enabled=False, + qkv_kernel_enabled=False, + qkv_nki_kernel_enabled=False, + attn_kernel_enabled=False, + enable_bucketing=True, + context_encoding_buckets=TALKER_BUCKETS, + token_generation_buckets=TALKER_BUCKETS, + # Capture the talker's final RMSNorm output (pre-lm_head hidden). The + # underlying NeuronBaseModel attribute is ``norm``. The NEFF emits an + # extra output tensor of shape [B, S_bucket, hidden_size=1024] per + # forward. + tensor_capture_config=TensorCaptureConfig( + modules_to_capture=["norm"], + capture_inputs=False, + ), + output_logits=True, + blockwise_matmul_config={"use_torch_block_wise": True}, + ) + + cfg = TalkerInferenceConfig( + neuron_config=neuron_config, + load_config=_talker_load_config(), + ) + return cfg + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--out", default="/tmp/qwen3_omni_compiled/talker_tp8_capnorm") + args = parser.parse_args() + + cfg = build_config() + print(f"Creating NeuronTalkerForCausalLM (tp={TP_DEGREE}, buckets={TALKER_BUCKETS})") + app = NeuronTalkerForCausalLM(model_path=MODEL_PATH, config=cfg) + + print(f"Compiling to {args.out} ...") + t0 = time.time() + app.compile(args.out) + print(f"Compile took {time.time()-t0:.0f}s") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_tp8.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_tp8.py new file mode 100644 index 00000000..7fd40e80 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_tp8.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +"""Compile Qwen3-Omni text MoE + vision encoder at TP=8 (LNC=2). + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-15 python compile_tp8.py +""" +import sys +import time +from pathlib import Path + +import torch + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from modeling_qwen3_omni import ( + NeuronQwen3OmniForCausalLM, + Qwen3OmniInferenceConfig, + load_qwen3_omni_multimodal_config, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +COMPILED_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-tp8" +TP_DEGREE = 8 + +text_neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=4096, + max_context_length=2048, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": 1, "do_sample": False}, + blockwise_matmul_config={"use_torch_block_wise": True}, +) + +vision_neuron_config = NeuronConfig( + tp_degree=TP_DEGREE, + batch_size=1, + seq_len=4096, + torch_dtype=torch.bfloat16, +) + +config = Qwen3OmniInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_qwen3_omni_multimodal_config(MODEL_PATH), +) + +model = NeuronQwen3OmniForCausalLM(MODEL_PATH, config, skip_vision_encoder=True) + +print(f"Compiling text MoE + vision at TP={TP_DEGREE} (LNC=2) ...") +t0 = time.perf_counter() +model.compile(COMPILED_PATH) +elapsed = time.perf_counter() - t0 +print(f"Compilation complete in {elapsed:.1f}s") +print(f"Compiled model saved to: {COMPILED_PATH}") diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py new file mode 100644 index 00000000..62c837c1 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +""" +Generate text from Qwen3-Omni-30B-A3B-Instruct on Neuron. + +Supports three modes: + --mode text : Text-only generation (vision + MoE text on Neuron) + --mode image : Image + text generation (vision + MoE text on Neuron) + --mode audio : Audio + text generation (audio + vision + MoE text on Neuron) + +All neural network components run on Neuron. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + cd /home/ubuntu/whn-ndi + + # Set environment + export NEURON_RT_VISIBLE_CORES=0-7 # or 0-15 for larger TP + export QWEN3_OMNI_MODEL_PATH=/path/to/Qwen3-Omni-30B-A3B-Instruct + + # Text mode + python contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py \\ + --mode text --prompt "What is quantum computing?" + + # Image mode + python contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py \\ + --mode image --image /path/to/image.jpg --prompt "Describe this image." + + # Audio mode + python contrib/models/Qwen3-Omni-30B-A3B-Instruct/examples/generate_qwen3_omni.py \\ + --mode audio --audio /path/to/audio.wav --prompt "Transcribe the speech." +""" + +import os +import sys +import argparse +import time +from pathlib import Path + +# Add src to path +_SRC = Path(__file__).resolve().parent.parent / "src" +if str(_SRC) not in sys.path: + sys.path.insert(0, str(_SRC)) +import _upstream_compat # noqa: F401 + +import gc +import torch +from _model_path import resolve_model_path + + +def parse_args(): + parser = argparse.ArgumentParser(description="Generate with Qwen3-Omni on Neuron") + parser.add_argument("--mode", choices=["text", "image", "audio"], default="text") + parser.add_argument("--prompt", type=str, default="What is quantum computing?") + parser.add_argument("--image", type=str, default=None, help="Path to image file") + parser.add_argument("--audio", type=str, default=None, help="Path to audio file") + parser.add_argument("--model-path", type=str, default=None) + parser.add_argument("--compiled-path", type=str, default="/tmp/qwen3_omni_compiled") + parser.add_argument("--tp-degree", type=int, default=8) + parser.add_argument("--max-new-tokens", type=int, default=256) + parser.add_argument("--seq-len", type=int, default=4096) + return parser.parse_args() + + +def build_model(model_path, compiled_path, tp_degree, seq_len): + from neuronx_distributed_inference.models.config import ( + MoENeuronConfig, + NeuronConfig, + OnDeviceSamplingConfig, + ) + from neuronx_distributed_inference.utils.hf_adapter import ( + load_pretrained_config, + HuggingFaceGenerationAdapter, + ) + from transformers import AutoProcessor + + from modeling_qwen3_omni import ( + Qwen3OmniMoEInferenceConfig, + NeuronQwen3OmniForCausalLM, + ) + + text_buckets = [256, 512, 1024, 2048, seq_len] + vision_seq_len = 1012 + vision_buckets = [vision_seq_len] + + text_neuron_config = MoENeuronConfig( + batch_size=1, + seq_len=seq_len, + max_context_length=seq_len, + ctx_batch_size=1, + tp_degree=tp_degree, + torch_dtype=torch.bfloat16, + fused_qkv=False, + sequence_parallel_enabled=False, + flash_decoding_enabled=False, + qkv_kernel_enabled=False, + qkv_nki_kernel_enabled=False, + attn_kernel_enabled=False, + enable_bucketing=True, + context_encoding_buckets=text_buckets, + token_generation_buckets=text_buckets, + on_device_sampling_config=OnDeviceSamplingConfig(do_sample=False, top_k=1), + blockwise_matmul_config={"use_torch_block_wise": True}, + ) + vision_neuron_config = NeuronConfig( + batch_size=1, + seq_len=vision_seq_len, + tp_degree=tp_degree, + torch_dtype=torch.bfloat16, + enable_bucketing=True, + buckets=vision_buckets, + fused_qkv=False, + qkv_kernel_enabled=False, + attn_kernel_enabled=False, + ) + + config = Qwen3OmniMoEInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_pretrained_config(model_path), + ) + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + compiled_dir = os.path.join(compiled_path, f"multimodal_tp{tp_degree}") + + print("Creating model...") + t0 = time.time() + model = NeuronQwen3OmniForCausalLM(model_path=model_path, config=config) + print(f" Created in {time.time()-t0:.1f}s") + + if not os.path.exists(os.path.join(compiled_dir, "neuron_config.json")): + print("Compiling text + vision...") + t0 = time.time() + model.compile(compiled_dir) + processor.save_pretrained(compiled_dir) + print(f" Compiled in {time.time()-t0:.1f}s") + else: + print(" Compiled artifacts found") + + print("Loading compiled model...") + t0 = time.time() + model.load(compiled_dir) + processor = AutoProcessor.from_pretrained(compiled_dir) + print(f" Loaded in {time.time()-t0:.1f}s") + + adapter = HuggingFaceGenerationAdapter(model) + return adapter, processor, config + + +def build_audio_encoder(model, model_path, compiled_path, tp_degree): + from modeling_qwen3_omni_audio import NeuronQwen3OmniAudioEncoder + + print("\nBuilding audio encoder...") + from transformers import AutoModelForCausalLM + + print(" Loading HF model for audio weights...") + t0 = time.time() + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, + torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, + ) + full_sd = hf_model.state_dict() + print(f" Loaded in {time.time()-t0:.1f}s") + + audio_sd = NeuronQwen3OmniAudioEncoder.convert_hf_to_neuron_state_dict( + full_sd, dtype=torch.bfloat16 + ) + del hf_model, full_sd + gc.collect() + + model.neuron_model.enable_audio_encoder(audio_sd) + + compiled_audio_dir = os.path.join(compiled_path, f"audio_encoder_tp{tp_degree}") + if not os.path.exists(os.path.join(compiled_audio_dir, "neuron_config.json")): + print(" Compiling audio encoder transformer...") + t0 = time.time() + model.neuron_model.compile_audio_encoder(compiled_audio_dir) + print(f" Compiled in {time.time()-t0:.1f}s") + else: + print(" Audio encoder compiled artifacts found") + + print(" Loading audio encoder...") + t0 = time.time() + model.neuron_model.load_audio_encoder(compiled_audio_dir) + print(f" Loaded in {time.time()-t0:.1f}s") + + del audio_sd + gc.collect() + + +def generate_text(adapter, processor, prompt, max_new_tokens): + from transformers import GenerationConfig + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"type": "text", "text": prompt}]}, + ] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], return_tensors="pt", padding=True) + + gen_config = GenerationConfig( + do_sample=False, + eos_token_id=[151645], + pad_token_id=151645, + ) + + t0 = time.time() + output_ids = adapter.generate( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + generation_config=gen_config, + max_new_tokens=max_new_tokens, + ) + gen_time = time.time() - t0 + + prompt_len = inputs.input_ids.shape[1] + new_tokens = output_ids[:, prompt_len:] + response = processor.batch_decode( + new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0].strip() + + return response, gen_time, new_tokens.shape[1] + + +def generate_with_image(adapter, processor, prompt, image_path, max_new_tokens): + from transformers import GenerationConfig + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [ + {"type": "image", "image": image_path}, + {"type": "text", "text": prompt}, + ]}, + ] + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], images=[image_path], return_tensors="pt", padding=True) + + gen_config = GenerationConfig( + do_sample=False, + eos_token_id=[151645], + pad_token_id=151645, + ) + + generate_kwargs = { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "generation_config": gen_config, + "max_new_tokens": max_new_tokens, + } + if hasattr(inputs, "pixel_values") and inputs.pixel_values is not None: + generate_kwargs["pixel_values"] = inputs.pixel_values.to(torch.bfloat16) + if hasattr(inputs, "image_grid_thw") and inputs.image_grid_thw is not None: + generate_kwargs["image_grid_thw"] = inputs.image_grid_thw + + t0 = time.time() + output_ids = adapter.generate(**generate_kwargs) + gen_time = time.time() - t0 + + prompt_len = inputs.input_ids.shape[1] + new_tokens = output_ids[:, prompt_len:] + response = processor.batch_decode( + new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0].strip() + + return response, gen_time, new_tokens.shape[1] + + +def generate_with_audio(adapter, processor, prompt, audio_path, max_new_tokens): + from transformers import GenerationConfig + + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [ + {"type": "audio", "audio": audio_path}, + {"type": "text", "text": prompt}, + ]}, + ] + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], audio=[audio_path], return_tensors="pt", padding=True) + + gen_config = GenerationConfig( + do_sample=False, + eos_token_id=[151645], + pad_token_id=151645, + ) + + generate_kwargs = { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "generation_config": gen_config, + "max_new_tokens": max_new_tokens, + } + if hasattr(inputs, "input_features") and inputs.input_features is not None: + generate_kwargs["input_features"] = inputs.input_features.to(torch.bfloat16) + if hasattr(inputs, "feature_attention_mask") and inputs.feature_attention_mask is not None: + generate_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + t0 = time.time() + output_ids = adapter.generate(**generate_kwargs) + gen_time = time.time() - t0 + + prompt_len = inputs.input_ids.shape[1] + new_tokens = output_ids[:, prompt_len:] + response = processor.batch_decode( + new_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0].strip() + + return response, gen_time, new_tokens.shape[1] + + +def main(): + args = parse_args() + model_path = args.model_path or resolve_model_path() + compiled_path = args.compiled_path + + print("=" * 60) + print(f"Qwen3-Omni-30B-A3B-Instruct on Neuron (mode={args.mode})") + print(f" Model: {model_path}") + print(f" TP: {args.tp_degree}") + print(f" Seq len: {args.seq_len}") + print("=" * 60) + + adapter, processor, config = build_model( + model_path, compiled_path, args.tp_degree, args.seq_len + ) + + if args.mode == "audio": + build_audio_encoder( + adapter, model_path, compiled_path, args.tp_degree + ) + + print("\n" + "=" * 60) + print("Generating...") + print("=" * 60) + + if args.mode == "text": + response, gen_time, n_tokens = generate_text( + adapter, processor, args.prompt, args.max_new_tokens + ) + elif args.mode == "image": + if args.image is None: + print("ERROR: --image required for image mode") + sys.exit(1) + response, gen_time, n_tokens = generate_with_image( + adapter, processor, args.prompt, args.image, args.max_new_tokens + ) + elif args.mode == "audio": + if args.audio is None: + print("ERROR: --audio required for audio mode") + sys.exit(1) + response, gen_time, n_tokens = generate_with_audio( + adapter, processor, args.prompt, args.audio, args.max_new_tokens + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Response: {response}") + print(f"\n Tokens: {n_tokens}") + print(f" Time: {gen_time:.2f}s") + print(f" Speed: {n_tokens/gen_time:.1f} tok/s") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_demo.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_demo.py new file mode 100644 index 00000000..05fd9d4b --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_demo.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" +Quick demo: Run Qwen3-Omni-30B-A3B-Instruct thinker text model on Neuron. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-31 python run_demo.py \ + --model-path /home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct \ + --compiled-model-path /home/ubuntu/traced_model/Qwen3-Omni-30B-A3B-Instruct \ + --tp-degree 32 \ + --prompt "Hello, who are you?" +""" +import argparse +import sys +import time +from pathlib import Path + +import torch +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig +from neuronx_distributed_inference.utils.accuracy import get_generate_outputs + +sys.path.insert(0, str(Path(__file__).parent / "src")) +from modeling_qwen3_omni_moe import ( + NeuronQwen3OmniMoeForCausalLM, + Qwen3OmniMoeInferenceConfig, + load_qwen3_omni_thinker_text_config, +) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-model-path", required=True) + parser.add_argument("--tp-degree", type=int, default=32) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--max-context-length", type=int, default=256) + parser.add_argument("--max-new-tokens", type=int, default=128) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--prompt", type=str, default="Hello, who are you?") + args = parser.parse_args() + + print(f"Model path: {args.model_path}") + print(f"TP degree: {args.tp_degree}") + + neuron_config = MoENeuronConfig( + tp_degree=args.tp_degree, + batch_size=args.batch_size, + seq_len=args.seq_len, + max_context_length=args.max_context_length, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": args.top_k, "do_sample": False}, + ) + + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(args.model_path), + ) + + model = NeuronQwen3OmniMoeForCausalLM(args.model_path, config) + + compiled_path = Path(args.compiled_model_path) + if not compiled_path.exists(): + print("Compiling model (this may take several minutes)...") + t0 = time.perf_counter() + model.compile(args.compiled_model_path) + print(f"Compilation took {time.perf_counter() - t0:.1f}s") + + print("Loading model to Neuron...") + t0 = time.perf_counter() + model.load(args.compiled_model_path) + print(f"Model loaded in {time.perf_counter() - t0:.1f}s") + + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + prompts = [args.prompt] * args.batch_size + + print(f"\nPrompt: {args.prompt}") + print("Generating...") + t0 = time.perf_counter() + _, output_tokens = get_generate_outputs( + model, + prompts, + tokenizer, + is_hf=False, + do_sample=False, + max_length=model.neuron_config.max_length, + ) + elapsed = time.perf_counter() - t0 + + for i, text in enumerate(output_tokens): + print(f"\nOutput[{i}]: {text}") + print(f"\nGeneration took {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_multimodal_demo.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_multimodal_demo.py new file mode 100644 index 00000000..c4e9be94 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/run_multimodal_demo.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Multimodal demo: Run Qwen3-Omni-30B-A3B-Instruct with vision+text on Neuron. + +Usage (vision+text): + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-31 python run_multimodal_demo.py \ + --model-path /home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct \ + --compiled-model-path /home/ubuntu/traced_model/Qwen3-Omni-multimodal \ + --image-path /path/to/image.jpg \ + --prompt "Describe this image." + +Usage (text-only): + NEURON_RT_VISIBLE_CORES=0-31 python run_multimodal_demo.py \ + --model-path /home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct \ + --compiled-model-path /home/ubuntu/traced_model/Qwen3-Omni-multimodal \ + --prompt "The capital of France is" +""" +import argparse +import sys +import time +from pathlib import Path + +import torch +from transformers import AutoProcessor, AutoTokenizer + +sys.path.insert(0, str(Path(__file__).parent / "src")) +from modeling_qwen3_omni import ( + NeuronQwen3OmniForCausalLM, + Qwen3OmniInferenceConfig, + Qwen3OmniMoeNeuronConfig, + load_qwen3_omni_multimodal_config, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", required=True) + parser.add_argument("--compiled-model-path", required=True) + parser.add_argument("--tp-degree", type=int, default=16) + parser.add_argument("--vision-tp-degree", type=int, default=16) + parser.add_argument("--batch-size", type=int, default=1) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--max-context-length", type=int, default=2048) + parser.add_argument("--vision-seq-len", type=int, default=4096, + help="Max vision sequence length (patches)") + parser.add_argument("--max-new-tokens", type=int, default=256) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument("--prompt", type=str, default="Describe this image in detail.") + parser.add_argument("--image-path", type=str, default=None, + help="Path to input image (optional, text-only if not provided)") + args = parser.parse_args() + + print(f"Model path: {args.model_path}") + print(f"Text TP degree: {args.tp_degree}") + print(f"Vision TP degree: {args.vision_tp_degree}") + print(f"Mode: {'vision+text' if args.image_path else 'text-only'}") + + # Text model neuron config (MoE) + text_neuron_config = MoENeuronConfig( + tp_degree=args.tp_degree, + batch_size=args.batch_size, + seq_len=args.seq_len, + max_context_length=args.max_context_length, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": args.top_k, "do_sample": False}, + blockwise_matmul_config={"use_torch_block_wise": True}, + ) + + # Vision model neuron config + vision_neuron_config = NeuronConfig( + tp_degree=args.vision_tp_degree, + batch_size=1, + seq_len=args.vision_seq_len, + torch_dtype=torch.bfloat16, + ) + + config = Qwen3OmniInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_qwen3_omni_multimodal_config(args.model_path), + ) + + model = NeuronQwen3OmniForCausalLM(args.model_path, config) + + compiled_path = Path(args.compiled_model_path) + if not compiled_path.exists(): + print("Compiling model (this may take a while)...") + t0 = time.perf_counter() + model.compile(args.compiled_model_path) + print(f"Compilation took {time.perf_counter() - t0:.1f}s") + + print("Loading model to Neuron...") + t0 = time.perf_counter() + model.load(args.compiled_model_path) + print(f"Model loaded in {time.perf_counter() - t0:.1f}s") + + # Prepare inputs — limit max_pixels so raw patches fit vision_seq_len + patch_size = 16 + max_pixels = args.vision_seq_len * (patch_size ** 2) + processor = AutoProcessor.from_pretrained( + args.model_path, trust_remote_code=True, + max_pixels=max_pixels, use_fast=False, + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + images = None + if args.image_path: + from PIL import Image + img = Image.open(args.image_path).convert("RGB") + images = [[img]] + + input_ids, attention_mask, vision_inputs = NeuronQwen3OmniForCausalLM.prepare_input_args( + prompts=[args.prompt], + images=images, + processor=processor, + config=config, + ) + + print(f"\nPrompt: {args.prompt}") + print(f"Input shape: {input_ids.shape}") + if vision_inputs: + print(f"Vision inputs: {list(vision_inputs.keys())}") + + print("Generating...") + t0 = time.perf_counter() + + # Use the HuggingFaceGenerationAdapter for generation + from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + generation_model = HuggingFaceGenerationAdapter(model) + + outputs = generation_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=args.max_new_tokens, + do_sample=False, + **vision_inputs, + ) + + elapsed = time.perf_counter() - t0 + + for i in range(outputs.shape[0]): + text = tokenizer.decode(outputs[i], skip_special_tokens=True) + print(f"\nOutput[{i}]: {text}") + print(f"\nGeneration took {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/__init__.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_model_path.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_model_path.py new file mode 100644 index 00000000..4937af60 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_model_path.py @@ -0,0 +1,7 @@ +import os + +_DEFAULT_MODEL_ID = "Qwen/Qwen3-Omni-30B-A3B-Instruct" + + +def resolve_model_path() -> str: + return os.environ.get("QWEN3_OMNI_MODEL_PATH", _DEFAULT_MODEL_ID) diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_upstream_compat.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_upstream_compat.py new file mode 100644 index 00000000..2feea6a2 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/_upstream_compat.py @@ -0,0 +1,197 @@ +# Upstream compatibility shims for Qwen3-Omni. +# Reuses the same patches as Qwen2.5-Omni since the upstream issues are shared. + +import logging +import inspect + +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + +logger = logging.getLogger(__name__) + + +def _patch_prepare_inputs_for_generation(): + """Fix upstream NameError: tensor_capture_hook not defined.""" + src = inspect.getsource(HuggingFaceGenerationAdapter.prepare_inputs_for_generation) + references_hook = '"tensor_capture_hook": tensor_capture_hook' in src + already_extracted = 'tensor_capture_hook = kwargs.get("tensor_capture_hook"' in src + if already_extracted or not references_hook: + return + + original = HuggingFaceGenerationAdapter.prepare_inputs_for_generation + + def patched(self, input_ids, *args, **kwargs): + import torch + self.prev_kv_cache_populated = self.neuron_model.kv_cache_populated + if self.neuron_model.kv_cache_populated: + input_ids = input_ids[:, -1:] + + past_key_values = kwargs.pop("past_key_values", None) + attention_mask = kwargs.pop("attention_mask", None) + inputs_embeds = kwargs.pop("inputs_embeds", None) + sampling_params = kwargs.pop("sampling_params", None) + adapter_ids = kwargs.pop("adapter_ids", None) + divergence_idx = kwargs.pop("divergence_idx", None) + + accepted_indices = kwargs.get("accepted_indices", None) + current_length = kwargs.get("current_length", None) + medusa_mask = kwargs.get("medusa_mask", None) + scatter_index = kwargs.get("scatter_index", None) + position_ids = kwargs.get("position_ids", None) + input_capture_hook = kwargs.get("input_capture_hook", None) + tensor_capture_hook = kwargs.get("tensor_capture_hook", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + if self.input_start_offsets: + if len(self.input_start_offsets) > 1: + position_ids += torch.tensor( + self.input_start_offsets, + dtype=position_ids.dtype, + device=position_ids.device, + )[:, None] + else: + position_ids += self.input_start_offsets[0] + for i, offset in enumerate(self.input_start_offsets): + position_ids[i, 0:offset] = torch.arange(offset) + else: + position_ids.masked_fill_(attention_mask == 0, 1) + + if self.neuron_model.kv_cache_populated: + position_ids = torch.amax(position_ids, 1, keepdim=True) + position_ids = position_ids + 1 + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache", False), + "attention_mask": attention_mask, + "medusa_args": (accepted_indices, current_length, medusa_mask, scatter_index), + "sampling_params": sampling_params, + "input_capture_hook": input_capture_hook, + "tensor_capture_hook": tensor_capture_hook, + "adapter_ids": adapter_ids, + } + ) + + tf_args = [] + if self.neuron_config.tensor_replacement_config: + from neuronx_distributed_inference.utils.tensor_replacement.registry import ( + TensorReplacementRegister, + ) + reg = TensorReplacementRegister.get_instance() + tf, masks = reg.step_args( + self.generation_step, + divergence_idx=True if divergence_idx else False, + ) + tf_args = tf + masks + + if tf_args: + model_inputs["tf_args"] = tf_args + + additional_kwargs = self.neuron_model.get_required_kwargs() + for arg in additional_kwargs: + model_inputs.update({arg: kwargs.get(arg, None)}) + + return model_inputs + + HuggingFaceGenerationAdapter.prepare_inputs_for_generation = patched + logger.info( + "Qwen3-Omni contrib: patched HuggingFaceGenerationAdapter." + "prepare_inputs_for_generation to extract tensor_capture_hook from kwargs." + ) + + +_patch_prepare_inputs_for_generation() + + +def _patch_vision_wrapper_load_state_dict(): + """Remap thinker.visual.* -> model.visual.* in safetensors loading. + + Qwen3-VL upstream expects model.visual.pos_embed.weight but + Qwen3-Omni safetensors store it as thinker.visual.pos_embed.weight. + """ + import neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_vision as vmod + + _original_load = vmod.load_state_dict + + def _remapped_load(state_dict_dir): + sd = _original_load(state_dict_dir) + if "model.visual.pos_embed.weight" not in sd and "thinker.visual.pos_embed.weight" in sd: + remapped = {} + for k, v in sd.items(): + if k.startswith("thinker.visual."): + remapped["model.visual." + k[len("thinker.visual."):]] = v + else: + remapped[k] = v + return remapped + return sd + + vmod.load_state_dict = _remapped_load + logger.info( + "Qwen3-Omni contrib: patched vision wrapper load_state_dict " + "to remap thinker.visual.* -> model.visual.*" + ) + + +_patch_vision_wrapper_load_state_dict() + + +def _patch_tensor_registry_clear(): + """Fix upstream NxD bug: TensorRegistry.clear() wipes modules_to_capture. + + Inside ``NeuronBaseModel._get_captured_tensors`` (called once per HLO + trace, once per bucket), the final line is ``registry.clear()``. Upstream + ``clear()`` replaces ``model_info`` with a fresh + ``CapturedModelInfo([], 10, False)`` — resetting ``modules_to_capture``. + The forward hooks installed by ``enable_tensor_capture`` still fire on + the next bucket's trace, but ``register_tensor`` now falls through to + the "manual" branch (module name no longer in ``modules_to_capture``). + + Net effect: only the FIRST bucket to trace captures a real tensor; every + subsequent bucket emits the empty fallback ``torch.zeros(1, dtype=bf16)``, + making layer-hidden-state capture (``layers.23`` for the talker pipeline) + unusable for any prompt that doesn't fit the first bucket. + + Fix: preserve the configured modules/flags through clear() by stashing + them on the singleton. + """ + from neuronx_distributed.utils.tensor_capture.registry import ( + CapturedModelInfo, TensorRegistry, + ) + + if getattr(TensorRegistry, "_nxdi_clear_patched", False): + return + + _orig_configure = TensorRegistry.configure + + def configure(self, enabled=False, modules=None, max_tensors=None, capture_inputs=False): + cfg_modules = list(modules or []) + if cfg_modules: + self._nxdi_last_modules = cfg_modules + self._nxdi_last_max_tensors = max_tensors + self._nxdi_last_capture_inputs = capture_inputs + _orig_configure(self, enabled=enabled, modules=modules, + max_tensors=max_tensors, capture_inputs=capture_inputs) + + def clear(self): + modules = getattr(self, "_nxdi_last_modules", []) + max_tensors = getattr(self, "_nxdi_last_max_tensors", 10) + capture_inputs = getattr(self, "_nxdi_last_capture_inputs", False) + self.model_info = CapturedModelInfo(modules, max_tensors, capture_inputs) + + TensorRegistry.configure = configure + TensorRegistry.clear = clear + TensorRegistry._nxdi_clear_patched = True + logger.info( + "Qwen3-Omni contrib: patched TensorRegistry.clear to preserve " + "modules_to_capture across bucket traces." + ) + + +_patch_tensor_registry_clear() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni.py new file mode 100644 index 00000000..71d9efe5 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni.py @@ -0,0 +1,1045 @@ +"""Qwen3-Omni-30B-A3B-Instruct multimodal model for NxD Inference. + +Combines: + - Qwen3-VL vision encoder (reused directly) + - Qwen3-Omni audio encoder (Conv2d + transformer on Neuron) + - MoE text decoder (MRoPE attention + sparse MoE FFN) + +All neural network components run on Neuron. +""" + +import copy +import gc +import logging +import math +import os +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union + +import torch +from transformers.modeling_outputs import CausalLMOutputWithPast + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + NeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + MOE_TKG_MK_INTERMEDIATE_PER_TP, +) +from neuronx_distributed_inference.models.image_to_text_model_base import ( + ImageToTextInferenceConfig, + NeuronBaseForImageToText, +) +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import ( + generate_positions_from_mask, + pad_positions, +) +from neuronx_distributed_inference.models.model_wrapper import VISION_ENCODER_MODEL_TAG +from neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_text import ( + NeuronQwen3VLTextForCausalLM, +) +from neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_vision import ( + NeuronQwen3VLForImageEncoding, + NeuronQwen3VLVisionModel, + NeuronQwen3VLVisionModelWrapper, +) +from neuronx_distributed_inference.modules.autobucketing import generate_buckets + +from modeling_qwen3_omni_text import ( + NeuronQwen3OmniTextModel, + NeuronQwen3OmniTextModelWrapper, + convert_qwen3_omni_text_hf_to_neuron, +) +from modeling_qwen3_omni_audio import ( + AudioEncoderInferenceConfig, + NeuronQwen3OmniAudioEncoder, + NeuronQwen3OmniForAudioEncoding, +) + +logger = logging.getLogger("Neuron") + + +class Qwen3OmniMoEInferenceConfig(ImageToTextInferenceConfig): + """Inference config for Qwen3-Omni multimodal model. + + Handles the nested config structure: + Qwen3OmniMoeConfig -> thinker_config -> text_config, vision_config, audio_config + + Combines ImageToTextInferenceConfig (vision + text) with MoE settings + from Qwen3MoeInferenceConfig. + """ + + @staticmethod + def _extract_thinker_config(obj): + thinker = getattr(obj, "thinker_config", None) + if thinker is None: + return + if hasattr(thinker, "__dict__") and not isinstance(thinker, dict): + thinker = vars(thinker) + if not isinstance(thinker, dict): + return + + def _to_dict(x): + if hasattr(x, "__dict__") and not isinstance(x, dict): + return vars(x) + return x + + if not hasattr(obj, "text_config") and "text_config" in thinker: + obj.text_config = _to_dict(thinker["text_config"]) + if not hasattr(obj, "vision_config") and "vision_config" in thinker: + obj.vision_config = _to_dict(thinker["vision_config"]) + if not hasattr(obj, "audio_config") and "audio_config" in thinker: + obj.audio_config = _to_dict(thinker["audio_config"]) + for token_key in [ + "audio_token_id", "image_token_id", "video_token_id", + "audio_start_token_id", "vision_start_token_id", "vision_end_token_id", + "vision_token_id", "pad_token_id", "position_id_per_seconds", + ]: + if token_key in thinker and not hasattr(obj, token_key): + setattr(obj, token_key, thinker[token_key]) + + def __init__( + self, + text_neuron_config, + vision_neuron_config, + fused_spec_config=None, + load_config=None, + metadata: Optional[Dict] = None, + **kwargs, + ): + # Extract sub-configs from thinker_config if present + thinker = kwargs.get("thinker_config", None) + if thinker is not None: + if hasattr(thinker, "__dict__") and not isinstance(thinker, dict): + thinker = vars(thinker) + if isinstance(thinker, dict): + if "text_config" not in kwargs and "text_config" in thinker: + tc = thinker["text_config"] + kwargs["text_config"] = ( + vars(tc) if hasattr(tc, "__dict__") and not isinstance(tc, dict) else tc + ) + if "vision_config" not in kwargs and "vision_config" in thinker: + vc = thinker["vision_config"] + kwargs["vision_config"] = ( + vars(vc) if hasattr(vc, "__dict__") and not isinstance(vc, dict) else vc + ) + if "audio_config" not in kwargs and "audio_config" in thinker: + ac = thinker["audio_config"] + kwargs["audio_config"] = ( + vars(ac) if hasattr(ac, "__dict__") and not isinstance(ac, dict) else ac + ) + for token_key in [ + "audio_token_id", "image_token_id", "video_token_id", + "audio_start_token_id", "vision_start_token_id", "vision_end_token_id", + "vision_token_id", "pad_token_id", + ]: + if token_key in thinker and token_key not in kwargs: + kwargs[token_key] = thinker[token_key] + + # Wrap load_config to extract thinker sub-configs + original_load_config = load_config + if original_load_config is not None: + extract = self._extract_thinker_config + def _wrapped_load_config(self_inner): + original_load_config(self_inner) + extract(self_inner) + load_config = _wrapped_load_config + + super().__init__( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + fused_spec_config=fused_spec_config, + load_config=load_config, + metadata=metadata, + **kwargs, + ) + + self._add_moe_config() + self._add_special_config() + self._validate_supported_configs() + + def _add_moe_config(self): + """Apply MoE-specific settings to text_config (from Qwen3MoeInferenceConfig).""" + tc = self.text_config + + # num_local_experts alias for initialize_moe_module + if hasattr(tc, "num_experts") and not hasattr(tc, "num_local_experts"): + tc.num_local_experts = tc.num_experts + tc.n_shared_experts = 0 + + # GLU MLP required for MoE + tc.neuron_config.glu_mlp = True + + # Router config + tc.neuron_config.router_config.dtype = torch.float32 + tc.neuron_config.router_config.act_fn = "softmax" + tc.neuron_config.disable_numeric_cc_token = True + if hasattr(tc, "norm_topk_prob") and tc.norm_topk_prob: + tc.neuron_config.normalize_top_k_affinities = True + + # Set intermediate_size to moe_intermediate_size for MoE layers + if hasattr(tc, "moe_intermediate_size"): + tc.intermediate_size = tc.moe_intermediate_size + + # Intermediate size padding for MoE + moe_tp_degree = tc.neuron_config.moe_tp_degree + if hasattr(tc, "moe_intermediate_size"): + I_TP = tc.moe_intermediate_size // moe_tp_degree + if getattr(tc.neuron_config.blockwise_matmul_config, "use_shard_on_intermediate_dynamic_while", False): + if I_TP % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded = ( + math.ceil(I_TP / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) + * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP + * moe_tp_degree + ) + tc.moe_intermediate_pad_size = max(padded - tc.moe_intermediate_size, 0) + tc.moe_intermediate_size = padded + + # MoE fused NKI kernel + I_TP = tc.moe_intermediate_size // moe_tp_degree + if ( + getattr(tc.neuron_config, "moe_fused_nki_kernel_enabled", False) + and I_TP % MOE_TKG_MK_INTERMEDIATE_PER_TP == 0 + ): + tc.moe_fused_nki_kernel_enabled = True + + def _add_special_config(self): + """Copy vision/text attributes and apply validation.""" + # Copy deepstack_visual_indexes from vision to text config + if hasattr(self.vision_config, "deepstack_visual_indexes"): + self.text_config.deepstack_visual_indexes = copy.deepcopy( + self.vision_config.deepstack_visual_indexes + ) + + # MRoPE section from text_config + if hasattr(self.text_config, "rope_scaling"): + rs = self.text_config.rope_scaling + if isinstance(rs, dict) and "mrope_section" in rs: + self.text_config.mrope_section = rs["mrope_section"] + + # Vision config derived attributes + if hasattr(self.vision_config, "hidden_size") and hasattr(self.vision_config, "num_heads"): + self.vision_config.head_dim = ( + self.vision_config.hidden_size // self.vision_config.num_heads + ) + self.vision_config.num_cores_per_group = 1 + + # Vision encoder uses fused QKV (HF stores qkv.weight, conversion maps to Wqkv) + self.vision_config.neuron_config.fused_qkv = True + + # Copy token IDs to top-level and text_config + for attr in [ + "image_token_id", "video_token_id", "audio_token_id", + "vision_start_token_id", "vision_end_token_id", + ]: + val = getattr(self, attr, None) + if val is not None: + setattr(self.text_config, attr, val) + + # Pad token + if hasattr(self, "pad_token_id"): + self.text_config.pad_token_id = self.pad_token_id + + # Qwen3 MoE text: no QKV bias, no output bias + self.text_config.attention_bias = False + self.text_config.qkv_bias = False + self.text_config.o_bias = False + + # Store audio_config as SimpleNamespace + if hasattr(self, "audio_config") and isinstance(self.audio_config, dict): + self.audio_config = SimpleNamespace(**self.audio_config) + + # Vision bucketing + if not self.vision_config.neuron_config.enable_bucketing: + VISION_SEQ_LENGTH = self.vision_config.neuron_config.seq_len + self.vision_config.neuron_config.enable_bucketing = True + self.vision_config.neuron_config.buckets = generate_buckets( + VISION_SEQ_LENGTH, VISION_SEQ_LENGTH + ) + + if self.text_config.neuron_config.seq_len > 10240: + os.environ["NEURON_RT_DBG_INTRA_RDH_CHANNEL_BUFFER_SIZE"] = f"{140 * 1024 * 1024}" + + def _validate_supported_configs(self): + unsupported_text = [ + "is_block_kv_layout", "is_prefix_caching", "is_chunked_prefill", + "is_medusa", "enable_fused_speculation", + ] + for cfg_name in unsupported_text: + if getattr(self.text_config.neuron_config, cfg_name, False): + setattr(self.text_config.neuron_config, cfg_name, False) + logger.warning(f"Qwen3-Omni text model does not support '{cfg_name}'. Disabled.") + + if self.text_config.neuron_config.attention_dp_degree > 1: + raise ValueError("Qwen3-Omni does not support attention data parallel") + if self.text_config.neuron_config.cp_degree > 1: + raise ValueError("Qwen3-Omni does not support context parallel") + + unsupported_vision = [ + "sequence_parallel_enabled", "flash_decoding_enabled", + "mlp_kernel_enabled", + "attn_block_tkg_nki_kernel_cache_update", + "attn_block_tkg_nki_kernel_enabled", + "qkv_kernel_enabled", "attn_kernel_enabled", + ] + for cfg_name in unsupported_vision: + if getattr(self.vision_config.neuron_config, cfg_name, False) is not False: + setattr(self.vision_config.neuron_config, cfg_name, False) + + def get_required_attributes(self) -> List[str]: + return [ + "text_config", + "vision_config", + "text_config.hidden_size", + "text_config.num_attention_heads", + "text_config.num_hidden_layers", + "text_config.num_key_value_heads", + "text_config.vocab_size", + "text_config.rms_norm_eps", + "text_config.rope_theta", + "text_config.moe_intermediate_size", + "text_config.num_experts", + "text_config.num_experts_per_tok", + "vision_config.deepstack_visual_indexes", + "vision_config.depth", + "vision_config.hidden_size", + "vision_config.num_heads", + "vision_config.patch_size", + "vision_config.spatial_merge_size", + ] + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return MoENeuronConfig + + +class NeuronQwen3OmniForCausalLM(NeuronBaseForImageToText): + """Qwen3-Omni multimodal model (vision + audio + MoE text) on Neuron. + + - Vision encoder: Qwen3-VL ViT (reused directly) + - Audio encoder: Conv2d + Neuron transformer + proj1/GELU/proj2 + - Text decoder: MRoPE attention + MoE FFN with deepstack + """ + + text_model_cls = NeuronQwen3OmniTextModel + vision_model_cls = NeuronQwen3VLVisionModel + + text_model_wrapper = NeuronQwen3OmniTextModelWrapper + vision_model_wrapper = NeuronQwen3VLVisionModelWrapper + + def __init__(self, *args, **kwargs): + super().__init__( + self.text_model_cls, + self.vision_model_cls, + self.text_model_wrapper, + self.vision_model_wrapper, + *args, + **kwargs, + ) + self.rope_deltas = None + self.audio_encoder = None + self._cached_neuron_state_dict = None + + def checkpoint_loader_fn(self, mmap: bool = False): + """Convert the full state dict once, split into text-only and + vision-only shards, and return a fresh shallow copy to the caller. + + The underlying ModelBuilder mutates and deletes keys from the returned + dict during `preprocess_checkpoint` (it drops everything not in the + target model's state_dict). Sharing one dict across text + vision + builders would delete the vision keys on the first pass. Each call + therefore returns its own dict, while tensor storage is shared + between them. + """ + if self._cached_neuron_state_dict is None: + sd = super().checkpoint_loader_fn(mmap=mmap) + # Partition keys by owning model. The same tensor memory is + # shared between partitions; each builder will drop the keys + # it does not own during preprocess_checkpoint. + text_prefixes = ("layers.", "embed_tokens.", "lm_head.", "norm.", "rank_util.") + vision_prefixes = ("blocks.", "patch_embed.", "merger.", + "deepstack_merger_list.", "rotary_pos_emb.", "pos_embed.") + text_sd, vision_sd = {}, {} + for k, v in sd.items(): + if any(k.startswith(p) for p in text_prefixes): + text_sd[k] = v + elif any(k.startswith(p) for p in vision_prefixes): + vision_sd[k] = v + self._cached_neuron_state_dict = {"text": text_sd, "vision": vision_sd} + del sd + + # Return a fresh shallow copy so preprocess_checkpoint's .pop calls + # don't mutate the cache. Tensors are shared (same underlying storage). + cache = self._cached_neuron_state_dict + return {**cache["text"], **cache["vision"]} + + def free_cached_state_dict(self): + """Call after all builders have finished sharding to release CPU memory.""" + self._cached_neuron_state_dict = None + + # --- Vision encoder --- + + def get_vision_compiler_args(self) -> str: + cc = self.vision_config.neuron_config.cc_pipeline_tiling_factor + return ( + f"--auto-cast=none --model-type=transformer " + f"--tensorizer-options='--enable-ccop-compute-overlap " + f"--cc-pipeline-tiling-factor={cc}' -O1 " + f"--internal-max-instruction-limit=15000000" + ) + + def get_compiler_args(self) -> str: + cc = self.text_config.neuron_config.cc_pipeline_tiling_factor + return ( + f"--auto-cast=none --model-type=transformer " + f"--tensorizer-options='--enable-ccop-compute-overlap " + f"--cc-pipeline-tiling-factor={cc}' -O1 " + f"--internal-max-instruction-limit=15000000" + ) + + def get_required_kwargs(self) -> List[str]: + return [ + "pixel_values", "image_grid_thw", + "input_features", "feature_attention_mask", + ] + + def enable_vision_encoder(self, enable_wlt_optimization: bool = True, **model_init_kwargs): + new_config = copy.deepcopy(self.config) + self.vision_encoder_model = self.vision_model_wrapper( + config=new_config, + model_cls=self.vision_model_cls, + tag=VISION_ENCODER_MODEL_TAG, + compiler_args=self.get_vision_compiler_args(), + model_init_kwargs=model_init_kwargs, + priority_model_idx=(0 if enable_wlt_optimization else None), + pipeline_execution=True, + return_ranked_to_cpu=False, + ) + self.vision_models.append(self.vision_encoder_model) + + # --- Audio encoder --- + + def enable_audio_encoder(self, state_dict=None): + audio_config = getattr(self.config, "audio_config", None) + if audio_config is None: + logger.warning("No audio_config found. Audio encoder not initialized.") + return + + dtype = torch.bfloat16 + if hasattr(self.config, "neuron_config"): + dtype = getattr(self.config.neuron_config, "torch_dtype", dtype) + + if state_dict is not None: + self.audio_encoder = NeuronQwen3OmniAudioEncoder.from_pretrained_state_dict( + audio_config, state_dict, dtype=dtype + ) + # Stash transformer weights for compile / load (checkpoint_loader_fn). + self._audio_transformer_state_dict = { + k: v for k, v in state_dict.items() if k.startswith("transformer.") + } + else: + self.audio_encoder = NeuronQwen3OmniAudioEncoder(audio_config, dtype=dtype) + self._audio_transformer_state_dict = None + + self.audio_encoder.eval() + logger.info("Audio encoder initialized (Neuron transformer pending compile/load)") + + def compile_audio_encoder(self, compiled_model_path, audio_neuron_config=None): + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + audio_config = getattr(self.config, "audio_config", None) + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + if audio_neuron_config is None: + tp_degree = self.neuron_config.tp_degree + audio_neuron_config = NeuronConfig( + tp_degree=tp_degree, + torch_dtype=self.neuron_config.torch_dtype, + batch_size=1, + buckets=[256, 512, 1024, 1500, 2048, 3000], + ) + + audio_inf_config = AudioEncoderInferenceConfig( + neuron_config=audio_neuron_config, + audio_config=vars(audio_config) if hasattr(audio_config, "__dict__") else audio_config, + ) + + # Pass the already-loaded transformer weights so checkpoint_loader_fn can return them. + transformer_sd = getattr(self, "_audio_transformer_state_dict", None) + + audio_app = NeuronQwen3OmniForAudioEncoding( + model_path=self.model_path, + config=audio_inf_config, + transformer_state_dict=transformer_sd, + ) + audio_app.compile(compiled_model_path) + logger.info("Audio encoder transformer compiled to %s", compiled_model_path) + return audio_app + + def load_audio_encoder(self, compiled_model_path, audio_app=None): + if self.audio_encoder is None: + raise RuntimeError("Call enable_audio_encoder() first") + + if audio_app is None: + inf_config = AudioEncoderInferenceConfig.load(compiled_model_path) + transformer_sd = getattr(self, "_audio_transformer_state_dict", None) + audio_app = NeuronQwen3OmniForAudioEncoding( + model_path=self.model_path, + config=inf_config, + transformer_state_dict=transformer_sd, + ) + audio_app.is_compiled = True + audio_app.traced_model = torch.jit.load( + compiled_model_path + "/model.pt" + ) + for mw in audio_app.models: + mw.model = audio_app.traced_model + audio_app.is_loaded_to_neuron = True + audio_app.load_weights(compiled_model_path) + + self.audio_encoder.transformer = audio_app.model + logger.info("Audio encoder transformer loaded from %s", compiled_model_path) + + # --- Image counting and splitting (from Qwen3-VL) --- + + def _count_images_per_batch_line(self, input_ids, attention_mask): + image_token_id = self.config.image_token_id + vision_start_token_id = self.config.vision_start_token_id + images_per_batch_line = [] + + for i in range(input_ids.shape[0]): + ids = input_ids[i] + if attention_mask is not None: + ids = ids[attention_mask[i] == 1] + vision_start_indices = torch.argwhere(ids == vision_start_token_id).squeeze(1) + if vision_start_indices.numel() == 0: + images_per_batch_line.append(0) + else: + vision_tokens = ids[vision_start_indices + 1] + num_images = (vision_tokens == image_token_id).sum().item() + images_per_batch_line.append(num_images) + + return images_per_batch_line + + def _split_vision_inputs_by_batch_line(self, pixel_values, image_grid_thw, images_per_batch_line): + result = [] + image_offset = 0 + patch_offset = 0 + + for num_images in images_per_batch_line: + if num_images == 0: + result.append((None, None)) + continue + + grid_thw_i = image_grid_thw[image_offset : image_offset + num_images] + num_patches = grid_thw_i.prod(dim=1).sum().item() + pixel_values_i = pixel_values[patch_offset : patch_offset + num_patches] + + result.append((pixel_values_i, grid_thw_i)) + image_offset += num_images + patch_offset += num_patches + + return result + + # --- Rope index (from Qwen3-VL) --- + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> tuple: + """Compute 3D MRoPE position IDs (copied from Qwen3-VL).""" + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave(video_grid_thw, video_grid_thw[:, 0], dim=0) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = getattr(self.config, "video_token_id", None) + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, input_ids.shape[0], input_ids.shape[1], + dtype=input_ids.dtype, device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, input_ids_i in enumerate(total_input_ids): + input_ids_i = input_ids_i[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids_i == vision_start_token_id).squeeze(1) + vision_tokens = input_ids_i[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + if video_token_id is not None: + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids_i.tolist() + llm_pos_ids_list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id is not None and video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions.to(total_input_ids.dtype) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas + + # --- Atomic prefill --- + + def forward_atomic_prefill( + self, + input_ids, + attention_mask, + position_ids, + seq_ids, + sampling_params, + pixel_values, + image_grid_thw, + audio_embeddings=None, + audio_positions=None, + input_capture_hook=None, + tensor_capture_hook=None, + ): + pad_limit = self.get_padding_length(input_ids) + + if pixel_values is not None and pixel_values.numel() > 0: + vision_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + vision_mask = vision_mask.to(torch.bool) + vision_mask = generate_positions_from_mask(vision_mask.squeeze()) + vision_mask = pad_positions(vision_mask, pad_limit, (pad_limit - 1)) + + vision_embeddings, deepstack_vision_embeds = self.vision_encoder_model( + pixel_values.to(self.vision_config.neuron_config.torch_dtype), image_grid_thw + ) + else: + vision_embeddings, vision_mask, deepstack_vision_embeds = ( + self.text_model_wrapper.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + ) + + # Merge audio embeddings into vision embeddings for scattering + if audio_embeddings is not None and audio_positions is not None: + if vision_embeddings is not None and vision_embeddings.numel() > 0: + all_embeddings = torch.cat([ + vision_embeddings.cpu() if hasattr(vision_embeddings, 'is_cuda') and vision_embeddings.is_cuda else vision_embeddings, + audio_embeddings, + ], dim=0) + all_positions = torch.cat([ + generate_positions_from_mask( + (input_ids == self.config.image_token_id).squeeze() + ), + audio_positions, + ]) + vision_embeddings = all_embeddings + vision_mask = pad_positions(all_positions, pad_limit, (pad_limit - 1)) + else: + vision_embeddings = audio_embeddings + vision_mask = pad_positions(audio_positions, pad_limit, (pad_limit - 1)) + + rotary_position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask, + ) + + output = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + rotary_position_ids=rotary_position_ids, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + deepstack_vision_embeds=deepstack_vision_embeds, + ) + return output, rope_deltas + + @staticmethod + def concat_causal_lm_outputs(outputs_list): + concatenated_logits = [] + concatenated_hidden_states = [] + concatenated_tokens = [] + + for output in outputs_list: + if isinstance(output.logits, torch.Tensor): + concatenated_logits.append(output.logits) + if isinstance(output.hidden_states, torch.Tensor): + concatenated_hidden_states.append(output.hidden_states) + elif isinstance(output.hidden_states, list): + concatenated_hidden_states.extend(output.hidden_states) + if hasattr(output, "tokens") and isinstance(output.tokens, torch.Tensor): + concatenated_tokens.append(output.tokens) + + concatenated_logits = torch.cat(concatenated_logits, dim=0) if concatenated_logits else None + concatenated_tokens = torch.cat(concatenated_tokens, dim=0) if concatenated_tokens else None + + concatenated_output = CausalLMOutputWithPast( + logits=concatenated_logits, + hidden_states=concatenated_hidden_states, + ) + if concatenated_tokens is not None: + concatenated_output.tokens = concatenated_tokens + return concatenated_output + + # --- Main forward --- + + def get_padding_length(self, input_ids): + buckets = self.context_encoding_model.config.neuron_config.buckets + for val in buckets: + if val >= input_ids.shape[1]: + return val + raise Exception("No bucket found for provided input_ids!") + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + seq_ids: Optional[torch.LongTensor] = None, + sampling_params: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + vision_mask: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + feature_attention_mask: Optional[torch.LongTensor] = None, + adapter_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + medusa_args=None, + input_capture_hook: Optional[Callable] = None, + tensor_capture_hook: Optional[Callable] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + is_context_encoding = input_ids.shape[-1] > 1 + # Ensure prompt_len < chosen bucket so pad_positions' fill_value + # (pad_limit - 1) lands on a padding token rather than the last real + # token. When prompt_len equals a bucket boundary (e.g. 256), the + # scatter in encode_vision_to_input overwrites position pad_limit-1 + # with a zero audio embedding, corrupting the prompt. + if is_context_encoding: + pad_limit = self.get_padding_length(input_ids) + if input_ids.shape[1] == pad_limit: + pad_token_id = getattr(self.config, "pad_token_id", None) or 151645 + input_ids = torch.cat([ + input_ids, + torch.full((input_ids.shape[0], 1), pad_token_id, + dtype=input_ids.dtype, device=input_ids.device), + ], dim=1) + if attention_mask is not None: + attention_mask = torch.cat([ + attention_mask, + torch.zeros((attention_mask.shape[0], 1), + dtype=attention_mask.dtype, + device=attention_mask.device), + ], dim=1) + if position_ids is not None: + position_ids = torch.cat([ + position_ids, + position_ids[:, -1:] + 1, + ], dim=1) + pad_limit = self.get_padding_length(input_ids) + + # --- Audio encoding --- + audio_embeddings = None + audio_positions = None + if ( + input_features is not None + and self.audio_encoder is not None + and is_context_encoding + ): + audio_token_id = getattr(self.config, "audio_token_id", 151646) + + with torch.no_grad(): + if feature_attention_mask is not None: + audio_feature_lengths = feature_attention_mask.sum(-1) + input_features_flat = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + else: + input_features_flat = input_features.squeeze(0).permute(1, 0) + audio_feature_lengths = torch.tensor( + [input_features_flat.shape[1]], dtype=torch.long + ) + + audio_embeddings = self.audio_encoder( + input_features_flat, + feature_lens=audio_feature_lengths, + ) + + audio_mask_bool = (input_ids == audio_token_id) + if audio_mask_bool.any() and audio_embeddings is not None: + audio_positions = generate_positions_from_mask( + audio_mask_bool.squeeze() + ) + + # --- Vision + Text prefill with atomic batching --- + has_vision = ( + pixel_values is not None + and is_context_encoding + and pixel_values.sum() != 0 + ) + + if has_vision: + batch_size = input_ids.shape[0] + images_per_batch_line = self._count_images_per_batch_line(input_ids, attention_mask) + vision_inputs_per_bl = self._split_vision_inputs_by_batch_line( + pixel_values, image_grid_thw, images_per_batch_line + ) + + if seq_ids is None: + seq_ids = torch.arange(batch_size) + + outputs = [] + rope_deltas_list = [] + for index in range(batch_size): + pv_i, grid_thw_i = vision_inputs_per_bl[index] + output, rope_deltas = self.forward_atomic_prefill( + input_ids[index].unsqueeze(0), + attention_mask[index].unsqueeze(0) if attention_mask is not None else None, + position_ids[index].unsqueeze(0) if position_ids is not None else None, + seq_ids[index].unsqueeze(0), + sampling_params[index].unsqueeze(0) if sampling_params is not None else None, + pv_i, + grid_thw_i, + audio_embeddings=audio_embeddings, + audio_positions=audio_positions, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + ) + outputs.append(output) + rope_deltas_list.append(rope_deltas) + + self.rope_deltas = torch.cat(rope_deltas_list, dim=0) + return self.concat_causal_lm_outputs(outputs) + + # --- Text-only or audio-only prefill, or decode --- + vision_embeddings_combined = None + vision_mask_combined = None + + if audio_embeddings is not None and audio_positions is not None: + vision_embeddings_combined = audio_embeddings + vision_mask_combined = pad_positions(audio_positions, pad_limit, (pad_limit - 1)) + + if vision_embeddings_combined is None: + vision_embeddings_combined, vision_mask_combined, deepstack_vision_embeds = ( + self.text_model_wrapper.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + ) + else: + _, _, deepstack_vision_embeds = ( + self.text_model_wrapper.get_dummy_vision_inputs( + config=self.text_config, + input_ids=input_ids, + n_active_tokens=pad_limit, + fill_value=(pad_limit - 1), + ) + ) + + # Compute rotary position IDs + if is_context_encoding: + rotary_position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, + video_grid_thw=None, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + if self.rope_deltas is not None: + delta = self.rope_deltas.to(input_ids.device) + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + else: + delta = 0 + rotary_position_ids = copy.deepcopy(position_ids) + rotary_position_ids = rotary_position_ids.add(delta) + rotary_position_ids = rotary_position_ids.unsqueeze(0).expand(3, -1, -1) + + output_token = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + seq_ids=seq_ids, + sampling_params=sampling_params, + input_capture_hook=input_capture_hook, + tensor_capture_hook=tensor_capture_hook, + rotary_position_ids=rotary_position_ids, + vision_embeddings=vision_embeddings_combined, + vision_mask=vision_mask_combined, + deepstack_vision_embeds=deepstack_vision_embeds, + ) + return output_token + + # --- HF model loading and state dict conversion --- + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, + inference_config: Qwen3OmniMoEInferenceConfig, + ) -> dict: + """Convert Qwen3-Omni full state dict to NxDI format. + + HF keys: thinker.visual.*, thinker.audio_tower.*, thinker.model.*, thinker.lm_head.* + + Step 0: Remap thinker.visual.* -> visual.* so Qwen3-VL vision conversion works. + Step 1: Vision encoder conversion (strips visual.* prefix, remaps attn keys) + Step 2: Audio encoder conversion (strips thinker.audio_tower.*, splits into frontend/transformer/postprocessor) + Step 3: MoE text model conversion (strips thinker.model.*, attention remap, expert stacking) + """ + # Step 0: Remap thinker.visual.* -> visual.* and map Qwen3-Omni vision + # merger/merger_list names to Qwen3-VL's merger/deepstack_merger_list schema: + # merger.ln_q.* -> merger.norm.* + # merger.mlp.0.* -> merger.linear_fc1.* + # merger.mlp.2.* -> merger.linear_fc2.* + # merger_list.N.* -> deepstack_merger_list.N.* (with same ln_q/mlp remap) + remapped = {} + for key, value in state_dict.items(): + if key.startswith("thinker.visual."): + new_key = "visual." + key[len("thinker.visual."):] + if new_key.startswith("visual.merger_list."): + new_key = "visual.deepstack_merger_list." + new_key[len("visual.merger_list."):] + new_key = new_key.replace(".ln_q.", ".norm.") + new_key = new_key.replace(".mlp.0.", ".linear_fc1.") + new_key = new_key.replace(".mlp.2.", ".linear_fc2.") + remapped[new_key] = value + else: + remapped[key] = value + state_dict = remapped + + # Step 1: Vision encoder conversion (Qwen3-VL: strips visual.*, remaps attn) + state_dict = NeuronQwen3VLForImageEncoding.convert_hf_to_neuron_state_dict( + state_dict, inference_config + ) + + # Step 2: Audio encoder conversion + audio_dtype = getattr( + inference_config.neuron_config, "torch_dtype", torch.bfloat16 + ) + state_dict = NeuronQwen3OmniAudioEncoder.convert_hf_to_neuron_state_dict( + state_dict, dtype=audio_dtype + ) + + # Step 3: MoE text model conversion + state_dict = convert_qwen3_omni_text_hf_to_neuron( + state_dict, inference_config.text_config + ) + + return state_dict + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + if "embed_tokens.weight" in state_dict and "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["embed_tokens.weight"].clone() + + @classmethod + def get_config_cls(cls): + return Qwen3OmniMoEInferenceConfig + + @classmethod + def prepare_input_args(cls, prompts, images, processor, role="user", config=None): + return NeuronQwen3VLForImageEncoding.prepare_input_args( + prompts, images, processor, role, config + ) diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_audio.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_audio.py new file mode 100644 index 00000000..5bd729ba --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_audio.py @@ -0,0 +1,719 @@ +"""Qwen3-Omni Audio Encoder for NxD Inference. + +Conv2d frontend (3 layers, downsample_hidden_size=480) + +32 transformer layers (d_model=1280, heads=20, ffn=5120) on Neuron + +proj1 + GELU + proj2 postprocessor on CPU. + +Key differences from Qwen2.5-Omni audio encoder: + - Conv2d frontend (2D mel processing) instead of Conv1d + - All attention projections have bias=True (k_proj included) + - proj1 + GELU + proj2 output instead of AvgPool + single proj + - No audio_bos_eos_token + - Different output length calculation (3-stage Conv2d downsampling) +""" + +import logging +import math +from types import SimpleNamespace +from typing import List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) +from neuronx_distributed_inference.models.application_base import NeuronApplicationBase +from neuronx_distributed_inference.models.config import InferenceConfig, NeuronConfig +from neuronx_distributed_inference.models.model_wrapper import ( + EncoderModelInstance, + ModelWrapper, +) + +logger = logging.getLogger(__name__) + + +def _get_feat_extract_output_lengths(input_lengths): + """Compute output lengths after 3x Conv2d (stride=2 each) + chunk-aware calculation. + + Matches HF Qwen3OmniMoeAudioEncoder._get_feat_extract_output_lengths. + Each Conv2d with stride=2 halves the time dimension: (L-1)//2 + 1. + The chunking introduces a correction factor of 13 per full window. + """ + input_lengths_leave = input_lengths % 100 + feat_lengths = (input_lengths_leave - 1) // 2 + 1 + output_lengths = ((feat_lengths - 1) // 2 + 1 - 1) // 2 + 1 + (input_lengths // 100) * 13 + return output_lengths + + +# --------------------------------------------------------------------------- +# CPU components +# --------------------------------------------------------------------------- + +class SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length, channels, max_timescale=10000): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels") + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + def forward(self, seqlen: int): + return self.positional_embedding[:seqlen, :] + + +class AudioCPUFrontend(nn.Module): + """Conv2d frontend + positional embeddings + chunking (CPU). + + 3x Conv2d layers with stride=2 each: mel (1, 128, T) -> (480, freq_reduced, T/8). + Then linear projection to d_model and sinusoidal positional embeddings. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + num_mel_bins = audio_config.num_mel_bins # 128 + max_source_positions = audio_config.max_source_positions # 1500 + self.n_window = audio_config.n_window # 100 + self.d_model = d_model + downsample_hidden_size = getattr(audio_config, "downsample_hidden_size", 480) + + self.conv2d1 = nn.Conv2d(1, downsample_hidden_size, 3, 2, padding=1, dtype=dtype) + self.conv2d2 = nn.Conv2d( + downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1, dtype=dtype + ) + self.conv2d3 = nn.Conv2d( + downsample_hidden_size, downsample_hidden_size, 3, 2, padding=1, dtype=dtype + ) + # After 3x stride-2 convs on freq axis: ((128+1)//2+1)//2+1)//2 = 16 (actually 17, let's compute) + # freq=128: after conv1 (s=2,p=1): (128+2*1-3)//2+1 = 64 + # after conv2: (64+2*1-3)//2+1 = 32 + # after conv3: (32+2*1-3)//2+1 = 16 + # HF formula: ((((num_mel_bins+1)//2+1)//2+1)//2 which is the same + freq_reduced = (((num_mel_bins + 1) // 2 + 1) // 2 + 1) // 2 + self.conv_out = nn.Linear( + downsample_hidden_size * freq_reduced, d_model, bias=False, dtype=dtype + ) + self.positional_embedding = SinusoidsPositionEmbedding( + max_source_positions, d_model + ) + + def forward(self, input_features, feature_lens): + """Process mel spectrogram through Conv2d frontend. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + + Returns: + hidden_states: (total_valid_tokens, d_model) + aftercnn_lens: (num_audios,) valid tokens per audio + cu_seqlens: cumulative sequence lengths for attention masking + """ + aftercnn_lens = _get_feat_extract_output_lengths(feature_lens) + + # Split into chunks of n_window * 2 mel frames + chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long() + chunk_lengths = torch.tensor( + [self.n_window * 2] * chunk_num.sum().item(), + dtype=torch.long, device=feature_lens.device, + ) + tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:] + chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2) + chunk_lengths = torch.where( + chunk_lengths == 0, self.n_window * 2, chunk_lengths + ) + + # Split mel into chunks, pad to max chunk length + chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0) + padded_feature = nn.utils.rnn.pad_sequence( + chunk_list, batch_first=True + ).transpose(1, 2) # (num_chunks, mel_bins, max_chunk_time) + + # Compute per-chunk output lengths after 3x Conv2d + feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths) + padded_mask_after_cnn = nn.utils.rnn.pad_sequence( + [torch.ones(length, dtype=torch.bool, device=padded_feature.device) + for length in feature_lens_after_cnn], + batch_first=True, + ) + + # Conv2d frontend: (num_chunks, 1, mel_bins, time) -> (num_chunks, 480, freq, time) + padded_feature = padded_feature.unsqueeze(1) + padded_embed = F.gelu(self.conv2d1(padded_feature)) + padded_embed = F.gelu(self.conv2d2(padded_embed)) + padded_embed = F.gelu(self.conv2d3(padded_embed)) + + # Reshape: (batch, channels, freq, time) -> (batch, time, channels*freq) -> linear -> (batch, time, d_model) + b, c, f, t = padded_embed.size() + padded_embed = self.conv_out( + padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f) + ) + + # Add positional embeddings + positional_embedding = ( + self.positional_embedding.positional_embedding[:padded_embed.shape[1], :] + .unsqueeze(0) + .to(padded_embed.dtype) + ) + padded_embed = padded_embed + positional_embedding + + # Flatten valid tokens + hidden_states = padded_embed[padded_mask_after_cnn] + + # Compute cu_seqlens for block-diagonal attention + cu_seqlens = torch.cat([ + torch.zeros(1, device=feature_lens.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0).to(torch.int32), + ]) + + return hidden_states, aftercnn_lens, cu_seqlens, padded_mask_after_cnn + + +class AudioCPUPostprocessor(nn.Module): + """LayerNorm + proj1 + GELU + proj2 (CPU). + + No AvgPool (unlike Qwen2.5-Omni). No audio_bos_eos_token. + """ + + def __init__(self, audio_config, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + d_model = audio_config.d_model # 1280 + output_dim = audio_config.output_dim # 3584 + + self.ln_post = nn.LayerNorm(d_model) # stays float32 + self.proj1 = nn.Linear(d_model, d_model, dtype=dtype) + self.act = nn.GELU() + self.proj2 = nn.Linear(d_model, output_dim, dtype=dtype) + + def forward(self, hidden_states): + """Post-process transformer output. + + Args: + hidden_states: (total_tokens, d_model) transformer output + + Returns: + audio_embeddings: (total_tokens, output_dim) + """ + hidden_states = self.ln_post(hidden_states.float()).to(hidden_states.dtype) + hidden_states = self.proj1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.proj2(hidden_states) + return hidden_states + + +# --------------------------------------------------------------------------- +# Neuron-compiled transformer components (TP-parallel) +# --------------------------------------------------------------------------- + +class NeuronAudioAttention(nn.Module): + """TP-parallel self-attention for audio encoder. + + All projections have bias=True (unlike Qwen2.5-Omni where k_proj has no bias). + + To support tp_degree values that do not evenly divide num_heads (the Qwen3-Omni + audio tower has 20 heads, which does not divide the text model's TP=8), we pad + the head count up to the next multiple of tp_degree and zero-fill the added + heads' QKV/output weights. The padding heads produce zeros and are discarded + by the output projection. + """ + + def __init__(self, d_model, num_heads, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.d_model = d_model + self.num_heads_orig = num_heads + self.head_dim = d_model // num_heads + # Round num_heads up to the next multiple of tp_degree. + self.num_heads = ((num_heads + tp_degree - 1) // tp_degree) * tp_degree + self.num_heads_per_rank = self.num_heads // tp_degree + # Hidden size used internally (may be padded beyond the upstream d_model). + self.padded_hidden = self.num_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + + # q/k/v are ColumnParallel: output is split by tp. Use padded_hidden as the + # output size so each rank gets exactly num_heads_per_rank heads. + self.q_proj = ColumnParallelLinear( + d_model, self.padded_hidden, bias=True, gather_output=False, dtype=dtype, + ) + self.k_proj = ColumnParallelLinear( + d_model, self.padded_hidden, bias=True, gather_output=False, dtype=dtype, + ) + self.v_proj = ColumnParallelLinear( + d_model, self.padded_hidden, bias=True, gather_output=False, dtype=dtype, + ) + # out_proj is RowParallel: input is split by tp, output is d_model. + self.out_proj = RowParallelLinear( + self.padded_hidden, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + + def forward(self, hidden_states, attention_mask=None): + bsz, seq_len, _ = hidden_states.shape + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q = q.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + k = k.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + v = v.view(bsz, seq_len, self.num_heads_per_rank, self.head_dim).transpose(1, 2) + + scores = torch.matmul(q, k.transpose(-2, -1)) * self.scaling + if attention_mask is not None: + scores = scores + attention_mask + + attn_weights = F.softmax(scores.float(), dim=-1).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, seq_len, -1) + attn_output = self.out_proj(attn_output) + return attn_output + + +class NeuronAudioEncoderLayer(nn.Module): + """Pre-norm transformer layer with TP parallelism.""" + + def __init__(self, d_model, num_heads, ffn_dim, tp_degree, dtype=torch.bfloat16): + super().__init__() + self.self_attn = NeuronAudioAttention(d_model, num_heads, tp_degree, dtype) + self.self_attn_layer_norm = nn.LayerNorm(d_model) + self.fc1 = ColumnParallelLinear( + d_model, ffn_dim, bias=True, gather_output=False, dtype=dtype, + ) + self.fc2 = RowParallelLinear( + ffn_dim, d_model, bias=True, input_is_parallel=True, dtype=dtype, + ) + self.final_layer_norm = nn.LayerNorm(d_model) + + def forward(self, hidden_states, attention_mask=None): + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = self.self_attn(hidden_states, attention_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states.float()).to(residual.dtype) + hidden_states = F.gelu(self.fc1(hidden_states)) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NeuronAudioTransformerModel(nn.Module): + """Audio encoder transformer (32 layers, compiled on Neuron). + + Input: (1, padded_seq_len, d_model) + (1, 1, padded_seq_len, padded_seq_len) + Output: (1, padded_seq_len, d_model) + """ + + def __init__(self, config: InferenceConfig): + super().__init__() + audio_config = config.audio_config + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + + tp_degree = config.neuron_config.tp_degree + dtype = config.neuron_config.torch_dtype + + d_model = audio_config.d_model + num_heads = audio_config.encoder_attention_heads + ffn_dim = audio_config.encoder_ffn_dim + num_layers = audio_config.encoder_layers + + self.layers = nn.ModuleList([ + NeuronAudioEncoderLayer(d_model, num_heads, ffn_dim, tp_degree, dtype) + for _ in range(num_layers) + ]) + + def forward(self, hidden_states, attention_mask): + for layer in self.layers: + hidden_states = layer(hidden_states, attention_mask) + return hidden_states + + +# --------------------------------------------------------------------------- +# ModelWrapper and Application classes +# --------------------------------------------------------------------------- + +class AudioTransformerModelWrapper(ModelWrapper): + """Handles bucketing, padding, and Neuron compilation for audio transformer.""" + + def __init__(self, config, model_cls, tag="", compiler_args=None, + priority_model_idx=None, pipeline_execution=True, + return_ranked_to_cpu=False, model_init_kwargs={}): + super().__init__( + config, model_cls, tag, compiler_args, priority_model_idx, + pipeline_execution, return_ranked_to_cpu, model_init_kwargs, + ) + + def input_generator(self) -> List[Tuple[torch.Tensor]]: + inputs = [] + dtype = self.config.neuron_config.torch_dtype + d_model = self.config.audio_config.d_model + if isinstance(d_model, dict): + d_model = d_model.get("d_model", 1280) + + for bucket in self.config.neuron_config.buckets: + hidden_states = torch.ones([1, bucket, d_model], dtype=dtype) + attention_mask = torch.zeros( + [1, 1, bucket, bucket], dtype=dtype, + ) + inputs.append((hidden_states, attention_mask)) + return inputs + + def get_model_instance(self): + return EncoderModelInstance(model_cls=self.model_cls, config=self.config) + + def get_target_bucket(self, seq_len): + for bucket in self.config.neuron_config.buckets: + if bucket >= seq_len: + return bucket + raise ValueError( + f"No bucket found for seq_len={seq_len}. " + f"Buckets: {self.config.neuron_config.buckets}" + ) + + def forward(self, hidden_states, attention_mask): + if self.model is None: + raise RuntimeError("Forward called before load.") + + seq_len = hidden_states.shape[1] + bucket = self.get_target_bucket(seq_len) + + if seq_len < bucket: + pad_len = bucket - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_len)) + dtype = attention_mask.dtype + mask_pad = torch.full( + (1, 1, bucket, bucket), + torch.finfo(dtype).min, + dtype=dtype, + ) + mask_pad[:, :, :seq_len, :seq_len] = attention_mask + attention_mask = mask_pad + + output = self._forward(hidden_states, attention_mask) + # Handle ranked output from Neuron pipeline execution + if isinstance(output, list): + output = output[0][0] + if output.device.type != "cpu": + output = output.to("cpu") + return output[:, :seq_len, :] + + +class NeuronQwen3OmniForAudioEncoding(NeuronApplicationBase): + """Neuron application for audio encoder transformer layers.""" + + _model_cls = NeuronAudioTransformerModel + + def __init__(self, model_path, config=None, neuron_config=None, transformer_state_dict=None): + # NeuronApplicationBase signature: (model_path, config, neuron_config) + super().__init__(model_path=model_path, config=config, neuron_config=neuron_config) + self._transformer_state_dict = transformer_state_dict + self.model = AudioTransformerModelWrapper( + config=self.config, + model_cls=self._model_cls, + tag=self._model_cls.__name__, + compiler_args=self.get_compiler_args(), + priority_model_idx=0, + ) + self.models.append(self.model) + + def forward(self, hidden_states, attention_mask): + return self.models[0](hidden_states, attention_mask) + + def get_compiler_args(self): + return ( + "--auto-cast=none --model-type=transformer " + "--tensorizer-options='--enable-ccop-compute-overlap " + "--cc-pipeline-tiling-factor=2 ' -O1 " + "--internal-hlo2tensorizer-options='--verify-hlo=true'" + ) + + def checkpoint_loader_fn(self, mmap: bool = False): + """Return pre-converted transformer-only state dict. + + Bypass NeuronApplicationBase's HF-loading path since we only have + audio_tower weights, already extracted by the caller. + """ + if self._transformer_state_dict is None: + raise RuntimeError( + "transformer_state_dict must be provided to NeuronQwen3OmniForAudioEncoding " + "for compilation/sharding." + ) + # Only keep transformer.layers.* weights (drop frontend/postprocessor). + sd = {} + for k, v in self._transformer_state_dict.items(): + if k.startswith("transformer.layers."): + new_key = k[len("transformer."):] + sd[new_key] = v + return sd + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + raise NotImplementedError( + "NeuronQwen3OmniForAudioEncoding loads weights via checkpoint_loader_fn, " + "not load_hf_model." + ) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, inference_config): + """Extract audio_tower.layers.* keys and strip prefix for Neuron.""" + new_state_dict = {} + for key, value in state_dict.items(): + if key.startswith("thinker.audio_tower.layers."): + new_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower.layers."): + new_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + new_state_dict[new_key] = value + return new_state_dict + + @classmethod + def get_config_cls(cls): + return AudioEncoderInferenceConfig + + +class AudioEncoderInferenceConfig(InferenceConfig): + """Config for audio encoder transformer compilation.""" + + def __init__(self, neuron_config, audio_config, **kwargs): + self.audio_config = audio_config + if isinstance(audio_config, dict): + self.audio_config = SimpleNamespace(**audio_config) + super().__init__(neuron_config=neuron_config, **kwargs) + self.num_cores_per_group = 1 + + def add_derived_config(self): + pass + + def get_required_attributes(self): + return [] + + +# --------------------------------------------------------------------------- +# Full Audio Encoder +# --------------------------------------------------------------------------- + +class NeuronQwen3OmniAudioEncoder(nn.Module): + """Qwen3-Omni Audio Encoder with Neuron acceleration. + + CPU: mel -> Conv2d frontend -> positional embeddings -> chunking + Neuron (TP): 32 transformer layers with block-diagonal attention + CPU: LayerNorm -> proj1 -> GELU -> proj2 -> audio embeddings + """ + + def __init__(self, audio_config, neuron_config=None, dtype=torch.bfloat16): + super().__init__() + if isinstance(audio_config, dict): + audio_config = SimpleNamespace(**audio_config) + self.audio_config = audio_config + self.dtype = dtype + self.n_window = audio_config.n_window + self.n_window_infer = getattr(audio_config, "n_window_infer", 400) + + self.frontend = AudioCPUFrontend(audio_config, dtype=dtype) + self.postprocessor = AudioCPUPostprocessor(audio_config, dtype=dtype) + self.transformer = None + self._neuron_config = neuron_config + + def _prepare_attention_mask(self, seq_length, cu_seqlens, dtype): + attention_mask = torch.full( + [1, 1, seq_length, seq_length], + torch.finfo(dtype).min, + dtype=dtype, + ) + for i in range(1, len(cu_seqlens)): + s, e = cu_seqlens[i - 1], cu_seqlens[i] + attention_mask[..., s:e, s:e] = 0 + return attention_mask + + def _compute_inference_cu_seqlens(self, aftercnn_lens, padded_mask_after_cnn): + """Compute cu_seqlens using n_window_infer chunking (matches HF forward). + + The HF encoder uses n_window_infer to sub-chunk the attention windows + for inference efficiency: each audio is split into sub-windows of size + window_aftercnn = mask_time_dim * (n_window_infer / (n_window * 2)). + """ + window_aftercnn = padded_mask_after_cnn.shape[-1] * ( + self.n_window_infer // (self.n_window * 2) + ) + cu_chunk_lens = [0] + for cnn_len in aftercnn_lens: + cnn_len = cnn_len.item() + cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn) + remainder = cnn_len % window_aftercnn + if remainder != 0: + cu_chunk_lens += [remainder] + cu_seqlens = torch.tensor( + cu_chunk_lens, device=aftercnn_lens.device + ).cumsum(-1, dtype=torch.int32) + return cu_seqlens + + def forward(self, input_features, feature_lens, aftercnn_lens=None): + """Process mel spectrogram through audio encoder. + + Args: + input_features: (n_mels, total_mel_len) mel spectrogram + feature_lens: (num_audios,) mel length for each audio + + Returns: + audio_embeddings: (total_audio_tokens, output_dim) + """ + # CPU: Conv2d frontend + chunking + hidden_states, aftercnn_lens_actual, _, padded_mask_after_cnn = self.frontend( + input_features, feature_lens + ) + + if self.transformer is None: + raise RuntimeError( + "Audio transformer must be compiled and loaded before inference." + ) + + # Neuron: transformer layers + seq_len = hidden_states.shape[0] + # HF uses n_window_infer-based sub-window chunking (window_aftercnn tokens + # per block) rather than one block per n_window chunk. Short audios + # produce the same block-diagonal structure either way, but for long + # audios the basic (per-n_window) grouping creates too-narrow attention + # windows (13 tokens) and the encoder output degrades into noise. + cu_seqlens = self._compute_inference_cu_seqlens( + aftercnn_lens_actual, padded_mask_after_cnn + ) + attention_mask = self._prepare_attention_mask( + seq_len, cu_seqlens, self.dtype + ) + hidden_states = hidden_states.unsqueeze(0) + hidden_states = self.transformer(hidden_states, attention_mask) + hidden_states = hidden_states.squeeze(0) + + # CPU: Postprocessing (ln_post + proj1 + GELU + proj2) + return self.postprocessor(hidden_states) + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict, dtype=torch.bfloat16, + tp_degree=1, num_heads=20, head_dim=64): + """Convert HF state dict to split architecture format. + + Prefixes keys for three groups: + - frontend.*: conv2d1, conv2d2, conv2d3, conv_out, positional_embedding (CPU) + - layers.*: transformer layers (Neuron) + - postprocessor.*: ln_post, proj1, proj2 (CPU) + + When tp_degree does not divide num_heads, pad q/k/v/out_proj weights along + the head dimension with zeros so each TP rank gets an integer number of + heads. The added heads produce zero output; the out_proj's zero columns + ensure they don't affect the residual stream. + """ + new_state_dict = {} + + ln_suffixes = ( + "self_attn_layer_norm.weight", "self_attn_layer_norm.bias", + "final_layer_norm.weight", "final_layer_norm.bias", + "ln_post.weight", "ln_post.bias", + ) + + frontend_prefixes = ( + "conv2d1.", "conv2d2.", "conv2d3.", "conv_out.", "positional_embedding.", + ) + postprocessor_prefixes = ("ln_post.", "proj1.", "proj2.") + + orig_hidden = num_heads * head_dim + padded_heads = ((num_heads + tp_degree - 1) // tp_degree) * tp_degree + padded_hidden = padded_heads * head_dim + pad_out = padded_hidden - orig_hidden # rows to add to q/k/v + + def _pad_attn(clean_key, tensor): + """Zero-pad q/k/v/out_proj tensors along the head dim.""" + if pad_out == 0: + return tensor + if (clean_key.endswith(".self_attn.q_proj.weight") + or clean_key.endswith(".self_attn.k_proj.weight") + or clean_key.endswith(".self_attn.v_proj.weight")): + # shape: (orig_hidden, d_model) -> (padded_hidden, d_model) + return F.pad(tensor, (0, 0, 0, pad_out)) + if (clean_key.endswith(".self_attn.q_proj.bias") + or clean_key.endswith(".self_attn.k_proj.bias") + or clean_key.endswith(".self_attn.v_proj.bias")): + # shape: (orig_hidden,) -> (padded_hidden,) + return F.pad(tensor, (0, pad_out)) + if clean_key.endswith(".self_attn.out_proj.weight"): + # shape: (d_model, orig_hidden) -> (d_model, padded_hidden) + return F.pad(tensor, (0, pad_out)) + return tensor + + for key, value in state_dict.items(): + if key.startswith("thinker.audio_tower."): + clean_key = key[len("thinker.audio_tower."):] + elif key.startswith("audio_tower."): + clean_key = key[len("audio_tower."):] + else: + new_state_dict[key] = value + continue + + if any(clean_key.endswith(s) for s in ln_suffixes): + target_dtype = torch.float32 + else: + target_dtype = dtype + + if any(clean_key.startswith(p) for p in frontend_prefixes): + new_state_dict["frontend." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif any(clean_key.startswith(p) for p in postprocessor_prefixes): + new_state_dict["postprocessor." + clean_key] = ( + value.clone().detach().contiguous().to(target_dtype) + ) + elif clean_key.startswith("layers."): + padded = _pad_attn(clean_key, value) + new_state_dict["transformer." + clean_key] = ( + padded.clone().detach().contiguous().to(target_dtype) + ) + else: + logger.warning("Unknown audio key: %s", clean_key) + + return new_state_dict + + @staticmethod + def from_pretrained_state_dict(audio_config, state_dict, dtype=torch.bfloat16): + """Create audio encoder and load CPU weights from converted state dict.""" + encoder = NeuronQwen3OmniAudioEncoder(audio_config, dtype=dtype) + + cpu_keys = {} + for key, value in state_dict.items(): + if key.startswith("frontend.") or key.startswith("postprocessor."): + cpu_keys[key] = value + + if cpu_keys: + missing, unexpected = encoder.load_state_dict(cpu_keys, strict=False) + missing = [k for k in missing if not k.startswith("transformer.")] + if missing: + logger.warning("Audio encoder CPU missing keys: %s", missing[:10]) + logger.info("Loaded %d CPU weights into audio encoder", len(cpu_keys)) + + return encoder diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_code_predictor.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_code_predictor.py new file mode 100644 index 00000000..50d09368 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_code_predictor.py @@ -0,0 +1,370 @@ +"""Qwen3-Omni Talker Code Predictor on Neuron. + +The code predictor runs once per talker decode step: + 1. Prefill with 2 tokens (past_hidden, last_id_hidden) → code 0 logits + KV + 2. 14 decode steps, each consuming the previous argmax'd code, producing + codes 1..14 and per-step hidden states. + +Total 15 residual codes are produced; only 14 decode hidden states are +consumed by the talker (mid_residual_hiddens). The prefill's KV cache has +length 2; decode extends it up to length 16. + +We compile a single NEFF: a 16-token-long causal self-attention over a fixed +input buffer, driven by a runtime state machine external to the NEFF (the +host Python code does the greedy argmax + embedding lookup between NEFF +invocations). The NEFF is invoked 15 times per talker step: + - Invocation 0: prefill (2 "valid" positions, 14 masked) + - Invocation 1..14: decode (i+1 valid positions) + +This avoids having to trace a KV-cache scatter op or multiple NEFFs. + +Architecture (from config.talker_config.code_predictor_config): + - hidden_size=1024, num_hidden_layers=5, dense GQA + - num_attention_heads=16, num_key_value_heads=8, head_dim=128 + - intermediate_size=3072, SwiGLU MLP + - q_norm + k_norm (per-head-dim RMSNorm) + - Plain 1D RoPE (no MRoPE), theta=1e6 + - 15 codec_embedding tables + 15 lm_heads + +TP sharding at TP=8: 2 Q heads/rank, 1 KV head/rank, dense MLP sharded +as ColumnParallel (3072/8=384) + RowParallel. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, +) + +logger = logging.getLogger("Neuron") + + +# Constants from config.talker_config.code_predictor_config +HIDDEN_SIZE = 1024 +NUM_LAYERS = 5 +NUM_ATTN_HEADS = 16 +NUM_KV_HEADS = 8 +HEAD_DIM = 128 +INTERMEDIATE = 3072 +VOCAB_SIZE = 2048 +NUM_CODE_GROUPS = 16 +NUM_EMBED_TABLES = NUM_CODE_GROUPS - 1 # 15 +NUM_LM_HEADS = NUM_CODE_GROUPS - 1 # 15 +# Total positions across a full prefill+decode cycle = 2 + 14 = 16. +MAX_SEQ_LEN = 16 +RMS_EPS = 1e-6 +ROPE_THETA = 1_000_000.0 + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = RMS_EPS): + super().__init__() + self.weight = nn.Parameter(torch.ones(dim)) + self.eps = eps + + def forward(self, x): + dtype = x.dtype + x32 = x.float() + var = x32.pow(2).mean(-1, keepdim=True) + x32 = x32 * torch.rsqrt(var + self.eps) + return (x32 * self.weight).to(dtype) + + +def _compute_rope(dim: int, max_pos: int, base: float = ROPE_THETA): + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + t = torch.arange(max_pos, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return emb.cos(), emb.sin() + + +def _rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rope(q, k, cos, sin): + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + return (q * cos) + (_rotate_half(q) * sin), (k * cos) + (_rotate_half(k) * sin) + + +class CPAttention(nn.Module): + def __init__(self, tp_degree: int, dtype=torch.bfloat16): + super().__init__() + self.tp_degree = tp_degree + self.num_heads_per_rank = NUM_ATTN_HEADS // tp_degree + self.num_kv_heads_per_rank = max(NUM_KV_HEADS // tp_degree, 1) + self.num_key_value_groups = self.num_heads_per_rank // self.num_kv_heads_per_rank + self.scaling = HEAD_DIM ** -0.5 + + self.q_proj = ColumnParallelLinear( + HIDDEN_SIZE, NUM_ATTN_HEADS * HEAD_DIM, bias=False, + gather_output=False, dtype=dtype, + ) + self.k_proj = ColumnParallelLinear( + HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False, + gather_output=False, dtype=dtype, + ) + self.v_proj = ColumnParallelLinear( + HIDDEN_SIZE, NUM_KV_HEADS * HEAD_DIM, bias=False, + gather_output=False, dtype=dtype, + ) + self.o_proj = RowParallelLinear( + NUM_ATTN_HEADS * HEAD_DIM, HIDDEN_SIZE, bias=False, + input_is_parallel=True, dtype=dtype, + ) + self.q_norm = RMSNorm(HEAD_DIM) + self.k_norm = RMSNorm(HEAD_DIM) + + def forward(self, x, cos, sin, causal_mask): + B, S, _ = x.shape + q = self.q_norm(self.q_proj(x).view(B, S, self.num_heads_per_rank, HEAD_DIM)).transpose(1, 2) + k = self.k_norm(self.k_proj(x).view(B, S, self.num_kv_heads_per_rank, HEAD_DIM)).transpose(1, 2) + v = self.v_proj(x).view(B, S, self.num_kv_heads_per_rank, HEAD_DIM).transpose(1, 2) + + q, k = _apply_rope(q, k, cos, sin) + + if self.num_key_value_groups > 1: + k_r = k.repeat_interleave(self.num_key_value_groups, dim=1) + v_r = v.repeat_interleave(self.num_key_value_groups, dim=1) + else: + k_r = k + v_r = v + + scores = torch.matmul(q, k_r.transpose(-2, -1)) * self.scaling + scores = scores + causal_mask + attn = F.softmax(scores.float(), dim=-1).to(q.dtype) + out = torch.matmul(attn, v_r).transpose(1, 2).contiguous().view(B, S, -1) + return self.o_proj(out) + + +class CPMLP(nn.Module): + def __init__(self, tp_degree: int, dtype=torch.bfloat16): + super().__init__() + self.gate_proj = ColumnParallelLinear( + HIDDEN_SIZE, INTERMEDIATE, bias=False, gather_output=False, dtype=dtype, + ) + self.up_proj = ColumnParallelLinear( + HIDDEN_SIZE, INTERMEDIATE, bias=False, gather_output=False, dtype=dtype, + ) + self.down_proj = RowParallelLinear( + INTERMEDIATE, HIDDEN_SIZE, bias=False, input_is_parallel=True, dtype=dtype, + ) + + def forward(self, x): + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class CPLayer(nn.Module): + def __init__(self, tp_degree: int, dtype=torch.bfloat16): + super().__init__() + self.input_layernorm = RMSNorm(HIDDEN_SIZE) + self.self_attn = CPAttention(tp_degree, dtype=dtype) + self.post_attention_layernorm = RMSNorm(HIDDEN_SIZE) + self.mlp = CPMLP(tp_degree, dtype=dtype) + + def forward(self, x, cos, sin, causal_mask): + x = x + self.self_attn(self.input_layernorm(x), cos, sin, causal_mask) + x = x + self.mlp(self.post_attention_layernorm(x)) + return x + + +class NeuronCodePredictor(nn.Module): + """One NEFF that re-runs the full causal self-attention over a + MAX_SEQ_LEN=16 input buffer each step. + + Inputs (all fixed shape): + inputs_embeds: [1, MAX_SEQ_LEN, HIDDEN] + position_ids: [1, MAX_SEQ_LEN] + mask_1d: [1, MAX_SEQ_LEN] — 1 for valid, 0 for masked + + Internally builds a [1,1,MAX,MAX] causal mask that also masks out the + invalid suffix (set to -inf where mask_1d==0). + + Output: + hidden: [1, MAX_SEQ_LEN, HIDDEN] (pre-lm_head; caller picks + the last-valid position) + """ + + def __init__(self, tp_degree: int, dtype=torch.bfloat16): + super().__init__() + self.tp_degree = tp_degree + self.dtype = dtype + self.layers = nn.ModuleList([CPLayer(tp_degree, dtype=dtype) for _ in range(NUM_LAYERS)]) + self.norm = RMSNorm(HIDDEN_SIZE) + + cos, sin = _compute_rope(HEAD_DIM, MAX_SEQ_LEN) + self.register_buffer("cos_cache", cos.to(dtype), persistent=False) + self.register_buffer("sin_cache", sin.to(dtype), persistent=False) + + def forward(self, inputs_embeds, position_ids, mask_1d): + # Build causal + validity mask [1, 1, MAX, MAX] + # Use a "small" negative value (not finfo.min) to avoid -inf overflow + # when causal and key masks overlap, which on bf16/Neuron can produce NaN. + MASK_VAL = -1e4 + S = MAX_SEQ_LEN + base_dtype = inputs_embeds.dtype + causal = torch.triu( + torch.full((S, S), MASK_VAL, dtype=base_dtype), + diagonal=1, + ).view(1, 1, S, S) + # mask out invalid keys (k masked columns) + key_mask = (mask_1d == 0).view(1, 1, 1, S).to(base_dtype) * MASK_VAL + attn_mask = causal + key_mask # broadcasted, min value 2*MASK_VAL is fine + + cos = self.cos_cache[position_ids[0]] + sin = self.sin_cache[position_ids[0]] + + x = inputs_embeds + for layer in self.layers: + x = layer(x, cos, sin, attn_mask) + return self.norm(x) + + +# --------------------------------------------------------------------------- +# UnifiedNeuronCodePredictor: runs all 15 steps in one NEFF call. +# --------------------------------------------------------------------------- + +class UnifiedNeuronCodePredictor(nn.Module): + """Single-NEFF unrolled code predictor. + + Inputs: + past_hidden: [1, 1, 1024] (talker's last hidden) + last_id_hidden: [1, 1, 1024] (embedding of last predicted talker token) + + Outputs: + codes: [1, 15] (int32 residual codes) + mid_hiddens: [1, 14, 1024] (hidden states from decode steps 1..14) + + Internally: + - Builds a 16-token buffer [past_hidden, last_id_hidden, z, z, ..., z] + - Runs 15 unrolled rounds. Round 0 predicts code[0] from the last + valid position. Rounds 1..14 embed the previous code via + codec_embedding[gs-1], put it at slot (1+gs), extend mask, rerun the + 5-layer attention, and apply lm_head[gs]. + - This uses the same "rerun full attention" pattern as the non-unified + predictor (no KV cache). 15 × full 16-pos attention is still small + since each layer is 5 × 1024-dim × 16 positions. + + The codec_embedding and lm_head tensors are replicated on every rank + (not TP-sharded) because they are small (15 × 2048 × 1024 = 31 MB) and + used inside a tight loop. + """ + + def __init__(self, tp_degree: int, dtype=torch.bfloat16): + super().__init__() + self.tp_degree = tp_degree + self.dtype = dtype + self.layers = nn.ModuleList([CPLayer(tp_degree, dtype=dtype) for _ in range(NUM_LAYERS)]) + self.norm = RMSNorm(HIDDEN_SIZE) + + cos, sin = _compute_rope(HEAD_DIM, MAX_SEQ_LEN) + self.register_buffer("cos_cache", cos.to(dtype), persistent=False) + self.register_buffer("sin_cache", sin.to(dtype), persistent=False) + + # Stacked codec_embedding tables: [15, VOCAB, HIDDEN] + self.codec_embedding_stacked = nn.Parameter( + torch.zeros(NUM_EMBED_TABLES, VOCAB_SIZE, HIDDEN_SIZE, dtype=dtype), + requires_grad=False, + ) + # Stacked lm_head weights: [15, VOCAB, HIDDEN] + self.lm_head_stacked = nn.Parameter( + torch.zeros(NUM_LM_HEADS, VOCAB_SIZE, HIDDEN_SIZE, dtype=dtype), + requires_grad=False, + ) + + def _attention_mask(self, mask_1d, base_dtype): + """Build [1, 1, MAX, MAX] causal + validity mask.""" + MASK_VAL = -1e4 + S = MAX_SEQ_LEN + causal = torch.triu( + torch.full((S, S), MASK_VAL, dtype=base_dtype), + diagonal=1, + ).view(1, 1, S, S) + key_mask = (mask_1d == 0).view(1, 1, 1, S).to(base_dtype) * MASK_VAL + return causal + key_mask + + def _run_layers(self, buf, attn_mask, cos, sin): + x = buf + for layer in self.layers: + x = layer(x, cos, sin, attn_mask) + return self.norm(x) + + def forward(self, past_hidden, last_id_hidden): + """Unroll 15 code-prediction steps in one NEFF call. + + past_hidden / last_id_hidden: [1, 1, HIDDEN], dtype=bfloat16. + """ + dtype = past_hidden.dtype + device = past_hidden.device + B = past_hidden.shape[0] # always 1 for our pipeline + + # Fixed-shape 16-position buffer + buf = torch.zeros(B, MAX_SEQ_LEN, HIDDEN_SIZE, dtype=dtype, device=device) + # Slot 0 = past_hidden, slot 1 = last_id_hidden + buf = buf.clone() # ensure writable in trace + buf[:, 0:1, :] = past_hidden + buf[:, 1:2, :] = last_id_hidden + + # Validity mask: both prefill positions are valid + mask_1d = torch.zeros(B, MAX_SEQ_LEN, dtype=torch.int32, device=device) + mask_1d[:, 0] = 1 + mask_1d[:, 1] = 1 + + position_ids = torch.arange(MAX_SEQ_LEN, dtype=torch.int32, device=device).unsqueeze(0) + cos = self.cos_cache[position_ids[0]] # [MAX, HEAD_DIM] + sin = self.sin_cache[position_ids[0]] + + # Round 0: prefill, predict code[0] with lm_head[0] from position 1 + attn_mask = self._attention_mask(mask_1d, dtype) + hidden = self._run_layers(buf, attn_mask, cos, sin) # [B, MAX, H] + last_valid_hidden = hidden[:, 1:2, :] # [B, 1, H] + logits0 = last_valid_hidden @ self.lm_head_stacked[0].transpose(-1, -2).to(dtype) + # argmax doesn't have bf16 → int on Neuron; use float() + code = logits0.float().argmax(dim=-1) # [B, 1], int64 + code = code.to(torch.int32) + + codes_list = [code] + mid_hiddens_list = [] + + # Rounds 1..14: decode (produces codes 1..14, plus mid_hiddens 1..14) + # Total codes = 1 (prefill) + 14 (decode) = 15. slot goes 2..15. + for gs in range(1, NUM_EMBED_TABLES): + # Embed the previously-predicted code using codec_embedding[gs-1] + # code shape [B, 1], flatten to [B] then index_select + emb_table = self.codec_embedding_stacked[gs - 1] # [VOCAB, HIDDEN] + new_embed = F.embedding(code, emb_table) # [B, 1, HIDDEN] + + # Write into buf at position (1 + gs) + slot = 1 + gs + # Explicit clone to keep NEFF happy about mutable buffer writes + buf = buf.clone() + buf[:, slot:slot + 1, :] = new_embed + + # Extend validity mask + mask_1d = mask_1d.clone() + mask_1d[:, slot] = 1 + + attn_mask = self._attention_mask(mask_1d, dtype) + hidden = self._run_layers(buf, attn_mask, cos, sin) + # Take the newly-computed position's hidden as "mid residual" + this_hidden = hidden[:, slot:slot + 1, :] # [B, 1, H] + mid_hiddens_list.append(this_hidden) + + # Predict next code via lm_head[gs] + head_w = self.lm_head_stacked[gs].transpose(-1, -2).to(dtype) + next_logits = this_hidden @ head_w # [B, 1, V] + code = next_logits.float().argmax(dim=-1).to(torch.int32) # [B, 1] + codes_list.append(code) + + # Stack outputs + codes_out = torch.cat(codes_list, dim=1) # [B, 15] + mid_hiddens_out = torch.cat(mid_hiddens_list, dim=1) # [B, 14, H] + return codes_out, mid_hiddens_out diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_moe.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_moe.py new file mode 100644 index 00000000..ea6887db --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_moe.py @@ -0,0 +1,775 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Qwen3-Omni-MoE text model (Thinker) for NxD Inference. + +Supports both text-only CausalLM mode and multimodal (vision+audio+text) mode. + +The thinker text model is architecturally identical to Qwen3-MoE with mRoPE +(multimodal rotary position embeddings) for 3D position encoding (time, height, width). +This implementation reuses the NxD MoE modules and Qwen3-VL's mRoPE. + +Key features: + 1. Config navigation: thinker_config -> text_config + 2. State dict prefix stripping: "thinker.model." / "thinker.lm_head." + 3. mRoPE with interleaved layout (same as Qwen3-VL) + 4. Vision embedding scatter for multimodal fusion (deepstack) +""" +import gc +import json +import logging +import os +import warnings +from typing import Dict, List, Optional, Tuple, Type, Union + +import math +import torch +from torch import nn + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, ParallelEmbedding +from neuronx_distributed.utils import cpu_mode + +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeRMSNorm + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + MOE_TKG_MK_INTERMEDIATE_PER_TP, +) +from neuronx_distributed_inference.models.model_base import NeuronBaseForCausalLM, NeuronBaseModel +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ImageToTextModelWrapper +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import scatter_by_index_put +from neuronx_distributed_inference.modules.attention.attention_base import NeuronAttentionBase +from neuronx_distributed_inference.modules.attention.utils import RotaryEmbedding +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module +from neuronx_distributed.parallel_layers.mappings import ( + _reduce_scatter_along_dim, + gather_from_sequence_parallel_region, +) +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + +logger = logging.getLogger(__name__) + + +def get_rmsnorm_cls(): + return Qwen3MoeRMSNorm if cpu_mode() else CustomRMSNorm + + +# --------------------------------------------------------------------------- +# State dict conversion helpers (same logic as qwen3_moe) +# --------------------------------------------------------------------------- + +def _strip_thinker_prefix(state_dict: dict) -> dict: + """ + Strip the thinker prefix from HF Qwen3-Omni state dict keys. + + HF keys look like: + thinker.model.embed_tokens.weight + thinker.model.layers.0.self_attn.q_proj.weight + thinker.lm_head.weight + + We map them to: + embed_tokens.weight + layers.0.self_attn.q_proj.weight + lm_head.weight + """ + prefixes = ["model.thinker.model.", "thinker.model."] + lm_head_prefixes = ["model.thinker.lm_head.", "thinker.lm_head."] + + # detect prefix + model_prefix = "" + for p in prefixes: + if any(k.startswith(p) for k in state_dict): + model_prefix = p + break + + lm_head_prefix = "" + for p in lm_head_prefixes: + if any(k.startswith(p) for k in state_dict): + lm_head_prefix = p + break + + stripped = {} + for key, value in state_dict.items(): + if model_prefix and key.startswith(model_prefix): + new_key = key[len(model_prefix):] + stripped[new_key] = value + elif lm_head_prefix and key.startswith(lm_head_prefix): + new_key = "lm_head." + key[len(lm_head_prefix):] + stripped[new_key] = value + elif not model_prefix: + # no prefix detected — keys are already bare + stripped[key] = value + + logger.info( + "Stripped thinker prefix: %d HF keys -> %d thinker text keys (prefix='%s')", + len(state_dict), len(stripped), model_prefix, + ) + return stripped + + +def _helper_concat_and_delete_qkv(sd: dict, layer: int, attr: str): + sd[f"layers.{layer}.self_attn.Wqkv.{attr}"] = torch.cat([ + sd[f"layers.{layer}.self_attn.q_proj.{attr}"], + sd[f"layers.{layer}.self_attn.k_proj.{attr}"], + sd[f"layers.{layer}.self_attn.v_proj.{attr}"], + ]) + del sd[f"layers.{layer}.self_attn.q_proj.{attr}"] + del sd[f"layers.{layer}.self_attn.k_proj.{attr}"] + del sd[f"layers.{layer}.self_attn.v_proj.{attr}"] + + +def convert_state_dict_to_fused_qkv(sd: dict, cfg: InferenceConfig) -> dict: + mods_to_skip = getattr(cfg.neuron_config, "modules_to_not_convert", None) or [] + for l in range(cfg.num_hidden_layers): + _helper_concat_and_delete_qkv(sd, l, "weight") + if ( + cfg.neuron_config.quantized_mlp_kernel_enabled or cfg.neuron_config.quantized + ) and f"layers.{l}.self_attn" not in mods_to_skip: + _helper_concat_and_delete_qkv(sd, l, "scale") + gc.collect() + return sd + + +def convert_qwen3_omni_moe_hf_to_neuron_state_dict(state_dict: dict, config) -> dict: + """ + Convert HF Qwen3-Omni-MoE thinker text state dict to NxD MoE format. + + Steps: + 1. Strip thinker.model.* prefix + 2. Add rank_util tensors for TP + 3. Rename q_norm/k_norm -> q_layernorm/k_layernorm + 4. Rename router: mlp.gate -> mlp.router.linear_router + 5. Reorganize expert weights into stacked 3D tensors + 6. Optionally fuse QKV + """ + assert config.neuron_config.glu_mlp is True, "Only GLU MLP is supported" + + neuron_state_dict = _strip_thinker_prefix(state_dict) + + # rank utilities + tp = config.neuron_config.tp_degree + neuron_state_dict["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + + for l in range(config.num_hidden_layers): + neuron_state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + + # rename qk norm + for proj in ("q", "k"): + old = f"layers.{l}.self_attn.{proj}_norm.weight" + new = f"layers.{l}.self_attn.{proj}_layernorm.weight" + if old in neuron_state_dict: + neuron_state_dict[new] = neuron_state_dict.pop(old).detach().clone() + + # rename router + gate_key = f"layers.{l}.mlp.gate.weight" + if gate_key in neuron_state_dict: + neuron_state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = ( + neuron_state_dict.pop(gate_key).detach().clone() + ) + + # reorganize expert weights + sample_key = f"layers.{l}.mlp.experts.0.gate_proj.weight" + if sample_key not in neuron_state_dict: + continue + intermediate_size, hidden_size = neuron_state_dict[sample_key].shape + device = neuron_state_dict[sample_key].device + dtype = neuron_state_dict[sample_key].dtype + + gate_up_proj = torch.empty(config.num_experts, hidden_size, 2 * intermediate_size, dtype=dtype, device=device) + down_proj = torch.empty(config.num_experts, intermediate_size, hidden_size, dtype=dtype, device=device) + + for e in range(config.num_experts): + gp = neuron_state_dict.pop(f"layers.{l}.mlp.experts.{e}.gate_proj.weight").T.detach().clone() + up = neuron_state_dict.pop(f"layers.{l}.mlp.experts.{e}.up_proj.weight").T.detach().clone() + dp = neuron_state_dict.pop(f"layers.{l}.mlp.experts.{e}.down_proj.weight").T.detach().clone() + + gate_up_proj[e, :, :intermediate_size] = gp + gate_up_proj[e, :, intermediate_size:] = up + down_proj[e] = dp + + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, 2, -1) + gate_up_proj = torch.nn.functional.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, -1) + down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) + + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + neuron_state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + gc.collect() + + if config.neuron_config.fused_qkv: + neuron_state_dict = convert_state_dict_to_fused_qkv(neuron_state_dict, config) + + return neuron_state_dict + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + +class Qwen3OmniMoeInferenceConfig(InferenceConfig): + """ + Inference config for Qwen3-Omni-MoE thinker text model. + + Navigates the nested Omni config (thinker_config.text_config) and sets up + MoE parameters identically to Qwen3MoeInferenceConfig. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.num_local_experts = self.num_experts + self.n_shared_experts = 0 + + self.maybe_pad_intermediate() + self.enable_moe_fused_nki_kernel() + + self.intermediate_size = self.moe_intermediate_size + + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "softmax" + self.neuron_config.disable_numeric_cc_token = True + self.neuron_config.normalize_top_k_affinities = True + + def maybe_pad_intermediate(self): + moe_tp = self.neuron_config.moe_tp_degree + i_tp = self.moe_intermediate_size // moe_tp + if getattr(self.neuron_config.blockwise_matmul_config, "use_shard_on_intermediate_dynamic_while", False): + if i_tp % SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP != 0: + padded = math.ceil(i_tp / SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP) * SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP * moe_tp + self.moe_intermediate_pad_size = max(padded - self.moe_intermediate_size, 0) + self.moe_intermediate_size = padded + + def enable_moe_fused_nki_kernel(self): + i_tp = self.moe_intermediate_size // self.neuron_config.moe_tp_degree + if getattr(self.neuron_config, "moe_fused_nki_kernel_enabled", False) and i_tp % MOE_TKG_MK_INTERMEDIATE_PER_TP == 0: + self.moe_fused_nki_kernel_enabled = True + + def get_required_attributes(self) -> List[str]: + return [ + "head_dim", + "hidden_act", + "hidden_size", + "max_position_embeddings", + "moe_intermediate_size", + "norm_topk_prob", + "num_attention_heads", + "num_experts", + "num_experts_per_tok", + "num_hidden_layers", + "num_key_value_heads", + "rms_norm_eps", + "rope_scaling", + "rope_theta", + "vocab_size", + ] + + @classmethod + def get_neuron_config_cls(cls): + return MoENeuronConfig + + +def load_qwen3_omni_thinker_text_config(model_path: str): + """ + Return a load_config hook that extracts thinker.text_config from the + Qwen3-Omni config.json and applies it to the InferenceConfig. + """ + def load_config(self: InferenceConfig): + config_path = os.path.join(model_path, "config.json") + with open(config_path) as f: + full = json.load(f) + + thinker = full.get("thinker_config", {}) + text = thinker.get("text_config", {}) + if not text: + raise ValueError( + f"Could not find thinker_config.text_config in {config_path}" + ) + + # torch_dtype handling + hf_dtype = text.pop("torch_dtype", text.pop("dtype", None)) + if hf_dtype and self.neuron_config and not self.neuron_config.overrides_torch_dtype: + from neuronx_distributed_inference.models.config import to_torch_dtype + if isinstance(hf_dtype, str): + hf_dtype = to_torch_dtype(hf_dtype) + self.neuron_config.torch_dtype = hf_dtype + + # Remove keys that conflict with InferenceConfig internals + for skip in ("model_type", "transformers_version", "architectures", "_name_or_path"): + text.pop(skip, None) + + self.__dict__.update(text) + + return load_config + + +# --------------------------------------------------------------------------- +# Model components +# --------------------------------------------------------------------------- + +class NeuronQwen3OmniMoERotaryEmbedding(nn.Module): + """mRoPE for Qwen3-Omni — identical to Qwen3-VL's interleaved layout.""" + inv_freq: torch.Tensor + + def __init__(self, config): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + self.config = config + + rope_scaling = getattr(config, "rope_scaling", None) or {} + self.rope_type = rope_scaling.get("rope_type", "default") + assert self.rope_type == "default", f"Only 'default' rope_type supported, got {self.rope_type}" + + base = config.rope_theta + dim = config.head_dim + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + self.attention_scaling = 1.0 + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.mrope_section = rope_scaling.get("mrope_section", [24, 20, 20]) + + def forward(self, x, position_ids): + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + + inv_freq_expanded = ( + self.inv_freq[None, None, :, None].float() + .expand(3, position_ids.shape[1], -1, 1) + ) + position_ids_expanded = position_ids[:, :, None, :].float() + + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = NeuronQwen3OmniMoERotaryEmbedding.neuron_compute_freqs_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + @staticmethod + def neuron_compute_freqs_mrope(freqs: torch.Tensor, mrope_section: list) -> torch.Tensor: + """XLA-friendly interleaved mRoPE frequency computation.""" + last_dim = freqs.shape[-1] + indices = torch.arange(last_dim, device=freqs.device, dtype=torch.int64) + freqs_t = freqs[0].clone() + for dim, offset in enumerate((1, 2), start=1): + length = mrope_section[dim] * 3 + mask = (indices % 3 == offset) & (indices < length) + freqs_t = torch.where(mask, freqs[dim], freqs_t) + return freqs_t + + +class NeuronQwen3OmniMoEAttention(NeuronAttentionBase): + def __init__(self, config: Qwen3OmniMoeInferenceConfig): + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling and rope_scaling.get("mrope_section"): + rotary_emb = NeuronQwen3OmniMoERotaryEmbedding(config) + else: + rotary_emb = RotaryEmbedding( + config.head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + super().__init__( + config=config, + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + head_dim=config.head_dim, + rotary_emb=rotary_emb, + rms_norm_eps=config.rms_norm_eps, + use_qk_norm=False, + ) + self.q_layernorm = get_rmsnorm_cls()(self.head_dim, self.rms_norm_eps) + self.k_layernorm = get_rmsnorm_cls()(self.head_dim, self.rms_norm_eps) + + if not parallel_state.model_parallel_is_initialized(): + raise ValueError( + "NeuronQwen3OmniMoEAttention must be initialized in a distributed env." + ) + + +class NeuronQwen3OmniMoeDecoderLayer(nn.Module): + def __init__(self, config: Qwen3OmniMoeInferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronQwen3OmniMoEAttention(config=config) + self.moe_fused_nki_kernel_enabled = getattr(config, "moe_fused_nki_kernel_enabled", False) + + self.input_layernorm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if self.moe_fused_nki_kernel_enabled: + self.mlp = initialize_moe_module( + config=config, rmsnorm=self.post_attention_layernorm, init_tkg_module=True, + ) + else: + self.mlp = initialize_moe_module(config=config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + self.moe_mask_padded_tokens = config.neuron_config.moe_mask_padded_tokens + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated. Use `attention_mask` instead." + ) + + residual = hidden_states + + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + qkv_fused_rmsnorm = None + if self.input_layernorm: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + if not self.moe_fused_nki_kernel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + is_spec = ( + self.config.neuron_config.enable_fused_speculation + and not self.config.neuron_config.is_prefill_stage + ) + hidden_states = self.mlp(hidden_states, padding_mask, is_speculative_decoding=is_spec)[0] + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +class NeuronQwen3OmniMoeModel(NeuronBaseModel): + def setup_attr_for_model(self, config: Qwen3OmniMoeInferenceConfig): + self.on_device_sampling = config.neuron_config.on_device_sampling_config is not None + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Qwen3OmniMoeInferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList([ + NeuronQwen3OmniMoeDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + ) + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask) -> torch.Tensor: + return scatter_by_index_put(inputs_embeds, vision_embeddings, vision_mask) + + def deepstack_process_xla( + self, + hidden_states: torch.Tensor, + visual_embeds: torch.Tensor, + vision_mask_positions: torch.Tensor, + ) -> torch.Tensor: + if self.sequence_parallel_enabled: + from neuronx_distributed_inference.utils.distributed import get_tp_group + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=get_tp_group(self.config), + ) + + assert hidden_states.shape == visual_embeds.shape, ( + f"Shape mismatch: hidden_states.shape={hidden_states.shape}, " + f"visual_embeds.shape={visual_embeds.shape}" + ) + + expanded_visual_embeds = torch.zeros_like(hidden_states) + expanded_visual_embeds = scatter_by_index_put( + expanded_visual_embeds, visual_embeds, vision_mask_positions + ) + hidden_states = hidden_states + expanded_visual_embeds + + if self.sequence_parallel_enabled: + from neuronx_distributed_inference.utils.distributed import get_tp_group + hidden_states = _reduce_scatter_along_dim( + hidden_states, + self.sequence_dimension, + xm.REDUCE_MAX, + process_group=get_tp_group(self.config), + ) + + return hidden_states + + +# --------------------------------------------------------------------------- +# Top-level CausalLM +# --------------------------------------------------------------------------- + +class NeuronQwen3OmniMoeForCausalLM(NeuronBaseForCausalLM): + """ + Causal LM wrapper for the Qwen3-Omni-MoE thinker text model on Neuron. + + Usage: + from modeling_qwen3_omni_moe import ( + NeuronQwen3OmniMoeForCausalLM, + Qwen3OmniMoeInferenceConfig, + load_qwen3_omni_thinker_text_config, + ) + from neuronx_distributed_inference.models.config import MoENeuronConfig + + neuron_config = MoENeuronConfig(tp_degree=8, batch_size=1, seq_len=512, ...) + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(model_path), + ) + model = NeuronQwen3OmniMoeForCausalLM(model_path, config) + model.compile(compiled_model_path) + model.load(compiled_model_path) + """ + + _model_cls = NeuronQwen3OmniMoeModel + + @staticmethod + def load_hf_model(model_path, **kwargs): + from transformers import AutoModelForCausalLM + return AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, **kwargs) + + @classmethod + def get_config_cls(cls): + return Qwen3OmniMoeInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict(state_dict: dict, config) -> dict: + return convert_qwen3_omni_moe_hf_to_neuron_state_dict(state_dict, config) + + def enable_context_encoding(self): + self.compile_tag = CONTEXT_ENCODING_MODEL_TAG + super().enable_context_encoding() + + def enable_token_generation(self): + self.compile_tag = TOKEN_GENERATION_MODEL_TAG + super().enable_token_generation() + + def get_compiler_args(self): + if self.compile_tag == CONTEXT_ENCODING_MODEL_TAG: + opt = "-O1" + elif self.compile_tag == TOKEN_GENERATION_MODEL_TAG: + opt = "-O3" if self.neuron_config.moe_ep_degree > 1 else "-O1" + else: + opt = "-O1" + + args = ( + f"--enable-saturate-infinity --enable-mixed-precision-accumulation " + f"--model-type transformer {opt}" + ) + args += " --tensorizer-options='--enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2'" + args += " --auto-cast=none" + args += " --internal-enable-dge-levels vector_dynamic_offsets" + args += " --internal-hlo2tensorizer-options='--verify-hlo=true'" + + if self.neuron_config.scratchpad_page_size: + args += f" --hbm-scratchpad-page-size={self.neuron_config.scratchpad_page_size} " + + if self.neuron_config.attn_block_tkg_nki_kernel_enabled: + assert self.neuron_config.attn_block_tkg_nki_kernel_cascaded_attention, ( + "attn_block_tkg_nki_kernel_enabled requires attn_block_tkg_nki_kernel_cascaded_attention" + ) + self.neuron_config.pre_rope_rmsnorm = True + args += " --internal-max-instruction-limit=15000000" + + return args + + +# --------------------------------------------------------------------------- +# Text model wrapper for multimodal (ImageToText) mode +# --------------------------------------------------------------------------- + +class NeuronQwen3OmniMoeTextModelWrapper(ImageToTextModelWrapper): + """Wraps the MoE text model for multimodal inference with mRoPE position IDs.""" + + _ROTARY_POSITION_IDS_INDEX = 21 + + def _forward_with_pad(self, *args): + """Fix rotary_position_ids after parent's incorrect dim-0 batch slice.""" + args = list(args) + rpi = args[self._ROTARY_POSITION_IDS_INDEX] + + if rpi.dim() == 3 and rpi.shape[0] != 3: + rpi = rpi[:1].expand(3, -1, -1) + + if rpi.dim() == 3 and rpi.shape[1] < self.neuron_config.batch_size: + pad_size = self.neuron_config.batch_size - rpi.shape[1] + padding = rpi[:, :1, :].expand(-1, pad_size, -1) + rpi = torch.cat([rpi, padding], dim=1) + + args[self._ROTARY_POSITION_IDS_INDEX] = rpi + return super()._forward_with_pad(*args) + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + input_batch_size, input_sequence_len = input_ids.shape[0], input_ids.shape[-1] + if input_sequence_len > 1: # prefill + vision_embeddings = torch.zeros( + input_batch_size, + config.neuron_config.seq_len, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, n_active_tokens, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + deepstack_vision_embeds = [ + torch.zeros( + input_batch_size, + config.neuron_config.seq_len, + config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + for _ in getattr(config, "deepstack_visual_indexes", []) + ] + if len(deepstack_vision_embeds) > 0: + deepstack_vision_embeds = torch.stack(deepstack_vision_embeds) + else: + deepstack_vision_embeds = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + else: # decode + vision_embeddings = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + vision_mask = torch.zeros((0), dtype=torch.bool) + deepstack_vision_embeds = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + return vision_embeddings, vision_mask, deepstack_vision_embeds + + def input_generator(self): + inputs = [] + for bucket in self.neuron_config.buckets: + n_active_tokens = ( + bucket + if self.neuron_config.bucket_n_active_tokens + else self.neuron_config.n_active_tokens + ) + + input_ids = torch.zeros( + (self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + attention_mask = torch.zeros( + (self.neuron_config.batch_size, bucket), dtype=torch.int32 + ) + position_ids = torch.zeros( + (self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32) + + sampling_params_len = prepare_sampling_params(1).shape[1] + sampling_params = torch.zeros( + (self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32 + ) + + vision_embeddings, vision_mask, deepstack_vision_embeds = ( + self.get_dummy_vision_inputs( + config=self.config, + input_ids=input_ids, + n_active_tokens=n_active_tokens, + fill_value=0, + ) + ) + + rotary_position_ids = torch.zeros( + (3, self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + + if self.tag == CONTEXT_ENCODING_MODEL_TAG or self.tag == TOKEN_GENERATION_MODEL_TAG: + inputs.append( + ( + input_ids, # 0 + attention_mask, # 1 + position_ids, # 2 + seq_ids, # 3 + sampling_params, # 4 + torch.empty(0), # 5 prev_hidden + torch.empty(0), # 6 adapter_ids + torch.empty(0), # 7 accepted_indices + torch.empty(0), # 8 current_length + torch.empty(0), # 9 medusa_mask + torch.empty(0), # 10 scatter_index + torch.empty(0), # 11 slot_mapping + torch.empty(0), # 12 active_block_table + torch.empty(0), # 13 num_queries + torch.empty(0), # 14 computed_context_lens + torch.empty(0), # 15 tile_q_indices + torch.empty(0), # 16 tile_block_tables + torch.empty(0), # 17 tile_masks + torch.empty(0), # 18 inputs_embeds + torch.empty(0), # 19 kv_cache + torch.empty(0), # 20 active_mask + rotary_position_ids, # 21 + vision_embeddings, # 22 + vision_mask, # 23 + deepstack_vision_embeds, # 24 + ) + ) + else: + raise ValueError(f"Unsupported model tag '{self.tag}'") + + return inputs diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_talker.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_talker.py new file mode 100644 index 00000000..334dde3d --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_talker.py @@ -0,0 +1,633 @@ +"""Qwen3-Omni Talker MoE transformer on Neuron. + +Talker architecture (config.talker_config.text_config): + - 20 layers, hidden=1024, heads=16, kv_heads=2 (GQA g=8), head_dim=128 + - 128 experts (moe_intermediate=384), top-6, norm_topk_prob=True + - shared_expert (intermediate=768) gated by sigmoid(shared_expert_gate(x)) + - q_norm / k_norm (per-head-dim RMSNorm) + - MRoPE theta=1e6, mrope_section=[24,20,20], interleaved + +The HF talker pipeline: + inputs_embeds ──► self.model (20 MoE layers) ──► codec_head ──► codec token + (this file traces this block) + +text_projection / hidden_projection / code_predictor / codec_head stay on CPU +and are orchestrated by host Python. This file wraps only the 20-layer MoE +body into Neuron, following the thinker pattern. + +The NxDI stock MoE module (`initialize_moe_module`) ties shared_expert size to +`config.intermediate_size`; Qwen3-Omni talker has different sizes (MoE=384, +shared=768), so we wrap the MoE block and add a separate SharedExpertSwiGLU. +""" +import copy +import logging +import math +import os +import warnings +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + ParallelEmbedding, + RowParallelLinear, +) + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + NeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ( + ImageToTextModelWrapper, +) +from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module + +# Reuse the thinker's Qwen3-VL attention (MRoPE, q/k-norm, GQA) — same exact +# attention architecture as the talker. +from neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_text import ( + NeuronQwen3VLAttention, + get_rmsnorm_cls, +) + +logger = logging.getLogger("Neuron") + + +# ----------------------------------------------------------------------------- +# Shared expert (not provided by initialize_moe_module when intermediate sizes +# differ between routed and shared experts). +# ----------------------------------------------------------------------------- + +class SharedExpertSwiGLU(nn.Module): + """Shared SwiGLU MLP with a sigmoid gate — matches HF Qwen3-Omni talker.""" + + def __init__(self, config: InferenceConfig): + super().__init__() + hidden = config.hidden_size + inter = config.shared_expert_intermediate_size + dtype = config.neuron_config.torch_dtype + self.gate_proj = ColumnParallelLinear( + hidden, inter, bias=False, gather_output=False, dtype=dtype, + ) + self.up_proj = ColumnParallelLinear( + hidden, inter, bias=False, gather_output=False, dtype=dtype, + ) + self.down_proj = RowParallelLinear( + inter, hidden, bias=False, input_is_parallel=True, dtype=dtype, + ) + # Output is a single sigmoid gate per token; output_size=1 isn't + # divisible by TP so we keep this one replicated (plain nn.Linear). + self.gate = nn.Linear(hidden, 1, bias=False, dtype=dtype) + + def forward(self, x): + y = self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + g = torch.sigmoid(self.gate(x).to(y.dtype)) + return g * y + + +# ----------------------------------------------------------------------------- +# Talker decoder layer: reuse thinker pattern, add shared_expert on top of +# routed MoE. +# ----------------------------------------------------------------------------- + +class NeuronTalkerDecoderLayer(nn.Module): + def __init__(self, config: InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronQwen3VLAttention(config) + + rmsnorm_cls = get_rmsnorm_cls() + self.input_layernorm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = rmsnorm_cls(config.hidden_size, eps=config.rms_norm_eps) + + # Routed experts via NxDI (driven off config.intermediate_size=MOE_INTER) + self.mlp = initialize_moe_module(config=config) + # Separate shared expert (different intermediate size) + self.shared_expert = SharedExpertSwiGLU(config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + rotary_position_ids: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + residual = hidden_states + + qkv_fused_rmsnorm = None + if self.input_layernorm is not None: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rotary_position_ids=rotary_position_ids, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + norm_out = self.post_attention_layernorm(hidden_states) + is_speculative_decoding = ( + self.config.neuron_config.enable_fused_speculation + and not self.config.neuron_config.is_prefill_stage + ) + routed = self.mlp(norm_out, padding_mask, is_speculative_decoding=is_speculative_decoding)[0] + shared = self.shared_expert(norm_out) + hidden_states = residual + routed + shared + + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +# ----------------------------------------------------------------------------- +# Talker transformer (NeuronBaseModel) +# ----------------------------------------------------------------------------- + +class NeuronTalkerModel(NeuronBaseModel): + """Talker MoE transformer on Neuron. + + Input: inputs_embeds [B, S, H] — the host passes the already-computed + sum of text/code embeddings (HF's prepare_inputs_for_generation output). + We expose this via the `vision_embeddings` slot of ImageToTextModelWrapper + so existing input/output plumbing works without re-tracing the framework. + """ + + def setup_attr_for_model(self, config): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config): + self.padding_idx = getattr(config, "pad_token_id", 0) or 0 + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + # embed_tokens is used for lookup during token-generation autoregressive + # stepping (when host passes input_ids instead of inputs_embeds). + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + ) + # codec_head on Neuron (avoids a big CPU matmul every step) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + pad=True, + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, self.hidden_size, self.padding_idx, + ) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + + self.layers = nn.ModuleList( + [NeuronTalkerDecoderLayer(config, i) for i in range(config.num_hidden_layers)] + ) + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + def get_model_output( + self, + input_ids=None, + inputs_embeds=None, + vision_embeddings=None, + vision_mask=None, + is_for_context_encoding: bool = False, + adapter_ids=None, + **kwargs, + ): + # We apply vision-embed injection ourselves (both prefill and decode) + # and pass through to super() with vision_*=None so the upstream does + # not try to re-inject (its shape expectations don't match ours). + if ( + vision_embeddings is not None + and vision_mask is not None + and vision_embeddings.numel() > 0 + ): + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if vision_embeddings.dtype != self.config.neuron_config.torch_dtype: + vision_embeddings = vision_embeddings.to(self.config.neuron_config.torch_dtype) + inputs_embeds = self.encode_vision_to_input( + inputs_embeds, vision_embeddings, vision_mask + ) + vision_embeddings = None + vision_mask = None + + return super().get_model_output( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=vision_embeddings, + vision_mask=vision_mask, + is_for_context_encoding=is_for_context_encoding, + adapter_ids=adapter_ids, + **kwargs, + ) + + def encode_vision_to_input(self, inputs_embeds, vision_embeddings, vision_mask): + """Replace inputs_embeds with the host-provided embedding wherever the + mask says valid. Unlike thinker's scatter-by-index, we accept that + host already built the full per-position embedding, and just write + wherever the mask says to. + + Talker's vision_embeddings comes in full seq_len size (4096) while + inputs_embeds matches the bucket size (say 64). We need to crop to + the bucket's active region. + """ + if vision_embeddings is None or vision_embeddings.numel() == 0: + return inputs_embeds + S_in = inputs_embeds.shape[1] + # Accept either bucket-sized or seq_len-sized vision_embeddings; + # always take the leading S_in positions. + if vision_embeddings.shape[1] >= S_in: + ve = vision_embeddings[:, :S_in, :] + else: + pad = torch.zeros( + inputs_embeds.shape[0], S_in - vision_embeddings.shape[1], + inputs_embeds.shape[2], dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + ve = torch.cat([vision_embeddings, pad], dim=1) + if vision_mask is None or vision_mask.numel() == 0: + return ve # full replacement + # vision_mask shape: [B, n_active_tokens, 1], same sizing convention + vm = vision_mask + if vm.shape[1] >= S_in: + vm = vm[:, :S_in, :] + mask_bool = vm.bool() + return torch.where(mask_bool, ve, inputs_embeds) + + +# ----------------------------------------------------------------------------- +# Model wrapper: provides dummy vision/deepstack tensors during tracing +# ----------------------------------------------------------------------------- + +class NeuronTalkerModelWrapper(ImageToTextModelWrapper): + """Input generator that emits dummy vision_embeddings at token-generation + time too (so the traced NEFF bakes in the ADD/REPLACE path). + """ + + _ROTARY_POSITION_IDS_INDEX = 21 + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + B, S = input_ids.shape + if S > 1: + vision_embeddings = torch.zeros( + B, config.neuron_config.seq_len, config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(B, n_active_tokens, 1), fill_value=fill_value, dtype=torch.int32, + ) + else: + vision_embeddings = torch.zeros( + B, 1, config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(B, 1, 1), fill_value=fill_value, dtype=torch.int32, + ) + # Talker has no deepstack — pass empty tensor + deepstack_vision_embeds = torch.zeros( + (0), dtype=config.neuron_config.torch_dtype + ) + return vision_embeddings, vision_mask, deepstack_vision_embeds + + def input_generator(self): + inputs = [] + for bucket in self.neuron_config.buckets: + n_active_tokens = ( + bucket if self.neuron_config.bucket_n_active_tokens + else self.neuron_config.n_active_tokens + ) + input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) + attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32) + position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) + seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32) + sampling_params_len = prepare_sampling_params(1).shape[1] + sampling_params = torch.zeros( + (self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32 + ) + ve, vm, ds = self.get_dummy_vision_inputs( + config=self.config, input_ids=input_ids, + n_active_tokens=n_active_tokens, fill_value=0, + ) + rotary_position_ids = torch.zeros( + (3, self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32, + ) + if self.tag in (CONTEXT_ENCODING_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG): + inputs.append(( + input_ids, attention_mask, position_ids, seq_ids, sampling_params, + torch.empty(0), # prev_hidden + torch.empty(0), # adapter_ids + torch.empty(0), # accepted_indices + torch.empty(0), # current_length + torch.empty(0), # medusa_mask + torch.empty(0), # scatter_index + torch.empty(0), # slot_mapping + torch.empty(0), # active_block_table + torch.empty(0), # num_queries + torch.empty(0), # computed_context_lens + torch.empty(0), # tile_q_indices + torch.empty(0), # tile_block_tables + torch.empty(0), # tile_masks + torch.empty(0), # inputs_embeds + torch.empty(0), # kv_cache + torch.empty(0), # active_mask + rotary_position_ids, # 21 + ve, # 22 + vm, # 23 + ds, # 24 + )) + else: + raise ValueError(f"Unsupported tag: {self.tag}") + return inputs + + +# ----------------------------------------------------------------------------- +# InferenceConfig +# ----------------------------------------------------------------------------- + +class TalkerInferenceConfig(InferenceConfig): + """Config wrapping the talker.text_config + MoE-specific Neuron fields.""" + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return MoENeuronConfig + + def get_required_attributes(self) -> List[str]: + return [ + "hidden_size", "num_attention_heads", "num_hidden_layers", + "num_key_value_heads", "vocab_size", "rms_norm_eps", "rope_theta", + "moe_intermediate_size", "num_experts", "num_experts_per_tok", + "shared_expert_intermediate_size", + ] + + def add_derived_config(self): + self.num_cores_per_group = 1 + # Talker attention defaults + self.attention_bias = False + self.qkv_bias = False + self.o_bias = False + # Explicit head_dim (already 128 in HF config) + if not hasattr(self, "head_dim") or self.head_dim is None: + self.head_dim = 128 + # MRoPE section (from rope_scaling) + rs = getattr(self, "rope_scaling", None) or {} + self.mrope_section = rs.get("mrope_section", [24, 20, 20]) + + # MoE adapters for initialize_moe_module: intermediate_size must be + # the MoE expert intermediate (NOT shared expert). + self.intermediate_size = self.moe_intermediate_size + # num_local_experts alias + if not hasattr(self, "num_local_experts"): + self.num_local_experts = self.num_experts + # No shared experts via initialize_moe_module — we handle those + # ourselves in SharedExpertSwiGLU because of different intermediate + # size from the routed experts. + self.n_shared_experts = 0 + # GLU MLP required for experts + self.neuron_config.glu_mlp = True + # Router config + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "softmax" + self.neuron_config.disable_numeric_cc_token = True + if getattr(self, "norm_topk_prob", True): + self.neuron_config.normalize_top_k_affinities = True + + +# ----------------------------------------------------------------------------- +# Application +# ----------------------------------------------------------------------------- + +class NeuronTalkerForCausalLM(NeuronBaseForCausalLM): + """Autoregressive talker on Neuron. + + Usage is similar to thinker: compile() then load(); then use adapter.generate + from host Python, with the host filling in `inputs_embeds` (via + vision_embeddings) for each prefill/decode call. + """ + + _model_cls = NeuronTalkerModel + + @classmethod + def get_config_cls(cls): + return TalkerInferenceConfig + + def get_model_wrapper_cls(self): + return NeuronTalkerModelWrapper + + def get_required_kwargs(self) -> List[str]: + return ["vision_embeddings", "vision_mask"] + + def get_compiler_args(self) -> str: + cc = self.neuron_config.cc_pipeline_tiling_factor + return ( + f"--auto-cast=none --model-type=transformer " + f"--tensorizer-options='--enable-ccop-compute-overlap " + f"--cc-pipeline-tiling-factor={cc}' -O1 " + f"--internal-max-instruction-limit=15000000" + ) + + @staticmethod + def update_state_dict_for_tied_weights(state_dict): + # Talker uses a separate codec_head (maps hidden -> codec vocab), + # never tied to embed_tokens. + pass + + @staticmethod + def load_hf_model(model_path, **kwargs): + # We never instantiate the full HF model to get state — we read + # safetensors directly via get_state_dict override. + raise NotImplementedError("Use get_state_dict / checkpoint_loader_fn") + + @classmethod + def get_state_dict(cls, model_name_or_path: str, config) -> dict: + """Read only talker.* tensors from the HF safetensors shards and + run our conversion. Avoids loading the full 30B model. + """ + import json as _json + from safetensors.torch import safe_open + + index_path = os.path.join(model_name_or_path, "model.safetensors.index.json") + talker_raw = {} + if os.path.exists(index_path): + with open(index_path) as f: + weight_map = _json.load(f)["weight_map"] + wanted_shards = { + weight_map[k] for k in weight_map + if k.startswith("talker.") and not k.startswith("talker.code_predictor.") + } + for shard in sorted(wanted_shards): + with safe_open(os.path.join(model_name_or_path, shard), framework="pt") as sf: + for k in sf.keys(): + if k.startswith("talker.") and not k.startswith("talker.code_predictor."): + talker_raw[k[len("talker."):]] = sf.get_tensor(k) + else: + for fname in sorted(os.listdir(model_name_or_path)): + if fname.endswith(".safetensors"): + with safe_open(os.path.join(model_name_or_path, fname), framework="pt") as sf: + for k in sf.keys(): + if k.startswith("talker.") and not k.startswith("talker.code_predictor."): + talker_raw[k[len("talker."):]] = sf.get_tensor(k) + + return convert_talker_hf_to_neuron(talker_raw, config) + + +# ----------------------------------------------------------------------------- +# Weight conversion: HF 'talker.*' → Neuron keys +# ----------------------------------------------------------------------------- + +def convert_talker_hf_to_neuron(hf_sd: Dict[str, torch.Tensor], config: TalkerInferenceConfig) -> Dict: + """Convert HF talker state dict to the key layout our module expects. + + HF keys (present in root state_dict): + model.embed_codec.weight (actually: model.codec_embedding.weight? verify) + model.layers.{l}.input_layernorm.weight + model.layers.{l}.post_attention_layernorm.weight + model.layers.{l}.self_attn.{q,k,v,o}_proj.weight + model.layers.{l}.self_attn.{q,k}_norm.weight + model.layers.{l}.mlp.gate.weight + model.layers.{l}.mlp.experts.{e}.{gate,up,down}_proj.weight + model.layers.{l}.mlp.shared_expert.{gate,up,down}_proj.weight + model.layers.{l}.mlp.shared_expert_gate.weight + model.norm.weight + codec_head.weight + (code_predictor and projections also present but dropped here — stay on CPU) + """ + import gc + num_experts = config.num_experts + moe_inter = config.moe_intermediate_size + hidden = config.hidden_size + tp = config.neuron_config.tp_degree + + out: Dict[str, torch.Tensor] = {} + + # Embedding table (codec token embedding used for talker token generation + # is model.codec_embedding). But HF's talker uses `get_input_embeddings()` + # which returns model.codec_embedding (the talker's own codec vocab). + # Keep the key neutral: "embed_tokens.weight". + if "model.codec_embedding.weight" in hf_sd: + out["embed_tokens.weight"] = hf_sd["model.codec_embedding.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + elif "model.embed_tokens.weight" in hf_sd: + out["embed_tokens.weight"] = hf_sd["model.embed_tokens.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + + # codec_head → lm_head + if "codec_head.weight" in hf_sd: + out["lm_head.weight"] = hf_sd["codec_head.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + + # Final norm + if "model.norm.weight" in hf_sd: + out["norm.weight"] = hf_sd["model.norm.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + + for l in range(config.num_hidden_layers): + base = f"model.layers.{l}" + tgt = f"layers.{l}" + + # Norms + out[f"{tgt}.input_layernorm.weight"] = hf_sd[f"{base}.input_layernorm.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.post_attention_layernorm.weight"] = hf_sd[f"{base}.post_attention_layernorm.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + + # Attention: NeuronQwen3VLAttention expects qkv_proj.{q,k,v}_proj + o_proj.o_proj + # format — match the thinker convert logic. + out[f"{tgt}.self_attn.qkv_proj.q_proj.weight"] = hf_sd[f"{base}.self_attn.q_proj.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.self_attn.qkv_proj.k_proj.weight"] = hf_sd[f"{base}.self_attn.k_proj.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.self_attn.qkv_proj.v_proj.weight"] = hf_sd[f"{base}.self_attn.v_proj.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.self_attn.o_proj.o_proj.weight"] = hf_sd[f"{base}.self_attn.o_proj.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + # q_norm / k_norm map to q_layernorm / k_layernorm in thinker naming + out[f"{tgt}.self_attn.q_layernorm.weight"] = hf_sd[f"{base}.self_attn.q_norm.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.self_attn.k_layernorm.weight"] = hf_sd[f"{base}.self_attn.k_norm.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + out[f"{tgt}.self_attn.rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + + # Routed MoE: gate → router.linear_router; stack experts into gate_up_proj/down_proj + out[f"{tgt}.mlp.router.linear_router.weight"] = hf_sd[f"{base}.mlp.gate.weight"].to( + config.neuron_config.torch_dtype + ).contiguous() + + dtype = config.neuron_config.torch_dtype + gate_up = torch.empty(num_experts, hidden, 2 * moe_inter, dtype=dtype) + down = torch.empty(num_experts, moe_inter, hidden, dtype=dtype) + for e in range(num_experts): + gw = hf_sd[f"{base}.mlp.experts.{e}.gate_proj.weight"] + uw = hf_sd[f"{base}.mlp.experts.{e}.up_proj.weight"] + dw = hf_sd[f"{base}.mlp.experts.{e}.down_proj.weight"] + gate_up[e, :, :moe_inter].copy_(gw.T.to(dtype)) + gate_up[e, :, moe_inter:].copy_(uw.T.to(dtype)) + down[e].copy_(dw.T.to(dtype)) + out[f"{tgt}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up + out[f"{tgt}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down + + # Shared expert (SharedExpertSwiGLU: gate_proj/up_proj/down_proj/gate) + out[f"{tgt}.shared_expert.gate_proj.weight"] = hf_sd[f"{base}.mlp.shared_expert.gate_proj.weight"].to(dtype).contiguous() + out[f"{tgt}.shared_expert.up_proj.weight"] = hf_sd[f"{base}.mlp.shared_expert.up_proj.weight"].to(dtype).contiguous() + out[f"{tgt}.shared_expert.down_proj.weight"] = hf_sd[f"{base}.mlp.shared_expert.down_proj.weight"].to(dtype).contiguous() + out[f"{tgt}.shared_expert.gate.weight"] = hf_sd[f"{base}.mlp.shared_expert_gate.weight"].to(dtype).contiguous() + + gc.collect() + + out["rank_util.rank"] = torch.arange(0, tp, dtype=torch.int32) + return out diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_text.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_text.py new file mode 100644 index 00000000..bd4a505e --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_text.py @@ -0,0 +1,361 @@ +"""Qwen3-Omni MoE text model for NxD Inference. + +Combines Qwen3-VL's multimodal attention (MRoPE, deepstack, vision scatter) +with Qwen3-MoE's sparse mixture-of-experts FFN layers. +""" + +import gc +import math +import warnings +from typing import Dict, Any, List, Optional, Tuple + +import torch +from torch import nn + +from neuronx_distributed.parallel_layers import parallel_state +from neuronx_distributed.parallel_layers.layers import ColumnParallelLinear, ParallelEmbedding + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + MoENeuronConfig, + SHARD_ON_INTERMEDIATE_DIMENSION_PER_TP, + MOE_TKG_MK_INTERMEDIATE_PER_TP, +) +from neuronx_distributed_inference.models.image_to_text_model_base import NeuronBaseForImageToText +from neuronx_distributed_inference.models.image_to_text_model_wrapper import ImageToTextModelWrapper +from neuronx_distributed_inference.models.model_base import NeuronBaseForCausalLM, NeuronBaseModel +from neuronx_distributed_inference.models.model_wrapper import ( + CONTEXT_ENCODING_MODEL_TAG, + TOKEN_GENERATION_MODEL_TAG, +) +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module +from neuronx_distributed_inference.modules.generation.sampling import prepare_sampling_params + +# Reuse Qwen3-VL components (identical attention + MRoPE + vision integration) +from neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_text import ( + NeuronQwen3VLAttention, + NeuronQwen3VLRotaryEmbedding, + NeuronQwen3VLTextModel, + get_rmsnorm_cls, +) +from neuronx_distributed_inference.models.llama4.utils.encoder_utils import scatter_by_index_put + + +class NeuronQwen3OmniMoEDecoderLayer(nn.Module): + """Decoder layer: Qwen3-VL attention (MRoPE) + Qwen3-MoE sparse FFN.""" + + def __init__(self, config: InferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = NeuronQwen3VLAttention(config) + self.moe_fused_nki_kernel_enabled = getattr(config, "moe_fused_nki_kernel_enabled", False) + + self.input_layernorm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + if self.moe_fused_nki_kernel_enabled: + self.mlp = initialize_moe_module( + config=config, rmsnorm=self.post_attention_layernorm, init_tkg_module=True + ) + else: + self.mlp = initialize_moe_module(config=config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + self.moe_mask_padded_tokens = config.neuron_config.moe_mask_padded_tokens + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + rotary_position_ids: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, ...]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated. Use `attention_mask` instead." + ) + + residual = hidden_states + + qkv_fused_rmsnorm = None + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + if self.input_layernorm: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rotary_position_ids=rotary_position_ids, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + if not self.moe_fused_nki_kernel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + is_speculative_decoding = ( + self.config.neuron_config.enable_fused_speculation + and not self.config.neuron_config.is_prefill_stage + ) + hidden_states = self.mlp(hidden_states, padding_mask, is_speculative_decoding=is_speculative_decoding)[0] + hidden_states = residual + hidden_states + + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + return (hidden_states, present_key_value, cos_cache, sin_cache, None) + + +class NeuronQwen3OmniTextModel(NeuronQwen3VLTextModel): + """MoE text model with deepstack and vision scatter from Qwen3-VL.""" + + def init_model(self, config: InferenceConfig): + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if parallel_state.model_parallel_is_initialized(): + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + config.pad_token_id, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + pad=True, + ) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=not self.on_device_sampling, + bias=False, + pad=True, + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, self.hidden_size, self.padding_idx, + ) + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + + self.layers = nn.ModuleList( + [NeuronQwen3OmniMoEDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = get_rmsnorm_cls()(config.hidden_size, eps=config.rms_norm_eps) + + +class NeuronQwen3OmniTextModelWrapper(ImageToTextModelWrapper): + """Wrapper with MRoPE input generator and deepstack dummy inputs. + + Identical to NeuronQwen3VLTextModelWrapper. + """ + + _ROTARY_POSITION_IDS_INDEX = 21 + + def _forward_with_pad(self, *args): + args = list(args) + rpi = args[self._ROTARY_POSITION_IDS_INDEX] + if rpi.dim() == 3 and rpi.shape[0] != 3: + rpi = rpi[:1].expand(3, -1, -1) + if rpi.dim() == 3 and rpi.shape[1] < self.neuron_config.batch_size: + pad_size = self.neuron_config.batch_size - rpi.shape[1] + padding = rpi[:, :1, :].expand(-1, pad_size, -1) + rpi = torch.cat([rpi, padding], dim=1) + args[self._ROTARY_POSITION_IDS_INDEX] = rpi + return super()._forward_with_pad(*args) + + @staticmethod + def get_dummy_vision_inputs(config, input_ids, n_active_tokens, fill_value): + input_batch_size, input_sequence_len = input_ids.shape[0], input_ids.shape[-1] + if input_sequence_len > 1: + vision_embeddings = torch.zeros( + input_batch_size, config.neuron_config.seq_len, config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + vision_mask = torch.full( + size=(input_batch_size, n_active_tokens, 1), + fill_value=fill_value, + dtype=torch.int32, + ) + deepstack_vision_embeds = [ + torch.zeros( + input_batch_size, config.neuron_config.seq_len, config.hidden_size, + dtype=config.neuron_config.torch_dtype, + ) + for _ in config.deepstack_visual_indexes + ] + if len(deepstack_vision_embeds) > 0: + deepstack_vision_embeds = torch.stack(deepstack_vision_embeds) + else: + deepstack_vision_embeds = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + else: + vision_embeddings = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + vision_mask = torch.zeros((0), dtype=torch.bool) + deepstack_vision_embeds = torch.zeros((0), dtype=config.neuron_config.torch_dtype) + return vision_embeddings, vision_mask, deepstack_vision_embeds + + def input_generator(self): + inputs = [] + for bucket in self.neuron_config.buckets: + n_active_tokens = ( + bucket if self.neuron_config.bucket_n_active_tokens + else self.neuron_config.n_active_tokens + ) + input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) + attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32) + position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32) + seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32) + sampling_params_len = prepare_sampling_params(1).shape[1] + sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32) + vision_embeddings, vision_mask, deepstack_vision_embeds = self.get_dummy_vision_inputs( + config=self.config, input_ids=input_ids, + n_active_tokens=n_active_tokens, fill_value=0, + ) + rotary_position_ids = torch.zeros( + (3, self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32 + ) + if self.tag in (CONTEXT_ENCODING_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG): + inputs.append(( + input_ids, # 0 + attention_mask, # 1 + position_ids, # 2 + seq_ids, # 3 + sampling_params, # 4 + torch.empty(0), # 5 prev_hidden + torch.empty(0), # 6 adapter_ids + torch.empty(0), # 7 accepted_indices + torch.empty(0), # 8 current_length + torch.empty(0), # 9 medusa_mask + torch.empty(0), # 10 scatter_index + torch.empty(0), # 11 slot_mapping + torch.empty(0), # 12 active_block_table + torch.empty(0), # 13 num_queries + torch.empty(0), # 14 computed_context_lens + torch.empty(0), # 15 tile_q_indices + torch.empty(0), # 16 tile_block_tables + torch.empty(0), # 17 tile_masks + torch.empty(0), # 18 inputs_embeds + torch.empty(0), # 19 kv_cache + torch.empty(0), # 20 active_mask + rotary_position_ids, # 21 + vision_embeddings, # 22 + vision_mask, # 23 + deepstack_vision_embeds, # 24 + )) + else: + raise ValueError(f"Unsupported model tag '{self.tag}'") + return inputs + + +def convert_qwen3_omni_text_hf_to_neuron(state_dict: dict, config: InferenceConfig) -> dict: + """Convert HF Qwen3-Omni thinker text weights to Neuron format. + + Handles both MRoPE attention key remapping (Qwen3-VL style) and + MoE expert weight stacking (Qwen3-MoE style). + """ + assert config.neuron_config.glu_mlp is True + + new_sd: Dict[str, Any] = {} + + # Step 1: Strip thinker prefix from text-model keys; preserve already-converted + # vision (blocks.*) and audio (frontend.*, transformer.*, postprocessor.*) keys. + for k, v in state_dict.items(): + if k.startswith("thinker.model."): + new_key = k[len("thinker.model."):] + new_sd[new_key] = v + elif k.startswith("thinker.lm_head."): + new_key = k[len("thinker."):] + new_sd[new_key] = v + elif k.startswith("thinker."): + # Drop any other thinker.* keys (e.g., thinker.audio_tower leftovers) + continue + else: + new_sd[k] = v + + state_dict = new_sd + + # Step 2: Attention key remapping (Qwen3-VL style) + attention_renames = { + ".self_attn.q_proj.": ".self_attn.qkv_proj.q_proj.", + ".self_attn.k_proj.": ".self_attn.qkv_proj.k_proj.", + ".self_attn.v_proj.": ".self_attn.qkv_proj.v_proj.", + ".self_attn.o_proj.": ".self_attn.o_proj.o_proj.", + } + renamed_sd: Dict[str, Any] = {} + for k, v in state_dict.items(): + new_key = k + if not config.neuron_config.fused_qkv: + for old, new in attention_renames.items(): + if old in new_key: + new_key = new_key.replace(old, new) + break + if ".q_norm." in new_key: + new_key = new_key.replace(".q_norm.", ".q_layernorm.") + if ".k_norm." in new_key: + new_key = new_key.replace(".k_norm.", ".k_layernorm.") + renamed_sd[new_key] = v + state_dict = renamed_sd + + # Step 3: rank_util tensors + state_dict["rank_util.rank"] = torch.arange(0, config.neuron_config.tp_degree, dtype=torch.int32) + + # Step 4: MoE weight conversion (Qwen3-MoE style) + for l in range(config.num_hidden_layers): + state_dict[f"layers.{l}.self_attn.rank_util.rank"] = torch.arange( + 0, config.neuron_config.tp_degree, dtype=torch.int32 + ) + + # Router: gate -> router.linear_router + gate_key = f"layers.{l}.mlp.gate.weight" + if gate_key in state_dict: + state_dict[f"layers.{l}.mlp.router.linear_router.weight"] = state_dict.pop(gate_key) + + # Stack expert weights + expert_key_0 = f"layers.{l}.mlp.experts.0.gate_proj.weight" + if expert_key_0 not in state_dict: + continue + + intermediate_size, hidden_size = state_dict[expert_key_0].shape + device = state_dict[expert_key_0].device + dtype = state_dict[expert_key_0].dtype + + gate_up_proj = torch.empty(config.num_experts, hidden_size, 2 * intermediate_size, dtype=dtype, device=device) + for e in range(config.num_experts): + gw = state_dict.pop(f"layers.{l}.mlp.experts.{e}.gate_proj.weight") + uw = state_dict.pop(f"layers.{l}.mlp.experts.{e}.up_proj.weight") + # copy_() writes into the preallocated buffer and releases gw/uw after the copy; + # avoids holding a second transposed materialization in RAM. + gate_up_proj[e, :, :intermediate_size].copy_(gw.T) + gate_up_proj[e, :, intermediate_size:].copy_(uw.T) + del gw, uw + + pad_size = getattr(config, "moe_intermediate_pad_size", 0) + if pad_size > 0: + gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, 2, -1) + gate_up_proj = torch.nn.functional.pad(gate_up_proj, (0, pad_size)) + gate_up_proj = gate_up_proj.reshape(config.num_experts, hidden_size, -1) + state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = gate_up_proj + + down_proj = torch.empty(config.num_experts, intermediate_size, hidden_size, dtype=dtype, device=device) + for e in range(config.num_experts): + dw = state_dict.pop(f"layers.{l}.mlp.experts.{e}.down_proj.weight") + down_proj[e].copy_(dw.T) + del dw + if pad_size > 0: + down_proj = torch.nn.functional.pad(down_proj, (0, 0, 0, pad_size)) + state_dict[f"layers.{l}.mlp.expert_mlps.mlp_op.down_proj.weight"] = down_proj + + gc.collect() + + return state_dict diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_vision.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_vision.py new file mode 100644 index 00000000..550a2ecc --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src/modeling_qwen3_omni_vision.py @@ -0,0 +1,132 @@ +# coding=utf-8 +# Copyright 2025 The Qwen team, Alibaba Group and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Qwen3-Omni vision encoder for NxD Inference. + +Reuses the Qwen3-VL vision model code since the architecture is identical. +Only difference: state dict key mapping (thinker.visual.* → visual.*) and +the PatchMerger naming (ln_q + mlp.0/mlp.2 vs norm + linear_fc1/linear_fc2). + +The vision model is compiled and run separately from the text model, with +outputs passed through the ImageToText framework. +""" +import logging +from typing import List +from unittest.mock import patch as mock_patch + +import torch +import torch.nn as nn + +from neuronx_distributed_inference.models.config import InferenceConfig +from neuronx_distributed_inference.modules.checkpoint import load_state_dict +from neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_vision import ( + NeuronQwen3VLForImageEncoding, + NeuronQwen3VLVisionModel, + NeuronQwen3VLVisionModelWrapper, +) + +logger = logging.getLogger(__name__) + + +def _load_state_dict_with_thinker_alias(model_path, *args, **kwargs): + """Wraps load_state_dict to add model.visual.* aliases for thinker.visual.* keys.""" + sd = load_state_dict(model_path, *args, **kwargs) + if "thinker.visual.pos_embed.weight" in sd and "model.visual.pos_embed.weight" not in sd: + sd["model.visual.pos_embed.weight"] = sd["thinker.visual.pos_embed.weight"] + return sd + + +class NeuronQwen3OmniVisionModel(NeuronQwen3VLVisionModel): + """ + Qwen3-Omni vision model — architecturally identical to Qwen3-VL. + The only code change is in state dict conversion (done externally). + """ + pass + + +class NeuronQwen3OmniVisionModelWrapper(NeuronQwen3VLVisionModelWrapper): + """Wraps NeuronQwen3OmniVisionModel for Neuron compilation. + + Patches load_state_dict during __init__ to handle Qwen3-Omni's + pos_embed key prefix (thinker.visual.* instead of model.visual.*). + """ + + def __init__(self, *args, **kwargs): + with mock_patch( + "neuronx_distributed_inference.models.qwen3_vl.modeling_qwen3_vl_vision.load_state_dict", + _load_state_dict_with_thinker_alias, + ): + super().__init__(*args, **kwargs) + + +class NeuronQwen3OmniForImageEncoding(NeuronQwen3VLForImageEncoding): + """Standalone vision encoder application for Qwen3-Omni.""" + + _model_cls = NeuronQwen3OmniVisionModel + + def get_model_wrapper_cls(self): + return NeuronQwen3OmniVisionModelWrapper + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, inference_config: InferenceConfig + ) -> dict: + """ + Converts HF Qwen3-Omni vision state dict to Neuron format. + + Key mappings: + thinker.visual.* → visual.* (then same as Qwen3-VL) + visual.merger.ln_q.* → merger.norm.* + visual.merger.mlp.0.* → merger.linear_fc1.* + visual.merger.mlp.2.* → merger.linear_fc2.* + visual.merger_list.N.ln_q.* → deepstack_merger_list.N.norm.* + visual.merger_list.N.mlp.0.* → deepstack_merger_list.N.linear_fc1.* + visual.merger_list.N.mlp.2.* → deepstack_merger_list.N.linear_fc2.* + visual.blocks.*.attn.qkv.* → blocks.*.attn.qkv_proj.Wqkv.* + visual.blocks.*.attn.proj.* → blocks.*.attn.o_proj.* + """ + new_state_dict = {} + + for key, value in state_dict.items(): + if "visual." not in key: + continue + + new_key = key + # Strip thinker prefix + if new_key.startswith("thinker.visual."): + new_key = new_key[len("thinker."):] + elif new_key.startswith("model.thinker.visual."): + new_key = new_key[len("model.thinker."):] + + # Strip visual. prefix (NxD vision model doesn't use it) + new_key = new_key.replace("visual.", "", 1) + + # Qwen3-Omni uses merger_list; NxD Qwen3-VL uses deepstack_merger_list + new_key = new_key.replace("merger_list.", "deepstack_merger_list.") + + # PatchMerger key renaming: ln_q → norm, mlp.0 → linear_fc1, mlp.2 → linear_fc2 + new_key = new_key.replace(".ln_q.", ".norm.") + new_key = new_key.replace(".mlp.0.", ".linear_fc1.") + new_key = new_key.replace(".mlp.2.", ".linear_fc2.") + + # Attention key renaming (same as Qwen3-VL) + if ".attn.qkv." in new_key: + new_key = new_key.replace(".attn.qkv.", ".attn.qkv_proj.Wqkv.") + elif ".attn.proj." in new_key: + new_key = new_key.replace(".attn.proj.", ".attn.o_proj.") + + new_state_dict[new_key] = value.clone().detach().contiguous() + + return new_state_dict diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/__init__.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/integration/__init__.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/integration/test_model.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/integration/test_model.py new file mode 100644 index 00000000..3739fbab --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/integration/test_model.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +""" +Integration tests for Qwen3-Omni-30B-A3B-Instruct (thinker text model) on NeuronX. + +Tests model compilation, loading, and inference on Neuron devices. +Requires: trn1.32xlarge or trn2.48xlarge with enough cores for tp_degree. +""" + +import pytest +import torch +import time +from pathlib import Path +from transformers import AutoTokenizer + +from neuronx_distributed_inference.models.config import MoENeuronConfig +from neuronx_distributed_inference.utils.accuracy import get_generate_outputs + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) +from modeling_qwen3_omni_moe import ( + NeuronQwen3OmniMoeForCausalLM, + Qwen3OmniMoeInferenceConfig, + load_qwen3_omni_thinker_text_config, +) + +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct/" +COMPILED_MODEL_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-30B-A3B-Instruct/" +TP_DEGREE = 32 +BATCH_SIZE = 1 +SEQ_LEN = 512 +MAX_CONTEXT_LENGTH = 256 + + +@pytest.fixture(scope="module") +def compiled_model(): + """Compile and load the Qwen3-Omni thinker text model.""" + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + max_context_length=MAX_CONTEXT_LENGTH, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": 1, "do_sample": False}, + ) + + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(MODEL_PATH), + ) + + model = NeuronQwen3OmniMoeForCausalLM(MODEL_PATH, config) + + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists(): + print(f"Compiling model to {COMPILED_MODEL_PATH}...") + model.compile(COMPILED_MODEL_PATH) + print("Compilation complete.") + + model.load(COMPILED_MODEL_PATH) + return model + + +@pytest.fixture(scope="module") +def tokenizer(): + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +def test_model_loads(compiled_model): + """Smoke test: model loads successfully.""" + assert compiled_model is not None + assert hasattr(compiled_model, "config") + assert hasattr(compiled_model.config, "neuron_config") + print("PASS: Model loaded successfully") + + +def test_model_generates(compiled_model, tokenizer): + """Test that model can generate text.""" + prompt = "The capital of France is" + + _, output_tokens = get_generate_outputs( + compiled_model, + [prompt], + tokenizer, + is_hf=False, + do_sample=False, + max_length=compiled_model.neuron_config.max_length, + ) + + output_text = output_tokens[0] + assert len(output_text) > len(prompt), "Output should be longer than prompt" + print(f"PASS: Generated: {output_text[:200]}") + + +def test_throughput(compiled_model, tokenizer): + """Measure token generation throughput.""" + prompt = "Hello" + + # warmup + get_generate_outputs( + compiled_model, [prompt], tokenizer, + is_hf=False, do_sample=False, + max_length=compiled_model.neuron_config.max_length, + ) + + start = time.perf_counter() + _, output_tokens = get_generate_outputs( + compiled_model, [prompt], tokenizer, + is_hf=False, do_sample=False, + max_length=compiled_model.neuron_config.max_length, + ) + elapsed = time.perf_counter() - start + + output_len = len(tokenizer.encode(output_tokens[0])) + input_len = len(tokenizer.encode(prompt)) + num_new = output_len - input_len + throughput = num_new / elapsed if elapsed > 0 else 0 + assert throughput > 1, f"Throughput {throughput:.2f} tok/s is too low" + print(f"PASS: Throughput {throughput:.2f} tok/s ({num_new} tokens in {elapsed:.2f}s)") + + +if __name__ == "__main__": + print("=" * 60) + print("Qwen3-Omni-30B-A3B-Instruct Integration Tests") + print("=" * 60) + + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + max_context_length=MAX_CONTEXT_LENGTH, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": 1, "do_sample": False}, + ) + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(MODEL_PATH), + ) + model = NeuronQwen3OmniMoeForCausalLM(MODEL_PATH, config) + compiled_path = Path(COMPILED_MODEL_PATH) + if not compiled_path.exists(): + print("Compiling...") + model.compile(COMPILED_MODEL_PATH) + model.load(COMPILED_MODEL_PATH) + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + test_model_loads(model) + test_model_generates(model, tok) + test_throughput(model, tok) + print("\nAll tests passed!") diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/unit/__init__.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/unit/test_config_and_state_dict.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/unit/test_config_and_state_dict.py new file mode 100644 index 00000000..8d3ec75e --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test/unit/test_config_and_state_dict.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +""" +Unit tests for Qwen3-Omni-MoE config loading and state dict conversion. +These tests run on CPU without Neuron devices. +""" +import json +import os +import tempfile + +import pytest +import torch + +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) + +from modeling_qwen3_omni_moe import ( + Qwen3OmniMoeInferenceConfig, + load_qwen3_omni_thinker_text_config, + _strip_thinker_prefix, + convert_qwen3_omni_moe_hf_to_neuron_state_dict, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig + + +SAMPLE_CONFIG = { + "model_type": "qwen3_omni_moe", + "thinker_config": { + "text_config": { + "hidden_size": 2048, + "num_hidden_layers": 4, + "num_attention_heads": 32, + "num_key_value_heads": 4, + "head_dim": 128, + "vocab_size": 152064, + "max_position_embeddings": 65536, + "moe_intermediate_size": 768, + "num_experts": 8, + "num_experts_per_tok": 2, + "norm_topk_prob": True, + "rms_norm_eps": 1e-6, + "rope_theta": 1000000, + "hidden_act": "silu", + "tie_word_embeddings": False, + "shared_expert_intermediate_size": 0, + "rope_scaling": { + "interleaved": True, + "mrope_section": [24, 20, 20], + "rope_type": "default", + }, + }, + "pad_token_id": None, + }, +} + + +@pytest.fixture +def config_dir(tmp_path): + config_path = tmp_path / "config.json" + config_path.write_text(json.dumps(SAMPLE_CONFIG)) + return str(tmp_path) + + +@pytest.fixture +def neuron_config(): + return MoENeuronConfig( + tp_degree=4, + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + + +def test_config_loads_from_nested_structure(config_dir, neuron_config): + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(config_dir), + ) + assert config.hidden_size == 2048 + assert config.num_hidden_layers == 4 + assert config.num_attention_heads == 32 + assert config.num_key_value_heads == 4 + assert config.head_dim == 128 + assert config.num_experts == 8 + assert config.num_experts_per_tok == 2 + assert config.moe_intermediate_size == 768 + assert config.vocab_size == 152064 + assert config.rope_theta == 1000000 + + +def test_config_moe_settings(config_dir, neuron_config): + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(config_dir), + ) + assert config.num_local_experts == config.num_experts + assert config.n_shared_experts == 0 + assert config.intermediate_size == config.moe_intermediate_size + assert config.neuron_config.router_config.dtype == torch.float32 + assert config.neuron_config.router_config.act_fn == "softmax" + assert config.neuron_config.normalize_top_k_affinities is True + assert config.neuron_config.disable_numeric_cc_token is True + + +def test_config_missing_thinker_raises(tmp_path, neuron_config): + bad_config = {"model_type": "something_else"} + (tmp_path / "config.json").write_text(json.dumps(bad_config)) + with pytest.raises(ValueError, match="thinker_config.text_config"): + Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(str(tmp_path)), + ) + + +def test_strip_thinker_prefix(): + sd = { + "thinker.model.embed_tokens.weight": torch.randn(10, 10), + "thinker.model.layers.0.self_attn.q_proj.weight": torch.randn(10, 10), + "thinker.model.layers.0.mlp.gate.weight": torch.randn(10, 10), + "thinker.model.norm.weight": torch.randn(10), + "thinker.lm_head.weight": torch.randn(10, 10), + "talker.model.layers.0.weight": torch.randn(5, 5), + "code2wav.decoder.weight": torch.randn(3, 3), + } + stripped = _strip_thinker_prefix(sd) + assert "embed_tokens.weight" in stripped + assert "layers.0.self_attn.q_proj.weight" in stripped + assert "layers.0.mlp.gate.weight" in stripped + assert "norm.weight" in stripped + assert "lm_head.weight" in stripped + assert "talker.model.layers.0.weight" not in stripped + assert "code2wav.decoder.weight" not in stripped + assert len(stripped) == 5 + + +def test_strip_with_model_thinker_prefix(): + sd = { + "model.thinker.model.embed_tokens.weight": torch.randn(10, 10), + "model.thinker.lm_head.weight": torch.randn(10, 10), + } + stripped = _strip_thinker_prefix(sd) + assert "embed_tokens.weight" in stripped + assert "lm_head.weight" in stripped + + +def test_strip_no_prefix(): + sd = { + "embed_tokens.weight": torch.randn(10, 10), + "layers.0.self_attn.q_proj.weight": torch.randn(10, 10), + "lm_head.weight": torch.randn(10, 10), + } + stripped = _strip_thinker_prefix(sd) + assert "embed_tokens.weight" in stripped + assert len(stripped) == 3 + + +def _make_fake_thinker_state_dict(num_layers=2, num_experts=4, hidden=64, intermediate=32): + """Build a fake HF-format state dict with thinker prefix.""" + sd = {} + sd["thinker.model.embed_tokens.weight"] = torch.randn(100, hidden) + sd["thinker.model.norm.weight"] = torch.randn(hidden) + sd["thinker.lm_head.weight"] = torch.randn(100, hidden) + + for l in range(num_layers): + pfx = f"thinker.model.layers.{l}" + sd[f"{pfx}.self_attn.q_proj.weight"] = torch.randn(hidden, hidden) + sd[f"{pfx}.self_attn.k_proj.weight"] = torch.randn(hidden // 8, hidden) + sd[f"{pfx}.self_attn.v_proj.weight"] = torch.randn(hidden // 8, hidden) + sd[f"{pfx}.self_attn.o_proj.weight"] = torch.randn(hidden, hidden) + sd[f"{pfx}.self_attn.q_norm.weight"] = torch.randn(hidden // (hidden // 8)) + sd[f"{pfx}.self_attn.k_norm.weight"] = torch.randn(hidden // (hidden // 8)) + sd[f"{pfx}.input_layernorm.weight"] = torch.randn(hidden) + sd[f"{pfx}.post_attention_layernorm.weight"] = torch.randn(hidden) + sd[f"{pfx}.mlp.gate.weight"] = torch.randn(num_experts, hidden) + for e in range(num_experts): + sd[f"{pfx}.mlp.experts.{e}.gate_proj.weight"] = torch.randn(intermediate, hidden) + sd[f"{pfx}.mlp.experts.{e}.up_proj.weight"] = torch.randn(intermediate, hidden) + sd[f"{pfx}.mlp.experts.{e}.down_proj.weight"] = torch.randn(hidden, intermediate) + + return sd + + +def test_full_state_dict_conversion(config_dir, neuron_config): + config = Qwen3OmniMoeInferenceConfig( + neuron_config, + load_config=load_qwen3_omni_thinker_text_config(config_dir), + ) + # override for small test + config.num_hidden_layers = 2 + config.num_experts = 4 + config.num_local_experts = 4 + config.hidden_size = 64 + config.moe_intermediate_size = 32 + config.intermediate_size = 32 + + sd = _make_fake_thinker_state_dict(num_layers=2, num_experts=4, hidden=64, intermediate=32) + neuron_sd = convert_qwen3_omni_moe_hf_to_neuron_state_dict(sd, config) + + # Check prefix stripped + assert "embed_tokens.weight" in neuron_sd + assert "lm_head.weight" in neuron_sd + assert "norm.weight" in neuron_sd + + # Check rank utils added + assert "rank_util.rank" in neuron_sd + assert "layers.0.self_attn.rank_util.rank" in neuron_sd + + # Check qk norm renamed + assert "layers.0.self_attn.q_layernorm.weight" in neuron_sd + assert "layers.0.self_attn.k_layernorm.weight" in neuron_sd + assert "layers.0.self_attn.q_norm.weight" not in neuron_sd + + # Check router renamed + assert "layers.0.mlp.router.linear_router.weight" in neuron_sd + assert "layers.0.mlp.gate.weight" not in neuron_sd + + # Check expert weights reorganized + gate_up = neuron_sd["layers.0.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] + assert gate_up.shape == (4, 64, 64) # (num_experts, hidden, 2*intermediate) + down = neuron_sd["layers.0.mlp.expert_mlps.mlp_op.down_proj.weight"] + assert down.shape == (4, 32, 64) # (num_experts, intermediate, hidden) + + # Check individual expert keys removed + assert "layers.0.mlp.experts.0.gate_proj.weight" not in neuron_sd + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_asr.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_asr.py new file mode 100644 index 00000000..b11a6782 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_asr.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +ASR benchmark: evaluate Qwen3-Omni-30B-A3B-Instruct on LibriSpeech test-clean. + +Runs the MoE text model on Neuron (TP=8, LNC=2) and audio encoder transformer +layers on a single Neuron core (no TP). Conv2d frontend stays on CPU. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + pip install jiwer "datasets<4" soundfile librosa + + NEURON_RT_VISIBLE_CORES=0-31 python test_asr.py \ + --model-path /home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct \ + --compiled-model-path /home/ubuntu/traced_model/Qwen3-Omni-asr \ + --audio-compiled-path /home/ubuntu/traced_model/Qwen3-Omni-audio \ + --num-samples 100 +""" +import argparse +import json +import os +import sys +import time +from pathlib import Path + +import numpy as np +import torch + +sys.path.insert(0, str(Path(__file__).parent / "src")) + + +def compile_and_load_model(model_path, compiled_model_path, tp_degree, seq_len, + max_context_length, vision_tp_degree, vision_seq_len, + audio_compiled_path=None): + """Compile (if needed) and load the multimodal model on Neuron.""" + from modeling_qwen3_omni import ( + NeuronQwen3OmniForCausalLM, + Qwen3OmniInferenceConfig, + load_qwen3_omni_multimodal_config, + ) + from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig + + text_neuron_config = MoENeuronConfig( + tp_degree=tp_degree, + batch_size=1, + seq_len=seq_len, + max_context_length=max_context_length, + torch_dtype=torch.bfloat16, + on_device_sampling_config={"top_k": 1, "do_sample": False}, + blockwise_matmul_config={"use_torch_block_wise": True}, + ) + + vision_neuron_config = NeuronConfig( + tp_degree=vision_tp_degree, + batch_size=1, + seq_len=vision_seq_len, + torch_dtype=torch.bfloat16, + ) + + config = Qwen3OmniInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_qwen3_omni_multimodal_config(model_path), + ) + + model = NeuronQwen3OmniForCausalLM(model_path, config, skip_vision_encoder=True) + + compiled_path = Path(compiled_model_path) + if not compiled_path.exists(): + print("Compiling multimodal model (this may take 20-40 minutes)...") + t0 = time.perf_counter() + model.compile(compiled_model_path) + print(f"Compilation took {time.perf_counter() - t0:.1f}s") + + print("Loading model to Neuron...") + t0 = time.perf_counter() + model.load(compiled_model_path) + print(f"Model loaded in {time.perf_counter() - t0:.1f}s") + + # Initialize audio encoder + if audio_compiled_path: + print("Loading audio encoder (Neuron + CPU hybrid)...") + t0 = time.perf_counter() + model.init_audio_encoder_neuron(model_path, audio_compiled_path) + print(f"Audio encoder loaded in {time.perf_counter() - t0:.1f}s") + else: + print("Loading audio encoder on CPU...") + model.init_audio_encoder(model_path) + print("Audio encoder loaded.") + + return model, config + + +def main(): + parser = argparse.ArgumentParser(description="ASR benchmark on LibriSpeech (Neuron)") + parser.add_argument("--model-path", type=str, + default="/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct") + parser.add_argument("--compiled-model-path", type=str, + default="/home/ubuntu/traced_model/Qwen3-Omni-asr") + parser.add_argument("--audio-compiled-path", type=str, default=None, + help="Path for compiled audio encoder (Neuron). If not set, uses CPU.") + parser.add_argument("--num-samples", type=int, default=100) + parser.add_argument("--max-new-tokens", type=int, default=256) + parser.add_argument("--tp-degree", type=int, default=8) + parser.add_argument("--vision-tp-degree", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--max-context-length", type=int, default=2048) + parser.add_argument("--vision-seq-len", type=int, default=4096) + parser.add_argument("--split", type=str, default="test.clean", + help="LibriSpeech split: test.clean, test.other, etc.") + parser.add_argument("--output-json", type=str, default="asr_results.json", + help="Path to save per-sample results") + args = parser.parse_args() + + # ── 1. Load dataset ────────────────────────────────────────────────── + from datasets import load_dataset + + print(f"Loading openslr/librispeech_asr split={args.split} ...") + ds = load_dataset("openslr/librispeech_asr", split=args.split, streaming=True) + + samples = [] + for i, item in enumerate(ds): + if i >= args.num_samples: + break + samples.append(item) + print(f"Loaded {len(samples)} samples") + + # ── 2. Compile and load model on Neuron ────────────────────────────── + model, config = compile_and_load_model( + args.model_path, args.compiled_model_path, + args.tp_degree, args.seq_len, args.max_context_length, + args.vision_tp_degree, args.vision_seq_len, + args.audio_compiled_path, + ) + + # ── 3. Load processor ──────────────────────────────────────────────── + from transformers import Qwen3OmniMoeProcessor, AutoTokenizer + from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter + + processor = Qwen3OmniMoeProcessor.from_pretrained( + args.model_path, trust_remote_code=True + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + adapter = HuggingFaceGenerationAdapter(model) + + # ── 4. Run ASR ─────────────────────────────────────────────────────── + ASR_PROMPT = "Please transcribe the following audio into English text." + results = [] + total_audio_duration = 0.0 + + print(f"\nRunning ASR on {len(samples)} samples ...") + inference_start = time.perf_counter() + + for idx, sample in enumerate(samples): + audio_array = np.array(sample["audio"]["array"]) + sampling_rate = sample["audio"]["sampling_rate"] + reference = sample["text"].strip() + audio_duration = len(audio_array) / sampling_rate + total_audio_duration += audio_duration + + conversation = [ + { + "role": "user", + "content": [ + {"type": "audio", "audio": audio_array}, + {"type": "text", "text": ASR_PROMPT}, + ], + } + ] + + text_input = processor.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) + inputs = processor( + text=text_input, + audio=[audio_array], + sampling_rate=sampling_rate, + return_tensors="pt", + padding=True, + ) + + generate_kwargs = { + "input_ids": inputs.input_ids, + "attention_mask": inputs.attention_mask, + "max_new_tokens": args.max_new_tokens, + "do_sample": False, + } + if hasattr(inputs, "input_features") and inputs.input_features is not None: + generate_kwargs["input_features"] = inputs.input_features + if hasattr(inputs, "feature_attention_mask") and inputs.feature_attention_mask is not None: + generate_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + t0 = time.perf_counter() + output_ids = adapter.generate(**generate_kwargs) + gen_time = time.perf_counter() - t0 + + input_len = inputs["input_ids"].shape[1] + generated_ids = output_ids[:, input_len:] + eos_id = tokenizer.eos_token_id + gen_ids = generated_ids[0].tolist() + if eos_id in gen_ids: + gen_ids = gen_ids[:gen_ids.index(eos_id)] + hypothesis = tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + results.append({ + "id": sample.get("id", idx), + "reference": reference, + "hypothesis": hypothesis, + "audio_duration_s": round(audio_duration, 2), + "gen_time_s": round(gen_time, 3), + }) + + if (idx + 1) % 10 == 0 or idx == 0: + print(f" [{idx+1}/{len(samples)}] {gen_time:.2f}s " + f"ref: {reference[:50]}... | hyp: {hypothesis[:50]}...") + + total_inference_time = time.perf_counter() - inference_start + + # ── 5. Compute metrics ─────────────────────────────────────────────── + from jiwer import wer, cer + + references = [r["reference"] for r in results] + hypotheses = [r["hypothesis"] for r in results] + + refs_norm = [r.lower() for r in references] + hyps_norm = [h.lower() for h in hypotheses] + + overall_wer = wer(refs_norm, hyps_norm) + overall_cer = cer(refs_norm, hyps_norm) + + per_sample_wer = [wer(r, h) for r, h in zip(refs_norm, hyps_norm)] + for i, r in enumerate(results): + r["wer"] = round(per_sample_wer[i], 4) + + # ── 6. Report ──────────────────────────────────────────────────────── + rtf = total_inference_time / total_audio_duration if total_audio_duration > 0 else float("inf") + avg_gen_time = sum(r["gen_time_s"] for r in results) / len(results) + + print("\n" + "=" * 70) + print("ASR Benchmark Results (Neuron)") + print("=" * 70) + print(f"Model: {args.model_path}") + print(f"TP degree: {args.tp_degree} (text MoE), {args.vision_tp_degree} (vision)") + print(f"Audio encoder: {'Neuron' if args.audio_compiled_path else 'CPU'}") + print(f"Dataset: openslr/librispeech_asr ({args.split})") + print(f"Samples: {len(results)}") + print(f"Total audio: {total_audio_duration:.1f}s") + print(f"Total inference: {total_inference_time:.1f}s") + print(f"Avg per sample: {avg_gen_time:.3f}s") + print(f"RTF: {rtf:.2f}x") + print(f"WER: {overall_wer:.4f} ({overall_wer*100:.2f}%)") + print(f"CER: {overall_cer:.4f} ({overall_cer*100:.2f}%)") + print("=" * 70) + + print("\nSample results:") + for r in results[:5]: + print(f" REF: {r['reference']}") + print(f" HYP: {r['hypothesis']}") + print(f" WER: {r['wer']:.4f}") + print() + + output = { + "model": args.model_path, + "dataset": f"openslr/librispeech_asr:{args.split}", + "tp_degree": args.tp_degree, + "audio_on_neuron": args.audio_compiled_path is not None, + "num_samples": len(results), + "total_audio_duration_s": round(total_audio_duration, 2), + "total_inference_time_s": round(total_inference_time, 2), + "avg_gen_time_s": round(avg_gen_time, 3), + "rtf": round(rtf, 4), + "wer": round(overall_wer, 4), + "cer": round(overall_cer, 4), + "samples": results, + } + + output_path = Path(args.output_json) + output_path.write_text(json.dumps(output, indent=2, ensure_ascii=False)) + print(f"\nDetailed results saved to {output_path}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_load_multimodal.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_load_multimodal.py new file mode 100644 index 00000000..0387afda --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_load_multimodal.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 +"""Quick test: load both text and vision models.""" +import sys, os, time, torch, logging, traceback +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format='%(name)s:%(levelname)s: %(message)s') + +sys.path.insert(0, str(Path(__file__).parent / "src")) + +from modeling_qwen3_omni import ( + NeuronQwen3OmniForCausalLM, + Qwen3OmniInferenceConfig, + load_qwen3_omni_multimodal_config, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig + +model_path = '/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct' +compiled_path = '/home/ubuntu/traced_model/Qwen3-Omni-multimodal' + +text_neuron_config = MoENeuronConfig( + tp_degree=16, batch_size=1, seq_len=4096, + max_context_length=2048, torch_dtype=torch.bfloat16, + on_device_sampling_config={'top_k': 1, 'do_sample': False}, + blockwise_matmul_config={'use_torch_block_wise': True}, +) + +vision_neuron_config = NeuronConfig( + tp_degree=16, batch_size=1, seq_len=4096, torch_dtype=torch.bfloat16, +) + +config = Qwen3OmniInferenceConfig( + text_neuron_config=text_neuron_config, + vision_neuron_config=vision_neuron_config, + load_config=load_qwen3_omni_multimodal_config(model_path), +) + +model = NeuronQwen3OmniForCausalLM(model_path, config) + +print('=== Loading model ===') +t0 = time.perf_counter() +try: + model.load(compiled_path) + print(f'SUCCESS: Model loaded in {time.perf_counter() - t0:.1f}s') +except Exception as e: + print(f'LOAD FAILED after {time.perf_counter() - t0:.1f}s') + traceback.print_exc() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_thinker_ttft_bench.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_thinker_ttft_bench.py new file mode 100644 index 00000000..33996364 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_thinker_ttft_bench.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +"""Thinker-only TTFT / throughput benchmark on /home/ubuntu/omni2 conversations 0-99. + +Flow per conversation (talker/code2wav disabled — see test_ttfb_rtf_bench.py for +the full streaming TTFB version, which is currently blocked by the layers.23 +tensor-capture only being wired at bucket 256). + +Metrics: + * ttft_ms — from adapter.generate() start to the first token emission + (first hook fire after the prefill returns) + * prefill_ms — time to the first hook fire (== first forward complete) + * decode_steps — number of post-prefill forward calls (= new tokens - 1) + * decode_mean_ms — mean inter-token wall time during decode (ITL) + * decode_p90_ms — p90 of per-step decode time + * tokens_per_s — thinker_tokens / thinker_wall + * rtf_vs_audio — thinker_wall / input_audio_s (lower = faster than audio) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-7 python test_thinker_ttft_bench.py --num 100 +""" +import os +os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7") +# The existing compiled artifacts include layers.23 capture; we still want it +# to flow (as a per-step timing signal), even though the shape is (1,) for +# buckets > 256. We only need the hook to *fire* each step. +os.environ.setdefault("QWEN3_OMNI_CAPTURE_LAYER_HIDDEN", "23") +os.environ["TRANSFORMERS_VERBOSITY"] = "error" + +import sys +from pathlib import Path +_HERE = Path(__file__).resolve().parent +_SRC = _HERE / "src" +if str(_SRC) not in sys.path: + sys.path.insert(0, str(_SRC)) +if "/home/ubuntu" not in sys.path: + sys.path.insert(0, "/home/ubuntu") + +import _upstream_compat # noqa: F401 + +import argparse +import json +import statistics +import time +import traceback + +import numpy as np +import soundfile as sf +import torch +from transformers import GenerationConfig + +import test_asr_qwen3_omni as asr # build_and_load_model applies pad_inputs patches + +CONV_JSON = "/home/ubuntu/omni2/merged_conversations_with_audio_x10_with_system.json" +AUDIO_DIR = "/home/ubuntu/omni2/speech_wav_16k" +MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct" +COMPILED_PATH = "/tmp/qwen3_omni_compiled" + + +def build_messages(conv): + msgs = conv["messages"] + out = [] + for i, m in enumerate(msgs): + if i == len(msgs) - 1: + break # drop the reference assistant reply + role = m["role"] + content = m["content"] + if i == len(msgs) - 2 and role == "user": + fname = os.path.basename(content) + wav_path = os.path.join(AUDIO_DIR, fname) + out.append({"role": role, "content": [{"type": "audio", "audio": wav_path}]}) + else: + out.append({"role": role, "content": content}) + return out + + +def percentile(values, p): + if not values: + return float("nan") + s = sorted(values) + i = int(round((len(s) - 1) * p / 100)) + return s[i] + + +def run_one(adapter, processor, conv, idx, max_new_tokens): + messages = build_messages(conv) + wav_path = messages[-1]["content"][0]["audio"] + audio_np, sr = sf.read(wav_path) + if audio_np.ndim == 2: + audio_np = audio_np.mean(axis=1) + audio_np = audio_np.astype(np.float32) + input_audio_s = float(len(audio_np) / sr) + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + target_sr = getattr(processor.feature_extractor, "sampling_rate", 16000) + if sr != target_sr: + import librosa + audio_for_fe = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr) + else: + audio_for_fe = audio_np + inputs = processor(text=[text], audio=[audio_for_fe], return_tensors="pt", padding=True) + prompt_tokens = int(inputs.input_ids.shape[1]) + + # The tensor_capture_hook fires once per forward pass (prefill + each decode step). + # We use it purely as a timing tap. + step_times = [] # absolute perf_counter timestamps + def _hook(_m, _tensors): + step_times.append(time.perf_counter()) + + gc_cfg = GenerationConfig(do_sample=False, eos_token_id=[151645], pad_token_id=151645) + gen_kwargs = dict( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + generation_config=gc_cfg, + max_new_tokens=max_new_tokens, + tensor_capture_hook=_hook, + ) + if getattr(inputs, "input_features", None) is not None: + gen_kwargs["input_features"] = inputs.input_features.to(torch.bfloat16) + if getattr(inputs, "feature_attention_mask", None) is not None: + gen_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + t_start = time.perf_counter() + out_ids = adapter.generate(**gen_kwargs) + t_end = time.perf_counter() + + thinker_wall = t_end - t_start + new_tokens = int(out_ids.shape[1] - inputs.input_ids.shape[1]) + assistant_text = processor.batch_decode( + out_ids[:, inputs.input_ids.shape[1]:], + skip_special_tokens=True, clean_up_tokenization_spaces=False, + )[0].strip() + + if step_times: + ttft_ms = (step_times[0] - t_start) * 1000 + prefill_ms = ttft_ms # first forward covers the prefill + if len(step_times) >= 2: + decode_diffs_ms = [(step_times[i] - step_times[i - 1]) * 1000 + for i in range(1, len(step_times))] + decode_mean_ms = statistics.mean(decode_diffs_ms) + decode_p50_ms = percentile(decode_diffs_ms, 50) + decode_p90_ms = percentile(decode_diffs_ms, 90) + else: + decode_diffs_ms = [] + decode_mean_ms = decode_p50_ms = decode_p90_ms = None + else: + ttft_ms = prefill_ms = decode_mean_ms = decode_p50_ms = decode_p90_ms = None + decode_diffs_ms = [] + + tokens_per_s = new_tokens / thinker_wall if thinker_wall > 0 else None + rtf_vs_audio = thinker_wall / input_audio_s if input_audio_s > 0 else None + + return { + "idx": idx, + "wav_path": wav_path, + "input_audio_s": input_audio_s, + "prompt_tokens": prompt_tokens, + "new_tokens": new_tokens, + "thinker_wall_ms": thinker_wall * 1000, + "ttft_ms": ttft_ms, + "prefill_ms": prefill_ms, + "decode_steps": max(0, len(step_times) - 1), + "decode_mean_ms": decode_mean_ms, + "decode_p50_ms": decode_p50_ms, + "decode_p90_ms": decode_p90_ms, + "tokens_per_s": tokens_per_s, + "rtf_vs_audio": rtf_vs_audio, + "text": assistant_text, + } + + +def summary_row(name, xs, fmt="{:7.1f}"): + xs = [x for x in xs if x is not None] + if not xs: + print(f" {name:20s} (no data)") + return + print( + f" {name:20s} mean={fmt.format(statistics.mean(xs))} " + f"p50={fmt.format(percentile(xs, 50))} " + f"p90={fmt.format(percentile(xs, 90))} " + f"p95={fmt.format(percentile(xs, 95))} " + f"max={fmt.format(max(xs))}" + ) + + +def print_summary(results): + ok = [r for r in results if "error" not in r] + print("\n=== SUMMARY ===") + print(f" samples ok: {len(ok)}/{len(results)}") + if not ok: + return + summary_row("TTFT ms", [r["ttft_ms"] for r in ok]) + summary_row("prefill ms", [r["prefill_ms"] for r in ok]) + summary_row("decode ITL mean ms", [r["decode_mean_ms"] for r in ok]) + summary_row("decode ITL p90 ms", [r["decode_p90_ms"] for r in ok]) + summary_row("thinker wall ms", [r["thinker_wall_ms"] for r in ok]) + summary_row("tokens/s (overall)", [r["tokens_per_s"] for r in ok], fmt="{:7.1f}") + summary_row("RTF vs audio in", [r["rtf_vs_audio"] for r in ok], fmt="{:7.2f}") + summary_row("prompt tokens", [r["prompt_tokens"] for r in ok], fmt="{:7.0f}") + summary_row("new tokens", [r["new_tokens"] for r in ok], fmt="{:7.0f}") + summary_row("input audio s", [r["input_audio_s"] for r in ok], fmt="{:7.2f}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num", type=int, default=100) + parser.add_argument("--start", type=int, default=0) + parser.add_argument("--max-new-tokens", type=int, default=200) + parser.add_argument("--out", default="/tmp/qwen3_omni_thinker_ttft.json") + args = parser.parse_args() + + with open(CONV_JSON) as f: + conversations = json.load(f) + + print(f"Loaded {len(conversations)} conversations; running [{args.start}, {args.start + args.num})") + print("Building Neuron thinker (+ audio encoder)...") + adapter, processor = asr.build_and_load_model(MODEL_PATH, COMPILED_PATH) + print("Ready.\n") + + results = [] + bench_start = time.perf_counter() + for k in range(args.num): + idx = args.start + k + if idx >= len(conversations): + break + try: + r = run_one(adapter, processor, conversations[idx], idx, args.max_new_tokens) + results.append(r) + itl = r.get("decode_mean_ms") + itl_str = f"{itl:5.1f}ms" if itl is not None else " n/a" + print( + f"[{k+1:3d}/{args.num}] conv {idx:3d} " + f"in={r['input_audio_s']:4.1f}s " + f"prompt={r['prompt_tokens']:4d}tok " + f"new={r['new_tokens']:3d}tok " + f"TTFT={r['ttft_ms']:6.0f}ms " + f"ITL={itl_str} " + f"tok/s={r['tokens_per_s']:5.1f} " + f"wall={r['thinker_wall_ms']:5.0f}ms " + f"[{r['text'][:32]}]" + ) + except Exception as e: + traceback.print_exc() + print(f"[{k+1:3d}/{args.num}] conv {idx}: FAILED {e}") + results.append({"idx": idx, "error": str(e)}) + + with open(args.out, "w") as f: + json.dump({ + "cumulative_wall_s": time.perf_counter() - bench_start, + "results": results, + }, f, indent=2, ensure_ascii=False) + + print_summary(results) + print(f"\nJSON: {args.out}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_pipelined_bench.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_pipelined_bench.py new file mode 100644 index 00000000..67b16d0f --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_pipelined_bench.py @@ -0,0 +1,685 @@ +#!/usr/bin/env python3 +"""Pipelined TTFB/RTF benchmark — thinker and talker run concurrently. + +Baseline `test_ttfb_rtf_bench.py` is strictly serial: full thinker.generate → +build talker inputs → talker.generate. This bench overlaps them: + +1. Thinker runs in a background thread. A custom streamer pushes every new + thinker token into a queue. A forward hook captures the layer-23 hidden + (one per decode step) and puts it in the same queue. +2. Main thread waits for the first 4 thinker tokens (needed to build the + talker's prefill input), then kicks off talker.generate. +3. Talker's `prepare_inputs_for_generation` is monkey-patched: when the HF + loop asks for `trailing_text_hidden[:, k]` at decode step k, the patched + function pulls the (k+4)-th thinker embedding from the streaming buffer + — blocking until that token is available. Usually it's already there + because talker decode (~21 ms/step) is slower than thinker decode + (~10 ms/step). + +Streaming code2wav stays the same as the serial bench — chunk-sized inline +c2w calls. + +Expected win: TTFB drops from ~2000 ms to ~1400 ms because we no longer +wait for the last ~60 thinker tokens before starting the talker. + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-7 CHUNK_SIZE=25 LEFT_CTX=5 \ + python test_ttfb_pipelined_bench.py --num 100 --neuron-c2w +""" +import os +os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7") +os.environ.setdefault("QWEN3_OMNI_CAPTURE_LAYER_HIDDEN", "23") +os.environ["TRANSFORMERS_VERBOSITY"] = "error" + +import sys +from pathlib import Path + +_HERE = Path(__file__).resolve().parent +for _candidate in (_HERE / "src", Path("/home/ubuntu/whn-ndi/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src")): + if (_candidate / "_upstream_compat.py").exists() and str(_candidate) not in sys.path: + sys.path.insert(0, str(_candidate)) + break +if "/home/ubuntu" not in sys.path: + sys.path.insert(0, "/home/ubuntu") + +import _upstream_compat # noqa: F401 + +import argparse +import functools +import json +import queue +import statistics +import threading +import time +import traceback + +import numpy as np +import soundfile as sf +import torch +from transformers import GenerationConfig +from transformers.generation.streamers import BaseStreamer + +import test_audio_streaming as STR +sys.path.insert(0, str(_HERE)) +from code2wav_neuron import install_neuron_code2wav # noqa: E402 + +CONV_JSON = "/home/ubuntu/omni2/merged_conversations_with_audio_x10_with_system.json" +AUDIO_DIR = "/home/ubuntu/omni2/speech_wav_16k" + + +def build_messages(conv): + msgs = conv["messages"] + out = [] + for i, m in enumerate(msgs): + if i == len(msgs) - 1: + break + role = m["role"] + content = m["content"] + if i == len(msgs) - 2 and role == "user": + fname = os.path.basename(content) + wav_path = os.path.join(AUDIO_DIR, fname) + out.append({"role": role, "content": [{"type": "audio", "audio": wav_path}]}) + else: + out.append({"role": role, "content": content}) + return out + + +# --------------------------------------------------------------------------- +# Streaming thinker→talker plumbing +# --------------------------------------------------------------------------- + +class PipelineState: + """Shared between thinker thread and main thread. + + As thinker emits each assistant token, we accumulate the token id and the + layer-23 hidden, and push-notify the waiting talker side via a condition + variable. ``assistant_start_idx`` is the prompt length — we skip the + prompt tokens the streamer sees at the beginning. + """ + def __init__(self, assistant_start_idx: int): + self.lock = threading.Lock() + self.cond = threading.Condition(self.lock) + self.assistant_start_idx = assistant_start_idx + self.token_ids: list[int] = [] # assistant tokens only + self.layer23_hidden: list[torch.Tensor] = [] # one per thinker forward (prefill + decode) + self.thinker_done = False + self.thinker_error: Exception | None = None + + # Populated once the prefill's layer-23 output is captured (needed + # for _build_talker_inputs' USER turn portions). + self.prefill_hidden: torch.Tensor | None = None + + def push_token(self, token_id: int): + with self.cond: + self.token_ids.append(token_id) + self.cond.notify_all() + + def push_layer23(self, hid: torch.Tensor): + with self.cond: + self.layer23_hidden.append(hid) + if self.prefill_hidden is None: + # First call = prefill, shape [1, bucket, 2048]. Later calls + # are decode shape [1, 1, 2048]. + self.prefill_hidden = hid + self.cond.notify_all() + + def mark_done(self, exc: Exception | None = None): + with self.cond: + self.thinker_done = True + self.thinker_error = exc + self.cond.notify_all() + + def wait_for_tokens(self, count: int, timeout: float = 30.0) -> bool: + """Block until at least ``count`` assistant tokens are available.""" + deadline = time.perf_counter() + timeout + with self.cond: + while len(self.token_ids) < count: + if self.thinker_done: + return len(self.token_ids) >= count + remaining = deadline - time.perf_counter() + if remaining <= 0: + return False + self.cond.wait(timeout=remaining) + return True + + +class TokenStreamStoppingCriteria: + """Abuses HF's StoppingCriteria plumbing as a per-step callback. + + NxDI's custom `_sample` ignores `streamer` kwarg, but it DOES call + `stopping_criteria(input_ids, None)` after every decode step with the + current ``input_ids`` buffer. We piggy-back on that to notify + ``PipelineState`` whenever a new token is appended. + """ + + def __init__(self, state: PipelineState, prompt_len: int): + self.state = state + self.prompt_len = prompt_len + self._last_len = prompt_len + + def __call__(self, input_ids: torch.LongTensor, scores, **kwargs) -> torch.Tensor: + cur_len = int(input_ids.shape[1]) + if cur_len > self._last_len: + # Push all tokens added since last call (usually just 1) + for idx in range(self._last_len, cur_len): + self.state.push_token(int(input_ids[0, idx].item())) + self._last_len = cur_len + # Never stop on our account + return torch.zeros(input_ids.shape[0], dtype=torch.bool, device=input_ids.device) + + +def run_thinker(adapter, gen_kwargs, state: PipelineState, prompt_len: int): + """Executed in a background thread. Runs the full thinker.generate while + piping tokens + layer-23 hiddens into ``state``.""" + try: + def _cap_hook(_m, tensors): + if tensors: + state.push_layer23(tensors[0].clone().to("cpu")) + + from transformers.generation.stopping_criteria import StoppingCriteriaList + sc = StoppingCriteriaList([TokenStreamStoppingCriteria(state, prompt_len)]) + + gen_kwargs = dict(gen_kwargs) + gen_kwargs["tensor_capture_hook"] = _cap_hook + gen_kwargs["stopping_criteria"] = sc + out_ids = adapter.generate(**gen_kwargs) + with state.cond: + state.out_ids = out_ids + state.mark_done() + except Exception as e: + state.mark_done(exc=e) + + +# --------------------------------------------------------------------------- +# Pipelined talker setup — incremental trailing_text_hidden +# --------------------------------------------------------------------------- + +class StreamingTalkerInputs: + """Replaces the single-shot ``_build_talker_inputs``. + + Holds: + * ``talker_embed``: the user-parts + first-4-assistant-tokens buffer + (built once both are ready) + * ``trailing_text_hidden``: a tensor we grow by one row per newly-arrived + thinker assistant token past the 4th + * ``talker_input_ids``: mirrors talker_embed's sequence length + + Reads from ``PipelineState`` under its condition variable. + """ + def __init__(self, hf_model, conv_inputs, state: PipelineState, speaker: str = "ethan"): + self.hf_model = hf_model + self.state = state + self.cfg = hf_model.config + self.speaker_id = self.cfg.talker_config.speaker_id[speaker.lower()] + self.conv_inputs = conv_inputs # original user-turn inputs + self.dtype = torch.bfloat16 + + # Pre-compute static parts + embed_layer = hf_model.thinker.get_input_embeddings() + talker_special = torch.tensor( + [[self.cfg.tts_bos_token_id, self.cfg.tts_eos_token_id, self.cfg.tts_pad_token_id]], + dtype=torch.long, + ) + with torch.no_grad(): + self.tts_bos_embed, self.tts_eos_embed, self.tts_pad_embed = ( + hf_model.talker.text_projection(embed_layer(talker_special)).chunk(3, dim=1) + ) + self._embed_layer = embed_layer + self._codec_special_tokens = torch.tensor( + [[ + self.cfg.talker_config.codec_nothink_id, + self.cfg.talker_config.codec_think_bos_id, + self.cfg.talker_config.codec_think_eos_id, + self.speaker_id, + self.cfg.talker_config.codec_pad_id, + self.cfg.talker_config.codec_bos_id, + ]], dtype=torch.long, + ) + self._codec_special_embeds = hf_model.talker.get_input_embeddings()( + self._codec_special_tokens + ).to(self.dtype) + + def build_prefill(self) -> tuple[torch.Tensor, torch.Tensor]: + """Block until 4 assistant tokens + all prefill USER hidden are ready, + then assemble the talker_embed / talker_input_ids for talker prefill. + + Returns (talker_input_embed, talker_input_ids). + """ + # Wait for prefill hidden (USER-turn hidden needed for _get_talker_user_parts) + with self.state.cond: + while self.state.prefill_hidden is None and not self.state.thinker_done: + self.state.cond.wait() + if self.state.thinker_done and self.state.thinker_error is not None: + raise self.state.thinker_error + # Wait for 4 assistant tokens (needed for the prefill slice) + ok = self.state.wait_for_tokens(4) + if not ok: + raise RuntimeError("thinker did not produce 4 assistant tokens in time") + + # -- Build current thinker hidden for USER parts only -- + # At this point, prefill_hidden has the USER + system + assistant-role + # header embedded. Decode-step hiddens (one per assistant token) are + # appended to layer23_hidden[1:]. For the talker USER parts, we only + # need positions up to assistant_start_idx — all in prefill_hidden. + prompt_len = self.state.assistant_start_idx + prefill_h = self.state.prefill_hidden[:, :prompt_len, :].to(self.dtype) + + # -- Build current thinker_embed up to assistant_start + 4 tokens -- + with self.state.cond: + first_asst_ids = list(self.state.token_ids[:4]) + assistant_ids = torch.tensor([first_asst_ids], dtype=torch.long) + prompt_ids = self.conv_inputs.input_ids + all_ids = torch.cat([prompt_ids, assistant_ids], dim=1) + with torch.no_grad(): + thinker_embed = self._embed_layer(all_ids).to(self.dtype) + + cfg = self.cfg + im_start_indexes = torch.cat(( + torch.nonzero(all_ids[0] == cfg.im_start_token_id).squeeze(), + torch.tensor([all_ids.shape[-1]], dtype=all_ids.dtype), + ), dim=-1) + multimodal_mask = ( + (all_ids == cfg.thinker_config.audio_token_id) + | (all_ids == cfg.thinker_config.image_token_id) + | (all_ids == cfg.thinker_config.video_token_id) + ) + + # assistant_hidden for the prefill = text_projection of first-4 assistant embeddings + assistant_embed_first4 = thinker_embed[:, prompt_len:prompt_len + 4] + assistant_hidden_first4 = self.hf_model.talker.text_projection( + assistant_embed_first4 + ).to(self.dtype) + + # --- USER parts --- + # Pad thinker_hidden out to all_ids length so _get_talker_user_parts + # can index it; only the USER positions within prompt are really used. + hidden_full = torch.zeros( + (1, all_ids.shape[1], prefill_h.shape[-1]), dtype=self.dtype, + ) + hidden_full[:, :prompt_len, :] = prefill_h + talker_embeds = [] + talker_ids_list = [] + for i in range(len(im_start_indexes) - 1): + ims = im_start_indexes[i] + segend = im_start_indexes[i + 1] + role = all_ids[0][ims + 1] + if role == cfg.system_token_id: + continue + if role == cfg.user_token_id: + part = self.hf_model._get_talker_user_parts( + ims, segend, multimodal_mask, hidden_full, thinker_embed, + ) + talker_embeds.append(part) + talker_ids_list.append(all_ids[:, ims:segend]) + # Assistant turn with im_start inside prompt (prior turns). For the + # final assistant turn (our generated one), we manually build below. + elif role == cfg.assistant_token_id: + # Only the very last im_start_index is our freshly-started + # assistant turn. Previous assistants are in prompt context + # and skipped (HF does the same in _build_talker_inputs). + if i == len(im_start_indexes) - 2: + # This is the new assistant turn — build from the first 4 + # tokens now available. + assistant_text_hidden = torch.cat(( + assistant_hidden_first4[:, :3], + self.tts_pad_embed.expand(-1, 4, -1), + self.tts_bos_embed, + assistant_hidden_first4[:, 3:4], + ), dim=1) + assistant_codec_hidden = torch.cat(( + torch.zeros( + (1, 3, cfg.talker_config.text_config.hidden_size), + dtype=self.dtype, + ), + self._codec_special_embeds, + ), dim=1) + input_embeds_asst = assistant_text_hidden + assistant_codec_hidden + input_ids_asst = torch.full( + (1, assistant_text_hidden.shape[1]), + fill_value=cfg.tts_pad_token_id, dtype=torch.long, + ) + talker_embeds.append(input_embeds_asst) + talker_ids_list.append(input_ids_asst) + # else: prior assistant turns — skipped (HF also skips them) + + talker_embed = torch.cat(talker_embeds, dim=1) + talker_input_ids = torch.cat(talker_ids_list, dim=1) + return talker_embed, talker_input_ids + + def get_trailing_slice(self, k: int) -> torch.Tensor: + """Return the k-th entry of trailing_text_hidden. + + trailing_text_hidden[k] corresponds to the (k+4)-th thinker assistant + embedding (via text_projection), per HF's assembly. For the tail + position past all thinker tokens, return tts_eos_embed. + """ + needed = k + 5 # need tokens[0..k+4], which is k+5 tokens + # Block until the (k+4)-th token is produced, or thinker finishes. + ok = self.state.wait_for_tokens(needed, timeout=60.0) + with self.state.cond: + n = len(self.state.token_ids) + done = self.state.thinker_done + if n >= k + 5: + tok_id = self.state.token_ids[k + 4] + # text_projection of the embedding for this single token + with torch.no_grad(): + e = self._embed_layer(torch.tensor([[tok_id]], dtype=torch.long)).to(self.dtype) + h = self.hf_model.talker.text_projection(e).to(self.dtype) + return h # shape [1, 1, 1024] + # Past end of thinker output → tts_eos_embed + return self.tts_eos_embed + + +def install_pipelined_prepare_inputs(hf_model, sti: StreamingTalkerInputs): + """Wrap the current ``Qwen3OmniMoeTalkerForConditionalGeneration.prepare_inputs_for_generation`` + (possibly already patched by ``install_streaming_talker_hook``) with a + layer that rewrites ``kwargs["trailing_text_hidden"]`` before the call so + that each decode step's indexing picks up a streaming-sourced row. + """ + from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import ( + Qwen3OmniMoeTalkerForConditionalGeneration as Cls, + ) + # Wrap WHATEVER is currently set (may be HF original, may be the streaming + # hook's wrapper). We only install once per run — save and restore below. + prev_prep = Cls.prepare_inputs_for_generation + step_counter = {"n": 0} + + @functools.wraps(prev_prep) + def patched(self_talker, input_ids, *args, **kwargs): + # HF's talker prepare_inputs reads kwargs["trailing_text_hidden"] as a + # dense pre-built tensor indexed by ``generation_step``. We replace + # that slot on the fly with a tensor whose ``[:, gen_step]`` row is + # fetched from the streaming thinker output (blocks if not yet + # produced). Non-decode calls (prefill) have no such indexing and + # pass-through unchanged. + if "trailing_text_hidden" in kwargs: + gen_step = kwargs.get("generation_step") + if gen_step is None: + gen_step = step_counter["n"] + if gen_step is not None and gen_step >= 0: + slice_h = sti.get_trailing_slice(gen_step) # [1, 1, hidden] + trailing = kwargs["trailing_text_hidden"] + if trailing is None or trailing.shape[1] <= gen_step: + # Build a minimal tensor sized to the current step. + hidden_dim = slice_h.shape[-1] + fresh = torch.zeros((1, gen_step + 1, hidden_dim), dtype=slice_h.dtype) + fresh[:, gen_step:gen_step + 1, :] = slice_h + kwargs["trailing_text_hidden"] = fresh + else: + trailing = trailing.clone() + trailing[:, gen_step:gen_step + 1, :] = slice_h + kwargs["trailing_text_hidden"] = trailing + + out = prev_prep(self_talker, input_ids, *args, **kwargs) + step_counter["n"] += 1 + return out + + Cls.prepare_inputs_for_generation = patched + + def teardown(): + Cls.prepare_inputs_for_generation = prev_prep + step_counter["n"] = 0 + return teardown + + +# --------------------------------------------------------------------------- +# Main run_one / main +# --------------------------------------------------------------------------- + +def run_one(adapter, processor, hf_model, shim, ucp, conv, idx, out_wav_dir, + max_thinker_tokens, max_talker_tokens, speaker="ethan"): + messages = build_messages(conv) + wav_path = messages[-1]["content"][0]["audio"] + audio_np, sr = sf.read(wav_path) + if audio_np.ndim == 2: + audio_np = audio_np.mean(axis=1) + audio_np = audio_np.astype(np.float32) + input_audio_s = float(len(audio_np) / sr) + + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + target_sr = getattr(processor.feature_extractor, "sampling_rate", 16000) + if sr != target_sr: + import librosa + audio_for_fe = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr) + else: + audio_for_fe = audio_np + inputs = processor(text=[text], audio=[audio_for_fe], return_tensors="pt", padding=True) + + # Streaming state & callback + stream_state = { + "codes_list": [], "per_step_times": [], + "emitted_up_to": 0, "chunk_index": 0, "c2w_per_chunk_s": [], + } + wav_chunks = [] + chunk_timing = [] + request_start = time.perf_counter() + + def on_audio(wav_np, chunk_index, c2w_ms, codec_tokens, final=False): + rel_t_ms = (time.perf_counter() - request_start) * 1000 + wav_chunks.append(wav_np) + chunk_timing.append({ + "idx": chunk_index, "t_ms": rel_t_ms, "c2w_ms": c2w_ms, + "codec_tokens": codec_tokens, + "wav_samples": int(len(wav_np)), "final": final, + }) + + STR.install_streaming_ucp(hf_model, ucp, stream_state) + STR.install_streaming_talker_hook(hf_model, stream_state, hf_model.code2wav, on_audio) + + # Pipeline state — the thinker thread will fill it + prompt_len = inputs.input_ids.shape[1] + state = PipelineState(assistant_start_idx=prompt_len) + sti = StreamingTalkerInputs(hf_model, inputs, state, speaker=speaker) + + # Thinker kwargs — run on bg thread + gc_cfg = GenerationConfig(do_sample=False, eos_token_id=[151645], pad_token_id=151645) + gen_kwargs = dict( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + generation_config=gc_cfg, + max_new_tokens=max_thinker_tokens, + ) + if getattr(inputs, "input_features", None) is not None: + gen_kwargs["input_features"] = inputs.input_features.to(torch.bfloat16) + if getattr(inputs, "feature_attention_mask", None) is not None: + gen_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + t_thinker_start = time.perf_counter() + thinker_thread = threading.Thread( + target=run_thinker, args=(adapter, gen_kwargs, state, prompt_len), daemon=True, + ) + thinker_thread.start() + + # Build talker prefill inputs (blocks until 4 tokens + prefill hidden ready) + t_blk = time.perf_counter() + talker_embed, talker_input_ids = sti.build_prefill() + build_blocked_ms = (time.perf_counter() - t_blk) * 1000 + build_talker_s = time.perf_counter() - t_thinker_start + + # Install the pipelined prepare_inputs patch + teardown = install_pipelined_prepare_inputs(hf_model, sti) + + try: + # Talker config — match HF reference + talker_cfg = hf_model.config.talker_config + talker_vocab = talker_cfg.text_config.vocab_size + suppress_tokens = [ + i for i in range(talker_vocab - 1024, talker_vocab) + if i != talker_cfg.codec_eos_token_id + ] + + shim.reset_cache() + t0 = time.perf_counter() + hf_model.talker.generate( + inputs_embeds=talker_embed, + trailing_text_hidden=None, # patched prepare_inputs fills slot-wise + tts_pad_embed=sti.tts_pad_embed, + talker_input_ids=talker_input_ids, + max_new_tokens=max_talker_tokens, + do_sample=True, top_k=50, top_p=0.8, temperature=0.9, + repetition_penalty=1.1, suppress_tokens=suppress_tokens, + eos_token_id=talker_cfg.codec_eos_token_id, + output_hidden_states=True, + return_dict_in_generate=True, + ) + talker_s = time.perf_counter() - t0 + + # Emit residual codec tokens + STR.finalize_stream(stream_state, hf_model.code2wav, on_audio) + finally: + teardown() + thinker_thread.join(timeout=60.0) + + if state.thinker_error is not None: + raise state.thinker_error + + full_wav = np.concatenate(wav_chunks) if wav_chunks else np.zeros(0, dtype=np.float32) + out_wav_path = os.path.join(out_wav_dir, f"conv_{idx:03d}.wav") + sf.write(out_wav_path, full_wav, 24000) + + total_s = time.perf_counter() - request_start + ttfb_ms = chunk_timing[0]["t_ms"] if chunk_timing else None + wav_s = float(len(full_wav) / 24000) + rtf = total_s / wav_s if wav_s > 0 else None + + out_ids = getattr(state, "out_ids", None) + asst_text = "" + thinker_tokens = len(state.token_ids) + if out_ids is not None: + asst_text = processor.batch_decode( + out_ids[:, prompt_len:], + skip_special_tokens=True, clean_up_tokenization_spaces=False, + )[0].strip() + thinker_tokens = int(out_ids.shape[1] - prompt_len) + + return { + "idx": idx, + "wav_path": wav_path, + "input_audio_s": input_audio_s, + "prompt_tokens": int(prompt_len), + "thinker_tokens": thinker_tokens, + "build_talker_blocked_ms": build_blocked_ms, + "build_talker_total_s": build_talker_s, + "talker_s": talker_s, + "codec_tokens": int(len(stream_state["codes_list"])), + "num_chunks": len(chunk_timing), + "ttfb_ms": ttfb_ms, + "total_s": total_s, + "wav_s": wav_s, + "rtf": rtf, + "out_wav": out_wav_path, + "text": asst_text, + "chunks": chunk_timing, + } + + +def percentile(values, p): + if not values: + return float("nan") + s = sorted(values) + i = int(round((len(s) - 1) * p / 100)) + return s[i] + + +def print_summary(results): + ok = [r for r in results if "error" not in r] + print("\n=== SUMMARY ===") + print(f" samples ok: {len(ok)}/{len(results)}") + if not ok: + return + ttfbs = [r["ttfb_ms"] for r in ok if r.get("ttfb_ms") is not None] + rtfs = [r["rtf"] for r in ok if r.get("rtf") is not None] + total_ms = [r["total_s"] * 1000 for r in ok] + blocked_ms = [r["build_talker_blocked_ms"] for r in ok] + in_audio = [r["input_audio_s"] for r in ok] + out_wav = [r["wav_s"] for r in ok] + th_toks = [r["thinker_tokens"] for r in ok] + + def row(name, xs, fmt="{:6.0f}"): + s = statistics.mean(xs) + p50 = percentile(xs, 50) + p90 = percentile(xs, 90) + p95 = percentile(xs, 95) + print(f" {name:20s} mean={fmt.format(s)} p50={fmt.format(p50)} " + f"p90={fmt.format(p90)} p95={fmt.format(p95)}") + + row("TTFB ms", ttfbs) + row("blocked_build ms", blocked_ms) + row("total ms", total_ms) + row("RTF", rtfs, fmt="{:6.2f}") + row("input audio s", in_audio, fmt="{:6.2f}") + row("output wav s", out_wav, fmt="{:6.2f}") + row("thinker tokens", th_toks, fmt="{:6.0f}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num", type=int, default=100) + parser.add_argument("--start", type=int, default=0) + parser.add_argument("--out", default="/tmp/qwen3_omni_pipelined.json") + parser.add_argument("--wav-dir", default="/tmp/qwen3_omni_pipelined_wavs") + parser.add_argument("--max-thinker", type=int, default=200) + parser.add_argument("--max-talker", type=int, default=500) + parser.add_argument("--speaker", default="ethan") + parser.add_argument("--neuron-c2w", action="store_true") + args = parser.parse_args() + + os.makedirs(args.wav_dir, exist_ok=True) + with open(CONV_JSON) as f: + conversations = json.load(f) + + print(f"Loaded {len(conversations)} conversations; running [{args.start}, {args.start + args.num})") + print("Building Neuron pipeline...") + adapter, processor, hf_model, shim, ucp = STR.build_all() + if args.neuron_c2w: + print("Installing Neuron code2wav shim...") + install_neuron_code2wav(hf_model) + print("Pipeline ready.\n") + + results = [] + bench_start = time.perf_counter() + for k in range(args.num): + idx = args.start + k + if idx >= len(conversations): + break + try: + r = run_one( + adapter, processor, hf_model, shim, ucp, + conversations[idx], idx, args.wav_dir, + max_thinker_tokens=args.max_thinker, + max_talker_tokens=args.max_talker, + speaker=args.speaker, + ) + results.append(r) + ttfb_str = f"{r['ttfb_ms']:5.0f}ms" if r['ttfb_ms'] is not None else " n/a" + print( + f"[{k+1:3d}/{args.num}] conv {idx:3d} " + f"in={r['input_audio_s']:4.1f}s " + f"prompt={r['prompt_tokens']:4d} " + f"blocked={r['build_talker_blocked_ms']:5.0f}ms " + f"new={r['thinker_tokens']:3d}tok " + f"TTFB={ttfb_str} " + f"total={r['total_s']*1000:5.0f}ms " + f"wav={r['wav_s']:4.1f}s " + f"RTF={r['rtf']:.2f} " + f"[{r['text'][:32]}]" + ) + except Exception as e: + traceback.print_exc() + print(f"[{k+1:3d}/{args.num}] conv {idx}: FAILED {e}") + results.append({"idx": idx, "error": str(e)}) + + with open(args.out, "w") as f: + json.dump({ + "cumulative_wall_s": time.perf_counter() - bench_start, + "results": results, + }, f, indent=2, ensure_ascii=False) + + print_summary(results) + print(f"\nJSON: {args.out}") + print(f"WAVs: {args.wav_dir}") + + +if __name__ == "__main__": + main() diff --git a/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_rtf_bench.py b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_rtf_bench.py new file mode 100644 index 00000000..818fb3b1 --- /dev/null +++ b/contrib/models/Qwen3-Omni-30B-A3B-Instruct/test_ttfb_rtf_bench.py @@ -0,0 +1,341 @@ +#!/usr/bin/env python3 +"""TTFB / RTF benchmark on /home/ubuntu/omni2 chat conversations 0-99. + +Per-conversation flow: + * system prompt + prior user/assistant turns (all plain text) + * FINAL user turn = audio (the JSON stores the wav path as its content) + * assistant reply is produced by thinker → talker (streaming) → code2wav (CPU) + +Metrics we compute: + * input_audio_s — duration of the audio user utterance + * ttfb_ms — request_start → first audio chunk delivered + * thinker_ms — thinker decode wall time + * total_ms — end-to-end (thinker → talker → finalize → stitch) + * wav_s — total emitted audio duration + * RTF — total_ms / wav_ms, lower = better (<1 = realtime) + +Usage: + source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate + NEURON_RT_VISIBLE_CORES=0-7 \ + python test_ttfb_rtf_bench.py --num 100 +""" +import os +os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7") +os.environ.setdefault("QWEN3_OMNI_CAPTURE_LAYER_HIDDEN", "23") +os.environ["TRANSFORMERS_VERBOSITY"] = "error" + +import sys +from pathlib import Path + +_HERE = Path(__file__).resolve().parent +# Qwen3-Omni model src lives in two locations; prefer the current project copy +# and fall back to whn-ndi (which has the identical files). Some operations +# (git branch switch) can remove the local src/ directory. +for _candidate in (_HERE / "src", Path("/home/ubuntu/whn-ndi/contrib/models/Qwen3-Omni-30B-A3B-Instruct/src")): + if (_candidate / "_upstream_compat.py").exists() and str(_candidate) not in sys.path: + sys.path.insert(0, str(_candidate)) + break +# The streaming/full-neuron helpers live at /home/ubuntu/*.py — make them importable. +if "/home/ubuntu" not in sys.path: + sys.path.insert(0, "/home/ubuntu") + +import _upstream_compat # noqa: F401 + +import argparse +import json +import statistics +import time +import traceback + +import numpy as np +import soundfile as sf +import torch +from transformers import GenerationConfig + +import test_audio_streaming as STR # build_all, install_*, _assemble_hidden, _build_talker_inputs, finalize_stream + +# Local helpers (co-located with this bench script) +sys.path.insert(0, str(_HERE)) +from code2wav_neuron import install_neuron_code2wav # noqa: E402 + +CONV_JSON = "/home/ubuntu/omni2/merged_conversations_with_audio_x10_with_system.json" +AUDIO_DIR = "/home/ubuntu/omni2/speech_wav_16k" + + +def build_messages(conv): + """Convert a JSON conversation into HF-style `messages` where the final user + turn becomes an audio block pointing at AUDIO_DIR/.wav. + + The ground-truth assistant reply (last message) is dropped — we generate it. + """ + msgs = conv["messages"] + out = [] + for i, m in enumerate(msgs): + if i == len(msgs) - 1: + break # drop the reference assistant reply + role = m["role"] + content = m["content"] + if i == len(msgs) - 2 and role == "user": + fname = os.path.basename(content) + wav_path = os.path.join(AUDIO_DIR, fname) + out.append({"role": role, "content": [{"type": "audio", "audio": wav_path}]}) + else: + # Plain string content is valid in HF chat templates. + out.append({"role": role, "content": content}) + return out + + +def run_one(adapter, processor, hf_model, shim, ucp, conv, idx, out_wav_dir, + max_thinker_tokens, max_talker_tokens, speaker="ethan"): + messages = build_messages(conv) + wav_path = messages[-1]["content"][0]["audio"] + audio_np, sr = sf.read(wav_path) + if audio_np.ndim == 2: + audio_np = audio_np.mean(axis=1) + audio_np = audio_np.astype(np.float32) + input_audio_s = float(len(audio_np) / sr) + + # --- Processor: chat template + feature extraction --- + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + target_sr = getattr(processor.feature_extractor, "sampling_rate", 16000) + if sr != target_sr: + import librosa + audio_for_fe = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr) + else: + audio_for_fe = audio_np + inputs = processor( + text=[text], audio=[audio_for_fe], + return_tensors="pt", padding=True, + ) + + # --- Streaming state & callback --- + stream_state = { + "codes_list": [], "per_step_times": [], + "emitted_up_to": 0, "chunk_index": 0, "c2w_per_chunk_s": [], + } + wav_chunks = [] + chunk_timing = [] + request_start = time.perf_counter() + + def on_audio(wav_np, chunk_index, c2w_ms, codec_tokens, final=False): + rel_t_ms = (time.perf_counter() - request_start) * 1000 + wav_chunks.append(wav_np) + chunk_timing.append({ + "idx": chunk_index, "t_ms": rel_t_ms, "c2w_ms": c2w_ms, + "codec_tokens": codec_tokens, + "wav_samples": int(len(wav_np)), "final": final, + }) + + STR.install_streaming_ucp(hf_model, ucp, stream_state) + STR.install_streaming_talker_hook(hf_model, stream_state, hf_model.code2wav, on_audio) + + # --- Thinker --- + gc_cfg = GenerationConfig(do_sample=False, eos_token_id=[151645], pad_token_id=151645) + captured = [] + + def _cap_hook(_m, tensors): + if tensors: + captured.append(tensors[0].clone().to("cpu")) + + gen_kwargs = dict( + input_ids=inputs.input_ids, + attention_mask=inputs.attention_mask, + generation_config=gc_cfg, + max_new_tokens=max_thinker_tokens, + tensor_capture_hook=_cap_hook, + ) + if getattr(inputs, "input_features", None) is not None: + gen_kwargs["input_features"] = inputs.input_features.to(torch.bfloat16) + if getattr(inputs, "feature_attention_mask", None) is not None: + gen_kwargs["feature_attention_mask"] = inputs.feature_attention_mask + + t0 = time.perf_counter() + out_ids = adapter.generate(**gen_kwargs) + thinker_s = time.perf_counter() - t0 + thinker_end_ms = (time.perf_counter() - request_start) * 1000 + thinker_hidden = STR._assemble_hidden(captured, inputs, out_ids) + thinker_tokens = int(out_ids.shape[1] - inputs.input_ids.shape[1]) + + # --- Build talker inputs (on CPU; uses thinker hidden state captured above) --- + t0 = time.perf_counter() + talker_embed, talker_id, tts_pad, trailing = STR._build_talker_inputs( + hf_model, out_ids, thinker_hidden, speaker=speaker, + ) + build_talker_s = time.perf_counter() - t0 + + # --- Talker (Neuron shim) + streaming code2wav fired from prepare_inputs hook --- + # Suppress non-codec vocab tokens so only codec ids (0..2047) + codec_eos + # (2150) can be picked. Matches HF's reference call in ``Qwen3OmniMoeForConditionalGeneration.generate``. + talker_cfg = hf_model.config.talker_config + talker_vocab = talker_cfg.text_config.vocab_size + suppress_tokens = [ + i for i in range(talker_vocab - 1024, talker_vocab) + if i != talker_cfg.codec_eos_token_id + ] + + shim.reset_cache() + t0 = time.perf_counter() + hf_model.talker.generate( + inputs_embeds=talker_embed, + trailing_text_hidden=trailing, + tts_pad_embed=tts_pad, + talker_input_ids=talker_id, + max_new_tokens=max_talker_tokens, + do_sample=True, + top_k=50, + top_p=0.8, + temperature=0.9, + repetition_penalty=1.1, + suppress_tokens=suppress_tokens, + eos_token_id=talker_cfg.codec_eos_token_id, + output_hidden_states=True, + return_dict_in_generate=True, + ) + talker_s = time.perf_counter() - t0 + + # Emit any residual codec tokens that didn't fill a CHUNK_SIZE chunk. + STR.finalize_stream(stream_state, hf_model.code2wav, on_audio) + + full_wav = np.concatenate(wav_chunks) if wav_chunks else np.zeros(0, dtype=np.float32) + out_wav_path = os.path.join(out_wav_dir, f"conv_{idx:03d}.wav") + sf.write(out_wav_path, full_wav, 24000) + + total_s = time.perf_counter() - request_start + ttfb_ms = chunk_timing[0]["t_ms"] if chunk_timing else None + wav_s = float(len(full_wav) / 24000) + rtf = total_s / wav_s if wav_s > 0 else None + asst_text = processor.batch_decode( + out_ids[:, inputs.input_ids.shape[1]:], + skip_special_tokens=True, clean_up_tokenization_spaces=False, + )[0].strip() + + return { + "idx": idx, + "wav_path": wav_path, + "input_audio_s": input_audio_s, + "prompt_tokens": int(inputs.input_ids.shape[1]), + "thinker_tokens": thinker_tokens, + "thinker_s": thinker_s, + "thinker_end_ms": thinker_end_ms, + "build_talker_s": build_talker_s, + "talker_s": talker_s, + "codec_tokens": int(len(stream_state["codes_list"])), + "num_chunks": len(chunk_timing), + "ttfb_ms": ttfb_ms, + "total_s": total_s, + "wav_s": wav_s, + "rtf": rtf, + "out_wav": out_wav_path, + "text": asst_text, + "chunks": chunk_timing, + } + + +def percentile(values, p): + if not values: + return float("nan") + s = sorted(values) + i = int(round((len(s) - 1) * p / 100)) + return s[i] + + +def print_summary(results): + ok = [r for r in results if "error" not in r] + print("\n=== SUMMARY ===") + print(f" samples ok: {len(ok)}/{len(results)}") + if not ok: + return + ttfbs = [r["ttfb_ms"] for r in ok if r.get("ttfb_ms") is not None] + rtfs = [r["rtf"] for r in ok if r.get("rtf") is not None] + thinker_ms = [r["thinker_s"] * 1000 for r in ok] + total_ms = [r["total_s"] * 1000 for r in ok] + in_audio = [r["input_audio_s"] for r in ok] + out_wav = [r["wav_s"] for r in ok] + th_toks = [r["thinker_tokens"] for r in ok] + + def row(name, xs, fmt="{:6.0f}"): + s = statistics.mean(xs) + p50 = percentile(xs, 50) + p90 = percentile(xs, 90) + p95 = percentile(xs, 95) + print(f" {name:18s} mean={fmt.format(s)} p50={fmt.format(p50)} " + f"p90={fmt.format(p90)} p95={fmt.format(p95)}") + + row("TTFB ms", ttfbs) + row("thinker ms", thinker_ms) + row("total ms", total_ms) + row("RTF", rtfs, fmt="{:6.2f}") + row("input audio s", in_audio, fmt="{:6.2f}") + row("output wav s", out_wav, fmt="{:6.2f}") + row("thinker tokens", th_toks, fmt="{:6.0f}") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--num", type=int, default=100) + parser.add_argument("--start", type=int, default=0) + parser.add_argument("--out", default="/tmp/qwen3_omni_ttfb_rtf.json") + parser.add_argument("--wav-dir", default="/tmp/qwen3_omni_ttfb_rtf_wavs") + parser.add_argument("--max-thinker", type=int, default=200) + parser.add_argument("--max-talker", type=int, default=512) + parser.add_argument("--speaker", default="ethan") + parser.add_argument("--neuron-c2w", action="store_true", + help="Route code2wav through Neuron NEFFs (default: CPU)") + args = parser.parse_args() + + os.makedirs(args.wav_dir, exist_ok=True) + with open(CONV_JSON) as f: + conversations = json.load(f) + + print(f"Loaded {len(conversations)} conversations; " + f"running [{args.start}, {args.start + args.num})") + print("Building Neuron pipeline (thinker + audio + talker + UCP)...") + adapter, processor, hf_model, shim, ucp = STR.build_all() + if args.neuron_c2w: + print("Installing Neuron code2wav shim...") + install_neuron_code2wav(hf_model) + print("Pipeline ready.\n") + + results = [] + bench_start = time.perf_counter() + + for k in range(args.num): + idx = args.start + k + if idx >= len(conversations): + break + conv = conversations[idx] + try: + r = run_one( + adapter, processor, hf_model, shim, ucp, conv, idx, args.wav_dir, + max_thinker_tokens=args.max_thinker, + max_talker_tokens=args.max_talker, + speaker=args.speaker, + ) + results.append(r) + print(f"[{k+1:3d}/{args.num}] conv {idx:3d} " + f"in={r['input_audio_s']:4.1f}s " + f"prompt_tok={r['prompt_tokens']:4d} " + f"thinker={r['thinker_s']*1000:4.0f}ms/{r['thinker_tokens']:3d}tok " + f"ttfb={r['ttfb_ms']:5.0f}ms " + f"total={r['total_s']*1000:5.0f}ms " + f"wav={r['wav_s']:4.1f}s " + f"RTF={r['rtf']:.2f} " + f"[{r['text'][:36]}]") + except Exception as e: + traceback.print_exc() + print(f"[{k+1:3d}/{args.num}] conv {idx}: FAILED {e}") + results.append({"idx": idx, "error": str(e)}) + + with open(args.out, "w") as f: + json.dump({ + "cumulative_wall_s": time.perf_counter() - bench_start, + "results": results, + }, f, indent=2, ensure_ascii=False) + + print_summary(results) + print(f"\nFull JSON: {args.out}") + print(f"WAVs: {args.wav_dir}") + + +if __name__ == "__main__": + main()