diff --git a/specparam/metrics/definitions.py b/specparam/metrics/definitions.py index b2458d30..c03a5dfb 100644 --- a/specparam/metrics/definitions.py +++ b/specparam/metrics/definitions.py @@ -14,13 +14,15 @@ measure='mae', description='Mean absolute error of the model fit to the data.', func=compute_mean_abs_error, + space='log', ) error_mse = Metric( category='error', measure='mse', description='Mean squared error of the model fit to the data.', - func=compute_mean_squared_error + func=compute_mean_squared_error, + space='log', ) error_rmse = Metric( @@ -28,6 +30,7 @@ measure='rmse', description='Root mean squared error of the model fit to the data.', func=compute_root_mean_squared_error, + space='log', ) error_medae = Metric( @@ -35,14 +38,28 @@ measure='medae', description='Median absolute error of the model fit to the data.', func=compute_median_abs_error, + space='log', +) + +error_maelin = Metric( + category='error', + measure='maelin', + description='Mean absolute error of the model fit to the data, in linear space.', + func=compute_mean_abs_error, + space='linear', ) # Collect available error metrics ERROR_METRICS = { + + # log spacing 'mae' : error_mae, 'mse' : error_mse, 'rmse' : error_rmse, 'medae' : error_medae, + + # linear spacing + 'maelin' : error_maelin, } ################################################################################################### @@ -53,6 +70,15 @@ measure='rsquared', description='R-squared between the model fit and the data.', func=compute_r_squared, + space='log', +) + +gof_rsquaredlin = Metric( + category='gof', + measure='rsquaredlin', + description='R-squared between the model fit and the data, in linear space.', + func=compute_r_squared, + space='linear', ) gof_adjrsquared = Metric( @@ -62,12 +88,18 @@ func=compute_adj_r_squared, kwargs={'n_params' : lambda data, results: \ results.params.periodic.params.size + results.params.aperiodic.params.size}, + space='log', ) # Collect available error metrics GOF_METRICS = { + + # log spacing 'rsquared' : gof_rsquared, 'adjrsquared' : gof_adjrsquared, + + # linear spacing + 'rsquaredlin' : gof_rsquaredlin, } ################################################################################################### diff --git a/specparam/metrics/metric.py b/specparam/metrics/metric.py index 122b4091..c3a8e1ea 100644 --- a/specparam/metrics/metric.py +++ b/specparam/metrics/metric.py @@ -18,6 +18,8 @@ class Metric(): Description of the metric. func : callable The function that computes the metric. + space : {'log', 'linear'} + Spacing of the data & model to use for metric evaluation. kwargs : dictionary Additional keyword argument to compute the metric. Each key should be the name of the additional argument. @@ -25,13 +27,14 @@ class Metric(): and returns the desired parameter / computed value. """ - def __init__(self, category, measure, description, func, kwargs=None): + def __init__(self, category, measure, description, func, space='log', kwargs=None): """Initialize metric.""" self.category = category self.measure = measure self.description = description self.func = func + self.space = space self.result = np.nan self.kwargs = {} if not kwargs else kwargs @@ -76,7 +79,10 @@ def compute_metric(self, data, results): for key, lfunc in self.kwargs.items(): kwargs[key] = lfunc(data, results) - self.result = self.func(data.power_spectrum, results.model.modeled_spectrum, **kwargs) + self.result = self.func( + data.get_data('full', space=self.space), + results.model.get_component('full', space=self.space), + **kwargs) def reset(self): diff --git a/specparam/tests/metrics/test_metric.py b/specparam/tests/metrics/test_metric.py index 03d2865e..7e96a394 100644 --- a/specparam/tests/metrics/test_metric.py +++ b/specparam/tests/metrics/test_metric.py @@ -20,7 +20,7 @@ def test_metric(tfm): def test_metric_kwargs(tfm): metric = Metric('gof', 'ar2', 'Description.', compute_adj_r_squared, - {'n_params' : lambda data, results: \ + kwargs={'n_params' : lambda data, results: \ results.params.periodic.params.size + results.params.aperiodic.params.size}) assert isinstance(metric, Metric)