diff --git a/MSTest.slnf b/MSTest.slnf
index 3c3f4c624a..f8858c6031 100644
--- a/MSTest.slnf
+++ b/MSTest.slnf
@@ -11,6 +11,7 @@
"src\\Analyzers\\MSTest.Analyzers.CodeFixes\\MSTest.Analyzers.CodeFixes.csproj",
"src\\Analyzers\\MSTest.Analyzers.Package\\MSTest.Analyzers.Package.csproj",
"src\\Analyzers\\MSTest.Analyzers\\MSTest.Analyzers.csproj",
+ "src\\Analyzers\\MSTest.AotReflection.SourceGeneration\\MSTest.AotReflection.SourceGeneration.csproj",
"src\\Analyzers\\MSTest.GlobalConfigsGenerator\\MSTest.GlobalConfigsGenerator.csproj",
"src\\Analyzers\\MSTest.SourceGeneration\\MSTest.SourceGeneration.csproj",
"src\\Package\\MSTest.Sdk\\MSTest.Sdk.csproj",
diff --git a/TestFx.slnx b/TestFx.slnx
index 991e9dc8db..4f9fc147da 100644
--- a/TestFx.slnx
+++ b/TestFx.slnx
@@ -69,6 +69,7 @@
+
@@ -133,6 +134,7 @@
+
diff --git a/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs b/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs
index 9188cf64c0..3801417f2c 100644
--- a/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs
+++ b/src/Analyzers/MSTest.AotReflection.SourceGeneration/Generators/TestClassModelBuilder.cs
@@ -1,9 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
+using System.Text;
using Microsoft.CodeAnalysis;
@@ -19,31 +21,61 @@ internal static class TestClassModelBuilder
{
private static readonly SymbolDisplayFormat FullyQualifiedFormat =
SymbolDisplayFormat.FullyQualifiedFormat.WithMiscellaneousOptions(
- SymbolDisplayMiscellaneousOptions.IncludeNullableReferenceTypeModifier
- | SymbolDisplayMiscellaneousOptions.UseSpecialTypes);
+ SymbolDisplayMiscellaneousOptions.UseSpecialTypes);
public static TestClassModel Build(INamedTypeSymbol typeSymbol)
{
+ // Methods / properties are walked across the full inheritance chain (excluding
+ // System.Object) so that MSTest members declared on a base class —
+ // [ClassInitialize], [ClassCleanup], [TestInitialize], [TestCleanup],
+ // [TestMethod], the [TestContext] setter, … — are visible to the consumer
+ // without runtime reflection.
+ //
+ // Iteration order is derived-first so that an override or `new`-shadowed member
+ // on the derived type wins over the base declaration with the same signature.
+ // Constructors are NEVER inherited and are taken only from the leaf type.
+ var methodsByKey = new Dictionary(StringComparer.Ordinal);
+ var propertiesByName = new Dictionary(StringComparer.Ordinal);
ImmutableArray.Builder methods = ImmutableArray.CreateBuilder();
ImmutableArray.Builder properties = ImmutableArray.CreateBuilder();
ImmutableArray.Builder ctors = ImmutableArray.CreateBuilder();
- foreach (ISymbol member in typeSymbol.GetMembers())
+ for (INamedTypeSymbol? current = typeSymbol;
+ current is not null && current.SpecialType != SpecialType.System_Object;
+ current = current.BaseType)
{
- switch (member)
+ bool isLeaf = SymbolEqualityComparer.Default.Equals(current, typeSymbol);
+
+ foreach (ISymbol member in current.GetMembers())
{
- case IMethodSymbol { MethodKind: MethodKind.Ordinary } method
- when method.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal:
- methods.Add(BuildMethod(method));
- break;
- case IPropertySymbol property
- when property.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal:
- properties.Add(BuildProperty(property));
- break;
- case IMethodSymbol { MethodKind: MethodKind.Constructor, IsStatic: false } ctor
- when ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal:
- ctors.Add(new TestConstructorModel(BuildParameters(ctor)));
- break;
+ switch (member)
+ {
+ case IMethodSymbol { MethodKind: MethodKind.Ordinary } method
+ when IsAccessibleFromConsumer(method):
+ string key = BuildMethodSignatureKey(method);
+ if (!methodsByKey.ContainsKey(key))
+ {
+ TestMethodModel model = BuildMethod(method);
+ methodsByKey[key] = model;
+ methods.Add(model);
+ }
+
+ break;
+ case IPropertySymbol property
+ when IsAccessibleFromConsumer(property):
+ if (!propertiesByName.ContainsKey(property.Name))
+ {
+ TestPropertyModel model = BuildProperty(property);
+ propertiesByName[property.Name] = model;
+ properties.Add(model);
+ }
+
+ break;
+ case IMethodSymbol { MethodKind: MethodKind.Constructor, IsStatic: false } ctor
+ when isLeaf && ctor.DeclaredAccessibility is Accessibility.Public or Accessibility.Internal:
+ ctors.Add(new TestConstructorModel(BuildParameters(ctor)));
+ break;
+ }
}
}
@@ -61,6 +93,52 @@ public static TestClassModel Build(INamedTypeSymbol typeSymbol)
Attributes: BuildAttributes(typeSymbol.GetAttributes()));
}
+ // Restricted to accessibilities the emitted helper class (a separate static type
+ // declared in MSTest.SourceGenerated, not a derived type) can legally call.
+ // 'protected' and 'private protected' members require the caller to be a derived
+ // type, so they are excluded; 'protected internal' is included because the internal
+ // half is satisfied (the generated helper lives in the same assembly).
+ private static bool IsAccessibleFromConsumer(ISymbol symbol)
+ => symbol.DeclaredAccessibility is
+ Accessibility.Public
+ or Accessibility.Internal
+ or Accessibility.ProtectedOrInternal;
+
+ private static string BuildMethodSignatureKey(IMethodSymbol method)
+ {
+ var sb = new StringBuilder();
+ sb.Append(method.IsStatic ? "S:" : "I:");
+ sb.Append(method.Name);
+ sb.Append('(');
+ bool first = true;
+ foreach (IParameterSymbol p in method.Parameters)
+ {
+ if (!first)
+ {
+ sb.Append(',');
+ }
+
+ first = false;
+ switch (p.RefKind)
+ {
+ case RefKind.Ref:
+ sb.Append("ref ");
+ break;
+ case RefKind.Out:
+ sb.Append("out ");
+ break;
+ case RefKind.In:
+ sb.Append("in ");
+ break;
+ }
+
+ sb.Append(p.Type.ToDisplayString(FullyQualifiedFormat));
+ }
+
+ sb.Append(')');
+ return sb.ToString();
+ }
+
private static TestMethodModel BuildMethod(IMethodSymbol method)
{
ITypeSymbol returnType = method.ReturnType;
@@ -82,7 +160,7 @@ private static TestMethodModel BuildMethod(IMethodSymbol method)
ReturnsValueTask: returnsValueTask,
ReturnsVoid: returnsVoid,
Parameters: BuildParameters(method),
- Attributes: BuildAttributes(method.GetAttributes()));
+ Attributes: BuildAttributes(CollectInheritedAttributes(method)));
}
private static TestPropertyModel BuildProperty(IPropertySymbol property)
@@ -90,7 +168,105 @@ private static TestPropertyModel BuildProperty(IPropertySymbol property)
Name: property.Name,
FullyQualifiedType: property.Type.ToDisplayString(FullyQualifiedFormat),
HasPublicSetter: property.SetMethod is { DeclaredAccessibility: Accessibility.Public },
- Attributes: BuildAttributes(property.GetAttributes()));
+ Attributes: BuildAttributes(CollectInheritedAttributes(property)));
+
+ // Mirror the runtime behavior of MemberInfo.GetCustomAttributes(inherit: true): walk the
+ // overridden-method chain and union attributes, keeping the most-derived application when
+ // the same attribute type appears on multiple levels.
+ private static ImmutableArray CollectInheritedAttributes(IMethodSymbol method)
+ {
+ ImmutableArray own = method.GetAttributes();
+ if (method.OverriddenMethod is null)
+ {
+ return own;
+ }
+
+ var seen = new HashSet(StringComparer.Ordinal);
+ ImmutableArray.Builder builder = ImmutableArray.CreateBuilder();
+ AppendUnique(builder, seen, own);
+ for (IMethodSymbol? baseMethod = method.OverriddenMethod; baseMethod is not null; baseMethod = baseMethod.OverriddenMethod)
+ {
+ AppendUnique(builder, seen, baseMethod.GetAttributes());
+ }
+
+ return builder.ToImmutable();
+ }
+
+ private static ImmutableArray CollectInheritedAttributes(IPropertySymbol property)
+ {
+ ImmutableArray own = property.GetAttributes();
+ if (property.OverriddenProperty is null)
+ {
+ return own;
+ }
+
+ var seen = new HashSet(StringComparer.Ordinal);
+ ImmutableArray.Builder builder = ImmutableArray.CreateBuilder();
+ AppendUnique(builder, seen, own);
+ for (IPropertySymbol? baseProperty = property.OverriddenProperty; baseProperty is not null; baseProperty = baseProperty.OverriddenProperty)
+ {
+ AppendUnique(builder, seen, baseProperty.GetAttributes());
+ }
+
+ return builder.ToImmutable();
+ }
+
+ private static void AppendUnique(
+ ImmutableArray.Builder builder,
+ HashSet seen,
+ ImmutableArray attributes)
+ {
+ foreach (AttributeData attribute in attributes)
+ {
+ if (attribute.AttributeClass is not { } attributeClass)
+ {
+ continue;
+ }
+
+ // Attributes declared with AttributeUsage(AllowMultiple = true) may legitimately
+ // appear several times across the override chain (e.g. [TestCategory]) — keep every
+ // instance instead of collapsing them to one.
+ if (AllowsMultiple(attributeClass))
+ {
+ builder.Add(attribute);
+ continue;
+ }
+
+ string key = attributeClass.ToDisplayString(FullyQualifiedFormat);
+ if (seen.Add(key))
+ {
+ builder.Add(attribute);
+ }
+ }
+ }
+
+ private static bool AllowsMultiple(INamedTypeSymbol attributeClass)
+ {
+ for (INamedTypeSymbol? current = attributeClass; current is not null; current = current.BaseType)
+ {
+ foreach (AttributeData attribute in current.GetAttributes())
+ {
+ if (attribute.AttributeClass?.ToDisplayString(FullyQualifiedFormat) != "global::System.AttributeUsageAttribute")
+ {
+ continue;
+ }
+
+ foreach (KeyValuePair named in attribute.NamedArguments)
+ {
+ if (named.Key == "AllowMultiple" && named.Value.Value is bool allowMultiple)
+ {
+ return allowMultiple;
+ }
+ }
+
+ // AttributeUsage was found on this level but did not set AllowMultiple — default is false
+ // and base-level [AttributeUsage] is shadowed by the derived application per CLI rules.
+ return false;
+ }
+ }
+
+ return false;
+ }
private static EquatableArray BuildParameters(IMethodSymbol method)
{
diff --git a/src/Analyzers/MSTest.AotReflection.SourceGeneration/MSTest.AotReflection.SourceGeneration.csproj b/src/Analyzers/MSTest.AotReflection.SourceGeneration/MSTest.AotReflection.SourceGeneration.csproj
index 1369428997..dc9bbade05 100644
--- a/src/Analyzers/MSTest.AotReflection.SourceGeneration/MSTest.AotReflection.SourceGeneration.csproj
+++ b/src/Analyzers/MSTest.AotReflection.SourceGeneration/MSTest.AotReflection.SourceGeneration.csproj
@@ -29,4 +29,8 @@
+
+
+
+
diff --git a/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests.csproj b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests.csproj
new file mode 100644
index 0000000000..b91def2e3a
--- /dev/null
+++ b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests.csproj
@@ -0,0 +1,24 @@
+
+
+
+ net8.0
+ MSTest.AotReflection.SourceGeneration.UnitTests
+ true
+ true
+ Exe
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs
new file mode 100644
index 0000000000..936bc5fb62
--- /dev/null
+++ b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/MSTestReflectionMetadataGeneratorTests.cs
@@ -0,0 +1,887 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using AwesomeAssertions;
+
+using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp;
+
+using MSTest.AotReflection.SourceGeneration.Generators;
+
+namespace MSTest.AotReflection.SourceGeneration.UnitTests;
+
+///
+/// Behavior tests for .
+/// These pin the current PoC output so the upcoming follow-up PRs (#1837) can extend it safely.
+///
+[TestClass]
+public sealed class MSTestReflectionMetadataGeneratorTests
+{
+ ///
+ /// Minimal MSTest attribute stubs so the generator can locate [TestClass] /
+ /// [TestMethod] in test fixtures without dragging the real TestFramework
+ /// assemblies into the Roslyn compilation.
+ ///
+ private const string MinimalMSTestStub = """
+ namespace Microsoft.VisualStudio.TestTools.UnitTesting
+ {
+ [System.AttributeUsage(System.AttributeTargets.Class)]
+ public class TestClassAttribute : System.Attribute { }
+
+ [System.AttributeUsage(System.AttributeTargets.Method)]
+ public class TestMethodAttribute : System.Attribute
+ {
+ public TestMethodAttribute() { }
+ public TestMethodAttribute(string displayName) { DisplayName = displayName; }
+ public string? DisplayName { get; set; }
+ }
+
+ [System.AttributeUsage(System.AttributeTargets.Class | System.AttributeTargets.Method, AllowMultiple = true)]
+ public class TestCategoryAttribute : System.Attribute
+ {
+ public TestCategoryAttribute(string category) { Category = category; }
+ public string Category { get; }
+ }
+
+ [System.AttributeUsage(System.AttributeTargets.Property)]
+ public class TestContextAttribute : System.Attribute { }
+
+ [System.AttributeUsage(System.AttributeTargets.Method)]
+ public class TestInitializeAttribute : System.Attribute { }
+
+ [System.AttributeUsage(System.AttributeTargets.Method)]
+ public class TestCleanupAttribute : System.Attribute { }
+ }
+ """;
+
+ [TestMethod]
+ public void Generator_EmitsSupportTypes_OnAnyCompilation()
+ {
+ const string userCode = """
+ // Intentionally empty — no [TestClass] in the consumer.
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ // Support types are emitted via RegisterPostInitializationOutput → always present.
+ string support = result.GeneratedSources
+ .Single(s => s.HintName == "MSTestReflectionMetadata.SupportTypes.g.cs")
+ .SourceText.ToString();
+
+ support.Should().Contain("namespace MSTest.SourceGenerated");
+ support.Should().Contain("internal sealed class TestClassReflectionInfo");
+ support.Should().Contain("internal sealed class TestMethodReflectionInfo");
+ support.Should().Contain("internal sealed class TestPropertyReflectionInfo");
+ support.Should().Contain("internal sealed class TestConstructorReflectionInfo");
+ }
+
+ [TestMethod]
+ public void Generator_EmitsRegistry_WithDiscoveredTestClass()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class MyTests
+ {
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+
+ registry.Should().Contain("internal static class MSTestReflectionMetadata");
+ registry.Should().Contain("public const string AssemblyName = \"TestSample\";");
+ registry.Should().Contain("Type = typeof(global::Sample.MyTests)");
+ registry.Should().Contain("Name = \"Test1\"");
+ registry.Should().Contain("Invoke = static (instance, args) => { ((global::Sample.MyTests)instance!).Test1(); return null; },");
+ }
+
+ [TestMethod]
+ public void Generator_EmitsEmptyRegistry_WhenNoTestClasses()
+ {
+ const string userCode = """
+ namespace Sample
+ {
+ // No [TestClass] anywhere.
+ public class NotATest { public void Foo() { } }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("public static IReadOnlyList TestClasses { get; } = new TestClassReflectionInfo[]");
+ // No concrete TestClassReflectionInfo instance is emitted (note the open paren).
+ registry.Should().NotContain("new TestClassReflectionInfo(");
+ }
+
+ [TestMethod]
+ public void Generator_SkipsStaticTestClass()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public static class StaticTests
+ {
+ [TestMethod]
+ public static void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ // Static classes are excluded by the predicate in the generator (cannot be instantiated).
+ registry.Should().NotContain("StaticTests");
+ }
+
+ [TestMethod]
+ public void Generator_SkipsAbstractTestClass()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public abstract class AbstractTests
+ {
+ [TestMethod]
+ public void Test1() { }
+ }
+
+ [TestClass]
+ public class ConcreteTests
+ {
+ [TestMethod]
+ public void Test2() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ // Abstract classes are filtered in BuildModel — they cannot be instantiated.
+ registry.Should().NotContain("AbstractTests");
+ registry.Should().Contain("typeof(global::Sample.ConcreteTests)");
+ }
+
+ [TestMethod]
+ public void Generator_SkipsGenericTestClass()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class GenericTests
+ {
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ // Open-generic test classes are out of scope for this PoC.
+ registry.Should().NotContain("GenericTests");
+ }
+
+ [TestMethod]
+ public void Generator_EmitsConstructorInvoker()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class CtorTests
+ {
+ public CtorTests() { }
+
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("Constructors = new TestConstructorReflectionInfo[]");
+ registry.Should().Contain("Invoke = static args => new global::Sample.CtorTests(),");
+ }
+
+ [TestMethod]
+ public void Generator_EmitsParameterTypes_ForMethodWithParameters()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class ParamTests
+ {
+ [TestMethod]
+ public void Test1(int x, string y) { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("ParameterTypes = new Type[] { typeof(int), typeof(string) }");
+ registry.Should().Contain("ParameterNames = new string[] { \"x\", \"y\" }");
+ }
+
+ [TestMethod]
+ public void Generator_FlagsAsyncReturnTypes()
+ {
+ const string userCode = """
+ using System.Threading.Tasks;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class AsyncTests
+ {
+ [TestMethod]
+ public Task Test1() => Task.CompletedTask;
+
+ [TestMethod]
+ public ValueTask Test2() => default;
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("Name = \"Test1\"");
+ registry.Should().Contain("ReturnsTask = true");
+ registry.Should().Contain("Name = \"Test2\"");
+ registry.Should().Contain("ReturnsValueTask = true");
+ }
+
+ [TestMethod]
+ public void Generator_CapturesClassLevelAttributes()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ [TestCategory("Smoke")]
+ public class TaggedTests
+ {
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestClassAttribute");
+ registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestCategoryAttribute");
+ registry.Should().Contain("\"Smoke\"");
+ }
+
+ [TestMethod]
+ public void Generator_EmitsPropertyGetterAndSetter()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class TestContext { }
+
+ [TestClass]
+ public class PropTests
+ {
+ [TestContext]
+ public TestContext? Context { get; set; }
+
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ GeneratorRunResult result = RunGenerator(MinimalMSTestStub, userCode);
+
+ result.Diagnostics.Should().BeEmpty();
+ string registry = GetRegistry(result);
+ registry.Should().Contain("Name = \"Context\"");
+ registry.Should().Contain("HasPublicSetter = true");
+ registry.Should().Contain("Get = static instance => instance is null ? null : (object?)((global::Sample.PropTests)instance).Context,");
+ registry.Should().Contain("Set = static (instance, value) => ((global::Sample.PropTests)instance!).Context = (global::Sample.TestContext)value!,");
+ }
+
+ [TestMethod]
+ public void Generator_EmittedSource_CompilesCleanly()
+ {
+ const string userCode = """
+ using System.Threading.Tasks;
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class TestContext { }
+
+ [TestClass]
+ [TestCategory("Smoke")]
+ public class FullShape
+ {
+ [TestContext]
+ public TestContext? Context { get; set; }
+
+ public FullShape() { }
+
+ [TestMethod("alias")]
+ public void Sync(int x) { }
+
+ [TestMethod]
+ public Task Asynchronous() => Task.CompletedTask;
+ }
+ }
+ """;
+
+ Compilation outputCompilation = RunGeneratorAndGetCompilation(MinimalMSTestStub, userCode);
+
+ IEnumerable diagnostics = outputCompilation
+ .GetDiagnostics()
+ .Where(d => d.Severity == DiagnosticSeverity.Error);
+
+ diagnostics.Should().BeEmpty(
+ "the generated source MUST compile cleanly when consumed in the same compilation as the user code");
+ }
+
+ [TestMethod]
+ public void Generator_StripsNullableAnnotation_FromTypeofExpressions()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class TestContext { }
+
+ [TestClass]
+ public class NullableShapes
+ {
+ [TestContext]
+ public TestContext? Context { get; set; }
+
+ [TestMethod]
+ public void TakesNullableRef(string? value) { }
+
+ [TestMethod]
+ public void TakesNullableValueType(int? n) { }
+ }
+ }
+ """;
+
+ Compilation outputCompilation = RunGeneratorAndGetCompilation(MinimalMSTestStub, userCode);
+ string registry = outputCompilation
+ .SyntaxTrees
+ .Single(t => t.FilePath.EndsWith("MSTestReflectionMetadata.Registry.g.cs", System.StringComparison.Ordinal))
+ .ToString();
+
+ // typeof(...) MUST NOT carry nullable reference type annotation (CS8639).
+ registry.Should().NotContain("typeof(global::Sample.TestContext?)");
+ registry.Should().NotContain("typeof(string?)");
+ // Reference types in typeof drop the annotation entirely.
+ registry.Should().Contain("typeof(global::Sample.TestContext)");
+ registry.Should().Contain("typeof(string)");
+ // Nullable value types are still distinct from their underlying type and must be preserved as Nullable.
+ registry.Should().Contain("typeof(int?)");
+
+ // The whole compilation must be free of CS errors.
+ IEnumerable errors = outputCompilation
+ .GetDiagnostics()
+ .Where(d => d.Severity == DiagnosticSeverity.Error);
+ errors.Should().BeEmpty("typeof(T?) on a reference type is invalid C# (CS8639)");
+ }
+
+ [TestMethod]
+ public void Generator_IsIncremental_SupportTypesAreCached_WhenInputUnchanged()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class IncTests
+ {
+ [TestMethod]
+ public void Test1() { }
+ }
+ }
+ """;
+
+ CSharpCompilation compilation = CreateCompilation(MinimalMSTestStub, userCode);
+ GeneratorDriver driver = CSharpGeneratorDriver
+ .Create(new MSTestReflectionMetadataGenerator())
+ .WithUpdatedParseOptions((CSharpParseOptions)compilation.SyntaxTrees.First().Options);
+
+ // Track step output cache reasons.
+ driver = driver.RunGenerators(compilation);
+ driver = driver.RunGenerators(compilation);
+
+ GeneratorDriverRunResult result = driver.GetRunResult();
+ result.Diagnostics.Should().BeEmpty();
+ result.Results.Should().ContainSingle();
+ // Two passes against the same compilation must produce identical sources.
+ result.Results[0].GeneratedSources.Should().HaveCount(2);
+ }
+
+ [TestMethod]
+ public void Generator_IncludesMethodsFromBaseType()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestInitialize]
+ public void Setup() { }
+
+ [TestMethod]
+ public void InheritedTest() { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestMethod]
+ public void DerivedTest() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ registry.Should().Contain("Name = \"InheritedTest\"");
+ registry.Should().Contain("Name = \"Setup\"");
+ registry.Should().Contain("Name = \"DerivedTest\"");
+ // The TestInitialize attribute applied on the base method must propagate too.
+ registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestInitializeAttribute");
+ }
+
+ [TestMethod]
+ public void Generator_IncludesMethodsFromMultiLevelInheritance()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class GrandparentTests
+ {
+ [TestMethod]
+ public void GrandparentTest() { }
+ }
+
+ public class ParentTests : GrandparentTests
+ {
+ [TestMethod]
+ public void ParentTest() { }
+ }
+
+ [TestClass]
+ public class LeafTests : ParentTests
+ {
+ [TestMethod]
+ public void LeafTest() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ registry.Should().Contain("Name = \"GrandparentTest\"");
+ registry.Should().Contain("Name = \"ParentTest\"");
+ registry.Should().Contain("Name = \"LeafTest\"");
+ }
+
+ [TestMethod]
+ public void Generator_OverriddenVirtualMethod_KeepsOnlyDerivedImplementation()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestMethod]
+ public virtual void Run() { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ public override void Run() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ // Only one entry for Run should be emitted, and the invoker must dispatch on the derived type.
+ int runEntries = registry.Split(["Name = \"Run\""], System.StringSplitOptions.None).Length - 1;
+ runEntries.Should().Be(1, "the derived override must replace the base entry (not duplicate it)");
+ registry.Should().Contain("((global::Sample.DerivedTests)instance!).Run();");
+ registry.Should().NotContain("((global::Sample.BaseTests)instance!).Run();");
+
+ // The override does NOT re-apply [TestMethod] but the attribute must still be visible
+ // via the override chain — matching the runtime semantics of GetCustomAttributes(inherit: true).
+ registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestMethodAttribute");
+ }
+
+ [TestMethod]
+ public void Generator_NewKeywordHiddenMethod_DedupsBySignature()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestMethod]
+ public void Hidden() { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestMethod]
+ public new void Hidden() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ int hiddenEntries = registry.Split(["Name = \"Hidden\""], System.StringSplitOptions.None).Length - 1;
+ hiddenEntries.Should().Be(1, "members with the same name and signature must be de-duplicated; derived wins");
+ registry.Should().Contain("((global::Sample.DerivedTests)instance!).Hidden();");
+ }
+
+ [TestMethod]
+ public void Generator_OverloadsWithDifferentSignatures_AreAllPreserved()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestMethod]
+ public void Op(int x) { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestMethod]
+ public void Op(string x) { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ // Both overloads survive — they have different signatures.
+ int opEntries = registry.Split(["Name = \"Op\""], System.StringSplitOptions.None).Length - 1;
+ opEntries.Should().Be(2);
+ registry.Should().Contain("typeof(int)");
+ registry.Should().Contain("typeof(string)");
+ }
+
+ [TestMethod]
+ public void Generator_IncludesPropertiesFromBaseType()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class TestContext { }
+
+ public class BaseTests
+ {
+ [TestContext]
+ public virtual TestContext Context { get; set; } = new();
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestMethod]
+ public void Test() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ registry.Should().Contain("Name = \"Context\"");
+ registry.Should().Contain("global::Microsoft.VisualStudio.TestTools.UnitTesting.TestContextAttribute");
+ }
+
+ [TestMethod]
+ public void Generator_DoesNotInheritConstructors()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ public BaseTests(int x) { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ public DerivedTests() : base(1) { }
+
+ [TestMethod]
+ public void Test() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ // Only the derived ctor (parameterless) should be emitted — base ctor is never inherited.
+ registry.Should().Contain("Invoke = static args => new global::Sample.DerivedTests(),");
+ registry.Should().NotContain("Invoke = static args => new global::Sample.BaseTests(");
+ // No int parameter from the base constructor leaks into the constructor list.
+ registry.Should().NotContain("ParameterTypes = new Type[] { typeof(int) },");
+ }
+
+ [TestMethod]
+ public void Generator_AbstractBaseWithConcreteDerived_FoldsBaseMembers()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class TestContext { }
+
+ public abstract class AbstractBase
+ {
+ [TestInitialize]
+ public void Setup() { }
+
+ [TestContext]
+ public TestContext Ctx { get; set; } = new();
+ }
+
+ [TestClass]
+ public class Concrete : AbstractBase
+ {
+ [TestMethod]
+ public void Test() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ registry.Should().Contain("Name = \"Setup\"");
+ registry.Should().Contain("Name = \"Ctx\"");
+ registry.Should().Contain("Name = \"Test\"");
+ // The base class was abstract but the concrete derived type is the one emitted.
+ registry.Should().Contain("Type = typeof(global::Sample.Concrete)");
+ }
+
+ [TestMethod]
+ public void Generator_DoesNotWalkPastSystemObject()
+ {
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ [TestClass]
+ public class SimpleTests
+ {
+ [TestMethod]
+ public void Test() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ // Members of System.Object (ToString, Equals, GetHashCode, GetType) must NOT be emitted.
+ registry.Should().NotContain("Name = \"ToString\"");
+ registry.Should().NotContain("Name = \"Equals\"");
+ registry.Should().NotContain("Name = \"GetHashCode\"");
+ registry.Should().NotContain("Name = \"GetType\"");
+ }
+
+ [TestMethod]
+ public void Generator_SkipsProtectedAndPrivateProtectedMembers()
+ {
+ // Generated invokers live in a separate static helper class (not a derived type),
+ // so 'protected' and 'private protected' members are not callable from the emitted
+ // code. They MUST be excluded from the registry to keep the generated source compiling.
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestMethod]
+ protected void ProtectedTest() { }
+
+ [TestMethod]
+ private protected void PrivateProtectedTest() { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestMethod]
+ public void PublicTest() { }
+ }
+ }
+ """;
+
+ Compilation outputCompilation = RunGeneratorAndGetCompilation(MinimalMSTestStub, userCode);
+ IEnumerable errors = outputCompilation
+ .GetDiagnostics()
+ .Where(d => d.Severity == DiagnosticSeverity.Error);
+ errors.Should().BeEmpty("the generated registry MUST NOT reference protected / private protected members");
+
+ string registry = outputCompilation
+ .SyntaxTrees
+ .Single(t => t.FilePath.EndsWith("MSTestReflectionMetadata.Registry.g.cs", System.StringComparison.Ordinal))
+ .ToString();
+
+ registry.Should().Contain("Name = \"PublicTest\"");
+ registry.Should().NotContain("Name = \"ProtectedTest\"");
+ registry.Should().NotContain("Name = \"PrivateProtectedTest\"");
+ }
+
+ [TestMethod]
+ public void Generator_KeepsAllowMultipleAttributes_AcrossOverrideChain()
+ {
+ // [TestCategory] is AllowMultiple=true. Collecting attributes across the override chain
+ // MUST keep every instance instead of collapsing them by type, otherwise the inherited
+ // categories disappear from the registry.
+ const string userCode = """
+ using Microsoft.VisualStudio.TestTools.UnitTesting;
+
+ namespace Sample
+ {
+ public class BaseTests
+ {
+ [TestMethod]
+ [TestCategory("BaseCat")]
+ public virtual void Run() { }
+ }
+
+ [TestClass]
+ public class DerivedTests : BaseTests
+ {
+ [TestCategory("DerivedCat")]
+ public override void Run() { }
+ }
+ }
+ """;
+
+ string registry = GetRegistry(RunGenerator(MinimalMSTestStub, userCode));
+
+ registry.Should().Contain("\"BaseCat\"");
+ registry.Should().Contain("\"DerivedCat\"");
+ }
+
+ private static string GetRegistry(GeneratorRunResult result)
+ => result.GeneratedSources
+ .Single(s => s.HintName == "MSTestReflectionMetadata.Registry.g.cs")
+ .SourceText.ToString()
+ .Replace("\r\n", "\n");
+
+ private static GeneratorRunResult RunGenerator(params string[] sources)
+ {
+ CSharpCompilation compilation = CreateCompilation(sources);
+ GeneratorDriver driver = CSharpGeneratorDriver.Create(new MSTestReflectionMetadataGenerator());
+ driver = driver.RunGeneratorsAndUpdateCompilation(compilation, out _, out _);
+ return driver.GetRunResult().Results[0];
+ }
+
+ private static Compilation RunGeneratorAndGetCompilation(params string[] sources)
+ {
+ CSharpCompilation compilation = CreateCompilation(sources);
+ GeneratorDriver driver = CSharpGeneratorDriver.Create(new MSTestReflectionMetadataGenerator());
+ driver.RunGeneratorsAndUpdateCompilation(compilation, out Compilation outputCompilation, out _);
+ return outputCompilation;
+ }
+
+ private static CSharpCompilation CreateCompilation(params string[] sources)
+ {
+ IEnumerable trees = sources.Select(s => CSharpSyntaxTree.ParseText(s));
+ MetadataReference[] references = new[]
+ {
+ MetadataReference.CreateFromFile(typeof(object).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Runtime.CompilerServices.ModuleInitializerAttribute).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Reflection.Assembly).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Collections.Generic.Dictionary<,>).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Reflection.MethodInfo).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Linq.Enumerable).Assembly.Location),
+ MetadataReference.CreateFromFile(typeof(System.Threading.Tasks.Task).Assembly.Location),
+ };
+
+ return CSharpCompilation.Create(
+ "TestSample",
+ trees,
+ references,
+ new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary));
+ }
+}
diff --git a/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/Program.cs b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/Program.cs
new file mode 100644
index 0000000000..9fbdc73473
--- /dev/null
+++ b/test/UnitTests/MSTest.AotReflection.SourceGeneration.UnitTests/Program.cs
@@ -0,0 +1,19 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT license. See LICENSE file in the project root for full license information.
+
+using Microsoft.Testing.Extensions;
+
+ITestApplicationBuilder builder = await TestApplication.CreateBuilderAsync(args);
+builder.AddMSTest(() => [Assembly.GetEntryAssembly()!]);
+
+#if ENABLE_CODECOVERAGE
+builder.AddCodeCoverageProvider();
+#endif
+builder.AddHangDumpProvider();
+builder.AddCrashDumpProvider(ignoreIfNotSupported: true);
+builder.AddTrxReportProvider();
+builder.AddAppInsightsTelemetryProvider();
+builder.AddAzureDevOpsProvider();
+
+using ITestApplication app = await builder.BuildAsync();
+return await app.RunAsync();