diff --git a/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py new file mode 100644 index 000000000..b31013bb2 --- /dev/null +++ b/datajunction-server/alembic/versions/2023_08_07_1432-789f91d2d69e_add_query_ast_to_noderevision.py @@ -0,0 +1,27 @@ +"""Add query ast to noderevision + +Revision ID: 789f91d2d69e +Revises: ccc77abcf899 +Create Date: 2023-08-07 14:32:54.290688+00:00 + +""" +# pylint: disable=no-member, invalid-name, missing-function-docstring, unused-import, no-name-in-module + +import sqlalchemy as sa +import sqlmodel + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "789f91d2d69e" +down_revision = "ccc77abcf899" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column("noderevision", sa.Column("query_ast", sa.JSON(), nullable=True)) + + +def downgrade(): + op.drop_column("noderevision", "query_ast") diff --git a/datajunction-server/datajunction_server/api/client.py b/datajunction-server/datajunction_server/api/client.py index 4308d69f6..b2370cfbb 100644 --- a/datajunction-server/datajunction_server/api/client.py +++ b/datajunction-server/datajunction_server/api/client.py @@ -38,6 +38,7 @@ def client_code_for_creating_node( "node_id", "updated_at", "query" if node.type == NodeType.CUBE else "", + "query_ast", }, exclude_none=True, ) diff --git a/datajunction-server/datajunction_server/api/helpers.py b/datajunction-server/datajunction_server/api/helpers.py index bb4ff28f9..2809d6ca8 100644 --- a/datajunction-server/datajunction_server/api/helpers.py +++ b/datajunction-server/datajunction_server/api/helpers.py @@ -58,6 +58,7 @@ from datajunction_server.service_clients import QueryServiceClient from datajunction_server.sql.dag import get_nodes_with_dimension from datajunction_server.sql.parsing import ast +from datajunction_server.sql.parsing.ast_json_encoder import ASTEncoder from datajunction_server.sql.parsing.backends.antlr4 import SqlSyntaxError, parse from datajunction_server.sql.parsing.backends.exceptions import DJParseException from datajunction_server.typing import END_JOB_STATES, UTCDatetime @@ -415,7 +416,7 @@ def validate_node_data( # pylint: disable=too-many-locals dependencies_map, ) validated_node.required_dimensions = matched_bound_columns - + validated_node.query_ast = json.loads(json.dumps(query_ast, cls=ASTEncoder)) errors = [] if missing_parents_map or type_inference_failures or invalid_required_dimensions: # update status (if needed) diff --git a/datajunction-server/datajunction_server/construction/build.py b/datajunction-server/datajunction_server/construction/build.py index 7a8cfe015..0f4e246e0 100755 --- a/datajunction-server/datajunction_server/construction/build.py +++ b/datajunction-server/datajunction_server/construction/build.py @@ -1,5 +1,6 @@ """Functions to add to an ast DJ node queries""" import collections +import json import logging import time @@ -16,6 +17,7 @@ from datajunction_server.models.node import BuildCriteria, Node, NodeRevision, NodeType from datajunction_server.sql.dag import get_shared_dimensions from datajunction_server.sql.parsing.ast import CompileContext +from datajunction_server.sql.parsing.ast_json_encoder import ast_decoder from datajunction_server.sql.parsing.backends.antlr4 import ast, parse from datajunction_server.sql.parsing.types import ColumnType from datajunction_server.utils import amenable_name @@ -432,6 +434,8 @@ def add_filters_dimensions_orderby_limit_to_query_ast( projection_update += list(projection_addition.values()) query.select.projection = projection_update + query.select._is_compiled = False # pylint: disable=protected-access + query._is_compiled = False # pylint: disable=protected-access if limit is not None: query.select.limit = ast.Number(limit) @@ -516,7 +520,12 @@ def build_node( # pylint: disable=too-many-arguments ): return ast.Query(select=select) # pragma: no cover - if node.query: + if node.query_ast: + query = json.loads( + json.dumps(node.query_ast), + object_hook=lambda _dict: ast_decoder(session, _dict), + ) + elif node.query: query = parse(node.query) else: query = build_source_node_query(node) @@ -824,6 +833,8 @@ def build_ast( # pylint: disable=too-many-arguments context = CompileContext(session=session, exception=DJException()) if hash(query) in memoized_queries: query = memoized_queries[hash(query)] # pragma: no cover + elif query.is_compiled(): + memoized_queries[hash(query)] = query # pragma: no cover else: query.compile(context) memoized_queries[hash(query)] = query diff --git a/datajunction-server/datajunction_server/models/node.py b/datajunction-server/datajunction_server/models/node.py index 013905b28..99dc9b613 100644 --- a/datajunction-server/datajunction_server/models/node.py +++ b/datajunction-server/datajunction_server/models/node.py @@ -8,7 +8,7 @@ from datetime import datetime, timezone from functools import partial from http import HTTPStatus -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import BaseModel, Extra from pydantic import Field as PydanticField @@ -678,6 +678,11 @@ class NodeRevision(NodeRevisionBase, table=True): # type: ignore }, ) + query_ast: Optional[Dict[str, Any]] = Field( + sa_column=SqlaColumn("query_ast", JSON), + default={}, + ) + def __hash__(self) -> int: return hash(self.id) @@ -830,6 +835,15 @@ def has_available_materialization(self, build_criteria: BuildCriteria) -> bool: ) ) + def __json_encode__(self): + """ + JSON encoder for node revision + """ + return { + "name": self.name, + "type": self.type, + } + class ImmutableNodeFields(BaseSQLModel): """ @@ -1101,6 +1115,7 @@ class NodeRevisionOutput(SQLModel): table: Optional[str] description: str = "" query: Optional[str] = None + query_ast: Optional[Dict] = {} availability: Optional[AvailabilityState] = None columns: List[ColumnOutput] updated_at: UTCDatetime diff --git a/datajunction-server/datajunction_server/sql/parsing/ast.py b/datajunction-server/datajunction_server/sql/parsing/ast.py index 32a4217b3..f0e32d280 100644 --- a/datajunction-server/datajunction_server/sql/parsing/ast.py +++ b/datajunction-server/datajunction_server/sql/parsing/ast.py @@ -102,6 +102,17 @@ class Node(ABC): _is_compiled: bool = False + @property + def json_ignore_keys(self): + return ["parent", "parent_key", "_is_compiled"] + + def __json_encode__(self): + return { + key: self.__dict__[key] + for key in self.__dict__ + if key not in self.json_ignore_keys + } + def __post_init__(self): self.add_self_as_parent() @@ -624,6 +635,10 @@ def identifier(self, quotes: bool = True) -> str: f"{namespace}{quote_style}{self.name}{quote_style}" # pylint: disable=C0301 ) + @property + def json_ignore_keys(self): + return ["names", "parent", "parent_key"] + TNamed = TypeVar("TNamed", bound="Named") # pylint: disable=C0103 @@ -705,6 +720,10 @@ class Column(Aliasable, Named, Expression): _expression: Optional[Expression] = field(repr=False, default=None) _is_compiled: bool = False + @property + def json_ignore_keys(self): + return ["parent", "parent_key", "columns"] + @property def type(self): if self._type: @@ -985,6 +1004,18 @@ class TableExpression(Aliasable, Expression): # ref (referenced) columns are columns used elsewhere from this table _ref_columns: List[Column] = field(init=False, repr=False, default_factory=list) + @property + def json_ignore_keys(self): + return [ + "parent", + "parent_key", + # "_is_compiled", + "_columns", + # "column_list", + "_ref_columns", + # "columns", + ] + @property def columns(self) -> List[Expression]: """ @@ -1229,6 +1260,11 @@ class BinaryOpKind(DJEnum): Minus = "-" Modulo = "%" + def __json_encode__(self): + return { + "value": self.value, + } + @dataclass(eq=False) class BinaryOp(Operation): @@ -2003,6 +2039,19 @@ class FunctionTable(FunctionTableExpression): Represents a table-valued function used in a statement """ + @property + def json_ignore_keys(self): + return [ + "parent", + "parent_key", + "_is_compiled", + "_table", + "_columns", + "column_list", + "_ref_columns", + "columns", + ] + def __str__(self) -> str: alias = f" {self.alias}" if self.alias else "" as_ = " AS " if self.as_ else "" diff --git a/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py new file mode 100644 index 000000000..56c59239f --- /dev/null +++ b/datajunction-server/datajunction_server/sql/parsing/ast_json_encoder.py @@ -0,0 +1,86 @@ +""" +JSON encoder for AST objects +""" +from json import JSONEncoder + +from sqlmodel import select + +from datajunction_server.models import Node +from datajunction_server.sql.parsing import ast + + +def remove_circular_refs(obj, _seen: set = None): + """ + Short-circuits circular references in AST nodes + """ + if _seen is None: + _seen = set() + if id(obj) in _seen: + return None + _seen.add(id(obj)) + if issubclass(obj.__class__, ast.Node): + serializable_keys = [ + key for key in obj.__dict__.keys() if key not in obj.json_ignore_keys + ] + for key in serializable_keys: + setattr(obj, key, remove_circular_refs(getattr(obj, key), _seen)) + _seen.remove(id(obj)) + return obj + + +class ASTEncoder(JSONEncoder): + """ + JSON encoder for AST objects. Disables the original circular check in favor + of our own version with _processed so that we can catch and handle circular + traversals. + """ + + def __init__(self, *args, **kwargs): + kwargs["check_circular"] = False + self.markers = set() + super().__init__(*args, **kwargs) + + def default(self, o): + o = remove_circular_refs(o) + json_dict = { + "__class__": o.__class__.__name__, + } + if hasattr(o, "__json_encode__"): # pragma: no cover + json_dict = {**json_dict, **o.__json_encode__()} + return json_dict + + +def ast_decoder(session, json_dict): + """ + Decodes json dict back into an AST entity + """ + class_name = json_dict["__class__"] + clazz = getattr(ast, class_name) + + # Instantiate the class + instance = clazz( + **{ + k: v + for k, v in json_dict.items() + if k not in {"__class__", "_type", "laterals", "_is_compiled"} + }, + ) + + # Set attributes where possible + for key, value in json_dict.items(): + if key not in {"__class__", "_is_compiled"}: + if hasattr(instance, key) and class_name not in {"BinaryOpKind"}: + setattr(instance, key, value) + + if class_name == "NodeRevision": + # Overwrite with DB object if it's a node revision + instance = ( + session.exec(select(Node).where(Node.name == json_dict["name"])) + .one() + .current + ) + elif class_name == "Column": + # Add in a reference to the table from the column + instance._table.parent = instance # pylint: disable=protected-access + instance._table.parent_key = "_table" # pylint: disable=protected-access + return instance diff --git a/datajunction-server/datajunction_server/sql/parsing/types.py b/datajunction-server/datajunction_server/sql/parsing/types.py index 754714a39..f10179969 100644 --- a/datajunction-server/datajunction_server/sql/parsing/types.py +++ b/datajunction-server/datajunction_server/sql/parsing/types.py @@ -74,6 +74,11 @@ def __str__(self): def __deepcopy__(self, memo): return self + def __json_encode__(self): + return { + "__class__": self.__class__.__name__, + } + @classmethod def __get_validators__(cls) -> Generator[AnyCallable, None, None]: """ diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index 0a733069f..f79f51896 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -1402,6 +1402,282 @@ def test_create_update_transform_node( }, ] assert data["parents"] == [{"name": "basic.source.users"}] + assert data["query_ast"] == { + "__class__": "Query", + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "ctes": [], + "name": { + "__class__": "DefaultName", + "_is_compiled": True, + "name": "", + "namespace": None, + "quote_style": "", + }, + "select": { + "__class__": "Select", + "alias": None, + "as_": None, + "from_": { + "__class__": "From", + "laterals": [], + "relations": [ + { + "__class__": "Relation", + "extensions": [], + "primary": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + }, + ], + }, + "group_by": [], + "having": None, + "lateral_views": [], + "limit": None, + "organization": {"__class__": "Organization", "order": [], "sort": []}, + "projection": [ + { + "__class__": "Column", + "_expression": { + "__class__": "Column", + "_expression": None, + "_is_compiled": False, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "StringType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "country", + "namespace": None, + "quote_style": "", + }, + }, + "_is_compiled": True, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "StringType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "country", + "namespace": None, + "quote_style": "", + }, + }, + { + "__class__": "Alias", + "alias": { + "__class__": "Name", + "_is_compiled": True, + "name": "num_users", + "namespace": None, + "quote_style": "", + }, + "as_": True, + "child": { + "__class__": "Function", + "args": [ + { + "__class__": "Column", + "_expression": { + "__class__": "Column", + "_expression": None, + "_is_compiled": False, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "IntegerType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "name": "id", + "namespace": None, + "quote_style": "", + }, + }, + "_is_compiled": True, + "_table": { + "__class__": "Table", + "_dj_node": { + "__class__": "NodeRevision", + "name": "basic.source.users", + "type": "source", + }, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "users", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "source", + "namespace": { + "__class__": "Name", + "_is_compiled": True, + "name": "basic", + "namespace": None, + "quote_style": "", + }, + "quote_style": "", + }, + "quote_style": "", + }, + }, + "_type": {"__class__": "IntegerType"}, + "alias": None, + "as_": None, + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "id", + "namespace": None, + "quote_style": "", + }, + }, + ], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "COUNT", + "namespace": None, + "quote_style": "", + }, + "over": None, + "quantifier": "DISTINCT", + }, + }, + ], + "quantifier": "", + "set_op": None, + "where": None, + }, + } # Update the transform node with two minor changes response = client.patch( @@ -2774,7 +3050,7 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: data = response.json() assert response.status_code == 422 - assert data == { + assert data == { # type: ignore "message": "Node `foo` is invalid.", "status": "invalid", "node_revision": { @@ -2783,6 +3059,80 @@ def test_validating_with_missing_parents(self, client: TestClient) -> None: "type": "transform", "description": "This is my foo transform node!", "query": "SELECT 1 FROM node_that_does_not_exist", + "query_ast": { + "__class__": "Query", + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "ctes": [], + "name": { + "__class__": "DefaultName", + "_is_compiled": True, + "name": "", + "namespace": None, + "quote_style": "", + }, + "select": { + "__class__": "Select", + "alias": None, + "as_": None, + "from_": { # type: ignore + "__class__": "From", + "laterals": [], + "relations": [ + { + "__class__": "Relation", + "extensions": [], + "primary": { + "__class__": "Table", + "_dj_node": None, + "_is_compiled": True, + "alias": None, + "as_": None, + "column_list": [], + "name": { + "__class__": "Name", + "_is_compiled": True, + "name": "node_that_does_not_exist", + "namespace": None, + "quote_style": "", + }, + }, + }, + ], + }, + "group_by": [], # type: ignore + "having": None, + "lateral_views": [], # type: ignore + "limit": None, + "organization": { # type: ignore + "__class__": "Organization", + "order": [], + "sort": [], + }, + "projection": [ # type: ignore + { + "__class__": "Alias", + "alias": { + "__class__": "Name", + "name": "col0", + "namespace": None, + "quote_style": "", + }, + "as_": None, + "child": { + "__class__": "Number", + "_type": None, + "value": 1, + }, + }, + ], + "quantifier": "", + "set_op": None, + "where": None, + }, + }, "mode": "published", "id": None, "version": "v0.1",