Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions providers/sftp/newsfragments/65475.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``deferrable`` parameter to ``SFTPOperator``, allowing ``get`` and ``put`` transfers of individual files to be deferred to a trigger instead of blocking a worker slot.
50 changes: 49 additions & 1 deletion providers/sftp/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

import paramiko

from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger


class SFTPOperation:
Expand Down Expand Up @@ -77,6 +78,10 @@ class SFTPOperator(BaseOperator):
:param concurrency: Number of threads when transferring directories. Each thread opens a new SFTP connection.
This parameter is used only when transferring directories, not individual files. (Default is 1)
:param prefetch: controls whether prefetch is performed (default: True)
:param deferrable: If True, the operator will defer to a trigger to transfer the file(s) asynchronously,
freeing the worker slot while the transfer is in progress. Only supported for ``get`` and ``put``
operations on individual files (i.e. ``concurrency == 1``); ``delete`` operations and directory
transfers are not supported in deferrable mode.

"""

Expand All @@ -95,6 +100,7 @@ def __init__(
create_intermediate_dirs: bool = False,
concurrency: int = 1,
prefetch: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -108,6 +114,7 @@ def __init__(
self.remote_filepath = remote_filepath
self.concurrency = concurrency
self.prefetch = prefetch
self.deferrable = deferrable

def execute(self, context: Any) -> str | list[str] | None:
if self.local_filepath is None:
Expand Down Expand Up @@ -142,6 +149,30 @@ def execute(self, context: Any) -> str | list[str] | None:
if self.concurrency < 1:
raise ValueError(f"concurrency should be greater than 0, got {self.concurrency}")

if self.deferrable:
if self.operation.lower() not in (SFTPOperation.GET, SFTPOperation.PUT):
raise ValueError(
f"deferrable=True is only supported for '{SFTPOperation.GET}' and "
f"'{SFTPOperation.PUT}' operations, got '{self.operation}'."
)
if self.concurrency > 1:
raise ValueError("deferrable=True is not supported when concurrency > 1.")

sftp_conn_id = self.ssh_conn_id or (self.sftp_hook.ssh_conn_id if self.sftp_hook else None)
if not sftp_conn_id:
raise ValueError("deferrable=True requires ssh_conn_id to be set.")

self.defer(
trigger=SFTPOperatorTrigger(
local_filepaths=local_filepath_array,
remote_filepaths=remote_filepath_array,
operation=self.operation.lower(),
sftp_conn_id=sftp_conn_id,
create_intermediate_dirs=self.create_intermediate_dirs,
),
method_name="execute_complete",
)

file_msg = None
try:
if self.remote_host is not None:
Expand Down Expand Up @@ -227,6 +258,23 @@ def execute(self, context: Any) -> str | list[str] | None:

return self.local_filepath

def execute_complete(self, context: Any, event: dict[str, Any] | None = None) -> str | list[str] | None:
"""
Execute callback when the trigger fires; returns immediately.

Relies on the trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event is None:
raise AirflowException("No event received in trigger callback")

if event.get("status") == "error":
raise AirflowException(
f"Error while processing {self.operation.upper()} operation, error: {event['message']}"
)

self.log.info(event["message"])
return self.local_filepath

@staticmethod
def _is_missing_path_error(exc: Exception) -> bool:
if isinstance(exc, FileNotFoundError):
Expand Down
66 changes: 66 additions & 0 deletions providers/sftp/src/airflow/providers/sftp/triggers/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
from __future__ import annotations

import asyncio
import os
from collections.abc import AsyncIterator
from datetime import datetime
from pathlib import Path
from typing import Any

from dateutil.parser import parse as parse_date
Expand Down Expand Up @@ -140,3 +142,67 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

def _get_async_hook(self) -> SFTPHookAsync:
return SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)


class SFTPOperatorTrigger(BaseTrigger):
"""
Trigger that asynchronously transfers files between local and remote SFTP paths.

:param local_filepaths: list of local file paths involved in the transfer
:param remote_filepaths: list of remote file paths involved in the transfer
:param operation: SFTP operation to perform, either ``get`` or ``put``
:param sftp_conn_id: SFTP connection ID to be used for connecting to SFTP server
:param create_intermediate_dirs: create missing local intermediate directories
when performing a ``get`` operation
"""

def __init__(
self,
local_filepaths: list[str],
remote_filepaths: list[str],
operation: str = "put",
sftp_conn_id: str = "sftp_default",
create_intermediate_dirs: bool = False,
) -> None:
super().__init__()
self.local_filepaths = local_filepaths
self.remote_filepaths = remote_filepaths
self.operation = operation
self.sftp_conn_id = sftp_conn_id
self.create_intermediate_dirs = create_intermediate_dirs

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize SFTPOperatorTrigger arguments and classpath."""
return (
"airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger",
{
"local_filepaths": self.local_filepaths,
"remote_filepaths": self.remote_filepaths,
"operation": self.operation,
"sftp_conn_id": self.sftp_conn_id,
"create_intermediate_dirs": self.create_intermediate_dirs,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Transfer each local/remote file pair and yield a TriggerEvent on completion or failure."""
hook = SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)
try:
for local_filepath, remote_filepath in zip(self.local_filepaths, self.remote_filepaths):
if self.operation == "get":
if self.create_intermediate_dirs:
Path(os.path.dirname(local_filepath)).mkdir(parents=True, exist_ok=True)
await hook.retrieve_file(remote_filepath, local_filepath)
else:
await hook.store_file(remote_filepath, local_filepath)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})
return

yield TriggerEvent(
{
"status": "success",
"message": f"Transferred {len(self.local_filepaths)} file(s).",
"local_filepaths": self.local_filepaths,
}
)
138 changes: 137 additions & 1 deletion providers/sftp/tests/unit/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@

from airflow.models import DAG, Connection
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -675,3 +676,138 @@ def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected):

assert lineage.inputs == expected[0]
assert lineage.outputs == expected[1]

@pytest.mark.parametrize(
("operation", "expected_operation"),
[
(SFTPOperation.GET, "get"),
(SFTPOperation.PUT, "put"),
],
)
def test_deferrable_defers_with_trigger(self, operation, expected_operation):
local_filepath = "/tmp/test"
remote_filepath = "/tmp/remotetest"
task = SFTPOperator(
task_id="test_deferrable_defers",
ssh_conn_id=TEST_CONN_ID,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=operation,
create_intermediate_dirs=True,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc_info:
task.execute(None)

trigger = exc_info.value.trigger
assert isinstance(trigger, SFTPOperatorTrigger)
assert trigger.local_filepaths == [local_filepath]
assert trigger.remote_filepaths == [remote_filepath]
assert trigger.operation == expected_operation
assert trigger.sftp_conn_id == TEST_CONN_ID
assert trigger.create_intermediate_dirs is True
assert exc_info.value.method_name == "execute_complete"

def test_deferrable_uses_sftp_hook_conn_id_when_no_ssh_conn_id(self):
local_filepath = "/tmp/test"
remote_filepath = "/tmp/remotetest"
task = SFTPOperator(
task_id="test_deferrable_sftp_hook_conn_id",
sftp_hook=self.sftp_hook,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.GET,
deferrable=True,
)

with pytest.raises(TaskDeferred) as exc_info:
task.execute(None)

assert exc_info.value.trigger.sftp_conn_id == self.sftp_hook.ssh_conn_id

def test_deferrable_raises_for_delete_operation(self):
remote_filepath = "/tmp/remotetest"
task = SFTPOperator(
task_id="test_deferrable_delete",
ssh_conn_id=TEST_CONN_ID,
remote_filepath=remote_filepath,
operation=SFTPOperation.DELETE,
deferrable=True,
)

with pytest.raises(ValueError, match="deferrable=True is only supported for"):
task.execute(None)

def test_deferrable_raises_for_concurrency_greater_than_one(self):
local_filepath = "/tmp_local"
remote_filepath = "/tmp_remote"
task = SFTPOperator(
task_id="test_deferrable_concurrency",
ssh_conn_id=TEST_CONN_ID,
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.GET,
concurrency=2,
deferrable=True,
)

with pytest.raises(ValueError, match="deferrable=True is not supported when concurrency > 1"):
task.execute(None)

def test_deferrable_raises_without_conn_id(self):
local_filepath = "/tmp/test"
remote_filepath = "/tmp/remotetest"
task = SFTPOperator(
task_id="test_deferrable_no_conn_id",
local_filepath=local_filepath,
remote_filepath=remote_filepath,
operation=SFTPOperation.GET,
deferrable=True,
)

with pytest.raises(ValueError, match="deferrable=True requires ssh_conn_id to be set"):
task.execute(None)

def test_execute_complete_success(self):
task = SFTPOperator(
task_id="test_execute_complete_success",
ssh_conn_id=TEST_CONN_ID,
local_filepath="/tmp/test",
remote_filepath="/tmp/remotetest",
operation=SFTPOperation.GET,
deferrable=True,
)

result = task.execute_complete(
context=None,
event={"status": "success", "message": "Transferred 1 file(s).", "local_filepaths": ["/tmp/test"]},
)

assert result == "/tmp/test"

def test_execute_complete_error(self):
task = SFTPOperator(
task_id="test_execute_complete_error",
ssh_conn_id=TEST_CONN_ID,
local_filepath="/tmp/test",
remote_filepath="/tmp/remotetest",
operation=SFTPOperation.GET,
deferrable=True,
)

with pytest.raises(AirflowException, match="boom"):
task.execute_complete(context=None, event={"status": "error", "message": "boom"})

def test_execute_complete_no_event(self):
task = SFTPOperator(
task_id="test_execute_complete_no_event",
ssh_conn_id=TEST_CONN_ID,
local_filepath="/tmp/test",
remote_filepath="/tmp/remotetest",
operation=SFTPOperation.GET,
deferrable=True,
)

with pytest.raises(AirflowException, match="No event received in trigger callback"):
task.execute_complete(context=None, event=None)
Loading