Skip to content
24 changes: 24 additions & 0 deletions src/Pipelines.Generator/Extensions/AttributeDataExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
namespace Microsoft.CodeAnalysis;

/// <summary>
/// Provides extension methods for <see cref="AttributeData"/> to identify pipeline-related attributes.
/// </summary>
internal static class AttributeDataExtensions
{
/// <summary>
/// Extension methods for an <see cref="AttributeData"/> instance.
/// </summary>
/// <param name="attribute">The attribute data to inspect.</param>
extension(AttributeData? attribute)
{
/// <summary>
/// Determines whether the attribute represents a <c>Pipelines.IgnoreAttribute</c> with no constructor arguments.
/// </summary>
public bool IsIgnoreAttribute() => attribute is
{
AttributeClass.ContainingAssembly.Name: "Pipelines",
AttributeClass.Name: "IgnoreAttribute",
ConstructorArguments.IsDefaultOrEmpty: true,
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public void NullableDisable()
}

/// <summary>
/// Writes a <see cref="System.CodeDom.Compiler.GeneratedCodeAttribute"/> line with the generator name and version.
/// Writes a using statement for the <see cref="global::System.CodeDom.Compiler"/> namespace.
/// </summary>
public void UsingSystemCodeDomCompiler()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ public static void Initialize(PipelinesGeneratorContext ctx)
.SyntaxProvider
.CreateSyntaxProvider(Predicate, (context, ct) => Transform(context.Node, context.SemanticModel, ct))
.SelectMany((x, ct) => x)
.Collect()
.Combine(features);
.Collect();

context.RegisterSourceOutput(handlers, static (spc, ctx) =>
var toGenerate = features
.Combine(handlers);

context.RegisterSourceOutput(toGenerate, static (spc, ctx) =>
{
var (handlers, features) = ctx;
var (features, handlers) = ctx;

if (features is not { HasDependencyInjection: true, DisableHandlerRegistration: false })
return;
Expand Down Expand Up @@ -50,10 +52,9 @@ private static IEnumerable<HandlerRegistration> Transform(SyntaxNode node, Seman
if (IRequestHandler(interfaceSymbol) ||
IStreamRequestHandler(interfaceSymbol))
{
yield return new HandlerRegistration(node, symbol, interfaceSymbol);
yield return new HandlerRegistration(symbol, interfaceSymbol);
}
}
//return null;

static bool ValidHandler([NotNullWhen(true)] INamedTypeSymbol? handler) => handler is
{
Expand All @@ -66,7 +67,7 @@ static bool ValidHandler([NotNullWhen(true)] INamedTypeSymbol? handler) => handl

static bool ValidTypeArgument(ITypeSymbol typeSymbol) => typeSymbol switch
{
// todo: expand failure criteria
// todo: expand failure criteria and add diagnostics
INamedTypeSymbol named => named is
{
IsAbstract: false,
Expand Down Expand Up @@ -101,7 +102,9 @@ static bool IStreamRequestHandler(INamedTypeSymbol handler)
}
}

private static void GenerateSourceOutput(SourceProductionContext spc, ImmutableArray<HandlerRegistration> handlersToGenerate)
private static void GenerateSourceOutput(
in SourceProductionContext spc,
in ImmutableArray<HandlerRegistration> handlers)
{
// todo: diagnostics for duplicate handlers
//var descriptor = new DiagnosticDescriptor(
Expand Down Expand Up @@ -130,9 +133,13 @@ private static void GenerateSourceOutput(SourceProductionContext spc, ImmutableA
{
using (sb.Block(sb, $"public static {sb.IServiceCollection} AddHandlers(this {sb.IServiceCollection} services)"))
{
foreach (var handlerToGenerate in handlersToGenerate)
if (!handlers.IsDefaultOrEmpty)
{
handlerToGenerate.ServiceRegistration(sb);
foreach (var handler in handlers)
{
handler.ServiceRegistration(sb);
}
sb.Line();
}
sb.Line("return services;");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,36 @@

internal readonly record struct HandlerRegistration
{
private readonly bool Disabled;

public readonly string Handler;

public readonly string Interface;
public readonly SyntaxNode Node;

public HandlerRegistration(SyntaxNode node, INamedTypeSymbol classSymbol, INamedTypeSymbol interfaceSymbol)
public HandlerRegistration(INamedTypeSymbol classSymbol, INamedTypeSymbol interfaceSymbol)
{
var attributes = classSymbol.GetAttributes();

Disabled = attributes.Any(x => x.IsIgnoreAttribute());
Handler = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
Interface = interfaceSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
Node = node;
}

public void ServiceRegistration(SourceBuilder sb)
{
if (Disabled) return;

sb.Line(sb, $"services.AddScoped<{Interface}, {Handler}>();");
}

public bool Equals(HandlerRegistration other)
{
return other.Handler == Handler
&& other.Interface == Interface;
}

public override int GetHashCode()
{
return HashCode.Combine(Handler, Interface);
}
}
134 changes: 134 additions & 0 deletions src/Pipelines.Generator/Primitives/EquatableArray.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
using System.Runtime.CompilerServices;

namespace System.Collections.Generic;

/// <summary>
/// Provides extension methods for creating <see cref="EquatableArray{T}"/> instances.
/// </summary>
internal static class EquatableArray
{
/// <summary>
/// Converts an <see cref="IEnumerable{T}"/> to an <see cref="EquatableArray{T}"/>.
/// </summary>
/// <typeparam name="T">The type of elements in the collection.</typeparam>
/// <param name="values">The collection of values to convert.</param>
/// <returns>An <see cref="EquatableArray{T}"/> containing the values from the input collection.</returns>
public static EquatableArray<T> ToEquatableArray<T>(this IEnumerable<T> values) where T : IEquatable<T>
{
return new EquatableArray<T>(values);
}
}

/// <summary>
/// An immutable, equatable array. This is equivalent to <see cref="Array{T}"/> but with value equality support.
/// </summary>
/// <typeparam name="T">The type of values in the array.</typeparam>
internal readonly struct EquatableArray<T> : IEquatable<EquatableArray<T>>
where T : IEquatable<T>
{
public static EquatableArray<T> Empty { get; } = new();

/// <summary>
/// The underlying <typeparamref name="T"/> array.
/// </summary>
private readonly ImmutableArray<T> _array;

public int Length => _array.Length;

/// <summary>
/// Creates a new <see cref="EquatableArray{T}"/> instance.
/// </summary>
public EquatableArray()
{
_array = [];
}

/// <summary>
/// Creates a new <see cref="EquatableArray{T}"/> instance.
/// </summary>
/// <param name="array">The input <see cref="ImmutableArray"/> to wrap.</param>
public EquatableArray(T[] array)
{
_array = [.. array];
}

/// <summary>
/// Creates a new <see cref="EquatableArray{T}"/> instance.
/// </summary>
/// <param name="array">The input <see cref="ImmutableArray"/> to wrap.</param>
public EquatableArray(IEnumerable<T> values)
{
_array = [.. values];
}

/// <sinheritdoc/>
public bool Equals(EquatableArray<T> array)
{
return AsSpan().SequenceEqual(array.AsSpan());
}

/// <sinheritdoc/>
public override bool Equals(object? obj)
{
return obj is EquatableArray<T> array && this.Equals(array);
}

/// <sinheritdoc/>
public override int GetHashCode()
{
if (_array.IsDefaultOrEmpty)
{
return 0;
}

HashCode hashCode = default;

foreach (T item in _array)
{
hashCode.Add(item);
}

return hashCode.ToHashCode();
}

/// <summary>
/// Returns a <see cref="ReadOnlySpan{T}"/> wrapping the current items.
/// </summary>
/// <returns>A <see cref="ReadOnlySpan{T}"/> wrapping the current items.</returns>
public ReadOnlySpan<T> AsSpan()
{
return _array.AsSpan();
}

/// <summary>
/// Returns an enumerator for the contents of the array.
/// </summary>
/// <returns>An enumerator.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ImmutableArray<T>.Enumerator GetEnumerator()
{
return _array.GetEnumerator();
}

/// <summary>
/// Checks whether two <see cref="EquatableArray{T}"/> values are the same.
/// </summary>
/// <param name="left">The first <see cref="EquatableArray{T}"/> value.</param>
/// <param name="right">The second <see cref="EquatableArray{T}"/> value.</param>
/// <returns>Whether <paramref name="left"/> and <paramref name="right"/> are equal.</returns>
public static bool operator ==(EquatableArray<T> left, EquatableArray<T> right)
{
return left.Equals(right);
}

/// <summary>
/// Checks whether two <see cref="EquatableArray{T}"/> values are not the same.
/// </summary>
/// <param name="left">The first <see cref="EquatableArray{T}"/> value.</param>
/// <param name="right">The second <see cref="EquatableArray{T}"/> value.</param>
/// <returns>Whether <paramref name="left"/> and <paramref name="right"/> are not equal.</returns>
public static bool operator !=(EquatableArray<T> left, EquatableArray<T> right)
{
return !left.Equals(right);
}
}
Loading