diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py index 94a7edb..ebf39be 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py @@ -3,11 +3,12 @@ from .utils import * from .aten_operations import * from .core.relevance_saver import save_relevance, load_relevance, Precision +from .core.visualization import SEMANTIC_LAYER_TYPES # Export pipeline modules (optional, can be imported separately) try: from .pipeline import DLBacktracePipeline, ModelRegistry, get_model_info, PipelineConfig - __all__ = ['DLBacktrace', 'activation_master', 'DLBacktracePipeline', 'ModelRegistry', 'get_model_info', 'PipelineConfig'] + __all__ = ['DLBacktrace', 'activation_master', 'DLBacktracePipeline', 'ModelRegistry', 'get_model_info', 'PipelineConfig', 'SEMANTIC_LAYER_TYPES'] except ImportError: - __all__ = ['DLBacktrace', 'activation_master'] + __all__ = ['DLBacktrace', 'activation_master', 'SEMANTIC_LAYER_TYPES'] \ No newline at end of file diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index e90ff02..cb6b953 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -7,6 +7,58 @@ from networkx.drawing.nx_pydot import graphviz_layout from collections import defaultdict from IPython.display import display, SVG, Image as IPyImage +from typing import Optional, Sequence + +# Semantically meaningful layer categories for compact visualization +# These match the ATEN_LAYER_MAP categories in graph_builder.py +SEMANTIC_LAYER_TYPES: tuple[str, ...] = ( + "MLP_Layer", # Linear/FC layers (linear, addmm) + "DL_Layer", # Conv layers (conv2d, max_pool2d, etc.) + "Activation", # ReLU, GELU, SiLU, etc. + "Normalization", # BatchNorm, LayerNorm, GroupNorm + "Attention", # Self/Cross attention (scaled_dot_product_attention) + "Output", # Final output node (layer_type) + "Placeholder", # Input nodes (layer_type) - CRITICAL for graph connectivity + "Model_Input", # Legacy input type (layer_type) + "NLP_Embedding", # Embedding layers (embedding, embedding_bag) +) + +# Default types to always force-include (for graph connectivity) +# Note: Placeholder excluded to keep compact graphs clean +DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Model_Input", "Output") + +# Prefixes for parameter/bias nodes to exclude from compact graphs +EXCLUDED_NODE_PREFIXES: tuple[str, ...] = ("p_model_", "b_model_") + + +def _is_excluded_node(node_name: str) -> bool: + """Check if node should be excluded (parameter/bias weights).""" + return any(node_name.startswith(prefix) for prefix in EXCLUDED_NODE_PREFIXES) + + +def _get_node_category(node_attrs: dict) -> str: + """Get the semantic category for a node by checking both layer_type and layer_name. + + The graph stores: + - layer_type: 'ATen_Operation', 'Placeholder', 'Output', 'Operation', etc. + - layer_name: 'DL_Layer', 'MLP_Layer', 'Activation', 'Normalization', etc. + + For filtering, we need to check layer_name first (for ATen ops), then layer_type. + """ + layer_name = node_attrs.get("layer_name", "") + layer_type = node_attrs.get("layer_type", "Unknown") + + # For ATen operations, layer_name contains the semantic category + if layer_name in SEMANTIC_LAYER_TYPES: + return layer_name + + # For Output, Model_Input, etc., layer_type is the category + # Note: Placeholder excluded to keep compact graphs clean + if layer_type in ("Output", "Model_Input"): + return layer_type + + # Return layer_type as fallback + return layer_type def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): @@ -47,18 +99,54 @@ def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): def visualize_relevance(graph, all_wt, output_path="backtrace_graph", *, top_k=None, relevance_threshold=None, + layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg"): - """šŸŽÆ Visualize relevance backtrace using Graphviz (shows inline + saves)""" + """šŸŽÆ Visualize relevance backtrace using Graphviz (shows inline + saves) + + Parameters + ---------- + graph : networkx.DiGraph + The computation graph with layer_type attributes on nodes + all_wt : dict + Relevance weights for each node + output_path : str + Output file path (without extension) + top_k : int, optional + Show only top-k nodes by relevance + relevance_threshold : float, optional + Show nodes with |relevance| >= threshold + layer_types : list[str], optional + Filter to only these layer types. If None, shows all nodes. + Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall), + "RL" (right-to-left), "BT" (bottom-to-top). Default "LR". + Use "TB" for LaTeX/paper-friendly vertical layout. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + """ relevance_data = {} # --- Extract relevance stats from all_wt --- + # Mean: sum of all entries (for batch=1) or average of sums across batches + # Max/Min: max/min of batch sums (for batched) or max/min element (for single) for node_name, rel in all_wt.items(): node_key = node_name.replace("/", " ").replace(":", " ") if isinstance(rel, (list, tuple)): - flat = [float(r.sum()) for r in rel if hasattr(r, "sum")] - stats = (float(sum(flat) / len(flat)), max(flat), min(flat)) if flat else (0.0, 0.0, 0.0) + # Batched data: list of tensors + batch_sums = [float(r.sum()) for r in rel if hasattr(r, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + stats = (mean_val, max(batch_sums), min(batch_sums)) + else: + stats = (0.0, 0.0, 0.0) elif hasattr(rel, "sum"): - stats = (float(rel.mean()), float(rel.max()), float(rel.min())) + # Single tensor (batch size = 1) + # Mean = sum of all entries in relevance vector + stats = (float(rel.sum()), float(rel.max()), float(rel.min())) else: try: val = float(rel) @@ -67,24 +155,62 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", stats = (0.0, 0.0, 0.0) relevance_data[node_key] = stats - # --- Filter based on top_k or threshold --- + # --- Filter based on top_k, threshold, or layer_types --- flat_scores = {k: v[0] for k, v in relevance_data.items()} + total_nodes = len(graph.nodes) force_include = { node.replace("/", " ").replace(":", " ") for node in graph.nodes - if graph.nodes[node].get("layer_type") in ("Placeholder", "Model_Input") + if _get_node_category(graph.nodes[node]) in DEFAULT_FORCE_INCLUDE_TYPES + and not _is_excluded_node(node) } - if top_k: + if layer_types is not None: + # Layer-type filtering mode - use _get_node_category for proper semantic matching + layer_types_set = set(layer_types) + top_node_names = { + node.replace("/", " ").replace(":", " ") + for node in graph.nodes + if _get_node_category(graph.nodes[node]) in layer_types_set + and not _is_excluded_node(node) + } | force_include + print(f"šŸ“Š Layer-type filtering: {total_nodes} nodes → {len(top_node_names)} nodes (filter: {list(layer_types_set)[:5]}{'...' if len(layer_types_set) > 5 else ''})") + elif top_k: top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] - top_node_names = {k for k, _ in top_keys} | force_include + top_node_names = {k for k, _ in top_keys if not _is_excluded_node(k)} | force_include + print(f"šŸ“Š Top-k filtering: {total_nodes} nodes → {len(top_node_names)} nodes (top_k={top_k})") elif relevance_threshold is not None: - top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold} | force_include + top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold and not _is_excluded_node(k)} | force_include + print(f"šŸ“Š Threshold filtering: {total_nodes} nodes → {len(top_node_names)} nodes (threshold={relevance_threshold})") else: - top_node_names = set(relevance_data.keys()) | force_include - - # --- Color map for node types --- + top_node_names = {k for k in relevance_data.keys() if not _is_excluded_node(k)} | force_include + print(f"šŸ“Š No filtering: {total_nodes} nodes") + + # --- Build raw->normalized name mapping for ancestor lookup --- + raw_to_norm = {node: node.replace("/", " ").replace(":", " ") for node in graph.nodes} + norm_to_raw = {v: k for k, v in raw_to_norm.items()} + + # --- Helper to find transitive ancestors in filtered set --- + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find all ancestors that are in the filtered set.""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = raw_to_norm.get(parent_raw, parent_raw.replace("/", " ").replace(":", " ")) + if parent_norm in top_node_names: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + # Recursively search this parent's ancestors + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors + + # --- Color map for node types (uses layer_name for ATen ops, layer_type for others) --- color_map = { "MLP_Layer": "lightblue", "DL_Layer": "lightgreen", @@ -104,7 +230,7 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", g = graphviz.Digraph( "DLBacktrace", format="svg", - graph_attr={"rankdir": "LR", "splines": "spline"}, + graph_attr={"rankdir": rankdir, "splines": "spline"}, node_attr={"fontname": "Helvetica", "fontsize": "10"} ) @@ -114,7 +240,9 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", if name not in top_node_names: continue rel = relevance_data.get(name, (0.0, 0.0, 0.0)) - fill = color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white") + # Use layer_name first (for semantic category), then layer_type as fallback + node_category = _get_node_category(graph.nodes[node]) + fill = color_map.get(node_category, color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white")) g.node( name, label=f"{name}\nMean: {rel[0]:.3f}\nMax: {rel[1]:.3f}\nMin: {rel[2]:.3f}", @@ -122,15 +250,30 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", fillcolor=fill, ) - # --- Add edges --- + # --- Add edges (with transitive connections when layer_types filtering) --- + added_edges = set() for node in graph.nodes: name = node.replace("/", " ").replace(":", " ") if name not in top_node_names: continue - for parent in graph.nodes[node].get("parents", []): - parent_fmt = parent.replace("/", " ").replace(":", " ") - if parent_fmt in top_node_names: - g.edge(parent_fmt, name) + + if layer_types is not None: + # Use transitive ancestor search for filtered graphs + ancestors = find_filtered_ancestors(node) + for ancestor in ancestors: + edge = (ancestor, name) + if edge not in added_edges: + added_edges.add(edge) + g.edge(ancestor, name) + else: + # Original direct parent logic + for parent in graph.nodes[node].get("parents", []): + parent_fmt = parent.replace("/", " ").replace(":", " ") + if parent_fmt in top_node_names: + edge = (parent_fmt, name) + if edge not in added_edges: + added_edges.add(edge) + g.edge(parent_fmt, name) out = g.render(output_path, format="svg", cleanup=True) @@ -254,32 +397,83 @@ def visualize_relevance_fast( max_parents_per_node=None, engine_auto_threshold=1200, disable_concentrate_for_sfdp=True, + layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg", ): + """Fast visualization for large/collapsed graphs. + + Parameters + ---------- + layer_types : list[str], optional + Filter to only these layer types. If None, shows all nodes. + rankdir : str + Graph direction: "LR" (left-to-right), "TB" (top-to-bottom). + Default "LR". Use "TB" for LaTeX-friendly vertical layout. + """ def _norm(s): return s.replace("/", " ").replace(":", " ") - # present nodes - present_raw = list(graph.nodes.keys()) - norm_by_raw = {raw: _norm(raw) for raw in present_raw} - present_norm = set(norm_by_raw.values()) + # present nodes - keep all for transitive edge computation + all_raw = list(graph.nodes.keys()) + total_nodes = len(all_raw) + + # Determine filtered set using _get_node_category for proper semantic matching + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + present_raw = [ + raw for raw in all_raw + if _get_node_category(graph.nodes[raw]) in layer_types_set + and not _is_excluded_node(raw) + ] + print(f"šŸ“Š Layer-type filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (filter: {list(layer_types)[:5]}{'...' if len(layer_types) > 5 else ''})") + else: + present_raw = [raw for raw in all_raw if not _is_excluded_node(raw)] + print(f"šŸ“Š No filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (excluded p_model_*/b_model_*)") + + norm_by_raw = {raw: _norm(raw) for raw in all_raw} # All nodes for lookup + present_norm = {_norm(raw) for raw in present_raw} # Filtered set + + # Helper to find transitive ancestors in filtered set + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find all ancestors that are in the filtered set.""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) or [] + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = norm_by_raw.get(parent_raw, _norm(parent_raw)) + if parent_norm in present_norm: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + # Recursively search this parent's ancestors + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors # relevance only for present + # Mean: sum of all entries (for batch=1) or average of sums across batches + # Max: maximum individual value in the relevance tensor + # Min: minimum individual value in the relevance tensor rel_map = {} for k, v in all_wt.items(): nk = _norm(k) if nk not in present_norm: continue if isinstance(v, (list, tuple)): - flat = [float(t.sum()) for t in v if hasattr(t, "sum")] - if flat: - mean = float(sum(flat) / len(flat)) - rel_map[nk] = (mean, max(flat), min(flat)) + # Batched data + batch_sums = [float(t.sum()) for t in v if hasattr(t, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + rel_map[nk] = (mean_val, max(batch_sums), min(batch_sums)) else: rel_map[nk] = (0.0, 0.0, 0.0) elif hasattr(v, "sum"): - rel_map[nk] = (float(v.mean()), float(v.max()), float(v.min())) + # Single tensor (batch size = 1): Mean = sum of all entries + rel_map[nk] = (float(v.sum()), float(v.max()), float(v.min())) else: try: x = float(v) @@ -317,7 +511,7 @@ def _norm(s): "outputorder": "edgesfirst", } if engine == "dot": - graph_attr["rankdir"] = "LR" + graph_attr["rankdir"] = rankdir graph_attr["splines"] = "spline" graph_attr["concentrate"] = "true" else: @@ -356,8 +550,9 @@ def _short(s, n=48): for raw in present_raw: nk = norm_by_raw[raw] mean, mx, mn = rel_map.get(nk, (0.0, 0.0, 0.0)) - lt = graph.nodes[raw].get("layer_type", "Unknown") - fill = color_map.get(lt, "white") + # Use _get_node_category for proper semantic coloring + node_category = _get_node_category(graph.nodes[raw]) + fill = color_map.get(node_category, color_map.get(graph.nodes[raw].get("layer_type", "Unknown"), "white")) collapsed = graph.nodes[raw].get("collapsed_count", 0) collapsed_line = f"\n[collapsed {collapsed}]" if collapsed else "" @@ -374,25 +569,35 @@ def _short(s, n=48): fillcolor=fill, ) - # edges + # edges - use transitive ancestors when layer_types filtering added = set() for raw in present_raw: child = norm_by_raw[raw] - parents = graph.nodes[raw].get("parents", []) or [] - if max_parents_per_node is not None and len(parents) > max_parents_per_node: - parents = sorted( - parents, - key=lambda p: abs(rel_map.get(_norm(p), (0.0, 0.0, 0.0))[0]), - reverse=True, - )[:max_parents_per_node] - - for p_raw in parents: - pn = norm_by_raw.get(p_raw, _norm(p_raw)) - e = (pn, child) - if e in added: - continue - added.add(e) - g.edge(pn, child) + + if layer_types is not None: + # Use transitive ancestor search for filtered graphs + ancestors = find_filtered_ancestors(raw) + for ancestor in ancestors: + e = (ancestor, child) + if e not in added: + added.add(e) + g.edge(ancestor, child) + else: + # Original direct parent logic + parents = graph.nodes[raw].get("parents", []) or [] + if max_parents_per_node is not None and len(parents) > max_parents_per_node: + parents = sorted( + parents, + key=lambda p: abs(rel_map.get(_norm(p), (0.0, 0.0, 0.0))[0]), + reverse=True, + )[:max_parents_per_node] + + for p_raw in parents: + pn = norm_by_raw.get(p_raw, _norm(p_raw)) + e = (pn, child) + if e not in added: + added.add(e) + g.edge(pn, child) out = g.render(output_path, cleanup=True) @@ -417,12 +622,36 @@ def visualize_relevance_auto( node_threshold=500, engine_auto_threshold=1500, fast_output_path="backtrace_collapsed_fast", + layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg", ): - """Auto-choose pretty vs fast; always show inline and save.""" - num_nodes = len(graph.nodes) - print(f"num_nodes: {num_nodes}") + """Auto-choose pretty vs fast visualization; always show inline and save. + + Parameters + ---------- + layer_types : list[str], optional + Filter to only these layer types. If specified, layer_types filtering + takes precedence over automatic collapsing for large graphs. + Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall). + Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout. + """ + # If layer_types specified, count only matching nodes for threshold decision + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + filtered_count = sum( + 1 for n in graph.nodes + if _get_node_category(graph.nodes[n]) in layer_types_set + and not _is_excluded_node(n) + ) + num_nodes = filtered_count + print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)") + else: + num_nodes = sum(1 for n in graph.nodes if not _is_excluded_node(n)) + print(f"num_nodes: {num_nodes} (from {len(graph.nodes)} total, excluded p_model_*/b_model_*)") if num_nodes < node_threshold: # small graph → original pretty version @@ -430,6 +659,8 @@ def visualize_relevance_auto( graph, all_wt, output_path=output_path, + layer_types=layer_types, + rankdir=rankdir, show=show, inline_format=inline_format, ) @@ -448,6 +679,404 @@ def visualize_relevance_auto( collapsed_map=collapsed_map, max_parents_per_node=2, engine_auto_threshold=engine_auto_threshold, + layer_types=layer_types, + rankdir=rankdir, show=show, inline_format=inline_format, ) + + +def visualize_relevance_paginated( + graph, + all_wt, + output_path="backtrace_graph", + *, + max_nodes_per_page: int = 30, + layer_types: Optional[Sequence[str]] = None, + rankdir: str = "TB", + pages_per_row: int = 1, + show=True, + inline_format="svg", +): + """šŸŽÆ Visualize relevance backtrace as multiple sub-graphs (pages) for long DAGs. + + Splits the graph into topologically-ordered pages, each containing at most + `max_nodes_per_page` nodes. Each page is saved as a separate file and + displayed inline sequentially. + + Parameters + ---------- + graph : networkx.DiGraph + The computation graph with layer_type attributes on nodes + all_wt : dict + Relevance weights for each node + output_path : str + Base output file path (without extension). Pages will be named + {output_path}_page1.svg, {output_path}_page2.svg, etc. + When pages_per_row > 1, also creates {output_path}_combined.svg + max_nodes_per_page : int + Maximum number of nodes per page/sub-graph (default: 30) + layer_types : list[str], optional + Filter to only these layer types. Use SEMANTIC_LAYER_TYPES for compact graphs. + rankdir : str + Graph direction: "TB" (top-to-bottom, default for paginated), + "LR" (left-to-right). Default "TB" for vertical flow. + pages_per_row : int + Number of pages to display side-by-side in a combined view (default: 1). + When > 1, creates an additional combined SVG with pages arranged + left-to-right. Useful for fitting multiple pages on one LaTeX page. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + + Returns + ------- + list[tuple[graphviz.Digraph, str]] + List of (graph, output_path) tuples for each page + """ + # --- Extract relevance stats --- + def _norm(s): + return s.replace("/", " ").replace(":", " ") + + relevance_data = {} + for node_name, rel in all_wt.items(): + node_key = _norm(node_name) + if isinstance(rel, (list, tuple)): + batch_sums = [float(r.sum()) for r in rel if hasattr(r, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + relevance_data[node_key] = (mean_val, max(batch_sums), min(batch_sums)) + else: + relevance_data[node_key] = (0.0, 0.0, 0.0) + elif hasattr(rel, "sum"): + relevance_data[node_key] = (float(rel.sum()), float(rel.max()), float(rel.min())) + else: + try: + val = float(rel) + relevance_data[node_key] = (val, val, val) + except Exception: + relevance_data[node_key] = (0.0, 0.0, 0.0) + + # --- Filter nodes --- + total_nodes = len(graph.nodes) + + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + filtered_nodes = [ + node for node in graph.nodes + if _get_node_category(graph.nodes[node]) in layer_types_set + and not _is_excluded_node(node) + ] + else: + filtered_nodes = [node for node in graph.nodes if not _is_excluded_node(node)] + + filtered_norm = {_norm(n) for n in filtered_nodes} + print(f"šŸ“Š Paginated: {total_nodes} nodes → {len(filtered_nodes)} nodes") + + # --- Build parent mapping for filtered nodes --- + raw_to_norm = {node: _norm(node) for node in graph.nodes} + + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find ancestors in filtered set (for transitive edges).""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) or [] + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = raw_to_norm.get(parent_raw, _norm(parent_raw)) + if parent_norm in filtered_norm: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors + + # --- Topological sort of filtered nodes --- + # Build edges for filtered subgraph + filtered_edges = {} # node_norm -> set of parent_norms (in filtered set) + for node in filtered_nodes: + node_norm = _norm(node) + if layer_types is not None: + filtered_edges[node_norm] = find_filtered_ancestors(node) + else: + parents = graph.nodes[node].get("parents", []) or [] + filtered_edges[node_norm] = {_norm(p) for p in parents if _norm(p) in filtered_norm} + + # Kahn's algorithm for topological sort + in_degree = {n: 0 for n in filtered_norm} + for node, parents in filtered_edges.items(): + for p in parents: + if p in in_degree: + in_degree[node] = in_degree.get(node, 0) + 1 + + # Start with nodes that have no filtered parents (in_degree == 0) + queue = [n for n, d in in_degree.items() if d == 0] + topo_order = [] + + while queue: + node = queue.pop(0) + topo_order.append(node) + # Find children of this node + for child, parents in filtered_edges.items(): + if node in parents: + in_degree[child] -= 1 + if in_degree[child] == 0 and child not in topo_order: + queue.append(child) + + # Add any remaining nodes (handles cycles gracefully) + for n in filtered_norm: + if n not in topo_order: + topo_order.append(n) + + print(f"šŸ“Š Topological order: {len(topo_order)} nodes") + + # --- Split into pages --- + pages = [] + for i in range(0, len(topo_order), max_nodes_per_page): + pages.append(topo_order[i:i + max_nodes_per_page]) + + print(f"šŸ“Š Split into {len(pages)} pages (max {max_nodes_per_page} nodes/page)") + + # --- Color scale setup --- + all_means = [relevance_data.get(n, (0.0, 0.0, 0.0))[0] for n in topo_order] + max_abs = max(abs(v) for v in all_means) if all_means else 1.0 + if max_abs == 0: + max_abs = 1.0 + + def get_color(mean_val): + norm = mean_val / max_abs + if norm >= 0: + r = int(255 * (1 - norm)) + return f"#{r:02x}ff{r:02x}" # green + else: + g = int(255 * (1 + norm)) + return f"#ff{g:02x}{g:02x}" # red + + # --- Render each page --- + results = [] + for page_idx, page_nodes in enumerate(pages): + page_num = page_idx + 1 + page_node_set = set(page_nodes) + + # Include connector nodes from previous page for context + connector_nodes = set() + if page_idx > 0: + prev_page_nodes = set(pages[page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) + + g = graphviz.Digraph( + name=f"DLBacktrace_Page{page_num}", + format="svg", + graph_attr={ + "rankdir": rankdir, + "label": f"DLBacktrace Graph - Page {page_num}/{len(pages)}", + "labelloc": "t", + "fontsize": "14", + "nodesep": "0.3", + "ranksep": "0.5", + }, + node_attr={ + "shape": "box", + "style": "filled,rounded", + "fontsize": "10", + }, + edge_attr={ + "fontsize": "8", + }, + ) + + # Add connector nodes (grayed out, from previous page) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + label = f"{node}\n(from prev page)" + g.node(node, label=label, fillcolor="lightgray", style="filled,rounded,dashed") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + label = f"{node}\nMean={mean:.4f}\nMax={mx:.4f} Min={mn:.4f}" + color = get_color(mean) + g.node(node, label=label, fillcolor=color) + + # Add edges within page and from connectors + added_edges = set() + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + edge = (parent, node) + if edge not in added_edges: + added_edges.add(edge) + # Get edge weight (child's mean relevance) + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + g.edge(parent, node, label=f"{child_mean:.3f}") + + # Render + page_output = f"{output_path}_page{page_num}" + out = g.render(page_output, cleanup=True) + results.append((g, out)) + + # Show inline + if show and pages_per_row == 1: + print(f"\n{'='*50}") + print(f"šŸ“„ Page {page_num}/{len(pages)} ({len(page_nodes)} nodes)") + print(f"{'='*50}") + if inline_format.lower() == "svg": + svg_bytes = g.pipe(format="svg") + display(SVG(svg_bytes)) + else: + png_bytes = g.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Page {page_num} saved → {out}") + + # --- Create combined grid/matrix views when pages_per_row > 1 --- + # Layout: pages arranged in columns, flowing top-to-bottom within each column + # Example: 3 pages with pages_per_row=3 creates a Nx3 matrix + if pages_per_row > 1 and len(pages) > 1: + num_columns = min(pages_per_row, len(pages)) + + # Organize pages into columns + # Column 0: pages[0], pages[num_columns], pages[2*num_columns], ... + # Column 1: pages[1], pages[num_columns+1], ... + columns = [[] for _ in range(num_columns)] + for page_idx, page_nodes in enumerate(pages): + col_idx = page_idx % num_columns + columns[col_idx].append((page_idx, page_nodes)) + + num_rows = max(len(col) for col in columns) + + print(f"\nšŸ“Š Creating grid layout: {num_rows} rows Ɨ {num_columns} columns...") + + # Create a master graph + combined = graphviz.Digraph( + name="DLBacktrace_Combined_Grid", + format="svg", + engine="dot", + graph_attr={ + "rankdir": "TB", # Top-to-bottom for rows + "label": f"DLBacktrace Graph - Grid View ({len(pages)} pages in {num_rows}Ɨ{num_columns} layout)", + "labelloc": "t", + "fontsize": "14", + "splines": "ortho", # Orthogonal edges for cleaner look + "nodesep": "0.4", + "ranksep": "0.6", + "newrank": "true", + }, + ) + + # Create nodes for each page in its column cluster + for col_idx, col_pages in enumerate(columns): + with combined.subgraph(name=f"cluster_col{col_idx}") as col_subg: + col_subg.attr( + label=f"Column {col_idx + 1}", + style="rounded,dashed", + color="gray", + fontsize="10", + margin="15", + ) + + for row_idx, (global_page_idx, page_nodes) in enumerate(col_pages): + page_num = global_page_idx + 1 + page_node_set = set(page_nodes) + + # Determine connector nodes from previous page + connector_nodes = set() + if global_page_idx > 0: + prev_page_nodes = set(pages[global_page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) + + # Create a sub-cluster for this page + with col_subg.subgraph(name=f"cluster_page{page_num}") as page_subg: + page_subg.attr( + label=f"Page {page_num}", + style="rounded,filled", + color="black", + fillcolor="white", + fontsize="9", + margin="8", + ) + + prefix = f"p{page_num}_" + + # Add connector nodes (grayed out) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\n(prev)" + page_subg.node(prefix + node, label=label, fillcolor="lightgray", + style="filled,rounded,dashed", shape="box", fontsize="7") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\nM={mean:.3f}\n↑{mx:.3f}↓{mn:.3f}" + color = get_color(mean) + page_subg.node(prefix + node, label=label, fillcolor=color, + style="filled,rounded", shape="box", fontsize="7") + + # Add edges within page + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + page_subg.edge(prefix + parent, prefix + node, + label=f"{child_mean:.2f}", fontsize="6") + + # Add invisible edge to next page in this column + if row_idx < len(col_pages) - 1: + next_page_idx, next_page_nodes = col_pages[row_idx + 1] + if page_nodes and next_page_nodes: + src_prefix = f"p{page_num}_" + dst_prefix = f"p{next_page_idx + 1}_" + combined.edge(src_prefix + page_nodes[-1], dst_prefix + next_page_nodes[0], + style="invis", constraint="true") + + # Add invisible edges between columns to align them horizontally + # Connect first node of each column's first page + for col_idx in range(num_columns - 1): + if columns[col_idx] and columns[col_idx + 1]: + _, left_nodes = columns[col_idx][0] + _, right_nodes = columns[col_idx + 1][0] + if left_nodes and right_nodes: + left_page_idx = columns[col_idx][0][0] + right_page_idx = columns[col_idx + 1][0][0] + src = f"p{left_page_idx + 1}_{left_nodes[0]}" + dst = f"p{right_page_idx + 1}_{right_nodes[0]}" + # Use rank=same to align horizontally + with combined.subgraph() as s: + s.attr(rank="same") + s.node(src) + s.node(dst) + + # Render combined grid + combined_output = f"{output_path}_combined" + combined_out = combined.render(combined_output, cleanup=True) + results.append((combined, combined_out)) + + if show: + print(f"\n{'='*60}") + print(f"šŸ“„ Combined Grid View ({num_rows} rows Ɨ {num_columns} columns)") + print(f"{'='*60}") + if inline_format.lower() == "svg": + svg_bytes = combined.pipe(format="svg") + display(SVG(svg_bytes)) + else: + png_bytes = combined.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Combined grid saved → {combined_out}") + + print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") + return results diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 1f26ffd..8c4a14a 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -11,7 +11,13 @@ from .core.dlb_auto_sampler import DLBAutoSampler from .core.relevance_propagation import RelevancePropagator from .core.compiled_propagation import PropagationSchedule -from .core.visualization import visualize_graph, visualize_relevance, visualize_relevance_auto +from .core.visualization import ( + visualize_graph, + visualize_relevance, + visualize_relevance_auto, + visualize_relevance_paginated, + SEMANTIC_LAYER_TYPES, +) from .core.token_relevance_visuals import ( plot_tokenwise_relevance_map_swapped, plot_input_heatmap_for_token, @@ -19,8 +25,11 @@ from .core.visualization_module_aware import visualize_relevance_with_module_labels from .core.relevance_saver import save_relevance as _save_relevance, Precision +import gc +import copy import numpy as np import torch +import torch.nn.functional as F import inspect class DLBacktrace: @@ -535,6 +544,350 @@ def evaluation(self, mode="default", start_wt=[], multiplier=100.0, scaler=1.0, ) return self.all_wt + # ------------------------------------------------------------------ # + # Two-step API: forward_pass() + relevance_pass() # + # ------------------------------------------------------------------ # + + @staticmethod + def _snapshot_node_io(node_io, move_to_cpu=True): + """ + Create a memory-efficient deep copy of node_io for one generation step. + + Only the large tensor fields (input_values, output_values) are cloned + and optionally moved to CPU. All other metadata (layer_name, func_name, + input_sources, output_children, layer_hyperparams, …) is shallow-copied + because it is immutable across generation steps. + + Args: + node_io (dict): The node I/O dict produced by predict(). + move_to_cpu (bool): Move cloned tensors to CPU to free GPU memory. + + Returns: + dict: A snapshot suitable for later relevance computation. + """ + def _clone_val(v): + if isinstance(v, torch.Tensor): + t = v.detach().clone() + return t.cpu() if move_to_cpu else t + if isinstance(v, (list, tuple)): + return type(v)(_clone_val(x) for x in v) + return v # scalars, None, etc. + + snapshot = {} + for name, info in node_io.items(): + entry = dict(info) # shallow copy of metadata + entry["input_values"] = _clone_val(info.get("input_values")) + entry["output_values"] = _clone_val(info.get("output_values")) + # layer_hyperparams may contain weight tensors; clone them too + hp = info.get("layer_hyperparams") + if isinstance(hp, dict): + hp_copy = {} + for k, v in hp.items(): + hp_copy[k] = _clone_val(v) if isinstance(v, torch.Tensor) else v + entry["layer_hyperparams"] = hp_copy + snapshot[name] = entry + return snapshot + + def forward_pass( + self, + inputs, + max_new_tokens=50, + temperature=None, + top_k=None, + top_p=None, + eos_token_id=None, + store_node_io=True, + move_snapshots_to_cpu=True, + debug=False, + ): + """ + Run autoregressive generation for N tokens, storing node I/O per step. + + This is **Step 1** of the two-step API. It performs only the forward + pass — no relevance is computed. The per-step node I/O snapshots are + saved in ``self.node_io_trace`` so that ``relevance_pass()`` can + consume them later. + + Args: + inputs: dict with 'input_ids' (and optionally 'attention_mask'), + or a tuple/list of (input_ids, attention_mask). + max_new_tokens (int): Number of tokens to generate. + temperature (float | None): Sampling temperature (None ≔ greedy). + top_k (int | None): Top-k sampling. + top_p (float | None): Nucleus sampling threshold. + eos_token_id (int | list[int] | None): Stop on these token(s). + store_node_io (bool): If True (default), store a snapshot of + node_io for every generation step so that relevance_pass() + can use it. Set to False if you only need generated tokens. + move_snapshots_to_cpu (bool): Move snapshot tensors to CPU to + free GPU VRAM (default True). + debug (bool): Print per-step diagnostics. + + Returns: + dict with keys: + - 'generated_token_ids': List[int] — newly generated tokens + - 'complete_sequence': Tensor [1, T] — full input + generated + - 'num_steps': int — actual number of generation steps run + """ + # ---- parse inputs ---- + if isinstance(inputs, dict): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + elif isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else None + else: + raise ValueError( + "inputs must be a dict with 'input_ids' or a tuple/list " + "of (input_ids, attention_mask)" + ) + + if input_ids is None: + raise ValueError("input_ids is required") + + if not isinstance(input_ids, torch.Tensor): + input_ids = torch.tensor(input_ids) + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + input_ids = input_ids.long() + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + elif not isinstance(attention_mask, torch.Tensor): + attention_mask = torch.tensor(attention_mask) + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + attention_mask = attention_mask.long() + + # ---- sampling config ---- + temp = temperature + do_sample = ( + (temp is not None and temp != 1.0) + or (top_k is not None and top_k > 0) + or (top_p is not None and 0.0 < top_p < 1.0) + ) + + if debug: + print(f"šŸš€ forward_pass: max_new_tokens={max_new_tokens}, " + f"temperature={temp}, top_k={top_k}, top_p={top_p}, " + f"do_sample={do_sample}, store_node_io={store_node_io}") + + # ---- state ---- + generated_tokens: list[int] = [] + generated = input_ids.clone() + attn = attention_mask.clone() + self.node_io_trace: list[dict] = [] # snapshots for relevance_pass + self._forward_pass_token_ids: list[int] = [] # for relevance_pass reference + + for step in range(max_new_tokens): + # --- forward through DLB engine --- + node_io = self.predict(generated, attn, debug=False, temperature=1.0) + + # --- extract logits --- + if "output" in node_io and "output_values" in node_io["output"]: + logits = node_io["output"]["output_values"] + else: + last_node = list(node_io.keys())[-1] + logits = node_io[last_node].get("output_values") + + if logits is None: + raise RuntimeError("Could not extract logits from model output") + + next_logits = logits[:, -1, :].float() + + # --- sampling / greedy --- + if do_sample: + if temp is not None and temp != 1.0 and temp > 0: + next_logits = next_logits / temp + + if top_k is not None and top_k > 0: + top_k_val = min(top_k, next_logits.size(-1)) + kth_vals = torch.topk(next_logits, top_k_val)[0][..., -1, None] + next_logits[next_logits < kth_vals] = float('-inf') + + if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort( + next_logits, descending=True + ) + cum_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + remove_mask = cum_probs > top_p + remove_mask[..., 1:] = remove_mask[..., :-1].clone() + remove_mask[..., 0] = False + indices_to_remove = remove_mask.scatter( + 1, sorted_indices, remove_mask + ) + next_logits[indices_to_remove] = float('-inf') + + probs = F.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_logits, dim=-1, keepdim=True) + + next_token = next_token.long() + token_id = int(next_token.view(-1)[0].item()) + generated_tokens.append(token_id) + + if debug: + print(f" Step {step + 1}: token_id={token_id}") + + # --- snapshot node_io BEFORE clearing --- + if store_node_io: + snapshot = self._snapshot_node_io( + node_io, move_to_cpu=move_snapshots_to_cpu + ) + self.node_io_trace.append(snapshot) + self._forward_pass_token_ids.append(token_id) + + # --- advance sequence --- + generated = torch.cat([generated, next_token], dim=-1) + attn = torch.cat( + [attn, torch.ones_like(next_token)], dim=-1 + ) + + # --- free GPU memory --- + self.node_io = {} + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # --- EOS check --- + if eos_token_id is not None: + if isinstance(eos_token_id, (list, tuple)): + if token_id in eos_token_id: + if debug: + print(f" EOS reached at step {step + 1}") + break + elif token_id == eos_token_id: + if debug: + print(f" EOS reached at step {step + 1}") + break + + if debug: + print(f"āœ… forward_pass complete: {len(generated_tokens)} tokens generated, " + f"{len(self.node_io_trace)} node_io snapshots stored") + + return { + "generated_token_ids": generated_tokens, + "complete_sequence": generated, + "num_steps": len(generated_tokens), + } + + def relevance_pass( + self, + token_indices=None, + mode="default", + multiplier=100.0, + scaler=1.0, + thresholding=0.5, + debug=False, + ): + """ + Compute relevance for tokens generated by a prior forward_pass(). + + This is **Step 2** of the two-step API. It reads from + ``self.node_io_trace`` (populated by ``forward_pass()``) and runs + relevance propagation for the requested generation steps. + + Args: + token_indices (list[int] | None): Which generation steps to + explain (0-based). ``None`` means *all* steps. + Example: ``[0, 4]`` → explain the 1st and 5th generated + tokens. + mode (str): Relevance propagation mode (default: "default"). + multiplier (float): Starting relevance value (default: 100.0). + scaler (float): Relevance scaler (default: 1.0). + thresholding (float): Thresholding for seed (default: 0.5). + debug (bool): Print diagnostics. + + Returns: + list[dict]: One entry per requested step, each containing: + - 'step_index': int — generation step (0-based) + - 'token_id': int — the token that was generated + - 'relevance': dict — node-name → relevance array (same + format as ``evaluation()`` output / ``self.all_wt``) + """ + if not hasattr(self, "node_io_trace") or not self.node_io_trace: + raise RuntimeError( + "No node_io_trace found. Run forward_pass(store_node_io=True) first." + ) + + total_steps = len(self.node_io_trace) + + # resolve indices + if token_indices is None: + indices = list(range(total_steps)) + else: + indices = list(token_indices) + for idx in indices: + if idx < 0 or idx >= total_steps: + raise IndexError( + f"token_index {idx} out of range — forward_pass " + f"stored {total_steps} steps (0..{total_steps - 1})" + ) + + if debug: + print(f"šŸ” relevance_pass: computing relevance for " + f"{len(indices)}/{total_steps} steps") + + results: list[dict] = [] + + for idx in indices: + snapshot = self.node_io_trace[idx] + token_id = self._forward_pass_token_ids[idx] + + if debug: + print(f" Step {idx}: token_id={token_id} …", end=" ") + + # Temporarily set self.node_io so evaluation() can use it + self.node_io = snapshot + + rel = self.evaluation( + mode=mode, + start_wt=[], + multiplier=multiplier, + scaler=scaler, + thresholding=thresholding, + task="generation", + target_token_ids=[token_id], + debug=False, + ) + + results.append({ + "step_index": idx, + "token_id": token_id, + "relevance": copy.deepcopy(rel), + }) + + # Free memory between steps + self.node_io = {} + gc.collect() + + if debug: + print(f"āœ… relevance_pass complete: {len(results)} step(s) explained") + + return results + + def clear_traces(self): + """ + Free all stored node_io snapshots and relevance results. + + Call this after you are done with relevance_pass() to reclaim memory. + """ + if hasattr(self, "node_io_trace"): + del self.node_io_trace + self.node_io_trace = [] + if hasattr(self, "_forward_pass_token_ids"): + del self._forward_pass_token_ids + self._forward_pass_token_ids = [] + self.node_io = {} + if hasattr(self, "all_wt"): + self.all_wt = {} + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def run_task( self, task="auto", @@ -566,20 +919,22 @@ def run_task( Unified method for running DL-Backtrace on different tasks. Args: - task (str): Task type - "auto", "image-classification", "text-classification", or "generation" + task (str): Task type - "auto", "image-classification", "text-classification", "generation", or "forward-only-generation" - "auto": Automatically detect task based on inputs - "tabular-classification": For PyTorch Tabular Classification models - "image-classification": For image classification models (e.g., MobileNet, ResNet) - "text-classification": For text classification models (e.g., BERT sentiment) - - "generation": For text generation models (e.g., GPT, LLaMA) + - "generation": For text generation models with relevance tracing (e.g., GPT, LLaMA) + - "forward-only-generation": Fast token generation without relevance computation inputs: Input data for the model - For tabular-classification: torch.Tensor of shape (B, F) - For image-classification: torch.Tensor of shape (B, C, H, W) - For text-classification: dict with 'input_ids' and 'attention_mask' or tuple of tensors - For generation: dict with 'input_ids' and 'attention_mask' or tuple of tensors + - For forward-only-generation: dict with 'input_ids' and 'attention_mask' or tuple of tensors - tokenizer: Required for generation tasks (HuggingFace tokenizer) + tokenizer: Required for "generation" task (not needed for "forward-only-generation") mode (str): Relevance propagation mode (default: "default") multiplier (float): Starting relevance value (default: 100.0) @@ -596,6 +951,13 @@ def run_task( relevance_move_to_cpu (bool): Move cached relevance tensors to CPU memory (default: True). debug (bool): Enable debug logging (default: False) + **generation_kwargs: Additional kwargs for generation tasks + - max_new_tokens: Maximum tokens to generate (default: 50) + - temperature: Sampling temperature (default: None for greedy) + - top_k: Top-k sampling (default: None) + - top_p: Nucleus sampling threshold (default: None) + - eos_token_id: Stop generation when this token is produced + - num_beams: Beam search (only for "generation" task) save_relevance (bool): Save relevance trace to disk (default: False) save_path (str): Output directory for saved files (default: "./relevance_output") save_format (str): Quantization format - "fp16", "fp8", or "fp4" (default: "fp8") @@ -612,7 +974,9 @@ def run_task( - 'node_io': Layer-wise outputs from predict() - 'relevance': Relevance scores from evaluation() - 'predictions': Model predictions (logits for classification) - - 'generated_ids': (generation only) Generated token IDs + - 'generated_ids': (generation task) Generated token IDs as tensor [1, T] + - 'generated_token_ids': (forward-only-generation) List[int] of newly generated tokens + - 'complete_sequence': (forward-only-generation) Full sequence tensor [1, T] (input + generated) - 'scores_trace': (if return_scores=True) Scores trace - 'relevance_trace': (if return_relevance=True) Relevance trace - 'layerwise_output_trace': (if return_layerwise_output=True) Layer-wise output trace @@ -644,6 +1008,16 @@ def run_task( return_scores=True ) + # Fast forward-only generation (no relevance, no tokenizer needed) + results = dlb.run_task( + task="forward-only-generation", + inputs={'input_ids': input_ids, 'attention_mask': attention_mask}, + max_new_tokens=10, + temperature=0.8, + top_k=50, + top_p=0.9, + ) + generated_tokens = results['generated_token_ids'] # List[int] # Text generation with saving relevance to disk results = dlb.run_task( task="generation", @@ -663,7 +1037,8 @@ def run_task( "tabular-classification", "image-classification", "text-classification", - "generation" + "generation", + "forward-only-generation" ] # Auto-detect task if needed @@ -754,6 +1129,117 @@ def run_task( return result + elif task == "forward-only-generation": + if isinstance(inputs, dict): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + elif isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else None + else: + raise ValueError("For forward-only-generation, inputs must be dict or tuple of (input_ids, attention_mask)") + + if input_ids is None: + raise ValueError("input_ids is required for forward-only-generation") + + if not isinstance(input_ids, torch.Tensor): + input_ids = torch.tensor(input_ids) + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + input_ids = input_ids.long() + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + elif not isinstance(attention_mask, torch.Tensor): + attention_mask = torch.tensor(attention_mask) + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + attention_mask = attention_mask.long() + + max_new_tokens = generation_kwargs.get("max_new_tokens", 50) + temp = generation_kwargs.get("temperature", None) + top_k = generation_kwargs.get("top_k", None) + top_p = generation_kwargs.get("top_p", None) + eos_token_id = generation_kwargs.get("eos_token_id", None) + + do_sample = (temp is not None and temp != 1.0) or (top_k is not None and top_k > 0) or (top_p is not None and 0.0 < top_p < 1.0) + + if debug: + print(f"šŸš€ Running forward-only-generation task...") + print(f" max_new_tokens: {max_new_tokens}") + print(f" temperature: {temp}, top_k: {top_k}, top_p: {top_p}") + print(f" do_sample: {do_sample}") + + generated_tokens = [] + generated = input_ids.clone() + attn = attention_mask.clone() + + for step in range(max_new_tokens): + node_io = self.predict(generated, attn, debug=False, temperature=1.0) + + if "output" in node_io and "output_values" in node_io["output"]: + logits = node_io["output"]["output_values"] + else: + last_node = list(node_io.keys())[-1] + logits = node_io[last_node].get("output_values") + + if logits is None: + raise RuntimeError("Could not extract logits from model output") + + next_logits = logits[:, -1, :].float() + + if do_sample: + if temp is not None and temp != 1.0 and temp > 0: + next_logits = next_logits / temp + + if top_k is not None and top_k > 0: + top_k_val = min(top_k, next_logits.size(-1)) + indices_to_remove = next_logits < torch.topk(next_logits, top_k_val)[0][..., -1, None] + next_logits[indices_to_remove] = float('-inf') + + if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_logits[indices_to_remove] = float('-inf') + + probs = F.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_logits, dim=-1, keepdim=True) + + next_token = next_token.long() + token_id = int(next_token.view(-1)[0].item()) + generated_tokens.append(token_id) + + if debug: + print(f" Step {step+1}: token_id={token_id}") + + generated = torch.cat([generated, next_token], dim=-1) + attn = torch.cat([attn, torch.ones_like(next_token)], dim=-1) + + self.node_io = {} + + if eos_token_id is not None: + if isinstance(eos_token_id, (list, tuple)): + if token_id in eos_token_id: + if debug: + print(f" EOS token reached at step {step+1}") + break + elif token_id == eos_token_id: + if debug: + print(f" EOS token reached at step {step+1}") + break + + return { + 'task': task, + 'generated_token_ids': generated_tokens, + 'complete_sequence': generated, + } + else: # Classification tasks (image or text) if debug: @@ -1059,17 +1545,94 @@ def print_all_relevance_info(self): def visualize(self, save_path="graph.png"): visualize_graph(self.graph, save_path) - def visualize_dlbacktrace(self, output_path="backtrace_graph", top_k=None, relevance_threshold=None, engine_auto_threshold=1500): - visualize_relevance_auto( - self.graph, - self.all_wt, - output_path=output_path, # pretty path for small graphs - node_threshold=500, - engine_auto_threshold=engine_auto_threshold, - fast_output_path="backtrace_collapsed_fast", # path for large graphs - show=True, # ā¬…ļø show in Colab - inline_format="svg", # or "png" if SVG too heavy - ) + def visualize_dlbacktrace( + self, + output_path="backtrace_graph", + top_k=None, + relevance_threshold=None, + engine_auto_threshold=1500, + layer_types=None, + compact=False, + paginated=False, + max_nodes_per_page=30, + pages_per_row=1, + rankdir="LR", + show=True, + inline_format="svg" + ): + """Visualize DL-Backtrace relevance graph. + + Parameters + ---------- + output_path : str + Output file path (without extension) + top_k : int, optional + Show only top-k nodes by relevance + relevance_threshold : float, optional + Show nodes with |relevance| >= threshold + engine_auto_threshold : int + Node count threshold for switching rendering engines + layer_types : list[str], optional + List of layer types to include. Options: + - "MLP_Layer" (Linear/FC) + - "DL_Layer" (Conv) + - "Activation" (ReLU, GELU, etc.) + - "Normalization" (BatchNorm, LayerNorm) + - "Attention" + - "Output" + - "Placeholder" / "Model_Input" + - "NLP_Embedding" + compact : bool + If True, uses SEMANTIC_LAYER_TYPES for a paper-ready compact graph. + Equivalent to layer_types=SEMANTIC_LAYER_TYPES. + paginated : bool + If True, splits the graph into multiple pages for long DAGs. + Each page contains max_nodes_per_page nodes in topological order. + max_nodes_per_page : int + Maximum nodes per page when paginated=True (default: 30) + pages_per_row : int + Number of pages to display side-by-side when paginated=True (default: 1). + When > 1, creates an additional combined SVG with pages arranged + left-to-right. E.g., pages_per_row=3 puts 3 pages side-by-side. + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall). + Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout. + For paginated mode, "TB" is recommended for each page. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + """ + # compact=True is a shortcut for semantic layer types + if compact and layer_types is None: + layer_types = list(SEMANTIC_LAYER_TYPES) + + if paginated: + # Use paginated visualization for long graphs + visualize_relevance_paginated( + self.graph, + self.all_wt, + output_path=output_path, + max_nodes_per_page=max_nodes_per_page, + layer_types=layer_types, + rankdir=rankdir, + pages_per_row=pages_per_row, + show=show, + inline_format=inline_format, + ) + else: + visualize_relevance_auto( + self.graph, + self.all_wt, + output_path=output_path, + node_threshold=500, + engine_auto_threshold=engine_auto_threshold, + fast_output_path=output_path, + layer_types=layer_types, + rankdir=rankdir, + show=show, + inline_format=inline_format, + ) def visualize_dlbacktrace_with_modules( self,