Skip to content

Commit 1321a64

Browse files
committed
__str__ based bytes_repr for Pipeline objects
1 parent ee7ed6f commit 1321a64

File tree

3 files changed

+4
-6
lines changed

3 files changed

+4
-6
lines changed

pydra_ml/report.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
223223
f"""There were no {quadrant.upper()}s, this will output NaNs
224224
in the csv and figure for this split column"""
225225
)
226-
shaps_i_quadrant = shaps_i[
226+
shaps_i_quadrant = np.array(shaps_i)[
227227
indexes.get(quadrant)
228228
] # shape (P, F) P prediction x F feature_names
229229
abs_weighted_shap_values = np.abs(shaps_i_quadrant) * split_performance
@@ -325,7 +325,7 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
325325
f"""There were no {quadrant.upper()}s, this will
326326
output NaNs in the csv and figure for this split column"""
327327
)
328-
shaps_i_quadrant = shaps_i[
328+
shaps_i_quadrant = np.array(shaps_i)[
329329
indexes.get(quadrant)
330330
] # shape (P, F) P prediction x F feature_names
331331
abs_weighted_shap_values = np.abs(shaps_i_quadrant) * split_performance

pydra_ml/tasks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
#!/usr/bin/env python
22

3-
import cloudpickle as cp
43
from pydra.utils.hash import Cache, register_serializer
54
from sklearn.pipeline import Pipeline
65

76

87
@register_serializer
98
def bytes_repr_Pipeline(obj: Pipeline, cache: Cache):
10-
yield cp.dumps(obj)
9+
yield str(obj).encode()
1110

1211

1312
def read_file(filename, x_indices=None, target_vars=None, group=None):

pydra_ml/tests/test_classifier.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
32
import numpy as np
43

54
from ..classifier import gen_workflow, run_workflow
@@ -35,7 +34,7 @@ def test_classifier(tmpdir):
3534
"metrics": ["roc_auc_score", "accuracy_score"],
3635
}
3736
wf = gen_workflow(inputs, cache_dir=tmpdir)
38-
results = run_workflow(wf, "cf", {"n_procs": 1})
37+
results = run_workflow(wf, "serial", {"n_procs": 1})
3938
assert results[0][0]["ml_wf.clf_info"][1] == "MLPClassifier"
4039
assert results[0][0]["ml_wf.permute"]
4140
assert results[0][1].output.score[0][0] < results[1][1].output.score[0][0]

0 commit comments

Comments
 (0)