Skip to content

Commit 28ef01c

Browse files
authored
Merge pull request #717 from InfiniTensor/issue/716
issue/716: Add save feature for existing test cases
2 parents e7e96a2 + a8875c9 commit 28ef01c

File tree

5 files changed

+401
-41
lines changed

5 files changed

+401
-41
lines changed

test/infinicore/framework/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .test_case import TestCase, TestResult
33
from .benchmark import BenchmarkUtils, BenchmarkResult
44
from .config import (
5+
add_common_test_args,
56
get_args,
67
get_hardware_args_group,
78
get_test_devices,
@@ -36,7 +37,9 @@
3637
"TestConfig",
3738
"TestResult",
3839
"TestRunner",
40+
"TestReporter",
3941
# Core functions
42+
"add_common_test_args",
4043
"compare_results",
4144
"convert_infinicore_to_torch",
4245
"create_test_comparator",

test/infinicore/framework/config.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,42 @@ def get_hardware_args_group(parser):
4444

4545
return hardware_group
4646

47+
def add_common_test_args(parser: argparse.ArgumentParser):
48+
"""
49+
Adds common test/execution arguments to the passed parser object.
50+
Includes: bench, debug, verbose, save args.
51+
"""
52+
# Create an argument group to make help info clearer
53+
group = parser.add_argument_group("Common Execution Options")
54+
55+
group.add_argument(
56+
"--bench",
57+
nargs="?",
58+
const="both",
59+
choices=["host", "device", "both"],
60+
help="Enable performance benchmarking mode. "
61+
"Options: host (CPU time only), device (GPU time only), both (default)",
62+
)
63+
64+
group.add_argument(
65+
"--debug",
66+
action="store_true",
67+
help="Enable debug mode for detailed tensor comparison",
68+
)
69+
70+
group.add_argument(
71+
"--verbose",
72+
action="store_true",
73+
help="Enable verbose mode to stop on first error with full traceback",
74+
)
75+
76+
group.add_argument(
77+
"--save",
78+
nargs="?",
79+
const="test_report.json",
80+
default=None,
81+
help="Save test results to a JSON file. Defaults to 'test_report.json' if no filename provided.",
82+
)
4783

4884
def get_args():
4985
"""Parse command line arguments for operator testing"""
@@ -77,14 +113,6 @@ def get_args():
77113
)
78114

79115
# Core testing options
80-
parser.add_argument(
81-
"--bench",
82-
nargs="?",
83-
const="both",
84-
choices=["host", "device", "both"],
85-
help="Enable performance benchmarking mode. "
86-
"Options: host (CPU time only), device (GPU time only), both (default)",
87-
)
88116
parser.add_argument(
89117
"--num_prerun",
90118
type=lambda x: max(0, int(x)),
@@ -97,16 +125,9 @@ def get_args():
97125
default=1000,
98126
help="Number of iterations for benchmarking (default: 1000)",
99127
)
100-
parser.add_argument(
101-
"--debug",
102-
action="store_true",
103-
help="Enable debug mode for detailed tensor comparison",
104-
)
105-
parser.add_argument(
106-
"--verbose",
107-
action="store_true",
108-
help="Enable verbose mode to stop on first error with full traceback",
109-
)
128+
129+
# Call the common method to add arguments
130+
add_common_test_args(parser)
110131

111132
# Device options using shared hardware info
112133
hardware_group = get_hardware_args_group(parser)
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
import json
2+
import os
3+
from datetime import datetime
4+
from typing import List, Dict, Any, Union
5+
from dataclasses import is_dataclass
6+
from .base import TensorSpec
7+
from .devices import InfiniDeviceEnum
8+
9+
class TestReporter:
10+
"""
11+
Handles report generation and file saving for test results.
12+
"""
13+
14+
@staticmethod
15+
def prepare_report_entry(
16+
op_name: str,
17+
test_cases: List[Any],
18+
args: Any,
19+
op_paths: Dict[str, str],
20+
results_list: List[Any]
21+
) -> List[Dict[str, Any]]:
22+
"""
23+
Combines static test case info with dynamic execution results.
24+
"""
25+
# 1. Normalize results
26+
results_map = {}
27+
if isinstance(results_list, list):
28+
results_map = {i: res for i, res in enumerate(results_list)}
29+
elif isinstance(results_list, dict):
30+
results_map = results_list
31+
else:
32+
results_map = {0: results_list} if results_list else {}
33+
34+
# 2. Global Args
35+
global_args = {
36+
k: getattr(args, k)
37+
for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"]
38+
if hasattr(args, k)
39+
}
40+
41+
grouped_entries: Dict[int, Dict[str, Any]] = {}
42+
43+
# 3. Iterate Test Cases
44+
for idx, tc in enumerate(test_cases):
45+
res = results_map.get(idx)
46+
dev_id = getattr(res, "device", 0) if res else 0
47+
48+
# --- A. Initialize Group ---
49+
if dev_id not in grouped_entries:
50+
device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")}
51+
dev_str = device_id_map.get(dev_id, str(dev_id))
52+
53+
grouped_entries[dev_id] = {
54+
"operator": op_name,
55+
"device": dev_str,
56+
"torch_op": op_paths.get("torch") or "unknown",
57+
"infinicore_op": op_paths.get("infinicore") or "unknown",
58+
"args": global_args,
59+
"testcases": []
60+
}
61+
62+
# --- B. Build Kwargs ---
63+
display_kwargs = {}
64+
65+
# B1. Process existing kwargs
66+
for k, v in tc.kwargs.items():
67+
# Handle Inplace: "out": index -> "out": "input_name"
68+
if k == "out" and isinstance(v, int):
69+
if 0 <= v < len(tc.inputs):
70+
display_kwargs[k] = tc.inputs[v].name
71+
else:
72+
display_kwargs[k] = f"Invalid_Index_{v}"
73+
else:
74+
display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v)
75+
76+
# B2. Inject Outputs into Kwargs
77+
if hasattr(tc, "output_specs") and tc.output_specs:
78+
for i, spec in enumerate(tc.output_specs):
79+
display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec)
80+
elif tc.output_spec:
81+
if "out" not in display_kwargs:
82+
display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec)
83+
84+
# --- C. Build Test Case Dictionary ---
85+
case_data = {
86+
"description": tc.description,
87+
"inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs],
88+
"kwargs": display_kwargs,
89+
"comparison_target": tc.comparison_target,
90+
"tolerance": tc.tolerance,
91+
}
92+
93+
# --- D. Inject Result ---
94+
if res:
95+
case_data["result"] = TestReporter._fmt_result(res)
96+
else:
97+
case_data["result"] = {"status": {"success": False, "error": "No result"}}
98+
99+
grouped_entries[dev_id]["testcases"].append(case_data)
100+
101+
return list(grouped_entries.values())
102+
103+
@staticmethod
104+
def save_all_results(save_path: str, total_results: List[Dict[str, Any]]):
105+
"""
106+
Saves the report list to a JSON file with specific custom formatting
107+
"""
108+
directory, filename = os.path.split(save_path)
109+
name, ext = os.path.splitext(filename)
110+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3]
111+
112+
final_path = os.path.join(directory, f"{name}_{timestamp}{ext}")
113+
114+
# Define indentation levels for cleaner code
115+
indent_4 = ' ' * 4
116+
indent_8 = ' ' * 8
117+
indent_12 = ' ' * 12
118+
indent_16 = ' ' * 16
119+
indent_20 = ' ' * 20
120+
121+
print(f"💾 Saving to: {final_path}")
122+
try:
123+
with open(final_path, "w", encoding="utf-8") as f:
124+
f.write("[\n")
125+
126+
for i, entry in enumerate(total_results):
127+
f.write(f"{indent_4}{{\n")
128+
keys = list(entry.keys())
129+
130+
for j, key in enumerate(keys):
131+
val = entry[key]
132+
comma = "," if j < len(keys) - 1 else ""
133+
134+
# -------------------------------------------------
135+
# Special Handling for 'testcases' list formatting
136+
# -------------------------------------------------
137+
if key == "testcases" and isinstance(val, list):
138+
f.write(f'{indent_8}"{key}": [\n')
139+
140+
for c_idx, case_item in enumerate(val):
141+
f.write(f"{indent_12}{{\n")
142+
case_keys = list(case_item.keys())
143+
144+
for k_idx, c_key in enumerate(case_keys):
145+
c_val = case_item[c_key]
146+
147+
# [Logic A] Skip fields we merged manually after 'kwargs'
148+
if c_key in ["comparison_target", "tolerance"]:
149+
continue
150+
151+
# Check comma for standard logic (might be overridden below)
152+
c_comma = "," if k_idx < len(case_keys) - 1 else ""
153+
154+
# [Logic B] Handle 'kwargs' + Grouped Fields
155+
if c_key == "kwargs":
156+
# 1. Use Helper for kwargs (Fill/Flow logic)
157+
TestReporter._write_smart_field(
158+
f, c_key, c_val, indent_16, indent_20, close_comma=","
159+
)
160+
161+
# 2. Write subsequent comparison_target and tolerance (on a new line)
162+
cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False)
163+
tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False)
164+
165+
remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")]
166+
line_comma = "," if remaining_keys else ""
167+
168+
f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n')
169+
continue
170+
171+
# [Logic C] Handle 'inputs' (Smart Wrap)
172+
if c_key == "inputs" and isinstance(c_val, list):
173+
TestReporter._write_smart_field(
174+
f, c_key, c_val, indent_16, indent_20, close_comma=c_comma
175+
)
176+
continue
177+
178+
# [Logic D] Standard fields (description, result, output_spec, etc.)
179+
else:
180+
c_val_str = json.dumps(c_val, ensure_ascii=False)
181+
f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n')
182+
183+
close_comma = "," if c_idx < len(val) - 1 else ""
184+
f.write(f"{indent_12}}}{close_comma}\n")
185+
186+
f.write(f"{indent_8}]{comma}\n")
187+
188+
# -------------------------------------------------
189+
# Standard top-level fields (operator, args, etc.)
190+
# -------------------------------------------------
191+
else:
192+
k_str = json.dumps(key, ensure_ascii=False)
193+
v_str = json.dumps(val, ensure_ascii=False)
194+
f.write(f"{indent_8}{k_str}: {v_str}{comma}\n")
195+
196+
if i < len(total_results) - 1:
197+
f.write(f"{indent_4}}},\n")
198+
else:
199+
f.write(f"{indent_4}}}\n")
200+
201+
f.write("]\n")
202+
print(f" ✅ Saved (Structure Matched).")
203+
except Exception as e:
204+
import traceback; traceback.print_exc()
205+
print(f" ❌ Save failed: {e}")
206+
207+
# --- Internal Helpers ---
208+
@staticmethod
209+
def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""):
210+
"""
211+
Helper to write a JSON field (List or Dict) with smart wrapping.
212+
- If compact length <= 180: Write on one line.
213+
- If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full).
214+
"""
215+
# 1. Try Compact Mode
216+
compact_json = json.dumps(value, ensure_ascii=False)
217+
if len(compact_json) <= 180:
218+
f.write(f'{indent}"{key}": {compact_json}{close_comma}\n')
219+
return
220+
221+
# 2. Fill/Flow Mode
222+
is_dict = isinstance(value, dict)
223+
open_char = '{' if is_dict else '['
224+
close_char = '}' if is_dict else ']'
225+
226+
f.write(f'{indent}"{key}": {open_char}')
227+
228+
# Normalize items for iteration
229+
if is_dict:
230+
items = list(value.items())
231+
else:
232+
items = value # List
233+
234+
# Initialize current line length tracking
235+
# Length includes indent + "key": [
236+
current_len = len(indent) + len(f'"{key}": {open_char}')
237+
238+
for i, item in enumerate(items):
239+
# Format individual item string
240+
if is_dict:
241+
k, v = item
242+
val_str = json.dumps(v, ensure_ascii=False)
243+
item_str = f'"{k}": {val_str}'
244+
else:
245+
item_str = json.dumps(item, ensure_ascii=False)
246+
247+
is_last = (i == len(items) - 1)
248+
item_comma = "" if is_last else ", "
249+
250+
# Predict new length: current + item + comma
251+
if current_len + len(item_str) + len(item_comma) > 180:
252+
# Wrap to new line
253+
f.write(f'\n{sub_indent}')
254+
current_len = len(sub_indent)
255+
256+
f.write(f'{item_str}{item_comma}')
257+
current_len += len(item_str) + len(item_comma)
258+
259+
f.write(f'{close_char}{close_comma}\n')
260+
261+
@staticmethod
262+
def _spec_to_dict(s):
263+
return {
264+
"name": getattr(s, "name", "unknown"),
265+
"shape": list(s.shape) if s.shape else None,
266+
"dtype": str(s.dtype).split(".")[-1],
267+
"strides": list(s.strides) if s.strides else None,
268+
}
269+
270+
@staticmethod
271+
def _fmt_result(res):
272+
if not (is_dataclass(res) or hasattr(res, "success")):
273+
return str(res)
274+
275+
get_time = lambda k: round(getattr(res, k, 0.0), 4)
276+
277+
return {
278+
"status": {
279+
"success": getattr(res, "success", False),
280+
"error": getattr(res, "error_message", ""),
281+
},
282+
"perf_ms": {
283+
"torch": {
284+
"host": get_time("torch_host_time"),
285+
"device": get_time("torch_device_time"),
286+
},
287+
"infinicore": {
288+
"host": get_time("infini_host_time"),
289+
"device": get_time("infini_device_time"),
290+
},
291+
},
292+
}

0 commit comments

Comments
 (0)