Skip to content

Commit 787cd68

Browse files
committed
Refactor aggregation methods
1 parent 9b78e9e commit 787cd68

File tree

3 files changed

+21
-31
lines changed

3 files changed

+21
-31
lines changed

django_mongodb_backend/aggregates.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
from django.db.models.aggregates import Aggregate, Count, StdDev, Variance
22
from django.db.models.expressions import Case, Value, When
33
from django.db.models.lookups import IsNull
4+
from django.db.models.sql.where import WhereNode
45

5-
from .query_utils import process_lhs
6+
from django_mongodb_backend.expressions import Remove
67

78
# Aggregates whose MongoDB aggregation name differ from Aggregate.function.lower().
89
MONGO_AGGREGATIONS = {Count: "sum"}
910

1011

1112
def aggregate(self, compiler, connection, operator=None, resolve_inner_expression=False):
12-
if self.filter:
13-
node = self.copy()
14-
node.filter = None
15-
source_expressions = node.get_source_expressions()
16-
condition = When(self.filter, then=source_expressions[0])
17-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
18-
else:
19-
node = self
20-
lhs_mql = process_lhs(node, compiler, connection, as_expr=True)
13+
source_expressions = self.get_source_expressions()
14+
condition = (
15+
Case(When(self.filter, then=source_expressions[0]), default=Remove())
16+
if self.filter
17+
else source_expressions[0]
18+
)
19+
lhs_mql = condition.as_mql(compiler, connection, as_expr=True)
2120
if resolve_inner_expression:
2221
return lhs_mql
2322
operator = operator or MONGO_AGGREGATIONS.get(self.__class__, self.function.lower())
@@ -30,31 +29,23 @@ def count(self, compiler, connection, resolve_inner_expression=False):
3029
value. This is used to count different elements, so the inner values are
3130
returned to be pushed into a set.
3231
"""
32+
source_expressions = self.get_source_expressions()
3333
if not self.distinct or resolve_inner_expression:
34+
conditions = [IsNull(source_expressions[0], False)]
3435
if self.filter:
35-
node = self.copy()
36-
node.filter = None
37-
source_expressions = node.get_source_expressions()
38-
condition = When(
39-
self.filter, then=Case(When(IsNull(source_expressions[0], False), then=Value(1)))
40-
)
41-
node.set_source_expressions([Case(condition), *source_expressions[1:]])
42-
inner_expression = process_lhs(node, compiler, connection, as_expr=True)
43-
else:
44-
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
45-
null_cond = {"$in": [{"$type": lhs_mql}, ["missing", "null"]]}
46-
inner_expression = {
47-
"$cond": {"if": null_cond, "then": None, "else": lhs_mql if self.distinct else 1}
48-
}
36+
conditions.append(self.filter)
37+
inner_expression = Case(
38+
When(WhereNode(conditions), then=source_expressions[0] if self.distinct else Value(1)),
39+
# Skip the rows that does not met the criteria.
40+
default=Remove(),
41+
)
42+
inner_expression = inner_expression.as_mql(compiler, connection, as_expr=True)
4943
if resolve_inner_expression:
5044
return inner_expression
5145
return {"$sum": inner_expression}
5246
# If distinct=True or resolve_inner_expression=False, sum the size of the
5347
# set.
54-
lhs_mql = process_lhs(self, compiler, connection, as_expr=True)
55-
# None shouldn't be counted, so subtract 1 if it's present.
56-
exits_null = {"$cond": {"if": {"$in": [{"$literal": None}, lhs_mql]}, "then": -1, "else": 0}}
57-
return {"$add": [{"$size": lhs_mql}, exits_null]}
48+
return {"$size": source_expressions[0].as_mql(compiler, connection, as_expr=True)}
5849

5950

6051
def stddev_variance(self, compiler, connection):

django_mongodb_backend/expressions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .expressions import Remove
12
from .search import (
23
CombinedSearchExpression,
34
CompoundExpression,
@@ -21,6 +22,7 @@
2122
__all__ = [
2223
"CombinedSearchExpression",
2324
"CompoundExpression",
25+
"Remove",
2426
"SearchAutocomplete",
2527
"SearchEquals",
2628
"SearchExists",

django_mongodb_backend/query_utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from django.core.exceptions import FullResultSet
22
from django.db.models import F
3-
from django.db.models.aggregates import Aggregate
43
from django.db.models.expressions import CombinedExpression, Func, Value
54
from django.db.models.sql.query import Query
65

@@ -20,8 +19,6 @@ def process_lhs(node, compiler, connection, as_expr=False):
2019
result.append(expr.as_mql(compiler, connection, as_expr=as_expr))
2120
except FullResultSet:
2221
result.append(Value(True).as_mql(compiler, connection, as_expr=as_expr))
23-
if isinstance(node, Aggregate):
24-
return result[0]
2522
return result
2623
# node is a Transform with just one source expression, aliased as "lhs".
2724
if is_direct_value(node.lhs):

0 commit comments

Comments
 (0)