Conversation
Add config support for Gemma 4 architecture features: - Per-layer hidden_dims (heterogeneous MLP widths) - Per-layer head_dim and num_groups (heterogeneous attention) - KV cache sharing (shared_kv_layer_period, shared_kv_sources) - PLE config (ple_dim, ple_vocab_size, ple_num_layers) - Per-layer scalar (has_layer_scalar) - Value norm config (value_norm) - Logit soft-cap validation
- model_shape.rs: per-layer attention dims (max_head_dim, max_num_groups), max hidden_dim for double-wide MLP, per-layer RoPE dims - cache_layers.rs: KV cache sharing via Rc::clone for shared layers, skip KV update for shared layers - scratch_buffers.rs: PLE gate/projection/embedding buffers - state.rs: PLE auxiliary buffers in ForwardPassState - rope_buffers.rs: per-type RoPE dim support
…se mul Add Metal and CPU kernel support for: - head_dim=512 (Gemma 4 global attention layers) in single-pass and two-pass attention kernels - ElementWiseMulStrided kernel for PLE gating
…ecoder - decoder.rs: model-level PLE computation (embed, project, norm, gate, inject per-layer), per-layer scalar multiplication, shared KV layer handling - layer/executables.rs: per-layer PLE injection, V-norm (all-ones value normalization), per-layer head_dim/num_groups reshape - embedding.rs: logit_soft_cap support, LookupOnly embedding mode - attention.rs: head_dim=512 kernel registration - qk_norm.rs: QK normalization support
- generate_traces.py: HuggingFace reference trace generation for Gemma 4
Add backward-compatible support for the config structure exported by trymirai/lalamo#197: - PLEModelConfig: deserialize nested PLE config with field name mapping (model_projection_scale → ple_projection_scale, input_scale → ple_combination_scale) - Per-layer fields: hidden_dim, kv_source_layer, ple_config on TransformerLayerConfig, with fallback derivation of top-level arrays - normalize_values: bool → value_norm_config conversion for V-norm - global_rope_dim/global_head_dim → partial_rope_dim derivation for global attention layers - All new fields use #[serde(default)] for backward compatibility Both our existing flat config format and their nested format now deserialize correctly into the same internal DecoderConfig.
…comments Fix fused-norm interaction with PLE and layer_scalar: the fused RMSNorm PRs (trymirai#319, trymirai#320) defer the MLP residual add to the next layer's pre_attention_norm, but PLE and layer_scalar must operate on the full residual sum. Add explicit TensorAddSwap before PLE/scalar, and disable fused residual_add on the subsequent layer's pre_attention_norm and output_norm accordingly. Fix trace output to copy Main directly for layers with explicit residual add. Also remove unnecessary comments per AGENTS.md coding standards.
…ests - Restructure PleModelWeights and PleLayerComponents into proper structs with non-optional fields, stored as Option<T> on parent types - Array<B> → B::Buffer for zero-bias buffers (matches delta_net convention) - Replace panic! with descriptive error for layer_scalar unsupported dtype - Rename abbreviated variables (nh/ng/hd) to full names - Vec → Box<[T]> for immutable config fields (InnerModelConfig) - Functional struct update for apply_mixer_conversions - Simplify Option chains with or_else - pub → pub(crate) for normalize_values deserialization shim - Add #[source] to ParameterError in qk_norm - Add HEAD_DIM=512 attention kernel tests (single-pass + two-pass)
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 65c24fd6e3
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| max_rope_dim = max_rope_dim.max(attn.partial_rope_dim.unwrap_or(hd)); | ||
| let nh = attn.num_heads.unwrap_or(decoder_config.num_heads); | ||
| let ng = attn.num_groups.unwrap_or(decoder_config.num_groups); | ||
| let rope_dim = hd; |
There was a problem hiding this comment.
Preserve partial RoPE dims when sizing rotary buffers
This computes rope_dim from head_dim unconditionally, which drops attn.partial_rope_dim for models that rotate only part of each head. In those configs, RoPE cos/sin tables become too wide and the rope kernel rotates channels that should remain untouched, changing attention outputs for existing partial-RoPE models. Please derive the per-layer rope dimension from partial_rope_dim.unwrap_or(head_dim) (as before) when building global/local rope sizes.
Useful? React with 👍 / 👎.
| let has_layer_scalar = if self.has_layer_scalar { | ||
| true | ||
| } else { | ||
| tf.layer_configs.iter().any(|l| l.ple_config.as_ref().is_some_and(|p| p.has_layer_scalar)) |
There was a problem hiding this comment.
Keep layer-scalar enablement per layer
The conversion collapses per-layer ple_config.has_layer_scalar into a single global boolean via any(...), and that value is then written into every DecoderLayerConfig. If only some transformer layers actually have layer_scalar weights, later layer construction will try to load missing weights and panic, or apply scalar where it should be disabled. This should be propagated per layer (or explicitly validated as homogeneous before conversion).
Useful? React with 👍 / 👎.
| decoder_config | ||
| .hidden_dims | ||
| .as_ref() | ||
| .map(|dims| dims[layer_index]) |
There was a problem hiding this comment.
Guard hidden_dims indexing against short configs
Layer construction indexes hidden_dims with dims[layer_index] without bounds checks, but decoder deserialization does not enforce that hidden_dims.len() == num_layers. A shorter hidden_dims array now causes an index-out-of-bounds panic during model load. Please validate lengths at config parse time or use checked access with a fallback to hidden_dim.
Useful? React with 👍 / 👎.
Summary
Add text-only inference support for Google's Gemma 4 E2B architecture:
Requires trymirai/lalamo#197 (model conversion support)
Benchmark (4-bit quantized, Apple M3 16GB)
Prompt 64 tokens, generation 256 tokens, 5 runs. MLX via mlx-lm@main (0.31.2, unreleased). Same quantization config (4-bit, group_size=64, affine). Initial implementation prioritizes correctness; optimization opportunities identified (KV projection waste, buffer reuse).
Trace Validation
Full-precision model validated against HuggingFace reference (
scripts/generate_traces.py):Scope
Text modality only, consistent with how other multimodal models (Gemma 3) are supported in uzu. Vision/audio encoders can be added incrementally.
Known Limitations
Test Plan
cargo +nightly fmt --checkpassesCARGO_ENCODED_RUSTFLAGS="-Dwarnings" cargo buildpasses (zero warnings)cargo test --package uzu --features tracing --libpasses (97/97)Hardware Note
Developed and tested on Apple M3 with 16GB RAM. Both full-precision (9.1GB) and 4-bit quantized (6.6GB) models load and generate correctly.