diff --git a/src/DataverseAnalyzer/EnumAssignmentAnalyzer.cs b/src/DataverseAnalyzer/EnumAssignmentAnalyzer.cs index 10a57ec..9c3bb35 100644 --- a/src/DataverseAnalyzer/EnumAssignmentAnalyzer.cs +++ b/src/DataverseAnalyzer/EnumAssignmentAnalyzer.cs @@ -52,8 +52,20 @@ private static void AnalyzePropertyDeclaration(SyntaxNodeAnalysisContext context private static void AnalyzeEnumAssignmentForProperty(SyntaxNodeAnalysisContext context, PropertyDeclarationSyntax property, ExpressionSyntax right) { - // Check if the right side is a numeric literal - if (right is not LiteralExpressionSyntax literal || !IsNumericLiteral(literal)) + LiteralExpressionSyntax literal; + + // Check if the right side is a numeric literal or a cast expression with a numeric literal + if (right is LiteralExpressionSyntax directLiteral && IsNumericLiteral(directLiteral)) + { + literal = directLiteral; + } + else if (right is CastExpressionSyntax castExpr && + castExpr.Expression is LiteralExpressionSyntax castLiteral && + IsNumericLiteral(castLiteral)) + { + literal = castLiteral; + } + else { return; } @@ -88,8 +100,20 @@ private static void AnalyzeEnumAssignmentForProperty(SyntaxNodeAnalysisContext c private static void AnalyzeEnumAssignment(SyntaxNodeAnalysisContext context, SyntaxNode left, ExpressionSyntax right) { - // Check if the right side is a numeric literal - if (right is not LiteralExpressionSyntax literal || !IsNumericLiteral(literal)) + LiteralExpressionSyntax literal; + + // Check if the right side is a numeric literal or a cast expression with a numeric literal + if (right is LiteralExpressionSyntax directLiteral && IsNumericLiteral(directLiteral)) + { + literal = directLiteral; + } + else if (right is CastExpressionSyntax castExpr && + castExpr.Expression is LiteralExpressionSyntax castLiteral && + IsNumericLiteral(castLiteral)) + { + literal = castLiteral; + } + else { return; } diff --git a/tests/DataverseAnalyzer.Tests/EnumAssignmentAnalyzerTests.cs b/tests/DataverseAnalyzer.Tests/EnumAssignmentAnalyzerTests.cs index 7b12c90..e5637c8 100644 --- a/tests/DataverseAnalyzer.Tests/EnumAssignmentAnalyzerTests.cs +++ b/tests/DataverseAnalyzer.Tests/EnumAssignmentAnalyzerTests.cs @@ -197,7 +197,7 @@ public void TestMethod() } [Fact] - public async Task EnumPropertyAssignedCastShouldNotTrigger() + public async Task EnumPropertyAssignedCastShouldTrigger() { var source = """ public enum Status @@ -217,6 +217,110 @@ public void TestMethod() } """; + var diagnostics = await GetDiagnosticsAsync(source); + Assert.Single(diagnostics); + Assert.Equal("CT0002", diagnostics[0].Id); + } + + [Fact] + public async Task EnumCastWithDataverseNamingPatternShouldTrigger() + { + var source = """ + public enum demo_Entity_statuscode + { + ValidStatus = 1, + InvalidStatus = 2 + } + + class TestClass + { + public demo_Entity_statuscode statuscode { get; set; } + + public void TestMethod() + { + statuscode = (demo_Entity_statuscode)3; + } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + Assert.Single(diagnostics); + Assert.Equal("CT0002", diagnostics[0].Id); + } + + [Fact] + public async Task PropertyInitializerWithCastShouldTrigger() + { + var source = """ + public enum Priority + { + Low = 1, + High = 2 + } + + class TestClass + { + public Priority Priority { get; set; } = (Priority)1; + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + Assert.Single(diagnostics); + Assert.Equal("CT0002", diagnostics[0].Id); + } + + [Fact] + public async Task MemberAccessEnumAssignedCastShouldTrigger() + { + var source = """ + public enum AccountCategoryCode + { + Standard = 1, + Preferred = 2 + } + + class Account + { + public AccountCategoryCode? AccountCategoryCode { get; set; } + } + + class TestClass + { + public void TestMethod() + { + var account = new Account(); + account.AccountCategoryCode = (AccountCategoryCode)2; + } + } + """; + + var diagnostics = await GetDiagnosticsAsync(source); + Assert.Single(diagnostics); + Assert.Equal("CT0002", diagnostics[0].Id); + } + + [Fact] + public async Task EnumCastFromVariableShouldNotTrigger() + { + var source = """ + public enum Status + { + Active = 1, + Inactive = 2 + } + + class TestClass + { + public Status Status { get; set; } + + public void TestMethod() + { + var value = 1; + Status = (Status)value; + } + } + """; + var diagnostics = await GetDiagnosticsAsync(source); Assert.Empty(diagnostics); }