From 1f1e9d1779713da99ea815a3d1200e21ad7e6548 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Thu, 12 Feb 2026 08:33:23 +0100 Subject: [PATCH 01/48] refactor(docs): Documentation Overhaul --- .claude/commands/commit-push-pr.md | 96 +++++++ .claude/commands/commit-push.md | 77 ++++++ .claude/commands/commit.md | 76 ++++++ CHANGELOG.md | 230 ++++++++++++++++ CLAUDE.md | 3 + Cargo.toml | 28 +- Makefile | 6 + README.md | 176 ++++-------- cliff.toml | 99 +++++++ examples/advanced_features.rs | 277 +++++++++++++++++++ examples/async_api_optimization.rs | 368 -------------------------- examples/basic_optimization.rs | 32 +++ examples/benchmark_convergence.rs | 124 --------- examples/ml_hyperparameter_tuning.rs | 275 ------------------- examples/parameter_api.rs | 56 ---- examples/parameter_types.rs | 73 +++++ examples/pruning_and_callbacks.rs | 133 ++++++++++ examples/sampler_comparison.rs | 94 +++++++ examples/visualization_demo.rs | 45 ---- src/error.rs | 60 +++-- src/fanova.rs | 75 +++++- src/importance.rs | 22 ++ src/lib.rs | 213 ++++----------- src/multi_objective.rs | 56 +++- src/param.rs | 33 ++- src/parameter.rs | 152 +++++++---- src/pareto.rs | 103 +++++-- src/pruner/hyperband.rs | 46 +++- src/pruner/median.rs | 36 +++ src/pruner/mod.rs | 38 +++ src/pruner/nop.rs | 12 + src/pruner/patient.rs | 31 +++ src/pruner/percentile.rs | 34 +++ src/pruner/successive_halving.rs | 50 ++++ src/pruner/threshold.rs | 30 +++ src/pruner/wilcoxon.rs | 41 +++ src/sampler/bohb.rs | 58 +++- src/sampler/cma_es.rs | 66 ++++- src/sampler/differential_evolution.rs | 64 ++++- src/sampler/gp.rs | 67 ++++- src/sampler/grid.rs | 74 ++++-- src/sampler/moead.rs | 89 ++++++- src/sampler/motpe.rs | 72 ++++- src/sampler/nsga2.rs | 50 +++- src/sampler/nsga3.rs | 56 +++- src/sampler/random.rs | 41 ++- src/sampler/sobol.rs | 58 +++- src/sampler/tpe/mod.rs | 62 ++++- src/sampler/tpe/multivariate.rs | 52 +++- src/sampler/tpe/sampler.rs | 2 +- src/storage/journal.rs | 106 +++++++- src/storage/memory.rs | 55 +++- src/storage/mod.rs | 49 +++- src/study.rs | 316 ++++++++++++++++------ src/trial.rs | 162 +++++++++--- src/visualization.rs | 56 +++- 56 files changed, 3298 insertions(+), 1557 deletions(-) create mode 100644 .claude/commands/commit-push-pr.md create mode 100644 .claude/commands/commit-push.md create mode 100644 .claude/commands/commit.md create mode 100644 CHANGELOG.md create mode 100644 CLAUDE.md create mode 100644 Makefile create mode 100644 cliff.toml create mode 100644 examples/advanced_features.rs delete mode 100644 examples/async_api_optimization.rs create mode 100644 examples/basic_optimization.rs delete mode 100644 examples/benchmark_convergence.rs delete mode 100644 examples/ml_hyperparameter_tuning.rs delete mode 100644 examples/parameter_api.rs create mode 100644 examples/parameter_types.rs create mode 100644 examples/pruning_and_callbacks.rs create mode 100644 examples/sampler_comparison.rs delete mode 100644 examples/visualization_demo.rs diff --git a/.claude/commands/commit-push-pr.md b/.claude/commands/commit-push-pr.md new file mode 100644 index 0000000..b29a7df --- /dev/null +++ b/.claude/commands/commit-push-pr.md @@ -0,0 +1,96 @@ +--- +allowed-tools: Bash(git add:*), Bash(git status:*), Bash(git commit:*), Bash(git diff:*), Bash(git log:*), Bash(git config:*), Bash(git push:*), Bash(git branch:*), Bash(git rev-parse:*), Bash(gh pr:*) +description: Create a git commit, push, and open a PR +--- + +## Context + +- Current git status: !`git status` +- Current git diff (staged and unstaged changes): !`git diff HEAD` +- Current branch: !`git branch --show-current` +- Remote tracking: !`git rev-parse --abbrev-ref --symbolic-full-name @{u} 2>/dev/null || echo "no upstream"` +- Recent commits (for style reference): !`git log --oneline -10` +- Author commits (for style reference): !`git log --author="$(git config user.email)" --oneline -10` +- All commits on this branch not on main: !`git log --oneline main..HEAD 2>/dev/null` +- User input (optional, can specify target branch): $ARGUMENTS + +## Your task + +Based on the above changes, create a git commit, push, and open a pull request. + +1. If there are no staged changes, stage only the relevant changed files by name (never use `git add -A` or `git add .`). +2. Do NOT commit files that likely contain secrets (.env, credentials.json, etc). +3. Write a commit message following the guidelines below, then commit. +4. Push to the remote. Use `git push -u origin ` if there is no upstream set. +5. Create a pull request using `gh pr create` targeting the base branch: + - If the user specified a target branch in their input, use that. + - Otherwise, default to `main`. +6. The PR title should match the commit title (or summarize all branch commits if multiple). +7. Write the PR body following the PR format below. +8. Return the PR URL when done. + +## Commit message guidelines + +- Use imperative mood (e.g., "add feature" not "added feature") +- First line: brief summary, max 72 characters +- Focus on the "why" and "what", not the "how" +- Be specific but concise +- Many commits do not need a body if the title is self-explanatory + - Litmus test: "Would a developer understand this commit from the title + diff?" If yes, skip the body. +- Do NOT include "Generated with ...", "Co-Authored-By ...", or any AI attribution + +### Conventional commit prefixes + +Match the prefix to the nature of the change. These must align with `cliff.toml` commit parsers so they appear correctly in the changelog. + +**Changelog: "Added"** +- `feat:` / `feat(scope):` — new feature + +**Changelog: "Fixed"** +- `fix:` / `fix(scope):` — bug fix + +**Changelog: "Changed"** +- `refactor:` / `refactor(scope):` — code restructuring without behavior change +- `perf:` / `perf(scope):` — performance improvement +- `docs:` / `docs(scope):` — documentation changes (README, guides, etc.) +- `style:` / `style(scope):` — formatting only (no logic change) +- `chore:` / `chore(scope):` — tooling, deps, config, scripts, AI config (.claude/, CLAUDE.md) + +**Changelog: "Removed"** +- `revert:` / `revert(scope):` — revert a previous commit + +**Excluded from changelog** +- `test:` / `test(scope):` — adding or updating tests +- `ci:` / `ci(scope):` — CI/CD pipeline changes + +### Body rules + +- When a body is needed (multiple important things in one commit), use bullet points. +- Body is separated from the title by a blank line. +- One bullet point per concept/change. +- Don't explain obvious things like "added unit tests for X". + +### Commit format + +Always pass the commit message via a HEREDOC: + +``` +git commit -m "$(cat <<'EOF' + + +<optional body> +EOF +)" +``` + +## PR format + +Use a HEREDOC for the body. Look at ALL commits on the branch (not just the latest) to write the summary. + +``` +gh pr create --title "<title>" --base <target-branch> --body "$(cat <<'EOF' +## Summary +<1-3 bullet points covering all branch commits> +EOF +)" +``` diff --git a/.claude/commands/commit-push.md b/.claude/commands/commit-push.md new file mode 100644 index 0000000..a3f38c7 --- /dev/null +++ b/.claude/commands/commit-push.md @@ -0,0 +1,77 @@ +--- +allowed-tools: Bash(git add:*), Bash(git status:*), Bash(git commit:*), Bash(git diff:*), Bash(git log:*), Bash(git config:*), Bash(git push:*), Bash(git branch:*), Bash(git rev-parse:*) +description: Create a git commit and push to remote +--- + +## Context + +- Current git status: !`git status` +- Current git diff (staged and unstaged changes): !`git diff HEAD` +- Current branch: !`git branch --show-current` +- Remote tracking: !`git rev-parse --abbrev-ref --symbolic-full-name @{u} 2>/dev/null || echo "no upstream"` +- Recent commits (for style reference): !`git log --oneline -10` +- Author commits (for style reference): !`git log --author="$(git config user.email)" --oneline -10` +- User input (optional): $ARGUMENTS + +## Your task + +Based on the above changes, create a single git commit and push it to the remote. + +1. If there are no staged changes, stage only the relevant changed files by name (never use `git add -A` or `git add .`). +2. Do NOT commit files that likely contain secrets (.env, credentials.json, etc). +3. Write a commit message following the guidelines below, then commit. +4. Push to the remote. Use `git push -u origin <branch>` if there is no upstream set. + +## Commit message guidelines + +- Use imperative mood (e.g., "add feature" not "added feature") +- First line: brief summary, max 72 characters +- Focus on the "why" and "what", not the "how" +- Be specific but concise +- Many commits do not need a body if the title is self-explanatory + - Litmus test: "Would a developer understand this commit from the title + diff?" If yes, skip the body. +- Do NOT include "Generated with ...", "Co-Authored-By ...", or any AI attribution + +### Conventional commit prefixes + +Match the prefix to the nature of the change. These must align with `cliff.toml` commit parsers so they appear correctly in the changelog. + +**Changelog: "Added"** +- `feat:` / `feat(scope):` — new feature + +**Changelog: "Fixed"** +- `fix:` / `fix(scope):` — bug fix + +**Changelog: "Changed"** +- `refactor:` / `refactor(scope):` — code restructuring without behavior change +- `perf:` / `perf(scope):` — performance improvement +- `docs:` / `docs(scope):` — documentation changes (README, guides, etc.) +- `style:` / `style(scope):` — formatting only (no logic change) +- `chore:` / `chore(scope):` — tooling, deps, config, scripts, AI config (.claude/, CLAUDE.md) + +**Changelog: "Removed"** +- `revert:` / `revert(scope):` — revert a previous commit + +**Excluded from changelog** +- `test:` / `test(scope):` — adding or updating tests +- `ci:` / `ci(scope):` — CI/CD pipeline changes + +### Body rules + +- When a body is needed (multiple important things in one commit), use bullet points. +- Body is separated from the title by a blank line. +- One bullet point per concept/change. +- Don't explain obvious things like "added unit tests for X". + +### Commit format + +Always pass the commit message via a HEREDOC: + +``` +git commit -m "$(cat <<'EOF' +<title line> + +<optional body> +EOF +)" +``` diff --git a/.claude/commands/commit.md b/.claude/commands/commit.md new file mode 100644 index 0000000..50d3362 --- /dev/null +++ b/.claude/commands/commit.md @@ -0,0 +1,76 @@ +--- +allowed-tools: Bash(git add:*), Bash(git status:*), Bash(git commit:*), Bash(git diff:*), Bash(git log:*), Bash(git config:*), Bash(git branch:*) +description: Create a git commit +--- + +## Context + +- Current git status: !`git status` +- Current git diff (staged and unstaged changes): !`git diff HEAD` +- Current branch: !`git branch --show-current` +- Recent commits (for style reference): !`git log --oneline -10` +- Author commits (for style reference): !`git log --author="$(git config user.email)" --oneline -10` +- User input (optional): $ARGUMENTS + +## Your task + +Based on the above changes, create a single git commit. + +1. If there are no staged changes, stage only the relevant changed files by name (never use `git add -A` or `git add .`). +2. Do NOT commit files that likely contain secrets (.env, credentials.json, etc). +3. Write a commit message following the guidelines below, then commit. +4. Do NOT push to remote. + +## Commit message guidelines + +- Use imperative mood (e.g., "add feature" not "added feature") +- First line: brief summary, max 72 characters +- Focus on the "why" and "what", not the "how" +- Be specific but concise +- Many commits do not need a body if the title is self-explanatory + - Litmus test: "Would a developer understand this commit from the title + diff?" If yes, skip the body. +- Do NOT include "Generated with ...", "Co-Authored-By ...", or any AI attribution + +### Conventional commit prefixes + +Match the prefix to the nature of the change. These must align with `cliff.toml` commit parsers so they appear correctly in the changelog. + +**Changelog: "Added"** +- `feat:` / `feat(scope):` — new feature + +**Changelog: "Fixed"** +- `fix:` / `fix(scope):` — bug fix + +**Changelog: "Changed"** +- `refactor:` / `refactor(scope):` — code restructuring without behavior change +- `perf:` / `perf(scope):` — performance improvement +- `docs:` / `docs(scope):` — documentation changes (README, guides, etc.) +- `style:` / `style(scope):` — formatting only (no logic change) +- `chore:` / `chore(scope):` — tooling, deps, config, scripts, AI config (.claude/, CLAUDE.md) + +**Changelog: "Removed"** +- `revert:` / `revert(scope):` — revert a previous commit + +**Excluded from changelog** +- `test:` / `test(scope):` — adding or updating tests +- `ci:` / `ci(scope):` — CI/CD pipeline changes + +### Body rules + +- When a body is needed (multiple important things in one commit), use bullet points. +- Body is separated from the title by a blank line. +- One bullet point per concept/change. +- Don't explain obvious things like "added unit tests for X". + +### Commit format + +Always pass the commit message via a HEREDOC: + +``` +git commit -m "$(cat <<'EOF' +<title line> + +<optional body> +EOF +)" +``` diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..2b0acf9 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,230 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.9.1] - 2026-02-12 + +### Fixed + +- Fix journal file handling to support both reading and writing modes + +## [0.9.0] - 2026-02-12 + +### Added + +- Storage trait abstraction with pluggable backends +- JSONL journal storage backend for persistent studies (behind `journal` feature) +- Differential Evolution sampler +- Gaussian Process sampler with Expected Improvement acquisition function +- NSGA-III sampler for many-objective optimization with reference-point decomposition +- MOEA/D sampler for decomposition-based multi-objective optimization +- `StudyBuilder` for fluent study construction + +### Changed + +- Move `next_trial_id` counter from `Study` into `Storage` trait +- Replace `rand 0.10` with `fastrand 2.3` for simpler, faster random number generation +- Remove `visualization` and `fanova` feature flags (now always available) +- Simplify CI test commands by removing feature matrix + +### Fixed + +- Seed importance tests to eliminate flakiness +- Use atomic counter for temp paths in journal tests + +### Removed + +- Reverted SQLite storage backend (may return in a future release) + +## [0.8.1] - 2026-02-11 + +### Fixed + +- Resolve broken rustdoc intra-doc links for `Error::NoCompletedTrials` + +## [0.8.0] - 2026-02-11 + +### Added + +- Multi-objective optimization with NSGA-II +- Multi-Objective TPE (MOTPE) sampler +- Pareto front analysis utilities +- CSV and JSON data export for visualization +- HTML visualization reports with Plotly.js charts +- fANOVA (functional ANOVA) parameter importance via random forest + +### Fixed + +- Rename variable to pass typos check + +## [0.7.2] - 2026-02-11 + +### Changed + +- Update `nalgebra` dependency to version 0.34 +- Add advisory ignore for unmaintained transitive dependency + +## [0.7.1] - 2026-02-11 + +### Changed + +- Update `tracing` dependency to version 0.1.29 + +## [0.7.0] - 2026-02-11 + +### Added + +- CMA-ES (Covariance Matrix Adaptation Evolution Strategy) sampler behind `cma-es` feature flag +- Sobol quasi-random sampler behind `sobol` feature flag +- BOHB (Bayesian Optimization with HyperBand) sampler for budget-aware optimization +- Parameter importance analysis via Spearman rank correlation +- Constraint handling with feasibility-aware trial ranking +- `optimize_with_retries()` for automatic retry of failed trials +- `optimize_with_checkpoint()` with atomic save writes for crash recovery +- `Study::summary()` and `Display` impl for study overview +- `IntoIterator` for `&Study` and `iter()` method +- Tracing integration behind `tracing` feature flag +- Serde serialization support behind `serde` feature flag +- Benchmark suite with criterion and standard test functions + +## [0.6.0] - 2026-02-11 + +### Added + +- Pruning system with `Pruner` trait and `NopPruner` default implementation +- Intermediate value reporting on trials for pruner integration +- `TrialPruned` error variant and `Pruned` trial state +- `ThresholdPruner` for fixed-bound trial pruning +- `MedianPruner` for statistics-based trial pruning +- `PercentilePruner` for configurable percentile-based trial pruning +- `PatientPruner` for patience-based trial pruning +- `SuccessiveHalvingPruner` for SHA-based trial pruning +- `HyperbandPruner` for multi-bracket trial pruning +- `WilcoxonPruner` for statistics-based trial pruning +- `Study::minimize()` and `Study::maximize()` constructor shortcuts +- `From<RangeInclusive>` for `FloatParam` and `IntParam` +- Timeout-based optimization with `optimize_until` +- `Study::top_trials(n)` for retrieving the best N trials +- Ask-and-tell interface with `ask()` and `tell()` methods +- Trial user attributes for logging and analysis +- `enqueue_trial()` for pre-specified parameter evaluation + +## [0.5.1] - 2026-02-10 + +### Changed + +- Update random number generation to use `rand::make_rng()` and upgrade `rand` to version 0.10 + +## [0.5.0] - 2026-02-06 + +### Added + +- Typed parameter API with `FloatParam`, `IntParam`, `CategoricalParam`, `BoolParam`, and `EnumParam` +- `#[derive(Categorical)]` proc macro for deriving categorical parameters from enums (behind `derive` feature) +- `.name()` builder method on all parameter types for custom labels +- `CompletedTrial::get()` for typed parameter access +- `Display` impl on `ParamValue` +- Prelude module at `optimizer::prelude::*` + +### Changed + +- Reorganize sampler module structure and update imports +- Remove `log` dependency + +## [0.4.0] - 2026-02-02 + +### Added + +- Multivariate TPE sampler for correlated parameter search +- Gamma strategies for TPE sampler (linear, sqrt, fixed) with examples +- Example: async API parameter optimization +- Example: ML hyperparameter tuning + +### Fixed + +- Handle end value in `suggest` method to avoid panic on underflow + +## [0.3.1] - 2026-01-31 + +### Added + +- Grid search sampler for exhaustive parameter exploration +- `suggest_bool()` method for boolean parameter suggestion +- `SuggestableRange` trait and `suggest_range()` method for parameter suggestion from ranges +- Documentation, keywords, categories, and readme fields in `Cargo.toml` + +### Changed + +- Replace `TpeError` with unified `Error` type across the library +- Add permissions for read access to contents in CI and scheduled workflows + +## [0.3.0] - 2026-01-30 + +### Added + +- Cross-target compilation checks in CI + +### Changed + +- Internal code refactoring and cleanup + +## [0.2.0] - 2026-01-30 + +### Added + +- CI step for unused dependencies check with `cargo-machete` +- CI step to publish to crates.io +- Default typo extension for 'Tpe' + +### Changed + +- Remove `serde` dependency (was unused) +- Remove unused `ordered-float` dependency + +## [0.1.1] - 2026-01-30 + +### Fixed + +- Minor release fixes + +## [0.1.0] - 2026-01-30 + +### Added + +- Initial implementation of the optimization library +- TPE (Tree-structured Parzen Estimator) sampler with optional fixed bandwidth KDE +- Random sampler +- Async optimization support +- CI workflows for testing, coverage (Codecov), auditing, and publishing +- README with project overview, features, and quick start guide + +--- + +## Release Workflow + +This changelog is maintained automatically with [git-cliff](https://git-cliff.org/). + +1. Write commits using [Conventional Commits](https://www.conventionalcommits.org/) (`feat:`, `fix:`, `refactor:`, etc.). +2. Tag the release: `git tag v0.X.0` +3. Regenerate the changelog: `make changelog` +4. Commit the updated `CHANGELOG.md` and push. + +[0.9.1]: https://github.com/raimannma/rust-optimizer/compare/v0.9.0...v0.9.1 +[0.9.0]: https://github.com/raimannma/rust-optimizer/compare/v0.8.1...v0.9.0 +[0.8.1]: https://github.com/raimannma/rust-optimizer/compare/v0.8.0...v0.8.1 +[0.8.0]: https://github.com/raimannma/rust-optimizer/compare/v0.7.2...v0.8.0 +[0.7.2]: https://github.com/raimannma/rust-optimizer/compare/v0.7.1...v0.7.2 +[0.7.1]: https://github.com/raimannma/rust-optimizer/compare/v0.7.0...v0.7.1 +[0.7.0]: https://github.com/raimannma/rust-optimizer/compare/v0.6.0...v0.7.0 +[0.6.0]: https://github.com/raimannma/rust-optimizer/compare/v0.5.1...v0.6.0 +[0.5.1]: https://github.com/raimannma/rust-optimizer/compare/v0.5.0...v0.5.1 +[0.5.0]: https://github.com/raimannma/rust-optimizer/compare/v0.4.0...v0.5.0 +[0.4.0]: https://github.com/raimannma/rust-optimizer/compare/v0.3.1...v0.4.0 +[0.3.1]: https://github.com/raimannma/rust-optimizer/compare/v0.3.0...v0.3.1 +[0.3.0]: https://github.com/raimannma/rust-optimizer/compare/v0.2.0...v0.3.0 +[0.2.0]: https://github.com/raimannma/rust-optimizer/compare/v0.1.1...v0.2.0 +[0.1.1]: https://github.com/raimannma/rust-optimizer/compare/v0.1.0...v0.1.1 +[0.1.0]: https://github.com/raimannma/rust-optimizer/releases/tag/v0.1.0 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..fa93e9d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,3 @@ +## Git Guidelines + +- If the user asks to make a commit, do not add any co-author by claude or any other co-author note. diff --git a/Cargo.toml b/Cargo.toml index 0d45868..d1c11d2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,11 +8,11 @@ edition = "2024" rust-version = "1.89" license = "MIT" authors = ["Manuel Raimann <raimannma@outlook.de"] -description = "A Rust library for optimization algorithms." +description = "Bayesian and population-based optimization library with an Optuna-like API for hyperparameter tuning and black-box optimization" repository = "https://github.com/raimannma/rust-optimizer" documentation = "https://docs.rs/optimizer" keywords = ["optimization", "hyperparameter", "tpe", "grid-search", "bayesian"] -categories = ["algorithm", "science", "data-structures"] +categories = ["algorithms", "science", "mathematics"] readme = "README.md" [dependencies] @@ -54,22 +54,26 @@ name = "optimization" harness = false [[example]] -name = "async_api_optimization" -path = "examples/async_api_optimization.rs" -required-features = ["async"] +name = "basic_optimization" +path = "examples/basic_optimization.rs" [[example]] -name = "ml_hyperparameter_tuning" -path = "examples/ml_hyperparameter_tuning.rs" +name = "parameter_types" +path = "examples/parameter_types.rs" +required-features = ["derive"] [[example]] -name = "parameter_api" -path = "examples/parameter_api.rs" -required-features = ["derive"] +name = "sampler_comparison" +path = "examples/sampler_comparison.rs" + +[[example]] +name = "pruning_and_callbacks" +path = "examples/pruning_and_callbacks.rs" [[example]] -name = "visualization_demo" -path = "examples/visualization_demo.rs" +name = "advanced_features" +path = "examples/advanced_features.rs" +required-features = ["async", "journal"] [[test]] name = "journal_tests" diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d1a6f66 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +.PHONY: changelog + +# Regenerate CHANGELOG.md from git tags using git-cliff. +# Install: cargo install git-cliff +changelog: + git-cliff --output CHANGELOG.md diff --git a/README.md b/README.md index 56cc9f1..970013c 100644 --- a/README.md +++ b/README.md @@ -1,155 +1,69 @@ # optimizer -A Rust library for black-box optimization with multiple sampling strategies. +Bayesian and population-based optimization library with an Optuna-like API +for hyperparameter tuning and black-box optimization. Supports 12 samplers, +8 pruners, multi-objective optimization, async parallelism, and persistent storage. [![Docs](https://docs.rs/optimizer/badge.svg)](https://docs.rs/optimizer) [![Crates.io](https://img.shields.io/crates/v/optimizer.svg)](https://crates.io/crates/optimizer) [![codecov](https://codecov.io/gh/raimannma/rust-optimizer/graph/badge.svg?token=WOE77XJ4M6)](https://codecov.io/gh/raimannma/rust-optimizer) -## Features - -- Optuna-like API for hyperparameter optimization -- Multiple sampling strategies: - - **Random Search** - Simple random sampling for baseline comparisons - - **TPE (Tree-Parzen Estimator)** - Bayesian optimization for efficient search - - **Grid Search** - Exhaustive search over a specified parameter grid -- Float, integer, categorical, boolean, and enum parameter types -- Log-scale and stepped parameter sampling -- Sync and async optimization with parallel trial evaluation -- `#[derive(Categorical)]` for enum parameters - ## Quick Start ```rust -use optimizer::parameter::{FloatParam, Parameter}; -use optimizer::sampler::tpe::TpeSampler; -use optimizer::{Direction, Study}; - -// Create a study with TPE sampler -let sampler = TpeSampler::builder().seed(42).build().unwrap(); -let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - -// Define parameter search space -let x_param = FloatParam::new(-10.0, 10.0); - -// Optimize x^2 for 20 trials -study - .optimize_with_sampler(20, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, optimizer::Error>(x * x) - }) - .unwrap(); - -// Get the best result -let best = study.best_trial().unwrap(); -println!("Best value: {}", best.value); -for (id, label) in &best.param_labels { - println!(" {}: {:?}", label, best.params[id]); -} -``` - -## Samplers - -### Random Search - -```rust -use optimizer::{Direction, Study}; -use optimizer::sampler::random::RandomSampler; - -let study: Study<f64> = Study::with_sampler( - Direction::Minimize, - RandomSampler::with_seed(42), -); -``` - -### TPE (Tree-Parzen Estimator) - -```rust -use optimizer::{Direction, Study}; -use optimizer::sampler::tpe::TpeSampler; - -let sampler = TpeSampler::builder() - .gamma(0.15) // Quantile for good/bad split - .n_startup_trials(20) // Random trials before TPE kicks in - .n_ei_candidates(32) // Candidates to evaluate - .seed(42) - .build() - .unwrap(); - -let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); -``` - -#### Gamma Strategies +use optimizer::prelude::*; -The gamma parameter controls what fraction of trials are considered "good" when building the TPE model. Instead of a fixed value, you can use adaptive strategies: +let study: Study<f64> = Study::new(Direction::Minimize); +let x = FloatParam::new(-10.0, 10.0).name("x"); -| Strategy | Description | Formula | -|----------|-------------|---------| -| `FixedGamma` | Constant value (default: 0.25) | `γ = constant` | -| `LinearGamma` | Linear interpolation over trials | `γ = γ_min + (γ_max - γ_min) * min(n/n_max, 1)` | -| `SqrtGamma` | Optuna-style inverse sqrt scaling | `γ = min(γ_max, factor/√n / n)` | -| `HyperoptGamma` | Hyperopt-style adaptive | `γ = min(γ_max, (base + 1) / n)` | +study.optimize(50, |trial| { + let val = x.suggest(trial)?; + Ok::<_, Error>((val - 3.0).powi(2)) +}).unwrap(); -```rust -use optimizer::sampler::tpe::{TpeSampler, SqrtGamma, LinearGamma}; - -// Optuna-style gamma that decreases with more trials -let sampler = TpeSampler::builder() - .gamma_strategy(SqrtGamma::default()) - .build() - .unwrap(); - -// Linear interpolation from 0.1 to 0.3 over 100 trials -let sampler = TpeSampler::builder() - .gamma_strategy(LinearGamma::new(0.1, 0.3, 100).unwrap()) - .build() - .unwrap(); -``` - -You can also implement custom strategies: - -```rust -use optimizer::sampler::tpe::{TpeSampler, GammaStrategy}; - -#[derive(Debug, Clone)] -struct MyGamma { base: f64 } - -impl GammaStrategy for MyGamma { - fn gamma(&self, n_trials: usize) -> f64 { - (self.base + 0.01 * n_trials as f64).min(0.5) - } - fn clone_box(&self) -> Box<dyn GammaStrategy> { - Box::new(self.clone()) - } -} - -let sampler = TpeSampler::builder() - .gamma_strategy(MyGamma { base: 0.1 }) - .build() - .unwrap(); +let best = study.best_trial().unwrap(); +println!("Best x = {:.4}, f(x) = {:.4}", best.get(&x).unwrap(), best.value); ``` -### Grid Search - -```rust -use optimizer::{Direction, Study}; -use optimizer::sampler::grid::GridSearchSampler; - -let sampler = GridSearchSampler::builder() - .n_points_per_param(10) // Number of points per parameter dimension - .build(); +## Features at a Glance -let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); -``` +- **[Samplers](https://docs.rs/optimizer/latest/optimizer/sampler/)** — Random, TPE, Multivariate TPE, Grid, Sobol, CMA-ES, Gaussian Process, Differential Evolution, BOHB, NSGA-II, NSGA-III, MOEA/D +- **[Pruners](https://docs.rs/optimizer/latest/optimizer/pruner/)** — Median, Percentile, Threshold, Patient, Hyperband, Successive Halving, Wilcoxon, Nop +- **[Parameters](https://docs.rs/optimizer/latest/optimizer/parameter/)** — Float, Int, Categorical, Bool, and Enum types with `.name()` labels and typed access +- **[Multi-objective](https://docs.rs/optimizer/latest/optimizer/multi_objective/)** — Pareto front extraction with NSGA-II/III and MOEA/D +- **[Async & parallel](https://docs.rs/optimizer/latest/optimizer/struct.Study.html#method.optimize_parallel)** — Concurrent trial evaluation with Tokio +- **[Storage backends](https://docs.rs/optimizer/latest/optimizer/storage/)** — In-memory (default) or JSONL journal for persistence and resumption +- **[Visualization](https://docs.rs/optimizer/latest/optimizer/fn.generate_html_report.html)** — HTML reports with optimization history, parameter importance, and Pareto fronts +- **[Analysis](https://docs.rs/optimizer/latest/optimizer/struct.Study.html#method.fanova)** — fANOVA and Spearman correlation for parameter importance ## Feature Flags -- `async` - Enable async optimization methods (requires tokio) -- `derive` - Enable `#[derive(Categorical)]` for enum parameters +| Flag | Enables | Default | +|------|---------|---------| +| `async` | Async/parallel optimization (Tokio) | No | +| `derive` | `#[derive(Categorical)]` for enum parameters | No | +| `serde` | Serialization of trials and parameters | No | +| `journal` | JSONL storage backend (implies `serde`) | No | +| `sobol` | Sobol quasi-random sampler | No | +| `cma-es` | CMA-ES sampler (requires `nalgebra`) | No | +| `gp` | Gaussian Process sampler (requires `nalgebra`) | No | +| `tracing` | Structured logging with `tracing` | No | + +## Examples + +```sh +cargo run --example basic_optimization # Minimize a quadratic — simplest possible usage +cargo run --example sampler_comparison # Compare Random, TPE, and Grid on the same problem +cargo run --example pruning_and_callbacks # Trial pruning with MedianPruner + early stopping +cargo run --example parameter_types --features derive # All 5 param types + #[derive(Categorical)] +cargo run --example advanced_features --features async,journal # Async, journal storage, ask-and-tell, multi-objective +``` -## Documentation +## Learn More -Full API documentation is available at [docs.rs/optimizer](https://docs.rs/optimizer). +- [API documentation](https://docs.rs/optimizer) +- [Changelog](CHANGELOG.md) +- [GitHub Issues](https://github.com/raimannma/rust-optimizer/issues) ## License diff --git a/cliff.toml b/cliff.toml new file mode 100644 index 0000000..c2166f6 --- /dev/null +++ b/cliff.toml @@ -0,0 +1,99 @@ +# git-cliff configuration +# https://git-cliff.org/docs/configuration + +[changelog] +# changelog header +header = """ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\n +""" +# template for the changelog body +# https://keats.github.io/tera/docs/#introduction +body = """ +{%- macro remote_url() -%} + https://github.com/raimannma/rust-optimizer +{%- endmacro -%} + +{% if version -%} + ## [{{ version | trim_start_matches(pat="v") }}] - {{ timestamp | date(format="%Y-%m-%d") }} +{% else -%} + ## [Unreleased] +{% endif -%} + +{% for group, commits in commits | group_by(attribute="group") %} + ### {{ group | striptags | trim }} + {% for commit in commits %} + - {{ commit.message | split(pat="\n") | first | trim }}\ + {%- if commit.breaking %} (**BREAKING**){% endif -%} + {% endfor %} +{% endfor %} +""" +# template for the changelog footer +footer = """ +{%- macro remote_url() -%} + https://github.com/raimannma/rust-optimizer +{%- endmacro -%} + +{% for release in releases -%} + {% if release.version -%} + {% if release.previous.version -%} + [{{ release.version | trim_start_matches(pat="v") }}]: \ + {{ self::remote_url() }}/compare/{{ release.previous.version }}...{{ release.version }} + {% else -%} + [{{ release.version | trim_start_matches(pat="v") }}]: \ + {{ self::remote_url() }}/releases/tag/{{ release.version }} + {% endif -%} + {% else -%} + {% if release.previous.version -%} + [Unreleased]: {{ self::remote_url() }}/compare/{{ release.previous.version }}...HEAD + {% endif -%} + {% endif -%} +{% endfor %} +""" +# remove the leading and trailing whitespace from the templates +trim = true + +[git] +# parse the commits based on https://www.conventionalcommits.org +conventional_commits = true +# filter out the commits that are not conventional +filter_unconventional = false +# process each line of a commit as an individual commit +split_commits = false +# regex for preprocessing the commit messages +commit_preprocessors = [ + { pattern = '\((\w+\s)?#([0-9]+)\)', replace = "([#${2}](https://github.com/raimannma/rust-optimizer/issues/${2}))" }, +] +# regex for parsing and grouping commits +commit_parsers = [ + { message = "^feat", group = "Added" }, + { message = "^fix", group = "Fixed" }, + { message = "^refactor", group = "Changed" }, + { message = "^perf", group = "Changed" }, + { message = "^doc", group = "Changed" }, + { message = "^style", group = "Changed" }, + { message = "^ci", skip = true }, + { message = "^chore\\(release\\)", skip = true }, + { message = "^chore: release", skip = true }, + { message = "^chore", group = "Changed" }, + { message = "^revert", group = "Removed" }, + { body = ".*security", group = "Security" }, +] +# protect breaking changes from being skipped due to matching a skipping commit_parser +protect_breaking_commits = true +# filter out the commits that are not matched by commit parsers +filter_commits = true +# regex for matching git tags +tag_pattern = "v[0-9].*" +# regex for skipping tags +skip_tags = "" +# regex for ignoring tags +ignore_tags = "" +# sort the tags topologically +topo_order = false +# sort the commits inside sections by oldest first +sort_commits = "oldest" diff --git a/examples/advanced_features.rs b/examples/advanced_features.rs new file mode 100644 index 0000000..41eed01 --- /dev/null +++ b/examples/advanced_features.rs @@ -0,0 +1,277 @@ +//! Advanced Features Example +//! +//! This example demonstrates four advanced capabilities of the optimizer crate: +//! +//! 1. **Async parallel optimization** — evaluate multiple trials concurrently +//! 2. **Journal storage** — persist trials to disk and resume studies later +//! 3. **Ask-and-tell interface** — decouple sampling from evaluation +//! 4. **Multi-objective optimization** — optimize competing objectives simultaneously +//! +//! Run with: `cargo run --example advanced_features --features "async,journal"` + +use std::time::Instant; + +use optimizer::multi_objective::MultiObjectiveStudy; +use optimizer::prelude::*; + +// ============================================================================ +// Section 1: Async Parallel Optimization +// ============================================================================ + +/// Runs multiple trials concurrently using tokio, reducing wall-clock time +/// when the objective function involves I/O or other async work. +async fn async_parallel_optimization() -> optimizer::Result<()> { + println!("=== Section 1: Async Parallel Optimization ===\n"); + + let sampler = TpeSampler::builder() + .n_startup_trials(5) + .seed(42) + .build() + .expect("Failed to build TPE sampler"); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + let n_trials = 30; + let concurrency = 4; + + println!("Running {n_trials} trials with {concurrency} concurrent workers..."); + let start = Instant::now(); + + // optimize_parallel spawns up to `concurrency` trials at once. + // The closure must take ownership of Trial and return (Trial, value). + study + .optimize_parallel(n_trials, concurrency, { + let x = x.clone(); + let y = y.clone(); + move |mut trial| { + let x = x.clone(); + let y = y.clone(); + async move { + let xv = x.suggest(&mut trial)?; + let yv = y.suggest(&mut trial)?; + + // Simulate async I/O (e.g., calling an external service) + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + // Sphere function: minimum at origin + let value = xv * xv + yv * yv; + Ok::<_, optimizer::Error>((trial, value)) + } + } + }) + .await?; + + let elapsed = start.elapsed(); + let best = study.best_trial()?; + + println!( + "Completed in {elapsed:.2?} (vs ~{:.0?} sequential)", + std::time::Duration::from_millis(10 * n_trials as u64) + ); + println!( + "Best: f({:.3}, {:.3}) = {:.6}\n", + best.get(&x).unwrap(), + best.get(&y).unwrap(), + best.value + ); + + Ok(()) +} + +// ============================================================================ +// Section 2: Journal Storage +// ============================================================================ + +/// Persists trials to a JSONL file so that a study can be resumed later. +/// Useful for long-running experiments or crash recovery. +fn journal_storage_demo() -> optimizer::Result<()> { + println!("=== Section 2: Journal Storage ===\n"); + + let path = std::env::temp_dir().join("optimizer_advanced_example.jsonl"); + + // Clean up from any previous run + let _ = std::fs::remove_file(&path); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + + // --- First run: optimize 20 trials and persist to disk --- + { + let storage = JournalStorage::<f64>::new(&path); + let study: Study<f64> = Study::builder() + .minimize() + .sampler(TpeSampler::new()) + .storage(storage) + .build(); + + study.optimize(20, |trial| { + let xv = x.suggest(trial)?; + Ok::<_, optimizer::Error>(xv * xv) + })?; + + println!( + "First run: {} trials saved to {}", + study.n_trials(), + path.display() + ); + } + + // --- Second run: resume from the journal file --- + { + // JournalStorage::open loads existing trials from disk + let storage = JournalStorage::<f64>::open(&path)?; + let study: Study<f64> = Study::builder() + .minimize() + .sampler(TpeSampler::new()) + .storage(storage) + .build(); + + // The sampler sees the prior 20 trials, so it starts informed + let before = study.n_trials(); + study.optimize(10, |trial| { + let xv = x.suggest(trial)?; + Ok::<_, optimizer::Error>(xv * xv) + })?; + + let best = study.best_trial()?; + println!( + "Resumed: {} → {} trials, best f({:.4}) = {:.6}", + before, + study.n_trials(), + best.get(&x).unwrap(), + best.value + ); + } + + // Clean up the temporary file + let _ = std::fs::remove_file(&path); + + println!(); + Ok(()) +} + +// ============================================================================ +// Section 3: Ask-and-Tell Interface +// ============================================================================ + +/// Decouples trial creation from evaluation. Useful when: +/// - Evaluations happen outside the optimizer (e.g., in a separate process) +/// - You want to batch evaluations before reporting results +/// - You need custom scheduling logic +fn ask_and_tell_demo() -> optimizer::Result<()> { + println!("=== Section 3: Ask-and-Tell Interface ===\n"); + + let study: Study<f64> = Study::new(Direction::Minimize); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + // Ask for a batch of trials, evaluate externally, then tell results + for batch in 0..3 { + let batch_size = 5; + let mut trials = Vec::with_capacity(batch_size); + + // ask() creates trials with sampled parameters + for _ in 0..batch_size { + let mut trial = study.ask(); + let xv = x.suggest(&mut trial)?; + let yv = y.suggest(&mut trial)?; + + // Store values alongside the trial for later evaluation + trials.push((trial, xv, yv)); + } + + // Evaluate the batch (could be sent to workers, GPUs, etc.) + for (trial, xv, yv) in trials { + let value = xv * xv + yv * yv; + // tell() reports the result back to the study + study.tell(trial, Ok::<_, &str>(value)); + } + + println!( + "Batch {}: evaluated {} trials (total: {})", + batch + 1, + batch_size, + study.n_trials() + ); + } + + let best = study.best_trial()?; + println!( + "Best: f({:.3}, {:.3}) = {:.6}\n", + best.get(&x).unwrap(), + best.get(&y).unwrap(), + best.value + ); + + Ok(()) +} + +// ============================================================================ +// Section 4: Multi-Objective Optimization +// ============================================================================ + +/// Optimizes two competing objectives simultaneously. +/// Returns the Pareto front — the set of solutions where no objective can +/// be improved without worsening the other. +fn multi_objective_demo() -> optimizer::Result<()> { + println!("=== Section 4: Multi-Objective Optimization ===\n"); + + // Two objectives, both minimized + let study = MultiObjectiveStudy::new(vec![Direction::Minimize, Direction::Minimize]); + + let x = FloatParam::new(0.0, 1.0).name("x"); + + // Classic bi-objective problem: f1(x) = x², f2(x) = (x - 1)² + // The Pareto front is the curve where improving f1 worsens f2 and vice versa. + study.optimize(50, |trial| { + let xv = x.suggest(trial)?; + let f1 = xv * xv; + let f2 = (xv - 1.0) * (xv - 1.0); + Ok::<_, optimizer::Error>(vec![f1, f2]) + })?; + + let front = study.pareto_front(); + println!( + "Ran {} trials, Pareto front has {} solutions:", + study.n_trials(), + front.len() + ); + + // Show a few Pareto-optimal trade-offs + let mut sorted_front = front.clone(); + sorted_front.sort_by(|a, b| a.values[0].partial_cmp(&b.values[0]).unwrap()); + + for (i, trial) in sorted_front.iter().take(5).enumerate() { + println!( + " {}: x={:.3}, f1={:.4}, f2={:.4}", + i + 1, + trial.get(&x).unwrap(), + trial.values[0], + trial.values[1] + ); + } + if sorted_front.len() > 5 { + println!(" ... and {} more", sorted_front.len() - 5); + } + + println!(); + Ok(()) +} + +// ============================================================================ +// Main +// ============================================================================ + +#[tokio::main] +async fn main() -> optimizer::Result<()> { + async_parallel_optimization().await?; + journal_storage_demo()?; + ask_and_tell_demo()?; + multi_objective_demo()?; + + println!("All sections completed successfully!"); + Ok(()) +} diff --git a/examples/async_api_optimization.rs b/examples/async_api_optimization.rs deleted file mode 100644 index 7b41c4f..0000000 --- a/examples/async_api_optimization.rs +++ /dev/null @@ -1,368 +0,0 @@ -//! Async API Parameter Optimization Example -//! -//! This example shows how to use async/parallel optimization to tune -//! configuration parameters for a web service. Each evaluation simulates -//! an async operation (like deploying and load-testing a service). -//! -//! # Key Concepts Demonstrated -//! -//! - Async optimization with `optimize_parallel` -//! - Running multiple trials concurrently for faster optimization -//! - Boolean and categorical parameter types -//! - Measuring speedup from parallelism -//! -//! # When to Use Async Optimization -//! -//! Use async/parallel optimization when your objective function involves: -//! - Network requests (API calls, database queries) -//! - File I/O operations -//! - External service calls -//! - Any operation where you're waiting for I/O rather than computing -//! -//! With parallelism, you can evaluate multiple configurations simultaneously, -//! significantly reducing total optimization time. -//! -//! Run with: `cargo run --example async_api_optimization --features async` - -use std::time::{Duration, Instant}; - -use optimizer::prelude::*; - -// ============================================================================ -// Configuration: Service parameters we want to tune -// ============================================================================ - -/// Configuration for a web service. -/// -/// In a real application, these parameters would control: -/// - Memory allocation (cache sizes) -/// - Connection management (pool sizes, timeouts) -/// - Request handling (batching, compression) -/// - Protocol options (HTTP version, load balancing) -struct ServiceConfig { - cache_size_mb: i64, - connection_pool_size: i64, - request_timeout_ms: i64, - retry_count: i64, - batch_size: i64, - compression_level: i64, - use_http2: bool, - load_balancing: String, -} - -// ============================================================================ -// Objective Function: Evaluate a service configuration -// ============================================================================ - -/// Simulates deploying and load-testing a service configuration. -/// -/// In a real scenario, this function might: -/// 1. Deploy the configuration to a staging environment -/// 2. Run load tests against the service -/// 3. Collect metrics (latency, throughput, error rate) -/// 4. Return a composite score -/// -/// The async sleep simulates the I/O time of these operations. -/// This is where parallel execution helps - while one trial is waiting -/// for I/O, other trials can run. -#[allow(clippy::too_many_arguments)] -async fn evaluate_service(config: &ServiceConfig) -> f64 { - // Simulate async I/O (deployment, load testing, metric collection) - tokio::time::sleep(Duration::from_millis(50)).await; - - // Calculate a score based on how close we are to optimal values - // Lower score = better configuration - let mut score = 0.0; - - // Cache size: too small = cache misses, too large = wasted memory - // Optimal around 512MB - let cache_optimal = 512.0; - score += ((config.cache_size_mb as f64 - cache_optimal) / 256.0).powi(2); - - // Connection pool: too small = contention, too large = resource waste - // Optimal around 100 - let pool_optimal = 100.0; - score += ((config.connection_pool_size as f64 - pool_optimal) / 50.0).powi(2); - - // Timeout: too short = false failures, too long = slow recovery - // Optimal around 5000ms - let timeout_optimal = 5000.0; - score += ((config.request_timeout_ms as f64 - timeout_optimal) / 2000.0).powi(2); - - // Retries: too few = fragile, too many = amplifies failures - // Optimal around 3 - let retry_optimal = 3.0; - score += ((config.retry_count as f64 - retry_optimal) / 2.0).powi(2); - - // Batch size: trade-off between latency and throughput - // Optimal around 64 - let batch_optimal = 64.0; - score += ((config.batch_size as f64 - batch_optimal) / 32.0).powi(2); - - // Compression level: trade-off between CPU and bandwidth - // Optimal around 6 - let compression_optimal = 6.0; - score += ((config.compression_level as f64 - compression_optimal) / 3.0).powi(2); - - // HTTP/2 is generally better for our use case - if !config.use_http2 { - score += 0.5; - } - - // Load balancing strategy affects performance - score += match config.load_balancing.as_str() { - "round_robin" => 0.0, // Best for our use case - "least_connections" => 0.1, // Good alternative - "ip_hash" => 0.2, // OK for session affinity - "random" => 0.3, // Not ideal - _ => 1.0, - }; - - // Add noise to simulate real-world variability - let noise = (config.cache_size_mb as f64 * 0.1).sin() * 0.05; - - score + noise -} - -/// The async objective function for each trial. -/// -/// For async optimization, the objective function must: -/// 1. Take ownership of the Trial (not a mutable reference) -/// 2. Return a Future -/// 3. Return both the Trial and the result value as a tuple -/// -/// This ownership pattern allows the trial to be used across await points. -#[allow(clippy::too_many_arguments)] -async fn objective( - mut trial: Trial, - cache_size_mb_param: &IntParam, - connection_pool_size_param: &IntParam, - request_timeout_ms_param: &IntParam, - retry_count_param: &IntParam, - batch_size_param: &IntParam, - compression_level_param: &IntParam, - use_http2_param: &BoolParam, - load_balancing_param: &CategoricalParam<&str>, -) -> optimizer::Result<(Trial, f64)> { - // Sample configuration parameters using parameter definitions - let cache_size_mb = cache_size_mb_param.suggest(&mut trial)?; - let connection_pool_size = connection_pool_size_param.suggest(&mut trial)?; - let request_timeout_ms = request_timeout_ms_param.suggest(&mut trial)?; - let retry_count = retry_count_param.suggest(&mut trial)?; - let batch_size = batch_size_param.suggest(&mut trial)?; - let compression_level = compression_level_param.suggest(&mut trial)?; - let use_http2 = use_http2_param.suggest(&mut trial)?; - let load_balancing = load_balancing_param.suggest(&mut trial)?; - - // Build configuration - let config = ServiceConfig { - cache_size_mb, - connection_pool_size, - request_timeout_ms, - retry_count, - batch_size, - compression_level, - use_http2, - load_balancing: load_balancing.to_string(), - }; - - // Evaluate (this is the async part) - let score = evaluate_service(&config).await; - - // Return both the trial and the score - Ok((trial, score)) -} - -// ============================================================================ -// Helper Functions -// ============================================================================ - -/// Prints the results of the optimization. -fn print_results(study: &Study<f64>, elapsed: Duration, n_trials: usize) { - println!("\n{}", "=".repeat(60)); - println!("\nOptimization completed!"); - println!("Total trials: {}", study.n_trials()); - println!("Time elapsed: {elapsed:.2?}"); - - // Calculate speedup from parallelism - // Each trial takes ~50ms, so sequential would take n_trials * 50ms - let sequential_time = n_trials as f64 * 0.050; - let actual_time = elapsed.as_secs_f64(); - println!( - "Effective parallelism: {:.1}x speedup", - sequential_time / actual_time - ); -} - -/// Prints the best configuration found. -#[allow(clippy::too_many_arguments)] -fn print_best_config( - study: &Study<f64>, - cache_size_mb_param: &IntParam, - connection_pool_size_param: &IntParam, - request_timeout_ms_param: &IntParam, - retry_count_param: &IntParam, - batch_size_param: &IntParam, - compression_level_param: &IntParam, - use_http2_param: &BoolParam, - load_balancing_param: &CategoricalParam<&str>, -) -> optimizer::Result<()> { - let best = study.best_trial()?; - - println!("\nBest configuration found:"); - println!(" Score: {:.6}", best.value); - println!("\n Parameters:"); - println!( - " cache_size_mb: {}", - best.get(cache_size_mb_param).unwrap() - ); - println!( - " connection_pool_size: {}", - best.get(connection_pool_size_param).unwrap() - ); - println!( - " request_timeout_ms: {}", - best.get(request_timeout_ms_param).unwrap() - ); - println!(" retry_count: {}", best.get(retry_count_param).unwrap()); - println!(" batch_size: {}", best.get(batch_size_param).unwrap()); - println!( - " compression_level: {}", - best.get(compression_level_param).unwrap() - ); - println!(" use_http2: {}", best.get(use_http2_param).unwrap()); - println!( - " load_balancing: {}", - best.get(load_balancing_param).unwrap() - ); - - Ok(()) -} - -/// Prints the top N trials. -fn print_top_trials(study: &Study<f64>, n: usize) { - println!("\nTop {n} trials:"); - - let mut trials = study.trials(); - trials.sort_by(|a, b| a.value.partial_cmp(&b.value).unwrap()); - - for (i, trial) in trials.iter().take(n).enumerate() { - println!( - " {}. Trial #{}: score = {:.6}", - i + 1, - trial.id, - trial.value - ); - } -} - -// ============================================================================ -// Main: Set up and run the async optimization -// ============================================================================ - -#[tokio::main] -async fn main() -> optimizer::Result<()> { - println!("=== Async API Parameter Optimization Example ===\n"); - - // Step 1: Create a TPE sampler - let sampler = TpeSampler::builder() - .n_startup_trials(8) - .gamma(0.2) - .seed(123) - .build() - .expect("Failed to build TPE sampler"); - - // Step 2: Create a study to minimize the score - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - // Step 3: Define parameter search spaces - let cache_size_mb_param = IntParam::new(64, 1024).name("cache_size_mb").step(64); - let connection_pool_size_param = IntParam::new(10, 200).name("connection_pool_size").step(10); - let request_timeout_ms_param = IntParam::new(1000, 10000) - .name("request_timeout_ms") - .step(500); - let retry_count_param = IntParam::new(0, 5).name("retry_count"); - let batch_size_param = IntParam::new(1, 256).name("batch_size").log_scale(); - let compression_level_param = IntParam::new(0, 9).name("compression_level"); - let use_http2_param = BoolParam::new().name("use_http2"); - let load_balancing_param = CategoricalParam::new(vec![ - "round_robin", - "least_connections", - "random", - "ip_hash", - ]) - .name("load_balancing"); - - // Clone params for use after the closure moves them - let cache_size_mb_p = cache_size_mb_param.clone(); - let connection_pool_size_p = connection_pool_size_param.clone(); - let request_timeout_ms_p = request_timeout_ms_param.clone(); - let retry_count_p = retry_count_param.clone(); - let batch_size_p = batch_size_param.clone(); - let compression_level_p = compression_level_param.clone(); - let use_http2_p = use_http2_param.clone(); - let load_balancing_p = load_balancing_param.clone(); - - // Step 4: Configure optimization - let n_trials = 40; - let concurrency = 4; // Run 4 trials in parallel - - println!("Starting parallel optimization with {concurrency} concurrent evaluations...\n"); - - let start = Instant::now(); - - // Step 5: Run parallel async optimization - // - // optimize_parallel: - // - Runs up to `concurrency` trials simultaneously - // - Each trial calls the objective function - // - Uses a semaphore to limit concurrent evaluations - // - Collects results as trials complete - // - // The sampler gets access to trial history for informed sampling. - study - .optimize_parallel(n_trials, concurrency, move |trial| { - let cache_size_mb_param = cache_size_mb_param.clone(); - let connection_pool_size_param = connection_pool_size_param.clone(); - let request_timeout_ms_param = request_timeout_ms_param.clone(); - let retry_count_param = retry_count_param.clone(); - let batch_size_param = batch_size_param.clone(); - let compression_level_param = compression_level_param.clone(); - let use_http2_param = use_http2_param.clone(); - let load_balancing_param = load_balancing_param.clone(); - async move { - objective( - trial, - &cache_size_mb_param, - &connection_pool_size_param, - &request_timeout_ms_param, - &retry_count_param, - &batch_size_param, - &compression_level_param, - &use_http2_param, - &load_balancing_param, - ) - .await - } - }) - .await?; - - let elapsed = start.elapsed(); - - // Step 5: Print results - print_results(&study, elapsed, n_trials); - print_best_config( - &study, - &cache_size_mb_p, - &connection_pool_size_p, - &request_timeout_ms_p, - &retry_count_p, - &batch_size_p, - &compression_level_p, - &use_http2_p, - &load_balancing_p, - )?; - print_top_trials(&study, 5); - - Ok(()) -} diff --git a/examples/basic_optimization.rs b/examples/basic_optimization.rs new file mode 100644 index 0000000..097853a --- /dev/null +++ b/examples/basic_optimization.rs @@ -0,0 +1,32 @@ +//! Basic optimization example — the "hello world" of the optimizer crate. +//! +//! Minimizes a simple quadratic function f(x) = (x - 3)² using the default +//! random sampler. No feature flags are required. +//! +//! Run with: `cargo run --example basic_optimization` + +use optimizer::prelude::*; + +fn main() { + // Create a study that minimizes the objective function. + // The default sampler is random; for smarter sampling, pass a TpeSampler. + let study: Study<f64> = Study::new(Direction::Minimize); + + // Search for x in [-10, 10]. The optimizer will suggest values from this range. + let x = FloatParam::new(-10.0, 10.0).name("x"); + + // Run 50 trials, each evaluating f(x) = (x - 3)² + study + .optimize(50, |trial| { + let x_val = x.suggest(trial)?; + let value = (x_val - 3.0).powi(2); + Ok::<_, Error>(value) + }) + .unwrap(); + + // Retrieve and display the best result + let best = study.best_trial().unwrap(); + println!("Best trial #{}", best.id); + println!(" x = {:.4}", best.get(&x).unwrap()); + println!(" f(x) = {:.4}", best.value); +} diff --git a/examples/benchmark_convergence.rs b/examples/benchmark_convergence.rs deleted file mode 100644 index 9f478a7..0000000 --- a/examples/benchmark_convergence.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::ops::ControlFlow; -use std::time::Instant; - -use optimizer::parameter::Parameter; -use optimizer::sampler::random::RandomSampler; -use optimizer::sampler::tpe::TpeSampler; -use optimizer::{FloatParam, Study}; - -/// Standard optimization test functions. -mod functions { - pub fn sphere(x: &[f64]) -> f64 { - x.iter().map(|xi| xi * xi).sum() - } - - pub fn rosenbrock(x: &[f64]) -> f64 { - x.windows(2) - .map(|w| 100.0 * (w[1] - w[0] * w[0]).powi(2) + (1.0 - w[0]).powi(2)) - .sum() - } - - pub fn rastrigin(x: &[f64]) -> f64 { - let n = x.len() as f64; - 10.0 * n - + x.iter() - .map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos()) - .sum::<f64>() - } -} - -fn run_convergence( - name: &str, - sampler_name: &str, - study: Study<f64>, - params: &[FloatParam], - objective: fn(&[f64]) -> f64, - n_trials: usize, -) { - let start = Instant::now(); - - study - .optimize_with_callback( - n_trials, - |trial| { - let x: Vec<f64> = params - .iter() - .map(|p| p.suggest(trial)) - .collect::<Result<_, _>>() - .unwrap(); - Ok::<_, optimizer::Error>(objective(&x)) - }, - |study, _trial| { - let elapsed = start.elapsed().as_millis(); - let best = study.best_value().unwrap(); - let n = study.n_trials(); - println!("{n},{best},{elapsed},{sampler_name},{name}"); - ControlFlow::Continue(()) - }, - ) - .unwrap(); -} - -fn main() { - println!("trial,best_value,elapsed_ms,sampler,function"); - - let dims = 5; - let params: Vec<FloatParam> = (0..dims) - .map(|i| FloatParam::new(-5.0, 5.0).name(format!("x{i}"))) - .collect(); - let n_trials = 200; - - // Sphere: Random vs TPE - run_convergence( - "sphere_5d", - "random", - Study::minimize(RandomSampler::with_seed(1)), - ¶ms, - functions::sphere, - n_trials, - ); - run_convergence( - "sphere_5d", - "tpe", - Study::minimize(TpeSampler::builder().seed(1).build().unwrap()), - ¶ms, - functions::sphere, - n_trials, - ); - - // Rosenbrock: Random vs TPE - run_convergence( - "rosenbrock_5d", - "random", - Study::minimize(RandomSampler::with_seed(2)), - ¶ms, - functions::rosenbrock, - n_trials, - ); - run_convergence( - "rosenbrock_5d", - "tpe", - Study::minimize(TpeSampler::builder().seed(2).build().unwrap()), - ¶ms, - functions::rosenbrock, - n_trials, - ); - - // Rastrigin: Random vs TPE - run_convergence( - "rastrigin_5d", - "random", - Study::minimize(RandomSampler::with_seed(3)), - ¶ms, - functions::rastrigin, - n_trials, - ); - run_convergence( - "rastrigin_5d", - "tpe", - Study::minimize(TpeSampler::builder().seed(3).build().unwrap()), - ¶ms, - functions::rastrigin, - n_trials, - ); -} diff --git a/examples/ml_hyperparameter_tuning.rs b/examples/ml_hyperparameter_tuning.rs deleted file mode 100644 index d73973d..0000000 --- a/examples/ml_hyperparameter_tuning.rs +++ /dev/null @@ -1,275 +0,0 @@ -//! Machine Learning Hyperparameter Tuning Example -//! -//! This example shows how to use the optimizer library to find the best -//! hyperparameters for a machine learning model. We simulate a gradient -//! boosting model (like XGBoost or LightGBM) and search for optimal settings. -//! -//! # Key Concepts Demonstrated -//! -//! - Creating a Study with a TPE (Tree-Parzen Estimator) sampler -//! - Defining an objective function that the optimizer will minimize -//! - Using different parameter types: floats, integers, log-scale, stepped -//! - Using callbacks to monitor progress and implement early stopping -//! -//! # How It Works -//! -//! 1. Create a `Study` - this manages the optimization process -//! 2. Define an objective function that takes a `Trial` and returns a score -//! 3. Inside the objective, use `trial.suggest_*()` to sample parameters -//! 4. The optimizer runs many trials, learning which parameter regions work best -//! 5. After optimization, retrieve the best parameters found -//! -//! Run with: `cargo run --example ml_hyperparameter_tuning` - -use std::ops::ControlFlow; - -use optimizer::prelude::*; - -// ============================================================================ -// Configuration: Hyperparameters we want to tune -// ============================================================================ - -/// Holds all the hyperparameters for our model. -/// -/// In a real application, you would pass these to your ML framework -/// (e.g., XGBoost, LightGBM, scikit-learn). -struct ModelConfig { - learning_rate: f64, - max_depth: i64, - n_estimators: i64, - subsample: f64, - colsample_bytree: f64, - min_child_weight: i64, - reg_alpha: f64, - reg_lambda: f64, -} - -// ============================================================================ -// Objective Function: What we want to optimize -// ============================================================================ - -/// Simulates training a model and returns the validation loss. -/// -/// In a real scenario, this function would: -/// 1. Create a model with the given hyperparameters -/// 2. Train it on your training data -/// 3. Evaluate it on validation data -/// 4. Return the validation metric (e.g., RMSE, log loss, accuracy) -/// -/// The optimizer will try to MINIMIZE this value (we set Direction::Minimize). -#[allow(clippy::too_many_arguments)] -fn evaluate_model(config: &ModelConfig) -> f64 { - // Simulated optimal hyperparameters: - // learning_rate ~ 0.05, max_depth ~ 6, n_estimators ~ 200 - // subsample ~ 0.8, colsample_bytree ~ 0.8, min_child_weight ~ 3 - // reg_alpha ~ 0.1, reg_lambda ~ 1.0 - - let mut loss = 0.15; // Base loss - - // Each term penalizes deviation from the optimal value - loss += (config.learning_rate - 0.05).powi(2) * 100.0; - loss += ((config.max_depth - 6) as f64).powi(2) * 0.01; - loss += ((config.n_estimators - 200) as f64).powi(2) * 0.00001; - loss += (config.subsample - 0.8).powi(2) * 10.0; - loss += (config.colsample_bytree - 0.8).powi(2) * 10.0; - loss += ((config.min_child_weight - 3) as f64).powi(2) * 0.05; - loss += (config.reg_alpha - 0.1).powi(2) * 5.0; - loss += (config.reg_lambda - 1.0).powi(2) * 2.0; - - // Add some noise to simulate real-world variability - let noise = (config.learning_rate * 1000.0).sin() * 0.01; - - loss + noise -} - -/// The objective function that the optimizer calls for each trial. -/// -/// This function: -/// 1. Uses parameter definitions passed as arguments -/// 2. Builds a model configuration from the suggested values -/// 3. Evaluates the model and returns the loss -/// -/// The optimizer learns from the results to suggest better parameters -/// in future trials. -#[allow(clippy::too_many_arguments)] -fn objective( - trial: &mut Trial, - learning_rate_param: &FloatParam, - max_depth_param: &IntParam, - n_estimators_param: &IntParam, - subsample_param: &FloatParam, - colsample_bytree_param: &FloatParam, - min_child_weight_param: &IntParam, - reg_alpha_param: &FloatParam, - reg_lambda_param: &FloatParam, -) -> optimizer::Result<f64> { - let learning_rate = learning_rate_param.suggest(trial)?; - let max_depth = max_depth_param.suggest(trial)?; - let n_estimators = n_estimators_param.suggest(trial)?; - let subsample = subsample_param.suggest(trial)?; - let colsample_bytree = colsample_bytree_param.suggest(trial)?; - let min_child_weight = min_child_weight_param.suggest(trial)?; - let reg_alpha = reg_alpha_param.suggest(trial)?; - let reg_lambda = reg_lambda_param.suggest(trial)?; - - // Build configuration and evaluate - let config = ModelConfig { - learning_rate, - max_depth, - n_estimators, - subsample, - colsample_bytree, - min_child_weight, - reg_alpha, - reg_lambda, - }; - - let loss = evaluate_model(&config); - - Ok(loss) -} - -// ============================================================================ -// Callback Function: Monitor progress and implement early stopping -// ============================================================================ - -/// Called after each successful trial completes. -/// -/// Use callbacks to: -/// - Log progress to console or file -/// - Save checkpoints -/// - Implement early stopping when a good solution is found -/// - Track metrics over time -/// -/// Return `ControlFlow::Continue(())` to keep optimizing. -/// Return `ControlFlow::Break(())` to stop early. -fn on_trial_complete(study: &Study<f64>, trial: &CompletedTrial<f64>) -> ControlFlow<()> { - // Print trial number and objective value - print!("{:>5} ", study.n_trials()); - for value in trial.params.values() { - print!("{value:>12} "); - } - println!("{:>12.6}", trial.value); - - // Early stopping: if we find an excellent solution, stop early - if trial.value < 0.16 { - println!("\nEarly stopping: found excellent solution!"); - return ControlFlow::Break(()); - } - - ControlFlow::Continue(()) -} - -// ============================================================================ -// Main: Set up and run the optimization -// ============================================================================ - -fn main() -> optimizer::Result<()> { - println!("=== ML Hyperparameter Tuning Example ===\n"); - - // Step 1: Create a sampler - // - // TPE (Tree-Parzen Estimator) is a Bayesian optimization algorithm. - // It learns from previous trials to suggest better parameters. - // - n_startup_trials: Number of random trials before TPE kicks in - // - gamma: What fraction of trials are considered "good" (lower = more selective) - // - seed: For reproducibility - let sampler = TpeSampler::builder() - .n_startup_trials(10) - .gamma(0.25) - .seed(42) - .build() - .expect("Failed to build TPE sampler"); - - // Step 2: Create a study - // - // The study manages the optimization process. We want to MINIMIZE - // the loss (lower is better). Use Direction::Maximize for metrics - // where higher is better (like accuracy). - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - // Print header - println!("Starting hyperparameter optimization...\n"); - println!( - "{:>5} {:>12} (parameters...) {:>12}", - "Trial", "Params", "Loss" - ); - println!("{}", "-".repeat(60)); - - // Step 3: Define parameter search spaces - let learning_rate_param = FloatParam::new(0.001, 0.3) - .name("learning_rate") - .log_scale(); - let max_depth_param = IntParam::new(3, 12).name("max_depth"); - let n_estimators_param = IntParam::new(50, 500).name("n_estimators").step(50); - let subsample_param = FloatParam::new(0.5, 1.0).name("subsample"); - let colsample_bytree_param = FloatParam::new(0.5, 1.0).name("colsample_bytree"); - let min_child_weight_param = IntParam::new(1, 10).name("min_child_weight"); - let reg_alpha_param = FloatParam::new(1e-3, 10.0).name("reg_alpha").log_scale(); - let reg_lambda_param = FloatParam::new(1e-3, 10.0).name("reg_lambda").log_scale(); - - // Step 4: Run optimization - // - // optimize_with_callback runs the objective function for up to - // n_trials iterations. After each trial, it calls the callback. - // The sampler gets access to trial history for informed sampling. - let n_trials = 50; - - study.optimize_with_callback( - n_trials, - |trial| { - objective( - trial, - &learning_rate_param, - &max_depth_param, - &n_estimators_param, - &subsample_param, - &colsample_bytree_param, - &min_child_weight_param, - ®_alpha_param, - ®_lambda_param, - ) - }, - on_trial_complete, - )?; - - // Step 4: Get the best result - println!("\n{}", "=".repeat(110)); - println!("\nOptimization completed!"); - println!("Total trials: {}", study.n_trials()); - - let best = study.best_trial()?; - println!("\nBest trial:"); - println!(" Loss: {:.6}", best.value); - println!(" Parameters:"); - println!( - " learning_rate: {:.6}", - best.get(&learning_rate_param).unwrap() - ); - println!(" max_depth: {}", best.get(&max_depth_param).unwrap()); - println!( - " n_estimators: {}", - best.get(&n_estimators_param).unwrap() - ); - println!(" subsample: {:.6}", best.get(&subsample_param).unwrap()); - println!( - " colsample_bytree: {:.6}", - best.get(&colsample_bytree_param).unwrap() - ); - println!( - " min_child_weight: {}", - best.get(&min_child_weight_param).unwrap() - ); - println!(" reg_alpha: {:.6}", best.get(®_alpha_param).unwrap()); - println!( - " reg_lambda: {:.6}", - best.get(®_lambda_param).unwrap() - ); - - // Step 5: Use the best parameters (in a real app) - // - // Now you would take best.params and use them to train your final model - // on the full dataset. - - Ok(()) -} diff --git a/examples/parameter_api.rs b/examples/parameter_api.rs deleted file mode 100644 index 119fcf4..0000000 --- a/examples/parameter_api.rs +++ /dev/null @@ -1,56 +0,0 @@ -use optimizer::prelude::*; -use optimizer_derive::Categorical; - -#[derive(Clone, Debug, Categorical)] -enum Activation { - Relu, - Sigmoid, - Tanh, -} - -fn main() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Define parameters outside the objective function - let lr_param = FloatParam::new(1e-5, 1e-1).name("lr").log_scale(); - let n_layers_param = IntParam::new(1, 5).name("n_layers"); - let units_param = IntParam::new(32, 512).name("units").step(32); - let optimizer_param = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]).name("optimizer"); - let activation_param = EnumParam::<Activation>::new().name("activation"); - let batch_size_param = IntParam::new(16, 256).name("batch_size").log_scale(); - let use_dropout_param = BoolParam::new().name("use_dropout"); - - study - .optimize(20, |trial| { - let lr = lr_param.suggest(trial)?; - let n_layers = n_layers_param.suggest(trial)?; - let units = units_param.suggest(trial)?; - let optimizer = optimizer_param.suggest(trial)?; - let use_dropout = use_dropout_param.suggest(trial)?; - let activation = activation_param.suggest(trial)?; - let batch_size = batch_size_param.suggest(trial)?; - - // Simulate a loss function - let loss = lr * (n_layers as f64) + (units as f64) * 0.001 - - if use_dropout { 0.1 } else { 0.0 }; - - println!( - "Trial {}: lr={lr:.6}, layers={n_layers}, units={units}, opt={optimizer}, \ - dropout={use_dropout}, activation={activation:?}, batch={batch_size} -> loss={loss:.4}", - trial.id() - ); - - Ok::<_, Error>(loss) - }) - .unwrap(); - - let best = study.best_trial().unwrap(); - println!("\nBest trial: value={:.4}", best.value); - println!(" lr: {:.6}", best.get(&lr_param).unwrap()); - println!(" n_layers: {}", best.get(&n_layers_param).unwrap()); - println!(" units: {}", best.get(&units_param).unwrap()); - println!(" optimizer: {}", best.get(&optimizer_param).unwrap()); - println!(" activation: {:?}", best.get(&activation_param).unwrap()); - println!(" batch_size: {}", best.get(&batch_size_param).unwrap()); - println!(" use_dropout: {}", best.get(&use_dropout_param).unwrap()); -} diff --git a/examples/parameter_types.rs b/examples/parameter_types.rs new file mode 100644 index 0000000..c509c5b --- /dev/null +++ b/examples/parameter_types.rs @@ -0,0 +1,73 @@ +//! Parameter types example — demonstrates all five parameter types and the derive feature. +//! +//! Shows `FloatParam`, `IntParam`, `CategoricalParam`, `BoolParam`, and `EnumParam` +//! with `.name()` labels, `#[derive(Categorical)]` for enums, and typed access +//! to results via `CompletedTrial::get()`. +//! +//! Run with: `cargo run --example parameter_types --features derive` + +use optimizer::prelude::*; +use optimizer_derive::Categorical; + +/// Activation functions — `#[derive(Categorical)]` auto-generates the +/// `Categorical` trait, mapping each variant to a sequential index. +#[derive(Clone, Debug, Categorical)] +enum Activation { + Relu, + Sigmoid, + Tanh, + Gelu, +} + +fn main() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // --- Define one of each parameter type, each with a human-readable .name() --- + + // Float: learning rate on a log scale (common for ML hyperparameters) + let lr = FloatParam::new(1e-5, 1e-1).log_scale().name("lr"); + + // Int: number of hidden layers (stepped by 1, the default) + let n_layers = IntParam::new(1, 5).name("n_layers"); + + // Categorical: optimizer algorithm chosen from a list of strings + let optimizer = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]).name("optimizer"); + + // Bool: whether to apply dropout + let use_dropout = BoolParam::new().name("use_dropout"); + + // Enum: activation function — uses #[derive(Categorical)] above + let activation = EnumParam::<Activation>::new().name("activation"); + + // --- Run the optimization --- + study + .optimize(30, |trial| { + let lr_val = lr.suggest(trial)?; + let layers = n_layers.suggest(trial)?; + let opt = optimizer.suggest(trial)?; + let dropout = use_dropout.suggest(trial)?; + let act = activation.suggest(trial)?; + + // Simulated loss that depends on all parameters + let loss = lr_val * f64::from(layers as i32) + + if opt == "adam" { -0.05 } else { 0.0 } + + if dropout { -0.02 } else { 0.0 } + + match act { + Activation::Gelu => -0.03, + Activation::Relu => -0.01, + _ => 0.0, + }; + + Ok::<_, Error>(loss) + }) + .unwrap(); + + // --- Retrieve best trial and read back each parameter with typed .get() --- + let best = study.best_trial().unwrap(); + println!("Best trial #{} — loss = {:.6}", best.id, best.value); + println!(" lr = {:.6}", best.get(&lr).unwrap()); + println!(" n_layers = {}", best.get(&n_layers).unwrap()); + println!(" optimizer = {}", best.get(&optimizer).unwrap()); + println!(" use_dropout = {}", best.get(&use_dropout).unwrap()); + println!(" activation = {:?}", best.get(&activation).unwrap()); +} diff --git a/examples/pruning_and_callbacks.rs b/examples/pruning_and_callbacks.rs new file mode 100644 index 0000000..207e55b --- /dev/null +++ b/examples/pruning_and_callbacks.rs @@ -0,0 +1,133 @@ +//! Pruning and early-stopping example — demonstrates trial pruning with `MedianPruner` +//! and early stopping via `optimize_with_callback`. +//! +//! Simulates a training loop where each trial trains for multiple "epochs". The pruner +//! stops unpromising trials early, and a callback halts the entire study once a target +//! loss is reached. +//! +//! Run with: `cargo run --example pruning_and_callbacks` + +use std::ops::ControlFlow; + +use optimizer::TrialState; +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + let n_trials: usize = 30; + let n_epochs: u64 = 20; + let target_loss = 0.15; + + // Build a study with a seeded random sampler and MedianPruner. + // MedianPruner compares each trial's intermediate value against the median of + // completed trials at the same step — trials performing below median are pruned. + let study: Study<f64> = Study::builder() + .minimize() + .sampler(RandomSampler::with_seed(42)) + .pruner( + MedianPruner::new(Direction::Minimize) + .n_warmup_steps(3) // let every trial run at least 3 epochs before pruning + .n_min_trials(3), // need 3 completed trials before pruning kicks in + ) + .build(); + + let learning_rate = FloatParam::new(1e-4, 1.0).name("learning_rate"); + let momentum = FloatParam::new(0.0, 0.99).name("momentum"); + + // Use optimize_with_callback to get both pruning AND early stopping. + // The callback fires after each completed (or pruned) trial and can halt the study. + study.optimize_with_callback( + n_trials, + // --- Objective function: simulated training loop with pruning --- + |trial| { + let lr = learning_rate.suggest(trial)?; + let mom = momentum.suggest(trial)?; + + // Simulate training for n_epochs, reporting intermediate loss each epoch. + // Good hyperparameters (lr ≈ 0.01, momentum ≈ 0.8) converge to low loss; + // bad combos plateau high — giving the pruner something to cut. + let mut loss = 1.0; + for epoch in 0..n_epochs { + let lr_penalty = (lr.log10() - 0.01_f64.log10()).powi(2); // 0 at lr=0.01 + let mom_penalty = (mom - 0.8).powi(2); // 0 at momentum=0.8 + let base_loss = 0.02 + 0.05 * lr_penalty + 1.5 * mom_penalty; + let progress = (epoch as f64 + 1.0) / n_epochs as f64; + // Loss decays from 1.0 toward base_loss over epochs. + loss = base_loss + (1.0 - base_loss) * (-3.5 * progress).exp(); + + // Report the intermediate value so the pruner can evaluate this trial. + trial.report(epoch, loss); + + // Check whether the pruner recommends stopping this trial early. + if trial.should_prune() { + // Signal that this trial was pruned — the study records it as Pruned. + Err(TrialPruned)?; + } + } + + Ok::<_, Error>(loss) + }, + // --- Callback: early stopping when we hit the target --- + |study, completed_trial| { + let n_complete = study.n_trials(); + let n_pruned = study + .trials() + .iter() + .filter(|t| t.state == TrialState::Pruned) + .count(); + + match completed_trial.state { + TrialState::Pruned => { + println!( + " Trial {:>3} PRUNED at epoch {} (loss = {:.4}) \ + [{n_complete} done, {n_pruned} pruned]", + completed_trial.id, + completed_trial.intermediate_values.len(), + completed_trial + .intermediate_values + .last() + .map_or(f64::NAN, |v| v.1), + ); + } + TrialState::Complete => { + println!( + " Trial {:>3} complete: loss = {:.4} \ + [{n_complete} done, {n_pruned} pruned]", + completed_trial.id, completed_trial.value, + ); + } + _ => {} + } + + // Stop the entire study once we find a good enough result. + if completed_trial.state == TrialState::Complete && completed_trial.value < target_loss + { + println!("\n Early stopping: reached target loss {target_loss}!"); + return ControlFlow::Break(()); + } + + ControlFlow::Continue(()) + }, + )?; + + // --- Results --- + let best = study.best_trial().expect("at least one completed trial"); + let total = study.n_trials(); + let pruned = study + .trials() + .iter() + .filter(|t| t.state == TrialState::Pruned) + .count(); + + println!("\n--- Results ---"); + println!(" Total trials : {total}"); + println!(" Pruned : {pruned}"); + println!(" Completed : {}", total - pruned); + println!(" Best trial #{}: loss = {:.6}", best.id, best.value); + println!( + " learning_rate = {:.6}", + best.get(&learning_rate).unwrap() + ); + println!(" momentum = {:.4}", best.get(&momentum).unwrap()); + + Ok(()) +} diff --git a/examples/sampler_comparison.rs b/examples/sampler_comparison.rs new file mode 100644 index 0000000..b743e2a --- /dev/null +++ b/examples/sampler_comparison.rs @@ -0,0 +1,94 @@ +//! Sampler comparison example — benchmarks Random, TPE, and Grid samplers on the same problem. +//! +//! Runs the Sphere function f(x, y) = x² + y² with each sampler and compares the best +//! value found. This shows how sampler choice affects optimization quality. +//! +//! Run with: `cargo run --example sampler_comparison` + +use optimizer::prelude::*; + +/// Shared objective function: Sphere function with global minimum at (0, 0). +/// Simple enough to solve well, but 2-D so samplers have room to differ. +fn sphere(x: f64, y: f64) -> f64 { + x.powi(2) + y.powi(2) +} + +/// Run an optimization study and return the best value found. +fn run_study(study: Study<f64>, n_trials: usize) -> f64 { + // Use asymmetric ranges so the Grid sampler tracks each parameter independently. + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-3.0, 3.0).name("y"); + + study + .optimize(n_trials, |trial| { + let x_val = x.suggest(trial)?; + let y_val = y.suggest(trial)?; + Ok::<_, Error>(sphere(x_val, y_val)) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + println!( + " Best trial #{:>3}: x = {:>7.4}, y = {:>7.4}, f(x,y) = {:.6}", + best.id, + best.get(&x).unwrap(), + best.get(&y).unwrap(), + best.value, + ); + best.value +} + +fn main() { + let n_trials: usize = 100; + println!("Comparing samplers on Sphere(x, y) = x² + y²"); + println!(" Search space: x ∈ [-5, 5], y ∈ [-3, 3]"); + println!(" Known minimum: f(0, 0) = 0"); + println!(" Trials per sampler: {n_trials}"); + println!(); + + // --- Random sampler (baseline) --- + // Pure random search: samples uniformly at random. Fast but not guided. + println!("1. Random sampler:"); + let random_best = run_study(Study::minimize(RandomSampler::with_seed(42)), n_trials); + + // --- TPE sampler (Bayesian) --- + // Tree-structured Parzen Estimator: builds a probabilistic model of good vs bad + // regions and focuses sampling where improvements are likely. + println!("\n2. TPE sampler (Bayesian):"); + let tpe = TpeSampler::builder() + .n_startup_trials(10) // random exploration for the first 10 trials + .n_ei_candidates(24) // candidates evaluated per Expected Improvement step + .gamma(0.25) // top 25% of trials define the "good" distribution + .seed(42) + .build() + .unwrap(); + let tpe_best = run_study(Study::minimize(tpe), n_trials); + + // --- Grid sampler (exhaustive) --- + // Evaluates evenly spaced grid points. Each parameter gets its own grid that + // is sampled in order, so n_points_per_param must be >= n_trials. + println!("\n3. Grid sampler (exhaustive):"); + let grid = GridSearchSampler::builder() + .n_points_per_param(n_trials) // one grid point per trial per parameter + .build(); + let grid_best = run_study(Study::minimize(grid), n_trials); + + // --- Summary --- + println!("\n--- Summary ---"); + println!(" Random : {random_best:.6}"); + println!(" TPE : {tpe_best:.6}"); + println!(" Grid : {grid_best:.6}"); + println!(); + + // Find the winner + let results = [ + ("Random", random_best), + ("TPE", tpe_best), + ("Grid", grid_best), + ]; + let (winner, _) = results + .iter() + .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) + .unwrap(); + println!("Winner: {winner} (closest to known minimum of 0.0)"); +} diff --git a/examples/visualization_demo.rs b/examples/visualization_demo.rs deleted file mode 100644 index e2efd76..0000000 --- a/examples/visualization_demo.rs +++ /dev/null @@ -1,45 +0,0 @@ -use optimizer::prelude::*; - -fn main() { - // Multi-parameter optimization with TPE sampler. - let sampler = TpeSampler::builder().seed(42).build().unwrap(); - let mut study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - study.set_pruner(MedianPruner::new(Direction::Minimize)); - - let lr = FloatParam::new(1e-5, 1e-1) - .log_scale() - .name("learning_rate"); - let n_layers = IntParam::new(1, 5).name("n_layers"); - let dropout = FloatParam::new(0.0, 0.5).step(0.05).name("dropout"); - let batch_size = CategoricalParam::new(vec![16, 32, 64, 128]).name("batch_size"); - - study - .optimize(80, |trial| { - let lr_val = lr.suggest(trial)?; - let layers = n_layers.suggest(trial)?; - let drop = dropout.suggest(trial)?; - let bs = batch_size.suggest(trial)?; - - // Simulate training with intermediate reporting. - let mut loss = 1.0; - for epoch in 0..10 { - loss *= 0.7 + 0.3 * lr_val.ln().abs() / 12.0; - loss += drop * 0.05; - loss += (1.0 / bs as f64) * 0.1; - loss -= layers as f64 * 0.02; - trial.report(epoch, loss); - if trial.should_prune() { - return Err(TrialPruned.into()); - } - } - - Ok::<_, Error>(loss) - }) - .unwrap(); - - println!("{}", study.summary()); - - let path = "optimization_report.html"; - generate_html_report(&study, path).unwrap(); - println!("\nReport saved to {path}"); -} diff --git a/src/error.rs b/src/error.rs index 130534e..70b2d98 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,21 @@ +//! Error types for the optimizer crate. +//! +//! All fallible operations in the crate return [`Result<T>`], which is an +//! alias for `core::result::Result<T, Error>`. The [`Error`] enum covers +//! parameter validation, sampling conflicts, pruning signals, and +//! feature-gated I/O errors. + +/// Errors returned by optimizer operations. +/// +/// Most variants are returned during parameter validation or trial +/// management. The [`TrialPruned`](Error::TrialPruned) variant has special +/// significance — it signals early stopping and is typically raised via +/// the [`TrialPruned`](super::TrialPruned) convenience type. #[derive(Debug, thiserror::Error)] pub enum Error { - /// Returned when the lower bound is greater than the upper bound. + /// The lower bound exceeds the upper bound in a + /// [`FloatParam`](crate::parameter::FloatParam) or + /// [`IntParam`](crate::parameter::IntParam). #[error("invalid bounds: low ({low}) must be less than or equal to high ({high})")] InvalidBounds { /// The lower bound value. @@ -9,19 +24,22 @@ pub enum Error { high: f64, }, - /// Returned when log scale is used with non-positive bounds. + /// Log-scale is enabled but the lower bound is not positive (float) or + /// is less than 1 (integer). #[error("invalid log bounds: low must be positive for log scale")] InvalidLogBounds, - /// Returned when step size is not positive. + /// The step size provided to a parameter is not positive. #[error("invalid step: step must be positive")] InvalidStep, - /// Returned when categorical choices are empty. + /// A [`CategoricalParam`](crate::parameter::CategoricalParam) was created + /// with an empty choices vector. #[error("categorical choices cannot be empty")] EmptyChoices, - /// Returned when a parameter is suggested with a different configuration. + /// The same [`ParamId`](crate::parameter::ParamId) was suggested twice + /// with a different distribution configuration. #[error("parameter conflict for '{name}': {reason}")] ParameterConflict { /// The name of the conflicting parameter. @@ -30,27 +48,29 @@ pub enum Error { reason: String, }, - /// Returned when requesting the best trial but no trials have completed. + /// [`Study::best_trial`](crate::Study::best_trial) or similar was called + /// before any trial completed successfully. #[error("no completed trials available")] NoCompletedTrials, - /// Returned when gamma is not in the valid range (0.0, 1.0). + /// The gamma value for TPE sampling is outside the open interval (0, 1). #[error("invalid gamma: {0} must be in (0.0, 1.0)")] InvalidGamma(f64), - /// Returned when bandwidth is not positive. + /// A KDE bandwidth value is not positive. #[error("invalid bandwidth: {0} must be positive")] InvalidBandwidth(f64), - /// Returned when KDE is created with empty samples. + /// A kernel density estimator was constructed with no samples. #[error("KDE requires at least one sample")] EmptySamples, - /// Returned when multivariate KDE samples have zero dimensions. + /// Multivariate KDE samples have zero dimensions. #[error("multivariate KDE samples must have at least one dimension")] ZeroDimensions, - /// Returned when multivariate KDE samples have inconsistent dimensions. + /// A sample in the multivariate KDE has a different number of dimensions + /// than the first sample. #[error( "dimension mismatch: expected {expected} dimensions but sample {sample_index} has {got}" )] @@ -63,7 +83,7 @@ pub enum Error { sample_index: usize, }, - /// Returned when bandwidth vector length doesn't match the number of dimensions. + /// The bandwidth vector length does not match the number of KDE dimensions. #[error("bandwidth dimension mismatch: expected {expected} bandwidths but got {got}")] BandwidthDimensionMismatch { /// The expected number of bandwidths. @@ -72,11 +92,14 @@ pub enum Error { got: usize, }, - /// Returned when a trial is pruned (stopped early by the objective function). + /// The objective signalled that this trial should be pruned (stopped + /// early). Typically raised via `Err(TrialPruned)?` inside the + /// objective closure. #[error("trial was pruned")] TrialPruned, - /// Returned when the objective returns the wrong number of values. + /// The multi-objective closure returned a different number of values + /// than the number of directions configured on the study. #[error("objective dimension mismatch: expected {expected} values, got {got}")] ObjectiveDimensionMismatch { /// The expected number of objective values. @@ -85,21 +108,24 @@ pub enum Error { got: usize, }, - /// Returned when an internal invariant is violated. + /// An internal invariant was violated. This indicates a bug in the + /// library rather than a user error. #[error("internal error: {0}")] Internal(&'static str), - /// Returned when an async task fails. + /// An async worker task failed. Only available with the `async` feature. #[cfg(feature = "async")] #[error("async task error: {0}")] TaskError(String), - /// Returned when a storage operation fails. + /// A storage I/O operation failed. Only available with the `journal` + /// feature. #[cfg(feature = "journal")] #[error("storage error: {0}")] Storage(String), } +/// A convenience alias for `core::result::Result<T, Error>`. pub type Result<T> = core::result::Result<T, Error>; /// Convenience type for signalling a pruned trial from an objective function. diff --git a/src/fanova.rs b/src/fanova.rs index 52882f6..91b6af0 100644 --- a/src/fanova.rs +++ b/src/fanova.rs @@ -1,26 +1,83 @@ //! fANOVA (functional ANOVA) parameter importance via random forest. //! -//! Decomposes the variance of the objective function into contributions -//! from individual parameters (main effects) and parameter interactions. +//! fANOVA decomposes the variance of the objective function into +//! contributions from individual parameters (**main effects**) and +//! parameter pairs (**interaction effects**). This helps answer the +//! question: *"Which parameters matter most, and do any parameters +//! interact?"* //! -//! The algorithm: -//! 1. Fits a random forest to `(parameters) -> objective_value` -//! 2. Applies functional ANOVA decomposition to the forest -//! 3. Computes main effects (single-parameter importance) -//! 4. Computes interaction effects (pairwise parameter importance) +//! # Algorithm +//! +//! 1. Fit a random forest to the mapping `(parameters) → objective` +//! 2. Apply functional ANOVA decomposition to the trained forest +//! 3. Compute main effects: the variance explained by each parameter alone +//! 4. Compute interaction effects: the additional variance explained by +//! pairs of parameters beyond their individual contributions +//! 5. Normalize so all importances sum to 1.0 +//! +//! # When to use +//! +//! - **After optimization**: call [`Study::fanova()`](crate::Study::fanova) +//! or [`Study::fanova_with_config()`](crate::Study::fanova_with_config) +//! to identify which parameters had the most impact +//! - **Interaction detection**: unlike Spearman correlation +//! ([`Study::param_importance()`](crate::Study::param_importance)), +//! fANOVA can detect non-linear relationships and parameter interactions +//! - **Hyperparameter tuning**: focus tuning effort on high-importance +//! parameters and fix low-importance ones to reasonable defaults +//! +//! # Reference +//! +//! Hutter, F., Hoos, H. & Leyton-Brown, K. (2014). "An Efficient +//! Approach for Assessing Hyperparameter Importance." ICML 2014. +//! +//! # Example +//! +//! ``` +//! use optimizer::prelude::*; +//! +//! let study: Study<f64> = Study::new(Direction::Minimize); +//! let x = FloatParam::new(0.0, 10.0).name("x"); +//! let y = FloatParam::new(0.0, 10.0).name("y"); +//! +//! study +//! .optimize(50, |trial| { +//! let xv = x.suggest(trial)?; +//! let yv = y.suggest(trial)?; +//! // x matters much more than y +//! Ok::<_, optimizer::Error>(3.0 * xv + 0.1 * yv) +//! }) +//! .unwrap(); +//! +//! let result = study.fanova().unwrap(); +//! // Main effects sorted by descending importance +//! assert_eq!(result.main_effects[0].0, "x"); +//! ``` /// Result of fANOVA analysis. +/// +/// All importance values are fractions of total variance and sum to 1.0 +/// across main effects and interactions combined. #[derive(Debug, Clone)] pub struct FanovaResult { /// Per-parameter importance (fraction of total variance explained). - /// Sorted by descending importance. + /// + /// Sorted by descending importance. Each entry is + /// `(parameter_name, importance)` where importance is in `[0.0, 1.0]`. pub main_effects: Vec<(String, f64)>, /// Pairwise interaction importance (fraction of total variance explained). - /// Sorted by descending importance. + /// + /// Sorted by descending importance. Each entry is + /// `((param_a, param_b), importance)`. Only pairs with non-negligible + /// interaction (> 1e-10) are included. pub interactions: Vec<((String, String), f64)>, } /// Configuration for fANOVA analysis. +/// +/// Use [`Default::default()`] for reasonable settings, or customize +/// the random forest parameters for specific needs. Pass to +/// [`Study::fanova_with_config()`](crate::Study::fanova_with_config). #[derive(Debug, Clone)] pub struct FanovaConfig { /// Number of trees in the random forest (default: 64). diff --git a/src/importance.rs b/src/importance.rs index 8590302..90b6e76 100644 --- a/src/importance.rs +++ b/src/importance.rs @@ -1,4 +1,26 @@ //! Parameter importance via Spearman rank correlation. +//! +//! Compute the absolute Spearman rank correlation between each parameter +//! and the objective value to estimate which parameters most influence +//! the outcome. This is a lightweight, non-parametric alternative to +//! [`fANOVA`](crate::fanova) that works well for monotonic relationships. +//! +//! # How it works +//! +//! 1. Rank parameter values and objective values independently +//! 2. Compute the Pearson correlation on the ranks (= Spearman ρ) +//! 3. Take the absolute value (direction of correlation is not relevant +//! for importance) +//! +//! # When to use +//! +//! - **Quick importance check**: call +//! [`Study::param_importance()`](crate::Study::param_importance) after +//! optimization for a fast, interpretable ranking +//! - **Monotonic relationships**: Spearman captures monotonic (not just +//! linear) correlations but may miss non-monotonic effects or interactions +//! - For interaction detection or non-linear importance, use +//! [`fANOVA`](crate::fanova) instead /// Assign average ranks to a slice of `f64` values (handles ties). #[allow(clippy::cast_precision_loss, clippy::float_cmp)] diff --git a/src/lib.rs b/src/lib.rs index 1d79dcd..8df4aec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,194 +9,79 @@ #![deny(clippy::pedantic)] #![deny(clippy::std_instead_of_core)] -//! A black-box optimization library with multiple sampling strategies. +//! Bayesian and population-based optimization library with an Optuna-like API +//! for hyperparameter tuning and black-box optimization. It ships 12 samplers +//! (from random search to CMA-ES and NSGA-III), 8 pruners, async/parallel +//! evaluation, and optional journal-based persistence — all with zero required +//! feature flags for the common case. //! -//! This library provides an Optuna-like API for hyperparameter optimization -//! with support for multiple sampling algorithms: +//! # Getting Started //! -//! - **Random Search** - Simple random sampling for baseline comparisons -//! - **TPE (Tree-Parzen Estimator)** - Bayesian optimization for efficient search -//! - **Grid Search** - Exhaustive search over a specified parameter grid -//! - **Sobol (QMC)** - Quasi-random sampling for better space coverage (requires `sobol` feature) -//! - **CMA-ES** - Covariance Matrix Adaptation Evolution Strategy for continuous optimization (requires `cma-es` feature) -//! - **DE** - Differential Evolution for population-based global optimization -//! - **GP** - Gaussian Process Bayesian optimization with Expected Improvement (requires `gp` feature) -//! - **BOHB** - Bayesian Optimization + `HyperBand` for budget-aware TPE sampling -//! - **NSGA-II** - Non-dominated Sorting Genetic Algorithm II for multi-objective optimization -//! - **NSGA-III** - Reference-point-based NSGA for many-objective (3+) optimization -//! - **MOEA/D** - Decomposition-based multi-objective with Tchebycheff, Weighted Sum, or PBI -//! - **MOTPE** - Multi-Objective Tree-Parzen Estimator for Bayesian multi-objective optimization -//! -//! Additional features include: -//! -//! - Float, integer, and categorical parameter types -//! - Log-scale and stepped parameter sampling -//! - Synchronous and async optimization -//! - Parallel trial evaluation with bounded concurrency -//! -//! # Quick Start +//! Minimize a function in five lines — no feature flags needed: //! //! ``` //! use optimizer::prelude::*; //! -//! // Create a study with TPE sampler -//! let sampler = TpeSampler::builder().seed(42).build().unwrap(); -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); -//! -//! // Define parameter search space +//! let study: Study<f64> = Study::new(Direction::Minimize); //! let x = FloatParam::new(-10.0, 10.0).name("x"); //! -//! // Optimize x^2 for 20 trials //! study -//! .optimize(20, |trial| { -//! let x_val = x.suggest(trial)?; -//! Ok::<_, Error>(x_val * x_val) +//! .optimize(50, |trial| { +//! let v = x.suggest(trial)?; +//! Ok::<_, Error>((v - 3.0).powi(2)) //! }) //! .unwrap(); //! -//! // Get the best result //! let best = study.best_trial().unwrap(); -//! println!("x = {}", best.get(&x).unwrap()); -//! ``` -//! -//! # Creating a Study -//! -//! A [`Study`] manages optimization trials. Create one with an optimization direction: -//! -//! ``` -//! use optimizer::sampler::random::RandomSampler; -//! use optimizer::sampler::tpe::TpeSampler; -//! use optimizer::{Direction, Study}; -//! -//! // Minimize with default random sampler -//! let study: Study<f64> = Study::new(Direction::Minimize); -//! -//! // Maximize with TPE sampler -//! let study: Study<f64> = Study::with_sampler(Direction::Maximize, TpeSampler::new()); -//! -//! // With seeded sampler for reproducibility -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); -//! ``` -//! -//! # Suggesting Parameters -//! -//! Within the objective function, use parameter types to suggest values: -//! -//! ``` -//! use optimizer::parameter::{BoolParam, CategoricalParam, FloatParam, IntParam, Parameter}; -//! use optimizer::{Direction, Study}; -//! -//! let study: Study<f64> = Study::new(Direction::Minimize); -//! -//! // Define parameter search spaces -//! let x_param = FloatParam::new(0.0, 1.0); -//! let lr_param = FloatParam::new(1e-5, 1e-1).log_scale(); -//! let step_param = FloatParam::new(0.0, 1.0).step(0.1); -//! let n_param = IntParam::new(1, 10); -//! let batch_param = IntParam::new(16, 256).log_scale(); -//! let units_param = IntParam::new(32, 512).step(32); -//! let flag_param = BoolParam::new(); -//! let optimizer_param = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]); -//! -//! study -//! .optimize(10, |trial| { -//! let x = x_param.suggest(trial)?; -//! let lr = lr_param.suggest(trial)?; -//! let step = step_param.suggest(trial)?; -//! let n = n_param.suggest(trial)?; -//! let batch = batch_param.suggest(trial)?; -//! let units = units_param.suggest(trial)?; -//! let flag = flag_param.suggest(trial)?; -//! let optimizer = optimizer_param.suggest(trial)?; -//! -//! Ok::<_, optimizer::Error>(x * n as f64) -//! }) -//! .unwrap(); -//! ``` -//! -//! # Available Samplers -//! -//! ## Random Search -//! -//! The simplest sampling strategy, useful for baselines: -//! -//! ``` -//! use optimizer::sampler::random::RandomSampler; -//! use optimizer::{Direction, Study}; -//! -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); -//! ``` -//! -//! ## TPE (Tree-Parzen Estimator) -//! -//! Bayesian optimization that learns from previous trials: -//! -//! ``` -//! use optimizer::sampler::tpe::TpeSampler; -//! -//! let sampler = TpeSampler::builder() -//! .gamma(0.15) // Quantile for good/bad split -//! .n_startup_trials(20) // Random trials before TPE -//! .n_ei_candidates(32) // Candidates to evaluate -//! .seed(42) // Reproducibility -//! .build() -//! .unwrap(); -//! ``` -//! -//! ## Grid Search -//! -//! Exhaustive search over a discretized parameter space: -//! +//! println!("x = {:.4}, f(x) = {:.4}", best.get(&x).unwrap(), best.value); //! ``` -//! use optimizer::sampler::grid::GridSearchSampler; -//! use optimizer::{Direction, Study}; //! -//! let sampler = GridSearchSampler::builder() -//! .n_points_per_param(10) // Points per parameter dimension -//! .build(); +//! # Core Concepts //! -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); -//! ``` -//! -//! # Async and Parallel Optimization +//! | Type | Role | +//! |------|------| +//! | [`Study`] | Drive an optimization loop: create trials, record results, track the best. | +//! | [`Trial`] | A single evaluation of the objective function, carrying suggested parameter values. | +//! | [`Parameter`] | Define the search space — [`FloatParam`], [`IntParam`], [`CategoricalParam`], [`BoolParam`], [`EnumParam`]. | +//! | [`Sampler`](sampler::Sampler) | Strategy for choosing the next point to evaluate (TPE, CMA-ES, random, etc.). | +//! | [`Direction`] | Whether the study minimizes or maximizes the objective value. | //! -//! With the `async` feature enabled, you can run trials asynchronously: +//! # Sampler Guide //! -//! ```ignore -//! use optimizer::{Study, Direction}; -//! use optimizer::parameter::{FloatParam, Parameter}; +//! ## Single-objective samplers //! -//! let x_param = FloatParam::new(0.0, 1.0); +//! | Sampler | Algorithm | Best for | Feature flag | +//! |---------|-----------|----------|--------------| +//! | [`RandomSampler`] | Uniform random | Baselines, high-dimensional | — | +//! | [`TpeSampler`] | Tree-Parzen Estimator | General-purpose Bayesian | — | +//! | [`GridSearchSampler`] | Exhaustive grid | Small, discrete spaces | — | +//! | [`SobolSampler`] | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | +//! | [`CmaEsSampler`] | CMA-ES | Continuous, moderate dimensions | `cma-es` | +//! | [`GpSampler`] | Gaussian Process + EI | Expensive objectives, few trials | `gp` | +//! | [`DifferentialEvolutionSampler`] | Differential Evolution | Non-convex, population-based | — | +//! | [`BohbSampler`] | BOHB (TPE + `HyperBand`) | Budget-aware early stopping | — | //! -//! // Sequential async -//! study.optimize_async(10, |mut trial| { -//! let x_param = x_param.clone(); -//! async move { -//! let x = x_param.suggest(&mut trial)?; -//! Ok((trial, x * x)) -//! } -//! }).await?; +//! ## Multi-objective samplers //! -//! // Parallel with bounded concurrency -//! study.optimize_parallel(10, 4, |mut trial| { -//! let x_param = x_param.clone(); -//! async move { -//! let x = x_param.suggest(&mut trial)?; -//! Ok((trial, x * x)) -//! } -//! }).await?; -//! ``` +//! | Sampler | Algorithm | Best for | Feature flag | +//! |---------|-----------|----------|--------------| +//! | [`Nsga2Sampler`] | NSGA-II | 2-3 objectives | — | +//! | [`Nsga3Sampler`] | NSGA-III (reference-point) | 3+ objectives | — | +//! | [`MoeadSampler`] | MOEA/D (decomposition) | Many objectives, structured fronts | — | +//! | [`MotpeSampler`] | Multi-Objective TPE | Bayesian multi-objective | — | //! //! # Feature Flags //! -//! - `async`: Enable async optimization methods (requires tokio) -//! - `derive`: Enable `#[derive(Categorical)]` for enum parameters -//! - `serde`: Enable `Serialize`/`Deserialize` on public types and `Study::save()`/`Study::load()` -//! - `sobol`: Enable the Sobol quasi-random sampler for better space coverage -//! - `cma-es`: Enable the CMA-ES sampler for continuous optimization -//! - `gp`: Enable the Gaussian Process sampler for Bayesian optimization -//! - `visualization`: Generate self-contained HTML reports with interactive Plotly.js charts -//! - `tracing`: Emit structured log events via the [`tracing`](https://docs.rs/tracing) crate at key optimization points +//! | Flag | What it enables | Default | +//! |------|----------------|---------| +//! | `async` | Async/parallel optimization via tokio ([`Study::optimize_async`], [`Study::optimize_parallel`]) | off | +//! | `derive` | `#[derive(Categorical)]` for enum parameters | off | +//! | `serde` | `Serialize`/`Deserialize` on public types, [`Study::save`]/[`Study::load`] | off | +//! | `journal` | [`JournalStorage`] — JSONL persistence with file locking (enables `serde`) | off | +//! | `sobol` | [`SobolSampler`] — quasi-random low-discrepancy sequences | off | +//! | `cma-es` | [`CmaEsSampler`] — Covariance Matrix Adaptation Evolution Strategy | off | +//! | `gp` | [`GpSampler`] — Gaussian Process surrogate with Expected Improvement | off | +//! | `tracing` | Structured log events via [`tracing`](https://docs.rs/tracing) at key optimization points | off | /// Emit a `tracing::info!` event when the `tracing` feature is enabled. /// No-op otherwise. diff --git a/src/multi_objective.rs b/src/multi_objective.rs index 5a4b3b8..d47ce37 100644 --- a/src/multi_objective.rs +++ b/src/multi_objective.rs @@ -1,9 +1,28 @@ //! Multi-objective optimization via a dedicated study type. //! -//! [`MultiObjectiveStudy`] manages trials that return multiple objective -//! values. It supports arbitrary numbers of objectives with per-objective -//! directions (minimize or maximize). Use [`pareto_front()`](MultiObjectiveStudy::pareto_front) -//! to retrieve the Pareto-optimal solutions. +//! [`MultiObjectiveStudy`] manages trials that return **multiple** objective +//! values simultaneously. It supports arbitrary numbers of objectives with +//! per-objective directions (minimize or maximize). +//! +//! # Key concepts +//! +//! In multi-objective optimization there is usually no single best solution. +//! Instead, there is a **Pareto front** — the set of solutions where no +//! objective can be improved without worsening another. Use +//! [`pareto_front()`](MultiObjectiveStudy::pareto_front) to retrieve these +//! non-dominated solutions after optimization. +//! +//! A solution **dominates** another if it is at least as good in all +//! objectives and strictly better in at least one. Solutions that are not +//! dominated by any other are called **Pareto-optimal**. +//! +//! # Samplers +//! +//! By default a random sampler is used. For smarter search, pass a +//! [`MultiObjectiveSampler`] such as [`Nsga2Sampler`](crate::Nsga2Sampler), +//! [`Nsga3Sampler`](crate::Nsga3Sampler), or +//! [`MoeadSampler`](crate::MoeadSampler) via +//! [`MultiObjectiveStudy::with_sampler`]. //! //! # Examples //! @@ -46,6 +65,11 @@ use crate::types::{Direction, TrialState}; // --------------------------------------------------------------------------- /// A completed trial with multiple objective values. +/// +/// Each trial stores its sampled parameter values, the vector of +/// objective values (one per objective), and optional constraint values. +/// Retrieve typed parameter values with [`get()`](Self::get) and check +/// constraint feasibility with [`is_feasible()`](Self::is_feasible). #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct MultiObjectiveTrial { @@ -111,9 +135,14 @@ impl MultiObjectiveTrial { /// Trait for samplers aware of multi-objective history. /// -/// Separate from [`Sampler`] because NSGA-II needs access to -/// `&[MultiObjectiveTrial]` (with vector-valued objectives) and -/// `&[Direction]` (one direction per objective). +/// Separate from [`Sampler`] because multi-objective algorithms (e.g., +/// NSGA-II) need access to the full vector of objective values per trial +/// (`&[MultiObjectiveTrial]`) and the per-objective directions +/// (`&[Direction]`). +/// +/// Implementations include [`Nsga2Sampler`](crate::Nsga2Sampler), +/// [`Nsga3Sampler`](crate::Nsga3Sampler), and +/// [`MoeadSampler`](crate::MoeadSampler). pub trait MultiObjectiveSampler: Send + Sync { /// Samples a parameter value from the given distribution. fn sample( @@ -181,9 +210,12 @@ impl Sampler for MoSamplerBridge { /// A study for multi-objective optimization. /// -/// Manages trials that return multiple objective values. Supports +/// Manage trials that return multiple objective values. Supports /// arbitrary numbers of objectives with independent minimize/maximize -/// directions. +/// directions. After optimization, call [`pareto_front()`](Self::pareto_front) +/// to retrieve the non-dominated solutions. +/// +/// For single-objective optimization, use [`Study`](crate::Study) instead. /// /// # Examples /// @@ -269,7 +301,11 @@ impl MultiObjectiveStudy { self.completed_trials.read().clone() } - /// Returns the Pareto-optimal trials (front 0). + /// Return the Pareto-optimal trials (the non-dominated front). + /// + /// Uses fast non-dominated sorting (Deb et al., 2002) from the + /// [`pareto`](crate::pareto) module. Returns an empty vec if no + /// trials have completed. #[must_use] pub fn pareto_front(&self) -> Vec<MultiObjectiveTrial> { let trials = self.completed_trials.read(); diff --git a/src/param.rs b/src/param.rs index 6dd3e60..864a19e 100644 --- a/src/param.rs +++ b/src/param.rs @@ -1,18 +1,35 @@ -//! Parameter value storage types. +//! Raw parameter value storage. +//! +//! [`ParamValue`] is the type-erased representation of a sampled parameter. +//! Users rarely construct `ParamValue` directly — the +//! [`Parameter::suggest`](crate::parameter::Parameter::suggest) method returns +//! the already-typed value (e.g., `f64` for [`FloatParam`](crate::parameter::FloatParam)). +//! +//! `ParamValue` is useful when inspecting raw trial data via +//! [`Trial::params`](crate::Trial::params) or +//! [`CompletedTrial::params`](crate::sampler::CompletedTrial). -/// Represents a sampled parameter value. +/// A type-erased sampled parameter value. /// -/// This enum stores different parameter value types uniformly. -/// For categorical parameters, the `Categorical` variant stores -/// the index into the choices array. +/// Stores float, integer, or categorical (index) values uniformly. +/// For categorical parameters the `Categorical` variant stores the +/// zero-based index into the choices array, not the choice itself. +/// +/// # Display +/// +/// `ParamValue` implements [`Display`](core::fmt::Display): floats and +/// integers print their numeric value, and categoricals print `category(i)`. #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub enum ParamValue { - /// A floating-point parameter value. + /// A floating-point parameter value (from [`FloatParam`](crate::parameter::FloatParam)). Float(f64), - /// An integer parameter value. + /// An integer parameter value (from [`IntParam`](crate::parameter::IntParam)). Int(i64), - /// A categorical parameter value, stored as an index into the choices array. + /// A categorical index into the choices array (from + /// [`CategoricalParam`](crate::parameter::CategoricalParam), + /// [`BoolParam`](crate::parameter::BoolParam), or + /// [`EnumParam`](crate::parameter::EnumParam)). Categorical(usize), } diff --git a/src/parameter.rs b/src/parameter.rs index 34f5fd3..6c48931 100644 --- a/src/parameter.rs +++ b/src/parameter.rs @@ -1,8 +1,19 @@ -//! Central parameter trait and built-in parameter types. +//! Parameter trait and five built-in parameter types. //! -//! The [`Parameter`] trait provides a unified way to define parameter types -//! and suggest values from a [`Trial`]. Built-in implementations -//! cover floats, integers, categoricals, booleans, and enum types. +//! The [`Parameter`] trait provides a unified way to define search-space +//! dimensions and sample values from a [`Trial`]. Five implementations +//! cover the most common hyperparameter types: +//! +//! | Type | Sampled value | Typical use | +//! |------|---------------|-------------| +//! | [`FloatParam`] | `f64` | Learning rate, dropout probability | +//! | [`IntParam`] | `i64` | Layer count, batch size | +//! | [`CategoricalParam`] | `T: Clone` | Optimizer name, activation function | +//! | [`BoolParam`] | `bool` | Feature toggle | +//! | [`EnumParam`] | `T: Categorical` | Typed enum variant selection | +//! +//! All five types support `.name()` for a human-readable label and +//! `.suggest(&mut trial)` as a shorthand for `trial.suggest_param(¶m)`. //! //! # Example //! @@ -14,10 +25,17 @@ //! //! let lr = FloatParam::new(1e-5, 1e-1) //! .log_scale() +//! .name("learning_rate") +//! .suggest(&mut trial) +//! .unwrap(); +//! let layers = IntParam::new(1, 10) +//! .name("n_layers") +//! .suggest(&mut trial) +//! .unwrap(); +//! let dropout = BoolParam::new() +//! .name("use_dropout") //! .suggest(&mut trial) //! .unwrap(); -//! let layers = IntParam::new(1, 10).suggest(&mut trial).unwrap(); -//! let dropout = BoolParam::new().suggest(&mut trial).unwrap(); //! ``` use core::fmt::Debug; @@ -42,7 +60,7 @@ static NEXT_PARAM_ID: AtomicU64 = AtomicU64::new(0); pub struct ParamId(u64); impl ParamId { - /// Creates a new unique `ParamId`. + /// Create a new unique `ParamId`. pub fn new() -> Self { Self(NEXT_PARAM_ID.fetch_add(1, Ordering::Relaxed)) } @@ -60,52 +78,67 @@ impl core::fmt::Display for ParamId { } } -/// A trait for defining parameter types that can be suggested by a [`Trial`]. +/// Define a parameter type that can be suggested by a [`Trial`]. /// /// Implementors specify the distribution to sample from and how to convert -/// the raw [`ParamValue`] back into a typed value. +/// the raw [`ParamValue`] back into a typed value. See the five built-in +/// implementations: [`FloatParam`], [`IntParam`], [`CategoricalParam`], +/// [`BoolParam`], and [`EnumParam`]. pub trait Parameter: Debug { /// The typed value returned after sampling. type Value; - /// Returns the unique identifier for this parameter. + /// Return the unique identifier for this parameter. fn id(&self) -> ParamId; - /// Returns the distribution that this parameter samples from. + /// Return the distribution that this parameter samples from. fn distribution(&self) -> Distribution; - /// Converts a raw [`ParamValue`] into the typed value. + /// Convert a raw [`ParamValue`] into the typed value. /// /// # Errors /// - /// Returns an error if the `ParamValue` variant doesn't match what this parameter expects. + /// Return an error if the `ParamValue` variant does not match what this parameter expects. fn cast_param_value(&self, param_value: &ParamValue) -> Result<Self::Value>; - /// Validates the parameter configuration. + /// Validate the parameter configuration. /// /// Called before sampling. The default implementation accepts all configurations. /// /// # Errors /// - /// Returns an error if the parameter configuration is invalid. + /// Return an error if the parameter configuration is invalid. fn validate(&self) -> Result<()> { Ok(()) } - /// Returns a human-readable label for this parameter. + /// Return a human-readable label for this parameter. /// - /// Defaults to the `Debug` output of the parameter. + /// Defaults to the `Debug` output of the parameter. Override with + /// the `.name()` builder method on concrete types. fn label(&self) -> String { format!("{self:?}") } - /// Suggests a value for this parameter from the given trial. + /// Suggest a value for this parameter from the given trial. /// /// This is a convenience method that delegates to [`Trial::suggest_param`]. /// + /// # Examples + /// + /// ``` + /// use optimizer::Trial; + /// use optimizer::parameter::{FloatParam, Parameter}; + /// + /// let mut trial = Trial::new(0); + /// let param = FloatParam::new(-5.0, 5.0).name("x"); + /// let value: f64 = param.suggest(&mut trial).unwrap(); + /// assert!((-5.0..=5.0).contains(&value)); + /// ``` + /// /// # Errors /// - /// Returns an error if validation fails, the parameter conflicts with + /// Return an error if validation fails, the parameter conflicts with /// a previously suggested parameter of the same id, or sampling fails. fn suggest(&self, trial: &mut Trial) -> Result<Self::Value> where @@ -117,7 +150,7 @@ pub trait Parameter: Debug { /// A floating-point parameter with optional log-scale and step size. /// -/// # Example +/// # Examples /// /// ``` /// use optimizer::Trial; @@ -128,13 +161,14 @@ pub trait Parameter: Debug { /// // Simple range /// let x = FloatParam::new(0.0, 1.0).suggest(&mut trial).unwrap(); /// -/// // Log-scale +/// // Log-scale with a human-readable name /// let lr = FloatParam::new(1e-5, 1e-1) /// .log_scale() +/// .name("learning_rate") /// .suggest(&mut trial) /// .unwrap(); /// -/// // Stepped +/// // Stepped (values will be multiples of 0.25) /// let step = FloatParam::new(0.0, 1.0) /// .step(0.25) /// .suggest(&mut trial) @@ -151,7 +185,7 @@ pub struct FloatParam { } impl FloatParam { - /// Creates a new float parameter with the given bounds. + /// Create a new float parameter sampling uniformly from `[low, high]`. #[must_use] pub fn new(low: f64, high: f64) -> Self { Self { @@ -164,21 +198,21 @@ impl FloatParam { } } - /// Enables log-scale sampling. + /// Enable log-scale sampling (bounds must be positive). #[must_use] pub fn log_scale(mut self) -> Self { self.log_scale = true; self } - /// Sets a step size for discretized sampling. + /// Set a step size for discretized sampling. #[must_use] pub fn step(mut self, step: f64) -> Self { self.step = Some(step); self } - /// Sets a human-readable name for this parameter. + /// Set a human-readable name for this parameter. /// /// When set, this name is used as the parameter's label instead of /// the default `Debug` output. @@ -245,7 +279,7 @@ impl Parameter for FloatParam { /// An integer parameter with optional log-scale and step size. /// -/// # Example +/// # Examples /// /// ``` /// use optimizer::Trial; @@ -254,15 +288,19 @@ impl Parameter for FloatParam { /// let mut trial = Trial::new(0); /// /// // Simple range -/// let n = IntParam::new(1, 10).suggest(&mut trial).unwrap(); +/// let n = IntParam::new(1, 10) +/// .name("n_layers") +/// .suggest(&mut trial) +/// .unwrap(); /// /// // Log-scale /// let batch = IntParam::new(1, 1024) /// .log_scale() +/// .name("batch_size") /// .suggest(&mut trial) /// .unwrap(); /// -/// // Stepped +/// // Stepped (multiples of 32) /// let units = IntParam::new(32, 512).step(32).suggest(&mut trial).unwrap(); /// ``` #[derive(Clone, Debug)] @@ -276,7 +314,7 @@ pub struct IntParam { } impl IntParam { - /// Creates a new integer parameter with the given bounds. + /// Create a new integer parameter sampling uniformly from `[low, high]`. #[must_use] pub fn new(low: i64, high: i64) -> Self { Self { @@ -289,21 +327,21 @@ impl IntParam { } } - /// Enables log-scale sampling. + /// Enable log-scale sampling (bounds must be ≥ 1). #[must_use] pub fn log_scale(mut self) -> Self { self.log_scale = true; self } - /// Sets a step size for discretized sampling. + /// Set a step size for discretized sampling. #[must_use] pub fn step(mut self, step: i64) -> Self { self.step = Some(step); self } - /// Sets a human-readable name for this parameter. + /// Set a human-readable name for this parameter. /// /// When set, this name is used as the parameter's label instead of /// the default `Debug` output. @@ -370,7 +408,10 @@ impl Parameter for IntParam { /// A categorical parameter that selects from a list of choices. /// -/// # Example +/// The generic type `T` is the element type of the choices vector. +/// The sampler picks an index and the corresponding element is returned. +/// +/// # Examples /// /// ``` /// use optimizer::Trial; @@ -378,6 +419,7 @@ impl Parameter for IntParam { /// /// let mut trial = Trial::new(0); /// let opt = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]) +/// .name("optimizer") /// .suggest(&mut trial) /// .unwrap(); /// ``` @@ -389,7 +431,7 @@ pub struct CategoricalParam<T: Clone> { } impl<T: Clone> CategoricalParam<T> { - /// Creates a new categorical parameter with the given choices. + /// Create a new categorical parameter with the given choices. #[must_use] pub fn new(choices: Vec<T>) -> Self { Self { @@ -399,7 +441,7 @@ impl<T: Clone> CategoricalParam<T> { } } - /// Sets a human-readable name for this parameter. + /// Set a human-readable name for this parameter. /// /// When set, this name is used as the parameter's label instead of /// the default `Debug` output. @@ -444,16 +486,19 @@ impl<T: Clone + Debug> Parameter for CategoricalParam<T> { } } -/// A boolean parameter (equivalent to a categorical with `[false, true]`). +/// A boolean parameter (equivalent to a two-choice categorical: `false` / `true`). /// -/// # Example +/// # Examples /// /// ``` /// use optimizer::Trial; /// use optimizer::parameter::{BoolParam, Parameter}; /// /// let mut trial = Trial::new(0); -/// let dropout = BoolParam::new().suggest(&mut trial).unwrap(); +/// let use_dropout = BoolParam::new() +/// .name("use_dropout") +/// .suggest(&mut trial) +/// .unwrap(); /// ``` #[derive(Clone, Debug)] pub struct BoolParam { @@ -462,7 +507,7 @@ pub struct BoolParam { } impl BoolParam { - /// Creates a new boolean parameter. + /// Create a new boolean parameter. #[must_use] pub fn new() -> Self { Self { @@ -471,7 +516,7 @@ impl BoolParam { } } - /// Sets a human-readable name for this parameter. + /// Set a human-readable name for this parameter. /// /// When set, this name is used as the parameter's label instead of /// the default `Debug` output. @@ -513,10 +558,10 @@ impl Parameter for BoolParam { } } -/// A trait for enum types that can be used as categorical parameters. +/// Map an enum type to sequential indices for use as a categorical parameter. /// -/// This trait maps enum variants to sequential indices and back. It can be -/// derived automatically for fieldless enums using `#[derive(Categorical)]` +/// This trait converts enum variants to sequential indices and back. It can +/// be derived automatically for fieldless enums using `#[derive(Categorical)]` /// when the `derive` feature is enabled. /// /// # Example @@ -558,20 +603,24 @@ pub trait Categorical: Sized + Clone { /// The number of variants in the enum. const N_CHOICES: usize; - /// Creates an instance from a variant index. + /// Create an instance from a variant index. /// /// # Panics /// /// Panics if `index >= N_CHOICES`. fn from_index(index: usize) -> Self; - /// Returns the index of this variant. + /// Return the index of this variant. fn to_index(&self) -> usize; } /// A parameter that selects from the variants of an enum implementing [`Categorical`]. /// -/// # Example +/// Prefer this over [`CategoricalParam`] when the choices map to a Rust enum, +/// because the returned value is already the correct variant — no string +/// matching required. +/// +/// # Examples /// /// ``` /// use optimizer::Trial; @@ -604,7 +653,10 @@ pub trait Categorical: Sized + Clone { /// } /// /// let mut trial = Trial::new(0); -/// let opt = EnumParam::<Optimizer>::new().suggest(&mut trial).unwrap(); +/// let opt = EnumParam::<Optimizer>::new() +/// .name("optimizer") +/// .suggest(&mut trial) +/// .unwrap(); /// ``` #[derive(Clone, Debug)] pub struct EnumParam<T: Categorical> { @@ -614,7 +666,7 @@ pub struct EnumParam<T: Categorical> { } impl<T: Categorical> EnumParam<T> { - /// Creates a new enum parameter. + /// Create a new enum parameter over all variants of `T`. #[must_use] pub fn new() -> Self { Self { @@ -624,7 +676,7 @@ impl<T: Categorical> EnumParam<T> { } } - /// Sets a human-readable name for this parameter. + /// Set a human-readable name for this parameter. /// /// When set, this name is used as the parameter's label instead of /// the default `Debug` output. diff --git a/src/pareto.rs b/src/pareto.rs index 4cfc822..c6a23c9 100644 --- a/src/pareto.rs +++ b/src/pareto.rs @@ -1,15 +1,71 @@ //! Pareto front analysis utilities for multi-objective optimization. //! -//! Provides functions for analyzing and working with Pareto fronts: +//! In multi-objective optimization there is generally no single best +//! solution. Instead, the goal is to find the **Pareto front** — the set +//! of solutions where no objective can be improved without worsening +//! another. This module provides tools for computing and analyzing Pareto +//! fronts. //! -//! - [`hypervolume`] — measure the quality of a Pareto front -//! - [`non_dominated_sort`] — rank solutions into successive fronts -//! - [`pareto_front_indices`] — filter to non-dominated solutions only -//! - [`crowding_distance`] — measure diversity within a front +//! # Available functions //! -//! Internally also provides fast non-dominated sorting (Deb et al., 2002) -//! used by [`MultiObjectiveStudy::pareto_front()`](crate::MultiObjectiveStudy::pareto_front) +//! | Function | Purpose | +//! |---|---| +//! | [`hypervolume`] | Measure the quality of a Pareto front (volume of dominated space) | +//! | [`non_dominated_sort`] | Rank solutions into successive fronts (front 0, 1, …) | +//! | [`pareto_front_indices`] | Filter to non-dominated (Pareto-optimal) solutions only | +//! | [`crowding_distance`] | Measure diversity/spread within a single front | +//! +//! # When to use +//! +//! - **Evaluating front quality**: Use [`hypervolume`] to compare two +//! Pareto fronts — a higher hypervolume indicates a better-quality front. +//! - **Ranking all solutions**: Use [`non_dominated_sort`] to partition +//! solutions into successive fronts, useful for selection in evolutionary +//! algorithms. +//! - **Extracting the best solutions**: Use [`pareto_front_indices`] to get +//! only the non-dominated set. +//! - **Diversity measurement**: Use [`crowding_distance`] to quantify how +//! spread out solutions are within a front, which helps maintain diversity. +//! +//! Internally, this module also provides the fast non-dominated sorting +//! algorithm (Deb et al., 2002) used by +//! [`MultiObjectiveStudy::pareto_front()`](crate::MultiObjectiveStudy::pareto_front) //! and [`Nsga2Sampler`](crate::Nsga2Sampler). +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pareto::{ +//! crowding_distance, hypervolume, non_dominated_sort, pareto_front_indices, +//! }; +//! +//! let solutions = vec![ +//! vec![1.0, 5.0], // Pareto-optimal +//! vec![5.0, 1.0], // Pareto-optimal +//! vec![3.0, 3.0], // Pareto-optimal +//! vec![4.0, 4.0], // Dominated by (3, 3) +//! ]; +//! let dirs = [Direction::Minimize, Direction::Minimize]; +//! +//! // Non-dominated sorting: front 0 has indices {0, 1, 2} +//! let fronts = non_dominated_sort(&solutions, &dirs); +//! assert_eq!(fronts.len(), 2); +//! +//! // Pareto front indices (shortcut for fronts[0]) +//! let mut front = pareto_front_indices(&solutions, &dirs); +//! front.sort(); +//! assert_eq!(front, vec![0, 1, 2]); +//! +//! // Hypervolume with reference point (6, 6) +//! let front_values: Vec<_> = front.iter().map(|&i| solutions[i].clone()).collect(); +//! let hv = hypervolume(&front_values, &[6.0, 6.0], &dirs); +//! assert!(hv > 0.0); +//! +//! // Crowding distance for diversity analysis +//! let cd = crowding_distance(&front_values, &dirs); +//! assert!(cd[0].is_infinite()); // boundary solution +//! ``` use crate::types::Direction; @@ -200,12 +256,17 @@ pub(crate) fn crowding_distance_indexed(front_indices: &[usize], values: &[Vec<f /// Compute the hypervolume indicator of a Pareto front. /// /// The hypervolume is the volume of the objective space dominated by -/// the Pareto front and bounded by a reference point. Higher values -/// indicate a better front. +/// the Pareto front and bounded by a reference point. A **higher** +/// hypervolume indicates a better front (closer to the ideal and more +/// spread out). /// /// Each entry in `front` is one solution's objective values. /// `reference_point` should be worse than all front members in every -/// objective (e.g., the worst acceptable values). +/// objective (e.g., the worst acceptable values). Solutions that do +/// not strictly dominate the reference point are ignored. +/// +/// Uses recursive slicing for dimensions > 1. Complexity grows with +/// the number of objectives and front size. /// /// # Panics /// @@ -345,12 +406,14 @@ fn non_dominated_minimize(points: &[Vec<f64>]) -> Vec<Vec<f64>> { /// Compute non-dominated sorting of a set of solutions. /// -/// Returns a vec of fronts, where `fronts[0]` is the Pareto front, -/// `fronts[1]` is the next best, etc. Each inner vec contains indices -/// into the original `solutions` slice. +/// Return a vec of fronts, where `fronts[0]` is the Pareto front +/// (non-dominated solutions), `fronts[1]` is the next-best front +/// (dominated only by front 0), and so on. Each inner vec contains +/// indices into the original `solutions` slice. /// -/// Uses the fast non-dominated sorting algorithm from -/// Deb et al. (2002) with O(M N²) complexity. +/// Use the fast non-dominated sorting algorithm from +/// Deb et al. (2002) with O(M × N²) complexity, where M is the +/// number of objectives and N is the number of solutions. #[must_use] pub fn non_dominated_sort(solutions: &[Vec<f64>], directions: &[Direction]) -> Vec<Vec<usize>> { fast_non_dominated_sort(solutions, directions) @@ -359,7 +422,8 @@ pub fn non_dominated_sort(solutions: &[Vec<f64>], directions: &[Direction]) -> V /// Filter solutions to return only non-dominated (Pareto-optimal) indices. /// /// Equivalent to `non_dominated_sort(solutions, directions)[0]` but -/// communicates the intent more clearly. +/// communicates the intent more clearly. Use this when you only need +/// the Pareto front and not the full ranking. #[must_use] pub fn pareto_front_indices(solutions: &[Vec<f64>], directions: &[Direction]) -> Vec<usize> { let fronts = fast_non_dominated_sort(solutions, directions); @@ -368,10 +432,13 @@ pub fn pareto_front_indices(solutions: &[Vec<f64>], directions: &[Direction]) -> /// Compute crowding distance for diversity measurement. /// -/// Returns one distance value per solution in `front` (same order). +/// Return one distance value per solution in `front` (same order). /// Boundary solutions (best/worst in any objective) receive /// [`f64::INFINITY`]. Interior solutions get a finite positive value -/// proportional to the gap between their neighbors. +/// proportional to the gap between their neighbors in each objective. +/// +/// Crowding distance is used by NSGA-II to prefer well-spread +/// solutions when two solutions are in the same front. /// /// `directions` is accepted for API consistency but does not affect /// the result, since crowding distance measures spacing regardless of diff --git a/src/pruner/hyperband.rs b/src/pruner/hyperband.rs index 98c6cca..ac82acf 100644 --- a/src/pruner/hyperband.rs +++ b/src/pruner/hyperband.rs @@ -1,3 +1,47 @@ +//! `HyperBand` pruner — adaptive budget scheduling with multiple SHA brackets. +//! +//! `HyperBand` addresses the main weakness of +//! [`SuccessiveHalvingPruner`](super::SuccessiveHalvingPruner): sensitivity to +//! the `min_resource` setting. It runs multiple Successive Halving brackets in +//! parallel, each with a different trade-off between the number of trials and +//! the starting budget: +//! +//! - **Bracket 0**: many trials, small starting budget (aggressive early pruning) +//! - **Bracket `s_max`**: few trials, full budget (no pruning) +//! +//! Trials are assigned to brackets in round-robin order. This ensures that +//! the overall search is robust regardless of how informative early steps are. +//! +//! # When to use +//! +//! - When you don't know how many epochs/steps are needed before performance +//! becomes predictive +//! - As a drop-in upgrade over [`SuccessiveHalvingPruner`](super::SuccessiveHalvingPruner) +//! when you can afford more total trials +//! - For large-scale hyperparameter searches where compute savings matter most +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `min_resource` | 1 | Smallest budget for the most aggressive bracket | +//! | `max_resource` | 81 | Full budget (last rung in every bracket) | +//! | `reduction_factor` | 3 | At each rung, keep top 1/η trials | +//! | `direction` | `Minimize` | Optimization direction | +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pruner::HyperbandPruner; +//! +//! let pruner = HyperbandPruner::new() +//! .min_resource(1) +//! .max_resource(81) +//! .reduction_factor(3) +//! .direction(Direction::Minimize); +//! ``` + use core::sync::atomic::{AtomicU64, Ordering}; use std::collections::HashMap; use std::sync::Mutex; @@ -6,7 +50,7 @@ use super::Pruner; use crate::sampler::CompletedTrial; use crate::types::{Direction, TrialState}; -/// Hyperband pruner that manages multiple Successive Halving brackets. +/// `HyperBand` pruner that manages multiple Successive Halving brackets. /// /// Hyperband addresses SHA's sensitivity to the `min_resource` choice by /// running multiple brackets, each with a different tradeoff between the diff --git a/src/pruner/median.rs b/src/pruner/median.rs index 122d483..0e97c8f 100644 --- a/src/pruner/median.rs +++ b/src/pruner/median.rs @@ -1,3 +1,39 @@ +//! Median pruner — the recommended default pruner for most use cases. +//! +//! At each step, the current trial's intermediate value is compared against +//! the median of all completed trials' values at the same step. Trials +//! performing worse than the median are pruned. +//! +//! This is a convenience wrapper around [`PercentilePruner`](super::PercentilePruner) +//! with a fixed percentile of 50%. +//! +//! # When to use +//! +//! - **Default choice** for any iterative objective (e.g., neural network training) +//! - Works well when intermediate values are a reasonable proxy for final performance +//! - Prunes roughly half of unpromising trials, giving a good speed/accuracy balance +//! +//! If your intermediate values are noisy, consider [`WilcoxonPruner`](super::WilcoxonPruner) +//! or wrapping this pruner in a [`PatientPruner`](super::PatientPruner). +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `n_warmup_steps` | 0 | Skip pruning in the first N steps | +//! | `n_min_trials` | 1 | Require at least N completed trials before pruning | +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pruner::MedianPruner; +//! +//! let pruner = MedianPruner::new(Direction::Minimize) +//! .n_warmup_steps(5) +//! .n_min_trials(3); +//! ``` + use super::Pruner; use super::percentile::compute_percentile; use crate::sampler::CompletedTrial; diff --git a/src/pruner/mod.rs b/src/pruner/mod.rs index 614dfa1..03d7f01 100644 --- a/src/pruner/mod.rs +++ b/src/pruner/mod.rs @@ -3,6 +3,44 @@ //! Pruners decide whether to stop (prune) a trial early based on its //! intermediate values compared to other trials. This is useful for //! discarding unpromising trials before they complete, saving compute. +//! +//! # How pruning works +//! +//! During optimization, each trial reports intermediate values at discrete +//! steps (e.g., validation loss after each training epoch). A pruner inspects +//! these values and compares them against completed trials to decide whether +//! the current trial should be stopped early. +//! +//! The typical flow is: +//! +//! 1. Call [`Trial::report`](crate::Trial::report) to record an intermediate value. +//! 2. Call [`Trial::should_prune`](crate::Trial::should_prune) to check the pruner's decision. +//! 3. If the pruner says prune, return [`TrialPruned`](crate::TrialPruned) from the objective. +//! +//! # Available pruners +//! +//! | Pruner | Algorithm | Best for | +//! |--------|-----------|----------| +//! | [`MedianPruner`] | Prune below median at each step | General-purpose default | +//! | [`PercentilePruner`] | Prune below configurable percentile | Tunable aggressiveness | +//! | [`ThresholdPruner`] | Prune outside fixed bounds | Known divergence limits | +//! | [`PatientPruner`] | Require N consecutive prune signals | Noisy intermediate values | +//! | [`SuccessiveHalvingPruner`] | Keep top 1/η fraction at each rung | Budget-aware pruning | +//! | [`HyperbandPruner`] | Multiple SHA brackets with different budgets | Robust to budget choice | +//! | [`WilcoxonPruner`] | Statistical signed-rank test vs. best trial | Rigorous noisy pruning | +//! | [`NopPruner`] | Never prune | Disabling pruning explicitly | +//! +//! # When to use pruning +//! +//! Pruning is most beneficial when: +//! +//! - The objective function has a natural notion of "steps" (e.g., training epochs) +//! - Early steps are informative about final performance +//! - Trials are expensive enough that stopping bad ones early saves significant time +//! +//! Start with [`MedianPruner`] for most use cases. Switch to [`WilcoxonPruner`] +//! if your intermediate values are noisy, or to [`HyperbandPruner`] if you want +//! automatic budget scheduling. mod hyperband; mod median; diff --git a/src/pruner/nop.rs b/src/pruner/nop.rs index 9ccfad9..ad9f337 100644 --- a/src/pruner/nop.rs +++ b/src/pruner/nop.rs @@ -1,3 +1,15 @@ +//! No-op pruner — never prune any trial. +//! +//! This is the default pruner used when no pruner is configured on a +//! [`Study`](crate::Study). It unconditionally returns `false` for every +//! pruning decision, allowing all trials to run to completion. +//! +//! # When to use +//! +//! - When you want to explicitly disable pruning +//! - As a baseline to compare against other pruners +//! - Already used by default — you rarely need to configure this manually + use super::Pruner; use crate::sampler::CompletedTrial; diff --git a/src/pruner/patient.rs b/src/pruner/patient.rs index 0f56f1c..ab94fd6 100644 --- a/src/pruner/patient.rs +++ b/src/pruner/patient.rs @@ -1,3 +1,34 @@ +//! Patient pruner — require consecutive prune signals before actually pruning. +//! +//! Wraps any other pruner and adds a patience window: the inner pruner +//! must recommend pruning for `patience` consecutive steps before the +//! trial is actually pruned. This prevents premature pruning when +//! intermediate values are noisy and a single bad step doesn't indicate +//! a truly bad trial. +//! +//! # When to use +//! +//! - When your intermediate values have high variance (e.g., mini-batch loss) +//! - When the inner pruner is too aggressive on its own +//! - To add robustness to any statistical pruner without changing its threshold +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `inner` | *(required)* | The underlying pruner to wrap | +//! | `patience` | *(required)* | Number of consecutive prune signals required | +//! +//! # Example +//! +//! ``` +//! use optimizer::pruner::{PatientPruner, ThresholdPruner}; +//! +//! // Only prune after the threshold is exceeded 3 times in a row +//! let inner = ThresholdPruner::new().upper(100.0); +//! let pruner = PatientPruner::new(inner, 3); +//! ``` + use std::collections::HashMap; use std::sync::Mutex; diff --git a/src/pruner/percentile.rs b/src/pruner/percentile.rs index 7a5d6f6..0495b6b 100644 --- a/src/pruner/percentile.rs +++ b/src/pruner/percentile.rs @@ -1,3 +1,37 @@ +//! Percentile pruner — prune trials outside the top N% at each step. +//! +//! A generalization of [`MedianPruner`](super::MedianPruner) that lets you +//! control how aggressively to prune. At each step, the current trial's +//! intermediate value is compared against the given percentile of all +//! completed trials' values at the same step. +//! +//! # When to use +//! +//! - When you want finer control over pruning aggressiveness than median pruning +//! - Lower percentiles (e.g., 25%) are more aggressive — only keep the best quarter +//! - Higher percentiles (e.g., 75%) are more lenient — keep the top three quarters +//! - Percentile 50% is equivalent to [`MedianPruner`](super::MedianPruner) +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `percentile` | *(required)* | Keep trials in the top N% — range `(0, 100)` | +//! | `n_warmup_steps` | 0 | Skip pruning in the first N steps | +//! | `n_min_trials` | 1 | Require at least N completed trials before pruning | +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pruner::PercentilePruner; +//! +//! // Keep only the top 25% of trials (aggressive pruning) +//! let pruner = PercentilePruner::new(25.0, Direction::Minimize) +//! .n_warmup_steps(5) +//! .n_min_trials(3); +//! ``` + use super::Pruner; use crate::sampler::CompletedTrial; use crate::types::{Direction, TrialState}; diff --git a/src/pruner/successive_halving.rs b/src/pruner/successive_halving.rs index 376ddea..0408101 100644 --- a/src/pruner/successive_halving.rs +++ b/src/pruner/successive_halving.rs @@ -1,3 +1,53 @@ +//! Successive Halving (SHA) pruner — budget-aware pruning at exponential rungs. +//! +//! Trials are evaluated at exponentially-spaced "rungs" (checkpoints). At each +//! rung, only the top 1/η fraction of trials survive to the next rung. This +//! is a principled way to allocate compute budget: give many trials a small +//! budget, then progressively invest more in the best ones. +//! +//! For example, with `min_resource=1`, `max_resource=81`, `reduction_factor=3`: +//! +//! | Rung | Step | Survivors | +//! |------|------|-----------| +//! | 0 | 1 | top 1/3 | +//! | 1 | 3 | top 1/3 | +//! | 2 | 9 | top 1/3 | +//! | 3 | 27 | top 1/3 | +//! | 4 | 81 | all (full budget) | +//! +//! # When to use +//! +//! - When your objective has a natural "budget" dimension (epochs, iterations) +//! - When early performance is a reasonable predictor of final performance +//! - When you want a principled alternative to median pruning +//! +//! If you're unsure about the right `min_resource`, consider +//! [`HyperbandPruner`](super::HyperbandPruner) which runs multiple brackets +//! to hedge against that choice. +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `min_resource` | 1 | Step at which the first rung is placed | +//! | `max_resource` | 81 | Full budget (final rung, no pruning) | +//! | `reduction_factor` | 3 | At each rung, keep top 1/η trials | +//! | `min_early_stopping_rate` | 0 | Skip the first N rungs | +//! | `direction` | `Minimize` | Optimization direction | +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pruner::SuccessiveHalvingPruner; +//! +//! let pruner = SuccessiveHalvingPruner::new() +//! .min_resource(1) +//! .max_resource(81) +//! .reduction_factor(3) +//! .direction(Direction::Minimize); +//! ``` + use super::Pruner; use crate::sampler::CompletedTrial; use crate::types::{Direction, TrialState}; diff --git a/src/pruner/threshold.rs b/src/pruner/threshold.rs index 7058a71..837df9f 100644 --- a/src/pruner/threshold.rs +++ b/src/pruner/threshold.rs @@ -1,3 +1,33 @@ +//! Threshold pruner — prune trials whose values fall outside fixed bounds. +//! +//! Unlike statistical pruners that compare against other trials, the +//! threshold pruner uses absolute bounds. Any trial whose latest +//! intermediate value exceeds the upper bound or falls below the lower +//! bound is pruned immediately. +//! +//! # When to use +//! +//! - When you know hard limits for valid intermediate values (e.g., loss should +//! never exceed 100.0) +//! - To catch diverging or NaN-producing trials early +//! - Often combined with other pruners via [`PatientPruner`](super::PatientPruner) +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `upper` | `None` | Prune if value exceeds this bound | +//! | `lower` | `None` | Prune if value falls below this bound | +//! +//! # Example +//! +//! ``` +//! use optimizer::pruner::ThresholdPruner; +//! +//! // Prune if loss exceeds 100.0 or accuracy drops below 0.0 +//! let pruner = ThresholdPruner::new().upper(100.0).lower(0.0); +//! ``` + use super::Pruner; use crate::sampler::CompletedTrial; diff --git a/src/pruner/wilcoxon.rs b/src/pruner/wilcoxon.rs index d148f20..9717e37 100644 --- a/src/pruner/wilcoxon.rs +++ b/src/pruner/wilcoxon.rs @@ -1,3 +1,44 @@ +//! Wilcoxon pruner — statistically rigorous pruning for noisy objectives. +//! +//! Uses the Wilcoxon signed-rank test to compare the current trial's +//! intermediate values against the best completed trial at matching steps. +//! The test accounts for the paired, step-aligned nature of the comparison +//! and only prunes when the difference is statistically significant. +//! +//! This is more principled than [`MedianPruner`](super::MedianPruner) for +//! noisy objectives because a single bad step won't trigger pruning — the +//! test considers the full distribution of paired differences. +//! +//! # When to use +//! +//! - When intermediate values have high variance (e.g., mini-batch loss, +//! stochastic reward signals) +//! - When you want a statistical guarantee that pruned trials are truly worse +//! - When you have enough steps (at least 6) for a meaningful test +//! +//! For less noisy objectives, [`MedianPruner`](super::MedianPruner) is simpler +//! and often sufficient. +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `p_value_threshold` | 0.05 | Significance level — lower is more conservative | +//! | `n_warmup_steps` | 0 | Skip pruning in the first N steps | +//! | `n_min_trials` | 1 | Require at least N completed trials before pruning | +//! +//! # Example +//! +//! ``` +//! use optimizer::Direction; +//! use optimizer::pruner::WilcoxonPruner; +//! +//! let pruner = WilcoxonPruner::new(Direction::Minimize) +//! .p_value_threshold(0.05) +//! .n_warmup_steps(5) +//! .n_min_trials(1); +//! ``` + use core::cmp::Ordering; use super::Pruner; diff --git a/src/sampler/bohb.rs b/src/sampler/bohb.rs index 2444e8f..2e2d0b8 100644 --- a/src/sampler/bohb.rs +++ b/src/sampler/bohb.rs @@ -1,13 +1,13 @@ //! BOHB (Bayesian Optimization + `HyperBand`) sampler. //! -//! BOHB combines TPE's model-guided sampling with Hyperband's budget-aware +//! BOHB combines TPE's model-guided sampling with `HyperBand`'s budget-aware //! evaluation. Instead of building one global TPE model, BOHB conditions //! its TPE model on trials evaluated at a specific budget level, giving -//! better-calibrated proposals for each rung of the Hyperband schedule. +//! better-calibrated proposals for each rung of the `HyperBand` schedule. //! //! # How it works //! -//! 1. Compute all Hyperband rung steps (budget levels) from the config. +//! 1. Compute all `HyperBand` rung steps (budget levels) from the config. //! 2. On each `sample()` call, scan the history's `intermediate_values` //! to find the **largest budget level** with enough observations //! (`>= min_points_in_model`). @@ -16,6 +16,26 @@ //! 4. Delegate to an internal [`TpeSampler`] for the actual sampling. //! 5. Fall back to random sampling if no budget level has enough data. //! +//! # When to use +//! +//! - You are tuning hyperparameters for models that support early stopping +//! (e.g., neural networks with configurable epoch counts). +//! - You want to combine model-guided search with aggressive pruning of +//! unpromising configurations. +//! - Your objective has a natural "budget" axis (epochs, iterations, data +//! fraction) reported via [`Trial::report`](crate::Trial::report). +//! +//! Pair `BohbSampler` with [`matching_pruner`](BohbSampler::matching_pruner) +//! to get a `HyperBandPruner` whose budget schedule is consistent with +//! the sampler's conditioning levels. +//! +//! # Configuration +//! +//! - `min_resource` / `max_resource` — budget range (default: 1 … 81) +//! - `reduction_factor` (η) — successive halving factor (default: 3) +//! - `min_points_in_model` — observations needed before TPE replaces random (default: 10) +//! - All [`TpeSamplerBuilder`](super::tpe::TpeSamplerBuilder) options (gamma, seed, etc.) +//! //! # Examples //! //! ``` @@ -27,7 +47,7 @@ //! let study: Study<f64> = Study::with_sampler_and_pruner(Direction::Minimize, bohb, pruner); //! ``` //! -//! Using the builder for custom configuration: +//! Custom configuration via builder: //! //! ``` //! use optimizer::sampler::bohb::BohbSampler; @@ -50,7 +70,7 @@ use crate::sampler::tpe::TpeSampler; use crate::sampler::{CompletedTrial, Sampler}; use crate::types::Direction; -/// A BOHB sampler that combines TPE with Hyperband budget awareness. +/// A BOHB sampler that combines TPE with `HyperBand` budget awareness. /// /// BOHB filters trial history by budget level before delegating to TPE, /// so the surrogate model is conditioned on trials evaluated at the same @@ -58,7 +78,25 @@ use crate::types::Direction; /// than using a single global model across all budgets. /// /// Use [`BohbSampler::matching_pruner`] to create a [`HyperbandPruner`] -/// with matching parameters. +/// with matching `HyperBand` parameters. +/// +/// # Examples +/// +/// ``` +/// use optimizer::parameter::{FloatParam, Parameter}; +/// use optimizer::sampler::bohb::BohbSampler; +/// use optimizer::{Direction, Study}; +/// +/// let bohb = BohbSampler::builder() +/// .min_resource(1) +/// .max_resource(27) +/// .reduction_factor(3) +/// .seed(42) +/// .build() +/// .unwrap(); +/// let pruner = bohb.matching_pruner(Direction::Minimize); +/// let study: Study<f64> = Study::with_sampler_and_pruner(Direction::Minimize, bohb, pruner); +/// ``` pub struct BohbSampler { min_resource: u64, max_resource: u64, @@ -225,6 +263,14 @@ impl Sampler for BohbSampler { /// Builder for configuring a [`BohbSampler`]. /// +/// # Defaults +/// +/// - `min_resource`: 1 +/// - `max_resource`: 81 +/// - `reduction_factor`: 3 (η) +/// - `min_points_in_model`: 10 +/// - TPE: default settings (gamma = 0.25, etc.) +/// /// # Examples /// /// ``` diff --git a/src/sampler/cma_es.rs b/src/sampler/cma_es.rs index eadaec8..f243a88 100644 --- a/src/sampler/cma_es.rs +++ b/src/sampler/cma_es.rs @@ -1,15 +1,55 @@ //! CMA-ES (Covariance Matrix Adaptation Evolution Strategy) sampler. //! -//! CMA-ES maintains a multivariate Gaussian distribution over continuous -//! parameters and adapts its mean, covariance matrix, and step-size based -//! on trial rankings. It is one of the most effective derivative-free -//! optimizers for continuous search spaces. +//! CMA-ES is a stochastic, population-based optimizer that maintains a +//! multivariate Gaussian distribution over continuous parameters and adapts +//! its **mean**, **covariance matrix**, and **step-size** (σ) based on +//! trial rankings. It is widely regarded as one of the most effective +//! derivative-free optimizers for continuous search spaces. //! -//! Categorical parameters are sampled uniformly at random (not part of -//! the CMA-ES vector). If all parameters are categorical, the sampler -//! falls back to pure random sampling. +//! # Algorithm overview //! -//! Requires the `cma-es` feature flag. +//! Each generation: +//! 1. **Sample** λ (population size) candidates from N(m, σ²C). +//! 2. **Evaluate** and **rank** the candidates by objective value. +//! 3. **Update** the mean toward the best μ candidates (weighted recombination). +//! 4. **Adapt** the covariance matrix C via rank-one and rank-μ updates, and +//! adapt σ via cumulative step-size adaptation (CSA). +//! +//! Over time the search distribution narrows and rotates to align with the +//! landscape, efficiently exploiting structure in the objective function. +//! +//! # When to use +//! +//! - **Continuous parameters only** (float/int). Categorical parameters are +//! sampled uniformly at random and do not participate in the CMA-ES model. +//! - **Moderate dimensionality** — works well up to ~100 continuous dimensions. +//! Beyond that, the O(n²) covariance matrix becomes expensive to maintain. +//! - **Non-separable objectives** — CMA-ES learns parameter correlations +//! through the covariance matrix, making it especially effective on +//! rotated or ill-conditioned landscapes. +//! - **Moderate evaluation budgets** — typically needs ≈10×n to 100×n +//! evaluations to converge, where n is the number of continuous dimensions. +//! +//! For very cheap evaluations in low dimensions (d ≤ 20), consider +//! [`GpSampler`](super::gp::GpSampler) instead. For high-dimensional +//! separable problems, TPE may be more efficient. +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `sigma0` | `avg_range / 4` | Initial step size — controls exploration breadth | +//! | `population_size` | `4 + ⌊3 ln n⌋` | Candidates per generation (λ) | +//! | `seed` | random | RNG seed for reproducibility | +//! +//! # Feature flag +//! +//! Requires the **`cma-es`** feature (adds the `nalgebra` dependency): +//! +//! ```toml +//! [dependencies] +//! optimizer = { version = "...", features = ["cma-es"] } +//! ``` //! //! # Examples //! @@ -17,8 +57,14 @@ //! use optimizer::sampler::cma_es::CmaEsSampler; //! use optimizer::{Direction, Study}; //! -//! let sampler = CmaEsSampler::with_seed(42); -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +//! // Minimize a 2-D sphere function with CMA-ES +//! let sampler = CmaEsSampler::builder() +//! .sigma0(1.0) +//! .population_size(10) +//! .seed(42) +//! .build(); +//! +//! let mut study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); //! ``` use std::collections::HashMap; diff --git a/src/sampler/differential_evolution.rs b/src/sampler/differential_evolution.rs index 0be8ffc..3a11e55 100644 --- a/src/sampler/differential_evolution.rs +++ b/src/sampler/differential_evolution.rs @@ -1,22 +1,66 @@ //! Differential Evolution (DE) sampler. //! -//! DE is a population-based metaheuristic that maintains a population of -//! candidate solutions and creates new candidates by combining (mutating + -//! crossing over) existing ones. It is competitive with CMA-ES on many -//! problems and simpler to implement. +//! DE is a population-based metaheuristic that maintains a pool of candidate +//! solutions and creates new candidates through **mutation** (combining +//! difference vectors of existing members) and **binomial crossover**. A +//! trial vector replaces its parent only if it achieves a better objective +//! value, guaranteeing monotonic improvement of the population. //! -//! Categorical parameters are sampled uniformly at random (not part of the -//! DE vector). If all parameters are categorical, the sampler falls back to -//! pure random sampling. +//! # Algorithm overview +//! +//! Each generation, for every population member *xᵢ*: +//! 1. **Mutation** — create a mutant vector *v* from other population +//! members using the selected [`DifferentialEvolutionStrategy`]: +//! - `Rand1`: `v = x_r1 + F * (x_r2 - x_r3)` +//! - `Best1`: `v = x_best + F * (x_r1 - x_r2)` +//! - `CurrentToBest1`: `v = x_i + F * (x_best - x_i) + F * (x_r1 - x_r2)` +//! 2. **Crossover** — create a trial vector *u* by mixing *v* and *xᵢ* +//! dimension-by-dimension with probability CR. +//! 3. **Selection** — replace *xᵢ* with *u* if `f(u) ≤ f(xᵢ)`. +//! +//! # When to use +//! +//! - **Continuous parameters** (float/int). Categorical parameters are +//! sampled uniformly at random and do not participate in DE. +//! - **Moderate to large search spaces** — DE scales better than GP-based +//! methods to higher dimensions, though it may need more evaluations. +//! - **Multi-modal landscapes** — the `Rand1` strategy maintains diversity +//! and avoids premature convergence. +//! - **No feature flags required** — DE is available with default features. +//! +//! For non-separable problems in moderate dimensions, consider +//! [`CmaEsSampler`](super::cma_es::CmaEsSampler) which learns parameter +//! correlations. For expensive functions with few dimensions, consider +//! [`GpSampler`](super::gp::GpSampler). +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `population_size` | `max(10n, 15)` | Candidates per generation | +//! | `mutation_factor` (F) | 0.8 | Differential amplification — higher = more exploration | +//! | `crossover_rate` (CR) | 0.9 | Probability of taking a dimension from the mutant | +//! | `strategy` | `Rand1` | Mutation strategy (see [`DifferentialEvolutionStrategy`]) | +//! | `seed` | random | RNG seed for reproducibility | //! //! # Examples //! //! ``` -//! use optimizer::sampler::differential_evolution::DifferentialEvolutionSampler; +//! use optimizer::sampler::differential_evolution::{ +//! DifferentialEvolutionSampler, DifferentialEvolutionStrategy, +//! }; //! use optimizer::{Direction, Study}; //! -//! let sampler = DifferentialEvolutionSampler::with_seed(42); -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +//! // Minimize with DE using the Best1 strategy for faster convergence +//! let sampler = DifferentialEvolutionSampler::builder() +//! .mutation_factor(0.7) +//! .crossover_rate(0.9) +//! .strategy(DifferentialEvolutionStrategy::Best1) +//! .population_size(20) +//! .seed(42) +//! .build(); +//! +//! let mut study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); //! ``` use std::collections::HashMap; diff --git a/src/sampler/gp.rs b/src/sampler/gp.rs index 2ae49b1..fe5d8f6 100644 --- a/src/sampler/gp.rs +++ b/src/sampler/gp.rs @@ -1,15 +1,60 @@ //! Gaussian Process (GP) sampler with Expected Improvement acquisition. //! -//! A classical Bayesian optimization sampler that uses a Gaussian Process -//! surrogate model with a Matérn 5/2 kernel and Expected Improvement (EI) +//! A classical Bayesian optimization sampler that builds a Gaussian Process +//! surrogate model with a **Matérn 5/2 kernel** (with ARD lengthscales) and +//! selects the next trial by maximizing the **Expected Improvement (EI)** //! acquisition function. Best suited for small, expensive evaluations in //! low-dimensional continuous spaces (d ≤ 20). //! -//! Categorical parameters are sampled uniformly at random (not part of -//! the GP model). If all parameters are categorical, the sampler falls -//! back to pure random sampling. +//! # Algorithm overview //! -//! Requires the `gp` feature flag. +//! 1. **Startup phase** — the first `n_startup_trials` trials are sampled +//! uniformly at random to build an initial dataset. +//! 2. **Fit GP** — training observations are standardized (zero mean, unit +//! variance) and a GP with Matérn 5/2 kernel is fitted via Cholesky +//! decomposition. ARD lengthscales are set to the per-dimension standard +//! deviation of the training inputs. +//! 3. **Maximize EI** — `n_candidates` random points are evaluated under +//! the GP posterior and the point with the highest Expected Improvement +//! is returned as the next trial. +//! +//! The GP uses at most 100 training points (the most recent ones) to keep +//! the O(n³) fitting cost manageable. +//! +//! # When to use +//! +//! - **Expensive objective functions** where every evaluation is costly +//! (e.g. physical experiments, large simulations). The GP surrogate +//! amortizes this cost by making fewer evaluations. +//! - **Low-dimensional continuous spaces** — typically d ≤ 20. Beyond that, +//! the GP becomes unreliable and alternatives like +//! [`CmaEsSampler`](super::cma_es::CmaEsSampler) or +//! [`TpeSampler`](super::tpe::TpeSampler) are preferable. +//! - **Smooth, low-noise objectives** — the GP assumes smoothness through +//! the Matérn 5/2 kernel. Very noisy objectives require increasing +//! `noise_variance`. +//! +//! Categorical parameters are sampled uniformly at random and do not +//! participate in the GP model. If all parameters are categorical, the +//! sampler falls back to pure random sampling. +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |--------|---------|-------------| +//! | `n_startup_trials` | 10 | Random trials before GP-guided sampling begins | +//! | `n_candidates` | 1000 | Random candidates for EI maximization | +//! | `noise_variance` | 1e-6 | Observation noise added to kernel diagonal | +//! | `seed` | random | RNG seed for reproducibility | +//! +//! # Feature flag +//! +//! Requires the **`gp`** feature (adds the `nalgebra` dependency): +//! +//! ```toml +//! [dependencies] +//! optimizer = { version = "...", features = ["gp"] } +//! ``` //! //! # Examples //! @@ -17,8 +62,14 @@ //! use optimizer::sampler::gp::GpSampler; //! use optimizer::{Direction, Study}; //! -//! let sampler = GpSampler::with_seed(42); -//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +//! // Minimize an expensive function with GP-based Bayesian optimization +//! let sampler = GpSampler::builder() +//! .n_startup_trials(5) +//! .n_candidates(500) +//! .seed(42) +//! .build(); +//! +//! let mut study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); //! ``` use std::collections::HashMap; diff --git a/src/sampler/grid.rs b/src/sampler/grid.rs index dd6768f..c3ce2b7 100644 --- a/src/sampler/grid.rs +++ b/src/sampler/grid.rs @@ -1,7 +1,39 @@ -//! Grid search sampler implementation. +//! Grid search sampler — exhaustive evaluation of discretized parameter spaces. //! -//! `GridSearchSampler` performs exhaustive grid search over the parameter space, -//! systematically evaluating all combinations of discretized parameter values. +//! [`GridSearchSampler`] divides each parameter range into a fixed number of +//! evenly spaced points (or uses the explicit step size when defined) and +//! evaluates them sequentially. This guarantees complete coverage of the +//! search grid at the cost of scaling exponentially with the number of +//! parameters. +//! +//! # When to use +//! +//! - **Small, discrete spaces** — when you have a handful of categorical or +//! integer parameters and want to evaluate every combination. +//! - **Reproducibility** — grid search is fully deterministic with no random +//! component. +//! - **Benchmarking** — compare grid search results against adaptive samplers +//! to measure their benefit. +//! +//! Avoid grid search for high-dimensional or large continuous spaces; +//! prefer [`TpeSampler`](super::tpe::TpeSampler) or +//! [`RandomSampler`](super::random::RandomSampler) instead. +//! +//! # Configuration +//! +//! | Option | Default | Description | +//! |---|---|---| +//! | `n_points_per_param` | 10 | Points per continuous parameter (ignored when `step` is set) | +//! +//! # Example +//! +//! ``` +//! use optimizer::prelude::*; +//! use optimizer::sampler::grid::GridSearchSampler; +//! +//! let sampler = GridSearchSampler::builder().n_points_per_param(5).build(); +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +//! ``` use std::collections::HashMap; @@ -295,38 +327,38 @@ struct GridState { grids: HashMap<String, CachedGrid>, } -/// A grid search sampler that exhaustively evaluates all grid points. +/// Exhaustive grid search sampler. /// -/// `GridSearchSampler` divides the parameter space into a grid and systematically -/// samples each point. This is useful when you want to evaluate all combinations -/// of parameter values, especially for discrete or small parameter spaces. +/// Divide each parameter range into evenly spaced points and evaluate them +/// sequentially. When a parameter has an explicit `step` size, the step grid +/// is used instead of auto-discretization. /// -/// # Grid Exhaustion +/// Grid state is tracked **per distribution key** (bounds + step + log-scale). +/// Parameters with identical distributions share the same grid counter, so +/// use distinct ranges when multiple parameters span the same domain. /// -/// The sampler tracks its position in the grid for each distribution independently. -/// When all grid points for a distribution have been sampled, subsequent calls to -/// `sample()` for that distribution will **panic** with the message: -/// `"GridSearchSampler: all grid points exhausted"`. +/// # Grid exhaustion /// -/// To avoid panics, use [`is_exhausted()`](Self::is_exhausted) to check if all -/// points have been sampled before calling `sample()`. You can also use -/// [`grid_size()`](Self::grid_size) to determine the total number of grid points -/// that will be sampled. +/// When all grid points for a distribution have been sampled, the next +/// `sample()` call for that distribution **panics**. Use +/// [`is_exhausted()`](Self::is_exhausted) to check before sampling, or +/// set `n_points_per_param` high enough to cover the planned number of +/// trials. /// -/// # Thread Safety +/// # Thread safety /// -/// `GridSearchSampler` is thread-safe (`Send + Sync`) and uses internal locking -/// to ensure safe concurrent access to grid state. +/// `GridSearchSampler` is `Send + Sync` and uses internal locking for +/// safe concurrent access. /// /// # Examples /// /// ``` /// use optimizer::sampler::grid::GridSearchSampler; /// -/// // Create with default settings (10 points per parameter) +/// // Default: 10 points per parameter /// let sampler = GridSearchSampler::new(); /// -/// // Create with custom settings using the builder +/// // Custom grid density /// let sampler = GridSearchSampler::builder().n_points_per_param(20).build(); /// ``` pub struct GridSearchSampler { diff --git a/src/sampler/moead.rs b/src/sampler/moead.rs index d6d3997..6f11d44 100644 --- a/src/sampler/moead.rs +++ b/src/sampler/moead.rs @@ -1,8 +1,58 @@ //! MOEA/D (Multi-Objective Evolutionary Algorithm based on Decomposition) sampler. //! -//! Decomposes a multi-objective problem into scalar subproblems using -//! weight vectors and solves them collaboratively. Supports Weighted Sum, -//! Tchebycheff, and Penalty-based Boundary Intersection (PBI) scalarization. +//! MOEA/D takes a fundamentally different approach from Pareto-based +//! algorithms like NSGA-II/III. It **decomposes** the multi-objective +//! problem into a set of scalar subproblems using evenly distributed +//! weight vectors (Das-Dennis points), then solves them collaboratively +//! through **neighborhood-based mating and replacement**. +//! +//! # Algorithm +//! +//! 1. **Decompose** — generate weight vectors on the unit simplex and +//! assign one scalar subproblem per weight vector. +//! 2. **Build neighborhoods** — for each subproblem, find its T nearest +//! neighbors by Euclidean distance between weight vectors. +//! 3. **Mate from neighborhood** — select parents from the neighborhood +//! of each subproblem and produce offspring via SBX crossover + +//! polynomial mutation. +//! 4. **Scalarize and update** — evaluate offspring using a scalarization +//! function and update neighboring subproblems if the offspring improves +//! their scalar value. +//! 5. **Update ideal point** — track the best value seen per objective. +//! +//! # Scalarization methods +//! +//! | Method | Formula | Best for | +//! |--------|---------|----------| +//! | [`Tchebycheff`](Decomposition::Tchebycheff) (default) | `max(wᵢ * \|fᵢ - zᵢ*\|)` | General purpose, handles non-convex fronts | +//! | [`WeightedSum`](Decomposition::WeightedSum) | `Σ(wᵢ * fᵢ)` | Convex Pareto fronts only | +//! | [`Pbi`](Decomposition::Pbi) | `d₁ + θ * d₂` | Fine-grained convergence/diversity control | +//! +//! # When to use +//! +//! - Problems where you want **evenly distributed** solutions along the +//! Pareto front (one solution per weight direction). +//! - Many-objective optimization (3+ objectives) — scales well because +//! each subproblem is a simple scalar optimization. +//! - Problems with **non-convex** Pareto fronts (use Tchebycheff or PBI). +//! - When you need explicit control over the trade-off distribution via +//! weight vectors. +//! +//! For Pareto-based approaches, see +//! [`Nsga2Sampler`](super::nsga2::Nsga2Sampler) (crowding distance) or +//! [`Nsga3Sampler`](super::nsga3::Nsga3Sampler) (reference-point niching). +//! +//! # Configuration +//! +//! | Parameter | Builder method | Default | +//! |-----------|---------------|---------| +//! | Population size | [`population_size`](MoeadSamplerBuilder::population_size) | Number of Das-Dennis weight vectors | +//! | Neighborhood size (T) | [`neighborhood_size`](MoeadSamplerBuilder::neighborhood_size) | `min(20, pop_size)` | +//! | Decomposition method | [`decomposition`](MoeadSamplerBuilder::decomposition) | Tchebycheff | +//! | Crossover probability | [`crossover_prob`](MoeadSamplerBuilder::crossover_prob) | 1.0 | +//! | SBX distribution index | [`crossover_eta`](MoeadSamplerBuilder::crossover_eta) | 20.0 | +//! | Mutation distribution index | [`mutation_eta`](MoeadSamplerBuilder::mutation_eta) | 20.0 | +//! | Random seed | [`seed`](MoeadSamplerBuilder::seed) | random | //! //! # Examples //! @@ -37,15 +87,29 @@ use crate::multi_objective::MultiObjectiveTrial; use crate::param::ParamValue; use crate::types::Direction; -/// Decomposition (scalarization) method for MOEA/D. +/// Decomposition (scalarization) method for [`MoeadSampler`]. +/// +/// Control how multi-objective values are reduced to a single scalar +/// for each subproblem. The default is [`Tchebycheff`](Self::Tchebycheff), +/// which handles both convex and non-convex Pareto fronts. #[derive(Debug, Clone, Default)] pub enum Decomposition { - /// Weighted sum: `sum(w_i * f_i)`. + /// Weighted sum: `Σ(wᵢ * fᵢ)`. + /// + /// Simplest method but can only find solutions on convex regions + /// of the Pareto front. WeightedSum, - /// Tchebycheff: `max(w_i * |f_i - z_i*|)`. + /// Tchebycheff: `max(wᵢ * |fᵢ - zᵢ*|)`. + /// + /// Handles non-convex Pareto fronts. The most commonly used + /// decomposition method (default). #[default] Tchebycheff, - /// Penalty-based Boundary Intersection with parameter theta. + /// Penalty-based Boundary Intersection: `d₁ + θ * d₂`. + /// + /// Provides fine-grained control over the convergence/diversity + /// balance via the penalty parameter `theta`. Higher `theta` + /// favors solutions closer to the weight direction. Pbi { /// Penalty parameter controlling the balance between convergence /// and diversity. Default: 5.0. @@ -55,9 +119,14 @@ pub enum Decomposition { /// MOEA/D sampler for multi-objective optimization. /// -/// Decomposes the multi-objective problem into scalar subproblems -/// using weight vectors, solving them collaboratively via -/// neighborhood-based mating and replacement. +/// Decompose a multi-objective problem into scalar subproblems using +/// weight vectors and solve them collaboratively via neighborhood-based +/// mating. Supports [`Tchebycheff`](Decomposition::Tchebycheff), +/// [`WeightedSum`](Decomposition::WeightedSum), and +/// [`Pbi`](Decomposition::Pbi) scalarization. +/// +/// Create with [`MoeadSampler::new`], [`MoeadSampler::with_seed`], or +/// [`MoeadSampler::builder`] for full configuration. pub struct MoeadSampler { state: Mutex<MoeadState>, } diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index 4525218..823b791 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -1,19 +1,38 @@ //! Multi-Objective Tree-Parzen Estimator (MOTPE) sampler. //! -//! Extends TPE to handle multi-objective optimization by using Pareto -//! non-dominated sorting to define "good" vs "bad" trial regions for -//! the KDE models, replacing the single-objective gamma-based split. +//! MOTPE extends TPE to multi-objective optimization by replacing the gamma-based +//! split with Pareto non-dominated sorting. This lets the sampler propose +//! parameters that push the Pareto front forward across all objectives +//! simultaneously. //! //! # Algorithm //! -//! In single-objective TPE, trials are sorted by value and split at a -//! gamma percentile into good/bad groups. MOTPE replaces this with: +//! In single-objective TPE, trials are sorted by value and split at a gamma +//! percentile into good/bad groups. MOTPE replaces this with: //! -//! 1. Compute non-dominated sorting on all completed trials -//! 2. Use the Pareto front (rank 0) as "good" trials -//! 3. Use dominated trials as "bad" trials -//! 4. Build KDE l(x) from good, g(x) from bad -//! 5. Sample candidates and score by l(x)/g(x) +//! 1. Compute non-dominated sorting on all completed trials. +//! 2. Use the Pareto front (rank 0) as "good" trials. +//! 3. Use dominated trials (rank 1+) as "bad" trials. +//! 4. Build KDE l(x) from good, g(x) from bad. +//! 5. Sample candidates and score by l(x)/g(x). +//! +//! # When to use +//! +//! - You have 2+ objectives and want model-guided search (not pure evolutionary). +//! - Your objectives are relatively smooth and continuous. +//! - You want a Pareto-aware version of TPE without the overhead of full +//! population-based algorithms like NSGA-II or NSGA-III. +//! +//! For single-objective problems, use [`TpeSampler`](super::tpe::TpeSampler) instead. +//! For many-objective (3+) problems with reference-point decomposition, consider +//! NSGA-III or MOEA/D. +//! +//! # Configuration +//! +//! - `n_startup_trials` — number of random trials before MOTPE kicks in (default: 11) +//! - `n_ei_candidates` — candidates evaluated per sample (default: 24) +//! - `kde_bandwidth` — optional fixed KDE bandwidth; `None` uses Scott's rule +//! - `seed` — optional seed for reproducibility //! //! # Examples //! @@ -50,14 +69,21 @@ use crate::{pareto, rng_util}; /// Multi-Objective TPE (MOTPE) sampler for multi-objective Bayesian optimization. /// -/// Uses Pareto non-dominated sorting to split completed trials into -/// "good" (non-dominated, rank 0) and "bad" (dominated) groups, then -/// fits kernel density estimators to each group and samples new points -/// that maximize l(x)/g(x). +/// Use Pareto non-dominated sorting to split completed trials into "good" +/// (non-dominated, rank 0) and "bad" (dominated) groups, then fit kernel +/// density estimators to each group and sample new points that maximize +/// l(x)/g(x). /// /// During the startup phase (fewer than `n_startup_trials` completed), /// MOTPE falls back to random sampling. /// +/// # When to use +/// +/// Use `MotpeSampler` when optimizing 2+ objectives and you want a +/// model-guided sampler that adapts proposals based on the current +/// Pareto front. For single-objective problems, use +/// [`TpeSampler`](super::tpe::TpeSampler) instead. +/// /// # Examples /// /// ``` @@ -74,6 +100,17 @@ use crate::{pareto, rng_util}; /// /// let study = /// MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); +/// +/// let x = FloatParam::new(0.0, 1.0); +/// study +/// .optimize(30, |trial| { +/// let xv = x.suggest(trial)?; +/// Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) +/// }) +/// .unwrap(); +/// +/// let front = study.pareto_front(); +/// assert!(!front.is_empty()); /// ``` pub struct MotpeSampler { /// Number of trials before MOTPE kicks in (uses random sampling before this). @@ -520,6 +557,13 @@ impl MultiObjectiveSampler for MotpeSampler { /// Builder for configuring a [`MotpeSampler`]. /// +/// # Defaults +/// +/// - `n_startup_trials`: 11 +/// - `n_ei_candidates`: 24 +/// - `kde_bandwidth`: None (Scott's rule) +/// - `seed`: None (OS entropy) +/// /// # Examples /// /// ``` diff --git a/src/sampler/nsga2.rs b/src/sampler/nsga2.rs index acbbab5..0a8f873 100644 --- a/src/sampler/nsga2.rs +++ b/src/sampler/nsga2.rs @@ -1,7 +1,45 @@ //! NSGA-II (Non-dominated Sorting Genetic Algorithm II) sampler. //! -//! Implements multi-objective optimization using non-dominated sorting, -//! crowding distance, SBX crossover, and polynomial mutation. +//! NSGA-II is one of the most widely used evolutionary multi-objective +//! optimization algorithms. It ranks the population using **non-dominated +//! sorting** (fast O(MN²) algorithm) and breaks ties within the same +//! Pareto front using **crowding distance**, which favors solutions in +//! less-crowded regions of the objective space. +//! +//! # Algorithm +//! +//! Each generation proceeds as follows: +//! +//! 1. **Non-dominated sorting** — partition the combined parent+offspring +//! population into Pareto fronts F₁, F₂, … +//! 2. **Crowding distance** — for each front, compute per-solution crowding +//! distance (sum of normalized neighbor gaps in each objective). +//! 3. **Selection** — fill the next population front-by-front. When a front +//! only partially fits, prefer solutions with higher crowding distance. +//! 4. **Binary tournament** — select parents using (rank, crowding distance) +//! comparisons. +//! 5. **SBX crossover + polynomial mutation** — generate offspring. +//! +//! # When to use +//! +//! - Two-objective problems where you want a well-spread Pareto front. +//! - General-purpose multi-objective optimization with moderate population +//! sizes. +//! - Problems that benefit from diversity preservation via crowding distance. +//! +//! For problems with **three or more objectives**, consider +//! [`Nsga3Sampler`](super::nsga3::Nsga3Sampler) (reference-point niching) +//! or [`MoeadSampler`](super::moead::MoeadSampler) (decomposition). +//! +//! # Configuration +//! +//! | Parameter | Builder method | Default | +//! |-----------|---------------|---------| +//! | Population size | [`population_size`](Nsga2SamplerBuilder::population_size) | `4 + floor(3 * ln(n_params))`, min 4 | +//! | Crossover probability | [`crossover_prob`](Nsga2SamplerBuilder::crossover_prob) | 0.9 | +//! | SBX distribution index | [`crossover_eta`](Nsga2SamplerBuilder::crossover_eta) | 20.0 | +//! | Mutation distribution index | [`mutation_eta`](Nsga2SamplerBuilder::mutation_eta) | 20.0 | +//! | Random seed | [`seed`](Nsga2SamplerBuilder::seed) | random | //! //! # Examples //! @@ -39,8 +77,12 @@ use crate::types::Direction; /// NSGA-II sampler for multi-objective optimization. /// -/// Provides non-dominated sorting, crowding distance selection, -/// SBX crossover, and polynomial mutation. +/// Use non-dominated sorting with crowding-distance tie-breaking to +/// evolve a well-spread Pareto front. Best suited for bi-objective +/// problems; for 3+ objectives prefer [`Nsga3Sampler`](super::nsga3::Nsga3Sampler). +/// +/// Create with [`Nsga2Sampler::new`], [`Nsga2Sampler::with_seed`], or +/// [`Nsga2Sampler::builder`] for full configuration. pub struct Nsga2Sampler { state: Mutex<Nsga2State>, } diff --git a/src/sampler/nsga3.rs b/src/sampler/nsga3.rs index 09f414c..44028ef 100644 --- a/src/sampler/nsga3.rs +++ b/src/sampler/nsga3.rs @@ -1,9 +1,49 @@ //! NSGA-III (Non-dominated Sorting Genetic Algorithm III) sampler. //! -//! Uses reference-point-based niching for better diversity in -//! many-objective (3+) optimization problems. Das-Dennis structured -//! reference points guide the search toward a well-distributed -//! Pareto front. +//! NSGA-III extends NSGA-II to handle **many-objective** (3+) problems +//! where crowding distance loses effectiveness. Instead of crowding +//! distance, it uses **reference-point-based niching** with structured +//! Das-Dennis reference points distributed on the unit simplex to guide +//! the population toward a well-diversified Pareto front. +//! +//! # Algorithm +//! +//! Each generation proceeds as follows: +//! +//! 1. **Non-dominated sorting** — same as NSGA-II, partition the +//! combined population into Pareto fronts F₁, F₂, … +//! 2. **Normalize objectives** — translate by ideal point and scale by +//! intercepts so all objectives lie in roughly \[0, 1\]. +//! 3. **Associate with reference points** — assign each solution to the +//! closest Das-Dennis reference direction by perpendicular distance. +//! 4. **Niching selection** — when the last front only partially fits, +//! prefer solutions associated with under-represented reference points +//! (lowest niche count first, closest distance second). +//! 5. **SBX crossover + polynomial mutation** — generate offspring via +//! rank-based tournament selection. +//! +//! # When to use +//! +//! - **Three or more objectives** — NSGA-III maintains diversity far +//! better than NSGA-II as the number of objectives grows. +//! - Problems where you want a **uniformly distributed** Pareto front +//! guided by structured reference points. +//! - Scales well up to ~10 objectives with appropriate division settings. +//! +//! For bi-objective problems, [`Nsga2Sampler`](super::nsga2::Nsga2Sampler) +//! is simpler and equally effective. For decomposition-based optimization, +//! see [`MoeadSampler`](super::moead::MoeadSampler). +//! +//! # Configuration +//! +//! | Parameter | Builder method | Default | +//! |-----------|---------------|---------| +//! | Population size | [`population_size`](Nsga3SamplerBuilder::population_size) | Number of Das-Dennis reference points | +//! | Das-Dennis divisions (H) | [`n_divisions`](Nsga3SamplerBuilder::n_divisions) | Auto-chosen from population size and objectives | +//! | Crossover probability | [`crossover_prob`](Nsga3SamplerBuilder::crossover_prob) | 1.0 | +//! | SBX distribution index | [`crossover_eta`](Nsga3SamplerBuilder::crossover_eta) | 30.0 | +//! | Mutation distribution index | [`mutation_eta`](Nsga3SamplerBuilder::mutation_eta) | 20.0 | +//! | Random seed | [`seed`](Nsga3SamplerBuilder::seed) | random | //! //! # Examples //! @@ -49,8 +89,12 @@ use crate::types::Direction; /// NSGA-III sampler for multi-objective optimization. /// -/// Uses reference-point-based niching to maintain diversity, -/// especially effective for problems with 3 or more objectives. +/// Use reference-point niching with Das-Dennis structured points to +/// maintain diversity in many-objective (3+) problems. For bi-objective +/// problems, [`Nsga2Sampler`](super::nsga2::Nsga2Sampler) is simpler. +/// +/// Create with [`Nsga3Sampler::new`], [`Nsga3Sampler::with_seed`], or +/// [`Nsga3Sampler::builder`] for full configuration. pub struct Nsga3Sampler { state: Mutex<Nsga3State>, } diff --git a/src/sampler/random.rs b/src/sampler/random.rs index e033a4e..292fe4a 100644 --- a/src/sampler/random.rs +++ b/src/sampler/random.rs @@ -1,4 +1,31 @@ -//! Random sampler implementation. +//! Random sampler — uniform independent sampling. +//! +//! [`RandomSampler`] draws each parameter value independently and uniformly +//! at random, ignoring trial history entirely. It respects log-scale and +//! step-size constraints defined by the parameter distribution. +//! +//! # When to use +//! +//! - **Baseline comparison** — run Random alongside smarter samplers to +//! quantify their benefit. +//! - **Startup phase** — many model-based samplers (TPE, GP, CMA-ES) use +//! random sampling for their first *n* trials before fitting a surrogate. +//! - **Very high dimensions** — when the search space is too large for +//! structured exploration, random search with enough budget can be +//! surprisingly competitive. +//! +//! For better uniform coverage without model fitting, consider +//! [`SobolSampler`](super::sobol::SobolSampler) (requires the `sobol` +//! feature flag). +//! +//! # Example +//! +//! ``` +//! use optimizer::prelude::*; +//! use optimizer::sampler::random::RandomSampler; +//! +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); +//! ``` use parking_lot::Mutex; @@ -7,11 +34,15 @@ use crate::param::ParamValue; use crate::rng_util; use crate::sampler::{CompletedTrial, Sampler}; -/// A simple random sampler that samples uniformly from distributions. +/// Uniform independent random sampler. /// -/// This sampler ignores the trial history and samples uniformly at random, -/// respecting log scale and step size constraints. It serves as a baseline -/// sampler and is used during the startup phase of more sophisticated samplers. +/// Sample each parameter value uniformly at random, respecting log-scale and +/// step-size constraints. Trial history is ignored — every sample is drawn +/// independently. +/// +/// This is the default sampler used by [`Study::new`](crate::Study::new) +/// and during the startup phase of model-based samplers such as +/// [`TpeSampler`](super::tpe::TpeSampler). /// /// # Examples /// diff --git a/src/sampler/sobol.rs b/src/sampler/sobol.rs index 3c4bb18..2dad0b8 100644 --- a/src/sampler/sobol.rs +++ b/src/sampler/sobol.rs @@ -1,4 +1,48 @@ //! Quasi-random sampler using Sobol low-discrepancy sequences. +//! +//! [`SobolSampler`] generates points from a Sobol sequence (scrambled via the +//! Burley 2020 algorithm) to fill the parameter space more uniformly than +//! pure random sampling. Where [`RandomSampler`](super::random::RandomSampler) +//! may cluster points in some regions by chance, Sobol sequences are +//! constructed to spread points evenly across all dimensions. +//! +//! # When to use +//! +//! - **Better-than-random baseline** — when you want uniform coverage +//! without the cost of model fitting (TPE, GP, etc.). +//! - **Startup phase replacement** — use Sobol instead of random for the +//! initial exploration phase of adaptive samplers. +//! - **Moderate dimensionality** — Sobol uniformity is strongest up to +//! ~20 dimensions; beyond that the advantage over random sampling +//! diminishes. +//! - **Deterministic exploration** — Sobol sequences are fully deterministic +//! for a given seed, making experiments reproducible. +//! +//! # How it works +//! +//! Each trial maps to a Sobol sequence index, and each parameter within a +//! trial maps to a separate Sobol dimension. The resulting quasi-random +//! point in \[0, 1) is then scaled to the parameter's distribution (linear, +//! log-scale, or step grid). +//! +//! **Important:** parameters must be suggested in the same order across +//! trials for consistent dimension assignment. +//! +//! Requires the **`sobol`** feature flag: +//! +//! ```toml +//! [dependencies] +//! optimizer = { version = "...", features = ["sobol"] } +//! ``` +//! +//! # Example +//! +//! ``` +//! use optimizer::prelude::*; +//! use optimizer::sampler::sobol::SobolSampler; +//! +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, SobolSampler::with_seed(42)); +//! ``` use parking_lot::Mutex; use sobol_burley::sample; @@ -17,13 +61,11 @@ struct SobolState { /// Quasi-random sampler using Sobol low-discrepancy sequences. /// -/// Provides better uniform coverage of the parameter space than -/// [`RandomSampler`](super::random::RandomSampler). Useful as a baseline or -/// for the startup phase of model-based samplers. -/// -/// Unlike random sampling, Sobol sequences are deterministic and fill the -/// space more evenly, reducing the number of trials needed to adequately -/// cover the search space. +/// Produce better uniform coverage of the parameter space than +/// [`RandomSampler`](super::random::RandomSampler) by using a +/// scrambled Sobol sequence (Burley 2020). Useful as a standalone +/// baseline or as a drop-in replacement for the random startup +/// phase of model-based samplers. /// /// Each trial uses a different Sobol sequence index, and each parameter /// within a trial maps to a different Sobol dimension. Parameters must be @@ -33,7 +75,7 @@ struct SobolState { /// Sobol sequences are most effective in moderate dimensions (up to ~20). /// For very high dimensions, the uniformity advantage diminishes. /// -/// Requires the `sobol` feature flag. +/// Requires the **`sobol`** feature flag. /// /// # Examples /// diff --git a/src/sampler/tpe/mod.rs b/src/sampler/tpe/mod.rs index 3013111..2c17e3d 100644 --- a/src/sampler/tpe/mod.rs +++ b/src/sampler/tpe/mod.rs @@ -1,7 +1,63 @@ -//! Tree-Parzen Estimator (TPE) sampler implementation and utilities. +//! Tree-Parzen Estimator (TPE) sampler family for Bayesian optimization. //! -//! This module provides TPE-based sampling for Bayesian optimization, -//! including support for intersection search space calculation. +//! TPE is a sequential model-based optimization algorithm that models P(x|y) instead +//! of P(y|x). It splits completed trials into "good" (below the gamma quantile) and +//! "bad" groups, fits a kernel density estimator (KDE) to each, and proposes new +//! points by maximizing the l(x)/g(x) ratio — an approximation of Expected Improvement. +//! +//! # Samplers +//! +//! | Sampler | Models parameters | Best for | +//! |---------|-------------------|----------| +//! | [`TpeSampler`] | Independently | General-purpose single-objective optimization | +//! | [`MultivariateTpeSampler`] | Jointly | Problems with correlated parameters | +//! +//! # Gamma strategies +//! +//! The gamma quantile controls how many trials are considered "good". This module +//! provides four built-in strategies via the [`GammaStrategy`] trait: +//! +//! | Strategy | Formula | Default | +//! |----------|---------|---------| +//! | [`FixedGamma`] | Constant value | gamma = 0.25 | +//! | [`LinearGamma`] | Linear ramp from min to max | 0.10 → 0.25 over 100 trials | +//! | [`SqrtGamma`] | 1/√n decay (Optuna-style) | factor = 1.0, max = 0.25 | +//! | [`HyperoptGamma`] | (base+1)/n (Hyperopt-style) | base = 24, max = 0.25 | +//! +//! You can also implement [`GammaStrategy`] for a custom splitting rule. +//! +//! # Search-space utilities +//! +//! The [`search_space`] submodule provides [`IntersectionSearchSpace`] for computing +//! the common parameter set across trials, and [`GroupDecomposedSearchSpace`] for +//! splitting parameters into independent groups based on co-occurrence. +//! +//! # Examples +//! +//! Basic TPE with default settings: +//! +//! ``` +//! use optimizer::sampler::tpe::TpeSampler; +//! use optimizer::{Direction, Study}; +//! +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, TpeSampler::new()); +//! ``` +//! +//! Multivariate TPE for correlated parameters: +//! +//! ``` +//! use optimizer::sampler::tpe::MultivariateTpeSampler; +//! use optimizer::{Direction, Study}; +//! +//! let sampler = MultivariateTpeSampler::builder() +//! .gamma(0.15) +//! .n_startup_trials(20) +//! .group(true) +//! .seed(42) +//! .build() +//! .unwrap(); +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +//! ``` mod gamma; mod multivariate; diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate.rs index 556545f..72ea840 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate.rs @@ -167,15 +167,53 @@ pub enum ConstantLiarStrategy { /// A Multivariate Tree-Parzen Estimator (TPE) sampler for Bayesian optimization. /// -/// This sampler extends the standard TPE approach by modeling joint distributions -/// over all parameters, allowing it to capture parameter correlations. +/// Unlike the standard [`super::TpeSampler`], which samples each parameter +/// independently, this sampler models joint distributions over all parameters +/// using multivariate KDE. This captures correlations between parameters and +/// can significantly improve optimization on problems where parameters interact +/// (e.g., Rosenbrock, coupled hyperparameters). /// -/// # Fields +/// When the search space varies between trials (conditional parameters), the +/// sampler automatically falls back to independent TPE or uniform sampling for +/// parameters outside the intersection search space. /// -/// - `gamma_strategy`: Strategy for computing the gamma quantile -/// - `n_startup_trials`: Number of random trials before TPE sampling begins -/// - `n_ei_candidates`: Number of candidates to evaluate per joint sample -/// - `group`: Whether to decompose search space into independent groups +/// # When to use +/// +/// - Parameters are correlated or interact with each other. +/// - The search space is mostly fixed across trials. +/// - You need parallel optimization (enable [`ConstantLiarStrategy`]). +/// +/// Prefer [`super::TpeSampler`] when parameters are independent or the search space changes +/// dynamically. +/// +/// # Examples +/// +/// ``` +/// use optimizer::parameter::{FloatParam, Parameter}; +/// use optimizer::sampler::tpe::MultivariateTpeSampler; +/// use optimizer::{Direction, Study}; +/// +/// let sampler = MultivariateTpeSampler::builder() +/// .gamma(0.15) +/// .n_startup_trials(20) +/// .seed(42) +/// .build() +/// .unwrap(); +/// +/// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); +/// let x = FloatParam::new(-5.0, 5.0); +/// let y = FloatParam::new(-5.0, 5.0); +/// +/// study +/// .optimize(30, |trial| { +/// let xv = x.suggest(trial)?; +/// let yv = y.suggest(trial)?; +/// Ok::<_, optimizer::Error>(xv * xv + yv * yv) +/// }) +/// .unwrap(); +/// +/// assert!(study.best_value().unwrap() < 1.0); +/// ``` pub struct MultivariateTpeSampler { /// Strategy for computing the gamma quantile. gamma_strategy: Arc<dyn GammaStrategy>, diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 6a40692..16ff43b 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -102,7 +102,7 @@ use crate::sampler::{CompletedTrial, Sampler}; /// /// // Create with custom settings using the builder /// let sampler = TpeSampler::builder() -/// .gamma(0.15) // Shorthand for Fixednew(0.15) +/// .gamma(0.15) // Shorthand for FixedGamma::new(0.15) /// .n_startup_trials(20) /// .n_ei_candidates(32) /// .seed(42) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 9e4ca68..367adfb 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -1,4 +1,72 @@ //! JSONL-based journal storage backend. +//! +//! [`JournalStorage`] persists completed trials as one JSON object per +//! line ([JSONL / JSON Lines](https://jsonlines.org/)) while keeping a +//! full copy in memory for fast read access. +//! +//! # File format +//! +//! Each line is a self-contained JSON serialization of a +//! [`CompletedTrial<V>`](crate::sampler::CompletedTrial). The file +//! is append-only — no existing lines are ever modified or deleted. +//! +//! ```text +//! {"id":0,"params":{...},"value":1.23,"state":"Completed",...} +//! {"id":1,"params":{...},"value":0.87,"state":"Completed",...} +//! ``` +//! +//! # File locking +//! +//! Concurrent access is coordinated with `fs2` file locks: +//! +//! - **Writes** acquire an *exclusive* lock so only one process +//! appends at a time. +//! - **Reads** ([`refresh`](super::Storage::refresh)) acquire a +//! *shared* lock so readers never see a partially written line. +//! +//! This makes it safe for multiple processes to share the same JSONL +//! file — for example, distributed workers each running their own +//! [`Study`](crate::Study) with a `JournalStorage` pointing to a +//! shared path. +//! +//! # Resuming a study +//! +//! Use [`JournalStorage::open`] to reload previously persisted trials +//! and continue optimization from where you left off: +//! +//! ```no_run +//! use optimizer::prelude::*; +//! use optimizer::storage::JournalStorage; +//! +//! // First run — creates the file. +//! let storage = JournalStorage::<f64>::new("trials.jsonl"); +//! let mut study = Study::builder().minimize().storage(storage).build(); +//! study +//! .optimize(50, |trial| { +//! let x = FloatParam::new(-5.0, 5.0).suggest(trial)?; +//! Ok::<_, optimizer::Error>(x * x) +//! }) +//! .unwrap(); +//! +//! // Later run — reloads previous 50 trials, then adds 50 more. +//! let storage = JournalStorage::<f64>::open("trials.jsonl").unwrap(); +//! let mut study = Study::builder().minimize().storage(storage).build(); +//! study +//! .optimize(50, |trial| { +//! let x = FloatParam::new(-5.0, 5.0).suggest(trial)?; +//! Ok::<_, optimizer::Error>(x * x) +//! }) +//! .unwrap(); +//! ``` +//! +//! # When to use +//! +//! - **Persistence** — survive process crashes or intentional restarts. +//! - **Multi-process** — several workers collaborating on a single study. +//! - **Inspection** — `cat trials.jsonl | jq .` for quick debugging. +//! +//! For pure in-memory usage without disk I/O, use +//! [`MemoryStorage`](super::MemoryStorage) instead (the default). use core::marker::PhantomData; use std::fs::{File, OpenOptions}; @@ -14,22 +82,30 @@ use serde::de::DeserializeOwned; use super::{MemoryStorage, Storage}; use crate::sampler::CompletedTrial; -/// A storage backend that appends completed trials as JSON lines to a file. +/// Append-only JSONL storage backend with file locking. /// -/// Trials are kept in memory for fast read access and simultaneously -/// persisted to a JSONL file. Multiple processes can safely share -/// the same file: writes use an exclusive file lock, reads use a -/// shared file lock. +/// Trials are kept in memory (via an inner [`MemoryStorage`]) for fast +/// read access and simultaneously appended to a JSONL file on disk. +/// Multiple processes can safely share the same file thanks to +/// `fs2` file locks — writes use an exclusive lock, reads use a +/// shared lock. /// -/// The type parameter `V` is the objective value type (typically `f64`). -/// It must be serializable so that trials can be written to disk. +/// The type parameter `V` is the objective value type (typically +/// `f64`). It must implement [`Serialize`](serde::Serialize) and +/// [`DeserializeOwned`](serde::de::DeserializeOwned) so trials can be +/// written to and read from disk. /// -/// # Examples +/// See the [`storage`](super) module docs for file format details +/// and a resumption example. +/// +/// # Example /// /// ```no_run +/// use optimizer::prelude::*; /// use optimizer::storage::JournalStorage; /// -/// let storage: JournalStorage<f64> = JournalStorage::new("trials.jsonl"); +/// let storage = JournalStorage::<f64>::new("trials.jsonl"); +/// let mut study = Study::builder().minimize().storage(storage).build(); /// ``` pub struct JournalStorage<V = f64> { memory: MemoryStorage<V>, @@ -40,14 +116,15 @@ pub struct JournalStorage<V = f64> { } impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { - /// Creates a new journal storage that writes to the given path. + /// Create a new journal storage that writes to the given path. /// /// The file does not need to exist yet — it will be created on the /// first write. Existing trials in the file are **not** loaded /// until [`refresh`](Storage::refresh) is called (which happens /// automatically at the start of each trial via the [`Study`](crate::Study)). /// - /// To pre-load existing trials, use [`JournalStorage::open`]. + /// To pre-load existing trials at construction time, use + /// [`JournalStorage::open`] instead. #[must_use] pub fn new(path: impl AsRef<Path>) -> Self { Self { @@ -58,13 +135,14 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { } } - /// Opens an existing journal file and loads all stored trials. + /// Open an existing journal file and load all stored trials. /// - /// If the file does not exist, returns an empty storage (no error). + /// If the file does not exist, return an empty storage (no error). + /// This is the primary way to **resume** a study after a restart. /// /// # Errors /// - /// Returns a [`Storage`](crate::Error::Storage) error if the file + /// Return a [`Storage`](crate::Error::Storage) error if the file /// exists but cannot be read or parsed. pub fn open(path: impl AsRef<Path>) -> crate::Result<Self> { let path = path.as_ref().to_path_buf(); diff --git a/src/storage/memory.rs b/src/storage/memory.rs index d9c2d29..676fc41 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -1,3 +1,32 @@ +//! In-memory storage backend. +//! +//! [`MemoryStorage`] is the default backend used by every +//! [`Study`](crate::Study). Trials are stored in a +//! `Vec<CompletedTrial<V>>` behind a [`parking_lot::RwLock`] for +//! thread-safe access. +//! +//! # When to use +//! +//! - **Single-process** studies where persistence is not needed. +//! - **Testing** or **prototyping** — zero configuration required. +//! - When you want the **fastest** possible read/write performance +//! (no disk I/O). +//! +//! For persistent storage that survives process restarts, see +//! [`JournalStorage`](super::JournalStorage) (requires the `journal` +//! feature). +//! +//! # Example +//! +//! ``` +//! use optimizer::prelude::*; +//! use optimizer::storage::MemoryStorage; +//! +//! // Explicit memory storage (equivalent to the default) +//! let storage = MemoryStorage::<f64>::new(); +//! let study = Study::builder().minimize().storage(storage).build(); +//! ``` + use core::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; @@ -8,14 +37,28 @@ use crate::sampler::CompletedTrial; /// In-memory trial storage (the default). /// -/// This is a thin wrapper around `Arc<RwLock<Vec<CompletedTrial<V>>>>`. +/// Wrap a `Vec<CompletedTrial<V>>` behind a read-write lock so that +/// trials can be appended from any thread. This is the backend that +/// [`Study`](crate::Study) uses when no explicit storage is provided. +/// +/// Use [`with_trials`](Self::with_trials) to seed a study with +/// previously collected data. +/// +/// # Example +/// +/// ``` +/// use optimizer::storage::{MemoryStorage, Storage}; +/// +/// let storage = MemoryStorage::<f64>::new(); +/// assert_eq!(storage.trials_arc().read().len(), 0); +/// ``` pub struct MemoryStorage<V> { trials: Arc<RwLock<Vec<CompletedTrial<V>>>>, next_id: AtomicU64, } impl<V> MemoryStorage<V> { - /// Creates a new, empty in-memory store. + /// Create a new, empty in-memory store. #[must_use] pub fn new() -> Self { Self { @@ -24,7 +67,10 @@ impl<V> MemoryStorage<V> { } } - /// Creates an in-memory store pre-populated with `trials`. + /// Create an in-memory store pre-populated with `trials`. + /// + /// The internal ID counter is set to one past the highest trial ID + /// so that subsequent trials receive unique IDs. #[must_use] pub fn with_trials(trials: Vec<CompletedTrial<V>>) -> Self { let next_id = trials.iter().map(|t| t.id).max().map_or(0, |id| id + 1); @@ -34,7 +80,8 @@ impl<V> MemoryStorage<V> { } } - /// Ensures the ID counter is at least `min_value`. + /// Ensure the ID counter is at least `min_value`. + #[cfg(feature = "journal")] pub(crate) fn bump_next_id(&self, min_value: u64) { self.next_id.fetch_max(min_value, Ordering::SeqCst); } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d5ceb00..1ed6357 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -1,10 +1,42 @@ //! Trial storage backends. //! -//! The [`Storage`] trait defines how completed trials are stored and -//! accessed. [`MemoryStorage`] keeps trials in memory (the default). -//! With the `journal` feature enabled, [`JournalStorage`] appends -//! trials to a JSONL file with file-level locking so multiple -//! processes can safely share state. +//! The [`Storage`] trait defines how completed trials are persisted and +//! retrieved. Every [`Study`](crate::Study) owns an `Arc<dyn Storage<V>>` +//! so storage is transparently shared across threads. +//! +//! # Available backends +//! +//! | Backend | Description | Feature flag | +//! |---------|-------------|-------------| +//! | [`MemoryStorage`] | In-memory `Vec` behind a read-write lock (the default) | — | +//! | [`JournalStorage`] | JSONL file with `fs2` file locking for multi-process sharing | `journal` | +//! +//! # When to swap backends +//! +//! The default [`MemoryStorage`] is sufficient for single-process studies +//! where persistence is not needed. Switch to [`JournalStorage`] when you +//! want to: +//! +//! - **Resume** a study after a process restart. +//! - **Share state** across multiple processes writing to the same file. +//! - **Inspect** trial history in a human-readable JSONL file. +//! +//! # Implementing a custom backend +//! +//! Implement the [`Storage`] trait to plug in your own backend (e.g. a +//! database). The trait requires four methods: [`push`](Storage::push), +//! [`trials_arc`](Storage::trials_arc), [`next_trial_id`](Storage::next_trial_id), +//! and optionally [`refresh`](Storage::refresh) for external data sources. +//! +//! Inject your storage into a study via the builder: +//! +//! ``` +//! use optimizer::prelude::*; +//! use optimizer::storage::MemoryStorage; +//! +//! let storage = MemoryStorage::<f64>::new(); +//! let study = Study::builder().minimize().storage(storage).build(); +//! ``` #[cfg(feature = "journal")] mod journal; @@ -26,7 +58,8 @@ use crate::sampler::CompletedTrial; /// default implementation is [`MemoryStorage`], which keeps trials in /// a plain `Vec` behind a read-write lock. /// -/// Implementations must be safe to use from multiple threads. +/// Implementations must be `Send + Sync` because a study may be shared +/// across threads (e.g. via [`optimize_parallel`](crate::Study::optimize_parallel)). pub trait Storage<V>: Send + Sync { /// Append a completed trial to the store. fn push(&self, trial: CompletedTrial<V>); @@ -38,14 +71,14 @@ pub trait Storage<V>: Send + Sync { /// lock for efficient, allocation-free access. fn trials_arc(&self) -> &Arc<RwLock<Vec<CompletedTrial<V>>>>; - /// Atomically returns the next unique trial ID. + /// Atomically return the next unique trial ID. /// /// Each call increments an internal counter so that consecutive /// calls always produce distinct IDs. fn next_trial_id(&self) -> u64; /// Reload from an external source (e.g. a file written by another - /// process). Returns `true` if the in-memory buffer was updated. + /// process). Return `true` if the in-memory buffer was updated. /// /// The default implementation is a no-op that returns `false`. fn refresh(&self) -> bool { diff --git a/src/study.rs b/src/study.rs index 8577d5a..e8effcf 100644 --- a/src/study.rs +++ b/src/study.rs @@ -63,7 +63,7 @@ impl<V> Study<V> where V: PartialOrd, { - /// Creates a new study with the given optimization direction. + /// Create a new study with the given optimization direction. /// /// Uses the default `RandomSampler` for parameter sampling. /// @@ -87,7 +87,7 @@ where Self::with_sampler(direction, RandomSampler::new()) } - /// Returns a [`StudyBuilder`] for constructing a study with a fluent API. + /// Return a [`StudyBuilder`] for constructing a study with a fluent API. /// /// # Examples /// @@ -111,7 +111,7 @@ where } } - /// Creates a study that minimizes the objective value. + /// Create a study that minimizes the objective value. /// /// This is a shorthand for `Study::with_sampler(Direction::Minimize, sampler)`. /// @@ -136,7 +136,7 @@ where Self::with_sampler(Direction::Minimize, sampler) } - /// Creates a study that maximizes the objective value. + /// Create a study that maximizes the objective value. /// /// This is a shorthand for `Study::with_sampler(Direction::Maximize, sampler)`. /// @@ -161,7 +161,7 @@ where Self::with_sampler(Direction::Maximize, sampler) } - /// Creates a new study with a custom sampler. + /// Create a new study with a custom sampler. /// /// # Arguments /// @@ -189,7 +189,7 @@ where ) } - /// Builds a trial factory for sampler integration when `V = f64`. + /// Build a trial factory for sampler integration when `V = f64`. fn make_trial_factory( sampler: &Arc<dyn Sampler>, storage: &Arc<dyn crate::storage::Storage<V>>, @@ -220,10 +220,28 @@ where }) } - /// Creates a study with a custom sampler and storage backend. + /// Create a study with a custom sampler and storage backend. /// /// This is the most general constructor — all other constructors - /// delegate to this one. + /// delegate to this one. Use it when you need a non-default storage + /// backend (e.g., [`JournalStorage`](crate::storage::JournalStorage)). + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `storage` - The storage backend for completed trials. + /// + /// # Examples + /// + /// ``` + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::storage::MemoryStorage; + /// use optimizer::{Direction, Study}; + /// + /// let storage = MemoryStorage::<f64>::new(); + /// let study = Study::with_sampler_and_storage(Direction::Minimize, RandomSampler::new(), storage); + /// ``` pub fn with_sampler_and_storage( direction: Direction, sampler: impl Sampler + 'static, @@ -247,28 +265,15 @@ where } } - /// Returns the optimization direction. + /// Return the optimization direction. #[must_use] pub fn direction(&self) -> Direction { self.direction } - /// Sets a new sampler for the study. - /// - /// # Arguments - /// - /// * `sampler` - The sampler to use for parameter sampling. + /// Creates a study with a custom sampler and pruner. /// - /// # Examples - /// - /// ``` - /// use optimizer::sampler::tpe::TpeSampler; - /// use optimizer::{Direction, Study}; - /// - /// let mut study: Study<f64> = Study::new(Direction::Minimize); - /// study.set_sampler(TpeSampler::new()); - /// ``` - /// Creates a new study with a custom sampler and pruner. + /// Uses the default [`MemoryStorage`](crate::storage::MemoryStorage) backend. /// /// # Arguments /// @@ -310,6 +315,21 @@ where } } + /// Replace the sampler used for future parameter suggestions. + /// + /// The new sampler takes effect for all subsequent calls to + /// [`create_trial`](Self::create_trial), [`ask`](Self::ask), and the + /// `optimize*` family. Already-completed trials are unaffected. + /// + /// # Examples + /// + /// ``` + /// use optimizer::sampler::tpe::TpeSampler; + /// use optimizer::{Direction, Study}; + /// + /// let mut study: Study<f64> = Study::new(Direction::Minimize); + /// study.set_sampler(TpeSampler::new()); + /// ``` pub fn set_sampler(&mut self, sampler: impl Sampler + 'static) where V: 'static, @@ -318,11 +338,18 @@ where self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); } - /// Sets a new pruner for the study. + /// Replace the pruner used for future trials. /// - /// # Arguments + /// The new pruner takes effect for all trials created after this call. /// - /// * `pruner` - The pruner to use for trial pruning. + /// # Examples + /// + /// ``` + /// use optimizer::prelude::*; + /// + /// let mut study: Study<f64> = Study::new(Direction::Minimize); + /// study.set_pruner(MedianPruner::new(Direction::Minimize)); + /// ``` pub fn set_pruner(&mut self, pruner: impl Pruner + 'static) where V: 'static, @@ -331,13 +358,13 @@ where self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); } - /// Returns a reference to the study's pruner. + /// Return a reference to the study's current pruner. #[must_use] pub fn pruner(&self) -> &dyn Pruner { &*self.pruner } - /// Enqueues a specific parameter configuration to be evaluated next. + /// Enqueue a specific parameter configuration to be evaluated next. /// /// The next call to [`ask()`](Self::ask) or the next trial in [`optimize()`](Self::optimize) /// will use these exact parameters instead of sampling from the sampler. @@ -377,7 +404,7 @@ where self.enqueued_params.lock().push_back(params); } - /// Returns the trial ID of the current best trial from the given slice. + /// Return the trial ID of the current best trial from the given slice. #[cfg(feature = "tracing")] fn best_id(&self, trials: &[CompletedTrial<V>]) -> Option<u64> { let direction = self.direction; @@ -388,7 +415,7 @@ where .map(|t| t.id) } - /// Creates a new trial with pre-set parameter values. + /// Create a new trial with pre-set parameter values. /// /// The trial gets a new unique ID but reuses the given parameters. When /// `suggest_param` is called on the resulting trial, fixed values are @@ -404,18 +431,20 @@ where trial } - /// Returns the number of enqueued parameter configurations. + /// Return the number of enqueued parameter configurations. + /// + /// See [`enqueue`](Self::enqueue) for how to add configurations. #[must_use] pub fn n_enqueued(&self) -> usize { self.enqueued_params.lock().len() } - /// Generates the next unique trial ID. + /// Generate the next unique trial ID. pub(crate) fn next_trial_id(&self) -> u64 { self.storage.next_trial_id() } - /// Creates a new trial with a unique ID. + /// Create a new trial with a unique ID. /// /// The trial starts in the `Running` state and can be used to suggest /// parameter values. After the objective function is evaluated, call @@ -456,7 +485,7 @@ where trial } - /// Records a completed trial with its objective value. + /// Record a completed trial with its objective value. /// /// This method stores the trial's parameters, distributions, and objective /// value in the study's history. The stored data is used by samplers to @@ -499,7 +528,7 @@ where self.storage.push(completed); } - /// Records a failed trial with an error message. + /// Record a failed trial with an error message. /// /// Failed trials are not stored in the study's history and do not /// contribute to future sampling decisions. This method is useful @@ -582,11 +611,15 @@ where } } - /// Records a pruned trial, preserving its intermediate values. + /// Record a pruned trial, preserving its intermediate values. /// /// Pruned trials are stored alongside completed trials so that samplers /// can optionally learn from partial evaluations. The trial's state is - /// set to `Pruned`. + /// set to [`Pruned`](crate::TrialState::Pruned). + /// + /// In practice you rarely call this directly — returning + /// `Err(TrialPruned)` from an objective function handles pruning + /// automatically. /// /// # Arguments /// @@ -611,9 +644,9 @@ where self.storage.push(completed); } - /// Returns an iterator over all completed trials. + /// Return all completed trials as a `Vec`. /// - /// The iterator yields references to `CompletedTrial` values, which contain + /// The returned vector contains clones of `CompletedTrial` values, which contain /// the trial's parameters, distributions, and objective value. /// /// Note: This method acquires a read lock on the completed trials, so the @@ -643,7 +676,7 @@ where self.storage.trials_arc().read().clone() } - /// Returns the number of completed trials. + /// Return the number of completed trials. /// /// Failed trials are not counted. /// @@ -667,7 +700,9 @@ where self.storage.trials_arc().read().len() } - /// Returns the number of pruned trials. + /// Return the number of pruned trials. + /// + /// Pruned trials are those that were stopped early by the pruner. #[must_use] pub fn n_pruned_trials(&self) -> usize { self.storage @@ -678,7 +713,7 @@ where .count() } - /// Compares two completed trials using constraint-aware ranking. + /// Compare two completed trials using constraint-aware ranking. /// /// 1. Feasible trials always rank above infeasible trials. /// 2. Among feasible trials, rank by objective value (respecting direction). @@ -708,7 +743,7 @@ where } } - /// Returns the trial with the best objective value. + /// Return the trial with the best objective value. /// /// The "best" trial depends on the optimization direction: /// - `Direction::Minimize`: Returns the trial with the lowest objective value. @@ -762,7 +797,7 @@ where Ok(best.clone()) } - /// Returns the best objective value found so far. + /// Return the best objective value found so far. /// /// The "best" value depends on the optimization direction: /// - `Direction::Minimize`: Returns the lowest objective value. @@ -803,13 +838,33 @@ where self.best_trial().map(|trial| trial.value) } - /// Returns the top `n` trials sorted by objective value. + /// Return the top `n` trials sorted by objective value. /// /// For `Direction::Minimize`, returns trials with the lowest values. /// For `Direction::Maximize`, returns trials with the highest values. /// Only includes completed trials (not failed or pruned). /// /// If fewer than `n` completed trials exist, returns all of them. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0); + /// + /// for val in [5.0, 1.0, 3.0] { + /// let mut t = study.create_trial(); + /// let _ = x.suggest(&mut t); + /// study.complete_trial(t, val); + /// } + /// + /// let top2 = study.top_trials(2); + /// assert_eq!(top2.len(), 2); + /// assert!(top2[0].value <= top2[1].value); + /// ``` #[must_use] pub fn top_trials(&self, n: usize) -> Vec<CompletedTrial<V>> where @@ -828,7 +883,7 @@ where completed } - /// Runs optimization with the given objective function. + /// Run optimization with the given objective function. /// /// This method runs `n_trials` evaluations sequentially. For each trial: /// 1. A new trial is created @@ -936,7 +991,7 @@ where Ok(()) } - /// Runs optimization asynchronously with the given objective function. + /// Run optimization asynchronously with the given objective function. /// /// This method runs `n_trials` evaluations sequentially, but the objective /// function can be async (e.g., for I/O-bound operations like network requests @@ -1036,7 +1091,7 @@ where Ok(()) } - /// Runs optimization with bounded parallelism for concurrent trial evaluation. + /// Run optimization with bounded parallelism for concurrent trial evaluation. /// /// This method runs up to `concurrency` trials simultaneously, allowing /// efficient use of async I/O-bound objective functions. A semaphore limits @@ -1163,7 +1218,7 @@ where Ok(()) } - /// Runs optimization with a callback for monitoring progress. + /// Run optimization with a callback for monitoring progress. /// /// This method is similar to `optimize`, but calls a callback function after /// each completed trial. The callback can inspect the study state and the @@ -1305,7 +1360,7 @@ where Ok(()) } - /// Runs optimization until the given duration has elapsed. + /// Run optimization until the given duration has elapsed. /// /// Trials that are already running when the timeout is reached will /// complete — we never interrupt mid-trial. The actual elapsed time @@ -1392,7 +1447,7 @@ where Ok(()) } - /// Runs optimization until the given duration has elapsed, with a callback. + /// Run optimization until the given duration has elapsed, with a callback. /// /// Like [`optimize_until`](Self::optimize_until), but calls a callback after /// each completed trial. The callback can stop optimization early by returning @@ -1526,16 +1581,17 @@ where Ok(()) } - /// Runs optimization asynchronously until the given duration has elapsed. + /// Run optimization asynchronously until the given duration has elapsed. /// /// The async variant of [`optimize_until`](Self::optimize_until). Trials are - /// run sequentially, but the objective function can be async. + /// run sequentially, but the objective function can be async (useful for + /// I/O-bound evaluations). /// /// # Arguments /// /// * `duration` - The maximum wall-clock time to spend on optimization. /// * `objective` - A function that takes a `Trial` and returns a `Future` - /// that resolves to a tuple of `(Trial, Result<V, E>)`. + /// that resolves to a tuple of `(Trial, V)` or an error. /// /// # Errors /// @@ -1585,7 +1641,7 @@ where Ok(()) } - /// Runs optimization with bounded parallelism until the given duration has elapsed. + /// Run optimization with bounded parallelism until the given duration has elapsed. /// /// The parallel variant of [`optimize_until`](Self::optimize_until). Runs up to /// `concurrency` trials simultaneously using async tasks. New trials are spawned @@ -1675,7 +1731,7 @@ where Ok(()) } - /// Runs optimization with automatic retry for failed trials. + /// Run optimization with automatic retry for failed trials. /// /// If the objective function returns an error, the same parameter /// configuration is retried up to `max_retries` times. Only after all @@ -1789,7 +1845,7 @@ impl<V> Study<V> where V: PartialOrd + Clone + fmt::Display, { - /// Export completed trials to CSV format. + /// Write completed trials to a writer in CSV format. /// /// Columns: `trial_id`, `value`, `state`, then one column per unique /// parameter label, then one column per unique user-attribute key. @@ -1800,6 +1856,25 @@ where /// # Errors /// /// Returns an I/O error if writing fails. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// + /// let mut trial = study.create_trial(); + /// let _ = x.suggest(&mut trial); + /// study.complete_trial(trial, 0.42); + /// + /// let mut buf = Vec::new(); + /// study.to_csv(&mut buf).unwrap(); + /// let csv = String::from_utf8(buf).unwrap(); + /// assert!(csv.contains("trial_id")); + /// ``` pub fn to_csv(&self, mut writer: impl std::io::Write) -> std::io::Result<()> { use std::collections::BTreeMap; @@ -1892,7 +1967,10 @@ where Ok(()) } - /// Export completed trials to a CSV file. + /// Export completed trials to a CSV file at the given path. + /// + /// Convenience wrapper around [`to_csv`](Self::to_csv) that creates a + /// buffered file writer. /// /// # Errors /// @@ -1902,7 +1980,7 @@ where self.to_csv(std::io::BufWriter::new(file)) } - /// Returns a human-readable summary of the study. + /// Return a human-readable summary of the study. /// /// The summary includes: /// - Optimization direction and total trial count @@ -1973,10 +2051,24 @@ impl<V> Study<V> where V: PartialOrd + Clone, { - /// Returns an iterator over all completed trials. + /// Return an iterator over all completed trials. /// /// This clones the internal trial list, so it is suitable for /// analysis and iteration but not for hot paths. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let trial = study.create_trial(); + /// study.complete_trial(trial, 1.0); + /// + /// for t in study.iter() { + /// println!("Trial {} → {}", t.id, t.value); + /// } + /// ``` #[must_use] pub fn iter(&self) -> std::vec::IntoIter<CompletedTrial<V>> { self.trials().into_iter() @@ -1987,7 +2079,7 @@ impl<V> Study<V> where V: PartialOrd + Clone + Into<f64>, { - /// Computes parameter importance scores using Spearman rank correlation. + /// Compute parameter importance scores using Spearman rank correlation. /// /// For each parameter, the absolute Spearman correlation between its values /// and the objective values is computed across all completed trials. Scores @@ -2086,7 +2178,7 @@ where scores } - /// Computes parameter importance using fANOVA (functional ANOVA) with + /// Compute parameter importance using fANOVA (functional ANOVA) with /// default configuration. /// /// Fits a random forest to the trial data and decomposes variance into @@ -2097,11 +2189,33 @@ where /// # Errors /// /// Returns [`crate::Error::NoCompletedTrials`] if fewer than 2 trials have completed. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// let y = FloatParam::new(0.0, 10.0).name("y"); + /// + /// study + /// .optimize(30, |trial| { + /// let xv = x.suggest(trial)?; + /// let yv = y.suggest(trial)?; + /// Ok::<_, optimizer::Error>(xv * xv + 0.1 * yv) + /// }) + /// .unwrap(); + /// + /// let result = study.fanova().unwrap(); + /// assert!(!result.main_effects.is_empty()); + /// ``` pub fn fanova(&self) -> crate::Result<crate::fanova::FanovaResult> { self.fanova_with_config(&crate::fanova::FanovaConfig::default()) } - /// Computes parameter importance using fANOVA with custom configuration. + /// Compute parameter importance using fANOVA with custom configuration. /// /// See [`Self::fanova`] for details. The [`FanovaConfig`](crate::fanova::FanovaConfig) /// allows tuning the number of trees, tree depth, and random seed. @@ -2316,7 +2430,30 @@ impl Study<f64> { } impl<V: PartialOrd + Send + Sync + 'static> Study<V> { - /// Creates a study with a custom sampler, pruner, and storage backend. + /// Create a study with a custom sampler, pruner, and storage backend. + /// + /// The most flexible constructor, allowing full control over all components. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `pruner` - The pruner to use for trial pruning. + /// * `storage` - The storage backend for completed trials. + /// + /// # Examples + /// + /// ``` + /// use optimizer::prelude::*; + /// use optimizer::storage::MemoryStorage; + /// + /// let study = Study::with_sampler_pruner_and_storage( + /// Direction::Minimize, + /// TpeSampler::new(), + /// MedianPruner::new(Direction::Minimize), + /// MemoryStorage::<f64>::new(), + /// ); + /// ``` pub fn with_sampler_pruner_and_storage( direction: Direction, sampler: impl Sampler + 'static, @@ -2373,49 +2510,55 @@ pub struct StudyBuilder<V: PartialOrd = f64> { } impl<V: PartialOrd> StudyBuilder<V> { - /// Sets the optimization direction to minimize. + /// Set the optimization direction to minimize (the default). #[must_use] pub fn minimize(mut self) -> Self { self.direction = Direction::Minimize; self } - /// Sets the optimization direction to maximize. + /// Set the optimization direction to maximize. #[must_use] pub fn maximize(mut self) -> Self { self.direction = Direction::Maximize; self } - /// Sets the optimization direction. + /// Set the optimization direction explicitly. #[must_use] pub fn direction(mut self, direction: Direction) -> Self { self.direction = direction; self } - /// Sets the sampler used for parameter suggestions. + /// Set the sampler used for parameter suggestions. + /// + /// Defaults to [`RandomSampler`] if not specified. #[must_use] pub fn sampler(mut self, sampler: impl Sampler + 'static) -> Self { self.sampler = Some(Box::new(sampler)); self } - /// Sets the pruner used for early stopping of trials. + /// Set the pruner used for early stopping of trials. + /// + /// Defaults to [`NopPruner`] (no pruning) if not specified. #[must_use] pub fn pruner(mut self, pruner: impl Pruner + 'static) -> Self { self.pruner = Some(Box::new(pruner)); self } - /// Sets a custom storage backend. + /// Set a custom storage backend. + /// + /// Defaults to [`MemoryStorage`](crate::storage::MemoryStorage) if not specified. #[must_use] pub fn storage(mut self, storage: impl crate::storage::Storage<V> + 'static) -> Self { self.storage = Some(Box::new(storage)); self } - /// Builds the [`Study`] with the configured options. + /// Build the [`Study`] with the configured options. #[must_use] pub fn build(self) -> Study<V> where @@ -2450,15 +2593,31 @@ impl<V> Study<V> where V: PartialOrd + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static, { - /// Creates a study backed by a JSONL journal file. + /// Create a study backed by a JSONL journal file. /// /// Any existing trials in the file are loaded into memory and the /// trial ID counter is set to one past the highest stored ID. New /// trials are written through to the file on completion. /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `path` - Path to the JSONL journal file (created if absent). + /// /// # Errors /// /// Returns a [`Storage`](crate::Error::Storage) error if loading fails. + /// + /// # Examples + /// + /// ```no_run + /// use optimizer::sampler::tpe::TpeSampler; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = + /// Study::with_journal(Direction::Minimize, TpeSampler::new(), "trials.jsonl").unwrap(); + /// ``` pub fn with_journal( direction: Direction, sampler: impl Sampler + 'static, @@ -2470,9 +2629,9 @@ where } impl Study<f64> { - /// Generates an HTML report with interactive Plotly.js charts. + /// Generate an HTML report with interactive Plotly.js charts. /// - /// Creates a self-contained HTML file that can be opened in any browser. + /// Create a self-contained HTML file that can be opened in any browser. /// See [`generate_html_report`](crate::visualization::generate_html_report) /// for details on the included charts. /// @@ -2519,6 +2678,7 @@ impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { /// Export trials as a pretty-printed JSON array to a file. /// /// Each element in the array is a serialized [`CompletedTrial`]. + /// Requires the `serde` feature. /// /// # Errors /// @@ -2529,7 +2689,7 @@ impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { serde_json::to_writer_pretty(file, &trials).map_err(std::io::Error::other) } - /// Saves the study state to a JSON file. + /// Save the study state to a JSON file. /// /// # Errors /// @@ -2561,7 +2721,7 @@ impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { #[cfg(feature = "serde")] impl<V: PartialOrd + Clone + Default + serde::Serialize> Study<V> { - /// Runs optimization with automatic checkpointing every `interval` trials. + /// Run optimization with automatic checkpointing every `interval` trials. /// /// This is convenience sugar over [`optimize_with_callback`](Self::optimize_with_callback) /// combined with [`save`](Self::save). The checkpoint is written atomically so @@ -2595,7 +2755,7 @@ impl<V: PartialOrd + Clone + Default + serde::Serialize> Study<V> { #[cfg(feature = "serde")] impl<V: PartialOrd + Send + Sync + Clone + serde::de::DeserializeOwned + 'static> Study<V> { - /// Loads a study from a JSON file. + /// Load a study from a JSON file. /// /// The loaded study uses a `RandomSampler` by default. Call /// [`set_sampler()`](Self::set_sampler) to restore the original sampler diff --git a/src/trial.rs b/src/trial.rs index 4f90fd3..e46077e 100644 --- a/src/trial.rs +++ b/src/trial.rs @@ -1,4 +1,23 @@ -//! Trial implementation for tracking sampled parameters and trial state. +//! Trial lifecycle management for optimization runs. +//! +//! A [`Trial`] represents a single evaluation of the objective function. The study +//! creates trials, the objective function samples parameters from them via +//! [`Parameter::suggest`](crate::parameter::Parameter::suggest), and reports +//! intermediate values for pruning decisions. +//! +//! # Lifecycle +//! +//! 1. **Created** — `Study` creates a trial with [`Trial::new`] or internally via +//! `Trial::with_sampler`. +//! 2. **Running** — The objective calls [`Trial::suggest_param`] to sample parameters +//! and optionally [`Trial::report`] / [`Trial::should_prune`] for early stopping. +//! 3. **Completed / Failed / Pruned** — The study marks the trial's final state. +//! +//! # User Attributes +//! +//! Trials support arbitrary key-value metadata via [`Trial::set_user_attr`] and +//! [`Trial::user_attr`], useful for logging hyperparameters, hardware info, or +//! debug notes alongside the optimization results. use std::collections::HashMap; use std::sync::Arc; @@ -57,14 +76,26 @@ impl From<bool> for AttrValue { } } -/// A trial represents a single evaluation of the objective function. +/// A single evaluation of the objective function. /// /// Each trial has a unique ID and stores the sampled parameters along with -/// their distributions. The trial progresses through states: Running -> Complete/Failed. +/// their distributions. The trial progresses through states: +/// `Running` → `Complete` / `Failed` / `Pruned`. /// -/// Trials use a sampler to generate parameter values. When created through -/// `Study::create_trial()`, the trial receives the study's sampler and access -/// to the history of completed trials for informed sampling. +/// Trials use a [`Sampler`](crate::sampler::Sampler) to generate parameter +/// values. When created through [`Study::create_trial`](crate::Study::create_trial), +/// the trial receives the study's sampler and access to the history of +/// completed trials for informed sampling. +/// +/// # Examples +/// +/// ``` +/// use optimizer::Trial; +/// use optimizer::parameter::{FloatParam, Parameter}; +/// +/// let mut trial = Trial::new(0); +/// let x = FloatParam::new(-5.0, 5.0).suggest(&mut trial).unwrap(); +/// ``` #[derive(Clone)] pub struct Trial { /// Unique identifier for this trial. @@ -113,13 +144,14 @@ impl core::fmt::Debug for Trial { } impl Trial { - /// Creates a new trial with the given ID. + /// Create a new trial with the given ID. /// /// The trial starts in the `Running` state with no parameters sampled. - /// This constructor creates a trial without a sampler, which will use - /// local random sampling for suggest methods. + /// This constructor creates a trial without a sampler, which will fall + /// back to random sampling for [`suggest_param`](Self::suggest_param) calls. /// - /// For trials that use the study's sampler, use `Trial::with_sampler` instead. + /// For trials that use the study's sampler, the study creates them + /// internally via `Trial::with_sampler`. /// /// # Arguments /// @@ -151,10 +183,10 @@ impl Trial { } } - /// Creates a new trial with a sampler and access to trial history. + /// Create a new trial with a sampler and access to trial history. /// - /// This constructor is used by `Study::create_trial()` to create trials - /// that use the study's sampler for informed parameter suggestions. + /// Used internally by `Study::create_trial()` to create trials that use + /// the study's sampler for informed parameter suggestions. /// /// # Arguments /// @@ -183,19 +215,19 @@ impl Trial { } } - /// Sets pre-filled parameters on this trial. + /// Set pre-filled parameters on this trial. /// - /// When `suggest_param` is called for a parameter that has a fixed value, - /// the fixed value is used instead of sampling. + /// When [`suggest_param`](Self::suggest_param) is called for a parameter + /// that has a fixed value, the fixed value is used instead of sampling. pub(crate) fn set_fixed_params(&mut self, params: HashMap<ParamId, ParamValue>) { self.fixed_params = params; } - /// Samples a value from the given distribution using the sampler. + /// Sample a value from the given distribution using the sampler. /// - /// If the trial has a sampler, it delegates to the sampler's sample method - /// with the history of completed trials. Otherwise, it uses the `RandomSampler` - /// as a fallback. + /// If the trial has a sampler, delegates to the sampler's sample method + /// with the history of completed trials. Otherwise, falls back to + /// [`RandomSampler`](crate::sampler::random::RandomSampler). fn sample_value(&self, distribution: &Distribution) -> ParamValue { if let (Some(sampler), Some(history)) = (&self.sampler, &self.history) { let history_guard = history.read(); @@ -208,40 +240,55 @@ impl Trial { } } - /// Returns the unique ID of this trial. + /// Return the unique ID of this trial. #[must_use] pub fn id(&self) -> u64 { self.id } - /// Returns the current state of this trial. + /// Return the current state of this trial. #[must_use] pub fn state(&self) -> TrialState { self.state } - /// Returns a reference to the sampled parameters. + /// Return a reference to the sampled parameters, keyed by [`ParamId`](crate::parameter::ParamId). #[must_use] pub fn params(&self) -> &HashMap<ParamId, ParamValue> { &self.params } - /// Returns a reference to the parameter distributions. + /// Return a reference to the parameter distributions, keyed by [`ParamId`](crate::parameter::ParamId). #[must_use] pub fn distributions(&self) -> &HashMap<ParamId, Distribution> { &self.distributions } - /// Returns a reference to the parameter labels. + /// Return a reference to the parameter labels, keyed by [`ParamId`](crate::parameter::ParamId). #[must_use] pub fn param_labels(&self) -> &HashMap<ParamId, String> { &self.param_labels } - /// Reports an intermediate objective value at a given step. + /// Report an intermediate objective value at a given step. + /// + /// Call this during iterative training (e.g., once per epoch) so the + /// [`Pruner`](crate::pruner::Pruner) can decide whether to stop the trial + /// early. Steps should be monotonically increasing; duplicate steps + /// overwrite the previous value. /// - /// Steps should be monotonically increasing (e.g., epoch number). - /// Duplicate steps overwrite the previous value. + /// # Examples + /// + /// ``` + /// use optimizer::Trial; + /// + /// let mut trial = Trial::new(0); + /// for epoch in 0..10 { + /// let loss = 1.0 / (epoch as f64 + 1.0); + /// trial.report(epoch, loss); + /// } + /// assert_eq!(trial.intermediate_values().len(), 10); + /// ``` pub fn report(&mut self, step: u64, value: f64) { if let Some(entry) = self .intermediate_values @@ -256,8 +303,11 @@ impl Trial { /// Ask whether this trial should be pruned at the current step. /// - /// Returns `true` if the pruner recommends stopping this trial. - /// The caller should return `Err(TrialPruned)` from the objective. + /// Return `true` if the pruner recommends stopping this trial based on + /// the intermediate values reported so far. When `true`, the objective + /// should return early with `Err(TrialPruned)?`. + /// + /// Always returns `false` when no pruner is configured. #[must_use] pub fn should_prune(&self) -> bool { let (Some(pruner), Some(history)) = (&self.pruner, &self.history) else { @@ -274,59 +324,87 @@ impl Trial { prune } - /// Returns all intermediate values reported so far. + /// Return all intermediate values reported so far as `(step, value)` pairs. #[must_use] pub fn intermediate_values(&self) -> &[(u64, f64)] { &self.intermediate_values } - /// Sets a user attribute on this trial. + /// Set a user attribute on this trial. + /// + /// User attributes are arbitrary key-value pairs for logging, debugging, + /// or analysis. Values can be `f64`, `i64`, `String`, `&str`, or `bool` + /// (anything implementing `Into<AttrValue>`). + /// + /// # Examples + /// + /// ``` + /// use optimizer::Trial; + /// + /// let mut trial = Trial::new(0); + /// trial.set_user_attr("gpu", "A100"); + /// trial.set_user_attr("batch_size", 64_i64); + /// trial.set_user_attr("accuracy", 0.95); + /// ``` pub fn set_user_attr(&mut self, key: impl Into<String>, value: impl Into<AttrValue>) { self.user_attrs.insert(key.into(), value.into()); } - /// Gets a user attribute by key. + /// Return a user attribute by key, or `None` if it does not exist. #[must_use] pub fn user_attr(&self, key: &str) -> Option<&AttrValue> { self.user_attrs.get(key) } - /// Returns all user attributes. + /// Return all user attributes as a map. #[must_use] pub fn user_attrs(&self) -> &HashMap<String, AttrValue> { &self.user_attrs } - /// Sets constraint values for this trial. + /// Set constraint values for this trial. + /// + /// Each element represents one constraint. A value ≤ 0.0 means the + /// constraint is satisfied (feasible); a value > 0.0 means violated. + /// Constrained samplers (e.g., NSGA-II with constraints) use these values + /// to prefer feasible solutions. /// - /// Each value represents a constraint; a value <= 0.0 means the constraint - /// is satisfied (feasible). A value > 0.0 means the constraint is violated. + /// # Examples + /// + /// ``` + /// use optimizer::Trial; + /// + /// let mut trial = Trial::new(0); + /// // Two constraints: first satisfied, second violated + /// trial.set_constraints(vec![-0.5, 0.3]); + /// assert_eq!(trial.constraint_values(), &[-0.5, 0.3]); + /// ``` pub fn set_constraints(&mut self, values: Vec<f64>) { self.constraint_values = values; } - /// Returns the constraint values for this trial. + /// Return the constraint values for this trial. #[must_use] pub fn constraint_values(&self) -> &[f64] { &self.constraint_values } - /// Sets the trial state to Complete. + /// Set the trial state to `Complete`. pub(crate) fn set_complete(&mut self) { self.state = TrialState::Complete; } - /// Sets the trial state to Failed. + /// Set the trial state to `Failed`. pub(crate) fn set_failed(&mut self) { self.state = TrialState::Failed; } - /// Sets the trial state to Pruned. + /// Set the trial state to `Pruned`. pub(crate) fn set_pruned(&mut self) { self.state = TrialState::Pruned; } - /// Suggests a parameter value using a [`Parameter`] definition. + /// Suggest a parameter value using a [`Parameter`] definition. /// /// This is the primary entry point for sampling parameters. It handles /// validation, caching, conflict detection, sampling, and conversion. diff --git a/src/visualization.rs b/src/visualization.rs index 3f1a69e..f607f25 100644 --- a/src/visualization.rs +++ b/src/visualization.rs @@ -1,7 +1,41 @@ //! HTML report generation for optimization visualization. //! -//! Generates self-contained HTML files with embedded Plotly.js charts -//! for offline visualization of optimization results. +//! Generate self-contained HTML files with embedded +//! [Plotly.js](https://plotly.com/javascript/) charts for offline +//! visualization of optimization results. No feature flag is required — +//! this module is always available. +//! +//! # Charts included +//! +//! | Chart | Description | +//! |---|---| +//! | **Optimization history** | Objective value vs trial number with best-so-far line | +//! | **Slice plots** | Objective value vs each parameter (1D scatter per param) | +//! | **Parallel coordinates** | Multi-parameter relationship view (color = objective) | +//! | **Parameter importance** | Horizontal bar chart of Spearman-based importance | +//! | **Trial timeline** | Duration/index of each trial, color-coded by state | +//! | **Intermediate values** | Per-trial learning curves (if pruning data available) | +//! +//! # Usage +//! +//! Call [`Study::export_html()`](crate::Study::export_html) or +//! [`generate_html_report()`] directly: +//! +//! ```no_run +//! use optimizer::prelude::*; +//! +//! let study: Study<f64> = Study::new(Direction::Minimize); +//! # let x = FloatParam::new(0.0, 1.0); +//! # study.optimize(10, |trial| { +//! # let v = x.suggest(trial)?; +//! # Ok::<_, optimizer::Error>(v * v) +//! # }).unwrap(); +//! study.export_html("report.html").unwrap(); +//! ``` +//! +//! The output is a single HTML file that can be opened in any browser. +//! An internet connection is needed on first load to fetch `Plotly.js` +//! from a CDN. use core::fmt::Write as _; use std::collections::BTreeMap; @@ -15,17 +49,19 @@ use crate::types::{Direction, TrialState}; /// Generate an HTML report with interactive Plotly.js charts. /// -/// Creates a self-contained HTML file at `path` containing: -/// - **Optimization history**: Objective value vs trial number with best-so-far line -/// - **Slice plots**: Objective value vs each parameter (1D scatter) -/// - **Parallel coordinates**: Multi-parameter relationship view -/// - **Trial timeline**: Duration index of each trial (horizontal bar) -/// - **Intermediate values**: Learning curves per trial (if pruning data available) -/// - **Parameter importance**: Bar chart (if enough completed trials) +/// Create a self-contained HTML file at `path` containing up to six +/// interactive charts. Charts that require data not present in the study +/// (e.g., intermediate values) are automatically omitted. +/// +/// The report includes: optimization history, slice plots, parallel +/// coordinates, parameter importance, trial timeline, and intermediate +/// values (when available). +/// +/// This is also available as [`Study::export_html()`](crate::Study::export_html). /// /// # Errors /// -/// Returns an I/O error if the file cannot be created or written. +/// Return an I/O error if the file cannot be created or written. pub fn generate_html_report( study: &crate::Study<f64>, path: impl AsRef<Path>, From cc0dce5daf034cc4f7d7f6bd23cfc3beecced615 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 10:13:48 +0100 Subject: [PATCH 02/48] ci: add doc tests and update examples to match current codebase --- .github/workflows/ci.yml | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a747dad..2150e89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,6 +51,8 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run tests run: cargo test --verbose --all-features --all-targets + - name: Run doc tests + run: cargo test --verbose --all-features --doc examples: name: Examples @@ -62,16 +64,16 @@ jobs: rustup override set stable rustup update stable - uses: Swatinem/rust-cache@v2 - - name: Run sync example (ml_hyperparameter_tuning) - run: cargo run --example ml_hyperparameter_tuning - - name: Run sync example (benchmark_convergence) - run: cargo run --example benchmark_convergence - - name: Run derive example (parameter_api) - run: cargo run --example parameter_api --features derive - - name: Run async example (async_api_optimization) - run: cargo run --example async_api_optimization --features async - - name: Run visualization example (visualization_demo) - run: cargo run --example visualization_demo + - name: Run basic_optimization + run: cargo run --example basic_optimization + - name: Run sampler_comparison + run: cargo run --example sampler_comparison + - name: Run pruning_and_callbacks + run: cargo run --example pruning_and_callbacks + - name: Run parameter_types + run: cargo run --example parameter_types --features derive + - name: Run advanced_features + run: cargo run --example advanced_features --features "async,journal" docs: name: Docs From 6a4b27c46f3584e90565463beb4df82de38166fd Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 10:28:05 +0100 Subject: [PATCH 03/48] refactor: move all type re-exports out of crate root into their modules - Sampler, pruner, storage, parameter, and multi_objective types are no longer re-exported at the crate root; access via module paths instead (e.g. optimizer::sampler::TpeSampler, optimizer::parameter::FloatParam) - Prelude continues to re-export everything for convenience - Add module-level pub use re-exports in sampler/mod.rs and parameter.rs - Update derive macro to reference optimizer::parameter::Categorical - Rename sampler::differential_evolution to sampler::de - Fix all downstream imports in tests, benches, and doctests - Fix all rustdoc intra-doc links to use explicit module paths --- benches/optimization.rs | 4 +- benches/samplers.rs | 7 +- optimizer-derive/src/lib.rs | 6 +- src/lib.rs | 97 ++++++------------- src/multi_objective.rs | 12 +-- src/parameter.rs | 2 +- src/pareto.rs | 4 +- src/pruner/hyperband.rs | 2 +- src/pruner/percentile.rs | 2 +- src/pruner/successive_halving.rs | 2 +- .../{differential_evolution.rs => de.rs} | 10 +- src/sampler/mod.rs | 18 +++- src/study.rs | 4 +- tests/differential_evolution_tests.rs | 4 +- tests/integration.rs | 6 +- tests/multi_objective_tests.rs | 3 +- tests/serde_tests.rs | 4 +- 17 files changed, 82 insertions(+), 105 deletions(-) rename src/sampler/{differential_evolution.rs => de.rs} (98%) diff --git a/benches/optimization.rs b/benches/optimization.rs index c4f58dc..cfdaafe 100644 --- a/benches/optimization.rs +++ b/benches/optimization.rs @@ -2,10 +2,10 @@ mod test_functions; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use optimizer::parameter::Parameter; +use optimizer::Study; +use optimizer::parameter::{FloatParam, Parameter}; use optimizer::sampler::random::RandomSampler; use optimizer::sampler::tpe::TpeSampler; -use optimizer::{FloatParam, Study}; fn make_params(dims: usize) -> Vec<FloatParam> { (0..dims) diff --git a/benches/samplers.rs b/benches/samplers.rs index f2e9ec3..79da8a6 100644 --- a/benches/samplers.rs +++ b/benches/samplers.rs @@ -1,12 +1,11 @@ use std::collections::HashMap; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; -use optimizer::parameter::Parameter; -use optimizer::sampler::Sampler; +use optimizer::parameter::{FloatParam, Parameter}; use optimizer::sampler::grid::GridSearchSampler; use optimizer::sampler::random::RandomSampler; use optimizer::sampler::tpe::TpeSampler; -use optimizer::{CompletedTrial, FloatParam}; +use optimizer::sampler::{CompletedTrial, Sampler}; /// Build a synthetic history of `n` completed trials over `dims` float parameters. fn build_history(n: usize, dims: usize) -> Vec<CompletedTrial<f64>> { @@ -33,7 +32,7 @@ fn build_history(n: usize, dims: usize) -> Vec<CompletedTrial<f64>> { let value: f64 = param_values .values() .map(|v| { - let optimizer::ParamValue::Float(f) = v else { + let optimizer::parameter::ParamValue::Float(f) = v else { unreachable!() }; f * f diff --git a/optimizer-derive/src/lib.rs b/optimizer-derive/src/lib.rs index 4d019f9..3e30885 100644 --- a/optimizer-derive/src/lib.rs +++ b/optimizer-derive/src/lib.rs @@ -4,13 +4,13 @@ use syn::{Data, DeriveInput, Fields, parse_macro_input}; /// Derive macro for the `Categorical` trait on fieldless enums. /// -/// Generates an implementation of `optimizer::Categorical` that maps +/// Generates an implementation of `optimizer::parameter::Categorical` that maps /// enum variants to/from sequential indices. /// /// # Example /// /// ```ignore -/// use optimizer::Categorical; +/// use optimizer::parameter::Categorical; /// /// #[derive(Clone, Categorical)] /// enum Color { @@ -49,7 +49,7 @@ pub fn derive_categorical(input: TokenStream) -> TokenStream { let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); let expanded = quote! { - impl #impl_generics optimizer::Categorical for #name #ty_generics #where_clause { + impl #impl_generics optimizer::parameter::Categorical for #name #ty_generics #where_clause { const N_CHOICES: usize = #n_choices; fn from_index(index: usize) -> Self { diff --git a/src/lib.rs b/src/lib.rs index 8df4aec..0636f54 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,7 +42,7 @@ //! |------|------| //! | [`Study`] | Drive an optimization loop: create trials, record results, track the best. | //! | [`Trial`] | A single evaluation of the objective function, carrying suggested parameter values. | -//! | [`Parameter`] | Define the search space — [`FloatParam`], [`IntParam`], [`CategoricalParam`], [`BoolParam`], [`EnumParam`]. | +//! | [`Parameter`](parameter::Parameter) | Define the search space — [`FloatParam`](parameter::FloatParam), [`IntParam`](parameter::IntParam), [`CategoricalParam`](parameter::CategoricalParam), [`BoolParam`](parameter::BoolParam), [`EnumParam`](parameter::EnumParam). | //! | [`Sampler`](sampler::Sampler) | Strategy for choosing the next point to evaluate (TPE, CMA-ES, random, etc.). | //! | [`Direction`] | Whether the study minimizes or maximizes the objective value. | //! @@ -52,23 +52,23 @@ //! //! | Sampler | Algorithm | Best for | Feature flag | //! |---------|-----------|----------|--------------| -//! | [`RandomSampler`] | Uniform random | Baselines, high-dimensional | — | -//! | [`TpeSampler`] | Tree-Parzen Estimator | General-purpose Bayesian | — | -//! | [`GridSearchSampler`] | Exhaustive grid | Small, discrete spaces | — | -//! | [`SobolSampler`] | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | -//! | [`CmaEsSampler`] | CMA-ES | Continuous, moderate dimensions | `cma-es` | -//! | [`GpSampler`] | Gaussian Process + EI | Expensive objectives, few trials | `gp` | -//! | [`DifferentialEvolutionSampler`] | Differential Evolution | Non-convex, population-based | — | -//! | [`BohbSampler`] | BOHB (TPE + `HyperBand`) | Budget-aware early stopping | — | +//! | [`RandomSampler`](sampler::RandomSampler) | Uniform random | Baselines, high-dimensional | — | +//! | [`TpeSampler`](sampler::TpeSampler) | Tree-Parzen Estimator | General-purpose Bayesian | — | +//! | [`GridSearchSampler`](sampler::GridSearchSampler) | Exhaustive grid | Small, discrete spaces | — | +//! | [`SobolSampler`](sampler::SobolSampler) | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | +//! | [`CmaEsSampler`](sampler::CmaEsSampler) | CMA-ES | Continuous, moderate dimensions | `cma-es` | +//! | [`GpSampler`](sampler::GpSampler) | Gaussian Process + EI | Expensive objectives, few trials | `gp` | +//! | [`DifferentialEvolutionSampler`](sampler::DifferentialEvolutionSampler) | Differential Evolution | Non-convex, population-based | — | +//! | [`BohbSampler`](sampler::BohbSampler) | BOHB (TPE + `HyperBand`) | Budget-aware early stopping | — | //! //! ## Multi-objective samplers //! //! | Sampler | Algorithm | Best for | Feature flag | //! |---------|-----------|----------|--------------| -//! | [`Nsga2Sampler`] | NSGA-II | 2-3 objectives | — | -//! | [`Nsga3Sampler`] | NSGA-III (reference-point) | 3+ objectives | — | -//! | [`MoeadSampler`] | MOEA/D (decomposition) | Many objectives, structured fronts | — | -//! | [`MotpeSampler`] | Multi-Objective TPE | Bayesian multi-objective | — | +//! | [`Nsga2Sampler`](sampler::Nsga2Sampler) | NSGA-II | 2-3 objectives | — | +//! | [`Nsga3Sampler`](sampler::Nsga3Sampler) | NSGA-III (reference-point) | 3+ objectives | — | +//! | [`MoeadSampler`](sampler::MoeadSampler) | MOEA/D (decomposition) | Many objectives, structured fronts | — | +//! | [`MotpeSampler`](sampler::MotpeSampler) | Multi-Objective TPE | Bayesian multi-objective | — | //! //! # Feature Flags //! @@ -77,10 +77,10 @@ //! | `async` | Async/parallel optimization via tokio ([`Study::optimize_async`], [`Study::optimize_parallel`]) | off | //! | `derive` | `#[derive(Categorical)]` for enum parameters | off | //! | `serde` | `Serialize`/`Deserialize` on public types, [`Study::save`]/[`Study::load`] | off | -//! | `journal` | [`JournalStorage`] — JSONL persistence with file locking (enables `serde`) | off | -//! | `sobol` | [`SobolSampler`] — quasi-random low-discrepancy sequences | off | -//! | `cma-es` | [`CmaEsSampler`] — Covariance Matrix Adaptation Evolution Strategy | off | -//! | `gp` | [`GpSampler`] — Gaussian Process surrogate with Expected Improvement | off | +//! | `journal` | [`JournalStorage`](storage::JournalStorage) — JSONL persistence with file locking (enables `serde`) | off | +//! | `sobol` | [`SobolSampler`](sampler::SobolSampler) — quasi-random low-discrepancy sequences | off | +//! | `cma-es` | [`CmaEsSampler`](sampler::CmaEsSampler) — Covariance Matrix Adaptation Evolution Strategy | off | +//! | `gp` | [`GpSampler`](sampler::GpSampler) — Gaussian Process surrogate with Expected Improvement | off | //! | `tracing` | Structured log events via [`tracing`](https://docs.rs/tracing) at key optimization points | off | /// Emit a `tracing::info!` event when the `tracing` feature is enabled. @@ -127,38 +127,8 @@ mod visualization; pub use error::{Error, Result, TrialPruned}; pub use fanova::{FanovaConfig, FanovaResult}; -pub use multi_objective::{MultiObjectiveSampler, MultiObjectiveStudy, MultiObjectiveTrial}; #[cfg(feature = "derive")] pub use optimizer_derive::Categorical; -pub use param::ParamValue; -pub use parameter::{ - BoolParam, Categorical, CategoricalParam, EnumParam, FloatParam, IntParam, ParamId, Parameter, -}; -pub use pruner::{ - HyperbandPruner, MedianPruner, NopPruner, PatientPruner, PercentilePruner, Pruner, - SuccessiveHalvingPruner, ThresholdPruner, WilcoxonPruner, -}; -pub use sampler::CompletedTrial; -pub use sampler::bohb::BohbSampler; -#[cfg(feature = "cma-es")] -pub use sampler::cma_es::CmaEsSampler; -pub use sampler::differential_evolution::{ - DifferentialEvolutionSampler, DifferentialEvolutionStrategy, -}; -#[cfg(feature = "gp")] -pub use sampler::gp::GpSampler; -pub use sampler::grid::GridSearchSampler; -pub use sampler::moead::{Decomposition, MoeadSampler}; -pub use sampler::motpe::MotpeSampler; -pub use sampler::nsga2::Nsga2Sampler; -pub use sampler::nsga3::Nsga3Sampler; -pub use sampler::random::RandomSampler; -#[cfg(feature = "sobol")] -pub use sampler::sobol::SobolSampler; -pub use sampler::tpe::TpeSampler; -#[cfg(feature = "journal")] -pub use storage::JournalStorage; -pub use storage::{MemoryStorage, Storage}; #[cfg(feature = "serde")] pub use study::StudySnapshot; pub use study::{Study, StudyBuilder}; @@ -177,33 +147,28 @@ pub mod prelude { pub use crate::error::{Error, Result, TrialPruned}; pub use crate::fanova::{FanovaConfig, FanovaResult}; - pub use crate::multi_objective::{MultiObjectiveStudy, MultiObjectiveTrial}; - pub use crate::param::ParamValue; + pub use crate::multi_objective::{ + MultiObjectiveSampler, MultiObjectiveStudy, MultiObjectiveTrial, + }; pub use crate::parameter::{ - BoolParam, Categorical, CategoricalParam, EnumParam, FloatParam, IntParam, Parameter, + BoolParam, Categorical, CategoricalParam, EnumParam, FloatParam, IntParam, ParamValue, + Parameter, }; pub use crate::pruner::{ HyperbandPruner, MedianPruner, NopPruner, PatientPruner, PercentilePruner, Pruner, - SuccessiveHalvingPruner, ThresholdPruner, + SuccessiveHalvingPruner, ThresholdPruner, WilcoxonPruner, }; - pub use crate::sampler::CompletedTrial; - pub use crate::sampler::bohb::BohbSampler; #[cfg(feature = "cma-es")] - pub use crate::sampler::cma_es::CmaEsSampler; - pub use crate::sampler::differential_evolution::{ - DifferentialEvolutionSampler, DifferentialEvolutionStrategy, - }; + pub use crate::sampler::CmaEsSampler; #[cfg(feature = "gp")] - pub use crate::sampler::gp::GpSampler; - pub use crate::sampler::grid::GridSearchSampler; - pub use crate::sampler::moead::{Decomposition, MoeadSampler}; - pub use crate::sampler::motpe::MotpeSampler; - pub use crate::sampler::nsga2::Nsga2Sampler; - pub use crate::sampler::nsga3::Nsga3Sampler; - pub use crate::sampler::random::RandomSampler; + pub use crate::sampler::GpSampler; #[cfg(feature = "sobol")] - pub use crate::sampler::sobol::SobolSampler; - pub use crate::sampler::tpe::TpeSampler; + pub use crate::sampler::SobolSampler; + pub use crate::sampler::{ + BohbSampler, CompletedTrial, Decomposition, DifferentialEvolutionSampler, + DifferentialEvolutionStrategy, GridSearchSampler, MoeadSampler, MotpeSampler, Nsga2Sampler, + Nsga3Sampler, RandomSampler, TpeSampler, + }; #[cfg(feature = "journal")] pub use crate::storage::JournalStorage; pub use crate::storage::{MemoryStorage, Storage}; diff --git a/src/multi_objective.rs b/src/multi_objective.rs index d47ce37..8dd144a 100644 --- a/src/multi_objective.rs +++ b/src/multi_objective.rs @@ -19,9 +19,9 @@ //! # Samplers //! //! By default a random sampler is used. For smarter search, pass a -//! [`MultiObjectiveSampler`] such as [`Nsga2Sampler`](crate::Nsga2Sampler), -//! [`Nsga3Sampler`](crate::Nsga3Sampler), or -//! [`MoeadSampler`](crate::MoeadSampler) via +//! [`MultiObjectiveSampler`] such as [`Nsga2Sampler`](crate::sampler::Nsga2Sampler), +//! [`Nsga3Sampler`](crate::sampler::Nsga3Sampler), or +//! [`MoeadSampler`](crate::sampler::MoeadSampler) via //! [`MultiObjectiveStudy::with_sampler`]. //! //! # Examples @@ -140,9 +140,9 @@ impl MultiObjectiveTrial { /// (`&[MultiObjectiveTrial]`) and the per-objective directions /// (`&[Direction]`). /// -/// Implementations include [`Nsga2Sampler`](crate::Nsga2Sampler), -/// [`Nsga3Sampler`](crate::Nsga3Sampler), and -/// [`MoeadSampler`](crate::MoeadSampler). +/// Implementations include [`Nsga2Sampler`](crate::sampler::Nsga2Sampler), +/// [`Nsga3Sampler`](crate::sampler::Nsga3Sampler), and +/// [`MoeadSampler`](crate::sampler::MoeadSampler). pub trait MultiObjectiveSampler: Send + Sync { /// Samples a parameter value from the given distribution. fn sample( diff --git a/src/parameter.rs b/src/parameter.rs index 6c48931..c266034 100644 --- a/src/parameter.rs +++ b/src/parameter.rs @@ -46,7 +46,7 @@ use crate::distribution::{ CategoricalDistribution, Distribution, FloatDistribution, IntDistribution, }; use crate::error::{Error, Result}; -use crate::param::ParamValue; +pub use crate::param::ParamValue; use crate::trial::Trial; static NEXT_PARAM_ID: AtomicU64 = AtomicU64::new(0); diff --git a/src/pareto.rs b/src/pareto.rs index c6a23c9..a7b321d 100644 --- a/src/pareto.rs +++ b/src/pareto.rs @@ -29,8 +29,8 @@ //! //! Internally, this module also provides the fast non-dominated sorting //! algorithm (Deb et al., 2002) used by -//! [`MultiObjectiveStudy::pareto_front()`](crate::MultiObjectiveStudy::pareto_front) -//! and [`Nsga2Sampler`](crate::Nsga2Sampler). +//! [`MultiObjectiveStudy::pareto_front()`](crate::multi_objective::MultiObjectiveStudy::pareto_front) +//! and [`Nsga2Sampler`](crate::sampler::Nsga2Sampler). //! //! # Example //! diff --git a/src/pruner/hyperband.rs b/src/pruner/hyperband.rs index ac82acf..697ea6c 100644 --- a/src/pruner/hyperband.rs +++ b/src/pruner/hyperband.rs @@ -317,7 +317,7 @@ mod tests { CompletedTrial::with_intermediate_values( id, - HashMap::<ParamId, crate::ParamValue>::new(), + HashMap::<ParamId, crate::parameter::ParamValue>::new(), HashMap::new(), HashMap::new(), 0.0, diff --git a/src/pruner/percentile.rs b/src/pruner/percentile.rs index 0495b6b..f12df01 100644 --- a/src/pruner/percentile.rs +++ b/src/pruner/percentile.rs @@ -233,7 +233,7 @@ mod tests { CompletedTrial::with_intermediate_values( id, - HashMap::<ParamId, crate::ParamValue>::new(), + HashMap::<ParamId, crate::parameter::ParamValue>::new(), HashMap::new(), HashMap::new(), 0.0, diff --git a/src/pruner/successive_halving.rs b/src/pruner/successive_halving.rs index 0408101..64d8791 100644 --- a/src/pruner/successive_halving.rs +++ b/src/pruner/successive_halving.rs @@ -285,7 +285,7 @@ mod tests { CompletedTrial::with_intermediate_values( id, - HashMap::<ParamId, crate::ParamValue>::new(), + HashMap::<ParamId, crate::parameter::ParamValue>::new(), HashMap::new(), HashMap::new(), 0.0, diff --git a/src/sampler/differential_evolution.rs b/src/sampler/de.rs similarity index 98% rename from src/sampler/differential_evolution.rs rename to src/sampler/de.rs index 3a11e55..7641084 100644 --- a/src/sampler/differential_evolution.rs +++ b/src/sampler/de.rs @@ -46,9 +46,7 @@ //! # Examples //! //! ``` -//! use optimizer::sampler::differential_evolution::{ -//! DifferentialEvolutionSampler, DifferentialEvolutionStrategy, -//! }; +//! use optimizer::sampler::de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; //! use optimizer::{Direction, Study}; //! //! // Minimize with DE using the Best1 strategy for faster convergence @@ -101,7 +99,7 @@ pub enum DifferentialEvolutionStrategy { /// # Examples /// /// ``` -/// use optimizer::sampler::differential_evolution::DifferentialEvolutionSampler; +/// use optimizer::sampler::de::DifferentialEvolutionSampler; /// use optimizer::{Direction, Study}; /// /// // Default configuration @@ -115,7 +113,7 @@ pub enum DifferentialEvolutionStrategy { /// ); /// /// // Custom configuration via builder -/// use optimizer::sampler::differential_evolution::DifferentialEvolutionStrategy; +/// use optimizer::sampler::de::DifferentialEvolutionStrategy; /// let sampler = DifferentialEvolutionSampler::builder() /// .mutation_factor(0.8) /// .crossover_rate(0.9) @@ -183,7 +181,7 @@ impl Default for DifferentialEvolutionSampler { /// # Examples /// /// ``` -/// use optimizer::sampler::differential_evolution::{ +/// use optimizer::sampler::de::{ /// DifferentialEvolutionSamplerBuilder, DifferentialEvolutionStrategy, /// }; /// diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index e1813e0..b512c49 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -3,7 +3,7 @@ pub mod bohb; #[cfg(feature = "cma-es")] pub mod cma_es; -pub mod differential_evolution; +pub mod de; pub(crate) mod genetic; #[cfg(feature = "gp")] pub mod gp; @@ -19,6 +19,22 @@ pub mod tpe; use std::collections::HashMap; +pub use bohb::BohbSampler; +#[cfg(feature = "cma-es")] +pub use cma_es::CmaEsSampler; +pub use de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; +#[cfg(feature = "gp")] +pub use gp::GpSampler; +pub use grid::GridSearchSampler; +pub use moead::{Decomposition, MoeadSampler}; +pub use motpe::MotpeSampler; +pub use nsga2::Nsga2Sampler; +pub use nsga3::Nsga3Sampler; +pub use random::RandomSampler; +#[cfg(feature = "sobol")] +pub use sobol::SobolSampler; +pub use tpe::TpeSampler; + use crate::distribution::Distribution; use crate::param::ParamValue; use crate::parameter::{ParamId, Parameter}; diff --git a/src/study.rs b/src/study.rs index e8effcf..3827942 100644 --- a/src/study.rs +++ b/src/study.rs @@ -382,8 +382,8 @@ where /// ``` /// use std::collections::HashMap; /// - /// use optimizer::parameter::{FloatParam, IntParam, Parameter}; - /// use optimizer::{Direction, ParamValue, Study}; + /// use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; + /// use optimizer::{Direction, Study}; /// /// let study: Study<f64> = Study::new(Direction::Minimize); /// let x = FloatParam::new(0.0, 10.0); diff --git a/tests/differential_evolution_tests.rs b/tests/differential_evolution_tests.rs index 75328bc..eb7a927 100644 --- a/tests/differential_evolution_tests.rs +++ b/tests/differential_evolution_tests.rs @@ -1,7 +1,5 @@ use optimizer::prelude::*; -use optimizer::sampler::differential_evolution::{ - DifferentialEvolutionSampler, DifferentialEvolutionStrategy, -}; +use optimizer::sampler::de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; #[test] fn sphere_function() { diff --git a/tests/integration.rs b/tests/integration.rs index 9e8eb76..b895606 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1707,7 +1707,7 @@ fn test_ask_and_tell_with_custom_value_type() { use std::collections::HashMap; -use optimizer::ParamValue; +use optimizer::parameter::ParamValue; #[test] fn test_enqueue_params_evaluated_first() { @@ -2291,7 +2291,7 @@ fn test_builder_with_sampler() { #[test] fn test_builder_with_pruner() { - use optimizer::NopPruner; + use optimizer::pruner::NopPruner; let study: Study<f64> = Study::builder().pruner(NopPruner).build(); @@ -2303,7 +2303,7 @@ fn test_builder_chaining() { let study: Study<f64> = Study::builder() .maximize() .sampler(RandomSampler::with_seed(42)) - .pruner(optimizer::NopPruner) + .pruner(optimizer::pruner::NopPruner) .build(); assert_eq!(study.direction(), Direction::Maximize); diff --git a/tests/multi_objective_tests.rs b/tests/multi_objective_tests.rs index ec5e200..dc30af9 100644 --- a/tests/multi_objective_tests.rs +++ b/tests/multi_objective_tests.rs @@ -1,11 +1,12 @@ //! Integration tests for multi-objective optimization. +use optimizer::Direction; use optimizer::multi_objective::MultiObjectiveStudy; use optimizer::parameter::{CategoricalParam, FloatParam, Parameter}; +use optimizer::sampler::Decomposition; use optimizer::sampler::moead::MoeadSampler; use optimizer::sampler::nsga2::Nsga2Sampler; use optimizer::sampler::nsga3::Nsga3Sampler; -use optimizer::{Decomposition, Direction}; // --------------------------------------------------------------------------- // Pareto utility tests (via public MultiObjectiveStudy) diff --git a/tests/serde_tests.rs b/tests/serde_tests.rs index 66aa745..202bc19 100644 --- a/tests/serde_tests.rs +++ b/tests/serde_tests.rs @@ -2,9 +2,9 @@ use std::collections::HashMap; -use optimizer::parameter::{FloatParam, IntParam, Parameter}; +use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; use optimizer::sampler::CompletedTrial; -use optimizer::{Direction, ParamValue, Study, StudySnapshot, TrialState}; +use optimizer::{Direction, Study, StudySnapshot, TrialState}; #[test] fn round_trip_save_load() { From ee59c9cdd01d07da4dd0dc80ab681bd989cb723f Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 10:33:51 +0100 Subject: [PATCH 04/48] refactor(examples): split multi-concept examples into focused single-topic files - Split pruning_and_callbacks into pruning and early_stopping - Split advanced_features into async_parallel, journal_storage, ask_and_tell, multi_objective - Each example now requires only its own feature flag - Trim sampler_comparison winner logic and verbose header - Update CI workflow and README to match new example names --- .github/workflows/ci.yml | 20 ++- Cargo.toml | 27 ++- README.md | 14 +- examples/advanced_features.rs | 277 ------------------------------ examples/ask_and_tell.rs | 51 ++++++ examples/async_parallel.rs | 52 ++++++ examples/early_stopping.rs | 43 +++++ examples/journal_storage.rs | 68 ++++++++ examples/multi_objective.rs | 49 ++++++ examples/pruning.rs | 67 ++++++++ examples/pruning_and_callbacks.rs | 133 -------------- examples/sampler_comparison.rs | 18 +- 12 files changed, 376 insertions(+), 443 deletions(-) delete mode 100644 examples/advanced_features.rs create mode 100644 examples/ask_and_tell.rs create mode 100644 examples/async_parallel.rs create mode 100644 examples/early_stopping.rs create mode 100644 examples/journal_storage.rs create mode 100644 examples/multi_objective.rs create mode 100644 examples/pruning.rs delete mode 100644 examples/pruning_and_callbacks.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2150e89..1f44f41 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,14 +66,22 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Run basic_optimization run: cargo run --example basic_optimization - - name: Run sampler_comparison - run: cargo run --example sampler_comparison - - name: Run pruning_and_callbacks - run: cargo run --example pruning_and_callbacks - name: Run parameter_types run: cargo run --example parameter_types --features derive - - name: Run advanced_features - run: cargo run --example advanced_features --features "async,journal" + - name: Run sampler_comparison + run: cargo run --example sampler_comparison + - name: Run pruning + run: cargo run --example pruning + - name: Run early_stopping + run: cargo run --example early_stopping + - name: Run async_parallel + run: cargo run --example async_parallel --features async + - name: Run journal_storage + run: cargo run --example journal_storage --features journal + - name: Run ask_and_tell + run: cargo run --example ask_and_tell + - name: Run multi_objective + run: cargo run --example multi_objective docs: name: Docs diff --git a/Cargo.toml b/Cargo.toml index d1c11d2..93d67ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,13 +67,30 @@ name = "sampler_comparison" path = "examples/sampler_comparison.rs" [[example]] -name = "pruning_and_callbacks" -path = "examples/pruning_and_callbacks.rs" +name = "pruning" +path = "examples/pruning.rs" [[example]] -name = "advanced_features" -path = "examples/advanced_features.rs" -required-features = ["async", "journal"] +name = "early_stopping" +path = "examples/early_stopping.rs" + +[[example]] +name = "async_parallel" +path = "examples/async_parallel.rs" +required-features = ["async"] + +[[example]] +name = "journal_storage" +path = "examples/journal_storage.rs" +required-features = ["journal"] + +[[example]] +name = "ask_and_tell" +path = "examples/ask_and_tell.rs" + +[[example]] +name = "multi_objective" +path = "examples/multi_objective.rs" [[test]] name = "journal_tests" diff --git a/README.md b/README.md index 970013c..9e31556 100644 --- a/README.md +++ b/README.md @@ -52,11 +52,15 @@ println!("Best x = {:.4}, f(x) = {:.4}", best.get(&x).unwrap(), best.value); ## Examples ```sh -cargo run --example basic_optimization # Minimize a quadratic — simplest possible usage -cargo run --example sampler_comparison # Compare Random, TPE, and Grid on the same problem -cargo run --example pruning_and_callbacks # Trial pruning with MedianPruner + early stopping -cargo run --example parameter_types --features derive # All 5 param types + #[derive(Categorical)] -cargo run --example advanced_features --features async,journal # Async, journal storage, ask-and-tell, multi-objective +cargo run --example basic_optimization # Minimize a quadratic — simplest possible usage +cargo run --example parameter_types --features derive # All 5 param types + #[derive(Categorical)] +cargo run --example sampler_comparison # Compare Random, TPE, and Grid on the same problem +cargo run --example pruning # Trial pruning with MedianPruner +cargo run --example early_stopping # Halt a study when a target is reached +cargo run --example async_parallel --features async # Evaluate trials concurrently with tokio +cargo run --example journal_storage --features journal # Persist trials to disk and resume later +cargo run --example ask_and_tell # Decouple sampling from evaluation +cargo run --example multi_objective # Optimize competing objectives + Pareto front ``` ## Learn More diff --git a/examples/advanced_features.rs b/examples/advanced_features.rs deleted file mode 100644 index 41eed01..0000000 --- a/examples/advanced_features.rs +++ /dev/null @@ -1,277 +0,0 @@ -//! Advanced Features Example -//! -//! This example demonstrates four advanced capabilities of the optimizer crate: -//! -//! 1. **Async parallel optimization** — evaluate multiple trials concurrently -//! 2. **Journal storage** — persist trials to disk and resume studies later -//! 3. **Ask-and-tell interface** — decouple sampling from evaluation -//! 4. **Multi-objective optimization** — optimize competing objectives simultaneously -//! -//! Run with: `cargo run --example advanced_features --features "async,journal"` - -use std::time::Instant; - -use optimizer::multi_objective::MultiObjectiveStudy; -use optimizer::prelude::*; - -// ============================================================================ -// Section 1: Async Parallel Optimization -// ============================================================================ - -/// Runs multiple trials concurrently using tokio, reducing wall-clock time -/// when the objective function involves I/O or other async work. -async fn async_parallel_optimization() -> optimizer::Result<()> { - println!("=== Section 1: Async Parallel Optimization ===\n"); - - let sampler = TpeSampler::builder() - .n_startup_trials(5) - .seed(42) - .build() - .expect("Failed to build TPE sampler"); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - let x = FloatParam::new(-5.0, 5.0).name("x"); - let y = FloatParam::new(-5.0, 5.0).name("y"); - - let n_trials = 30; - let concurrency = 4; - - println!("Running {n_trials} trials with {concurrency} concurrent workers..."); - let start = Instant::now(); - - // optimize_parallel spawns up to `concurrency` trials at once. - // The closure must take ownership of Trial and return (Trial, value). - study - .optimize_parallel(n_trials, concurrency, { - let x = x.clone(); - let y = y.clone(); - move |mut trial| { - let x = x.clone(); - let y = y.clone(); - async move { - let xv = x.suggest(&mut trial)?; - let yv = y.suggest(&mut trial)?; - - // Simulate async I/O (e.g., calling an external service) - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - // Sphere function: minimum at origin - let value = xv * xv + yv * yv; - Ok::<_, optimizer::Error>((trial, value)) - } - } - }) - .await?; - - let elapsed = start.elapsed(); - let best = study.best_trial()?; - - println!( - "Completed in {elapsed:.2?} (vs ~{:.0?} sequential)", - std::time::Duration::from_millis(10 * n_trials as u64) - ); - println!( - "Best: f({:.3}, {:.3}) = {:.6}\n", - best.get(&x).unwrap(), - best.get(&y).unwrap(), - best.value - ); - - Ok(()) -} - -// ============================================================================ -// Section 2: Journal Storage -// ============================================================================ - -/// Persists trials to a JSONL file so that a study can be resumed later. -/// Useful for long-running experiments or crash recovery. -fn journal_storage_demo() -> optimizer::Result<()> { - println!("=== Section 2: Journal Storage ===\n"); - - let path = std::env::temp_dir().join("optimizer_advanced_example.jsonl"); - - // Clean up from any previous run - let _ = std::fs::remove_file(&path); - - let x = FloatParam::new(-5.0, 5.0).name("x"); - - // --- First run: optimize 20 trials and persist to disk --- - { - let storage = JournalStorage::<f64>::new(&path); - let study: Study<f64> = Study::builder() - .minimize() - .sampler(TpeSampler::new()) - .storage(storage) - .build(); - - study.optimize(20, |trial| { - let xv = x.suggest(trial)?; - Ok::<_, optimizer::Error>(xv * xv) - })?; - - println!( - "First run: {} trials saved to {}", - study.n_trials(), - path.display() - ); - } - - // --- Second run: resume from the journal file --- - { - // JournalStorage::open loads existing trials from disk - let storage = JournalStorage::<f64>::open(&path)?; - let study: Study<f64> = Study::builder() - .minimize() - .sampler(TpeSampler::new()) - .storage(storage) - .build(); - - // The sampler sees the prior 20 trials, so it starts informed - let before = study.n_trials(); - study.optimize(10, |trial| { - let xv = x.suggest(trial)?; - Ok::<_, optimizer::Error>(xv * xv) - })?; - - let best = study.best_trial()?; - println!( - "Resumed: {} → {} trials, best f({:.4}) = {:.6}", - before, - study.n_trials(), - best.get(&x).unwrap(), - best.value - ); - } - - // Clean up the temporary file - let _ = std::fs::remove_file(&path); - - println!(); - Ok(()) -} - -// ============================================================================ -// Section 3: Ask-and-Tell Interface -// ============================================================================ - -/// Decouples trial creation from evaluation. Useful when: -/// - Evaluations happen outside the optimizer (e.g., in a separate process) -/// - You want to batch evaluations before reporting results -/// - You need custom scheduling logic -fn ask_and_tell_demo() -> optimizer::Result<()> { - println!("=== Section 3: Ask-and-Tell Interface ===\n"); - - let study: Study<f64> = Study::new(Direction::Minimize); - - let x = FloatParam::new(-5.0, 5.0).name("x"); - let y = FloatParam::new(-5.0, 5.0).name("y"); - - // Ask for a batch of trials, evaluate externally, then tell results - for batch in 0..3 { - let batch_size = 5; - let mut trials = Vec::with_capacity(batch_size); - - // ask() creates trials with sampled parameters - for _ in 0..batch_size { - let mut trial = study.ask(); - let xv = x.suggest(&mut trial)?; - let yv = y.suggest(&mut trial)?; - - // Store values alongside the trial for later evaluation - trials.push((trial, xv, yv)); - } - - // Evaluate the batch (could be sent to workers, GPUs, etc.) - for (trial, xv, yv) in trials { - let value = xv * xv + yv * yv; - // tell() reports the result back to the study - study.tell(trial, Ok::<_, &str>(value)); - } - - println!( - "Batch {}: evaluated {} trials (total: {})", - batch + 1, - batch_size, - study.n_trials() - ); - } - - let best = study.best_trial()?; - println!( - "Best: f({:.3}, {:.3}) = {:.6}\n", - best.get(&x).unwrap(), - best.get(&y).unwrap(), - best.value - ); - - Ok(()) -} - -// ============================================================================ -// Section 4: Multi-Objective Optimization -// ============================================================================ - -/// Optimizes two competing objectives simultaneously. -/// Returns the Pareto front — the set of solutions where no objective can -/// be improved without worsening the other. -fn multi_objective_demo() -> optimizer::Result<()> { - println!("=== Section 4: Multi-Objective Optimization ===\n"); - - // Two objectives, both minimized - let study = MultiObjectiveStudy::new(vec![Direction::Minimize, Direction::Minimize]); - - let x = FloatParam::new(0.0, 1.0).name("x"); - - // Classic bi-objective problem: f1(x) = x², f2(x) = (x - 1)² - // The Pareto front is the curve where improving f1 worsens f2 and vice versa. - study.optimize(50, |trial| { - let xv = x.suggest(trial)?; - let f1 = xv * xv; - let f2 = (xv - 1.0) * (xv - 1.0); - Ok::<_, optimizer::Error>(vec![f1, f2]) - })?; - - let front = study.pareto_front(); - println!( - "Ran {} trials, Pareto front has {} solutions:", - study.n_trials(), - front.len() - ); - - // Show a few Pareto-optimal trade-offs - let mut sorted_front = front.clone(); - sorted_front.sort_by(|a, b| a.values[0].partial_cmp(&b.values[0]).unwrap()); - - for (i, trial) in sorted_front.iter().take(5).enumerate() { - println!( - " {}: x={:.3}, f1={:.4}, f2={:.4}", - i + 1, - trial.get(&x).unwrap(), - trial.values[0], - trial.values[1] - ); - } - if sorted_front.len() > 5 { - println!(" ... and {} more", sorted_front.len() - 5); - } - - println!(); - Ok(()) -} - -// ============================================================================ -// Main -// ============================================================================ - -#[tokio::main] -async fn main() -> optimizer::Result<()> { - async_parallel_optimization().await?; - journal_storage_demo()?; - ask_and_tell_demo()?; - multi_objective_demo()?; - - println!("All sections completed successfully!"); - Ok(()) -} diff --git a/examples/ask_and_tell.rs b/examples/ask_and_tell.rs new file mode 100644 index 0000000..0125c0d --- /dev/null +++ b/examples/ask_and_tell.rs @@ -0,0 +1,51 @@ +//! Ask-and-tell interface — decouple sampling from evaluation. +//! +//! Use `ask()` to get a trial with sampled parameters, evaluate it however +//! you like (workers, GPUs, external processes), then `tell()` the result. +//! This is useful for batch evaluation or custom scheduling. +//! +//! Run with: `cargo run --example ask_and_tell` + +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + let study: Study<f64> = Study::new(Direction::Minimize); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + for batch in 0..3 { + let batch_size = 5; + let mut trials = Vec::with_capacity(batch_size); + + // ask() creates trials with sampled parameters + for _ in 0..batch_size { + let mut trial = study.ask(); + let xv = x.suggest(&mut trial)?; + let yv = y.suggest(&mut trial)?; + trials.push((trial, xv, yv)); + } + + // Evaluate the batch (could be sent to workers, GPUs, etc.) + for (trial, xv, yv) in trials { + let value = xv * xv + yv * yv; + study.tell(trial, Ok::<_, &str>(value)); + } + + println!( + "Batch {}: evaluated {batch_size} trials (total: {})", + batch + 1, + study.n_trials(), + ); + } + + let best = study.best_trial()?; + println!( + "Best: f({:.3}, {:.3}) = {:.6}", + best.get(&x).unwrap(), + best.get(&y).unwrap(), + best.value, + ); + + Ok(()) +} diff --git a/examples/async_parallel.rs b/examples/async_parallel.rs new file mode 100644 index 0000000..f3c3949 --- /dev/null +++ b/examples/async_parallel.rs @@ -0,0 +1,52 @@ +//! Async parallel optimization — evaluate multiple trials concurrently. +//! +//! Uses `optimize_parallel` with tokio to run several trials at once, +//! reducing wall-clock time when the objective involves I/O or async work. +//! +//! Run with: `cargo run --example async_parallel --features async` + +use optimizer::prelude::*; + +#[tokio::main] +async fn main() -> optimizer::Result<()> { + let study: Study<f64> = Study::minimize(TpeSampler::new()); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + let n_trials = 30; + let concurrency = 4; + + println!("Running {n_trials} trials with {concurrency} concurrent workers..."); + + study + .optimize_parallel(n_trials, concurrency, { + let x = x.clone(); + let y = y.clone(); + move |mut trial| { + let x = x.clone(); + let y = y.clone(); + async move { + let xv = x.suggest(&mut trial)?; + let yv = y.suggest(&mut trial)?; + + // Simulate async I/O (e.g. calling an external service) + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + + let value = xv * xv + yv * yv; + Ok::<_, optimizer::Error>((trial, value)) + } + } + }) + .await?; + + let best = study.best_trial()?; + println!( + "Best: f({:.3}, {:.3}) = {:.6}", + best.get(&x).unwrap(), + best.get(&y).unwrap(), + best.value, + ); + + Ok(()) +} diff --git a/examples/early_stopping.rs b/examples/early_stopping.rs new file mode 100644 index 0000000..7f57c88 --- /dev/null +++ b/examples/early_stopping.rs @@ -0,0 +1,43 @@ +//! Early stopping — halt an entire study once a target is reached. +//! +//! Use `optimize_with_callback` to inspect each completed trial and return +//! `ControlFlow::Break(())` when the study should stop (e.g. a quality +//! threshold is met or a time budget is exhausted). +//! +//! Run with: `cargo run --example early_stopping` + +use std::ops::ControlFlow; + +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(-10.0, 10.0).name("x"); + + let target = 0.01; + + study.optimize_with_callback( + 100, // upper bound — we expect to stop much earlier + |trial| { + let xv = x.suggest(trial)?; + Ok::<_, Error>((xv - 3.0).powi(2)) + }, + |_study, completed| { + if completed.value < target { + println!("Target {target} reached at trial #{}", completed.id); + return ControlFlow::Break(()); + } + ControlFlow::Continue(()) + }, + )?; + + let best = study.best_trial()?; + println!( + "Stopped after {} trials — best f({:.4}) = {:.6}", + study.n_trials(), + best.get(&x).unwrap(), + best.value, + ); + + Ok(()) +} diff --git a/examples/journal_storage.rs b/examples/journal_storage.rs new file mode 100644 index 0000000..c85eba4 --- /dev/null +++ b/examples/journal_storage.rs @@ -0,0 +1,68 @@ +//! Journal storage — persist trials to disk and resume later. +//! +//! `JournalStorage` writes every trial to a JSONL file so that a study can +//! be resumed after a crash or across separate runs. +//! +//! Run with: `cargo run --example journal_storage --features journal` + +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + let path = std::env::temp_dir().join("optimizer_journal_example.jsonl"); + + // Clean up from any previous run + let _ = std::fs::remove_file(&path); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + + // --- First run: optimize 20 trials and persist to disk --- + { + let storage = JournalStorage::<f64>::new(&path); + let study: Study<f64> = Study::builder() + .minimize() + .sampler(TpeSampler::new()) + .storage(storage) + .build(); + + study.optimize(20, |trial| { + let xv = x.suggest(trial)?; + Ok::<_, optimizer::Error>(xv * xv) + })?; + + println!( + "First run: {} trials saved to {}", + study.n_trials(), + path.display(), + ); + } + + // --- Second run: resume from the journal file --- + { + let storage = JournalStorage::<f64>::open(&path)?; + let study: Study<f64> = Study::builder() + .minimize() + .sampler(TpeSampler::new()) + .storage(storage) + .build(); + + let before = study.n_trials(); + study.optimize(10, |trial| { + let xv = x.suggest(trial)?; + Ok::<_, optimizer::Error>(xv * xv) + })?; + + let best = study.best_trial()?; + println!( + "Resumed: {} → {} trials, best f({:.4}) = {:.6}", + before, + study.n_trials(), + best.get(&x).unwrap(), + best.value, + ); + } + + // Clean up + let _ = std::fs::remove_file(&path); + + Ok(()) +} diff --git a/examples/multi_objective.rs b/examples/multi_objective.rs new file mode 100644 index 0000000..3d66a8a --- /dev/null +++ b/examples/multi_objective.rs @@ -0,0 +1,49 @@ +//! Multi-objective optimization — optimize competing objectives simultaneously. +//! +//! `MultiObjectiveStudy` returns the Pareto front: the set of solutions where +//! no objective can be improved without worsening another. +//! +//! Run with: `cargo run --example multi_objective` + +use optimizer::multi_objective::MultiObjectiveStudy; +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + let study = MultiObjectiveStudy::new(vec![Direction::Minimize, Direction::Minimize]); + + let x = FloatParam::new(0.0, 1.0).name("x"); + + // Classic bi-objective: f1(x) = x², f2(x) = (x-1)² + // The Pareto front is the curve where improving f1 worsens f2. + study.optimize(50, |trial| { + let xv = x.suggest(trial)?; + let f1 = xv * xv; + let f2 = (xv - 1.0) * (xv - 1.0); + Ok::<_, optimizer::Error>(vec![f1, f2]) + })?; + + let front = study.pareto_front(); + println!( + "Ran {} trials, Pareto front has {} solutions:", + study.n_trials(), + front.len(), + ); + + let mut sorted = front.clone(); + sorted.sort_by(|a, b| a.values[0].partial_cmp(&b.values[0]).unwrap()); + + for (i, trial) in sorted.iter().take(5).enumerate() { + println!( + " {}: x={:.3}, f1={:.4}, f2={:.4}", + i + 1, + trial.get(&x).unwrap(), + trial.values[0], + trial.values[1], + ); + } + if sorted.len() > 5 { + println!(" ... and {} more", sorted.len() - 5); + } + + Ok(()) +} diff --git a/examples/pruning.rs b/examples/pruning.rs new file mode 100644 index 0000000..afbfea2 --- /dev/null +++ b/examples/pruning.rs @@ -0,0 +1,67 @@ +//! Trial pruning — stop unpromising trials early with `MedianPruner`. +//! +//! When your objective involves an iterative loop (e.g. training epochs), +//! the pruner compares intermediate values across trials and kills the +//! ones that fall below the median — saving compute on bad configurations. +//! +//! Run with: `cargo run --example pruning` + +use optimizer::prelude::*; + +fn main() -> optimizer::Result<()> { + // MedianPruner prunes trials whose intermediate value falls below the + // median of previously completed trials at the same step. + let study: Study<f64> = Study::builder() + .minimize() + .sampler(RandomSampler::with_seed(42)) + .pruner( + MedianPruner::new(Direction::Minimize) + .n_warmup_steps(3) // run at least 3 epochs before pruning + .n_min_trials(3), // need 3 completed trials before pruning kicks in + ) + .build(); + + let lr = FloatParam::new(1e-4, 1.0).name("learning_rate"); + let momentum = FloatParam::new(0.0, 0.99).name("momentum"); + + let n_epochs: u64 = 20; + + study.optimize(30, |trial| { + let lr_val = lr.suggest(trial)?; + let mom = momentum.suggest(trial)?; + + // Simulated training loop — good hyperparameters converge to low loss, + // bad ones plateau high, giving the pruner something to cut. + let mut loss = 1.0; + for epoch in 0..n_epochs { + let lr_penalty = (lr_val.log10() - 0.01_f64.log10()).powi(2); + let mom_penalty = (mom - 0.8).powi(2); + let base_loss = 0.02 + 0.05 * lr_penalty + 1.5 * mom_penalty; + let progress = (epoch as f64 + 1.0) / n_epochs as f64; + loss = base_loss + (1.0 - base_loss) * (-3.5 * progress).exp(); + + // Report intermediate value so the pruner can evaluate this trial. + trial.report(epoch, loss); + + // Check whether the pruner recommends stopping early. + if trial.should_prune() { + Err(TrialPruned)?; + } + } + + Ok::<_, Error>(loss) + })?; + + // --- Results --- + let best = study.best_trial()?; + println!( + "Completed {} trials ({} pruned)", + study.n_trials(), + study.n_pruned_trials() + ); + println!("Best trial #{}: loss = {:.6}", best.id, best.value); + println!(" learning_rate = {:.6}", best.get(&lr).unwrap()); + println!(" momentum = {:.4}", best.get(&momentum).unwrap()); + + Ok(()) +} diff --git a/examples/pruning_and_callbacks.rs b/examples/pruning_and_callbacks.rs deleted file mode 100644 index 207e55b..0000000 --- a/examples/pruning_and_callbacks.rs +++ /dev/null @@ -1,133 +0,0 @@ -//! Pruning and early-stopping example — demonstrates trial pruning with `MedianPruner` -//! and early stopping via `optimize_with_callback`. -//! -//! Simulates a training loop where each trial trains for multiple "epochs". The pruner -//! stops unpromising trials early, and a callback halts the entire study once a target -//! loss is reached. -//! -//! Run with: `cargo run --example pruning_and_callbacks` - -use std::ops::ControlFlow; - -use optimizer::TrialState; -use optimizer::prelude::*; - -fn main() -> optimizer::Result<()> { - let n_trials: usize = 30; - let n_epochs: u64 = 20; - let target_loss = 0.15; - - // Build a study with a seeded random sampler and MedianPruner. - // MedianPruner compares each trial's intermediate value against the median of - // completed trials at the same step — trials performing below median are pruned. - let study: Study<f64> = Study::builder() - .minimize() - .sampler(RandomSampler::with_seed(42)) - .pruner( - MedianPruner::new(Direction::Minimize) - .n_warmup_steps(3) // let every trial run at least 3 epochs before pruning - .n_min_trials(3), // need 3 completed trials before pruning kicks in - ) - .build(); - - let learning_rate = FloatParam::new(1e-4, 1.0).name("learning_rate"); - let momentum = FloatParam::new(0.0, 0.99).name("momentum"); - - // Use optimize_with_callback to get both pruning AND early stopping. - // The callback fires after each completed (or pruned) trial and can halt the study. - study.optimize_with_callback( - n_trials, - // --- Objective function: simulated training loop with pruning --- - |trial| { - let lr = learning_rate.suggest(trial)?; - let mom = momentum.suggest(trial)?; - - // Simulate training for n_epochs, reporting intermediate loss each epoch. - // Good hyperparameters (lr ≈ 0.01, momentum ≈ 0.8) converge to low loss; - // bad combos plateau high — giving the pruner something to cut. - let mut loss = 1.0; - for epoch in 0..n_epochs { - let lr_penalty = (lr.log10() - 0.01_f64.log10()).powi(2); // 0 at lr=0.01 - let mom_penalty = (mom - 0.8).powi(2); // 0 at momentum=0.8 - let base_loss = 0.02 + 0.05 * lr_penalty + 1.5 * mom_penalty; - let progress = (epoch as f64 + 1.0) / n_epochs as f64; - // Loss decays from 1.0 toward base_loss over epochs. - loss = base_loss + (1.0 - base_loss) * (-3.5 * progress).exp(); - - // Report the intermediate value so the pruner can evaluate this trial. - trial.report(epoch, loss); - - // Check whether the pruner recommends stopping this trial early. - if trial.should_prune() { - // Signal that this trial was pruned — the study records it as Pruned. - Err(TrialPruned)?; - } - } - - Ok::<_, Error>(loss) - }, - // --- Callback: early stopping when we hit the target --- - |study, completed_trial| { - let n_complete = study.n_trials(); - let n_pruned = study - .trials() - .iter() - .filter(|t| t.state == TrialState::Pruned) - .count(); - - match completed_trial.state { - TrialState::Pruned => { - println!( - " Trial {:>3} PRUNED at epoch {} (loss = {:.4}) \ - [{n_complete} done, {n_pruned} pruned]", - completed_trial.id, - completed_trial.intermediate_values.len(), - completed_trial - .intermediate_values - .last() - .map_or(f64::NAN, |v| v.1), - ); - } - TrialState::Complete => { - println!( - " Trial {:>3} complete: loss = {:.4} \ - [{n_complete} done, {n_pruned} pruned]", - completed_trial.id, completed_trial.value, - ); - } - _ => {} - } - - // Stop the entire study once we find a good enough result. - if completed_trial.state == TrialState::Complete && completed_trial.value < target_loss - { - println!("\n Early stopping: reached target loss {target_loss}!"); - return ControlFlow::Break(()); - } - - ControlFlow::Continue(()) - }, - )?; - - // --- Results --- - let best = study.best_trial().expect("at least one completed trial"); - let total = study.n_trials(); - let pruned = study - .trials() - .iter() - .filter(|t| t.state == TrialState::Pruned) - .count(); - - println!("\n--- Results ---"); - println!(" Total trials : {total}"); - println!(" Pruned : {pruned}"); - println!(" Completed : {}", total - pruned); - println!(" Best trial #{}: loss = {:.6}", best.id, best.value); - println!( - " learning_rate = {:.6}", - best.get(&learning_rate).unwrap() - ); - println!(" momentum = {:.4}", best.get(&momentum).unwrap()); - - Ok(()) -} diff --git a/examples/sampler_comparison.rs b/examples/sampler_comparison.rs index b743e2a..3034ab3 100644 --- a/examples/sampler_comparison.rs +++ b/examples/sampler_comparison.rs @@ -40,10 +40,7 @@ fn run_study(study: Study<f64>, n_trials: usize) -> f64 { fn main() { let n_trials: usize = 100; - println!("Comparing samplers on Sphere(x, y) = x² + y²"); - println!(" Search space: x ∈ [-5, 5], y ∈ [-3, 3]"); - println!(" Known minimum: f(0, 0) = 0"); - println!(" Trials per sampler: {n_trials}"); + println!("Comparing samplers on Sphere(x, y) = x² + y² ({n_trials} trials each)"); println!(); // --- Random sampler (baseline) --- @@ -78,17 +75,4 @@ fn main() { println!(" Random : {random_best:.6}"); println!(" TPE : {tpe_best:.6}"); println!(" Grid : {grid_best:.6}"); - println!(); - - // Find the winner - let results = [ - ("Random", random_best), - ("TPE", tpe_best), - ("Grid", grid_best), - ]; - let (winner, _) = results - .iter() - .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap()) - .unwrap(); - println!("Winner: {winner} (closest to known minimum of 0.0)"); } From 964b4d5749ab338f9ed507b9d73d1be9a1e9c478 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 12:25:20 +0100 Subject: [PATCH 05/48] feat: add Objective trait and unify optimize API - Add `Objective<V>` trait with lifecycle hooks (`before_trial`, `after_trial`, `max_retries`) in new `src/objective.rs` - Replace 14+ optimize variants with 6 methods: `optimize`, `optimize_with`, and async/parallel counterparts - `optimize*` methods accept closures directly (FnMut for sync, Fn for async); `optimize_with*` methods accept `impl Objective<V>` for struct-based objectives with hooks and retries - Remove `optimize_until`, `optimize_with_callback`, `optimize_with_retries`, `optimize_with_checkpoint`, and all deprecated `_with_sampler` methods --- examples/async_parallel.rs | 32 +- examples/early_stopping.rs | 53 +- src/lib.rs | 3 + src/objective.rs | 132 ++++ src/study.rs | 1245 ++++++++++++------------------------ tests/async_tests.rs | 118 +--- tests/integration.rs | 535 +++++++--------- tests/serde_tests.rs | 72 +-- 8 files changed, 872 insertions(+), 1318 deletions(-) create mode 100644 src/objective.rs diff --git a/examples/async_parallel.rs b/examples/async_parallel.rs index f3c3949..8a31f30 100644 --- a/examples/async_parallel.rs +++ b/examples/async_parallel.rs @@ -1,7 +1,8 @@ //! Async parallel optimization — evaluate multiple trials concurrently. //! //! Uses `optimize_parallel` with tokio to run several trials at once, -//! reducing wall-clock time when the objective involves I/O or async work. +//! reducing wall-clock time when the objective involves blocking work. +//! Each sync closure is internally wrapped in `spawn_blocking`. //! //! Run with: `cargo run --example async_parallel --features async` @@ -19,25 +20,18 @@ async fn main() -> optimizer::Result<()> { println!("Running {n_trials} trials with {concurrency} concurrent workers..."); + let xc = x.clone(); + let yc = y.clone(); study - .optimize_parallel(n_trials, concurrency, { - let x = x.clone(); - let y = y.clone(); - move |mut trial| { - let x = x.clone(); - let y = y.clone(); - async move { - let xv = x.suggest(&mut trial)?; - let yv = y.suggest(&mut trial)?; - - // Simulate async I/O (e.g. calling an external service) - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - - let value = xv * xv + yv * yv; - Ok::<_, optimizer::Error>((trial, value)) - } - } - }) + .optimize_parallel( + n_trials, + concurrency, + move |trial: &mut optimizer::Trial| { + let xv = xc.suggest(trial)?; + let yv = yc.suggest(trial)?; + Ok::<_, optimizer::Error>(xv * xv + yv * yv) + }, + ) .await?; let best = study.best_trial()?; diff --git a/examples/early_stopping.rs b/examples/early_stopping.rs index 7f57c88..efbf78c 100644 --- a/examples/early_stopping.rs +++ b/examples/early_stopping.rs @@ -1,8 +1,8 @@ //! Early stopping — halt an entire study once a target is reached. //! -//! Use `optimize_with_callback` to inspect each completed trial and return -//! `ControlFlow::Break(())` when the study should stop (e.g. a quality -//! threshold is met or a time budget is exhausted). +//! Implements the [`Objective`] trait on a custom struct and uses the +//! [`after_trial`](Objective::after_trial) hook to return +//! `ControlFlow::Break(())` when the best value drops below a threshold. //! //! Run with: `cargo run --example early_stopping` @@ -10,26 +10,41 @@ use std::ops::ControlFlow; use optimizer::prelude::*; +/// An objective that minimises `(x - 3)^2` and stops early once the +/// value drops below `target`. +struct EarlyStopObjective { + x: FloatParam, + target: f64, +} + +impl Objective<f64> for EarlyStopObjective { + type Error = Error; + + fn evaluate(&self, trial: &mut Trial) -> Result<f64> { + let v = self.x.suggest(trial)?; + Ok((v - 3.0).powi(2)) + } + + fn after_trial(&self, _study: &Study<f64>, trial: &CompletedTrial<f64>) -> ControlFlow<()> { + if trial.value < self.target { + println!("Target {} reached at trial #{}", self.target, trial.id); + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } +} + fn main() -> optimizer::Result<()> { let study: Study<f64> = Study::new(Direction::Minimize); let x = FloatParam::new(-10.0, 10.0).name("x"); - let target = 0.01; - - study.optimize_with_callback( - 100, // upper bound — we expect to stop much earlier - |trial| { - let xv = x.suggest(trial)?; - Ok::<_, Error>((xv - 3.0).powi(2)) - }, - |_study, completed| { - if completed.value < target { - println!("Target {target} reached at trial #{}", completed.id); - return ControlFlow::Break(()); - } - ControlFlow::Continue(()) - }, - )?; + let objective = EarlyStopObjective { + x: x.clone(), + target: 0.01, + }; + + study.optimize_with(100, objective)?; let best = study.best_trial()?; println!( diff --git a/src/lib.rs b/src/lib.rs index 0636f54..a696b6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -113,6 +113,7 @@ mod fanova; mod importance; mod kde; pub mod multi_objective; +pub mod objective; mod param; pub mod parameter; pub mod pareto; @@ -127,6 +128,7 @@ mod visualization; pub use error::{Error, Result, TrialPruned}; pub use fanova::{FanovaConfig, FanovaResult}; +pub use objective::Objective; #[cfg(feature = "derive")] pub use optimizer_derive::Categorical; #[cfg(feature = "serde")] @@ -150,6 +152,7 @@ pub mod prelude { pub use crate::multi_objective::{ MultiObjectiveSampler, MultiObjectiveStudy, MultiObjectiveTrial, }; + pub use crate::objective::Objective; pub use crate::parameter::{ BoolParam, Categorical, CategoricalParam, EnumParam, FloatParam, IntParam, ParamValue, Parameter, diff --git a/src/objective.rs b/src/objective.rs new file mode 100644 index 0000000..d1ce710 --- /dev/null +++ b/src/objective.rs @@ -0,0 +1,132 @@ +//! The [`Objective`] trait defines what gets optimized. +//! +//! For simple closures, pass them directly to +//! [`Study::optimize`](crate::Study::optimize): +//! +//! ``` +//! use optimizer::prelude::*; +//! +//! let study: Study<f64> = Study::new(Direction::Minimize); +//! let x = FloatParam::new(-10.0, 10.0).name("x"); +//! +//! study +//! .optimize(50, |trial| { +//! let v = x.suggest(trial)?; +//! Ok::<_, Error>((v - 3.0).powi(2)) +//! }) +//! .unwrap(); +//! ``` +//! +//! For richer control — early stopping, retries, or per-trial logging — +//! implement [`Objective`] on a struct and pass it to +//! [`Study::optimize_with`](crate::Study::optimize_with): +//! +//! ``` +//! use std::ops::ControlFlow; +//! +//! use optimizer::Objective; +//! use optimizer::prelude::*; +//! +//! struct QuadraticWithEarlyStopping { +//! x: FloatParam, +//! target: f64, +//! } +//! +//! impl Objective<f64> for QuadraticWithEarlyStopping { +//! type Error = Error; +//! +//! fn evaluate(&self, trial: &mut Trial) -> Result<f64> { +//! let v = self.x.suggest(trial)?; +//! Ok((v - 3.0).powi(2)) +//! } +//! +//! fn after_trial(&self, _study: &Study<f64>, trial: &CompletedTrial<f64>) -> ControlFlow<()> { +//! if trial.value < self.target { +//! ControlFlow::Break(()) +//! } else { +//! ControlFlow::Continue(()) +//! } +//! } +//! } +//! +//! let study: Study<f64> = Study::new(Direction::Minimize); +//! let obj = QuadraticWithEarlyStopping { +//! x: FloatParam::new(-10.0, 10.0).name("x"), +//! target: 1.0, +//! }; +//! study.optimize_with(200, obj).unwrap(); +//! assert!(study.best_value().unwrap() < 1.0); +//! ``` + +use core::ops::ControlFlow; + +use crate::sampler::CompletedTrial; +use crate::study::Study; +use crate::trial::Trial; + +/// Defines an objective function with lifecycle hooks for optimization. +/// +/// The only required method is [`evaluate`](Objective::evaluate), which +/// computes the objective value for a given trial. Optional hooks provide +/// early stopping ([`before_trial`](Objective::before_trial), +/// [`after_trial`](Objective::after_trial)) and automatic retries +/// ([`max_retries`](Objective::max_retries)). +/// +/// # When to use `Objective` vs a closure +/// +/// - **Closure** — pass directly to [`Study::optimize`](crate::Study::optimize) +/// for simple evaluate-only objectives. +/// - **`Objective` struct** — implement this trait when you need hooks +/// (`before_trial`, `after_trial`) or retries. +/// +/// # Thread safety +/// +/// The async optimization methods (`optimize_async`, `optimize_parallel`) +/// additionally require `Send + Sync + 'static` on the objective. The +/// sync `optimize` method has no thread-safety requirements. +pub trait Objective<V: PartialOrd = f64> { + /// The error type returned by [`evaluate`](Objective::evaluate). + type Error: ToString + 'static; + + /// Evaluate the objective function for a single trial. + /// + /// Sample parameters from `trial` via + /// [`Parameter::suggest`](crate::parameter::Parameter::suggest) and + /// return the objective value. Return `Err(TrialPruned)` to prune a + /// trial early. + /// + /// # Errors + /// + /// Any error whose type implements `ToString`. Pruning errors + /// (`Error::TrialPruned` or `TrialPruned`) are handled specially — + /// the trial is recorded as pruned rather than failed. + fn evaluate(&self, trial: &mut Trial) -> Result<V, Self::Error>; + + /// Called before each trial is created. + /// + /// Return `ControlFlow::Break(())` to stop the optimization loop + /// before the next trial starts. + /// + /// Default: always continues. + fn before_trial(&self, _study: &Study<V>) -> ControlFlow<()> { + ControlFlow::Continue(()) + } + + /// Called after each **completed** trial (not failed or pruned). + /// + /// Return `ControlFlow::Break(())` to stop the optimization loop. + /// + /// Default: always continues. + fn after_trial(&self, _study: &Study<V>, _trial: &CompletedTrial<V>) -> ControlFlow<()> { + ControlFlow::Continue(()) + } + + /// Maximum number of retries for a failed trial. + /// + /// When `evaluate` returns a non-pruning error and retries remain, + /// the same parameter configuration is re-evaluated. Set to `0` + /// (the default) to disable retries. + fn max_retries(&self) -> usize { + 0 + } +} diff --git a/src/study.rs b/src/study.rs index 3827942..fd9acb6 100644 --- a/src/study.rs +++ b/src/study.rs @@ -2,14 +2,10 @@ use core::any::Any; use core::fmt; -#[cfg(feature = "async")] -use core::future::Future; use core::marker::PhantomData; use core::ops::ControlFlow; -use core::time::Duration; use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use std::time::Instant; use parking_lot::{Mutex, RwLock}; @@ -883,26 +879,16 @@ where completed } - /// Run optimization with the given objective function. + /// Run optimization with a closure. /// - /// This method runs `n_trials` evaluations sequentially. For each trial: - /// 1. A new trial is created - /// 2. The objective function is called with the trial - /// 3. If successful, the trial is recorded as completed - /// 4. If the objective returns an error, the trial is recorded as failed - /// - /// Failed trials do not stop the optimization; the process continues with - /// the next trial. - /// - /// # Arguments - /// - /// * `n_trials` - The number of trials to run. - /// * `objective` - A closure that takes a mutable reference to a `Trial` and - /// returns the objective value or an error. + /// Runs up to `n_trials` evaluations of `objective` sequentially. + /// For lifecycle hooks (early stopping, retries), implement the + /// [`Objective`](crate::Objective) trait and use + /// [`optimize_with`](Self::optimize_with) instead. /// /// # Errors /// - /// Returns `Error::NoCompletedTrials` if all trials failed (no successful trials). + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. /// /// # Examples /// @@ -911,29 +897,25 @@ where /// use optimizer::sampler::random::RandomSampler; /// use optimizer::{Direction, Study}; /// - /// // Minimize x^2 /// let sampler = RandomSampler::with_seed(42); /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// /// let x_param = FloatParam::new(-10.0, 10.0); /// /// study - /// .optimize(10, |trial| { + /// .optimize(10, |trial: &mut optimizer::Trial| { /// let x = x_param.suggest(trial)?; /// Ok::<_, optimizer::Error>(x * x) /// }) /// .unwrap(); /// - /// // At least one trial should have completed /// assert!(study.n_trials() > 0); - /// let best = study.best_value().unwrap(); - /// assert!(best >= 0.0); + /// assert!(study.best_value().unwrap() >= 0.0); /// ``` pub fn optimize<F, E>(&self, n_trials: usize, mut objective: F) -> crate::Result<()> where - F: FnMut(&mut Trial) -> core::result::Result<V, E>, + F: FnMut(&mut Trial) -> Result<V, E>, E: ToString + 'static, - V: Default, + V: Clone + Default, { #[cfg(feature = "tracing")] let _span = @@ -941,270 +923,28 @@ where for _ in 0..n_trials { let mut trial = self.create_trial(); - match objective(&mut trial) { Ok(value) => { #[cfg(feature = "tracing")] let trial_id = trial.id(); self.complete_trial(trial, value); - - #[cfg(feature = "tracing")] - { - tracing::info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count() - == 1 - || trials.last().map(|t| t.id) == self.best_id(&trials) - { - tracing::info!(trial_id, "new best value found"); - } - } + trace_info!(trial_id, "trial completed"); } - Err(e) => { + Err(e) if is_trial_pruned(&e) => { #[cfg(feature = "tracing")] let trial_id = trial.id(); - if is_trial_pruned(&e) { - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - // Return error if no trials completed successfully - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run optimization asynchronously with the given objective function. - /// - /// This method runs `n_trials` evaluations sequentially, but the objective - /// function can be async (e.g., for I/O-bound operations like network requests - /// or file operations). - /// - /// The objective function takes ownership of the `Trial` and must return it - /// along with the result. This allows async operations to use the trial - /// across await points. - /// - /// # Arguments - /// - /// * `n_trials` - The number of trials to run. - /// * `objective` - A function that takes a `Trial` and returns a `Future` - /// that resolves to a tuple of `(Trial, Result<V, E>)`. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if all trials failed (no successful trials). - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// # #[cfg(feature = "async")] - /// # async fn example() -> optimizer::Result<()> { - /// // Minimize x^2 with async objective - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize_async(10, |mut trial| { - /// let x_param = x_param.clone(); - /// async move { - /// let x = x_param.suggest(&mut trial)?; - /// // Simulate async work (e.g., network request) - /// let value = x * x; - /// Ok::<_, optimizer::Error>((trial, value)) - /// } - /// }) - /// .await?; - /// - /// // At least one trial should have completed - /// assert!(study.n_trials() > 0); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "async")] - pub async fn optimize_async<F, Fut, E>( - &self, - n_trials: usize, - objective: F, - ) -> crate::Result<()> - where - F: Fn(Trial) -> Fut, - Fut: Future<Output = core::result::Result<(Trial, V), E>>, - E: ToString, - { - #[cfg(feature = "tracing")] - let _span = - tracing::info_span!("optimize_async", n_trials, direction = ?self.direction).entered(); - - for _ in 0..n_trials { - let trial = self.create_trial(); - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - - match objective(trial).await { - Ok((trial, value)) => { - self.complete_trial(trial, value); - trace_info!(trial_id, "trial completed"); + self.prune_trial(trial); + trace_info!(trial_id, "trial pruned"); } Err(e) => { - // For async, we don't have the trial back on error - // We'll just count this as a failed trial without recording it - let _ = e.to_string(); - trace_debug!(trial_id, "trial failed"); - } - } - } - - // Return error if no trials completed successfully - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run optimization with bounded parallelism for concurrent trial evaluation. - /// - /// This method runs up to `concurrency` trials simultaneously, allowing - /// efficient use of async I/O-bound objective functions. A semaphore limits - /// the number of concurrent evaluations. - /// - /// The objective function takes ownership of the `Trial` and must return it - /// along with the result. This allows async operations to use the trial - /// across await points. - /// - /// # Arguments - /// - /// * `n_trials` - The total number of trials to run. - /// * `concurrency` - The maximum number of trials to run simultaneously. - /// * `objective` - A function that takes a `Trial` and returns a `Future` - /// that resolves to a tuple of `(Trial, V)` or an error. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if all trials failed (no successful trials). - /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// # #[cfg(feature = "async")] - /// # async fn example() -> optimizer::Result<()> { - /// // Minimize x^2 with parallel async evaluation - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize_parallel(10, 4, move |mut trial| { - /// let x_param = x_param.clone(); - /// async move { - /// let x = x_param.suggest(&mut trial)?; - /// // Async objective function (e.g., network request) - /// let value = x * x; - /// Ok::<_, optimizer::Error>((trial, value)) - /// } - /// }) - /// .await?; - /// - /// // All trials should have completed - /// assert_eq!(study.n_trials(), 10); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "async")] - pub async fn optimize_parallel<F, Fut, E>( - &self, - n_trials: usize, - concurrency: usize, - objective: F, - ) -> crate::Result<()> - where - F: Fn(Trial) -> Fut + Send + Sync + 'static, - Fut: Future<Output = core::result::Result<(Trial, V), E>> + Send, - E: ToString + Send + 'static, - V: Send + 'static, - { - use tokio::sync::Semaphore; - - #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); - - let semaphore = Arc::new(Semaphore::new(concurrency)); - let objective = Arc::new(objective); - - let mut handles = Vec::with_capacity(n_trials); - - for _ in 0..n_trials { - let permit = semaphore - .clone() - .acquire_owned() - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - let trial = self.create_trial(); - let objective = Arc::clone(&objective); - - let handle = tokio::spawn(async move { - let result = objective(trial).await; - drop(permit); // Release semaphore permit when done - result - }); - - handles.push(handle); - } - - // Wait for all tasks and record results - for handle in handles { - match handle - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))? - { - Ok((trial, value)) => { #[cfg(feature = "tracing")] let trial_id = trial.id(); - self.complete_trial(trial, value); - trace_info!(trial_id, "trial completed"); - } - Err(e) => { - let _ = e.to_string(); + self.fail_trial(trial, e.to_string()); + trace_debug!(trial_id, "trial failed"); } } } - // Return error if no trials completed successfully let has_complete = self .storage .trials_arc() @@ -1218,130 +958,126 @@ where Ok(()) } - /// Run optimization with a callback for monitoring progress. - /// - /// This method is similar to `optimize`, but calls a callback function after - /// each completed trial. The callback can inspect the study state and the - /// completed trial, and can optionally stop optimization early by returning - /// `ControlFlow::Break(())`. - /// - /// # Arguments + /// Run optimization with an [`Objective`](crate::Objective) implementation. /// - /// * `n_trials` - The maximum number of trials to run. - /// * `objective` - A closure that takes a mutable reference to a `Trial` and - /// returns the objective value or an error. - /// * `callback` - A closure called after each successful trial. Returns - /// `ControlFlow::Continue(())` to proceed or `ControlFlow::Break(())` to stop. + /// Like [`optimize`](Self::optimize), but accepts a struct implementing + /// [`Objective`](crate::Objective) for lifecycle hooks + /// ([`before_trial`](crate::Objective::before_trial), + /// [`after_trial`](crate::Objective::after_trial)) and automatic retries + /// ([`max_retries`](crate::Objective::max_retries)). /// /// # Errors /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully - /// before optimization stopped (either by completing all trials or early stopping). - /// Returns `Error::Internal` if a completed trial is not found after adding (internal invariant violation). + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. /// /// # Examples /// /// ``` /// use std::ops::ControlFlow; /// - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// // Stop early when we find a good enough value - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + /// use optimizer::prelude::*; /// - /// let x_param = FloatParam::new(-10.0, 10.0); + /// struct QuadraticObj { + /// x: FloatParam, + /// target: f64, + /// } /// - /// study - /// .optimize_with_callback( - /// 100, - /// |trial| { - /// let x = x_param.suggest(trial)?; - /// Ok::<_, optimizer::Error>(x * x) - /// }, - /// |_study, completed_trial| { - /// // Stop early if we find a value less than 1.0 - /// if completed_trial.value < 1.0 { - /// ControlFlow::Break(()) - /// } else { - /// ControlFlow::Continue(()) - /// } - /// }, - /// ) - /// .unwrap(); + /// impl Objective<f64> for QuadraticObj { + /// type Error = Error; + /// fn evaluate(&self, trial: &mut Trial) -> Result<f64> { + /// let v = self.x.suggest(trial)?; + /// Ok((v - 3.0).powi(2)) + /// } + /// fn after_trial(&self, _: &Study<f64>, t: &CompletedTrial<f64>) -> ControlFlow<()> { + /// if t.value < self.target { + /// ControlFlow::Break(()) + /// } else { + /// ControlFlow::Continue(()) + /// } + /// } + /// } /// - /// // May have stopped early, but should have at least one trial - /// assert!(study.n_trials() > 0); - /// ``` - pub fn optimize_with_callback<F, C, E>( + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let obj = QuadraticObj { + /// x: FloatParam::new(-10.0, 10.0), + /// target: 1.0, + /// }; + /// study.optimize_with(200, obj).unwrap(); + /// assert!(study.best_value().unwrap() < 1.0); + /// ``` + #[allow(clippy::needless_pass_by_value)] + pub fn optimize_with( &self, n_trials: usize, - mut objective: F, - mut callback: C, + objective: impl crate::objective::Objective<V>, ) -> crate::Result<()> where V: Clone + Default, - F: FnMut(&mut Trial) -> core::result::Result<V, E>, - C: FnMut(&Study<V>, &CompletedTrial<V>) -> ControlFlow<()>, - E: ToString + 'static, { #[cfg(feature = "tracing")] let _span = - tracing::info_span!("optimize", n_trials, direction = ?self.direction).entered(); + tracing::info_span!("optimize_with", n_trials, direction = ?self.direction).entered(); + + let max_retries = objective.max_retries(); for _ in 0..n_trials { - let mut trial = self.create_trial(); + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } - match objective(&mut trial) { - Ok(value) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); + let mut trial = self.create_trial(); + let mut retries = 0; + loop { + match objective.evaluate(&mut trial) { + Ok(value) => { + #[cfg(feature = "tracing")] + let trial_id = trial.id(); + self.complete_trial(trial, value); - #[cfg(feature = "tracing")] - { - tracing::info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count() - == 1 - || trials.last().map(|t| t.id) == self.best_id(&trials) + #[cfg(feature = "tracing")] { - tracing::info!(trial_id, "new best value found"); + tracing::info!(trial_id, "trial completed"); + let trials = self.storage.trials_arc().read(); + if trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .count() + == 1 + || trials.last().map(|t| t.id) == self.best_id(&trials) + { + tracing::info!(trial_id, "new best value found"); + } } - } - - // Get the just-completed trial for the callback - let trials = self.storage.trials_arc().read(); - let Some(completed) = trials.last() else { - return Err(crate::Error::Internal( - "completed trial not found after adding", - )); - }; - - // Call the callback and check if we should stop - // Note: We need to drop the read lock before calling callback - // to avoid potential deadlock if callback accesses the study - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = callback(self, &completed_clone) { + // Fire after_trial hook + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + if let ControlFlow::Break(()) = + objective.after_trial(self, &completed_clone) + { + // Return early — at least one trial completed. + return Ok(()); + } + } break; } - } - Err(e) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - if is_trial_pruned(&e) { - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); + Err(e) if !is_trial_pruned(&e) && retries < max_retries => { + retries += 1; + trial = self.create_trial_with_params(trial.params().clone()); + } + Err(e) => { + #[cfg(feature = "tracing")] + let trial_id = trial.id(); + if is_trial_pruned(&e) { + self.prune_trial(trial); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(trial, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + break; } } } @@ -1360,210 +1096,85 @@ where Ok(()) } - /// Run optimization until the given duration has elapsed. + + /// Run async optimization with a closure. /// - /// Trials that are already running when the timeout is reached will - /// complete — we never interrupt mid-trial. The actual elapsed time - /// may therefore slightly exceed the specified duration. + /// Each evaluation is wrapped in + /// [`spawn_blocking`](tokio::task::spawn_blocking), keeping the async + /// runtime responsive for CPU-bound objectives. Trials run sequentially. /// - /// # Arguments - /// - /// * `duration` - The maximum wall-clock time to spend on optimization. - /// * `objective` - A closure that takes a mutable reference to a `Trial` and - /// returns the objective value or an error. + /// For lifecycle hooks, use [`optimize_with_async`](Self::optimize_with_async). /// /// # Errors /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully - /// before the timeout. + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. + /// Returns `Error::TaskError` if a spawned blocking task panics. /// /// # Examples /// /// ``` - /// use std::time::Duration; - /// /// use optimizer::parameter::{FloatParam, Parameter}; /// use optimizer::sampler::random::RandomSampler; /// use optimizer::{Direction, Study}; /// + /// # #[cfg(feature = "async")] + /// # async fn example() -> optimizer::Result<()> { /// let sampler = RandomSampler::with_seed(42); /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// /// let x_param = FloatParam::new(-10.0, 10.0); /// /// study - /// .optimize_until(Duration::from_millis(100), |trial| { + /// .optimize_async(10, move |trial: &mut optimizer::Trial| { /// let x = x_param.suggest(trial)?; /// Ok::<_, optimizer::Error>(x * x) /// }) - /// .unwrap(); - /// - /// assert!(study.n_trials() > 0); - /// ``` - pub fn optimize_until<F, E>(&self, duration: Duration, mut objective: F) -> crate::Result<()> - where - F: FnMut(&mut Trial) -> core::result::Result<V, E>, - E: ToString + 'static, - V: Default, - { - #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize", duration_secs = duration.as_secs(), direction = ?self.direction).entered(); - - let deadline = Instant::now() + duration; - while Instant::now() < deadline { - let mut trial = self.create_trial(); - - match objective(&mut trial) { - Ok(value) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); - trace_info!(trial_id, "trial completed"); - } - Err(e) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - if is_trial_pruned(&e) { - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run optimization until the given duration has elapsed, with a callback. - /// - /// Like [`optimize_until`](Self::optimize_until), but calls a callback after - /// each completed trial. The callback can stop optimization early by returning - /// `ControlFlow::Break(())`. - /// - /// # Arguments - /// - /// * `duration` - The maximum wall-clock time to spend on optimization. - /// * `objective` - A closure that takes a mutable reference to a `Trial` and - /// returns the objective value or an error. - /// * `callback` - A closure called after each successful trial. Returns - /// `ControlFlow::Continue(())` to proceed or `ControlFlow::Break(())` to stop. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// Returns `Error::Internal` if a completed trial is not found after adding. - /// - /// # Examples - /// - /// ``` - /// use std::ops::ControlFlow; - /// use std::time::Duration; - /// - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize_until_with_callback( - /// Duration::from_secs(1), - /// |trial| { - /// let x = x_param.suggest(trial)?; - /// Ok::<_, optimizer::Error>(x * x) - /// }, - /// |_study, completed_trial| { - /// if completed_trial.value < 1.0 { - /// ControlFlow::Break(()) - /// } else { - /// ControlFlow::Continue(()) - /// } - /// }, - /// ) - /// .unwrap(); + /// .await?; /// /// assert!(study.n_trials() > 0); + /// # Ok(()) + /// # } /// ``` - pub fn optimize_until_with_callback<F, C, E>( - &self, - duration: Duration, - mut objective: F, - mut callback: C, - ) -> crate::Result<()> + #[cfg(feature = "async")] + pub async fn optimize_async<F, E>(&self, n_trials: usize, objective: F) -> crate::Result<()> where - V: Clone + Default, - F: FnMut(&mut Trial) -> core::result::Result<V, E>, - C: FnMut(&Study<V>, &CompletedTrial<V>) -> ControlFlow<()>, - E: ToString + 'static, + F: Fn(&mut Trial) -> Result<V, E> + Send + Sync + 'static, + E: ToString + Send + 'static, + V: Clone + Default + Send + 'static, { #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize", duration_secs = duration.as_secs(), direction = ?self.direction).entered(); - - let deadline = Instant::now() + duration; - while Instant::now() < deadline { - let mut trial = self.create_trial(); - - match objective(&mut trial) { - Ok(value) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); - - #[cfg(feature = "tracing")] - { - tracing::info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count() - == 1 - || trials.last().map(|t| t.id) == self.best_id(&trials) - { - tracing::info!(trial_id, "new best value found"); - } - } + let _span = + tracing::info_span!("optimize_async", n_trials, direction = ?self.direction).entered(); - let trials = self.storage.trials_arc().read(); - let Some(completed) = trials.last() else { - return Err(crate::Error::Internal( - "completed trial not found after adding", - )); - }; + let objective = Arc::new(objective); - let completed_clone = completed.clone(); - drop(trials); + for _ in 0..n_trials { + let obj = Arc::clone(&objective); + let mut trial = self.create_trial(); + let result = tokio::task::spawn_blocking(move || { + let res = obj(&mut trial); + (trial, res) + }) + .await + .map_err(|e| crate::Error::TaskError(e.to_string()))?; - if let ControlFlow::Break(()) = callback(self, &completed_clone) { - break; - } + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.complete_trial(t, value); + trace_info!(trial_id, "trial completed"); } - Err(e) => { + (t, Err(e)) if is_trial_pruned(&e) => { #[cfg(feature = "tracing")] - let trial_id = trial.id(); - if is_trial_pruned(&e) { - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } + let trial_id = t.id(); + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); } } } @@ -1581,49 +1192,83 @@ where Ok(()) } - /// Run optimization asynchronously until the given duration has elapsed. - /// - /// The async variant of [`optimize_until`](Self::optimize_until). Trials are - /// run sequentially, but the objective function can be async (useful for - /// I/O-bound evaluations). - /// - /// # Arguments + /// Run async optimization with an [`Objective`](crate::Objective) implementation. /// - /// * `duration` - The maximum wall-clock time to spend on optimization. - /// * `objective` - A function that takes a `Trial` and returns a `Future` - /// that resolves to a tuple of `(Trial, V)` or an error. + /// Like [`optimize_async`](Self::optimize_async), but accepts a struct + /// implementing [`Objective`](crate::Objective) for lifecycle hooks and + /// automatic retries. /// /// # Errors /// /// Returns `Error::NoCompletedTrials` if no trials completed successfully. + /// Returns `Error::TaskError` if a spawned blocking task panics. #[cfg(feature = "async")] - pub async fn optimize_until_async<F, Fut, E>( - &self, - duration: Duration, - objective: F, - ) -> crate::Result<()> + pub async fn optimize_with_async<O>(&self, n_trials: usize, objective: O) -> crate::Result<()> where - F: Fn(Trial) -> Fut, - Fut: Future<Output = core::result::Result<(Trial, V), E>>, - E: ToString, + O: crate::objective::Objective<V> + Send + Sync + 'static, + O::Error: Send, + V: Clone + Default + Send + 'static, { #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_until_async", duration_secs = duration.as_secs(), direction = ?self.direction).entered(); + let _span = + tracing::info_span!("optimize_with_async", n_trials, direction = ?self.direction) + .entered(); - let deadline = Instant::now() + duration; - while Instant::now() < deadline { - let trial = self.create_trial(); - #[cfg(feature = "tracing")] - let trial_id = trial.id(); + let objective = Arc::new(objective); + let max_retries = objective.max_retries(); - match objective(trial).await { - Ok((trial, value)) => { - self.complete_trial(trial, value); - trace_info!(trial_id, "trial completed"); - } - Err(e) => { - let _ = e.to_string(); - trace_debug!(trial_id, "trial failed"); + for _ in 0..n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + + let mut trial = self.create_trial(); + let mut retries = 0; + loop { + let obj = Arc::clone(&objective); + let result = tokio::task::spawn_blocking(move || { + let res = obj.evaluate(&mut trial); + (trial, res) + }) + .await + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.complete_trial(t, value); + trace_info!(trial_id, "trial completed"); + + // Fire after_trial hook + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + if let ControlFlow::Break(()) = + objective.after_trial(self, &completed_clone) + { + return Ok(()); + } + } + break; + } + (t, Err(e)) if !is_trial_pruned(&e) && retries < max_retries => { + retries += 1; + trial = self.create_trial_with_params(t.params().clone()); + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + break; + } } } } @@ -1641,79 +1286,138 @@ where Ok(()) } - /// Run optimization with bounded parallelism until the given duration has elapsed. + /// Run parallel optimization with a closure. /// - /// The parallel variant of [`optimize_until`](Self::optimize_until). Runs up to - /// `concurrency` trials simultaneously using async tasks. New trials are spawned - /// as long as the deadline has not been reached; trials already running when the - /// deadline passes will complete. + /// Spawns up to `concurrency` evaluations concurrently using + /// [`spawn_blocking`](tokio::task::spawn_blocking). Results are + /// collected via a [`JoinSet`](tokio::task::JoinSet). /// - /// # Arguments - /// - /// * `duration` - The maximum wall-clock time to spend spawning new trials. - /// * `concurrency` - The maximum number of trials to run simultaneously. - /// * `objective` - A function that takes a `Trial` and returns a `Future` - /// that resolves to a tuple of `(Trial, V)` or an error. + /// For lifecycle hooks, use + /// [`optimize_with_parallel`](Self::optimize_with_parallel). /// /// # Errors /// /// Returns `Error::NoCompletedTrials` if no trials completed successfully. /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// # #[cfg(feature = "async")] + /// # async fn example() -> optimizer::Result<()> { + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + /// let x_param = FloatParam::new(-10.0, 10.0); + /// + /// study + /// .optimize_parallel(10, 4, move |trial: &mut optimizer::Trial| { + /// let x = x_param.suggest(trial)?; + /// Ok::<_, optimizer::Error>(x * x) + /// }) + /// .await?; + /// + /// assert_eq!(study.n_trials(), 10); + /// # Ok(()) + /// # } + /// ``` #[cfg(feature = "async")] - pub async fn optimize_until_parallel<F, Fut, E>( + #[allow(clippy::missing_panics_doc)] + pub async fn optimize_parallel<F, E>( &self, - duration: Duration, + n_trials: usize, concurrency: usize, objective: F, ) -> crate::Result<()> where - F: Fn(Trial) -> Fut + Send + Sync + 'static, - Fut: Future<Output = core::result::Result<(Trial, V), E>> + Send, + F: Fn(&mut Trial) -> Result<V, E> + Send + Sync + 'static, E: ToString + Send + 'static, - V: Send + 'static, + V: Clone + Default + Send + 'static, { use tokio::sync::Semaphore; + use tokio::task::JoinSet; #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_until_parallel", duration_secs = duration.as_secs(), concurrency, direction = ?self.direction).entered(); + let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); - let deadline = Instant::now() + duration; - let semaphore = Arc::new(Semaphore::new(concurrency)); let objective = Arc::new(objective); + let semaphore = Arc::new(Semaphore::new(concurrency)); + let mut join_set: JoinSet<(Trial, Result<V, E>)> = JoinSet::new(); + let mut spawned = 0; + + while spawned < n_trials { + // If the join set is full, drain one result to free a slot. + while join_set.len() >= concurrency { + let result = join_set + .join_next() + .await + .expect("join_set should not be empty") + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.complete_trial(t, value); + trace_info!(trial_id, "trial completed"); + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + } + } + } - let mut handles = Vec::new(); - - while Instant::now() < deadline { let permit = semaphore .clone() .acquire_owned() .await .map_err(|e| crate::Error::TaskError(e.to_string()))?; - let trial = self.create_trial(); - let objective = Arc::clone(&objective); - let handle = tokio::spawn(async move { - let result = objective(trial).await; + let mut trial = self.create_trial(); + let obj = Arc::clone(&objective); + join_set.spawn(async move { + let result = tokio::task::spawn_blocking(move || { + let res = obj(&mut trial); + (trial, res) + }) + .await + .expect("spawn_blocking should not panic"); drop(permit); result }); - - handles.push(handle); + spawned += 1; } - for handle in handles { - match handle - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))? - { - Ok((trial, value)) => { + // Drain remaining in-flight tasks. + while let Some(result) = join_set.join_next().await { + let result = result.map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); + let trial_id = t.id(); + self.complete_trial(t, value); trace_info!(trial_id, "trial completed"); } - Err(e) => { - let _ = e.to_string(); + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } } } } @@ -1731,102 +1435,138 @@ where Ok(()) } - /// Run optimization with automatic retry for failed trials. - /// - /// If the objective function returns an error, the same parameter - /// configuration is retried up to `max_retries` times. Only after all - /// retries are exhausted is the trial recorded as permanently failed. + /// Run parallel optimization with an [`Objective`](crate::Objective) implementation. /// - /// `n_trials` counts unique parameter configurations, not total - /// evaluations. A trial retried 3 times still counts as 1 toward the - /// `n_trials` limit. - /// - /// # Arguments - /// - /// * `n_trials` - The number of unique configurations to evaluate. - /// * `max_retries` - Maximum retry attempts per failed trial. - /// * `objective` - A closure that takes a mutable reference to a `Trial` - /// and returns the objective value or an error. + /// Like [`optimize_parallel`](Self::optimize_parallel), but accepts a struct + /// implementing [`Objective`](crate::Objective) for lifecycle hooks and + /// automatic retries. The [`after_trial`](crate::Objective::after_trial) + /// hook fires as each result arrives — returning `Break` stops spawning + /// new trials while in-flight tasks drain. /// /// # Errors /// /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// let call_count = std::cell::Cell::new(0u32); - /// study - /// .optimize_with_retries(5, 2, |trial| { - /// let x = x_param.suggest(trial)?; - /// call_count.set(call_count.get() + 1); - /// // Fail once every other call to exercise retry - /// if call_count.get() % 2 == 0 { - /// Err::<f64, _>(optimizer::Error::Internal("transient")) - /// } else { - /// Ok(x * x) - /// } - /// }) - /// .unwrap(); - /// - /// assert_eq!(study.n_trials(), 5); - /// ``` - pub fn optimize_with_retries<F, E>( + /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. + #[cfg(feature = "async")] + #[allow(clippy::missing_panics_doc, clippy::too_many_lines)] + pub async fn optimize_with_parallel<O>( &self, n_trials: usize, - max_retries: usize, - mut objective: F, + concurrency: usize, + objective: O, ) -> crate::Result<()> where - F: FnMut(&mut Trial) -> core::result::Result<V, E>, - E: ToString + 'static, - V: Default, + O: crate::objective::Objective<V> + Send + Sync + 'static, + O::Error: Send, + V: Clone + Default + Send + 'static, { + use tokio::sync::Semaphore; + use tokio::task::JoinSet; + #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_with_retries", n_trials, max_retries, direction = ?self.direction).entered(); + let _span = tracing::info_span!("optimize_with_parallel", n_trials, concurrency, direction = ?self.direction).entered(); - for _ in 0..n_trials { - let mut trial = self.create_trial(); - let mut retries = 0; - loop { - match objective(&mut trial) { - Ok(value) => { + let objective = Arc::new(objective); + let semaphore = Arc::new(Semaphore::new(concurrency)); + let mut join_set: JoinSet<(Trial, Result<V, O::Error>)> = JoinSet::new(); + let mut spawned = 0; + + 'spawn: while spawned < n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + + // If the join set is full, drain one result to free a slot. + while join_set.len() >= concurrency { + let result = join_set + .join_next() + .await + .expect("join_set should not be empty") + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); + let trial_id = t.id(); + self.complete_trial(t, value); trace_info!(trial_id, "trial completed"); - break; - } - Err(_) if retries < max_retries => { - retries += 1; - // Create a new trial with the same parameters - trial = self.create_trial_with_params(trial.params().clone()); + + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + if let ControlFlow::Break(()) = + objective.after_trial(self, &completed_clone) + { + break 'spawn; + } + } } - Err(e) => { + (t, Err(e)) => { #[cfg(feature = "tracing")] - let trial_id = trial.id(); + let trial_id = t.id(); if is_trial_pruned(&e) { - self.prune_trial(trial); + self.prune_trial(t); trace_info!(trial_id, "trial pruned"); } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial permanently failed"); + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); } - break; + } + } + } + + let permit = semaphore + .clone() + .acquire_owned() + .await + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + + let mut trial = self.create_trial(); + let obj = Arc::clone(&objective); + join_set.spawn(async move { + let result = tokio::task::spawn_blocking(move || { + let res = obj.evaluate(&mut trial); + (trial, res) + }) + .await + .expect("spawn_blocking should not panic"); + drop(permit); + result + }); + spawned += 1; + } + + // Drain remaining in-flight tasks. + while let Some(result) = join_set.join_next().await { + let result = result.map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.complete_trial(t, value); + trace_info!(trial_id, "trial completed"); + // Still fire after_trial for bookkeeping, but don't break — we're draining. + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + let _ = objective.after_trial(self, &completed_clone); + } + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); } } } } - // Return error if no trials completed successfully let has_complete = self .storage .trials_arc() @@ -2324,111 +2064,6 @@ where } } -// Specialized implementation for Study<f64> that provides deprecated `_with_sampler` aliases. -// -// For Study<f64>, the generic methods from `impl<V> Study<V>` (like `optimize()`, -// `create_trial()`) now automatically use the sampler via the `trial_factory`. -// The `_with_sampler` method names are deprecated in favor of the generic names. -#[allow(clippy::missing_errors_doc)] -impl Study<f64> { - /// Deprecated: use `create_trial()` instead. - /// - /// The generic `create_trial()` now automatically integrates with the sampler - /// for `Study<f64>`. - #[deprecated( - since = "0.2.0", - note = "use `create_trial()` instead — it now uses the sampler automatically for Study<f64>" - )] - #[must_use] - pub fn create_trial_with_sampler(&self) -> Trial { - self.create_trial() - } - - /// Deprecated: use `optimize()` instead. - /// - /// The generic `optimize()` now automatically integrates with the sampler - /// for `Study<f64>`. - #[deprecated( - since = "0.2.0", - note = "use `optimize()` instead — it now uses the sampler automatically for Study<f64>" - )] - pub fn optimize_with_sampler<F, E>(&self, n_trials: usize, objective: F) -> crate::Result<()> - where - F: FnMut(&mut Trial) -> core::result::Result<f64, E>, - E: ToString + 'static, - { - self.optimize(n_trials, objective) - } - - /// Deprecated: use `optimize_with_callback()` instead. - /// - /// The generic `optimize_with_callback()` now automatically integrates with the - /// sampler for `Study<f64>`. - #[deprecated( - since = "0.2.0", - note = "use `optimize_with_callback()` instead — it now uses the sampler automatically for Study<f64>" - )] - pub fn optimize_with_callback_sampler<F, C, E>( - &self, - n_trials: usize, - objective: F, - callback: C, - ) -> crate::Result<()> - where - F: FnMut(&mut Trial) -> core::result::Result<f64, E>, - C: FnMut(&Study<f64>, &CompletedTrial<f64>) -> ControlFlow<()>, - E: ToString + 'static, - { - self.optimize_with_callback(n_trials, objective, callback) - } - - /// Deprecated: use `optimize_async()` instead. - /// - /// The generic `optimize_async()` now automatically integrates with the sampler - /// for `Study<f64>`. - #[cfg(feature = "async")] - #[deprecated( - since = "0.2.0", - note = "use `optimize_async()` instead — it now uses the sampler automatically for Study<f64>" - )] - pub async fn optimize_async_with_sampler<F, Fut, E>( - &self, - n_trials: usize, - objective: F, - ) -> crate::Result<()> - where - F: Fn(Trial) -> Fut, - Fut: Future<Output = core::result::Result<(Trial, f64), E>>, - E: ToString, - { - self.optimize_async(n_trials, objective).await - } - - /// Deprecated: use `optimize_parallel()` instead. - /// - /// The generic `optimize_parallel()` now automatically integrates with the - /// sampler for `Study<f64>`. - #[cfg(feature = "async")] - #[deprecated( - since = "0.2.0", - note = "use `optimize_parallel()` instead — it now uses the sampler automatically for Study<f64>" - )] - pub async fn optimize_parallel_with_sampler<F, Fut, E>( - &self, - n_trials: usize, - concurrency: usize, - objective: F, - ) -> crate::Result<()> - where - F: Fn(Trial) -> Fut + Send + Sync + 'static, - Fut: Future<Output = core::result::Result<(Trial, f64), E>> + Send, - E: ToString + Send + 'static, - { - self.optimize_parallel(n_trials, concurrency, objective) - .await - } -} - impl<V: PartialOrd + Send + Sync + 'static> Study<V> { /// Create a study with a custom sampler, pruner, and storage backend. /// @@ -2719,40 +2354,6 @@ impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { } } -#[cfg(feature = "serde")] -impl<V: PartialOrd + Clone + Default + serde::Serialize> Study<V> { - /// Run optimization with automatic checkpointing every `interval` trials. - /// - /// This is convenience sugar over [`optimize_with_callback`](Self::optimize_with_callback) - /// combined with [`save`](Self::save). The checkpoint is written atomically so - /// a crash mid-write will never leave a corrupt file. - /// - /// # Errors - /// - /// Returns an error if the optimization itself fails (see - /// [`optimize`](Self::optimize) for details). Checkpoint I/O errors are - /// silently ignored (best-effort). - pub fn optimize_with_checkpoint<F, E>( - &self, - n_trials: usize, - checkpoint_interval: usize, - checkpoint_path: impl AsRef<std::path::Path>, - objective: F, - ) -> crate::Result<()> - where - F: FnMut(&mut Trial) -> core::result::Result<V, E>, - E: ToString + 'static, - { - let path = checkpoint_path.as_ref().to_owned(); - self.optimize_with_callback(n_trials, objective, |study, _trial| { - if study.n_trials().is_multiple_of(checkpoint_interval) { - let _ = study.save(&path); - } - ControlFlow::Continue(()) - }) - } -} - #[cfg(feature = "serde")] impl<V: PartialOrd + Send + Sync + Clone + serde::de::DeserializeOwned + 'static> Study<V> { /// Load a study from a JSON file. diff --git a/tests/async_tests.rs b/tests/async_tests.rs index b99b1df..ba9754f 100644 --- a/tests/async_tests.rs +++ b/tests/async_tests.rs @@ -17,12 +17,9 @@ async fn test_optimize_async_basic() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize_async(10, move |mut trial| { - let x_param = x_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x * x)) - } + .optimize_async(10, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) }) .await .expect("async optimization should succeed"); @@ -45,12 +42,9 @@ async fn test_optimize_async_with_tpe() { let x_param = FloatParam::new(-5.0, 5.0); study - .optimize_async(15, move |mut trial| { - let x_param = x_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x * x)) - } + .optimize_async(15, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) }) .await .expect("async optimization with sampler should succeed"); @@ -68,12 +62,9 @@ async fn test_optimize_parallel() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize_parallel(20, 4, move |mut trial| { - let x_param = x_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x * x)) - } + .optimize_parallel(20, 4, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) }) .await .expect("parallel optimization should succeed"); @@ -95,14 +86,10 @@ async fn test_optimize_parallel_with_tpe() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize_parallel(15, 3, move |mut trial| { - let x_param = x_param.clone(); - let y_param = y_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - let y = y_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x * x + y * y)) - } + .optimize_parallel(15, 3, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + let y = y_param.suggest(trial)?; + Ok::<_, Error>(x * x + y * y) }) .await .expect("parallel optimization with sampler should succeed"); @@ -115,27 +102,8 @@ async fn test_optimize_async_all_failures() { let study: Study<f64> = Study::new(Direction::Minimize); let result = study - .optimize_async(5, |trial| async move { - let _ = trial; - Err::<(_, f64), &str>("always fails") - }) - .await; - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[tokio::test] -#[allow(deprecated)] -async fn test_optimize_async_with_sampler_all_failures() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study - .optimize_async_with_sampler(5, |trial| async move { - let _ = trial; - Err::<(_, f64), &str>("always fails") + .optimize_async(5, |_trial: &mut optimizer::Trial| { + Err::<f64, &str>("always fails") }) .await; @@ -150,27 +118,8 @@ async fn test_optimize_parallel_all_failures() { let study: Study<f64> = Study::new(Direction::Minimize); let result = study - .optimize_parallel(5, 2, |trial| async move { - let _ = trial; - Err::<(_, f64), &str>("always fails") - }) - .await; - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[tokio::test] -#[allow(deprecated)] -async fn test_optimize_parallel_with_sampler_all_failures() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study - .optimize_parallel_with_sampler(5, 2, |trial| async move { - let _ = trial; - Err::<(_, f64), &str>("always fails") + .optimize_parallel(5, 2, |_trial: &mut optimizer::Trial| { + Err::<f64, &str>("always fails") }) .await; @@ -190,16 +139,13 @@ async fn test_optimize_async_partial_failures() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize_async(10, move |mut trial| { + .optimize_async(10, move |trial: &mut optimizer::Trial| { let count = counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - let x_param = x_param.clone(); - async move { - if count.is_multiple_of(2) { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x)) - } else { - Err(Error::NoCompletedTrials) // Use as error type - } + if count.is_multiple_of(2) { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + } else { + Err(Error::NoCompletedTrials) // Use as error type } }) .await @@ -218,12 +164,9 @@ async fn test_optimize_parallel_high_concurrency() { // Run with concurrency higher than n_trials study - .optimize_parallel(5, 10, move |mut trial| { - let x_param = x_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x)) - } + .optimize_parallel(5, 10, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) }) .await .expect("should handle high concurrency"); @@ -240,12 +183,9 @@ async fn test_optimize_parallel_single_concurrency() { // Run with concurrency of 1 (sequential) study - .optimize_parallel(10, 1, move |mut trial| { - let x_param = x_param.clone(); - async move { - let x = x_param.suggest(&mut trial)?; - Ok::<_, Error>((trial, x)) - } + .optimize_parallel(10, 1, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) }) .await .expect("should work with single concurrency"); diff --git a/tests/integration.rs b/tests/integration.rs index b895606..a231230 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -570,28 +570,36 @@ fn test_tpe_with_integer_parameters() { #[test] fn test_callback_early_stopping() { - use std::cell::Cell; use std::ops::ControlFlow; - let study: Study<f64> = Study::new(Direction::Minimize); - let trials_run = Cell::new(0); - let x_param = FloatParam::new(0.0, 10.0); + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + + struct EarlyStopAfter5 { + x_param: FloatParam, + } + + impl Objective<f64> for EarlyStopAfter5 { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { + if study.n_trials() >= 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + } + let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize_with_callback( + .optimize_with( 100, - |trial| { - trials_run.set(trials_run.get() + 1); - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }, - |_study, _trial| { - // Stop after 5 trials - if trials_run.get() >= 5 { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } + EarlyStopAfter5 { + x_param: FloatParam::new(0.0, 10.0), }, ) .expect("optimization should succeed"); @@ -769,48 +777,10 @@ fn test_optimize_all_trials_fail() { } #[test] -fn test_optimize_with_callback_all_trials_fail() { - use std::ops::ControlFlow; - - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.optimize_with_callback( - 5, - |_trial| Err::<f64, &str>("always fails"), - |_study, _trial| ControlFlow::Continue(()), - ); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[test] -#[allow(deprecated)] -fn test_optimize_with_sampler_all_trials_fail() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.optimize_with_sampler(5, |_trial| Err::<f64, &str>("always fails")); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[test] -#[allow(deprecated)] -fn test_optimize_with_callback_sampler_all_trials_fail() { - use std::ops::ControlFlow; - +fn test_optimize_with_all_trials_fail() { let study: Study<f64> = Study::new(Direction::Minimize); - let result = study.optimize_with_callback_sampler( - 5, - |_trial| Err::<f64, &str>("always fails"), - |_study, _trial| ControlFlow::Continue(()), - ); + let result = study.optimize(5, |_trial| Err::<f64, &str>("always fails")); assert!( matches!(result, Err(Error::NoCompletedTrials)), @@ -964,27 +934,6 @@ fn test_tpe_with_step_distributions() { assert!(best.value < 100.0, "should find reasonable solution"); } -#[test] -#[allow(deprecated)] -fn test_create_trial_vs_create_trial_with_sampler() { - let sampler = RandomSampler::with_seed(42); - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - // create_trial() creates trial with sampler integration for Study<f64> - let trial1 = study.create_trial(); - assert_eq!(trial1.id(), 0); - - // create_trial_with_sampler() is deprecated but still works - let trial2 = study.create_trial_with_sampler(); - assert_eq!(trial2.id(), 1); - - // Both should work for suggesting parameters - let x_param = FloatParam::new(0.0, 1.0); - let mut trial3 = study.create_trial(); - let x = x_param.suggest(&mut trial3).unwrap(); - assert!((0.0..=1.0).contains(&x)); -} - #[test] fn test_manual_trial_completion() { let study: Study<f64> = Study::new(Direction::Minimize); @@ -1058,19 +1007,34 @@ fn test_tpe_empty_good_or_bad_values_fallback() { fn test_callback_early_stopping_on_first_trial() { use std::ops::ControlFlow; - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + struct StopImmediately { + x_param: FloatParam, + } + + impl Objective<f64> for StopImmediately { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial( + &self, + _study: &Study<f64>, + _trial: &CompletedTrial<f64>, + ) -> ControlFlow<()> { + ControlFlow::Break(()) + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize_with_callback( + .optimize_with( 100, - |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }, - |_study, _trial| { - // Stop immediately after first trial - ControlFlow::Break(()) + StopImmediately { + x_param: FloatParam::new(0.0, 10.0), }, ) .expect("optimization should succeed"); @@ -1082,23 +1046,35 @@ fn test_callback_early_stopping_on_first_trial() { fn test_callback_sampler_early_stopping() { use std::ops::ControlFlow; + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + + struct StopAfter3 { + x_param: FloatParam, + } + + impl Objective<f64> for StopAfter3 { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { + if study.n_trials() >= 3 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + } + let sampler = RandomSampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 10.0); - study - .optimize_with_callback( + .optimize_with( 100, - |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }, - |study, _trial| { - if study.n_trials() >= 3 { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } + StopAfter3 { + x_param: FloatParam::new(0.0, 10.0), }, ) .expect("optimization should succeed"); @@ -1364,165 +1340,6 @@ fn test_completed_trial_get() { assert!((1..=10).contains(&n_val)); } -// ============================================================================= -// Tests for timeout-based optimization -// ============================================================================= - -#[test] -fn test_optimize_until_runs_for_approximately_specified_duration() { - use std::time::{Duration, Instant}; - - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-10.0, 10.0); - - let duration = Duration::from_millis(200); - let start = Instant::now(); - - study - .optimize_until(duration, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .unwrap(); - - let elapsed = start.elapsed(); - assert!( - elapsed >= duration, - "should run for at least the specified duration, elapsed: {elapsed:?}" - ); - // Allow generous upper bound — the last trial may overshoot - assert!( - elapsed < duration + Duration::from_millis(200), - "should not overshoot excessively, elapsed: {elapsed:?}" - ); -} - -#[test] -fn test_optimize_until_completes_at_least_one_trial() { - use std::time::Duration; - - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-10.0, 10.0); - - study - .optimize_until(Duration::from_millis(100), |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .unwrap(); - - assert!( - study.n_trials() >= 1, - "should complete at least one trial, got {}", - study.n_trials() - ); -} - -#[test] -fn test_optimize_until_works_with_minimize() { - use std::time::Duration; - - let sampler = RandomSampler::with_seed(42); - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(-10.0, 10.0); - - study - .optimize_until(Duration::from_millis(100), |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .unwrap(); - - let best = study.best_value().unwrap(); - assert!(best >= 0.0, "x^2 should be non-negative"); -} - -#[test] -fn test_optimize_until_works_with_maximize() { - use std::time::Duration; - - let sampler = RandomSampler::with_seed(42); - let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize_until(Duration::from_millis(100), |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - let best = study.best_value().unwrap(); - assert!(best >= 0.0); -} - -#[test] -fn test_optimize_until_with_callback_early_stopping() { - use std::ops::ControlFlow; - use std::time::Duration; - - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize_until_with_callback( - Duration::from_secs(10), // long timeout — callback should stop early - |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }, - |study, _trial| { - if study.n_trials() >= 5 { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } - }, - ) - .unwrap(); - - assert_eq!( - study.n_trials(), - 5, - "callback should have stopped after 5 trials" - ); -} - -#[test] -fn test_optimize_until_all_trials_fail() { - use std::time::Duration; - - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.optimize_until(Duration::from_millis(50), |_trial| { - Err::<f64, &str>("always fails") - }); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[test] -fn test_optimize_until_with_non_f64_value_type() { - use std::time::Duration; - - let study: Study<i32> = Study::new(Direction::Minimize); - let x_param = IntParam::new(-10, 10); - - study - .optimize_until(Duration::from_millis(100), |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x.abs() as i32) - }) - .unwrap(); - - assert!(study.n_trials() >= 1); - let best = study.best_trial().unwrap(); - assert!(best.value >= 0); -} - // ============================================================================= // Tests for top_trials // ============================================================================= @@ -1966,55 +1783,111 @@ fn test_display_matches_summary() { } // ============================================================================= -// Tests: optimize_with_retries +// Tests: optimize_with retries via Objective trait // ============================================================================= #[test] fn test_retries_successful_trials_not_retried() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use optimizer::Objective; + + struct SuccessObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for SuccessObj { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + self.call_count.fetch_add(1, Ordering::Relaxed); + Ok(x * x) + } + fn max_retries(&self) -> usize { + 3 + } + } + let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - let call_count = std::cell::Cell::new(0u32); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = SuccessObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; - study - .optimize_with_retries(5, 3, |trial| { - let x = x_param.suggest(trial)?; - call_count.set(call_count.get() + 1); - Ok::<_, Error>(x * x) - }) - .unwrap(); + study.optimize_with(5, obj).unwrap(); // All trials succeed on first try — exactly 5 calls - assert_eq!(call_count.get(), 5); + assert_eq!(call_count.load(Ordering::Relaxed), 5); assert_eq!(study.n_trials(), 5); } #[test] fn test_retries_failed_trials_retried_up_to_max() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use optimizer::Objective; + + struct AlwaysFailObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for AlwaysFailObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + self.call_count.fetch_add(1, Ordering::Relaxed); + Err("always fails".to_string()) + } + fn max_retries(&self) -> usize { + 3 + } + } + let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - let call_count = std::cell::Cell::new(0u32); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = AlwaysFailObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; - let result = study.optimize_with_retries(1, 3, |trial| { - let _ = x_param.suggest(trial).unwrap(); - call_count.set(call_count.get() + 1); - Err::<f64, _>("always fails") - }); + let result = study.optimize_with(1, obj); // 1 initial attempt + 3 retries = 4 total calls - assert_eq!(call_count.get(), 4); + assert_eq!(call_count.load(Ordering::Relaxed), 4); // No trials completed assert!(matches!(result, Err(Error::NoCompletedTrials))); } #[test] fn test_retries_permanently_failed_after_exhaustion() { + use optimizer::Objective; + + struct AlwaysFailObj { + x_param: FloatParam, + } + + impl Objective<f64> for AlwaysFailObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + Err("transient error".to_string()) + } + fn max_retries(&self) -> usize { + 2 + } + } + let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); + let obj = AlwaysFailObj { + x_param: FloatParam::new(0.0, 10.0), + }; - let result = study.optimize_with_retries(3, 2, |trial| { - let _ = x_param.suggest(trial).unwrap(); - Err::<f64, _>("transient error") - }); + let result = study.optimize_with(3, obj); assert!( matches!(result, Err(Error::NoCompletedTrials)), @@ -2029,26 +1902,47 @@ fn test_retries_permanently_failed_after_exhaustion() { #[test] fn test_retries_uses_same_parameters() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - let seen_values = std::cell::RefCell::new(Vec::new()); - let call_count = std::cell::Cell::new(0u32); + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::{Arc, Mutex}; - study - .optimize_with_retries(1, 2, |trial| { - let x = x_param.suggest(trial).map_err(|e| e.to_string())?; - seen_values.borrow_mut().push(x); - call_count.set(call_count.get() + 1); + use optimizer::Objective; + + struct RetryObj { + x_param: FloatParam, + seen_values: Arc<Mutex<Vec<f64>>>, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for RetryObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + self.seen_values.lock().unwrap().push(x); + let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; // Fail first two attempts, succeed on third - if call_count.get() < 3 { - Err::<f64, _>("transient".to_string()) + if count < 3 { + Err("transient".to_string()) } else { Ok(x * x) } - }) - .unwrap(); + } + fn max_retries(&self) -> usize { + 2 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let seen_values = Arc::new(Mutex::new(Vec::new())); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = RetryObj { + x_param: FloatParam::new(0.0, 10.0), + seen_values: Arc::clone(&seen_values), + call_count: Arc::clone(&call_count), + }; + + study.optimize_with(1, obj).unwrap(); - let values = seen_values.borrow(); + let values = seen_values.lock().unwrap(); assert_eq!(values.len(), 3, "should be called 3 times (1 + 2 retries)"); // All three calls should have gotten the same parameter value assert_eq!(values[0], values[1]); @@ -2057,25 +1951,44 @@ fn test_retries_uses_same_parameters() { #[test] fn test_retries_n_trials_counts_unique_configs() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - let call_count = std::cell::Cell::new(0u32); + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; - study - .optimize_with_retries(3, 2, |trial| { - let x = x_param.suggest(trial).map_err(|e| e.to_string())?; - call_count.set(call_count.get() + 1); + use optimizer::Objective; + + struct FailFirstObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for FailFirstObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; // Fail first attempt of each config, succeed on retry - if call_count.get() % 2 == 1 { - Err::<f64, _>("transient".to_string()) + if count % 2 == 1 { + Err("transient".to_string()) } else { Ok(x * x) } - }) - .unwrap(); + } + fn max_retries(&self) -> usize { + 2 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = FailFirstObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; + + study.optimize_with(3, obj).unwrap(); // 3 unique configs, each needing 2 calls = 6 total calls - assert_eq!(call_count.get(), 6); + assert_eq!(call_count.load(Ordering::Relaxed), 6); // But only 3 completed trials assert_eq!(study.n_trials(), 3); } @@ -2087,7 +2000,7 @@ fn test_retries_with_zero_max_retries_same_as_optimize() { let call_count = std::cell::Cell::new(0u32); study - .optimize_with_retries(5, 0, |trial| { + .optimize(5, |trial| { let x = x_param.suggest(trial)?; call_count.set(call_count.get() + 1); Ok::<_, Error>(x * x) diff --git a/tests/serde_tests.rs b/tests/serde_tests.rs index 202bc19..b734152 100644 --- a/tests/serde_tests.rs +++ b/tests/serde_tests.rs @@ -173,73 +173,27 @@ fn round_trip_preserves_trial_id_counter() { } #[test] -fn checkpoint_file_created_at_interval() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(-10.0, 10.0).name("x"); - - let dir = tempdir(); - let checkpoint = dir.join("checkpoint.json"); - - study - .optimize_with_checkpoint(10, 3, &checkpoint, |trial| { - let v = x.suggest(trial)?; - Ok::<_, optimizer::Error>(v * v) - }) - .unwrap(); - - // Checkpoint should exist (written at trials 3, 6, 9) - assert!(checkpoint.exists(), "checkpoint file was not created"); - - // Load it and verify it's valid - let loaded: Study<f64> = Study::load(&checkpoint).unwrap(); - // Last checkpoint was at trial 9, so it should have 9 trials - assert_eq!(loaded.n_trials(), 9); - - std::fs::remove_dir_all(&dir).ok(); -} - -#[test] -fn checkpoint_overwrites_previous() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 1.0); - - let dir = tempdir(); - let checkpoint = dir.join("checkpoint.json"); - - study - .optimize_with_checkpoint(6, 3, &checkpoint, |trial| { - let v = x.suggest(trial)?; - Ok::<_, optimizer::Error>(v) - }) - .unwrap(); - - // The checkpoint at trial 6 should overwrite the one from trial 3 - let loaded: Study<f64> = Study::load(&checkpoint).unwrap(); - assert_eq!(loaded.n_trials(), 6); - - std::fs::remove_dir_all(&dir).ok(); -} - -#[test] -fn resume_from_checkpoint_continues_trial_ids() { +fn save_and_resume_continues_trial_ids() { let study: Study<f64> = Study::new(Direction::Minimize); let x = FloatParam::new(-5.0, 5.0).name("x"); let dir = tempdir(); - let checkpoint = dir.join("resume.json"); + let save_path = dir.join("resume.json"); - // Run 10 trials with checkpointing + // Run 10 trials study - .optimize_with_checkpoint(10, 5, &checkpoint, |trial| { + .optimize(10, |trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v * v) }) .unwrap(); - // Load and continue - let loaded: Study<f64> = Study::load(&checkpoint).unwrap(); + // Save and reload + study.save(&save_path).unwrap(); + let loaded: Study<f64> = Study::load(&save_path).unwrap(); assert_eq!(loaded.n_trials(), 10); + // Continue with 5 more trials let remaining = 15 - loaded.n_trials(); loaded .optimize(remaining, |trial| { @@ -261,24 +215,26 @@ fn resume_from_checkpoint_continues_trial_ids() { } #[test] -fn atomic_write_no_temp_file_left_behind() { +fn save_uses_atomic_write() { let study: Study<f64> = Study::new(Direction::Minimize); let x = FloatParam::new(0.0, 1.0); let dir = tempdir(); - let checkpoint = dir.join("atomic.json"); + let save_path = dir.join("atomic.json"); study - .optimize_with_checkpoint(3, 3, &checkpoint, |trial| { + .optimize(3, |trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v) }) .unwrap(); + study.save(&save_path).unwrap(); + // The temp file should have been renamed, not left behind let tmp_path = dir.join(".atomic.json.tmp"); assert!(!tmp_path.exists(), "temp file was not cleaned up"); - assert!(checkpoint.exists(), "checkpoint file was not created"); + assert!(save_path.exists(), "save file was not created"); std::fs::remove_dir_all(&dir).ok(); } From d81d1de4ff543ce8f6c4da9ac4085655d25064e1 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 12:43:44 +0100 Subject: [PATCH 06/48] refactor(tests): split integration.rs into focused subfolders - Delete 20 duplicate tests already covered by parameter_tests.rs - Move 11 pure Trial unit tests into src/trial.rs - Split remaining 84 integration tests into tests/study/ (9 modules) - Group sampler tests into tests/sampler/ (7 modules) - Group pruner tests into tests/pruner/ (2 modules) --- src/trial.rs | 161 ++ tests/integration.rs | 2252 ----------------- tests/pruner/main.rs | 2 + .../median.rs} | 0 .../threshold.rs} | 0 .../{bohb_integration.rs => sampler/bohb.rs} | 0 tests/{cma_es_tests.rs => sampler/cma_es.rs} | 2 - .../differential_evolution.rs} | 0 tests/{gp_tests.rs => sampler/gp.rs} | 2 - tests/sampler/main.rs | 15 + .../multivariate_tpe.rs} | 0 tests/sampler/random.rs | 142 ++ tests/sampler/tpe.rs | 381 +++ tests/study/ask_tell.rs | 98 + tests/study/builder.rs | 92 + tests/study/constraints.rs | 101 + tests/study/enqueue.rs | 189 ++ tests/study/iterator.rs | 43 + tests/study/main.rs | 15 + tests/study/objective.rs | 346 +++ tests/study/summary.rs | 71 + tests/study/top_trials.rs | 76 + tests/study/workflow.rs | 259 ++ 23 files changed, 1991 insertions(+), 2256 deletions(-) delete mode 100644 tests/integration.rs create mode 100644 tests/pruner/main.rs rename tests/{median_pruner_tests.rs => pruner/median.rs} (100%) rename tests/{threshold_pruner_tests.rs => pruner/threshold.rs} (100%) rename tests/{bohb_integration.rs => sampler/bohb.rs} (100%) rename tests/{cma_es_tests.rs => sampler/cma_es.rs} (99%) rename tests/{differential_evolution_tests.rs => sampler/differential_evolution.rs} (100%) rename tests/{gp_tests.rs => sampler/gp.rs} (99%) create mode 100644 tests/sampler/main.rs rename tests/{multivariate_tpe_integration.rs => sampler/multivariate_tpe.rs} (100%) create mode 100644 tests/sampler/random.rs create mode 100644 tests/sampler/tpe.rs create mode 100644 tests/study/ask_tell.rs create mode 100644 tests/study/builder.rs create mode 100644 tests/study/constraints.rs create mode 100644 tests/study/enqueue.rs create mode 100644 tests/study/iterator.rs create mode 100644 tests/study/main.rs create mode 100644 tests/study/objective.rs create mode 100644 tests/study/summary.rs create mode 100644 tests/study/top_trials.rs create mode 100644 tests/study/workflow.rs diff --git a/src/trial.rs b/src/trial.rs index e46077e..1ef939e 100644 --- a/src/trial.rs +++ b/src/trial.rs @@ -483,3 +483,164 @@ impl Trial { Ok(result) } } + +#[cfg(test)] +#[allow(clippy::float_cmp)] +mod tests { + use crate::parameter::{BoolParam, CategoricalParam, FloatParam, IntParam, Parameter}; + use crate::types::TrialState; + + #[test] + fn trial_state() { + // from test_trial_state (L643 of integration.rs) + let trial = super::Trial::new(0); + assert_eq!(trial.state(), TrialState::Running); + } + + #[test] + fn trial_params_access() { + // from test_trial_params_access (L651) + let x_param = FloatParam::new(0.0, 1.0); + let n_param = IntParam::new(1, 10); + let mut trial = super::Trial::new(0); + + x_param.suggest(&mut trial).unwrap(); + n_param.suggest(&mut trial).unwrap(); + + let params = trial.params(); + assert_eq!(params.len(), 2); + } + + #[test] + fn trial_debug_format() { + // from test_trial_debug_format (L792) + let param = FloatParam::new(0.0, 1.0); + let mut trial = super::Trial::new(42); + param.suggest(&mut trial).unwrap(); + + let debug_str = format!("{trial:?}"); + + assert!(debug_str.contains("Trial")); + assert!(debug_str.contains("42")); + assert!(debug_str.contains("has_sampler")); + } + + #[test] + fn distributions_access() { + // from test_distributions_access (L960) + let x_param = FloatParam::new(0.0, 1.0); + let n_param = IntParam::new(1, 10); + let opt_param = CategoricalParam::new(vec!["a", "b", "c"]); + let mut trial = super::Trial::new(0); + + x_param.suggest(&mut trial).unwrap(); + n_param.suggest(&mut trial).unwrap(); + opt_param.suggest(&mut trial).unwrap(); + + let dists = trial.distributions(); + assert_eq!(dists.len(), 3); + } + + #[test] + fn multiple_parameters_independent_caching() { + // from test_multiple_parameters_independent_caching (L356) + let x_param = FloatParam::new(0.0, 1.0); + let y_param = FloatParam::new(0.0, 1.0); + let n_param = IntParam::new(1, 10); + let opt_param = CategoricalParam::new(vec!["a", "b"]); + let mut trial = super::Trial::new(0); + + let x = x_param.suggest(&mut trial).unwrap(); + let y = y_param.suggest(&mut trial).unwrap(); + let n = n_param.suggest(&mut trial).unwrap(); + let opt = opt_param.suggest(&mut trial).unwrap(); + + assert_eq!(x, x_param.suggest(&mut trial).unwrap()); + assert_eq!(y, y_param.suggest(&mut trial).unwrap()); + assert_eq!(n, n_param.suggest(&mut trial).unwrap()); + assert_eq!(opt, opt_param.suggest(&mut trial).unwrap()); + } + + #[test] + fn suggest_bool_multiple_parameters() { + // from test_suggest_bool_multiple_parameters (L1131) + let dropout_param = BoolParam::new(); + let batchnorm_param = BoolParam::new(); + let skip_param = BoolParam::new(); + let mut trial = super::Trial::new(0); + + let a = dropout_param.suggest(&mut trial).unwrap(); + let b = batchnorm_param.suggest(&mut trial).unwrap(); + let c = skip_param.suggest(&mut trial).unwrap(); + + assert_eq!(a, dropout_param.suggest(&mut trial).unwrap()); + assert_eq!(b, batchnorm_param.suggest(&mut trial).unwrap()); + assert_eq!(c, skip_param.suggest(&mut trial).unwrap()); + } + + #[test] + fn param_name() { + // from test_param_name (L1312) + let param = FloatParam::new(0.0, 1.0).name("learning_rate"); + let mut trial = super::Trial::new(0); + param.suggest(&mut trial).unwrap(); + + let labels = trial.param_labels(); + let label = labels.values().next().unwrap(); + assert_eq!(label, "learning_rate"); + } + + #[test] + fn step_float_snaps_to_grid() { + // from test_step_float_snaps_to_grid (L676) + let param = FloatParam::new(0.0, 1.0).step(0.25); + let mut trial = super::Trial::new(0); + + let x = param.suggest(&mut trial).unwrap(); + + let valid_values = [0.0, 0.25, 0.5, 0.75, 1.0]; + let is_valid = valid_values.iter().any(|&v| (x - v).abs() < 1e-10); + assert!(is_valid, "stepped float {x} should snap to grid"); + } + + #[test] + fn step_int_snaps_to_grid() { + // from test_step_int_snaps_to_grid (L689) + let param = IntParam::new(0, 100).step(25); + let mut trial = super::Trial::new(0); + + let n = param.suggest(&mut trial).unwrap(); + + assert!( + n % 25 == 0 && (0..=100).contains(&n), + "stepped int {n} should snap to grid" + ); + } + + #[test] + fn int_bounds_with_low_equals_high() { + // from test_int_bounds_with_low_equals_high (L1086) + let mut trial = super::Trial::new(0); + + let n_param = IntParam::new(5, 5); + let n = n_param.suggest(&mut trial).unwrap(); + assert_eq!(n, 5); + + let x_param = FloatParam::new(3.0, 3.0); + let x = x_param.suggest(&mut trial).unwrap(); + assert_eq!(x, 3.0); + } + + #[test] + fn single_value_float_range() { + // from test_single_value_float_range (L1296) + let param = FloatParam::new(4.2, 4.2); + let mut trial = super::Trial::new(0); + + let x = param.suggest(&mut trial).unwrap(); + assert!( + (x - 4.2).abs() < f64::EPSILON, + "single-value range should return that value" + ); + } +} diff --git a/tests/integration.rs b/tests/integration.rs deleted file mode 100644 index a231230..0000000 --- a/tests/integration.rs +++ /dev/null @@ -1,2252 +0,0 @@ -//! Integration tests for the optimizer library. - -#![allow( - clippy::cast_sign_loss, - clippy::cast_precision_loss, - clippy::cast_possible_truncation -)] - -use optimizer::parameter::{BoolParam, CategoricalParam, FloatParam, IntParam, Parameter}; -use optimizer::sampler::random::RandomSampler; -use optimizer::sampler::tpe::TpeSampler; -use optimizer::{Direction, Error, Study, Trial}; - -// ============================================================================= -// Test: optimize simple quadratic function with TPE, finds near-optimal -// ============================================================================= - -#[test] -fn test_tpe_optimizes_quadratic_function() { - // Minimize f(x) = (x - 3)^2 where x in [-10, 10] - // Optimal: x = 3, f(3) = 0 - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(10) - .n_ei_candidates(24) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - let x_param = FloatParam::new(-10.0, 10.0); - - study - .optimize(100, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>((x - 3.0).powi(2)) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().expect("should have at least one trial"); - - // TPE should find a reasonable value over 100 trials - // With random startup + TPE, we expect to get within a few units of optimal - assert!( - best.value < 5.0, - "TPE should find near-optimal: best value {} should be < 5.0", - best.value - ); -} - -#[test] -fn test_tpe_optimizes_multivariate_function() { - // Minimize f(x, y) = x^2 + y^2 where x, y in [-5, 5] - // Optimal: (0, 0), f(0, 0) = 0 - let sampler = TpeSampler::builder() - .seed(123) - .n_startup_trials(10) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - let x_param = FloatParam::new(-5.0, 5.0); - let y_param = FloatParam::new(-5.0, 5.0); - - study - .optimize(100, |trial| { - let x = x_param.suggest(trial)?; - let y = y_param.suggest(trial)?; - Ok::<_, Error>(x * x + y * y) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().expect("should have at least one trial"); - - // TPE should find a reasonably good solution - assert!( - best.value < 5.0, - "TPE should find near-optimal: best value {} should be < 5.0", - best.value - ); -} - -#[test] -fn test_tpe_maximization() { - // Maximize f(x) = -(x - 2)^2 + 10 where x in [-10, 10] - // Optimal: x = 2, f(2) = 10 - let sampler = TpeSampler::builder() - .seed(456) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); - - let x_param = FloatParam::new(-10.0, 10.0); - - study - .optimize(50, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(-(x - 2.0).powi(2) + 10.0) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().expect("should have at least one trial"); - - assert!( - best.value > 5.0, - "TPE should find reasonably good solution: best value {} should be > 5.0", - best.value - ); -} - -// ============================================================================= -// Test: RandomSampler samples uniformly across range -// ============================================================================= - -#[test] -fn test_random_sampler_uniform_float_distribution() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); - - let n_samples = 1000; - let mut samples = Vec::with_capacity(n_samples); - - let x_param = FloatParam::new(0.0, 1.0); - - study - .optimize(n_samples, |trial| { - let x = x_param.suggest(trial)?; - samples.push(x); - Ok::<_, Error>(x) - }) - .unwrap(); - - // All samples should be in range - for &s in &samples { - assert!((0.0..=1.0).contains(&s), "sample {s} out of range [0, 1]"); - } - - // Check distribution is roughly uniform by looking at quartiles - samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); - - let q1 = samples[n_samples / 4]; - let q2 = samples[n_samples / 2]; - let q3 = samples[3 * n_samples / 4]; - - assert!((q1 - 0.25).abs() < 0.1, "Q1 {q1} should be close to 0.25"); - assert!( - (q2 - 0.5).abs() < 0.1, - "Q2 (median) {q2} should be close to 0.5" - ); - assert!((q3 - 0.75).abs() < 0.1, "Q3 {q3} should be close to 0.75"); -} - -#[test] -fn test_random_sampler_uniform_int_distribution() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(123)); - - let n_samples = 5000; - let mut counts = [0u32; 10]; // counts for values 1-10 - - let n_param = IntParam::new(1, 10); - - study - .optimize(n_samples, |trial| { - let n = n_param.suggest(trial)?; - assert!((1..=10).contains(&n), "sample {n} out of range [1, 10]"); - counts[(n - 1) as usize] += 1; - Ok::<_, Error>(n as f64) - }) - .unwrap(); - - let expected = n_samples as f64 / 10.0; - for (i, &count) in counts.iter().enumerate() { - let diff = (count as f64 - expected).abs() / expected; - assert!( - diff < 0.2, - "value {} appeared {} times, expected ~{}, diff = {:.1}%", - i + 1, - count, - expected, - diff * 100.0 - ); - } -} - -#[test] -fn test_random_sampler_uniform_categorical_distribution() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(456)); - - let n_samples = 2000; - let mut counts = [0u32; 4]; - let choices = ["a", "b", "c", "d"]; - - let cat_param = CategoricalParam::new(choices.to_vec()); - - study - .optimize(n_samples, |trial| { - let choice = cat_param.suggest(trial)?; - let idx = choices.iter().position(|&c| c == choice).unwrap(); - counts[idx] += 1; - Ok::<_, Error>(idx as f64) - }) - .unwrap(); - - let expected = n_samples as f64 / 4.0; - for (i, &count) in counts.iter().enumerate() { - let diff = (count as f64 - expected).abs() / expected; - assert!( - diff < 0.15, - "category {} appeared {} times, expected ~{}, diff = {:.1}%", - i, - count, - expected, - diff * 100.0 - ); - } -} - -#[test] -fn test_random_sampler_reproducibility() { - let study1: Study<f64> = - Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(999)); - let study2: Study<f64> = - Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(999)); - - let mut values1 = Vec::new(); - let mut values2 = Vec::new(); - - let x_param1 = FloatParam::new(0.0, 100.0); - let x_param2 = FloatParam::new(0.0, 100.0); - - study1 - .optimize(100, |trial| { - let x = x_param1.suggest(trial)?; - values1.push(x); - Ok::<_, Error>(x) - }) - .unwrap(); - - study2 - .optimize(100, |trial| { - let x = x_param2.suggest(trial)?; - values2.push(x); - Ok::<_, Error>(x) - }) - .unwrap(); - - for (i, (v1, v2)) in values1.iter().zip(values2.iter()).enumerate() { - assert_eq!( - v1, v2, - "values at trial {i} should be identical with same seed: {v1} vs {v2}" - ); - } -} - -// ============================================================================= -// Test: suggest_param returns cached values on repeated calls -// ============================================================================= - -#[test] -fn test_suggest_float_caching() { - let param = FloatParam::new(0.0, 10.0); - let mut trial = Trial::new(0); - - let x1 = param.suggest(&mut trial).unwrap(); - let x2 = param.suggest(&mut trial).unwrap(); - let x3 = param.suggest(&mut trial).unwrap(); - - assert_eq!(x1, x2, "repeated suggest should return cached value"); - assert_eq!(x2, x3, "repeated suggest should return cached value"); -} - -#[test] -fn test_suggest_float_log_caching() { - let param = FloatParam::new(1e-5, 1e-1).log_scale(); - let mut trial = Trial::new(0); - - let x1 = param.suggest(&mut trial).unwrap(); - let x2 = param.suggest(&mut trial).unwrap(); - - assert_eq!( - x1, x2, - "repeated suggest float log should return cached value" - ); -} - -#[test] -fn test_suggest_float_step_caching() { - let param = FloatParam::new(0.0, 1.0).step(0.1); - let mut trial = Trial::new(0); - - let x1 = param.suggest(&mut trial).unwrap(); - let x2 = param.suggest(&mut trial).unwrap(); - - assert_eq!( - x1, x2, - "repeated suggest float step should return cached value" - ); -} - -#[test] -fn test_suggest_int_caching() { - let param = IntParam::new(1, 100); - let mut trial = Trial::new(0); - - let n1 = param.suggest(&mut trial).unwrap(); - let n2 = param.suggest(&mut trial).unwrap(); - - assert_eq!(n1, n2, "repeated suggest int should return cached value"); -} - -#[test] -fn test_suggest_int_log_caching() { - let param = IntParam::new(1, 1024).log_scale(); - let mut trial = Trial::new(0); - - let n1 = param.suggest(&mut trial).unwrap(); - let n2 = param.suggest(&mut trial).unwrap(); - - assert_eq!( - n1, n2, - "repeated suggest int log should return cached value" - ); -} - -#[test] -fn test_suggest_int_step_caching() { - let param = IntParam::new(32, 512).step(32); - let mut trial = Trial::new(0); - - let n1 = param.suggest(&mut trial).unwrap(); - let n2 = param.suggest(&mut trial).unwrap(); - - assert_eq!( - n1, n2, - "repeated suggest int step should return cached value" - ); -} - -#[test] -fn test_suggest_categorical_caching() { - let param = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]); - let mut trial = Trial::new(0); - - let c1 = param.suggest(&mut trial).unwrap(); - let c2 = param.suggest(&mut trial).unwrap(); - - assert_eq!( - c1, c2, - "repeated suggest categorical should return cached value" - ); -} - -#[test] -fn test_multiple_parameters_independent_caching() { - let x_param = FloatParam::new(0.0, 1.0); - let y_param = FloatParam::new(0.0, 1.0); - let n_param = IntParam::new(1, 10); - let opt_param = CategoricalParam::new(vec!["a", "b"]); - let mut trial = Trial::new(0); - - // Suggest multiple parameters - let x = x_param.suggest(&mut trial).unwrap(); - let y = y_param.suggest(&mut trial).unwrap(); - let n = n_param.suggest(&mut trial).unwrap(); - let opt = opt_param.suggest(&mut trial).unwrap(); - - // All should be cached independently - assert_eq!(x, x_param.suggest(&mut trial).unwrap()); - assert_eq!(y, y_param.suggest(&mut trial).unwrap()); - assert_eq!(n, n_param.suggest(&mut trial).unwrap()); - assert_eq!(opt, opt_param.suggest(&mut trial).unwrap()); -} - -// ============================================================================= -// Test: parameter conflict returns error -// ============================================================================= - -#[test] -fn test_parameter_conflict_same_param_different_distribution() { - // With ParamId-based API, conflict happens when the same ParamId is used - // with a different distribution. This can happen via suggest_param with - // a param that has a mismatched distribution for an already-stored id. - // Since each FloatParam::new() gets a unique id, conflicts only happen - // when the same param object is reused with different internal state, - // which is not possible with the immutable API. - // We test that different param objects don't conflict (they have different ids). - let param1 = FloatParam::new(0.0, 1.0); - let param2 = FloatParam::new(0.0, 2.0); - let mut trial = Trial::new(0); - - trial.suggest_param(¶m1).unwrap(); - // Different param object = different id = no conflict - let result = trial.suggest_param(¶m2); - assert!(result.is_ok()); -} - -#[test] -fn test_empty_categorical_returns_error() { - let param = CategoricalParam::<&str>::new(vec![]); - let mut trial = Trial::new(0); - - let result = trial.suggest_param(¶m); - assert!(matches!(result, Err(Error::EmptyChoices))); -} - -// ============================================================================= -// Additional integration tests -// ============================================================================= - -#[test] -fn test_study_basic_workflow() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-5.0, 5.0); - - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .expect("optimization should succeed"); - - assert_eq!(study.n_trials(), 10); - let best = study.best_trial().expect("should have best trial"); - assert!(best.value >= 0.0, "x^2 should be non-negative"); -} - -#[test] -fn test_study_with_failures() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-5.0, 5.0); - - // Every other trial fails - let mut counter = 0; - study - .optimize(10, |trial| { - counter += 1; - if counter % 2 == 0 { - return Err::<f64, &str>("intentional failure"); - } - let x = x_param.suggest(trial).map_err(|_| "param error")?; - Ok(x * x) - }) - .expect("optimization should succeed with some failures"); - - // Only half the trials should have succeeded - assert_eq!(study.n_trials(), 5, "only 5 trials should have completed"); -} - -#[test] -fn test_no_completed_trials_error() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.best_trial(); - assert!(matches!(result, Err(Error::NoCompletedTrials))); -} - -#[test] -fn test_invalid_bounds_errors() { - let mut trial = Trial::new(0); - - // low > high for float - let result = trial.suggest_param(&FloatParam::new(10.0, 5.0)); - assert!(matches!(result, Err(Error::InvalidBounds { .. }))); - - // low > high for int - let result = trial.suggest_param(&IntParam::new(100, 50)); - assert!(matches!(result, Err(Error::InvalidBounds { .. }))); -} - -#[test] -fn test_invalid_log_bounds_errors() { - let mut trial = Trial::new(0); - - // low <= 0 for log float - let result = trial.suggest_param(&FloatParam::new(0.0, 1.0).log_scale()); - assert!(matches!(result, Err(Error::InvalidLogBounds))); - - let result = trial.suggest_param(&FloatParam::new(-1.0, 1.0).log_scale()); - assert!(matches!(result, Err(Error::InvalidLogBounds))); - - // low < 1 for log int - let result = trial.suggest_param(&IntParam::new(0, 100).log_scale()); - assert!(matches!(result, Err(Error::InvalidLogBounds))); -} - -#[test] -fn test_invalid_step_errors() { - let mut trial = Trial::new(0); - - // step <= 0 for float - let result = trial.suggest_param(&FloatParam::new(0.0, 1.0).step(0.0)); - assert!(matches!(result, Err(Error::InvalidStep))); - - let result = trial.suggest_param(&FloatParam::new(0.0, 1.0).step(-0.1)); - assert!(matches!(result, Err(Error::InvalidStep))); - - // step <= 0 for int - let result = trial.suggest_param(&IntParam::new(0, 100).step(0)); - assert!(matches!(result, Err(Error::InvalidStep))); -} - -#[test] -fn test_tpe_with_categorical_parameter() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); - - let model_param = CategoricalParam::new(vec!["linear", "quadratic", "cubic"]); - let x_param = FloatParam::new(0.0, 2.0); - - // Optimization where the best choice depends on the categorical - study - .optimize(30, |trial| { - let choice = model_param.suggest(trial)?; - let x = x_param.suggest(trial)?; - - // cubic model is best at x=1 - let value = match choice { - "linear" => x, - "quadratic" => x * x, - "cubic" => -((x - 1.0).powi(2)) + 10.0, // peak at x=1, max value 10 - _ => unreachable!(), - }; - Ok::<_, Error>(value) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().expect("should have best trial"); - assert!( - best.value > 5.0, - "should find good solution, got {}", - best.value - ); -} - -#[test] -fn test_tpe_with_integer_parameters() { - let sampler = TpeSampler::builder() - .seed(789) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - - let n_param = IntParam::new(1, 10); - - // Minimize (n - 7)^2 where n in [1, 10] - study - .optimize(30, |trial| { - let n = n_param.suggest(trial)?; - Ok::<_, Error>(((n - 7) as f64).powi(2)) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().expect("should have best trial"); - - assert!( - best.value < 5.0, - "should find n close to 7, best value = {}", - best.value - ); -} - -#[test] -fn test_callback_early_stopping() { - use std::ops::ControlFlow; - - use optimizer::Objective; - use optimizer::sampler::CompletedTrial; - - struct EarlyStopAfter5 { - x_param: FloatParam, - } - - impl Objective<f64> for EarlyStopAfter5 { - type Error = Error; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { - let x = self.x_param.suggest(trial)?; - Ok(x) - } - fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { - if study.n_trials() >= 5 { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - study - .optimize_with( - 100, - EarlyStopAfter5 { - x_param: FloatParam::new(0.0, 10.0), - }, - ) - .expect("optimization should succeed"); - - assert_eq!(study.n_trials(), 5, "should have stopped after 5 trials"); -} - -#[test] -fn test_study_trials_iteration() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 1.0); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - let trials = study.trials(); - assert_eq!(trials.len(), 5); - - for trial in &trials { - assert!( - !trial.params.is_empty(), - "each trial should have parameters" - ); - } -} - -#[test] -fn test_study_direction() { - let study_min: Study<f64> = Study::new(Direction::Minimize); - assert_eq!(study_min.direction(), Direction::Minimize); - - let study_max: Study<f64> = Study::new(Direction::Maximize); - assert_eq!(study_max.direction(), Direction::Maximize); -} - -#[test] -fn test_trial_state() { - use optimizer::TrialState; - - let trial = Trial::new(0); - assert_eq!(trial.state(), TrialState::Running); -} - -#[test] -fn test_trial_params_access() { - let x_param = FloatParam::new(0.0, 1.0); - let n_param = IntParam::new(1, 10); - let mut trial = Trial::new(0); - - x_param.suggest(&mut trial).unwrap(); - n_param.suggest(&mut trial).unwrap(); - - let params = trial.params(); - assert_eq!(params.len(), 2); -} - -#[test] -fn test_log_scale_float_range() { - let param = FloatParam::new(1e-5, 1e-1).log_scale(); - let mut trial = Trial::new(0); - - let lr = param.suggest(&mut trial).unwrap(); - assert!( - (1e-5..=1e-1).contains(&lr), - "log-scale value {lr} out of range" - ); -} - -#[test] -fn test_step_float_snaps_to_grid() { - let param = FloatParam::new(0.0, 1.0).step(0.25); - let mut trial = Trial::new(0); - - let x = param.suggest(&mut trial).unwrap(); - - // x should be one of: 0.0, 0.25, 0.5, 0.75, 1.0 - let valid_values = [0.0, 0.25, 0.5, 0.75, 1.0]; - let is_valid = valid_values.iter().any(|&v| (x - v).abs() < 1e-10); - assert!(is_valid, "stepped float {x} should snap to grid"); -} - -#[test] -fn test_step_int_snaps_to_grid() { - let param = IntParam::new(0, 100).step(25); - let mut trial = Trial::new(0); - - let n = param.suggest(&mut trial).unwrap(); - - // n should be one of: 0, 25, 50, 75, 100 - assert!( - n % 25 == 0 && (0..=100).contains(&n), - "stepped int {n} should snap to grid" - ); -} - -#[test] -fn test_best_value() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - let best_value = study.best_value().expect("should have best value"); - let best_trial = study.best_trial().expect("should have best trial"); - - assert_eq!( - best_value, best_trial.value, - "best_value should match best_trial.value" - ); -} - -// ============================================================================= -// Additional coverage tests -// ============================================================================= - -#[test] -fn test_study_set_sampler() { - let mut study: Study<f64> = Study::new(Direction::Minimize); - - let tpe = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - study.set_sampler(tpe); - - let x_param = FloatParam::new(-5.0, 5.0); - - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .expect("optimization should succeed with new sampler"); - - assert_eq!(study.n_trials(), 10); -} - -#[test] -fn test_study_with_i32_value_type() { - let study: Study<i32> = Study::new(Direction::Minimize); - let x_param = IntParam::new(-10, 10); - - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x.abs() as i32) - }) - .expect("optimization should succeed"); - - assert_eq!(study.n_trials(), 10); - let best = study.best_trial().expect("should have best trial"); - assert!(best.value >= 0, "absolute value should be non-negative"); -} - -#[test] -fn test_optimize_all_trials_fail() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.optimize(5, |_trial| Err::<f64, &str>("always fails")); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[test] -fn test_optimize_with_all_trials_fail() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let result = study.optimize(5, |_trial| Err::<f64, &str>("always fails")); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "should return NoCompletedTrials when all trials fail" - ); -} - -#[test] -fn test_trial_debug_format() { - let param = FloatParam::new(0.0, 1.0); - let mut trial = Trial::new(42); - param.suggest(&mut trial).unwrap(); - - let debug_str = format!("{trial:?}"); - - assert!(debug_str.contains("Trial")); - assert!(debug_str.contains("42")); - assert!(debug_str.contains("has_sampler")); -} - -#[test] -fn test_tpe_sampler_builder_default_trait() { - use optimizer::sampler::tpe::TpeSamplerBuilder; - - let builder = TpeSamplerBuilder::default(); - let sampler = builder.build().unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 1.0); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_tpe_sampler_default_trait() { - let sampler = TpeSampler::default(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 1.0); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_tpe_with_fixed_kde_bandwidth() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .kde_bandwidth(0.5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(-5.0, 5.0); - - study - .optimize(20, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x * x) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().unwrap(); - assert!(best.value < 10.0, "should find reasonable solution"); -} - -#[test] -fn test_tpe_sampler_invalid_kde_bandwidth() { - let result = TpeSampler::with_config(0.25, 10, 24, Some(-1.0), None); - assert!(matches!(result, Err(Error::InvalidBandwidth(_)))); -} - -#[test] -fn test_tpe_split_trials_with_two_trials() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(2) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .expect("optimization should succeed with small history"); - - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_tpe_with_log_scale_int() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let batch_param = IntParam::new(1, 1024).log_scale(); - - study - .optimize(20, |trial| { - let batch_size = batch_param.suggest(trial)?; - Ok::<_, Error>(((batch_size as f64).log2() - 5.0).powi(2)) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().unwrap(); - assert!(best.value < 10.0, "should find reasonable solution"); -} - -#[test] -fn test_tpe_with_step_distributions() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 10.0).step(0.5); - let n_param = IntParam::new(0, 100).step(10); - - study - .optimize(20, |trial| { - let x = x_param.suggest(trial)?; - let n = n_param.suggest(trial)?; - Ok::<_, Error>((x - 5.0).powi(2) + ((n - 50) as f64).powi(2)) - }) - .expect("optimization should succeed"); - - let best = study.best_trial().unwrap(); - assert!(best.value < 100.0, "should find reasonable solution"); -} - -#[test] -fn test_manual_trial_completion() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - // Manually create and complete trials - let mut trial = study.create_trial(); - let x = x_param.suggest(&mut trial).unwrap(); - study.complete_trial(trial, x * x); - - let mut trial2 = study.create_trial(); - let y = x_param.suggest(&mut trial2).unwrap(); - study.complete_trial(trial2, y * y); - - // Manually fail a trial - let trial3 = study.create_trial(); - study.fail_trial(trial3, "test failure"); - - // Only 2 completed trials - assert_eq!(study.n_trials(), 2); -} - -#[test] -fn test_distributions_access() { - let x_param = FloatParam::new(0.0, 1.0); - let n_param = IntParam::new(1, 10); - let opt_param = CategoricalParam::new(vec!["a", "b", "c"]); - let mut trial = Trial::new(0); - - x_param.suggest(&mut trial).unwrap(); - n_param.suggest(&mut trial).unwrap(); - opt_param.suggest(&mut trial).unwrap(); - - let dists = trial.distributions(); - assert_eq!(dists.len(), 3); -} - -#[test] -fn test_tpe_empty_good_or_bad_values_fallback() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .gamma(0.1) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(0.0, 10.0); - let y_param = FloatParam::new(0.0, 10.0); - - // First optimize with one parameter - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - // Now try with a different parameter - TPE won't have history for "y" - study - .optimize(5, |trial| { - let y = y_param.suggest(trial)?; - Ok::<_, Error>(y) - }) - .unwrap(); - - assert_eq!(study.n_trials(), 15); -} - -#[test] -fn test_callback_early_stopping_on_first_trial() { - use std::ops::ControlFlow; - - use optimizer::Objective; - use optimizer::sampler::CompletedTrial; - - struct StopImmediately { - x_param: FloatParam, - } - - impl Objective<f64> for StopImmediately { - type Error = Error; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { - let x = self.x_param.suggest(trial)?; - Ok(x) - } - fn after_trial( - &self, - _study: &Study<f64>, - _trial: &CompletedTrial<f64>, - ) -> ControlFlow<()> { - ControlFlow::Break(()) - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - study - .optimize_with( - 100, - StopImmediately { - x_param: FloatParam::new(0.0, 10.0), - }, - ) - .expect("optimization should succeed"); - - assert_eq!(study.n_trials(), 1, "should have stopped after 1 trial"); -} - -#[test] -fn test_callback_sampler_early_stopping() { - use std::ops::ControlFlow; - - use optimizer::Objective; - use optimizer::sampler::CompletedTrial; - - struct StopAfter3 { - x_param: FloatParam, - } - - impl Objective<f64> for StopAfter3 { - type Error = Error; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { - let x = self.x_param.suggest(trial)?; - Ok(x) - } - fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { - if study.n_trials() >= 3 { - ControlFlow::Break(()) - } else { - ControlFlow::Continue(()) - } - } - } - - let sampler = RandomSampler::with_seed(42); - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - study - .optimize_with( - 100, - StopAfter3 { - x_param: FloatParam::new(0.0, 10.0), - }, - ) - .expect("optimization should succeed"); - - assert_eq!(study.n_trials(), 3); -} - -#[test] -fn test_int_bounds_with_low_equals_high() { - let mut trial = Trial::new(0); - - // When low == high, should return that exact value - let n_param = IntParam::new(5, 5); - let n = n_param.suggest(&mut trial).unwrap(); - assert_eq!(n, 5); - - let x_param = FloatParam::new(3.0, 3.0); - let x = x_param.suggest(&mut trial).unwrap(); - assert_eq!(x, 3.0); -} - -#[test] -fn test_best_trial_with_nan_values() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); - - let best = study.best_trial(); - assert!(best.is_ok()); -} - -// ============================================================================= -// Tests for BoolParam -// ============================================================================= - -#[test] -fn test_suggest_bool_caching() { - let param = BoolParam::new(); - let mut trial = Trial::new(0); - - let b1 = param.suggest(&mut trial).unwrap(); - let b2 = param.suggest(&mut trial).unwrap(); - - assert_eq!(b1, b2, "repeated suggest bool should return cached value"); -} - -#[test] -fn test_suggest_bool_multiple_parameters() { - let dropout_param = BoolParam::new(); - let batchnorm_param = BoolParam::new(); - let skip_param = BoolParam::new(); - let mut trial = Trial::new(0); - - let a = dropout_param.suggest(&mut trial).unwrap(); - let b = batchnorm_param.suggest(&mut trial).unwrap(); - let c = skip_param.suggest(&mut trial).unwrap(); - - // All should be cached independently - assert_eq!(a, dropout_param.suggest(&mut trial).unwrap()); - assert_eq!(b, batchnorm_param.suggest(&mut trial).unwrap()); - assert_eq!(c, skip_param.suggest(&mut trial).unwrap()); -} - -#[test] -fn test_suggest_bool_in_optimization() { - let study: Study<f64> = Study::new(Direction::Minimize); - let use_feature_param = BoolParam::new(); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize(10, |trial| { - let use_feature = use_feature_param.suggest(trial)?; - let x = x_param.suggest(trial)?; - - let value = if use_feature { x } else { x * 2.0 }; - Ok::<_, Error>(value) - }) - .unwrap(); - - assert_eq!(study.n_trials(), 10); -} - -#[test] -fn test_suggest_bool_with_tpe() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let use_large_param = BoolParam::new(); - let x_param = FloatParam::new(0.0, 10.0); - - study - .optimize(20, |trial| { - let use_large = use_large_param.suggest(trial)?; - let x = x_param.suggest(trial)?; - // The value depends on use_large flag - let base = if use_large { x * 2.0 } else { x }; - Ok::<_, Error>(base) - }) - .unwrap(); - - let best = study.best_trial().unwrap(); - assert!(best.value < 10.0); -} - -// ============================================================================= -// Tests for FloatParam and IntParam ranges -// ============================================================================= - -#[test] -fn test_float_param_exclusive_range() { - let param = FloatParam::new(0.0, 1.0); - let mut trial = Trial::new(0); - - let x = param.suggest(&mut trial).unwrap(); - assert!((0.0..=1.0).contains(&x), "value {x} out of range 0.0..1.0"); -} - -#[test] -fn test_float_param_inclusive_range() { - let param = FloatParam::new(0.0, 1.0); - let mut trial = Trial::new(0); - - let x = param.suggest(&mut trial).unwrap(); - assert!((0.0..=1.0).contains(&x), "value {x} out of range 0.0..=1.0"); -} - -#[test] -fn test_int_param_range() { - let param = IntParam::new(1, 10); - let mut trial = Trial::new(0); - - let n = param.suggest(&mut trial).unwrap(); - assert!((1..=10).contains(&n), "value {n} out of range 1..=10"); -} - -#[test] -fn test_param_caching_float() { - let param = FloatParam::new(0.0, 1.0); - let mut trial = Trial::new(0); - - let x1 = param.suggest(&mut trial).unwrap(); - let x2 = param.suggest(&mut trial).unwrap(); - - assert_eq!(x1, x2, "repeated suggest should return cached value"); -} - -#[test] -fn test_param_caching_int() { - let param = IntParam::new(1, 100); - let mut trial = Trial::new(0); - - let n1 = param.suggest(&mut trial).unwrap(); - let n2 = param.suggest(&mut trial).unwrap(); - - assert_eq!(n1, n2, "repeated suggest should return cached value"); -} - -#[test] -fn test_multiple_params_in_optimization() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-10.0, 10.0); - let n_param = IntParam::new(1, 5); - - study - .optimize(10, |trial| { - let x = x_param.suggest(trial)?; - let n = n_param.suggest(trial)?; - Ok::<_, Error>(x * x + n as f64) - }) - .unwrap(); - - assert_eq!(study.n_trials(), 10); -} - -#[test] -fn test_params_with_tpe() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x_param = FloatParam::new(-5.0, 5.0); - let n_param = IntParam::new(1, 10); - - study - .optimize(30, |trial| { - let x = x_param.suggest(trial)?; - let n = n_param.suggest(trial)?; - Ok::<_, Error>(x * x + (n as f64 - 5.0).powi(2)) - }) - .unwrap(); - - let best = study.best_trial().unwrap(); - assert!(best.value < 10.0, "TPE should find good solution"); -} - -#[test] -fn test_single_value_int_range() { - let param = IntParam::new(5, 5); - let mut trial = Trial::new(0); - - let n = param.suggest(&mut trial).unwrap(); - assert_eq!(n, 5, "single-value range should return that value"); -} - -#[test] -fn test_single_value_float_range() { - let param = FloatParam::new(4.2, 4.2); - let mut trial = Trial::new(0); - - let x = param.suggest(&mut trial).unwrap(); - assert!( - (x - 4.2).abs() < f64::EPSILON, - "single-value range should return that value" - ); -} - -// ============================================================================= -// Tests for new API features -// ============================================================================= - -#[test] -fn test_param_name() { - let param = FloatParam::new(0.0, 1.0).name("learning_rate"); - let mut trial = Trial::new(0); - param.suggest(&mut trial).unwrap(); - - let labels = trial.param_labels(); - let label = labels.values().next().unwrap(); - assert_eq!(label, "learning_rate"); -} - -#[test] -fn test_completed_trial_get() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-10.0, 10.0).name("x"); - let n_param = IntParam::new(1, 10).name("n"); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - let n = n_param.suggest(trial)?; - Ok::<_, Error>(x * x + n as f64) - }) - .unwrap(); - - let best = study.best_trial().unwrap(); - let x_val: f64 = best.get(&x_param).unwrap(); - let n_val: i64 = best.get(&n_param).unwrap(); - assert!((-10.0..=10.0).contains(&x_val)); - assert!((1..=10).contains(&n_val)); -} - -// ============================================================================= -// Tests for top_trials -// ============================================================================= - -#[test] -fn test_top_trials_minimize() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Manually complete trials with known values - for &val in &[5.0, 1.0, 3.0, 2.0, 4.0] { - let trial = study.create_trial(); - study.complete_trial(trial, val); - } - - let top3 = study.top_trials(3); - assert_eq!(top3.len(), 3); - assert_eq!(top3[0].value, 1.0); - assert_eq!(top3[1].value, 2.0); - assert_eq!(top3[2].value, 3.0); -} - -#[test] -fn test_top_trials_maximize() { - let study: Study<f64> = Study::new(Direction::Maximize); - - for &val in &[5.0, 1.0, 3.0, 2.0, 4.0] { - let trial = study.create_trial(); - study.complete_trial(trial, val); - } - - let top3 = study.top_trials(3); - assert_eq!(top3.len(), 3); - assert_eq!(top3[0].value, 5.0); - assert_eq!(top3[1].value, 4.0); - assert_eq!(top3[2].value, 3.0); -} - -#[test] -fn test_top_trials_n_greater_than_total() { - let study: Study<f64> = Study::new(Direction::Minimize); - - for &val in &[3.0, 1.0] { - let trial = study.create_trial(); - study.complete_trial(trial, val); - } - - let top = study.top_trials(10); - assert_eq!(top.len(), 2); - assert_eq!(top[0].value, 1.0); - assert_eq!(top[1].value, 3.0); -} - -#[test] -fn test_top_trials_empty() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let top = study.top_trials(5); - assert!(top.is_empty()); -} - -#[test] -fn test_top_trials_excludes_pruned() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Complete some trials - for &val in &[5.0, 1.0, 3.0] { - let trial = study.create_trial(); - study.complete_trial(trial, val); - } - - // Prune a trial (it gets a default value of 0.0 but should be excluded) - let trial = study.create_trial(); - study.prune_trial(trial); - - let top = study.top_trials(5); - assert_eq!(top.len(), 3, "pruned trial should be excluded"); - assert_eq!(top[0].value, 1.0); -} - -// ============================================================================= -// Test: ask-and-tell interface -// ============================================================================= - -#[test] -fn test_ask_and_tell_basic() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - for _ in 0..10 { - let mut trial = study.ask(); - let x = x_param.suggest(&mut trial).unwrap(); - let value = x * x; - study.tell(trial, Ok::<_, &str>(value)); - } - - assert_eq!(study.n_trials(), 10); - assert!(study.best_value().unwrap() >= 0.0); -} - -#[test] -fn test_ask_and_tell_with_failures() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(-5.0, 5.0); - - // Alternate success and failure - for i in 0..10 { - let mut trial = study.ask(); - let x = x_param.suggest(&mut trial).unwrap(); - if i % 2 == 0 { - study.tell(trial, Ok::<_, &str>(x * x)); - } else { - study.tell(trial, Err::<f64, _>("simulated failure")); - } - } - - // Only successful trials are counted - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_ask_and_tell_with_tpe_sampler() { - let sampler = TpeSampler::builder() - .seed(42) - .n_startup_trials(5) - .build() - .unwrap(); - let study: Study<f64> = Study::minimize(sampler); - let x_param = FloatParam::new(-10.0, 10.0); - - for _ in 0..30 { - let mut trial = study.ask(); - let x = x_param.suggest(&mut trial).unwrap(); - study.tell(trial, Ok::<_, &str>((x - 3.0).powi(2))); - } - - assert_eq!(study.n_trials(), 30); - assert!( - study.best_value().unwrap() < 5.0, - "TPE ask-and-tell should find a reasonable value" - ); -} - -#[test] -fn test_ask_and_tell_batch() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - - // Ask a batch of trials - let batch: Vec<_> = (0..5) - .map(|_| { - let mut t = study.ask(); - let x = x_param.suggest(&mut t).unwrap(); - (t, x) - }) - .collect(); - - // Tell results for the batch - for (trial, x) in batch { - study.tell(trial, Ok::<_, &str>(x * x)); - } - - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_ask_and_tell_with_custom_value_type() { - // Ask-and-tell works with non-f64 value types too - let study: Study<i32> = Study::new(Direction::Maximize); - - for i in 0..5 { - let trial = study.ask(); - study.tell(trial, Ok::<_, &str>(i * 10)); - } - - assert_eq!(study.n_trials(), 5); - assert_eq!(study.best_value().unwrap(), 40); -} - -// ============================================================================= -// Tests: enqueue trials -// ============================================================================= - -use std::collections::HashMap; - -use optimizer::parameter::ParamValue; - -#[test] -fn test_enqueue_params_evaluated_first() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - let y = IntParam::new(1, 100); - - // Enqueue a specific configuration - study.enqueue(HashMap::from([ - (x.id(), ParamValue::Float(5.0)), - (y.id(), ParamValue::Int(42)), - ])); - - // The first trial should use the enqueued params - let mut trial = study.ask(); - let x_val = x.suggest(&mut trial).unwrap(); - let y_val = y.suggest(&mut trial).unwrap(); - - assert_eq!(x_val, 5.0); - assert_eq!(y_val, 42); -} - -#[test] -fn test_enqueue_fifo_order() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - // Enqueue two configs - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); - - // First trial gets first enqueued value - let mut trial1 = study.ask(); - assert_eq!(x.suggest(&mut trial1).unwrap(), 1.0); - - // Second trial gets second enqueued value - let mut trial2 = study.ask(); - assert_eq!(x.suggest(&mut trial2).unwrap(), 2.0); -} - -#[test] -fn test_enqueue_then_normal_sampling_resumes() { - let sampler = RandomSampler::with_seed(42); - let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - let x = FloatParam::new(0.0, 10.0); - - // Enqueue one config - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(5.0))])); - - // First trial uses enqueued value - let mut trial1 = study.ask(); - assert_eq!(x.suggest(&mut trial1).unwrap(), 5.0); - study.tell(trial1, Ok::<_, &str>(25.0)); - - // Second trial uses normal sampling (not 5.0) - let mut trial2 = study.ask(); - let x_val = x.suggest(&mut trial2).unwrap(); - // The sampled value should be in [0, 10] but extremely unlikely to be exactly 5.0 - assert!((0.0..=10.0).contains(&x_val)); -} - -#[test] -fn test_enqueue_with_optimize() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - // Enqueue two specific configs - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); - - let mut values = Vec::new(); - - study - .optimize(5, |trial| { - let x_val = x.suggest(trial)?; - values.push(x_val); - Ok::<_, Error>(x_val * x_val) - }) - .unwrap(); - - // First two trials should use enqueued values - assert_eq!(values[0], 1.0); - assert_eq!(values[1], 2.0); - // All 5 trials should have completed - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_enqueue_partial_params_fall_back_to_sampling() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - let y = IntParam::new(1, 100); - - // Enqueue only x, not y - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(3.0))])); - - let mut trial = study.ask(); - let x_val = x.suggest(&mut trial).unwrap(); - let y_val = y.suggest(&mut trial).unwrap(); - - // x should be the enqueued value - assert_eq!(x_val, 3.0); - // y should be sampled (within range) - assert!((1..=100).contains(&y_val)); -} - -#[test] -fn test_enqueue_trials_appear_in_completed_trials() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(7.0))])); - - study - .optimize(1, |trial| { - let x_val = x.suggest(trial)?; - Ok::<_, Error>(x_val) - }) - .unwrap(); - - let trials = study.trials(); - assert_eq!(trials.len(), 1); - assert_eq!(trials[0].value, 7.0); - assert_eq!( - *trials[0].params.get(&x.id()).unwrap(), - ParamValue::Float(7.0) - ); -} - -#[test] -fn test_enqueue_with_ask_and_tell() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(4.0))])); - - let mut trial = study.ask(); - let x_val = x.suggest(&mut trial).unwrap(); - assert_eq!(x_val, 4.0); - - study.tell(trial, Ok::<_, &str>(x_val * x_val)); - assert_eq!(study.n_trials(), 1); - assert_eq!(study.best_value().unwrap(), 16.0); -} - -#[test] -fn test_n_enqueued() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - assert_eq!(study.n_enqueued(), 0); - - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); - assert_eq!(study.n_enqueued(), 1); - - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); - assert_eq!(study.n_enqueued(), 2); - - // Creating a trial dequeues one - let _ = study.ask(); - assert_eq!(study.n_enqueued(), 1); - - let _ = study.ask(); - assert_eq!(study.n_enqueued(), 0); -} - -#[test] -fn test_enqueue_counted_in_n_trials() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x = FloatParam::new(0.0, 10.0); - - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); - study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); - - study - .optimize(5, |trial| { - let x_val = x.suggest(trial)?; - Ok::<_, Error>(x_val) - }) - .unwrap(); - - // All 5 trials count, including the 2 enqueued ones - assert_eq!(study.n_trials(), 5); -} - -// ============================================================================= -// Test: Study summary and Display -// ============================================================================= - -#[test] -fn test_summary_with_completed_trials() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); - let x = FloatParam::new(0.0, 10.0).name("x"); - - study - .optimize(5, |trial| { - let val = x.suggest(trial)?; - Ok::<_, Error>(val * val) - }) - .unwrap(); - - let summary = study.summary(); - assert!(summary.contains("Minimize")); - assert!(summary.contains("5 trials")); - assert!(summary.contains("Best value:")); - assert!(summary.contains("x = ")); -} - -#[test] -fn test_summary_no_completed_trials() { - let study: Study<f64> = Study::new(Direction::Maximize); - let summary = study.summary(); - assert!(summary.contains("Maximize")); - assert!(summary.contains("0 trials")); - assert!(!summary.contains("Best value:")); -} - -#[test] -fn test_summary_with_pruned_trials() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); - let x = FloatParam::new(0.0, 10.0).name("x"); - - // Manually create some complete and pruned trials - for _ in 0..3 { - let mut trial = study.create_trial(); - let val = x.suggest(&mut trial).unwrap(); - study.complete_trial(trial, val); - } - for _ in 0..2 { - let mut trial = study.create_trial(); - let _ = x.suggest(&mut trial).unwrap(); - study.prune_trial(trial); - } - - let summary = study.summary(); - // Should show breakdown when there are pruned trials - if study.n_pruned_trials() > 0 { - assert!(summary.contains("complete")); - assert!(summary.contains("pruned")); - } -} - -#[test] -fn test_display_matches_summary() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); - let x = FloatParam::new(0.0, 10.0).name("x"); - - study - .optimize(3, |trial| { - let val = x.suggest(trial)?; - Ok::<_, Error>(val) - }) - .unwrap(); - - assert_eq!(format!("{study}"), study.summary()); -} - -// ============================================================================= -// Tests: optimize_with retries via Objective trait -// ============================================================================= - -#[test] -fn test_retries_successful_trials_not_retried() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - - use optimizer::Objective; - - struct SuccessObj { - x_param: FloatParam, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for SuccessObj { - type Error = Error; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { - let x = self.x_param.suggest(trial)?; - self.call_count.fetch_add(1, Ordering::Relaxed); - Ok(x * x) - } - fn max_retries(&self) -> usize { - 3 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = SuccessObj { - x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), - }; - - study.optimize_with(5, obj).unwrap(); - - // All trials succeed on first try — exactly 5 calls - assert_eq!(call_count.load(Ordering::Relaxed), 5); - assert_eq!(study.n_trials(), 5); -} - -#[test] -fn test_retries_failed_trials_retried_up_to_max() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - - use optimizer::Objective; - - struct AlwaysFailObj { - x_param: FloatParam, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for AlwaysFailObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - self.call_count.fetch_add(1, Ordering::Relaxed); - Err("always fails".to_string()) - } - fn max_retries(&self) -> usize { - 3 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = AlwaysFailObj { - x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), - }; - - let result = study.optimize_with(1, obj); - - // 1 initial attempt + 3 retries = 4 total calls - assert_eq!(call_count.load(Ordering::Relaxed), 4); - // No trials completed - assert!(matches!(result, Err(Error::NoCompletedTrials))); -} - -#[test] -fn test_retries_permanently_failed_after_exhaustion() { - use optimizer::Objective; - - struct AlwaysFailObj { - x_param: FloatParam, - } - - impl Objective<f64> for AlwaysFailObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - Err("transient error".to_string()) - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let obj = AlwaysFailObj { - x_param: FloatParam::new(0.0, 10.0), - }; - - let result = study.optimize_with(3, obj); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "all trials should permanently fail" - ); - assert_eq!( - study.n_trials(), - 0, - "no completed trials should be recorded" - ); -} - -#[test] -fn test_retries_uses_same_parameters() { - use std::sync::atomic::{AtomicU32, Ordering}; - use std::sync::{Arc, Mutex}; - - use optimizer::Objective; - - struct RetryObj { - x_param: FloatParam, - seen_values: Arc<Mutex<Vec<f64>>>, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for RetryObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - self.seen_values.lock().unwrap().push(x); - let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; - // Fail first two attempts, succeed on third - if count < 3 { - Err("transient".to_string()) - } else { - Ok(x * x) - } - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let seen_values = Arc::new(Mutex::new(Vec::new())); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = RetryObj { - x_param: FloatParam::new(0.0, 10.0), - seen_values: Arc::clone(&seen_values), - call_count: Arc::clone(&call_count), - }; - - study.optimize_with(1, obj).unwrap(); - - let values = seen_values.lock().unwrap(); - assert_eq!(values.len(), 3, "should be called 3 times (1 + 2 retries)"); - // All three calls should have gotten the same parameter value - assert_eq!(values[0], values[1]); - assert_eq!(values[1], values[2]); -} - -#[test] -fn test_retries_n_trials_counts_unique_configs() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - - use optimizer::Objective; - - struct FailFirstObj { - x_param: FloatParam, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for FailFirstObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; - // Fail first attempt of each config, succeed on retry - if count % 2 == 1 { - Err("transient".to_string()) - } else { - Ok(x * x) - } - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = FailFirstObj { - x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), - }; - - study.optimize_with(3, obj).unwrap(); - - // 3 unique configs, each needing 2 calls = 6 total calls - assert_eq!(call_count.load(Ordering::Relaxed), 6); - // But only 3 completed trials - assert_eq!(study.n_trials(), 3); -} - -#[test] -fn test_retries_with_zero_max_retries_same_as_optimize() { - let study: Study<f64> = Study::new(Direction::Minimize); - let x_param = FloatParam::new(0.0, 10.0); - let call_count = std::cell::Cell::new(0u32); - - study - .optimize(5, |trial| { - let x = x_param.suggest(trial)?; - call_count.set(call_count.get() + 1); - Ok::<_, Error>(x * x) - }) - .unwrap(); - - assert_eq!(call_count.get(), 5); - assert_eq!(study.n_trials(), 5); -} - -// ============================================================================= -// Tests: IntoIterator for &Study -// ============================================================================= - -#[test] -fn test_into_iterator_iterates_all_trials() { - let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); - let x_param = FloatParam::new(0.0, 10.0); - - for _ in 0..5 { - let mut trial = study.create_trial(); - let x = x_param.suggest(&mut trial).unwrap(); - study.complete_trial(trial, x * x); - } - - let mut count = 0; - for trial in &study { - assert_eq!(trial.state, optimizer::TrialState::Complete); - count += 1; - } - assert_eq!(count, 5); -} - -#[test] -fn test_into_iterator_empty_study() { - let study: Study<f64> = Study::new(Direction::Minimize); - - let count = (&study).into_iter().count(); - assert_eq!(count, 0); -} - -#[test] -fn test_into_iterator_preserves_insertion_order() { - let study: Study<f64> = Study::new(Direction::Minimize); - - for i in 0..3 { - let trial = study.create_trial(); - study.complete_trial(trial, f64::from(i)); - } - - let ids: Vec<u64> = (&study).into_iter().map(|t| t.id).collect(); - assert_eq!(ids, vec![0, 1, 2]); -} - -// ============================================================================= -// Tests: Constraint handling -// ============================================================================= - -#[test] -fn test_is_feasible_all_satisfied() { - let study: Study<f64> = Study::new(Direction::Minimize); - let mut trial = study.create_trial(); - trial.set_constraints(vec![-1.0, 0.0, -0.5]); - study.complete_trial(trial, 1.0); - - let completed = study.best_trial().unwrap(); - assert!(completed.is_feasible()); -} - -#[test] -fn test_is_feasible_one_violated() { - let study: Study<f64> = Study::new(Direction::Minimize); - let mut trial = study.create_trial(); - trial.set_constraints(vec![-1.0, 0.5, -0.5]); - study.complete_trial(trial, 1.0); - - let completed = study.best_trial().unwrap(); - assert!(!completed.is_feasible()); -} - -#[test] -fn test_is_feasible_empty_constraints() { - let study: Study<f64> = Study::new(Direction::Minimize); - let trial = study.create_trial(); - study.complete_trial(trial, 1.0); - - let completed = study.best_trial().unwrap(); - assert!(completed.is_feasible()); -} - -#[test] -fn test_best_trial_prefers_feasible() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Infeasible trial with better objective - let mut trial1 = study.create_trial(); - trial1.set_constraints(vec![1.0]); - study.complete_trial(trial1, 0.1); - - // Feasible trial with worse objective - let mut trial2 = study.create_trial(); - trial2.set_constraints(vec![-1.0]); - study.complete_trial(trial2, 100.0); - - let best = study.best_trial().unwrap(); - assert_eq!(best.id, 1); // feasible trial wins - assert_eq!(best.value, 100.0); -} - -#[test] -fn test_best_trial_feasible_by_objective() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Feasible, worse objective - let mut trial1 = study.create_trial(); - trial1.set_constraints(vec![-1.0]); - study.complete_trial(trial1, 10.0); - - // Feasible, better objective - let mut trial2 = study.create_trial(); - trial2.set_constraints(vec![-0.5]); - study.complete_trial(trial2, 2.0); - - let best = study.best_trial().unwrap(); - assert_eq!(best.id, 1); // lower objective wins among feasible - assert_eq!(best.value, 2.0); -} - -#[test] -fn test_top_trials_ranks_feasible_above_infeasible() { - let study: Study<f64> = Study::new(Direction::Minimize); - - // Infeasible, low violation - let mut t0 = study.create_trial(); - t0.set_constraints(vec![0.5]); - study.complete_trial(t0, 1.0); - - // Feasible, worst objective among feasible - let mut t1 = study.create_trial(); - t1.set_constraints(vec![-1.0]); - study.complete_trial(t1, 50.0); - - // Feasible, best objective among feasible - let mut t2 = study.create_trial(); - t2.set_constraints(vec![-0.1]); - study.complete_trial(t2, 5.0); - - // Infeasible, high violation - let mut t3 = study.create_trial(); - t3.set_constraints(vec![3.0]); - study.complete_trial(t3, 0.5); - - let top = study.top_trials(4); - let ids: Vec<u64> = top.iter().map(|t| t.id).collect(); - // Feasible sorted by objective first (5.0, 50.0), then infeasible by violation (0.5, 3.0) - assert_eq!(ids, vec![2, 1, 0, 3]); -} - -// ============================================================================= -// Test: StudyBuilder -// ============================================================================= - -#[test] -fn test_builder_defaults() { - let study: Study<f64> = Study::builder().build(); - assert_eq!(study.direction(), Direction::Minimize); -} - -#[test] -fn test_builder_maximize() { - let study: Study<f64> = Study::builder().maximize().build(); - assert_eq!(study.direction(), Direction::Maximize); -} - -#[test] -fn test_builder_minimize() { - let study: Study<f64> = Study::builder().minimize().build(); - assert_eq!(study.direction(), Direction::Minimize); -} - -#[test] -fn test_builder_direction() { - let study: Study<f64> = Study::builder().direction(Direction::Maximize).build(); - assert_eq!(study.direction(), Direction::Maximize); -} - -#[test] -fn test_builder_with_sampler() { - let x = FloatParam::new(-5.0, 5.0); - let study: Study<f64> = Study::builder().sampler(TpeSampler::new()).build(); - - study - .optimize(10, |trial| { - let val = x.suggest(trial)?; - Ok::<_, Error>(val * val) - }) - .unwrap(); - - assert_eq!(study.trials().len(), 10); -} - -#[test] -fn test_builder_with_pruner() { - use optimizer::pruner::NopPruner; - - let study: Study<f64> = Study::builder().pruner(NopPruner).build(); - - assert_eq!(study.direction(), Direction::Minimize); -} - -#[test] -fn test_builder_chaining() { - let study: Study<f64> = Study::builder() - .maximize() - .sampler(RandomSampler::with_seed(42)) - .pruner(optimizer::pruner::NopPruner) - .build(); - - assert_eq!(study.direction(), Direction::Maximize); -} - -#[test] -fn test_builder_with_custom_value_type() { - let study: Study<i32> = Study::builder().maximize().build(); - assert_eq!(study.direction(), Direction::Maximize); -} - -#[test] -fn test_builder_optimizes_correctly() { - let x = FloatParam::new(-10.0, 10.0); - let study: Study<f64> = Study::builder() - .minimize() - .sampler(TpeSampler::builder().seed(42).build().unwrap()) - .build(); - - study - .optimize(100, |trial| { - let val = x.suggest(trial)?; - Ok::<_, Error>((val - 3.0) * (val - 3.0)) - }) - .unwrap(); - - let best = study.best_trial().unwrap(); - assert!( - best.value < 5.0, - "best value should be < 5.0, got {}", - best.value - ); -} diff --git a/tests/pruner/main.rs b/tests/pruner/main.rs new file mode 100644 index 0000000..b9eaa5d --- /dev/null +++ b/tests/pruner/main.rs @@ -0,0 +1,2 @@ +mod median; +mod threshold; diff --git a/tests/median_pruner_tests.rs b/tests/pruner/median.rs similarity index 100% rename from tests/median_pruner_tests.rs rename to tests/pruner/median.rs diff --git a/tests/threshold_pruner_tests.rs b/tests/pruner/threshold.rs similarity index 100% rename from tests/threshold_pruner_tests.rs rename to tests/pruner/threshold.rs diff --git a/tests/bohb_integration.rs b/tests/sampler/bohb.rs similarity index 100% rename from tests/bohb_integration.rs rename to tests/sampler/bohb.rs diff --git a/tests/cma_es_tests.rs b/tests/sampler/cma_es.rs similarity index 99% rename from tests/cma_es_tests.rs rename to tests/sampler/cma_es.rs index 84b3700..2c3f992 100644 --- a/tests/cma_es_tests.rs +++ b/tests/sampler/cma_es.rs @@ -1,5 +1,3 @@ -#![cfg(feature = "cma-es")] - use optimizer::prelude::*; use optimizer::sampler::cma_es::CmaEsSampler; diff --git a/tests/differential_evolution_tests.rs b/tests/sampler/differential_evolution.rs similarity index 100% rename from tests/differential_evolution_tests.rs rename to tests/sampler/differential_evolution.rs diff --git a/tests/gp_tests.rs b/tests/sampler/gp.rs similarity index 99% rename from tests/gp_tests.rs rename to tests/sampler/gp.rs index b2dc23f..6ddf3b5 100644 --- a/tests/gp_tests.rs +++ b/tests/sampler/gp.rs @@ -1,5 +1,3 @@ -#![cfg(feature = "gp")] - use optimizer::prelude::*; use optimizer::sampler::gp::GpSampler; diff --git a/tests/sampler/main.rs b/tests/sampler/main.rs new file mode 100644 index 0000000..9f0f825 --- /dev/null +++ b/tests/sampler/main.rs @@ -0,0 +1,15 @@ +#![allow( + clippy::cast_sign_loss, + clippy::cast_precision_loss, + clippy::cast_possible_truncation +)] + +mod bohb; +#[cfg(feature = "cma-es")] +mod cma_es; +mod differential_evolution; +#[cfg(feature = "gp")] +mod gp; +mod multivariate_tpe; +mod random; +mod tpe; diff --git a/tests/multivariate_tpe_integration.rs b/tests/sampler/multivariate_tpe.rs similarity index 100% rename from tests/multivariate_tpe_integration.rs rename to tests/sampler/multivariate_tpe.rs diff --git a/tests/sampler/random.rs b/tests/sampler/random.rs new file mode 100644 index 0000000..911f192 --- /dev/null +++ b/tests/sampler/random.rs @@ -0,0 +1,142 @@ +use optimizer::parameter::{CategoricalParam, FloatParam, IntParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_random_sampler_uniform_float_distribution() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); + + let n_samples = 1000; + let mut samples = Vec::with_capacity(n_samples); + + let x_param = FloatParam::new(0.0, 1.0); + + study + .optimize(n_samples, |trial| { + let x = x_param.suggest(trial)?; + samples.push(x); + Ok::<_, Error>(x) + }) + .unwrap(); + + // All samples should be in range + for &s in &samples { + assert!((0.0..=1.0).contains(&s), "sample {s} out of range [0, 1]"); + } + + // Check distribution is roughly uniform by looking at quartiles + samples.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let q1 = samples[n_samples / 4]; + let q2 = samples[n_samples / 2]; + let q3 = samples[3 * n_samples / 4]; + + assert!((q1 - 0.25).abs() < 0.1, "Q1 {q1} should be close to 0.25"); + assert!( + (q2 - 0.5).abs() < 0.1, + "Q2 (median) {q2} should be close to 0.5" + ); + assert!((q3 - 0.75).abs() < 0.1, "Q3 {q3} should be close to 0.75"); +} + +#[test] +fn test_random_sampler_uniform_int_distribution() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(123)); + + let n_samples = 5000; + let mut counts = [0u32; 10]; // counts for values 1-10 + + let n_param = IntParam::new(1, 10); + + study + .optimize(n_samples, |trial| { + let n = n_param.suggest(trial)?; + assert!((1..=10).contains(&n), "sample {n} out of range [1, 10]"); + counts[(n - 1) as usize] += 1; + Ok::<_, Error>(n as f64) + }) + .unwrap(); + + let expected = n_samples as f64 / 10.0; + for (i, &count) in counts.iter().enumerate() { + let diff = (count as f64 - expected).abs() / expected; + assert!( + diff < 0.2, + "value {} appeared {} times, expected ~{}, diff = {:.1}%", + i + 1, + count, + expected, + diff * 100.0 + ); + } +} + +#[test] +fn test_random_sampler_uniform_categorical_distribution() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(456)); + + let n_samples = 2000; + let mut counts = [0u32; 4]; + let choices = ["a", "b", "c", "d"]; + + let cat_param = CategoricalParam::new(choices.to_vec()); + + study + .optimize(n_samples, |trial| { + let choice = cat_param.suggest(trial)?; + let idx = choices.iter().position(|&c| c == choice).unwrap(); + counts[idx] += 1; + Ok::<_, Error>(idx as f64) + }) + .unwrap(); + + let expected = n_samples as f64 / 4.0; + for (i, &count) in counts.iter().enumerate() { + let diff = (count as f64 - expected).abs() / expected; + assert!( + diff < 0.15, + "category {} appeared {} times, expected ~{}, diff = {:.1}%", + i, + count, + expected, + diff * 100.0 + ); + } +} + +#[test] +fn test_random_sampler_reproducibility() { + let study1: Study<f64> = + Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(999)); + let study2: Study<f64> = + Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(999)); + + let mut values1 = Vec::new(); + let mut values2 = Vec::new(); + + let x_param1 = FloatParam::new(0.0, 100.0); + let x_param2 = FloatParam::new(0.0, 100.0); + + study1 + .optimize(100, |trial| { + let x = x_param1.suggest(trial)?; + values1.push(x); + Ok::<_, Error>(x) + }) + .unwrap(); + + study2 + .optimize(100, |trial| { + let x = x_param2.suggest(trial)?; + values2.push(x); + Ok::<_, Error>(x) + }) + .unwrap(); + + for (i, (v1, v2)) in values1.iter().zip(values2.iter()).enumerate() { + assert_eq!( + v1, v2, + "values at trial {i} should be identical with same seed: {v1} vs {v2}" + ); + } +} diff --git a/tests/sampler/tpe.rs b/tests/sampler/tpe.rs new file mode 100644 index 0000000..8091077 --- /dev/null +++ b/tests/sampler/tpe.rs @@ -0,0 +1,381 @@ +use optimizer::parameter::{BoolParam, CategoricalParam, FloatParam, IntParam, Parameter}; +use optimizer::sampler::tpe::TpeSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_tpe_optimizes_quadratic_function() { + // Minimize f(x) = (x - 3)^2 where x in [-10, 10] + // Optimal: x = 3, f(3) = 0 + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(10) + .n_ei_candidates(24) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x_param = FloatParam::new(-10.0, 10.0); + + study + .optimize(100, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>((x - 3.0).powi(2)) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().expect("should have at least one trial"); + + // TPE should find a reasonable value over 100 trials + // With random startup + TPE, we expect to get within a few units of optimal + assert!( + best.value < 5.0, + "TPE should find near-optimal: best value {} should be < 5.0", + best.value + ); +} + +#[test] +fn test_tpe_optimizes_multivariate_function() { + // Minimize f(x, y) = x^2 + y^2 where x, y in [-5, 5] + // Optimal: (0, 0), f(0, 0) = 0 + let sampler = TpeSampler::builder() + .seed(123) + .n_startup_trials(10) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x_param = FloatParam::new(-5.0, 5.0); + let y_param = FloatParam::new(-5.0, 5.0); + + study + .optimize(100, |trial| { + let x = x_param.suggest(trial)?; + let y = y_param.suggest(trial)?; + Ok::<_, Error>(x * x + y * y) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().expect("should have at least one trial"); + + // TPE should find a reasonably good solution + assert!( + best.value < 5.0, + "TPE should find near-optimal: best value {} should be < 5.0", + best.value + ); +} + +#[test] +fn test_tpe_maximization() { + // Maximize f(x) = -(x - 2)^2 + 10 where x in [-10, 10] + // Optimal: x = 2, f(2) = 10 + let sampler = TpeSampler::builder() + .seed(456) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); + + let x_param = FloatParam::new(-10.0, 10.0); + + study + .optimize(50, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(-(x - 2.0).powi(2) + 10.0) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().expect("should have at least one trial"); + + assert!( + best.value > 5.0, + "TPE should find reasonably good solution: best value {} should be > 5.0", + best.value + ); +} + +#[test] +fn test_tpe_with_categorical_parameter() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); + + let model_param = CategoricalParam::new(vec!["linear", "quadratic", "cubic"]); + let x_param = FloatParam::new(0.0, 2.0); + + // Optimization where the best choice depends on the categorical + study + .optimize(30, |trial| { + let choice = model_param.suggest(trial)?; + let x = x_param.suggest(trial)?; + + // cubic model is best at x=1 + let value = match choice { + "linear" => x, + "quadratic" => x * x, + "cubic" => -((x - 1.0).powi(2)) + 10.0, // peak at x=1, max value 10 + _ => unreachable!(), + }; + Ok::<_, Error>(value) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().expect("should have best trial"); + assert!( + best.value > 5.0, + "should find good solution, got {}", + best.value + ); +} + +#[test] +fn test_tpe_with_integer_parameters() { + let sampler = TpeSampler::builder() + .seed(789) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let n_param = IntParam::new(1, 10); + + // Minimize (n - 7)^2 where n in [1, 10] + study + .optimize(30, |trial| { + let n = n_param.suggest(trial)?; + Ok::<_, Error>(((n - 7) as f64).powi(2)) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().expect("should have best trial"); + + assert!( + best.value < 5.0, + "should find n close to 7, best value = {}", + best.value + ); +} + +#[test] +fn test_tpe_with_log_scale_int() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let batch_param = IntParam::new(1, 1024).log_scale(); + + study + .optimize(20, |trial| { + let batch_size = batch_param.suggest(trial)?; + Ok::<_, Error>(((batch_size as f64).log2() - 5.0).powi(2)) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().unwrap(); + assert!(best.value < 10.0, "should find reasonable solution"); +} + +#[test] +fn test_tpe_with_step_distributions() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(0.0, 10.0).step(0.5); + let n_param = IntParam::new(0, 100).step(10); + + study + .optimize(20, |trial| { + let x = x_param.suggest(trial)?; + let n = n_param.suggest(trial)?; + Ok::<_, Error>((x - 5.0).powi(2) + ((n - 50) as f64).powi(2)) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().unwrap(); + assert!(best.value < 100.0, "should find reasonable solution"); +} + +#[test] +fn test_tpe_with_fixed_kde_bandwidth() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .kde_bandwidth(0.5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(-5.0, 5.0); + + study + .optimize(20, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) + }) + .expect("optimization should succeed"); + + let best = study.best_trial().unwrap(); + assert!(best.value < 10.0, "should find reasonable solution"); +} + +#[test] +fn test_tpe_sampler_invalid_kde_bandwidth() { + let result = TpeSampler::with_config(0.25, 10, 24, Some(-1.0), None); + assert!(matches!(result, Err(Error::InvalidBandwidth(_)))); +} + +#[test] +fn test_tpe_split_trials_with_two_trials() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(2) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(0.0, 10.0); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .expect("optimization should succeed with small history"); + + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_tpe_empty_good_or_bad_values_fallback() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .gamma(0.1) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(0.0, 10.0); + let y_param = FloatParam::new(0.0, 10.0); + + // First optimize with one parameter + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + // Now try with a different parameter - TPE won't have history for "y" + study + .optimize(5, |trial| { + let y = y_param.suggest(trial)?; + Ok::<_, Error>(y) + }) + .unwrap(); + + assert_eq!(study.n_trials(), 15); +} + +#[test] +fn test_tpe_sampler_builder_default_trait() { + use optimizer::sampler::tpe::TpeSamplerBuilder; + + let builder = TpeSamplerBuilder::default(); + let sampler = builder.build().unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(0.0, 1.0); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_tpe_sampler_default_trait() { + let sampler = TpeSampler::default(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(0.0, 1.0); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_suggest_bool_with_tpe() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let use_large_param = BoolParam::new(); + let x_param = FloatParam::new(0.0, 10.0); + + study + .optimize(20, |trial| { + let use_large = use_large_param.suggest(trial)?; + let x = x_param.suggest(trial)?; + // The value depends on use_large flag + let base = if use_large { x * 2.0 } else { x }; + Ok::<_, Error>(base) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!(best.value < 10.0); +} + +#[test] +fn test_params_with_tpe() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x_param = FloatParam::new(-5.0, 5.0); + let n_param = IntParam::new(1, 10); + + study + .optimize(30, |trial| { + let x = x_param.suggest(trial)?; + let n = n_param.suggest(trial)?; + Ok::<_, Error>(x * x + (n as f64 - 5.0).powi(2)) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!(best.value < 10.0, "TPE should find good solution"); +} diff --git a/tests/study/ask_tell.rs b/tests/study/ask_tell.rs new file mode 100644 index 0000000..cf7d1d0 --- /dev/null +++ b/tests/study/ask_tell.rs @@ -0,0 +1,98 @@ +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::tpe::TpeSampler; +use optimizer::{Direction, Study}; + +#[test] +fn test_ask_and_tell_basic() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + + for _ in 0..10 { + let mut trial = study.ask(); + let x = x_param.suggest(&mut trial).unwrap(); + let value = x * x; + study.tell(trial, Ok::<_, &str>(value)); + } + + assert_eq!(study.n_trials(), 10); + assert!(study.best_value().unwrap() >= 0.0); +} + +#[test] +fn test_ask_and_tell_with_failures() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(-5.0, 5.0); + + // Alternate success and failure + for i in 0..10 { + let mut trial = study.ask(); + let x = x_param.suggest(&mut trial).unwrap(); + if i % 2 == 0 { + study.tell(trial, Ok::<_, &str>(x * x)); + } else { + study.tell(trial, Err::<f64, _>("simulated failure")); + } + } + + // Only successful trials are counted + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_ask_and_tell_with_tpe_sampler() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + let study: Study<f64> = Study::minimize(sampler); + let x_param = FloatParam::new(-10.0, 10.0); + + for _ in 0..30 { + let mut trial = study.ask(); + let x = x_param.suggest(&mut trial).unwrap(); + study.tell(trial, Ok::<_, &str>((x - 3.0).powi(2))); + } + + assert_eq!(study.n_trials(), 30); + assert!( + study.best_value().unwrap() < 5.0, + "TPE ask-and-tell should find a reasonable value" + ); +} + +#[test] +fn test_ask_and_tell_batch() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + + // Ask a batch of trials + let batch: Vec<_> = (0..5) + .map(|_| { + let mut t = study.ask(); + let x = x_param.suggest(&mut t).unwrap(); + (t, x) + }) + .collect(); + + // Tell results for the batch + for (trial, x) in batch { + study.tell(trial, Ok::<_, &str>(x * x)); + } + + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_ask_and_tell_with_custom_value_type() { + // Ask-and-tell works with non-f64 value types too + let study: Study<i32> = Study::new(Direction::Maximize); + + for i in 0..5 { + let trial = study.ask(); + study.tell(trial, Ok::<_, &str>(i * 10)); + } + + assert_eq!(study.n_trials(), 5); + assert_eq!(study.best_value().unwrap(), 40); +} diff --git a/tests/study/builder.rs b/tests/study/builder.rs new file mode 100644 index 0000000..8b421ff --- /dev/null +++ b/tests/study/builder.rs @@ -0,0 +1,92 @@ +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::sampler::tpe::TpeSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_builder_defaults() { + let study: Study<f64> = Study::builder().build(); + assert_eq!(study.direction(), Direction::Minimize); +} + +#[test] +fn test_builder_maximize() { + let study: Study<f64> = Study::builder().maximize().build(); + assert_eq!(study.direction(), Direction::Maximize); +} + +#[test] +fn test_builder_minimize() { + let study: Study<f64> = Study::builder().minimize().build(); + assert_eq!(study.direction(), Direction::Minimize); +} + +#[test] +fn test_builder_direction() { + let study: Study<f64> = Study::builder().direction(Direction::Maximize).build(); + assert_eq!(study.direction(), Direction::Maximize); +} + +#[test] +fn test_builder_with_sampler() { + let x = FloatParam::new(-5.0, 5.0); + let study: Study<f64> = Study::builder().sampler(TpeSampler::new()).build(); + + study + .optimize(10, |trial| { + let val = x.suggest(trial)?; + Ok::<_, Error>(val * val) + }) + .unwrap(); + + assert_eq!(study.trials().len(), 10); +} + +#[test] +fn test_builder_with_pruner() { + use optimizer::pruner::NopPruner; + + let study: Study<f64> = Study::builder().pruner(NopPruner).build(); + + assert_eq!(study.direction(), Direction::Minimize); +} + +#[test] +fn test_builder_chaining() { + let study: Study<f64> = Study::builder() + .maximize() + .sampler(RandomSampler::with_seed(42)) + .pruner(optimizer::pruner::NopPruner) + .build(); + + assert_eq!(study.direction(), Direction::Maximize); +} + +#[test] +fn test_builder_with_custom_value_type() { + let study: Study<i32> = Study::builder().maximize().build(); + assert_eq!(study.direction(), Direction::Maximize); +} + +#[test] +fn test_builder_optimizes_correctly() { + let x = FloatParam::new(-10.0, 10.0); + let study: Study<f64> = Study::builder() + .minimize() + .sampler(TpeSampler::builder().seed(42).build().unwrap()) + .build(); + + study + .optimize(100, |trial| { + let val = x.suggest(trial)?; + Ok::<_, Error>((val - 3.0) * (val - 3.0)) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!( + best.value < 5.0, + "best value should be < 5.0, got {}", + best.value + ); +} diff --git a/tests/study/constraints.rs b/tests/study/constraints.rs new file mode 100644 index 0000000..592ac93 --- /dev/null +++ b/tests/study/constraints.rs @@ -0,0 +1,101 @@ +use optimizer::{Direction, Study}; + +#[test] +fn test_is_feasible_all_satisfied() { + let study: Study<f64> = Study::new(Direction::Minimize); + let mut trial = study.create_trial(); + trial.set_constraints(vec![-1.0, 0.0, -0.5]); + study.complete_trial(trial, 1.0); + + let completed = study.best_trial().unwrap(); + assert!(completed.is_feasible()); +} + +#[test] +fn test_is_feasible_one_violated() { + let study: Study<f64> = Study::new(Direction::Minimize); + let mut trial = study.create_trial(); + trial.set_constraints(vec![-1.0, 0.5, -0.5]); + study.complete_trial(trial, 1.0); + + let completed = study.best_trial().unwrap(); + assert!(!completed.is_feasible()); +} + +#[test] +fn test_is_feasible_empty_constraints() { + let study: Study<f64> = Study::new(Direction::Minimize); + let trial = study.create_trial(); + study.complete_trial(trial, 1.0); + + let completed = study.best_trial().unwrap(); + assert!(completed.is_feasible()); +} + +#[test] +fn test_best_trial_prefers_feasible() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // Infeasible trial with better objective + let mut trial1 = study.create_trial(); + trial1.set_constraints(vec![1.0]); + study.complete_trial(trial1, 0.1); + + // Feasible trial with worse objective + let mut trial2 = study.create_trial(); + trial2.set_constraints(vec![-1.0]); + study.complete_trial(trial2, 100.0); + + let best = study.best_trial().unwrap(); + assert_eq!(best.id, 1); // feasible trial wins + assert_eq!(best.value, 100.0); +} + +#[test] +fn test_best_trial_feasible_by_objective() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // Feasible, worse objective + let mut trial1 = study.create_trial(); + trial1.set_constraints(vec![-1.0]); + study.complete_trial(trial1, 10.0); + + // Feasible, better objective + let mut trial2 = study.create_trial(); + trial2.set_constraints(vec![-0.5]); + study.complete_trial(trial2, 2.0); + + let best = study.best_trial().unwrap(); + assert_eq!(best.id, 1); // lower objective wins among feasible + assert_eq!(best.value, 2.0); +} + +#[test] +fn test_top_trials_ranks_feasible_above_infeasible() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // Infeasible, low violation + let mut t0 = study.create_trial(); + t0.set_constraints(vec![0.5]); + study.complete_trial(t0, 1.0); + + // Feasible, worst objective among feasible + let mut t1 = study.create_trial(); + t1.set_constraints(vec![-1.0]); + study.complete_trial(t1, 50.0); + + // Feasible, best objective among feasible + let mut t2 = study.create_trial(); + t2.set_constraints(vec![-0.1]); + study.complete_trial(t2, 5.0); + + // Infeasible, high violation + let mut t3 = study.create_trial(); + t3.set_constraints(vec![3.0]); + study.complete_trial(t3, 0.5); + + let top = study.top_trials(4); + let ids: Vec<u64> = top.iter().map(|t| t.id).collect(); + // Feasible sorted by objective first (5.0, 50.0), then infeasible by violation (0.5, 3.0) + assert_eq!(ids, vec![2, 1, 0, 3]); +} diff --git a/tests/study/enqueue.rs b/tests/study/enqueue.rs new file mode 100644 index 0000000..c84e008 --- /dev/null +++ b/tests/study/enqueue.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; + +use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_enqueue_params_evaluated_first() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + let y = IntParam::new(1, 100); + + // Enqueue a specific configuration + study.enqueue(HashMap::from([ + (x.id(), ParamValue::Float(5.0)), + (y.id(), ParamValue::Int(42)), + ])); + + // The first trial should use the enqueued params + let mut trial = study.ask(); + let x_val = x.suggest(&mut trial).unwrap(); + let y_val = y.suggest(&mut trial).unwrap(); + + assert_eq!(x_val, 5.0); + assert_eq!(y_val, 42); +} + +#[test] +fn test_enqueue_fifo_order() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + // Enqueue two configs + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); + + // First trial gets first enqueued value + let mut trial1 = study.ask(); + assert_eq!(x.suggest(&mut trial1).unwrap(), 1.0); + + // Second trial gets second enqueued value + let mut trial2 = study.ask(); + assert_eq!(x.suggest(&mut trial2).unwrap(), 2.0); +} + +#[test] +fn test_enqueue_then_normal_sampling_resumes() { + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let x = FloatParam::new(0.0, 10.0); + + // Enqueue one config + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(5.0))])); + + // First trial uses enqueued value + let mut trial1 = study.ask(); + assert_eq!(x.suggest(&mut trial1).unwrap(), 5.0); + study.tell(trial1, Ok::<_, &str>(25.0)); + + // Second trial uses normal sampling (not 5.0) + let mut trial2 = study.ask(); + let x_val = x.suggest(&mut trial2).unwrap(); + // The sampled value should be in [0, 10] but extremely unlikely to be exactly 5.0 + assert!((0.0..=10.0).contains(&x_val)); +} + +#[test] +fn test_enqueue_with_optimize() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + // Enqueue two specific configs + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); + + let mut values = Vec::new(); + + study + .optimize(5, |trial| { + let x_val = x.suggest(trial)?; + values.push(x_val); + Ok::<_, Error>(x_val * x_val) + }) + .unwrap(); + + // First two trials should use enqueued values + assert_eq!(values[0], 1.0); + assert_eq!(values[1], 2.0); + // All 5 trials should have completed + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_enqueue_partial_params_fall_back_to_sampling() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + let y = IntParam::new(1, 100); + + // Enqueue only x, not y + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(3.0))])); + + let mut trial = study.ask(); + let x_val = x.suggest(&mut trial).unwrap(); + let y_val = y.suggest(&mut trial).unwrap(); + + // x should be the enqueued value + assert_eq!(x_val, 3.0); + // y should be sampled (within range) + assert!((1..=100).contains(&y_val)); +} + +#[test] +fn test_enqueue_trials_appear_in_completed_trials() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(7.0))])); + + study + .optimize(1, |trial| { + let x_val = x.suggest(trial)?; + Ok::<_, Error>(x_val) + }) + .unwrap(); + + let trials = study.trials(); + assert_eq!(trials.len(), 1); + assert_eq!(trials[0].value, 7.0); + assert_eq!( + *trials[0].params.get(&x.id()).unwrap(), + ParamValue::Float(7.0) + ); +} + +#[test] +fn test_enqueue_with_ask_and_tell() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(4.0))])); + + let mut trial = study.ask(); + let x_val = x.suggest(&mut trial).unwrap(); + assert_eq!(x_val, 4.0); + + study.tell(trial, Ok::<_, &str>(x_val * x_val)); + assert_eq!(study.n_trials(), 1); + assert_eq!(study.best_value().unwrap(), 16.0); +} + +#[test] +fn test_n_enqueued() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + assert_eq!(study.n_enqueued(), 0); + + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); + assert_eq!(study.n_enqueued(), 1); + + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); + assert_eq!(study.n_enqueued(), 2); + + // Creating a trial dequeues one + let _ = study.ask(); + assert_eq!(study.n_enqueued(), 1); + + let _ = study.ask(); + assert_eq!(study.n_enqueued(), 0); +} + +#[test] +fn test_enqueue_counted_in_n_trials() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x = FloatParam::new(0.0, 10.0); + + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); + study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); + + study + .optimize(5, |trial| { + let x_val = x.suggest(trial)?; + Ok::<_, Error>(x_val) + }) + .unwrap(); + + // All 5 trials count, including the 2 enqueued ones + assert_eq!(study.n_trials(), 5); +} diff --git a/tests/study/iterator.rs b/tests/study/iterator.rs new file mode 100644 index 0000000..4578a85 --- /dev/null +++ b/tests/study/iterator.rs @@ -0,0 +1,43 @@ +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::{Direction, Study}; + +#[test] +fn test_into_iterator_iterates_all_trials() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); + let x_param = FloatParam::new(0.0, 10.0); + + for _ in 0..5 { + let mut trial = study.create_trial(); + let x = x_param.suggest(&mut trial).unwrap(); + study.complete_trial(trial, x * x); + } + + let mut count = 0; + for trial in &study { + assert_eq!(trial.state, optimizer::TrialState::Complete); + count += 1; + } + assert_eq!(count, 5); +} + +#[test] +fn test_into_iterator_empty_study() { + let study: Study<f64> = Study::new(Direction::Minimize); + + let count = (&study).into_iter().count(); + assert_eq!(count, 0); +} + +#[test] +fn test_into_iterator_preserves_insertion_order() { + let study: Study<f64> = Study::new(Direction::Minimize); + + for i in 0..3 { + let trial = study.create_trial(); + study.complete_trial(trial, f64::from(i)); + } + + let ids: Vec<u64> = (&study).into_iter().map(|t| t.id).collect(); + assert_eq!(ids, vec![0, 1, 2]); +} diff --git a/tests/study/main.rs b/tests/study/main.rs new file mode 100644 index 0000000..d9a788c --- /dev/null +++ b/tests/study/main.rs @@ -0,0 +1,15 @@ +#![allow( + clippy::cast_sign_loss, + clippy::cast_precision_loss, + clippy::cast_possible_truncation +)] + +mod ask_tell; +mod builder; +mod constraints; +mod enqueue; +mod iterator; +mod objective; +mod summary; +mod top_trials; +mod workflow; diff --git a/tests/study/objective.rs b/tests/study/objective.rs new file mode 100644 index 0000000..6bb1944 --- /dev/null +++ b/tests/study/objective.rs @@ -0,0 +1,346 @@ +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::{Direction, Error, Study, Trial}; + +#[test] +fn test_callback_early_stopping() { + use std::ops::ControlFlow; + + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + + struct EarlyStopAfter5 { + x_param: FloatParam, + } + + impl Objective<f64> for EarlyStopAfter5 { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { + if study.n_trials() >= 5 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + study + .optimize_with( + 100, + EarlyStopAfter5 { + x_param: FloatParam::new(0.0, 10.0), + }, + ) + .expect("optimization should succeed"); + + assert_eq!(study.n_trials(), 5, "should have stopped after 5 trials"); +} + +#[test] +fn test_callback_early_stopping_on_first_trial() { + use std::ops::ControlFlow; + + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + + struct StopImmediately { + x_param: FloatParam, + } + + impl Objective<f64> for StopImmediately { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial( + &self, + _study: &Study<f64>, + _trial: &CompletedTrial<f64>, + ) -> ControlFlow<()> { + ControlFlow::Break(()) + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + study + .optimize_with( + 100, + StopImmediately { + x_param: FloatParam::new(0.0, 10.0), + }, + ) + .expect("optimization should succeed"); + + assert_eq!(study.n_trials(), 1, "should have stopped after 1 trial"); +} + +#[test] +fn test_callback_sampler_early_stopping() { + use std::ops::ControlFlow; + + use optimizer::Objective; + use optimizer::sampler::CompletedTrial; + + struct StopAfter3 { + x_param: FloatParam, + } + + impl Objective<f64> for StopAfter3 { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + Ok(x) + } + fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { + if study.n_trials() >= 3 { + ControlFlow::Break(()) + } else { + ControlFlow::Continue(()) + } + } + } + + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + study + .optimize_with( + 100, + StopAfter3 { + x_param: FloatParam::new(0.0, 10.0), + }, + ) + .expect("optimization should succeed"); + + assert_eq!(study.n_trials(), 3); +} + +#[test] +fn test_retries_successful_trials_not_retried() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use optimizer::Objective; + + struct SuccessObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for SuccessObj { + type Error = Error; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { + let x = self.x_param.suggest(trial)?; + self.call_count.fetch_add(1, Ordering::Relaxed); + Ok(x * x) + } + fn max_retries(&self) -> usize { + 3 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = SuccessObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; + + study.optimize_with(5, obj).unwrap(); + + // All trials succeed on first try — exactly 5 calls + assert_eq!(call_count.load(Ordering::Relaxed), 5); + assert_eq!(study.n_trials(), 5); +} + +#[test] +fn test_retries_failed_trials_retried_up_to_max() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use optimizer::Objective; + + struct AlwaysFailObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for AlwaysFailObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + self.call_count.fetch_add(1, Ordering::Relaxed); + Err("always fails".to_string()) + } + fn max_retries(&self) -> usize { + 3 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = AlwaysFailObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; + + let result = study.optimize_with(1, obj); + + // 1 initial attempt + 3 retries = 4 total calls + assert_eq!(call_count.load(Ordering::Relaxed), 4); + // No trials completed + assert!(matches!(result, Err(Error::NoCompletedTrials))); +} + +#[test] +fn test_retries_permanently_failed_after_exhaustion() { + use optimizer::Objective; + + struct AlwaysFailObj { + x_param: FloatParam, + } + + impl Objective<f64> for AlwaysFailObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + Err("transient error".to_string()) + } + fn max_retries(&self) -> usize { + 2 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let obj = AlwaysFailObj { + x_param: FloatParam::new(0.0, 10.0), + }; + + let result = study.optimize_with(3, obj); + + assert!( + matches!(result, Err(Error::NoCompletedTrials)), + "all trials should permanently fail" + ); + assert_eq!( + study.n_trials(), + 0, + "no completed trials should be recorded" + ); +} + +#[test] +fn test_retries_uses_same_parameters() { + use std::sync::atomic::{AtomicU32, Ordering}; + use std::sync::{Arc, Mutex}; + + use optimizer::Objective; + + struct RetryObj { + x_param: FloatParam, + seen_values: Arc<Mutex<Vec<f64>>>, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for RetryObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + self.seen_values.lock().unwrap().push(x); + let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; + // Fail first two attempts, succeed on third + if count < 3 { + Err("transient".to_string()) + } else { + Ok(x * x) + } + } + fn max_retries(&self) -> usize { + 2 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let seen_values = Arc::new(Mutex::new(Vec::new())); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = RetryObj { + x_param: FloatParam::new(0.0, 10.0), + seen_values: Arc::clone(&seen_values), + call_count: Arc::clone(&call_count), + }; + + study.optimize_with(1, obj).unwrap(); + + let values = seen_values.lock().unwrap(); + assert_eq!(values.len(), 3, "should be called 3 times (1 + 2 retries)"); + // All three calls should have gotten the same parameter value + assert_eq!(values[0], values[1]); + assert_eq!(values[1], values[2]); +} + +#[test] +fn test_retries_n_trials_counts_unique_configs() { + use std::sync::Arc; + use std::sync::atomic::{AtomicU32, Ordering}; + + use optimizer::Objective; + + struct FailFirstObj { + x_param: FloatParam, + call_count: Arc<AtomicU32>, + } + + impl Objective<f64> for FailFirstObj { + type Error = String; + fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { + let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; + let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; + // Fail first attempt of each config, succeed on retry + if count % 2 == 1 { + Err("transient".to_string()) + } else { + Ok(x * x) + } + } + fn max_retries(&self) -> usize { + 2 + } + } + + let study: Study<f64> = Study::new(Direction::Minimize); + let call_count = Arc::new(AtomicU32::new(0)); + let obj = FailFirstObj { + x_param: FloatParam::new(0.0, 10.0), + call_count: Arc::clone(&call_count), + }; + + study.optimize_with(3, obj).unwrap(); + + // 3 unique configs, each needing 2 calls = 6 total calls + assert_eq!(call_count.load(Ordering::Relaxed), 6); + // But only 3 completed trials + assert_eq!(study.n_trials(), 3); +} + +#[test] +fn test_retries_with_zero_max_retries_same_as_optimize() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + let call_count = std::cell::Cell::new(0u32); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + call_count.set(call_count.get() + 1); + Ok::<_, Error>(x * x) + }) + .unwrap(); + + assert_eq!(call_count.get(), 5); + assert_eq!(study.n_trials(), 5); +} diff --git a/tests/study/summary.rs b/tests/study/summary.rs new file mode 100644 index 0000000..8f688af --- /dev/null +++ b/tests/study/summary.rs @@ -0,0 +1,71 @@ +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_summary_with_completed_trials() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); + let x = FloatParam::new(0.0, 10.0).name("x"); + + study + .optimize(5, |trial| { + let val = x.suggest(trial)?; + Ok::<_, Error>(val * val) + }) + .unwrap(); + + let summary = study.summary(); + assert!(summary.contains("Minimize")); + assert!(summary.contains("5 trials")); + assert!(summary.contains("Best value:")); + assert!(summary.contains("x = ")); +} + +#[test] +fn test_summary_no_completed_trials() { + let study: Study<f64> = Study::new(Direction::Maximize); + let summary = study.summary(); + assert!(summary.contains("Maximize")); + assert!(summary.contains("0 trials")); + assert!(!summary.contains("Best value:")); +} + +#[test] +fn test_summary_with_pruned_trials() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); + let x = FloatParam::new(0.0, 10.0).name("x"); + + // Manually create some complete and pruned trials + for _ in 0..3 { + let mut trial = study.create_trial(); + let val = x.suggest(&mut trial).unwrap(); + study.complete_trial(trial, val); + } + for _ in 0..2 { + let mut trial = study.create_trial(); + let _ = x.suggest(&mut trial).unwrap(); + study.prune_trial(trial); + } + + let summary = study.summary(); + // Should show breakdown when there are pruned trials + if study.n_pruned_trials() > 0 { + assert!(summary.contains("complete")); + assert!(summary.contains("pruned")); + } +} + +#[test] +fn test_display_matches_summary() { + let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(1)); + let x = FloatParam::new(0.0, 10.0).name("x"); + + study + .optimize(3, |trial| { + let val = x.suggest(trial)?; + Ok::<_, Error>(val) + }) + .unwrap(); + + assert_eq!(format!("{study}"), study.summary()); +} diff --git a/tests/study/top_trials.rs b/tests/study/top_trials.rs new file mode 100644 index 0000000..a3844af --- /dev/null +++ b/tests/study/top_trials.rs @@ -0,0 +1,76 @@ +use optimizer::{Direction, Study}; + +#[test] +fn test_top_trials_minimize() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // Manually complete trials with known values + for &val in &[5.0, 1.0, 3.0, 2.0, 4.0] { + let trial = study.create_trial(); + study.complete_trial(trial, val); + } + + let top3 = study.top_trials(3); + assert_eq!(top3.len(), 3); + assert_eq!(top3[0].value, 1.0); + assert_eq!(top3[1].value, 2.0); + assert_eq!(top3[2].value, 3.0); +} + +#[test] +fn test_top_trials_maximize() { + let study: Study<f64> = Study::new(Direction::Maximize); + + for &val in &[5.0, 1.0, 3.0, 2.0, 4.0] { + let trial = study.create_trial(); + study.complete_trial(trial, val); + } + + let top3 = study.top_trials(3); + assert_eq!(top3.len(), 3); + assert_eq!(top3[0].value, 5.0); + assert_eq!(top3[1].value, 4.0); + assert_eq!(top3[2].value, 3.0); +} + +#[test] +fn test_top_trials_n_greater_than_total() { + let study: Study<f64> = Study::new(Direction::Minimize); + + for &val in &[3.0, 1.0] { + let trial = study.create_trial(); + study.complete_trial(trial, val); + } + + let top = study.top_trials(10); + assert_eq!(top.len(), 2); + assert_eq!(top[0].value, 1.0); + assert_eq!(top[1].value, 3.0); +} + +#[test] +fn test_top_trials_empty() { + let study: Study<f64> = Study::new(Direction::Minimize); + + let top = study.top_trials(5); + assert!(top.is_empty()); +} + +#[test] +fn test_top_trials_excludes_pruned() { + let study: Study<f64> = Study::new(Direction::Minimize); + + // Complete some trials + for &val in &[5.0, 1.0, 3.0] { + let trial = study.create_trial(); + study.complete_trial(trial, val); + } + + // Prune a trial (it gets a default value of 0.0 but should be excluded) + let trial = study.create_trial(); + study.prune_trial(trial); + + let top = study.top_trials(5); + assert_eq!(top.len(), 3, "pruned trial should be excluded"); + assert_eq!(top[0].value, 1.0); +} diff --git a/tests/study/workflow.rs b/tests/study/workflow.rs new file mode 100644 index 0000000..f41bff8 --- /dev/null +++ b/tests/study/workflow.rs @@ -0,0 +1,259 @@ +use optimizer::parameter::{BoolParam, FloatParam, IntParam, Parameter}; +use optimizer::sampler::tpe::TpeSampler; +use optimizer::{Direction, Error, Study}; + +#[test] +fn test_study_basic_workflow() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(-5.0, 5.0); + + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) + }) + .expect("optimization should succeed"); + + assert_eq!(study.n_trials(), 10); + let best = study.best_trial().expect("should have best trial"); + assert!(best.value >= 0.0, "x^2 should be non-negative"); +} + +#[test] +fn test_study_with_failures() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(-5.0, 5.0); + + // Every other trial fails + let mut counter = 0; + study + .optimize(10, |trial| { + counter += 1; + if counter % 2 == 0 { + return Err::<f64, &str>("intentional failure"); + } + let x = x_param.suggest(trial).map_err(|_| "param error")?; + Ok(x * x) + }) + .expect("optimization should succeed with some failures"); + + // Only half the trials should have succeeded + assert_eq!(study.n_trials(), 5, "only 5 trials should have completed"); +} + +#[test] +fn test_no_completed_trials_error() { + let study: Study<f64> = Study::new(Direction::Minimize); + + let result = study.best_trial(); + assert!(matches!(result, Err(Error::NoCompletedTrials))); +} + +#[test] +fn test_study_direction() { + let study_min: Study<f64> = Study::new(Direction::Minimize); + assert_eq!(study_min.direction(), Direction::Minimize); + + let study_max: Study<f64> = Study::new(Direction::Maximize); + assert_eq!(study_max.direction(), Direction::Maximize); +} + +#[test] +fn test_study_trials_iteration() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 1.0); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + let trials = study.trials(); + assert_eq!(trials.len(), 5); + + for trial in &trials { + assert!( + !trial.params.is_empty(), + "each trial should have parameters" + ); + } +} + +#[test] +fn test_study_set_sampler() { + let mut study: Study<f64> = Study::new(Direction::Minimize); + + let tpe = TpeSampler::builder() + .seed(42) + .n_startup_trials(5) + .build() + .unwrap(); + study.set_sampler(tpe); + + let x_param = FloatParam::new(-5.0, 5.0); + + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x * x) + }) + .expect("optimization should succeed with new sampler"); + + assert_eq!(study.n_trials(), 10); +} + +#[test] +fn test_study_with_i32_value_type() { + let study: Study<i32> = Study::new(Direction::Minimize); + let x_param = IntParam::new(-10, 10); + + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x.abs() as i32) + }) + .expect("optimization should succeed"); + + assert_eq!(study.n_trials(), 10); + let best = study.best_trial().expect("should have best trial"); + assert!(best.value >= 0, "absolute value should be non-negative"); +} + +#[test] +fn test_optimize_all_trials_fail() { + let study: Study<f64> = Study::new(Direction::Minimize); + + let result = study.optimize(5, |_trial| Err::<f64, &str>("always fails")); + + assert!( + matches!(result, Err(Error::NoCompletedTrials)), + "should return NoCompletedTrials when all trials fail" + ); +} + +#[test] +fn test_best_value() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + let best_value = study.best_value().expect("should have best value"); + let best_trial = study.best_trial().expect("should have best trial"); + + assert_eq!( + best_value, best_trial.value, + "best_value should match best_trial.value" + ); +} + +#[test] +fn test_best_trial_with_nan_values() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .unwrap(); + + let best = study.best_trial(); + assert!(best.is_ok()); +} + +#[test] +fn test_manual_trial_completion() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(0.0, 10.0); + + // Manually create and complete trials + let mut trial = study.create_trial(); + let x = x_param.suggest(&mut trial).unwrap(); + study.complete_trial(trial, x * x); + + let mut trial2 = study.create_trial(); + let y = x_param.suggest(&mut trial2).unwrap(); + study.complete_trial(trial2, y * y); + + // Manually fail a trial + let trial3 = study.create_trial(); + study.fail_trial(trial3, "test failure"); + + // Only 2 completed trials + assert_eq!(study.n_trials(), 2); +} + +#[test] +fn test_multiple_params_in_optimization() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(-10.0, 10.0); + let n_param = IntParam::new(1, 5); + + study + .optimize(10, |trial| { + let x = x_param.suggest(trial)?; + let n = n_param.suggest(trial)?; + Ok::<_, Error>(x * x + n as f64) + }) + .unwrap(); + + assert_eq!(study.n_trials(), 10); +} + +#[test] +fn test_suggest_bool_in_optimization() { + let study: Study<f64> = Study::new(Direction::Minimize); + let use_feature_param = BoolParam::new(); + let x_param = FloatParam::new(0.0, 10.0); + + study + .optimize(10, |trial| { + let use_feature = use_feature_param.suggest(trial)?; + let x = x_param.suggest(trial)?; + + let value = if use_feature { x } else { x * 2.0 }; + Ok::<_, Error>(value) + }) + .unwrap(); + + assert_eq!(study.n_trials(), 10); +} + +#[test] +fn test_completed_trial_get() { + let study: Study<f64> = Study::new(Direction::Minimize); + let x_param = FloatParam::new(-10.0, 10.0).name("x"); + let n_param = IntParam::new(1, 10).name("n"); + + study + .optimize(5, |trial| { + let x = x_param.suggest(trial)?; + let n = n_param.suggest(trial)?; + Ok::<_, Error>(x * x + n as f64) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + let x_val: f64 = best.get(&x_param).unwrap(); + let n_val: i64 = best.get(&n_param).unwrap(); + assert!((-10.0..=10.0).contains(&x_val)); + assert!((1..=10).contains(&n_val)); +} + +#[test] +fn test_single_value_int_range() { + let param = IntParam::new(5, 5); + let mut trial = optimizer::Trial::new(0); + + let n = param.suggest(&mut trial).unwrap(); + assert_eq!(n, 5, "single-value range should return that value"); +} From c20a53dfbac3c48346dcd5f93967940675d63a31 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 12:52:02 +0100 Subject: [PATCH 07/48] fix: bump minimum tokio version to 1.30 for JoinSet::len() --- Cargo.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 93d67ab..c746972 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ readme = "README.md" fastrand = "2.3" thiserror = "2" parking_lot = "0.12" -tokio = { version = "1", features = ["sync", "rt-multi-thread"], optional = true } +tokio = { version = "1.30", features = ["sync", "rt-multi-thread"], optional = true } optimizer-derive = { version = "0.1.0", path = "optimizer-derive", optional = true } serde = { version = "1", features = ["derive"], optional = true } serde_json = { version = "1", optional = true } @@ -40,7 +40,7 @@ cma-es = ["dep:nalgebra"] gp = ["dep:nalgebra"] [dev-dependencies] -tokio = { version = "1", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.30", features = ["rt-multi-thread", "macros", "time"] } optimizer-derive = { version = "0.1.0", path = "optimizer-derive" } serde_json = "1" criterion = { version = "0.8", features = ["html_reports"] } From 47b5f9cec812a1976fa6c63a96753af054a1f182 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:09:14 +0100 Subject: [PATCH 08/48] feat: unify optimize and optimize_with via blanket Objective impl - Add blanket `impl Objective<V> for Fn(&mut Trial) -> Result<V, E>` so closures work directly with `optimize` - Rewrite optimize, optimize_async, optimize_parallel to accept `impl Objective<V>` with before_trial/after_trial hooks - Remove optimize_with, optimize_with_async, optimize_with_parallel - Remove max_retries and retry logic from Objective trait - Add explicit closure type annotations for HRTB inference - Convert FnMut test closures to Fn via RefCell/Cell --- benches/optimization.rs | 8 +- examples/basic_optimization.rs | 2 +- examples/early_stopping.rs | 2 +- examples/journal_storage.rs | 4 +- examples/multi_objective.rs | 2 +- examples/parameter_types.rs | 2 +- examples/pruning.rs | 2 +- examples/sampler_comparison.rs | 2 +- src/fanova.rs | 2 +- src/lib.rs | 2 +- src/multi_objective.rs | 4 +- src/objective.rs | 49 ++- src/sampler/mod.rs | 2 +- src/sampler/moead.rs | 2 +- src/sampler/motpe.rs | 7 +- src/sampler/nsga2.rs | 2 +- src/sampler/nsga3.rs | 2 +- src/sampler/tpe/multivariate.rs | 2 +- src/storage/journal.rs | 4 +- src/study.rs | 475 ++++-------------------- src/visualization.rs | 2 +- tests/export_tests.rs | 10 +- tests/fanova_tests.rs | 6 +- tests/journal_tests.rs | 10 +- tests/multi_objective_tests.rs | 46 +-- tests/parameter_tests.rs | 2 +- tests/sampler/bohb.rs | 6 +- tests/sampler/cma_es.rs | 20 +- tests/sampler/differential_evolution.rs | 26 +- tests/sampler/gp.rs | 20 +- tests/sampler/multivariate_tpe.rs | 20 +- tests/sampler/random.rs | 38 +- tests/sampler/tpe.rs | 30 +- tests/serde_tests.rs | 12 +- tests/study/builder.rs | 4 +- tests/study/enqueue.rs | 12 +- tests/study/objective.rs | 206 +--------- tests/study/summary.rs | 4 +- tests/study/workflow.rs | 30 +- tests/user_attr_tests.rs | 16 +- tests/visualization_tests.rs | 12 +- 41 files changed, 316 insertions(+), 793 deletions(-) diff --git a/benches/optimization.rs b/benches/optimization.rs index cfdaafe..3435781 100644 --- a/benches/optimization.rs +++ b/benches/optimization.rs @@ -23,7 +23,7 @@ fn bench_tpe_sphere(c: &mut Criterion) { b.iter(|| { let study = Study::minimize(TpeSampler::builder().seed(42).build().unwrap()); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -48,7 +48,7 @@ fn bench_tpe_rosenbrock(c: &mut Criterion) { b.iter(|| { let study = Study::minimize(TpeSampler::builder().seed(42).build().unwrap()); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -72,7 +72,7 @@ fn bench_random_vs_tpe(c: &mut Criterion) { b.iter(|| { let study = Study::minimize(RandomSampler::with_seed(42)); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -88,7 +88,7 @@ fn bench_random_vs_tpe(c: &mut Criterion) { b.iter(|| { let study = Study::minimize(TpeSampler::builder().seed(42).build().unwrap()); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) diff --git a/examples/basic_optimization.rs b/examples/basic_optimization.rs index 097853a..e868d75 100644 --- a/examples/basic_optimization.rs +++ b/examples/basic_optimization.rs @@ -17,7 +17,7 @@ fn main() { // Run 50 trials, each evaluating f(x) = (x - 3)² study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; let value = (x_val - 3.0).powi(2); Ok::<_, Error>(value) diff --git a/examples/early_stopping.rs b/examples/early_stopping.rs index efbf78c..7bfe707 100644 --- a/examples/early_stopping.rs +++ b/examples/early_stopping.rs @@ -44,7 +44,7 @@ fn main() -> optimizer::Result<()> { target: 0.01, }; - study.optimize_with(100, objective)?; + study.optimize(100, objective)?; let best = study.best_trial()?; println!( diff --git a/examples/journal_storage.rs b/examples/journal_storage.rs index c85eba4..a3410b6 100644 --- a/examples/journal_storage.rs +++ b/examples/journal_storage.rs @@ -24,7 +24,7 @@ fn main() -> optimizer::Result<()> { .storage(storage) .build(); - study.optimize(20, |trial| { + study.optimize(20, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) })?; @@ -46,7 +46,7 @@ fn main() -> optimizer::Result<()> { .build(); let before = study.n_trials(); - study.optimize(10, |trial| { + study.optimize(10, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) })?; diff --git a/examples/multi_objective.rs b/examples/multi_objective.rs index 3d66a8a..804e47a 100644 --- a/examples/multi_objective.rs +++ b/examples/multi_objective.rs @@ -15,7 +15,7 @@ fn main() -> optimizer::Result<()> { // Classic bi-objective: f1(x) = x², f2(x) = (x-1)² // The Pareto front is the curve where improving f1 worsens f2. - study.optimize(50, |trial| { + study.optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let f1 = xv * xv; let f2 = (xv - 1.0) * (xv - 1.0); diff --git a/examples/parameter_types.rs b/examples/parameter_types.rs index c509c5b..5efc823 100644 --- a/examples/parameter_types.rs +++ b/examples/parameter_types.rs @@ -41,7 +41,7 @@ fn main() { // --- Run the optimization --- study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let lr_val = lr.suggest(trial)?; let layers = n_layers.suggest(trial)?; let opt = optimizer.suggest(trial)?; diff --git a/examples/pruning.rs b/examples/pruning.rs index afbfea2..284afe5 100644 --- a/examples/pruning.rs +++ b/examples/pruning.rs @@ -26,7 +26,7 @@ fn main() -> optimizer::Result<()> { let n_epochs: u64 = 20; - study.optimize(30, |trial| { + study.optimize(30, |trial: &mut optimizer::Trial| { let lr_val = lr.suggest(trial)?; let mom = momentum.suggest(trial)?; diff --git a/examples/sampler_comparison.rs b/examples/sampler_comparison.rs index 3034ab3..9e484e2 100644 --- a/examples/sampler_comparison.rs +++ b/examples/sampler_comparison.rs @@ -20,7 +20,7 @@ fn run_study(study: Study<f64>, n_trials: usize) -> f64 { let y = FloatParam::new(-3.0, 3.0).name("y"); study - .optimize(n_trials, |trial| { + .optimize(n_trials, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; let y_val = y.suggest(trial)?; Ok::<_, Error>(sphere(x_val, y_val)) diff --git a/src/fanova.rs b/src/fanova.rs index 91b6af0..89cd452 100644 --- a/src/fanova.rs +++ b/src/fanova.rs @@ -41,7 +41,7 @@ //! let y = FloatParam::new(0.0, 10.0).name("y"); //! //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! let yv = y.suggest(trial)?; //! // x matters much more than y diff --git a/src/lib.rs b/src/lib.rs index a696b6a..1424b33 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -26,7 +26,7 @@ //! let x = FloatParam::new(-10.0, 10.0).name("x"); //! //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let v = x.suggest(trial)?; //! Ok::<_, Error>((v - 3.0).powi(2)) //! }) diff --git a/src/multi_objective.rs b/src/multi_objective.rs index 8dd144a..a82b1ad 100644 --- a/src/multi_objective.rs +++ b/src/multi_objective.rs @@ -35,7 +35,7 @@ //! let x = FloatParam::new(0.0, 1.0); //! //! study -//! .optimize(20, |trial| { +//! .optimize(20, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) //! }) @@ -229,7 +229,7 @@ impl Sampler for MoSamplerBridge { /// let x = FloatParam::new(0.0, 1.0); /// /// study -/// .optimize(30, |trial| { +/// .optimize(30, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) /// }) diff --git a/src/objective.rs b/src/objective.rs index d1ce710..c22babf 100644 --- a/src/objective.rs +++ b/src/objective.rs @@ -1,6 +1,9 @@ //! The [`Objective`] trait defines what gets optimized. //! -//! For simple closures, pass them directly to +//! # Closures work directly +//! +//! Any `Fn(&mut Trial) -> Result<V, E>` closure automatically implements +//! [`Objective`], so you can pass closures straight to //! [`Study::optimize`](crate::Study::optimize): //! //! ``` @@ -10,16 +13,18 @@ //! let x = FloatParam::new(-10.0, 10.0).name("x"); //! //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let v = x.suggest(trial)?; //! Ok::<_, Error>((v - 3.0).powi(2)) //! }) //! .unwrap(); //! ``` //! -//! For richer control — early stopping, retries, or per-trial logging — -//! implement [`Objective`] on a struct and pass it to -//! [`Study::optimize_with`](crate::Study::optimize_with): +//! # Structs for lifecycle hooks +//! +//! For richer control — early stopping or per-trial logging — implement +//! [`Objective`] on a struct and pass it to the same +//! [`Study::optimize`](crate::Study::optimize) method: //! //! ``` //! use std::ops::ControlFlow; @@ -54,7 +59,7 @@ //! x: FloatParam::new(-10.0, 10.0).name("x"), //! target: 1.0, //! }; -//! study.optimize_with(200, obj).unwrap(); +//! study.optimize(200, obj).unwrap(); //! assert!(study.best_value().unwrap() < 1.0); //! ``` @@ -69,15 +74,13 @@ use crate::trial::Trial; /// The only required method is [`evaluate`](Objective::evaluate), which /// computes the objective value for a given trial. Optional hooks provide /// early stopping ([`before_trial`](Objective::before_trial), -/// [`after_trial`](Objective::after_trial)) and automatic retries -/// ([`max_retries`](Objective::max_retries)). +/// [`after_trial`](Objective::after_trial)). /// -/// # When to use `Objective` vs a closure +/// # Closures implement `Objective` automatically /// -/// - **Closure** — pass directly to [`Study::optimize`](crate::Study::optimize) -/// for simple evaluate-only objectives. -/// - **`Objective` struct** — implement this trait when you need hooks -/// (`before_trial`, `after_trial`) or retries. +/// A blanket implementation covers all `Fn(&mut Trial) -> Result<V, E>` +/// closures, so you can pass closures directly to +/// [`Study::optimize`](crate::Study::optimize) without wrapping them. /// /// # Thread safety /// @@ -120,13 +123,19 @@ pub trait Objective<V: PartialOrd = f64> { fn after_trial(&self, _study: &Study<V>, _trial: &CompletedTrial<V>) -> ControlFlow<()> { ControlFlow::Continue(()) } +} - /// Maximum number of retries for a failed trial. - /// - /// When `evaluate` returns a non-pruning error and retries remain, - /// the same parameter configuration is re-evaluated. Set to `0` - /// (the default) to disable retries. - fn max_retries(&self) -> usize { - 0 +/// Blanket implementation: any `Fn(&mut Trial) -> Result<V, E>` is an +/// `Objective` with no lifecycle hooks. +impl<F, V, E> Objective<V> for F +where + F: Fn(&mut Trial) -> Result<V, E>, + V: PartialOrd, + E: ToString + 'static, +{ + type Error = E; + + fn evaluate(&self, trial: &mut Trial) -> Result<V, E> { + self(trial) } } diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index b512c49..cfa5537 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -138,7 +138,7 @@ impl<V> CompletedTrial<V> { /// let x = FloatParam::new(-10.0, 10.0); /// /// study - /// .optimize(5, |trial| { + /// .optimize(5, |trial: &mut optimizer::Trial| { /// let val = x.suggest(trial)?; /// Ok::<_, optimizer::Error>(val * val) /// }) diff --git a/src/sampler/moead.rs b/src/sampler/moead.rs index 6f11d44..4e6bb78 100644 --- a/src/sampler/moead.rs +++ b/src/sampler/moead.rs @@ -68,7 +68,7 @@ //! //! let x = FloatParam::new(0.0, 1.0); //! study -//! .optimize(100, |trial| { +//! .optimize(100, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) //! }) diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index 823b791..abffc78 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -48,7 +48,7 @@ //! //! let x = FloatParam::new(0.0, 1.0); //! study -//! .optimize(30, |trial| { +//! .optimize(30, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) //! }) @@ -103,7 +103,7 @@ use crate::{pareto, rng_util}; /// /// let x = FloatParam::new(0.0, 1.0); /// study -/// .optimize(30, |trial| { +/// .optimize(30, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) /// }) @@ -660,6 +660,7 @@ mod tests { use super::*; use crate::distribution::{CategoricalDistribution, FloatDistribution, IntDistribution}; use crate::parameter::ParamId; + use crate::trial::Trial; fn create_mo_trial( id: u64, @@ -930,7 +931,7 @@ mod tests { let x = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut Trial| { let xv = x.suggest(trial)?; Ok::<_, crate::Error>(vec![xv, 1.0 - xv]) }) diff --git a/src/sampler/nsga2.rs b/src/sampler/nsga2.rs index 0a8f873..f806032 100644 --- a/src/sampler/nsga2.rs +++ b/src/sampler/nsga2.rs @@ -55,7 +55,7 @@ //! //! let x = FloatParam::new(0.0, 1.0); //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! Ok::<_, optimizer::Error>(vec![xv * xv, (xv - 1.0).powi(2)]) //! }) diff --git a/src/sampler/nsga3.rs b/src/sampler/nsga3.rs index 44028ef..7ac840c 100644 --- a/src/sampler/nsga3.rs +++ b/src/sampler/nsga3.rs @@ -66,7 +66,7 @@ //! let x = FloatParam::new(0.0, 1.0); //! let y = FloatParam::new(0.0, 1.0); //! study -//! .optimize(100, |trial| { +//! .optimize(100, |trial: &mut optimizer::Trial| { //! let xv = x.suggest(trial)?; //! let yv = y.suggest(trial)?; //! Ok::<_, optimizer::Error>(vec![xv, yv, (1.0 - xv - yv).abs()]) diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate.rs index 72ea840..502125f 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate.rs @@ -205,7 +205,7 @@ pub enum ConstantLiarStrategy { /// let y = FloatParam::new(-5.0, 5.0); /// /// study -/// .optimize(30, |trial| { +/// .optimize(30, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// let yv = y.suggest(trial)?; /// Ok::<_, optimizer::Error>(xv * xv + yv * yv) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 367adfb..1ee477b 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -42,7 +42,7 @@ //! let storage = JournalStorage::<f64>::new("trials.jsonl"); //! let mut study = Study::builder().minimize().storage(storage).build(); //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let x = FloatParam::new(-5.0, 5.0).suggest(trial)?; //! Ok::<_, optimizer::Error>(x * x) //! }) @@ -52,7 +52,7 @@ //! let storage = JournalStorage::<f64>::open("trials.jsonl").unwrap(); //! let mut study = Study::builder().minimize().storage(storage).build(); //! study -//! .optimize(50, |trial| { +//! .optimize(50, |trial: &mut optimizer::Trial| { //! let x = FloatParam::new(-5.0, 5.0).suggest(trial)?; //! Ok::<_, optimizer::Error>(x * x) //! }) diff --git a/src/study.rs b/src/study.rs index fd9acb6..dc4c648 100644 --- a/src/study.rs +++ b/src/study.rs @@ -411,22 +411,6 @@ where .map(|t| t.id) } - /// Create a new trial with pre-set parameter values. - /// - /// The trial gets a new unique ID but reuses the given parameters. When - /// `suggest_param` is called on the resulting trial, fixed values are - /// returned instead of sampling. - fn create_trial_with_params(&self, params: HashMap<ParamId, ParamValue>) -> Trial { - let id = self.next_trial_id(); - let mut trial = if let Some(factory) = &self.trial_factory { - factory(id) - } else { - Trial::new(id) - }; - trial.set_fixed_params(params); - trial - } - /// Return the number of enqueued parameter configurations. /// /// See [`enqueue`](Self::enqueue) for how to add configurations. @@ -879,12 +863,15 @@ where completed } - /// Run optimization with a closure. + /// Run optimization with an objective. /// - /// Runs up to `n_trials` evaluations of `objective` sequentially. - /// For lifecycle hooks (early stopping, retries), implement the - /// [`Objective`](crate::Objective) trait and use - /// [`optimize_with`](Self::optimize_with) instead. + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures (`Fn(&mut Trial) -> Result<V, E>`) thanks to the + /// blanket impl. Struct-based objectives can override + /// [`before_trial`](crate::Objective::before_trial) and + /// [`after_trial`](crate::Objective::after_trial) for early stopping. + /// + /// Runs up to `n_trials` evaluations sequentially. /// /// # Errors /// @@ -911,10 +898,13 @@ where /// assert!(study.n_trials() > 0); /// assert!(study.best_value().unwrap() >= 0.0); /// ``` - pub fn optimize<F, E>(&self, n_trials: usize, mut objective: F) -> crate::Result<()> + #[allow(clippy::needless_pass_by_value)] + pub fn optimize( + &self, + n_trials: usize, + objective: impl crate::objective::Objective<V>, + ) -> crate::Result<()> where - F: FnMut(&mut Trial) -> Result<V, E>, - E: ToString + 'static, V: Clone + Default, { #[cfg(feature = "tracing")] @@ -922,13 +912,44 @@ where tracing::info_span!("optimize", n_trials, direction = ?self.direction).entered(); for _ in 0..n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + let mut trial = self.create_trial(); - match objective(&mut trial) { + match objective.evaluate(&mut trial) { Ok(value) => { #[cfg(feature = "tracing")] let trial_id = trial.id(); self.complete_trial(trial, value); - trace_info!(trial_id, "trial completed"); + + #[cfg(feature = "tracing")] + { + tracing::info!(trial_id, "trial completed"); + let trials = self.storage.trials_arc().read(); + if trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .count() + == 1 + || trials.last().map(|t| t.id) == self.best_id(&trials) + { + tracing::info!(trial_id, "new best value found"); + } + } + + // Fire after_trial hook + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + if let ControlFlow::Break(()) = + objective.after_trial(self, &completed_clone) + { + // Return early — at least one trial completed. + return Ok(()); + } + } } Err(e) if is_trial_pruned(&e) => { #[cfg(feature = "tracing")] @@ -945,144 +966,6 @@ where } } - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run optimization with an [`Objective`](crate::Objective) implementation. - /// - /// Like [`optimize`](Self::optimize), but accepts a struct implementing - /// [`Objective`](crate::Objective) for lifecycle hooks - /// ([`before_trial`](crate::Objective::before_trial), - /// [`after_trial`](crate::Objective::after_trial)) and automatic retries - /// ([`max_retries`](crate::Objective::max_retries)). - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// - /// # Examples - /// - /// ``` - /// use std::ops::ControlFlow; - /// - /// use optimizer::prelude::*; - /// - /// struct QuadraticObj { - /// x: FloatParam, - /// target: f64, - /// } - /// - /// impl Objective<f64> for QuadraticObj { - /// type Error = Error; - /// fn evaluate(&self, trial: &mut Trial) -> Result<f64> { - /// let v = self.x.suggest(trial)?; - /// Ok((v - 3.0).powi(2)) - /// } - /// fn after_trial(&self, _: &Study<f64>, t: &CompletedTrial<f64>) -> ControlFlow<()> { - /// if t.value < self.target { - /// ControlFlow::Break(()) - /// } else { - /// ControlFlow::Continue(()) - /// } - /// } - /// } - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let obj = QuadraticObj { - /// x: FloatParam::new(-10.0, 10.0), - /// target: 1.0, - /// }; - /// study.optimize_with(200, obj).unwrap(); - /// assert!(study.best_value().unwrap() < 1.0); - /// ``` - #[allow(clippy::needless_pass_by_value)] - pub fn optimize_with( - &self, - n_trials: usize, - objective: impl crate::objective::Objective<V>, - ) -> crate::Result<()> - where - V: Clone + Default, - { - #[cfg(feature = "tracing")] - let _span = - tracing::info_span!("optimize_with", n_trials, direction = ?self.direction).entered(); - - let max_retries = objective.max_retries(); - - for _ in 0..n_trials { - if let ControlFlow::Break(()) = objective.before_trial(self) { - break; - } - - let mut trial = self.create_trial(); - let mut retries = 0; - loop { - match objective.evaluate(&mut trial) { - Ok(value) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.complete_trial(trial, value); - - #[cfg(feature = "tracing")] - { - tracing::info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count() - == 1 - || trials.last().map(|t| t.id) == self.best_id(&trials) - { - tracing::info!(trial_id, "new best value found"); - } - } - - // Fire after_trial hook - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = - objective.after_trial(self, &completed_clone) - { - // Return early — at least one trial completed. - return Ok(()); - } - } - break; - } - Err(e) if !is_trial_pruned(&e) && retries < max_retries => { - retries += 1; - trial = self.create_trial_with_params(trial.params().clone()); - } - Err(e) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - if is_trial_pruned(&e) { - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - break; - } - } - } - } - // Return error if no trials completed successfully let has_complete = self .storage @@ -1097,13 +980,14 @@ where Ok(()) } - /// Run async optimization with a closure. + /// Run async optimization with an objective. /// - /// Each evaluation is wrapped in + /// Like [`optimize`](Self::optimize), but each evaluation is wrapped in /// [`spawn_blocking`](tokio::task::spawn_blocking), keeping the async /// runtime responsive for CPU-bound objectives. Trials run sequentially. /// - /// For lifecycle hooks, use [`optimize_with_async`](Self::optimize_with_async). + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures. Struct-based objectives can override lifecycle hooks. /// /// # Errors /// @@ -1135,10 +1019,10 @@ where /// # } /// ``` #[cfg(feature = "async")] - pub async fn optimize_async<F, E>(&self, n_trials: usize, objective: F) -> crate::Result<()> + pub async fn optimize_async<O>(&self, n_trials: usize, objective: O) -> crate::Result<()> where - F: Fn(&mut Trial) -> Result<V, E> + Send + Sync + 'static, - E: ToString + Send + 'static, + O: crate::objective::Objective<V> + Send + Sync + 'static, + O::Error: Send, V: Clone + Default + Send + 'static, { #[cfg(feature = "tracing")] @@ -1148,10 +1032,14 @@ where let objective = Arc::new(objective); for _ in 0..n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + let obj = Arc::clone(&objective); let mut trial = self.create_trial(); let result = tokio::task::spawn_blocking(move || { - let res = obj(&mut trial); + let res = obj.evaluate(&mut trial); (trial, res) }) .await @@ -1163,6 +1051,18 @@ where let trial_id = t.id(); self.complete_trial(t, value); trace_info!(trial_id, "trial completed"); + + // Fire after_trial hook + let trials = self.storage.trials_arc().read(); + if let Some(completed) = trials.last() { + let completed_clone = completed.clone(); + drop(trials); + if let ControlFlow::Break(()) = + objective.after_trial(self, &completed_clone) + { + return Ok(()); + } + } } (t, Err(e)) if is_trial_pruned(&e) => { #[cfg(feature = "tracing")] @@ -1192,108 +1092,16 @@ where Ok(()) } - /// Run async optimization with an [`Objective`](crate::Objective) implementation. - /// - /// Like [`optimize_async`](Self::optimize_async), but accepts a struct - /// implementing [`Objective`](crate::Objective) for lifecycle hooks and - /// automatic retries. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// Returns `Error::TaskError` if a spawned blocking task panics. - #[cfg(feature = "async")] - pub async fn optimize_with_async<O>(&self, n_trials: usize, objective: O) -> crate::Result<()> - where - O: crate::objective::Objective<V> + Send + Sync + 'static, - O::Error: Send, - V: Clone + Default + Send + 'static, - { - #[cfg(feature = "tracing")] - let _span = - tracing::info_span!("optimize_with_async", n_trials, direction = ?self.direction) - .entered(); - - let objective = Arc::new(objective); - let max_retries = objective.max_retries(); - - for _ in 0..n_trials { - if let ControlFlow::Break(()) = objective.before_trial(self) { - break; - } - - let mut trial = self.create_trial(); - let mut retries = 0; - loop { - let obj = Arc::clone(&objective); - let result = tokio::task::spawn_blocking(move || { - let res = obj.evaluate(&mut trial); - (trial, res) - }) - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - self.complete_trial(t, value); - trace_info!(trial_id, "trial completed"); - - // Fire after_trial hook - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = - objective.after_trial(self, &completed_clone) - { - return Ok(()); - } - } - break; - } - (t, Err(e)) if !is_trial_pruned(&e) && retries < max_retries => { - retries += 1; - trial = self.create_trial_with_params(t.params().clone()); - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - if is_trial_pruned(&e) { - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - break; - } - } - } - } - - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run parallel optimization with a closure. + /// Run parallel optimization with an objective. /// /// Spawns up to `concurrency` evaluations concurrently using /// [`spawn_blocking`](tokio::task::spawn_blocking). Results are /// collected via a [`JoinSet`](tokio::task::JoinSet). /// - /// For lifecycle hooks, use - /// [`optimize_with_parallel`](Self::optimize_with_parallel). + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures. The [`after_trial`](crate::Objective::after_trial) + /// hook fires as each result arrives — returning `Break` stops spawning + /// new trials while in-flight tasks drain. /// /// # Errors /// @@ -1325,131 +1133,8 @@ where /// # } /// ``` #[cfg(feature = "async")] - #[allow(clippy::missing_panics_doc)] - pub async fn optimize_parallel<F, E>( - &self, - n_trials: usize, - concurrency: usize, - objective: F, - ) -> crate::Result<()> - where - F: Fn(&mut Trial) -> Result<V, E> + Send + Sync + 'static, - E: ToString + Send + 'static, - V: Clone + Default + Send + 'static, - { - use tokio::sync::Semaphore; - use tokio::task::JoinSet; - - #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); - - let objective = Arc::new(objective); - let semaphore = Arc::new(Semaphore::new(concurrency)); - let mut join_set: JoinSet<(Trial, Result<V, E>)> = JoinSet::new(); - let mut spawned = 0; - - while spawned < n_trials { - // If the join set is full, drain one result to free a slot. - while join_set.len() >= concurrency { - let result = join_set - .join_next() - .await - .expect("join_set should not be empty") - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - self.complete_trial(t, value); - trace_info!(trial_id, "trial completed"); - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - if is_trial_pruned(&e) { - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - let permit = semaphore - .clone() - .acquire_owned() - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - - let mut trial = self.create_trial(); - let obj = Arc::clone(&objective); - join_set.spawn(async move { - let result = tokio::task::spawn_blocking(move || { - let res = obj(&mut trial); - (trial, res) - }) - .await - .expect("spawn_blocking should not panic"); - drop(permit); - result - }); - spawned += 1; - } - - // Drain remaining in-flight tasks. - while let Some(result) = join_set.join_next().await { - let result = result.map_err(|e| crate::Error::TaskError(e.to_string()))?; - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - self.complete_trial(t, value); - trace_info!(trial_id, "trial completed"); - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - if is_trial_pruned(&e) { - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run parallel optimization with an [`Objective`](crate::Objective) implementation. - /// - /// Like [`optimize_parallel`](Self::optimize_parallel), but accepts a struct - /// implementing [`Objective`](crate::Objective) for lifecycle hooks and - /// automatic retries. The [`after_trial`](crate::Objective::after_trial) - /// hook fires as each result arrives — returning `Break` stops spawning - /// new trials while in-flight tasks drain. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. - #[cfg(feature = "async")] #[allow(clippy::missing_panics_doc, clippy::too_many_lines)] - pub async fn optimize_with_parallel<O>( + pub async fn optimize_parallel<O>( &self, n_trials: usize, concurrency: usize, @@ -1464,7 +1149,7 @@ where use tokio::task::JoinSet; #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_with_parallel", n_trials, concurrency, direction = ?self.direction).entered(); + let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); let objective = Arc::new(objective); let semaphore = Arc::new(Semaphore::new(concurrency)); @@ -1838,7 +1523,7 @@ where /// let x = FloatParam::new(0.0, 10.0).name("x"); /// /// study - /// .optimize(20, |trial| { + /// .optimize(20, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// Ok::<_, optimizer::Error>(xv * xv) /// }) @@ -1941,7 +1626,7 @@ where /// let y = FloatParam::new(0.0, 10.0).name("y"); /// /// study - /// .optimize(30, |trial| { + /// .optimize(30, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// let yv = y.suggest(trial)?; /// Ok::<_, optimizer::Error>(xv * xv + 0.1 * yv) diff --git a/src/visualization.rs b/src/visualization.rs index f607f25..3c9558d 100644 --- a/src/visualization.rs +++ b/src/visualization.rs @@ -26,7 +26,7 @@ //! //! let study: Study<f64> = Study::new(Direction::Minimize); //! # let x = FloatParam::new(0.0, 1.0); -//! # study.optimize(10, |trial| { +//! # study.optimize(10, |trial: &mut optimizer::Trial| { //! # let v = x.suggest(trial)?; //! # Ok::<_, optimizer::Error>(v * v) //! # }).unwrap(); diff --git a/tests/export_tests.rs b/tests/export_tests.rs index 9a43149..20b4a5c 100644 --- a/tests/export_tests.rs +++ b/tests/export_tests.rs @@ -18,7 +18,7 @@ fn csv_includes_all_trial_data() { let y = IntParam::new(1, 5).name("y"); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(xv + yv as f64) @@ -124,7 +124,7 @@ fn csv_output_is_parseable() { let layers = IntParam::new(1, 5).name("n_layers"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let l = lr.suggest(trial)?; let n = layers.suggest(trial)?; Ok::<_, optimizer::Error>(l * n as f64) @@ -153,7 +153,7 @@ fn export_csv_writes_file() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) }) @@ -179,7 +179,7 @@ fn export_json_writes_file() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) }) @@ -214,7 +214,7 @@ fn csv_includes_user_attributes() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(2, |trial| { + .optimize(2, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; trial.set_user_attr("training_time_secs", 45.2); Ok::<_, optimizer::Error>(xv * xv) diff --git a/tests/fanova_tests.rs b/tests/fanova_tests.rs index 88eb3c8..cb69950 100644 --- a/tests/fanova_tests.rs +++ b/tests/fanova_tests.rs @@ -10,7 +10,7 @@ fn fanova_dominant_parameter() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let _yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv) @@ -34,7 +34,7 @@ fn fanova_interaction() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(7)); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * yv) @@ -62,7 +62,7 @@ fn fanova_consistent_with_correlation() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(99)); study - .optimize(80, |trial| { + .optimize(80, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(3.0 * xv + 0.5 * yv) diff --git a/tests/journal_tests.rs b/tests/journal_tests.rs index fab5790..4eef94e 100644 --- a/tests/journal_tests.rs +++ b/tests/journal_tests.rs @@ -119,7 +119,7 @@ fn study_with_journal_integration() { let study = Study::with_journal(Direction::Minimize, RandomSampler::with_seed(1), &path).unwrap(); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, optimizer::Error>(val * val) }) @@ -134,7 +134,7 @@ fn study_with_journal_integration() { // Continue optimizing study2 - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, optimizer::Error>(val * val) }) @@ -158,7 +158,7 @@ fn ids_are_unique_after_reload() { let study = Study::with_journal(Direction::Minimize, RandomSampler::with_seed(1), &path).unwrap(); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let _ = FloatParam::new(0.0, 1.0).suggest(trial)?; Ok::<_, optimizer::Error>(1.0) }) @@ -169,7 +169,7 @@ fn ids_are_unique_after_reload() { let study = Study::with_journal(Direction::Minimize, RandomSampler::with_seed(2), &path).unwrap(); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let _ = FloatParam::new(0.0, 1.0).suggest(trial)?; Ok::<_, optimizer::Error>(1.0) }) @@ -194,7 +194,7 @@ fn pruned_trials_are_stored() { // Complete one, prune one let x = FloatParam::new(0.0, 1.0); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; if trial.id() == 1 { Err(optimizer::TrialPruned)?; diff --git a/tests/multi_objective_tests.rs b/tests/multi_objective_tests.rs index dc30af9..c66e9df 100644 --- a/tests/multi_objective_tests.rs +++ b/tests/multi_objective_tests.rs @@ -18,7 +18,7 @@ fn test_basic_two_objective_random() { let x = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) @@ -50,7 +50,7 @@ fn test_dimension_mismatch_error() { let study = MultiObjectiveStudy::new(vec![Direction::Minimize, Direction::Minimize]); let x = FloatParam::new(0.0, 1.0); - let result = study.optimize(1, |trial| { + let result = study.optimize(1, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; // Return wrong number of values Ok::<_, optimizer::Error>(vec![xv]) @@ -104,7 +104,7 @@ fn test_n_trials_counting() { let x = FloatParam::new(0.0, 1.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) @@ -124,7 +124,7 @@ fn test_three_objectives() { let y = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, yv, 1.0 - xv - yv]) @@ -150,7 +150,7 @@ fn test_trials_accessor() { let x = FloatParam::new(0.0, 1.0); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) @@ -178,7 +178,7 @@ fn test_nsga2_zdt1() { MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -224,7 +224,7 @@ fn test_nsga2_with_seed_reproducible() { sampler, ); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, yv]) @@ -256,7 +256,7 @@ fn test_nsga2_builder() { let x = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) @@ -275,7 +275,7 @@ fn test_nsga2_categorical_params() { let cat = CategoricalParam::new(vec!["a", "b", "c"]); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let cv = cat.suggest(trial)?; let bonus = match cv { @@ -301,7 +301,7 @@ fn test_nsga2_constraints() { let x = FloatParam::new(0.0, 1.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; // Constraint: x >= 0.3 (i.e. 0.3 - x <= 0) trial.set_constraints(vec![0.3 - xv]); @@ -326,7 +326,7 @@ fn test_multi_objective_trial_get() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 10.0 - xv]) }) @@ -345,7 +345,7 @@ fn test_multi_objective_trial_is_feasible() { let x = FloatParam::new(0.0, 1.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; trial.set_constraints(vec![0.5 - xv]); // feasible if x >= 0.5 Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) @@ -369,7 +369,7 @@ fn test_multi_objective_trial_user_attrs() { let x = FloatParam::new(0.0, 1.0); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; trial.set_user_attr("iteration", 42_i64); Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) @@ -409,7 +409,7 @@ fn test_nsga3_zdt1() { MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -458,7 +458,7 @@ fn test_nsga3_four_objectives() { let study = MultiObjectiveStudy::with_sampler(directions, sampler); study - .optimize(500, |trial| { + .optimize(500, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -506,7 +506,7 @@ fn test_nsga3_reproducible() { sampler, ); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, yv]) @@ -539,7 +539,7 @@ fn test_nsga3_builder() { let x = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) @@ -557,7 +557,7 @@ fn test_nsga3_constraints() { let x = FloatParam::new(0.0, 1.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; trial.set_constraints(vec![0.3 - xv]); Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) @@ -588,7 +588,7 @@ fn test_moead_zdt1_tchebycheff() { MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -635,7 +635,7 @@ fn test_moead_zdt1_weighted_sum() { MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -666,7 +666,7 @@ fn test_moead_zdt1_pbi() { MultiObjectiveStudy::with_sampler(vec![Direction::Minimize, Direction::Minimize], sampler); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xs: Vec<f64> = params .iter() .map(|p| p.suggest(trial)) @@ -695,7 +695,7 @@ fn test_moead_reproducible() { sampler, ); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, yv]) @@ -729,7 +729,7 @@ fn test_moead_builder() { let x = FloatParam::new(0.0, 1.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(vec![xv, 1.0 - xv]) }) diff --git a/tests/parameter_tests.rs b/tests/parameter_tests.rs index 37e205f..3d4f308 100644 --- a/tests/parameter_tests.rs +++ b/tests/parameter_tests.rs @@ -180,7 +180,7 @@ fn parameter_api_with_study() { let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; let dropout = dropout_param.suggest(trial)?; diff --git a/tests/sampler/bohb.rs b/tests/sampler/bohb.rs index 980de7c..f9c24de 100644 --- a/tests/sampler/bohb.rs +++ b/tests/sampler/bohb.rs @@ -27,7 +27,7 @@ fn bohb_converges_on_quadratic() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize(60, |trial| { + .optimize(60, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; // Report intermediate values at budget steps 1, 3, 9 @@ -66,7 +66,7 @@ fn bohb_with_pruning() { let x_param = FloatParam::new(-5.0, 5.0); study - .optimize(40, |trial| { + .optimize(40, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let obj = x * x; @@ -112,7 +112,7 @@ fn bohb_uses_budget_conditioned_history() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; // Intermediate values that guide optimization toward x=2 trial.report(1, (x - 2.0).powi(2) + 1.0); diff --git a/tests/sampler/cma_es.rs b/tests/sampler/cma_es.rs index 2c3f992..711cccd 100644 --- a/tests/sampler/cma_es.rs +++ b/tests/sampler/cma_es.rs @@ -10,7 +10,7 @@ fn sphere_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -34,7 +34,7 @@ fn rosenbrock_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(300, |trial| { + .optimize(300, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; let val = (1.0 - xv).powi(2) + 100.0 * (yv - xv * xv).powi(2); @@ -60,7 +60,7 @@ fn bounds_respected() { let y = FloatParam::new(0.0, 10.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv + yv) @@ -84,7 +84,7 @@ fn mixed_params_float_and_categorical() { let cat = CategoricalParam::new(vec!["a", "b", "c"]).name("cat"); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let cv = cat.suggest(trial)?; let penalty = match cv { @@ -114,7 +114,7 @@ fn seeded_reproducibility() { let sampler = CmaEsSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -137,7 +137,7 @@ fn different_seeds_different_results() { let sampler = CmaEsSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -162,7 +162,7 @@ fn single_dimension() { let x = FloatParam::new(-10.0, 10.0).name("x"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, Error>((xv - 3.0).powi(2)) }) @@ -184,7 +184,7 @@ fn integer_params() { let n = IntParam::new(1, 20).name("n"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let nv = n.suggest(trial)?; // Minimum at n = 10 Ok::<_, Error>(((nv - 10) * (nv - 10)) as f64) @@ -212,7 +212,7 @@ fn log_scale_params() { let lr = FloatParam::new(1e-5, 1.0).log_scale().name("lr"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let lrv = lr.suggest(trial)?; // Minimum at lr = 0.01 Ok::<_, Error>((lrv.ln() - 0.01_f64.ln()).powi(2)) @@ -241,7 +241,7 @@ fn custom_population_size_and_sigma() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) diff --git a/tests/sampler/differential_evolution.rs b/tests/sampler/differential_evolution.rs index eb7a927..4c75d5f 100644 --- a/tests/sampler/differential_evolution.rs +++ b/tests/sampler/differential_evolution.rs @@ -10,7 +10,7 @@ fn sphere_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -37,7 +37,7 @@ fn rosenbrock_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(400, |trial| { + .optimize(400, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; let val = (1.0 - xv).powi(2) + 100.0 * (yv - xv * xv).powi(2); @@ -67,7 +67,7 @@ fn rastrigin_function() { let y = FloatParam::new(-5.12, 5.12).name("y"); study - .optimize(500, |trial| { + .optimize(500, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; let val = 20.0 @@ -95,7 +95,7 @@ fn bounds_respected() { let y = FloatParam::new(0.0, 10.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv + yv) @@ -123,7 +123,7 @@ fn strategy_best1() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -151,7 +151,7 @@ fn strategy_current_to_best1() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(200, |trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -175,7 +175,7 @@ fn mixed_params_float_and_categorical() { let cat = CategoricalParam::new(vec!["a", "b", "c"]).name("cat"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let cv = cat.suggest(trial)?; let penalty = match cv { @@ -204,7 +204,7 @@ fn seeded_reproducibility() { let sampler = DifferentialEvolutionSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -227,7 +227,7 @@ fn different_seeds_different_results() { let sampler = DifferentialEvolutionSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -252,7 +252,7 @@ fn single_dimension() { let x = FloatParam::new(-10.0, 10.0).name("x"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, Error>((xv - 3.0).powi(2)) }) @@ -274,7 +274,7 @@ fn integer_params() { let n = IntParam::new(1, 20).name("n"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let nv = n.suggest(trial)?; // Minimum at n = 10 Ok::<_, Error>(((nv - 10) * (nv - 10)) as f64) @@ -302,7 +302,7 @@ fn log_scale_params() { let lr = FloatParam::new(1e-5, 1.0).log_scale().name("lr"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let lrv = lr.suggest(trial)?; // Minimum at lr = 0.01 Ok::<_, Error>((lrv.ln() - 0.01_f64.ln()).powi(2)) @@ -332,7 +332,7 @@ fn custom_mutation_and_crossover() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) diff --git a/tests/sampler/gp.rs b/tests/sampler/gp.rs index 6ddf3b5..24b051d 100644 --- a/tests/sampler/gp.rs +++ b/tests/sampler/gp.rs @@ -10,7 +10,7 @@ fn sphere_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(80, |trial| { + .optimize(80, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -34,7 +34,7 @@ fn rosenbrock_function() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; let val = (1.0 - xv).powi(2) + 100.0 * (yv - xv * xv).powi(2); @@ -59,7 +59,7 @@ fn bounds_respected() { let y = FloatParam::new(0.0, 10.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv + yv) @@ -83,7 +83,7 @@ fn mixed_params_float_and_categorical() { let cat = CategoricalParam::new(vec!["a", "b", "c"]).name("cat"); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let cv = cat.suggest(trial)?; let penalty = match cv { @@ -112,7 +112,7 @@ fn seeded_reproducibility() { let sampler = GpSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -135,7 +135,7 @@ fn different_seeds_different_results() { let sampler = GpSampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) @@ -160,7 +160,7 @@ fn single_dimension() { let x = FloatParam::new(-10.0, 10.0).name("x"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, Error>((xv - 3.0).powi(2)) }) @@ -182,7 +182,7 @@ fn integer_params() { let n = IntParam::new(1, 20).name("n"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let nv = n.suggest(trial)?; Ok::<_, Error>(((nv - 10) * (nv - 10)) as f64) }) @@ -209,7 +209,7 @@ fn log_scale_params() { let lr = FloatParam::new(1e-5, 1.0).log_scale().name("lr"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let lrv = lr.suggest(trial)?; Ok::<_, Error>((lrv.ln() - 0.01_f64.ln()).powi(2)) }) @@ -238,7 +238,7 @@ fn builder_configuration() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, Error>(xv * xv + yv * yv) diff --git a/tests/sampler/multivariate_tpe.rs b/tests/sampler/multivariate_tpe.rs index db84523..a84ca34 100644 --- a/tests/sampler/multivariate_tpe.rs +++ b/tests/sampler/multivariate_tpe.rs @@ -55,7 +55,7 @@ fn test_multivariate_tpe_rosenbrock_finds_good_solution() { let y_param = FloatParam::new(-2.0, 4.0); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(rosenbrock(x, y)) @@ -90,7 +90,7 @@ fn test_independent_tpe_rosenbrock() { let y_param = FloatParam::new(-2.0, 4.0); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(rosenbrock(x, y)) @@ -133,7 +133,7 @@ fn test_multivariate_tpe_outperforms_on_correlated_problem() { let y_param = FloatParam::new(-2.0, 4.0); study - .optimize(n_trials, |trial| { + .optimize(n_trials, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(rosenbrock(x, y)) @@ -156,7 +156,7 @@ fn test_multivariate_tpe_outperforms_on_correlated_problem() { let y_param = FloatParam::new(-2.0, 4.0); study - .optimize(n_trials, |trial| { + .optimize(n_trials, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(rosenbrock(x, y)) @@ -218,7 +218,7 @@ fn test_multivariate_tpe_independent_problem() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(sphere(x, y)) @@ -250,7 +250,7 @@ fn test_independent_tpe_independent_problem() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(sphere(x, y)) @@ -290,7 +290,7 @@ fn test_both_samplers_work_on_independent_problem() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(n_trials, |trial| { + .optimize(n_trials, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(sphere(x, y)) @@ -312,7 +312,7 @@ fn test_both_samplers_work_on_independent_problem() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(n_trials, |trial| { + .optimize(n_trials, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(sphere(x, y)) @@ -360,7 +360,7 @@ fn test_multivariate_tpe_with_group_decomposition() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(sphere(x, y)) @@ -396,7 +396,7 @@ fn test_multivariate_tpe_mixed_parameter_types() { let mode_param = CategoricalParam::new(vec!["a", "b", "c"]); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; let mode = mode_param.suggest(trial)?; diff --git a/tests/sampler/random.rs b/tests/sampler/random.rs index 911f192..66a9264 100644 --- a/tests/sampler/random.rs +++ b/tests/sampler/random.rs @@ -1,3 +1,5 @@ +use std::cell::RefCell; + use optimizer::parameter::{CategoricalParam, FloatParam, IntParam, Parameter}; use optimizer::sampler::random::RandomSampler; use optimizer::{Direction, Error, Study}; @@ -7,18 +9,20 @@ fn test_random_sampler_uniform_float_distribution() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); let n_samples = 1000; - let mut samples = Vec::with_capacity(n_samples); + let samples = RefCell::new(Vec::with_capacity(n_samples)); let x_param = FloatParam::new(0.0, 1.0); study - .optimize(n_samples, |trial| { + .optimize(n_samples, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; - samples.push(x); + samples.borrow_mut().push(x); Ok::<_, Error>(x) }) .unwrap(); + let mut samples = samples.into_inner(); + // All samples should be in range for &s in &samples { assert!((0.0..=1.0).contains(&s), "sample {s} out of range [0, 1]"); @@ -44,19 +48,20 @@ fn test_random_sampler_uniform_int_distribution() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(123)); let n_samples = 5000; - let mut counts = [0u32; 10]; // counts for values 1-10 + let counts = RefCell::new([0u32; 10]); // counts for values 1-10 let n_param = IntParam::new(1, 10); study - .optimize(n_samples, |trial| { + .optimize(n_samples, |trial: &mut optimizer::Trial| { let n = n_param.suggest(trial)?; assert!((1..=10).contains(&n), "sample {n} out of range [1, 10]"); - counts[(n - 1) as usize] += 1; + counts.borrow_mut()[(n - 1) as usize] += 1; Ok::<_, Error>(n as f64) }) .unwrap(); + let counts = counts.into_inner(); let expected = n_samples as f64 / 10.0; for (i, &count) in counts.iter().enumerate() { let diff = (count as f64 - expected).abs() / expected; @@ -76,20 +81,21 @@ fn test_random_sampler_uniform_categorical_distribution() { let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(456)); let n_samples = 2000; - let mut counts = [0u32; 4]; + let counts = RefCell::new([0u32; 4]); let choices = ["a", "b", "c", "d"]; let cat_param = CategoricalParam::new(choices.to_vec()); study - .optimize(n_samples, |trial| { + .optimize(n_samples, |trial: &mut optimizer::Trial| { let choice = cat_param.suggest(trial)?; let idx = choices.iter().position(|&c| c == choice).unwrap(); - counts[idx] += 1; + counts.borrow_mut()[idx] += 1; Ok::<_, Error>(idx as f64) }) .unwrap(); + let counts = counts.into_inner(); let expected = n_samples as f64 / 4.0; for (i, &count) in counts.iter().enumerate() { let diff = (count as f64 - expected).abs() / expected; @@ -111,28 +117,30 @@ fn test_random_sampler_reproducibility() { let study2: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(999)); - let mut values1 = Vec::new(); - let mut values2 = Vec::new(); + let values1 = RefCell::new(Vec::new()); + let values2 = RefCell::new(Vec::new()); let x_param1 = FloatParam::new(0.0, 100.0); let x_param2 = FloatParam::new(0.0, 100.0); study1 - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param1.suggest(trial)?; - values1.push(x); + values1.borrow_mut().push(x); Ok::<_, Error>(x) }) .unwrap(); study2 - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param2.suggest(trial)?; - values2.push(x); + values2.borrow_mut().push(x); Ok::<_, Error>(x) }) .unwrap(); + let values1 = values1.into_inner(); + let values2 = values2.into_inner(); for (i, (v1, v2)) in values1.iter().zip(values2.iter()).enumerate() { assert_eq!( v1, v2, diff --git a/tests/sampler/tpe.rs b/tests/sampler/tpe.rs index 8091077..3e37899 100644 --- a/tests/sampler/tpe.rs +++ b/tests/sampler/tpe.rs @@ -18,7 +18,7 @@ fn test_tpe_optimizes_quadratic_function() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>((x - 3.0).powi(2)) }) @@ -51,7 +51,7 @@ fn test_tpe_optimizes_multivariate_function() { let y_param = FloatParam::new(-5.0, 5.0); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(x * x + y * y) @@ -83,7 +83,7 @@ fn test_tpe_maximization() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize(50, |trial| { + .optimize(50, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(-(x - 2.0).powi(2) + 10.0) }) @@ -113,7 +113,7 @@ fn test_tpe_with_categorical_parameter() { // Optimization where the best choice depends on the categorical study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let choice = model_param.suggest(trial)?; let x = x_param.suggest(trial)?; @@ -150,7 +150,7 @@ fn test_tpe_with_integer_parameters() { // Minimize (n - 7)^2 where n in [1, 10] study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let n = n_param.suggest(trial)?; Ok::<_, Error>(((n - 7) as f64).powi(2)) }) @@ -177,7 +177,7 @@ fn test_tpe_with_log_scale_int() { let batch_param = IntParam::new(1, 1024).log_scale(); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let batch_size = batch_param.suggest(trial)?; Ok::<_, Error>(((batch_size as f64).log2() - 5.0).powi(2)) }) @@ -200,7 +200,7 @@ fn test_tpe_with_step_distributions() { let n_param = IntParam::new(0, 100).step(10); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; Ok::<_, Error>((x - 5.0).powi(2) + ((n - 50) as f64).powi(2)) @@ -224,7 +224,7 @@ fn test_tpe_with_fixed_kde_bandwidth() { let x_param = FloatParam::new(-5.0, 5.0); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x * x) }) @@ -252,7 +252,7 @@ fn test_tpe_split_trials_with_two_trials() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -276,7 +276,7 @@ fn test_tpe_empty_good_or_bad_values_fallback() { // First optimize with one parameter study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -284,7 +284,7 @@ fn test_tpe_empty_good_or_bad_values_fallback() { // Now try with a different parameter - TPE won't have history for "y" study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let y = y_param.suggest(trial)?; Ok::<_, Error>(y) }) @@ -304,7 +304,7 @@ fn test_tpe_sampler_builder_default_trait() { let x_param = FloatParam::new(0.0, 1.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -321,7 +321,7 @@ fn test_tpe_sampler_default_trait() { let x_param = FloatParam::new(0.0, 1.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -343,7 +343,7 @@ fn test_suggest_bool_with_tpe() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let use_large = use_large_param.suggest(trial)?; let x = x_param.suggest(trial)?; // The value depends on use_large flag @@ -369,7 +369,7 @@ fn test_params_with_tpe() { let n_param = IntParam::new(1, 10); study - .optimize(30, |trial| { + .optimize(30, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; Ok::<_, Error>(x * x + (n as f64 - 5.0).powi(2)) diff --git a/tests/serde_tests.rs b/tests/serde_tests.rs index b734152..1495710 100644 --- a/tests/serde_tests.rs +++ b/tests/serde_tests.rs @@ -13,7 +13,7 @@ fn round_trip_save_load() { let n = IntParam::new(1, 100).name("n"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; let n_val = n.suggest(trial)?; Ok::<_, optimizer::Error>(x_val * x_val + n_val as f64) @@ -51,7 +51,7 @@ fn json_output_is_human_readable() { let x = FloatParam::new(0.0, 1.0).name("x"); study - .optimize(2, |trial| { + .optimize(2, |trial: &mut optimizer::Trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v) }) @@ -153,7 +153,7 @@ fn round_trip_preserves_trial_id_counter() { let x = FloatParam::new(0.0, 1.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v) }) @@ -182,7 +182,7 @@ fn save_and_resume_continues_trial_ids() { // Run 10 trials study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v * v) }) @@ -196,7 +196,7 @@ fn save_and_resume_continues_trial_ids() { // Continue with 5 more trials let remaining = 15 - loaded.n_trials(); loaded - .optimize(remaining, |trial| { + .optimize(remaining, |trial: &mut optimizer::Trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v * v) }) @@ -223,7 +223,7 @@ fn save_uses_atomic_write() { let save_path = dir.join("atomic.json"); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let v = x.suggest(trial)?; Ok::<_, optimizer::Error>(v) }) diff --git a/tests/study/builder.rs b/tests/study/builder.rs index 8b421ff..fc25a33 100644 --- a/tests/study/builder.rs +++ b/tests/study/builder.rs @@ -33,7 +33,7 @@ fn test_builder_with_sampler() { let study: Study<f64> = Study::builder().sampler(TpeSampler::new()).build(); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, Error>(val * val) }) @@ -77,7 +77,7 @@ fn test_builder_optimizes_correctly() { .build(); study - .optimize(100, |trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, Error>((val - 3.0) * (val - 3.0)) }) diff --git a/tests/study/enqueue.rs b/tests/study/enqueue.rs index c84e008..8e20f37 100644 --- a/tests/study/enqueue.rs +++ b/tests/study/enqueue.rs @@ -1,3 +1,4 @@ +use std::cell::RefCell; use std::collections::HashMap; use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; @@ -73,16 +74,17 @@ fn test_enqueue_with_optimize() { study.enqueue(HashMap::from([(x.id(), ParamValue::Float(1.0))])); study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); - let mut values = Vec::new(); + let values = RefCell::new(Vec::new()); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; - values.push(x_val); + values.borrow_mut().push(x_val); Ok::<_, Error>(x_val * x_val) }) .unwrap(); + let values = values.into_inner(); // First two trials should use enqueued values assert_eq!(values[0], 1.0); assert_eq!(values[1], 2.0); @@ -117,7 +119,7 @@ fn test_enqueue_trials_appear_in_completed_trials() { study.enqueue(HashMap::from([(x.id(), ParamValue::Float(7.0))])); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; Ok::<_, Error>(x_val) }) @@ -178,7 +180,7 @@ fn test_enqueue_counted_in_n_trials() { study.enqueue(HashMap::from([(x.id(), ParamValue::Float(2.0))])); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x_val = x.suggest(trial)?; Ok::<_, Error>(x_val) }) diff --git a/tests/study/objective.rs b/tests/study/objective.rs index 6bb1944..8d28bf1 100644 --- a/tests/study/objective.rs +++ b/tests/study/objective.rs @@ -30,7 +30,7 @@ fn test_callback_early_stopping() { let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize_with( + .optimize( 100, EarlyStopAfter5 { x_param: FloatParam::new(0.0, 10.0), @@ -69,7 +69,7 @@ fn test_callback_early_stopping_on_first_trial() { let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize_with( + .optimize( 100, StopImmediately { x_param: FloatParam::new(0.0, 10.0), @@ -109,7 +109,7 @@ fn test_callback_sampler_early_stopping() { let sampler = RandomSampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study - .optimize_with( + .optimize( 100, StopAfter3 { x_param: FloatParam::new(0.0, 10.0), @@ -121,226 +121,42 @@ fn test_callback_sampler_early_stopping() { } #[test] -fn test_retries_successful_trials_not_retried() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - +fn test_objective_struct_basic() { use optimizer::Objective; - struct SuccessObj { + struct SquareObj { x_param: FloatParam, - call_count: Arc<AtomicU32>, } - impl Objective<f64> for SuccessObj { + impl Objective<f64> for SquareObj { type Error = Error; fn evaluate(&self, trial: &mut Trial) -> Result<f64, Error> { let x = self.x_param.suggest(trial)?; - self.call_count.fetch_add(1, Ordering::Relaxed); Ok(x * x) } - fn max_retries(&self) -> usize { - 3 - } } let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = SuccessObj { + let obj = SquareObj { x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), }; - study.optimize_with(5, obj).unwrap(); + study.optimize(5, obj).unwrap(); - // All trials succeed on first try — exactly 5 calls - assert_eq!(call_count.load(Ordering::Relaxed), 5); assert_eq!(study.n_trials(), 5); } #[test] -fn test_retries_failed_trials_retried_up_to_max() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - - use optimizer::Objective; - - struct AlwaysFailObj { - x_param: FloatParam, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for AlwaysFailObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - self.call_count.fetch_add(1, Ordering::Relaxed); - Err("always fails".to_string()) - } - fn max_retries(&self) -> usize { - 3 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = AlwaysFailObj { - x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), - }; - - let result = study.optimize_with(1, obj); - - // 1 initial attempt + 3 retries = 4 total calls - assert_eq!(call_count.load(Ordering::Relaxed), 4); - // No trials completed - assert!(matches!(result, Err(Error::NoCompletedTrials))); -} - -#[test] -fn test_retries_permanently_failed_after_exhaustion() { - use optimizer::Objective; - - struct AlwaysFailObj { - x_param: FloatParam, - } - - impl Objective<f64> for AlwaysFailObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let _ = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - Err("transient error".to_string()) - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let obj = AlwaysFailObj { - x_param: FloatParam::new(0.0, 10.0), - }; - - let result = study.optimize_with(3, obj); - - assert!( - matches!(result, Err(Error::NoCompletedTrials)), - "all trials should permanently fail" - ); - assert_eq!( - study.n_trials(), - 0, - "no completed trials should be recorded" - ); -} - -#[test] -fn test_retries_uses_same_parameters() { - use std::sync::atomic::{AtomicU32, Ordering}; - use std::sync::{Arc, Mutex}; - - use optimizer::Objective; - - struct RetryObj { - x_param: FloatParam, - seen_values: Arc<Mutex<Vec<f64>>>, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for RetryObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - self.seen_values.lock().unwrap().push(x); - let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; - // Fail first two attempts, succeed on third - if count < 3 { - Err("transient".to_string()) - } else { - Ok(x * x) - } - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let seen_values = Arc::new(Mutex::new(Vec::new())); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = RetryObj { - x_param: FloatParam::new(0.0, 10.0), - seen_values: Arc::clone(&seen_values), - call_count: Arc::clone(&call_count), - }; - - study.optimize_with(1, obj).unwrap(); - - let values = seen_values.lock().unwrap(); - assert_eq!(values.len(), 3, "should be called 3 times (1 + 2 retries)"); - // All three calls should have gotten the same parameter value - assert_eq!(values[0], values[1]); - assert_eq!(values[1], values[2]); -} - -#[test] -fn test_retries_n_trials_counts_unique_configs() { - use std::sync::Arc; - use std::sync::atomic::{AtomicU32, Ordering}; - - use optimizer::Objective; - - struct FailFirstObj { - x_param: FloatParam, - call_count: Arc<AtomicU32>, - } - - impl Objective<f64> for FailFirstObj { - type Error = String; - fn evaluate(&self, trial: &mut Trial) -> Result<f64, String> { - let x = self.x_param.suggest(trial).map_err(|e| e.to_string())?; - let count = self.call_count.fetch_add(1, Ordering::Relaxed) + 1; - // Fail first attempt of each config, succeed on retry - if count % 2 == 1 { - Err("transient".to_string()) - } else { - Ok(x * x) - } - } - fn max_retries(&self) -> usize { - 2 - } - } - - let study: Study<f64> = Study::new(Direction::Minimize); - let call_count = Arc::new(AtomicU32::new(0)); - let obj = FailFirstObj { - x_param: FloatParam::new(0.0, 10.0), - call_count: Arc::clone(&call_count), - }; - - study.optimize_with(3, obj).unwrap(); - - // 3 unique configs, each needing 2 calls = 6 total calls - assert_eq!(call_count.load(Ordering::Relaxed), 6); - // But only 3 completed trials - assert_eq!(study.n_trials(), 3); -} - -#[test] -fn test_retries_with_zero_max_retries_same_as_optimize() { - let study: Study<f64> = Study::new(Direction::Minimize); +fn test_closure_and_objective_produce_same_results() { let x_param = FloatParam::new(0.0, 10.0); - let call_count = std::cell::Cell::new(0u32); + let study: Study<f64> = Study::new(Direction::Minimize); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut Trial| { let x = x_param.suggest(trial)?; - call_count.set(call_count.get() + 1); Ok::<_, Error>(x * x) }) .unwrap(); - assert_eq!(call_count.get(), 5); assert_eq!(study.n_trials(), 5); } diff --git a/tests/study/summary.rs b/tests/study/summary.rs index 8f688af..f868f0d 100644 --- a/tests/study/summary.rs +++ b/tests/study/summary.rs @@ -8,7 +8,7 @@ fn test_summary_with_completed_trials() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, Error>(val * val) }) @@ -61,7 +61,7 @@ fn test_display_matches_summary() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(3, |trial| { + .optimize(3, |trial: &mut optimizer::Trial| { let val = x.suggest(trial)?; Ok::<_, Error>(val) }) diff --git a/tests/study/workflow.rs b/tests/study/workflow.rs index f41bff8..4b998ec 100644 --- a/tests/study/workflow.rs +++ b/tests/study/workflow.rs @@ -8,7 +8,7 @@ fn test_study_basic_workflow() { let x_param = FloatParam::new(-5.0, 5.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x * x) }) @@ -25,11 +25,11 @@ fn test_study_with_failures() { let x_param = FloatParam::new(-5.0, 5.0); // Every other trial fails - let mut counter = 0; + let counter = std::cell::Cell::new(0u32); study - .optimize(10, |trial| { - counter += 1; - if counter % 2 == 0 { + .optimize(10, |trial: &mut optimizer::Trial| { + counter.set(counter.get() + 1); + if counter.get().is_multiple_of(2) { return Err::<f64, &str>("intentional failure"); } let x = x_param.suggest(trial).map_err(|_| "param error")?; @@ -64,7 +64,7 @@ fn test_study_trials_iteration() { let x_param = FloatParam::new(0.0, 1.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -95,7 +95,7 @@ fn test_study_set_sampler() { let x_param = FloatParam::new(-5.0, 5.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x * x) }) @@ -110,7 +110,7 @@ fn test_study_with_i32_value_type() { let x_param = IntParam::new(-10, 10); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x.abs() as i32) }) @@ -125,7 +125,9 @@ fn test_study_with_i32_value_type() { fn test_optimize_all_trials_fail() { let study: Study<f64> = Study::new(Direction::Minimize); - let result = study.optimize(5, |_trial| Err::<f64, &str>("always fails")); + let result = study.optimize(5, |_trial: &mut optimizer::Trial| { + Err::<f64, &str>("always fails") + }); assert!( matches!(result, Err(Error::NoCompletedTrials)), @@ -139,7 +141,7 @@ fn test_best_value() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -160,7 +162,7 @@ fn test_best_trial_with_nan_values() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(x) }) @@ -199,7 +201,7 @@ fn test_multiple_params_in_optimization() { let n_param = IntParam::new(1, 5); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; Ok::<_, Error>(x * x + n as f64) @@ -216,7 +218,7 @@ fn test_suggest_bool_in_optimization() { let x_param = FloatParam::new(0.0, 10.0); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let use_feature = use_feature_param.suggest(trial)?; let x = x_param.suggest(trial)?; @@ -235,7 +237,7 @@ fn test_completed_trial_get() { let n_param = IntParam::new(1, 10).name("n"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let n = n_param.suggest(trial)?; Ok::<_, Error>(x * x + n as f64) diff --git a/tests/user_attr_tests.rs b/tests/user_attr_tests.rs index 866560b..bf544d9 100644 --- a/tests/user_attr_tests.rs +++ b/tests/user_attr_tests.rs @@ -7,7 +7,7 @@ fn set_and_get_float_attr() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("score", 42.5); assert_eq!(trial.user_attr("score"), Some(&AttrValue::Float(42.5))); @@ -22,7 +22,7 @@ fn set_and_get_int_attr() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("epoch", 42_i64); assert_eq!(trial.user_attr("epoch"), Some(&AttrValue::Int(42))); @@ -37,7 +37,7 @@ fn set_and_get_string_attr() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("model", "resnet50"); assert_eq!( @@ -55,7 +55,7 @@ fn set_and_get_bool_attr() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("converged", true); assert_eq!(trial.user_attr("converged"), Some(&AttrValue::Bool(true))); @@ -70,7 +70,7 @@ fn attrs_propagate_to_completed_trial() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("time_secs", 1.5); trial.set_user_attr("tag", "baseline"); @@ -92,7 +92,7 @@ fn overwrite_attr_replaces_value() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("key", "old"); trial.set_user_attr("key", "new"); @@ -117,7 +117,7 @@ fn missing_attr_returns_none() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; assert_eq!(trial.user_attr("nonexistent"), None); Ok::<_, optimizer::Error>(1.0) @@ -134,7 +134,7 @@ fn user_attrs_map_returns_all() { let x = FloatParam::new(0.0, 1.0); study - .optimize(1, |trial| { + .optimize(1, |trial: &mut optimizer::Trial| { let _ = x.suggest(trial)?; trial.set_user_attr("a", 1.0); trial.set_user_attr("b", true); diff --git a/tests/visualization_tests.rs b/tests/visualization_tests.rs index 19a3b3a..349ec47 100644 --- a/tests/visualization_tests.rs +++ b/tests/visualization_tests.rs @@ -9,7 +9,7 @@ fn html_report_creates_file() { let y = IntParam::new(1, 5).name("y"); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(xv + yv as f64) @@ -32,7 +32,7 @@ fn html_report_contains_all_chart_sections() { let y = FloatParam::new(-5.0, 5.0).name("y"); study - .optimize(20, |trial| { + .optimize(20, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; let yv = y.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv + yv * yv) @@ -85,7 +85,7 @@ fn html_report_single_param_no_parcoords() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) }) @@ -109,7 +109,7 @@ fn html_report_maximize_direction() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv) }) @@ -130,7 +130,7 @@ fn export_html_convenience_method() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(5, |trial| { + .optimize(5, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; Ok::<_, optimizer::Error>(xv * xv) }) @@ -156,7 +156,7 @@ fn html_report_with_intermediate_values() { let x = FloatParam::new(0.0, 10.0).name("x"); study - .optimize(10, |trial| { + .optimize(10, |trial: &mut optimizer::Trial| { let xv = x.suggest(trial)?; for step in 0..5 { let val = xv * xv + step as f64; From d20f09c66a89e8d6461eb7f6b3da305880c2a88a Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:19:06 +0100 Subject: [PATCH 09/48] refactor: shorten sampler type names MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - GridSearchSampler → GridSampler - DifferentialEvolutionSampler → DESampler - DifferentialEvolutionStrategy → DEStrategy - DifferentialEvolutionSamplerBuilder → DESamplerBuilder --- benches/samplers.rs | 4 +- examples/sampler_comparison.rs | 2 +- src/lib.rs | 9 +- src/sampler/de.rs | 114 ++++++++++-------------- src/sampler/grid.rs | 78 ++++++++-------- src/sampler/mod.rs | 4 +- tests/sampler/differential_evolution.rs | 35 ++++---- 7 files changed, 109 insertions(+), 137 deletions(-) diff --git a/benches/samplers.rs b/benches/samplers.rs index 79da8a6..4cb2e7f 100644 --- a/benches/samplers.rs +++ b/benches/samplers.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use optimizer::parameter::{FloatParam, Parameter}; -use optimizer::sampler::grid::GridSearchSampler; +use optimizer::sampler::grid::GridSampler; use optimizer::sampler::random::RandomSampler; use optimizer::sampler::tpe::TpeSampler; use optimizer::sampler::{CompletedTrial, Sampler}; @@ -97,7 +97,7 @@ fn bench_grid_sample(c: &mut Criterion) { |b, _| { b.iter(|| { // Fresh sampler each iteration since grid tracks used points - let sampler = GridSearchSampler::builder() + let sampler = GridSampler::builder() .n_points_per_param(grid_points) .build(); sampler.sample(&dist, 0, &history) diff --git a/examples/sampler_comparison.rs b/examples/sampler_comparison.rs index 9e484e2..f1c1b7a 100644 --- a/examples/sampler_comparison.rs +++ b/examples/sampler_comparison.rs @@ -65,7 +65,7 @@ fn main() { // Evaluates evenly spaced grid points. Each parameter gets its own grid that // is sampled in order, so n_points_per_param must be >= n_trials. println!("\n3. Grid sampler (exhaustive):"); - let grid = GridSearchSampler::builder() + let grid = GridSampler::builder() .n_points_per_param(n_trials) // one grid point per trial per parameter .build(); let grid_best = run_study(Study::minimize(grid), n_trials); diff --git a/src/lib.rs b/src/lib.rs index 1424b33..808012d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -54,11 +54,11 @@ //! |---------|-----------|----------|--------------| //! | [`RandomSampler`](sampler::RandomSampler) | Uniform random | Baselines, high-dimensional | — | //! | [`TpeSampler`](sampler::TpeSampler) | Tree-Parzen Estimator | General-purpose Bayesian | — | -//! | [`GridSearchSampler`](sampler::GridSearchSampler) | Exhaustive grid | Small, discrete spaces | — | +//! | [`GridSearchSampler`](sampler::GridSampler) | Exhaustive grid | Small, discrete spaces | — | //! | [`SobolSampler`](sampler::SobolSampler) | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | //! | [`CmaEsSampler`](sampler::CmaEsSampler) | CMA-ES | Continuous, moderate dimensions | `cma-es` | //! | [`GpSampler`](sampler::GpSampler) | Gaussian Process + EI | Expensive objectives, few trials | `gp` | -//! | [`DifferentialEvolutionSampler`](sampler::DifferentialEvolutionSampler) | Differential Evolution | Non-convex, population-based | — | +//! | [`DESampler`](sampler::DESampler) | Differential Evolution | Non-convex, population-based | — | //! | [`BohbSampler`](sampler::BohbSampler) | BOHB (TPE + `HyperBand`) | Budget-aware early stopping | — | //! //! ## Multi-objective samplers @@ -168,9 +168,8 @@ pub mod prelude { #[cfg(feature = "sobol")] pub use crate::sampler::SobolSampler; pub use crate::sampler::{ - BohbSampler, CompletedTrial, Decomposition, DifferentialEvolutionSampler, - DifferentialEvolutionStrategy, GridSearchSampler, MoeadSampler, MotpeSampler, Nsga2Sampler, - Nsga3Sampler, RandomSampler, TpeSampler, + BohbSampler, CompletedTrial, DESampler, DEStrategy, Decomposition, GridSampler, + MoeadSampler, MotpeSampler, Nsga2Sampler, Nsga3Sampler, RandomSampler, TpeSampler, }; #[cfg(feature = "journal")] pub use crate::storage::JournalStorage; diff --git a/src/sampler/de.rs b/src/sampler/de.rs index 7641084..bd6be26 100644 --- a/src/sampler/de.rs +++ b/src/sampler/de.rs @@ -10,7 +10,7 @@ //! //! Each generation, for every population member *xᵢ*: //! 1. **Mutation** — create a mutant vector *v* from other population -//! members using the selected [`DifferentialEvolutionStrategy`]: +//! members using the selected [`DEStrategy`]: //! - `Rand1`: `v = x_r1 + F * (x_r2 - x_r3)` //! - `Best1`: `v = x_best + F * (x_r1 - x_r2)` //! - `CurrentToBest1`: `v = x_i + F * (x_best - x_i) + F * (x_r1 - x_r2)` @@ -40,20 +40,20 @@ //! | `population_size` | `max(10n, 15)` | Candidates per generation | //! | `mutation_factor` (F) | 0.8 | Differential amplification — higher = more exploration | //! | `crossover_rate` (CR) | 0.9 | Probability of taking a dimension from the mutant | -//! | `strategy` | `Rand1` | Mutation strategy (see [`DifferentialEvolutionStrategy`]) | +//! | `strategy` | `Rand1` | Mutation strategy (see [`DEStrategy`]) | //! | `seed` | random | RNG seed for reproducibility | //! //! # Examples //! //! ``` -//! use optimizer::sampler::de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; +//! use optimizer::sampler::de::{DESampler, DEStrategy}; //! use optimizer::{Direction, Study}; //! //! // Minimize with DE using the Best1 strategy for faster convergence -//! let sampler = DifferentialEvolutionSampler::builder() +//! let sampler = DESampler::builder() //! .mutation_factor(0.7) //! .crossover_rate(0.9) -//! .strategy(DifferentialEvolutionStrategy::Best1) +//! .strategy(DEStrategy::Best1) //! .population_size(20) //! .seed(42) //! .build(); @@ -74,7 +74,7 @@ use crate::sampler::{CompletedTrial, Sampler}; /// /// Controls how mutant vectors are created from the current population. #[derive(Clone, Copy, Debug, Default)] -pub enum DifferentialEvolutionStrategy { +pub enum DEStrategy { /// DE/rand/1: `v = x_r1 + F * (x_r2 - x_r3)` /// /// The most robust strategy. Uses three random population members. @@ -99,46 +99,36 @@ pub enum DifferentialEvolutionStrategy { /// # Examples /// /// ``` -/// use optimizer::sampler::de::DifferentialEvolutionSampler; +/// use optimizer::sampler::de::DESampler; /// use optimizer::{Direction, Study}; /// /// // Default configuration -/// let study: Study<f64> = -/// Study::with_sampler(Direction::Minimize, DifferentialEvolutionSampler::new()); +/// let study: Study<f64> = Study::with_sampler(Direction::Minimize, DESampler::new()); /// /// // With seed for reproducibility -/// let study: Study<f64> = Study::with_sampler( -/// Direction::Minimize, -/// DifferentialEvolutionSampler::with_seed(42), -/// ); +/// let study: Study<f64> = Study::with_sampler(Direction::Minimize, DESampler::with_seed(42)); /// /// // Custom configuration via builder -/// use optimizer::sampler::de::DifferentialEvolutionStrategy; -/// let sampler = DifferentialEvolutionSampler::builder() +/// use optimizer::sampler::de::DEStrategy; +/// let sampler = DESampler::builder() /// .mutation_factor(0.8) /// .crossover_rate(0.9) -/// .strategy(DifferentialEvolutionStrategy::Best1) +/// .strategy(DEStrategy::Best1) /// .population_size(30) /// .seed(42) /// .build(); /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); /// ``` -pub struct DifferentialEvolutionSampler { +pub struct DESampler { state: Mutex<State>, } -impl DifferentialEvolutionSampler { +impl DESampler { /// Creates a new DE sampler with default settings and a random seed. #[must_use] pub fn new() -> Self { Self { - state: Mutex::new(State::new( - None, - 0.8, - 0.9, - DifferentialEvolutionStrategy::Rand1, - None, - )), + state: Mutex::new(State::new(None, 0.8, 0.9, DEStrategy::Rand1, None)), } } @@ -146,30 +136,24 @@ impl DifferentialEvolutionSampler { #[must_use] pub fn with_seed(seed: u64) -> Self { Self { - state: Mutex::new(State::new( - None, - 0.8, - 0.9, - DifferentialEvolutionStrategy::Rand1, - Some(seed), - )), + state: Mutex::new(State::new(None, 0.8, 0.9, DEStrategy::Rand1, Some(seed))), } } - /// Creates a builder for configuring a `DifferentialEvolutionSampler`. + /// Creates a builder for configuring a `DESampler`. #[must_use] - pub fn builder() -> DifferentialEvolutionSamplerBuilder { - DifferentialEvolutionSamplerBuilder::new() + pub fn builder() -> DESamplerBuilder { + DESamplerBuilder::new() } } -impl Default for DifferentialEvolutionSampler { +impl Default for DESampler { fn default() -> Self { Self::new() } } -/// Builder for configuring a [`DifferentialEvolutionSampler`]. +/// Builder for configuring a [`DESampler`]. /// /// All options have sensible defaults: /// - `population_size`: `max(10 * n_dims, 15)` (auto-computed from parameter count) @@ -181,34 +165,32 @@ impl Default for DifferentialEvolutionSampler { /// # Examples /// /// ``` -/// use optimizer::sampler::de::{ -/// DifferentialEvolutionSamplerBuilder, DifferentialEvolutionStrategy, -/// }; +/// use optimizer::sampler::de::{DESamplerBuilder, DEStrategy}; /// -/// let sampler = DifferentialEvolutionSamplerBuilder::new() +/// let sampler = DESamplerBuilder::new() /// .mutation_factor(0.5) /// .crossover_rate(0.7) -/// .strategy(DifferentialEvolutionStrategy::CurrentToBest1) +/// .strategy(DEStrategy::CurrentToBest1) /// .population_size(20) /// .seed(42) /// .build(); /// ``` #[derive(Debug, Clone)] -pub struct DifferentialEvolutionSamplerBuilder { +pub struct DESamplerBuilder { population_size: Option<usize>, mutation_factor: f64, crossover_rate: f64, - strategy: DifferentialEvolutionStrategy, + strategy: DEStrategy, seed: Option<u64>, } -impl Default for DifferentialEvolutionSamplerBuilder { +impl Default for DESamplerBuilder { fn default() -> Self { Self::new() } } -impl DifferentialEvolutionSamplerBuilder { +impl DESamplerBuilder { /// Creates a new builder with default settings. #[must_use] pub fn new() -> Self { @@ -216,7 +198,7 @@ impl DifferentialEvolutionSamplerBuilder { population_size: None, mutation_factor: 0.8, crossover_rate: 0.9, - strategy: DifferentialEvolutionStrategy::Rand1, + strategy: DEStrategy::Rand1, seed: None, } } @@ -261,9 +243,9 @@ impl DifferentialEvolutionSamplerBuilder { /// Sets the mutation strategy. /// - /// Default: [`DifferentialEvolutionStrategy::Rand1`]. + /// Default: [`DEStrategy::Rand1`]. #[must_use] - pub fn strategy(mut self, strategy: DifferentialEvolutionStrategy) -> Self { + pub fn strategy(mut self, strategy: DEStrategy) -> Self { self.strategy = strategy; self } @@ -275,10 +257,10 @@ impl DifferentialEvolutionSamplerBuilder { self } - /// Builds the configured [`DifferentialEvolutionSampler`]. + /// Builds the configured [`DESampler`]. #[must_use] - pub fn build(self) -> DifferentialEvolutionSampler { - DifferentialEvolutionSampler { + pub fn build(self) -> DESampler { + DESampler { state: Mutex::new(State::new( self.population_size, self.mutation_factor, @@ -345,7 +327,7 @@ struct State { /// Crossover rate (CR). crossover_rate: f64, /// Mutation strategy. - strategy: DifferentialEvolutionStrategy, + strategy: DEStrategy, /// Current phase. phase: Phase, /// Discovered dimension info (populated during discovery). @@ -383,7 +365,7 @@ impl State { user_population_size: Option<usize>, mutation_factor: f64, crossover_rate: f64, - strategy: DifferentialEvolutionStrategy, + strategy: DEStrategy, seed: Option<u64>, ) -> Self { let rng = seed.map_or_else(fastrand::Rng::new, fastrand::Rng::with_seed); @@ -601,21 +583,21 @@ fn create_mutant_with_rng(state: &mut State, target_idx: usize, n_continuous: us let pop_size = state.population_size; match state.strategy { - DifferentialEvolutionStrategy::Rand1 => { + DEStrategy::Rand1 => { let indices = select_random_indices(&mut state.rng, pop_size, 3, &[target_idx]); let (r1, r2, r3) = (indices[0], indices[1], indices[2]); (0..n_continuous) .map(|j| pop[r1][j] + f * (pop[r2][j] - pop[r3][j])) .collect() } - DifferentialEvolutionStrategy::Best1 => { + DEStrategy::Best1 => { let indices = select_random_indices(&mut state.rng, pop_size, 2, &[target_idx]); let (r1, r2) = (indices[0], indices[1]); (0..n_continuous) .map(|j| pop[best_idx][j] + f * (pop[r1][j] - pop[r2][j])) .collect() } - DifferentialEvolutionStrategy::CurrentToBest1 => { + DEStrategy::CurrentToBest1 => { let indices = select_random_indices(&mut state.rng, pop_size, 2, &[target_idx]); let (r1, r2) = (indices[0], indices[1]); (0..n_continuous) @@ -693,7 +675,7 @@ fn generate_initial_population(state: &mut State) -> Vec<Candidate> { // Sampler trait implementation // --------------------------------------------------------------------------- -impl Sampler for DifferentialEvolutionSampler { +impl Sampler for DESampler { #[allow(clippy::cast_precision_loss)] fn sample( &self, @@ -968,7 +950,7 @@ mod tests { #[test] fn test_de_sampler_basic_float() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let dist = Distribution::Float(FloatDistribution { low: -5.0, high: 5.0, @@ -1000,7 +982,7 @@ mod tests { }); let sample_values = |seed: u64| { - let sampler = DifferentialEvolutionSampler::with_seed(seed); + let sampler = DESampler::with_seed(seed); (0..20) .map(|i| sampler.sample(&dist, i, &[])) .collect::<Vec<_>>() @@ -1016,22 +998,16 @@ mod tests { #[test] fn test_de_strategy_default() { - assert!(matches!( - DifferentialEvolutionStrategy::default(), - DifferentialEvolutionStrategy::Rand1 - )); + assert!(matches!(DEStrategy::default(), DEStrategy::Rand1)); } #[test] fn test_builder_defaults() { - let builder = DifferentialEvolutionSamplerBuilder::new(); + let builder = DESamplerBuilder::new(); assert!(builder.population_size.is_none()); assert!((builder.mutation_factor - 0.8).abs() < f64::EPSILON); assert!((builder.crossover_rate - 0.9).abs() < f64::EPSILON); - assert!(matches!( - builder.strategy, - DifferentialEvolutionStrategy::Rand1 - )); + assert!(matches!(builder.strategy, DEStrategy::Rand1)); assert!(builder.seed.is_none()); } } diff --git a/src/sampler/grid.rs b/src/sampler/grid.rs index c3ce2b7..b991f35 100644 --- a/src/sampler/grid.rs +++ b/src/sampler/grid.rs @@ -1,6 +1,6 @@ //! Grid search sampler — exhaustive evaluation of discretized parameter spaces. //! -//! [`GridSearchSampler`] divides each parameter range into a fixed number of +//! [`GridSampler`] divides each parameter range into a fixed number of //! evenly spaced points (or uses the explicit step size when defined) and //! evaluates them sequentially. This guarantees complete coverage of the //! search grid at the cost of scaling exponentially with the number of @@ -29,9 +29,9 @@ //! //! ``` //! use optimizer::prelude::*; -//! use optimizer::sampler::grid::GridSearchSampler; +//! use optimizer::sampler::grid::GridSampler; //! -//! let sampler = GridSearchSampler::builder().n_points_per_param(5).build(); +//! let sampler = GridSampler::builder().n_points_per_param(5).build(); //! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); //! ``` @@ -353,22 +353,22 @@ struct GridState { /// # Examples /// /// ``` -/// use optimizer::sampler::grid::GridSearchSampler; +/// use optimizer::sampler::grid::GridSampler; /// /// // Default: 10 points per parameter -/// let sampler = GridSearchSampler::new(); +/// let sampler = GridSampler::new(); /// /// // Custom grid density -/// let sampler = GridSearchSampler::builder().n_points_per_param(20).build(); +/// let sampler = GridSampler::builder().n_points_per_param(20).build(); /// ``` -pub struct GridSearchSampler { +pub struct GridSampler { /// Number of grid points per parameter (used when auto-discretizing). n_points_per_param: usize, /// Thread-safe internal state for tracking grid positions. state: Mutex<GridState>, } -impl GridSearchSampler { +impl GridSampler { /// Creates a new grid search sampler with default settings. /// /// Default settings: @@ -386,9 +386,9 @@ impl GridSearchSampler { /// # Examples /// /// ``` - /// use optimizer::sampler::grid::GridSearchSampler; + /// use optimizer::sampler::grid::GridSampler; /// - /// let sampler = GridSearchSampler::builder().n_points_per_param(20).build(); + /// let sampler = GridSampler::builder().n_points_per_param(20).build(); /// ``` #[must_use] pub fn builder() -> GridSearchSamplerBuilder { @@ -396,13 +396,13 @@ impl GridSearchSampler { } } -impl Default for GridSearchSampler { +impl Default for GridSampler { fn default() -> Self { Self::new() } } -impl GridSearchSampler { +impl GridSampler { /// Returns `true` if all grid points for all tracked distributions have been sampled. /// /// A distribution is considered exhausted when its `current_index` equals the number @@ -416,9 +416,9 @@ impl GridSearchSampler { /// # Examples /// /// ``` - /// use optimizer::sampler::grid::GridSearchSampler; + /// use optimizer::sampler::grid::GridSampler; /// - /// let sampler = GridSearchSampler::new(); + /// let sampler = GridSampler::new(); /// // Initially exhausted (no distributions tracked yet) /// assert!(sampler.is_exhausted()); /// ``` @@ -446,9 +446,9 @@ impl GridSearchSampler { /// # Examples /// /// ``` - /// use optimizer::sampler::grid::GridSearchSampler; + /// use optimizer::sampler::grid::GridSampler; /// - /// let sampler = GridSearchSampler::new(); + /// let sampler = GridSampler::new(); /// // No distributions tracked yet /// assert_eq!(sampler.grid_size(), 0); /// ``` @@ -459,7 +459,7 @@ impl GridSearchSampler { } } -/// Builder for configuring a [`GridSearchSampler`]. +/// Builder for configuring a [`GridSampler`]. /// /// # Examples /// @@ -511,7 +511,7 @@ impl GridSearchSamplerBuilder { self } - /// Builds the configured [`GridSearchSampler`]. + /// Builds the configured [`GridSampler`]. /// /// # Examples /// @@ -523,8 +523,8 @@ impl GridSearchSamplerBuilder { /// .build(); /// ``` #[must_use] - pub fn build(self) -> GridSearchSampler { - GridSearchSampler { + pub fn build(self) -> GridSampler { + GridSampler { n_points_per_param: self.n_points_per_param, state: Mutex::new(GridState::default()), } @@ -566,7 +566,7 @@ fn distribution_key(dist: &Distribution) -> String { } } -impl Sampler for GridSearchSampler { +impl Sampler for GridSampler { fn sample( &self, distribution: &Distribution, @@ -876,7 +876,7 @@ mod tests { #[test] fn test_sampler_exhausts_after_expected_samples() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 3 }); // Sample all 3 points @@ -890,7 +890,7 @@ mod tests { #[test] fn test_sampler_exhaustion_with_int_distribution() { - let sampler = GridSearchSampler::builder().n_points_per_param(5).build(); + let sampler = GridSampler::builder().n_points_per_param(5).build(); let dist = Distribution::Int(IntDistribution { low: 0, high: 100, @@ -910,7 +910,7 @@ mod tests { #[test] #[should_panic(expected = "GridSearchSampler: all grid points exhausted")] fn test_sampler_panics_after_exhaustion() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 2 }); // Sample all 2 points @@ -925,14 +925,14 @@ mod tests { #[test] fn test_is_exhausted_before_sampling() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); // Newly created sampler is vacuously exhausted (no distributions tracked) assert!(sampler.is_exhausted()); } #[test] fn test_is_exhausted_during_sampling() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 3 }); // After first sample, not exhausted @@ -950,7 +950,7 @@ mod tests { #[test] fn test_is_exhausted_multiple_distributions() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); // Use different n_choices so they have different distribution keys let dist1 = Distribution::Categorical(CategoricalDistribution { n_choices: 2 }); let dist2 = Distribution::Categorical(CategoricalDistribution { n_choices: 3 }); @@ -976,7 +976,7 @@ mod tests { #[test] fn test_builder_default() { - let sampler = GridSearchSampler::builder().build(); + let sampler = GridSampler::builder().build(); let dist = Distribution::Float(FloatDistribution { low: 0.0, high: 1.0, @@ -993,7 +993,7 @@ mod tests { #[test] fn test_builder_custom_n_points() { - let sampler = GridSearchSampler::builder().n_points_per_param(3).build(); + let sampler = GridSampler::builder().n_points_per_param(3).build(); let dist = Distribution::Float(FloatDistribution { low: 0.0, high: 1.0, @@ -1011,7 +1011,7 @@ mod tests { #[test] fn test_new_default() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); let dist = Distribution::Float(FloatDistribution { low: 0.0, high: 1.0, @@ -1031,8 +1031,8 @@ mod tests { #[test] fn test_reproducibility_same_grid_order() { // Two samplers with the same configuration should produce the same grid order - let sampler1 = GridSearchSampler::builder().n_points_per_param(5).build(); - let sampler2 = GridSearchSampler::builder().n_points_per_param(5).build(); + let sampler1 = GridSampler::builder().n_points_per_param(5).build(); + let sampler2 = GridSampler::builder().n_points_per_param(5).build(); let dist = Distribution::Float(FloatDistribution { low: 0.0, @@ -1051,8 +1051,8 @@ mod tests { #[test] fn test_reproducibility_int_distribution() { - let sampler1 = GridSearchSampler::new(); - let sampler2 = GridSearchSampler::new(); + let sampler1 = GridSampler::new(); + let sampler2 = GridSampler::new(); let dist = Distribution::Int(IntDistribution { low: 0, @@ -1073,8 +1073,8 @@ mod tests { #[test] fn test_reproducibility_categorical() { - let sampler1 = GridSearchSampler::new(); - let sampler2 = GridSearchSampler::new(); + let sampler1 = GridSampler::new(); + let sampler2 = GridSampler::new(); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 4 }); @@ -1091,13 +1091,13 @@ mod tests { #[test] fn test_grid_size_empty() { - let sampler = GridSearchSampler::new(); + let sampler = GridSampler::new(); assert_eq!(sampler.grid_size(), 0); } #[test] fn test_grid_size_single_distribution() { - let sampler = GridSearchSampler::builder().n_points_per_param(5).build(); + let sampler = GridSampler::builder().n_points_per_param(5).build(); let dist = Distribution::Float(FloatDistribution { low: 0.0, high: 1.0, @@ -1115,7 +1115,7 @@ mod tests { #[test] fn test_grid_size_multiple_distributions() { - let sampler = GridSearchSampler::builder().n_points_per_param(3).build(); + let sampler = GridSampler::builder().n_points_per_param(3).build(); let dist1 = Distribution::Float(FloatDistribution { low: 0.0, high: 1.0, diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index cfa5537..bea4200 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -22,10 +22,10 @@ use std::collections::HashMap; pub use bohb::BohbSampler; #[cfg(feature = "cma-es")] pub use cma_es::CmaEsSampler; -pub use de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; +pub use de::{DESampler, DEStrategy}; #[cfg(feature = "gp")] pub use gp::GpSampler; -pub use grid::GridSearchSampler; +pub use grid::GridSampler; pub use moead::{Decomposition, MoeadSampler}; pub use motpe::MotpeSampler; pub use nsga2::Nsga2Sampler; diff --git a/tests/sampler/differential_evolution.rs b/tests/sampler/differential_evolution.rs index 4c75d5f..1ffa5e9 100644 --- a/tests/sampler/differential_evolution.rs +++ b/tests/sampler/differential_evolution.rs @@ -1,9 +1,9 @@ use optimizer::prelude::*; -use optimizer::sampler::de::{DifferentialEvolutionSampler, DifferentialEvolutionStrategy}; +use optimizer::sampler::de::{DESampler, DEStrategy}; #[test] fn sphere_function() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x = FloatParam::new(-5.0, 5.0).name("x"); @@ -27,10 +27,7 @@ fn sphere_function() { #[test] fn rosenbrock_function() { - let sampler = DifferentialEvolutionSampler::builder() - .population_size(20) - .seed(42) - .build(); + let sampler = DESampler::builder().population_size(20).seed(42).build(); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x = FloatParam::new(-5.0, 5.0).name("x"); @@ -55,7 +52,7 @@ fn rosenbrock_function() { #[test] fn rastrigin_function() { - let sampler = DifferentialEvolutionSampler::builder() + let sampler = DESampler::builder() .population_size(30) .mutation_factor(0.7) .crossover_rate(0.9) @@ -88,7 +85,7 @@ fn rastrigin_function() { #[test] fn bounds_respected() { - let sampler = DifferentialEvolutionSampler::with_seed(123); + let sampler = DESampler::with_seed(123); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x = FloatParam::new(-2.0, 3.0).name("x"); @@ -112,8 +109,8 @@ fn bounds_respected() { #[test] fn strategy_best1() { - let sampler = DifferentialEvolutionSampler::builder() - .strategy(DifferentialEvolutionStrategy::Best1) + let sampler = DESampler::builder() + .strategy(DEStrategy::Best1) .population_size(15) .seed(42) .build(); @@ -140,8 +137,8 @@ fn strategy_best1() { #[test] fn strategy_current_to_best1() { - let sampler = DifferentialEvolutionSampler::builder() - .strategy(DifferentialEvolutionStrategy::CurrentToBest1) + let sampler = DESampler::builder() + .strategy(DEStrategy::CurrentToBest1) .population_size(15) .seed(42) .build(); @@ -168,7 +165,7 @@ fn strategy_current_to_best1() { #[test] fn mixed_params_float_and_categorical() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x = FloatParam::new(-5.0, 5.0).name("x"); @@ -201,7 +198,7 @@ fn seeded_reproducibility() { let y = FloatParam::new(-5.0, 5.0).name("y"); let run = |seed: u64| { - let sampler = DifferentialEvolutionSampler::with_seed(seed); + let sampler = DESampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study .optimize(50, |trial: &mut optimizer::Trial| { @@ -224,7 +221,7 @@ fn different_seeds_different_results() { let y = FloatParam::new(-5.0, 5.0).name("y"); let run = |seed: u64| { - let sampler = DifferentialEvolutionSampler::with_seed(seed); + let sampler = DESampler::with_seed(seed); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); study .optimize(20, |trial: &mut optimizer::Trial| { @@ -246,7 +243,7 @@ fn different_seeds_different_results() { #[test] fn single_dimension() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x = FloatParam::new(-10.0, 10.0).name("x"); @@ -268,7 +265,7 @@ fn single_dimension() { #[test] fn integer_params() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let n = IntParam::new(1, 20).name("n"); @@ -296,7 +293,7 @@ fn integer_params() { #[test] fn log_scale_params() { - let sampler = DifferentialEvolutionSampler::with_seed(42); + let sampler = DESampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let lr = FloatParam::new(1e-5, 1.0).log_scale().name("lr"); @@ -320,7 +317,7 @@ fn log_scale_params() { #[test] fn custom_mutation_and_crossover() { - let sampler = DifferentialEvolutionSampler::builder() + let sampler = DESampler::builder() .mutation_factor(0.5) .crossover_rate(0.7) .population_size(10) From e68a3788bef22131de294751c475b944d731c69b Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:21:37 +0100 Subject: [PATCH 10/48] refactor: move RandomMultiObjectiveSampler into sampler::random --- src/multi_objective.rs | 27 +-------------------------- src/sampler/random.rs | 23 +++++++++++++++++++++++ 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/multi_objective.rs b/src/multi_objective.rs index a82b1ad..9b924f3 100644 --- a/src/multi_objective.rs +++ b/src/multi_objective.rs @@ -55,7 +55,7 @@ use crate::distribution::Distribution; use crate::param::ParamValue; use crate::parameter::{ParamId, Parameter}; use crate::pruner::NopPruner; -use crate::sampler::random::RandomSampler; +use crate::sampler::random::RandomMultiObjectiveSampler; use crate::sampler::{CompletedTrial, Sampler}; use crate::trial::{AttrValue, Trial}; use crate::types::{Direction, TrialState}; @@ -154,31 +154,6 @@ pub trait MultiObjectiveSampler: Send + Sync { ) -> ParamValue; } -// --------------------------------------------------------------------------- -// RandomMultiObjectiveSampler -// --------------------------------------------------------------------------- - -/// Default MO sampler that delegates to [`RandomSampler`]. -pub(crate) struct RandomMultiObjectiveSampler(RandomSampler); - -impl RandomMultiObjectiveSampler { - pub(crate) fn new() -> Self { - Self(RandomSampler::new()) - } -} - -impl MultiObjectiveSampler for RandomMultiObjectiveSampler { - fn sample( - &self, - distribution: &Distribution, - trial_id: u64, - _history: &[MultiObjectiveTrial], - _directions: &[Direction], - ) -> ParamValue { - self.0.sample(distribution, trial_id, &[]) - } -} - // --------------------------------------------------------------------------- // MoSamplerBridge — bridges MultiObjectiveSampler to Sampler trait // --------------------------------------------------------------------------- diff --git a/src/sampler/random.rs b/src/sampler/random.rs index 292fe4a..7dc31ef 100644 --- a/src/sampler/random.rs +++ b/src/sampler/random.rs @@ -30,9 +30,11 @@ use parking_lot::Mutex; use crate::distribution::Distribution; +use crate::multi_objective::{MultiObjectiveSampler, MultiObjectiveTrial}; use crate::param::ParamValue; use crate::rng_util; use crate::sampler::{CompletedTrial, Sampler}; +use crate::types::Direction; /// Uniform independent random sampler. /// @@ -79,6 +81,27 @@ impl RandomSampler { } } +/// Default multi-objective sampler that delegates to [`RandomSampler`]. +pub(crate) struct RandomMultiObjectiveSampler(RandomSampler); + +impl RandomMultiObjectiveSampler { + pub(crate) fn new() -> Self { + Self(RandomSampler::new()) + } +} + +impl MultiObjectiveSampler for RandomMultiObjectiveSampler { + fn sample( + &self, + distribution: &Distribution, + trial_id: u64, + _history: &[MultiObjectiveTrial], + _directions: &[Direction], + ) -> ParamValue { + self.0.sample(distribution, trial_id, &[]) + } +} + impl Default for RandomSampler { fn default() -> Self { Self::new() From 12f7f25baaecd7f8a26bbb02f3a6021c76b61cc8 Mon Sep 17 00:00:00 2001 From: Manuel <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:29:44 +0100 Subject: [PATCH 11/48] Update README.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9e31556..5ba0d2d 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ println!("Best x = {:.4}, f(x) = {:.4}", best.get(&x).unwrap(), best.value); - **[Multi-objective](https://docs.rs/optimizer/latest/optimizer/multi_objective/)** — Pareto front extraction with NSGA-II/III and MOEA/D - **[Async & parallel](https://docs.rs/optimizer/latest/optimizer/struct.Study.html#method.optimize_parallel)** — Concurrent trial evaluation with Tokio - **[Storage backends](https://docs.rs/optimizer/latest/optimizer/storage/)** — In-memory (default) or JSONL journal for persistence and resumption -- **[Visualization](https://docs.rs/optimizer/latest/optimizer/fn.generate_html_report.html)** — HTML reports with optimization history, parameter importance, and Pareto fronts +- **[Visualization](https://docs.rs/optimizer/latest/optimizer/fn.generate_html_report.html)** — HTML reports with optimization history and parameter importance - **[Analysis](https://docs.rs/optimizer/latest/optimizer/struct.Study.html#method.fanova)** — fANOVA and Spearman correlation for parameter importance ## Feature Flags From 6f1c070b7d6613c855f7e3c56e55d68bfb27e92f Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:38:36 +0100 Subject: [PATCH 12/48] test: actually use NaN values in test_best_trial_with_nan_values --- tests/study/workflow.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/study/workflow.rs b/tests/study/workflow.rs index 4b998ec..a14bf11 100644 --- a/tests/study/workflow.rs +++ b/tests/study/workflow.rs @@ -161,13 +161,20 @@ fn test_best_trial_with_nan_values() { let study: Study<f64> = Study::new(Direction::Minimize); let x_param = FloatParam::new(0.0, 10.0); - study - .optimize(5, |trial: &mut optimizer::Trial| { - let x = x_param.suggest(trial)?; - Ok::<_, Error>(x) - }) - .unwrap(); + // Mix NaN and valid objective values + let mut trial = study.create_trial(); + let _ = x_param.suggest(&mut trial).unwrap(); + study.complete_trial(trial, f64::NAN); + + let mut trial = study.create_trial(); + let _ = x_param.suggest(&mut trial).unwrap(); + study.complete_trial(trial, 5.0); + + let mut trial = study.create_trial(); + let _ = x_param.suggest(&mut trial).unwrap(); + study.complete_trial(trial, f64::NAN); + // best_trial succeeds even when some trials have NaN values let best = study.best_trial(); assert!(best.is_ok()); } From cc3937d4f7eeccd61e030cafb6171ca909d91dc4 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:49:49 +0100 Subject: [PATCH 13/48] Add LICENSE.md --- LICENSE.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 LICENSE.md diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..4c98eae --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 optimizer + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. From 2d88526d8d5325814501f95efd08100579d9a095 Mon Sep 17 00:00:00 2001 From: Manuel <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:52:26 +0100 Subject: [PATCH 14/48] Update examples/pruning.rs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- examples/pruning.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pruning.rs b/examples/pruning.rs index 284afe5..bb6bc26 100644 --- a/examples/pruning.rs +++ b/examples/pruning.rs @@ -55,13 +55,13 @@ fn main() -> optimizer::Result<()> { // --- Results --- let best = study.best_trial()?; println!( - "Completed {} trials ({} pruned)", + "Recorded {} trials ({} pruned)", study.n_trials(), study.n_pruned_trials() ); println!("Best trial #{}: loss = {:.6}", best.id, best.value); println!(" learning_rate = {:.6}", best.get(&lr).unwrap()); - println!(" momentum = {:.4}", best.get(&momentum).unwrap()); +println!(" momentum = {:.4}", best.get(&momentum).unwrap()); Ok(()) } From 968eb67d075e5a5ad13a8ad1f1626522d6f089a9 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:54:19 +0100 Subject: [PATCH 15/48] fix: return error instead of panicking on out-of-bounds categorical index --- src/parameter.rs | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/parameter.rs b/src/parameter.rs index c266034..74782f8 100644 --- a/src/parameter.rs +++ b/src/parameter.rs @@ -467,7 +467,11 @@ impl<T: Clone + Debug> Parameter for CategoricalParam<T> { fn cast_param_value(&self, param_value: &ParamValue) -> Result<T> { match param_value { - ParamValue::Categorical(index) => Ok(self.choices[*index].clone()), + ParamValue::Categorical(index) => self + .choices + .get(*index) + .cloned() + .ok_or(Error::Internal("categorical index out of bounds")), _ => Err(Error::Internal( "Categorical distribution should return Categorical value", )), @@ -708,7 +712,8 @@ impl<T: Categorical + Debug> Parameter for EnumParam<T> { fn cast_param_value(&self, param_value: &ParamValue) -> Result<T> { match param_value { - ParamValue::Categorical(index) => Ok(T::from_index(*index)), + ParamValue::Categorical(index) if *index < T::N_CHOICES => Ok(T::from_index(*index)), + ParamValue::Categorical(_) => Err(Error::Internal("categorical index out of bounds")), _ => Err(Error::Internal( "Categorical distribution should return Categorical value", )), @@ -887,6 +892,17 @@ mod tests { assert!(param.cast_param_value(&ParamValue::Float(1.0)).is_err()); } + #[test] + fn categorical_param_cast_out_of_bounds() { + let param = CategoricalParam::new(vec!["sgd", "adam", "rmsprop"]); + assert!(param.cast_param_value(&ParamValue::Categorical(3)).is_err()); + assert!( + param + .cast_param_value(&ParamValue::Categorical(usize::MAX)) + .is_err() + ); + } + #[test] fn bool_param_distribution() { let param = BoolParam::new(); @@ -955,6 +971,17 @@ mod tests { assert!(param.cast_param_value(&ParamValue::Float(1.0)).is_err()); } + #[test] + fn enum_param_cast_out_of_bounds() { + let param = EnumParam::<TestEnum>::new(); + assert!(param.cast_param_value(&ParamValue::Categorical(3)).is_err()); + assert!( + param + .cast_param_value(&ParamValue::Categorical(usize::MAX)) + .is_err() + ); + } + #[test] fn float_param_suggest_via_trial() { let param = FloatParam::new(0.0, 1.0); From 3723f91849bfdbfdcc1fa13b15fbb129ff51e905 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:56:09 +0100 Subject: [PATCH 16/48] refactor: update rustfmt configuration by commenting out unused options --- rustfmt.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 197a4e5..a7a3898 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,8 +1,8 @@ # Nightly Features -imports_granularity = "Module" -group_imports = "StdExternalCrate" +#imports_granularity = "Module" +#group_imports = "StdExternalCrate" #format_strings = true -format_code_in_doc_comments = true +#format_code_in_doc_comments = true # Stable Features merge_derives = true From 8713a9363650c4a1be5e521c323ae9ace4eacfe2 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 13:59:14 +0100 Subject: [PATCH 17/48] refactor: adjust indentation for better readability in pruning.rs --- examples/pruning.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pruning.rs b/examples/pruning.rs index bb6bc26..1b75513 100644 --- a/examples/pruning.rs +++ b/examples/pruning.rs @@ -61,7 +61,7 @@ fn main() -> optimizer::Result<()> { ); println!("Best trial #{}: loss = {:.6}", best.id, best.value); println!(" learning_rate = {:.6}", best.get(&lr).unwrap()); -println!(" momentum = {:.4}", best.get(&momentum).unwrap()); + println!(" momentum = {:.4}", best.get(&momentum).unwrap()); Ok(()) } From 04d822479ec9d8ccc28bbf7bdb8afb4e8708d237 Mon Sep 17 00:00:00 2001 From: Manuel <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:12:52 +0100 Subject: [PATCH 18/48] Update Cargo.toml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index c746972..0d0d453 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,7 +7,7 @@ version = "0.9.1" edition = "2024" rust-version = "1.89" license = "MIT" -authors = ["Manuel Raimann <raimannma@outlook.de"] +authors = ["Manuel Raimann <raimannma@outlook.de>"] description = "Bayesian and population-based optimization library with an Optuna-like API for hyperparameter tuning and black-box optimization" repository = "https://github.com/raimannma/rust-optimizer" documentation = "https://docs.rs/optimizer" From c561caaa342ee236ae3b790c19a3e158b1d2dfd7 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:02:24 +0100 Subject: [PATCH 19/48] fix: validate NaN/Inf in deserialized trials during journal loading - Add `CompletedTrial::validate()` checking all f64 fields are finite - Call validate after deserializing each trial in journal storage - Make `distribution` and `param` modules public (types already in public fields) --- src/lib.rs | 4 +- src/sampler/mod.rs | 68 +++++++++++++++++++++++++++ src/storage/journal.rs | 1 + tests/journal_tests.rs | 103 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 174 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 808012d..a7c507e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -107,14 +107,14 @@ macro_rules! trace_debug { ($($arg:tt)*) => {}; } -mod distribution; +pub mod distribution; mod error; mod fanova; mod importance; mod kde; pub mod multi_objective; pub mod objective; -mod param; +pub mod param; pub mod parameter; pub mod pareto; pub mod pruner; diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index bea4200..4e20f4e 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -175,6 +175,74 @@ impl<V> CompletedTrial<V> { pub fn user_attrs(&self) -> &HashMap<String, AttrValue> { &self.user_attrs } + + /// Validates that all floating-point fields are finite (not NaN or + /// Infinity). + /// + /// Checks distribution bounds, parameter values, constraints, and + /// intermediate values. Returns a description of the first invalid + /// field found, or `Ok(())` if everything is valid. + /// + /// # Errors + /// + /// Returns a `String` describing the first non-finite value found. + pub fn validate(&self) -> core::result::Result<(), String> { + for (id, dist) in &self.distributions { + if let Distribution::Float(fd) = dist { + if !fd.low.is_finite() { + return Err(format!( + "trial {}: float distribution for param {id} has non-finite low bound ({})", + self.id, fd.low + )); + } + if !fd.high.is_finite() { + return Err(format!( + "trial {}: float distribution for param {id} has non-finite high bound ({})", + self.id, fd.high + )); + } + if let Some(step) = fd.step + && !step.is_finite() + { + return Err(format!( + "trial {}: float distribution for param {id} has non-finite step ({step})", + self.id + )); + } + } + } + + for (id, pv) in &self.params { + if let ParamValue::Float(v) = pv + && !v.is_finite() + { + return Err(format!( + "trial {}: param {id} has non-finite float value ({v})", + self.id + )); + } + } + + for (i, &c) in self.constraints.iter().enumerate() { + if !c.is_finite() { + return Err(format!( + "trial {}: constraint[{i}] is non-finite ({c})", + self.id + )); + } + } + + for &(step, v) in &self.intermediate_values { + if !v.is_finite() { + return Err(format!( + "trial {}: intermediate value at step {step} is non-finite ({v})", + self.id + )); + } + } + + Ok(()) + } } /// A pending (running) trial with its parameters and distributions, but no objective value yet. diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 1ee477b..55ad087 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -244,6 +244,7 @@ fn load_trials_from_file<V: DeserializeOwned>( } let trial: CompletedTrial<V> = serde_json::from_str(line).map_err(|e| crate::Error::Storage(e.to_string()))?; + trial.validate().map_err(crate::Error::Storage)?; trials.push(trial); } diff --git a/tests/journal_tests.rs b/tests/journal_tests.rs index 4eef94e..6d5f473 100644 --- a/tests/journal_tests.rs +++ b/tests/journal_tests.rs @@ -214,3 +214,106 @@ fn pruned_trials_are_stored() { std::fs::remove_file(&path).ok(); } + +#[test] +fn rejects_non_finite_values_in_journal() { + // serde_json rejects 1e999 ("number out of range"), so non-finite + // floats cannot sneak in through standard JSON. Verify the overall + // loading path catches the error regardless of which layer rejects it. + let path = temp_path(); + std::fs::write( + &path, + r#"{"id":0,"params":{},"distributions":{"0":{"Float":{"low":0.0,"high":1e999,"log_scale":false,"step":null}}},"param_labels":{},"value":1.0,"intermediate_values":[],"state":"Complete","user_attrs":{},"constraints":[]}"#, + ) + .unwrap(); + + assert!(JournalStorage::<f64>::open(&path).is_err()); + std::fs::remove_file(&path).ok(); +} + +#[test] +fn validate_rejects_non_finite_distribution_bound() { + use optimizer::distribution::{Distribution, FloatDistribution}; + + let pid = FloatParam::new(0.0, 1.0).id(); + let mut trial = sample_trial(0, 1.0); + trial.distributions.insert( + pid, + Distribution::Float(FloatDistribution { + low: 0.0, + high: f64::INFINITY, + log_scale: false, + step: None, + }), + ); + let err = trial.validate().unwrap_err(); + assert!(err.contains("non-finite"), "unexpected: {err}"); +} + +#[test] +fn validate_rejects_nan_constraint() { + let mut trial = sample_trial(0, 1.0); + trial.constraints.push(f64::NAN); + let err = trial.validate().unwrap_err(); + assert!(err.contains("non-finite"), "unexpected: {err}"); +} + +#[test] +fn validate_rejects_non_finite_param_value() { + use optimizer::param::ParamValue; + + let pid = FloatParam::new(0.0, 1.0).id(); + let mut trial = sample_trial(0, 1.0); + trial + .params + .insert(pid, ParamValue::Float(f64::NEG_INFINITY)); + let err = trial.validate().unwrap_err(); + assert!(err.contains("non-finite"), "unexpected: {err}"); +} + +#[test] +fn validate_rejects_nan_intermediate_value() { + let mut trial = sample_trial(0, 1.0); + trial.intermediate_values.push((0, f64::NAN)); + let err = trial.validate().unwrap_err(); + assert!(err.contains("non-finite"), "unexpected: {err}"); +} + +#[test] +fn validate_accepts_valid_trial() { + use optimizer::distribution::{Distribution, FloatDistribution}; + use optimizer::param::ParamValue; + + let pid = FloatParam::new(0.0, 1.0).id(); + let mut trial = sample_trial(0, 1.0); + trial.params.insert(pid, ParamValue::Float(0.5)); + trial.distributions.insert( + pid, + Distribution::Float(FloatDistribution { + low: 0.0, + high: 1.0, + log_scale: false, + step: None, + }), + ); + trial.constraints.push(-1.0); + trial.intermediate_values.push((0, 0.5)); + assert!(trial.validate().is_ok()); +} + +#[test] +fn accepts_valid_journal_with_distributions() { + let path = temp_path(); + std::fs::write( + &path, + r#"{"id":0,"params":{"0":{"Float":0.5}},"distributions":{"0":{"Float":{"low":0.0,"high":1.0,"log_scale":false,"step":null}}},"param_labels":{},"value":0.25,"intermediate_values":[],"state":"Complete","user_attrs":{},"constraints":[-1.0]}"#, + ) + .unwrap(); + + let storage = JournalStorage::<f64>::open(&path).unwrap(); + let loaded = storage.trials_arc().read().clone(); + assert_eq!(loaded.len(), 1); + assert_eq!(loaded[0].value, 0.25); + + std::fs::remove_file(&path).ok(); +} From 271748700f127bdd63e51f8e9da521de4395149a Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:13:34 +0100 Subject: [PATCH 20/48] perf(tpe): use quickselect instead of full sort in split_trials --- src/sampler/tpe/sampler.rs | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 16ff43b..47d2d09 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -288,15 +288,6 @@ impl TpeSampler { return (vec![], vec![]); } - // Sort trials by value (ascending for minimization) - let mut sorted_indices: Vec<usize> = (0..history.len()).collect(); - sorted_indices.sort_by(|&a, &b| { - history[a] - .value - .partial_cmp(&history[b].value) - .unwrap_or(core::cmp::Ordering::Equal) - }); - // Compute gamma using the strategy and clamp to valid range let gamma = self .gamma_strategy @@ -309,14 +300,20 @@ impl TpeSampler { .max(1) .min(history.len() - 1); - let good: Vec<_> = sorted_indices[..n_good] - .iter() - .map(|&i| &history[i]) - .collect(); - let bad: Vec<_> = sorted_indices[n_good..] - .iter() - .map(|&i| &history[i]) - .collect(); + // Use quickselect (O(n)) to partition indices instead of full sort (O(n log n)). + // We only need to know which trials are in the top gamma-quantile, not their order. + let mut indices: Vec<usize> = (0..history.len()).collect(); + if n_good > 0 { + indices.select_nth_unstable_by(n_good - 1, |&a, &b| { + history[a] + .value + .partial_cmp(&history[b].value) + .unwrap_or(core::cmp::Ordering::Equal) + }); + } + + let good: Vec<_> = indices[..n_good].iter().map(|&i| &history[i]).collect(); + let bad: Vec<_> = indices[n_good..].iter().map(|&i| &history[i]).collect(); (good, bad) } From 750bbad0d8a7dbfa9343c3f3646827b2a1f99a5b Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:19:07 +0100 Subject: [PATCH 21/48] test: assert pruned count unconditionally in summary test Remove the conditional guard on pruned trial assertions since the test explicitly prunes 2 trials. Assert the expected count first, then check summary content unconditionally to catch regressions. --- tests/study/summary.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/study/summary.rs b/tests/study/summary.rs index f868f0d..2195e38 100644 --- a/tests/study/summary.rs +++ b/tests/study/summary.rs @@ -47,12 +47,10 @@ fn test_summary_with_pruned_trials() { study.prune_trial(trial); } + assert_eq!(study.n_pruned_trials(), 2); let summary = study.summary(); - // Should show breakdown when there are pruned trials - if study.n_pruned_trials() > 0 { - assert!(summary.contains("complete")); - assert!(summary.contains("pruned")); - } + assert!(summary.contains("complete")); + assert!(summary.contains("pruned")); } #[test] From baebbae64ce881571f782d45959d13733be49973 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:27:00 +0100 Subject: [PATCH 22/48] docs: add pre-commit guidelines for formatting and testing --- .claude/commands/commit-push-pr.md | 5 +++++ .claude/commands/commit-push.md | 5 +++++ .claude/commands/commit.md | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/.claude/commands/commit-push-pr.md b/.claude/commands/commit-push-pr.md index b29a7df..f757343 100644 --- a/.claude/commands/commit-push-pr.md +++ b/.claude/commands/commit-push-pr.md @@ -94,3 +94,8 @@ gh pr create --title "<title>" --base <target-branch> --body "$(cat <<'EOF' EOF )" ``` + +## Before you commit + +1. Run `cargo +nightly fmt --all && cargo +nightly clippy --all-features --all-targets --fix --allow-dirty` to auto-format and fix lint issues. +2. Run all tests and doctests to ensure nothing is broken. diff --git a/.claude/commands/commit-push.md b/.claude/commands/commit-push.md index a3f38c7..16907a8 100644 --- a/.claude/commands/commit-push.md +++ b/.claude/commands/commit-push.md @@ -75,3 +75,8 @@ git commit -m "$(cat <<'EOF' EOF )" ``` + +## Before you commit + +1. Run `cargo +nightly fmt --all && cargo +nightly clippy --all-features --all-targets --fix --allow-dirty` to auto-format and fix lint issues. +2. Run all tests and doctests to ensure nothing is broken. diff --git a/.claude/commands/commit.md b/.claude/commands/commit.md index 50d3362..d0f4a28 100644 --- a/.claude/commands/commit.md +++ b/.claude/commands/commit.md @@ -74,3 +74,8 @@ git commit -m "$(cat <<'EOF' EOF )" ``` + +## Before you commit + +1. Run `cargo +nightly fmt --all && cargo +nightly clippy --all-features --all-targets --fix --allow-dirty` to auto-format and fix lint issues. +2. Run all tests and doctests to ensure nothing is broken. From 74cf0643eb2d7b8d44dbf3c5b17f38beec074203 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:34:07 +0100 Subject: [PATCH 23/48] perf: replace RNG mutex with per-call seed derivation in stateless samplers - Replace Mutex<fastrand::Rng> with a stored seed + MurmurHash3 mixer in RandomSampler, TpeSampler, and MotpeSampler so parallel workers no longer serialize on a shared lock - Add mix_seed() and distribution_fingerprint() to rng_util for deterministic per-call RNG derivation from (seed, trial_id, distribution) - Include an AtomicU64 call counter to disambiguate parameters that share the same distribution within a trial --- src/rng_util.rs | 53 ++++++++++++++++++++++++++++++++++++ src/sampler/motpe.rs | 35 +++++++++++++----------- src/sampler/random.rs | 55 ++++++++++++++++++++++---------------- src/sampler/tpe/sampler.rs | 41 ++++++++++++++-------------- tests/sampler/tpe.rs | 4 +-- 5 files changed, 127 insertions(+), 61 deletions(-) diff --git a/src/rng_util.rs b/src/rng_util.rs index ee273e1..8833f89 100644 --- a/src/rng_util.rs +++ b/src/rng_util.rs @@ -1,5 +1,58 @@ +use crate::distribution::Distribution; + /// Generate a random `f64` in the range `[low, high)`. #[inline] pub(crate) fn f64_range(rng: &mut fastrand::Rng, low: f64, high: f64) -> f64 { low + rng.f64() * (high - low) } + +/// Combine a base seed, trial id, and distribution fingerprint into a +/// deterministic per-call seed using `MurmurHash3`'s 64-bit finalizer. +#[inline] +pub(crate) fn mix_seed(base: u64, trial_id: u64, dist_fingerprint: u64) -> u64 { + let mut h = base + .wrapping_mul(0xff51_afd7_ed55_8ccd) + .wrapping_add(trial_id) + .wrapping_mul(0xc4ce_b9fe_1a85_ec53) + .wrapping_add(dist_fingerprint); + h ^= h >> 33; + h = h.wrapping_mul(0xff51_afd7_ed55_8ccd); + h ^= h >> 33; + h = h.wrapping_mul(0xc4ce_b9fe_1a85_ec53); + h ^= h >> 33; + h +} + +/// Stable `u64` fingerprint for a [`Distribution`], using variant tags and +/// `f64::to_bits()` for float fields so that distinct distributions within +/// the same trial produce different RNG streams. +#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] +pub(crate) fn distribution_fingerprint(distribution: &Distribution) -> u64 { + match distribution { + Distribution::Float(d) => { + let mut h: u64 = 1; + h = h.wrapping_mul(31).wrapping_add(d.low.to_bits()); + h = h.wrapping_mul(31).wrapping_add(d.high.to_bits()); + h = h.wrapping_mul(31).wrapping_add(u64::from(d.log_scale)); + if let Some(step) = d.step { + h = h.wrapping_mul(31).wrapping_add(step.to_bits()); + } + h + } + Distribution::Int(d) => { + let mut h: u64 = 2; + h = h.wrapping_mul(31).wrapping_add(d.low as u64); + h = h.wrapping_mul(31).wrapping_add(d.high as u64); + h = h.wrapping_mul(31).wrapping_add(u64::from(d.log_scale)); + if let Some(step) = d.step { + h = h.wrapping_mul(31).wrapping_add(step as u64); + } + h + } + Distribution::Categorical(d) => { + let mut h: u64 = 3; + h = h.wrapping_mul(31).wrapping_add(d.n_choices as u64); + h + } + } +} diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index abffc78..cdcf559 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -58,7 +58,7 @@ //! assert!(!front.is_empty()); //! ``` -use parking_lot::Mutex; +use core::sync::atomic::{AtomicU64, Ordering}; use crate::distribution::Distribution; use crate::kde::KernelDensityEstimator; @@ -119,8 +119,10 @@ pub struct MotpeSampler { n_ei_candidates: usize, /// Optional fixed bandwidth for KDE. If None, uses Scott's rule. kde_bandwidth: Option<f64>, - /// Thread-safe RNG for sampling. - rng: Mutex<fastrand::Rng>, + /// Base seed for deterministic per-call RNG derivation (no mutex needed). + seed: u64, + /// Monotonic counter to disambiguate calls with identical (`trial_id`, distribution). + call_seq: AtomicU64, } impl MotpeSampler { @@ -136,7 +138,8 @@ impl MotpeSampler { n_startup_trials: 11, n_ei_candidates: 24, kde_bandwidth: None, - rng: Mutex::new(fastrand::Rng::new()), + seed: fastrand::u64(..), + call_seq: AtomicU64::new(0), } } @@ -147,7 +150,8 @@ impl MotpeSampler { n_startup_trials: 11, n_ei_candidates: 24, kde_bandwidth: None, - rng: Mutex::new(fastrand::Rng::with_seed(seed)), + seed, + call_seq: AtomicU64::new(0), } } @@ -423,11 +427,16 @@ impl MultiObjectiveSampler for MotpeSampler { fn sample( &self, distribution: &Distribution, - _trial_id: u64, + trial_id: u64, history: &[MultiObjectiveTrial], directions: &[Direction], ) -> ParamValue { - let mut rng = self.rng.lock(); + let seq = self.call_seq.fetch_add(1, Ordering::Relaxed); + let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed( + self.seed, + trial_id, + rng_util::distribution_fingerprint(distribution).wrapping_add(seq), + )); // Fall back to random sampling during startup phase let n_complete = history @@ -628,16 +637,12 @@ impl MotpeSamplerBuilder { /// Builds the configured [`MotpeSampler`]. #[must_use] pub fn build(self) -> MotpeSampler { - let rng = match self.seed { - Some(s) => fastrand::Rng::with_seed(s), - None => fastrand::Rng::new(), - }; - MotpeSampler { n_startup_trials: self.n_startup_trials, n_ei_candidates: self.n_ei_candidates, kde_bandwidth: self.kde_bandwidth, - rng: Mutex::new(rng), + seed: self.seed.unwrap_or_else(|| fastrand::u64(..)), + call_seq: AtomicU64::new(0), } } } @@ -699,8 +704,8 @@ mod tests { // With no history, should use random sampling let history: Vec<MultiObjectiveTrial> = vec![]; - for _ in 0..50 { - let value = sampler.sample(&dist, 0, &history, &directions); + for i in 0..50 { + let value = sampler.sample(&dist, i, &history, &directions); if let ParamValue::Float(v) = value { assert!((0.0..=1.0).contains(&v)); } else { diff --git a/src/sampler/random.rs b/src/sampler/random.rs index 7dc31ef..6c750b3 100644 --- a/src/sampler/random.rs +++ b/src/sampler/random.rs @@ -27,7 +27,7 @@ //! let study: Study<f64> = Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(42)); //! ``` -use parking_lot::Mutex; +use core::sync::atomic::{AtomicU64, Ordering}; use crate::distribution::Distribution; use crate::multi_objective::{MultiObjectiveSampler, MultiObjectiveTrial}; @@ -58,7 +58,9 @@ use crate::types::Direction; /// let sampler = RandomSampler::with_seed(42); /// ``` pub struct RandomSampler { - rng: Mutex<fastrand::Rng>, + seed: u64, + /// Monotonic counter to disambiguate calls with identical (`trial_id`, distribution). + call_seq: AtomicU64, } impl RandomSampler { @@ -66,7 +68,8 @@ impl RandomSampler { #[must_use] pub fn new() -> Self { Self { - rng: Mutex::new(fastrand::Rng::new()), + seed: fastrand::u64(..), + call_seq: AtomicU64::new(0), } } @@ -76,7 +79,8 @@ impl RandomSampler { #[must_use] pub fn with_seed(seed: u64) -> Self { Self { - rng: Mutex::new(fastrand::Rng::with_seed(seed)), + seed, + call_seq: AtomicU64::new(0), } } } @@ -113,10 +117,15 @@ impl Sampler for RandomSampler { fn sample( &self, distribution: &Distribution, - _trial_id: u64, + trial_id: u64, _history: &[CompletedTrial], ) -> ParamValue { - let mut rng = self.rng.lock(); + let seq = self.call_seq.fetch_add(1, Ordering::Relaxed); + let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed( + self.seed, + trial_id, + rng_util::distribution_fingerprint(distribution).wrapping_add(seq), + )); match distribution { Distribution::Float(d) => { @@ -181,8 +190,8 @@ mod tests { step: None, }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Float(v) = value { assert!((0.0..=1.0).contains(&v)); } else { @@ -201,8 +210,8 @@ mod tests { step: None, }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Float(v) = value { assert!((1e-5..=1.0).contains(&v)); } else { @@ -221,8 +230,8 @@ mod tests { step: Some(0.25), }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Float(v) = value { assert!((0.0..=1.0).contains(&v)); // Check it's on the step grid @@ -245,8 +254,8 @@ mod tests { step: None, }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Int(v) = value { assert!((0..=10).contains(&v)); } else { @@ -265,8 +274,8 @@ mod tests { step: None, }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Int(v) = value { assert!((1..=1000).contains(&v)); } else { @@ -285,8 +294,8 @@ mod tests { step: Some(2), }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Int(v) = value { assert!((0..=10).contains(&v)); // Check it's on the step grid: 0, 2, 4, 6, 8, 10 @@ -302,8 +311,8 @@ mod tests { let sampler = RandomSampler::with_seed(42); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 5 }); - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &[]); + for i in 0..100 { + let value = sampler.sample(&dist, i, &[]); if let ParamValue::Categorical(idx) = value { assert!(idx < 5); } else { @@ -323,9 +332,9 @@ mod tests { step: None, }); - for _ in 0..10 { - let v1 = sampler1.sample(&dist, 0, &[]); - let v2 = sampler2.sample(&dist, 0, &[]); + for i in 0..10 { + let v1 = sampler1.sample(&dist, i, &[]); + let v2 = sampler2.sample(&dist, i, &[]); assert_eq!(v1, v2); } } diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 47d2d09..a9ad698 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -56,10 +56,9 @@ //! ``` use core::fmt::Debug; +use core::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use parking_lot::Mutex; - use crate::distribution::Distribution; use crate::error::{Error, Result}; use crate::kde::KernelDensityEstimator; @@ -129,8 +128,10 @@ pub struct TpeSampler { n_ei_candidates: usize, /// Optional fixed bandwidth for KDE. If None, uses Scott's rule. kde_bandwidth: Option<f64>, - /// Thread-safe RNG for sampling. - rng: Mutex<fastrand::Rng>, + /// Base seed for deterministic per-call RNG derivation (no mutex needed). + seed: u64, + /// Monotonic counter to disambiguate calls with identical (`trial_id`, distribution). + call_seq: AtomicU64, } impl TpeSampler { @@ -148,7 +149,8 @@ impl TpeSampler { n_startup_trials: 10, n_ei_candidates: 24, kde_bandwidth: None, - rng: Mutex::new(fastrand::Rng::new()), + seed: fastrand::u64(..), + call_seq: AtomicU64::new(0), } } @@ -248,17 +250,13 @@ impl TpeSampler { return Err(Error::InvalidBandwidth(bw)); } - let rng = match seed { - Some(s) => fastrand::Rng::with_seed(s), - None => fastrand::Rng::new(), - }; - Ok(Self { gamma_strategy: Arc::new(gamma_strategy), n_startup_trials, n_ei_candidates, kde_bandwidth, - rng: Mutex::new(rng), + seed: seed.unwrap_or_else(|| fastrand::u64(..)), + call_seq: AtomicU64::new(0), }) } @@ -849,17 +847,13 @@ impl TpeSamplerBuilder { return Err(Error::InvalidBandwidth(bw)); } - let rng = match self.seed { - Some(s) => fastrand::Rng::with_seed(s), - None => fastrand::Rng::new(), - }; - Ok(TpeSampler { gamma_strategy, n_startup_trials: self.n_startup_trials, n_ei_candidates: self.n_ei_candidates, kde_bandwidth: self.kde_bandwidth, - rng: Mutex::new(rng), + seed: self.seed.unwrap_or_else(|| fastrand::u64(..)), + call_seq: AtomicU64::new(0), }) } } @@ -875,10 +869,15 @@ impl Sampler for TpeSampler { fn sample( &self, distribution: &Distribution, - _trial_id: u64, + trial_id: u64, history: &[CompletedTrial], ) -> ParamValue { - let mut rng = self.rng.lock(); + let seq = self.call_seq.fetch_add(1, Ordering::Relaxed); + let mut rng = fastrand::Rng::with_seed(rng_util::mix_seed( + self.seed, + trial_id, + rng_util::distribution_fingerprint(distribution).wrapping_add(seq), + )); // Fall back to random sampling during startup phase if history.len() < self.n_startup_trials { @@ -1077,8 +1076,8 @@ mod tests { // With fewer than n_startup_trials, should use random sampling let history: Vec<CompletedTrial> = vec![]; - for _ in 0..100 { - let value = sampler.sample(&dist, 0, &history); + for i in 0..100 { + let value = sampler.sample(&dist, i, &history); if let ParamValue::Float(v) = value { assert!((0.0..=1.0).contains(&v)); } else { diff --git a/tests/sampler/tpe.rs b/tests/sampler/tpe.rs index 3e37899..799bb0d 100644 --- a/tests/sampler/tpe.rs +++ b/tests/sampler/tpe.rs @@ -74,7 +74,7 @@ fn test_tpe_maximization() { // Optimal: x = 2, f(2) = 10 let sampler = TpeSampler::builder() .seed(456) - .n_startup_trials(5) + .n_startup_trials(15) .build() .unwrap(); @@ -83,7 +83,7 @@ fn test_tpe_maximization() { let x_param = FloatParam::new(-10.0, 10.0); study - .optimize(50, |trial: &mut optimizer::Trial| { + .optimize(100, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; Ok::<_, Error>(-(x - 2.0).powi(2) + 10.0) }) From 5da63d89763517cce56ea83ee0393d85ade4c852 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:39:28 +0100 Subject: [PATCH 24/48] perf: reduce CompletedTrial cloning overhead - Add Trial::into_completed() and into_multi_objective_trial() to move fields instead of cloning 5 HashMaps/Vecs per trial completion - Fire after_trial callback before pushing to storage, eliminating the clone-from-storage pattern at all 4 call sites - Optimize top_trials(n) to sort indices and clone only N trials instead of cloning all completed trials - Remove unused set_complete/set_pruned methods --- src/multi_objective.rs | 14 +---- src/objective.rs | 5 ++ src/study.rs | 112 +++++++++++++-------------------------- src/trial.rs | 47 ++++++++++++---- tests/study/objective.rs | 8 ++- 5 files changed, 87 insertions(+), 99 deletions(-) diff --git a/src/multi_objective.rs b/src/multi_objective.rs index 9b924f3..d2e702a 100644 --- a/src/multi_objective.rs +++ b/src/multi_objective.rs @@ -321,18 +321,8 @@ impl MultiObjectiveStudy { } /// Records a completed trial. - fn complete_trial(&self, mut trial: Trial, values: Vec<f64>) { - trial.set_complete(); - let mo_trial = MultiObjectiveTrial { - id: trial.id(), - params: trial.params().clone(), - distributions: trial.distributions().clone(), - param_labels: trial.param_labels().clone(), - values, - state: TrialState::Complete, - user_attrs: trial.user_attrs().clone(), - constraints: trial.constraint_values().to_vec(), - }; + fn complete_trial(&self, trial: Trial, values: Vec<f64>) { + let mo_trial = trial.into_multi_objective_trial(values, TrialState::Complete); self.completed_trials.write().push(mo_trial); } diff --git a/src/objective.rs b/src/objective.rs index c22babf..4cad201 100644 --- a/src/objective.rs +++ b/src/objective.rs @@ -117,6 +117,11 @@ pub trait Objective<V: PartialOrd = f64> { /// Called after each **completed** trial (not failed or pruned). /// + /// The trial is passed directly as the argument *before* it is pushed + /// to storage, so `study.n_trials()` and `study.trials()` do not yet + /// include this trial. The trial is always pushed to storage after this + /// callback returns, regardless of the return value. + /// /// Return `ControlFlow::Break(())` to stop the optimization loop. /// /// Default: always continues. diff --git a/src/study.rs b/src/study.rs index dc4c648..2fcba0e 100644 --- a/src/study.rs +++ b/src/study.rs @@ -491,20 +491,8 @@ where /// /// assert_eq!(study.n_trials(), 1); /// ``` - pub fn complete_trial(&self, mut trial: Trial, value: V) { - trial.set_complete(); - let mut completed = CompletedTrial::with_intermediate_values( - trial.id(), - trial.params().clone(), - trial.distributions().clone(), - trial.param_labels().clone(), - value, - trial.intermediate_values().to_vec(), - trial.user_attrs().clone(), - ); - completed.state = TrialState::Complete; - completed.constraints = trial.constraint_values().to_vec(); - + pub fn complete_trial(&self, trial: Trial, value: V) { + let completed = trial.into_completed(value, TrialState::Complete); self.storage.push(completed); } @@ -604,23 +592,11 @@ where /// # Arguments /// /// * `trial` - The trial that was pruned. - pub fn prune_trial(&self, mut trial: Trial) + pub fn prune_trial(&self, trial: Trial) where V: Default, { - trial.set_pruned(); - let mut completed = CompletedTrial::with_intermediate_values( - trial.id(), - trial.params().clone(), - trial.distributions().clone(), - trial.param_labels().clone(), - V::default(), - trial.intermediate_values().to_vec(), - trial.user_attrs().clone(), - ); - completed.state = TrialState::Pruned; - completed.constraints = trial.constraint_values().to_vec(); - + let completed = trial.into_completed(V::default(), TrialState::Pruned); self.storage.push(completed); } @@ -852,15 +828,17 @@ where { let trials = self.storage.trials_arc().read(); let direction = self.direction; - let mut completed: Vec<_> = trials + // Sort indices instead of cloning all trials, then clone only the top N. + let mut indices: Vec<usize> = trials .iter() - .filter(|t| t.state == TrialState::Complete) - .cloned() + .enumerate() + .filter(|(_, t)| t.state == TrialState::Complete) + .map(|(i, _)| i) .collect(); // Sort best-first: reverse the compare_trials ordering (which is designed for max_by) - completed.sort_by(|a, b| Self::compare_trials(b, a, direction)); - completed.truncate(n); - completed + indices.sort_by(|&a, &b| Self::compare_trials(&trials[b], &trials[a], direction)); + indices.truncate(n); + indices.iter().map(|&i| trials[i].clone()).collect() } /// Run optimization with an objective. @@ -921,7 +899,12 @@ where Ok(value) => { #[cfg(feature = "tracing")] let trial_id = trial.id(); - self.complete_trial(trial, value); + + let completed = trial.into_completed(value, TrialState::Complete); + + // Fire after_trial hook before pushing to storage + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); #[cfg(feature = "tracing")] { @@ -938,17 +921,8 @@ where } } - // Fire after_trial hook - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = - objective.after_trial(self, &completed_clone) - { - // Return early — at least one trial completed. - return Ok(()); - } + if let ControlFlow::Break(()) = flow { + return Ok(()); } } Err(e) if is_trial_pruned(&e) => { @@ -1049,19 +1023,14 @@ where (t, Ok(value)) => { #[cfg(feature = "tracing")] let trial_id = t.id(); - self.complete_trial(t, value); + + let completed = t.into_completed(value, TrialState::Complete); + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); trace_info!(trial_id, "trial completed"); - // Fire after_trial hook - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = - objective.after_trial(self, &completed_clone) - { - return Ok(()); - } + if let ControlFlow::Break(()) = flow { + return Ok(()); } } (t, Err(e)) if is_trial_pruned(&e) => { @@ -1172,18 +1141,14 @@ where (t, Ok(value)) => { #[cfg(feature = "tracing")] let trial_id = t.id(); - self.complete_trial(t, value); + + let completed = t.into_completed(value, TrialState::Complete); + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); trace_info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - if let ControlFlow::Break(()) = - objective.after_trial(self, &completed_clone) - { - break 'spawn; - } + if let ControlFlow::Break(()) = flow { + break 'spawn; } } (t, Err(e)) => { @@ -1228,15 +1193,12 @@ where (t, Ok(value)) => { #[cfg(feature = "tracing")] let trial_id = t.id(); - self.complete_trial(t, value); - trace_info!(trial_id, "trial completed"); + + let completed = t.into_completed(value, TrialState::Complete); // Still fire after_trial for bookkeeping, but don't break — we're draining. - let trials = self.storage.trials_arc().read(); - if let Some(completed) = trials.last() { - let completed_clone = completed.clone(); - drop(trials); - let _ = objective.after_trial(self, &completed_clone); - } + let _ = objective.after_trial(self, &completed); + self.storage.push(completed); + trace_info!(trial_id, "trial completed"); } (t, Err(e)) => { #[cfg(feature = "tracing")] diff --git a/src/trial.rs b/src/trial.rs index 1ef939e..9d069e0 100644 --- a/src/trial.rs +++ b/src/trial.rs @@ -26,6 +26,7 @@ use parking_lot::RwLock; use crate::distribution::Distribution; use crate::error::{Error, Result}; +use crate::multi_objective::MultiObjectiveTrial; use crate::param::ParamValue; use crate::parameter::{ParamId, Parameter}; use crate::pruner::Pruner; @@ -389,21 +390,11 @@ impl Trial { &self.constraint_values } - /// Set the trial state to `Complete`. - pub(crate) fn set_complete(&mut self) { - self.state = TrialState::Complete; - } - /// Set the trial state to `Failed`. pub(crate) fn set_failed(&mut self) { self.state = TrialState::Failed; } - /// Set the trial state to `Pruned`. - pub(crate) fn set_pruned(&mut self) { - self.state = TrialState::Pruned; - } - /// Suggest a parameter value using a [`Parameter`] definition. /// /// This is the primary entry point for sampling parameters. It handles @@ -482,6 +473,42 @@ impl Trial { Ok(result) } + + /// Consume this trial and move its fields into a [`CompletedTrial`]. + /// + /// This avoids cloning the trial's `HashMap`s and `Vec`s by moving + /// ownership directly into the completed trial. + pub(crate) fn into_completed<V>(self, value: V, state: TrialState) -> CompletedTrial<V> { + CompletedTrial { + id: self.id, + params: self.params, + distributions: self.distributions, + param_labels: self.param_labels, + value, + intermediate_values: self.intermediate_values, + state, + user_attrs: self.user_attrs, + constraints: self.constraint_values, + } + } + + /// Consume this trial and move its fields into a [`MultiObjectiveTrial`]. + pub(crate) fn into_multi_objective_trial( + self, + values: Vec<f64>, + state: TrialState, + ) -> MultiObjectiveTrial { + MultiObjectiveTrial { + id: self.id, + params: self.params, + distributions: self.distributions, + param_labels: self.param_labels, + values, + state, + user_attrs: self.user_attrs, + constraints: self.constraint_values, + } + } } #[cfg(test)] diff --git a/tests/study/objective.rs b/tests/study/objective.rs index 8d28bf1..e379ec9 100644 --- a/tests/study/objective.rs +++ b/tests/study/objective.rs @@ -20,7 +20,9 @@ fn test_callback_early_stopping() { Ok(x) } fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { - if study.n_trials() >= 5 { + // The current trial has not yet been pushed to storage when + // after_trial fires, so n_trials() == 4 means this is the 5th. + if study.n_trials() >= 4 { ControlFlow::Break(()) } else { ControlFlow::Continue(()) @@ -98,7 +100,9 @@ fn test_callback_sampler_early_stopping() { Ok(x) } fn after_trial(&self, study: &Study<f64>, _trial: &CompletedTrial<f64>) -> ControlFlow<()> { - if study.n_trials() >= 3 { + // The current trial has not yet been pushed to storage when + // after_trial fires, so n_trials() == 2 means this is the 3rd. + if study.n_trials() >= 2 { ControlFlow::Break(()) } else { ControlFlow::Continue(()) From ddbbe294bcc842f58eee271c5684dceb72a79bd1 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:43:24 +0100 Subject: [PATCH 25/48] perf(journal): use incremental refresh instead of re-reading entire file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Track a byte offset so refresh() only parses lines appended since the last read, reducing total I/O from O(n²) to O(n) over an n-trial run. --- src/storage/journal.rs | 103 ++++++++++++++++++++++++++++++++++------- tests/journal_tests.rs | 56 ++++++++++++++++++++++ 2 files changed, 143 insertions(+), 16 deletions(-) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 55ad087..6833bd6 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -69,8 +69,9 @@ //! [`MemoryStorage`](super::MemoryStorage) instead (the default). use core::marker::PhantomData; +use core::sync::atomic::{AtomicU64, Ordering}; use std::fs::{File, OpenOptions}; -use std::io::{BufRead, BufReader, Seek, SeekFrom, Write}; +use std::io::{BufRead, BufReader, Read as _, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; @@ -112,6 +113,8 @@ pub struct JournalStorage<V = f64> { path: PathBuf, /// Serialise in-process writes so we only hold the file lock briefly. write_lock: Mutex<()>, + /// Byte offset of last-read position for incremental refresh. + file_offset: AtomicU64, _marker: PhantomData<V>, } @@ -131,6 +134,7 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { memory: MemoryStorage::new(), path: path.as_ref().to_path_buf(), write_lock: Mutex::new(()), + file_offset: AtomicU64::new(0), _marker: PhantomData, } } @@ -146,11 +150,12 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { /// exists but cannot be read or parsed. pub fn open(path: impl AsRef<Path>) -> crate::Result<Self> { let path = path.as_ref().to_path_buf(); - let trials = load_trials_from_file(&path)?; + let (trials, offset) = load_trials_from_file(&path)?; Ok(Self { memory: MemoryStorage::with_trials(trials), path, write_lock: Mutex::new(()), + file_offset: AtomicU64::new(offset), _marker: PhantomData, }) } @@ -180,6 +185,11 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { file.sync_data() .map_err(|e| crate::Error::Storage(e.to_string()))?; + let pos = file + .stream_position() + .map_err(|e| crate::Error::Storage(e.to_string()))?; + self.file_offset.store(pos, Ordering::SeqCst); + file.unlock() .map_err(|e| crate::Error::Storage(e.to_string()))?; @@ -203,36 +213,97 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> Storage<V> for JournalStorag } fn refresh(&self) -> bool { - let Ok(loaded) = load_trials_from_file::<V>(&self.path) else { + let Ok(file) = File::open(&self.path) else { return false; }; - let mut guard = self.memory.trials_arc().write(); - if loaded.len() > guard.len() { - if let Some(max_id) = loaded.iter().map(|t| t.id).max() { - self.memory.bump_next_id(max_id + 1); - } - *guard = loaded; - true + + if file.lock_shared().is_err() { + return false; + } + + let offset = self.file_offset.load(Ordering::SeqCst); + + let file_size = if let Ok(m) = file.metadata() { + m.len() } else { - false + let _ = file.unlock(); + return false; + }; + + if file_size <= offset { + let _ = file.unlock(); + return false; + } + + let mut buf = String::new(); + let mut handle = &file; + if handle.seek(SeekFrom::Start(offset)).is_err() { + let _ = file.unlock(); + return false; } + if handle.read_to_string(&mut buf).is_err() { + let _ = file.unlock(); + return false; + } + + let _ = file.unlock(); + + let bytes_read = buf.len() as u64; + let mut new_trials = Vec::new(); + + for line in buf.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + let trial: CompletedTrial<V> = match serde_json::from_str(line) { + Ok(t) => t, + Err(_) => return false, + }; + if trial.validate().is_err() { + return false; + } + new_trials.push(trial); + } + + if new_trials.is_empty() { + self.file_offset + .store(offset + bytes_read, Ordering::SeqCst); + return false; + } + + let mut guard = self.memory.trials_arc().write(); + if let Some(max_id) = new_trials.iter().map(|t| t.id).max() { + self.memory.bump_next_id(max_id + 1); + } + guard.extend(new_trials); + self.file_offset + .store(offset + bytes_read, Ordering::SeqCst); + true } } -/// Read all trials from a JSONL file. Returns an empty vec if the -/// file does not exist. +/// Read all trials from a JSONL file. Returns an empty vec (and +/// offset 0) if the file does not exist. The returned `u64` is the +/// file size at the time of reading, suitable for initialising the +/// incremental-refresh offset. fn load_trials_from_file<V: DeserializeOwned>( path: &Path, -) -> crate::Result<Vec<CompletedTrial<V>>> { +) -> crate::Result<(Vec<CompletedTrial<V>>, u64)> { let file = match File::open(path) { Ok(f) => f, - Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok((Vec::new(), 0)), Err(e) => return Err(crate::Error::Storage(e.to_string())), }; file.lock_shared() .map_err(|e| crate::Error::Storage(e.to_string()))?; + let file_size = file + .metadata() + .map_err(|e| crate::Error::Storage(e.to_string()))? + .len(); + let reader = BufReader::new(&file); let mut trials = Vec::new(); @@ -251,5 +322,5 @@ fn load_trials_from_file<V: DeserializeOwned>( file.unlock() .map_err(|e| crate::Error::Storage(e.to_string()))?; - Ok(trials) + Ok((trials, file_size)) } diff --git a/tests/journal_tests.rs b/tests/journal_tests.rs index 6d5f473..4c25e0e 100644 --- a/tests/journal_tests.rs +++ b/tests/journal_tests.rs @@ -3,6 +3,8 @@ use std::collections::HashMap; use std::sync::Arc; +use std::io::Write; + use optimizer::parameter::{FloatParam, Parameter}; use optimizer::sampler::CompletedTrial; use optimizer::sampler::random::RandomSampler; @@ -317,3 +319,57 @@ fn accepts_valid_journal_with_distributions() { std::fs::remove_file(&path).ok(); } + +#[test] +fn refresh_skips_own_writes() { + let path = temp_path(); + let storage = JournalStorage::new(&path); + + for i in 0..5 { + storage.push(sample_trial(i, i as f64)); + // Our own push advanced the offset, so refresh should find nothing new. + assert!(!storage.refresh(), "refresh returned true after push {i}"); + } + + assert_eq!(storage.trials_arc().read().len(), 5); + std::fs::remove_file(&path).ok(); +} + +#[test] +fn refresh_picks_up_external_writes() { + let path = temp_path(); + let storage = JournalStorage::new(&path); + + // Push 3 trials through the storage (advances offset). + for i in 0..3 { + storage.push(sample_trial(i, i as f64)); + } + assert_eq!(storage.trials_arc().read().len(), 3); + + // Simulate an external process appending 2 more lines directly. + { + let mut file = std::fs::OpenOptions::new() + .append(true) + .open(&path) + .unwrap(); + for i in 3..5u64 { + let trial = sample_trial(i, i as f64); + let line = serde_json::to_string(&trial).unwrap(); + writeln!(file, "{line}").unwrap(); + } + file.sync_all().unwrap(); + } + + // refresh() should pick up the 2 external trials. + assert!(storage.refresh(), "refresh should detect external writes"); + assert_eq!(storage.trials_arc().read().len(), 5); + + // A second refresh should be a no-op. + assert!( + !storage.refresh(), + "second refresh should return false (no new data)" + ); + assert_eq!(storage.trials_arc().read().len(), 5); + + std::fs::remove_file(&path).ok(); +} From 1b158e4d1b395ed670f2c1b074f6592488531980 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:46:46 +0100 Subject: [PATCH 26/48] perf: pre-allocate vectors in analysis methods - Use Vec::with_capacity() in param_importance() and fanova_with_config() where iteration count is known or bounded - Fix flaky MultivariateTpeSampler doctest by increasing trials from 30 to 50 --- src/sampler/tpe/multivariate.rs | 2 +- src/study.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate.rs index 502125f..cb9c4e3 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate.rs @@ -205,7 +205,7 @@ pub enum ConstantLiarStrategy { /// let y = FloatParam::new(-5.0, 5.0); /// /// study -/// .optimize(30, |trial: &mut optimizer::Trial| { +/// .optimize(50, |trial: &mut optimizer::Trial| { /// let xv = x.suggest(trial)?; /// let yv = y.suggest(trial)?; /// Ok::<_, optimizer::Error>(xv * xv + yv * yv) diff --git a/src/study.rs b/src/study.rs index 2fcba0e..0769f83 100644 --- a/src/study.rs +++ b/src/study.rs @@ -1517,12 +1517,12 @@ where // Collect all parameter IDs across trials. let all_param_ids: BTreeSet<_> = complete.iter().flat_map(|t| t.params.keys()).collect(); - let mut scores: Vec<(String, f64)> = Vec::new(); + let mut scores: Vec<(String, f64)> = Vec::with_capacity(all_param_ids.len()); for ¶m_id in &all_param_ids { // Collect (param_value_f64, objective_f64) for trials that have this param. - let mut param_vals = Vec::new(); - let mut obj_vals = Vec::new(); + let mut param_vals = Vec::with_capacity(complete.len()); + let mut obj_vals = Vec::with_capacity(complete.len()); for trial in &complete { if let Some(pv) = trial.params.get(param_id) { @@ -1645,8 +1645,8 @@ where } // Build feature matrix (only trials that have all parameters). - let mut data = Vec::new(); - let mut targets = Vec::new(); + let mut data = Vec::with_capacity(complete.len()); + let mut targets = Vec::with_capacity(complete.len()); for trial in &complete { let mut row = Vec::with_capacity(all_param_ids.len()); From 9ab94002d951866a7a60261162baf3bb2f08baf4 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 14:51:03 +0100 Subject: [PATCH 27/48] perf(tpe): eliminate unnecessary Vec allocations in KDE/TPE sampling - In-place log transformation in sample_tpe_float (saves 2 allocs per log-scale float sample) - Pass owned Vec<i64> to sample_tpe_int to reduce peak memory during int-to-float conversion - Stack-allocate categorical count arrays for <=32 choices in sample_tpe_categorical --- src/sampler/motpe.rs | 55 ++++++++++++++++++++++++++------- src/sampler/tpe/multivariate.rs | 55 ++++++++++++++++++++++++++------- src/sampler/tpe/sampler.rs | 55 ++++++++++++++++++++++++++------- 3 files changed, 130 insertions(+), 35 deletions(-) diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index cdcf559..883cfe0 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -263,8 +263,20 @@ impl MotpeSampler { let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { let i_low = low.ln(); let i_high = high.ln(); - let g: Vec<f64> = good_values.iter().map(|&v| v.ln()).collect(); - let b: Vec<f64> = bad_values.iter().map(|&v| v.ln()).collect(); + let g = { + let mut v = good_values; + for x in &mut v { + *x = x.ln(); + } + v + }; + let b = { + let mut v = bad_values; + for x in &mut v { + *x = x.ln(); + } + v + }; (i_low, i_high, g, b) } else { (low, high, good_values, bad_values) @@ -339,12 +351,12 @@ impl MotpeSampler { high: i64, log_scale: bool, step: Option<i64>, - good_values: &[i64], - bad_values: &[i64], + good_values: Vec<i64>, + bad_values: Vec<i64>, rng: &mut fastrand::Rng, ) -> i64 { - let good_floats: Vec<f64> = good_values.iter().map(|&v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.iter().map(|&v| v as f64).collect(); + let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); + let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); let float_value = self.sample_tpe_float( low as f64, @@ -375,9 +387,30 @@ impl MotpeSampler { bad_indices: &[usize], rng: &mut fastrand::Rng, ) -> usize { - let mut good_counts = vec![0usize; n_choices]; - let mut bad_counts = vec![0usize; n_choices]; + // Stack-allocate for the common case (<=32 choices), heap for rare large cases + let mut good_buf = [0usize; 32]; + let mut bad_buf = [0usize; 32]; + let mut weight_buf = [0.0f64; 32]; + + let mut good_vec; + let mut bad_vec; + let mut weight_vec; + + let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = + if n_choices <= 32 { + ( + &mut good_buf[..n_choices], + &mut bad_buf[..n_choices], + &mut weight_buf[..n_choices], + ) + } else { + good_vec = vec![0usize; n_choices]; + bad_vec = vec![0usize; n_choices]; + weight_vec = vec![0.0f64; n_choices]; + (&mut good_vec, &mut bad_vec, &mut weight_vec) + }; + // Count occurrences in good and bad groups for &idx in good_indices { if idx < n_choices { good_counts[idx] += 1; @@ -393,7 +426,7 @@ impl MotpeSampler { let good_total = good_indices.len() as f64 + n_choices as f64; let bad_total = bad_indices.len() as f64 + n_choices as f64; - let mut weights = vec![0.0f64; n_choices]; + // Calculate l(x)/g(x) ratio for each category for i in 0..n_choices { let l_prob = (good_counts[i] as f64 + 1.0) / good_total; let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; @@ -521,8 +554,8 @@ impl MultiObjectiveSampler for MotpeSampler { d.high, d.log_scale, d.step, - &good_values, - &bad_values, + good_values, + bad_values, &mut rng, ); ParamValue::Int(value) diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate.rs index cb9c4e3..731fc44 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate.rs @@ -1356,8 +1356,8 @@ impl MultivariateTpeSampler { d.high, d.log_scale, d.step, - &good_values, - &bad_values, + good_values, + bad_values, rng, ); ParamValue::Int(value) @@ -1412,8 +1412,20 @@ impl MultivariateTpeSampler { let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { let i_low = low.ln(); let i_high = high.ln(); - let g: Vec<f64> = good_values.iter().map(|&v| v.ln()).collect(); - let b: Vec<f64> = bad_values.iter().map(|&v| v.ln()).collect(); + let g = { + let mut v = good_values; + for x in &mut v { + *x = x.ln(); + } + v + }; + let b = { + let mut v = bad_values; + for x in &mut v { + *x = x.ln(); + } + v + }; (i_low, i_high, g, b) } else { (low, high, good_values, bad_values) @@ -1487,13 +1499,13 @@ impl MultivariateTpeSampler { high: i64, log_scale: bool, step: Option<i64>, - good_values: &[i64], - bad_values: &[i64], + good_values: Vec<i64>, + bad_values: Vec<i64>, rng: &mut fastrand::Rng, ) -> i64 { // Convert to floats for KDE - let good_floats: Vec<f64> = good_values.iter().map(|&v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.iter().map(|&v| v as f64).collect(); + let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); + let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); // Use float TPE sampling let float_value = self.sample_tpe_float( @@ -1598,10 +1610,30 @@ impl MultivariateTpeSampler { bad_indices: &[usize], rng: &mut fastrand::Rng, ) -> usize { - // Count occurrences in good and bad groups - let mut good_counts = vec![0usize; n_choices]; - let mut bad_counts = vec![0usize; n_choices]; + // Stack-allocate for the common case (<=32 choices), heap for rare large cases + let mut good_buf = [0usize; 32]; + let mut bad_buf = [0usize; 32]; + let mut weight_buf = [0.0f64; 32]; + + let mut good_vec; + let mut bad_vec; + let mut weight_vec; + + let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = + if n_choices <= 32 { + ( + &mut good_buf[..n_choices], + &mut bad_buf[..n_choices], + &mut weight_buf[..n_choices], + ) + } else { + good_vec = vec![0usize; n_choices]; + bad_vec = vec![0usize; n_choices]; + weight_vec = vec![0.0f64; n_choices]; + (&mut good_vec, &mut bad_vec, &mut weight_vec) + }; + // Count occurrences in good and bad groups for &idx in good_indices { if idx < n_choices { good_counts[idx] += 1; @@ -1618,7 +1650,6 @@ impl MultivariateTpeSampler { let bad_total = bad_indices.len() as f64 + n_choices as f64; // Calculate l(x)/g(x) ratio for each category - let mut weights = vec![0.0f64; n_choices]; for i in 0..n_choices { let l_prob = (good_counts[i] as f64 + 1.0) / good_total; let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index a9ad698..fb9efe6 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -373,8 +373,20 @@ impl TpeSampler { let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { let i_low = low.ln(); let i_high = high.ln(); - let g: Vec<f64> = good_values.iter().map(|&v| v.ln()).collect(); - let b: Vec<f64> = bad_values.iter().map(|&v| v.ln()).collect(); + let g = { + let mut v = good_values; + for x in &mut v { + *x = x.ln(); + } + v + }; + let b = { + let mut v = bad_values; + for x in &mut v { + *x = x.ln(); + } + v + }; (i_low, i_high, g, b) } else { (low, high, good_values, bad_values) @@ -454,13 +466,13 @@ impl TpeSampler { high: i64, log_scale: bool, step: Option<i64>, - good_values: &[i64], - bad_values: &[i64], + good_values: Vec<i64>, + bad_values: Vec<i64>, rng: &mut fastrand::Rng, ) -> i64 { // Convert to floats for KDE - let good_floats: Vec<f64> = good_values.iter().map(|&v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.iter().map(|&v| v as f64).collect(); + let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); + let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); // Use float TPE sampling let float_value = self.sample_tpe_float( @@ -497,10 +509,30 @@ impl TpeSampler { bad_indices: &[usize], rng: &mut fastrand::Rng, ) -> usize { - // Count occurrences in good and bad groups - let mut good_counts = vec![0usize; n_choices]; - let mut bad_counts = vec![0usize; n_choices]; + // Stack-allocate for the common case (<=32 choices), heap for rare large cases + let mut good_buf = [0usize; 32]; + let mut bad_buf = [0usize; 32]; + let mut weight_buf = [0.0f64; 32]; + + let mut good_vec; + let mut bad_vec; + let mut weight_vec; + + let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = + if n_choices <= 32 { + ( + &mut good_buf[..n_choices], + &mut bad_buf[..n_choices], + &mut weight_buf[..n_choices], + ) + } else { + good_vec = vec![0usize; n_choices]; + bad_vec = vec![0usize; n_choices]; + weight_vec = vec![0.0f64; n_choices]; + (&mut good_vec, &mut bad_vec, &mut weight_vec) + }; + // Count occurrences in good and bad groups for &idx in good_indices { if idx < n_choices { good_counts[idx] += 1; @@ -517,7 +549,6 @@ impl TpeSampler { let bad_total = bad_indices.len() as f64 + n_choices as f64; // Calculate l(x)/g(x) ratio for each category - let mut weights = vec![0.0f64; n_choices]; for i in 0..n_choices { let l_prob = (good_counts[i] as f64 + 1.0) / good_total; let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; @@ -967,8 +998,8 @@ impl Sampler for TpeSampler { d.high, d.log_scale, d.step, - &good_values, - &bad_values, + good_values, + bad_values, &mut rng, ); ParamValue::Int(value) From 4781107ede2260123353bd0841014b7c0c41777a Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:03:02 +0100 Subject: [PATCH 28/48] refactor(study): split monolithic study.rs into focused submodules - mod.rs: core struct (pub(crate) fields), constructors, trial management - builder.rs: StudyBuilder fluent API - optimize.rs: sync optimization loop - async_impl.rs: optimize_async/optimize_parallel (feature-gated) - analysis.rs: best_trial, top_trials, param_importance, fanova - export.rs: CSV, summary, Display, export_html - persistence.rs: StudySnapshot, save/load, with_journal - iter.rs: iter(), IntoIterator --- src/study.rs | 2060 -------------------------------------- src/study/analysis.rs | 379 +++++++ src/study/async_impl.rs | 283 ++++++ src/study/builder.rs | 135 +++ src/study/export.rs | 257 +++++ src/study/iter.rs | 43 + src/study/mod.rs | 766 ++++++++++++++ src/study/optimize.rs | 123 +++ src/study/persistence.rs | 148 +++ 9 files changed, 2134 insertions(+), 2060 deletions(-) delete mode 100644 src/study.rs create mode 100644 src/study/analysis.rs create mode 100644 src/study/async_impl.rs create mode 100644 src/study/builder.rs create mode 100644 src/study/export.rs create mode 100644 src/study/iter.rs create mode 100644 src/study/mod.rs create mode 100644 src/study/optimize.rs create mode 100644 src/study/persistence.rs diff --git a/src/study.rs b/src/study.rs deleted file mode 100644 index 0769f83..0000000 --- a/src/study.rs +++ /dev/null @@ -1,2060 +0,0 @@ -//! Study implementation for managing optimization trials. - -use core::any::Any; -use core::fmt; -use core::marker::PhantomData; -use core::ops::ControlFlow; -use std::collections::{HashMap, VecDeque}; -use std::sync::Arc; - -use parking_lot::{Mutex, RwLock}; - -use crate::param::ParamValue; -use crate::parameter::ParamId; -use crate::pruner::{NopPruner, Pruner}; -use crate::sampler::random::RandomSampler; -use crate::sampler::{CompletedTrial, Sampler}; -use crate::trial::Trial; -use crate::types::{Direction, TrialState}; - -/// A study manages the optimization process, tracking trials and their results. -/// -/// The study is parameterized by the objective value type `V`, which defaults to `f64`. -/// The only constraint on `V` is `PartialOrd`, allowing comparison of objective values -/// to determine which trial is best. -/// -/// When `V = f64`, the study passes trial history to the sampler for informed -/// parameter suggestions (e.g., TPE sampler uses history to guide sampling). -/// -/// # Examples -/// -/// ``` -/// use optimizer::{Direction, Study}; -/// -/// // Create a study to minimize an objective function -/// let study: Study<f64> = Study::new(Direction::Minimize); -/// assert_eq!(study.direction(), Direction::Minimize); -/// ``` -pub struct Study<V = f64> -where - V: PartialOrd, -{ - /// The optimization direction. - direction: Direction, - /// The sampler used to generate parameter values. - sampler: Arc<dyn Sampler>, - /// The pruner used to decide whether to stop trials early. - pruner: Arc<dyn Pruner>, - /// Trial storage backend (default: [`MemoryStorage`](crate::storage::MemoryStorage)). - storage: Arc<dyn crate::storage::Storage<V>>, - /// Optional factory for creating sampler-aware trials. - /// Set automatically for `Study<f64>` so that `create_trial()` and all - /// optimization methods use the sampler without requiring `_with_sampler` suffixes. - trial_factory: Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>>, - /// Queue of parameter configurations to evaluate next. - enqueued_params: Arc<Mutex<VecDeque<HashMap<ParamId, ParamValue>>>>, -} - -impl<V> Study<V> -where - V: PartialOrd, -{ - /// Create a new study with the given optimization direction. - /// - /// Uses the default `RandomSampler` for parameter sampling. - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// - /// # Examples - /// - /// ``` - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// assert_eq!(study.direction(), Direction::Minimize); - /// ``` - #[must_use] - pub fn new(direction: Direction) -> Self - where - V: Send + Sync + 'static, - { - Self::with_sampler(direction, RandomSampler::new()) - } - - /// Return a [`StudyBuilder`] for constructing a study with a fluent API. - /// - /// # Examples - /// - /// ``` - /// use optimizer::prelude::*; - /// - /// let study: Study<f64> = Study::builder() - /// .minimize() - /// .sampler(TpeSampler::new()) - /// .pruner(NopPruner) - /// .build(); - /// ``` - #[must_use] - pub fn builder() -> StudyBuilder<V> { - StudyBuilder { - direction: Direction::Minimize, - sampler: None, - pruner: None, - storage: None, - _marker: PhantomData, - } - } - - /// Create a study that minimizes the objective value. - /// - /// This is a shorthand for `Study::with_sampler(Direction::Minimize, sampler)`. - /// - /// # Arguments - /// - /// * `sampler` - The sampler to use for parameter sampling. - /// - /// # Examples - /// - /// ``` - /// use optimizer::Study; - /// use optimizer::sampler::tpe::TpeSampler; - /// - /// let study: Study<f64> = Study::minimize(TpeSampler::new()); - /// assert_eq!(study.direction(), optimizer::Direction::Minimize); - /// ``` - #[must_use] - pub fn minimize(sampler: impl Sampler + 'static) -> Self - where - V: Send + Sync + 'static, - { - Self::with_sampler(Direction::Minimize, sampler) - } - - /// Create a study that maximizes the objective value. - /// - /// This is a shorthand for `Study::with_sampler(Direction::Maximize, sampler)`. - /// - /// # Arguments - /// - /// * `sampler` - The sampler to use for parameter sampling. - /// - /// # Examples - /// - /// ``` - /// use optimizer::Study; - /// use optimizer::sampler::tpe::TpeSampler; - /// - /// let study: Study<f64> = Study::maximize(TpeSampler::new()); - /// assert_eq!(study.direction(), optimizer::Direction::Maximize); - /// ``` - #[must_use] - pub fn maximize(sampler: impl Sampler + 'static) -> Self - where - V: Send + Sync + 'static, - { - Self::with_sampler(Direction::Maximize, sampler) - } - - /// Create a new study with a custom sampler. - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// * `sampler` - The sampler to use for parameter sampling. - /// - /// # Examples - /// - /// ``` - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); - /// assert_eq!(study.direction(), Direction::Maximize); - /// ``` - pub fn with_sampler(direction: Direction, sampler: impl Sampler + 'static) -> Self - where - V: Send + Sync + 'static, - { - Self::with_sampler_and_storage( - direction, - sampler, - crate::storage::MemoryStorage::<V>::new(), - ) - } - - /// Build a trial factory for sampler integration when `V = f64`. - fn make_trial_factory( - sampler: &Arc<dyn Sampler>, - storage: &Arc<dyn crate::storage::Storage<V>>, - pruner: &Arc<dyn Pruner>, - ) -> Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>> - where - V: 'static, - { - // Try to downcast the storage's trial buffer to the f64 specialization. - // This succeeds only when V = f64, enabling automatic sampler integration. - let trials_arc = storage.trials_arc(); - let any_ref: &dyn Any = trials_arc; - let f64_trials: Option<&Arc<RwLock<Vec<CompletedTrial<f64>>>>> = any_ref.downcast_ref(); - - f64_trials.map(|trials| { - let sampler = Arc::clone(sampler); - let trials = Arc::clone(trials); - let pruner = Arc::clone(pruner); - let factory: Arc<dyn Fn(u64) -> Trial + Send + Sync> = Arc::new(move |id| { - Trial::with_sampler( - id, - Arc::clone(&sampler), - Arc::clone(&trials), - Arc::clone(&pruner), - ) - }); - factory - }) - } - - /// Create a study with a custom sampler and storage backend. - /// - /// This is the most general constructor — all other constructors - /// delegate to this one. Use it when you need a non-default storage - /// backend (e.g., [`JournalStorage`](crate::storage::JournalStorage)). - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// * `sampler` - The sampler to use for parameter sampling. - /// * `storage` - The storage backend for completed trials. - /// - /// # Examples - /// - /// ``` - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::storage::MemoryStorage; - /// use optimizer::{Direction, Study}; - /// - /// let storage = MemoryStorage::<f64>::new(); - /// let study = Study::with_sampler_and_storage(Direction::Minimize, RandomSampler::new(), storage); - /// ``` - pub fn with_sampler_and_storage( - direction: Direction, - sampler: impl Sampler + 'static, - storage: impl crate::storage::Storage<V> + 'static, - ) -> Self - where - V: 'static, - { - let sampler: Arc<dyn Sampler> = Arc::new(sampler); - let pruner: Arc<dyn Pruner> = Arc::new(NopPruner); - let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage); - let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); - - Self { - direction, - sampler, - pruner, - storage, - trial_factory, - enqueued_params: Arc::new(Mutex::new(VecDeque::new())), - } - } - - /// Return the optimization direction. - #[must_use] - pub fn direction(&self) -> Direction { - self.direction - } - - /// Creates a study with a custom sampler and pruner. - /// - /// Uses the default [`MemoryStorage`](crate::storage::MemoryStorage) backend. - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// * `sampler` - The sampler to use for parameter sampling. - /// * `pruner` - The pruner to use for trial pruning. - /// - /// # Examples - /// - /// ``` - /// use optimizer::pruner::NopPruner; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler_and_pruner(Direction::Minimize, sampler, NopPruner); - /// ``` - pub fn with_sampler_and_pruner( - direction: Direction, - sampler: impl Sampler + 'static, - pruner: impl Pruner + 'static, - ) -> Self - where - V: Send + Sync + 'static, - { - let sampler: Arc<dyn Sampler> = Arc::new(sampler); - let pruner: Arc<dyn Pruner> = Arc::new(pruner); - let storage: Arc<dyn crate::storage::Storage<V>> = - Arc::new(crate::storage::MemoryStorage::<V>::new()); - let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); - - Self { - direction, - sampler, - pruner, - storage, - trial_factory, - enqueued_params: Arc::new(Mutex::new(VecDeque::new())), - } - } - - /// Replace the sampler used for future parameter suggestions. - /// - /// The new sampler takes effect for all subsequent calls to - /// [`create_trial`](Self::create_trial), [`ask`](Self::ask), and the - /// `optimize*` family. Already-completed trials are unaffected. - /// - /// # Examples - /// - /// ``` - /// use optimizer::sampler::tpe::TpeSampler; - /// use optimizer::{Direction, Study}; - /// - /// let mut study: Study<f64> = Study::new(Direction::Minimize); - /// study.set_sampler(TpeSampler::new()); - /// ``` - pub fn set_sampler(&mut self, sampler: impl Sampler + 'static) - where - V: 'static, - { - self.sampler = Arc::new(sampler); - self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); - } - - /// Replace the pruner used for future trials. - /// - /// The new pruner takes effect for all trials created after this call. - /// - /// # Examples - /// - /// ``` - /// use optimizer::prelude::*; - /// - /// let mut study: Study<f64> = Study::new(Direction::Minimize); - /// study.set_pruner(MedianPruner::new(Direction::Minimize)); - /// ``` - pub fn set_pruner(&mut self, pruner: impl Pruner + 'static) - where - V: 'static, - { - self.pruner = Arc::new(pruner); - self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); - } - - /// Return a reference to the study's current pruner. - #[must_use] - pub fn pruner(&self) -> &dyn Pruner { - &*self.pruner - } - - /// Enqueue a specific parameter configuration to be evaluated next. - /// - /// The next call to [`ask()`](Self::ask) or the next trial in [`optimize()`](Self::optimize) - /// will use these exact parameters instead of sampling from the sampler. - /// - /// Multiple configurations can be enqueued; they are evaluated in FIFO order. - /// If an enqueued configuration is missing a parameter that the objective calls - /// `suggest()` on, that parameter falls back to normal sampling. - /// - /// # Arguments - /// - /// * `params` - A map from parameter IDs to the values to use. - /// - /// # Examples - /// - /// ``` - /// use std::collections::HashMap; - /// - /// use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0); - /// let y = IntParam::new(1, 100); - /// - /// // Evaluate these specific configurations first - /// study.enqueue(HashMap::from([ - /// (x.id(), ParamValue::Float(0.001)), - /// (y.id(), ParamValue::Int(3)), - /// ])); - /// - /// // Next trial will use x=0.001, y=3 - /// let mut trial = study.ask(); - /// assert_eq!(x.suggest(&mut trial).unwrap(), 0.001); - /// assert_eq!(y.suggest(&mut trial).unwrap(), 3); - /// ``` - pub fn enqueue(&self, params: HashMap<ParamId, ParamValue>) { - self.enqueued_params.lock().push_back(params); - } - - /// Return the trial ID of the current best trial from the given slice. - #[cfg(feature = "tracing")] - fn best_id(&self, trials: &[CompletedTrial<V>]) -> Option<u64> { - let direction = self.direction; - trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .max_by(|a, b| Self::compare_trials(a, b, direction)) - .map(|t| t.id) - } - - /// Return the number of enqueued parameter configurations. - /// - /// See [`enqueue`](Self::enqueue) for how to add configurations. - #[must_use] - pub fn n_enqueued(&self) -> usize { - self.enqueued_params.lock().len() - } - - /// Generate the next unique trial ID. - pub(crate) fn next_trial_id(&self) -> u64 { - self.storage.next_trial_id() - } - - /// Create a new trial with a unique ID. - /// - /// The trial starts in the `Running` state and can be used to suggest - /// parameter values. After the objective function is evaluated, call - /// `complete_trial` or `fail_trial` to record the result. - /// - /// For `Study<f64>`, this method automatically integrates with the study's - /// sampler and trial history, so there is no need to call a separate - /// `create_trial_with_sampler()` method. - /// - /// # Examples - /// - /// ``` - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let trial = study.create_trial(); - /// assert_eq!(trial.id(), 0); - /// - /// let trial2 = study.create_trial(); - /// assert_eq!(trial2.id(), 1); - /// ``` - #[must_use] - pub fn create_trial(&self) -> Trial { - self.storage.refresh(); - - let id = self.next_trial_id(); - let mut trial = if let Some(factory) = &self.trial_factory { - factory(id) - } else { - Trial::new(id) - }; - - // If there are enqueued params, inject them into this trial - if let Some(fixed_params) = self.enqueued_params.lock().pop_front() { - trial.set_fixed_params(fixed_params); - } - - trial - } - - /// Record a completed trial with its objective value. - /// - /// This method stores the trial's parameters, distributions, and objective - /// value in the study's history. The stored data is used by samplers to - /// inform future parameter suggestions. - /// - /// # Arguments - /// - /// * `trial` - The trial that was evaluated. - /// * `value` - The objective value returned by the objective function. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x_param = FloatParam::new(0.0, 1.0); - /// let mut trial = study.create_trial(); - /// let x = x_param.suggest(&mut trial).unwrap(); - /// let objective_value = x * x; - /// study.complete_trial(trial, objective_value); - /// - /// assert_eq!(study.n_trials(), 1); - /// ``` - pub fn complete_trial(&self, trial: Trial, value: V) { - let completed = trial.into_completed(value, TrialState::Complete); - self.storage.push(completed); - } - - /// Record a failed trial with an error message. - /// - /// Failed trials are not stored in the study's history and do not - /// contribute to future sampling decisions. This method is useful - /// when the objective function raises an error that should not stop - /// the optimization process. - /// - /// # Arguments - /// - /// * `trial` - The trial that failed. - /// * `_error` - An error message describing why the trial failed. - /// - /// # Examples - /// - /// ``` - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let trial = study.create_trial(); - /// study.fail_trial(trial, "objective function raised an exception"); - /// - /// // Failed trials are not counted - /// assert_eq!(study.n_trials(), 0); - /// ``` - pub fn fail_trial(&self, mut trial: Trial, _error: impl ToString) { - trial.set_failed(); - // Failed trials are not stored in completed_trials - // They could be stored in a separate list for debugging if needed - } - - /// Request a new trial with suggested parameters. - /// - /// This is the first half of the ask-and-tell interface. After calling - /// `ask()`, use parameter types to suggest values on the returned trial, - /// evaluate your objective externally, then pass the trial back to - /// [`tell()`](Self::tell) with the result. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0); - /// - /// let mut trial = study.ask(); - /// let x_val = x.suggest(&mut trial).unwrap(); - /// let value = x_val * x_val; - /// study.tell(trial, Ok::<_, &str>(value)); - /// ``` - #[must_use] - pub fn ask(&self) -> Trial { - self.create_trial() - } - - /// Report the result of a trial obtained from [`ask()`](Self::ask). - /// - /// Pass `Ok(value)` for a successful evaluation or `Err(reason)` for a - /// failure. Failed trials are not stored in the study's history. - /// - /// # Examples - /// - /// ``` - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// - /// let trial = study.ask(); - /// study.tell(trial, Ok::<_, &str>(42.0)); - /// assert_eq!(study.n_trials(), 1); - /// - /// let trial = study.ask(); - /// study.tell(trial, Err::<f64, _>("evaluation failed")); - /// assert_eq!(study.n_trials(), 1); // failed trials not counted - /// ``` - pub fn tell(&self, trial: Trial, value: core::result::Result<V, impl ToString>) { - match value { - Ok(v) => self.complete_trial(trial, v), - Err(e) => self.fail_trial(trial, e), - } - } - - /// Record a pruned trial, preserving its intermediate values. - /// - /// Pruned trials are stored alongside completed trials so that samplers - /// can optionally learn from partial evaluations. The trial's state is - /// set to [`Pruned`](crate::TrialState::Pruned). - /// - /// In practice you rarely call this directly — returning - /// `Err(TrialPruned)` from an objective function handles pruning - /// automatically. - /// - /// # Arguments - /// - /// * `trial` - The trial that was pruned. - pub fn prune_trial(&self, trial: Trial) - where - V: Default, - { - let completed = trial.into_completed(V::default(), TrialState::Pruned); - self.storage.push(completed); - } - - /// Return all completed trials as a `Vec`. - /// - /// The returned vector contains clones of `CompletedTrial` values, which contain - /// the trial's parameters, distributions, and objective value. - /// - /// Note: This method acquires a read lock on the completed trials, so the - /// returned vector is a clone of the internal storage. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x_param = FloatParam::new(0.0, 1.0); - /// let mut trial = study.create_trial(); - /// let _ = x_param.suggest(&mut trial); - /// study.complete_trial(trial, 0.5); - /// - /// for completed in study.trials() { - /// println!("Trial {} has value {:?}", completed.id, completed.value); - /// } - /// ``` - #[must_use] - pub fn trials(&self) -> Vec<CompletedTrial<V>> - where - V: Clone, - { - self.storage.trials_arc().read().clone() - } - - /// Return the number of completed trials. - /// - /// Failed trials are not counted. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// assert_eq!(study.n_trials(), 0); - /// - /// let x_param = FloatParam::new(0.0, 1.0); - /// let mut trial = study.create_trial(); - /// let _ = x_param.suggest(&mut trial); - /// study.complete_trial(trial, 0.5); - /// assert_eq!(study.n_trials(), 1); - /// ``` - #[must_use] - pub fn n_trials(&self) -> usize { - self.storage.trials_arc().read().len() - } - - /// Return the number of pruned trials. - /// - /// Pruned trials are those that were stopped early by the pruner. - #[must_use] - pub fn n_pruned_trials(&self) -> usize { - self.storage - .trials_arc() - .read() - .iter() - .filter(|t| t.state == TrialState::Pruned) - .count() - } - - /// Compare two completed trials using constraint-aware ranking. - /// - /// 1. Feasible trials always rank above infeasible trials. - /// 2. Among feasible trials, rank by objective value (respecting direction). - /// 3. Among infeasible trials, rank by total constraint violation (lower is better). - fn compare_trials( - a: &CompletedTrial<V>, - b: &CompletedTrial<V>, - direction: Direction, - ) -> core::cmp::Ordering { - match (a.is_feasible(), b.is_feasible()) { - (true, false) => core::cmp::Ordering::Greater, - (false, true) => core::cmp::Ordering::Less, - (false, false) => { - let va: f64 = a.constraints.iter().map(|c| c.max(0.0)).sum(); - let vb: f64 = b.constraints.iter().map(|c| c.max(0.0)).sum(); - vb.partial_cmp(&va).unwrap_or(core::cmp::Ordering::Equal) - } - (true, true) => { - let ordering = a.value.partial_cmp(&b.value); - match direction { - Direction::Minimize => { - ordering.map_or(core::cmp::Ordering::Equal, core::cmp::Ordering::reverse) - } - Direction::Maximize => ordering.unwrap_or(core::cmp::Ordering::Equal), - } - } - } - } - - /// Return the trial with the best objective value. - /// - /// The "best" trial depends on the optimization direction: - /// - `Direction::Minimize`: Returns the trial with the lowest objective value. - /// - `Direction::Maximize`: Returns the trial with the highest objective value. - /// - /// When constraints are present, feasible trials always rank above infeasible - /// trials. Among infeasible trials, those with lower total constraint violation - /// are preferred. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials have been completed. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// - /// // Error when no trials completed - /// assert!(study.best_trial().is_err()); - /// - /// let x_param = FloatParam::new(0.0, 1.0); - /// - /// let mut trial1 = study.create_trial(); - /// let _ = x_param.suggest(&mut trial1); - /// study.complete_trial(trial1, 0.8); - /// - /// let mut trial2 = study.create_trial(); - /// let _ = x_param.suggest(&mut trial2); - /// study.complete_trial(trial2, 0.3); - /// - /// let best = study.best_trial().unwrap(); - /// assert_eq!(best.value, 0.3); // Minimize: lower is better - /// ``` - pub fn best_trial(&self) -> crate::Result<CompletedTrial<V>> - where - V: Clone, - { - let trials = self.storage.trials_arc().read(); - let direction = self.direction; - - let best = trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .max_by(|a, b| Self::compare_trials(a, b, direction)) - .ok_or(crate::Error::NoCompletedTrials)?; - - Ok(best.clone()) - } - - /// Return the best objective value found so far. - /// - /// The "best" value depends on the optimization direction: - /// - `Direction::Minimize`: Returns the lowest objective value. - /// - `Direction::Maximize`: Returns the highest objective value. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials have been completed. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Maximize); - /// - /// // Error when no trials completed - /// assert!(study.best_value().is_err()); - /// - /// let x_param = FloatParam::new(0.0, 1.0); - /// - /// let mut trial1 = study.create_trial(); - /// let _ = x_param.suggest(&mut trial1); - /// study.complete_trial(trial1, 0.3); - /// - /// let mut trial2 = study.create_trial(); - /// let _ = x_param.suggest(&mut trial2); - /// study.complete_trial(trial2, 0.8); - /// - /// let best = study.best_value().unwrap(); - /// assert_eq!(best, 0.8); // Maximize: higher is better - /// ``` - pub fn best_value(&self) -> crate::Result<V> - where - V: Clone, - { - self.best_trial().map(|trial| trial.value) - } - - /// Return the top `n` trials sorted by objective value. - /// - /// For `Direction::Minimize`, returns trials with the lowest values. - /// For `Direction::Maximize`, returns trials with the highest values. - /// Only includes completed trials (not failed or pruned). - /// - /// If fewer than `n` completed trials exist, returns all of them. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0); - /// - /// for val in [5.0, 1.0, 3.0] { - /// let mut t = study.create_trial(); - /// let _ = x.suggest(&mut t); - /// study.complete_trial(t, val); - /// } - /// - /// let top2 = study.top_trials(2); - /// assert_eq!(top2.len(), 2); - /// assert!(top2[0].value <= top2[1].value); - /// ``` - #[must_use] - pub fn top_trials(&self, n: usize) -> Vec<CompletedTrial<V>> - where - V: Clone, - { - let trials = self.storage.trials_arc().read(); - let direction = self.direction; - // Sort indices instead of cloning all trials, then clone only the top N. - let mut indices: Vec<usize> = trials - .iter() - .enumerate() - .filter(|(_, t)| t.state == TrialState::Complete) - .map(|(i, _)| i) - .collect(); - // Sort best-first: reverse the compare_trials ordering (which is designed for max_by) - indices.sort_by(|&a, &b| Self::compare_trials(&trials[b], &trials[a], direction)); - indices.truncate(n); - indices.iter().map(|&i| trials[i].clone()).collect() - } - - /// Run optimization with an objective. - /// - /// Accepts any [`Objective`](crate::Objective) implementation, including - /// plain closures (`Fn(&mut Trial) -> Result<V, E>`) thanks to the - /// blanket impl. Struct-based objectives can override - /// [`before_trial`](crate::Objective::before_trial) and - /// [`after_trial`](crate::Objective::after_trial) for early stopping. - /// - /// Runs up to `n_trials` evaluations sequentially. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize(10, |trial: &mut optimizer::Trial| { - /// let x = x_param.suggest(trial)?; - /// Ok::<_, optimizer::Error>(x * x) - /// }) - /// .unwrap(); - /// - /// assert!(study.n_trials() > 0); - /// assert!(study.best_value().unwrap() >= 0.0); - /// ``` - #[allow(clippy::needless_pass_by_value)] - pub fn optimize( - &self, - n_trials: usize, - objective: impl crate::objective::Objective<V>, - ) -> crate::Result<()> - where - V: Clone + Default, - { - #[cfg(feature = "tracing")] - let _span = - tracing::info_span!("optimize", n_trials, direction = ?self.direction).entered(); - - for _ in 0..n_trials { - if let ControlFlow::Break(()) = objective.before_trial(self) { - break; - } - - let mut trial = self.create_trial(); - match objective.evaluate(&mut trial) { - Ok(value) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - - let completed = trial.into_completed(value, TrialState::Complete); - - // Fire after_trial hook before pushing to storage - let flow = objective.after_trial(self, &completed); - self.storage.push(completed); - - #[cfg(feature = "tracing")] - { - tracing::info!(trial_id, "trial completed"); - let trials = self.storage.trials_arc().read(); - if trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count() - == 1 - || trials.last().map(|t| t.id) == self.best_id(&trials) - { - tracing::info!(trial_id, "new best value found"); - } - } - - if let ControlFlow::Break(()) = flow { - return Ok(()); - } - } - Err(e) if is_trial_pruned(&e) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.prune_trial(trial); - trace_info!(trial_id, "trial pruned"); - } - Err(e) => { - #[cfg(feature = "tracing")] - let trial_id = trial.id(); - self.fail_trial(trial, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - - // Return error if no trials completed successfully - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run async optimization with an objective. - /// - /// Like [`optimize`](Self::optimize), but each evaluation is wrapped in - /// [`spawn_blocking`](tokio::task::spawn_blocking), keeping the async - /// runtime responsive for CPU-bound objectives. Trials run sequentially. - /// - /// Accepts any [`Objective`](crate::Objective) implementation, including - /// plain closures. Struct-based objectives can override lifecycle hooks. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// Returns `Error::TaskError` if a spawned blocking task panics. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// # #[cfg(feature = "async")] - /// # async fn example() -> optimizer::Result<()> { - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize_async(10, move |trial: &mut optimizer::Trial| { - /// let x = x_param.suggest(trial)?; - /// Ok::<_, optimizer::Error>(x * x) - /// }) - /// .await?; - /// - /// assert!(study.n_trials() > 0); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "async")] - pub async fn optimize_async<O>(&self, n_trials: usize, objective: O) -> crate::Result<()> - where - O: crate::objective::Objective<V> + Send + Sync + 'static, - O::Error: Send, - V: Clone + Default + Send + 'static, - { - #[cfg(feature = "tracing")] - let _span = - tracing::info_span!("optimize_async", n_trials, direction = ?self.direction).entered(); - - let objective = Arc::new(objective); - - for _ in 0..n_trials { - if let ControlFlow::Break(()) = objective.before_trial(self) { - break; - } - - let obj = Arc::clone(&objective); - let mut trial = self.create_trial(); - let result = tokio::task::spawn_blocking(move || { - let res = obj.evaluate(&mut trial); - (trial, res) - }) - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - - let completed = t.into_completed(value, TrialState::Complete); - let flow = objective.after_trial(self, &completed); - self.storage.push(completed); - trace_info!(trial_id, "trial completed"); - - if let ControlFlow::Break(()) = flow { - return Ok(()); - } - } - (t, Err(e)) if is_trial_pruned(&e) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } - - /// Run parallel optimization with an objective. - /// - /// Spawns up to `concurrency` evaluations concurrently using - /// [`spawn_blocking`](tokio::task::spawn_blocking). Results are - /// collected via a [`JoinSet`](tokio::task::JoinSet). - /// - /// Accepts any [`Objective`](crate::Objective) implementation, including - /// plain closures. The [`after_trial`](crate::Objective::after_trial) - /// hook fires as each result arrives — returning `Break` stops spawning - /// new trials while in-flight tasks drain. - /// - /// # Errors - /// - /// Returns `Error::NoCompletedTrials` if no trials completed successfully. - /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::sampler::random::RandomSampler; - /// use optimizer::{Direction, Study}; - /// - /// # #[cfg(feature = "async")] - /// # async fn example() -> optimizer::Result<()> { - /// let sampler = RandomSampler::with_seed(42); - /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); - /// let x_param = FloatParam::new(-10.0, 10.0); - /// - /// study - /// .optimize_parallel(10, 4, move |trial: &mut optimizer::Trial| { - /// let x = x_param.suggest(trial)?; - /// Ok::<_, optimizer::Error>(x * x) - /// }) - /// .await?; - /// - /// assert_eq!(study.n_trials(), 10); - /// # Ok(()) - /// # } - /// ``` - #[cfg(feature = "async")] - #[allow(clippy::missing_panics_doc, clippy::too_many_lines)] - pub async fn optimize_parallel<O>( - &self, - n_trials: usize, - concurrency: usize, - objective: O, - ) -> crate::Result<()> - where - O: crate::objective::Objective<V> + Send + Sync + 'static, - O::Error: Send, - V: Clone + Default + Send + 'static, - { - use tokio::sync::Semaphore; - use tokio::task::JoinSet; - - #[cfg(feature = "tracing")] - let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); - - let objective = Arc::new(objective); - let semaphore = Arc::new(Semaphore::new(concurrency)); - let mut join_set: JoinSet<(Trial, Result<V, O::Error>)> = JoinSet::new(); - let mut spawned = 0; - - 'spawn: while spawned < n_trials { - if let ControlFlow::Break(()) = objective.before_trial(self) { - break; - } - - // If the join set is full, drain one result to free a slot. - while join_set.len() >= concurrency { - let result = join_set - .join_next() - .await - .expect("join_set should not be empty") - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - - let completed = t.into_completed(value, TrialState::Complete); - let flow = objective.after_trial(self, &completed); - self.storage.push(completed); - trace_info!(trial_id, "trial completed"); - - if let ControlFlow::Break(()) = flow { - break 'spawn; - } - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - if is_trial_pruned(&e) { - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - let permit = semaphore - .clone() - .acquire_owned() - .await - .map_err(|e| crate::Error::TaskError(e.to_string()))?; - - let mut trial = self.create_trial(); - let obj = Arc::clone(&objective); - join_set.spawn(async move { - let result = tokio::task::spawn_blocking(move || { - let res = obj.evaluate(&mut trial); - (trial, res) - }) - .await - .expect("spawn_blocking should not panic"); - drop(permit); - result - }); - spawned += 1; - } - - // Drain remaining in-flight tasks. - while let Some(result) = join_set.join_next().await { - let result = result.map_err(|e| crate::Error::TaskError(e.to_string()))?; - match result { - (t, Ok(value)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - - let completed = t.into_completed(value, TrialState::Complete); - // Still fire after_trial for bookkeeping, but don't break — we're draining. - let _ = objective.after_trial(self, &completed); - self.storage.push(completed); - trace_info!(trial_id, "trial completed"); - } - (t, Err(e)) => { - #[cfg(feature = "tracing")] - let trial_id = t.id(); - if is_trial_pruned(&e) { - self.prune_trial(t); - trace_info!(trial_id, "trial pruned"); - } else { - self.fail_trial(t, e.to_string()); - trace_debug!(trial_id, "trial failed"); - } - } - } - } - - let has_complete = self - .storage - .trials_arc() - .read() - .iter() - .any(|t| t.state == TrialState::Complete); - if !has_complete { - return Err(crate::Error::NoCompletedTrials); - } - - Ok(()) - } -} - -impl<V> Study<V> -where - V: PartialOrd + Clone + fmt::Display, -{ - /// Write completed trials to a writer in CSV format. - /// - /// Columns: `trial_id`, `value`, `state`, then one column per unique - /// parameter label, then one column per unique user-attribute key. - /// - /// Parameters without labels use a generated name (`param_<id>`). - /// Pruned trials have an empty `value` cell. - /// - /// # Errors - /// - /// Returns an I/O error if writing fails. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0).name("x"); - /// - /// let mut trial = study.create_trial(); - /// let _ = x.suggest(&mut trial); - /// study.complete_trial(trial, 0.42); - /// - /// let mut buf = Vec::new(); - /// study.to_csv(&mut buf).unwrap(); - /// let csv = String::from_utf8(buf).unwrap(); - /// assert!(csv.contains("trial_id")); - /// ``` - pub fn to_csv(&self, mut writer: impl std::io::Write) -> std::io::Result<()> { - use std::collections::BTreeMap; - - let trials = self.storage.trials_arc().read(); - - // Collect all unique parameter labels (sorted for deterministic column order). - let mut param_columns: BTreeMap<ParamId, String> = BTreeMap::new(); - for trial in trials.iter() { - for &id in trial.params.keys() { - param_columns.entry(id).or_insert_with(|| { - trial - .param_labels - .get(&id) - .cloned() - .unwrap_or_else(|| id.to_string()) - }); - } - } - // Fill in labels from other trials that might have better labels. - for trial in trials.iter() { - for (&id, label) in &trial.param_labels { - param_columns.entry(id).or_insert_with(|| label.clone()); - } - } - - // Collect all unique attribute keys (sorted). - let mut attr_keys: Vec<String> = Vec::new(); - for trial in trials.iter() { - for key in trial.user_attrs.keys() { - if !attr_keys.contains(key) { - attr_keys.push(key.clone()); - } - } - } - attr_keys.sort(); - - let param_ids: Vec<ParamId> = param_columns.keys().copied().collect(); - - // Write header. - write!(writer, "trial_id,value,state")?; - for id in ¶m_ids { - write!(writer, ",{}", csv_escape(¶m_columns[id]))?; - } - for key in &attr_keys { - write!(writer, ",{}", csv_escape(key))?; - } - writeln!(writer)?; - - // Write one row per trial. - for trial in trials.iter() { - write!(writer, "{}", trial.id)?; - - // Value: empty for pruned trials. - if trial.state == TrialState::Complete { - write!(writer, ",{}", trial.value)?; - } else { - write!(writer, ",")?; - } - - write!( - writer, - ",{}", - match trial.state { - TrialState::Complete => "Complete", - TrialState::Pruned => "Pruned", - TrialState::Failed => "Failed", - TrialState::Running => "Running", - } - )?; - - for id in ¶m_ids { - if let Some(pv) = trial.params.get(id) { - write!(writer, ",{pv}")?; - } else { - write!(writer, ",")?; - } - } - - for key in &attr_keys { - if let Some(attr) = trial.user_attrs.get(key) { - write!(writer, ",{}", csv_escape(&format_attr(attr)))?; - } else { - write!(writer, ",")?; - } - } - - writeln!(writer)?; - } - - Ok(()) - } - - /// Export completed trials to a CSV file at the given path. - /// - /// Convenience wrapper around [`to_csv`](Self::to_csv) that creates a - /// buffered file writer. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be created or written. - pub fn export_csv(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { - let file = std::fs::File::create(path)?; - self.to_csv(std::io::BufWriter::new(file)) - } - - /// Return a human-readable summary of the study. - /// - /// The summary includes: - /// - Optimization direction and total trial count - /// - Breakdown by state (complete, pruned) when applicable - /// - Best trial value and parameters (if any completed trials exist) - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0).name("x"); - /// - /// let mut trial = study.create_trial(); - /// let _ = x.suggest(&mut trial).unwrap(); - /// study.complete_trial(trial, 0.42); - /// - /// let summary = study.summary(); - /// assert!(summary.contains("Minimize")); - /// assert!(summary.contains("0.42")); - /// ``` - #[must_use] - pub fn summary(&self) -> String { - use fmt::Write; - - let trials = self.storage.trials_arc().read(); - let n_complete = trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .count(); - let n_pruned = trials - .iter() - .filter(|t| t.state == TrialState::Pruned) - .count(); - - let direction_str = match self.direction { - Direction::Minimize => "Minimize", - Direction::Maximize => "Maximize", - }; - - let mut s = format!("Study: {direction_str} | {n} trials", n = trials.len()); - if n_pruned > 0 { - let _ = write!(s, " ({n_complete} complete, {n_pruned} pruned)"); - } - - drop(trials); - - if let Ok(best) = self.best_trial() { - let _ = write!(s, "\nBest value: {} (trial #{})", best.value, best.id); - if !best.params.is_empty() { - s.push_str("\nBest parameters:"); - let mut params: Vec<_> = best.params.iter().collect(); - params.sort_by_key(|(id, _)| *id); - for (id, value) in params { - let label = best.param_labels.get(id).map_or("?", String::as_str); - let _ = write!(s, "\n {label} = {value}"); - } - } - } - - s - } -} - -impl<V> Study<V> -where - V: PartialOrd + Clone, -{ - /// Return an iterator over all completed trials. - /// - /// This clones the internal trial list, so it is suitable for - /// analysis and iteration but not for hot paths. - /// - /// # Examples - /// - /// ``` - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let trial = study.create_trial(); - /// study.complete_trial(trial, 1.0); - /// - /// for t in study.iter() { - /// println!("Trial {} → {}", t.id, t.value); - /// } - /// ``` - #[must_use] - pub fn iter(&self) -> std::vec::IntoIter<CompletedTrial<V>> { - self.trials().into_iter() - } -} - -impl<V> Study<V> -where - V: PartialOrd + Clone + Into<f64>, -{ - /// Compute parameter importance scores using Spearman rank correlation. - /// - /// For each parameter, the absolute Spearman correlation between its values - /// and the objective values is computed across all completed trials. Scores - /// are normalized so they sum to 1.0 and sorted in descending order. - /// - /// Parameters that appear in fewer than 2 trials are omitted. - /// Returns an empty `Vec` if the study has fewer than 2 completed trials. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0).name("x"); - /// - /// study - /// .optimize(20, |trial: &mut optimizer::Trial| { - /// let xv = x.suggest(trial)?; - /// Ok::<_, optimizer::Error>(xv * xv) - /// }) - /// .unwrap(); - /// - /// let importance = study.param_importance(); - /// assert_eq!(importance.len(), 1); - /// assert_eq!(importance[0].0, "x"); - /// ``` - #[must_use] - #[allow(clippy::cast_precision_loss)] - pub fn param_importance(&self) -> Vec<(String, f64)> { - use std::collections::BTreeSet; - - use crate::importance::spearman; - use crate::param::ParamValue; - use crate::types::TrialState; - - let trials = self.storage.trials_arc().read(); - let complete: Vec<_> = trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .collect(); - - if complete.len() < 2 { - return Vec::new(); - } - - // Collect all parameter IDs across trials. - let all_param_ids: BTreeSet<_> = complete.iter().flat_map(|t| t.params.keys()).collect(); - - let mut scores: Vec<(String, f64)> = Vec::with_capacity(all_param_ids.len()); - - for ¶m_id in &all_param_ids { - // Collect (param_value_f64, objective_f64) for trials that have this param. - let mut param_vals = Vec::with_capacity(complete.len()); - let mut obj_vals = Vec::with_capacity(complete.len()); - - for trial in &complete { - if let Some(pv) = trial.params.get(param_id) { - let f = match *pv { - ParamValue::Float(v) => v, - ParamValue::Int(v) => v as f64, - ParamValue::Categorical(v) => v as f64, - }; - param_vals.push(f); - obj_vals.push(trial.value.clone().into()); - } - } - - if param_vals.len() < 2 { - continue; - } - - let corr = spearman(¶m_vals, &obj_vals).abs(); - - // Determine label: use param_labels if available, else "param_{id}". - let label = complete - .iter() - .find_map(|t| t.param_labels.get(param_id)) - .map_or_else(|| param_id.to_string(), Clone::clone); - - scores.push((label, corr)); - } - - // Normalize so scores sum to 1.0. - let sum: f64 = scores.iter().map(|(_, s)| *s).sum(); - if sum > 0.0 { - for entry in &mut scores { - entry.1 /= sum; - } - } - - // Sort descending by score. - scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); - - scores - } - - /// Compute parameter importance using fANOVA (functional ANOVA) with - /// default configuration. - /// - /// Fits a random forest to the trial data and decomposes variance into - /// per-parameter main effects and pairwise interaction effects. This is - /// more accurate than correlation-based importance ([`Self::param_importance`]) - /// and can detect non-linear relationships and parameter interactions. - /// - /// # Errors - /// - /// Returns [`crate::Error::NoCompletedTrials`] if fewer than 2 trials have completed. - /// - /// # Examples - /// - /// ``` - /// use optimizer::parameter::{FloatParam, Parameter}; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = Study::new(Direction::Minimize); - /// let x = FloatParam::new(0.0, 10.0).name("x"); - /// let y = FloatParam::new(0.0, 10.0).name("y"); - /// - /// study - /// .optimize(30, |trial: &mut optimizer::Trial| { - /// let xv = x.suggest(trial)?; - /// let yv = y.suggest(trial)?; - /// Ok::<_, optimizer::Error>(xv * xv + 0.1 * yv) - /// }) - /// .unwrap(); - /// - /// let result = study.fanova().unwrap(); - /// assert!(!result.main_effects.is_empty()); - /// ``` - pub fn fanova(&self) -> crate::Result<crate::fanova::FanovaResult> { - self.fanova_with_config(&crate::fanova::FanovaConfig::default()) - } - - /// Compute parameter importance using fANOVA with custom configuration. - /// - /// See [`Self::fanova`] for details. The [`FanovaConfig`](crate::fanova::FanovaConfig) - /// allows tuning the number of trees, tree depth, and random seed. - /// - /// # Errors - /// - /// Returns [`crate::Error::NoCompletedTrials`] if fewer than 2 trials have completed. - #[allow(clippy::cast_precision_loss)] - pub fn fanova_with_config( - &self, - config: &crate::fanova::FanovaConfig, - ) -> crate::Result<crate::fanova::FanovaResult> { - use std::collections::BTreeSet; - - use crate::fanova::compute_fanova; - use crate::param::ParamValue; - use crate::types::TrialState; - - let trials = self.storage.trials_arc().read(); - let complete: Vec<_> = trials - .iter() - .filter(|t| t.state == TrialState::Complete) - .collect(); - - if complete.len() < 2 { - return Err(crate::Error::NoCompletedTrials); - } - - // Collect all parameter IDs in a stable order. - let all_param_ids: Vec<_> = { - let set: BTreeSet<_> = complete.iter().flat_map(|t| t.params.keys()).collect(); - set.into_iter().collect() - }; - - if all_param_ids.is_empty() { - return Ok(crate::fanova::FanovaResult { - main_effects: Vec::new(), - interactions: Vec::new(), - }); - } - - // Build feature matrix (only trials that have all parameters). - let mut data = Vec::with_capacity(complete.len()); - let mut targets = Vec::with_capacity(complete.len()); - - for trial in &complete { - let mut row = Vec::with_capacity(all_param_ids.len()); - let mut has_all = true; - - for &pid in &all_param_ids { - if let Some(pv) = trial.params.get(pid) { - row.push(match *pv { - ParamValue::Float(v) => v, - ParamValue::Int(v) => v as f64, - ParamValue::Categorical(v) => v as f64, - }); - } else { - has_all = false; - break; - } - } - - if has_all { - data.push(row); - targets.push(trial.value.clone().into()); - } - } - - if data.len() < 2 { - return Err(crate::Error::NoCompletedTrials); - } - - // Build feature names from parameter labels. - let feature_names: Vec<String> = all_param_ids - .iter() - .map(|&pid| { - complete - .iter() - .find_map(|t| t.param_labels.get(pid)) - .map_or_else(|| pid.to_string(), Clone::clone) - }) - .collect(); - - Ok(compute_fanova(&data, &targets, &feature_names, config)) - } -} - -impl<V> IntoIterator for &Study<V> -where - V: PartialOrd + Clone, -{ - type Item = CompletedTrial<V>; - type IntoIter = std::vec::IntoIter<CompletedTrial<V>>; - - fn into_iter(self) -> Self::IntoIter { - self.iter() - } -} - -impl<V> fmt::Display for Study<V> -where - V: PartialOrd + Clone + fmt::Display, -{ - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&self.summary()) - } -} - -impl<V: PartialOrd + Send + Sync + 'static> Study<V> { - /// Create a study with a custom sampler, pruner, and storage backend. - /// - /// The most flexible constructor, allowing full control over all components. - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// * `sampler` - The sampler to use for parameter sampling. - /// * `pruner` - The pruner to use for trial pruning. - /// * `storage` - The storage backend for completed trials. - /// - /// # Examples - /// - /// ``` - /// use optimizer::prelude::*; - /// use optimizer::storage::MemoryStorage; - /// - /// let study = Study::with_sampler_pruner_and_storage( - /// Direction::Minimize, - /// TpeSampler::new(), - /// MedianPruner::new(Direction::Minimize), - /// MemoryStorage::<f64>::new(), - /// ); - /// ``` - pub fn with_sampler_pruner_and_storage( - direction: Direction, - sampler: impl Sampler + 'static, - pruner: impl Pruner + 'static, - storage: impl crate::storage::Storage<V> + 'static, - ) -> Self { - let sampler: Arc<dyn Sampler> = Arc::new(sampler); - let pruner: Arc<dyn Pruner> = Arc::new(pruner); - let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage); - let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); - - Self { - direction, - sampler, - pruner, - storage, - trial_factory, - enqueued_params: Arc::new(Mutex::new(VecDeque::new())), - } - } -} - -/// A builder for constructing [`Study`] instances with a fluent API. -/// -/// Created via [`Study::builder()`]. Collects sampler, pruner, direction, -/// and storage options before constructing the study. -/// -/// # Defaults -/// -/// - Direction: [`Minimize`](Direction::Minimize) -/// - Sampler: [`RandomSampler`] -/// - Pruner: [`NopPruner`] -/// - Storage: [`MemoryStorage`](crate::storage::MemoryStorage) -/// -/// # Examples -/// -/// ``` -/// use optimizer::prelude::*; -/// -/// let study: Study<f64> = Study::builder() -/// .maximize() -/// .sampler(TpeSampler::new()) -/// .pruner(MedianPruner::new(Direction::Maximize).n_warmup_steps(5)) -/// .build(); -/// -/// assert_eq!(study.direction(), Direction::Maximize); -/// ``` -pub struct StudyBuilder<V: PartialOrd = f64> { - direction: Direction, - sampler: Option<Box<dyn Sampler>>, - pruner: Option<Box<dyn Pruner>>, - storage: Option<Box<dyn crate::storage::Storage<V>>>, - _marker: PhantomData<V>, -} - -impl<V: PartialOrd> StudyBuilder<V> { - /// Set the optimization direction to minimize (the default). - #[must_use] - pub fn minimize(mut self) -> Self { - self.direction = Direction::Minimize; - self - } - - /// Set the optimization direction to maximize. - #[must_use] - pub fn maximize(mut self) -> Self { - self.direction = Direction::Maximize; - self - } - - /// Set the optimization direction explicitly. - #[must_use] - pub fn direction(mut self, direction: Direction) -> Self { - self.direction = direction; - self - } - - /// Set the sampler used for parameter suggestions. - /// - /// Defaults to [`RandomSampler`] if not specified. - #[must_use] - pub fn sampler(mut self, sampler: impl Sampler + 'static) -> Self { - self.sampler = Some(Box::new(sampler)); - self - } - - /// Set the pruner used for early stopping of trials. - /// - /// Defaults to [`NopPruner`] (no pruning) if not specified. - #[must_use] - pub fn pruner(mut self, pruner: impl Pruner + 'static) -> Self { - self.pruner = Some(Box::new(pruner)); - self - } - - /// Set a custom storage backend. - /// - /// Defaults to [`MemoryStorage`](crate::storage::MemoryStorage) if not specified. - #[must_use] - pub fn storage(mut self, storage: impl crate::storage::Storage<V> + 'static) -> Self { - self.storage = Some(Box::new(storage)); - self - } - - /// Build the [`Study`] with the configured options. - #[must_use] - pub fn build(self) -> Study<V> - where - V: Send + Sync + 'static, - { - let sampler = self - .sampler - .unwrap_or_else(|| Box::new(RandomSampler::new())); - let pruner = self.pruner.unwrap_or_else(|| Box::new(NopPruner)); - let storage = self - .storage - .unwrap_or_else(|| Box::new(crate::storage::MemoryStorage::<V>::new())); - - let sampler: Arc<dyn Sampler> = Arc::from(sampler); - let pruner: Arc<dyn Pruner> = Arc::from(pruner); - let storage: Arc<dyn crate::storage::Storage<V>> = Arc::from(storage); - let trial_factory = Study::make_trial_factory(&sampler, &storage, &pruner); - - Study { - direction: self.direction, - sampler, - pruner, - storage, - trial_factory, - enqueued_params: Arc::new(Mutex::new(VecDeque::new())), - } - } -} - -#[cfg(feature = "journal")] -impl<V> Study<V> -where - V: PartialOrd + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static, -{ - /// Create a study backed by a JSONL journal file. - /// - /// Any existing trials in the file are loaded into memory and the - /// trial ID counter is set to one past the highest stored ID. New - /// trials are written through to the file on completion. - /// - /// # Arguments - /// - /// * `direction` - Whether to minimize or maximize the objective function. - /// * `sampler` - The sampler to use for parameter sampling. - /// * `path` - Path to the JSONL journal file (created if absent). - /// - /// # Errors - /// - /// Returns a [`Storage`](crate::Error::Storage) error if loading fails. - /// - /// # Examples - /// - /// ```no_run - /// use optimizer::sampler::tpe::TpeSampler; - /// use optimizer::{Direction, Study}; - /// - /// let study: Study<f64> = - /// Study::with_journal(Direction::Minimize, TpeSampler::new(), "trials.jsonl").unwrap(); - /// ``` - pub fn with_journal( - direction: Direction, - sampler: impl Sampler + 'static, - path: impl AsRef<std::path::Path>, - ) -> crate::Result<Self> { - let storage = crate::storage::JournalStorage::<V>::open(path)?; - Ok(Self::with_sampler_and_storage(direction, sampler, storage)) - } -} - -impl Study<f64> { - /// Generate an HTML report with interactive Plotly.js charts. - /// - /// Create a self-contained HTML file that can be opened in any browser. - /// See [`generate_html_report`](crate::visualization::generate_html_report) - /// for details on the included charts. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be created or written. - pub fn export_html(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { - crate::visualization::generate_html_report(self, path) - } -} - -/// A serializable snapshot of a study's state. -/// -/// Since [`Study`] contains non-serializable fields (samplers, atomics, etc.), -/// this struct captures the essential state needed to save and restore a study. -/// -/// # Schema versioning -/// -/// The `version` field enables future schema evolution without breaking existing files. -/// The current version is `1`. -/// -/// # Sampler state -/// -/// Sampler state is **not** included in the snapshot. After loading, the study -/// uses a default `RandomSampler`. Call [`Study::set_sampler`] to restore -/// the desired sampler configuration. -#[cfg(feature = "serde")] -#[derive(serde::Serialize, serde::Deserialize)] -pub struct StudySnapshot<V> { - /// Schema version for forward compatibility. - pub version: u32, - /// The optimization direction. - pub direction: Direction, - /// All completed (and pruned) trials. - pub trials: Vec<CompletedTrial<V>>, - /// The next trial ID to assign. - pub next_trial_id: u64, - /// Optional metadata (creation timestamp, sampler description, etc.). - pub metadata: HashMap<String, String>, -} - -#[cfg(feature = "serde")] -impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { - /// Export trials as a pretty-printed JSON array to a file. - /// - /// Each element in the array is a serialized [`CompletedTrial`]. - /// Requires the `serde` feature. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be created or written. - pub fn export_json(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { - let file = std::fs::File::create(path)?; - let trials = self.trials(); - serde_json::to_writer_pretty(file, &trials).map_err(std::io::Error::other) - } - - /// Save the study state to a JSON file. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be created or written. - pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { - let path = path.as_ref(); - let trials = self.trials(); - let next_trial_id = trials.iter().map(|t| t.id).max().map_or(0, |id| id + 1); - let snapshot = StudySnapshot { - version: 1, - direction: self.direction, - trials, - next_trial_id, - metadata: HashMap::new(), - }; - - // Atomic write: write to a temp file in the same directory, then rename. - // This prevents corrupt files if the process crashes mid-write. - let parent = path.parent().unwrap_or(std::path::Path::new(".")); - let tmp_path = parent.join(format!( - ".{}.tmp", - path.file_name().unwrap_or_default().to_string_lossy() - )); - let file = std::fs::File::create(&tmp_path)?; - serde_json::to_writer_pretty(file, &snapshot).map_err(std::io::Error::other)?; - std::fs::rename(&tmp_path, path) - } -} - -#[cfg(feature = "serde")] -impl<V: PartialOrd + Send + Sync + Clone + serde::de::DeserializeOwned + 'static> Study<V> { - /// Load a study from a JSON file. - /// - /// The loaded study uses a `RandomSampler` by default. Call - /// [`set_sampler()`](Self::set_sampler) to restore the original sampler - /// configuration. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be read or parsed. - pub fn load(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> { - let file = std::fs::File::open(path)?; - let snapshot: StudySnapshot<V> = serde_json::from_reader(file) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - let storage = crate::storage::MemoryStorage::with_trials(snapshot.trials); - Ok(Self::with_sampler_and_storage( - snapshot.direction, - RandomSampler::new(), - storage, - )) - } -} - -/// Returns `true` if the error represents a pruned trial. -/// -/// Checks via `Any` downcasting whether `e` is `Error::TrialPruned` or -/// the standalone `TrialPruned` struct. -fn is_trial_pruned<E: 'static>(e: &E) -> bool { - let any: &dyn Any = e; - if let Some(err) = any.downcast_ref::<crate::Error>() { - matches!(err, crate::Error::TrialPruned) - } else { - any.downcast_ref::<crate::error::TrialPruned>().is_some() - } -} - -/// Escape a string for CSV output. If the value contains a comma, quote, or -/// newline, wrap it in double-quotes and double any embedded quotes. -fn csv_escape(s: &str) -> String { - if s.contains(',') || s.contains('"') || s.contains('\n') { - format!("\"{}\"", s.replace('"', "\"\"")) - } else { - s.to_string() - } -} - -/// Format an `AttrValue` as a string for CSV cells. -fn format_attr(attr: &crate::trial::AttrValue) -> String { - use crate::trial::AttrValue; - match attr { - AttrValue::Float(v) => v.to_string(), - AttrValue::Int(v) => v.to_string(), - AttrValue::String(v) => v.clone(), - AttrValue::Bool(v) => v.to_string(), - } -} diff --git a/src/study/analysis.rs b/src/study/analysis.rs new file mode 100644 index 0000000..90a367b --- /dev/null +++ b/src/study/analysis.rs @@ -0,0 +1,379 @@ +use crate::sampler::CompletedTrial; +use crate::types::TrialState; + +use super::Study; + +impl<V> Study<V> +where + V: PartialOrd, +{ + /// Return the trial with the best objective value. + /// + /// The "best" trial depends on the optimization direction: + /// - `Direction::Minimize`: Returns the trial with the lowest objective value. + /// - `Direction::Maximize`: Returns the trial with the highest objective value. + /// + /// When constraints are present, feasible trials always rank above infeasible + /// trials. Among infeasible trials, those with lower total constraint violation + /// are preferred. + /// + /// # Errors + /// + /// Returns `Error::NoCompletedTrials` if no trials have been completed. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// + /// // Error when no trials completed + /// assert!(study.best_trial().is_err()); + /// + /// let x_param = FloatParam::new(0.0, 1.0); + /// + /// let mut trial1 = study.create_trial(); + /// let _ = x_param.suggest(&mut trial1); + /// study.complete_trial(trial1, 0.8); + /// + /// let mut trial2 = study.create_trial(); + /// let _ = x_param.suggest(&mut trial2); + /// study.complete_trial(trial2, 0.3); + /// + /// let best = study.best_trial().unwrap(); + /// assert_eq!(best.value, 0.3); // Minimize: lower is better + /// ``` + pub fn best_trial(&self) -> crate::Result<CompletedTrial<V>> + where + V: Clone, + { + let trials = self.storage.trials_arc().read(); + let direction = self.direction; + + let best = trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .max_by(|a, b| Self::compare_trials(a, b, direction)) + .ok_or(crate::Error::NoCompletedTrials)?; + + Ok(best.clone()) + } + + /// Return the best objective value found so far. + /// + /// The "best" value depends on the optimization direction: + /// - `Direction::Minimize`: Returns the lowest objective value. + /// - `Direction::Maximize`: Returns the highest objective value. + /// + /// # Errors + /// + /// Returns `Error::NoCompletedTrials` if no trials have been completed. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Maximize); + /// + /// // Error when no trials completed + /// assert!(study.best_value().is_err()); + /// + /// let x_param = FloatParam::new(0.0, 1.0); + /// + /// let mut trial1 = study.create_trial(); + /// let _ = x_param.suggest(&mut trial1); + /// study.complete_trial(trial1, 0.3); + /// + /// let mut trial2 = study.create_trial(); + /// let _ = x_param.suggest(&mut trial2); + /// study.complete_trial(trial2, 0.8); + /// + /// let best = study.best_value().unwrap(); + /// assert_eq!(best, 0.8); // Maximize: higher is better + /// ``` + pub fn best_value(&self) -> crate::Result<V> + where + V: Clone, + { + self.best_trial().map(|trial| trial.value) + } + + /// Return the top `n` trials sorted by objective value. + /// + /// For `Direction::Minimize`, returns trials with the lowest values. + /// For `Direction::Maximize`, returns trials with the highest values. + /// Only includes completed trials (not failed or pruned). + /// + /// If fewer than `n` completed trials exist, returns all of them. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0); + /// + /// for val in [5.0, 1.0, 3.0] { + /// let mut t = study.create_trial(); + /// let _ = x.suggest(&mut t); + /// study.complete_trial(t, val); + /// } + /// + /// let top2 = study.top_trials(2); + /// assert_eq!(top2.len(), 2); + /// assert!(top2[0].value <= top2[1].value); + /// ``` + #[must_use] + pub fn top_trials(&self, n: usize) -> Vec<CompletedTrial<V>> + where + V: Clone, + { + let trials = self.storage.trials_arc().read(); + let direction = self.direction; + // Sort indices instead of cloning all trials, then clone only the top N. + let mut indices: Vec<usize> = trials + .iter() + .enumerate() + .filter(|(_, t)| t.state == TrialState::Complete) + .map(|(i, _)| i) + .collect(); + // Sort best-first: reverse the compare_trials ordering (which is designed for max_by) + indices.sort_by(|&a, &b| Self::compare_trials(&trials[b], &trials[a], direction)); + indices.truncate(n); + indices.iter().map(|&i| trials[i].clone()).collect() + } +} + +impl<V> Study<V> +where + V: PartialOrd + Clone + Into<f64>, +{ + /// Compute parameter importance scores using Spearman rank correlation. + /// + /// For each parameter, the absolute Spearman correlation between its values + /// and the objective values is computed across all completed trials. Scores + /// are normalized so they sum to 1.0 and sorted in descending order. + /// + /// Parameters that appear in fewer than 2 trials are omitted. + /// Returns an empty `Vec` if the study has fewer than 2 completed trials. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// + /// study + /// .optimize(20, |trial: &mut optimizer::Trial| { + /// let xv = x.suggest(trial)?; + /// Ok::<_, optimizer::Error>(xv * xv) + /// }) + /// .unwrap(); + /// + /// let importance = study.param_importance(); + /// assert_eq!(importance.len(), 1); + /// assert_eq!(importance[0].0, "x"); + /// ``` + #[must_use] + #[allow(clippy::cast_precision_loss)] + pub fn param_importance(&self) -> Vec<(String, f64)> { + use std::collections::BTreeSet; + + use crate::importance::spearman; + use crate::param::ParamValue; + use crate::types::TrialState; + + let trials = self.storage.trials_arc().read(); + let complete: Vec<_> = trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .collect(); + + if complete.len() < 2 { + return Vec::new(); + } + + // Collect all parameter IDs across trials. + let all_param_ids: BTreeSet<_> = complete.iter().flat_map(|t| t.params.keys()).collect(); + + let mut scores: Vec<(String, f64)> = Vec::with_capacity(all_param_ids.len()); + + for ¶m_id in &all_param_ids { + // Collect (param_value_f64, objective_f64) for trials that have this param. + let mut param_vals = Vec::with_capacity(complete.len()); + let mut obj_vals = Vec::with_capacity(complete.len()); + + for trial in &complete { + if let Some(pv) = trial.params.get(param_id) { + let f = match *pv { + ParamValue::Float(v) => v, + ParamValue::Int(v) => v as f64, + ParamValue::Categorical(v) => v as f64, + }; + param_vals.push(f); + obj_vals.push(trial.value.clone().into()); + } + } + + if param_vals.len() < 2 { + continue; + } + + let corr = spearman(¶m_vals, &obj_vals).abs(); + + // Determine label: use param_labels if available, else "param_{id}". + let label = complete + .iter() + .find_map(|t| t.param_labels.get(param_id)) + .map_or_else(|| param_id.to_string(), Clone::clone); + + scores.push((label, corr)); + } + + // Normalize so scores sum to 1.0. + let sum: f64 = scores.iter().map(|(_, s)| *s).sum(); + if sum > 0.0 { + for entry in &mut scores { + entry.1 /= sum; + } + } + + // Sort descending by score. + scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(core::cmp::Ordering::Equal)); + + scores + } + + /// Compute parameter importance using fANOVA (functional ANOVA) with + /// default configuration. + /// + /// Fits a random forest to the trial data and decomposes variance into + /// per-parameter main effects and pairwise interaction effects. This is + /// more accurate than correlation-based importance ([`Self::param_importance`]) + /// and can detect non-linear relationships and parameter interactions. + /// + /// # Errors + /// + /// Returns [`crate::Error::NoCompletedTrials`] if fewer than 2 trials have completed. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// let y = FloatParam::new(0.0, 10.0).name("y"); + /// + /// study + /// .optimize(30, |trial: &mut optimizer::Trial| { + /// let xv = x.suggest(trial)?; + /// let yv = y.suggest(trial)?; + /// Ok::<_, optimizer::Error>(xv * xv + 0.1 * yv) + /// }) + /// .unwrap(); + /// + /// let result = study.fanova().unwrap(); + /// assert!(!result.main_effects.is_empty()); + /// ``` + pub fn fanova(&self) -> crate::Result<crate::fanova::FanovaResult> { + self.fanova_with_config(&crate::fanova::FanovaConfig::default()) + } + + /// Compute parameter importance using fANOVA with custom configuration. + /// + /// See [`Self::fanova`] for details. The [`FanovaConfig`](crate::fanova::FanovaConfig) + /// allows tuning the number of trees, tree depth, and random seed. + /// + /// # Errors + /// + /// Returns [`crate::Error::NoCompletedTrials`] if fewer than 2 trials have completed. + #[allow(clippy::cast_precision_loss)] + pub fn fanova_with_config( + &self, + config: &crate::fanova::FanovaConfig, + ) -> crate::Result<crate::fanova::FanovaResult> { + use std::collections::BTreeSet; + + use crate::fanova::compute_fanova; + use crate::param::ParamValue; + use crate::types::TrialState; + + let trials = self.storage.trials_arc().read(); + let complete: Vec<_> = trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .collect(); + + if complete.len() < 2 { + return Err(crate::Error::NoCompletedTrials); + } + + // Collect all parameter IDs in a stable order. + let all_param_ids: Vec<_> = { + let set: BTreeSet<_> = complete.iter().flat_map(|t| t.params.keys()).collect(); + set.into_iter().collect() + }; + + if all_param_ids.is_empty() { + return Ok(crate::fanova::FanovaResult { + main_effects: Vec::new(), + interactions: Vec::new(), + }); + } + + // Build feature matrix (only trials that have all parameters). + let mut data = Vec::with_capacity(complete.len()); + let mut targets = Vec::with_capacity(complete.len()); + + for trial in &complete { + let mut row = Vec::with_capacity(all_param_ids.len()); + let mut has_all = true; + + for &pid in &all_param_ids { + if let Some(pv) = trial.params.get(pid) { + row.push(match *pv { + ParamValue::Float(v) => v, + ParamValue::Int(v) => v as f64, + ParamValue::Categorical(v) => v as f64, + }); + } else { + has_all = false; + break; + } + } + + if has_all { + data.push(row); + targets.push(trial.value.clone().into()); + } + } + + if data.len() < 2 { + return Err(crate::Error::NoCompletedTrials); + } + + // Build feature names from parameter labels. + let feature_names: Vec<String> = all_param_ids + .iter() + .map(|&pid| { + complete + .iter() + .find_map(|t| t.param_labels.get(pid)) + .map_or_else(|| pid.to_string(), Clone::clone) + }) + .collect(); + + Ok(compute_fanova(&data, &targets, &feature_names, config)) + } +} diff --git a/src/study/async_impl.rs b/src/study/async_impl.rs new file mode 100644 index 0000000..ca9e0a5 --- /dev/null +++ b/src/study/async_impl.rs @@ -0,0 +1,283 @@ +use core::ops::ControlFlow; +use std::sync::Arc; + +use crate::trial::Trial; +use crate::types::TrialState; + +use super::{Study, is_trial_pruned}; + +impl<V> Study<V> +where + V: PartialOrd, +{ + /// Run async optimization with an objective. + /// + /// Like [`optimize`](Self::optimize), but each evaluation is wrapped in + /// [`spawn_blocking`](tokio::task::spawn_blocking), keeping the async + /// runtime responsive for CPU-bound objectives. Trials run sequentially. + /// + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures. Struct-based objectives can override lifecycle hooks. + /// + /// # Errors + /// + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. + /// Returns `Error::TaskError` if a spawned blocking task panics. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// # #[cfg(feature = "async")] + /// # async fn example() -> optimizer::Result<()> { + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + /// let x_param = FloatParam::new(-10.0, 10.0); + /// + /// study + /// .optimize_async(10, move |trial: &mut optimizer::Trial| { + /// let x = x_param.suggest(trial)?; + /// Ok::<_, optimizer::Error>(x * x) + /// }) + /// .await?; + /// + /// assert!(study.n_trials() > 0); + /// # Ok(()) + /// # } + /// ``` + pub async fn optimize_async<O>(&self, n_trials: usize, objective: O) -> crate::Result<()> + where + O: crate::objective::Objective<V> + Send + Sync + 'static, + O::Error: Send, + V: Clone + Default + Send + 'static, + { + #[cfg(feature = "tracing")] + let _span = + tracing::info_span!("optimize_async", n_trials, direction = ?self.direction).entered(); + + let objective = Arc::new(objective); + + for _ in 0..n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + + let obj = Arc::clone(&objective); + let mut trial = self.create_trial(); + let result = tokio::task::spawn_blocking(move || { + let res = obj.evaluate(&mut trial); + (trial, res) + }) + .await + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + + let completed = t.into_completed(value, TrialState::Complete); + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); + trace_info!(trial_id, "trial completed"); + + if let ControlFlow::Break(()) = flow { + return Ok(()); + } + } + (t, Err(e)) if is_trial_pruned(&e) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + } + } + + let has_complete = self + .storage + .trials_arc() + .read() + .iter() + .any(|t| t.state == TrialState::Complete); + if !has_complete { + return Err(crate::Error::NoCompletedTrials); + } + + Ok(()) + } + + /// Run parallel optimization with an objective. + /// + /// Spawns up to `concurrency` evaluations concurrently using + /// [`spawn_blocking`](tokio::task::spawn_blocking). Results are + /// collected via a [`JoinSet`](tokio::task::JoinSet). + /// + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures. The [`after_trial`](crate::Objective::after_trial) + /// hook fires as each result arrives — returning `Break` stops spawning + /// new trials while in-flight tasks drain. + /// + /// # Errors + /// + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. + /// Returns `Error::TaskError` if the semaphore is closed or a spawned task panics. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// # #[cfg(feature = "async")] + /// # async fn example() -> optimizer::Result<()> { + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + /// let x_param = FloatParam::new(-10.0, 10.0); + /// + /// study + /// .optimize_parallel(10, 4, move |trial: &mut optimizer::Trial| { + /// let x = x_param.suggest(trial)?; + /// Ok::<_, optimizer::Error>(x * x) + /// }) + /// .await?; + /// + /// assert_eq!(study.n_trials(), 10); + /// # Ok(()) + /// # } + /// ``` + #[allow(clippy::missing_panics_doc, clippy::too_many_lines)] + pub async fn optimize_parallel<O>( + &self, + n_trials: usize, + concurrency: usize, + objective: O, + ) -> crate::Result<()> + where + O: crate::objective::Objective<V> + Send + Sync + 'static, + O::Error: Send, + V: Clone + Default + Send + 'static, + { + use tokio::sync::Semaphore; + use tokio::task::JoinSet; + + #[cfg(feature = "tracing")] + let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); + + let objective = Arc::new(objective); + let semaphore = Arc::new(Semaphore::new(concurrency)); + let mut join_set: JoinSet<(Trial, Result<V, O::Error>)> = JoinSet::new(); + let mut spawned = 0; + + 'spawn: while spawned < n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + + // If the join set is full, drain one result to free a slot. + while join_set.len() >= concurrency { + let result = join_set + .join_next() + .await + .expect("join_set should not be empty") + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + + let completed = t.into_completed(value, TrialState::Complete); + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); + trace_info!(trial_id, "trial completed"); + + if let ControlFlow::Break(()) = flow { + break 'spawn; + } + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + } + } + } + + let permit = semaphore + .clone() + .acquire_owned() + .await + .map_err(|e| crate::Error::TaskError(e.to_string()))?; + + let mut trial = self.create_trial(); + let obj = Arc::clone(&objective); + join_set.spawn(async move { + let result = tokio::task::spawn_blocking(move || { + let res = obj.evaluate(&mut trial); + (trial, res) + }) + .await + .expect("spawn_blocking should not panic"); + drop(permit); + result + }); + spawned += 1; + } + + // Drain remaining in-flight tasks. + while let Some(result) = join_set.join_next().await { + let result = result.map_err(|e| crate::Error::TaskError(e.to_string()))?; + match result { + (t, Ok(value)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + + let completed = t.into_completed(value, TrialState::Complete); + // Still fire after_trial for bookkeeping, but don't break — we're draining. + let _ = objective.after_trial(self, &completed); + self.storage.push(completed); + trace_info!(trial_id, "trial completed"); + } + (t, Err(e)) => { + #[cfg(feature = "tracing")] + let trial_id = t.id(); + if is_trial_pruned(&e) { + self.prune_trial(t); + trace_info!(trial_id, "trial pruned"); + } else { + self.fail_trial(t, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + } + } + } + + let has_complete = self + .storage + .trials_arc() + .read() + .iter() + .any(|t| t.state == TrialState::Complete); + if !has_complete { + return Err(crate::Error::NoCompletedTrials); + } + + Ok(()) + } +} diff --git a/src/study/builder.rs b/src/study/builder.rs new file mode 100644 index 0000000..c3751c8 --- /dev/null +++ b/src/study/builder.rs @@ -0,0 +1,135 @@ +use core::marker::PhantomData; +use std::collections::VecDeque; +use std::sync::Arc; + +use parking_lot::Mutex; + +use crate::pruner::{NopPruner, Pruner}; +use crate::sampler::Sampler; +use crate::sampler::random::RandomSampler; +use crate::types::Direction; + +use super::Study; + +/// A builder for constructing [`Study`] instances with a fluent API. +/// +/// Created via [`Study::builder()`]. Collects sampler, pruner, direction, +/// and storage options before constructing the study. +/// +/// # Defaults +/// +/// - Direction: [`Minimize`](Direction::Minimize) +/// - Sampler: [`RandomSampler`] +/// - Pruner: [`NopPruner`] +/// - Storage: [`MemoryStorage`](crate::storage::MemoryStorage) +/// +/// # Examples +/// +/// ``` +/// use optimizer::prelude::*; +/// +/// let study: Study<f64> = Study::builder() +/// .maximize() +/// .sampler(TpeSampler::new()) +/// .pruner(MedianPruner::new(Direction::Maximize).n_warmup_steps(5)) +/// .build(); +/// +/// assert_eq!(study.direction(), Direction::Maximize); +/// ``` +pub struct StudyBuilder<V: PartialOrd = f64> { + direction: Direction, + sampler: Option<Box<dyn Sampler>>, + pruner: Option<Box<dyn Pruner>>, + storage: Option<Box<dyn crate::storage::Storage<V>>>, + _marker: PhantomData<V>, +} + +impl<V: PartialOrd> StudyBuilder<V> { + /// Create a new builder with default settings. + pub(super) fn new() -> Self { + Self { + direction: Direction::Minimize, + sampler: None, + pruner: None, + storage: None, + _marker: PhantomData, + } + } + + /// Set the optimization direction to minimize (the default). + #[must_use] + pub fn minimize(mut self) -> Self { + self.direction = Direction::Minimize; + self + } + + /// Set the optimization direction to maximize. + #[must_use] + pub fn maximize(mut self) -> Self { + self.direction = Direction::Maximize; + self + } + + /// Set the optimization direction explicitly. + #[must_use] + pub fn direction(mut self, direction: Direction) -> Self { + self.direction = direction; + self + } + + /// Set the sampler used for parameter suggestions. + /// + /// Defaults to [`RandomSampler`] if not specified. + #[must_use] + pub fn sampler(mut self, sampler: impl Sampler + 'static) -> Self { + self.sampler = Some(Box::new(sampler)); + self + } + + /// Set the pruner used for early stopping of trials. + /// + /// Defaults to [`NopPruner`] (no pruning) if not specified. + #[must_use] + pub fn pruner(mut self, pruner: impl Pruner + 'static) -> Self { + self.pruner = Some(Box::new(pruner)); + self + } + + /// Set a custom storage backend. + /// + /// Defaults to [`MemoryStorage`](crate::storage::MemoryStorage) if not specified. + #[must_use] + pub fn storage(mut self, storage: impl crate::storage::Storage<V> + 'static) -> Self { + self.storage = Some(Box::new(storage)); + self + } + + /// Build the [`Study`] with the configured options. + #[must_use] + pub fn build(self) -> Study<V> + where + V: Send + Sync + 'static, + { + let sampler = self + .sampler + .unwrap_or_else(|| Box::new(RandomSampler::new())); + let pruner = self.pruner.unwrap_or_else(|| Box::new(NopPruner)); + let storage = self + .storage + .unwrap_or_else(|| Box::new(crate::storage::MemoryStorage::<V>::new())); + + let sampler: Arc<dyn Sampler> = Arc::from(sampler); + let pruner: Arc<dyn Pruner> = Arc::from(pruner); + let storage: Arc<dyn crate::storage::Storage<V>> = Arc::from(storage); + let trial_factory = Study::make_trial_factory(&sampler, &storage, &pruner); + + Study { + direction: self.direction, + sampler, + pruner, + storage, + trial_factory, + enqueued_params: Arc::new(Mutex::new(VecDeque::new())), + } + } +} diff --git a/src/study/export.rs b/src/study/export.rs new file mode 100644 index 0000000..913ab2b --- /dev/null +++ b/src/study/export.rs @@ -0,0 +1,257 @@ +use core::fmt; + +use crate::parameter::ParamId; +use crate::types::{Direction, TrialState}; + +use super::Study; + +impl<V> Study<V> +where + V: PartialOrd + Clone + fmt::Display, +{ + /// Write completed trials to a writer in CSV format. + /// + /// Columns: `trial_id`, `value`, `state`, then one column per unique + /// parameter label, then one column per unique user-attribute key. + /// + /// Parameters without labels use a generated name (`param_<id>`). + /// Pruned trials have an empty `value` cell. + /// + /// # Errors + /// + /// Returns an I/O error if writing fails. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// + /// let mut trial = study.create_trial(); + /// let _ = x.suggest(&mut trial); + /// study.complete_trial(trial, 0.42); + /// + /// let mut buf = Vec::new(); + /// study.to_csv(&mut buf).unwrap(); + /// let csv = String::from_utf8(buf).unwrap(); + /// assert!(csv.contains("trial_id")); + /// ``` + pub fn to_csv(&self, mut writer: impl std::io::Write) -> std::io::Result<()> { + use std::collections::BTreeMap; + + let trials = self.storage.trials_arc().read(); + + // Collect all unique parameter labels (sorted for deterministic column order). + let mut param_columns: BTreeMap<ParamId, String> = BTreeMap::new(); + for trial in trials.iter() { + for &id in trial.params.keys() { + param_columns.entry(id).or_insert_with(|| { + trial + .param_labels + .get(&id) + .cloned() + .unwrap_or_else(|| id.to_string()) + }); + } + } + // Fill in labels from other trials that might have better labels. + for trial in trials.iter() { + for (&id, label) in &trial.param_labels { + param_columns.entry(id).or_insert_with(|| label.clone()); + } + } + + // Collect all unique attribute keys (sorted). + let mut attr_keys: Vec<String> = Vec::new(); + for trial in trials.iter() { + for key in trial.user_attrs.keys() { + if !attr_keys.contains(key) { + attr_keys.push(key.clone()); + } + } + } + attr_keys.sort(); + + let param_ids: Vec<ParamId> = param_columns.keys().copied().collect(); + + // Write header. + write!(writer, "trial_id,value,state")?; + for id in ¶m_ids { + write!(writer, ",{}", csv_escape(¶m_columns[id]))?; + } + for key in &attr_keys { + write!(writer, ",{}", csv_escape(key))?; + } + writeln!(writer)?; + + // Write one row per trial. + for trial in trials.iter() { + write!(writer, "{}", trial.id)?; + + // Value: empty for pruned trials. + if trial.state == TrialState::Complete { + write!(writer, ",{}", trial.value)?; + } else { + write!(writer, ",")?; + } + + write!( + writer, + ",{}", + match trial.state { + TrialState::Complete => "Complete", + TrialState::Pruned => "Pruned", + TrialState::Failed => "Failed", + TrialState::Running => "Running", + } + )?; + + for id in ¶m_ids { + if let Some(pv) = trial.params.get(id) { + write!(writer, ",{pv}")?; + } else { + write!(writer, ",")?; + } + } + + for key in &attr_keys { + if let Some(attr) = trial.user_attrs.get(key) { + write!(writer, ",{}", csv_escape(&format_attr(attr)))?; + } else { + write!(writer, ",")?; + } + } + + writeln!(writer)?; + } + + Ok(()) + } + + /// Export completed trials to a CSV file at the given path. + /// + /// Convenience wrapper around [`to_csv`](Self::to_csv) that creates a + /// buffered file writer. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be created or written. + pub fn export_csv(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { + let file = std::fs::File::create(path)?; + self.to_csv(std::io::BufWriter::new(file)) + } + + /// Return a human-readable summary of the study. + /// + /// The summary includes: + /// - Optimization direction and total trial count + /// - Breakdown by state (complete, pruned) when applicable + /// - Best trial value and parameters (if any completed trials exist) + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0).name("x"); + /// + /// let mut trial = study.create_trial(); + /// let _ = x.suggest(&mut trial).unwrap(); + /// study.complete_trial(trial, 0.42); + /// + /// let summary = study.summary(); + /// assert!(summary.contains("Minimize")); + /// assert!(summary.contains("0.42")); + /// ``` + #[must_use] + pub fn summary(&self) -> String { + use fmt::Write; + + let trials = self.storage.trials_arc().read(); + let n_complete = trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .count(); + let n_pruned = trials + .iter() + .filter(|t| t.state == TrialState::Pruned) + .count(); + + let direction_str = match self.direction { + Direction::Minimize => "Minimize", + Direction::Maximize => "Maximize", + }; + + let mut s = format!("Study: {direction_str} | {n} trials", n = trials.len()); + if n_pruned > 0 { + let _ = write!(s, " ({n_complete} complete, {n_pruned} pruned)"); + } + + drop(trials); + + if let Ok(best) = self.best_trial() { + let _ = write!(s, "\nBest value: {} (trial #{})", best.value, best.id); + if !best.params.is_empty() { + s.push_str("\nBest parameters:"); + let mut params: Vec<_> = best.params.iter().collect(); + params.sort_by_key(|(id, _)| *id); + for (id, value) in params { + let label = best.param_labels.get(id).map_or("?", String::as_str); + let _ = write!(s, "\n {label} = {value}"); + } + } + } + + s + } +} + +impl<V> fmt::Display for Study<V> +where + V: PartialOrd + Clone + fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.summary()) + } +} + +impl Study<f64> { + /// Generate an HTML report with interactive Plotly.js charts. + /// + /// Create a self-contained HTML file that can be opened in any browser. + /// See [`generate_html_report`](crate::visualization::generate_html_report) + /// for details on the included charts. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be created or written. + pub fn export_html(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { + crate::visualization::generate_html_report(self, path) + } +} + +/// Escape a string for CSV output. If the value contains a comma, quote, or +/// newline, wrap it in double-quotes and double any embedded quotes. +fn csv_escape(s: &str) -> String { + if s.contains(',') || s.contains('"') || s.contains('\n') { + format!("\"{}\"", s.replace('"', "\"\"")) + } else { + s.to_string() + } +} + +/// Format an `AttrValue` as a string for CSV cells. +fn format_attr(attr: &crate::trial::AttrValue) -> String { + use crate::trial::AttrValue; + match attr { + AttrValue::Float(v) => v.to_string(), + AttrValue::Int(v) => v.to_string(), + AttrValue::String(v) => v.clone(), + AttrValue::Bool(v) => v.to_string(), + } +} diff --git a/src/study/iter.rs b/src/study/iter.rs new file mode 100644 index 0000000..01c8e88 --- /dev/null +++ b/src/study/iter.rs @@ -0,0 +1,43 @@ +use crate::sampler::CompletedTrial; + +use super::Study; + +impl<V> Study<V> +where + V: PartialOrd + Clone, +{ + /// Return an iterator over all completed trials. + /// + /// This clones the internal trial list, so it is suitable for + /// analysis and iteration but not for hot paths. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let trial = study.create_trial(); + /// study.complete_trial(trial, 1.0); + /// + /// for t in study.iter() { + /// println!("Trial {} → {}", t.id, t.value); + /// } + /// ``` + #[must_use] + pub fn iter(&self) -> std::vec::IntoIter<CompletedTrial<V>> { + self.trials().into_iter() + } +} + +impl<V> IntoIterator for &Study<V> +where + V: PartialOrd + Clone, +{ + type Item = CompletedTrial<V>; + type IntoIter = std::vec::IntoIter<CompletedTrial<V>>; + + fn into_iter(self) -> Self::IntoIter { + self.iter() + } +} diff --git a/src/study/mod.rs b/src/study/mod.rs new file mode 100644 index 0000000..f2981c5 --- /dev/null +++ b/src/study/mod.rs @@ -0,0 +1,766 @@ +//! Study implementation for managing optimization trials. + +use core::any::Any; +use std::collections::{HashMap, VecDeque}; +use std::sync::Arc; + +use parking_lot::{Mutex, RwLock}; + +use crate::param::ParamValue; +use crate::parameter::ParamId; +use crate::pruner::{NopPruner, Pruner}; +use crate::sampler::random::RandomSampler; +use crate::sampler::{CompletedTrial, Sampler}; +use crate::trial::Trial; +use crate::types::{Direction, TrialState}; + +mod analysis; +mod builder; +mod export; +mod iter; +mod optimize; +mod persistence; + +#[cfg(feature = "async")] +mod async_impl; + +pub use builder::StudyBuilder; +#[cfg(feature = "serde")] +pub use persistence::StudySnapshot; + +/// A study manages the optimization process, tracking trials and their results. +/// +/// The study is parameterized by the objective value type `V`, which defaults to `f64`. +/// The only constraint on `V` is `PartialOrd`, allowing comparison of objective values +/// to determine which trial is best. +/// +/// When `V = f64`, the study passes trial history to the sampler for informed +/// parameter suggestions (e.g., TPE sampler uses history to guide sampling). +/// +/// # Examples +/// +/// ``` +/// use optimizer::{Direction, Study}; +/// +/// // Create a study to minimize an objective function +/// let study: Study<f64> = Study::new(Direction::Minimize); +/// assert_eq!(study.direction(), Direction::Minimize); +/// ``` +pub struct Study<V = f64> +where + V: PartialOrd, +{ + /// The optimization direction. + pub(crate) direction: Direction, + /// The sampler used to generate parameter values. + pub(crate) sampler: Arc<dyn Sampler>, + /// The pruner used to decide whether to stop trials early. + pub(crate) pruner: Arc<dyn Pruner>, + /// Trial storage backend (default: [`MemoryStorage`](crate::storage::MemoryStorage)). + pub(crate) storage: Arc<dyn crate::storage::Storage<V>>, + /// Optional factory for creating sampler-aware trials. + /// Set automatically for `Study<f64>` so that `create_trial()` and all + /// optimization methods use the sampler without requiring `_with_sampler` suffixes. + pub(crate) trial_factory: Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>>, + /// Queue of parameter configurations to evaluate next. + pub(crate) enqueued_params: Arc<Mutex<VecDeque<HashMap<ParamId, ParamValue>>>>, +} + +impl<V> Study<V> +where + V: PartialOrd, +{ + /// Create a new study with the given optimization direction. + /// + /// Uses the default `RandomSampler` for parameter sampling. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// assert_eq!(study.direction(), Direction::Minimize); + /// ``` + #[must_use] + pub fn new(direction: Direction) -> Self + where + V: Send + Sync + 'static, + { + Self::with_sampler(direction, RandomSampler::new()) + } + + /// Return a [`StudyBuilder`] for constructing a study with a fluent API. + /// + /// # Examples + /// + /// ``` + /// use optimizer::prelude::*; + /// + /// let study: Study<f64> = Study::builder() + /// .minimize() + /// .sampler(TpeSampler::new()) + /// .pruner(NopPruner) + /// .build(); + /// ``` + #[must_use] + pub fn builder() -> StudyBuilder<V> { + StudyBuilder::new() + } + + /// Create a study that minimizes the objective value. + /// + /// This is a shorthand for `Study::with_sampler(Direction::Minimize, sampler)`. + /// + /// # Arguments + /// + /// * `sampler` - The sampler to use for parameter sampling. + /// + /// # Examples + /// + /// ``` + /// use optimizer::Study; + /// use optimizer::sampler::tpe::TpeSampler; + /// + /// let study: Study<f64> = Study::minimize(TpeSampler::new()); + /// assert_eq!(study.direction(), optimizer::Direction::Minimize); + /// ``` + #[must_use] + pub fn minimize(sampler: impl Sampler + 'static) -> Self + where + V: Send + Sync + 'static, + { + Self::with_sampler(Direction::Minimize, sampler) + } + + /// Create a study that maximizes the objective value. + /// + /// This is a shorthand for `Study::with_sampler(Direction::Maximize, sampler)`. + /// + /// # Arguments + /// + /// * `sampler` - The sampler to use for parameter sampling. + /// + /// # Examples + /// + /// ``` + /// use optimizer::Study; + /// use optimizer::sampler::tpe::TpeSampler; + /// + /// let study: Study<f64> = Study::maximize(TpeSampler::new()); + /// assert_eq!(study.direction(), optimizer::Direction::Maximize); + /// ``` + #[must_use] + pub fn maximize(sampler: impl Sampler + 'static) -> Self + where + V: Send + Sync + 'static, + { + Self::with_sampler(Direction::Maximize, sampler) + } + + /// Create a new study with a custom sampler. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// + /// # Examples + /// + /// ``` + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler(Direction::Maximize, sampler); + /// assert_eq!(study.direction(), Direction::Maximize); + /// ``` + pub fn with_sampler(direction: Direction, sampler: impl Sampler + 'static) -> Self + where + V: Send + Sync + 'static, + { + Self::with_sampler_and_storage( + direction, + sampler, + crate::storage::MemoryStorage::<V>::new(), + ) + } + + /// Build a trial factory for sampler integration when `V = f64`. + pub(crate) fn make_trial_factory( + sampler: &Arc<dyn Sampler>, + storage: &Arc<dyn crate::storage::Storage<V>>, + pruner: &Arc<dyn Pruner>, + ) -> Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>> + where + V: 'static, + { + // Try to downcast the storage's trial buffer to the f64 specialization. + // This succeeds only when V = f64, enabling automatic sampler integration. + let trials_arc = storage.trials_arc(); + let any_ref: &dyn Any = trials_arc; + let f64_trials: Option<&Arc<RwLock<Vec<CompletedTrial<f64>>>>> = any_ref.downcast_ref(); + + f64_trials.map(|trials| { + let sampler = Arc::clone(sampler); + let trials = Arc::clone(trials); + let pruner = Arc::clone(pruner); + let factory: Arc<dyn Fn(u64) -> Trial + Send + Sync> = Arc::new(move |id| { + Trial::with_sampler( + id, + Arc::clone(&sampler), + Arc::clone(&trials), + Arc::clone(&pruner), + ) + }); + factory + }) + } + + /// Create a study with a custom sampler and storage backend. + /// + /// This is the most general constructor — all other constructors + /// delegate to this one. Use it when you need a non-default storage + /// backend (e.g., [`JournalStorage`](crate::storage::JournalStorage)). + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `storage` - The storage backend for completed trials. + /// + /// # Examples + /// + /// ``` + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::storage::MemoryStorage; + /// use optimizer::{Direction, Study}; + /// + /// let storage = MemoryStorage::<f64>::new(); + /// let study = Study::with_sampler_and_storage(Direction::Minimize, RandomSampler::new(), storage); + /// ``` + pub fn with_sampler_and_storage( + direction: Direction, + sampler: impl Sampler + 'static, + storage: impl crate::storage::Storage<V> + 'static, + ) -> Self + where + V: 'static, + { + let sampler: Arc<dyn Sampler> = Arc::new(sampler); + let pruner: Arc<dyn Pruner> = Arc::new(NopPruner); + let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage); + let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); + + Self { + direction, + sampler, + pruner, + storage, + trial_factory, + enqueued_params: Arc::new(Mutex::new(VecDeque::new())), + } + } + + /// Return the optimization direction. + #[must_use] + pub fn direction(&self) -> Direction { + self.direction + } + + /// Creates a study with a custom sampler and pruner. + /// + /// Uses the default [`MemoryStorage`](crate::storage::MemoryStorage) backend. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `pruner` - The pruner to use for trial pruning. + /// + /// # Examples + /// + /// ``` + /// use optimizer::pruner::NopPruner; + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler_and_pruner(Direction::Minimize, sampler, NopPruner); + /// ``` + pub fn with_sampler_and_pruner( + direction: Direction, + sampler: impl Sampler + 'static, + pruner: impl Pruner + 'static, + ) -> Self + where + V: Send + Sync + 'static, + { + let sampler: Arc<dyn Sampler> = Arc::new(sampler); + let pruner: Arc<dyn Pruner> = Arc::new(pruner); + let storage: Arc<dyn crate::storage::Storage<V>> = + Arc::new(crate::storage::MemoryStorage::<V>::new()); + let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); + + Self { + direction, + sampler, + pruner, + storage, + trial_factory, + enqueued_params: Arc::new(Mutex::new(VecDeque::new())), + } + } + + /// Replace the sampler used for future parameter suggestions. + /// + /// The new sampler takes effect for all subsequent calls to + /// [`create_trial`](Self::create_trial), [`ask`](Self::ask), and the + /// `optimize*` family. Already-completed trials are unaffected. + /// + /// # Examples + /// + /// ``` + /// use optimizer::sampler::tpe::TpeSampler; + /// use optimizer::{Direction, Study}; + /// + /// let mut study: Study<f64> = Study::new(Direction::Minimize); + /// study.set_sampler(TpeSampler::new()); + /// ``` + pub fn set_sampler(&mut self, sampler: impl Sampler + 'static) + where + V: 'static, + { + self.sampler = Arc::new(sampler); + self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); + } + + /// Replace the pruner used for future trials. + /// + /// The new pruner takes effect for all trials created after this call. + /// + /// # Examples + /// + /// ``` + /// use optimizer::prelude::*; + /// + /// let mut study: Study<f64> = Study::new(Direction::Minimize); + /// study.set_pruner(MedianPruner::new(Direction::Minimize)); + /// ``` + pub fn set_pruner(&mut self, pruner: impl Pruner + 'static) + where + V: 'static, + { + self.pruner = Arc::new(pruner); + self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner); + } + + /// Return a reference to the study's current pruner. + #[must_use] + pub fn pruner(&self) -> &dyn Pruner { + &*self.pruner + } + + /// Enqueue a specific parameter configuration to be evaluated next. + /// + /// The next call to [`ask()`](Self::ask) or the next trial in [`optimize()`](Self::optimize) + /// will use these exact parameters instead of sampling from the sampler. + /// + /// Multiple configurations can be enqueued; they are evaluated in FIFO order. + /// If an enqueued configuration is missing a parameter that the objective calls + /// `suggest()` on, that parameter falls back to normal sampling. + /// + /// # Arguments + /// + /// * `params` - A map from parameter IDs to the values to use. + /// + /// # Examples + /// + /// ``` + /// use std::collections::HashMap; + /// + /// use optimizer::parameter::{FloatParam, IntParam, ParamValue, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0); + /// let y = IntParam::new(1, 100); + /// + /// // Evaluate these specific configurations first + /// study.enqueue(HashMap::from([ + /// (x.id(), ParamValue::Float(0.001)), + /// (y.id(), ParamValue::Int(3)), + /// ])); + /// + /// // Next trial will use x=0.001, y=3 + /// let mut trial = study.ask(); + /// assert_eq!(x.suggest(&mut trial).unwrap(), 0.001); + /// assert_eq!(y.suggest(&mut trial).unwrap(), 3); + /// ``` + pub fn enqueue(&self, params: HashMap<ParamId, ParamValue>) { + self.enqueued_params.lock().push_back(params); + } + + /// Return the trial ID of the current best trial from the given slice. + #[cfg(feature = "tracing")] + pub(crate) fn best_id(&self, trials: &[CompletedTrial<V>]) -> Option<u64> { + let direction = self.direction; + trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .max_by(|a, b| Self::compare_trials(a, b, direction)) + .map(|t| t.id) + } + + /// Return the number of enqueued parameter configurations. + /// + /// See [`enqueue`](Self::enqueue) for how to add configurations. + #[must_use] + pub fn n_enqueued(&self) -> usize { + self.enqueued_params.lock().len() + } + + /// Generate the next unique trial ID. + pub(crate) fn next_trial_id(&self) -> u64 { + self.storage.next_trial_id() + } + + /// Create a new trial with a unique ID. + /// + /// The trial starts in the `Running` state and can be used to suggest + /// parameter values. After the objective function is evaluated, call + /// `complete_trial` or `fail_trial` to record the result. + /// + /// For `Study<f64>`, this method automatically integrates with the study's + /// sampler and trial history, so there is no need to call a separate + /// `create_trial_with_sampler()` method. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let trial = study.create_trial(); + /// assert_eq!(trial.id(), 0); + /// + /// let trial2 = study.create_trial(); + /// assert_eq!(trial2.id(), 1); + /// ``` + #[must_use] + pub fn create_trial(&self) -> Trial { + self.storage.refresh(); + + let id = self.next_trial_id(); + let mut trial = if let Some(factory) = &self.trial_factory { + factory(id) + } else { + Trial::new(id) + }; + + // If there are enqueued params, inject them into this trial + if let Some(fixed_params) = self.enqueued_params.lock().pop_front() { + trial.set_fixed_params(fixed_params); + } + + trial + } + + /// Record a completed trial with its objective value. + /// + /// This method stores the trial's parameters, distributions, and objective + /// value in the study's history. The stored data is used by samplers to + /// inform future parameter suggestions. + /// + /// # Arguments + /// + /// * `trial` - The trial that was evaluated. + /// * `value` - The objective value returned by the objective function. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x_param = FloatParam::new(0.0, 1.0); + /// let mut trial = study.create_trial(); + /// let x = x_param.suggest(&mut trial).unwrap(); + /// let objective_value = x * x; + /// study.complete_trial(trial, objective_value); + /// + /// assert_eq!(study.n_trials(), 1); + /// ``` + pub fn complete_trial(&self, trial: Trial, value: V) { + let completed = trial.into_completed(value, TrialState::Complete); + self.storage.push(completed); + } + + /// Record a failed trial with an error message. + /// + /// Failed trials are not stored in the study's history and do not + /// contribute to future sampling decisions. This method is useful + /// when the objective function raises an error that should not stop + /// the optimization process. + /// + /// # Arguments + /// + /// * `trial` - The trial that failed. + /// * `_error` - An error message describing why the trial failed. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let trial = study.create_trial(); + /// study.fail_trial(trial, "objective function raised an exception"); + /// + /// // Failed trials are not counted + /// assert_eq!(study.n_trials(), 0); + /// ``` + pub fn fail_trial(&self, mut trial: Trial, _error: impl ToString) { + trial.set_failed(); + // Failed trials are not stored in completed_trials + // They could be stored in a separate list for debugging if needed + } + + /// Request a new trial with suggested parameters. + /// + /// This is the first half of the ask-and-tell interface. After calling + /// `ask()`, use parameter types to suggest values on the returned trial, + /// evaluate your objective externally, then pass the trial back to + /// [`tell()`](Self::tell) with the result. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x = FloatParam::new(0.0, 10.0); + /// + /// let mut trial = study.ask(); + /// let x_val = x.suggest(&mut trial).unwrap(); + /// let value = x_val * x_val; + /// study.tell(trial, Ok::<_, &str>(value)); + /// ``` + #[must_use] + pub fn ask(&self) -> Trial { + self.create_trial() + } + + /// Report the result of a trial obtained from [`ask()`](Self::ask). + /// + /// Pass `Ok(value)` for a successful evaluation or `Err(reason)` for a + /// failure. Failed trials are not stored in the study's history. + /// + /// # Examples + /// + /// ``` + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// + /// let trial = study.ask(); + /// study.tell(trial, Ok::<_, &str>(42.0)); + /// assert_eq!(study.n_trials(), 1); + /// + /// let trial = study.ask(); + /// study.tell(trial, Err::<f64, _>("evaluation failed")); + /// assert_eq!(study.n_trials(), 1); // failed trials not counted + /// ``` + pub fn tell(&self, trial: Trial, value: core::result::Result<V, impl ToString>) { + match value { + Ok(v) => self.complete_trial(trial, v), + Err(e) => self.fail_trial(trial, e), + } + } + + /// Record a pruned trial, preserving its intermediate values. + /// + /// Pruned trials are stored alongside completed trials so that samplers + /// can optionally learn from partial evaluations. The trial's state is + /// set to [`Pruned`](crate::TrialState::Pruned). + /// + /// In practice you rarely call this directly — returning + /// `Err(TrialPruned)` from an objective function handles pruning + /// automatically. + /// + /// # Arguments + /// + /// * `trial` - The trial that was pruned. + pub fn prune_trial(&self, trial: Trial) + where + V: Default, + { + let completed = trial.into_completed(V::default(), TrialState::Pruned); + self.storage.push(completed); + } + + /// Return all completed trials as a `Vec`. + /// + /// The returned vector contains clones of `CompletedTrial` values, which contain + /// the trial's parameters, distributions, and objective value. + /// + /// Note: This method acquires a read lock on the completed trials, so the + /// returned vector is a clone of the internal storage. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// let x_param = FloatParam::new(0.0, 1.0); + /// let mut trial = study.create_trial(); + /// let _ = x_param.suggest(&mut trial); + /// study.complete_trial(trial, 0.5); + /// + /// for completed in study.trials() { + /// println!("Trial {} has value {:?}", completed.id, completed.value); + /// } + /// ``` + #[must_use] + pub fn trials(&self) -> Vec<CompletedTrial<V>> + where + V: Clone, + { + self.storage.trials_arc().read().clone() + } + + /// Return the number of completed trials. + /// + /// Failed trials are not counted. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = Study::new(Direction::Minimize); + /// assert_eq!(study.n_trials(), 0); + /// + /// let x_param = FloatParam::new(0.0, 1.0); + /// let mut trial = study.create_trial(); + /// let _ = x_param.suggest(&mut trial); + /// study.complete_trial(trial, 0.5); + /// assert_eq!(study.n_trials(), 1); + /// ``` + #[must_use] + pub fn n_trials(&self) -> usize { + self.storage.trials_arc().read().len() + } + + /// Return the number of pruned trials. + /// + /// Pruned trials are those that were stopped early by the pruner. + #[must_use] + pub fn n_pruned_trials(&self) -> usize { + self.storage + .trials_arc() + .read() + .iter() + .filter(|t| t.state == TrialState::Pruned) + .count() + } + + /// Compare two completed trials using constraint-aware ranking. + /// + /// 1. Feasible trials always rank above infeasible trials. + /// 2. Among feasible trials, rank by objective value (respecting direction). + /// 3. Among infeasible trials, rank by total constraint violation (lower is better). + pub(crate) fn compare_trials( + a: &CompletedTrial<V>, + b: &CompletedTrial<V>, + direction: Direction, + ) -> core::cmp::Ordering { + match (a.is_feasible(), b.is_feasible()) { + (true, false) => core::cmp::Ordering::Greater, + (false, true) => core::cmp::Ordering::Less, + (false, false) => { + let va: f64 = a.constraints.iter().map(|c| c.max(0.0)).sum(); + let vb: f64 = b.constraints.iter().map(|c| c.max(0.0)).sum(); + vb.partial_cmp(&va).unwrap_or(core::cmp::Ordering::Equal) + } + (true, true) => { + let ordering = a.value.partial_cmp(&b.value); + match direction { + Direction::Minimize => { + ordering.map_or(core::cmp::Ordering::Equal, core::cmp::Ordering::reverse) + } + Direction::Maximize => ordering.unwrap_or(core::cmp::Ordering::Equal), + } + } + } + } +} + +impl<V: PartialOrd + Send + Sync + 'static> Study<V> { + /// Create a study with a custom sampler, pruner, and storage backend. + /// + /// The most flexible constructor, allowing full control over all components. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `pruner` - The pruner to use for trial pruning. + /// * `storage` - The storage backend for completed trials. + /// + /// # Examples + /// + /// ``` + /// use optimizer::prelude::*; + /// use optimizer::storage::MemoryStorage; + /// + /// let study = Study::with_sampler_pruner_and_storage( + /// Direction::Minimize, + /// TpeSampler::new(), + /// MedianPruner::new(Direction::Minimize), + /// MemoryStorage::<f64>::new(), + /// ); + /// ``` + pub fn with_sampler_pruner_and_storage( + direction: Direction, + sampler: impl Sampler + 'static, + pruner: impl Pruner + 'static, + storage: impl crate::storage::Storage<V> + 'static, + ) -> Self { + let sampler: Arc<dyn Sampler> = Arc::new(sampler); + let pruner: Arc<dyn Pruner> = Arc::new(pruner); + let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage); + let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner); + + Self { + direction, + sampler, + pruner, + storage, + trial_factory, + enqueued_params: Arc::new(Mutex::new(VecDeque::new())), + } + } +} + +/// Returns `true` if the error represents a pruned trial. +/// +/// Checks via `Any` downcasting whether `e` is `Error::TrialPruned` or +/// the standalone `TrialPruned` struct. +pub(super) fn is_trial_pruned<E: 'static>(e: &E) -> bool { + let any: &dyn Any = e; + if let Some(err) = any.downcast_ref::<crate::Error>() { + matches!(err, crate::Error::TrialPruned) + } else { + any.downcast_ref::<crate::error::TrialPruned>().is_some() + } +} diff --git a/src/study/optimize.rs b/src/study/optimize.rs new file mode 100644 index 0000000..8ffb8e8 --- /dev/null +++ b/src/study/optimize.rs @@ -0,0 +1,123 @@ +use core::ops::ControlFlow; + +use crate::types::TrialState; + +use super::{Study, is_trial_pruned}; + +impl<V> Study<V> +where + V: PartialOrd, +{ + /// Run optimization with an objective. + /// + /// Accepts any [`Objective`](crate::Objective) implementation, including + /// plain closures (`Fn(&mut Trial) -> Result<V, E>`) thanks to the + /// blanket impl. Struct-based objectives can override + /// [`before_trial`](crate::Objective::before_trial) and + /// [`after_trial`](crate::Objective::after_trial) for early stopping. + /// + /// Runs up to `n_trials` evaluations sequentially. + /// + /// # Errors + /// + /// Returns `Error::NoCompletedTrials` if no trials completed successfully. + /// + /// # Examples + /// + /// ``` + /// use optimizer::parameter::{FloatParam, Parameter}; + /// use optimizer::sampler::random::RandomSampler; + /// use optimizer::{Direction, Study}; + /// + /// let sampler = RandomSampler::with_seed(42); + /// let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + /// let x_param = FloatParam::new(-10.0, 10.0); + /// + /// study + /// .optimize(10, |trial: &mut optimizer::Trial| { + /// let x = x_param.suggest(trial)?; + /// Ok::<_, optimizer::Error>(x * x) + /// }) + /// .unwrap(); + /// + /// assert!(study.n_trials() > 0); + /// assert!(study.best_value().unwrap() >= 0.0); + /// ``` + #[allow(clippy::needless_pass_by_value)] + pub fn optimize( + &self, + n_trials: usize, + objective: impl crate::objective::Objective<V>, + ) -> crate::Result<()> + where + V: Clone + Default, + { + #[cfg(feature = "tracing")] + let _span = + tracing::info_span!("optimize", n_trials, direction = ?self.direction).entered(); + + for _ in 0..n_trials { + if let ControlFlow::Break(()) = objective.before_trial(self) { + break; + } + + let mut trial = self.create_trial(); + match objective.evaluate(&mut trial) { + Ok(value) => { + #[cfg(feature = "tracing")] + let trial_id = trial.id(); + + let completed = trial.into_completed(value, TrialState::Complete); + + // Fire after_trial hook before pushing to storage + let flow = objective.after_trial(self, &completed); + self.storage.push(completed); + + #[cfg(feature = "tracing")] + { + tracing::info!(trial_id, "trial completed"); + let trials = self.storage.trials_arc().read(); + if trials + .iter() + .filter(|t| t.state == TrialState::Complete) + .count() + == 1 + || trials.last().map(|t| t.id) == self.best_id(&trials) + { + tracing::info!(trial_id, "new best value found"); + } + } + + if let ControlFlow::Break(()) = flow { + return Ok(()); + } + } + Err(e) if is_trial_pruned(&e) => { + #[cfg(feature = "tracing")] + let trial_id = trial.id(); + self.prune_trial(trial); + trace_info!(trial_id, "trial pruned"); + } + Err(e) => { + #[cfg(feature = "tracing")] + let trial_id = trial.id(); + self.fail_trial(trial, e.to_string()); + trace_debug!(trial_id, "trial failed"); + } + } + } + + // Return error if no trials completed successfully + let has_complete = self + .storage + .trials_arc() + .read() + .iter() + .any(|t| t.state == TrialState::Complete); + if !has_complete { + return Err(crate::Error::NoCompletedTrials); + } + + Ok(()) + } +} diff --git a/src/study/persistence.rs b/src/study/persistence.rs new file mode 100644 index 0000000..78b94cf --- /dev/null +++ b/src/study/persistence.rs @@ -0,0 +1,148 @@ +use std::collections::HashMap; + +use crate::sampler::CompletedTrial; +use crate::types::Direction; + +use super::Study; + +/// A serializable snapshot of a study's state. +/// +/// Since [`Study`] contains non-serializable fields (samplers, atomics, etc.), +/// this struct captures the essential state needed to save and restore a study. +/// +/// # Schema versioning +/// +/// The `version` field enables future schema evolution without breaking existing files. +/// The current version is `1`. +/// +/// # Sampler state +/// +/// Sampler state is **not** included in the snapshot. After loading, the study +/// uses a default `RandomSampler`. Call [`Study::set_sampler`] to restore +/// the desired sampler configuration. +#[cfg(feature = "serde")] +#[derive(serde::Serialize, serde::Deserialize)] +pub struct StudySnapshot<V> { + /// Schema version for forward compatibility. + pub version: u32, + /// The optimization direction. + pub direction: Direction, + /// All completed (and pruned) trials. + pub trials: Vec<CompletedTrial<V>>, + /// The next trial ID to assign. + pub next_trial_id: u64, + /// Optional metadata (creation timestamp, sampler description, etc.). + pub metadata: HashMap<String, String>, +} + +#[cfg(feature = "serde")] +impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { + /// Export trials as a pretty-printed JSON array to a file. + /// + /// Each element in the array is a serialized [`CompletedTrial`]. + /// Requires the `serde` feature. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be created or written. + pub fn export_json(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { + let file = std::fs::File::create(path)?; + let trials = self.trials(); + serde_json::to_writer_pretty(file, &trials).map_err(std::io::Error::other) + } + + /// Save the study state to a JSON file. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be created or written. + pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { + let path = path.as_ref(); + let trials = self.trials(); + let next_trial_id = trials.iter().map(|t| t.id).max().map_or(0, |id| id + 1); + let snapshot = StudySnapshot { + version: 1, + direction: self.direction, + trials, + next_trial_id, + metadata: HashMap::new(), + }; + + // Atomic write: write to a temp file in the same directory, then rename. + // This prevents corrupt files if the process crashes mid-write. + let parent = path.parent().unwrap_or(std::path::Path::new(".")); + let tmp_path = parent.join(format!( + ".{}.tmp", + path.file_name().unwrap_or_default().to_string_lossy() + )); + let file = std::fs::File::create(&tmp_path)?; + serde_json::to_writer_pretty(file, &snapshot).map_err(std::io::Error::other)?; + std::fs::rename(&tmp_path, path) + } +} + +#[cfg(feature = "serde")] +impl<V: PartialOrd + Send + Sync + Clone + serde::de::DeserializeOwned + 'static> Study<V> { + /// Load a study from a JSON file. + /// + /// The loaded study uses a `RandomSampler` by default. Call + /// [`set_sampler()`](Self::set_sampler) to restore the original sampler + /// configuration. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be read or parsed. + pub fn load(path: impl AsRef<std::path::Path>) -> std::io::Result<Self> { + use crate::sampler::random::RandomSampler; + + let file = std::fs::File::open(path)?; + let snapshot: StudySnapshot<V> = serde_json::from_reader(file) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + let storage = crate::storage::MemoryStorage::with_trials(snapshot.trials); + Ok(Self::with_sampler_and_storage( + snapshot.direction, + RandomSampler::new(), + storage, + )) + } +} + +#[cfg(feature = "journal")] +impl<V> Study<V> +where + V: PartialOrd + Send + Sync + serde::Serialize + serde::de::DeserializeOwned + 'static, +{ + /// Create a study backed by a JSONL journal file. + /// + /// Any existing trials in the file are loaded into memory and the + /// trial ID counter is set to one past the highest stored ID. New + /// trials are written through to the file on completion. + /// + /// # Arguments + /// + /// * `direction` - Whether to minimize or maximize the objective function. + /// * `sampler` - The sampler to use for parameter sampling. + /// * `path` - Path to the JSONL journal file (created if absent). + /// + /// # Errors + /// + /// Returns a [`Storage`](crate::Error::Storage) error if loading fails. + /// + /// # Examples + /// + /// ```no_run + /// use optimizer::sampler::tpe::TpeSampler; + /// use optimizer::{Direction, Study}; + /// + /// let study: Study<f64> = + /// Study::with_journal(Direction::Minimize, TpeSampler::new(), "trials.jsonl").unwrap(); + /// ``` + pub fn with_journal( + direction: Direction, + sampler: impl crate::sampler::Sampler + 'static, + path: impl AsRef<std::path::Path>, + ) -> crate::Result<Self> { + let storage = crate::storage::JournalStorage::<V>::open(path)?; + Ok(Self::with_sampler_and_storage(direction, sampler, storage)) + } +} From d80d2bb1fbca72284e2538fc540d912aa0197467 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:25:27 +0100 Subject: [PATCH 29/48] refactor(sampler): extract shared utilities to reduce duplication - Add sampler/common.rs with distribution helpers (internal_bounds, from_internal, to_internal, sample_random) used by 8 samplers - Add sampler/tpe/common.rs with TPE sampling functions (sample_tpe_float, sample_tpe_int, sample_tpe_categorical) shared by TpeSampler, MultivariateTpeSampler, and MotpeSampler - Remove ~940 lines of near-identical code across sampler modules --- src/sampler/cma_es.rs | 101 +---------- src/sampler/common.rs | 116 +++++++++++++ src/sampler/de.rs | 91 +--------- src/sampler/genetic.rs | 37 +---- src/sampler/gp.rs | 113 +------------ src/sampler/mod.rs | 1 + src/sampler/motpe.rs | 265 ++--------------------------- src/sampler/random.rs | 45 +---- src/sampler/tpe/common.rs | 218 ++++++++++++++++++++++++ src/sampler/tpe/mod.rs | 1 + src/sampler/tpe/multivariate.rs | 57 +------ src/sampler/tpe/sampler.rs | 285 +++----------------------------- 12 files changed, 385 insertions(+), 945 deletions(-) create mode 100644 src/sampler/common.rs create mode 100644 src/sampler/tpe/common.rs diff --git a/src/sampler/cma_es.rs b/src/sampler/cma_es.rs index f243a88..ac8b501 100644 --- a/src/sampler/cma_es.rs +++ b/src/sampler/cma_es.rs @@ -77,6 +77,8 @@ use crate::param::ParamValue; use crate::rng_util; use crate::sampler::{CompletedTrial, Sampler}; +use super::common::{from_internal, internal_bounds, sample_random}; + /// CMA-ES sampler for continuous optimization. /// /// Adapts a multivariate Gaussian to concentrate around promising regions @@ -671,58 +673,6 @@ fn clip_to_bounds(x: &mut DVector<f64>, dimensions: &[DimensionInfo]) { } } -/// Convert an internal-space value back to a `ParamValue`. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn from_internal(value: f64, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low) / step).round(); - d.low + k * step - } else { - v - }; - ParamValue::Float(v.clamp(d.low, d.high)) - } - Distribution::Int(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low as f64) / step as f64).round() as i64; - d.low + k * step - } else { - v.round() as i64 - }; - ParamValue::Int(v.clamp(d.low, d.high)) - } - Distribution::Categorical(_) => { - unreachable!("from_internal should not be called for categorical distributions") - } - } -} - -/// Compute internal-space bounds for a distribution. -#[allow(clippy::cast_precision_loss)] -fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> { - match distribution { - Distribution::Float(d) => { - if d.log_scale { - Some((d.low.ln(), d.high.ln())) - } else { - Some((d.low, d.high)) - } - } - Distribution::Int(d) => { - if d.log_scale { - Some(((d.low as f64).ln(), (d.high as f64).ln())) - } else { - Some((d.low as f64, d.high as f64)) - } - } - Distribution::Categorical(_) => None, - } -} - /// Sample a value from the standard normal distribution using Box-Muller transform. fn sample_standard_normal(rng: &mut fastrand::Rng) -> f64 { // Box-Muller transform @@ -731,51 +681,6 @@ fn sample_standard_normal(rng: &mut fastrand::Rng) -> f64 { (-2.0 * u1.ln()).sqrt() * u2.cos() } -/// Sample a categorical value randomly. -fn sample_random_categorical(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - _ => unreachable!("sample_random_categorical called with non-categorical distribution"), - } -} - -/// Sample a random value for any distribution (used during discovery phase). -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } -} - // --------------------------------------------------------------------------- // Sampler trait implementation // --------------------------------------------------------------------------- @@ -920,7 +825,7 @@ fn sample_active( if let Some(&cat_idx) = candidate.categorical_values.get(&dim_idx) { ParamValue::Categorical(cat_idx) } else { - sample_random_categorical(&mut state.rng, distribution) + sample_random(&mut state.rng, distribution) } } } diff --git a/src/sampler/common.rs b/src/sampler/common.rs new file mode 100644 index 0000000..a2640bd --- /dev/null +++ b/src/sampler/common.rs @@ -0,0 +1,116 @@ +//! Shared distribution-level utilities used across multiple samplers. + +use crate::distribution::Distribution; +use crate::param::ParamValue; +use crate::rng_util; + +/// Compute internal-space bounds for a distribution. +#[allow(clippy::cast_precision_loss)] +pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> { + match distribution { + Distribution::Float(d) => { + if d.log_scale { + Some((d.low.ln(), d.high.ln())) + } else { + Some((d.low, d.high)) + } + } + Distribution::Int(d) => { + if d.log_scale { + Some(((d.low as f64).ln(), (d.high as f64).ln())) + } else { + Some((d.low as f64, d.high as f64)) + } + } + Distribution::Categorical(_) => None, + } +} + +/// Convert an internal-space value back to a `ParamValue`. +#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +pub(crate) fn from_internal(value: f64, distribution: &Distribution) -> ParamValue { + match distribution { + Distribution::Float(d) => { + let v = if d.log_scale { value.exp() } else { value }; + let v = if let Some(step) = d.step { + let k = ((v - d.low) / step).round(); + d.low + k * step + } else { + v + }; + ParamValue::Float(v.clamp(d.low, d.high)) + } + Distribution::Int(d) => { + let v = if d.log_scale { value.exp() } else { value }; + let v = if let Some(step) = d.step { + let k = ((v - d.low as f64) / step as f64).round() as i64; + d.low + k * step + } else { + v.round() as i64 + }; + ParamValue::Int(v.clamp(d.low, d.high)) + } + Distribution::Categorical(_) => { + unreachable!("from_internal should not be called for categorical distributions") + } + } +} + +/// Convert a `ParamValue` to its internal-space representation. +#[allow(clippy::cast_precision_loss, dead_code)] +pub(crate) fn to_internal(value: &ParamValue, distribution: &Distribution) -> f64 { + match (value, distribution) { + (ParamValue::Float(v), Distribution::Float(d)) => { + if d.log_scale { + v.ln() + } else { + *v + } + } + (ParamValue::Int(v), Distribution::Int(d)) => { + if d.log_scale { + (*v as f64).ln() + } else { + *v as f64 + } + } + _ => 0.0, + } +} + +/// Sample a random value for any distribution. +#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] +pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { + match distribution { + Distribution::Float(d) => { + let value = if d.log_scale { + let log_low = d.low.ln(); + let log_high = d.high.ln(); + rng_util::f64_range(rng, log_low, log_high).exp() + } else if let Some(step) = d.step { + let n_steps = ((d.high - d.low) / step).floor() as i64; + let k = rng.i64(0..=n_steps); + d.low + (k as f64) * step + } else { + rng_util::f64_range(rng, d.low, d.high) + }; + ParamValue::Float(value) + } + Distribution::Int(d) => { + let value = if d.log_scale { + let log_low = (d.low as f64).ln(); + let log_high = (d.high as f64).ln(); + let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; + raw.clamp(d.low, d.high) + } else if let Some(step) = d.step { + let n_steps = (d.high - d.low) / step; + let k = rng.i64(0..=n_steps); + d.low + k * step + } else { + rng.i64(d.low..=d.high) + }; + ParamValue::Int(value) + } + Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), + } +} diff --git a/src/sampler/de.rs b/src/sampler/de.rs index bd6be26..274078a 100644 --- a/src/sampler/de.rs +++ b/src/sampler/de.rs @@ -70,6 +70,8 @@ use crate::param::ParamValue; use crate::rng_util; use crate::sampler::{CompletedTrial, Sampler}; +use super::common::{from_internal, internal_bounds, sample_random}; + /// Differential Evolution mutation strategy. /// /// Controls how mutant vectors are created from the current population. @@ -396,95 +398,6 @@ impl State { // Helpers // --------------------------------------------------------------------------- -/// Compute internal-space bounds for a distribution. -#[allow(clippy::cast_precision_loss)] -fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> { - match distribution { - Distribution::Float(d) => { - if d.log_scale { - Some((d.low.ln(), d.high.ln())) - } else { - Some((d.low, d.high)) - } - } - Distribution::Int(d) => { - if d.log_scale { - Some(((d.low as f64).ln(), (d.high as f64).ln())) - } else { - Some((d.low as f64, d.high as f64)) - } - } - Distribution::Categorical(_) => None, - } -} - -/// Convert an internal-space value back to a `ParamValue`. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn from_internal(value: f64, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low) / step).round(); - d.low + k * step - } else { - v - }; - ParamValue::Float(v.clamp(d.low, d.high)) - } - Distribution::Int(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low as f64) / step as f64).round() as i64; - d.low + k * step - } else { - v.round() as i64 - }; - ParamValue::Int(v.clamp(d.low, d.high)) - } - Distribution::Categorical(_) => { - unreachable!("from_internal should not be called for categorical distributions") - } - } -} - -/// Sample a random value for any distribution. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } -} - /// Sample a random value in internal space for a continuous dimension. fn sample_random_internal(rng: &mut fastrand::Rng, bounds: (f64, f64)) -> f64 { rng_util::f64_range(rng, bounds.0, bounds.1) diff --git a/src/sampler/genetic.rs b/src/sampler/genetic.rs index d2a539a..d6d6fa2 100644 --- a/src/sampler/genetic.rs +++ b/src/sampler/genetic.rs @@ -406,42 +406,7 @@ pub(crate) fn polynomial_mutation_f64( (x + delta_q * range).clamp(low, high) } -/// Random sampling for a single distribution. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } -} +pub(crate) use super::common::sample_random; // --------------------------------------------------------------------------- // Das-Dennis reference point generation diff --git a/src/sampler/gp.rs b/src/sampler/gp.rs index fe5d8f6..05baa3b 100644 --- a/src/sampler/gp.rs +++ b/src/sampler/gp.rs @@ -82,6 +82,8 @@ use crate::param::ParamValue; use crate::rng_util; use crate::sampler::{CompletedTrial, Sampler}; +use super::common::{from_internal, internal_bounds, sample_random, to_internal}; + // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- @@ -536,58 +538,6 @@ fn optimize_acquisition( // Data preprocessing helpers // --------------------------------------------------------------------------- -/// Compute internal-space bounds for a distribution. -#[allow(clippy::cast_precision_loss)] -fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> { - match distribution { - Distribution::Float(d) => { - if d.log_scale { - Some((d.low.ln(), d.high.ln())) - } else { - Some((d.low, d.high)) - } - } - Distribution::Int(d) => { - if d.log_scale { - Some(((d.low as f64).ln(), (d.high as f64).ln())) - } else { - Some((d.low as f64, d.high as f64)) - } - } - Distribution::Categorical(_) => None, - } -} - -/// Convert a value from internal space to a `ParamValue` in original space. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn from_internal(value: f64, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low) / step).round(); - d.low + k * step - } else { - v - }; - ParamValue::Float(v.clamp(d.low, d.high)) - } - Distribution::Int(d) => { - let v = if d.log_scale { value.exp() } else { value }; - let v = if let Some(step) = d.step { - let k = ((v - d.low as f64) / step as f64).round() as i64; - d.low + k * step - } else { - v.round() as i64 - }; - ParamValue::Int(v.clamp(d.low, d.high)) - } - Distribution::Categorical(_) => { - unreachable!("from_internal should not be called for categorical distributions") - } - } -} - /// Convert an internal-space value to normalized [0, 1] using bounds. fn to_normalized(value: f64, lo: f64, hi: f64) -> f64 { if (hi - lo).abs() < 1e-15 { @@ -602,65 +552,6 @@ fn from_normalized(value: f64, lo: f64, hi: f64) -> f64 { lo + value * (hi - lo) } -/// Convert a `ParamValue` to its internal-space representation. -#[allow(clippy::cast_precision_loss)] -fn to_internal(value: &ParamValue, distribution: &Distribution) -> f64 { - match (value, distribution) { - (ParamValue::Float(v), Distribution::Float(d)) => { - if d.log_scale { - v.ln() - } else { - *v - } - } - (ParamValue::Int(v), Distribution::Int(d)) => { - if d.log_scale { - (*v as f64).ln() - } else { - *v as f64 - } - } - _ => 0.0, - } -} - -/// Sample a random value for any distribution. -#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)] -fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } -} - // --------------------------------------------------------------------------- // Extract training data from history // --------------------------------------------------------------------------- diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index 4e20f4e..a660d11 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -3,6 +3,7 @@ pub mod bohb; #[cfg(feature = "cma-es")] pub mod cma_es; +pub(crate) mod common; pub mod de; pub(crate) mod genetic; #[cfg(feature = "gp")] diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index 883cfe0..a84580d 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -61,9 +61,10 @@ use core::sync::atomic::{AtomicU64, Ordering}; use crate::distribution::Distribution; -use crate::kde::KernelDensityEstimator; use crate::multi_objective::{MultiObjectiveSampler, MultiObjectiveTrial}; use crate::param::ParamValue; +use crate::sampler::common; +use crate::sampler::tpe::common as tpe_common; use crate::types::{Direction, TrialState}; use crate::{pareto, rng_util}; @@ -205,248 +206,6 @@ impl MotpeSampler { (good, bad) } - - /// Samples uniformly from a distribution (used during startup phase). - #[allow( - clippy::cast_possible_truncation, - clippy::cast_precision_loss, - clippy::unused_self - )] - fn sample_uniform(distribution: &Distribution, rng: &mut fastrand::Rng) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } - } - - /// Samples using TPE for float distributions. - #[allow(clippy::too_many_arguments)] - fn sample_tpe_float( - &self, - low: f64, - high: f64, - log_scale: bool, - step: Option<f64>, - good_values: Vec<f64>, - bad_values: Vec<f64>, - rng: &mut fastrand::Rng, - ) -> f64 { - // Transform to internal space (log space if needed) - let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { - let i_low = low.ln(); - let i_high = high.ln(); - let g = { - let mut v = good_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - let b = { - let mut v = bad_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - (i_low, i_high, g, b) - } else { - (low, high, good_values, bad_values) - }; - - // Fit KDEs to good and bad groups - let l_kde = match self.kde_bandwidth { - Some(bw) => KernelDensityEstimator::with_bandwidth(good_internal, bw), - None => KernelDensityEstimator::new(good_internal), - }; - let g_kde = match self.kde_bandwidth { - Some(bw) => KernelDensityEstimator::with_bandwidth(bad_internal, bw), - None => KernelDensityEstimator::new(bad_internal), - }; - - // If KDE construction fails, fall back to uniform sampling - let (Ok(l_kde), Ok(g_kde)) = (l_kde, g_kde) else { - return rng_util::f64_range(rng, low, high); - }; - - // Generate candidates from l(x) and select the one with best l(x)/g(x) - let mut best_candidate = internal_low; - let mut best_ratio = f64::NEG_INFINITY; - - for _ in 0..self.n_ei_candidates { - let candidate = l_kde.sample(rng).clamp(internal_low, internal_high); - - let l_density = l_kde.pdf(candidate); - let g_density = g_kde.pdf(candidate); - - let ratio = if g_density < f64::EPSILON { - if l_density > f64::EPSILON { - f64::INFINITY - } else { - 0.0 - } - } else { - l_density / g_density - }; - - if ratio > best_ratio { - best_ratio = ratio; - best_candidate = candidate; - } - } - - // Transform back from internal space - let mut value = if log_scale { - best_candidate.exp() - } else { - best_candidate - }; - - // Apply step constraint if present - if let Some(step) = step { - let k = ((value - low) / step).round(); - value = low + k * step; - } - - value.clamp(low, high) - } - - /// Samples using TPE for integer distributions. - #[allow( - clippy::too_many_arguments, - clippy::cast_precision_loss, - clippy::cast_possible_truncation - )] - fn sample_tpe_int( - &self, - low: i64, - high: i64, - log_scale: bool, - step: Option<i64>, - good_values: Vec<i64>, - bad_values: Vec<i64>, - rng: &mut fastrand::Rng, - ) -> i64 { - let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); - - let float_value = self.sample_tpe_float( - low as f64, - high as f64, - log_scale, - step.map(|s| s as f64), - good_floats, - bad_floats, - rng, - ); - - let int_value = float_value.round() as i64; - let int_value = if let Some(step) = step { - let k = ((int_value - low) as f64 / step as f64).round() as i64; - low + k * step - } else { - int_value - }; - - int_value.clamp(low, high) - } - - /// Samples using TPE for categorical distributions. - #[allow(clippy::cast_precision_loss)] - fn sample_tpe_categorical( - n_choices: usize, - good_indices: &[usize], - bad_indices: &[usize], - rng: &mut fastrand::Rng, - ) -> usize { - // Stack-allocate for the common case (<=32 choices), heap for rare large cases - let mut good_buf = [0usize; 32]; - let mut bad_buf = [0usize; 32]; - let mut weight_buf = [0.0f64; 32]; - - let mut good_vec; - let mut bad_vec; - let mut weight_vec; - - let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = - if n_choices <= 32 { - ( - &mut good_buf[..n_choices], - &mut bad_buf[..n_choices], - &mut weight_buf[..n_choices], - ) - } else { - good_vec = vec![0usize; n_choices]; - bad_vec = vec![0usize; n_choices]; - weight_vec = vec![0.0f64; n_choices]; - (&mut good_vec, &mut bad_vec, &mut weight_vec) - }; - - // Count occurrences in good and bad groups - for &idx in good_indices { - if idx < n_choices { - good_counts[idx] += 1; - } - } - for &idx in bad_indices { - if idx < n_choices { - bad_counts[idx] += 1; - } - } - - // Laplace smoothing - let good_total = good_indices.len() as f64 + n_choices as f64; - let bad_total = bad_indices.len() as f64 + n_choices as f64; - - // Calculate l(x)/g(x) ratio for each category - for i in 0..n_choices { - let l_prob = (good_counts[i] as f64 + 1.0) / good_total; - let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; - weights[i] = l_prob / g_prob; - } - - // Sample proportionally to weights - let total_weight: f64 = weights.iter().sum(); - let threshold = rng.f64() * total_weight; - - let mut cumulative = 0.0; - for (i, &w) in weights.iter().enumerate() { - cumulative += w; - if cumulative >= threshold { - return i; - } - } - - n_choices - 1 - } } impl Default for MotpeSampler { @@ -477,14 +236,14 @@ impl MultiObjectiveSampler for MotpeSampler { .filter(|t| t.state == TrialState::Complete) .count(); if n_complete < self.n_startup_trials { - return Self::sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } // Split trials into good (Pareto front) and bad (dominated) let (good_trials, bad_trials) = Self::split_trials(history, directions); if good_trials.is_empty() || bad_trials.is_empty() { - return Self::sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } match distribution { @@ -510,16 +269,18 @@ impl MultiObjectiveSampler for MotpeSampler { .collect(); if good_values.is_empty() || bad_values.is_empty() { - return Self::sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let value = self.sample_tpe_float( + let value = tpe_common::sample_tpe_float( d.low, d.high, d.log_scale, d.step, good_values, bad_values, + self.n_ei_candidates, + self.kde_bandwidth, &mut rng, ); ParamValue::Float(value) @@ -546,16 +307,18 @@ impl MultiObjectiveSampler for MotpeSampler { .collect(); if good_values.is_empty() || bad_values.is_empty() { - return Self::sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let value = self.sample_tpe_int( + let value = tpe_common::sample_tpe_int( d.low, d.high, d.log_scale, d.step, good_values, bad_values, + self.n_ei_candidates, + self.kde_bandwidth, &mut rng, ); ParamValue::Int(value) @@ -582,10 +345,10 @@ impl MultiObjectiveSampler for MotpeSampler { .collect(); if good_indices.is_empty() || bad_indices.is_empty() { - return Self::sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let index = Self::sample_tpe_categorical( + let index = tpe_common::sample_tpe_categorical( d.n_choices, &good_indices, &bad_indices, diff --git a/src/sampler/random.rs b/src/sampler/random.rs index 6c750b3..2850376 100644 --- a/src/sampler/random.rs +++ b/src/sampler/random.rs @@ -127,50 +127,7 @@ impl Sampler for RandomSampler { rng_util::distribution_fingerprint(distribution).wrapping_add(seq), )); - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - // Sample uniformly in log space - let log_low = d.low.ln(); - let log_high = d.high.ln(); - let log_value = rng_util::f64_range(&mut rng, log_low, log_high); - log_value.exp() - } else if let Some(step) = d.step { - // Sample from step grid - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - // Uniform sampling - rng_util::f64_range(&mut rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - // Sample uniformly in log space, then round - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let log_value = rng_util::f64_range(&mut rng, log_low, log_high); - let raw = log_value.exp().round() as i64; - // Clamp to bounds since rounding might push outside - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - // Sample from step grid - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - // Uniform sampling - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => { - let index = rng.usize(0..d.n_choices); - ParamValue::Categorical(index) - } - } + super::common::sample_random(&mut rng, distribution) } } diff --git a/src/sampler/tpe/common.rs b/src/sampler/tpe/common.rs new file mode 100644 index 0000000..c566ae9 --- /dev/null +++ b/src/sampler/tpe/common.rs @@ -0,0 +1,218 @@ +//! Shared TPE sampling functions used by both `TpeSampler` and `MotpeSampler`. + +use crate::kde::KernelDensityEstimator; +use crate::rng_util; + +/// Samples using TPE for float distributions. +#[allow(clippy::too_many_arguments)] +pub(crate) fn sample_tpe_float( + low: f64, + high: f64, + log_scale: bool, + step: Option<f64>, + good_values: Vec<f64>, + bad_values: Vec<f64>, + n_ei_candidates: usize, + kde_bandwidth: Option<f64>, + rng: &mut fastrand::Rng, +) -> f64 { + // Transform to internal space (log space if needed) + let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { + let i_low = low.ln(); + let i_high = high.ln(); + let g = { + let mut v = good_values; + for x in &mut v { + *x = x.ln(); + } + v + }; + let b = { + let mut v = bad_values; + for x in &mut v { + *x = x.ln(); + } + v + }; + (i_low, i_high, g, b) + } else { + (low, high, good_values, bad_values) + }; + + // Fit KDEs to good and bad groups + let l_kde = match kde_bandwidth { + Some(bw) => KernelDensityEstimator::with_bandwidth(good_internal, bw), + None => KernelDensityEstimator::new(good_internal), + }; + let g_kde = match kde_bandwidth { + Some(bw) => KernelDensityEstimator::with_bandwidth(bad_internal, bw), + None => KernelDensityEstimator::new(bad_internal), + }; + + // If KDE construction fails, fall back to uniform sampling + let (Ok(l_kde), Ok(g_kde)) = (l_kde, g_kde) else { + return rng_util::f64_range(rng, low, high); + }; + + // Generate candidates from l(x) and select the one with best l(x)/g(x) ratio + let mut best_candidate = internal_low; + let mut best_ratio = f64::NEG_INFINITY; + + for _ in 0..n_ei_candidates { + let candidate = l_kde.sample(rng).clamp(internal_low, internal_high); + + let l_density = l_kde.pdf(candidate); + let g_density = g_kde.pdf(candidate); + + // Compute l(x)/g(x) ratio, handling zero density + let ratio = if g_density < f64::EPSILON { + if l_density > f64::EPSILON { + f64::INFINITY + } else { + 0.0 + } + } else { + l_density / g_density + }; + + if ratio > best_ratio { + best_ratio = ratio; + best_candidate = candidate; + } + } + + // Transform back from internal space + let mut value = if log_scale { + best_candidate.exp() + } else { + best_candidate + }; + + // Apply step constraint if present + if let Some(step) = step { + let k = ((value - low) / step).round(); + value = low + k * step; + } + + // Ensure value is within bounds + value.clamp(low, high) +} + +/// Samples using TPE for integer distributions. +#[allow( + clippy::too_many_arguments, + clippy::cast_precision_loss, + clippy::cast_possible_truncation +)] +pub(crate) fn sample_tpe_int( + low: i64, + high: i64, + log_scale: bool, + step: Option<i64>, + good_values: Vec<i64>, + bad_values: Vec<i64>, + n_ei_candidates: usize, + kde_bandwidth: Option<f64>, + rng: &mut fastrand::Rng, +) -> i64 { + // Convert to floats for KDE + let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); + let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); + + // Use float TPE sampling + let float_value = sample_tpe_float( + low as f64, + high as f64, + log_scale, + step.map(|s| s as f64), + good_floats, + bad_floats, + n_ei_candidates, + kde_bandwidth, + rng, + ); + + // Round to nearest integer + let int_value = float_value.round() as i64; + + // Apply step constraint if present + let int_value = if let Some(step) = step { + let k = ((int_value - low) as f64 / step as f64).round() as i64; + low + k * step + } else { + int_value + }; + + // Ensure value is within bounds + int_value.clamp(low, high) +} + +/// Samples using TPE for categorical distributions. +#[allow(clippy::cast_precision_loss)] +pub(crate) fn sample_tpe_categorical( + n_choices: usize, + good_indices: &[usize], + bad_indices: &[usize], + rng: &mut fastrand::Rng, +) -> usize { + // Stack-allocate for the common case (<=32 choices), heap for rare large cases + let mut good_buf = [0usize; 32]; + let mut bad_buf = [0usize; 32]; + let mut weight_buf = [0.0f64; 32]; + + let mut good_vec; + let mut bad_vec; + let mut weight_vec; + + let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = + if n_choices <= 32 { + ( + &mut good_buf[..n_choices], + &mut bad_buf[..n_choices], + &mut weight_buf[..n_choices], + ) + } else { + good_vec = vec![0usize; n_choices]; + bad_vec = vec![0usize; n_choices]; + weight_vec = vec![0.0f64; n_choices]; + (&mut good_vec, &mut bad_vec, &mut weight_vec) + }; + + // Count occurrences in good and bad groups + for &idx in good_indices { + if idx < n_choices { + good_counts[idx] += 1; + } + } + for &idx in bad_indices { + if idx < n_choices { + bad_counts[idx] += 1; + } + } + + // Add smoothing (Laplace smoothing) to avoid zero probabilities + let good_total = good_indices.len() as f64 + n_choices as f64; + let bad_total = bad_indices.len() as f64 + n_choices as f64; + + // Calculate l(x)/g(x) ratio for each category + for i in 0..n_choices { + let l_prob = (good_counts[i] as f64 + 1.0) / good_total; + let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; + weights[i] = l_prob / g_prob; + } + + // Sample proportionally to weights + let total_weight: f64 = weights.iter().sum(); + let threshold = rng.f64() * total_weight; + + let mut cumulative = 0.0; + for (i, &w) in weights.iter().enumerate() { + cumulative += w; + if cumulative >= threshold { + return i; + } + } + + // Fallback to last index (shouldn't happen) + n_choices - 1 +} diff --git a/src/sampler/tpe/mod.rs b/src/sampler/tpe/mod.rs index 2c17e3d..a856de2 100644 --- a/src/sampler/tpe/mod.rs +++ b/src/sampler/tpe/mod.rs @@ -59,6 +59,7 @@ //! let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); //! ``` +pub(crate) mod common; mod gamma; mod multivariate; mod sampler; diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate.rs index 731fc44..de8b37e 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate.rs @@ -810,55 +810,10 @@ impl MultivariateTpeSampler { ) -> HashMap<ParamId, ParamValue> { search_space .iter() - .map(|(id, dist)| (*id, Self::sample_uniform_single(dist, rng))) + .map(|(id, dist)| (*id, crate::sampler::common::sample_random(rng, dist))) .collect() } - /// Samples a single parameter uniformly at random from its distribution. - fn sample_uniform_single(distribution: &Distribution, rng: &mut fastrand::Rng) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - #[allow(clippy::cast_possible_truncation)] - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - #[allow(clippy::cast_precision_loss)] - let result = d.low + (k as f64) * step; - result - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - #[allow(clippy::cast_precision_loss)] - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - #[allow(clippy::cast_possible_truncation)] - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - #[allow(clippy::cast_possible_truncation)] - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => { - let index = rng.usize(0..d.n_choices); - ParamValue::Categorical(index) - } - } - } - /// Samples parameters jointly using multivariate TPE. /// /// This method samples all parameters in the search space jointly, capturing @@ -1025,7 +980,7 @@ impl MultivariateTpeSampler { // Sample ungrouped parameters uniformly (no history for them) let mut rng = self.rng.lock(); for (id, dist) in &ungrouped_params { - let value = Self::sample_uniform_single(dist, &mut rng); + let value = crate::sampler::common::sample_random(&mut rng, dist); result.insert(*id, value); } } @@ -1312,7 +1267,7 @@ impl MultivariateTpeSampler { .collect(); if good_values.is_empty() || bad_values.is_empty() { - return Self::sample_uniform_single(distribution, rng); + return crate::sampler::common::sample_random(rng, distribution); } let value = self.sample_tpe_float( @@ -1348,7 +1303,7 @@ impl MultivariateTpeSampler { .collect(); if good_values.is_empty() || bad_values.is_empty() { - return Self::sample_uniform_single(distribution, rng); + return crate::sampler::common::sample_random(rng, distribution); } let value = self.sample_tpe_int( @@ -1384,7 +1339,7 @@ impl MultivariateTpeSampler { .collect(); if good_indices.is_empty() || bad_indices.is_empty() { - return Self::sample_uniform_single(distribution, rng); + return crate::sampler::common::sample_random(rng, distribution); } let idx = @@ -1765,7 +1720,7 @@ impl Sampler for MultivariateTpeSampler { Self::find_matching_param(distribution, &joint_sample).unwrap_or_else(|| { // Fallback to uniform sampling if no match found let mut rng = self.rng.lock(); - Self::sample_uniform_single(distribution, &mut rng) + crate::sampler::common::sample_random(&mut rng, distribution) }) } } diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index fb9efe6..5c9af1d 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -61,12 +61,14 @@ use std::sync::Arc; use crate::distribution::Distribution; use crate::error::{Error, Result}; -use crate::kde::KernelDensityEstimator; use crate::param::ParamValue; use crate::rng_util; +use crate::sampler::common; use crate::sampler::tpe::gamma::{FixedGamma, GammaStrategy}; use crate::sampler::{CompletedTrial, Sampler}; +use super::common as tpe_common; + // ============================================================================ // Gamma Strategy Trait and Implementations // ============================================================================ @@ -315,261 +317,6 @@ impl TpeSampler { (good, bad) } - - /// Samples uniformly from a distribution (used during startup phase). - #[allow( - clippy::cast_possible_truncation, - clippy::cast_precision_loss, - clippy::unused_self - )] - fn sample_uniform(&self, distribution: &Distribution, rng: &mut fastrand::Rng) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let value = if d.log_scale { - let log_low = d.low.ln(); - let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() - } else if let Some(step) = d.step { - let n_steps = ((d.high - d.low) / step).floor() as i64; - let k = rng.i64(0..=n_steps); - d.low + (k as f64) * step - } else { - rng_util::f64_range(rng, d.low, d.high) - }; - ParamValue::Float(value) - } - Distribution::Int(d) => { - let value = if d.log_scale { - let log_low = (d.low as f64).ln(); - let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; - raw.clamp(d.low, d.high) - } else if let Some(step) = d.step { - let n_steps = (d.high - d.low) / step; - let k = rng.i64(0..=n_steps); - d.low + k * step - } else { - rng.i64(d.low..=d.high) - }; - ParamValue::Int(value) - } - Distribution::Categorical(d) => ParamValue::Categorical(rng.usize(0..d.n_choices)), - } - } - - /// Samples using TPE for float distributions. - #[allow(clippy::too_many_arguments)] - fn sample_tpe_float( - &self, - low: f64, - high: f64, - log_scale: bool, - step: Option<f64>, - good_values: Vec<f64>, - bad_values: Vec<f64>, - rng: &mut fastrand::Rng, - ) -> f64 { - // Transform to internal space (log space if needed) - let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { - let i_low = low.ln(); - let i_high = high.ln(); - let g = { - let mut v = good_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - let b = { - let mut v = bad_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - (i_low, i_high, g, b) - } else { - (low, high, good_values, bad_values) - }; - - // Fit KDEs to good and bad groups - let l_kde = match self.kde_bandwidth { - Some(bw) => KernelDensityEstimator::with_bandwidth(good_internal, bw), - None => KernelDensityEstimator::new(good_internal), - }; - let g_kde = match self.kde_bandwidth { - Some(bw) => KernelDensityEstimator::with_bandwidth(bad_internal, bw), - None => KernelDensityEstimator::new(bad_internal), - }; - - // If KDE construction fails, fall back to uniform sampling - let (Ok(l_kde), Ok(g_kde)) = (l_kde, g_kde) else { - return rng_util::f64_range(rng, low, high); - }; - - // Generate candidates from l(x) and select the one with best l(x)/g(x) ratio - let mut best_candidate = internal_low; - let mut best_ratio = f64::NEG_INFINITY; - - for _ in 0..self.n_ei_candidates { - let candidate = l_kde.sample(rng); - - // Clamp to bounds - let candidate = candidate.clamp(internal_low, internal_high); - - let l_density = l_kde.pdf(candidate); - let g_density = g_kde.pdf(candidate); - - // Compute l(x)/g(x) ratio, handling zero density - let ratio = if g_density < f64::EPSILON { - if l_density > f64::EPSILON { - f64::INFINITY - } else { - 0.0 - } - } else { - l_density / g_density - }; - - if ratio > best_ratio { - best_ratio = ratio; - best_candidate = candidate; - } - } - - // Transform back from internal space - let mut value = if log_scale { - best_candidate.exp() - } else { - best_candidate - }; - - // Apply step constraint if present - if let Some(step) = step { - let k = ((value - low) / step).round(); - value = low + k * step; - } - - // Ensure value is within bounds - value.clamp(low, high) - } - - /// Samples using TPE for integer distributions. - #[allow( - clippy::too_many_arguments, - clippy::cast_precision_loss, - clippy::cast_possible_truncation - )] - fn sample_tpe_int( - &self, - low: i64, - high: i64, - log_scale: bool, - step: Option<i64>, - good_values: Vec<i64>, - bad_values: Vec<i64>, - rng: &mut fastrand::Rng, - ) -> i64 { - // Convert to floats for KDE - let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); - - // Use float TPE sampling - let float_value = self.sample_tpe_float( - low as f64, - high as f64, - log_scale, - step.map(|s| s as f64), - good_floats, - bad_floats, - rng, - ); - - // Round to nearest integer - let int_value = float_value.round() as i64; - - // Apply step constraint if present - let int_value = if let Some(step) = step { - let k = ((int_value - low) as f64 / step as f64).round() as i64; - low + k * step - } else { - int_value - }; - - // Ensure value is within bounds - int_value.clamp(low, high) - } - - /// Samples using TPE for categorical distributions. - #[allow(clippy::cast_precision_loss, clippy::unused_self)] - fn sample_tpe_categorical( - &self, - n_choices: usize, - good_indices: &[usize], - bad_indices: &[usize], - rng: &mut fastrand::Rng, - ) -> usize { - // Stack-allocate for the common case (<=32 choices), heap for rare large cases - let mut good_buf = [0usize; 32]; - let mut bad_buf = [0usize; 32]; - let mut weight_buf = [0.0f64; 32]; - - let mut good_vec; - let mut bad_vec; - let mut weight_vec; - - let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = - if n_choices <= 32 { - ( - &mut good_buf[..n_choices], - &mut bad_buf[..n_choices], - &mut weight_buf[..n_choices], - ) - } else { - good_vec = vec![0usize; n_choices]; - bad_vec = vec![0usize; n_choices]; - weight_vec = vec![0.0f64; n_choices]; - (&mut good_vec, &mut bad_vec, &mut weight_vec) - }; - - // Count occurrences in good and bad groups - for &idx in good_indices { - if idx < n_choices { - good_counts[idx] += 1; - } - } - for &idx in bad_indices { - if idx < n_choices { - bad_counts[idx] += 1; - } - } - - // Add smoothing (Laplace smoothing) to avoid zero probabilities - let good_total = good_indices.len() as f64 + n_choices as f64; - let bad_total = bad_indices.len() as f64 + n_choices as f64; - - // Calculate l(x)/g(x) ratio for each category - for i in 0..n_choices { - let l_prob = (good_counts[i] as f64 + 1.0) / good_total; - let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; - weights[i] = l_prob / g_prob; - } - - // Sample proportionally to weights - let total_weight: f64 = weights.iter().sum(); - let threshold = rng.f64() * total_weight; - - let mut cumulative = 0.0; - for (i, &w) in weights.iter().enumerate() { - cumulative += w; - if cumulative >= threshold { - return i; - } - } - - // Fallback to last index (shouldn't happen) - n_choices - 1 - } } impl Default for TpeSampler { @@ -912,7 +659,7 @@ impl Sampler for TpeSampler { // Fall back to random sampling during startup phase if history.len() < self.n_startup_trials { - return self.sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } // Split trials into good and bad groups @@ -920,7 +667,7 @@ impl Sampler for TpeSampler { // Need at least 1 trial in each group for TPE if good_trials.is_empty() || bad_trials.is_empty() { - return self.sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } // Extract parameter values for this distribution @@ -954,16 +701,18 @@ impl Sampler for TpeSampler { // Need values in both groups for TPE if good_values.is_empty() || bad_values.is_empty() { - return self.sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let value = self.sample_tpe_float( + let value = tpe_common::sample_tpe_float( d.low, d.high, d.log_scale, d.step, good_values, bad_values, + self.n_ei_candidates, + self.kde_bandwidth, &mut rng, ); ParamValue::Float(value) @@ -990,16 +739,18 @@ impl Sampler for TpeSampler { .collect(); if good_values.is_empty() || bad_values.is_empty() { - return self.sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let value = self.sample_tpe_int( + let value = tpe_common::sample_tpe_int( d.low, d.high, d.log_scale, d.step, good_values, bad_values, + self.n_ei_candidates, + self.kde_bandwidth, &mut rng, ); ParamValue::Int(value) @@ -1026,11 +777,15 @@ impl Sampler for TpeSampler { .collect(); if good_indices.is_empty() || bad_indices.is_empty() { - return self.sample_uniform(distribution, &mut rng); + return common::sample_random(&mut rng, distribution); } - let index = - self.sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, &mut rng); + let index = tpe_common::sample_tpe_categorical( + d.n_choices, + &good_indices, + &bad_indices, + &mut rng, + ); ParamValue::Categorical(index) } } From a502d7e1b844a2dd3aaae063b6f45e11820334d4 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:38:23 +0100 Subject: [PATCH 30/48] refactor(tpe): split multivariate sampler into focused submodules - Extract engine.rs (core sampling logic) and trials.rs (trial processing) - Deduplicate TPE sampling functions by delegating to tpe::common - Gate persistence.rs imports with #[cfg(feature = "serde")] --- src/sampler/tpe/multivariate/engine.rs | 607 ++++++++ .../{multivariate.rs => multivariate/mod.rs} | 1256 +---------------- src/sampler/tpe/multivariate/trials.rs | 220 +++ src/study/persistence.rs | 4 + 4 files changed, 839 insertions(+), 1248 deletions(-) create mode 100644 src/sampler/tpe/multivariate/engine.rs rename src/sampler/tpe/{multivariate.rs => multivariate/mod.rs} (80%) create mode 100644 src/sampler/tpe/multivariate/trials.rs diff --git a/src/sampler/tpe/multivariate/engine.rs b/src/sampler/tpe/multivariate/engine.rs new file mode 100644 index 0000000..1002b29 --- /dev/null +++ b/src/sampler/tpe/multivariate/engine.rs @@ -0,0 +1,607 @@ +//! Core multivariate TPE sampling logic. +//! +//! Contains the main sampling engine: group decomposition, single-group multivariate +//! TPE, candidate selection, independent fallbacks, and value conversion. + +use std::collections::HashMap; + +use crate::distribution::Distribution; +use crate::param::ParamValue; +use crate::parameter::ParamId; +use crate::sampler::CompletedTrial; + +use super::MultivariateTpeSampler; + +impl MultivariateTpeSampler { + /// Samples parameters by decomposing the search space into independent groups. + /// + /// When `group=true`, this method analyzes the trial history to identify groups of + /// parameters that always appear together, then samples each group independently + /// using multivariate TPE. This is more efficient when parameters naturally partition + /// into independent subsets (e.g., due to conditional search spaces). + /// + /// # Arguments + /// + /// * `search_space` - The full search space containing all parameters to sample. + /// * `history` - Completed trials from the optimization history. + /// + /// # Returns + /// + /// A `HashMap` mapping parameter names to their sampled values. + pub(crate) fn sample_with_groups( + &self, + search_space: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + ) -> HashMap<ParamId, ParamValue> { + use std::collections::HashSet; + + use crate::sampler::tpe::GroupDecomposedSearchSpace; + + // Decompose the search space into independent parameter groups + let groups = GroupDecomposedSearchSpace::calculate(history); + + let mut result: HashMap<ParamId, ParamValue> = HashMap::new(); + + // Sample each group independently + for group in &groups { + // Build a sub-search space for this group + let group_search_space: HashMap<ParamId, Distribution> = search_space + .iter() + .filter(|(id, _)| group.contains(id)) + .map(|(id, dist)| (*id, dist.clone())) + .collect(); + + if group_search_space.is_empty() { + continue; + } + + // Filter history to trials that have at least one parameter in this group + let group_history: Vec<&CompletedTrial> = history + .iter() + .filter(|trial| { + trial + .distributions + .keys() + .any(|param_id| group.contains(param_id)) + }) + .collect(); + + // Build completed trials from references for the group + // We need to create a temporary slice for sample_group_internal + let group_history_owned: Vec<CompletedTrial> = + group_history.iter().map(|t| (*t).clone()).collect(); + + // Sample this group using multivariate TPE + let mut rng = self.rng.lock(); + let group_result = + self.sample_single_group(&group_search_space, &group_history_owned, &mut rng); + drop(rng); + + // Merge group results into the main result + for (id, value) in group_result { + result.insert(id, value); + } + } + + // Handle parameters not in any group (sample independently) + let grouped_params: HashSet<ParamId> = groups.iter().flatten().copied().collect(); + let ungrouped_params: HashMap<ParamId, Distribution> = search_space + .iter() + .filter(|(id, _)| !grouped_params.contains(id) && !result.contains_key(id)) + .map(|(id, dist)| (*id, dist.clone())) + .collect(); + + if !ungrouped_params.is_empty() { + // Sample ungrouped parameters uniformly (no history for them) + let mut rng = self.rng.lock(); + for (id, dist) in &ungrouped_params { + let value = crate::sampler::common::sample_random(&mut rng, dist); + result.insert(*id, value); + } + } + + result + } + + /// Samples parameters as a single group using multivariate TPE. + /// + /// This is the core multivariate TPE sampling logic, used both in non-grouped mode + /// and for sampling individual groups in grouped mode. + /// + /// # Arguments + /// + /// * `search_space` - The search space for this group of parameters. + /// * `history` - Completed trials to use for model fitting. + /// * `rng` - Random number generator (caller must hold lock). + /// + /// # Returns + /// + /// A `HashMap` mapping parameter names to their sampled values. + #[allow(clippy::too_many_lines)] + pub(crate) fn sample_single_group( + &self, + search_space: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + rng: &mut fastrand::Rng, + ) -> HashMap<ParamId, ParamValue> { + use crate::kde::MultivariateKDE; + use crate::sampler::tpe::IntersectionSearchSpace; + use crate::sampler::tpe::common; + + // Early returns for cases requiring random sampling + if history.len() < self.n_startup_trials { + return self.sample_all_uniform(search_space, rng); + } + + let intersection = IntersectionSearchSpace::calculate(history); + if intersection.is_empty() { + return self.sample_all_independent_with_rng(search_space, history, rng); + } + + let filtered = self.filter_trials(history, &intersection); + if filtered.len() < 2 { + return self.sample_all_independent_with_rng(search_space, history, rng); + } + + let (good, bad) = self.split_trials(&filtered); + + // Sample categorical parameters using TPE with l(x)/g(x) ratio + let mut result: HashMap<ParamId, ParamValue> = HashMap::new(); + for (param_id, dist) in &intersection { + if let Distribution::Categorical(d) = dist { + let good_indices = Self::extract_categorical_indices(&good, *param_id); + let bad_indices = Self::extract_categorical_indices(&bad, *param_id); + let idx = + common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); + result.insert(*param_id, ParamValue::Categorical(idx)); + } + } + + // Collect continuous parameters + let mut param_order: Vec<ParamId> = intersection + .iter() + .filter(|(_, dist)| !matches!(dist, Distribution::Categorical(_))) + .map(|(id, _)| *id) + .collect(); + + if param_order.is_empty() { + // Only categorical parameters in intersection - fill remaining with independent TPE + self.fill_remaining_independent_with_rng( + search_space, + &intersection, + history, + &mut result, + rng, + ); + return result; + } + + param_order.sort_by_key(|id| format!("{id}")); + + // Extract observations and validate + let good_obs = self.extract_observations(&good, ¶m_order); + let bad_obs = self.extract_observations(&bad, ¶m_order); + let expected_dims = param_order.len(); + + let valid = !good_obs.is_empty() + && !bad_obs.is_empty() + && good_obs.iter().all(|obs| obs.len() == expected_dims) + && bad_obs.iter().all(|obs| obs.len() == expected_dims); + + if !valid { + // Observations invalid - fill remaining with independent TPE + self.fill_remaining_independent_with_rng( + search_space, + &intersection, + history, + &mut result, + rng, + ); + return result; + } + + // Fit KDEs using let...else pattern + let Ok(good_kde) = MultivariateKDE::new(good_obs) else { + // KDE construction failed - fill remaining with independent TPE + self.fill_remaining_independent_with_rng( + search_space, + &intersection, + history, + &mut result, + rng, + ); + return result; + }; + + let Ok(bad_kde) = MultivariateKDE::new(bad_obs) else { + // KDE construction failed - fill remaining with independent TPE + self.fill_remaining_independent_with_rng( + search_space, + &intersection, + history, + &mut result, + rng, + ); + return result; + }; + + let selected = self.select_candidate_with_rng(&good_kde, &bad_kde, rng); + + // Map selected values to parameter ids + for (idx, param_id) in param_order.iter().enumerate() { + if let Some(dist) = intersection.get(param_id) { + let value = selected[idx]; + let param_value = self.convert_to_param_value(value, dist); + if let Some(pv) = param_value { + result.insert(*param_id, pv); + } + } + } + + // Fill remaining parameters using independent TPE sampling + self.fill_remaining_independent_with_rng( + search_space, + &intersection, + history, + &mut result, + rng, + ); + result + } + + /// Converts a raw f64 value to a `ParamValue` based on the distribution. + #[allow(clippy::unused_self)] + pub(crate) fn convert_to_param_value( + &self, + value: f64, + dist: &Distribution, + ) -> Option<ParamValue> { + match dist { + Distribution::Float(d) => { + let clamped = value.clamp(d.low, d.high); + let stepped = if let Some(step) = d.step { + let steps = ((clamped - d.low) / step).round(); + (d.low + steps * step).clamp(d.low, d.high) + } else { + clamped + }; + Some(ParamValue::Float(stepped)) + } + Distribution::Int(d) => { + #[allow(clippy::cast_possible_truncation)] + let int_value = value.round() as i64; + let clamped = int_value.clamp(d.low, d.high); + let stepped = if let Some(step) = d.step { + let steps = (clamped - d.low) / step; + (d.low + steps * step).clamp(d.low, d.high) + } else { + clamped + }; + Some(ParamValue::Int(stepped)) + } + Distribution::Categorical(_) => None, + } + } + + /// Selects the best candidate from a set of samples using the joint acquisition function. + /// + /// This method implements the core TPE selection criterion: it generates candidates + /// from the "good" KDE (l(x)) and selects the one that maximizes the ratio l(x)/g(x), + /// which is equivalent to maximizing `log(l(x)) - log(g(x))`. + #[must_use] + #[allow(dead_code)] // Used by tests + pub(crate) fn select_candidate( + &self, + good_kde: &crate::kde::MultivariateKDE, + bad_kde: &crate::kde::MultivariateKDE, + ) -> Vec<f64> { + let mut rng = self.rng.lock(); + + // Generate candidates from the good distribution + let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) + .map(|_| good_kde.sample(&mut rng)) + .collect(); + + // Compute log(l(x)) - log(g(x)) for each candidate + // This is equivalent to log(l(x)/g(x)) which we want to maximize + let log_ratios: Vec<f64> = candidates + .iter() + .map(|candidate| { + let log_l = good_kde.log_pdf(candidate); + let log_g = bad_kde.log_pdf(candidate); + log_l - log_g + }) + .collect(); + + // Find the candidate with the maximum log ratio + let mut best_idx = 0; + let mut best_ratio = f64::NEG_INFINITY; + + for (idx, &ratio) in log_ratios.iter().enumerate() { + // Handle NaN by treating it as worse than any finite value + if ratio > best_ratio || (best_ratio.is_nan() && !ratio.is_nan()) { + best_ratio = ratio; + best_idx = idx; + } + } + + candidates.into_iter().nth(best_idx).unwrap_or_default() + } + + /// Selects the best candidate using an external RNG. + /// + /// This variant accepts an external RNG, used when the caller already holds the lock. + pub(crate) fn select_candidate_with_rng( + &self, + good_kde: &crate::kde::MultivariateKDE, + bad_kde: &crate::kde::MultivariateKDE, + rng: &mut fastrand::Rng, + ) -> Vec<f64> { + // Generate candidates from the good distribution + let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) + .map(|_| good_kde.sample(rng)) + .collect(); + + // Compute log(l(x)) - log(g(x)) for each candidate + let log_ratios: Vec<f64> = candidates + .iter() + .map(|candidate| { + let log_l = good_kde.log_pdf(candidate); + let log_g = bad_kde.log_pdf(candidate); + log_l - log_g + }) + .collect(); + + // Find the candidate with the maximum log ratio + let mut best_idx = 0; + let mut best_ratio = f64::NEG_INFINITY; + + for (idx, &ratio) in log_ratios.iter().enumerate() { + if ratio > best_ratio || (best_ratio.is_nan() && !ratio.is_nan()) { + best_ratio = ratio; + best_idx = idx; + } + } + + candidates.into_iter().nth(best_idx).unwrap_or_default() + } + + /// Fills remaining parameters in result using independent TPE sampling. + /// + /// This method is used to sample parameters that are not in the intersection + /// search space. It uses independent univariate TPE sampling for each parameter, + /// similar to the standard [`TpeSampler`]. + /// + /// When there isn't enough history for a parameter, falls back to uniform sampling. + #[allow(dead_code)] + pub(crate) fn fill_remaining_independent( + &self, + search_space: &HashMap<ParamId, Distribution>, + _intersection: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + result: &mut HashMap<ParamId, ParamValue>, + ) { + // Identify parameters not in result (and not in intersection) + let missing_params: Vec<(&ParamId, &Distribution)> = search_space + .iter() + .filter(|(id, _)| !result.contains_key(id)) + .collect(); + + if missing_params.is_empty() { + return; + } + + // Split trials for independent sampling + let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); + + let mut rng = self.rng.lock(); + + for (param_id, dist) in missing_params { + let value = + self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); + result.insert(*param_id, value); + } + } + + /// Fills remaining parameters using independent TPE sampling with an external RNG. + /// + /// This variant accepts an external RNG, used when the caller already holds the lock. + pub(crate) fn fill_remaining_independent_with_rng( + &self, + search_space: &HashMap<ParamId, Distribution>, + _intersection: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + result: &mut HashMap<ParamId, ParamValue>, + rng: &mut fastrand::Rng, + ) { + // Identify parameters not in result (and not in intersection) + let missing_params: Vec<(&ParamId, &Distribution)> = search_space + .iter() + .filter(|(id, _)| !result.contains_key(id)) + .collect(); + + if missing_params.is_empty() { + return; + } + + // Split trials for independent sampling + let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); + + for (param_id, dist) in missing_params { + let value = + self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, rng); + result.insert(*param_id, value); + } + } + + /// Samples all parameters using independent TPE sampling. + /// + /// This is used as a complete fallback when no intersection search space exists. + #[allow(dead_code)] // Used by tests + pub(crate) fn sample_all_independent( + &self, + search_space: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + ) -> HashMap<ParamId, ParamValue> { + // Split trials for independent sampling + let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); + + let mut rng = self.rng.lock(); + let mut result = HashMap::new(); + + for (param_id, dist) in search_space { + let value = + self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); + result.insert(*param_id, value); + } + + result + } + + /// Samples all parameters using independent TPE sampling with an external RNG. + /// + /// This variant accepts an external RNG, used when the caller already holds the lock. + pub(crate) fn sample_all_independent_with_rng( + &self, + search_space: &HashMap<ParamId, Distribution>, + history: &[CompletedTrial], + rng: &mut fastrand::Rng, + ) -> HashMap<ParamId, ParamValue> { + // Split trials for independent sampling + let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); + + let mut result = HashMap::new(); + + for (param_id, dist) in search_space { + let value = + self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, rng); + result.insert(*param_id, value); + } + + result + } + + /// Samples a single parameter using independent TPE. + /// + /// This method extracts values for the given parameter from good and bad trials, + /// fits univariate KDEs, and samples using the TPE acquisition function. + #[allow(clippy::too_many_lines)] + pub(crate) fn sample_independent_tpe( + &self, + param_id: ParamId, + distribution: &Distribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + use crate::sampler::tpe::common; + + match distribution { + Distribution::Float(d) => { + let good_values: Vec<f64> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<f64> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return crate::sampler::common::sample_random(rng, distribution); + } + + let value = common::sample_tpe_float( + d.low, + d.high, + d.log_scale, + d.step, + good_values, + bad_values, + self.n_ei_candidates, + None, + rng, + ); + ParamValue::Float(value) + } + Distribution::Int(d) => { + let good_values: Vec<i64> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<i64> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return crate::sampler::common::sample_random(rng, distribution); + } + + let value = common::sample_tpe_int( + d.low, + d.high, + d.log_scale, + d.step, + good_values, + bad_values, + self.n_ei_candidates, + None, + rng, + ); + ParamValue::Int(value) + } + Distribution::Categorical(d) => { + let good_indices: Vec<usize> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + let bad_indices: Vec<usize> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + if good_indices.is_empty() || bad_indices.is_empty() { + return crate::sampler::common::sample_random(rng, distribution); + } + + let idx = + common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); + ParamValue::Categorical(idx) + } + } + } +} diff --git a/src/sampler/tpe/multivariate.rs b/src/sampler/tpe/multivariate/mod.rs similarity index 80% rename from src/sampler/tpe/multivariate.rs rename to src/sampler/tpe/multivariate/mod.rs index de8b37e..4900e13 100644 --- a/src/sampler/tpe/multivariate.rs +++ b/src/sampler/tpe/multivariate/mod.rs @@ -112,6 +112,9 @@ //! let sampler = MultivariateTpeSampler::builder().build().unwrap(); //! ``` +mod engine; +mod trials; + use std::collections::HashMap; use std::sync::Arc; @@ -122,8 +125,7 @@ use crate::distribution::Distribution; use crate::error::Result; use crate::param::ParamValue; use crate::parameter::ParamId; -use crate::rng_util; -use crate::sampler::{CompletedTrial, PendingTrial, Sampler}; +use crate::sampler::{CompletedTrial, Sampler}; /// Strategy for imputing objective values for pending/running trials during parallel optimization. /// @@ -311,494 +313,6 @@ impl MultivariateTpeSampler { &self.constant_liar } - /// Imputes objective values for pending trials based on the constant liar strategy. - /// - /// In parallel optimization, multiple trials may be running simultaneously. This method - /// assigns "lie" values to pending trials so they can be included in the model fitting, - /// which helps avoid redundant exploration of the same region. - /// - /// # Arguments - /// - /// * `pending_trials` - Trials that are currently running and have no objective value yet. - /// * `completed_trials` - Trials that have completed and have objective values. - /// - /// # Returns - /// - /// A vector of `CompletedTrial` objects containing both the original completed trials - /// and the pending trials with imputed values. If the strategy is `None`, returns - /// only the completed trials (pending trials are ignored). - /// - /// # Imputation Strategies - /// - /// - `None`: Pending trials are ignored (returns only completed trials) - /// - `Mean`: Pending trials get the mean of completed objective values - /// - `Best`: Pending trials get the minimum completed objective value (for minimization) - /// - `Worst`: Pending trials get the maximum completed objective value (for minimization) - /// - `Custom(v)`: Pending trials get the specified value `v` - /// - /// # Examples - /// - /// ```ignore - /// use std::collections::HashMap; - /// use optimizer::sampler::{ - /// ConstantLiarStrategy, MultivariateTpeSampler, CompletedTrial, PendingTrial, - /// }; - /// use optimizer::param::ParamValue; - /// use optimizer::parameter::ParamId; - /// use optimizer::distribution::{Distribution, FloatDistribution}; - /// - /// // Create a sampler with mean imputation - /// let sampler = MultivariateTpeSampler::builder() - /// .constant_liar(ConstantLiarStrategy::Mean) - /// .build() - /// .unwrap(); - /// - /// // Create some completed trials - /// let dist = Distribution::Float(FloatDistribution { - /// low: 0.0, high: 1.0, log_scale: false, step: None, - /// }); - /// let x_id = ParamId::new(); - /// let completed = vec![ - /// CompletedTrial::new( - /// 0, - /// [(x_id, ParamValue::Float(0.2))].into_iter().collect(), - /// [(x_id, dist.clone())].into_iter().collect(), - /// HashMap::new(), - /// 1.0, - /// ), - /// CompletedTrial::new( - /// 1, - /// [(x_id, ParamValue::Float(0.8))].into_iter().collect(), - /// [(x_id, dist.clone())].into_iter().collect(), - /// HashMap::new(), - /// 3.0, - /// ), - /// ]; - /// - /// // Create a pending trial - /// let pending = vec![ - /// PendingTrial::new( - /// 2, - /// [(x_id, ParamValue::Float(0.5))].into_iter().collect(), - /// [(x_id, dist.clone())].into_iter().collect(), - /// HashMap::new(), - /// ), - /// ]; - /// - /// // Impute values - /// let augmented = sampler.impute_pending_trials(&pending, &completed); - /// - /// // Should have 3 trials total - /// assert_eq!(augmented.len(), 3); - /// - /// // The pending trial should have the mean value (1.0 + 3.0) / 2 = 2.0 - /// let imputed = augmented.iter().find(|t| t.id == 2).unwrap(); - /// assert!((imputed.value - 2.0).abs() < f64::EPSILON); - /// ``` - #[must_use] - pub fn impute_pending_trials( - &self, - pending_trials: &[PendingTrial], - completed_trials: &[CompletedTrial], - ) -> Vec<CompletedTrial> { - // Start with a copy of completed trials - let mut result: Vec<CompletedTrial> = completed_trials.to_vec(); - - // If strategy is None or no pending trials, just return completed trials - if matches!(self.constant_liar, ConstantLiarStrategy::None) || pending_trials.is_empty() { - return result; - } - - // Compute the imputation value based on strategy - let imputed_value = self.compute_imputation_value(completed_trials); - - // Convert pending trials to completed trials with imputed values - for pending in pending_trials { - result.push(CompletedTrial::new( - pending.id, - pending.params.clone(), - pending.distributions.clone(), - HashMap::new(), - imputed_value, - )); - } - - result - } - - /// Computes the imputation value based on the constant liar strategy. - /// - /// This is a helper method used by [`impute_pending_trials`](Self::impute_pending_trials). - /// - /// # Arguments - /// - /// * `completed_trials` - The completed trials to compute the imputation value from. - /// - /// # Returns - /// - /// The imputed value based on the strategy. Returns 0.0 if there are no completed - /// trials (except for `Custom` strategy which returns its specified value). - #[allow(clippy::cast_precision_loss)] - fn compute_imputation_value(&self, completed_trials: &[CompletedTrial]) -> f64 { - match self.constant_liar { - ConstantLiarStrategy::None => 0.0, // This case is handled before calling this method - ConstantLiarStrategy::Mean => { - if completed_trials.is_empty() { - 0.0 - } else { - let sum: f64 = completed_trials.iter().map(|t| t.value).sum(); - sum / completed_trials.len() as f64 - } - } - ConstantLiarStrategy::Best => { - // Best means minimum for minimization problems - completed_trials - .iter() - .map(|t| t.value) - .fold(f64::INFINITY, f64::min) - } - ConstantLiarStrategy::Worst => { - // Worst means maximum for minimization problems - completed_trials - .iter() - .map(|t| t.value) - .fold(f64::NEG_INFINITY, f64::max) - } - ConstantLiarStrategy::Custom(v) => v, - } - } - - /// Filters trials to those containing all parameters in the search space. - /// - /// This method is used to identify trials that can be used for multivariate - /// KDE fitting. Only trials that contain ALL parameters in the search space - /// are included, ensuring we can model the joint distribution over all - /// parameters. - /// - /// # Arguments - /// - /// * `history` - All completed trials from the optimization history. - /// * `search_space` - The intersection search space containing parameters that - /// appear in all trials. - /// - /// # Returns - /// - /// A vector of references to trials that contain all parameters in the search space. - /// - /// # Examples - /// - /// ```ignore - /// use std::collections::HashMap; - /// use optimizer::sampler::tpe::MultivariateTpeSampler; - /// use optimizer::sampler::tpe::IntersectionSearchSpace; - /// - /// let sampler = MultivariateTpeSampler::new(); - /// let trials = vec![/* ... completed trials ... */]; - /// let search_space = IntersectionSearchSpace::calculate(&trials); - /// let filtered = sampler.filter_trials(&trials, &search_space); - /// ``` - #[must_use] - pub fn filter_trials<'a>( - &self, - history: &'a [CompletedTrial], - search_space: &HashMap<ParamId, Distribution>, - ) -> Vec<&'a CompletedTrial> { - history - .iter() - .filter(|trial| { - // Include trial only if it has ALL parameters in the search space - search_space - .keys() - .all(|param_id| trial.params.contains_key(param_id)) - }) - .collect() - } - - /// Splits filtered trials into good and bad groups based on the gamma quantile. - /// - /// The gamma value is computed dynamically using the configured [`GammaStrategy`]. - /// Trials are sorted by objective value (ascending for minimization), and the - /// gamma quantile determines the split point. - /// - /// # Arguments - /// - /// * `trials` - Filtered trials to split (typically from [`filter_trials`](Self::filter_trials)). - /// - /// # Returns - /// - /// A tuple `(good_trials, bad_trials)` where: - /// - `good_trials` contains trials with values below the gamma quantile - /// - `bad_trials` contains trials with values at or above the gamma quantile - /// - /// Both vectors are guaranteed to be non-empty when the input has at least 2 trials. - /// If the input has fewer than 2 trials, both vectors may be empty or one may - /// contain the single trial. - /// - /// # Examples - /// - /// ```ignore - /// use std::collections::HashMap; - /// use optimizer::sampler::tpe::MultivariateTpeSampler; - /// use optimizer::sampler::tpe::IntersectionSearchSpace; - /// - /// let sampler = MultivariateTpeSampler::new(); - /// let trials = vec![/* ... completed trials ... */]; - /// let search_space = IntersectionSearchSpace::calculate(&trials); - /// let filtered = sampler.filter_trials(&trials, &search_space); - /// let (good, bad) = sampler.split_trials(&filtered); - /// - /// // good contains trials with lowest objective values - /// // bad contains trials with higher objective values - /// ``` - #[allow( - clippy::cast_precision_loss, - clippy::cast_possible_truncation, - clippy::cast_sign_loss - )] - #[must_use] - pub fn split_trials<'a>( - &self, - trials: &[&'a CompletedTrial], - ) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) { - if trials.is_empty() { - return (vec![], vec![]); - } - - // Sort trials by objective value (ascending for minimization) - let mut sorted_indices: Vec<usize> = (0..trials.len()).collect(); - sorted_indices.sort_by(|&a, &b| { - trials[a] - .value - .partial_cmp(&trials[b].value) - .unwrap_or(core::cmp::Ordering::Equal) - }); - - // Compute gamma using the strategy and clamp to valid range - let gamma = self - .gamma_strategy - .gamma(trials.len()) - .clamp(f64::EPSILON, 1.0 - f64::EPSILON); - - // Calculate the split point (gamma quantile) - // Ensure at least 1 trial in each group if possible - let n_good = ((trials.len() as f64 * gamma).ceil() as usize) - .max(1) - .min(trials.len().saturating_sub(1)); - - // Handle edge case: if we have only 1 trial, put it in good - if trials.len() == 1 { - return (vec![trials[0]], vec![]); - } - - let good: Vec<_> = sorted_indices[..n_good] - .iter() - .map(|&i| trials[i]) - .collect(); - let bad: Vec<_> = sorted_indices[n_good..] - .iter() - .map(|&i| trials[i]) - .collect(); - - (good, bad) - } - - /// Extracts parameter values from trials as a numeric observation matrix. - /// - /// This method converts trial parameter values into a matrix format suitable - /// for multivariate KDE fitting. Each row in the output represents one trial's - /// parameter values in the specified order. - /// - /// # Arguments - /// - /// * `trials` - Trials to extract observations from. - /// * `param_order` - The order of parameters in the output vectors. This ensures - /// consistent column ordering across the observation matrix. - /// - /// # Returns - /// - /// A `Vec<Vec<f64>>` where: - /// - Outer vec has one entry per trial - /// - Inner vec has one entry per parameter (in `param_order` order) - /// - Float values are used directly - /// - Int values are converted to f64 - /// - Categorical parameters are skipped (not included in output) - /// - /// # Panics - /// - /// This method does not panic. If a parameter is missing from a trial or has - /// an unsupported type (Categorical), it is simply skipped. - /// - /// # Examples - /// - /// ```ignore - /// use std::collections::HashMap; - /// use optimizer::sampler::tpe::MultivariateTpeSampler; - /// use optimizer::parameter::ParamId; - /// - /// let sampler = MultivariateTpeSampler::new(); - /// let trials = vec![/* ... completed trials ... */]; - /// let filtered = sampler.filter_trials(&trials, &search_space); - /// - /// // Extract observations for x and y in that order - /// let x_id = ParamId::new(); - /// let y_id = ParamId::new(); - /// let param_order = vec![x_id, y_id]; - /// let observations = sampler.extract_observations(&filtered, ¶m_order); - /// - /// // observations[i][0] is the x value for trial i - /// // observations[i][1] is the y value for trial i - /// ``` - #[must_use] - #[allow(clippy::cast_precision_loss)] - pub fn extract_observations( - &self, - trials: &[&CompletedTrial], - param_order: &[ParamId], - ) -> Vec<Vec<f64>> { - trials - .iter() - .map(|trial| { - param_order - .iter() - .filter_map(|param_id| { - trial.params.get(param_id).and_then(|value| match value { - crate::param::ParamValue::Float(f) => Some(*f), - crate::param::ParamValue::Int(i) => Some(*i as f64), - crate::param::ParamValue::Categorical(_) => None, // Skip categorical - }) - }) - .collect() - }) - .collect() - } - - /// Selects the best candidate from a set of samples using the joint acquisition function. - /// - /// This method implements the core TPE selection criterion: it generates candidates - /// from the "good" KDE (l(x)) and selects the one that maximizes the ratio l(x)/g(x), - /// which is equivalent to maximizing `log(l(x)) - log(g(x))`. - /// - /// # Arguments - /// - /// * `good_kde` - The KDE fitted on the "good" trials (low objective values). - /// * `bad_kde` - The KDE fitted on the "bad" trials (high objective values). - /// - /// # Returns - /// - /// A `Vec<f64>` representing the selected candidate point in the parameter space. - /// The point is chosen to maximize the expected improvement proxy l(x)/g(x). - /// - /// # Algorithm - /// - /// 1. Generate `n_ei_candidates` samples from `good_kde` - /// 2. For each candidate, compute `log(l(x)) - log(g(x))` - /// 3. Select the candidate with the highest ratio - /// - /// # Edge Cases - /// - /// - If `g(x)` is very small (near zero), the log-space computation handles this - /// gracefully without division issues. - /// - If all candidates have `-inf` log ratios, the first candidate is returned. - /// - /// # Examples - /// - /// ```ignore - /// use optimizer::sampler::tpe::MultivariateTpeSampler; - /// use optimizer::kde::MultivariateKDE; - /// - /// let sampler = MultivariateTpeSampler::builder() - /// .n_ei_candidates(24) - /// .seed(42) - /// .build() - /// .unwrap(); - /// - /// let good_obs = vec![vec![0.1, 0.2], vec![0.2, 0.3], vec![0.15, 0.25]]; - /// let bad_obs = vec![vec![0.8, 0.9], vec![0.7, 0.85], vec![0.9, 0.95]]; - /// - /// let good_kde = MultivariateKDE::new(good_obs).unwrap(); - /// let bad_kde = MultivariateKDE::new(bad_obs).unwrap(); - /// - /// let selected = sampler.select_candidate(&good_kde, &bad_kde); - /// // selected should be a point likely in the "good" region - /// ``` - #[must_use] - #[allow(dead_code)] // Used by tests - pub(crate) fn select_candidate( - &self, - good_kde: &crate::kde::MultivariateKDE, - bad_kde: &crate::kde::MultivariateKDE, - ) -> Vec<f64> { - let mut rng = self.rng.lock(); - - // Generate candidates from the good distribution - let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) - .map(|_| good_kde.sample(&mut rng)) - .collect(); - - // Compute log(l(x)) - log(g(x)) for each candidate - // This is equivalent to log(l(x)/g(x)) which we want to maximize - let log_ratios: Vec<f64> = candidates - .iter() - .map(|candidate| { - let log_l = good_kde.log_pdf(candidate); - let log_g = bad_kde.log_pdf(candidate); - log_l - log_g - }) - .collect(); - - // Find the candidate with the maximum log ratio - let mut best_idx = 0; - let mut best_ratio = f64::NEG_INFINITY; - - for (idx, &ratio) in log_ratios.iter().enumerate() { - // Handle NaN by treating it as worse than any finite value - if ratio > best_ratio || (best_ratio.is_nan() && !ratio.is_nan()) { - best_ratio = ratio; - best_idx = idx; - } - } - - candidates.into_iter().nth(best_idx).unwrap_or_default() - } - - /// Selects the best candidate using an external RNG. - /// - /// This variant accepts an external RNG, used when the caller already holds the lock. - fn select_candidate_with_rng( - &self, - good_kde: &crate::kde::MultivariateKDE, - bad_kde: &crate::kde::MultivariateKDE, - rng: &mut fastrand::Rng, - ) -> Vec<f64> { - // Generate candidates from the good distribution - let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) - .map(|_| good_kde.sample(rng)) - .collect(); - - // Compute log(l(x)) - log(g(x)) for each candidate - let log_ratios: Vec<f64> = candidates - .iter() - .map(|candidate| { - let log_l = good_kde.log_pdf(candidate); - let log_g = bad_kde.log_pdf(candidate); - log_l - log_g - }) - .collect(); - - // Find the candidate with the maximum log ratio - let mut best_idx = 0; - let mut best_ratio = f64::NEG_INFINITY; - - for (idx, &ratio) in log_ratios.iter().enumerate() { - if ratio > best_ratio || (best_ratio.is_nan() && !ratio.is_nan()) { - best_ratio = ratio; - best_idx = idx; - } - } - - candidates.into_iter().nth(best_idx).unwrap_or_default() - } - /// Samples all parameters uniformly at random. /// /// This is a fallback method used when multivariate TPE cannot be applied. @@ -897,760 +411,6 @@ impl MultivariateTpeSampler { // Non-grouped mode: use the original single-group logic self.sample_single_group(search_space, history, &mut rng) } - - /// Samples parameters by decomposing the search space into independent groups. - /// - /// When `group=true`, this method analyzes the trial history to identify groups of - /// parameters that always appear together, then samples each group independently - /// using multivariate TPE. This is more efficient when parameters naturally partition - /// into independent subsets (e.g., due to conditional search spaces). - /// - /// # Arguments - /// - /// * `search_space` - The full search space containing all parameters to sample. - /// * `history` - Completed trials from the optimization history. - /// - /// # Returns - /// - /// A `HashMap` mapping parameter names to their sampled values. - fn sample_with_groups( - &self, - search_space: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - ) -> HashMap<ParamId, ParamValue> { - use std::collections::HashSet; - - use super::GroupDecomposedSearchSpace; - - // Decompose the search space into independent parameter groups - let groups = GroupDecomposedSearchSpace::calculate(history); - - let mut result: HashMap<ParamId, ParamValue> = HashMap::new(); - - // Sample each group independently - for group in &groups { - // Build a sub-search space for this group - let group_search_space: HashMap<ParamId, Distribution> = search_space - .iter() - .filter(|(id, _)| group.contains(id)) - .map(|(id, dist)| (*id, dist.clone())) - .collect(); - - if group_search_space.is_empty() { - continue; - } - - // Filter history to trials that have at least one parameter in this group - let group_history: Vec<&CompletedTrial> = history - .iter() - .filter(|trial| { - trial - .distributions - .keys() - .any(|param_id| group.contains(param_id)) - }) - .collect(); - - // Build completed trials from references for the group - // We need to create a temporary slice for sample_group_internal - let group_history_owned: Vec<CompletedTrial> = - group_history.iter().map(|t| (*t).clone()).collect(); - - // Sample this group using multivariate TPE - let mut rng = self.rng.lock(); - let group_result = - self.sample_single_group(&group_search_space, &group_history_owned, &mut rng); - drop(rng); - - // Merge group results into the main result - for (id, value) in group_result { - result.insert(id, value); - } - } - - // Handle parameters not in any group (sample independently) - let grouped_params: HashSet<ParamId> = groups.iter().flatten().copied().collect(); - let ungrouped_params: HashMap<ParamId, Distribution> = search_space - .iter() - .filter(|(id, _)| !grouped_params.contains(id) && !result.contains_key(id)) - .map(|(id, dist)| (*id, dist.clone())) - .collect(); - - if !ungrouped_params.is_empty() { - // Sample ungrouped parameters uniformly (no history for them) - let mut rng = self.rng.lock(); - for (id, dist) in &ungrouped_params { - let value = crate::sampler::common::sample_random(&mut rng, dist); - result.insert(*id, value); - } - } - - result - } - - /// Samples parameters as a single group using multivariate TPE. - /// - /// This is the core multivariate TPE sampling logic, used both in non-grouped mode - /// and for sampling individual groups in grouped mode. - /// - /// # Arguments - /// - /// * `search_space` - The search space for this group of parameters. - /// * `history` - Completed trials to use for model fitting. - /// * `rng` - Random number generator (caller must hold lock). - /// - /// # Returns - /// - /// A `HashMap` mapping parameter names to their sampled values. - #[allow(clippy::too_many_lines)] - fn sample_single_group( - &self, - search_space: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - rng: &mut fastrand::Rng, - ) -> HashMap<ParamId, ParamValue> { - use super::IntersectionSearchSpace; - use crate::kde::MultivariateKDE; - - // Early returns for cases requiring random sampling - if history.len() < self.n_startup_trials { - return self.sample_all_uniform(search_space, rng); - } - - let intersection = IntersectionSearchSpace::calculate(history); - if intersection.is_empty() { - return self.sample_all_independent_with_rng(search_space, history, rng); - } - - let filtered = self.filter_trials(history, &intersection); - if filtered.len() < 2 { - return self.sample_all_independent_with_rng(search_space, history, rng); - } - - let (good, bad) = self.split_trials(&filtered); - - // Sample categorical parameters using TPE with l(x)/g(x) ratio - let mut result: HashMap<ParamId, ParamValue> = HashMap::new(); - for (param_id, dist) in &intersection { - if let Distribution::Categorical(d) = dist { - let good_indices = Self::extract_categorical_indices(&good, *param_id); - let bad_indices = Self::extract_categorical_indices(&bad, *param_id); - let idx = - Self::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); - result.insert(*param_id, ParamValue::Categorical(idx)); - } - } - - // Collect continuous parameters - let mut param_order: Vec<ParamId> = intersection - .iter() - .filter(|(_, dist)| !matches!(dist, Distribution::Categorical(_))) - .map(|(id, _)| *id) - .collect(); - - if param_order.is_empty() { - // Only categorical parameters in intersection - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - } - - param_order.sort_by_key(|id| format!("{id}")); - - // Extract observations and validate - let good_obs = self.extract_observations(&good, ¶m_order); - let bad_obs = self.extract_observations(&bad, ¶m_order); - let expected_dims = param_order.len(); - - let valid = !good_obs.is_empty() - && !bad_obs.is_empty() - && good_obs.iter().all(|obs| obs.len() == expected_dims) - && bad_obs.iter().all(|obs| obs.len() == expected_dims); - - if !valid { - // Observations invalid - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - } - - // Fit KDEs using let...else pattern - let Ok(good_kde) = MultivariateKDE::new(good_obs) else { - // KDE construction failed - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - }; - - let Ok(bad_kde) = MultivariateKDE::new(bad_obs) else { - // KDE construction failed - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - }; - - let selected = self.select_candidate_with_rng(&good_kde, &bad_kde, rng); - - // Map selected values to parameter ids - for (idx, param_id) in param_order.iter().enumerate() { - if let Some(dist) = intersection.get(param_id) { - let value = selected[idx]; - let param_value = self.convert_to_param_value(value, dist); - if let Some(pv) = param_value { - result.insert(*param_id, pv); - } - } - } - - // Fill remaining parameters using independent TPE sampling - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - result - } - - /// Converts a raw f64 value to a `ParamValue` based on the distribution. - #[allow(clippy::unused_self)] - fn convert_to_param_value(&self, value: f64, dist: &Distribution) -> Option<ParamValue> { - match dist { - Distribution::Float(d) => { - let clamped = value.clamp(d.low, d.high); - let stepped = if let Some(step) = d.step { - let steps = ((clamped - d.low) / step).round(); - (d.low + steps * step).clamp(d.low, d.high) - } else { - clamped - }; - Some(ParamValue::Float(stepped)) - } - Distribution::Int(d) => { - #[allow(clippy::cast_possible_truncation)] - let int_value = value.round() as i64; - let clamped = int_value.clamp(d.low, d.high); - let stepped = if let Some(step) = d.step { - let steps = (clamped - d.low) / step; - (d.low + steps * step).clamp(d.low, d.high) - } else { - clamped - }; - Some(ParamValue::Int(stepped)) - } - Distribution::Categorical(_) => None, - } - } - - /// Fills remaining parameters in result using independent TPE sampling. - /// - /// This method is used to sample parameters that are not in the intersection - /// search space. It uses independent univariate TPE sampling for each parameter, - /// similar to the standard [`TpeSampler`]. - /// - /// When there isn't enough history for a parameter, falls back to uniform sampling. - #[allow(dead_code)] - fn fill_remaining_independent( - &self, - search_space: &HashMap<ParamId, Distribution>, - _intersection: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - result: &mut HashMap<ParamId, ParamValue>, - ) { - // Identify parameters not in result (and not in intersection) - let missing_params: Vec<(&ParamId, &Distribution)> = search_space - .iter() - .filter(|(id, _)| !result.contains_key(id)) - .collect(); - - if missing_params.is_empty() { - return; - } - - // Split trials for independent sampling - let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); - - let mut rng = self.rng.lock(); - - for (param_id, dist) in missing_params { - let value = - self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); - result.insert(*param_id, value); - } - } - - /// Fills remaining parameters using independent TPE sampling with an external RNG. - /// - /// This variant accepts an external RNG, used when the caller already holds the lock. - fn fill_remaining_independent_with_rng( - &self, - search_space: &HashMap<ParamId, Distribution>, - _intersection: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - result: &mut HashMap<ParamId, ParamValue>, - rng: &mut fastrand::Rng, - ) { - // Identify parameters not in result (and not in intersection) - let missing_params: Vec<(&ParamId, &Distribution)> = search_space - .iter() - .filter(|(id, _)| !result.contains_key(id)) - .collect(); - - if missing_params.is_empty() { - return; - } - - // Split trials for independent sampling - let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); - - for (param_id, dist) in missing_params { - let value = - self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, rng); - result.insert(*param_id, value); - } - } - - /// Samples a single parameter using independent TPE. - /// - /// This method extracts values for the given parameter from good and bad trials, - /// fits univariate KDEs, and samples using the TPE acquisition function. - #[allow(clippy::too_many_lines)] - fn sample_independent_tpe( - &self, - param_id: ParamId, - distribution: &Distribution, - good_trials: &[&CompletedTrial], - bad_trials: &[&CompletedTrial], - rng: &mut fastrand::Rng, - ) -> ParamValue { - match distribution { - Distribution::Float(d) => { - let good_values: Vec<f64> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<f64> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let value = self.sample_tpe_float( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - rng, - ); - ParamValue::Float(value) - } - Distribution::Int(d) => { - let good_values: Vec<i64> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<i64> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let value = self.sample_tpe_int( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - rng, - ); - ParamValue::Int(value) - } - Distribution::Categorical(d) => { - let good_indices: Vec<usize> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - let bad_indices: Vec<usize> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - if good_indices.is_empty() || bad_indices.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let idx = - Self::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); - ParamValue::Categorical(idx) - } - } - } - - /// Samples using TPE for float distributions. - #[allow(clippy::too_many_arguments)] - fn sample_tpe_float( - &self, - low: f64, - high: f64, - log_scale: bool, - step: Option<f64>, - good_values: Vec<f64>, - bad_values: Vec<f64>, - rng: &mut fastrand::Rng, - ) -> f64 { - use crate::kde::KernelDensityEstimator; - - // Transform to internal space (log space if needed) - let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { - let i_low = low.ln(); - let i_high = high.ln(); - let g = { - let mut v = good_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - let b = { - let mut v = bad_values; - for x in &mut v { - *x = x.ln(); - } - v - }; - (i_low, i_high, g, b) - } else { - (low, high, good_values, bad_values) - }; - - // Fit KDEs to good and bad groups - let l_kde = KernelDensityEstimator::new(good_internal); - let g_kde = KernelDensityEstimator::new(bad_internal); - - // If KDE construction fails, fall back to uniform sampling - let (Ok(l_kde), Ok(g_kde)) = (l_kde, g_kde) else { - return rng_util::f64_range(rng, low, high); - }; - - // Generate candidates from l(x) and select the one with best l(x)/g(x) ratio - let mut best_candidate = internal_low; - let mut best_ratio = f64::NEG_INFINITY; - - for _ in 0..self.n_ei_candidates { - let candidate = l_kde.sample(rng); - - // Clamp to bounds - let candidate = candidate.clamp(internal_low, internal_high); - - let l_density = l_kde.pdf(candidate); - let g_density = g_kde.pdf(candidate); - - // Compute l(x)/g(x) ratio, handling zero density - let ratio = if g_density < f64::EPSILON { - if l_density > f64::EPSILON { - f64::INFINITY - } else { - 0.0 - } - } else { - l_density / g_density - }; - - if ratio > best_ratio { - best_ratio = ratio; - best_candidate = candidate; - } - } - - // Transform back from internal space - let mut value = if log_scale { - best_candidate.exp() - } else { - best_candidate - }; - - // Apply step constraint if present - if let Some(step) = step { - let k = ((value - low) / step).round(); - value = low + k * step; - } - - // Ensure value is within bounds - value.clamp(low, high) - } - - /// Samples using TPE for integer distributions. - #[allow( - clippy::too_many_arguments, - clippy::cast_precision_loss, - clippy::cast_possible_truncation - )] - fn sample_tpe_int( - &self, - low: i64, - high: i64, - log_scale: bool, - step: Option<i64>, - good_values: Vec<i64>, - bad_values: Vec<i64>, - rng: &mut fastrand::Rng, - ) -> i64 { - // Convert to floats for KDE - let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); - let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); - - // Use float TPE sampling - let float_value = self.sample_tpe_float( - low as f64, - high as f64, - log_scale, - step.map(|s| s as f64), - good_floats, - bad_floats, - rng, - ); - - // Round to nearest integer - let int_value = float_value.round() as i64; - - // Apply step constraint if present - let int_value = if let Some(step) = step { - let k = ((int_value - low) as f64 / step as f64).round() as i64; - low + k * step - } else { - int_value - }; - - // Ensure value is within bounds - int_value.clamp(low, high) - } - - /// Samples all parameters using independent TPE sampling. - /// - /// This is used as a complete fallback when no intersection search space exists. - #[allow(dead_code)] // Used by tests - fn sample_all_independent( - &self, - search_space: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - ) -> HashMap<ParamId, ParamValue> { - // Split trials for independent sampling - let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); - - let mut rng = self.rng.lock(); - let mut result = HashMap::new(); - - for (param_id, dist) in search_space { - let value = - self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); - result.insert(*param_id, value); - } - - result - } - - /// Samples all parameters using independent TPE sampling with an external RNG. - /// - /// This variant accepts an external RNG, used when the caller already holds the lock. - fn sample_all_independent_with_rng( - &self, - search_space: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - rng: &mut fastrand::Rng, - ) -> HashMap<ParamId, ParamValue> { - // Split trials for independent sampling - let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); - - let mut result = HashMap::new(); - - for (param_id, dist) in search_space { - let value = - self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, rng); - result.insert(*param_id, value); - } - - result - } - - /// Samples using TPE for categorical distributions. - /// - /// This method computes kernel-weighted category probabilities based on - /// the good and bad trial groups, then samples proportionally to the - /// `l(x)/g(x)` ratio for each category. - /// - /// # Arguments - /// - /// * `n_choices` - The number of categories in the categorical distribution. - /// * `good_indices` - Category indices from the "good" trials. - /// * `bad_indices` - Category indices from the "bad" trials. - /// * `rng` - Random number generator for sampling. - /// - /// # Returns - /// - /// The selected category index. - /// - /// # Algorithm - /// - /// 1. Count occurrences of each category in good and bad groups - /// 2. Apply Laplace smoothing (add 1 to each count) to avoid zero probabilities - /// 3. Compute `l(x)/g(x)` ratio for each category - /// 4. Sample proportionally to the computed weights - #[allow(clippy::cast_precision_loss)] - fn sample_tpe_categorical( - n_choices: usize, - good_indices: &[usize], - bad_indices: &[usize], - rng: &mut fastrand::Rng, - ) -> usize { - // Stack-allocate for the common case (<=32 choices), heap for rare large cases - let mut good_buf = [0usize; 32]; - let mut bad_buf = [0usize; 32]; - let mut weight_buf = [0.0f64; 32]; - - let mut good_vec; - let mut bad_vec; - let mut weight_vec; - - let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) = - if n_choices <= 32 { - ( - &mut good_buf[..n_choices], - &mut bad_buf[..n_choices], - &mut weight_buf[..n_choices], - ) - } else { - good_vec = vec![0usize; n_choices]; - bad_vec = vec![0usize; n_choices]; - weight_vec = vec![0.0f64; n_choices]; - (&mut good_vec, &mut bad_vec, &mut weight_vec) - }; - - // Count occurrences in good and bad groups - for &idx in good_indices { - if idx < n_choices { - good_counts[idx] += 1; - } - } - for &idx in bad_indices { - if idx < n_choices { - bad_counts[idx] += 1; - } - } - - // Add Laplace smoothing to avoid zero probabilities - let good_total = good_indices.len() as f64 + n_choices as f64; - let bad_total = bad_indices.len() as f64 + n_choices as f64; - - // Calculate l(x)/g(x) ratio for each category - for i in 0..n_choices { - let l_prob = (good_counts[i] as f64 + 1.0) / good_total; - let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total; - weights[i] = l_prob / g_prob; - } - - // Sample proportionally to weights - let total_weight: f64 = weights.iter().sum(); - let threshold = rng.f64() * total_weight; - - let mut cumulative = 0.0; - for (i, &w) in weights.iter().enumerate() { - cumulative += w; - if cumulative >= threshold { - return i; - } - } - - // Fallback to last index (shouldn't happen) - n_choices - 1 - } - - /// Extracts categorical indices from trials for a specific parameter. - /// - /// # Arguments - /// - /// * `trials` - Trials to extract from. - /// * `param_name` - The name of the categorical parameter. - /// - /// # Returns - /// - /// A vector of category indices from the trials. - fn extract_categorical_indices(trials: &[&CompletedTrial], param_id: ParamId) -> Vec<usize> { - trials - .iter() - .filter_map(|trial| { - trial.params.get(¶m_id).and_then(|value| { - if let ParamValue::Categorical(idx) = value { - Some(*idx) - } else { - None - } - }) - }) - .collect() - } } impl Default for MultivariateTpeSampler { @@ -4883,7 +3643,7 @@ mod tests { // Sample many times and check bias toward good category let mut counts = [0usize; 3]; for _ in 0..1000 { - let idx = MultivariateTpeSampler::sample_tpe_categorical( + let idx = crate::sampler::tpe::common::sample_tpe_categorical( 3, &good_indices, &bad_indices, @@ -4917,7 +3677,7 @@ mod tests { let mut sampled_two = false; for _ in 0..1000 { - let idx = MultivariateTpeSampler::sample_tpe_categorical( + let idx = crate::sampler::tpe::common::sample_tpe_categorical( 3, &good_indices, &bad_indices, @@ -4945,7 +3705,7 @@ mod tests { let mut counts = [0usize; 3]; for _ in 0..1000 { - let idx = MultivariateTpeSampler::sample_tpe_categorical( + let idx = crate::sampler::tpe::common::sample_tpe_categorical( 3, &good_indices, &bad_indices, @@ -4970,7 +3730,7 @@ mod tests { // All samples should be valid indices for _ in 0..100 { - let idx = MultivariateTpeSampler::sample_tpe_categorical( + let idx = crate::sampler::tpe::common::sample_tpe_categorical( n_choices, &good_indices, &bad_indices, diff --git a/src/sampler/tpe/multivariate/trials.rs b/src/sampler/tpe/multivariate/trials.rs new file mode 100644 index 0000000..e1d31c9 --- /dev/null +++ b/src/sampler/tpe/multivariate/trials.rs @@ -0,0 +1,220 @@ +//! Trial processing for the multivariate TPE sampler. +//! +//! Contains constant-liar imputation, trial filtering, good/bad splitting, +//! observation extraction, and categorical index extraction. + +use std::collections::HashMap; + +use crate::distribution::Distribution; +use crate::param::ParamValue; +use crate::parameter::ParamId; +use crate::sampler::{CompletedTrial, PendingTrial}; + +use super::{ConstantLiarStrategy, MultivariateTpeSampler}; + +impl MultivariateTpeSampler { + /// Imputes objective values for pending trials based on the constant liar strategy. + /// + /// In parallel optimization, multiple trials may be running simultaneously. This method + /// assigns "lie" values to pending trials so they can be included in the model fitting, + /// which helps avoid redundant exploration of the same region. + /// + /// # Arguments + /// + /// * `pending_trials` - Trials that are currently running and have no objective value yet. + /// * `completed_trials` - Trials that have completed and have objective values. + /// + /// # Returns + /// + /// A vector of `CompletedTrial` objects containing both the original completed trials + /// and the pending trials with imputed values. If the strategy is `None`, returns + /// only the completed trials (pending trials are ignored). + #[must_use] + pub fn impute_pending_trials( + &self, + pending_trials: &[PendingTrial], + completed_trials: &[CompletedTrial], + ) -> Vec<CompletedTrial> { + // Start with a copy of completed trials + let mut result: Vec<CompletedTrial> = completed_trials.to_vec(); + + // If strategy is None or no pending trials, just return completed trials + if matches!(self.constant_liar, ConstantLiarStrategy::None) || pending_trials.is_empty() { + return result; + } + + // Compute the imputation value based on strategy + let imputed_value = self.compute_imputation_value(completed_trials); + + // Convert pending trials to completed trials with imputed values + for pending in pending_trials { + result.push(CompletedTrial::new( + pending.id, + pending.params.clone(), + pending.distributions.clone(), + HashMap::new(), + imputed_value, + )); + } + + result + } + + /// Computes the imputation value based on the constant liar strategy. + /// + /// This is a helper method used by [`impute_pending_trials`](Self::impute_pending_trials). + #[allow(clippy::cast_precision_loss)] + pub(crate) fn compute_imputation_value(&self, completed_trials: &[CompletedTrial]) -> f64 { + match self.constant_liar { + ConstantLiarStrategy::None => 0.0, // This case is handled before calling this method + ConstantLiarStrategy::Mean => { + if completed_trials.is_empty() { + 0.0 + } else { + let sum: f64 = completed_trials.iter().map(|t| t.value).sum(); + sum / completed_trials.len() as f64 + } + } + ConstantLiarStrategy::Best => { + // Best means minimum for minimization problems + completed_trials + .iter() + .map(|t| t.value) + .fold(f64::INFINITY, f64::min) + } + ConstantLiarStrategy::Worst => { + // Worst means maximum for minimization problems + completed_trials + .iter() + .map(|t| t.value) + .fold(f64::NEG_INFINITY, f64::max) + } + ConstantLiarStrategy::Custom(v) => v, + } + } + + /// Filters trials to those containing all parameters in the search space. + /// + /// Only trials that contain ALL parameters in the search space are included, + /// ensuring we can model the joint distribution over all parameters. + #[must_use] + pub fn filter_trials<'a>( + &self, + history: &'a [CompletedTrial], + search_space: &HashMap<ParamId, Distribution>, + ) -> Vec<&'a CompletedTrial> { + history + .iter() + .filter(|trial| { + // Include trial only if it has ALL parameters in the search space + search_space + .keys() + .all(|param_id| trial.params.contains_key(param_id)) + }) + .collect() + } + + /// Splits filtered trials into good and bad groups based on the gamma quantile. + /// + /// The gamma value is computed dynamically using the configured [`GammaStrategy`]. + /// Trials are sorted by objective value (ascending for minimization), and the + /// gamma quantile determines the split point. + #[allow( + clippy::cast_precision_loss, + clippy::cast_possible_truncation, + clippy::cast_sign_loss + )] + #[must_use] + pub fn split_trials<'a>( + &self, + trials: &[&'a CompletedTrial], + ) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) { + if trials.is_empty() { + return (vec![], vec![]); + } + + // Sort trials by objective value (ascending for minimization) + let mut sorted_indices: Vec<usize> = (0..trials.len()).collect(); + sorted_indices.sort_by(|&a, &b| { + trials[a] + .value + .partial_cmp(&trials[b].value) + .unwrap_or(core::cmp::Ordering::Equal) + }); + + // Compute gamma using the strategy and clamp to valid range + let gamma = self + .gamma_strategy + .gamma(trials.len()) + .clamp(f64::EPSILON, 1.0 - f64::EPSILON); + + // Calculate the split point (gamma quantile) + // Ensure at least 1 trial in each group if possible + let n_good = ((trials.len() as f64 * gamma).ceil() as usize) + .max(1) + .min(trials.len().saturating_sub(1)); + + // Handle edge case: if we have only 1 trial, put it in good + if trials.len() == 1 { + return (vec![trials[0]], vec![]); + } + + let good: Vec<_> = sorted_indices[..n_good] + .iter() + .map(|&i| trials[i]) + .collect(); + let bad: Vec<_> = sorted_indices[n_good..] + .iter() + .map(|&i| trials[i]) + .collect(); + + (good, bad) + } + + /// Extracts parameter values from trials as a numeric observation matrix. + /// + /// Each row in the output represents one trial's parameter values in the specified order. + /// Categorical parameters are skipped. + #[must_use] + #[allow(clippy::cast_precision_loss)] + pub fn extract_observations( + &self, + trials: &[&CompletedTrial], + param_order: &[ParamId], + ) -> Vec<Vec<f64>> { + trials + .iter() + .map(|trial| { + param_order + .iter() + .filter_map(|param_id| { + trial.params.get(param_id).and_then(|value| match value { + crate::param::ParamValue::Float(f) => Some(*f), + crate::param::ParamValue::Int(i) => Some(*i as f64), + crate::param::ParamValue::Categorical(_) => None, // Skip categorical + }) + }) + .collect() + }) + .collect() + } + + /// Extracts categorical indices from trials for a specific parameter. + pub(crate) fn extract_categorical_indices( + trials: &[&CompletedTrial], + param_id: ParamId, + ) -> Vec<usize> { + trials + .iter() + .filter_map(|trial| { + trial.params.get(¶m_id).and_then(|value| { + if let ParamValue::Categorical(idx) = value { + Some(*idx) + } else { + None + } + }) + }) + .collect() + } +} diff --git a/src/study/persistence.rs b/src/study/persistence.rs index 78b94cf..6e6ffc9 100644 --- a/src/study/persistence.rs +++ b/src/study/persistence.rs @@ -1,8 +1,12 @@ +#[cfg(feature = "serde")] use std::collections::HashMap; +#[cfg(feature = "serde")] use crate::sampler::CompletedTrial; +#[cfg(feature = "serde")] use crate::types::Direction; +#[cfg(feature = "serde")] use super::Study; /// A serializable snapshot of a study's state. From f0798ca58a1baba00f8f96acc4a68b4a1f1a8be9 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:42:34 +0100 Subject: [PATCH 31/48] refactor(study): move export_json() from persistence to export module --- src/study/export.rs | 17 +++++++++++++++++ src/study/persistence.rs | 14 -------------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/study/export.rs b/src/study/export.rs index 913ab2b..62c6368 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -220,6 +220,23 @@ where } } +#[cfg(feature = "serde")] +impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { + /// Export trials as a pretty-printed JSON array to a file. + /// + /// Each element in the array is a serialized [`CompletedTrial`]. + /// Requires the `serde` feature. + /// + /// # Errors + /// + /// Returns an I/O error if the file cannot be created or written. + pub fn export_json(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { + let file = std::fs::File::create(path)?; + let trials = self.trials(); + serde_json::to_writer_pretty(file, &trials).map_err(std::io::Error::other) + } +} + impl Study<f64> { /// Generate an HTML report with interactive Plotly.js charts. /// diff --git a/src/study/persistence.rs b/src/study/persistence.rs index 6e6ffc9..c9174d6 100644 --- a/src/study/persistence.rs +++ b/src/study/persistence.rs @@ -41,20 +41,6 @@ pub struct StudySnapshot<V> { #[cfg(feature = "serde")] impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { - /// Export trials as a pretty-printed JSON array to a file. - /// - /// Each element in the array is a serialized [`CompletedTrial`]. - /// Requires the `serde` feature. - /// - /// # Errors - /// - /// Returns an I/O error if the file cannot be created or written. - pub fn export_json(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { - let file = std::fs::File::create(path)?; - let trials = self.trials(); - serde_json::to_writer_pretty(file, &trials).map_err(std::io::Error::other) - } - /// Save the study state to a JSON file. /// /// # Errors From 38a20bbc42a345ec95e9c1a474ea8ad7196d7fd2 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:47:53 +0100 Subject: [PATCH 32/48] fix(docs): resolve all broken rustdoc intra-doc links Fix direct path references for in-scope items (GammaStrategy, CompletedTrial) and convert feature-gated items to plain backtick formatting so docs build cleanly without --all-features. --- src/lib.rs | 18 +++++++++--------- src/sampler/de.rs | 4 ++-- src/sampler/random.rs | 2 +- src/sampler/tpe/multivariate/trials.rs | 2 +- src/storage/mod.rs | 6 +++--- src/study/export.rs | 2 +- src/study/mod.rs | 2 +- 7 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a7c507e..9ce87ad 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,9 +55,9 @@ //! | [`RandomSampler`](sampler::RandomSampler) | Uniform random | Baselines, high-dimensional | — | //! | [`TpeSampler`](sampler::TpeSampler) | Tree-Parzen Estimator | General-purpose Bayesian | — | //! | [`GridSearchSampler`](sampler::GridSampler) | Exhaustive grid | Small, discrete spaces | — | -//! | [`SobolSampler`](sampler::SobolSampler) | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | -//! | [`CmaEsSampler`](sampler::CmaEsSampler) | CMA-ES | Continuous, moderate dimensions | `cma-es` | -//! | [`GpSampler`](sampler::GpSampler) | Gaussian Process + EI | Expensive objectives, few trials | `gp` | +//! | `SobolSampler` | Sobol quasi-random sequence | Space-filling, low dimensions | `sobol` | +//! | `CmaEsSampler` | CMA-ES | Continuous, moderate dimensions | `cma-es` | +//! | `GpSampler` | Gaussian Process + EI | Expensive objectives, few trials | `gp` | //! | [`DESampler`](sampler::DESampler) | Differential Evolution | Non-convex, population-based | — | //! | [`BohbSampler`](sampler::BohbSampler) | BOHB (TPE + `HyperBand`) | Budget-aware early stopping | — | //! @@ -74,13 +74,13 @@ //! //! | Flag | What it enables | Default | //! |------|----------------|---------| -//! | `async` | Async/parallel optimization via tokio ([`Study::optimize_async`], [`Study::optimize_parallel`]) | off | +//! | `async` | Async/parallel optimization via tokio (`Study::optimize_async`, `Study::optimize_parallel`) | off | //! | `derive` | `#[derive(Categorical)]` for enum parameters | off | -//! | `serde` | `Serialize`/`Deserialize` on public types, [`Study::save`]/[`Study::load`] | off | -//! | `journal` | [`JournalStorage`](storage::JournalStorage) — JSONL persistence with file locking (enables `serde`) | off | -//! | `sobol` | [`SobolSampler`](sampler::SobolSampler) — quasi-random low-discrepancy sequences | off | -//! | `cma-es` | [`CmaEsSampler`](sampler::CmaEsSampler) — Covariance Matrix Adaptation Evolution Strategy | off | -//! | `gp` | [`GpSampler`](sampler::GpSampler) — Gaussian Process surrogate with Expected Improvement | off | +//! | `serde` | `Serialize`/`Deserialize` on public types, `Study::save`/`Study::load` | off | +//! | `journal` | `JournalStorage` — JSONL persistence with file locking (enables `serde`) | off | +//! | `sobol` | `SobolSampler` — quasi-random low-discrepancy sequences | off | +//! | `cma-es` | `CmaEsSampler` — Covariance Matrix Adaptation Evolution Strategy | off | +//! | `gp` | `GpSampler` — Gaussian Process surrogate with Expected Improvement | off | //! | `tracing` | Structured log events via [`tracing`](https://docs.rs/tracing) at key optimization points | off | /// Emit a `tracing::info!` event when the `tracing` feature is enabled. diff --git a/src/sampler/de.rs b/src/sampler/de.rs index 274078a..57ded49 100644 --- a/src/sampler/de.rs +++ b/src/sampler/de.rs @@ -29,9 +29,9 @@ //! - **No feature flags required** — DE is available with default features. //! //! For non-separable problems in moderate dimensions, consider -//! [`CmaEsSampler`](super::cma_es::CmaEsSampler) which learns parameter +//! `CmaEsSampler` which learns parameter //! correlations. For expensive functions with few dimensions, consider -//! [`GpSampler`](super::gp::GpSampler). +//! `GpSampler`. //! //! # Configuration //! diff --git a/src/sampler/random.rs b/src/sampler/random.rs index 2850376..22036b7 100644 --- a/src/sampler/random.rs +++ b/src/sampler/random.rs @@ -15,7 +15,7 @@ //! surprisingly competitive. //! //! For better uniform coverage without model fitting, consider -//! [`SobolSampler`](super::sobol::SobolSampler) (requires the `sobol` +//! `SobolSampler` (requires the `sobol` //! feature flag). //! //! # Example diff --git a/src/sampler/tpe/multivariate/trials.rs b/src/sampler/tpe/multivariate/trials.rs index e1d31c9..6e1f00d 100644 --- a/src/sampler/tpe/multivariate/trials.rs +++ b/src/sampler/tpe/multivariate/trials.rs @@ -116,7 +116,7 @@ impl MultivariateTpeSampler { /// Splits filtered trials into good and bad groups based on the gamma quantile. /// - /// The gamma value is computed dynamically using the configured [`GammaStrategy`]. + /// The gamma value is computed dynamically using the configured [`super::GammaStrategy`]. /// Trials are sorted by objective value (ascending for minimization), and the /// gamma quantile determines the split point. #[allow( diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 1ed6357..d0e32a0 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -9,12 +9,12 @@ //! | Backend | Description | Feature flag | //! |---------|-------------|-------------| //! | [`MemoryStorage`] | In-memory `Vec` behind a read-write lock (the default) | — | -//! | [`JournalStorage`] | JSONL file with `fs2` file locking for multi-process sharing | `journal` | +//! | `JournalStorage` | JSONL file with `fs2` file locking for multi-process sharing | `journal` | //! //! # When to swap backends //! //! The default [`MemoryStorage`] is sufficient for single-process studies -//! where persistence is not needed. Switch to [`JournalStorage`] when you +//! where persistence is not needed. Switch to `JournalStorage` when you //! want to: //! //! - **Resume** a study after a process restart. @@ -59,7 +59,7 @@ use crate::sampler::CompletedTrial; /// a plain `Vec` behind a read-write lock. /// /// Implementations must be `Send + Sync` because a study may be shared -/// across threads (e.g. via [`optimize_parallel`](crate::Study::optimize_parallel)). +/// across threads (e.g. via `Study::optimize_parallel`). pub trait Storage<V>: Send + Sync { /// Append a completed trial to the store. fn push(&self, trial: CompletedTrial<V>); diff --git a/src/study/export.rs b/src/study/export.rs index 62c6368..0db9ef9 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -224,7 +224,7 @@ where impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { /// Export trials as a pretty-printed JSON array to a file. /// - /// Each element in the array is a serialized [`CompletedTrial`]. + /// Each element in the array is a serialized [`crate::CompletedTrial`]. /// Requires the `serde` feature. /// /// # Errors diff --git a/src/study/mod.rs b/src/study/mod.rs index f2981c5..00948c1 100644 --- a/src/study/mod.rs +++ b/src/study/mod.rs @@ -225,7 +225,7 @@ where /// /// This is the most general constructor — all other constructors /// delegate to this one. Use it when you need a non-default storage - /// backend (e.g., [`JournalStorage`](crate::storage::JournalStorage)). + /// backend (e.g., `JournalStorage`). /// /// # Arguments /// From cc50e1ad4b7cb39688ea3bd4f91e61fed35ebf9a Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:52:59 +0100 Subject: [PATCH 33/48] refactor: reduce #[allow] attributes and extract per-distribution helpers - Remove stale #[allow(dead_code)] and #[allow(too_many_lines)] attributes - Replace #[allow(dead_code)] with #[cfg(test)] for test-only methods - Delete unused fill_remaining_independent() and ParamMeta.dist field - Extract sample_float/int/categorical helpers from TpeSampler, MotpeSampler, and SamplingEngine to shorten match-heavy sample() methods - Extract try_fit_kdes() helper to consolidate repeated fallback blocks - Refactor sample_tpe_float/int to take &FloatDistribution/&IntDistribution instead of destructured args, removing #[allow(too_many_arguments)] --- src/kde/multivariate.rs | 9 +- src/sampler/motpe.rs | 234 ++++++++++-------- src/sampler/tpe/common.rs | 34 ++- src/sampler/tpe/multivariate/engine.rs | 330 ++++++++++++------------- src/sampler/tpe/multivariate/mod.rs | 1 - src/sampler/tpe/sampler.rs | 242 +++++++++--------- src/study/export.rs | 2 +- src/visualization.rs | 6 +- 8 files changed, 427 insertions(+), 431 deletions(-) diff --git a/src/kde/multivariate.rs b/src/kde/multivariate.rs index 00bd843..ef1c02e 100644 --- a/src/kde/multivariate.rs +++ b/src/kde/multivariate.rs @@ -30,7 +30,6 @@ use crate::error::{Error, Result}; /// // Get dimensionality /// assert_eq!(kde.n_dims(), 2); /// ``` -#[allow(dead_code)] // Fields and methods will be used in subsequent stories (US-003, US-004) #[derive(Clone, Debug)] pub(crate) struct MultivariateKDE { /// The sample points used to construct the KDE. @@ -43,7 +42,6 @@ pub(crate) struct MultivariateKDE { n_dims: usize, } -#[allow(dead_code)] // Methods will be used in subsequent stories (US-003, US-004) impl MultivariateKDE { /// Creates a new multivariate KDE with automatic bandwidth selection using Scott's rule. /// @@ -100,6 +98,7 @@ impl MultivariateKDE { /// Returns `Error::ZeroDimensions` if samples have zero dimensions. /// Returns `Error::BandwidthDimensionMismatch` if bandwidths length doesn't match dimensions. /// Returns `Error::InvalidBandwidth` if any bandwidth is not positive. + #[cfg(test)] pub(crate) fn with_bandwidths(samples: Vec<Vec<f64>>, bandwidths: Vec<f64>) -> Result<Self> { if samples.is_empty() { return Err(Error::EmptySamples); @@ -180,12 +179,13 @@ impl MultivariateKDE { } /// Returns the number of dimensions. + #[cfg(test)] pub(crate) fn n_dims(&self) -> usize { self.n_dims } /// Returns the number of samples. - #[allow(dead_code)] // Will be used in subsequent stories + #[cfg(test)] pub(crate) fn n_samples(&self) -> usize { self.samples.len() } @@ -197,7 +197,7 @@ impl MultivariateKDE { } /// Returns a reference to the samples. - #[allow(dead_code)] // Will be used in subsequent stories + #[cfg(test)] pub(crate) fn samples(&self) -> &[Vec<f64>] { &self.samples } @@ -288,6 +288,7 @@ impl MultivariateKDE { /// # Panics /// /// Panics if `x.len() != self.n_dims`. + #[cfg(test)] pub(crate) fn pdf(&self, x: &[f64]) -> f64 { self.log_pdf(x).exp() } diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index a84580d..8049dee 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -214,8 +214,130 @@ impl Default for MotpeSampler { } } +impl MotpeSampler { + fn sample_float( + &self, + d: &crate::distribution::FloatDistribution, + good_trials: &[&MultiObjectiveTrial], + bad_trials: &[&MultiObjectiveTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_values: Vec<f64> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<f64> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return ParamValue::Float(rng_util::f64_range(rng, d.low, d.high)); + } + + let value = tpe_common::sample_tpe_float( + d, + good_values, + bad_values, + self.n_ei_candidates, + self.kde_bandwidth, + rng, + ); + ParamValue::Float(value) + } + + fn sample_int( + &self, + d: &crate::distribution::IntDistribution, + good_trials: &[&MultiObjectiveTrial], + bad_trials: &[&MultiObjectiveTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_values: Vec<i64> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<i64> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return common::sample_random(rng, &Distribution::Int(d.clone())); + } + + let value = tpe_common::sample_tpe_int( + d, + good_values, + bad_values, + self.n_ei_candidates, + self.kde_bandwidth, + rng, + ); + ParamValue::Int(value) + } + + #[allow(clippy::unused_self)] + fn sample_categorical( + &self, + d: &crate::distribution::CategoricalDistribution, + good_trials: &[&MultiObjectiveTrial], + bad_trials: &[&MultiObjectiveTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_indices: Vec<usize> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + let bad_indices: Vec<usize> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + if good_indices.is_empty() || bad_indices.is_empty() { + return common::sample_random(rng, &Distribution::Categorical(d.clone())); + } + + let index = + tpe_common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); + ParamValue::Categorical(index) + } +} + impl MultiObjectiveSampler for MotpeSampler { - #[allow(clippy::too_many_lines)] fn sample( &self, distribution: &Distribution, @@ -247,114 +369,10 @@ impl MultiObjectiveSampler for MotpeSampler { } match distribution { - Distribution::Float(d) => { - let good_values: Vec<f64> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<f64> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let value = tpe_common::sample_tpe_float( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - self.kde_bandwidth, - &mut rng, - ); - ParamValue::Float(value) - } - Distribution::Int(d) => { - let good_values: Vec<i64> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<i64> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let value = tpe_common::sample_tpe_int( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - self.kde_bandwidth, - &mut rng, - ); - ParamValue::Int(value) - } + Distribution::Float(d) => self.sample_float(d, &good_trials, &bad_trials, &mut rng), + Distribution::Int(d) => self.sample_int(d, &good_trials, &bad_trials, &mut rng), Distribution::Categorical(d) => { - let good_indices: Vec<usize> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - let bad_indices: Vec<usize> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - if good_indices.is_empty() || bad_indices.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let index = tpe_common::sample_tpe_categorical( - d.n_choices, - &good_indices, - &bad_indices, - &mut rng, - ); - ParamValue::Categorical(index) + self.sample_categorical(d, &good_trials, &bad_trials, &mut rng) } } } diff --git a/src/sampler/tpe/common.rs b/src/sampler/tpe/common.rs index c566ae9..02b41b8 100644 --- a/src/sampler/tpe/common.rs +++ b/src/sampler/tpe/common.rs @@ -1,21 +1,20 @@ //! Shared TPE sampling functions used by both `TpeSampler` and `MotpeSampler`. +use crate::distribution::{FloatDistribution, IntDistribution}; use crate::kde::KernelDensityEstimator; use crate::rng_util; /// Samples using TPE for float distributions. -#[allow(clippy::too_many_arguments)] pub(crate) fn sample_tpe_float( - low: f64, - high: f64, - log_scale: bool, - step: Option<f64>, + dist: &FloatDistribution, good_values: Vec<f64>, bad_values: Vec<f64>, n_ei_candidates: usize, kde_bandwidth: Option<f64>, rng: &mut fastrand::Rng, ) -> f64 { + let (low, high, log_scale, step) = (dist.low, dist.high, dist.log_scale, dist.step); + // Transform to internal space (log space if needed) let (internal_low, internal_high, good_internal, bad_internal) = if log_scale { let i_low = low.ln(); @@ -99,32 +98,31 @@ pub(crate) fn sample_tpe_float( } /// Samples using TPE for integer distributions. -#[allow( - clippy::too_many_arguments, - clippy::cast_precision_loss, - clippy::cast_possible_truncation -)] +#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)] pub(crate) fn sample_tpe_int( - low: i64, - high: i64, - log_scale: bool, - step: Option<i64>, + dist: &IntDistribution, good_values: Vec<i64>, bad_values: Vec<i64>, n_ei_candidates: usize, kde_bandwidth: Option<f64>, rng: &mut fastrand::Rng, ) -> i64 { + let (low, high, log_scale, step) = (dist.low, dist.high, dist.log_scale, dist.step); + // Convert to floats for KDE let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect(); let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect(); + let float_dist = FloatDistribution { + low: low as f64, + high: high as f64, + log_scale, + step: step.map(|s| s as f64), + }; + // Use float TPE sampling let float_value = sample_tpe_float( - low as f64, - high as f64, - log_scale, - step.map(|s| s as f64), + &float_dist, good_floats, bad_floats, n_ei_candidates, diff --git a/src/sampler/tpe/multivariate/engine.rs b/src/sampler/tpe/multivariate/engine.rs index 1002b29..52f5f53 100644 --- a/src/sampler/tpe/multivariate/engine.rs +++ b/src/sampler/tpe/multivariate/engine.rs @@ -103,6 +103,30 @@ impl MultivariateTpeSampler { result } + /// Validates observations and fits multivariate KDEs for good and bad groups. + /// + /// Returns `None` if observations are invalid or KDE construction fails. + fn try_fit_kdes( + good_obs: Vec<Vec<f64>>, + bad_obs: Vec<Vec<f64>>, + expected_dims: usize, + ) -> Option<(crate::kde::MultivariateKDE, crate::kde::MultivariateKDE)> { + use crate::kde::MultivariateKDE; + + let valid = !good_obs.is_empty() + && !bad_obs.is_empty() + && good_obs.iter().all(|obs| obs.len() == expected_dims) + && bad_obs.iter().all(|obs| obs.len() == expected_dims); + + if !valid { + return None; + } + + let good_kde = MultivariateKDE::new(good_obs).ok()?; + let bad_kde = MultivariateKDE::new(bad_obs).ok()?; + Some((good_kde, bad_kde)) + } + /// Samples parameters as a single group using multivariate TPE. /// /// This is the core multivariate TPE sampling logic, used both in non-grouped mode @@ -117,14 +141,12 @@ impl MultivariateTpeSampler { /// # Returns /// /// A `HashMap` mapping parameter names to their sampled values. - #[allow(clippy::too_many_lines)] pub(crate) fn sample_single_group( &self, search_space: &HashMap<ParamId, Distribution>, history: &[CompletedTrial], rng: &mut fastrand::Rng, ) -> HashMap<ParamId, ParamValue> { - use crate::kde::MultivariateKDE; use crate::sampler::tpe::IntersectionSearchSpace; use crate::sampler::tpe::common; @@ -165,7 +187,6 @@ impl MultivariateTpeSampler { .collect(); if param_order.is_empty() { - // Only categorical parameters in intersection - fill remaining with independent TPE self.fill_remaining_independent_with_rng( search_space, &intersection, @@ -178,43 +199,12 @@ impl MultivariateTpeSampler { param_order.sort_by_key(|id| format!("{id}")); - // Extract observations and validate + // Extract observations, validate, and fit KDEs let good_obs = self.extract_observations(&good, ¶m_order); let bad_obs = self.extract_observations(&bad, ¶m_order); - let expected_dims = param_order.len(); - - let valid = !good_obs.is_empty() - && !bad_obs.is_empty() - && good_obs.iter().all(|obs| obs.len() == expected_dims) - && bad_obs.iter().all(|obs| obs.len() == expected_dims); - - if !valid { - // Observations invalid - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - } - - // Fit KDEs using let...else pattern - let Ok(good_kde) = MultivariateKDE::new(good_obs) else { - // KDE construction failed - fill remaining with independent TPE - self.fill_remaining_independent_with_rng( - search_space, - &intersection, - history, - &mut result, - rng, - ); - return result; - }; - let Ok(bad_kde) = MultivariateKDE::new(bad_obs) else { - // KDE construction failed - fill remaining with independent TPE + let Some((good_kde, bad_kde)) = Self::try_fit_kdes(good_obs, bad_obs, param_order.len()) + else { self.fill_remaining_independent_with_rng( search_space, &intersection, @@ -289,7 +279,7 @@ impl MultivariateTpeSampler { /// from the "good" KDE (l(x)) and selects the one that maximizes the ratio l(x)/g(x), /// which is equivalent to maximizing `log(l(x)) - log(g(x))`. #[must_use] - #[allow(dead_code)] // Used by tests + #[cfg(test)] pub(crate) fn select_candidate( &self, good_kde: &crate::kde::MultivariateKDE, @@ -366,43 +356,6 @@ impl MultivariateTpeSampler { candidates.into_iter().nth(best_idx).unwrap_or_default() } - /// Fills remaining parameters in result using independent TPE sampling. - /// - /// This method is used to sample parameters that are not in the intersection - /// search space. It uses independent univariate TPE sampling for each parameter, - /// similar to the standard [`TpeSampler`]. - /// - /// When there isn't enough history for a parameter, falls back to uniform sampling. - #[allow(dead_code)] - pub(crate) fn fill_remaining_independent( - &self, - search_space: &HashMap<ParamId, Distribution>, - _intersection: &HashMap<ParamId, Distribution>, - history: &[CompletedTrial], - result: &mut HashMap<ParamId, ParamValue>, - ) { - // Identify parameters not in result (and not in intersection) - let missing_params: Vec<(&ParamId, &Distribution)> = search_space - .iter() - .filter(|(id, _)| !result.contains_key(id)) - .collect(); - - if missing_params.is_empty() { - return; - } - - // Split trials for independent sampling - let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); - - let mut rng = self.rng.lock(); - - for (param_id, dist) in missing_params { - let value = - self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); - result.insert(*param_id, value); - } - } - /// Fills remaining parameters using independent TPE sampling with an external RNG. /// /// This variant accepts an external RNG, used when the caller already holds the lock. @@ -437,7 +390,7 @@ impl MultivariateTpeSampler { /// Samples all parameters using independent TPE sampling. /// /// This is used as a complete fallback when no intersection search space exists. - #[allow(dead_code)] // Used by tests + #[cfg(test)] pub(crate) fn sample_all_independent( &self, search_space: &HashMap<ParamId, Distribution>, @@ -485,7 +438,6 @@ impl MultivariateTpeSampler { /// /// This method extracts values for the given parameter from good and bad trials, /// fits univariate KDEs, and samples using the TPE acquisition function. - #[allow(clippy::too_many_lines)] pub(crate) fn sample_independent_tpe( &self, param_id: ParamId, @@ -494,114 +446,136 @@ impl MultivariateTpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { - use crate::sampler::tpe::common; - match distribution { Distribution::Float(d) => { - let good_values: Vec<f64> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<f64> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let value = common::sample_tpe_float( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - None, - rng, - ); - ParamValue::Float(value) + self.sample_independent_float(param_id, d, good_trials, bad_trials, rng) } Distribution::Int(d) => { - let good_values: Vec<i64> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<i64> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let value = common::sample_tpe_int( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - None, - rng, - ); - ParamValue::Int(value) + self.sample_independent_int(param_id, d, good_trials, bad_trials, rng) } Distribution::Categorical(d) => { - let good_indices: Vec<usize> = good_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - let bad_indices: Vec<usize> = bad_trials - .iter() - .filter_map(|t| t.params.get(¶m_id)) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - if good_indices.is_empty() || bad_indices.is_empty() { - return crate::sampler::common::sample_random(rng, distribution); - } - - let idx = - common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); - ParamValue::Categorical(idx) + self.sample_independent_categorical(param_id, d, good_trials, bad_trials, rng) } } } + + fn sample_independent_float( + &self, + param_id: ParamId, + d: &crate::distribution::FloatDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + use crate::sampler::tpe::common; + + let good_values: Vec<f64> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<f64> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return crate::sampler::common::sample_random(rng, &Distribution::Float(d.clone())); + } + + let value = + common::sample_tpe_float(d, good_values, bad_values, self.n_ei_candidates, None, rng); + ParamValue::Float(value) + } + + fn sample_independent_int( + &self, + param_id: ParamId, + d: &crate::distribution::IntDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + use crate::sampler::tpe::common; + + let good_values: Vec<i64> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<i64> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return crate::sampler::common::sample_random(rng, &Distribution::Int(d.clone())); + } + + let value = + common::sample_tpe_int(d, good_values, bad_values, self.n_ei_candidates, None, rng); + ParamValue::Int(value) + } + + #[allow(clippy::unused_self)] + fn sample_independent_categorical( + &self, + param_id: ParamId, + d: &crate::distribution::CategoricalDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + use crate::sampler::tpe::common; + + let good_indices: Vec<usize> = good_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + let bad_indices: Vec<usize> = bad_trials + .iter() + .filter_map(|t| t.params.get(¶m_id)) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + if good_indices.is_empty() || bad_indices.is_empty() { + return crate::sampler::common::sample_random( + rng, + &Distribution::Categorical(d.clone()), + ); + } + + let idx = common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); + ParamValue::Categorical(idx) + } } diff --git a/src/sampler/tpe/multivariate/mod.rs b/src/sampler/tpe/multivariate/mod.rs index 4900e13..f3738f9 100644 --- a/src/sampler/tpe/multivariate/mod.rs +++ b/src/sampler/tpe/multivariate/mod.rs @@ -389,7 +389,6 @@ impl MultivariateTpeSampler { /// let params = sampler.sample_joint(&search_space, &history); /// ``` #[must_use] - #[allow(clippy::too_many_lines)] pub fn sample_joint( &self, search_space: &HashMap<ParamId, Distribution>, diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 5c9af1d..32f3f1a 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -642,8 +642,130 @@ impl Default for TpeSamplerBuilder { } } +impl TpeSampler { + fn sample_float( + &self, + d: &crate::distribution::FloatDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_values: Vec<f64> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<f64> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return ParamValue::Float(rng_util::f64_range(rng, d.low, d.high)); + } + + let value = tpe_common::sample_tpe_float( + d, + good_values, + bad_values, + self.n_ei_candidates, + self.kde_bandwidth, + rng, + ); + ParamValue::Float(value) + } + + fn sample_int( + &self, + d: &crate::distribution::IntDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_values: Vec<i64> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + let bad_values: Vec<i64> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + .filter(|&v| v >= d.low && v <= d.high) + .collect(); + + if good_values.is_empty() || bad_values.is_empty() { + return common::sample_random(rng, &Distribution::Int(d.clone())); + } + + let value = tpe_common::sample_tpe_int( + d, + good_values, + bad_values, + self.n_ei_candidates, + self.kde_bandwidth, + rng, + ); + ParamValue::Int(value) + } + + #[allow(clippy::unused_self)] + fn sample_categorical( + &self, + d: &crate::distribution::CategoricalDistribution, + good_trials: &[&CompletedTrial], + bad_trials: &[&CompletedTrial], + rng: &mut fastrand::Rng, + ) -> ParamValue { + let good_indices: Vec<usize> = good_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + let bad_indices: Vec<usize> = bad_trials + .iter() + .flat_map(|t| t.params.values()) + .filter_map(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + .filter(|&i| i < d.n_choices) + .collect(); + + if good_indices.is_empty() || bad_indices.is_empty() { + return common::sample_random(rng, &Distribution::Categorical(d.clone())); + } + + let index = + tpe_common::sample_tpe_categorical(d.n_choices, &good_indices, &bad_indices, rng); + ParamValue::Categorical(index) + } +} + impl Sampler for TpeSampler { - #[allow(clippy::too_many_lines)] fn sample( &self, distribution: &Distribution, @@ -670,123 +792,11 @@ impl Sampler for TpeSampler { return common::sample_random(&mut rng, distribution); } - // Extract parameter values for this distribution - // Since we don't have the parameter name here, we need to look at all - // trials and find matching distributions - // Note: This is a simplification - in practice, we'd need the param name - // For now, we'll collect values from trials that have this exact distribution type - match distribution { - Distribution::Float(d) => { - // Collect float values from trials - let good_values: Vec<f64> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<f64> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - // Need values in both groups for TPE - if good_values.is_empty() || bad_values.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let value = tpe_common::sample_tpe_float( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - self.kde_bandwidth, - &mut rng, - ); - ParamValue::Float(value) - } - Distribution::Int(d) => { - let good_values: Vec<i64> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - let bad_values: Vec<i64> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, - }) - .filter(|&v| v >= d.low && v <= d.high) - .collect(); - - if good_values.is_empty() || bad_values.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let value = tpe_common::sample_tpe_int( - d.low, - d.high, - d.log_scale, - d.step, - good_values, - bad_values, - self.n_ei_candidates, - self.kde_bandwidth, - &mut rng, - ); - ParamValue::Int(value) - } + Distribution::Float(d) => self.sample_float(d, &good_trials, &bad_trials, &mut rng), + Distribution::Int(d) => self.sample_int(d, &good_trials, &bad_trials, &mut rng), Distribution::Categorical(d) => { - let good_indices: Vec<usize> = good_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - let bad_indices: Vec<usize> = bad_trials - .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, - }) - .filter(|&i| i < d.n_choices) - .collect(); - - if good_indices.is_empty() || bad_indices.is_empty() { - return common::sample_random(&mut rng, distribution); - } - - let index = tpe_common::sample_tpe_categorical( - d.n_choices, - &good_indices, - &bad_indices, - &mut rng, - ); - ParamValue::Categorical(index) + self.sample_categorical(d, &good_trials, &bad_trials, &mut rng) } } } diff --git a/src/study/export.rs b/src/study/export.rs index 0db9ef9..beff14b 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -224,7 +224,7 @@ where impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { /// Export trials as a pretty-printed JSON array to a file. /// - /// Each element in the array is a serialized [`crate::CompletedTrial`]. + /// Each element in the array is a serialized [`CompletedTrial`](crate::sampler::CompletedTrial). /// Requires the `serde` feature. /// /// # Errors diff --git a/src/visualization.rs b/src/visualization.rs index 3c9558d..00eac89 100644 --- a/src/visualization.rs +++ b/src/visualization.rs @@ -41,7 +41,6 @@ use core::fmt::Write as _; use std::collections::BTreeMap; use std::path::Path; -use crate::distribution::Distribution; use crate::param::ParamValue; use crate::parameter::ParamId; use crate::sampler::CompletedTrial; @@ -156,8 +155,6 @@ fn build_html( /// Metadata about each parameter seen across trials. struct ParamMeta { label: String, - #[allow(dead_code)] - dist: Option<Distribution>, } /// Collect parameter labels and distributions across all trials. @@ -171,8 +168,7 @@ fn collect_param_info(trials: &[CompletedTrial<f64>]) -> BTreeMap<ParamId, Param .get(&id) .cloned() .unwrap_or_else(|| id.to_string()); - let dist = trial.distributions.get(&id).cloned(); - ParamMeta { label, dist } + ParamMeta { label } }); } } From 3d3dcb4c2689e892e7aaf93df0b5db0b3cad6c3b Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 15:56:16 +0100 Subject: [PATCH 34/48] feat(error): mark Error enum as #[non_exhaustive] --- src/error.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/error.rs b/src/error.rs index 70b2d98..d34db2f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,6 +11,7 @@ /// management. The [`TrialPruned`](Error::TrialPruned) variant has special /// significance — it signals early stopping and is typically raised via /// the [`TrialPruned`](super::TrialPruned) convenience type. +#[non_exhaustive] #[derive(Debug, thiserror::Error)] pub enum Error { /// The lower bound exceeds the upper bound in a From 12efbaabab9f6371a920147be7f5513679003d56 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:00:05 +0100 Subject: [PATCH 35/48] fix(sampler): return None instead of panicking on parameter type mismatch in CompletedTrial::get() --- src/sampler/mod.rs | 18 ++++++------------ tests/study/workflow.rs | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index a660d11..ff83bc6 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -121,13 +121,9 @@ impl<V> CompletedTrial<V> { /// Looks up the parameter by its unique id and casts the stored /// [`ParamValue`] to the parameter's typed value. /// - /// Returns `None` if the parameter was not used in this trial. - /// - /// # Panics - /// - /// Panics if the stored value is incompatible with the parameter type - /// (e.g., a `Float` value stored for an `IntParam`). This indicates - /// a bug in the program, not a runtime error. + /// Returns `None` if the parameter was not used in this trial or if + /// the stored value is incompatible with the parameter type (e.g., a + /// `Float` value stored for an `IntParam`). /// /// # Examples /// @@ -150,11 +146,9 @@ impl<V> CompletedTrial<V> { /// assert!((-10.0..=10.0).contains(&x_val)); /// ``` pub fn get<P: Parameter>(&self, param: &P) -> Option<P::Value> { - self.params.get(¶m.id()).map(|v| { - param - .cast_param_value(v) - .expect("parameter type mismatch: stored value incompatible with parameter") - }) + self.params + .get(¶m.id()) + .and_then(|v| param.cast_param_value(v).ok()) } /// Returns `true` if all constraints are satisfied (values <= 0.0). diff --git a/tests/study/workflow.rs b/tests/study/workflow.rs index a14bf11..98f1154 100644 --- a/tests/study/workflow.rs +++ b/tests/study/workflow.rs @@ -258,6 +258,28 @@ fn test_completed_trial_get() { assert!((1..=10).contains(&n_val)); } +#[test] +fn test_completed_trial_get_type_mismatch_returns_none() { + let study: Study<f64> = Study::new(Direction::Minimize); + let int_param = IntParam::new(1, 10).name("x"); + + study + .optimize(1, |trial: &mut optimizer::Trial| { + let n = int_param.suggest(trial)?; + Ok::<_, Error>(n as f64) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + + // The stored value is ParamValue::Int, but we query with a FloatParam. + let wrong_type = FloatParam::new(0.0, 100.0).name("x"); + assert!( + best.get(&wrong_type).is_none(), + "type mismatch should return None, not panic" + ); +} + #[test] fn test_single_value_int_range() { let param = IntParam::new(5, 5); From 0e0e9ff38f9140fb1f91a4313ca0dc19e6601d63 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:02:46 +0100 Subject: [PATCH 36/48] test(sobol): add integration tests for SobolSampler via Study API --- tests/sampler/main.rs | 2 + tests/sampler/sobol.rs | 314 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 316 insertions(+) create mode 100644 tests/sampler/sobol.rs diff --git a/tests/sampler/main.rs b/tests/sampler/main.rs index 9f0f825..82f2749 100644 --- a/tests/sampler/main.rs +++ b/tests/sampler/main.rs @@ -12,4 +12,6 @@ mod differential_evolution; mod gp; mod multivariate_tpe; mod random; +#[cfg(feature = "sobol")] +mod sobol; mod tpe; diff --git a/tests/sampler/sobol.rs b/tests/sampler/sobol.rs new file mode 100644 index 0000000..c41b742 --- /dev/null +++ b/tests/sampler/sobol.rs @@ -0,0 +1,314 @@ +use optimizer::prelude::*; +use optimizer::sampler::random::RandomSampler; +use optimizer::sampler::sobol::SobolSampler; + +#[test] +fn sphere_function() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + let yv = y.suggest(trial)?; + Ok::<_, Error>(xv * xv + yv * yv) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!( + best.value < 100.0, + "sphere best value should be < 100.0, got {}", + best.value + ); +} + +#[test] +fn bounds_respected() { + let sampler = SobolSampler::with_seed(123); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x = FloatParam::new(-2.0, 3.0).name("x"); + let y = FloatParam::new(0.0, 10.0).name("y"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + let yv = y.suggest(trial)?; + Ok::<_, Error>(xv + yv) + }) + .unwrap(); + + for trial in study.trials() { + let xv: f64 = trial.get(&x).unwrap(); + let yv: f64 = trial.get(&y).unwrap(); + assert!((-2.0..=3.0).contains(&xv), "x = {xv} out of bounds [-2, 3]"); + assert!((0.0..=10.0).contains(&yv), "y = {yv} out of bounds [0, 10]"); + } +} + +#[test] +fn integer_params() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let n = IntParam::new(1, 20).name("n"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let nv = n.suggest(trial)?; + Ok::<_, Error>(((nv - 10) * (nv - 10)) as f64) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + let best_n: i64 = best.get(&n).unwrap(); + assert!( + (1..=20).contains(&best_n), + "integer value {best_n} out of bounds" + ); + assert!( + best.value < 10.0, + "integer optimization should find a good value, got {}", + best.value + ); +} + +#[test] +fn log_scale_params() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let lr = FloatParam::new(1e-5, 1.0).log_scale().name("lr"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let lrv = lr.suggest(trial)?; + Ok::<_, Error>((lrv.ln() - 0.01_f64.ln()).powi(2)) + }) + .unwrap(); + + for trial in study.trials() { + let lrv: f64 = trial.get(&lr).unwrap(); + assert!( + (1e-5..=1.0).contains(&lrv), + "log-scale value {lrv} out of bounds" + ); + } +} + +#[test] +fn categorical_params() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let cat = CategoricalParam::new(vec!["a", "b", "c"]).name("cat"); + + study + .optimize(50, |trial: &mut optimizer::Trial| { + let cv = cat.suggest(trial)?; + let val = match cv { + "a" => 0.0, + "b" => 1.0, + _ => 2.0, + }; + Ok::<_, Error>(val) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!( + best.value < 2.0, + "categorical optimization should find a good value, got {}", + best.value + ); +} + +#[test] +fn mixed_params() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x = FloatParam::new(-5.0, 5.0).name("x"); + let n = IntParam::new(1, 10).name("n"); + let cat = CategoricalParam::new(vec!["a", "b", "c"]).name("cat"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + let nv = n.suggest(trial)?; + let cv = cat.suggest(trial)?; + let penalty = match cv { + "a" => 0.0, + "b" => 1.0, + _ => 2.0, + }; + Ok::<_, Error>(xv * xv + nv as f64 + penalty) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!( + best.value < 20.0, + "mixed-param optimization should find a reasonable value, got {}", + best.value + ); +} + +#[test] +fn single_dimension() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x = FloatParam::new(-10.0, 10.0).name("x"); + + study + .optimize(100, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + Ok::<_, Error>((xv - 3.0).powi(2)) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + assert!( + best.value < 5.0, + "1-D optimization should find a decent value, got {}", + best.value + ); +} + +#[test] +fn many_dimensions() { + let sampler = SobolSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let params: Vec<FloatParam> = (0..8) + .map(|i| FloatParam::new(-5.0, 5.0).name(format!("x{i}"))) + .collect(); + + study + .optimize(200, |trial: &mut optimizer::Trial| { + let mut sum = 0.0; + for p in ¶ms { + let v = p.suggest(trial)?; + sum += v * v; + } + Ok::<_, Error>(sum) + }) + .unwrap(); + + let best = study.best_trial().unwrap(); + // With 8 dimensions and 200 quasi-random trials the best won't be amazing, + // but it should be noticeably below the worst case (8 * 25 = 200). + assert!( + best.value < 150.0, + "8-D optimization should find something reasonable, got {}", + best.value + ); +} + +#[test] +fn seeded_reproducibility() { + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + let run = |seed: u64| { + let sampler = SobolSampler::with_seed(seed); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + study + .optimize(50, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + let yv = y.suggest(trial)?; + Ok::<_, Error>(xv * xv + yv * yv) + }) + .unwrap(); + study.trials().iter().map(|t| t.value).collect::<Vec<_>>() + }; + + let results1 = run(42); + let results2 = run(42); + assert_eq!(results1, results2, "same seed should produce same results"); +} + +#[test] +fn different_seeds_different_results() { + let x = FloatParam::new(-5.0, 5.0).name("x"); + let y = FloatParam::new(-5.0, 5.0).name("y"); + + let run = |seed: u64| { + let sampler = SobolSampler::with_seed(seed); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + study + .optimize(20, |trial: &mut optimizer::Trial| { + let xv = x.suggest(trial)?; + let yv = y.suggest(trial)?; + Ok::<_, Error>(xv * xv + yv * yv) + }) + .unwrap(); + study.trials().iter().map(|t| t.value).collect::<Vec<_>>() + }; + + let results1 = run(42); + let results2 = run(99); + assert_ne!( + results1, results2, + "different seeds should produce different results" + ); +} + +#[test] +fn better_coverage_than_random() { + let n_trials = 30; + let n_bins = 10; + + let x = FloatParam::new(0.0, 1.0).name("x"); + + // Count bins filled by Sobol. + let sobol_study: Study<f64> = + Study::with_sampler(Direction::Minimize, SobolSampler::with_seed(0)); + sobol_study + .optimize(n_trials, |trial: &mut optimizer::Trial| { + let v = x.suggest(trial)?; + Ok::<_, Error>(v) + }) + .unwrap(); + + let mut sobol_bins = vec![0u32; n_bins]; + for trial in sobol_study.trials() { + let v: f64 = trial.get(&x).unwrap(); + let bin = ((v * n_bins as f64).floor() as usize).min(n_bins - 1); + sobol_bins[bin] += 1; + } + let sobol_filled = sobol_bins.iter().filter(|&&c| c > 0).count(); + + // Count bins filled by Random. + let random_study: Study<f64> = + Study::with_sampler(Direction::Minimize, RandomSampler::with_seed(0)); + random_study + .optimize(n_trials, |trial: &mut optimizer::Trial| { + let v = x.suggest(trial)?; + Ok::<_, Error>(v) + }) + .unwrap(); + + let mut random_bins = vec![0u32; n_bins]; + for trial in random_study.trials() { + let v: f64 = trial.get(&x).unwrap(); + let bin = ((v * n_bins as f64).floor() as usize).min(n_bins - 1); + random_bins[bin] += 1; + } + let random_filled = random_bins.iter().filter(|&&c| c > 0).count(); + + assert!( + sobol_filled >= random_filled, + "Sobol should fill at least as many bins as random: sobol={sobol_filled}, random={random_filled}" + ); + // Sobol with 30 samples in 10 bins should fill all or nearly all bins. + assert!( + sobol_filled >= 9, + "Sobol should fill at least 9/10 bins, got {sobol_filled}: {sobol_bins:?}" + ); +} From 8394a747a5d74ff4c396ed38fa77fb8a0efb0360 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:07:31 +0100 Subject: [PATCH 37/48] test(journal): add corrupted and malicious JSONL file tests --- tests/journal_tests.rs | 165 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/tests/journal_tests.rs b/tests/journal_tests.rs index 4c25e0e..e40e68f 100644 --- a/tests/journal_tests.rs +++ b/tests/journal_tests.rs @@ -373,3 +373,168 @@ fn refresh_picks_up_external_writes() { std::fs::remove_file(&path).ok(); } + +// ── Corrupted / malicious journal file tests ──────────────────────── + +fn valid_trial_line_with_id(id: u64) -> String { + format!( + r#"{{"id":{id},"params":{{}},"distributions":{{}},"param_labels":{{}},"value":1.0,"intermediate_values":[],"state":"Complete","user_attrs":{{}},"constraints":[]}}"# + ) +} + +#[test] +fn empty_file_loads_as_empty_storage() { + let path = temp_path(); + std::fs::write(&path, "").unwrap(); + + let storage = JournalStorage::<f64>::open(&path).unwrap(); + assert_eq!(storage.trials_arc().read().len(), 0); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn whitespace_only_lines_are_skipped() { + let path = temp_path(); + std::fs::write(&path, " \n\t\n\n").unwrap(); + + let storage = JournalStorage::<f64>::open(&path).unwrap(); + assert_eq!(storage.trials_arc().read().len(), 0); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn truncated_json_line_returns_error() { + let path = temp_path(); + std::fs::write(&path, r#"{"id":0,"params":{"#).unwrap(); + + assert!(JournalStorage::<f64>::open(&path).is_err()); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn invalid_json_syntax_returns_error() { + let path = temp_path(); + std::fs::write(&path, "not valid json\n").unwrap(); + + assert!(JournalStorage::<f64>::open(&path).is_err()); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn missing_required_field_returns_error() { + let path = temp_path(); + // Missing params, distributions, param_labels, etc. + std::fs::write(&path, r#"{"id":0,"value":1.0}"#).unwrap(); + + assert!(JournalStorage::<f64>::open(&path).is_err()); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn extra_fields_are_ignored() { + let path = temp_path(); + let line = r#"{"id":0,"params":{},"distributions":{},"param_labels":{},"value":0.5,"intermediate_values":[],"state":"Complete","user_attrs":{},"constraints":[],"foo":"bar","extra_number":42}"#; + std::fs::write(&path, format!("{line}\n")).unwrap(); + + let storage = JournalStorage::<f64>::open(&path).unwrap(); + let loaded = storage.trials_arc().read().clone(); + assert_eq!(loaded.len(), 1); + assert_eq!(loaded[0].id, 0); + assert_eq!(loaded[0].value, 0.5); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn out_of_bounds_categorical_index_loads() { + let path = temp_path(); + // Categorical param with index 999, but distribution only has 3 choices. + // validate() does not check categorical bounds, so this should load. + let line = r#"{"id":0,"params":{"0":{"Categorical":999}},"distributions":{"0":{"Categorical":{"n_choices":3}}},"param_labels":{},"value":1.0,"intermediate_values":[],"state":"Complete","user_attrs":{},"constraints":[]}"#; + std::fs::write(&path, format!("{line}\n")).unwrap(); + + let storage = JournalStorage::<f64>::open(&path).unwrap(); + let loaded = storage.trials_arc().read().clone(); + assert_eq!(loaded.len(), 1); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn valid_lines_before_corruption_are_not_loaded() { + let path = temp_path(); + let content = format!( + "{}\n{}\n{}\n", + valid_trial_line_with_id(0), + valid_trial_line_with_id(1), + "CORRUPTED LINE" + ); + std::fs::write(&path, content).unwrap(); + + // load_trials_from_file is all-or-nothing: the corrupted third line + // makes the entire open() fail. + assert!(JournalStorage::<f64>::open(&path).is_err()); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn refresh_rejects_corrupted_external_append() { + let path = temp_path(); + let storage = JournalStorage::new(&path); + + // Push 2 valid trials through the storage API. + storage.push(sample_trial(0, 1.0)); + storage.push(sample_trial(1, 2.0)); + assert_eq!(storage.trials_arc().read().len(), 2); + + // Simulate an external process appending corrupted JSON. + { + let mut file = std::fs::OpenOptions::new() + .append(true) + .open(&path) + .unwrap(); + writeln!(file, "CORRUPTED LINE").unwrap(); + file.sync_all().unwrap(); + } + + // refresh() should reject the corrupted data and return false. + assert!(!storage.refresh()); + // Memory still has only the original 2 trials. + assert_eq!(storage.trials_arc().read().len(), 2); + + std::fs::remove_file(&path).ok(); +} + +#[test] +fn refresh_rejects_truncated_external_append() { + let path = temp_path(); + let storage = JournalStorage::new(&path); + + // Push 2 valid trials through the storage API. + storage.push(sample_trial(0, 1.0)); + storage.push(sample_trial(1, 2.0)); + assert_eq!(storage.trials_arc().read().len(), 2); + + // Simulate an external process appending truncated JSON. + { + let mut file = std::fs::OpenOptions::new() + .append(true) + .open(&path) + .unwrap(); + writeln!(file, r#"{{"id":2,"params":{{"#).unwrap(); + file.sync_all().unwrap(); + } + + // refresh() should reject the truncated data and return false. + assert!(!storage.refresh()); + // Memory still has only the original 2 trials. + assert_eq!(storage.trials_arc().read().len(), 2); + + std::fs::remove_file(&path).ok(); +} From 7c0cce160a4cc4c3fcf605fae76c91db55254c26 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:18:55 +0100 Subject: [PATCH 38/48] test(stress): add ignored stress tests for large-scale scenarios - 10k sequential trials with RandomSampler - 128 parameters with TPE sampler - 128 concurrent workers with optimize_parallel - 5k trials with TPE and 32 parallel workers --- tests/stress_tests.rs | 129 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) create mode 100644 tests/stress_tests.rs diff --git a/tests/stress_tests.rs b/tests/stress_tests.rs new file mode 100644 index 0000000..b652b90 --- /dev/null +++ b/tests/stress_tests.rs @@ -0,0 +1,129 @@ +//! Stress and large-scale tests for the optimizer library. +//! +//! All tests are `#[ignore]`-gated so they don't run in normal CI. +//! Run with: `cargo test --features async -- --ignored` + +use std::collections::HashSet; + +use optimizer::parameter::{FloatParam, Parameter}; +use optimizer::sampler::random::RandomSampler; +use optimizer::sampler::tpe::TpeSampler; +use optimizer::{Direction, Error, Study}; + +fn make_float_params(n: usize) -> Vec<FloatParam> { + (0..n) + .map(|i| FloatParam::new(-5.0, 5.0).name(format!("x{i}"))) + .collect() +} + +fn sphere(trial: &mut optimizer::Trial, params: &[FloatParam]) -> Result<f64, Error> { + let mut sum = 0.0; + for p in params { + let v = p.suggest(trial)?; + sum += v * v; + } + Ok(sum) +} + +#[test] +#[ignore] +fn stress_many_trials_random() { + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let params = make_float_params(5); + + study + .optimize(10_000, |trial: &mut optimizer::Trial| { + sphere(trial, ¶ms) + }) + .expect("10k trials should complete"); + + assert_eq!(study.n_trials(), 10_000); + let best = study.best_value().expect("should have a best value"); + assert!(best.is_finite(), "best value should be finite"); + assert!(best >= 0.0, "sphere function is non-negative"); +} + +#[test] +#[ignore] +fn stress_many_params_tpe() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(10) + .build() + .unwrap(); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let params = make_float_params(128); + + study + .optimize(200, |trial: &mut optimizer::Trial| sphere(trial, ¶ms)) + .expect("200 trials with 128 params should complete"); + + assert_eq!(study.n_trials(), 200); + + let best = study.best_trial().expect("should have a best trial"); + assert_eq!(best.params.len(), 128, "best trial should have 128 params"); + assert!(best.value.is_finite(), "best value should be finite"); + for v in best.params.values() { + let f = match v { + optimizer::param::ParamValue::Float(f) => *f, + other => panic!("expected Float param, got {other}"), + }; + assert!(f.is_finite(), "all param values should be finite"); + } +} + +#[cfg(feature = "async")] +#[tokio::test] +#[ignore] +async fn stress_high_concurrency_parallel() { + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let params = make_float_params(10); + + study + .optimize_parallel(1_000, 128, move |trial: &mut optimizer::Trial| { + sphere(trial, ¶ms) + }) + .await + .expect("1k trials with 128 workers should complete"); + + assert_eq!(study.n_trials(), 1_000); + + let trials = study.trials(); + let ids: HashSet<u64> = trials.iter().map(|t| t.id).collect(); + assert_eq!(ids.len(), 1_000, "all trial IDs should be unique"); + + let best = study.best_value().expect("should have a best value"); + assert!(best.is_finite(), "best value should be finite"); +} + +#[cfg(feature = "async")] +#[tokio::test] +#[ignore] +async fn stress_long_running_tpe_parallel() { + let sampler = TpeSampler::builder() + .seed(42) + .n_startup_trials(20) + .build() + .unwrap(); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + let params = make_float_params(20); + + study + .optimize_parallel(5_000, 32, move |trial: &mut optimizer::Trial| { + sphere(trial, ¶ms) + }) + .await + .expect("5k trials with TPE and 32 workers should complete"); + + assert_eq!(study.n_trials(), 5_000); + + let trials = study.trials(); + let ids: HashSet<u64> = trials.iter().map(|t| t.id).collect(); + assert_eq!(ids.len(), 5_000, "all trial IDs should be unique"); + + let best = study.best_trial().expect("should have a best trial"); + assert!(best.value.is_finite(), "best value should be finite"); + assert_eq!(best.params.len(), 20, "best trial should have 20 params"); +} From fd5c5453e5295748f137ffde6bc0599fe6e97e5e Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:22:38 +0100 Subject: [PATCH 39/48] test(async): add concurrency verification tests for optimize_parallel - Timing-based proof that trials run in parallel, not sequentially - Atomic counter proof that multiple workers overlap simultaneously - Panic propagation returns TaskError with panic message - Partial failures with parallel path store only successful trials --- tests/async_tests.rs | 117 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/tests/async_tests.rs b/tests/async_tests.rs index ba9754f..a5c271f 100644 --- a/tests/async_tests.rs +++ b/tests/async_tests.rs @@ -4,6 +4,8 @@ #![cfg(feature = "async")] +use std::sync::Arc; + use optimizer::parameter::{FloatParam, Parameter}; use optimizer::sampler::random::RandomSampler; use optimizer::sampler::tpe::TpeSampler; @@ -192,3 +194,118 @@ async fn test_optimize_parallel_single_concurrency() { assert_eq!(study.n_trials(), 10); } + +#[tokio::test] +async fn test_parallel_executes_concurrently() { + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x_param = FloatParam::new(0.0, 10.0); + + let start = tokio::time::Instant::now(); + study + .optimize_parallel(4, 4, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + std::thread::sleep(std::time::Duration::from_millis(100)); + Ok::<_, Error>(x) + }) + .await + .expect("parallel optimization should succeed"); + + let elapsed = start.elapsed(); + assert_eq!(study.n_trials(), 4); + // Sequential would take ~400ms; parallel with concurrency=4 should be ~100ms + assert!( + elapsed < std::time::Duration::from_millis(350), + "expected parallel execution under 350ms, took {elapsed:?}" + ); +} + +#[tokio::test] +async fn test_parallel_max_concurrency_reached() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x_param = FloatParam::new(0.0, 10.0); + let active = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + + let active_c = Arc::clone(&active); + let max_active_c = Arc::clone(&max_active); + + study + .optimize_parallel(8, 4, move |trial: &mut optimizer::Trial| { + let x = x_param.suggest(trial)?; + + let current = active_c.fetch_add(1, Ordering::SeqCst) + 1; + // Update max_active if this is the highest seen so far + max_active_c.fetch_max(current, Ordering::SeqCst); + + std::thread::sleep(std::time::Duration::from_millis(50)); + + active_c.fetch_sub(1, Ordering::SeqCst); + Ok::<_, Error>(x) + }) + .await + .expect("parallel optimization should succeed"); + + assert_eq!(study.n_trials(), 8); + let max = max_active.load(Ordering::SeqCst); + assert!( + max >= 2, + "expected at least 2 concurrent workers, but max was {max}" + ); +} + +#[tokio::test] +async fn test_parallel_panic_returns_task_error() { + let study: Study<f64> = Study::new(Direction::Minimize); + + let result = study + .optimize_parallel(3, 2, |_trial: &mut optimizer::Trial| { + panic!("boom"); + #[allow(unreachable_code)] + Ok::<_, Error>(0.0) + }) + .await; + + match result { + Err(Error::TaskError(msg)) => { + assert!( + msg.contains("panic"), + "expected error message to contain 'panic', got: {msg}" + ); + } + other => panic!("expected TaskError, got {other:?}"), + } +} + +#[tokio::test] +async fn test_parallel_partial_failures_trial_count() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + let sampler = RandomSampler::with_seed(42); + let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); + + let x_param = FloatParam::new(0.0, 10.0); + let counter = Arc::new(AtomicUsize::new(0)); + + let counter_c = Arc::clone(&counter); + + study + .optimize_parallel(10, 3, move |trial: &mut optimizer::Trial| { + let idx = counter_c.fetch_add(1, Ordering::SeqCst); + if idx % 2 == 1 { + return Err(Error::NoCompletedTrials); + } + let x = x_param.suggest(trial)?; + Ok::<_, Error>(x) + }) + .await + .expect("should succeed with partial failures"); + + // Even indices succeed (0, 2, 4, 6, 8), odd indices fail + assert_eq!(study.n_trials(), 5); +} From 479554ffb2df990bcd9d3ac863900265d415130b Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:24:26 +0100 Subject: [PATCH 40/48] fix(docs): convert JournalStorage intra-doc link to plain backtick text --- src/storage/memory.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/storage/memory.rs b/src/storage/memory.rs index 676fc41..eeea613 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -13,8 +13,7 @@ //! (no disk I/O). //! //! For persistent storage that survives process restarts, see -//! [`JournalStorage`](super::JournalStorage) (requires the `journal` -//! feature). +//! `JournalStorage` (requires the `journal` feature). //! //! # Example //! From 8d06d213dbfbce93acbfb48f3297117dc6291e2a Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:26:41 +0100 Subject: [PATCH 41/48] docs(journal): document file-size DoS potential in JournalStorage --- src/storage/journal.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 6833bd6..61275e6 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -67,6 +67,20 @@ //! //! For pure in-memory usage without disk I/O, use //! [`MemoryStorage`](super::MemoryStorage) instead (the default). +//! +//! # Security considerations +//! +//! Both [`JournalStorage::open`] and [`refresh`](super::Storage::refresh) +//! read the entire JSONL file into memory. A very large file will +//! consume memory proportional to its size, which could lead to +//! out-of-memory conditions. +//! +//! If your application accepts externally-provided file paths, consider: +//! +//! - Checking the file size before calling `open`. +//! - Validating or sanitizing untrusted JSONL content. +//! - Imposing an upper bound on the number of trials or file size you +//! are willing to load. use core::marker::PhantomData; use core::sync::atomic::{AtomicU64, Ordering}; @@ -108,6 +122,11 @@ use crate::sampler::CompletedTrial; /// let storage = JournalStorage::<f64>::new("trials.jsonl"); /// let mut study = Study::builder().minimize().storage(storage).build(); /// ``` +/// +/// # Security considerations +/// +/// File contents are loaded into memory in full; see the +/// [module-level docs](self) for details and mitigations. pub struct JournalStorage<V = f64> { memory: MemoryStorage<V>, path: PathBuf, From cfa8426312eefac5e6c152df19e93a8530481819 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:31:25 +0100 Subject: [PATCH 42/48] docs(sampler,pruner): add implementation guides for custom samplers and pruners - Expand sampler module docs with available samplers tables, custom sampler walkthrough, stateless/stateful patterns, cold start handling, history reading, thread safety, and testing guidance - Expand pruner module docs with stateful/stateless classification, warmup parameters, decorator composition, thread safety, and testing - Add code examples to both trait docs (NoisySampler, StalePruner) --- src/pruner/mod.rs | 103 +++++++++++++++++++++ src/sampler/mod.rs | 219 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 322 insertions(+) diff --git a/src/pruner/mod.rs b/src/pruner/mod.rs index 03d7f01..02fb055 100644 --- a/src/pruner/mod.rs +++ b/src/pruner/mod.rs @@ -41,6 +41,61 @@ //! Start with [`MedianPruner`] for most use cases. Switch to [`WilcoxonPruner`] //! if your intermediate values are noisy, or to [`HyperbandPruner`] if you want //! automatic budget scheduling. +//! +//! # Stateful vs stateless pruners +//! +//! **Stateless** pruners make their decision purely from the arguments passed +//! to [`Pruner::should_prune`] — they hold no mutable per-trial state. +//! [`MedianPruner`], [`PercentilePruner`], [`ThresholdPruner`], +//! [`WilcoxonPruner`], and [`NopPruner`] are all stateless. +//! +//! **Stateful** pruners track information across calls. [`PatientPruner`] +//! uses `Mutex<HashMap<u64, u64>>` to count consecutive prune signals per +//! trial. [`HyperbandPruner`] uses `Mutex` and `AtomicU64` for bracket +//! assignment state. When writing a stateful pruner, wrap mutable state in a +//! `Mutex` and key it by `trial_id` to keep trials independent. +//! +//! # Cold start and warmup +//! +//! Two builder parameters control when pruning begins: +//! +//! - **`n_warmup_steps`** — skip pruning before step N *within a trial*, +//! giving the objective time to stabilize. +//! - **`n_min_trials`** — require N completed trials before pruning any trial, +//! ensuring a meaningful comparison baseline. +//! +//! See [`MedianPruner`] for the canonical implementation of both parameters. +//! Custom pruners should expose similar knobs when applicable. +//! +//! # Composing pruners +//! +//! [`PatientPruner`] demonstrates the decorator pattern: it wraps any +//! `Box<dyn Pruner>` and adds patience logic on top. Custom pruners can use +//! the same pattern to layer multiple pruning conditions — for example, +//! combining a statistical test with a hard threshold. +//! +//! # Thread safety +//! +//! The [`Pruner`] trait requires `Send + Sync`. +//! [`Study`](crate::Study) stores the pruner as `Arc<dyn Pruner>`, so +//! multiple threads may call [`Pruner::should_prune`] concurrently. +//! +//! - **Stateless pruners** satisfy `Send + Sync` automatically. +//! - **Stateful pruners** should use `std::sync::Mutex` or +//! `parking_lot::Mutex` to protect mutable state, keyed by `trial_id`. +//! +//! # Testing custom pruners +//! +//! Recommended test categories: +//! +//! 1. **Never-prune baseline** — empty history and early steps should not +//! prune. +//! 2. **Known-prune scenario** — a clearly worse trial should be pruned. +//! 3. **Known-keep scenario** — a well-performing trial should survive. +//! 4. **Warmup respected** — pruning must be suppressed during warmup steps +//! and while the minimum trial count has not been reached. +//! 5. **Per-trial independence** — stateful pruners must not leak state +//! between different `trial_id` values. mod hyperband; mod median; @@ -93,6 +148,54 @@ use crate::sampler::CompletedTrial; /// } /// } /// ``` +/// +/// A stateful pruner that tracks per-trial state with a `Mutex`: +/// +/// ``` +/// use std::collections::HashMap; +/// use std::sync::Mutex; +/// use optimizer::pruner::Pruner; +/// use optimizer::sampler::CompletedTrial; +/// +/// /// Prune after the value worsens for `max_stale` consecutive steps. +/// struct StalePruner { +/// max_stale: u64, +/// // Per-trial: (previous_value, consecutive_stale_count) +/// state: Mutex<HashMap<u64, (f64, u64)>>, +/// } +/// +/// impl StalePruner { +/// fn new(max_stale: u64) -> Self { +/// Self { max_stale, state: Mutex::new(HashMap::new()) } +/// } +/// } +/// +/// impl Pruner for StalePruner { +/// fn should_prune( +/// &self, +/// trial_id: u64, +/// _step: u64, +/// intermediate_values: &[(u64, f64)], +/// _completed_trials: &[CompletedTrial], +/// ) -> bool { +/// let Some(&(_, current)) = intermediate_values.last() else { +/// return false; +/// }; +/// let mut state = self.state.lock().unwrap(); +/// let entry = state.entry(trial_id).or_insert((current, 0)); +/// if current >= entry.0 { +/// entry.1 += 1; +/// } else { +/// entry.1 = 0; +/// } +/// entry.0 = current; +/// entry.1 >= self.max_stale +/// } +/// } +/// ``` +/// +/// See the [module-level documentation](self) for a comprehensive guide +/// covering warmup, composition, thread safety, and testing. pub trait Pruner: Send + Sync { /// Decide whether to prune a trial at the given step. /// diff --git a/src/sampler/mod.rs b/src/sampler/mod.rs index ff83bc6..6f5ba17 100644 --- a/src/sampler/mod.rs +++ b/src/sampler/mod.rs @@ -1,4 +1,179 @@ //! Sampler trait and implementations for parameter sampling. +//! +//! A sampler generates parameter values for each trial. It receives a +//! [`Distribution`] describing the parameter space, a monotonically increasing +//! `trial_id`, and the list of all [`CompletedTrial`]s so far, and returns a +//! [`ParamValue`] that matches the distribution variant. +//! +//! # Available samplers +//! +//! ## Single-objective +//! +//! | Sampler | Algorithm | Best for | +//! |---------|-----------|----------| +//! | [`RandomSampler`] | Uniform independent sampling | Baselines, startup phases | +//! | [`TpeSampler`] | Tree-Parzen Estimator | General-purpose Bayesian optimization | +//! | [`TpeSampler`] (multivariate) | Multivariate TPE with tree-structured Parzen | Correlated parameters | +//! | [`GridSampler`] | Exhaustive grid evaluation | Small discrete spaces | +//! | [`SobolSampler`]\* | Quasi-random Sobol sequences | Uniform coverage without model | +//! | [`CmaEsSampler`]\* | Covariance Matrix Adaptation | Continuous, non-separable problems | +//! | [`GpSampler`]\* | Gaussian Process with EI | Expensive, low-dimensional functions | +//! | [`DESampler`] | Differential Evolution | Population-based, multi-modal landscapes | +//! | [`BohbSampler`] | Bayesian Optimization + `HyperBand` | Combined sampling and pruning | +//! +//! \*Requires a feature flag (`sobol`, `cma-es`, or `gp`). +//! +//! ## Multi-objective +//! +//! | Sampler | Algorithm | Best for | +//! |---------|-----------|----------| +//! | [`Nsga2Sampler`] | NSGA-II | General multi-objective with 2-3 objectives | +//! | [`Nsga3Sampler`] | NSGA-III | Many-objective (4+ objectives) | +//! | [`MoeadSampler`] | MOEA/D with decomposition | Structured Pareto front exploration | +//! | [`MotpeSampler`] | Multi-objective TPE | Bayesian multi-objective | +//! +//! # Implementing a custom sampler +//! +//! Implement the [`Sampler`] trait with its single method: +//! +//! ```rust +//! use optimizer::sampler::{Sampler, CompletedTrial}; +//! use optimizer::distribution::Distribution; +//! use optimizer::param::ParamValue; +//! +//! /// A sampler that always picks the midpoint of each distribution. +//! struct MidpointSampler; +//! +//! impl Sampler for MidpointSampler { +//! fn sample( +//! &self, +//! distribution: &Distribution, +//! _trial_id: u64, +//! _history: &[CompletedTrial], +//! ) -> ParamValue { +//! match distribution { +//! Distribution::Float(fd) => { +//! ParamValue::Float((fd.low + fd.high) / 2.0) +//! } +//! Distribution::Int(id) => { +//! ParamValue::Int((id.low + id.high) / 2) +//! } +//! Distribution::Categorical(cd) => { +//! ParamValue::Categorical(cd.n_choices / 2) +//! } +//! } +//! } +//! } +//! ``` +//! +//! The arguments to [`Sampler::sample`]: +//! +//! - **`distribution`** — a [`Distribution::Float`], [`Distribution::Int`], or +//! [`Distribution::Categorical`] that describes the parameter bounds, +//! log-scale flag, and optional step size. +//! - **`trial_id`** — a monotonically increasing identifier. Useful for +//! deterministic RNG seeding (see [Stateless vs stateful samplers]). +//! - **`history`** — all completed trials so far. May be empty on the first +//! trial. Model-based samplers use this to guide future sampling. +//! - **Return value** — the [`ParamValue`] variant *must* match the +//! distribution variant (`Float` → `ParamValue::Float`, etc.). +//! +//! [Stateless vs stateful samplers]: #stateless-vs-stateful-samplers +//! +//! # Stateless vs stateful samplers +//! +//! **Stateless** samplers derive all randomness from a deterministic function +//! of `seed + trial_id + distribution`. They use an [`AtomicU64`] call-sequence +//! counter to disambiguate multiple calls within the same trial, but need no +//! `Mutex`. See [`RandomSampler`] and [`TpeSampler`] for this pattern. +//! +//! **Stateful** samplers maintain mutable state (e.g. a population pool) +//! across calls. Wrap mutable state in `parking_lot::Mutex<State>` and lock +//! for the duration of [`Sampler::sample`]. See [`DESampler`] and +//! [`GridSampler`] for this pattern. +//! +//! [`AtomicU64`]: core::sync::atomic::AtomicU64 +//! +//! # Cold start handling +//! +//! Model-based samplers need completed trials before their surrogate model is +//! useful. The standard pattern is to check `history.len() < n_startup_trials` +//! and fall back to random sampling during the startup phase. Expose this as a +//! builder parameter so users can tune the trade-off between exploration and +//! exploitation. See [`TpeSampler`] for a reference implementation. +//! +//! # Reading trial history +//! +//! The `history` slice contains only completed trials (never pending ones). +//! Common operations: +//! +//! - **Extract a parameter value:** +//! `trial.params.get(¶m_id)` returns `Option<&ParamValue>`. +//! - **Find the best trial:** +//! `history.iter().min_by(|a, b| a.value.partial_cmp(&b.value).unwrap())`. +//! - **Filter by state:** +//! `history.iter().filter(|t| t.state == TrialState::Complete)`. +//! - **Check feasibility:** +//! `trial.is_feasible()` returns `true` when all constraints are ≤ 0. +//! +//! # Thread safety +//! +//! The [`Sampler`] trait requires `Send + Sync`. [`Study`](crate::Study) stores +//! the sampler as `Arc<dyn Sampler>`, so multiple threads may call +//! [`Sampler::sample`] concurrently. +//! +//! - **Stateless:** `AtomicU64` counters satisfy `Send + Sync` without locking. +//! - **Stateful:** use `parking_lot::Mutex` (the crate convention) or +//! `std::sync::Mutex` to protect mutable state. +//! +//! # Testing custom samplers +//! +//! Recommended test categories: +//! +//! 1. **Bounds compliance** — sample many values and assert they fall within +//! the distribution range. +//! 2. **Step / log-scale correctness** — verify that discretized and +//! log-scaled distributions produce valid values. +//! 3. **Reproducibility** — the same seed must produce the same output. +//! 4. **History sensitivity** — model-based samplers should produce different +//! (better) samples as history grows. +//! 5. **Empty history** — `sample()` must not panic when `history` is empty. +//! +//! # Using a custom sampler with Study +//! +//! ```rust +//! use optimizer::{Direction, Study}; +//! use optimizer::sampler::{Sampler, CompletedTrial}; +//! use optimizer::distribution::Distribution; +//! use optimizer::param::ParamValue; +//! +//! struct MySampler; +//! impl Sampler for MySampler { +//! fn sample( +//! &self, +//! distribution: &Distribution, +//! _trial_id: u64, +//! _history: &[CompletedTrial], +//! ) -> ParamValue { +//! match distribution { +//! Distribution::Float(fd) => ParamValue::Float(fd.low), +//! Distribution::Int(id) => ParamValue::Int(id.low), +//! Distribution::Categorical(_) => ParamValue::Categorical(0), +//! } +//! } +//! } +//! +//! let study: Study<f64> = Study::with_sampler(Direction::Minimize, MySampler); +//! ``` +//! +//! The sampler is wrapped in `Arc<dyn Sampler>` internally. +//! +//! # Reference implementations +//! +//! - [`RandomSampler`] — simplest sampler; stateless, ignores history. +//! - [`TpeSampler`] — model-based with cold start fallback. +//! - [`DESampler`] — stateful, population-based. +//! - [`GridSampler`] — deterministic, exhaustive search. pub mod bohb; #[cfg(feature = "cma-es")] @@ -280,6 +455,50 @@ impl PendingTrial { /// Samplers are responsible for generating parameter values based on /// the distribution and historical trial data. The trait requires /// `Send + Sync` to support concurrent and async optimization. +/// +/// # Implementing a custom sampler +/// +/// ``` +/// use optimizer::sampler::{Sampler, CompletedTrial}; +/// use optimizer::distribution::Distribution; +/// use optimizer::param::ParamValue; +/// +/// struct NoisySampler { +/// noise_scale: f64, +/// seed: u64, +/// } +/// +/// impl Sampler for NoisySampler { +/// fn sample( +/// &self, +/// distribution: &Distribution, +/// trial_id: u64, +/// history: &[CompletedTrial], +/// ) -> ParamValue { +/// // Find the best value seen so far, or fall back to the midpoint +/// match distribution { +/// Distribution::Float(fd) => { +/// let center = if history.is_empty() { +/// (fd.low + fd.high) / 2.0 +/// } else { +/// history.iter() +/// .filter_map(|t| t.params.values().next()) +/// .filter_map(|v| if let ParamValue::Float(f) = v { Some(*f) } else { None }) +/// .next() +/// .unwrap_or((fd.low + fd.high) / 2.0) +/// }; +/// let noise = (trial_id as f64 * 0.1).sin() * self.noise_scale; +/// ParamValue::Float(center + noise) +/// } +/// Distribution::Int(id) => ParamValue::Int((id.low + id.high) / 2), +/// Distribution::Categorical(cd) => ParamValue::Categorical(trial_id as usize % cd.n_choices), +/// } +/// } +/// } +/// ``` +/// +/// See the [module-level documentation](self) for a comprehensive guide +/// covering cold start handling, thread safety patterns, and testing. pub trait Sampler: Send + Sync { /// Samples a parameter value from the given distribution. /// From 8c025b452565510f9d1d26471a9b6d05aca4a325 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 16:34:27 +0100 Subject: [PATCH 43/48] fix(journal): canonicalize file paths in JournalStorage constructors Resolve symlinks and `../` traversals best-effort in `new()` and `open()`, falling back to the original path if canonicalization fails (e.g. file doesn't exist yet). --- src/storage/journal.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 61275e6..cb08ccb 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -149,9 +149,13 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { /// [`JournalStorage::open`] instead. #[must_use] pub fn new(path: impl AsRef<Path>) -> Self { + let path = path + .as_ref() + .canonicalize() + .unwrap_or_else(|_| path.as_ref().to_path_buf()); Self { memory: MemoryStorage::new(), - path: path.as_ref().to_path_buf(), + path, write_lock: Mutex::new(()), file_offset: AtomicU64::new(0), _marker: PhantomData, @@ -168,7 +172,10 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { /// Return a [`Storage`](crate::Error::Storage) error if the file /// exists but cannot be read or parsed. pub fn open(path: impl AsRef<Path>) -> crate::Result<Self> { - let path = path.as_ref().to_path_buf(); + let path = path + .as_ref() + .canonicalize() + .unwrap_or_else(|_| path.as_ref().to_path_buf()); let (trials, offset) = load_trials_from_file(&path)?; Ok(Self { memory: MemoryStorage::with_trials(trials), From 9f2ce2a4823f9abd1167fa35acf0d982e3f34df0 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 17:05:12 +0100 Subject: [PATCH 44/48] fix(docs): remove private intra-doc link from JournalStorage doc comment --- src/storage/journal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/storage/journal.rs b/src/storage/journal.rs index cb08ccb..01b7ed8 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -126,7 +126,7 @@ use crate::sampler::CompletedTrial; /// # Security considerations /// /// File contents are loaded into memory in full; see the -/// [module-level docs](self) for details and mitigations. +/// module-level docs for details and mitigations. pub struct JournalStorage<V = f64> { memory: MemoryStorage<V>, path: PathBuf, From 1cd18c16f915422b5342811cb86ff7d10bb2fb16 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Thu, 12 Feb 2026 17:57:29 +0100 Subject: [PATCH 45/48] fix: address review findings in grid sampler, parallel optimization, and CSV export - Rename stale "GridSearchSampler" panic message to "GridSampler" - Assert concurrency > 0 in optimize_parallel to prevent deadlock - Fix inaccurate comment in CSV export (empty for all non-complete trials, not just pruned) --- src/sampler/grid.rs | 4 ++-- src/study/async_impl.rs | 2 ++ src/study/export.rs | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/sampler/grid.rs b/src/sampler/grid.rs index b991f35..f6aea7b 100644 --- a/src/sampler/grid.rs +++ b/src/sampler/grid.rs @@ -601,7 +601,7 @@ impl Sampler for GridSampler { // Check if all points have been exhausted assert!( cached.current_index < cached.points.len(), - "GridSearchSampler: all grid points exhausted" + "GridSampler: all grid points exhausted" ); // Get the current grid point and advance the index @@ -908,7 +908,7 @@ mod tests { } #[test] - #[should_panic(expected = "GridSearchSampler: all grid points exhausted")] + #[should_panic(expected = "GridSampler: all grid points exhausted")] fn test_sampler_panics_after_exhaustion() { let sampler = GridSampler::new(); let dist = Distribution::Categorical(CategoricalDistribution { n_choices: 2 }); diff --git a/src/study/async_impl.rs b/src/study/async_impl.rs index ca9e0a5..2d0a1a6 100644 --- a/src/study/async_impl.rs +++ b/src/study/async_impl.rs @@ -171,6 +171,8 @@ where use tokio::sync::Semaphore; use tokio::task::JoinSet; + assert!(concurrency > 0, "concurrency must be at least 1"); + #[cfg(feature = "tracing")] let _span = tracing::info_span!("optimize_parallel", n_trials, concurrency, direction = ?self.direction).entered(); diff --git a/src/study/export.rs b/src/study/export.rs index beff14b..4933982 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -91,7 +91,7 @@ where for trial in trials.iter() { write!(writer, "{}", trial.id)?; - // Value: empty for pruned trials. + // Value: empty for non-complete trials. if trial.state == TrialState::Complete { write!(writer, ",{}", trial.value)?; } else { From 11a8534b38d4bd1601877701ce5e14b5de4e37ad Mon Sep 17 00:00:00 2001 From: Manuel <raimannma@outlook.de> Date: Fri, 13 Feb 2026 10:03:37 +0100 Subject: [PATCH 46/48] fix: address 18 bugs found during codebase audit (#8) * fix: address 18 bugs found during codebase audit High severity: - TPE/MOTPE: match parameters by exact distribution equality instead of flat-mapping over all param values, preventing cross-parameter mixing - MultivariateTpeSampler: find_matching_param now uses search space distributions for exact matching instead of type+range heuristic - JournalStorage: write_to_file no longer advances file_offset (left to refresh), both operations serialized under single io_lock mutex, refresh uses fetch_max and deduplicates by trial ID Medium severity: - NSGA-III: use actual Pareto front ranks for tournament selection instead of artificial cyclic indices - sample_random: apply step quantization after log-scale sampling - internal_bounds: return None for non-positive log-scale bounds - SobolSampler: use per-trial dimension HashMap for concurrent safety - JournalStorage refresh: protect with io_lock mutex, use fetch_max - n_trials(): filter by TrialState::Complete as documented - FloatParam: reject NaN/Infinity in validate() - Pruners: assert n_min_trials >= 1, guard compute_percentile on empty - Visualization: escape_js for importance chart parameter names Low severity: - save(): use peek_next_trial_id() from Storage trait - csv_escape: handle carriage return per RFC 4180 - from_internal: use saturating arithmetic for stepped Int distributions - BoolParam: bounds-check categorical index < 2 - min_max: skip NaN values with safe fallback * ci: trigger CI on pull requests targeting any branch --- .github/workflows/ci.yml | 1 - src/parameter.rs | 42 +++++++++++- src/pruner/median.rs | 5 ++ src/pruner/percentile.rs | 6 ++ src/sampler/common.rs | 24 ++++++- src/sampler/motpe.rs | 99 ++++++++++++++++++++--------- src/sampler/nsga3.rs | 31 +++++---- src/sampler/sobol.rs | 26 +++----- src/sampler/tpe/multivariate/mod.rs | 92 +++++++++++++++------------ src/sampler/tpe/sampler.rs | 99 ++++++++++++++++++++--------- src/storage/journal.rs | 44 ++++++++----- src/storage/memory.rs | 4 ++ src/storage/mod.rs | 7 ++ src/study/export.rs | 2 +- src/study/mod.rs | 10 ++- src/study/persistence.rs | 2 +- src/visualization.rs | 12 +++- 17 files changed, 350 insertions(+), 156 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f44f41..268bc25 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,6 @@ on: push: branches: [main, master] pull_request: - branches: [main, master] permissions: contents: read diff --git a/src/parameter.rs b/src/parameter.rs index 74782f8..b22a1c8 100644 --- a/src/parameter.rs +++ b/src/parameter.rs @@ -255,6 +255,12 @@ impl Parameter for FloatParam { } fn validate(&self) -> Result<()> { + if !self.low.is_finite() || !self.high.is_finite() { + return Err(Error::InvalidBounds { + low: self.low, + high: self.high, + }); + } if self.low > self.high { return Err(Error::InvalidBounds { low: self.low, @@ -265,7 +271,7 @@ impl Parameter for FloatParam { return Err(Error::InvalidLogBounds); } if let Some(step) = self.step - && step <= 0.0 + && (!step.is_finite() || step <= 0.0) { return Err(Error::InvalidStep); } @@ -550,7 +556,8 @@ impl Parameter for BoolParam { fn cast_param_value(&self, param_value: &ParamValue) -> Result<bool> { match param_value { - ParamValue::Categorical(index) => Ok(*index != 0), + ParamValue::Categorical(index) if *index < 2 => Ok(*index != 0), + ParamValue::Categorical(_) => Err(Error::Internal("bool index out of bounds")), _ => Err(Error::Internal( "Categorical distribution should return Categorical value", )), @@ -789,6 +796,30 @@ mod tests { assert!(param.validate().is_err()); } + #[test] + fn float_param_validate_nan() { + assert!(FloatParam::new(f64::NAN, 1.0).validate().is_err()); + assert!(FloatParam::new(0.0, f64::NAN).validate().is_err()); + assert!(FloatParam::new(f64::NAN, f64::NAN).validate().is_err()); + } + + #[test] + fn float_param_validate_infinity() { + assert!(FloatParam::new(f64::INFINITY, 1.0).validate().is_err()); + assert!(FloatParam::new(0.0, f64::NEG_INFINITY).validate().is_err()); + } + + #[test] + fn float_param_validate_nan_step() { + assert!(FloatParam::new(0.0, 1.0).step(f64::NAN).validate().is_err()); + assert!( + FloatParam::new(0.0, 1.0) + .step(f64::INFINITY) + .validate() + .is_err() + ); + } + #[test] #[allow(clippy::float_cmp)] fn float_param_cast_param_value() { @@ -920,6 +951,13 @@ mod tests { assert!(param.cast_param_value(&ParamValue::Float(1.0)).is_err()); } + #[test] + fn bool_param_cast_out_of_bounds() { + let param = BoolParam::new(); + assert!(param.cast_param_value(&ParamValue::Categorical(2)).is_err()); + assert!(param.cast_param_value(&ParamValue::Categorical(5)).is_err()); + } + #[derive(Clone, Debug, PartialEq)] enum TestEnum { A, diff --git a/src/pruner/median.rs b/src/pruner/median.rs index 0e97c8f..5451e66 100644 --- a/src/pruner/median.rs +++ b/src/pruner/median.rs @@ -89,8 +89,13 @@ impl MedianPruner { } /// Set the minimum number of completed trials required before pruning. + /// + /// # Panics + /// + /// Panics if `n` is 0. #[must_use] pub fn n_min_trials(mut self, n: usize) -> Self { + assert!(n >= 1, "n_min_trials must be >= 1, got {n}"); self.n_min_trials = n; self } diff --git a/src/pruner/percentile.rs b/src/pruner/percentile.rs index f12df01..5d6e574 100644 --- a/src/pruner/percentile.rs +++ b/src/pruner/percentile.rs @@ -95,8 +95,13 @@ impl PercentilePruner { } /// Set the minimum number of completed trials required before pruning. + /// + /// # Panics + /// + /// Panics if `n` is 0. #[must_use] pub fn n_min_trials(mut self, n: usize) -> Self { + assert!(n >= 1, "n_min_trials must be >= 1, got {n}"); self.n_min_trials = n; self } @@ -157,6 +162,7 @@ impl Pruner for PercentilePruner { clippy::cast_sign_loss )] pub(crate) fn compute_percentile(values: &mut [f64], percentile: f64) -> f64 { + assert!(!values.is_empty(), "compute_percentile: empty input"); values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal)); let len = values.len(); if len == 1 { diff --git a/src/sampler/common.rs b/src/sampler/common.rs index a2640bd..293b96b 100644 --- a/src/sampler/common.rs +++ b/src/sampler/common.rs @@ -10,6 +10,9 @@ pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> match distribution { Distribution::Float(d) => { if d.log_scale { + if d.low <= 0.0 || d.high <= 0.0 { + return None; + } Some((d.low.ln(), d.high.ln())) } else { Some((d.low, d.high)) @@ -17,6 +20,9 @@ pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> } Distribution::Int(d) => { if d.log_scale { + if d.low < 1 { + return None; + } Some(((d.low as f64).ln(), (d.high as f64).ln())) } else { Some((d.low as f64, d.high as f64)) @@ -44,7 +50,7 @@ pub(crate) fn from_internal(value: f64, distribution: &Distribution) -> ParamVal let v = if d.log_scale { value.exp() } else { value }; let v = if let Some(step) = d.step { let k = ((v - d.low as f64) / step as f64).round() as i64; - d.low + k * step + d.low.saturating_add(k.saturating_mul(step)) } else { v.round() as i64 }; @@ -86,7 +92,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution let value = if d.log_scale { let log_low = d.low.ln(); let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() + let v = rng_util::f64_range(rng, log_low, log_high).exp(); + if let Some(step) = d.step { + let k = ((v - d.low) / step).round(); + (d.low + k * step).clamp(d.low, d.high) + } else { + v + } } else if let Some(step) = d.step { let n_steps = ((d.high - d.low) / step).floor() as i64; let k = rng.i64(0..=n_steps); @@ -100,7 +112,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution let value = if d.log_scale { let log_low = (d.low as f64).ln(); let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; + let v = rng_util::f64_range(rng, log_low, log_high).exp(); + let raw = if let Some(step) = d.step { + let k = ((v - d.low as f64) / step as f64).round() as i64; + d.low.saturating_add(k.saturating_mul(step)) + } else { + v.round() as i64 + }; raw.clamp(d.low, d.high) } else if let Some(step) = d.step { let n_steps = (d.high - d.low) / step; diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index 8049dee..749e576 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -222,24 +222,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Float(d.clone()); let good_values: Vec<f64> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec<f64> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -264,24 +277,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Int(d.clone()); let good_values: Vec<i64> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec<i64> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -307,24 +333,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Categorical(d.clone()); let good_indices: Vec<usize> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); let bad_indices: Vec<usize> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); if good_indices.is_empty() || bad_indices.is_empty() { diff --git a/src/sampler/nsga3.rs b/src/sampler/nsga3.rs index 7ac840c..80c9117 100644 --- a/src/sampler/nsga3.rs +++ b/src/sampler/nsga3.rs @@ -561,7 +561,7 @@ fn nsga3_select( state: &mut Nsga3State, population: &[&MultiObjectiveTrial], directions: &[Direction], -) -> Vec<Vec<ParamValue>> { +) -> (Vec<Vec<ParamValue>>, Vec<usize>) { let pop_size = state.evo.population_size; let n_obj = directions.len(); @@ -633,12 +633,13 @@ fn nsga3_select( selected.push(state.evo.rng.usize(0..n)); } - selected + let params = selected .iter() .map(|&idx| { extract_trial_params(population[idx], &state.evo.dimensions, &mut state.evo.rng) }) - .collect() + .collect(); + (params, selected) } /// Tournament selection based on rank only (no crowding distance in NSGA-III). @@ -675,26 +676,34 @@ fn nsga3_generate_offspring( initialize_nsga3(state, directions); } - let parents = nsga3_select(state, population, directions); + let (parents, selected_indices) = nsga3_select(state, population, directions); - // Assign ranks for tournament selection + // Assign Pareto front ranks for tournament selection let n_obj = directions.len(); let min_values: Vec<Vec<f64>> = population .iter() .map(|t| to_minimize_space(&t.values, directions)) .collect(); let fronts = pareto::fast_non_dominated_sort(&min_values, &vec![Direction::Minimize; n_obj]); - let mut rank = vec![0_usize; parents.len()]; + // Build rank lookup for population indices + let mut pop_rank = vec![0_usize; population.len()]; for (front_rank, front) in fronts.iter().enumerate() { for &idx in front { - if idx < rank.len() { - rank[idx] = front_rank; + if idx < pop_rank.len() { + pop_rank[idx] = front_rank; } } } - // Ranks for selected parents (simplified: use index order) - let parent_ranks: Vec<usize> = (0..parents.len()) - .map(|i| i % (fronts.len().max(1))) + // Map population ranks to selected parent indices + let parent_ranks: Vec<usize> = selected_indices + .iter() + .map(|&idx| { + if idx < pop_rank.len() { + pop_rank[idx] + } else { + 0 + } + }) .collect(); let mut offspring = Vec::with_capacity(pop_size); diff --git a/src/sampler/sobol.rs b/src/sampler/sobol.rs index 2dad0b8..4b15ce4 100644 --- a/src/sampler/sobol.rs +++ b/src/sampler/sobol.rs @@ -44,6 +44,8 @@ //! let study: Study<f64> = Study::with_sampler(Direction::Minimize, SobolSampler::with_seed(42)); //! ``` +use std::collections::HashMap; + use parking_lot::Mutex; use sobol_burley::sample; @@ -51,12 +53,10 @@ use crate::distribution::Distribution; use crate::param::ParamValue; use crate::sampler::{CompletedTrial, Sampler}; -/// Internal state for tracking the dimension counter within a trial. +/// Internal state for tracking per-trial dimension counters. struct SobolState { - /// The `trial_id` of the current trial (used to reset dimension counter). - current_trial: u64, - /// Next Sobol dimension to use for the current trial. - next_dimension: u32, + /// Next Sobol dimension for each in-flight trial. + dimensions: HashMap<u64, u32>, } /// Quasi-random sampler using Sobol low-discrepancy sequences. @@ -107,8 +107,7 @@ impl SobolSampler { Self { seed: seed as u32, state: Mutex::new(SobolState { - current_trial: u64::MAX, - next_dimension: 0, + dimensions: HashMap::new(), }), } } @@ -130,20 +129,15 @@ impl Sampler for SobolSampler { ) -> ParamValue { let mut state = self.state.lock(); - // Reset dimension counter when a new trial starts. - if state.current_trial != trial_id { - state.current_trial = trial_id; - state.next_dimension = 0; - } - - let dimension = state.next_dimension; - state.next_dimension = dimension + 1; + let dimension = state.dimensions.entry(trial_id).or_insert(0); + let dim = *dimension; + *dimension = dim + 1; // Use trial_id as the Sobol sequence index. let index = trial_id as u32; // Generate a quasi-random point in [0, 1). - let point = f64::from(sample(index, dimension, self.seed)); + let point = f64::from(sample(index, dim, self.seed)); map_point_to_distribution(point, distribution) } diff --git a/src/sampler/tpe/multivariate/mod.rs b/src/sampler/tpe/multivariate/mod.rs index f3738f9..4c452bc 100644 --- a/src/sampler/tpe/multivariate/mod.rs +++ b/src/sampler/tpe/multivariate/mod.rs @@ -216,6 +216,13 @@ pub enum ConstantLiarStrategy { /// /// assert!(study.best_value().unwrap() < 1.0); /// ``` +/// Cached joint sample for a specific trial. +struct JointSampleCache { + trial_id: u64, + search_space: HashMap<ParamId, Distribution>, + sample: HashMap<ParamId, ParamValue>, +} + pub struct MultivariateTpeSampler { /// Strategy for computing the gamma quantile. gamma_strategy: Arc<dyn GammaStrategy>, @@ -230,8 +237,7 @@ pub struct MultivariateTpeSampler { /// Thread-safe RNG for sampling. rng: Mutex<fastrand::Rng>, /// Cache for joint samples to maintain consistency across parameters within the same trial. - /// The tuple contains (`trial_id`, cached joint sample). - joint_sample_cache: Mutex<Option<(u64, HashMap<ParamId, ParamValue>)>>, + joint_sample_cache: Mutex<Option<JointSampleCache>>, } impl MultivariateTpeSampler { @@ -453,11 +459,13 @@ impl Sampler for MultivariateTpeSampler { // Check if we have a cached joint sample for this trial { let cache = self.joint_sample_cache.lock(); - if let Some((cached_trial_id, ref cached_sample)) = *cache - && cached_trial_id == trial_id + if let Some(ref c) = *cache + && c.trial_id == trial_id { // Try to find a matching parameter from the cached sample - if let Some(value) = Self::find_matching_param(distribution, cached_sample) { + if let Some(value) = + Self::find_matching_param(distribution, &c.search_space, &c.sample) + { return value; } } @@ -470,13 +478,18 @@ impl Sampler for MultivariateTpeSampler { let joint_sample = self.sample_joint(&search_space, history); // Cache the joint sample for this trial + let result = Self::find_matching_param(distribution, &search_space, &joint_sample); { let mut cache = self.joint_sample_cache.lock(); - *cache = Some((trial_id, joint_sample.clone())); + *cache = Some(JointSampleCache { + trial_id, + search_space, + sample: joint_sample, + }); } // Find and return the value for the requested distribution - Self::find_matching_param(distribution, &joint_sample).unwrap_or_else(|| { + result.unwrap_or_else(|| { // Fallback to uniform sampling if no match found let mut rng = self.rng.lock(); crate::sampler::common::sample_random(&mut rng, distribution) @@ -485,33 +498,18 @@ impl Sampler for MultivariateTpeSampler { } impl MultivariateTpeSampler { - /// Finds a matching parameter value from the cached sample based on distribution. - /// - /// This is an associated function that matches parameters by comparing - /// distribution bounds and types. + /// Finds a matching parameter value from the cached sample based on exact + /// distribution equality. fn find_matching_param( distribution: &Distribution, + search_space: &HashMap<ParamId, Distribution>, cached_sample: &HashMap<ParamId, ParamValue>, ) -> Option<ParamValue> { - // Match by distribution type and value compatibility - for value in cached_sample.values() { - match (distribution, value) { - (Distribution::Float(d), ParamValue::Float(v)) => { - if *v >= d.low && *v <= d.high { - return Some(value.clone()); - } - } - (Distribution::Int(d), ParamValue::Int(v)) => { - if *v >= d.low && *v <= d.high { - return Some(value.clone()); - } - } - (Distribution::Categorical(d), ParamValue::Categorical(v)) => { - if *v < d.n_choices { - return Some(value.clone()); - } - } - _ => {} + for (id, dist) in search_space { + if dist == distribution + && let Some(value) = cached_sample.get(id) + { + return Some(value.clone()); } } None @@ -4213,12 +4211,15 @@ mod tests { fn test_find_matching_param_float() { let x_id = ParamId::new(); let y_id = ParamId::new(); + let dist = float_dist(0.0, 1.0); + let mut space = HashMap::new(); + space.insert(x_id, dist.clone()); + space.insert(y_id, float_dist(2.0, 3.0)); let mut cached = HashMap::new(); cached.insert(x_id, ParamValue::Float(0.5)); - cached.insert(y_id, ParamValue::Float(0.8)); + cached.insert(y_id, ParamValue::Float(2.8)); - let dist = float_dist(0.0, 1.0); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Float(v)) = result { @@ -4229,11 +4230,13 @@ mod tests { #[test] fn test_find_matching_param_int() { let n_id = ParamId::new(); + let dist = int_dist(0, 10); + let mut space = HashMap::new(); + space.insert(n_id, dist.clone()); let mut cached = HashMap::new(); cached.insert(n_id, ParamValue::Int(5)); - let dist = int_dist(0, 10); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Int(v)) = result { @@ -4244,11 +4247,13 @@ mod tests { #[test] fn test_find_matching_param_categorical() { let choice_id = ParamId::new(); + let dist = categorical_dist(3); + let mut space = HashMap::new(); + space.insert(choice_id, dist.clone()); let mut cached = HashMap::new(); cached.insert(choice_id, ParamValue::Categorical(1)); - let dist = categorical_dist(3); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Categorical(v)) = result { @@ -4259,12 +4264,14 @@ mod tests { #[test] fn test_find_matching_param_no_match() { let x_id = ParamId::new(); + let mut space = HashMap::new(); + space.insert(x_id, float_dist(0.0, 1.0)); let mut cached = HashMap::new(); cached.insert(x_id, ParamValue::Float(0.5)); - // Looking for Int, but only Float in cache + // Looking for Int, but only Float in search space let dist = int_dist(0, 10); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_none()); } @@ -4272,11 +4279,14 @@ mod tests { #[test] fn test_find_matching_param_out_of_bounds() { let x_id = ParamId::new(); + // Search space has a different distribution than what we're looking for + let mut space = HashMap::new(); + space.insert(x_id, float_dist(0.0, 10.0)); let mut cached = HashMap::new(); - cached.insert(x_id, ParamValue::Float(5.0)); // Out of bounds + cached.insert(x_id, ParamValue::Float(5.0)); let dist = float_dist(0.0, 1.0); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_none()); } diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 32f3f1a..b4c1c69 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -650,24 +650,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Float(d.clone()); let good_values: Vec<f64> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec<f64> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -692,24 +705,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Int(d.clone()); let good_values: Vec<i64> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec<i64> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -735,24 +761,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Categorical(d.clone()); let good_indices: Vec<usize> = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); let bad_indices: Vec<usize> = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); if good_indices.is_empty() || bad_indices.is_empty() { diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 01b7ed8..8802972 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -130,8 +130,8 @@ use crate::sampler::CompletedTrial; pub struct JournalStorage<V = f64> { memory: MemoryStorage<V>, path: PathBuf, - /// Serialise in-process writes so we only hold the file lock briefly. - write_lock: Mutex<()>, + /// Serialise in-process writes and refreshes so they don't race. + io_lock: Mutex<()>, /// Byte offset of last-read position for incremental refresh. file_offset: AtomicU64, _marker: PhantomData<V>, @@ -156,7 +156,7 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { Self { memory: MemoryStorage::new(), path, - write_lock: Mutex::new(()), + io_lock: Mutex::new(()), file_offset: AtomicU64::new(0), _marker: PhantomData, } @@ -180,15 +180,19 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { Ok(Self { memory: MemoryStorage::with_trials(trials), path, - write_lock: Mutex::new(()), + io_lock: Mutex::new(()), file_offset: AtomicU64::new(offset), _marker: PhantomData, }) } /// Append a single trial to the JSONL file (best-effort). + /// + /// Does **not** advance `file_offset` — that is left to `refresh` + /// so that externally-written data between the old offset and our + /// write is never skipped. fn write_to_file(&self, trial: &CompletedTrial<V>) -> crate::Result<()> { - let _guard = self.write_lock.lock(); + let _guard = self.io_lock.lock(); let mut file = OpenOptions::new() .create(true) @@ -211,11 +215,6 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> JournalStorage<V> { file.sync_data() .map_err(|e| crate::Error::Storage(e.to_string()))?; - let pos = file - .stream_position() - .map_err(|e| crate::Error::Storage(e.to_string()))?; - self.file_offset.store(pos, Ordering::SeqCst); - file.unlock() .map_err(|e| crate::Error::Storage(e.to_string()))?; @@ -238,7 +237,13 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> Storage<V> for JournalStorag self.memory.next_trial_id() } + fn peek_next_trial_id(&self) -> u64 { + self.memory.peek_next_trial_id() + } + fn refresh(&self) -> bool { + let _guard = self.io_lock.lock(); + let Ok(file) = File::open(&self.path) else { return false; }; @@ -275,6 +280,7 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> Storage<V> for JournalStorag let _ = file.unlock(); let bytes_read = buf.len() as u64; + let new_offset = offset + bytes_read; let mut new_trials = Vec::new(); for line in buf.lines() { @@ -293,19 +299,23 @@ impl<V: Serialize + DeserializeOwned + Send + Sync> Storage<V> for JournalStorag } if new_trials.is_empty() { - self.file_offset - .store(offset + bytes_read, Ordering::SeqCst); + self.file_offset.fetch_max(new_offset, Ordering::SeqCst); return false; } - let mut guard = self.memory.trials_arc().write(); + let mut mem_guard = self.memory.trials_arc().write(); + + // Deduplicate: only add trials whose IDs are not already in memory. + let existing_ids: std::collections::HashSet<u64> = mem_guard.iter().map(|t| t.id).collect(); + new_trials.retain(|t| !existing_ids.contains(&t.id)); + if let Some(max_id) = new_trials.iter().map(|t| t.id).max() { self.memory.bump_next_id(max_id + 1); } - guard.extend(new_trials); - self.file_offset - .store(offset + bytes_read, Ordering::SeqCst); - true + let added = !new_trials.is_empty(); + mem_guard.extend(new_trials); + self.file_offset.fetch_max(new_offset, Ordering::SeqCst); + added } } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index eeea613..8f91adb 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -104,4 +104,8 @@ impl<V: Send + Sync> Storage<V> for MemoryStorage<V> { fn next_trial_id(&self) -> u64 { self.next_id.fetch_add(1, Ordering::SeqCst) } + + fn peek_next_trial_id(&self) -> u64 { + self.next_id.load(Ordering::SeqCst) + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d0e32a0..5b683c7 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -77,6 +77,13 @@ pub trait Storage<V>: Send + Sync { /// calls always produce distinct IDs. fn next_trial_id(&self) -> u64; + /// Return the current value of the next-trial-ID counter without incrementing. + /// + /// This is used for persistence (e.g. `Study::save`) to capture the + /// counter's exact position, including IDs assigned to failed trials + /// that are not stored. + fn peek_next_trial_id(&self) -> u64; + /// Reload from an external source (e.g. a file written by another /// process). Return `true` if the in-memory buffer was updated. /// diff --git a/src/study/export.rs b/src/study/export.rs index 4933982..b02afb8 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -255,7 +255,7 @@ impl Study<f64> { /// Escape a string for CSV output. If the value contains a comma, quote, or /// newline, wrap it in double-quotes and double any embedded quotes. fn csv_escape(s: &str) -> String { - if s.contains(',') || s.contains('"') || s.contains('\n') { + if s.contains(',') || s.contains('"') || s.contains('\n') || s.contains('\r') { format!("\"{}\"", s.replace('"', "\"\"")) } else { s.to_string() diff --git a/src/study/mod.rs b/src/study/mod.rs index 00948c1..4906698 100644 --- a/src/study/mod.rs +++ b/src/study/mod.rs @@ -639,7 +639,8 @@ where /// Return the number of completed trials. /// - /// Failed trials are not counted. + /// Pruned and failed trials are not counted. Use + /// [`n_pruned_trials()`](Self::n_pruned_trials) for the pruned count. /// /// # Examples /// @@ -658,7 +659,12 @@ where /// ``` #[must_use] pub fn n_trials(&self) -> usize { - self.storage.trials_arc().read().len() + self.storage + .trials_arc() + .read() + .iter() + .filter(|t| t.state == TrialState::Complete) + .count() } /// Return the number of pruned trials. diff --git a/src/study/persistence.rs b/src/study/persistence.rs index c9174d6..95b07bc 100644 --- a/src/study/persistence.rs +++ b/src/study/persistence.rs @@ -49,7 +49,7 @@ impl<V: PartialOrd + Clone + serde::Serialize> Study<V> { pub fn save(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> { let path = path.as_ref(); let trials = self.trials(); - let next_trial_id = trials.iter().map(|t| t.id).max().map_or(0, |id| id + 1); + let next_trial_id = self.storage.peek_next_trial_id(); let snapshot = StudySnapshot { version: 1, direction: self.direction, diff --git a/src/visualization.rs b/src/visualization.rs index 00eac89..776479b 100644 --- a/src/visualization.rs +++ b/src/visualization.rs @@ -341,7 +341,10 @@ Plotly.newPlot("parcoords", [{{ } fn write_importance_chart(html: &mut String, importance: &[(String, f64)]) { - let names: Vec<_> = importance.iter().map(|(n, _)| format!("\"{n}\"")).collect(); + let names: Vec<_> = importance + .iter() + .map(|(n, _)| format!("\"{}\"", escape_js(n))) + .collect(); let values: Vec<f64> = importance.iter().map(|(_, v)| *v).collect(); let _ = write!( @@ -457,6 +460,9 @@ fn min_max(vals: &[f64]) -> (f64, f64) { let mut mn = f64::INFINITY; let mut mx = f64::NEG_INFINITY; for &v in vals { + if v.is_nan() { + continue; + } if v < mn { mn = v; } @@ -464,6 +470,10 @@ fn min_max(vals: &[f64]) -> (f64, f64) { mx = v; } } + // If all values were NaN, return 0.0..1.0 as a safe fallback. + if mn > mx { + return (0.0, 1.0); + } (mn, mx) } From 86db6361d639039cada0635fe69ade533c3babc1 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Fri, 13 Feb 2026 10:20:32 +0100 Subject: [PATCH 47/48] fix(sampler): clamp multivariate TPE candidates to parameter bounds - Clamp KDE candidates to parameter bounds before evaluating l(x)/g(x), matching the univariate TPE behavior; without this, candidates scored well at out-of-bounds locations but became suboptimal when clamped - Sort HashMap iterations by ParamId before consuming the seeded RNG to eliminate non-deterministic sampling caused by global ParamId counter - Replace wall-clock timing assertion in async concurrency test with atomic max-active counter to avoid CI flakiness - Gate unused HashSet import behind cfg(feature = "async") --- src/sampler/tpe/multivariate/engine.rs | 62 +++++++++++++++++++++----- src/sampler/tpe/multivariate/mod.rs | 37 ++++++++++----- tests/async_tests.rs | 24 +++++++--- tests/sampler/multivariate_tpe.rs | 4 +- tests/stress_tests.rs | 1 + 5 files changed, 98 insertions(+), 30 deletions(-) diff --git a/src/sampler/tpe/multivariate/engine.rs b/src/sampler/tpe/multivariate/engine.rs index 52f5f53..4095329 100644 --- a/src/sampler/tpe/multivariate/engine.rs +++ b/src/sampler/tpe/multivariate/engine.rs @@ -197,7 +197,7 @@ impl MultivariateTpeSampler { return result; } - param_order.sort_by_key(|id| format!("{id}")); + param_order.sort(); // Extract observations, validate, and fit KDEs let good_obs = self.extract_observations(&good, ¶m_order); @@ -215,7 +215,20 @@ impl MultivariateTpeSampler { return result; }; - let selected = self.select_candidate_with_rng(&good_kde, &bad_kde, rng); + // Compute parameter bounds for each dimension so candidates are clamped + let bounds: Vec<(f64, f64)> = param_order + .iter() + .filter_map(|id| { + intersection.get(id).and_then(|dist| match dist { + Distribution::Float(d) => Some((d.low, d.high)), + #[allow(clippy::cast_precision_loss)] + Distribution::Int(d) => Some((d.low as f64, d.high as f64)), + Distribution::Categorical(_) => None, + }) + }) + .collect(); + + let selected = self.select_candidate_with_rng(&good_kde, &bad_kde, &bounds, rng); // Map selected values to parameter ids for (idx, param_id) in param_order.iter().enumerate() { @@ -273,23 +286,38 @@ impl MultivariateTpeSampler { } } + /// Clamps each dimension of a candidate to the corresponding parameter bounds. + fn clamp_candidate(candidate: &mut [f64], bounds: &[(f64, f64)]) { + for (val, &(lo, hi)) in candidate.iter_mut().zip(bounds.iter()) { + *val = val.clamp(lo, hi); + } + } + /// Selects the best candidate from a set of samples using the joint acquisition function. /// /// This method implements the core TPE selection criterion: it generates candidates /// from the "good" KDE (l(x)) and selects the one that maximizes the ratio l(x)/g(x), /// which is equivalent to maximizing `log(l(x)) - log(g(x))`. + /// + /// Candidates are clamped to parameter bounds before evaluation so the acquisition + /// function scores the values that will actually be used. #[must_use] #[cfg(test)] pub(crate) fn select_candidate( &self, good_kde: &crate::kde::MultivariateKDE, bad_kde: &crate::kde::MultivariateKDE, + bounds: &[(f64, f64)], ) -> Vec<f64> { let mut rng = self.rng.lock(); - // Generate candidates from the good distribution + // Generate candidates from the good distribution, clamped to bounds let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) - .map(|_| good_kde.sample(&mut rng)) + .map(|_| { + let mut c = good_kde.sample(&mut rng); + Self::clamp_candidate(&mut c, bounds); + c + }) .collect(); // Compute log(l(x)) - log(g(x)) for each candidate @@ -321,15 +349,21 @@ impl MultivariateTpeSampler { /// Selects the best candidate using an external RNG. /// /// This variant accepts an external RNG, used when the caller already holds the lock. + /// Candidates are clamped to parameter bounds before evaluation. pub(crate) fn select_candidate_with_rng( &self, good_kde: &crate::kde::MultivariateKDE, bad_kde: &crate::kde::MultivariateKDE, + bounds: &[(f64, f64)], rng: &mut fastrand::Rng, ) -> Vec<f64> { - // Generate candidates from the good distribution + // Generate candidates from the good distribution, clamped to bounds let candidates: Vec<Vec<f64>> = (0..self.n_ei_candidates) - .map(|_| good_kde.sample(rng)) + .map(|_| { + let mut c = good_kde.sample(rng); + Self::clamp_candidate(&mut c, bounds); + c + }) .collect(); // Compute log(l(x)) - log(g(x)) for each candidate @@ -367,8 +401,8 @@ impl MultivariateTpeSampler { result: &mut HashMap<ParamId, ParamValue>, rng: &mut fastrand::Rng, ) { - // Identify parameters not in result (and not in intersection) - let missing_params: Vec<(&ParamId, &Distribution)> = search_space + // Identify parameters not in result, sorted for deterministic RNG consumption + let mut missing_params: Vec<(&ParamId, &Distribution)> = search_space .iter() .filter(|(id, _)| !result.contains_key(id)) .collect(); @@ -377,6 +411,8 @@ impl MultivariateTpeSampler { return; } + missing_params.sort_by_key(|(id, _)| *id); + // Split trials for independent sampling let (good_trials, bad_trials) = self.split_trials(&history.iter().collect::<Vec<_>>()); @@ -390,6 +426,7 @@ impl MultivariateTpeSampler { /// Samples all parameters using independent TPE sampling. /// /// This is used as a complete fallback when no intersection search space exists. + /// Parameters are sorted by `ParamId` for deterministic RNG consumption order. #[cfg(test)] pub(crate) fn sample_all_independent( &self, @@ -402,7 +439,9 @@ impl MultivariateTpeSampler { let mut rng = self.rng.lock(); let mut result = HashMap::new(); - for (param_id, dist) in search_space { + let mut sorted: Vec<_> = search_space.iter().collect(); + sorted.sort_by_key(|(id, _)| *id); + for (param_id, dist) in sorted { let value = self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, &mut rng); result.insert(*param_id, value); @@ -414,6 +453,7 @@ impl MultivariateTpeSampler { /// Samples all parameters using independent TPE sampling with an external RNG. /// /// This variant accepts an external RNG, used when the caller already holds the lock. + /// Parameters are sorted by `ParamId` for deterministic RNG consumption order. pub(crate) fn sample_all_independent_with_rng( &self, search_space: &HashMap<ParamId, Distribution>, @@ -425,7 +465,9 @@ impl MultivariateTpeSampler { let mut result = HashMap::new(); - for (param_id, dist) in search_space { + let mut sorted: Vec<_> = search_space.iter().collect(); + sorted.sort_by_key(|(id, _)| *id); + for (param_id, dist) in sorted { let value = self.sample_independent_tpe(*param_id, dist, &good_trials, &bad_trials, rng); result.insert(*param_id, value); diff --git a/src/sampler/tpe/multivariate/mod.rs b/src/sampler/tpe/multivariate/mod.rs index 4c452bc..05b2cd0 100644 --- a/src/sampler/tpe/multivariate/mod.rs +++ b/src/sampler/tpe/multivariate/mod.rs @@ -322,14 +322,18 @@ impl MultivariateTpeSampler { /// Samples all parameters uniformly at random. /// /// This is a fallback method used when multivariate TPE cannot be applied. + /// Parameters are sorted by `ParamId` to ensure deterministic RNG consumption + /// order when using a seeded sampler. #[allow(clippy::unused_self)] fn sample_all_uniform( &self, search_space: &HashMap<ParamId, Distribution>, rng: &mut fastrand::Rng, ) -> HashMap<ParamId, ParamValue> { - search_space - .iter() + let mut sorted: Vec<_> = search_space.iter().collect(); + sorted.sort_by_key(|(id, _)| *id); + sorted + .into_iter() .map(|(id, dist)| (*id, crate::sampler::common::sample_random(rng, dist))) .collect() } @@ -2781,8 +2785,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(0.0, 1.0), (0.0, 1.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); // The selected candidate should have 2 dimensions assert_eq!(selected.len(), 2); @@ -2822,8 +2827,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(-1.0, 2.0), (-1.0, 2.0), (-1.0, 2.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); assert_eq!(selected.len(), 3); } @@ -2841,8 +2847,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(0.0, 10.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); assert_eq!(selected.len(), 1); // Selected value should be closer to the good region @@ -2873,14 +2880,15 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples.clone()).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples.clone()).unwrap(); + let bounds = &[(0.0, 10.0), (0.0, 10.0)]; - let selected1 = sampler1.select_candidate(&good_kde, &bad_kde); + let selected1 = sampler1.select_candidate(&good_kde, &bad_kde, bounds); // Need to recreate KDEs for second sampler since we consumed them let good_kde2 = MultivariateKDE::new(good_samples).unwrap(); let bad_kde2 = MultivariateKDE::new(bad_samples).unwrap(); - let selected2 = sampler2.select_candidate(&good_kde2, &bad_kde2); + let selected2 = sampler2.select_candidate(&good_kde2, &bad_kde2, bounds); // With same seed, should get same result assert!( @@ -2911,8 +2919,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(-10.0, 10.0), (-10.0, 10.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); assert_eq!(selected.len(), 2); } @@ -2942,8 +2951,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(-5.0, 15.0), (-5.0, 15.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); // With more candidates, should definitely find a point in the good region assert!( @@ -2981,8 +2991,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(-5.0, 5.0), (-5.0, 5.0)]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); // Should still return a valid point assert_eq!(selected.len(), 2); @@ -3017,8 +3028,9 @@ mod tests { let good_kde = MultivariateKDE::new(good_samples).unwrap(); let bad_kde = MultivariateKDE::new(bad_samples).unwrap(); + let bounds = &[(-5.0, 15.0); 5]; - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); assert_eq!(selected.len(), 5); @@ -3091,7 +3103,8 @@ mod tests { let bad_kde = MultivariateKDE::new(bad_obs).unwrap(); // Select candidate - let selected = sampler.select_candidate(&good_kde, &bad_kde); + let bounds = &[(0.0, 10.0), (0.0, 10.0)]; + let selected = sampler.select_candidate(&good_kde, &bad_kde, bounds); assert_eq!(selected.len(), 2); diff --git a/tests/async_tests.rs b/tests/async_tests.rs index a5c271f..f82be17 100644 --- a/tests/async_tests.rs +++ b/tests/async_tests.rs @@ -197,27 +197,39 @@ async fn test_optimize_parallel_single_concurrency() { #[tokio::test] async fn test_parallel_executes_concurrently() { + use std::sync::atomic::{AtomicUsize, Ordering}; + let sampler = RandomSampler::with_seed(42); let study: Study<f64> = Study::with_sampler(Direction::Minimize, sampler); let x_param = FloatParam::new(0.0, 10.0); + let active = Arc::new(AtomicUsize::new(0)); + let max_active = Arc::new(AtomicUsize::new(0)); + + let active_c = Arc::clone(&active); + let max_active_c = Arc::clone(&max_active); - let start = tokio::time::Instant::now(); study .optimize_parallel(4, 4, move |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; - std::thread::sleep(std::time::Duration::from_millis(100)); + + let current = active_c.fetch_add(1, Ordering::SeqCst) + 1; + max_active_c.fetch_max(current, Ordering::SeqCst); + + std::thread::sleep(std::time::Duration::from_millis(50)); + + active_c.fetch_sub(1, Ordering::SeqCst); Ok::<_, Error>(x) }) .await .expect("parallel optimization should succeed"); - let elapsed = start.elapsed(); assert_eq!(study.n_trials(), 4); - // Sequential would take ~400ms; parallel with concurrency=4 should be ~100ms + let max = max_active.load(Ordering::SeqCst); + // With 4 trials and concurrency=4, all should run concurrently assert!( - elapsed < std::time::Duration::from_millis(350), - "expected parallel execution under 350ms, took {elapsed:?}" + max >= 2, + "expected at least 2 concurrent workers, but max was {max}" ); } diff --git a/tests/sampler/multivariate_tpe.rs b/tests/sampler/multivariate_tpe.rs index a84ca34..78d1718 100644 --- a/tests/sampler/multivariate_tpe.rs +++ b/tests/sampler/multivariate_tpe.rs @@ -45,7 +45,7 @@ fn test_multivariate_tpe_rosenbrock_finds_good_solution() { let sampler = MultivariateTpeSampler::builder() .seed(42) .n_startup_trials(10) - .n_ei_candidates(24) + .n_ei_candidates(48) .build() .unwrap(); @@ -55,7 +55,7 @@ fn test_multivariate_tpe_rosenbrock_finds_good_solution() { let y_param = FloatParam::new(-2.0, 4.0); study - .optimize(100, |trial: &mut optimizer::Trial| { + .optimize(200, |trial: &mut optimizer::Trial| { let x = x_param.suggest(trial)?; let y = y_param.suggest(trial)?; Ok::<_, Error>(rosenbrock(x, y)) diff --git a/tests/stress_tests.rs b/tests/stress_tests.rs index b652b90..48df70d 100644 --- a/tests/stress_tests.rs +++ b/tests/stress_tests.rs @@ -3,6 +3,7 @@ //! All tests are `#[ignore]`-gated so they don't run in normal CI. //! Run with: `cargo test --features async -- --ignored` +#[cfg(feature = "async")] use std::collections::HashSet; use optimizer::parameter::{FloatParam, Parameter}; From 93d7a63057533d456488475f6bc53ba4c8d518e5 Mon Sep 17 00:00:00 2001 From: Manuel Raimann <raimannma@outlook.de> Date: Fri, 13 Feb 2026 10:04:11 +0100 Subject: [PATCH 48/48] chore: release v1.0.0 --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 0d0d453..8ca9a69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["optimizer-derive"] [package] name = "optimizer" -version = "0.9.1" +version = "1.0.0" edition = "2024" rust-version = "1.89" license = "MIT"