@@ -19,7 +19,7 @@ def read_file(filename, x_indices=None, target_vars=None, group=None):
1919 X = data [x_indices ]
2020 else :
2121 raise ValueError (f"{ x_indices } is not a list of string or ints" )
22- Y = data [target_vars ]
22+ Y = data [list ( target_vars ) ]
2323 if group is None :
2424 groups = list (range (X .shape [0 ]))
2525 else :
@@ -100,7 +100,11 @@ def to_instance(clf_info):
100100 else :
101101 pipe .fit (X [train_index ], y [train_index ])
102102 predicted = pipe .predict (X [test_index ])
103- return (y [test_index ], predicted ), (pipe , train_index , test_index )
103+ try :
104+ predicted_proba = pipe .predict_proba (X [test_index ])
105+ except AttributeError :
106+ predicted_proba = None
107+ return (y [test_index ], predicted , predicted_proba ), (pipe , train_index , test_index )
104108
105109
106110def calc_metric (output , metrics ):
@@ -114,7 +118,11 @@ def calc_metric(output, metrics):
114118 for metric in metrics :
115119 metric_mod = __import__ ("sklearn.metrics" , fromlist = [metric ])
116120 metric_func = getattr (metric_mod , metric )
117- score .append (metric_func (output [0 ], output [1 ]))
121+ if metric == 'roc_auc_score' and output [2 ] is not None :
122+ # For roc_auc_score, we need to pass the probability of the positive class
123+ score .append (metric_func (output [0 ], output [2 ][:, 1 ]))
124+ else :
125+ score .append (metric_func (output [0 ], output [1 ]))
118126 return score , output
119127
120128
0 commit comments