Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
UnaryCheck = Callable[[float], bool]
BinaryCheck = Callable[[float, float], bool]

# Global registry for recording special cases
special_cases_registry: List[str] = []


def make_strict_eq(v: float) -> UnaryCheck:
if math.isnan(v):
Expand Down Expand Up @@ -679,6 +682,11 @@ def parse_unary_case_block(case_block: str, func_name: str) -> List[UnaryCase]:
cases = []
for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)
# Record the special case line in the global registry.
# Per requirements, we record ALL case lines that match r_case pattern,
# including those that fail to parse or trigger warnings.
special_cases_registry.append(f"{func_name}: {case_str}")

if r_already_int_case.search(case_str):
cases.append(already_int_case)
elif r_even_round_halves_case.search(case_str):
Expand Down Expand Up @@ -1145,6 +1153,11 @@ def parse_binary_case_block(case_block: str, func_name: str) -> List[BinaryCase]
cases = []
for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)
# Record the special case line in the global registry.
# Per requirements, we record ALL case lines that match r_case pattern,
# including those that fail to parse or trigger warnings.
special_cases_registry.append(f"{func_name}: {case_str}")

if r_redundant_case.search(case_str):
continue
if r_binary_case.match(case_str):
Expand Down Expand Up @@ -1351,3 +1364,31 @@ def test_nan_propagation(func_name, x, data):

ph.assert_shape(func_name, out_shape=out.shape, expected=()) # sanity check
assert xp.isnan(out), f"{out=!r}, but should be NaN"


def test_print_special_cases_registry():
"""
Test function to emit all recorded special cases.

This test prints the complete registry of special cases that were parsed
from docstrings during module load time. The registry is populated when
the module is imported, as parse_unary_case_block and parse_binary_case_block
are called to set up test parameters.

This test always passes - it's purely for informational/debugging purposes
to view all special cases that were extracted from the Array API specification.
"""
print("\n" + "=" * 80)
print("SPECIAL CASES REGISTRY")
print("=" * 80)
if special_cases_registry:
for case_record in special_cases_registry:
print(case_record)
print("=" * 80)
print(f"Total special cases recorded: {len(special_cases_registry)}")
else:
print("No special cases recorded")
print("=" * 80)
# Verify the registry is accessible
assert isinstance(special_cases_registry, list)