Skip to content

Add LoRA inference support to the MLX path#57

Open
brxs wants to merge 2 commits into
Stability-AI:mainfrom
brxs:mlx-lora
Open

Add LoRA inference support to the MLX path#57
brxs wants to merge 2 commits into
Stability-AI:mainfrom
brxs:mlx-lora

Conversation

@brxs

@brxs brxs commented Jun 27, 2026

Copy link
Copy Markdown

What

Adds LoRA support to the pure-MLX inference path (optimized/mlx), which had
none — LoRA was only available via the PyTorch Gradio path. A new --lora /
--lora-strength flag on sa3_mlx.py merges one or more adapters into the DiT
weights at load time: no runtime parametrization, no per-step cost, and
--lora-strength 0 is a bit-exact bypass.

Details

  • models/defs/lora_merge.py: merge-at-load for all nine adapter types
    (lora, dora-rows/cols, bora, and the four -xs variants), mirroring
    stable_audio_3/models/lora/model.py's forwards and the
    accumulate-deltas-against-the-original semantics of
    merge_loras_into_base_model. -xs SVD bases are recomputed from the base
    weight (they are not stored in the checkpoint), with the reference's sign
    convention.
  • Reads both conventions: SA3-native (train_lora.py; config in safetensors
    metadata) and HuggingFace PEFT (adapter_model.safetensors +
    adapter_config.json).
  • Wired through both loaders (dit_mlx.py, dit_mlx_medium.py). Handles fused
    to_qkv, the to_local_embed remap, and the Conv1d (out,in,k)↔(out,k,in)
    layout.

Trust / safety

Only .safetensors adapters are accepted; pickle .ckpt/.pt/.bin is
refused — this path never calls torch.load.

Testing

  • scripts/test_lora_merge.py (pytest- and standalone-runnable, no
    torch/weights needed): all nine types (zero-init→identity + nonzero delta),
    exact reconstruction for standard LoRA and PEFT, Conv1d round-trip, the name
    remap, strength scaling + bit-exact bypass, pickle refusal, base-mismatch
    errors, and the 0-merge warning. All pass.
  • End-to-end: a public PEFT adapter
    (motiftechnologies/stable-audio-3-maqam-lora, medium) merged into the medium
    DiT (168 layers) and produced measurably different audio through the CLI;
    --lora-strength 0 was bit-identical.

Notes

A base-mismatched adapter (e.g. a medium adapter with --dit sm-music) raises a
clear error; non-matching layers are skipped with a warning. optimized/ is
excluded from the repo's ruff config, so these files follow the existing
hand-style there.

brxs and others added 2 commits June 27, 2026 17:18
Merge LoRA adapters into the DiT weights at load time via a new --lora /
--lora-strength flag on sa3_mlx.py — no runtime parametrization, no per-step
cost, and a bit-exact bypass at strength 0.

- models/defs/lora_merge.py: merge logic for all nine adapter types (lora,
  dora-rows/cols, bora, and the four -xs variants), reading both the SA3-native
  (train_lora.py) and PEFT safetensors conventions. -xs SVD bases are recomputed
  from the base weight (matching the PyTorch reference).
- dit_mlx.py / dit_mlx_medium.py: load_dit gains lora_paths / lora_strength.
- sa3_mlx.py: --lora / --lora-strength CLI flags, wired through both loaders.
- Trust: only .safetensors adapters are accepted; pickle .ckpt/.pt is refused
  (this path never calls torch.load).

Validated on the medium DiT with a public PEFT adapter (168 layers merged):
output changes measurably and strength 0 is bit-identical to no LoRA.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…o-op

- scripts/test_lora_merge.py: pytest-compatible unit tests for the merge math
  (all nine adapter types, standard/PEFT exact reconstruction, Conv1d layout
  round-trip, to_local_embed remap, strength scaling + bit-exact bypass, the
  pickle trust boundary, base mismatch, and the 0-merge warning). Runs with or
  without pytest.
- lora_merge: raise a clear LoraError when an adapter's shapes don't fit the
  base weight (wrong base for --dit) instead of a raw numpy error; warn when a
  merge touches 0 layers rather than reporting a silent no-op.
- lora_merge: replace the in-band "\0restore" sentinel key with a
  [delta, restore] accumulator value; guard the degenerate _mag_2d squeeze.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant