diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..542a828 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.9 + hooks: + - id: ruff-check + args: [--fix] + - id: ruff-format + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml diff --git a/pyproject.toml b/pyproject.toml index 21fd049..5b64fd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,15 @@ push-changes = false tag-format = "v{version}" tag-message = "unittest2pytest {version}" tag-signing = true + +[tool.ruff] +target-version = "py39" +extend-exclude = ["tests/fixtures"] + +[tool.ruff.lint] +select = ["E", "F", "I", "W", "UP"] +ignore = ["E501", "E701", "E741", "UP031"] + +[tool.ruff.lint.per-file-ignores] +"unittest2pytest/fixes/fix_remove_class.py" = ["W291"] +"unittest2pytest/fixes/fix_self_assert.py" = ["F841"] diff --git a/tests/test_all_fixes.py b/tests/test_all_fixes.py index f288143..9926b83 100644 --- a/tests/test_all_fixes.py +++ b/tests/test_all_fixes.py @@ -21,45 +21,43 @@ # along with this program. If not, see . # - __author__ = "Hartmut Goebel " __copyright__ = "Copyright 2015-2019 by Hartmut Goebel" __licence__ = "GNU General Public License version 3 or later (GPLv3+)" -import pytest - - +import glob +import logging import os -from os.path import join, abspath import re -import glob -import shutil -from difflib import unified_diff import unittest -import logging +from difflib import unified_diff +from os.path import abspath, join +import pytest from fissix.main import main # make logging less verbose -logging.getLogger('fissix.main').setLevel(logging.WARN) -logging.getLogger('RefactoringTool').setLevel(logging.WARN) +logging.getLogger("fissix.main").setLevel(logging.WARN) +logging.getLogger("RefactoringTool").setLevel(logging.WARN) + +FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures") -FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') def requiredTestMethod(name): # skip if TestCase does not have this method is_missing = getattr(unittest.TestCase, name, None) is None - return pytest.mark.skipif(is_missing, - reason="unittest does not have TestCase.%s " % name) + return pytest.mark.skipif( + is_missing, reason="unittest does not have TestCase.%s " % name + ) def _collect_in_files_from_directory(directory): - fixture_files = glob.glob(abspath(join(directory, '*_in.py'))) + fixture_files = glob.glob(abspath(join(directory, "*_in.py"))) for fixture_file in fixture_files: with open(fixture_file) as fh: text = fh.read(200) - l = re.findall(r'^# required-method: (\S+)', text) + l = re.findall(r"^# required-method: (\S+)", text) method = l[0] if l else None yield fixture_file, method @@ -70,7 +68,7 @@ def collect_all_test_fixtures(): # subdirectory, only run the fixer of the subdirectory name, else run # all fixers. for in_file, method in _collect_in_files_from_directory(root): - fixer_to_run = root[len(FIXTURE_PATH)+1:] or None + fixer_to_run = root[len(FIXTURE_PATH) + 1 :] or None marks = [] if method: marks.append(requiredTestMethod(method)) @@ -82,20 +80,39 @@ def _get_id(argvalue): return os.path.basename(argvalue).replace("_in.py", "") -@pytest.mark.parametrize("fixer, in_file", - collect_all_test_fixtures(), ids=_get_id) +@pytest.mark.parametrize("fixer, in_file", collect_all_test_fixtures(), ids=_get_id) def test_check_fixture(in_file, fixer, tmpdir): if fixer: - main("unittest2pytest.fixes", - args=['--no-diffs', '--fix', fixer, '-w', in_file, - '--nobackups', '--output-dir', str(tmpdir)]) + main( + "unittest2pytest.fixes", + args=[ + "--no-diffs", + "--fix", + fixer, + "-w", + in_file, + "--nobackups", + "--output-dir", + str(tmpdir), + ], + ) else: - main("unittest2pytest.fixes", - args=['--no-diffs', '--fix', 'all', '-w', in_file, - '--nobackups', '--output-dir', str(tmpdir)]) + main( + "unittest2pytest.fixes", + args=[ + "--no-diffs", + "--fix", + "all", + "-w", + in_file, + "--nobackups", + "--output-dir", + str(tmpdir), + ], + ) result_file_name = tmpdir.join(os.path.basename(in_file)) - assert result_file_name.exists(), '%s is missing' % result_file_name + assert result_file_name.exists(), "%s is missing" % result_file_name result_file_contents = result_file_name.readlines() expected_file = in_file.replace("_in.py", "_out.py") @@ -104,13 +121,15 @@ def test_check_fixture(in_file, fixer, tmpdir): # ensure the expected code is actually correct and compiles try: - compile(''.join(expected_contents), expected_file, 'exec') + compile("".join(expected_contents), expected_file, "exec") except Exception as e: - pytest.fail(f"FATAL: {expected_file} does not compile: {e}", - False) + pytest.fail(f"FATAL: {expected_file} does not compile: {e}", False) if result_file_contents != expected_contents: text = "Refactured code doesn't match expected outcome\n" - text += ''.join(unified_diff(expected_contents, result_file_contents, - 'expected', 'refactured result')) + text += "".join( + unified_diff( + expected_contents, result_file_contents, "expected", "refactured result" + ) + ) pytest.fail(text, False) diff --git a/unittest2pytest/__init__.py b/unittest2pytest/__init__.py index c34a1a1..22e45fb 100644 --- a/unittest2pytest/__init__.py +++ b/unittest2pytest/__init__.py @@ -22,5 +22,5 @@ __licence__ = "GNU General Public License version 3 or later (GPLv3+)" -__title__ = 'unittest2pytest' -__version__ = '0.6.dev0' +__title__ = "unittest2pytest" +__version__ = "0.6.dev0" diff --git a/unittest2pytest/__main__.py b/unittest2pytest/__main__.py index 1feb059..513bcad 100755 --- a/unittest2pytest/__main__.py +++ b/unittest2pytest/__main__.py @@ -23,10 +23,13 @@ import fissix.main + from . import fixes + def main(): raise SystemExit(fissix.main.main(fixes.__name__)) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/unittest2pytest/fixes/fix_remove_class.py b/unittest2pytest/fixes/fix_remove_class.py index aee86b9..e51a2e4 100644 --- a/unittest2pytest/fixes/fix_remove_class.py +++ b/unittest2pytest/fixes/fix_remove_class.py @@ -27,7 +27,7 @@ from fissix.fixer_base import BaseFix -from fissix.fixer_util import token, find_indentation +from fissix.fixer_util import find_indentation, token """ Node(classdef, @@ -47,8 +47,8 @@ Leaf(6, '')])]) """ -class FixRemoveClass(BaseFix): +class FixRemoveClass(BaseFix): PATTERN = """ classdef< 'class' name=any '(' 'TestCase' ')' ':' suite=suite @@ -67,10 +67,9 @@ def dedent(self, suite, dedent): # todo: handle tabs if len(kid.prefix) > len(self.current_indent): kid.prefix = self.current_indent - def transform(self, node, results): - suite = results['suite'].clone() + suite = results["suite"].clone() # todo: handle tabs dedent = len(find_indentation(suite)) - len(find_indentation(node)) self.dedent(suite, dedent) diff --git a/unittest2pytest/fixes/fix_self_assert.py b/unittest2pytest/fixes/fix_self_assert.py index 18c4f9f..e4c9992 100644 --- a/unittest2pytest/fixes/fix_self_assert.py +++ b/unittest2pytest/fixes/fix_self_assert.py @@ -26,21 +26,33 @@ __licence__ = "GNU General Public License version 3 or later (GPLv3+)" -from fissix.fixer_base import BaseFix -from fissix.fixer_util import ( - Comma, Name, Call, Node, Leaf, - Newline, KeywordArg, find_indentation, - ArgList, String, Number, syms, token, - does_tree_import, is_import, parenthesize) - -from functools import partial import re import unittest +from functools import partial + +from fissix.fixer_base import BaseFix +from fissix.fixer_util import ( + Call, + Comma, + KeywordArg, + Leaf, + Name, + Newline, + Node, + Number, + String, + does_tree_import, + find_indentation, + is_import, + parenthesize, + syms, + token, +) from .. import utils +TEMPLATE_PATTERN = re.compile("[\1\2]|[^\1\2]+") -TEMPLATE_PATTERN = re.compile('[\1\2]|[^\1\2]+') def CompOp(op, left, right, kws): op = Name(op, prefix=" ") @@ -48,7 +60,7 @@ def CompOp(op, left, right, kws): right = parenthesize_expression(right) left.prefix = "" - if '\n' not in right.prefix: + if "\n" not in right.prefix: right.prefix = " " return Node(syms.comparison, (left, op, right), prefix=" ") @@ -78,11 +90,12 @@ def UnaryOp(prefix, postfix, value, kws): syms.comparison, ] + def parenthesize_expression(value): if value.type in _NEEDS_PARENTHESIS: parenthesized = parenthesize(value.clone()) parenthesized.prefix = parenthesized.children[1].prefix - parenthesized.children[1].prefix = '' + parenthesized.children[1].prefix = "" value = parenthesized return value @@ -91,43 +104,43 @@ def fill_template(template, *args): parts = TEMPLATE_PATTERN.findall(template) kids = [] for p in parts: - if p == '': + if p == "": continue - elif p in '\1\2\3\4\5': - p = args[ord(p)-1] - p.prefix = '' + elif p in "\1\2\3\4\5": + p = args[ord(p) - 1] + p.prefix = "" else: p = Name(p) kids.append(p.clone()) return kids + def DualOp(template, first, second, kws): kids = fill_template(template, first, second) return Node(syms.test, kids, prefix=" ") def SequenceEqual(left, right, kws): - if 'seq_type' in kws: + if "seq_type" in kws: # :todo: implement `assert isinstance(xx, seq_type)` pass - return CompOp('==', left, right, kws) + return CompOp("==", left, right, kws) def AlmostOp(places_op, delta_op, first, second, kws): - first.prefix = "" + first.prefix = "" second.prefix = "" first = parenthesize_expression(first) second = parenthesize_expression(second) - abs_op = Call(Name('abs'), - [Node(syms.factor, [first, Name('-'), second])]) - if kws.get('delta', None) is not None: + abs_op = Call(Name("abs"), [Node(syms.factor, [first, Name("-"), second])]) + if kws.get("delta", None) is not None: # delta - return CompOp(delta_op, abs_op, kws['delta'], {}) + return CompOp(delta_op, abs_op, kws["delta"], {}) else: # `7` is the default in unittest.TestCase.assertAlmostEqual - places = kws['places'] or Number(7) + places = kws["places"] or Number(7) places.prefix = " " - round_op = Call(Name('round'), (abs_op, Comma(), places)) + round_op = Call(Name("round"), (abs_op, Comma(), places)) return CompOp(places_op, round_op, Number(0), {}) @@ -135,28 +148,27 @@ def RaisesOp(context, exceptionClass, indent, kws, arglist, node): exceptionClass.prefix = "" args = [exceptionClass] # Add match keyword arg to with statement if an expected regex was provided. - if 'expected_regex' in kws: - expected_regex = kws.get('expected_regex').clone() - expected_regex.prefix = '' - args.append(String(', ')) - args.append( - KeywordArg(Name('match'), expected_regex)) + if "expected_regex" in kws: + expected_regex = kws.get("expected_regex").clone() + expected_regex.prefix = "" + args.append(String(", ")) + args.append(KeywordArg(Name("match"), expected_regex)) with_item = Call(Name(context), args) with_item.prefix = " " args = [] arglist = [a.clone() for a in arglist.children[4:]] if arglist: - arglist[0].prefix="" + arglist[0].prefix = "" func = None # :fixme: this uses hardcoded parameter names, which may change - if 'callableObj' in kws: - func = kws['callableObj'] - elif 'callable_obj' in kws: - func = kws['callable_obj'] - elif kws['args']: # any arguments assigned to `*args` - func = kws['args'][0] + if "callableObj" in kws: + func = kws["callableObj"] + elif "callable_obj" in kws: + func = kws["callable_obj"] + elif kws["args"]: # any arguments assigned to `*args` + func = kws["args"][0] else: func = None @@ -173,18 +185,15 @@ def RaisesOp(context, exceptionClass, indent, kws, arglist, node): suite = Call(func, arglist) suite.prefix = indent + (4 * " ") - return Node(syms.with_stmt, - [Name('with'), - with_item, - Name(':'), - Newline(), - suite]) - -def RaisesRegexOp(context, designator, exceptionClass, expected_regex, - indent, kws, arglist, node): + return Node(syms.with_stmt, [Name("with"), with_item, Name(":"), Newline(), suite]) + + +def RaisesRegexOp( + context, designator, exceptionClass, expected_regex, indent, kws, arglist, node +): arglist = [a.clone() for a in arglist.children] pattern = arglist[2] - del arglist[2:4] # remove pattern and comma + del arglist[2:4] # remove pattern and comma arglist = Node(syms.arglist, arglist) with_stmt = RaisesOp(context, exceptionClass, indent, kws, arglist, node) @@ -200,11 +209,13 @@ def RaisesRegexOp(context, designator, exceptionClass, expected_regex, else: return Node(syms.suite, [with_stmt]) + def FailOp(indent, kws, arglist, node): new = node.clone() - new.set_child(0, Name('pytest')) + new.set_child(0, Name("pytest")) return new + def add_import(import_name, node): suite = get_parent_of_type(node, syms.suite) test_case = suite @@ -213,10 +224,13 @@ def add_import(import_name, node): file_input = test_case.parent if not does_tree_import(None, import_name, node): - import_stmt = Node(syms.simple_stmt, - [Node(syms.import_name, [Name('import'), Name(import_name, prefix=' ')]), - Newline(), - ]) + import_stmt = Node( + syms.simple_stmt, + [ + Node(syms.import_name, [Name("import"), Name(import_name, prefix=" ")]), + Newline(), + ], + ) insert_import(import_stmt, test_case, file_input) @@ -239,69 +253,61 @@ def insert_import(import_stmt, test_case, file_input): else: i = file_input.children.index(test_case) import_stmt.prefix = test_case.prefix - test_case.prefix = '' + test_case.prefix = "" file_input.insert_child(i, import_stmt) def get_import_nodes(node): return [ - x for c in node.children + x + for c in node.children for x in c.children - if c.type == syms.simple_stmt - and is_import(x) + if c.type == syms.simple_stmt and is_import(x) ] _method_map = { # simple ones - 'assertEqual': partial(CompOp, '=='), - 'assertNotEqual': partial(CompOp, '!='), - 'assertFalse': partial(UnaryOp, 'not', ''), - 'assertGreater': partial(CompOp, '>'), - 'assertGreaterEqual': partial(CompOp, '>='), - 'assertIn': partial(CompOp, 'in'), - 'assertIs': partial(CompOp, 'is'), - 'assertIsInstance': partial(DualOp, 'isinstance(\1, \2)'), - 'assertIsNone': partial(UnaryOp, '', 'is None'), - 'assertIsNot': partial(CompOp, 'is not'), - 'assertIsNotNone': partial(UnaryOp, '', 'is not None'), - 'assertLess': partial(CompOp, '<'), - 'assertLessEqual': partial(CompOp, '<='), - 'assertNotIn': partial(CompOp, 'not in'), - 'assertNotIsInstance': partial(DualOp, 'not isinstance(\1, \2)'), - 'assertTrue': partial(UnaryOp, '', ''), - + "assertEqual": partial(CompOp, "=="), + "assertNotEqual": partial(CompOp, "!="), + "assertFalse": partial(UnaryOp, "not", ""), + "assertGreater": partial(CompOp, ">"), + "assertGreaterEqual": partial(CompOp, ">="), + "assertIn": partial(CompOp, "in"), + "assertIs": partial(CompOp, "is"), + "assertIsInstance": partial(DualOp, "isinstance(\1, \2)"), + "assertIsNone": partial(UnaryOp, "", "is None"), + "assertIsNot": partial(CompOp, "is not"), + "assertIsNotNone": partial(UnaryOp, "", "is not None"), + "assertLess": partial(CompOp, "<"), + "assertLessEqual": partial(CompOp, "<="), + "assertNotIn": partial(CompOp, "not in"), + "assertNotIsInstance": partial(DualOp, "not isinstance(\1, \2)"), + "assertTrue": partial(UnaryOp, "", ""), # types ones - 'assertDictEqual': partial(CompOp, '=='), - 'assertListEqual': partial(CompOp, '=='), - 'assertMultiLineEqual': partial(CompOp, '=='), - 'assertSetEqual': partial(CompOp, '=='), - 'assertTupleEqual': partial(CompOp, '=='), - 'assertSequenceEqual': SequenceEqual, - - 'assertDictContainsSubset': partial(DualOp, '{**\2, **\1} == \2'), - 'assertItemsEqual': partial(DualOp, 'sorted(\1) == sorted(\2)'), - - 'assertAlmostEqual': partial(AlmostOp, "==", "<"), - 'assertNotAlmostEqual': partial(AlmostOp, "!=", ">"), - - 'assertRaises': partial(RaisesOp, 'pytest.raises'), - 'assertWarns': partial(RaisesOp, 'pytest.warns'), # new Py 3.2 - - 'assertRegex': partial(DualOp, 're.search(\2, \1)'), - 'assertNotRegex': partial(DualOp, 'not re.search(\2, \1)'), # new Py 3.2 - - 'assertRaisesRegex': partial(RaisesRegexOp, 'pytest.raises', 'excinfo'), - 'assertWarnsRegex': partial(RaisesRegexOp, 'pytest.warns', 'record'), - - 'fail': FailOp, - + "assertDictEqual": partial(CompOp, "=="), + "assertListEqual": partial(CompOp, "=="), + "assertMultiLineEqual": partial(CompOp, "=="), + "assertSetEqual": partial(CompOp, "=="), + "assertTupleEqual": partial(CompOp, "=="), + "assertSequenceEqual": SequenceEqual, + "assertDictContainsSubset": partial(DualOp, "{**\2, **\1} == \2"), + "assertItemsEqual": partial(DualOp, "sorted(\1) == sorted(\2)"), + "assertAlmostEqual": partial(AlmostOp, "==", "<"), + "assertNotAlmostEqual": partial(AlmostOp, "!=", ">"), + "assertRaises": partial(RaisesOp, "pytest.raises"), + "assertWarns": partial(RaisesOp, "pytest.warns"), # new Py 3.2 + "assertRegex": partial(DualOp, "re.search(\2, \1)"), + "assertNotRegex": partial(DualOp, "not re.search(\2, \1)"), # new Py 3.2 + "assertRaisesRegex": partial(RaisesRegexOp, "pytest.raises", "excinfo"), + "assertWarnsRegex": partial(RaisesRegexOp, "pytest.warns", "record"), + "fail": FailOp, #'assertLogs': -- not to be handled here, is an context handler only } for newname, oldname in ( - ('assertRaisesRegex', 'assertRaisesRegexp'), - ('assertRegex', 'assertRegexpMatches'), + ("assertRaisesRegex", "assertRaisesRegexp"), + ("assertRegex", "assertRegexpMatches"), ): if not hasattr(unittest.TestCase, newname): # use old name @@ -315,25 +321,24 @@ def get_import_nodes(node): # (Deprecated) Aliases _method_aliases = { - 'assertEquals' : 'assertEqual', - 'assertNotEquals' : 'assertNotEqual', - 'assert_' : 'assertTrue', - 'assertAlmostEquals' : 'assertAlmostEqual', - 'assertNotAlmostEquals': 'assertNotAlmostEqual', - 'assertRegexpMatches' : 'assertRegex', - 'assertRaisesRegexp' : 'assertRaisesRegex', - - 'failUnlessEqual' : 'assertEqual', - 'failIfEqual' : 'assertNotEqual', - 'failUnless' : 'assertTrue', - 'failIf' : 'assertFalse', - 'failUnlessRaises' : 'assertRaises', - 'failUnlessAlmostEqual': 'assertAlmostEqual', - 'failIfAlmostEqual' : 'assertNotAlmostEqual', + "assertEquals": "assertEqual", + "assertNotEquals": "assertNotEqual", + "assert_": "assertTrue", + "assertAlmostEquals": "assertAlmostEqual", + "assertNotAlmostEquals": "assertNotAlmostEqual", + "assertRegexpMatches": "assertRegex", + "assertRaisesRegexp": "assertRaisesRegex", + "failUnlessEqual": "assertEqual", + "failIfEqual": "assertNotEqual", + "failUnless": "assertTrue", + "failIf": "assertFalse", + "failUnlessRaises": "assertRaises", + "failUnlessAlmostEqual": "assertAlmostEqual", + "failIfAlmostEqual": "assertNotAlmostEqual", } for a, o in list(_method_aliases.items()): - if not o in _method_map: + if o not in _method_map: # if the original name is not a TestCase method, remove the alias del _method_aliases[a] @@ -378,49 +383,57 @@ def get_import_nodes(node): class FixSelfAssert(BaseFix): - PATTERN = """ power< 'self' trailer< '.' method=( %s ) > trailer< '(' [arglist=any] ')' > > - """ % ' | '.join(map(repr, - (set(_method_map.keys()) | set(_method_aliases.keys())))) + """ % " | ".join(map(repr, (set(_method_map.keys()) | set(_method_aliases.keys())))) def transform(self, node, results): def process_arg(arg): if isinstance(arg, Leaf) and arg.type == token.COMMA: return - elif (isinstance(arg, Node) and arg.type == syms.argument and - arg.children[1].type == token.EQUAL): + elif ( + isinstance(arg, Node) + and arg.type == syms.argument + and arg.children[1].type == token.EQUAL + ): # keyword argument name, equal, value = arg.children assert name.type == token.NAME assert equal.type == token.EQUAL value = value.clone() kwargs[name.value] = value - if '\n' in arg.prefix: + if "\n" in arg.prefix: value.prefix = arg.prefix else: value.prefix = arg.prefix.strip() + " " else: - if (isinstance(arg, Node) and arg.type == syms.argument and - arg.children[0].type == 36 and arg.children[0].value == '**'): + if ( + isinstance(arg, Node) + and arg.type == syms.argument + and arg.children[0].type == 36 + and arg.children[0].value == "**" + ): return - assert not kwargs, 'all positional args are assumed to come first' - if (isinstance(arg, Node) and arg.type == syms.argument and - arg.children[1].type == syms.comp_for): + assert not kwargs, "all positional args are assumed to come first" + if ( + isinstance(arg, Node) + and arg.type == syms.argument + and arg.children[1].type == syms.comp_for + ): # argument is a generator expression w/o # parenthesis, add parenthesis value = arg.clone() - value.children.insert(0, Leaf(token.LPAR, '(')) - value.children.append(Leaf(token.RPAR, ')')) + value.children.insert(0, Leaf(token.LPAR, "(")) + value.children.append(Leaf(token.RPAR, ")")) posargs.append(value) else: posargs.append(arg.clone()) - method = results['method'][0].value + method = results["method"][0].value # map (deprecated) aliases to original to avoid analysing # the decorator function method = _method_aliases.get(method, method) @@ -429,52 +442,57 @@ def process_arg(arg): kwargs = {} # This is either empty, an "arglist", or a single argument - if 'arglist' not in results: + if "arglist" not in results: pass - elif results['arglist'].type == syms.arglist: - for arg in results['arglist'].children: + elif results["arglist"].type == syms.arglist: + for arg in results["arglist"].children: process_arg(arg) else: - process_arg(results['arglist']) + process_arg(results["arglist"]) try: test_func = getattr(unittest.TestCase, method) except AttributeError: - raise RuntimeError("Your unittest package does not support '%s'. " - "consider updating the package" % method) + raise RuntimeError( + "Your unittest package does not support '%s'. " + "consider updating the package" % method + ) required_args, argsdict = utils.resolve_func_args(test_func, posargs, kwargs) - if method.startswith(('assertRaises', 'assertWarns')) or method == 'fail': - n_stmt = _method_map[method](*required_args, - indent=find_indentation(node), - kws=argsdict, - arglist=results.get('arglist'), - node=node) + if method.startswith(("assertRaises", "assertWarns")) or method == "fail": + n_stmt = _method_map[method]( + *required_args, + indent=find_indentation(node), + kws=argsdict, + arglist=results.get("arglist"), + node=node, + ) else: - n_stmt = Node(syms.assert_stmt, - [Name('assert'), - _method_map[method](*required_args, kws=argsdict)]) - if argsdict.get('msg', None) is not None and method != 'fail': - n_stmt.children.extend((Name(','), argsdict['msg'])) + n_stmt = Node( + syms.assert_stmt, + [Name("assert"), _method_map[method](*required_args, kws=argsdict)], + ) + if argsdict.get("msg", None) is not None and method != "fail": + n_stmt.children.extend((Name(","), argsdict["msg"])) def fix_line_wrapping(x): for c in x.children: # no need to worry about wrapping of "[", "{" and "(" if c.type in [token.LSQB, token.LBRACE, token.LPAR]: break - if c.prefix.startswith('\n'): - c.prefix = c.prefix.replace('\n', ' \\\n') + if c.prefix.startswith("\n"): + c.prefix = c.prefix.replace("\n", " \\\n") fix_line_wrapping(c) + fix_line_wrapping(n_stmt) # the prefix should be set only after fixing line wrapping because it can contain a '\n' n_stmt.prefix = node.prefix # add necessary imports - if 'Raises' in method or 'Warns' in method or method == 'fail': - add_import('pytest', node) - if ('Regex' in method and not 'Raises' in method and - not 'Warns' in method): - add_import('re', node) + if "Raises" in method or "Warns" in method or method == "fail": + add_import("pytest", node) + if "Regex" in method and "Raises" not in method and "Warns" not in method: + add_import("re", node) return n_stmt diff --git a/unittest2pytest/utils.py b/unittest2pytest/utils.py index c223830..753d58a 100644 --- a/unittest2pytest/utils.py +++ b/unittest2pytest/utils.py @@ -29,22 +29,28 @@ from inspect import Parameter -class SelfMarker: pass +class SelfMarker: + pass def resolve_func_args(test_func, posargs, kwargs): sig = inspect.signature(test_func) - assert (list(iter(sig.parameters))[0] == 'self') + assert list(iter(sig.parameters))[0] == "self" posargs.insert(0, SelfMarker) ba = sig.bind(*posargs, **kwargs) ba.apply_defaults() args = ba.arguments - required_args = [n for n,v in sig.parameters.items() - if (v.default is Parameter.empty and - v.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD))] - assert args['self'] == SelfMarker - assert required_args[0] == 'self' - del required_args[0], args['self'] + required_args = [ + n + for n, v in sig.parameters.items() + if ( + v.default is Parameter.empty + and v.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD) + ) + ] + assert args["self"] == SelfMarker + assert required_args[0] == "self" + del required_args[0], args["self"] required_args = [args[n] for n in required_args] return required_args, args