diff --git a/.claude/skills/test-ml/SKILL.md b/.claude/skills/test-ml/SKILL.md new file mode 100644 index 00000000..0627b57a --- /dev/null +++ b/.claude/skills/test-ml/SKILL.md @@ -0,0 +1,59 @@ +--- +name: test-ml +description: Train ML models (GP, NN, ensemble_NN), populate MLflow, and verify they load and evaluate correctly. Use when testing the ML training and inference pipeline end-to-end. +--- + +# Test ML Models + +Train models with `train_model.py`, then verify they load and evaluate via the `check_model.py` script bundled with this skill (`.claude/skills/test-ml/check_model.py`). + +## Parse user intent + +- If `$ARGUMENTS` specifies a model type (GP, NN, or ensemble_NN), only test that model type. +- If `$ARGUMENTS` specifies an experiment name, use `experiments/synapse-/config.yaml`. +- If no model type is given, test all three: GP, NN, ensemble_NN. +- If no experiment is given, auto-detect available experiments by globbing `experiments/synapse-*/config.yaml`. Only use experiments whose config contains an `mlflow` section with a `tracking_uri`. + +## Pre-flight checks (do these BEFORE any training) + +Run all checks and **stop immediately with a clear warning** if any fail: + +1. **Conda environment**: Verify `synapse-ml` conda environment exists: + ``` + source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml + ``` + +2. **Database credentials**: Source `~/db.profile`, then read the config YAML and check that the environment variable named in `database.password_ro_env` is set. + +3. **MLflow server**: Read `mlflow.tracking_uri` from the config YAML. Test reachability with a quick socket connection (e.g. `python -c "import socket; socket.create_connection(('', ), timeout=5)"`). If unreachable, tell the user to start the MLflow server. + +4. **MongoDB**: Test connectivity with a quick pymongo ping: + ``` + python -c "import pymongo; pymongo.MongoClient(host='', port=, serverSelectionTimeoutMS=5000).admin.command('ping')" + ``` + Note: for the MongoDB check, source `~/db.profile` first, then read the database credentials from the config file (host, port, auth, username_ro, password_ro_env). + +## Train models + +Run training from the `ml/` directory. Each command: +``` +source ~/db.profile && source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml && cd /ml && python train_model.py --config_file --model +``` + +- When testing **multiple model types**, run the training commands **in parallel** (use background bash tasks). +- When testing a **single model type**, run it in the foreground. +- Wait for all training to complete before proceeding. + +## Validate models + +After training, run the bundled `check_model.py` for each model type: +``` +source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml && python /.claude/skills/test-ml/check_model.py --config_file --model +``` + +## Report results + +Provide a summary table: +- Model type | Train status | Validation status +- Show `[PASS]` or `[FAIL]` for each +- For failures, include the error message diff --git a/.claude/skills/test-ml/check_model.py b/.claude/skills/test-ml/check_model.py new file mode 100644 index 00000000..6cbf8c2b --- /dev/null +++ b/.claude/skills/test-ml/check_model.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +""" +Check that a model stored in MLflow loads and evaluates correctly, +using the same download logic as the dashboard. + +Usage: + python check_model.py --config_file --model +""" + +import argparse +import os +import socket +import sys +from pathlib import Path +from urllib.parse import urlparse +import pandas as pd +import torch +import yaml +import mlflow + +# Import DB connection helper from train_model.py +_ML_DIR = Path(__file__).resolve().parents[3] / "ml" +sys.path.insert(0, str(_ML_DIR)) +from train_model import connect_to_db + + +MODEL_TYPES = ["GP", "NN", "ensemble_NN"] + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description="Verify that an MLflow model loads and predicts with default parameters." + ) + parser.add_argument( + "--config_file", + help="Path to the configuration file", + type=str, + required=True, + ) + parser.add_argument( + "--model", + help="Model type: GP, NN, or ensemble_NN", + choices=MODEL_TYPES, + required=True, + ) + args = parser.parse_args() + print(f"Config file: {args.config_file}, Model type: {args.model}") + return args.config_file, args.model + + +def load_config(config_file): + if not os.path.exists(config_file): + raise RuntimeError(f"Configuration file not found: {config_file}") + with open(config_file) as f: + return yaml.safe_load(f.read()) + + +def enable_amsc_x_api_key(config_dict): + """Inject AmSC X-Api-Key header into all MLflow requests (mirrors train_model.py).""" + import mlflow.utils.rest_utils as rest_utils + + mlflow_cfg = config_dict.get("mlflow") or {} + api_key_env = mlflow_cfg.get("api_key_env") + if not api_key_env: + raise KeyError( + "Missing 'api_key_env' in 'mlflow' configuration for AmSC authentication." + ) + api_key = os.getenv(api_key_env) + if api_key is None: + raise KeyError( + f"Environment variable '{api_key_env}' (from mlflow.api_key_env) is not set." + ) + + _orig = rest_utils.http_request + + def patched(host_creds, endpoint, method, *args, **kwargs): + if "headers" in kwargs and kwargs["headers"] is not None: + h = dict(kwargs["headers"]) + h["X-Api-Key"] = api_key + kwargs["headers"] = h + else: + h = dict(kwargs.get("extra_headers") or {}) + h["X-Api-Key"] = api_key + kwargs["extra_headers"] = h + return _orig(host_creds, endpoint, method, *args, **kwargs) + + rest_utils.http_request = patched + + +def check_server_reachable(tracking_uri, timeout=5): + """Quick socket check to fail fast if the MLflow server is unreachable.""" + parsed = urlparse(tracking_uri) + host = parsed.hostname + port = parsed.port or (443 if parsed.scheme == "https" else 80) + try: + with socket.create_connection((host, port), timeout=timeout): + pass + print(f"MLflow server reachable at {host}:{port}") + except OSError as e: + raise RuntimeError( + f"MLflow server at {tracking_uri} is not reachable: {e}" + ) from e + + +def download_model(config_dict, model_type): + """Download the model from MLflow, exactly as the dashboard does.""" + if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"): + raise RuntimeError( + "No mlflow.tracking_uri found in config file; cannot load model from MLflow." + ) + + tracking_uri = config_dict["mlflow"]["tracking_uri"] + check_server_reachable(tracking_uri) + mlflow.set_tracking_uri(tracking_uri) + print(f"MLflow tracking URI: {tracking_uri}") + + # Mirror dashboard authentication logic + if tracking_uri == "https://mlflow.american-science-cloud.org": + enable_amsc_x_api_key(config_dict) + + experiment = config_dict["experiment"] + model_name = f"{experiment}_{model_type}" + model_uri = f"models:/{model_name}/latest" + print(f"Downloading model '{model_uri}' ...") + + # Same download command as in the dashboard (model_manager.py) + model = mlflow.pyfunc.load_model(model_uri).unwrap_python_model().model + print(f"Model downloaded successfully: {type(model).__name__}") + return model + + +def load_experimental_inputs(config_dict): + """Fetch all experimental points from the database and return as a batch input dict.""" + experiment = config_dict["experiment"] + input_variables = config_dict["inputs"] + input_names = [v["name"] for v in input_variables.values()] + + db = connect_to_db(config_dict) + date_filter = config_dict.get("date_filter", {}) + df_exp = pd.DataFrame(db[experiment].find({"experiment_flag": 1, **date_filter})) + + if df_exp.empty: + raise RuntimeError("No experimental points found in the database.") + + missing = [name for name in input_names if name not in df_exp.columns] + if missing: + raise RuntimeError(f"Missing input columns in experimental data: {missing}") + + print(f"Fetched {len(df_exp)} experimental points from the database.") + return { + name: torch.tensor(df_exp[name].values, dtype=torch.float64) + for name in input_names + } + + +def check_evaluate(model, config_dict): + """Call evaluate() with experimental data fetched from the database.""" + inputs = load_experimental_inputs(config_dict) + print( + f"Calling model.evaluate() with {len(next(iter(inputs.values())))} experimental points..." + ) + result = model.evaluate(inputs) + print("evaluate() succeeded.") + print(f"Output keys: {list(result.keys())}") + return result + + +if __name__ == "__main__": + config_file, model_type = parse_arguments() + + # Load configuration + config_dict = load_config(config_file) + print(f"Experiment: {config_dict['experiment']}") + + # Download model from MLflow + try: + model = download_model(config_dict, model_type) + except Exception as e: + print(f"[FAIL] Could not download model: {e}") + sys.exit(1) + + # Evaluate with default parameters + try: + check_evaluate(model, config_dict) + except Exception as e: + print(f"[FAIL] evaluate() raised an error: {e}") + sys.exit(1) + + print("[PASS] Model loaded and evaluated successfully.") diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..33e88d7e --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,67 @@ +# CLAUDE.md + +This file provides guidance to AI agents when working with code in this repository. + +## Project Overview + +Synapse is a framework that couples experimental data, simulations, and machine learning models through a web-based dashboard deployed at NERSC. + +## Architecture + +- **`dashboard/`** — Trame (Vue.js/Flask) web application. `app.py` is the main entry point. Functionality is split into manager modules: `model_manager.py` (ML models), `parameters_manager.py` (parameter handling), `calibration_manager.py` (simulation calibration), `sfapi_manager.py` (NERSC Superfacility API), `optimization_manager.py` (Bayesian optimization), `state_manager.py`, `error_manager.py`, `outputs_manager.py`. `utils.py` handles data loading. +- **`ml/`** — PyTorch training pipeline. `train_model.py` is the main script, `Neural_Net_Classes.py` defines network architectures. Supports GP, NN, and ensemble models. +- **`experiments/`** — YAML config files for each BELLA experiment (inputs, outputs, calibration variables). +- **`dashboard.Dockerfile`** / **`ml.Dockerfile`** — Container images for GUI and ML training respectively. +- **`publish_container.py`** — Builds and pushes containers to NERSC registry. + +## Key Commands + +### Linting +```bash +pre-commit run --all-files # Ruff linting + formatting +``` + +### Running the Dashboard Locally +```bash +conda-lock install --name synapse-gui dashboard/environment-lock.yml +conda activate synapse-gui +python -u dashboard/app.py --port 8080 +``` +Requires MongoDB access via `SF_DB_HOST` and `SF_DB_READONLY_PASSWORD` environment variables. + +### Running ML Training Locally +```bash +conda-lock install --name synapse-ml ml/environment-lock.yml +conda activate synapse-ml +python ml/train_model.py --test --model NN --config_file config.yaml +``` + +### Testing ML Models End-to-End +Use the `/test-ml` skill (defined in `.claude/skills/test-ml/`) to train models, populate MLflow, and verify they load and evaluate correctly. It handles pre-flight checks (conda env, DB, MLflow server), parallel training, and validation. + +### Docker +```bash +# Build +docker build --platform linux/amd64 --output type=image,oci-mediatypes=true -t synapse-gui -f dashboard.Dockerfile . +docker build --platform linux/amd64 --output type=image,oci-mediatypes=true -t synapse-ml -f ml.Dockerfile . + +# Publish to NERSC +python publish_container.py --gui --yes +python publish_container.py --ml --yes +``` + +## Key Dependencies + +- **Dashboard**: Trame, Flask, PyTorch, PyMongo, Plotly, BoTorch +- **ML Training**: PyTorch, GPyTorch, BoTorch, LUME-Model, PyMongo +- **Environments**: Managed via conda-lock (`environment-lock.yml` in each directory) + +## Database + +MongoDB stores experimental/simulation data and ML models. Dashboard uses read-only access (`SF_DB_READONLY_PASSWORD`), training uses admin access (`SF_DB_ADMIN_PASSWORD`). + +## Deployment + +- Dashboard deployed at NERSC Spin (bellasuperfacility.lbl.gov) +- ML training runs on Perlmutter HPC, triggered via Superfacility API from the dashboard +- Container registry: `registry.nersc.gov/m558/superfacility/` diff --git a/CLAUDE.md b/CLAUDE.md new file mode 120000 index 00000000..47dc3e3d --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +AGENTS.md \ No newline at end of file