Official Implementation of Variational Inference for Lévy Process-Driven SDEs via Neural Tilting
This repository implements:
- Tilted-Stable SDE — the proposed variational family that learns an exponential tilt of the prior Lévy measure via a quadratic neural parametrisation.
- Gaussian SDE — the Gaussian-diffusion variational baseline for direct comparison.
- Both models in two drift settings: Ornstein-Uhlenbeck (OU) and double-well potential (DW).
- Financial forecasting variants applied to rolling-window equity price prediction, with five baseline models (DeepAR, DLinear, N-HiTS, Neural Jump SDE, Neural MJD).
submission/
├── models/ # Model implementations
│ ├── tilted_stable_sde.py # Tilted-Stable SDE (OU drift)
│ ├── gaussian_sde.py # Gaussian SDE (OU drift)
│ ├── tilted_stable_sde_double_well.py
│ ├── gaussian_sde_double_well.py
│ └── components/ # Shared building blocks
│ ├── attention.py # Temporal attention / adaptive encoding
│ ├── drift.py # Drift network
│ ├── potential.py # Tilting potential (quadratic parametrisation)
│ ├── scale.py # Diffusion scale network
│ └── utils.py
│
├── training/ # Training scripts and infrastructure
│ ├── train_tilted_stable_sde_gpu.py
│ ├── train_gaussian_sde_gpu.py
│ ├── train_tilted_stable_sde_double_well_gpu.py
│ ├── train_gaussian_sde_double_well_gpu.py
│ ├── loss.py # ELBO and regularisation losses
│ ├── train_model.py # Generic training entry point
│ └── components/
│ ├── optimiser_config.py
│ ├── training_loop.py
│ └── training_monitor.py
│
├── evaluation/ # Evaluation and analysis
│ ├── evaluate_runs.py # Per-run metrics (MSE, MAE, CRPS, jump metrics)
│ ├── compare_models.py # Pairwise TS vs Gaussian comparison
│ ├── analyse_comparisons.py # Aggregate analysis, tables, and plots
│ ├── loss_functions.py
│ └── run_utils.py
│
├── generation/ # Dataset and posterior generation
│ ├── generate_batch_prior_tilted_stable_sde.py # OU datasets
│ ├── generate_batch_prior_tilted_stable_sde_double_well.py # DW datasets
│ └── generate_posteriors.py # Posterior sample visualisation
│
├── utils/
│ ├── dataset_utils.py
│ ├── training_utils.py
│ └── visualization_utils.py
│
├── configs/ # Example configurations
│ ├── ou/
│ │ ├── tilted_stable.yaml # TS-SDE on OU dataset
│ │ └── gaussian.yaml # Gaussian SDE on OU dataset
│ └── double_well/
│ ├── tilted_stable.yaml # TS-SDE on double-well dataset
│ └── gaussian.yaml # Gaussian SDE on double-well dataset
│
├── generate_ou_configs.py # Batch config generator for OU experiments
├── generate_double_well_configs.py # Batch config generator for DW experiments
├── requirements.txt
│
└── financial/ # Financial forecasting application
├── prepare_windows_prices.py # Build rolling-window datasets from Yahoo Finance
├── run_window.py # Train + evaluate one window (TS or Gaussian)
├── run_window_deepar.py # DeepAR baseline
├── run_window_dlinear.py # DLinear baseline
├── run_window_nhits.py # N-HiTS baseline
├── run_window_njsde.py # Neural Jump SDE baseline
├── run_window_neural_mjd.py # Neural MJD baseline
├── data/ # Download and windowing utilities
├── models/ # Financial model implementations
├── training/
├── forecasting/
├── evaluation/ # Aggregation and analysis scripts
└── configs/ # Experiment configs (one per setting)
Python 3.10+ is required.
python -m venv venv
source venv/bin/activate # Windows: venv\Scripts\activate
pip install -r requirements.txtThe training scripts auto-detect the available JAX backend and fall back to CPU if no GPU is found. To enable GPU support, install the appropriate JAX build after the step above:
# NVIDIA CUDA 12 (Linux):
pip install -U "jax[cuda12]"
# NVIDIA CUDA 11 (Linux):
pip install -U "jax[cuda11_pip]" \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Apple Silicon / AMD GPU (macOS):
pip install jax-metalThe synthetic workflow has five stages: generate datasets → generate configs → train → evaluate → compare and analyse.
OU datasets (used by both TS-SDE and Gaussian SDE):
python generation/generate_batch_prior_tilted_stable_sde.pyDouble-well datasets:
python generation/generate_batch_prior_tilted_stable_sde_double_well.pyBoth scripts write to datasets/ and contain a parameter grid at the top of the file (alpha values, observation noise, number of realisations). Edit the grid there before running. Datasets are shared between the two model types — each dataset is generated once from the true tilted-stable process.
After datasets are in place, generate per-seed, per-alpha YAML configs automatically:
# OU experiment (scans datasets/tilted_stable_sde/)
python generate_ou_configs.py --obs-std 0.10 --n-seeds 50
# Double-well experiment (scans datasets/tilted_stable_sde_double_well/)
python generate_double_well_configs.py --obs-std 0.10 --n-seeds 10This writes configs to training_configs/tilted_stable/, training_configs/gaussian/, etc. To see what would be written without writing: add --dry-run.
For quick experiments, use the provided example configs in configs/ directly (see step 3).
# Tilted-Stable SDE — OU
python training/train_tilted_stable_sde_gpu.py --config configs/ou/tilted_stable.yaml
# Gaussian SDE — OU
python training/train_gaussian_sde_gpu.py --config configs/ou/gaussian.yaml
# Tilted-Stable SDE — double-well
python training/train_tilted_stable_sde_double_well_gpu.py --config configs/double_well/tilted_stable.yaml
# Gaussian SDE — double-well
python training/train_gaussian_sde_double_well_gpu.py --config configs/double_well/gaussian.yamlEach run saves its model, config snapshot, training metrics, and diagnostic plots to:
training_runs/{model_type}/{alpha}_{obs_std}/data_{data_seed}_train_{train_seed}/
To continue training from a checkpoint:
python training/train_tilted_stable_sde_gpu.py \
--config configs/ou/tilted_stable.yaml \
--parent-run "alpha_1.50_obsstd_0.10/data_42_train_713"python generation/generate_posteriors.py \
training_runs/tilted_stable_sde/alpha_1.50_obsstd_0.10/data_42_train_713 \
--n-posterior-samples 50Pass two run directories to overlay TS and Gaussian posteriors in the same plot:
python generation/generate_posteriors.py \
training_runs/tilted_stable_sde/alpha_1.50_obsstd_0.10/data_42_train_713 \
training_runs/gaussian_sde/alpha_1.50_obsstd_0.10/data_42_train_321 \
--n-posterior-samples 50Posterior samples are saved as .pkl files alongside the model and are reused by evaluate_runs.py.
python evaluation/evaluate_runs.py \
--run-path training_runs/tilted_stable_sde/alpha_1.50_obsstd_0.10/data_42_train_713 \
--n-posterior-samples 100Metrics (MSE, MAE, CRPS, jump-conditioned variants, drift parameter error) are written to evaluation/results/{model_type}/{param_dir}/{run_id}/evaluation_summary.json.
Once both models have been evaluated on the same dataset:
python evaluation/compare_models.py \
--eval-dirs \
evaluation/results/tilted_stable_sde/alpha_1.50_obsstd_0.10/<ts_run_id> \
evaluation/results/gaussian_sde/alpha_1.50_obsstd_0.10/<gauss_run_id>After running comparisons for all seeds and alpha values:
python evaluation/analyse_comparisons.pyThis produces CSV tables and plots in evaluation/results/analysis/, including per-alpha CRPS, MAE, MSE, and relative improvement figures.
The financial experiment runs a rolling-window forecast over equity log-price data downloaded automatically from Yahoo Finance.
python financial/prepare_windows_prices.py \
--config financial/configs/nvda_tilted_stable_alpha1.9_sig1.0_obsstd0.10_d2w32.yamlThis downloads the specified tickers, builds the rolling-window dataset, and writes it to financial/data/prepared/<config_name>/. The number of windows is printed on completion.
# Tilted-Stable SDE
python financial/run_window.py \
--config-name nvda_tilted_stable_alpha1.9_sig1.0_obsstd0.10_d2w32 \
--window-idx 0
# Gaussian SDE (use the corresponding Gaussian config name)
python financial/run_window.py \
--config-name nvda_gaussian_sig1.0_obsstd0.20_d2w32_creg2 \
--window-idx 0Repeat for all window indices (0 to N−1). Results are written to financial/results/<config_name>/window_<idx>/.
Each baseline has its own entry point with the same --config-name / --window-idx interface. The config names for the baselines are in financial/configs/.
python financial/run_window_deepar.py --config-name 10ticker_deepar_h256l2_obsstd0.10_train30d_fc2d_prices --window-idx 0
python financial/run_window_dlinear.py --config-name 10ticker_dlinear_obsstd0.10_train30d_fc2d_prices --window-idx 0
python financial/run_window_nhits.py --config-name 10ticker_nhits_obsstd0.10_train30d_fc2d_prices --window-idx 0
python financial/run_window_njsde.py --config-name 10ticker_njsde_reg_obsstd0.10_train30d_fc2d_prices --window-idx 0
python financial/run_window_neural_mjd.py --config-name 10ticker_neural_mjd_dropout0.1_obsstd0.10_train30d_fc2d_prices --window-idx 0After all windows and models are complete:
python financial/evaluation/aggregate_prices.py \
--model nvda_tilted_stable_alpha1.9_sig1.0_obsstd0.10_d2w32:"TS-SDE" \
--model nvda_gaussian_sig1.0_obsstd0.20_d2w32_creg2:"Gaussian-SDE" \
--model <deepar_config_name>:"DeepAR"This writes per-model and comparative metric summaries (CRPS, MAE, energy score, jump CRPS) to financial/results/analysis/.
For jump-conditioned metrics across models:
python financial/evaluation/jump_aggregate_prices.py \
--model nvda_tilted_stable_alpha1.9_sig1.0_obsstd0.10_d2w32:"TS-SDE" \
--model nvda_gaussian_sig1.0_obsstd0.20_d2w32_creg2:"Gaussian-SDE"All training configs are YAML files. Key fields shared across model types:
| Field | Description |
|---|---|
model.alpha |
Stability index of the driving Lévy process (1 < α < 2) |
data.data_seed |
Seed used to select the dataset realisation |
data.obs_std |
Observation noise standard deviation |
training.training_steps |
Number of gradient steps |
training.n_loss_samples |
Monte Carlo samples per ELBO gradient estimate |
training.n_latent_steps |
Latent time discretisation steps |
Tilted-Stable-specific fields:
| Field | Description |
|---|---|
model.tilting_width / tilting_depth |
MLP width/depth for the tilting potential |
model.n_attention_references |
Number of learnable temporal reference points |
model.a_min |
Minimum absolute value of the A coefficient (strict negativity) |
training.tilting_regularisation |
L2 penalty on tilting coefficients |
Gaussian-SDE-specific fields:
| Field | Description |
|---|---|
model.control_width / control_depth |
MLP width/depth for the control network |
training.control_regularisation |
L2 penalty on control network weights |