Skip to content

Commit 41239fe

Browse files
charettessarahboyce
authored andcommitted
Fixed #36149 -- Allowed subquery values against tuple exact and in lookups.
Non-tuple exact and in lookups have specialized logic for subqueries that can be adapted to properly assign select mask if unspecified and ensure the number of involved members are matching on both side of the operator.
1 parent 0597e8a commit 41239fe

12 files changed

Lines changed: 137 additions & 77 deletions

File tree

django/db/models/expressions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,9 @@ def relabeled_clone(self, relabels):
13671367
def resolve_expression(self, *args, **kwargs):
13681368
return self
13691369

1370+
def select_format(self, compiler, sql, params):
1371+
return sql, params
1372+
13701373

13711374
class Ref(Expression):
13721375
"""

django/db/models/fields/related_lookups.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,16 @@ def get_normalized_value(value, lhs):
4040

4141
class RelatedIn(In):
4242
def get_prep_lookup(self):
43-
if not isinstance(self.lhs, ColPairs):
43+
from django.db.models.sql.query import Query # avoid circular import
44+
45+
if isinstance(self.lhs, ColPairs):
46+
if (
47+
isinstance(self.rhs, Query)
48+
and not self.rhs.has_select_fields
49+
and self.lhs.output_field.related_model is self.rhs.model
50+
):
51+
self.rhs.set_values([f.name for f in self.lhs.sources])
52+
else:
4453
if self.rhs_is_direct_value():
4554
# If we get here, we are dealing with single-column relations.
4655
self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]

django/db/models/fields/tuple_lookups.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def get_prep_lookup(self):
4747
self.check_rhs_is_tuple_or_list()
4848
self.check_rhs_length_equals_lhs_length()
4949
else:
50-
self.check_rhs_is_outer_ref()
50+
self.check_rhs_is_supported_expression()
51+
super().get_prep_lookup()
5152
return self.rhs
5253

5354
def check_rhs_is_tuple_or_list(self):
@@ -65,13 +66,13 @@ def check_rhs_length_equals_lhs_length(self):
6566
f"{self.lookup_name!r} lookup of {lhs_str} must have {len_lhs} elements"
6667
)
6768

68-
def check_rhs_is_outer_ref(self):
69-
if not isinstance(self.rhs, ResolvedOuterRef):
69+
def check_rhs_is_supported_expression(self):
70+
if not isinstance(self.rhs, (ResolvedOuterRef, Query)):
7071
lhs_str = self.get_lhs_str()
7172
rhs_cls = self.rhs.__class__.__name__
7273
raise ValueError(
7374
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
74-
f"only supports OuterRef objects (received {rhs_cls!r})"
75+
f"only supports OuterRef and QuerySet objects (received {rhs_cls!r})"
7576
)
7677

7778
def get_lhs_str(self):
@@ -101,11 +102,14 @@ def process_rhs(self, compiler, connection):
101102
return compiler.compile(Tuple(*args))
102103
else:
103104
sql, params = compiler.compile(self.rhs)
104-
if not isinstance(self.rhs, ColPairs):
105+
if isinstance(self.rhs, ColPairs):
106+
return "(%s)" % sql, params
107+
elif isinstance(self.rhs, Query):
108+
return super().process_rhs(compiler, connection)
109+
else:
105110
raise ValueError(
106111
"Composite field lookups only work with composite expressions."
107112
)
108-
return "(%s)" % sql, params
109113

110114
def get_fallback_sql(self, compiler, connection):
111115
raise NotImplementedError(
@@ -121,6 +125,8 @@ def as_sql(self, compiler, connection):
121125

122126
class TupleExact(TupleLookupMixin, Exact):
123127
def get_fallback_sql(self, compiler, connection):
128+
if isinstance(self.rhs, Query):
129+
return super(TupleLookupMixin, self).as_sql(compiler, connection)
124130
# Process right-hand-side to trigger sanitization.
125131
self.process_rhs(compiler, connection)
126132
# e.g.: (a, b, c) == (x, y, z) as SQL:
@@ -273,7 +279,7 @@ def get_prep_lookup(self):
273279
self.check_rhs_elements_length_equals_lhs_length()
274280
else:
275281
self.check_rhs_is_query()
276-
self.check_rhs_select_length_equals_lhs_length()
282+
super(TupleLookupMixin, self).get_prep_lookup()
277283

278284
return self.rhs # skip checks from mixin
279285

@@ -303,19 +309,10 @@ def check_rhs_is_query(self):
303309
f"must be a Query object (received {rhs_cls!r})"
304310
)
305311

306-
def check_rhs_select_length_equals_lhs_length(self):
307-
len_rhs = len(self.rhs.select)
308-
if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
309-
len_rhs = len(self.rhs.select[0])
310-
len_lhs = len(self.lhs)
311-
if len_rhs != len_lhs:
312-
lhs_str = self.get_lhs_str()
313-
raise ValueError(
314-
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
315-
f"must have {len_lhs} fields (received {len_rhs})"
316-
)
317-
318312
def process_rhs(self, compiler, connection):
313+
if not self.rhs_is_direct_value():
314+
return super(TupleLookupMixin, self).process_rhs(compiler, connection)
315+
319316
rhs = self.rhs
320317
if not rhs:
321318
raise EmptyResultSet
@@ -337,19 +334,12 @@ def process_rhs(self, compiler, connection):
337334

338335
return compiler.compile(Tuple(*result))
339336

340-
def as_subquery_sql(self, compiler, connection):
341-
lhs = self.lhs
342-
rhs = self.rhs
343-
if isinstance(lhs, ColPairs):
344-
rhs = rhs.clone()
345-
rhs.set_values([source.name for source in lhs.sources])
346-
lhs = Tuple(lhs)
347-
return compiler.compile(In(lhs, rhs))
348-
349337
def get_fallback_sql(self, compiler, connection):
350338
rhs = self.rhs
351339
if not rhs:
352340
raise EmptyResultSet
341+
if not self.rhs_is_direct_value():
342+
return super(TupleLookupMixin, self).as_sql(compiler, connection)
353343

354344
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
355345
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -362,11 +352,6 @@ def get_fallback_sql(self, compiler, connection):
362352

363353
return root.as_sql(compiler, connection)
364354

365-
def as_sql(self, compiler, connection):
366-
if not self.rhs_is_direct_value():
367-
return self.as_subquery_sql(compiler, connection)
368-
return super().as_sql(compiler, connection)
369-
370355

371356
tuple_lookups = {
372357
"exact": TupleExact,

django/db/models/lookups.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,21 @@ class Exact(FieldGetDbPrepValueMixin, BuiltinLookup):
373373
def get_prep_lookup(self):
374374
from django.db.models.sql.query import Query # avoid circular import
375375

376-
if isinstance(self.rhs, Query):
377-
if self.rhs.has_limit_one():
378-
if not self.rhs.has_select_fields:
379-
self.rhs.clear_select_clause()
380-
self.rhs.add_fields(["pk"])
381-
else:
376+
if isinstance(query := self.rhs, Query):
377+
if not query.has_limit_one():
382378
raise ValueError(
383379
"The QuerySet value for an exact lookup must be limited to "
384380
"one result using slicing."
385381
)
382+
lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1
383+
if (rhs_len := query._subquery_fields_len) != lhs_len:
384+
raise ValueError(
385+
f"The QuerySet value for the exact lookup must have {lhs_len} "
386+
f"selected fields (received {rhs_len})"
387+
)
388+
if not query.has_select_fields:
389+
query.clear_select_clause()
390+
query.add_fields(["pk"])
386391
return super().get_prep_lookup()
387392

388393
def as_sql(self, compiler, connection):
@@ -499,6 +504,12 @@ def get_prep_lookup(self):
499504
from django.db.models.sql.query import Query # avoid circular import
500505

501506
if isinstance(self.rhs, Query):
507+
lhs_len = len(self.lhs) if isinstance(self.lhs, (ColPairs, tuple)) else 1
508+
if (rhs_len := self.rhs._subquery_fields_len) != lhs_len:
509+
raise ValueError(
510+
f"The QuerySet value for the 'in' lookup must have {lhs_len} "
511+
f"selected fields (received {rhs_len})"
512+
)
502513
self.rhs.clear_ordering(clear_default=True)
503514
if not self.rhs.has_select_fields:
504515
self.rhs.clear_select_clause()

django/db/models/query.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,10 +1958,6 @@ def _merge_known_related_objects(self, other):
19581958
self._known_related_objects.setdefault(field, {}).update(objects)
19591959

19601960
def resolve_expression(self, *args, **kwargs):
1961-
if self._fields and len(self._fields) > 1:
1962-
# values() queryset can only be used as nested queries
1963-
# if they are set up to select only a single field.
1964-
raise TypeError("Cannot use multi-field values as a filter value.")
19651961
query = self.query.resolve_expression(*args, **kwargs)
19661962
query._db = self._db
19671963
return query

django/db/models/sql/query.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,12 @@ def add_annotation(self, annotation, alias, select=True):
12241224
if self.selected:
12251225
self.selected[alias] = alias
12261226

1227+
@property
1228+
def _subquery_fields_len(self):
1229+
if self.has_select_fields:
1230+
return len(self.selected)
1231+
return len(self.model._meta.pk_fields)
1232+
12271233
def resolve_expression(self, query, *args, **kwargs):
12281234
clone = self.clone()
12291235
# Subqueries need to use a different set of aliases than the outer query.

tests/composite_pk/models/tenant.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class Comment(models.Model):
4444
related_name="comments",
4545
)
4646
text = models.TextField(default="", blank=True)
47+
integer = models.IntegerField(default=0)
4748

4849

4950
class Post(models.Model):

tests/composite_pk/test_filter.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111
from django.db.models.functions import Cast
1212
from django.db.models.lookups import Exact
13-
from django.test import TestCase
13+
from django.test import TestCase, skipUnlessDBFeature
1414

1515
from .models import Comment, Tenant, User
1616

@@ -182,6 +182,30 @@ def test_filter_comments_by_pk_in(self):
182182
Comment.objects.filter(pk__in=pks).order_by("pk"), objs
183183
)
184184

185+
def test_filter_comments_by_pk_in_subquery(self):
186+
self.assertSequenceEqual(
187+
Comment.objects.filter(
188+
pk__in=Comment.objects.filter(pk=self.comment_1.pk),
189+
),
190+
[self.comment_1],
191+
)
192+
self.assertSequenceEqual(
193+
Comment.objects.filter(
194+
pk__in=Comment.objects.filter(pk=self.comment_1.pk).values(
195+
"tenant_id", "id"
196+
),
197+
),
198+
[self.comment_1],
199+
)
200+
self.comment_2.integer = self.comment_1.id
201+
self.comment_2.save()
202+
self.assertSequenceEqual(
203+
Comment.objects.filter(
204+
pk__in=Comment.objects.values("tenant_id", "integer"),
205+
),
206+
[self.comment_1],
207+
)
208+
185209
def test_filter_comments_by_user_and_order_by_pk_asc(self):
186210
self.assertSequenceEqual(
187211
Comment.objects.filter(user=self.user_1).order_by("pk"),
@@ -440,16 +464,40 @@ def test_outer_ref_pk(self):
440464
queryset = Comment.objects.filter(**{f"id{lookup}": subquery})
441465
self.assertEqual(queryset.count(), expected_count)
442466

443-
def test_non_outer_ref_subquery(self):
444-
# If rhs is any non-OuterRef object with an as_sql() function.
467+
def test_unsupported_rhs(self):
445468
pk = Exact(F("tenant_id"), 1)
446469
msg = (
447-
"'exact' subquery lookup of 'pk' only supports OuterRef objects "
448-
"(received 'Exact')"
470+
"'exact' subquery lookup of 'pk' only supports OuterRef "
471+
"and QuerySet objects (received 'Exact')"
449472
)
450473
with self.assertRaisesMessage(ValueError, msg):
451474
Comment.objects.filter(pk=pk)
452475

476+
@skipUnlessDBFeature("allow_sliced_subqueries_with_in")
477+
def test_filter_comments_by_pk_exact_subquery(self):
478+
self.assertSequenceEqual(
479+
Comment.objects.filter(
480+
pk=Comment.objects.filter(pk=self.comment_1.pk)[:1],
481+
),
482+
[self.comment_1],
483+
)
484+
self.assertSequenceEqual(
485+
Comment.objects.filter(
486+
pk__in=Comment.objects.filter(pk=self.comment_1.pk).values(
487+
"tenant_id", "id"
488+
)[:1],
489+
),
490+
[self.comment_1],
491+
)
492+
self.comment_2.integer = self.comment_1.id
493+
self.comment_2.save()
494+
self.assertSequenceEqual(
495+
Comment.objects.filter(
496+
pk__in=Comment.objects.values("tenant_id", "integer"),
497+
)[:1],
498+
[self.comment_1],
499+
)
500+
453501
def test_outer_ref_not_composite_pk(self):
454502
subquery = Comment.objects.filter(pk=OuterRef("id")).values("id")
455503
queryset = Comment.objects.filter(id=Subquery(subquery))

tests/composite_pk/tests.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,10 @@ def test_pk_must_have_2_elements(self):
109109

110110
def test_composite_pk_in_fields(self):
111111
user_fields = {f.name for f in User._meta.get_fields()}
112-
self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
112+
self.assertTrue({"pk", "tenant", "id"}.issubset(user_fields))
113113

114114
comment_fields = {f.name for f in Comment._meta.get_fields()}
115-
self.assertEqual(
116-
comment_fields,
117-
{"pk", "tenant", "id", "user_id", "user", "text"},
118-
)
115+
self.assertTrue({"pk", "tenant", "id"}.issubset(comment_fields))
119116

120117
def test_pk_field(self):
121118
pk = User._meta.get_field("pk")
@@ -174,7 +171,7 @@ def test_only(self):
174171
self.assertEqual(user.email, self.user.email)
175172

176173
def test_model_forms(self):
177-
fields = ["tenant", "id", "user_id", "text"]
174+
fields = ["tenant", "id", "user_id", "text", "integer"]
178175
self.assertEqual(list(CommentForm.base_fields), fields)
179176

180177
form = modelform_factory(Comment, fields="__all__")

tests/foreign_object/test_tuple_lookups.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ def test_exact(self):
6363
)
6464

6565
def test_exact_subquery(self):
66-
with self.assertRaisesMessage(
67-
ValueError, "'exact' doesn't support multi-column subqueries."
68-
):
66+
msg = (
67+
"The QuerySet value for the exact lookup must have 2 selected "
68+
"fields (received 1)"
69+
)
70+
with self.assertRaisesMessage(ValueError, msg):
6971
subquery = Customer.objects.filter(id=self.customer_1.id)[:1]
7072
self.assertSequenceEqual(
7173
Contact.objects.filter(customer=subquery).order_by("id"), ()
@@ -140,11 +142,11 @@ def test_tuple_in_subquery_must_be_query(self):
140142
def test_tuple_in_subquery_must_have_2_fields(self):
141143
lhs = (F("customer_code"), F("company_code"))
142144
rhs = Customer.objects.values_list("customer_id").query
143-
with self.assertRaisesMessage(
144-
ValueError,
145-
"'in' subquery lookup of ('customer_code', 'company_code') "
146-
"must have 2 fields (received 1)",
147-
):
145+
msg = (
146+
"The QuerySet value for the 'in' lookup must have 2 selected "
147+
"fields (received 1)"
148+
)
149+
with self.assertRaisesMessage(ValueError, msg):
148150
TupleIn(lhs, rhs)
149151

150152
def test_tuple_in_subquery(self):

0 commit comments

Comments
 (0)