This module manages the optimization of the VAE parameters to learn the mapping from Audio Spectra to Filter Parameters.
The run_training function executes the main training loop using JAX/Flax.
- Initialization:
- Loads the dataset iterator.
- Initializes the
ParametricVAEmodel andTrainStatewith Adam optimizer.
- Epoch Loop:
- Iterates for a fixed number of epochs (default 100).
- Runs
steps_per_epochbatches. - Beta Annealing: Calculates a
betavalue (KL weight) that ramps up from 0 to 0.001 over the first 100 epochs (Note: currently the loss function uses a fixed beta of 0.001, see below).
- Step Execution (
train_step):- Performs the forward pass (Encoder -> Reparameterization -> Decoder).
- Computes gradients of the
combined_losswith respect to model parameters. - Updates parameters using the optimizer.
- Checkpointing: Saves the model parameters to
checkpoints/model.msgpackusing Flax serialization.
uv run python -m aether.training.train --data data/ --epochs 100
### Recursive Loading
The training script recursively searches for `.wav` files in the `data_dir`. You can organize your data into subfolders (e.g., `data/EchoThief`, `data/Guitar`) and simply point to the parent directory:
```bash
uv run python -m aether.training.train --data_dir data --epochs 500The combined_loss function is critical for training the VAE to produce parameters that reconstruct dimensionality-reduced audio spectra.
- Differentiable DSP (Decoder Output):
- The VAE decoder outputs filter parameters: Frequency (
f), Q (q), and Gain (g). - The loss function constructs the separate Complex Frequency Responses of the filters using the Z-transform formula for Biquad filters.
- It sums these complex responses (implementing the Parallel Bandpass topology) and computes the Total Magnitude Response.
- The VAE decoder outputs filter parameters: Frequency (
- Reconstruction Loss:
- Computes the Mean Squared Error (MSE) between the Predicted Magnitude Response and the Input Target Spectrum.
- This forces the VAE to learn filter parameters that approximate the input sound's spectrum.
- KL Divergence:
- Regularizes the latent space
zto be close to a standard normal distributionN(0, 1). - Weight: 0.001 (Fixed).
- Regularizes the latent space
Below is an example of the reconstruction performance. The blue line is the target spectrum, and the orange line is the VAE's reconstruction using parallel bandpass filters.
