Skip to content

Commit fdf2d5e

Browse files
committed
initial attempt to work with pydra 0.23+
1 parent b58ad3d commit fdf2d5e

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

pydra_ml/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
7171
messengers=FileMessenger(),
7272
messenger_args={"message_dir": os.path.join(os.getcwd(), "messages")},
7373
)
74-
wf.split(["clf_info", "permute"])
74+
wf.split(clf_info=inputs["clf_info"], permute=inputs["permute"])
7575
wf.add(
7676
read_file_pdt(
7777
name="readcsv",
@@ -102,7 +102,7 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
102102
permute=wf.lzin.permute,
103103
)
104104
)
105-
wf.fit_clf.split("split_index")
105+
wf.fit_clf.split(split_index=wf.gensplit.lzout.split_indices)
106106
wf.add(
107107
calc_metric_pdt(
108108
name="metric", output=wf.fit_clf.lzout.output, metrics=wf.lzin.metrics

pydra_ml/tasks.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
#!/usr/bin/env python
22

3+
import cloudpickle as cp
4+
from pydra.utils.hash import Cache, register_serializer
5+
from sklearn.pipeline import Pipeline
6+
7+
8+
@register_serializer
9+
def bytes_repr_Pipeline(obj: Pipeline, cache: Cache):
10+
yield cp.dump(obj)
11+
312

413
def read_file(filename, x_indices=None, target_vars=None, group=None):
514
"""Read a CSV data file
@@ -126,7 +135,27 @@ def calc_metric(output, metrics):
126135
return score, output
127136

128137

129-
def get_feature_importance(permute, model, gen_feature_importance=True):
138+
def get_feature_importance(
139+
*,
140+
permute: bool,
141+
model: tuple[Pipeline, list, list],
142+
gen_feature_importance: bool = True,
143+
):
144+
"""Compute feature importance for the model
145+
146+
Parameters
147+
----------
148+
permute : bool
149+
Whether or not to run the model in permuted mode
150+
model : tuple(sklearn.pipeline.Pipeline, list, list)
151+
The model to compute feature importance for
152+
gen_feature_importance : bool
153+
Whether or not to generate the feature importance
154+
Returns
155+
-------
156+
list
157+
List of feature importance
158+
"""
130159
if permute or not gen_feature_importance:
131160
return []
132161
pipeline, train_index, test_index = model
@@ -172,7 +201,7 @@ def get_feature_importance(permute, model, gen_feature_importance=True):
172201
pipeline_steps.coefs_
173202
pipeline_steps.coef_
174203
175-
Please add correct method in tasks.py or if inexistent,
204+
Please add correct method in tasks.py or if non-existent,
176205
set gen_feature_importance to false in the spec file.
177206
178207
This is the error that was returned by sklearn:\n\t{e}\n

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers =
2626
[options]
2727
python_requires = >= 3.8
2828
install_requires =
29-
pydra == 0.22.0
29+
pydra >= 0.23.0-alpha
3030
psutil
3131
scikit-learn
3232
seaborn

0 commit comments

Comments
 (0)