Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 77 additions & 29 deletions beanquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from functools import singledispatchmethod
from typing import Optional, Sequence, Mapping

from . import tables
from . import types
from . import parser
from .errors import ProgrammingError
Expand All @@ -15,12 +16,14 @@
EvalCoalesce,
EvalColumn,
EvalConstant,
EvalEnv,
EvalGetItem,
EvalGetter,
EvalOr,
EvalPivot,
EvalPrint,
EvalQuery,
EvalSubquery,
EvalTarget,
FUNCTIONS,
OPERATORS,
Expand All @@ -39,10 +42,35 @@ def __init__(self, message, ast=None):
self.parseinfo = ast.parseinfo if ast is not None else None


class Environment:
def __init__(self):
object.__setattr__(self, 'stack', [{}])

def push(self, table):
self.stack.append({'table': table})

def pop(self):
self.stack.pop()

def get(self, name):
for env in reversed(self.stack):
table = env.get('table')
if table is not None and table.name == name:
return EvalEnv(table)
return None

def __getattr__(self, name):
return self.stack[-1][name]

def __setattr__(self, name, value):
self.stack[-1][name] = value


class Compiler:
def __init__(self, context):
self.context = context
self.table = context.tables.get('postings')
self.env = Environment()
self.default = context.tables.get('postings')

def compile(self, query, parameters=None):
"""Compile an AST into an executable statement."""
Expand All @@ -68,16 +96,29 @@ def compile(self, query, parameters=None):
else:
raise ProgrammingError('positional and named parameters cannot be mixed')

return self._compile(query)
return self._compile_statement(query)

@singledispatchmethod
def _compile(self, node: Optional[ast.Node]):
if node is None:
return None
def _compile_statement(self, node: ast.Node):
raise NotImplementedError

@_compile.register
def _select(self, node: ast.Select):
@_compile_statement.register
def _compile_balances(self, node: ast.Balances):
return self._compile_select(transform_balances(node))

@_compile_statement.register
def _compile_journal(self, node: ast.Journal):
return self._compile_select(transform_journal(node))

@_compile_statement.register
def _compile_print(self, node: ast.Print):
self.env.table = self.context.tables.get('entries')
expr = self._compile_from(node.from_clause)
return EvalPrint(self.env.table, expr)

@_compile_statement.register
def _compile_select(self, node: ast.Select):
self.env.push(self.default)

# Compile the FROM clause.
c_from_expr = self._compile_from(node.from_clause)
Expand Down Expand Up @@ -121,7 +162,7 @@ def _select(self, node: ast.Select):
'all non-aggregates must be covered by GROUP-BY clause in aggregate query: '
'the following targets are missing: {}'.format(','.join(missing_names)))

query = EvalQuery(self.table,
query = EvalQuery(self.env.table,
c_targets,
c_where,
group_indexes,
Expand All @@ -134,6 +175,7 @@ def _select(self, node: ast.Select):
if pivots:
return EvalPivot(query, pivots)

self.env.pop()
return query

def _compile_from(self, node):
Expand All @@ -142,13 +184,14 @@ def _compile_from(self, node):

# Subquery.
if isinstance(node, ast.Select):
self.table = SubqueryTable(self._compile(node))
query = self._compile_statement(node)
self.env.table = SubqueryTable(query)
return None

# Table reference.
if isinstance(node, ast.Table):
self.table = self.context.tables.get(node.name)
if self.table is None:
self.env.table = self.context.tables.get(node.name)
if self.env.table is None:
raise CompilationError(f'table "{node.name}" does not exist', node)
return None

Expand All @@ -164,7 +207,7 @@ def _compile_from(self, node):
raise CompilationError('CLOSE date must follow OPEN date')

# Apply OPEN, CLOSE, and CLEAR clauses.
self.table = self.table.update(open=node.open, close=node.close, clear=node.clear)
self.env.table = self.env.table.update(open=node.open, close=node.close, clear=node.clear)

return c_expression

Expand All @@ -182,7 +225,7 @@ def _compile_targets(self, targets):
if isinstance(targets, ast.Asterisk):
# Insert the full list of available columns.
targets = [ast.Target(ast.Column(name), None)
for name in self.table.wildcard_columns]
for name in self.env.table.wildcard_columns]

# Compile targets.
c_targets = []
Expand Down Expand Up @@ -437,11 +480,20 @@ def _compile_group_by(self, group_by, c_targets):

return new_targets[len(c_targets):], group_indexes, having_index

@singledispatchmethod
def _compile(self, node: Optional[ast.Node]):
if node is None:
return None
raise NotImplementedError(node)

@_compile.register
def _column(self, node: ast.Column):
column = self.table.columns.get(node.name)
column = self.env.table.columns.get(node.name)
if column is not None:
return column
var = self.env.get(node.name)
if var is not None:
return var
raise CompilationError(f'column "{node.name}" does not exist', node)

@_compile.register
Expand Down Expand Up @@ -471,7 +523,7 @@ def _function(self, node: ast.Function):
function = function(operands)
# Constants folding.
if all(isinstance(operand, EvalConstant) for operand in operands) and function.pure:
return EvalConstant(function(None), function.dtype)
return EvalConstant(function(None, None), function.dtype)
return function

@_compile.register
Expand All @@ -490,6 +542,11 @@ def _attribute(self, node: ast.Attribute):
if getter is None:
raise CompilationError(f'structured type has no attribute "{node.name}"', node)
return EvalGetter(operand, getter, getter.dtype)
if issubclass(dtype, tables.Table):
column = operand.columns.get(node.name)
if column is None:
raise CompilationError(f'column "{node.name}" does not exist', node)
return EvalGetter(operand, column, column.dtype)
raise CompilationError('column type is not structured', node)

@_compile.register
Expand All @@ -502,7 +559,7 @@ def _unaryop(self, node: ast.UnaryOp):
function = function(operand)
# Constants folding.
if isinstance(operand, EvalConstant):
return EvalConstant(function(None), function.dtype)
return EvalConstant(function(None, None), function.dtype)
return function

@_compile.register
Expand Down Expand Up @@ -532,7 +589,7 @@ def _binaryop(self, node: ast.BinaryOp):
function = op(left, right)
# Constants folding.
if isinstance(left, EvalConstant) and isinstance(right, EvalConstant):
return EvalConstant(function(None), function.dtype)
return EvalConstant(function(None, None), function.dtype)
return function

# Implement type inference when one of the operands is not strongly typed.
Expand Down Expand Up @@ -581,18 +638,9 @@ def _asterisk(self, node: ast.Asterisk):
return EvalConstant(None, dtype=types.Asterisk)

@_compile.register
def _balances(self, node: ast.Balances):
return self._compile(transform_balances(node))

@_compile.register
def _journal(self, node: ast.Journal):
return self._compile(transform_journal(node))

@_compile.register
def _print(self, node: ast.Print):
self.table = self.context.tables.get('entries')
expr = self._compile_from(node.from_clause)
return EvalPrint(self.table, expr)
def _subquery(self, node: ast.Select):
query = self._compile_statement(node)
return EvalSubquery(query)


def transform_journal(journal):
Expand Down
Loading