Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions providers/databricks/docs/connections/databricks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ Extra (optional)

* ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure.

The following optional parameter can be used when Airflow workers need to access Databricks or Azure
token endpoints through an HTTP proxy:

* ``proxies``: JSON object with optional ``http`` and ``https`` keys, using the same shape as the
``requests`` and Azure SDK ``proxies`` argument. Only these two keys are accepted. The configured proxy
is applied to Databricks REST API calls, Databricks OAuth token exchanges, and Azure Identity token
acquisition for AAD and default Azure credential authentication.

.. code-block:: json

{
"proxies": {
"http": "http://proxy.example.com:8080",
"https": "http://proxy.example.com:8443"
}
}

Following parameters are necessary if using authentication with OAuth token for Databricks-managed Service Principal:

* ``service_principal_oauth``: required boolean flag. If specified as ``true``, use the Client ID and Client Secret as the Username and Password. See `Authentication using OAuth for service principals <https://docs.databricks.com/en/dev-tools/authentication-oauth.html>`_.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@
}


class DatabricksProxyConfigurationError(AirflowException):
"""Raised when Databricks connection proxy configuration is invalid."""


class BaseDatabricksHook(BaseHook):
"""
Base for interaction with Databricks.
Expand Down Expand Up @@ -115,6 +119,7 @@ class BaseDatabricksHook(BaseHook):
"azure_ad_endpoint",
"azure_resource_id",
"azure_tenant_id",
"proxies",
"service_principal_oauth",
"federated_k8s",
"k8s_token_path",
Expand Down Expand Up @@ -233,6 +238,47 @@ def _get_connection_attr(self, attr_name: str) -> str:
raise ValueError(f"`{attr_name}` must be present in Connection")
return attr

@cached_property
def proxies(self) -> dict[str, str] | None:
"""Return validated proxy configuration from connection extras."""
extra_dejson = self.databricks_conn.extra_dejson
Comment thread
Vamsi-klu marked this conversation as resolved.
if not isinstance(extra_dejson, dict):
return None

proxies = extra_dejson.get("proxies")
if proxies is None:
return None
if not isinstance(proxies, dict):
raise DatabricksProxyConfigurationError("Connection extra 'proxies' must be a JSON object.")

invalid_keys = set(proxies) - {"http", "https"}
if invalid_keys:
invalid_keys_str = ", ".join(sorted(invalid_keys))
raise DatabricksProxyConfigurationError(
f"Connection extra 'proxies' only supports 'http' and 'https' keys. Got: {invalid_keys_str}."
)

for proxy_scheme, proxy_url in proxies.items():
Comment thread
Vamsi-klu marked this conversation as resolved.
if not isinstance(proxy_url, str) or not proxy_url:
raise DatabricksProxyConfigurationError(
"Connection extra 'proxies' values must be non-empty strings. "
f"Invalid value for '{proxy_scheme}'."
)

return proxies or None

def _get_requests_kwargs(self) -> dict[str, Any]:
return {"proxies": self.proxies} if self.proxies else {}

def _get_aiohttp_kwargs(self, url: str) -> dict[str, str]:
if not self.proxies:
return {}
proxy = self.proxies.get(urlsplit(url).scheme)
return {"proxy": proxy} if proxy else {}

def _get_azure_credential_kwargs(self) -> dict[str, dict[str, str]]:
return {"proxies": self.proxies} if self.proxies else {}

def _get_retry_object(self) -> Retrying:
"""
Instantiate a retry object.
Expand Down Expand Up @@ -268,6 +314,7 @@ def _get_sp_token(self, resource: str) -> str:
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
**self._get_requests_kwargs(),
)

resp.raise_for_status()
Expand Down Expand Up @@ -306,6 +353,7 @@ async def _a_get_sp_token(self, resource: str) -> str:
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
**self._get_aiohttp_kwargs(resource),
) as resp:
resp.raise_for_status()
jsn = await resp.json()
Expand Down Expand Up @@ -352,6 +400,7 @@ def _get_aad_token(self, resource: str) -> str:
client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
**self._get_azure_credential_kwargs(),
)
token = credential.get_token(f"{resource}/.default")
jsn = {
Expand Down Expand Up @@ -403,6 +452,7 @@ async def _a_get_aad_token(self, resource: str) -> str:
client_id=self._get_connection_attr("login"),
client_secret=self.databricks_conn.password,
tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"],
**self._get_azure_credential_kwargs(),
) as credential:
token = await credential.get_token(f"{resource}/.default")
jsn = {
Expand Down Expand Up @@ -445,7 +495,9 @@ def _get_aad_token_for_default_az_credential(self, resource: str) -> str:
#
# While there is a WorkloadIdentityCredential class, the below class is advised by Microsoft
# https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
token = DefaultAzureCredential().get_token(f"{resource}/.default")
token = DefaultAzureCredential(**self._get_azure_credential_kwargs()).get_token(
Comment thread
Vamsi-klu marked this conversation as resolved.
f"{resource}/.default"
)

jsn = {
"access_token": token.token,
Expand Down Expand Up @@ -490,7 +542,9 @@ async def _a_get_aad_token_for_default_az_credential(self, resource: str) -> str
#
# While there is a WorkloadIdentityCredential class, the below class is advised by Microsoft
# https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview
token = await AsyncDefaultAzureCredential().get_token(f"{resource}/.default")
token = await AsyncDefaultAzureCredential(
**self._get_azure_credential_kwargs()
).get_token(f"{resource}/.default")

jsn = {
"access_token": token.token,
Expand Down Expand Up @@ -865,6 +919,7 @@ def _get_federated_databricks_token(self, resource: str) -> str:
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
**self._get_requests_kwargs(),
)
resp.raise_for_status()
jsn = resp.json()
Expand Down Expand Up @@ -912,6 +967,7 @@ async def _a_get_federated_databricks_token(self, resource: str) -> str:
"Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.token_timeout_seconds,
**self._get_aiohttp_kwargs(token_exchange_url),
) as resp:
resp.raise_for_status()
jsn = await resp.json()
Expand Down Expand Up @@ -1140,6 +1196,7 @@ def _do_api_call(
auth=auth,
headers=headers,
timeout=self.timeout_seconds,
**self._get_requests_kwargs(),
)
self.log.debug("Response Status Code: %s", response.status_code)
self.log.debug("Response text: %s", response.text)
Expand Down Expand Up @@ -1203,6 +1260,7 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A
auth=auth,
headers={**headers, **self.user_agent_header},
timeout=self.timeout_seconds,
**self._get_aiohttp_kwargs(url),
) as response:
self.log.debug("Response Status Code: %s", response.status)
self.log.debug("Response text: %s", response.text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
LOGIN = "login"
PASSWORD = "password"
TOKEN = "token"
PROXIES = {"http": "http://proxy.example.com:8080", "https": "http://proxy.example.com:8443"}
AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token"
RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1"
Expand Down Expand Up @@ -460,6 +461,23 @@ def test_do_api_call_patch(self, mock_requests):
timeout=self.hook.timeout_seconds,
)

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_do_api_call_uses_proxies_from_connection_extra(self, mock_requests):
hook = DatabricksHook(retry_delay=0)
hook.databricks_conn = Connection(
conn_id=DEFAULT_CONN_ID,
conn_type="databricks",
host=HOST,
login=LOGIN,
password=PASSWORD,
extra=json.dumps({"proxies": PROXIES}),
)
mock_requests.post.return_value.json.return_value = {"run_id": "1"}

assert hook.submit_run({"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER}) == "1"

assert mock_requests.post.call_args.kwargs["proxies"] == PROXIES

@mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
def test_create(self, mock_requests):
mock_requests.codes.ok = 200
Expand Down Expand Up @@ -1610,6 +1628,7 @@ def setup_connections(self, create_connection_without_db):
extra=json.dumps(
{
"azure_tenant_id": "3ff810a6-5504-4ab8-85cb-cd0e6f879c1d",
"proxies": PROXIES,
}
),
)
Expand All @@ -1629,9 +1648,11 @@ def test_submit_run(self, mock_azure_identity, mock_requests):
run_id = self.hook.submit_run(data)

assert run_id == "1"
assert mock_azure_identity.call_args.kwargs["proxies"] == PROXIES
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs["auth"].token == TOKEN
assert kwargs["proxies"] == PROXIES


@pytest.mark.db_test
Expand All @@ -1657,6 +1678,7 @@ def setup_connections(self, create_connection_without_db):
{
"azure_tenant_id": self.tenant_id,
"azure_ad_endpoint": self.ad_endpoint,
"proxies": PROXIES,
}
),
)
Expand All @@ -1679,13 +1701,15 @@ def test_submit_run(self, mock_azure_identity, mock_requests):
azure_identity_args = mock_azure_identity.call_args.kwargs
assert azure_identity_args["tenant_id"] == self.tenant_id
assert azure_identity_args["client_id"] == self.client_id
assert azure_identity_args["proxies"] == PROXIES
get_token_args = mock_azure_identity.return_value.get_token.call_args_list
assert get_token_args == [mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default")]

assert run_id == "1"
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs["auth"].token == TOKEN
assert kwargs["proxies"] == PROXIES


@pytest.mark.db_test
Expand Down Expand Up @@ -1913,6 +1937,25 @@ async def test_do_api_call_patch(self, mock_patch):
timeout=self.hook.timeout_seconds,
)

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_do_api_call_uses_proxies_from_connection_extra(self, mock_get):
self.hook.databricks_conn = Connection(
conn_id=DEFAULT_CONN_ID,
conn_type="databricks",
host=HOST,
login=LOGIN,
password=PASSWORD,
extra=json.dumps({"proxies": PROXIES}),
)
mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE)

async with self.hook:
run_state = await self.hook.a_get_run_state(RUN_ID)

assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE)
assert mock_get.call_args.kwargs["proxy"] == PROXIES["https"]

@pytest.mark.asyncio
@mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get")
async def test_get_run_page_url(self, mock_get):
Expand Down Expand Up @@ -2050,6 +2093,7 @@ def setup_connections(self, create_connection_without_db):
{
"azure_tenant_id": self.tenant_id,
"azure_ad_endpoint": self.ad_endpoint,
"proxies": PROXIES,
}
),
)
Expand Down Expand Up @@ -2079,6 +2123,7 @@ async def test_get_run_state(self, mock_get, mock_client_secret_credential_class
credential_call_kwargs = mock_client_secret_credential_class.call_args.kwargs
assert credential_call_kwargs["tenant_id"] == self.tenant_id
assert credential_call_kwargs["client_id"] == self.client_id
assert credential_call_kwargs["proxies"] == PROXIES

mock_credential.get_token.assert_called_once_with(f"{DEFAULT_DATABRICKS_SCOPE}/.default")

Expand All @@ -2088,6 +2133,7 @@ async def test_get_run_state(self, mock_get, mock_client_secret_credential_class
auth=BearerAuth(TOKEN),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
proxy=PROXIES["https"],
)


Expand Down Expand Up @@ -2235,7 +2281,7 @@ def setup_connections(self, create_connection_without_db):
host=HOST,
login="c64f6d12-f6e4-45a4-846e-032b42b27758",
password="secret",
extra=json.dumps({"service_principal_oauth": True}),
extra=json.dumps({"service_principal_oauth": True, "proxies": PROXIES}),
)
)
self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
Expand All @@ -2255,11 +2301,13 @@ def test_submit_run(self, mock_requests):
ad_call_args = mock_requests.method_calls[0]
assert ad_call_args[1][0] == OIDC_TOKEN_SERVICE_URL.format(f"https://{HOST}")
assert ad_call_args[2]["data"] == "grant_type=client_credentials&scope=all-apis"
assert ad_call_args[2]["proxies"] == PROXIES

assert run_id == "1"
args = mock_requests.post.call_args
kwargs = args[1]
assert kwargs["auth"].token == TOKEN
assert kwargs["proxies"] == PROXIES


@pytest.mark.db_test
Expand All @@ -2278,7 +2326,7 @@ def setup_connections(self, create_connection_without_db):
host=HOST,
login="c64f6d12-f6e4-45a4-846e-032b42b27758",
password="secret",
extra=json.dumps({"service_principal_oauth": True}),
extra=json.dumps({"service_principal_oauth": True, "proxies": PROXIES}),
)
)
self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS)
Expand All @@ -2296,12 +2344,14 @@ async def test_get_run_state(self, mock_post, mock_get):
run_state = await self.hook.a_get_run_state(RUN_ID)

assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE)
assert mock_post.call_args.kwargs["proxy"] == PROXIES["https"]
mock_get.assert_called_once_with(
get_run_endpoint(HOST),
json={"run_id": RUN_ID},
auth=BearerAuth(TOKEN),
headers=self.hook.user_agent_header,
timeout=self.hook.timeout_seconds,
proxy=PROXIES["https"],
)


Expand Down
Loading
Loading