diff --git a/.travis.yml b/.travis.yml index 959be50..0650f47 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,9 +9,9 @@ cache: pip install: - pip install -e .[test] script: -- py.test --cov=snapshottest tests examples/pytest +- py.test --cov=snapshottest --snapshot-strict tests examples/pytest # Run Pytest Example -- py.test examples/pytest +- py.test --snapshot-strict examples/pytest # Run Unittest Example - python examples/unittest/test_demo.py # Run nose diff --git a/examples/pytest/test_demo.py b/examples/pytest/test_demo.py index 1addd6c..1974585 100644 --- a/examples/pytest/test_demo.py +++ b/examples/pytest/test_demo.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import pytest + from snapshottest.file import FileSnapshot @@ -58,3 +60,8 @@ def test_multiple_files(snapshot, tmpdir): temp_file1 = tmpdir.join('example2.txt') temp_file1.write('Hello, world 2!') snapshot.assert_match(FileSnapshot(str(temp_file1))) + + +def test_unused_snapshot_should_be_reject_in_strict_mode(snapshot): + with pytest.raises(AssertionError, match="Saving snapshots not allowed in strict mode."): + snapshot.assert_match(False) diff --git a/setup.cfg b/setup.cfg index 40bdda5..5a97e99 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,5 +6,5 @@ exclude = snapshots .tox max-line-length = 120 [tool:pytest] -addopts = --cov snapshottest +addopts = --cov snapshottest --snapshot-strict testpaths = tests diff --git a/snapshottest/module.py b/snapshottest/module.py index 494ae79..4ef965e 100644 --- a/snapshottest/module.py +++ b/snapshottest/module.py @@ -4,6 +4,7 @@ import imp from collections import defaultdict import logging +from typing import Set, Iterable from .snapshot import Snapshot from .formatter import Formatter @@ -48,7 +49,7 @@ def delete_unvisited(self): del self.snapshots[unvisited] @property - def unvisited_snapshots(self): + def unvisited_snapshots(self) -> Set[str]: return set(self.snapshots.keys()) - self.visited_snapshots @classmethod @@ -63,7 +64,7 @@ def total_unvisited_snapshots(cls): return unvisited_snapshots, unvisited_modules @classmethod - def get_modules(cls): + def get_modules(cls) -> Iterable["SnapshotModule"]: return SnapshotModule._snapshot_modules.values() @classmethod @@ -185,13 +186,22 @@ def get_module_for_testpath(cls, test_filepath): return cls._snapshot_modules[test_filepath] + @classmethod + def all_unvisited_snapshots(cls) -> Set[str]: + unvisited: Set[str] = set() + for module in cls.get_modules(): + unvisited |= {f"{module.filepath}::{snapshot}" for snapshot in module.unvisited_snapshots} + + return unvisited + class SnapshotTest(object): _current_tester = None - def __init__(self): + def __init__(self, strict=False): self.curr_snapshot = '' self.snapshot_counter = 1 + self.strict = strict @property def module(self): @@ -220,6 +230,7 @@ def fail(self): self.module.mark_failed(self.test_name) def store(self, data): + assert not self.strict, "Saving snapshots not allowed in strict mode." formatter = Formatter.get_formatter(data) data = formatter.store(self, data) self.module[self.test_name] = data diff --git a/snapshottest/pytest.py b/snapshottest/pytest.py index bc0fba3..fc6ff1f 100644 --- a/snapshottest/pytest.py +++ b/snapshottest/pytest.py @@ -21,13 +21,19 @@ def pytest_addoption(parser): default=False, help='Dump diagnostic and progress information.' ) + group.addoption( + '--snapshot-strict', + action='store_true', + default=False, + help='Fails test if new snapshot is created or some snapshot is unused.' + ) class PyTestSnapshotTest(SnapshotTest): - def __init__(self, request=None): + def __init__(self, request=None, strict=False): self.request = request - super(PyTestSnapshotTest, self).__init__() + super(PyTestSnapshotTest, self).__init__(strict=strict) @property def module(self): @@ -50,6 +56,7 @@ def test_name(self): class SnapshotSession(object): def __init__(self, config): self.verbose = config.getoption("snapshot_verbose") + self.strict = config.option.snapshot_strict self.config = config def display(self, tr): @@ -58,7 +65,7 @@ def display(self, tr): tr.write_sep("=", "SnapshotTest summary") - for line in reporting_lines('pytest'): + for line in reporting_lines('pytest', strict=self.strict): tr.write_line(line) @@ -69,7 +76,8 @@ def pytest_assertrepr_compare(op, left, right): @pytest.fixture def snapshot(request): - with PyTestSnapshotTest(request) as snapshot_test: + strict = request.config.option.snapshot_strict + with PyTestSnapshotTest(request, strict=strict) as snapshot_test: yield snapshot_test @@ -82,6 +90,20 @@ def pytest_terminal_summary(terminalreporter): terminalreporter.config._snapshotsession.display(terminalreporter) +@pytest.hookimpl(trylast=True) +def pytest_sessionfinish(session, exitstatus): + from _pytest.main import EXIT_NOTESTSCOLLECTED, EXIT_OK, EXIT_TESTSFAILED + + unvisited_snapshots = SnapshotModule.all_unvisited_snapshots() + if not session.config.option.snapshot_strict or not unvisited_snapshots: + return + + if exitstatus not in (EXIT_NOTESTSCOLLECTED, EXIT_OK, EXIT_TESTSFAILED): + return + + session.exitstatus = EXIT_TESTSFAILED + + @pytest.mark.trylast # force the other plugins to initialise, fixes issue with capture not being properly initialised def pytest_configure(config): config._snapshotsession = SnapshotSession(config) diff --git a/snapshottest/reporting.py b/snapshottest/reporting.py index c0c88f5..5fbe421 100644 --- a/snapshottest/reporting.py +++ b/snapshottest/reporting.py @@ -4,7 +4,7 @@ from .module import SnapshotModule -def reporting_lines(testing_cli): +def reporting_lines(testing_cli, strict=False): successful_snapshots = SnapshotModule.stats_successful_snapshots() bold = ['bold'] if successful_snapshots: @@ -26,12 +26,20 @@ def reporting_lines(testing_cli): colored('{} snapshots failed', 'red', attrs=bold) + ' in {} test suites. ' + inspect_str ).format(*failed_snapshots) - unvisited_snapshots = SnapshotModule.stats_unvisited_snapshots() - if unvisited_snapshots[0]: - yield ( - colored('{} snapshots deprecated', 'yellow', attrs=bold) + ' in {} test suites. ' - + inspect_str - ).format(*unvisited_snapshots) + unvisited_snapshots = SnapshotModule.all_unvisited_snapshots() + if unvisited_snapshots: + if strict: + error_msg = colored(f'ERROR: {len(unvisited_snapshots)} snapshots unused.', 'red', attrs=bold) + comment = colored("In strict mode all snapshots must be used.",attrs=['dark']) + yield f"{error_msg} {comment}" + yield colored("Unused snapshots:", "red", attrs=bold) + yield from unvisited_snapshots + + else: + yield ( + colored('{} snapshots deprecated', 'yellow', attrs=bold) + ' in {} test suites. ' + + inspect_str + ).format(*unvisited_snapshots) def diff_report(left, right): diff --git a/tests/test_module.py b/tests/test_module.py index cef2207..7460df8 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -3,7 +3,7 @@ import pytest from snapshottest import Snapshot -from snapshottest.module import SnapshotModule +from snapshottest.module import SnapshotModule, SnapshotTest class TestSnapshotModuleLoading(object): @@ -27,3 +27,29 @@ def test_load_corrupted_snapshot(self, tmpdir): module = SnapshotModule("tests.snapshots.snap_error", str(filepath)) with pytest.raises(SyntaxError): module.load_snapshots() + + + +class MockSnapshotTest(SnapshotTest): + def __init__(self, *, strict=False, module: SnapshotModule, test_name: str) -> None: + super().__init__(strict=strict) + self._module = module + self._test_name = test_name + + @property + def module(self) -> SnapshotModule: + return self._module + + @property + def test_name(self) -> str: + return self._test_name + + +class TestStrictSnapshotModule: + def test_should_fail_if_new_snapshots_appear(self, tmpdir): + file_with_snapshots = tmpdir.join("snap_new.py") + module = SnapshotModule(module="tests.snapshots.snap_new", filepath=str(file_with_snapshots)) + test = MockSnapshotTest(strict=True, module=module, test_name="Some test") + + with pytest.raises(AssertionError, match="Saving snapshots not allowed in strict mode."): + test.assert_match(False) diff --git a/tests/test_pytest_unused_snapshot.py b/tests/test_pytest_unused_snapshot.py new file mode 100644 index 0000000..45236b5 --- /dev/null +++ b/tests/test_pytest_unused_snapshot.py @@ -0,0 +1,14 @@ +import subprocess + + +def test_suite_should_fail_for_unused_snapshots(snapshot): + completed_process = subprocess.run( + ["pytest", "--snapshot-strict", "-v", "tests/unused_snapshot/command.py"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8" + ) + + assert completed_process.returncode == 1 + assert "test_should_not_use_any_snapshot PASSED" in completed_process.stdout + assert "ERROR: 1 snapshots unused" in completed_process.stdout diff --git a/tests/unused_snapshot/command.py b/tests/unused_snapshot/command.py new file mode 100644 index 0000000..9ab62ef --- /dev/null +++ b/tests/unused_snapshot/command.py @@ -0,0 +1,5 @@ +import pytest + + +def test_should_not_use_any_snapshot(snapshot): + snapshot.assert_match(True) diff --git a/tests/unused_snapshot/snapshots/__init__.py b/tests/unused_snapshot/snapshots/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unused_snapshot/snapshots/snap_command.py b/tests/unused_snapshot/snapshots/snap_command.py new file mode 100644 index 0000000..9d1999d --- /dev/null +++ b/tests/unused_snapshot/snapshots/snap_command.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# snapshottest: v1 - https://goo.gl/zC4yUc +from __future__ import unicode_literals + +from snapshottest import GenericRepr, Snapshot +from snapshottest.file import FileSnapshot + + +snapshots = Snapshot() + +snapshots['not used'] = {} +snapshots['test_should_not_use_any_snapshot 1'] = True