From 701e4ac87e80508ca4458d11b1324c5ea729fde0 Mon Sep 17 00:00:00 2001 From: markbrady Date: Mon, 9 Feb 2026 09:25:40 -0800 Subject: [PATCH] [IfChainToSwitch] refactor common logic into `SwitchUtils` library PiperOrigin-RevId: 867639648 --- .../bugpatterns/IfChainToSwitch.java | 91 ++++++++----------- .../bugpatterns/IfChainToSwitchTest.java | 87 ++++++++++++++---- 2 files changed, 110 insertions(+), 68 deletions(-) diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java b/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java index 12fb26e5359..fa5c3b85d81 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java @@ -20,6 +20,9 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.errorprone.BugPattern.SeverityLevel.WARNING; +import static com.google.errorprone.bugpatterns.SwitchUtils.COMPILE_TIME_CONSTANT_MATCHER; +import static com.google.errorprone.bugpatterns.SwitchUtils.isEnumValue; +import static com.google.errorprone.bugpatterns.SwitchUtils.renderComments; import static com.google.errorprone.matchers.Description.NO_MATCH; import static com.google.errorprone.util.ASTHelpers.constValue; import static com.google.errorprone.util.ASTHelpers.getStartPosition; @@ -40,11 +43,11 @@ import com.google.errorprone.ErrorProneFlags; import com.google.errorprone.VisitorState; import com.google.errorprone.bugpatterns.BugChecker.IfTreeMatcher; +import com.google.errorprone.bugpatterns.SwitchUtils.Validity; +import com.google.errorprone.bugpatterns.threadsafety.ConstantExpressions; import com.google.errorprone.fixes.SuggestedFix; import com.google.errorprone.fixes.SuggestedFixes; -import com.google.errorprone.matchers.CompileTimeConstantExpressionMatcher; import com.google.errorprone.matchers.Description; -import com.google.errorprone.matchers.Matcher; import com.google.errorprone.suppliers.Suppliers; import com.google.errorprone.util.ASTHelpers; import com.google.errorprone.util.ErrorProneComment; @@ -87,28 +90,18 @@ public final class IfChainToSwitch extends BugChecker implements IfTreeMatcher { // it's either an ExpressionStatement or a Throw. Refer to JLS 14 ยง14.11.1 private static final ImmutableSet KINDS_CONVERTIBLE_WITHOUT_BRACES = ImmutableSet.of(THROW, EXPRESSION_STATEMENT); - private static final Matcher COMPILE_TIME_CONSTANT_MATCHER = - CompileTimeConstantExpressionMatcher.instance(); - - /** - * Tri-state of whether the if-chain is valid, invalid, or possibly valid for conversion to a - * switch. - */ - enum Validity { - MAYBE_VALID, - INVALID, - VALID - } private final boolean enableMain; private final boolean enableSafe; private final int maxChainLength; + private final ConstantExpressions constantExpressions; @Inject - IfChainToSwitch(ErrorProneFlags flags) { + IfChainToSwitch(ErrorProneFlags flags, ConstantExpressions constantExpressions) { enableMain = flags.getBoolean("IfChainToSwitch:EnableMain").orElse(false); enableSafe = flags.getBoolean("IfChainToSwitch:EnableSafe").orElse(false); maxChainLength = flags.getInteger("IfChainToSwitch:MaxChainLength").orElse(50); + this.constantExpressions = constantExpressions; } @Override @@ -346,14 +339,6 @@ private static Range buildCommentRange(ErrorProneComment comment, int i return Range.closedOpen(comment.getPos() + ifTreeStart, comment.getEndPos() + ifTreeStart); } - /** Render the supplied comments, separated by newlines. */ - private static String renderComments(ImmutableList comments) { - return comments.stream() - .map(ErrorProneComment::getText) - .filter(commentText -> !commentText.isEmpty()) - .collect(joining("\n")); - } - /** * Renders Java source code representation of the supplied {@code Type} that is suitable for use * in fixes, where any raw types are replaced with wildcard types. For example, `List` becomes @@ -465,7 +450,7 @@ && isSubtype( boolean hasPattern = cases.stream().anyMatch(x -> x.instanceOfOptional().isPresent()); boolean allEnumValuesPresent = - isEnum(subject, state) + isEnumValue(subject, state) && handledEnumValues.containsAll(ASTHelpers.enumValues(switchType.asElement())); if (hasDefault && hasUnconditional) { @@ -796,7 +781,7 @@ private static boolean hasDominanceViolation( * converting it to a switch statement. Returns the analysis state following the analysis of the * if statement at this level. */ - private static IfChainAnalysisState analyzeIfStatement( + private IfChainAnalysisState analyzeIfStatement( IfChainAnalysisState ifChainAnalysisState, ExpressionTree condition, StatementTree conditionalBlock, @@ -866,10 +851,6 @@ private static IfChainAnalysisState analyzeIfStatement( ImmutableSet.copyOf(handledEnumValues)); } - private static boolean isEnum(ExpressionTree tree, VisitorState state) { - return isSubtype(getType(tree), state.getSymtab().enumSym.type, state); - } - /** Determines whether any yield or break statements are present in the tree. */ private static boolean hasBreakOrYieldInTree(Tree tree) { Boolean result = @@ -900,7 +881,7 @@ public Boolean reduce(@Nullable Boolean left, @Nullable Boolean right) { * so even if this method returns {@code Optional.empty()} the predicate may still be convertible * by some other (unsupported) means. */ - private static Optional validatePredicateForSubject( + private Optional validatePredicateForSubject( ExpressionTree predicate, Optional subject, VisitorState state, @@ -955,7 +936,7 @@ private static Optional validatePredicateForSubject( hasElseIf); } else { // Predicate is a binary tree, but neither side is a constant. - if (isEnum(lhs, state) || isEnum(rhs, state)) { + if (isEnumValue(lhs, state) || isEnumValue(rhs, state)) { return validateEnumPredicateForSubject( lhs, rhs, @@ -1029,7 +1010,23 @@ private static Optional validatePredicateForSubject( return Optional.empty(); } - private static Optional validateInstanceofForSubject( + /** + * Determines whether the {@code subject} expression "matches" the given {@code expression}. If + * {@code enableSafe} is true, then matching means that the subject must be referring to the same + * variable or constant expression. If {@code enableSafe} is false, then we also allow expressions + * that have potential side-effects. + */ + private boolean subjectMatches( + ExpressionTree subject, ExpressionTree expression, VisitorState state) { + + boolean sameVariable = sameVariable(subject, expression); + + return enableSafe + ? sameVariable || constantExpressions.isSame(subject, expression, state) + : sameVariable || expressionSourceMatches(subject, expression, state); + } + + private Optional validateInstanceofForSubject( ExpressionTree at, InstanceOfTree instanceOfTree, Optional subject, @@ -1044,10 +1041,7 @@ private static Optional validateInstanceofForSubject( ExpressionTree expression = at; // Does this expression and the subject (if present) refer to the same thing? - if (subject.isPresent() - && !(sameVariable(subject.get(), expression) - || subject.get().equals(expression) - || expressionSourceMatches(subject, expression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), expression, state)) { return Optional.empty(); } @@ -1134,7 +1128,7 @@ private static Optional validateInstanceofForSubject( return Optional.of(expression); } - private static Optional validateCompileTimeConstantForSubject( + private Optional validateCompileTimeConstantForSubject( ExpressionTree lhs, ExpressionTree rhs, Optional subject, @@ -1150,10 +1144,7 @@ private static Optional validateCompileTimeConstantForSubject( ExpressionTree testExpression = compileTimeConstantOnLhs ? rhs : lhs; ExpressionTree compileTimeConstant = compileTimeConstantOnLhs ? lhs : rhs; - if (subject.isPresent() - && !(sameVariable(subject.get(), testExpression) - || subject.get().equals(testExpression) - || expressionSourceMatches(subject, testExpression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), testExpression, state)) { // Predicate not compatible with predicate of preceding if statement return Optional.empty(); } @@ -1210,7 +1201,7 @@ private static Optional validateCompileTimeConstantForSubject( return Optional.of(testExpression); } - private static Optional validateEnumPredicateForSubject( + private Optional validateEnumPredicateForSubject( ExpressionTree lhs, ExpressionTree rhs, Optional subject, @@ -1223,8 +1214,8 @@ private static Optional validateEnumPredicateForSubject( int caseEndPosition, boolean hasElse, boolean hasElseIf) { - boolean lhsIsEnumConstant = isEnum(lhs, state) && ASTHelpers.isEnumConstant(lhs); - boolean rhsIsEnumConstant = isEnum(rhs, state) && ASTHelpers.isEnumConstant(rhs); + boolean lhsIsEnumConstant = isEnumValue(lhs, state) && ASTHelpers.isEnumConstant(lhs); + boolean rhsIsEnumConstant = isEnumValue(rhs, state) && ASTHelpers.isEnumConstant(rhs); if (lhsIsEnumConstant && rhsIsEnumConstant) { // Comparing enum const to enum const, cannot convert @@ -1240,10 +1231,7 @@ private static Optional validateEnumPredicateForSubject( ExpressionTree compileTimeConstant = lhsIsEnumConstant ? lhs : rhs; ExpressionTree testExpression = lhsIsEnumConstant ? rhs : lhs; - if (subject.isPresent() - && !(sameVariable(subject.get(), testExpression) - || subject.get().equals(testExpression) - || expressionSourceMatches(subject, testExpression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), testExpression, state)) { return Optional.empty(); } @@ -1298,10 +1286,9 @@ private static Optional validateEnumPredicateForSubject( * comments, etc. */ private static boolean expressionSourceMatches( - Optional subject, ExpressionTree expression, VisitorState state) { + ExpressionTree subject, ExpressionTree expression, VisitorState state) { - return subject.isPresent() - && state.getSourceForNode(subject.get()).equals(state.getSourceForNode(expression)); + return state.getSourceForNode(subject).equals(state.getSourceForNode(expression)); } /** Retrieves a list of all statements (if any) following the current path, if any. */ @@ -1540,7 +1527,7 @@ public static boolean isDominatedBy( continue; } } - boolean isEnum = isEnum(constantExpression, state); + boolean isEnum = isEnumValue(constantExpression, state); if (isEnum) { if (lhs.guardOptional().isPresent()) { // Guarded patterns cannot dominate enum values diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java index 62ec25ff8b1..26ed8481d4e 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java @@ -327,18 +327,16 @@ public void ifChain_dontAlwaysPullUp_error() { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); - if (this.suit instanceof String) { + if (suit instanceof String) { System.out.println("It's a string!"); } else if (suit instanceof Number) { System.out.println("It's a number!"); } else if (suit instanceof Suit) { System.out.println("It's a Suit!"); - } else if (this.suit instanceof Object o) { + } else if (suit instanceof Object o) { System.out.println("It's an object!"); } throw new AssertionError(); @@ -351,10 +349,8 @@ public void foo(Suit s) { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); switch (suit) { case String unused -> System.out.println("It's a string!"); @@ -381,18 +377,16 @@ public void ifChain_dontAlwaysPullUpSafe_error() { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); - if (this.suit instanceof String) { + if (suit instanceof String) { System.out.println("It's a string!"); } else if (suit instanceof Number) { System.out.println("It's a number!"); } else if (suit instanceof Suit) { System.out.println("It's a Suit!"); - } else if (this.suit instanceof Object o) { + } else if (suit instanceof Object o) { System.out.println("It's an object!"); } throw new AssertionError(); @@ -405,10 +399,8 @@ public void foo(Suit s) { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); switch (suit) { case String unused -> System.out.println("It's a string!"); @@ -2375,6 +2367,69 @@ public void foo(Suit s) { .doTest(); } + @Test + public void ifChain_methodInvocation_error() { + refactoringHelper + .addInputLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + if (suit.hashCode() == 1) { + System.out.println("Hash code 1"); + } else if (suit.hashCode() == 2) { + System.out.println("Hash code 2"); + } else if (suit.hashCode() == 3) { + System.out.println("Hash code 3"); + } else throw new AssertionError("Some other hash code"); + } + } + """) + .addOutputLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + switch (suit.hashCode()) { + case 1 -> System.out.println("Hash code 1"); + case 2 -> System.out.println("Hash code 2"); + case 3 -> System.out.println("Hash code 3"); + default -> throw new AssertionError("Some other hash code"); + } + } + } + """) + .setArgs("-XepOpt:IfChainToSwitch:EnableMain") + .doTest(); + } + + @Test + public void ifChain_methodInvocationSafe_noError() { + // Same code as ifChain_methodInvocation_error, but should not refactor in safe mode because the + // subject is not the same variable. + helper + .addSourceLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + if (suit.hashCode() == 1) { + System.out.println("Hash code 1"); + } else if (suit.hashCode() == 2) { + System.out.println("Hash code 2"); + } else if (suit.hashCode() == 3) { + System.out.println("Hash code 3"); + } else throw new AssertionError("Some other hash code"); + } + } + """) + .setArgs("-XepOpt:IfChainToSwitch:EnableMain", "-XepOpt:IfChainToSwitch:EnableSafe=true") + .doTest(); + } + @Test public void ifChain_parameterizedTypeSafe_error() { // Raw types are converted to the wildcard type.