From af1fa0e543c0d7d9d48e7d08a5f5070e2ec536a6 Mon Sep 17 00:00:00 2001 From: Daniele Nicolodi Date: Tue, 25 Jul 2023 21:50:06 +0200 Subject: [PATCH 1/4] env: Cleanup --- beanquery/query_env.py | 48 ++++++++++++++---------------------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/beanquery/query_env.py b/beanquery/query_env.py index 4c95cd87..be05cc37 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -714,35 +714,6 @@ def update(self, store, context): class Row: """A dumb container for information used by a row expression.""" - rowid = None - - # The current posting being evaluated. - posting = None - - # The current transaction of the posting being evaluated. - entry = None - - # The current running balance *after* applying the posting. - balance = None - - # The parser's options_map. - options_map = None - - # An AccountTypes tuple of the account types. - account_types = None - - # A dict of account name strings to (open, close) entries for those accounts. - open_close_map = None - - # A dict of currency name strings to the corresponding Commodity entry. - commodity_map = None - - # A price dict as computed by build_price_map() - price_map = None - - # A storage area for computing aggregate expression. - store = None - # The context hash is used in caching column accessor functions. # Instead than hashing the row context content, use the rowid as # hash. @@ -751,13 +722,26 @@ def __hash__(self): def __init__(self, entries, options): self.rowid = 0 + + # The current transaction of the posting being evaluated. + self.entry = None + + # The current posting being evaluated. + self.posting = None + + # The current running balance. self.balance = inventory.Inventory() - self.balance_update_rowid = -1 - # Global properties used by some of the accessors. - self.options = options + + # An AccountTypes tuple of the account types. self.account_types = opts.get_account_types(options) + + # A dict of account name strings to (open, close) entries for those accounts. self.open_close_map = getters.get_account_open_close(entries) + + # A dict of currency name strings to the corresponding Commodity entry. self.commodity_map = getters.get_commodity_directives(entries) + + # A price dict. self.price_map = prices.build_price_map(entries) From f5ac2a644e2d08386c66d9c768fbc8f78e27903f Mon Sep 17 00:00:00 2001 From: Daniele Nicolodi Date: Mon, 2 Oct 2023 19:20:22 +0200 Subject: [PATCH 2/4] execute: Cleanup Just variable renames. --- beanquery/query_execute.py | 73 +++++++++++++++++--------------------- 1 file changed, 33 insertions(+), 40 deletions(-) diff --git a/beanquery/query_execute.py b/beanquery/query_execute.py index a7f8df02..808b2306 100644 --- a/beanquery/query_execute.py +++ b/beanquery/query_execute.py @@ -183,35 +183,29 @@ def execute_select(query): 'result_types'. """ # Figure out the result types that describe what we return. - result_types = tuple(Column(target.name, target.c_expr.dtype) - for target in query.c_targets - if target.name is not None) + columns = tuple(Column(target.name, target.c_expr.dtype) for target in query.c_targets if target.name is not None) # Pre-compute lists of the expressions to evaluate. - group_indexes = (set(query.group_indexes) - if query.group_indexes is not None - else query.group_indexes) + group_indexes = set(query.group_indexes) if query.group_indexes is not None else None # Indexes of the columns for result rows and order rows. - result_indexes = [index - for index, c_target in enumerate(query.c_targets) - if c_target.name] - order_spec = query.order_spec - - # Dispatch between the non-aggregated queries and aggregated queries. - c_where = query.c_where - rows = [] + result_indexes = [index for index, target in enumerate(query.c_targets) if target.name] # Precompute a list of expressions to be evaluated. - c_target_exprs = [c_target.c_expr for c_target in query.c_targets] + target_exprs = [target.c_expr for target in query.c_targets] + order_spec = query.order_spec + where = query.c_where + rows = [] + + # Dispatch between the non-aggregated queries and aggregated queries. if query.group_indexes is None: # This is a non-aggregated query. - # Iterate over all the postings once. - for context in query.table: - if c_where is None or c_where(context): - values = [c_expr(context) for c_expr in c_target_exprs] + # Iterate over all the table rows. + for row in query.table: + if where is None or where(row): + values = [expr(row) for expr in target_exprs] rows.append(values) else: @@ -220,45 +214,44 @@ def execute_select(query): # Precompute lists of non-aggregate and aggregate expressions to # evaluate. For aggregate targets, we hunt down the aggregate # sub-expressions to evaluate, to avoid recursion during iteration. - c_nonaggregate_exprs = [] - c_aggregate_exprs = [] - for index, c_expr in enumerate(c_target_exprs): + nonaggregate_exprs = [] + aggregate_exprs = [] + for index, expr in enumerate(target_exprs): if index in group_indexes: - c_nonaggregate_exprs.append(c_expr) + nonaggregate_exprs.append(expr) else: - _, aggregate_exprs = compiler.get_columns_and_aggregates(c_expr) - c_aggregate_exprs.extend(aggregate_exprs) + _, aggregates = compiler.get_columns_and_aggregates(expr) + aggregate_exprs.extend(aggregates) # Note: it is possible that there are no aggregates to compute here. You could # have all columns be non-aggregates and group-by the entire list of columns. # Pre-allocate handles in aggregation nodes. allocator = Allocator() - for c_expr in c_aggregate_exprs: - c_expr.allocate(allocator) + for expr in aggregate_exprs: + expr.allocate(allocator) def create(): # Create a new row in the aggregates store. store = allocator.create_store() - for c_expr in c_aggregate_exprs: - c_expr.initialize(store) + for expr in aggregate_exprs: + expr.initialize(store) return store - context = None aggregates = collections.defaultdict(create) # Iterate over all the postings to evaluate the aggregates. - for context in query.table: - if c_where is None or c_where(context): + for row in query.table: + if where is None or where(row): # Compute the non-aggregate expressions. - key = tuple(c_expr(context) for c_expr in c_nonaggregate_exprs) + key = tuple(expr(row) for expr in nonaggregate_exprs) # Get an appropriate store for the unique key of this row. store = aggregates[key] # Update the aggregate expressions. - for c_expr in c_aggregate_exprs: - c_expr.update(store, context) + for expr in aggregate_exprs: + expr.update(store, row) # Iterate over all the aggregations. for key, store in aggregates.items(): @@ -266,14 +259,14 @@ def create(): values = [] # Finalize the store. - for c_expr in c_aggregate_exprs: - c_expr.finalize(store) + for expr in aggregate_exprs: + expr.finalize(store) - for index, c_expr in enumerate(c_target_exprs): + for index, expr in enumerate(target_exprs): if index in group_indexes: value = next(key_iter) else: - value = c_expr(context) + value = expr(None) values.append(value) # Skip row if HAVING clause expression is false. @@ -304,4 +297,4 @@ def create(): if query.limit is not None: rows = itertools.islice(rows, query.limit) - return result_types, list(rows) + return columns, list(rows) From f7cba46bcf61b4dd4c46b8b34fd810c080457d64 Mon Sep 17 00:00:00 2001 From: Daniele Nicolodi Date: Mon, 2 Oct 2023 19:30:41 +0200 Subject: [PATCH 3/4] compiler: Implement support for sub-queries in select targets --- beanquery/compiler.py | 95 ++++++++++++++++++++++++--------- beanquery/query_compile.py | 10 ++++ beanquery/query_compile_test.py | 6 +-- beanquery/query_execute.py | 3 +- beanquery/query_execute_test.py | 5 ++ beanquery/sources/beancount.py | 2 +- 6 files changed, 89 insertions(+), 32 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index d09c75dc..ae2c0a69 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -21,6 +21,7 @@ EvalPivot, EvalPrint, EvalQuery, + EvalSubquery, EvalTarget, FUNCTIONS, OPERATORS, @@ -39,10 +40,35 @@ def __init__(self, message, ast=None): self.parseinfo = ast.parseinfo if ast is not None else None +class Environment: + def __init__(self): + object.__setattr__(self, 'stack', [{}]) + + def push(self, table): + self.stack.append({'table': table}) + + def pop(self): + self.stack.pop() + + def get(self, name): + for env in reversed(self.stack): + table = env.get('table') + if table is not None and table.name == name: + return EvalEnv(table) + return None + + def __getattr__(self, name): + return self.stack[-1][name] + + def __setattr__(self, name, value): + self.stack[-1][name] = value + + class Compiler: def __init__(self, context): self.context = context - self.table = context.tables.get('postings') + self.env = Environment() + self.default = context.tables.get('postings') def compile(self, query, parameters=None): """Compile an AST into an executable statement.""" @@ -68,16 +94,29 @@ def compile(self, query, parameters=None): else: raise ProgrammingError('positional and named parameters cannot be mixed') - return self._compile(query) + return self._compile_statement(query) @singledispatchmethod - def _compile(self, node: Optional[ast.Node]): - if node is None: - return None + def _compile_statement(self, node: ast.Node): raise NotImplementedError - @_compile.register - def _select(self, node: ast.Select): + @_compile_statement.register + def _compile_balances(self, node: ast.Balances): + return self._compile_select(transform_balances(node)) + + @_compile_statement.register + def _compile_journal(self, node: ast.Journal): + return self._compile_select(transform_journal(node)) + + @_compile_statement.register + def _compile_print(self, node: ast.Print): + self.env.table = self.context.tables.get('entries') + expr = self._compile_from(node.from_clause) + return EvalPrint(self.env.table, expr) + + @_compile_statement.register + def _compile_select(self, node: ast.Select): + self.env.push(self.default) # Compile the FROM clause. c_from_expr = self._compile_from(node.from_clause) @@ -121,7 +160,7 @@ def _select(self, node: ast.Select): 'all non-aggregates must be covered by GROUP-BY clause in aggregate query: ' 'the following targets are missing: {}'.format(','.join(missing_names))) - query = EvalQuery(self.table, + query = EvalQuery(self.env.table, c_targets, c_where, group_indexes, @@ -134,6 +173,7 @@ def _select(self, node: ast.Select): if pivots: return EvalPivot(query, pivots) + self.env.pop() return query def _compile_from(self, node): @@ -142,13 +182,14 @@ def _compile_from(self, node): # Subquery. if isinstance(node, ast.Select): - self.table = SubqueryTable(self._compile(node)) + query = self._compile_statement(node) + self.env.table = SubqueryTable(query) return None # Table reference. if isinstance(node, ast.Table): - self.table = self.context.tables.get(node.name) - if self.table is None: + self.env.table = self.context.tables.get(node.name) + if self.env.table is None: raise CompilationError(f'table "{node.name}" does not exist', node) return None @@ -164,7 +205,7 @@ def _compile_from(self, node): raise CompilationError('CLOSE date must follow OPEN date') # Apply OPEN, CLOSE, and CLEAR clauses. - self.table = self.table.update(open=node.open, close=node.close, clear=node.clear) + self.env.table = self.env.table.update(open=node.open, close=node.close, clear=node.clear) return c_expression @@ -182,7 +223,7 @@ def _compile_targets(self, targets): if isinstance(targets, ast.Asterisk): # Insert the full list of available columns. targets = [ast.Target(ast.Column(name), None) - for name in self.table.wildcard_columns] + for name in self.env.table.wildcard_columns] # Compile targets. c_targets = [] @@ -437,9 +478,15 @@ def _compile_group_by(self, group_by, c_targets): return new_targets[len(c_targets):], group_indexes, having_index + @singledispatchmethod + def _compile(self, node: Optional[ast.Node]): + if node is None: + return None + raise NotImplementedError(node) + @_compile.register def _column(self, node: ast.Column): - column = self.table.columns.get(node.name) + column = self.env.table.columns.get(node.name) if column is not None: return column raise CompilationError(f'column "{node.name}" does not exist', node) @@ -490,6 +537,11 @@ def _attribute(self, node: ast.Attribute): if getter is None: raise CompilationError(f'structured type has no attribute "{node.name}"', node) return EvalGetter(operand, getter, getter.dtype) + if issubclass(dtype, tables.Table): + column = operand.columns.get(node.name) + if column is None: + raise CompilationError(f'column "{node.name}" does not exist', node) + return EvalGetter(operand, column, column.dtype) raise CompilationError('column type is not structured', node) @_compile.register @@ -581,18 +633,9 @@ def _asterisk(self, node: ast.Asterisk): return EvalConstant(None, dtype=types.Asterisk) @_compile.register - def _balances(self, node: ast.Balances): - return self._compile(transform_balances(node)) - - @_compile.register - def _journal(self, node: ast.Journal): - return self._compile(transform_journal(node)) - - @_compile.register - def _print(self, node: ast.Print): - self.table = self.context.tables.get('entries') - expr = self._compile_from(node.from_clause) - return EvalPrint(self.table, expr) + def _subquery(self, node: ast.Select): + query = self._compile_statement(node) + return EvalSubquery(query) def transform_journal(journal): diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index 081fff53..8fc51e8f 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -540,6 +540,16 @@ def __iter__(self): 'limit distinct')) +class EvalSubquery(EvalNode): + def __init__(self, query): + self.query = query._replace(limit=1) + super().__init__(query.c_targets[0].c_expr.dtype) + + def __call__(self, row): + columns, rows = query_execute.execute_query(self.query) + return rows[0][0] + + # A compiled query with a PIVOT BY clause. # # The PIVOT BY clause causes the structure of the returned table to be diff --git a/beanquery/query_compile_test.py b/beanquery/query_compile_test.py index b15edc8c..192db5d8 100644 --- a/beanquery/query_compile_test.py +++ b/beanquery/query_compile_test.py @@ -30,10 +30,10 @@ class TestCompileExpression(unittest.TestCase): def setUpClass(cls): context = Connection() cls.compiler = compiler.Compiler(context) - cls.compiler.table = qe.PostingsEnvironment() + cls.compiler.env.table = qe.PostingsEnvironment() def compile(self, expr): - return self.compiler.compile(expr) + return self.compiler._compile(expr) def test_expr_invalid(self): with self.assertRaises(CompilationError): @@ -760,7 +760,7 @@ def setUpClass(cls): def compile(self, query, params): c = compiler.Compiler(self.context) - c.table = self.context.tables.get('') + c.default = self.context.tables.get('') return c.compile(parser.parse(query), params) def test_named_parameters(self): diff --git a/beanquery/query_execute.py b/beanquery/query_execute.py index 808b2306..c41c9607 100644 --- a/beanquery/query_execute.py +++ b/beanquery/query_execute.py @@ -165,8 +165,7 @@ def execute_query(query): return columns, pivoted - # Not reached. - raise RuntimeError + raise NotImplementedError(query) def execute_select(query): diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 9832f839..26dcc995 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -1390,6 +1390,11 @@ def test_subquery(self): self.execute("""SELECT a + 2 AS b FROM (SELECT 3 AS a FROM #)"""), ([('b', int)], [(5, )])) + def test_subquery_target(self): + self.assertEqual( + self.execute("""SELECT 1 + (SELECT 1 FROM #) AS b FROM #"""), + ([('b', int)], [(2, )])) + class SimpleColumn(qc.EvalColumn): def __init__(self, name, func, dtype): diff --git a/beanquery/sources/beancount.py b/beanquery/sources/beancount.py index bbc64e21..a0653168 100644 --- a/beanquery/sources/beancount.py +++ b/beanquery/sources/beancount.py @@ -160,7 +160,7 @@ def __init__(self, key, dtype): super().__init__(dtype) self.key = key - def __call__(self, row): + def __call__(self, row, env): return row[self.key] From 5973a206748ca13f82572444cc52d5ed3ea3507d Mon Sep 17 00:00:00 2001 From: Daniele Nicolodi Date: Mon, 2 Oct 2023 19:44:27 +0200 Subject: [PATCH 4/4] compiler, execute: Allow sub-queries to reference query environment This is an invasive change as the __call__() method of all the query evaluation objects need to be updated to receive a second parameter pointing to the query execution environment. --- beanquery/compiler.py | 11 ++++-- beanquery/query_compile.py | 66 +++++++++++++++++++-------------- beanquery/query_env.py | 50 +++++++++++++------------ beanquery/query_execute.py | 20 +++++----- beanquery/query_execute_test.py | 4 +- beanquery/sources/beancount.py | 2 +- beanquery/tests/tables.py | 2 +- 7 files changed, 86 insertions(+), 69 deletions(-) diff --git a/beanquery/compiler.py b/beanquery/compiler.py index ae2c0a69..7b2de859 100644 --- a/beanquery/compiler.py +++ b/beanquery/compiler.py @@ -4,6 +4,7 @@ from functools import singledispatchmethod from typing import Optional, Sequence, Mapping +from . import tables from . import types from . import parser from .errors import ProgrammingError @@ -15,6 +16,7 @@ EvalCoalesce, EvalColumn, EvalConstant, + EvalEnv, EvalGetItem, EvalGetter, EvalOr, @@ -489,6 +491,9 @@ def _column(self, node: ast.Column): column = self.env.table.columns.get(node.name) if column is not None: return column + var = self.env.get(node.name) + if var is not None: + return var raise CompilationError(f'column "{node.name}" does not exist', node) @_compile.register @@ -518,7 +523,7 @@ def _function(self, node: ast.Function): function = function(operands) # Constants folding. if all(isinstance(operand, EvalConstant) for operand in operands) and function.pure: - return EvalConstant(function(None), function.dtype) + return EvalConstant(function(None, None), function.dtype) return function @_compile.register @@ -554,7 +559,7 @@ def _unaryop(self, node: ast.UnaryOp): function = function(operand) # Constants folding. if isinstance(operand, EvalConstant): - return EvalConstant(function(None), function.dtype) + return EvalConstant(function(None, None), function.dtype) return function @_compile.register @@ -584,7 +589,7 @@ def _binaryop(self, node: ast.BinaryOp): function = op(left, right) # Constants folding. if isinstance(left, EvalConstant) and isinstance(right, EvalConstant): - return EvalConstant(function(None), function.dtype) + return EvalConstant(function(None, None), function.dtype) return function # Implement type inference when one of the operands is not strongly typed. diff --git a/beanquery/query_compile.py b/beanquery/query_compile.py index 8fc51e8f..b45a65f8 100644 --- a/beanquery/query_compile.py +++ b/beanquery/query_compile.py @@ -63,7 +63,7 @@ def childnodes(self): if isinstance(element, EvalNode): yield element - def __call__(self, context): + def __call__(self, context, env): """Evaluate this node. This is designed to recurse on its children. All subclasses must override and implement this method. @@ -85,7 +85,7 @@ def __init__(self, value, dtype=None): super().__init__(type(value) if dtype is None else dtype) self.value = value - def __call__(self, _): + def __call__(self, context, env): return self.value @@ -97,8 +97,8 @@ def __init__(self, operator, operand, dtype): self.operand = operand self.operator = operator - def __call__(self, context): - operand = self.operand(context) + def __call__(self, context, env): + operand = self.operand(context, env) return self.operator(operand) def __repr__(self): @@ -107,8 +107,8 @@ def __repr__(self): class EvalUnaryOpSafe(EvalUnaryOp): - def __call__(self, context): - operand = self.operand(context) + def __call__(self, context, env): + operand = self.operand(context, env) if operand is None: return None return self.operator(operand) @@ -123,11 +123,11 @@ def __init__(self, operator, left, right, dtype): self.left = left self.right = right - def __call__(self, context): - left = self.left(context) + def __call__(self, context, env): + left = self.left(context, env) if left is None: return None - right = self.right(context) + right = self.right(context, env) if right is None: return None return self.operator(left, right) @@ -145,14 +145,14 @@ def __init__(self, operand, lower, upper): self.lower = lower self.upper = upper - def __call__(self, context): - operand = self.operand(context) + def __call__(self, context, env): + operand = self.operand(context, env) if operand is None: return None - lower = self.lower(context) + lower = self.lower(context, env) if lower is None: return None - upper = self.upper(context) + upper = self.upper(context, env) if upper is None: return None return lower <= operand <= upper @@ -344,9 +344,9 @@ def __init__(self, args): super().__init__(bool) self.args = args - def __call__(self, context): + def __call__(self, context, env): for arg in self.args: - value = arg(context) + value = arg(context, env) if value is None: return None if not value: @@ -361,10 +361,10 @@ def __init__(self, args): super().__init__(bool) self.args = args - def __call__(self, context): + def __call__(self, context, env): r = False for arg in self.args: - value = arg(context) + value = arg(context, env) if value is None: r = None if value: @@ -379,9 +379,9 @@ def __init__(self, args): super().__init__(args[0].dtype) self.args = args - def __call__(self, context): + def __call__(self, context, env): for arg in self.args: - value = arg(context) + value = arg(context, env) if value is not None: return value return None @@ -406,8 +406,8 @@ def __init__(self, operand, key): self.operand = operand self.key = key - def __call__(self, context): - operand = self.operand(context) + def __call__(self, context, env): + operand = self.operand(context, env) if operand is None: return None return operand.get(self.key) @@ -421,11 +421,11 @@ def __init__(self, operand, getter, dtype): self.operand = operand self.getter = getter - def __call__(self, context): - operand = self.operand(context) + def __call__(self, context, env): + operand = self.operand(context, env) if operand is None: return None - return self.getter(operand) + return self.getter(operand, env) class EvalColumn(EvalNode): @@ -478,7 +478,7 @@ def finalize(self, store): """ self.value = store[self.handle] - def __call__(self, context): + def __call__(self, context, env): """Return the value on evaluation. Args: @@ -502,7 +502,8 @@ def column(i, name, dtype): class Column(EvalColumn): def __init__(self): super().__init__(dtype) - __call__ = staticmethod(operator.itemgetter(i)) + def __call__(self, row, env): + return row[i] return Column def __iter__(self): @@ -545,11 +546,20 @@ def __init__(self, query): self.query = query._replace(limit=1) super().__init__(query.c_targets[0].c_expr.dtype) - def __call__(self, row): - columns, rows = query_execute.execute_query(self.query) + def __call__(self, row, env): + columns, rows = query_execute.execute_query(self.query, row) return rows[0][0] +class EvalEnv(EvalNode): + def __init__(self, table): + super().__init__(type(table)) + self.columns = table.columns + + def __call__(self, row, env): + return env + + # A compiled query with a PIVOT BY clause. # # The PIVOT BY clause causes the structure of the returned table to be diff --git a/beanquery/query_env.py b/beanquery/query_env.py index be05cc37..68cef631 100644 --- a/beanquery/query_env.py +++ b/beanquery/query_env.py @@ -44,8 +44,8 @@ class Func(query_compile.EvalFunction): pure = not pass_context def __init__(self, operands): super().__init__(operands, outtype) - def __call__(self, context): - args = [operand(context) for operand in self.operands] + def __call__(self, context, env): + args = [operand(context, env) for operand in self.operands] for arg in args: if arg is None: return None @@ -587,7 +587,7 @@ class Count(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, int) - def update(self, store, context): + def update(self, store, context, env): store[self.handle] += 1 @@ -597,8 +597,8 @@ class CountArg(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, int) - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle] += 1 @@ -609,8 +609,8 @@ class SumInt(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, operands[0].dtype) - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle] += value @@ -618,8 +618,8 @@ def update(self, store, context): @aggregator([Decimal], name='sum') class SumDecimal(query_compile.EvalAggregator): """Calculate the sum of the numerical argument.""" - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle] += value @@ -630,8 +630,8 @@ class SumAmount(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, inventory.Inventory) - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle].add_amount(value) @@ -642,8 +642,8 @@ class SumPosition(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, inventory.Inventory) - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle].add_position(value) @@ -654,8 +654,8 @@ class SumInventory(query_compile.EvalAggregator): def __init__(self, operands): super().__init__(operands, inventory.Inventory) - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: store[self.handle].add_inventory(value) @@ -666,9 +666,9 @@ class First(query_compile.EvalAggregator): def initialize(self, store): store[self.handle] = None - def update(self, store, context): + def update(self, store, context, env): if store[self.handle] is None: - value = self.operands[0](context) + value = self.operands[0](context, env) store[self.handle] = value @@ -678,8 +678,8 @@ class Last(query_compile.EvalAggregator): def initialize(self, store): store[self.handle] = None - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) store[self.handle] = value @@ -689,8 +689,8 @@ class Min(query_compile.EvalAggregator): def initialize(self, store): store[self.handle] = None - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: cur = store[self.handle] if cur is None or value < cur: @@ -703,8 +703,8 @@ class Max(query_compile.EvalAggregator): def initialize(self, store): store[self.handle] = None - def update(self, store, context): - value = self.operands[0](context) + def update(self, store, context, env): + value = self.operands[0](context, env) if value is not None: cur = store[self.handle] if cur is None or value > cur: @@ -760,7 +760,9 @@ def decorator(func): class Col(query_compile.EvalColumn): def __init__(self): super().__init__(dtype) - __call__ = staticmethod(func) + self.func = func + def __call__(self, context, env): + return self.func(context) Col.__name__ = name or func.__name__ Col.__doc__ = help or func.__doc__ cls.columns[Col.__name__] = Col() diff --git a/beanquery/query_execute.py b/beanquery/query_execute.py index c41c9607..b4b392f6 100644 --- a/beanquery/query_execute.py +++ b/beanquery/query_execute.py @@ -34,7 +34,7 @@ def execute_print(c_print, file): entries = [] expr = c_print.where for row in c_print.table: - if expr is None or expr(row): + if expr is None or expr(row, None): entries.append(row.entry) # Create a context that renders all numbers with their natural @@ -118,7 +118,7 @@ def func(obj): return func -def execute_query(query): +def execute_query(query, env=None): """Given a compiled select statement, execute the query. Args: @@ -133,7 +133,7 @@ def execute_query(query): """ if isinstance(query, query_compile.EvalQuery): - return execute_select(query) + return execute_select(query, env) if isinstance(query, query_compile.EvalPivot): columns, rows = execute_select(query.query) @@ -168,7 +168,7 @@ def execute_query(query): raise NotImplementedError(query) -def execute_select(query): +def execute_select(query, env=None): """Given a compiled select statement, execute the query. Args: @@ -203,8 +203,8 @@ def execute_select(query): # Iterate over all the table rows. for row in query.table: - if where is None or where(row): - values = [expr(row) for expr in target_exprs] + if where is None or where(row, env): + values = [expr(row, env) for expr in target_exprs] rows.append(values) else: @@ -240,17 +240,17 @@ def create(): # Iterate over all the postings to evaluate the aggregates. for row in query.table: - if where is None or where(row): + if where is None or where(row, env): # Compute the non-aggregate expressions. - key = tuple(expr(row) for expr in nonaggregate_exprs) + key = tuple(expr(row, env) for expr in nonaggregate_exprs) # Get an appropriate store for the unique key of this row. store = aggregates[key] # Update the aggregate expressions. for expr in aggregate_exprs: - expr.update(store, row) + expr.update(store, row, env) # Iterate over all the aggregations. for key, store in aggregates.items(): @@ -265,7 +265,7 @@ def create(): if index in group_indexes: value = next(key_iter) else: - value = expr(None) + value = expr(None, None) values.append(value) # Skip row if HAVING clause expression is false. diff --git a/beanquery/query_execute_test.py b/beanquery/query_execute_test.py index 26dcc995..743ce199 100644 --- a/beanquery/query_execute_test.py +++ b/beanquery/query_execute_test.py @@ -339,7 +339,7 @@ def filter_entries(query): context = qe.Row(entries, query.table.options) for entry in query.table.prepare(): context.entry = entry - if expr is None or expr(context): + if expr is None or expr(context, None): entries.append(entry) return entries @@ -1402,7 +1402,7 @@ def __init__(self, name, func, dtype): self.name = name self.func = func - def __call__(self, row): + def __call__(self, row, env): return self.func(row) diff --git a/beanquery/sources/beancount.py b/beanquery/sources/beancount.py index a0653168..c3f8b0ff 100644 --- a/beanquery/sources/beancount.py +++ b/beanquery/sources/beancount.py @@ -46,7 +46,7 @@ def __init__(self, name, dtype): super().__init__(dtype) self.name = name - def __call__(self, context): + def __call__(self, context, env): return getattr(context, self.name) diff --git a/beanquery/tests/tables.py b/beanquery/tests/tables.py index 0abdcbdd..98ca3a07 100644 --- a/beanquery/tests/tables.py +++ b/beanquery/tests/tables.py @@ -7,7 +7,7 @@ def __init__(self, func, datatype): super().__init__(datatype) self.func = func - def __call__(self, row): + def __call__(self, row, env): return self.func(row)