From 50ed0c38f81210fa7667228c878b9f5fb42d2f47 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Sun, 3 May 2026 07:03:29 -0700 Subject: [PATCH 1/2] Add support for overriding Druid config --- .../datajunction_server/api/cubes.py | 12 + .../models/cube_materialization.py | 21 + .../datajunction_server/utils.py | 38 +- datajunction-server/tests/api/cubes_test.py | 123 ++ datajunction-server/tests/utils_test.py | 1094 +++-------------- 5 files changed, 370 insertions(+), 918 deletions(-) diff --git a/datajunction-server/datajunction_server/api/cubes.py b/datajunction-server/datajunction_server/api/cubes.py index 824e43e3e..25beef3e7 100644 --- a/datajunction-server/datajunction_server/api/cubes.py +++ b/datajunction-server/datajunction_server/api/cubes.py @@ -74,6 +74,7 @@ from datajunction_server.naming import from_amenable_name from datajunction_server.service_clients import QueryServiceClient from datajunction_server.utils import ( + deep_merge, get_current_user, get_query_service_client, get_session, @@ -647,6 +648,15 @@ async def materialize_cube( }, } + # Apply user-provided overrides to the Druid spec + if data.druid_overrides: + druid_spec = deep_merge(druid_spec, data.druid_overrides) + _logger.info( + "Applied druid_overrides to cube=%s: %s", + name, + list(data.druid_overrides.keys()), + ) + # Convert columns to ColumnMetadata output_columns = [ ColumnMetadata( @@ -680,6 +690,7 @@ async def materialize_cube( strategy=data.strategy, schedule=data.schedule, lookback_window=effective_lookback, + druid_overrides=data.druid_overrides, ) # Call the query service to create the workflow. @@ -738,6 +749,7 @@ async def materialize_cube( metrics=metrics_list, timestamp_column=timestamp_column, timestamp_format=timestamp_format, + druid_overrides=data.druid_overrides, workflow_urls=workflow_urls, workflow_names=workflow_names, ) diff --git a/datajunction-server/datajunction_server/models/cube_materialization.py b/datajunction-server/datajunction_server/models/cube_materialization.py index 987f00ef4..185697d44 100644 --- a/datajunction-server/datajunction_server/models/cube_materialization.py +++ b/datajunction-server/datajunction_server/models/cube_materialization.py @@ -457,6 +457,15 @@ class CubeMaterializeRequest(BaseModel): default=True, description="Whether to run an initial backfill", ) + druid_overrides: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Override any part of the generated Druid ingestion spec. " + "This dict is deep-merged into the generated spec, allowing fine-grained " + "control over tuningConfig, indexSpec, granularitySpec, etc. " + "Example: {'tuningConfig': {'partitionsSpec': {'targetRowsPerSegment': 1000000}}}" + ), + ) class PreAggTableInfo(BaseModel): @@ -587,6 +596,12 @@ class CubeMaterializationV2Input(BaseModel): description="Lookback window for incremental. None for FULL strategy.", ) + # Druid spec overrides (passed through from request) + druid_overrides: Optional[Dict[str, Any]] = Field( + default=None, + description="User-provided overrides that were deep-merged into druid_spec.", + ) + class DruidCubeV3Config(BaseModel): """ @@ -665,6 +680,12 @@ class DruidCubeV3Config(BaseModel): description="Format of the timestamp column. None when timestamp_column is None.", ) + # User-provided Druid spec overrides (persisted for reference) + druid_overrides: Optional[Dict[str, Any]] = Field( + default=None, + description="User-provided overrides that were applied to the Druid spec.", + ) + # Workflow tracking workflow_urls: List[str] = Field( default_factory=list, diff --git a/datajunction-server/datajunction_server/utils.py b/datajunction-server/datajunction_server/utils.py index 4eb06e3df..cebaba70f 100644 --- a/datajunction-server/datajunction_server/utils.py +++ b/datajunction-server/datajunction_server/utils.py @@ -3,6 +3,7 @@ """ import asyncio +import copy from contextlib import asynccontextmanager import json import logging @@ -11,7 +12,7 @@ from functools import lru_cache from http import HTTPStatus -from typing import AsyncIterator, List, Optional +from typing import Any, AsyncIterator, Dict, List, Optional from dotenv import load_dotenv from fastapi import Depends @@ -590,4 +591,39 @@ async def sync_user_groups( return group_names +def deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]: + """ + Deep merge two dictionaries. + + Values from `overrides` take precedence. Nested dictionaries are merged + recursively. Non-dict values in `overrides` replace values in `base`. + + Args: + base: The base dictionary + overrides: Dictionary with values to merge/override + + Returns: + A new dictionary with merged values (original dicts are not modified) + + Example: + >>> base = {"a": 1, "b": {"c": 2, "d": 3}} + >>> overrides = {"b": {"c": 10, "e": 5}} + >>> deep_merge(base, overrides) + {"a": 1, "b": {"c": 10, "d": 3, "e": 5}} + """ + result = copy.deepcopy(base) + + for key, override_value in overrides.items(): + if ( + key in result + and isinstance(result[key], dict) + and isinstance(override_value, dict) + ): + result[key] = deep_merge(result[key], override_value) + else: + result[key] = copy.deepcopy(override_value) + + return result + + SEPARATOR = "." diff --git a/datajunction-server/tests/api/cubes_test.py b/datajunction-server/tests/api/cubes_test.py index 610ada94c..a6bf2fb4d 100644 --- a/datajunction-server/tests/api/cubes_test.py +++ b/datajunction-server/tests/api/cubes_test.py @@ -4818,6 +4818,129 @@ async def test_materialize_cube_updates_existing_materialization( assert data["strategy"] == "incremental_time" assert data["schedule"] == "0 6 * * *" + @pytest.mark.asyncio + async def test_materialize_cube_with_druid_overrides( + self, + client_with_repairs_cube: AsyncClient, + mocker, + ): + """Test that druid_overrides are deep-merged into the generated Druid spec. + + This allows users to override tuningConfig, indexSpec, or any other Druid + settings without DJ having to expose every possible knob. + """ + cube_name = "default.test_materialize_druid_overrides_cube" + await make_a_test_cube( + client_with_repairs_cube, + cube_name, + with_materialization=False, + ) + + mock_columns = [ + V3ColumnMetadata( + name="date_id", + type="int", + semantic_name="default.hard_hat.hire_date", + semantic_type="dimension", + ), + V3ColumnMetadata( + name="num_repair_orders", + type="bigint", + semantic_name="default.num_repair_orders", + semantic_type="measure", + ), + ] + + mock_combined_result = _create_mock_combined_result( + mocker, + columns=mock_columns, + shared_dimensions=["default.hard_hat.hire_date"], + sql_string="SELECT date_id, COUNT(*) FROM preagg GROUP BY date_id", + ) + + mock_temporal_info = TemporalPartitionInfo( + column_name="date_id", + format="yyyyMMdd", + granularity="day", + ) + + mocker.patch( + "datajunction_server.api.cubes.build_combiner_sql_from_preaggs", + return_value=( + mock_combined_result, + [ + PreAggSourceInfo( + table_ref="catalog.schema.preagg_table1", + parent_name="default.repair_orders", + strategy=None, + ), + ], + mock_temporal_info, + ), + ) + + qs_client = client_with_repairs_cube.app.dependency_overrides[ + get_query_service_client + ]() + mocker.patch.object( + qs_client, + "materialize_cube_v2", + return_value=mocker.MagicMock( + urls=["http://workflow/cube-workflow"], + workflow_names=["cube-workflow"], + ), + ) + + # Request with druid_overrides + druid_overrides = { + "tuningConfig": { + "partitionsSpec": { + "targetRowsPerSegment": 1000000, # Override default 5000000 + "type": "single_dim", # Override default "hashed" + "partitionDimension": "date_id", + }, + "maxNumConcurrentSubTasks": 20, # Add new field + }, + "dataSchema": { + "granularitySpec": { + "queryGranularity": "HOUR", # Add rollup granularity + }, + }, + } + + response = await client_with_repairs_cube.post( + f"/cubes/{cube_name}/materialize", + json={ + "strategy": "incremental_time", + "schedule": "0 6 * * *", + "lookback_window": "1 DAY", + "druid_overrides": druid_overrides, + }, + ) + + assert response.status_code == 200, response.json() + data = response.json() + + # Verify overrides were applied + druid_spec = data["druid_spec"] + + # Check tuningConfig overrides + tuning_config = druid_spec["tuningConfig"] + assert tuning_config["partitionsSpec"]["targetRowsPerSegment"] == 1000000 + assert tuning_config["partitionsSpec"]["type"] == "single_dim" + assert tuning_config["partitionsSpec"]["partitionDimension"] == "date_id" + assert tuning_config["maxNumConcurrentSubTasks"] == 20 + # Verify existing defaults are preserved + assert tuning_config["useCombiner"] is True + assert tuning_config["type"] == "hadoop" + + # Check dataSchema override + granularity_spec = druid_spec["dataSchema"]["granularitySpec"] + assert granularity_spec["queryGranularity"] == "HOUR" + # Verify existing defaults are preserved + assert granularity_spec["type"] == "uniform" + assert granularity_spec["segmentGranularity"] == "DAY" + @pytest.mark.asyncio async def test_materialize_cube_returns_metric_combiners( self, diff --git a/datajunction-server/tests/utils_test.py b/datajunction-server/tests/utils_test.py index 517d43637..a174e3bf0 100644 --- a/datajunction-server/tests/utils_test.py +++ b/datajunction-server/tests/utils_test.py @@ -1,934 +1,194 @@ """ -Tests for ``datajunction_server.utils``. +Tests for utility functions. """ -from typing import cast -import logging -from unittest.mock import AsyncMock, MagicMock, patch -import json -import pytest -from starlette.requests import Request -from starlette.datastructures import Headers -from starlette.types import Scope +from datajunction_server.utils import deep_merge -import pytest -from pytest_mock import MockerFixture -from sqlalchemy.exc import OperationalError -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from starlette.background import BackgroundTasks -from testcontainers.postgres import PostgresContainer -from yarl import URL -from datajunction_server.config import DatabaseConfig, Settings -from datajunction_server.database.user import OAuthProvider, User -from datajunction_server.errors import ( - DJDatabaseException, - DJException, - DJUninitializedResourceException, -) -from datajunction_server.utils import ( - DatabaseSessionManager, - Version, - execute_with_retry, - get_issue_url, - get_query_service_client, - get_legacy_query_service_client, - get_session, - get_session_manager, - get_settings, - setup_logging, - is_graphql_query, - sync_user_groups, - _create_configured_query_client, -) -from datajunction_server.database.user import PrincipalKind +class TestDeepMerge: + """Tests for the deep_merge utility function.""" + def test_deep_merge_simple(self): + """Test basic deep merge of two dicts.""" + base = {"a": 1, "b": 2} + overrides = {"b": 3, "c": 4} + result = deep_merge(base, overrides) + assert result == {"a": 1, "b": 3, "c": 4} -def test_setup_logging() -> None: - """ - Test ``setup_logging``. - """ - setup_logging("debug") - assert logging.root.level == logging.DEBUG - - with pytest.raises(ValueError) as excinfo: - setup_logging("invalid") - assert str(excinfo.value) == "Invalid log level: invalid" + def test_deep_merge_nested(self): + """Test deep merge with nested dictionaries.""" + base = { + "a": 1, + "b": { + "c": 2, + "d": 3, + }, + } + overrides = { + "b": { + "c": 10, + "e": 5, + }, + } + result = deep_merge(base, overrides) + assert result == { + "a": 1, + "b": { + "c": 10, + "d": 3, + "e": 5, + }, + } + def test_deep_merge_deeply_nested(self): + """Test deep merge with multiple levels of nesting.""" + base = { + "level1": { + "level2": { + "level3": { + "a": 1, + "b": 2, + }, + }, + }, + } + overrides = { + "level1": { + "level2": { + "level3": { + "b": 20, + "c": 30, + }, + }, + }, + } + result = deep_merge(base, overrides) + assert result == { + "level1": { + "level2": { + "level3": { + "a": 1, + "b": 20, + "c": 30, + }, + }, + }, + } -@pytest.mark.asyncio -async def test_get_session(mocker: MockerFixture) -> None: - """ - Test ``get_session``. - """ - with patch( - "fastapi.BackgroundTasks", - mocker.MagicMock(autospec=BackgroundTasks), - ) as background_tasks: - background_tasks.side_effect = lambda x, y: None - session = await anext(get_session(request=mocker.MagicMock())) - assert isinstance(session, AsyncSession) + def test_deep_merge_override_replaces_non_dict(self): + """Test that override replaces non-dict value with dict.""" + base = {"a": 1} + overrides = {"a": {"nested": "value"}} + result = deep_merge(base, overrides) + assert result == {"a": {"nested": "value"}} + + def test_deep_merge_non_dict_replaces_dict(self): + """Test that non-dict override replaces dict value.""" + base = {"a": {"nested": "value"}} + overrides = {"a": 1} + result = deep_merge(base, overrides) + assert result == {"a": 1} + + def test_deep_merge_empty_base(self): + """Test deep merge with empty base dict.""" + base = {} + overrides = {"a": 1, "b": {"c": 2}} + result = deep_merge(base, overrides) + assert result == {"a": 1, "b": {"c": 2}} + + def test_deep_merge_empty_overrides(self): + """Test deep merge with empty overrides dict.""" + base = {"a": 1, "b": {"c": 2}} + overrides = {} + result = deep_merge(base, overrides) + assert result == {"a": 1, "b": {"c": 2}} + + def test_deep_merge_does_not_mutate_inputs(self): + """Test that deep_merge does not modify the input dicts.""" + base = {"a": 1, "b": {"c": 2}} + overrides = {"b": {"c": 10, "d": 3}} + + # Store original values + base_original = {"a": 1, "b": {"c": 2}} + overrides_original = {"b": {"c": 10, "d": 3}} + + result = deep_merge(base, overrides) + + # Verify inputs are unchanged + assert base == base_original + assert overrides == overrides_original + + # Verify result is independent + result["b"]["c"] = 999 + assert base["b"]["c"] == 2 + assert overrides["b"]["c"] == 10 + + def test_deep_merge_with_lists(self): + """Test that lists are replaced, not merged.""" + base = {"a": [1, 2, 3]} + overrides = {"a": [4, 5]} + result = deep_merge(base, overrides) + assert result == {"a": [4, 5]} + + def test_deep_merge_druid_spec_example(self): + """Test deep_merge with a realistic Druid spec override scenario.""" + base = { + "dataSchema": { + "dataSource": "my_cube", + "granularitySpec": { + "type": "uniform", + "segmentGranularity": "DAY", + "intervals": [], + }, + }, + "tuningConfig": { + "partitionsSpec": { + "targetPartitionSize": 5000000, + "type": "hashed", + }, + "useCombiner": True, + "type": "hadoop", + }, + } + overrides = { + "tuningConfig": { + "partitionsSpec": { + "targetRowsPerSegment": 1000000, + "type": "single_dim", + "partitionDimension": "date_id", + }, + "maxNumConcurrentSubTasks": 20, + }, + "dataSchema": { + "granularitySpec": { + "queryGranularity": "HOUR", + }, + }, + } -@pytest.mark.asyncio -@pytest.mark.parametrize( - "method,expected_session_attr", - [ - ("GET", "reader_session"), - ("POST", "writer_session"), - ], -) -async def test_get_session_uses_correct_session(method, expected_session_attr): - """ - Ensure get_session uses reader_session for GET and writer_session for others. - """ - get_session_manager.cache_clear() - mock_session_manager = get_session_manager() - request = MagicMock() - request.method = method - assert mock_session_manager.reader_engine is not None - assert mock_session_manager.writer_engine is not None - assert mock_session_manager.reader_sessionmaker is not None - assert mock_session_manager.writer_sessionmaker is not None + result = deep_merge(base, overrides) - agen = get_session(request) - session = await anext(agen) - if expected_session_attr == "reader_session": + # Check tuningConfig assert ( - str(session.bind.url) - == "postgresql+psycopg://readonly_user:***@postgres_metadata:5432/dj" + result["tuningConfig"]["partitionsSpec"]["targetRowsPerSegment"] == 1000000 ) - else: + assert result["tuningConfig"]["partitionsSpec"]["type"] == "single_dim" assert ( - str(session.bind.url) - == "postgresql+psycopg://dj:***@postgres_metadata:5432/dj" + result["tuningConfig"]["partitionsSpec"]["partitionDimension"] == "date_id" + ) + assert result["tuningConfig"]["maxNumConcurrentSubTasks"] == 20 + # Original values preserved + assert result["tuningConfig"]["useCombiner"] is True + assert result["tuningConfig"]["type"] == "hadoop" + # Note: targetPartitionSize was in base but not in overrides' partitionsSpec, + # and since partitionsSpec is a dict, it gets deep-merged + assert ( + result["tuningConfig"]["partitionsSpec"]["targetPartitionSize"] == 5000000 ) - await agen.aclose() - - -def test_get_settings(mocker: MockerFixture) -> None: - """ - Test ``get_settings``. - """ - mocker.patch("datajunction_server.utils.load_dotenv") - Settings = mocker.patch( - "datajunction_server.utils.Settings", - ) - - # should be already cached, since it's called by the Celery app - get_settings() - Settings.assert_not_called() - - -def test_get_issue_url() -> None: - """ - Test ``get_issue_url``. - """ - assert get_issue_url() == URL( - "https://github.com/DataJunction/dj/issues/new", - ) - assert get_issue_url( - baseurl=URL("https://example.org/"), - title="Title with spaces", - body="This is the body", - labels=["help", "troubleshoot"], - ) == URL( - "https://example.org/?title=Title+with+spaces&" - "body=This+is+the+body&labels=help,troubleshoot", - ) - - -def test_database_session_manager( - mocker: MockerFixture, - settings: Settings, - postgres_container: PostgresContainer, -) -> None: - """ - Test DatabaseSessionManager. - """ - connection_url = postgres_container.get_connection_url() - settings.writer_db = DatabaseConfig(uri=connection_url) - mocker.patch("datajunction_server.utils.get_settings", return_value=settings) - - session_manager = DatabaseSessionManager() - with pytest.raises(DJUninitializedResourceException): - session_manager.reader_engine - with pytest.raises(DJUninitializedResourceException): - session_manager.writer_engine - with pytest.raises(DJUninitializedResourceException): - session_manager.reader_sessionmaker - with pytest.raises(DJUninitializedResourceException): - session_manager.writer_sessionmaker - - session_manager.init_db() - - writer_engine = session_manager.writer_engine - writer_engine.pool.size() == settings.writer_db.pool_size # type: ignore - writer_engine.pool.timeout() == settings.writer_db.pool_timeout # type: ignore - writer_engine.pool.overflow() == settings.writer_db.max_overflow # type: ignore - - reader_engine = session_manager.reader_engine - reader_engine.pool.size() == settings.reader_db.pool_size # type: ignore - reader_engine.pool.timeout() == settings.reader_db.pool_timeout # type: ignore - reader_engine.pool.overflow() == settings.reader_db.max_overflow # type: ignore - - assert session_manager.reader_engine != session_manager.writer_engine - assert isinstance(session_manager.reader_sessionmaker, async_sessionmaker) - assert isinstance(session_manager.writer_sessionmaker, async_sessionmaker) - assert session_manager.sessionmaker == session_manager.writer_sessionmaker - - -def test_get_query_service_client(mocker: MockerFixture, settings: Settings) -> None: - """ - Test ``get_query_service_client``. - """ - settings.query_service = "http://query_service:8001" - query_service_client = get_query_service_client(settings=settings) - assert query_service_client.uri == "http://query_service:8001" # type: ignore - - -def test_version_parse() -> None: - """ - Test version parsing - """ - ver = Version.parse("v1.0") - assert ver.major == 1 - assert ver.minor == 0 - assert str(ver.next_major_version()) == "v2.0" - assert str(ver.next_minor_version()) == "v1.1" - assert str(ver.next_minor_version().next_minor_version()) == "v1.2" - - ver = Version.parse("v21.12") - assert ver.major == 21 - assert ver.minor == 12 - assert str(ver.next_major_version()) == "v22.0" - assert str(ver.next_minor_version()) == "v21.13" - assert str(ver.next_minor_version().next_minor_version()) == "v21.14" - assert str(ver.next_major_version().next_minor_version()) == "v22.1" - - with pytest.raises(DJException) as excinfo: - Version.parse("0") - assert str(excinfo.value) == "Unparseable version 0!" - - -@pytest.mark.asyncio -async def test_execute_with_retry_success_after_flaky_connection(): - """ - Test that execute_with_retry succeeds after a flaky connection. - """ - session = AsyncMock(spec=AsyncSession) - statement = MagicMock() - - # Simulate flaky DB: first 2 calls raise OperationalError, 3rd returns success - mock_result = MagicMock() - mock_result.unique.return_value.scalars.return_value.all.return_value = [ - "node1", - "node2", - ] - session.execute.side_effect = [ - OperationalError("flaky", None, None), # type: ignore - OperationalError("still flaky", None, None), # type: ignore - mock_result, - ] - - result = await execute_with_retry(session, statement, retries=5, base_delay=0.01) - values = result.unique().scalars().all() - assert values == ["node1", "node2"] - assert session.execute.call_count == 3 - - -@pytest.mark.asyncio -async def test_execute_with_retry_exhausts_retries(): - """ - Test that execute_with_retry exhausts retries and fails. - """ - session = AsyncMock(spec=AsyncSession) - statement = MagicMock() - - # Always fail - session.execute.side_effect = OperationalError("permanent fail", None, None) # type: ignore - - with pytest.raises(DJDatabaseException): - await execute_with_retry(session, statement, retries=3, base_delay=0.01) - - assert session.execute.call_count == 4 # initial try + 3 retries - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "path, body, expected", - [ - # Not /graphql - ("/not-graphql", json.dumps({"query": "query { users }"}), False), - # /graphql with query - ("/graphql", json.dumps({"query": "query { users }"}), True), - # /graphql with mutation - ( - "/graphql", - json.dumps({"query": 'mutation { addUser(name: "Hi") { id } }'}), - False, - ), - # /graphql with invalid JSON - ("/graphql", "not json", False), - # /graphql with no query key - ("/graphql", json.dumps({"foo": "bar"}), False), - # /graphql with empty body - ("/graphql", "", False), - ], -) -async def test_is_graphql_query(path, body, expected): - """ - Test the `is_graphql_query` utility function. - This function checks if the request is a GraphQL query based on the path and body. - """ - # Build a fake ASGI scope - scope: Scope = { - "type": "http", - "method": "POST", - "path": path, - "headers": Headers({"content-type": "application/json"}).raw, - } - - # Create a receive function that yields the body - async def receive() -> dict: - return { - "type": "http.request", - "body": body.encode(), - "more_body": False, - } - - request = Request(scope, receive) - result = await is_graphql_query(request) - assert result is expected - - -def test_get_query_service_client_with_configured_client( - mocker: MockerFixture, - settings: Settings, -) -> None: - """ - Test get_query_service_client with configured client (non-HTTP). - """ - from datajunction_server.config import QueryClientConfig - - # Configure Snowflake client - settings.query_client = QueryClientConfig( - type="snowflake", - connection={"account": "test_account", "user": "test_user"}, - ) - settings.query_service = None - - # Mock the SnowflakeClient import to avoid dependency issues - mock_snowflake_client = mocker.MagicMock() - mocker.patch( - "datajunction_server.query_clients.SnowflakeClient", - mock_snowflake_client, - ) - - client = get_query_service_client(settings=settings) - assert client is not None - mock_snowflake_client.assert_called_once_with( - account="test_account", - user="test_user", - ) - - -def test_get_query_service_client_returns_none( - mocker: MockerFixture, - settings: Settings, -) -> None: - """ - Test get_query_service_client returns None when no configuration is provided. - """ - settings.query_service = None - from datajunction_server.config import QueryClientConfig - - settings.query_client = QueryClientConfig(type="http", connection={}) - - client = get_query_service_client(settings=settings) - assert client is None - - -def test_create_configured_query_client_http_success(mocker: MockerFixture) -> None: - """ - Test _create_configured_query_client creates HTTP client successfully. - """ - from datajunction_server.config import QueryClientConfig - from datajunction_server.query_clients import HttpQueryServiceClient - - config = QueryClientConfig(type="http", connection={"uri": "http://test:8001"}) - - client = _create_configured_query_client(config) - assert isinstance(client, HttpQueryServiceClient) - assert client.uri == "http://test:8001" - - -def test_create_configured_query_client_http_missing_uri(mocker: MockerFixture) -> None: - """ - Test _create_configured_query_client raises error for HTTP client without URI. - """ - from datajunction_server.config import QueryClientConfig - - config = QueryClientConfig(type="http", connection={}) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "HTTP client requires 'uri' in connection parameters" in str(exc_info.value) - - -def test_create_configured_query_client_snowflake_missing_params( - mocker: MockerFixture, -) -> None: - """ - Test _create_configured_query_client raises error for Snowflake client without required params. - """ - from datajunction_server.config import QueryClientConfig - - # Missing 'user' parameter - config = QueryClientConfig(type="snowflake", connection={"account": "test_account"}) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "Snowflake client requires 'user' in connection parameters" in str( - exc_info.value, - ) - - # Missing 'account' parameter - config = QueryClientConfig(type="snowflake", connection={"user": "test_user"}) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "Snowflake client requires 'account' in connection parameters" in str( - exc_info.value, - ) - - -def test_create_configured_query_client_snowflake_import_error( - mocker: MockerFixture, -) -> None: - """ - Test _create_configured_query_client handles ImportError for Snowflake client. - """ - from datajunction_server.config import QueryClientConfig - - config = QueryClientConfig( - type="snowflake", - connection={"account": "test_account", "user": "test_user"}, - ) - - # Mock the import to fail - mocker.patch( - "datajunction_server.query_clients.SnowflakeClient", - side_effect=ImportError("No module named 'snowflake'"), - ) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "Snowflake client dependencies not installed" in str(exc_info.value) - assert "pip install 'datajunction-server[snowflake]'" in str(exc_info.value) - - -def test_create_configured_query_client_bigquery_import_error( - mocker: MockerFixture, -) -> None: - """ - Test _create_configured_query_client handles ImportError for BigQuery client. - """ - from datajunction_server.config import QueryClientConfig - - config = QueryClientConfig( - type="bigquery", - connection={"project": "my-project"}, - ) - - mocker.patch( - "datajunction_server.query_clients.BigQueryClient", - side_effect=ImportError("No module named 'google.cloud.bigquery'"), - ) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "BigQuery client dependencies not installed" in str(exc_info.value) - assert "pip install 'datajunction-server[bigquery]'" in str(exc_info.value) - - -def test_create_configured_query_client_unsupported_type(mocker: MockerFixture) -> None: - """ - Test _create_configured_query_client raises error for unsupported client type. - """ - from datajunction_server.config import QueryClientConfig - - config = QueryClientConfig(type="unsupported", connection={}) - - with pytest.raises(ValueError) as exc_info: - _create_configured_query_client(config) - assert "Unsupported query client type: unsupported" in str(exc_info.value) - - -def test_get_legacy_query_service_client( - mocker: MockerFixture, - settings: Settings, -) -> None: - """ - Test get_legacy_query_service_client returns QueryServiceClient. - """ - settings.query_service = "http://query_service:8001" - - mock_query_service_client_cls = mocker.MagicMock() - mock_query_service_client_instance = mocker.MagicMock() - mock_query_service_client_cls.return_value = mock_query_service_client_instance - mocker.patch( - "datajunction_server.service_clients.QueryServiceClient", - mock_query_service_client_cls, - ) - - client = get_legacy_query_service_client(settings=settings) - mock_query_service_client_cls.assert_called_once_with("http://query_service:8001") - assert client == mock_query_service_client_instance - - -def test_http_query_service_client_wrapper(mocker: MockerFixture) -> None: - """ - Test HttpQueryServiceClient properly wraps QueryServiceClient. - """ - from datajunction_server.query_clients import HttpQueryServiceClient - from datajunction_server.models.query import QueryCreate - from datajunction_server.models.node_type import NodeType - - # Mock the underlying QueryServiceClient - mock_client = mocker.MagicMock() - mocker.patch( - "datajunction_server.query_clients.http.QueryServiceClient", - return_value=mock_client, - ) - - # Create HTTP client - client = HttpQueryServiceClient("http://test:8001", retries=3) - assert client.uri == "http://test:8001" - - # Test get_columns_for_table - mock_client.get_columns_for_table.return_value = [] - get_columns_for_table_result = client.get_columns_for_table("cat", "sch", "tbl") - assert get_columns_for_table_result == [] - mock_client.get_columns_for_table.assert_called_once() - - # Test create_view - mock_client.create_view.return_value = "view_created" - query = QueryCreate( - submitted_query="SELECT 1", - catalog_name="test", - engine_name="test", - engine_version="v1", - ) - create_view_result = client.create_view("test_view", query) - assert create_view_result == "view_created" - - # Test submit_query - mock_result = mocker.MagicMock() - mock_client.submit_query.return_value = mock_result - submit_query_result = client.submit_query(query) - assert submit_query_result == mock_result - - # Test get_query - mock_client.get_query.return_value = mock_result - get_query_result = client.get_query("query_id_123") - assert get_query_result == mock_result - - # Test materialize - mock_mat_result = mocker.MagicMock() - mock_client.materialize.return_value = mock_mat_result - materialize_result = client.materialize(mocker.MagicMock()) - assert materialize_result == mock_mat_result - - # Test materialize_cube - mock_client.materialize_cube.return_value = mock_mat_result - materialize_cube_result = client.materialize_cube(mocker.MagicMock()) - assert materialize_cube_result == mock_mat_result - - # Test deactivate_materialization - mock_client.deactivate_materialization.return_value = mock_mat_result - deactivate_materialization_result = client.deactivate_materialization("node", "mat") - assert deactivate_materialization_result == mock_mat_result - - # Test get_materialization_info - mock_client.get_materialization_info.return_value = mock_mat_result - get_materialization_info_result = client.get_materialization_info( - "node", - "v1", - NodeType.SOURCE, - "mat", - ) - assert get_materialization_info_result == mock_mat_result - - # Test run_backfill - mock_client.run_backfill.return_value = mock_mat_result - run_backfill_result = client.run_backfill("node", "v1", NodeType.SOURCE, "mat", []) - assert run_backfill_result == mock_mat_result - - # Test materialize_preagg - mock_preagg_result = {"workflow_url": "http://test/workflow", "status": "SCHEDULED"} - mock_client.materialize_preagg.return_value = mock_preagg_result - mock_preagg_input = mocker.MagicMock() - materialize_preagg_result = client.materialize_preagg(mock_preagg_input) - assert materialize_preagg_result == mock_preagg_result - mock_client.materialize_preagg.assert_called_once_with( - materialization_input=mock_preagg_input, - request_headers=None, - ) - - # Test deactivate_preagg_workflow - mock_deactivate_result = {"status": "DEACTIVATED"} - mock_client.deactivate_preagg_workflow.return_value = mock_deactivate_result - deactivate_preagg_result = client.deactivate_preagg_workflow( - output_table="test_preagg_table", - ) - assert deactivate_preagg_result == mock_deactivate_result - mock_client.deactivate_preagg_workflow.assert_called_once_with( - output_table="test_preagg_table", - request_headers=None, - ) - - # Test run_preagg_backfill - mock_backfill_result = {"job_url": "http://test/backfill/123"} - mock_client.run_preagg_backfill.return_value = mock_backfill_result - mock_backfill_input = mocker.MagicMock() - run_preagg_backfill_result = client.run_preagg_backfill(mock_backfill_input) - assert run_preagg_backfill_result == mock_backfill_result - mock_client.run_preagg_backfill.assert_called_once_with( - backfill_input=mock_backfill_input, - request_headers=None, - ) - - -def test_snowflake_client_initialization_with_mock(mocker: MockerFixture) -> None: - """ - Test SnowflakeClient initialization when snowflake package is available. - """ - # Mock snowflake being available - mocker.patch( - "datajunction_server.query_clients.snowflake.SNOWFLAKE_AVAILABLE", - True, - ) - mocker.patch( - "datajunction_server.query_clients.snowflake.SnowflakeDatabaseError", - Exception, - ) - - # Mock the snowflake connector - mock_snowflake = mocker.MagicMock() - mock_conn = mocker.MagicMock() - mock_cursor = mocker.MagicMock() - mock_cursor.fetchall.return_value = [ - { - "COLUMN_NAME": "id", - "DATA_TYPE": "NUMBER", - "IS_NULLABLE": "NO", - "ORDINAL_POSITION": 1, - }, - ] - mock_cursor.fetchone.return_value = (1,) - mock_conn.cursor.return_value.__enter__.return_value = mock_cursor - mock_snowflake.connector.connect.return_value = mock_conn - mock_snowflake.connector.DatabaseError = Exception - - mocker.patch( - "datajunction_server.query_clients.snowflake.snowflake", - mock_snowflake, - ) - - from datajunction_server.query_clients.snowflake import SnowflakeClient - - # Create client with password auth - client = SnowflakeClient( - account="test_account", - user="test_user", - password="test_pass", - warehouse="TEST_WH", - database="TEST_DB", - ) - - assert client.connection_params["account"] == "test_account" - assert client.connection_params["user"] == "test_user" - assert client.connection_params["password"] == "test_pass" - assert client.connection_params["warehouse"] == "TEST_WH" - assert client.connection_params["database"] == "TEST_DB" - - # Test get_columns_for_table - result = client.get_columns_for_table("catalog", "schema", "table") - assert len(result) == 1 - assert result[0].name == "id" - - # Test connection test - assert client.test_connection() is True - - # Test with private key and role - mock_open_func = mocker.mock_open(read_data=b"private_key_data") - mocker.patch("builtins.open", mock_open_func) - - client2 = SnowflakeClient( - account="test_account", - user="test_user", - private_key_path="/path/to/key.pem", - warehouse="TEST_WH", - role="TEST_ROLE", - ) - assert "private_key" in client2.connection_params - assert client2.connection_params["private_key"] == b"private_key_data" - assert client2.connection_params["role"] == "TEST_ROLE" - - # Test _get_database_from_engine with engine URI - mock_engine = mocker.MagicMock() - mock_engine.uri = "snowflake://user:pass@account/DATABASE_FROM_URI?warehouse=WH" - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "DATABASE_FROM_URI" - - # Test with database in query params - mock_engine.uri = ( - "snowflake://user:pass@account/?database=DB_FROM_QUERY&warehouse=WH" - ) - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "DB_FROM_QUERY" - - # Test with no database in URI (falls back to connection params) - mock_engine.uri = "snowflake://user:pass@account/?warehouse=WH" - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "TEST_DB" # From client.connection_params - - # Test with empty path (no database, no query params - falls back) - mock_engine.uri = "snowflake://user:pass@account" - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "TEST_DB" # From client.connection_params - - # Test with empty database name in path (just slash, no query) - mock_engine.uri = "snowflake://user:pass@account/" - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "TEST_DB" # From client.connection_params - - # Test with path that becomes empty after processing (double slash case) - mock_engine.uri = "snowflake://user:pass@account//" - db_name = client._get_database_from_engine(mock_engine, "fallback") - assert db_name == "TEST_DB" # From client.connection_params - - # Test error handling in get_columns_for_table - from datajunction_server.errors import DJDoesNotExistException - - mock_cursor.fetchall.return_value = [] - with pytest.raises(DJDoesNotExistException): - client.get_columns_for_table("catalog", "schema", "nonexistent") - - # Test connection failure - mock_snowflake.connector.connect.side_effect = Exception("Connection failed") - assert client.test_connection() is False - - # Reset side effect for next tests - mock_snowflake.connector.connect.side_effect = None - mock_snowflake.connector.connect.return_value = mock_conn - - # Test database error handling - from datajunction_server.errors import DJQueryServiceClientException - - mock_cursor.execute.side_effect = mock_snowflake.connector.DatabaseError( - "Table does not exist", - ) - with pytest.raises(DJDoesNotExistException): - client.get_columns_for_table("catalog", "schema", "missing_table") - - # Test other database error - mock_cursor.execute.side_effect = mock_snowflake.connector.DatabaseError( - "Connection timeout", - ) - with pytest.raises(DJQueryServiceClientException): - client.get_columns_for_table("catalog", "schema", "table") - - # Test type mapping with decimal parameters - assert client._map_snowflake_type_to_dj("NUMBER(10,2)") - assert client._map_snowflake_type_to_dj("DECIMAL(20,5)") - assert client._map_snowflake_type_to_dj("NUMERIC(15)") # No scale parameter - assert client._map_snowflake_type_to_dj("NUMBER(invalid)") # Invalid params - - -@pytest.mark.asyncio -async def test_sync_user_groups_no_groups(session: AsyncSession, mocker: MockerFixture): - """ - Test sync_user_groups when user has no groups. - """ - # Mock the group membership service to return no groups - mock_service = mocker.MagicMock() - mock_service.get_user_groups = mocker.AsyncMock(return_value=[]) - mocker.patch( - "datajunction_server.utils.get_group_membership_service", - return_value=mock_service, - ) - - result = await sync_user_groups(session, "testuser") - - assert result == [] - mock_service.get_user_groups.assert_called_once_with(session, "testuser") - - -@pytest.mark.asyncio -async def test_sync_user_groups_creates_new_groups( - session: AsyncSession, - mocker: MockerFixture, -): - """ - Test sync_user_groups creates group principals that don't exist. - """ - # Mock the group membership service to return groups - mock_service = mocker.MagicMock() - mock_service.get_user_groups = mocker.AsyncMock( - return_value=["eng-team", "data-team"], - ) - mocker.patch( - "datajunction_server.utils.get_group_membership_service", - return_value=mock_service, - ) - - result = await sync_user_groups(session, "testuser") - - assert result == ["eng-team", "data-team"] - - # Verify groups were created - eng_group = await User.get_by_username(session, "eng-team", options=[]) - data_group = await User.get_by_username(session, "data-team", options=[]) - - assert eng_group is not None - assert eng_group.kind == PrincipalKind.GROUP - assert eng_group.name == "eng-team" - - assert data_group is not None - assert data_group.kind == PrincipalKind.GROUP - assert data_group.name == "data-team" - - -@pytest.mark.asyncio -async def test_sync_user_groups_skips_existing_groups( - session: AsyncSession, - mocker: MockerFixture, -): - """ - Test sync_user_groups skips groups that already exist. - """ - # Create an existing group - existing_group = User( - username="existing-group", - password=None, - email="existing@group.com", - name="Existing Group", - oauth_provider=OAuthProvider.BASIC, - is_admin=False, - kind=PrincipalKind.GROUP, - ) - session.add(existing_group) - await session.commit() - original_id = existing_group.id - - # Mock the group membership service to return the existing group - mock_service = mocker.MagicMock() - mock_service.get_user_groups = mocker.AsyncMock(return_value=["existing-group"]) - mocker.patch( - "datajunction_server.utils.get_group_membership_service", - return_value=mock_service, - ) - - result = await sync_user_groups(session, "testuser") - - assert result == ["existing-group"] - - # Verify the group still exists with same ID (wasn't recreated) - group = cast( - User, - await User.get_by_username( - session, - "existing-group", - options=[], - ), - ) - assert group.id == original_id - assert group.email == "existing@group.com" # Original email preserved - - -@pytest.mark.asyncio -async def test_sync_user_groups_warns_on_non_group_principal( - session: AsyncSession, - mocker: MockerFixture, - caplog, -): - """ - Test sync_user_groups logs warning when a principal exists but is not a group. - """ - # Create an existing user (not a group) with a name that matches a group - existing_user = User( - username="alice", - password=None, - email="alice@example.com", - name="Alice", - oauth_provider=OAuthProvider.BASIC, - is_admin=False, - kind=PrincipalKind.USER, - ) - session.add(existing_user) - await session.commit() - - # Mock the group membership service to return "alice" as a group - mock_service = mocker.MagicMock() - mock_service.get_user_groups = mocker.AsyncMock(return_value=["alice"]) - mocker.patch( - "datajunction_server.utils.get_group_membership_service", - return_value=mock_service, - ) - - with caplog.at_level(logging.WARNING): - result = await sync_user_groups(session, "testuser") - - assert result == ["alice"] - assert "Principal alice exists but is not a group (kind=user), skipping" in ( - caplog.text - ) - - -@pytest.mark.asyncio -async def test_sync_user_groups_mixed_existing_and_new( - session: AsyncSession, - mocker: MockerFixture, -): - """ - Test sync_user_groups handles mix of existing and new groups. - """ - # Create one existing group - existing_group = User( - username="existing-team", - password=None, - email=None, - name="Existing Team", - oauth_provider=OAuthProvider.BASIC, - is_admin=False, - kind=PrincipalKind.GROUP, - ) - session.add(existing_group) - await session.commit() - - # Mock the group membership service to return both existing and new groups - mock_service = mocker.MagicMock() - mock_service.get_user_groups = mocker.AsyncMock( - return_value=["existing-team", "new-team"], - ) - mocker.patch( - "datajunction_server.utils.get_group_membership_service", - return_value=mock_service, - ) - - result = await sync_user_groups(session, "testuser") - - assert result == ["existing-team", "new-team"] - - # Verify both groups exist - existing = await User.get_by_username(session, "existing-team", options=[]) - new = await User.get_by_username(session, "new-team", options=[]) - - assert existing is not None - assert existing.kind == PrincipalKind.GROUP - assert new is not None - assert new.kind == PrincipalKind.GROUP - assert new.name == "new-team" + # Check dataSchema + assert result["dataSchema"]["dataSource"] == "my_cube" + assert result["dataSchema"]["granularitySpec"]["queryGranularity"] == "HOUR" + assert result["dataSchema"]["granularitySpec"]["type"] == "uniform" + assert result["dataSchema"]["granularitySpec"]["segmentGranularity"] == "DAY" From 9c15b94d6b46a68b8a3a6a6d925e778107d88e58 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 4 May 2026 08:59:42 -0700 Subject: [PATCH 2/2] Fix tests --- datajunction-server/datajunction_server/api/cubes.py | 1 - .../datajunction_server/models/cube_materialization.py | 6 ------ datajunction-server/datajunction_server/utils.py | 2 +- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/datajunction-server/datajunction_server/api/cubes.py b/datajunction-server/datajunction_server/api/cubes.py index 25beef3e7..36258d993 100644 --- a/datajunction-server/datajunction_server/api/cubes.py +++ b/datajunction-server/datajunction_server/api/cubes.py @@ -690,7 +690,6 @@ async def materialize_cube( strategy=data.strategy, schedule=data.schedule, lookback_window=effective_lookback, - druid_overrides=data.druid_overrides, ) # Call the query service to create the workflow. diff --git a/datajunction-server/datajunction_server/models/cube_materialization.py b/datajunction-server/datajunction_server/models/cube_materialization.py index 185697d44..6573beec5 100644 --- a/datajunction-server/datajunction_server/models/cube_materialization.py +++ b/datajunction-server/datajunction_server/models/cube_materialization.py @@ -596,12 +596,6 @@ class CubeMaterializationV2Input(BaseModel): description="Lookback window for incremental. None for FULL strategy.", ) - # Druid spec overrides (passed through from request) - druid_overrides: Optional[Dict[str, Any]] = Field( - default=None, - description="User-provided overrides that were deep-merged into druid_spec.", - ) - class DruidCubeV3Config(BaseModel): """ diff --git a/datajunction-server/datajunction_server/utils.py b/datajunction-server/datajunction_server/utils.py index cebaba70f..988af10e1 100644 --- a/datajunction-server/datajunction_server/utils.py +++ b/datajunction-server/datajunction_server/utils.py @@ -609,7 +609,7 @@ def deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any >>> base = {"a": 1, "b": {"c": 2, "d": 3}} >>> overrides = {"b": {"c": 10, "e": 5}} >>> deep_merge(base, overrides) - {"a": 1, "b": {"c": 10, "d": 3, "e": 5}} + {'a': 1, 'b': {'c': 10, 'd': 3, 'e': 5}} """ result = copy.deepcopy(base)