Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions dashboard/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,9 @@ def __init__(self, config_dict, model_type):
if model_type not in ("NN", "ensemble_NN", "GP"):
raise ValueError(f"Unsupported model type: {model_type}")
# Populate inferred calibration in physics units for GUI
# (only meaningful inside the dashboard where state.simulation_calibration is set)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now also set state.simulation_calibration in check_model.py, so this check is not meaningful anymore.

if state.simulation_calibration is not None:
self.populate_inferred_calibration(
config_dict["inputs"], config_dict["outputs"]
)
self.populate_inferred_calibration(
config_dict["inputs"], config_dict["outputs"]
)
except Exception as e:
title = f"Unable to load model {model_type}"
msg = f"Error occurred when loading model from MLflow: {e}"
Expand Down Expand Up @@ -152,7 +150,11 @@ def populate_inferred_calibration(self, input_variables, output_variables):
value.pop("beta_inferred", None)

# Input calibration
input_transformers = self.__model.input_transformers
# For ensemble_NN, transformers live on each inner TorchModel (not on NNEnsemble itself)
if self.__model_type == "ensemble_NN":
input_transformers = self.__model.models[0].input_transformers
else:
input_transformers = self.__model.input_transformers
assert len(input_transformers) == 2, (
f"Expected exactly 2 input transformers (calibration + normalization), "
f"but got {len(input_transformers)}."
Expand All @@ -167,7 +169,11 @@ def populate_inferred_calibration(self, input_variables, output_variables):
state.simulation_calibration[key]["beta_inferred"] = float(beta_inferred[i])

# Output calibration
output_transformers = self.__model.output_transformers
# For ensemble_NN, transformers live on each inner TorchModel (not on NNEnsemble itself)
if self.__model_type == "ensemble_NN":
output_transformers = self.__model.models[0].output_transformers
else:
output_transformers = self.__model.output_transformers
assert len(output_transformers) == 2, (
f"Expected exactly 2 output transformers (normalization + calibration), "
f"but got {len(output_transformers)}."
Expand Down
66 changes: 42 additions & 24 deletions tests/check_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,38 @@
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "dashboard"
) # similar to "cd ../dashboard"
sys.path.insert(0, DASHBOARD_DIR)
from calibration_manager import SimulationCalibrationManager # noqa: E402
from model_manager import ModelManager # noqa: E402
from utils import load_database, load_data # noqa: E402
from state_manager import state # noqa: E402
from utils import load_database, load_data as _load_data # noqa: E402


MODEL_TYPES = ["GP", "NN", "ensemble_NN"]
ACCURACY_TOLERANCE = 0.80


def load_experimental_data(config_dict):
"""Fetch all experimental points from the database."""
def load_data(config_dict):
"""Fetch all experimental and simulation points from the database."""
input_names = [v["name"] for v in config_dict["inputs"].values()]
output_names = [v["name"] for v in config_dict["outputs"].values()]

db = load_database(config_dict)
exp_data, _ = load_data(db, config_dict["experiment"])
exp_data, sim_data = _load_data(db, config_dict["experiment"])

return exp_data, input_names, output_names
return exp_data, sim_data, input_names, output_names


def check_evaluate(config_dict, model_type):
"""Load model and evaluate with experimental data; verify accuracy (relative RMSE <= threshold)."""
# Load model
mm = ModelManager(config_dict=config_dict, model_type=model_type)
# Load experimental data
df_exp, input_names, output_names = load_experimental_data(config_dict)
def check_accuracy(mm, df, input_names, output_names, label):
"""Evaluate model on *df* and return True if all outputs pass the accuracy threshold."""
if len(df) == 0:
print(f"[SKIP] No {label} data available; skipping accuracy check.")
return True

# Skip accuracy check if no experimental data available
if len(df_exp) == 0:
print(
f"[SKIP] No experimental data available for {config_dict['experiment']}; skipping accuracy check."
)
return
inputs = {n: torch.tensor(df[n].values) for n in input_names}

# Convert input to the format expected by the model manager
inputs = {n: torch.tensor(df_exp[n].values) for n in input_names}

# Check accuracy
all_passed = True
for output_name in output_names:
actual = torch.tensor(df_exp[output_name].values)
actual = torch.tensor(df[output_name].values)
if actual.isnan().all():
print(
f" [SKIP] Output '{output_name}': all actual values are NaN; skipping."
Expand All @@ -75,8 +67,34 @@ def check_evaluate(config_dict, model_type):
print(
f" [{status}] Output '{output_name}': relative RMSE = {rmse:.1%} (tolerance {ACCURACY_TOLERANCE:.0%})"
)
return all_passed


def check_evaluate(config_dict, model_type):
"""Load model and evaluate with experimental and simulation data; verify accuracy."""
# Set up calibration so ModelManager can populate inferred values
simulation_calibration = config_dict.get("simulation_calibration", {})
cal_manager = SimulationCalibrationManager(simulation_calibration)
state.use_inferred_calibration = True

# Load model (populates inferred calibration in state.simulation_calibration)
mm = ModelManager(config_dict=config_dict, model_type=model_type)

# Load experimental and simulation data
df_exp, df_sim, input_names, output_names = load_data(config_dict)

# Check accuracy on experimental data
print("Checking experimental data...")
exp_passed = check_accuracy(mm, df_exp, input_names, output_names, "experimental")

# Convert simulation data to experimental units using inferred calibration
cal_manager.convert_sim_to_exp(df_sim)

# Check accuracy on simulation data
print("Checking simulation data...")
sim_passed = check_accuracy(mm, df_sim, input_names, output_names, "simulation")

if not all_passed:
if not (exp_passed and sim_passed):
raise RuntimeError(
f"Accuracy check failed: relative RMSE exceeded {ACCURACY_TOLERANCE:.0%} for one or more outputs."
)
Expand Down Expand Up @@ -105,7 +123,7 @@ def check_evaluate(config_dict, model_type):
config_dict = yaml.safe_load(f)
print(f"Experiment: {config_dict['experiment']}")

# Load model and evaluate with experimental data
# Load model and evaluate with experimental and simulation data
try:
check_evaluate(config_dict, args.model)
except Exception as e:
Expand Down
Loading