Skip to content

Commit 26d4f70

Browse files
neerajaryaaiCopilot
andcommitted
Fix layer_types filtering to use layer_name for semantic categories
The graph stores semantic categories (DL_Layer, MLP_Layer, Activation, etc.) in layer_name field, not layer_type. layer_type is always 'ATen_Operation' for PyTorch ops. - Add _get_node_category() helper to check both layer_name and layer_type - Update visualize_relevance() to use _get_node_category for filtering - Update visualize_relevance_fast() to use _get_node_category - Update visualize_relevance_auto() node counting - Fix color map lookup to use semantic categories Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 4ce3f95 commit 26d4f70

1 file changed

Lines changed: 44 additions & 16 deletions

File tree

dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,48 @@
99
from IPython.display import display, SVG, Image as IPyImage
1010
from typing import Optional, Sequence
1111

12-
# Semantically meaningful layer types for compact visualization
12+
# Semantically meaningful layer categories for compact visualization
13+
# These match the ATEN_LAYER_MAP categories in graph_builder.py
1314
SEMANTIC_LAYER_TYPES: tuple[str, ...] = (
14-
"MLP_Layer", # Linear/FC layers
15-
"DL_Layer", # Conv layers (Conv1d, Conv2d, Conv3d)
15+
"MLP_Layer", # Linear/FC layers (linear, addmm)
16+
"DL_Layer", # Conv layers (conv2d, max_pool2d, etc.)
1617
"Activation", # ReLU, GELU, SiLU, etc.
1718
"Normalization", # BatchNorm, LayerNorm, GroupNorm
1819
"Attention", # Self/Cross attention (scaled_dot_product_attention)
19-
"Output", # Final output node
20-
"Placeholder", # Input nodes (x, input_ids, etc.) - CRITICAL for graph connectivity
21-
"Model_Input", # Legacy input type (kept for compatibility)
20+
"Output", # Final output node (layer_type)
21+
"Placeholder", # Input nodes (layer_type) - CRITICAL for graph connectivity
22+
"Model_Input", # Legacy input type (layer_type)
2223
"NLP_Embedding", # Embedding layers (embedding, embedding_bag)
2324
)
2425

2526
# Default types to always force-include (for graph connectivity)
2627
DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Placeholder", "Model_Input", "Output")
2728

2829

30+
def _get_node_category(node_attrs: dict) -> str:
31+
"""Get the semantic category for a node by checking both layer_type and layer_name.
32+
33+
The graph stores:
34+
- layer_type: 'ATen_Operation', 'Placeholder', 'Output', 'Operation', etc.
35+
- layer_name: 'DL_Layer', 'MLP_Layer', 'Activation', 'Normalization', etc.
36+
37+
For filtering, we need to check layer_name first (for ATen ops), then layer_type.
38+
"""
39+
layer_name = node_attrs.get("layer_name", "")
40+
layer_type = node_attrs.get("layer_type", "Unknown")
41+
42+
# For ATen operations, layer_name contains the semantic category
43+
if layer_name in SEMANTIC_LAYER_TYPES:
44+
return layer_name
45+
46+
# For Placeholder, Output, etc., layer_type is the category
47+
if layer_type in ("Placeholder", "Output", "Model_Input"):
48+
return layer_type
49+
50+
# Return layer_type as fallback
51+
return layer_type
52+
53+
2954
def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600):
3055
"""📊 Visualize forward execution graph with dynamic scaling (shows inline + saves)"""
3156
num_nodes = len(graph.nodes)
@@ -112,16 +137,16 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph",
112137
force_include = {
113138
node.replace("/", " ").replace(":", " ")
114139
for node in graph.nodes
115-
if graph.nodes[node].get("layer_type") in DEFAULT_FORCE_INCLUDE_TYPES
140+
if _get_node_category(graph.nodes[node]) in DEFAULT_FORCE_INCLUDE_TYPES
116141
}
117142

118143
if layer_types is not None:
119-
# Layer-type filtering mode
144+
# Layer-type filtering mode - use _get_node_category for proper semantic matching
120145
layer_types_set = set(layer_types)
121146
top_node_names = {
122147
node.replace("/", " ").replace(":", " ")
123148
for node in graph.nodes
124-
if graph.nodes[node].get("layer_type") in layer_types_set
149+
if _get_node_category(graph.nodes[node]) in layer_types_set
125150
} | force_include
126151
elif top_k:
127152
top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k]
@@ -154,7 +179,7 @@ def find_filtered_ancestors(node_raw, visited=None):
154179
ancestors.update(find_filtered_ancestors(parent_raw, visited))
155180
return ancestors
156181

157-
# --- Color map for node types ---
182+
# --- Color map for node types (uses layer_name for ATen ops, layer_type for others) ---
158183
color_map = {
159184
"MLP_Layer": "lightblue",
160185
"DL_Layer": "lightgreen",
@@ -184,7 +209,9 @@ def find_filtered_ancestors(node_raw, visited=None):
184209
if name not in top_node_names:
185210
continue
186211
rel = relevance_data.get(name, (0.0, 0.0, 0.0))
187-
fill = color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white")
212+
# Use layer_name first (for semantic category), then layer_type as fallback
213+
node_category = _get_node_category(graph.nodes[node])
214+
fill = color_map.get(node_category, color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white"))
188215
g.node(
189216
name,
190217
label=f"{name}\nMean: {rel[0]:.3f}\nMax: {rel[1]:.3f}\nMin: {rel[2]:.3f}",
@@ -356,12 +383,12 @@ def _norm(s):
356383
# present nodes - keep all for transitive edge computation
357384
all_raw = list(graph.nodes.keys())
358385

359-
# Determine filtered set
386+
# Determine filtered set using _get_node_category for proper semantic matching
360387
if layer_types is not None:
361388
layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES)
362389
present_raw = [
363390
raw for raw in all_raw
364-
if graph.nodes[raw].get("layer_type") in layer_types_set
391+
if _get_node_category(graph.nodes[raw]) in layer_types_set
365392
]
366393
else:
367394
present_raw = all_raw
@@ -479,8 +506,9 @@ def _short(s, n=48):
479506
for raw in present_raw:
480507
nk = norm_by_raw[raw]
481508
mean, mx, mn = rel_map.get(nk, (0.0, 0.0, 0.0))
482-
lt = graph.nodes[raw].get("layer_type", "Unknown")
483-
fill = color_map.get(lt, "white")
509+
# Use _get_node_category for proper semantic coloring
510+
node_category = _get_node_category(graph.nodes[raw])
511+
fill = color_map.get(node_category, color_map.get(graph.nodes[raw].get("layer_type", "Unknown"), "white"))
484512
collapsed = graph.nodes[raw].get("collapsed_count", 0)
485513
collapsed_line = f"\n[collapsed {collapsed}]" if collapsed else ""
486514

@@ -568,7 +596,7 @@ def visualize_relevance_auto(
568596
layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES)
569597
filtered_count = sum(
570598
1 for n in graph.nodes
571-
if graph.nodes[n].get("layer_type") in layer_types_set
599+
if _get_node_category(graph.nodes[n]) in layer_types_set
572600
)
573601
num_nodes = filtered_count
574602
print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)")

0 commit comments

Comments
 (0)