Add LoRA inference support to the MLX path#57
Open
brxs wants to merge 2 commits into
Open
Conversation
Merge LoRA adapters into the DiT weights at load time via a new --lora / --lora-strength flag on sa3_mlx.py — no runtime parametrization, no per-step cost, and a bit-exact bypass at strength 0. - models/defs/lora_merge.py: merge logic for all nine adapter types (lora, dora-rows/cols, bora, and the four -xs variants), reading both the SA3-native (train_lora.py) and PEFT safetensors conventions. -xs SVD bases are recomputed from the base weight (matching the PyTorch reference). - dit_mlx.py / dit_mlx_medium.py: load_dit gains lora_paths / lora_strength. - sa3_mlx.py: --lora / --lora-strength CLI flags, wired through both loaders. - Trust: only .safetensors adapters are accepted; pickle .ckpt/.pt is refused (this path never calls torch.load). Validated on the medium DiT with a public PEFT adapter (168 layers merged): output changes measurably and strength 0 is bit-identical to no LoRA. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…o-op - scripts/test_lora_merge.py: pytest-compatible unit tests for the merge math (all nine adapter types, standard/PEFT exact reconstruction, Conv1d layout round-trip, to_local_embed remap, strength scaling + bit-exact bypass, the pickle trust boundary, base mismatch, and the 0-merge warning). Runs with or without pytest. - lora_merge: raise a clear LoraError when an adapter's shapes don't fit the base weight (wrong base for --dit) instead of a raw numpy error; warn when a merge touches 0 layers rather than reporting a silent no-op. - lora_merge: replace the in-band "\0restore" sentinel key with a [delta, restore] accumulator value; guard the degenerate _mag_2d squeeze. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Adds LoRA support to the pure-MLX inference path (
optimized/mlx), which hadnone — LoRA was only available via the PyTorch Gradio path. A new
--lora/--lora-strengthflag onsa3_mlx.pymerges one or more adapters into the DiTweights at load time: no runtime parametrization, no per-step cost, and
--lora-strength 0is a bit-exact bypass.Details
models/defs/lora_merge.py: merge-at-load for all nine adapter types(
lora,dora-rows/cols,bora, and the four-xsvariants), mirroringstable_audio_3/models/lora/model.py's forwards and theaccumulate-deltas-against-the-original semantics of
merge_loras_into_base_model.-xsSVD bases are recomputed from the baseweight (they are not stored in the checkpoint), with the reference's sign
convention.
train_lora.py; config in safetensorsmetadata) and HuggingFace PEFT (
adapter_model.safetensors+adapter_config.json).dit_mlx.py,dit_mlx_medium.py). Handles fusedto_qkv, theto_local_embedremap, and the Conv1d(out,in,k)↔(out,k,in)layout.
Trust / safety
Only
.safetensorsadapters are accepted; pickle.ckpt/.pt/.binisrefused — this path never calls
torch.load.Testing
scripts/test_lora_merge.py(pytest- and standalone-runnable, notorch/weights needed): all nine types (zero-init→identity + nonzero delta),
exact reconstruction for standard LoRA and PEFT, Conv1d round-trip, the name
remap, strength scaling + bit-exact bypass, pickle refusal, base-mismatch
errors, and the 0-merge warning. All pass.
(
motiftechnologies/stable-audio-3-maqam-lora, medium) merged into the mediumDiT (168 layers) and produced measurably different audio through the CLI;
--lora-strength 0was bit-identical.Notes
A base-mismatched adapter (e.g. a medium adapter with
--dit sm-music) raises aclear error; non-matching layers are skipped with a warning.
optimized/isexcluded from the repo's ruff config, so these files follow the existing
hand-style there.