From dc37c5e47c5c612ef698128aa3043931346f14fd Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Thu, 8 Jan 2026 14:31:40 -0800 Subject: [PATCH 1/6] Fix alignment of components in parameters control panel --- dashboard/parameters_manager.py | 37 +++++++++++++++++---------------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index 1fafc8ba..8ea13a46 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -227,24 +227,25 @@ def panel(self): change="flushState('parameters_show_all')", label="Show all", ) - with vuetify.VRow(): - with vuetify.VCol(): - vuetify.VBtn( - "Reset", - click=self.reset, - style="text-transform: none", - ) - with vuetify.VRow(): - with vuetify.VCol(): - vuetify.VBtn( - "Simulate", - click=self.simulation_trigger, - disabled=( - "simulation_running || perlmutter_status != 'active' || !simulatable", - ), - style="text-transform: none;", - ) - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols=6): + with vuetify.VRow(): + with vuetify.VCol(): + vuetify.VBtn( + "Reset", + click=self.reset, + style="text-transform: none", + ) + with vuetify.VCol(): + vuetify.VBtn( + "Simulate", + click=self.simulation_trigger, + disabled=( + "simulation_running || perlmutter_status != 'active' || !simulatable", + ), + style="text-transform: none;", + ) + with vuetify.VCol(cols=6): vuetify.VTextField( v_model_number=("simulation_running_status",), label="Simulation status", From 919ad6e754986d5cf5b214b038521909fe7b3f75 Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Fri, 9 Jan 2026 09:27:18 -0800 Subject: [PATCH 2/6] Add displayed inputs type selector, move output selector --- dashboard/app.py | 4 ---- dashboard/outputs_manager.py | 17 ----------------- dashboard/parameters_manager.py | 16 ++++++++++++++++ 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/dashboard/app.py b/dashboard/app.py index 7cb6fe86..f554588b 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -285,10 +285,6 @@ def home_route(): vuetify.VTab("ML", value="ml_tab") with vuetify.VWindow(v_model=("active_tab",), mandatory=True): with vuetify.VWindowItem(value="parameters_tab"): - # output control panel - with vuetify.VRow(): - with vuetify.VCol(): - out_manager.panel() # parameters control panel with vuetify.VRow(): with vuetify.VCol(): diff --git a/dashboard/outputs_manager.py b/dashboard/outputs_manager.py index a25a8b25..f8499fda 100644 --- a/dashboard/outputs_manager.py +++ b/dashboard/outputs_manager.py @@ -1,5 +1,3 @@ -from trame.widgets import vuetify3 as vuetify - from state_manager import state @@ -9,18 +7,3 @@ def __init__(self, output_variables): # define state variables state.output_variables = [v["name"] for v in output_variables.values()] state.displayed_output = state.output_variables[0] - - def panel(self): - print("Setting output card...") - with vuetify.VExpansionPanels(v_model=("expand_panel_control_output", 0)): - with vuetify.VExpansionPanel( - title="Control: Displayed Output", - style="font-size: 20px; font-weight: 500;", - ): - with vuetify.VExpansionPanelText(): - with vuetify.VRow(): - vuetify.VSelect( - v_model=("displayed_output",), - items=(state.output_variables,), - dense=True, - ) diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index 8ea13a46..96334fa3 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -39,6 +39,8 @@ def __init__(self, model, input_variables): state.parameters_show_all[key] = False self.parameters_step[key] = (pmax - pmin) / 100 state.parameters_init = copy.deepcopy(state.parameters) + # define default dislpayed inputs + state.displayed_inputs = "Experiment" @property def model(self): @@ -155,6 +157,20 @@ def panel(self): style="font-size: 20px; font-weight: 500;", ): with vuetify.VExpansionPanelText(): + with vuetify.VRow(): + vuetify.VSelect( + v_model=("displayed_output",), + items=(state.output_variables,), + dense=True, + label="Displayed output", + ) + with vuetify.VRow(): + vuetify.VSelect( + v_model=("displayed_inputs",), + items=(["Experiment", "Simulation"],), + dense=True, + label="Displayed inputs", + ) with client.DeepReactive("parameters"): for count, key in enumerate(state.parameters.keys()): # create a row for the parameter label From e0bfe197a0e2b2befcff5c811d20e250e7607bb3 Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Fri, 9 Jan 2026 10:34:47 -0800 Subject: [PATCH 3/6] Move selectors on the same row --- dashboard/parameters_manager.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index 96334fa3..a629d89f 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -39,7 +39,7 @@ def __init__(self, model, input_variables): state.parameters_show_all[key] = False self.parameters_step[key] = (pmax - pmin) / 100 state.parameters_init = copy.deepcopy(state.parameters) - # define default dislpayed inputs + # define other state variables state.displayed_inputs = "Experiment" @property @@ -158,19 +158,20 @@ def panel(self): ): with vuetify.VExpansionPanelText(): with vuetify.VRow(): - vuetify.VSelect( - v_model=("displayed_output",), - items=(state.output_variables,), - dense=True, - label="Displayed output", - ) - with vuetify.VRow(): - vuetify.VSelect( - v_model=("displayed_inputs",), - items=(["Experiment", "Simulation"],), - dense=True, - label="Displayed inputs", - ) + with vuetify.VCol(): + vuetify.VSelect( + v_model=("displayed_inputs",), + items=(["Experiment", "Simulation"],), + dense=True, + label="Displayed inputs", + ) + with vuetify.VCol(): + vuetify.VSelect( + v_model=("displayed_output",), + items=(state.output_variables,), + dense=True, + label="Displayed output", + ) with client.DeepReactive("parameters"): for count, key in enumerate(state.parameters.keys()): # create a row for the parameter label From 40c2a053e37a9085e20402ba8c66feb9cf9fe4b7 Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Fri, 9 Jan 2026 16:44:21 -0800 Subject: [PATCH 4/6] Add reactivity, move default initialization --- dashboard/app.py | 8 ++++---- dashboard/parameters_manager.py | 2 -- dashboard/state_manager.py | 4 +++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/dashboard/app.py b/dashboard/app.py index f554588b..f37ce2f2 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -105,11 +105,11 @@ def update( ctrl.figure_update(fig) -@state.change("experiment") +@state.change("experiment", "displayed_inputs") def update_on_change_experiment(**kwargs): # skip if triggered on server ready (all state variables marked as modified) if len(state.modified_keys) == 1: - print("Experiment changed...") + print("Reacting on state change...") update( reset_model=True, reset_output=True, @@ -127,7 +127,7 @@ def update_on_change_experiment(**kwargs): def update_on_change_model(**kwargs): # skip if triggered on server ready (all state variables marked as modified) if len(state.modified_keys) == 1: - print("Model type changed...") + print("Reacting on state change...") update( reset_model=True, reset_output=False, @@ -154,7 +154,7 @@ def update_on_change_model(**kwargs): def update_on_change_others(**kwargs): # skip if triggered on server ready (all state variables marked as modified) if len(state.modified_keys) == 1: - print("Parameters, opacity changed...") + print("Reacting on state change...") update( reset_model=False, reset_output=False, diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index a629d89f..037987c8 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -39,8 +39,6 @@ def __init__(self, model, input_variables): state.parameters_show_all[key] = False self.parameters_step[key] = (pmax - pmin) / 100 state.parameters_init = copy.deepcopy(state.parameters) - # define other state variables - state.displayed_inputs = "Experiment" @property def model(self): diff --git a/dashboard/state_manager.py b/dashboard/state_manager.py index 1ece270b..2c3a54f3 100644 --- a/dashboard/state_manager.py +++ b/dashboard/state_manager.py @@ -52,5 +52,7 @@ def initialize_state(): # Errors management state.errors = [] state.error_counter = 0 - # Calibration toggles + # Calibration option state.use_inferred_calibration = False + # Displayed inputs + state.displayed_inputs = "Experiment" From 44b0bd0580cc2ce4ddcaec3d0b73b0a00981612a Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Mon, 12 Jan 2026 17:08:30 -0800 Subject: [PATCH 5/6] New data structure to store both families of parameters --- dashboard/app.py | 11 ++-- dashboard/calibration_manager.py | 21 ++++---- dashboard/optimization_manager.py | 10 ++-- dashboard/parameters_manager.py | 85 ++++++++++++++++++++++--------- dashboard/utils.py | 11 ++-- 5 files changed, 91 insertions(+), 47 deletions(-) diff --git a/dashboard/app.py b/dashboard/app.py index f37ce2f2..1e113ec3 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -75,7 +75,9 @@ def update( opt_manager = OptimizationManager(mod_manager) # reset parameters if reset_parameters: - par_manager = ParametersManager(mod_manager, input_variables) + par_manager = ParametersManager( + mod_manager, input_variables, simulation_calibration + ) elif reset_model: # if resetting only model, model attribute must be updated par_manager.model = mod_manager @@ -177,7 +179,7 @@ def find_simulation(event, db): if len(documents) == 1: this_point_parameters = { parameter: documents[0][parameter] - for parameter in state.parameters.keys() + for parameter in state.parameters["exp"].keys() if parameter in documents[0] } print(f"Clicked on data point ({this_point_parameters})") @@ -311,8 +313,11 @@ def home_route(): with vuetify.VCol(cols=8): with vuetify.VCard(): with vuetify.VCardTitle("Plots"): + param_family = ( + "exp" if state.displayed_inputs == "Experiment" else "sim" + ) with vuetify.VContainer( - style=f"height: {400 * len(state.parameters)}px;" + style=f"height: {400 * len(state.parameters[param_family])}px;" ): figure = plotly.Figure( display_mode_bar="true", diff --git a/dashboard/calibration_manager.py b/dashboard/calibration_manager.py index 74016391..04f62810 100644 --- a/dashboard/calibration_manager.py +++ b/dashboard/calibration_manager.py @@ -52,16 +52,17 @@ def convert(value, alpha, beta): for _, value in state.simulation_calibration.items(): sim_name = value["name"] exp_name = value["depends_on"] - # strip characters after '[' parenthesis to remove units, strip - # leading/trailing white spaces, replace white spaces and '-' with '_', - # and convert to lower case - sim_name = ( - sim_name.split("[")[0] - .strip() - .replace(" ", "_") - .replace("-", "_") - .lower() - ) + # FIXME + ## strip characters after '[' parenthesis to remove units, strip + ## leading/trailing white spaces, replace white spaces and '-' with '_', + ## and convert to lower case + # sim_name = ( + # sim_name.split("[")[0] + # .strip() + # .replace(" ", "_") + # .replace("-", "_") + # .lower() + # ) # fill the dictionary if exp_name in exp_dict: sim_dict[sim_name] = convert( diff --git a/dashboard/optimization_manager.py b/dashboard/optimization_manager.py index 237c4ae2..af2242ee 100644 --- a/dashboard/optimization_manager.py +++ b/dashboard/optimization_manager.py @@ -15,7 +15,7 @@ def __init__(self, model): def model_wrapper(self, parameters_array): print("Wrapping model...") # convert array of parameters to dictionary - parameters_dict = dict(zip(state.parameters.keys(), parameters_array)) + parameters_dict = dict(zip(state.parameters["exp"].keys(), parameters_array)) # change sign to the result in order to maximize when optimizing mean, lower, upper = self.__model.evaluate( parameters_dict, state.optimization_target @@ -28,12 +28,12 @@ def optimize(self): # info print statement skipped to avoid redundancy if self.__model is not None: # get array of current parameters from state - parameters_values = np.array(list(state.parameters.values())) + parameters_values = np.array(list(state.parameters["exp"].values())) # define parameters bounds for optimization parameters_bounds = [] - for key in state.parameters.keys(): + for key in state.parameters["exp"].keys(): parameters_bounds.append( - (state.parameters_min[key], state.parameters_max[key]) + (state.parameters_min["exp"][key], state.parameters_max["exp"][key]) ) # optimize model (maximize output value) res = minimize( @@ -44,7 +44,7 @@ def optimize(self): ) print(f"Optimization result:\n{res}") # update parameters in state with optimal values - state.parameters = dict(zip(state.parameters.keys(), res.x)) + state.parameters["exp"] = dict(zip(state.parameters["exp"].keys(), res.x)) # push again at flush time state.dirty("parameters") # Force flush now (TODO fix state change listeners, remove workaround) diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index 037987c8..8b29e858 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -15,16 +15,16 @@ class ParametersManager: - def __init__(self, model, input_variables): + def __init__(self, model, input_variables, simulation_calibration): print("Initializing parameters manager...") # save model self.__model = model # define state variables - state.parameters = dict() - state.parameters_min = dict() - state.parameters_max = dict() - state.parameters_show_all = dict() - self.parameters_step = dict() + state.parameters = {"exp": {}, "sim": {}} + state.parameters_min = {"exp": {}, "sim": {}} + state.parameters_max = {"exp": {}, "sim": {}} + state.parameters_show_all = {"exp": {}, "sim": {}} + self.parameters_step = {"exp": {}, "sim": {}} state.simulatable = ( self.simulation_scripts_base_path / "submission_script_single" ).is_file() @@ -33,11 +33,30 @@ def __init__(self, model, input_variables): pmin = float(parameter_dict["value_range"][0]) pmax = float(parameter_dict["value_range"][1]) pval = float(parameter_dict["default"]) - state.parameters[key] = pval - state.parameters_min[key] = pmin - state.parameters_max[key] = pmax - state.parameters_show_all[key] = False - self.parameters_step[key] = (pmax - pmin) / 100 + state.parameters["exp"][key] = pval + state.parameters_min["exp"][key] = pmin + state.parameters_max["exp"][key] = pmax + state.parameters_show_all["exp"][key] = False + self.parameters_step["exp"][key] = ( + state.parameters_max["exp"][key] - state.parameters_min["exp"][key] + ) / 100 + # store simulation parameters converted from experimental ones + sim_cal = SimulationCalibrationManager(simulation_calibration) + state.parameters["sim"] = sim_cal.convert_exp_to_sim(state.parameters["exp"]) + state.parameters_min["sim"] = sim_cal.convert_exp_to_sim( + state.parameters_min["exp"] + ) + state.parameters_max["sim"] = sim_cal.convert_exp_to_sim( + state.parameters_max["exp"] + ) + state.parameters_show_all["sim"] = copy.deepcopy( + state.parameters_show_all["exp"] + ) + for key in state.parameters["sim"].keys(): + self.parameters_step["sim"][key] = ( + state.parameters_max["sim"][key] - state.parameters_min["sim"][key] + ) / 100 + # save initial parameters for reset state.parameters_init = copy.deepcopy(state.parameters) @property @@ -77,7 +96,7 @@ async def simulation_kernel(self): ) _, _, simulation_calibration = load_variables(state.experiment) sim_cal = SimulationCalibrationManager(simulation_calibration) - sim_dict = sim_cal.convert_exp_to_sim(state.parameters) + sim_dict = sim_cal.convert_exp_to_sim(state.parameters["exp"]) with open(temp_file_path, "w") as temp_file: yaml.dump(sim_dict, temp_file) temp_file.flush() @@ -171,7 +190,12 @@ def panel(self): label="Displayed output", ) with client.DeepReactive("parameters"): - for count, key in enumerate(state.parameters.keys()): + param_family = ( + "exp" if state.displayed_inputs == "Experiment" else "sim" + ) + for count, key in enumerate( + state.parameters[param_family].keys() + ): # create a row for the parameter label with vuetify.VRow(): vuetify.VListSubheader( @@ -184,19 +208,23 @@ def panel(self): ) with vuetify.VRow(no_gutters=True): with vuetify.VSlider( - v_model_number=(f"parameters['{key}']",), + v_model_number=( + f"parameters['{param_family}']['{key}']", + ), change="flushState('parameters')", hide_details=True, - min=(f"parameters_min['{key}']",), - max=(f"parameters_max['{key}']",), + min=(f"parameters_min['{param_family}']['{key}']",), + max=(f"parameters_max['{param_family}']['{key}']",), step=( - f"(parameters_max['{key}'] - parameters_min['{key}']) / 100", + f"(parameters_max['{param_family}']['{key}'] - parameters_min['{param_family}']['{key}']) / 100", ), style="align-items: center;", ): with vuetify.Template(v_slot_append=True): vuetify.VTextField( - v_model_number=(f"parameters['{key}']",), + v_model_number=( + f"parameters['{param_family}']['{key}']", + ), density="compact", hide_details=True, readonly=True, @@ -204,15 +232,20 @@ def panel(self): style="margin-top: 0px; padding-top: 0px; width: 100px;", type="number", ) - step = self.parameters_step[key] + print(self.parameters_step) + step = self.parameters_step[param_family][key] with vuetify.VRow(no_gutters=True): with vuetify.VCol(): vuetify.VTextField( - v_model_number=(f"parameters_min['{key}']",), + v_model_number=( + f"parameters_min['{param_family}']['{key}']", + ), change="flushState('parameters_min')", density="compact", hide_details=True, - disabled=(f"parameters_show_all['{key}']",), + disabled=( + f"parameters_show_all['{param_family}']['{key}']", + ), step=step, __properties=["step"], style="width: 100px;", @@ -221,11 +254,15 @@ def panel(self): ) with vuetify.VCol(): vuetify.VTextField( - v_model_number=(f"parameters_max['{key}']",), + v_model_number=( + f"parameters_max['{param_family}']['{key}']", + ), change="flushState('parameters_max')", density="compact", hide_details=True, - disabled=(f"parameters_show_all['{key}']",), + disabled=( + f"parameters_show_all['{param_family}']['{key}']", + ), step=step, __properties=["step"], style="width: 100px;", @@ -235,7 +272,7 @@ def panel(self): with vuetify.VCol(style="min-width: 100px;"): vuetify.VCheckbox( v_model=( - f"parameters_show_all['{key}']", + f"parameters_show_all['{param_family}']['{key}']", False, ), density="compact", diff --git a/dashboard/utils.py b/dashboard/utils.py index 5c2325c5..b7c3c645 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -139,10 +139,11 @@ def plot(exp_data, sim_data, model_manager, cal_manager): # convert simulation data to experimental data cal_manager.convert_sim_to_exp(sim_data) # local aliases - parameters = state.parameters - parameters_min = state.parameters_min - parameters_max = state.parameters_max - parameters_show_all = state.parameters_show_all + param_family = "exp" # FIXME if state.displayed_inputs == "Experiment" else "sim" + parameters = state.parameters[param_family] + parameters_min = state.parameters_min[param_family] + parameters_max = state.parameters_max[param_family] + parameters_show_all = state.parameters_show_all[param_family] try: objective_name = state.displayed_output except Exception as e: @@ -214,7 +215,7 @@ def hover_section(title, cols, hover_data): return section # Determine which data is shown when hovering over the plot - hover_parameters = list(state.parameters.keys()) + hover_parameters = list(state.parameters[param_family].keys()) hover_output_variables = state.output_variables hover_customdata = ["_id"] + hover_parameters + hover_output_variables From dbb86ee2864f2fb3c2ef473bd05ef6154e552ebd Mon Sep 17 00:00:00 2001 From: Edoardo Zoni Date: Tue, 13 Jan 2026 11:57:59 -0800 Subject: [PATCH 6/6] Continue implementation of selector --- dashboard/parameters_manager.py | 22 ++++++------- dashboard/utils.py | 56 +++++++++++++++++++-------------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/dashboard/parameters_manager.py b/dashboard/parameters_manager.py index 8b29e858..22c1e57e 100644 --- a/dashboard/parameters_manager.py +++ b/dashboard/parameters_manager.py @@ -28,7 +28,7 @@ def __init__(self, model, input_variables, simulation_calibration): state.simulatable = ( self.simulation_scripts_base_path / "submission_script_single" ).is_file() - for _, parameter_dict in input_variables.items(): + for parameter_dict in input_variables.values(): key = parameter_dict["name"] pmin = float(parameter_dict["value_range"][0]) pmax = float(parameter_dict["value_range"][1]) @@ -37,9 +37,7 @@ def __init__(self, model, input_variables, simulation_calibration): state.parameters_min["exp"][key] = pmin state.parameters_max["exp"][key] = pmax state.parameters_show_all["exp"][key] = False - self.parameters_step["exp"][key] = ( - state.parameters_max["exp"][key] - state.parameters_min["exp"][key] - ) / 100 + self.parameters_step["exp"][key] = (pmax - pmin) / 100 # store simulation parameters converted from experimental ones sim_cal = SimulationCalibrationManager(simulation_calibration) state.parameters["sim"] = sim_cal.convert_exp_to_sim(state.parameters["exp"]) @@ -49,13 +47,14 @@ def __init__(self, model, input_variables, simulation_calibration): state.parameters_max["sim"] = sim_cal.convert_exp_to_sim( state.parameters_max["exp"] ) - state.parameters_show_all["sim"] = copy.deepcopy( - state.parameters_show_all["exp"] - ) - for key in state.parameters["sim"].keys(): - self.parameters_step["sim"][key] = ( - state.parameters_max["sim"][key] - state.parameters_min["sim"][key] - ) / 100 + state.parameters_show_all["sim"] = { + key: False for key in state.parameters["sim"].keys() + } + self.parameters_step["sim"] = { + key: (state.parameters_max["sim"][key] - state.parameters_min["sim"][key]) + / 100 + for key in state.parameters["sim"].keys() + } # save initial parameters for reset state.parameters_init = copy.deepcopy(state.parameters) @@ -232,7 +231,6 @@ def panel(self): style="margin-top: 0px; padding-top: 0px; width: 100px;", type="number", ) - print(self.parameters_step) step = self.parameters_step[param_family][key] with vuetify.VRow(no_gutters=True): with vuetify.VCol(): diff --git a/dashboard/utils.py b/dashboard/utils.py index b7c3c645..0b4118de 100644 --- a/dashboard/utils.py +++ b/dashboard/utils.py @@ -138,12 +138,8 @@ def plot(exp_data, sim_data, model_manager, cal_manager): print("Plotting...") # convert simulation data to experimental data cal_manager.convert_sim_to_exp(sim_data) - # local aliases - param_family = "exp" # FIXME if state.displayed_inputs == "Experiment" else "sim" - parameters = state.parameters[param_family] - parameters_min = state.parameters_min[param_family] - parameters_max = state.parameters_max[param_family] - parameters_show_all = state.parameters_show_all[param_family] + # displayed inputs type + param_family = "exp" if state.displayed_inputs == "Experiment" else "sim" try: objective_name = state.displayed_output except Exception as e: @@ -156,10 +152,10 @@ def plot(exp_data, sim_data, model_manager, cal_manager): df_cds = ["blue", "red"] df_leg = ["Experiment", "Simulation"] # plot - fig = make_subplots(rows=len(parameters), cols=1) + fig = make_subplots(rows=len(state.parameters[param_family]), cols=1) global_ymin = float("inf") global_ymax = float("-inf") - for i, key in enumerate(parameters.keys()): + for i, key in enumerate(state.parameters[param_family].keys()): # NOTE row count starts from 1, enumerate count starts from 0 this_row = i + 1 this_col = 1 @@ -177,13 +173,13 @@ def plot(exp_data, sim_data, model_manager, cal_manager): # loop over all inputs except the current one for subkey in [ subkey - for subkey in parameters.keys() + for subkey in state.parameters[param_family].keys() if (subkey != key and subkey in df_copy.columns) ]: pname_loc = subkey - pval_loc = parameters[subkey] - pmin_loc = parameters_min[subkey] - pmax_loc = parameters_max[subkey] + pval_loc = state.parameters[param_family][subkey] + pmin_loc = state.parameters_min[param_family][subkey] + pmax_loc = state.parameters_max[param_family][subkey] df_copy["distance"] += ( (df_copy[f"{pname_loc}"] - pval_loc) / (pmax_loc - pmin_loc) ) ** 2 @@ -287,15 +283,24 @@ def hover_section(title, cols, hover_data): # ---------------------------------------------------------------------- # figure trace from model data if model_manager.avail(): + exp_key = ( + key + if param_family == "exp" + else list(state.parameters["exp"].keys())[i] + ) input_dict_loc = dict() steps = 1000 - input_dict_loc[key] = torch.linspace( - start=parameters_min[key], - end=parameters_max[key], + input_dict_loc[exp_key] = torch.linspace( + start=state.parameters_min["exp"][exp_key], + end=state.parameters_max["exp"][exp_key], steps=steps, ) - for subkey in [subkey for subkey in parameters.keys() if subkey != key]: - input_dict_loc[subkey] = parameters[subkey] * torch.ones(steps) + for subkey in [ + subkey for subkey in state.parameters["exp"].keys() if subkey != exp_key + ]: + input_dict_loc[subkey] = state.parameters["exp"][subkey] * torch.ones( + steps + ) # get mean and lower/upper bounds for uncertainty prediction # (when lower/upper bounds are not predicted by the model, # their values are set to zero to collapse the error range) @@ -308,7 +313,7 @@ def hover_section(title, cols, hover_data): # upper bound upper_bound = go.Scatter( - x=input_dict_loc[key], + x=input_dict_loc[exp_key], y=upper, line=dict(color="orange", width=0.3), showlegend=False, @@ -321,7 +326,7 @@ def hover_section(title, cols, hover_data): ) # lower bound lower_bound = go.Scatter( - x=input_dict_loc[key], + x=input_dict_loc[exp_key], y=lower, fill="tonexty", # fill area between this trace and the next one fillcolor="rgba(255,165,0,0.25)", # orange with alpha @@ -336,7 +341,7 @@ def hover_section(title, cols, hover_data): ) # scatter plot mod_trace = go.Scatter( - x=input_dict_loc[key], + x=input_dict_loc[exp_key], y=mean, line=dict(color="orange"), name="ML Model", @@ -351,14 +356,14 @@ def hover_section(title, cols, hover_data): # ---------------------------------------------------------------------- # add reference input line fig.add_vline( - x=parameters[key], + x=state.parameters[param_family][key], line_dash="dash", row=this_row, col=this_col, ) # ---------------------------------------------------------------------- # figures style - if parameters_show_all[key]: + if state.parameters_show_all[param_family][key]: fig.update_xaxes( exponentformat="e", title_text=key, @@ -367,7 +372,10 @@ def hover_section(title, cols, hover_data): ) else: fig.update_xaxes( - range=(parameters_min[key], parameters_max[key]), + range=( + state.parameters_min[param_family][key], + state.parameters_max[param_family][key], + ), exponentformat="e", title_text=key, row=this_row, @@ -376,7 +384,7 @@ def hover_section(title, cols, hover_data): # A bit of padding on either end of the y range so we can see all the data. padding = 0.05 * (global_ymax - global_ymin) - for i, key in enumerate(parameters.keys()): + for i, key in enumerate(state.parameters[param_family].keys()): this_row = i + 1 this_col = 1 fig.update_yaxes(