Skip to content

Commit 4b7a47c

Browse files
authored
Allow user-defined function as first parameter of the any/contains/startsWith/endsWith built-in query string functions; add example to expose function defined in SQL (#1881)
1 parent 1e33bd4 commit 4b7a47c

File tree

20 files changed

+543
-63
lines changed

20 files changed

+543
-63
lines changed

src/Examples/DapperExample/TranslationToSql/Builders/SelectStatementBuilder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,13 +678,13 @@ private CountNode GetCountClause(CountExpression expression, TableAccessorNode o
678678

679679
public override SqlTreeNode VisitMatchText(MatchTextExpression expression, TableAccessorNode tableAccessor)
680680
{
681-
var column = (ColumnNode)Visit(expression.TargetAttribute, tableAccessor);
681+
var column = (ColumnNode)Visit(expression.MatchTarget, tableAccessor);
682682
return new LikeNode(column, expression.MatchKind, (string)expression.TextValue.TypedValue);
683683
}
684684

685685
public override SqlTreeNode VisitAny(AnyExpression expression, TableAccessorNode tableAccessor)
686686
{
687-
var column = (ColumnNode)Visit(expression.TargetAttribute, tableAccessor);
687+
var column = (ColumnNode)Visit(expression.MatchTarget, tableAccessor);
688688

689689
ReadOnlyCollection<ParameterNode> parameters =
690690
VisitSequence<LiteralConstantExpression, ParameterNode>(expression.Constants.OrderBy(constant => constant.TypedValue), tableAccessor);

src/JsonApiDotNetCore/Queries/Expressions/AnyExpression.cs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,21 @@ namespace JsonApiDotNetCore.Queries.Expressions;
1717
public class AnyExpression : FilterExpression
1818
{
1919
/// <summary>
20-
/// The attribute whose value to compare. Chain format: an optional list of to-one relationships, followed by an attribute.
20+
/// The function or attribute whose value to compare. Attribute chain format: an optional list of to-one relationships, followed by an attribute.
2121
/// </summary>
22-
public ResourceFieldChainExpression TargetAttribute { get; }
22+
public QueryExpression MatchTarget { get; }
2323

2424
/// <summary>
2525
/// One or more constants to compare the attribute's value against.
2626
/// </summary>
2727
public IImmutableSet<LiteralConstantExpression> Constants { get; }
2828

29-
public AnyExpression(ResourceFieldChainExpression targetAttribute, IImmutableSet<LiteralConstantExpression> constants)
29+
public AnyExpression(QueryExpression matchTarget, IImmutableSet<LiteralConstantExpression> constants)
3030
{
31-
ArgumentNullException.ThrowIfNull(targetAttribute);
31+
ArgumentNullException.ThrowIfNull(matchTarget);
3232
ArgumentGuard.NotNullNorEmpty(constants);
3333

34-
TargetAttribute = targetAttribute;
34+
MatchTarget = matchTarget;
3535
Constants = constants;
3636
}
3737

@@ -56,7 +56,7 @@ private string InnerToString(bool toFullString)
5656

5757
builder.Append(Keywords.Any);
5858
builder.Append('(');
59-
builder.Append(toFullString ? TargetAttribute.ToFullString() : TargetAttribute.ToString());
59+
builder.Append(toFullString ? MatchTarget.ToFullString() : MatchTarget.ToString());
6060
builder.Append(',');
6161
builder.Append(string.Join(',', Constants.Select(constant => toFullString ? constant.ToFullString() : constant.ToString()).Order()));
6262
builder.Append(')');
@@ -78,13 +78,13 @@ public override bool Equals(object? obj)
7878

7979
var other = (AnyExpression)obj;
8080

81-
return TargetAttribute.Equals(other.TargetAttribute) && Constants.SetEquals(other.Constants);
81+
return MatchTarget.Equals(other.MatchTarget) && Constants.SetEquals(other.Constants);
8282
}
8383

8484
public override int GetHashCode()
8585
{
8686
var hashCode = new HashCode();
87-
hashCode.Add(TargetAttribute);
87+
hashCode.Add(MatchTarget);
8888

8989
foreach (LiteralConstantExpression constant in Constants)
9090
{

src/JsonApiDotNetCore/Queries/Expressions/MatchTextExpression.cs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ namespace JsonApiDotNetCore.Queries.Expressions;
2424
public class MatchTextExpression : FilterExpression
2525
{
2626
/// <summary>
27-
/// The attribute whose value to match. Chain format: an optional list of to-one relationships, followed by an attribute.
27+
/// The function or attribute whose value to match. Attribute chain format: an optional list of to-one relationships, followed by an attribute.
2828
/// </summary>
29-
public ResourceFieldChainExpression TargetAttribute { get; }
29+
public QueryExpression MatchTarget { get; }
3030

3131
/// <summary>
32-
/// The text to match the attribute's value against.
32+
/// The text to match against.
3333
/// </summary>
3434
public LiteralConstantExpression TextValue { get; }
3535

@@ -38,12 +38,12 @@ public class MatchTextExpression : FilterExpression
3838
/// </summary>
3939
public TextMatchKind MatchKind { get; }
4040

41-
public MatchTextExpression(ResourceFieldChainExpression targetAttribute, LiteralConstantExpression textValue, TextMatchKind matchKind)
41+
public MatchTextExpression(QueryExpression matchTarget, LiteralConstantExpression textValue, TextMatchKind matchKind)
4242
{
43-
ArgumentNullException.ThrowIfNull(targetAttribute);
43+
ArgumentNullException.ThrowIfNull(matchTarget);
4444
ArgumentNullException.ThrowIfNull(textValue);
4545

46-
TargetAttribute = targetAttribute;
46+
MatchTarget = matchTarget;
4747
TextValue = textValue;
4848
MatchKind = matchKind;
4949
}
@@ -71,8 +71,8 @@ private string InnerToString(bool toFullString)
7171
builder.Append('(');
7272

7373
builder.Append(toFullString
74-
? string.Join(',', TargetAttribute.ToFullString(), TextValue.ToFullString())
75-
: string.Join(',', TargetAttribute.ToString(), TextValue.ToString()));
74+
? string.Join(',', MatchTarget.ToFullString(), TextValue.ToFullString())
75+
: string.Join(',', MatchTarget.ToString(), TextValue.ToString()));
7676

7777
builder.Append(')');
7878

@@ -93,11 +93,11 @@ public override bool Equals(object? obj)
9393

9494
var other = (MatchTextExpression)obj;
9595

96-
return TargetAttribute.Equals(other.TargetAttribute) && TextValue.Equals(other.TextValue) && MatchKind == other.MatchKind;
96+
return MatchTarget.Equals(other.MatchTarget) && TextValue.Equals(other.TextValue) && MatchKind == other.MatchKind;
9797
}
9898

9999
public override int GetHashCode()
100100
{
101-
return HashCode.Combine(TargetAttribute, TextValue, MatchKind);
101+
return HashCode.Combine(MatchTarget, TextValue, MatchKind);
102102
}
103103
}

src/JsonApiDotNetCore/Queries/Expressions/QueryExpressionRewriter.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,12 +149,12 @@ public override QueryExpression VisitPagination(PaginationExpression expression,
149149

150150
public override QueryExpression? VisitMatchText(MatchTextExpression expression, TArgument argument)
151151
{
152-
var newTargetAttribute = Visit(expression.TargetAttribute, argument) as ResourceFieldChainExpression;
152+
var newMatchTarget = Visit(expression.MatchTarget, argument) as ResourceFieldChainExpression;
153153
var newTextValue = Visit(expression.TextValue, argument) as LiteralConstantExpression;
154154

155-
if (newTargetAttribute != null && newTextValue != null)
155+
if (newMatchTarget != null && newTextValue != null)
156156
{
157-
var newExpression = new MatchTextExpression(newTargetAttribute, newTextValue, expression.MatchKind);
157+
var newExpression = new MatchTextExpression(newMatchTarget, newTextValue, expression.MatchKind);
158158
return newExpression.Equals(expression) ? expression : newExpression;
159159
}
160160

@@ -163,12 +163,12 @@ public override QueryExpression VisitPagination(PaginationExpression expression,
163163

164164
public override QueryExpression? VisitAny(AnyExpression expression, TArgument argument)
165165
{
166-
var newTargetAttribute = Visit(expression.TargetAttribute, argument) as ResourceFieldChainExpression;
166+
var newMatchTarget = Visit(expression.MatchTarget, argument) as ResourceFieldChainExpression;
167167
IImmutableSet<LiteralConstantExpression> newConstants = VisitSet(expression.Constants, argument);
168168

169-
if (newTargetAttribute != null)
169+
if (newMatchTarget != null)
170170
{
171-
var newExpression = new AnyExpression(newTargetAttribute, newConstants);
171+
var newExpression = new AnyExpression(newMatchTarget, newConstants);
172172
return newExpression.Equals(expression) ? expression : newExpression;
173173
}
174174

src/JsonApiDotNetCore/Queries/Parsing/FilterParser.cs

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,37 @@ protected virtual MatchTextExpression ParseTextMatch(string operatorName)
320320
EatText(operatorName);
321321
EatSingleCharacterToken(TokenKind.OpenParen);
322322

323+
QueryExpression matchTarget = ParseTextMatchLeftTerm();
324+
325+
EatSingleCharacterToken(TokenKind.Comma);
326+
327+
ConstantValueConverter constantValueConverter = GetConstantValueConverterForType(typeof(string));
328+
LiteralConstantExpression constant = ParseConstant(constantValueConverter);
329+
330+
EatSingleCharacterToken(TokenKind.CloseParen);
331+
332+
var matchKind = Enum.Parse<TextMatchKind>(operatorName.Pascalize());
333+
return new MatchTextExpression(matchTarget, constant, matchKind);
334+
}
335+
336+
private QueryExpression ParseTextMatchLeftTerm()
337+
{
338+
if (TokenStack.TryPeek(out Token? nextToken) && nextToken is { Kind: TokenKind.Text } && IsFunction(nextToken.Value!))
339+
{
340+
FunctionExpression targetFunction = ParseFunction();
341+
342+
if (targetFunction.ReturnType != typeof(string))
343+
{
344+
throw new QueryParseException("Function that returns type 'String' expected.", nextToken.Position);
345+
}
346+
347+
return targetFunction;
348+
}
349+
323350
int chainStartPosition = GetNextTokenPositionOrEnd();
324351

325-
ResourceFieldChainExpression targetAttributeChain =
326-
ParseFieldChain(BuiltInPatterns.ToOneChainEndingInAttribute, FieldChainPatternMatchOptions.None, ResourceTypeInScope, null);
352+
ResourceFieldChainExpression targetAttributeChain = ParseFieldChain(BuiltInPatterns.ToOneChainEndingInAttribute, FieldChainPatternMatchOptions.None,
353+
ResourceTypeInScope, null);
327354

328355
var targetAttribute = (AttrAttribute)targetAttributeChain.Fields[^1];
329356

@@ -333,32 +360,21 @@ protected virtual MatchTextExpression ParseTextMatch(string operatorName)
333360
throw new QueryParseException("Attribute of type 'String' expected.", position);
334361
}
335362

336-
EatSingleCharacterToken(TokenKind.Comma);
337-
338-
ConstantValueConverter constantValueConverter = GetConstantValueConverterForAttribute(targetAttribute);
339-
LiteralConstantExpression constant = ParseConstant(constantValueConverter);
340-
341-
EatSingleCharacterToken(TokenKind.CloseParen);
342-
343-
var matchKind = Enum.Parse<TextMatchKind>(operatorName.Pascalize());
344-
return new MatchTextExpression(targetAttributeChain, constant, matchKind);
363+
return targetAttributeChain;
345364
}
346365

347366
protected virtual AnyExpression ParseAny()
348367
{
349368
EatText(Keywords.Any);
350369
EatSingleCharacterToken(TokenKind.OpenParen);
351370

352-
ResourceFieldChainExpression targetAttributeChain =
353-
ParseFieldChain(BuiltInPatterns.ToOneChainEndingInAttribute, FieldChainPatternMatchOptions.None, ResourceTypeInScope, null);
354-
355-
var targetAttribute = (AttrAttribute)targetAttributeChain.Fields[^1];
371+
(QueryExpression matchTarget, Func<ConstantValueConverter> constantValueConverterFactory) = ParseAnyLeftTerm();
356372

357373
EatSingleCharacterToken(TokenKind.Comma);
358374

359375
ImmutableHashSet<LiteralConstantExpression>.Builder constantsBuilder = ImmutableHashSet.CreateBuilder<LiteralConstantExpression>();
360376

361-
ConstantValueConverter constantValueConverter = GetConstantValueConverterForAttribute(targetAttribute);
377+
ConstantValueConverter constantValueConverter = constantValueConverterFactory();
362378
LiteralConstantExpression constant = ParseConstant(constantValueConverter);
363379
constantsBuilder.Add(constant);
364380

@@ -374,7 +390,26 @@ protected virtual AnyExpression ParseAny()
374390

375391
IImmutableSet<LiteralConstantExpression> constantSet = constantsBuilder.ToImmutable();
376392

377-
return new AnyExpression(targetAttributeChain, constantSet);
393+
return new AnyExpression(matchTarget, constantSet);
394+
}
395+
396+
private (QueryExpression matchTarget, Func<ConstantValueConverter> constantValueConverterFactory) ParseAnyLeftTerm()
397+
{
398+
if (TokenStack.TryPeek(out Token? nextToken) && nextToken is { Kind: TokenKind.Text } && IsFunction(nextToken.Value!))
399+
{
400+
FunctionExpression targetFunction = ParseFunction();
401+
402+
Func<ConstantValueConverter> functionConverterFactory = () => GetConstantValueConverterForType(targetFunction.ReturnType);
403+
return (targetFunction, functionConverterFactory);
404+
}
405+
406+
ResourceFieldChainExpression targetAttributeChain =
407+
ParseFieldChain(BuiltInPatterns.ToOneChainEndingInAttribute, FieldChainPatternMatchOptions.None, ResourceTypeInScope, null);
408+
409+
var targetAttribute = (AttrAttribute)targetAttributeChain.Fields[^1];
410+
411+
Func<ConstantValueConverter> attributeConverterFactory = () => GetConstantValueConverterForAttribute(targetAttribute);
412+
return (targetAttributeChain, attributeConverterFactory);
378413
}
379414

380415
protected virtual HasExpression ParseHas()

src/JsonApiDotNetCore/Queries/QueryableBuilding/WhereClauseBuilder.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ public override Expression VisitIsType(IsTypeExpression expression, QueryClauseB
9090

9191
public override Expression VisitMatchText(MatchTextExpression expression, QueryClauseBuilderContext context)
9292
{
93-
Expression property = Visit(expression.TargetAttribute, context);
93+
Expression property = Visit(expression.MatchTarget, context);
9494

9595
if (property.Type != typeof(string))
9696
{
@@ -109,7 +109,7 @@ public override Expression VisitMatchText(MatchTextExpression expression, QueryC
109109

110110
public override Expression VisitAny(AnyExpression expression, QueryClauseBuilderContext context)
111111
{
112-
Expression property = Visit(expression.TargetAttribute, context);
112+
Expression property = Visit(expression.MatchTarget, context);
113113

114114
var valueList = (IList)Activator.CreateInstance(typeof(List<>).MakeGenericType(property.Type))!;
115115

test/JsonApiDotNetCoreTests/IntegrationTests/CompositeKeys/CarExpressionRewriter.cs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,24 +50,30 @@ public CarExpressionRewriter(IResourceGraph resourceGraph)
5050

5151
public override QueryExpression? VisitAny(AnyExpression expression, object? argument)
5252
{
53-
PropertyInfo property = expression.TargetAttribute.Fields[^1].Property;
54-
55-
if (IsCarId(property))
53+
if (expression.MatchTarget is ResourceFieldChainExpression targetAttributeChain)
5654
{
57-
string[] carStringIds = expression.Constants.Select(constant => (string)constant.TypedValue).ToArray();
58-
return RewriteFilterOnCarStringIds(expression.TargetAttribute, carStringIds);
55+
PropertyInfo property = targetAttributeChain.Fields[^1].Property;
56+
57+
if (IsCarId(property))
58+
{
59+
string[] carStringIds = expression.Constants.Select(constant => (string)constant.TypedValue).ToArray();
60+
return RewriteFilterOnCarStringIds(targetAttributeChain, carStringIds);
61+
}
5962
}
6063

6164
return base.VisitAny(expression, argument);
6265
}
6366

6467
public override QueryExpression? VisitMatchText(MatchTextExpression expression, object? argument)
6568
{
66-
PropertyInfo property = expression.TargetAttribute.Fields[^1].Property;
67-
68-
if (IsCarId(property))
69+
if (expression.MatchTarget is ResourceFieldChainExpression targetAttributeChain)
6970
{
70-
throw new NotSupportedException("Partial text matching on Car IDs is not possible.");
71+
PropertyInfo property = targetAttributeChain.Fields[^1].Property;
72+
73+
if (IsCarId(property))
74+
{
75+
throw new NotSupportedException("Partial text matching on Car IDs is not possible.");
76+
}
7177
}
7278

7379
return base.VisitMatchText(expression, argument);
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System.Reflection;
2+
3+
#pragma warning disable AV1008 // Class should not be static
4+
5+
namespace JsonApiDotNetCoreTests.IntegrationTests.QueryStrings.CustomFunctions.Decrypt;
6+
7+
internal static class DatabaseFunctionStub
8+
{
9+
public static readonly MethodInfo DecryptMethod = typeof(DatabaseFunctionStub).GetMethod(nameof(Decrypt), [typeof(string)])!;
10+
11+
public static string Decrypt(string text)
12+
{
13+
_ = text;
14+
throw new InvalidOperationException($"The '{nameof(Decrypt)}' user-defined SQL function cannot be called client-side.");
15+
}
16+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using JetBrains.Annotations;
2+
using Microsoft.EntityFrameworkCore;
3+
using TestBuildingBlocks;
4+
5+
// @formatter:wrap_chained_method_calls chop_always
6+
7+
namespace JsonApiDotNetCoreTests.IntegrationTests.QueryStrings.CustomFunctions.Decrypt;
8+
9+
[UsedImplicitly(ImplicitUseTargetFlags.Members)]
10+
public sealed class DecryptDbContext(DbContextOptions<DecryptDbContext> options)
11+
: TestableDbContext(options)
12+
{
13+
public DbSet<Blog> Blogs => Set<Blog>();
14+
15+
protected override void OnModelCreating(ModelBuilder builder)
16+
{
17+
QueryStringDbContext.ConfigureModel(builder);
18+
19+
builder.HasDbFunction(DatabaseFunctionStub.DecryptMethod)
20+
.HasName("decrypt_column_value");
21+
22+
base.OnModelCreating(builder);
23+
}
24+
25+
internal async Task DeclareDecryptFunctionAsync()
26+
{
27+
// Just for demo purposes, decryption is defined as: base64-decode the incoming value.
28+
await Database.ExecuteSqlRawAsync("""
29+
CREATE OR REPLACE FUNCTION decrypt_column_value(value text)
30+
RETURNS text
31+
RETURN encode(decode(value, 'base64'), 'escape');
32+
""");
33+
}
34+
}

0 commit comments

Comments
 (0)