Skip to content

Commit 8382181

Browse files
author
Shrey Modi
committed
zip support, change upload flow
1 parent 69dbd1b commit 8382181

File tree

2 files changed

+326
-48
lines changed

2 files changed

+326
-48
lines changed

eval_protocol/cli_commands/upload.py

Lines changed: 100 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def _is_eval_protocol_test(obj: Any) -> bool:
8181
return False
8282
# Must have pytest marks from evaluation_test
8383
marks = getattr(obj, "pytestmark", [])
84+
# Handle pytest proxy objects (APIRemovedInV1Proxy)
85+
if not isinstance(marks, (list, tuple)):
86+
try:
87+
marks = list(marks) if marks else []
88+
except (TypeError, AttributeError):
89+
return False
8490
return len(marks) > 0
8591

8692

@@ -131,51 +137,103 @@ def _discover_tests(root: str) -> list[DiscoveredTest]:
131137

132138
discovered: list[DiscoveredTest] = []
133139

134-
# Collect all test functions from Python files
135-
for file_path in _iter_python_files(root):
140+
class CollectionPlugin:
141+
"""Plugin to capture collected items without running code."""
142+
143+
def __init__(self):
144+
self.items = []
145+
146+
def pytest_ignore_collect(self, collection_path, config):
147+
"""Ignore problematic files before pytest tries to import them."""
148+
# Ignore specific files
149+
ignored_files = ["setup.py", "versioneer.py", "conf.py", "__main__.py"]
150+
if collection_path.name in ignored_files:
151+
return True
152+
153+
# Ignore hidden files (starting with .)
154+
if collection_path.name.startswith("."):
155+
return True
156+
157+
# Ignore test_discovery files
158+
if collection_path.name.startswith("test_discovery"):
159+
return True
160+
161+
return None
162+
163+
def pytest_collection_modifyitems(self, items):
164+
"""Hook called after collection is done."""
165+
self.items = items
166+
167+
plugin = CollectionPlugin()
168+
169+
# Run pytest collection only (--collect-only prevents code execution)
170+
# Override python_files to collect from ANY .py file
171+
args = [
172+
abs_root,
173+
"--collect-only",
174+
"-q",
175+
"--pythonwarnings=ignore",
176+
"-o",
177+
"python_files=*.py", # Override to collect all .py files
178+
]
179+
180+
try:
181+
# Suppress pytest output
182+
import io
183+
import contextlib
184+
185+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
186+
pytest.main(args, plugins=[plugin])
187+
except Exception:
188+
# If pytest collection fails, fall back to empty list
189+
return []
190+
191+
# Process collected items
192+
for item in plugin.items:
193+
if not hasattr(item, "obj"):
194+
continue
195+
196+
obj = item.obj
197+
if not _is_eval_protocol_test(obj):
198+
continue
199+
200+
origin = getattr(obj, "_origin_func", obj)
136201
try:
137-
unique_name = "ep_upload_" + re.sub(r"[^a-zA-Z0-9_]", "_", os.path.abspath(file_path))
138-
spec = importlib.util.spec_from_file_location(unique_name, file_path)
139-
if spec and spec.loader:
140-
module = importlib.util.module_from_spec(spec)
141-
sys.modules[spec.name] = module
142-
spec.loader.exec_module(module) # type: ignore[attr-defined]
143-
else:
144-
continue
202+
src_file = inspect.getsourcefile(origin) or str(item.path)
203+
_, lineno = inspect.getsourcelines(origin)
145204
except Exception:
146-
continue
205+
src_file, lineno = str(item.path), None
147206

148-
for name, obj in inspect.getmembers(module):
149-
if _is_eval_protocol_test(obj):
150-
origin = getattr(obj, "_origin_func", obj)
151-
try:
152-
src_file = inspect.getsourcefile(origin) or file_path
153-
_, lineno = inspect.getsourcelines(origin)
154-
except Exception:
155-
src_file, lineno = file_path, None
156-
157-
# Extract parametrization info from marks
158-
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)
159-
160-
# Generate synthetic nodeids for display
161-
base_nodeid = f"{os.path.basename(file_path)}::{name}"
162-
if has_parametrize and param_ids:
163-
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
164-
else:
165-
nodeids = [base_nodeid]
166-
167-
discovered.append(
168-
DiscoveredTest(
169-
module_path=module.__name__,
170-
module_name=module.__name__,
171-
qualname=f"{module.__name__}.{name}",
172-
file_path=os.path.abspath(src_file),
173-
lineno=lineno,
174-
has_parametrize=has_parametrize,
175-
param_count=param_count,
176-
nodeids=nodeids,
177-
)
178-
)
207+
# Extract parametrization info from marks
208+
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)
209+
210+
# Get module name and function name
211+
module_name = (
212+
item.module.__name__
213+
if hasattr(item, "module")
214+
else item.nodeid.split("::")[0].replace("/", ".").replace(".py", "")
215+
)
216+
func_name = item.name.split("[")[0] if "[" in item.name else item.name
217+
218+
# Generate nodeids
219+
base_nodeid = f"{os.path.basename(src_file)}::{func_name}"
220+
if param_ids:
221+
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
222+
else:
223+
nodeids = [base_nodeid]
224+
225+
discovered.append(
226+
DiscoveredTest(
227+
module_path=module_name,
228+
module_name=module_name,
229+
qualname=f"{module_name}.{func_name}",
230+
file_path=os.path.abspath(src_file),
231+
lineno=lineno,
232+
has_parametrize=has_parametrize,
233+
param_count=param_count,
234+
nodeids=nodeids,
235+
)
236+
)
179237

180238
# Deduplicate by qualname (in case same test appears multiple times)
181239
by_qual: dict[str, DiscoveredTest] = {}

0 commit comments

Comments
 (0)