Skip to content

Commit 665bc56

Browse files
author
Shrey Modi
committed
gcs support
1 parent 9e36c13 commit 665bc56

File tree

2 files changed

+303
-81
lines changed

2 files changed

+303
-81
lines changed

eval_protocol/cli_commands/upload.py

Lines changed: 129 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def _is_eval_protocol_test(obj: Any) -> bool:
8181
return False
8282
# Must have pytest marks from evaluation_test
8383
marks = getattr(obj, "pytestmark", [])
84+
# Handle pytest proxy objects (APIRemovedInV1Proxy)
85+
if not isinstance(marks, (list, tuple)):
86+
try:
87+
marks = list(marks) if marks else []
88+
except (TypeError, AttributeError):
89+
return False
8490
return len(marks) > 0
8591

8692

@@ -131,51 +137,103 @@ def _discover_tests(root: str) -> list[DiscoveredTest]:
131137

132138
discovered: list[DiscoveredTest] = []
133139

134-
# Collect all test functions from Python files
135-
for file_path in _iter_python_files(root):
140+
class CollectionPlugin:
141+
"""Plugin to capture collected items without running code."""
142+
143+
def __init__(self):
144+
self.items = []
145+
146+
def pytest_ignore_collect(self, collection_path, config):
147+
"""Ignore problematic files before pytest tries to import them."""
148+
# Ignore specific files
149+
ignored_files = ["setup.py", "versioneer.py", "conf.py", "__main__.py"]
150+
if collection_path.name in ignored_files:
151+
return True
152+
153+
# Ignore hidden files (starting with .)
154+
if collection_path.name.startswith("."):
155+
return True
156+
157+
# Ignore test_discovery files
158+
if collection_path.name.startswith("test_discovery"):
159+
return True
160+
161+
return None
162+
163+
def pytest_collection_modifyitems(self, items):
164+
"""Hook called after collection is done."""
165+
self.items = items
166+
167+
plugin = CollectionPlugin()
168+
169+
# Run pytest collection only (--collect-only prevents code execution)
170+
# Override python_files to collect from ANY .py file
171+
args = [
172+
abs_root,
173+
"--collect-only",
174+
"-q",
175+
"--pythonwarnings=ignore",
176+
"-o",
177+
"python_files=*.py", # Override to collect all .py files
178+
]
179+
180+
try:
181+
# Suppress pytest output
182+
import io
183+
import contextlib
184+
185+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
186+
pytest.main(args, plugins=[plugin])
187+
except Exception:
188+
# If pytest collection fails, fall back to empty list
189+
return []
190+
191+
# Process collected items
192+
for item in plugin.items:
193+
if not hasattr(item, "obj"):
194+
continue
195+
196+
obj = item.obj
197+
if not _is_eval_protocol_test(obj):
198+
continue
199+
200+
origin = getattr(obj, "_origin_func", obj)
136201
try:
137-
unique_name = "ep_upload_" + re.sub(r"[^a-zA-Z0-9_]", "_", os.path.abspath(file_path))
138-
spec = importlib.util.spec_from_file_location(unique_name, file_path)
139-
if spec and spec.loader:
140-
module = importlib.util.module_from_spec(spec)
141-
sys.modules[spec.name] = module
142-
spec.loader.exec_module(module) # type: ignore[attr-defined]
143-
else:
144-
continue
202+
src_file = inspect.getsourcefile(origin) or str(item.path)
203+
_, lineno = inspect.getsourcelines(origin)
145204
except Exception:
146-
continue
205+
src_file, lineno = str(item.path), None
147206

148-
for name, obj in inspect.getmembers(module):
149-
if _is_eval_protocol_test(obj):
150-
origin = getattr(obj, "_origin_func", obj)
151-
try:
152-
src_file = inspect.getsourcefile(origin) or file_path
153-
_, lineno = inspect.getsourcelines(origin)
154-
except Exception:
155-
src_file, lineno = file_path, None
156-
157-
# Extract parametrization info from marks
158-
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)
159-
160-
# Generate synthetic nodeids for display
161-
base_nodeid = f"{os.path.basename(file_path)}::{name}"
162-
if has_parametrize and param_ids:
163-
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
164-
else:
165-
nodeids = [base_nodeid]
166-
167-
discovered.append(
168-
DiscoveredTest(
169-
module_path=module.__name__,
170-
module_name=module.__name__,
171-
qualname=f"{module.__name__}.{name}",
172-
file_path=os.path.abspath(src_file),
173-
lineno=lineno,
174-
has_parametrize=has_parametrize,
175-
param_count=param_count,
176-
nodeids=nodeids,
177-
)
178-
)
207+
# Extract parametrization info from marks
208+
has_parametrize, param_count, param_ids = _extract_param_info_from_marks(obj)
209+
210+
# Get module name and function name
211+
module_name = (
212+
item.module.__name__
213+
if hasattr(item, "module")
214+
else item.nodeid.split("::")[0].replace("/", ".").replace(".py", "")
215+
)
216+
func_name = item.name.split("[")[0] if "[" in item.name else item.name
217+
218+
# Generate nodeids
219+
base_nodeid = f"{os.path.basename(src_file)}::{func_name}"
220+
if param_ids:
221+
nodeids = [f"{base_nodeid}[{pid}]" for pid in param_ids]
222+
else:
223+
nodeids = [base_nodeid]
224+
225+
discovered.append(
226+
DiscoveredTest(
227+
module_path=module_name,
228+
module_name=module_name,
229+
qualname=f"{module_name}.{func_name}",
230+
file_path=os.path.abspath(src_file),
231+
lineno=lineno,
232+
has_parametrize=has_parametrize,
233+
param_count=param_count,
234+
nodeids=nodeids,
235+
)
236+
)
179237

180238
# Deduplicate by qualname (in case same test appears multiple times)
181239
by_qual: dict[str, DiscoveredTest] = {}
@@ -519,35 +577,35 @@ def upload_command(args: argparse.Namespace) -> int:
519577
description = getattr(args, "description", None)
520578
force = bool(getattr(args, "force", False))
521579

522-
# # Ensure FIREWORKS_API_KEY is available to the remote by storing it as a Fireworks secret
523-
# try:
524-
# fw_account_id = get_fireworks_account_id()
525-
# fw_api_key_value = get_fireworks_api_key()
526-
# if not fw_account_id and fw_api_key_value:
527-
# # Attempt to verify and resolve account id from server headers
528-
# resolved = verify_api_key_and_get_account_id(api_key=fw_api_key_value, api_base=get_fireworks_api_base())
529-
# if resolved:
530-
# fw_account_id = resolved
531-
# # Propagate to environment so downstream calls use it if needed
532-
# os.environ["FIREWORKS_ACCOUNT_ID"] = fw_account_id
533-
# print(f"Resolved FIREWORKS_ACCOUNT_ID via API verification: {fw_account_id}")
534-
# if fw_account_id and fw_api_key_value:
535-
# print("Ensuring FIREWORKS_API_KEY is registered as a secret on Fireworks for rollout...")
536-
# if create_or_update_fireworks_secret(
537-
# account_id=fw_account_id,
538-
# key_name="FIREWORKS_API_KEY",
539-
# secret_value=fw_api_key_value,
540-
# ):
541-
# print("✓ FIREWORKS_API_KEY secret created/updated on Fireworks.")
542-
# else:
543-
# print("Warning: Failed to create/update FIREWORKS_API_KEY secret on Fireworks.")
544-
# else:
545-
# if not fw_account_id:
546-
# print("Warning: FIREWORKS_ACCOUNT_ID not found; cannot register FIREWORKS_API_KEY secret.")
547-
# if not fw_api_key_value:
548-
# print("Warning: FIREWORKS_API_KEY not found locally; cannot register secret.")
549-
# except Exception as e:
550-
# print(f"Warning: Skipped Fireworks secret registration due to error: {e}")
580+
# Ensure FIREWORKS_API_KEY is available to the remote by storing it as a Fireworks secret
581+
try:
582+
fw_account_id = get_fireworks_account_id()
583+
fw_api_key_value = get_fireworks_api_key()
584+
if not fw_account_id and fw_api_key_value:
585+
# Attempt to verify and resolve account id from server headers
586+
resolved = verify_api_key_and_get_account_id(api_key=fw_api_key_value, api_base=get_fireworks_api_base())
587+
if resolved:
588+
fw_account_id = resolved
589+
# Propagate to environment so downstream calls use it if needed
590+
os.environ["FIREWORKS_ACCOUNT_ID"] = fw_account_id
591+
print(f"Resolved FIREWORKS_ACCOUNT_ID via API verification: {fw_account_id}")
592+
if fw_account_id and fw_api_key_value:
593+
print("Ensuring FIREWORKS_API_KEY is registered as a secret on Fireworks for rollout...")
594+
if create_or_update_fireworks_secret(
595+
account_id=fw_account_id,
596+
key_name="FIREWORKS_API_KEY",
597+
secret_value=fw_api_key_value,
598+
):
599+
print("✓ FIREWORKS_API_KEY secret created/updated on Fireworks.")
600+
else:
601+
print("Warning: Failed to create/update FIREWORKS_API_KEY secret on Fireworks.")
602+
else:
603+
if not fw_account_id:
604+
print("Warning: FIREWORKS_ACCOUNT_ID not found; cannot register FIREWORKS_API_KEY secret.")
605+
if not fw_api_key_value:
606+
print("Warning: FIREWORKS_API_KEY not found locally; cannot register secret.")
607+
except Exception as e:
608+
print(f"Warning: Skipped Fireworks secret registration due to error: {e}")
551609

552610
exit_code = 0
553611
for i, (qualname, source_file_path) in enumerate(selected_specs):

0 commit comments

Comments
 (0)