|
| 1 | +import json |
| 2 | +import os |
| 3 | +from datetime import datetime |
| 4 | +from typing import List, Dict, Any, Union |
| 5 | +from dataclasses import is_dataclass |
| 6 | +from .base import TensorSpec |
| 7 | +from .devices import InfiniDeviceEnum |
| 8 | + |
| 9 | +class TestReporter: |
| 10 | + """ |
| 11 | + Handles report generation and file saving for test results. |
| 12 | + """ |
| 13 | + |
| 14 | + @staticmethod |
| 15 | + def prepare_report_entry( |
| 16 | + op_name: str, |
| 17 | + test_cases: List[Any], |
| 18 | + args: Any, |
| 19 | + op_paths: Dict[str, str], |
| 20 | + results_list: List[Any] |
| 21 | + ) -> List[Dict[str, Any]]: |
| 22 | + """ |
| 23 | + Combines static test case info with dynamic execution results. |
| 24 | + """ |
| 25 | + # 1. Normalize results |
| 26 | + results_map = {} |
| 27 | + if isinstance(results_list, list): |
| 28 | + results_map = {i: res for i, res in enumerate(results_list)} |
| 29 | + elif isinstance(results_list, dict): |
| 30 | + results_map = results_list |
| 31 | + else: |
| 32 | + results_map = {0: results_list} if results_list else {} |
| 33 | + |
| 34 | + # 2. Global Args |
| 35 | + global_args = { |
| 36 | + k: getattr(args, k) |
| 37 | + for k in ["bench", "num_prerun", "num_iterations", "verbose", "debug"] |
| 38 | + if hasattr(args, k) |
| 39 | + } |
| 40 | + |
| 41 | + grouped_entries: Dict[int, Dict[str, Any]] = {} |
| 42 | + |
| 43 | + # 3. Iterate Test Cases |
| 44 | + for idx, tc in enumerate(test_cases): |
| 45 | + res = results_map.get(idx) |
| 46 | + dev_id = getattr(res, "device", 0) if res else 0 |
| 47 | + |
| 48 | + # --- A. Initialize Group --- |
| 49 | + if dev_id not in grouped_entries: |
| 50 | + device_id_map = {v: k for k, v in vars(InfiniDeviceEnum).items() if not k.startswith("_")} |
| 51 | + dev_str = device_id_map.get(dev_id, str(dev_id)) |
| 52 | + |
| 53 | + grouped_entries[dev_id] = { |
| 54 | + "operator": op_name, |
| 55 | + "device": dev_str, |
| 56 | + "torch_op": op_paths.get("torch") or "unknown", |
| 57 | + "infinicore_op": op_paths.get("infinicore") or "unknown", |
| 58 | + "args": global_args, |
| 59 | + "testcases": [] |
| 60 | + } |
| 61 | + |
| 62 | + # --- B. Build Kwargs --- |
| 63 | + display_kwargs = {} |
| 64 | + |
| 65 | + # B1. Process existing kwargs |
| 66 | + for k, v in tc.kwargs.items(): |
| 67 | + # Handle Inplace: "out": index -> "out": "input_name" |
| 68 | + if k == "out" and isinstance(v, int): |
| 69 | + if 0 <= v < len(tc.inputs): |
| 70 | + display_kwargs[k] = tc.inputs[v].name |
| 71 | + else: |
| 72 | + display_kwargs[k] = f"Invalid_Index_{v}" |
| 73 | + else: |
| 74 | + display_kwargs[k] = (TestReporter._spec_to_dict(v) if isinstance(v, TensorSpec) else v) |
| 75 | + |
| 76 | + # B2. Inject Outputs into Kwargs |
| 77 | + if hasattr(tc, "output_specs") and tc.output_specs: |
| 78 | + for i, spec in enumerate(tc.output_specs): |
| 79 | + display_kwargs[f"out_{i}"] = TestReporter._spec_to_dict(spec) |
| 80 | + elif tc.output_spec: |
| 81 | + if "out" not in display_kwargs: |
| 82 | + display_kwargs["out"] = TestReporter._spec_to_dict(tc.output_spec) |
| 83 | + |
| 84 | + # --- C. Build Test Case Dictionary --- |
| 85 | + case_data = { |
| 86 | + "description": tc.description, |
| 87 | + "inputs": [TestReporter._spec_to_dict(i) for i in tc.inputs], |
| 88 | + "kwargs": display_kwargs, |
| 89 | + "comparison_target": tc.comparison_target, |
| 90 | + "tolerance": tc.tolerance, |
| 91 | + } |
| 92 | + |
| 93 | + # --- D. Inject Result --- |
| 94 | + if res: |
| 95 | + case_data["result"] = TestReporter._fmt_result(res) |
| 96 | + else: |
| 97 | + case_data["result"] = {"status": {"success": False, "error": "No result"}} |
| 98 | + |
| 99 | + grouped_entries[dev_id]["testcases"].append(case_data) |
| 100 | + |
| 101 | + return list(grouped_entries.values()) |
| 102 | + |
| 103 | + @staticmethod |
| 104 | + def save_all_results(save_path: str, total_results: List[Dict[str, Any]]): |
| 105 | + """ |
| 106 | + Saves the report list to a JSON file with specific custom formatting |
| 107 | + """ |
| 108 | + directory, filename = os.path.split(save_path) |
| 109 | + name, ext = os.path.splitext(filename) |
| 110 | + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3] |
| 111 | + |
| 112 | + final_path = os.path.join(directory, f"{name}_{timestamp}{ext}") |
| 113 | + |
| 114 | + # Define indentation levels for cleaner code |
| 115 | + indent_4 = ' ' * 4 |
| 116 | + indent_8 = ' ' * 8 |
| 117 | + indent_12 = ' ' * 12 |
| 118 | + indent_16 = ' ' * 16 |
| 119 | + indent_20 = ' ' * 20 |
| 120 | + |
| 121 | + print(f"💾 Saving to: {final_path}") |
| 122 | + try: |
| 123 | + with open(final_path, "w", encoding="utf-8") as f: |
| 124 | + f.write("[\n") |
| 125 | + |
| 126 | + for i, entry in enumerate(total_results): |
| 127 | + f.write(f"{indent_4}{{\n") |
| 128 | + keys = list(entry.keys()) |
| 129 | + |
| 130 | + for j, key in enumerate(keys): |
| 131 | + val = entry[key] |
| 132 | + comma = "," if j < len(keys) - 1 else "" |
| 133 | + |
| 134 | + # ------------------------------------------------- |
| 135 | + # Special Handling for 'testcases' list formatting |
| 136 | + # ------------------------------------------------- |
| 137 | + if key == "testcases" and isinstance(val, list): |
| 138 | + f.write(f'{indent_8}"{key}": [\n') |
| 139 | + |
| 140 | + for c_idx, case_item in enumerate(val): |
| 141 | + f.write(f"{indent_12}{{\n") |
| 142 | + case_keys = list(case_item.keys()) |
| 143 | + |
| 144 | + for k_idx, c_key in enumerate(case_keys): |
| 145 | + c_val = case_item[c_key] |
| 146 | + |
| 147 | + # [Logic A] Skip fields we merged manually after 'kwargs' |
| 148 | + if c_key in ["comparison_target", "tolerance"]: |
| 149 | + continue |
| 150 | + |
| 151 | + # Check comma for standard logic (might be overridden below) |
| 152 | + c_comma = "," if k_idx < len(case_keys) - 1 else "" |
| 153 | + |
| 154 | + # [Logic B] Handle 'kwargs' + Grouped Fields |
| 155 | + if c_key == "kwargs": |
| 156 | + # 1. Use Helper for kwargs (Fill/Flow logic) |
| 157 | + TestReporter._write_smart_field( |
| 158 | + f, c_key, c_val, indent_16, indent_20, close_comma="," |
| 159 | + ) |
| 160 | + |
| 161 | + # 2. Write subsequent comparison_target and tolerance (on a new line) |
| 162 | + cmp_v = json.dumps(case_item.get("comparison_target"), ensure_ascii=False) |
| 163 | + tol_v = json.dumps(case_item.get("tolerance"), ensure_ascii=False) |
| 164 | + |
| 165 | + remaining_keys = [k for k in case_keys[k_idx+1:] if k not in ("comparison_target", "tolerance")] |
| 166 | + line_comma = "," if remaining_keys else "" |
| 167 | + |
| 168 | + f.write(f'{indent_16}"comparison_target": {cmp_v}, "tolerance": {tol_v}{line_comma}\n') |
| 169 | + continue |
| 170 | + |
| 171 | + # [Logic C] Handle 'inputs' (Smart Wrap) |
| 172 | + if c_key == "inputs" and isinstance(c_val, list): |
| 173 | + TestReporter._write_smart_field( |
| 174 | + f, c_key, c_val, indent_16, indent_20, close_comma=c_comma |
| 175 | + ) |
| 176 | + continue |
| 177 | + |
| 178 | + # [Logic D] Standard fields (description, result, output_spec, etc.) |
| 179 | + else: |
| 180 | + c_val_str = json.dumps(c_val, ensure_ascii=False) |
| 181 | + f.write(f'{indent_16}"{c_key}": {c_val_str}{c_comma}\n') |
| 182 | + |
| 183 | + close_comma = "," if c_idx < len(val) - 1 else "" |
| 184 | + f.write(f"{indent_12}}}{close_comma}\n") |
| 185 | + |
| 186 | + f.write(f"{indent_8}]{comma}\n") |
| 187 | + |
| 188 | + # ------------------------------------------------- |
| 189 | + # Standard top-level fields (operator, args, etc.) |
| 190 | + # ------------------------------------------------- |
| 191 | + else: |
| 192 | + k_str = json.dumps(key, ensure_ascii=False) |
| 193 | + v_str = json.dumps(val, ensure_ascii=False) |
| 194 | + f.write(f"{indent_8}{k_str}: {v_str}{comma}\n") |
| 195 | + |
| 196 | + if i < len(total_results) - 1: |
| 197 | + f.write(f"{indent_4}}},\n") |
| 198 | + else: |
| 199 | + f.write(f"{indent_4}}}\n") |
| 200 | + |
| 201 | + f.write("]\n") |
| 202 | + print(f" ✅ Saved (Structure Matched).") |
| 203 | + except Exception as e: |
| 204 | + import traceback; traceback.print_exc() |
| 205 | + print(f" ❌ Save failed: {e}") |
| 206 | + |
| 207 | + # --- Internal Helpers --- |
| 208 | + @staticmethod |
| 209 | + def _write_smart_field(f, key, value, indent, sub_indent, close_comma=""): |
| 210 | + """ |
| 211 | + Helper to write a JSON field (List or Dict) with smart wrapping. |
| 212 | + - If compact length <= 180: Write on one line. |
| 213 | + - If > 180: Use 'Fill/Flow' mode (multiple items per line, wrap when line is full). |
| 214 | + """ |
| 215 | + # 1. Try Compact Mode |
| 216 | + compact_json = json.dumps(value, ensure_ascii=False) |
| 217 | + if len(compact_json) <= 180: |
| 218 | + f.write(f'{indent}"{key}": {compact_json}{close_comma}\n') |
| 219 | + return |
| 220 | + |
| 221 | + # 2. Fill/Flow Mode |
| 222 | + is_dict = isinstance(value, dict) |
| 223 | + open_char = '{' if is_dict else '[' |
| 224 | + close_char = '}' if is_dict else ']' |
| 225 | + |
| 226 | + f.write(f'{indent}"{key}": {open_char}') |
| 227 | + |
| 228 | + # Normalize items for iteration |
| 229 | + if is_dict: |
| 230 | + items = list(value.items()) |
| 231 | + else: |
| 232 | + items = value # List |
| 233 | + |
| 234 | + # Initialize current line length tracking |
| 235 | + # Length includes indent + "key": [ |
| 236 | + current_len = len(indent) + len(f'"{key}": {open_char}') |
| 237 | + |
| 238 | + for i, item in enumerate(items): |
| 239 | + # Format individual item string |
| 240 | + if is_dict: |
| 241 | + k, v = item |
| 242 | + val_str = json.dumps(v, ensure_ascii=False) |
| 243 | + item_str = f'"{k}": {val_str}' |
| 244 | + else: |
| 245 | + item_str = json.dumps(item, ensure_ascii=False) |
| 246 | + |
| 247 | + is_last = (i == len(items) - 1) |
| 248 | + item_comma = "" if is_last else ", " |
| 249 | + |
| 250 | + # Predict new length: current + item + comma |
| 251 | + if current_len + len(item_str) + len(item_comma) > 180: |
| 252 | + # Wrap to new line |
| 253 | + f.write(f'\n{sub_indent}') |
| 254 | + current_len = len(sub_indent) |
| 255 | + |
| 256 | + f.write(f'{item_str}{item_comma}') |
| 257 | + current_len += len(item_str) + len(item_comma) |
| 258 | + |
| 259 | + f.write(f'{close_char}{close_comma}\n') |
| 260 | + |
| 261 | + @staticmethod |
| 262 | + def _spec_to_dict(s): |
| 263 | + return { |
| 264 | + "name": getattr(s, "name", "unknown"), |
| 265 | + "shape": list(s.shape) if s.shape else None, |
| 266 | + "dtype": str(s.dtype).split(".")[-1], |
| 267 | + "strides": list(s.strides) if s.strides else None, |
| 268 | + } |
| 269 | + |
| 270 | + @staticmethod |
| 271 | + def _fmt_result(res): |
| 272 | + if not (is_dataclass(res) or hasattr(res, "success")): |
| 273 | + return str(res) |
| 274 | + |
| 275 | + get_time = lambda k: round(getattr(res, k, 0.0), 4) |
| 276 | + |
| 277 | + return { |
| 278 | + "status": { |
| 279 | + "success": getattr(res, "success", False), |
| 280 | + "error": getattr(res, "error_message", ""), |
| 281 | + }, |
| 282 | + "perf_ms": { |
| 283 | + "torch": { |
| 284 | + "host": get_time("torch_host_time"), |
| 285 | + "device": get_time("torch_device_time"), |
| 286 | + }, |
| 287 | + "infinicore": { |
| 288 | + "host": get_time("infini_host_time"), |
| 289 | + "device": get_time("infini_device_time"), |
| 290 | + }, |
| 291 | + }, |
| 292 | + } |
0 commit comments