Skip to content

Commit 8564980

Browse files
committed
enh: use predict_proba for roc_auc_score
1 parent cbd7fee commit 8564980

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

pydra_ml/tasks.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def read_file(filename, x_indices=None, target_vars=None, group=None):
1919
X = data[x_indices]
2020
else:
2121
raise ValueError(f"{x_indices} is not a list of string or ints")
22-
Y = data[target_vars]
22+
Y = data[list(target_vars)]
2323
if group is None:
2424
groups = list(range(X.shape[0]))
2525
else:
@@ -100,7 +100,11 @@ def to_instance(clf_info):
100100
else:
101101
pipe.fit(X[train_index], y[train_index])
102102
predicted = pipe.predict(X[test_index])
103-
return (y[test_index], predicted), (pipe, train_index, test_index)
103+
try:
104+
predicted_proba = pipe.predict_proba(X[test_index])
105+
except AttributeError:
106+
predicted_proba = None
107+
return (y[test_index], predicted, predicted_proba), (pipe, train_index, test_index)
104108

105109

106110
def calc_metric(output, metrics):
@@ -114,7 +118,11 @@ def calc_metric(output, metrics):
114118
for metric in metrics:
115119
metric_mod = __import__("sklearn.metrics", fromlist=[metric])
116120
metric_func = getattr(metric_mod, metric)
117-
score.append(metric_func(output[0], output[1]))
121+
if metric == 'roc_auc_score' and output[2] is not None:
122+
# For roc_auc_score, we need to pass the probability of the positive class
123+
score.append(metric_func(output[0], output[2][:, 1]))
124+
else:
125+
score.append(metric_func(output[0], output[1]))
118126
return score, output
119127

120128

0 commit comments

Comments
 (0)