@@ -8,13 +8,11 @@ namespace AsyncTaskOrchestratorGenerator
88{
99 internal static class OutputGenerator
1010 {
11- public static ( string source , string className ) GenerateOutputs ( ( INamedTypeSymbol typeSymbol , SemanticModel semanticModel ) typeInfo ) {
12- var type = typeInfo . typeSymbol ;
13- var semanticModel = typeInfo . semanticModel ;
14-
15- var constructorArguments = type . GetAttributes ( ) . First ( ( a ) => a . AttributeClass . Name == nameof ( AsyncTaskOrchestratorAttribute ) ) . ConstructorArguments ;
11+ public static ( string source , string className ) GenerateClassOutputs ( INamedTypeSymbol type ) {
12+ var constructorArguments = GetAttributeConstructorArguments ( type ) ;
1613 var className = constructorArguments . First ( ) . Value . ToString ( ) ;
1714 var executeMethodName = constructorArguments . ElementAt ( 1 ) . Value . ToString ( ) ;
15+ var interfaceName = $ "I{ className } ";
1816
1917 var accessModifier = type . DeclaredAccessibility . ToString ( ) . ToLower ( ) ;
2018 var typeMembers = type . GetMembers ( ) ;
@@ -36,7 +34,7 @@ public static (string source, string className) GenerateOutputs((INamedTypeSymbo
3634
3735namespace { type . ContainingNamespace . ToDisplayString ( ) } ;
3836
39- { accessModifier } class { className }
37+ { accessModifier } class { className } : { interfaceName }
4038{{
4139{ string . Join ( @"
4240" , formattedFields ) }
@@ -50,37 +48,40 @@ namespace {type.ContainingNamespace.ToDisplayString()};
5048 return ( source , className ) ;
5149 }
5250
53- private static ( ExecuteMethodSignatureData , Dictionary < string , TaskData > , TaskData ) CreateExecuteMethodData ( INamedTypeSymbol type , IEnumerable < IFieldSymbol > fields , string executeMethodName ) {
54- var executeMethod = type
55- . GetMembers ( )
56- . Where ( m => m is IMethodSymbol )
57- . First ( m => ( m as IMethodSymbol ) . MethodKind == MethodKind . Ordinary ) as IMethodSymbol ;
58- var statements = ( executeMethod . DeclaringSyntaxReferences . First ( ) . GetSyntax ( ) as MethodDeclarationSyntax ) . Body . Statements ;
59- var variableStatements = statements . Remove ( statements . Last ( ) ) ;
51+ public static ( string source , string interfaceName ) GenerateInterfaceOutputs ( INamedTypeSymbol type ) {
52+ var constructorArguments = GetAttributeConstructorArguments ( type ) ;
53+ var className = constructorArguments . First ( ) . Value . ToString ( ) ;
54+ var executeMethodName = constructorArguments . ElementAt ( 1 ) . Value . ToString ( ) ;
55+ var interfaceName = $ "I{ className } ";
56+
57+ var accessModifier = type . DeclaredAccessibility . ToString ( ) . ToLower ( ) ;
58+ var executeMethod = GetExecuteMethod ( type ) ;
59+ var executeMethodAccessibility = executeMethod . DeclaredAccessibility . ToString ( ) . ToLower ( ) ;
60+ var formattedExecuteMethod = $ "{ executeMethodAccessibility } { executeMethod . ReturnType } { executeMethodName } ();";
6061
61- var variableData = variableStatements
62- . Select ( s => s as LocalDeclarationStatementSyntax )
63- . SelectMany ( v => v . Declaration . Variables )
64- . Select ( declarationSyntax => {
65- var invocation = declarationSyntax . Initializer . Value as InvocationExpressionSyntax ;
66- var methodAccessExpression = invocation . Expression as MemberAccessExpressionSyntax ;
67- var methodCallTypeName = methodAccessExpression . ToString ( ) . Split ( '.' ) . First ( ) ;
68- var methodCallType = fields . First ( f => f . Name == methodCallTypeName ) . Type ;
69- var methodCallName = methodAccessExpression . ToString ( ) . Split ( '.' ) . Last ( ) ;
70- var methodSymbol = methodCallType . GetMembers ( methodCallName ) . First ( ) as IMethodSymbol ;
62+ var source =
63+ $@ "// <auto-generated/>
64+ #nullable restore
7165
72- var arguments = invocation . ArgumentList . Arguments ;
73- var argumentTypeNames = arguments . Select ( a => a . ToString ( ) . Split ( '.' ) . First ( ) ) ;
66+ namespace { type . ContainingNamespace . ToDisplayString ( ) } ;
7467
75- return new TaskData
76- {
77- OutputName = declarationSyntax . Identifier . Text ,
78- MethodCallName = methodAccessExpression . ToString ( ) ,
79- MethodCallReturnType = methodSymbol . ReturnType . ToString ( ) ,
80- DependenciesOutputNames = argumentTypeNames ,
81- TaskName = $ "{ declarationSyntax . Identifier . Text } Task"
82- } ;
83- } ) ;
68+ { accessModifier } interface { interfaceName }
69+ {{
70+ { formattedExecuteMethod }
71+ }}
72+ " ;
73+ return ( source , interfaceName ) ;
74+ }
75+
76+ private static System . Collections . Immutable . ImmutableArray < TypedConstant > GetAttributeConstructorArguments ( INamedTypeSymbol type ) {
77+ return type . GetAttributes ( ) . First ( ( a ) => a . AttributeClass . Name == nameof ( AsyncTaskOrchestratorAttribute ) ) . ConstructorArguments ;
78+ }
79+
80+ private static ( ExecuteMethodSignatureData , Dictionary < string , TaskData > , TaskData ) CreateExecuteMethodData ( INamedTypeSymbol type , IEnumerable < IFieldSymbol > fields , string executeMethodName ) {
81+ var executeMethod = GetExecuteMethod ( type ) ;
82+ var statements = ( executeMethod . DeclaringSyntaxReferences . First ( ) . GetSyntax ( ) as MethodDeclarationSyntax ) . Body . Statements ;
83+ var variableStatements = statements . Remove ( statements . Last ( ) ) ;
84+ var variableData = GetVariableData ( fields , variableStatements ) ;
8485
8586 var lastStatement = statements . Last ( ) as ReturnStatementSyntax ;
8687 var invocation = lastStatement . Expression as InvocationExpressionSyntax ;
@@ -110,36 +111,55 @@ private static (ExecuteMethodSignatureData, Dictionary<string, TaskData>, TaskDa
110111 } , variableData . ToDictionary ( taskData => taskData . OutputName ) , finalTaskData ) ;
111112 }
112113
114+ private static IEnumerable < TaskData > GetVariableData ( IEnumerable < IFieldSymbol > fields , SyntaxList < StatementSyntax > variableStatements ) {
115+ return variableStatements
116+ . Select ( s => s as LocalDeclarationStatementSyntax )
117+ . SelectMany ( v => v . Declaration . Variables )
118+ . Select ( declarationSyntax => {
119+ var invocation = declarationSyntax . Initializer . Value as InvocationExpressionSyntax ;
120+ var methodAccessExpression = invocation . Expression as MemberAccessExpressionSyntax ;
121+ var methodCallTypeName = methodAccessExpression . ToString ( ) . Split ( '.' ) . First ( ) ;
122+ var methodCallType = fields . First ( f => f . Name == methodCallTypeName ) . Type ;
123+ var methodCallName = methodAccessExpression . ToString ( ) . Split ( '.' ) . Last ( ) ;
124+ var methodSymbol = methodCallType . GetMembers ( methodCallName ) . First ( ) as IMethodSymbol ;
125+
126+ var arguments = invocation . ArgumentList . Arguments ;
127+ var argumentTypeNames = arguments . Select ( a => a . ToString ( ) . Split ( '.' ) . First ( ) ) ;
128+
129+ return new TaskData
130+ {
131+ OutputName = declarationSyntax . Identifier . Text ,
132+ MethodCallName = methodAccessExpression . ToString ( ) ,
133+ MethodCallReturnType = methodSymbol . ReturnType . ToString ( ) ,
134+ DependenciesOutputNames = argumentTypeNames ,
135+ TaskName = $ "{ declarationSyntax . Identifier . Text } Task"
136+ } ;
137+ } ) ;
138+ }
139+
140+ private static IMethodSymbol GetExecuteMethod ( INamedTypeSymbol type ) {
141+ return type
142+ . GetMembers ( )
143+ . Where ( m => m is IMethodSymbol )
144+ . First ( m => ( m as IMethodSymbol ) . MethodKind == MethodKind . Ordinary ) as IMethodSymbol ;
145+ }
146+
113147 private static string FormatExecuteMethod ( ExecuteMethodSignatureData signatureData , Dictionary < string , TaskData > data , TaskData finalTaskData ) {
114148 var formattedTaskDeclarations = data . Select ( keyValue => {
115149 var item = keyValue . Value ;
116150 var hasDependencies = item . DependenciesOutputNames . Any ( ) ;
117- return hasDependencies ?
118- $@ "var { item . TaskName } = new { item . MethodCallReturnType } (() => default);":
151+ return hasDependencies ?
152+ $@ "var { item . TaskName } = new { item . MethodCallReturnType } (() => default);" :
119153 $@ "var { item . TaskName } = { item . MethodCallName } ();";
120154 } ) ;
121155
122156 var taskNames = data . Where ( keyValue => ! keyValue . Value . DependenciesOutputNames . Any ( ) ) . Select ( keyValue => keyValue . Value . TaskName ) ;
123157 var formattedTasksList = $@ "var tasksToProcess = new List<Task> {{ { string . Join ( @", " , taskNames ) } }};";
124-
125- var formattedHandleTaskCompletions = data . Where ( keyValue => keyValue . Value . DependenciesOutputNames . Any ( ) ) . Select ( keyValue => {
126- var item = keyValue . Value ;
127- var dependencyTaskNames = item . DependenciesOutputNames . Select ( depName => data [ depName ] . TaskName ) ;
128- var formattedCompletedDependencyTaskNames = string . Join ( " && " , dependencyTaskNames . Select ( tn => $ "{ tn } .IsCompleted") ) ;
129- var formattedResultDependencyTaskNames = string . Join ( ", " , dependencyTaskNames . Select ( tn => $ "{ tn } .Result") ) ;
130- var formattedCallDependencies = $@ "{ item . TaskName } = { item . MethodCallName } ({ formattedResultDependencyTaskNames } );";
131- var formattedAddTaskToList = $@ "tasksToProcess.Add({ item . TaskName } );";
132-
133- return $@ "if (!{ item . TaskName } .IsCompleted && { formattedCompletedDependencyTaskNames } )
134- {{
135- { formattedCallDependencies }
136- { formattedAddTaskToList }
137- }}" ;
138- } ) ;
158+ var formattedHandleTaskCompletions = CreateFormattedHandleTaskCompletions ( data ) ;
139159
140160 var formattedWhenEach = $@ "await foreach (var completed in Task.WhenEach(tasksToProcess))
141161 {{
142- { string . Join ( @"
162+ { string . Join ( @"
143163
144164 " , formattedHandleTaskCompletions ) }
145165 }}" ;
@@ -149,11 +169,11 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
149169 var formattedFinalResult = $@ "var finalResult = await { finalTaskData . MethodCallName } ({ formattedResultDependencyTaskNames } );
150170
151171 return finalResult;" ;
152-
172+
153173 return $@ "{ signatureData . AccessModifier } async { signatureData . ReturnType } { signatureData . Name } ()
154174 {{
155175 { string . Join ( @"
156- " , formattedTaskDeclarations ) }
176+ " , formattedTaskDeclarations ) }
157177
158178 { formattedTasksList }
159179
@@ -163,6 +183,23 @@ private static string FormatExecuteMethod(ExecuteMethodSignatureData signatureDa
163183 }}" ;
164184 }
165185
186+ private static IEnumerable < string > CreateFormattedHandleTaskCompletions ( Dictionary < string , TaskData > data ) {
187+ return data . Where ( keyValue => keyValue . Value . DependenciesOutputNames . Any ( ) ) . Select ( keyValue => {
188+ var item = keyValue . Value ;
189+ var dependencyTaskNames = item . DependenciesOutputNames . Select ( depName => data [ depName ] . TaskName ) ;
190+ var formattedCompletedDependencyTaskNames = string . Join ( " && " , dependencyTaskNames . Select ( tn => $ "{ tn } .IsCompleted") ) ;
191+ var formattedResultDependencyTaskNames = string . Join ( ", " , dependencyTaskNames . Select ( tn => $ "{ tn } .Result") ) ;
192+ var formattedCallDependencies = $@ "{ item . TaskName } = { item . MethodCallName } ({ formattedResultDependencyTaskNames } );";
193+ var formattedAddTaskToList = $@ "tasksToProcess.Add({ item . TaskName } );";
194+
195+ return $@ "if (!{ item . TaskName } .IsCompleted && { formattedCompletedDependencyTaskNames } )
196+ {{
197+ { formattedCallDependencies }
198+ { formattedAddTaskToList }
199+ }}" ;
200+ } ) ;
201+ }
202+
166203 private static string FormatConstructor ( INamedTypeSymbol type , string className , IEnumerable < ISymbol > typeMembers ) {
167204 var constructor = typeMembers
168205 . Where ( m => m . Kind == SymbolKind . Method )
0 commit comments