Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 69 additions & 40 deletions ml_peg/analysis/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,17 +386,31 @@ def calc_metric_scores(
normalizer = normalizer if normalizer is not None else normalize_metric
cleaned_thresholds = clean_thresholds(thresholds) if thresholds else None

metrics_scores = [row.copy() for row in metrics_data]
for row in metrics_scores:
for key, value in row.items():
# Value may be ``None`` if missing for a benchmark
if key not in {"MLIP", "Score", "id"} and value is not None:
if cleaned_thresholds is None or key not in cleaned_thresholds:
row[key] = value
continue

entry = cleaned_thresholds[key]
row[key] = normalizer(value, entry["good"], entry["bad"])
if cleaned_thresholds is None or not metrics_data:
return metrics_data

metric_columns = [
Comment thread
ElliottKasoar marked this conversation as resolved.
key for key in metrics_data[0] if key not in {"MLIP", "Score", "id"}
]
threshold_lookup = {
key: (entry["good"], entry["bad"]) for key, entry in cleaned_thresholds.items()
}

metrics_scores = []
for row in metrics_data:
new_row = row.copy()

for key in metric_columns:
if (value := row.get(key)) is None:
continue

if (thresholds_entry := threshold_lookup.get(key)) is None:
continue

good, bad = thresholds_entry
new_row[key] = normalizer(value, good, bad)

metrics_scores.append(new_row)

return metrics_scores

Expand All @@ -407,7 +421,8 @@ def calc_table_scores(
thresholds: Thresholds | None = None,
normalizer: Callable[[float, float, float], float] | None = None,
require_all_metrics: bool = True,
) -> list[MetricRow]:
return_scores: bool = False,
) -> list[MetricRow] | tuple[list[MetricRow], list[MetricRow]]:
"""
Calculate (normalised) score for each model and add to table data.

Expand All @@ -429,50 +444,65 @@ def calc_table_scores(
If True, score is set to None unless all metrics are present (not None).
If False, score is calculated from available metrics only.
Default is True.
return_scores
If True, also return the normalised metric rows used to calculate scores.
Default is False.

Returns
-------
list[MetricRow]
Rows of data with combined score for each model added.
list[MetricRow] | tuple[list[MetricRow], list[MetricRow]]
Rows of data with combined score for each model added. If `return_scores` is
`True`, the normalised metric rows are also returned.
"""
weights = weights if weights else {}

metrics_scores = calc_metric_scores(metrics_data, thresholds, normalizer)

if not metrics_data:
return metrics_data if not return_scores else (metrics_data, metrics_scores)

metric_columns = [
Comment thread
ElliottKasoar marked this conversation as resolved.
key for key in metrics_data[0] if key not in {"MLIP", "Score", "id"}
]
metric_weights = {key: weights.get(key, 1.0) for key in metric_columns}

for metrics_row, scores_row in zip(metrics_data, metrics_scores, strict=True):
scores_list = []
weights_list = []
weighted_sum = 0.0
weight_sum = 0.0

all_metrics_present = True
contains_nan = False

for key, value in metrics_row.items():
if key in {"MLIP", "Score", "id"}:
for key in metric_columns:
if (weight := metric_weights[key]) == 0:
continue

weight = weights.get(key, 1.0)
if weight == 0:
# Weight of zero excludes the metric from scoring requirements
continue
value = metrics_row.get(key)
score = scores_row.get(key)

if value is not None:
scores_list.append(scores_row[key])
weights_list.append(weight)
else:
# Track if any (weighted) metric is missing
if value is None or score is None:
all_metrics_present = False
continue

if isinstance(score, float) and np.isnan(score):
contains_nan = True
break

# Calculate score only if conditions are met
if require_all_metrics and not all_metrics_present:
# Strict mode: require all metrics to be present
weighted_sum += score * weight
weight_sum += weight

if contains_nan:
metrics_row["Score"] = np.nan
elif require_all_metrics and not all_metrics_present:
metrics_row["Score"] = None
elif scores_list:
# Calculate weighted average of available metrics
try:
metrics_row["Score"] = np.average(scores_list, weights=weights_list)
except ZeroDivisionError:
metrics_row["Score"] = np.mean(scores_list)
elif weight_sum > 0:
metrics_row["Score"] = weighted_sum / weight_sum
else:
metrics_row["Score"] = None

if return_scores:
return metrics_data, metrics_scores

return metrics_data


Expand Down Expand Up @@ -687,8 +717,7 @@ def update_score_style(
Updated table rows and style data.
"""
weights = clean_weights(weights)
data = calc_table_scores(data, weights, thresholds)
scored_data = calc_metric_scores(data, thresholds)
data, scored_data = calc_table_scores(data, weights, thresholds, return_scores=True)
style = get_table_style(data, scored_data=scored_data)
return data, style

Expand Down Expand Up @@ -726,9 +755,9 @@ def normalize_metric(
try:
# Handle NaNs robustly
if np.isnan([value, good_threshold, bad_threshold]).any():
return None
return np.nan
except TypeError:
return None
return np.nan

if good_threshold == bad_threshold:
return 1.0 if value == good_threshold else 0.0
Expand Down
39 changes: 37 additions & 2 deletions ml_peg/app/base_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC, abstractmethod
from pathlib import Path

from dash.dcc import Store
from dash.development.base_component import Component
from dash.html import Div

Expand All @@ -31,7 +32,7 @@ class BaseApp(ABC):
URL for online documentation. Default is None.
framework_id
Framework identifier used for benchmark attribution tags. Default is
``"ml_peg"``.
`"ml_peg"`.
"""

def __init__(
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
URL to online documentation. Default is None.
framework_id
Framework identifier used for benchmark attribution tags.
Default is `"ml_peg"`.
"""
self.name = name
self.description = description
Expand Down Expand Up @@ -91,11 +93,44 @@ def build_layout(self) -> Div:
framework_id=self.framework_id,
table=self.table,
column_widths=getattr(self.table, "column_widths", None),
thresholds=getattr(self.table, "thresholds", None),
thresholds=self.table.thresholds,
extra_components=self.extra_components,
)

@abstractmethod
def register_callbacks(self):
"""Register callbacks with app."""
pass

@property
def stores(self) -> list[Store]:
"""
List Stores to be registered with full app.

Returns
-------
list[Store]
List of Stores to be registered with full app.
"""
return [
Store(
id=f"{self.table_id}-computed-store",
storage_type="session",
data=self.table.data,
),
Store(
id=f"{self.table_id}-raw-data-store",
storage_type="session",
data=self.table.data,
),
Store(
id=f"{self.table_id}-weight-store",
storage_type="session",
data=self.table.weights,
),
Store(
id=f"{self.table_id}-thresholds-store",
storage_type="session",
data=self.table.thresholds,
),
]
Loading
Loading