From cdfb6135a54ab4f6d5ef41dc7b3a31fca1959cad Mon Sep 17 00:00:00 2001 From: deepinsight coder Date: Sun, 14 Jun 2026 02:51:24 +0000 Subject: [PATCH] Add Databricks connection proxy support --- .../docs/connections/databricks.rst | 17 +++++ .../databricks/hooks/databricks_base.py | 62 ++++++++++++++++++- .../unit/databricks/hooks/test_databricks.py | 54 +++++++++++++++- .../databricks/hooks/test_databricks_base.py | 55 ++++++++++++++++ .../databricks/hooks/test_databricks_sql.py | 24 +++++++ 5 files changed, 208 insertions(+), 4 deletions(-) diff --git a/providers/databricks/docs/connections/databricks.rst b/providers/databricks/docs/connections/databricks.rst index 630526266be8a..b5c4aefb63504 100644 --- a/providers/databricks/docs/connections/databricks.rst +++ b/providers/databricks/docs/connections/databricks.rst @@ -81,6 +81,23 @@ Extra (optional) * ``token``: Specify PAT to use. Consider to switch to specification of PAT in the Password field as it's more secure. + The following optional parameter can be used when Airflow workers need to access Databricks or Azure + token endpoints through an HTTP proxy: + + * ``proxies``: JSON object with optional ``http`` and ``https`` keys, using the same shape as the + ``requests`` and Azure SDK ``proxies`` argument. Only these two keys are accepted. The configured proxy + is applied to Databricks REST API calls, Databricks OAuth token exchanges, and Azure Identity token + acquisition for AAD and default Azure credential authentication. + + .. code-block:: json + + { + "proxies": { + "http": "http://proxy.example.com:8080", + "https": "http://proxy.example.com:8443" + } + } + Following parameters are necessary if using authentication with OAuth token for Databricks-managed Service Principal: * ``service_principal_oauth``: required boolean flag. If specified as ``true``, use the Client ID and Client Secret as the Username and Password. See `Authentication using OAuth for service principals `_. diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py index 7ab51fc7dfa12..272b65c6120d5 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_base.py @@ -87,6 +87,10 @@ } +class DatabricksProxyConfigurationError(AirflowException): + """Raised when Databricks connection proxy configuration is invalid.""" + + class BaseDatabricksHook(BaseHook): """ Base for interaction with Databricks. @@ -115,6 +119,7 @@ class BaseDatabricksHook(BaseHook): "azure_ad_endpoint", "azure_resource_id", "azure_tenant_id", + "proxies", "service_principal_oauth", "federated_k8s", "k8s_token_path", @@ -233,6 +238,47 @@ def _get_connection_attr(self, attr_name: str) -> str: raise ValueError(f"`{attr_name}` must be present in Connection") return attr + @cached_property + def proxies(self) -> dict[str, str] | None: + """Return validated proxy configuration from connection extras.""" + extra_dejson = self.databricks_conn.extra_dejson + if not isinstance(extra_dejson, dict): + return None + + proxies = extra_dejson.get("proxies") + if proxies is None: + return None + if not isinstance(proxies, dict): + raise DatabricksProxyConfigurationError("Connection extra 'proxies' must be a JSON object.") + + invalid_keys = set(proxies) - {"http", "https"} + if invalid_keys: + invalid_keys_str = ", ".join(sorted(invalid_keys)) + raise DatabricksProxyConfigurationError( + f"Connection extra 'proxies' only supports 'http' and 'https' keys. Got: {invalid_keys_str}." + ) + + for proxy_scheme, proxy_url in proxies.items(): + if not isinstance(proxy_url, str) or not proxy_url: + raise DatabricksProxyConfigurationError( + "Connection extra 'proxies' values must be non-empty strings. " + f"Invalid value for '{proxy_scheme}'." + ) + + return proxies or None + + def _get_requests_kwargs(self) -> dict[str, Any]: + return {"proxies": self.proxies} if self.proxies else {} + + def _get_aiohttp_kwargs(self, url: str) -> dict[str, str]: + if not self.proxies: + return {} + proxy = self.proxies.get(urlsplit(url).scheme) + return {"proxy": proxy} if proxy else {} + + def _get_azure_credential_kwargs(self) -> dict[str, dict[str, str]]: + return {"proxies": self.proxies} if self.proxies else {} + def _get_retry_object(self) -> Retrying: """ Instantiate a retry object. @@ -268,6 +314,7 @@ def _get_sp_token(self, resource: str) -> str: "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.token_timeout_seconds, + **self._get_requests_kwargs(), ) resp.raise_for_status() @@ -306,6 +353,7 @@ async def _a_get_sp_token(self, resource: str) -> str: "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.token_timeout_seconds, + **self._get_aiohttp_kwargs(resource), ) as resp: resp.raise_for_status() jsn = await resp.json() @@ -352,6 +400,7 @@ def _get_aad_token(self, resource: str) -> str: client_id=self._get_connection_attr("login"), client_secret=self.databricks_conn.password, tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], + **self._get_azure_credential_kwargs(), ) token = credential.get_token(f"{resource}/.default") jsn = { @@ -403,6 +452,7 @@ async def _a_get_aad_token(self, resource: str) -> str: client_id=self._get_connection_attr("login"), client_secret=self.databricks_conn.password, tenant_id=self.databricks_conn.extra_dejson["azure_tenant_id"], + **self._get_azure_credential_kwargs(), ) as credential: token = await credential.get_token(f"{resource}/.default") jsn = { @@ -445,7 +495,9 @@ def _get_aad_token_for_default_az_credential(self, resource: str) -> str: # # While there is a WorkloadIdentityCredential class, the below class is advised by Microsoft # https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview - token = DefaultAzureCredential().get_token(f"{resource}/.default") + token = DefaultAzureCredential(**self._get_azure_credential_kwargs()).get_token( + f"{resource}/.default" + ) jsn = { "access_token": token.token, @@ -490,7 +542,9 @@ async def _a_get_aad_token_for_default_az_credential(self, resource: str) -> str # # While there is a WorkloadIdentityCredential class, the below class is advised by Microsoft # https://learn.microsoft.com/en-us/azure/aks/workload-identity-overview - token = await AsyncDefaultAzureCredential().get_token(f"{resource}/.default") + token = await AsyncDefaultAzureCredential( + **self._get_azure_credential_kwargs() + ).get_token(f"{resource}/.default") jsn = { "access_token": token.token, @@ -865,6 +919,7 @@ def _get_federated_databricks_token(self, resource: str) -> str: "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.token_timeout_seconds, + **self._get_requests_kwargs(), ) resp.raise_for_status() jsn = resp.json() @@ -912,6 +967,7 @@ async def _a_get_federated_databricks_token(self, resource: str) -> str: "Content-Type": "application/x-www-form-urlencoded", }, timeout=self.token_timeout_seconds, + **self._get_aiohttp_kwargs(token_exchange_url), ) as resp: resp.raise_for_status() jsn = await resp.json() @@ -1140,6 +1196,7 @@ def _do_api_call( auth=auth, headers=headers, timeout=self.timeout_seconds, + **self._get_requests_kwargs(), ) self.log.debug("Response Status Code: %s", response.status_code) self.log.debug("Response text: %s", response.text) @@ -1203,6 +1260,7 @@ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, A auth=auth, headers={**headers, **self.user_agent_header}, timeout=self.timeout_seconds, + **self._get_aiohttp_kwargs(url), ) as response: self.log.debug("Response Status Code: %s", response.status) self.log.debug("Response text: %s", response.text) diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py index 558e39a212987..af2f028b44cf8 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks.py @@ -78,6 +78,7 @@ LOGIN = "login" PASSWORD = "password" TOKEN = "token" +PROXIES = {"http": "http://proxy.example.com:8080", "https": "http://proxy.example.com:8443"} AZURE_DEFAULT_AD_ENDPOINT = "https://login.microsoftonline.com" AZURE_TOKEN_SERVICE_URL = "{}/{}/oauth2/token" RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" @@ -460,6 +461,23 @@ def test_do_api_call_patch(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_do_api_call_uses_proxies_from_connection_extra(self, mock_requests): + hook = DatabricksHook(retry_delay=0) + hook.databricks_conn = Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=json.dumps({"proxies": PROXIES}), + ) + mock_requests.post.return_value.json.return_value = {"run_id": "1"} + + assert hook.submit_run({"notebook_task": NOTEBOOK_TASK, "new_cluster": NEW_CLUSTER}) == "1" + + assert mock_requests.post.call_args.kwargs["proxies"] == PROXIES + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_create(self, mock_requests): mock_requests.codes.ok = 200 @@ -1610,6 +1628,7 @@ def setup_connections(self, create_connection_without_db): extra=json.dumps( { "azure_tenant_id": "3ff810a6-5504-4ab8-85cb-cd0e6f879c1d", + "proxies": PROXIES, } ), ) @@ -1629,9 +1648,11 @@ def test_submit_run(self, mock_azure_identity, mock_requests): run_id = self.hook.submit_run(data) assert run_id == "1" + assert mock_azure_identity.call_args.kwargs["proxies"] == PROXIES args = mock_requests.post.call_args kwargs = args[1] assert kwargs["auth"].token == TOKEN + assert kwargs["proxies"] == PROXIES @pytest.mark.db_test @@ -1657,6 +1678,7 @@ def setup_connections(self, create_connection_without_db): { "azure_tenant_id": self.tenant_id, "azure_ad_endpoint": self.ad_endpoint, + "proxies": PROXIES, } ), ) @@ -1679,6 +1701,7 @@ def test_submit_run(self, mock_azure_identity, mock_requests): azure_identity_args = mock_azure_identity.call_args.kwargs assert azure_identity_args["tenant_id"] == self.tenant_id assert azure_identity_args["client_id"] == self.client_id + assert azure_identity_args["proxies"] == PROXIES get_token_args = mock_azure_identity.return_value.get_token.call_args_list assert get_token_args == [mock.call(f"{DEFAULT_DATABRICKS_SCOPE}/.default")] @@ -1686,6 +1709,7 @@ def test_submit_run(self, mock_azure_identity, mock_requests): args = mock_requests.post.call_args kwargs = args[1] assert kwargs["auth"].token == TOKEN + assert kwargs["proxies"] == PROXIES @pytest.mark.db_test @@ -1913,6 +1937,25 @@ async def test_do_api_call_patch(self, mock_patch): timeout=self.hook.timeout_seconds, ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + async def test_do_api_call_uses_proxies_from_connection_extra(self, mock_get): + self.hook.databricks_conn = Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + login=LOGIN, + password=PASSWORD, + extra=json.dumps({"proxies": PROXIES}), + ) + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_RESPONSE) + + async with self.hook: + run_state = await self.hook.a_get_run_state(RUN_ID) + + assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + assert mock_get.call_args.kwargs["proxy"] == PROXIES["https"] + @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") async def test_get_run_page_url(self, mock_get): @@ -2050,6 +2093,7 @@ def setup_connections(self, create_connection_without_db): { "azure_tenant_id": self.tenant_id, "azure_ad_endpoint": self.ad_endpoint, + "proxies": PROXIES, } ), ) @@ -2079,6 +2123,7 @@ async def test_get_run_state(self, mock_get, mock_client_secret_credential_class credential_call_kwargs = mock_client_secret_credential_class.call_args.kwargs assert credential_call_kwargs["tenant_id"] == self.tenant_id assert credential_call_kwargs["client_id"] == self.client_id + assert credential_call_kwargs["proxies"] == PROXIES mock_credential.get_token.assert_called_once_with(f"{DEFAULT_DATABRICKS_SCOPE}/.default") @@ -2088,6 +2133,7 @@ async def test_get_run_state(self, mock_get, mock_client_secret_credential_class auth=BearerAuth(TOKEN), headers=self.hook.user_agent_header, timeout=self.hook.timeout_seconds, + proxy=PROXIES["https"], ) @@ -2235,7 +2281,7 @@ def setup_connections(self, create_connection_without_db): host=HOST, login="c64f6d12-f6e4-45a4-846e-032b42b27758", password="secret", - extra=json.dumps({"service_principal_oauth": True}), + extra=json.dumps({"service_principal_oauth": True, "proxies": PROXIES}), ) ) self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @@ -2255,11 +2301,13 @@ def test_submit_run(self, mock_requests): ad_call_args = mock_requests.method_calls[0] assert ad_call_args[1][0] == OIDC_TOKEN_SERVICE_URL.format(f"https://{HOST}") assert ad_call_args[2]["data"] == "grant_type=client_credentials&scope=all-apis" + assert ad_call_args[2]["proxies"] == PROXIES assert run_id == "1" args = mock_requests.post.call_args kwargs = args[1] assert kwargs["auth"].token == TOKEN + assert kwargs["proxies"] == PROXIES @pytest.mark.db_test @@ -2278,7 +2326,7 @@ def setup_connections(self, create_connection_without_db): host=HOST, login="c64f6d12-f6e4-45a4-846e-032b42b27758", password="secret", - extra=json.dumps({"service_principal_oauth": True}), + extra=json.dumps({"service_principal_oauth": True, "proxies": PROXIES}), ) ) self.hook = DatabricksHook(retry_args=DEFAULT_RETRY_ARGS) @@ -2296,12 +2344,14 @@ async def test_get_run_state(self, mock_post, mock_get): run_state = await self.hook.a_get_run_state(RUN_ID) assert run_state == RunState(LIFE_CYCLE_STATE, RESULT_STATE, STATE_MESSAGE) + assert mock_post.call_args.kwargs["proxy"] == PROXIES["https"] mock_get.assert_called_once_with( get_run_endpoint(HOST), json={"run_id": RUN_ID}, auth=BearerAuth(TOKEN), headers=self.hook.user_agent_header, timeout=self.hook.timeout_seconds, + proxy=PROXIES["https"], ) diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py index 090a3e34c7814..393107497d711 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_base.py @@ -37,9 +37,11 @@ K8S_CA_CERT_PATH, TOKEN_REFRESH_LEAD_TIME, BaseDatabricksHook, + DatabricksProxyConfigurationError, ) DEFAULT_CONN_ID = "databricks_default" +PROXIES = {"http": "http://proxy.example.com:8080", "https": "http://proxy.example.com:8443"} class TestBaseDatabricksHook: @@ -110,6 +112,59 @@ def test_init_with_default_params(self): def test_parse_host(self, input_url, expected_host): assert BaseDatabricksHook._parse_host(input_url) == expected_host + @mock.patch( + "airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn", + new_callable=mock.PropertyMock, + ) + def test_proxies_from_extra(self, mock_conn): + mock_conn.return_value = Connection(extra={"proxies": PROXIES}) + hook = BaseDatabricksHook() + + assert hook.proxies == PROXIES + assert hook._get_requests_kwargs() == {"proxies": PROXIES} + assert hook._get_azure_credential_kwargs() == {"proxies": PROXIES} + + @mock.patch( + "airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn", + new_callable=mock.PropertyMock, + ) + def test_aiohttp_proxy_uses_request_scheme(self, mock_conn): + mock_conn.return_value = Connection(extra={"proxies": PROXIES}) + hook = BaseDatabricksHook() + + assert hook._get_aiohttp_kwargs("https://example.databricks.com/api") == {"proxy": PROXIES["https"]} + assert hook._get_aiohttp_kwargs("http://example.databricks.com/api") == {"proxy": PROXIES["http"]} + + @mock.patch( + "airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn", + new_callable=mock.PropertyMock, + ) + def test_aiohttp_proxy_returns_empty_kwargs_without_matching_scheme(self, mock_conn): + mock_conn.return_value = Connection(extra={"proxies": {"http": PROXIES["http"]}}) + hook = BaseDatabricksHook() + + assert hook._get_aiohttp_kwargs("https://example.databricks.com/api") == {} + + @mock.patch( + "airflow.providers.databricks.hooks.databricks_base.BaseDatabricksHook.databricks_conn", + new_callable=mock.PropertyMock, + ) + @pytest.mark.parametrize( + ("proxies", "message"), + [ + ("http://proxy.example.com:8080", "must be a JSON object"), + ({"ftp": "http://proxy.example.com:8080"}, "only supports 'http' and 'https' keys"), + ({"https": ""}, "values must be non-empty strings"), + ({"https": 8080}, "values must be non-empty strings"), + ], + ) + def test_proxies_invalid_extra_raises(self, mock_conn, proxies, message): + mock_conn.return_value = Connection(extra={"proxies": proxies}) + hook = BaseDatabricksHook() + + with pytest.raises(DatabricksProxyConfigurationError, match=message): + hook.proxies + @mock.patch("requests.post") @time_machine.travel("2025-07-12 12:00:00", tick=False) def test_get_sp_token(self, mock_post): diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index cd3c00e2839b1..6fd835b787287 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -854,6 +854,30 @@ def test_get_conn_no_query_tags(mock_connect, mock_get_requests): assert session_cfg is None or "QUERY_TAGS" not in session_cfg +@mock.patch("airflow.providers.databricks.hooks.databricks_sql.sql.connect") +def test_get_conn_does_not_leak_proxies_into_connector(mock_connect, mock_get_requests): + """A ``proxies`` connection extra must not be forwarded to ``sql.connect()``. + + ``proxies`` configures the REST/token HTTP paths only; the + databricks-sql-connector does not accept it and raises ``TypeError`` on + unexpected keyword arguments (>=4.0.0). It is listed in + ``extra_parameters`` so ``_get_extra_config`` strips it from connect kwargs. + """ + hook = DatabricksSqlHook(databricks_conn_id=DEFAULT_CONN_ID, http_path=HTTP_PATH) + hook.databricks_conn = Connection( + conn_id=DEFAULT_CONN_ID, + conn_type="databricks", + host=HOST, + password=TOKEN, + extra={"proxies": {"https": "http://proxy.example.com:8443"}}, + ) + + hook.get_conn() + + mock_connect.assert_called_once() + assert "proxies" not in mock_connect.call_args.kwargs + + class TestFormatQueryTags: def test_simple_values(self): result = _format_query_tags({"dag_id": "my_dag", "task_id": "my_task"})