diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index d104fbd5..b9294d36 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -98,11 +98,9 @@ def __init__(self, config_dict, model_type): if model_type not in ("NN", "ensemble_NN", "GP"): raise ValueError(f"Unsupported model type: {model_type}") # Populate inferred calibration in physics units for GUI - # (only meaningful inside the dashboard where state.simulation_calibration is set) - if state.simulation_calibration is not None: - self.populate_inferred_calibration( - config_dict["inputs"], config_dict["outputs"] - ) + self.populate_inferred_calibration( + config_dict["inputs"], config_dict["outputs"] + ) except Exception as e: title = f"Unable to load model {model_type}" msg = f"Error occurred when loading model from MLflow: {e}" @@ -152,7 +150,11 @@ def populate_inferred_calibration(self, input_variables, output_variables): value.pop("beta_inferred", None) # Input calibration - input_transformers = self.__model.input_transformers + # For ensemble_NN, transformers live on each inner TorchModel (not on NNEnsemble itself) + if self.__model_type == "ensemble_NN": + input_transformers = self.__model.models[0].input_transformers + else: + input_transformers = self.__model.input_transformers assert len(input_transformers) == 2, ( f"Expected exactly 2 input transformers (calibration + normalization), " f"but got {len(input_transformers)}." @@ -167,7 +169,11 @@ def populate_inferred_calibration(self, input_variables, output_variables): state.simulation_calibration[key]["beta_inferred"] = float(beta_inferred[i]) # Output calibration - output_transformers = self.__model.output_transformers + # For ensemble_NN, transformers live on each inner TorchModel (not on NNEnsemble itself) + if self.__model_type == "ensemble_NN": + output_transformers = self.__model.models[0].output_transformers + else: + output_transformers = self.__model.output_transformers assert len(output_transformers) == 2, ( f"Expected exactly 2 output transformers (normalization + calibration), " f"but got {len(output_transformers)}." diff --git a/tests/check_model.py b/tests/check_model.py index 4e6487b8..a9733d88 100644 --- a/tests/check_model.py +++ b/tests/check_model.py @@ -17,46 +17,38 @@ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "dashboard" ) # similar to "cd ../dashboard" sys.path.insert(0, DASHBOARD_DIR) +from calibration_manager import SimulationCalibrationManager # noqa: E402 from model_manager import ModelManager # noqa: E402 -from utils import load_database, load_data # noqa: E402 +from state_manager import state # noqa: E402 +from utils import load_database, load_data as _load_data # noqa: E402 MODEL_TYPES = ["GP", "NN", "ensemble_NN"] ACCURACY_TOLERANCE = 0.80 -def load_experimental_data(config_dict): - """Fetch all experimental points from the database.""" +def load_data(config_dict): + """Fetch all experimental and simulation points from the database.""" input_names = [v["name"] for v in config_dict["inputs"].values()] output_names = [v["name"] for v in config_dict["outputs"].values()] db = load_database(config_dict) - exp_data, _ = load_data(db, config_dict["experiment"]) + exp_data, sim_data = _load_data(db, config_dict["experiment"]) - return exp_data, input_names, output_names + return exp_data, sim_data, input_names, output_names -def check_evaluate(config_dict, model_type): - """Load model and evaluate with experimental data; verify accuracy (relative RMSE <= threshold).""" - # Load model - mm = ModelManager(config_dict=config_dict, model_type=model_type) - # Load experimental data - df_exp, input_names, output_names = load_experimental_data(config_dict) +def check_accuracy(mm, df, input_names, output_names, label): + """Evaluate model on *df* and return True if all outputs pass the accuracy threshold.""" + if len(df) == 0: + print(f"[SKIP] No {label} data available; skipping accuracy check.") + return True - # Skip accuracy check if no experimental data available - if len(df_exp) == 0: - print( - f"[SKIP] No experimental data available for {config_dict['experiment']}; skipping accuracy check." - ) - return + inputs = {n: torch.tensor(df[n].values) for n in input_names} - # Convert input to the format expected by the model manager - inputs = {n: torch.tensor(df_exp[n].values) for n in input_names} - - # Check accuracy all_passed = True for output_name in output_names: - actual = torch.tensor(df_exp[output_name].values) + actual = torch.tensor(df[output_name].values) if actual.isnan().all(): print( f" [SKIP] Output '{output_name}': all actual values are NaN; skipping." @@ -75,8 +67,34 @@ def check_evaluate(config_dict, model_type): print( f" [{status}] Output '{output_name}': relative RMSE = {rmse:.1%} (tolerance {ACCURACY_TOLERANCE:.0%})" ) + return all_passed + + +def check_evaluate(config_dict, model_type): + """Load model and evaluate with experimental and simulation data; verify accuracy.""" + # Set up calibration so ModelManager can populate inferred values + simulation_calibration = config_dict.get("simulation_calibration", {}) + cal_manager = SimulationCalibrationManager(simulation_calibration) + state.use_inferred_calibration = True + + # Load model (populates inferred calibration in state.simulation_calibration) + mm = ModelManager(config_dict=config_dict, model_type=model_type) + + # Load experimental and simulation data + df_exp, df_sim, input_names, output_names = load_data(config_dict) + + # Check accuracy on experimental data + print("Checking experimental data...") + exp_passed = check_accuracy(mm, df_exp, input_names, output_names, "experimental") + + # Convert simulation data to experimental units using inferred calibration + cal_manager.convert_sim_to_exp(df_sim) + + # Check accuracy on simulation data + print("Checking simulation data...") + sim_passed = check_accuracy(mm, df_sim, input_names, output_names, "simulation") - if not all_passed: + if not (exp_passed and sim_passed): raise RuntimeError( f"Accuracy check failed: relative RMSE exceeded {ACCURACY_TOLERANCE:.0%} for one or more outputs." ) @@ -105,7 +123,7 @@ def check_evaluate(config_dict, model_type): config_dict = yaml.safe_load(f) print(f"Experiment: {config_dict['experiment']}") - # Load model and evaluate with experimental data + # Load model and evaluate with experimental and simulation data try: check_evaluate(config_dict, args.model) except Exception as e: