Skip to content

Gemma 4 E2B text model support#326

Open
Adi2K wants to merge 8 commits intotrymirai:mainfrom
Adi2K:gemma4-clean
Open

Gemma 4 E2B text model support#326
Adi2K wants to merge 8 commits intotrymirai:mainfrom
Adi2K:gemma4-clean

Conversation

@Adi2K
Copy link
Copy Markdown

@Adi2K Adi2K commented Apr 7, 2026

Summary

Add text-only inference support for Google's Gemma 4 E2B architecture:

  • Per-Layer Embedding (PLE): Model-level PLE computation with gated projection, per-layer injection, dedicated ArrayIds and scratch buffers
  • Heterogeneous attention: Per-layer head_dim (256/512) and num_groups, head_dim=512 Metal+CPU kernel variants, per-layer buffer reshaping
  • KV cache sharing: Shared layers (15-34) alias KV from source layers via Rc, skip KV update for shared layers
  • Double-wide MLP: Per-layer hidden_dims, max-based buffer sizing for scratch buffers
  • Per-layer scalar: Layer-wise scaling factor applied after residual + PLE
  • V-norm: QK normalization with all-ones value normalization buffer
  • Proportional RoPE: Per-type (global/local) head_dim for frequency computation
  • Fused norm compatibility: Explicit residual add before PLE/scalar to work correctly with fused RMSNorm

Requires trymirai/lalamo#197 (model conversion support)

Benchmark (4-bit quantized, Apple M3 16GB)

Metric uzu MLX Delta
Prefill (t/s) 516.9 1039.5 -50.3%
Generation (t/s) 38.8 57.1 -32.1%

Prompt 64 tokens, generation 256 tokens, 5 runs. MLX via mlx-lm@main (0.31.2, unreleased). Same quantization config (4-bit, group_size=64, affine). Initial implementation prioritizes correctness; optimization opportunities identified (KV projection waste, buffer reuse).

Trace Validation

Full-precision model validated against HuggingFace reference (scripts/generate_traces.py):

  • 0/1 token violations (correct argmax prediction)
  • Layers 0-14: <1.5% output error (bf16 precision floor)
  • Layers 15-34 (shared): <2.5% output error (accumulated bf16 drift)
  • Mixer (attention) error: 0.23%-0.99% across all 35 layers

Scope

Text modality only, consistent with how other multimodal models (Gemma 3) are supported in uzu. Vision/audio encoders can be added incrementally.

Known Limitations

  • KV projection waste: Shared layers still compute unused K/V projections (~33% wasted compute, performance-only). Fix: skip K/V projection for shared layers.

Test Plan

  • cargo +nightly fmt --check passes
  • CARGO_ENCODED_RUSTFLAGS="-Dwarnings" cargo build passes (zero warnings)
  • cargo test --package uzu --features tracing --lib passes (97/97)
  • Tracer integration test passes (0 token violations)
  • Full-precision model generates coherent text (18.2 t/s)
  • 4-bit quantized model generates coherent text (40.7 t/s)
  • HEAD_DIM=512 attention kernel tests added (single-pass + two-pass)

Hardware Note

Developed and tested on Apple M3 with 16GB RAM. Both full-precision (9.1GB) and 4-bit quantized (6.6GB) models load and generate correctly.

Adi2K added 8 commits April 7, 2026 19:27
Add config support for Gemma 4 architecture features:
- Per-layer hidden_dims (heterogeneous MLP widths)
- Per-layer head_dim and num_groups (heterogeneous attention)
- KV cache sharing (shared_kv_layer_period, shared_kv_sources)
- PLE config (ple_dim, ple_vocab_size, ple_num_layers)
- Per-layer scalar (has_layer_scalar)
- Value norm config (value_norm)
- Logit soft-cap validation
- model_shape.rs: per-layer attention dims (max_head_dim, max_num_groups),
  max hidden_dim for double-wide MLP, per-layer RoPE dims
- cache_layers.rs: KV cache sharing via Rc::clone for shared layers,
  skip KV update for shared layers
- scratch_buffers.rs: PLE gate/projection/embedding buffers
- state.rs: PLE auxiliary buffers in ForwardPassState
- rope_buffers.rs: per-type RoPE dim support
…se mul

Add Metal and CPU kernel support for:
- head_dim=512 (Gemma 4 global attention layers) in single-pass and
  two-pass attention kernels
- ElementWiseMulStrided kernel for PLE gating
…ecoder

- decoder.rs: model-level PLE computation (embed, project, norm, gate,
  inject per-layer), per-layer scalar multiplication, shared KV layer
  handling
- layer/executables.rs: per-layer PLE injection, V-norm (all-ones
  value normalization), per-layer head_dim/num_groups reshape
- embedding.rs: logit_soft_cap support, LookupOnly embedding mode
- attention.rs: head_dim=512 kernel registration
- qk_norm.rs: QK normalization support
- generate_traces.py: HuggingFace reference trace generation for Gemma 4
Add backward-compatible support for the config structure exported by
trymirai/lalamo#197:

- PLEModelConfig: deserialize nested PLE config with field name mapping
  (model_projection_scale → ple_projection_scale, input_scale →
  ple_combination_scale)
- Per-layer fields: hidden_dim, kv_source_layer, ple_config on
  TransformerLayerConfig, with fallback derivation of top-level arrays
- normalize_values: bool → value_norm_config conversion for V-norm
- global_rope_dim/global_head_dim → partial_rope_dim derivation for
  global attention layers
- All new fields use #[serde(default)] for backward compatibility

Both our existing flat config format and their nested format now
deserialize correctly into the same internal DecoderConfig.
…comments

Fix fused-norm interaction with PLE and layer_scalar: the fused RMSNorm
PRs (trymirai#319, trymirai#320) defer the MLP residual add to the next layer's
pre_attention_norm, but PLE and layer_scalar must operate on the full
residual sum. Add explicit TensorAddSwap before PLE/scalar, and disable
fused residual_add on the subsequent layer's pre_attention_norm and
output_norm accordingly. Fix trace output to copy Main directly for
layers with explicit residual add.

Also remove unnecessary comments per AGENTS.md coding standards.
…ests

- Restructure PleModelWeights and PleLayerComponents into proper structs
  with non-optional fields, stored as Option<T> on parent types
- Array<B> → B::Buffer for zero-bias buffers (matches delta_net convention)
- Replace panic! with descriptive error for layer_scalar unsupported dtype
- Rename abbreviated variables (nh/ng/hd) to full names
- Vec → Box<[T]> for immutable config fields (InnerModelConfig)
- Functional struct update for apply_mixer_conversions
- Simplify Option chains with or_else
- pub → pub(crate) for normalize_values deserialization shim
- Add #[source] to ParameterError in qk_norm
- Add HEAD_DIM=512 attention kernel tests (single-pass + two-pass)
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 65c24fd6e3

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

max_rope_dim = max_rope_dim.max(attn.partial_rope_dim.unwrap_or(hd));
let nh = attn.num_heads.unwrap_or(decoder_config.num_heads);
let ng = attn.num_groups.unwrap_or(decoder_config.num_groups);
let rope_dim = hd;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve partial RoPE dims when sizing rotary buffers

This computes rope_dim from head_dim unconditionally, which drops attn.partial_rope_dim for models that rotate only part of each head. In those configs, RoPE cos/sin tables become too wide and the rope kernel rotates channels that should remain untouched, changing attention outputs for existing partial-RoPE models. Please derive the per-layer rope dimension from partial_rope_dim.unwrap_or(head_dim) (as before) when building global/local rope sizes.

Useful? React with 👍 / 👎.

let has_layer_scalar = if self.has_layer_scalar {
true
} else {
tf.layer_configs.iter().any(|l| l.ple_config.as_ref().is_some_and(|p| p.has_layer_scalar))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep layer-scalar enablement per layer

The conversion collapses per-layer ple_config.has_layer_scalar into a single global boolean via any(...), and that value is then written into every DecoderLayerConfig. If only some transformer layers actually have layer_scalar weights, later layer construction will try to load missing weights and panic, or apply scalar where it should be disabled. This should be propagated per layer (or explicitly validated as homogeneous before conversion).

Useful? React with 👍 / 👎.

decoder_config
.hidden_dims
.as_ref()
.map(|dims| dims[layer_index])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard hidden_dims indexing against short configs

Layer construction indexes hidden_dims with dims[layer_index] without bounds checks, but decoder deserialization does not enforce that hidden_dims.len() == num_layers. A shorter hidden_dims array now causes an index-out-of-bounds panic during model load. Please validate lengths at config parse time or use checked access with a fallback to hidden_dim.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant