Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches: [main, master]
pull_request:
branches: [main, master]

permissions:
contents: read
Expand Down
42 changes: 40 additions & 2 deletions src/parameter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ impl Parameter for FloatParam {
}

fn validate(&self) -> Result<()> {
if !self.low.is_finite() || !self.high.is_finite() {
return Err(Error::InvalidBounds {
low: self.low,
high: self.high,
});
}
if self.low > self.high {
return Err(Error::InvalidBounds {
low: self.low,
Expand All @@ -265,7 +271,7 @@ impl Parameter for FloatParam {
return Err(Error::InvalidLogBounds);
}
if let Some(step) = self.step
&& step <= 0.0
&& (!step.is_finite() || step <= 0.0)
{
return Err(Error::InvalidStep);
}
Expand Down Expand Up @@ -550,7 +556,8 @@ impl Parameter for BoolParam {

fn cast_param_value(&self, param_value: &ParamValue) -> Result<bool> {
match param_value {
ParamValue::Categorical(index) => Ok(*index != 0),
ParamValue::Categorical(index) if *index < 2 => Ok(*index != 0),
ParamValue::Categorical(_) => Err(Error::Internal("bool index out of bounds")),
_ => Err(Error::Internal(
"Categorical distribution should return Categorical value",
)),
Expand Down Expand Up @@ -789,6 +796,30 @@ mod tests {
assert!(param.validate().is_err());
}

#[test]
fn float_param_validate_nan() {
assert!(FloatParam::new(f64::NAN, 1.0).validate().is_err());
assert!(FloatParam::new(0.0, f64::NAN).validate().is_err());
assert!(FloatParam::new(f64::NAN, f64::NAN).validate().is_err());
}

#[test]
fn float_param_validate_infinity() {
assert!(FloatParam::new(f64::INFINITY, 1.0).validate().is_err());
assert!(FloatParam::new(0.0, f64::NEG_INFINITY).validate().is_err());
}

#[test]
fn float_param_validate_nan_step() {
assert!(FloatParam::new(0.0, 1.0).step(f64::NAN).validate().is_err());
assert!(
FloatParam::new(0.0, 1.0)
.step(f64::INFINITY)
.validate()
.is_err()
);
}

#[test]
#[allow(clippy::float_cmp)]
fn float_param_cast_param_value() {
Expand Down Expand Up @@ -920,6 +951,13 @@ mod tests {
assert!(param.cast_param_value(&ParamValue::Float(1.0)).is_err());
}

#[test]
fn bool_param_cast_out_of_bounds() {
let param = BoolParam::new();
assert!(param.cast_param_value(&ParamValue::Categorical(2)).is_err());
assert!(param.cast_param_value(&ParamValue::Categorical(5)).is_err());
}

#[derive(Clone, Debug, PartialEq)]
enum TestEnum {
A,
Expand Down
5 changes: 5 additions & 0 deletions src/pruner/median.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,13 @@ impl MedianPruner {
}

/// Set the minimum number of completed trials required before pruning.
///
/// # Panics
///
/// Panics if `n` is 0.
#[must_use]
pub fn n_min_trials(mut self, n: usize) -> Self {
assert!(n >= 1, "n_min_trials must be >= 1, got {n}");
self.n_min_trials = n;
self
}
Expand Down
6 changes: 6 additions & 0 deletions src/pruner/percentile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ impl PercentilePruner {
}

/// Set the minimum number of completed trials required before pruning.
///
/// # Panics
///
/// Panics if `n` is 0.
#[must_use]
pub fn n_min_trials(mut self, n: usize) -> Self {
assert!(n >= 1, "n_min_trials must be >= 1, got {n}");
self.n_min_trials = n;
self
}
Expand Down Expand Up @@ -157,6 +162,7 @@ impl Pruner for PercentilePruner {
clippy::cast_sign_loss
)]
pub(crate) fn compute_percentile(values: &mut [f64], percentile: f64) -> f64 {
assert!(!values.is_empty(), "compute_percentile: empty input");
values.sort_unstable_by(|a, b| a.partial_cmp(b).unwrap_or(core::cmp::Ordering::Equal));
let len = values.len();
if len == 1 {
Expand Down
24 changes: 21 additions & 3 deletions src/sampler/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,19 @@ pub(crate) fn internal_bounds(distribution: &Distribution) -> Option<(f64, f64)>
match distribution {
Distribution::Float(d) => {
if d.log_scale {
if d.low <= 0.0 || d.high <= 0.0 {
return None;
}
Some((d.low.ln(), d.high.ln()))
} else {
Some((d.low, d.high))
}
}
Distribution::Int(d) => {
if d.log_scale {
if d.low < 1 {
return None;
}
Some(((d.low as f64).ln(), (d.high as f64).ln()))
} else {
Some((d.low as f64, d.high as f64))
Expand Down Expand Up @@ -44,7 +50,7 @@ pub(crate) fn from_internal(value: f64, distribution: &Distribution) -> ParamVal
let v = if d.log_scale { value.exp() } else { value };
let v = if let Some(step) = d.step {
let k = ((v - d.low as f64) / step as f64).round() as i64;
d.low + k * step
d.low.saturating_add(k.saturating_mul(step))
} else {
v.round() as i64
};
Expand Down Expand Up @@ -86,7 +92,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution
let value = if d.log_scale {
let log_low = d.low.ln();
let log_high = d.high.ln();
rng_util::f64_range(rng, log_low, log_high).exp()
let v = rng_util::f64_range(rng, log_low, log_high).exp();
if let Some(step) = d.step {
let k = ((v - d.low) / step).round();
(d.low + k * step).clamp(d.low, d.high)
} else {
v
}
} else if let Some(step) = d.step {
let n_steps = ((d.high - d.low) / step).floor() as i64;
let k = rng.i64(0..=n_steps);
Expand All @@ -100,7 +112,13 @@ pub(crate) fn sample_random(rng: &mut fastrand::Rng, distribution: &Distribution
let value = if d.log_scale {
let log_low = (d.low as f64).ln();
let log_high = (d.high as f64).ln();
let raw = rng_util::f64_range(rng, log_low, log_high).exp().round() as i64;
let v = rng_util::f64_range(rng, log_low, log_high).exp();
let raw = if let Some(step) = d.step {
let k = ((v - d.low as f64) / step as f64).round() as i64;
d.low.saturating_add(k.saturating_mul(step))
} else {
v.round() as i64
};
raw.clamp(d.low, d.high)
} else if let Some(step) = d.step {
let n_steps = (d.high - d.low) / step;
Expand Down
99 changes: 69 additions & 30 deletions src/sampler/motpe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,24 +222,37 @@ impl MotpeSampler {
bad_trials: &[&MultiObjectiveTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Float(d.clone());
let good_values: Vec<f64> = good_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Float(f) => Some(*f),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Float(f) => Some(*f),
_ => None,
})
} else {
None
}
})
})
.filter(|&v| v >= d.low && v <= d.high)
.collect();

let bad_values: Vec<f64> = bad_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Float(f) => Some(*f),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Float(f) => Some(*f),
_ => None,
})
} else {
None
}
})
})
.filter(|&v| v >= d.low && v <= d.high)
.collect();

if good_values.is_empty() || bad_values.is_empty() {
Expand All @@ -264,24 +277,37 @@ impl MotpeSampler {
bad_trials: &[&MultiObjectiveTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Int(d.clone());
let good_values: Vec<i64> = good_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Int(i) => Some(*i),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Int(i) => Some(*i),
_ => None,
})
} else {
None
}
})
})
.filter(|&v| v >= d.low && v <= d.high)
.collect();

let bad_values: Vec<i64> = bad_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Int(i) => Some(*i),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Int(i) => Some(*i),
_ => None,
})
} else {
None
}
})
})
.filter(|&v| v >= d.low && v <= d.high)
.collect();

if good_values.is_empty() || bad_values.is_empty() {
Expand All @@ -307,24 +333,37 @@ impl MotpeSampler {
bad_trials: &[&MultiObjectiveTrial],
rng: &mut fastrand::Rng,
) -> ParamValue {
let target_dist = Distribution::Categorical(d.clone());
let good_indices: Vec<usize> = good_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Categorical(i) => Some(*i),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Categorical(i) => Some(*i),
_ => None,
})
} else {
None
}
})
})
.filter(|&i| i < d.n_choices)
.collect();

let bad_indices: Vec<usize> = bad_trials
.iter()
.flat_map(|t| t.params.values())
.filter_map(|v| match v {
ParamValue::Categorical(i) => Some(*i),
_ => None,
.filter_map(|t| {
t.distributions.iter().find_map(|(id, dist)| {
if *dist == target_dist {
t.params.get(id).and_then(|v| match v {
ParamValue::Categorical(i) => Some(*i),
_ => None,
})
} else {
None
}
})
})
.filter(|&i| i < d.n_choices)
.collect();

if good_indices.is_empty() || bad_indices.is_empty() {
Expand Down
31 changes: 20 additions & 11 deletions src/sampler/nsga3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ fn nsga3_select(
state: &mut Nsga3State,
population: &[&MultiObjectiveTrial],
directions: &[Direction],
) -> Vec<Vec<ParamValue>> {
) -> (Vec<Vec<ParamValue>>, Vec<usize>) {
let pop_size = state.evo.population_size;
let n_obj = directions.len();

Expand Down Expand Up @@ -633,12 +633,13 @@ fn nsga3_select(
selected.push(state.evo.rng.usize(0..n));
}

selected
let params = selected
.iter()
.map(|&idx| {
extract_trial_params(population[idx], &state.evo.dimensions, &mut state.evo.rng)
})
.collect()
.collect();
(params, selected)
}

/// Tournament selection based on rank only (no crowding distance in NSGA-III).
Expand Down Expand Up @@ -675,26 +676,34 @@ fn nsga3_generate_offspring(
initialize_nsga3(state, directions);
}

let parents = nsga3_select(state, population, directions);
let (parents, selected_indices) = nsga3_select(state, population, directions);

// Assign ranks for tournament selection
// Assign Pareto front ranks for tournament selection
let n_obj = directions.len();
let min_values: Vec<Vec<f64>> = population
.iter()
.map(|t| to_minimize_space(&t.values, directions))
.collect();
let fronts = pareto::fast_non_dominated_sort(&min_values, &vec![Direction::Minimize; n_obj]);
let mut rank = vec![0_usize; parents.len()];
// Build rank lookup for population indices
let mut pop_rank = vec![0_usize; population.len()];
for (front_rank, front) in fronts.iter().enumerate() {
for &idx in front {
if idx < rank.len() {
rank[idx] = front_rank;
if idx < pop_rank.len() {
pop_rank[idx] = front_rank;
}
}
}
// Ranks for selected parents (simplified: use index order)
let parent_ranks: Vec<usize> = (0..parents.len())
.map(|i| i % (fronts.len().max(1)))
// Map population ranks to selected parent indices
let parent_ranks: Vec<usize> = selected_indices
.iter()
.map(|&idx| {
if idx < pop_rank.len() {
pop_rank[idx]
} else {
0
}
})
.collect();

let mut offspring = Vec::with_capacity(pop_size);
Expand Down
Loading