Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Tests for EngineX consume argument processing."""

import pytest

from execution_testing.cli.pytest_commands.processors import (
HiveEnvironmentProcessor,
)


@pytest.mark.parametrize(
"parallelism_args",
[
["-n", "6"],
["-n=6"],
["-n6"],
["--numprocesses", "6"],
["--numprocesses=6"],
],
)
def test_enginex_parallelism_uses_loadgroup(
monkeypatch: pytest.MonkeyPatch, parallelism_args: list[str]
) -> None:
"""EngineX must use xdist loadgroup for every supported -n spelling."""
monkeypatch.delenv("HIVE_PARALLELISM", raising=False)

args = HiveEnvironmentProcessor("enginex").process_args(
[*parallelism_args]
)

assert "--dist" in args
assert args[args.index("--dist") + 1] == "loadgroup"


@pytest.mark.parametrize(
"dist_args",
[
["--dist", "load"],
["--dist=load"],
],
)
def test_enginex_parallelism_overrides_non_loadgroup_dist(
monkeypatch: pytest.MonkeyPatch, dist_args: list[str]
) -> None:
"""EngineX overrides incompatible xdist distribution modes."""
monkeypatch.delenv("HIVE_PARALLELISM", raising=False)

with pytest.warns(UserWarning, match="requires `--dist=loadgroup`"):
args = HiveEnvironmentProcessor("enginex").process_args(
["-n=6", *dist_args]
)

assert "--dist=load" not in args
if "--dist" in args:
assert args[args.index("--dist") + 1] == "loadgroup"
else:
assert "--dist=loadgroup" in args


def test_consume_engine_parallelism_does_not_force_loadgroup(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""The loadgroup override is scoped to consume enginex."""
monkeypatch.delenv("HIVE_PARALLELISM", raising=False)

args = HiveEnvironmentProcessor("engine").process_args(["-n=6"])

assert "--dist" not in args
assert "--dist=loadgroup" not in args
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def process_args(self, args: List[str]) -> List[str]:
if self.command_name == "enginex" and self._has_parallelism_flag(
modified_args
):
if "--dist" not in modified_args:
modified_args.extend(["--dist", "loadgroup"])
modified_args = self._ensure_loadgroup_dist(modified_args)

if os.getenv("HIVE_RANDOM_SEED") is not None:
warnings.warn(
Expand Down Expand Up @@ -160,7 +159,58 @@ def _has_regex_or_sim_limit(self, args: List[str]) -> bool:

def _has_parallelism_flag(self, args: List[str]) -> bool:
"""Check if args already contain parallelism flag."""
return "-n" in args
return any(
arg == "-n"
or arg.startswith("-n=")
or (arg.startswith("-n") and len(arg) > 2)
or arg == "--numprocesses"
or arg.startswith("--numprocesses=")
for arg in args
)

def _ensure_loadgroup_dist(self, args: List[str]) -> List[str]:
"""
Ensure EngineX xdist runs keep pre-alloc groups on one worker.

EngineX client cleanup depends on each worker seeing every test in a
group. Any xdist distribution mode other than loadgroup can split a
pre-alloc group across workers, causing each worker to start its own
group client and defer cleanup until session teardown.
"""
modified_args = args[:]
found_dist = False
changed_dist = False
index = 0

while index < len(modified_args):
arg = modified_args[index]
if arg == "--dist":
found_dist = True
if index + 1 < len(modified_args):
if modified_args[index + 1] != "loadgroup":
modified_args[index + 1] = "loadgroup"
changed_dist = True
index += 2
continue
modified_args.append("loadgroup")
changed_dist = True
elif arg.startswith("--dist="):
found_dist = True
if arg != "--dist=loadgroup":
modified_args[index] = "--dist=loadgroup"
changed_dist = True
index += 1

if not found_dist:
modified_args.extend(["--dist", "loadgroup"])
elif changed_dist:
warnings.warn(
"`consume enginex` requires `--dist=loadgroup`; overriding "
"the provided xdist distribution mode.",
stacklevel=2,
)

return modified_args


class WatchFlagsProcessor(ArgumentProcessor):
Expand Down
Loading