diff --git a/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java b/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java index 12fb26e5359..9c70953097b 100644 --- a/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java +++ b/core/src/main/java/com/google/errorprone/bugpatterns/IfChainToSwitch.java @@ -40,6 +40,7 @@ import com.google.errorprone.ErrorProneFlags; import com.google.errorprone.VisitorState; import com.google.errorprone.bugpatterns.BugChecker.IfTreeMatcher; +import com.google.errorprone.bugpatterns.threadsafety.ConstantExpressions; import com.google.errorprone.fixes.SuggestedFix; import com.google.errorprone.fixes.SuggestedFixes; import com.google.errorprone.matchers.CompileTimeConstantExpressionMatcher; @@ -103,12 +104,14 @@ enum Validity { private final boolean enableMain; private final boolean enableSafe; private final int maxChainLength; + private final ConstantExpressions constantExpressions; @Inject - IfChainToSwitch(ErrorProneFlags flags) { + IfChainToSwitch(ErrorProneFlags flags, ConstantExpressions constantExpressions) { enableMain = flags.getBoolean("IfChainToSwitch:EnableMain").orElse(false); enableSafe = flags.getBoolean("IfChainToSwitch:EnableSafe").orElse(false); maxChainLength = flags.getInteger("IfChainToSwitch:MaxChainLength").orElse(50); + this.constantExpressions = constantExpressions; } @Override @@ -796,7 +799,7 @@ private static boolean hasDominanceViolation( * converting it to a switch statement. Returns the analysis state following the analysis of the * if statement at this level. */ - private static IfChainAnalysisState analyzeIfStatement( + private IfChainAnalysisState analyzeIfStatement( IfChainAnalysisState ifChainAnalysisState, ExpressionTree condition, StatementTree conditionalBlock, @@ -900,7 +903,7 @@ public Boolean reduce(@Nullable Boolean left, @Nullable Boolean right) { * so even if this method returns {@code Optional.empty()} the predicate may still be convertible * by some other (unsupported) means. */ - private static Optional validatePredicateForSubject( + private Optional validatePredicateForSubject( ExpressionTree predicate, Optional subject, VisitorState state, @@ -1029,7 +1032,23 @@ private static Optional validatePredicateForSubject( return Optional.empty(); } - private static Optional validateInstanceofForSubject( + /** + * Determines whether the {@code subject} expression "matches" the given {@code expression}. If + * {@code enableSafe} is true, then matching means that the subject must be referring to the same + * variable or constant expression. If {@code enableSafe} is false, then we also allow expressions + * that have potential side-effects. + */ + private boolean subjectMatches( + ExpressionTree subject, ExpressionTree expression, VisitorState state) { + + boolean sameVariable = sameVariable(subject, expression); + + return enableSafe + ? sameVariable || constantExpressions.isSame(subject, expression, state) + : sameVariable || expressionSourceMatches(subject, expression, state); + } + + private Optional validateInstanceofForSubject( ExpressionTree at, InstanceOfTree instanceOfTree, Optional subject, @@ -1044,10 +1063,7 @@ private static Optional validateInstanceofForSubject( ExpressionTree expression = at; // Does this expression and the subject (if present) refer to the same thing? - if (subject.isPresent() - && !(sameVariable(subject.get(), expression) - || subject.get().equals(expression) - || expressionSourceMatches(subject, expression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), expression, state)) { return Optional.empty(); } @@ -1134,7 +1150,7 @@ private static Optional validateInstanceofForSubject( return Optional.of(expression); } - private static Optional validateCompileTimeConstantForSubject( + private Optional validateCompileTimeConstantForSubject( ExpressionTree lhs, ExpressionTree rhs, Optional subject, @@ -1150,10 +1166,7 @@ private static Optional validateCompileTimeConstantForSubject( ExpressionTree testExpression = compileTimeConstantOnLhs ? rhs : lhs; ExpressionTree compileTimeConstant = compileTimeConstantOnLhs ? lhs : rhs; - if (subject.isPresent() - && !(sameVariable(subject.get(), testExpression) - || subject.get().equals(testExpression) - || expressionSourceMatches(subject, testExpression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), testExpression, state)) { // Predicate not compatible with predicate of preceding if statement return Optional.empty(); } @@ -1210,7 +1223,7 @@ private static Optional validateCompileTimeConstantForSubject( return Optional.of(testExpression); } - private static Optional validateEnumPredicateForSubject( + private Optional validateEnumPredicateForSubject( ExpressionTree lhs, ExpressionTree rhs, Optional subject, @@ -1240,10 +1253,7 @@ private static Optional validateEnumPredicateForSubject( ExpressionTree compileTimeConstant = lhsIsEnumConstant ? lhs : rhs; ExpressionTree testExpression = lhsIsEnumConstant ? rhs : lhs; - if (subject.isPresent() - && !(sameVariable(subject.get(), testExpression) - || subject.get().equals(testExpression) - || expressionSourceMatches(subject, testExpression, state))) { + if (subject.isPresent() && !subjectMatches(subject.get(), testExpression, state)) { return Optional.empty(); } @@ -1298,10 +1308,9 @@ private static Optional validateEnumPredicateForSubject( * comments, etc. */ private static boolean expressionSourceMatches( - Optional subject, ExpressionTree expression, VisitorState state) { + ExpressionTree subject, ExpressionTree expression, VisitorState state) { - return subject.isPresent() - && state.getSourceForNode(subject.get()).equals(state.getSourceForNode(expression)); + return state.getSourceForNode(subject).equals(state.getSourceForNode(expression)); } /** Retrieves a list of all statements (if any) following the current path, if any. */ diff --git a/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java b/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java index 62ec25ff8b1..26ed8481d4e 100644 --- a/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java +++ b/core/src/test/java/com/google/errorprone/bugpatterns/IfChainToSwitchTest.java @@ -327,18 +327,16 @@ public void ifChain_dontAlwaysPullUp_error() { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); - if (this.suit instanceof String) { + if (suit instanceof String) { System.out.println("It's a string!"); } else if (suit instanceof Number) { System.out.println("It's a number!"); } else if (suit instanceof Suit) { System.out.println("It's a Suit!"); - } else if (this.suit instanceof Object o) { + } else if (suit instanceof Object o) { System.out.println("It's an object!"); } throw new AssertionError(); @@ -351,10 +349,8 @@ public void foo(Suit s) { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); switch (suit) { case String unused -> System.out.println("It's a string!"); @@ -381,18 +377,16 @@ public void ifChain_dontAlwaysPullUpSafe_error() { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); - if (this.suit instanceof String) { + if (suit instanceof String) { System.out.println("It's a string!"); } else if (suit instanceof Number) { System.out.println("It's a number!"); } else if (suit instanceof Suit) { System.out.println("It's a Suit!"); - } else if (this.suit instanceof Object o) { + } else if (suit instanceof Object o) { System.out.println("It's an object!"); } throw new AssertionError(); @@ -405,10 +399,8 @@ public void foo(Suit s) { import java.lang.Number; class Test { - private Object suit; - public void foo(Suit s) { - this.suit = null; + Object suit = s; System.out.println("yo"); switch (suit) { case String unused -> System.out.println("It's a string!"); @@ -2375,6 +2367,69 @@ public void foo(Suit s) { .doTest(); } + @Test + public void ifChain_methodInvocation_error() { + refactoringHelper + .addInputLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + if (suit.hashCode() == 1) { + System.out.println("Hash code 1"); + } else if (suit.hashCode() == 2) { + System.out.println("Hash code 2"); + } else if (suit.hashCode() == 3) { + System.out.println("Hash code 3"); + } else throw new AssertionError("Some other hash code"); + } + } + """) + .addOutputLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + switch (suit.hashCode()) { + case 1 -> System.out.println("Hash code 1"); + case 2 -> System.out.println("Hash code 2"); + case 3 -> System.out.println("Hash code 3"); + default -> throw new AssertionError("Some other hash code"); + } + } + } + """) + .setArgs("-XepOpt:IfChainToSwitch:EnableMain") + .doTest(); + } + + @Test + public void ifChain_methodInvocationSafe_noError() { + // Same code as ifChain_methodInvocation_error, but should not refactor in safe mode because the + // subject is not the same variable. + helper + .addSourceLines( + "Test.java", + """ + class Test { + public void foo(Suit s) { + Object suit = s; + if (suit.hashCode() == 1) { + System.out.println("Hash code 1"); + } else if (suit.hashCode() == 2) { + System.out.println("Hash code 2"); + } else if (suit.hashCode() == 3) { + System.out.println("Hash code 3"); + } else throw new AssertionError("Some other hash code"); + } + } + """) + .setArgs("-XepOpt:IfChainToSwitch:EnableMain", "-XepOpt:IfChainToSwitch:EnableSafe=true") + .doTest(); + } + @Test public void ifChain_parameterizedTypeSafe_error() { // Raw types are converted to the wildcard type.