diff --git a/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py index 25c02456..0d1fda6c 100644 --- a/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py +++ b/src/ecooptimizer/refactorers/concrete/member_ignoring_method.py @@ -10,13 +10,15 @@ from ..multi_file_refactorer import MultiFileRefactorer from ...data_types.smell import MIMSmell +logger = CONFIG["refactorLogger"] + class CallTransformer(cst.CSTTransformer): METADATA_DEPENDENCIES = (PositionProvider,) def __init__(self, class_name: str): self.method_calls: list[tuple[str, int, str, str]] = None # type: ignore - self.class_name = class_name # Class name to replace instance calls + self.class_name = class_name # Class nme to replace instance calls self.transformed = False def set_calls(self, valid_calls: list[tuple[str, int, str, str]]): @@ -34,15 +36,13 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal # Check if this call matches one from astroid (by caller, method name, and line number) for call_caller, line, call_method, cls in self.method_calls: - CONFIG["refactorLogger"].debug( - f"cst caller: {call_caller} at line {position.start.line}" - ) + logger.debug(f"cst caller: {call_caller} at line {position.start.line}") if ( method == call_method and position.start.line == line and caller.deep_equals(cst.parse_expression(call_caller)) ): - CONFIG["refactorLogger"].debug("transforming") + logger.debug("transforming") # Transform `obj.method(args)` -> `ClassName.method(args)` new_func = cst.Attribute( value=cst.Name(cls), # Replace `obj` with class name @@ -65,12 +65,12 @@ def find_valid_method_calls( """ valid_calls = [] - CONFIG["refactorLogger"].info("Finding valid method calls") + logger.debug("Finding valid method calls") for node in tree.body: for descendant in node.nodes_of_class(nodes.Call): if isinstance(descendant.func, nodes.Attribute): - CONFIG["refactorLogger"].debug(f"caller: {descendant.func.expr.as_string()}") + logger.debug(f"caller: {descendant.func.expr.as_string()}") caller = descendant.func.expr # The object calling the method method_name = descendant.func.attrname @@ -78,28 +78,32 @@ def find_valid_method_calls( continue inferred_types: list[str] = [] - inferrences = caller.infer() - - for inferred in inferrences: - CONFIG["refactorLogger"].debug(f"inferred: {inferred.repr_name()}") - if isinstance(inferred, util.UninferableBase): - hint = check_for_annotations(caller, descendant.scope()) - inits = check_for_initializations(caller, descendant.scope()) - if hint: - inferred_types.append(hint.as_string()) - elif inits: - inferred_types.extend(inits) + try: + inferrences = caller.infer() + + for inferred in inferrences: + logger.debug(f"inferred: {inferred.repr_name()}") + if isinstance(inferred, util.UninferableBase): + hint = check_for_annotations(caller, descendant.scope()) + inits = check_for_initializations(caller, descendant.scope()) + if hint: + inferred_types.append(hint.as_string()) + elif inits: + inferred_types.extend(inits) + else: + continue else: - continue - else: - inferred_types.append(inferred.repr_name()) + inferred_types.append(inferred.repr_name()) + except astroid.InferenceError as e: + print(e) + continue - CONFIG["refactorLogger"].debug(f"Inferred types: {inferred_types}") + logger.debug(f"Inferred types: {inferred_types}") # Check if any inferred type matches a valid class for cls in inferred_types: if cls in valid_classes: - CONFIG["refactorLogger"].debug( + logger.debug( f"Foud valid call: {caller.as_string()} at line {descendant.lineno}" ) valid_calls.append( @@ -127,7 +131,7 @@ def check_for_annotations(caller: nodes.NodeNG, scope: nodes.NodeNG): return None hint = None - CONFIG["refactorLogger"].debug(f"annotations: {scope.args}") + logger.debug(f"annotations: {scope.args}") args = scope.args.args anns = scope.args.annotations @@ -162,6 +166,8 @@ def refactor( self.target_line = smell.occurences[0].line self.target_file = target_file + print("smell:", smell) + if not smell.obj: raise TypeError("No method object found") @@ -194,12 +200,12 @@ def get_subclasses(tree: nodes.Module): subclasses.add(klass.name) return subclasses - CONFIG["refactorLogger"].debug("find all subclasses") + logger.debug("find all subclasses") self.traverse(directory) for file in self.py_files: tree = astroid.parse(file.read_text()) self.valid_classes = self.valid_classes.union(get_subclasses(tree)) - CONFIG["refactorLogger"].debug(f"valid classes: {self.valid_classes}") + logger.debug(f"valid classes: {self.valid_classes}") def _process_file(self, file: Path): processed = False @@ -228,7 +234,7 @@ def leave_FunctionDef( if func_name and updated_node.deep_equals(original_node): position = self.get_metadata(PositionProvider, original_node).start # type: ignore if position.line == self.target_line and func_name == self.mim_method: - CONFIG["refactorLogger"].debug("Modifying MIM method") + logger.debug("Modifying MIM method") decorators = [ *list(original_node.decorators), cst.Decorator(cst.Name("staticmethod")), diff --git a/src/ecooptimizer/refactorers/multi_file_refactorer.py b/src/ecooptimizer/refactorers/multi_file_refactorer.py index f5ee57e0..77d8dc4f 100644 --- a/src/ecooptimizer/refactorers/multi_file_refactorer.py +++ b/src/ecooptimizer/refactorers/multi_file_refactorer.py @@ -60,7 +60,7 @@ def traverse(self, directory: Path): continue CONFIG["refactorLogger"].debug(f"Entering directory: {item!s}") - self.traverse_and_process(item) + self.traverse(item) elif item.is_file() and item.suffix == ".py": self.py_files.append(item) diff --git a/tests/refactorers/test_member_ignoring_method.py b/tests/refactorers/test_member_ignoring_method.py index 1531049b..2b930a57 100644 --- a/tests/refactorers/test_member_ignoring_method.py +++ b/tests/refactorers/test_member_ignoring_method.py @@ -101,7 +101,6 @@ def mim_method(x): result = Example.mim_method(5) """) - # Check if the refactoring worked assert file1.read_text().strip() == expected_file1.strip() assert file2.read_text().strip() == expected_file2.strip() @@ -169,7 +168,6 @@ class SubExample(Example): result = SubExample.mim_method(5) """) - # Check if the refactoring worked assert file1.read_text().strip() == expected_file1.strip() assert file2.read_text().strip() == expected_file2.strip() @@ -239,7 +237,6 @@ class SubExample(Example): result = SubExample.mim_method(5) """) - # Check if the refactoring worked assert file1.read_text().strip() == expected_file1.strip() assert file2.read_text().strip() == expected_file2.strip() @@ -309,7 +306,6 @@ def mim_method(self, x): result = example.mim_method(5) """) - # Check if the refactoring worked assert file1.read_text().strip() == expected_file1.strip() assert file2.read_text().strip() == expected_file2.strip() @@ -360,5 +356,257 @@ def test(example: Example): num = Example.mim_method(5) """) - # Check if the refactoring worked + assert file1.read_text().strip() == expected_file1.strip() + + +def test_mim_multiple_classes_same_method_name(source_files, refactorer): + """ + Tests that only the method call from the correct class instance is refactored + when there are multiple method calls with the same method name but from + instances of different classes. + """ + + # --- File 1: Defines the methods in different classes --- + test_dir = Path(source_files, "temp_multiple_classes_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + def mim_method(self, x): + return x * 2 + + class AnotherExample: + def mim_method(self, x): + return x + 3 + + example = Example() + another_example = AnotherExample() + num1 = example.mim_method(5) + num2 = another_example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[4], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + @staticmethod + def mim_method(x): + return x * 2 + + class AnotherExample: + def mim_method(self, x): + return x + 3 + + example = Example() + another_example = AnotherExample() + num1 = Example.mim_method(5) + num2 = another_example.mim_method(5) + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_mim_ignores_wrong_method_call(source_files, refactorer): + """ + Tests that a different method call from the same class is not refactored. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_mim_type_hint_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + def mim_method(self, x): + return x * 2 + + def other_method(self): + print(self.attr) + + example = Example() + example.other_method() + """) + ) + + smell = create_smell(occurences=[5], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + @staticmethod + def mim_method(x): + return x * 2 + + def other_method(self): + print(self.attr) + + example = Example() + example.other_method() + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_mim_method_in_class_with_decorator(source_files, refactorer): + """ + Tests that methods in classes with decorators (e.g., @dataclass) are correctly refactored. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_decorated_class_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + from dataclasses import dataclass + @dataclass + class Example: + attr: str + + def mim_method(self, x): + return x * 2 + + example = Example(attr="something") + num = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[6], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + from dataclasses import dataclass + @dataclass + class Example: + attr: str + + @staticmethod + def mim_method(x): + return x * 2 + + example = Example(attr="something") + num = Example.mim_method(5) + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_mim_method_with_existing_decorator(source_files, refactorer): + """ + Tests that methods with existing decorators retain those decorators + when the @staticmethod decorator is added. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_existing_decorator_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + @custom_decorator + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[6], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + @custom_decorator + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + + assert file1.read_text().strip() == expected_file1.strip() + + +def test_mim_method_with_multiple_decorators(source_files, refactorer): + """ + Tests that methods with multiple existing decorators retain all of them + when the @staticmethod decorator is added. + """ + + # --- File 1: Defines the method --- + test_dir = Path(source_files, "temp_multiple_decorators_mim") + test_dir.mkdir(exist_ok=True) + + file1 = test_dir / "class_def.py" + file1.write_text( + textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + @decorator_one + @decorator_two + def mim_method(self, x): + return x * 2 + + example = Example() + num = example.mim_method(5) + """) + ) + + smell = create_smell(occurences=[7], obj="Example.mim_method")() + + refactorer.refactor(file1, test_dir, smell, Path("fake.py")) + + # --- Expected Result for File 1 --- + expected_file1 = textwrap.dedent("""\ + class Example: + def __init__(self): + self.attr = "something" + + @decorator_one + @decorator_two + @staticmethod + def mim_method(x): + return x * 2 + + example = Example() + num = Example.mim_method(5) + """) + assert file1.read_text().strip() == expected_file1.strip()