From 9071325950513dbe45084c085954410af6949f33 Mon Sep 17 00:00:00 2001 From: byhsu Date: Tue, 9 May 2023 23:43:19 -0700 Subject: [PATCH 1/2] Fallback to python task if worker is zero for pytorch Signed-off-by: byhsu --- .github/workflows/pythonpublish.yml | 2 +- .../flytekitplugins/kfpytorch/task.py | 6 +++++- .../flytekit-kf-pytorch/tests/test_pytorch_task.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 63e888e6a1..1e55d43079 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -180,7 +180,7 @@ jobs: tags: ${{ steps.external-plugin-service-names.outputs.tags }} build-args: | VERSION=${{ needs.deploy.outputs.version }} - file: ./Dockerfile + file: ./Dockerfile.external-plugin-service cache-from: type=gha cache-to: type=gha,mode=max diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index d79fa785f6..1524d6a8ab 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -65,12 +65,16 @@ class PyTorchFunctionTask(PythonFunctionTask[PyTorch]): """ _PYTORCH_TASK_TYPE = "pytorch" + _PYTORCH_TASK_TYPE_STANDALONE = "python-task" def __init__(self, task_config: PyTorch, task_function: Callable, **kwargs): + + task_type = self._PYTORCH_TASK_TYPE_STANDALONEE if task_config.num_workers == 0 else self._PYTORCH_TASK_TYPE + super().__init__( task_config, task_function, - task_type=self._PYTORCH_TASK_TYPE, + task_type=task_type, **kwargs, ) diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 00eb6c0953..801dcd5422 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -31,3 +31,14 @@ def my_pytorch_task(x: int, y: str) -> int: assert my_pytorch_task.resources.limits == Resources() assert my_pytorch_task.resources.requests == Resources(cpu="1") assert my_pytorch_task.task_type == "pytorch" + +def test_zero_worker(): + @task( + task_config=PyTorch(num_workers=0), + cache=True, + cache_version="1", + requests=Resources(cpu="1"), + ) + def my_pytorch_task(x: int, y: str) -> int: + return x + assert my_pytorch_task.task_type == "python-task" \ No newline at end of file From 6e0160014ae3f27d43486e197b32fe77e8766449 Mon Sep 17 00:00:00 2001 From: byhsu Date: Tue, 9 May 2023 23:45:38 -0700 Subject: [PATCH 2/2] improve Signed-off-by: byhsu --- .github/workflows/pythonpublish.yml | 2 +- plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml index 1e55d43079..63e888e6a1 100644 --- a/.github/workflows/pythonpublish.yml +++ b/.github/workflows/pythonpublish.yml @@ -180,7 +180,7 @@ jobs: tags: ${{ steps.external-plugin-service-names.outputs.tags }} build-args: | VERSION=${{ needs.deploy.outputs.version }} - file: ./Dockerfile.external-plugin-service + file: ./Dockerfile cache-from: type=gha cache-to: type=gha,mode=max diff --git a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py index 801dcd5422..929811cd53 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_pytorch_task.py @@ -41,4 +41,4 @@ def test_zero_worker(): ) def my_pytorch_task(x: int, y: str) -> int: return x - assert my_pytorch_task.task_type == "python-task" \ No newline at end of file + assert my_pytorch_task.task_type == "python-task"