diff --git a/providers/sftp/newsfragments/65475.feature.rst b/providers/sftp/newsfragments/65475.feature.rst new file mode 100644 index 0000000000000..9306edd4e79a8 --- /dev/null +++ b/providers/sftp/newsfragments/65475.feature.rst @@ -0,0 +1 @@ +Add ``deferrable`` parameter to ``SFTPOperator``, allowing ``get`` and ``put`` transfers of individual files to be deferred to a trigger instead of blocking a worker slot. diff --git a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py index 0b47b5b7d5ddb..cb6ca1c02904e 100644 --- a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py @@ -28,8 +28,9 @@ import paramiko -from airflow.providers.common.compat.sdk import AirflowException, BaseOperator +from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf from airflow.providers.sftp.hooks.sftp import SFTPHook +from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger class SFTPOperation: @@ -77,6 +78,10 @@ class SFTPOperator(BaseOperator): :param concurrency: Number of threads when transferring directories. Each thread opens a new SFTP connection. This parameter is used only when transferring directories, not individual files. (Default is 1) :param prefetch: controls whether prefetch is performed (default: True) + :param deferrable: If True, the operator will defer to a trigger to transfer the file(s) asynchronously, + freeing the worker slot while the transfer is in progress. Only supported for ``get`` and ``put`` + operations on individual files (i.e. ``concurrency == 1``); ``delete`` operations and directory + transfers are not supported in deferrable mode. """ @@ -95,6 +100,7 @@ def __init__( create_intermediate_dirs: bool = False, concurrency: int = 1, prefetch: bool = True, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -108,6 +114,7 @@ def __init__( self.remote_filepath = remote_filepath self.concurrency = concurrency self.prefetch = prefetch + self.deferrable = deferrable def execute(self, context: Any) -> str | list[str] | None: if self.local_filepath is None: @@ -142,6 +149,30 @@ def execute(self, context: Any) -> str | list[str] | None: if self.concurrency < 1: raise ValueError(f"concurrency should be greater than 0, got {self.concurrency}") + if self.deferrable: + if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT): + raise ValueError( + f"deferrable=True is only supported for '{SFTPOperation.GET}' and " + f"'{SFTPOperation.PUT}' operations, got '{self.operation}'." + ) + if self.concurrency > 1: + raise ValueError("deferrable=True is not supported when concurrency > 1.") + + sftp_conn_id = self.ssh_conn_id or (self.sftp_hook.ssh_conn_id if self.sftp_hook else None) + if not sftp_conn_id: + raise ValueError("deferrable=True requires ssh_conn_id to be set.") + + self.defer( + trigger=SFTPOperatorTrigger( + local_filepaths=local_filepath_array, + remote_filepaths=remote_filepath_array, + operation=self.operation.lower(), + sftp_conn_id=sftp_conn_id, + create_intermediate_dirs=self.create_intermediate_dirs, + ), + method_name="execute_complete", + ) + file_msg = None try: if self.remote_host is not None: @@ -227,6 +258,23 @@ def execute(self, context: Any) -> str | list[str] | None: return self.local_filepath + def execute_complete(self, context: Any, event: dict[str, Any] | None = None) -> str | list[str] | None: + """ + Execute callback when the trigger fires; returns immediately. + + Relies on the trigger to throw an exception, otherwise it assumes execution was successful. + """ + if event is None: + raise AirflowException("No event received in trigger callback") + + if event.get("status") == "error": + raise AirflowException( + f"Error while processing {self.operation.upper()} operation, error: {event['message']}" + ) + + self.log.info(event["message"]) + return self.local_filepath + @staticmethod def _is_missing_path_error(exc: Exception) -> bool: if isinstance(exc, FileNotFoundError): diff --git a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py index a46f29d5a4abe..09ace23d861d0 100644 --- a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py @@ -18,8 +18,10 @@ from __future__ import annotations import asyncio +import os from collections.abc import AsyncIterator from datetime import datetime +from pathlib import Path from typing import Any from dateutil.parser import parse as parse_date @@ -140,3 +142,67 @@ async def run(self) -> AsyncIterator[TriggerEvent]: def _get_async_hook(self) -> SFTPHookAsync: return SFTPHookAsync(sftp_conn_id=self.sftp_conn_id) + + +class SFTPOperatorTrigger(BaseTrigger): + """ + Trigger that asynchronously transfers files between local and remote SFTP paths. + + :param local_filepaths: list of local file paths involved in the transfer + :param remote_filepaths: list of remote file paths involved in the transfer + :param operation: SFTP operation to perform, either ``get`` or ``put`` + :param sftp_conn_id: SFTP connection ID to be used for connecting to SFTP server + :param create_intermediate_dirs: create missing local intermediate directories + when performing a ``get`` operation + """ + + def __init__( + self, + local_filepaths: list[str], + remote_filepaths: list[str], + operation: str = "put", + sftp_conn_id: str = "sftp_default", + create_intermediate_dirs: bool = False, + ) -> None: + super().__init__() + self.local_filepaths = local_filepaths + self.remote_filepaths = remote_filepaths + self.operation = operation + self.sftp_conn_id = sftp_conn_id + self.create_intermediate_dirs = create_intermediate_dirs + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize SFTPOperatorTrigger arguments and classpath.""" + return ( + "airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger", + { + "local_filepaths": self.local_filepaths, + "remote_filepaths": self.remote_filepaths, + "operation": self.operation, + "sftp_conn_id": self.sftp_conn_id, + "create_intermediate_dirs": self.create_intermediate_dirs, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Transfer each local/remote file pair and yield a TriggerEvent on completion or failure.""" + hook = SFTPHookAsync(sftp_conn_id=self.sftp_conn_id) + try: + for local_filepath, remote_filepath in zip(self.local_filepaths, self.remote_filepaths): + if self.operation == "get": + if self.create_intermediate_dirs: + Path(os.path.dirname(local_filepath)).mkdir(parents=True, exist_ok=True) + await hook.retrieve_file(remote_filepath, local_filepath) + else: + await hook.store_file(remote_filepath, local_filepath) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + yield TriggerEvent( + { + "status": "success", + "message": f"Transferred {len(self.local_filepaths)} file(s).", + "local_filepaths": self.local_filepaths, + } + ) diff --git a/providers/sftp/tests/unit/sftp/operators/test_sftp.py b/providers/sftp/tests/unit/sftp/operators/test_sftp.py index 815d981320107..e3af5282e8d82 100644 --- a/providers/sftp/tests/unit/sftp/operators/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/operators/test_sftp.py @@ -30,9 +30,10 @@ from airflow.models import DAG, Connection from airflow.providers.common.compat.openlineage.facet import Dataset -from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator +from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator from airflow.utils import timezone @@ -675,3 +676,138 @@ def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected): assert lineage.inputs == expected[0] assert lineage.outputs == expected[1] + + @pytest.mark.parametrize( + ("operation", "expected_operation"), + [ + (SFTPOperation.GET, "get"), + (SFTPOperation.PUT, "put"), + ], + ) + def test_deferrable_defers_with_trigger(self, operation, expected_operation): + local_filepath = "/tmp/test" + remote_filepath = "/tmp/remotetest" + task = SFTPOperator( + task_id="test_deferrable_defers", + ssh_conn_id=TEST_CONN_ID, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=operation, + create_intermediate_dirs=True, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + task.execute(None) + + trigger = exc_info.value.trigger + assert isinstance(trigger, SFTPOperatorTrigger) + assert trigger.local_filepaths == [local_filepath] + assert trigger.remote_filepaths == [remote_filepath] + assert trigger.operation == expected_operation + assert trigger.sftp_conn_id == TEST_CONN_ID + assert trigger.create_intermediate_dirs is True + assert exc_info.value.method_name == "execute_complete" + + def test_deferrable_uses_sftp_hook_conn_id_when_no_ssh_conn_id(self): + local_filepath = "/tmp/test" + remote_filepath = "/tmp/remotetest" + task = SFTPOperator( + task_id="test_deferrable_sftp_hook_conn_id", + sftp_hook=self.sftp_hook, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + deferrable=True, + ) + + with pytest.raises(TaskDeferred) as exc_info: + task.execute(None) + + assert exc_info.value.trigger.sftp_conn_id == self.sftp_hook.ssh_conn_id + + def test_deferrable_raises_for_delete_operation(self): + remote_filepath = "/tmp/remotetest" + task = SFTPOperator( + task_id="test_deferrable_delete", + ssh_conn_id=TEST_CONN_ID, + remote_filepath=remote_filepath, + operation=SFTPOperation.DELETE, + deferrable=True, + ) + + with pytest.raises(ValueError, match="deferrable=True is only supported for"): + task.execute(None) + + def test_deferrable_raises_for_concurrency_greater_than_one(self): + local_filepath = "/tmp_local" + remote_filepath = "/tmp_remote" + task = SFTPOperator( + task_id="test_deferrable_concurrency", + ssh_conn_id=TEST_CONN_ID, + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + concurrency=2, + deferrable=True, + ) + + with pytest.raises(ValueError, match="deferrable=True is not supported when concurrency > 1"): + task.execute(None) + + def test_deferrable_raises_without_conn_id(self): + local_filepath = "/tmp/test" + remote_filepath = "/tmp/remotetest" + task = SFTPOperator( + task_id="test_deferrable_no_conn_id", + local_filepath=local_filepath, + remote_filepath=remote_filepath, + operation=SFTPOperation.GET, + deferrable=True, + ) + + with pytest.raises(ValueError, match="deferrable=True requires ssh_conn_id to be set"): + task.execute(None) + + def test_execute_complete_success(self): + task = SFTPOperator( + task_id="test_execute_complete_success", + ssh_conn_id=TEST_CONN_ID, + local_filepath="/tmp/test", + remote_filepath="/tmp/remotetest", + operation=SFTPOperation.GET, + deferrable=True, + ) + + result = task.execute_complete( + context=None, + event={"status": "success", "message": "Transferred 1 file(s).", "local_filepaths": ["/tmp/test"]}, + ) + + assert result == "/tmp/test" + + def test_execute_complete_error(self): + task = SFTPOperator( + task_id="test_execute_complete_error", + ssh_conn_id=TEST_CONN_ID, + local_filepath="/tmp/test", + remote_filepath="/tmp/remotetest", + operation=SFTPOperation.GET, + deferrable=True, + ) + + with pytest.raises(AirflowException, match="boom"): + task.execute_complete(context=None, event={"status": "error", "message": "boom"}) + + def test_execute_complete_no_event(self): + task = SFTPOperator( + task_id="test_execute_complete_no_event", + ssh_conn_id=TEST_CONN_ID, + local_filepath="/tmp/test", + remote_filepath="/tmp/remotetest", + operation=SFTPOperation.GET, + deferrable=True, + ) + + with pytest.raises(AirflowException, match="No event received in trigger callback"): + task.execute_complete(context=None, event=None) diff --git a/providers/sftp/tests/unit/sftp/triggers/test_sftp.py b/providers/sftp/tests/unit/sftp/triggers/test_sftp.py index e6c2502f780d6..f363df90aac60 100644 --- a/providers/sftp/tests/unit/sftp/triggers/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/triggers/test_sftp.py @@ -27,7 +27,7 @@ from asyncssh.sftp import SFTPAttrs, SFTPName from airflow.providers.common.compat.sdk import AirflowException -from airflow.providers.sftp.triggers.sftp import SFTPTrigger +from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger, SFTPTrigger from airflow.triggers.base import TriggerEvent WARNING_CATEGORY: type[Warning] @@ -220,3 +220,114 @@ async def test_sftp_trigger_run_airflow_exception(self, mock_get_files_by_patter # TriggerEvent was not returned assert task.done() is False asyncio.get_event_loop().stop() + + +class TestSFTPOperatorTrigger: + def test_sftp_operator_trigger_serialization(self): + """Asserts that the SFTPOperatorTrigger correctly serializes its arguments and classpath.""" + trigger = SFTPOperatorTrigger( + local_filepaths=["/tmp/local"], + remote_filepaths=["/tmp/remote"], + operation="get", + sftp_conn_id="sftp_default", + create_intermediate_dirs=True, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger" + assert kwargs == { + "local_filepaths": ["/tmp/local"], + "remote_filepaths": ["/tmp/remote"], + "operation": "get", + "sftp_conn_id": "sftp_default", + "create_intermediate_dirs": True, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.retrieve_file") + async def test_sftp_operator_trigger_run_get_success(self, mock_retrieve_file): + """Assert that a TriggerEvent with a success status is yielded after a successful get.""" + mock_retrieve_file.return_value = None + + trigger = SFTPOperatorTrigger( + local_filepaths=["/tmp/local"], + remote_filepaths=["/tmp/remote"], + operation="get", + sftp_conn_id="sftp_default", + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == TriggerEvent( + { + "status": "success", + "message": "Transferred 1 file(s).", + "local_filepaths": ["/tmp/local"], + } + ) + mock_retrieve_file.assert_awaited_once_with("/tmp/remote", "/tmp/local") + + @pytest.mark.asyncio + @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.store_file") + async def test_sftp_operator_trigger_run_put_success(self, mock_store_file): + """Assert that a TriggerEvent with a success status is yielded after a successful put.""" + mock_store_file.return_value = None + + trigger = SFTPOperatorTrigger( + local_filepaths=["/tmp/local"], + remote_filepaths=["/tmp/remote"], + operation="put", + sftp_conn_id="sftp_default", + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == TriggerEvent( + { + "status": "success", + "message": "Transferred 1 file(s).", + "local_filepaths": ["/tmp/local"], + } + ) + mock_store_file.assert_awaited_once_with("/tmp/remote", "/tmp/local") + + @pytest.mark.asyncio + @mock.patch("pathlib.Path.mkdir") + @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.retrieve_file") + async def test_sftp_operator_trigger_run_get_creates_intermediate_dirs( + self, mock_retrieve_file, mock_mkdir + ): + """Assert that intermediate local directories are created for get when requested.""" + mock_retrieve_file.return_value = None + + trigger = SFTPOperatorTrigger( + local_filepaths=["/tmp/some/nested/local"], + remote_filepaths=["/tmp/remote"], + operation="get", + sftp_conn_id="sftp_default", + create_intermediate_dirs=True, + ) + + generator = trigger.run() + await generator.asend(None) + + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync.retrieve_file") + async def test_sftp_operator_trigger_run_failure_state(self, mock_retrieve_file): + """Assert that a TriggerEvent with an error status is yielded if the transfer fails.""" + mock_retrieve_file.side_effect = Exception("An unexpected exception") + + trigger = SFTPOperatorTrigger( + local_filepaths=["/tmp/local"], + remote_filepaths=["/tmp/remote"], + operation="get", + sftp_conn_id="sftp_default", + ) + + generator = trigger.run() + actual_event = await generator.asend(None) + + assert actual_event == TriggerEvent({"status": "error", "message": "An unexpected exception"})