From d7dc4e6e5c8d0e3ff2e58f4108cc58c835adb89d Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 13:31:25 +0100 Subject: [PATCH 1/7] Update apps for element filtering --- ml_peg/app/base_app.py | 49 ++++++++++ ml_peg/app/build_app.py | 48 ++-------- .../DMC_ICE13/app_DMC_ICE13.py | 89 +++++++++++++++++-- ml_peg/app/molecular_crystal/X23/app_X23.py | 89 +++++++++++++++++-- .../app/nebs/li_diffusion/app_li_diffusion.py | 54 +++++++++-- ml_peg/app/utils/register_callbacks.py | 86 +++++++++++++++++- 6 files changed, 349 insertions(+), 66 deletions(-) diff --git a/ml_peg/app/base_app.py b/ml_peg/app/base_app.py index 5f9f0691e..bd7581eb9 100644 --- a/ml_peg/app/base_app.py +++ b/ml_peg/app/base_app.py @@ -3,7 +3,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from copy import deepcopy +import json from pathlib import Path +import warnings from dash.dcc import Store from dash.development.base_component import Component @@ -33,6 +36,8 @@ class BaseApp(ABC): framework_id Framework identifier used for benchmark attribution tags. Default is `"ml_peg"`. + info_path + Path to json file containing additional info for filtering. Default is None. """ def __init__( @@ -43,6 +48,7 @@ def __init__( extra_components: list[Component], docs_url: str | None = None, framework_id: str = "ml_peg", + info_path: Path | None = None, ): """ Initiaise class. @@ -62,6 +68,8 @@ def __init__( framework_id Framework identifier used for benchmark attribution tags. Default is `"ml_peg"`. + info_path + Path to json file containing additional info for filtering. Default is None. """ self.name = name self.description = description @@ -73,7 +81,31 @@ def __init__( self.table = rebuild_table( self.table_path, id=self.table_id, description=description ) + self.original_table = deepcopy(self.table) self.layout = self.build_layout() + if info_path: + self.load_info(info_path) + else: + self.info = None + warnings.warn("No info_path provided.", stacklevel=2) + if hasattr(self, "set_elements"): + self.set_elements() + else: + self.elements = None + + def load_info(self, info_path: Path) -> None: + """ + Load additional info for app. + + Parameters + ---------- + info_path + Path to json file containing additional info for filtering. + """ + if not info_path.exists(): + warnings.warn(f"{info_path} does not exist, skipping.", stacklevel=2) + with open(info_path) as f: + self.info = json.load(f) def build_layout(self) -> Div: """ @@ -102,6 +134,23 @@ def register_callbacks(self): """Register callbacks with app.""" pass + def filter_table(self, filter_elements: list[str] | None) -> None: + """ + Filter data by elements. + + Parameters + ---------- + filter_elements + List of elements to filter out of data. + + Returns + ------- + dict[str, dict] + Updated benchmark table. + """ + print(f"No filter_data method defined for {self.name}, skipping.") + return self.table.data + @property def stores(self) -> list[Store]: """ diff --git a/ml_peg/app/build_app.py b/ml_peg/app/build_app.py index b26f42b5c..ed7b08612 100644 --- a/ml_peg/app/build_app.py +++ b/ml_peg/app/build_app.py @@ -14,6 +14,7 @@ from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style from ml_peg.app import APP_ROOT +from ml_peg.app.filter import get_element_filter, get_model_filter from ml_peg.app.utils.build_components import ( build_download_controls, build_faqs, @@ -25,7 +26,10 @@ build_tutorial_button, register_onboarding_callbacks, ) -from ml_peg.app.utils.register_callbacks import register_benchmark_to_category_callback +from ml_peg.app.utils.register_callbacks import ( + register_benchmark_to_category_callback, + register_filter_tables_callback, +) from ml_peg.app.utils.utils import ( build_level_of_theory_warnings, get_framework_config, @@ -870,43 +874,6 @@ def build_nav( framework_id: framework_views[framework_id]["label"] for framework_id in framework_order } - model_options = [{"label": m, "value": m} for m in MODELS] - - model_filter = Details( - [ - Summary( - "Visible models", - style={ - "cursor": "pointer", - "fontWeight": "600", - "fontSize": "11px", - "textTransform": "uppercase", - "letterSpacing": "0.07em", - "color": "#6c757d", - "padding": "5px", - }, - ), - Div( - [ - Dropdown( - id="model-filter-checklist", - options=model_options, - value=MODELS, - multi=True, - maxHeight=600, - optionHeight=10, - placeholder="Select visible models", - closeOnSelect=False, - style={"fontSize": "12px"}, - ), - ], - style={"padding": "8px 12px"}, - ), - ], - id="model-filter-details", - open=True, - style={"marginBottom": "8px", "fontSize": "13px"}, - ) _summary_label_style = { "cursor": "pointer", @@ -1041,8 +1008,9 @@ def build_nav( sidebar, Div( [ - model_filter, + get_model_filter(MODELS), cmap_selector, + get_element_filter(), Store( id="selected-models-store", storage_type="session", @@ -1273,6 +1241,8 @@ def build_full_app(full_app: Dash, category: str = "*") -> None: if not all_layouts: raise ValueError("No tests were built successfully") + register_filter_tables_callback(all_apps) + # Combine tests into categories and create category summary cat_views, cat_tables, cat_weights, framework_ids = build_category( all_layouts, all_tables, all_frameworks diff --git a/ml_peg/app/molecular_crystal/DMC_ICE13/app_DMC_ICE13.py b/ml_peg/app/molecular_crystal/DMC_ICE13/app_DMC_ICE13.py index f76e7625c..f09f76a40 100644 --- a/ml_peg/app/molecular_crystal/DMC_ICE13/app_DMC_ICE13.py +++ b/ml_peg/app/molecular_crystal/DMC_ICE13/app_DMC_ICE13.py @@ -2,47 +2,58 @@ from __future__ import annotations +import warnings + from dash import Dash +from dash.dcc import Graph from dash.html import Div +import numpy as np +from ml_peg.analysis.molecular_crystal.DMC_ICE13.analyse_DMC_ICE13 import get_metrics from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp from ml_peg.app.utils.build_callbacks import ( + filter_table, plot_from_table_column, struct_from_scatter, ) from ml_peg.app.utils.load import read_plot -from ml_peg.models import current_models -from ml_peg.models.get_models import get_model_names # Get all models -MODELS = get_model_names(current_models) BENCHMARK_NAME = "DMC-ICE13 Lattice Energies" DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/molecular_crystal.html#dmc-ice13" DATA_PATH = APP_ROOT / "data" / "molecular_crystal" / "DMC_ICE13" +INFO_PATH = DATA_PATH / "info.json" class DMCICE13App(BaseApp): """DMC-ICE13 benchmark app layout and callbacks.""" - def register_callbacks(self) -> None: - """Register callbacks to app.""" - scatter = read_plot( + def load_data(self) -> None: + """Load data required for filtering.""" + self.data = read_plot( DATA_PATH / "figure_lattice_energies.json", id=f"{BENCHMARK_NAME}-figure", ) + def register_callbacks(self) -> None: + """Register callbacks to app.""" + if not hasattr(self, "data"): + self.load_data() + # Assets dir will be parent directory - individual files for each polymorph - structs_dir = DATA_PATH / MODELS[0] + structs_dir = DATA_PATH / "mock" + if not structs_dir.exists(): + warnings.warn(f"Structures directory {structs_dir} not found", stacklevel=2) structs = [ - f"/assets/molecular_crystal/DMC_ICE13/{MODELS[0]}/{struct_file.stem}.xyz" + f"/assets/molecular_crystal/DMC_ICE13/mock/{struct_file.stem}.xyz" for struct_file in sorted(structs_dir.glob("*.xyz")) ] plot_from_table_column( table_id=self.table_id, plot_id=f"{BENCHMARK_NAME}-figure-placeholder", - column_to_plot={"MAE": scatter}, + column_to_plot={"MAE": self.data}, ) struct_from_scatter( @@ -52,6 +63,65 @@ def register_callbacks(self) -> None: mode="struct", ) + # Ensure data and elements are loaded + if not hasattr(self, "data"): + self.load_data() + if not hasattr(self, "elements"): + self.get_elements() + + filter_table( + table_id=self.table_id, + filter_func=self.filter_data, + filter_kwargs={"data": self.data, "test_elements": self.elements}, + ) + + def get_elements(self) -> None: + """Get element sets for filtering.""" + try: + self.elements = [set(entry) for entry in self.info["elements"]] + except (AttributeError, KeyError, TypeError): + self.elements = [] + warnings.warn("Unable to read elements lists.", stacklevel=2) + + @staticmethod + def filter_data( + filter_elements: set[str], data: Graph, test_elements: list[set[str]] + ) -> dict[str, dict]: + """ + Apply elements filter to data. + + Parameters + ---------- + filter_elements + Set of elements to filter out of data. + data + Scatter plot to filter. + test_elements + List of element for each system. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + # Get overlap of deselected elements with each system's elements + filtered_indices = [ + not bool(elements & filter_elements) for elements in test_elements + ] + + results = {} + ref_filtered = False + + for plot in data.figure.data: + # Ignore unamed (parity) line + if plot.name: + results[plot.name] = np.array(plot.x)[filtered_indices].tolist() + if not ref_filtered: + results["ref"] = np.array(plot.y)[filtered_indices].tolist() + ref_filtered = True + + return get_metrics(results) + def get_app() -> DMCICE13App: """ @@ -71,6 +141,7 @@ def get_app() -> DMCICE13App: Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), ], + info_path=INFO_PATH, ) diff --git a/ml_peg/app/molecular_crystal/X23/app_X23.py b/ml_peg/app/molecular_crystal/X23/app_X23.py index 437c64f24..3e5bffb5c 100644 --- a/ml_peg/app/molecular_crystal/X23/app_X23.py +++ b/ml_peg/app/molecular_crystal/X23/app_X23.py @@ -2,49 +2,60 @@ from __future__ import annotations +import warnings + from dash import Dash +from dash.dcc import Graph from dash.html import Div +import numpy as np +from ml_peg.analysis.molecular_crystal.X23.analyse_X23 import get_metrics from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp from ml_peg.app.utils.build_callbacks import ( + filter_table, plot_from_table_column, struct_from_scatter, ) from ml_peg.app.utils.load import read_plot -from ml_peg.models import current_models -from ml_peg.models.get_models import get_model_names # Get all models -MODELS = get_model_names(current_models) BENCHMARK_NAME = "X23 Lattice Energies" DOCS_URL = ( "https://ddmms.github.io/ml-peg/user_guide/benchmarks/molecular_crystal.html#x23" ) DATA_PATH = APP_ROOT / "data" / "molecular_crystal" / "X23" +INFO_PATH = DATA_PATH / "info.json" class X23App(BaseApp): """X23 benchmark app layout and callbacks.""" - def register_callbacks(self) -> None: - """Register callbacks to app.""" - scatter = read_plot( + def load_data(self) -> None: + """Load data required for filtering.""" + self.data = read_plot( DATA_PATH / "figure_lattice_energies.json", id=f"{BENCHMARK_NAME}-figure", ) + def register_callbacks(self) -> None: + """Register callbacks to app.""" + if not hasattr(self, "data"): + self.load_data() + # Assets dir will be parent directory - individual files for each system - structs_dir = DATA_PATH / MODELS[0] + structs_dir = DATA_PATH / "mock" + if not structs_dir.exists(): + warnings.warn(f"Structures directory {structs_dir} not found", stacklevel=2) structs = [ - f"/assets/molecular_crystal/X23/{MODELS[0]}/{struct_file.stem}.xyz" + f"/assets/molecular_crystal/X23/mock/{struct_file.stem}.xyz" for struct_file in sorted(structs_dir.glob("*.xyz")) ] plot_from_table_column( table_id=self.table_id, plot_id=f"{BENCHMARK_NAME}-figure-placeholder", - column_to_plot={"MAE": scatter}, + column_to_plot={"MAE": self.data}, ) struct_from_scatter( @@ -54,6 +65,65 @@ def register_callbacks(self) -> None: mode="struct", ) + # Ensure data and elements are loaded + if not hasattr(self, "data"): + self.load_data() + if not hasattr(self, "elements"): + self.get_elements() + + filter_table( + table_id=self.table_id, + filter_func=self.filter_data, + filter_kwargs={"data": self.data, "test_elements": self.elements}, + ) + + def get_elements(self) -> None: + """Get element sets for filtering from loaded info.""" + try: + self.elements = [set(entry) for entry in self.info["elements"]] + except (AttributeError, KeyError, TypeError): + self.elements = [] + warnings.warn("Unable to read elements lists.", stacklevel=2) + + @staticmethod + def filter_data( + filter_elements: set[str], data: Graph, test_elements: list[set[str]] + ) -> dict[str, dict]: + """ + Apply elements filter to data. + + Parameters + ---------- + filter_elements + Set of elements to filter out of data. + data + Scatter plot to filter. + test_elements + List of element for each system. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + # Get overlap of deselected elements with each system's elements + filtered_indices = [ + not bool(elements & filter_elements) for elements in test_elements + ] + + results = {} + ref_filtered = False + + for plot in data.figure.data: + # Ignore unamed (parity) line + if plot.name: + results[plot.name] = np.array(plot.x)[filtered_indices].tolist() + if not ref_filtered: + results["ref"] = np.array(plot.y)[filtered_indices].tolist() + ref_filtered = True + + return get_metrics(results) + def get_app() -> X23App: """ @@ -73,6 +143,7 @@ def get_app() -> X23App: Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), ], + info_path=INFO_PATH, ) diff --git a/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py b/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py index 99576f4af..08f6ea08d 100644 --- a/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py +++ b/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py @@ -2,15 +2,14 @@ from __future__ import annotations +import warnings + from dash import Dash from dash.html import Div from ml_peg.app import APP_ROOT from ml_peg.app.base_app import BaseApp -from ml_peg.app.utils.build_callbacks import ( - plot_from_table_cell, - struct_from_scatter, -) +from ml_peg.app.utils.build_callbacks import plot_from_table_cell, struct_from_scatter from ml_peg.app.utils.load import read_plot from ml_peg.models import current_models from ml_peg.models.get_models import get_model_names @@ -30,11 +29,11 @@ def register_callbacks(self) -> None: scatter_plots = { model: { "Path B error": read_plot( - DATA_PATH / f"figure_{model}_neb_b.json", + DATA_PATH / model / "figure_neb_b.json", id=f"{BENCHMARK_NAME}-{model}-figure-b", ), "Path C error": read_plot( - DATA_PATH / f"figure_{model}_neb_c.json", + DATA_PATH / model / "figure_neb_c.json", id=f"{BENCHMARK_NAME}-{model}-figure-c", ), } @@ -45,8 +44,8 @@ def register_callbacks(self) -> None: assets_dir = "/assets/nebs/li_diffusion" structs = { model: { - "Path B error": f"{assets_dir}/{model}/{model}-b-neb-band.extxyz", - "Path C error": f"{assets_dir}/{model}/{model}-c-neb-band.extxyz", + "Path B error": f"{assets_dir}/{model}/b-neb-band.extxyz", + "Path C error": f"{assets_dir}/{model}/c-neb-band.extxyz", } for model in MODELS } @@ -66,6 +65,44 @@ def register_callbacks(self) -> None: mode="traj", ) + def set_elements(self) -> None: + """Get element sets for filtering.""" + try: + self.elements = set(self.info["elements"][0]) + except (AttributeError, KeyError, TypeError): + self.elements = set() + warnings.warn("Unable to read elements lists.", stacklevel=2) + + def filter_table(self, filter_elements: list[str] | None) -> dict[str, dict]: + """ + Apply elements filter to data. + + Parameters + ---------- + filter_elements + List of elements to filter out of data. + + Returns + ------- + dict[str, dict] + Metric names and values for all models. + """ + filter_elements = set(filter_elements) if filter_elements else set() + + # Get overlap of deselected elements with each system's elements + if bool(self.elements & filter_elements): + for row in self.table.data: + row["Path B error"] = None + row["Path C error"] = None + else: + for current_row, original_row in zip( + self.table.data, self.original_table.data, strict=True + ): + current_row["Path B error"] = original_row["Path B error"] + current_row["Path C error"] = original_row["Path C error"] + + return self.table.data + def get_app() -> LiDiffusionApp: """ @@ -85,6 +122,7 @@ def get_app() -> LiDiffusionApp: Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), ], + info_path=DATA_PATH / "info.json", ) diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 8d2c0d4db..8e0c6cc50 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -10,6 +10,7 @@ from dash import ( MATCH, ClientsideFunction, + Dash, Input, Output, Patch, @@ -1300,5 +1301,88 @@ def download_table( Output(f"{table_id}-download", "data", allow_duplicate=True), Input(f"{table_id}-download-request", "data"), prevent_initial_call=True, - optional=True, ) + + +def register_filter_tables_callback(apps: dict[str, Dash]) -> None: + """ + Update all tables when filter dropdown value changes. + + Parameters + ---------- + apps + Dictionary of test apps to register callbacks for. + """ + app_entries = [] + for app in apps.values(): + app_entries.append( + { + "app": app, + "weight_state": State(f"{app.table_id}-weight-store", "data"), + "threshold_state": State(f"{app.table_id}-thresholds-store", "data"), + } + ) + + outputs = [] + for entry in sorted(app_entries): + app = entry["app"] + outputs.extend( + [ + Output(f"{app.table_id}-computed-store", "data"), + Output(f"{app.table_id}-raw-data-store", "data", allow_duplicate=True), + ] + ) + + states = [] + for entry in app_entries: + states.extend([entry["weight_state"], entry["threshold_state"]]) + + @callback( + outputs, Input("element-filter", "value"), states, prevent_initial_call=True + ) + def recompute_tables(elements, *args): + """ + Recompute all benchmark tables when element filter is applied. + + Parameters + ---------- + elements + List of selected elements to filter by. + *args + Weight and threshold states for each app. + + Returns + ------- + list[list[dict]] + Updated rows for each app's computed store and raw data stores. + """ + # Rebuild inputs for each app + per_app_state = {} + iterator = iter(args) + + for entry in sorted(app_entries): + app = entry["app"] + per_app_state[app.table_id] = { + "weights": next(iterator), + "thresholds": next(iterator), + } + + results = [] + + for entry in sorted(app_entries): + app = entry["app"] + state = per_app_state[app.table_id] + weights = state["weights"] + thresholds = state["thresholds"] + + updated_data = app.filter_table(elements) + + # Update overall table score for new weights and thresholds + metrics_data = calc_table_scores(updated_data, weights, thresholds) + + # Update stored scores per metric + scored_rows = calc_metric_scores(updated_data, thresholds) + + results.extend([scored_rows, metrics_data]) + + return results From 3ff0f19c108266aad60301876655581af50ea6a5 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 13:37:57 +0100 Subject: [PATCH 2/7] Add missing filters file and module name --- ml_peg/app/build_app.py | 2 +- ml_peg/app/filters.py | 108 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 ml_peg/app/filters.py diff --git a/ml_peg/app/build_app.py b/ml_peg/app/build_app.py index ed7b08612..e43fab370 100644 --- a/ml_peg/app/build_app.py +++ b/ml_peg/app/build_app.py @@ -14,7 +14,7 @@ from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style from ml_peg.app import APP_ROOT -from ml_peg.app.filter import get_element_filter, get_model_filter +from ml_peg.app.filters import get_element_filter, get_model_filter from ml_peg.app.utils.build_components import ( build_download_controls, build_faqs, diff --git a/ml_peg/app/filters.py b/ml_peg/app/filters.py new file mode 100644 index 000000000..256977cb3 --- /dev/null +++ b/ml_peg/app/filters.py @@ -0,0 +1,108 @@ +"""Build data components.""" + +from __future__ import annotations + +from ase.data import chemical_symbols +from dash.dcc import Dropdown +from dash.html import Details, Div, Summary + + +def get_model_filter(models) -> Details: + """ + Get model filter component. + + Parameters + ---------- + models + List of model names to include in filter options. + + Returns + ------- + Details + Model filter component. + """ + model_options = [{"label": m, "value": m} for m in models] + + return Details( + [ + Summary( + "Visible models", + style={ + "cursor": "pointer", + "fontWeight": "600", + "fontSize": "11px", + "textTransform": "uppercase", + "letterSpacing": "0.07em", + "color": "#6c757d", + "padding": "5px", + }, + ), + Div( + [ + Dropdown( + id="model-filter-checklist", + options=model_options, + value=models, + multi=True, + maxHeight=600, + optionHeight=10, + placeholder="Select visible models", + closeOnSelect=False, + style={"fontSize": "12px"}, + ), + ], + style={"padding": "8px 12px"}, + ), + ], + id="model-filter-details", + open=True, + style={"marginBottom": "8px", "fontSize": "13px"}, + ) + + +def get_element_filter() -> Details: + """ + Get element filter component. + + Returns + ------- + Details + Element filter component. + """ + # Exclude placeholder symbol for index 0 + elements = chemical_symbols[1:] + + return Details( + [ + Summary( + "Filter by elements", + style={ + "cursor": "pointer", + "fontWeight": "600", + "fontSize": "11px", + "textTransform": "uppercase", + "letterSpacing": "0.07em", + "color": "#6c757d", + "padding": "5px", + }, + ), + Div( + [ + Dropdown( + id="element-filter", + options=elements, + value=None, + multi=True, + placeholder="Filter elements", + closeOnSelect=False, + style={"fontSize": "13px"}, + debounce=True, + ), + ], + style={"padding": "8px 12px"}, + ), + ], + id="element-filter-details", + open=True, + style={"marginBottom": "8px", "fontSize": "13px"}, + ) From a03f4fc8732d25646b84c386bf90cc56eb0adc74 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 13:41:03 +0100 Subject: [PATCH 3/7] Fix sorting of apps --- ml_peg/app/utils/register_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 8e0c6cc50..45c788406 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -1334,7 +1334,7 @@ def register_filter_tables_callback(apps: dict[str, Dash]) -> None: ) states = [] - for entry in app_entries: + for entry in sorted(app_entries): states.extend([entry["weight_state"], entry["threshold_state"]]) @callback( From 973c4f12222b46a4c923ccf3709eb8022e8e4bd2 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 16:31:48 +0100 Subject: [PATCH 4/7] Refactor filter --- ml_peg/app/base_app.py | 24 +++++++++++++-- .../app/nebs/li_diffusion/app_li_diffusion.py | 30 ------------------- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/ml_peg/app/base_app.py b/ml_peg/app/base_app.py index bd7581eb9..b5e0839c8 100644 --- a/ml_peg/app/base_app.py +++ b/ml_peg/app/base_app.py @@ -81,6 +81,9 @@ def __init__( self.table = rebuild_table( self.table_path, id=self.table_id, description=description ) + self.metrics = [ + col for col in self.table.columns if col not in ("MLIP", "Score", "id") + ] self.original_table = deepcopy(self.table) self.layout = self.build_layout() if info_path: @@ -134,7 +137,7 @@ def register_callbacks(self): """Register callbacks with app.""" pass - def filter_table(self, filter_elements: list[str] | None) -> None: + def filter_table(self, filter_elements: list[str] | None) -> dict[str, dict]: """ Filter data by elements. @@ -148,7 +151,24 @@ def filter_table(self, filter_elements: list[str] | None) -> None: dict[str, dict] Updated benchmark table. """ - print(f"No filter_data method defined for {self.name}, skipping.") + if self.elements is None: + warnings.warn("No elements info available, skipping filter.", stacklevel=2) + return self.table.data + + filter_elements = set(filter_elements) if filter_elements else set() + + # Get overlap of deselected elements with each system's elements + if bool(self.elements & filter_elements): + for row in self.table.data: + for metric in self.metrics: + row[metric] = None + else: + for current_row, original_row in zip( + self.table.data, self.original_table.data, strict=True + ): + for metric in self.metrics: + current_row[metric] = original_row[metric] + return self.table.data @property diff --git a/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py b/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py index 08f6ea08d..2cc6fbdb8 100644 --- a/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py +++ b/ml_peg/app/nebs/li_diffusion/app_li_diffusion.py @@ -73,36 +73,6 @@ def set_elements(self) -> None: self.elements = set() warnings.warn("Unable to read elements lists.", stacklevel=2) - def filter_table(self, filter_elements: list[str] | None) -> dict[str, dict]: - """ - Apply elements filter to data. - - Parameters - ---------- - filter_elements - List of elements to filter out of data. - - Returns - ------- - dict[str, dict] - Metric names and values for all models. - """ - filter_elements = set(filter_elements) if filter_elements else set() - - # Get overlap of deselected elements with each system's elements - if bool(self.elements & filter_elements): - for row in self.table.data: - row["Path B error"] = None - row["Path C error"] = None - else: - for current_row, original_row in zip( - self.table.data, self.original_table.data, strict=True - ): - current_row["Path B error"] = original_row["Path B error"] - current_row["Path C error"] = original_row["Path C error"] - - return self.table.data - def get_app() -> LiDiffusionApp: """ From b7e9743d739c6a10db1ad94ef64a1229d8b16db8 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 17:36:29 +0100 Subject: [PATCH 5/7] Fix broken sorting --- ml_peg/app/utils/register_callbacks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ml_peg/app/utils/register_callbacks.py b/ml_peg/app/utils/register_callbacks.py index 45c788406..ca95c945b 100644 --- a/ml_peg/app/utils/register_callbacks.py +++ b/ml_peg/app/utils/register_callbacks.py @@ -1324,7 +1324,7 @@ def register_filter_tables_callback(apps: dict[str, Dash]) -> None: ) outputs = [] - for entry in sorted(app_entries): + for entry in app_entries: app = entry["app"] outputs.extend( [ @@ -1334,7 +1334,7 @@ def register_filter_tables_callback(apps: dict[str, Dash]) -> None: ) states = [] - for entry in sorted(app_entries): + for entry in app_entries: states.extend([entry["weight_state"], entry["threshold_state"]]) @callback( @@ -1360,7 +1360,7 @@ def recompute_tables(elements, *args): per_app_state = {} iterator = iter(args) - for entry in sorted(app_entries): + for entry in app_entries: app = entry["app"] per_app_state[app.table_id] = { "weights": next(iterator), @@ -1369,7 +1369,7 @@ def recompute_tables(elements, *args): results = [] - for entry in sorted(app_entries): + for entry in app_entries: app = entry["app"] state = per_app_state[app.table_id] weights = state["weights"] From 6d522ca0c264c69d7fb4742ff6c15b6a6c3b72e5 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 17:43:10 +0100 Subject: [PATCH 6/7] Fix setting metrics --- ml_peg/app/base_app.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ml_peg/app/base_app.py b/ml_peg/app/base_app.py index b5e0839c8..194965ecb 100644 --- a/ml_peg/app/base_app.py +++ b/ml_peg/app/base_app.py @@ -82,7 +82,9 @@ def __init__( self.table_path, id=self.table_id, description=description ) self.metrics = [ - col for col in self.table.columns if col not in ("MLIP", "Score", "id") + col["id"] + for col in self.table.columns + if col["id"] not in ("MLIP", "Score", "id") ] self.original_table = deepcopy(self.table) self.layout = self.build_layout() From 2ec7797ea0819656f3fc5ef30a1c07655c6a683b Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Fri, 29 May 2026 17:43:25 +0100 Subject: [PATCH 7/7] Set Si defects elements --- ml_peg/app/nebs/si_defects/app_si_defects.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/ml_peg/app/nebs/si_defects/app_si_defects.py b/ml_peg/app/nebs/si_defects/app_si_defects.py index 04da41cef..b55160e3c 100644 --- a/ml_peg/app/nebs/si_defects/app_si_defects.py +++ b/ml_peg/app/nebs/si_defects/app_si_defects.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +import warnings from dash import Dash from dash.html import Div @@ -90,6 +91,14 @@ def register_callbacks(self) -> None: mode="traj", ) + def set_elements(self) -> None: + """Get element sets for filtering.""" + try: + self.elements = set(self.info["elements"]) + except (AttributeError, KeyError, TypeError): + self.elements = set() + warnings.warn("Unable to read elements lists.", stacklevel=2) + def get_app() -> SiDefectNebSinglepointsApp: """ @@ -112,6 +121,7 @@ def get_app() -> SiDefectNebSinglepointsApp: Div(id=f"{BENCHMARK_NAME}-figure-placeholder"), Div(id=f"{BENCHMARK_NAME}-struct-placeholder"), ], + info_path=DATA_PATH / "info.json", )