diff --git a/dashboard/app.py b/dashboard/app.py index cb5108f..1278cee 100644 --- a/dashboard/app.py +++ b/dashboard/app.py @@ -1,3 +1,4 @@ +import asyncio from bson.objectid import ObjectId import os import re @@ -6,7 +7,12 @@ from trame.ui.vuetify3 import SinglePageWithDrawerLayout from trame.widgets import plotly, router, vuetify3 as vuetify, html -from model_manager import ModelManager, model_type_dict +from model_manager import ( + ModelManager, + is_model_available_on_mlflow, + load_model_from_mlflow_with_progress, + model_type_dict, +) from outputs_manager import OutputManager from optimization_manager import OptimizationManager from parameters_manager import ParametersManager @@ -52,6 +58,7 @@ def update( reset_gui_route_nersc=True, reset_gui_route_chat=True, reset_gui_layout=True, + preloaded_model_manager=None, **kwargs, ): print("Updating...") @@ -69,6 +76,9 @@ def update( # derive execution mode from execution_mode in the experiment configuration file execution_mode = config_dict.get("execution_mode") or {} state.model_training_mode = execution_mode.get("ml_training", "local") + state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get( + "tracking_uri" + ) db = load_database(config_dict) exp_data, sim_data = load_data(db, state.experiment, state.experiment_date_range) # reset output @@ -79,10 +89,15 @@ def update( cal_manager = SimulationCalibrationManager(simulation_calibration) # reset model if reset_model: - mod_manager = ModelManager( - config_dict=config_dict, - model_type=model_type_dict[state.model_type_verbose], - ) + state.model_available = False + if preloaded_model_manager is None: + mod_manager = ModelManager( + config_dict=config_dict, + model_type=model_type_dict[state.model_type_verbose], + ) + else: + mod_manager = preloaded_model_manager + state.model_available = mod_manager.avail() opt_manager = OptimizationManager(mod_manager) # reset parameters if reset_parameters: @@ -113,6 +128,67 @@ def update( ctrl.figure_update(fig) +async def update_with_model_download_indicator(**update_kwargs): + """Run a dashboard update with visible download feedback for large MLflow models.""" + load_error = None + experiment = state.experiment + model_type_verbose = state.model_type_verbose + config_dict = load_config_dict(experiment) + model_type = model_type_dict[model_type_verbose] + state.model_available = False + state.model_mlflow_tracking_uri = (config_dict.get("mlflow") or {}).get( + "tracking_uri" + ) + state.model_downloading = True + state.model_download_status = "Downloading model from MLflow..." + state.model_download_progress = None + state.flush() + await asyncio.sleep(0.05) + try: + loaded_model = await asyncio.to_thread( + load_model_from_mlflow_with_progress, + config_dict, + model_type, + asyncio.get_running_loop(), + ) + except Exception as e: + loaded_model = None + load_error = e + if state.experiment != experiment or state.model_type_verbose != model_type_verbose: + return + if load_error is not None: + title = f"Unable to load model {model_type}" + msg = f"Error occurred when loading model from MLflow: {load_error}" + add_error(title, msg) + print(msg) + update_kwargs["preloaded_model_manager"] = ModelManager( + config_dict=config_dict, + model_type=model_type, + loaded_model=loaded_model, + ) + try: + update(**update_kwargs) + finally: + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.flush() + + +def update_model_selection(**update_kwargs): + config_dict = load_config_dict(state.experiment) + model_type = model_type_dict[state.model_type_verbose] + if update_kwargs.get("reset_model", True) and is_model_available_on_mlflow( + config_dict, model_type + ): + asyncio.create_task(update_with_model_download_indicator(**update_kwargs)) + else: + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + update(**update_kwargs) + + @state.change( "experiment", "experiment_date_range", @@ -128,17 +204,18 @@ def update( "use_inferred_calibration", ) def reset(**kwargs): + modified_keys = set(state.modified_keys) # skip if triggered on server ready (all state variables marked as modified) - if len(state.modified_keys) == 1: - print(f"Reacting to state change in {state.modified_keys}...") + if len(modified_keys) == 1: + print(f"Reacting to state change in {modified_keys}...") if any( - key in state.modified_keys + key in modified_keys for key in [ "experiment", "experiment_date_range", ] ): - update( + update_model_selection( reset_model=True, reset_output=True, reset_parameters=True, @@ -150,13 +227,13 @@ def reset(**kwargs): reset_gui_layout=False, ) elif any( - key in state.modified_keys + key in modified_keys for key in [ "model_type_verbose", "model_training_time", ] ): - update( + update_model_selection( reset_model=True, reset_output=False, reset_parameters=False, @@ -168,7 +245,7 @@ def reset(**kwargs): reset_gui_layout=False, ) elif any( - key in state.modified_keys + key in modified_keys for key in [ "displayed_output", "parameters", diff --git a/dashboard/logos/AmSC_300px.png b/dashboard/logos/AmSC_300px.png new file mode 100644 index 0000000..a002c3a Binary files /dev/null and b/dashboard/logos/AmSC_300px.png differ diff --git a/dashboard/model_manager.py b/dashboard/model_manager.py index 6f8095a..b7fc08f 100644 --- a/dashboard/model_manager.py +++ b/dashboard/model_manager.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import contextmanager from datetime import datetime from pathlib import Path import tempfile @@ -6,22 +7,166 @@ import yaml import re import mlflow +import mlflow.store.artifact.artifact_repo as mlflow_artifact_repo +import mlflow.store.artifact.cloud_artifact_repo as mlflow_cloud_artifact_repo +import mlflow.utils.file_utils as mlflow_file_utils +from mlflow.exceptions import MlflowException +from trame.assets.local import LocalFileManager from sfapi_client import AsyncClient from sfapi_client.compute import Machine -from trame.widgets import vuetify3 as vuetify +from trame.widgets import vuetify3 as vuetify, html from utils import timer, load_config_dict, create_date_filter from calibration_manager import build_inferred_calibration from error_manager import add_error from sfapi_manager import monitor_sfapi_job from state_manager import state +LOGO_DIR = Path(__file__).parent / "logos" +AMSC_MLFLOW_URL = "https://mlflow.american-science-cloud.org" +MODEL_TYPE_GP = "Gaussian Process" +MODEL_TYPE_NN_SINGLE = "Neural Network (single)" +MODEL_TYPE_NN_ENSEMBLE = "Neural Network (ensemble)" +AMSC_LOGO_PATH = LOGO_DIR / "AmSC_300px.png" +AMSC_LOGO_URL = ( + LocalFileManager(LOGO_DIR).url("amsc_logo", AMSC_LOGO_PATH) + if AMSC_LOGO_PATH.is_file() + else None +) +MODEL_DOWNLOAD_ACTIVE_EXPR = "model_downloading" +AMSC_MLFLOW_LINK_ACTIVE_EXPR = ( + f"model_available && model_mlflow_tracking_uri === '{AMSC_MLFLOW_URL}'" +) +AMSC_MLFLOW_MODEL_URL_EXPR = ( + f"'{AMSC_MLFLOW_URL}/#/models/synapse-' + experiment + '_' + " + f"(model_type_verbose === '{MODEL_TYPE_GP}' ? 'GP' : " + f"model_type_verbose === '{MODEL_TYPE_NN_SINGLE}' ? 'NN' : " + f"model_type_verbose === '{MODEL_TYPE_NN_ENSEMBLE}' ? 'ensemble_NN' : " + "model_type_verbose)" +) + model_type_dict = { - "Gaussian Process": "GP", - "Neural Network (single)": "NN", - "Neural Network (ensemble)": "ensemble_NN", + MODEL_TYPE_GP: "GP", + MODEL_TYPE_NN_SINGLE: "NN", + MODEL_TYPE_NN_ENSEMBLE: "ensemble_NN", } +_NO_PRELOADED_MODEL = object() + + +def build_mlflow_model_name(config_dict, model_type): + """Return the registered MLflow model name for an experiment and model type.""" + return f"synapse-{config_dict['experiment']}_{model_type}" + + +def configure_mlflow_tracking(config_dict): + """Configure MLflow tracking for an experiment when MLflow is available.""" + mlflow_cfg = config_dict.get("mlflow") or {} + tracking_uri = mlflow_cfg.get("tracking_uri") + if not tracking_uri: + msg = ( + "No mlflow.tracking_uri in configuration file for " + f"{config_dict['experiment']}; cannot load model from MLflow." + ) + print(msg) + return False + + mlflow.set_tracking_uri(tracking_uri) + # When using the AmSC MLflow, inject the X-Api-Key to authenticate. + # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) + if tracking_uri == AMSC_MLFLOW_URL: + enable_amsc_x_api_key(config_dict) + return True + + +def load_model_from_mlflow(config_dict, model_type): + """Load the latest registered MLflow model for an experiment configuration.""" + if not configure_mlflow_tracking(config_dict): + return None + + model_name = build_mlflow_model_name(config_dict, model_type) + return ( + mlflow.pyfunc.load_model(f"models:/{model_name}/latest") + .unwrap_python_model() + .model + ) + + +def is_model_available_on_mlflow(config_dict, model_type): + """Return whether MLflow has a registered model version to download.""" + if not configure_mlflow_tracking(config_dict): + return False + + model_name = build_mlflow_model_name(config_dict, model_type) + try: + versions = mlflow.MlflowClient().search_model_versions( + f"name='{model_name}'", + max_results=1, + ) + except MlflowException as e: + if e.error_code == "RESOURCE_DOES_NOT_EXIST": + return False + print(f"Unable to check MLflow model availability for {model_name}: {e}") + return False + return bool(versions) + + +@contextmanager +def mlflow_artifact_progress_to_state(loop): + """Expose MLflow artifact download progress through dashboard state.""" + progress_bar_modules = [ + mlflow_file_utils, + mlflow_artifact_repo, + mlflow_cloud_artifact_repo, + ] + original_progress_bars = { + module: module.ArtifactProgressBar for module in progress_bar_modules + } + original_progress_bar = mlflow_file_utils.ArtifactProgressBar + + def set_download_progress(progress, total): + """Publish the current download completion percentage to the GUI.""" + + def update_progress_state(): + if total: + state.model_download_progress = min(100, progress / total * 100) + else: + state.model_download_progress = None + state.flush() + + loop.call_soon_threadsafe(update_progress_state) + + class TrameArtifactProgressBar(original_progress_bar): + def __init__(self, desc, total, step, **kwargs): + super().__init__(desc, total, step, **kwargs) + self.trame_progress = 0 + if desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + def update(self): + super().update() + self.trame_progress = min( + self.total, + self.trame_progress + self.step, + ) + if self.desc.startswith("Downloading"): + set_download_progress(self.trame_progress, self.total) + + for module in progress_bar_modules: + module.ArtifactProgressBar = TrameArtifactProgressBar + try: + yield + finally: + for module, progress_bar in original_progress_bars.items(): + module.ArtifactProgressBar = progress_bar + + +def load_model_from_mlflow_with_progress(config_dict, model_type, loop): + """Load an MLflow model while reporting artifact download progress.""" + with mlflow_artifact_progress_to_state(loop): + return load_model_from_mlflow(config_dict, model_type) + + def enable_amsc_x_api_key(config_dict): """ MLflow authentication helper for the AmSC MLflow server. @@ -49,7 +194,10 @@ def enable_amsc_x_api_key(config_dict): add_error(title, msg) print(msg) return - _orig = rest_utils.http_request + if getattr(rest_utils.http_request, "_synapse_amsc_api_key", None) == api_key: + return + + _orig = getattr(rest_utils, "_synapse_http_request", rest_utils.http_request) def patched(host_creds, endpoint, method, *args, **kwargs): if "headers" in kwargs and kwargs["headers"] is not None: @@ -62,40 +210,25 @@ def patched(host_creds, endpoint, method, *args, **kwargs): kwargs["extra_headers"] = h return _orig(host_creds, endpoint, method, *args, **kwargs) + patched._synapse_amsc_api_key = api_key + rest_utils._synapse_http_request = _orig rest_utils.http_request = patched class ModelManager: - def __init__(self, config_dict, model_type): + def __init__(self, config_dict, model_type, loaded_model=_NO_PRELOADED_MODEL): print("Initializing model manager...") self.__model = None self.__model_type = model_type - if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): - print( - f"No mlflow.tracking_uri in configuration file for {config_dict['experiment']}; cannot load model from MLflow." - ) - return - - mlflow.set_tracking_uri(config_dict["mlflow"]["tracking_uri"]) - # When using the AmSC MLflow: inject the X-Api-Key into the requests to authenticate with the MLflow server - # (See https://gitlab.com/amsc2/ai-services/model-services/intro-to-mlflow-pytorch) - if ( - config_dict["mlflow"]["tracking_uri"] - == "https://mlflow.american-science-cloud.org" - ): - enable_amsc_x_api_key(config_dict) - - experiment = config_dict["experiment"] - model_name = f"synapse-{experiment}_{model_type}" - try: - # Download model from MLflow server self.__model = ( - mlflow.pyfunc.load_model(f"models:/{model_name}/latest") - .unwrap_python_model() - .model + load_model_from_mlflow(config_dict, model_type) + if loaded_model is _NO_PRELOADED_MODEL + else loaded_model ) + if self.__model is None: + return if model_type not in ("NN", "ensemble_NN", "GP"): raise ValueError(f"Unsupported model type: {model_type}") # Populate inferred calibration in physics units for GUI @@ -353,37 +486,113 @@ def panel(self): print("Setting model card...") # list of available model types model_type_list = [ - "Gaussian Process", - "Neural Network (single)", - "Neural Network (ensemble)", + MODEL_TYPE_GP, + MODEL_TYPE_NN_SINGLE, + MODEL_TYPE_NN_ENSEMBLE, ] + model_type_cols = 8 if AMSC_LOGO_URL else 12 with vuetify.VExpansionPanels(v_model=("expand_panel_control_model", 0)): with vuetify.VExpansionPanel( title="Control: Models", style="font-size: 20px; font-weight: 500;", ): with vuetify.VExpansionPanelText(): - with vuetify.VRow(): - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols=model_type_cols): vuetify.VSelect( v_model=("model_type_verbose",), label="Model type", items=(model_type_list,), dense=True, ) - with vuetify.VCol(): - vuetify.VTextField( - v_model_number=("model_training_status",), - label="Training status", - readonly=True, + if AMSC_LOGO_URL: + with vuetify.VCol( + cols=4, + classes="d-flex align-center justify-end", + ): + with html.A( + v_if=(AMSC_MLFLOW_LINK_ACTIVE_EXPR,), + href=(AMSC_MLFLOW_MODEL_URL_EXPR,), + target="_blank", + rel="noopener noreferrer", + title="Open selected model in AmSC MLflow", + style=( + "display: block; width: 100%; " + "max-width: 300px; margin-left: auto; " + "cursor: pointer;" + ), + ): + vuetify.VImg( + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + style="width: 100%;", + ) + vuetify.VImg( + v_if=(f"!({AMSC_MLFLOW_LINK_ACTIVE_EXPR})",), + src=AMSC_LOGO_URL, + alt="AmSC", + max_width=300, + max_height=72, + contain=True, + title=( + "Selected model is not available in AmSC MLflow" + ), + style=( + "width: 100%; max-width: 300px; " + "margin-left: auto;" + ), + ) + with vuetify.VRow( + v_if=(MODEL_DOWNLOAD_ACTIVE_EXPR,), + no_gutters=True, + align="center", + style="margin-top: -8px; margin-bottom: 8px;", + ): + with vuetify.VCol(cols=model_type_cols): + with html.Div( + classes=( + "d-flex align-center text-caption " + "text-medium-emphasis mb-1" + ) + ): + vuetify.VIcon( + "mdi-cloud-download-outline", + size=16, + classes="mr-1", + ) + html.Span(v_text=("model_download_status",)) + vuetify.VSpacer() + html.Span( + v_if=("model_download_progress !== null",), + v_text=( + "`${Math.round(model_download_progress)}%`", + ), + ) + vuetify.VProgressLinear( + indeterminate=("model_download_progress === null",), + model_value=("model_download_progress",), + color="primary", + height=4, + rounded=True, ) - with vuetify.VRow(): - with vuetify.VCol(): + with vuetify.VRow(align="center"): + with vuetify.VCol(cols="auto"): vuetify.VBtn( "Train", click=self.training_trigger, disabled=( - "model_training || (model_training_mode === 'sfapi' && perlmutter_status !== 'active')", + "model_training || " + "(model_training_mode === 'sfapi' && " + "perlmutter_status !== 'active')", ), style="text-transform: none", ) + with vuetify.VCol(cols=6, style="margin-left: auto;"): + vuetify.VTextField( + v_model_number=("model_training_status",), + label="Training status", + readonly=True, + ) diff --git a/dashboard/state_manager.py b/dashboard/state_manager.py index 8d8d31e..1068096 100644 --- a/dashboard/state_manager.py +++ b/dashboard/state_manager.py @@ -32,6 +32,11 @@ def initialize_state(): state.model_training_mode = "local" state.model_training_status = None state.model_training_time = None + state.model_available = False + state.model_downloading = False + state.model_download_status = None + state.model_download_progress = None + state.model_mlflow_tracking_uri = None # Optimization state.optimization_type = "Maximize" state.optimization_status = None