Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
453 changes: 453 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/BENCHMARK_OMNI2_TTFB.md

Large diffs are not rendered by default.

407 changes: 407 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/README.md

Large diffs are not rendered by default.

104 changes: 104 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/code2wav_neuron.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Runtime shim: replace ``hf_model.code2wav`` CPU calls with Neuron NEFFs.

Buckets are chosen at install time; at call time we pick the smallest bucket
>= T and pad the codec-tokens tensor up to it. The output is trimmed back to
``T * total_upsample`` samples to match CPU behavior.

Install once per process via ``install_neuron_code2wav(hf_model)``.
"""
import os
from pathlib import Path
from typing import List, Optional

import torch

DEFAULT_BUCKETS_DIR = Path("/tmp/qwen3_omni_compiled/code2wav_buckets")


class NeuronCode2WavShim(torch.nn.Module):
"""Holds one compiled NEFF per bucket size; dispatches on T at call time."""

def __init__(self, hf_c2w, buckets_dir: Path, buckets: Optional[List[int]] = None):
super().__init__()
# We want to keep ``config`` and ``total_upsample`` from the original so
# callers that read those still work (``chunked_decode`` uses
# ``self.total_upsample``).
self.hf_c2w = hf_c2w
self.config = hf_c2w.config
self.total_upsample = hf_c2w.total_upsample

found = {}
for f in sorted(buckets_dir.glob("model_T*.pt")):
# Parse T from filename "model_T{int}.pt"
T = int(f.stem.split("_T")[-1])
if buckets is None or T in buckets:
found[T] = f
if not found:
raise RuntimeError(f"No code2wav NEFFs found in {buckets_dir}")

self._neffs = {}
for T in sorted(found):
print(f" [code2wav shim] loading T={T} from {found[T]}")
self._neffs[T] = torch.jit.load(str(found[T]))
self._bucket_sizes = sorted(self._neffs.keys())
self._max_bucket = self._bucket_sizes[-1]

def _pick_bucket(self, T: int) -> int:
for b in self._bucket_sizes:
if b >= T:
return b
# T exceeds the largest bucket — fall back to CPU.
return -1

def forward(self, codes: torch.Tensor) -> torch.Tensor:
B, Q, T = codes.shape
bucket = self._pick_bucket(T)
if bucket == -1:
# No NEFF big enough: use CPU
return self.hf_c2w(codes)

if T == bucket:
padded_codes = codes
else:
# Right-pad with zeros (valid codec ids live in [0, codebook_size=2048))
pad_amount = bucket - T
pad = torch.zeros((B, Q, pad_amount), dtype=codes.dtype, device=codes.device)
padded_codes = torch.cat([codes, pad], dim=-1)

neuron = self._neffs[bucket]
wav = neuron(padded_codes)
# Output shape is (B, 1, bucket * total_upsample). Trim to real length.
real_samples = T * self.total_upsample
wav = wav[..., :real_samples]
return wav

# chunked_decode is inherited behavior on hf_c2w but our forward shim gets
# called with codes — we re-implement here for symmetry and to avoid HF
# accidentally calling the CPU forward.
def chunked_decode(self, codes: torch.Tensor, chunk_size: int = 300,
left_context_size: int = 25) -> torch.Tensor:
wavs = []
start_index = 0
while start_index < codes.shape[-1]:
end_index = min(start_index + chunk_size, codes.shape[-1])
context_size = left_context_size if start_index - left_context_size > 0 else start_index
codes_chunk = codes[..., start_index - context_size: end_index]
wav_chunk = self.forward(codes_chunk)
wavs.append(wav_chunk[..., context_size * self.total_upsample:])
start_index = end_index
return torch.cat(wavs, dim=-1)


def install_neuron_code2wav(
hf_model,
buckets_dir: Path = DEFAULT_BUCKETS_DIR,
buckets: Optional[List[int]] = None,
) -> NeuronCode2WavShim:
"""Replace ``hf_model.code2wav`` with a Neuron-backed shim.

Returns the shim (holding the original HF code2wav on ``.hf_c2w`` in case
callers want to fall back).
"""
shim = NeuronCode2WavShim(hf_model.code2wav, buckets_dir, buckets=buckets)
hf_model.code2wav = shim
return shim
47 changes: 47 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""Compile Qwen3-Omni audio encoder transformer to a single Neuron core.

Conv2d frontend stays on CPU. Transformer layers + postprocessor are traced
per bucket size via torch_neuronx.trace (no TP).

Usage:
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
NEURON_RT_VISIBLE_CORES=16 python compile_audio.py
"""
import json
import logging
import sys
import time
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).parent / "src"))

logging.basicConfig(level=logging.INFO)

MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct"
COMPILED_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-audio"

from modeling_qwen3_omni_audio import Qwen3OmniAudioEncoder

config_path = Path(MODEL_PATH) / "config.json"
with open(config_path) as f:
full_config = json.load(f)

audio_config = full_config.get("thinker_config", {}).get("audio_config", {})
print(f"Audio config: d_model={audio_config.get('d_model')}, "
f"layers={audio_config.get('encoder_layers', audio_config.get('num_hidden_layers'))}, "
f"heads={audio_config.get('encoder_attention_heads')}")

print("Loading audio encoder weights...")
t0 = time.perf_counter()
encoder = Qwen3OmniAudioEncoder.from_pretrained(MODEL_PATH, audio_config)
print(f"Weights loaded in {time.perf_counter() - t0:.1f}s")

print(f"Compiling audio encoder to Neuron (buckets: {encoder.__class__.__name__})...")
t0 = time.perf_counter()
encoder.compile_neuron(COMPILED_PATH)
elapsed = time.perf_counter() - t0
print(f"Audio encoder compilation complete in {elapsed:.1f}s")
print(f"Compiled audio encoder saved to: {COMPILED_PATH}")
113 changes: 113 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_code2wav.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""Compile code2wav (vocoder) on Neuron.

code2wav is the ConvNeXt + upsample + BigVGAN stack that maps 16-channel codec
tokens to 24 kHz audio. It ran on CPU in the streaming bench, spending ~390 ms
on the first chunk and blocking TTFB.

The model is a fixed-size graph given a fixed input length T (in codec tokens).
We trace one NEFF per bucket and dispatch at runtime by rounding T up to the
next bucket. Compile via ``torch_neuronx.trace`` (not the SPMD ModelBuilder) —
single-core, fp32 weights, no tensor parallelism.

Output: /tmp/qwen3_omni_compiled/code2wav_buckets/model_T{T}.pt for each T.

Usage:
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
NEURON_RT_VISIBLE_CORES=0-7 python compile_code2wav.py
"""
import os
os.environ.setdefault("NEURON_RT_VISIBLE_CORES", "0-7")
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

import argparse
import time
from pathlib import Path

import torch
import torch_neuronx
from transformers import Qwen3OmniMoeForConditionalGeneration

MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct"
# Bucket set tuned for the streaming bench:
# * streaming chunk: CHUNK_SIZE + LEFT_CTX = 25 + 5 = 30
# * finalize chunk (tail): LEFT_CTX + 0..CHUNK_SIZE-1 ≤ 30
# * non-streaming `chunked_decode` default: chunk_size=300 + left_context=25
# We cover the streaming sizes and a large bucket for safety.
DEFAULT_BUCKETS = [30, 50, 128, 300, 512]


class Code2WavWrapper(torch.nn.Module):
"""Wraps ``Qwen3OmniMoeCode2Wav.forward`` so it is trace-friendly.

The original forward does a shape check that raises a Python error if
codes.shape[1] != num_quantizers. We keep that check out of the trace
(it's a static invariant) and only expose the compute.
"""

def __init__(self, c2w):
super().__init__()
self.c2w = c2w

def forward(self, codes):
# codes: [1, num_quantizers=16, T], long
c2w = self.c2w
hidden = c2w.code_embedding(codes + c2w.code_offset).mean(1)
hidden = c2w.pre_transformer(inputs_embeds=hidden).last_hidden_state
hidden = hidden.permute(0, 2, 1)
for blocks in c2w.upsample:
for block in blocks:
hidden = block(hidden)
wav = hidden
for block in c2w.decoder:
wav = block(wav)
return wav.clamp(min=-1, max=1)


def compile_one(c2w_wrapper, T, out_path):
example = torch.randint(0, 2048, (1, 16, T), dtype=torch.long)
print(f" tracing T={T} ...")
t0 = time.time()
traced = torch_neuronx.trace(
c2w_wrapper,
example,
compiler_workdir=f"/tmp/c2w_workdir_T{T}",
# fp32 for correctness; c2w is fairly small so cost is modest.
compiler_args="--auto-cast=none",
)
traced.save(str(out_path))
# Quick sanity: run once
out = traced(example)
print(f" done in {time.time()-t0:.0f}s, out shape={tuple(out.shape)}")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--out-dir", default="/tmp/qwen3_omni_compiled/code2wav_buckets")
parser.add_argument("--buckets", nargs="*", type=int, default=DEFAULT_BUCKETS)
args = parser.parse_args()

out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)

print(f"Loading HF model (we only need .code2wav) ...")
t0 = time.time()
hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
MODEL_PATH, dtype=torch.float32, low_cpu_mem_usage=True, device_map="cpu",
)
hf_model.eval()
print(f" loaded in {time.time()-t0:.0f}s")

wrapper = Code2WavWrapper(hf_model.code2wav).eval()

for T in args.buckets:
out_path = out_dir / f"model_T{T}.pt"
if out_path.exists():
print(f"T={T}: already compiled at {out_path}, skipping")
continue
print(f"T={T}: compiling to {out_path}")
compile_one(wrapper, T, out_path)


if __name__ == "__main__":
main()
59 changes: 59 additions & 0 deletions contrib/models/Qwen3-Omni-30B-A3B-Instruct/compile_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""Compile Qwen3-Omni multimodal model (text MoE + vision encoder) for Neuron.

Both text and vision models use TP=16 with LNC=2, running on 32 physical cores.

Usage:
source /opt/aws_neuronx_venv_pytorch_2_9_nxd_inference/bin/activate
NEURON_RT_VISIBLE_CORES=0-31 python compile_multimodal.py
"""
import sys
import time
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).parent / "src"))

from modeling_qwen3_omni import (
NeuronQwen3OmniForCausalLM,
Qwen3OmniInferenceConfig,
load_qwen3_omni_multimodal_config,
)
from neuronx_distributed_inference.models.config import MoENeuronConfig, NeuronConfig

MODEL_PATH = "/home/ubuntu/models/Qwen3-Omni-30B-A3B-Instruct"
COMPILED_PATH = "/home/ubuntu/traced_model/Qwen3-Omni-multimodal"
TP_DEGREE = 16

text_neuron_config = MoENeuronConfig(
tp_degree=TP_DEGREE,
batch_size=1,
seq_len=4096,
max_context_length=2048,
torch_dtype=torch.bfloat16,
on_device_sampling_config={"top_k": 1, "do_sample": False},
blockwise_matmul_config={"use_torch_block_wise": True},
)

vision_neuron_config = NeuronConfig(
tp_degree=TP_DEGREE,
batch_size=1,
seq_len=4096,
torch_dtype=torch.bfloat16,
)

config = Qwen3OmniInferenceConfig(
text_neuron_config=text_neuron_config,
vision_neuron_config=vision_neuron_config,
load_config=load_qwen3_omni_multimodal_config(MODEL_PATH),
)

model = NeuronQwen3OmniForCausalLM(MODEL_PATH, config)

print(f"Compiling multimodal model with TP={TP_DEGREE} ...")
t0 = time.perf_counter()
model.compile(COMPILED_PATH)
elapsed = time.perf_counter() - t0
print(f"Compilation complete in {elapsed:.1f}s")
print(f"Compiled model saved to: {COMPILED_PATH}")
Loading