diff --git a/src/fides/api/service/connectors/saas_connector.py b/src/fides/api/service/connectors/saas_connector.py index 8b53545ac30..7eb247d3802 100644 --- a/src/fides/api/service/connectors/saas_connector.py +++ b/src/fides/api/service/connectors/saas_connector.py @@ -22,6 +22,7 @@ from fides.api.models.connectionconfig import ConnectionConfig, ConnectionTestStatus from fides.api.models.policy import Policy from fides.api.models.privacy_request import PrivacyRequest, RequestTask +from fides.api.models.privacy_request.request_task import AsyncTaskType from fides.api.schemas.consentable_item import ( ConsentableItem, build_consent_item_hierarchy, @@ -277,11 +278,20 @@ def retrieve_data( # Delegate async requests with get_db() as db: - # Guard clause to ensure we only run async access requests for access requests - if self.guard_access_request(policy): - if async_dsr_strategy := _get_async_dsr_strategy( - db, request_task, query_config, ActionType.access + if async_dsr_strategy := _get_async_dsr_strategy( + db, request_task, query_config, ActionType.access + ): + check_guard_access_request = self.guard_access_request(policy) + # Guard clause only applies to polling requests + # Callback requests should always proceed + if (async_dsr_strategy.type == AsyncTaskType.polling) and ( + not check_guard_access_request ): + logger.info( + f"Skipping async access request for policy: {policy.name}" + ) + return [] + if check_guard_access_request: return async_dsr_strategy.async_retrieve_data( client=self.create_client(), request_task_id=request_task.id, diff --git a/tests/fixtures/saas/test_data/saas_async_polling_config.yml b/tests/fixtures/saas/test_data/saas_async_polling_config.yml new file mode 100644 index 00000000000..4c26bb4d47d --- /dev/null +++ b/tests/fixtures/saas/test_data/saas_async_polling_config.yml @@ -0,0 +1,66 @@ +saas_config: + fides_key: saas_async_polling_config + name: Async Polling Example Custom Connector + type: async_polling_example + description: Test Async Polling Config + version: 0.0.1 + + connector_params: + - name: domain + - name: api_token + label: API token + + client_config: + protocol: http + host: + authentication: + strategy: bearer + configuration: + token: + + test_request: + method: GET + path: / + + endpoints: + - name: user + requests: + read: + method: GET + path: /api/v1/user + query_params: + - name: query + value: + param_values: + - name: email + identity: email + correlation_id_path: request_id + async_config: + strategy: polling + configuration: + status_request: + method: GET + path: /api/v1/user/status + status_path: status + status_completed_value: completed + result_request: + method: GET + path: /api/v1/user/result + update: + method: DELETE + path: /api/v1/user/ + correlation_id_path: correlation_id + async_config: + strategy: polling + configuration: + status_request: + method: GET + path: /api/v1/user//status + status_path: status + status_completed_value: completed + param_values: + - name: user_id + references: + - dataset: saas_async_polling_config + field: user.id + direction: from diff --git a/tests/fixtures/saas/test_data/saas_async_polling_dataset.yml b/tests/fixtures/saas/test_data/saas_async_polling_dataset.yml new file mode 100644 index 00000000000..f83d391a92b --- /dev/null +++ b/tests/fixtures/saas/test_data/saas_async_polling_dataset.yml @@ -0,0 +1,15 @@ +dataset: + - fides_key: saas_async_polling_config + name: async_polling_example + description: A sample dataset for async polling + collections: + - name: user + fields: + - name: id + data_categories: [user.unique_id] + fidesops_meta: + primary_key: True + - name: system_id + data_categories: [system] + - name: state + data_categories: [user.contact.address.state] diff --git a/tests/ops/service/connectors/test_saas_connector.py b/tests/ops/service/connectors/test_saas_connector.py index 3741cccb36c..eaf70ab794a 100644 --- a/tests/ops/service/connectors/test_saas_connector.py +++ b/tests/ops/service/connectors/test_saas_connector.py @@ -12,6 +12,7 @@ from starlette.status import HTTP_200_OK, HTTP_204_NO_CONTENT, HTTP_404_NOT_FOUND from fides.api.common_exceptions import ( + AwaitingAsyncProcessing, AwaitingAsyncTask, ClientUnsuccessfulException, ConnectionException, @@ -22,7 +23,7 @@ from fides.api.graph.graph import DatasetGraph, Node from fides.api.graph.traversal import Traversal, TraversalNode from fides.api.models.consent_automation import ConsentAutomation -from fides.api.models.policy import Policy +from fides.api.models.policy import ActionType, Policy, Rule, RuleTarget from fides.api.models.privacy_notice import UserConsentPreference from fides.api.models.privacy_request import PrivacyRequest, RequestTask from fides.api.models.worker_task import ExecutionLogStatus @@ -1121,6 +1122,23 @@ def async_graph(self, saas_example_async_dataset_config, db, privacy_request): db, privacy_request, traversal_nodes, end_nodes, graph ) + @pytest.fixture(scope="function") + def async_graph_polling( + self, saas_async_polling_example_dataset_config, db, privacy_request + ): + # Build proper async graph with persisted request tasks for polling tests + async_graph = saas_async_polling_example_dataset_config.get_graph() + graph = DatasetGraph(async_graph) + traversal = Traversal(graph, {"email": "customer-1@example.com"}) + traversal_nodes = {} + end_nodes = traversal.traverse(traversal_nodes, collect_tasks_fn) + persist_new_access_request_tasks( + db, privacy_request, traversal, traversal_nodes, end_nodes, graph + ) + persist_initial_erasure_request_tasks( + db, privacy_request, traversal_nodes, end_nodes, graph + ) + @mock.patch("fides.api.service.connectors.saas_connector.AuthenticatedClient.send") def test_read_request_expects_async_results( self, @@ -1306,3 +1324,195 @@ def test_callback_succeeded_mask_data( ) == 5 ) + + @mock.patch( + "fides.api.service.connectors.saas_connector.SaaSConnector.create_client" + ) + def test_guard_access_request_with_access_policy( + self, + mock_create_client, + privacy_request, + saas_async_polling_example_connection_config, + async_graph_polling, + ): + """ + Test that guard_access_request allows async access requests to run + when the policy has access rules (access request scenario). + """ + connector: SaaSConnector = get_connector( + saas_async_polling_example_connection_config + ) + mock_create_client.return_value = mock.MagicMock() + + # Get access request task + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_name == "user" + ).first() + execution_node = ExecutionNode(request_task) + + # Policy has access rules, so guard should return True and async_retrieve_data should be called + with pytest.raises(AwaitingAsyncProcessing): + connector.retrieve_data( + execution_node, + privacy_request.policy, + privacy_request, + request_task, + {}, + ) + + @mock.patch( + "fides.api.service.connectors.saas_connector.SaaSConnector.create_client" + ) + def test_guard_access_request_with_erasure_only_policy( + self, + mock_create_client, + db, + privacy_request, + saas_async_polling_example_connection_config, + async_graph_polling, + oauth_client, + ): + """ + Test that guard_access_request skips async access requests + when the policy has no access rules (erasure-only request scenario). + This test ensures coverage of the logger.info and return [] lines. + Uses polling async strategy to test the guard clause. + """ + # Create an erasure-only policy (no access rules) + erasure_only_policy = Policy.create( + db=db, + data={ + "name": "Erasure Only Policy", + "key": "erasure_only_policy_test", + "client_id": oauth_client.id, + }, + ) + + erasure_rule = Rule.create( + db=db, + data={ + "action_type": ActionType.erasure, + "name": "Erasure Rule", + "key": "erasure_rule_test", + "policy_id": erasure_only_policy.id, + "masking_strategy": { + "strategy": "null_rewrite", + "configuration": {}, + }, + "client_id": oauth_client.id, + }, + ) + + RuleTarget.create( + db=db, + data={ + "data_category": "user.name", + "rule_id": erasure_rule.id, + "client_id": oauth_client.id, + }, + ) + + connector: SaaSConnector = get_connector( + saas_async_polling_example_connection_config + ) + + # Get access request task + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_name == "user" + ).first() + execution_node = ExecutionNode(request_task) + + # Verify guard_access_request returns False for erasure-only policy + assert connector.guard_access_request(erasure_only_policy) is False + + result = connector.retrieve_data( + execution_node, + erasure_only_policy, + privacy_request, + request_task, + {}, + ) + + # Should return empty list without calling async_retrieve_data + assert result == [] + + @mock.patch( + "fides.api.service.connectors.saas_connector.SaaSConnector.create_client" + ) + def test_callback_requests_ignore_guard_clause( + self, + mock_create_client, + db, + privacy_request, + saas_async_example_connection_config, + async_graph, + oauth_client, + ): + """ + Test that callback requests ignore the guard clause entirely. + Even if guard_access_request returns False (erasure-only policy), + callback requests should still proceed and raise AwaitingAsyncTask. + """ + # Create an erasure-only policy (no access rules) + erasure_only_policy = Policy.create( + db=db, + data={ + "name": "Erasure Only Policy Callback Test", + "key": "erasure_only_policy_callback_test", + "client_id": oauth_client.id, + }, + ) + + erasure_rule = Rule.create( + db=db, + data={ + "action_type": ActionType.erasure, + "name": "Erasure Rule Callback Test", + "key": "erasure_rule_callback_test", + "policy_id": erasure_only_policy.id, + "masking_strategy": { + "strategy": "null_rewrite", + "configuration": {}, + }, + "client_id": oauth_client.id, + }, + ) + + RuleTarget.create( + db=db, + data={ + "data_category": "user.name", + "rule_id": erasure_rule.id, + "client_id": oauth_client.id, + }, + ) + + connector: SaaSConnector = get_connector(saas_async_example_connection_config) + # Mock the client and its send method to allow async callback flow + mock_client = mock.MagicMock() + mock_send_response = mock.MagicMock() + mock_send_response.json.return_value = {"id": "123"} + mock_client.send.return_value = mock_send_response + mock_create_client.return_value = mock_client + + # Get access request task + request_task = privacy_request.access_tasks.filter( + RequestTask.collection_name == "user" + ).first() + execution_node = ExecutionNode(request_task) + + # Verify guard_access_request returns False for erasure-only policy + assert connector.guard_access_request(erasure_only_policy) is False + + # Even though guard_access_request returns False, callback requests + # should ignore the guard and always proceed with common requests. + connector.retrieve_data( + execution_node, + erasure_only_policy, + privacy_request, + request_task, + {}, + ) + + # Verify that the async callback flow was triggered (client.send was called) + assert mock_client.send.called