diff --git a/changelog/13963.bugfix.rst b/changelog/13963.bugfix.rst new file mode 100644 index 00000000000..a5f7ebe5c03 --- /dev/null +++ b/changelog/13963.bugfix.rst @@ -0,0 +1,3 @@ +Fixed subtests running with `pytest-xdist `__ when their contexts contain objects that are not JSON-serializable. + +Fixes `pytest-dev/pytest-xdist#1273 `__. diff --git a/src/_pytest/subtests.py b/src/_pytest/subtests.py index e0ceb27f4b1..04ac36bdf23 100644 --- a/src/_pytest/subtests.py +++ b/src/_pytest/subtests.py @@ -11,6 +11,7 @@ from contextlib import ExitStack from contextlib import nullcontext import dataclasses +import pickle import time from types import TracebackType from typing import Any @@ -62,11 +63,16 @@ class SubtestContext: kwargs: Mapping[str, Any] def _to_json(self) -> dict[str, Any]: - return dataclasses.asdict(self) + result = dataclasses.asdict(self) + # Use protocol 0 because it is human-readable and guaranteed to be not-binary. + protocol = 0 + data = pickle.dumps(result["kwargs"], protocol=protocol) + result["kwargs"] = data.decode("UTF-8") + return result @classmethod def _from_json(cls, d: dict[str, Any]) -> Self: - return cls(msg=d["msg"], kwargs=d["kwargs"]) + return cls(msg=d["msg"], kwargs=pickle.loads(d["kwargs"].encode("UTF-8"))) @dataclasses.dataclass(init=False) diff --git a/testing/test_subtests.py b/testing/test_subtests.py index 6849df53622..67d567afd1c 100644 --- a/testing/test_subtests.py +++ b/testing/test_subtests.py @@ -1,5 +1,7 @@ from __future__ import annotations +from enum import Enum +import json import sys from typing import Literal @@ -957,7 +959,14 @@ def test(subtests): ) +class MyEnum(Enum): + """Used in test_serialization, needs to be declared at the module level to be pickled.""" + + A = "A" + + def test_serialization() -> None: + """Ensure subtest's kwargs are serialized using `saferepr` (pytest-dev/pytest-xdist#1273).""" from _pytest.subtests import pytest_report_from_serializable from _pytest.subtests import pytest_report_to_serializable @@ -968,10 +977,41 @@ def test_serialization() -> None: outcome="passed", when="call", longrepr=None, - context=SubtestContext(msg="custom message", kwargs=dict(i=10)), + context=SubtestContext(msg="custom message", kwargs=dict(i=10, a=MyEnum.A)), ) data = pytest_report_to_serializable(report) assert data is not None + # Ensure the report is actually serializable to JSON. + _ = json.dumps(data) new_report = pytest_report_from_serializable(data) assert new_report is not None - assert new_report.context == SubtestContext(msg="custom message", kwargs=dict(i=10)) + assert new_report.context == SubtestContext( + msg="custom message", kwargs=dict(i=10, a=MyEnum.A) + ) + + +def test_serialization_xdist(pytester: pytest.Pytester) -> None: # pragma: no cover + """Regression test for pytest-dev/pytest-xdist#1273.""" + pytest.importorskip("xdist") + pytester.makepyfile( + """ + from enum import Enum + import unittest + + class MyEnum(Enum): + A = "A" + + def test(subtests): + with subtests.test(a=MyEnum.A): + pass + + class T(unittest.TestCase): + + def test(self): + with self.subTest(a=MyEnum.A): + pass + """ + ) + pytester.syspathinsert() + result = pytester.runpytest("-n1", "-pxdist.plugin") + result.assert_outcomes(passed=2)