Skip to content

Commit 57ca935

Browse files
authored
Generate orchestrator interface (#3)
* - Implement generating interface file. * Implement generating inhering from generated interface. * Minor refactoring to remove passing SemanticModel where it is not used. * Downgrade main project from .NET Standard 2.1 to .NET Standard 2.0 to avoid warning that targeting other frameworks for source generators has unpredicable behavior. * Remove unnecessary newline.
1 parent c3f30a6 commit 57ca935

6 files changed

Lines changed: 114 additions & 68 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// <auto-generated/>
2+
#nullable restore
3+
4+
namespace TestLibrary;
5+
6+
internal interface IOrchestrator
7+
{
8+
public Task<int> Execute();
9+
}

AsyncTaskOrchestratorGenerator.UnitTests/Orchestrator.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
namespace TestLibrary;
88

9-
internal class Orchestrator
9+
internal class Orchestrator : IOrchestrator
1010
{
1111
private readonly TestLibrary.A a;
1212
private readonly TestLibrary.B b;

AsyncTaskOrchestratorGenerator.UnitTests/Tests.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ public void Setup() {
1919
[Test]
2020
public async Task OneInterface() {
2121
var source = await ReadCSharpFile<OrchestratorSpec>(true);
22-
var generated = await ReadCSharpFile<Orchestrator>(false);
22+
var generatedClass = await ReadCSharpFile<Orchestrator>(false);
23+
var generatedInterface = await ReadCSharpFile<IOrchestrator>(false);
2324

2425
await new VerifyCS.Test
2526
{
@@ -35,7 +36,8 @@ public async Task OneInterface() {
3536
Sources = { source },
3637
GeneratedSources =
3738
{
38-
(typeof(Main), "Orchestrator.generated.cs", SourceText.From(generated, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
39+
(typeof(Main), "Orchestrator.generated.cs", SourceText.From(generatedClass, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
40+
(typeof(Main), "IOrchestrator.generated.cs", SourceText.From(generatedInterface, Encoding.UTF8, SourceHashAlgorithm.Sha256)),
3941
},
4042
},
4143
}.RunAsync();

AsyncTaskOrchestratorGenerator/AsyncTaskOrchestratorGenerator.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
4-
<TargetFramework>netstandard2.1</TargetFramework>
4+
<TargetFramework>netstandard2.0</TargetFramework>
55
<IsRoslynComponent>true</IsRoslynComponent>
66
<GeneratePackageOnBuild>True</GeneratePackageOnBuild>
77
<Title>Async Task Orchestrator Generator</Title>

AsyncTaskOrchestratorGenerator/Main.cs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using Microsoft.CodeAnalysis;
22
using Microsoft.CodeAnalysis.CSharp.Syntax;
33
using Microsoft.CodeAnalysis.Text;
4-
using System;
5-
using System.Collections.Generic;
6-
using System.Linq;
74
using System.Text;
85
using System.Threading;
96

@@ -26,15 +23,16 @@ private static bool IsSyntaxTargetForGeneration(SyntaxNode syntaxNode, Cancellat
2623
return syntaxNode is TypeDeclarationSyntax;
2724
}
2825

29-
private static (INamedTypeSymbol, SemanticModel) GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {
30-
31-
return (context.TargetSymbol as INamedTypeSymbol, context.SemanticModel);
26+
private static INamedTypeSymbol GetSemanticTargetForGeneration(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) {
27+
return context.TargetSymbol as INamedTypeSymbol;
3228
}
3329

34-
private static void Execute(SourceProductionContext context, (INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
35-
var (source, className) = OutputGenerator.GenerateOutputs(typeInfo);
30+
private static void Execute(SourceProductionContext context, INamedTypeSymbol typeSymbol) {
31+
var (classSource, className) = OutputGenerator.GenerateClassOutputs(typeSymbol);
32+
var (interfaceSource, interfaceName) = OutputGenerator.GenerateInterfaceOutputs(typeSymbol);
3633

37-
context.AddSource($"{className}.generated.cs", SourceText.From(source, Encoding.UTF8, SourceHashAlgorithm.Sha256));
34+
context.AddSource($"{className}.generated.cs", SourceText.From(classSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
35+
context.AddSource($"{interfaceName}.generated.cs", SourceText.From(interfaceSource, Encoding.UTF8, SourceHashAlgorithm.Sha256));
3836
}
3937
}
4038
}

AsyncTaskOrchestratorGenerator/OutputGenerator.cs

Lines changed: 91 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ namespace AsyncTaskOrchestratorGenerator
88
{
99
internal static class OutputGenerator
1010
{
11-
public static (string source, string className) GenerateOutputs((INamedTypeSymbol typeSymbol, SemanticModel semanticModel) typeInfo) {
12-
var type = typeInfo.typeSymbol;
13-
var semanticModel = typeInfo.semanticModel;
14-
15-
var constructorArguments = type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
11+
public static (string source, string className) GenerateClassOutputs(INamedTypeSymbol type) {
12+
var constructorArguments = GetAttributeConstructorArguments(type);
1613
var className = constructorArguments.First().Value.ToString();
1714
var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
15+
var interfaceName = $"I{className}";
1816

1917
var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
2018
var typeMembers = type.GetMembers();
@@ -36,7 +34,7 @@ public static (string source, string className) GenerateOutputs((INamedTypeSymbo
3634
3735
namespace {type.ContainingNamespace.ToDisplayString()};
3836
39-
{accessModifier} class {className}
37+
{accessModifier} class {className} : {interfaceName}
4038
{{
4139
{string.Join(@"
4240
", formattedFields)}
@@ -50,37 +48,40 @@ namespace {type.ContainingNamespace.ToDisplayString()};
5048
return (source, className);
5149
}
5250

53-
private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable<IFieldSymbol> fields, string executeMethodName) {
54-
var executeMethod = type
55-
.GetMembers()
56-
.Where(m => m is IMethodSymbol)
57-
.First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
58-
var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
59-
var variableStatements = statements.Remove(statements.Last());
51+
public static (string source, string interfaceName) GenerateInterfaceOutputs(INamedTypeSymbol type) {
52+
var constructorArguments = GetAttributeConstructorArguments(type);
53+
var className = constructorArguments.First().Value.ToString();
54+
var executeMethodName = constructorArguments.ElementAt(1).Value.ToString();
55+
var interfaceName = $"I{className}";
56+
57+
var accessModifier = type.DeclaredAccessibility.ToString().ToLower();
58+
var executeMethod = GetExecuteMethod(type);
59+
var executeMethodAccessibility = executeMethod.DeclaredAccessibility.ToString().ToLower();
60+
var formattedExecuteMethod = $"{executeMethodAccessibility} {executeMethod.ReturnType} {executeMethodName}();";
6061

61-
var variableData = variableStatements
62-
.Select(s => s as LocalDeclarationStatementSyntax)
63-
.SelectMany(v => v.Declaration.Variables)
64-
.Select(declarationSyntax => {
65-
var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
66-
var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
67-
var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
68-
var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
69-
var methodCallName = methodAccessExpression.ToString().Split('.').Last();
70-
var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;
62+
var source =
63+
$@"// <auto-generated/>
64+
#nullable restore
7165
72-
var arguments = invocation.ArgumentList.Arguments;
73-
var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());
66+
namespace {type.ContainingNamespace.ToDisplayString()};
7467
75-
return new TaskData
76-
{
77-
OutputName = declarationSyntax.Identifier.Text,
78-
MethodCallName = methodAccessExpression.ToString(),
79-
MethodCallReturnType = methodSymbol.ReturnType.ToString(),
80-
DependenciesOutputNames = argumentTypeNames,
81-
TaskName = $"{declarationSyntax.Identifier.Text}Task"
82-
};
83-
});
68+
{accessModifier} interface {interfaceName}
69+
{{
70+
{formattedExecuteMethod}
71+
}}
72+
";
73+
return (source, interfaceName);
74+
}
75+
76+
private static System.Collections.Immutable.ImmutableArray<TypedConstant> GetAttributeConstructorArguments(INamedTypeSymbol type) {
77+
return type.GetAttributes().First((a) => a.AttributeClass.Name == nameof(AsyncTaskOrchestratorAttribute)).ConstructorArguments;
78+
}
79+
80+
private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskData) CreateExecuteMethodData(INamedTypeSymbol type, IEnumerable<IFieldSymbol> fields, string executeMethodName) {
81+
var executeMethod = GetExecuteMethod(type);
82+
var statements = (executeMethod.DeclaringSyntaxReferences.First().GetSyntax() as MethodDeclarationSyntax).Body.Statements;
83+
var variableStatements = statements.Remove(statements.Last());
84+
var variableData = GetVariableData(fields, variableStatements);
8485

8586
var lastStatement = statements.Last() as ReturnStatementSyntax;
8687
var invocation = lastStatement.Expression as InvocationExpressionSyntax;
@@ -110,36 +111,55 @@ private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskDa
110111
}, variableData.ToDictionary(taskData => taskData.OutputName), finalTaskData);
111112
}
112113

114+
private static IEnumerable<TaskData> GetVariableData(IEnumerable<IFieldSymbol> fields, SyntaxList<StatementSyntax> variableStatements) {
115+
return variableStatements
116+
.Select(s => s as LocalDeclarationStatementSyntax)
117+
.SelectMany(v => v.Declaration.Variables)
118+
.Select(declarationSyntax => {
119+
var invocation = declarationSyntax.Initializer.Value as InvocationExpressionSyntax;
120+
var methodAccessExpression = invocation.Expression as MemberAccessExpressionSyntax;
121+
var methodCallTypeName = methodAccessExpression.ToString().Split('.').First();
122+
var methodCallType = fields.First(f => f.Name == methodCallTypeName).Type;
123+
var methodCallName = methodAccessExpression.ToString().Split('.').Last();
124+
var methodSymbol = methodCallType.GetMembers(methodCallName).First() as IMethodSymbol;
125+
126+
var arguments = invocation.ArgumentList.Arguments;
127+
var argumentTypeNames = arguments.Select(a => a.ToString().Split('.').First());
128+
129+
return new TaskData
130+
{
131+
OutputName = declarationSyntax.Identifier.Text,
132+
MethodCallName = methodAccessExpression.ToString(),
133+
MethodCallReturnType = methodSymbol.ReturnType.ToString(),
134+
DependenciesOutputNames = argumentTypeNames,
135+
TaskName = $"{declarationSyntax.Identifier.Text}Task"
136+
};
137+
});
138+
}
139+
140+
private static IMethodSymbol GetExecuteMethod(INamedTypeSymbol type) {
141+
return type
142+
.GetMembers()
143+
.Where(m => m is IMethodSymbol)
144+
.First(m => (m as IMethodSymbol).MethodKind == MethodKind.Ordinary) as IMethodSymbol;
145+
}
146+
113147
private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureData, Dictionary<string, TaskData> data, TaskData finalTaskData) {
114148
var formattedTaskDeclarations = data.Select(keyValue => {
115149
var item = keyValue.Value;
116150
var hasDependencies = item.DependenciesOutputNames.Any();
117-
return hasDependencies ?
118-
$@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);":
151+
return hasDependencies ?
152+
$@"var {item.TaskName} = new {item.MethodCallReturnType}(() => default);" :
119153
$@"var {item.TaskName} = {item.MethodCallName}();";
120154
});
121155

122156
var taskNames = data.Where(keyValue => !keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => keyValue.Value.TaskName);
123157
var formattedTasksList = $@"var tasksToProcess = new List<Task> {{ {string.Join(@", ", taskNames)} }};";
124-
125-
var formattedHandleTaskCompletions = data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
126-
var item = keyValue.Value;
127-
var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
128-
var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
129-
var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
130-
var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
131-
var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";
132-
133-
return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
134-
{{
135-
{formattedCallDependencies}
136-
{formattedAddTaskToList}
137-
}}";
138-
});
158+
var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions(data);
139159

140160
var formattedWhenEach = $@"await foreach (var completed in Task.WhenEach(tasksToProcess))
141161
{{
142-
{ string.Join(@"
162+
{string.Join(@"
143163
144164
", formattedHandleTaskCompletions)}
145165
}}";
@@ -149,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
149169
var formattedFinalResult = $@"var finalResult = await {finalTaskData.MethodCallName}({formattedResultDependencyTaskNames});
150170
151171
return finalResult;";
152-
172+
153173
return $@"{signatureData.AccessModifier} async {signatureData.ReturnType} {signatureData.Name}()
154174
{{
155175
{string.Join(@"
156-
", formattedTaskDeclarations) }
176+
", formattedTaskDeclarations)}
157177
158178
{formattedTasksList}
159179
@@ -163,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
163183
}}";
164184
}
165185

186+
private static IEnumerable<string> CreateFormattedHandleTaskCompletions(Dictionary<string, TaskData> data) {
187+
return data.Where(keyValue => keyValue.Value.DependenciesOutputNames.Any()).Select(keyValue => {
188+
var item = keyValue.Value;
189+
var dependencyTaskNames = item.DependenciesOutputNames.Select(depName => data[depName].TaskName);
190+
var formattedCompletedDependencyTaskNames = string.Join(" && ", dependencyTaskNames.Select(tn => $"{tn}.IsCompleted"));
191+
var formattedResultDependencyTaskNames = string.Join(", ", dependencyTaskNames.Select(tn => $"{tn}.Result"));
192+
var formattedCallDependencies = $@"{item.TaskName} = {item.MethodCallName}({formattedResultDependencyTaskNames});";
193+
var formattedAddTaskToList = $@"tasksToProcess.Add({item.TaskName});";
194+
195+
return $@"if (!{item.TaskName}.IsCompleted && {formattedCompletedDependencyTaskNames})
196+
{{
197+
{formattedCallDependencies}
198+
{formattedAddTaskToList}
199+
}}";
200+
});
201+
}
202+
166203
private static string FormatConstructor(INamedTypeSymbol type, string className, IEnumerable<ISymbol> typeMembers) {
167204
var constructor = typeMembers
168205
.Where(m => m.Kind == SymbolKind.Method)

0 commit comments

Comments
 (0)