Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 108 additions & 42 deletions eval_protocol/cli_commands/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ def _is_eval_protocol_test(obj: Any) -> bool:
return False
# Must have pytest marks from evaluation_test
marks = getattr(obj, "pytestmark", [])
# Handle pytest proxy objects (APIRemovedInV1Proxy)
if not isinstance(marks, (list, tuple)):
try:
marks = list(marks) if marks else []
except (TypeError, AttributeError):
return False
return len(marks) > 0


Expand All @@ -91,6 +97,14 @@ def _extract_param_info_from_marks(obj: Any) -> tuple[bool, int, list[str]]:
(has_parametrize, param_count, param_ids)
"""
marks = getattr(obj, "pytestmark", [])

# Handle pytest proxy objects (APIRemovedInV1Proxy) - same as _is_eval_protocol_test
if not isinstance(marks, (list, tuple)):
try:
marks = list(marks) if marks else []
except (TypeError, AttributeError):
marks = []

has_parametrize = False
total_combinations = 0
all_param_ids: list[str] = []
Expand Down Expand Up @@ -131,51 +145,103 @@ def _discover_tests(root: str) -> list[DiscoveredTest]:

discovered: list[DiscoveredTest] = []

# Collect all test functions from Python files
for file_path in _iter_python_files(root):
class CollectionPlugin:
"""Plugin to capture collected items without running code."""

def __init__(self):
self.items = []

def pytest_ignore_collect(self, collection_path, config):
"""Ignore problematic files before pytest tries to import them."""
# Ignore specific files
ignored_files = ["setup.py", "versioneer.py", "conf.py", "__main__.py"]
if collection_path.name in ignored_files:
return True

# Ignore hidden files (starting with .)
if collection_path.name.startswith("."):
return True

# Ignore test_discovery files
if collection_path.name.startswith("test_discovery"):
return True

return None

def pytest_collection_modifyitems(self, items):
"""Hook called after collection is done."""
self.items = items

plugin = CollectionPlugin()

# Run pytest collection only (--collect-only prevents code execution)
# Override python_files to collect from ANY .py file
args = [
abs_root,
"--collect-only",
"-q",
"--pythonwarnings=ignore",
"-o",
"python_files=*.py", # Override to collect all .py files
]

try:
# Suppress pytest output
import io
import contextlib

with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
pytest.main(args, plugins=[plugin])
except Exception:
# If pytest collection fails, fall back to empty list
return []

# Process collected items
for item in plugin.items:
if not hasattr(item, "obj"):
continue

obj = item.obj
if not _is_eval_protocol_test(obj):
continue

origin = getattr(obj, "_origin_func", obj)
try:
unique_name = "ep_upload_" + re.sub(r"[^a-zA-Z0-9_]", "_", os.path.abspath(file_path))
spec = importlib.util.spec_from_file_location(unique_name, file_path)
if spec and spec.loader:
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
else:
continue
src_file = inspect.getsourcefile(origin) or str(item.path)
_, lineno = inspect.getsourcelines(origin)
except Exception:
continue
src_file, lineno = str(item.path), None

for name, obj in inspect.getmembers(module):
if _is_eval_protocol_test(obj):
origin = getattr(obj, "_origin_func", obj)
try:
src_file = inspect.getsourcefile(origin) or file_path
_, lineno = inspect.getsourcelines(origin)
except Exception:
src_file, lineno = file_path, None

# Extract parametrization info from marks
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)

# Generate synthetic nodeids for display
base_nodeid = f"{os.path.basename(file_path)}::{name}"
if has_parametrize and param_ids:
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
else:
nodeids = [base_nodeid]

discovered.append(
DiscoveredTest(
module_path=module.__name__,
module_name=module.__name__,
qualname=f"{module.__name__}.{name}",
file_path=os.path.abspath(src_file),
lineno=lineno,
has_parametrize=has_parametrize,
param_count=param_count,
nodeids=nodeids,
)
)
# Extract parametrization info from marks
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)

# Get module name and function name
module_name = (
item.module.__name__
if hasattr(item, "module")
else item.nodeid.split("::")[0].replace("/", ".").replace(".py", "")
)
func_name = item.name.split("[")[0] if "[" in item.name else item.name

# Generate nodeids
base_nodeid = f"{os.path.basename(src_file)}::{func_name}"
if param_ids:
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
else:
nodeids = [base_nodeid]

discovered.append(
DiscoveredTest(
module_path=module_name,
module_name=module_name,
qualname=f"{module_name}.{func_name}",
file_path=os.path.abspath(src_file),
lineno=lineno,
has_parametrize=has_parametrize,
param_count=param_count,
nodeids=nodeids,
)
)

# Deduplicate by qualname (in case same test appears multiple times)
by_qual: dict[str, DiscoveredTest] = {}
Expand Down
Loading
Loading