Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 52 additions & 22 deletions beanquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,40 +506,70 @@ def _and(self, node: ast.And):
@_compile.register(ast.All)
@_compile.register(ast.Any)
def _all(self, node):
# This parses a node of the form
# All(left, op, right, side), which arises from the syntax:
#
# ALL( ... ) <op> <val> (side == 'lhs')
# <val> <op> ALL( ... ) (side == 'rhs')
#
# Example: ANY(accounts) = "Assets:Checking"
#
left = self._compile(node.left)
right = self._compile(node.right)

if node.side == 'lhs':
collection = left
collection_node = node.left
value = right
elif node.side == 'rhs':
collection = right
collection_node = node.right
value = left

if isinstance(collection, EvalQuery):
if len(collection.columns) != 1:
raise CompilationError('subquery has too many columns', collection_node)
collection = EvalConstantSubquery1D(collection)

collection_dtype = typing.get_origin(collection.dtype) or collection.dtype
value_dtype = typing.get_origin(value.dtype) or value.dtype

if collection_dtype not in {list, set, EvalConstantSubquery1D}:
raise CompilationError(
f'ANY/ALL requires a collection (list, set, or subquery), got {types.name(collection.dtype)}',
node)

if isinstance(right, EvalQuery):
if len(right.columns) != 1:
raise CompilationError('subquery has too many columns', node.right)
right = EvalConstantSubquery1D(right)

right_dtype = typing.get_origin(right.dtype) or right.dtype
if right_dtype not in {list, set}:
raise CompilationError(f'not a list or set but {right_dtype}', node.right)
args = typing.get_args(right.dtype)
collection_dtype = typing.get_origin(collection.dtype) or collection.dtype
if collection_dtype not in {list, set}:
raise CompilationError(f'not a list or set but {collection_dtype}', collection_node)
args = typing.get_args(collection.dtype)
if args:
assert len(args) == 1
right_element_dtype = args[0]
collection_element_dtype = args[0]
else:
right_element_dtype = object
collection_element_dtype = object

left = self._compile(node.left)

# lookup operator implementaton and check typing
# Lookup operator implementation and check typing.
op = self._OPERATORS[node.op]
for func in OPERATORS[op]:
if func.__intypes__ == [right_element_dtype, left.dtype]:
break
else:
if node.side == 'rhs':
# value op ANY(collection) -> operator(value, element)
func = types.operator_lookup(OPERATORS[op], [value.dtype, collection_element_dtype])
left_type, right_type = value.dtype, collection_element_dtype
else: # node.side == 'lhs'
# ANY(collection) op value -> operator(element, value)
func = types.operator_lookup(OPERATORS[op], [collection_element_dtype, value.dtype])
left_type, right_type = collection_element_dtype, value.dtype

if func is None:
raise CompilationError(
f'operator "{op.__name__.lower()}('
f'{left.dtype.__name__}, {right_element_dtype.__name__})" not supported', node)
f'{types.name(left_type)}, {types.name(right_type)})" not supported', node)

# need to instantiate the operaotr implementation to get to the underlying function
operator = func(None, None).operator

cls = EvalAll if type(node) is ast.All else EvalAny
return cls(operator, left, right)
return cls(operator, collection, value, node.side)

@_compile.register
def _function(self, node: ast.Function):
Expand Down Expand Up @@ -583,9 +613,9 @@ def _function(self, node: ast.Function):
ast.Attribute(ast.Column('entry', parseinfo=node.parseinfo), 'meta'), key])])
return self._compile(node)

# Replace ``has_account(regexp)`` with ``('(?i)' + regexp) ~? any (accounts)``.
# Replace ``has_account(regexp)`` with ``any (accounts) ~ ('(?i)' + regexp) ``.
if node.fname == 'has_account':
node = ast.Any(ast.Add(ast.Constant('(?i)'), node.operands[0]), '?~', ast.Column('accounts'))
node = ast.Any(ast.Column('accounts'), '~', ast.Add(ast.Constant('(?i)'), node.operands[0]), side = 'lhs')
return self._compile(node)

function = function(self.context, operands)
Expand Down
4 changes: 2 additions & 2 deletions beanquery/parser/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ class Sub(BinaryOp):
__slots__ = ()


Any = node('Any', 'left op right')
All = node('All', 'left op right')
Any = node('Any', 'left op right side')
All = node('All', 'left op right side')


CreateTable = node('CreateTable', 'name columns using query')
Expand Down
8 changes: 6 additions & 2 deletions beanquery/parser/bql.ebnf
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,15 @@ comparison
;

any::Any
= left:sum op:op 'any' '(' right:expression ')'
=
| left:sum op:op 'any' '(' right:expression ')' side:`rhs`
| 'any' '(' left:expression ')' op:op right:sum side:`lhs`
;

all::All
= left:sum op:op 'all' '(' right:expression ')'
=
| left:sum op:op 'all' '(' right:expression ')' side:`rhs`
| 'all' '(' left:expression ')' op:op right:sum side:`lhs`
;

op
Expand Down
84 changes: 63 additions & 21 deletions beanquery/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _comparison_(self):
self._sum_()
self._error(
'expecting one of: '
'<add> <all> <any> <between> <eq> <gt>'
"'all' 'any' <add> <between> <eq> <gt>"
'<gte> <in> <isnotnull> <isnull> <lt>'
'<lte> <match> <matches> <neq> <notin>'
'<notmatch> <sub> <sum> <term>'
Expand All @@ -586,29 +586,71 @@ def _comparison_(self):
@tatsumasu('Any')
@nomemo
def _any_(self):
self._sum_()
self.name_last_node('left')
self._op_()
self.name_last_node('op')
self._token('any')
self._token('(')
self._expression_()
self.name_last_node('right')
self._token(')')
self._define(['left', 'op', 'right'], [])
with self._choice():
with self._option():
self._sum_()
self.name_last_node('left')
self._op_()
self.name_last_node('op')
self._token('any')
self._token('(')
self._expression_()
self.name_last_node('right')
self._token(')')
self._constant('rhs')
self.name_last_node('side')
self._define(['left', 'op', 'right', 'side'], [])
with self._option():
self._token('any')
self._token('(')
self._expression_()
self.name_last_node('left')
self._token(')')
self._op_()
self.name_last_node('op')
self._sum_()
self.name_last_node('right')
self._constant('lhs')
self.name_last_node('side')
self._define(['left', 'op', 'right', 'side'], [])
self._error(
'expecting one of: '
"'any' <add> <sub> <sum> <term>"
)

@tatsumasu('All')
def _all_(self):
self._sum_()
self.name_last_node('left')
self._op_()
self.name_last_node('op')
self._token('all')
self._token('(')
self._expression_()
self.name_last_node('right')
self._token(')')
self._define(['left', 'op', 'right'], [])
with self._choice():
with self._option():
self._sum_()
self.name_last_node('left')
self._op_()
self.name_last_node('op')
self._token('all')
self._token('(')
self._expression_()
self.name_last_node('right')
self._token(')')
self._constant('rhs')
self.name_last_node('side')
self._define(['left', 'op', 'right', 'side'], [])
with self._option():
self._token('all')
self._token('(')
self._expression_()
self.name_last_node('left')
self._token(')')
self._op_()
self.name_last_node('op')
self._sum_()
self.name_last_node('right')
self._constant('lhs')
self.name_last_node('side')
self._define(['left', 'op', 'right', 'side'], [])
self._error(
'expecting one of: '
"'all' <add> <sub> <sum> <term>"
)

@tatsumasu()
def _op_(self):
Expand Down
44 changes: 32 additions & 12 deletions beanquery/query_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,13 +448,14 @@ def __call__(self, context):


class EvalAny(EvalNode):
__slots__ = ('op', 'left', 'right')
__slots__ = ('op', 'left', 'right', 'collection_side')

def __init__(self, op, left, right):
def __init__(self, op, left, right, collection_side):
super().__init__(bool)
self.op = op
self.left = left
self.right = right
self.collection_side = collection_side

def __call__(self, row):
left = self.left(row)
Expand All @@ -463,26 +464,45 @@ def __call__(self, row):
right = self.right(row)
if right is None:
return None
return any(self.op(left, x) for x in right)

if self.collection_side == 'right':
# Original form: value op ANY(collection)
return any(self.op(left, x) for x in right)
else: # collection_side == 'left'
# New form: ANY(collection) op value
return any(self.op(x, right) for x in left)


class EvalAll(EvalNode):
__slots__ = ('op', 'left', 'right')
__slots__ = ('op', 'collection', 'value', 'side')

def __init__(self, op, left, right):
def __init__(self, op, collection, value, side):
"""
Either: ALL(<collection>) <op> <value> (side == 'lhs')
Or: <value> <op> ALL(<collection>) (side == 'rhs')
"""
super().__init__(bool)
self.op = op
self.left = left
self.right = right
self.collection = collection
self.value = value
if side not in ['lhs', 'rhs']:
raise ValueError('EvalAll: Parameter "side" must be one of "lhs", "rhs"')
self.side = side

def __call__(self, row):
left = self.left(row)
if left is None:
collection = self.collection(row)
if collection is None:
return None
right = self.right(row)
if right is None:
value = self.value(row)
if value is None:
return None
return all(self.op(left, x) for x in right)

if self.side == 'rhs':
# Syntax form: value op ALL(collection)
return all(self.op(value, x) for x in collection)
else: # side == 'lhs'
# Syntax form: ALL(collection) op value
return all(self.op(x, value) for x in collection)


class EvalRow(EvalNode):
Expand Down
8 changes: 4 additions & 4 deletions beanquery/query_execute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1667,23 +1667,23 @@ def test_in_accounts_transactions(self):
]
)

def test_mathes_any_accounts_transactions(self):
def test_matches_any_accounts_transactions(self):
self.check_query(self.data, """
SELECT date, narration
FROM #transactions
WHERE ':Two' ?~ ANY(accounts)
WHERE ANY(accounts) ~ ':Two'
""",
(('date', datetime.date), ('narration', str)),
[
(datetime.date(2025, 1, 2), 'Two'),
]
)

def test_mathes_all_accounts_transactions(self):
def test_matches_all_accounts_transactions(self):
self.check_query(self.data, """
SELECT date, narration
FROM #transactions
WHERE '(?i):two|:cash' ?~ ALL(accounts)
WHERE ALL(accounts) ~ '(?i):two|:cash'
""",
(('date', datetime.date), ('narration', str)),
[
Expand Down
10 changes: 5 additions & 5 deletions beanquery/sources/beancount.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,12 @@ def description(entry):
return None
return ' | '.join(filter(None, [entry.payee, entry.narration]))

@columns.register(set)
@columns.register(typing.Set[str])
def tags(entry):
"""The set of tags of the transaction."""
return getattr(entry, 'tags', None)

@columns.register(set)
@columns.register(typing.Set[str])
def links(entry):
"""The set of links of the transaction."""
return getattr(entry, 'links', None)
Expand Down Expand Up @@ -493,12 +493,12 @@ def description(context):
"""A combination of the payee + narration for the transaction of this posting."""
return ' | '.join(filter(None, [context.entry.payee, context.entry.narration]))

@columns.register(set)
@columns.register(typing.Set[str])
def tags(context):
"""The set of tags of the parent transaction for this posting."""
return context.entry.tags

@columns.register(set)
@columns.register(typing.Set[str])
def links(context):
"""The set of links of the parent transaction for this posting."""
return context.entry.links
Expand All @@ -513,7 +513,7 @@ def account(context):
"""The account of the posting."""
return context.posting.account

@columns.register(set)
@columns.register(typing.Set[str])
def other_accounts(context):
"""The list of other accounts in the transaction, excluding that of this posting."""
return sorted({posting.account for posting in context.entry.postings if posting is not context.posting})
Expand Down
Loading