Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions datajunction-server/datajunction_server/api/cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from datajunction_server.naming import from_amenable_name
from datajunction_server.service_clients import QueryServiceClient
from datajunction_server.utils import (
deep_merge,
get_current_user,
get_query_service_client,
get_session,
Expand Down Expand Up @@ -647,6 +648,15 @@ async def materialize_cube(
},
}

# Apply user-provided overrides to the Druid spec
if data.druid_overrides:
druid_spec = deep_merge(druid_spec, data.druid_overrides)
_logger.info(
"Applied druid_overrides to cube=%s: %s",
name,
list(data.druid_overrides.keys()),
)

# Convert columns to ColumnMetadata
output_columns = [
ColumnMetadata(
Expand Down Expand Up @@ -738,6 +748,7 @@ async def materialize_cube(
metrics=metrics_list,
timestamp_column=timestamp_column,
timestamp_format=timestamp_format,
druid_overrides=data.druid_overrides,
workflow_urls=workflow_urls,
workflow_names=workflow_names,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,15 @@ class CubeMaterializeRequest(BaseModel):
default=True,
description="Whether to run an initial backfill",
)
druid_overrides: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Override any part of the generated Druid ingestion spec. "
"This dict is deep-merged into the generated spec, allowing fine-grained "
"control over tuningConfig, indexSpec, granularitySpec, etc. "
"Example: {'tuningConfig': {'partitionsSpec': {'targetRowsPerSegment': 1000000}}}"
),
)


class PreAggTableInfo(BaseModel):
Expand Down Expand Up @@ -665,6 +674,12 @@ class DruidCubeV3Config(BaseModel):
description="Format of the timestamp column. None when timestamp_column is None.",
)

# User-provided Druid spec overrides (persisted for reference)
druid_overrides: Optional[Dict[str, Any]] = Field(
default=None,
description="User-provided overrides that were applied to the Druid spec.",
)

# Workflow tracking
workflow_urls: List[str] = Field(
default_factory=list,
Expand Down
38 changes: 37 additions & 1 deletion datajunction-server/datajunction_server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import asyncio
import copy
from contextlib import asynccontextmanager
import json
import logging
Expand All @@ -11,7 +12,7 @@
from functools import lru_cache
from http import HTTPStatus

from typing import AsyncIterator, List, Optional
from typing import Any, AsyncIterator, Dict, List, Optional

from dotenv import load_dotenv
from fastapi import Depends
Expand Down Expand Up @@ -590,4 +591,39 @@ async def sync_user_groups(
return group_names


def deep_merge(base: Dict[str, Any], overrides: Dict[str, Any]) -> Dict[str, Any]:
"""
Deep merge two dictionaries.

Values from `overrides` take precedence. Nested dictionaries are merged
recursively. Non-dict values in `overrides` replace values in `base`.

Args:
base: The base dictionary
overrides: Dictionary with values to merge/override

Returns:
A new dictionary with merged values (original dicts are not modified)

Example:
>>> base = {"a": 1, "b": {"c": 2, "d": 3}}
>>> overrides = {"b": {"c": 10, "e": 5}}
>>> deep_merge(base, overrides)
{'a': 1, 'b': {'c': 10, 'd': 3, 'e': 5}}
"""
result = copy.deepcopy(base)

for key, override_value in overrides.items():
if (
key in result
and isinstance(result[key], dict)
and isinstance(override_value, dict)
):
result[key] = deep_merge(result[key], override_value)
else:
result[key] = copy.deepcopy(override_value)

return result


SEPARATOR = "."
123 changes: 123 additions & 0 deletions datajunction-server/tests/api/cubes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4818,6 +4818,129 @@ async def test_materialize_cube_updates_existing_materialization(
assert data["strategy"] == "incremental_time"
assert data["schedule"] == "0 6 * * *"

@pytest.mark.asyncio
async def test_materialize_cube_with_druid_overrides(
self,
client_with_repairs_cube: AsyncClient,
mocker,
):
"""Test that druid_overrides are deep-merged into the generated Druid spec.

This allows users to override tuningConfig, indexSpec, or any other Druid
settings without DJ having to expose every possible knob.
"""
cube_name = "default.test_materialize_druid_overrides_cube"
await make_a_test_cube(
client_with_repairs_cube,
cube_name,
with_materialization=False,
)

mock_columns = [
V3ColumnMetadata(
name="date_id",
type="int",
semantic_name="default.hard_hat.hire_date",
semantic_type="dimension",
),
V3ColumnMetadata(
name="num_repair_orders",
type="bigint",
semantic_name="default.num_repair_orders",
semantic_type="measure",
),
]

mock_combined_result = _create_mock_combined_result(
mocker,
columns=mock_columns,
shared_dimensions=["default.hard_hat.hire_date"],
sql_string="SELECT date_id, COUNT(*) FROM preagg GROUP BY date_id",
)

mock_temporal_info = TemporalPartitionInfo(
column_name="date_id",
format="yyyyMMdd",
granularity="day",
)

mocker.patch(
"datajunction_server.api.cubes.build_combiner_sql_from_preaggs",
return_value=(
mock_combined_result,
[
PreAggSourceInfo(
table_ref="catalog.schema.preagg_table1",
parent_name="default.repair_orders",
strategy=None,
),
],
mock_temporal_info,
),
)

qs_client = client_with_repairs_cube.app.dependency_overrides[
get_query_service_client
]()
mocker.patch.object(
qs_client,
"materialize_cube_v2",
return_value=mocker.MagicMock(
urls=["http://workflow/cube-workflow"],
workflow_names=["cube-workflow"],
),
)

# Request with druid_overrides
druid_overrides = {
"tuningConfig": {
"partitionsSpec": {
"targetRowsPerSegment": 1000000, # Override default 5000000
"type": "single_dim", # Override default "hashed"
"partitionDimension": "date_id",
},
"maxNumConcurrentSubTasks": 20, # Add new field
},
"dataSchema": {
"granularitySpec": {
"queryGranularity": "HOUR", # Add rollup granularity
},
},
}

response = await client_with_repairs_cube.post(
f"/cubes/{cube_name}/materialize",
json={
"strategy": "incremental_time",
"schedule": "0 6 * * *",
"lookback_window": "1 DAY",
"druid_overrides": druid_overrides,
},
)

assert response.status_code == 200, response.json()
data = response.json()

# Verify overrides were applied
druid_spec = data["druid_spec"]

# Check tuningConfig overrides
tuning_config = druid_spec["tuningConfig"]
assert tuning_config["partitionsSpec"]["targetRowsPerSegment"] == 1000000
assert tuning_config["partitionsSpec"]["type"] == "single_dim"
assert tuning_config["partitionsSpec"]["partitionDimension"] == "date_id"
assert tuning_config["maxNumConcurrentSubTasks"] == 20
# Verify existing defaults are preserved
assert tuning_config["useCombiner"] is True
assert tuning_config["type"] == "hadoop"

# Check dataSchema override
granularity_spec = druid_spec["dataSchema"]["granularitySpec"]
assert granularity_spec["queryGranularity"] == "HOUR"
# Verify existing defaults are preserved
assert granularity_spec["type"] == "uniform"
assert granularity_spec["segmentGranularity"] == "DAY"

@pytest.mark.asyncio
async def test_materialize_cube_returns_metric_combiners(
self,
Expand Down
Loading
Loading