Skip to content

Commit 75b8a65

Browse files
committed
add print statements for before/after node counts in filtering
Shows total nodes → filtered nodes count for each filtering mode: - Layer-type filtering (visualize_relevance and visualize_relevance_fast) - Top-k filtering - Threshold filtering - No filtering
1 parent f8f9cf6 commit 75b8a65

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph",
133133

134134
# --- Filter based on top_k, threshold, or layer_types ---
135135
flat_scores = {k: v[0] for k, v in relevance_data.items()}
136+
total_nodes = len(graph.nodes)
136137

137138
force_include = {
138139
node.replace("/", " ").replace(":", " ")
@@ -148,13 +149,17 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph",
148149
for node in graph.nodes
149150
if _get_node_category(graph.nodes[node]) in layer_types_set
150151
} | force_include
152+
print(f"📊 Layer-type filtering: {total_nodes} nodes → {len(top_node_names)} nodes (filter: {list(layer_types_set)[:5]}{'...' if len(layer_types_set) > 5 else ''})")
151153
elif top_k:
152154
top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k]
153155
top_node_names = {k for k, _ in top_keys} | force_include
156+
print(f"📊 Top-k filtering: {total_nodes} nodes → {len(top_node_names)} nodes (top_k={top_k})")
154157
elif relevance_threshold is not None:
155158
top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold} | force_include
159+
print(f"📊 Threshold filtering: {total_nodes} nodes → {len(top_node_names)} nodes (threshold={relevance_threshold})")
156160
else:
157161
top_node_names = set(relevance_data.keys()) | force_include
162+
print(f"📊 No filtering: {total_nodes} nodes")
158163

159164
# --- Build raw->normalized name mapping for ancestor lookup ---
160165
raw_to_norm = {node: node.replace("/", " ").replace(":", " ") for node in graph.nodes}
@@ -382,6 +387,7 @@ def _norm(s):
382387

383388
# present nodes - keep all for transitive edge computation
384389
all_raw = list(graph.nodes.keys())
390+
total_nodes = len(all_raw)
385391

386392
# Determine filtered set using _get_node_category for proper semantic matching
387393
if layer_types is not None:
@@ -390,8 +396,10 @@ def _norm(s):
390396
raw for raw in all_raw
391397
if _get_node_category(graph.nodes[raw]) in layer_types_set
392398
]
399+
print(f"📊 Layer-type filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (filter: {list(layer_types)[:5]}{'...' if len(layer_types) > 5 else ''})")
393400
else:
394401
present_raw = all_raw
402+
print(f"📊 No filtering (fast): {total_nodes} nodes")
395403

396404
norm_by_raw = {raw: _norm(raw) for raw in all_raw} # All nodes for lookup
397405
present_norm = {_norm(raw) for raw in present_raw} # Filtered set

0 commit comments

Comments
 (0)