Skip to content

Commit c8073a5

Browse files
authored
Merge pull request #197 from EFNext/fix-generic-extension-methods
Fix generic extension methods
2 parents 9b4e459 + 894ff2e commit c8073a5

14 files changed

+358
-35
lines changed

src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.BodyProcessors.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ private static bool TryApplyMethodBody(
6464
ApplyParameterList(methodDeclarationSyntax.ParameterList, declarationSyntaxRewriter, descriptor);
6565
ApplyTypeParameters(methodDeclarationSyntax, declarationSyntaxRewriter, descriptor);
6666

67+
// For C# 14 generic extension blocks (e.g. extension<T>(Wrapper<T> w)), the block-level
68+
// type parameter T is on the extension type, not on the method declaration syntax.
69+
// ApplyTypeParameters() therefore finds nothing; promote the extension-block type
70+
// parameters to method-level type parameters when no syntax-level ones were found.
71+
if (descriptor.TypeParameterList is null)
72+
{
73+
ApplyExtensionBlockTypeParameters(memberSymbol, descriptor);
74+
}
75+
6776
return true;
6877
}
6978

src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.Helpers.cs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,90 @@ private static void ApplyTypeParameters(
5050
}
5151
}
5252

53+
/// <summary>
54+
/// For C# 14 generic extension blocks (e.g. <c>extension&lt;T&gt;(Wrapper&lt;T&gt; w)</c>),
55+
/// the block-level type parameter <c>T</c> is owned by the extension type, not by the
56+
/// method declaration syntax. <see cref="ApplyTypeParameters"/> therefore finds no
57+
/// <c>TypeParameterList</c> on the method and produces nothing.
58+
/// <para>
59+
/// This helper promotes those extension-block type parameters to method-level type
60+
/// parameters on <paramref name="descriptor"/> so the generated
61+
/// <c>Expression&lt;T&gt;()</c> factory method is correctly generic.
62+
/// It is a no-op when the containing type is not a generic extension block.
63+
/// </para>
64+
/// </summary>
65+
private static void ApplyExtensionBlockTypeParameters(
66+
ISymbol memberSymbol,
67+
ProjectableDescriptor descriptor)
68+
{
69+
if (memberSymbol.ContainingType is not { IsExtension: true } extensionType
70+
|| extensionType.TypeParameters.IsDefaultOrEmpty)
71+
{
72+
return;
73+
}
74+
75+
descriptor.TypeParameterList = SyntaxFactory.TypeParameterList();
76+
77+
foreach (var tp in extensionType.TypeParameters)
78+
{
79+
descriptor.TypeParameterList = descriptor.TypeParameterList.AddParameters(
80+
SyntaxFactory.TypeParameter(tp.Name));
81+
82+
// Build the constraint clause when any constraint is present.
83+
var hasAnyConstraint =
84+
tp.HasReferenceTypeConstraint
85+
|| tp.HasValueTypeConstraint
86+
|| tp.HasNotNullConstraint
87+
|| !tp.ConstraintTypes.IsDefaultOrEmpty
88+
|| tp.HasConstructorConstraint;
89+
90+
if (!hasAnyConstraint)
91+
{
92+
continue;
93+
}
94+
95+
descriptor.ConstraintClauses ??= SyntaxFactory.List<TypeParameterConstraintClauseSyntax>();
96+
descriptor.ConstraintClauses = descriptor.ConstraintClauses.Value.Add(BuildConstraintClause(tp));
97+
}
98+
}
99+
100+
/// <summary>
101+
/// Builds a <see cref="TypeParameterConstraintClauseSyntax"/> for <paramref name="tp"/>
102+
/// by collecting all of its constraints in canonical order:
103+
/// <c>class</c> / <c>struct</c> / <c>notnull</c>, explicit type constraints, then <c>new()</c>.
104+
/// </summary>
105+
private static TypeParameterConstraintClauseSyntax BuildConstraintClause(ITypeParameterSymbol tp)
106+
{
107+
var constraints = new List<TypeConstraintSyntax>();
108+
109+
if (tp.HasReferenceTypeConstraint)
110+
{
111+
constraints.Add(MakeTypeConstraint("class"));
112+
}
113+
114+
if (tp.HasValueTypeConstraint)
115+
{
116+
constraints.Add(MakeTypeConstraint("struct"));
117+
}
118+
119+
if (tp.HasNotNullConstraint)
120+
{
121+
constraints.Add(MakeTypeConstraint("notnull"));
122+
}
123+
124+
constraints.AddRange(tp.ConstraintTypes
125+
.Select(c => MakeTypeConstraint(c.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))));
126+
127+
if (tp.HasConstructorConstraint)
128+
{
129+
constraints.Add(MakeTypeConstraint("new()"));
130+
}
131+
132+
return SyntaxFactory.TypeParameterConstraintClause(
133+
SyntaxFactory.IdentifierName(tp.Name),
134+
SyntaxFactory.SeparatedList<TypeParameterConstraintSyntax>(constraints));
135+
}
136+
53137
/// <summary>
54138
/// Returns the readable getter expression from a property declaration, trying in order:
55139
/// the property-level expression-body, the getter's expression-body, then the first

src/EntityFrameworkCore.Projectables.Generator/Interpretation/ProjectableInterpreter.cs

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -261,36 +261,7 @@ private static void SetupGenericTypeParameters(ProjectableDescriptor descriptor,
261261
}
262262

263263
descriptor.ClassConstraintClauses ??= SyntaxFactory.List<TypeParameterConstraintClauseSyntax>();
264-
265-
var constraints = new List<TypeConstraintSyntax>();
266-
267-
if (tp.HasReferenceTypeConstraint)
268-
{
269-
constraints.Add(MakeTypeConstraint("class"));
270-
}
271-
272-
if (tp.HasValueTypeConstraint)
273-
{
274-
constraints.Add(MakeTypeConstraint("struct"));
275-
}
276-
277-
if (tp.HasNotNullConstraint)
278-
{
279-
constraints.Add(MakeTypeConstraint("notnull"));
280-
}
281-
282-
constraints.AddRange(tp.ConstraintTypes
283-
.Select(c => MakeTypeConstraint(c.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))));
284-
285-
if (tp.HasConstructorConstraint)
286-
{
287-
constraints.Add(MakeTypeConstraint("new()"));
288-
}
289-
290-
descriptor.ClassConstraintClauses = descriptor.ClassConstraintClauses.Value.Add(
291-
SyntaxFactory.TypeParameterConstraintClause(
292-
SyntaxFactory.IdentifierName(tp.Name),
293-
SyntaxFactory.SeparatedList<TypeParameterConstraintSyntax>(constraints)));
264+
descriptor.ClassConstraintClauses = descriptor.ClassConstraintClauses.Value.Add(BuildConstraintClause(tp));
294265
}
295266
}
296267

src/EntityFrameworkCore.Projectables/Services/ProjectionExpressionClassNameGenerator.cs

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,19 +112,37 @@ static string GenerateNameImpl(StringBuilder stringBuilder, string? namespaceNam
112112
}
113113

114114
/// <summary>
115-
/// Appends <paramref name="typeName"/> to <paramref name="sb"/>, stripping the
116-
/// <c>global::</c> prefix and replacing every character that is invalid in a C# identifier
117-
/// with <c>'_'</c> — all in a single pass with no intermediate string allocations.
115+
/// Appends <paramref name="typeName"/> to <paramref name="sb"/>, stripping every
116+
/// <c>global::</c> occurrence (leading and those inside generic type argument lists)
117+
/// and replacing every character that is invalid in a C# identifier with <c>'_'</c>.
118+
/// <para>
119+
/// The multi-occurrence stripping is necessary so that fully-qualified generic types
120+
/// such as <c>global::Foo.Wrapper&lt;global::Foo.Entity&gt;</c> — produced by Roslyn's
121+
/// <c>FullyQualifiedFormat</c> — yield the same sanitised name as the runtime resolver,
122+
/// which never includes <c>global::</c>.
123+
/// </para>
118124
/// </summary>
119125
private static void AppendSanitizedTypeName(StringBuilder sb, string typeName)
120126
{
121127
const string GlobalPrefix = "global::";
122-
var start = typeName.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? GlobalPrefix.Length : 0;
128+
const int PrefixLength = 8; // "global::".Length
123129

124-
for (var i = start; i < typeName.Length; i++)
130+
var i = 0;
131+
while (i < typeName.Length)
125132
{
133+
// Skip every "global::" occurrence — both the leading prefix and any that
134+
// appear inside generic type argument lists (e.g. "Wrapper<global::Inner>").
135+
if (typeName[i] == 'g'
136+
&& i + PrefixLength <= typeName.Length
137+
&& string.CompareOrdinal(typeName, i, GlobalPrefix, 0, PrefixLength) == 0)
138+
{
139+
i += PrefixLength;
140+
continue;
141+
}
142+
126143
var c = typeName[i];
127144
sb.Append(IsInvalidIdentifierChar(c) ? '_' : c);
145+
i++;
128146
}
129147
}
130148

tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionMembers/EntityExtensions.cs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,35 @@ public static class EntityExtensions
2525
}
2626
}
2727

28+
/// <summary>
29+
/// Extension on a closed generic receiver type: <c>extension(GenericWrapper&lt;Entity&gt; w)</c>.
30+
/// Tests the fix for the bug where <c>global::</c> inside generic type arguments caused a
31+
/// name mismatch between the generated class and the runtime resolver.
32+
/// </summary>
33+
public static class ClosedGenericWrapperExtensions
34+
{
35+
extension(GenericWrapper<Entity> w)
36+
{
37+
[Projectable]
38+
public int DoubleId() => w.Id * 2;
39+
}
40+
}
41+
42+
/// <summary>
43+
/// Extension on an open generic receiver type: <c>extension&lt;T&gt;(GenericWrapper&lt;T&gt; w)</c>.
44+
/// The block-level type parameter <c>T</c> becomes a method-level type parameter on the
45+
/// generated <c>Expression&lt;T&gt;()</c> factory, resolved at runtime via generic method
46+
/// reflection.
47+
/// </summary>
48+
public static class OpenGenericWrapperExtensions
49+
{
50+
extension<T>(GenericWrapper<T> w) where T : class
51+
{
52+
[Projectable]
53+
public int TripleId() => w.Id * 3;
54+
}
55+
}
56+
2857
public static class IntExtensions
2958
{
3059
extension(int i)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
SELECT [e].[Id] * 2
2+
FROM [Entity] AS [e]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
SELECT [e].[Id] * 3
2+
FROM [Entity] AS [e]

tests/EntityFrameworkCore.Projectables.FunctionalTests/ExtensionMembers/ExtensionMemberTests.cs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,43 @@ public Task ExtensionMemberMethodWithParameterOnEntity()
3737

3838
return Verifier.Verify(query.ToQueryString());
3939
}
40+
41+
/// <summary>
42+
/// Regression test: extension member on a <em>closed</em> generic receiver type
43+
/// (e.g. <c>extension(GenericWrapper&lt;Entity&gt; w)</c>) previously threw
44+
/// "Unable to resolve generated expression" because <c>global::</c> inside generic
45+
/// type arguments caused a naming mismatch between the generator and the resolver.
46+
/// </summary>
47+
[Fact]
48+
public Task ExtensionMemberMethodOnClosedGenericReceiverType()
49+
{
50+
using var dbContext = new SampleDbContext<Entity>();
51+
52+
var query = dbContext.Set<Entity>()
53+
.Select(x => new GenericWrapper<Entity> { Id = x.Id })
54+
.Select(x => x.DoubleId());
55+
56+
return Verifier.Verify(query.ToQueryString());
57+
}
58+
59+
/// <summary>
60+
/// Exercises support for extension members on an <em>open</em> generic receiver type
61+
/// (e.g. <c>extension&lt;T&gt;(GenericWrapper&lt;T&gt; w)</c>).
62+
/// The block-level type parameter <c>T</c> must be promoted to a method-level type
63+
/// parameter on the generated <c>Expression&lt;T&gt;()</c> factory so the runtime
64+
/// resolver can construct the correct closed-generic expression.
65+
/// </summary>
66+
[Fact]
67+
public Task ExtensionMemberMethodOnOpenGenericReceiverType()
68+
{
69+
using var dbContext = new SampleDbContext<Entity>();
70+
71+
var query = dbContext.Set<Entity>()
72+
.Select(x => new GenericWrapper<Entity> { Id = x.Id })
73+
.Select(x => x.TripleId());
74+
75+
return Verifier.Verify(query.ToQueryString());
76+
}
4077
}
4178
}
4279
#endif
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#if NET10_0_OR_GREATER
2+
namespace EntityFrameworkCore.Projectables.FunctionalTests.ExtensionMembers
3+
{
4+
public class GenericWrapper<T>
5+
{
6+
public int Id { get; set; }
7+
}
8+
}
9+
#endif
10+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// <auto-generated/>
2+
#nullable disable
3+
using System;
4+
using EntityFrameworkCore.Projectables;
5+
using Foo;
6+
7+
namespace EntityFrameworkCore.Projectables.Generated
8+
{
9+
[global::System.ComponentModel.EditorBrowsable(global::System.ComponentModel.EditorBrowsableState.Never)]
10+
static class Foo_WrapperExtensions_DoubleId_P0_Foo_Wrapper_Foo_Entity_
11+
{
12+
static global::System.Linq.Expressions.Expression<global::System.Func<global::Foo.Wrapper<global::Foo.Entity>, int>> Expression()
13+
{
14+
return (global::Foo.Wrapper<global::Foo.Entity> @this) => @this.Value.Id * 2;
15+
}
16+
}
17+
}

0 commit comments

Comments
 (0)