3535import org .springframework .data .jpa .repository .Modifying ;
3636import org .springframework .data .jpa .repository .NativeQuery ;
3737import org .springframework .data .jpa .repository .QueryHints ;
38+ import org .springframework .data .jpa .repository .QueryRewriter ;
3839import org .springframework .data .jpa .repository .query .DeclaredQuery ;
3940import org .springframework .data .jpa .repository .query .JpaQueryMethod ;
4041import org .springframework .data .jpa .repository .query .ParameterBinding ;
@@ -85,6 +86,7 @@ static class QueryBlockBuilder {
8586 private @ Nullable AotEntityGraph entityGraph ;
8687 private @ Nullable String sqlResultSetMapping ;
8788 private @ Nullable Class <?> queryReturnType ;
89+ private @ Nullable Class <?> queryRewriter = QueryRewriter .IdentityQueryRewriter .class ;
8890
8991 private QueryBlockBuilder (AotQueryMethodGenerationContext context , JpaQueryMethod queryMethod ) {
9092 this .context = context ;
@@ -126,6 +128,11 @@ public QueryBlockBuilder queryReturnType(@Nullable Class<?> queryReturnType) {
126128 return this ;
127129 }
128130
131+ public QueryBlockBuilder queryRewriter (@ Nullable Class <?> queryRewriter ) {
132+ this .queryRewriter = queryRewriter == null ? QueryRewriter .IdentityQueryRewriter .class : queryRewriter ;
133+ return this ;
134+ }
135+
129136 /**
130137 * Build the query block.
131138 *
@@ -145,12 +152,20 @@ public CodeBlock build() {
145152 CodeBlock .Builder builder = CodeBlock .builder ();
146153 builder .add ("\n " );
147154
148- String queryStringNameVariableName = null ;
155+ String queryStringVariableName = null ;
156+
157+ String queryRewriterName = null ;
158+
159+ if (queries .result () instanceof StringAotQuery && queryRewriter != QueryRewriter .IdentityQueryRewriter .class ) {
160+
161+ queryRewriterName = "queryRewriter" ;
162+ builder .addStatement ("$T $L = new $T()" , queryRewriter , queryRewriterName , queryRewriter );
163+ }
149164
150165 if (queries != null && queries .result () instanceof StringAotQuery sq ) {
151166
152- queryStringNameVariableName = "%sString" .formatted (queryVariableName );
153- builder .addStatement ( "$T $L = $S" , String . class , queryStringNameVariableName , sq . getQueryString ( ));
167+ queryStringVariableName = "%sString" .formatted (queryVariableName );
168+ builder .add ( buildQueryString ( sq , queryStringVariableName ));
154169 }
155170
156171 String countQueryStringNameVariableName = null ;
@@ -159,7 +174,7 @@ public CodeBlock build() {
159174 if (queryMethod .isPageQuery () && queries .count () instanceof StringAotQuery sq ) {
160175
161176 countQueryStringNameVariableName = "count%sString" .formatted (StringUtils .capitalize (queryVariableName ));
162- builder .addStatement ( "$T $L = $S" , String . class , countQueryStringNameVariableName , sq . getQueryString ( ));
177+ builder .add ( buildQueryString ( sq , countQueryStringNameVariableName ));
163178 }
164179
165180 String sortParameterName = context .getSortParameterName ();
@@ -169,14 +184,14 @@ public CodeBlock build() {
169184
170185 if ((StringUtils .hasText (sortParameterName ) || StringUtils .hasText (dynamicReturnType ))
171186 && queries .result () instanceof StringAotQuery ) {
172- builder .add (applyRewrite (sortParameterName , dynamicReturnType , queryStringNameVariableName , actualReturnType ));
187+ builder .add (applyRewrite (sortParameterName , dynamicReturnType , queryStringVariableName , actualReturnType ));
173188 }
174189
175190 if (queries .result ().hasExpression () || queries .count ().hasExpression ()) {
176191 builder .addStatement ("class ExpressionMarker{}" );
177192 }
178193
179- builder .add (createQuery (false , queryVariableName , queryStringNameVariableName , queries .result (),
194+ builder .add (createQuery (false , queryVariableName , queryStringVariableName , queryRewriterName , queries .result (),
180195 this .sqlResultSetMapping , this .queryHints , this .entityGraph , this .queryReturnType ));
181196
182197 builder .add (applyLimits (queries .result ().isExists ()));
@@ -187,7 +202,8 @@ public CodeBlock build() {
187202
188203 boolean queryHints = this .queryHints .isPresent () && this .queryHints .getBoolean ("forCounting" );
189204
190- builder .add (createQuery (true , countQueryVariableName , countQueryStringNameVariableName , queries .count (), null ,
205+ builder .add (createQuery (true , countQueryVariableName , countQueryStringNameVariableName , queryRewriterName ,
206+ queries .count (), null ,
191207 queryHints ? this .queryHints : MergedAnnotation .missing (), null , Long .class ));
192208 builder .addStatement ("return ($T) $L.getSingleResult()" , Long .class , countQueryVariableName );
193209
@@ -199,6 +215,13 @@ public CodeBlock build() {
199215 return builder .build ();
200216 }
201217
218+ private CodeBlock buildQueryString (StringAotQuery sq , String queryStringVariableName ) {
219+
220+ CodeBlock .Builder builder = CodeBlock .builder ();
221+ builder .addStatement ("$T $L = $S" , String .class , queryStringVariableName , sq .getQueryString ());
222+ return builder .build ();
223+ }
224+
202225 private CodeBlock applyRewrite (@ Nullable String sort , @ Nullable String dynamicReturnType , String queryString ,
203226 Class <?> actualReturnType ) {
204227
@@ -268,12 +291,14 @@ private CodeBlock applyLimits(boolean exists) {
268291 }
269292
270293 private CodeBlock createQuery (boolean count , String queryVariableName , @ Nullable String queryStringNameVariableName ,
271- AotQuery query , @ Nullable String sqlResultSetMapping , MergedAnnotation <QueryHints > queryHints ,
294+ @ Nullable String queryRewriterName , AotQuery query , @ Nullable String sqlResultSetMapping ,
295+ MergedAnnotation <QueryHints > queryHints ,
272296 @ Nullable AotEntityGraph entityGraph , @ Nullable Class <?> queryReturnType ) {
273297
274298 Builder builder = CodeBlock .builder ();
275299
276- builder .add (doCreateQuery (count , queryVariableName , queryStringNameVariableName , query , sqlResultSetMapping ,
300+ builder .add (doCreateQuery (count , queryVariableName , queryStringNameVariableName , queryRewriterName , query ,
301+ sqlResultSetMapping ,
277302 queryReturnType ));
278303
279304 if (entityGraph != null ) {
@@ -306,18 +331,36 @@ private CodeBlock createQuery(boolean count, String queryVariableName, @Nullable
306331 }
307332
308333 private CodeBlock doCreateQuery (boolean count , String queryVariableName ,
309- @ Nullable String queryStringNameVariableName , AotQuery query , @ Nullable String sqlResultSetMapping ,
334+ @ Nullable String queryStringName , @ Nullable String queryRewriterName , AotQuery query ,
335+ @ Nullable String sqlResultSetMapping ,
310336 @ Nullable Class <?> queryReturnType ) {
311337
312338 ReturnedType returnedType = context .getReturnedType ();
313339 Builder builder = CodeBlock .builder ();
340+ String queryStringNameToUse = queryStringName ;
314341
315342 if (query instanceof StringAotQuery sq ) {
316343
344+ if (StringUtils .hasText (queryRewriterName )) {
345+
346+ queryStringNameToUse = queryStringName + "Rewritten" ;
347+
348+ if (StringUtils .hasText (context .getPageableParameterName ())) {
349+ builder .addStatement ("$T $L = $L.rewrite($L, $L)" , String .class , queryStringNameToUse , queryRewriterName ,
350+ queryStringName , context .getPageableParameterName ());
351+ } else if (StringUtils .hasText (context .getSortParameterName ())) {
352+ builder .addStatement ("$T $L = $L.rewrite($L, $L)" , String .class , queryStringNameToUse , queryRewriterName ,
353+ queryStringName , context .getSortParameterName ());
354+ } else {
355+ builder .addStatement ("$T $L = $L.rewrite($L, $T.unsorted())" , String .class , queryStringNameToUse ,
356+ queryRewriterName , queryStringName , Sort .class );
357+ }
358+ }
359+
317360 if (StringUtils .hasText (sqlResultSetMapping )) {
318361
319362 builder .addStatement ("$T $L = this.$L.createNativeQuery($L, $S)" , Query .class , queryVariableName ,
320- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName , sqlResultSetMapping );
363+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse , sqlResultSetMapping );
321364
322365 return builder .build ();
323366 }
@@ -327,10 +370,10 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
327370 if (queryReturnType != null ) {
328371
329372 builder .addStatement ("$T $L = this.$L.createNativeQuery($L, $T.class)" , Query .class , queryVariableName ,
330- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName , queryReturnType );
373+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse , queryReturnType );
331374 } else {
332375 builder .addStatement ("$T $L = this.$L.createNativeQuery($L)" , Query .class , queryVariableName ,
333- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName );
376+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse );
334377 }
335378
336379 return builder .build ();
@@ -339,18 +382,18 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
339382 if (sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
340383 && returnedType .getReturnedType ().isInterface ()) {
341384 builder .addStatement ("$T $L = this.$L.createQuery($L)" , Query .class , queryVariableName ,
342- context .fieldNameOf (EntityManager .class ), queryStringNameVariableName );
385+ context .fieldNameOf (EntityManager .class ), queryStringNameToUse );
343386 } else {
344387
345388 String createQueryMethod = query .isNative () ? "createNativeQuery" : "createQuery" ;
346389
347390 if (!sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
348391 && returnedType .getReturnedType ().isInterface ()) {
349392 builder .addStatement ("$T $L = this.$L.$L($L, $T.class)" , Query .class , queryVariableName ,
350- context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameVariableName , Tuple .class );
393+ context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameToUse , Tuple .class );
351394 } else {
352395 builder .addStatement ("$T $L = this.$L.$L($L)" , Query .class , queryVariableName ,
353- context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameVariableName );
396+ context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameToUse );
354397 }
355398 }
356399
0 commit comments