Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions .claude/skills/test-ml/SKILL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
---
name: test-ml
description: Train ML models (GP, NN, ensemble_NN), populate MLflow, and verify they load and evaluate correctly. Use when testing the ML training and inference pipeline end-to-end.
---

# Test ML Models

Train models with `train_model.py`, then verify they load and evaluate via the `check_model.py` script bundled with this skill (`.claude/skills/test-ml/check_model.py`).

## Parse user intent

- If `$ARGUMENTS` specifies a model type (GP, NN, or ensemble_NN), only test that model type.
- If `$ARGUMENTS` specifies an experiment name, use `experiments/synapse-<name>/config.yaml`.
- If no model type is given, test all three: GP, NN, ensemble_NN.
- If no experiment is given, auto-detect available experiments by globbing `experiments/synapse-*/config.yaml`. Only use experiments whose config contains an `mlflow` section with a `tracking_uri`.

## Pre-flight checks (do these BEFORE any training)

Run all checks and **stop immediately with a clear warning** if any fail:

1. **Conda environment**: Verify `synapse-ml` conda environment exists:
```
source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml
```

2. **Database credentials**: Source `~/db.profile`, then read the config YAML and check that the environment variable named in `database.password_ro_env` is set.

3. **MLflow server**: Read `mlflow.tracking_uri` from the config YAML. Test reachability with a quick socket connection (e.g. `python -c "import socket; socket.create_connection(('<host>', <port>), timeout=5)"`). If unreachable, tell the user to start the MLflow server.

4. **MongoDB**: Test connectivity with a quick pymongo ping:
```
python -c "import pymongo; pymongo.MongoClient(host='<host>', port=<port>, serverSelectionTimeoutMS=5000).admin.command('ping')"
```
Note: for the MongoDB check, source `~/db.profile` first, then read the database credentials from the config file (host, port, auth, username_ro, password_ro_env).

## Train models

Run training from the `ml/` directory. Each command:
```
source ~/db.profile && source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml && cd <repo_root>/ml && python train_model.py --config_file <config_path> --model <model_type>
```

- When testing **multiple model types**, run the training commands **in parallel** (use background bash tasks).
- When testing a **single model type**, run it in the foreground.
- Wait for all training to complete before proceeding.

## Validate models

After training, run the bundled `check_model.py` for each model type:
```
source ~/miniconda3/etc/profile.d/conda.sh && conda activate synapse-ml && python <repo_root>/.claude/skills/test-ml/check_model.py --config_file <config_path> --model <model_type>
```

## Report results

Provide a summary table:
- Model type | Train status | Validation status
- Show `[PASS]` or `[FAIL]` for each
- For failures, include the error message
189 changes: 189 additions & 0 deletions .claude/skills/test-ml/check_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#!/usr/bin/env python
"""
Check that a model stored in MLflow loads and evaluates correctly,
using the same download logic as the dashboard.

Usage:
python check_model.py --config_file <path/to/config.yaml> --model <GP|NN|ensemble_NN>
"""

import argparse
import os
import socket
import sys
from pathlib import Path
from urllib.parse import urlparse
import pandas as pd
import torch
import yaml
import mlflow

# Import DB connection helper from train_model.py
_ML_DIR = Path(__file__).resolve().parents[3] / "ml"
sys.path.insert(0, str(_ML_DIR))
from train_model import connect_to_db


MODEL_TYPES = ["GP", "NN", "ensemble_NN"]


def parse_arguments():
parser = argparse.ArgumentParser(
description="Verify that an MLflow model loads and predicts with default parameters."
)
parser.add_argument(
"--config_file",
help="Path to the configuration file",
type=str,
required=True,
)
parser.add_argument(
"--model",
help="Model type: GP, NN, or ensemble_NN",
choices=MODEL_TYPES,
required=True,
)
args = parser.parse_args()
print(f"Config file: {args.config_file}, Model type: {args.model}")
return args.config_file, args.model


def load_config(config_file):
if not os.path.exists(config_file):
raise RuntimeError(f"Configuration file not found: {config_file}")
with open(config_file) as f:
return yaml.safe_load(f.read())


def enable_amsc_x_api_key(config_dict):
"""Inject AmSC X-Api-Key header into all MLflow requests (mirrors train_model.py)."""
import mlflow.utils.rest_utils as rest_utils

mlflow_cfg = config_dict.get("mlflow") or {}
api_key_env = mlflow_cfg.get("api_key_env")
if not api_key_env:
raise KeyError(
"Missing 'api_key_env' in 'mlflow' configuration for AmSC authentication."
)
api_key = os.getenv(api_key_env)
if api_key is None:
raise KeyError(
f"Environment variable '{api_key_env}' (from mlflow.api_key_env) is not set."
)

_orig = rest_utils.http_request

def patched(host_creds, endpoint, method, *args, **kwargs):
if "headers" in kwargs and kwargs["headers"] is not None:
h = dict(kwargs["headers"])
h["X-Api-Key"] = api_key
kwargs["headers"] = h
else:
h = dict(kwargs.get("extra_headers") or {})
h["X-Api-Key"] = api_key
kwargs["extra_headers"] = h
return _orig(host_creds, endpoint, method, *args, **kwargs)

rest_utils.http_request = patched


def check_server_reachable(tracking_uri, timeout=5):
"""Quick socket check to fail fast if the MLflow server is unreachable."""
parsed = urlparse(tracking_uri)
host = parsed.hostname
port = parsed.port or (443 if parsed.scheme == "https" else 80)
try:
with socket.create_connection((host, port), timeout=timeout):
pass
print(f"MLflow server reachable at {host}:{port}")
except OSError as e:
raise RuntimeError(
f"MLflow server at {tracking_uri} is not reachable: {e}"
) from e


def download_model(config_dict, model_type):
"""Download the model from MLflow, exactly as the dashboard does."""
if "mlflow" not in config_dict or not config_dict["mlflow"].get("tracking_uri"):
raise RuntimeError(
"No mlflow.tracking_uri found in config file; cannot load model from MLflow."
)

tracking_uri = config_dict["mlflow"]["tracking_uri"]
check_server_reachable(tracking_uri)
mlflow.set_tracking_uri(tracking_uri)
print(f"MLflow tracking URI: {tracking_uri}")

# Mirror dashboard authentication logic
if tracking_uri == "https://mlflow.american-science-cloud.org":
enable_amsc_x_api_key(config_dict)

experiment = config_dict["experiment"]
model_name = f"{experiment}_{model_type}"
model_uri = f"models:/{model_name}/latest"
print(f"Downloading model '{model_uri}' ...")

# Same download command as in the dashboard (model_manager.py)
model = mlflow.pyfunc.load_model(model_uri).unwrap_python_model().model
print(f"Model downloaded successfully: {type(model).__name__}")
return model


def load_experimental_inputs(config_dict):
"""Fetch all experimental points from the database and return as a batch input dict."""
experiment = config_dict["experiment"]
input_variables = config_dict["inputs"]
input_names = [v["name"] for v in input_variables.values()]

db = connect_to_db(config_dict)
date_filter = config_dict.get("date_filter", {})
df_exp = pd.DataFrame(db[experiment].find({"experiment_flag": 1, **date_filter}))

if df_exp.empty:
raise RuntimeError("No experimental points found in the database.")

missing = [name for name in input_names if name not in df_exp.columns]
if missing:
raise RuntimeError(f"Missing input columns in experimental data: {missing}")

print(f"Fetched {len(df_exp)} experimental points from the database.")
return {
name: torch.tensor(df_exp[name].values, dtype=torch.float64)
for name in input_names
}


def check_evaluate(model, config_dict):
"""Call evaluate() with experimental data fetched from the database."""
inputs = load_experimental_inputs(config_dict)
print(
f"Calling model.evaluate() with {len(next(iter(inputs.values())))} experimental points..."
)
result = model.evaluate(inputs)
print("evaluate() succeeded.")
print(f"Output keys: {list(result.keys())}")
return result


if __name__ == "__main__":
config_file, model_type = parse_arguments()

# Load configuration
config_dict = load_config(config_file)
print(f"Experiment: {config_dict['experiment']}")

# Download model from MLflow
try:
model = download_model(config_dict, model_type)
except Exception as e:
print(f"[FAIL] Could not download model: {e}")
sys.exit(1)

# Evaluate with default parameters
try:
check_evaluate(model, config_dict)
except Exception as e:
print(f"[FAIL] evaluate() raised an error: {e}")
sys.exit(1)

print("[PASS] Model loaded and evaluated successfully.")
67 changes: 67 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# CLAUDE.md

This file provides guidance to AI agents when working with code in this repository.

## Project Overview

Synapse is a framework that couples experimental data, simulations, and machine learning models through a web-based dashboard deployed at NERSC.

## Architecture

- **`dashboard/`** — Trame (Vue.js/Flask) web application. `app.py` is the main entry point. Functionality is split into manager modules: `model_manager.py` (ML models), `parameters_manager.py` (parameter handling), `calibration_manager.py` (simulation calibration), `sfapi_manager.py` (NERSC Superfacility API), `optimization_manager.py` (Bayesian optimization), `state_manager.py`, `error_manager.py`, `outputs_manager.py`. `utils.py` handles data loading.
- **`ml/`** — PyTorch training pipeline. `train_model.py` is the main script, `Neural_Net_Classes.py` defines network architectures. Supports GP, NN, and ensemble models.
- **`experiments/`** — YAML config files for each BELLA experiment (inputs, outputs, calibration variables).
- **`dashboard.Dockerfile`** / **`ml.Dockerfile`** — Container images for GUI and ML training respectively.
- **`publish_container.py`** — Builds and pushes containers to NERSC registry.

## Key Commands

### Linting
```bash
pre-commit run --all-files # Ruff linting + formatting
```

### Running the Dashboard Locally
```bash
conda-lock install --name synapse-gui dashboard/environment-lock.yml
conda activate synapse-gui
python -u dashboard/app.py --port 8080
```
Requires MongoDB access via `SF_DB_HOST` and `SF_DB_READONLY_PASSWORD` environment variables.

### Running ML Training Locally
```bash
conda-lock install --name synapse-ml ml/environment-lock.yml
conda activate synapse-ml
python ml/train_model.py --test --model NN --config_file config.yaml
```

### Testing ML Models End-to-End
Use the `/test-ml` skill (defined in `.claude/skills/test-ml/`) to train models, populate MLflow, and verify they load and evaluate correctly. It handles pre-flight checks (conda env, DB, MLflow server), parallel training, and validation.

### Docker
```bash
# Build
docker build --platform linux/amd64 --output type=image,oci-mediatypes=true -t synapse-gui -f dashboard.Dockerfile .
docker build --platform linux/amd64 --output type=image,oci-mediatypes=true -t synapse-ml -f ml.Dockerfile .

# Publish to NERSC
python publish_container.py --gui --yes
python publish_container.py --ml --yes
```

## Key Dependencies

- **Dashboard**: Trame, Flask, PyTorch, PyMongo, Plotly, BoTorch
- **ML Training**: PyTorch, GPyTorch, BoTorch, LUME-Model, PyMongo
- **Environments**: Managed via conda-lock (`environment-lock.yml` in each directory)

## Database

MongoDB stores experimental/simulation data and ML models. Dashboard uses read-only access (`SF_DB_READONLY_PASSWORD`), training uses admin access (`SF_DB_ADMIN_PASSWORD`).

## Deployment

- Dashboard deployed at NERSC Spin (bellasuperfacility.lbl.gov)
- ML training runs on Perlmutter HPC, triggered via Superfacility API from the dashboard
- Container registry: `registry.nersc.gov/m558/superfacility/`
1 change: 1 addition & 0 deletions CLAUDE.md
Loading