diff --git a/datajunction-server/datajunction_server/api/nodes.py b/datajunction-server/datajunction_server/api/nodes.py index 8cbc773a6..d9e9386d2 100644 --- a/datajunction-server/datajunction_server/api/nodes.py +++ b/datajunction-server/datajunction_server/api/nodes.py @@ -971,6 +971,7 @@ async def add_reference_dimension_link( ), ) await session.commit() + await session.refresh(target_column) return JSONResponse( status_code=201, content={ diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index 2f140c2d1..dc8a38fa6 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -1275,6 +1275,30 @@ async def dimension_join_path( return join_path await refresh_if_needed(session, current_link.dimension, ["current"]) + + # Check the reference links on this dimension node + await refresh_if_needed(session, current_link.dimension.current, ["columns"]) + for col in current_link.dimension.current.columns: + if col.dimension: + # Check if it matches the reference link dimension attribute + if f"{col.dimension.name}.{col.dimension_column}" == dimension: + return join_path + # Check if it matches any of the reference link dimension's linked attributes + await refresh_if_needed(session, col.dimension, ["current"]) + await refresh_if_needed( + session, + col.dimension.current, + ["dimension_links"], + ) + for link in col.dimension.current.dimension_links: + if ( + link.foreign_keys.get( + f"{col.dimension.name}.{col.dimension_column}", + ) + == dimension + ): + return join_path + await refresh_if_needed( session, current_link.dimension.current, @@ -1410,7 +1434,19 @@ def build_dimension_attribute( if dimension_attr.name in link.foreign_keys_reversed else None ) + reference_links = { + col.name: f"{col.dimension.name}.{col.dimension_column}" + for col in link.dimension.current.columns + if col.dimension + } for col in node_query.select.projection: + if reference_links.get(col.alias_or_name.name) == full_column_name: # type: ignore + return ast.Column( + name=ast.Name(col.alias_or_name.name), # type: ignore + alias=ast.Name(alias) if alias else None, + _table=node_query, + _type=col.type, # type: ignore + ) if col.alias_or_name.name == dimension_attr.column_name or ( # type: ignore foreign_key_column_name and col.alias_or_name.identifier() == foreign_key_column_name # type: ignore diff --git a/datajunction-server/datajunction_server/models/node.py b/datajunction-server/datajunction_server/models/node.py index f18c06d35..a5ba4b075 100644 --- a/datajunction-server/datajunction_server/models/node.py +++ b/datajunction-server/datajunction_server/models/node.py @@ -591,8 +591,8 @@ class ColumnOutput(BaseModel): type: str attributes: Optional[List[AttributeOutput]] dimension: Optional[NodeNameOutput] + dimension_column: Optional[str] partition: Optional[PartitionOutput] - # order: Optional[int] class Config: # pylint: disable=missing-class-docstring, too-few-public-methods """ diff --git a/datajunction-server/datajunction_server/sql/dag.py b/datajunction-server/datajunction_server/sql/dag.py index 095b91613..a08f172f7 100644 --- a/datajunction-server/datajunction_server/sql/dag.py +++ b/datajunction-server/datajunction_server/sql/dag.py @@ -352,13 +352,13 @@ async def get_dimensions_dag( # pylint: disable=too-many-locals ) .join( graph_branches, - (current_rev.id == graph_branches.c.node_revision_id) - & (is_(graph_branches.c.dimension_column, None)), + (current_rev.id == graph_branches.c.node_revision_id), + # & (is_(graph_branches.c.dimension_column, None)), ) .join( next_node, (next_node.id == graph_branches.c.dimension_id) - & (is_(graph_branches.c.dimension_column, None)) + # & (is_(graph_branches.c.dimension_column, None)) & (is_(next_node.deactivated_at, None)), ) .join( diff --git a/datajunction-server/tests/api/dimension_links_test.py b/datajunction-server/tests/api/dimension_links_test.py index bff6db519..ce2172305 100644 --- a/datajunction-server/tests/api/dimension_links_test.py +++ b/datajunction-server/tests/api/dimension_links_test.py @@ -101,6 +101,27 @@ async def _link_events_to_users_without_role() -> Response: return _link_events_to_users_without_role +@pytest.fixture +def reference_link_users_date( + dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name +): + """ + Create a reference link between users and date + """ + + async def _reference_link_users_date() -> Response: + response = await dimensions_link_client.post( + "/nodes/default.users/columns/snapshot_date/link", + params={ + "dimension_node": "default.date", + "dimension_column": "dateint", + }, + ) + return response + + return _reference_link_users_date + + @pytest.fixture def link_events_to_users_with_role_direct( dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name @@ -964,6 +985,59 @@ async def test_measures_sql_with_reference_dimension_links( assert response_data[0]["errors"] == [] +@pytest.mark.asyncio +async def test_measures_sql_with_ref_link_on_dim_node( + dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name + link_events_to_users_without_role, # pylint: disable=redefined-outer-name + reference_link_users_date, # pylint: disable=redefined-outer-name +): + """ + Verify that measures SQL can be retrieved for dimension attributes that come from a + reference dimension link from one dim node to another dim node. + """ + await link_events_to_users_without_role() + await reference_link_users_date() + + response = await dimensions_link_client.get( + "/sql/measures/v2", + params={ + "metrics": ["default.elapsed_secs"], + "dimensions": [ + "default.date.dateint", + ], + }, + ) + response_data = response.json() + expected_sql = """ + WITH default_DOT_events AS ( + SELECT + default_DOT_events_table.user_id, + default_DOT_events_table.event_start_date, + default_DOT_events_table.event_end_date, + default_DOT_events_table.elapsed_secs, + default_DOT_events_table.user_registration_country + FROM examples.events AS default_DOT_events_table + ), + default_DOT_users AS ( + SELECT + default_DOT_users_table.user_id, + default_DOT_users_table.snapshot_date, + default_DOT_users_table.registration_country, + default_DOT_users_table.residence_country, + default_DOT_users_table.account_type + FROM examples.users AS default_DOT_users_table + ) + SELECT + default_DOT_events.elapsed_secs default_DOT_events_DOT_elapsed_secs, + default_DOT_users.snapshot_date default_DOT_date_DOT_dateint + FROM default_DOT_events + LEFT JOIN default_DOT_users + ON default_DOT_events.user_id = default_DOT_users.user_id + AND default_DOT_events.event_start_date = default_DOT_users.snapshot_date + """ + assert str(parse(response_data[0]["sql"])) == str(parse(expected_sql)) + + @pytest.mark.asyncio async def test_dimension_link_cross_join( dimensions_link_client: AsyncClient, # pylint: disable=redefined-outer-name diff --git a/datajunction-server/tests/examples.py b/datajunction-server/tests/examples.py index cac6caa5c..0fa709a06 100644 --- a/datajunction-server/tests/examples.py +++ b/datajunction-server/tests/examples.py @@ -2258,6 +2258,16 @@ "primary_key": ["country_code"], }, ), + ( + "/nodes/dimension/", + { + "description": "Date dimension", + "query": """SELECT 1 AS dateint""", + "mode": "published", + "name": "default.date", + "primary_key": ["dateint"], + }, + ), ( "/nodes/metric/", {