|
9 | 9 | from IPython.display import display, SVG, Image as IPyImage |
10 | 10 | from typing import Optional, Sequence |
11 | 11 |
|
12 | | -# Semantically meaningful layer types for compact visualization |
| 12 | +# Semantically meaningful layer categories for compact visualization |
| 13 | +# These match the ATEN_LAYER_MAP categories in graph_builder.py |
13 | 14 | SEMANTIC_LAYER_TYPES: tuple[str, ...] = ( |
14 | | - "MLP_Layer", # Linear/FC layers |
15 | | - "DL_Layer", # Conv layers (Conv1d, Conv2d, Conv3d) |
| 15 | + "MLP_Layer", # Linear/FC layers (linear, addmm) |
| 16 | + "DL_Layer", # Conv layers (conv2d, max_pool2d, etc.) |
16 | 17 | "Activation", # ReLU, GELU, SiLU, etc. |
17 | 18 | "Normalization", # BatchNorm, LayerNorm, GroupNorm |
18 | 19 | "Attention", # Self/Cross attention (scaled_dot_product_attention) |
19 | | - "Output", # Final output node |
20 | | - "Placeholder", # Input nodes (x, input_ids, etc.) - CRITICAL for graph connectivity |
21 | | - "Model_Input", # Legacy input type (kept for compatibility) |
| 20 | + "Output", # Final output node (layer_type) |
| 21 | + "Placeholder", # Input nodes (layer_type) - CRITICAL for graph connectivity |
| 22 | + "Model_Input", # Legacy input type (layer_type) |
22 | 23 | "NLP_Embedding", # Embedding layers (embedding, embedding_bag) |
23 | 24 | ) |
24 | 25 |
|
25 | 26 | # Default types to always force-include (for graph connectivity) |
26 | 27 | DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Placeholder", "Model_Input", "Output") |
27 | 28 |
|
28 | 29 |
|
| 30 | +def _get_node_category(node_attrs: dict) -> str: |
| 31 | + """Get the semantic category for a node by checking both layer_type and layer_name. |
| 32 | + |
| 33 | + The graph stores: |
| 34 | + - layer_type: 'ATen_Operation', 'Placeholder', 'Output', 'Operation', etc. |
| 35 | + - layer_name: 'DL_Layer', 'MLP_Layer', 'Activation', 'Normalization', etc. |
| 36 | + |
| 37 | + For filtering, we need to check layer_name first (for ATen ops), then layer_type. |
| 38 | + """ |
| 39 | + layer_name = node_attrs.get("layer_name", "") |
| 40 | + layer_type = node_attrs.get("layer_type", "Unknown") |
| 41 | + |
| 42 | + # For ATen operations, layer_name contains the semantic category |
| 43 | + if layer_name in SEMANTIC_LAYER_TYPES: |
| 44 | + return layer_name |
| 45 | + |
| 46 | + # For Placeholder, Output, etc., layer_type is the category |
| 47 | + if layer_type in ("Placeholder", "Output", "Model_Input"): |
| 48 | + return layer_type |
| 49 | + |
| 50 | + # Return layer_type as fallback |
| 51 | + return layer_type |
| 52 | + |
| 53 | + |
29 | 54 | def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): |
30 | 55 | """📊 Visualize forward execution graph with dynamic scaling (shows inline + saves)""" |
31 | 56 | num_nodes = len(graph.nodes) |
@@ -112,16 +137,16 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", |
112 | 137 | force_include = { |
113 | 138 | node.replace("/", " ").replace(":", " ") |
114 | 139 | for node in graph.nodes |
115 | | - if graph.nodes[node].get("layer_type") in DEFAULT_FORCE_INCLUDE_TYPES |
| 140 | + if _get_node_category(graph.nodes[node]) in DEFAULT_FORCE_INCLUDE_TYPES |
116 | 141 | } |
117 | 142 |
|
118 | 143 | if layer_types is not None: |
119 | | - # Layer-type filtering mode |
| 144 | + # Layer-type filtering mode - use _get_node_category for proper semantic matching |
120 | 145 | layer_types_set = set(layer_types) |
121 | 146 | top_node_names = { |
122 | 147 | node.replace("/", " ").replace(":", " ") |
123 | 148 | for node in graph.nodes |
124 | | - if graph.nodes[node].get("layer_type") in layer_types_set |
| 149 | + if _get_node_category(graph.nodes[node]) in layer_types_set |
125 | 150 | } | force_include |
126 | 151 | elif top_k: |
127 | 152 | top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] |
@@ -154,7 +179,7 @@ def find_filtered_ancestors(node_raw, visited=None): |
154 | 179 | ancestors.update(find_filtered_ancestors(parent_raw, visited)) |
155 | 180 | return ancestors |
156 | 181 |
|
157 | | - # --- Color map for node types --- |
| 182 | + # --- Color map for node types (uses layer_name for ATen ops, layer_type for others) --- |
158 | 183 | color_map = { |
159 | 184 | "MLP_Layer": "lightblue", |
160 | 185 | "DL_Layer": "lightgreen", |
@@ -184,7 +209,9 @@ def find_filtered_ancestors(node_raw, visited=None): |
184 | 209 | if name not in top_node_names: |
185 | 210 | continue |
186 | 211 | rel = relevance_data.get(name, (0.0, 0.0, 0.0)) |
187 | | - fill = color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white") |
| 212 | + # Use layer_name first (for semantic category), then layer_type as fallback |
| 213 | + node_category = _get_node_category(graph.nodes[node]) |
| 214 | + fill = color_map.get(node_category, color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white")) |
188 | 215 | g.node( |
189 | 216 | name, |
190 | 217 | label=f"{name}\nMean: {rel[0]:.3f}\nMax: {rel[1]:.3f}\nMin: {rel[2]:.3f}", |
@@ -356,12 +383,12 @@ def _norm(s): |
356 | 383 | # present nodes - keep all for transitive edge computation |
357 | 384 | all_raw = list(graph.nodes.keys()) |
358 | 385 |
|
359 | | - # Determine filtered set |
| 386 | + # Determine filtered set using _get_node_category for proper semantic matching |
360 | 387 | if layer_types is not None: |
361 | 388 | layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) |
362 | 389 | present_raw = [ |
363 | 390 | raw for raw in all_raw |
364 | | - if graph.nodes[raw].get("layer_type") in layer_types_set |
| 391 | + if _get_node_category(graph.nodes[raw]) in layer_types_set |
365 | 392 | ] |
366 | 393 | else: |
367 | 394 | present_raw = all_raw |
@@ -479,8 +506,9 @@ def _short(s, n=48): |
479 | 506 | for raw in present_raw: |
480 | 507 | nk = norm_by_raw[raw] |
481 | 508 | mean, mx, mn = rel_map.get(nk, (0.0, 0.0, 0.0)) |
482 | | - lt = graph.nodes[raw].get("layer_type", "Unknown") |
483 | | - fill = color_map.get(lt, "white") |
| 509 | + # Use _get_node_category for proper semantic coloring |
| 510 | + node_category = _get_node_category(graph.nodes[raw]) |
| 511 | + fill = color_map.get(node_category, color_map.get(graph.nodes[raw].get("layer_type", "Unknown"), "white")) |
484 | 512 | collapsed = graph.nodes[raw].get("collapsed_count", 0) |
485 | 513 | collapsed_line = f"\n[collapsed {collapsed}]" if collapsed else "" |
486 | 514 |
|
@@ -568,7 +596,7 @@ def visualize_relevance_auto( |
568 | 596 | layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) |
569 | 597 | filtered_count = sum( |
570 | 598 | 1 for n in graph.nodes |
571 | | - if graph.nodes[n].get("layer_type") in layer_types_set |
| 599 | + if _get_node_category(graph.nodes[n]) in layer_types_set |
572 | 600 | ) |
573 | 601 | num_nodes = filtered_count |
574 | 602 | print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)") |
|
0 commit comments