From 23dfc87d473780ffdcc0aaaaaaf3ae5b796d7141 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Fri, 13 Feb 2026 09:55:32 +0100 Subject: [PATCH 1/3] fix: address 18 bugs found during codebase audit High severity: - TPE/MOTPE: match parameters by exact distribution equality instead of flat-mapping over all param values, preventing cross-parameter mixing - MultivariateTpeSampler: find_matching_param now uses search space distributions for exact matching instead of type+range heuristic - JournalStorage: write_to_file no longer advances file_offset (left to refresh), both operations serialized under single io_lock mutex, refresh uses fetch_max and deduplicates by trial ID Medium severity: - NSGA-III: use actual Pareto front ranks for tournament selection instead of artificial cyclic indices - sample_random: apply step quantization after log-scale sampling - internal_bounds: return None for non-positive log-scale bounds - SobolSampler: use per-trial dimension HashMap for concurrent safety - JournalStorage refresh: protect with io_lock mutex, use fetch_max - n_trials(): filter by TrialState::Complete as documented - FloatParam: reject NaN/Infinity in validate() - Pruners: assert n_min_trials >= 1, guard compute_percentile on empty - Visualization: escape_js for importance chart parameter names Low severity: - save(): use peek_next_trial_id() from Storage trait - csv_escape: handle carriage return per RFC 4180 - from_internal: use saturating arithmetic for stepped Int distributions - BoolParam: bounds-check categorical index < 2 - min_max: skip NaN values with safe fallback --- src/parameter.rs | 42 +++++++++++- src/pruner/median.rs | 5 ++ src/pruner/percentile.rs | 6 ++ src/sampler/common.rs | 24 ++++++- src/sampler/motpe.rs | 99 ++++++++++++++++++++--------- src/sampler/nsga3.rs | 31 +++++---- src/sampler/sobol.rs | 26 +++----- src/sampler/tpe/multivariate/mod.rs | 92 +++++++++++++++------------ src/sampler/tpe/sampler.rs | 99 ++++++++++++++++++++--------- src/storage/journal.rs | 44 ++++++++----- src/storage/memory.rs | 4 ++ src/storage/mod.rs | 7 ++ src/study/export.rs | 2 +- src/study/mod.rs | 10 ++- src/study/persistence.rs | 2 +- src/visualization.rs | 12 +++- 16 files changed, 350 insertions(+), 155 deletions(-) diff --git a/src/parameter.rs b/src/parameter.rs index 74782f8..b22a1c8 100644 --- a/src/parameter.rs +++ b/src/parameter.rs @@ -255,6 +255,12 @@ impl Parameter for FloatParam { } fn validate(&self) -> Result<()> { + if !self.low.is_finite() || !self.high.is_finite() { + return Err(Error::InvalidBounds { + low: self.low, + high: self.high, + }); + } if self.low > self.high { return Err(Error::InvalidBounds { low: self.low, @@ -265,7 +271,7 @@ impl Parameter for FloatParam { return Err(Error::InvalidLogBounds); } if let Some(step) = self.step - && step <= 0.0 + && (!step.is_finite() || step <= 0.0) { return Err(Error::InvalidStep); } @@ -550,7 +556,8 @@ impl Parameter for BoolParam { fn cast_param_value(&self, param_value: &ParamValue) -> Result { match param_value { - ParamValue::Categorical(index) => Ok(*index != 0), + ParamValue::Categorical(index) if *index < 2 => Ok(*index != 0), + ParamValue::Categorical(_) => Err(Error::Internal("bool index out of bounds")), _ => Err(Error::Internal( "Categorical distribution should return Categorical value", )), @@ -789,6 +796,30 @@ mod tests { assert!(param.validate().is_err()); } + #[test] + fn float_param_validate_nan() { + assert!(FloatParam::new(f64::NAN, 1.0).validate().is_err()); + assert!(FloatParam::new(0.0, f64::NAN).validate().is_err()); + assert!(FloatParam::new(f64::NAN, f64::NAN).validate().is_err()); + } + + #[test] + fn float_param_validate_infinity() { + assert!(FloatParam::new(f64::INFINITY, 1.0).validate().is_err()); + assert!(FloatParam::new(0.0, f64::NEG_INFINITY).validate().is_err()); + } + + #[test] + fn float_param_validate_nan_step() { + assert!(FloatParam::new(0.0, 1.0).step(f64::NAN).validate().is_err()); + assert!( + FloatParam::new(0.0, 1.0) + .step(f64::INFINITY) + .validate() + .is_err() + ); + } + #[test] #[allow(clippy::float_cmp)] fn float_param_cast_param_value() { @@ -920,6 +951,13 @@ mod tests { assert!(param.cast_param_value(&ParamValue::Float(1.0)).is_err()); } + #[test] + fn bool_param_cast_out_of_bounds() { + let param = BoolParam::new(); + assert!(param.cast_param_value(&ParamValue::Categorical(2)).is_err()); + assert!(param.cast_param_value(&ParamValue::Categorical(5)).is_err()); + } + #[derive(Clone, Debug, PartialEq)] enum TestEnum { A, diff --git a/src/pruner/median.rs b/src/pruner/median.rs index 0e97c8f..5451e66 100644 --- a/src/pruner/median.rs +++ b/src/pruner/median.rs @@ -89,8 +89,13 @@ impl MedianPruner { } /// Set the minimum number of completed trials required before pruning. + /// + /// # Panics + /// + /// Panics if `n` is 0. #[must_use] pub fn n_min_trials(mut self, n: usize) -> Self { + assert!(n >= 1, "n_min_trials must be >= 1, got {n}"); self.n_min_trials = n; self } diff --git a/src/pruner/percentile.rs b/src/pruner/percentile.rs index f12df01..5d6e574 100644 --- a/src/pruner/percentile.rs +++ b/src/pruner/percentile.rs @@ -95,8 +95,13 @@ impl PercentilePruner { } /// Set the minimum number of completed trials required before pruning. + /// + /// # Panics + /// + /// Panics if `n` is 0. #[must_use] pub fn n_min_trials(mut self, n: usize) -> Self { + assert!(n >= 1, "n_min_trials must be >= 1, got {n}"); self.n_min_trials = n; self } @@ -157,6 +162,7 @@ impl Pruner for PercentilePruner { clippy::cast_sign_loss )] pub(crate) fn compute_percentile(values: &mut [f64], percentile: f64) -> f64 { + assert!(!values.is_empty(), "compute_percentile: empty input"); values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal)); let len = values.len(); if len == 1 { diff --git a/src/sampler/common.rs b/src/sampler/common.rs index a2640bd..293b96b 100644 --- a/src/sampler/common.rs +++ b/src/sampler/common.rs @@ -10,6 +10,9 @@ pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> match distribution { Distribution::Float(d) => { if d.log_scale { + if d.low <= 0.0 || d.high <= 0.0 { + return None; + } Some((d.low.ln(), d.high.ln())) } else { Some((d.low, d.high)) @@ -17,6 +20,9 @@ pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)> } Distribution::Int(d) => { if d.log_scale { + if d.low < 1 { + return None; + } Some(((d.low as f64).ln(), (d.high as f64).ln())) } else { Some((d.low as f64, d.high as f64)) @@ -44,7 +50,7 @@ pub(crate) fn from_internal(value: f64, distribution: &Distribution) -> ParamVal let v = if d.log_scale { value.exp() } else { value }; let v = if let Some(step) = d.step { let k = ((v - d.low as f64) / step as f64).round() as i64; - d.low + k * step + d.low.saturating_add(k.saturating_mul(step)) } else { v.round() as i64 }; @@ -86,7 +92,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution let value = if d.log_scale { let log_low = d.low.ln(); let log_high = d.high.ln(); - rng_util::f64_range(rng, log_low, log_high).exp() + let v = rng_util::f64_range(rng, log_low, log_high).exp(); + if let Some(step) = d.step { + let k = ((v - d.low) / step).round(); + (d.low + k * step).clamp(d.low, d.high) + } else { + v + } } else if let Some(step) = d.step { let n_steps = ((d.high - d.low) / step).floor() as i64; let k = rng.i64(0..=n_steps); @@ -100,7 +112,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution let value = if d.log_scale { let log_low = (d.low as f64).ln(); let log_high = (d.high as f64).ln(); - let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64; + let v = rng_util::f64_range(rng, log_low, log_high).exp(); + let raw = if let Some(step) = d.step { + let k = ((v - d.low as f64) / step as f64).round() as i64; + d.low.saturating_add(k.saturating_mul(step)) + } else { + v.round() as i64 + }; raw.clamp(d.low, d.high) } else if let Some(step) = d.step { let n_steps = (d.high - d.low) / step; diff --git a/src/sampler/motpe.rs b/src/sampler/motpe.rs index 8049dee..749e576 100644 --- a/src/sampler/motpe.rs +++ b/src/sampler/motpe.rs @@ -222,24 +222,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Float(d.clone()); let good_values: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -264,24 +277,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Int(d.clone()); let good_values: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -307,24 +333,37 @@ impl MotpeSampler { bad_trials: &[&MultiObjectiveTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Categorical(d.clone()); let good_indices: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); let bad_indices: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); if good_indices.is_empty() || bad_indices.is_empty() { diff --git a/src/sampler/nsga3.rs b/src/sampler/nsga3.rs index 7ac840c..80c9117 100644 --- a/src/sampler/nsga3.rs +++ b/src/sampler/nsga3.rs @@ -561,7 +561,7 @@ fn nsga3_select( state: &mut Nsga3State, population: &[&MultiObjectiveTrial], directions: &[Direction], -) -> Vec> { +) -> (Vec>, Vec) { let pop_size = state.evo.population_size; let n_obj = directions.len(); @@ -633,12 +633,13 @@ fn nsga3_select( selected.push(state.evo.rng.usize(0..n)); } - selected + let params = selected .iter() .map(|&idx| { extract_trial_params(population[idx], &state.evo.dimensions, &mut state.evo.rng) }) - .collect() + .collect(); + (params, selected) } /// Tournament selection based on rank only (no crowding distance in NSGA-III). @@ -675,26 +676,34 @@ fn nsga3_generate_offspring( initialize_nsga3(state, directions); } - let parents = nsga3_select(state, population, directions); + let (parents, selected_indices) = nsga3_select(state, population, directions); - // Assign ranks for tournament selection + // Assign Pareto front ranks for tournament selection let n_obj = directions.len(); let min_values: Vec> = population .iter() .map(|t| to_minimize_space(&t.values, directions)) .collect(); let fronts = pareto::fast_non_dominated_sort(&min_values, &vec![Direction::Minimize; n_obj]); - let mut rank = vec![0_usize; parents.len()]; + // Build rank lookup for population indices + let mut pop_rank = vec![0_usize; population.len()]; for (front_rank, front) in fronts.iter().enumerate() { for &idx in front { - if idx < rank.len() { - rank[idx] = front_rank; + if idx < pop_rank.len() { + pop_rank[idx] = front_rank; } } } - // Ranks for selected parents (simplified: use index order) - let parent_ranks: Vec = (0..parents.len()) - .map(|i| i % (fronts.len().max(1))) + // Map population ranks to selected parent indices + let parent_ranks: Vec = selected_indices + .iter() + .map(|&idx| { + if idx < pop_rank.len() { + pop_rank[idx] + } else { + 0 + } + }) .collect(); let mut offspring = Vec::with_capacity(pop_size); diff --git a/src/sampler/sobol.rs b/src/sampler/sobol.rs index 2dad0b8..4b15ce4 100644 --- a/src/sampler/sobol.rs +++ b/src/sampler/sobol.rs @@ -44,6 +44,8 @@ //! let study: Study = Study::with_sampler(Direction::Minimize, SobolSampler::with_seed(42)); //! ``` +use std::collections::HashMap; + use parking_lot::Mutex; use sobol_burley::sample; @@ -51,12 +53,10 @@ use crate::distribution::Distribution; use crate::param::ParamValue; use crate::sampler::{CompletedTrial, Sampler}; -/// Internal state for tracking the dimension counter within a trial. +/// Internal state for tracking per-trial dimension counters. struct SobolState { - /// The `trial_id` of the current trial (used to reset dimension counter). - current_trial: u64, - /// Next Sobol dimension to use for the current trial. - next_dimension: u32, + /// Next Sobol dimension for each in-flight trial. + dimensions: HashMap, } /// Quasi-random sampler using Sobol low-discrepancy sequences. @@ -107,8 +107,7 @@ impl SobolSampler { Self { seed: seed as u32, state: Mutex::new(SobolState { - current_trial: u64::MAX, - next_dimension: 0, + dimensions: HashMap::new(), }), } } @@ -130,20 +129,15 @@ impl Sampler for SobolSampler { ) -> ParamValue { let mut state = self.state.lock(); - // Reset dimension counter when a new trial starts. - if state.current_trial != trial_id { - state.current_trial = trial_id; - state.next_dimension = 0; - } - - let dimension = state.next_dimension; - state.next_dimension = dimension + 1; + let dimension = state.dimensions.entry(trial_id).or_insert(0); + let dim = *dimension; + *dimension = dim + 1; // Use trial_id as the Sobol sequence index. let index = trial_id as u32; // Generate a quasi-random point in [0, 1). - let point = f64::from(sample(index, dimension, self.seed)); + let point = f64::from(sample(index, dim, self.seed)); map_point_to_distribution(point, distribution) } diff --git a/src/sampler/tpe/multivariate/mod.rs b/src/sampler/tpe/multivariate/mod.rs index f3738f9..4c452bc 100644 --- a/src/sampler/tpe/multivariate/mod.rs +++ b/src/sampler/tpe/multivariate/mod.rs @@ -216,6 +216,13 @@ pub enum ConstantLiarStrategy { /// /// assert!(study.best_value().unwrap() < 1.0); /// ``` +/// Cached joint sample for a specific trial. +struct JointSampleCache { + trial_id: u64, + search_space: HashMap, + sample: HashMap, +} + pub struct MultivariateTpeSampler { /// Strategy for computing the gamma quantile. gamma_strategy: Arc, @@ -230,8 +237,7 @@ pub struct MultivariateTpeSampler { /// Thread-safe RNG for sampling. rng: Mutex, /// Cache for joint samples to maintain consistency across parameters within the same trial. - /// The tuple contains (`trial_id`, cached joint sample). - joint_sample_cache: Mutex)>>, + joint_sample_cache: Mutex>, } impl MultivariateTpeSampler { @@ -453,11 +459,13 @@ impl Sampler for MultivariateTpeSampler { // Check if we have a cached joint sample for this trial { let cache = self.joint_sample_cache.lock(); - if let Some((cached_trial_id, ref cached_sample)) = *cache - && cached_trial_id == trial_id + if let Some(ref c) = *cache + && c.trial_id == trial_id { // Try to find a matching parameter from the cached sample - if let Some(value) = Self::find_matching_param(distribution, cached_sample) { + if let Some(value) = + Self::find_matching_param(distribution, &c.search_space, &c.sample) + { return value; } } @@ -470,13 +478,18 @@ impl Sampler for MultivariateTpeSampler { let joint_sample = self.sample_joint(&search_space, history); // Cache the joint sample for this trial + let result = Self::find_matching_param(distribution, &search_space, &joint_sample); { let mut cache = self.joint_sample_cache.lock(); - *cache = Some((trial_id, joint_sample.clone())); + *cache = Some(JointSampleCache { + trial_id, + search_space, + sample: joint_sample, + }); } // Find and return the value for the requested distribution - Self::find_matching_param(distribution, &joint_sample).unwrap_or_else(|| { + result.unwrap_or_else(|| { // Fallback to uniform sampling if no match found let mut rng = self.rng.lock(); crate::sampler::common::sample_random(&mut rng, distribution) @@ -485,33 +498,18 @@ impl Sampler for MultivariateTpeSampler { } impl MultivariateTpeSampler { - /// Finds a matching parameter value from the cached sample based on distribution. - /// - /// This is an associated function that matches parameters by comparing - /// distribution bounds and types. + /// Finds a matching parameter value from the cached sample based on exact + /// distribution equality. fn find_matching_param( distribution: &Distribution, + search_space: &HashMap, cached_sample: &HashMap, ) -> Option { - // Match by distribution type and value compatibility - for value in cached_sample.values() { - match (distribution, value) { - (Distribution::Float(d), ParamValue::Float(v)) => { - if *v >= d.low && *v <= d.high { - return Some(value.clone()); - } - } - (Distribution::Int(d), ParamValue::Int(v)) => { - if *v >= d.low && *v <= d.high { - return Some(value.clone()); - } - } - (Distribution::Categorical(d), ParamValue::Categorical(v)) => { - if *v < d.n_choices { - return Some(value.clone()); - } - } - _ => {} + for (id, dist) in search_space { + if dist == distribution + && let Some(value) = cached_sample.get(id) + { + return Some(value.clone()); } } None @@ -4213,12 +4211,15 @@ mod tests { fn test_find_matching_param_float() { let x_id = ParamId::new(); let y_id = ParamId::new(); + let dist = float_dist(0.0, 1.0); + let mut space = HashMap::new(); + space.insert(x_id, dist.clone()); + space.insert(y_id, float_dist(2.0, 3.0)); let mut cached = HashMap::new(); cached.insert(x_id, ParamValue::Float(0.5)); - cached.insert(y_id, ParamValue::Float(0.8)); + cached.insert(y_id, ParamValue::Float(2.8)); - let dist = float_dist(0.0, 1.0); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Float(v)) = result { @@ -4229,11 +4230,13 @@ mod tests { #[test] fn test_find_matching_param_int() { let n_id = ParamId::new(); + let dist = int_dist(0, 10); + let mut space = HashMap::new(); + space.insert(n_id, dist.clone()); let mut cached = HashMap::new(); cached.insert(n_id, ParamValue::Int(5)); - let dist = int_dist(0, 10); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Int(v)) = result { @@ -4244,11 +4247,13 @@ mod tests { #[test] fn test_find_matching_param_categorical() { let choice_id = ParamId::new(); + let dist = categorical_dist(3); + let mut space = HashMap::new(); + space.insert(choice_id, dist.clone()); let mut cached = HashMap::new(); cached.insert(choice_id, ParamValue::Categorical(1)); - let dist = categorical_dist(3); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_some()); if let Some(ParamValue::Categorical(v)) = result { @@ -4259,12 +4264,14 @@ mod tests { #[test] fn test_find_matching_param_no_match() { let x_id = ParamId::new(); + let mut space = HashMap::new(); + space.insert(x_id, float_dist(0.0, 1.0)); let mut cached = HashMap::new(); cached.insert(x_id, ParamValue::Float(0.5)); - // Looking for Int, but only Float in cache + // Looking for Int, but only Float in search space let dist = int_dist(0, 10); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_none()); } @@ -4272,11 +4279,14 @@ mod tests { #[test] fn test_find_matching_param_out_of_bounds() { let x_id = ParamId::new(); + // Search space has a different distribution than what we're looking for + let mut space = HashMap::new(); + space.insert(x_id, float_dist(0.0, 10.0)); let mut cached = HashMap::new(); - cached.insert(x_id, ParamValue::Float(5.0)); // Out of bounds + cached.insert(x_id, ParamValue::Float(5.0)); let dist = float_dist(0.0, 1.0); - let result = MultivariateTpeSampler::find_matching_param(&dist, &cached); + let result = MultivariateTpeSampler::find_matching_param(&dist, &space, &cached); assert!(result.is_none()); } diff --git a/src/sampler/tpe/sampler.rs b/src/sampler/tpe/sampler.rs index 32f3f1a..b4c1c69 100644 --- a/src/sampler/tpe/sampler.rs +++ b/src/sampler/tpe/sampler.rs @@ -650,24 +650,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Float(d.clone()); let good_values: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Float(f) => Some(*f), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Float(f) => Some(*f), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -692,24 +705,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Int(d.clone()); let good_values: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); let bad_values: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Int(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Int(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&v| v >= d.low && v <= d.high) .collect(); if good_values.is_empty() || bad_values.is_empty() { @@ -735,24 +761,37 @@ impl TpeSampler { bad_trials: &[&CompletedTrial], rng: &mut fastrand::Rng, ) -> ParamValue { + let target_dist = Distribution::Categorical(d.clone()); let good_indices: Vec = good_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); let bad_indices: Vec = bad_trials .iter() - .flat_map(|t| t.params.values()) - .filter_map(|v| match v { - ParamValue::Categorical(i) => Some(*i), - _ => None, + .filter_map(|t| { + t.distributions.iter().find_map(|(id, dist)| { + if *dist == target_dist { + t.params.get(id).and_then(|v| match v { + ParamValue::Categorical(i) => Some(*i), + _ => None, + }) + } else { + None + } + }) }) - .filter(|&i| i < d.n_choices) .collect(); if good_indices.is_empty() || bad_indices.is_empty() { diff --git a/src/storage/journal.rs b/src/storage/journal.rs index 01b7ed8..8802972 100644 --- a/src/storage/journal.rs +++ b/src/storage/journal.rs @@ -130,8 +130,8 @@ use crate::sampler::CompletedTrial; pub struct JournalStorage { memory: MemoryStorage, path: PathBuf, - /// Serialise in-process writes so we only hold the file lock briefly. - write_lock: Mutex<()>, + /// Serialise in-process writes and refreshes so they don't race. + io_lock: Mutex<()>, /// Byte offset of last-read position for incremental refresh. file_offset: AtomicU64, _marker: PhantomData, @@ -156,7 +156,7 @@ impl JournalStorage { Self { memory: MemoryStorage::new(), path, - write_lock: Mutex::new(()), + io_lock: Mutex::new(()), file_offset: AtomicU64::new(0), _marker: PhantomData, } @@ -180,15 +180,19 @@ impl JournalStorage { Ok(Self { memory: MemoryStorage::with_trials(trials), path, - write_lock: Mutex::new(()), + io_lock: Mutex::new(()), file_offset: AtomicU64::new(offset), _marker: PhantomData, }) } /// Append a single trial to the JSONL file (best-effort). + /// + /// Does **not** advance `file_offset` — that is left to `refresh` + /// so that externally-written data between the old offset and our + /// write is never skipped. fn write_to_file(&self, trial: &CompletedTrial) -> crate::Result<()> { - let _guard = self.write_lock.lock(); + let _guard = self.io_lock.lock(); let mut file = OpenOptions::new() .create(true) @@ -211,11 +215,6 @@ impl JournalStorage { file.sync_data() .map_err(|e| crate::Error::Storage(e.to_string()))?; - let pos = file - .stream_position() - .map_err(|e| crate::Error::Storage(e.to_string()))?; - self.file_offset.store(pos, Ordering::SeqCst); - file.unlock() .map_err(|e| crate::Error::Storage(e.to_string()))?; @@ -238,7 +237,13 @@ impl Storage for JournalStorag self.memory.next_trial_id() } + fn peek_next_trial_id(&self) -> u64 { + self.memory.peek_next_trial_id() + } + fn refresh(&self) -> bool { + let _guard = self.io_lock.lock(); + let Ok(file) = File::open(&self.path) else { return false; }; @@ -275,6 +280,7 @@ impl Storage for JournalStorag let _ = file.unlock(); let bytes_read = buf.len() as u64; + let new_offset = offset + bytes_read; let mut new_trials = Vec::new(); for line in buf.lines() { @@ -293,19 +299,23 @@ impl Storage for JournalStorag } if new_trials.is_empty() { - self.file_offset - .store(offset + bytes_read, Ordering::SeqCst); + self.file_offset.fetch_max(new_offset, Ordering::SeqCst); return false; } - let mut guard = self.memory.trials_arc().write(); + let mut mem_guard = self.memory.trials_arc().write(); + + // Deduplicate: only add trials whose IDs are not already in memory. + let existing_ids: std::collections::HashSet = mem_guard.iter().map(|t| t.id).collect(); + new_trials.retain(|t| !existing_ids.contains(&t.id)); + if let Some(max_id) = new_trials.iter().map(|t| t.id).max() { self.memory.bump_next_id(max_id + 1); } - guard.extend(new_trials); - self.file_offset - .store(offset + bytes_read, Ordering::SeqCst); - true + let added = !new_trials.is_empty(); + mem_guard.extend(new_trials); + self.file_offset.fetch_max(new_offset, Ordering::SeqCst); + added } } diff --git a/src/storage/memory.rs b/src/storage/memory.rs index eeea613..8f91adb 100644 --- a/src/storage/memory.rs +++ b/src/storage/memory.rs @@ -104,4 +104,8 @@ impl Storage for MemoryStorage { fn next_trial_id(&self) -> u64 { self.next_id.fetch_add(1, Ordering::SeqCst) } + + fn peek_next_trial_id(&self) -> u64 { + self.next_id.load(Ordering::SeqCst) + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index d0e32a0..5b683c7 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -77,6 +77,13 @@ pub trait Storage: Send + Sync { /// calls always produce distinct IDs. fn next_trial_id(&self) -> u64; + /// Return the current value of the next-trial-ID counter without incrementing. + /// + /// This is used for persistence (e.g. `Study::save`) to capture the + /// counter's exact position, including IDs assigned to failed trials + /// that are not stored. + fn peek_next_trial_id(&self) -> u64; + /// Reload from an external source (e.g. a file written by another /// process). Return `true` if the in-memory buffer was updated. /// diff --git a/src/study/export.rs b/src/study/export.rs index 4933982..b02afb8 100644 --- a/src/study/export.rs +++ b/src/study/export.rs @@ -255,7 +255,7 @@ impl Study { /// Escape a string for CSV output. If the value contains a comma, quote, or /// newline, wrap it in double-quotes and double any embedded quotes. fn csv_escape(s: &str) -> String { - if s.contains(',') || s.contains('"') || s.contains('\n') { + if s.contains(',') || s.contains('"') || s.contains('\n') || s.contains('\r') { format!("\"{}\"", s.replace('"', "\"\"")) } else { s.to_string() diff --git a/src/study/mod.rs b/src/study/mod.rs index 00948c1..4906698 100644 --- a/src/study/mod.rs +++ b/src/study/mod.rs @@ -639,7 +639,8 @@ where /// Return the number of completed trials. /// - /// Failed trials are not counted. + /// Pruned and failed trials are not counted. Use + /// [`n_pruned_trials()`](Self::n_pruned_trials) for the pruned count. /// /// # Examples /// @@ -658,7 +659,12 @@ where /// ``` #[must_use] pub fn n_trials(&self) -> usize { - self.storage.trials_arc().read().len() + self.storage + .trials_arc() + .read() + .iter() + .filter(|t| t.state == TrialState::Complete) + .count() } /// Return the number of pruned trials. diff --git a/src/study/persistence.rs b/src/study/persistence.rs index c9174d6..95b07bc 100644 --- a/src/study/persistence.rs +++ b/src/study/persistence.rs @@ -49,7 +49,7 @@ impl Study { pub fn save(&self, path: impl AsRef) -> std::io::Result<()> { let path = path.as_ref(); let trials = self.trials(); - let next_trial_id = trials.iter().map(|t| t.id).max().map_or(0, |id| id + 1); + let next_trial_id = self.storage.peek_next_trial_id(); let snapshot = StudySnapshot { version: 1, direction: self.direction, diff --git a/src/visualization.rs b/src/visualization.rs index 00eac89..776479b 100644 --- a/src/visualization.rs +++ b/src/visualization.rs @@ -341,7 +341,10 @@ Plotly.newPlot("parcoords", [{{ } fn write_importance_chart(html: &mut String, importance: &[(String, f64)]) { - let names: Vec<_> = importance.iter().map(|(n, _)| format!("\"{n}\"")).collect(); + let names: Vec<_> = importance + .iter() + .map(|(n, _)| format!("\"{}\"", escape_js(n))) + .collect(); let values: Vec = importance.iter().map(|(_, v)| *v).collect(); let _ = write!( @@ -457,6 +460,9 @@ fn min_max(vals: &[f64]) -> (f64, f64) { let mut mn = f64::INFINITY; let mut mx = f64::NEG_INFINITY; for &v in vals { + if v.is_nan() { + continue; + } if v < mn { mn = v; } @@ -464,6 +470,10 @@ fn min_max(vals: &[f64]) -> (f64, f64) { mx = v; } } + // If all values were NaN, return 0.0..1.0 as a safe fallback. + if mn > mx { + return (0.0, 1.0); + } (mn, mx) } From fd8aacfb133afdc6ad7502cf1a7d6a076cbb38de Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Fri, 13 Feb 2026 10:00:07 +0100 Subject: [PATCH 2/3] ci: trigger CI on pull requests targeting refactor branch --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1f44f41..2789db2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: [main, master] pull_request: - branches: [main, master] + branches: [main, master, refactor] permissions: contents: read From f45f146ef0451b5726a64db080e96499124530e7 Mon Sep 17 00:00:00 2001 From: Manuel Raimann Date: Fri, 13 Feb 2026 10:00:23 +0100 Subject: [PATCH 3/3] ci: trigger CI on pull requests targeting any branch --- .github/workflows/ci.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2789db2..268bc25 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,6 @@ on: push: branches: [main, master] pull_request: - branches: [main, master, refactor] permissions: contents: read