Skip to content

Latest commit

 

History

History
58 lines (43 loc) · 2.73 KB

File metadata and controls

58 lines (43 loc) · 2.73 KB

Training Module (aether.training)

This module manages the optimization of the VAE parameters to learn the mapping from Audio Spectra to Filter Parameters.

Training Loop (train.py)

The run_training function executes the main training loop using JAX/Flax.

Process

  1. Initialization:
    • Loads the dataset iterator.
    • Initializes the ParametricVAE model and TrainState with Adam optimizer.
  2. Epoch Loop:
    • Iterates for a fixed number of epochs (default 100).
    • Runs steps_per_epoch batches.
    • Beta Annealing: Calculates a beta value (KL weight) that ramps up from 0 to 0.001 over the first 100 epochs (Note: currently the loss function uses a fixed beta of 0.001, see below).
  3. Step Execution (train_step):
    • Performs the forward pass (Encoder -> Reparameterization -> Decoder).
    • Computes gradients of the combined_loss with respect to model parameters.
    • Updates parameters using the optimizer.
  4. Checkpointing: Saves the model parameters to checkpoints/model.msgpack using Flax serialization.

CLI Usage

uv run python -m aether.training.train --data data/ --epochs 100
### Recursive Loading
The training script recursively searches for `.wav` files in the `data_dir`. You can organize your data into subfolders (e.g., `data/EchoThief`, `data/Guitar`) and simply point to the parent directory:
```bash
uv run python -m aether.training.train --data_dir data --epochs 500

Loss Function (loss.py)

The combined_loss function is critical for training the VAE to produce parameters that reconstruct dimensionality-reduced audio spectra.

Components

  1. Differentiable DSP (Decoder Output):
    • The VAE decoder outputs filter parameters: Frequency (f), Q (q), and Gain (g).
    • The loss function constructs the separate Complex Frequency Responses of the filters using the Z-transform formula for Biquad filters.
    • It sums these complex responses (implementing the Parallel Bandpass topology) and computes the Total Magnitude Response.
  2. Reconstruction Loss:
    • Computes the Mean Squared Error (MSE) between the Predicted Magnitude Response and the Input Target Spectrum.
    • This forces the VAE to learn filter parameters that approximate the input sound's spectrum.
  3. KL Divergence:
    • Regularizes the latent space z to be close to a standard normal distribution N(0, 1).
    • Weight: 0.001 (Fixed).

Visualization

Below is an example of the reconstruction performance. The blue line is the target spectrum, and the orange line is the VAE's reconstruction using parallel bandpass filters.

Reconstruction


← Previous: Models | Next: DSP & Simulation →