Skip to content

Commit 6e52b49

Browse files
author
Arthur Douillard
committed
[results_utils] Add std of average of average incremental accuracies.
1 parent 6cf32f8 commit 6e52b49

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

inclearn/results_utils.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,31 @@ def aggregate(runs_accs):
9393
return means, stds
9494

9595

96-
def compute_unique_score(means, skip_first=False):
96+
def compute_unique_score(runs_accs, skip_first=False):
9797
"""Computes the average of the (average incremental) accuracies to get a
9898
unique score.
9999
100-
:param means: A list of mean accuracies over several runs.
100+
:param runs_accs: A list of runs. Each runs is a list of (average
101+
incremental) accuracies.
101102
:param skip_first: Whether to skip the first task accuracy as advised in
102103
End-to-End Incremental Accuracy.
103104
:return: A unique score being the average of the (average incremental)
104-
accuracies.
105+
accuracies, and a standard deviation.
105106
"""
106107
start = int(skip_first)
107-
sub_means = [means[i] for i in range(start, len(means))]
108108

109-
return round(sum(sub_means) / len(sub_means), 2)
109+
means = []
110+
for run in runs_accs:
111+
means.append(sum(run[start:]) / len(run[start:]))
112+
113+
mean_of_mean = sum(means) / len(means)
114+
if len(runs_accs) == 1: # One run, probably a paper, don't compute std:
115+
std = ""
116+
else:
117+
std = math.sqrt(sum(math.pow(mean_of_mean - i, 2) for i in means) / len(means))
118+
std = " ± " + str(round(std, 2))
119+
120+
return str(round(mean_of_mean, 2)), std
110121

111122

112123
def plot(results, increment, total, title="", path_to_save=None):
@@ -137,9 +148,9 @@ def plot(results, increment, total, title="", path_to_save=None):
137148
runs_accs = extract(path, avg_inc=avg_inc)
138149
means, stds = aggregate(runs_accs)
139150

140-
unique_score = compute_unique_score(means, skip_first=skip_first)
151+
unique_score, unique_std = compute_unique_score(runs_accs, skip_first=skip_first)
141152

142-
plt.errorbar(x, means, stds, label=label + " ({})".format(unique_score),
153+
plt.errorbar(x, means, stds, label=label + " ({})".format(unique_score + unique_std),
143154
marker="o", markersize=3)
144155

145156
plt.legend(loc="upper right")

0 commit comments

Comments
 (0)