Skip to content

Commit db013ab

Browse files
committed
ncia_hb300spxx10 app
1 parent fb6b110 commit db013ab

3 files changed

Lines changed: 94 additions & 41 deletions

File tree

ml_peg/analysis/non_covalent_interactions/ncia_hb375x10/analyse_ncia_hb375x10.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import pytest
1010

1111
from ml_peg.analysis.utils.decorators import build_table, plot_parity
12-
from ml_peg.analysis.utils.utils import (
13-
build_d3_name_map,
14-
load_metrics_config,
15-
mae,
16-
rmse,
17-
)
12+
from ml_peg.analysis.utils.utils import build_d3_name_map, load_metrics_config, mae
1813
from ml_peg.app import APP_ROOT
1914
from ml_peg.calcs import CALCS_ROOT
2015
from ml_peg.models.get_models import load_models
@@ -44,7 +39,7 @@ def labels() -> list:
4439
List of all system names.
4540
"""
4641
for model in MODELS:
47-
labels_list = [path.stem for path in (CALC_PATH / model).glob("*.xyz")]
42+
labels_list = sorted([path.stem for path in (CALC_PATH / model).glob("*.xyz")])
4843
break
4944
return labels_list
5045

@@ -112,37 +107,14 @@ def get_mae(interaction_energies) -> dict[str, float]:
112107
return results
113108

114109

115-
@pytest.fixture
116-
def get_rmse(interaction_energies) -> dict[str, float]:
117-
"""
118-
Get root mean square error for energies.
119-
120-
Parameters
121-
----------
122-
interaction_energies
123-
Dictionary of reference and predicted energies.
124-
125-
Returns
126-
-------
127-
dict[str, float]
128-
Dictionary of predicted energy errors for all models.
129-
"""
130-
results = {}
131-
for model_name in MODELS:
132-
results[model_name] = rmse(
133-
interaction_energies["ref"], interaction_energies[model_name]
134-
)
135-
return results
136-
137-
138110
@pytest.fixture
139111
@build_table(
140112
filename=OUT_PATH / "ncia_hb375x10_metrics_table.json",
141113
metric_tooltips=DEFAULT_TOOLTIPS,
142114
thresholds=DEFAULT_THRESHOLDS,
143115
mlip_name_map=D3_MODEL_NAMES,
144116
)
145-
def metrics(get_mae: dict[str, float], get_rmse: dict[str, float]) -> dict[str, dict]:
117+
def metrics(get_mae: dict[str, float]) -> dict[str, dict]:
146118
"""
147119
Get all metrics.
148120
@@ -151,17 +123,13 @@ def metrics(get_mae: dict[str, float], get_rmse: dict[str, float]) -> dict[str,
151123
get_mae
152124
Mean absolute errors for all models.
153125
154-
get_rmse
155-
Root Mean Square Error for all models.
156-
157126
Returns
158127
-------
159128
dict[str, dict]
160129
Metric names and values for all models.
161130
"""
162131
return {
163132
"MAE": get_mae,
164-
"RMSE": get_rmse,
165133
}
166134

167135

ml_peg/analysis/non_covalent_interactions/ncia_hb375x10/metrics.yml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,3 @@ metrics:
55
unit: eV
66
tooltip: Mean Absolute Error for all systems
77
level_of_theory: CCSD(T)
8-
RMSE:
9-
good: 0.0
10-
bad: 0.5
11-
unit: eV
12-
tooltip: Root Mean Square Error for all systems
13-
level_of_theory: CCSD(T)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Run NCIA_HB375x10 app."""
2+
3+
from __future__ import annotations
4+
5+
from dash import Dash
6+
from dash.html import Div
7+
8+
from ml_peg.app import APP_ROOT
9+
from ml_peg.app.base_app import BaseApp
10+
from ml_peg.app.utils.build_callbacks import (
11+
plot_from_table_column,
12+
struct_from_scatter,
13+
)
14+
from ml_peg.app.utils.load import read_plot
15+
from ml_peg.models.get_models import get_model_names
16+
from ml_peg.models.models import current_models
17+
18+
MODELS = get_model_names(current_models)
19+
BENCHMARK_NAME = "NCIA_HB375x10"
20+
DOCS_URL = (
21+
"https://ddmms.github.io/ml-peg/user_guide/benchmarks/"
22+
"non_covalent_interactions.html#ncia-hb375x10"
23+
)
24+
DATA_PATH = APP_ROOT / "data" / "non_covalent_interactions" / "ncia_hb375x10"
25+
26+
27+
class NCIANHB375x10App(BaseApp):
28+
"""NCIA_HB375x10 benchmark app layout and callbacks."""
29+
30+
def register_callbacks(self) -> None:
31+
"""Register callbacks to app."""
32+
scatter = read_plot(
33+
DATA_PATH / "figure_ncia_hb375x10.json",
34+
id=f"{BENCHMARK_NAME}-figure",
35+
)
36+
37+
model_dir = DATA_PATH / MODELS[0]
38+
if model_dir.exists():
39+
labels = sorted([f.stem for f in model_dir.glob("*.xyz")])
40+
structs = [
41+
f"assets/non_covalent_interactions/ncia_hb375x10/{MODELS[0]}/{label}.xyz"
42+
for label in labels
43+
]
44+
else:
45+
structs = []
46+
47+
plot_from_table_column(
48+
table_id=self.table_id,
49+
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
50+
column_to_plot={"MAE": scatter},
51+
)
52+
53+
struct_from_scatter(
54+
scatter_id=f"{BENCHMARK_NAME}-figure",
55+
struct_id=f"{BENCHMARK_NAME}-struct-placeholder",
56+
structs=structs,
57+
mode="struct",
58+
)
59+
60+
61+
def get_app() -> NCIANHB375x10App:
62+
"""
63+
Get NCIA_HB375x10 benchmark app layout and callback registration.
64+
65+
Returns
66+
-------
67+
NCIANHB375x10App
68+
Benchmark layout and callback registration.
69+
"""
70+
return NCIANHB375x10App(
71+
name=BENCHMARK_NAME,
72+
description=(
73+
"Performance in predicting hydrogen-bonded interaction energies "
74+
"for the NCIA HB375x10 dataset (neutral dimers from HB375). "
75+
"Reference data from CCSD(T) calculations."
76+
),
77+
docs_url=DOCS_URL,
78+
table_path=DATA_PATH / "ncia_hb375x10_metrics_table.json",
79+
extra_components=[
80+
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
81+
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
82+
],
83+
)
84+
85+
86+
if __name__ == "__main__":
87+
full_app = Dash(__name__, assets_folder=DATA_PATH.parent.parent)
88+
benchmark_app = get_app()
89+
full_app.layout = benchmark_app.layout
90+
benchmark_app.register_callbacks()
91+
full_app.run(port=8058, debug=True)

0 commit comments

Comments
 (0)