diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..542a828
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,13 @@
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.15.9
+ hooks:
+ - id: ruff-check
+ args: [--fix]
+ - id: ruff-format
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v6.0.0
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
diff --git a/pyproject.toml b/pyproject.toml
index 21fd049..5b64fd6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -51,3 +51,15 @@ push-changes = false
tag-format = "v{version}"
tag-message = "unittest2pytest {version}"
tag-signing = true
+
+[tool.ruff]
+target-version = "py39"
+extend-exclude = ["tests/fixtures"]
+
+[tool.ruff.lint]
+select = ["E", "F", "I", "W", "UP"]
+ignore = ["E501", "E701", "E741", "UP031"]
+
+[tool.ruff.lint.per-file-ignores]
+"unittest2pytest/fixes/fix_remove_class.py" = ["W291"]
+"unittest2pytest/fixes/fix_self_assert.py" = ["F841"]
diff --git a/tests/test_all_fixes.py b/tests/test_all_fixes.py
index f288143..9926b83 100644
--- a/tests/test_all_fixes.py
+++ b/tests/test_all_fixes.py
@@ -21,45 +21,43 @@
# along with this program. If not, see .
#
-
__author__ = "Hartmut Goebel "
__copyright__ = "Copyright 2015-2019 by Hartmut Goebel"
__licence__ = "GNU General Public License version 3 or later (GPLv3+)"
-import pytest
-
-
+import glob
+import logging
import os
-from os.path import join, abspath
import re
-import glob
-import shutil
-from difflib import unified_diff
import unittest
-import logging
+from difflib import unified_diff
+from os.path import abspath, join
+import pytest
from fissix.main import main
# make logging less verbose
-logging.getLogger('fissix.main').setLevel(logging.WARN)
-logging.getLogger('RefactoringTool').setLevel(logging.WARN)
+logging.getLogger("fissix.main").setLevel(logging.WARN)
+logging.getLogger("RefactoringTool").setLevel(logging.WARN)
+
+FIXTURE_PATH = os.path.join(os.path.dirname(__file__), "fixtures")
-FIXTURE_PATH = os.path.join(os.path.dirname(__file__), 'fixtures')
def requiredTestMethod(name):
# skip if TestCase does not have this method
is_missing = getattr(unittest.TestCase, name, None) is None
- return pytest.mark.skipif(is_missing,
- reason="unittest does not have TestCase.%s " % name)
+ return pytest.mark.skipif(
+ is_missing, reason="unittest does not have TestCase.%s " % name
+ )
def _collect_in_files_from_directory(directory):
- fixture_files = glob.glob(abspath(join(directory, '*_in.py')))
+ fixture_files = glob.glob(abspath(join(directory, "*_in.py")))
for fixture_file in fixture_files:
with open(fixture_file) as fh:
text = fh.read(200)
- l = re.findall(r'^# required-method: (\S+)', text)
+ l = re.findall(r"^# required-method: (\S+)", text)
method = l[0] if l else None
yield fixture_file, method
@@ -70,7 +68,7 @@ def collect_all_test_fixtures():
# subdirectory, only run the fixer of the subdirectory name, else run
# all fixers.
for in_file, method in _collect_in_files_from_directory(root):
- fixer_to_run = root[len(FIXTURE_PATH)+1:] or None
+ fixer_to_run = root[len(FIXTURE_PATH) + 1 :] or None
marks = []
if method:
marks.append(requiredTestMethod(method))
@@ -82,20 +80,39 @@ def _get_id(argvalue):
return os.path.basename(argvalue).replace("_in.py", "")
-@pytest.mark.parametrize("fixer, in_file",
- collect_all_test_fixtures(), ids=_get_id)
+@pytest.mark.parametrize("fixer, in_file", collect_all_test_fixtures(), ids=_get_id)
def test_check_fixture(in_file, fixer, tmpdir):
if fixer:
- main("unittest2pytest.fixes",
- args=['--no-diffs', '--fix', fixer, '-w', in_file,
- '--nobackups', '--output-dir', str(tmpdir)])
+ main(
+ "unittest2pytest.fixes",
+ args=[
+ "--no-diffs",
+ "--fix",
+ fixer,
+ "-w",
+ in_file,
+ "--nobackups",
+ "--output-dir",
+ str(tmpdir),
+ ],
+ )
else:
- main("unittest2pytest.fixes",
- args=['--no-diffs', '--fix', 'all', '-w', in_file,
- '--nobackups', '--output-dir', str(tmpdir)])
+ main(
+ "unittest2pytest.fixes",
+ args=[
+ "--no-diffs",
+ "--fix",
+ "all",
+ "-w",
+ in_file,
+ "--nobackups",
+ "--output-dir",
+ str(tmpdir),
+ ],
+ )
result_file_name = tmpdir.join(os.path.basename(in_file))
- assert result_file_name.exists(), '%s is missing' % result_file_name
+ assert result_file_name.exists(), "%s is missing" % result_file_name
result_file_contents = result_file_name.readlines()
expected_file = in_file.replace("_in.py", "_out.py")
@@ -104,13 +121,15 @@ def test_check_fixture(in_file, fixer, tmpdir):
# ensure the expected code is actually correct and compiles
try:
- compile(''.join(expected_contents), expected_file, 'exec')
+ compile("".join(expected_contents), expected_file, "exec")
except Exception as e:
- pytest.fail(f"FATAL: {expected_file} does not compile: {e}",
- False)
+ pytest.fail(f"FATAL: {expected_file} does not compile: {e}", False)
if result_file_contents != expected_contents:
text = "Refactured code doesn't match expected outcome\n"
- text += ''.join(unified_diff(expected_contents, result_file_contents,
- 'expected', 'refactured result'))
+ text += "".join(
+ unified_diff(
+ expected_contents, result_file_contents, "expected", "refactured result"
+ )
+ )
pytest.fail(text, False)
diff --git a/unittest2pytest/__init__.py b/unittest2pytest/__init__.py
index c34a1a1..22e45fb 100644
--- a/unittest2pytest/__init__.py
+++ b/unittest2pytest/__init__.py
@@ -22,5 +22,5 @@
__licence__ = "GNU General Public License version 3 or later (GPLv3+)"
-__title__ = 'unittest2pytest'
-__version__ = '0.6.dev0'
+__title__ = "unittest2pytest"
+__version__ = "0.6.dev0"
diff --git a/unittest2pytest/__main__.py b/unittest2pytest/__main__.py
index 1feb059..513bcad 100755
--- a/unittest2pytest/__main__.py
+++ b/unittest2pytest/__main__.py
@@ -23,10 +23,13 @@
import fissix.main
+
from . import fixes
+
def main():
raise SystemExit(fissix.main.main(fixes.__name__))
-if __name__ == '__main__':
+
+if __name__ == "__main__":
main()
diff --git a/unittest2pytest/fixes/fix_remove_class.py b/unittest2pytest/fixes/fix_remove_class.py
index aee86b9..e51a2e4 100644
--- a/unittest2pytest/fixes/fix_remove_class.py
+++ b/unittest2pytest/fixes/fix_remove_class.py
@@ -27,7 +27,7 @@
from fissix.fixer_base import BaseFix
-from fissix.fixer_util import token, find_indentation
+from fissix.fixer_util import find_indentation, token
"""
Node(classdef,
@@ -47,8 +47,8 @@
Leaf(6, '')])])
"""
-class FixRemoveClass(BaseFix):
+class FixRemoveClass(BaseFix):
PATTERN = """
classdef< 'class' name=any '(' 'TestCase' ')' ':'
suite=suite
@@ -67,10 +67,9 @@ def dedent(self, suite, dedent):
# todo: handle tabs
if len(kid.prefix) > len(self.current_indent):
kid.prefix = self.current_indent
-
def transform(self, node, results):
- suite = results['suite'].clone()
+ suite = results["suite"].clone()
# todo: handle tabs
dedent = len(find_indentation(suite)) - len(find_indentation(node))
self.dedent(suite, dedent)
diff --git a/unittest2pytest/fixes/fix_self_assert.py b/unittest2pytest/fixes/fix_self_assert.py
index 18c4f9f..e4c9992 100644
--- a/unittest2pytest/fixes/fix_self_assert.py
+++ b/unittest2pytest/fixes/fix_self_assert.py
@@ -26,21 +26,33 @@
__licence__ = "GNU General Public License version 3 or later (GPLv3+)"
-from fissix.fixer_base import BaseFix
-from fissix.fixer_util import (
- Comma, Name, Call, Node, Leaf,
- Newline, KeywordArg, find_indentation,
- ArgList, String, Number, syms, token,
- does_tree_import, is_import, parenthesize)
-
-from functools import partial
import re
import unittest
+from functools import partial
+
+from fissix.fixer_base import BaseFix
+from fissix.fixer_util import (
+ Call,
+ Comma,
+ KeywordArg,
+ Leaf,
+ Name,
+ Newline,
+ Node,
+ Number,
+ String,
+ does_tree_import,
+ find_indentation,
+ is_import,
+ parenthesize,
+ syms,
+ token,
+)
from .. import utils
+TEMPLATE_PATTERN = re.compile("[\1\2]|[^\1\2]+")
-TEMPLATE_PATTERN = re.compile('[\1\2]|[^\1\2]+')
def CompOp(op, left, right, kws):
op = Name(op, prefix=" ")
@@ -48,7 +60,7 @@ def CompOp(op, left, right, kws):
right = parenthesize_expression(right)
left.prefix = ""
- if '\n' not in right.prefix:
+ if "\n" not in right.prefix:
right.prefix = " "
return Node(syms.comparison, (left, op, right), prefix=" ")
@@ -78,11 +90,12 @@ def UnaryOp(prefix, postfix, value, kws):
syms.comparison,
]
+
def parenthesize_expression(value):
if value.type in _NEEDS_PARENTHESIS:
parenthesized = parenthesize(value.clone())
parenthesized.prefix = parenthesized.children[1].prefix
- parenthesized.children[1].prefix = ''
+ parenthesized.children[1].prefix = ""
value = parenthesized
return value
@@ -91,43 +104,43 @@ def fill_template(template, *args):
parts = TEMPLATE_PATTERN.findall(template)
kids = []
for p in parts:
- if p == '':
+ if p == "":
continue
- elif p in '\1\2\3\4\5':
- p = args[ord(p)-1]
- p.prefix = ''
+ elif p in "\1\2\3\4\5":
+ p = args[ord(p) - 1]
+ p.prefix = ""
else:
p = Name(p)
kids.append(p.clone())
return kids
+
def DualOp(template, first, second, kws):
kids = fill_template(template, first, second)
return Node(syms.test, kids, prefix=" ")
def SequenceEqual(left, right, kws):
- if 'seq_type' in kws:
+ if "seq_type" in kws:
# :todo: implement `assert isinstance(xx, seq_type)`
pass
- return CompOp('==', left, right, kws)
+ return CompOp("==", left, right, kws)
def AlmostOp(places_op, delta_op, first, second, kws):
- first.prefix = ""
+ first.prefix = ""
second.prefix = ""
first = parenthesize_expression(first)
second = parenthesize_expression(second)
- abs_op = Call(Name('abs'),
- [Node(syms.factor, [first, Name('-'), second])])
- if kws.get('delta', None) is not None:
+ abs_op = Call(Name("abs"), [Node(syms.factor, [first, Name("-"), second])])
+ if kws.get("delta", None) is not None:
# delta
- return CompOp(delta_op, abs_op, kws['delta'], {})
+ return CompOp(delta_op, abs_op, kws["delta"], {})
else:
# `7` is the default in unittest.TestCase.assertAlmostEqual
- places = kws['places'] or Number(7)
+ places = kws["places"] or Number(7)
places.prefix = " "
- round_op = Call(Name('round'), (abs_op, Comma(), places))
+ round_op = Call(Name("round"), (abs_op, Comma(), places))
return CompOp(places_op, round_op, Number(0), {})
@@ -135,28 +148,27 @@ def RaisesOp(context, exceptionClass, indent, kws, arglist, node):
exceptionClass.prefix = ""
args = [exceptionClass]
# Add match keyword arg to with statement if an expected regex was provided.
- if 'expected_regex' in kws:
- expected_regex = kws.get('expected_regex').clone()
- expected_regex.prefix = ''
- args.append(String(', '))
- args.append(
- KeywordArg(Name('match'), expected_regex))
+ if "expected_regex" in kws:
+ expected_regex = kws.get("expected_regex").clone()
+ expected_regex.prefix = ""
+ args.append(String(", "))
+ args.append(KeywordArg(Name("match"), expected_regex))
with_item = Call(Name(context), args)
with_item.prefix = " "
args = []
arglist = [a.clone() for a in arglist.children[4:]]
if arglist:
- arglist[0].prefix=""
+ arglist[0].prefix = ""
func = None
# :fixme: this uses hardcoded parameter names, which may change
- if 'callableObj' in kws:
- func = kws['callableObj']
- elif 'callable_obj' in kws:
- func = kws['callable_obj']
- elif kws['args']: # any arguments assigned to `*args`
- func = kws['args'][0]
+ if "callableObj" in kws:
+ func = kws["callableObj"]
+ elif "callable_obj" in kws:
+ func = kws["callable_obj"]
+ elif kws["args"]: # any arguments assigned to `*args`
+ func = kws["args"][0]
else:
func = None
@@ -173,18 +185,15 @@ def RaisesOp(context, exceptionClass, indent, kws, arglist, node):
suite = Call(func, arglist)
suite.prefix = indent + (4 * " ")
- return Node(syms.with_stmt,
- [Name('with'),
- with_item,
- Name(':'),
- Newline(),
- suite])
-
-def RaisesRegexOp(context, designator, exceptionClass, expected_regex,
- indent, kws, arglist, node):
+ return Node(syms.with_stmt, [Name("with"), with_item, Name(":"), Newline(), suite])
+
+
+def RaisesRegexOp(
+ context, designator, exceptionClass, expected_regex, indent, kws, arglist, node
+):
arglist = [a.clone() for a in arglist.children]
pattern = arglist[2]
- del arglist[2:4] # remove pattern and comma
+ del arglist[2:4] # remove pattern and comma
arglist = Node(syms.arglist, arglist)
with_stmt = RaisesOp(context, exceptionClass, indent, kws, arglist, node)
@@ -200,11 +209,13 @@ def RaisesRegexOp(context, designator, exceptionClass, expected_regex,
else:
return Node(syms.suite, [with_stmt])
+
def FailOp(indent, kws, arglist, node):
new = node.clone()
- new.set_child(0, Name('pytest'))
+ new.set_child(0, Name("pytest"))
return new
+
def add_import(import_name, node):
suite = get_parent_of_type(node, syms.suite)
test_case = suite
@@ -213,10 +224,13 @@ def add_import(import_name, node):
file_input = test_case.parent
if not does_tree_import(None, import_name, node):
- import_stmt = Node(syms.simple_stmt,
- [Node(syms.import_name, [Name('import'), Name(import_name, prefix=' ')]),
- Newline(),
- ])
+ import_stmt = Node(
+ syms.simple_stmt,
+ [
+ Node(syms.import_name, [Name("import"), Name(import_name, prefix=" ")]),
+ Newline(),
+ ],
+ )
insert_import(import_stmt, test_case, file_input)
@@ -239,69 +253,61 @@ def insert_import(import_stmt, test_case, file_input):
else:
i = file_input.children.index(test_case)
import_stmt.prefix = test_case.prefix
- test_case.prefix = ''
+ test_case.prefix = ""
file_input.insert_child(i, import_stmt)
def get_import_nodes(node):
return [
- x for c in node.children
+ x
+ for c in node.children
for x in c.children
- if c.type == syms.simple_stmt
- and is_import(x)
+ if c.type == syms.simple_stmt and is_import(x)
]
_method_map = {
# simple ones
- 'assertEqual': partial(CompOp, '=='),
- 'assertNotEqual': partial(CompOp, '!='),
- 'assertFalse': partial(UnaryOp, 'not', ''),
- 'assertGreater': partial(CompOp, '>'),
- 'assertGreaterEqual': partial(CompOp, '>='),
- 'assertIn': partial(CompOp, 'in'),
- 'assertIs': partial(CompOp, 'is'),
- 'assertIsInstance': partial(DualOp, 'isinstance(\1, \2)'),
- 'assertIsNone': partial(UnaryOp, '', 'is None'),
- 'assertIsNot': partial(CompOp, 'is not'),
- 'assertIsNotNone': partial(UnaryOp, '', 'is not None'),
- 'assertLess': partial(CompOp, '<'),
- 'assertLessEqual': partial(CompOp, '<='),
- 'assertNotIn': partial(CompOp, 'not in'),
- 'assertNotIsInstance': partial(DualOp, 'not isinstance(\1, \2)'),
- 'assertTrue': partial(UnaryOp, '', ''),
-
+ "assertEqual": partial(CompOp, "=="),
+ "assertNotEqual": partial(CompOp, "!="),
+ "assertFalse": partial(UnaryOp, "not", ""),
+ "assertGreater": partial(CompOp, ">"),
+ "assertGreaterEqual": partial(CompOp, ">="),
+ "assertIn": partial(CompOp, "in"),
+ "assertIs": partial(CompOp, "is"),
+ "assertIsInstance": partial(DualOp, "isinstance(\1, \2)"),
+ "assertIsNone": partial(UnaryOp, "", "is None"),
+ "assertIsNot": partial(CompOp, "is not"),
+ "assertIsNotNone": partial(UnaryOp, "", "is not None"),
+ "assertLess": partial(CompOp, "<"),
+ "assertLessEqual": partial(CompOp, "<="),
+ "assertNotIn": partial(CompOp, "not in"),
+ "assertNotIsInstance": partial(DualOp, "not isinstance(\1, \2)"),
+ "assertTrue": partial(UnaryOp, "", ""),
# types ones
- 'assertDictEqual': partial(CompOp, '=='),
- 'assertListEqual': partial(CompOp, '=='),
- 'assertMultiLineEqual': partial(CompOp, '=='),
- 'assertSetEqual': partial(CompOp, '=='),
- 'assertTupleEqual': partial(CompOp, '=='),
- 'assertSequenceEqual': SequenceEqual,
-
- 'assertDictContainsSubset': partial(DualOp, '{**\2, **\1} == \2'),
- 'assertItemsEqual': partial(DualOp, 'sorted(\1) == sorted(\2)'),
-
- 'assertAlmostEqual': partial(AlmostOp, "==", "<"),
- 'assertNotAlmostEqual': partial(AlmostOp, "!=", ">"),
-
- 'assertRaises': partial(RaisesOp, 'pytest.raises'),
- 'assertWarns': partial(RaisesOp, 'pytest.warns'), # new Py 3.2
-
- 'assertRegex': partial(DualOp, 're.search(\2, \1)'),
- 'assertNotRegex': partial(DualOp, 'not re.search(\2, \1)'), # new Py 3.2
-
- 'assertRaisesRegex': partial(RaisesRegexOp, 'pytest.raises', 'excinfo'),
- 'assertWarnsRegex': partial(RaisesRegexOp, 'pytest.warns', 'record'),
-
- 'fail': FailOp,
-
+ "assertDictEqual": partial(CompOp, "=="),
+ "assertListEqual": partial(CompOp, "=="),
+ "assertMultiLineEqual": partial(CompOp, "=="),
+ "assertSetEqual": partial(CompOp, "=="),
+ "assertTupleEqual": partial(CompOp, "=="),
+ "assertSequenceEqual": SequenceEqual,
+ "assertDictContainsSubset": partial(DualOp, "{**\2, **\1} == \2"),
+ "assertItemsEqual": partial(DualOp, "sorted(\1) == sorted(\2)"),
+ "assertAlmostEqual": partial(AlmostOp, "==", "<"),
+ "assertNotAlmostEqual": partial(AlmostOp, "!=", ">"),
+ "assertRaises": partial(RaisesOp, "pytest.raises"),
+ "assertWarns": partial(RaisesOp, "pytest.warns"), # new Py 3.2
+ "assertRegex": partial(DualOp, "re.search(\2, \1)"),
+ "assertNotRegex": partial(DualOp, "not re.search(\2, \1)"), # new Py 3.2
+ "assertRaisesRegex": partial(RaisesRegexOp, "pytest.raises", "excinfo"),
+ "assertWarnsRegex": partial(RaisesRegexOp, "pytest.warns", "record"),
+ "fail": FailOp,
#'assertLogs': -- not to be handled here, is an context handler only
}
for newname, oldname in (
- ('assertRaisesRegex', 'assertRaisesRegexp'),
- ('assertRegex', 'assertRegexpMatches'),
+ ("assertRaisesRegex", "assertRaisesRegexp"),
+ ("assertRegex", "assertRegexpMatches"),
):
if not hasattr(unittest.TestCase, newname):
# use old name
@@ -315,25 +321,24 @@ def get_import_nodes(node):
# (Deprecated) Aliases
_method_aliases = {
- 'assertEquals' : 'assertEqual',
- 'assertNotEquals' : 'assertNotEqual',
- 'assert_' : 'assertTrue',
- 'assertAlmostEquals' : 'assertAlmostEqual',
- 'assertNotAlmostEquals': 'assertNotAlmostEqual',
- 'assertRegexpMatches' : 'assertRegex',
- 'assertRaisesRegexp' : 'assertRaisesRegex',
-
- 'failUnlessEqual' : 'assertEqual',
- 'failIfEqual' : 'assertNotEqual',
- 'failUnless' : 'assertTrue',
- 'failIf' : 'assertFalse',
- 'failUnlessRaises' : 'assertRaises',
- 'failUnlessAlmostEqual': 'assertAlmostEqual',
- 'failIfAlmostEqual' : 'assertNotAlmostEqual',
+ "assertEquals": "assertEqual",
+ "assertNotEquals": "assertNotEqual",
+ "assert_": "assertTrue",
+ "assertAlmostEquals": "assertAlmostEqual",
+ "assertNotAlmostEquals": "assertNotAlmostEqual",
+ "assertRegexpMatches": "assertRegex",
+ "assertRaisesRegexp": "assertRaisesRegex",
+ "failUnlessEqual": "assertEqual",
+ "failIfEqual": "assertNotEqual",
+ "failUnless": "assertTrue",
+ "failIf": "assertFalse",
+ "failUnlessRaises": "assertRaises",
+ "failUnlessAlmostEqual": "assertAlmostEqual",
+ "failIfAlmostEqual": "assertNotAlmostEqual",
}
for a, o in list(_method_aliases.items()):
- if not o in _method_map:
+ if o not in _method_map:
# if the original name is not a TestCase method, remove the alias
del _method_aliases[a]
@@ -378,49 +383,57 @@ def get_import_nodes(node):
class FixSelfAssert(BaseFix):
-
PATTERN = """
power< 'self'
trailer< '.' method=( %s ) >
trailer< '(' [arglist=any] ')' >
>
- """ % ' | '.join(map(repr,
- (set(_method_map.keys()) | set(_method_aliases.keys()))))
+ """ % " | ".join(map(repr, (set(_method_map.keys()) | set(_method_aliases.keys()))))
def transform(self, node, results):
def process_arg(arg):
if isinstance(arg, Leaf) and arg.type == token.COMMA:
return
- elif (isinstance(arg, Node) and arg.type == syms.argument and
- arg.children[1].type == token.EQUAL):
+ elif (
+ isinstance(arg, Node)
+ and arg.type == syms.argument
+ and arg.children[1].type == token.EQUAL
+ ):
# keyword argument
name, equal, value = arg.children
assert name.type == token.NAME
assert equal.type == token.EQUAL
value = value.clone()
kwargs[name.value] = value
- if '\n' in arg.prefix:
+ if "\n" in arg.prefix:
value.prefix = arg.prefix
else:
value.prefix = arg.prefix.strip() + " "
else:
- if (isinstance(arg, Node) and arg.type == syms.argument and
- arg.children[0].type == 36 and arg.children[0].value == '**'):
+ if (
+ isinstance(arg, Node)
+ and arg.type == syms.argument
+ and arg.children[0].type == 36
+ and arg.children[0].value == "**"
+ ):
return
- assert not kwargs, 'all positional args are assumed to come first'
- if (isinstance(arg, Node) and arg.type == syms.argument and
- arg.children[1].type == syms.comp_for):
+ assert not kwargs, "all positional args are assumed to come first"
+ if (
+ isinstance(arg, Node)
+ and arg.type == syms.argument
+ and arg.children[1].type == syms.comp_for
+ ):
# argument is a generator expression w/o
# parenthesis, add parenthesis
value = arg.clone()
- value.children.insert(0, Leaf(token.LPAR, '('))
- value.children.append(Leaf(token.RPAR, ')'))
+ value.children.insert(0, Leaf(token.LPAR, "("))
+ value.children.append(Leaf(token.RPAR, ")"))
posargs.append(value)
else:
posargs.append(arg.clone())
- method = results['method'][0].value
+ method = results["method"][0].value
# map (deprecated) aliases to original to avoid analysing
# the decorator function
method = _method_aliases.get(method, method)
@@ -429,52 +442,57 @@ def process_arg(arg):
kwargs = {}
# This is either empty, an "arglist", or a single argument
- if 'arglist' not in results:
+ if "arglist" not in results:
pass
- elif results['arglist'].type == syms.arglist:
- for arg in results['arglist'].children:
+ elif results["arglist"].type == syms.arglist:
+ for arg in results["arglist"].children:
process_arg(arg)
else:
- process_arg(results['arglist'])
+ process_arg(results["arglist"])
try:
test_func = getattr(unittest.TestCase, method)
except AttributeError:
- raise RuntimeError("Your unittest package does not support '%s'. "
- "consider updating the package" % method)
+ raise RuntimeError(
+ "Your unittest package does not support '%s'. "
+ "consider updating the package" % method
+ )
required_args, argsdict = utils.resolve_func_args(test_func, posargs, kwargs)
- if method.startswith(('assertRaises', 'assertWarns')) or method == 'fail':
- n_stmt = _method_map[method](*required_args,
- indent=find_indentation(node),
- kws=argsdict,
- arglist=results.get('arglist'),
- node=node)
+ if method.startswith(("assertRaises", "assertWarns")) or method == "fail":
+ n_stmt = _method_map[method](
+ *required_args,
+ indent=find_indentation(node),
+ kws=argsdict,
+ arglist=results.get("arglist"),
+ node=node,
+ )
else:
- n_stmt = Node(syms.assert_stmt,
- [Name('assert'),
- _method_map[method](*required_args, kws=argsdict)])
- if argsdict.get('msg', None) is not None and method != 'fail':
- n_stmt.children.extend((Name(','), argsdict['msg']))
+ n_stmt = Node(
+ syms.assert_stmt,
+ [Name("assert"), _method_map[method](*required_args, kws=argsdict)],
+ )
+ if argsdict.get("msg", None) is not None and method != "fail":
+ n_stmt.children.extend((Name(","), argsdict["msg"]))
def fix_line_wrapping(x):
for c in x.children:
# no need to worry about wrapping of "[", "{" and "("
if c.type in [token.LSQB, token.LBRACE, token.LPAR]:
break
- if c.prefix.startswith('\n'):
- c.prefix = c.prefix.replace('\n', ' \\\n')
+ if c.prefix.startswith("\n"):
+ c.prefix = c.prefix.replace("\n", " \\\n")
fix_line_wrapping(c)
+
fix_line_wrapping(n_stmt)
# the prefix should be set only after fixing line wrapping because it can contain a '\n'
n_stmt.prefix = node.prefix
# add necessary imports
- if 'Raises' in method or 'Warns' in method or method == 'fail':
- add_import('pytest', node)
- if ('Regex' in method and not 'Raises' in method and
- not 'Warns' in method):
- add_import('re', node)
+ if "Raises" in method or "Warns" in method or method == "fail":
+ add_import("pytest", node)
+ if "Regex" in method and "Raises" not in method and "Warns" not in method:
+ add_import("re", node)
return n_stmt
diff --git a/unittest2pytest/utils.py b/unittest2pytest/utils.py
index c223830..753d58a 100644
--- a/unittest2pytest/utils.py
+++ b/unittest2pytest/utils.py
@@ -29,22 +29,28 @@
from inspect import Parameter
-class SelfMarker: pass
+class SelfMarker:
+ pass
def resolve_func_args(test_func, posargs, kwargs):
sig = inspect.signature(test_func)
- assert (list(iter(sig.parameters))[0] == 'self')
+ assert list(iter(sig.parameters))[0] == "self"
posargs.insert(0, SelfMarker)
ba = sig.bind(*posargs, **kwargs)
ba.apply_defaults()
args = ba.arguments
- required_args = [n for n,v in sig.parameters.items()
- if (v.default is Parameter.empty and
- v.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD))]
- assert args['self'] == SelfMarker
- assert required_args[0] == 'self'
- del required_args[0], args['self']
+ required_args = [
+ n
+ for n, v in sig.parameters.items()
+ if (
+ v.default is Parameter.empty
+ and v.kind not in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD)
+ )
+ ]
+ assert args["self"] == SelfMarker
+ assert required_args[0] == "self"
+ del required_args[0], args["self"]
required_args = [args[n] for n in required_args]
return required_args, args