Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions libensemble/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from pathlib import Path

import numpy as np
import pydantic
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

Expand Down Expand Up @@ -354,6 +355,24 @@ def set_fields_from_vocs(self):
if "_id" not in self.persis_in:
self.persis_in.append("_id")

# Set user["lb"]/["ub"] from VOCS continuous variables (for legacy generators
# that read bounds from gen_specs["user"]). Skip variables without a ``.domain``
# attribute (e.g., DiscreteVariable). Do not overwrite user-provided values.
if self.user is None:
self.user = {}
if "lb" not in self.user or "ub" not in self.user:
lbs, ubs = [], []
for _name, var in (getattr(self.vocs, "variables", None) or {}).items():
domain = getattr(var, "domain", None)
if domain is not None and len(domain) == 2:
lbs.append(domain[0])
ubs.append(domain[1])
if lbs:
if "lb" not in self.user:
self.user["lb"] = np.array(lbs, dtype=float)
if "ub" not in self.user:
self.user["ub"] = np.array(ubs, dtype=float)

return self

@model_validator(mode="after")
Expand Down
4 changes: 0 additions & 4 deletions libensemble/tests/regression_tests/test_asktell_gpCAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@
"persis_in": ["x", "f", "sim_id"],
"out": [("x", float, (n,))],
"batch_size": batch_size,
"user": {
"lb": np.array([-3, -2, -1, -1]),
"ub": np.array([3, 2, 1, 1]),
},
}

vocs = VOCS(variables={"x0": [-3, 3], "x1": [-2, 2], "x2": [-1, 1], "x3": [-1, 1]}, objectives={"f": "MINIMIZE"})
Expand Down
97 changes: 97 additions & 0 deletions libensemble/tests/unit_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,97 @@ def test_ready_happy_path():
assert issues == [], f"Issues should be empty but got: {issues}"


def test_gen_specs_vocs_populates_user_bounds():
"""GenSpecs should populate user['lb']/['ub'] from VOCS continuous variables."""
from gest_api.vocs import VOCS

from libensemble.specs import GenSpecs

vocs = VOCS(
variables={"x0": [-3, 3], "x1": [-2, 2], "x2": [-1, 1], "x3": [-1, 1]},
objectives={"f": "MINIMIZE"},
)
gs = GenSpecs(vocs=vocs)
assert "lb" in gs.user, "lb should be populated in user from VOCS"
assert "ub" in gs.user, "ub should be populated in user from VOCS"
assert isinstance(gs.user["lb"], np.ndarray), "lb should be a numpy array"
assert isinstance(gs.user["ub"], np.ndarray), "ub should be a numpy array"
assert np.array_equal(gs.user["lb"], np.array([-3, -2, -1, -1]))
assert np.array_equal(gs.user["ub"], np.array([3, 2, 1, 1]))


def test_gen_specs_vocs_does_not_overwrite_user_bounds():
"""GenSpecs should not overwrite user-provided lb/ub when vocs is also given."""
from gest_api.vocs import VOCS

from libensemble.specs import GenSpecs

vocs = VOCS(variables={"x0": [-3, 3], "x1": [-2, 2]}, objectives={"f": "MINIMIZE"})
explicit_lb = np.array([0.0, 0.0])
explicit_ub = np.array([1.0, 1.0])
gs = GenSpecs(vocs=vocs, user={"lb": explicit_lb, "ub": explicit_ub})
assert np.array_equal(gs.user["lb"], explicit_lb), "Explicit lb should be preserved"
assert np.array_equal(gs.user["ub"], explicit_ub), "Explicit ub should be preserved"


def test_gen_specs_vocs_partial_user_bounds():
"""GenSpecs should fill in only the missing one of lb/ub if user supplies just one."""
from gest_api.vocs import VOCS

from libensemble.specs import GenSpecs

vocs = VOCS(variables={"x0": [-3, 3], "x1": [-2, 2]}, objectives={"f": "MINIMIZE"})
explicit_lb = np.array([0.0, 0.0])
gs = GenSpecs(vocs=vocs, user={"lb": explicit_lb})
assert np.array_equal(gs.user["lb"], explicit_lb), "Explicit lb should be preserved"
assert "ub" in gs.user, "ub should be populated from VOCS"
assert np.array_equal(gs.user["ub"], np.array([3, 2]))


def test_gen_specs_no_vocs_leaves_user_empty():
"""Without VOCS, GenSpecs.user should remain empty by default."""
from libensemble.specs import GenSpecs

gs = GenSpecs(outputs=[("x", float, (1,))])
assert "lb" not in gs.user, "lb should not be auto-populated without VOCS"
assert "ub" not in gs.user, "ub should not be auto-populated without VOCS"


def test_gen_specs_vocs_satisfies_legacy_user_params():
"""VOCS-populated user bounds should satisfy legacy gen_f consumers like
persistent_uniform (which require lb/ub to be numpy arrays and uses len(lb)
for dimension)."""
from gest_api.vocs import VOCS

from libensemble.gen_funcs.persistent_sampling import _get_user_params
from libensemble.specs import GenSpecs

vocs = VOCS(variables={"x0": [-3, 3], "x1": [-2, 2]}, objectives={"f": "MINIMIZE"})
gs = GenSpecs(vocs=vocs, initial_batch_size=10)

# Convert to dict shape that _get_user_params expects
gs_dict = {"initial_batch_size": gs.initial_batch_size, "user": gs.user}
b, n, lb, ub = _get_user_params(gs_dict["user"], gs_dict)
assert b == 10
assert n == 2
assert isinstance(lb, np.ndarray) and lb.dtype == float
assert isinstance(ub, np.ndarray) and ub.dtype == float
assert np.array_equal(lb, np.array([-3.0, -2.0]))
assert np.array_equal(ub, np.array([3.0, 2.0]))


def test_gen_specs_vocs_integer_domain_yields_float_array():
"""Integer-valued VOCS domains should still produce float dtype lb/ub arrays."""
from gest_api.vocs import VOCS

from libensemble.specs import GenSpecs

vocs = VOCS(variables={"x0": [0, 10], "x1": [-5, 5]}, objectives={"f": "MINIMIZE"})
gs = GenSpecs(vocs=vocs)
assert gs.user["lb"].dtype == float, "lb should be float dtype even for integer-domain variables"
assert gs.user["ub"].dtype == float, "ub should be float dtype even for integer-domain variables"


if __name__ == "__main__":
test_ensemble_init()
test_ensemble_parse_args_false()
Expand All @@ -283,3 +374,9 @@ def test_ready_happy_path():
test_ready_missing_nworkers_local()
test_ready_field_mismatch()
test_ready_happy_path()
test_gen_specs_vocs_populates_user_bounds()
test_gen_specs_vocs_does_not_overwrite_user_bounds()
test_gen_specs_vocs_partial_user_bounds()
test_gen_specs_no_vocs_leaves_user_empty()
test_gen_specs_vocs_satisfies_legacy_user_params()
test_gen_specs_vocs_integer_domain_yields_float_array()