Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 89 additions & 12 deletions dashboard/app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from bson.objectid import ObjectId
import os
import re
Expand All @@ -6,7 +7,12 @@
from trame.ui.vuetify3 import SinglePageWithDrawerLayout
from trame.widgets import plotly, router, vuetify3 as vuetify, html

from model_manager import ModelManager, model_type_dict
from model_manager import (
ModelManager,
is_model_available_on_mlflow,
load_model_from_mlflow_with_progress,
model_type_dict,
)
from outputs_manager import OutputManager
from optimization_manager import OptimizationManager
from parameters_manager import ParametersManager
Expand Down Expand Up @@ -52,6 +58,7 @@ def update(
reset_gui_route_nersc=True,
reset_gui_route_chat=True,
reset_gui_layout=True,
preloaded_model_manager=None,
**kwargs,
):
print("Updating...")
Expand All @@ -69,6 +76,9 @@ def update(
# derive execution mode from execution_mode in the experiment configuration file
execution_mode = config_dict.get("execution_mode") or {}
state.model_training_mode = execution_mode.get("ml_training", "local")
state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get(
"tracking_uri"
)
db = load_database(config_dict)
exp_data, sim_data = load_data(db, state.experiment, state.experiment_date_range)
# reset output
Expand All @@ -79,10 +89,15 @@ def update(
cal_manager = SimulationCalibrationManager(simulation_calibration)
# reset model
if reset_model:
mod_manager = ModelManager(
config_dict=config_dict,
model_type=model_type_dict[state.model_type_verbose],
)
state.model_available = False
if preloaded_model_manager is None:
mod_manager = ModelManager(
config_dict=config_dict,
model_type=model_type_dict[state.model_type_verbose],
)
else:
mod_manager = preloaded_model_manager
state.model_available = mod_manager.avail()
opt_manager = OptimizationManager(mod_manager)
# reset parameters
if reset_parameters:
Expand Down Expand Up @@ -113,6 +128,67 @@ def update(
ctrl.figure_update(fig)


async def update_with_model_download_indicator(**update_kwargs):
"""Run a dashboard update with visible download feedback for large MLflow models."""
load_error = None
experiment = state.experiment
model_type_verbose = state.model_type_verbose
config_dict = load_config_dict(experiment)
model_type = model_type_dict[model_type_verbose]
state.model_available = False
state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get(
"tracking_uri"
)
state.model_downloading = True
state.model_download_status = "Downloading model from MLflow..."
state.model_download_progress = None
state.flush()
await asyncio.sleep(0.05)
try:
loaded_model = await asyncio.to_thread(
load_model_from_mlflow_with_progress,
config_dict,
model_type,
asyncio.get_running_loop(),
)
except Exception as e:
loaded_model = None
load_error = e
if state.experiment != experiment or state.model_type_verbose != model_type_verbose:
return
if load_error is not None:
title = f"Unable to load model {model_type}"
msg = f"Error occurred when loading model from MLflow: {load_error}"
add_error(title, msg)
print(msg)
update_kwargs["preloaded_model_manager"] = ModelManager(
config_dict=config_dict,
model_type=model_type,
loaded_model=loaded_model,
)
try:
update(**update_kwargs)
finally:
state.model_downloading = False
state.model_download_status = None
state.model_download_progress = None
state.flush()


def update_model_selection(**update_kwargs):
config_dict = load_config_dict(state.experiment)
model_type = model_type_dict[state.model_type_verbose]
if update_kwargs.get("reset_model", True) and is_model_available_on_mlflow(
config_dict, model_type
):
asyncio.create_task(update_with_model_download_indicator(**update_kwargs))
else:
state.model_downloading = False
state.model_download_status = None
state.model_download_progress = None
update(**update_kwargs)


@state.change(
"experiment",
"experiment_date_range",
Expand All @@ -128,17 +204,18 @@ def update(
"use_inferred_calibration",
)
def reset(**kwargs):
modified_keys = set(state.modified_keys)
# skip if triggered on server ready (all state variables marked as modified)
if len(state.modified_keys) == 1:
print(f"Reacting to state change in {state.modified_keys}...")
if len(modified_keys) == 1:
print(f"Reacting to state change in {modified_keys}...")
if any(
key in state.modified_keys
key in modified_keys
for key in [
"experiment",
"experiment_date_range",
]
):
update(
update_model_selection(
reset_model=True,
reset_output=True,
reset_parameters=True,
Expand All @@ -150,13 +227,13 @@ def reset(**kwargs):
reset_gui_layout=False,
)
elif any(
key in state.modified_keys
key in modified_keys
for key in [
"model_type_verbose",
"model_training_time",
]
):
update(
update_model_selection(
reset_model=True,
reset_output=False,
reset_parameters=False,
Expand All @@ -168,7 +245,7 @@ def reset(**kwargs):
reset_gui_layout=False,
)
elif any(
key in state.modified_keys
key in modified_keys
for key in [
"displayed_output",
"parameters",
Expand Down
Binary file added dashboard/logos/AmSC_300px.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading