This is an implementation of https://arxiv.org/abs/2410.03282 by Chemseddine et. al. and https://arxiv.org/abs/2410.03282 by Máté et al. to learn samplers from unnormalized densities. We implement gradient flow, learned and linear interpolations and provide examples on more target distributions.
First, install dependencies:
uv syncor withotu uv
pip install -e .uv run python -m src.trainer --config rings_gf --num_modes 4 --radius 1.0 --sigma_rings 0.15uv run python -m src.run_all_variants --target funnel --sigma_funnel 3.0-
Asymmetric GMM - (similar to the one from Máté et al.)
- Parameter:
--mean_offset(default: 8.0) - Config:
asymmetric_gmm_gf,asymmetric_gmm_learned,asymmetric_gmm_linear
- Parameter:
-
Rings - 2D ring distribution
- Parameters:
--num_modes(default: 4),--radius(default: 1.0),--sigma_rings(default: 0.15) - Config:
rings_gf,rings_learned,rings_linear
- Parameters:
-
Neal's Funnel - Example of non
$\beta$ -smooth distribution- Parameter:
--sigma_funnel(default: 3.0) - Config:
funnel_gf,funnel_learned,funnel_linear
- Parameter:
All configs are in src/configs/. Some parameters that might be important:
--config: Specific config to use (e.g., funnel_gf, asymmetric_gmm_learned, rings_linear)--sigma_funnel: Funnel sigma parameter (only does something if target is funnel)--mean_offset: GMM mean offset (only does something if target is asymmetric_gmm)--num_modes: Number of modes for Rings (only does something if target is rings)--radius: Radius for Rings--sigma_rings: Sigma for Rings--ntrain: Number of training steps
src/
├── trainer.py # Main training script
├── models.py # Neural Networks
├── loss.py # Loss functions
├── evaluation.py # Plotting and Eval
├── compute_action_direct.py # Computation of Action
├── run_all_variants.py # Run all three variants in parallel
├── default_factories.py # Model factory functions
├── flow_wrappers.py # Flow utilities
├── configs/ # Configuration classes
│ ├── base_config.py
│ ├── Funnel_configs.py
│ ├── GMM_configs.py
│ └── Rings_configs.py
└── fab/
├── utils/
└── target_distributions/ # Target distributions
This codebase is still WIP, i.e. the wandb integration is not fully updated and eval in general needs some work (big eval not used consistently).



