From 74fdc7983e569dd1a25b6bf3563b67f718ba6d2b Mon Sep 17 00:00:00 2001 From: Jim Burtoft Date: Sat, 23 May 2026 11:42:06 -0400 Subject: [PATCH] Add GLM-4.7-Flash contrib model GLM-4.7-Flash (zai-org/GLM-4.7-Flash) is a 30B-A3B MoE model using DeepSeek-V3-style Multi-head Latent Attention (MLA) with compressed KV cache, 64 routed experts with sigmoid routing, and shared expert. Key features: - MLA attention with 94% KV cache reduction (576 dims vs 10,240) - FP8 E4M3 quantization for MoE expert weights - NKI bwmm_shard_on_block CTE kernel for optimized prefill - 16K context support on trn2.3xlarge TP=4 LNC=2 - vLLM 0.16.0 serving support (48 tok/s at concurrency=4) - 51.7 tok/s throughput at BS=4, 99.8 tok/s at BS=16/SEQ=4096 Validated on trn2.3xlarge with SDK 2.29.1 (neuronx-cc 2.24.8799). Includes 4 unit tests (CPU) and 5 integration tests (Neuron device). --- contrib/models/GLM-4.7-Flash/README.md | 270 ++++ contrib/models/GLM-4.7-Flash/src/__init__.py | 28 + contrib/models/GLM-4.7-Flash/src/compat.py | 516 +++++++ .../src/modeling_glm4_moe_lite.py | 1272 +++++++++++++++++ contrib/models/GLM-4.7-Flash/src/rope_util.py | 61 + contrib/models/GLM-4.7-Flash/test/__init__.py | 0 .../test/integration/__init__.py | 0 .../test/integration/compile_fp8.py | 583 ++++++++ .../test/integration/test_model.py | 326 +++++ .../GLM-4.7-Flash/test/unit/__init__.py | 0 .../GLM-4.7-Flash/test/unit/test_config.py | 164 +++ .../GLM-4.7-Flash/test/unit/test_rope.py | 110 ++ .../GLM-4.7-Flash/test/unit/test_router.py | 228 +++ .../test/unit/test_weight_conversion.py | 288 ++++ 14 files changed, 3846 insertions(+) create mode 100644 contrib/models/GLM-4.7-Flash/README.md create mode 100644 contrib/models/GLM-4.7-Flash/src/__init__.py create mode 100644 contrib/models/GLM-4.7-Flash/src/compat.py create mode 100644 contrib/models/GLM-4.7-Flash/src/modeling_glm4_moe_lite.py create mode 100644 contrib/models/GLM-4.7-Flash/src/rope_util.py create mode 100644 contrib/models/GLM-4.7-Flash/test/__init__.py create mode 100644 contrib/models/GLM-4.7-Flash/test/integration/__init__.py create mode 100644 contrib/models/GLM-4.7-Flash/test/integration/compile_fp8.py create mode 100644 contrib/models/GLM-4.7-Flash/test/integration/test_model.py create mode 100644 contrib/models/GLM-4.7-Flash/test/unit/__init__.py create mode 100644 contrib/models/GLM-4.7-Flash/test/unit/test_config.py create mode 100644 contrib/models/GLM-4.7-Flash/test/unit/test_rope.py create mode 100644 contrib/models/GLM-4.7-Flash/test/unit/test_router.py create mode 100644 contrib/models/GLM-4.7-Flash/test/unit/test_weight_conversion.py diff --git a/contrib/models/GLM-4.7-Flash/README.md b/contrib/models/GLM-4.7-Flash/README.md new file mode 100644 index 00000000..738ebade --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/README.md @@ -0,0 +1,270 @@ +# Contrib Model: GLM-4.7-Flash + +Neuron inference support for [GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) (30B-A3B MoE) using NxD Inference. This model uses DeepSeek-V3-style Multi-head Latent Attention (MLA) with compressed KV cache, 64 routed experts with sigmoid routing, and supports up to 16K context on trn2.3xlarge. + +## Model Information + +- **HuggingFace ID:** `zai-org/GLM-4.7-Flash` +- **Model Type:** Decoder-only MoE transformer with MLA attention +- **Parameters:** ~31B total, ~3B active per token (64 experts, top-4 routing) +- **Architecture:** MLA (Multi-head Latent Attention), MoE with sigmoid routing + shared expert, SiLU activation, RMSNorm, standard RoPE +- **License:** MIT + +## Validation Results + +**Validated:** 2026-05-23 +**Instance:** trn2.3xlarge (TP=4, LNC=2) +**SDK:** Neuron SDK 2.29.1 (neuronx-cc 2.24.8799, NxDI 0.9.17334) + +### Benchmark Results + +Configuration: BS=4, SEQ_LEN=16384, FP8 expert quantization, NKI `bwmm_shard_on_block` CTE kernel + +| Prompt Length | TTFT | TPOT | Throughput (batch total) | +|--------------|------|------|------------------------| +| 128 tokens | 758 ms | 77 ms | 51.7 tok/s | +| 1,024 tokens | 1,189 ms | 77 ms | 51.7 tok/s | +| 4,096 tokens | 5,012 ms | 78 ms | 51.6 tok/s | +| 8,192 tokens | 11,374 ms | 81 ms | 49.5 tok/s | +| 16,000 tokens | 11,379 ms | 81 ms | 49.6 tok/s | + +### Batch Size Scaling + +| Batch Size | Seq Len | TPOT | Throughput | SDK Required | +|-----------|---------|------|-----------|--------------| +| 4 | 16384 | 77 ms | 51.7 tok/s | 2.29.1 | +| 8 | 4096 | 105.6 ms | 75.7 tok/s | 2.29.1 | +| 16 | 4096 | 160.3 ms | 99.8 tok/s | 2.29.1 | + +### vLLM Serving Performance + +Configuration: BS=4, SEQ_LEN=16384, vLLM 0.16.0 + vllm-neuron 0.5.0 + +| Concurrency | Output Throughput | Mean Latency | TPOT | +|-------------|------------------:|-------------:|-----:| +| 1 | 12.7 tok/s | 10.1s | 78 ms | +| 2 | 24.8 tok/s | 10.3s | 78 ms | +| 4 | 48.1 tok/s | 10.6s | 78 ms | + +### Accuracy Validation + +First-token accuracy against CPU FP32 reference (greedy/top-k=1): + +| Prompt | Expected Token | Got Token | Status | +|--------|---------------|-----------|--------| +| "The capital of France is" | " Paris" (12089) | " Paris" (12089) | EXACT MATCH | +| "In machine learning, a transformer model" | " is" (374) | " is" (374) | EXACT MATCH | +| "def fibonacci(n):" | "\n" (715) | "\n" (715) | EXACT MATCH | + +Multi-token generation produces coherent, non-repetitive text at 16K context. + +## Usage + +```python +import os +import torch +from neuronx_distributed_inference.models.config import MoENeuronConfig, OnDeviceSamplingConfig +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from transformers import AutoConfig, AutoTokenizer, GenerationConfig +from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig +from src import ( + Glm4MoeLiteGenerationAdapter, + Glm4MoeLiteInferenceConfig, + NeuronGlm4MoeLiteForCausalLM, +) + +os.environ["NEURON_RT_VISIBLE_CORES"] = "0-3" + +# Register glm4_moe_lite model type (not yet in transformers registry) +class Glm4MoeLiteConfig(Glm4MoeConfig): + model_type = "glm4_moe_lite" + +AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig) + +MODEL_PATH = "/path/to/GLM-4.7-Flash" +COMPILED_PATH = "/path/to/compiled_glm4" + +# Load HF config +hf_config = AutoConfig.from_pretrained(MODEL_PATH) + +# Configure for Neuron +neuron_config = MoENeuronConfig( + tp_degree=4, + batch_size=4, + ctx_batch_size=1, + tkg_batch_size=4, + seq_len=16384, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=True, + flash_decoding_enabled=False, + logical_nc_config=2, +) + +inf_config = Glm4MoeLiteInferenceConfig( + neuron_config, load_config=load_pretrained_config(hf_config=hf_config) +) + +# Compile (first time only) +model = NeuronGlm4MoeLiteForCausalLM(MODEL_PATH, inf_config) +model.compile(COMPILED_PATH) + +# Load compiled model +model = NeuronGlm4MoeLiteForCausalLM(COMPILED_PATH, inf_config) +model.load(COMPILED_PATH) + +# Generate (MUST use Glm4MoeLiteGenerationAdapter for transformers >= 5.0) +gen_model = Glm4MoeLiteGenerationAdapter(model) +tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) +tokenizer.pad_token = tokenizer.eos_token + +gen_config = GenerationConfig( + do_sample=True, top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, +) + +prompt = "The capital of France is" +inputs = tokenizer([prompt] * 4, return_tensors="pt", padding=True) +outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, +) +print(tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)) +# Output: "Paris. The capital of Belgium is Brussels..." +``` + +### FP8 Quantized Inference + +For FP8 expert quantization (reduces memory, improves throughput): + +```python +# Pre-quantize expert weights (one-time step) +# Use scripts/quantize_experts_fp8.py to generate FP8 checkpoint + +neuron_config = MoENeuronConfig( + tp_degree=4, + batch_size=4, + ctx_batch_size=1, + tkg_batch_size=4, + seq_len=16384, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=True, + flash_decoding_enabled=False, + logical_nc_config=2, + # FP8 configuration + quantized=True, + quantization_type="expert_wise_per_channel_symmetric", + quantization_dtype="f8e4m3", + quantized_checkpoints_path="/path/to/GLM-4.7-Flash-FP8", + modules_to_not_convert=[ + "lm_head", "embed_tokens", "self_attn", "norm", + "layers.0.mlp", "shared_experts", "router", + ], + moe_fused_nki_kernel_enabled=True, +) +``` + +### vLLM Serving + +```bash +source /opt/aws_neuronx_venv_pytorch_inference_vllm_0_16/bin/activate +export PYTHONPATH=/path/to/GLM-4.7-Flash-contrib:$PYTHONPATH +export NEURON_RT_VISIBLE_CORES=0-3 +export UNSAFE_FP8FNCAST=1 +export NEURON_COMPILED_ARTIFACTS=/path/to/compiled_model + +vllm serve /path/to/GLM-4.7-Flash \ + --tensor-parallel-size 4 --dtype bfloat16 --block-size 128 \ + --max-model-len 16384 --max-num-seqs 4 \ + --additional-config '{"override_neuron_config": {"quantized": true, "quantization_type": "expert_wise_per_channel_symmetric", "quantization_dtype": "f8e4m3", "quantized_checkpoints_path": "/path/to/GLM-4.7-Flash-FP8", "modules_to_not_convert": ["lm_head", "embed_tokens", "self_attn", "norm", "layers.0.mlp", "shared_experts", "router"], "moe_fused_nki_kernel_enabled": true, "logical_nc_config": 2, "flash_decoding_enabled": false}}' +``` + +**Note:** vLLM integration requires: +1. Changing `model_type` in config.json from `glm4_moe_lite` to `glm4_moe` (for HF AutoConfig compatibility with transformers 4.57.6) +2. Registering `glm4moelite` in NxDI's `constants.py` MODEL_TYPES dict +3. Removing `auto_map` field and custom `.py` files from model directory + +## Compatibility Matrix + +| Instance | SDK 2.29.1 | SDK 2.29 | +|----------|-----------|----------| +| trn2.3xlarge (TP=4, LNC=2) | **VALIDATED** (BS=4-16) | VALIDATED (BS=4-8 only) | +| trn2.48xlarge | Not tested | Not tested | + +## Example Checkpoints + +* [zai-org/GLM-4.7-Flash](https://huggingface.co/zai-org/GLM-4.7-Flash) (BF16, 59 GB) + +## Testing Instructions + +```bash +# Unit tests (no Neuron device required): +python -m pytest test/unit/ -v + +# Integration tests (requires trn2.3xlarge with model weights): +GLM4_MODEL_PATH=/path/to/GLM-4.7-Flash \ +GLM4_COMPILED_PATH=/path/to/compiled_glm4 \ +pytest test/integration/test_model.py --capture=tee-sys +``` + +## Known Issues + +### transformers 5.x position_ids Compatibility + +NxDI's `HuggingFaceGenerationAdapter.prepare_inputs_for_generation` only recomputes `position_ids` when they are `None` in kwargs. In transformers >= 5.0, `_update_model_kwargs_for_generation` passes stale position_ids back, breaking decode. + +**Fix**: Use `Glm4MoeLiteGenerationAdapter` (included) which removes stale `position_ids` from kwargs. This issue affects all NxDI contrib models with transformers >= 5.0. + +### Minimum batch_size=4 (Compiler Workaround) + +The Neuron compiler has an internal issue (NCC_IBIR297) that causes compilation failure with `tkg_batch_size=1` and blockwise MoE at small TP degrees. Workaround: set `batch_size >= 4`. + +### BS > 8 requires SDK 2.29.1 + +On SDK 2.29 (neuronx-cc 2.24.5133), batch sizes > 8 cause a runtime DGE scatter/gather out-of-bounds error. Fixed in SDK 2.29.1 (neuronx-cc 2.24.8799). + +### Maximum context length + +- **16,384 tokens**: Maximum validated CTE bucket (uses 24 GB HBM budget cleanly) +- **32,768 tokens**: OOM by 0.25 GB per LNC=2 core +- Longer contexts would require LNC=1 (48 GB per core) or FP8 KV cache + +### `glm4_moe_lite` model_type registration + +The `glm4_moe_lite` model_type is not registered in transformers 4.57.6. Register it manually using the pattern in the Usage section. For vLLM, use `model_type: "glm4_moe"` in config.json instead. + +## Architecture Details + +### MLA (Multi-head Latent Attention) + +- **KV cache compression**: Stores only 576 dims per position (kv_lora_rank=512 + qk_rope_head_dim=64) vs 10,240 for standard MHA — 94% reduction +- **Weight absorption trick**: kv_b_proj weights absorbed into query-side computation to avoid decompression at decode time +- **Critical dimension fix**: `out_absorb = wkv_b[:, qk_nope_head_dim:, :]` (split at 192, not v_head_dim=256) + +### MoE Configuration + +- 64 routed experts, top-4 selection +- Sigmoid activation with e_score_correction_bias +- L1-normalized affinities, scaled by factor 1.8 +- 1 shared expert (always-active dense MLP) +- Layer 0 is fully dense, layers 1-46 are MoE +- FP8 E4M3 quantization for routed expert weights (attention/embeddings remain BF16) + +### Compiler Flags + +``` +--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1 +--tensorizer-options='--vectorize-strided-dma' +--auto-cast=none +--lnc=2 +``` + +**Important**: Do NOT add `--enable-ccop-compute-overlap` or `--cc-pipeline-tiling-factor` — these cause an ICE and 4.8x performance degradation. + +## Maintainer + +Jim Burtoft diff --git a/contrib/models/GLM-4.7-Flash/src/__init__.py b/contrib/models/GLM-4.7-Flash/src/__init__.py new file mode 100644 index 00000000..64fd91b7 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/src/__init__.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from src.modeling_glm4_moe_lite import ( + Glm4MoeLiteAttention, + Glm4MoeLiteDenseMLP, + Glm4MoeLiteGenerationAdapter, + Glm4MoeLiteInferenceConfig, + Glm4MoeLiteNeuronConfig, + Glm4MoeLiteRouter, + NeuronGlm4MoeLiteDecoderLayer, + NeuronGlm4MoeLiteForCausalLM, + NeuronGlm4MoeLiteModel, + custom_compiler_args, +) + +__all__ = [ + "Glm4MoeLiteAttention", + "Glm4MoeLiteDenseMLP", + "Glm4MoeLiteGenerationAdapter", + "Glm4MoeLiteInferenceConfig", + "Glm4MoeLiteNeuronConfig", + "Glm4MoeLiteRouter", + "NeuronGlm4MoeLiteDecoderLayer", + "NeuronGlm4MoeLiteForCausalLM", + "NeuronGlm4MoeLiteModel", + "custom_compiler_args", +] diff --git a/contrib/models/GLM-4.7-Flash/src/compat.py b/contrib/models/GLM-4.7-Flash/src/compat.py new file mode 100644 index 00000000..451748a4 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/src/compat.py @@ -0,0 +1,516 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +NKI kernel compatibility patch for GLM-4.7-Flash on SDK 2.29. + +SDK 2.29 removed the `neuronxcc.nki._private.blockwise_mm` module, leaving +`_call_shard_hidden_kernel` in NxD as a stub. This patch restores it using +the nkilib `blockwise_mm_baseline_shard_hidden` kernel (BF16 path only). + +Usage: + import src.compat # patches are applied on import + + # Then set use_torch_block_wise=False in your config to use the NKI kernel: + # config.neuron_config.blockwise_matmul_config.use_torch_block_wise = False + +This enables the NKI-optimized MoE CTE (context encoding) path, which should +provide significantly better performance than the torch fallback for blockwise +MoE computation. + +Based on the patching pattern from MiniMax-M2 contrib (compat.py). +""" + +import importlib +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def _patch_blockwise_shard_hidden(): + """Patch NxD blockwise.py _call_shard_hidden_kernel from nkilib. + + GLM-4.7-Flash uses native BF16 weights — no FP8 dequant needed. + We only restore the shard_hidden kernel call for the CTE path. + """ + try: + import neuronx_distributed.modules.moe.blockwise as bw + except ImportError: + logger.debug( + "neuronx_distributed.modules.moe.blockwise not available, skipping patch" + ) + return False + + # Check if the function is a stub (raises NotImplementedError) + try: + bw._call_shard_hidden_kernel(None) + except NotImplementedError: + pass # Confirmed stub, proceed with patch + except (TypeError, AttributeError): + logger.debug("_call_shard_hidden_kernel appears functional, skipping patch") + return False + + try: + mod = importlib.import_module("nkilib.experimental.moe.forward.bwmm_shard_on_H") + kernel_fn = getattr(mod, "blockwise_mm_baseline_shard_hidden") + + import nki + + wrapped_kernel = nki.jit(kernel_fn) + bw._blockwise_mm_baseline_shard_hidden_nki_call = wrapped_kernel + + def _call_shard_hidden_kernel_patched(args): + """Call the nkilib shard_hidden kernel for blockwise matmul (BF16).""" + output = wrapped_kernel[2]( + hidden_states=args.hidden_states, + expert_affinities_masked=args.expert_affinities_masked, + gate_up_proj_weight=args.gate_up_proj_weight, + down_proj_weight=args.down_proj_weight, + block_size=args.block_size, + token_position_to_id=args.token_position_to_id.to(dtype=torch.int32), + block_to_expert=args.block_to_expert.to(dtype=torch.int32), + gate_up_activations_T=args.gate_up_activations_T, + down_activations=args.down_activations, + skip_dma=args.skip_dma, + is_tensor_update_accumulating=args.is_tensor_update_accumulating, + expert_affinities_scaling_mode=args.expert_affinities_scaling_mode, + ) + return output, args.gate_up_activations_T, args.down_activations + + bw._call_shard_hidden_kernel = _call_shard_hidden_kernel_patched + logger.info( + "Patched NxD blockwise._call_shard_hidden_kernel with nkilib kernel" + ) + return True + + except Exception as e: + logger.warning(f"Failed to patch blockwise._call_shard_hidden_kernel: {e}") + return False + + +# Apply patch on import +_patched = _patch_blockwise_shard_hidden() +if _patched: + logger.info("NKI blockwise MoE shard_hidden kernel enabled") +else: + logger.warning( + "NKI blockwise MoE shard_hidden kernel NOT enabled — " + "falling back to torch blockwise (use_torch_block_wise=True)" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Patch 2: MoE TKG selective-expert NKI kernel +# ───────────────────────────────────────────────────────────────────────────── + + +def _patch_moe_tkg_selective_loading(): + """Patch ExpertMLPsV2.forward_selective_loading with nkilib moe_tkg kernel. + + The default forward_selective_loading loops over tokens in Python — + this replaces it with a single NKI kernel call that processes all tokens + in parallel using the selective-expert path (only top-K experts loaded). + + Weight layout follows MoEFusedTKG convention: + gate_up_proj.weight.view(E_L, H, 2, -1) -> [E_L, H, 2, I] + down_proj.weight -> [E_L, I, H] + + GLM-4.7-Flash: T=4, H=2048, I=384 (TP=4), E_L=64, K=4, BF16. + """ + try: + from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPsV2 + except ImportError: + logger.debug("ExpertMLPsV2 not available, skipping TKG patch") + return False + + try: + mod = importlib.import_module("nkilib.core.moe.moe_tkg.moe_tkg") + moe_tkg_fn = getattr(mod, "moe_tkg") + types_mod = importlib.import_module("nkilib.core.utils.common_types") + ExpertAffinityScaleMode = getattr(types_mod, "ExpertAffinityScaleMode") + ActFnType = getattr(types_mod, "ActFnType") + + import nki + + wrapped_moe_tkg = nki.jit(moe_tkg_fn) + except Exception as e: + logger.warning(f"Failed to import nkilib moe_tkg: {e}") + return False + + # Save original for fallback + _original_forward_selective_loading = ExpertMLPsV2.forward_selective_loading + + def _forward_selective_loading_nki( + self, hidden_states, expert_affinities, expert_index + ): + """NKI moe_tkg selective-expert replacement for forward_selective_loading. + + Follows the exact weight access pattern from MoEFusedTKG: + gate_up_proj.weight.view(num_local_experts, hidden_size, 2, -1) + down_proj.weight # already [E_L, I, H] + + Args: + hidden_states: [T, H] (already 2D from MoE dispatcher) + expert_affinities: [T, E] (dense, with scaled top-K values, zeros elsewhere) + expert_index: [T, K] (top-k expert indices per token, int64) + """ + H = hidden_states.shape[1] + mlp_op = self.get_mlp_op() + + # Access weights exactly as MoEFusedTKG does (line 225-228 of moe_fused_tkg.py) + # gate_up_proj.weight: flat tensor, reshape to [E_L, H, 2, I] + E_L = mlp_op.gate_up_proj._n_local_experts + gate_up_reshaped = mlp_op.gate_up_proj.weight.view(E_L, H, 2, -1) + # down_proj.weight: already [E_L, I, H] + down_w = mlp_op.down_proj.weight + + # Call NKI kernel: selective-expert mode (is_all_expert=False) + # POST_SCALE: kernel extracts affinities at expert_index positions + # and multiplies expert outputs by them (matching our router's pre-scaling) + # NOTE: expert_affinities MUST be float32 — the kernel's tensor_scalar op + # for affinity scaling requires float32 operand (MLIR verification fails on bf16) + output = wrapped_moe_tkg[2]( + hidden_input=hidden_states, # [T, H] + expert_gate_up_weights=gate_up_reshaped, # [E_L, H, 2, I] + expert_down_weights=down_w, # [E_L, I, H] + expert_affinities=expert_affinities.to(torch.float32), # [T, E] float32 + expert_index=expert_index.to(torch.int32), # [T, K] + is_all_expert=False, + expert_affinities_scaling_mode=ExpertAffinityScaleMode.POST_SCALE, + activation_fn=ActFnType.SiLU, + ) + return output + + ExpertMLPsV2.forward_selective_loading = _forward_selective_loading_nki + logger.info( + "Patched ExpertMLPsV2.forward_selective_loading with nkilib moe_tkg kernel" + ) + return True + + +_patched_moe_tkg = _patch_moe_tkg_selective_loading() +if _patched_moe_tkg: + logger.info("NKI MoE TKG selective-expert kernel enabled") +else: + logger.warning( + "NKI MoE TKG selective-expert kernel NOT enabled — " + "using default forward_selective_loading (torch loop)" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Patch 3: MoE TKG all-expert NKI kernel (for larger batch sizes) +# ───────────────────────────────────────────────────────────────────────────── + + +def _patch_moe_tkg_all_experts(): + """Patch ExpertMLPsV2.forward_all_experts with nkilib moe_tkg kernel (all-expert mode). + + When batch_size is large enough that T*K/E >= 1.0, NxD switches to forward_all_experts + which broadcasts all tokens through ALL experts. The NKI kernel with is_all_expert=True + does this in a single fused kernel call. + + Requires rank_id for affinity slicing in all-expert mode. + """ + try: + from neuronx_distributed.modules.moe.expert_mlps import ExpertMLPsV2 + except ImportError: + return False + + try: + mod = importlib.import_module("nkilib.core.moe.moe_tkg.moe_tkg") + moe_tkg_fn = getattr(mod, "moe_tkg") + types_mod = importlib.import_module("nkilib.core.utils.common_types") + ExpertAffinityScaleMode = getattr(types_mod, "ExpertAffinityScaleMode") + ActFnType = getattr(types_mod, "ActFnType") + + import nki + + wrapped_moe_tkg = nki.jit(moe_tkg_fn) + except Exception as e: + logger.warning(f"Failed to import nkilib moe_tkg for all-expert patch: {e}") + return False + + _original_forward_all_experts = ExpertMLPsV2.forward_all_experts + + def _forward_all_experts_nki( + self, hidden_states, expert_affinities, expert_index, chosen_expert_indices=None + ): + """NKI moe_tkg all-expert replacement for forward_all_experts. + + All tokens go through ALL local experts. The kernel handles masking/scaling internally. + + Args: + hidden_states: [T, H] + expert_affinities: [T, E] (dense, pre-scaled by router) + expert_index: [T, K] (top-k indices) + chosen_expert_indices: ignored (used by some EP paths) + """ + H = hidden_states.shape[1] + mlp_op = self.get_mlp_op() + + E_L = mlp_op.gate_up_proj._n_local_experts + gate_up_reshaped = mlp_op.gate_up_proj.weight.view(E_L, H, 2, -1) + down_w = mlp_op.down_proj.weight + + # All-expert mode requires rank_id for affinity slicing + # With no EP (single rank), rank_id = 0 + # Must be on the same device as other tensors (XLA device during tracing) + rank_id = torch.zeros(1, 1, dtype=torch.int32, device=hidden_states.device) + + output = wrapped_moe_tkg[2]( + hidden_input=hidden_states, # [T, H] + expert_gate_up_weights=gate_up_reshaped, # [E_L, H, 2, I] + expert_down_weights=down_w, # [E_L, I, H] + expert_affinities=expert_affinities.to(torch.float32), # [T, E] + expert_index=expert_index.to(torch.int32), # [T, K] + is_all_expert=True, + rank_id=rank_id, + expert_affinities_scaling_mode=ExpertAffinityScaleMode.POST_SCALE, + activation_fn=ActFnType.SiLU, + ) + return output + + ExpertMLPsV2.forward_all_experts = _forward_all_experts_nki + logger.info( + "Patched ExpertMLPsV2.forward_all_experts with nkilib moe_tkg kernel (all-expert mode)" + ) + return True + + +# NOTE: All-expert NKI patch — DGE OOB was fixed in SDK 2.29.1 (neuronx-cc 2.24.8799). +# Re-enabling the NKI all-expert kernel for BS>=16 where T*K/E >= 1.0 triggers +# forward_all_experts mode. The kernel provides fused expert computation which may +# improve TPOT at high batch sizes. +_patched_moe_tkg_all = _patch_moe_tkg_all_experts() + + +# ───────────────────────────────────────────────────────────────────────────── +# Patch 4: Replace MoEFusedTKG kernel with nkilib moe_block_tkg for +# e_score_correction_bias support +# ───────────────────────────────────────────────────────────────────────────── + + +def _patch_fused_tkg_for_correction_bias(): + """Replace MoEFusedTKG._moe_fused_tkg_kernel to use nkilib's moe_block_tkg. + + The pre-prod kernels (moe_token_gen_*) that NxDI calls internally do NOT + support e_score_correction_bias. However, the open-source nkilib + moe_block_tkg kernel does (we added router_correction_bias/scale params). + + This patch replaces the entire _moe_fused_tkg_kernel method to call + nkilib's moe_block_tkg kernel directly, passing the correction bias + and scaling factor from the GLM-4.7-Flash router. + + The nkilib kernel handles correction bias by: + 1. Adding bias to sigmoid affinities for top-K index selection + 2. Gathering ORIGINAL (unbiased) affinities at selected indices + 3. L1-normalizing and scaling by router_correction_scale + + Must be called before model.compile(). + """ + try: + import neuronx_distributed.modules.moe.moe_fused_tkg as fused_tkg_mod + from neuronx_distributed.modules.moe.model_utils import ( + ACTFunc, + DEFAULT_SELECTIVE_LOADING_THRESHOLD, + get_kernel_activation_func_id, + ) + + # Import nkilib moe_block_tkg kernel + try: + from nkilib_src.nkilib.core.moe_block.moe_block_tkg import moe_block_tkg + from nkilib_src.nkilib.core.utils.common_types import ( + ActFnType as NkilibActFnType, + RouterActFnType as NkilibRouterActFnType, + ExpertAffinityScaleMode as NkilibExpertAffinityScaleMode, + ) + except ImportError: + from nkilib.core.moe_block.moe_block_tkg import moe_block_tkg + from nkilib.core.utils.common_types import ( + ActFnType as NkilibActFnType, + RouterActFnType as NkilibRouterActFnType, + ExpertAffinityScaleMode as NkilibExpertAffinityScaleMode, + ) + + # Map NxDI's act_fn string names to nkilib enum values + NKILIB_ROUTER_ACT_FN_MAP = { + "sigmoid": NkilibRouterActFnType.SIGMOID, + "softmax": NkilibRouterActFnType.SOFTMAX, + } + NKILIB_ACT_FN_MAP = { + "silu": NkilibActFnType.SiLU, + "gelu": NkilibActFnType.GELU, + } + + def _replacement_moe_fused_tkg_kernel(self, hidden_states): + """Replacement _moe_fused_tkg_kernel using nkilib moe_block_tkg. + + Calls the nkilib kernel which supports router_correction_bias, + instead of the pre-prod moe_token_gen_* kernels which don't. + """ + hidden_states_shape = hidden_states.shape + + # Determine expert affinity scaling mode + if self.expert_mlps.routed_experts_mlp_config.early_expert_affinity_modulation: + scaling_mode = NkilibExpertAffinityScaleMode.PRE_SCALE + else: + scaling_mode = NkilibExpertAffinityScaleMode.POST_SCALE + + # Determine if we should use all-expert mode + total_tokens = hidden_states_shape[0] * hidden_states_shape[1] + perc_experts_loaded = ( + total_tokens * self.num_experts_per_tok / self.num_local_experts + ) + use_all_expert = perc_experts_loaded >= DEFAULT_SELECTIVE_LOADING_THRESHOLD + + # LNC config for nkilib kernel (integer: 1 or 2) + lnc = self.logical_nc_config + + # Get shared expert weights (will be None in FP8 mode) + ( + shared_expert_gate_w, + shared_expert_up_w, + shared_expert_down_w, + ) = self._slice_shared_experts_weights() + + # Get activation function + routed_experts_mlp_config = self.expert_mlps.routed_experts_mlp_config + kernel_activation_func_id = get_kernel_activation_func_id( + ACTFunc.validate(routed_experts_mlp_config.hidden_act), + routed_experts_mlp_config.glu_type, + ) + + # Build kernel kwargs + kernel_kwargs = dict( + inp=hidden_states, # [B, S, H] + gamma=self.post_attention_layernorm.weight.unsqueeze(0), # [1, H] + router_weights=self.router.weight_T, # [H, E] + expert_gate_up_weights=self.expert_mlps.mlp_op.gate_up_proj.weight.view( + self.num_local_experts, self.hidden_size, 2, -1 + ), # [E, H, 2, I] + expert_down_weights=self.expert_mlps.mlp_op.down_proj.weight, # [E, I, H] + shared_expert_gate_w=shared_expert_gate_w, + shared_expert_up_w=shared_expert_up_w, + shared_expert_down_w=shared_expert_down_w, + expert_gate_up_weights_scale=( + self.expert_mlps.mlp_op.gate_up_proj.scale.view( + self.num_local_experts, 2, -1 + ) + if self.config.quantized + else None + ), + expert_down_weights_scale=( + self.expert_mlps.mlp_op.down_proj.scale.view( + self.num_local_experts, -1 + ) + if self.config.quantized + else None + ), + router_bias=( + self.router.linear_router.bias if self.router.bias else None + ), + expert_gate_up_bias=( + self.expert_mlps.mlp_op.gate_up_proj.bias.view( + self.num_local_experts, 2, -1 + ) + if routed_experts_mlp_config.bias + else None + ), + expert_down_bias=( + self.expert_mlps.mlp_op.down_proj.bias + if routed_experts_mlp_config.bias + else None + ), + eps=self.post_attention_layernorm.variance_epsilon, + top_k=self.num_experts_per_tok, + router_act_fn=NKILIB_ROUTER_ACT_FN_MAP[self.router.act_fn], + router_pre_norm=not self.router.apply_act_fn_over_topk, + norm_topk_prob=self.config.norm_topk_prob, + expert_affinities_scaling_mode=scaling_mode, + hidden_act_fn=NkilibActFnType(kernel_activation_func_id), + is_all_expert=use_all_expert, + ) + + # Optional: hidden_actual for padded dimensions + if routed_experts_mlp_config.hidden_size_actual is not None: + kernel_kwargs["hidden_actual"] = ( + routed_experts_mlp_config.hidden_size_actual + ) + + # Optional: clamping limits + if routed_experts_mlp_config.gate_clamp_upper_limit is not None: + kernel_kwargs["gate_clamp_upper_limit"] = ( + routed_experts_mlp_config.gate_clamp_upper_limit + ) + if routed_experts_mlp_config.gate_clamp_lower_limit is not None: + kernel_kwargs["gate_clamp_lower_limit"] = ( + routed_experts_mlp_config.gate_clamp_lower_limit + ) + if routed_experts_mlp_config.up_clamp_upper_limit is not None: + kernel_kwargs["up_clamp_upper_limit"] = ( + routed_experts_mlp_config.up_clamp_upper_limit + ) + if routed_experts_mlp_config.up_clamp_lower_limit is not None: + kernel_kwargs["up_clamp_lower_limit"] = ( + routed_experts_mlp_config.up_clamp_lower_limit + ) + + # For all-expert mode, provide rank_id + if use_all_expert: + local_rank = self.expert_mlps.spmd_rank.get_rank() + local_ep_rank = ( + local_rank + // self.expert_mlps.moe_tensor_model_parallel_group.size() + ) + kernel_kwargs["rank_id"] = local_ep_rank.reshape(1, 1) + + # Inject router_correction_bias and router_correction_scale + # NOTE: Access correction_bias directly from self (MoEFusedTKG) where it's + # registered as a parameter. This ensures XLA tracing captures it as a + # weight input to the NEFF (not inlined as a compile-time constant). + # Accessing via self.router.e_score_correction_bias doesn't work because + # XLA's parameter tracking doesn't follow nested module access inside NKI calls. + if hasattr(self, "correction_bias"): + bias = self.correction_bias + bias = bias.to(torch.float32) + if bias.dim() == 1: + bias = bias.unsqueeze(0) # [1, E] + kernel_kwargs["router_correction_bias"] = bias + if hasattr(self, "router") and hasattr( + self.router, "routed_scaling_factor" + ): + kernel_kwargs["router_correction_scale"] = ( + self.router.routed_scaling_factor + ) + + # Call nkilib moe_block_tkg kernel + out, router_logits = moe_block_tkg[lnc](**kernel_kwargs) + + return out.view(hidden_states_shape), router_logits.to(hidden_states.dtype) + + # Replace the method + fused_tkg_mod.MoEFusedTKG._moe_fused_tkg_kernel = ( + _replacement_moe_fused_tkg_kernel + ) + logger.info( + "Patched MoEFusedTKG._moe_fused_tkg_kernel to use nkilib " + "moe_block_tkg kernel with router_correction_bias support" + ) + return True + + except ImportError as e: + logger.warning( + "Failed to import nkilib moe_block_tkg kernel: %s. " + "Correction bias will NOT be applied in fused TKG path.", + e, + ) + return False + except Exception as e: + logger.warning("Failed to patch MoEFusedTKG for correction bias: %s", e) + return False + + +_patched_correction_bias = _patch_fused_tkg_for_correction_bias() diff --git a/contrib/models/GLM-4.7-Flash/src/modeling_glm4_moe_lite.py b/contrib/models/GLM-4.7-Flash/src/modeling_glm4_moe_lite.py new file mode 100644 index 00000000..3557213d --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/src/modeling_glm4_moe_lite.py @@ -0,0 +1,1272 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# coding=utf-8 +# Adapted from DeepSeek-V3 NxDI contrib for GLM-4.7-Flash (glm4_moe_lite). +# GLM-4.7-Flash uses the same MLA + MoE architecture as DeepSeek-V3 but at +# smaller scale (30B-A3B, 47 layers, 64 experts top-4, no YaRN). +# Supports FP8 E4M3 quantization of MoE expert weights (EXPERT_WISE_PER_CHANNEL_SYMMETRIC). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 + +import gc +import logging +import os +from typing import List, Optional, Tuple, Type + +import warnings +import torch +import torch.utils.checkpoint +from neuronx_distributed.parallel_layers.layers import ( + ColumnParallelLinear, + RowParallelLinear, + ParallelEmbedding, + SPMDRank, +) +from neuronx_distributed.parallel_layers.mappings import ( + gather_from_sequence_parallel_region, +) +from neuronx_distributed.utils import cpu_mode +from torch import Tensor, nn + +from neuronx_distributed_inference.models.config import ( + InferenceConfig, + NeuronConfig, + MoENeuronConfig, +) +from neuronx_distributed_inference.models.model_base import ( + NeuronBaseForCausalLM, + NeuronBaseModel, +) +from neuronx_distributed_inference.utils.hf_adapter import HuggingFaceGenerationAdapter +from neuronx_distributed_inference.models.layer_boundary_marker import ( + ModuleMarkerEndWrapper, + ModuleMarkerStartWrapper, +) +from src.rope_util import ( + Glm4MoeLiteRotaryEmbedding, + apply_rotary_pos_emb, +) +from neuronx_distributed_inference.modules.attention.utils import manual_softmax +from neuronx_distributed_inference.modules.custom_calls import CustomRMSNorm +from neuronx_distributed_inference.modules.moe_v2 import initialize_moe_module +from neuronx_distributed.modules.moe.routing import GroupLimitedRouter, RouterTopK +from transformers import AutoModelForCausalLM +from transformers.activations import ACT2FN +from transformers.models.llama.modeling_llama import LlamaRMSNorm + +# NKI MLA attention kernel disabled: research showed 1.84x slower than XLA +# baseline due to graph fusion barriers from custom op boundaries. +_nki_mla_attention = None +_USE_NKI_MLA = False + +logger = logging.getLogger(__name__) + + +def convert_glm4_moe_lite_hf_to_neuron_state_dict( + state_dict: dict, config: "Glm4MoeLiteInferenceConfig" +) -> dict: + """ + Convert HuggingFace GLM-4.7-Flash (glm4_moe_lite) state dict to Neuron-compatible format. + + Transformations: + 1. Add rank utility tensors for TP sharding + 2. Rename router weights: gate.weight -> router.linear_router.weight + 3. Rename e_score_correction_bias -> router.e_score_correction_bias + 4. Fuse gate_proj + up_proj into gate_up_proj for each expert + 5. Stack down_proj weights across experts + 6. Skip dense layers (first_k_dense_replace layers, only layer 0 for GLM) + 7. Skip MTP layer weights (layer 47 embed_tokens) for initial bring-up + + When loading pre-quantized checkpoints (already in NxDI format with FP8 weights + and scale tensors), the expert weight fusion steps (4-5) are skipped since the + checkpoint already contains fused gate_up_proj and stacked down_proj in FP8. + """ + num_hidden_layers = config.num_hidden_layers + num_local_experts = config.num_local_experts + tp_degree = getattr(config.neuron_config, "tp_degree", 1) + first_k_dense = getattr(config, "first_k_dense_replace", 1) + + # Detect pre-quantized checkpoint: if the state dict already has fused expert + # weight keys with scale tensors, skip the fusion step. + _sample_fused_key = ( + f"layers.{first_k_dense}.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + ) + _sample_scale_key = ( + f"layers.{first_k_dense}.mlp.expert_mlps.mlp_op.gate_up_proj.scale" + ) + is_prequantized = ( + _sample_fused_key in state_dict and _sample_scale_key in state_dict + ) + + if is_prequantized: + logger.info( + "Detected pre-quantized checkpoint (FP8 expert weights with scales). " + "Skipping expert weight fusion." + ) + # FP8 mode: shared_experts are moved from inside MoE module to decoder layer. + # Rename: layers.X.mlp.shared_experts.* -> layers.X.shared_experts.* + shared_expert_renames = {} + for k in list(state_dict.keys()): + if ".mlp.shared_experts." in k: + new_key = k.replace(".mlp.shared_experts.", ".shared_experts.") + shared_expert_renames[k] = new_key + for old_key, new_key in shared_expert_renames.items(): + state_dict[new_key] = state_dict.pop(old_key) + if shared_expert_renames: + logger.info( + f"Renamed {len(shared_expert_renames)} shared_expert keys " + "(moved from MoE module to decoder layer for FP8 mode)" + ) + + # Add rank utilities for TP + state_dict["rank_util.rank"] = torch.arange(0, tp_degree, dtype=torch.int32) + + # Remove MTP layer weights (layer 47) — not used in initial bring-up + mtp_keys = [k for k in state_dict if k.startswith(f"layers.{num_hidden_layers}.")] + for k in mtp_keys: + del state_dict[k] + + for layer_idx in range(num_hidden_layers): + # Add rank utility for attention + state_dict[f"layers.{layer_idx}.self_attn.rank_util.rank"] = torch.arange( + 0, tp_degree, dtype=torch.int32 + ) + + # Skip dense layers (no MoE conversion needed) + if layer_idx < first_k_dense: + continue + + # Rename router weights: gate.weight -> router.linear_router.weight + router_key = f"layers.{layer_idx}.mlp.gate.weight" + if router_key in state_dict: + router_weight = state_dict[router_key].detach().clone() + state_dict[f"layers.{layer_idx}.mlp.router.linear_router.weight"] = ( + router_weight + ) + del state_dict[router_key] + + # MoEFusedTKG requires transposed router weights (weight_T) + # Generate it from linear_router.weight (works for both fresh and pre-quantized) + router_linear_key = f"layers.{layer_idx}.mlp.router.linear_router.weight" + if is_prequantized and router_linear_key in state_dict: + state_dict[f"layers.{layer_idx}.mlp.moe_fused_tkg.router.weight_T"] = ( + state_dict[router_linear_key].detach().T.contiguous() + ) + + # Rename e_score_correction_bias for GroupLimitedRouter + bias_key = f"layers.{layer_idx}.mlp.gate.e_score_correction_bias" + if bias_key in state_dict: + bias_tensor = state_dict[bias_key].detach().clone() + state_dict[f"layers.{layer_idx}.mlp.router.e_score_correction_bias"] = ( + bias_tensor + ) + # Also provide at moe_fused_tkg.correction_bias path for XLA weight loading. + # During SPMD tracing, the bias is accessed via self.correction_bias on + # MoEFusedTKG, so the compiled model expects it at this path. + state_dict[f"layers.{layer_idx}.mlp.moe_fused_tkg.correction_bias"] = ( + bias_tensor + ) + del state_dict[bias_key] + + # For pre-quantized checkpoints, the bias is already at mlp.router path. + # Add the moe_fused_tkg.correction_bias duplicate if not present. + if is_prequantized: + router_bias_key = f"layers.{layer_idx}.mlp.router.e_score_correction_bias" + tkg_bias_key = f"layers.{layer_idx}.mlp.moe_fused_tkg.correction_bias" + if router_bias_key in state_dict and tkg_bias_key not in state_dict: + state_dict[tkg_bias_key] = state_dict[router_bias_key] + + # If pre-quantized checkpoint, expert weights are already fused — skip fusion + if is_prequantized: + continue + + # Check if expert weights exist for this layer + expert_gate_key = f"layers.{layer_idx}.mlp.experts.0.gate_proj.weight" + if expert_gate_key not in state_dict: + continue + + intermediate_size, hidden_size = state_dict[expert_gate_key].shape + device = state_dict[expert_gate_key].device + dtype = state_dict[expert_gate_key].dtype + + # Fuse gate_proj + up_proj into gate_up_proj for all experts + gate_up_proj = torch.empty( + num_local_experts, + hidden_size, + 2 * intermediate_size, + dtype=dtype, + device=device, + ) + + for e in range(num_local_experts): + gate_key = f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight" + up_key = f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight" + + if gate_key in state_dict and up_key in state_dict: + gate_proj_weights = state_dict[gate_key].T.detach().clone() + up_proj_weights = state_dict[up_key].T.detach().clone() + + gate_up_proj_slice = torch.narrow(gate_up_proj, 0, e, 1) + torch.narrow(gate_up_proj_slice, 2, 0, intermediate_size).copy_( + gate_proj_weights + ) + torch.narrow( + gate_up_proj_slice, 2, intermediate_size, intermediate_size + ).copy_(up_proj_weights) + + del state_dict[gate_key] + del state_dict[up_key] + + state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.gate_up_proj.weight"] = ( + gate_up_proj + ) + + # Stack down_proj weights across all experts + down_proj = torch.empty( + num_local_experts, + intermediate_size, + hidden_size, + dtype=dtype, + device=device, + ) + + for e in range(num_local_experts): + down_key = f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight" + if down_key in state_dict: + down_proj_weights = state_dict[down_key].T.detach().clone() + torch.narrow(down_proj, 0, e, 1).copy_(down_proj_weights) + del state_dict[down_key] + + state_dict[f"layers.{layer_idx}.mlp.expert_mlps.mlp_op.down_proj.weight"] = ( + down_proj + ) + + gc.collect() + + return state_dict + + +class Glm4MoeLiteNeuronConfig(MoENeuronConfig): + """Neuron hardware configuration for GLM-4.7-Flash MoE model.""" + + pass + + +class Glm4MoeLiteRouter(RouterTopK): + """ + Custom router for GLM-4.7-Flash using sigmoid activation + e_score_correction_bias. + + GLM-4.7-Flash uses n_group=1, topk_group=1 which makes GroupLimitedRouter's + group selection a complete no-op. We use RouterTopK (simple torch.topk) instead, + which produces a simpler computation graph that avoids the NCC_IBIR297 compiler + bug in the tensorizer's ModuleForkPass at small TP degrees. + + After top-k selection, the selected affinities are L1-normalized and then + scaled by routed_scaling_factor (1.8). This replaces the + normalize_top_k_affinities step in ExpertMLPsV2, so the config must set + normalize_top_k_affinities=False. + """ + + def __init__( + self, + routed_scaling_factor: float = 1.8, + n_group: int = 1, + topk_group: int = 1, + **kwargs, + ): + # RouterTopK doesn't accept n_group/topk_group, so we pop them + super().__init__(**kwargs) + self.routed_scaling_factor = routed_scaling_factor + self.n_group = n_group + self.topk_group = topk_group + # e_score_correction_bias is a trained parameter loaded from checkpoint. + self.e_score_correction_bias = nn.Parameter( + torch.zeros(self.num_experts, dtype=torch.float32) + ) + + def forward(self, hidden_states): + router_logits = self.get_router_logits(hidden_states) + expert_affinities = self.apply_activation_fn(router_logits) + expert_affinities = expert_affinities.to(dtype=hidden_states.dtype) + + # Add correction bias for top-k selection (DS-V3 e_score_correction_bias) + scores_for_choice = expert_affinities + self.e_score_correction_bias.unsqueeze( + 0 + ) + + # Simple top-k (no group logic since n_group=1, topk_group=1) + _, topk_idx = torch.topk(scores_for_choice, k=self.top_k) + topk_idx = topk_idx.detach().to(dtype=torch.long) + + # Gather ORIGINAL affinities (without bias) for selected experts + topk_weights = expert_affinities.gather(1, topk_idx) # (T, top_k) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights * self.routed_scaling_factor + + # Scatter back to dense (T, E) layout for ExpertMLPsV2 + expert_affinities_scaled = torch.zeros_like(expert_affinities) + expert_affinities_scaled.scatter_(1, topk_idx, topk_weights) + + return router_logits, expert_affinities_scaled, topk_idx + + +class Glm4MoeLiteInferenceConfig(InferenceConfig): + """ + Inference configuration for GLM-4.7-Flash (glm4_moe_lite). + + Handles MLA attention parameters, MoE routing config, dense/MoE layer + distinction, and KV cache shape overrides for MLA's compressed cache format. + + Differences from DeepSeek-V3: + - No YaRN RoPE (standard RoPE) + - first_k_dense_replace = 1 (only layer 0 is dense) + - routed_scaling_factor = 1.8 (vs 2.5) + - n_group = 1, topk_group = 1 (no group selection) + - 64 experts, top-4 (vs 256 experts, top-8) + + FP8 Quantization: + - Supports EXPERT_WISE_PER_CHANNEL_SYMMETRIC quantization of MoE expert weights + - Set neuron_config.quantized=True and provide quantized_checkpoints_path + - Only expert gate_up_proj and down_proj are quantized; attention/dense layers stay BF16 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Standard HF config attributes expected by model_base.py + if not hasattr(self, "output_attentions"): + self.output_attentions = False + if not hasattr(self, "output_hidden_states"): + self.output_hidden_states = False + if not hasattr(self, "return_dict"): + self.return_dict = True + + # GLM-4.7-Flash stores rope_theta inside rope_parameters dict + if not hasattr(self, "rope_theta"): + rope_params = getattr(self, "rope_parameters", None) + if rope_params is not None: + if isinstance(rope_params, dict): + self.rope_theta = rope_params.get("rope_theta", 1000000) + else: + self.rope_theta = getattr(rope_params, "rope_theta", 1000000) + else: + self.rope_theta = 1000000 + + # GLM-4.7-Flash uses attention_dropout=0.0 + if not hasattr(self, "attention_dropout"): + self.attention_dropout = 0.0 + + # Map HF config names to NXDI MoE names + self.num_local_experts = getattr( + self, "n_routed_experts", getattr(self, "num_experts", 0) + ) + self.n_shared_experts = getattr(self, "n_shared_experts", 0) + self.num_experts_per_tok = getattr(self, "num_experts_per_tok", 0) + + # Store dense layer intermediate size before overriding with MoE size. + # HF config uses "intermediate_size" for the dense FFN (10240). + if not hasattr(self, "dense_intermediate_size"): + self.dense_intermediate_size = getattr(self, "intermediate_size", 0) + + # ExpertMLPsV2 reads config.intermediate_size for MoE expert size + if getattr(self, "moe_intermediate_size", None) is not None: + self.intermediate_size = self.moe_intermediate_size + + # Activation function + if not hasattr(self, "hidden_act"): + self.hidden_act = "silu" + + # Number of dense (non-MoE) layers at the start + if not hasattr(self, "first_k_dense_replace"): + self.first_k_dense_replace = 1 + + # MoE routing config (only when MoENeuronConfig is used) + if hasattr(self.neuron_config, "router_config"): + self.neuron_config.router_config.dtype = torch.float32 + self.neuron_config.router_config.act_fn = "sigmoid" + # Normalization + scaling is handled by Glm4MoeLiteRouter, not ExpertMLPsV2 + self.neuron_config.normalize_top_k_affinities = False + + # MoE kernel selection: use NKI shard-on-block kernel for CTE path. + # This is the preferred kernel for GLM-4.7-Flash because: + # - I_TP=384 satisfies shard_on_block constraint (I_TP % 16 == 0) + # - shard_on_intermediate requires I_TP % 256 == 0 (would need padding to 512) + # - shard_on_block has dynamic while loop for early exit on empty blocks + # - Requires PING_PONG sharding strategy + if hasattr(self.neuron_config, "blockwise_matmul_config"): + from src.compat import _patched as _nki_kernel_available + + self.neuron_config.blockwise_matmul_config.use_torch_block_wise = False + self.neuron_config.blockwise_matmul_config.use_shard_on_block_dynamic_while = True + from neuronx_distributed.modules.moe.blockwise import BlockShardStrategy + + self.neuron_config.blockwise_matmul_config.block_sharding_strategy = ( + BlockShardStrategy.PING_PONG + ) + logger.info( + "NKI shard-on-block kernel enabled for MoE CTE blockwise matmul" + ) + + # Also keep shard_hidden patch as fallback (if shard_on_block fails) + if not _nki_kernel_available: + logger.warning( + "NKI compat patches not applied - shard_hidden fallback unavailable. " + "shard_on_block kernel should still work via nkilib." + ) + + # Disable numeric CC token (workaround for all-gather/reduce-scatter) + self.neuron_config.disable_numeric_cc_token = True + + # FP8 quantization support for MoE expert weights + if getattr(self.neuron_config, "quantized", False): + # Set modules_to_not_convert: everything except MoE expert gate_up/down_proj. + # CRITICAL: EXPERT_WISE_PER_CHANNEL_SYMMETRIC has per_channel_axis=None which + # causes QuantizedColumnParallel to assert. We must exclude ALL non-expert-fused + # linear layers from conversion. + if not getattr(self.neuron_config, "modules_to_not_convert", None): + self.neuron_config.modules_to_not_convert = [ + "lm_head", + "embed_tokens", + "self_attn", + "input_layernorm", + "post_attention_layernorm", + "norm", + "layers.0.mlp", # Dense MLP layer (not MoE) + "shared_experts", # Shared expert MLP in MoE layers (not fused) + "router", + "rmsnorm", + ] + # Set the UNSAFE_FP8FNCAST env var required by the Neuron compiler + os.environ["UNSAFE_FP8FNCAST"] = "1" + # FP8 strategy: Use MoEFusedTKG for routed experts (it handles FP8 scales + # natively), but the TKG kernel doesn't support shared_experts yet. + # Solution: Set n_shared_experts=0 so initialize_moe_module doesn't include + # shared experts in the MoE module. Instead, we handle shared experts as a + # separate BF16 MLP in the decoder layer forward. + if getattr(self, "n_shared_experts", 0) > 0: + logger.info( + f"FP8 mode: Moving shared_experts (n={self.n_shared_experts}) out of MoE module " + "into separate BF16 MLP (MoEFusedTKG doesn't support shared_experts)." + ) + # Store original value for decoder layer to create separate shared expert MLP + self._fp8_shared_expert_intermediate_size = getattr( + self, "shared_expert_intermediate_size", None + ) or (self.moe_intermediate_size * self.n_shared_experts) + self.n_shared_experts = 0 + # Enable MoEFusedTKG for FP8 (now safe without shared_experts) + if not self.neuron_config.moe_fused_nki_kernel_enabled: + self.neuron_config.moe_fused_nki_kernel_enabled = True + logger.info( + "Enabled moe_fused_nki_kernel for FP8 quantized path " + "(MoEFusedTKG handles scale tensor passing natively)" + ) + logger.info( + "FP8 quantization enabled for MoE experts. " + f"modules_to_not_convert={self.neuron_config.modules_to_not_convert}" + ) + + # MLA KV cache: override head_dim and num_key_value_heads so the + # KVCacheManager allocates (bsz, 1, max_len, rope_dim + kv_lora_rank) + # instead of standard GQA layout. + # For GLM-4.7-Flash: 64 (qk_rope_head_dim) + 512 (kv_lora_rank) = 576 + # + # CRITICAL: The HF Glm4MoeLiteConfig has attribute_map={'head_dim': 'qk_rope_head_dim'} + # which means setting self.head_dim would actually modify qk_rope_head_dim. + # We must bypass this by writing directly to __dict__. + self.__dict__["head_dim"] = self.qk_rope_head_dim + self.kv_lora_rank + self.__dict__["num_key_value_heads"] = 1 + # Remove the head_dim alias from attribute_map to prevent KVCacheManager confusion + if hasattr(self, "attribute_map") and isinstance(self.attribute_map, dict): + self.attribute_map.pop("head_dim", None) + + def add_derived_config(self): + self.num_cores_per_group = 1 + + @classmethod + def get_neuron_config_cls(cls) -> Type[NeuronConfig]: + return Glm4MoeLiteNeuronConfig + + def get_required_attributes(self) -> List[str]: + return [ + # MLA (Multi-head Latent Attention) parameters + "kv_lora_rank", + "qk_nope_head_dim", + "qk_rope_head_dim", + "v_head_dim", + # MoE parameters + "n_routed_experts", + "num_experts_per_tok", + "moe_intermediate_size", + ] + + +def get_rmsnorm_cls(): + # Initialize to the appropriate implementation of RMSNorm + # If infer on NXD -> CustomRMSNorm + # If infer on CPU -> HF_RMSNorm (CustomRMSNorm does not work on CPU) + return LlamaRMSNorm if cpu_mode() else CustomRMSNorm + + +def custom_compiler_args(quantized=False): + """ + Compiler flags for GLM-4.7-Flash on Neuron. + Same as DeepSeek-V3 except no --verify-hlo (debug only). + When quantized=True, adds FP8 E4M3 cast flag. + """ + compiler_args = "--enable-saturate-infinity --enable-mixed-precision-accumulation --model-type transformer -O1" + # Removed: --enable-ccop-compute-overlap --cc-pipeline-tiling-factor=2 (causes NCC_IXCG967 at T=4096) + compiler_args += " --tensorizer-options='--vectorize-strided-dma'" + compiler_args += " --auto-cast=none" + if quantized: + # Enable unsafe FP8 E4M3 cast for Neuron hardware + compiler_args += " --internal-hlo2tensorizer-options='--experimental-unsafe-fp8e4m3fn-as-fp8e4m3'" + return compiler_args + + +class Glm4MoeLiteDenseMLP(nn.Module): + """ + Dense MLP for GLM-4.7-Flash layer 0 (first_k_dense_replace=1). + + Uses SiLU-gated architecture: output = down_proj(silu(gate_proj(x)) * up_proj(x)) + Uses dense_intermediate_size (10240) instead of moe_intermediate_size (1536). + """ + + def __init__(self, config: Glm4MoeLiteInferenceConfig): + super().__init__() + dtype = config.neuron_config.torch_dtype + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.dense_intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + config.dense_intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + ) + self.down_proj = RowParallelLinear( + config.dense_intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states, padding_mask=None, **kwargs): + output = self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + return (output,) + + +class Glm4MoeLiteSharedExpertMLP(nn.Module): + """ + Separate shared expert MLP for FP8 mode. + + When FP8 quantization is enabled, the MoEFusedTKG kernel handles routed experts + but doesn't support shared_experts. This module runs the shared expert computation + separately in BF16, then its output is added to the routed expert output. + + Uses the same SiLU-gated architecture as the dense MLP: + output = down_proj(silu(gate_proj(x)) * up_proj(x)) + + But with moe_intermediate_size (1536) instead of dense_intermediate_size (10240). + """ + + def __init__(self, config: "Glm4MoeLiteInferenceConfig"): + super().__init__() + dtype = config.neuron_config.torch_dtype + # Use the stored shared expert intermediate size from config + intermediate_size = getattr( + config, "_fp8_shared_expert_intermediate_size", None + ) + if intermediate_size is None: + intermediate_size = config.moe_intermediate_size * getattr( + config, "n_shared_experts", 1 + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + ) + self.up_proj = ColumnParallelLinear( + config.hidden_size, + intermediate_size, + bias=False, + gather_output=False, + dtype=dtype, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=False, + input_is_parallel=True, + dtype=dtype, + ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + return self.down_proj( + self.act_fn(self.gate_proj(hidden_states)) * self.up_proj(hidden_states) + ) + + +class Glm4MoeLiteAttention(nn.Module): + """ + Multi-head Latent Attention (MLA) for GLM-4.7-Flash. + + Key differences from DeepSeek-V3: + - qk_nope_head_dim=192 (vs 128), v_head_dim=256 (vs 128) -- requires + corrected wkv_b split using qk_nope_head_dim instead of v_head_dim + - Standard RoPE (no YaRN scaling) + - q_lora_rank=768 (vs 1536) + - 20 attention heads (vs 128) + """ + + def __init__( + self, + config: Glm4MoeLiteInferenceConfig, + layer_idx: Optional[int] = None, + tensor_model_parallel_group=None, + ): + super().__init__() + + # Config + self.config = config + self.neuron_config = config.neuron_config + + # Tensor parallelism + self.tp_degree = config.neuron_config.tp_degree + if tensor_model_parallel_group is not None: + self.tensor_model_parallel_group = tensor_model_parallel_group + else: + try: + from neuronx_distributed.parallel_layers import parallel_state + + self.tensor_model_parallel_group = ( + parallel_state.get_tensor_model_parallel_group() + ) + except Exception: + self.tensor_model_parallel_group = None + self.rank_util = SPMDRank(world_size=self.tp_degree) + + # Data types + self.torch_dtype = ( + getattr(config.neuron_config, "attention_dtype", None) + or config.neuron_config.torch_dtype + ) + self.rpl_reduce_dtype = getattr(config.neuron_config, "rpl_reduce_dtype", None) + + # Sequence parallelism + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.sequence_dimension = 1 if self.sequence_parallel_enabled else None + + # Model dimensions + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + + # Standard RoPE (no YaRN) + self.rotary_emb = Glm4MoeLiteRotaryEmbedding( + dim=config.qk_rope_head_dim, + max_position_embeddings=config.max_position_embeddings, + base=config.rope_theta, + ) + self.bias = getattr(config, "attention_bias", False) + self.layer_idx = layer_idx + assert layer_idx is not None, ( + "Please make sure to provide a `layer_idx` when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.num_total_heads = config.num_attention_heads + assert self.num_attention_heads % self.tp_degree == 0, ( + "Number of attention heads must be a multiple of tp degree." + ) + if cpu_mode(): + self.num_heads = self.num_total_heads + else: + self.num_heads = self.num_total_heads // self.tp_degree + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.head_dim = self.v_head_dim + + self.is_causal = True + self.init_mla_properties() + + # Standard softmax scale (no mscale adjustment for standard RoPE) + self.softmax_scale = self.q_head_dim ** (-0.5) + + def init_mla_properties(self): + config = self.config + dtype = self.torch_dtype + if self.q_lora_rank is None: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=self.tensor_model_parallel_group, + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, + config.q_lora_rank, + bias=config.attention_bias, + dtype=dtype, + ) + self.q_a_layernorm = get_rmsnorm_cls()(config.q_lora_rank) + self.q_b_proj = ColumnParallelLinear( + config.q_lora_rank, + self.num_total_heads * self.q_head_dim, + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=self.tensor_model_parallel_group, + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + dtype=dtype, + ) + self.kv_a_layernorm = get_rmsnorm_cls()(config.kv_lora_rank) + if self.tensor_model_parallel_group is not None: + self.kv_b_proj = ColumnParallelLinear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + gather_output=False, + dtype=dtype, + tensor_model_parallel_group=self.tensor_model_parallel_group, + ) + else: + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_total_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + if self.tensor_model_parallel_group is not None: + self.o_proj = RowParallelLinear( + self.num_attention_heads * self.v_head_dim, + self.hidden_size, + bias=self.bias, + input_is_parallel=True, + dtype=self.torch_dtype, + sequence_parallel_enabled=self.sequence_parallel_enabled, + sequence_dimension=self.sequence_dimension, + tensor_model_parallel_group=self.tensor_model_parallel_group, + reduce_dtype=self.rpl_reduce_dtype, + ) + else: + self.o_proj = nn.Linear( + self.num_attention_heads * self.v_head_dim, + self.hidden_size, + bias=self.bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: torch.Tensor = None, + active_mask: Optional[torch.LongTensor] = None, + adapter_ids=None, + cos_cache: Optional[torch.Tensor] = None, + sin_cache: Optional[torch.Tensor] = None, + **kwargs, + ): + """Implements each layer's forward pass for the attention block.""" + # On decode, past_key_value comes from KVCacheManager as [k_cache, v_cache] + # each shaped (bsz, 1, seq_len, qk_rope_head_dim + kv_lora_rank). + # Convert to the single concatenated tensor that the decode path expects. + if past_key_value is not None and isinstance(past_key_value, (list, tuple)): + combined = past_key_value[0].squeeze( + 1 + ) # (bsz, seq_len, rope_dim + kv_lora_rank) + past_key_value = combined + + if ( + self.sequence_parallel_enabled + and self.tensor_model_parallel_group is not None + ): + hidden_states = gather_from_sequence_parallel_region( + hidden_states, + self.sequence_dimension, + process_group=self.tensor_model_parallel_group, + ) + + bsz, q_len, _ = hidden_states.size() + + # Weight matrix absorption + wkv_b = self.kv_b_proj.weight + wkv_b = wkv_b.view(self.num_heads, -1, self.kv_lora_rank) + # CRITICAL FIX: Split by qk_nope_head_dim, NOT v_head_dim. + # Layout in kv_b_proj output: [K_nope (qk_nope_head_dim) | V (v_head_dim)] per head. + # DS-V3 used v_head_dim which only worked because nope==v==128. + # GLM-4.7-Flash: nope=192, v=256, so we must use qk_nope_head_dim. + out_absorb = wkv_b[ + :, self.qk_nope_head_dim :, : + ] # V absorption: (num_heads, v_head_dim, kv_lora_rank) + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + + q_nope, q_pe = torch.tensor_split(q, (self.qk_nope_head_dim,), dim=-1) + compressed_kv, k_pe = torch.tensor_split( + compressed_kv, (self.kv_lora_rank,), dim=-1 + ) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + + # Q_nope absorption: project Q_nope into compressed space using K_nope weights + q_absorb = wkv_b[ + :, : self.qk_nope_head_dim + ] # K_nope absorption: (num_heads, qk_nope_head_dim, kv_lora_rank) + q_nope = torch.einsum("hdc,bhqd->bhqc", q_absorb, q_nope) + + seq_len = self.neuron_config.seq_len + if sin_cache is None and cos_cache is None: + cos_cache, sin_cache = self.rotary_emb(k_pe, seq_len) + q_pe = apply_rotary_pos_emb(q_pe, cos_cache, sin_cache, position_ids) + k_pe = apply_rotary_pos_emb(k_pe, cos_cache, sin_cache, position_ids) + + active_scores = torch.matmul(q_pe, k_pe.transpose(2, 3)) + torch.einsum( + "bhqc,blc->bhql", q_nope, compressed_kv + ) + active_scores *= self.softmax_scale + + if past_key_value is None: + active_scores = torch.where( + attention_mask, active_scores, torch.finfo(active_scores.dtype).min + ) + active_scores = nn.functional.softmax( + active_scores, dim=-1, dtype=torch.float32 + ).to(k_pe.dtype) + + # Attention result with V absorb + x = torch.einsum("bhql,blc->bhqc", active_scores, compressed_kv) + attn_output = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + else: + if _USE_NKI_MLA and _nki_mla_attention is not None: + # === NKI FUSED MLA ATTENTION (TKG decode) === + # The kernel fuses: score computation + online softmax + V multiplication + # for the prior KV cache AND combines with active token in one pass. + # + # Returns fully normalized output in compressed space [B, H, kv_lora_rank]. + seq_len_prior = past_key_value.shape[1] + + # q_nope: [B, H, 1, kv_lora_rank] -> [B, H, kv_lora_rank] (squeeze S_q=1) + # q_pe: [B, H, 1, rope_dim] -> [B, H, rope_dim] + q_nope_squeezed = q_nope.squeeze(2) # [B, H, kv_lora_rank] + q_pe_squeezed = q_pe.squeeze(2) # [B, H, rope_dim] + + # Active token score and V for combining inside kernel + # active_scores is [B, H, 1, 1] -- squeeze last dim to [B, H, 1] float32 + active_scores_for_kernel = active_scores.squeeze( + -1 + ).float() # [B, H, 1] + + # Active V: compressed_kv is [B, 1, kv_lora_rank] -- shared across heads + # Expand to [B, H, kv_lora_rank] + active_v_for_kernel = ( + compressed_kv.squeeze(1) + .unsqueeze(1) + .expand(bsz, self.num_heads, self.kv_lora_rank) + .contiguous() + ) + + # Construct additive attention mask for NKI kernel: [B, S, 1] + # attention_mask is [B, 1, 1, S] bool (True=valid, False=invalid) + # Convert to [B, S, 1] float32: 0.0 for valid, -9984.0 for invalid + nki_mask = torch.where( + attention_mask.squeeze(1).squeeze(1).unsqueeze(-1), # [B, S, 1] + torch.zeros(1, dtype=torch.float32, device=attention_mask.device), + torch.full( + (1,), -9984.0, dtype=torch.float32, device=attention_mask.device + ), + ) + + # KV cache is stored as [k_pe(64) | compressed_kv(512)] -- kernel accepts this order directly + v_compressed = _nki_mla_attention[2]( + q_nope_squeezed, + q_pe_squeezed, + past_key_value, + active_scores_for_kernel, + active_v_for_kernel, + nki_mask, + softmax_scale=self.softmax_scale, + batch_size=bsz, + num_heads=self.num_heads, + seq_len=seq_len_prior, + kv_lora_rank=self.kv_lora_rank, + qk_rope_head_dim=self.qk_rope_head_dim, + ) + # v_compressed: [B, H, kv_lora_rank] BF16 (fully normalized) + + # Apply out_absorb: [B, H, kv_lora_rank] @ [H, v_head_dim, kv_lora_rank]^T -> [B, H, v_head_dim] + attn_output = ( + torch.einsum("bhc,hdc->bhd", v_compressed.float(), out_absorb) + .to(q_nope.dtype) + .unsqueeze(2) + ) # [B, H, 1, v_head_dim] to match expected shape + + else: + # === ORIGINAL PyTorch MLA ATTENTION (fallback) === + k_pe_prior, compressed_kv_prior = torch.tensor_split( + past_key_value, + [ + self.qk_rope_head_dim, + ], + dim=-1, + ) + k_pe_prior = k_pe_prior.reshape( + bsz, 1, compressed_kv_prior.shape[1], self.qk_rope_head_dim + ) + + # I. Scores and softmax + prior_scores = torch.matmul( + q_pe, k_pe_prior.transpose(2, 3) + ) + torch.einsum("bhqc,blc->bhql", q_nope, compressed_kv_prior) + prior_scores *= self.softmax_scale + prior_scores = torch.where( + attention_mask, prior_scores, torch.finfo(prior_scores.dtype).min + ) + prior_scores = prior_scores.to(torch.float32) + + softmax_prior, softmax_active = manual_softmax( + prior_scores, active_scores, is_speculation=False + ) + softmax_prior, softmax_active = ( + softmax_prior.to(k_pe.dtype), + softmax_active.to(k_pe.dtype), + ) + + # II. Attention result with V absorb + x = torch.einsum("bhql,blc->bhqc", softmax_active, compressed_kv) + attn_active = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + x = torch.einsum("bhql,blc->bhqc", softmax_prior, compressed_kv_prior) + attn_prior = torch.einsum("bhqc,hdc->bhqd", x, out_absorb) + + attn_output = attn_prior + attn_active + + # Transpose BHSD -> BSHD + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + # Z = Z.Wo + attn_output = self.o_proj(attn_output) + + # Concatenate k_pe and compressed_kv into combined format for KVCacheManager. + # KVCacheManager expects (key, value) tuple each shaped (bsz, 1, seq_len, head_dim). + # For MLA, we store [k_pe | compressed_kv] in both slots (V is duplicate). + combined = torch.cat([k_pe.squeeze(1), compressed_kv], dim=-1).unsqueeze(1) + past_key_value = (combined, combined) + + return attn_output, past_key_value, cos_cache, sin_cache + + +class NeuronGlm4MoeLiteDecoderLayer(nn.Module): + """ + GLM-4.7-Flash decoder layer with MLA attention and Dense MLP or MoE. + + Layer 0 uses a dense MLP; layers 1-46 use Mixture-of-Experts (MoE). + """ + + def __init__(self, config: Glm4MoeLiteInferenceConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + self.layer_idx = layer_idx + self.is_dense_layer = layer_idx < getattr(config, "first_k_dense_replace", 1) + + self.self_attn = Glm4MoeLiteAttention(config=config, layer_idx=layer_idx) + self.moe_fused_nki_kernel_enabled = getattr( + config.neuron_config, "moe_fused_nki_kernel_enabled", False + ) + + self.input_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + self.post_attention_layernorm = get_rmsnorm_cls()( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + if self.is_dense_layer: + self.mlp = Glm4MoeLiteDenseMLP(config) + elif self.moe_fused_nki_kernel_enabled: + self.mlp = initialize_moe_module( + config=config, + rmsnorm=self.post_attention_layernorm, + init_tkg_module=True, + ) + else: + self.mlp = initialize_moe_module(config=config) + + # Swap in Glm4MoeLiteRouter (GroupLimitedRouter + routed_scaling_factor) + if not self.is_dense_layer: + self.mlp.router = Glm4MoeLiteRouter( + routed_scaling_factor=getattr(config, "routed_scaling_factor", 1.8), + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + n_group=getattr(config, "n_group", 1), + topk_group=getattr(config, "topk_group", 1), + dtype=config.neuron_config.router_config.dtype, + act_fn=config.neuron_config.router_config.act_fn, + sequence_parallel_enabled=config.neuron_config.sequence_parallel_enabled, + sequence_dimension=1, + # MoEFusedTKG requires transposed router weights + store_transposed_weights=self.moe_fused_nki_kernel_enabled, + ) + # Also update the router reference in MoEFusedTKG (if present) + if ( + hasattr(self.mlp, "moe_fused_tkg") + and self.mlp.moe_fused_tkg is not None + ): + self.mlp.moe_fused_tkg.router = self.mlp.router + # Register correction bias directly on MoEFusedTKG as a parameter + # so that XLA tracing captures it as a weight input (not a constant). + # Accessing it via self.router.e_score_correction_bias inside the NKI + # kernel call doesn't get captured by XLA's weight tracking. + if hasattr(self.mlp.router, "e_score_correction_bias"): + self.mlp.moe_fused_tkg.correction_bias = ( + self.mlp.router.e_score_correction_bias + ) + + # FP8 mode: create separate shared expert MLP (BF16) since MoEFusedTKG + # doesn't support shared_experts in the fused kernel. + self.has_separate_shared_expert = not self.is_dense_layer and hasattr( + config, "_fp8_shared_expert_intermediate_size" + ) + if self.has_separate_shared_expert: + self.shared_experts = Glm4MoeLiteSharedExpertMLP(config) + + self.qkv_kernel_enabled = config.neuron_config.qkv_kernel_enabled + self.sequence_parallel_enabled = config.neuron_config.sequence_parallel_enabled + self.qkv_kernel_fused_rmsnorm = not self.sequence_parallel_enabled + self.moe_mask_padded_tokens = config.neuron_config.moe_mask_padded_tokens + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states: input to the layer of shape (batch, seq_len, embed_dim) + attention_mask: mask of size (batch_size, 1, query_seq_len, key_seq_len) + position_ids: position ids of size (batch_size, sequence_length) + past_key_value: cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated. Please use `attention_mask` instead." + ) + + residual = hidden_states + + qkv_fused_rmsnorm = None + hidden_states = ModuleMarkerStartWrapper()(hidden_states) + if self.input_layernorm: + if self.qkv_kernel_enabled and self.qkv_kernel_fused_rmsnorm: + qkv_fused_rmsnorm = self.input_layernorm + else: + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, present_key_value, cos_cache, sin_cache = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + rmsnorm=qkv_fused_rmsnorm, + **kwargs, + ) + hidden_states = residual + hidden_states + + # MLP (Dense for layer 0, MoE for rest) + residual = hidden_states + if self.is_dense_layer: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states, padding_mask)[0] + else: + if not self.moe_fused_nki_kernel_enabled: + hidden_states = self.post_attention_layernorm(hidden_states) + # Save post-layernorm input for shared expert (FP8 mode) + # In fused TKG mode, the rmsnorm is fused into the MoE module, + # so we need the pre-norm input for the separate shared expert. + if self.has_separate_shared_expert: + if self.moe_fused_nki_kernel_enabled: + # In fused mode, post_attention_layernorm is passed to MoE as rmsnorm. + # The shared expert needs the normalized input. + shared_expert_input = self.post_attention_layernorm(hidden_states) + else: + shared_expert_input = hidden_states + is_speculative_decoding = ( + self.config.neuron_config.enable_fused_speculation + and (not self.config.neuron_config.is_prefill_stage) + ) + hidden_states = self.mlp( + hidden_states, + padding_mask, + is_speculative_decoding=is_speculative_decoding, + )[0] + # Add shared expert output (FP8 mode: separate BF16 computation) + if self.has_separate_shared_expert: + hidden_states = hidden_states + self.shared_experts(shared_expert_input) + hidden_states = residual + hidden_states + + # End module marker + hidden_states = ModuleMarkerEndWrapper()(hidden_states) + outputs = (hidden_states, present_key_value, cos_cache, sin_cache, None) + + return outputs + + +class NeuronGlm4MoeLiteModel(NeuronBaseModel): + """ + NeuronGlm4MoeLiteModel extends the GLM-4.7-Flash model to be traceable. + The forward function of this class is traced by NxDI. + """ + + def setup_attr_for_model(self, config: Glm4MoeLiteInferenceConfig): + self.on_device_sampling = ( + config.neuron_config.on_device_sampling_config is not None + ) + self.tp_degree = config.neuron_config.tp_degree + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.max_batch_size = config.neuron_config.max_batch_size + self.buckets = config.neuron_config.buckets + + def init_model(self, config: Glm4MoeLiteInferenceConfig): + self.padding_idx = getattr(config, "pad_token_id", None) + self.vocab_size = config.vocab_size + + self.embed_tokens = ParallelEmbedding( + config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=config.neuron_config.torch_dtype, + shard_across_embedding=True, + ) + self.layers = nn.ModuleList( + [ + NeuronGlm4MoeLiteDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = get_rmsnorm_cls()(self.hidden_size, eps=config.rms_norm_eps) + self.lm_head = ColumnParallelLinear( + config.hidden_size, + config.vocab_size, + gather_output=False if self.on_device_sampling else True, + bias=False, + ) + + +class NeuronGlm4MoeLiteForCausalLM(NeuronBaseForCausalLM): + """ + NxDI CausalLM wrapper for GLM-4.7-Flash (glm4_moe_lite). + """ + + _model_cls = NeuronGlm4MoeLiteModel + + @staticmethod + def load_hf_model(model_path, **kwargs): + kwargs.setdefault("torch_dtype", torch.bfloat16) + return AutoModelForCausalLM.from_pretrained( + model_path, trust_remote_code=True, **kwargs + ) + + @classmethod + def get_config_cls(cls): + return Glm4MoeLiteInferenceConfig + + @staticmethod + def convert_hf_to_neuron_state_dict( + state_dict: dict, config: Glm4MoeLiteInferenceConfig + ) -> dict: + return convert_glm4_moe_lite_hf_to_neuron_state_dict(state_dict, config) + + def get_compiler_args(self): + """Return compiler args for GLM-4.7-Flash on Neuron.""" + quantized = getattr(self.config.neuron_config, "quantized", False) + args = custom_compiler_args(quantized=quantized) + args += f" --lnc={self.config.neuron_config.logical_nc_config}" + return args + + +class Glm4MoeLiteGenerationAdapter(HuggingFaceGenerationAdapter): + """Generation adapter with position_ids fix for transformers 5.x. + + In transformers >= 5.0, _update_model_kwargs_for_generation appends to + position_ids and passes them back via kwargs on subsequent decode steps. + However, NxDI's HuggingFaceGenerationAdapter.prepare_inputs_for_generation + only recomputes position_ids when they are None in kwargs. When they are + present (stale, growing), it passes them unchanged — leading to incorrect + RoPE and KV cache positioning during autoregressive decode. + + Fix: Remove stale position_ids from kwargs so the base class recomputes + them correctly from attention_mask each step. + """ + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + sampling_params=None, + adapter_ids=None, + divergence_idx=None, + **kwargs, + ): + # Remove stale position_ids so base class recomputes from attention_mask + kwargs.pop("position_ids", None) + return super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + sampling_params=sampling_params, + adapter_ids=adapter_ids, + divergence_idx=divergence_idx, + **kwargs, + ) diff --git a/contrib/models/GLM-4.7-Flash/src/rope_util.py b/contrib/models/GLM-4.7-Flash/src/rope_util.py new file mode 100644 index 00000000..ff8e5426 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/src/rope_util.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Standard Rotary Position Embedding for GLM-4.7-Flash. + +GLM-4.7-Flash uses standard RoPE without YaRN scaling extension. +This is a simplified version of the DeepSeek-V3 rope_util.py. +""" + +import torch +import torch.utils.checkpoint +from torch import nn + + +class Glm4MoeLiteRotaryEmbedding(nn.Module): + """Standard RoPE with no scaling (factor=1.0, no YaRN).""" + + def __init__(self, dim, max_position_embeddings=200000, base=10000, device=None): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + + def get_freqs_table(self, device, seq_len): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(t.device)) + return freqs + + def forward(self, x, seq_len=None, freqs=None): + device = x.device + dtype = x.dtype + if freqs is None: + freqs = self.get_freqs_table(device, seq_len) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + return cos, sin + + +def rotate_fn(x: torch.Tensor): + """Interleaved rotation: pairs (x0,x1) -> (-x1,x0), (x2,x3) -> (-x3,x2), ...""" + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def apply_rotary_pos_emb(q: torch.Tensor, cos, sin, position_ids): + """Apply rotary position embedding with interleaved layout.""" + cos_sglang = cos.chunk(2, dim=-1)[0][position_ids] + sin_sglang = sin.chunk(2, dim=-1)[0][position_ids] + + sin = sin_sglang.repeat_interleave(2, dim=-1)[0] + cos = cos_sglang.repeat_interleave(2, dim=-1)[0] + + q_embed = (q * cos) + rotate_fn(q) * sin + return q_embed.to(q.dtype) diff --git a/contrib/models/GLM-4.7-Flash/test/__init__.py b/contrib/models/GLM-4.7-Flash/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/GLM-4.7-Flash/test/integration/__init__.py b/contrib/models/GLM-4.7-Flash/test/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/GLM-4.7-Flash/test/integration/compile_fp8.py b/contrib/models/GLM-4.7-Flash/test/integration/compile_fp8.py new file mode 100644 index 00000000..1ee84cac --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/integration/compile_fp8.py @@ -0,0 +1,583 @@ +#!/usr/bin/env python3 +"""Compile and test GLM-4.7-Flash with FP8 quantized MoE expert weights. + +This script: + 1. Quantizes expert weights from BF16 to FP8 E4M3 (if not already done) + 2. Compiles the model with quantized=True (MoEFusedTKG path) + 3. Loads and runs inference to validate correctness + 4. Benchmarks TPOT for FP8 vs BF16 comparison + +Usage: + # Full pipeline (quantize + compile + test): + python compile_fp8.py --quantize --compile --test + + # Just compile (quantized checkpoint already exists): + python compile_fp8.py --compile + + # Just test (compiled model already exists): + python compile_fp8.py --test + + # Quick benchmark (existing compiled model): + python compile_fp8.py --benchmark +""" + +import argparse +import os +import sys +import time + +import torch + +sys.path.insert(0, "/mnt/models/GLM-4.7-Flash-contrib") +os.environ["NEURON_RT_VISIBLE_CORES"] = "0-3" +os.environ["UNSAFE_FP8FNCAST"] = "1" + +from neuronx_distributed_inference.models.config import ( + MoENeuronConfig, + OnDeviceSamplingConfig, +) +from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config +from transformers import AutoConfig, AutoTokenizer, GenerationConfig + +from src.modeling_glm4_moe_lite import ( + Glm4MoeLiteGenerationAdapter, + Glm4MoeLiteInferenceConfig, + NeuronGlm4MoeLiteForCausalLM, +) + +# Register glm4_moe_lite config (not in transformers registry) +try: + from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig + + class Glm4MoeLiteConfig(Glm4MoeConfig): + model_type = "glm4_moe_lite" + + AutoConfig.register("glm4_moe_lite", Glm4MoeLiteConfig) +except Exception: + pass # Already registered or glm4_moe not available + +# Paths +MODEL_PATH = "/mnt/models/GLM-4.7-Flash" +QUANTIZED_PATH = "/mnt/models/GLM-4.7-Flash-FP8" +COMPILED_FP8_PATH = "/mnt/models/compiled_glm4_fp8_sob" + +# Config +BATCH_SIZE = 4 +CTX_BATCH_SIZE = 1 # CTE processes 1 prompt at a time (eliminates left-padding issues) +SEQ_LEN = 16384 +TP_DEGREE = 4 + +# Bucketing config: CTE bucket sizes for short-prompt TTFT optimization +# Each bucket compiles a separate NEFF, so more buckets = longer compile time +# With 4 CTE buckets: compile time ~60-80 min (vs ~20 min unbucketed) +ENABLE_BUCKETING = True +CTE_BUCKETS = [128, 512, 2048, 4096, 8192, 16384] +# TKG buckets: single bucket for maximum compiler optimization of the TKG NEFF. +# Multiple TKG buckets cause massive TPOT regression (6.8x) due to bucket switching overhead. +TKG_BUCKETS = [16384] + + +def step_quantize(): + """Step 1: Quantize expert weights to FP8.""" + print("\n" + "=" * 70) + print("STEP 1: Quantize Expert Weights (BF16 -> FP8 E4M3)") + print("=" * 70) + + if os.path.exists(os.path.join(QUANTIZED_PATH, "model.safetensors.index.json")): + print(f" Quantized checkpoint already exists at {QUANTIZED_PATH}") + print(" Skipping quantization. Use --force-quantize to redo.") + return + + # Run the quantization script + import subprocess + + result = subprocess.run( + [ + sys.executable, + "/mnt/models/GLM-4.7-Flash-contrib/scripts/quantize_experts_fp8.py", + "--model-path", + MODEL_PATH, + "--output-path", + QUANTIZED_PATH, + "--tp-degree", + str(TP_DEGREE), + ], + capture_output=False, + ) + if result.returncode != 0: + raise RuntimeError(f"Quantization failed with return code {result.returncode}") + + +def step_compile(): + """Step 2: Compile model with FP8 quantized weights.""" + print("\n" + "=" * 70) + print("STEP 2: Compile Model (FP8 Quantized, MoEFusedTKG)") + print("=" * 70) + + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=ENABLE_BUCKETING, + context_encoding_buckets=CTE_BUCKETS if ENABLE_BUCKETING else None, + token_generation_buckets=TKG_BUCKETS if ENABLE_BUCKETING else None, + flash_decoding_enabled=False, + logical_nc_config=2, + # Enable continuous batching for proper KV cache indexing with ctx_batch_size=1 + is_continuous_batching=True, + # FP8 quantization config + quantized=True, + quantization_type="expert_wise_per_channel_symmetric", + quantization_dtype="f8e4m3", + quantized_checkpoints_path=QUANTIZED_PATH, + modules_to_not_convert=[ + "lm_head", + "embed_tokens", + "self_attn", + "norm", + "layers.0.mlp", + "shared_experts", + "router", + ], + # Use MoEFusedTKG for FP8 routed experts (shared experts handled separately) + moe_fused_nki_kernel_enabled=True, + ) + + inf_config = Glm4MoeLiteInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + print(f"\nConfig:") + print(f" TP degree: {TP_DEGREE}") + print(f" Batch size: {BATCH_SIZE}") + print(f" CTE batch size: {CTX_BATCH_SIZE}") + print(f" TKG batch size: {BATCH_SIZE}") + print(f" Seq len: {SEQ_LEN}") + print(f" Dtype: bfloat16 (FP8 experts)") + print(f" LNC: 2") + print(f" Bucketing: {ENABLE_BUCKETING}") + if ENABLE_BUCKETING: + print(f" CTE buckets: {CTE_BUCKETS}") + print(f" TKG buckets: {TKG_BUCKETS}") + print(f" Continuous batching: True") + print(f" Quantized: True") + print(f" Quantization type: expert_wise_per_channel_symmetric") + print(f" Quantized checkpoint: {QUANTIZED_PATH}") + print(f" MoE kernel: MoEFusedTKG (FP8 routed) + separate shared expert (BF16)") + print(f"\nCompiling to: {COMPILED_FP8_PATH}") + + os.makedirs(COMPILED_FP8_PATH, exist_ok=True) + + t0 = time.time() + model = NeuronGlm4MoeLiteForCausalLM(MODEL_PATH, inf_config) + model.compile(COMPILED_FP8_PATH) + compile_time = time.time() - t0 + + print( + f"\nCompilation complete in {compile_time:.1f}s ({compile_time / 60:.1f} min)" + ) + print(f"Artifacts saved to: {COMPILED_FP8_PATH}") + return compile_time + + +def step_test(): + """Step 3: Load and run inference to validate correctness.""" + print("\n" + "=" * 70) + print("STEP 3: Test Inference (FP8 Quantized)") + print("=" * 70) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" # Must use right-padding with ctx_batch_size=1 + + # Load model + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=ENABLE_BUCKETING, + context_encoding_buckets=CTE_BUCKETS if ENABLE_BUCKETING else None, + token_generation_buckets=TKG_BUCKETS if ENABLE_BUCKETING else None, + flash_decoding_enabled=False, + logical_nc_config=2, + is_continuous_batching=True, + quantized=True, + quantization_type="expert_wise_per_channel_symmetric", + quantization_dtype="f8e4m3", + quantized_checkpoints_path=QUANTIZED_PATH, + modules_to_not_convert=[ + "lm_head", + "embed_tokens", + "self_attn", + "norm", + "layers.0.mlp", + "shared_experts", + "router", + ], + moe_fused_nki_kernel_enabled=True, + ) + + inf_config = Glm4MoeLiteInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + print(" Loading compiled model...") + t0 = time.time() + model = NeuronGlm4MoeLiteForCausalLM(COMPILED_FP8_PATH, inf_config) + model.load(COMPILED_FP8_PATH) + load_time = time.time() - t0 + print(f" Loaded in {load_time:.1f}s") + + gen_model = Glm4MoeLiteGenerationAdapter(model) + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + # Test prompts — with ctx_batch_size=1, CTE processes one sequence at a time + # so left-padding between sequences is no longer an issue. We can now pass + # mixed-length prompts in a single batch. + test_prompts = [ + "The capital of France is", + "In machine learning, a transformer model works by", + "The square root of 144 is", + "Python is a programming language known for", + ] + + # Test 1: Mixed-length batch (the main test for ctx_batch_size=1 fixing left-padding) + print(f"\n TEST 1: Mixed-length batch (BS={BATCH_SIZE}, ctx_bs={CTX_BATCH_SIZE})") + print(" This tests whether ctx_batch_size=1 fixes left-padding for mixed lengths.") + + inputs = tokenizer(test_prompts, return_tensors="pt", padding=True) + print(f" Input shape: {inputs.input_ids.shape}") + print(f" Pad token positions per sequence:") + for i, mask in enumerate(inputs.attention_mask): + n_pad = (mask == 0).sum().item() + print(f" [{i}] '{test_prompts[i][:40]}...' -> {n_pad} pad tokens") + + t0 = time.time() + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=50, + ) + gen_time = time.time() - t0 + + decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) + generated = decoded + + print(f"\n Generation complete in {gen_time:.2f}s") + print("\n --- Outputs ---") + for i, (prompt, output) in enumerate(zip(test_prompts, generated)): + generated_part = output[len(prompt) :] + print(f" [{i}] {prompt}") + print(f" -> {generated_part[:100]}...") + print() + + # Validate outputs are non-empty and coherent (basic sanity) + all_valid = True + for i, output in enumerate(generated): + if len(output) <= len(test_prompts[i]): + print(f" WARNING: Output {i} has no generated tokens!") + all_valid = False + # Check for repetitive garbage patterns + gen_part = output[len(test_prompts[i]) :] + if len(gen_part) > 10: + # Check for excessive repetition (same char/word repeated) + chars = set(gen_part[:20]) + if len(chars) <= 3: + print(f" WARNING: Output {i} appears to be repetitive garbage!") + all_valid = False + + if all_valid: + print(" PASS: All outputs generated successfully") + else: + print(" FAIL: Some outputs are empty or garbage") + + return all_valid + + +def step_benchmark(): + """Step 4: Benchmark TPOT with FP8.""" + print("\n" + "=" * 70) + print("STEP 4: TPOT Benchmark (FP8 Quantized)") + print("=" * 70) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = "right" # Must use right-padding with ctx_batch_size=1 + + neuron_config = MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + ctx_batch_size=CTX_BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=ENABLE_BUCKETING, + context_encoding_buckets=CTE_BUCKETS if ENABLE_BUCKETING else None, + token_generation_buckets=TKG_BUCKETS if ENABLE_BUCKETING else None, + flash_decoding_enabled=False, + logical_nc_config=2, + is_continuous_batching=True, + quantized=True, + quantization_type="expert_wise_per_channel_symmetric", + quantization_dtype="f8e4m3", + quantized_checkpoints_path=QUANTIZED_PATH, + modules_to_not_convert=[ + "lm_head", + "embed_tokens", + "self_attn", + "norm", + "layers.0.mlp", + "shared_experts", + "router", + ], + moe_fused_nki_kernel_enabled=True, + ) + + inf_config = Glm4MoeLiteInferenceConfig( + neuron_config, + load_config=load_pretrained_config(MODEL_PATH), + ) + + print(" Loading compiled model...") + model = NeuronGlm4MoeLiteForCausalLM(COMPILED_FP8_PATH, inf_config) + model.load(COMPILED_FP8_PATH) + + gen_model = Glm4MoeLiteGenerationAdapter(model) + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + # Warmup + print(" Warming up (5 iterations)...") + warmup_text = "The quick brown fox " * 10 + warmup_inputs = tokenizer( + [warmup_text] * BATCH_SIZE, return_tensors="pt", padding=True + ) + for _ in range(5): + gen_model.generate( + warmup_inputs.input_ids, + generation_config=gen_config, + attention_mask=warmup_inputs.attention_mask, + max_new_tokens=10, + ) + + # Measure TPOT: generate 128 tokens, measure E2E time + # TPOT ≈ (E2E - TTFT) / (n_tokens - 1) + print("\n Measuring TPOT (128 in, 128 out)...") + prompt = "In the field of quantum computing, " * 8 # ~128 tokens + inputs = tokenizer( + [prompt] * BATCH_SIZE, + return_tensors="pt", + padding=True, + max_length=128, + truncation=True, + ) + + # Measure TTFT (1 token) + ttft_times = [] + for _ in range(10): + t0 = time.time() + gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=1, + ) + ttft_times.append(time.time() - t0) + + # Measure E2E (128 tokens) + e2e_times = [] + n_tokens_list = [] + for _ in range(10): + t0 = time.time() + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=128, + ) + e2e_times.append(time.time() - t0) + n_tokens_list.append(outputs.shape[1] - inputs.input_ids.shape[1]) + + # Calculate metrics + avg_ttft = sum(ttft_times) / len(ttft_times) + avg_e2e = sum(e2e_times) / len(e2e_times) + avg_n_tokens = sum(n_tokens_list) / len(n_tokens_list) + avg_tpot = (avg_e2e - avg_ttft) / max(avg_n_tokens - 1, 1) + throughput = BATCH_SIZE / avg_tpot # Total tok/s + + print(f"\n Results (FP8, BS={BATCH_SIZE}, 128in/128out):") + print(f" TTFT: {avg_ttft * 1000:.1f} ms") + print(f" TPOT: {avg_tpot * 1000:.1f} ms") + print(f" Throughput: {throughput:.1f} tok/s (batch)") + print(f" E2E: {avg_e2e * 1000:.0f} ms") + print(f" Tokens: {avg_n_tokens:.0f}") + + # Compare with BF16 baseline (from previous benchmarks) + bf16_tpot_ms = 419.0 # BS=16 from formal benchmark (adjusted for BS=4) + # BS=4 BF16 TPOT from formal_benchmark_bs4.json + print(f"\n Comparison with BF16 baseline:") + print(f" BF16 TPOT (BS=16): 419.0 ms → {BATCH_SIZE * 1000 / 419.0:.1f} tok/s") + print( + f" FP8 TPOT (BS={BATCH_SIZE}): {avg_tpot * 1000:.1f} ms → {throughput:.1f} tok/s" + ) + + return { + "ttft_ms": avg_ttft * 1000, + "tpot_ms": avg_tpot * 1000, + "throughput_tok_s": throughput, + "e2e_ms": avg_e2e * 1000, + "batch_size": BATCH_SIZE, + "dtype": "fp8_e4m3_experts", + } + + +def main(): + parser = argparse.ArgumentParser( + description="GLM-4.7-Flash FP8 compilation and testing" + ) + parser.add_argument("--quantize", action="store_true", help="Run quantization step") + parser.add_argument("--compile", action="store_true", help="Run compilation step") + parser.add_argument("--test", action="store_true", help="Run inference test") + parser.add_argument("--benchmark", action="store_true", help="Run TPOT benchmark") + parser.add_argument( + "--force-quantize", + action="store_true", + help="Force re-quantization even if checkpoint exists", + ) + parser.add_argument( + "--no-bucketing", + action="store_true", + help="Disable bucketing (single 4096 CTE bucket, faster compile)", + ) + parser.add_argument( + "--ctx-batch-size", + type=int, + default=None, + help="Override CTE batch size (default: 1). Use 4 to match legacy behavior.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Override TKG batch size (default: 4). Adjusts compiled model path.", + ) + parser.add_argument( + "--max-cte-bucket", + type=int, + default=None, + help="Maximum CTE bucket size (default: 16384). Reduce for larger BS to avoid CTE OOM.", + ) + parser.add_argument( + "--seq-len", + type=int, + default=None, + help="Override SEQ_LEN (default: 16384). Reduces KV cache size for larger BS.", + ) + parser.add_argument( + "--all", + action="store_true", + help="Run all steps (quantize + compile + test + benchmark)", + ) + args = parser.parse_args() + + # Apply --no-bucketing override + global \ + ENABLE_BUCKETING, \ + CTX_BATCH_SIZE, \ + COMPILED_FP8_PATH, \ + BATCH_SIZE, \ + CTE_BUCKETS, \ + SEQ_LEN, \ + TKG_BUCKETS + if args.no_bucketing: + ENABLE_BUCKETING = False + if args.ctx_batch_size is not None: + CTX_BATCH_SIZE = args.ctx_batch_size + if CTX_BATCH_SIZE != 1: + # Use different output path for non-default ctx_batch_size + COMPILED_FP8_PATH = ( + f"/mnt/models/compiled_glm4_fp8_bucketed_ctx{CTX_BATCH_SIZE}" + ) + if args.batch_size is not None: + BATCH_SIZE = args.batch_size + COMPILED_FP8_PATH = f"/mnt/models/compiled_glm4_fp8_bs{BATCH_SIZE}" + if args.max_cte_bucket is not None: + CTE_BUCKETS = [b for b in CTE_BUCKETS if b <= args.max_cte_bucket] + if args.seq_len is not None: + SEQ_LEN = args.seq_len + CTE_BUCKETS = [b for b in CTE_BUCKETS if b <= SEQ_LEN] + TKG_BUCKETS = [SEQ_LEN] + if args.batch_size is not None: + COMPILED_FP8_PATH = ( + f"/mnt/models/compiled_glm4_fp8_bs{BATCH_SIZE}_seq{SEQ_LEN}" + ) + else: + COMPILED_FP8_PATH = f"/mnt/models/compiled_glm4_fp8_seq{SEQ_LEN}" + + if args.all: + args.quantize = args.compile = args.test = args.benchmark = True + + if not any([args.quantize, args.compile, args.test, args.benchmark]): + parser.print_help() + print( + "\nPlease specify at least one step: --quantize, --compile, --test, --benchmark, or --all" + ) + sys.exit(1) + + print("=" * 70) + print("GLM-4.7-Flash FP8 Expert Quantization Pipeline") + print(f" Model: {MODEL_PATH}") + print(f" Quantized: {QUANTIZED_PATH}") + print(f" Compiled: {COMPILED_FP8_PATH}") + print( + f" Config: BS={BATCH_SIZE}, CTX_BS={CTX_BATCH_SIZE}, SEQ={SEQ_LEN}, TP={TP_DEGREE}, LNC=2" + ) + print(f" Bucketing: {ENABLE_BUCKETING}") + if ENABLE_BUCKETING: + print(f" CTE buckets: {CTE_BUCKETS}") + print(f" TKG buckets: {TKG_BUCKETS}") + print("=" * 70) + + if args.quantize: + step_quantize() + + if args.compile: + step_compile() + + if args.test: + step_test() + + if args.benchmark: + step_benchmark() + + print("\n" + "=" * 70) + print("DONE") + print("=" * 70) + + +if __name__ == "__main__": + main() diff --git a/contrib/models/GLM-4.7-Flash/test/integration/test_model.py b/contrib/models/GLM-4.7-Flash/test/integration/test_model.py new file mode 100644 index 00000000..d0307d16 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/integration/test_model.py @@ -0,0 +1,326 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Integration tests for GLM-4.7-Flash on Neuron. + +Tests compilation, loading, and inference accuracy using the full 30B model +on a trn2.3xlarge instance with TP=4. + +Environment variables: + GLM4_MODEL_PATH Path to HF model weights (required) + GLM4_COMPILED_PATH Path to compiled artifacts (default: /tmp/glm4_traced) + GLM4_TP_DEGREE Tensor parallelism degree (default: 4) + GLM4_SEQ_LEN Max sequence length (default: 4096) + GLM4_BATCH_SIZE Batch size (default: 4, minimum for NCC_IBIR297 workaround) + +Prerequisites: + - trn2.3xlarge with LNC=2 (4 NeuronCores) + - NxDI installed (neuronx_distributed_inference >= 0.9) + - transformers >= 5.0 + - Model weights downloaded (59 GB) + +Usage: + # Full model (requires trn2.3xlarge + model weights): + GLM4_MODEL_PATH=/mnt/models/GLM-4.7-Flash \ + GLM4_COMPILED_PATH=/mnt/models/compiled_glm4_4096 \ + pytest test/integration/test_model.py --capture=tee-sys + + # Quick validation (pre-compiled): + GLM4_MODEL_PATH=/mnt/models/GLM-4.7-Flash \ + GLM4_COMPILED_PATH=/mnt/models/compiled_glm4_4096 \ + pytest test/integration/test_model.py -k "test_inference_accuracy" --capture=tee-sys + +Known Issues: + - Minimum batch_size=4 required (NCC_IBIR297 compiler issue at small TP degrees) + - transformers >= 5.0 requires Glm4MoeLiteGenerationAdapter (position_ids fix) + - NKI MoE kernel unavailable in SDK 2.29 (uses torch blockwise fallback) +""" + +import gc +import json +import os +import sys +import time + +import pytest +import torch + +# Ensure the contrib root is on sys.path +_CONTRIB_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) +if _CONTRIB_ROOT not in sys.path: + sys.path.insert(0, _CONTRIB_ROOT) + +# ── Configuration ─────────────────────────────────────────────────────── + +MODEL_PATH = os.environ.get("GLM4_MODEL_PATH", "") +COMPILED_PATH = os.environ.get("GLM4_COMPILED_PATH", "/tmp/glm4_traced") +TP_DEGREE = int(os.environ.get("GLM4_TP_DEGREE", "4")) +SEQ_LEN = int(os.environ.get("GLM4_SEQ_LEN", "4096")) +BATCH_SIZE = int(os.environ.get("GLM4_BATCH_SIZE", "4")) + +if not MODEL_PATH: + pytest.skip("GLM4_MODEL_PATH not set", allow_module_level=True) + + +# ── Fixtures ──────────────────────────────────────────────────────────── + + +@pytest.fixture(scope="module") +def neuron_config(): + """Create MoE Neuron config for GLM-4.7-Flash.""" + from neuronx_distributed_inference.models.config import ( + MoENeuronConfig, + OnDeviceSamplingConfig, + ) + + return MoENeuronConfig( + tp_degree=TP_DEGREE, + batch_size=BATCH_SIZE, + ctx_batch_size=BATCH_SIZE, + tkg_batch_size=BATCH_SIZE, + seq_len=SEQ_LEN, + torch_dtype=torch.bfloat16, + on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), + enable_bucketing=False, + flash_decoding_enabled=False, + logical_nc_config=2, + ) + + +@pytest.fixture(scope="module") +def inf_config(neuron_config): + """Create GLM-4.7-Flash inference config.""" + from neuronx_distributed_inference.utils.hf_adapter import load_pretrained_config + from src.modeling_glm4_moe_lite import Glm4MoeLiteInferenceConfig + + return Glm4MoeLiteInferenceConfig( + neuron_config, load_config=load_pretrained_config(MODEL_PATH) + ) + + +@pytest.fixture(scope="module") +def compiled_model(inf_config): + """Compile or load the model.""" + from src.modeling_glm4_moe_lite import NeuronGlm4MoeLiteForCausalLM + + if os.path.exists(os.path.join(COMPILED_PATH, "model.pt")): + print(f"\n Loading pre-compiled model from {COMPILED_PATH}") + model = NeuronGlm4MoeLiteForCausalLM(COMPILED_PATH, inf_config) + model.load(COMPILED_PATH) + else: + print(f"\n Compiling model to {COMPILED_PATH}") + os.makedirs(COMPILED_PATH, exist_ok=True) + model = NeuronGlm4MoeLiteForCausalLM(MODEL_PATH, inf_config) + model.compile(COMPILED_PATH) + model.load(COMPILED_PATH) + + yield model + del model + gc.collect() + + +@pytest.fixture(scope="module") +def tokenizer(): + """Load tokenizer.""" + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + if tok.pad_token is None: + tok.pad_token = tok.eos_token + return tok + + +@pytest.fixture(scope="module") +def gen_model(compiled_model): + """Create generation adapter (fixes transformers 5.x position_ids issue).""" + from src.modeling_glm4_moe_lite import Glm4MoeLiteGenerationAdapter + + return Glm4MoeLiteGenerationAdapter(compiled_model) + + +# ── Tests ─────────────────────────────────────────────────────────────── + + +class TestGlm4MoeLiteInference: + """Integration tests for GLM-4.7-Flash on Neuron.""" + + def test_model_loads(self, compiled_model): + """Verify model loads successfully and all cores are utilized.""" + assert compiled_model is not None + + def test_first_token_accuracy(self, gen_model, tokenizer): + """Verify first-token accuracy matches CPU reference (exact token ID match). + + These reference token IDs were captured from CPU FP32 inference with greedy + decoding. Exact token ID match provides strong accuracy validation without + requiring the full 30B model to fit on CPU during test execution. + """ + from transformers import GenerationConfig + + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + # Test prompts with known CPU reference token IDs (captured from FP32 reference) + test_cases = [ + ("The capital of France is", 12089, " Paris"), + ("In machine learning, a transformer model", 374, " is"), + ("def fibonacci(n):", 715, "\n"), + ] + + for prompt, expected_token_id, expected_text in test_cases: + inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True) + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=1, + ) + # Check first generated token by ID (exact match) + first_new_token = outputs[0, inputs.input_ids.shape[1]].item() + decoded = tokenizer.decode([first_new_token]) + print( + f"\n Prompt: '{prompt}' -> token_id={first_new_token} '{decoded}' " + f"(expected: {expected_token_id} '{expected_text}')" + ) + assert first_new_token == expected_token_id, ( + f"First token ID mismatch for '{prompt}': " + f"got {first_new_token} ('{decoded}'), " + f"expected {expected_token_id} ('{expected_text}')" + ) + + def test_coherent_generation(self, gen_model, tokenizer): + """Verify multi-token generation produces coherent text.""" + from transformers import GenerationConfig + + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + prompt = "The capital of France is" + inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True) + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=30, + ) + + generated_ids = outputs[0, inputs.input_ids.shape[1] :].tolist() + generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True) + print(f"\n Generated: '{generated_text[:100]}'") + + # Should contain "Paris" and be coherent + assert "Paris" in generated_text, ( + f"Expected 'Paris' in output: '{generated_text}'" + ) + assert len(generated_text) > 10, f"Generation too short: '{generated_text}'" + + def test_deterministic_outputs(self, gen_model, tokenizer): + """Verify greedy decoding produces identical outputs across runs.""" + from transformers import GenerationConfig + + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + prompt = "Hello, world" + inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True) + + outputs_1 = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=10, + ) + outputs_2 = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=10, + ) + + assert torch.equal(outputs_1, outputs_2), ( + "Greedy decoding should be deterministic" + ) + + def test_batch_consistency(self, gen_model, tokenizer): + """Verify all sequences in batch produce identical output (same prompt).""" + from transformers import GenerationConfig + + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + prompt = "The meaning of life is" + inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True) + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=15, + ) + + # All sequences in batch should be identical (same input, greedy) + for i in range(1, BATCH_SIZE): + assert torch.equal(outputs[0], outputs[i]), ( + f"Batch inconsistency: seq 0 != seq {i}" + ) + + def test_throughput(self, gen_model, tokenizer): + """Measure and report throughput metrics.""" + from transformers import GenerationConfig + + gen_config = GenerationConfig( + do_sample=True, + top_k=1, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + + prompt = "Explain quantum computing in simple terms:" + inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True) + + # Warmup + gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=5, + ) + + # Measure + max_new_tokens = 50 + t0 = time.time() + outputs = gen_model.generate( + inputs.input_ids, + generation_config=gen_config, + attention_mask=inputs.attention_mask, + max_new_tokens=max_new_tokens, + ) + elapsed = time.time() - t0 + + n_generated = outputs.shape[1] - inputs.input_ids.shape[1] + total_tokens = n_generated * BATCH_SIZE + throughput = total_tokens / elapsed + + print(f"\n Generated: {n_generated} tokens/seq") + print(f" Batch throughput: {throughput:.1f} tok/s") + print(f" Per-seq throughput: {n_generated / elapsed:.2f} tok/s") + print(f" Avg latency/token: {elapsed / n_generated * 1000:.1f} ms") + + # Minimum sanity threshold + assert throughput > 1.0, f"Throughput too low: {throughput:.2f} tok/s" diff --git a/contrib/models/GLM-4.7-Flash/test/unit/__init__.py b/contrib/models/GLM-4.7-Flash/test/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrib/models/GLM-4.7-Flash/test/unit/test_config.py b/contrib/models/GLM-4.7-Flash/test/unit/test_config.py new file mode 100644 index 00000000..8caa2e38 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/unit/test_config.py @@ -0,0 +1,164 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for GLM-4.7-Flash (glm4_moe_lite) inference configuration. + +CPU-only tests that validate config parsing, MLA parameter setup, +MoE routing config, dense layer detection, and KV cache shape. +""" + +import unittest +from unittest.mock import MagicMock + +import torch + +from src.modeling_glm4_moe_lite import ( + Glm4MoeLiteInferenceConfig, + Glm4MoeLiteNeuronConfig, +) +from neuronx_distributed_inference.models.config import MoENeuronConfig + + +def _make_config(**overrides): + """Create a Glm4MoeLiteInferenceConfig with GLM-4.7-Flash defaults.""" + neuron_config = MoENeuronConfig( + tp_degree=overrides.pop("tp_degree", 4), + batch_size=1, + seq_len=128, + torch_dtype=torch.bfloat16, + ) + defaults = dict( + hidden_size=2048, + num_hidden_layers=47, + num_attention_heads=20, + num_key_value_heads=1, + kv_lora_rank=512, + q_lora_rank=768, + qk_nope_head_dim=192, + qk_rope_head_dim=64, + v_head_dim=256, + n_routed_experts=64, + n_shared_experts=1, + num_experts_per_tok=4, + n_group=1, + topk_group=1, + intermediate_size=10240, + moe_intermediate_size=1536, + first_k_dense_replace=1, + vocab_size=154880, + rms_norm_eps=1e-6, + max_position_embeddings=200000, + rope_theta=10000, + routed_scaling_factor=1.8, + attention_dropout=0.0, + attention_bias=False, + ) + defaults.update(overrides) + config = Glm4MoeLiteInferenceConfig(neuron_config=neuron_config, **defaults) + return config + + +class TestConfigParsing(unittest.TestCase): + """Test basic config attribute initialization.""" + + def test_mla_parameters(self): + config = _make_config() + self.assertEqual(config.kv_lora_rank, 512) + self.assertEqual(config.q_lora_rank, 768) + self.assertEqual(config.qk_nope_head_dim, 192) + self.assertEqual(config.qk_rope_head_dim, 64) + self.assertEqual(config.v_head_dim, 256) + + def test_head_dim_override_for_kv_cache(self): + """MLA overrides head_dim to rope_dim + kv_lora_rank for KV cache allocation.""" + config = _make_config() + self.assertEqual(config.head_dim, 64 + 512) # rope_dim + kv_lora_rank = 576 + + def test_num_kv_heads_override(self): + """MLA sets num_key_value_heads=1 (MLA uses a single compressed KV, not GQA).""" + config = _make_config() + self.assertEqual(config.num_key_value_heads, 1) + + def test_moe_expert_params(self): + config = _make_config() + self.assertEqual(config.num_local_experts, 64) + self.assertEqual(config.n_shared_experts, 1) + self.assertEqual(config.num_experts_per_tok, 4) + + def test_intermediate_size_swap(self): + """intermediate_size should be swapped to moe_intermediate_size for MoE experts.""" + config = _make_config(intermediate_size=10240, moe_intermediate_size=1536) + self.assertEqual(config.intermediate_size, 1536) + self.assertEqual(config.dense_intermediate_size, 10240) + + def test_dense_layer_count(self): + """GLM-4.7-Flash has first_k_dense_replace=1 (only layer 0 is dense).""" + config = _make_config() + self.assertEqual(config.first_k_dense_replace, 1) + + def test_hidden_act_default(self): + config = _make_config() + self.assertEqual(config.hidden_act, "silu") + + +class TestNoYaRNRoPE(unittest.TestCase): + """Test that GLM-4.7-Flash does NOT inject YaRN config.""" + + def test_no_rope_scaling_injected(self): + """GLM-4.7-Flash uses standard RoPE — no rope_scaling should be injected.""" + config = _make_config() + # The config should not have YaRN-specific rope_scaling + # (unlike DeepSeek-V3 which injects a no-op YaRN config) + # Our config doesn't touch rope_scaling at all + # Just verify it doesn't crash and the relevant dims are correct + self.assertEqual(config.qk_rope_head_dim, 64) + + +class TestNeuronConfig(unittest.TestCase): + """Test Neuron-specific configuration settings.""" + + def test_disable_numeric_cc_token(self): + config = _make_config() + self.assertTrue(config.neuron_config.disable_numeric_cc_token) + + def test_neuron_config_cls(self): + self.assertEqual( + Glm4MoeLiteInferenceConfig.get_neuron_config_cls(), + Glm4MoeLiteNeuronConfig, + ) + + def test_required_attributes(self): + config = _make_config() + required = config.get_required_attributes() + self.assertIn("kv_lora_rank", required) + self.assertIn("n_routed_experts", required) + self.assertIn("moe_intermediate_size", required) + self.assertIn("qk_nope_head_dim", required) + self.assertIn("v_head_dim", required) + + def test_router_config_sigmoid(self): + """Router should use sigmoid activation for noaux_tc routing.""" + config = _make_config() + self.assertEqual(config.neuron_config.router_config.act_fn, "sigmoid") + self.assertEqual(config.neuron_config.router_config.dtype, torch.float32) + + def test_normalize_top_k_disabled(self): + """Normalization handled by router, not ExpertMLPsV2.""" + config = _make_config() + self.assertFalse(config.neuron_config.normalize_top_k_affinities) + + +class TestTPDivisibility(unittest.TestCase): + """Verify all dimensions are TP-divisible at TP=4.""" + + def test_attention_heads_divisible(self): + config = _make_config(tp_degree=4) + self.assertEqual(config.num_attention_heads % 4, 0) + + def test_vocab_size_divisible(self): + config = _make_config(tp_degree=4) + self.assertEqual(config.vocab_size % 4, 0) # 154880 / 4 = 38720 + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/GLM-4.7-Flash/test/unit/test_rope.py b/contrib/models/GLM-4.7-Flash/test/unit/test_rope.py new file mode 100644 index 00000000..cd3f43ce --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/unit/test_rope.py @@ -0,0 +1,110 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for GLM-4.7-Flash RoPE implementation. + +Validates standard RoPE (no YaRN) produces correct embeddings. +""" + +import unittest +import math + +import torch + +from src.rope_util import ( + Glm4MoeLiteRotaryEmbedding, + apply_rotary_pos_emb, + rotate_fn, +) + + +class TestRotaryEmbedding(unittest.TestCase): + """Test Glm4MoeLiteRotaryEmbedding.""" + + def test_output_shapes(self): + """cos/sin should have shape (seq_len, dim) after forward.""" + dim = 64 + seq_len = 128 + rope = Glm4MoeLiteRotaryEmbedding(dim=dim, max_position_embeddings=200000) + # forward expects a tensor for device/dtype reference + x = torch.randn(1, 1, 1, dim) + cos, sin = rope(x, seq_len=seq_len) + self.assertEqual(cos.shape, (seq_len, dim)) + self.assertEqual(sin.shape, (seq_len, dim)) + + def test_cos_sin_bounded(self): + """cos/sin values should be in [-1, 1].""" + dim = 64 + rope = Glm4MoeLiteRotaryEmbedding(dim=dim) + x = torch.randn(1, 1, 1, dim) + cos, sin = rope(x, seq_len=1024) + self.assertTrue(cos.abs().max() <= 1.0 + 1e-6) + self.assertTrue(sin.abs().max() <= 1.0 + 1e-6) + + def test_no_scaling(self): + """Standard RoPE: inv_freq = base^(-2i/dim), no scaling factor.""" + dim = 8 + base = 10000 + rope = Glm4MoeLiteRotaryEmbedding(dim=dim, base=base) + expected_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + torch.testing.assert_close(rope.inv_freq, expected_inv_freq) + + def test_position_0_is_identity(self): + """At position 0, cos=1 and sin=0 so RoPE is identity.""" + dim = 64 + rope = Glm4MoeLiteRotaryEmbedding(dim=dim) + x = torch.randn(1, 1, 1, dim) + cos, sin = rope(x, seq_len=1) + # At pos 0: cos=1, sin=0 for all dims + torch.testing.assert_close(cos[0], torch.ones(dim)) + torch.testing.assert_close(sin[0], torch.zeros(dim), atol=1e-6, rtol=0) + + +class TestRotateFn(unittest.TestCase): + """Test the interleaved rotation function.""" + + def test_basic_rotation(self): + """rotate_fn pairs (x0,x1) -> (-x1,x0).""" + x = torch.tensor([1.0, 2.0, 3.0, 4.0]) + expected = torch.tensor([-2.0, 1.0, -4.0, 3.0]) + result = rotate_fn(x.unsqueeze(0)).squeeze(0) + torch.testing.assert_close(result, expected) + + def test_rotation_preserves_norm(self): + """Rotation should preserve the L2 norm of each pair.""" + x = torch.randn(2, 4, 8, 64) + rotated = rotate_fn(x) + # L2 norm of each pair should be preserved + x_pairs = x.view(*x.shape[:-1], -1, 2) + r_pairs = rotated.view(*rotated.shape[:-1], -1, 2) + x_norms = x_pairs.norm(dim=-1) + r_norms = r_pairs.norm(dim=-1) + torch.testing.assert_close(x_norms, r_norms, atol=1e-5, rtol=1e-5) + + +class TestApplyRotaryPosEmb(unittest.TestCase): + """Test the full apply_rotary_pos_emb function.""" + + def test_output_shape_preserved(self): + """Output shape should match input shape.""" + bsz, num_heads, seq_len, head_dim = 2, 5, 16, 64 + rope = Glm4MoeLiteRotaryEmbedding(dim=head_dim) + q = torch.randn(bsz, num_heads, seq_len, head_dim) + cos, sin = rope(q, seq_len=seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(bsz, -1) + q_embed = apply_rotary_pos_emb(q, cos, sin, position_ids) + self.assertEqual(q_embed.shape, q.shape) + + def test_dtype_preserved(self): + """Output dtype should match input dtype.""" + bsz, num_heads, seq_len, head_dim = 1, 1, 4, 64 + rope = Glm4MoeLiteRotaryEmbedding(dim=head_dim) + q = torch.randn(bsz, num_heads, seq_len, head_dim, dtype=torch.bfloat16) + cos, sin = rope(q, seq_len=seq_len) + position_ids = torch.arange(seq_len).unsqueeze(0) + q_embed = apply_rotary_pos_emb(q, cos, sin, position_ids) + self.assertEqual(q_embed.dtype, torch.bfloat16) + + +if __name__ == "__main__": + unittest.main() diff --git a/contrib/models/GLM-4.7-Flash/test/unit/test_router.py b/contrib/models/GLM-4.7-Flash/test/unit/test_router.py new file mode 100644 index 00000000..f801376e --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/unit/test_router.py @@ -0,0 +1,228 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for GLM-4.7-Flash router (Glm4MoeLiteRouter). + +Tests sigmoid activation, noaux_tc top-k selection, L1 normalization, +and routed_scaling_factor application. + +Note: GroupLimitedRouter requires distributed parallel state to be initialized. +We mock get_expert_model_parallel_size and get_tensor_model_parallel_group +for CPU-only unit testing. +""" + +import pytest +import torch +import torch.nn.functional as F +from torch import nn +from unittest.mock import patch + + +class ReferenceGlm4Router(nn.Module): + """ + Reference implementation of GLM-4.7-Flash routing logic. + + With n_group=1, topk_group=1, the group logic is a no-op — it reduces + to sigmoid + bias + topk + L1-norm + scale. + """ + + def __init__(self, num_experts, top_k, hidden_size, routed_scaling_factor): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.routed_scaling_factor = routed_scaling_factor + self.weight = nn.Parameter(torch.empty(num_experts, hidden_size)) + self.e_score_correction_bias = nn.Parameter(torch.zeros(num_experts)) + + def forward(self, hidden_states): + # Linear + sigmoid in fp64 (matching GroupLimitedRouter) + router_logits = F.linear(hidden_states.float(), self.weight.float()) + scores = torch.sigmoid(router_logits.to(torch.float64)).to(hidden_states.dtype) + + # With n_group=1: no group selection, just topk on (scores + bias) + scores_for_choice = scores + self.e_score_correction_bias.unsqueeze(0) + _, topk_indices = torch.topk(scores_for_choice, k=self.top_k) + + # Gather, L1-norm, scale + topk_weights = scores.gather(1, topk_indices) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights * self.routed_scaling_factor + + return topk_indices, topk_weights + + +def _mock_get_expert_model_parallel_size(): + return 1 + + +def _mock_get_tensor_model_parallel_group(): + return None + + +# Patch NxD parallel state for CPU testing +_PARALLEL_PATCHES = [ + patch( + "neuronx_distributed.modules.moe.routing.get_expert_model_parallel_size", + _mock_get_expert_model_parallel_size, + ), + patch( + "neuronx_distributed.modules.moe.routing.get_tensor_model_parallel_group", + _mock_get_tensor_model_parallel_group, + ), +] + + +def _create_neuron_router( + num_experts=64, top_k=4, hidden_size=2048, routed_scaling_factor=1.8 +): + """Create Glm4MoeLiteRouter with mocked parallel state.""" + from src.modeling_glm4_moe_lite import Glm4MoeLiteRouter + + for p in _PARALLEL_PATCHES: + p.start() + try: + router = Glm4MoeLiteRouter( + routed_scaling_factor=routed_scaling_factor, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + n_group=1, + topk_group=1, + dtype=torch.float32, + sequence_parallel_enabled=False, + sequence_dimension=1, + ) + finally: + for p in _PARALLEL_PATCHES: + p.stop() + return router + + +@pytest.fixture +def router_config(): + return dict( + num_experts=64, + top_k=4, + hidden_size=2048, + routed_scaling_factor=1.8, + ) + + +@pytest.fixture +def router_and_ref(router_config): + """Create a Glm4MoeLiteRouter and matching reference, with shared weights.""" + ref = ReferenceGlm4Router(**router_config) + neuron_router = _create_neuron_router(**router_config) + + with torch.no_grad(): + nn.init.normal_(ref.weight, std=0.01) + nn.init.normal_(ref.e_score_correction_bias, std=0.001) + neuron_router.linear_router.weight.copy_(ref.weight) + neuron_router.e_score_correction_bias.copy_(ref.e_score_correction_bias) + + return neuron_router, ref + + +class TestGlm4MoeLiteRouter: + def test_routed_scaling_factor(self, router_and_ref): + neuron_router, _ = router_and_ref + assert neuron_router.routed_scaling_factor == 1.8 + + def test_e_score_correction_bias_registered(self, router_and_ref): + neuron_router, _ = router_and_ref + assert hasattr(neuron_router, "e_score_correction_bias") + assert neuron_router.e_score_correction_bias.shape == (64,) + + def test_n_group_topk_group(self, router_and_ref): + """GLM-4.7-Flash uses n_group=1, topk_group=1 (no group selection).""" + neuron_router, _ = router_and_ref + assert neuron_router.n_group == 1 + assert neuron_router.topk_group == 1 + + def test_expert_selection_matches_reference(self, router_and_ref, router_config): + """Expert indices must be consistent between separate forward calls.""" + neuron_router, ref = router_and_ref + torch.manual_seed(42) + x = torch.randn(16, router_config["hidden_size"]) + + # Run twice — should be deterministic + _, _, expert_index_1 = neuron_router(x) + _, _, expert_index_2 = neuron_router(x) + + assert torch.equal(expert_index_1, expert_index_2), ( + f"Router is non-deterministic.\nRun 1: {expert_index_1[:3]}\nRun 2: {expert_index_2[:3]}" + ) + + def test_expert_weights_match_reference(self, router_and_ref, router_config): + """Expert weights (normalized + scaled) must be self-consistent.""" + neuron_router, ref = router_and_ref + torch.manual_seed(42) + x = torch.randn(16, router_config["hidden_size"]) + + _, expert_affinities, expert_index = neuron_router(x) + + # Gather the non-zero weights for selected experts + neuron_weights = expert_affinities.gather(1, expert_index) + + # Weights should sum to routed_scaling_factor (L1 norm + scale) + weight_sums = neuron_weights.sum(dim=-1) + expected = router_config["routed_scaling_factor"] + torch.testing.assert_close( + weight_sums, torch.full_like(weight_sums, expected), atol=1e-4, rtol=1e-4 + ) + + # All weights should be positive (sigmoid outputs are positive) + assert (neuron_weights > 0).all(), ( + "All selected expert weights should be positive" + ) + + def test_output_shapes(self, router_and_ref, router_config): + """Router outputs have correct shapes.""" + neuron_router, _ = router_and_ref + T = 16 + x = torch.randn(T, router_config["hidden_size"]) + + router_logits, expert_affinities, expert_index = neuron_router(x) + assert router_logits.shape == (T, router_config["num_experts"]) + assert expert_affinities.shape == (T, router_config["num_experts"]) + assert expert_index.shape == (T, router_config["top_k"]) + + def test_topk_indices_valid(self, router_and_ref, router_config): + """Top-k indices should be in [0, num_experts).""" + neuron_router, _ = router_and_ref + x = torch.randn(8, router_config["hidden_size"]) + + _, _, topk_idx = neuron_router(x) + assert (topk_idx >= 0).all() + assert (topk_idx < router_config["num_experts"]).all() + + def test_scaling_factor_sum(self, router_and_ref, router_config): + """Weights should sum to routed_scaling_factor per token (L1 norm + scale).""" + neuron_router, _ = router_and_ref + torch.manual_seed(42) + x = torch.randn(32, router_config["hidden_size"]) + + _, expert_affinities, expert_index = neuron_router(x) + topk_weights = expert_affinities.gather(1, expert_index) + weight_sums = topk_weights.sum(dim=-1) + + expected = router_config["routed_scaling_factor"] + torch.testing.assert_close( + weight_sums, torch.full_like(weight_sums, expected), atol=1e-4, rtol=1e-4 + ) + + def test_sparsity_pattern(self, router_and_ref, router_config): + """Only top_k experts should have non-zero affinities per token.""" + neuron_router, _ = router_and_ref + torch.manual_seed(42) + x = torch.randn(8, router_config["hidden_size"]) + + _, expert_affinities, _ = neuron_router(x) + nonzero_per_token = (expert_affinities != 0).sum(dim=-1) + assert (nonzero_per_token == router_config["top_k"]).all(), ( + f"Expected {router_config['top_k']} non-zero per token, got {nonzero_per_token}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/contrib/models/GLM-4.7-Flash/test/unit/test_weight_conversion.py b/contrib/models/GLM-4.7-Flash/test/unit/test_weight_conversion.py new file mode 100644 index 00000000..4a0387a9 --- /dev/null +++ b/contrib/models/GLM-4.7-Flash/test/unit/test_weight_conversion.py @@ -0,0 +1,288 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for GLM-4.7-Flash state dict conversion. + +Tests: +1. State dict key renaming (router, e_score_correction_bias) +2. Expert weight fusion (gate_proj + up_proj -> gate_up_proj, down_proj stacking) +3. Dense layers are skipped (first_k_dense_replace=1, only layer 0) +4. Rank utility tensors added +5. MTP layer weights removed +6. No FP8 dequantization (native BF16) +""" + +import unittest +from unittest.mock import MagicMock + +import torch + +from src.modeling_glm4_moe_lite import ( + convert_glm4_moe_lite_hf_to_neuron_state_dict, +) + + +def _make_mock_config( + num_hidden_layers=5, + num_local_experts=8, + tp_degree=4, + first_k_dense_replace=1, + hidden_size=64, + intermediate_size=32, +): + """Create a lightweight mock config for state dict conversion tests.""" + config = MagicMock() + config.num_hidden_layers = num_hidden_layers + config.num_local_experts = num_local_experts + config.first_k_dense_replace = first_k_dense_replace + config.hidden_size = hidden_size + config.intermediate_size = intermediate_size + config.neuron_config = MagicMock() + config.neuron_config.tp_degree = tp_degree + return config + + +def _make_hf_moe_state_dict( + layer_idx, num_experts=8, hidden_size=64, intermediate_size=32, dtype=torch.float32 +): + """Create a fake HF MoE layer state dict with router + experts.""" + sd = {} + # Router + sd[f"layers.{layer_idx}.mlp.gate.weight"] = torch.randn( + num_experts, hidden_size, dtype=dtype + ) + sd[f"layers.{layer_idx}.mlp.gate.e_score_correction_bias"] = torch.randn( + num_experts, dtype=dtype + ) + # Experts + for e in range(num_experts): + sd[f"layers.{layer_idx}.mlp.experts.{e}.gate_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=dtype + ) + sd[f"layers.{layer_idx}.mlp.experts.{e}.up_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=dtype + ) + sd[f"layers.{layer_idx}.mlp.experts.{e}.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size, dtype=dtype + ) + # Shared experts (pass through unchanged) + sd[f"layers.{layer_idx}.mlp.shared_experts.gate_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=dtype + ) + sd[f"layers.{layer_idx}.mlp.shared_experts.up_proj.weight"] = torch.randn( + intermediate_size, hidden_size, dtype=dtype + ) + sd[f"layers.{layer_idx}.mlp.shared_experts.down_proj.weight"] = torch.randn( + hidden_size, intermediate_size, dtype=dtype + ) + return sd + + +class TestStateDictConversion(unittest.TestCase): + """Tests for convert_glm4_moe_lite_hf_to_neuron_state_dict.""" + + def test_rank_util_tensors_added(self): + """Rank utility tensors must be added for TP sharding.""" + config = _make_mock_config(num_hidden_layers=2, first_k_dense_replace=2) + sd = {} + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + assert "rank_util.rank" in result + torch.testing.assert_close( + result["rank_util.rank"], torch.arange(0, 4, dtype=torch.int32) + ) + for i in range(2): + assert f"layers.{i}.self_attn.rank_util.rank" in result + + def test_dense_layer_skipped(self): + """Layer 0 (dense, first_k_dense_replace=1) should not have MoE conversion.""" + config = _make_mock_config( + num_hidden_layers=3, first_k_dense_replace=1, num_local_experts=4 + ) + sd = {} + # Add dense layer weights (layer 0) + sd["layers.0.mlp.gate_proj.weight"] = torch.randn(32, 64) + sd["layers.0.mlp.up_proj.weight"] = torch.randn(32, 64) + sd["layers.0.mlp.down_proj.weight"] = torch.randn(64, 32) + # Add MoE layers (layers 1-2) + sd.update(_make_hf_moe_state_dict(1, num_experts=4)) + sd.update(_make_hf_moe_state_dict(2, num_experts=4)) + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + # Dense layer weights should be unchanged + assert "layers.0.mlp.gate_proj.weight" in result + + # MoE layers should have converted keys + assert "layers.1.mlp.router.linear_router.weight" in result + assert "layers.1.mlp.expert_mlps.mlp_op.gate_up_proj.weight" in result + assert "layers.2.mlp.router.linear_router.weight" in result + + def test_router_rename(self): + """gate.weight -> router.linear_router.weight""" + config = _make_mock_config( + num_hidden_layers=1, first_k_dense_replace=0, num_local_experts=4 + ) + sd = _make_hf_moe_state_dict(0, num_experts=4) + + original_router_w = sd["layers.0.mlp.gate.weight"].clone() + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + assert "layers.0.mlp.gate.weight" not in result + assert "layers.0.mlp.router.linear_router.weight" in result + torch.testing.assert_close( + result["layers.0.mlp.router.linear_router.weight"], original_router_w + ) + + def test_e_score_correction_bias_renamed(self): + """gate.e_score_correction_bias should be renamed to router.e_score_correction_bias.""" + config = _make_mock_config( + num_hidden_layers=1, first_k_dense_replace=0, num_local_experts=4 + ) + sd = _make_hf_moe_state_dict(0, num_experts=4) + original_bias = sd["layers.0.mlp.gate.e_score_correction_bias"].clone() + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + assert "layers.0.mlp.gate.e_score_correction_bias" not in result + assert "layers.0.mlp.router.e_score_correction_bias" in result + torch.testing.assert_close( + result["layers.0.mlp.router.e_score_correction_bias"], original_bias + ) + + def test_expert_gate_up_fusion(self): + """Per-expert gate_proj + up_proj -> fused gate_up_proj [num_experts, hidden, 2*intermediate].""" + num_experts = 4 + hidden_size = 64 + intermediate_size = 32 + config = _make_mock_config( + num_hidden_layers=1, + first_k_dense_replace=0, + num_local_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + sd = _make_hf_moe_state_dict( + 0, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + + # Save originals for verification + gate_projs = [] + up_projs = [] + for e in range(num_experts): + gate_projs.append(sd[f"layers.0.mlp.experts.{e}.gate_proj.weight"].clone()) + up_projs.append(sd[f"layers.0.mlp.experts.{e}.up_proj.weight"].clone()) + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + fused_key = "layers.0.mlp.expert_mlps.mlp_op.gate_up_proj.weight" + assert fused_key in result + fused = result[fused_key] + assert fused.shape == (num_experts, hidden_size, 2 * intermediate_size) + + # Verify fusion: gate_proj.T in first half, up_proj.T in second half + for e in range(num_experts): + torch.testing.assert_close(fused[e, :, :intermediate_size], gate_projs[e].T) + torch.testing.assert_close(fused[e, :, intermediate_size:], up_projs[e].T) + + # Per-expert keys should be removed + for e in range(num_experts): + assert f"layers.0.mlp.experts.{e}.gate_proj.weight" not in result + assert f"layers.0.mlp.experts.{e}.up_proj.weight" not in result + + def test_expert_down_proj_stacking(self): + """Per-expert down_proj -> stacked [num_experts, intermediate, hidden].""" + num_experts = 4 + hidden_size = 64 + intermediate_size = 32 + config = _make_mock_config( + num_hidden_layers=1, + first_k_dense_replace=0, + num_local_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + sd = _make_hf_moe_state_dict( + 0, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + ) + + down_projs = [] + for e in range(num_experts): + down_projs.append(sd[f"layers.0.mlp.experts.{e}.down_proj.weight"].clone()) + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + stacked_key = "layers.0.mlp.expert_mlps.mlp_op.down_proj.weight" + assert stacked_key in result + stacked = result[stacked_key] + assert stacked.shape == (num_experts, intermediate_size, hidden_size) + + for e in range(num_experts): + torch.testing.assert_close(stacked[e], down_projs[e].T) + assert f"layers.0.mlp.experts.{e}.down_proj.weight" not in result + + def test_shared_experts_unchanged(self): + """Shared expert weights should pass through without renaming.""" + config = _make_mock_config( + num_hidden_layers=1, first_k_dense_replace=0, num_local_experts=4 + ) + sd = _make_hf_moe_state_dict(0, num_experts=4) + + original_shared_gate = sd[ + "layers.0.mlp.shared_experts.gate_proj.weight" + ].clone() + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + assert "layers.0.mlp.shared_experts.gate_proj.weight" in result + torch.testing.assert_close( + result["layers.0.mlp.shared_experts.gate_proj.weight"], original_shared_gate + ) + + def test_mtp_layer_removed(self): + """MTP layer weights (layer N where N=num_hidden_layers) should be removed.""" + num_hidden_layers = 3 + config = _make_mock_config( + num_hidden_layers=num_hidden_layers, + first_k_dense_replace=0, + num_local_experts=4, + ) + sd = _make_hf_moe_state_dict(0, num_experts=4) + sd.update(_make_hf_moe_state_dict(1, num_experts=4)) + sd.update(_make_hf_moe_state_dict(2, num_experts=4)) + # Add MTP layer (layer 3 = num_hidden_layers) + sd[f"layers.{num_hidden_layers}.embed_tokens.weight"] = torch.randn( + 154880, 2048 + ) + sd[f"layers.{num_hidden_layers}.enorm.weight"] = torch.randn(2048) + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + # MTP keys should be gone + assert f"layers.{num_hidden_layers}.embed_tokens.weight" not in result + assert f"layers.{num_hidden_layers}.enorm.weight" not in result + # Regular layer keys should still exist + assert "layers.0.mlp.router.linear_router.weight" in result + + def test_no_fp8_handling(self): + """GLM-4.7-Flash uses native BF16 — no FP8 scale_inv keys expected.""" + config = _make_mock_config( + num_hidden_layers=1, first_k_dense_replace=0, num_local_experts=4 + ) + sd = _make_hf_moe_state_dict(0, num_experts=4, dtype=torch.bfloat16) + + result = convert_glm4_moe_lite_hf_to_neuron_state_dict(sd, config) + + # No scale_inv keys should exist + scale_keys = [k for k in result if "scale_inv" in k] + self.assertEqual(len(scale_keys), 0) + + +if __name__ == "__main__": + unittest.main()