Skip to content

Add MLX LoRA training and audio encoding primitives#51

Open
betweentwomidnights wants to merge 4 commits into
Stability-AI:mainfrom
betweentwomidnights:feature/mlx-lora-training
Open

Add MLX LoRA training and audio encoding primitives#51
betweentwomidnights wants to merge 4 commits into
Stability-AI:mainfrom
betweentwomidnights:feature/mlx-lora-training

Conversation

@betweentwomidnights

@betweentwomidnights betweentwomidnights commented Jun 15, 2026

Copy link
Copy Markdown

Summary

This adds pure-MLX primitives for training and applying Stable Audio 3 LoRA-family adapters within the existing optimized/mlx runtime.
These mirror the adapter types already supported by the official PyTorch implementation.

  • Trainable MLX Linear and Conv1d adapters
  • LoRA, DoRA rows/columns, BoRA, and their XS variants
  • SA3-compatible safetensors loading and saving
  • Canonical checkpoint-name mapping for the optimized MLX model layout
  • PyTorch-to-MLX Conv1d weight-layout conversion
  • Rectified-flow loss, training timestep samplers, and distribution shifts
  • SAME waveform-to-latent encoding, including codec alignment, overlapping
    chunking, and latent validity masks
  • Fixed-strength inference application for one or more checkpoints

The optimized runtime remains MLX/NumPy-only. PyTorch and the separate safetensors package are not added as runtime dependencies.

Motivation

These primitives were developed while adding Apple Silicon LoRA training to the open-source gary4local application. This PR adapts that downstream work to the official optimized model definitions and checkpoint conventions rather than importing the application-specific training pipeline.

The longer downstream implementation and adaptation record is available here for optional context:

https://github.com/betweentwomidnights/gary-localhost-installer-mac/blob/main/docs/sa3/STABLE_AUDIO_3_UPSTREAM_PR_NOTES.md

Checkpoint compatibility

Saved checkpoints use the existing parametrization-style tensor keys and lora_config metadata. Tests load MLX-generated checkpoints through the official PyTorch loader and compare every supported adapter variant directly against the official PyTorch adapter math.

The implementation also handles two optimized-runtime compatibility details:

  • to_local_embed.0/2 checkpoint names map to the optimized
    to_local_embed.seq.0/2 modules.
  • PyTorch Conv1d weights use [out, in, kernel], while MLX uses
    [out, kernel, in].

Audio pre-encoding

The MLX training loss operates on SAME latents. encode_audio() supplies the required waveform-to-latent boundary through the existing optimized SAME-S and SAME-L encoders. It pads waveforms to the codec's required alignment, supports overlapping chunked encoding for longer clips, and returns a ceiling-scaled validity mask for padded batches.

encode_audio() begins after audio has been decoded and resampled. It accepts 44.1 kHz stereo waveforms; dataset-level concerns such as file discovery, captions, and saving encoded latents remain with the calling training workflow.

Inference scope

apply_lora_checkpoint() and apply_lora_checkpoints() materialize requested strengths into a loaded model's in-memory weights. They do not modify the base checkpoint on disk.

This is intended for fixed-strength, load-time inference. It supports multiple checkpoints with independent strengths, but it is deliberately not presented as a persistent slider API: applying a new strength repeatedly to the same model instance would compound adapted weights.

Because Stable Audio 3 LoRAs are especially expressive when blended together as a set of adjustable strengths, as the existing Gradio UI supports for PyTorch users, a follow-up PR is planned around a reusable MLX generation
pipeline and a base-preserving adapter session. That layer will support persistent inference, multiple simultaneously loaded LoRAs with independently adjustable strengths, and MLX integration for the CLI/Gradio interface.

Cross-backend listening tests

A matched ear test applied the same downstream DoRA checkpoint in this optimized runtime and in Gary4local using the same prompt, seed, sampling steps, and adapter strength. The outputs were audibly different, but both had the expected adapter character and comparable quality; subjectively, they sounded like two pieces from the same song.

Waveform identity is not expected between the two pipelines. For the tested duration, the optimized runtime generates 87 latent positions while gary4local aligns to 88 and supplies a padding mask. The pipelines also consume MLX random state in a different order. A shared numeric seed only produces identical noise when random calls and tensor shapes also match, and the extra latent position changes both the sampled tensor shape and model attention context.

Deliberately out of scope

  • Dataset discovery, file decoding/resampling, and latent-cache persistence
  • A complete optimizer/training command
  • Changes to the existing CLI or Gradio interface
  • Application-specific defaults, job management, or UI

Those pieces can consume these primitives without making this initial API review depend on a larger interface refactor.

Validation

  • uvx ruff check .
  • uvx ruff format --check .
  • 26 passed across the focused MLX audio-encoding, LoRA, and training tests
  • Existing official PyTorch LoRA test passes
  • Real optimized medium-DiT forward, backward, and AdamW smoke step
  • Real SAME-L encoding of a 16-second stereo clip exactly matches the existing
    optimized inference path; chunked encoding differs by at most 5.5e-4
  • A 500-step rank-4 DoRA run using the proposed encoder, adapter, timestep,
    loss, optimizer, and checkpoint primitives completed with finite loss on
    every step; mean loss moved from 0.792 over the first 50 steps to 0.495
    over the final 50
  • The resulting 108-tensor checkpoint loads through the official PyTorch
    checkpoint reader and applies to all 36 expected optimized ARC targets with
    no missing or skipped layers
  • In a listening test of that newly trained checkpoint, a bell-arpeggio source
    trained with the garybell trigger audibly introduced its bell character
    into a lo-fi hip-hop generation using the same trigger. This provides a
    perceptual end-to-end check of pre-encoding, prompt conditioning, training,
    checkpoint loading, and inference application.
  • Separately, an existing 500-step DoRA checkpoint produced by Gary4local
    applied to all 36 expected optimized medium layers with no missing or
    skipped targets, providing a backwards-compatibility check for downstream
    checkpoints created before this upstream adaptation.

Together, these checks cover both directions of the intended boundary:
training a new adapter with the proposed official MLX primitives, and loading an existing downstream adapter into the official optimized runtime.

Add pure-MLX LoRA, DoRA, BoRA, and XS adapter injection, checkpoint interoperability, fixed-strength inference support, waveform-to-SAME-latent encoding, SA3 timestep sampling, distribution shifting, rectified-flow loss, and focused parity tests.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds MLX/NumPy-only primitives to the optimized/mlx runtime for (1) training and applying Stable Audio 3 LoRA-family adapters, (2) training timestep sampling + rectified-flow loss, and (3) SAME waveform→latent pre-encoding (including padding/masking and optional chunked encoding). This brings the optimized MLX backend closer to feature parity with the official PyTorch LoRA ecosystem while keeping runtime dependencies minimal.

Changes:

  • Introduces MLX LoRA/DoRA/BoRA (+ XS) trainable layer wrappers, safetensors checkpoint save/load, and fixed-strength in-place application APIs.
  • Adds training helpers: timestep samplers, distribution shifting, and rectified-flow velocity loss with optional masking.
  • Adds SAME-aligned audio encoding utilities (patched pretransform, codec alignment padding, chunked encoding, and validity masks), plus new focused tests and README documentation.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
optimized/mlx/models/defs/lora.py Implements MLX LoRA-family adapters, checkpoint IO, and in-place fixed-strength application.
optimized/mlx/models/defs/training.py Adds timestep sampling/shift utilities and rectified-flow loss for MLX training.
optimized/mlx/models/defs/audio_encoding.py Adds SAME waveform→latent encoding utilities with padding masks and optional chunked stitching.
optimized/mlx/README.md Documents the new LoRA training + audio encoding primitives and intended usage.
tests/test_mlx_lora.py Validates adapter injection, checkpoint round-trips, name/layout mapping, and math parity vs PyTorch.
tests/test_mlx_training.py Tests timestep sampler distribution, default shift behavior, and rectified-flow loss masking/targets.
tests/test_mlx_audio_encoding.py Tests patched pretransform layout, padding+masking, chunked-vs-unchunked equivalence, and input contract validation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread optimized/mlx/models/defs/audio_encoding.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants