From 48e12fbc92a3ca00804adcf816b140857f80287d Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 20 Nov 2025 13:32:00 -0800 Subject: [PATCH 1/4] basic formatter --- core/query_parser.py | 215 +++++++++++++++++++++++++++++++++++-- tests/test_query_parser.py | 77 ++++++++++++- 2 files changed, 285 insertions(+), 7 deletions(-) diff --git a/core/query_parser.py b/core/query_parser.py index 1ac3796..9f92ae1 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,4 +1,12 @@ +import mo_sql_parsing as mosql from core.ast.node import QueryNode +from core.ast.node import ( + QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode +) +from core.ast.enums import NodeType, JoinType, SortOrder +from core.ast.node import Node class QueryParser: @@ -13,11 +21,206 @@ def parse(self, query: str) -> QueryNode: # Any (JSON) -> AST (QueryNode) def format(self, query: QueryNode) -> str: - # Implement formatting logic to convert AST back to SQL string - pass + # [1] AST (QueryNode) -> JSON + json_query = ast_to_json(query) + + # [2] Any (JSON) -> str + sql = mosql.format(json_query) + + return sql + +def ast_to_json(node: QueryNode) -> dict: + """Convert QueryNode AST to JSON dictionary for mosql""" + result = {} + + # process each clause in the query + for child in node.children: + if child.type == NodeType.SELECT: + result['select'] = format_select(child) + elif child.type == NodeType.FROM: + result['from'] = format_from(child) + elif child.type == NodeType.WHERE: + result['where'] = format_where(child) + elif child.type == NodeType.GROUP_BY: + result['groupby'] = format_group_by(child) + elif child.type == NodeType.HAVING: + result['having'] = format_having(child) + elif child.type == NodeType.ORDER_BY: + result['orderby'] = format_order_by(child) + elif child.type == NodeType.LIMIT: + result['limit'] = child.limit + elif child.type == NodeType.OFFSET: + result['offset'] = child.offset + + return result + + +def format_select(select_node: SelectNode) -> list: + """Format SELECT clause""" + items = [] + + for child in select_node.children: + if child.type == NodeType.COLUMN: + if child.alias: + items.append({'name': child.alias, 'value': format_expression(child)}) + else: + items.append({'value': format_expression(child)}) + elif child.type == NodeType.FUNCTION: + func_expr = format_expression(child) + if hasattr(child, 'alias') and child.alias: + items.append({'name': child.alias, 'value': func_expr}) + else: + items.append({'value': func_expr}) + else: + items.append({'value': format_expression(child)}) + + return items + + +def format_from(from_node: FromNode) -> list: + """Format FROM clause""" + sources = [] + tables = list(from_node.children) + + if tables: + main_table = tables[0] + sources.append(format_table(main_table)) + + # additional tables become JOINs + # TODO: add other join type support beyond implicit + for table in tables[1:]: + join_item = { + 'join': format_table(table), + 'on': infer_join_condition(tables[0], table) + } + sources.append(join_item) + + return sources + + +def format_table(table_node: TableNode) -> dict: + """Format a table reference""" + result = {'value': table_node.name} + if table_node.alias: + result['name'] = table_node.alias + return result + + +def infer_join_condition(table1: TableNode, table2: TableNode) -> dict: + """Infer JOIN condition between tables""" + # assume foreign key pattern like table1.table2_id = table2.id + alias1 = table1.alias or table1.name + alias2 = table2.alias or table2.name + + return {'eq': [f'{alias1}.{table2.name[:-1]}_id', f'{alias2}.id']} + + +def format_where(where_node: WhereNode) -> dict: + """Format WHERE clause""" + predicates = list(where_node.children) + if len(predicates) == 1: + return format_expression(predicates[0]) + else: + return {'and': [format_expression(p) for p in predicates]} + + +def format_group_by(group_by_node: GroupByNode) -> list: + """Format GROUP BY clause""" + return [{'value': format_expression(child)} + for child in group_by_node.children] + + +def format_having(having_node: HavingNode) -> dict: + """Format HAVING clause""" + predicates = list(having_node.children) + if len(predicates) == 1: + return format_expression(predicates[0]) + else: + return {'and': [format_expression(p) for p in predicates]} + + +def format_order_by(order_by_node: OrderByNode) -> list: + """Format ORDER BY clause items.""" + items = [] + + # get all items and their sort orders + sort_orders = [] + for child in order_by_node.children: + if child.type == NodeType.ORDER_BY_ITEM: + column = list(child.children)[0] + item = {'value': format_expression(column)} + sort_order = child.sort + sort_orders.append(sort_order) + else: + item = {'value': format_expression(child)} + sort_order = SortOrder.ASC + sort_orders.append(sort_order) + + items.append((item, sort_order)) + + # check if all sort orders are the same + all_same = len(set(sort_orders)) == 1 + common_sort = sort_orders[0] if all_same else None + + # reformat into single sort operator if all items have same sort operator + # ex. ORDER BY dept_name DESC, emp_count DESC -> ORDER BY dept_name, emp_count DESC + result = [] + for i, (item, sort_order) in enumerate(items): + if all_same and i == len(items) - 1: + if common_sort != SortOrder.ASC: + item['sort'] = common_sort.value.lower() + elif not all_same: + if sort_order != SortOrder.ASC: + item['sort'] = sort_order.value.lower() + + result.append(item) + + return result - # [1] Our new code - # AST (QueryNode) -> JSON - # [2] Call mo_sql_format - # Any (JSON) -> str \ No newline at end of file +def format_expression(node: Node): + """Format an expression node""" + if node.type == NodeType.COLUMN: + if node.parent_alias: + return f"{node.parent_alias}.{node.name}" + return node.name + + elif node.type == NodeType.LITERAL: + return node.value + + elif node.type == NodeType.FUNCTION: + # format: {'function_name': args} + func_name = node.name.lower() + args = [format_expression(arg) for arg in node.children] + return {func_name: args[0] if len(args) == 1 else args} + + elif node.type == NodeType.OPERATOR: + # format: {'operator': [left, right]} + op_map = { + '>': 'gt', + '<': 'lt', + '>=': 'gte', + '<=': 'lte', + '=': 'eq', + '!=': 'ne', + 'AND': 'and', + 'OR': 'or', + } + + op_name = op_map.get(node.name.upper(), node.name.lower()) + children = list(node.children) + + left = format_expression(children[0]) + + if len(children) == 2: + right = format_expression(children[1]) + return {op_name: [left, right]} + else: + # unary operator + return {op_name: left} + + elif node.type == NodeType.TABLE: + return format_table(node) + + else: + raise ValueError(f"Unsupported node type in expression: {node.type}") \ No newline at end of file diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index c3e7b61..17b8783 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,15 +1,90 @@ import mo_sql_parsing as mosql from core.query_parser import QueryParser from core.ast.node import ( - QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode ) from core.ast.enums import NodeType, JoinType, SortOrder from data.queries import get_query +from re import sub parser = QueryParser() +def normalize_sql(s): + """Remove extra whitespace and normalize SQL string to be used in comparisons""" + s = s.strip() + s = sub(r'\s+', ' ', s) + + return s + +def test_basic_format(): + # Construct input AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}, 'emp_count') + count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + # SELECT clause + select_clause = SelectNode([emp_name, dept_name, count_star]) + # FROM clause (with implicit JOIN logic) + from_clause = FromNode([emp_table, dept_table]) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode([where_condition]) + # GROUP BY clause + group_by_clause = GroupByNode([dept_id, dept_name]) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode([having_condition]) + # ORDER BY clause + order_by_clause = OrderByNode([ + OrderByItemNode(ColumnNode("dept_name"), _sort=SortOrder.DESC), + OrderByItemNode(ColumnNode("emp_count"), _sort=SortOrder.DESC) + ]) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + # Construct expected query text + expected_sql = """ + SELECT e.name, d.name AS dept_name, COUNT(*) AS emp_count + FROM employees AS e JOIN departments AS d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + expected_sql = expected_sql.strip() + print(mosql.parse(expected_sql)) + print(ast) + + sql = parser.format(ast) + sql = sql.strip() + + assert normalize_sql(sql) == normalize_sql(expected_sql) + def test_parse_1(): query = get_query(1) sql = query['pattern'] From 4fdaa97d30471da5baf261a6922c1a7ef4b439b7 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Thu, 20 Nov 2025 13:36:50 -0800 Subject: [PATCH 2/4] update order by in test --- tests/test_query_parser.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index 17b8783..cac4485 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -31,7 +31,9 @@ def test_basic_format(): emp_dept_id = ColumnNode("department_id", _parent_alias="e") dept_id = ColumnNode("id", _parent_alias="d") count_star = FunctionNode("COUNT", {ColumnNode("*")}, 'emp_count') - count_alias = ColumnNode("emp_count") # This would be the alias for COUNT(*) + count_alias = ColumnNode("emp_count") + dept_alias = ColumnNode("dept_name") + # SELECT clause select_clause = SelectNode([emp_name, dept_name, count_star]) # FROM clause (with implicit JOIN logic) @@ -48,8 +50,8 @@ def test_basic_format(): having_clause = HavingNode([having_condition]) # ORDER BY clause order_by_clause = OrderByNode([ - OrderByItemNode(ColumnNode("dept_name"), _sort=SortOrder.DESC), - OrderByItemNode(ColumnNode("emp_count"), _sort=SortOrder.DESC) + OrderByItemNode(dept_alias, SortOrder.DESC), + OrderByItemNode(count_alias, SortOrder.DESC) ]) # LIMIT and OFFSET limit_clause = LimitNode(10) From 50b3a406d1805d7da9b178302e69ee6af8c314d6 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Mon, 1 Dec 2025 12:04:24 -0800 Subject: [PATCH 3/4] create new files for formatter --- core/query_formatter.py | 215 ++++++++++++++++++++++++++++++++++ core/query_parser.py | 215 +--------------------------------- tests/test_query_formatter.py | 88 ++++++++++++++ tests/test_query_parser.py | 79 +------------ 4 files changed, 310 insertions(+), 287 deletions(-) create mode 100644 core/query_formatter.py create mode 100644 tests/test_query_formatter.py diff --git a/core/query_formatter.py b/core/query_formatter.py new file mode 100644 index 0000000..88a250e --- /dev/null +++ b/core/query_formatter.py @@ -0,0 +1,215 @@ +import mo_sql_parsing as mosql +from core.ast.node import QueryNode +from core.ast.node import ( + QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode +) +from core.ast.enums import NodeType, JoinType, SortOrder +from core.ast.node import Node + +class QueryFormatter: + def format(self, query: QueryNode) -> str: + # [1] AST (QueryNode) -> JSON + json_query = ast_to_json(query) + + # [2] Any (JSON) -> str + sql = mosql.format(json_query) + + return sql + +def ast_to_json(node: QueryNode) -> dict: + """Convert QueryNode AST to JSON dictionary for mosql""" + result = {} + + # process each clause in the query + for child in node.children: + if child.type == NodeType.SELECT: + result['select'] = format_select(child) + elif child.type == NodeType.FROM: + result['from'] = format_from(child) + elif child.type == NodeType.WHERE: + result['where'] = format_where(child) + elif child.type == NodeType.GROUP_BY: + result['groupby'] = format_group_by(child) + elif child.type == NodeType.HAVING: + result['having'] = format_having(child) + elif child.type == NodeType.ORDER_BY: + result['orderby'] = format_order_by(child) + elif child.type == NodeType.LIMIT: + result['limit'] = child.limit + elif child.type == NodeType.OFFSET: + result['offset'] = child.offset + + return result + + +def format_select(select_node: SelectNode) -> list: + """Format SELECT clause""" + items = [] + + for child in select_node.children: + if child.type == NodeType.COLUMN: + if child.alias: + items.append({'name': child.alias, 'value': format_expression(child)}) + else: + items.append({'value': format_expression(child)}) + elif child.type == NodeType.FUNCTION: + func_expr = format_expression(child) + if hasattr(child, 'alias') and child.alias: + items.append({'name': child.alias, 'value': func_expr}) + else: + items.append({'value': func_expr}) + else: + items.append({'value': format_expression(child)}) + + return items + + +def format_from(from_node: FromNode) -> list: + """Format FROM clause""" + sources = [] + tables = list(from_node.children) + + if tables: + main_table = tables[0] + sources.append(format_table(main_table)) + + # additional tables become JOINs + # TODO: add other join type support beyond implicit + for table in tables[1:]: + join_item = { + 'join': format_table(table), + 'on': infer_join_condition(tables[0], table) + } + sources.append(join_item) + + return sources + + +def format_table(table_node: TableNode) -> dict: + """Format a table reference""" + result = {'value': table_node.name} + if table_node.alias: + result['name'] = table_node.alias + return result + + +def infer_join_condition(table1: TableNode, table2: TableNode) -> dict: + """Infer JOIN condition between tables""" + # assume foreign key pattern like table1.table2_id = table2.id + alias1 = table1.alias or table1.name + alias2 = table2.alias or table2.name + + return {'eq': [f'{alias1}.{table2.name[:-1]}_id', f'{alias2}.id']} + + +def format_where(where_node: WhereNode) -> dict: + """Format WHERE clause""" + predicates = list(where_node.children) + if len(predicates) == 1: + return format_expression(predicates[0]) + else: + return {'and': [format_expression(p) for p in predicates]} + + +def format_group_by(group_by_node: GroupByNode) -> list: + """Format GROUP BY clause""" + return [{'value': format_expression(child)} + for child in group_by_node.children] + + +def format_having(having_node: HavingNode) -> dict: + """Format HAVING clause""" + predicates = list(having_node.children) + if len(predicates) == 1: + return format_expression(predicates[0]) + else: + return {'and': [format_expression(p) for p in predicates]} + + +def format_order_by(order_by_node: OrderByNode) -> list: + """Format ORDER BY clause items.""" + items = [] + + # get all items and their sort orders + sort_orders = [] + for child in order_by_node.children: + if child.type == NodeType.ORDER_BY_ITEM: + column = list(child.children)[0] + item = {'value': format_expression(column)} + sort_order = child.sort + sort_orders.append(sort_order) + else: + item = {'value': format_expression(child)} + sort_order = SortOrder.ASC + sort_orders.append(sort_order) + + items.append((item, sort_order)) + + # check if all sort orders are the same + all_same = len(set(sort_orders)) == 1 + common_sort = sort_orders[0] if all_same else None + + # reformat into single sort operator if all items have same sort operator + # ex. ORDER BY dept_name DESC, emp_count DESC -> ORDER BY dept_name, emp_count DESC + result = [] + for i, (item, sort_order) in enumerate(items): + if all_same and i == len(items) - 1: + if common_sort != SortOrder.ASC: + item['sort'] = common_sort.value.lower() + elif not all_same: + if sort_order != SortOrder.ASC: + item['sort'] = sort_order.value.lower() + + result.append(item) + + return result + + +def format_expression(node: Node): + """Format an expression node""" + if node.type == NodeType.COLUMN: + if node.parent_alias: + return f"{node.parent_alias}.{node.name}" + return node.name + + elif node.type == NodeType.LITERAL: + return node.value + + elif node.type == NodeType.FUNCTION: + # format: {'function_name': args} + func_name = node.name.lower() + args = [format_expression(arg) for arg in node.children] + return {func_name: args[0] if len(args) == 1 else args} + + elif node.type == NodeType.OPERATOR: + # format: {'operator': [left, right]} + op_map = { + '>': 'gt', + '<': 'lt', + '>=': 'gte', + '<=': 'lte', + '=': 'eq', + '!=': 'ne', + 'AND': 'and', + 'OR': 'or', + } + + op_name = op_map.get(node.name.upper(), node.name.lower()) + children = list(node.children) + + left = format_expression(children[0]) + + if len(children) == 2: + right = format_expression(children[1]) + return {op_name: [left, right]} + else: + # unary operator + return {op_name: left} + + elif node.type == NodeType.TABLE: + return format_table(node) + + else: + raise ValueError(f"Unsupported node type in expression: {node.type}") \ No newline at end of file diff --git a/core/query_parser.py b/core/query_parser.py index 9f92ae1..1ac3796 100644 --- a/core/query_parser.py +++ b/core/query_parser.py @@ -1,12 +1,4 @@ -import mo_sql_parsing as mosql from core.ast.node import QueryNode -from core.ast.node import ( - QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, - LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode -) -from core.ast.enums import NodeType, JoinType, SortOrder -from core.ast.node import Node class QueryParser: @@ -21,206 +13,11 @@ def parse(self, query: str) -> QueryNode: # Any (JSON) -> AST (QueryNode) def format(self, query: QueryNode) -> str: - # [1] AST (QueryNode) -> JSON - json_query = ast_to_json(query) - - # [2] Any (JSON) -> str - sql = mosql.format(json_query) - - return sql - -def ast_to_json(node: QueryNode) -> dict: - """Convert QueryNode AST to JSON dictionary for mosql""" - result = {} - - # process each clause in the query - for child in node.children: - if child.type == NodeType.SELECT: - result['select'] = format_select(child) - elif child.type == NodeType.FROM: - result['from'] = format_from(child) - elif child.type == NodeType.WHERE: - result['where'] = format_where(child) - elif child.type == NodeType.GROUP_BY: - result['groupby'] = format_group_by(child) - elif child.type == NodeType.HAVING: - result['having'] = format_having(child) - elif child.type == NodeType.ORDER_BY: - result['orderby'] = format_order_by(child) - elif child.type == NodeType.LIMIT: - result['limit'] = child.limit - elif child.type == NodeType.OFFSET: - result['offset'] = child.offset - - return result - - -def format_select(select_node: SelectNode) -> list: - """Format SELECT clause""" - items = [] - - for child in select_node.children: - if child.type == NodeType.COLUMN: - if child.alias: - items.append({'name': child.alias, 'value': format_expression(child)}) - else: - items.append({'value': format_expression(child)}) - elif child.type == NodeType.FUNCTION: - func_expr = format_expression(child) - if hasattr(child, 'alias') and child.alias: - items.append({'name': child.alias, 'value': func_expr}) - else: - items.append({'value': func_expr}) - else: - items.append({'value': format_expression(child)}) - - return items - - -def format_from(from_node: FromNode) -> list: - """Format FROM clause""" - sources = [] - tables = list(from_node.children) - - if tables: - main_table = tables[0] - sources.append(format_table(main_table)) - - # additional tables become JOINs - # TODO: add other join type support beyond implicit - for table in tables[1:]: - join_item = { - 'join': format_table(table), - 'on': infer_join_condition(tables[0], table) - } - sources.append(join_item) - - return sources - - -def format_table(table_node: TableNode) -> dict: - """Format a table reference""" - result = {'value': table_node.name} - if table_node.alias: - result['name'] = table_node.alias - return result - - -def infer_join_condition(table1: TableNode, table2: TableNode) -> dict: - """Infer JOIN condition between tables""" - # assume foreign key pattern like table1.table2_id = table2.id - alias1 = table1.alias or table1.name - alias2 = table2.alias or table2.name - - return {'eq': [f'{alias1}.{table2.name[:-1]}_id', f'{alias2}.id']} - - -def format_where(where_node: WhereNode) -> dict: - """Format WHERE clause""" - predicates = list(where_node.children) - if len(predicates) == 1: - return format_expression(predicates[0]) - else: - return {'and': [format_expression(p) for p in predicates]} - - -def format_group_by(group_by_node: GroupByNode) -> list: - """Format GROUP BY clause""" - return [{'value': format_expression(child)} - for child in group_by_node.children] - - -def format_having(having_node: HavingNode) -> dict: - """Format HAVING clause""" - predicates = list(having_node.children) - if len(predicates) == 1: - return format_expression(predicates[0]) - else: - return {'and': [format_expression(p) for p in predicates]} - - -def format_order_by(order_by_node: OrderByNode) -> list: - """Format ORDER BY clause items.""" - items = [] - - # get all items and their sort orders - sort_orders = [] - for child in order_by_node.children: - if child.type == NodeType.ORDER_BY_ITEM: - column = list(child.children)[0] - item = {'value': format_expression(column)} - sort_order = child.sort - sort_orders.append(sort_order) - else: - item = {'value': format_expression(child)} - sort_order = SortOrder.ASC - sort_orders.append(sort_order) - - items.append((item, sort_order)) - - # check if all sort orders are the same - all_same = len(set(sort_orders)) == 1 - common_sort = sort_orders[0] if all_same else None - - # reformat into single sort operator if all items have same sort operator - # ex. ORDER BY dept_name DESC, emp_count DESC -> ORDER BY dept_name, emp_count DESC - result = [] - for i, (item, sort_order) in enumerate(items): - if all_same and i == len(items) - 1: - if common_sort != SortOrder.ASC: - item['sort'] = common_sort.value.lower() - elif not all_same: - if sort_order != SortOrder.ASC: - item['sort'] = sort_order.value.lower() - - result.append(item) - - return result + # Implement formatting logic to convert AST back to SQL string + pass + # [1] Our new code + # AST (QueryNode) -> JSON -def format_expression(node: Node): - """Format an expression node""" - if node.type == NodeType.COLUMN: - if node.parent_alias: - return f"{node.parent_alias}.{node.name}" - return node.name - - elif node.type == NodeType.LITERAL: - return node.value - - elif node.type == NodeType.FUNCTION: - # format: {'function_name': args} - func_name = node.name.lower() - args = [format_expression(arg) for arg in node.children] - return {func_name: args[0] if len(args) == 1 else args} - - elif node.type == NodeType.OPERATOR: - # format: {'operator': [left, right]} - op_map = { - '>': 'gt', - '<': 'lt', - '>=': 'gte', - '<=': 'lte', - '=': 'eq', - '!=': 'ne', - 'AND': 'and', - 'OR': 'or', - } - - op_name = op_map.get(node.name.upper(), node.name.lower()) - children = list(node.children) - - left = format_expression(children[0]) - - if len(children) == 2: - right = format_expression(children[1]) - return {op_name: [left, right]} - else: - # unary operator - return {op_name: left} - - elif node.type == NodeType.TABLE: - return format_table(node) - - else: - raise ValueError(f"Unsupported node type in expression: {node.type}") \ No newline at end of file + # [2] Call mo_sql_format + # Any (JSON) -> str \ No newline at end of file diff --git a/tests/test_query_formatter.py b/tests/test_query_formatter.py new file mode 100644 index 0000000..732e121 --- /dev/null +++ b/tests/test_query_formatter.py @@ -0,0 +1,88 @@ +import mo_sql_parsing as mosql +from core.query_formatter import QueryFormatter +from core.ast.node import ( + OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode +) +from core.ast.enums import NodeType, JoinType, SortOrder +from data.queries import get_query +from re import sub + +formatter = QueryFormatter() + +def normalize_sql(s): + """Remove extra whitespace and normalize SQL string to be used in comparisons""" + s = s.strip() + s = sub(r'\s+', ' ', s) + + return s + +def test_basic_format(): + # Construct input AST + # Tables + emp_table = TableNode("employees", "e") + dept_table = TableNode("departments", "d") + # Columns + emp_name = ColumnNode("name", _parent_alias="e") + dept_name = ColumnNode("name", "dept_name", "d") + emp_salary = ColumnNode("salary", _parent_alias="e") + emp_age = ColumnNode("age", _parent_alias="e") + emp_dept_id = ColumnNode("department_id", _parent_alias="e") + dept_id = ColumnNode("id", _parent_alias="d") + count_star = FunctionNode("COUNT", {ColumnNode("*")}, 'emp_count') + count_alias = ColumnNode("emp_count") + dept_alias = ColumnNode("dept_name") + + # SELECT clause + select_clause = SelectNode([emp_name, dept_name, count_star]) + # FROM clause (with implicit JOIN logic) + from_clause = FromNode([emp_table, dept_table]) + # WHERE clause + salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) + age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) + where_condition = OperatorNode(salary_condition, "AND", age_condition) + where_clause = WhereNode([where_condition]) + # GROUP BY clause + group_by_clause = GroupByNode([dept_id, dept_name]) + # HAVING clause + having_condition = OperatorNode(count_star, ">", LiteralNode(2)) + having_clause = HavingNode([having_condition]) + # ORDER BY clause + order_by_clause = OrderByNode([ + OrderByItemNode(dept_alias, SortOrder.DESC), + OrderByItemNode(count_alias, SortOrder.DESC) + ]) + # LIMIT and OFFSET + limit_clause = LimitNode(10) + offset_clause = OffsetNode(5) + # Complete query + ast = QueryNode( + _select=select_clause, + _from=from_clause, + _where=where_clause, + _group_by=group_by_clause, + _having=having_clause, + _order_by=order_by_clause, + _limit=limit_clause, + _offset=offset_clause + ) + + # Construct expected query text + expected_sql = """ + SELECT e.name, d.name AS dept_name, COUNT(*) AS emp_count + FROM employees AS e JOIN departments AS d ON e.department_id = d.id + WHERE e.salary > 40000 AND e.age < 60 + GROUP BY d.id, d.name + HAVING COUNT(*) > 2 + ORDER BY dept_name, emp_count DESC + LIMIT 10 OFFSET 5 + """ + expected_sql = expected_sql.strip() + print(mosql.parse(expected_sql)) + print(ast) + + sql = formatter.format(ast) + sql = sql.strip() + + assert normalize_sql(sql) == normalize_sql(expected_sql) \ No newline at end of file diff --git a/tests/test_query_parser.py b/tests/test_query_parser.py index cac4485..c3e7b61 100644 --- a/tests/test_query_parser.py +++ b/tests/test_query_parser.py @@ -1,92 +1,15 @@ import mo_sql_parsing as mosql from core.query_parser import QueryParser from core.ast.node import ( - OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, + QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode ) from core.ast.enums import NodeType, JoinType, SortOrder from data.queries import get_query -from re import sub parser = QueryParser() -def normalize_sql(s): - """Remove extra whitespace and normalize SQL string to be used in comparisons""" - s = s.strip() - s = sub(r'\s+', ' ', s) - - return s - -def test_basic_format(): - # Construct input AST - # Tables - emp_table = TableNode("employees", "e") - dept_table = TableNode("departments", "d") - # Columns - emp_name = ColumnNode("name", _parent_alias="e") - dept_name = ColumnNode("name", "dept_name", "d") - emp_salary = ColumnNode("salary", _parent_alias="e") - emp_age = ColumnNode("age", _parent_alias="e") - emp_dept_id = ColumnNode("department_id", _parent_alias="e") - dept_id = ColumnNode("id", _parent_alias="d") - count_star = FunctionNode("COUNT", {ColumnNode("*")}, 'emp_count') - count_alias = ColumnNode("emp_count") - dept_alias = ColumnNode("dept_name") - - # SELECT clause - select_clause = SelectNode([emp_name, dept_name, count_star]) - # FROM clause (with implicit JOIN logic) - from_clause = FromNode([emp_table, dept_table]) - # WHERE clause - salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) - age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) - where_condition = OperatorNode(salary_condition, "AND", age_condition) - where_clause = WhereNode([where_condition]) - # GROUP BY clause - group_by_clause = GroupByNode([dept_id, dept_name]) - # HAVING clause - having_condition = OperatorNode(count_star, ">", LiteralNode(2)) - having_clause = HavingNode([having_condition]) - # ORDER BY clause - order_by_clause = OrderByNode([ - OrderByItemNode(dept_alias, SortOrder.DESC), - OrderByItemNode(count_alias, SortOrder.DESC) - ]) - # LIMIT and OFFSET - limit_clause = LimitNode(10) - offset_clause = OffsetNode(5) - # Complete query - ast = QueryNode( - _select=select_clause, - _from=from_clause, - _where=where_clause, - _group_by=group_by_clause, - _having=having_clause, - _order_by=order_by_clause, - _limit=limit_clause, - _offset=offset_clause - ) - - # Construct expected query text - expected_sql = """ - SELECT e.name, d.name AS dept_name, COUNT(*) AS emp_count - FROM employees AS e JOIN departments AS d ON e.department_id = d.id - WHERE e.salary > 40000 AND e.age < 60 - GROUP BY d.id, d.name - HAVING COUNT(*) > 2 - ORDER BY dept_name, emp_count DESC - LIMIT 10 OFFSET 5 - """ - expected_sql = expected_sql.strip() - print(mosql.parse(expected_sql)) - print(ast) - - sql = parser.format(ast) - sql = sql.strip() - - assert normalize_sql(sql) == normalize_sql(expected_sql) - def test_parse_1(): query = get_query(1) sql = query['pattern'] From 15482c7c87f7cfcc6e5a135ea67ebd587553e201 Mon Sep 17 00:00:00 2001 From: Colin Harrison Date: Mon, 1 Dec 2025 12:17:39 -0800 Subject: [PATCH 4/4] add join node support --- core/query_formatter.py | 102 +++++++++++++++++++++++++--------- tests/test_query_formatter.py | 25 +++++---- 2 files changed, 89 insertions(+), 38 deletions(-) diff --git a/core/query_formatter.py b/core/query_formatter.py index 88a250e..6b242e5 100644 --- a/core/query_formatter.py +++ b/core/query_formatter.py @@ -3,7 +3,8 @@ from core.ast.node import ( QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, + JoinNode ) from core.ast.enums import NodeType, JoinType, SortOrder from core.ast.node import Node @@ -67,26 +68,73 @@ def format_select(select_node: SelectNode) -> list: def format_from(from_node: FromNode) -> list: - """Format FROM clause""" + """Format FROM clause with explicit JOIN support""" sources = [] - tables = list(from_node.children) + children = list(from_node.children) - if tables: - main_table = tables[0] - sources.append(format_table(main_table)) - - # additional tables become JOINs - # TODO: add other join type support beyond implicit - for table in tables[1:]: - join_item = { - 'join': format_table(table), - 'on': infer_join_condition(tables[0], table) - } - sources.append(join_item) + if not children: + return sources + + # Process JoinNode structure + for child in children: + if child.type == NodeType.JOIN: + join_sources = format_join(child) + # format_join returns a list, extend sources with it + if isinstance(join_sources, list): + sources.extend(join_sources) + else: + sources.append(join_sources) + elif child.type == NodeType.TABLE: + sources.append(format_table(child)) return sources +def format_join(join_node: JoinNode) -> list: + """Format a JOIN node""" + children = list(join_node.children) + + if len(children) < 2: + raise ValueError("JoinNode must have at least 2 children (left and right tables)") + + left_node = children[0] + right_node = children[1] + join_condition = children[2] if len(children) > 2 else None + + result = [] + + # Format left side (could be a table or nested join) + if left_node.type == NodeType.JOIN: + # Nested join - recursively format + result.extend(format_join(left_node)) + elif left_node.type == NodeType.TABLE: + # Simple table - this becomes the FROM table + result.append(format_table(left_node)) + + # Format the join itself + join_dict = {} + + # Map join types to mosql format + join_type_map = { + JoinType.INNER: 'join', + JoinType.LEFT: 'left join', + JoinType.RIGHT: 'right join', + JoinType.FULL: 'full join', + JoinType.CROSS: 'cross join', + } + + join_key = join_type_map.get(join_node.join_type, 'join') + join_dict[join_key] = format_table(right_node) + + # Add join condition if it exists + if join_condition: + join_dict['on'] = format_expression(join_condition) + + result.append(join_dict) + + return result + + def format_table(table_node: TableNode) -> dict: """Format a table reference""" result = {'value': table_node.name} @@ -95,15 +143,6 @@ def format_table(table_node: TableNode) -> dict: return result -def infer_join_condition(table1: TableNode, table2: TableNode) -> dict: - """Infer JOIN condition between tables""" - # assume foreign key pattern like table1.table2_id = table2.id - alias1 = table1.alias or table1.name - alias2 = table2.alias or table2.name - - return {'eq': [f'{alias1}.{table2.name[:-1]}_id', f'{alias2}.id']} - - def format_where(where_node: WhereNode) -> dict: """Format WHERE clause""" predicates = list(where_node.children) @@ -137,11 +176,22 @@ def format_order_by(order_by_node: OrderByNode) -> list: for child in order_by_node.children: if child.type == NodeType.ORDER_BY_ITEM: column = list(child.children)[0] - item = {'value': format_expression(column)} + + # Check if the column has an alias + if hasattr(column, 'alias') and column.alias: + item = {'value': column.alias} + else: + item = {'value': format_expression(column)} + sort_order = child.sort sort_orders.append(sort_order) else: - item = {'value': format_expression(child)} + # Direct column reference (no OrderByItemNode wrapper) + if hasattr(child, 'alias') and child.alias: + item = {'value': child.alias} + else: + item = {'value': format_expression(child)} + sort_order = SortOrder.ASC sort_orders.append(sort_order) diff --git a/tests/test_query_formatter.py b/tests/test_query_formatter.py index 732e121..a20b270 100644 --- a/tests/test_query_formatter.py +++ b/tests/test_query_formatter.py @@ -3,7 +3,7 @@ from core.ast.node import ( OrderByItemNode, QueryNode, SelectNode, FromNode, WhereNode, TableNode, ColumnNode, LiteralNode, OperatorNode, FunctionNode, GroupByNode, HavingNode, - OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode + OrderByNode, LimitNode, OffsetNode, SubqueryNode, VarNode, VarSetNode, JoinNode ) from core.ast.enums import NodeType, JoinType, SortOrder from data.queries import get_query @@ -19,25 +19,27 @@ def normalize_sql(s): return s def test_basic_format(): - # Construct input AST + # Construct expected AST # Tables emp_table = TableNode("employees", "e") dept_table = TableNode("departments", "d") # Columns emp_name = ColumnNode("name", _parent_alias="e") - dept_name = ColumnNode("name", "dept_name", "d") emp_salary = ColumnNode("salary", _parent_alias="e") emp_age = ColumnNode("age", _parent_alias="e") emp_dept_id = ColumnNode("department_id", _parent_alias="e") + + dept_name = ColumnNode("name", _alias="dept_name", _parent_alias="d") dept_id = ColumnNode("id", _parent_alias="d") - count_star = FunctionNode("COUNT", {ColumnNode("*")}, 'emp_count') - count_alias = ColumnNode("emp_count") - dept_alias = ColumnNode("dept_name") + + count_star = FunctionNode("COUNT", _alias="emp_count", _args=[ColumnNode("*")]) # SELECT clause select_clause = SelectNode([emp_name, dept_name, count_star]) - # FROM clause (with implicit JOIN logic) - from_clause = FromNode([emp_table, dept_table]) + # FROM clause with JOIN + join_condition = OperatorNode(emp_dept_id, "=", dept_id) + join_node = JoinNode(emp_table, dept_table, JoinType.INNER, join_condition) + from_clause = FromNode([join_node]) # WHERE clause salary_condition = OperatorNode(emp_salary, ">", LiteralNode(40000)) age_condition = OperatorNode(emp_age, "<", LiteralNode(60)) @@ -49,10 +51,9 @@ def test_basic_format(): having_condition = OperatorNode(count_star, ">", LiteralNode(2)) having_clause = HavingNode([having_condition]) # ORDER BY clause - order_by_clause = OrderByNode([ - OrderByItemNode(dept_alias, SortOrder.DESC), - OrderByItemNode(count_alias, SortOrder.DESC) - ]) + order_by_item1 = OrderByItemNode(dept_name, SortOrder.ASC) + order_by_item2 = OrderByItemNode(count_star, SortOrder.DESC) + order_by_clause = OrderByNode([order_by_item1, order_by_item2]) # LIMIT and OFFSET limit_clause = LimitNode(10) offset_clause = OffsetNode(5)