@@ -47,7 +47,8 @@ def get_prep_lookup(self):
4747 self .check_rhs_is_tuple_or_list ()
4848 self .check_rhs_length_equals_lhs_length ()
4949 else :
50- self .check_rhs_is_outer_ref ()
50+ self .check_rhs_is_supported_expression ()
51+ super ().get_prep_lookup ()
5152 return self .rhs
5253
5354 def check_rhs_is_tuple_or_list (self ):
@@ -65,13 +66,13 @@ def check_rhs_length_equals_lhs_length(self):
6566 f"{ self .lookup_name !r} lookup of { lhs_str } must have { len_lhs } elements"
6667 )
6768
68- def check_rhs_is_outer_ref (self ):
69- if not isinstance (self .rhs , ResolvedOuterRef ):
69+ def check_rhs_is_supported_expression (self ):
70+ if not isinstance (self .rhs , ( ResolvedOuterRef , Query ) ):
7071 lhs_str = self .get_lhs_str ()
7172 rhs_cls = self .rhs .__class__ .__name__
7273 raise ValueError (
7374 f"{ self .lookup_name !r} subquery lookup of { lhs_str } "
74- f"only supports OuterRef objects (received { rhs_cls !r} )"
75+ f"only supports OuterRef and QuerySet objects (received { rhs_cls !r} )"
7576 )
7677
7778 def get_lhs_str (self ):
@@ -101,11 +102,14 @@ def process_rhs(self, compiler, connection):
101102 return compiler .compile (Tuple (* args ))
102103 else :
103104 sql , params = compiler .compile (self .rhs )
104- if not isinstance (self .rhs , ColPairs ):
105+ if isinstance (self .rhs , ColPairs ):
106+ return "(%s)" % sql , params
107+ elif isinstance (self .rhs , Query ):
108+ return super ().process_rhs (compiler , connection )
109+ else :
105110 raise ValueError (
106111 "Composite field lookups only work with composite expressions."
107112 )
108- return "(%s)" % sql , params
109113
110114 def get_fallback_sql (self , compiler , connection ):
111115 raise NotImplementedError (
@@ -121,6 +125,8 @@ def as_sql(self, compiler, connection):
121125
122126class TupleExact (TupleLookupMixin , Exact ):
123127 def get_fallback_sql (self , compiler , connection ):
128+ if isinstance (self .rhs , Query ):
129+ return super (TupleLookupMixin , self ).as_sql (compiler , connection )
124130 # Process right-hand-side to trigger sanitization.
125131 self .process_rhs (compiler , connection )
126132 # e.g.: (a, b, c) == (x, y, z) as SQL:
@@ -273,7 +279,7 @@ def get_prep_lookup(self):
273279 self .check_rhs_elements_length_equals_lhs_length ()
274280 else :
275281 self .check_rhs_is_query ()
276- self . check_rhs_select_length_equals_lhs_length ()
282+ super ( TupleLookupMixin , self ). get_prep_lookup ()
277283
278284 return self .rhs # skip checks from mixin
279285
@@ -303,19 +309,10 @@ def check_rhs_is_query(self):
303309 f"must be a Query object (received { rhs_cls !r} )"
304310 )
305311
306- def check_rhs_select_length_equals_lhs_length (self ):
307- len_rhs = len (self .rhs .select )
308- if len_rhs == 1 and isinstance (self .rhs .select [0 ], ColPairs ):
309- len_rhs = len (self .rhs .select [0 ])
310- len_lhs = len (self .lhs )
311- if len_rhs != len_lhs :
312- lhs_str = self .get_lhs_str ()
313- raise ValueError (
314- f"{ self .lookup_name !r} subquery lookup of { lhs_str } "
315- f"must have { len_lhs } fields (received { len_rhs } )"
316- )
317-
318312 def process_rhs (self , compiler , connection ):
313+ if not self .rhs_is_direct_value ():
314+ return super (TupleLookupMixin , self ).process_rhs (compiler , connection )
315+
319316 rhs = self .rhs
320317 if not rhs :
321318 raise EmptyResultSet
@@ -337,19 +334,12 @@ def process_rhs(self, compiler, connection):
337334
338335 return compiler .compile (Tuple (* result ))
339336
340- def as_subquery_sql (self , compiler , connection ):
341- lhs = self .lhs
342- rhs = self .rhs
343- if isinstance (lhs , ColPairs ):
344- rhs = rhs .clone ()
345- rhs .set_values ([source .name for source in lhs .sources ])
346- lhs = Tuple (lhs )
347- return compiler .compile (In (lhs , rhs ))
348-
349337 def get_fallback_sql (self , compiler , connection ):
350338 rhs = self .rhs
351339 if not rhs :
352340 raise EmptyResultSet
341+ if not self .rhs_is_direct_value ():
342+ return super (TupleLookupMixin , self ).as_sql (compiler , connection )
353343
354344 # e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
355345 # WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -362,11 +352,6 @@ def get_fallback_sql(self, compiler, connection):
362352
363353 return root .as_sql (compiler , connection )
364354
365- def as_sql (self , compiler , connection ):
366- if not self .rhs_is_direct_value ():
367- return self .as_subquery_sql (compiler , connection )
368- return super ().as_sql (compiler , connection )
369-
370355
371356tuple_lookups = {
372357 "exact" : TupleExact ,
0 commit comments