Add MLX LoRA training and audio encoding primitives#51
Open
betweentwomidnights wants to merge 4 commits into
Open
Add MLX LoRA training and audio encoding primitives#51betweentwomidnights wants to merge 4 commits into
betweentwomidnights wants to merge 4 commits into
Conversation
Add pure-MLX LoRA, DoRA, BoRA, and XS adapter injection, checkpoint interoperability, fixed-strength inference support, waveform-to-SAME-latent encoding, SA3 timestep sampling, distribution shifting, rectified-flow loss, and focused parity tests.
Contributor
There was a problem hiding this comment.
Pull request overview
Adds MLX/NumPy-only primitives to the optimized/mlx runtime for (1) training and applying Stable Audio 3 LoRA-family adapters, (2) training timestep sampling + rectified-flow loss, and (3) SAME waveform→latent pre-encoding (including padding/masking and optional chunked encoding). This brings the optimized MLX backend closer to feature parity with the official PyTorch LoRA ecosystem while keeping runtime dependencies minimal.
Changes:
- Introduces MLX LoRA/DoRA/BoRA (+ XS) trainable layer wrappers, safetensors checkpoint save/load, and fixed-strength in-place application APIs.
- Adds training helpers: timestep samplers, distribution shifting, and rectified-flow velocity loss with optional masking.
- Adds SAME-aligned audio encoding utilities (patched pretransform, codec alignment padding, chunked encoding, and validity masks), plus new focused tests and README documentation.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
optimized/mlx/models/defs/lora.py |
Implements MLX LoRA-family adapters, checkpoint IO, and in-place fixed-strength application. |
optimized/mlx/models/defs/training.py |
Adds timestep sampling/shift utilities and rectified-flow loss for MLX training. |
optimized/mlx/models/defs/audio_encoding.py |
Adds SAME waveform→latent encoding utilities with padding masks and optional chunked stitching. |
optimized/mlx/README.md |
Documents the new LoRA training + audio encoding primitives and intended usage. |
tests/test_mlx_lora.py |
Validates adapter injection, checkpoint round-trips, name/layout mapping, and math parity vs PyTorch. |
tests/test_mlx_training.py |
Tests timestep sampler distribution, default shift behavior, and rectified-flow loss masking/targets. |
tests/test_mlx_audio_encoding.py |
Tests patched pretransform layout, padding+masking, chunked-vs-unchunked equivalence, and input contract validation. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This adds pure-MLX primitives for training and applying Stable Audio 3 LoRA-family adapters within the existing
optimized/mlxruntime.These mirror the adapter types already supported by the official PyTorch implementation.
LinearandConv1dadaptersConv1dweight-layout conversionchunking, and latent validity masks
The optimized runtime remains MLX/NumPy-only. PyTorch and the separate
safetensorspackage are not added as runtime dependencies.Motivation
These primitives were developed while adding Apple Silicon LoRA training to the open-source gary4local application. This PR adapts that downstream work to the official optimized model definitions and checkpoint conventions rather than importing the application-specific training pipeline.
The longer downstream implementation and adaptation record is available here for optional context:
https://github.com/betweentwomidnights/gary-localhost-installer-mac/blob/main/docs/sa3/STABLE_AUDIO_3_UPSTREAM_PR_NOTES.md
Checkpoint compatibility
Saved checkpoints use the existing parametrization-style tensor keys and
lora_configmetadata. Tests load MLX-generated checkpoints through the official PyTorch loader and compare every supported adapter variant directly against the official PyTorch adapter math.The implementation also handles two optimized-runtime compatibility details:
to_local_embed.0/2checkpoint names map to the optimizedto_local_embed.seq.0/2modules.Conv1dweights use[out, in, kernel], while MLX uses[out, kernel, in].Audio pre-encoding
The MLX training loss operates on SAME latents.
encode_audio()supplies the required waveform-to-latent boundary through the existing optimized SAME-S and SAME-L encoders. It pads waveforms to the codec's required alignment, supports overlapping chunked encoding for longer clips, and returns a ceiling-scaled validity mask for padded batches.encode_audio()begins after audio has been decoded and resampled. It accepts 44.1 kHz stereo waveforms; dataset-level concerns such as file discovery, captions, and saving encoded latents remain with the calling training workflow.Inference scope
apply_lora_checkpoint()andapply_lora_checkpoints()materialize requested strengths into a loaded model's in-memory weights. They do not modify the base checkpoint on disk.This is intended for fixed-strength, load-time inference. It supports multiple checkpoints with independent strengths, but it is deliberately not presented as a persistent slider API: applying a new strength repeatedly to the same model instance would compound adapted weights.
Because Stable Audio 3 LoRAs are especially expressive when blended together as a set of adjustable strengths, as the existing Gradio UI supports for PyTorch users, a follow-up PR is planned around a reusable MLX generation
pipeline and a base-preserving adapter session. That layer will support persistent inference, multiple simultaneously loaded LoRAs with independently adjustable strengths, and MLX integration for the CLI/Gradio interface.
Cross-backend listening tests
A matched ear test applied the same downstream DoRA checkpoint in this optimized runtime and in Gary4local using the same prompt, seed, sampling steps, and adapter strength. The outputs were audibly different, but both had the expected adapter character and comparable quality; subjectively, they sounded like two pieces from the same song.
Waveform identity is not expected between the two pipelines. For the tested duration, the optimized runtime generates 87 latent positions while gary4local aligns to 88 and supplies a padding mask. The pipelines also consume MLX random state in a different order. A shared numeric seed only produces identical noise when random calls and tensor shapes also match, and the extra latent position changes both the sampled tensor shape and model attention context.
Deliberately out of scope
Those pieces can consume these primitives without making this initial API review depend on a larger interface refactor.
Validation
uvx ruff check .uvx ruff format --check .26 passedacross the focused MLX audio-encoding, LoRA, and training testsoptimized inference path; chunked encoding differs by at most
5.5e-4loss, optimizer, and checkpoint primitives completed with finite loss on
every step; mean loss moved from
0.792over the first 50 steps to0.495over the final 50
checkpoint reader and applies to all 36 expected optimized ARC targets with
no missing or skipped layers
trained with the
garybelltrigger audibly introduced its bell characterinto a lo-fi hip-hop generation using the same trigger. This provides a
perceptual end-to-end check of pre-encoding, prompt conditioning, training,
checkpoint loading, and inference application.
applied to all 36 expected optimized medium layers with no missing or
skipped targets, providing a backwards-compatibility check for downstream
checkpoints created before this upstream adaptation.
Together, these checks cover both directions of the intended boundary:
training a new adapter with the proposed official MLX primitives, and loading an existing downstream adapter into the official optimized runtime.