Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions ml_peg/app/base_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from copy import deepcopy
import json
from pathlib import Path
import warnings

from dash.dcc import Store
from dash.development.base_component import Component
Expand Down Expand Up @@ -33,6 +36,8 @@ class BaseApp(ABC):
framework_id
Framework identifier used for benchmark attribution tags. Default is
`"ml_peg"`.
info_path
Path to json file containing additional info for filtering. Default is None.
"""

def __init__(
Expand All @@ -43,6 +48,7 @@ def __init__(
extra_components: list[Component],
docs_url: str | None = None,
framework_id: str = "ml_peg",
info_path: Path | None = None,
):
"""
Initiaise class.
Expand All @@ -62,6 +68,8 @@ def __init__(
framework_id
Framework identifier used for benchmark attribution tags.
Default is `"ml_peg"`.
info_path
Path to json file containing additional info for filtering. Default is None.
"""
self.name = name
self.description = description
Expand All @@ -73,7 +81,36 @@ def __init__(
self.table = rebuild_table(
self.table_path, id=self.table_id, description=description
)
self.metrics = [
col["id"]
for col in self.table.columns
if col["id"] not in ("MLIP", "Score", "id")
]
self.original_table = deepcopy(self.table)
self.layout = self.build_layout()
if info_path:
self.load_info(info_path)
else:
self.info = None
warnings.warn("No info_path provided.", stacklevel=2)
if hasattr(self, "set_elements"):
self.set_elements()
else:
self.elements = None

def load_info(self, info_path: Path) -> None:
"""
Load additional info for app.

Parameters
----------
info_path
Path to json file containing additional info for filtering.
"""
if not info_path.exists():
warnings.warn(f"{info_path} does not exist, skipping.", stacklevel=2)
with open(info_path) as f:
self.info = json.load(f)

def build_layout(self) -> Div:
"""
Expand Down Expand Up @@ -102,6 +139,40 @@ def register_callbacks(self):
"""Register callbacks with app."""
pass

def filter_table(self, filter_elements: list[str] | None) -> dict[str, dict]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this funciton mutating self.table.data is dangerous for the hosted app with multiple users. i think we should build from deepcopy(self.original_table.data)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yes I think you're right, thanks, I'll fix this soon

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yes I think you're right, thanks, I'll fix this soon

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also be careful if theres anything else that could have a similar effect, as its less strightforward to spot when we test locally

"""
Filter data by elements.

Parameters
----------
filter_elements
List of elements to filter out of data.

Returns
-------
dict[str, dict]
Updated benchmark table.
"""
if self.elements is None:
warnings.warn("No elements info available, skipping filter.", stacklevel=2)
return self.table.data

filter_elements = set(filter_elements) if filter_elements else set()

# Get overlap of deselected elements with each system's elements
if bool(self.elements & filter_elements):
for row in self.table.data:
for metric in self.metrics:
row[metric] = None
else:
for current_row, original_row in zip(
self.table.data, self.original_table.data, strict=True
):
for metric in self.metrics:
current_row[metric] = original_row[metric]

return self.table.data

@property
def stores(self) -> list[Store]:
"""
Expand Down
48 changes: 9 additions & 39 deletions ml_peg/app/build_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ml_peg.analysis.utils.utils import calc_table_scores, get_table_style
from ml_peg.app import APP_ROOT
from ml_peg.app.filters import get_element_filter, get_model_filter
from ml_peg.app.utils.build_components import (
build_download_controls,
build_faqs,
Expand All @@ -25,7 +26,10 @@
build_tutorial_button,
register_onboarding_callbacks,
)
from ml_peg.app.utils.register_callbacks import register_benchmark_to_category_callback
from ml_peg.app.utils.register_callbacks import (
register_benchmark_to_category_callback,
register_filter_tables_callback,
)
from ml_peg.app.utils.utils import (
build_level_of_theory_warnings,
get_framework_config,
Expand Down Expand Up @@ -870,43 +874,6 @@ def build_nav(
framework_id: framework_views[framework_id]["label"]
for framework_id in framework_order
}
model_options = [{"label": m, "value": m} for m in MODELS]

model_filter = Details(
[
Summary(
"Visible models",
style={
"cursor": "pointer",
"fontWeight": "600",
"fontSize": "11px",
"textTransform": "uppercase",
"letterSpacing": "0.07em",
"color": "#6c757d",
"padding": "5px",
},
),
Div(
[
Dropdown(
id="model-filter-checklist",
options=model_options,
value=MODELS,
multi=True,
maxHeight=600,
optionHeight=10,
placeholder="Select visible models",
closeOnSelect=False,
style={"fontSize": "12px"},
),
],
style={"padding": "8px 12px"},
),
],
id="model-filter-details",
open=True,
style={"marginBottom": "8px", "fontSize": "13px"},
)

_summary_label_style = {
"cursor": "pointer",
Expand Down Expand Up @@ -1041,8 +1008,9 @@ def build_nav(
sidebar,
Div(
[
model_filter,
get_model_filter(MODELS),
cmap_selector,
get_element_filter(),
Store(
id="selected-models-store",
storage_type="session",
Expand Down Expand Up @@ -1273,6 +1241,8 @@ def build_full_app(full_app: Dash, category: str = "*") -> None:
if not all_layouts:
raise ValueError("No tests were built successfully")

register_filter_tables_callback(all_apps)

# Combine tests into categories and create category summary
cat_views, cat_tables, cat_weights, framework_ids = build_category(
all_layouts, all_tables, all_frameworks
Expand Down
108 changes: 108 additions & 0 deletions ml_peg/app/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Build data components."""

from __future__ import annotations

from ase.data import chemical_symbols
from dash.dcc import Dropdown
from dash.html import Details, Div, Summary


def get_model_filter(models) -> Details:
"""
Get model filter component.

Parameters
----------
models
List of model names to include in filter options.

Returns
-------
Details
Model filter component.
"""
model_options = [{"label": m, "value": m} for m in models]

return Details(
[
Summary(
"Visible models",
style={
"cursor": "pointer",
"fontWeight": "600",
"fontSize": "11px",
"textTransform": "uppercase",
"letterSpacing": "0.07em",
"color": "#6c757d",
"padding": "5px",
},
),
Div(
[
Dropdown(
id="model-filter-checklist",
options=model_options,
value=models,
multi=True,
maxHeight=600,
optionHeight=10,
placeholder="Select visible models",
closeOnSelect=False,
style={"fontSize": "12px"},
),
],
style={"padding": "8px 12px"},
),
],
id="model-filter-details",
open=True,
style={"marginBottom": "8px", "fontSize": "13px"},
)


def get_element_filter() -> Details:
"""
Get element filter component.

Returns
-------
Details
Element filter component.
"""
# Exclude placeholder symbol for index 0
elements = chemical_symbols[1:]

return Details(
[
Summary(
"Filter by elements",
style={
"cursor": "pointer",
"fontWeight": "600",
"fontSize": "11px",
"textTransform": "uppercase",
"letterSpacing": "0.07em",
"color": "#6c757d",
"padding": "5px",
},
),
Div(
[
Dropdown(
id="element-filter",
options=elements,
value=None,
multi=True,
placeholder="Filter elements",
closeOnSelect=False,
style={"fontSize": "13px"},
debounce=True,
),
],
style={"padding": "8px 12px"},
),
],
id="element-filter-details",
open=True,
style={"marginBottom": "8px", "fontSize": "13px"},
)
Loading
Loading