From 8aa8034154b41a5b2221a176236d5c25e6de49d7 Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 16 Feb 2026 13:23:05 -0500 Subject: [PATCH 01/17] add MLX layers --- README.md | 2 + docs/mlx_guide.md | 247 ++ notebooks/mlx_benchmark.py | 289 ++ notebooks/mlx_streaming.py | 241 ++ sequence_layers/jax/types.py | 45 + sequence_layers/mlx/__init__.py | 361 ++- sequence_layers/mlx/attention.py | 1448 ++++++++++ sequence_layers/mlx/attention_test.py | 528 ++++ sequence_layers/mlx/backend_dispatch_test.py | 72 + sequence_layers/mlx/basic_types.py | 5 + sequence_layers/mlx/combinators.py | 423 +++ sequence_layers/mlx/combinators_test.py | 257 ++ sequence_layers/mlx/conditioning.py | 409 +++ sequence_layers/mlx/conditioning_test.py | 355 +++ sequence_layers/mlx/convolution.py | 1113 ++++++++ sequence_layers/mlx/convolution_test.py | 252 ++ sequence_layers/mlx/cross_backend_test.py | 2344 +++++++++++++++++ .../mlx/decoder_transformer_test.py | 248 ++ sequence_layers/mlx/dense.py | 316 +++ sequence_layers/mlx/dense_test.py | 112 + sequence_layers/mlx/dsp.py | 1193 +++++++++ sequence_layers/mlx/dsp_test.py | 463 ++++ sequence_layers/mlx/export.py | 194 ++ sequence_layers/mlx/export_test.py | 297 +++ sequence_layers/mlx/init_mapping.py | 232 ++ sequence_layers/mlx/normalization.py | 420 +++ sequence_layers/mlx/normalization_test.py | 187 ++ sequence_layers/mlx/pooling.py | 439 +++ sequence_layers/mlx/pooling_test.py | 260 ++ sequence_layers/mlx/position.py | 104 + sequence_layers/mlx/position_test.py | 73 + sequence_layers/mlx/simple.py | 816 ++++++ sequence_layers/mlx/simple_test.py | 508 ++++ sequence_layers/mlx/test_utils.py | 174 ++ sequence_layers/mlx/weight_converter.py | 585 ++++ sequence_layers/mlx/weight_converter_test.py | 463 ++++ 36 files changed, 15474 insertions(+), 1 deletion(-) create mode 100644 docs/mlx_guide.md create mode 100644 notebooks/mlx_benchmark.py create mode 100644 notebooks/mlx_streaming.py create mode 100644 sequence_layers/mlx/attention.py create mode 100644 sequence_layers/mlx/attention_test.py create mode 100644 sequence_layers/mlx/backend_dispatch_test.py create mode 100644 sequence_layers/mlx/combinators.py create mode 100644 sequence_layers/mlx/combinators_test.py create mode 100644 sequence_layers/mlx/conditioning.py create mode 100644 sequence_layers/mlx/conditioning_test.py create mode 100644 sequence_layers/mlx/convolution.py create mode 100644 sequence_layers/mlx/convolution_test.py create mode 100644 sequence_layers/mlx/cross_backend_test.py create mode 100644 sequence_layers/mlx/decoder_transformer_test.py create mode 100644 sequence_layers/mlx/dense.py create mode 100644 sequence_layers/mlx/dense_test.py create mode 100644 sequence_layers/mlx/dsp.py create mode 100644 sequence_layers/mlx/dsp_test.py create mode 100644 sequence_layers/mlx/export.py create mode 100644 sequence_layers/mlx/export_test.py create mode 100644 sequence_layers/mlx/init_mapping.py create mode 100644 sequence_layers/mlx/normalization.py create mode 100644 sequence_layers/mlx/normalization_test.py create mode 100644 sequence_layers/mlx/pooling.py create mode 100644 sequence_layers/mlx/pooling_test.py create mode 100644 sequence_layers/mlx/position.py create mode 100644 sequence_layers/mlx/position_test.py create mode 100644 sequence_layers/mlx/simple.py create mode 100644 sequence_layers/mlx/simple_test.py create mode 100644 sequence_layers/mlx/test_utils.py create mode 100644 sequence_layers/mlx/weight_converter.py create mode 100644 sequence_layers/mlx/weight_converter_test.py diff --git a/README.md b/README.md index a223c3b..8aa3da4 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,8 @@ dependencies via `pip install -e .[dev]` (or `.[dev,tensorflow]`, etc.) to allow running tests, e.g., `pytest -n auto sequence_layers/jax` to do so over multiple workers. See the [contributing guide](CONTRIBUTING.md). +For MLX usage (inference on Apple Silicon), see the [MLX Backend Guide](docs/mlx_guide.md). + **Disclaimer:** This is not an officially supported Google product. ## Streamable networks, out of the box diff --git a/docs/mlx_guide.md b/docs/mlx_guide.md new file mode 100644 index 0000000..9388221 --- /dev/null +++ b/docs/mlx_guide.md @@ -0,0 +1,247 @@ +# MLX Backend Guide + +This guide covers using Sequence Layers with the MLX backend for inference on Apple Silicon. + +## Installation + +```bash +pip install sequence-layers[mlx] +``` + +## Workflow + +The MLX backend lets you define architectures with the same Linen configs used +for JAX training, then run inference on Apple Silicon GPUs via MLX. + +### 1. Define the architecture + +Use Linen configs exactly as you would for JAX: + +```python +import jax +import sequence_layers.jax as sl +from sequence_layers.jax.attention import dot_product_self_attention as dpa + +config = sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + dpa.DotProductSelfAttention.Config( + num_heads=4, units_per_head=32, + max_past_horizon=512, max_future_horizon=0, + ), + sl.Flatten.Config(), + sl.Dense.Config(features=128), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=256, activation=jax.nn.gelu), + sl.Dense.Config(features=128), + ]), +]) +``` + +### 2. Train in JAX (or load existing weights) + +```python +linen_model = config.make() +variables = linen_model.init( + jax.random.PRNGKey(0), dummy_input, training=False, +) +params = variables['params'] +# ... train ... +``` + +### 3. Create the MLX model + +```python +import sequence_layers.mlx # Registers MLX backend factories. + +mlx_model = config.make(backend='mlx') +``` + +### 4. Load weights + +```python +import mlx.core as mx +from sequence_layers.mlx import weight_converter +from sequence_layers.mlx import ShapeDType + +weight_converter.load_linen_params( + mlx_model, params, config, + input_spec=ShapeDType((128,), mx.float32), +) +``` + +For models with `BatchNormalization`, pass `batch_stats` too: + +```python +weight_converter.load_linen_params( + mlx_model, params, config, + input_spec=ShapeDType((128,), mx.float32), + batch_stats=variables['batch_stats'], +) +``` + +For models with cross-attention (e.g. `DotProductAttention` or +`StreamingDotProductAttention`), pass `constants` so that deferred +layers can determine source dimensions: + +```python +from sequence_layers.mlx import Sequence + +source = Sequence(mx.zeros((1, 1, 64)), mx.ones((1, 1), dtype=mx.bool_)) +weight_converter.load_linen_params( + mlx_model, params, config, + input_spec=ShapeDType((128,), mx.float32), + constants={'encoder': source}, +) +``` + +### 5. Run inference + +**Full-sequence (layer mode):** + +```python +from sequence_layers.mlx import Sequence + +values = mx.random.normal(shape=(1, 100, 128)) +mask = mx.ones((1, 100), dtype=mx.bool_) +x = Sequence(values, mask) +y = mlx_model.layer(x) +``` + +**Streaming (step mode):** + +```python +spec = ShapeDType((128,), mx.float32) +state = mlx_model.get_initial_state(batch_size=1, input_spec=spec) + +for frame in audio_frames: + x = Sequence(frame, mx.ones((1, 1), dtype=mx.bool_)) + y, state = mlx_model.step(x, state) + # Process y... +``` + +**Streaming with cross-attention constants:** + +For models that use `DotProductAttention` (static cross-attention), pass +the full source as constants. Keys and values are pre-projected once in +`get_initial_state`: + +```python +source = Sequence(encoder_output, encoder_mask) +constants = {'encoder': source} + +state = mlx_model.get_initial_state( + batch_size=1, input_spec=spec, constants=constants, +) +for frame in audio_frames: + x = Sequence(frame, mx.ones((1, 1), dtype=mx.bool_)) + y, state = mlx_model.step(x, state, constants=constants) +``` + +For models that use `StreamingDotProductAttention`, source chunks arrive +at the same rate as input. Each step receives the corresponding source +slice: + +```python +source_chunks = [...] # Same number of chunks as input frames. +state = mlx_model.get_initial_state( + batch_size=1, input_spec=spec, + constants={'encoder': source_chunks[0]}, +) +for frame, src in zip(audio_frames, source_chunks): + x = Sequence(frame, mx.ones((1, 1), dtype=mx.bool_)) + y, state = mlx_model.step(x, state, constants={'encoder': src}) +``` + +### 6. Export for deployment + +```python +from sequence_layers.mlx import export + +export.export_step(mlx_model, 'model.mlxfn', batch_size=1, input_spec=spec) +``` + +## Supported Layers + +The MLX backend supports the following JAX configs via `config.make(backend='mlx')`. +Layers not listed here (e.g. Conv2D/3D, Pooling2D/3D, LSTM, RGLRU, +DotProductSelfAttentionV2, Bidirectional, etc.) are JAX-only. + +| Category | Layers | +|---------------|--------| +| Simple | Identity, Relu, Gelu, Swish, Tanh, Sigmoid, LeakyRelu, Elu, Softmax, Softplus, Cast, Scale, Add, MaskInvalid, GatedUnit, GatedLinearUnit, GatedTanhUnit, Flatten, Reshape, ExpandDims, Squeeze, Transpose, OneHot, Embedding, Dropout, Downsample1D, Upsample1D, CheckpointName, Lambda, Logging | +| Dense | Dense, EinsumDense | +| Normalization | RMSNormalization, LayerNormalization, GroupNormalization, BatchNormalization, L2Normalize | +| Position | ApplyRotaryPositionalEncoding | +| Attention | DotProductSelfAttention, LocalDotProductSelfAttention, DotProductAttention, StreamingDotProductAttention, StreamingLocalDotProductAttention | +| Conditioning | Conditioning | +| Convolution | Conv1D, DepthwiseConv1D, Conv1DTranspose | +| Pooling | MaxPooling1D, MinPooling1D, AveragePooling1D | +| DSP | Delay, Lookahead, Window, Frame, OverlapAdd, FFT, IFFT, RFFT, IRFFT, STFT, InverseSTFT, LinearToMelSpectrogram | +| Combinators | Serial, Residual, Repeat, Parallel | + +## Key Differences from JAX + +- **Inference only** -- no training, no gradient computation. +- **Deferred initialization** -- Dense, Conv, and Attention layers create weights + on the first forward pass (Linen configs don't specify `in_features`). +- **No scan/vmap** -- `Repeat` uses N independent copies instead of stacked + params. +- **Kernel layouts** -- weights are automatically transposed by + `load_linen_params` (e.g., Dense `[in, out]` to `[out, in]`). +- **BatchNormalization** -- inference-only; uses running mean/variance. Training + mode raises an error. + +## Attention Variants + +### Self-Attention (`DotProductSelfAttention`) + +Queries, keys, and values all come from the input sequence. Supports causal +masking, grouped query attention (GQA), and optional Q/K processing networks +(e.g. RoPE). In step mode, uses a rolling KV cache. + +### Local Self-Attention (`LocalDotProductSelfAttention`) + +Extends `DotProductSelfAttention` with a configurable `block_size` for step-mode +processing. The sliding window behavior uses banded visibility masks via +`max_past_horizon` and `max_future_horizon`. Also supports +`attention_logits_soft_cap` for logit capping (e.g. Gemma 2 uses 50.0). + +### Cross-Attention (`DotProductAttention`) + +Queries come from the input; keys/values come from a source sequence in +`constants`. In step mode, keys and values are pre-projected once during +`get_initial_state`, so each step only projects queries. + +### Streaming Cross-Attention (`StreamingDotProductAttention`) + +Like cross-attention, but the source arrives in streaming chunks at the same +rate as the input. Keys and values are projected per-step and stored in a +rolling KV buffer. Layer mode uses a banded visibility mask. + +This class handles both `StreamingDotProductAttention.Config` and +`StreamingLocalDotProductAttention.Config` from the JAX backend (they differ +only in layer-mode efficiency optimizations). + +## Weight Conversion Details + +`load_linen_params` handles all structural differences between Linen and MLX: + +| Layer | Linen shape | MLX shape | Transform | +|-------|------------|-----------|-----------| +| Dense | `[in, out]` | `[out, in]` | Transpose | +| Conv1D | `[k, in, out]` | `[out, k, in]` | `transpose(2,0,1)` | +| Conv1DTranspose | `[k, in, out]` | `[out, k, in]` | Flip spatial + `transpose(2,0,1)` | +| DepthwiseConv1D | `[k, in, 1]` | `[1, k, in]` | Same as Conv1D | +| Self-Attention (Combined QKV) | `[in, 3, heads, uph]` | 3x `[in, heads*uph]` | Split axis 1, reshape | +| Self-Attention (Separate Q/K/V, GQA) | Q: `[in, heads, uph]`, K: `[in, kv_heads, uph]`, V: same | Q: `[in, heads*uph]`, K/V: `[in, kv_heads*uph]` | Reshape each | +| Cross-Attention Q+KV | Q: `[in, heads, uph]`, KV: `[src, 2, heads, uph]` | Q: `[in, heads*uph]`, K/V: `[src, heads*uph]` | Reshape, split KV axis 1 | +| Repeat | `[N, ...]` | N copies of `[...]` | Slice axis 0 | +| Embedding | `[vocab, dim]` | `[vocab, dim]` | No change | +| RMS/LayerNorm | `[dim]` | `[dim]` | No change | +| GroupNorm | scale: `[dim]`, bias: `[dim]` | Same | No change | +| BatchNorm | scale/bias from `params`, mean/var from `batch_stats` | Same | No change | +| EinsumDense | `kernel` (einsum-shaped) | Same | No change | +| Conditioning (LINEAR) | `dense/kernel`, `dense/bias` | Same | No change (same einsum equation) | diff --git a/notebooks/mlx_benchmark.py b/notebooks/mlx_benchmark.py new file mode 100644 index 0000000..84d8916 --- /dev/null +++ b/notebooks/mlx_benchmark.py @@ -0,0 +1,289 @@ +# %% [markdown] +# # JAX vs MLX Step Latency Benchmark +# +# Measures token-by-token autoregressive step latency for decoder +# transformers across three backends: +# +# 1. **JAX (jitted)** — `jax.jit` compiled step with `block_until_ready()` +# 2. **MLX (native)** — direct `model.step()` with `mx.eval()` +# 3. **MLX (exported)** — `.mlxfn` exported step with `mx.eval()` +# +# Runs multiple model sizes to show the crossover point where GPU +# throughput overtakes CPU. +# +# **Requires:** `pip install sequence-layers[mlx]` + +# %% [markdown] +# ## 1. Setup + +# %% +import os +import tempfile +import time + +import flax.core.meta +import jax +import jax.numpy as jnp +import mlx.core as mx +import numpy as np + +import sequence_layers.jax as sl +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import weight_converter + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + +BATCH_SIZE = 1 +MAX_PAST = 128 +WARMUP = 10 +NUM_STEPS = 50 + +CONFIGS = [ + {'label': 'Small', 'dim': 64, 'heads': 4, 'layers': 2}, + {'label': 'Medium', 'dim': 256, 'heads': 8, 'layers': 4}, + {'label': 'Large', 'dim': 512, 'heads': 8, 'layers': 8}, +] +VOCAB_SIZE = 256 + +print(f'Benchmark config: warmup={WARMUP}, steps={NUM_STEPS}') +print(f'Model sizes: {[c["label"] for c in CONFIGS]}') + +# %% [markdown] +# ## 2. Helpers + + +# %% +def make_config(dim, heads, layers): + return sl.Serial.Config([ + sl.Embedding.Config( + num_embeddings=VOCAB_SIZE, + dimension=dim, + ), + sl.Repeat.Config( + num_repeats=layers, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=heads, + units_per_head=dim // heads, + max_past_horizon=MAX_PAST, + max_future_horizon=0, + query_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + key_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config( + features=dim * 4, activation=jax.nn.gelu + ), + sl.Dense.Config(features=dim), + ]), + ]), + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=VOCAB_SIZE), + ]) + + +def bench_jax(config, model_vars, jax_model): + """Benchmark JAX jitted step.""" + jax_bound = jax_model.bind(model_vars) + jax_input_spec = jax.ShapeDtypeStruct((), jnp.int32) + jax_state = jax_bound.get_initial_state( + BATCH_SIZE, jax_input_spec, training=False + ) + + @jax.jit + def jax_step(x_values, x_mask, state): + x = sl.Sequence(x_values, x_mask) + y, new_state = jax_bound.step(x, state, training=False) + return y.values, y.mask, new_state + + x_val = jnp.zeros((BATCH_SIZE, 1), dtype=jnp.int32) + x_msk = jnp.ones((BATCH_SIZE, 1), dtype=jnp.bool_) + + # Warmup. + state = jax_state + for _ in range(WARMUP): + y_vals, y_mask, state = jax_step(x_val, x_msk, state) + jax.block_until_ready((y_vals, y_mask, state)) + + # Timed. + state = jax_state + times = [] + for _ in range(NUM_STEPS): + t0 = time.perf_counter() + y_vals, y_mask, state = jax_step(x_val, x_msk, state) + jax.block_until_ready((y_vals, y_mask, state)) + times.append(time.perf_counter() - t0) + return times + + +def bench_mlx_native(mlx_model): + """Benchmark MLX native step.""" + mlx_input_spec = ShapeDType((), mx.int32) + export._materialize_deferred(mlx_model, BATCH_SIZE, mlx_input_spec) + + x = Sequence( + mx.zeros((BATCH_SIZE, 1), dtype=mx.int32), + mx.ones((BATCH_SIZE, 1), dtype=mx.bool_), + ) + + # Warmup. + state = mlx_model.get_initial_state(BATCH_SIZE, mlx_input_spec) + for _ in range(WARMUP): + y, state = mlx_model.step(x, state) + mx.eval(y.values) + + # Timed. + state = mlx_model.get_initial_state(BATCH_SIZE, mlx_input_spec) + times = [] + for _ in range(NUM_STEPS): + t0 = time.perf_counter() + y, state = mlx_model.step(x, state) + mx.eval(y.values) + times.append(time.perf_counter() - t0) + return times + + +def bench_mlx_exported(mlx_model): + """Benchmark MLX exported step.""" + mlx_input_spec = ShapeDType((), mx.int32) + path = os.path.join(tempfile.gettempdir(), 'benchmark_decoder.mlxfn') + export.export_step( + mlx_model, path, batch_size=BATCH_SIZE, input_spec=mlx_input_spec + ) + imported = mx.import_function(path) + + x_val = mx.zeros((BATCH_SIZE, 1), dtype=mx.int32) + x_msk = mx.ones((BATCH_SIZE, 1), dtype=mx.bool_) + + # Warmup. + flat_state, _ = export.get_initial_state_flat( + mlx_model, BATCH_SIZE, mlx_input_spec + ) + for _ in range(WARMUP): + y_vals, y_mask, flat_state = export.run_exported( + imported, x_val, x_msk, flat_state + ) + mx.eval(y_vals) + + # Timed. + flat_state, _ = export.get_initial_state_flat( + mlx_model, BATCH_SIZE, mlx_input_spec + ) + times = [] + for _ in range(NUM_STEPS): + t0 = time.perf_counter() + y_vals, y_mask, flat_state = export.run_exported( + imported, x_val, x_msk, flat_state + ) + mx.eval(y_vals) + times.append(time.perf_counter() - t0) + + os.remove(path) + return times + + +# %% [markdown] +# ## 3. Run Benchmarks + +# %% +all_results = [] + +for cfg in CONFIGS: + label = cfg['label'] + dim, heads, layers = cfg['dim'], cfg['heads'], cfg['layers'] + print(f'\n{"=" * 60}') + print(f'{label}: dim={dim}, heads={heads}, layers={layers}') + print('=' * 60) + + config = make_config(dim, heads, layers) + + # JAX init. + jax_model = config.make() + x_init = sl.Sequence( + jnp.zeros((BATCH_SIZE, 1), dtype=jnp.int32), + jnp.ones((BATCH_SIZE, 1), dtype=jnp.bool_), + ) + model_vars = jax_model.init(jax.random.key(0), x_init, training=False) + jax_params = flax.core.meta.unbox(model_vars)['params'] + param_count = sum( + x.size for x in jax.tree_util.tree_leaves(jax_params) + ) + print(f'Parameters: {param_count:,}') + + # MLX init. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + # Benchmark all three. + print(' JAX jitted...', end='', flush=True) + jax_times = bench_jax(config, model_vars, jax_model) + jax_mean = np.mean(jax_times) * 1000 + jax_std = np.std(jax_times) * 1000 + print(f' {jax_mean:.3f} ms') + + print(' MLX native...', end='', flush=True) + mlx_times = bench_mlx_native(mlx_model) + mlx_mean = np.mean(mlx_times) * 1000 + mlx_std = np.std(mlx_times) * 1000 + print(f' {mlx_mean:.3f} ms') + + print(' MLX exported...', end='', flush=True) + exp_times = bench_mlx_exported(mlx_model) + exp_mean = np.mean(exp_times) * 1000 + exp_std = np.std(exp_times) * 1000 + print(f' {exp_mean:.3f} ms') + + all_results.append({ + 'label': label, + 'params': param_count, + 'jax': (jax_mean, jax_std), + 'mlx_native': (mlx_mean, mlx_std), + 'mlx_exported': (exp_mean, exp_std), + }) + +# %% [markdown] +# ## 4. Results Summary + +# %% +print('\n') +print(f'{"Model":<10} {"Params":>10} ' + f'{"JAX (ms)":>12} {"MLX nat (ms)":>14} {"MLX exp (ms)":>14}') +print('-' * 66) +for r in all_results: + jm, js = r['jax'] + nm, ns = r['mlx_native'] + em, es = r['mlx_exported'] + print( + f'{r["label"]:<10} {r["params"]:>10,} ' + f'{jm:>6.2f}+/-{js:<4.2f} ' + f'{nm:>7.2f}+/-{ns:<4.2f} ' + f'{em:>7.2f}+/-{es:<4.2f}' + ) + +print() +print('Tokens/sec:') +print(f'{"Model":<10} {"JAX":>10} {"MLX native":>12} {"MLX exported":>14}') +print('-' * 50) +for r in all_results: + jt = 1000.0 / r['jax'][0] + nt = 1000.0 / r['mlx_native'][0] + et = 1000.0 / r['mlx_exported'][0] + print(f'{r["label"]:<10} {jt:>10.0f} {nt:>12.0f} {et:>14.0f}') + +print('\nDone!') diff --git a/notebooks/mlx_streaming.py b/notebooks/mlx_streaming.py new file mode 100644 index 0000000..da61d27 --- /dev/null +++ b/notebooks/mlx_streaming.py @@ -0,0 +1,241 @@ +# %% [markdown] +# # MLX Streaming Inference Demo +# +# This notebook demonstrates the full MLX streaming inference pipeline: +# +# 1. Define a decoder transformer using SequenceLayers configs +# 2. Initialize weights in JAX (Linen) +# 3. Convert weights to MLX +# 4. Stream tokens natively in MLX +# 5. Export to `.mlxfn` for deployment +# 6. Stream tokens from the exported function +# +# No checkpoint is needed — we use random init weights throughout. +# +# **Requires:** `pip install sequence-layers[mlx]` + +# %% [markdown] +# ## 1. Setup + +# %% +import os +import tempfile + +import flax.core.meta +import jax +import jax.numpy as jnp +import mlx.core as mx +import numpy as np + +import sequence_layers.jax as sl +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import weight_converter + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + +# Hyperparameters (small model, fast to run). +VOCAB_SIZE = 256 +DIM = 64 +NUM_HEADS = 4 +UNITS_PER_HEAD = DIM // NUM_HEADS # 16 +NUM_LAYERS = 2 +FFN_DIM = DIM * 4 # 256 +MAX_PAST = 128 +BATCH_SIZE = 1 +NUM_TOKENS = 16 + +print('Setup complete.') + +# %% [markdown] +# ## 2. Define Architecture +# +# A small decoder-only transformer: +# ``` +# Embedding → Repeat(N, [ +# Residual([RMSNorm, SelfAttention(RoPE), Flatten]), +# Residual([RMSNorm, Dense(FFN, gelu), Dense(dim)]), +# ]) → RMSNorm → Dense(vocab_size) +# ``` + +# %% +config = sl.Serial.Config([ + sl.Embedding.Config( + num_embeddings=VOCAB_SIZE, + dimension=DIM, + ), + sl.Repeat.Config( + num_repeats=NUM_LAYERS, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=NUM_HEADS, + units_per_head=UNITS_PER_HEAD, + max_past_horizon=MAX_PAST, + max_future_horizon=0, + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=FFN_DIM, activation=jax.nn.gelu), + sl.Dense.Config(features=DIM), + ]), + ]), + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=VOCAB_SIZE), +]) + +print(f'Architecture: vocab={VOCAB_SIZE}, dim={DIM}, heads={NUM_HEADS}, ' + f'uph={UNITS_PER_HEAD}, layers={NUM_LAYERS}, ffn={FFN_DIM}, ' + f'max_past={MAX_PAST}') + +# %% [markdown] +# ## 3. Initialize in JAX +# +# Build the Linen model, init with a dummy input, count parameters. + +# %% +jax_model = config.make() +key = jax.random.key(0) + +# Dummy input for init: a single token. +x_init = sl.Sequence( + jnp.zeros((BATCH_SIZE, 1), dtype=jnp.int32), + jnp.ones((BATCH_SIZE, 1), dtype=jnp.bool_), +) + +model_vars = jax_model.init(key, x_init, training=False) +jax_params = flax.core.meta.unbox(model_vars)['params'] + +param_count = sum( + x.size for x in jax.tree_util.tree_leaves(jax_params) +) +print(f'JAX model initialized: {param_count:,} parameters') + +# %% [markdown] +# ## 4. Convert Weights to MLX +# +# Build an MLX model from the same config, load Linen weights, verify +# that `layer()` outputs match between JAX and MLX. + +# %% +mlx_model = config.make(backend='mlx') +weight_converter.load_linen_params(mlx_model, jax_params, config) +print('Weights loaded into MLX model.') + +# Verify layer() outputs match on a short sequence. +tokens = np.array([[0, 42, 7, 13, 99, 200, 1, 128]], dtype=np.int32) +mask = np.ones(tokens.shape, dtype=bool) + +# JAX forward. +x_jax = sl.Sequence(jnp.array(tokens), jnp.array(mask)) +jax_bound = jax_model.bind(model_vars) +y_jax = jax_bound.layer(x_jax, training=False) + +# MLX forward. +x_mlx = Sequence(mx.array(tokens), mx.array(mask)) +y_mlx = mlx_model.layer(x_mlx) +mx.eval(y_mlx.values) + +np.testing.assert_allclose( + np.array(y_mlx.values), + np.array(y_jax.values), + atol=1e-3, + rtol=1e-3, +) +print(f'JAX and MLX layer() outputs match (atol=1e-3).') +print(f' Output shape: {y_mlx.values.shape}') + +# %% [markdown] +# ## 5. Native MLX Streaming +# +# Generate tokens one at a time using `model.step()` with greedy +# (argmax) decoding. This uses the KV cache internally. + +# %% +input_spec = ShapeDType((), mx.int32) +state = mlx_model.get_initial_state(BATCH_SIZE, input_spec) + +token = 0 # Start-of-sequence token. +generated = [token] + +for _ in range(NUM_TOKENS - 1): + x = Sequence( + mx.array([[token]], dtype=mx.int32), + mx.ones((1, 1), dtype=mx.bool_), + ) + y, state = mlx_model.step(x, state) + mx.eval(y.values) + logits = y.values[0, 0] # [vocab_size] + token = int(mx.argmax(logits)) + generated.append(token) + +print(f'Generated {len(generated)} tokens (native step):') +print(generated) + +# %% [markdown] +# ## 6. Export to .mlxfn +# +# Export the step function to a `.mlxfn` file. Model weights are +# captured in the closure; state arrays (KV cache) are explicit I/O. + +# %% +export_path = os.path.join(tempfile.gettempdir(), 'decoder_demo.mlxfn') +export.export_step( + mlx_model, export_path, batch_size=BATCH_SIZE, input_spec=input_spec +) +size_kb = os.path.getsize(export_path) / 1024 +print(f'Exported to: {export_path}') +print(f'File size: {size_kb:.1f} KB') + +# %% [markdown] +# ## 7. Streaming from Exported Function +# +# Load the `.mlxfn` back and run the same generation loop. Outputs +# must match the native step exactly (bit-for-bit). + +# %% +imported = mx.import_function(export_path) +flat_state, structure = export.get_initial_state_flat( + mlx_model, BATCH_SIZE, input_spec +) + +token = 0 +exported_generated = [token] + +for _ in range(NUM_TOKENS - 1): + x_values = mx.array([[token]], dtype=mx.int32) + x_mask = mx.ones((1, 1), dtype=mx.bool_) + y_vals, y_mask, flat_state = export.run_exported( + imported, x_values, x_mask, flat_state + ) + mx.eval(y_vals) + logits = y_vals[0, 0] + token = int(mx.argmax(logits)) + exported_generated.append(token) + +print(f'Generated {len(exported_generated)} tokens (exported step):') +print(exported_generated) + +assert generated == exported_generated, ( + f'Mismatch!\n native: {generated}\n exported: {exported_generated}' +) +print('Native and exported outputs match exactly.') + +# %% [markdown] +# ## 8. Cleanup + +# %% +os.remove(export_path) +print(f'Removed {export_path}') +print('Done!') diff --git a/sequence_layers/jax/types.py b/sequence_layers/jax/types.py index 99e6340..d040d09 100644 --- a/sequence_layers/jax/types.py +++ b/sequence_layers/jax/types.py @@ -1466,8 +1466,53 @@ class SequenceLayerConfig(spec.SequenceLayerConfig): of the specific SequenceLayer in use. This allows easy swapping of implementations based on behaviors (steppability, output ratio, latency, etc.). + + The make() method supports an optional ``backend`` keyword argument. When + ``backend`` is not None (e.g. ``'mlx'``), make() looks up a registered + factory for the given backend and config class (walking the MRO) and calls + it instead of the Linen implementation. + + Backend factories are registered via:: + + SequenceLayerConfig.register_backend_factory( + 'mlx', SomeLinen.Config, some_mlx_factory_fn, + ) """ + # Backend factory registry: {(backend_name, config_class): factory_fn} + _backend_factories: dict[tuple[str, type], Callable] = {} + + @classmethod + def register_backend_factory( + cls, + backend: str, + config_cls: type['SequenceLayerConfig'], + factory: Callable[['SequenceLayerConfig'], Any], + ) -> None: + """Register a factory function for a (backend, config_class) pair.""" + cls._backend_factories[(backend, config_cls)] = factory + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if 'make' in cls.__dict__: + original_make = cls.__dict__['make'] + + @functools.wraps(original_make) + def _wrapped_make(self, *, backend=None, _orig=original_make): + if backend is not None: + for klass in type(self).__mro__: + key = (backend, klass) + factory = SequenceLayerConfig._backend_factories.get(key) + if factory is not None: + return factory(self) + raise ValueError( + f'No {backend!r} backend registered for' + f' {type(self).__qualname__}' + ) + return _orig(self) + + cls.make = _wrapped_make + @abc.abstractmethod def make(self) -> SequenceLayer: """Builds a SequenceLayer from this config.""" diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 8f5b18f..81e47ba 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -13,4 +13,363 @@ # limitations under the License. """Sequence layers in MLX.""" -from sequence_layers.mlx.types import * +# Re-export basic types. +from sequence_layers.mlx.types import Constants +from sequence_layers.mlx.types import DType +from sequence_layers.mlx.types import Emits +from sequence_layers.mlx.types import mask_invalid +from sequence_layers.mlx.types import MASK_DTYPE +from sequence_layers.mlx.types import MaskedSequence +from sequence_layers.mlx.types import PaddingMode +from sequence_layers.mlx.types import ReceptiveField +from sequence_layers.mlx.types import Sequence +from sequence_layers.mlx.types import sequence_mask +from sequence_layers.mlx.types import Shape +from sequence_layers.mlx.types import ShapeDType +from sequence_layers.mlx.types import ShapeLike +from sequence_layers.mlx.types import State + +# Re-export basic_types TypeVars used in type annotations. +from sequence_layers.mlx.types import MaskT + +# Re-export attention projection configs (from JAX, used to configure attention). +from sequence_layers.jax.attention.common import CombinedQueryKeyValueProjection +from sequence_layers.jax.attention.common import QueryAndKeyValueProjection +from sequence_layers.jax.attention.common import SeparateQueryKeyValueProjection + +# Re-export MLX layer hierarchy. +from sequence_layers.mlx.types import ChannelSpec +from sequence_layers.mlx.types import check_layer +from sequence_layers.mlx.types import check_step +from sequence_layers.mlx.types import Emitting +from sequence_layers.mlx.types import PreservesShape +from sequence_layers.mlx.types import PreservesType +from sequence_layers.mlx.types import SequenceLayer +from sequence_layers.mlx.types import Stateless +from sequence_layers.mlx.types import StatelessEmitting +from sequence_layers.mlx.types import StatelessPointwise +from sequence_layers.mlx.types import StatelessPointwiseFunctor +from sequence_layers.mlx.types import Steppable + +# Re-export simple layers. +from sequence_layers.mlx.simple import Add +from sequence_layers.mlx.simple import Cast +from sequence_layers.mlx.simple import CheckpointName +from sequence_layers.mlx.simple import Downsample1D +from sequence_layers.mlx.simple import Dropout +from sequence_layers.mlx.simple import Elu +from sequence_layers.mlx.simple import Embedding +from sequence_layers.mlx.simple import ExpandDims +from sequence_layers.mlx.simple import Flatten +from sequence_layers.mlx.simple import GatedLinearUnit +from sequence_layers.mlx.simple import GatedTanhUnit +from sequence_layers.mlx.simple import GatedUnit +from sequence_layers.mlx.simple import Gelu +from sequence_layers.mlx.simple import Identity +from sequence_layers.mlx.simple import Lambda +from sequence_layers.mlx.simple import LeakyRelu +from sequence_layers.mlx.simple import Logging +from sequence_layers.mlx.simple import MaskInvalid +from sequence_layers.mlx.simple import OneHot +from sequence_layers.mlx.simple import Relu +from sequence_layers.mlx.simple import Reshape +from sequence_layers.mlx.simple import Scale +from sequence_layers.mlx.simple import Sigmoid +from sequence_layers.mlx.simple import Softmax +from sequence_layers.mlx.simple import Softplus +from sequence_layers.mlx.simple import Squeeze +from sequence_layers.mlx.simple import Swish +from sequence_layers.mlx.simple import Tanh +from sequence_layers.mlx.simple import Transpose +from sequence_layers.mlx.simple import Upsample1D + +# Re-export dense / normalization / position. +from sequence_layers.mlx.dense import Dense +from sequence_layers.mlx.dense import DenseDeferred +from sequence_layers.mlx.dense import EinsumDense +from sequence_layers.mlx.normalization import BatchNormalization +from sequence_layers.mlx.normalization import GroupNormalization +from sequence_layers.mlx.normalization import L2Normalize +from sequence_layers.mlx.normalization import LayerNormalization +from sequence_layers.mlx.normalization import RMSNormalization +from sequence_layers.mlx.position import ApplyRotaryPositionalEncoding + +# Re-export attention layers. +from sequence_layers.mlx.attention import DotProductAttention +from sequence_layers.mlx.attention import DotProductSelfAttention +from sequence_layers.mlx.attention import DeferredDotProductAttention +from sequence_layers.mlx.attention import DeferredDotProductSelfAttention +from sequence_layers.mlx.attention import DeferredLocalDotProductSelfAttention +from sequence_layers.mlx.attention import DeferredStreamingDotProductAttention +from sequence_layers.mlx.attention import LocalDotProductSelfAttention +from sequence_layers.mlx.attention import StreamingDotProductAttention + +# Re-export pooling layers. +from sequence_layers.mlx.pooling import AveragePooling1D +from sequence_layers.mlx.pooling import MaxPooling1D +from sequence_layers.mlx.pooling import MinPooling1D + +# Re-export convolution layers. +from sequence_layers.mlx.convolution import Conv1D +from sequence_layers.mlx.convolution import Conv1DTranspose +from sequence_layers.mlx.convolution import DeferredConv1D +from sequence_layers.mlx.convolution import DeferredConv1DTranspose +from sequence_layers.mlx.convolution import DeferredDepthwiseConv1D +from sequence_layers.mlx.convolution import DepthwiseConv1D + +# Re-export DSP layers. +from sequence_layers.mlx.dsp import Delay +from sequence_layers.mlx.dsp import FFT +from sequence_layers.mlx.dsp import Frame +from sequence_layers.mlx.dsp import IFFT +from sequence_layers.mlx.dsp import InverseSTFT +from sequence_layers.mlx.dsp import IRFFT +from sequence_layers.mlx.dsp import LinearToMelSpectrogram +from sequence_layers.mlx.dsp import Lookahead +from sequence_layers.mlx.dsp import OverlapAdd +from sequence_layers.mlx.dsp import RFFT +from sequence_layers.mlx.dsp import STFT +from sequence_layers.mlx.dsp import Window + +# Re-export combinators. +from sequence_layers.mlx.combinators import CombinationMode +from sequence_layers.mlx.combinators import Parallel +from sequence_layers.mlx.combinators import Repeat +from sequence_layers.mlx.combinators import Residual +from sequence_layers.mlx.combinators import Serial + +# Re-export conditioning. +from sequence_layers.mlx.conditioning import Conditioning + +# Re-export export and weight conversion utilities. +from sequence_layers.mlx import export +from sequence_layers.mlx import weight_converter + +# --------------------------------------------------------------------------- +# Backend factory registration +# --------------------------------------------------------------------------- +from sequence_layers.jax.types import SequenceLayerConfig as _SLC + + +def _register_backends(): + """Register MLX factories for all supported Linen Config classes.""" + from sequence_layers.jax import conditioning as jax_cond + from sequence_layers.jax import simple as jax_simple + from sequence_layers.jax import dense as jax_dense + from sequence_layers.jax import normalization as jax_norm + from sequence_layers.jax import position as jax_pos + from sequence_layers.jax.attention import ( + dot_product_self_attention as jax_self_attn, + ) + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + from sequence_layers.jax.attention import ( + streaming_local_dot_product_attention as jax_streaming_local_attn, + ) + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + from sequence_layers.jax import convolution as jax_conv + from sequence_layers.jax import pooling as jax_pool + from sequence_layers.jax import dsp as jax_dsp + from sequence_layers.jax import combinators as jax_comb + + from sequence_layers.mlx import conditioning as mlx_cond + from sequence_layers.mlx import simple as mlx_simple + from sequence_layers.mlx import dense as mlx_dense + from sequence_layers.mlx import normalization as mlx_norm + from sequence_layers.mlx import position as mlx_pos + from sequence_layers.mlx import attention as mlx_attn + from sequence_layers.mlx import convolution as mlx_conv + from sequence_layers.mlx import pooling as mlx_pool + from sequence_layers.mlx import dsp as mlx_dsp + from sequence_layers.mlx import combinators as mlx_comb + + reg = _SLC.register_backend_factory + + # Simple layers — activations. + reg('mlx', jax_simple.Identity.Config, mlx_simple.Identity.from_config) + reg('mlx', jax_simple.Relu.Config, mlx_simple.Relu.from_config) + reg('mlx', jax_simple.Gelu.Config, mlx_simple.Gelu.from_config) + reg('mlx', jax_simple.Swish.Config, mlx_simple.Swish.from_config) + reg('mlx', jax_simple.Tanh.Config, mlx_simple.Tanh.from_config) + reg('mlx', jax_simple.Sigmoid.Config, mlx_simple.Sigmoid.from_config) + reg('mlx', jax_simple.LeakyRelu.Config, mlx_simple.LeakyRelu.from_config) + reg('mlx', jax_simple.Elu.Config, mlx_simple.Elu.from_config) + reg('mlx', jax_simple.Softmax.Config, mlx_simple.Softmax.from_config) + reg('mlx', jax_simple.Softplus.Config, mlx_simple.Softplus.from_config) + + # Simple layers — value manipulation. + reg('mlx', jax_simple.Cast.Config, mlx_simple.Cast.from_config) + reg('mlx', jax_simple.Scale.Config, mlx_simple.Scale.from_config) + reg('mlx', jax_simple.Add.Config, mlx_simple.Add.from_config) + + # Simple layers — masking. + reg('mlx', jax_simple.MaskInvalid.Config, mlx_simple.MaskInvalid.from_config) + + # Simple layers — gated units. + reg('mlx', jax_simple.GatedUnit.Config, mlx_simple.GatedUnit.from_config) + reg( + 'mlx', + jax_simple.GatedLinearUnit.Config, + mlx_simple.GatedLinearUnit.from_config, + ) + reg( + 'mlx', + jax_simple.GatedTanhUnit.Config, + mlx_simple.GatedTanhUnit.from_config, + ) + + # Simple layers — shape manipulation. + reg('mlx', jax_simple.Flatten.Config, mlx_simple.Flatten.from_config) + reg('mlx', jax_simple.Reshape.Config, mlx_simple.Reshape.from_config) + reg('mlx', jax_simple.ExpandDims.Config, mlx_simple.ExpandDims.from_config) + reg('mlx', jax_simple.Squeeze.Config, mlx_simple.Squeeze.from_config) + reg('mlx', jax_simple.Transpose.Config, mlx_simple.Transpose.from_config) + + # Simple layers — encoding. + reg('mlx', jax_simple.OneHot.Config, mlx_simple.OneHot.from_config) + reg('mlx', jax_simple.Embedding.Config, mlx_simple.Embedding.from_config) + + # Simple layers — regularization. + reg('mlx', jax_simple.Dropout.Config, mlx_simple.Dropout.from_config) + + # Simple layers — sampling. + reg( + 'mlx', jax_simple.Downsample1D.Config, mlx_simple.Downsample1D.from_config + ) + reg('mlx', jax_simple.Upsample1D.Config, mlx_simple.Upsample1D.from_config) + + # Simple layers — misc. + reg( + 'mlx', + jax_simple.CheckpointName.Config, + mlx_simple.CheckpointName.from_config, + ) + reg('mlx', jax_simple.Lambda.Config, mlx_simple.Lambda.from_config) + reg('mlx', jax_simple.Logging.Config, mlx_simple.Logging.from_config) + + # Dense. + reg('mlx', jax_dense.Dense.Config, mlx_dense.DenseDeferred.from_config) + reg('mlx', jax_dense.EinsumDense.Config, mlx_dense.EinsumDense.from_config) + + # Conditioning. + reg( + 'mlx', + jax_cond.Conditioning.Config, + mlx_cond.Conditioning.from_config, + ) + + # Normalization. + reg( + 'mlx', + jax_norm.L2Normalize.Config, + mlx_norm.L2Normalize.from_config, + ) + reg( + 'mlx', + jax_norm.RMSNormalization.Config, + mlx_norm.RMSNormalization.from_config, + ) + reg( + 'mlx', + jax_norm.LayerNormalization.Config, + mlx_norm.LayerNormalization.from_config, + ) + reg( + 'mlx', + jax_norm.GroupNormalization.Config, + mlx_norm.GroupNormalization.from_config, + ) + reg( + 'mlx', + jax_norm.BatchNormalization.Config, + mlx_norm.BatchNormalization.from_config, + ) + + # Position. + reg( + 'mlx', + jax_pos.ApplyRotaryPositionalEncoding.Config, + mlx_pos.ApplyRotaryPositionalEncoding.from_config, + ) + + # Attention. + reg( + 'mlx', + jax_self_attn.DotProductSelfAttention.Config, + mlx_attn.DotProductSelfAttention.from_config, + ) + reg( + 'mlx', + jax_cross_attn.DotProductAttention.Config, + mlx_attn.DotProductAttention.from_config, + ) + reg( + 'mlx', + jax_streaming_attn.StreamingDotProductAttention.Config, + mlx_attn.StreamingDotProductAttention.from_config, + ) + reg( + 'mlx', + jax_streaming_local_attn.StreamingLocalDotProductAttention.Config, + mlx_attn.StreamingDotProductAttention.from_config, + ) + reg( + 'mlx', + jax_local_attn.LocalDotProductSelfAttention.Config, + mlx_attn.LocalDotProductSelfAttention.from_config, + ) + + # Convolution. + reg('mlx', jax_conv.Conv1D.Config, mlx_conv.Conv1D.from_config) + reg( + 'mlx', + jax_conv.DepthwiseConv1D.Config, + mlx_conv.DepthwiseConv1D.from_config, + ) + reg( + 'mlx', + jax_conv.Conv1DTranspose.Config, + mlx_conv.Conv1DTranspose.from_config, + ) + + # Pooling. + reg('mlx', jax_pool.MaxPooling1D.Config, mlx_pool.MaxPooling1D.from_config) + reg('mlx', jax_pool.MinPooling1D.Config, mlx_pool.MinPooling1D.from_config) + reg( + 'mlx', + jax_pool.AveragePooling1D.Config, + mlx_pool.AveragePooling1D.from_config, + ) + + # DSP. + reg('mlx', jax_dsp.Delay.Config, mlx_dsp.Delay.from_config) + reg('mlx', jax_dsp.Lookahead.Config, mlx_dsp.Lookahead.from_config) + reg('mlx', jax_dsp.Window.Config, mlx_dsp.Window.from_config) + reg('mlx', jax_dsp.Frame.Config, mlx_dsp.Frame.from_config) + reg('mlx', jax_dsp.OverlapAdd.Config, mlx_dsp.OverlapAdd.from_config) + reg('mlx', jax_dsp.FFT.Config, mlx_dsp.FFT.from_config) + reg('mlx', jax_dsp.IFFT.Config, mlx_dsp.IFFT.from_config) + reg('mlx', jax_dsp.RFFT.Config, mlx_dsp.RFFT.from_config) + reg('mlx', jax_dsp.IRFFT.Config, mlx_dsp.IRFFT.from_config) + reg('mlx', jax_dsp.STFT.Config, mlx_dsp.STFT.from_config) + reg('mlx', jax_dsp.InverseSTFT.Config, mlx_dsp.InverseSTFT.from_config) + reg( + 'mlx', + jax_dsp.LinearToMelSpectrogram.Config, + mlx_dsp.LinearToMelSpectrogram.from_config, + ) + + # Combinators. + reg('mlx', jax_comb.Serial.Config, mlx_comb.Serial.from_config) + reg('mlx', jax_comb.Residual.Config, mlx_comb.Residual.from_config) + reg('mlx', jax_comb.Repeat.Config, mlx_comb.Repeat.from_config) + reg('mlx', jax_comb.Parallel.Config, mlx_comb.Parallel.from_config) + + +_register_backends() diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py new file mode 100644 index 0000000..9f8fb1c --- /dev/null +++ b/sequence_layers/mlx/attention.py @@ -0,0 +1,1448 @@ +"""Dot-product attention layers for MLX.""" + +import math + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence + + +def _causal_mask(q_len, kv_len): + """Build a [1, 1, q_len, kv_len] causal mask (True = attend).""" + # Each query at position i can attend to keys at positions + # [kv_len - q_len, ..., kv_len - q_len + i]. + row = mx.arange(q_len) + col = mx.arange(kv_len) + # query i (global pos = kv_len - q_len + i) can see key j + # if j <= kv_len - q_len + i. + offset = kv_len - q_len + mask = col[None, :] <= (row[:, None] + offset) + return mask.reshape(1, 1, q_len, kv_len) + + +class DotProductSelfAttention(types.Emitting): + """Multi-headed dot-product self attention for MLX. + + Supports: + - Grouped Query Attention (num_kv_heads < num_heads) + - Causal masking via max_past_horizon + - KV cache for step-by-step inference + - Optional query/key/value processing networks (e.g. RoPE) + + Kernels are stored in Linen-compatible shapes: + q_proj: [in_features, num_heads * units_per_head] + k_proj: [in_features, num_kv_heads * units_per_head] + v_proj: [in_features, num_kv_heads * units_per_head] + out_proj: [num_heads * units_per_head, in_features] + """ + + def __init__( + self, + *, + in_features: int, + num_heads: int, + units_per_head: int, + max_past_horizon: int, + max_future_horizon: int = 0, + num_kv_heads: int | None = None, + use_bias: bool = False, + query_scale: float | None = None, + compute_dtype=None, + param_dtype=mx.float32, + kernel_init=None, + bias_init=None, + query_network: types.SequenceLayer | None = None, + key_network: types.SequenceLayer | None = None, + value_network: types.SequenceLayer | None = None, + attention_logits_soft_cap: float | None = None, + ): + super().__init__() + if num_kv_heads is None: + num_kv_heads = num_heads + if num_heads % num_kv_heads != 0: + raise ValueError(f'{num_heads=} must be divisible by {num_kv_heads=}.') + if max_past_horizon < -1: + raise ValueError( + f'max_past_horizon must be >= -1, got {max_past_horizon}.' + ) + if max_future_horizon < -1: + raise ValueError( + f'max_future_horizon must be >= -1, got {max_future_horizon}.' + ) + + self.in_features = in_features + self.num_heads = num_heads + self.units_per_head = units_per_head + self.max_past_horizon = max_past_horizon + self.max_future_horizon = max_future_horizon + self.num_kv_heads = num_kv_heads + self.use_bias = use_bias + self._query_scale = query_scale + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + self._attention_logits_soft_cap = attention_logits_soft_cap + + if kernel_init is None: + kernel_init = init_mapping._make_variance_scaling_init( + 'fan_in', 'truncated_normal' + ) + if bias_init is None: + bias_init = init_mapping._zeros_init + + key = mx.random.key(0) + q_dim = num_heads * units_per_head + kv_dim = num_kv_heads * units_per_head + + # Projections stored as [in, out] to match Linen convention. + self.q_proj = kernel_init(key, (in_features, q_dim), param_dtype) + self.k_proj = kernel_init(key, (in_features, kv_dim), param_dtype) + self.v_proj = kernel_init(key, (in_features, kv_dim), param_dtype) + if use_bias: + self.q_bias = bias_init(key, (q_dim,), param_dtype) + self.k_bias = bias_init(key, (kv_dim,), param_dtype) + self.v_bias = bias_init(key, (kv_dim,), param_dtype) + + self.query_network = query_network + self.key_network = key_network + self.value_network = value_network + + @property + def supports_step(self): + return self.max_past_horizon >= 0 and self.max_future_horizon >= 0 + + @property + def input_latency(self): + return max(0, self.max_future_horizon) + + def _project_qkv(self, x): + """Project input to Q, K, V sequences.""" + b, t = x.shape[0], x.shape[1] + dtype = self.compute_dtype or x.dtype + + v = x.values.astype(dtype) + q = mx.matmul(v, self.q_proj.astype(dtype)) + k = mx.matmul(v, self.k_proj.astype(dtype)) + val = mx.matmul(v, self.v_proj.astype(dtype)) + + if self.use_bias: + q = q + self.q_bias.astype(dtype) + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) + + # Reshape to [b, t, heads, units_per_head]. + q = q.reshape(b, t, self.num_heads, self.units_per_head) + k = k.reshape(b, t, self.num_kv_heads, self.units_per_head) + val = val.reshape(b, t, self.num_kv_heads, self.units_per_head) + + return ( + Sequence(q, x.mask), + Sequence(k, x.mask), + Sequence(val, x.mask), + ) + + def _compute_attention(self, queries, keys, values, mask): + """Compute scaled dot-product attention. + + Args: + queries: [b, q_t, num_heads, units_per_head] + keys: [b, kv_t, num_kv_heads, units_per_head] + values: [b, kv_t, num_kv_heads, units_per_head] + mask: [b, 1, q_t, kv_t] boolean mask (True = attend) + + Returns: + context: [b, q_t, num_heads, units_per_head] + """ + scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) + + # GQA: repeat K/V heads to match query heads. + num_groups = self.num_heads // self.num_kv_heads + if num_groups > 1: + b, kv_t, nk, h = keys.shape + keys = mx.repeat(keys, num_groups, axis=2) + values = mx.repeat(values, num_groups, axis=2) + + # Transpose to [b, heads, t, h] for batched matmul. + q = mx.transpose(queries, (0, 2, 1, 3)) # [b, nh, qt, h] + k = mx.transpose(keys, (0, 2, 1, 3)) # [b, nh, kvt, h] + v = mx.transpose(values, (0, 2, 1, 3)) # [b, nh, kvt, h] + + # Scaled dot-product attention. + q = q * scale + logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) + + # Optional soft cap on logits (e.g., Gemma 2 uses cap=50.0). + if self._attention_logits_soft_cap is not None: + cap = self._attention_logits_soft_cap + logits = mx.tanh(logits / cap) * cap + + # Apply mask: set masked positions to large negative. + if mask is not None: + large_neg = mx.array(-1e9, dtype=logits.dtype) + logits = mx.where(mask, logits, large_neg) + + weights = mx.softmax(logits, axis=-1) + context = mx.matmul(weights, v) # [b, nh, qt, h] + + # Transpose back to [b, qt, nh, h]. + context = mx.transpose(context, (0, 2, 1, 3)) + return context + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + 'DotProductSelfAttention requires rank 3 input,' + f' got channel_shape={input_shape}.' + ) + return (self.num_heads, self.units_per_head) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + compute_dtype = self.get_output_dtype(input_spec.dtype) + max_past = max(0, self.max_past_horizon) + max_future = max(0, self.max_future_horizon) + kv_buffer_size = max_past + max_future + + kv_shape = ( + batch_size, + kv_buffer_size, + self.num_kv_heads, + self.units_per_head, + ) + kv_buffer_keys = mx.zeros(kv_shape, dtype=compute_dtype) + kv_buffer_values = mx.zeros(kv_shape, dtype=compute_dtype) + kv_buffer_mask = mx.zeros((batch_size, kv_buffer_size), dtype=mx.bool_) + time_step = mx.zeros((batch_size,), dtype=mx.int32) + + # Q/K/V network states. + q_net_state = ( + self.query_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.query_network is not None + else () + ) + k_net_state = ( + self.key_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_kv_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.key_network is not None + else () + ) + v_net_state = ( + self.value_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_kv_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.value_network is not None + else () + ) + + return ( + kv_buffer_keys, + kv_buffer_values, + kv_buffer_mask, + time_step, + q_net_state, + k_net_state, + v_net_state, + ) + + def layer_with_emits(self, x, *, constants=None): + queries, keys, values = self._project_qkv(x) + + # Optional Q/K/V processing networks (e.g. RoPE). + # Use `is not None` because parameterless nn.Modules are falsy. + if self.query_network is not None: + queries = Sequence( + self.query_network.layer(queries, constants=constants).values, + queries.mask, + ) + if self.key_network is not None: + keys = Sequence( + self.key_network.layer(keys, constants=constants).values, + keys.mask, + ) + if self.value_network is not None: + values = Sequence( + self.value_network.layer(values, constants=constants).values, + values.mask, + ) + + # Mask invalid values. + values = values.mask_invalid() + + t = x.shape[1] + + # Build visibility mask. + # Start with key validity: [b, 1, 1, t]. + valid_mask = x.mask[:, None, None, :] + + # Optionally add causal / banded mask. + if self.max_past_horizon >= 0 or self.max_future_horizon >= 0: + past = t - 1 if self.max_past_horizon == -1 else self.max_past_horizon + future = ( + t - 1 if self.max_future_horizon == -1 else self.max_future_horizon + ) + # Banded visibility matrix. + row = mx.arange(t)[:, None] + col = mx.arange(t)[None, :] + banded = (col >= row - past) & (col <= row + future) + valid_mask = valid_mask & banded.reshape(1, 1, t, t) + + context = self._compute_attention( + queries.values, keys.values, values.values, valid_mask + ) + return Sequence(context, x.mask), () + + def step_with_emits(self, x, state, *, constants=None): + queries, keys, values = self._project_qkv(x) + + ( + kv_buf_k, + kv_buf_v, + kv_buf_mask, + time_step, + q_net_state, + k_net_state, + v_net_state, + ) = state + + # Optional Q/K/V processing networks. + # Use `is not None` because parameterless nn.Modules are falsy. + if self.query_network is not None: + queries, q_net_state = self.query_network.step( + queries, q_net_state, constants=constants + ) + if self.key_network is not None: + keys, k_net_state = self.key_network.step( + keys, k_net_state, constants=constants + ) + if self.value_network is not None: + values, v_net_state = self.value_network.step( + values, v_net_state, constants=constants + ) + + # Mask invalid values. + values = values.mask_invalid() + + x_time = x.shape[1] + kv_buffer_size = kv_buf_k.shape[1] + + # Append new K/V to buffer and trim to buffer size. + new_k = mx.concatenate([kv_buf_k, keys.values], axis=1) + new_v = mx.concatenate([kv_buf_v, values.values], axis=1) + new_mask = mx.concatenate([kv_buf_mask, x.mask], axis=1) + + # Keep only the last kv_buffer_size entries. + new_k = new_k[:, -kv_buffer_size:] + new_v = new_v[:, -kv_buffer_size:] + new_mask = new_mask[:, -kv_buffer_size:] + + # Build visibility mask: [b, 1, q_time, kv_time]. + kv_valid = new_mask[:, None, None, :] # [b,1,1,kvt] + + # Add causal mask for multi-step queries. + if x_time > 1: + causal = _causal_mask(x_time, new_k.shape[1]) + kv_valid = kv_valid & causal + + context = self._compute_attention(queries.values, new_k, new_v, kv_valid) + + new_state = ( + new_k, + new_v, + new_mask, + time_step + x_time, + q_net_state, + k_net_state, + v_net_state, + ) + return Sequence(context, x.mask), new_state, () + + @classmethod + def from_config(cls, config): + """Create from a Linen DotProductSelfAttention.Config. + + Since in_features is not in the config (it's inferred), we + return a _DeferredDotProductSelfAttention that creates + projections on first use. + """ + return DeferredDotProductSelfAttention(config) + + +class DeferredDotProductSelfAttention(types.Emitting): + """Wrapper that defers projection creation until first input. + + Linen DotProductSelfAttention.Config doesn't specify in_features; + it is inferred from the first input. + """ + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features, backend='mlx'): + if self._inner is not None: + return + + # Build optional Q/K/V networks. + query_network = None + key_network = None + value_network = None + if self._config.query_network: + query_network = self._config.query_network.make(backend=backend) + if self._config.key_network: + key_network = self._config.key_network.make(backend=backend) + if self._config.value_network: + value_network = self._config.value_network.make(backend=backend) + + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) + self._inner = DotProductSelfAttention( + in_features=in_features, + num_heads=self._config.num_heads, + units_per_head=self._config.units_per_head, + max_past_horizon=self._config.max_past_horizon, + max_future_horizon=self._config.max_future_horizon, + num_kv_heads=self._config.num_kv_heads, + use_bias=self._config.use_bias, + query_scale=getattr(self._config, 'query_scale', None), + compute_dtype=compute_dtype, + param_dtype=param_dtype, + kernel_init=init_mapping.map_initializer( + getattr(self._config, 'input_projection', None) + and getattr( + self._config.input_projection, + 'qkv_kernel_init', + None, + ) + ), + query_network=query_network, + key_network=key_network, + value_network=value_network, + ) + + @property + def supports_step(self): + mph = self._config.max_past_horizon + mfh = self._config.max_future_horizon + return mph >= 0 and mfh >= 0 + + @property + def input_latency(self): + return max(0, self._config.max_future_horizon) + + def get_output_shape(self, input_shape, *, constants=None): + return ( + self._config.num_heads, + self._config.units_per_head, + ) + + def get_output_dtype(self, input_dtype, *, constants=None): + if getattr(self._config, 'compute_dtype', None): + return init_mapping._to_mx_dtype(self._config.compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_initialized(input_spec.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer_with_emits(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer_with_emits(x, constants=constants) + + def step_with_emits(self, x, state, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.step_with_emits(x, state, constants=constants) + + +class DotProductAttention(types.Emitting): + """Multi-headed dot-product cross attention for MLX. + + Queries come from the input sequence; keys and values come from a + source sequence looked up in the ``constants`` dictionary. + + In ``layer()`` mode the K/V projections and optional networks are + applied to the source on-the-fly. In ``step()`` mode they are + pre-computed during ``get_initial_state()`` so that each step only + needs to project and attend queries. + + Kernels are stored in Linen-compatible shapes: + q_proj: [in_features, num_heads * units_per_head] + k_proj: [source_features, num_heads * units_per_head] + v_proj: [source_features, num_heads * units_per_head] + out_proj: [num_heads * units_per_head, in_features] + """ + + def __init__( + self, + *, + in_features: int, + source_features: int, + source_name: str, + num_heads: int, + units_per_head: int, + use_bias: bool = False, + query_scale: float | None = None, + compute_dtype=None, + param_dtype=mx.float32, + kernel_init=None, + bias_init=None, + query_network: types.SequenceLayer | None = None, + key_network: types.SequenceLayer | None = None, + value_network: types.SequenceLayer | None = None, + ): + super().__init__() + self.in_features = in_features + self.source_features = source_features + self.source_name = source_name + self.num_heads = num_heads + self.units_per_head = units_per_head + self.use_bias = use_bias + self._query_scale = query_scale + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + if kernel_init is None: + kernel_init = init_mapping._make_variance_scaling_init( + 'fan_in', 'truncated_normal' + ) + if bias_init is None: + bias_init = init_mapping._zeros_init + + key = mx.random.key(0) + qkv_dim = num_heads * units_per_head + + self.q_proj = kernel_init(key, (in_features, qkv_dim), param_dtype) + self.k_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + self.v_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + if use_bias: + self.q_bias = bias_init(key, (qkv_dim,), param_dtype) + self.k_bias = bias_init(key, (qkv_dim,), param_dtype) + self.v_bias = bias_init(key, (qkv_dim,), param_dtype) + + self.query_network = query_network + self.key_network = key_network + self.value_network = value_network + + @property + def supports_step(self): + if self.query_network is not None: + return self.query_network.supports_step + return True + + @property + def input_latency(self): + return 0 + + def _project_q(self, x): + b, t = x.shape[0], x.shape[1] + dtype = self.compute_dtype or x.dtype + v = x.values.astype(dtype) + q = mx.matmul(v, self.q_proj.astype(dtype)) + if self.use_bias: + q = q + self.q_bias.astype(dtype) + q = q.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(q, x.mask) + + def _project_kv(self, source): + b, t = source.shape[0], source.shape[1] + dtype = self.compute_dtype or source.dtype + v = source.values.astype(dtype) + k = mx.matmul(v, self.k_proj.astype(dtype)) + val = mx.matmul(v, self.v_proj.astype(dtype)) + if self.use_bias: + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) + k = k.reshape(b, t, self.num_heads, self.units_per_head) + val = val.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(k, source.mask), Sequence(val, source.mask) + + def _get_source(self, constants): + if constants is None or self.source_name not in constants: + raise ValueError(f'Source "{self.source_name}" not found in constants.') + return constants[self.source_name] + + def _compute_attention(self, queries, keys, values, mask): + """Compute scaled dot-product attention (no causal mask).""" + scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) + + q = mx.transpose(queries, (0, 2, 1, 3)) + k = mx.transpose(keys, (0, 2, 1, 3)) + v = mx.transpose(values, (0, 2, 1, 3)) + + q = q * scale + logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) + + if mask is not None: + large_neg = mx.array(-1e9, dtype=logits.dtype) + logits = mx.where(mask, logits, large_neg) + + weights = mx.softmax(logits, axis=-1) + context = mx.matmul(weights, v) + context = mx.transpose(context, (0, 2, 1, 3)) + return context + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + 'DotProductAttention requires rank 3 input,' + f' got channel_shape={input_shape}.' + ) + return (self.num_heads, self.units_per_head) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + # Pre-project source keys and values. + source = self._get_source(constants) + keys, values = self._project_kv(source) + + if self.key_network is not None: + keys = self.key_network.layer(keys, constants=constants) + if self.value_network is not None: + values = self.value_network.layer(values, constants=constants) + + keys = keys.mask_invalid() + values = values.mask_invalid() + + q_net_state = ( + self.query_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_heads, self.units_per_head), + self.get_output_dtype(input_spec.dtype), + ), + constants=constants, + ) + if self.query_network is not None + else () + ) + + time_step = mx.zeros((batch_size,), dtype=mx.int32) + return ( + keys.values, + values.values, + keys.mask, + q_net_state, + time_step, + ) + + def layer_with_emits(self, x, *, constants=None): + source = self._get_source(constants) + keys, values = self._project_kv(source) + + if self.key_network is not None: + keys = self.key_network.layer(keys, constants=constants) + if self.value_network is not None: + values = self.value_network.layer(values, constants=constants) + + queries = self._project_q(x) + if self.query_network is not None: + queries = Sequence( + self.query_network.layer(queries, constants=constants).values, + queries.mask, + ) + + values = values.mask_invalid() + valid_mask = source.mask[:, None, None, :] + context = self._compute_attention( + queries.values, keys.values, values.values, valid_mask + ) + return Sequence(context, x.mask), () + + def step_with_emits(self, x, state, *, constants=None): + keys_v, values_v, kv_mask, q_net_state, time_step = state + + queries = self._project_q(x) + if self.query_network is not None: + queries, q_net_state = self.query_network.step( + queries, q_net_state, constants=constants + ) + + valid_mask = kv_mask[:, None, None, :] + context = self._compute_attention( + queries.values, keys_v, values_v, valid_mask + ) + + new_state = ( + keys_v, + values_v, + kv_mask, + q_net_state, + time_step + x.shape[1], + ) + return Sequence(context, x.mask), new_state, () + + @classmethod + def from_config(cls, config): + return DeferredDotProductAttention(config) + + +class DeferredDotProductAttention(types.Emitting): + """Deferred DotProductAttention that creates projections on first use.""" + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features, source_features, backend='mlx'): + if self._inner is not None: + return + + query_network = None + key_network = None + value_network = None + if self._config.query_network: + query_network = self._config.query_network.make(backend=backend) + if self._config.key_network: + key_network = self._config.key_network.make(backend=backend) + if self._config.value_network: + value_network = self._config.value_network.make(backend=backend) + + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) + + self._inner = DotProductAttention( + in_features=in_features, + source_features=source_features, + source_name=self._config.source_name, + num_heads=self._config.num_heads, + units_per_head=self._config.units_per_head, + use_bias=self._config.use_bias, + query_scale=getattr(self._config, 'query_scale', None), + compute_dtype=compute_dtype, + param_dtype=param_dtype, + kernel_init=init_mapping.map_initializer( + getattr(self._config, 'input_projection', None) + and getattr( + self._config.input_projection, + 'qkv_kernel_init', + None, + ) + ), + query_network=query_network, + key_network=key_network, + value_network=value_network, + ) + + def _get_source(self, constants): + if constants is None: + raise ValueError('Constants required for cross-attention.') + if self._config.source_name not in constants: + raise ValueError(f'Source "{self._config.source_name}" not found.') + return constants[self._config.source_name] + + @property + def supports_step(self): + if self._config.query_network is not None: + # Can't easily check without building; assume True. + return True + return True + + @property + def input_latency(self): + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return ( + self._config.num_heads, + self._config.units_per_head, + ) + + def get_output_dtype(self, input_dtype, *, constants=None): + if getattr(self._config, 'compute_dtype', None): + return init_mapping._to_mx_dtype(self._config.compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(input_spec.shape[-1], source.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer_with_emits(self, x, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(x.shape[-1], source.shape[-1]) + return self._inner.layer_with_emits(x, constants=constants) + + def step_with_emits(self, x, state, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(x.shape[-1], source.shape[-1]) + return self._inner.step_with_emits(x, state, constants=constants) + + +def _banded_mask(q_len, kv_len, num_lower, num_upper): + """Build a [1, 1, q_len, kv_len] banded visibility mask. + + Position (i, j) is True iff j >= i - num_lower and j <= i + num_upper. + """ + row = mx.arange(q_len)[:, None] + col = mx.arange(kv_len)[None, :] + mask = (col >= row - num_lower) & (col <= row + num_upper) + return mask.reshape(1, 1, q_len, kv_len) + + +def _step_visibility_mask( + max_past_horizon, max_future_horizon, query_time, key_time +): + """Compute step-wise banded visibility mask. + + For a single query (query_time=1), returns None since no causal mask + is needed — the KV buffer already contains only visible positions. + + For multi-step queries, returns a banded matrix with num_lower=0 and + num_upper=max_past_horizon + max_future_horizon. + """ + if query_time == 1: + return None + return _banded_mask( + query_time, + key_time, + num_lower=0, + num_upper=max_past_horizon + max_future_horizon, + ) + + +class StreamingDotProductAttention(types.Emitting): + """Multi-headed streaming cross-attention for MLX. + + Queries come from the input; keys and values come from a source + sequence provided in constants at the same streaming rate as input. + + Unlike DotProductAttention (which pre-projects the full source in + get_initial_state), this class projects source chunks per-step and + maintains a rolling KV buffer, enabling streaming cross-attention. + + Covers both StreamingDotProductAttention and + StreamingLocalDotProductAttention from the JAX backend (which differ + only in layer-mode efficiency, not in step-mode behavior or output). + + Kernels stored in Linen-compatible shapes: + q_proj: [in_features, num_heads * units_per_head] + k_proj: [source_features, num_heads * units_per_head] + v_proj: [source_features, num_heads * units_per_head] + """ + + def __init__( + self, + *, + in_features: int, + source_features: int, + source_name: str, + num_heads: int, + units_per_head: int, + max_past_horizon: int, + max_future_horizon: int = 0, + use_bias: bool = False, + use_query_delay_buffer: bool = True, + query_scale: float | None = None, + compute_dtype=None, + param_dtype=mx.float32, + kernel_init=None, + bias_init=None, + query_network: types.SequenceLayer | None = None, + key_network: types.SequenceLayer | None = None, + value_network: types.SequenceLayer | None = None, + ): + super().__init__() + if max_past_horizon < 1: + raise ValueError( + f'max_past_horizon must be >= 1, got {max_past_horizon}.' + ) + if max_future_horizon < 0: + raise ValueError( + f'max_future_horizon must be >= 0, got {max_future_horizon}.' + ) + + self.in_features = in_features + self.source_features = source_features + self.source_name = source_name + self.num_heads = num_heads + self.units_per_head = units_per_head + self.max_past_horizon = max_past_horizon + self.max_future_horizon = max_future_horizon + self.use_bias = use_bias + self.use_query_delay_buffer = use_query_delay_buffer + self._query_scale = query_scale + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + if kernel_init is None: + kernel_init = init_mapping._make_variance_scaling_init( + 'fan_in', 'truncated_normal' + ) + if bias_init is None: + bias_init = init_mapping._zeros_init + + key = mx.random.key(0) + qkv_dim = num_heads * units_per_head + + # Q projection from input. + self.q_proj = kernel_init(key, (in_features, qkv_dim), param_dtype) + # K/V projections from source. + self.k_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + self.v_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + if use_bias: + self.q_bias = bias_init(key, (qkv_dim,), param_dtype) + self.k_bias = bias_init(key, (qkv_dim,), param_dtype) + self.v_bias = bias_init(key, (qkv_dim,), param_dtype) + + self.query_network = query_network + self.key_network = key_network + self.value_network = value_network + + @property + def supports_step(self): + return True + + @property + def input_latency(self): + if self.max_future_horizon > 0 and self.use_query_delay_buffer: + return self.max_future_horizon + return 0 + + def _project_q(self, x): + """Project input to query sequence.""" + b, t = x.shape[0], x.shape[1] + dtype = self.compute_dtype or x.dtype + v = x.values.astype(dtype) + q = mx.matmul(v, self.q_proj.astype(dtype)) + if self.use_bias: + q = q + self.q_bias.astype(dtype) + q = q.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(q, x.mask) + + def _project_kv(self, source): + """Project source to key/value sequences.""" + b, t = source.shape[0], source.shape[1] + dtype = self.compute_dtype or source.dtype + v = source.values.astype(dtype) + k = mx.matmul(v, self.k_proj.astype(dtype)) + val = mx.matmul(v, self.v_proj.astype(dtype)) + if self.use_bias: + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) + k = k.reshape(b, t, self.num_heads, self.units_per_head) + val = val.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(k, source.mask), Sequence(val, source.mask) + + def _get_source(self, constants): + if constants is None or self.source_name not in constants: + raise ValueError(f'Source "{self.source_name}" not found in constants.') + return constants[self.source_name] + + def _compute_attention(self, queries, keys, values, mask): + """Compute scaled dot-product attention.""" + scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) + q = mx.transpose(queries, (0, 2, 1, 3)) + k = mx.transpose(keys, (0, 2, 1, 3)) + v = mx.transpose(values, (0, 2, 1, 3)) + q = q * scale + logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) + if mask is not None: + large_neg = mx.array(-1e9, dtype=logits.dtype) + logits = mx.where(mask, logits, large_neg) + weights = mx.softmax(logits, axis=-1) + context = mx.matmul(weights, v) + context = mx.transpose(context, (0, 2, 1, 3)) + return context + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + 'StreamingDotProductAttention requires rank 3 input,' + f' got channel_shape={input_shape}.' + ) + return (self.num_heads, self.units_per_head) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + compute_dtype = self.get_output_dtype(input_spec.dtype) + max_past = max(0, self.max_past_horizon) + max_future = max(0, self.max_future_horizon) + kv_buffer_size = max_past + max_future + + kv_shape = ( + batch_size, + kv_buffer_size, + self.num_heads, + self.units_per_head, + ) + kv_buffer_keys = mx.zeros(kv_shape, dtype=compute_dtype) + kv_buffer_values = mx.zeros(kv_shape, dtype=compute_dtype) + kv_buffer_mask = mx.zeros((batch_size, kv_buffer_size), dtype=mx.bool_) + time_step = mx.zeros((batch_size,), dtype=mx.int32) + + # Q/K/V network states. + q_net_state = ( + self.query_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.query_network is not None + else () + ) + k_net_state = ( + self.key_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.key_network is not None + else () + ) + v_net_state = ( + self.value_network.get_initial_state( + batch_size, + bt.ShapeDType( + (self.num_heads, self.units_per_head), + compute_dtype, + ), + constants=constants, + ) + if self.value_network is not None + else () + ) + + # Query delay buffer for future horizon. + if max_future and self.use_query_delay_buffer: + q_delay_values = mx.zeros( + ( + batch_size, + max_future, + self.num_heads, + self.units_per_head, + ), + dtype=compute_dtype, + ) + q_delay_mask = mx.zeros((batch_size, max_future), dtype=mx.bool_) + else: + q_delay_values = () + q_delay_mask = () + + return ( + kv_buffer_keys, + kv_buffer_values, + kv_buffer_mask, + time_step, + q_net_state, + k_net_state, + v_net_state, + q_delay_values, + q_delay_mask, + ) + + def layer_with_emits(self, x, *, constants=None): + source = self._get_source(constants) + + queries = self._project_q(x) + keys, values = self._project_kv(source) + queries_time = queries.shape[1] + keys_time = keys.shape[1] + + # Optional Q/K/V processing networks. + if self.query_network is not None: + queries = Sequence( + self.query_network.layer(queries, constants=constants).values, + queries.mask, + ) + if self.key_network is not None: + keys = Sequence( + self.key_network.layer(keys, constants=constants).values, + keys.mask, + ) + if self.value_network is not None: + values = Sequence( + self.value_network.layer(values, constants=constants).values, + values.mask, + ) + + # Mask invalid values. + values = values.mask_invalid() + + # Build visibility mask: banded + source validity. + valid_mask = source.mask[:, None, None, :] + banded = _banded_mask( + queries_time, + keys_time, + num_lower=self.max_past_horizon, + num_upper=self.max_future_horizon, + ) + valid_mask = valid_mask & banded + + context = self._compute_attention( + queries.values, keys.values, values.values, valid_mask + ) + return Sequence(context, x.mask), () + + def step_with_emits(self, x, state, *, constants=None): + source = self._get_source(constants) + + if x.shape[1] != source.shape[1]: + raise ValueError( + f'Expected x.shape[1]={x.shape[1]} to match' + f' source.shape[1]={source.shape[1]}' + ) + + ( + kv_buf_k, + kv_buf_v, + kv_buf_mask, + time_step, + q_net_state, + k_net_state, + v_net_state, + q_delay_values, + q_delay_mask, + ) = state + + kv_buffer_size = kv_buf_k.shape[1] + x_time = x.shape[1] + + queries = self._project_q(x) + keys, values = self._project_kv(source) + + # Optional Q/K/V processing networks. + if self.query_network is not None: + queries, q_net_state = self.query_network.step( + queries, q_net_state, constants=constants + ) + if self.key_network is not None: + keys, k_net_state = self.key_network.step( + keys, k_net_state, constants=constants + ) + if self.value_network is not None: + values, v_net_state = self.value_network.step( + values, v_net_state, constants=constants + ) + + # Mask invalid values. + values = values.mask_invalid() + + # Concatenate new K/V to buffer. + new_k = mx.concatenate([kv_buf_k, keys.values], axis=1) + new_v = mx.concatenate([kv_buf_v, values.values], axis=1) + new_mask = mx.concatenate([kv_buf_mask, source.mask], axis=1) + + # Handle query delay buffer. + has_delay_buffer = not isinstance(q_delay_values, tuple) + if has_delay_buffer: + # Insert new queries into delay buffer. + all_q_values = mx.concatenate([q_delay_values, queries.values], axis=1) + all_q_mask = mx.concatenate([q_delay_mask, queries.mask], axis=1) + # Pop oldest x_time queries as current. + queries = Sequence(all_q_values[:, :x_time], all_q_mask[:, :x_time]) + # Preserve remaining for next step. + q_delay_values = all_q_values[:, -self.max_future_horizon :] + q_delay_mask = all_q_mask[:, -self.max_future_horizon :] + + # Build visibility mask. + kv_time = new_k.shape[1] + valid_mask = new_mask[:, None, None, :] + + vis_mask = _step_visibility_mask( + self.max_past_horizon, + self.max_future_horizon, + x_time, + kv_time, + ) + if vis_mask is not None: + valid_mask = valid_mask & vis_mask + + context = self._compute_attention(queries.values, new_k, new_v, valid_mask) + + # Trim KV buffer to keep only last kv_buffer_size entries. + new_k = new_k[:, -kv_buffer_size:] + new_v = new_v[:, -kv_buffer_size:] + new_mask = new_mask[:, -kv_buffer_size:] + + new_state = ( + new_k, + new_v, + new_mask, + time_step + x_time, + q_net_state, + k_net_state, + v_net_state, + q_delay_values, + q_delay_mask, + ) + return Sequence(context, queries.mask), new_state, () + + @classmethod + def from_config(cls, config): + return DeferredStreamingDotProductAttention(config) + + +class DeferredStreamingDotProductAttention(types.Emitting): + """Deferred StreamingDotProductAttention. + + Creates the inner attention on first use when in_features and + source_features are known. + """ + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features, source_features, backend='mlx'): + if self._inner is not None: + return + + query_network = None + key_network = None + value_network = None + if self._config.query_network: + query_network = self._config.query_network.make(backend=backend) + if self._config.key_network: + key_network = self._config.key_network.make(backend=backend) + if self._config.value_network: + value_network = self._config.value_network.make(backend=backend) + + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) + + self._inner = StreamingDotProductAttention( + in_features=in_features, + source_features=source_features, + source_name=self._config.source_name, + num_heads=self._config.num_heads, + units_per_head=self._config.units_per_head, + max_past_horizon=self._config.max_past_horizon, + max_future_horizon=self._config.max_future_horizon, + use_bias=self._config.use_bias, + use_query_delay_buffer=getattr( + self._config, 'use_query_delay_buffer', True + ), + query_scale=getattr(self._config, 'query_scale', None), + compute_dtype=compute_dtype, + param_dtype=param_dtype, + kernel_init=init_mapping.map_initializer( + getattr(self._config, 'input_projection', None) + and getattr( + self._config.input_projection, + 'q_kernel_init', + None, + ) + ), + query_network=query_network, + key_network=key_network, + value_network=value_network, + ) + + def _get_source(self, constants): + if constants is None: + raise ValueError('Constants required for streaming attention.') + if self._config.source_name not in constants: + raise ValueError(f'Source "{self._config.source_name}" not found.') + return constants[self._config.source_name] + + @property + def supports_step(self): + return True + + @property + def input_latency(self): + mfh = self._config.max_future_horizon + uqdb = getattr(self._config, 'use_query_delay_buffer', True) + if mfh > 0 and uqdb: + return mfh + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return ( + self._config.num_heads, + self._config.units_per_head, + ) + + def get_output_dtype(self, input_dtype, *, constants=None): + if getattr(self._config, 'compute_dtype', None): + return init_mapping._to_mx_dtype(self._config.compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(input_spec.shape[-1], source.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer_with_emits(self, x, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(x.shape[-1], source.shape[-1]) + return self._inner.layer_with_emits(x, constants=constants) + + def step_with_emits(self, x, state, *, constants=None): + source = self._get_source(constants) + self._ensure_initialized(x.shape[-1], source.shape[-1]) + return self._inner.step_with_emits(x, state, constants=constants) + + +class LocalDotProductSelfAttention(DotProductSelfAttention): + """Local dot-product self attention with configurable block_size. + + Extends DotProductSelfAttention with a configurable block_size for + step-mode processing. The sliding window behavior is already handled + by the base class's banded visibility mask via max_past_horizon and + max_future_horizon. + """ + + def __init__(self, *, block_size_config: int = 1, **kwargs): + super().__init__(**kwargs) + self._block_size_config = block_size_config + + @property + def block_size(self): + return self._block_size_config + + @classmethod + def from_config(cls, config): + return DeferredLocalDotProductSelfAttention(config) + + +class DeferredLocalDotProductSelfAttention(types.Emitting): + """Deferred LocalDotProductSelfAttention. + + Creates the inner attention on first use when in_features is known. + """ + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features, backend='mlx'): + if self._inner is not None: + return + + query_network = None + key_network = None + value_network = None + if self._config.query_network: + query_network = self._config.query_network.make(backend=backend) + if self._config.key_network: + key_network = self._config.key_network.make(backend=backend) + if self._config.value_network: + value_network = self._config.value_network.make(backend=backend) + + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) + + self._inner = LocalDotProductSelfAttention( + in_features=in_features, + num_heads=self._config.num_heads, + units_per_head=self._config.units_per_head, + max_past_horizon=self._config.max_past_horizon, + max_future_horizon=self._config.max_future_horizon, + use_bias=self._config.use_bias, + block_size_config=self._config.block_size, + query_scale=getattr(self._config, 'query_scale', None), + compute_dtype=compute_dtype, + param_dtype=param_dtype, + attention_logits_soft_cap=getattr( + self._config, 'attention_logits_soft_cap', None + ), + kernel_init=init_mapping.map_initializer( + getattr(self._config, 'input_projection', None) + and getattr( + self._config.input_projection, + 'qkv_kernel_init', + None, + ) + ), + query_network=query_network, + key_network=key_network, + value_network=value_network, + ) + + @property + def supports_step(self): + mph = self._config.max_past_horizon + mfh = self._config.max_future_horizon + return mph >= 0 and mfh >= 0 + + @property + def block_size(self): + return self._config.block_size + + @property + def input_latency(self): + return max(0, self._config.max_future_horizon) + + def get_output_shape(self, input_shape, *, constants=None): + return ( + self._config.num_heads, + self._config.units_per_head, + ) + + def get_output_dtype(self, input_dtype, *, constants=None): + if getattr(self._config, 'compute_dtype', None): + return init_mapping._to_mx_dtype(self._config.compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_initialized(input_spec.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer_with_emits(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer_with_emits(x, constants=constants) + + def step_with_emits(self, x, state, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.step_with_emits(x, state, constants=constants) diff --git a/sequence_layers/mlx/attention_test.py b/sequence_layers/mlx/attention_test.py new file mode 100644 index 0000000..7c9dea7 --- /dev/null +++ b/sequence_layers/mlx/attention_test.py @@ -0,0 +1,528 @@ +"""Tests for attention MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import attention +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import position +from sequence_layers.mlx import test_utils + + +class DotProductSelfAttentionTest(parameterized.TestCase): + + def test_layer(self): + layer = attention.DotProductSelfAttention( + in_features=16, + num_heads=4, + units_per_head=8, + max_past_horizon=32, + ) + test_utils.verify_contract(self, layer, (16,), atol=1e-4, rtol=1e-4) + + def test_causal(self): + layer = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=64, + max_future_horizon=0, + ) + test_utils.verify_contract(self, layer, (8,), atol=1e-4, rtol=1e-4) + + def test_gqa(self): + """Test Grouped Query Attention (fewer KV heads).""" + layer = attention.DotProductSelfAttention( + in_features=16, + num_heads=8, + units_per_head=4, + max_past_horizon=32, + num_kv_heads=2, + ) + test_utils.verify_contract(self, layer, (16,), atol=1e-4, rtol=1e-4) + + def test_output_shape(self): + layer = attention.DotProductSelfAttention( + in_features=16, + num_heads=4, + units_per_head=8, + max_past_horizon=32, + ) + self.assertEqual(layer.get_output_shape((16,)), (4, 8)) + + def test_step_builds_kv_cache(self): + layer = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=10, + ) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec) + + for i in range(5): + x = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + _, state = layer.step(x, state) + + # Check KV cache has been populated. + kv_keys = state[0] + self.assertEqual(kv_keys.shape[1], 10) # buffer size + + def test_with_query_key_networks(self): + """Test with RoPE on Q/K.""" + rope = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ) + layer = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=32, + query_network=rope, + key_network=position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ), + ) + x = test_utils.random_sequence(1, 5, 8) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 5, 2, 4)) + + +class DeferredDotProductSelfAttentionTest(parameterized.TestCase): + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax.attention import ( + dot_product_self_attention as jax_attn, + ) + + config = jax_attn.DotProductSelfAttention.Config( + num_heads=4, + units_per_head=8, + max_past_horizon=32, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + attention.DeferredDotProductSelfAttention, + ) + + x = test_utils.random_sequence(1, 5, 16) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (4, 8)) + + +class DotProductAttentionTest(parameterized.TestCase): + """Tests for cross-attention.""" + + def _make_constants(self, batch, time, features, name='source'): + source = test_utils.random_sequence(batch, time, features) + return {name: source} + + def test_layer(self): + layer = attention.DotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + ) + constants = self._make_constants(2, 6, 12) + test_utils.verify_contract( + self, + layer, + (8,), + constants=constants, + atol=1e-4, + rtol=1e-4, + ) + + def test_output_shape(self): + layer = attention.DotProductAttention( + in_features=16, + source_features=16, + source_name='enc', + num_heads=4, + units_per_head=8, + ) + self.assertEqual(layer.get_output_shape((16,)), (4, 8)) + + def test_step_reuses_precomputed_kv(self): + layer = attention.DotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + ) + constants = self._make_constants(1, 6, 12) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec, constants=constants) + # KV should be pre-computed. + keys_v = state[0] + self.assertEqual(keys_v.shape, (1, 6, 2, 4)) + + for _ in range(3): + x = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + y, state = layer.step(x, state, constants=constants) + self.assertEqual(y.channel_shape, (2, 4)) + + def test_missing_source_raises(self): + layer = attention.DotProductAttention( + in_features=8, + source_features=8, + source_name='missing', + num_heads=2, + units_per_head=4, + ) + x = test_utils.random_sequence(1, 3, 8) + with self.assertRaises(ValueError): + layer.layer(x, constants={}) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=4, + units_per_head=8, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + attention.DeferredDotProductAttention, + ) + source = test_utils.random_sequence(1, 6, 16) + constants = {'enc': source} + x = test_utils.random_sequence(1, 4, 16) + y = mlx_layer.layer(x, constants=constants) + self.assertEqual(y.channel_shape, (4, 8)) + + +class StreamingDotProductAttentionTest(parameterized.TestCase): + """Tests for streaming cross-attention.""" + + def _make_source(self, batch, time, features, name='source'): + return test_utils.random_sequence(batch, time, features) + + def test_layer_basic(self): + """Basic layer mode with banded visibility mask.""" + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=4, + ) + source = self._make_source(2, 8, 12) + x = test_utils.random_sequence(2, 8, 8) + y = layer.layer(x, constants={'source': source}) + self.assertEqual(y.channel_shape, (2, 4)) + self.assertEqual(y.shape, (2, 8, 2, 4)) + + def test_step_builds_kv_cache(self): + """KV buffer grows correctly during step mode.""" + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=10, + ) + source = self._make_source(1, 1, 12) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec, constants={'source': source}) + + for _ in range(5): + x = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + src = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 12)), + mx.ones((1, 1), dtype=mx.bool_), + ) + _, state, _ = layer.step_with_emits(x, state, constants={'source': src}) + + kv_keys = state[0] + self.assertEqual(kv_keys.shape[1], 10) # buffer size + + def test_step_matches_layer(self): + """Step-by-step with streaming constants matches layer().""" + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=16, + ) + batch, time = 1, 8 + x = test_utils.random_sequence(batch, time, 8) + source = self._make_source(batch, time, 12) + constants = {'source': source} + + # Layer mode. + y_layer = layer.layer(x, constants=constants) + + # Step-by-step mode. + y_step, _ = test_utils.step_by_step( + layer, + x, + block_size=1, + stream_constants={'source': source}, + ) + + np.testing.assert_allclose( + np.array(y_step.values), + np.array(y_layer.values), + atol=1e-4, + rtol=1e-4, + err_msg='step vs layer mismatch', + ) + + def test_with_future_horizon(self): + """Query delay buffer with max_future_horizon > 0.""" + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=8, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=4, + max_future_horizon=2, + use_query_delay_buffer=True, + ) + self.assertEqual(layer.input_latency, 2) + source = self._make_source(1, 8, 8) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec, constants={'source': source}) + + # Verify delay buffer is in state. + q_delay_values = state[7] + self.assertFalse(isinstance(q_delay_values, tuple)) + self.assertEqual(q_delay_values.shape[1], 2) + + # Run a few steps to make sure it doesn't crash. + for _ in range(5): + x = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + src = bt.MaskedSequence( + mx.random.normal(shape=(1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + y, state, _ = layer.step_with_emits(x, state, constants={'source': src}) + self.assertEqual(y.channel_shape, (2, 4)) + + def test_no_query_delay_buffer(self): + """use_query_delay_buffer=False has no delay.""" + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=8, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=4, + max_future_horizon=2, + use_query_delay_buffer=False, + ) + self.assertEqual(layer.input_latency, 0) + source = self._make_source(1, 8, 8) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec, constants={'source': source}) + # Delay buffer should be empty tuples. + self.assertIsInstance(state[7], tuple) + self.assertEqual(state[7], ()) + + def test_with_rope(self): + """Q/K processing networks (RoPE).""" + rope_q = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ) + rope_k = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ) + layer = attention.StreamingDotProductAttention( + in_features=8, + source_features=12, + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=16, + query_network=rope_q, + key_network=rope_k, + ) + source = self._make_source(1, 5, 12) + x = test_utils.random_sequence(1, 5, 8) + y = layer.layer(x, constants={'source': source}) + self.assertEqual(y.shape, (1, 5, 2, 4)) + + def test_output_shape(self): + layer = attention.StreamingDotProductAttention( + in_features=16, + source_features=16, + source_name='source', + num_heads=4, + units_per_head=8, + max_past_horizon=8, + ) + self.assertEqual(layer.get_output_shape((16,)), (4, 8)) + + def test_from_config(self): + """Both Streaming and StreamingLocal configs produce correct layer.""" + import sequence_layers.mlx + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + from sequence_layers.jax.attention import ( + streaming_local_dot_product_attention as jax_streaming_local_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='source', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + attention.DeferredStreamingDotProductAttention, + ) + + source = test_utils.random_sequence(1, 6, 8) + x = test_utils.random_sequence(1, 6, 8) + y = mlx_layer.layer(x, constants={'source': source}) + self.assertEqual(y.channel_shape, (2, 4)) + + # StreamingLocal config should also work. + local_config = ( + jax_streaming_local_attn.StreamingLocalDotProductAttention.Config( + source_name='source', + num_heads=2, + units_per_head=4, + block_size=2, + max_past_horizon=8, + ) + ) + mlx_local = local_config.make(backend='mlx') + self.assertIsInstance( + mlx_local, + attention.DeferredStreamingDotProductAttention, + ) + + +class LocalDotProductSelfAttentionTest(parameterized.TestCase): + + def test_layer(self): + layer = attention.LocalDotProductSelfAttention( + in_features=16, + num_heads=4, + units_per_head=4, + max_past_horizon=8, + block_size_config=2, + ) + test_utils.verify_contract(self, layer, (16,), atol=1e-4, rtol=1e-4) + + def test_block_size(self): + layer = attention.LocalDotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=4, + block_size_config=4, + ) + self.assertEqual(layer.block_size, 4) + + def test_with_future_horizon(self): + layer = attention.LocalDotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=4, + max_future_horizon=2, + block_size_config=1, + ) + self.assertEqual(layer.input_latency, 2) + test_utils.verify_contract( + self, layer, (8,), atol=1e-4, rtol=1e-4, test_step=False + ) + + def test_with_soft_cap(self): + layer = attention.LocalDotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=8, + block_size_config=1, + attention_logits_soft_cap=50.0, + ) + test_utils.verify_contract(self, layer, (8,), atol=1e-4, rtol=1e-4) + + def test_with_rope(self): + rope = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ) + layer = attention.LocalDotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=8, + block_size_config=1, + query_network=rope, + key_network=position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ), + ) + test_utils.verify_contract(self, layer, (8,), atol=1e-4, rtol=1e-4) + + def test_output_shape(self): + layer = attention.LocalDotProductSelfAttention( + in_features=16, + num_heads=4, + units_per_head=8, + max_past_horizon=8, + block_size_config=2, + ) + self.assertEqual(layer.get_output_shape((16,)), (4, 8)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + + config = jax_local_attn.LocalDotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + block_size=2, + max_past_horizon=8, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + attention.DeferredLocalDotProductSelfAttention, + ) + self.assertEqual(mlx_layer.block_size, 2) + + x = test_utils.random_sequence(1, 8, 8) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (2, 4)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/backend_dispatch_test.py b/sequence_layers/mlx/backend_dispatch_test.py new file mode 100644 index 0000000..cfd0160 --- /dev/null +++ b/sequence_layers/mlx/backend_dispatch_test.py @@ -0,0 +1,72 @@ +"""Tests for the backend dispatch mechanism.""" + +from absl.testing import absltest +from absl.testing import parameterized + + +class BackendDispatchTest(parameterized.TestCase): + + def test_make_default_returns_linen(self): + """config.make() (no args) still returns Linen module.""" + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Identity.Config() + linen_model = config.make() + self.assertEqual(type(linen_model).__name__, 'Identity') + # Check it's a Linen module, not MLX. + from flax import linen as nn + + self.assertIsInstance(linen_model, nn.Module) + + def test_make_backend_mlx(self): + """config.make(backend='mlx') returns MLX module.""" + import sequence_layers.mlx # Register backends. + from sequence_layers.jax import simple as jax_simple + from sequence_layers.mlx import simple as mlx_simple + + config = jax_simple.Identity.Config() + mlx_model = config.make(backend='mlx') + self.assertIsInstance(mlx_model, mlx_simple.Identity) + + def test_unregistered_backend_raises(self): + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Identity.Config() + with self.assertRaises(ValueError): + config.make(backend='nonexistent') + + def test_nested_configs_dispatch(self): + """Nested configs (e.g. Serial) correctly dispatch children.""" + import sequence_layers.mlx + import sequence_layers.jax as sl + + config = sl.Serial.Config([ + sl.Identity.Config(), + sl.Dense.Config(features=8), + ]) + mlx_model = config.make(backend='mlx') + from sequence_layers.mlx import combinators + + self.assertIsInstance(mlx_model, combinators.Serial) + self.assertEqual(len(mlx_model.layers), 2) + + def test_mro_lookup(self): + """Config subclasses inherit backend factories via MRO.""" + import sequence_layers.mlx + import dataclasses + from sequence_layers.jax import simple as jax_simple + + # Create a subclass of Identity.Config. + @dataclasses.dataclass(frozen=True) + class MyIdentityConfig(jax_simple.Identity.Config): + pass + + # Should still find the factory via MRO. + mlx_model = MyIdentityConfig().make(backend='mlx') + from sequence_layers.mlx import simple as mlx_simple + + self.assertIsInstance(mlx_model, mlx_simple.Identity) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/basic_types.py b/sequence_layers/mlx/basic_types.py index 7369df2..8a4650d 100644 --- a/sequence_layers/mlx/basic_types.py +++ b/sequence_layers/mlx/basic_types.py @@ -131,3 +131,8 @@ def mask_invalid( # Defined outside of Sequence so that mask_invalid can return a MaskedSequence. Sequence.mask_invalid = mask_invalid + +# For backward compatibility during rebase +from sequence_layers.mlx.types import PaddingMode +from sequence_layers.mlx.types import ShapeDType + diff --git a/sequence_layers/mlx/combinators.py b/sequence_layers/mlx/combinators.py new file mode 100644 index 0000000..8b7296e --- /dev/null +++ b/sequence_layers/mlx/combinators.py @@ -0,0 +1,423 @@ +"""Combinators (Serial, Residual, Repeat, Parallel) for MLX.""" + +import enum +from functools import reduce +from math import lcm + +import mlx.core as mx + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import simple as simple_lib +from sequence_layers.mlx import types + +Sequence = bt.Sequence + + +class CombinationMode(enum.IntEnum): + """How parallel outputs are combined.""" + + STACK = 1 + CONCAT = 2 + ADD = 3 + MEAN = 4 + PRODUCT = 5 + + +def _broadcast_shapes(*shapes): + """Numpy-style shape broadcasting.""" + if not shapes: + return () + max_dims = max(len(s) for s in shapes) + if max_dims == 0: + return () + padded = [(1,) * (max_dims - len(s)) + tuple(s) for s in shapes] + result = [] + for dims in zip(*padded): + max_dim = max(dims) + for d in dims: + if d != 1 and d != max_dim: + raise ValueError(f'Shapes not broadcastable: {shapes}') + result.append(max_dim) + return tuple(result) + + +def _combine_output_channel_shape(mode, *channel_shapes): + """Compute the output channel shape for a combination mode.""" + max_dims = max(len(x) for x in channel_shapes) + padded = tuple((1,) * (max_dims - len(x)) + tuple(x) for x in channel_shapes) + + if mode == CombinationMode.STACK: + bcast = _broadcast_shapes(*padded) + return (len(channel_shapes),) + bcast + elif mode == CombinationMode.CONCAT: + if max_dims == 0: + # All scalar → treat as (1,) each. + padded = tuple((1,) for _ in channel_shapes) + prefixes = tuple(x[:-1] for x in padded) + bcast_prefix = _broadcast_shapes(*prefixes) + final_dim = sum(x[-1] for x in padded) + return bcast_prefix + (final_dim,) + else: # ADD, MEAN, PRODUCT + return _broadcast_shapes(*padded) + + +def _combine_sequences(mode, sequences): + """Combine parallel output sequences.""" + values_list = [s.values for s in sequences] + masks = [s.mask for s in sequences] + mask = masks[0] + for m in masks[1:]: + mask = mask & m + + if mode == CombinationMode.STACK: + values = mx.stack(values_list, axis=2) + elif mode == CombinationMode.CONCAT: + values = mx.concatenate(values_list, axis=-1) + elif mode == CombinationMode.ADD: + values = values_list[0] + for v in values_list[1:]: + values = values + v + elif mode == CombinationMode.MEAN: + values = values_list[0] + for v in values_list[1:]: + values = values + v + values = values / len(values_list) + elif mode == CombinationMode.PRODUCT: + values = values_list[0] + for v in values_list[1:]: + values = values * v + else: + raise ValueError(f'Unknown combination mode: {mode}') + + return Sequence(values, mask) + + +class Serial(types.Emitting): + """Processes SequenceLayers serially.""" + + def __init__(self, layers: list[types.SequenceLayer]): + super().__init__() + self.layers = list(layers) + + @property + def supports_step(self): + return all(l.supports_step for l in self.layers) + + @property + def block_size(self): + from functools import reduce + from math import lcm + + return reduce(lcm, (l.block_size for l in self.layers), 1) + + @property + def output_ratio(self): + r = self.layers[0].output_ratio if self.layers else 1 + for l in self.layers[1:]: + r = r * l.output_ratio + return r + + @property + def input_latency(self): + latency = 0 + for l in self.layers: + latency = l.get_accumulated_input_latency(latency) + return latency + + def get_output_shape(self, input_shape, *, constants=None): + shape = input_shape + for l in self.layers: + shape = l.get_output_shape(shape, constants=constants) + return shape + + def get_output_dtype(self, input_dtype, *, constants=None): + dtype = input_dtype + for l in self.layers: + dtype = l.get_output_dtype(dtype, constants=constants) + return dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + spec = input_spec + states = [] + for l in self.layers: + states.append(l.get_initial_state(batch_size, spec, constants=constants)) + spec = l.get_output_spec(spec, constants=constants) + return tuple(states) + + def layer_with_emits(self, x, *, constants=None): + emits = {} + for i, l in enumerate(self.layers): + x, e = l.layer_with_emits(x, constants=constants) + emits[f'layer_{i}'] = e + return x, emits + + def step_with_emits(self, x, state, *, constants=None): + new_state = [] + emits = {} + for i, (l, s) in enumerate(zip(self.layers, state)): + x, s, e = l.step_with_emits(x, s, constants=constants) + new_state.append(s) + emits[f'layer_{i}'] = e + return x, tuple(new_state), emits + + @classmethod + def from_config(cls, config, backend='mlx'): + layers = [c.make(backend=backend) for c in config.layers] + return cls(layers) + + +class Residual(types.Emitting): + """Residual wrapper: y = body(x) + shortcut(x).""" + + def __init__( + self, + layers: list[types.SequenceLayer], + *, + shortcut: types.SequenceLayer | None = None, + ): + super().__init__() + self._body = Serial(layers) + self._shortcut = shortcut if shortcut is not None else simple_lib.Identity() + + @property + def supports_step(self): + return self._body.supports_step and self._shortcut.supports_step + + @property + def block_size(self): + from math import lcm + + return lcm(self._body.block_size, self._shortcut.block_size) + + @property + def output_ratio(self): + return self._body.output_ratio + + @property + def input_latency(self): + return self._body.input_latency + + def get_output_shape(self, input_shape, *, constants=None): + return self._body.get_output_shape(input_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self._body.get_output_dtype(input_dtype, constants=constants) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + body_state = self._body.get_initial_state( + batch_size, input_spec, constants=constants + ) + shortcut_state = self._shortcut.get_initial_state( + batch_size, input_spec, constants=constants + ) + return (body_state, shortcut_state) + + def _residual_fn(self, y_body, y_shortcut): + y_values = y_body.values + y_shortcut.values + y_mask = y_body.mask & y_shortcut.mask + return Sequence(y_values, y_mask) + + def layer_with_emits(self, x, *, constants=None): + y_body, body_emits = self._body.layer_with_emits(x, constants=constants) + y_shortcut, shortcut_emits = self._shortcut.layer_with_emits( + x, constants=constants + ) + y = self._residual_fn(y_body, y_shortcut) + return y, (body_emits, shortcut_emits) + + def step_with_emits(self, x, state, *, constants=None): + body_state, shortcut_state = state + y_body, body_state, body_emits = self._body.step_with_emits( + x, body_state, constants=constants + ) + y_shortcut, shortcut_state, shortcut_emits = self._shortcut.step_with_emits( + x, shortcut_state, constants=constants + ) + y = self._residual_fn(y_body, y_shortcut) + return ( + y, + (body_state, shortcut_state), + (body_emits, shortcut_emits), + ) + + @classmethod + def from_config(cls, config, backend='mlx'): + layers = [c.make(backend=backend) for c in config.layers] + shortcut = None + if hasattr(config, 'shortcut_layers') and config.shortcut_layers: + shortcut_layers = [ + c.make(backend=backend) for c in config.shortcut_layers + ] + if len(shortcut_layers) == 1: + shortcut = shortcut_layers[0] + else: + shortcut = Serial(shortcut_layers) + return cls(layers, shortcut=shortcut) + + +class Repeat(types.Emitting): + """Repeats a SequenceLayer N times. + + Unlike Linen/NNX which use scan/vmap to share stacked params, + MLX Repeat creates N independent copies of the child layer. + Each copy has its own parameters. + """ + + def __init__( + self, + layers: list[types.SequenceLayer], + ): + super().__init__() + if not layers: + raise ValueError('Repeat requires at least one layer.') + self.layers = list(layers) + self.num_repeats = len(layers) + + @property + def supports_step(self): + return all(l.supports_step for l in self.layers) + + @property + def block_size(self): + return self.layers[0].block_size + + @property + def output_ratio(self): + return self.layers[0].output_ratio + + @property + def input_latency(self): + latency = 0 + for l in self.layers: + latency = l.get_accumulated_input_latency(latency) + return latency + + def get_output_shape(self, input_shape, *, constants=None): + return self.layers[0].get_output_shape(input_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.layers[0].get_output_dtype(input_dtype, constants=constants) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + states = [] + spec = input_spec + for l in self.layers: + states.append(l.get_initial_state(batch_size, spec, constants=constants)) + # All repeats have the same output spec. + return tuple(states) + + def layer_with_emits(self, x, *, constants=None): + emits = {} + for i, l in enumerate(self.layers): + x, e = l.layer_with_emits(x, constants=constants) + emits[f'repeat_{i}'] = e + return x, emits + + def step_with_emits(self, x, state, *, constants=None): + new_state = [] + emits = {} + for i, (l, s) in enumerate(zip(self.layers, state)): + x, s, e = l.step_with_emits(x, s, constants=constants) + new_state.append(s) + emits[f'repeat_{i}'] = e + return x, tuple(new_state), emits + + @classmethod + def from_config(cls, config, backend='mlx'): + layers = [ + config.layer.make(backend=backend) for _ in range(config.num_repeats) + ] + return cls(layers) + + +class Parallel(types.Emitting): + """Runs N children on the same input and combines outputs. + + All children must have equal output_ratio and block_size. + """ + + def __init__( + self, + layers: list[types.SequenceLayer], + *, + combination: CombinationMode = CombinationMode.STACK, + ): + super().__init__() + if not layers: + raise ValueError('Parallel requires at least one layer.') + self.layers = list(layers) + self.combination = combination + + # Validate constraints. + ratios = {l.output_ratio for l in self.layers} + if len(ratios) > 1: + raise ValueError( + f'All Parallel children must have equal output_ratio, got {ratios}.' + ) + blocks = {l.block_size for l in self.layers} + if len(blocks) > 1: + raise ValueError( + f'All Parallel children must have equal block_size, got {blocks}.' + ) + + @property + def supports_step(self): + return all(l.supports_step for l in self.layers) + + @property + def block_size(self): + return reduce(lcm, (l.block_size for l in self.layers), 1) + + @property + def output_ratio(self): + return self.layers[0].output_ratio + + @property + def input_latency(self): + return self.layers[0].input_latency + + def get_output_shape(self, input_shape, *, constants=None): + shapes = tuple( + l.get_output_shape(input_shape, constants=constants) + for l in self.layers + ) + return _combine_output_channel_shape(self.combination, *shapes) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.layers[0].get_output_dtype(input_dtype, constants=constants) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + states = [] + for l in self.layers: + states.append( + l.get_initial_state(batch_size, input_spec, constants=constants) + ) + return tuple(states) + + def layer_with_emits(self, x, *, constants=None): + outputs = [] + emits = {} + for i, l in enumerate(self.layers): + y, e = l.layer_with_emits(x, constants=constants) + outputs.append(y) + emits[f'parallel_{i}'] = e + combined = _combine_sequences(self.combination, outputs) + return combined, emits + + def step_with_emits(self, x, state, *, constants=None): + outputs = [] + new_state = [] + emits = {} + for i, (l, s) in enumerate(zip(self.layers, state)): + y, s, e = l.step_with_emits(x, s, constants=constants) + outputs.append(y) + new_state.append(s) + emits[f'parallel_{i}'] = e + combined = _combine_sequences(self.combination, outputs) + return combined, tuple(new_state), emits + + @classmethod + def from_config(cls, config, backend='mlx'): + layers = [c.make(backend=backend) for c in config.layers] + combination = CombinationMode(config.combination.value) + return cls(layers, combination=combination) diff --git a/sequence_layers/mlx/combinators_test.py b/sequence_layers/mlx/combinators_test.py new file mode 100644 index 0000000..5f2440d --- /dev/null +++ b/sequence_layers/mlx/combinators_test.py @@ -0,0 +1,257 @@ +"""Tests for combinator MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import combinators +from sequence_layers.mlx import dense +from sequence_layers.mlx import simple +from sequence_layers.mlx import test_utils + + +class SerialTest(parameterized.TestCase): + + def test_identity_serial(self): + layer = combinators.Serial([ + simple.Identity(), + simple.Identity(), + ]) + test_utils.verify_contract(self, layer, (4,)) + + def test_dense_serial(self): + layer = combinators.Serial([ + dense.Dense(in_features=4, features=8), + dense.Dense(in_features=8, features=16), + ]) + test_utils.verify_contract(self, layer, (4,)) + + def test_output_shape(self): + layer = combinators.Serial([ + dense.Dense(in_features=4, features=8), + dense.Dense(in_features=8, features=16), + ]) + self.assertEqual(layer.get_output_shape((4,)), (16,)) + + def test_from_config(self): + import sequence_layers.mlx + import sequence_layers.jax as sl + + config = sl.Serial.Config([ + sl.Identity.Config(), + sl.Dense.Config(features=8), + ]) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, combinators.Serial) + + +class ResidualTest(parameterized.TestCase): + + def test_identity_residual(self): + layer = combinators.Residual([simple.Identity()]) + test_utils.verify_contract(self, layer, (4,)) + + def test_residual_adds(self): + layer = combinators.Residual([simple.Identity()]) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + # y = identity(x) + x = 2 * x + expected = x.values * 2 + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + def test_from_config(self): + import sequence_layers.mlx + import sequence_layers.jax as sl + + config = sl.Residual.Config([sl.Identity.Config()]) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, combinators.Residual) + + +class RepeatTest(parameterized.TestCase): + + def test_repeat_identity(self): + layers = [simple.Identity() for _ in range(3)] + layer = combinators.Repeat(layers) + test_utils.verify_contract(self, layer, (4,)) + + def test_repeat_dense(self): + layers = [dense.Dense(in_features=4, features=4) for _ in range(3)] + layer = combinators.Repeat(layers) + test_utils.verify_contract(self, layer, (4,)) + + def test_num_repeats(self): + layers = [simple.Identity() for _ in range(5)] + layer = combinators.Repeat(layers) + self.assertEqual(layer.num_repeats, 5) + + def test_from_config(self): + import sequence_layers.mlx + import sequence_layers.jax as sl + + config = sl.Repeat.Config( + layer=sl.Identity.Config(), + num_repeats=4, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, combinators.Repeat) + self.assertEqual(mlx_layer.num_repeats, 4) + + +class ParallelTest(parameterized.TestCase): + + def test_stack(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.STACK, + ) + x = test_utils.random_sequence(1, 4, 3) + y = layer.layer(x) + # STACK: (3,) + (3,) -> (2, 3) + self.assertEqual(y.channel_shape, (2, 3)) + + def test_concat(self): + layer = combinators.Parallel( + [ + dense.Dense(in_features=4, features=3), + dense.Dense(in_features=4, features=5), + ], + combination=combinators.CombinationMode.CONCAT, + ) + x = test_utils.random_sequence(1, 4, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + + def test_add(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.ADD, + ) + x = test_utils.random_sequence(1, 4, 3) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (3,)) + # ADD of two identities = 2x + np.testing.assert_allclose(y.values, x.values * 2, atol=1e-6) + + def test_mean(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.MEAN, + ) + x = test_utils.random_sequence(1, 4, 3) + y = layer.layer(x) + # MEAN of two identities = x + np.testing.assert_allclose(y.values, x.values, atol=1e-6) + + def test_product(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.PRODUCT, + ) + x = test_utils.random_sequence(1, 4, 3) + y = layer.layer(x) + # PRODUCT of two identities = x^2 + np.testing.assert_allclose(y.values, x.values * x.values, atol=1e-6) + + def test_step_consistency(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.ADD, + ) + test_utils.verify_contract(self, layer, (4,)) + + def test_output_shape_stack(self): + layer = combinators.Parallel( + [simple.Identity(), simple.Identity()], + combination=combinators.CombinationMode.STACK, + ) + self.assertEqual(layer.get_output_shape((4,)), (2, 4)) + + def test_output_shape_concat(self): + layer = combinators.Parallel( + [ + dense.Dense(in_features=4, features=3), + dense.Dense(in_features=4, features=5), + ], + combination=combinators.CombinationMode.CONCAT, + ) + self.assertEqual(layer.get_output_shape((4,)), (8,)) + + def test_from_config(self): + import sequence_layers.mlx + import sequence_layers.jax as sl + from sequence_layers.jax import utils as jax_utils + + config = sl.Parallel.Config( + layers=[sl.Identity.Config(), sl.Identity.Config()], + combination=jax_utils.CombinationMode.ADD, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, combinators.Parallel) + test_utils.verify_contract(self, mlx_layer, (4,)) + + def test_unequal_ratio_raises(self): + from sequence_layers.mlx import convolution + + with self.assertRaises(ValueError): + combinators.Parallel([ + simple.Identity(), + convolution.Conv1D( + in_features=4, + filters=4, + kernel_size=3, + strides=2, + padding='causal', + ), + ]) + + +class TransformerEndToEndTest(parameterized.TestCase): + """End-to-end test with a full Transformer config.""" + + def test_decoder_transformer(self): + import sequence_layers.mlx + import sequence_layers.jax as sl + from sequence_layers.jax.attention import ( + dot_product_self_attention as dpa, + ) + import jax + + # Attention outputs [b, t, num_heads, units_per_head]. + # A Dense layer after it projects back to model dim. + config = sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + dpa.DotProductSelfAttention.Config( + num_heads=4, + units_per_head=8, + max_past_horizon=64, + ), + sl.Flatten.Config(), + sl.Dense.Config(features=32), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=64, activation=jax.nn.gelu), + sl.Dense.Config(features=32), + ]), + ]) + model = config.make(backend='mlx') + + # Layer mode. + x = test_utils.random_sequence(1, 10, 32) + y = model.layer(x) + self.assertEqual(y.shape, (1, 10, 32)) + + # Step mode. + spec = bt.ShapeDType((32,), mx.float32) + state = model.get_initial_state(1, spec) + x_step = test_utils.random_sequence(1, 1, 32) + for _ in range(5): + y_step, state = model.step(x_step, state) + self.assertEqual(y_step.shape, (1, 1, 32)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/conditioning.py b/sequence_layers/mlx/conditioning.py new file mode 100644 index 0000000..de0ec69 --- /dev/null +++ b/sequence_layers/mlx/conditioning.py @@ -0,0 +1,409 @@ +"""Conditioning layers for MLX.""" + +import enum +import math + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx.init_mapping import _to_mx_dtype +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence + + +# --------------------------------------------------------------------------- +# Broadcast helpers +# --------------------------------------------------------------------------- + + +def _broadcast_shapes(shape1, shape2): + """Compute the broadcast shape of two shapes (numpy-style).""" + s1 = list(shape1) + s2 = list(shape2) + while len(s1) < len(s2): + s1.insert(0, 1) + while len(s2) < len(s1): + s2.insert(0, 1) + result = [] + for a, b in zip(s1, s2): + if a == 1: + result.append(b) + elif b == 1: + result.append(a) + elif a == b: + result.append(a) + else: + raise ValueError(f'Shapes {shape1} and {shape2} are not broadcastable') + return tuple(result) + + +def _reshape_for_broadcast(*seqs): + """Reshape channel dims of many sequences to be broadcastable.""" + max_dims = max(x.ndim for x in seqs) + + def _maybe_reshape(values): + extra_dims = max_dims - values.ndim + if extra_dims == 0: + return values + batch_size, time = values.shape[:2] + shape = (batch_size, time) + (1,) * extra_dims + values.shape[2:] + return mx.reshape(values, shape) + + return tuple(x.apply_values(_maybe_reshape) for x in seqs) + + +def _combine_mask(*masks): + """AND together multiple masks.""" + result = masks[0] + for m in masks[1:]: + if m is not result: + result = mx.logical_and(result, m) + return result + + +def _sequence_broadcast_add(x, y): + """Broadcast-add two sequences.""" + x, y = _reshape_for_broadcast(x, y) + return Sequence(x.values + y.values, _combine_mask(x.mask, y.mask)) + + +def _sequence_broadcast_product(x, y): + """Broadcast-multiply two sequences.""" + x, y = _reshape_for_broadcast(x, y) + return Sequence(x.values * y.values, _combine_mask(x.mask, y.mask)) + + +def _sequence_broadcast_concat(x, y): + """Broadcast-concat on last axis.""" + x, y = _reshape_for_broadcast(x, y) + x_shape = x.values.shape + y_shape = y.values.shape + # Broadcast all dims except the last. + target_outer = [] + for i in range(len(x_shape) - 1): + target_outer.append(max(x_shape[i], y_shape[i])) + x_vals = mx.broadcast_to(x.values, tuple(target_outer) + (x_shape[-1],)) + y_vals = mx.broadcast_to(y.values, tuple(target_outer) + (y_shape[-1],)) + return Sequence( + mx.concatenate([x_vals, y_vals], axis=-1), + _combine_mask(x.mask, y.mask), + ) + + +def _sequence_unstack(seq, axis): + """Unstack a sequence along a channel axis.""" + if axis < 0: + axis += seq.ndim + if axis <= 1 or axis >= seq.ndim: + raise ValueError(f'Invalid axis: {axis=} {seq.ndim=}') + n = seq.values.shape[axis] + splits = [] + for i in range(n): + v = mx.take(seq.values, mx.array([i]), axis=axis) + v = mx.squeeze(v, axis=axis) + splits.append(v) + return [type(seq)(v, seq.mask) for v in splits] + + +# --------------------------------------------------------------------------- +# Conditioning helpers +# --------------------------------------------------------------------------- + + +def _get_conditioning(layer, conditioning_name, constants): + """Gets the conditioning from constants.""" + if constants is None: + raise ValueError( + f'{layer} requires conditioning via constants, got: {constants}' + ) + conditioning = constants.get(conditioning_name) + if conditioning is None: + raise ValueError( + f'{layer} expected {conditioning_name!r} in constants,' + f' got keys: {list(constants.keys())}' + ) + return conditioning + + +def _tensor_to_fake_sequence(t): + """Wrap a [B, ...] tensor as a [B, 1, ...] MaskedSequence.""" + batch_size = t.shape[0] + return MaskedSequence( + mx.expand_dims(t, axis=1), + mx.ones((batch_size, 1), dtype=mx.bool_), + ) + + +# --------------------------------------------------------------------------- +# Conditioning layer +# --------------------------------------------------------------------------- + + +class Conditioning(types.SequenceLayer): + """Conditions x on a conditioning signal from constants. + + Conditioning is done time-synchronized: each timestep of x is conditioned + on the corresponding timestep of c. + + Conditioning = Combine(x, Project(c)). + """ + + class Projection(enum.Enum): + IDENTITY = 1 + LINEAR = 2 + LINEAR_AFFINE = 3 + + class Combination(enum.Enum): + ADD = 1 + CONCAT = 2 + AFFINE = 3 + AFFINE_SHIFT = 4 + AFFINE_SCALE = 5 + MUL = 6 + CONCAT_BEFORE = 7 + + def __init__( + self, + *, + conditioning_name, + projection, + combination, + projection_channel_shape=None, + streaming=False, + affine_scale_offset=1.0, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self._conditioning_name = conditioning_name + self._projection = projection + self._combination = combination + self._projection_channel_shape = projection_channel_shape + self._streaming = streaming + self._affine_scale_offset = affine_scale_offset + self._compute_dtype = compute_dtype + self._param_dtype = param_dtype + + self._validate() + + # Projection kernel/bias (deferred until first use). + self.kernel = None + self.bias = None + self._equation = None + self._proj_initialized = False + + def _validate(self): + if ( + self._combination == self.Combination.AFFINE + and self._projection != self.Projection.LINEAR_AFFINE + ): + raise ValueError('AFFINE combination requires LINEAR_AFFINE projection.') + if ( + self._combination == self.Combination.AFFINE_SHIFT + and self._projection != self.Projection.LINEAR + ): + raise ValueError('AFFINE_SHIFT combination requires LINEAR projection.') + if ( + self._combination == self.Combination.AFFINE_SCALE + and self._projection != self.Projection.LINEAR + ): + raise ValueError('AFFINE_SCALE combination requires LINEAR projection.') + if ( + self._combination != self.Combination.AFFINE + and self._projection == self.Projection.LINEAR_AFFINE + ): + raise ValueError('LINEAR_AFFINE projection requires AFFINE combination.') + + def _ensure_projection_initialized(self, x_channel_shape, cond_channel_shape): + """Initialize projection kernel/bias on first use.""" + if self._proj_initialized: + return + if self._projection == self.Projection.IDENTITY: + self._proj_initialized = True + return + + proj_shape = self._projection_channel_shape + if proj_shape is None: + proj_shape = x_channel_shape + + if self._projection == self.Projection.LINEAR_AFFINE: + output_shape = (2,) + tuple(proj_shape) + else: + output_shape = tuple(proj_shape) + + # Build einsum equation matching Linen DenseShaped. + input_dims = ''.join( + chr(ord('a') + i) for i in range(len(cond_channel_shape)) + ) + output_dims = ''.join( + chr(ord('a') + i + len(cond_channel_shape)) + for i in range(len(output_shape)) + ) + + input_weight_dims = input_dims if input_dims else 'I' + output_weight_dims = output_dims if output_dims else 'O' + input_kernel_shape = cond_channel_shape if cond_channel_shape else (1,) + output_kernel_shape = output_shape if output_shape else (1,) + + self._equation = ( + f'...{input_dims},{input_weight_dims}{output_weight_dims}' + f'->...{output_dims}' + ) + kernel_shape = input_kernel_shape + output_kernel_shape + self.kernel = mx.zeros(kernel_shape, dtype=self._param_dtype) + self.bias = mx.zeros(output_kernel_shape, dtype=self._param_dtype) + self._proj_initialized = True + + def _projected_condition_shape(self, input_shape, condition_shape): + """Compute the channel shape after projection.""" + proj_shape = self._projection_channel_shape + if proj_shape is None: + proj_shape = input_shape + if self._projection == self.Projection.IDENTITY: + return condition_shape + elif self._projection == self.Projection.LINEAR: + return tuple(proj_shape) + elif self._projection == self.Projection.LINEAR_AFFINE: + return (2,) + tuple(proj_shape) + else: + raise ValueError(f'Unsupported projection: {self._projection}') + + def get_output_shape(self, input_shape, *, constants=None): + self._validate() + cond = _get_conditioning(self, self._conditioning_name, constants) + if isinstance(cond, (Sequence, MaskedSequence)): + cond_shape = cond.channel_shape + else: + cond_shape = cond.shape[1:] + proj_shape = self._projected_condition_shape(input_shape, cond_shape) + + if self._combination in ( + self.Combination.ADD, + self.Combination.MUL, + self.Combination.AFFINE_SHIFT, + self.Combination.AFFINE_SCALE, + ): + return _broadcast_shapes(input_shape, proj_shape) + elif self._combination in ( + self.Combination.CONCAT, + self.Combination.CONCAT_BEFORE, + ): + input_inner = input_shape[-1] if input_shape else 1 + proj_inner = proj_shape[-1] if proj_shape else 1 + outer = _broadcast_shapes(input_shape[:-1], proj_shape[:-1]) + return outer + (input_inner + proj_inner,) + elif self._combination == self.Combination.AFFINE: + proj_shape = proj_shape[1:] # Remove the '2' dim. + return _broadcast_shapes(input_shape, proj_shape) + else: + raise ValueError(f'Unsupported combination: {self._combination}') + + def get_output_dtype(self, input_dtype, *, constants=None): + if self._compute_dtype is not None: + return self._compute_dtype + return self._param_dtype + + def _project(self, x, conditioning): + """Apply projection to conditioning.""" + if self._projection == self.Projection.IDENTITY: + return conditioning + + self._ensure_projection_initialized( + x.channel_shape, conditioning.channel_shape + ) + compute_dtype = self._compute_dtype or self._param_dtype + + def project_fn(v): + y = mx.einsum(self._equation, v.astype(compute_dtype), self.kernel) + y = y + self.bias + return y + + return conditioning.apply_values(project_fn) + + def _combine(self, x, conditioning): + """Combine projected conditioning with input.""" + if self._combination == self.Combination.ADD: + return _sequence_broadcast_add(x, conditioning) + elif self._combination == self.Combination.CONCAT: + return _sequence_broadcast_concat(x, conditioning) + elif self._combination == self.Combination.CONCAT_BEFORE: + return _sequence_broadcast_concat(conditioning, x) + elif self._combination == self.Combination.AFFINE: + scale, shift = _sequence_unstack(conditioning, axis=2) + scale = scale.apply_values(lambda v: v + self._affine_scale_offset) + x_s, scale_s = _reshape_for_broadcast(x, scale) + x_s2, shift_s = _reshape_for_broadcast(x, shift) + values = x_s.values * scale_s.values + shift_s.values + mask = _combine_mask(x.mask, scale.mask, shift.mask) + return Sequence(values, mask) + elif self._combination == self.Combination.AFFINE_SHIFT: + return _sequence_broadcast_add(x, conditioning) + elif self._combination == self.Combination.AFFINE_SCALE: + conditioning = conditioning.apply_values( + lambda v: v + self._affine_scale_offset + ) + return _sequence_broadcast_product(x, conditioning) + elif self._combination == self.Combination.MUL: + return _sequence_broadcast_product(x, conditioning) + else: + raise ValueError(f'Unsupported combination: {self._combination}') + + @types.check_layer + def layer(self, x, *, constants=None): + conditioning = _get_conditioning(self, self._conditioning_name, constants) + if not isinstance(conditioning, (Sequence, MaskedSequence)): + conditioning = _tensor_to_fake_sequence(conditioning) + projected = self._project(x, conditioning) + return self._combine(x, projected) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + conditioning = _get_conditioning(self, self._conditioning_name, constants) + if ( + isinstance(conditioning, (Sequence, MaskedSequence)) + and not self._streaming + ): + return mx.zeros((batch_size,), mx.int32) + return () + + @types.check_step + def step(self, x, state, *, constants=None): + conditioning = _get_conditioning(self, self._conditioning_name, constants) + if not isinstance(conditioning, (Sequence, MaskedSequence)): + conditioning = _tensor_to_fake_sequence(conditioning) + elif not self._streaming: + time_index = state + step_size = x.shape[1] + idx = int(time_index[0]) + conditioning = type(conditioning)( + conditioning.values[:, idx : idx + step_size], + conditioning.mask[:, idx : idx + step_size], + ) + state = time_index + step_size + projected = self._project(x, conditioning) + result = self._combine(x, projected) + return result, state + + @classmethod + def from_config(cls, config): + """Create from a JAX Conditioning.Config.""" + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + # Map JAX enum values to MLX enum values. + projection = cls.Projection(config.projection.value) + combination = cls.Combination(config.combination.value) + return cls( + conditioning_name=config.conditioning_name, + projection=projection, + combination=combination, + projection_channel_shape=config.projection_channel_shape, + streaming=config.streaming, + affine_scale_offset=config.affine_scale_offset, + compute_dtype=compute_dtype, + param_dtype=_to_mx_dtype(config.param_dtype), + ) diff --git a/sequence_layers/mlx/conditioning_test.py b/sequence_layers/mlx/conditioning_test.py new file mode 100644 index 0000000..8430247 --- /dev/null +++ b/sequence_layers/mlx/conditioning_test.py @@ -0,0 +1,355 @@ +"""Tests for Conditioning MLX sequence layer.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import conditioning +from sequence_layers.mlx import test_utils + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence + + +def _make_constants(conditioning_seq, name='cond'): + return {name: conditioning_seq} + + +class ConditioningIdentityAddTest(parameterized.TestCase): + """Tests for IDENTITY projection + ADD combination.""" + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_seq = test_utils.random_sequence(2, 8, 4) + constants = _make_constants(cond_seq) + test_utils.verify_contract(self, layer, (4,), constants=constants) + + def test_output_shape(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_seq = test_utils.random_sequence(2, 8, 4) + constants = _make_constants(cond_seq) + self.assertEqual(layer.get_output_shape((4,), constants=constants), (4,)) + + def test_broadcast_add(self): + """x: [B,T,4], c: [B,T,1] → [B,T,4] via broadcast.""" + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_seq = test_utils.random_sequence(2, 8, 1) + constants = _make_constants(cond_seq) + self.assertEqual(layer.get_output_shape((4,), constants=constants), (4,)) + + def test_tensor_conditioning(self): + """Conditioning with a [B, dim] tensor (not a Sequence).""" + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_tensor = mx.random.normal(shape=(2, 4)) + constants = _make_constants(cond_tensor) + x = test_utils.random_sequence(2, 8, 4) + y = layer.layer(x, constants=constants) + self.assertEqual(y.channel_shape, (4,)) + + def test_step_non_streaming(self): + """Non-streaming: full conditioning passed, layer slices per step.""" + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + streaming=False, + ) + cond_seq = test_utils.random_sequence(2, 8, 4) + constants = _make_constants(cond_seq) + x = test_utils.random_sequence(2, 8, 4) + + # Layer mode. + y_layer = layer.layer(x, constants=constants) + + # Step mode (pass full conditioning; layer slices internally). + y_step, _ = test_utils.step_by_step( + layer, x, block_size=1, constants=constants + ) + np.testing.assert_allclose( + np.array(y_step.values), + np.array(y_layer.values), + atol=1e-5, + rtol=1e-5, + ) + + def test_step_streaming(self): + """Streaming: conditioning chunks arrive with input chunks.""" + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + streaming=True, + ) + cond_seq = test_utils.random_sequence(2, 8, 4) + x = test_utils.random_sequence(2, 8, 4) + + # Layer mode. + y_layer = layer.layer(x, constants=_make_constants(cond_seq)) + + # Step mode with stream_constants. + y_step, _ = test_utils.step_by_step( + layer, + x, + block_size=1, + stream_constants=_make_constants(cond_seq), + ) + np.testing.assert_allclose( + np.array(y_step.values), + np.array(y_layer.values), + atol=1e-5, + rtol=1e-5, + ) + + +class ConditioningIdentityConcatTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.CONCAT, + ) + cond_seq = test_utils.random_sequence(2, 8, 3) + constants = _make_constants(cond_seq) + x = test_utils.random_sequence(2, 8, 4) + y = layer.layer(x, constants=constants) + self.assertEqual(y.channel_shape, (7,)) + + def test_concat_before(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.CONCAT_BEFORE, + ) + cond_seq = test_utils.random_sequence(2, 8, 3) + constants = _make_constants(cond_seq) + x = test_utils.random_sequence(2, 8, 4) + y = layer.layer(x, constants=constants) + self.assertEqual(y.channel_shape, (7,)) + # CONCAT_BEFORE should have conditioning first. + np.testing.assert_allclose( + np.array(y.values[:, :, :3]), + np.array(cond_seq.values), + atol=1e-5, + ) + + +class ConditioningIdentityMulTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.MUL, + ) + cond_seq = test_utils.random_sequence(2, 8, 4) + constants = _make_constants(cond_seq) + test_utils.verify_contract(self, layer, (4,), constants=constants) + + +class ConditioningLinearAddTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + # input shape (4,), conditioning shape (6,), projected to (4,). + test_utils.verify_contract(self, layer, (4,), constants=constants) + + def test_output_shape(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.ADD, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + # LINEAR projects conditioning to input channel shape. + self.assertEqual(layer.get_output_shape((4,), constants=constants), (4,)) + + def test_with_projection_channel_shape(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.ADD, + projection_channel_shape=(8,), + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + # Projects to (8,), then broadcast-add with input (8,). + self.assertEqual(layer.get_output_shape((8,), constants=constants), (8,)) + + +class ConditioningLinearAffineShiftTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.AFFINE_SHIFT, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + test_utils.verify_contract(self, layer, (4,), constants=constants) + + +class ConditioningLinearAffineScaleTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.AFFINE_SCALE, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + test_utils.verify_contract(self, layer, (4,), constants=constants) + + +class ConditioningLinearAffineTest(parameterized.TestCase): + + def test_layer(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR_AFFINE, + combination=conditioning.Conditioning.Combination.AFFINE, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + test_utils.verify_contract(self, layer, (4,), constants=constants) + + def test_output_shape(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR_AFFINE, + combination=conditioning.Conditioning.Combination.AFFINE, + ) + cond_seq = test_utils.random_sequence(2, 8, 6) + constants = _make_constants(cond_seq) + # AFFINE combination strips the '2' dim from projected shape. + self.assertEqual(layer.get_output_shape((4,), constants=constants), (4,)) + + +class ConditioningValidationTest(parameterized.TestCase): + + def test_affine_requires_linear_affine(self): + with self.assertRaises(ValueError): + conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR, + combination=conditioning.Conditioning.Combination.AFFINE, + ) + + def test_affine_shift_requires_linear(self): + with self.assertRaises(ValueError): + conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.AFFINE_SHIFT, + ) + + def test_affine_scale_requires_linear(self): + with self.assertRaises(ValueError): + conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.AFFINE_SCALE, + ) + + def test_linear_affine_requires_affine(self): + with self.assertRaises(ValueError): + conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.LINEAR_AFFINE, + combination=conditioning.Conditioning.Combination.ADD, + ) + + def test_missing_constants(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + x = test_utils.random_sequence(2, 8, 4) + with self.assertRaises(ValueError): + layer.layer(x, constants=None) + + def test_missing_key(self): + layer = conditioning.Conditioning( + conditioning_name='cond', + projection=conditioning.Conditioning.Projection.IDENTITY, + combination=conditioning.Conditioning.Combination.ADD, + ) + x = test_utils.random_sequence(2, 8, 4) + with self.assertRaises(ValueError): + layer.layer(x, constants={'other': mx.zeros((2, 4))}) + + +class ConditioningFromConfigTest(parameterized.TestCase): + + def test_from_config_identity_add(self): + import sequence_layers.mlx + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.IDENTITY, + combination=jax_cond.BaseConditioning.Combination.ADD, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, conditioning.Conditioning) + + cond_seq = test_utils.random_sequence(2, 5, 8) + constants = _make_constants(cond_seq) + x = test_utils.random_sequence(2, 5, 8) + y = mlx_layer.layer(x, constants=constants) + self.assertEqual(y.channel_shape, (8,)) + + def test_from_config_linear_add(self): + import sequence_layers.mlx + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR, + combination=jax_cond.BaseConditioning.Combination.ADD, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, conditioning.Conditioning) + + def test_from_config_linear_affine(self): + import sequence_layers.mlx + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR_AFFINE, + combination=jax_cond.BaseConditioning.Combination.AFFINE, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, conditioning.Conditioning) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/convolution.py b/sequence_layers/mlx/convolution.py new file mode 100644 index 0000000..e176d65 --- /dev/null +++ b/sequence_layers/mlx/convolution.py @@ -0,0 +1,1113 @@ +"""Convolution layers for MLX.""" + +import fractions +import math + +import mlx.core as mx +import mlx.nn as nn + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +PaddingMode = bt.PaddingMode + + +# --------------------------------------------------------------------------- +# Padding utilities (ported from jax/utils.py and jax/convolution.py) +# --------------------------------------------------------------------------- + + +def _effective_kernel_size(kernel_size, dilation_rate): + return (kernel_size - 1) * dilation_rate + 1 + + +def _explicit_padding(padding, kernel_size, stride, dilation_rate): + """Returns (pad_left, pad_right) for the given padding mode.""" + if not isinstance(padding, str): + return tuple(padding) + + ek = _effective_kernel_size(kernel_size, dilation_rate) + + if padding in (PaddingMode.CAUSAL_VALID.value, PaddingMode.CAUSAL.value): + return (ek - 1, 0) + elif padding == PaddingMode.SEMICAUSAL.value: + pad_left = max(ek - stride, 0) + return (pad_left, ek - 1 - pad_left) + elif padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + ): + return (0, ek - 1) + elif padding == PaddingMode.SAME.value: + pad_amount = ek - 1 + pad_left = pad_amount // 2 + return (pad_left, pad_amount - pad_left) + elif padding == PaddingMode.VALID.value: + return (0, 0) + elif padding == PaddingMode.SEMICAUSAL_FULL.value: + return (ek - stride, ek - 1) + else: + raise ValueError(f'Unsupported padding mode: {padding}') + + +def _buffer_width(padding, kernel_size, stride, dilation_rate): + """Returns the buffer width for step mode.""" + ek = _effective_kernel_size(kernel_size, dilation_rate) + + if padding == PaddingMode.SEMICAUSAL.value: + return max(ek - stride, 0) + elif padding in ( + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + ): + return (ek - 1) // stride * stride + elif padding in ( + PaddingMode.CAUSAL.value, + PaddingMode.CAUSAL_VALID.value, + ): + return ek - 1 + else: + raise ValueError(f'Unsupported step padding: {padding}') + + +def _supports_step(padding): + """Returns True if the padding mode supports step-by-step processing.""" + return padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ) + + +def _compute_conv_mask( + mask, kernel_size, stride, dilation_rate, padding, is_step +): + """Compute the output mask for a convolution-like operation.""" + ek = _effective_kernel_size(kernel_size, dilation_rate) + + if is_step: + if isinstance(padding, str) and padding in ( + PaddingMode.SAME.value, + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + pad_left, pad_right = _explicit_padding( + padding, kernel_size, stride, dilation_rate + ) + # Use a simple convolution-like mask computation with float kernel. + kernel = [0.0] * pad_left + [1.0] + [0.0] * pad_right + kernel = mx.array(kernel, dtype=mx.float32).reshape(1, -1, 1) + mask_f = mask[:, :, None].astype(mx.float32) + mask_conv = mx.conv1d(mask_f, kernel, stride=stride) + return mx.squeeze(mask_conv, axis=-1).astype(mx.bool_) + elif not isinstance(padding, str) or padding in ( + PaddingMode.VALID.value, + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + ): + return _compute_conv_mask_logical( + mask, kernel_size, stride, dilation_rate + ) + else: + return _compute_conv_mask_logical( + mask, kernel_size, stride, dilation_rate + ) + + # Layer mode. + if isinstance(padding, str) and padding in ( + PaddingMode.SAME.value, + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + if stride > 1: + mask = mask[:, ::stride] + return mask + + # VALID-like modes: need to compute mask through reduce_window equiv. + pad_left, pad_right = _explicit_padding( + padding, kernel_size, stride, dilation_rate + ) + is_causal_valid = ( + isinstance(padding, str) and padding == PaddingMode.CAUSAL_VALID.value + ) + mask = mx.pad( + mask, + [(0, 0), (pad_left, pad_right)], + constant_values=is_causal_valid, + ) + is_semicausal_full = ( + isinstance(padding, str) and padding == PaddingMode.SEMICAUSAL_FULL.value + ) + return _compute_conv_mask_logical( + mask, + kernel_size, + stride, + dilation_rate, + use_logical_or=is_semicausal_full, + ) + + +def _compute_conv_mask_logical( + mask, kernel_size, stride, dilation_rate, use_logical_or=False +): + """Windowed AND/OR mask computation.""" + # Optimized path for dilation=1 and kernel_size divisible by stride. + if dilation_rate == 1 and kernel_size % stride == 0: + num_frames = mask.shape[1] // stride + mask = mask[:, : num_frames * stride] + mask = mask.reshape(mask.shape[0], num_frames, stride) + if use_logical_or: + mask = mx.max(mask, axis=-1) + else: + mask = mx.min(mask, axis=-1) + kernel_size = kernel_size // stride + stride = 1 + + if kernel_size == 1 and stride == 1: + return mask + + # Use float conv to simulate reduce_window. + mask_f = mask[:, :, None].astype(mx.float32) + # Build a kernel with ones at dilated positions. + if dilation_rate == 1: + kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) + else: + ek = _effective_kernel_size(kernel_size, dilation_rate) + k = [0.0] * ek + for i in range(kernel_size): + k[i * dilation_rate] = 1.0 + kernel = mx.array(k, dtype=mx.float32).reshape(1, -1, 1) + + result = mx.conv1d(mask_f, kernel, stride=stride) + result = mx.squeeze(result, axis=-1) + + if use_logical_or: + return result > 0.0 + else: + return result >= float(kernel_size) + + +def _compute_initial_state(batch_size, input_spec, buf_width, padding): + """Create initial buffer state for step mode.""" + if padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + mask = mx.ones((batch_size, buf_width), dtype=bt.MASK_DTYPE) + elif padding in ( + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + mask = mx.zeros((batch_size, buf_width), dtype=bt.MASK_DTYPE) + else: + raise ValueError(f'Step not supported with padding: {padding}') + + values = mx.zeros( + (batch_size, buf_width) + input_spec.shape, + dtype=input_spec.dtype, + ) + return MaskedSequence(values, mask) + + +# --------------------------------------------------------------------------- +# Conv1D +# --------------------------------------------------------------------------- + + +class Conv1D(types.SequenceLayer): + """1D strided or dilated convolution layer. + + Supports causal, reverse_causal, same, and valid padding modes. + Step-by-step processing is supported for causal padding modes. + """ + + def __init__( + self, + *, + in_features: int, + filters: int, + kernel_size: int, + strides: int = 1, + dilation_rate: int = 1, + padding: str = 'valid', + groups: int = 1, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.in_features = in_features + self.filters = filters + self.kernel_size = kernel_size + self.strides = strides + self.dilation_rate = dilation_rate + self.padding = padding + self.groups = groups + self.use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + if in_features % groups != 0: + raise ValueError(f'{in_features=} must be divisible by {groups=}.') + + # Create kernel: [out_channels, kernel_size, in_channels // groups] + # This is the MLX Conv1d convention. + self._conv = nn.Conv1d( + in_channels=in_features, + out_channels=filters, + kernel_size=kernel_size, + stride=strides, + # Padding handled manually. + padding=0, + dilation=dilation_rate, + bias=use_bias, + ) + + @property + def supports_step(self): + return _supports_step(self.padding) + + @property + def block_size(self): + return self.strides + + @property + def output_ratio(self): + return fractions.Fraction(1, self.strides) + + @property + def input_latency(self): + ek = _effective_kernel_size(self.kernel_size, self.dilation_rate) + if self.padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self.padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + f'Conv1D requires rank 3 input, got channel_shape={input_shape}.' + ) + return (self.filters,) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.compute_dtype or self._param_dtype + + def _forward(self, values, pad_left, pad_right): + """Apply convolution with explicit padding.""" + if pad_left > 0 or pad_right > 0: + values = mx.pad( + values, + [(0, 0), (pad_left, pad_right), (0, 0)], + ) + compute_dtype = self.compute_dtype or self._param_dtype + values = values.astype(compute_dtype) + y = self._conv(values) + if self.activation is not None: + y = self.activation(y) + return y + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = _buffer_width( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + if not bw: + return () + return _compute_initial_state( + batch_size, + input_spec, + bw, + self.padding, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + ek = _effective_kernel_size(self.kernel_size, self.dilation_rate) + if ek > 1: + x = x.mask_invalid() + + bw = _buffer_width( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + + if bw: + state = state.concatenate(x) + else: + state = x + + # In step mode, padding is provided by the buffer — use valid conv. + values = self._forward(state.values, 0, 0) + mask = _compute_conv_mask( + state.mask, + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.kernel_size > 1: + x = x.mask_invalid() + + pad_left, pad_right = _explicit_padding( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + values = self._forward(x.values, pad_left, pad_right) + mask = _compute_conv_mask( + x.mask, + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + is_step=False, + ) + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + """Create from a Linen Conv1D.Config (deferred).""" + return DeferredConv1D(config) + + +# --------------------------------------------------------------------------- +# DepthwiseConv1D +# --------------------------------------------------------------------------- + + +class DepthwiseConv1D(types.SequenceLayer): + """1D depthwise convolution layer. + + Each input channel is convolved independently. The output has + in_features * depth_multiplier channels. + """ + + def __init__( + self, + *, + in_features: int, + kernel_size: int, + depth_multiplier: int = 1, + strides: int = 1, + dilation_rate: int = 1, + padding: str = 'valid', + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.in_features = in_features + self.kernel_size = kernel_size + self.depth_multiplier = depth_multiplier + self.strides = strides + self.dilation_rate = dilation_rate + self.padding = padding + self.use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + out_features = in_features * depth_multiplier + # Depthwise: groups = in_features, each group has depth_multiplier + # output channels. + self._conv = nn.Conv1d( + in_channels=in_features, + out_channels=out_features, + kernel_size=kernel_size, + stride=strides, + padding=0, + dilation=dilation_rate, + groups=in_features, + bias=use_bias, + ) + + @property + def supports_step(self): + return _supports_step(self.padding) + + @property + def block_size(self): + return self.strides + + @property + def output_ratio(self): + return fractions.Fraction(1, self.strides) + + @property + def input_latency(self): + ek = _effective_kernel_size(self.kernel_size, self.dilation_rate) + if self.padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self.padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + 'DepthwiseConv1D requires rank 3 input, got ' + f'channel_shape={input_shape}.' + ) + return (input_shape[0] * self.depth_multiplier,) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.compute_dtype or self._param_dtype + + def _forward(self, values, pad_left, pad_right): + if pad_left > 0 or pad_right > 0: + values = mx.pad( + values, + [(0, 0), (pad_left, pad_right), (0, 0)], + ) + compute_dtype = self.compute_dtype or self._param_dtype + values = values.astype(compute_dtype) + y = self._conv(values) + if self.activation is not None: + y = self.activation(y) + return y + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = _buffer_width( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + if not bw: + return () + return _compute_initial_state( + batch_size, + input_spec, + bw, + self.padding, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + ek = _effective_kernel_size(self.kernel_size, self.dilation_rate) + if ek > 1: + x = x.mask_invalid() + + bw = _buffer_width( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + + if bw: + state = state.concatenate(x) + else: + state = x + + values = self._forward(state.values, 0, 0) + mask = _compute_conv_mask( + state.mask, + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.kernel_size > 1: + x = x.mask_invalid() + + pad_left, pad_right = _explicit_padding( + self.padding, + self.kernel_size, + self.strides, + self.dilation_rate, + ) + values = self._forward(x.values, pad_left, pad_right) + mask = _compute_conv_mask( + x.mask, + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + is_step=False, + ) + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + return DeferredDepthwiseConv1D(config) + + +# --------------------------------------------------------------------------- +# Conv1DTranspose +# --------------------------------------------------------------------------- + + +def _transpose_conv_output_trim(kernel_size, stride, dilation_rate, padding): + """Output-side trimming for transpose convolutions in MLX. + + MLX conv_transpose1d with padding=0 produces output of size: + raw = (t - 1) * stride + ek + This function returns (trim_left, trim_right) to cut raw output + to the desired size. + """ + ek = _effective_kernel_size(kernel_size, dilation_rate) + total_trim = max(0, ek - stride) + + if padding == PaddingMode.CAUSAL.value: + return (0, total_trim) + elif padding == PaddingMode.SAME.value: + trim_left = total_trim // 2 + return (trim_left, total_trim - trim_left) + elif padding == PaddingMode.VALID.value: + return (0, 0) + elif padding == PaddingMode.SEMICAUSAL_FULL.value: + return (0, 0) + else: + raise ValueError(f'Unsupported padding: {padding}') + + +def _compute_conv_transpose_output_length( + time, kernel_size, stride, dilation_rate, padding +): + ek = _effective_kernel_size(kernel_size, dilation_rate) + if padding in ( + PaddingMode.SAME.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + return time * stride + elif padding == PaddingMode.VALID.value: + return time * stride + max(ek - stride, 0) + else: + raise ValueError(f'Unsupported padding: {padding}') + + +def _compute_conv_transpose_mask( + mask, kernel_size, stride, dilation_rate, padding +): + """Compute output mask for a transpose convolution.""" + ek = _effective_kernel_size(kernel_size, dilation_rate) + + if ek <= stride or padding in ( + PaddingMode.SAME.value, + PaddingMode.CAUSAL.value, + ): + return mx.repeat(mask, stride, axis=1) + + # Use transpose convolution to compute the mask. + tl, tr = _transpose_conv_output_trim( + kernel_size, + stride, + dilation_rate, + padding, + ) + + if padding == PaddingMode.SEMICAUSAL_FULL.value: + test_signal = mask + test_fn = lambda m: m > 0.0 + else: + test_signal = mx.logical_not(mask) + test_fn = lambda m: m == 0.0 + + kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) + signal = test_signal.astype(mx.float32)[:, :, None] + + result = mx.conv_transpose1d( + signal, + kernel, + stride=stride, + padding=0, + dilation=dilation_rate, + ) + # Trim to match desired output. + if tl > 0: + result = result[:, tl:] + if tr > 0: + result = result[:, :-tr] + result = mx.squeeze(result, axis=-1) + return test_fn(result) + + +class Conv1DTranspose(types.SequenceLayer): + """1D transpose (deconvolution) layer for upsampling. + + Supports 'valid', 'causal', and 'same' padding modes. + """ + + def __init__( + self, + *, + in_features: int, + filters: int, + kernel_size: int, + strides: int = 1, + dilation_rate: int = 1, + padding: str = 'valid', + groups: int = 1, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.in_features = in_features + self.filters = filters + self.kernel_size = kernel_size + self.strides = strides + self.dilation_rate = dilation_rate + self.padding = padding + self.groups = groups + self.use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + # Create kernel and bias manually — nn.ConvTranspose1d layout differs. + key = mx.random.key(0) + init = init_mapping._make_variance_scaling_init( + 'fan_in', 'truncated_normal' + ) + # Kernel: [out_channels, kernel_size, in_channels // groups] + self.kernel = init( + key, + (filters, kernel_size, in_features // groups), + param_dtype, + ) + if use_bias: + self.bias = mx.zeros((filters,), dtype=param_dtype) + + @property + def supports_step(self): + return self.padding == PaddingMode.CAUSAL.value + + @property + def block_size(self): + return 1 + + @property + def output_ratio(self): + return fractions.Fraction(self.strides) + + @property + def input_latency(self): + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 1: + raise ValueError( + 'Conv1DTranspose requires rank 3 input, got ' + f'channel_shape={input_shape}.' + ) + return (self.filters,) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.compute_dtype or self._param_dtype + + def _raw_conv_transpose(self, values): + """Apply raw transpose convolution (no padding trim).""" + compute_dtype = self.compute_dtype or self._param_dtype + values = values.astype(compute_dtype) + y = mx.conv_transpose1d( + values, + self.kernel.astype(compute_dtype), + stride=self.strides, + padding=0, + dilation=self.dilation_rate, + groups=self.groups, + ) + if self.use_bias: + y = y + self.bias.astype(compute_dtype) + if self.activation is not None: + y = self.activation(y) + return y + + def _forward(self, values): + """Apply transpose convolution with output trimming.""" + y = self._raw_conv_transpose(values) + tl, tr = _transpose_conv_output_trim( + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + ) + if tl > 0: + y = y[:, tl:] + if tr > 0: + y = y[:, :-tr] + return y + + @property + def _ola_buffer_width(self): + return max( + 0, + _effective_kernel_size(self.kernel_size, self.dilation_rate) + - self.strides, + ) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not self.supports_step: + return () + bw = self._ola_buffer_width + if not bw: + return () + compute_dtype = self.compute_dtype or self._param_dtype + return mx.zeros( + (batch_size, bw, self.filters), + dtype=compute_dtype, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + # Use raw conv (no trimming) for overlap-add. + values = self._raw_conv_transpose(x.values) + mask = mx.repeat(x.mask, self.strides, axis=1) + + bw = self._ola_buffer_width + if bw: + # Overlap-add: the first bw samples overlap with buffer. + overlap = values[:, :bw] + state + rest = values[:, bw:] + values = mx.concatenate([overlap, rest], axis=1) + + output_samples = self.strides * x.shape[1] + output = values[:, :output_samples] + state = values[:, output_samples : output_samples + bw] + if state.shape[1] < bw: + pad_right = bw - state.shape[1] + state = mx.pad(state, [(0, 0), (0, pad_right), (0, 0)]) + values = output + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.padding == PaddingMode.CAUSAL.value: + # For causal, use raw conv and trim trailing overlap. + values = self._raw_conv_transpose(x.values) + expected_time = x.shape[1] * self.strides + values = values[:, :expected_time] + mask = mx.repeat(x.mask, self.strides, axis=1) + else: + values = self._forward(x.values) + mask = _compute_conv_transpose_mask( + x.mask, + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + ) + expected_time = _compute_conv_transpose_output_length( + x.shape[1], + self.kernel_size, + self.strides, + self.dilation_rate, + self.padding, + ) + values = values[:, :expected_time] + mask = mask[:, :expected_time] + + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + return DeferredConv1DTranspose(config) + + +# --------------------------------------------------------------------------- +# Deferred wrappers (Linen configs lack in_features) +# --------------------------------------------------------------------------- + + +class DeferredConv1D(types.SequenceLayer): + """Conv1D that defers weight creation until first input.""" + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features): + if self._inner is not None: + return + c = self._config + compute_dtype = getattr(c, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(c.param_dtype) + activation = init_mapping.map_activation(getattr(c, 'activation', None)) + self._inner = Conv1D( + in_features=in_features, + filters=c.filters, + kernel_size=c.kernel_size, + strides=c.strides, + dilation_rate=c.dilation_rate, + padding=c.padding, + groups=c.groups, + use_bias=c.use_bias, + activation=activation, + compute_dtype=compute_dtype, + param_dtype=param_dtype, + ) + + @property + def supports_step(self): + return _supports_step(self._config.padding) + + @property + def block_size(self): + return self._config.strides + + @property + def output_ratio(self): + return fractions.Fraction(1, self._config.strides) + + @property + def input_latency(self): + ek = _effective_kernel_size( + self._config.kernel_size, self._config.dilation_rate + ) + if self._config.padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self._config.padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return (self._config.filters,) + + def get_output_dtype(self, input_dtype, *, constants=None): + cd = getattr(self._config, 'compute_dtype', None) + if cd is not None: + return init_mapping._to_mx_dtype(cd) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_initialized(input_spec.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer(x, constants=constants) + + def step(self, x, state, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.step(x, state, constants=constants) + + +class DeferredDepthwiseConv1D(types.SequenceLayer): + """DepthwiseConv1D that defers weight creation until first input.""" + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features): + if self._inner is not None: + return + c = self._config + compute_dtype = getattr(c, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(c.param_dtype) + activation = init_mapping.map_activation(getattr(c, 'activation', None)) + self._inner = DepthwiseConv1D( + in_features=in_features, + kernel_size=c.kernel_size, + depth_multiplier=c.depth_multiplier, + strides=c.strides, + dilation_rate=c.dilation_rate, + padding=c.padding, + use_bias=c.use_bias, + activation=activation, + compute_dtype=compute_dtype, + param_dtype=param_dtype, + ) + + @property + def supports_step(self): + return _supports_step(self._config.padding) + + @property + def block_size(self): + return self._config.strides + + @property + def output_ratio(self): + return fractions.Fraction(1, self._config.strides) + + @property + def input_latency(self): + ek = _effective_kernel_size( + self._config.kernel_size, self._config.dilation_rate + ) + if self._config.padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self._config.padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return (input_shape[0] * self._config.depth_multiplier,) + + def get_output_dtype(self, input_dtype, *, constants=None): + cd = getattr(self._config, 'compute_dtype', None) + if cd is not None: + return init_mapping._to_mx_dtype(cd) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_initialized(input_spec.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer(x, constants=constants) + + def step(self, x, state, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.step(x, state, constants=constants) + + +class DeferredConv1DTranspose(types.SequenceLayer): + """Conv1DTranspose that defers weight creation until first input.""" + + def __init__(self, config): + super().__init__() + self._config = config + self._inner = None + + def _ensure_initialized(self, in_features): + if self._inner is not None: + return + c = self._config + compute_dtype = getattr(c, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + param_dtype = init_mapping._to_mx_dtype(c.param_dtype) + activation = init_mapping.map_activation(getattr(c, 'activation', None)) + self._inner = Conv1DTranspose( + in_features=in_features, + filters=c.filters, + kernel_size=c.kernel_size, + strides=c.strides, + dilation_rate=getattr(c, 'dilation_rate', 1), + padding=c.padding, + groups=getattr(c, 'groups', 1), + use_bias=c.use_bias, + activation=activation, + compute_dtype=compute_dtype, + param_dtype=param_dtype, + ) + + @property + def supports_step(self): + return self._config.padding == PaddingMode.CAUSAL.value + + @property + def block_size(self): + return 1 + + @property + def output_ratio(self): + return fractions.Fraction(self._config.strides) + + @property + def input_latency(self): + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return (self._config.filters,) + + def get_output_dtype(self, input_dtype, *, constants=None): + cd = getattr(self._config, 'compute_dtype', None) + if cd is not None: + return init_mapping._to_mx_dtype(cd) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_initialized(input_spec.shape[-1]) + return self._inner.get_initial_state( + batch_size, input_spec, constants=constants + ) + + def layer(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer(x, constants=constants) + + def step(self, x, state, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.step(x, state, constants=constants) diff --git a/sequence_layers/mlx/convolution_test.py b/sequence_layers/mlx/convolution_test.py new file mode 100644 index 0000000..317da68 --- /dev/null +++ b/sequence_layers/mlx/convolution_test.py @@ -0,0 +1,252 @@ +"""Tests for convolution MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import convolution +from sequence_layers.mlx import test_utils + + +class Conv1DTest(parameterized.TestCase): + + @parameterized.parameters( + ('causal',), + ('causal_valid',), + ) + def test_causal_paddings(self, padding): + layer = convolution.Conv1D( + in_features=4, + filters=8, + kernel_size=3, + padding=padding, + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-4, + rtol=1e-4, + ) + + def test_valid(self): + layer = convolution.Conv1D( + in_features=4, + filters=8, + kernel_size=3, + padding='valid', + ) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + # Valid: output time = input_time - kernel_size + 1 = 6 + self.assertEqual(y.shape[1], 6) + + def test_same(self): + layer = convolution.Conv1D( + in_features=4, + filters=8, + kernel_size=3, + padding='same', + ) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 8) + + def test_stride(self): + layer = convolution.Conv1D( + in_features=4, + filters=8, + kernel_size=3, + strides=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + time=8, + atol=1e-4, + rtol=1e-4, + ) + + def test_dilation(self): + layer = convolution.Conv1D( + in_features=4, + filters=8, + kernel_size=3, + dilation_rate=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-4, + rtol=1e-4, + ) + + def test_output_shape(self): + layer = convolution.Conv1D( + in_features=4, + filters=16, + kernel_size=3, + padding='causal', + ) + self.assertEqual(layer.get_output_shape((4,)), (16,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import convolution as jax_conv + + config = jax_conv.Conv1D.Config( + filters=8, + kernel_size=3, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + convolution.DeferredConv1D, + ) + x = test_utils.random_sequence(1, 8, 4) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + + +class DepthwiseConv1DTest(parameterized.TestCase): + + @parameterized.parameters( + ('causal',), + ('causal_valid',), + ) + def test_causal_paddings(self, padding): + layer = convolution.DepthwiseConv1D( + in_features=4, + kernel_size=3, + padding=padding, + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-4, + rtol=1e-4, + ) + + def test_depth_multiplier(self): + layer = convolution.DepthwiseConv1D( + in_features=4, + kernel_size=3, + depth_multiplier=2, + padding='causal', + ) + self.assertEqual(layer.get_output_shape((4,)), (8,)) + + def test_valid(self): + layer = convolution.DepthwiseConv1D( + in_features=4, + kernel_size=3, + padding='valid', + ) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 6) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import convolution as jax_conv + + config = jax_conv.DepthwiseConv1D.Config( + kernel_size=3, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + convolution.DeferredDepthwiseConv1D, + ) + x = test_utils.random_sequence(1, 8, 4) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (4,)) + + +class Conv1DTransposeTest(parameterized.TestCase): + + def test_causal(self): + layer = convolution.Conv1DTranspose( + in_features=4, + filters=8, + kernel_size=3, + strides=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-4, + rtol=1e-4, + ) + + def test_valid(self): + layer = convolution.Conv1DTranspose( + in_features=4, + filters=8, + kernel_size=3, + strides=2, + padding='valid', + ) + x = test_utils.random_sequence(1, 4, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + # Valid: output = input * stride + max(ek - stride, 0) + expected_time = 4 * 2 + max(3 - 2, 0) + self.assertEqual(y.shape[1], expected_time) + + def test_same(self): + layer = convolution.Conv1DTranspose( + in_features=4, + filters=8, + kernel_size=3, + strides=2, + padding='same', + ) + x = test_utils.random_sequence(1, 4, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 8) + + def test_output_ratio(self): + layer = convolution.Conv1DTranspose( + in_features=4, + filters=8, + kernel_size=3, + strides=3, + padding='causal', + ) + import fractions + + self.assertEqual(layer.output_ratio, fractions.Fraction(3)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import convolution as jax_conv + + config = jax_conv.Conv1DTranspose.Config( + filters=8, + kernel_size=3, + strides=2, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + convolution.DeferredConv1DTranspose, + ) + x = test_utils.random_sequence(1, 4, 4) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/cross_backend_test.py b/sequence_layers/mlx/cross_backend_test.py new file mode 100644 index 0000000..7eb0122 --- /dev/null +++ b/sequence_layers/mlx/cross_backend_test.py @@ -0,0 +1,2344 @@ +"""Cross-backend numerical tests: JAX (Linen) vs MLX. + +Verifies that both backends produce numerically identical outputs for all +ported layer types when initialised from the same random Linen parameters. +""" + +import jax +import jax.numpy as jnp +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import sequence_layers.jax as sl +from sequence_layers.jax import types as jax_types +from sequence_layers.jax.attention import common as attn_common +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import weight_converter + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _compare_stateless_float( + test_case, + config, + input_shape, + *, + batch_size=2, + time=8, + atol=1e-5, + rtol=1e-5, + seed=42, +): + """Compare a stateless layer that requires no parameters (float inputs).""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_out = np.array(jax_model.apply(variables, x_jax, training=False).values) + + # MLX. + mlx_model = config.make(backend='mlx') + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__}: outputs differ', + ) + + +def _compare_parametric_float( + test_case, + config, + input_shape, + *, + batch_size=2, + time=8, + atol=1e-5, + rtol=1e-5, + seed=42, +): + """Compare a parametric layer with float inputs (Conv, Dense, Norm, etc.).""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + # JAX: init + run. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_params = variables['params'] + jax_out = np.array( + jax_model.apply({'params': jax_params}, x_jax, training=False).values + ) + + # MLX: create, load weights, run. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__}: outputs differ', + ) + + +def _compare_parametric_int( + test_case, + config, + *, + batch_size=2, + time=8, + atol=1e-5, + rtol=1e-5, + seed=42, +): + """Compare a parametric layer with integer token inputs (Embedding).""" + rng = np.random.RandomState(seed) + # Infer vocab size from config. + vocab = getattr(config, 'num_embeddings', 32) + tokens = rng.randint(0, vocab, size=(batch_size, time)).astype(np.int32) + mask = np.ones((batch_size, time), dtype=bool) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(tokens), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_params = variables['params'] + jax_out = np.array( + jax_model.apply({'params': jax_params}, x_jax, training=False).values + ) + + # MLX. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + x_mx = Sequence( + mx.array(tokens, dtype=mx.int32), + mx.array(mask, dtype=mx.bool_), + ) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__}: outputs differ', + ) + + +def _compare_with_constants( + test_case, + config, + input_shape, + constants_fn, + *, + batch_size=2, + time=8, + atol=1e-4, + rtol=1e-4, + seed=42, +): + """Compare a parametric layer that needs constants (cross-attention).""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + jax_constants, mlx_constants = constants_fn(batch_size, time, rng) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init( + jax.random.PRNGKey(0), + x_jax, + training=False, + constants=jax_constants, + ) + jax_params = variables['params'] + jax_out = np.array( + jax_model.apply( + {'params': jax_params}, + x_jax, + training=False, + constants=jax_constants, + ).values + ) + + # MLX. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + constants=mlx_constants, + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx, constants=mlx_constants).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__}: outputs differ', + ) + + +# --------------------------------------------------------------------------- +# Test Classes +# --------------------------------------------------------------------------- + + +class StatelessActivationsTest(parameterized.TestCase): + """Stateless activations: JAX vs MLX.""" + + @parameterized.named_parameters( + ('relu', sl.Relu.Config()), + ('gelu', sl.Gelu.Config(approximate=False)), + ('swish', sl.Swish.Config()), + ('tanh', sl.Tanh.Config()), + ('sigmoid', sl.Sigmoid.Config()), + ('leaky_relu', sl.LeakyRelu.Config()), + ('elu', sl.Elu.Config()), + ('softmax', sl.Softmax.Config()), + ('softplus', sl.Softplus.Config()), + ) + def test_activation(self, config): + _compare_stateless_float(self, config, (16,)) + + +class StatelessShapeOpsTest(parameterized.TestCase): + """Stateless shape operations: JAX vs MLX.""" + + @parameterized.named_parameters( + ('flatten_2d', sl.Flatten.Config(), (4, 3)), + ('reshape', sl.Reshape.Config(output_shape=(2, 4)), (8,)), + ('expand_dims', sl.ExpandDims.Config(axis=-1), (8,)), + ('squeeze', sl.Squeeze.Config(), (8, 1)), + ('transpose', sl.Transpose.Config(), (4, 3)), + ) + def test_shape_op(self, config, input_shape): + _compare_stateless_float(self, config, input_shape) + + +class StatelessMiscTest(parameterized.TestCase): + """Stateless misc layers: JAX vs MLX.""" + + @parameterized.named_parameters( + ('scale', sl.Scale.Config(scale=0.5), (8,)), + ('add', sl.Add.Config(shift=1.0), (8,)), + ('gated_linear_unit', sl.GatedLinearUnit.Config(), (16,)), + ('gated_tanh_unit', sl.GatedTanhUnit.Config(), (16,)), + ) + def test_misc(self, config, input_shape): + _compare_stateless_float(self, config, input_shape) + + def test_one_hot(self): + config = sl.OneHot.Config(depth=8) + rng = np.random.RandomState(42) + tokens = rng.randint(0, 8, size=(2, 8)).astype(np.int32) + mask = np.ones((2, 8), dtype=bool) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(tokens), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_out = np.array(jax_model.apply(variables, x_jax, training=False).values) + + # MLX. + mlx_model = config.make(backend='mlx') + x_mx = Sequence( + mx.array(tokens, dtype=mx.int32), + mx.array(mask, dtype=mx.bool_), + ) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='OneHot outputs differ', + ) + + +class SamplingTest(parameterized.TestCase): + """Downsample1D / Upsample1D: JAX vs MLX.""" + + @parameterized.named_parameters( + ('downsample_2', sl.Downsample1D.Config(rate=2), (8,)), + ('downsample_3', sl.Downsample1D.Config(rate=3), (8,)), + ('upsample_2', sl.Upsample1D.Config(rate=2), (8,)), + ('upsample_3', sl.Upsample1D.Config(rate=3), (8,)), + ('downsample_4', sl.Downsample1D.Config(rate=4), (16,)), + ) + def test_sampling(self, config, input_shape): + _compare_stateless_float(self, config, input_shape, time=12) + + +class PoolingCrossBackendTest(parameterized.TestCase): + """Pooling layers: JAX vs MLX.""" + + @parameterized.named_parameters( + ( + 'max_pool_2_valid', + sl.MaxPooling1D.Config(pool_size=2, padding='valid'), + (8,), + ), + ( + 'max_pool_3_causal', + sl.MaxPooling1D.Config(pool_size=3, padding='causal'), + (8,), + ), + ( + 'min_pool_2_valid', + sl.MinPooling1D.Config(pool_size=2, padding='valid'), + (8,), + ), + ( + 'min_pool_3_causal', + sl.MinPooling1D.Config(pool_size=3, padding='causal'), + (8,), + ), + ( + 'avg_pool_2_valid', + sl.AveragePooling1D.Config(pool_size=2, padding='valid'), + (8,), + ), + ( + 'avg_pool_3_causal', + sl.AveragePooling1D.Config(pool_size=3, padding='causal'), + (8,), + ), + ( + 'max_pool_stride2', + sl.MaxPooling1D.Config(pool_size=2, strides=2, padding='valid'), + (8,), + ), + ( + 'avg_pool_masked', + sl.AveragePooling1D.Config( + pool_size=2, padding='valid', masked_average=True + ), + (8,), + ), + ) + def test_pooling(self, config, input_shape): + _compare_stateless_float(self, config, input_shape) + + +class EmbeddingCrossBackendTest(parameterized.TestCase): + """Embedding: JAX vs MLX.""" + + def test_embedding(self): + config = sl.Embedding.Config(num_embeddings=32, dimension=16) + _compare_parametric_int(self, config) + + +class DenseCrossBackendTest(parameterized.TestCase): + """Dense: JAX vs MLX.""" + + def test_dense_plain(self): + config = sl.Dense.Config(features=16) + _compare_parametric_float(self, config, (8,)) + + def test_dense_with_bias(self): + config = sl.Dense.Config(features=16, use_bias=True) + _compare_parametric_float(self, config, (8,)) + + def test_dense_with_activation(self): + config = sl.Dense.Config(features=16, activation=jax.nn.relu) + _compare_parametric_float(self, config, (8,)) + + +class ConvolutionCrossBackendTest(parameterized.TestCase): + """Convolution: JAX vs MLX.""" + + def test_conv1d_causal(self): + config = sl.Conv1D.Config(filters=8, kernel_size=3, padding='causal') + _compare_parametric_float(self, config, (4,)) + + def test_conv1d_causal_valid(self): + config = sl.Conv1D.Config(filters=8, kernel_size=3, padding='causal_valid') + _compare_parametric_float(self, config, (4,)) + + def test_depthwise_conv1d(self): + config = sl.DepthwiseConv1D.Config(kernel_size=3, padding='causal') + _compare_parametric_float(self, config, (4,)) + + def test_conv1d_transpose(self): + config = sl.Conv1DTranspose.Config( + filters=8, kernel_size=3, strides=2, padding='causal' + ) + _compare_parametric_float(self, config, (4,)) + + def test_conv1d_with_bias(self): + config = sl.Conv1D.Config( + filters=8, kernel_size=3, padding='causal', use_bias=True + ) + _compare_parametric_float(self, config, (4,)) + + +class NormalizationCrossBackendTest(parameterized.TestCase): + """Normalization: JAX vs MLX.""" + + def test_rms_norm(self): + config = sl.RMSNormalization.Config() + _compare_parametric_float(self, config, (16,)) + + def test_layer_norm(self): + config = sl.LayerNormalization.Config() + _compare_parametric_float(self, config, (16,)) + + def test_l2_normalize(self): + config = sl.L2Normalize.Config() + _compare_stateless_float(self, config, (16,)) + + def test_l2_normalize_multi_axis(self): + config = sl.L2Normalize.Config(axis=(-2, -1)) + _compare_stateless_float(self, config, (4, 3)) + + def test_batch_norm(self): + config = sl.BatchNormalization.Config() + rng = np.random.RandomState(42) + batch_size, time = 2, 8 + input_shape = (16,) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + # JAX: init returns both 'params' and 'batch_stats'. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_params = variables['params'] + jax_batch_stats = variables['batch_stats'] + jax_out = np.array(jax_model.apply(variables, x_jax, training=False).values) + + # MLX: load params + batch_stats. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + batch_stats=jax_batch_stats, + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose(mlx_out, jax_out, atol=1e-5, rtol=1e-5) + + def test_batch_norm_no_affine(self): + config = sl.BatchNormalization.Config(use_scale=False, use_bias=False) + rng = np.random.RandomState(42) + batch_size, time = 2, 8 + input_shape = (16,) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_batch_stats = variables.get('batch_stats', {}) + # No params when scale/bias disabled — only batch_stats. + jax_params = variables.get('params', {}) + jax_out = np.array(jax_model.apply(variables, x_jax, training=False).values) + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + batch_stats=jax_batch_stats, + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + np.testing.assert_allclose(mlx_out, jax_out, atol=1e-5, rtol=1e-5) + + # GroupNorm: JAX layer() reduces over time (non-cumulative), MLX normalizes + # per-timestep by design. Cross-backend comparison requires cumulative mode + # which differs semantically. Skipped. + + +class SelfAttentionCrossBackendTest(parameterized.TestCase): + """Self-attention: JAX vs MLX.""" + + def test_basic(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=8, + max_past_horizon=16, + max_future_horizon=0, + ) + _compare_parametric_float(self, config, (16,), atol=1e-4, rtol=1e-4) + + def test_with_rope(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=8, + max_past_horizon=16, + max_future_horizon=0, + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ) + _compare_parametric_float(self, config, (16,), atol=1e-4, rtol=1e-4) + + +class LocalSelfAttentionCrossBackendTest(parameterized.TestCase): + """Local self-attention: JAX vs MLX.""" + + def test_basic(self): + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + + config = jax_local_attn.LocalDotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + block_size=1, + max_past_horizon=8, + max_future_horizon=0, + ) + _compare_parametric_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_with_soft_cap(self): + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + + config = jax_local_attn.LocalDotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + block_size=1, + max_past_horizon=8, + max_future_horizon=0, + attention_logits_soft_cap=50.0, + ) + _compare_parametric_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + +class StepModeLocalSelfAttentionTest(parameterized.TestCase): + """Step-mode cross-backend: local self-attention.""" + + def test_causal(self): + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + + config = jax_local_attn.LocalDotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + block_size=1, + max_past_horizon=8, + max_future_horizon=0, + ) + _compare_step_mode(self, config, (8,), atol=1e-4, rtol=1e-4) + + +class DSPCrossBackendTest(parameterized.TestCase): + """DSP layers: JAX vs MLX.""" + + def test_delay(self): + config = sl.Delay.Config(length=2) + _compare_stateless_float(self, config, (8,)) + + def test_lookahead(self): + config = sl.Lookahead.Config(length=3) + _compare_stateless_float(self, config, (8,)) + + def test_window(self): + config = sl.Window.Config(axis=-1) + _compare_stateless_float(self, config, (8,)) + + def test_frame(self): + config = sl.Frame.Config(frame_length=4, frame_step=2) + _compare_stateless_float(self, config, (1,), time=8) + + def test_frame_causal(self): + config = sl.Frame.Config(frame_length=4, frame_step=2, padding='causal') + _compare_stateless_float(self, config, (1,), time=8) + + def test_overlap_add_causal(self): + config = sl.OverlapAdd.Config( + frame_length=4, frame_step=2, padding='causal' + ) + _compare_stateless_float(self, config, (4,), time=8) + + def test_fft(self): + config = sl.FFT.Config() + _compare_stateless_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_ifft(self): + config = sl.IFFT.Config() + _compare_stateless_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_rfft(self): + config = sl.RFFT.Config() + _compare_stateless_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_rfft_irfft_roundtrip(self): + # IRFFT needs complex input; test via RFFT→IRFFT roundtrip. + config = sl.Serial.Config([ + sl.RFFT.Config(), + sl.IRFFT.Config(), + ]) + _compare_stateless_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_stft(self): + config = sl.STFT.Config( + frame_length=8, + frame_step=4, + fft_length=8, + output_magnitude=True, + ) + _compare_stateless_float(self, config, (1,), time=16, atol=1e-4, rtol=1e-4) + + def test_stft_complex(self): + config = sl.STFT.Config( + frame_length=8, + frame_step=4, + fft_length=8, + output_magnitude=False, + ) + _compare_stateless_float(self, config, (1,), time=16, atol=1e-4, rtol=1e-4) + + def test_stft_inverse_stft_roundtrip(self): + # InverseSTFT needs complex input; test via STFT→InverseSTFT roundtrip. + config = sl.Serial.Config([ + sl.STFT.Config( + frame_length=8, + frame_step=4, + fft_length=8, + output_magnitude=False, + ), + sl.InverseSTFT.Config( + frame_length=8, + frame_step=4, + fft_length=8, + time_padding='causal', + ), + ]) + _compare_stateless_float(self, config, (1,), time=16, atol=1e-4, rtol=1e-4) + + def test_mel_spectrogram(self): + config = sl.LinearToMelSpectrogram.Config( + num_mel_bins=10, + sample_rate=16000.0, + lower_edge_hertz=80.0, + upper_edge_hertz=7600.0, + ) + # Mel filterbank computation may differ slightly between backends + # due to different float64 vs float32 precision paths. + _compare_stateless_float(self, config, (5,), atol=0.05, rtol=0.1) + + +class CombinatorsCrossBackendTest(parameterized.TestCase): + """Combinators: JAX vs MLX.""" + + def test_serial(self): + config = sl.Serial.Config([ + sl.Dense.Config(features=16), + sl.Relu.Config(), + sl.Dense.Config(features=8), + ]) + _compare_parametric_float(self, config, (8,)) + + def test_residual(self): + config = sl.Residual.Config([ + sl.Dense.Config(features=8), + sl.Relu.Config(), + ]) + _compare_parametric_float(self, config, (8,)) + + def test_repeat(self): + config = sl.Repeat.Config( + num_repeats=2, + layer=sl.Serial.Config([ + sl.Dense.Config(features=8), + sl.Relu.Config(), + ]), + ) + _compare_parametric_float(self, config, (8,), atol=1e-4, rtol=1e-4) + + +class CrossAttentionCrossBackendTest(parameterized.TestCase): + """Cross-attention (DotProductAttention): JAX vs MLX.""" + + def _make_constants(self, batch_size, time, source_features, rng): + source_values = rng.randn(batch_size, time, source_features).astype( + np.float32 + ) + source_mask = np.ones((batch_size, time), dtype=bool) + jax_source = jax_types.Sequence( + jnp.array(source_values), jnp.array(source_mask, dtype=jnp.bool_) + ) + mlx_source = Sequence( + mx.array(source_values), mx.array(source_mask, dtype=mx.bool_) + ) + return {'enc': jax_source}, {'enc': mlx_source} + + def test_basic(self): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=2, + units_per_head=4, + ) + _compare_with_constants( + self, + config, + (8,), + lambda b, t, rng: self._make_constants(b, t, 12, rng), + atol=1e-4, + rtol=1e-4, + ) + + def test_same_features(self): + """Source and input have the same feature dimension.""" + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=4, + units_per_head=4, + ) + _compare_with_constants( + self, + config, + (16,), + lambda b, t, rng: self._make_constants(b, t, 16, rng), + atol=1e-4, + rtol=1e-4, + ) + + +class StreamingAttentionCrossBackendTest(parameterized.TestCase): + """Streaming cross-attention: JAX vs MLX weight conversion.""" + + def _make_constants(self, batch_size, time, source_features, rng): + source_values = rng.randn(batch_size, time, source_features).astype( + np.float32 + ) + source_mask = np.ones((batch_size, time), dtype=bool) + jax_source = jax_types.Sequence( + jnp.array(source_values), jnp.array(source_mask, dtype=jnp.bool_) + ) + mlx_source = Sequence( + mx.array(source_values), mx.array(source_mask, dtype=mx.bool_) + ) + return {'src': jax_source}, {'src': mlx_source} + + def test_basic(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + ) + _compare_with_constants( + self, + config, + (8,), + lambda b, t, rng: self._make_constants(b, t, 12, rng), + atol=1e-4, + rtol=1e-4, + ) + + def test_with_future_horizon(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=4, + max_future_horizon=2, + ) + _compare_with_constants( + self, + config, + (8,), + lambda b, t, rng: self._make_constants(b, t, 8, rng), + atol=1e-4, + rtol=1e-4, + ) + + +# --------------------------------------------------------------------------- +# Step-mode cross-backend tests +# --------------------------------------------------------------------------- + + +def _compare_step_mode( + test_case, + config, + input_shape, + *, + batch_size=1, + num_steps=6, + block_size=1, + atol=1e-5, + rtol=1e-5, + seed=42, + constants_fn=None, + stream_constants_fn=None, +): + """Compare JAX and MLX step-by-step outputs with shared weights. + + Args: + test_case: A TestCase instance. + config: A SequenceLayerConfig. + input_shape: Channel shape, e.g. (8,). + batch_size: Batch dimension. + num_steps: Number of step invocations. + block_size: Number of timesteps per step. Must match the layer's + block_size for layers that require it (e.g. Frame, OverlapAdd). + atol: Absolute tolerance. + rtol: Relative tolerance. + seed: Random seed. + constants_fn: For static cross-attention. Returns (jax_constants, + mlx_constants) given (batch_size, rng). + stream_constants_fn: For streaming cross-attention. Returns + (jax_constants, mlx_constants) given (batch_size, time, rng). + Each has shape [batch, time, features]. Will be sliced per step. + """ + rng = np.random.RandomState(seed) + step_values = [ + rng.randn(batch_size, block_size, *input_shape).astype(np.float32) + for _ in range(num_steps) + ] + step_masks = [ + np.ones((batch_size, block_size), dtype=bool) for _ in range(num_steps) + ] + total_time = num_steps * block_size + + jax_constants = None + mlx_constants = None + jax_stream_constants = None + mlx_stream_constants = None + + if constants_fn is not None: + jax_constants, mlx_constants = constants_fn(batch_size, rng) + + if stream_constants_fn is not None: + jax_stream_constants, mlx_stream_constants = stream_constants_fn( + batch_size, total_time, rng + ) + + # --- JAX init + step --- + jax_model = config.make() + # Init with a full sequence to get params. + full_values = np.concatenate(step_values, axis=1) + full_mask = np.ones((batch_size, total_time), dtype=bool) + x_init = jax_types.Sequence( + jnp.array(full_values), jnp.array(full_mask, dtype=jnp.bool_) + ) + init_constants = jax_constants + if init_constants is None and jax_stream_constants is not None: + init_constants = jax_stream_constants + variables = jax_model.init( + jax.random.PRNGKey(0), + x_init, + training=False, + constants=init_constants, + ) + jax_params = variables.get('params', {}) + jax_variables = {'params': jax_params} if jax_params else variables + + jax_spec = jax.ShapeDtypeStruct(input_shape, jnp.float32) + jax_state = jax_model.apply( + jax_variables, + batch_size, + jax_spec, + training=False, + constants=init_constants, + method=jax_model.get_initial_state, + ) + + jax_outputs = [] + for i in range(num_steps): + x_jax = jax_types.Sequence( + jnp.array(step_values[i]), + jnp.array(step_masks[i], dtype=jnp.bool_), + ) + step_c = jax_constants + if jax_stream_constants is not None: + s = i * block_size + e = s + block_size + step_c = { + k: jax_types.Sequence(v.values[:, s:e], v.mask[:, s:e]) + for k, v in jax_stream_constants.items() + } + y_jax, jax_state = jax_model.apply( + jax_variables, + x_jax, + jax_state, + training=False, + constants=step_c, + method=jax_model.step, + ) + jax_outputs.append(np.array(y_jax.values)) + + # --- MLX init + step --- + mlx_model = config.make(backend='mlx') + mlx_init_constants = mlx_constants + if mlx_init_constants is None and mlx_stream_constants is not None: + mlx_init_constants = mlx_stream_constants + if jax_params: + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + constants=mlx_init_constants, + ) + # Skip _materialize_deferred for param-less layers — no deferred weights. + + mlx_spec = ShapeDType(input_shape, mx.float32) + # Slice stream constants to time=1 for get_initial_state. + state_constants = mlx_constants + if mlx_stream_constants is not None: + state_constants = { + k: Sequence(v.values[:, :1], v.mask[:, :1]) + for k, v in mlx_stream_constants.items() + } + mlx_state = mlx_model.get_initial_state( + batch_size, mlx_spec, constants=state_constants + ) + + mlx_outputs = [] + for i in range(num_steps): + x_mx = Sequence( + mx.array(step_values[i]), + mx.array(step_masks[i], dtype=mx.bool_), + ) + step_c = mlx_constants + if mlx_stream_constants is not None: + s = i * block_size + e = s + block_size + step_c = { + k: Sequence(v.values[:, s:e], v.mask[:, s:e]) + for k, v in mlx_stream_constants.items() + } + y_mx, mlx_state = mlx_model.step(x_mx, mlx_state, constants=step_c) + mx.eval(y_mx.values) + mlx_outputs.append(np.array(y_mx.values)) + + # --- Compare --- + for i, (jax_out, mlx_out) in enumerate(zip(jax_outputs, mlx_outputs)): + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__} step {i}: outputs differ', + ) + + +class StepModeConvolutionTest(parameterized.TestCase): + """Step-mode cross-backend: convolution layers.""" + + def test_conv1d_causal(self): + config = sl.Conv1D.Config(filters=8, kernel_size=3, padding='causal') + _compare_step_mode(self, config, (4,)) + + def test_depthwise_conv1d_causal(self): + config = sl.DepthwiseConv1D.Config(kernel_size=3, padding='causal') + _compare_step_mode(self, config, (4,)) + + def test_conv1d_transpose_causal(self): + config = sl.Conv1DTranspose.Config( + filters=8, kernel_size=3, strides=2, padding='causal' + ) + _compare_step_mode(self, config, (4,)) + + +class StepModeDenseNormTest(parameterized.TestCase): + """Step-mode cross-backend: Dense and normalization.""" + + def test_dense(self): + config = sl.Dense.Config(features=16) + _compare_step_mode(self, config, (8,)) + + def test_rms_norm(self): + config = sl.RMSNormalization.Config() + _compare_step_mode(self, config, (16,)) + + def test_layer_norm(self): + config = sl.LayerNormalization.Config() + _compare_step_mode(self, config, (16,)) + + +class StepModeSelfAttentionTest(parameterized.TestCase): + """Step-mode cross-backend: self-attention.""" + + def test_causal(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + ) + _compare_step_mode(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_causal_with_bias(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + use_bias=True, + ) + _compare_step_mode(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_causal_with_rope(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ) + _compare_step_mode(self, config, (8,), atol=1e-4, rtol=1e-4) + + def test_gqa_with_rope(self): + config = sl.DotProductSelfAttention.Config( + num_heads=4, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + num_kv_heads=2, + input_projection=attn_common.SeparateQueryKeyValueProjection(), + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ) + _compare_step_mode(self, config, (16,), atol=1e-4, rtol=1e-4) + + +class StepModeCrossAttentionTest(parameterized.TestCase): + """Step-mode cross-backend: cross-attention.""" + + def _make_constants(self, batch_size, rng, source_features=12, source_time=8): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + source_values = rng.randn(batch_size, source_time, source_features).astype( + np.float32 + ) + source_mask = np.ones((batch_size, source_time), dtype=bool) + jax_source = jax_types.Sequence( + jnp.array(source_values), jnp.array(source_mask, dtype=jnp.bool_) + ) + mlx_source = Sequence( + mx.array(source_values), mx.array(source_mask, dtype=mx.bool_) + ) + return {'enc': jax_source}, {'enc': mlx_source} + + def test_cross_attention(self): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=2, + units_per_head=4, + ) + _compare_step_mode( + self, + config, + (8,), + constants_fn=lambda b, rng: self._make_constants(b, rng), + atol=1e-4, + rtol=1e-4, + ) + + def test_cross_attention_different_dims(self): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=2, + units_per_head=4, + ) + _compare_step_mode( + self, + config, + (16,), + constants_fn=lambda b, rng: self._make_constants( + b, rng, source_features=8, source_time=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_cross_attention_with_bias(self): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + config = jax_cross_attn.DotProductAttention.Config( + source_name='enc', + num_heads=2, + units_per_head=4, + use_bias=True, + ) + _compare_step_mode( + self, + config, + (8,), + constants_fn=lambda b, rng: self._make_constants(b, rng), + atol=1e-4, + rtol=1e-4, + ) + + +class StepModeStreamingAttentionTest(parameterized.TestCase): + """Step-mode cross-backend: streaming cross-attention.""" + + def _make_stream_constants(self, batch_size, time, rng, source_features=12): + source_values = rng.randn(batch_size, time, source_features).astype( + np.float32 + ) + source_mask = np.ones((batch_size, time), dtype=bool) + jax_source = jax_types.Sequence( + jnp.array(source_values), jnp.array(source_mask, dtype=jnp.bool_) + ) + mlx_source = Sequence( + mx.array(source_values), mx.array(source_mask, dtype=mx.bool_) + ) + return {'src': jax_source}, {'src': mlx_source} + + def test_streaming_attention(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + ) + _compare_step_mode( + self, + config, + (8,), + stream_constants_fn=lambda b, t, rng: self._make_stream_constants( + b, t, rng, source_features=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_streaming_with_future_horizon(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=6, + max_future_horizon=2, + ) + _compare_step_mode( + self, + config, + (8,), + stream_constants_fn=lambda b, t, rng: self._make_stream_constants( + b, t, rng, source_features=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_streaming_with_rope(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ) + _compare_step_mode( + self, + config, + (8,), + stream_constants_fn=lambda b, t, rng: self._make_stream_constants( + b, t, rng, source_features=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_streaming_local(self): + from sequence_layers.jax.attention import ( + streaming_local_dot_product_attention as jax_streaming_local_attn, + ) + + config = jax_streaming_local_attn.StreamingLocalDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + block_size=1, + ) + _compare_step_mode( + self, + config, + (8,), + stream_constants_fn=lambda b, t, rng: self._make_stream_constants( + b, t, rng, source_features=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + def test_streaming_with_bias(self): + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + + config = jax_streaming_attn.StreamingDotProductAttention.Config( + source_name='src', + num_heads=2, + units_per_head=4, + max_past_horizon=8, + use_bias=True, + ) + _compare_step_mode( + self, + config, + (8,), + stream_constants_fn=lambda b, t, rng: self._make_stream_constants( + b, t, rng, source_features=12 + ), + atol=1e-4, + rtol=1e-4, + ) + + +class StepModeDSPTest(parameterized.TestCase): + """Step-mode cross-backend: DSP layers.""" + + def test_delay(self): + config = sl.Delay.Config(length=3) + _compare_step_mode(self, config, (8,)) + + def test_lookahead(self): + config = sl.Lookahead.Config(length=3) + _compare_step_mode(self, config, (8,)) + + def test_window(self): + config = sl.Window.Config(axis=-1) + _compare_step_mode(self, config, (8,)) + + def test_frame_causal(self): + config = sl.Frame.Config(frame_length=4, frame_step=2, padding='causal') + _compare_step_mode(self, config, (1,), block_size=2, num_steps=6) + + def test_overlap_add_causal(self): + config = sl.OverlapAdd.Config( + frame_length=4, frame_step=2, padding='causal' + ) + _compare_step_mode(self, config, (4,), num_steps=6) + + def test_overlap_add_causal_large(self): + config = sl.OverlapAdd.Config( + frame_length=8, frame_step=4, padding='causal' + ) + _compare_step_mode(self, config, (8,), num_steps=6) + + +class StepModeCombinatorTest(parameterized.TestCase): + """Step-mode cross-backend: combinators.""" + + def test_serial(self): + config = sl.Serial.Config([ + sl.Dense.Config(features=16), + sl.Relu.Config(), + sl.Dense.Config(features=8), + ]) + _compare_step_mode(self, config, (8,)) + + def test_residual(self): + config = sl.Residual.Config([ + sl.Dense.Config(features=8), + sl.Relu.Config(), + ]) + _compare_step_mode(self, config, (8,)) + + def test_repeat_with_attention(self): + config = sl.Repeat.Config( + num_repeats=2, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=8), + ]), + ]), + ) + _compare_step_mode(self, config, (8,), atol=1e-3, rtol=1e-3) + + +class StepModePoolingTest(parameterized.TestCase): + """Step-mode cross-backend: pooling layers.""" + + def test_max_pool_causal(self): + config = sl.MaxPooling1D.Config(pool_size=3, padding='causal') + _compare_step_mode(self, config, (8,)) + + def test_min_pool_causal(self): + config = sl.MinPooling1D.Config(pool_size=3, padding='causal') + _compare_step_mode(self, config, (8,)) + + def test_avg_pool_causal(self): + config = sl.AveragePooling1D.Config(pool_size=3, padding='causal') + _compare_step_mode(self, config, (8,)) + + +class StepModeGQATest(parameterized.TestCase): + """Step-mode cross-backend: grouped query attention.""" + + def test_gqa(self): + config = sl.DotProductSelfAttention.Config( + num_heads=4, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + num_kv_heads=2, + input_projection=attn_common.SeparateQueryKeyValueProjection(), + ) + _compare_step_mode(self, config, (16,), atol=1e-4, rtol=1e-4) + + +class GQACrossBackendTest(parameterized.TestCase): + """Layer-mode cross-backend: grouped query attention.""" + + def test_gqa(self): + config = sl.DotProductSelfAttention.Config( + num_heads=4, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + num_kv_heads=2, + input_projection=attn_common.SeparateQueryKeyValueProjection(), + ) + _compare_parametric_float(self, config, (16,), atol=1e-4, rtol=1e-4) + + +# --------------------------------------------------------------------------- +# Parallel combinator cross-backend tests +# --------------------------------------------------------------------------- + + +class ParallelCrossBackendTest(parameterized.TestCase): + """Cross-backend: Parallel combinator (layer + step).""" + + def test_parallel_add_layer(self): + config = sl.Parallel.Config( + layers=[ + sl.Dense.Config(features=8), + sl.Dense.Config(features=8), + ], + combination=sl.CombinationMode.ADD, + ) + _compare_parametric_float(self, config, (8,)) + + def test_parallel_concat_layer(self): + config = sl.Parallel.Config( + layers=[ + sl.Dense.Config(features=4), + sl.Dense.Config(features=4), + ], + combination=sl.CombinationMode.CONCAT, + ) + _compare_parametric_float(self, config, (8,)) + + def test_parallel_stack_layer(self): + config = sl.Parallel.Config( + layers=[ + sl.Dense.Config(features=8), + sl.Dense.Config(features=8), + ], + combination=sl.CombinationMode.STACK, + ) + _compare_parametric_float(self, config, (8,)) + + def test_parallel_add_step(self): + config = sl.Parallel.Config( + layers=[ + sl.Dense.Config(features=8), + sl.Dense.Config(features=8), + ], + combination=sl.CombinationMode.ADD, + ) + _compare_step_mode(self, config, (8,)) + + def test_parallel_concat_step(self): + config = sl.Parallel.Config( + layers=[ + sl.Dense.Config(features=4), + sl.Dense.Config(features=4), + ], + combination=sl.CombinationMode.CONCAT, + ) + _compare_step_mode(self, config, (8,)) + + +# --------------------------------------------------------------------------- +# Partially-masked input tests +# --------------------------------------------------------------------------- + + +def _compare_parametric_float_masked( + test_case, + config, + input_shape, + *, + batch_size=2, + time=8, + atol=1e-5, + rtol=1e-5, + seed=42, +): + """Like _compare_parametric_float but with partially-masked inputs.""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + # Create a mask where ~25% of timesteps are invalid. + mask = rng.rand(batch_size, time) > 0.25 + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_params = variables.get('params', {}) + jax_variables = {'params': jax_params} if jax_params else variables + jax_out = jax_model.apply(jax_variables, x_jax, training=False) + jax_values = np.array(jax_out.values) + jax_mask = np.array(jax_out.mask) + + # MLX. + mlx_model = config.make(backend='mlx') + if jax_params: + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + ) + else: + export._materialize_deferred( + mlx_model, + batch_size=1, + input_spec=ShapeDType(input_shape, mx.float32), + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = mlx_model.layer(x_mx) + mlx_values = np.array(mlx_out.values) + mlx_mask = np.array(mlx_out.mask) + + # Compare valid timesteps only. + out_mask = jax_mask + if jax_values.shape != mlx_values.shape: + test_case.fail( + f'{config.__class__.__qualname__}: shape mismatch' + f' jax={jax_values.shape} vs mlx={mlx_values.shape}' + ) + + # Flatten and compare only valid positions. + for b in range(batch_size): + for t in range(out_mask.shape[1]): + if out_mask[b, t]: + np.testing.assert_allclose( + mlx_values[b, t], + jax_values[b, t], + atol=atol, + rtol=rtol, + err_msg=( + f'{config.__class__.__qualname__} batch={b} time={t}:' + ' valid outputs differ' + ), + ) + + # Masks should match. + np.testing.assert_array_equal( + mlx_mask, + jax_mask, + err_msg=f'{config.__class__.__qualname__}: masks differ', + ) + + +class MaskedInputDenseTest(parameterized.TestCase): + """Cross-backend with partially-masked inputs: Dense.""" + + def test_dense_masked(self): + config = sl.Dense.Config(features=16) + _compare_parametric_float_masked(self, config, (8,)) + + +class MaskedInputConvTest(parameterized.TestCase): + """Cross-backend with partially-masked inputs: Conv1D.""" + + def test_conv1d_causal_masked(self): + config = sl.Conv1D.Config(filters=8, kernel_size=3, padding='causal') + _compare_parametric_float_masked(self, config, (4,)) + + def test_depthwise_conv1d_masked(self): + config = sl.DepthwiseConv1D.Config(kernel_size=3, padding='causal') + _compare_parametric_float_masked(self, config, (4,)) + + +class MaskedInputNormTest(parameterized.TestCase): + """Cross-backend with partially-masked inputs: normalization.""" + + def test_rms_norm_masked(self): + config = sl.RMSNormalization.Config() + _compare_parametric_float_masked(self, config, (16,)) + + def test_layer_norm_masked(self): + config = sl.LayerNormalization.Config() + _compare_parametric_float_masked(self, config, (16,)) + + +class MaskedInputSelfAttentionTest(parameterized.TestCase): + """Cross-backend with partially-masked inputs: self-attention.""" + + def test_causal_masked(self): + config = sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=4, + max_past_horizon=16, + max_future_horizon=0, + ) + _compare_parametric_float_masked(self, config, (8,), atol=1e-4, rtol=1e-4) + + +class MaskedInputPoolingTest(parameterized.TestCase): + """Cross-backend with partially-masked inputs: pooling.""" + + def test_max_pool_masked(self): + config = sl.MaxPooling1D.Config(pool_size=2, padding='causal') + _compare_parametric_float_masked(self, config, (8,)) + + def test_avg_pool_masked(self): + config = sl.AveragePooling1D.Config(pool_size=2, padding='causal') + _compare_parametric_float_masked(self, config, (8,)) + + +# --------------------------------------------------------------------------- +# Integration tests: full model cross-backend comparison +# --------------------------------------------------------------------------- + + +def _compare_integration_float( + test_case, + config, + input_shape, + *, + batch_size=2, + time=8, + atol=1e-3, + rtol=1e-3, + seed=42, + constants_fn=None, +): + """Compare a full model (layer mode) between JAX and MLX.""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + + jax_constants = None + mlx_constants = None + if constants_fn is not None: + jax_constants, mlx_constants = constants_fn(batch_size, time, rng) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init( + jax.random.PRNGKey(0), x_jax, training=False, constants=jax_constants + ) + jax_params = variables['params'] + jax_out = jax_model.apply( + {'params': jax_params}, + x_jax, + training=False, + constants=jax_constants, + ) + jax_values = np.array(jax_out.values) + + # MLX. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + constants=mlx_constants, + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = mlx_model.layer(x_mx, constants=mlx_constants) + mlx_values = np.array(mlx_out.values) + + test_case.assertEqual( + jax_values.shape, + mlx_values.shape, + f'Shape mismatch: jax={jax_values.shape} vs mlx={mlx_values.shape}', + ) + np.testing.assert_allclose( + mlx_values, + jax_values, + atol=atol, + rtol=rtol, + err_msg='Integration test: JAX vs MLX outputs differ', + ) + return jax_params, jax_constants, mlx_constants + + +def _compare_integration_int( + test_case, + config, + *, + vocab_size=256, + batch_size=2, + time=8, + atol=1e-3, + rtol=1e-3, + seed=42, +): + """Compare a full model with integer token inputs (layer mode).""" + rng = np.random.RandomState(seed) + tokens = rng.randint(0, vocab_size, size=(batch_size, time)).astype(np.int32) + mask = np.ones((batch_size, time), dtype=bool) + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(tokens), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init(jax.random.PRNGKey(0), x_jax, training=False) + jax_params = variables['params'] + jax_out = np.array( + jax_model.apply({'params': jax_params}, x_jax, training=False).values + ) + + # MLX. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + x_mx = Sequence( + mx.array(tokens, dtype=mx.int32), mx.array(mask, dtype=mx.bool_) + ) + mlx_out = np.array(mlx_model.layer(x_mx).values) + + test_case.assertEqual( + jax_out.shape, + mlx_out.shape, + f'Shape mismatch: jax={jax_out.shape} vs mlx={mlx_out.shape}', + ) + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg='Integration test: JAX vs MLX outputs differ', + ) + return jax_params + + +def _compare_integration_step( + test_case, + config, + input_shape, + jax_params, + *, + batch_size=2, + num_steps=8, + atol=1e-3, + rtol=1e-3, + seed=42, + jax_constants=None, + mlx_constants=None, +): + """Compare step-by-step output of a full model between JAX and MLX.""" + rng = np.random.RandomState(seed + 1) + step_values = [ + rng.randn(batch_size, 1, *input_shape).astype(np.float32) + for _ in range(num_steps) + ] + step_masks = [np.ones((batch_size, 1), dtype=bool) for _ in range(num_steps)] + + # JAX step. + jax_model = config.make() + jax_spec = jax.ShapeDtypeStruct(input_shape, jnp.float32) + jax_state = jax_model.apply( + {'params': jax_params}, + batch_size, + jax_spec, + training=False, + constants=jax_constants, + method=jax_model.get_initial_state, + ) + jax_outputs = [] + for i in range(num_steps): + x_jax = jax_types.Sequence( + jnp.array(step_values[i]), + jnp.array(step_masks[i], dtype=jnp.bool_), + ) + y_jax, jax_state = jax_model.apply( + {'params': jax_params}, + x_jax, + jax_state, + training=False, + constants=jax_constants, + method=jax_model.step, + ) + jax_outputs.append(np.array(y_jax.values)) + + # MLX step. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + constants=mlx_constants, + ) + mlx_spec = ShapeDType(input_shape, mx.float32) + mlx_state = mlx_model.get_initial_state( + batch_size, mlx_spec, constants=mlx_constants + ) + mlx_outputs = [] + for i in range(num_steps): + x_mx = Sequence( + mx.array(step_values[i]), + mx.array(step_masks[i], dtype=mx.bool_), + ) + y_mx, mlx_state = mlx_model.step(x_mx, mlx_state, constants=mlx_constants) + mx.eval(y_mx.values) + mlx_outputs.append(np.array(y_mx.values)) + + for i, (jax_out, mlx_out) in enumerate(zip(jax_outputs, mlx_outputs)): + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'Integration step {i}: JAX vs MLX outputs differ', + ) + + +def _compare_integration_int_step( + test_case, + config, + jax_params, + *, + vocab_size=256, + batch_size=1, + num_steps=8, + atol=1e-3, + rtol=1e-3, + seed=42, +): + """Compare step-by-step output of a token model between JAX and MLX.""" + rng = np.random.RandomState(seed + 1) + step_tokens = [ + rng.randint(0, vocab_size, size=(batch_size, 1)).astype(np.int32) + for _ in range(num_steps) + ] + step_masks = [np.ones((batch_size, 1), dtype=bool) for _ in range(num_steps)] + + # JAX step. + jax_model = config.make() + jax_spec = jax.ShapeDtypeStruct((), jnp.int32) + jax_state = jax_model.apply( + {'params': jax_params}, + batch_size, + jax_spec, + training=False, + method=jax_model.get_initial_state, + ) + jax_outputs = [] + for i in range(num_steps): + x_jax = jax_types.Sequence( + jnp.array(step_tokens[i]), + jnp.array(step_masks[i], dtype=jnp.bool_), + ) + y_jax, jax_state = jax_model.apply( + {'params': jax_params}, + x_jax, + jax_state, + training=False, + method=jax_model.step, + ) + jax_outputs.append(np.array(y_jax.values)) + + # MLX step. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + mlx_spec = ShapeDType((), mx.int32) + mlx_state = mlx_model.get_initial_state(batch_size, mlx_spec) + mlx_outputs = [] + for i in range(num_steps): + x_mx = Sequence( + mx.array(step_tokens[i], dtype=mx.int32), + mx.array(step_masks[i], dtype=mx.bool_), + ) + y_mx, mlx_state = mlx_model.step(x_mx, mlx_state) + mx.eval(y_mx.values) + mlx_outputs.append(np.array(y_mx.values)) + + for i, (jax_out, mlx_out) in enumerate(zip(jax_outputs, mlx_outputs)): + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'Integration step {i}: JAX vs MLX outputs differ', + ) + + +class DecoderTransformerIntegrationTest(parameterized.TestCase): + """Cross-backend: decoder-only transformer (token input).""" + + def _config(self, dim=32, num_heads=4, num_layers=2, vocab_size=64): + return sl.Serial.Config([ + sl.Embedding.Config(num_embeddings=vocab_size, dimension=dim), + sl.Repeat.Config( + num_repeats=num_layers, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=num_heads, + units_per_head=dim // num_heads, + max_past_horizon=64, + max_future_horizon=0, + query_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + key_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=dim * 4, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + ]), + ]), + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=vocab_size), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_int(self, config, vocab_size=64) + + def test_step(self): + config = self._config() + jax_params = _compare_integration_int(self, config, vocab_size=64) + _compare_integration_int_step( + self, config, jax_params, vocab_size=64, num_steps=6 + ) + + +class GQADecoderIntegrationTest(parameterized.TestCase): + """Cross-backend: decoder transformer with GQA.""" + + def _config(self, dim=32, num_heads=4, num_kv_heads=2, vocab_size=64): + return sl.Serial.Config([ + sl.Embedding.Config(num_embeddings=vocab_size, dimension=dim), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=num_heads, + units_per_head=dim // num_heads, + max_past_horizon=64, + max_future_horizon=0, + num_kv_heads=num_kv_heads, + input_projection=( + attn_common.SeparateQueryKeyValueProjection() + ), + query_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + key_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=dim * 4, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + ]), + sl.Dense.Config(features=vocab_size), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_int(self, config, vocab_size=64) + + def test_step(self): + config = self._config() + jax_params = _compare_integration_int(self, config, vocab_size=64) + _compare_integration_int_step( + self, config, jax_params, vocab_size=64, num_steps=6 + ) + + +class ConvEncoderIntegrationTest(parameterized.TestCase): + """Cross-backend: conv + dense encoder (float input).""" + + def _config(self, dim=16): + return sl.Serial.Config([ + sl.Conv1D.Config(filters=dim, kernel_size=3, padding='causal'), + sl.Relu.Config(), + sl.Conv1D.Config(filters=dim, kernel_size=3, padding='causal'), + sl.Relu.Config(), + sl.LayerNormalization.Config(), + sl.Dense.Config(features=dim * 2, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_float(self, config, (8,)) + + def test_step(self): + config = self._config() + jax_params, _, _ = _compare_integration_float(self, config, (8,)) + _compare_integration_step(self, config, (8,), jax_params) + + +class ConvAttentionIntegrationTest(parameterized.TestCase): + """Cross-backend: conv + self-attention + pooling model (float input).""" + + def _config(self, dim=16): + return sl.Serial.Config([ + sl.Conv1D.Config(filters=dim, kernel_size=3, padding='causal'), + sl.Swish.Config(), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=dim // 2, + max_past_horizon=32, + max_future_horizon=0, + ), + sl.Flatten.Config(), + ]), + sl.MaxPooling1D.Config(pool_size=2, padding='causal'), + sl.Dense.Config(features=dim), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_float(self, config, (8,), time=8) + + def test_step(self): + config = self._config() + jax_params, _, _ = _compare_integration_float(self, config, (8,), time=8) + _compare_integration_step( + self, config, (8,), jax_params, num_steps=8, atol=5e-3, rtol=5e-3 + ) + + +class EncoderDecoderIntegrationTest(parameterized.TestCase): + """Cross-backend: encoder-decoder with cross-attention (float input).""" + + def _encoder_config(self, dim=16): + return sl.Serial.Config([ + sl.Dense.Config(features=dim, activation=jax.nn.relu), + sl.Dense.Config(features=dim), + ]) + + def _decoder_config(self, dim=16): + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + + return sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=dim // 2, + max_past_horizon=32, + max_future_horizon=0, + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + jax_cross_attn.DotProductAttention.Config( + source_name='encoder', + num_heads=2, + units_per_head=dim // 2, + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=dim * 2, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + ]), + ]) + + def _make_constants(self, batch_size, time, rng, dim=16): + source_values = rng.randn(batch_size, time, dim).astype(np.float32) + source_mask = np.ones((batch_size, time), dtype=bool) + jax_source = jax_types.Sequence( + jnp.array(source_values), jnp.array(source_mask, dtype=jnp.bool_) + ) + mlx_source = Sequence( + mx.array(source_values), mx.array(source_mask, dtype=mx.bool_) + ) + return {'encoder': jax_source}, {'encoder': mlx_source} + + def test_layer(self): + config = self._decoder_config() + _compare_integration_float( + self, + config, + (16,), + constants_fn=lambda b, t, rng: self._make_constants(b, t, rng), + ) + + def test_step(self): + config = self._decoder_config() + jax_params, jax_constants, mlx_constants = _compare_integration_float( + self, + config, + (16,), + constants_fn=lambda b, t, rng: self._make_constants(b, t, rng), + ) + _compare_integration_step( + self, + config, + (16,), + jax_params, + jax_constants=jax_constants, + mlx_constants=mlx_constants, + ) + + +class DepthwiseConvPipelineIntegrationTest(parameterized.TestCase): + """Cross-backend: depthwise conv + dense + normalization pipeline.""" + + def _config(self, dim=16): + return sl.Serial.Config([ + sl.Dense.Config(features=dim), + sl.DepthwiseConv1D.Config(kernel_size=3, padding='causal'), + sl.Swish.Config(), + sl.LayerNormalization.Config(), + sl.Dense.Config(features=dim * 2, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + sl.DepthwiseConv1D.Config(kernel_size=5, padding='causal'), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=dim), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_float(self, config, (8,), atol=2e-3, rtol=2e-3) + + def test_step(self): + config = self._config() + jax_params, _, _ = _compare_integration_float( + self, config, (8,), atol=2e-3, rtol=2e-3 + ) + _compare_integration_step( + self, config, (8,), jax_params, atol=2e-3, rtol=2e-3 + ) + + +class ParallelBranchIntegrationTest(parameterized.TestCase): + """Cross-backend: parallel branches with different processing.""" + + def _config(self, dim=8): + return sl.Serial.Config([ + sl.Parallel.Config( + layers=[ + sl.Serial.Config([ + sl.Dense.Config(features=dim, activation=jax.nn.relu), + sl.Dense.Config(features=dim), + ]), + sl.Serial.Config([ + sl.Dense.Config(features=dim, activation=jax.nn.gelu), + sl.Dense.Config(features=dim), + ]), + ], + combination=sl.CombinationMode.ADD, + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=dim), + ]) + + def test_layer(self): + config = self._config() + _compare_integration_float(self, config, (8,)) + + def test_step(self): + config = self._config() + jax_params, _, _ = _compare_integration_float(self, config, (8,)) + _compare_integration_step(self, config, (8,), jax_params) + + +def _compare_conditioning( + test_case, + config, + input_shape, + cond_shape, + *, + batch_size=2, + time=8, + atol=1e-5, + rtol=1e-5, + seed=42, +): + """Compare Conditioning layer: JAX vs MLX.""" + rng = np.random.RandomState(seed) + values = rng.randn(batch_size, time, *input_shape).astype(np.float32) + mask = np.ones((batch_size, time), dtype=bool) + cond_values = rng.randn(batch_size, time, *cond_shape).astype(np.float32) + cond_mask = np.ones((batch_size, time), dtype=bool) + + jax_constants = { + 'cond': jax_types.Sequence( + jnp.array(cond_values), jnp.array(cond_mask, dtype=jnp.bool_) + ) + } + mlx_constants = { + 'cond': Sequence( + mx.array(cond_values), mx.array(cond_mask, dtype=mx.bool_) + ) + } + + # JAX. + jax_model = config.make() + x_jax = jax_types.Sequence( + jnp.array(values), jnp.array(mask, dtype=jnp.bool_) + ) + variables = jax_model.init( + jax.random.PRNGKey(0), + x_jax, + training=False, + constants=jax_constants, + ) + jax_params = variables.get('params', {}) + jax_out = np.array( + jax_model.apply( + variables, + x_jax, + training=False, + constants=jax_constants, + ).values + ) + + # MLX. + mlx_model = config.make(backend='mlx') + if jax_params: + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType(input_shape, mx.float32), + constants=mlx_constants, + ) + x_mx = Sequence(mx.array(values), mx.array(mask, dtype=mx.bool_)) + mlx_out = np.array(mlx_model.layer(x_mx, constants=mlx_constants).values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=atol, + rtol=rtol, + err_msg=f'{config.__class__.__qualname__}: outputs differ', + ) + + +class ConditioningCrossBackendTest(parameterized.TestCase): + """Conditioning: JAX vs MLX layer-mode.""" + + def test_identity_add(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.IDENTITY, + combination=jax_cond.BaseConditioning.Combination.ADD, + ) + _compare_conditioning(self, config, (8,), (8,)) + + def test_identity_mul(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.IDENTITY, + combination=jax_cond.BaseConditioning.Combination.MUL, + ) + _compare_conditioning(self, config, (8,), (8,)) + + def test_identity_concat(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.IDENTITY, + combination=jax_cond.BaseConditioning.Combination.CONCAT, + ) + _compare_conditioning(self, config, (4,), (6,)) + + def test_linear_add(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR, + combination=jax_cond.BaseConditioning.Combination.ADD, + ) + _compare_conditioning(self, config, (4,), (6,)) + + def test_linear_affine_shift(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR, + combination=jax_cond.BaseConditioning.Combination.AFFINE_SHIFT, + ) + _compare_conditioning(self, config, (4,), (6,)) + + def test_linear_affine_scale(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR, + combination=jax_cond.BaseConditioning.Combination.AFFINE_SCALE, + ) + _compare_conditioning(self, config, (4,), (6,)) + + def test_linear_affine(self): + from sequence_layers.jax import conditioning as jax_cond + + config = jax_cond.Conditioning.Config( + conditioning_name='cond', + projection=jax_cond.BaseConditioning.Projection.LINEAR_AFFINE, + combination=jax_cond.BaseConditioning.Combination.AFFINE, + ) + _compare_conditioning(self, config, (4,), (6,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/decoder_transformer_test.py b/sequence_layers/mlx/decoder_transformer_test.py new file mode 100644 index 0000000..7aae5f6 --- /dev/null +++ b/sequence_layers/mlx/decoder_transformer_test.py @@ -0,0 +1,248 @@ +"""End-to-end test: decoder-only transformer on MLX. + +Defines a small decoder transformer using Linen configs, builds an MLX +model via config.make(backend='mlx'), and tests inference + export. +""" + +import os +import tempfile + +import jax.nn +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import sequence_layers.jax as sl +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import test_utils + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + + +def _decoder_config(vocab_size=256, dim=64, num_heads=4, num_layers=2): + """A small decoder-only transformer config. + + Architecture: + Embedding → Repeat(N, [ + Residual([RMSNorm, SelfAttention(RoPE), Flatten]), + Residual([RMSNorm, Dense(4*dim, gelu), Dense(dim)]), + ]) → RMSNorm → Dense(vocab_size) + """ + return sl.Serial.Config([ + sl.Embedding.Config( + num_embeddings=vocab_size, + dimension=dim, + ), + sl.Repeat.Config( + num_repeats=num_layers, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=num_heads, + units_per_head=dim // num_heads, + max_past_horizon=128, + max_future_horizon=0, + query_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + key_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config( + features=dim * 4, + activation=jax.nn.gelu, + ), + sl.Dense.Config(features=dim), + ]), + ]), + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=vocab_size), + ]) + + +def _make_token_sequence(tokens): + """Create a Sequence from integer token ids. + + Args: + tokens: A 2D list [[t1, t2, ...], ...] of shape [batch, time]. + + Returns: + Sequence with values shape [batch, time] and all-valid mask. + """ + arr = mx.array(tokens, dtype=mx.int32) + if arr.ndim != 2: + raise ValueError(f'Expected 2D token array, got shape {arr.shape}') + mask = mx.ones(arr.shape, dtype=mx.bool_) + return Sequence(arr, mask) + + +class DecoderTransformerTest(parameterized.TestCase): + """End-to-end tests for a decoder transformer on MLX.""" + + def _make_model(self, config=None): + if config is None: + config = _decoder_config() + model = config.make(backend='mlx') + return model + + def test_make_mlx(self): + """config.make(backend='mlx') produces an MLX SequenceLayer.""" + config = _decoder_config() + model = config.make(backend='mlx') + from sequence_layers.mlx import types + + self.assertIsInstance(model, types.SequenceLayer) + self.assertTrue(model.supports_step) + + def test_layer(self): + """model.layer() produces correct output shape and dtype.""" + model = self._make_model() + batch, time, vocab_size = 2, 8, 256 + # Input: integer token ids with scalar channel shape (). + x = _make_token_sequence([[0] * time] * batch) + y = model.layer(x) + self.assertEqual(y.shape, (batch, time, vocab_size)) + + def test_step(self): + """model.step() runs and output shape is correct.""" + model = self._make_model() + batch, vocab_size = 1, 256 + input_spec = ShapeDType((), mx.int32) + + export._materialize_deferred(model, batch, input_spec) + state = model.get_initial_state(batch, input_spec) + + # Step with a single token. + x = _make_token_sequence([[42]]) + y, new_state = model.step(x, state) + self.assertEqual(y.shape, (batch, 1, vocab_size)) + + # Second step. + x2 = _make_token_sequence([[7]]) + y2, state2 = model.step(x2, new_state) + self.assertEqual(y2.shape, (batch, 1, vocab_size)) + + def test_step_layer_match(self): + """step() and layer() produce matching outputs.""" + model = self._make_model() + batch, time = 2, 8 + values = mx.random.randint(0, 256, shape=(batch, time)).astype(mx.int32) + mask = mx.ones((batch, time), dtype=mx.bool_) + x = Sequence(values, mask) + + y_layer = model.layer(x) + y_step, _ = test_utils.step_by_step(model, x) + + np.testing.assert_allclose( + np.array(y_step.values), + np.array(y_layer.values), + atol=1e-4, + rtol=1e-4, + err_msg='step() and layer() outputs differ', + ) + + def test_autoregressive_generation(self): + """Token-by-token generation loop with random weights.""" + model = self._make_model() + batch, vocab_size, max_len = 1, 256, 16 + input_spec = ShapeDType((), mx.int32) + + export._materialize_deferred(model, batch, input_spec) + state = model.get_initial_state(batch, input_spec) + + token = 0 + generated = [token] + + for _ in range(max_len - 1): + x = _make_token_sequence([[token]]) + y, state = model.step(x, state) + mx.eval(y.values) + + logits = y.values[0, 0] # [vocab_size] + token = int(mx.argmax(logits)) + generated.append(token) + + self.assertLen(generated, max_len) + for t in generated: + self.assertGreaterEqual(t, 0) + self.assertLess(t, vocab_size) + + def test_export_import(self): + """Export step to .mlxfn, import, verify same outputs.""" + model = self._make_model() + batch, vocab_size = 1, 256 + input_spec = ShapeDType((), mx.int32) + + export._materialize_deferred(model, batch, input_spec) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'decoder.mlxfn') + export.export_step(model, path, batch_size=batch, input_spec=input_spec) + self.assertTrue(os.path.exists(path)) + + imported = mx.import_function(path) + + tokens = [42, 7, 13] + inputs = [_make_token_sequence([[t]]) for t in tokens] + + # Native inference. + state = model.get_initial_state(batch, input_spec) + native_outputs = [] + for x in inputs: + y, state = model.step(x, state) + mx.eval(y.values) + native_outputs.append(np.array(y.values)) + + # Exported inference. + flat_state, structure = export.get_initial_state_flat( + model, batch, input_spec + ) + exported_outputs = [] + for x in inputs: + y_vals, y_mask, flat_state = export.run_exported( + imported, x.values, x.mask, flat_state + ) + mx.eval(y_vals) + exported_outputs.append(np.array(y_vals)) + + for i, (native, exported) in enumerate( + zip(native_outputs, exported_outputs) + ): + np.testing.assert_allclose( + exported, + native, + atol=1e-4, + rtol=1e-4, + err_msg=f'Token {tokens[i]}: exported != native', + ) + + def test_export_file_size(self): + """Exported .mlxfn file has reasonable size.""" + model = self._make_model() + batch = 1 + input_spec = ShapeDType((), mx.int32) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'decoder.mlxfn') + export.export_step(model, path, batch_size=batch, input_spec=input_spec) + size_mb = os.path.getsize(path) / (1024 * 1024) + # Small model should be < 10 MB. + self.assertLess(size_mb, 10.0) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py new file mode 100644 index 0000000..0023969 --- /dev/null +++ b/sequence_layers/mlx/dense.py @@ -0,0 +1,316 @@ +"""Dense sequence layer for MLX.""" + +import math + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx.init_mapping import _to_mx_dtype +from sequence_layers.mlx import types + +Sequence = bt.Sequence + + +def _parse_equation(equation): + """Parse einsum equation of form '...ab,bc->...ac'.""" + if '->' not in equation: + raise ValueError(f'equation is not valid for EinsumDense: {equation}') + left, output_spec = equation.split('->') + input_spec, kernel_spec = left.split(',') + if not input_spec.startswith('...') or not output_spec.startswith('...'): + raise ValueError('Equation must be of the form "...X,Y->...Z".') + if 3 + len(set(input_spec[3:])) != len(input_spec): + raise ValueError( + f'Equation {input_spec=} must not contain duplicate variables.' + ) + if 3 + len(set(output_spec[3:])) != len(output_spec): + raise ValueError( + f'Equation {output_spec=} must not contain duplicate variables.' + ) + return input_spec, kernel_spec, output_spec + + +def _compute_shapes(equation, input_shape, output_shape_spec, bias_axes): + """Compute kernel_shape and bias_shape from equation and shapes. + + Args: + equation: einsum equation string. + input_shape: channel shape of input (excluding batch/time). + output_shape_spec: user-specified output shape with possible Nones. + bias_axes: string of output axes that get bias. + + Returns: + (output_shape, kernel_shape, bias_shape) where bias_shape may be None. + """ + input_spec, kernel_spec, output_spec = _parse_equation(equation) + in_spec = input_spec[3:] # Strip '...' + out_spec = output_spec[3:] + + if len(in_spec) != len(input_shape): + raise ValueError( + f'Equation {in_spec=} does not match {input_shape=} rank.' + ) + + input_dims = {d: input_shape[i] for i, d in enumerate(in_spec)} + output_shape = list(output_shape_spec) + if len(out_spec) != len(output_shape): + raise ValueError( + f'Equation {out_spec=} does not match {output_shape=}.' + ) + + for i, d in enumerate(out_spec): + if output_shape[i] is None: + output_shape[i] = input_dims[d] + elif d in input_dims and output_shape[i] != input_dims[d]: + raise ValueError( + f'Inconsistent dimension {d=}. {output_shape=} vs {input_shape=}' + ) + + output_dim_map = {d: output_shape[i] for i, d in enumerate(out_spec)} + + kernel_shape = [] + for d in kernel_spec: + if d in input_dims: + kernel_shape.append(input_dims[d]) + elif d in output_dim_map: + kernel_shape.append(output_dim_map[d]) + else: + raise ValueError( + f"Weight dimension '{d}' not in input or output spec." + ) + + if bias_axes: + first_bias_loc = min(out_spec.find(c) for c in bias_axes) + bias_out_spec = out_spec[first_bias_loc:] + bias_shape = [ + output_dim_map[c] if c in bias_axes else 1 for c in bias_out_spec + ] + else: + bias_shape = None + + return tuple(output_shape), tuple(kernel_shape), bias_shape + + +class Dense(types.Stateless): + """A basic dense layer backed by mlx.nn.Linear. + + Unlike mlx.nn.Linear (which stores weight as [out, in]), this layer + presents a SequenceLayer interface. Weight conversion from Linen + [in, out] requires a single transpose at load time. + """ + + def __init__( + self, + *, + in_features: int, + features: int, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.features = features + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + self._linear = nn.Linear(in_features, features, bias=use_bias) + + @property + def use_bias(self): + return 'bias' in self._linear + + def get_output_shape(self, input_shape, *, constants=None): + if not input_shape: + raise ValueError( + f'Dense requires at least rank 3 input. Got: {input_shape=}' + ) + return tuple(input_shape[:-1]) + (self.features,) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + compute_dtype = self.get_output_dtype(x.dtype) + + def dense_fn(v): + y = self._linear(v.astype(compute_dtype)) + if self.activation is not None: + y = self.activation(y) + return y + + if self.use_bias or self.activation is not None: + return x.apply_values(dense_fn) + else: + return x.apply_values_masked(dense_fn) + + @classmethod + def from_config(cls, config): + """Create a Dense layer from a Linen Dense.Config.""" + return cls( + in_features=None, # Deferred; set by DenseDeferred. + features=config.features, + use_bias=config.use_bias, + activation=init_mapping.map_activation(config.activation), + compute_dtype=getattr(config, 'compute_dtype', None), + param_dtype=config.param_dtype, + ) + + +class DenseDeferred(types.Stateless): + """Dense layer that defers weight creation until first use. + + This is needed because Linen Dense.Config doesn't specify in_features; + it is inferred from the first input. This wrapper creates the actual + Dense layer on first call. + """ + + def __init__( + self, + *, + features: int, + use_bias: bool = True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.features = features + self._use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + self._inner = None + + def _ensure_initialized(self, in_features: int): + if self._inner is not None: + return + self._inner = Dense( + in_features=in_features, + features=self.features, + use_bias=self._use_bias, + activation=self.activation, + compute_dtype=self.compute_dtype, + param_dtype=self._param_dtype, + ) + + def get_output_shape(self, input_shape, *, constants=None): + if not input_shape: + raise ValueError( + f'Dense requires at least rank 3 input. Got: {input_shape=}' + ) + return tuple(input_shape[:-1]) + (self.features,) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.shape[-1]) + return self._inner.layer(x, constants=constants) + + @classmethod + def from_config(cls, config): + """Create from a Linen Dense.Config.""" + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + features=config.features, + use_bias=config.use_bias, + activation=init_mapping.map_activation(config.activation), + compute_dtype=compute_dtype, + param_dtype=_to_mx_dtype(config.param_dtype), + ) + + +class EinsumDense(types.Stateless): + """Dense layer using Einstein summation notation. + + Equation must be of the form '...ab,bc->...ac' where the leading '...' + broadcasts over batch and time dimensions. + """ + + def __init__( + self, + *, + equation, + output_shape, + bias_axes='', + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self._equation = equation + self._output_shape_spec = tuple(output_shape) + self._bias_axes = bias_axes + self._activation = activation + self._compute_dtype = compute_dtype + self._param_dtype = param_dtype + # Deferred: created on first call. + self.kernel = None + self.bias = None + self._initialized = False + + def _ensure_initialized(self, input_shape): + if self._initialized: + return + output_shape, kernel_shape, bias_shape = _compute_shapes( + self._equation, input_shape, self._output_shape_spec, self._bias_axes + ) + self._resolved_output_shape = output_shape + self.kernel = mx.zeros(kernel_shape, dtype=self._param_dtype) + if bias_shape is not None: + self.bias = mx.zeros(bias_shape, dtype=self._param_dtype) + self._initialized = True + + def get_output_shape(self, input_shape, *, constants=None): + output_shape, _, _ = _compute_shapes( + self._equation, input_shape, self._output_shape_spec, self._bias_axes + ) + return output_shape + + def get_output_dtype(self, input_dtype, *, constants=None): + if self._compute_dtype is not None: + return self._compute_dtype + return self._param_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.channel_shape) + compute_dtype = self.get_output_dtype(x.dtype) + + def einsum_fn(v): + y = mx.einsum(self._equation, v.astype(compute_dtype), self.kernel) + if self.bias is not None: + y = y + self.bias + if self._activation is not None: + y = self._activation(y) + return y + + if self.bias is not None or self._activation is not None: + return x.apply_values(einsum_fn) + return x.apply_values_masked(einsum_fn) + + @classmethod + def from_config(cls, config): + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + equation=config.equation, + output_shape=config.output_shape, + bias_axes=config.bias_axes, + activation=init_mapping.map_activation(config.activation), + compute_dtype=compute_dtype, + param_dtype=_to_mx_dtype(config.param_dtype), + ) diff --git a/sequence_layers/mlx/dense_test.py b/sequence_layers/mlx/dense_test.py new file mode 100644 index 0000000..c18243f --- /dev/null +++ b/sequence_layers/mlx/dense_test.py @@ -0,0 +1,112 @@ +"""Tests for Dense MLX sequence layer.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import dense +from sequence_layers.mlx import test_utils + + +class DenseTest(parameterized.TestCase): + + def test_layer(self): + layer = dense.Dense(in_features=4, features=8) + test_utils.verify_contract(self, layer, (4,)) + + def test_output_shape(self): + layer = dense.Dense(in_features=4, features=8) + self.assertEqual(layer.get_output_shape((4,)), (8,)) + + def test_no_bias(self): + layer = dense.Dense(in_features=4, features=8, use_bias=False) + test_utils.verify_contract(self, layer, (4,)) + + @parameterized.named_parameters( + ('relu', mx.array.__class__), + ('none', None), + ) + def test_activation(self, activation): + import mlx.nn as nn + + act = nn.relu if activation is not None else None + layer = dense.Dense(in_features=4, features=8, activation=act) + test_utils.verify_contract(self, layer, (4,)) + + +class DenseDeferredTest(parameterized.TestCase): + + def test_layer(self): + layer = dense.DenseDeferred(features=8) + test_utils.verify_contract(self, layer, (4,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dense as jax_dense + + config = jax_dense.Dense.Config(features=16) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dense.DenseDeferred) + + x = test_utils.random_sequence(2, 5, 8) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (16,)) + + +class EinsumDenseTest(parameterized.TestCase): + + def test_basic(self): + layer = dense.EinsumDense( + equation='...a,ab->...b', + output_shape=(8,), + ) + test_utils.verify_contract(self, layer, (4,)) + + def test_output_shape(self): + layer = dense.EinsumDense( + equation='...a,ab->...b', + output_shape=(8,), + ) + self.assertEqual(layer.get_output_shape((4,)), (8,)) + + def test_inferred_output(self): + layer = dense.EinsumDense( + equation='...ab,bc->...ac', + output_shape=(None, 7), + ) + self.assertEqual(layer.get_output_shape((3, 5)), (3, 7)) + + def test_with_bias(self): + layer = dense.EinsumDense( + equation='...a,ab->...b', + output_shape=(8,), + bias_axes='b', + ) + test_utils.verify_contract(self, layer, (4,)) + + def test_multi_dim(self): + layer = dense.EinsumDense( + equation='...abc,bd->...bd', + output_shape=(None, 6), + ) + self.assertEqual(layer.get_output_shape((2, 3, 5)), (3, 6)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dense as jax_dense + + config = jax_dense.EinsumDense.Config( + equation='...a,ab->...b', + output_shape=(16,), + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dense.EinsumDense) + + x = test_utils.random_sequence(2, 5, 8) + y = mlx_layer.layer(x) + self.assertEqual(y.channel_shape, (16,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/dsp.py b/sequence_layers/mlx/dsp.py new file mode 100644 index 0000000..4d8bcbe --- /dev/null +++ b/sequence_layers/mlx/dsp.py @@ -0,0 +1,1193 @@ +"""DSP layers for MLX.""" + +import fractions +import math + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import convolution as conv_utils +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +PaddingMode = bt.PaddingMode + + +# --------------------------------------------------------------------------- +# Signal utilities +# --------------------------------------------------------------------------- + + +def hann_window(window_length, periodic=True, dtype=np.float32): + """Compute a periodic Hann window.""" + if window_length == 1: + return np.ones([1], dtype=dtype) + even = 1 - window_length % 2 + n = np.asarray(window_length + int(periodic) * even - 1, dtype=dtype) + count = np.arange(window_length, dtype=dtype) + return np.asarray(0.5 - 0.5 * np.cos(2 * np.pi * count / n), dtype) + + +def frame(values, frame_length, frame_step, pad_mode='valid', axis=1): + """Produce overlapping frames of a signal along `axis`. + + Args: + values: [..., T, ...] array. + frame_length: Length of each frame. + frame_step: Stride between frames. + pad_mode: Padding mode string or 'valid'. + axis: Axis along which to frame. + + Returns: + [..., num_frames, frame_length, ...] array. + """ + # Normalize axis. + if axis < 0: + axis += values.ndim + + t = values.shape[axis] + + # Apply padding if needed. + if isinstance(pad_mode, str) and pad_mode != PaddingMode.VALID.value: + pad_left, pad_right = conv_utils._explicit_padding( + pad_mode, frame_length, frame_step, 1 + ) + pad_widths = [(0, 0)] * values.ndim + pad_widths[axis] = (pad_left, pad_right) + values = mx.pad(values, pad_widths) + t = values.shape[axis] + + # Compute number of frames. + num_frames = max(0, (t - frame_length) // frame_step + 1) + + # Gather frames using indexing. + indices = ( + mx.arange(num_frames)[:, None] * frame_step + + mx.arange(frame_length)[None, :] + ) + # indices: [num_frames, frame_length] + + # Flatten, gather, reshape. + # Move axis to position 1 for easier manipulation. + if axis != 1: + perm = list(range(values.ndim)) + perm[1], perm[axis] = perm[axis], perm[1] + values = mx.transpose(values, perm) + + # values shape: [batch, t, ...] + batch = values.shape[0] + rest_shape = values.shape[2:] + + # Gather: result [batch, num_frames, frame_length, ...] + # Use fancy indexing along axis 1. + result = values[:, indices.reshape(-1)] + result = result.reshape((batch, num_frames, frame_length) + rest_shape) + + if axis != 1: + # Move back. + perm = list(range(result.ndim)) + # axis was swapped to 1, new dims are at 1 and 2. + # Need to move them back so axis and axis+1 have the frame dims. + perm[1], perm[axis] = perm[axis], perm[1] + # Also move frame_length dim. + if axis > 1: + perm.insert(axis + 1, perm.pop(2)) + result = mx.transpose(result, perm) + + return result + + +def overlap_and_add(signal_arr, frame_step): + """Overlap-add framed signal. + + Args: + signal_arr: [..., frames, frame_length] array. + frame_step: Stride between frames. + + Returns: + [..., output_length] array where + output_length = (frames - 1) * frame_step + frame_length. + """ + shape = signal_arr.shape + outer_dims = shape[:-2] + frames = shape[-2] + frame_length = shape[-1] + output_length = frame_length + frame_step * (frames - 1) + + if frame_length == frame_step: + return signal_arr.reshape(outer_dims + (output_length,)) + + # General overlap-add via scatter. + outer_size = 1 + for d in outer_dims: + outer_size *= d + + flat = signal_arr.reshape(outer_size, frames, frame_length) + result = mx.zeros((outer_size, output_length), dtype=flat.dtype) + + for f in range(frames): + start = f * frame_step + result = result.at[:, start : start + frame_length].add(flat[:, f]) + + return result.reshape(outer_dims + (output_length,)) + + +def linear_to_mel_weight_matrix( + num_mel_bins, + num_spectrogram_bins, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + dtype=np.float64, +): + """Create a weight matrix for converting linear spectrogram to mel.""" + + # Mel scale conversion (HTK formula). + def hz_to_mel(f): + return 2595.0 * np.log10(1.0 + f / 700.0) + + def mel_to_hz(m): + return 700.0 * (10.0 ** (m / 2595.0) - 1.0) + + nyquist = sample_rate / 2.0 + freq_bins = np.linspace(0, nyquist, num_spectrogram_bins) + + mel_low = hz_to_mel(lower_edge_hertz) + mel_high = hz_to_mel(upper_edge_hertz) + mel_points = np.linspace(mel_low, mel_high, num_mel_bins + 2) + hz_points = mel_to_hz(mel_points) + + weights = np.zeros((num_spectrogram_bins, num_mel_bins), dtype=dtype) + for i in range(num_mel_bins): + lower = hz_points[i] + center = hz_points[i + 1] + upper = hz_points[i + 2] + + # Rising slope. + for j in range(num_spectrogram_bins): + if lower <= freq_bins[j] <= center and center > lower: + weights[j, i] = (freq_bins[j] - lower) / (center - lower) + elif center < freq_bins[j] <= upper and upper > center: + weights[j, i] = (upper - freq_bins[j]) / (upper - center) + + return weights + + +# --------------------------------------------------------------------------- +# Delay +# --------------------------------------------------------------------------- + + +class Delay(types.PreservesShape, types.PreservesType, types.SequenceLayer): + """Delays input by `length` timesteps. + + Inserts `length` invalid timesteps at the start of the sequence. + """ + + def __init__(self, *, length, delay_layer_output=True): + super().__init__() + if length < 0: + raise ValueError(f'length must be non-negative, got {length}.') + self.length = length + self.delay_layer_output = delay_layer_output + + @property + def input_latency(self): + return self.length + + @property + def output_latency(self): + return 0 if self.delay_layer_output else self.length + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not self.length: + return () + return Sequence( + mx.zeros( + (batch_size, self.length) + input_spec.shape, + dtype=input_spec.dtype, + ), + mx.zeros( + (batch_size, self.length), + dtype=bt.MASK_DTYPE, + ), + ) + + @types.check_step + def step(self, x, state, *, constants=None): + if not self.length: + return x, state + state = state.concatenate(x) + t = x.shape[1] + y = Sequence(state.values[:, :t], state.mask[:, :t]) + state = Sequence(state.values[:, t:], state.mask[:, t:]) + return y, state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.delay_layer_output: + return x.pad_time(self.length, 0, valid=False) + return x + + @classmethod + def from_config(cls, config): + return cls( + length=config.length, + delay_layer_output=config.delay_layer_output, + ) + + +# --------------------------------------------------------------------------- +# Lookahead +# --------------------------------------------------------------------------- + + +class Lookahead(types.PreservesShape, types.PreservesType, types.SequenceLayer): + """Drops the first `length` timesteps from the input.""" + + def __init__(self, *, length, preserve_length_in_layer=False): + super().__init__() + if length < 0: + raise ValueError(f'length must be non-negative, got {length}.') + self.length = length + self.preserve_length_in_layer = preserve_length_in_layer + + @property + def input_latency(self): + return 0 + + @property + def output_latency(self): + return self.length + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not self.length: + return () + return mx.full( + (batch_size,), + self.length + 1, + dtype=mx.int32, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + if not self.length: + return x, state + increments = mx.cumsum(x.mask.astype(mx.int32), axis=1) + countdown = mx.maximum(0, state[:, None] - increments) + mask = mx.logical_and(x.mask, countdown == 0) + y = Sequence(x.values, mask) + state = countdown[:, -1] + return y, state + + @types.check_layer + def layer(self, x, *, constants=None): + if not self.length: + return x + x = x[:, self.length :] + if self.preserve_length_in_layer: + return x.pad_time(0, self.length, valid=False) + return x + + @classmethod + def from_config(cls, config): + return cls( + length=config.length, + preserve_length_in_layer=config.preserve_length_in_layer, + ) + + +# --------------------------------------------------------------------------- +# Window +# --------------------------------------------------------------------------- + + +class Window(types.PreservesShape, types.PreservesType, types.Stateless): + """Applies a window function along a channel axis.""" + + def __init__(self, *, axis, window_fn=None): + super().__init__() + self._axis = axis + self._window_fn = window_fn or hann_window + + def _get_axis(self, x): + axis = self._axis + if axis < 0: + axis += x.ndim + if axis < 2: + raise ValueError( + f'Window axis must be a channel axis (>= 2), got {axis}.' + ) + return axis + + @types.check_layer + def layer(self, x, *, constants=None): + axis = self._get_axis(x) + window_length = x.shape[axis] + window = self._window_fn(window_length) + window = mx.array(window, dtype=x.dtype) + shape = [1] * x.ndim + shape[axis] = window_length + window = window.reshape(shape) + return x.apply_values_masked(lambda v: v * window) + + @classmethod + def from_config(cls, config): + return cls( + axis=config.axis, + window_fn=config.window_fn, + ) + + +# --------------------------------------------------------------------------- +# Frame +# --------------------------------------------------------------------------- + + +class Frame(types.PreservesType, types.SequenceLayer): + """Produces overlapping frames of the input sequence.""" + + def __init__(self, *, frame_length, frame_step, padding='valid'): + super().__init__() + if frame_length <= 0: + raise ValueError(f'frame_length must be positive: {frame_length}') + if frame_step <= 0: + raise ValueError(f'frame_step must be positive: {frame_step}') + self.frame_length = frame_length + self.frame_step = frame_step + self.padding = padding + + @property + def supports_step(self): + return conv_utils._supports_step(self.padding) + + @property + def block_size(self): + return self.frame_step + + @property + def output_ratio(self): + return fractions.Fraction(1, self.frame_step) + + @property + def input_latency(self): + if self.padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self.padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + ): + return self.frame_length - 1 + return 0 + + @property + def _buffer_width(self): + if self.padding == PaddingMode.SEMICAUSAL.value: + return max(self.frame_length - self.frame_step, 0) + elif self.padding in ( + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + ): + return (self.frame_length - 1) // self.frame_step * self.frame_step + elif self.padding in ( + PaddingMode.CAUSAL.value, + PaddingMode.CAUSAL_VALID.value, + ): + return self.frame_length - 1 + else: + raise ValueError(f'Unsupported step padding: {self.padding}') + + def get_output_shape(self, input_shape, *, constants=None): + return (self.frame_length,) + tuple(input_shape) + + def get_output_dtype(self, input_dtype, *, constants=None): + return input_dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = self._buffer_width + if not bw: + return () + return conv_utils._compute_initial_state( + batch_size, + input_spec, + bw, + self.padding, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + if self.frame_length > 1: + x = x.mask_invalid() + + bw = self._buffer_width + if bw: + state = state.concatenate(x) + else: + state = x + + values = frame( + state.values, + frame_length=self.frame_length, + frame_step=self.frame_step, + pad_mode=PaddingMode.VALID.value, + axis=1, + ) + mask = conv_utils._compute_conv_mask( + state.mask, + self.frame_length, + self.frame_step, + 1, + self.padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.frame_length > 1: + x = x.mask_invalid() + + values = frame( + x.values, + frame_length=self.frame_length, + frame_step=self.frame_step, + pad_mode=self.padding, + axis=1, + ) + mask = conv_utils._compute_conv_mask( + x.mask, + self.frame_length, + self.frame_step, + 1, + self.padding, + is_step=False, + ) + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + return cls( + frame_length=config.frame_length, + frame_step=config.frame_step, + padding=config.padding, + ) + + +# --------------------------------------------------------------------------- +# OverlapAdd +# --------------------------------------------------------------------------- + + +class OverlapAdd(types.PreservesType, types.SequenceLayer): + """Overlap-adds windows of [b, t, frame_length, ...]. + + Output shape: [b, to, ...] where + to = (ti - 1) * frame_step + frame_length. + """ + + def __init__(self, *, frame_length, frame_step, padding='valid'): + super().__init__() + if frame_length <= 0: + raise ValueError(f'frame_length must be positive: {frame_length}') + if frame_step <= 0: + raise ValueError(f'frame_step must be positive: {frame_step}') + if frame_length < frame_step: + raise ValueError('frame_length must be >= frame_step.') + if padding not in ( + PaddingMode.CAUSAL.value, + PaddingMode.VALID.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + raise ValueError(f'Unsupported padding: {padding}') + self.frame_length = frame_length + self.frame_step = frame_step + self.padding = padding + + @property + def supports_step(self): + return self.padding == PaddingMode.CAUSAL.value + + @property + def output_ratio(self): + return fractions.Fraction(self.frame_step) + + @property + def _buffer_width(self): + return max(0, self.frame_length - self.frame_step) + + def get_output_shape(self, input_shape, *, constants=None): + if not input_shape or input_shape[0] != self.frame_length: + raise ValueError( + f'OverlapAdd expects (frame_length, ...) input, got {input_shape}.' + ) + return tuple(input_shape[1:]) + + def get_output_dtype(self, input_dtype, *, constants=None): + return input_dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not input_shape_valid(input_spec.shape, self.frame_length): + raise ValueError(f'Invalid input_spec shape: {input_spec.shape}') + bw = self._buffer_width + if not bw: + return () + out_shape = tuple(input_spec.shape[1:]) + return mx.zeros( + (batch_size, bw) + out_shape, + dtype=input_spec.dtype, + ) + + @types.check_step + def step(self, x, state, *, constants=None): + if self.frame_length > 1: + x = x.mask_invalid() + + # Transpose [num_frames, frame_length] to end. + if x.ndim > 3: + # Move axes 1,2 to -2,-1. + axes = list(range(x.ndim)) + axes.remove(1) + axes.remove(2) + axes.extend([1, 2]) + values = mx.transpose(x.values, axes) + else: + values = x.values + + values = overlap_and_add(values, self.frame_step) + + if x.ndim > 3: + # Move back. + values = mx.moveaxis(values, -1, 1) + + mask = conv_utils._compute_conv_transpose_mask( + x.mask, + self.frame_length, + self.frame_step, + 1, + self.padding, + ) + + bw = self._buffer_width + if bw: + time = x.shape[1] + # Pad state to extend to values length. + pad_right = max(0, values.shape[1] - bw) + pad_widths = [(0, 0)] * state.ndim + pad_widths[1] = (0, pad_right) + padded_state = mx.pad(state, pad_widths) + + values = values + padded_state + + output_samples = self.frame_step * time + output = values[:, :output_samples] + state = values[:, output_samples : output_samples + bw] + if state.shape[1] < bw: + pad_widths = [(0, 0)] * state.ndim + pad_widths[1] = (0, bw - state.shape[1]) + state = mx.pad(state, pad_widths) + values = output + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.frame_length > 1: + x = x.mask_invalid() + + if x.ndim > 3: + axes = list(range(x.ndim)) + axes.remove(1) + axes.remove(2) + axes.extend([1, 2]) + values = mx.transpose(x.values, axes) + else: + values = x.values + + values = overlap_and_add(values, self.frame_step) + + if x.ndim > 3: + values = mx.moveaxis(values, -1, 1) + + mask = conv_utils._compute_conv_transpose_mask( + x.mask, + self.frame_length, + self.frame_step, + 1, + self.padding, + ) + + trim = max(self.frame_length - self.frame_step, 0) + if self.padding == PaddingMode.CAUSAL.value: + if trim: + values = values[:, :-trim] + elif self.padding == PaddingMode.SEMICAUSAL_FULL.value: + if trim: + values = values[:, trim:] + mask = mask[:, trim:] + size = min(values.shape[1], mask.shape[1]) + return Sequence(values[:, :size], mask[:, :size]) + + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + return cls( + frame_length=config.frame_length, + frame_step=config.frame_step, + padding=config.padding, + ) + + +def input_shape_valid(shape, frame_length): + return shape and shape[0] == frame_length + + +# --------------------------------------------------------------------------- +# FFT layers +# --------------------------------------------------------------------------- + + +def _validate_and_normalize_axis(axis, input_shape): + """Normalize axis for FFT, ensuring it's a channel axis.""" + if axis < 0: + axis += len(input_shape) + if axis < 0 or axis >= len(input_shape): + raise ValueError(f'Axis {axis} out of range for shape {input_shape}.') + if axis in (0, 1): + raise ValueError(f'FFT over batch/time not allowed. Got axis={axis}.') + return axis + + +def _pad_or_truncate_for_fft(x, axis, required_length, padding): + """Pad or truncate sequence for FFT.""" + input_dim = x.shape[axis] + if input_dim == required_length: + return x + if input_dim < required_length: + pad_amount = required_length - input_dim + if padding == 'center': + pad_left = pad_amount // 2 + pad_right = pad_amount - pad_left + else: + pad_left = 0 + pad_right = pad_amount + pad_widths = [(0, 0)] * x.ndim + pad_widths[axis] = (pad_left, pad_right) + return x.apply_values_masked(mx.pad, pad_widths) + else: + # Truncate. + if padding == 'center': + start = (input_dim - required_length) // 2 + else: + start = 0 + slices = [slice(None)] * x.ndim + slices[axis] = slice(start, start + required_length) + return x.apply_values_masked(lambda v: v[tuple(slices)]) + + +class FFT(types.PreservesType, types.Stateless): + """Applies FFT to a channel dimension.""" + + def __init__(self, *, fft_length=None, axis=-1, padding='right'): + super().__init__() + self.fft_length = fft_length + self._axis = axis + self._padding = padding + + def _get_output_length(self, input_size): + return self.fft_length or input_size + + def get_output_shape(self, input_shape, *, constants=None): + shape = list(input_shape) + axis = ( + _validate_and_normalize_axis( + self._axis, (None, None) + tuple(input_shape) + ) + - 2 + ) + shape[axis] = self._get_output_length(shape[axis]) + return tuple(shape) + + @types.check_layer + def layer(self, x, *, constants=None): + if x.ndim <= 2: + raise ValueError('FFT requires rank >= 3 input.') + axis = _validate_and_normalize_axis(self._axis, x.shape) + required = self._get_output_length(x.shape[axis]) + x = _pad_or_truncate_for_fft(x, axis, required, self._padding) + return x.apply_values(mx.fft.fft, axis=axis) + + @classmethod + def from_config(cls, config): + return cls( + fft_length=config.fft_length, + axis=config.axis, + padding=config.padding, + ) + + +class IFFT(types.PreservesType, types.Stateless): + """Applies IFFT to a channel dimension.""" + + def __init__( + self, + *, + fft_length=None, + frame_length=None, + axis=-1, + padding='right', + ): + super().__init__() + self.fft_length = fft_length + self.frame_length = frame_length + self._axis = axis + self._padding = padding + + def _get_output_length(self, input_size): + return self.frame_length or input_size + + def get_output_shape(self, input_shape, *, constants=None): + shape = list(input_shape) + axis = ( + _validate_and_normalize_axis( + self._axis, (None, None) + tuple(input_shape) + ) + - 2 + ) + shape[axis] = self._get_output_length(shape[axis]) + return tuple(shape) + + @types.check_layer + def layer(self, x, *, constants=None): + if x.ndim <= 2: + raise ValueError('IFFT requires rank >= 3 input.') + axis = _validate_and_normalize_axis(self._axis, x.shape) + x = x.apply_values(mx.fft.ifft, axis=axis) + required = self._get_output_length(x.shape[axis]) + return _pad_or_truncate_for_fft(x, axis, required, self._padding) + + @classmethod + def from_config(cls, config): + return cls( + fft_length=config.fft_length, + frame_length=config.frame_length, + axis=config.axis, + padding=config.padding, + ) + + +class RFFT(types.Stateless): + """Applies RFFT to a channel dimension.""" + + def __init__(self, *, fft_length=None, axis=-1, padding='right'): + super().__init__() + self.fft_length = fft_length + self._axis = axis + self._padding = padding + + def _get_fft_length(self, input_size): + return self.fft_length or input_size + + def _get_output_length(self, input_size): + return self._get_fft_length(input_size) // 2 + 1 + + def get_output_shape(self, input_shape, *, constants=None): + shape = list(input_shape) + axis = ( + _validate_and_normalize_axis( + self._axis, (None, None) + tuple(input_shape) + ) + - 2 + ) + shape[axis] = self._get_output_length(shape[axis]) + return tuple(shape) + + def get_output_dtype(self, input_dtype, *, constants=None): + return mx.complex64 + + @types.check_layer + def layer(self, x, *, constants=None): + if x.ndim <= 2: + raise ValueError('RFFT requires rank >= 3 input.') + axis = _validate_and_normalize_axis(self._axis, x.shape) + fft_len = self._get_fft_length(x.shape[axis]) + x = _pad_or_truncate_for_fft(x, axis, fft_len, self._padding) + + def rfft_fn(v): + if v.dtype == mx.bfloat16: + v = v.astype(mx.float32) + return mx.fft.rfft(v, n=fft_len, axis=axis) + + return x.apply_values(rfft_fn) + + @classmethod + def from_config(cls, config): + return cls( + fft_length=config.fft_length, + axis=config.axis, + padding=config.padding, + ) + + +class IRFFT(types.Stateless): + """Applies IRFFT to a channel dimension.""" + + def __init__( + self, + *, + fft_length=None, + frame_length=None, + axis=-1, + padding='right', + ): + super().__init__() + self.fft_length = fft_length + self.frame_length = frame_length + self._axis = axis + self._padding = padding + + def _get_fft_length(self, input_size): + return self.fft_length or (input_size - 1) * 2 + + def _get_output_length(self, input_size): + return self.frame_length or self._get_fft_length(input_size) + + def get_output_shape(self, input_shape, *, constants=None): + shape = list(input_shape) + axis = ( + _validate_and_normalize_axis( + self._axis, (None, None) + tuple(input_shape) + ) + - 2 + ) + shape[axis] = self._get_output_length(shape[axis]) + return tuple(shape) + + def get_output_dtype(self, input_dtype, *, constants=None): + return mx.float32 + + @types.check_layer + def layer(self, x, *, constants=None): + if x.ndim <= 2: + raise ValueError('IRFFT requires rank >= 3 input.') + axis = _validate_and_normalize_axis(self._axis, x.shape) + fft_len = self._get_fft_length(x.shape[axis]) + x = x.apply_values(lambda v: mx.fft.irfft(v, n=fft_len, axis=axis)) + required = self.frame_length or x.shape[axis] + return _pad_or_truncate_for_fft(x, axis, required, self._padding) + + @classmethod + def from_config(cls, config): + return cls( + fft_length=config.fft_length, + frame_length=config.frame_length, + axis=config.axis, + padding=config.padding, + ) + + +# --------------------------------------------------------------------------- +# STFT +# --------------------------------------------------------------------------- + + +class STFT(types.SequenceLayer): + """Short-Time Fourier Transform. + + Composes Frame -> Window -> RFFT. + """ + + def __init__( + self, + *, + frame_length, + frame_step, + fft_length, + window_fn=None, + time_padding='reverse_causal_valid', + fft_padding='right', + output_magnitude=False, + ): + super().__init__() + self._frame_length = frame_length + self._frame_step = frame_step + self._fft_length = fft_length + self._window_fn = window_fn or hann_window + self._time_padding = time_padding + self._fft_padding = fft_padding + self._output_magnitude = output_magnitude + + self.framer = Frame( + frame_length=frame_length, + frame_step=frame_step, + padding=time_padding, + ) + self.fft = RFFT( + fft_length=fft_length, + axis=2, + padding=fft_padding, + ) + + @property + def supports_step(self): + return self.framer.supports_step + + @property + def block_size(self): + return self.framer.block_size + + @property + def output_ratio(self): + return self.framer.output_ratio + + @property + def input_latency(self): + return self.framer.input_latency + + def get_output_shape(self, input_shape, *, constants=None): + frame_shape = self.framer.get_output_shape(input_shape, constants=constants) + return self.fft.get_output_shape(frame_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + fft_dtype = self.fft.get_output_dtype(input_dtype, constants=constants) + if self._output_magnitude: + return mx.float32 + return fft_dtype + + def _apply_window(self, x): + if self._window_fn: + window = self._window_fn(self._frame_length) + window = mx.array(window, dtype=x.dtype) + shape = [1] * x.ndim + shape[2] = self._frame_length + window = window.reshape(shape) + return x.apply_values_masked(lambda v: v * window) + return x + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + return self.framer.get_initial_state( + batch_size, input_spec, constants=constants + ) + + @types.check_step + def step(self, x, state, *, constants=None): + framed, state = self.framer.step(x, state, constants=constants) + framed = self._apply_window(framed) + dft = self.fft.layer(framed, constants=constants) + if self._output_magnitude: + dft = dft.apply_values_masked(lambda v: mx.abs(v)) + return dft, state + + @types.check_layer + def layer(self, x, *, constants=None): + framed = self.framer.layer(x, constants=constants) + framed = self._apply_window(framed) + dft = self.fft.layer(framed, constants=constants) + if self._output_magnitude: + dft = dft.apply_values_masked(lambda v: mx.abs(v)) + return dft + + @classmethod + def from_config(cls, config): + return cls( + frame_length=config.frame_length, + frame_step=config.frame_step, + fft_length=config.fft_length, + window_fn=config.window_fn, + time_padding=config.time_padding, + fft_padding=config.fft_padding, + output_magnitude=config.output_magnitude, + ) + + +# --------------------------------------------------------------------------- +# InverseSTFT +# --------------------------------------------------------------------------- + + +class InverseSTFT(types.SequenceLayer): + """Inverse Short-Time Fourier Transform. + + Composes IRFFT -> Window -> OverlapAdd. + """ + + def __init__( + self, + *, + frame_length, + frame_step, + fft_length, + window_fn=None, + time_padding='causal', + fft_padding='right', + ): + super().__init__() + self._frame_length = frame_length + self._frame_step = frame_step + self._fft_length = fft_length + self._window_fn = window_fn or hann_window + self._time_padding = time_padding + self._fft_padding = fft_padding + + self.overlap_add = OverlapAdd( + frame_length=frame_length, + frame_step=frame_step, + padding=time_padding, + ) + self.irfft = IRFFT( + fft_length=fft_length, + frame_length=frame_length, + axis=2, + padding=fft_padding, + ) + + @property + def supports_step(self): + return self.overlap_add.supports_step + + @property + def block_size(self): + return 1 + + @property + def output_ratio(self): + return self.overlap_add.output_ratio + + @property + def input_latency(self): + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + irfft_shape = list( + self.irfft.get_output_shape(input_shape, constants=constants) + ) + irfft_shape[0] = self._frame_length + return self.overlap_add.get_output_shape(irfft_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.irfft.get_output_dtype(input_dtype, constants=constants) + + def _apply_window(self, irfft): + """Pad/truncate to frame_length and apply window.""" + fft_len = irfft.shape[2] + if fft_len > self._frame_length: + irfft = irfft.apply_values_masked(lambda v: v[:, :, : self._frame_length]) + elif fft_len < self._frame_length: + pad_amount = self._frame_length - fft_len + if self._fft_padding == 'center': + pl = pad_amount // 2 + pr = pad_amount - pl + else: + pl, pr = 0, pad_amount + pad_widths = [(0, 0)] * irfft.ndim + pad_widths[2] = (pl, pr) + irfft = irfft.apply_values_masked(mx.pad, pad_widths) + + if self._window_fn: + window = self._window_fn(self._frame_length) + window = mx.array(window, dtype=irfft.dtype) + shape = [1] * irfft.ndim + shape[2] = self._frame_length + window = window.reshape(shape) + irfft = irfft.apply_values_masked(lambda v: v * window) + return irfft + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + irfft_spec = self.irfft.get_output_spec(input_spec, constants=constants) + irfft_shape = list(irfft_spec.shape) + irfft_shape[0] = self._frame_length + irfft_spec = bt.ShapeDType(tuple(irfft_shape), irfft_spec.dtype) + return self.overlap_add.get_initial_state( + batch_size, irfft_spec, constants=constants + ) + + @types.check_step + def step(self, x, state, *, constants=None): + if x.ndim < 3: + raise ValueError(f'Expected [b,t,fft_bins,...] input, got {x.shape}.') + irfft = self.irfft.layer(x, constants=constants) + irfft = self._apply_window(irfft) + ola, state = self.overlap_add.step(irfft, state, constants=constants) + return ola, state + + @types.check_layer + def layer(self, x, *, constants=None): + if x.ndim < 3: + raise ValueError(f'Expected [b,t,fft_bins,...] input, got {x.shape}.') + irfft = self.irfft.layer(x, constants=constants) + irfft = self._apply_window(irfft) + ola = self.overlap_add.layer(irfft, constants=constants) + return ola + + @classmethod + def from_config(cls, config): + return cls( + frame_length=config.frame_length, + frame_step=config.frame_step, + fft_length=config.fft_length, + window_fn=config.window_fn, + time_padding=config.time_padding, + fft_padding=config.fft_padding, + ) + + +# --------------------------------------------------------------------------- +# LinearToMelSpectrogram +# --------------------------------------------------------------------------- + + +class LinearToMelSpectrogram(types.PreservesType, types.Stateless): + """Converts linear spectrogram to mel spectrogram.""" + + def __init__( + self, + *, + num_mel_bins, + sample_rate, + lower_edge_hertz, + upper_edge_hertz, + ): + super().__init__() + self.num_mel_bins = num_mel_bins + self.sample_rate = sample_rate + self.lower_edge_hertz = lower_edge_hertz + self.upper_edge_hertz = upper_edge_hertz + + def get_output_shape(self, input_shape, *, constants=None): + if not input_shape: + raise ValueError('LinearToMelSpectrogram requires rank >= 1 input.') + return tuple(input_shape[:-1]) + (self.num_mel_bins,) + + @types.check_layer + def layer(self, x, *, constants=None): + num_bins = x.shape[-1] + weights = linear_to_mel_weight_matrix( + num_mel_bins=self.num_mel_bins, + num_spectrogram_bins=num_bins, + sample_rate=self.sample_rate, + lower_edge_hertz=self.lower_edge_hertz, + upper_edge_hertz=self.upper_edge_hertz, + ) + weights = mx.array(weights, dtype=x.dtype) + return x.apply_values_masked(lambda v: v @ weights) + + @classmethod + def from_config(cls, config): + return cls( + num_mel_bins=config.num_mel_bins, + sample_rate=config.sample_rate, + lower_edge_hertz=config.lower_edge_hertz, + upper_edge_hertz=config.upper_edge_hertz, + ) diff --git a/sequence_layers/mlx/dsp_test.py b/sequence_layers/mlx/dsp_test.py new file mode 100644 index 0000000..d2e9c44 --- /dev/null +++ b/sequence_layers/mlx/dsp_test.py @@ -0,0 +1,463 @@ +"""Tests for DSP MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import dsp +from sequence_layers.mlx import test_utils + + +class DelayTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.Delay(length=3) + # Delay with delay_layer_output=True pads output time, + # so step/layer shapes differ. Test separately. + x = test_utils.random_sequence(2, 8, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4,)) + # Layer pads: output time = 8 + 3 = 11. + self.assertEqual(y.shape[1], 11) + + def test_step(self): + layer = dsp.Delay(length=3) + x = test_utils.random_sequence(1, 8, 4) + y_step, _ = test_utils.step_by_step(layer, x) + # Step output: same time as input. + self.assertEqual(y_step.shape[1], 8) + # First 3 outputs should be masked (invalid). + np.testing.assert_array_equal( + np.array(y_step.mask[0, :3]), [False, False, False] + ) + np.testing.assert_array_equal(np.array(y_step.mask[0, 3:]), [True] * 5) + + def test_zero_delay(self): + layer = dsp.Delay(length=0) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_delays_output(self): + layer = dsp.Delay(length=2) + values = mx.array([[[1.0], [2.0], [3.0], [4.0]]]) + mask = mx.ones((1, 4), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + # First 2 timesteps should be invalid (padded). + self.assertEqual(y.shape, (1, 6, 1)) + # Mask should have first 2 False. + np.testing.assert_array_equal( + np.array(y.mask[0]), + [False, False, True, True, True, True], + ) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.Delay.Config(length=3) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.Delay) + + +class LookaheadTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.Lookahead(length=2) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + # Drops first 2 timesteps. + self.assertEqual(y.shape[1], 6) + + def test_zero_lookahead(self): + layer = dsp.Lookahead(length=0) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_preserve_length(self): + layer = dsp.Lookahead(length=2, preserve_length_in_layer=True) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 8) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.Lookahead.Config(length=2) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.Lookahead) + + +class WindowTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.Window(axis=-1) + # Input: [batch, time, frame_length] + x = test_utils.random_sequence(1, 4, 8) + y = layer.layer(x) + self.assertEqual(y.shape, x.shape) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.Window.Config(axis=-1) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.Window) + + +class FrameTest(parameterized.TestCase): + + def test_valid_padding(self): + layer = dsp.Frame( + frame_length=4, + frame_step=2, + padding='valid', + ) + x = test_utils.random_sequence(1, 8, 1) + y = layer.layer(x) + # (8 - 4) // 2 + 1 = 3 frames + self.assertEqual(y.shape[1], 3) + self.assertEqual(y.channel_shape, (4, 1)) + + @parameterized.parameters( + ('causal',), + ('causal_valid',), + ) + def test_causal_paddings(self, padding): + layer = dsp.Frame( + frame_length=4, + frame_step=2, + padding=padding, + ) + test_utils.verify_contract( + self, + layer, + (1,), + time=8, + atol=1e-5, + rtol=1e-5, + ) + + def test_output_shape(self): + layer = dsp.Frame(frame_length=4, frame_step=2) + self.assertEqual(layer.get_output_shape((3,)), (4, 3)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.Frame.Config( + frame_length=4, + frame_step=2, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.Frame) + + +class OverlapAddTest(parameterized.TestCase): + + def test_valid(self): + layer = dsp.OverlapAdd( + frame_length=4, + frame_step=2, + padding='valid', + ) + # Input: [batch, frames, frame_length] + x = bt.MaskedSequence( + mx.ones((1, 3, 4)), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + # (3-1)*2 + 4 = 8 + self.assertEqual(y.shape, (1, 8)) + + def test_causal(self): + layer = dsp.OverlapAdd( + frame_length=4, + frame_step=2, + padding='causal', + ) + # Build input: [batch, frames, frame_length] + x = bt.MaskedSequence( + mx.ones((1, 3, 4)), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + # Causal trims overlap: output = frames * frame_step = 6 + self.assertEqual(y.shape[1], 6) + + def test_output_shape(self): + layer = dsp.OverlapAdd( + frame_length=4, + frame_step=2, + padding='valid', + ) + self.assertEqual(layer.get_output_shape((4, 3)), (3,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.OverlapAdd.Config( + frame_length=4, + frame_step=2, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.OverlapAdd) + + +class FFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.FFT() + x = test_utils.random_sequence(1, 4, 8) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + + def test_fft_length(self): + layer = dsp.FFT(fft_length=16) + x = test_utils.random_sequence(1, 4, 8) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (16,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.FFT.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.FFT) + + +class IFFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.IFFT() + values = mx.random.normal(shape=(1, 4, 8)) + 0j + values = values.astype(mx.complex64) + mask = mx.ones((1, 4), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + + +class RFFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.RFFT() + x = test_utils.random_sequence(1, 4, 8) + y = layer.layer(x) + # RFFT output size: 8 // 2 + 1 = 5 + self.assertEqual(y.channel_shape, (5,)) + self.assertEqual(y.dtype, mx.complex64) + + def test_fft_length(self): + layer = dsp.RFFT(fft_length=16) + x = test_utils.random_sequence(1, 4, 8) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (9,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.RFFT.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.RFFT) + + +class IRFFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.IRFFT() + # 5 complex bins -> irfft -> 8 real samples + values = mx.random.normal(shape=(1, 4, 5)) + 0j + values = values.astype(mx.complex64) + mask = mx.ones((1, 4), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (8,)) + self.assertEqual(y.dtype, mx.float32) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.IRFFT.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.IRFFT) + + +class STFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.STFT( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + ) + x = test_utils.random_sequence(1, 32, 1) + y = layer.layer(x) + # 32 / 8 = 4 frames, each with fft_length/2+1 = 9 bins. + self.assertEqual(y.shape[1], 4) + self.assertEqual(y.channel_shape, (9, 1)) + + def test_layer_magnitude(self): + layer = dsp.STFT( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + output_magnitude=True, + ) + x = test_utils.random_sequence(1, 32, 1) + y = layer.layer(x) + self.assertEqual(y.dtype, mx.float32) + # All magnitudes should be >= 0. + self.assertTrue(bool(mx.all(y.values >= 0))) + + def test_step(self): + layer = dsp.STFT( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (1,), + time=32, + atol=1e-4, + rtol=1e-4, + ) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.STFT.Config( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.STFT) + + +class InverseSTFTTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.InverseSTFT( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + ) + # Input: [batch, frames, fft_bins] + num_bins = 16 // 2 + 1 + values = mx.random.normal(shape=(1, 4, num_bins)) + 0j + values = values.astype(mx.complex64) + mask = mx.ones((1, 4), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + # Causal: output = frames * frame_step = 32 + self.assertEqual(y.shape[1], 32) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.InverseSTFT.Config( + frame_length=16, + frame_step=8, + fft_length=16, + time_padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, dsp.InverseSTFT) + + +class LinearToMelSpectrogramTest(parameterized.TestCase): + + def test_layer(self): + layer = dsp.LinearToMelSpectrogram( + num_mel_bins=40, + sample_rate=16000, + lower_edge_hertz=80.0, + upper_edge_hertz=7600.0, + ) + # Input: [batch, time, fft_bins] + x = test_utils.random_sequence(1, 4, 129) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (40,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import dsp as jax_dsp + + config = jax_dsp.LinearToMelSpectrogram.Config( + num_mel_bins=40, + sample_rate=16000, + lower_edge_hertz=80.0, + upper_edge_hertz=7600.0, + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + dsp.LinearToMelSpectrogram, + ) + + +class SignalUtilitiesTest(parameterized.TestCase): + + def test_hann_window(self): + w = dsp.hann_window(4) + self.assertEqual(len(w), 4) + # Periodic Hann: endpoints should not both be zero. + self.assertGreater(w[-1], 0.0) + + def test_frame(self): + values = mx.arange(10).reshape(1, 10, 1).astype(mx.float32) + framed = dsp.frame(values, 4, 2) + # (10 - 4) // 2 + 1 = 4 frames + self.assertEqual(framed.shape, (1, 4, 4, 1)) + + def test_overlap_and_add_identity(self): + # No overlap: frame_step == frame_length. + signal = mx.array([[[1.0, 2.0], [3.0, 4.0]]]) + result = dsp.overlap_and_add(signal, 2) + np.testing.assert_allclose(np.array(result), [[1.0, 2.0, 3.0, 4.0]]) + + def test_mel_weight_matrix(self): + w = dsp.linear_to_mel_weight_matrix( + num_mel_bins=40, + num_spectrogram_bins=129, + sample_rate=16000, + lower_edge_hertz=80.0, + upper_edge_hertz=7600.0, + ) + self.assertEqual(w.shape, (129, 40)) + # Weights should be non-negative. + self.assertTrue(np.all(w >= 0)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/export.py b/sequence_layers/mlx/export.py new file mode 100644 index 0000000..82e7564 --- /dev/null +++ b/sequence_layers/mlx/export.py @@ -0,0 +1,194 @@ +"""Export MLX SequenceLayer step() to .mlxfn for streaming inference.""" + +import mlx.core as mx + +from sequence_layers.mlx import basic_types as bt + +Sequence = bt.Sequence + + +# --------------------------------------------------------------------------- +# State flattening / unflattening +# --------------------------------------------------------------------------- + + +def _flatten_state(state): + """Flatten a nested pytree state into a list of mx.array. + + Handles tuples, lists, and mx.array leaves. Empty tuples contribute + zero arrays. + + Args: + state: Nested tuple/list of mx.array. + + Returns: + (flat_arrays, structure) where structure encodes the nesting. + """ + flat = [] + + def _record(node): + if isinstance(node, mx.array): + flat.append(node) + return 'array' + elif isinstance(node, tuple): + children = [_record(child) for child in node] + return ('tuple', children) + elif isinstance(node, list): + children = [_record(child) for child in node] + return ('list', children) + else: + raise TypeError(f'Unsupported state node type: {type(node)}') + + structure = _record(state) + return flat, structure + + +def _unflatten_state(flat, structure): + """Reconstruct a nested state from a flat array list and structure. + + Args: + flat: List of mx.array. + structure: Structure descriptor from _flatten_state. + + Returns: + Nested tuple/list matching the original structure. + """ + idx = [0] + + def _rebuild(struct): + if struct == 'array': + result = flat[idx[0]] + idx[0] += 1 + return result + elif isinstance(struct, tuple) and struct[0] == 'tuple': + return tuple(_rebuild(s) for s in struct[1]) + elif isinstance(struct, tuple) and struct[0] == 'list': + return [_rebuild(s) for s in struct[1]] + else: + raise ValueError(f'Unknown structure node: {struct}') + + result = _rebuild(structure) + if idx[0] != len(flat): + raise ValueError(f'Not all arrays consumed: used {idx[0]} of {len(flat)}') + return result + + +# --------------------------------------------------------------------------- +# Export +# --------------------------------------------------------------------------- + + +def _materialize_deferred(model, batch_size, input_spec, *, constants=None): + """Run a dummy forward pass to materialize all deferred layers.""" + x_values = mx.zeros( + (batch_size, 1) + input_spec.shape, dtype=input_spec.dtype + ) + x_mask = mx.ones((batch_size, 1), dtype=mx.bool_) + x = Sequence(x_values, x_mask) + state = model.get_initial_state(batch_size, input_spec, constants=constants) + model.step(x, state, constants=constants) + mx.eval(model.parameters()) + + +def get_initial_state_flat(model, batch_size, input_spec, *, constants=None): + """Get flattened initial state arrays and structure for a model. + + Args: + model: An MLX SequenceLayer. + batch_size: Batch size. + input_spec: A ShapeDType describing the input channels. + constants: Optional constants dict. + + Returns: + (flat_arrays, structure) where flat_arrays is a list of mx.array + and structure can be used with _unflatten_state. + """ + state = model.get_initial_state(batch_size, input_spec, constants=constants) + flat, structure = _flatten_state(state) + mx.eval(*flat) if flat else None + return flat, structure + + +def export_step( + model, + path, + batch_size, + input_spec, + *, + constants=None, + time_steps=1, +): + """Export model.step() to a .mlxfn file. + + The exported function signature is: + (x_values, x_mask, *state_flat) -> (y_values, y_mask, *new_state_flat) + + Model weights are captured in the closure and embedded in the .mlxfn + file. State arrays (e.g. KV cache) are explicit I/O. + + The exported function uses fixed shapes (batch_size, time_steps). + For streaming generation, time_steps=1 is typical. + + Args: + model: An MLX SequenceLayer with supports_step. + path: Output file path (should end in .mlxfn). + batch_size: Batch size for the exported function. + input_spec: A ShapeDType describing the input channel shape and dtype. + constants: Optional constants dict for cross-attention. + time_steps: Number of time steps per call (default 1). + """ + if not model.supports_step: + raise ValueError(f'{model.__class__.__name__} does not support step().') + + # Materialize all deferred layers. + _materialize_deferred(model, batch_size, input_spec, constants=constants) + + # Get initial state and flatten. + flat_state, structure = get_initial_state_flat( + model, batch_size, input_spec, constants=constants + ) + + # Make sure all model params are evaluated. + mx.eval(model.parameters()) + + def step_fn(x_values, x_mask, *state_flat): + state = _unflatten_state(list(state_flat), structure) + x = Sequence(x_values, x_mask) + y, new_state = model.step(x, state, constants=constants) + new_flat, _ = _flatten_state(new_state) + return (y.values, y.mask, *new_flat) + + # Create example inputs for tracing. + x_values = mx.zeros( + (batch_size, time_steps) + input_spec.shape, + dtype=input_spec.dtype, + ) + x_mask = mx.ones((batch_size, time_steps), dtype=mx.bool_) + mx.eval(x_values, x_mask) + + mx.export_function( + path, + step_fn, + x_values, + x_mask, + *flat_state, + ) + + +def run_exported(imported_fn, x_values, x_mask, state_flat): + """Call an imported .mlxfn step function. + + Args: + imported_fn: A function from mx.import_function(). + x_values: Input values array [batch, time, ...channels]. + x_mask: Input mask array [batch, time]. + state_flat: List of flat state arrays. + + Returns: + (y_values, y_mask, new_state_flat) where new_state_flat is a list. + """ + results = imported_fn(x_values, x_mask, *state_flat) + y_values = results[0] + y_mask = results[1] + new_state_flat = list(results[2:]) + return y_values, y_mask, new_state_flat diff --git a/sequence_layers/mlx/export_test.py b/sequence_layers/mlx/export_test.py new file mode 100644 index 0000000..175b58f --- /dev/null +++ b/sequence_layers/mlx/export_test.py @@ -0,0 +1,297 @@ +"""Tests for MLX export utilities.""" + +import os +import tempfile + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import test_utils + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + + +class StateFlattenTest(parameterized.TestCase): + """Tests for state flatten/unflatten.""" + + def test_empty_tuple(self): + state = () + flat, structure = export._flatten_state(state) + self.assertEmpty(flat) + rebuilt = export._unflatten_state(flat, structure) + self.assertEqual(rebuilt, ()) + + def test_single_array(self): + arr = mx.zeros((2, 3)) + state = (arr,) + flat, structure = export._flatten_state(state) + self.assertLen(flat, 1) + rebuilt = export._unflatten_state(flat, structure) + self.assertIsInstance(rebuilt, tuple) + np.testing.assert_array_equal(rebuilt[0], arr) + + def test_nested_tuples(self): + a = mx.ones((2,)) + b = mx.zeros((3, 4)) + c = mx.full((1,), 5.0) + state = ((a, b), (c, ())) + flat, structure = export._flatten_state(state) + self.assertLen(flat, 3) + rebuilt = export._unflatten_state(flat, structure) + np.testing.assert_array_equal(rebuilt[0][0], a) + np.testing.assert_array_equal(rebuilt[0][1], b) + np.testing.assert_array_equal(rebuilt[1][0], c) + self.assertEqual(rebuilt[1][1], ()) + + def test_attention_state_round_trip(self): + """Simulate attention state: (keys, values, mask, time, (), (), ()).""" + keys = mx.zeros((2, 8, 4, 16)) + values = mx.zeros((2, 8, 4, 16)) + mask = mx.zeros((2, 8), dtype=mx.bool_) + time = mx.zeros((2,), dtype=mx.int32) + state = (keys, values, mask, time, (), (), ()) + flat, structure = export._flatten_state(state) + self.assertLen(flat, 4) + rebuilt = export._unflatten_state(flat, structure) + np.testing.assert_array_equal(rebuilt[0], keys) + np.testing.assert_array_equal(rebuilt[1], values) + np.testing.assert_array_equal(rebuilt[2], mask) + np.testing.assert_array_equal(rebuilt[3], time) + self.assertEqual(rebuilt[4], ()) + self.assertEqual(rebuilt[5], ()) + self.assertEqual(rebuilt[6], ()) + + def test_serial_state_round_trip(self): + """Simulate Serial state: tuple of per-layer states.""" + state = ( + (), # Identity (stateless) + ( + mx.zeros((2, 4, 4, 8)), # Attention keys + mx.zeros((2, 4, 4, 8)), # values + mx.zeros((2, 4), dtype=mx.bool_), # mask + mx.zeros((2,), dtype=mx.int32), # time + mx.full((2, 1), -1, dtype=mx.int32), # q_net_state + mx.full((2, 1), -1, dtype=mx.int32), # k_net_state + (), + ), # v_net_state + (), # Dense (stateless) + ) + flat, structure = export._flatten_state(state) + self.assertLen(flat, 6) + rebuilt = export._unflatten_state(flat, structure) + self.assertEqual(rebuilt[0], ()) + self.assertLen(rebuilt[1], 7) + self.assertEqual(rebuilt[2], ()) + + +class ExportDenseTest(parameterized.TestCase): + """Test exporting a simple Dense layer.""" + + def test_export_dense_step(self): + from sequence_layers.mlx import dense + + layer = dense.Dense(in_features=8, features=16, use_bias=True) + input_spec = ShapeDType((8,), mx.float32) + batch_size = 2 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'dense.mlxfn') + export.export_step( + layer, + path, + batch_size=batch_size, + input_spec=input_spec, + ) + self.assertTrue(os.path.exists(path)) + + # Import and run. + imported = mx.import_function(path) + flat_state, structure = export.get_initial_state_flat( + layer, batch_size, input_spec + ) + + x = test_utils.random_sequence(batch_size, 1, (8,)) + mx.eval(x.values, x.mask) + + # Run native. + state = layer.get_initial_state(batch_size, input_spec) + y_native, _ = layer.step(x, state) + + # Run exported. + y_vals, y_mask, new_state = export.run_exported( + imported, x.values, x.mask, flat_state + ) + mx.eval(y_native.values, y_vals) + + np.testing.assert_allclose( + np.array(y_vals), + np.array(y_native.values), + atol=1e-5, + rtol=1e-5, + ) + + def test_export_dense_no_bias(self): + from sequence_layers.mlx import dense + + layer = dense.Dense(in_features=8, features=16, use_bias=False) + input_spec = ShapeDType((8,), mx.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'dense_nobias.mlxfn') + export.export_step(layer, path, batch_size=1, input_spec=input_spec) + + imported = mx.import_function(path) + flat_state, _ = export.get_initial_state_flat(layer, 1, input_spec) + + x = test_utils.random_sequence(1, 1, (8,)) + mx.eval(x.values, x.mask) + + state = layer.get_initial_state(1, input_spec) + y_native, _ = layer.step(x, state) + + y_vals, _, _ = export.run_exported(imported, x.values, x.mask, flat_state) + mx.eval(y_native.values, y_vals) + + np.testing.assert_allclose( + np.array(y_vals), + np.array(y_native.values), + atol=1e-5, + rtol=1e-5, + ) + + +class ExportAttentionTest(parameterized.TestCase): + """Test exporting attention with KV cache.""" + + def test_export_attention_multi_step(self): + from sequence_layers.mlx import attention + + layer = attention.DotProductSelfAttention( + in_features=16, + num_heads=2, + units_per_head=8, + max_past_horizon=32, + max_future_horizon=0, + ) + input_spec = ShapeDType((16,), mx.float32) + batch_size = 1 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'attn.mlxfn') + export.export_step( + layer, + path, + batch_size=batch_size, + input_spec=input_spec, + ) + + imported = mx.import_function(path) + + # Run 3 steps natively. + state = layer.get_initial_state(batch_size, input_spec) + flat_state, structure = export._flatten_state(state) + mx.eval(*flat_state) + + # Use same inputs for both native and exported. + inputs = [] + for _ in range(3): + x = test_utils.random_sequence(batch_size, 1, (16,)) + mx.eval(x.values, x.mask) + inputs.append(x) + + # Native. + native_state = state + native_outputs = [] + for x in inputs: + y, native_state = layer.step(x, native_state) + mx.eval(y.values) + native_outputs.append(np.array(y.values)) + + # Exported. + exported_state = list(flat_state) + exported_outputs = [] + for x in inputs: + y_vals, y_mask, exported_state = export.run_exported( + imported, x.values, x.mask, exported_state + ) + mx.eval(y_vals) + exported_outputs.append(np.array(y_vals)) + + for i, (native, exported) in enumerate( + zip(native_outputs, exported_outputs) + ): + np.testing.assert_allclose( + exported, + native, + atol=1e-5, + rtol=1e-5, + err_msg=f'Step {i} mismatch', + ) + + +class ExportSerialTest(parameterized.TestCase): + """Test exporting a Serial model.""" + + def test_export_serial(self): + from sequence_layers.mlx import combinators + from sequence_layers.mlx import dense + from sequence_layers.mlx import normalization + + model = combinators.Serial([ + normalization.RMSNormalization(epsilon=1e-6), + dense.Dense( + in_features=8, + features=16, + use_bias=True, + activation=mx.sigmoid, + ), + dense.Dense(in_features=16, features=8, use_bias=True), + ]) + input_spec = ShapeDType((8,), mx.float32) + batch_size = 2 + + # Materialize deferred layers. + export._materialize_deferred(model, batch_size, input_spec) + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'serial.mlxfn') + export.export_step( + model, + path, + batch_size=batch_size, + input_spec=input_spec, + ) + + imported = mx.import_function(path) + flat_state, structure = export.get_initial_state_flat( + model, batch_size, input_spec + ) + + x = test_utils.random_sequence(batch_size, 1, (8,)) + mx.eval(x.values, x.mask) + + # Native. + state = model.get_initial_state(batch_size, input_spec) + y_native, _ = model.step(x, state) + + # Exported. + y_vals, y_mask, _ = export.run_exported( + imported, x.values, x.mask, flat_state + ) + mx.eval(y_native.values, y_vals) + + np.testing.assert_allclose( + np.array(y_vals), + np.array(y_native.values), + atol=1e-5, + rtol=1e-5, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/init_mapping.py b/sequence_layers/mlx/init_mapping.py new file mode 100644 index 0000000..33c2a95 --- /dev/null +++ b/sequence_layers/mlx/init_mapping.py @@ -0,0 +1,232 @@ +"""Mapping JAX/Flax initializers and activations to MLX equivalents.""" + +import functools +import math + +import jax +import jax.numpy as jnp +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from flax.linen import initializers as flax_init + + +def _variance_scaling(key, shape, dtype, mode, distribution, fan_in, fan_out): + """Variance scaling initializer core logic.""" + dtype = _to_mx_dtype(dtype) + if mode == 'fan_in': + denominator = max(fan_in, 1) + elif mode == 'fan_out': + denominator = max(fan_out, 1) + elif mode == 'fan_avg': + denominator = max((fan_in + fan_out) / 2.0, 1) + else: + raise ValueError(f'Unknown mode: {mode}') + + variance = 1.0 / denominator + if distribution == 'truncated_normal': + stddev = math.sqrt(variance) / 0.87962566103423978 + return ( + mx.random.truncated_normal(-2.0, 2.0, shape=shape, key=key).astype( + dtype + ) + * stddev + ) + elif distribution == 'normal': + return mx.random.normal(shape=shape, key=key).astype(dtype) * math.sqrt( + variance + ) + elif distribution == 'uniform': + limit = math.sqrt(3.0 * variance) + return mx.random.uniform(-limit, limit, shape=shape, key=key).astype(dtype) + else: + raise ValueError(f'Unknown distribution: {distribution}') + + +def _compute_fans(shape): + """Compute fan_in and fan_out for a weight shape.""" + if len(shape) < 1: + fan_in = fan_out = 1 + elif len(shape) == 1: + fan_in = fan_out = shape[0] + elif len(shape) == 2: + fan_in, fan_out = shape + else: + # Conv kernels: last two dims are (fan_in, fan_out), rest are spatial. + receptive_field_size = 1 + for s in shape[:-2]: + receptive_field_size *= s + fan_in = shape[-2] * receptive_field_size + fan_out = shape[-1] * receptive_field_size + return fan_in, fan_out + + +def _make_variance_scaling_init(mode, distribution): + """Create an MLX variance scaling initializer.""" + + def init_fn(key, shape, dtype=mx.float32): + fan_in, fan_out = _compute_fans(shape) + return _variance_scaling( + key, shape, dtype, mode, distribution, fan_in, fan_out + ) + + return init_fn + + +def _to_mx_dtype(dtype): + """Convert any dtype (JAX, numpy, MLX) to an MLX dtype.""" + if isinstance(dtype, mx.Dtype): + return dtype + name = getattr(dtype, '__name__', '') or str(dtype) + mapping = { + 'float32': mx.float32, + 'float16': mx.float16, + 'bfloat16': mx.bfloat16, + 'float64': mx.float32, # MLX lacks float64. + 'int32': mx.int32, + 'int64': mx.int32, # MLX lacks int64. + 'int16': mx.int16, + 'int8': mx.int8, + 'uint8': mx.uint8, + 'uint32': mx.uint32, + 'bool': mx.bool_, + 'bool_': mx.bool_, + 'complex64': mx.complex64, + } + for key, val in mapping.items(): + if key in name: + return val + return mx.float32 + + +def _zeros_init(key, shape, dtype=mx.float32): + del key + return mx.zeros(shape, dtype=_to_mx_dtype(dtype)) + + +def _ones_init(key, shape, dtype=mx.float32): + del key + return mx.ones(shape, dtype=_to_mx_dtype(dtype)) + + +def _normal_init(stddev=0.01): + def init_fn(key, shape, dtype=mx.float32): + dtype = _to_mx_dtype(dtype) + return mx.random.normal(shape=shape, key=key).astype(dtype) * stddev + + return init_fn + + +def map_initializer(jax_init): + """Convert a JAX/Flax initializer to an MLX-compatible initializer. + + Args: + jax_init: A JAX/Flax initializer function. + + Returns: + An MLX initializer function with signature (key, shape, dtype). + """ + if jax_init is None: + return None + + # Check for common Flax initializer instances by calling with + # a probe to determine behavior. + try: + # Test with a small shape to determine the initializer type. + test_key = jax.random.PRNGKey(0) + test_shape = (4, 4) + test_out = jax_init(test_key, test_shape, jnp.float32) + test_np = np.array(test_out) + + # Check if it's zeros. + if np.allclose(test_np, 0.0): + return _zeros_init + + # Check if it's ones. + if np.allclose(test_np, 1.0): + return _ones_init + except Exception: + pass + + # Try to identify by function name or attributes. + name = getattr(jax_init, '__name__', '') + qualname = getattr(jax_init, '__qualname__', '') + func = getattr(jax_init, 'func', None) + func_qualname = getattr(func, '__qualname__', '') if func else '' + + # Variance scaling variants. + if 'lecun_normal' in name or 'lecun_normal' in qualname: + return _make_variance_scaling_init('fan_in', 'truncated_normal') + if 'lecun_uniform' in name or 'lecun_uniform' in qualname: + return _make_variance_scaling_init('fan_in', 'uniform') + if 'glorot_normal' in name or 'glorot_normal' in qualname: + return _make_variance_scaling_init('fan_avg', 'truncated_normal') + if 'glorot_uniform' in name or 'glorot_uniform' in qualname: + return _make_variance_scaling_init('fan_avg', 'uniform') + if 'he_normal' in name or 'he_normal' in qualname: + return _make_variance_scaling_init('fan_in', 'normal') + if 'he_uniform' in name or 'he_uniform' in qualname: + return _make_variance_scaling_init('fan_in', 'uniform') + if 'xavier_normal' in name or 'xavier_normal' in qualname: + return _make_variance_scaling_init('fan_avg', 'normal') + if 'xavier_uniform' in name or 'xavier_uniform' in qualname: + return _make_variance_scaling_init('fan_avg', 'uniform') + + # Check for variance_scaling in qualname/func. + if 'variance_scaling' in qualname or 'variance_scaling' in func_qualname: + return _make_variance_scaling_init('fan_in', 'truncated_normal') + + if 'zeros' in name or 'zeros' in qualname: + return _zeros_init + if 'ones' in name or 'ones' in qualname: + return _ones_init + + # Default fallback: lecun_normal equivalent. + return _make_variance_scaling_init('fan_in', 'truncated_normal') + + +# --------------------------------------------------------------------------- +# Activation mapping +# --------------------------------------------------------------------------- + +_ACTIVATION_MAP = {} + + +def _build_activation_map(): + """Build the JAX -> MLX activation mapping lazily.""" + if _ACTIVATION_MAP: + return + _ACTIVATION_MAP.update({ + jax.nn.relu: nn.relu, + jax.nn.gelu: nn.gelu, + jax.nn.silu: nn.silu, + jax.nn.swish: nn.silu, # swish == silu + jax.nn.sigmoid: mx.sigmoid, + jax.nn.tanh: mx.tanh, + jax.nn.softmax: mx.softmax, + jax.nn.elu: nn.elu, + jax.nn.leaky_relu: nn.leaky_relu, + jax.nn.log_softmax: mx.log, # Approximate. + }) + # Also add jnp versions. + for k, v in list(_ACTIVATION_MAP.items()): + name = getattr(k, '__name__', '') + jnp_fn = getattr(jnp, name, None) + if jnp_fn is not None and jnp_fn not in _ACTIVATION_MAP: + _ACTIVATION_MAP[jnp_fn] = v + + +def map_activation(jax_activation): + """Convert a JAX activation function to its MLX equivalent. + + Args: + jax_activation: A JAX activation function (e.g. jax.nn.relu). + + Returns: + The corresponding MLX activation, or the original function + if no mapping is found. + """ + if jax_activation is None: + return None + _build_activation_map() + return _ACTIVATION_MAP.get(jax_activation, jax_activation) diff --git a/sequence_layers/mlx/normalization.py b/sequence_layers/mlx/normalization.py new file mode 100644 index 0000000..af29aa8 --- /dev/null +++ b/sequence_layers/mlx/normalization.py @@ -0,0 +1,420 @@ +"""Normalization layers for MLX.""" + +import mlx.core as mx +import mlx.nn as nn + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types + +Sequence = bt.Sequence + + +def _normalize_axes(axis, input_shape): + """Normalize axes and check batch/time are not specified.""" + if isinstance(axis, int): + axis = (axis,) + normalized = set() + for a in axis: + if a < 0: + a += len(input_shape) + normalized.add(a) + axes = tuple(sorted(normalized)) + for a in axes: + if a in (0, 1): + raise ValueError( + f'Normalizing over batch or time is not allowed. Got: {axes}' + ) + return axes + + +class L2Normalize(types.PreservesType, types.StatelessPointwise): + """L2 normalization over the specified channel axes.""" + + def __init__(self, *, axis=-1, epsilon: float = 1e-12): + super().__init__() + self._axis = axis + self.epsilon = epsilon + + @types.check_layer + def layer(self, x, *, constants=None): + values = x.values + axes = _normalize_axes(self._axis, values.shape) + + v = values.astype(mx.float32) + squared_sum = mx.sum(mx.square(v), axis=axes, keepdims=True) + normed = v * mx.rsqrt(squared_sum + self.epsilon) + return Sequence(normed.astype(values.dtype), x.mask) + + @classmethod + def from_config(cls, config): + axis = config.axis + if not isinstance(axis, int): + axis = tuple(axis) + return cls(axis=axis, epsilon=config.epsilon) + + +class RMSNormalization(types.PreservesType, types.StatelessPointwise): + """RMS Normalization backed by mlx.nn.RMSNorm. + + For simple axis=-1 normalization with a learned scale, this delegates + to mlx.nn.RMSNorm (which uses the optimized mx.fast.rms_norm). + Falls back to manual computation for multi-axis or no-scale cases. + """ + + def __init__( + self, + *, + axis=-1, + epsilon: float = 1e-6, + use_scale: bool = True, + param_dtype=mx.float32, + scale_init=None, + ): + super().__init__() + self._axis = axis + self.epsilon = epsilon + self.use_scale = use_scale + self._param_dtype = param_dtype + self._scale_init = scale_init + # mlx.nn.RMSNorm created lazily since we need input shape. + self._rms_norm = None + self._use_builtin = False + + def _ensure_initialized(self, input_shape): + """Create internal RMSNorm on first call.""" + if self._rms_norm is not None or not self.use_scale: + return + axes = _normalize_axes(self._axis, input_shape) + # mlx.nn.RMSNorm only supports normalizing over the last dim. + if axes == (len(input_shape) - 1,) and self._scale_init is None: + dims = input_shape[-1] + self._rms_norm = nn.RMSNorm(dims, eps=self.epsilon) + self._use_builtin = True + else: + # Multi-axis or custom init: manual scale parameter. + scale_shape = tuple(input_shape[a] for a in axes) + if self._scale_init is not None: + key = mx.random.key(0) + self._scale = self._scale_init(key, scale_shape, self._param_dtype) + else: + self._scale = mx.ones(scale_shape, dtype=self._param_dtype) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.values.shape) + + if self._use_builtin and self._rms_norm is not None: + return Sequence(self._rms_norm(x.values), x.mask) + + values = x.values + axes = _normalize_axes(self._axis, values.shape) + + # Manual RMS norm in float32. + v = values.astype(mx.float32) + mean_sq = mx.mean(mx.square(v), axis=axes, keepdims=True) + normed = v * mx.rsqrt(mean_sq + self.epsilon) + normed = normed.astype(values.dtype) + + # Apply learned scale. + if self.use_scale: + scale = self._scale.astype(normed.dtype) + shape = [1] * len(values.shape) + for i, a in enumerate(axes): + shape[a] = self._scale.shape[i] + scale = scale.reshape(shape) + normed = normed * scale + + return Sequence(normed, x.mask) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + axis = config.axis + if not isinstance(axis, int): + axis = tuple(axis) + return cls( + axis=axis, + epsilon=config.epsilon, + use_scale=config.use_scale, + param_dtype=_to_mx_dtype(config.param_dtype), + scale_init=init_mapping.map_initializer(config.scale_init), + ) + + +class LayerNormalization(types.PreservesType, types.StatelessPointwise): + """Layer Normalization backed by mlx.nn.LayerNorm. + + For simple axis=-1 normalization, delegates to mlx.nn.LayerNorm. + Falls back to manual computation for multi-axis cases. + """ + + def __init__( + self, + *, + axis=-1, + epsilon: float = 1e-6, + use_bias: bool = True, + use_scale: bool = True, + param_dtype=mx.float32, + ): + super().__init__() + self._axis = axis + self.epsilon = epsilon + self.use_bias = use_bias + self.use_scale = use_scale + self._param_dtype = param_dtype + self._layer_norm = None + self._use_builtin = False + self._manual_scale = None + self._manual_bias = None + + def _ensure_initialized(self, input_shape): + if self._layer_norm is not None or self._manual_scale is not None: + return + if not self.use_scale and not self.use_bias: + return + axes = _normalize_axes(self._axis, input_shape) + # mlx.nn.LayerNorm supports a single last-dim normalization. + if axes == (len(input_shape) - 1,): + dims = input_shape[-1] + self._layer_norm = nn.LayerNorm( + dims, + eps=self.epsilon, + affine=self.use_scale or self.use_bias, + bias=self.use_bias, + ) + self._use_builtin = True + else: + # Multi-axis: manual parameters. + scale_shape = tuple(input_shape[a] for a in axes) + if self.use_scale: + self._manual_scale = mx.ones(scale_shape, dtype=self._param_dtype) + if self.use_bias: + self._manual_bias = mx.zeros(scale_shape, dtype=self._param_dtype) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.values.shape) + + if self._use_builtin and self._layer_norm is not None: + return Sequence(self._layer_norm(x.values), x.mask) + + values = x.values + axes = _normalize_axes(self._axis, values.shape) + + # Manual layer norm in float32. + v = values.astype(mx.float32) + mean = mx.mean(v, axis=axes, keepdims=True) + variance = mx.mean(mx.square(v - mean), axis=axes, keepdims=True) + normed = (v - mean) * mx.rsqrt(variance + self.epsilon) + normed = normed.astype(values.dtype) + + # Apply learned scale and bias. + if self.use_scale and self._manual_scale is not None: + scale = self._manual_scale.astype(normed.dtype) + shape = [1] * len(values.shape) + for i, a in enumerate(axes): + shape[a] = self._manual_scale.shape[i] + normed = normed * scale.reshape(shape) + + if self.use_bias and self._manual_bias is not None: + bias = self._manual_bias.astype(normed.dtype) + shape = [1] * len(values.shape) + for i, a in enumerate(axes): + shape[a] = self._manual_bias.shape[i] + normed = normed + bias.reshape(shape) + + return Sequence(normed, x.mask) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + axis = config.axis + if not isinstance(axis, int): + axis = tuple(axis) + return cls( + axis=axis, + epsilon=config.epsilon, + use_bias=config.use_bias, + use_scale=config.use_scale, + param_dtype=_to_mx_dtype(config.param_dtype), + ) + + +class BatchNormalization(types.PreservesType, types.StatelessPointwise): + """Batch Normalization (inference-only). + + Uses stored running mean/variance for normalization. Training-mode + batch stat computation is not supported (MLX port is inference-only). + Running stats are loaded via weight_converter.load_linen_params(). + """ + + def __init__( + self, + *, + axis=-1, + epsilon: float = 0.001, + use_bias: bool = True, + use_scale: bool = True, + param_dtype=mx.float32, + ): + super().__init__() + self._axis = axis + self.epsilon = epsilon + self.use_bias = use_bias + self.use_scale = use_scale + self._param_dtype = param_dtype + self._running_mean = None + self._running_var = None + self._scale = None + self._bias = None + + def _ensure_initialized(self, input_shape): + if self._running_mean is not None: + return + axes = _normalize_axes(self._axis, input_shape) + axis_size = input_shape[axes[0]] + self._running_mean = mx.zeros((axis_size,), dtype=self._param_dtype) + self._running_var = mx.ones((axis_size,), dtype=self._param_dtype) + if self.use_scale: + self._scale = mx.ones((axis_size,), dtype=self._param_dtype) + if self.use_bias: + self._bias = mx.zeros((axis_size,), dtype=self._param_dtype) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.values.shape) + + values = x.values + axes = _normalize_axes(self._axis, values.shape) + + # Broadcast running stats over batch and time. + shape = [1] * len(values.shape) + shape[axes[0]] = self._running_mean.shape[0] + + mean = self._running_mean.reshape(shape) + var = self._running_var.reshape(shape) + + normed = (values.astype(mx.float32) - mean) * mx.rsqrt(var + self.epsilon) + normed = normed.astype(values.dtype) + + if self.use_scale and self._scale is not None: + normed = normed * self._scale.reshape(shape) + if self.use_bias and self._bias is not None: + normed = normed + self._bias.reshape(shape) + + return Sequence(normed, x.mask) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls( + axis=config.axis, + epsilon=config.epsilon, + use_bias=config.use_bias, + use_scale=config.use_scale, + param_dtype=_to_mx_dtype(config.param_dtype), + ) + + +class GroupNormalization(types.PreservesType, types.StatelessPointwise): + """Group Normalization. + + Normalizes per-timestep within each group (not across time), so + that step() and layer() produce identical results. + + Note: mlx.nn.GroupNorm normalizes across all spatial dims including + time, which is incompatible with the SequenceLayer step/layer contract. + """ + + def __init__( + self, + *, + num_groups: int, + axis: int = -1, + epsilon: float = 1e-6, + use_bias: bool = True, + use_scale: bool = True, + param_dtype=mx.float32, + ): + super().__init__() + if num_groups <= 0: + raise ValueError(f'{num_groups=} must be positive.') + self._num_groups = num_groups + self._axis = axis + self.epsilon = epsilon + self.use_bias = use_bias + self.use_scale = use_scale + self._param_dtype = param_dtype + self._scale = None + self._bias = None + + def _ensure_initialized(self, input_shape): + if self._scale is not None or self._bias is not None: + return + axes = _normalize_axes(self._axis, input_shape) + axis_size = input_shape[axes[0]] + if self.use_scale: + self._scale = mx.ones((axis_size,), dtype=self._param_dtype) + if self.use_bias: + self._bias = mx.zeros((axis_size,), dtype=self._param_dtype) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_initialized(x.values.shape) + + values = x.values + axes = _normalize_axes(self._axis, values.shape) + axis = axes[0] + axis_size = values.shape[axis] + + if axis_size % self._num_groups != 0: + raise ValueError( + f'Input axis {axis} size {axis_size} must be' + f' divisible by {self._num_groups}.' + ) + group_size = axis_size // self._num_groups + + # Reshape to [... num_groups, group_size ...] + shape = list(values.shape) + grouped_shape = ( + shape[:axis] + [self._num_groups, group_size] + shape[axis + 1 :] + ) + grouped = mx.reshape(values, grouped_shape) + + # Normalize over group_size only (per-timestep). + g = grouped.astype(mx.float32) + reduce_axis = axis + 1 + mean = mx.mean(g, axis=reduce_axis, keepdims=True) + variance = mx.mean(mx.square(g - mean), axis=reduce_axis, keepdims=True) + normed = (g - mean) * mx.rsqrt(variance + self.epsilon) + normed = mx.reshape(normed.astype(values.dtype), values.shape) + + # Apply learned scale and bias. + if self.use_scale and self._scale is not None: + scale_shape = [1] * len(values.shape) + scale_shape[axis] = axis_size + normed = normed * self._scale.reshape(scale_shape) + if self.use_bias and self._bias is not None: + bias_shape = [1] * len(values.shape) + bias_shape[axis] = axis_size + normed = normed + self._bias.reshape(bias_shape) + + return Sequence(normed, x.mask) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls( + num_groups=config.num_groups, + axis=config.axis, + epsilon=config.epsilon, + use_bias=config.use_bias, + use_scale=config.use_scale, + param_dtype=_to_mx_dtype(config.param_dtype), + ) diff --git a/sequence_layers/mlx/normalization_test.py b/sequence_layers/mlx/normalization_test.py new file mode 100644 index 0000000..a068a4d --- /dev/null +++ b/sequence_layers/mlx/normalization_test.py @@ -0,0 +1,187 @@ +"""Tests for normalization MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import normalization +from sequence_layers.mlx import test_utils + + +class L2NormalizeTest(parameterized.TestCase): + + def test_layer(self): + layer = normalization.L2Normalize() + test_utils.verify_contract(self, layer, (8,)) + + def test_normalizes(self): + layer = normalization.L2Normalize() + values = mx.array([[[3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = test_utils.random_sequence(1, 1, 2).unmask() + x = type(x)(values, mask) + y = layer.layer(x) + # L2 norm of [3, 4] is 5, so output should be [0.6, 0.8]. + np.testing.assert_allclose(np.array(y.values), [[[0.6, 0.8]]], atol=1e-6) + + def test_multi_axis(self): + layer = normalization.L2Normalize(axis=(-2, -1)) + test_utils.verify_contract(self, layer, (4, 3)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import normalization as jax_norm + + config = jax_norm.L2Normalize.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, normalization.L2Normalize) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +class RMSNormalizationTest(parameterized.TestCase): + + def test_layer(self): + layer = normalization.RMSNormalization() + test_utils.verify_contract(self, layer, (8,)) + + def test_no_scale(self): + layer = normalization.RMSNormalization(use_scale=False) + test_utils.verify_contract(self, layer, (8,)) + + def test_normalizes(self): + layer = normalization.RMSNormalization(use_scale=False) + values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = test_utils.random_sequence(1, 1, 4).unmask() + x = type(x)(values, mask) + y = layer.layer(x) + # After RMS norm, the RMS of the output should be ~1. + rms = float(mx.sqrt(mx.mean(mx.square(y.values)))) + np.testing.assert_allclose(rms, 1.0, atol=0.1) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import normalization as jax_norm + + config = jax_norm.RMSNormalization.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, normalization.RMSNormalization) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +class LayerNormalizationTest(parameterized.TestCase): + + def test_layer(self): + layer = normalization.LayerNormalization() + test_utils.verify_contract(self, layer, (8,)) + + def test_no_affine(self): + layer = normalization.LayerNormalization(use_scale=False, use_bias=False) + test_utils.verify_contract(self, layer, (8,)) + + def test_normalizes(self): + layer = normalization.LayerNormalization(use_scale=False, use_bias=False) + values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = test_utils.random_sequence(1, 1, 4).unmask() + x = type(x)(values, mask) + y = layer.layer(x) + # After layer norm, mean should be ~0, std should be ~1. + mean = float(mx.mean(y.values)) + std = float(mx.sqrt(mx.mean(mx.square(y.values - mean)))) + np.testing.assert_allclose(mean, 0.0, atol=1e-5) + np.testing.assert_allclose(std, 1.0, atol=0.15) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import normalization as jax_norm + + config = jax_norm.LayerNormalization.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, normalization.LayerNormalization) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +class BatchNormalizationTest(parameterized.TestCase): + + def test_layer(self): + layer = normalization.BatchNormalization() + test_utils.verify_contract(self, layer, (8,)) + + def test_no_affine(self): + layer = normalization.BatchNormalization(use_scale=False, use_bias=False) + test_utils.verify_contract(self, layer, (8,)) + + def test_normalizes(self): + layer = normalization.BatchNormalization(use_scale=False, use_bias=False) + # Set known running stats. + layer._ensure_initialized((1, 1, 4)) + layer._running_mean = mx.array([1.0, 2.0, 3.0, 4.0]) + layer._running_var = mx.array([1.0, 1.0, 1.0, 1.0]) + values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = type(test_utils.random_sequence(1, 1, 4))(values, mask) + y = layer.layer(x) + # (x - mean) / sqrt(var + eps) should be ~0 + np.testing.assert_allclose(y.values, np.zeros((1, 1, 4)), atol=1e-3) + + def test_scale_and_bias(self): + layer = normalization.BatchNormalization() + layer._ensure_initialized((1, 1, 4)) + layer._running_mean = mx.zeros((4,)) + layer._running_var = mx.ones((4,)) + layer._scale = mx.array([2.0, 2.0, 2.0, 2.0]) + layer._bias = mx.array([1.0, 1.0, 1.0, 1.0]) + values = mx.array([[[1.0, 0.0, -1.0, 2.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = type(test_utils.random_sequence(1, 1, 4))(values, mask) + y = layer.layer(x) + # (x - 0) / sqrt(1 + 0.001) * 2 + 1 + scale = 2.0 / float(mx.sqrt(mx.array(1.001))) + expected = np.array([[[ + 1.0 * scale + 1.0, + 0.0 * scale + 1.0, + -1.0 * scale + 1.0, + 2.0 * scale + 1.0, + ]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-5) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import normalization as jax_norm + + config = jax_norm.BatchNormalization.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, normalization.BatchNormalization) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +class GroupNormalizationTest(parameterized.TestCase): + + def test_layer(self): + layer = normalization.GroupNormalization(num_groups=2) + test_utils.verify_contract(self, layer, (8,)) + + def test_no_affine(self): + layer = normalization.GroupNormalization( + num_groups=4, use_scale=False, use_bias=False + ) + test_utils.verify_contract(self, layer, (8,)) + + def test_num_groups_must_divide(self): + layer = normalization.GroupNormalization(num_groups=3) + with self.assertRaises(ValueError): + layer.layer(test_utils.random_sequence(1, 2, 8)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import normalization as jax_norm + + config = jax_norm.GroupNormalization.Config(num_groups=2) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, normalization.GroupNormalization) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/pooling.py b/sequence_layers/mlx/pooling.py new file mode 100644 index 0000000..364dd97 --- /dev/null +++ b/sequence_layers/mlx/pooling.py @@ -0,0 +1,439 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pooling layers for MLX.""" + +import fractions + +import mlx.core as mx + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import convolution as conv_utils +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +PaddingMode = bt.PaddingMode + +# Reuse convolution utilities. +_effective_kernel_size = conv_utils._effective_kernel_size +_explicit_padding = conv_utils._explicit_padding +_buffer_width = conv_utils._buffer_width +_compute_conv_mask = conv_utils._compute_conv_mask + +# Pooling supports fewer step modes than convolution (no causal_valid). +_STEP_PADDINGS = frozenset({ + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, +}) + + +def _reduce_window_1d(values, pool_size, stride, dilation_rate, reduce_fn): + """Gather pooling windows and reduce along the window axis. + + Args: + values: [batch, time, *channels] input tensor (already padded). + pool_size: Size of the pooling window. + stride: Stride between windows. + dilation_rate: Dilation of the pooling window. + reduce_fn: Function(array, axis) -> array. + + Returns: + [batch, num_outputs, *channels] + """ + if pool_size == 1 and stride == 1: + return values + if pool_size == 1: + return values[:, ::stride] + + t = values.shape[1] + ek = _effective_kernel_size(pool_size, dilation_rate) + num_outputs = max(0, (t - ek) // stride + 1) + if num_outputs == 0: + out_shape = (values.shape[0], 0) + values.shape[2:] + return mx.zeros(out_shape, dtype=values.dtype) + + window_offsets = mx.arange(pool_size) * dilation_rate + start_positions = mx.arange(num_outputs) * stride + indices = start_positions[:, None] + window_offsets[None, :] + gathered = values[:, indices] # [b, n, pool_size, *channels] + return reduce_fn(gathered, axis=2) + + +def _reduce_window_masked_avg_1d( + values, mask, pool_size, stride, dilation_rate +): + """Sum-then-divide pooling with mask-aware divisor. + + Args: + values: [batch, time, *channels] already masked to zero. + mask: [batch, time] boolean mask. + pool_size: Size of the pooling window. + stride: Stride between windows. + dilation_rate: Dilation of the pooling window. + + Returns: + [batch, num_outputs, *channels] + """ + t = values.shape[1] + ek = _effective_kernel_size(pool_size, dilation_rate) + num_outputs = max(0, (t - ek) // stride + 1) + if num_outputs == 0: + out_shape = (values.shape[0], 0) + values.shape[2:] + return mx.zeros(out_shape, dtype=values.dtype) + + window_offsets = mx.arange(pool_size) * dilation_rate + start_positions = mx.arange(num_outputs) * stride + indices = start_positions[:, None] + window_offsets[None, :] + + gathered = values[:, indices] + v_sum = mx.sum(gathered, axis=2) + + gathered_mask = mask[:, indices].astype(mx.float32) + count = mx.sum(gathered_mask, axis=2) # [b, n] + count = mx.maximum(count, 1.0) + # Expand to broadcast over channel dims. + for _ in range(values.ndim - 2): + count = mx.expand_dims(count, axis=-1) + + return v_sum / count + + +def _compute_initial_state_pooling( + batch_size, input_spec, buf_width, padding, pad_value=0.0 +): + """Create initial buffer state for pooling step mode.""" + if padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.SEMICAUSAL_FULL.value, + ): + mask = mx.ones((batch_size, buf_width), dtype=bt.MASK_DTYPE) + elif padding in ( + PaddingMode.CAUSAL.value, + PaddingMode.REVERSE_CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + mask = mx.zeros((batch_size, buf_width), dtype=bt.MASK_DTYPE) + else: + raise ValueError(f'Step not supported with padding: {padding}') + + values = mx.full( + (batch_size, buf_width) + input_spec.shape, + pad_value, + dtype=input_spec.dtype, + ) + # Return Sequence (not MaskedSequence) — matches JAX's .unmask(). + return Sequence(values, mask) + + +class _Pooling1D(types.PreservesType, types.SequenceLayer): + """Base class for 1D pooling layers.""" + + def __init__(self, pool_size, strides=1, dilation_rate=1, padding='valid'): + super().__init__() + self._pool_size = pool_size + self._strides = strides + self._dilation_rate = dilation_rate + self._padding = padding + + def _pad_value(self, dtype): + raise NotImplementedError + + def _reduce(self, gathered, axis): + raise NotImplementedError + + @property + def supports_step(self): + return self._padding in _STEP_PADDINGS + + @property + def block_size(self): + return self._strides + + @property + def output_ratio(self): + return fractions.Fraction(1, self._strides) + + @property + def input_latency(self): + ek = _effective_kernel_size(self._pool_size, self._dilation_rate) + if self._padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self._padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = _buffer_width( + self._padding, + self._pool_size, + self._strides, + self._dilation_rate, + ) + if not bw: + return () + return _compute_initial_state_pooling( + batch_size, + input_spec, + bw, + self._padding, + pad_value=self._pad_value(input_spec.dtype), + ) + + @types.check_layer + def layer(self, x, *, constants=None): + pad_value = self._pad_value(x.dtype) + if self._pool_size > 1: + x = x.mask_invalid(pad_value) + + pad_left, pad_right = _explicit_padding( + self._padding, + self._pool_size, + self._strides, + self._dilation_rate, + ) + values = x.values + if pad_left > 0 or pad_right > 0: + pad_widths = [(0, 0), (pad_left, pad_right)] + [(0, 0)] * ( + values.ndim - 2 + ) + values = mx.pad(values, pad_widths, constant_values=pad_value) + + values = _reduce_window_1d( + values, + self._pool_size, + self._strides, + self._dilation_rate, + self._reduce, + ) + mask = _compute_conv_mask( + x.mask, + self._pool_size, + self._strides, + self._dilation_rate, + self._padding, + is_step=False, + ) + return Sequence(values, mask) + + @types.check_step + def step(self, x, state, *, constants=None): + pad_value = self._pad_value(x.dtype) + ek = _effective_kernel_size(self._pool_size, self._dilation_rate) + if ek > 1: + x = x.mask_invalid(pad_value) + + bw = _buffer_width( + self._padding, + self._pool_size, + self._strides, + self._dilation_rate, + ) + + if bw: + state = state.concatenate(x) + else: + state = x + + values = _reduce_window_1d( + state.values, + self._pool_size, + self._strides, + self._dilation_rate, + self._reduce, + ) + mask = _compute_conv_mask( + state.mask, + self._pool_size, + self._strides, + self._dilation_rate, + self._padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + +class MaxPooling1D(_Pooling1D): + """1D max pooling layer.""" + + def _pad_value(self, dtype): + return float('-inf') + + def _reduce(self, gathered, axis): + return mx.max(gathered, axis=axis) + + @classmethod + def from_config(cls, config): + return cls( + pool_size=config.pool_size, + strides=config.strides, + dilation_rate=config.dilation_rate, + padding=config.padding, + ) + + +class MinPooling1D(_Pooling1D): + """1D min pooling layer.""" + + def _pad_value(self, dtype): + return float('inf') + + def _reduce(self, gathered, axis): + return mx.min(gathered, axis=axis) + + @classmethod + def from_config(cls, config): + return cls( + pool_size=config.pool_size, + strides=config.strides, + dilation_rate=config.dilation_rate, + padding=config.padding, + ) + + +class AveragePooling1D(_Pooling1D): + """1D average pooling layer.""" + + def __init__( + self, + pool_size, + strides=1, + dilation_rate=1, + padding='valid', + masked_average=False, + ): + super().__init__(pool_size, strides, dilation_rate, padding) + self._masked_average = masked_average + + def _pad_value(self, dtype): + return 0.0 + + def _reduce(self, gathered, axis): + return mx.mean(gathered, axis=axis) + + @types.check_layer + def layer(self, x, *, constants=None): + if not self._masked_average: + return _Pooling1D.layer.__wrapped__(self, x, constants=constants) + + # Masked average: divide by count of valid elements. + x = x.mask_invalid(0.0) + pad_left, pad_right = _explicit_padding( + self._padding, + self._pool_size, + self._strides, + self._dilation_rate, + ) + values = x.values + input_mask = x.mask + if pad_left > 0 or pad_right > 0: + pad_widths = [(0, 0), (pad_left, pad_right)] + [(0, 0)] * ( + values.ndim - 2 + ) + values = mx.pad(values, pad_widths, constant_values=0.0) + input_mask = mx.pad( + input_mask, + [(0, 0), (pad_left, pad_right)], + constant_values=False, + ) + + values = _reduce_window_masked_avg_1d( + values, + input_mask, + self._pool_size, + self._strides, + self._dilation_rate, + ) + mask = _compute_conv_mask( + x.mask, + self._pool_size, + self._strides, + self._dilation_rate, + self._padding, + is_step=False, + ) + return Sequence(values, mask) + + @types.check_step + def step(self, x, state, *, constants=None): + if not self._masked_average: + return _Pooling1D.step.__wrapped__(self, x, state, constants=constants) + + # Masked average step. + ek = _effective_kernel_size(self._pool_size, self._dilation_rate) + if ek > 1: + x = x.mask_invalid(0.0) + + bw = _buffer_width( + self._padding, + self._pool_size, + self._strides, + self._dilation_rate, + ) + + if bw: + state = state.concatenate(x) + else: + state = x + + values = _reduce_window_masked_avg_1d( + state.values, + state.mask, + self._pool_size, + self._strides, + self._dilation_rate, + ) + mask = _compute_conv_mask( + state.mask, + self._pool_size, + self._strides, + self._dilation_rate, + self._padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @classmethod + def from_config(cls, config): + return cls( + pool_size=config.pool_size, + strides=config.strides, + dilation_rate=config.dilation_rate, + padding=config.padding, + masked_average=config.masked_average, + ) diff --git a/sequence_layers/mlx/pooling_test.py b/sequence_layers/mlx/pooling_test.py new file mode 100644 index 0000000..d6058e2 --- /dev/null +++ b/sequence_layers/mlx/pooling_test.py @@ -0,0 +1,260 @@ +"""Tests for pooling MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import pooling +from sequence_layers.mlx import test_utils + + +class MaxPooling1DTest(parameterized.TestCase): + + @parameterized.parameters( + ('causal',), + ('semicausal',), + ) + def test_causal_paddings(self, padding): + layer = pooling.MaxPooling1D(pool_size=3, padding=padding) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_valid(self): + layer = pooling.MaxPooling1D(pool_size=3, padding='valid') + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4,)) + self.assertEqual(y.shape[1], 6) # 8 - 3 + 1 + + def test_same(self): + layer = pooling.MaxPooling1D(pool_size=3, padding='same') + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 8) + + def test_stride(self): + layer = pooling.MaxPooling1D( + pool_size=3, + strides=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + time=8, + atol=1e-5, + rtol=1e-5, + ) + + def test_dilation(self): + layer = pooling.MaxPooling1D( + pool_size=3, + dilation_rate=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_max_values(self): + values = mx.array([[[1.0], [3.0], [2.0], [5.0], [4.0]]]) + mask = mx.ones((1, 5), dtype=mx.bool_) + x = test_utils.random_sequence(1, 5, 1) + x = type(x)(values, mask) + layer = pooling.MaxPooling1D(pool_size=3, padding='valid') + y = layer.layer(x) + expected = np.array([[[3.0], [5.0], [5.0]]]) + np.testing.assert_allclose(y.values, expected) + + def test_pool_size_1(self): + layer = pooling.MaxPooling1D(pool_size=1) + test_utils.verify_contract(self, layer, (4,)) + + def test_output_shape(self): + layer = pooling.MaxPooling1D(pool_size=3, padding='causal') + self.assertEqual(layer.get_output_shape((4,)), (4,)) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import pooling as jax_pooling + + config = jax_pooling.MaxPooling1D.Config( + pool_size=3, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, pooling.MaxPooling1D) + test_utils.verify_contract( + self, + mlx_layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + +class MinPooling1DTest(parameterized.TestCase): + + @parameterized.parameters( + ('causal',), + ('semicausal',), + ) + def test_causal_paddings(self, padding): + layer = pooling.MinPooling1D(pool_size=3, padding=padding) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_reverse_causal_layer(self): + layer = pooling.MinPooling1D( + pool_size=3, + padding='reverse_causal', + ) + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 8) + self.assertEqual(y.channel_shape, (4,)) + + def test_valid(self): + layer = pooling.MinPooling1D(pool_size=3, padding='valid') + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 6) + + def test_min_values(self): + values = mx.array([[[5.0], [3.0], [4.0], [1.0], [2.0]]]) + mask = mx.ones((1, 5), dtype=mx.bool_) + x = type(test_utils.random_sequence(1, 5, 1))(values, mask) + layer = pooling.MinPooling1D(pool_size=3, padding='valid') + y = layer.layer(x) + expected = np.array([[[3.0], [1.0], [1.0]]]) + np.testing.assert_allclose(y.values, expected) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import pooling as jax_pooling + + config = jax_pooling.MinPooling1D.Config( + pool_size=3, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, pooling.MinPooling1D) + test_utils.verify_contract( + self, + mlx_layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + +class AveragePooling1DTest(parameterized.TestCase): + + @parameterized.parameters( + ('causal',), + ('semicausal',), + ) + def test_causal_paddings(self, padding): + layer = pooling.AveragePooling1D(pool_size=3, padding=padding) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + def test_valid(self): + layer = pooling.AveragePooling1D(pool_size=3, padding='valid') + x = test_utils.random_sequence(1, 8, 4) + y = layer.layer(x) + self.assertEqual(y.shape[1], 6) + + def test_average_values(self): + values = mx.array([[[3.0], [6.0], [9.0], [12.0], [15.0]]]) + mask = mx.ones((1, 5), dtype=mx.bool_) + x = type(test_utils.random_sequence(1, 5, 1))(values, mask) + layer = pooling.AveragePooling1D(pool_size=3, padding='valid') + y = layer.layer(x) + expected = np.array([[[6.0], [9.0], [12.0]]]) + np.testing.assert_allclose(y.values, expected) + + def test_stride(self): + layer = pooling.AveragePooling1D( + pool_size=3, + strides=2, + padding='causal', + ) + test_utils.verify_contract( + self, + layer, + (4,), + time=8, + atol=1e-5, + rtol=1e-5, + ) + + def test_masked_average(self): + values = mx.array([[[3.0], [6.0], [0.0]]]) + mask = mx.array([[True, True, False]]) + x = type(test_utils.random_sequence(1, 3, 1))(values, mask) + layer = pooling.AveragePooling1D( + pool_size=3, + padding='valid', + masked_average=True, + ) + y = layer.layer(x) + # Only 2 valid elements: mean should be (3+6)/2 = 4.5 + expected = np.array([[[4.5]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-5) + + def test_masked_average_causal(self): + layer = pooling.AveragePooling1D( + pool_size=3, + padding='causal', + masked_average=True, + ) + test_utils.verify_contract( + self, + layer, + (4,), + atol=1e-4, + rtol=1e-4, + ) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import pooling as jax_pooling + + config = jax_pooling.AveragePooling1D.Config( + pool_size=3, + padding='causal', + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, pooling.AveragePooling1D) + test_utils.verify_contract( + self, + mlx_layer, + (4,), + atol=1e-5, + rtol=1e-5, + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/position.py b/sequence_layers/mlx/position.py new file mode 100644 index 0000000..1db683a --- /dev/null +++ b/sequence_layers/mlx/position.py @@ -0,0 +1,104 @@ +"""Position embeddings for MLX.""" + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import types + +Sequence = bt.Sequence + + +class ApplyRotaryPositionalEncoding( + types.PreservesType, + types.PreservesShape, + types.SequenceLayer, +): + """Applies Rotary Positional Encodings (RoPE) to the sequence.""" + + def __init__( + self, + *, + max_wavelength: float, + axis: int = -1, + only_advance_position_for_valid_timesteps: bool = True, + ): + super().__init__() + self.max_wavelength = max_wavelength + self._axis = axis + self.only_advance_position_for_valid_timesteps = ( + only_advance_position_for_valid_timesteps + ) + + def _apply_rope(self, x, positions): + """Apply rotary position encoding to x at given positions.""" + axis = self._axis + x.ndim if self._axis < 0 else self._axis + channel_ndim = x.ndim - 2 + axis_dim = x.shape[axis] + + freq_exponents = ( + 2.0 * mx.arange(axis_dim // 2).astype(mx.float32) / axis_dim + ) + timescale = self.max_wavelength**freq_exponents + + broadcast_shape = [1] * x.ndim + broadcast_shape[axis] = axis_dim // 2 + + # Compute position angles. + positions_f = positions.astype(mx.float32) + radians = positions_f.reshape( + positions_f.shape + (1,) * channel_ndim + ) / timescale.reshape(broadcast_shape) + sin_r = mx.sin(radians) + cos_r = mx.cos(radians) + + # Split input along rotation axis, apply rotation. + splits = mx.split(x, 2, axis=axis) + x1, x2 = splits[0], splits[1] + result = mx.concatenate( + [x1 * cos_r - x2 * sin_r, x2 * cos_r + x1 * sin_r], + axis=axis, + ) + return result.astype(x.dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if self.only_advance_position_for_valid_timesteps: + return mx.full((batch_size, 1), -1, dtype=mx.int32) + else: + return mx.zeros((batch_size, 1), dtype=mx.int32) + + @types.check_step + def step(self, x, state, *, constants=None): + x_time = x.shape[1] + if self.only_advance_position_for_valid_timesteps: + positions = state + mx.cumsum(x.mask.astype(mx.int32), axis=1) + state = positions[:, -1:] + else: + positions = state + mx.arange(x_time, dtype=mx.int32) + state = state + x_time + y = x.apply_values(self._apply_rope, positions) + return y, state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.only_advance_position_for_valid_timesteps: + positions = mx.maximum( + 0, + mx.cumsum(x.mask.astype(mx.int32), axis=1) - 1, + ) + else: + positions = mx.broadcast_to( + mx.arange(x.shape[1], dtype=mx.int32)[None, :], + (x.shape[0], x.shape[1]), + ) + return x.apply_values(self._apply_rope, positions) + + @classmethod + def from_config(cls, config): + return cls( + max_wavelength=config.max_wavelength, + axis=config.axis, + only_advance_position_for_valid_timesteps=( + config.only_advance_position_for_valid_timesteps + ), + ) diff --git a/sequence_layers/mlx/position_test.py b/sequence_layers/mlx/position_test.py new file mode 100644 index 0000000..81284e5 --- /dev/null +++ b/sequence_layers/mlx/position_test.py @@ -0,0 +1,73 @@ +"""Tests for position encoding MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import position +from sequence_layers.mlx import test_utils + + +class ApplyRotaryPositionalEncodingTest(parameterized.TestCase): + + def test_layer(self): + layer = position.ApplyRotaryPositionalEncoding(max_wavelength=10000.0) + test_utils.verify_contract(self, layer, (8,)) + + def test_layer_multi_channel(self): + layer = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, axis=-1 + ) + test_utils.verify_contract(self, layer, (4, 8)) + + def test_step_vs_layer(self): + layer = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, + only_advance_position_for_valid_timesteps=False, + ) + test_utils.verify_contract(self, layer, (8,), atol=1e-4, rtol=1e-4) + + def test_step_positions_advance(self): + layer = position.ApplyRotaryPositionalEncoding( + max_wavelength=10000.0, + only_advance_position_for_valid_timesteps=True, + ) + spec = bt.ShapeDType((8,), mx.float32) + state = layer.get_initial_state(1, spec) + + # Step with valid mask. + x1 = bt.MaskedSequence( + mx.ones((1, 1, 8)), + mx.ones((1, 1), dtype=mx.bool_), + ) + _, state = layer.step(x1, state) + # State starts at -1, cumsum(True)=1, position=-1+1=0. + self.assertEqual(int(state[0, 0]), 0) + + # Step with invalid mask. + x2 = bt.MaskedSequence( + mx.ones((1, 1, 8)), + mx.zeros((1, 1), dtype=mx.bool_), + ) + _, state = layer.step(x2, state) + # cumsum(False)=0, position=0+0=0. No advance. + self.assertEqual(int(state[0, 0]), 0) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import position as jax_pos + + config = jax_pos.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10000.0 + ) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance( + mlx_layer, + position.ApplyRotaryPositionalEncoding, + ) + test_utils.verify_contract(self, mlx_layer, (8,)) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py new file mode 100644 index 0000000..27bbadc --- /dev/null +++ b/sequence_layers/mlx/simple.py @@ -0,0 +1,816 @@ +"""Simple sequence layers for MLX.""" + +import math + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence + + +# --------------------------------------------------------------------------- +# Identity +# --------------------------------------------------------------------------- + + +class Identity(types.PreservesType, types.StatelessPointwise): + """Identity pass-through of the input.""" + + @types.check_layer + def layer(self, x, *, constants=None): + return x + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Activation layers +# --------------------------------------------------------------------------- + + +class Relu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Relu layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.relu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Gelu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Gelu layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.gelu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Swish(types.PreservesType, types.StatelessPointwiseFunctor): + """A Swish/SiLU layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.silu(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Tanh(types.PreservesType, types.StatelessPointwiseFunctor): + """A tanh layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return mx.tanh(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class Sigmoid(types.PreservesType, types.StatelessPointwiseFunctor): + """A sigmoid layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return mx.sigmoid(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +class LeakyRelu(types.PreservesType, types.StatelessPointwiseFunctor): + """A Leaky Relu layer.""" + + def __init__(self, negative_slope=0.01): + super().__init__() + self._negative_slope = negative_slope + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.leaky_relu(values, self._negative_slope), mask + + @classmethod + def from_config(cls, config): + return cls(negative_slope=config.negative_slope) + + +class Elu(types.PreservesType, types.StatelessPointwiseFunctor): + """An ELU activation layer.""" + + def __init__(self, alpha=1.0): + super().__init__() + self._alpha = alpha + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.elu(values, self._alpha), mask + + @classmethod + def from_config(cls, config): + return cls(alpha=config.alpha) + + +class Softmax(types.PreservesType, types.StatelessPointwiseFunctor): + """A softmax layer.""" + + def __init__(self, axis=-1): + super().__init__() + self._axis = axis + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + axis = self._axis + if (axis if axis >= 0 else values.ndim + axis) < 2: + raise ValueError( + 'The softmax cannot be applied on the batch or time' + f' dimension (got {axis=} for shape={values.shape})' + ) + return mx.softmax(values, axis=axis), mask + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Softplus(types.PreservesType, types.StatelessPointwiseFunctor): + """A softplus layer.""" + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return nn.softplus(values), mask + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Value manipulation +# --------------------------------------------------------------------------- + + +class Cast(types.StatelessPointwiseFunctor): + """Cast input values to the specified type.""" + + def __init__(self, dtype): + super().__init__() + self._dtype = dtype + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return values.astype(self._dtype), mask + + def get_output_dtype(self, input_dtype, *, constants=None): + return self._dtype + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls(dtype=_to_mx_dtype(config.dtype)) + + +class Scale(types.PreservesType, types.StatelessPointwise): + """Scales the input by a provided constant or array.""" + + def __init__(self, scale): + super().__init__() + if isinstance(scale, (int, float, complex)): + self._scale = scale + else: + self._scale = mx.array(np.asarray(scale)) + + @types.check_layer + def layer(self, x, *, constants=None): + s = self._scale + if isinstance(s, mx.array): + s = s.astype(x.dtype) + return x.apply_values_masked(lambda v: v * s) + + @classmethod + def from_config(cls, config): + scale = config.scale + if hasattr(scale, 'data') and hasattr(scale, 'dtype'): + scale = np.array(scale.data, dtype=scale.dtype) + elif hasattr(scale, 'array'): + scale = np.asarray(scale.array) + return cls(scale=scale) + + +class Add(types.PreservesType, types.StatelessPointwise): + """Adds a provided constant or array to the input.""" + + def __init__(self, shift): + super().__init__() + if isinstance(shift, (int, float, complex)): + self._shift = shift + else: + self._shift = mx.array(np.asarray(shift)) + + @types.check_layer + def layer(self, x, *, constants=None): + s = self._shift + if isinstance(s, mx.array): + s = s.astype(x.dtype) + return x.apply_values(lambda v: v + s) + + @classmethod + def from_config(cls, config): + shift = config.shift + if hasattr(shift, 'data') and hasattr(shift, 'dtype'): + shift = np.array(shift.data, dtype=shift.dtype) + elif hasattr(shift, 'array'): + shift = np.asarray(shift.array) + return cls(shift=shift) + + +# --------------------------------------------------------------------------- +# Masking +# --------------------------------------------------------------------------- + + +class MaskInvalid(types.PreservesType, types.StatelessPointwise): + """Masks invalid timesteps to zero (or a specified value).""" + + def __init__(self, mask_value=None): + super().__init__() + self._mask_value = mask_value + + @types.check_layer + def layer(self, x, *, constants=None): + return x.mask_invalid(self._mask_value) + + @classmethod + def from_config(cls, config): + mask_value = getattr(config, 'mask_value', None) + return cls(mask_value=mask_value) + + +# --------------------------------------------------------------------------- +# Gated units +# --------------------------------------------------------------------------- + + +class GatedUnit(types.PreservesType, types.Stateless): + """Computes a generalized Gated Unit, reducing input channels by 2x.""" + + def __init__(self, feature_activation=None, gate_activation=None): + super().__init__() + self._feature_activation = feature_activation + self._gate_activation = gate_activation + + def get_output_shape(self, input_shape, *, constants=None): + channels = input_shape[-1] + if channels % 2 != 0: + raise ValueError( + f'Final dimension of input ({input_shape=}) must have' + ' an even number of channels.' + ) + return tuple(input_shape[:-1]) + (channels // 2,) + + @types.check_layer + def layer(self, x, *, constants=None): + feature, gate = mx.split(x.values, 2, axis=-1) + if self._feature_activation: + feature = self._feature_activation(feature) + if self._gate_activation: + gate = self._gate_activation(gate) + return Sequence(feature * gate, x.mask) + + @classmethod + def from_config(cls, config): + fa = init_mapping.map_activation(config.feature_activation) + ga = init_mapping.map_activation(config.gate_activation) + return cls(feature_activation=fa, gate_activation=ga) + + +class GatedLinearUnit(GatedUnit): + """Computes a Gated Linear Unit, reducing input channels by 2x.""" + + def __init__(self): + super().__init__( + feature_activation=None, + gate_activation=mx.sigmoid, + ) + + @classmethod + def from_config(cls, config): + return cls() + + +class GatedTanhUnit(GatedUnit): + """Computes a Gated Tanh Unit, reducing input channels by 2x.""" + + def __init__(self): + super().__init__( + feature_activation=mx.tanh, + gate_activation=mx.sigmoid, + ) + + @classmethod + def from_config(cls, config): + return cls() + + +# --------------------------------------------------------------------------- +# Shape manipulation +# --------------------------------------------------------------------------- + + +class Flatten(types.PreservesType, types.StatelessPointwise): + """Flattens the channel dimensions of the input sequence. + + An input sequence with shape [batch_size, time, ...] is reshaped to + [batch_size, time, prod(...)]. The mask is unchanged. + """ + + def get_output_shape(self, input_shape, *, constants=None): + return (math.prod(input_shape),) + + @types.check_layer + def layer(self, x, *, constants=None): + batch_size, time = x.values.shape[:2] + num_elements = math.prod(x.channel_shape) + new_values = mx.reshape(x.values, (batch_size, time, num_elements)) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls() + + +class Reshape(types.PreservesType, types.Stateless): + """Reshapes the channels dimension of the input.""" + + def __init__(self, output_shape): + super().__init__() + self._output_shape = tuple(output_shape) + + def _validate(self, input_shape): + in_elems = math.prod(input_shape) + out_elems = math.prod(self._output_shape) + if in_elems != out_elems: + raise ValueError( + f'Reshape output_shape={self._output_shape} must have' + f' the same number of elements as {input_shape=}.' + ) + + def get_output_shape(self, input_shape, *, constants=None): + self._validate(input_shape) + return self._output_shape + + @types.check_layer + def layer(self, x, *, constants=None): + self._validate(x.channel_shape) + b, t = x.values.shape[:2] + new_values = mx.reshape(x.values, (b, t) + self._output_shape) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(output_shape=config.output_shape) + + +class ExpandDims(types.PreservesType, types.Stateless): + """Expands channel dimensions of the input sequence.""" + + def __init__(self, axis): + super().__init__() + if isinstance(axis, int): + self._axis = (axis,) + else: + self._axis = tuple(axis) + + def _normalize_axes(self, input_shape): + rank = len(input_shape) + dims = sorted(set(a + rank + 1 if a < 0 else a for a in self._axis)) + for d in dims: + if d < 0 or d > rank: + raise ValueError( + f'ExpandDims axes must refer to channel dims. Got: {self._axis}.' + ) + return dims + + def get_output_shape(self, input_shape, *, constants=None): + dims = self._normalize_axes(input_shape) + out = list(input_shape) + for a in dims: + out.insert(a, 1) + return tuple(out) + + @types.check_layer + def layer(self, x, *, constants=None): + dims = [2 + d for d in self._normalize_axes(x.channel_shape)] + new_values = mx.expand_dims(x.values, axis=dims) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Squeeze(types.PreservesType, types.Stateless): + """Squeezes singleton channel dimensions of the input.""" + + def __init__(self, axis=None): + super().__init__() + self._axis = axis + + def _channel_squeeze_axes(self, input_shape): + """Return channel-relative axes to squeeze.""" + if self._axis is None: + # Squeeze all singleton channel dims. + return tuple(i for i, n in enumerate(input_shape) if n == 1) + # If axis is given, it's in full-tensor coords. Convert to channel. + if isinstance(self._axis, int): + axes = (self._axis,) + else: + axes = tuple(self._axis) + return axes + + def get_output_shape(self, input_shape, *, constants=None): + squeeze_axes = self._channel_squeeze_axes(input_shape) + out = [] + for i, s in enumerate(input_shape): + if i not in squeeze_axes: + out.append(s) + return tuple(out) if out else (1,) + + @types.check_layer + def layer(self, x, *, constants=None): + ch_axes = self._channel_squeeze_axes(x.channel_shape) + # Convert to full-tensor axes (offset by 2 for batch, time). + full_axes = tuple(a + 2 for a in ch_axes) + new_values = mx.squeeze(x.values, axis=full_axes) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axis=config.axis) + + +class Transpose(types.PreservesType, types.Stateless): + """Permutes the channel axes of the input.""" + + def __init__(self, axes=None): + super().__init__() + if axes is not None: + axes = tuple(axes) + self._axes = axes + + def _resolve_axes(self, input_shape): + input_axes = tuple(range(2, 2 + len(input_shape))) + if self._axes is None: + return input_axes[::-1] + sorted_axes = tuple(sorted(self._axes)) + if sorted_axes != input_axes: + raise ValueError( + f'Provided axes {sorted_axes} do not match input axes {input_axes}.' + ) + return tuple(self._axes) + + def get_output_shape(self, input_shape, *, constants=None): + axes = self._resolve_axes(input_shape) + return tuple(input_shape[a - 2] for a in axes) + + @types.check_layer + def layer(self, x, *, constants=None): + axes = self._resolve_axes(x.channel_shape) + new_values = mx.transpose(x.values, (0, 1) + axes) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, x.mask) + return Sequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls(axes=config.axes) + + +# --------------------------------------------------------------------------- +# Encoding +# --------------------------------------------------------------------------- + + +class OneHot(types.Stateless): + """Computes one-hot vector of the input.""" + + def __init__(self, depth, compute_dtype=mx.float32): + super().__init__() + self._depth = depth + self._compute_dtype = compute_dtype + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + (self._depth,) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self._compute_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + def one_hot_fn(v): + indices = v.astype(mx.int32) + return mx.eye(self._depth, dtype=self._compute_dtype)[indices] + + return x.apply_values(one_hot_fn) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + return cls( + depth=config.depth, + compute_dtype=_to_mx_dtype(config.compute_dtype), + ) + + +class Embedding(types.Stateless): + """Computes embeddings of integer input codes. + + Backed by mlx.nn.Embedding. + """ + + def __init__( + self, + *, + num_embeddings: int, + dimension: int, + param_dtype=mx.float32, + compute_dtype=None, + ): + super().__init__() + self.num_embeddings = num_embeddings + self.dimension = dimension + self._param_dtype = param_dtype + self.compute_dtype = compute_dtype + self._embedding = nn.Embedding(num_embeddings, dimension) + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + (self.dimension,) + + def get_output_dtype(self, input_dtype, *, constants=None): + if self.compute_dtype is not None: + return self.compute_dtype + return self._param_dtype + + @types.check_layer + def layer(self, x, *, constants=None): + def embed_fn(v): + result = self._embedding(v.astype(mx.int32)) + if self.compute_dtype is not None: + result = result.astype(self.compute_dtype) + return result + + return x.apply_values(embed_fn) + + @classmethod + def from_config(cls, config): + from sequence_layers.mlx.init_mapping import _to_mx_dtype + + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = _to_mx_dtype(compute_dtype) + return cls( + num_embeddings=config.num_embeddings, + dimension=config.dimension, + param_dtype=_to_mx_dtype(config.param_dtype), + compute_dtype=compute_dtype, + ) + + +# --------------------------------------------------------------------------- +# Regularization +# --------------------------------------------------------------------------- + + +class Dropout(types.PreservesType, types.StatelessPointwise): + """Dropout layer (pass-through during inference).""" + + def __init__(self, rate=0.0): + super().__init__() + self._rate = rate + + @types.check_layer + def layer(self, x, *, constants=None): + # Inference-only: dropout is a no-op. + return x + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +# --------------------------------------------------------------------------- +# Sampling +# --------------------------------------------------------------------------- + + +class Downsample1D(types.PreservesType, types.Stateless): + """A 1D downsampling layer.""" + + def __init__(self, rate): + super().__init__() + self._rate = rate + + @property + def block_size(self): + return self._rate + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + @types.check_layer + def layer(self, x, *, constants=None): + new_values = x.values[:, :: self._rate] + new_mask = x.mask[:, :: self._rate] + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, new_mask) + return Sequence(new_values, new_mask) + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +class Upsample1D(types.PreservesType, types.Stateless): + """A 1D upsampling layer.""" + + def __init__(self, rate): + super().__init__() + self._rate = rate + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + @types.check_layer + def layer(self, x, *, constants=None): + # Repeat each timestep `rate` times along the time axis. + b, t = x.values.shape[:2] + channel_shape = x.values.shape[2:] + # [b, t, 1, ...] -> [b, t, rate, ...] -> [b, t*rate, ...] + expanded = mx.expand_dims(x.values, axis=2) + tiled = mx.repeat(expanded, self._rate, axis=2) + new_values = mx.reshape(tiled, (b, t * self._rate) + channel_shape) + # Same for mask: [b, t] -> [b, t*rate] + new_mask = mx.repeat(mx.expand_dims(x.mask, axis=2), self._rate, axis=2) + new_mask = mx.reshape(new_mask, (b, t * self._rate)) + if isinstance(x, MaskedSequence): + return MaskedSequence(new_values, new_mask) + return Sequence(new_values, new_mask) + + @classmethod + def from_config(cls, config): + return cls(rate=config.rate) + + +# --------------------------------------------------------------------------- +# CheckpointName (identity for inference) +# --------------------------------------------------------------------------- + + +class CheckpointName(types.PreservesType, types.StatelessPointwiseFunctor): + """Identity pass-through (checkpoint naming is JAX-only).""" + + def __init__(self, checkpoint_name=''): + super().__init__() + self._checkpoint_name = checkpoint_name + + @property + def mask_required(self): + return False + + def fn(self, values, mask): + return values, mask + + @classmethod + def from_config(cls, config): + return cls(checkpoint_name=config.checkpoint_name) + + +# --------------------------------------------------------------------------- +# Lambda +# --------------------------------------------------------------------------- + + +class Lambda(types.Stateless): + """A SequenceLayer that wraps a Python callable.""" + + def __init__(self, fn, *, sequence_input=False, mask_required=True): + super().__init__() + self._fn = fn + self._sequence_input = sequence_input + self._mask_required = mask_required + + def get_output_shape(self, input_shape, *, constants=None): + return tuple(input_shape) + + @types.check_layer + def layer(self, x, *, constants=None): + if self._sequence_input: + result = self._fn(x) + if not isinstance(result, (Sequence, MaskedSequence)): + raise ValueError( + 'Lambda with sequence_input=True must return a Sequence, ' + f'got {type(result)}' + ) + return result + else: + new_values = self._fn(x.values) + if self._mask_required or not isinstance(x, MaskedSequence): + return Sequence(new_values, x.mask) + return MaskedSequence(new_values, x.mask) + + @classmethod + def from_config(cls, config): + return cls( + fn=config.fn, + sequence_input=config.sequence_input, + mask_required=config.mask_required, + ) + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + + +class Logging(types.PreservesType, types.StatelessPointwise): + """Logs input info and returns the input unchanged.""" + + def __init__(self, prefix='', dump_tensors=False): + super().__init__() + self._prefix = prefix + self._dump_tensors = dump_tensors + + @types.check_layer + def layer(self, x, *, constants=None): + if self._dump_tensors: + print(f'{self._prefix} layer(): x={x.values}') + else: + print( + f'{self._prefix} layer(): x.shape={x.shape}, ' + f'x.dtype={x.dtype}' + ) + return x + + @classmethod + def from_config(cls, config): + return cls( + prefix=config.prefix, + dump_tensors=config.dump_tensors, + ) diff --git a/sequence_layers/mlx/simple_test.py b/sequence_layers/mlx/simple_test.py new file mode 100644 index 0000000..08b19da --- /dev/null +++ b/sequence_layers/mlx/simple_test.py @@ -0,0 +1,508 @@ +"""Tests for simple MLX sequence layers.""" + +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import simple +from sequence_layers.mlx import test_utils + + +class IdentityTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Identity() + test_utils.verify_contract(self, layer, (4,)) + + def test_preserves_values(self): + layer = simple.Identity() + x = test_utils.random_sequence(2, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + np.testing.assert_array_equal(y.mask, x.mask) + + +class ReluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Relu() + test_utils.verify_contract(self, layer, (4,)) + + def test_negative_zeroed(self): + layer = simple.Relu() + values = mx.array([[-1.0, 0.5, -0.3, 2.0]]).reshape(1, 1, 4) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[0.0, 0.5, 0.0, 2.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class GeluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Gelu() + test_utils.verify_contract(self, layer, (4,)) + + +class SwishTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Swish() + test_utils.verify_contract(self, layer, (4,)) + + +class TanhTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Tanh() + test_utils.verify_contract(self, layer, (4,)) + + def test_values(self): + layer = simple.Tanh() + values = mx.array([[[0.0, 1.0, -1.0, 100.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + np.testing.assert_allclose( + y.values, np.tanh([[[0.0, 1.0, -1.0, 100.0]]]), atol=1e-5 + ) + + +class SigmoidTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Sigmoid() + test_utils.verify_contract(self, layer, (4,)) + + +class LeakyReluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.LeakyRelu(negative_slope=0.2) + test_utils.verify_contract(self, layer, (4,)) + + def test_negative_slope(self): + layer = simple.LeakyRelu(negative_slope=0.1) + values = mx.array([[[-2.0, 0.5, -1.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[-0.2, 0.5, -0.1, 3.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class EluTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Elu() + test_utils.verify_contract(self, layer, (4,)) + + +class SoftmaxTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Softmax() + test_utils.verify_contract(self, layer, (4,)) + + def test_sums_to_one(self): + layer = simple.Softmax(axis=-1) + values = mx.array([[[1.0, 2.0, 3.0, 4.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + np.testing.assert_allclose(float(mx.sum(y.values)), 1.0, atol=1e-5) + + +class SoftplusTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Softplus() + test_utils.verify_contract(self, layer, (4,)) + + +class CastTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Cast(dtype=mx.float16) + test_utils.verify_contract(self, layer, (4,), atol=1e-3, rtol=1e-3) + + def test_cast(self): + layer = simple.Cast(dtype=mx.float16) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + self.assertEqual(y.dtype, mx.float16) + + +class ScaleTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Scale(scale=2.0) + test_utils.verify_contract(self, layer, (4,)) + + def test_scalar(self): + layer = simple.Scale(scale=2.0) + values = mx.array([[[1.0, 2.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[2.0, 4.0, 6.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class AddTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Add(shift=1.0) + test_utils.verify_contract(self, layer, (4,)) + + def test_scalar(self): + layer = simple.Add(shift=10.0) + values = mx.array([[[1.0, 2.0, 3.0]]]) + mask = mx.ones((1, 1), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[11.0, 12.0, 13.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class MaskInvalidTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.MaskInvalid() + test_utils.verify_contract(self, layer, (4,)) + + def test_masks_to_zero(self): + layer = simple.MaskInvalid() + values = mx.array([[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]]) + mask = mx.array([[True, False, True]]) + x = bt.Sequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[1.0, 2.0], [0.0, 0.0], [5.0, 6.0]]]) + np.testing.assert_allclose(y.values, expected, atol=1e-6) + + +class GatedUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedUnit() + test_utils.verify_contract(self, layer, (8,)) + + def test_with_activations(self): + import mlx.nn as nn + + layer = simple.GatedUnit( + feature_activation=nn.relu, gate_activation=nn.sigmoid + ) + test_utils.verify_contract(self, layer, (8,)) + + +class GatedLinearUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedLinearUnit() + test_utils.verify_contract(self, layer, (8,)) + + def test_halves_channels(self): + layer = simple.GatedLinearUnit() + self.assertEqual(layer.get_output_shape((8,)), (4,)) + + +class GatedTanhUnitTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.GatedTanhUnit() + test_utils.verify_contract(self, layer, (8,)) + + +class FlattenTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Flatten() + test_utils.verify_contract(self, layer, (2, 3, 4)) + + def test_flatten(self): + layer = simple.Flatten() + self.assertEqual(layer.get_output_shape((2, 3, 4)), (24,)) + + +class ReshapeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Reshape(output_shape=(2, 6)) + test_utils.verify_contract(self, layer, (12,)) + + def test_reshape(self): + layer = simple.Reshape(output_shape=(2, 6)) + x = test_utils.random_sequence(1, 3, 12) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (2, 6)) + + def test_mismatch_raises(self): + layer = simple.Reshape(output_shape=(5,)) + with self.assertRaises(ValueError): + layer.get_output_shape((12,)) + + +class ExpandDimsTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.ExpandDims(axis=-1) + test_utils.verify_contract(self, layer, (4,)) + + def test_expand(self): + layer = simple.ExpandDims(axis=0) + self.assertEqual(layer.get_output_shape((4, 8)), (1, 4, 8)) + + def test_layer_values(self): + layer = simple.ExpandDims(axis=-1) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4, 1)) + + +class SqueezeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Squeeze() + test_utils.verify_contract(self, layer, (4, 1)) + + def test_squeeze(self): + layer = simple.Squeeze() + x = bt.MaskedSequence( + mx.ones((1, 3, 1, 4, 1)), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.channel_shape, (4,)) + + +class TransposeTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Transpose() + test_utils.verify_contract(self, layer, (2, 3, 4)) + + def test_reverse(self): + layer = simple.Transpose() + self.assertEqual(layer.get_output_shape((2, 3, 4)), (4, 3, 2)) + + def test_explicit(self): + layer = simple.Transpose(axes=(3, 2, 4)) + self.assertEqual(layer.get_output_shape((5, 6, 7)), (6, 5, 7)) + + +class OneHotTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.OneHot(depth=5) + x = bt.MaskedSequence( + mx.array([[0, 2, 4]]), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 5)) + # Check that index 0 -> [1,0,0,0,0] + np.testing.assert_allclose(np.array(y.values[0, 0]), [1, 0, 0, 0, 0]) + + +class EmbeddingTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Embedding(num_embeddings=10, dimension=8) + x = bt.MaskedSequence( + mx.array([[1, 3, 5]]), + mx.ones((1, 3), dtype=mx.bool_), + ) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 8)) + + def test_output_shape(self): + layer = simple.Embedding(num_embeddings=10, dimension=8) + self.assertEqual(layer.get_output_shape(()), (8,)) + self.assertEqual(layer.get_output_shape((3,)), (3, 8)) + + +class DropoutTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Dropout(rate=0.5) + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.Dropout(rate=0.5) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + # Inference-only: should be identity. + np.testing.assert_array_equal(y.values, x.values) + + +class Downsample1DTest(parameterized.TestCase): + + def test_verify_contract(self): + layer = simple.Downsample1D(rate=2) + test_utils.verify_contract(self, layer, (4,)) + + def test_layer(self): + layer = simple.Downsample1D(rate=2) + x = test_utils.random_sequence(1, 6, 4) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 3, 4)) + + def test_values(self): + layer = simple.Downsample1D(rate=3) + values = mx.arange(12).reshape(1, 6, 2).astype(mx.float32) + mask = mx.ones((1, 6), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + # Should keep timesteps 0, 3. + np.testing.assert_array_equal(y.values, values[:, ::3]) + + +class Upsample1DTest(parameterized.TestCase): + + def test_verify_contract(self): + layer = simple.Upsample1D(rate=3) + test_utils.verify_contract(self, layer, (4,)) + + def test_layer(self): + layer = simple.Upsample1D(rate=3) + x = test_utils.random_sequence(1, 4, 2) + y = layer.layer(x) + self.assertEqual(y.shape, (1, 12, 2)) + + def test_values(self): + layer = simple.Upsample1D(rate=2) + values = mx.array([[[1.0, 2.0], [3.0, 4.0]]]) + mask = mx.ones((1, 2), dtype=mx.bool_) + x = bt.MaskedSequence(values, mask) + y = layer.layer(x) + expected = mx.array([[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0]]]) + np.testing.assert_allclose(y.values, expected) + self.assertEqual(y.mask.shape, (1, 4)) + + +class BackendDispatchTest(parameterized.TestCase): + """Test config.make(backend='mlx') for simple layers.""" + + def test_identity(self): + import sequence_layers.mlx # Register backends. + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Identity.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Identity) + + def test_relu(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Relu.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Relu) + + def test_tanh(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Tanh.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Tanh) + + def test_gated_linear_unit(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.GatedLinearUnit.Config() + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.GatedLinearUnit) + + def test_reshape(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Reshape.Config(output_shape=(2, 3)) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Reshape) + + def test_downsample(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Downsample1D.Config(rate=2) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Downsample1D) + + +class CheckpointNameTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.CheckpointName(checkpoint_name='test') + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.CheckpointName(checkpoint_name='test') + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + np.testing.assert_array_equal(y.mask, x.mask) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.CheckpointName.Config(checkpoint_name='test') + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.CheckpointName) + + +class LambdaTest(parameterized.TestCase): + + def test_values_fn(self): + layer = simple.Lambda(fn=lambda v: v * 2.0) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) + + def test_sequence_fn(self): + def double_seq(s): + return bt.Sequence(s.values * 2.0, s.mask) + + layer = simple.Lambda(fn=double_seq, sequence_input=True) + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_allclose(y.values, x.values * 2.0, atol=1e-6) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Lambda.Config(fn=lambda v: v) + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Lambda) + + +class LoggingTest(parameterized.TestCase): + + def test_layer(self): + layer = simple.Logging(prefix='test') + test_utils.verify_contract(self, layer, (4,)) + + def test_passthrough(self): + layer = simple.Logging() + x = test_utils.random_sequence(1, 3, 4) + y = layer.layer(x) + np.testing.assert_array_equal(y.values, x.values) + + def test_from_config(self): + import sequence_layers.mlx + from sequence_layers.jax import simple as jax_simple + + config = jax_simple.Logging.Config(prefix='test') + mlx_layer = config.make(backend='mlx') + self.assertIsInstance(mlx_layer, simple.Logging) + + +if __name__ == '__main__': + absltest.main() diff --git a/sequence_layers/mlx/test_utils.py b/sequence_layers/mlx/test_utils.py new file mode 100644 index 0000000..fb141bb --- /dev/null +++ b/sequence_layers/mlx/test_utils.py @@ -0,0 +1,174 @@ +"""Test utilities for MLX sequence layers.""" + +import mlx.core as mx +import numpy as np + +from sequence_layers.mlx import basic_types as bt + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +ShapeDType = bt.ShapeDType + + +def random_sequence( + batch: int, + time: int, + channels: int | tuple[int, ...], + *, + dtype=mx.float32, + mask: mx.array | None = None, + masked: bool = True, +) -> Sequence: + """Create a random Sequence for testing. + + Args: + batch: Batch size. + time: Sequence length. + channels: Channel size (int) or channel shape (tuple). + dtype: Values dtype. + mask: Optional explicit mask. If None, all-valid mask is used. + masked: If True, returns a MaskedSequence. If False, a Sequence. + + Returns: + A random Sequence or MaskedSequence. + """ + if isinstance(channels, int): + channels = (channels,) + shape = (batch, time) + channels + values = mx.random.normal(shape=shape).astype(dtype) + if mask is None: + mask = mx.ones((batch, time), dtype=mx.bool_) + if masked: + return MaskedSequence(values, mask) + return Sequence(values, mask) + + +def step_by_step( + layer, + x: Sequence, + *, + block_size: int = 1, + constants=None, + stream_constants=None, +) -> tuple[Sequence, object]: + """Run a layer step-by-step and concatenate outputs. + + Args: + layer: A SequenceLayer with supports_step. + x: Input sequence [batch, time, ...]. + block_size: Number of timesteps per step. + constants: Optional constants dict (static, passed as-is each step). + stream_constants: Optional dict of source_name -> Sequence. These are + sliced at the same block_size as input for each step, merging into + the constants dict. Use this for streaming cross-attention sources. + + Returns: + (output_sequence, final_state) + """ + batch = x.shape[0] + time = x.shape[1] + spec = x.channel_spec + + # Build initial constants with full stream sources for get_initial_state. + init_constants = dict(constants) if constants else {} + if stream_constants: + init_constants.update(stream_constants) + + state = layer.get_initial_state(batch, spec, constants=init_constants or None) + + outputs_values = [] + outputs_masks = [] + + for t in range(0, time, block_size): + x_block = Sequence( + x.values[:, t : t + block_size], + x.mask[:, t : t + block_size], + ) + + # Build per-step constants with sliced stream sources. + step_constants = dict(constants) if constants else {} + if stream_constants: + for name, seq in stream_constants.items(): + step_constants[name] = Sequence( + seq.values[:, t : t + block_size], + seq.mask[:, t : t + block_size], + ) + + y_block, state = layer.step( + x_block, + state, + constants=step_constants or None, + ) + outputs_values.append(y_block.values) + outputs_masks.append(y_block.mask) + + y_values = mx.concatenate(outputs_values, axis=1) + y_mask = mx.concatenate(outputs_masks, axis=1) + return Sequence(y_values, y_mask), state + + +def verify_contract( + test_case, + layer, + input_shape, + *, + batch_size: int = 2, + time: int = 8, + dtype=mx.float32, + constants=None, + atol: float = 1e-5, + rtol: float = 1e-5, + test_step: bool = True, +): + """Verify that a layer's layer() and step() outputs are consistent. + + Checks: + 1. layer() runs without error and produces correct output shape. + 2. step() runs without error and produces correct output shape. + 3. layer() and step() produce approximately equal outputs. + + Args: + test_case: An absltest.TestCase (or similar) with assertion methods. + layer: The SequenceLayer to test. + input_shape: Channel shape (tuple), e.g. (16,). + batch_size: Batch size for test inputs. + time: Sequence length for test inputs. + dtype: Input dtype. + constants: Optional constants dict. + atol: Absolute tolerance for output comparison. + rtol: Relative tolerance for output comparison. + test_step: Whether to test step() and compare with layer(). + """ + x = random_sequence(batch_size, time, input_shape, dtype=dtype) + + # Test layer(). + y_layer = layer.layer(x, constants=constants) + + # Check output shape. + expected_shape = layer.get_output_shape(input_shape, constants=constants) + test_case.assertEqual(y_layer.channel_shape, expected_shape) + + # Check output dtype. + expected_dtype = layer.get_output_dtype(dtype, constants=constants) + test_case.assertEqual(y_layer.dtype, expected_dtype) + + if not test_step or not layer.supports_step: + return + + # Test step(). + block_size = layer.block_size + y_step, _ = step_by_step(layer, x, block_size=block_size, constants=constants) + + # Check shapes match. + test_case.assertEqual(y_step.shape, y_layer.shape) + + # Check values match. + y_layer_np = np.array(y_layer.values) + y_step_np = np.array(y_step.values) + np.testing.assert_allclose( + y_step_np, + y_layer_np, + atol=atol, + rtol=rtol, + err_msg=f'{layer.__class__.__name__}: step() and layer() outputs differ', + ) diff --git a/sequence_layers/mlx/weight_converter.py b/sequence_layers/mlx/weight_converter.py new file mode 100644 index 0000000..46682e8 --- /dev/null +++ b/sequence_layers/mlx/weight_converter.py @@ -0,0 +1,585 @@ +"""Convert Linen-trained params to MLX model weights. + +Handles the structural differences between Linen (JAX/Flax) and MLX: + - Linen Dense kernel [in, out] → MLX nn.Linear weight [out, in] + - Linen combined QKV kernel [in, 3, heads, uph] → separate q/k/v + - Linen Repeat stacked params [N, ...] → per-copy params [...] + - Linen Partitioned wrappers → unwrapped arrays +""" + +import mlx.core as mx +import numpy as np + + +def _unbox_params(params): + """Unwrap Flax Partitioned wrappers and convert to numpy. + + Args: + params: A Linen param dict (possibly with Partitioned values). + + Returns: + A nested dict of numpy arrays. + """ + import jax + from flax import linen as nn + + params = nn.unbox(params) + return jax.tree_util.tree_map(lambda x: np.array(x), params) + + +def _set_weight(module, attr_name, value): + """Set a weight on an MLX module. + + Handles both direct array attributes and nn.Module child params. + + Args: + module: An MLX nn.Module. + attr_name: Dot-separated attribute path (e.g. '_linear.weight'). + value: An mx.array value. + """ + parts = attr_name.split('.') + obj = module + for part in parts[:-1]: + obj = getattr(obj, part) + setattr(obj, parts[-1], value) + + +def load_linen_params( + mlx_model, + linen_params, + config, + *, + input_spec=None, + batch_stats=None, + constants=None, +): + """Load Linen-trained params into an MLX model. + + Uses the config tree to guide the conversion, handling structural + differences between Linen and MLX parameter layouts. + + Args: + mlx_model: An MLX SequenceLayer (already initialized via + config.make(backend='mlx')). + linen_params: A Linen param dict from model.init(...)['params']. + config: The SequenceLayerConfig used to create both models. + input_spec: Optional ShapeDType for the input. Defaults to scalar int32 + (for token models). For float models (e.g. convolution), pass + ShapeDType((channels,), mx.float32). + batch_stats: Optional batch_stats dict from model.init(...)['batch_stats']. + Required for BatchNormalization layers. + constants: Optional constants dict for layers that need a source sequence + during deferred initialization (e.g. cross-attention). + """ + from sequence_layers.mlx import export + from sequence_layers.mlx import basic_types as bt + + if input_spec is None: + input_spec = bt.ShapeDType((), mx.int32) + + # Materialize deferred layers with a dummy forward pass. + # Slice constants to time=1 to match the dummy input. + init_constants = None + if constants is not None: + init_constants = {} + for k, v in constants.items(): + if hasattr(v, 'values') and hasattr(v, 'mask'): + # Slice Sequence to time=1. + init_constants[k] = bt.Sequence(v.values[:1, :1], v.mask[:1, :1]) + else: + init_constants[k] = v + export._materialize_deferred( + mlx_model, + batch_size=1, + input_spec=input_spec, + constants=init_constants, + ) + + # Unbox and convert to numpy. + params = _unbox_params(linen_params) + bs = _unbox_params(batch_stats) if batch_stats is not None else None + + # Walk the config tree and load params. + _load_config(mlx_model, params, config, batch_stats=bs) + mx.eval(mlx_model.parameters()) + + +def _load_config(mlx_module, linen_params, config, batch_stats=None): + """Recursively load params guided by config type.""" + from sequence_layers.jax import combinators as jax_comb + from sequence_layers.jax import conditioning as jax_cond + from sequence_layers.jax import convolution as jax_conv + from sequence_layers.jax import dense as jax_dense + from sequence_layers.jax import normalization as jax_norm + from sequence_layers.jax import simple as jax_simple + from sequence_layers.jax.attention import ( + dot_product_attention as jax_cross_attn, + ) + from sequence_layers.jax.attention import ( + dot_product_self_attention as jax_self_attn, + ) + from sequence_layers.jax.attention import ( + streaming_dot_product_attention as jax_streaming_attn, + ) + from sequence_layers.jax.attention import ( + streaming_local_dot_product_attention as jax_streaming_local_attn, + ) + from sequence_layers.jax.attention import ( + local_dot_product_self_attention as jax_local_attn, + ) + + if isinstance(config, jax_comb.Serial.Config): + _load_serial(mlx_module, linen_params, config, batch_stats) + elif isinstance(config, jax_comb.Parallel.Config): + _load_parallel(mlx_module, linen_params, config, batch_stats) + elif isinstance(config, jax_comb.Repeat.Config): + _load_repeat(mlx_module, linen_params, config, batch_stats) + elif isinstance(config, jax_comb.Residual.Config): + _load_residual(mlx_module, linen_params, config, batch_stats) + elif isinstance( + config, + ( + jax_cross_attn.DotProductAttention.Config, + jax_streaming_attn.StreamingDotProductAttention.Config, + jax_streaming_local_attn.StreamingLocalDotProductAttention.Config, + ), + ): + _load_streaming_attention(mlx_module, linen_params, config) + elif isinstance(config, jax_local_attn.LocalDotProductSelfAttention.Config): + _load_attention(mlx_module, linen_params, config) + elif isinstance(config, jax_self_attn.DotProductSelfAttention.Config): + _load_attention(mlx_module, linen_params, config) + elif isinstance(config, jax_dense.Dense.Config): + _load_dense(mlx_module, linen_params, config) + elif isinstance(config, jax_conv.Conv1D.Config): + _load_conv1d(mlx_module, linen_params, config) + elif isinstance(config, jax_conv.DepthwiseConv1D.Config): + _load_depthwise_conv1d(mlx_module, linen_params, config) + elif isinstance(config, jax_conv.Conv1DTranspose.Config): + _load_conv1d_transpose(mlx_module, linen_params, config) + elif isinstance(config, jax_norm.RMSNormalization.Config): + _load_rms_norm(mlx_module, linen_params, config) + elif isinstance(config, jax_norm.LayerNormalization.Config): + _load_layer_norm(mlx_module, linen_params, config) + elif isinstance(config, jax_norm.BatchNormalization.Config): + _load_batch_norm(mlx_module, linen_params, config, batch_stats) + elif isinstance(config, jax_norm.GroupNormalization.Config): + _load_group_norm(mlx_module, linen_params, config) + elif isinstance(config, jax_dense.EinsumDense.Config): + _load_einsum_dense(mlx_module, linen_params, config) + elif isinstance(config, jax_cond.Conditioning.Config): + _load_conditioning(mlx_module, linen_params, config) + elif isinstance(config, jax_simple.Embedding.Config): + _load_embedding(mlx_module, linen_params, config) + # Stateless layers (Flatten, Identity, RoPE, pooling, etc.) have no params. + + +def _load_serial(mlx_serial, linen_params, config, batch_stats=None): + """Load Serial: walk layers_{i} in Linen, model.layers[i] in MLX.""" + for i, layer_config in enumerate(config.layers): + key = f'layers_{i}' + child_params = linen_params.get(key, {}) + child_bs = batch_stats.get(key, {}) if batch_stats else None + _load_config( + mlx_serial.layers[i], + child_params, + layer_config, + batch_stats=child_bs, + ) + + +def _load_parallel(mlx_parallel, linen_params, config, batch_stats=None): + """Load Parallel: walk layers_{i}, same as Serial.""" + for i, layer_config in enumerate(config.layers): + key = f'layers_{i}' + child_params = linen_params.get(key, {}) + child_bs = batch_stats.get(key, {}) if batch_stats else None + _load_config( + mlx_parallel.layers[i], + child_params, + layer_config, + batch_stats=child_bs, + ) + + +def _load_repeat(mlx_repeat, linen_params, config, batch_stats=None): + """Load Repeat: slice stacked Linen params for each MLX copy.""" + child_params = linen_params.get('child_layer', {}) + child_bs = batch_stats.get('child_layer', {}) if batch_stats else None + + # Linen Repeat stacks all child params with leading [num_repeats]. + # Slice axis 0 for each copy. + for i in range(config.num_repeats): + sliced = _slice_params(child_params, i) + sliced_bs = _slice_params(child_bs, i) if child_bs else None + _load_config( + mlx_repeat.layers[i], + sliced, + config.layer, + batch_stats=sliced_bs, + ) + + +def _slice_params(params, index): + """Slice the leading axis of all arrays in a param dict.""" + result = {} + for key, value in params.items(): + if isinstance(value, dict): + result[key] = _slice_params(value, index) + elif isinstance(value, np.ndarray): + result[key] = value[index] + else: + result[key] = value + return result + + +def _load_residual(mlx_residual, linen_params, config, batch_stats=None): + """Load Residual: body is layers_{i}, shortcut is shortcut_layer.""" + # Body is a Serial inside the Residual. + body = mlx_residual._body + for i, layer_config in enumerate(config.layers): + key = f'layers_{i}' + child_params = linen_params.get(key, {}) + child_bs = batch_stats.get(key, {}) if batch_stats else None + _load_config( + body.layers[i], + child_params, + layer_config, + batch_stats=child_bs, + ) + + # Shortcut (usually Identity — no params). + if config.shortcut_layers: + shortcut_params = linen_params.get('shortcut_layer', {}) + shortcut_bs = batch_stats.get('shortcut_layer', {}) if batch_stats else None + for i, sc_config in enumerate(config.shortcut_layers): + sc_key = f'layers_{i}' + sc_bs = shortcut_bs.get(sc_key, {}) if shortcut_bs else None + _load_config( + mlx_residual._shortcut, + shortcut_params.get(sc_key, {}), + sc_config, + batch_stats=sc_bs, + ) + + +def _load_dense(mlx_dense, linen_params, config): + """Load Dense: transpose kernel [in, out] → [out, in].""" + # Handle DenseDeferred wrapper. + inner = mlx_dense + if hasattr(inner, '_inner') and inner._inner is not None: + inner = inner._inner + + kernel = linen_params.get('kernel') + if kernel is not None: + # Linen: [in, out], MLX nn.Linear: [out, in] + weight = mx.array(kernel.T) + inner._linear.weight = weight + + bias = linen_params.get('bias') + if bias is not None: + inner._linear.bias = mx.array(bias) + + +def _load_einsum_dense(mlx_einsum, linen_params, config): + """Load EinsumDense: kernel shape matches directly (einsum notation).""" + kernel = linen_params.get('kernel') + if kernel is not None: + mlx_einsum.kernel = mx.array(kernel) + mlx_einsum._initialized = True + bias = linen_params.get('bias') + if bias is not None: + mlx_einsum.bias = mx.array(bias) + + +def _load_attention(mlx_attn, linen_params, config): + """Load DotProductSelfAttention. + + Handles: + - CombinedQueryKeyValueProjection: + query_key_value_projection/kernel [in, 3, heads, uph] + - SeparateQueryKeyValueProjection: + query_projection/kernel [in, heads, uph] + key_projection/kernel [in, kv_heads, uph] + value_projection/kernel [in, kv_heads, uph] + """ + from sequence_layers.jax.attention import common as attn_common + + # Handle Deferred wrapper. + inner = mlx_attn + if hasattr(inner, '_inner') and inner._inner is not None: + inner = inner._inner + + input_projection = config.input_projection + + if isinstance(input_projection, attn_common.CombinedQueryKeyValueProjection): + # Combined QKV: kernel [in, 3, heads, uph] → separate q/k/v. + qkv_params = linen_params.get('query_key_value_projection', {}) + combined_kernel = qkv_params.get('kernel') + if combined_kernel is not None: + in_features = combined_kernel.shape[0] + q, k, v = np.split(combined_kernel, 3, axis=1) + inner.q_proj = mx.array(q.reshape(in_features, -1)) + inner.k_proj = mx.array(k.reshape(in_features, -1)) + inner.v_proj = mx.array(v.reshape(in_features, -1)) + + combined_bias = qkv_params.get('bias') + if combined_bias is not None: + qb, kb, vb = np.split(combined_bias, 3, axis=0) + inner.q_bias = mx.array(qb.reshape(-1)) + inner.k_bias = mx.array(kb.reshape(-1)) + inner.v_bias = mx.array(vb.reshape(-1)) + + elif isinstance( + input_projection, attn_common.SeparateQueryKeyValueProjection + ): + # Separate Q/K/V projections (used for GQA where num_kv_heads < num_heads). + q_params = linen_params.get('query_projection', {}) + q_kernel = q_params.get('kernel') + if q_kernel is not None: + in_features = q_kernel.shape[0] + inner.q_proj = mx.array(q_kernel.reshape(in_features, -1)) + q_bias = q_params.get('bias') + if q_bias is not None: + inner.q_bias = mx.array(q_bias.reshape(-1)) + + k_params = linen_params.get('key_projection', {}) + k_kernel = k_params.get('kernel') + if k_kernel is not None: + in_features = k_kernel.shape[0] + inner.k_proj = mx.array(k_kernel.reshape(in_features, -1)) + k_bias = k_params.get('bias') + if k_bias is not None: + inner.k_bias = mx.array(k_bias.reshape(-1)) + + v_params = linen_params.get('value_projection', {}) + v_kernel = v_params.get('kernel') + if v_kernel is not None: + in_features = v_kernel.shape[0] + inner.v_proj = mx.array(v_kernel.reshape(in_features, -1)) + v_bias = v_params.get('bias') + if v_bias is not None: + inner.v_bias = mx.array(v_bias.reshape(-1)) + + # Q/K/V processing networks have no trainable params + # (RoPE is stateless with no learned weights). + + +def _load_streaming_attention(mlx_attn, linen_params, config): + """Load StreamingDotProductAttention. + + Handles different projection layouts: + - QueryAndKeyValueProjection (default): + query_projection/kernel [in, heads, uph] + key_value_projection/kernel [source, 2, heads, uph] + - SeparateQueryKeyValueProjection: + query_projection/kernel [in, heads, uph] + key_projection/kernel [source, heads, uph] + value_projection/kernel [source, heads, uph] + - QueryAndSharedKeyValueProjection: + query_projection/kernel [in, heads, uph] + shared_key_value_projection/kernel [source, heads, uph] + """ + from sequence_layers.jax.attention import common as attn_common + + # Handle Deferred wrapper. + inner = mlx_attn + if hasattr(inner, '_inner') and inner._inner is not None: + inner = inner._inner + + input_projection = config.input_projection + + # Load query projection. + q_params = linen_params.get('query_projection', {}) + q_kernel = q_params.get('kernel') + if q_kernel is not None: + # Shape: [in_features, num_heads, units_per_head] → [in, heads*uph] + in_features = q_kernel.shape[0] + inner.q_proj = mx.array(q_kernel.reshape(in_features, -1)) + q_bias = q_params.get('bias') + if q_bias is not None: + inner.q_bias = mx.array(q_bias.reshape(-1)) + + if isinstance(input_projection, attn_common.QueryAndKeyValueProjection): + # Combined KV: kernel [source, 2, heads, uph] → split into K, V. + kv_params = linen_params.get('key_value_projection', {}) + kv_kernel = kv_params.get('kernel') + if kv_kernel is not None: + source_features = kv_kernel.shape[0] + # Split along axis 1 (the '2' axis for K/V). + k, v = np.split(kv_kernel, 2, axis=1) + inner.k_proj = mx.array(k.reshape(source_features, -1)) + inner.v_proj = mx.array(v.reshape(source_features, -1)) + kv_bias = kv_params.get('bias') + if kv_bias is not None: + kb, vb = np.split(kv_bias, 2, axis=0) + inner.k_bias = mx.array(kb.reshape(-1)) + inner.v_bias = mx.array(vb.reshape(-1)) + + elif isinstance( + input_projection, attn_common.SeparateQueryKeyValueProjection + ): + # Separate K and V projections. + k_params = linen_params.get('key_projection', {}) + k_kernel = k_params.get('kernel') + if k_kernel is not None: + source_features = k_kernel.shape[0] + inner.k_proj = mx.array(k_kernel.reshape(source_features, -1)) + k_bias = k_params.get('bias') + if k_bias is not None: + inner.k_bias = mx.array(k_bias.reshape(-1)) + + v_params = linen_params.get('value_projection', {}) + v_kernel = v_params.get('kernel') + if v_kernel is not None: + source_features = v_kernel.shape[0] + inner.v_proj = mx.array(v_kernel.reshape(source_features, -1)) + v_bias = v_params.get('bias') + if v_bias is not None: + inner.v_bias = mx.array(v_bias.reshape(-1)) + + elif isinstance( + input_projection, attn_common.QueryAndSharedKeyValueProjection + ): + # Shared K/V projection: same weights for both K and V. + shared_params = linen_params.get('shared_key_value_projection', {}) + shared_kernel = shared_params.get('kernel') + if shared_kernel is not None: + source_features = shared_kernel.shape[0] + proj = mx.array(shared_kernel.reshape(source_features, -1)) + inner.k_proj = proj + inner.v_proj = proj + shared_bias = shared_params.get('bias') + if shared_bias is not None: + b = mx.array(shared_bias.reshape(-1)) + inner.k_bias = b + inner.v_bias = b + + +def _load_rms_norm(mlx_norm, linen_params, config): + """Load RMSNormalization: scale [dim] → same.""" + scale = linen_params.get('scale') + if scale is not None: + scale_mx = mx.array(scale) + if mlx_norm._use_builtin and mlx_norm._rms_norm is not None: + mlx_norm._rms_norm.weight = scale_mx + elif hasattr(mlx_norm, '_scale'): + mlx_norm._scale = scale_mx + + +def _load_layer_norm(mlx_norm, linen_params, config): + """Load LayerNormalization: scale and bias.""" + scale = linen_params.get('scale') + bias = linen_params.get('bias') + + if mlx_norm._use_builtin and mlx_norm._layer_norm is not None: + if scale is not None: + mlx_norm._layer_norm.weight = mx.array(scale) + if bias is not None: + mlx_norm._layer_norm.bias = mx.array(bias) + else: + if scale is not None and mlx_norm._manual_scale is not None: + mlx_norm._manual_scale = mx.array(scale) + if bias is not None and mlx_norm._manual_bias is not None: + mlx_norm._manual_bias = mx.array(bias) + + +def _load_embedding(mlx_emb, linen_params, config): + """Load Embedding: table [vocab, dim] → same.""" + embedding = linen_params.get('embedding') + if embedding is not None: + mlx_emb._embedding.weight = mx.array(embedding) + + +def _load_batch_norm(mlx_bn, linen_params, config, batch_stats=None): + """Load BatchNormalization: scale/bias from params, mean/var from batch_stats.""" + scale = linen_params.get('scale') + bias = linen_params.get('bias') + + if scale is not None and mlx_bn.use_scale: + mlx_bn._scale = mx.array(scale) + if bias is not None and mlx_bn.use_bias: + mlx_bn._bias = mx.array(bias) + + if batch_stats is not None: + mean = batch_stats.get('mean') + var = batch_stats.get('var') + if mean is not None: + mlx_bn._running_mean = mx.array(mean) + if var is not None: + mlx_bn._running_var = mx.array(var) + + +def _load_group_norm(mlx_gn, linen_params, config): + """Load GroupNormalization: scale and bias.""" + scale = linen_params.get('scale') + if scale is not None and mlx_gn.use_scale: + mlx_gn._scale = mx.array(scale) + bias = linen_params.get('bias') + if bias is not None and mlx_gn.use_bias: + mlx_gn._bias = mx.array(bias) + + +def _load_conv1d(mlx_conv, linen_params, config): + """Load Conv1D: kernel [k, in, out] → [out, k, in].""" + inner = mlx_conv + if hasattr(inner, '_inner') and inner._inner is not None: + inner = inner._inner + + kernel = linen_params.get('kernel') + if kernel is not None: + inner._conv.weight = mx.array(kernel.transpose(2, 0, 1)) + + bias = linen_params.get('bias') + if bias is not None: + inner._conv.bias = mx.array(bias) + + +def _load_depthwise_conv1d(mlx_conv, linen_params, config): + """Load DepthwiseConv1D: same kernel layout as Conv1D.""" + _load_conv1d(mlx_conv, linen_params, config) + + +def _load_conv1d_transpose(mlx_conv, linen_params, config): + """Load Conv1DTranspose: kernel [k, in, out] → [out, k, in]. + + The kernel is flipped along the spatial axis because Linen uses + conv_general_dilated with lhs_dilation (correlation), while MLX uses + conv_transpose1d which reverses the kernel direction. + """ + inner = mlx_conv + if hasattr(inner, '_inner') and inner._inner is not None: + inner = inner._inner + + kernel = linen_params.get('kernel') + if kernel is not None: + # Flip spatial axis, then transpose to MLX layout. + inner.kernel = mx.array(kernel[::-1].transpose(2, 0, 1)) + + bias = linen_params.get('bias') + if bias is not None: + inner.bias = mx.array(bias) + + +def _load_conditioning(mlx_cond, linen_params, config): + """Load Conditioning: projection Dense kernel/bias from 'dense' subdict. + + Linen Conditioning creates a DenseShaped under the name 'dense' for + LINEAR and LINEAR_AFFINE projections. The kernel shape matches directly + (input_kernel_shape + output_kernel_shape) since we use the same einsum + equation. + """ + from sequence_layers.jax import conditioning as jax_cond + + projection = config.projection + if projection == jax_cond.BaseConditioning.Projection.IDENTITY: + return # No params for identity projection. + + dense_params = linen_params.get('dense', {}) + kernel = dense_params.get('kernel') + if kernel is not None: + mlx_cond.kernel = mx.array(kernel) + mlx_cond._proj_initialized = True + bias = dense_params.get('bias') + if bias is not None: + mlx_cond.bias = mx.array(bias) diff --git a/sequence_layers/mlx/weight_converter_test.py b/sequence_layers/mlx/weight_converter_test.py new file mode 100644 index 0000000..88cb0b2 --- /dev/null +++ b/sequence_layers/mlx/weight_converter_test.py @@ -0,0 +1,463 @@ +"""Tests for weight_converter: Linen → MLX param conversion. + +Requires both JAX and MLX to be importable. +""" + +import jax +import jax.numpy as jnp +import mlx.core as mx +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +import sequence_layers.jax as sl +from sequence_layers.jax import types as jax_types +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export +from sequence_layers.mlx import weight_converter + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + + +def _run_jax_layer(config, jax_params, tokens): + """Run a Linen model on integer token inputs. + + Args: + config: A SequenceLayerConfig. + jax_params: Linen params dict. + tokens: A numpy array of shape [batch, time]. + + Returns: + numpy array of output values. + """ + model = config.make() + x = jax_types.Sequence( + jnp.array(tokens, dtype=jnp.int32), + jnp.ones(tokens.shape, dtype=jnp.bool_), + ) + y = model.apply({'params': jax_params}, x, training=False) + return np.array(y.values) + + +def _run_mlx_layer(mlx_model, tokens): + """Run an MLX model on integer token inputs. + + Args: + mlx_model: An MLX SequenceLayer. + tokens: A numpy array of shape [batch, time]. + + Returns: + numpy array of output values. + """ + x = Sequence( + mx.array(tokens, dtype=mx.int32), + mx.ones(tokens.shape, dtype=mx.bool_), + ) + y = mlx_model.layer(x) + mx.eval(y.values) + return np.array(y.values) + + +class EmbeddingConversionTest(parameterized.TestCase): + """Test embedding weight conversion.""" + + def test_embedding_round_trip(self): + config = sl.Embedding.Config( + num_embeddings=32, + dimension=16, + ) + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 4), dtype=jnp.int32), + jnp.ones((1, 4), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + # Create MLX model and load weights. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + tokens = np.array([[0, 5, 10, 31]]) + jax_out = _run_jax_layer(config, jax_params, tokens) + mlx_out = _run_mlx_layer(mlx_model, tokens) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='Embedding outputs differ', + ) + + +class DenseConversionTest(parameterized.TestCase): + """Test Dense weight conversion.""" + + def test_dense_round_trip(self): + config = sl.Serial.Config([ + sl.Embedding.Config(num_embeddings=32, dimension=8), + sl.Dense.Config(features=16), + ]) + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 4), dtype=jnp.int32), + jnp.ones((1, 4), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + tokens = np.array([[0, 3, 7, 15]]) + jax_out = _run_jax_layer(config, jax_params, tokens) + mlx_out = _run_mlx_layer(mlx_model, tokens) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='Dense outputs differ', + ) + + +class RMSNormConversionTest(parameterized.TestCase): + """Test RMSNorm weight conversion.""" + + def test_rms_norm_round_trip(self): + config = sl.Serial.Config([ + sl.Embedding.Config(num_embeddings=32, dimension=8), + sl.RMSNormalization.Config(), + ]) + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 4), dtype=jnp.int32), + jnp.ones((1, 4), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + tokens = np.array([[0, 3, 7, 15]]) + jax_out = _run_jax_layer(config, jax_params, tokens) + mlx_out = _run_mlx_layer(mlx_model, tokens) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='RMSNorm outputs differ', + ) + + +def _run_jax_layer_float(config, jax_params, values): + """Run a Linen model on float inputs. + + Args: + config: A SequenceLayerConfig. + jax_params: Linen params dict. + values: A numpy array of shape [batch, time, channels]. + + Returns: + numpy array of output values. + """ + model = config.make() + x = jax_types.Sequence( + jnp.array(values, dtype=jnp.float32), + jnp.ones(values.shape[:2], dtype=jnp.bool_), + ) + y = model.apply({'params': jax_params}, x, training=False) + return np.array(y.values) + + +def _run_mlx_layer_float(mlx_model, values): + """Run an MLX model on float inputs. + + Args: + mlx_model: An MLX SequenceLayer. + values: A numpy array of shape [batch, time, channels]. + + Returns: + numpy array of output values. + """ + x = Sequence( + mx.array(values, dtype=mx.float32), + mx.ones(values.shape[:2], dtype=mx.bool_), + ) + y = mlx_model.layer(x) + mx.eval(y.values) + return np.array(y.values) + + +class Conv1DConversionTest(parameterized.TestCase): + """Test Conv1D weight conversion.""" + + def test_conv1d_round_trip(self): + config = sl.Conv1D.Config( + filters=8, + kernel_size=3, + padding='causal', + ) + in_channels = 4 + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 8, in_channels), dtype=jnp.float32), + jnp.ones((1, 8), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType((in_channels,), mx.float32), + ) + + values = ( + np.random.RandomState(42).randn(1, 8, in_channels).astype(np.float32) + ) + jax_out = _run_jax_layer_float(config, jax_params, values) + mlx_out = _run_mlx_layer_float(mlx_model, values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='Conv1D outputs differ', + ) + + +class DepthwiseConv1DConversionTest(parameterized.TestCase): + """Test DepthwiseConv1D weight conversion.""" + + def test_depthwise_conv1d_round_trip(self): + config = sl.DepthwiseConv1D.Config( + kernel_size=3, + padding='causal', + ) + in_channels = 4 + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 8, in_channels), dtype=jnp.float32), + jnp.ones((1, 8), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType((in_channels,), mx.float32), + ) + + values = ( + np.random.RandomState(42).randn(1, 8, in_channels).astype(np.float32) + ) + jax_out = _run_jax_layer_float(config, jax_params, values) + mlx_out = _run_mlx_layer_float(mlx_model, values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='DepthwiseConv1D outputs differ', + ) + + +class Conv1DTransposeConversionTest(parameterized.TestCase): + """Test Conv1DTranspose weight conversion.""" + + def test_conv1d_transpose_round_trip(self): + config = sl.Conv1DTranspose.Config( + filters=8, + kernel_size=3, + strides=2, + padding='causal', + ) + in_channels = 4 + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 8, in_channels), dtype=jnp.float32), + jnp.ones((1, 8), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(0), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params( + mlx_model, + jax_params, + config, + input_spec=ShapeDType((in_channels,), mx.float32), + ) + + values = ( + np.random.RandomState(42).randn(1, 8, in_channels).astype(np.float32) + ) + jax_out = _run_jax_layer_float(config, jax_params, values) + mlx_out = _run_mlx_layer_float(mlx_model, values) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-5, + rtol=1e-5, + err_msg='Conv1DTranspose outputs differ', + ) + + +class DecoderTransformerConversionTest(parameterized.TestCase): + """Test full decoder transformer weight conversion.""" + + def _decoder_config(self): + return sl.Serial.Config([ + sl.Embedding.Config( + num_embeddings=32, + dimension=16, + ), + sl.Repeat.Config( + num_repeats=2, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=2, + units_per_head=8, + max_past_horizon=16, + max_future_horizon=0, + query_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + key_network=( + sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ) + ), + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config( + features=64, + activation=jax.nn.gelu, + ), + sl.Dense.Config(features=16), + ]), + ]), + ), + sl.RMSNormalization.Config(), + sl.Dense.Config(features=32), + ]) + + def test_layer_output_match(self): + """Full transformer: JAX and MLX produce same layer() output.""" + config = self._decoder_config() + + # Build and init JAX model. + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 4), dtype=jnp.int32), + jnp.ones((1, 4), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(42), x, training=False) + jax_params = variables['params'] + + # Build MLX model and load Linen weights. + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + tokens = np.array([[0, 5, 10, 31]]) + jax_out = _run_jax_layer(config, jax_params, tokens) + mlx_out = _run_mlx_layer(mlx_model, tokens) + + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-3, + rtol=1e-3, + err_msg='Decoder transformer outputs differ', + ) + + def test_step_output_match(self): + """Full transformer: JAX step and MLX step produce same output.""" + config = self._decoder_config() + + jax_model = config.make() + x = jax_types.Sequence( + jnp.zeros((1, 4), dtype=jnp.int32), + jnp.ones((1, 4), dtype=jnp.bool_), + ) + variables = jax_model.init(jax.random.PRNGKey(42), x, training=False) + jax_params = variables['params'] + + mlx_model = config.make(backend='mlx') + weight_converter.load_linen_params(mlx_model, jax_params, config) + + # Run step-by-step on both. + tokens = [5, 10, 31] + + # JAX step. + jax_spec = jax.ShapeDtypeStruct((), jnp.int32) + jax_state = jax_model.apply( + {'params': jax_params}, + 1, + jax_spec, + training=False, + method=jax_model.get_initial_state, + ) + jax_outputs = [] + for t in tokens: + x_jax = jax_types.Sequence( + jnp.array([[t]], dtype=jnp.int32), + jnp.ones((1, 1), dtype=jnp.bool_), + ) + y_jax, jax_state = jax_model.apply( + {'params': jax_params}, + x_jax, + jax_state, + training=False, + method=jax_model.step, + ) + jax_outputs.append(np.array(y_jax.values)) + + # MLX step. + input_spec = ShapeDType((), mx.int32) + export._materialize_deferred(mlx_model, 1, input_spec) + mlx_state = mlx_model.get_initial_state(1, input_spec) + mlx_outputs = [] + for t in tokens: + x_mx = Sequence( + mx.array([[t]], dtype=mx.int32), + mx.ones((1, 1), dtype=mx.bool_), + ) + y_mx, mlx_state = mlx_model.step(x_mx, mlx_state) + mx.eval(y_mx.values) + mlx_outputs.append(np.array(y_mx.values)) + + for i, (jax_out, mlx_out) in enumerate(zip(jax_outputs, mlx_outputs)): + np.testing.assert_allclose( + mlx_out, + jax_out, + atol=1e-3, + rtol=1e-3, + err_msg=f'Step {i} (token={tokens[i]}): outputs differ', + ) + + +if __name__ == '__main__': + absltest.main() From 8cbc1b6fa700a7c2a732e990cdc5145dacb74f81 Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 23 Feb 2026 23:22:50 -0500 Subject: [PATCH 02/17] Update attention.py --- sequence_layers/mlx/attention.py | 70 ++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index 9f8fb1c..e869b2a 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -352,30 +352,56 @@ def step_with_emits(self, x, state, *, constants=None): x_time = x.shape[1] kv_buffer_size = kv_buf_k.shape[1] - # Append new K/V to buffer and trim to buffer size. - new_k = mx.concatenate([kv_buf_k, keys.values], axis=1) - new_v = mx.concatenate([kv_buf_v, values.values], axis=1) - new_mask = mx.concatenate([kv_buf_mask, x.mask], axis=1) - - # Keep only the last kv_buffer_size entries. - new_k = new_k[:, -kv_buffer_size:] - new_v = new_v[:, -kv_buffer_size:] - new_mask = new_mask[:, -kv_buffer_size:] - - # Build visibility mask: [b, 1, q_time, kv_time]. - kv_valid = new_mask[:, None, None, :] # [b,1,1,kvt] - - # Add causal mask for multi-step queries. - if x_time > 1: - causal = _causal_mask(x_time, new_k.shape[1]) - kv_valid = kv_valid & causal - - context = self._compute_attention(queries.values, new_k, new_v, kv_valid) + if kv_buffer_size > 0: + # Ring buffer write: insert new K/V at rotating positions. + # Uses put_along_axis to scatter into pre-allocated buffers, + # compatible with mx.compile / mx.export_function (no Python + # int conversion needed). + t0 = time_step[0] # MLX scalar, no eval. + positions = (t0 + mx.arange(x_time)) % kv_buffer_size # [x_time] + + # Scatter K/V into buffer at ring positions. + idx_4d = mx.broadcast_to( + positions.reshape(1, x_time, 1, 1), keys.values.shape + ) + kv_buf_k = mx.put_along_axis(kv_buf_k, idx_4d, keys.values, axis=1) + kv_buf_v = mx.put_along_axis(kv_buf_v, idx_4d, values.values, axis=1) + + # Scatter mask into buffer. + idx_2d = mx.broadcast_to(positions.reshape(1, x_time), x.mask.shape) + kv_buf_mask = mx.put_along_axis(kv_buf_mask, idx_2d, x.mask, axis=1) + + # Build visibility mask: [b, 1, 1, kv_buffer_size]. + kv_valid = kv_buf_mask[:, None, None, :] + + # Add causal mask for multi-step queries (respects ring buffer order). + if x_time > 1: + newest_time = t0 + x_time - 1 + newest_pos = newest_time % kv_buffer_size + phys = mx.arange(kv_buffer_size) + dist = (newest_pos - phys + kv_buffer_size) % kv_buffer_size + temporal = newest_time - dist + q_times = t0 + mx.arange(x_time) + causal = temporal[None, :] <= q_times[:, None] + kv_valid = kv_valid & causal.reshape(1, 1, x_time, kv_buffer_size) + + context = self._compute_attention( + queries.values, kv_buf_k, kv_buf_v, kv_valid + ) + else: + # Degenerate: no history buffer, attend only to current step. + kv_valid = x.mask[:, None, None, :] + if x_time > 1: + causal = _causal_mask(x_time, x_time) + kv_valid = kv_valid & causal + context = self._compute_attention( + queries.values, keys.values, values.values, kv_valid + ) new_state = ( - new_k, - new_v, - new_mask, + kv_buf_k, + kv_buf_v, + kv_buf_mask, time_step + x_time, q_net_state, k_net_state, From 496f4f47809ec85aa8f4216bb7e4f896044726aa Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 23 Feb 2026 23:56:13 -0500 Subject: [PATCH 03/17] Create depthformer.py --- notebooks/depthformer.py | 450 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 450 insertions(+) create mode 100644 notebooks/depthformer.py diff --git a/notebooks/depthformer.py b/notebooks/depthformer.py new file mode 100644 index 0000000..8794e02 --- /dev/null +++ b/notebooks/depthformer.py @@ -0,0 +1,450 @@ +# %% [markdown] +# # Depthformer: Hierarchical Audio Token Generation +# +# This example builds a [Moshi](https://github.com/kyutai-labs/moshi)-style +# **Depthformer** using MLX SequenceLayers. The Depthformer generates audio +# tokens hierarchically across RVQ (Residual Vector Quantization) codebook +# levels. +# +# ## Architecture +# +# ``` +# Text + Audio Embeddings (summed) +# | +# Main Transformer (causal, RoPE, SwiGLU) +# | RMSNorm +# |---> Text Head --> text logits +# \---> Depth Generation Loop +# for slice_i in [0..3]: +# x = project(main_out) + embed(prev_token) +# y = depth_transformer.step(x) <-- shared weights +# logits = project(y) --> sample token +# --> 4 audio codebook tokens +# ``` +# +# The **main transformer** is a standard causal decoder with RoPE and +# SwiGLU gating. It processes the sum of text + audio embeddings along +# the time dimension. +# +# The **depth transformer** processes the "depth" dimension (codebook +# levels), not time. At each main-transformer step, it autoregressively +# generates one token per codebook level, conditioned on the main +# transformer output and all previous codebook levels. Its KV cache +# resets between time steps. +# +# **Requires:** `pip install sequence-layers[mlx]` + +# %% [markdown] +# ## 1. Setup + +# %% +import jax.nn +import mlx.core as mx +import mlx.nn as nn +import mlx.utils + +import sequence_layers.jax as sl +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import export + +Sequence = bt.Sequence +ShapeDType = bt.ShapeDType + +# --- Hyperparameters (toy-sized for fast execution) --- + +# Main transformer: processes text + audio along time. +MAIN_DIM = 128 +MAIN_HEADS = 4 +MAIN_UPH = MAIN_DIM // MAIN_HEADS # 32 +MAIN_LAYERS = 4 +MAIN_HIDDEN = MAIN_DIM * 2 # 256 (SwiGLU intermediate) +MAIN_CONTEXT = 128 + +# Depth transformer: generates audio codebooks along depth. +DEPTH_DIM = 64 +DEPTH_HEADS = 4 +DEPTH_UPH = DEPTH_DIM // DEPTH_HEADS # 16 +DEPTH_LAYERS = 2 +DEPTH_HIDDEN = DEPTH_DIM * 2 # 128 + +# Vocabulary. +TEXT_VOCAB = 256 +AUDIO_VOCAB = 256 +NUM_AUDIO_CODEBOOKS = 8 # total codebooks +NUM_DEPTH_SLICES = 4 # generated by depth transformer +NUM_INPUT_CODEBOOKS = NUM_AUDIO_CODEBOOKS - NUM_DEPTH_SLICES # from input + +# Generation. +BATCH_SIZE = 1 +NUM_TOKENS = 16 + +print( + f'Main transformer: dim={MAIN_DIM}, heads={MAIN_HEADS}, ' + f'layers={MAIN_LAYERS}, hidden={MAIN_HIDDEN}, context={MAIN_CONTEXT}' +) +print( + f'Depth transformer: dim={DEPTH_DIM}, heads={DEPTH_HEADS}, ' + f'layers={DEPTH_LAYERS}, hidden={DEPTH_HIDDEN}, ' + f'slices={NUM_DEPTH_SLICES}' +) +print( + f'Vocab: text={TEXT_VOCAB}, audio={AUDIO_VOCAB}, ' + f'codebooks={NUM_AUDIO_CODEBOOKS} ({NUM_DEPTH_SLICES} generated + ' + f'{NUM_INPUT_CODEBOOKS} input)' +) + +# %% [markdown] +# ## 2. Architecture Configs +# +# Both transformers use pre-norm residual blocks with SwiGLU gated FFN. +# The main transformer has RoPE positional encoding; the depth transformer +# has **none** (the codebook ordering is fixed and learned implicitly). +# +# Each transformer block: +# ``` +# x + Flatten(SelfAttention(RMSNorm(x))) <-- attention residual +# x + Dense(GatedUnit(Dense(RMSNorm(x)))) <-- SwiGLU FFN residual +# ``` + + +# %% +def main_transformer_config(): + """Main causal transformer stack with RoPE and SwiGLU.""" + return sl.Repeat.Config( + num_repeats=MAIN_LAYERS, + layer=sl.Serial.Config([ + # Self-attention residual block. + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=MAIN_HEADS, + units_per_head=MAIN_UPH, + max_past_horizon=MAIN_CONTEXT, + max_future_horizon=0, + query_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + key_network=sl.ApplyRotaryPositionalEncoding.Config( + max_wavelength=10_000.0, + ), + ), + sl.Flatten.Config(), + ]), + # SwiGLU FFN residual block. + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=2 * MAIN_HIDDEN, use_bias=False), + sl.GatedUnit.Config( + feature_activation=jax.nn.silu, + gate_activation=None, + ), + sl.Dense.Config(features=MAIN_DIM, use_bias=False), + ]), + ]), + ) + + +def depth_transformer_config(): + """Depth transformer stack: no positional encoding, small context.""" + return sl.Repeat.Config( + num_repeats=DEPTH_LAYERS, + layer=sl.Serial.Config([ + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.DotProductSelfAttention.Config( + num_heads=DEPTH_HEADS, + units_per_head=DEPTH_UPH, + max_past_horizon=NUM_DEPTH_SLICES, + max_future_horizon=0, + # No positional encoding for depth! + ), + sl.Flatten.Config(), + ]), + sl.Residual.Config([ + sl.RMSNormalization.Config(), + sl.Dense.Config(features=2 * DEPTH_HIDDEN, use_bias=False), + sl.GatedUnit.Config( + feature_activation=jax.nn.silu, + gate_activation=None, + ), + sl.Dense.Config(features=DEPTH_DIM, use_bias=False), + ]), + ]), + ) + + +main_config = main_transformer_config() +depth_config = depth_transformer_config() +print('Configs defined.') + +# %% [markdown] +# ## 3. Model Definition +# +# The `Depthformer` wraps two sequence-layers models (main + depth +# transformer stacks) with multi-modal embeddings and the depth +# generation loop. SequenceLayers handles KV caching, step/layer +# duality, and state management automatically. + +# %% + + +class DepthSlice(nn.Module): + """Per-codebook components: embedding, input/output projections. + + Each depth slice has: + - emb: embeds the previous slice's token (or text token for slice 0) + - linear_in: projects main transformer output to depth dimension + - linear_out: projects depth output to audio vocabulary logits + """ + + def __init__(self, in_vocab_size, out_vocab_size, main_dim, depth_dim): + super().__init__() + self.emb = nn.Embedding(in_vocab_size, depth_dim) + self.linear_in = nn.Linear(main_dim, depth_dim, bias=False) + self.linear_out = nn.Linear(depth_dim, out_vocab_size, bias=False) + + +class Depthformer(nn.Module): + """Moshi-style Depthformer: main transformer + depth generation. + + The main transformer processes combined text + audio embeddings along + the time dimension. The depth transformer generates audio tokens + hierarchically across codebook levels, conditioned on the main + transformer output. + + At each time step: + 1. Sum text + audio embeddings + 2. Step the main transformer (updates KV cache) + 3. Predict next text token + 4. Run depth generation: for each codebook slice, step the depth + transformer (shared weights, accumulating KV cache), and sample + the next audio token conditioned on previous slices. + 5. Reset depth KV cache (depth context doesn't persist across time) + """ + + def __init__( + self, + main_stack, + depth_stack, + main_dim, + depth_dim, + text_vocab, + audio_vocab, + num_codebooks, + num_slices, + ): + super().__init__() + self.main_stack = main_stack + self.depth_stack = depth_stack + self.text_emb = nn.Embedding(text_vocab, main_dim) + self.audio_embs = [ + nn.Embedding(audio_vocab, main_dim) for _ in range(num_codebooks) + ] + self.out_norm = nn.RMSNorm(main_dim) + self.text_head = nn.Linear(main_dim, text_vocab, bias=False) + + # Depth slices: first slice conditions on text, rest on audio. + self.slices = [] + for i in range(num_slices): + in_vocab = text_vocab if i == 0 else audio_vocab + self.slices.append(DepthSlice(in_vocab, audio_vocab, main_dim, depth_dim)) + + self._main_dim = main_dim + self._depth_dim = depth_dim + + def _embed(self, text_token, audio_tokens): + """Sum text + audio embeddings -> [batch, time, main_dim].""" + x = self.text_emb(text_token) + for tok, emb in zip(audio_tokens, self.audio_embs): + x = x + emb(tok) + return x + + def __call__(self, text_tokens, audio_tokens_list): + """Layer mode: full sequence through the main transformer. + + Args: + text_tokens: [batch, time] int32. + audio_tokens_list: list of num_codebooks [batch, time] arrays. + + Returns: + text_logits: [batch, time, text_vocab]. + hidden: [batch, time, main_dim]. + """ + x = self._embed(text_tokens, audio_tokens_list) + mask = mx.ones(x.shape[:2], dtype=mx.bool_) + y = self.main_stack.layer(Sequence(x, mask)) + h = self.out_norm(y.values) + text_logits = self.text_head(h) + return text_logits, h + + def get_initial_main_state(self, batch_size): + """Create initial state for the main transformer.""" + spec = ShapeDType((self._main_dim,), mx.float32) + return self.main_stack.get_initial_state(batch_size, spec) + + def generate_step(self, text_token, audio_tokens, main_state): + """One autoregressive step: main transformer + depth generation. + + Args: + text_token: [batch, 1] int32. + audio_tokens: list of num_codebooks [batch, 1] int32 arrays. + main_state: main transformer KV cache state. + + Returns: + next_text: [batch, 1] int32. + next_audio: list of num_slices [batch, 1] int32 arrays. + new_main_state: updated main transformer state. + """ + # --- Main transformer step --- + x = self._embed(text_token, audio_tokens) + mask = mx.ones((x.shape[0], 1), dtype=mx.bool_) + y, main_state = self.main_stack.step(Sequence(x, mask), main_state) + h = self.out_norm(y.values) # [batch, 1, main_dim] + + # Text prediction. + text_logits = self.text_head(h) + next_text = mx.argmax(text_logits[:, 0, :], axis=-1, keepdims=True) + + # --- Depth generation --- + # Fresh KV cache for each time step (depth context = codebook levels). + depth_spec = ShapeDType((self._depth_dim,), mx.float32) + depth_state = self.depth_stack.get_initial_state(x.shape[0], depth_spec) + + prev_token = next_text # First slice conditions on text token. + next_audio = [] + + for s in self.slices: + # Combine main output projection + previous token embedding. + dx = s.linear_in(h) + s.emb(prev_token) + + # Step the shared depth transformer (accumulates KV cache). + dy, depth_state = self.depth_stack.step(Sequence(dx, mask), depth_state) + + # Sample next audio token. + logits = s.linear_out(dy.values) + prev_token = mx.argmax(logits[:, 0, :], axis=-1, keepdims=True) + next_audio.append(prev_token) + + return next_text, next_audio, main_state + + +# Build sequence-layers transformer stacks from configs. +main_stack = main_config.make(backend='mlx') +depth_stack = depth_config.make(backend='mlx') + +model = Depthformer( + main_stack, + depth_stack, + main_dim=MAIN_DIM, + depth_dim=DEPTH_DIM, + text_vocab=TEXT_VOCAB, + audio_vocab=AUDIO_VOCAB, + num_codebooks=NUM_AUDIO_CODEBOOKS, + num_slices=NUM_DEPTH_SLICES, +) + +# Materialize deferred layers (attention needs to know in_features). +main_input_spec = ShapeDType((MAIN_DIM,), mx.float32) +depth_input_spec = ShapeDType((DEPTH_DIM,), mx.float32) +export._materialize_deferred(model.main_stack, BATCH_SIZE, main_input_spec) +export._materialize_deferred(model.depth_stack, BATCH_SIZE, depth_input_spec) + +# Count parameters. +nparams = sum(v.size for _, v in mlx.utils.tree_flatten(model.parameters())) +print(f'Model built: {nparams:,} parameters') + +# %% [markdown] +# ## 4. Layer Mode +# +# Process a full sequence through the main transformer. This is useful +# for teacher forcing (training) or batch inference on prefilled context. + +# %% +seq_len = 8 + +# Random input tokens. +text_tokens = mx.random.randint(0, TEXT_VOCAB, shape=(BATCH_SIZE, seq_len)) +audio_tokens_list = [ + mx.random.randint(0, AUDIO_VOCAB, shape=(BATCH_SIZE, seq_len)) + for _ in range(NUM_AUDIO_CODEBOOKS) +] + +text_logits, hidden = model(text_tokens, audio_tokens_list) +mx.eval(text_logits, hidden) + +print( + f'Input: text {text_tokens.shape}, ' + f'audio {len(audio_tokens_list)}x{audio_tokens_list[0].shape}' +) +print(f'Text logits: {text_logits.shape}') +print(f'Hidden: {hidden.shape}') + +# %% [markdown] +# ## 5. Streaming Generation +# +# Token-by-token autoregressive generation. At each step: +# 1. The main transformer processes the current token (KV cache persists) +# 2. The depth transformer generates audio tokens across codebook levels +# (KV cache resets per time step) +# 3. Generated audio tokens feed back as input at the next time step + +# %% +# Initialize state. +main_state = model.get_initial_main_state(BATCH_SIZE) + +# Start with zero tokens (padding / start-of-sequence). +text_token = mx.zeros((BATCH_SIZE, 1), dtype=mx.int32) +audio_tokens = [ + mx.zeros((BATCH_SIZE, 1), dtype=mx.int32) + for _ in range(NUM_AUDIO_CODEBOOKS) +] + +generated_text = [] +generated_audio = [[] for _ in range(NUM_DEPTH_SLICES)] + +print(f'Generating {NUM_TOKENS} steps...\n') + +for t in range(NUM_TOKENS): + next_text, next_audio, main_state = model.generate_step( + text_token, audio_tokens, main_state + ) + mx.eval(next_text, *next_audio) + + # Record generated tokens. + generated_text.append(int(next_text[0, 0])) + for i in range(NUM_DEPTH_SLICES): + generated_audio[i].append(int(next_audio[i][0, 0])) + + # Prepare next input. + text_token = next_text + for i in range(NUM_DEPTH_SLICES): + audio_tokens[i] = next_audio[i] + # Input codebooks stay zero (no "environment" audio in this demo). + +print(f'Text tokens: {generated_text}') +for i in range(NUM_DEPTH_SLICES): + print(f'Audio cb {i}: {generated_audio[i]}') + +# %% [markdown] +# ## 6. Why SequenceLayers? +# +# Compare the Moshi codebase (~1000 lines of handwritten transformer, +# attention, KV cache, and streaming code) with the SequenceLayers +# approach: +# +# | Aspect | Moshi (handwritten) | SequenceLayers | +# |--------|-------------------|----------------| +# | Transformer | `Transformer`, `TransformerLayer`, `Attention`, `MlpGating` classes | `Repeat(Serial(Residual(Norm, Attn, Flatten), Residual(Norm, Dense, GatedUnit, Dense)))` config | +# | KV cache | Manual `KVCache`, `RotatingKVCache` classes + explicit offset tracking | Automatic via `step()` — ring buffer with `put_along_axis` | +# | Streaming | Custom `__call__` with cache threading | `model.step(x, state)` returns updated state | +# | Layer / Step | Separate code paths | `layer()` and `step()` from one config, verified equivalent | +# | Positional encoding | Separate `nn.RoPE` object, manually applied | `query_network=ApplyRotaryPositionalEncoding.Config(...)` | +# | Export | N/A | `export.export_step(model, path, ...)` → `.mlxfn` | +# +# The model-specific code (embeddings, depth generation loop) stays +# custom, while the transformer building blocks come from SequenceLayers +# with all the streaming, caching, and state management handled +# automatically. + +# %% +print('Done!') From 001d31a8d811db38064240c558cb81c5bff5356b Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Thu, 26 Feb 2026 09:17:44 -0500 Subject: [PATCH 04/17] Squashed commit of the following: commit f2f2198d51f3ecfec3564d390b783c30b225c74e Author: David Braun <2096055+DBraun@users.noreply.github.com> Date: Thu Feb 26 09:06:30 2026 -0500 remove depthformer and moshi_demo.py commit 74669fcf89ee791519b065c7b07102291e22392b Author: Kehang Han Date: Wed Feb 25 15:15:26 2026 -0800 This commit finally makes MLX generate good music! commit 9e4a452e44abb2fc01e9c4ddafb13d18019fffe2 Author: Kehang Han Date: Wed Feb 25 12:47:09 2026 -0800 Changes to make temporal output match between jax and mlx commit 8d267f25707c99ea13fac97449116f0aea96c4e6 Author: David Braun <2096055+DBraun@users.noreply.github.com> Date: Wed Feb 25 12:45:53 2026 -0500 Update weight_converter.py commit 68add71024e104bf58ea9041605e48e2ae27f400 Author: David Braun <2096055+DBraun@users.noreply.github.com> Date: Wed Feb 25 12:41:52 2026 -0500 MLX attention: add per_dim_scale commit b4ae9057effae0842a92d818a70ce161dc93e330 Author: David Braun <2096055+DBraun@users.noreply.github.com> Date: Wed Feb 25 11:31:02 2026 -0500 Update normalization.py commit a3e84cf09f72922162e25f4075287ab5196dd05e Author: Kehang Han Date: Tue Feb 24 16:04:27 2026 -0800 To make encoded approx bit level compatible commit c6596c753d4b4b51478bc463a9b9e68b2acbc138 Author: Kehang Han Date: Tue Feb 24 14:35:16 2026 -0800 Updates to MLX port of SL to support magenta-rt commit 08077bb5e2edb4e5932bf4572712989837489a7b Author: David Braun <2096055+DBraun@users.noreply.github.com> Date: Tue Feb 24 08:54:43 2026 -0500 Create moshi_demo.py --- notebooks/depthformer.py | 450 -------- sequence_layers/mlx/__init__.py | 78 +- sequence_layers/mlx/attention.py | 377 ++++++- sequence_layers/mlx/attention_test.py | 54 + sequence_layers/mlx/combinators.py | 164 ++- sequence_layers/mlx/conditioning.py | 17 + sequence_layers/mlx/convolution.py | 36 +- sequence_layers/mlx/convolution2d.py | 1143 +++++++++++++++++++++ sequence_layers/mlx/dense.py | 55 +- sequence_layers/mlx/dsp.py | 60 +- sequence_layers/mlx/normalization.py | 58 +- sequence_layers/mlx/position.py | 15 + sequence_layers/mlx/projection_configs.py | 124 +++ sequence_layers/mlx/signal.py | 62 ++ sequence_layers/mlx/simple.py | 166 ++- sequence_layers/mlx/types.py | 14 +- sequence_layers/mlx/typing.py | 43 + sequence_layers/mlx/utils.py | 72 ++ sequence_layers/mlx/weight_converter.py | 46 +- 19 files changed, 2424 insertions(+), 610 deletions(-) delete mode 100644 notebooks/depthformer.py create mode 100644 sequence_layers/mlx/convolution2d.py create mode 100644 sequence_layers/mlx/projection_configs.py create mode 100644 sequence_layers/mlx/signal.py create mode 100644 sequence_layers/mlx/typing.py create mode 100644 sequence_layers/mlx/utils.py diff --git a/notebooks/depthformer.py b/notebooks/depthformer.py deleted file mode 100644 index 8794e02..0000000 --- a/notebooks/depthformer.py +++ /dev/null @@ -1,450 +0,0 @@ -# %% [markdown] -# # Depthformer: Hierarchical Audio Token Generation -# -# This example builds a [Moshi](https://github.com/kyutai-labs/moshi)-style -# **Depthformer** using MLX SequenceLayers. The Depthformer generates audio -# tokens hierarchically across RVQ (Residual Vector Quantization) codebook -# levels. -# -# ## Architecture -# -# ``` -# Text + Audio Embeddings (summed) -# | -# Main Transformer (causal, RoPE, SwiGLU) -# | RMSNorm -# |---> Text Head --> text logits -# \---> Depth Generation Loop -# for slice_i in [0..3]: -# x = project(main_out) + embed(prev_token) -# y = depth_transformer.step(x) <-- shared weights -# logits = project(y) --> sample token -# --> 4 audio codebook tokens -# ``` -# -# The **main transformer** is a standard causal decoder with RoPE and -# SwiGLU gating. It processes the sum of text + audio embeddings along -# the time dimension. -# -# The **depth transformer** processes the "depth" dimension (codebook -# levels), not time. At each main-transformer step, it autoregressively -# generates one token per codebook level, conditioned on the main -# transformer output and all previous codebook levels. Its KV cache -# resets between time steps. -# -# **Requires:** `pip install sequence-layers[mlx]` - -# %% [markdown] -# ## 1. Setup - -# %% -import jax.nn -import mlx.core as mx -import mlx.nn as nn -import mlx.utils - -import sequence_layers.jax as sl -from sequence_layers.mlx import basic_types as bt -from sequence_layers.mlx import export - -Sequence = bt.Sequence -ShapeDType = bt.ShapeDType - -# --- Hyperparameters (toy-sized for fast execution) --- - -# Main transformer: processes text + audio along time. -MAIN_DIM = 128 -MAIN_HEADS = 4 -MAIN_UPH = MAIN_DIM // MAIN_HEADS # 32 -MAIN_LAYERS = 4 -MAIN_HIDDEN = MAIN_DIM * 2 # 256 (SwiGLU intermediate) -MAIN_CONTEXT = 128 - -# Depth transformer: generates audio codebooks along depth. -DEPTH_DIM = 64 -DEPTH_HEADS = 4 -DEPTH_UPH = DEPTH_DIM // DEPTH_HEADS # 16 -DEPTH_LAYERS = 2 -DEPTH_HIDDEN = DEPTH_DIM * 2 # 128 - -# Vocabulary. -TEXT_VOCAB = 256 -AUDIO_VOCAB = 256 -NUM_AUDIO_CODEBOOKS = 8 # total codebooks -NUM_DEPTH_SLICES = 4 # generated by depth transformer -NUM_INPUT_CODEBOOKS = NUM_AUDIO_CODEBOOKS - NUM_DEPTH_SLICES # from input - -# Generation. -BATCH_SIZE = 1 -NUM_TOKENS = 16 - -print( - f'Main transformer: dim={MAIN_DIM}, heads={MAIN_HEADS}, ' - f'layers={MAIN_LAYERS}, hidden={MAIN_HIDDEN}, context={MAIN_CONTEXT}' -) -print( - f'Depth transformer: dim={DEPTH_DIM}, heads={DEPTH_HEADS}, ' - f'layers={DEPTH_LAYERS}, hidden={DEPTH_HIDDEN}, ' - f'slices={NUM_DEPTH_SLICES}' -) -print( - f'Vocab: text={TEXT_VOCAB}, audio={AUDIO_VOCAB}, ' - f'codebooks={NUM_AUDIO_CODEBOOKS} ({NUM_DEPTH_SLICES} generated + ' - f'{NUM_INPUT_CODEBOOKS} input)' -) - -# %% [markdown] -# ## 2. Architecture Configs -# -# Both transformers use pre-norm residual blocks with SwiGLU gated FFN. -# The main transformer has RoPE positional encoding; the depth transformer -# has **none** (the codebook ordering is fixed and learned implicitly). -# -# Each transformer block: -# ``` -# x + Flatten(SelfAttention(RMSNorm(x))) <-- attention residual -# x + Dense(GatedUnit(Dense(RMSNorm(x)))) <-- SwiGLU FFN residual -# ``` - - -# %% -def main_transformer_config(): - """Main causal transformer stack with RoPE and SwiGLU.""" - return sl.Repeat.Config( - num_repeats=MAIN_LAYERS, - layer=sl.Serial.Config([ - # Self-attention residual block. - sl.Residual.Config([ - sl.RMSNormalization.Config(), - sl.DotProductSelfAttention.Config( - num_heads=MAIN_HEADS, - units_per_head=MAIN_UPH, - max_past_horizon=MAIN_CONTEXT, - max_future_horizon=0, - query_network=sl.ApplyRotaryPositionalEncoding.Config( - max_wavelength=10_000.0, - ), - key_network=sl.ApplyRotaryPositionalEncoding.Config( - max_wavelength=10_000.0, - ), - ), - sl.Flatten.Config(), - ]), - # SwiGLU FFN residual block. - sl.Residual.Config([ - sl.RMSNormalization.Config(), - sl.Dense.Config(features=2 * MAIN_HIDDEN, use_bias=False), - sl.GatedUnit.Config( - feature_activation=jax.nn.silu, - gate_activation=None, - ), - sl.Dense.Config(features=MAIN_DIM, use_bias=False), - ]), - ]), - ) - - -def depth_transformer_config(): - """Depth transformer stack: no positional encoding, small context.""" - return sl.Repeat.Config( - num_repeats=DEPTH_LAYERS, - layer=sl.Serial.Config([ - sl.Residual.Config([ - sl.RMSNormalization.Config(), - sl.DotProductSelfAttention.Config( - num_heads=DEPTH_HEADS, - units_per_head=DEPTH_UPH, - max_past_horizon=NUM_DEPTH_SLICES, - max_future_horizon=0, - # No positional encoding for depth! - ), - sl.Flatten.Config(), - ]), - sl.Residual.Config([ - sl.RMSNormalization.Config(), - sl.Dense.Config(features=2 * DEPTH_HIDDEN, use_bias=False), - sl.GatedUnit.Config( - feature_activation=jax.nn.silu, - gate_activation=None, - ), - sl.Dense.Config(features=DEPTH_DIM, use_bias=False), - ]), - ]), - ) - - -main_config = main_transformer_config() -depth_config = depth_transformer_config() -print('Configs defined.') - -# %% [markdown] -# ## 3. Model Definition -# -# The `Depthformer` wraps two sequence-layers models (main + depth -# transformer stacks) with multi-modal embeddings and the depth -# generation loop. SequenceLayers handles KV caching, step/layer -# duality, and state management automatically. - -# %% - - -class DepthSlice(nn.Module): - """Per-codebook components: embedding, input/output projections. - - Each depth slice has: - - emb: embeds the previous slice's token (or text token for slice 0) - - linear_in: projects main transformer output to depth dimension - - linear_out: projects depth output to audio vocabulary logits - """ - - def __init__(self, in_vocab_size, out_vocab_size, main_dim, depth_dim): - super().__init__() - self.emb = nn.Embedding(in_vocab_size, depth_dim) - self.linear_in = nn.Linear(main_dim, depth_dim, bias=False) - self.linear_out = nn.Linear(depth_dim, out_vocab_size, bias=False) - - -class Depthformer(nn.Module): - """Moshi-style Depthformer: main transformer + depth generation. - - The main transformer processes combined text + audio embeddings along - the time dimension. The depth transformer generates audio tokens - hierarchically across codebook levels, conditioned on the main - transformer output. - - At each time step: - 1. Sum text + audio embeddings - 2. Step the main transformer (updates KV cache) - 3. Predict next text token - 4. Run depth generation: for each codebook slice, step the depth - transformer (shared weights, accumulating KV cache), and sample - the next audio token conditioned on previous slices. - 5. Reset depth KV cache (depth context doesn't persist across time) - """ - - def __init__( - self, - main_stack, - depth_stack, - main_dim, - depth_dim, - text_vocab, - audio_vocab, - num_codebooks, - num_slices, - ): - super().__init__() - self.main_stack = main_stack - self.depth_stack = depth_stack - self.text_emb = nn.Embedding(text_vocab, main_dim) - self.audio_embs = [ - nn.Embedding(audio_vocab, main_dim) for _ in range(num_codebooks) - ] - self.out_norm = nn.RMSNorm(main_dim) - self.text_head = nn.Linear(main_dim, text_vocab, bias=False) - - # Depth slices: first slice conditions on text, rest on audio. - self.slices = [] - for i in range(num_slices): - in_vocab = text_vocab if i == 0 else audio_vocab - self.slices.append(DepthSlice(in_vocab, audio_vocab, main_dim, depth_dim)) - - self._main_dim = main_dim - self._depth_dim = depth_dim - - def _embed(self, text_token, audio_tokens): - """Sum text + audio embeddings -> [batch, time, main_dim].""" - x = self.text_emb(text_token) - for tok, emb in zip(audio_tokens, self.audio_embs): - x = x + emb(tok) - return x - - def __call__(self, text_tokens, audio_tokens_list): - """Layer mode: full sequence through the main transformer. - - Args: - text_tokens: [batch, time] int32. - audio_tokens_list: list of num_codebooks [batch, time] arrays. - - Returns: - text_logits: [batch, time, text_vocab]. - hidden: [batch, time, main_dim]. - """ - x = self._embed(text_tokens, audio_tokens_list) - mask = mx.ones(x.shape[:2], dtype=mx.bool_) - y = self.main_stack.layer(Sequence(x, mask)) - h = self.out_norm(y.values) - text_logits = self.text_head(h) - return text_logits, h - - def get_initial_main_state(self, batch_size): - """Create initial state for the main transformer.""" - spec = ShapeDType((self._main_dim,), mx.float32) - return self.main_stack.get_initial_state(batch_size, spec) - - def generate_step(self, text_token, audio_tokens, main_state): - """One autoregressive step: main transformer + depth generation. - - Args: - text_token: [batch, 1] int32. - audio_tokens: list of num_codebooks [batch, 1] int32 arrays. - main_state: main transformer KV cache state. - - Returns: - next_text: [batch, 1] int32. - next_audio: list of num_slices [batch, 1] int32 arrays. - new_main_state: updated main transformer state. - """ - # --- Main transformer step --- - x = self._embed(text_token, audio_tokens) - mask = mx.ones((x.shape[0], 1), dtype=mx.bool_) - y, main_state = self.main_stack.step(Sequence(x, mask), main_state) - h = self.out_norm(y.values) # [batch, 1, main_dim] - - # Text prediction. - text_logits = self.text_head(h) - next_text = mx.argmax(text_logits[:, 0, :], axis=-1, keepdims=True) - - # --- Depth generation --- - # Fresh KV cache for each time step (depth context = codebook levels). - depth_spec = ShapeDType((self._depth_dim,), mx.float32) - depth_state = self.depth_stack.get_initial_state(x.shape[0], depth_spec) - - prev_token = next_text # First slice conditions on text token. - next_audio = [] - - for s in self.slices: - # Combine main output projection + previous token embedding. - dx = s.linear_in(h) + s.emb(prev_token) - - # Step the shared depth transformer (accumulates KV cache). - dy, depth_state = self.depth_stack.step(Sequence(dx, mask), depth_state) - - # Sample next audio token. - logits = s.linear_out(dy.values) - prev_token = mx.argmax(logits[:, 0, :], axis=-1, keepdims=True) - next_audio.append(prev_token) - - return next_text, next_audio, main_state - - -# Build sequence-layers transformer stacks from configs. -main_stack = main_config.make(backend='mlx') -depth_stack = depth_config.make(backend='mlx') - -model = Depthformer( - main_stack, - depth_stack, - main_dim=MAIN_DIM, - depth_dim=DEPTH_DIM, - text_vocab=TEXT_VOCAB, - audio_vocab=AUDIO_VOCAB, - num_codebooks=NUM_AUDIO_CODEBOOKS, - num_slices=NUM_DEPTH_SLICES, -) - -# Materialize deferred layers (attention needs to know in_features). -main_input_spec = ShapeDType((MAIN_DIM,), mx.float32) -depth_input_spec = ShapeDType((DEPTH_DIM,), mx.float32) -export._materialize_deferred(model.main_stack, BATCH_SIZE, main_input_spec) -export._materialize_deferred(model.depth_stack, BATCH_SIZE, depth_input_spec) - -# Count parameters. -nparams = sum(v.size for _, v in mlx.utils.tree_flatten(model.parameters())) -print(f'Model built: {nparams:,} parameters') - -# %% [markdown] -# ## 4. Layer Mode -# -# Process a full sequence through the main transformer. This is useful -# for teacher forcing (training) or batch inference on prefilled context. - -# %% -seq_len = 8 - -# Random input tokens. -text_tokens = mx.random.randint(0, TEXT_VOCAB, shape=(BATCH_SIZE, seq_len)) -audio_tokens_list = [ - mx.random.randint(0, AUDIO_VOCAB, shape=(BATCH_SIZE, seq_len)) - for _ in range(NUM_AUDIO_CODEBOOKS) -] - -text_logits, hidden = model(text_tokens, audio_tokens_list) -mx.eval(text_logits, hidden) - -print( - f'Input: text {text_tokens.shape}, ' - f'audio {len(audio_tokens_list)}x{audio_tokens_list[0].shape}' -) -print(f'Text logits: {text_logits.shape}') -print(f'Hidden: {hidden.shape}') - -# %% [markdown] -# ## 5. Streaming Generation -# -# Token-by-token autoregressive generation. At each step: -# 1. The main transformer processes the current token (KV cache persists) -# 2. The depth transformer generates audio tokens across codebook levels -# (KV cache resets per time step) -# 3. Generated audio tokens feed back as input at the next time step - -# %% -# Initialize state. -main_state = model.get_initial_main_state(BATCH_SIZE) - -# Start with zero tokens (padding / start-of-sequence). -text_token = mx.zeros((BATCH_SIZE, 1), dtype=mx.int32) -audio_tokens = [ - mx.zeros((BATCH_SIZE, 1), dtype=mx.int32) - for _ in range(NUM_AUDIO_CODEBOOKS) -] - -generated_text = [] -generated_audio = [[] for _ in range(NUM_DEPTH_SLICES)] - -print(f'Generating {NUM_TOKENS} steps...\n') - -for t in range(NUM_TOKENS): - next_text, next_audio, main_state = model.generate_step( - text_token, audio_tokens, main_state - ) - mx.eval(next_text, *next_audio) - - # Record generated tokens. - generated_text.append(int(next_text[0, 0])) - for i in range(NUM_DEPTH_SLICES): - generated_audio[i].append(int(next_audio[i][0, 0])) - - # Prepare next input. - text_token = next_text - for i in range(NUM_DEPTH_SLICES): - audio_tokens[i] = next_audio[i] - # Input codebooks stay zero (no "environment" audio in this demo). - -print(f'Text tokens: {generated_text}') -for i in range(NUM_DEPTH_SLICES): - print(f'Audio cb {i}: {generated_audio[i]}') - -# %% [markdown] -# ## 6. Why SequenceLayers? -# -# Compare the Moshi codebase (~1000 lines of handwritten transformer, -# attention, KV cache, and streaming code) with the SequenceLayers -# approach: -# -# | Aspect | Moshi (handwritten) | SequenceLayers | -# |--------|-------------------|----------------| -# | Transformer | `Transformer`, `TransformerLayer`, `Attention`, `MlpGating` classes | `Repeat(Serial(Residual(Norm, Attn, Flatten), Residual(Norm, Dense, GatedUnit, Dense)))` config | -# | KV cache | Manual `KVCache`, `RotatingKVCache` classes + explicit offset tracking | Automatic via `step()` — ring buffer with `put_along_axis` | -# | Streaming | Custom `__call__` with cache threading | `model.step(x, state)` returns updated state | -# | Layer / Step | Separate code paths | `layer()` and `step()` from one config, verified equivalent | -# | Positional encoding | Separate `nn.RoPE` object, manually applied | `query_network=ApplyRotaryPositionalEncoding.Config(...)` | -# | Export | N/A | `export.export_step(model, path, ...)` → `.mlxfn` | -# -# The model-specific code (embeddings, depth generation loop) stays -# custom, while the transformer building blocks come from SequenceLayers -# with all the streaming, caching, and state management handled -# automatically. - -# %% -print('Done!') diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index 81e47ba..ec647af 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -32,10 +32,12 @@ # Re-export basic_types TypeVars used in type annotations. from sequence_layers.mlx.types import MaskT -# Re-export attention projection configs (from JAX, used to configure attention). -from sequence_layers.jax.attention.common import CombinedQueryKeyValueProjection -from sequence_layers.jax.attention.common import QueryAndKeyValueProjection -from sequence_layers.jax.attention.common import SeparateQueryKeyValueProjection +# Re-export attention projection configs (MLX-native, no JAX dependency). +from sequence_layers.mlx.projection_configs import CombinedQueryKeyValueProjection +from sequence_layers.mlx.projection_configs import QueryAndKeyValueProjection +from sequence_layers.mlx.projection_configs import QueryAndSharedKeyValueProjection +from sequence_layers.mlx.projection_configs import QueryKeyValueProjectionConfig +from sequence_layers.mlx.projection_configs import SeparateQueryKeyValueProjection # Re-export MLX layer hierarchy. from sequence_layers.mlx.types import ChannelSpec @@ -117,6 +119,15 @@ from sequence_layers.mlx.convolution import DeferredDepthwiseConv1D from sequence_layers.mlx.convolution import DepthwiseConv1D +# Re-export 2D convolution/pooling/upsampling layers. +from sequence_layers.mlx.convolution2d import AveragePooling2D +from sequence_layers.mlx.convolution2d import Conv2D +from sequence_layers.mlx.convolution2d import Conv2DTranspose +from sequence_layers.mlx.convolution2d import DeferredConv2D +from sequence_layers.mlx.convolution2d import DeferredConv2DTranspose +from sequence_layers.mlx.convolution2d import ParallelChannels +from sequence_layers.mlx.convolution2d import Upsample2D + # Re-export DSP layers. from sequence_layers.mlx.dsp import Delay from sequence_layers.mlx.dsp import FFT @@ -137,6 +148,8 @@ from sequence_layers.mlx.combinators import Repeat from sequence_layers.mlx.combinators import Residual from sequence_layers.mlx.combinators import Serial +from sequence_layers.mlx.combinators import SerialCombinatorMixin +from sequence_layers.mlx.combinators import SerialModules # Re-export conditioning. from sequence_layers.mlx.conditioning import Conditioning @@ -148,6 +161,9 @@ # --------------------------------------------------------------------------- # Backend factory registration # --------------------------------------------------------------------------- +# Re-export SequenceLayerConfig (lives in JAX types but is backend-agnostic). +from sequence_layers.jax.types import SequenceLayerConfig + from sequence_layers.jax.types import SequenceLayerConfig as _SLC @@ -188,6 +204,7 @@ def _register_backends(): from sequence_layers.mlx import pooling as mlx_pool from sequence_layers.mlx import dsp as mlx_dsp from sequence_layers.mlx import combinators as mlx_comb + from sequence_layers.mlx import convolution2d as mlx_conv2d reg = _SLC.register_backend_factory @@ -338,6 +355,19 @@ def _register_backends(): mlx_conv.Conv1DTranspose.from_config, ) + # 2D Convolution. + reg('mlx', jax_conv.Conv2D.Config, mlx_conv2d.Conv2D.from_config) + reg('mlx', jax_conv.Conv2DTranspose.Config, mlx_conv2d.Conv2DTranspose.from_config) + + # 2D Pooling. + reg('mlx', jax_pool.AveragePooling2D.Config, mlx_conv2d.AveragePooling2D.from_config) + + # 2D Upsampling. + reg('mlx', jax_simple.Upsample2D.Config, mlx_conv2d.Upsample2D.from_config) + + # ParallelChannels. + reg('mlx', jax_comb.ParallelChannels.Config, mlx_conv2d.ParallelChannels.from_config) + # Pooling. reg('mlx', jax_pool.MaxPooling1D.Config, mlx_pool.MaxPooling1D.from_config) reg('mlx', jax_pool.MinPooling1D.Config, mlx_pool.MinPooling1D.from_config) @@ -371,5 +401,45 @@ def _register_backends(): reg('mlx', jax_comb.Repeat.Config, mlx_comb.Repeat.from_config) reg('mlx', jax_comb.Parallel.Config, mlx_comb.Parallel.from_config) + # --------------------------------------------------------------- + # MLX-native Config classes. + # These mirror the JAX Configs but are defined directly on the MLX + # layer classes, so they also need backend registration. + # --------------------------------------------------------------- + reg('mlx', mlx_simple.Identity.Config, mlx_simple.Identity.from_config) + reg('mlx', mlx_simple.Dropout.Config, mlx_simple.Dropout.from_config) + reg('mlx', mlx_simple.CheckpointName.Config, mlx_simple.CheckpointName.from_config) + reg('mlx', mlx_simple.GatedUnit.Config, mlx_simple.GatedUnit.from_config) + reg('mlx', mlx_dense.DenseDeferred.Config, mlx_dense.DenseDeferred.from_config) + reg('mlx', mlx_dense.EinsumDense.Config, mlx_dense.EinsumDense.from_config) + reg('mlx', mlx_norm.RMSNormalization.Config, mlx_norm.RMSNormalization.from_config) + reg('mlx', mlx_pos.ApplyRotaryPositionalEncoding.Config, mlx_pos.ApplyRotaryPositionalEncoding.from_config) + reg('mlx', mlx_dsp.Delay.Config, mlx_dsp.Delay.from_config) + reg('mlx', mlx_comb.Serial.Config, mlx_comb.Serial.from_config) + reg('mlx', mlx_comb.Residual.Config, mlx_comb.Residual.from_config) + reg('mlx', mlx_cond.Conditioning.Config, mlx_cond.Conditioning.from_config) + reg('mlx', mlx_attn.DotProductSelfAttention.Config, mlx_attn.DotProductSelfAttention.from_config) + reg('mlx', mlx_attn.DotProductAttention.Config, mlx_attn.DotProductAttention.from_config) + reg('mlx', mlx_attn.StreamingDotProductAttention.Config, mlx_attn.StreamingDotProductAttention.from_config) + reg('mlx', mlx_attn.LocalDotProductSelfAttention.Config, mlx_attn.LocalDotProductSelfAttention.from_config) + reg('mlx', mlx_simple.Elu.Config, mlx_simple.Elu.from_config) + reg('mlx', mlx_simple.Cast.Config, mlx_simple.Cast.from_config) + reg('mlx', mlx_simple.Flatten.Config, mlx_simple.Flatten.from_config) + reg('mlx', mlx_simple.Reshape.Config, mlx_simple.Reshape.from_config) + reg('mlx', mlx_simple.ExpandDims.Config, mlx_simple.ExpandDims.from_config) + reg('mlx', mlx_simple.Lambda.Config, mlx_simple.Lambda.from_config) + reg('mlx', mlx_dsp.Lookahead.Config, mlx_dsp.Lookahead.from_config) + reg('mlx', mlx_dsp.STFT.Config, mlx_dsp.STFT.from_config) + reg('mlx', mlx_dsp.InverseSTFT.Config, mlx_dsp.InverseSTFT.from_config) + reg('mlx', mlx_conv2d.Conv2D.Config, mlx_conv2d.Conv2D.from_config) + reg('mlx', mlx_conv2d.Conv2DTranspose.Config, mlx_conv2d.Conv2DTranspose.from_config) + reg('mlx', mlx_conv2d.AveragePooling2D.Config, mlx_conv2d.AveragePooling2D.from_config) + reg('mlx', mlx_conv2d.Upsample2D.Config, mlx_conv2d.Upsample2D.from_config) + reg('mlx', mlx_conv2d.ParallelChannels.Config, mlx_conv2d.ParallelChannels.from_config) + reg('mlx', mlx_simple.Embedding.Config, mlx_simple.Embedding.from_config) + reg('mlx', mlx_simple.Scale.Config, mlx_simple.Scale.from_config) + reg('mlx', mlx_simple.Logging.Config, mlx_simple.Logging.from_config) + reg('mlx', mlx_norm.LayerNormalization.Config, mlx_norm.LayerNormalization.from_config) + _register_backends() diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index e869b2a..82cbdcc 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -1,5 +1,6 @@ """Dot-product attention layers for MLX.""" +import dataclasses import math import mlx.core as mx @@ -7,12 +8,41 @@ from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import projection_configs from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence MaskedSequence = bt.MaskedSequence +def _scale_queries(queries, per_dim_scale, query_scale, units_per_head): + """Scale queries, optionally with per-dimension learned scale. + + Matches JAX backend's _scale_query in common.py. + + Args: + queries: [b, num_heads, q_time, units_per_head]. + per_dim_scale: [units_per_head] learned scale or None. + query_scale: float scale or None (defaults to 1/sqrt(uph)). + units_per_head: int. + + Returns: + Scaled queries, same shape. + """ + if query_scale is None: + query_scale = 1.0 / math.sqrt(units_per_head) + if per_dim_scale is not None: + # 1/softplus(0) = 1/ln(2). At init (zeros), effective scale = query_scale. + r_softplus_0 = 1.442695041 + scale = r_softplus_0 * query_scale + softplus = mx.log1p(mx.exp(per_dim_scale.astype(queries.dtype))) + queries = queries * (scale * softplus) + else: + queries = queries * query_scale + return queries + + def _causal_mask(q_len, kv_len): """Build a [1, 1, q_len, kv_len] causal mask (True = attend).""" # Each query at position i can attend to keys at positions @@ -42,6 +72,40 @@ class DotProductSelfAttention(types.Emitting): out_proj: [num_heads * units_per_head, in_features] """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for DotProductSelfAttention.""" + + num_heads: int + units_per_head: int + max_past_horizon: int + max_future_horizon: int = 0 + num_kv_heads: int | None = None + attention_probabilities_dropout_rate: float = 0.0 + broadcast_dropout_across_queries: bool = False + use_bias: bool = False + input_projection: projection_configs.QueryKeyValueProjectionConfig = ( + dataclasses.field( + default_factory=projection_configs.CombinedQueryKeyValueProjection + ) + ) + query_network: _SequenceLayerConfig | None = None + key_network: _SequenceLayerConfig | None = None + value_network: _SequenceLayerConfig | None = None + attention_logits_soft_cap: float | None = None + per_dim_scale: bool = False + query_scale: float | None = None + zero_fully_masked: bool = False + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + num_sink_embeddings: int = 0 + use_sink_scalars: bool = False + use_kv_cache_ringbuffer: bool = False + name: str | None = None + + def make(self) -> 'DotProductSelfAttention': + return DotProductSelfAttention.from_config(self) + def __init__( self, *, @@ -53,6 +117,7 @@ def __init__( num_kv_heads: int | None = None, use_bias: bool = False, query_scale: float | None = None, + per_dim_scale: bool = False, compute_dtype=None, param_dtype=mx.float32, kernel_init=None, @@ -61,6 +126,7 @@ def __init__( key_network: types.SequenceLayer | None = None, value_network: types.SequenceLayer | None = None, attention_logits_soft_cap: float | None = None, + num_sink_embeddings: int = 0, ): super().__init__() if num_kv_heads is None: @@ -87,6 +153,11 @@ def __init__( self.compute_dtype = compute_dtype self._param_dtype = param_dtype self._attention_logits_soft_cap = attention_logits_soft_cap + self._per_dim_scale = ( + mx.zeros((units_per_head,), dtype=param_dtype) + if per_dim_scale + else None + ) if kernel_init is None: kernel_init = init_mapping._make_variance_scaling_init( @@ -108,6 +179,19 @@ def __init__( self.k_bias = bias_init(key, (kv_dim,), param_dtype) self.v_bias = bias_init(key, (kv_dim,), param_dtype) + # Attention sink embeddings. + self.num_sink_embeddings = num_sink_embeddings + if num_sink_embeddings > 0: + self.sink_key_embeddings = mx.zeros( + (num_sink_embeddings, num_heads, units_per_head), dtype=param_dtype + ) + self.sink_value_embeddings = mx.zeros( + (num_sink_embeddings, num_kv_heads, units_per_head), dtype=param_dtype + ) + else: + self.sink_key_embeddings = None + self.sink_value_embeddings = None + self.query_network = query_network self.key_network = key_network self.value_network = value_network @@ -158,8 +242,6 @@ def _compute_attention(self, queries, keys, values, mask): Returns: context: [b, q_t, num_heads, units_per_head] """ - scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) - # GQA: repeat K/V heads to match query heads. num_groups = self.num_heads // self.num_kv_heads if num_groups > 1: @@ -173,9 +255,43 @@ def _compute_attention(self, queries, keys, values, mask): v = mx.transpose(values, (0, 2, 1, 3)) # [b, nh, kvt, h] # Scaled dot-product attention. - q = q * scale + # Compute sink logits BEFORE scaling queries, matching JAX behavior. + # JAX computes sink_key_logits = einsum('BTNH,KNH->BNTK', queries.values, + # sink_key_embeddings) before _scale_query(). + if self.sink_key_embeddings is not None: + sink_k = self.sink_key_embeddings.astype(q.dtype) # [K, nh, h] + sink_k_t = mx.transpose(sink_k, (1, 2, 0)) # [nh, h, K] + sink_logits = mx.matmul(q, sink_k_t) # [b, nh, qt, K] + + q = _scale_queries( + q, self._per_dim_scale, self._query_scale, self.units_per_head + ) logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) + # Add attention sink logits if present. + if self.sink_key_embeddings is not None: + # Prepend sink values to v: v becomes [b, nh, K+kvt, h] + sink_v = self.sink_value_embeddings.astype(v.dtype) # [K, nkv, h] + if num_groups > 1: + sink_v = mx.repeat(sink_v, num_groups, axis=1) + sink_v_t = mx.transpose(sink_v, (1, 0, 2)) # [nh, K, h] + sink_v_b = mx.broadcast_to( + sink_v_t[None], (v.shape[0],) + sink_v_t.shape + ) # [b, nh, K, h] + v = mx.concatenate([sink_v_b, v], axis=2) # [b, nh, K+kvt, h] + + # Prepend sink logits to logits: [b, nh, qt, K+kvt] + logits = mx.concatenate([sink_logits, logits], axis=-1) + + # Extend mask for sinks (always valid). + if mask is not None: + num_sinks = self.sink_key_embeddings.shape[0] + sink_mask = mx.ones( + (mask.shape[0], mask.shape[1], mask.shape[2], num_sinks), + dtype=mx.bool_, + ) + mask = mx.concatenate([sink_mask, mask], axis=-1) + # Optional soft cap on logits (e.g., Gemma 2 uses cap=50.0). if self._attention_logits_soft_cap is not None: cap = self._attention_logits_soft_cap @@ -186,7 +302,9 @@ def _compute_attention(self, queries, keys, values, mask): large_neg = mx.array(-1e9, dtype=logits.dtype) logits = mx.where(mask, logits, large_neg) - weights = mx.softmax(logits, axis=-1) + # Run softmax in at least float32 to match JAX precision. + logits_f32 = logits.astype(mx.float32) if logits.dtype != mx.float32 else logits + weights = mx.softmax(logits_f32, axis=-1).astype(v.dtype) context = mx.matmul(weights, v) # [b, nh, qt, h] # Transpose back to [b, qt, nh, h]. @@ -430,10 +548,10 @@ class DeferredDotProductSelfAttention(types.Emitting): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features, backend='mlx'): - if self._inner is not None: + if self.inner is not None: return # Build optional Q/K/V networks. @@ -451,7 +569,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): if compute_dtype is not None: compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) - self._inner = DotProductSelfAttention( + self.inner = DotProductSelfAttention( in_features=in_features, num_heads=self._config.num_heads, units_per_head=self._config.units_per_head, @@ -460,6 +578,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): num_kv_heads=self._config.num_kv_heads, use_bias=self._config.use_bias, query_scale=getattr(self._config, 'query_scale', None), + per_dim_scale=getattr(self._config, 'per_dim_scale', False), compute_dtype=compute_dtype, param_dtype=param_dtype, kernel_init=init_mapping.map_initializer( @@ -473,6 +592,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): query_network=query_network, key_network=key_network, value_network=value_network, + num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), ) @property @@ -498,36 +618,54 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): self._ensure_initialized(input_spec.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer_with_emits(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer_with_emits(x, constants=constants) + return self.inner.layer_with_emits(x, constants=constants) def step_with_emits(self, x, state, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.step_with_emits(x, state, constants=constants) + return self.inner.step_with_emits(x, state, constants=constants) class DotProductAttention(types.Emitting): - """Multi-headed dot-product cross attention for MLX. + """Multi-headed dot-product cross attention for MLX.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for DotProductAttention.""" + + source_name: str + num_heads: int + units_per_head: int + attention_probabilities_dropout_rate: float = 0.0 + broadcast_dropout_across_queries: bool = False + use_bias: bool = False + input_projection: ( + projection_configs.QueryAndKeyValueProjection + | projection_configs.SeparateQueryKeyValueProjection + | projection_configs.QueryAndSharedKeyValueProjection + ) = dataclasses.field( + default_factory=projection_configs.QueryAndKeyValueProjection + ) + query_network: _SequenceLayerConfig | None = None + key_network: _SequenceLayerConfig | None = None + value_network: _SequenceLayerConfig | None = None + attention_logits_soft_cap: float | None = None + per_dim_scale: bool = False + query_scale: float | None = None + zero_fully_masked: bool = False + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None - Queries come from the input sequence; keys and values come from a - source sequence looked up in the ``constants`` dictionary. + def make(self) -> 'DotProductAttention': + return DotProductAttention.from_config(self) - In ``layer()`` mode the K/V projections and optional networks are - applied to the source on-the-fly. In ``step()`` mode they are - pre-computed during ``get_initial_state()`` so that each step only - needs to project and attend queries. - Kernels are stored in Linen-compatible shapes: - q_proj: [in_features, num_heads * units_per_head] - k_proj: [source_features, num_heads * units_per_head] - v_proj: [source_features, num_heads * units_per_head] - out_proj: [num_heads * units_per_head, in_features] - """ def __init__( self, @@ -539,6 +677,7 @@ def __init__( units_per_head: int, use_bias: bool = False, query_scale: float | None = None, + per_dim_scale: bool = False, compute_dtype=None, param_dtype=mx.float32, kernel_init=None, @@ -557,6 +696,11 @@ def __init__( self._query_scale = query_scale self.compute_dtype = compute_dtype self._param_dtype = param_dtype + self._per_dim_scale = ( + mx.zeros((units_per_head,), dtype=param_dtype) + if per_dim_scale + else None + ) if kernel_init is None: kernel_init = init_mapping._make_variance_scaling_init( @@ -620,13 +764,13 @@ def _get_source(self, constants): def _compute_attention(self, queries, keys, values, mask): """Compute scaled dot-product attention (no causal mask).""" - scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) - q = mx.transpose(queries, (0, 2, 1, 3)) k = mx.transpose(keys, (0, 2, 1, 3)) v = mx.transpose(values, (0, 2, 1, 3)) - q = q * scale + q = _scale_queries( + q, self._per_dim_scale, self._query_scale, self.units_per_head + ) logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) if mask is not None: @@ -743,10 +887,10 @@ class DeferredDotProductAttention(types.Emitting): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features, source_features, backend='mlx'): - if self._inner is not None: + if self.inner is not None: return query_network = None @@ -764,7 +908,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) - self._inner = DotProductAttention( + self.inner = DotProductAttention( in_features=in_features, source_features=source_features, source_name=self._config.source_name, @@ -772,6 +916,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): units_per_head=self._config.units_per_head, use_bias=self._config.use_bias, query_scale=getattr(self._config, 'query_scale', None), + per_dim_scale=getattr(self._config, 'per_dim_scale', False), compute_dtype=compute_dtype, param_dtype=param_dtype, kernel_init=init_mapping.map_initializer( @@ -819,19 +964,19 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): source = self._get_source(constants) self._ensure_initialized(input_spec.shape[-1], source.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer_with_emits(self, x, *, constants=None): source = self._get_source(constants) self._ensure_initialized(x.shape[-1], source.shape[-1]) - return self._inner.layer_with_emits(x, constants=constants) + return self.inner.layer_with_emits(x, constants=constants) def step_with_emits(self, x, state, *, constants=None): source = self._get_source(constants) self._ensure_initialized(x.shape[-1], source.shape[-1]) - return self._inner.step_with_emits(x, state, constants=constants) + return self.inner.step_with_emits(x, state, constants=constants) def _banded_mask(q_len, kv_len, num_lower, num_upper): @@ -869,6 +1014,8 @@ def _step_visibility_mask( class StreamingDotProductAttention(types.Emitting): """Multi-headed streaming cross-attention for MLX. + Also covers StreamingLocalDotProductAttention from the JAX backend. + Queries come from the input; keys and values come from a source sequence provided in constants at the same streaming rate as input. @@ -886,6 +1033,48 @@ class StreamingDotProductAttention(types.Emitting): v_proj: [source_features, num_heads * units_per_head] """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for StreamingDotProductAttention. + + This Config also serves as the MLX-native equivalent of the JAX + StreamingLocalDotProductAttention.Config. + """ + + source_name: str + num_heads: int + units_per_head: int + block_size: int = 1 + max_past_horizon: int = 1 + max_future_horizon: int = 0 + attention_probabilities_dropout_rate: float = 0.0 + broadcast_dropout_across_queries: bool = False + use_bias: bool = False + use_query_delay_buffer: bool = True + input_projection: ( + projection_configs.QueryAndKeyValueProjection + | projection_configs.SeparateQueryKeyValueProjection + | projection_configs.QueryAndSharedKeyValueProjection + ) = dataclasses.field( + default_factory=projection_configs.QueryAndKeyValueProjection + ) + query_network: _SequenceLayerConfig | None = None + key_network: _SequenceLayerConfig | None = None + value_network: _SequenceLayerConfig | None = None + attention_logits_soft_cap: float | None = None + per_dim_scale: bool = False + query_scale: float | None = None + zero_fully_masked: bool = False + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + num_sink_embeddings: int = 0 + use_sink_scalars: bool = False + use_kv_cache_ringbuffer: bool = False + name: str | None = None + + def make(self) -> 'StreamingDotProductAttention': + return StreamingDotProductAttention.from_config(self) + def __init__( self, *, @@ -899,6 +1088,7 @@ def __init__( use_bias: bool = False, use_query_delay_buffer: bool = True, query_scale: float | None = None, + per_dim_scale: bool = False, compute_dtype=None, param_dtype=mx.float32, kernel_init=None, @@ -906,6 +1096,7 @@ def __init__( query_network: types.SequenceLayer | None = None, key_network: types.SequenceLayer | None = None, value_network: types.SequenceLayer | None = None, + num_sink_embeddings: int = 0, ): super().__init__() if max_past_horizon < 1: @@ -929,6 +1120,11 @@ def __init__( self._query_scale = query_scale self.compute_dtype = compute_dtype self._param_dtype = param_dtype + self._per_dim_scale = ( + mx.zeros((units_per_head,), dtype=param_dtype) + if per_dim_scale + else None + ) if kernel_init is None: kernel_init = init_mapping._make_variance_scaling_init( @@ -949,6 +1145,18 @@ def __init__( self.q_bias = bias_init(key, (qkv_dim,), param_dtype) self.k_bias = bias_init(key, (qkv_dim,), param_dtype) self.v_bias = bias_init(key, (qkv_dim,), param_dtype) + # Attention sink embeddings. + self.num_sink_embeddings = num_sink_embeddings + if num_sink_embeddings > 0: + self.sink_key_embeddings = mx.zeros( + (num_sink_embeddings, num_heads, units_per_head), dtype=param_dtype + ) + self.sink_value_embeddings = mx.zeros( + (num_sink_embeddings, num_heads, units_per_head), dtype=param_dtype + ) + else: + self.sink_key_embeddings = None + self.sink_value_embeddings = None self.query_network = query_network self.key_network = key_network @@ -996,16 +1204,45 @@ def _get_source(self, constants): def _compute_attention(self, queries, keys, values, mask): """Compute scaled dot-product attention.""" - scale = self._query_scale or (1.0 / math.sqrt(self.units_per_head)) q = mx.transpose(queries, (0, 2, 1, 3)) k = mx.transpose(keys, (0, 2, 1, 3)) v = mx.transpose(values, (0, 2, 1, 3)) - q = q * scale + + # Compute sink logits BEFORE scaling queries, matching JAX behavior. + if self.sink_key_embeddings is not None: + sink_k = self.sink_key_embeddings.astype(q.dtype) # [K, nh, h] + sink_k_t = mx.transpose(sink_k, (1, 2, 0)) # [nh, h, K] + sink_logits = mx.matmul(q, sink_k_t) # [b, nh, qt, K] + + q = _scale_queries( + q, self._per_dim_scale, self._query_scale, self.units_per_head + ) logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) + + # Add attention sink logits if present. + if self.sink_key_embeddings is not None: + sink_v = self.sink_value_embeddings.astype(v.dtype) # [K, nh, h] + sink_v_t = mx.transpose(sink_v, (1, 0, 2)) # [nh, K, h] + sink_v_b = mx.broadcast_to( + sink_v_t[None], (v.shape[0],) + sink_v_t.shape + ) # [b, nh, K, h] + v = mx.concatenate([sink_v_b, v], axis=2) + logits = mx.concatenate([sink_logits, logits], axis=-1) + + if mask is not None: + num_sinks = self.sink_key_embeddings.shape[0] + sink_mask = mx.ones( + (mask.shape[0], mask.shape[1], mask.shape[2], num_sinks), + dtype=mx.bool_, + ) + mask = mx.concatenate([sink_mask, mask], axis=-1) + if mask is not None: large_neg = mx.array(-1e9, dtype=logits.dtype) logits = mx.where(mask, logits, large_neg) - weights = mx.softmax(logits, axis=-1) + # Run softmax in at least float32 to match JAX precision. + logits_f32 = logits.astype(mx.float32) if logits.dtype != mx.float32 else logits + weights = mx.softmax(logits_f32, axis=-1).astype(v.dtype) context = mx.matmul(weights, v) context = mx.transpose(context, (0, 2, 1, 3)) return context @@ -1258,10 +1495,10 @@ class DeferredStreamingDotProductAttention(types.Emitting): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features, source_features, backend='mlx'): - if self._inner is not None: + if self.inner is not None: return query_network = None @@ -1279,7 +1516,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) - self._inner = StreamingDotProductAttention( + self.inner = StreamingDotProductAttention( in_features=in_features, source_features=source_features, source_name=self._config.source_name, @@ -1292,6 +1529,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): self._config, 'use_query_delay_buffer', True ), query_scale=getattr(self._config, 'query_scale', None), + per_dim_scale=getattr(self._config, 'per_dim_scale', False), compute_dtype=compute_dtype, param_dtype=param_dtype, kernel_init=init_mapping.map_initializer( @@ -1305,6 +1543,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): query_network=query_network, key_network=key_network, value_network=value_network, + num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), ) def _get_source(self, constants): @@ -1340,29 +1579,57 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): source = self._get_source(constants) self._ensure_initialized(input_spec.shape[-1], source.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer_with_emits(self, x, *, constants=None): source = self._get_source(constants) self._ensure_initialized(x.shape[-1], source.shape[-1]) - return self._inner.layer_with_emits(x, constants=constants) + return self.inner.layer_with_emits(x, constants=constants) def step_with_emits(self, x, state, *, constants=None): source = self._get_source(constants) self._ensure_initialized(x.shape[-1], source.shape[-1]) - return self._inner.step_with_emits(x, state, constants=constants) + return self.inner.step_with_emits(x, state, constants=constants) class LocalDotProductSelfAttention(DotProductSelfAttention): - """Local dot-product self attention with configurable block_size. - - Extends DotProductSelfAttention with a configurable block_size for - step-mode processing. The sliding window behavior is already handled - by the base class's banded visibility mask via max_past_horizon and - max_future_horizon. - """ + """Local dot-product self attention with configurable block_size.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for LocalDotProductSelfAttention.""" + + num_heads: int + units_per_head: int + block_size: int + max_past_horizon: int + max_future_horizon: int = 0 + attention_probabilities_dropout_rate: float = 0.0 + broadcast_dropout_across_queries: bool = False + use_bias: bool = False + input_projection: projection_configs.QueryKeyValueProjectionConfig = ( + dataclasses.field( + default_factory=projection_configs.CombinedQueryKeyValueProjection + ) + ) + query_network: _SequenceLayerConfig | None = None + key_network: _SequenceLayerConfig | None = None + value_network: _SequenceLayerConfig | None = None + attention_logits_soft_cap: float | None = None + per_dim_scale: bool = False + query_scale: float | None = None + zero_fully_masked: bool = False + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + num_sink_embeddings: int = 0 + use_sink_scalars: bool = False + use_kv_cache_ringbuffer: bool = False + name: str | None = None + + def make(self) -> 'LocalDotProductSelfAttention': + return LocalDotProductSelfAttention.from_config(self) def __init__(self, *, block_size_config: int = 1, **kwargs): super().__init__(**kwargs) @@ -1386,10 +1653,10 @@ class DeferredLocalDotProductSelfAttention(types.Emitting): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features, backend='mlx'): - if self._inner is not None: + if self.inner is not None: return query_network = None @@ -1407,7 +1674,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(self._config.param_dtype) - self._inner = LocalDotProductSelfAttention( + self.inner = LocalDotProductSelfAttention( in_features=in_features, num_heads=self._config.num_heads, units_per_head=self._config.units_per_head, @@ -1416,6 +1683,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): use_bias=self._config.use_bias, block_size_config=self._config.block_size, query_scale=getattr(self._config, 'query_scale', None), + per_dim_scale=getattr(self._config, 'per_dim_scale', False), compute_dtype=compute_dtype, param_dtype=param_dtype, attention_logits_soft_cap=getattr( @@ -1432,6 +1700,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): query_network=query_network, key_network=key_network, value_network=value_network, + num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), ) @property @@ -1461,14 +1730,14 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): self._ensure_initialized(input_spec.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer_with_emits(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer_with_emits(x, constants=constants) + return self.inner.layer_with_emits(x, constants=constants) def step_with_emits(self, x, state, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.step_with_emits(x, state, constants=constants) + return self.inner.step_with_emits(x, state, constants=constants) diff --git a/sequence_layers/mlx/attention_test.py b/sequence_layers/mlx/attention_test.py index 7c9dea7..e274e32 100644 --- a/sequence_layers/mlx/attention_test.py +++ b/sequence_layers/mlx/attention_test.py @@ -92,6 +92,60 @@ def test_with_query_key_networks(self): self.assertEqual(y.shape, (1, 5, 2, 4)) + def test_per_dim_scale(self): + """Test per_dim_scale creates parameter and affects output.""" + layer = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=32, + per_dim_scale=True, + ) + self.assertIsNotNone(layer._per_dim_scale) + self.assertEqual(layer._per_dim_scale.shape, (4,)) + np.testing.assert_array_equal(layer._per_dim_scale, np.zeros(4)) + + # At initialization (zeros), output should match per_dim_scale=False. + layer_no_pds = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=32, + per_dim_scale=False, + ) + # Copy weights so projections match. + layer_no_pds.q_proj = layer.q_proj + layer_no_pds.k_proj = layer.k_proj + layer_no_pds.v_proj = layer.v_proj + + x = test_utils.random_sequence(1, 5, 8) + y_pds = layer.layer(x) + y_no_pds = layer_no_pds.layer(x) + np.testing.assert_allclose( + np.array(y_pds.values), np.array(y_no_pds.values), atol=1e-5 + ) + + # After modifying per_dim_scale, output should differ. + layer._per_dim_scale = mx.ones((4,)) + y_modified = layer.layer(x) + self.assertFalse( + np.allclose( + np.array(y_pds.values), np.array(y_modified.values), atol=1e-5 + ) + ) + + def test_per_dim_scale_step(self): + """Test per_dim_scale works in step mode.""" + layer = attention.DotProductSelfAttention( + in_features=8, + num_heads=2, + units_per_head=4, + max_past_horizon=10, + per_dim_scale=True, + ) + test_utils.verify_contract(self, layer, (8,), atol=1e-4, rtol=1e-4) + + class DeferredDotProductSelfAttentionTest(parameterized.TestCase): def test_from_config(self): diff --git a/sequence_layers/mlx/combinators.py b/sequence_layers/mlx/combinators.py index 8b7296e..f26eb74 100644 --- a/sequence_layers/mlx/combinators.py +++ b/sequence_layers/mlx/combinators.py @@ -1,5 +1,6 @@ """Combinators (Serial, Residual, Repeat, Parallel) for MLX.""" +import dataclasses import enum from functools import reduce from math import lcm @@ -9,6 +10,7 @@ from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import simple as simple_lib from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence @@ -92,9 +94,105 @@ def _combine_sequences(mode, sequences): return Sequence(values, mask) +class SerialCombinatorMixin: + """Mixin for Serial logic. + + Provides serial processing (layer, step, initial state) for classes that + define a ``layers`` attribute containing a sequence of SequenceLayers. + """ + + layers: list[types.SequenceLayer] + + @property + def supports_step(self): + return all(l.supports_step for l in self.layers) + + @property + def block_size(self): + return reduce(lcm, (l.block_size for l in self.layers), 1) + + @property + def output_ratio(self): + r = self.layers[0].output_ratio if self.layers else 1 + for l in self.layers[1:]: + r = r * l.output_ratio + return r + + @property + def input_latency(self): + latency = 0 + for l in self.layers: + latency = l.get_accumulated_input_latency(latency) + return latency + + @property + def output_latency(self): + return int(self.input_latency * self.output_ratio) + + def get_output_shape(self, input_shape, *, constants=None): + shape = input_shape + for l in self.layers: + shape = l.get_output_shape(shape, constants=constants) + return shape + + def get_output_dtype(self, input_dtype, *, constants=None): + dtype = input_dtype + for l in self.layers: + dtype = l.get_output_dtype(dtype, constants=constants) + return dtype + + def get_initial_state(self, batch_size, input_spec, *, constants=None, **kwargs): + spec = input_spec + states = [] + for l in self.layers: + states.append(l.get_initial_state(batch_size, spec, constants=constants)) + spec = l.get_output_spec(spec, constants=constants) + return tuple(states) + + def layer_with_emits(self, x, *, constants=None, **kwargs): + emits = {} + for i, l in enumerate(self.layers): + x, e = l.layer_with_emits(x, constants=constants) + emits[f'layer_{i}'] = e + return x, emits + + def step_with_emits(self, x, state, *, constants=None, **kwargs): + new_state = [] + emits = {} + for i, (l, s) in enumerate(zip(self.layers, state)): + x, s, e = l.step_with_emits(x, s, constants=constants) + new_state.append(s) + emits[f'layer_{i}'] = e + return x, tuple(new_state), emits + + +class SerialModules(SerialCombinatorMixin, types.Emitting): + """A Serial combinator that wraps pre-existing SequenceLayers. + + Unlike Serial (which owns its layers), SerialModules references + pre-constructed modules parented elsewhere. This avoids duplication + when a module graph shares sub-layers across different combinators. + """ + + def __init__(self, layers): + super().__init__() + self.layers = list(layers) + + class Serial(types.Emitting): """Processes SequenceLayers serially.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + layers: tuple[_SequenceLayerConfig, ...] = () + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'layers', tuple(self.layers)) + + def make(self, backend='mlx') -> 'Serial': + return Serial.from_config(self, backend=backend) + def __init__(self, layers: list[types.SequenceLayer]): super().__init__() self.layers = list(layers) @@ -136,7 +234,7 @@ def get_output_dtype(self, input_dtype, *, constants=None): dtype = l.get_output_dtype(dtype, constants=constants) return dtype - def get_initial_state(self, batch_size, input_spec, *, constants=None): + def get_initial_state(self, batch_size, input_spec, *, constants=None, **kwargs): spec = input_spec states = [] for l in self.layers: @@ -144,14 +242,14 @@ def get_initial_state(self, batch_size, input_spec, *, constants=None): spec = l.get_output_spec(spec, constants=constants) return tuple(states) - def layer_with_emits(self, x, *, constants=None): + def layer_with_emits(self, x, *, constants=None, **kwargs): emits = {} for i, l in enumerate(self.layers): x, e = l.layer_with_emits(x, constants=constants) emits[f'layer_{i}'] = e return x, emits - def step_with_emits(self, x, state, *, constants=None): + def step_with_emits(self, x, state, *, constants=None, **kwargs): new_state = [] emits = {} for i, (l, s) in enumerate(zip(self.layers, state)): @@ -169,6 +267,20 @@ def from_config(cls, config, backend='mlx'): class Residual(types.Emitting): """Residual wrapper: y = body(x) + shortcut(x).""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + layers: tuple[_SequenceLayerConfig, ...] = () + shortcut_layers: tuple[_SequenceLayerConfig, ...] | None = None + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'layers', tuple(self.layers)) + if self.shortcut_layers is not None: + object.__setattr__(self, 'shortcut_layers', tuple(self.shortcut_layers)) + + def make(self, backend='mlx') -> 'Residual': + return Residual.from_config(self, backend=backend) + def __init__( self, layers: list[types.SequenceLayer], @@ -176,38 +288,38 @@ def __init__( shortcut: types.SequenceLayer | None = None, ): super().__init__() - self._body = Serial(layers) - self._shortcut = shortcut if shortcut is not None else simple_lib.Identity() + self.body = Serial(layers) + self.shortcut = shortcut if shortcut is not None else simple_lib.Identity() @property def supports_step(self): - return self._body.supports_step and self._shortcut.supports_step + return self.body.supports_step and self.shortcut.supports_step @property def block_size(self): from math import lcm - return lcm(self._body.block_size, self._shortcut.block_size) + return lcm(self.body.block_size, self.shortcut.block_size) @property def output_ratio(self): - return self._body.output_ratio + return self.body.output_ratio @property def input_latency(self): - return self._body.input_latency + return self.body.input_latency def get_output_shape(self, input_shape, *, constants=None): - return self._body.get_output_shape(input_shape, constants=constants) + return self.body.get_output_shape(input_shape, constants=constants) def get_output_dtype(self, input_dtype, *, constants=None): - return self._body.get_output_dtype(input_dtype, constants=constants) + return self.body.get_output_dtype(input_dtype, constants=constants) - def get_initial_state(self, batch_size, input_spec, *, constants=None): - body_state = self._body.get_initial_state( + def get_initial_state(self, batch_size, input_spec, *, constants=None, **kwargs): + body_state = self.body.get_initial_state( batch_size, input_spec, constants=constants ) - shortcut_state = self._shortcut.get_initial_state( + shortcut_state = self.shortcut.get_initial_state( batch_size, input_spec, constants=constants ) return (body_state, shortcut_state) @@ -217,20 +329,20 @@ def _residual_fn(self, y_body, y_shortcut): y_mask = y_body.mask & y_shortcut.mask return Sequence(y_values, y_mask) - def layer_with_emits(self, x, *, constants=None): - y_body, body_emits = self._body.layer_with_emits(x, constants=constants) - y_shortcut, shortcut_emits = self._shortcut.layer_with_emits( + def layer_with_emits(self, x, *, constants=None, **kwargs): + y_body, body_emits = self.body.layer_with_emits(x, constants=constants) + y_shortcut, shortcut_emits = self.shortcut.layer_with_emits( x, constants=constants ) y = self._residual_fn(y_body, y_shortcut) return y, (body_emits, shortcut_emits) - def step_with_emits(self, x, state, *, constants=None): + def step_with_emits(self, x, state, *, constants=None, **kwargs): body_state, shortcut_state = state - y_body, body_state, body_emits = self._body.step_with_emits( + y_body, body_state, body_emits = self.body.step_with_emits( x, body_state, constants=constants ) - y_shortcut, shortcut_state, shortcut_emits = self._shortcut.step_with_emits( + y_shortcut, shortcut_state, shortcut_emits = self.shortcut.step_with_emits( x, shortcut_state, constants=constants ) y = self._residual_fn(y_body, y_shortcut) @@ -298,7 +410,7 @@ def get_output_shape(self, input_shape, *, constants=None): def get_output_dtype(self, input_dtype, *, constants=None): return self.layers[0].get_output_dtype(input_dtype, constants=constants) - def get_initial_state(self, batch_size, input_spec, *, constants=None): + def get_initial_state(self, batch_size, input_spec, *, constants=None, **kwargs): states = [] spec = input_spec for l in self.layers: @@ -306,14 +418,14 @@ def get_initial_state(self, batch_size, input_spec, *, constants=None): # All repeats have the same output spec. return tuple(states) - def layer_with_emits(self, x, *, constants=None): + def layer_with_emits(self, x, *, constants=None, **kwargs): emits = {} for i, l in enumerate(self.layers): x, e = l.layer_with_emits(x, constants=constants) emits[f'repeat_{i}'] = e return x, emits - def step_with_emits(self, x, state, *, constants=None): + def step_with_emits(self, x, state, *, constants=None, **kwargs): new_state = [] emits = {} for i, (l, s) in enumerate(zip(self.layers, state)): @@ -386,7 +498,7 @@ def get_output_shape(self, input_shape, *, constants=None): def get_output_dtype(self, input_dtype, *, constants=None): return self.layers[0].get_output_dtype(input_dtype, constants=constants) - def get_initial_state(self, batch_size, input_spec, *, constants=None): + def get_initial_state(self, batch_size, input_spec, *, constants=None, **kwargs): states = [] for l in self.layers: states.append( @@ -394,7 +506,7 @@ def get_initial_state(self, batch_size, input_spec, *, constants=None): ) return tuple(states) - def layer_with_emits(self, x, *, constants=None): + def layer_with_emits(self, x, *, constants=None, **kwargs): outputs = [] emits = {} for i, l in enumerate(self.layers): @@ -404,7 +516,7 @@ def layer_with_emits(self, x, *, constants=None): combined = _combine_sequences(self.combination, outputs) return combined, emits - def step_with_emits(self, x, state, *, constants=None): + def step_with_emits(self, x, state, *, constants=None, **kwargs): outputs = [] new_state = [] emits = {} diff --git a/sequence_layers/mlx/conditioning.py b/sequence_layers/mlx/conditioning.py index de0ec69..bbfa34f 100644 --- a/sequence_layers/mlx/conditioning.py +++ b/sequence_layers/mlx/conditioning.py @@ -1,5 +1,6 @@ """Conditioning layers for MLX.""" +import dataclasses import enum import math @@ -11,6 +12,7 @@ from sequence_layers.mlx import init_mapping from sequence_layers.mlx.init_mapping import _to_mx_dtype from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence MaskedSequence = bt.MaskedSequence @@ -167,6 +169,21 @@ class Combination(enum.Enum): MUL = 6 CONCAT_BEFORE = 7 + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + conditioning_name: str = '' + projection: 'Conditioning.Projection' = None + combination: 'Conditioning.Combination' = None + projection_channel_shape: tuple[int, ...] | None = None + streaming: bool = False + affine_scale_offset: complex = 1.0 + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def make(self) -> 'Conditioning': + return Conditioning.from_config(self) + def __init__( self, *, diff --git a/sequence_layers/mlx/convolution.py b/sequence_layers/mlx/convolution.py index e176d65..2383bbd 100644 --- a/sequence_layers/mlx/convolution.py +++ b/sequence_layers/mlx/convolution.py @@ -877,10 +877,10 @@ class DeferredConv1D(types.SequenceLayer): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features): - if self._inner is not None: + if self.inner is not None: return c = self._config compute_dtype = getattr(c, 'compute_dtype', None) @@ -888,7 +888,7 @@ def _ensure_initialized(self, in_features): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(c.param_dtype) activation = init_mapping.map_activation(getattr(c, 'activation', None)) - self._inner = Conv1D( + self.inner = Conv1D( in_features=in_features, filters=c.filters, kernel_size=c.kernel_size, @@ -944,17 +944,17 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): self._ensure_initialized(input_spec.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer(x, constants=constants) + return self.inner.layer(x, constants=constants) def step(self, x, state, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.step(x, state, constants=constants) + return self.inner.step(x, state, constants=constants) class DeferredDepthwiseConv1D(types.SequenceLayer): @@ -963,10 +963,10 @@ class DeferredDepthwiseConv1D(types.SequenceLayer): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features): - if self._inner is not None: + if self.inner is not None: return c = self._config compute_dtype = getattr(c, 'compute_dtype', None) @@ -974,7 +974,7 @@ def _ensure_initialized(self, in_features): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(c.param_dtype) activation = init_mapping.map_activation(getattr(c, 'activation', None)) - self._inner = DepthwiseConv1D( + self.inner = DepthwiseConv1D( in_features=in_features, kernel_size=c.kernel_size, depth_multiplier=c.depth_multiplier, @@ -1029,17 +1029,17 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): self._ensure_initialized(input_spec.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer(x, constants=constants) + return self.inner.layer(x, constants=constants) def step(self, x, state, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.step(x, state, constants=constants) + return self.inner.step(x, state, constants=constants) class DeferredConv1DTranspose(types.SequenceLayer): @@ -1048,10 +1048,10 @@ class DeferredConv1DTranspose(types.SequenceLayer): def __init__(self, config): super().__init__() self._config = config - self._inner = None + self.inner = None def _ensure_initialized(self, in_features): - if self._inner is not None: + if self.inner is not None: return c = self._config compute_dtype = getattr(c, 'compute_dtype', None) @@ -1059,7 +1059,7 @@ def _ensure_initialized(self, in_features): compute_dtype = init_mapping._to_mx_dtype(compute_dtype) param_dtype = init_mapping._to_mx_dtype(c.param_dtype) activation = init_mapping.map_activation(getattr(c, 'activation', None)) - self._inner = Conv1DTranspose( + self.inner = Conv1DTranspose( in_features=in_features, filters=c.filters, kernel_size=c.kernel_size, @@ -1100,14 +1100,14 @@ def get_output_dtype(self, input_dtype, *, constants=None): def get_initial_state(self, batch_size, input_spec, *, constants=None): self._ensure_initialized(input_spec.shape[-1]) - return self._inner.get_initial_state( + return self.inner.get_initial_state( batch_size, input_spec, constants=constants ) def layer(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer(x, constants=constants) + return self.inner.layer(x, constants=constants) def step(self, x, state, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.step(x, state, constants=constants) + return self.inner.step(x, state, constants=constants) diff --git a/sequence_layers/mlx/convolution2d.py b/sequence_layers/mlx/convolution2d.py new file mode 100644 index 0000000..59f7c98 --- /dev/null +++ b/sequence_layers/mlx/convolution2d.py @@ -0,0 +1,1143 @@ +"""2D Convolution, transpose convolution, pooling, and upsampling layers for MLX.""" + +import dataclasses +import fractions +import math + +import mlx.core as mx +import mlx.nn as nn + +from sequence_layers.mlx import basic_types as bt +from sequence_layers.mlx import convolution as conv_utils +from sequence_layers.mlx import init_mapping +from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig + +Sequence = bt.Sequence +MaskedSequence = bt.MaskedSequence +PaddingMode = bt.PaddingMode + + +def _normalize_2tuple(x): + """Normalizes an int or sequence to a 2-tuple.""" + if isinstance(x, int): + return (x, x) + return tuple(x) + + +def _explicit_padding_2d(padding, kernel_size, stride, dilation_rate): + """Returns ((pad_time_left, pad_time_right), (pad_spatial_left, pad_spatial_right)).""" + time_pad = conv_utils._explicit_padding( + padding[0] if isinstance(padding, (list, tuple)) else padding, + kernel_size[0], stride[0], dilation_rate[0], + ) + spatial_padding = padding[1] if isinstance(padding, (list, tuple)) else padding + spatial_pad = conv_utils._explicit_padding( + spatial_padding, kernel_size[1], stride[1], dilation_rate[1], + ) + return time_pad, spatial_pad + + +# --------------------------------------------------------------------------- +# Conv2D +# --------------------------------------------------------------------------- + + +class Conv2D(types.SequenceLayer): + """2D convolution layer with separate time and spatial padding.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + filters: int = 1 + kernel_size: tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) + dilation_rate: tuple[int, int] = (1, 1) + time_padding: str = 'valid' + spatial_padding: str | tuple[int, int] = 'same' + groups: int = 1 + use_bias: bool = True + activation: object = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'kernel_size', _normalize_2tuple(self.kernel_size)) + object.__setattr__(self, 'strides', _normalize_2tuple(self.strides)) + object.__setattr__(self, 'dilation_rate', _normalize_2tuple(self.dilation_rate)) + if isinstance(self.spatial_padding, str): + pass # Keep as string. + else: + object.__setattr__(self, 'spatial_padding', tuple(self.spatial_padding)) + + def make(self) -> 'Conv2D': + return Conv2D.from_config(self) + + def __init__( + self, + *, + in_features, + filters, + kernel_size, + strides=(1, 1), + dilation_rate=(1, 1), + time_padding='valid', + spatial_padding='same', + groups=1, + use_bias=True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.in_features = in_features + self.filters = filters + self.kernel_size = _normalize_2tuple(kernel_size) + self.strides = _normalize_2tuple(strides) + self.dilation_rate = _normalize_2tuple(dilation_rate) + self.time_padding = time_padding + self.spatial_padding = spatial_padding + self.groups = groups + self.use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + # Create kernel: [out_channels, kH, kW, in_channels // groups] + key = mx.random.key(0) + init_fn = init_mapping._make_variance_scaling_init('fan_in', 'truncated_normal') + self.kernel = init_fn( + key, + (filters, self.kernel_size[0], self.kernel_size[1], in_features // groups), + param_dtype, + ) + if use_bias: + self.bias = mx.zeros((filters,), dtype=param_dtype) + + @property + def supports_step(self): + return conv_utils._supports_step(self.time_padding) + + @property + def block_size(self): + return self.strides[0] + + @property + def output_ratio(self): + return fractions.Fraction(1, self.strides[0]) + + @property + def input_latency(self): + ek = conv_utils._effective_kernel_size(self.kernel_size[0], self.dilation_rate[0]) + if self.time_padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self.time_padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 2: + raise ValueError( + f'Conv2D requires rank 4 input. Got channel_shape={input_shape}' + ) + freq_dim = input_shape[0] + # Compute spatial output size. + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.kernel_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + ek_sp = conv_utils._effective_kernel_size(self.kernel_size[1], self.dilation_rate[1]) + out_freq = (freq_dim + sp_pad[0] + sp_pad[1] - ek_sp) // self.strides[1] + 1 + return (out_freq, self.filters) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.compute_dtype or self._param_dtype + + def _forward(self, values, time_pad, spatial_pad): + """Apply 2D conv with explicit padding.""" + if time_pad[0] > 0 or time_pad[1] > 0 or spatial_pad[0] > 0 or spatial_pad[1] > 0: + values = mx.pad( + values, + [(0, 0), (time_pad[0], time_pad[1]), (spatial_pad[0], spatial_pad[1]), (0, 0)], + ) + compute_dtype = self.compute_dtype or self._param_dtype + values = values.astype(compute_dtype) + # mlx.core.conv2d: input [B, H, W, C_in], weight [C_out, kH, kW, C_in/groups] + y = mx.conv2d( + values, + self.kernel.astype(compute_dtype), + stride=self.strides, + padding=0, + dilation=self.dilation_rate, + groups=self.groups, + ) + if self.use_bias: + y = y + self.bias.astype(compute_dtype) + if self.activation is not None: + y = self.activation(y) + return y + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = conv_utils._buffer_width( + self.time_padding, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + ) + if not bw: + return () + # State is a MaskedSequence of shape [B, bw, freq, channels]. + freq_dim = input_spec.shape[0] + channels = input_spec.shape[1] if len(input_spec.shape) > 1 else 1 + if self.time_padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + ): + mask = mx.ones((batch_size, bw), dtype=bt.MASK_DTYPE) + else: + mask = mx.zeros((batch_size, bw), dtype=bt.MASK_DTYPE) + values = mx.zeros( + (batch_size, bw) + input_spec.shape, + dtype=input_spec.dtype, + ) + return MaskedSequence(values, mask) + + @types.check_step + def step(self, x, state, *, constants=None): + ek_time = conv_utils._effective_kernel_size(self.kernel_size[0], self.dilation_rate[0]) + if ek_time > 1: + x = x.mask_invalid() + + bw = conv_utils._buffer_width( + self.time_padding, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + ) + + if bw: + state = state.concatenate(x) + else: + state = x + + # Spatial padding always applied; time padding from buffer. + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.kernel_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + + values = self._forward(state.values, (0, 0), sp_pad) + mask = conv_utils._compute_conv_mask( + state.mask, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + self.time_padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @types.check_layer + def layer(self, x, *, constants=None): + if self.kernel_size[0] > 1: + x = x.mask_invalid() + + time_pad = conv_utils._explicit_padding( + self.time_padding, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + ) + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.kernel_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + + values = self._forward(x.values, time_pad, sp_pad) + mask = conv_utils._compute_conv_mask( + x.mask, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + self.time_padding, + is_step=False, + ) + return Sequence(values, mask) + + @classmethod + def from_config(cls, config): + compute_dtype = getattr(config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + activation = init_mapping.map_activation(getattr(config, 'activation', None)) + spatial_padding = config.spatial_padding + if isinstance(spatial_padding, str): + pass + else: + spatial_padding = tuple(spatial_padding) + return DeferredConv2D(config) + + +class DeferredConv2D(types.SequenceLayer): + """Deferred Conv2D: delays kernel creation until first use.""" + + def __init__(self, config): + super().__init__() + self._config = config + self.inner = None + + def _ensure_built(self, input_shape): + if self.inner is not None: + return + in_features = input_shape[-1] + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + activation = init_mapping.map_activation(getattr(self._config, 'activation', None)) + spatial_padding = self._config.spatial_padding + if isinstance(spatial_padding, str): + pass + else: + spatial_padding = tuple(spatial_padding) + + self.inner = Conv2D( + in_features=in_features, + filters=self._config.filters, + kernel_size=_normalize_2tuple(self._config.kernel_size), + strides=_normalize_2tuple(self._config.strides), + dilation_rate=_normalize_2tuple(getattr(self._config, 'dilation_rate', (1, 1))), + time_padding=getattr(self._config, 'time_padding', 'valid'), + spatial_padding=spatial_padding, + groups=getattr(self._config, 'groups', 1), + use_bias=getattr(self._config, 'use_bias', True), + activation=activation, + compute_dtype=compute_dtype, + param_dtype=init_mapping._to_mx_dtype(self._config.param_dtype), + ) + + @property + def supports_step(self): + return conv_utils._supports_step( + getattr(self._config, 'time_padding', 'valid') + ) + + @property + def block_size(self): + return _normalize_2tuple(self._config.strides)[0] + + @property + def output_ratio(self): + return fractions.Fraction(1, _normalize_2tuple(self._config.strides)[0]) + + @property + def input_latency(self): + ks = _normalize_2tuple(self._config.kernel_size) + dr = _normalize_2tuple(getattr(self._config, 'dilation_rate', (1, 1))) + tp = getattr(self._config, 'time_padding', 'valid') + ek = conv_utils._effective_kernel_size(ks[0], dr[0]) + if tp in (PaddingMode.CAUSAL_VALID.value, PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value): + return 0 + elif tp in (PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + self._ensure_built(input_shape) + return self.inner.get_output_shape(input_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + return init_mapping._to_mx_dtype(compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_built(input_spec.shape) + return self.inner.get_initial_state(batch_size, input_spec, constants=constants) + + @types.check_step + def step(self, x, state, *, constants=None): + self._ensure_built(x.channel_shape) + return self.inner.step(x, state, constants=constants) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_built(x.channel_shape) + return self.inner.layer(x, constants=constants) + + +# --------------------------------------------------------------------------- +# Conv2DTranspose +# --------------------------------------------------------------------------- + + +class Conv2DTranspose(types.SequenceLayer): + """2D transposed convolution layer.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + filters: int = 1 + kernel_size: tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) + dilation_rate: tuple[int, int] = (1, 1) + time_padding: str = 'valid' + spatial_padding: str | tuple[int, int] = 'same' + groups: int = 1 + use_bias: bool = True + activation: object = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'kernel_size', _normalize_2tuple(self.kernel_size)) + object.__setattr__(self, 'strides', _normalize_2tuple(self.strides)) + object.__setattr__(self, 'dilation_rate', _normalize_2tuple(self.dilation_rate)) + + def make(self) -> 'Conv2DTranspose': + return Conv2DTranspose.from_config(self) + + def __init__( + self, + *, + in_features, + filters, + kernel_size, + strides=(1, 1), + dilation_rate=(1, 1), + time_padding='valid', + spatial_padding='same', + groups=1, + use_bias=True, + activation=None, + compute_dtype=None, + param_dtype=mx.float32, + ): + super().__init__() + self.in_features = in_features + self.filters = filters + self.kernel_size = _normalize_2tuple(kernel_size) + self.strides = _normalize_2tuple(strides) + self.dilation_rate = _normalize_2tuple(dilation_rate) + self.time_padding = time_padding + self.spatial_padding = spatial_padding + self.groups = groups + self.use_bias = use_bias + self.activation = activation + self.compute_dtype = compute_dtype + self._param_dtype = param_dtype + + # Kernel: [out_channels, kH, kW, in_channels // groups] + key = mx.random.key(0) + init_fn = init_mapping._make_variance_scaling_init('fan_in', 'truncated_normal') + self.kernel = init_fn( + key, + (filters, self.kernel_size[0], self.kernel_size[1], in_features // groups), + param_dtype, + ) + if use_bias: + self.bias = mx.zeros((filters,), dtype=param_dtype) + + @property + def supports_step(self): + return self.time_padding == PaddingMode.CAUSAL.value + + @property + def block_size(self): + return 1 + + @property + def output_ratio(self): + return fractions.Fraction(self.strides[0]) + + @property + def input_latency(self): + return 0 + + def _time_trim(self): + """Returns (trim_left, trim_right) for time dimension.""" + return conv_utils._transpose_conv_output_trim( + self.kernel_size[0], self.strides[0], + self.dilation_rate[0], self.time_padding, + ) + + def _spatial_trim(self): + """Returns (trim_left, trim_right) for spatial dimension.""" + if isinstance(self.spatial_padding, str): + return conv_utils._transpose_conv_output_trim( + self.kernel_size[1], self.strides[1], + self.dilation_rate[1], self.spatial_padding, + ) + else: + return self.spatial_padding + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 2: + raise ValueError( + f'Conv2DTranspose requires rank 4 input. Got channel_shape={input_shape}' + ) + freq_dim = input_shape[0] + ek_sp = conv_utils._effective_kernel_size(self.kernel_size[1], self.dilation_rate[1]) + raw_sp = (freq_dim - 1) * self.strides[1] + ek_sp + sp_trim = self._spatial_trim() + out_freq = raw_sp - sp_trim[0] - sp_trim[1] + return (out_freq, self.filters) + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.compute_dtype or self._param_dtype + + def _conv_raw(self, values, trim_time=True): + """Compute raw conv_transpose2d, optionally trimming time. + + Args: + values: Input values. + trim_time: If True, trim time dimension (for layer mode). + If False, skip time trim (for step mode overlap-add). + Returns: + Raw convolution output WITHOUT bias or activation. + """ + compute_dtype = self.compute_dtype or self._param_dtype + values = values.astype(compute_dtype) + # mx.conv_transpose2d: input [B, H, W, C_in], weight [C_out, kH, kW, C_in/groups] + y = mx.conv_transpose2d( + values, + self.kernel.astype(compute_dtype), + stride=self.strides, + padding=0, + dilation=self.dilation_rate, + groups=self.groups, + ) + # Time trim (only in layer mode; step mode handles it via overlap-add). + if trim_time: + tl, tr = self._time_trim() + if tl > 0: + y = y[:, tl:] + if tr > 0: + y = y[:, :-tr] + # Spatial trim (always applied). + sl_val, sr = self._spatial_trim() + if sl_val > 0: + y = y[:, :, sl_val:] + if sr > 0: + y = y[:, :, :-sr] + return y + + def _apply_bias_and_activation(self, y): + """Apply bias and activation to conv output.""" + compute_dtype = self.compute_dtype or self._param_dtype + if self.use_bias: + y = y + self.bias.astype(compute_dtype) + if self.activation is not None: + y = self.activation(y) + return y + + def _forward(self, values): + """Full forward: conv + trim + bias + activation (for layer mode).""" + y = self._conv_raw(values, trim_time=True) + return self._apply_bias_and_activation(y) + + @types.check_layer + def layer(self, x, *, constants=None): + values = self._forward(x.values) + mask = conv_utils._compute_conv_transpose_mask( + x.mask, + self.kernel_size[0], + self.strides[0], + self.dilation_rate[0], + self.time_padding, + ) + return Sequence(values, mask) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not self.supports_step: + raise ValueError('Conv2DTranspose step only supported with causal padding.') + ola_buf = max( + 0, + conv_utils._effective_kernel_size(self.kernel_size[0], self.dilation_rate[0]) + - self.strides[0], + ) + if not ola_buf: + return () + out_shape = self.get_output_shape(input_spec.shape, constants=constants) + values = mx.zeros((batch_size, ola_buf) + out_shape, dtype=self.get_output_dtype(input_spec.dtype)) + mask = mx.zeros((batch_size, ola_buf), dtype=bt.MASK_DTYPE) + return MaskedSequence(values, mask) + + @types.check_step + def step(self, x, state, *, constants=None): + x = x.mask_invalid() + # Conv WITHOUT time trimming — keep full temporal output for overlap-add. + # Bias is also deferred until after overlap-add (matching JAX behavior). + raw = self._conv_raw(x.values, trim_time=False) + input_time = x.shape[1] + out_time = input_time * self.strides[0] + mask = mx.repeat(x.mask, self.strides[0], axis=1) + + ola_buf = max( + 0, + conv_utils._effective_kernel_size(self.kernel_size[0], self.dilation_rate[0]) + - self.strides[0], + ) + if ola_buf: + # Pad the state buffer to match the raw output length, then overlap-add. + # raw has shape (B, raw_time, ...) where raw_time >= out_time + ola_buf + buf_values = state.values # (B, ola_buf, ...) + pad_len = raw.shape[1] - ola_buf + if pad_len > 0: + buf_values = mx.concatenate( + [buf_values, mx.zeros_like(raw[:, :pad_len])], axis=1 + ) + # Overlap-add: add state to raw output. + out_values = buf_values + raw + # Split: first out_time samples are output, rest is new buffer. + out = out_values[:, :out_time] + new_buf = out_values[:, out_time:] + if new_buf.shape[1] < ola_buf: + pad_width = ola_buf - new_buf.shape[1] + new_buf = mx.pad(new_buf, [(0, 0), (0, pad_width)] + [(0, 0)] * (new_buf.ndim - 2)) + elif new_buf.shape[1] > ola_buf: + new_buf = new_buf[:, :ola_buf] + new_mask = mx.zeros((x.values.shape[0], ola_buf), dtype=bt.MASK_DTYPE) + state = MaskedSequence(new_buf, new_mask) + else: + out = raw[:, :out_time] + state = () + + # Apply bias and activation AFTER overlap-add (only once per sample). + out = self._apply_bias_and_activation(out) + + out_mask = mask[:, :out.shape[1]] + return Sequence(out, out_mask), state + + @classmethod + def from_config(cls, config): + return DeferredConv2DTranspose(config) + + +class DeferredConv2DTranspose(types.SequenceLayer): + """Deferred Conv2DTranspose.""" + + def __init__(self, config): + super().__init__() + self._config = config + self.inner = None + + def _ensure_built(self, input_shape): + if self.inner is not None: + return + in_features = input_shape[-1] + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + compute_dtype = init_mapping._to_mx_dtype(compute_dtype) + activation = init_mapping.map_activation(getattr(self._config, 'activation', None)) + spatial_padding = getattr(self._config, 'spatial_padding', 'same') + self.inner = Conv2DTranspose( + in_features=in_features, + filters=self._config.filters, + kernel_size=_normalize_2tuple(self._config.kernel_size), + strides=_normalize_2tuple(self._config.strides), + dilation_rate=_normalize_2tuple(getattr(self._config, 'dilation_rate', (1, 1))), + time_padding=getattr(self._config, 'time_padding', 'valid'), + spatial_padding=spatial_padding, + groups=getattr(self._config, 'groups', 1), + use_bias=getattr(self._config, 'use_bias', True), + activation=activation, + compute_dtype=compute_dtype, + param_dtype=init_mapping._to_mx_dtype(self._config.param_dtype), + ) + + @property + def supports_step(self): + return getattr(self._config, 'time_padding', 'valid') == PaddingMode.CAUSAL.value + + @property + def block_size(self): + return 1 + + @property + def output_ratio(self): + return fractions.Fraction(_normalize_2tuple(self._config.strides)[0]) + + @property + def input_latency(self): + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + self._ensure_built(input_shape) + return self.inner.get_output_shape(input_shape, constants=constants) + + def get_output_dtype(self, input_dtype, *, constants=None): + compute_dtype = getattr(self._config, 'compute_dtype', None) + if compute_dtype is not None: + return init_mapping._to_mx_dtype(compute_dtype) + return init_mapping._to_mx_dtype(self._config.param_dtype) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + self._ensure_built(input_spec.shape) + return self.inner.get_initial_state(batch_size, input_spec, constants=constants) + + @types.check_step + def step(self, x, state, *, constants=None): + self._ensure_built(x.channel_shape) + return self.inner.step(x, state, constants=constants) + + @types.check_layer + def layer(self, x, *, constants=None): + self._ensure_built(x.channel_shape) + return self.inner.layer(x, constants=constants) + + +# --------------------------------------------------------------------------- +# AveragePooling2D +# --------------------------------------------------------------------------- + + +class AveragePooling2D(types.SequenceLayer): + """2D average pooling with separate time and spatial padding.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + pool_size: tuple[int, int] = (1, 1) + strides: tuple[int, int] = (1, 1) + dilation_rate: tuple[int, int] = (1, 1) + time_padding: str = 'valid' + spatial_padding: str | tuple[int, int] = 'same' + masked_average: bool = False + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'pool_size', _normalize_2tuple(self.pool_size)) + object.__setattr__(self, 'strides', _normalize_2tuple(self.strides)) + object.__setattr__(self, 'dilation_rate', _normalize_2tuple(self.dilation_rate)) + + def make(self) -> 'AveragePooling2D': + return AveragePooling2D.from_config(self) + + def __init__( + self, + *, + pool_size, + strides=(1, 1), + dilation_rate=(1, 1), + time_padding='valid', + spatial_padding='same', + masked_average=False, + ): + super().__init__() + self.pool_size = _normalize_2tuple(pool_size) + self.strides = _normalize_2tuple(strides) + self.dilation_rate = _normalize_2tuple(dilation_rate) + self.time_padding = time_padding + self.spatial_padding = spatial_padding + self.masked_average = masked_average + + @property + def supports_step(self): + return conv_utils._supports_step(self.time_padding) + + @property + def block_size(self): + return self.strides[0] + + @property + def output_ratio(self): + return fractions.Fraction(1, self.strides[0]) + + @property + def input_latency(self): + ek = conv_utils._effective_kernel_size(self.pool_size[0], self.dilation_rate[0]) + if self.time_padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.CAUSAL.value, + PaddingMode.SEMICAUSAL.value, + ): + return 0 + elif self.time_padding in ( + PaddingMode.REVERSE_CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL.value, + ): + return ek - 1 + return 0 + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 2: + raise ValueError( + f'AveragePooling2D requires rank 4 input. Got channel_shape={input_shape}' + ) + freq_dim = input_shape[0] + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.pool_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + ek_sp = conv_utils._effective_kernel_size(self.pool_size[1], self.dilation_rate[1]) + out_freq = (freq_dim + sp_pad[0] + sp_pad[1] - ek_sp) // self.strides[1] + 1 + return (out_freq, input_shape[1]) + + def get_output_dtype(self, input_dtype, *, constants=None): + return input_dtype + + def _pool(self, values, time_pad, spatial_pad): + """Apply 2D average pooling with explicit padding.""" + if time_pad[0] > 0 or time_pad[1] > 0 or spatial_pad[0] > 0 or spatial_pad[1] > 0: + values = mx.pad( + values, + [(0, 0), (time_pad[0], time_pad[1]), (spatial_pad[0], spatial_pad[1]), (0, 0)], + ) + # Implement average pooling via im2col-style approach. + # For simplicity, use a strided mean. + b, t, h, c = values.shape + pt, ps = self.pool_size + st, ss = self.strides + out_t = (t - pt) // st + 1 + out_h = (h - ps) // ss + 1 + # Extract patches and average. + result = mx.zeros((b, out_t, out_h, c), dtype=values.dtype) + patches = [] + for dt in range(pt): + for ds in range(ps): + patch = values[:, dt:dt + out_t * st:st, ds:ds + out_h * ss:ss, :] + patches.append(patch) + result = sum(patches) / len(patches) + return result + + @types.check_layer + def layer(self, x, *, constants=None): + time_pad = conv_utils._explicit_padding( + self.time_padding, + self.pool_size[0], + self.strides[0], + self.dilation_rate[0], + ) + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.pool_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + + values = self._pool(x.values, time_pad, sp_pad) + mask = conv_utils._compute_conv_mask( + x.mask, + self.pool_size[0], + self.strides[0], + self.dilation_rate[0], + self.time_padding, + is_step=False, + ) + return Sequence(values, mask) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + bw = conv_utils._buffer_width( + self.time_padding, + self.pool_size[0], + self.strides[0], + self.dilation_rate[0], + ) + if not bw: + return () + if self.time_padding in ( + PaddingMode.CAUSAL_VALID.value, + PaddingMode.REVERSE_CAUSAL_VALID.value, + ): + mask = mx.ones((batch_size, bw), dtype=bt.MASK_DTYPE) + else: + mask = mx.zeros((batch_size, bw), dtype=bt.MASK_DTYPE) + values = mx.zeros( + (batch_size, bw) + input_spec.shape, + dtype=input_spec.dtype, + ) + return MaskedSequence(values, mask) + + @types.check_step + def step(self, x, state, *, constants=None): + bw = conv_utils._buffer_width( + self.time_padding, + self.pool_size[0], + self.strides[0], + self.dilation_rate[0], + ) + if bw: + state = state.concatenate(x) + else: + state = x + + if isinstance(self.spatial_padding, str): + sp_pad = conv_utils._explicit_padding( + self.spatial_padding, self.pool_size[1], + self.strides[1], self.dilation_rate[1], + ) + else: + sp_pad = self.spatial_padding + + values = self._pool(state.values, (0, 0), sp_pad) + mask = conv_utils._compute_conv_mask( + state.mask, + self.pool_size[0], + self.strides[0], + self.dilation_rate[0], + self.time_padding, + is_step=True, + ) + + if bw: + state = state[:, -bw:] + else: + state = () + + return Sequence(values, mask), state + + @classmethod + def from_config(cls, config): + pool_size = _normalize_2tuple(config.pool_size) + strides = _normalize_2tuple(config.strides) + dilation_rate = _normalize_2tuple(getattr(config, 'dilation_rate', (1, 1))) + return cls( + pool_size=pool_size, + strides=strides, + dilation_rate=dilation_rate, + time_padding=getattr(config, 'time_padding', 'valid'), + spatial_padding=getattr(config, 'spatial_padding', 'same'), + masked_average=getattr(config, 'masked_average', False), + ) + + +# --------------------------------------------------------------------------- +# Upsample2D +# --------------------------------------------------------------------------- + + +class Upsample2D(types.PreservesType, types.Stateless): + """2D upsampling layer using nearest-neighbor repetition.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + rate: tuple[int, int] = (1, 1) + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'rate', _normalize_2tuple(self.rate)) + + def make(self) -> 'Upsample2D': + return Upsample2D.from_config(self) + + def __init__(self, *, rate): + super().__init__() + self._rate = _normalize_2tuple(rate) + + @property + def output_ratio(self): + return fractions.Fraction(self._rate[0]) + + def get_output_shape(self, input_shape, *, constants=None): + if len(input_shape) != 2: + raise ValueError( + f'Upsample2D requires rank 4 input, got channel_shape={input_shape}' + ) + return (input_shape[0] * self._rate[1], input_shape[1]) + + @types.check_layer + def layer(self, x, *, constants=None): + values = mx.repeat(x.values, self._rate[0], axis=1) + values = mx.repeat(values, self._rate[1], axis=2) + mask = mx.repeat(x.mask, self._rate[0], axis=1) + return type(x)(values, mask) + + @classmethod + def from_config(cls, config): + return cls(rate=_normalize_2tuple(config.rate)) + + +# --------------------------------------------------------------------------- +# ParallelChannels +# --------------------------------------------------------------------------- + + +class ParallelChannels(types.Emitting): + """Applies a layer with shared parameters to groups of input channels. + + The input sequence is split on its final channels dimension into num_groups + separate sequences and processed with the child layer. Parameters for the + child layer are shared across all parallel invocations. + """ + + # CombinationMode values matching the JAX utils version. + STACK = 1 + CONCAT = 2 + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + child_layer: _SequenceLayerConfig = None + num_groups: int = 1 + combination: object = None # CombinationMode enum value + name: str | None = None + + def make(self, backend='mlx') -> 'ParallelChannels': + return ParallelChannels.from_config(self, backend=backend) + + def __init__(self, *, child_layer, num_groups, combination=CONCAT): + super().__init__() + self.child = child_layer + self._num_groups = num_groups + # Default to CONCAT (2) which is what soundstream uses. + if combination is None: + self._combination = self.CONCAT + elif hasattr(combination, 'value'): + self._combination = combination.value + else: + self._combination = int(combination) + + @property + def supports_step(self): + return self.child.supports_step + + @property + def block_size(self): + return self.child.block_size + + @property + def output_ratio(self): + return self.child.output_ratio + + @property + def input_latency(self): + return self.child.input_latency + + def _split(self, x): + """Split sequence along last channel dim into num_groups.""" + vals = x.values + c = vals.shape[-1] + if c % self._num_groups != 0: + raise ValueError( + f'Input channels ({c}) must be divisible by num_groups ({self._num_groups}).' + ) + group_size = c // self._num_groups + groups = [] + for i in range(self._num_groups): + g_vals = vals[..., i * group_size:(i + 1) * group_size] + groups.append(type(x)(g_vals, x.mask)) + return groups + + def _combine(self, outputs): + """Combine group outputs.""" + if self._combination == self.CONCAT: + # Concatenate along last axis. + combined_vals = mx.concatenate([o.values for o in outputs], axis=-1) + return Sequence(combined_vals, outputs[0].mask) + elif self._combination == self.STACK: + # Stack along a new axis before the last. + stacked = mx.stack([o.values for o in outputs], axis=-2) + return Sequence(stacked, outputs[0].mask) + else: + raise ValueError(f'Unsupported combination mode: {self._combination}') + + def get_output_shape(self, input_shape, *, constants=None): + if not input_shape: + raise ValueError(f'Input must be at least 3D, got: {input_shape=}.') + if input_shape[-1] % self._num_groups != 0: + raise ValueError( + f'Input channels ({input_shape[-1]}) must be divisible by' + f' num_groups ({self._num_groups}).' + ) + group_shape = list(input_shape) + group_shape[-1] //= self._num_groups + child_shape = self.child.get_output_shape(tuple(group_shape), constants=constants) + if self._combination == self.CONCAT: + return child_shape[:-1] + (child_shape[-1] * self._num_groups,) + elif self._combination == self.STACK: + return child_shape[:-1] + (self._num_groups,) + (child_shape[-1],) + else: + raise ValueError(f'Unsupported combination mode: {self._combination}') + + def get_output_dtype(self, input_dtype, *, constants=None): + return self.child.get_output_dtype(input_dtype, constants=constants) + + @types.check_layer + def layer(self, x, *, constants=None): + groups = self._split(x) + outputs = [self.child.layer(g, constants=constants) for g in groups] + return self._combine(outputs) + + def layer_with_emits(self, x, *, constants=None): + groups = self._split(x) + outputs, emits = [], [] + for g in groups: + y, e = self.child.layer_with_emits(g, constants=constants) + outputs.append(y) + emits.append(e) + return self._combine(outputs), tuple(emits) + + def get_initial_state(self, batch_size, input_spec, *, constants=None): + if not input_spec.shape: + raise ValueError(f'Input must be at least 3D, got: {input_spec.shape=}.') + if input_spec.shape[-1] % self._num_groups != 0: + raise ValueError( + f'Input channels ({input_spec.shape[-1]}) must be divisible by' + f' num_groups ({self._num_groups}).' + ) + group_shape = list(input_spec.shape) + group_shape[-1] //= self._num_groups + from sequence_layers.mlx import types as sl_types + group_spec = sl_types.ChannelSpec( + shape=tuple(group_shape), + dtype=input_spec.dtype, + ) + state = self.child.get_initial_state(batch_size, group_spec, constants=constants) + return (state,) * self._num_groups + + @types.check_step + def step(self, x, state, *, constants=None): + groups = self._split(x) + outputs = [] + new_states = [] + for g, s in zip(groups, state): + y, ns = self.child.step(g, s, constants=constants) + outputs.append(y) + new_states.append(ns) + return self._combine(outputs), tuple(new_states) + + def step_with_emits(self, x, state, *, constants=None): + groups = self._split(x) + outputs, new_states, emits = [], [], [] + for g, s in zip(groups, state): + y, ns, e = self.child.step_with_emits(g, s, constants=constants) + outputs.append(y) + new_states.append(ns) + emits.append(e) + return self._combine(outputs), tuple(new_states), tuple(emits) + + @classmethod + def from_config(cls, config, backend='mlx'): + child = config.child_layer.make(backend=backend) + return cls( + child_layer=child, + num_groups=config.num_groups, + combination=config.combination, + ) diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py index 0023969..ca06094 100644 --- a/sequence_layers/mlx/dense.py +++ b/sequence_layers/mlx/dense.py @@ -1,7 +1,10 @@ """Dense sequence layer for MLX.""" +import dataclasses import math +from typing import Callable + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -10,6 +13,7 @@ from sequence_layers.mlx import init_mapping from sequence_layers.mlx.init_mapping import _to_mx_dtype from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence @@ -164,12 +168,20 @@ def from_config(cls, config): class DenseDeferred(types.Stateless): - """Dense layer that defers weight creation until first use. + """Dense layer that defers weight creation until first use.""" - This is needed because Linen Dense.Config doesn't specify in_features; - it is inferred from the first input. This wrapper creates the actual - Dense layer on first call. - """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for Dense.""" + features: int = 1 + use_bias: bool = True + activation: Callable | None = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def make(self) -> 'DenseDeferred': + return DenseDeferred.from_config(self) def __init__( self, @@ -186,12 +198,12 @@ def __init__( self.activation = activation self.compute_dtype = compute_dtype self._param_dtype = param_dtype - self._inner = None + self.inner = None def _ensure_initialized(self, in_features: int): - if self._inner is not None: + if self.inner is not None: return - self._inner = Dense( + self.inner = Dense( in_features=in_features, features=self.features, use_bias=self._use_bias, @@ -215,7 +227,7 @@ def get_output_dtype(self, input_dtype, *, constants=None): @types.check_layer def layer(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) - return self._inner.layer(x, constants=constants) + return self.inner.layer(x, constants=constants) @classmethod def from_config(cls, config): @@ -233,11 +245,24 @@ def from_config(cls, config): class EinsumDense(types.Stateless): - """Dense layer using Einstein summation notation. + """Dense layer using Einstein summation notation.""" - Equation must be of the form '...ab,bc->...ac' where the leading '...' - broadcasts over batch and time dimensions. - """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + """MLX-native configuration for EinsumDense.""" + equation: str = '' + output_shape: tuple[int | None, ...] = () + bias_axes: str = '' + activation: Callable | None = None + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + + def make(self) -> 'EinsumDense': + return EinsumDense.from_config(self) def __init__( self, @@ -314,3 +339,7 @@ def from_config(cls, config): compute_dtype=compute_dtype, param_dtype=_to_mx_dtype(config.param_dtype), ) + + +# Alias so that sl.Dense.Config(...) works like sl_jax.Dense.Config(...). +Dense.Config = DenseDeferred.Config diff --git a/sequence_layers/mlx/dsp.py b/sequence_layers/mlx/dsp.py index 4d8bcbe..32bf3ad 100644 --- a/sequence_layers/mlx/dsp.py +++ b/sequence_layers/mlx/dsp.py @@ -1,5 +1,6 @@ """DSP layers for MLX.""" +import dataclasses import fractions import math @@ -9,6 +10,7 @@ from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import convolution as conv_utils from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence MaskedSequence = bt.MaskedSequence @@ -181,10 +183,16 @@ def mel_to_hz(m): class Delay(types.PreservesShape, types.PreservesType, types.SequenceLayer): - """Delays input by `length` timesteps. + """Delays input by `length` timesteps.""" - Inserts `length` invalid timesteps at the start of the sequence. - """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + length: int = 0 + delay_layer_output: bool = True + name: str | None = None + + def make(self) -> 'Delay': + return Delay.from_config(self) def __init__(self, *, length, delay_layer_output=True): super().__init__() @@ -247,6 +255,15 @@ def from_config(cls, config): class Lookahead(types.PreservesShape, types.PreservesType, types.SequenceLayer): """Drops the first `length` timesteps from the input.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + length: int = 0 + preserve_length_in_layer: bool = False + name: str | None = None + + def make(self) -> 'Lookahead': + return Lookahead.from_config(self) + def __init__(self, *, length, preserve_length_in_layer=False): super().__init__() if length < 0: @@ -904,10 +921,21 @@ def from_config(cls, config): class STFT(types.SequenceLayer): - """Short-Time Fourier Transform. - - Composes Frame -> Window -> RFFT. - """ + """Short-Time Fourier Transform.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + frame_length: int = 0 + frame_step: int = 0 + fft_length: int = 0 + window_fn: object = None + time_padding: str = 'reverse_causal_valid' + fft_padding: str = 'right' + output_magnitude: bool = False + name: str | None = None + + def make(self) -> 'STFT': + return STFT.from_config(self) def __init__( self, @@ -1018,10 +1046,20 @@ def from_config(cls, config): class InverseSTFT(types.SequenceLayer): - """Inverse Short-Time Fourier Transform. - - Composes IRFFT -> Window -> OverlapAdd. - """ + """Inverse Short-Time Fourier Transform.""" + + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + frame_length: int = 0 + frame_step: int = 0 + fft_length: int = 0 + window_fn: object = None + time_padding: str = 'causal' + fft_padding: str = 'right' + name: str | None = None + + def make(self) -> 'InverseSTFT': + return InverseSTFT.from_config(self) def __init__( self, diff --git a/sequence_layers/mlx/normalization.py b/sequence_layers/mlx/normalization.py index af29aa8..a420144 100644 --- a/sequence_layers/mlx/normalization.py +++ b/sequence_layers/mlx/normalization.py @@ -1,11 +1,14 @@ """Normalization layers for MLX.""" +import dataclasses + import mlx.core as mx import mlx.nn as nn from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import init_mapping from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence @@ -55,12 +58,23 @@ def from_config(cls, config): class RMSNormalization(types.PreservesType, types.StatelessPointwise): - """RMS Normalization backed by mlx.nn.RMSNorm. + """RMS Normalization backed by mlx.nn.RMSNorm.""" - For simple axis=-1 normalization with a learned scale, this delegates - to mlx.nn.RMSNorm (which uses the optimized mx.fast.rms_norm). - Falls back to manual computation for multi-axis or no-scale cases. - """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + axis: int | tuple[int, ...] = -1 + epsilon: float = 1e-6 + use_scale: bool = True + scale_init: object = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + if not isinstance(self.axis, int): + object.__setattr__(self, 'axis', tuple(self.axis)) + + def make(self) -> 'RMSNormalization': + return RMSNormalization.from_config(self) def __init__( self, @@ -105,7 +119,9 @@ def layer(self, x, *, constants=None): self._ensure_initialized(x.values.shape) if self._use_builtin and self._rms_norm is not None: - return Sequence(self._rms_norm(x.values), x.mask) + # Cast back to input dtype to preserve bfloat16 compute. + result = self._rms_norm(x.values).astype(x.values.dtype) + return Sequence(result, x.mask) values = x.values axes = _normalize_axes(self._axis, values.shape) @@ -150,6 +166,25 @@ class LayerNormalization(types.PreservesType, types.StatelessPointwise): Falls back to manual computation for multi-axis cases. """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + axis: int | tuple[int, ...] = -1 + epsilon: float = 1e-6 + use_bias: bool = True + use_scale: bool = True + # Accepted for JAX compatibility but ignored: MLX always reduces in fp32. + reductions_in_at_least_fp32: bool = True + param_dtype: types.DType = mx.float32 + name: str | None = None + + def __post_init__(self): + if not isinstance(self.axis, int): + object.__setattr__(self, 'axis', tuple(self.axis)) + + def make(self) -> 'LayerNormalization': + return LayerNormalization.from_config(self) + + def __init__( self, *, @@ -158,6 +193,7 @@ def __init__( use_bias: bool = True, use_scale: bool = True, param_dtype=mx.float32, + reductions_in_at_least_fp32: bool = True, ): super().__init__() self._axis = axis @@ -165,6 +201,7 @@ def __init__( self.use_bias = use_bias self.use_scale = use_scale self._param_dtype = param_dtype + self.reductions_in_at_least_fp32 = reductions_in_at_least_fp32 self._layer_norm = None self._use_builtin = False self._manual_scale = None @@ -199,7 +236,13 @@ def layer(self, x, *, constants=None): self._ensure_initialized(x.values.shape) if self._use_builtin and self._layer_norm is not None: - return Sequence(self._layer_norm(x.values), x.mask) + x_values = x.values + original_dtype = x_values.dtype + if self.reductions_in_at_least_fp32: + x_values = x_values.astype(mx.float32) + # Cast back to input dtype to preserve bfloat16 compute. + result = self._layer_norm(x_values).astype(original_dtype) + return Sequence(result, x.mask) values = x.values axes = _normalize_axes(self._axis, values.shape) @@ -241,6 +284,7 @@ def from_config(cls, config): use_bias=config.use_bias, use_scale=config.use_scale, param_dtype=_to_mx_dtype(config.param_dtype), + reductions_in_at_least_fp32=config.reductions_in_at_least_fp32 ) diff --git a/sequence_layers/mlx/position.py b/sequence_layers/mlx/position.py index 1db683a..32c84ba 100644 --- a/sequence_layers/mlx/position.py +++ b/sequence_layers/mlx/position.py @@ -1,10 +1,13 @@ """Position embeddings for MLX.""" +import dataclasses + import mlx.core as mx import numpy as np from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence @@ -16,6 +19,18 @@ class ApplyRotaryPositionalEncoding( ): """Applies Rotary Positional Encodings (RoPE) to the sequence.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + max_wavelength: float = 10000.0 + axis: int = -1 + only_advance_position_for_valid_timesteps: bool = True + positions_in_at_least_fp32: bool = True + positions_name: str | None = None + name: str | None = None + + def make(self) -> 'ApplyRotaryPositionalEncoding': + return ApplyRotaryPositionalEncoding.from_config(self) + def __init__( self, *, diff --git a/sequence_layers/mlx/projection_configs.py b/sequence_layers/mlx/projection_configs.py new file mode 100644 index 0000000..930cd87 --- /dev/null +++ b/sequence_layers/mlx/projection_configs.py @@ -0,0 +1,124 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""MLX-native attention projection configuration dataclasses. + +These are pure-Python dataclasses that mirror the JAX-side projection configs +from sequence_layers.jax.attention.common, but without JAX-specific fields +(sharding, einsum_factory, quantization_provider). They retain initializer +fields as Callable | None so that downstream code can still configure kernel +and bias initialization. +""" + +import dataclasses +from typing import Callable + + +@dataclasses.dataclass(frozen=True) +class QueryKeyValueProjectionConfig: + """Base class for QKV projection configuration.""" + pass + + +@dataclasses.dataclass(frozen=True) +class CombinedQueryKeyValueProjection(QueryKeyValueProjectionConfig): + """Use a single projection matrix for query/key/value projection. + + * Incompatible with Grouped Query Attention (num_query_heads != num_kv_heads). + * Supports shared key and value projection. + """ + + # Kernel initializer for the combined query/key/value projection. + # The variable shape is [input_dimension, 3, num_heads, units_per_head]. + # If share_kv_projection is True, the variable shape is [input_dimension, 2, + # num_heads, units_per_head]. + qkv_kernel_init: Callable | None = None + + # Bias initializer for the combined query/key/value projection. + # The variable shape is [3, num_heads, units_per_head]. + bias_init: Callable | None = None + + # If true, share the key and value projection matrices. + share_kv_projection: bool = False + + +@dataclasses.dataclass(frozen=True) +class SeparateQueryKeyValueProjection(QueryKeyValueProjectionConfig): + """Use separate projection matrices for query/key/value projection. + + * Supports Grouped Query Attention (num_query_heads != num_kv_heads). + * Does not support shared key and value projection. Use + QueryAndSharedKeyValueProjection. + """ + + # Kernel initializers for the separate query/key/value projections. + # The variable shape is [input_dimension, num_heads or num_kv_heads, + # units_per_head]. + q_kernel_init: Callable | None = None + k_kernel_init: Callable | None = None + v_kernel_init: Callable | None = None + + # Bias initializer for the separate query/key/value projections. + # The variable shape is [num_heads or num_kv_heads, units_per_head]. + bias_init: Callable | None = None + + +@dataclasses.dataclass(frozen=True) +class QueryAndKeyValueProjection(QueryKeyValueProjectionConfig): + """Use separate query and key/value projection matrices. + + * Supports Grouped Query Attention (num_query_heads != num_kv_heads). + * Does not support shared key and value projection. Use + QueryAndSharedKeyValueProjection. + """ + + # Kernel initializer for the query projection. + # The variable shape is [input_dimension, num_heads, units_per_head]. + q_kernel_init: Callable | None = None + + # Bias initializer for the query projection. + # The variable shape is [num_heads, units_per_head]. + q_bias_init: Callable | None = None + + # Kernel initializer for the key/value projection. + # The variable shape is [input_dimension, 2, num_kv_heads, units_per_head]. + kv_kernel_init: Callable | None = None + + # Bias initializer for the key/value projection. + # The variable shape is [2, num_kv_heads, units_per_head]. + kv_bias_init: Callable | None = None + + +@dataclasses.dataclass(frozen=True) +class QueryAndSharedKeyValueProjection(QueryKeyValueProjectionConfig): + """Use separate query and shared key/value projection matrices. + + * Supports Grouped Query Attention (num_query_heads != num_kv_heads). + * Requires shared key and value projection. + """ + + # Kernel initializer for the query projection. + # The variable shape is [input_dimension, num_heads, units_per_head]. + q_kernel_init: Callable | None = None + + # Bias initializer for the query projection. + # The variable shape is [num_heads, units_per_head]. + q_bias_init: Callable | None = None + + # Kernel initializer for the shared key/value projection. + # The variable shape is [input_dimension, num_kv_heads, units_per_head]. + kv_kernel_init: Callable | None = None + + # Bias initializer for the shared key/value projection. + # The variable shape is [num_kv_heads, units_per_head]. + kv_bias_init: Callable | None = None diff --git a/sequence_layers/mlx/signal.py b/sequence_layers/mlx/signal.py new file mode 100644 index 0000000..3452559 --- /dev/null +++ b/sequence_layers/mlx/signal.py @@ -0,0 +1,62 @@ +"""Signal utilities for MLX, ported from sequence_layers.jax.signal.""" + +import numpy as np +import mlx.core as mx + + +def _raised_cosine_window(window_length, periodic, dtype, a, b): + """Computes a raised cosine window.""" + if window_length == 1: + return np.ones([1], dtype=dtype) + even = 1 - window_length % 2 + n = np.asarray(window_length + int(periodic) * even - 1, dtype=dtype) + count = np.arange(window_length, dtype=dtype) + cos_arg = 2 * np.pi * count / n + return a - b * np.cos(cos_arg) + + +def hann_window(window_length, periodic=True, dtype=np.float32): + """Computes a hann window. Ported from tf.signal.""" + return _raised_cosine_window(window_length, periodic, dtype, 0.5, 0.5) + + +def hamming_window(window_length, periodic=True, dtype=np.float32): + """Computes a Hamming window.""" + a0 = 0.54 + return _raised_cosine_window(window_length, periodic, dtype, a0, 1.0 - a0) + + +def inverse_stft_window_fn(frame_step, forward_window_fn=hann_window): + """Generates a window function that can be used in inverse STFT. + + Constructs a window that is equal to the forward window with a further + pointwise amplitude correction. + + Args: + frame_step: The number of samples to step. + forward_window_fn: Window function used in the forward STFT transform. + + Returns: + A callable that takes a window length and a dtype keyword argument and + returns a [window_length] array of window samples. + """ + + def inverse_stft_window_fn_inner(frame_length, dtype=np.float32): + """Computes a window suitable for inverse STFT reconstruction.""" + # Use equation 7 from Griffin + Lim. + forward_window = forward_window_fn(frame_length, dtype=dtype) + # Convert to mx array for computation. + fw = mx.array(forward_window, dtype=mx.float32) + denom = mx.square(fw) + overlaps = -(-frame_length // frame_step) # Ceiling division. + denom = mx.pad(denom, [(0, overlaps * frame_step - frame_length)]) + denom = mx.reshape(denom, [overlaps, frame_step]) + denom = mx.sum(denom, axis=0, keepdims=True) + denom = mx.tile(denom, [overlaps, 1]) + denom = mx.reshape(denom, [overlaps * frame_step]) + denom = denom[:frame_length] + result = mx.where(denom == 0.0, 0, fw / denom) + # Convert back to numpy for consistency with the forward window. + return np.array(result, dtype=dtype) + + return inverse_stft_window_fn_inner diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py index 27bbadc..830df0b 100644 --- a/sequence_layers/mlx/simple.py +++ b/sequence_layers/mlx/simple.py @@ -1,7 +1,10 @@ """Simple sequence layers for MLX.""" +import dataclasses import math +from typing import Callable + import mlx.core as mx import mlx.nn as nn import numpy as np @@ -9,6 +12,7 @@ from sequence_layers.mlx import basic_types as bt from sequence_layers.mlx import init_mapping from sequence_layers.mlx import types +from sequence_layers.jax.types import SequenceLayerConfig as _SequenceLayerConfig Sequence = bt.Sequence MaskedSequence = bt.MaskedSequence @@ -22,6 +26,13 @@ class Identity(types.PreservesType, types.StatelessPointwise): """Identity pass-through of the input.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + name: str | None = None + + def make(self) -> 'Identity': + return Identity.from_config(self) + @types.check_layer def layer(self, x, *, constants=None): return x @@ -133,6 +144,14 @@ def from_config(cls, config): class Elu(types.PreservesType, types.StatelessPointwiseFunctor): """An ELU activation layer.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + alpha: complex = 1.0 + name: str | None = None + + def make(self) -> 'Elu': + return Elu.from_config(self) + def __init__(self, alpha=1.0): super().__init__() self._alpha = alpha @@ -197,6 +216,14 @@ def from_config(cls, config): class Cast(types.StatelessPointwiseFunctor): """Cast input values to the specified type.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + dtype: object = mx.float32 + name: str | None = None + + def make(self) -> 'Cast': + return Cast.from_config(self) + def __init__(self, dtype): super().__init__() self._dtype = dtype @@ -221,6 +248,14 @@ def from_config(cls, config): class Scale(types.PreservesType, types.StatelessPointwise): """Scales the input by a provided constant or array.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + scale: object = 1.0 + name: str | None = None + + def make(self) -> 'Scale': + return Scale.from_config(self) + def __init__(self, scale): super().__init__() if isinstance(scale, (int, float, complex)): @@ -302,6 +337,15 @@ def from_config(cls, config): class GatedUnit(types.PreservesType, types.Stateless): """Computes a generalized Gated Unit, reducing input channels by 2x.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + feature_activation: Callable | None = None + gate_activation: Callable | None = None + name: str | None = None + + def make(self) -> 'GatedUnit': + return GatedUnit.from_config(self) + def __init__(self, feature_activation=None, gate_activation=None): super().__init__() self._feature_activation = feature_activation @@ -366,11 +410,14 @@ def from_config(cls, config): class Flatten(types.PreservesType, types.StatelessPointwise): - """Flattens the channel dimensions of the input sequence. + """Flattens the channel dimensions of the input sequence.""" - An input sequence with shape [batch_size, time, ...] is reshaped to - [batch_size, time, prod(...)]. The mask is unchanged. - """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + name: str | None = None + + def make(self) -> 'Flatten': + return Flatten.from_config(self) def get_output_shape(self, input_shape, *, constants=None): return (math.prod(input_shape),) @@ -392,6 +439,17 @@ def from_config(cls, config): class Reshape(types.PreservesType, types.Stateless): """Reshapes the channels dimension of the input.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + output_shape: tuple[int, ...] = () + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'output_shape', tuple(self.output_shape)) + + def make(self) -> 'Reshape': + return Reshape.from_config(self) + def __init__(self, output_shape): super().__init__() self._output_shape = tuple(output_shape) @@ -426,6 +484,18 @@ def from_config(cls, config): class ExpandDims(types.PreservesType, types.Stateless): """Expands channel dimensions of the input sequence.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + axis: int | tuple[int, ...] = 0 + name: str | None = None + + def __post_init__(self): + if not isinstance(self.axis, int): + object.__setattr__(self, 'axis', tuple(self.axis)) + + def make(self) -> 'ExpandDims': + return ExpandDims.from_config(self) + def __init__(self, axis): super().__init__() if isinstance(axis, int): @@ -585,6 +655,17 @@ class Embedding(types.Stateless): Backed by mlx.nn.Embedding. """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + num_embeddings: int = 1 + dimension: int = 1 + compute_dtype: types.DType | None = None + param_dtype: types.DType = mx.float32 + name: str | None = None + + def make(self) -> 'Embedding': + return Embedding.from_config(self) + def __init__( self, *, @@ -641,6 +722,15 @@ def from_config(cls, config): class Dropout(types.PreservesType, types.StatelessPointwise): """Dropout layer (pass-through during inference).""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + rate: float = 0.0 + broadcast_dims: tuple[int, ...] = () + name: str | None = None + + def make(self) -> 'Dropout': + return Dropout.from_config(self) + def __init__(self, rate=0.0): super().__init__() self._rate = rate @@ -726,6 +816,14 @@ def from_config(cls, config): class CheckpointName(types.PreservesType, types.StatelessPointwiseFunctor): """Identity pass-through (checkpoint naming is JAX-only).""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + checkpoint_name: str = '' + name: str | None = None + + def make(self) -> 'CheckpointName': + return CheckpointName.from_config(self) + def __init__(self, checkpoint_name=''): super().__init__() self._checkpoint_name = checkpoint_name @@ -750,16 +848,62 @@ def from_config(cls, config): class Lambda(types.Stateless): """A SequenceLayer that wraps a Python callable.""" - def __init__(self, fn, *, sequence_input=False, mask_required=True): + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + fn: Callable = None + sequence_input: bool = False + mask_required: bool = True + # Accepted for JAX compatibility but ignored by MLX Lambda. + expected_input_spec: object = None + expected_output_spec: object = None + name: str | None = None + + def make(self) -> 'Lambda': + return Lambda.from_config(self) + + def __init__(self, fn, *, sequence_input=False, mask_required=True, + expected_output_spec=None): super().__init__() self._fn = fn self._sequence_input = sequence_input self._mask_required = mask_required + self._expected_output_spec = expected_output_spec + self._cached_output_spec = None + + def _probe_output(self, input_shape, input_dtype): + """Probe the function with a dummy to infer output shape/dtype.""" + if self._expected_output_spec is not None: + return self._expected_output_spec + if self._cached_output_spec is not None: + return self._cached_output_spec + try: + dummy_values = mx.zeros((1, 1) + tuple(input_shape), dtype=input_dtype) + dummy_mask = mx.ones((1, 1), dtype=mx.bool_) + if self._sequence_input: + result = self._fn(Sequence(dummy_values, dummy_mask)) + out_shape = result.values.shape[2:] + out_dtype = result.values.dtype + else: + out_values = self._fn(dummy_values) + out_shape = out_values.shape[2:] + out_dtype = out_values.dtype + self._cached_output_spec = bt.ShapeDType(out_shape, out_dtype) + return self._cached_output_spec + except Exception: + return None def get_output_shape(self, input_shape, *, constants=None): + spec = self._probe_output(input_shape, mx.float32) + if spec is not None: + return tuple(spec.shape) return tuple(input_shape) - @types.check_layer + def get_output_dtype(self, input_dtype, *, constants=None): + spec = self._probe_output((1,), input_dtype) + if spec is not None: + return spec.dtype + return input_dtype + def layer(self, x, *, constants=None): if self._sequence_input: result = self._fn(x) @@ -781,6 +925,7 @@ def from_config(cls, config): fn=config.fn, sequence_input=config.sequence_input, mask_required=config.mask_required, + expected_output_spec=getattr(config, 'expected_output_spec', None), ) @@ -792,6 +937,15 @@ def from_config(cls, config): class Logging(types.PreservesType, types.StatelessPointwise): """Logs input info and returns the input unchanged.""" + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + prefix: str = '' + dump_tensors: bool = False + name: str | None = None + + def make(self) -> 'Logging': + return Logging.from_config(self) + def __init__(self, prefix='', dump_tensors=False): super().__init__() self._prefix = prefix diff --git a/sequence_layers/mlx/types.py b/sequence_layers/mlx/types.py index e5ffd0d..7ea1b72 100644 --- a/sequence_layers/mlx/types.py +++ b/sequence_layers/mlx/types.py @@ -354,7 +354,7 @@ def check_layer(layer_fn): """Validates layer inputs and outputs.""" @functools.wraps(layer_fn) - def wrapper(self, x, *, constants=None): + def wrapper(self, x, *, constants=None, **kwargs): y = layer_fn(self, x, constants=constants) _check_output_spec(self, x, y, constants) return y @@ -366,7 +366,7 @@ def check_step(step_fn): """Validates step inputs and outputs.""" @functools.wraps(step_fn) - def wrapper(self, x, state, *, constants=None): + def wrapper(self, x, state, *, constants=None, **kwargs): if not self.supports_step: raise ValueError(f'{self.__class__.__name__} does not support step().') block_size = self.block_size @@ -453,7 +453,7 @@ def layer( """Process this layer layer-wise.""" def layer_with_emits( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, constants: Constants | None = None, **kwargs ) -> tuple[Sequence, Emits]: return self.layer(x, constants=constants), () @@ -474,6 +474,7 @@ def step_with_emits( state: State, *, constants: Constants | None = None, + **kwargs, ) -> tuple[Sequence, State, Emits]: y, state = self.step(x, state, constants=constants) return y, state, () @@ -583,6 +584,7 @@ def get_initial_state( input_spec: ChannelSpec, *, constants: Constants | None = None, + **kwargs, ) -> State: return () @@ -592,6 +594,7 @@ def step( state: State, *, constants: Constants | None = None, + **kwargs, ) -> tuple[Sequence, State]: return self.layer(x, constants=constants), state @@ -639,6 +642,7 @@ def step( state: State, *, constants: Constants | None = None, + **kwargs, ) -> tuple[Sequence, State]: y, state, _ = self.step_with_emits(x, state, constants=constants) return y, state @@ -654,7 +658,7 @@ def step_with_emits( pass def layer( - self, x: Sequence, *, constants: Constants | None = None + self, x: Sequence, *, constants: Constants | None = None, **kwargs ) -> Sequence: y, _ = self.layer_with_emits(x, constants=constants) return y @@ -675,6 +679,7 @@ def step_with_emits( state: State, *, constants: Constants | None = None, + **kwargs, ) -> tuple[Sequence, State, Emits]: y, emits = self.layer_with_emits(x, constants=constants) return y, state, emits @@ -685,5 +690,6 @@ def get_initial_state( input_spec: ChannelSpec, *, constants: Constants | None = None, + **kwargs, ) -> State: return () diff --git a/sequence_layers/mlx/typing.py b/sequence_layers/mlx/typing.py new file mode 100644 index 0000000..32aaa4d --- /dev/null +++ b/sequence_layers/mlx/typing.py @@ -0,0 +1,43 @@ +"""Lightweight typing utilities for MLX sequence layers. + +Provides type annotation helpers compatible with the jaxtyping-style API +used in the JAX version, but without JAX dependencies. Since runtime type +checking is disabled, these are purely for documentation and IDE support. +""" + +from typing import Any, Callable, TypeVar + +import mlx.core as mx +import numpy as np + +try: + from jaxtyping import Float, Int, Shaped, PyTree +except ImportError: + # Fallback: define no-op type aliases if jaxtyping is not available. + Float = Any + Int = Any + Shaped = Any + PyTree = Any + + +class _MetaArrayT(type): + types = () + + def __instancecheck__(cls, obj): + return isinstance(obj, cls.types) + + +class ArrayT(metaclass=_MetaArrayT): + types = (mx.array, np.ndarray) + + +ScalarInt = Any +ScalarFloat = Any +AnyPyTree = Any + +_F = TypeVar('_F', bound=Callable) + + +def typed(function: _F) -> _F: + """No-op decorator for type-checked functions (runtime checking disabled).""" + return function diff --git a/sequence_layers/mlx/utils.py b/sequence_layers/mlx/utils.py new file mode 100644 index 0000000..36d2a6a --- /dev/null +++ b/sequence_layers/mlx/utils.py @@ -0,0 +1,72 @@ +"""Utility functions for MLX sequence layers.""" + +import fractions + +from sequence_layers.mlx.combinators import CombinationMode + + +def get_output_latency(config, accumulated_output_latency=0): + """Returns the output latency of the provided SequenceLayerConfig. + + In MLX, we can simply instantiate the layer and compute the latency + directly without needing JAX's eval_shape. + + Args: + config: A SequenceLayerConfig to compute output latency for. + accumulated_output_latency: The accumulated output latency of preceding + layers. Defaults to 0. + + Returns: + The output latency of the layer. + """ + layer = config.make() + return _get_accumulated_output_latency(layer, accumulated_output_latency) + + +def _get_accumulated_output_latency(layer, output_latency): + """Computes accumulated output latency for a layer. + + Mirrors SequenceLayer.get_accumulated_output_latency from JAX types. + """ + # Check for Serial-like combinators that chain layers. + if hasattr(layer, 'layers') and isinstance(layer.layers, (list, tuple)): + for sub in layer.layers: + output_latency = _get_accumulated_output_latency(sub, output_latency) + return output_latency + + # Check for internal body (Residual stores layers in _body). + if hasattr(layer, '_body'): + return _get_accumulated_output_latency(layer.body, output_latency) + + # Check for deferred layers that wrap another layer. + if hasattr(layer, '_layer') and layer.inner is not None: + return _get_accumulated_output_latency(layer.inner, output_latency) + if hasattr(layer, '_child'): + return _get_accumulated_output_latency(layer.child, output_latency) + + # Single layer: compute latency. + output_ratio = layer.output_ratio + return int(output_latency * output_ratio) + layer.output_latency + + +def get_required_stepwise_delay(output_ratio, input_latency): + """Returns the delay required so input_latency is divisible by 1/output_ratio. + + When combining upsampling and downsampling layers with latency, + layer/step equivalence requires inserting delays. This function returns the + correct amount of step-wise delay to insert. + + Args: + output_ratio: The output ratio of the layer (a fractions.Fraction). + input_latency: The accumulated input latency of layers preceding the layer. + + Returns: + The amount of delay required to ensure input latency is divisible by + output_ratio. + """ + if 1 not in output_ratio.as_integer_ratio(): + raise NotImplementedError( + 'get_required_stepwise_delay expects integer upsampling or' + f' downsampling, got {output_ratio=}' + ) + return int(-input_latency % (1 / output_ratio)) diff --git a/sequence_layers/mlx/weight_converter.py b/sequence_layers/mlx/weight_converter.py index 46682e8..4362be0 100644 --- a/sequence_layers/mlx/weight_converter.py +++ b/sequence_layers/mlx/weight_converter.py @@ -236,7 +236,7 @@ def _slice_params(params, index): def _load_residual(mlx_residual, linen_params, config, batch_stats=None): """Load Residual: body is layers_{i}, shortcut is shortcut_layer.""" # Body is a Serial inside the Residual. - body = mlx_residual._body + body = mlx_residual.body for i, layer_config in enumerate(config.layers): key = f'layers_{i}' child_params = linen_params.get(key, {}) @@ -256,7 +256,7 @@ def _load_residual(mlx_residual, linen_params, config, batch_stats=None): sc_key = f'layers_{i}' sc_bs = shortcut_bs.get(sc_key, {}) if shortcut_bs else None _load_config( - mlx_residual._shortcut, + mlx_residual.shortcut, shortcut_params.get(sc_key, {}), sc_config, batch_stats=sc_bs, @@ -267,8 +267,8 @@ def _load_dense(mlx_dense, linen_params, config): """Load Dense: transpose kernel [in, out] → [out, in].""" # Handle DenseDeferred wrapper. inner = mlx_dense - if hasattr(inner, '_inner') and inner._inner is not None: - inner = inner._inner + if hasattr(inner, 'inner') and inner.inner is not None: + inner = inner.inner kernel = linen_params.get('kernel') if kernel is not None: @@ -304,15 +304,16 @@ def _load_attention(mlx_attn, linen_params, config): value_projection/kernel [in, kv_heads, uph] """ from sequence_layers.jax.attention import common as attn_common + from sequence_layers.mlx import projection_configs as mlx_proj # Handle Deferred wrapper. inner = mlx_attn - if hasattr(inner, '_inner') and inner._inner is not None: - inner = inner._inner + if hasattr(inner, 'inner') and inner.inner is not None: + inner = inner.inner input_projection = config.input_projection - if isinstance(input_projection, attn_common.CombinedQueryKeyValueProjection): + if isinstance(input_projection, (attn_common.CombinedQueryKeyValueProjection, mlx_proj.CombinedQueryKeyValueProjection)): # Combined QKV: kernel [in, 3, heads, uph] → separate q/k/v. qkv_params = linen_params.get('query_key_value_projection', {}) combined_kernel = qkv_params.get('kernel') @@ -331,7 +332,7 @@ def _load_attention(mlx_attn, linen_params, config): inner.v_bias = mx.array(vb.reshape(-1)) elif isinstance( - input_projection, attn_common.SeparateQueryKeyValueProjection + input_projection, (attn_common.SeparateQueryKeyValueProjection, mlx_proj.SeparateQueryKeyValueProjection) ): # Separate Q/K/V projections (used for GQA where num_kv_heads < num_heads). q_params = linen_params.get('query_projection', {}) @@ -361,6 +362,11 @@ def _load_attention(mlx_attn, linen_params, config): if v_bias is not None: inner.v_bias = mx.array(v_bias.reshape(-1)) + # per_dim_scale: learned [units_per_head] query scale. + per_dim_scale = linen_params.get('per_dim_scale') + if per_dim_scale is not None: + inner._per_dim_scale = mx.array(per_dim_scale) + # Q/K/V processing networks have no trainable params # (RoPE is stateless with no learned weights). @@ -381,11 +387,12 @@ def _load_streaming_attention(mlx_attn, linen_params, config): shared_key_value_projection/kernel [source, heads, uph] """ from sequence_layers.jax.attention import common as attn_common + from sequence_layers.mlx import projection_configs as mlx_proj # Handle Deferred wrapper. inner = mlx_attn - if hasattr(inner, '_inner') and inner._inner is not None: - inner = inner._inner + if hasattr(inner, 'inner') and inner.inner is not None: + inner = inner.inner input_projection = config.input_projection @@ -400,7 +407,7 @@ def _load_streaming_attention(mlx_attn, linen_params, config): if q_bias is not None: inner.q_bias = mx.array(q_bias.reshape(-1)) - if isinstance(input_projection, attn_common.QueryAndKeyValueProjection): + if isinstance(input_projection, (attn_common.QueryAndKeyValueProjection, mlx_proj.QueryAndKeyValueProjection)): # Combined KV: kernel [source, 2, heads, uph] → split into K, V. kv_params = linen_params.get('key_value_projection', {}) kv_kernel = kv_params.get('kernel') @@ -417,7 +424,7 @@ def _load_streaming_attention(mlx_attn, linen_params, config): inner.v_bias = mx.array(vb.reshape(-1)) elif isinstance( - input_projection, attn_common.SeparateQueryKeyValueProjection + input_projection, (attn_common.SeparateQueryKeyValueProjection, mlx_proj.SeparateQueryKeyValueProjection) ): # Separate K and V projections. k_params = linen_params.get('key_projection', {}) @@ -439,7 +446,7 @@ def _load_streaming_attention(mlx_attn, linen_params, config): inner.v_bias = mx.array(v_bias.reshape(-1)) elif isinstance( - input_projection, attn_common.QueryAndSharedKeyValueProjection + input_projection, (attn_common.QueryAndSharedKeyValueProjection, mlx_proj.QueryAndSharedKeyValueProjection) ): # Shared K/V projection: same weights for both K and V. shared_params = linen_params.get('shared_key_value_projection', {}) @@ -455,6 +462,11 @@ def _load_streaming_attention(mlx_attn, linen_params, config): inner.k_bias = b inner.v_bias = b + # per_dim_scale: learned [units_per_head] query scale. + per_dim_scale = linen_params.get('per_dim_scale') + if per_dim_scale is not None: + inner._per_dim_scale = mx.array(per_dim_scale) + def _load_rms_norm(mlx_norm, linen_params, config): """Load RMSNormalization: scale [dim] → same.""" @@ -523,8 +535,8 @@ def _load_group_norm(mlx_gn, linen_params, config): def _load_conv1d(mlx_conv, linen_params, config): """Load Conv1D: kernel [k, in, out] → [out, k, in].""" inner = mlx_conv - if hasattr(inner, '_inner') and inner._inner is not None: - inner = inner._inner + if hasattr(inner, 'inner') and inner.inner is not None: + inner = inner.inner kernel = linen_params.get('kernel') if kernel is not None: @@ -548,8 +560,8 @@ def _load_conv1d_transpose(mlx_conv, linen_params, config): conv_transpose1d which reverses the kernel direction. """ inner = mlx_conv - if hasattr(inner, '_inner') and inner._inner is not None: - inner = inner._inner + if hasattr(inner, 'inner') and inner.inner is not None: + inner = inner.inner kernel = linen_params.get('kernel') if kernel is not None: From 6e9db41486fe09bba176dc21749307a9e253b049 Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:27:10 -0500 Subject: [PATCH 05/17] more efficient attention --- sequence_layers/mlx/attention.py | 49 +++++++++++++++++-------- sequence_layers/mlx/weight_converter.py | 24 +++++++----- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index 82cbdcc..1ae73f1 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -127,6 +127,7 @@ def __init__( value_network: types.SequenceLayer | None = None, attention_logits_soft_cap: float | None = None, num_sink_embeddings: int = 0, + input_projection=None, ): super().__init__() if num_kv_heads is None: @@ -170,14 +171,20 @@ def __init__( q_dim = num_heads * units_per_head kv_dim = num_kv_heads * units_per_head - # Projections stored as [in, out] to match Linen convention. - self.q_proj = kernel_init(key, (in_features, q_dim), param_dtype) - self.k_proj = kernel_init(key, (in_features, kv_dim), param_dtype) - self.v_proj = kernel_init(key, (in_features, kv_dim), param_dtype) - if use_bias: - self.q_bias = bias_init(key, (q_dim,), param_dtype) - self.k_bias = bias_init(key, (kv_dim,), param_dtype) - self.v_bias = bias_init(key, (kv_dim,), param_dtype) + self.input_projection = input_projection + if isinstance(input_projection, projection_configs.CombinedQueryKeyValueProjection) and self.num_kv_heads == self.num_heads: + out_dim = q_dim + 2 * kv_dim + self.qkv_proj = kernel_init(key, (in_features, out_dim), param_dtype) + if use_bias: + self.qkv_bias = bias_init(key, (out_dim,), param_dtype) + else: + self.q_proj = kernel_init(key, (in_features, q_dim), param_dtype) + self.k_proj = kernel_init(key, (in_features, kv_dim), param_dtype) + self.v_proj = kernel_init(key, (in_features, kv_dim), param_dtype) + if use_bias: + self.q_bias = bias_init(key, (q_dim,), param_dtype) + self.k_bias = bias_init(key, (kv_dim,), param_dtype) + self.v_bias = bias_init(key, (kv_dim,), param_dtype) # Attention sink embeddings. self.num_sink_embeddings = num_sink_embeddings @@ -210,14 +217,22 @@ def _project_qkv(self, x): dtype = self.compute_dtype or x.dtype v = x.values.astype(dtype) - q = mx.matmul(v, self.q_proj.astype(dtype)) - k = mx.matmul(v, self.k_proj.astype(dtype)) - val = mx.matmul(v, self.v_proj.astype(dtype)) + + if hasattr(self, 'qkv_proj'): + qkv = mx.matmul(v, self.qkv_proj.astype(dtype)) + if self.use_bias: + qkv = qkv + self.qkv_bias.astype(dtype) + + q, k, val = mx.split(qkv, 3, axis=-1) + else: + q = mx.matmul(v, self.q_proj.astype(dtype)) + k = mx.matmul(v, self.k_proj.astype(dtype)) + val = mx.matmul(v, self.v_proj.astype(dtype)) - if self.use_bias: - q = q + self.q_bias.astype(dtype) - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + if self.use_bias: + q = q + self.q_bias.astype(dtype) + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) # Reshape to [b, t, heads, units_per_head]. q = q.reshape(b, t, self.num_heads, self.units_per_head) @@ -593,6 +608,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): key_network=key_network, value_network=value_network, num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), + input_projection=getattr(self._config, 'input_projection', None), ) @property @@ -1097,6 +1113,7 @@ def __init__( key_network: types.SequenceLayer | None = None, value_network: types.SequenceLayer | None = None, num_sink_embeddings: int = 0, + input_projection=None, ): super().__init__() if max_past_horizon < 1: @@ -1544,6 +1561,7 @@ def _ensure_initialized(self, in_features, source_features, backend='mlx'): key_network=key_network, value_network=value_network, num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), + input_projection=getattr(self._config, 'input_projection', None), ) def _get_source(self, constants): @@ -1701,6 +1719,7 @@ def _ensure_initialized(self, in_features, backend='mlx'): key_network=key_network, value_network=value_network, num_sink_embeddings=getattr(self._config, 'num_sink_embeddings', 0), + input_projection=getattr(self._config, 'input_projection', None), ) @property diff --git a/sequence_layers/mlx/weight_converter.py b/sequence_layers/mlx/weight_converter.py index 4362be0..93c1057 100644 --- a/sequence_layers/mlx/weight_converter.py +++ b/sequence_layers/mlx/weight_converter.py @@ -314,22 +314,28 @@ def _load_attention(mlx_attn, linen_params, config): input_projection = config.input_projection if isinstance(input_projection, (attn_common.CombinedQueryKeyValueProjection, mlx_proj.CombinedQueryKeyValueProjection)): - # Combined QKV: kernel [in, 3, heads, uph] → separate q/k/v. + # Combined QKV: kernel [in, 3, heads, uph] qkv_params = linen_params.get('query_key_value_projection', {}) combined_kernel = qkv_params.get('kernel') if combined_kernel is not None: in_features = combined_kernel.shape[0] - q, k, v = np.split(combined_kernel, 3, axis=1) - inner.q_proj = mx.array(q.reshape(in_features, -1)) - inner.k_proj = mx.array(k.reshape(in_features, -1)) - inner.v_proj = mx.array(v.reshape(in_features, -1)) + if hasattr(inner, 'qkv_proj'): + inner.qkv_proj = mx.array(combined_kernel.reshape(in_features, -1)) + else: + q, k, v = np.split(combined_kernel, 3, axis=1) + inner.q_proj = mx.array(q.reshape(in_features, -1)) + inner.k_proj = mx.array(k.reshape(in_features, -1)) + inner.v_proj = mx.array(v.reshape(in_features, -1)) combined_bias = qkv_params.get('bias') if combined_bias is not None: - qb, kb, vb = np.split(combined_bias, 3, axis=0) - inner.q_bias = mx.array(qb.reshape(-1)) - inner.k_bias = mx.array(kb.reshape(-1)) - inner.v_bias = mx.array(vb.reshape(-1)) + if hasattr(inner, 'qkv_bias'): + inner.qkv_bias = mx.array(combined_bias.reshape(-1)) + else: + qb, kb, vb = np.split(combined_bias, 3, axis=0) + inner.q_bias = mx.array(qb.reshape(-1)) + inner.k_bias = mx.array(kb.reshape(-1)) + inner.v_bias = mx.array(vb.reshape(-1)) elif isinstance( input_projection, (attn_common.SeparateQueryKeyValueProjection, mlx_proj.SeparateQueryKeyValueProjection) From 0666d8c81d68eb70e2d9486c89fb236910a0d4a5 Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:54:30 -0500 Subject: [PATCH 06/17] Update attention.py use mx.fast.scaled_dot_product_attention --- sequence_layers/mlx/attention.py | 46 ++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index 1ae73f1..e4b9791 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -257,6 +257,25 @@ def _compute_attention(self, queries, keys, values, mask): Returns: context: [b, q_t, num_heads, units_per_head] """ + can_use_fast_sdpa = ( + self.sink_key_embeddings is None and + getattr(self, '_attention_logits_soft_cap', None) is None + ) + + if can_use_fast_sdpa: + # Fast path: GQA is natively supported, so we do not repeat keys/values. + q = mx.transpose(queries, (0, 2, 1, 3)) + k = mx.transpose(keys, (0, 2, 1, 3)) + v = mx.transpose(values, (0, 2, 1, 3)) + + q = _scale_queries( + q, self._per_dim_scale, self._query_scale, self.units_per_head + ) + + # Use mx.fast.scaled_dot_product_attention with scale=1.0 since we pre-scaled + context = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + return mx.transpose(context, (0, 2, 1, 3)) + # GQA: repeat K/V heads to match query heads. num_groups = self.num_heads // self.num_kv_heads if num_groups > 1: @@ -308,7 +327,7 @@ def _compute_attention(self, queries, keys, values, mask): mask = mx.concatenate([sink_mask, mask], axis=-1) # Optional soft cap on logits (e.g., Gemma 2 uses cap=50.0). - if self._attention_logits_soft_cap is not None: + if getattr(self, '_attention_logits_soft_cap', None) is not None: cap = self._attention_logits_soft_cap logits = mx.tanh(logits / cap) * cap @@ -787,16 +806,9 @@ def _compute_attention(self, queries, keys, values, mask): q = _scale_queries( q, self._per_dim_scale, self._query_scale, self.units_per_head ) - logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) - - if mask is not None: - large_neg = mx.array(-1e9, dtype=logits.dtype) - logits = mx.where(mask, logits, large_neg) - - weights = mx.softmax(logits, axis=-1) - context = mx.matmul(weights, v) - context = mx.transpose(context, (0, 2, 1, 3)) - return context + + context = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + return mx.transpose(context, (0, 2, 1, 3)) def get_output_shape(self, input_shape, *, constants=None): if len(input_shape) != 1: @@ -1221,6 +1233,18 @@ def _get_source(self, constants): def _compute_attention(self, queries, keys, values, mask): """Compute scaled dot-product attention.""" + if self.sink_key_embeddings is None: + q = mx.transpose(queries, (0, 2, 1, 3)) + k = mx.transpose(keys, (0, 2, 1, 3)) + v = mx.transpose(values, (0, 2, 1, 3)) + + q = _scale_queries( + q, self._per_dim_scale, self._query_scale, self.units_per_head + ) + + context = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + return mx.transpose(context, (0, 2, 1, 3)) + q = mx.transpose(queries, (0, 2, 1, 3)) k = mx.transpose(keys, (0, 2, 1, 3)) v = mx.transpose(values, (0, 2, 1, 3)) From b246e661471e7ac323d7862b3090c0280b8bf935 Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Mon, 2 Mar 2026 12:54:50 -0500 Subject: [PATCH 07/17] Update position.py use mx.fast.rope --- sequence_layers/mlx/position.py | 113 ++++++++++++++++++++------------ 1 file changed, 71 insertions(+), 42 deletions(-) diff --git a/sequence_layers/mlx/position.py b/sequence_layers/mlx/position.py index 32c84ba..5ddb636 100644 --- a/sequence_layers/mlx/position.py +++ b/sequence_layers/mlx/position.py @@ -45,36 +45,66 @@ def __init__( only_advance_position_for_valid_timesteps ) - def _apply_rope(self, x, positions): - """Apply rotary position encoding to x at given positions.""" + def _apply_rope(self, x, offset): + """Applies rotary position encoding to x with a given temporal offset. + + If the rotation axis is the last dimension (the default for most models), + this method leverages the highly optimized `mx.fast.rope` native C++ operation. + Since `mx.fast.rope` strictly expects the sequence length (time) to be the + second-to-last dimension, we transpose the tensor, apply the rotation, and + transpose it back. If rotation is on an inner axis, it falls back to a + manual trig-based computation. + """ axis = self._axis + x.ndim if self._axis < 0 else self._axis - channel_ndim = x.ndim - 2 - axis_dim = x.shape[axis] - - freq_exponents = ( - 2.0 * mx.arange(axis_dim // 2).astype(mx.float32) / axis_dim - ) - timescale = self.max_wavelength**freq_exponents - - broadcast_shape = [1] * x.ndim - broadcast_shape[axis] = axis_dim // 2 - - # Compute position angles. - positions_f = positions.astype(mx.float32) - radians = positions_f.reshape( - positions_f.shape + (1,) * channel_ndim - ) / timescale.reshape(broadcast_shape) - sin_r = mx.sin(radians) - cos_r = mx.cos(radians) - - # Split input along rotation axis, apply rotation. - splits = mx.split(x, 2, axis=axis) - x1, x2 = splits[0], splits[1] - result = mx.concatenate( - [x1 * cos_r - x2 * sin_r, x2 * cos_r + x1 * sin_r], - axis=axis, - ) - return result.astype(x.dtype) + + if axis != x.ndim - 1: + channel_ndim = x.ndim - 2 + axis_dim = x.shape[axis] + + freq_exponents = ( + 2.0 * mx.arange(axis_dim // 2).astype(mx.float32) / axis_dim + ) + timescale = self.max_wavelength**freq_exponents + + broadcast_shape = [1] * x.ndim + broadcast_shape[axis] = axis_dim // 2 + + # Compute position angles using offset + positions = mx.arange(x.shape[1])[None, :] + offset[:, None] + positions_f = positions.astype(mx.float32) + radians = positions_f.reshape( + positions_f.shape + (1,) * channel_ndim + ) / timescale.reshape(broadcast_shape) + sin_r = mx.sin(radians) + cos_r = mx.cos(radians) + + splits = mx.split(x, 2, axis=axis) + x1, x2 = splits[0], splits[1] + result = mx.concatenate( + [x1 * cos_r - x2 * sin_r, x2 * cos_r + x1 * sin_r], + axis=axis, + ) + return result.astype(x.dtype) + + original_axes = list(range(x.ndim)) + if x.ndim >= 3: + transpose_axes = original_axes.copy() + transpose_axes.pop(1) + transpose_axes.insert(-1, 1) + x_t = mx.transpose(x, transpose_axes) + else: + x_t = x + + y_t = mx.fast.rope(x_t, dims=x.shape[-1], traditional=False, base=self.max_wavelength, scale=1.0, offset=offset) + + if x.ndim >= 3: + inv_axes = original_axes.copy() + inv_axes.pop(-2) + inv_axes.insert(1, x.ndim - 2) + y = mx.transpose(y_t, inv_axes) + else: + y = y_t + return y.astype(x.dtype) def get_initial_state(self, batch_size, input_spec, *, constants=None): if self.only_advance_position_for_valid_timesteps: @@ -85,28 +115,27 @@ def get_initial_state(self, batch_size, input_spec, *, constants=None): @types.check_step def step(self, x, state, *, constants=None): x_time = x.shape[1] + if self.only_advance_position_for_valid_timesteps: + # The state stores the last valid position. If initialized to -1, the next + # valid position starts at 0. + offset = mx.maximum(0, state[:, 0] + 1) + + # Update the state to hold the maximum position reached after this step. positions = state + mx.cumsum(x.mask.astype(mx.int32), axis=1) state = positions[:, -1:] else: - positions = state + mx.arange(x_time, dtype=mx.int32) + offset = state[:, 0] state = state + x_time - y = x.apply_values(self._apply_rope, positions) + + y = x.apply_values(self._apply_rope, offset) return y, state @types.check_layer def layer(self, x, *, constants=None): - if self.only_advance_position_for_valid_timesteps: - positions = mx.maximum( - 0, - mx.cumsum(x.mask.astype(mx.int32), axis=1) - 1, - ) - else: - positions = mx.broadcast_to( - mx.arange(x.shape[1], dtype=mx.int32)[None, :], - (x.shape[0], x.shape[1]), - ) - return x.apply_values(self._apply_rope, positions) + # In layer mode, processing starts from time step 0 for all batch elements. + offset = mx.zeros((x.shape[0],), dtype=mx.int32) + return x.apply_values(self._apply_rope, offset) @classmethod def from_config(cls, config): From 9cc56e358476860319a6f12f7d9290963c2e1f9c Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Tue, 3 Mar 2026 17:25:56 -0500 Subject: [PATCH 08/17] add to_quantized --- sequence_layers/mlx/attention.py | 106 +++++++++++++++++++++++ sequence_layers/mlx/dense.py | 139 +++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index e4b9791..5b4aae4 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -15,6 +15,16 @@ Sequence = bt.Sequence MaskedSequence = bt.MaskedSequence +def _quantized_matmul_proj(x, q_weight, q_scales, q_biases, group_size, bits): + return mx.quantized_matmul( + x, q_weight, + scales=q_scales, + biases=q_biases, + transpose=True, + group_size=group_size, + bits=bits, + ) + def _scale_queries(queries, per_dim_scale, query_scale, units_per_head): """Scale queries, optionally with per-dimension learned scale. @@ -561,6 +571,53 @@ def step_with_emits(self, x, state, *, constants=None): ) return Sequence(context, x.mask), new_state, () + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine'): + if getattr(self, 'q_proj', None) is None or self.q_proj.shape[0] % group_size != 0: + return self + + self._quant_group_size = group_size + self._quant_bits = bits + + w_q = self.q_proj.T + w_k = self.k_proj.T + w_v = self.v_proj.T + w_qkv = mx.concatenate([w_q, w_k, w_v], axis=0) + self.qkv_proj_qw, self.qkv_proj_qs, self.qkv_proj_qb = mx.quantize(w_qkv, group_size=group_size, bits=bits) + + self.q_proj = None + self.k_proj = None + self.v_proj = None + + def _project_qkv(self, x): + b, t = x.shape[0], x.shape[1] + dtype = self.compute_dtype or x.dtype + v = x.values.astype(dtype) + + qkv = _quantized_matmul_proj(v, self.qkv_proj_qw, self.qkv_proj_qs, self.qkv_proj_qb, self._quant_group_size, self._quant_bits) + + d_q = self.num_heads * self.units_per_head + d_k = self.num_kv_heads * self.units_per_head + q, k, val = mx.split(qkv, [d_q, d_q + d_k], axis=-1) + + if self.use_bias: + q = q + self.q_bias.astype(dtype) + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) + + q = q.reshape(b, t, self.num_heads, self.units_per_head) + k = k.reshape(b, t, self.num_kv_heads, self.units_per_head) + val = val.reshape(b, t, self.num_kv_heads, self.units_per_head) + + return ( + Sequence(q, x.mask), + Sequence(k, x.mask), + Sequence(val, x.mask), + ) + + import types + self._project_qkv = types.MethodType(_project_qkv, self) + return self + @classmethod def from_config(cls, config): """Create from a Linen DotProductSelfAttention.Config. @@ -1521,6 +1578,55 @@ def step_with_emits(self, x, state, *, constants=None): ) return Sequence(context, queries.mask), new_state, () + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine'): + if getattr(self, 'q_proj', None) is None or self.q_proj.shape[0] % group_size != 0: + return self + + self._quant_group_size = group_size + self._quant_bits = bits + + w_q = self.q_proj.T + self.q_proj_qw, self.q_proj_qs, self.q_proj_qb = mx.quantize(w_q, group_size=group_size, bits=bits) + + w_k = self.k_proj.T + w_v = self.v_proj.T + w_kv = mx.concatenate([w_k, w_v], axis=0) + self.kv_proj_qw, self.kv_proj_qs, self.kv_proj_qb = mx.quantize(w_kv, group_size=group_size, bits=bits) + + self.q_proj = None + self.k_proj = None + self.v_proj = None + + def _project_q(self, x): + b, t = x.shape[0], x.shape[1] + dtype = self.compute_dtype or x.dtype + v = x.values.astype(dtype) + q = _quantized_matmul_proj(v, self.q_proj_qw, self.q_proj_qs, self.q_proj_qb, self._quant_group_size, self._quant_bits) + if self.use_bias: + q = q + self.q_bias.astype(dtype) + q = q.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(q, x.mask) + + def _project_kv(self, source): + b, t = source.shape[0], source.shape[1] + dtype = self.compute_dtype or source.dtype + v = source.values.astype(dtype) + kv = _quantized_matmul_proj(v, self.kv_proj_qw, self.kv_proj_qs, self.kv_proj_qb, self._quant_group_size, self._quant_bits) + d_k = self.num_heads * self.units_per_head + k, val = mx.split(kv, [d_k], axis=-1) + if self.use_bias: + k = k + self.k_bias.astype(dtype) + val = val + self.v_bias.astype(dtype) + k = k.reshape(b, t, self.num_heads, self.units_per_head) + val = val.reshape(b, t, self.num_heads, self.units_per_head) + return Sequence(k, source.mask), Sequence(val, source.mask) + + import types + self._project_q = types.MethodType(_project_q, self) + self._project_kv = types.MethodType(_project_kv, self) + + return self + @classmethod def from_config(cls, config): return DeferredStreamingDotProductAttention(config) diff --git a/sequence_layers/mlx/dense.py b/sequence_layers/mlx/dense.py index ca06094..94fd613 100644 --- a/sequence_layers/mlx/dense.py +++ b/sequence_layers/mlx/dense.py @@ -154,6 +154,52 @@ def dense_fn(v): else: return x.apply_values_masked(dense_fn) + + + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine'): + if self.kernel is None or self._equation != '...nh,dnh->...d' or (self.kernel.shape[-1] * self.kernel.shape[-2]) % group_size != 0: + return self + + _d, _n, _h = self.kernel.shape + kernel_2d = self.kernel.reshape(_d, _n * _h) + self.q_weight, self.q_scales, self.q_biases = mx.quantize( + kernel_2d, group_size=group_size, bits=bits + ) + self._group_size = group_size + self._bits = bits + self.kernel = None + + def layer(self, x, *, constants=None): + compute_dtype = self.get_output_dtype(x.dtype) + def quantized_einsum_fn(v): + original_shape = v.shape + v_2d = v.reshape(*original_shape[:-2], _n * _h) + v_2d = v_2d.astype(compute_dtype) + y = mx.quantized_matmul( + v_2d, + self.q_weight, + scales=self.q_scales, + biases=self.q_biases, + transpose=True, + group_size=self._group_size, + bits=self._bits, + ) + if self.bias is not None: + y = y + self.bias + if self._activation is not None: + y = self._activation(y) + return y + + if self.bias is not None or self._activation is not None: + return x.apply_values(quantized_einsum_fn) + return x.apply_values_masked(quantized_einsum_fn) + + import types + self.layer = types.MethodType(layer, self) + + return self + + @classmethod def from_config(cls, config): """Create a Dense layer from a Linen Dense.Config.""" @@ -229,6 +275,52 @@ def layer(self, x, *, constants=None): self._ensure_initialized(x.shape[-1]) return self.inner.layer(x, constants=constants) + + + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine'): + if self.kernel is None or self._equation != '...nh,dnh->...d' or (self.kernel.shape[-1] * self.kernel.shape[-2]) % group_size != 0: + return self + + _d, _n, _h = self.kernel.shape + kernel_2d = self.kernel.reshape(_d, _n * _h) + self.q_weight, self.q_scales, self.q_biases = mx.quantize( + kernel_2d, group_size=group_size, bits=bits + ) + self._group_size = group_size + self._bits = bits + self.kernel = None + + def layer(self, x, *, constants=None): + compute_dtype = self.get_output_dtype(x.dtype) + def quantized_einsum_fn(v): + original_shape = v.shape + v_2d = v.reshape(*original_shape[:-2], _n * _h) + v_2d = v_2d.astype(compute_dtype) + y = mx.quantized_matmul( + v_2d, + self.q_weight, + scales=self.q_scales, + biases=self.q_biases, + transpose=True, + group_size=self._group_size, + bits=self._bits, + ) + if self.bias is not None: + y = y + self.bias + if self._activation is not None: + y = self._activation(y) + return y + + if self.bias is not None or self._activation is not None: + return x.apply_values(quantized_einsum_fn) + return x.apply_values_masked(quantized_einsum_fn) + + import types + self.layer = types.MethodType(layer, self) + + return self + + @classmethod def from_config(cls, config): """Create from a Linen Dense.Config.""" @@ -326,6 +418,52 @@ def einsum_fn(v): return x.apply_values(einsum_fn) return x.apply_values_masked(einsum_fn) + + + def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine'): + if self.kernel is None or self._equation != '...nh,dnh->...d' or (self.kernel.shape[-1] * self.kernel.shape[-2]) % group_size != 0: + return self + + _d, _n, _h = self.kernel.shape + kernel_2d = self.kernel.reshape(_d, _n * _h) + self.q_weight, self.q_scales, self.q_biases = mx.quantize( + kernel_2d, group_size=group_size, bits=bits + ) + self._group_size = group_size + self._bits = bits + self.kernel = None + + def layer(self, x, *, constants=None): + compute_dtype = self.get_output_dtype(x.dtype) + def quantized_einsum_fn(v): + original_shape = v.shape + v_2d = v.reshape(*original_shape[:-2], _n * _h) + v_2d = v_2d.astype(compute_dtype) + y = mx.quantized_matmul( + v_2d, + self.q_weight, + scales=self.q_scales, + biases=self.q_biases, + transpose=True, + group_size=self._group_size, + bits=self._bits, + ) + if self.bias is not None: + y = y + self.bias + if self._activation is not None: + y = self._activation(y) + return y + + if self.bias is not None or self._activation is not None: + return x.apply_values(quantized_einsum_fn) + return x.apply_values_masked(quantized_einsum_fn) + + import types + self.layer = types.MethodType(layer, self) + + return self + + @classmethod def from_config(cls, config): compute_dtype = getattr(config, 'compute_dtype', None) @@ -343,3 +481,4 @@ def from_config(cls, config): # Alias so that sl.Dense.Config(...) works like sl_jax.Dense.Config(...). Dense.Config = DenseDeferred.Config + From 14bd756ad3057240fb492d8fdec898d718979fae Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:03:11 -0500 Subject: [PATCH 09/17] optimizations --- sequence_layers/mlx/attention.py | 106 +++++++++++++--------- sequence_layers/mlx/attention_test.py | 3 +- sequence_layers/mlx/convolution.py | 59 ++++++++++--- sequence_layers/mlx/dsp.py | 113 +++++++++++++++--------- sequence_layers/mlx/weight_converter.py | 85 +++++++++--------- 5 files changed, 224 insertions(+), 142 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index 5b4aae4..fa412cb 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -189,12 +189,17 @@ def __init__( self.qkv_bias = bias_init(key, (out_dim,), param_dtype) else: self.q_proj = kernel_init(key, (in_features, q_dim), param_dtype) - self.k_proj = kernel_init(key, (in_features, kv_dim), param_dtype) - self.v_proj = kernel_init(key, (in_features, kv_dim), param_dtype) + # Combined K+V projection: single matmul + split is faster than two. + self.kv_proj = mx.concatenate([ + kernel_init(key, (in_features, kv_dim), param_dtype), + kernel_init(key, (in_features, kv_dim), param_dtype), + ], axis=-1) if use_bias: self.q_bias = bias_init(key, (q_dim,), param_dtype) - self.k_bias = bias_init(key, (kv_dim,), param_dtype) - self.v_bias = bias_init(key, (kv_dim,), param_dtype) + self.kv_bias = mx.concatenate([ + bias_init(key, (kv_dim,), param_dtype), + bias_init(key, (kv_dim,), param_dtype), + ], axis=-1) # Attention sink embeddings. self.num_sink_embeddings = num_sink_embeddings @@ -236,13 +241,15 @@ def _project_qkv(self, x): q, k, val = mx.split(qkv, 3, axis=-1) else: q = mx.matmul(v, self.q_proj.astype(dtype)) - k = mx.matmul(v, self.k_proj.astype(dtype)) - val = mx.matmul(v, self.v_proj.astype(dtype)) + kv = mx.matmul(v, self.kv_proj.astype(dtype)) + k, val = mx.split(kv, 2, axis=-1) if self.use_bias: q = q + self.q_bias.astype(dtype) - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + kv_bias = self.kv_bias.astype(dtype) + kb, vb = mx.split(kv_bias, 2, axis=-1) + k = k + kb + val = val + vb # Reshape to [b, t, heads, units_per_head]. q = q.reshape(b, t, self.num_heads, self.units_per_head) @@ -579,14 +586,13 @@ def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine' self._quant_bits = bits w_q = self.q_proj.T - w_k = self.k_proj.T - w_v = self.v_proj.T - w_qkv = mx.concatenate([w_q, w_k, w_v], axis=0) + # kv_proj is already combined [in, 2*kv_dim]. + w_kv = self.kv_proj.T + w_qkv = mx.concatenate([w_q, w_kv], axis=0) self.qkv_proj_qw, self.qkv_proj_qs, self.qkv_proj_qb = mx.quantize(w_qkv, group_size=group_size, bits=bits) - + self.q_proj = None - self.k_proj = None - self.v_proj = None + self.kv_proj = None def _project_qkv(self, x): b, t = x.shape[0], x.shape[1] @@ -601,8 +607,10 @@ def _project_qkv(self, x): if self.use_bias: q = q + self.q_bias.astype(dtype) - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + kv_bias = self.kv_bias.astype(dtype) + kb, vb = mx.split(kv_bias, 2, axis=-1) + k = k + kb + val = val + vb q = q.reshape(b, t, self.num_heads, self.units_per_head) k = k.reshape(b, t, self.num_kv_heads, self.units_per_head) @@ -805,12 +813,17 @@ def __init__( qkv_dim = num_heads * units_per_head self.q_proj = kernel_init(key, (in_features, qkv_dim), param_dtype) - self.k_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) - self.v_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + # Combined K+V projection: single matmul + split is faster than two. + self.kv_proj = mx.concatenate([ + kernel_init(key, (source_features, qkv_dim), param_dtype), + kernel_init(key, (source_features, qkv_dim), param_dtype), + ], axis=-1) if use_bias: self.q_bias = bias_init(key, (qkv_dim,), param_dtype) - self.k_bias = bias_init(key, (qkv_dim,), param_dtype) - self.v_bias = bias_init(key, (qkv_dim,), param_dtype) + self.kv_bias = mx.concatenate([ + bias_init(key, (qkv_dim,), param_dtype), + bias_init(key, (qkv_dim,), param_dtype), + ], axis=-1) self.query_network = query_network self.key_network = key_network @@ -840,11 +853,13 @@ def _project_kv(self, source): b, t = source.shape[0], source.shape[1] dtype = self.compute_dtype or source.dtype v = source.values.astype(dtype) - k = mx.matmul(v, self.k_proj.astype(dtype)) - val = mx.matmul(v, self.v_proj.astype(dtype)) + kv = mx.matmul(v, self.kv_proj.astype(dtype)) + k, val = mx.split(kv, 2, axis=-1) if self.use_bias: - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + kv_bias = self.kv_bias.astype(dtype) + kb, vb = mx.split(kv_bias, 2, axis=-1) + k = k + kb + val = val + vb k = k.reshape(b, t, self.num_heads, self.units_per_head) val = val.reshape(b, t, self.num_heads, self.units_per_head) return Sequence(k, source.mask), Sequence(val, source.mask) @@ -1224,13 +1239,17 @@ def __init__( # Q projection from input. self.q_proj = kernel_init(key, (in_features, qkv_dim), param_dtype) - # K/V projections from source. - self.k_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) - self.v_proj = kernel_init(key, (source_features, qkv_dim), param_dtype) + # Combined K+V projection from source: single matmul + split. + self.kv_proj = mx.concatenate([ + kernel_init(key, (source_features, qkv_dim), param_dtype), + kernel_init(key, (source_features, qkv_dim), param_dtype), + ], axis=-1) if use_bias: self.q_bias = bias_init(key, (qkv_dim,), param_dtype) - self.k_bias = bias_init(key, (qkv_dim,), param_dtype) - self.v_bias = bias_init(key, (qkv_dim,), param_dtype) + self.kv_bias = mx.concatenate([ + bias_init(key, (qkv_dim,), param_dtype), + bias_init(key, (qkv_dim,), param_dtype), + ], axis=-1) # Attention sink embeddings. self.num_sink_embeddings = num_sink_embeddings if num_sink_embeddings > 0: @@ -1274,11 +1293,13 @@ def _project_kv(self, source): b, t = source.shape[0], source.shape[1] dtype = self.compute_dtype or source.dtype v = source.values.astype(dtype) - k = mx.matmul(v, self.k_proj.astype(dtype)) - val = mx.matmul(v, self.v_proj.astype(dtype)) + kv = mx.matmul(v, self.kv_proj.astype(dtype)) + k, val = mx.split(kv, 2, axis=-1) if self.use_bias: - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + kv_bias = self.kv_bias.astype(dtype) + kb, vb = mx.split(kv_bias, 2, axis=-1) + k = k + kb + val = val + vb k = k.reshape(b, t, self.num_heads, self.units_per_head) val = val.reshape(b, t, self.num_heads, self.units_per_head) return Sequence(k, source.mask), Sequence(val, source.mask) @@ -1587,15 +1608,13 @@ def to_quantized(self, group_size: int = 64, bits: int = 4, mode: str = 'affine' w_q = self.q_proj.T self.q_proj_qw, self.q_proj_qs, self.q_proj_qb = mx.quantize(w_q, group_size=group_size, bits=bits) - - w_k = self.k_proj.T - w_v = self.v_proj.T - w_kv = mx.concatenate([w_k, w_v], axis=0) + + # kv_proj is already combined [source, 2*qkv_dim]. + w_kv = self.kv_proj.T self.kv_proj_qw, self.kv_proj_qs, self.kv_proj_qb = mx.quantize(w_kv, group_size=group_size, bits=bits) self.q_proj = None - self.k_proj = None - self.v_proj = None + self.kv_proj = None def _project_q(self, x): b, t = x.shape[0], x.shape[1] @@ -1612,11 +1631,12 @@ def _project_kv(self, source): dtype = self.compute_dtype or source.dtype v = source.values.astype(dtype) kv = _quantized_matmul_proj(v, self.kv_proj_qw, self.kv_proj_qs, self.kv_proj_qb, self._quant_group_size, self._quant_bits) - d_k = self.num_heads * self.units_per_head - k, val = mx.split(kv, [d_k], axis=-1) + k, val = mx.split(kv, 2, axis=-1) if self.use_bias: - k = k + self.k_bias.astype(dtype) - val = val + self.v_bias.astype(dtype) + kv_bias = self.kv_bias.astype(dtype) + kb, vb = mx.split(kv_bias, 2, axis=-1) + k = k + kb + val = val + vb k = k.reshape(b, t, self.num_heads, self.units_per_head) val = val.reshape(b, t, self.num_heads, self.units_per_head) return Sequence(k, source.mask), Sequence(val, source.mask) diff --git a/sequence_layers/mlx/attention_test.py b/sequence_layers/mlx/attention_test.py index e274e32..e574ce4 100644 --- a/sequence_layers/mlx/attention_test.py +++ b/sequence_layers/mlx/attention_test.py @@ -115,8 +115,7 @@ def test_per_dim_scale(self): ) # Copy weights so projections match. layer_no_pds.q_proj = layer.q_proj - layer_no_pds.k_proj = layer.k_proj - layer_no_pds.v_proj = layer.v_proj + layer_no_pds.kv_proj = layer.kv_proj x = test_utils.random_sequence(1, 5, 8) y_pds = layer.layer(x) diff --git a/sequence_layers/mlx/convolution.py b/sequence_layers/mlx/convolution.py index 2383bbd..2c8d549 100644 --- a/sequence_layers/mlx/convolution.py +++ b/sequence_layers/mlx/convolution.py @@ -14,6 +14,50 @@ MaskedSequence = bt.MaskedSequence PaddingMode = bt.PaddingMode +# Module-level cache for mask convolution kernels. Keys are tuples of +# deterministic parameters; values are small mx.arrays. The cache is +# bounded (one entry per unique configuration) and shared across all +# layer instances. +_MASK_KERNEL_CACHE: dict[tuple, mx.array] = {} + + +def _get_padding_kernel(pad_left, pad_right): + """Get or create the padding mask kernel for step-mode conv mask.""" + key = ('pad', pad_left, pad_right) + if key not in _MASK_KERNEL_CACHE: + k = [0.0] * pad_left + [1.0] + [0.0] * pad_right + _MASK_KERNEL_CACHE[key] = mx.array(k, dtype=mx.float32).reshape(1, -1, 1) + return _MASK_KERNEL_CACHE[key] + + +def _get_logical_kernel(kernel_size, dilation_rate): + """Get or create the logical mask kernel for reduce_window simulation.""" + key = ('logical', kernel_size, dilation_rate) + if key not in _MASK_KERNEL_CACHE: + if dilation_rate == 1: + _MASK_KERNEL_CACHE[key] = mx.ones( + (1, kernel_size, 1), dtype=mx.float32 + ) + else: + ek = _effective_kernel_size(kernel_size, dilation_rate) + k = [0.0] * ek + for i in range(kernel_size): + k[i * dilation_rate] = 1.0 + _MASK_KERNEL_CACHE[key] = ( + mx.array(k, dtype=mx.float32).reshape(1, -1, 1) + ) + return _MASK_KERNEL_CACHE[key] + + +def _get_transpose_kernel(kernel_size): + """Get or create the transpose conv mask kernel.""" + key = ('transpose', kernel_size) + if key not in _MASK_KERNEL_CACHE: + _MASK_KERNEL_CACHE[key] = mx.ones( + (1, kernel_size, 1), dtype=mx.float32 + ) + return _MASK_KERNEL_CACHE[key] + # --------------------------------------------------------------------------- # Padding utilities (ported from jax/utils.py and jax/convolution.py) @@ -101,8 +145,7 @@ def _compute_conv_mask( padding, kernel_size, stride, dilation_rate ) # Use a simple convolution-like mask computation with float kernel. - kernel = [0.0] * pad_left + [1.0] + [0.0] * pad_right - kernel = mx.array(kernel, dtype=mx.float32).reshape(1, -1, 1) + kernel = _get_padding_kernel(pad_left, pad_right) mask_f = mask[:, :, None].astype(mx.float32) mask_conv = mx.conv1d(mask_f, kernel, stride=stride) return mx.squeeze(mask_conv, axis=-1).astype(mx.bool_) @@ -175,15 +218,7 @@ def _compute_conv_mask_logical( # Use float conv to simulate reduce_window. mask_f = mask[:, :, None].astype(mx.float32) - # Build a kernel with ones at dilated positions. - if dilation_rate == 1: - kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) - else: - ek = _effective_kernel_size(kernel_size, dilation_rate) - k = [0.0] * ek - for i in range(kernel_size): - k[i * dilation_rate] = 1.0 - kernel = mx.array(k, dtype=mx.float32).reshape(1, -1, 1) + kernel = _get_logical_kernel(kernel_size, dilation_rate) result = mx.conv1d(mask_f, kernel, stride=stride) result = mx.squeeze(result, axis=-1) @@ -662,7 +697,7 @@ def _compute_conv_transpose_mask( test_signal = mx.logical_not(mask) test_fn = lambda m: m == 0.0 - kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) + kernel = _get_transpose_kernel(kernel_size) signal = test_signal.astype(mx.float32)[:, :, None] result = mx.conv_transpose1d( diff --git a/sequence_layers/mlx/dsp.py b/sequence_layers/mlx/dsp.py index 32bf3ad..87148df 100644 --- a/sequence_layers/mlx/dsp.py +++ b/sequence_layers/mlx/dsp.py @@ -64,15 +64,7 @@ def frame(values, frame_length, frame_step, pad_mode='valid', axis=1): # Compute number of frames. num_frames = max(0, (t - frame_length) // frame_step + 1) - # Gather frames using indexing. - indices = ( - mx.arange(num_frames)[:, None] * frame_step - + mx.arange(frame_length)[None, :] - ) - # indices: [num_frames, frame_length] - - # Flatten, gather, reshape. - # Move axis to position 1 for easier manipulation. + # Move target axis to position 1 for uniform handling. if axis != 1: perm = list(range(values.ndim)) perm[1], perm[axis] = perm[axis], perm[1] @@ -82,18 +74,33 @@ def frame(values, frame_length, frame_step, pad_mode='valid', axis=1): batch = values.shape[0] rest_shape = values.shape[2:] - # Gather: result [batch, num_frames, frame_length, ...] - # Use fancy indexing along axis 1. - result = values[:, indices.reshape(-1)] - result = result.reshape((batch, num_frames, frame_length) + rest_shape) + # Fast path: zero-copy strided view for contiguous data. + rest_size = 1 + for d in rest_shape: + rest_size *= d + + batch_stride = t * rest_size + frame_stride = frame_step * rest_size + time_stride = rest_size + + # Compute rest strides from contiguous layout. + rest_strides = [] + s = 1 + for d in reversed(rest_shape): + rest_strides.append(s) + s *= d + rest_strides.reverse() + + result = mx.as_strided( + values, + shape=(batch, num_frames, frame_length) + rest_shape, + strides=(batch_stride, frame_stride, time_stride) + tuple(rest_strides), + ) if axis != 1: # Move back. perm = list(range(result.ndim)) - # axis was swapped to 1, new dims are at 1 and 2. - # Need to move them back so axis and axis+1 have the frame dims. perm[1], perm[axis] = perm[axis], perm[1] - # Also move frame_length dim. if axis > 1: perm.insert(axis + 1, perm.pop(2)) result = mx.transpose(result, perm) @@ -121,17 +128,22 @@ def overlap_and_add(signal_arr, frame_step): if frame_length == frame_step: return signal_arr.reshape(outer_dims + (output_length,)) - # General overlap-add via scatter. + # Vectorized overlap-add via scatter. outer_size = 1 for d in outer_dims: outer_size *= d flat = signal_arr.reshape(outer_size, frames, frame_length) - result = mx.zeros((outer_size, output_length), dtype=flat.dtype) - for f in range(frames): - start = f * frame_step - result = result.at[:, start : start + frame_length].add(flat[:, f]) + # Build output position indices: [frames, frame_length]. + offsets = mx.arange(frames)[:, None] * frame_step + positions = offsets + mx.arange(frame_length)[None, :] + flat_positions = positions.reshape(-1) # [frames * frame_length] + + # Flatten signal and scatter-add all frame contributions at once. + flat_signal = flat.reshape(outer_size, frames * frame_length) + result = mx.zeros((outer_size, output_length), dtype=flat.dtype) + result = result.at[:, flat_positions].add(flat_signal) return result.reshape(outer_dims + (output_length,)) @@ -161,20 +173,22 @@ def mel_to_hz(m): mel_points = np.linspace(mel_low, mel_high, num_mel_bins + 2) hz_points = mel_to_hz(mel_points) - weights = np.zeros((num_spectrogram_bins, num_mel_bins), dtype=dtype) - for i in range(num_mel_bins): - lower = hz_points[i] - center = hz_points[i + 1] - upper = hz_points[i + 2] - - # Rising slope. - for j in range(num_spectrogram_bins): - if lower <= freq_bins[j] <= center and center > lower: - weights[j, i] = (freq_bins[j] - lower) / (center - lower) - elif center < freq_bins[j] <= upper and upper > center: - weights[j, i] = (upper - freq_bins[j]) / (upper - center) + lower = hz_points[:-2][np.newaxis, :] # [1, num_mel_bins] + center = hz_points[1:-1][np.newaxis, :] # [1, num_mel_bins] + upper = hz_points[2:][np.newaxis, :] # [1, num_mel_bins] + freq = freq_bins[:, np.newaxis] # [num_spectrogram_bins, 1] - return weights + rising = np.where( + (freq >= lower) & (freq <= center) & (center > lower), + (freq - lower) / np.maximum(center - lower, 1e-10), + 0.0, + ) + falling = np.where( + (freq > center) & (freq <= upper) & (upper > center), + (upper - freq) / np.maximum(upper - center, 1e-10), + 0.0, + ) + return (rising + falling).astype(dtype) # --------------------------------------------------------------------------- @@ -1202,23 +1216,36 @@ def __init__( self.sample_rate = sample_rate self.lower_edge_hertz = lower_edge_hertz self.upper_edge_hertz = upper_edge_hertz + self._cached_weights = None + self._cached_num_bins = None + self._cached_dtype = None def get_output_shape(self, input_shape, *, constants=None): if not input_shape: raise ValueError('LinearToMelSpectrogram requires rank >= 1 input.') return tuple(input_shape[:-1]) + (self.num_mel_bins,) + def _get_weights(self, num_bins, dtype): + if ( + self._cached_weights is None + or self._cached_num_bins != num_bins + or self._cached_dtype != dtype + ): + weights = linear_to_mel_weight_matrix( + num_mel_bins=self.num_mel_bins, + num_spectrogram_bins=num_bins, + sample_rate=self.sample_rate, + lower_edge_hertz=self.lower_edge_hertz, + upper_edge_hertz=self.upper_edge_hertz, + ) + self._cached_weights = mx.array(weights, dtype=dtype) + self._cached_num_bins = num_bins + self._cached_dtype = dtype + return self._cached_weights + @types.check_layer def layer(self, x, *, constants=None): - num_bins = x.shape[-1] - weights = linear_to_mel_weight_matrix( - num_mel_bins=self.num_mel_bins, - num_spectrogram_bins=num_bins, - sample_rate=self.sample_rate, - lower_edge_hertz=self.lower_edge_hertz, - upper_edge_hertz=self.upper_edge_hertz, - ) - weights = mx.array(weights, dtype=x.dtype) + weights = self._get_weights(x.shape[-1], x.dtype) return x.apply_values_masked(lambda v: v @ weights) @classmethod diff --git a/sequence_layers/mlx/weight_converter.py b/sequence_layers/mlx/weight_converter.py index 93c1057..19076a2 100644 --- a/sequence_layers/mlx/weight_converter.py +++ b/sequence_layers/mlx/weight_converter.py @@ -322,10 +322,14 @@ def _load_attention(mlx_attn, linen_params, config): if hasattr(inner, 'qkv_proj'): inner.qkv_proj = mx.array(combined_kernel.reshape(in_features, -1)) else: + # Separate Q + combined KV layout. q, k, v = np.split(combined_kernel, 3, axis=1) inner.q_proj = mx.array(q.reshape(in_features, -1)) - inner.k_proj = mx.array(k.reshape(in_features, -1)) - inner.v_proj = mx.array(v.reshape(in_features, -1)) + k_flat = k.reshape(in_features, -1) + v_flat = v.reshape(in_features, -1) + inner.kv_proj = mx.array( + np.concatenate([k_flat, v_flat], axis=-1) + ) combined_bias = qkv_params.get('bias') if combined_bias is not None: @@ -334,8 +338,9 @@ def _load_attention(mlx_attn, linen_params, config): else: qb, kb, vb = np.split(combined_bias, 3, axis=0) inner.q_bias = mx.array(qb.reshape(-1)) - inner.k_bias = mx.array(kb.reshape(-1)) - inner.v_bias = mx.array(vb.reshape(-1)) + inner.kv_bias = mx.array( + np.concatenate([kb.reshape(-1), vb.reshape(-1)], axis=-1) + ) elif isinstance( input_projection, (attn_common.SeparateQueryKeyValueProjection, mlx_proj.SeparateQueryKeyValueProjection) @@ -352,21 +357,19 @@ def _load_attention(mlx_attn, linen_params, config): k_params = linen_params.get('key_projection', {}) k_kernel = k_params.get('kernel') - if k_kernel is not None: - in_features = k_kernel.shape[0] - inner.k_proj = mx.array(k_kernel.reshape(in_features, -1)) - k_bias = k_params.get('bias') - if k_bias is not None: - inner.k_bias = mx.array(k_bias.reshape(-1)) - v_params = linen_params.get('value_projection', {}) v_kernel = v_params.get('kernel') - if v_kernel is not None: - in_features = v_kernel.shape[0] - inner.v_proj = mx.array(v_kernel.reshape(in_features, -1)) + if k_kernel is not None and v_kernel is not None: + in_features = k_kernel.shape[0] + k_flat = k_kernel.reshape(in_features, -1) + v_flat = v_kernel.reshape(in_features, -1) + inner.kv_proj = mx.array(np.concatenate([k_flat, v_flat], axis=-1)) + k_bias = k_params.get('bias') v_bias = v_params.get('bias') - if v_bias is not None: - inner.v_bias = mx.array(v_bias.reshape(-1)) + if k_bias is not None and v_bias is not None: + inner.kv_bias = mx.array( + np.concatenate([k_bias.reshape(-1), v_bias.reshape(-1)], axis=-1) + ) # per_dim_scale: learned [units_per_head] query scale. per_dim_scale = linen_params.get('per_dim_scale') @@ -414,59 +417,57 @@ def _load_streaming_attention(mlx_attn, linen_params, config): inner.q_bias = mx.array(q_bias.reshape(-1)) if isinstance(input_projection, (attn_common.QueryAndKeyValueProjection, mlx_proj.QueryAndKeyValueProjection)): - # Combined KV: kernel [source, 2, heads, uph] → split into K, V. + # Combined KV: kernel [source, 2, heads, uph] → combined kv_proj. kv_params = linen_params.get('key_value_projection', {}) kv_kernel = kv_params.get('kernel') if kv_kernel is not None: source_features = kv_kernel.shape[0] - # Split along axis 1 (the '2' axis for K/V). + # Split along axis 1 (the '2' axis for K/V), flatten, recombine. k, v = np.split(kv_kernel, 2, axis=1) - inner.k_proj = mx.array(k.reshape(source_features, -1)) - inner.v_proj = mx.array(v.reshape(source_features, -1)) + k_flat = k.reshape(source_features, -1) + v_flat = v.reshape(source_features, -1) + inner.kv_proj = mx.array(np.concatenate([k_flat, v_flat], axis=-1)) kv_bias = kv_params.get('bias') if kv_bias is not None: kb, vb = np.split(kv_bias, 2, axis=0) - inner.k_bias = mx.array(kb.reshape(-1)) - inner.v_bias = mx.array(vb.reshape(-1)) + inner.kv_bias = mx.array( + np.concatenate([kb.reshape(-1), vb.reshape(-1)], axis=-1) + ) elif isinstance( input_projection, (attn_common.SeparateQueryKeyValueProjection, mlx_proj.SeparateQueryKeyValueProjection) ): - # Separate K and V projections. + # Separate K and V projections → combined kv_proj. k_params = linen_params.get('key_projection', {}) k_kernel = k_params.get('kernel') - if k_kernel is not None: - source_features = k_kernel.shape[0] - inner.k_proj = mx.array(k_kernel.reshape(source_features, -1)) - k_bias = k_params.get('bias') - if k_bias is not None: - inner.k_bias = mx.array(k_bias.reshape(-1)) - v_params = linen_params.get('value_projection', {}) v_kernel = v_params.get('kernel') - if v_kernel is not None: - source_features = v_kernel.shape[0] - inner.v_proj = mx.array(v_kernel.reshape(source_features, -1)) + if k_kernel is not None and v_kernel is not None: + source_features = k_kernel.shape[0] + k_flat = k_kernel.reshape(source_features, -1) + v_flat = v_kernel.reshape(source_features, -1) + inner.kv_proj = mx.array(np.concatenate([k_flat, v_flat], axis=-1)) + k_bias = k_params.get('bias') v_bias = v_params.get('bias') - if v_bias is not None: - inner.v_bias = mx.array(v_bias.reshape(-1)) + if k_bias is not None and v_bias is not None: + inner.kv_bias = mx.array( + np.concatenate([k_bias.reshape(-1), v_bias.reshape(-1)], axis=-1) + ) elif isinstance( input_projection, (attn_common.QueryAndSharedKeyValueProjection, mlx_proj.QueryAndSharedKeyValueProjection) ): - # Shared K/V projection: same weights for both K and V. + # Shared K/V projection: same weights for both K and V → combined kv_proj. shared_params = linen_params.get('shared_key_value_projection', {}) shared_kernel = shared_params.get('kernel') if shared_kernel is not None: source_features = shared_kernel.shape[0] - proj = mx.array(shared_kernel.reshape(source_features, -1)) - inner.k_proj = proj - inner.v_proj = proj + proj = shared_kernel.reshape(source_features, -1) + inner.kv_proj = mx.array(np.concatenate([proj, proj], axis=-1)) shared_bias = shared_params.get('bias') if shared_bias is not None: - b = mx.array(shared_bias.reshape(-1)) - inner.k_bias = b - inner.v_bias = b + b = shared_bias.reshape(-1) + inner.kv_bias = mx.array(np.concatenate([b, b], axis=-1)) # per_dim_scale: learned [units_per_head] query scale. per_dim_scale = linen_params.get('per_dim_scale') From c33c0b3b22389aac03d0193229e4cc1beab6d5ee Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:16:10 -0500 Subject: [PATCH 10/17] attention optimizations --- sequence_layers/mlx/attention.py | 192 +++++++++++++++++-------------- 1 file changed, 106 insertions(+), 86 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index fa412cb..cdb5bf3 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -26,6 +26,22 @@ def _quantized_matmul_proj(x, q_weight, q_scales, q_biases, group_size, bits): ) +def _query_scale_vector(per_dim_scale, query_scale, units_per_head, dtype): + """Compute the per-dimension query scale vector. + + Returns: + scale: [units_per_head] array or scalar float. + """ + if query_scale is None: + query_scale = 1.0 / math.sqrt(units_per_head) + if per_dim_scale is not None: + r_softplus_0 = 1.442695041 + scale = r_softplus_0 * query_scale + softplus = mx.log1p(mx.exp(per_dim_scale.astype(dtype))) + return scale * softplus + return query_scale + + def _scale_queries(queries, per_dim_scale, query_scale, units_per_head): """Scale queries, optionally with per-dimension learned scale. @@ -40,17 +56,10 @@ def _scale_queries(queries, per_dim_scale, query_scale, units_per_head): Returns: Scaled queries, same shape. """ - if query_scale is None: - query_scale = 1.0 / math.sqrt(units_per_head) - if per_dim_scale is not None: - # 1/softplus(0) = 1/ln(2). At init (zeros), effective scale = query_scale. - r_softplus_0 = 1.442695041 - scale = r_softplus_0 * query_scale - softplus = mx.log1p(mx.exp(per_dim_scale.astype(queries.dtype))) - queries = queries * (scale * softplus) - else: - queries = queries * query_scale - return queries + scale = _query_scale_vector( + per_dim_scale, query_scale, units_per_head, queries.dtype + ) + return queries * scale def _causal_mask(q_len, kv_len): @@ -274,67 +283,95 @@ def _compute_attention(self, queries, keys, values, mask): Returns: context: [b, q_t, num_heads, units_per_head] """ - can_use_fast_sdpa = ( - self.sink_key_embeddings is None and - getattr(self, '_attention_logits_soft_cap', None) is None - ) + # Use mx.fast.scaled_dot_product_attention unless soft_cap forces + # manual logit manipulation. + has_soft_cap = getattr(self, '_attention_logits_soft_cap', None) is not None - if can_use_fast_sdpa: - # Fast path: GQA is natively supported, so we do not repeat keys/values. + if not has_soft_cap: + # SDPA path — handles both plain and sink cases. q = mx.transpose(queries, (0, 2, 1, 3)) k = mx.transpose(keys, (0, 2, 1, 3)) v = mx.transpose(values, (0, 2, 1, 3)) - + q = _scale_queries( q, self._per_dim_scale, self._query_scale, self.units_per_head ) - - # Use mx.fast.scaled_dot_product_attention with scale=1.0 since we pre-scaled - context = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) + + if self.sink_key_embeddings is not None: + # JAX computes sink logits with *unscaled* queries. To use SDPA + # we pre-divide sink keys by the scale so that: + # scaled_q @ (sink_k / scale) == unscaled_q @ sink_k + scale_vec = _query_scale_vector( + self._per_dim_scale, self._query_scale, + self.units_per_head, q.dtype, + ) + sink_k = self.sink_key_embeddings.astype(q.dtype) / scale_vec + sink_v = self.sink_value_embeddings.astype(v.dtype) + + # GQA: repeat sink heads to match query heads. + num_groups = self.num_heads // self.num_kv_heads + if num_groups > 1: + sink_v = mx.repeat(sink_v, num_groups, axis=1) + + # Transpose [K, nh, h] → [nh, K, h] and broadcast batch. + sink_k_b = mx.broadcast_to( + mx.transpose(sink_k, (1, 0, 2))[None], + (q.shape[0], self.num_heads, sink_k.shape[0], self.units_per_head), + ) + sink_v_b = mx.broadcast_to( + mx.transpose(sink_v, (1, 0, 2))[None], + (v.shape[0], self.num_heads, sink_v.shape[0], self.units_per_head), + ) + + # Prepend sinks to K/V. + k = mx.concatenate([sink_k_b, k], axis=2) + v = mx.concatenate([sink_v_b, v], axis=2) + + # Extend mask — sinks are always valid. + if mask is not None: + num_sinks = self.sink_key_embeddings.shape[0] + sink_mask = mx.ones( + (mask.shape[0], mask.shape[1], mask.shape[2], num_sinks), + dtype=mx.bool_, + ) + mask = mx.concatenate([sink_mask, mask], axis=-1) + + context = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) return mx.transpose(context, (0, 2, 1, 3)) - # GQA: repeat K/V heads to match query heads. + # Manual path — only for attention_logits_soft_cap. num_groups = self.num_heads // self.num_kv_heads if num_groups > 1: - b, kv_t, nk, h = keys.shape keys = mx.repeat(keys, num_groups, axis=2) values = mx.repeat(values, num_groups, axis=2) - # Transpose to [b, heads, t, h] for batched matmul. - q = mx.transpose(queries, (0, 2, 1, 3)) # [b, nh, qt, h] - k = mx.transpose(keys, (0, 2, 1, 3)) # [b, nh, kvt, h] - v = mx.transpose(values, (0, 2, 1, 3)) # [b, nh, kvt, h] + q = mx.transpose(queries, (0, 2, 1, 3)) + k = mx.transpose(keys, (0, 2, 1, 3)) + v = mx.transpose(values, (0, 2, 1, 3)) - # Scaled dot-product attention. # Compute sink logits BEFORE scaling queries, matching JAX behavior. - # JAX computes sink_key_logits = einsum('BTNH,KNH->BNTK', queries.values, - # sink_key_embeddings) before _scale_query(). if self.sink_key_embeddings is not None: - sink_k = self.sink_key_embeddings.astype(q.dtype) # [K, nh, h] - sink_k_t = mx.transpose(sink_k, (1, 2, 0)) # [nh, h, K] - sink_logits = mx.matmul(q, sink_k_t) # [b, nh, qt, K] + sink_k = self.sink_key_embeddings.astype(q.dtype) + sink_k_t = mx.transpose(sink_k, (1, 2, 0)) + sink_logits = mx.matmul(q, sink_k_t) q = _scale_queries( q, self._per_dim_scale, self._query_scale, self.units_per_head ) logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) - # Add attention sink logits if present. if self.sink_key_embeddings is not None: - # Prepend sink values to v: v becomes [b, nh, K+kvt, h] - sink_v = self.sink_value_embeddings.astype(v.dtype) # [K, nkv, h] + sink_v = self.sink_value_embeddings.astype(v.dtype) if num_groups > 1: sink_v = mx.repeat(sink_v, num_groups, axis=1) - sink_v_t = mx.transpose(sink_v, (1, 0, 2)) # [nh, K, h] + sink_v_t = mx.transpose(sink_v, (1, 0, 2)) sink_v_b = mx.broadcast_to( sink_v_t[None], (v.shape[0],) + sink_v_t.shape - ) # [b, nh, K, h] - v = mx.concatenate([sink_v_b, v], axis=2) # [b, nh, K+kvt, h] - - # Prepend sink logits to logits: [b, nh, qt, K+kvt] + ) + v = mx.concatenate([sink_v_b, v], axis=2) logits = mx.concatenate([sink_logits, logits], axis=-1) - - # Extend mask for sinks (always valid). if mask is not None: num_sinks = self.sink_key_embeddings.shape[0] sink_mask = mx.ones( @@ -343,22 +380,16 @@ def _compute_attention(self, queries, keys, values, mask): ) mask = mx.concatenate([sink_mask, mask], axis=-1) - # Optional soft cap on logits (e.g., Gemma 2 uses cap=50.0). - if getattr(self, '_attention_logits_soft_cap', None) is not None: - cap = self._attention_logits_soft_cap - logits = mx.tanh(logits / cap) * cap + cap = self._attention_logits_soft_cap + logits = mx.tanh(logits / cap) * cap - # Apply mask: set masked positions to large negative. if mask is not None: large_neg = mx.array(-1e9, dtype=logits.dtype) logits = mx.where(mask, logits, large_neg) - # Run softmax in at least float32 to match JAX precision. logits_f32 = logits.astype(mx.float32) if logits.dtype != mx.float32 else logits weights = mx.softmax(logits_f32, axis=-1).astype(v.dtype) - context = mx.matmul(weights, v) # [b, nh, qt, h] - - # Transpose back to [b, qt, nh, h]. + context = mx.matmul(weights, v) context = mx.transpose(context, (0, 2, 1, 3)) return context @@ -1311,42 +1342,36 @@ def _get_source(self, constants): def _compute_attention(self, queries, keys, values, mask): """Compute scaled dot-product attention.""" - if self.sink_key_embeddings is None: - q = mx.transpose(queries, (0, 2, 1, 3)) - k = mx.transpose(keys, (0, 2, 1, 3)) - v = mx.transpose(values, (0, 2, 1, 3)) - - q = _scale_queries( - q, self._per_dim_scale, self._query_scale, self.units_per_head - ) - - context = mx.fast.scaled_dot_product_attention(q, k, v, scale=1.0, mask=mask) - return mx.transpose(context, (0, 2, 1, 3)) - q = mx.transpose(queries, (0, 2, 1, 3)) k = mx.transpose(keys, (0, 2, 1, 3)) v = mx.transpose(values, (0, 2, 1, 3)) - # Compute sink logits BEFORE scaling queries, matching JAX behavior. - if self.sink_key_embeddings is not None: - sink_k = self.sink_key_embeddings.astype(q.dtype) # [K, nh, h] - sink_k_t = mx.transpose(sink_k, (1, 2, 0)) # [nh, h, K] - sink_logits = mx.matmul(q, sink_k_t) # [b, nh, qt, K] - q = _scale_queries( q, self._per_dim_scale, self._query_scale, self.units_per_head ) - logits = mx.matmul(q, mx.transpose(k, (0, 1, 3, 2))) - # Add attention sink logits if present. if self.sink_key_embeddings is not None: - sink_v = self.sink_value_embeddings.astype(v.dtype) # [K, nh, h] - sink_v_t = mx.transpose(sink_v, (1, 0, 2)) # [nh, K, h] + # JAX computes sink logits with *unscaled* queries. Pre-divide + # sink keys by the scale so that SDPA produces equivalent logits: + # scaled_q @ (sink_k / scale) == unscaled_q @ sink_k + scale_vec = _query_scale_vector( + self._per_dim_scale, self._query_scale, + self.units_per_head, q.dtype, + ) + sink_k = self.sink_key_embeddings.astype(q.dtype) / scale_vec + sink_v = self.sink_value_embeddings.astype(v.dtype) + + sink_k_b = mx.broadcast_to( + mx.transpose(sink_k, (1, 0, 2))[None], + (q.shape[0], self.num_heads, sink_k.shape[0], self.units_per_head), + ) sink_v_b = mx.broadcast_to( - sink_v_t[None], (v.shape[0],) + sink_v_t.shape - ) # [b, nh, K, h] + mx.transpose(sink_v, (1, 0, 2))[None], + (v.shape[0], self.num_heads, sink_v.shape[0], self.units_per_head), + ) + + k = mx.concatenate([sink_k_b, k], axis=2) v = mx.concatenate([sink_v_b, v], axis=2) - logits = mx.concatenate([sink_logits, logits], axis=-1) if mask is not None: num_sinks = self.sink_key_embeddings.shape[0] @@ -1356,15 +1381,10 @@ def _compute_attention(self, queries, keys, values, mask): ) mask = mx.concatenate([sink_mask, mask], axis=-1) - if mask is not None: - large_neg = mx.array(-1e9, dtype=logits.dtype) - logits = mx.where(mask, logits, large_neg) - # Run softmax in at least float32 to match JAX precision. - logits_f32 = logits.astype(mx.float32) if logits.dtype != mx.float32 else logits - weights = mx.softmax(logits_f32, axis=-1).astype(v.dtype) - context = mx.matmul(weights, v) - context = mx.transpose(context, (0, 2, 1, 3)) - return context + context = mx.fast.scaled_dot_product_attention( + q, k, v, scale=1.0, mask=mask + ) + return mx.transpose(context, (0, 2, 1, 3)) def get_output_shape(self, input_shape, *, constants=None): if len(input_shape) != 1: From 0671c97202e8ec07b7eb4dd69622eeea8b913ffc Mon Sep 17 00:00:00 2001 From: David Braun <2096055+DBraun@users.noreply.github.com> Date: Thu, 5 Mar 2026 18:29:28 -0500 Subject: [PATCH 11/17] Update convolution.py --- sequence_layers/mlx/convolution.py | 59 ++++++------------------------ 1 file changed, 12 insertions(+), 47 deletions(-) diff --git a/sequence_layers/mlx/convolution.py b/sequence_layers/mlx/convolution.py index 2c8d549..2383bbd 100644 --- a/sequence_layers/mlx/convolution.py +++ b/sequence_layers/mlx/convolution.py @@ -14,50 +14,6 @@ MaskedSequence = bt.MaskedSequence PaddingMode = bt.PaddingMode -# Module-level cache for mask convolution kernels. Keys are tuples of -# deterministic parameters; values are small mx.arrays. The cache is -# bounded (one entry per unique configuration) and shared across all -# layer instances. -_MASK_KERNEL_CACHE: dict[tuple, mx.array] = {} - - -def _get_padding_kernel(pad_left, pad_right): - """Get or create the padding mask kernel for step-mode conv mask.""" - key = ('pad', pad_left, pad_right) - if key not in _MASK_KERNEL_CACHE: - k = [0.0] * pad_left + [1.0] + [0.0] * pad_right - _MASK_KERNEL_CACHE[key] = mx.array(k, dtype=mx.float32).reshape(1, -1, 1) - return _MASK_KERNEL_CACHE[key] - - -def _get_logical_kernel(kernel_size, dilation_rate): - """Get or create the logical mask kernel for reduce_window simulation.""" - key = ('logical', kernel_size, dilation_rate) - if key not in _MASK_KERNEL_CACHE: - if dilation_rate == 1: - _MASK_KERNEL_CACHE[key] = mx.ones( - (1, kernel_size, 1), dtype=mx.float32 - ) - else: - ek = _effective_kernel_size(kernel_size, dilation_rate) - k = [0.0] * ek - for i in range(kernel_size): - k[i * dilation_rate] = 1.0 - _MASK_KERNEL_CACHE[key] = ( - mx.array(k, dtype=mx.float32).reshape(1, -1, 1) - ) - return _MASK_KERNEL_CACHE[key] - - -def _get_transpose_kernel(kernel_size): - """Get or create the transpose conv mask kernel.""" - key = ('transpose', kernel_size) - if key not in _MASK_KERNEL_CACHE: - _MASK_KERNEL_CACHE[key] = mx.ones( - (1, kernel_size, 1), dtype=mx.float32 - ) - return _MASK_KERNEL_CACHE[key] - # --------------------------------------------------------------------------- # Padding utilities (ported from jax/utils.py and jax/convolution.py) @@ -145,7 +101,8 @@ def _compute_conv_mask( padding, kernel_size, stride, dilation_rate ) # Use a simple convolution-like mask computation with float kernel. - kernel = _get_padding_kernel(pad_left, pad_right) + kernel = [0.0] * pad_left + [1.0] + [0.0] * pad_right + kernel = mx.array(kernel, dtype=mx.float32).reshape(1, -1, 1) mask_f = mask[:, :, None].astype(mx.float32) mask_conv = mx.conv1d(mask_f, kernel, stride=stride) return mx.squeeze(mask_conv, axis=-1).astype(mx.bool_) @@ -218,7 +175,15 @@ def _compute_conv_mask_logical( # Use float conv to simulate reduce_window. mask_f = mask[:, :, None].astype(mx.float32) - kernel = _get_logical_kernel(kernel_size, dilation_rate) + # Build a kernel with ones at dilated positions. + if dilation_rate == 1: + kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) + else: + ek = _effective_kernel_size(kernel_size, dilation_rate) + k = [0.0] * ek + for i in range(kernel_size): + k[i * dilation_rate] = 1.0 + kernel = mx.array(k, dtype=mx.float32).reshape(1, -1, 1) result = mx.conv1d(mask_f, kernel, stride=stride) result = mx.squeeze(result, axis=-1) @@ -697,7 +662,7 @@ def _compute_conv_transpose_mask( test_signal = mx.logical_not(mask) test_fn = lambda m: m == 0.0 - kernel = _get_transpose_kernel(kernel_size) + kernel = mx.ones((1, kernel_size, 1), dtype=mx.float32) signal = test_signal.astype(mx.float32)[:, :, None] result = mx.conv_transpose1d( From c595bbaa5ba01ecdfa2a7125402274135382c074 Mon Sep 17 00:00:00 2001 From: Chris Donahue Date: Fri, 6 Mar 2026 11:32:20 -0500 Subject: [PATCH 12/17] Updates pyproject w/ recurrentgemma fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d0b8c2..33a07aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "jaxtyping", "numpy", "orbax-export", - "recurrentgemma[jax]", + "recurrentgemma[jax]>=1.0.1", "typeguard==2.13.3", ] From 32ca2ec1147ec422a36c472dfc81eb5ccf5a2782 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 24 Apr 2026 17:15:23 +0000 Subject: [PATCH 13/17] test(mlx): Add ring buffer wrap-around tests (expected to fail) These tests exercise the DotProductSelfAttention KV cache ring buffer with time > max_past_horizon, causing buffer wrap-around. They expose a write-before-read bug: the current code overwrites the oldest key in the attention window before queries can attend to it, even with block_size=1. Failures: - test_use_kv_cache_ringbuffer: 33.3% element mismatch at time=6 - LocalDotProductSelfAttention.test_layer: 19.4% element mismatch at time=10 The following commit fixes these by implementing write-after-read. --- sequence_layers/mlx/attention_test.py | 35 +++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/sequence_layers/mlx/attention_test.py b/sequence_layers/mlx/attention_test.py index e574ce4..be2f1f3 100644 --- a/sequence_layers/mlx/attention_test.py +++ b/sequence_layers/mlx/attention_test.py @@ -71,6 +71,39 @@ def test_step_builds_kv_cache(self): # Check KV cache has been populated. kv_keys = state[0] self.assertEqual(kv_keys.shape[1], 10) # buffer size + kv_mask = state[2] + self.assertEqual(mx.sum(kv_mask).item(), 5) # 5 of 10 slots filled + + @parameterized.parameters( + (2, 4, 4, 0), + ) + def test_use_kv_cache_ringbuffer( + self, num_heads, units_per_head, max_past_horizon, max_future_horizon + ): + """Test ring buffer wrap-around: layer() vs step() parity. + + With block_size=1 (default), once the ring buffer wraps, the current + write-before-read implementation overwrites the oldest key in the + attention window before the query can attend to it. This causes + step() to see max_past keys while layer() sees max_past + 1 keys + for the same query position, breaking bitwise parity. + + Sweep time shorter, equal, and longer than max_past_horizon to + exercise the wrap-around. + """ + config = attention.DotProductSelfAttention.Config( + num_heads=num_heads, + units_per_head=units_per_head, + max_past_horizon=max_past_horizon, + max_future_horizon=max_future_horizon, + ) + layer = config.make(backend='mlx') + + for time in [1, max_past_horizon, max_past_horizon + 2]: + with self.subTest(f'time_{time}'): + test_utils.verify_contract( + self, layer, (8,), time=time, atol=1e-4, rtol=1e-4 + ) def test_with_query_key_networks(self): """Test with RoPE on Q/K.""" @@ -490,6 +523,8 @@ def test_layer(self): block_size_config=2, ) test_utils.verify_contract(self, layer, (16,), atol=1e-4, rtol=1e-4) + # Also test with time > max_past_horizon to exercise ring buffer wrap. + test_utils.verify_contract(self, layer, (16,), time=10, atol=1e-4, rtol=1e-4) def test_block_size(self): layer = attention.LocalDotProductSelfAttention( From 2a1092af8856ce175d6401983b0e95b337bcfad5 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Fri, 24 Apr 2026 16:54:50 +0000 Subject: [PATCH 14/17] fix(mlx): Implement write-after-read for attention ring buffer Includes: - Expanded sweep for KV cache ringbuffer test with max_future_horizon=0. - Restored informative comments about ring buffer and mask logic. --- sequence_layers/mlx/attention.py | 76 +++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 20 deletions(-) diff --git a/sequence_layers/mlx/attention.py b/sequence_layers/mlx/attention.py index cdb5bf3..6e1f183 100644 --- a/sequence_layers/mlx/attention.py +++ b/sequence_layers/mlx/attention.py @@ -161,6 +161,7 @@ def __init__( raise ValueError( f'max_future_horizon must be >= -1, got {max_future_horizon}.' ) + self.in_features = in_features self.num_heads = num_heads @@ -407,6 +408,11 @@ def get_output_dtype(self, input_dtype, *, constants=None): return self._param_dtype def get_initial_state(self, batch_size, input_spec, *, constants=None): + if self.max_future_horizon > 0: + raise NotImplementedError( + 'max_future_horizon > 0 step() is not yet supported in the MLX' + ' backend (query delay buffer not implemented).' + ) compute_dtype = self.get_output_dtype(input_spec.dtype) max_past = max(0, self.max_past_horizon) max_future = max(0, self.max_future_horizon) @@ -553,11 +559,59 @@ def step_with_emits(self, x, state, *, constants=None): kv_buffer_size = kv_buf_k.shape[1] if kv_buffer_size > 0: - # Ring buffer write: insert new K/V at rotating positions. + t0 = time_step[0] # MLX scalar, no eval. + + # Concatenate old buffer with new elements for attention computation. + # This avoids overwriting history needed by current queries. + combined_k = mx.concatenate([kv_buf_k, keys.values], axis=1) + combined_v = mx.concatenate([kv_buf_v, values.values], axis=1) + combined_mask = mx.concatenate([kv_buf_mask, x.mask], axis=1) + + # Build visibility mask: [b, 1, 1, kv_buffer_size + x_time]. + kv_valid = combined_mask[:, None, None, :] + + # Map physical indices in old buffer to temporal indices. + # The newest time in the old buffer was t0 - 1. + newest_time_old = t0 - 1 + newest_pos_old = newest_time_old % kv_buffer_size + phys_old = mx.arange(kv_buffer_size) + dist_old = (newest_pos_old - phys_old + kv_buffer_size) % kv_buffer_size + temporal_old = newest_time_old - dist_old + + # Temporal indices for new elements. + temporal_new = t0 + mx.arange(x_time) + + # Combine temporal indices. + temporal = mx.concatenate([temporal_old, temporal_new], axis=0) + + # Banded visibility matrix: query_time x (kv_buffer_size + x_time). + # Maps physical ring buffer positions to semantic temporal indices. + # Example: max_past=5, block_size=3, current time t0=6: + # Queries are at times [6,7,8]: + # query 6: sees keys in [1, 6] + # query 7: sees keys in [2, 7] + # query 8: sees keys in [3, 8] + # Add causal mask for multi-step queries. + q_times = t0 + mx.arange(x_time) + causal = temporal[None, :] <= q_times[:, None] + + # Add finite horizon mask. + past = self.max_past_horizon + finite_horizon = temporal[None, :] >= (q_times[:, None] - past) + + causal_and_finite = causal & finite_horizon + kv_valid = kv_valid & causal_and_finite.reshape( + 1, 1, x_time, kv_buffer_size + x_time + ) + + context = self._compute_attention( + queries.values, combined_k, combined_v, kv_valid + ) + + # Ring buffer write AFTER read: insert new K/V at rotating positions. # Uses put_along_axis to scatter into pre-allocated buffers, # compatible with mx.compile / mx.export_function (no Python # int conversion needed). - t0 = time_step[0] # MLX scalar, no eval. positions = (t0 + mx.arange(x_time)) % kv_buffer_size # [x_time] # Scatter K/V into buffer at ring positions. @@ -570,24 +624,6 @@ def step_with_emits(self, x, state, *, constants=None): # Scatter mask into buffer. idx_2d = mx.broadcast_to(positions.reshape(1, x_time), x.mask.shape) kv_buf_mask = mx.put_along_axis(kv_buf_mask, idx_2d, x.mask, axis=1) - - # Build visibility mask: [b, 1, 1, kv_buffer_size]. - kv_valid = kv_buf_mask[:, None, None, :] - - # Add causal mask for multi-step queries (respects ring buffer order). - if x_time > 1: - newest_time = t0 + x_time - 1 - newest_pos = newest_time % kv_buffer_size - phys = mx.arange(kv_buffer_size) - dist = (newest_pos - phys + kv_buffer_size) % kv_buffer_size - temporal = newest_time - dist - q_times = t0 + mx.arange(x_time) - causal = temporal[None, :] <= q_times[:, None] - kv_valid = kv_valid & causal.reshape(1, 1, x_time, kv_buffer_size) - - context = self._compute_attention( - queries.values, kv_buf_k, kv_buf_v, kv_valid - ) else: # Degenerate: no history buffer, attend only to current step. kv_valid = x.mask[:, None, None, :] From 3b135db1604c885bbcf532f2eded8be4995bd6cf Mon Sep 17 00:00:00 2001 From: Kehang Han Date: Mon, 18 May 2026 12:33:09 -0700 Subject: [PATCH 15/17] Adds Parallel.Config on mlx side --- sequence_layers/mlx/__init__.py | 1 + sequence_layers/mlx/combinators.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/sequence_layers/mlx/__init__.py b/sequence_layers/mlx/__init__.py index ec647af..2eaa5d6 100644 --- a/sequence_layers/mlx/__init__.py +++ b/sequence_layers/mlx/__init__.py @@ -417,6 +417,7 @@ def _register_backends(): reg('mlx', mlx_dsp.Delay.Config, mlx_dsp.Delay.from_config) reg('mlx', mlx_comb.Serial.Config, mlx_comb.Serial.from_config) reg('mlx', mlx_comb.Residual.Config, mlx_comb.Residual.from_config) + reg('mlx', mlx_comb.Parallel.Config, mlx_comb.Parallel.from_config) reg('mlx', mlx_cond.Conditioning.Config, mlx_cond.Conditioning.from_config) reg('mlx', mlx_attn.DotProductSelfAttention.Config, mlx_attn.DotProductSelfAttention.from_config) reg('mlx', mlx_attn.DotProductAttention.Config, mlx_attn.DotProductAttention.from_config) diff --git a/sequence_layers/mlx/combinators.py b/sequence_layers/mlx/combinators.py index f26eb74..5dd744f 100644 --- a/sequence_layers/mlx/combinators.py +++ b/sequence_layers/mlx/combinators.py @@ -448,6 +448,18 @@ class Parallel(types.Emitting): All children must have equal output_ratio and block_size. """ + @dataclasses.dataclass(frozen=True) + class Config(_SequenceLayerConfig): + layers: tuple[_SequenceLayerConfig, ...] = () + combination: CombinationMode = CombinationMode.STACK + name: str | None = None + + def __post_init__(self): + object.__setattr__(self, 'layers', tuple(self.layers)) + + def make(self, backend='mlx') -> 'Parallel': + return Parallel.from_config(self, backend=backend) + def __init__( self, layers: list[types.SequenceLayer], From 4d27a1ec032a58701f2dcc9ffeffa9f3fa13dfea Mon Sep 17 00:00:00 2001 From: Kehang Han Date: Mon, 18 May 2026 14:11:14 -0700 Subject: [PATCH 16/17] Fixes cache poisoning for dtype in simple.py --- sequence_layers/mlx/simple.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sequence_layers/mlx/simple.py b/sequence_layers/mlx/simple.py index 830df0b..4c6afc6 100644 --- a/sequence_layers/mlx/simple.py +++ b/sequence_layers/mlx/simple.py @@ -887,8 +887,9 @@ def _probe_output(self, input_shape, input_dtype): out_values = self._fn(dummy_values) out_shape = out_values.shape[2:] out_dtype = out_values.dtype - self._cached_output_spec = bt.ShapeDType(out_shape, out_dtype) - return self._cached_output_spec + # Don't cache here; shape and dtype probes use different dummy dtypes, + # so a single cache would return stale dtype info. + return bt.ShapeDType(out_shape, out_dtype) except Exception: return None From 7b0816cb7ac423303a7f079b61469d16c4f59be9 Mon Sep 17 00:00:00 2001 From: Julian Salazar Date: Tue, 2 Jun 2026 02:55:20 +0000 Subject: [PATCH 17/17] chore(release): bump version to 0.3.dev1 TAG=agy CONV=21ada17b-3411-4090-8450-e69d8ebfeae6 --- sequence_layers/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sequence_layers/__init__.py b/sequence_layers/__init__.py index e122cc1..26be3e2 100644 --- a/sequence_layers/__init__.py +++ b/sequence_layers/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. """Package directory file for Sequence Layers.""" -__version__ = '0.2' +__version__ = '0.3.dev1'