diff --git a/openml/tasks/functions.py b/openml/tasks/functions.py index 25156f2e5..c4bb13617 100644 --- a/openml/tasks/functions.py +++ b/openml/tasks/functions.py @@ -492,6 +492,7 @@ def _create_task_from_xml(xml: str) -> OpenMLTask: "data_set_id": inputs["source_data"]["oml:data_set"]["oml:data_set_id"], "evaluation_measure": evaluation_measures, } + # TODO: add OpenMLClusteringTask? if task_type in ( TaskType.SUPERVISED_CLASSIFICATION, TaskType.SUPERVISED_REGRESSION, @@ -508,6 +509,10 @@ def _create_task_from_xml(xml: str) -> OpenMLTask: common_kwargs["estimation_procedure_type"] = inputs["estimation_procedure"][ "oml:estimation_procedure" ]["oml:type"] + common_kwargs["estimation_procedure_id"] = int( + inputs["estimation_procedure"]["oml:estimation_procedure"]["oml:id"] + ) + common_kwargs["estimation_parameters"] = estimation_parameters common_kwargs["target_name"] = inputs["source_data"]["oml:data_set"]["oml:target_feature"] common_kwargs["data_splits_url"] = inputs["estimation_procedure"][ diff --git a/tests/test_tasks/test_classification_task.py b/tests/test_tasks/test_classification_task.py index bb4545154..d3553262f 100644 --- a/tests/test_tasks/test_classification_task.py +++ b/tests/test_tasks/test_classification_task.py @@ -15,7 +15,7 @@ def setUp(self, n_levels: int = 1): super().setUp() self.task_id = 119 # diabetes self.task_type = TaskType.SUPERVISED_CLASSIFICATION - self.estimation_procedure = 1 + self.estimation_procedure = 5 def test_get_X_and_Y(self): X, Y = super().test_get_X_and_Y() @@ -30,7 +30,8 @@ def test_download_task(self): assert task.task_id == self.task_id assert task.task_type_id == TaskType.SUPERVISED_CLASSIFICATION assert task.dataset_id == 20 + assert task.estimation_procedure_id == self.estimation_procedure def test_class_labels(self): task = get_task(self.task_id) - assert task.class_labels == ["tested_negative", "tested_positive"] + assert task.class_labels == ["tested_negative", "tested_positive"] \ No newline at end of file diff --git a/tests/test_tasks/test_regression_task.py b/tests/test_tasks/test_regression_task.py index 36decc534..14ed59470 100644 --- a/tests/test_tasks/test_regression_task.py +++ b/tests/test_tasks/test_regression_task.py @@ -18,11 +18,11 @@ class OpenMLRegressionTaskTest(OpenMLSupervisedTaskTest): def setUp(self, n_levels: int = 1): super().setUp() - + self.estimation_procedure = 9 task_meta_data = { "task_type": TaskType.SUPERVISED_REGRESSION, "dataset_id": 105, # wisconsin - "estimation_procedure_id": 7, + "estimation_procedure_id": self.estimation_procedure, # non default value to test estimation procedure id "target_name": "time", } _task_id = check_task_existence(**task_meta_data) @@ -46,7 +46,7 @@ def setUp(self, n_levels: int = 1): raise Exception(repr(e)) self.task_id = task_id self.task_type = TaskType.SUPERVISED_REGRESSION - self.estimation_procedure = 7 + def test_get_X_and_Y(self): X, Y = super().test_get_X_and_Y() @@ -61,3 +61,4 @@ def test_download_task(self): assert task.task_id == self.task_id assert task.task_type_id == TaskType.SUPERVISED_REGRESSION assert task.dataset_id == 105 + assert task.estimation_procedure_id == self.estimation_procedure