From a91f4599ca4b98c4c1b21eb1c51fb8b65bd4cdf9 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 7 Aug 2023 09:19:20 -0700 Subject: [PATCH 1/5] Add JSON serializer for query ASTs and store them upon node creation --- ...f91d2d69e_add_query_ast_to_noderevision.py | 27 +++ .../datajunction_server/api/client.py | 1 + .../datajunction_server/api/helpers.py | 3 +- .../datajunction_server/models/node.py | 8 +- .../datajunction_server/sql/parsing/ast.py | 25 +++ .../sql/parsing/ast_json_encoder.py | 37 ++++ datajunction-server/tests/api/nodes_test.py | 182 +++++++++++++++++- 7 files changed, 280 insertions(+), 3 deletions(-) create mode 100644 datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py create mode 100644 datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py diff --git a/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py new file mode 100644 index 000000000..3e2d698d7 --- /dev/null +++ b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py @@ -0,0 +1,27 @@ +"""Add query ast to noderevision + +Revision ID: 789f91d2d69e +Revises: ccc77abcf899 +Create Date: 2023-08-07 14:32:54.290688+00:00 + +""" +# pylint: disable=no-member, invalid-name, missing-function-docstring, unused-import, no-name-in-module + +import sqlalchemy as sa +import sqlmodel +from alembic import op + + +# revision identifiers, used by Alembic. +revision = '789f91d2d69e' +down_revision = 'ccc77abcf899' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('noderevision', sa.Column('query_ast', sa.JSON(), nullable=True)) + + +def downgrade(): + op.drop_column('noderevision', 'query_ast') diff --git a/datajunction-server/datajunction_server/api/client.py b/datajunction-server/datajunction_server/api/client.py index 4308d69f6..b2370cfbb 100644 --- a/datajunction-server/datajunction_server/api/client.py +++ b/datajunction-server/datajunction_server/api/client.py @@ -38,6 +38,7 @@ def client_code_for_creating_node( "node_id", "updated_at", "query" if node.type == NodeType.CUBE else "", + "query_ast", }, exclude_none=True, ) diff --git a/datajunction-server/datajunction_server/api/helpers.py b/datajunction-server/datajunction_server/api/helpers.py index bb4ff28f9..2809d6ca8 100644 --- a/datajunction-server/datajunction_server/api/helpers.py +++ b/datajunction-server/datajunction_server/api/helpers.py @@ -58,6 +58,7 @@ from datajunction_server.service_clients import QueryServiceClient from datajunction_server.sql.dag import get_nodes_with_dimension from datajunction_server.sql.parsing import ast +from datajunction_server.sql.parsing.ast_json_encoder import ASTEncoder from datajunction_server.sql.parsing.backends.antlr4 import SqlSyntaxError, parse from datajunction_server.sql.parsing.backends.exceptions import DJParseException from datajunction_server.typing import END_JOB_STATES, UTCDatetime @@ -415,7 +416,7 @@ def validate_node_data( # pylint: disable=too-many-locals dependencies_map, ) validated_node.required_dimensions = matched_bound_columns - + validated_node.query_ast = json.loads(json.dumps(query_ast, cls=ASTEncoder)) errors = [] if missing_parents_map or type_inference_failures or invalid_required_dimensions: # update status (if needed) diff --git a/datajunction-server/datajunction_server/models/node.py b/datajunction-server/datajunction_server/models/node.py index 013905b28..12012fb6b 100644 --- a/datajunction-server/datajunction_server/models/node.py +++ b/datajunction-server/datajunction_server/models/node.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from functools import partial from http import HTTPStatus -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Extra from pydantic import Field as PydanticField @@ -678,6 +678,11 @@ class NodeRevision(NodeRevisionBase, table=True): # type: ignore }, ) + query_ast: Optional[Dict[str, Any]] = Field( + sa_column=SqlaColumn("query_ast", JSON), + default={}, + ) + def __hash__(self) -> int: return hash(self.id) @@ -1101,6 +1106,7 @@ class NodeRevisionOutput(SQLModel): table: Optional[str] description: str = "" query: Optional[str] = None + query_ast: Optional[Dict] = {} availability: Optional[AvailabilityState] = None columns: List[ColumnOutput] updated_at: UTCDatetime diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index 32a4217b3..dffbc3e1e 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -102,6 +102,10 @@ class Node(ABC): _is_compiled: bool = False + @property + def json_ignore_keys(self): + return ["parent", "parent_key", "_is_compiled"] + def __post_init__(self): self.add_self_as_parent() @@ -705,6 +709,12 @@ class Column(Aliasable, Named, Expression): _expression: Optional[Expression] = field(repr=False, default=None) _is_compiled: bool = False + @property + def json_ignore_keys(self): + if set(self._expression.columns).intersection(self.columns): + return ["parent", "parent_key", "_is_compiled", "_expression", "columns"] + return ["parent", "parent_key", "_is_compiled", "columns"] + @property def type(self): if self._type: @@ -985,6 +995,17 @@ class TableExpression(Aliasable, Expression): # ref (referenced) columns are columns used elsewhere from this table _ref_columns: List[Column] = field(init=False, repr=False, default_factory=list) + @property + def json_ignore_keys(self): + return [ + "parent", + "parent_key", + "_is_compiled", + "_columns", + "column_list", + "_ref_columns", + ] + @property def columns(self) -> List[Expression]: """ @@ -2003,6 +2024,10 @@ class FunctionTable(FunctionTableExpression): Represents a table-valued function used in a statement """ + @property + def json_ignore_keys(self): + return ["parent", "parent_key", "_is_compiled", "_table"] + def __str__(self) -> str: alias = f" {self.alias}" if self.alias else "" as_ = " AS " if self.as_ else "" diff --git a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py new file mode 100644 index 000000000..a50603917 --- /dev/null +++ b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py @@ -0,0 +1,37 @@ +""" +JSON encoder for AST objects +""" +from json import JSONEncoder + + +class ASTEncoder(JSONEncoder): + """ + JSON encoder for AST objects. Disables the original circular check in favor + of our own version with _processed so that we can catch and handle circular + traversals. + """ + + def __init__(self, *args, **kwargs): + kwargs["check_circular"] = False # no need to check anymore + super().__init__(*args, **kwargs) + self._processed = set() + + def default(self, o): + if id(o) in self._processed: + return None + self._processed.add(id(o)) + + if o.__class__.__name__ == "NodeRevision": + return { + "__class__": o.__class__.__name__, + "name": o.name, + "type": o.type, + } + + json_dict = { + k: o.__dict__[k] + for k in o.__dict__ + if hasattr(o, "json_ignore_keys") and k not in o.json_ignore_keys + } + json_dict["__class__"] = o.__class__.__name__ + return json_dict diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 0a733069f..fca04f6dd 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -1402,6 +1402,118 @@ def test_create_update_transform_node( }, ] assert data["parents"] == [{"name": "basic.source.users"}] + assert data["query_ast"] == { # type: ignore + "__class__": "Query", + "alias": None, + "as_": None, + "ctes": [], + "name": { + "__class__": "DefaultName", + "name": "", + "namespace": None, + "quote_style": "", + }, + "select": { + "__class__": "Select", + "alias": None, + "as_": None, + "from_": { # type: ignore + "__class__": "From", + "laterals": [], + "relations": [ + {"__class__": "Relation", "extensions": [], "primary": None}, + ], + }, + "group_by": [], # type: ignore + "having": None, + "lateral_views": [], # type: ignore + "limit": None, + "organization": { # type: ignore + "__class__": "Organization", + "order": [], + "sort": [], + }, + "projection": [ # type: ignore + { + "__class__": "Column", + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "users", + "namespace": { + "__class__": "Name", + "name": "source", + "namespace": { + "__class__": "Name", + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "StringType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "country", + "namespace": None, + "quote_style": "", + }, + }, + { + "__class__": "Alias", + "alias": { + "__class__": "Name", + "name": "num_users", + "namespace": None, + "quote_style": "", + }, + "as_": True, + "child": { + "__class__": "Function", + "args": [ + { + "__class__": "Column", + "_table": None, + "_type": {"__class__": "IntegerType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "id", + "namespace": None, + "quote_style": "", + }, + }, + ], + "name": { + "__class__": "Name", + "name": "COUNT", + "namespace": None, + "quote_style": "", + }, + "over": None, + "quantifier": "DISTINCT", + }, + }, + ], + "quantifier": "", + "set_op": None, + "where": None, + }, + } # Update the transform node with two minor changes response = client.patch( @@ -2774,7 +2886,7 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: data = response.json() assert response.status_code == 422 - assert data == { + assert data == { # type: ignore "message": "Node `foo` is invalid.", "status": "invalid", "node_revision": { @@ -2783,6 +2895,74 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: "type": "transform", "description": "This is my foo transform node!", "query": "SELECT 1 FROM node_that_does_not_exist", + "query_ast": { + "__class__": "Query", + "alias": None, + "as_": None, + "ctes": [], + "name": { + "__class__": "DefaultName", + "name": "", + "namespace": None, + "quote_style": "", + }, + "select": { + "__class__": "Select", + "alias": None, + "as_": None, + "from_": { # type: ignore + "__class__": "From", + "laterals": [], + "relations": [ + { + "__class__": "Relation", + "extensions": [], + "primary": { + "__class__": "Table", + "_dj_node": None, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "node_that_does_not_exist", + "namespace": None, + "quote_style": "", + }, + }, + }, + ], + }, + "group_by": [], # type: ignore + "having": None, + "lateral_views": [], # type: ignore + "limit": None, + "organization": { # type: ignore + "__class__": "Organization", + "order": [], + "sort": [], + }, + "projection": [ # type: ignore + { + "__class__": "Alias", + "alias": { + "__class__": "Name", + "name": "col0", + "namespace": None, + "quote_style": "", + }, + "as_": None, + "child": { + "__class__": "Number", + "_type": None, + "value": 1, + }, + }, + ], + "quantifier": "", + "set_op": None, + "where": None, + }, + }, "mode": "published", "id": None, "version": "v0.1", From 96e9bf72aeed883b064237b6ae2bb2a56c4ffccc Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 7 Aug 2023 09:38:48 -0700 Subject: [PATCH 2/5] Fix lint --- ..._1432-789f91d2d69e_add_query_ast_to_noderevision.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py index 3e2d698d7..b31013bb2 100644 --- a/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py +++ b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py @@ -9,19 +9,19 @@ import sqlalchemy as sa import sqlmodel -from alembic import op +from alembic import op # revision identifiers, used by Alembic. -revision = '789f91d2d69e' -down_revision = 'ccc77abcf899' +revision = "789f91d2d69e" +down_revision = "ccc77abcf899" branch_labels = None depends_on = None def upgrade(): - op.add_column('noderevision', sa.Column('query_ast', sa.JSON(), nullable=True)) + op.add_column("noderevision", sa.Column("query_ast", sa.JSON(), nullable=True)) def downgrade(): - op.drop_column('noderevision', 'query_ast') + op.drop_column("noderevision", "query_ast") From b7eff1cf5d041c5e394974cc2ceff5efea8651a8 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 9 Aug 2023 09:07:37 -0700 Subject: [PATCH 3/5] Update json serializer so that we automatically short-circuit circular references and thus can serialize more of the AST --- .../datajunction_server/models/node.py | 9 + .../datajunction_server/sql/parsing/ast.py | 36 +++- .../sql/parsing/ast_json_encoder.py | 82 ++++++-- .../datajunction_server/sql/parsing/types.py | 5 + datajunction-server/tests/api/nodes_test.py | 178 ++++++++++++++++-- datajunction-server/tests/api/sql_test.py | 99 +++++----- 6 files changed, 325 insertions(+), 84 deletions(-) diff --git a/datajunction-server/datajunction_server/models/node.py b/datajunction-server/datajunction_server/models/node.py index 12012fb6b..99dc9b613 100644 --- a/datajunction-server/datajunction_server/models/node.py +++ b/datajunction-server/datajunction_server/models/node.py @@ -835,6 +835,15 @@ def has_available_materialization(self, build_criteria: BuildCriteria) -> bool: ) ) + def __json_encode__(self): + """ + JSON encoder for node revision + """ + return { + "name": self.name, + "type": self.type, + } + class ImmutableNodeFields(BaseSQLModel): """ diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index dffbc3e1e..564b5f9b1 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -106,6 +106,13 @@ class Node(ABC): def json_ignore_keys(self): return ["parent", "parent_key", "_is_compiled"] + def __json_encode__(self): + return { + key: self.__dict__[key] + for key in self.__dict__ + if key not in self.json_ignore_keys + } + def __post_init__(self): self.add_self_as_parent() @@ -628,6 +635,10 @@ def identifier(self, quotes: bool = True) -> str: f"{namespace}{quote_style}{self.name}{quote_style}" # pylint: disable=C0301 ) + @property + def json_ignore_keys(self): + return ["names", "parent", "parent_key"] + TNamed = TypeVar("TNamed", bound="Named") # pylint: disable=C0103 @@ -711,9 +722,7 @@ class Column(Aliasable, Named, Expression): @property def json_ignore_keys(self): - if set(self._expression.columns).intersection(self.columns): - return ["parent", "parent_key", "_is_compiled", "_expression", "columns"] - return ["parent", "parent_key", "_is_compiled", "columns"] + return ["parent", "parent_key", "columns"] @property def type(self): @@ -1000,10 +1009,11 @@ def json_ignore_keys(self): return [ "parent", "parent_key", - "_is_compiled", + # "_is_compiled", "_columns", - "column_list", + # "column_list", "_ref_columns", + "columns", ] @property @@ -1250,6 +1260,11 @@ class BinaryOpKind(DJEnum): Minus = "-" Modulo = "%" + def __json_encode__(self): + return { + "value": self.value, + } + @dataclass(eq=False) class BinaryOp(Operation): @@ -2026,7 +2041,16 @@ class FunctionTable(FunctionTableExpression): @property def json_ignore_keys(self): - return ["parent", "parent_key", "_is_compiled", "_table"] + return [ + "parent", + "parent_key", + "_is_compiled", + "_table", + "_columns", + "column_list", + "_ref_columns", + "columns", + ] def __str__(self) -> str: alias = f" {self.alias}" if self.alias else "" diff --git a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py index a50603917..aecb1d41d 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py @@ -3,6 +3,30 @@ """ from json import JSONEncoder +from sqlmodel import select + +from datajunction_server.models import Node +from datajunction_server.sql.parsing import ast + + +def remove_circular_refs(obj, _seen: set = None): + """ + Short-circuits circular references in AST nodes + """ + if _seen is None: + _seen = set() + if id(obj) in _seen: + return None + _seen.add(id(obj)) + if issubclass(obj.__class__, ast.Node): + serializable_keys = [ + key for key in obj.__dict__.keys() if key not in obj.json_ignore_keys + ] + for key in serializable_keys: + setattr(obj, key, remove_circular_refs(getattr(obj, key), _seen)) + _seen.remove(id(obj)) + return obj + class ASTEncoder(JSONEncoder): """ @@ -12,26 +36,50 @@ class ASTEncoder(JSONEncoder): """ def __init__(self, *args, **kwargs): - kwargs["check_circular"] = False # no need to check anymore + kwargs["check_circular"] = False + self.markers = set() super().__init__(*args, **kwargs) - self._processed = set() def default(self, o): - if id(o) in self._processed: - return None - self._processed.add(id(o)) - - if o.__class__.__name__ == "NodeRevision": - return { - "__class__": o.__class__.__name__, - "name": o.name, - "type": o.type, - } - + o = remove_circular_refs(o) json_dict = { - k: o.__dict__[k] - for k in o.__dict__ - if hasattr(o, "json_ignore_keys") and k not in o.json_ignore_keys + "__class__": o.__class__.__name__, } - json_dict["__class__"] = o.__class__.__name__ + if hasattr(o, "__json_encode__"): + json_dict = {**json_dict, **o.__json_encode__()} return json_dict + + +def ast_decoder(session, json_dict): + """Decodes json dict""" + class_name = json_dict["__class__"] + if not class_name or not hasattr(ast, class_name): + return None + clazz = getattr(ast, class_name) + if class_name == "NodeRevision": + instance = ( + session.exec(select(Node).where(Node.name == json_dict["name"])) + .one() + .current + ) + else: + instance = clazz( + **{ + k: v + for k, v in json_dict.items() + if k not in {"__class__", "_type", "laterals", "_is_compiled"} + }, + ) + for key, value in json_dict.items(): + if key not in {"__class__", "_is_compiled"}: + try: + setattr(instance, key, value) + except AttributeError: + pass + + if class_name == "Table": + instance._columns = [ # pylint: disable=protected-access + ast.Column(ast.Name(col.name), _table=instance, _type=col.type) + for col in instance._dj_node.columns # pylint: disable=protected-access + ] + return instance diff --git a/datajunction-server/datajunction_server/sql/parsing/types.py b/datajunction-server/datajunction_server/sql/parsing/types.py index 754714a39..f10179969 100644 --- a/datajunction-server/datajunction_server/sql/parsing/types.py +++ b/datajunction-server/datajunction_server/sql/parsing/types.py @@ -74,6 +74,11 @@ def __str__(self): def __deepcopy__(self, memo): return self + def __json_encode__(self): + return { + "__class__": self.__class__.__name__, + } + @classmethod def __get_validators__(cls) -> Generator[AnyCallable, None, None]: """ diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index fca04f6dd..256f7262b 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -1402,13 +1402,14 @@ def test_create_update_transform_node( }, ] assert data["parents"] == [{"name": "basic.source.users"}] - assert data["query_ast"] == { # type: ignore + assert data["query_ast"] == { "__class__": "Query", "alias": None, "as_": None, "ctes": [], "name": { "__class__": "DefaultName", + "_is_compiled": True, "name": "", "namespace": None, "quote_style": "", @@ -1417,25 +1418,95 @@ def test_create_update_transform_node( "__class__": "Select", "alias": None, "as_": None, - "from_": { # type: ignore + "from_": { "__class__": "From", "laterals": [], "relations": [ - {"__class__": "Relation", "extensions": [], "primary": None}, + { + "__class__": "Relation", + "extensions": [], + "primary": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + }, ], }, - "group_by": [], # type: ignore + "group_by": [], "having": None, - "lateral_views": [], # type: ignore + "lateral_views": [], "limit": None, - "organization": { # type: ignore - "__class__": "Organization", - "order": [], - "sort": [], - }, - "projection": [ # type: ignore + "organization": {"__class__": "Organization", "order": [], "sort": []}, + "projection": [ { "__class__": "Column", + "_expression": { + "__class__": "Column", + "_expression": None, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "StringType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "country", + "namespace": None, + "quote_style": "", + }, + }, "_table": { "__class__": "Table", "_dj_node": { @@ -1447,12 +1518,15 @@ def test_create_update_transform_node( "as_": None, "name": { "__class__": "Name", + "_is_compiled": True, "name": "users", "namespace": { "__class__": "Name", + "_is_compiled": True, "name": "source", "namespace": { "__class__": "Name", + "_is_compiled": True, "name": "basic", "namespace": None, "quote_style": "", @@ -1467,6 +1541,7 @@ def test_create_update_transform_node( "as_": None, "name": { "__class__": "Name", + "_is_compiled": True, "name": "country", "namespace": None, "quote_style": "", @@ -1476,6 +1551,7 @@ def test_create_update_transform_node( "__class__": "Alias", "alias": { "__class__": "Name", + "_is_compiled": True, "name": "num_users", "namespace": None, "quote_style": "", @@ -1486,12 +1562,83 @@ def test_create_update_transform_node( "args": [ { "__class__": "Column", - "_table": None, + "_expression": { + "__class__": "Column", + "_expression": None, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "IntegerType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "id", + "namespace": None, + "quote_style": "", + }, + }, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, "_type": {"__class__": "IntegerType"}, "alias": None, "as_": None, "name": { "__class__": "Name", + "_is_compiled": True, "name": "id", "namespace": None, "quote_style": "", @@ -1500,6 +1647,7 @@ def test_create_update_transform_node( ], "name": { "__class__": "Name", + "_is_compiled": True, "name": "COUNT", "namespace": None, "quote_style": "", @@ -2897,11 +3045,14 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: "query": "SELECT 1 FROM node_that_does_not_exist", "query_ast": { "__class__": "Query", + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "ctes": [], "name": { "__class__": "DefaultName", + "_is_compiled": True, "name": "", "namespace": None, "quote_style": "", @@ -2920,10 +3071,13 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: "primary": { "__class__": "Table", "_dj_node": None, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", + "_is_compiled": True, "name": "node_that_does_not_exist", "namespace": None, "quote_style": "", diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 2a6fc8e8a..63eaa39ea 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -350,56 +350,56 @@ def test_sql_with_filters( "node_name, dimensions, filters, orderby, sql", [ # querying on source node with filter on joinable dimension - ( - "foo.bar.repair_orders", - [], - ["foo.bar.hard_hat.state='CA'"], - [], - """ - SELECT foo_DOT_bar_DOT_repair_orders.dispatched_date, - foo_DOT_bar_DOT_repair_orders.dispatcher_id, - foo_DOT_bar_DOT_hard_hat.state, - foo_DOT_bar_DOT_repair_orders.hard_hat_id, - foo_DOT_bar_DOT_repair_orders.municipality_id, - foo_DOT_bar_DOT_repair_orders.order_date, - foo_DOT_bar_DOT_repair_orders.repair_order_id, - foo_DOT_bar_DOT_repair_orders.required_date - FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - LEFT OUTER JOIN ( - SELECT foo_DOT_bar_DOT_repair_orders.dispatcher_id, - foo_DOT_bar_DOT_repair_orders.hard_hat_id, - foo_DOT_bar_DOT_repair_orders.municipality_id, - foo_DOT_bar_DOT_repair_orders.repair_order_id - FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - ) AS foo_DOT_bar_DOT_repair_order ON foo_DOT_bar_DOT_repair_orders.repair_order_id = foo_DOT_bar_DOT_repair_order.repair_order_id - LEFT OUTER JOIN ( - SELECT foo_DOT_bar_DOT_hard_hats.hard_hat_id, - foo_DOT_bar_DOT_hard_hats.state - FROM roads.hard_hats AS foo_DOT_bar_DOT_hard_hats - ) AS foo_DOT_bar_DOT_hard_hat ON foo_DOT_bar_DOT_repair_order.hard_hat_id = foo_DOT_bar_DOT_hard_hat.hard_hat_id - WHERE foo_DOT_bar_DOT_hard_hat.state = 'CA' - """, - ), + # ( + # "foo.bar.repair_orders", + # [], + # ["foo.bar.hard_hat.state='CA'"], + # [], + # """ + # SELECT foo_DOT_bar_DOT_repair_orders.dispatched_date, + # foo_DOT_bar_DOT_repair_orders.dispatcher_id, + # foo_DOT_bar_DOT_hard_hat.state, + # foo_DOT_bar_DOT_repair_orders.hard_hat_id, + # foo_DOT_bar_DOT_repair_orders.municipality_id, + # foo_DOT_bar_DOT_repair_orders.order_date, + # foo_DOT_bar_DOT_repair_orders.repair_order_id, + # foo_DOT_bar_DOT_repair_orders.required_date + # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + # LEFT OUTER JOIN ( + # SELECT foo_DOT_bar_DOT_repair_orders.dispatcher_id, + # foo_DOT_bar_DOT_repair_orders.hard_hat_id, + # foo_DOT_bar_DOT_repair_orders.municipality_id, + # foo_DOT_bar_DOT_repair_orders.repair_order_id + # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + # ) AS foo_DOT_bar_DOT_repair_order ON foo_DOT_bar_DOT_repair_orders.repair_order_id = foo_DOT_bar_DOT_repair_order.repair_order_id + # LEFT OUTER JOIN ( + # SELECT foo_DOT_bar_DOT_hard_hats.hard_hat_id, + # foo_DOT_bar_DOT_hard_hats.state + # FROM roads.hard_hats AS foo_DOT_bar_DOT_hard_hats + # ) AS foo_DOT_bar_DOT_hard_hat ON foo_DOT_bar_DOT_repair_order.hard_hat_id = foo_DOT_bar_DOT_hard_hat.hard_hat_id + # WHERE foo_DOT_bar_DOT_hard_hat.state = 'CA' + # """, + # ), # querying source node with filters directly on the node - ( - "foo.bar.repair_orders", - [], - ["foo.bar.repair_orders.order_date='2009-08-14'"], - [], - """ - SELECT - foo_DOT_bar_DOT_repair_orders.dispatched_date, - foo_DOT_bar_DOT_repair_orders.dispatcher_id, - foo_DOT_bar_DOT_repair_orders.hard_hat_id, - foo_DOT_bar_DOT_repair_orders.municipality_id, - foo_DOT_bar_DOT_repair_orders.order_date, - foo_DOT_bar_DOT_repair_orders.repair_order_id, - foo_DOT_bar_DOT_repair_orders.required_date - FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - WHERE - foo_DOT_bar_DOT_repair_orders.order_date = '2009-08-14' - """, - ), + # ( + # "foo.bar.repair_orders", + # [], + # ["foo.bar.repair_orders.order_date='2009-08-14'"], + # [], + # """ + # SELECT + # foo_DOT_bar_DOT_repair_orders.dispatched_date, + # foo_DOT_bar_DOT_repair_orders.dispatcher_id, + # foo_DOT_bar_DOT_repair_orders.hard_hat_id, + # foo_DOT_bar_DOT_repair_orders.municipality_id, + # foo_DOT_bar_DOT_repair_orders.order_date, + # foo_DOT_bar_DOT_repair_orders.repair_order_id, + # foo_DOT_bar_DOT_repair_orders.required_date + # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + # WHERE + # foo_DOT_bar_DOT_repair_orders.order_date = '2009-08-14' + # """, + # ), ( "foo.bar.num_repair_orders", [], @@ -546,6 +546,7 @@ def test_sql_with_filters_on_namespaced_nodes( # pylint: disable=R0913 params={"dimensions": dimensions, "filters": filters, "orderby": orderby}, ) data = response.json() + print(data["sql"]) assert compare_query_strings(data["sql"], sql) From a18c3a877803af983d6a1b150c78103235666b41 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 9 Aug 2023 09:09:15 -0700 Subject: [PATCH 4/5] Undo sql test changes --- datajunction-server/tests/api/sql_test.py | 98 +++++++++++------------ 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 63eaa39ea..e0afcbeb1 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -350,56 +350,56 @@ def test_sql_with_filters( "node_name, dimensions, filters, orderby, sql", [ # querying on source node with filter on joinable dimension - # ( - # "foo.bar.repair_orders", - # [], - # ["foo.bar.hard_hat.state='CA'"], - # [], - # """ - # SELECT foo_DOT_bar_DOT_repair_orders.dispatched_date, - # foo_DOT_bar_DOT_repair_orders.dispatcher_id, - # foo_DOT_bar_DOT_hard_hat.state, - # foo_DOT_bar_DOT_repair_orders.hard_hat_id, - # foo_DOT_bar_DOT_repair_orders.municipality_id, - # foo_DOT_bar_DOT_repair_orders.order_date, - # foo_DOT_bar_DOT_repair_orders.repair_order_id, - # foo_DOT_bar_DOT_repair_orders.required_date - # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - # LEFT OUTER JOIN ( - # SELECT foo_DOT_bar_DOT_repair_orders.dispatcher_id, - # foo_DOT_bar_DOT_repair_orders.hard_hat_id, - # foo_DOT_bar_DOT_repair_orders.municipality_id, - # foo_DOT_bar_DOT_repair_orders.repair_order_id - # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - # ) AS foo_DOT_bar_DOT_repair_order ON foo_DOT_bar_DOT_repair_orders.repair_order_id = foo_DOT_bar_DOT_repair_order.repair_order_id - # LEFT OUTER JOIN ( - # SELECT foo_DOT_bar_DOT_hard_hats.hard_hat_id, - # foo_DOT_bar_DOT_hard_hats.state - # FROM roads.hard_hats AS foo_DOT_bar_DOT_hard_hats - # ) AS foo_DOT_bar_DOT_hard_hat ON foo_DOT_bar_DOT_repair_order.hard_hat_id = foo_DOT_bar_DOT_hard_hat.hard_hat_id - # WHERE foo_DOT_bar_DOT_hard_hat.state = 'CA' - # """, - # ), + ( + "foo.bar.repair_orders", + [], + ["foo.bar.hard_hat.state='CA'"], + [], + """ + SELECT foo_DOT_bar_DOT_repair_orders.dispatched_date, + foo_DOT_bar_DOT_repair_orders.dispatcher_id, + foo_DOT_bar_DOT_hard_hat.state, + foo_DOT_bar_DOT_repair_orders.hard_hat_id, + foo_DOT_bar_DOT_repair_orders.municipality_id, + foo_DOT_bar_DOT_repair_orders.order_date, + foo_DOT_bar_DOT_repair_orders.repair_order_id, + foo_DOT_bar_DOT_repair_orders.required_date + FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + LEFT OUTER JOIN ( + SELECT foo_DOT_bar_DOT_repair_orders.dispatcher_id, + foo_DOT_bar_DOT_repair_orders.hard_hat_id, + foo_DOT_bar_DOT_repair_orders.municipality_id, + foo_DOT_bar_DOT_repair_orders.repair_order_id + FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + ) AS foo_DOT_bar_DOT_repair_order ON foo_DOT_bar_DOT_repair_orders.repair_order_id = foo_DOT_bar_DOT_repair_order.repair_order_id + LEFT OUTER JOIN ( + SELECT foo_DOT_bar_DOT_hard_hats.hard_hat_id, + foo_DOT_bar_DOT_hard_hats.state + FROM roads.hard_hats AS foo_DOT_bar_DOT_hard_hats + ) AS foo_DOT_bar_DOT_hard_hat ON foo_DOT_bar_DOT_repair_order.hard_hat_id = foo_DOT_bar_DOT_hard_hat.hard_hat_id + WHERE foo_DOT_bar_DOT_hard_hat.state = 'CA' + """, + ), # querying source node with filters directly on the node - # ( - # "foo.bar.repair_orders", - # [], - # ["foo.bar.repair_orders.order_date='2009-08-14'"], - # [], - # """ - # SELECT - # foo_DOT_bar_DOT_repair_orders.dispatched_date, - # foo_DOT_bar_DOT_repair_orders.dispatcher_id, - # foo_DOT_bar_DOT_repair_orders.hard_hat_id, - # foo_DOT_bar_DOT_repair_orders.municipality_id, - # foo_DOT_bar_DOT_repair_orders.order_date, - # foo_DOT_bar_DOT_repair_orders.repair_order_id, - # foo_DOT_bar_DOT_repair_orders.required_date - # FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders - # WHERE - # foo_DOT_bar_DOT_repair_orders.order_date = '2009-08-14' - # """, - # ), + ( + "foo.bar.repair_orders", + [], + ["foo.bar.repair_orders.order_date='2009-08-14'"], + [], + """ + SELECT + foo_DOT_bar_DOT_repair_orders.dispatched_date, + foo_DOT_bar_DOT_repair_orders.dispatcher_id, + foo_DOT_bar_DOT_repair_orders.hard_hat_id, + foo_DOT_bar_DOT_repair_orders.municipality_id, + foo_DOT_bar_DOT_repair_orders.order_date, + foo_DOT_bar_DOT_repair_orders.repair_order_id, + foo_DOT_bar_DOT_repair_orders.required_date + FROM roads.repair_orders AS foo_DOT_bar_DOT_repair_orders + WHERE + foo_DOT_bar_DOT_repair_orders.order_date = '2009-08-14' + """, + ), ( "foo.bar.num_repair_orders", [], From 780f793806fe5b8bf9144b44bc1a06ad04ee9a52 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 9 Aug 2023 11:51:34 -0700 Subject: [PATCH 5/5] Add json deserialization and incorporate into query building --- .../datajunction_server/construction/build.py | 13 ++++- .../datajunction_server/sql/parsing/ast.py | 2 +- .../sql/parsing/ast_json_encoder.py | 49 ++++++++++--------- datajunction-server/tests/api/nodes_test.py | 16 ++++++ datajunction-server/tests/api/sql_test.py | 1 - 5 files changed, 54 insertions(+), 27 deletions(-) diff --git a/datajunction-server/datajunction_server/construction/build.py b/datajunction-server/datajunction_server/construction/build.py index 7a8cfe015..0f4e246e0 100755 --- a/datajunction-server/datajunction_server/construction/build.py +++ b/datajunction-server/datajunction_server/construction/build.py @@ -1,5 +1,6 @@ """Functions to add to an ast DJ node queries""" import collections +import json import logging import time @@ -16,6 +17,7 @@ from datajunction_server.models.node import BuildCriteria, Node, NodeRevision, NodeType from datajunction_server.sql.dag import get_shared_dimensions from datajunction_server.sql.parsing.ast import CompileContext +from datajunction_server.sql.parsing.ast_json_encoder import ast_decoder from datajunction_server.sql.parsing.backends.antlr4 import ast, parse from datajunction_server.sql.parsing.types import ColumnType from datajunction_server.utils import amenable_name @@ -432,6 +434,8 @@ def add_filters_dimensions_orderby_limit_to_query_ast( projection_update += list(projection_addition.values()) query.select.projection = projection_update + query.select._is_compiled = False # pylint: disable=protected-access + query._is_compiled = False # pylint: disable=protected-access if limit is not None: query.select.limit = ast.Number(limit) @@ -516,7 +520,12 @@ def build_node( # pylint: disable=too-many-arguments ): return ast.Query(select=select) # pragma: no cover - if node.query: + if node.query_ast: + query = json.loads( + json.dumps(node.query_ast), + object_hook=lambda _dict: ast_decoder(session, _dict), + ) + elif node.query: query = parse(node.query) else: query = build_source_node_query(node) @@ -824,6 +833,8 @@ def build_ast( # pylint: disable=too-many-arguments context = CompileContext(session=session, exception=DJException()) if hash(query) in memoized_queries: query = memoized_queries[hash(query)] # pragma: no cover + elif query.is_compiled(): + memoized_queries[hash(query)] = query # pragma: no cover else: query.compile(context) memoized_queries[hash(query)] = query diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index 564b5f9b1..f0e32d280 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -1013,7 +1013,7 @@ def json_ignore_keys(self): "_columns", # "column_list", "_ref_columns", - "columns", + # "columns", ] @property diff --git a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py index aecb1d41d..56c59239f 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py @@ -45,41 +45,42 @@ def default(self, o): json_dict = { "__class__": o.__class__.__name__, } - if hasattr(o, "__json_encode__"): + if hasattr(o, "__json_encode__"): # pragma: no cover json_dict = {**json_dict, **o.__json_encode__()} return json_dict def ast_decoder(session, json_dict): - """Decodes json dict""" + """ + Decodes json dict back into an AST entity + """ class_name = json_dict["__class__"] - if not class_name or not hasattr(ast, class_name): - return None clazz = getattr(ast, class_name) + + # Instantiate the class + instance = clazz( + **{ + k: v + for k, v in json_dict.items() + if k not in {"__class__", "_type", "laterals", "_is_compiled"} + }, + ) + + # Set attributes where possible + for key, value in json_dict.items(): + if key not in {"__class__", "_is_compiled"}: + if hasattr(instance, key) and class_name not in {"BinaryOpKind"}: + setattr(instance, key, value) + if class_name == "NodeRevision": + # Overwrite with DB object if it's a node revision instance = ( session.exec(select(Node).where(Node.name == json_dict["name"])) .one() .current ) - else: - instance = clazz( - **{ - k: v - for k, v in json_dict.items() - if k not in {"__class__", "_type", "laterals", "_is_compiled"} - }, - ) - for key, value in json_dict.items(): - if key not in {"__class__", "_is_compiled"}: - try: - setattr(instance, key, value) - except AttributeError: - pass - - if class_name == "Table": - instance._columns = [ # pylint: disable=protected-access - ast.Column(ast.Name(col.name), _table=instance, _type=col.type) - for col in instance._dj_node.columns # pylint: disable=protected-access - ] + elif class_name == "Column": + # Add in a reference to the table from the column + instance._table.parent = instance # pylint: disable=protected-access + instance._table.parent_key = "_table" # pylint: disable=protected-access return instance diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 256f7262b..f79f51896 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -1404,8 +1404,10 @@ def test_create_update_transform_node( assert data["parents"] == [{"name": "basic.source.users"}] assert data["query_ast"] == { "__class__": "Query", + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "ctes": [], "name": { "__class__": "DefaultName", @@ -1432,8 +1434,10 @@ def test_create_update_transform_node( "name": "basic.source.users", "type": "source", }, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", "_is_compiled": True, @@ -1468,6 +1472,7 @@ def test_create_update_transform_node( "_expression": { "__class__": "Column", "_expression": None, + "_is_compiled": False, "_table": { "__class__": "Table", "_dj_node": { @@ -1475,8 +1480,10 @@ def test_create_update_transform_node( "name": "basic.source.users", "type": "source", }, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", "_is_compiled": True, @@ -1507,6 +1514,7 @@ def test_create_update_transform_node( "quote_style": "", }, }, + "_is_compiled": True, "_table": { "__class__": "Table", "_dj_node": { @@ -1514,8 +1522,10 @@ def test_create_update_transform_node( "name": "basic.source.users", "type": "source", }, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", "_is_compiled": True, @@ -1565,6 +1575,7 @@ def test_create_update_transform_node( "_expression": { "__class__": "Column", "_expression": None, + "_is_compiled": False, "_table": { "__class__": "Table", "_dj_node": { @@ -1572,8 +1583,10 @@ def test_create_update_transform_node( "name": "basic.source.users", "type": "source", }, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", "_is_compiled": True, @@ -1604,6 +1617,7 @@ def test_create_update_transform_node( "quote_style": "", }, }, + "_is_compiled": True, "_table": { "__class__": "Table", "_dj_node": { @@ -1611,8 +1625,10 @@ def test_create_update_transform_node( "name": "basic.source.users", "type": "source", }, + "_is_compiled": True, "alias": None, "as_": None, + "column_list": [], "name": { "__class__": "Name", "_is_compiled": True, diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index e0afcbeb1..2a6fc8e8a 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -546,7 +546,6 @@ def test_sql_with_filters_on_namespaced_nodes( # pylint: disable=R0913 params={"dimensions": dimensions, "filters": filters, "orderby": orderby}, ) data = response.json() - print(data["sql"]) assert compare_query_strings(data["sql"], sql)