From fda55de76ae2c964c8409a825c1fe5f11a9d15a8 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Wed, 4 Feb 2026 16:18:20 +0530 Subject: [PATCH 01/16] feat: add 'forward-only-generation' task to run_task() and make visualize_dlbacktrace() configurable --- .../dlbacktrace/dlbacktrace.py | 154 ++++++++++++++++-- 1 file changed, 143 insertions(+), 11 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 1381083..3a57c54 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -19,6 +19,7 @@ import numpy as np import torch +import torch.nn.functional as F import inspect class DLBacktrace: @@ -548,20 +549,22 @@ def run_task( Unified method for running DL-Backtrace on different tasks. Args: - task (str): Task type - "auto", "image-classification", "text-classification", or "generation" + task (str): Task type - "auto", "image-classification", "text-classification", "generation", or "forward-only-generation" - "auto": Automatically detect task based on inputs - "tabular-classification": For PyTorch Tabular Classification models - "image-classification": For image classification models (e.g., MobileNet, ResNet) - "text-classification": For text classification models (e.g., BERT sentiment) - - "generation": For text generation models (e.g., GPT, LLaMA) + - "generation": For text generation models with relevance tracing (e.g., GPT, LLaMA) + - "forward-only-generation": Fast token generation without relevance computation inputs: Input data for the model - For tabular-classification: torch.Tensor of shape (B, F) - For image-classification: torch.Tensor of shape (B, C, H, W) - For text-classification: dict with 'input_ids' and 'attention_mask' or tuple of tensors - For generation: dict with 'input_ids' and 'attention_mask' or tuple of tensors + - For forward-only-generation: dict with 'input_ids' and 'attention_mask' or tuple of tensors - tokenizer: Required for generation tasks (HuggingFace tokenizer) + tokenizer: Required for "generation" task (not needed for "forward-only-generation") mode (str): Relevance propagation mode (default: "default") multiplier (float): Starting relevance value (default: 100.0) @@ -578,8 +581,13 @@ def run_task( relevance_move_to_cpu (bool): Move cached relevance tensors to CPU memory (default: True). debug (bool): Enable debug logging (default: False) - **generation_kwargs: Additional kwargs for generation task (passed to sample_auto) - - max_new_tokens, top_k, top_p, num_beams, etc. + **generation_kwargs: Additional kwargs for generation tasks + - max_new_tokens: Maximum tokens to generate (default: 50) + - temperature: Sampling temperature (default: None for greedy) + - top_k: Top-k sampling (default: None) + - top_p: Nucleus sampling threshold (default: None) + - eos_token_id: Stop generation when this token is produced + - num_beams: Beam search (only for "generation" task) Returns: dict: Results containing: @@ -587,7 +595,8 @@ def run_task( - 'node_io': Layer-wise outputs from predict() - 'relevance': Relevance scores from evaluation() - 'predictions': Model predictions (logits for classification) - - 'generated_ids': (generation only) Generated token IDs + - 'generated_ids': (generation tasks) Generated token IDs as tensor [1, T] + - 'generated_token_ids': (forward-only-generation) List[int] of generated tokens - 'scores_trace': (if return_scores=True) Scores trace - 'relevance_trace': (if return_relevance=True) Relevance trace - 'layerwise_output_trace': (if return_layerwise_output=True) Layer-wise output trace @@ -617,13 +626,25 @@ def run_task( return_relevance=True, return_scores=True ) + + # Fast forward-only generation (no relevance, no tokenizer needed) + results = dlb.run_task( + task="forward-only-generation", + inputs={'input_ids': input_ids, 'attention_mask': attention_mask}, + max_new_tokens=10, + temperature=0.8, + top_k=50, + top_p=0.9, + ) + generated_tokens = results['generated_token_ids'] # List[int] """ # Validate task type valid_tasks = [ "tabular-classification", "image-classification", "text-classification", - "generation" + "generation", + "forward-only-generation" ] # Auto-detect task if needed @@ -698,6 +719,117 @@ def run_task( return result + elif task == "forward-only-generation": + if isinstance(inputs, dict): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + elif isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else None + else: + raise ValueError("For forward-only-generation, inputs must be dict or tuple of (input_ids, attention_mask)") + + if input_ids is None: + raise ValueError("input_ids is required for forward-only-generation") + + if not isinstance(input_ids, torch.Tensor): + input_ids = torch.tensor(input_ids) + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + input_ids = input_ids.long() + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + elif not isinstance(attention_mask, torch.Tensor): + attention_mask = torch.tensor(attention_mask) + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + attention_mask = attention_mask.long() + + max_new_tokens = generation_kwargs.get("max_new_tokens", 50) + temp = generation_kwargs.get("temperature", None) + top_k = generation_kwargs.get("top_k", None) + top_p = generation_kwargs.get("top_p", None) + eos_token_id = generation_kwargs.get("eos_token_id", None) + + do_sample = (temp is not None and temp != 1.0) or (top_k is not None and top_k > 0) or (top_p is not None and 0.0 < top_p < 1.0) + + if debug: + print(f"šŸš€ Running forward-only-generation task...") + print(f" max_new_tokens: {max_new_tokens}") + print(f" temperature: {temp}, top_k: {top_k}, top_p: {top_p}") + print(f" do_sample: {do_sample}") + + generated_tokens = [] + generated = input_ids.clone() + attn = attention_mask.clone() + + for step in range(max_new_tokens): + node_io = self.predict(generated, attn, debug=False, temperature=1.0) + + if "output" in node_io and "output_values" in node_io["output"]: + logits = node_io["output"]["output_values"] + else: + last_node = list(node_io.keys())[-1] + logits = node_io[last_node].get("output_values") + + if logits is None: + raise RuntimeError("Could not extract logits from model output") + + next_logits = logits[:, -1, :].float() + + if do_sample: + if temp is not None and temp != 1.0 and temp > 0: + next_logits = next_logits / temp + + if top_k is not None and top_k > 0: + top_k_val = min(top_k, next_logits.size(-1)) + indices_to_remove = next_logits < torch.topk(next_logits, top_k_val)[0][..., -1, None] + next_logits[indices_to_remove] = float('-inf') + + if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(next_logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = False + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + next_logits[indices_to_remove] = float('-inf') + + probs = F.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_logits, dim=-1, keepdim=True) + + next_token = next_token.long() + token_id = int(next_token.view(-1)[0].item()) + generated_tokens.append(token_id) + + if debug: + print(f" Step {step+1}: token_id={token_id}") + + generated = torch.cat([generated, next_token], dim=-1) + attn = torch.cat([attn, torch.ones_like(next_token)], dim=-1) + + self.node_io = {} + + if eos_token_id is not None: + if isinstance(eos_token_id, (list, tuple)): + if token_id in eos_token_id: + if debug: + print(f" EOS token reached at step {step+1}") + break + elif token_id == eos_token_id: + if debug: + print(f" EOS token reached at step {step+1}") + break + + return { + 'task': task, + 'generated_token_ids': generated_tokens, + 'generated_ids': generated, + } + else: # Classification tasks (image or text) if debug: @@ -985,16 +1117,16 @@ def print_all_relevance_info(self): def visualize(self, save_path="graph.png"): visualize_graph(self.graph, save_path) - def visualize_dlbacktrace(self, output_path="backtrace_graph", top_k=None, relevance_threshold=None, engine_auto_threshold=1500): + def visualize_dlbacktrace(self, output_path="backtrace_graph", top_k=None, relevance_threshold=None, engine_auto_threshold=1500, show=True, inline_format="svg"): visualize_relevance_auto( self.graph, self.all_wt, output_path=output_path, # pretty path for small graphs node_threshold=500, engine_auto_threshold=engine_auto_threshold, - fast_output_path="backtrace_collapsed_fast", # path for large graphs - show=True, # ā¬…ļø show in Colab - inline_format="svg", # or "png" if SVG too heavy + fast_output_path=output_path, # path for large graphs + show=show, # ā¬…ļø show in Colab + inline_format=inline_format, # or "png" if SVG too heavy ) def visualize_dlbacktrace_with_modules( From ef73444e5281455faf94e467ffe7987ab442efda Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Wed, 4 Feb 2026 16:51:41 +0530 Subject: [PATCH 02/16] refactor: rename 'generated_ids' to 'complete_sequence' in forward-only-generation for clarity --- dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 3a57c54..a7c8994 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -595,8 +595,9 @@ def run_task( - 'node_io': Layer-wise outputs from predict() - 'relevance': Relevance scores from evaluation() - 'predictions': Model predictions (logits for classification) - - 'generated_ids': (generation tasks) Generated token IDs as tensor [1, T] - - 'generated_token_ids': (forward-only-generation) List[int] of generated tokens + - 'generated_ids': (generation task) Generated token IDs as tensor [1, T] + - 'generated_token_ids': (forward-only-generation) List[int] of newly generated tokens + - 'complete_sequence': (forward-only-generation) Full sequence tensor [1, T] (input + generated) - 'scores_trace': (if return_scores=True) Scores trace - 'relevance_trace': (if return_relevance=True) Relevance trace - 'layerwise_output_trace': (if return_layerwise_output=True) Layer-wise output trace @@ -827,7 +828,7 @@ def run_task( return { 'task': task, 'generated_token_ids': generated_tokens, - 'generated_ids': generated, + 'complete_sequence': generated, } else: From 86f379dd7494117675f04ee56cb4e451ee7e8feb Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Mon, 9 Feb 2026 09:43:50 +0530 Subject: [PATCH 03/16] feat: add two-step forward_pass() + relevance_pass() API for decoupled multi-token generation and explanation --- .../dlbacktrace/dlbacktrace.py | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index a7c8994..3b41b49 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -17,6 +17,8 @@ ) from .core.visualization_module_aware import visualize_relevance_with_module_labels +import gc +import copy import numpy as np import torch import torch.nn.functional as F @@ -525,6 +527,357 @@ def evaluation(self, mode="default", start_wt=[], multiplier=100.0, scaler=1.0, ) return self.all_wt + # ------------------------------------------------------------------ # + # Two-step API: forward_pass() + relevance_pass() # + # ------------------------------------------------------------------ # + + @staticmethod + def _snapshot_node_io(node_io, move_to_cpu=True): + """ + Create a memory-efficient deep copy of node_io for one generation step. + + Only the large tensor fields (input_values, output_values) are cloned + and optionally moved to CPU. All other metadata (layer_name, func_name, + input_sources, output_children, layer_hyperparams, …) is shallow-copied + because it is immutable across generation steps. + + Args: + node_io (dict): The node I/O dict produced by predict(). + move_to_cpu (bool): Move cloned tensors to CPU to free GPU memory. + + Returns: + dict: A snapshot suitable for later relevance computation. + """ + def _clone_val(v): + if isinstance(v, torch.Tensor): + t = v.detach().clone() + return t.cpu() if move_to_cpu else t + if isinstance(v, (list, tuple)): + return type(v)(_clone_val(x) for x in v) + return v # scalars, None, etc. + + snapshot = {} + for name, info in node_io.items(): + entry = dict(info) # shallow copy of metadata + entry["input_values"] = _clone_val(info.get("input_values")) + entry["output_values"] = _clone_val(info.get("output_values")) + # layer_hyperparams may contain weight tensors; clone them too + hp = info.get("layer_hyperparams") + if isinstance(hp, dict): + hp_copy = {} + for k, v in hp.items(): + hp_copy[k] = _clone_val(v) if isinstance(v, torch.Tensor) else v + entry["layer_hyperparams"] = hp_copy + snapshot[name] = entry + return snapshot + + def forward_pass( + self, + inputs, + max_new_tokens=50, + temperature=None, + top_k=None, + top_p=None, + eos_token_id=None, + store_node_io=True, + move_snapshots_to_cpu=True, + debug=False, + ): + """ + Run autoregressive generation for N tokens, storing node I/O per step. + + This is **Step 1** of the two-step API. It performs only the forward + pass — no relevance is computed. The per-step node I/O snapshots are + saved in ``self.node_io_trace`` so that ``relevance_pass()`` can + consume them later. + + Args: + inputs: dict with 'input_ids' (and optionally 'attention_mask'), + or a tuple/list of (input_ids, attention_mask). + max_new_tokens (int): Number of tokens to generate. + temperature (float | None): Sampling temperature (None ≔ greedy). + top_k (int | None): Top-k sampling. + top_p (float | None): Nucleus sampling threshold. + eos_token_id (int | list[int] | None): Stop on these token(s). + store_node_io (bool): If True (default), store a snapshot of + node_io for every generation step so that relevance_pass() + can use it. Set to False if you only need generated tokens. + move_snapshots_to_cpu (bool): Move snapshot tensors to CPU to + free GPU VRAM (default True). + debug (bool): Print per-step diagnostics. + + Returns: + dict with keys: + - 'generated_token_ids': List[int] — newly generated tokens + - 'complete_sequence': Tensor [1, T] — full input + generated + - 'num_steps': int — actual number of generation steps run + """ + # ---- parse inputs ---- + if isinstance(inputs, dict): + input_ids = inputs.get("input_ids") + attention_mask = inputs.get("attention_mask") + elif isinstance(inputs, (tuple, list)): + input_ids = inputs[0] + attention_mask = inputs[1] if len(inputs) > 1 else None + else: + raise ValueError( + "inputs must be a dict with 'input_ids' or a tuple/list " + "of (input_ids, attention_mask)" + ) + + if input_ids is None: + raise ValueError("input_ids is required") + + if not isinstance(input_ids, torch.Tensor): + input_ids = torch.tensor(input_ids) + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + input_ids = input_ids.long() + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + elif not isinstance(attention_mask, torch.Tensor): + attention_mask = torch.tensor(attention_mask) + if attention_mask.dim() == 1: + attention_mask = attention_mask.unsqueeze(0) + attention_mask = attention_mask.long() + + # ---- sampling config ---- + temp = temperature + do_sample = ( + (temp is not None and temp != 1.0) + or (top_k is not None and top_k > 0) + or (top_p is not None and 0.0 < top_p < 1.0) + ) + + if debug: + print(f"šŸš€ forward_pass: max_new_tokens={max_new_tokens}, " + f"temperature={temp}, top_k={top_k}, top_p={top_p}, " + f"do_sample={do_sample}, store_node_io={store_node_io}") + + # ---- state ---- + generated_tokens: list[int] = [] + generated = input_ids.clone() + attn = attention_mask.clone() + self.node_io_trace: list[dict] = [] # snapshots for relevance_pass + self._forward_pass_token_ids: list[int] = [] # for relevance_pass reference + + for step in range(max_new_tokens): + # --- forward through DLB engine --- + node_io = self.predict(generated, attn, debug=False, temperature=1.0) + + # --- extract logits --- + if "output" in node_io and "output_values" in node_io["output"]: + logits = node_io["output"]["output_values"] + else: + last_node = list(node_io.keys())[-1] + logits = node_io[last_node].get("output_values") + + if logits is None: + raise RuntimeError("Could not extract logits from model output") + + next_logits = logits[:, -1, :].float() + + # --- sampling / greedy --- + if do_sample: + if temp is not None and temp != 1.0 and temp > 0: + next_logits = next_logits / temp + + if top_k is not None and top_k > 0: + top_k_val = min(top_k, next_logits.size(-1)) + kth_vals = torch.topk(next_logits, top_k_val)[0][..., -1, None] + next_logits[next_logits < kth_vals] = float('-inf') + + if top_p is not None and 0.0 < top_p < 1.0: + sorted_logits, sorted_indices = torch.sort( + next_logits, descending=True + ) + cum_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + remove_mask = cum_probs > top_p + remove_mask[..., 1:] = remove_mask[..., :-1].clone() + remove_mask[..., 0] = False + indices_to_remove = remove_mask.scatter( + 1, sorted_indices, remove_mask + ) + next_logits[indices_to_remove] = float('-inf') + + probs = F.softmax(next_logits, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(next_logits, dim=-1, keepdim=True) + + next_token = next_token.long() + token_id = int(next_token.view(-1)[0].item()) + generated_tokens.append(token_id) + + if debug: + print(f" Step {step + 1}: token_id={token_id}") + + # --- snapshot node_io BEFORE clearing --- + if store_node_io: + snapshot = self._snapshot_node_io( + node_io, move_to_cpu=move_snapshots_to_cpu + ) + self.node_io_trace.append(snapshot) + self._forward_pass_token_ids.append(token_id) + + # --- advance sequence --- + generated = torch.cat([generated, next_token], dim=-1) + attn = torch.cat( + [attn, torch.ones_like(next_token)], dim=-1 + ) + + # --- free GPU memory --- + self.node_io = {} + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # --- EOS check --- + if eos_token_id is not None: + if isinstance(eos_token_id, (list, tuple)): + if token_id in eos_token_id: + if debug: + print(f" EOS reached at step {step + 1}") + break + elif token_id == eos_token_id: + if debug: + print(f" EOS reached at step {step + 1}") + break + + if debug: + print(f"āœ… forward_pass complete: {len(generated_tokens)} tokens generated, " + f"{len(self.node_io_trace)} node_io snapshots stored") + + return { + "generated_token_ids": generated_tokens, + "complete_sequence": generated, + "num_steps": len(generated_tokens), + } + + def relevance_pass( + self, + token_indices=None, + mode="default", + multiplier=100.0, + scaler=1.0, + thresholding=0.5, + debug=False, + ): + """ + Compute relevance for tokens generated by a prior forward_pass(). + + This is **Step 2** of the two-step API. It reads from + ``self.node_io_trace`` (populated by ``forward_pass()``) and runs + relevance propagation for the requested generation steps. + + Args: + token_indices (list[int] | None): Which generation steps to + explain (0-based). ``None`` means *all* steps. + Example: ``[0, 4]`` → explain the 1st and 5th generated + tokens. + mode (str): Relevance propagation mode (default: "default"). + multiplier (float): Starting relevance value (default: 100.0). + scaler (float): Relevance scaler (default: 1.0). + thresholding (float): Thresholding for seed (default: 0.5). + debug (bool): Print diagnostics. + + Returns: + list[dict]: One entry per requested step, each containing: + - 'step_index': int — generation step (0-based) + - 'token_id': int — the token that was generated + - 'relevance': dict — node-name → relevance array (same + format as ``evaluation()`` output / ``self.all_wt``) + """ + if not hasattr(self, "node_io_trace") or not self.node_io_trace: + raise RuntimeError( + "No node_io_trace found. Run forward_pass(store_node_io=True) first." + ) + + total_steps = len(self.node_io_trace) + + # resolve indices + if token_indices is None: + indices = list(range(total_steps)) + else: + indices = list(token_indices) + for idx in indices: + if idx < 0 or idx >= total_steps: + raise IndexError( + f"token_index {idx} out of range — forward_pass " + f"stored {total_steps} steps (0..{total_steps - 1})" + ) + + if debug: + print(f"šŸ” relevance_pass: computing relevance for " + f"{len(indices)}/{total_steps} steps") + + results: list[dict] = [] + + for idx in indices: + snapshot = self.node_io_trace[idx] + token_id = self._forward_pass_token_ids[idx] + + if debug: + print(f" Step {idx}: token_id={token_id} …", end=" ") + + # Temporarily set self.node_io so evaluation() can use it + self.node_io = snapshot + + rel = self.evaluation( + mode=mode, + start_wt=[], + multiplier=multiplier, + scaler=scaler, + thresholding=thresholding, + task="generation", + target_token_ids=[token_id], + debug=False, + ) + + results.append({ + "step_index": idx, + "token_id": token_id, + "relevance": copy.deepcopy(rel), + }) + + if debug: + total_rel = sum( + float(np.sum(v)) if isinstance(v, np.ndarray) else 0.0 + for v in rel.values() + ) + print(f"total_relevance={total_rel:.4f}") + + # Free memory between steps + self.node_io = {} + gc.collect() + + if debug: + print(f"āœ… relevance_pass complete: {len(results)} step(s) explained") + + return results + + def clear_traces(self): + """ + Free all stored node_io snapshots and relevance results. + + Call this after you are done with relevance_pass() to reclaim memory. + """ + if hasattr(self, "node_io_trace"): + del self.node_io_trace + self.node_io_trace = [] + if hasattr(self, "_forward_pass_token_ids"): + del self._forward_pass_token_ids + self._forward_pass_token_ids = [] + self.node_io = {} + if hasattr(self, "all_wt"): + self.all_wt = {} + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + def run_task( self, task="auto", From 2b4cba6cf883a94cd5a9e1a5db856ceb510bdb6a Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Mon, 9 Feb 2026 10:38:17 +0530 Subject: [PATCH 04/16] fix: redundant debug statement --- dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index ee8ddea..994ded3 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -844,13 +844,6 @@ def relevance_pass( "relevance": copy.deepcopy(rel), }) - if debug: - total_rel = sum( - float(np.sum(v)) if isinstance(v, np.ndarray) else 0.0 - for v in rel.values() - ) - print(f"total_relevance={total_rel:.4f}") - # Free memory between steps self.node_io = {} gc.collect() From d911c23d4dc620c92e9c64ba2a02cb0ec7c5c874 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Wed, 25 Mar 2026 13:00:50 +0530 Subject: [PATCH 05/16] feat: Add compact graph visualization with layer type filtering - Add SEMANTIC_LAYER_TYPES constant for paper-ready compact visualizations - Add layer_types parameter to visualize_relevance, visualize_relevance_fast, and visualize_relevance_auto - Filter nodes by semantic layer types (MLP, Attention, Normalization, etc.) - Always include Placeholder, Model_Input, and Output nodes for graph connectivity - Export SEMANTIC_LAYER_TYPES from package __init__.py --- .../pytorch_backtrace/dlbacktrace/__init__.py | 5 +- .../dlbacktrace/core/visualization.py | 99 +++++++++++++++++-- .../dlbacktrace/dlbacktrace.py | 62 ++++++++++-- 3 files changed, 151 insertions(+), 15 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py index 94a7edb..ebf39be 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/__init__.py @@ -3,11 +3,12 @@ from .utils import * from .aten_operations import * from .core.relevance_saver import save_relevance, load_relevance, Precision +from .core.visualization import SEMANTIC_LAYER_TYPES # Export pipeline modules (optional, can be imported separately) try: from .pipeline import DLBacktracePipeline, ModelRegistry, get_model_info, PipelineConfig - __all__ = ['DLBacktrace', 'activation_master', 'DLBacktracePipeline', 'ModelRegistry', 'get_model_info', 'PipelineConfig'] + __all__ = ['DLBacktrace', 'activation_master', 'DLBacktracePipeline', 'ModelRegistry', 'get_model_info', 'PipelineConfig', 'SEMANTIC_LAYER_TYPES'] except ImportError: - __all__ = ['DLBacktrace', 'activation_master'] + __all__ = ['DLBacktrace', 'activation_master', 'SEMANTIC_LAYER_TYPES'] \ No newline at end of file diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index e90ff02..9df787c 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -7,6 +7,23 @@ from networkx.drawing.nx_pydot import graphviz_layout from collections import defaultdict from IPython.display import display, SVG, Image as IPyImage +from typing import Optional, Sequence + +# Semantically meaningful layer types for compact visualization +SEMANTIC_LAYER_TYPES: tuple[str, ...] = ( + "MLP_Layer", # Linear/FC layers + "DL_Layer", # Conv layers (Conv1d, Conv2d, Conv3d) + "Activation", # ReLU, GELU, SiLU, etc. + "Normalization", # BatchNorm, LayerNorm, GroupNorm + "Attention", # Self/Cross attention (scaled_dot_product_attention) + "Output", # Final output node + "Placeholder", # Input nodes (x, input_ids, etc.) - CRITICAL for graph connectivity + "Model_Input", # Legacy input type (kept for compatibility) + "NLP_Embedding", # Embedding layers (embedding, embedding_bag) +) + +# Default types to always force-include (for graph connectivity) +DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Placeholder", "Model_Input", "Output") def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): @@ -47,8 +64,30 @@ def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): def visualize_relevance(graph, all_wt, output_path="backtrace_graph", *, top_k=None, relevance_threshold=None, + layer_types: Optional[Sequence[str]] = None, show=True, inline_format="svg"): - """šŸŽÆ Visualize relevance backtrace using Graphviz (shows inline + saves)""" + """šŸŽÆ Visualize relevance backtrace using Graphviz (shows inline + saves) + + Parameters + ---------- + graph : networkx.DiGraph + The computation graph with layer_type attributes on nodes + all_wt : dict + Relevance weights for each node + output_path : str + Output file path (without extension) + top_k : int, optional + Show only top-k nodes by relevance + relevance_threshold : float, optional + Show nodes with |relevance| >= threshold + layer_types : list[str], optional + Filter to only these layer types. If None, shows all nodes. + Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + """ relevance_data = {} # --- Extract relevance stats from all_wt --- @@ -67,16 +106,24 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", stats = (0.0, 0.0, 0.0) relevance_data[node_key] = stats - # --- Filter based on top_k or threshold --- + # --- Filter based on top_k, threshold, or layer_types --- flat_scores = {k: v[0] for k, v in relevance_data.items()} force_include = { node.replace("/", " ").replace(":", " ") for node in graph.nodes - if graph.nodes[node].get("layer_type") in ("Placeholder", "Model_Input") + if graph.nodes[node].get("layer_type") in DEFAULT_FORCE_INCLUDE_TYPES } - if top_k: + if layer_types is not None: + # Layer-type filtering mode + layer_types_set = set(layer_types) + top_node_names = { + node.replace("/", " ").replace(":", " ") + for node in graph.nodes + if graph.nodes[node].get("layer_type") in layer_types_set + } | force_include + elif top_k: top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] top_node_names = {k for k, _ in top_keys} | force_include elif relevance_threshold is not None: @@ -254,14 +301,31 @@ def visualize_relevance_fast( max_parents_per_node=None, engine_auto_threshold=1200, disable_concentrate_for_sfdp=True, + layer_types: Optional[Sequence[str]] = None, show=True, inline_format="svg", ): + """Fast visualization for large/collapsed graphs. + + Parameters + ---------- + layer_types : list[str], optional + Filter to only these layer types. If None, shows all nodes. + """ def _norm(s): return s.replace("/", " ").replace(":", " ") # present nodes present_raw = list(graph.nodes.keys()) + + # Apply layer_types filter if specified + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + present_raw = [ + raw for raw in present_raw + if graph.nodes[raw].get("layer_type") in layer_types_set + ] + norm_by_raw = {raw: _norm(raw) for raw in present_raw} present_norm = set(norm_by_raw.values()) @@ -417,12 +481,31 @@ def visualize_relevance_auto( node_threshold=500, engine_auto_threshold=1500, fast_output_path="backtrace_collapsed_fast", + layer_types: Optional[Sequence[str]] = None, show=True, inline_format="svg", ): - """Auto-choose pretty vs fast; always show inline and save.""" - num_nodes = len(graph.nodes) - print(f"num_nodes: {num_nodes}") + """Auto-choose pretty vs fast visualization; always show inline and save. + + Parameters + ---------- + layer_types : list[str], optional + Filter to only these layer types. If specified, layer_types filtering + takes precedence over automatic collapsing for large graphs. + Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + """ + # If layer_types specified, count only matching nodes for threshold decision + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + filtered_count = sum( + 1 for n in graph.nodes + if graph.nodes[n].get("layer_type") in layer_types_set + ) + num_nodes = filtered_count + print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)") + else: + num_nodes = len(graph.nodes) + print(f"num_nodes: {num_nodes}") if num_nodes < node_threshold: # small graph → original pretty version @@ -430,6 +513,7 @@ def visualize_relevance_auto( graph, all_wt, output_path=output_path, + layer_types=layer_types, show=show, inline_format=inline_format, ) @@ -448,6 +532,7 @@ def visualize_relevance_auto( collapsed_map=collapsed_map, max_parents_per_node=2, engine_auto_threshold=engine_auto_threshold, + layer_types=layer_types, show=show, inline_format=inline_format, ) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 994ded3..3a9dd5a 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -10,7 +10,12 @@ from .core.config import activation_master from .core.dlb_auto_sampler import DLBAutoSampler from .core.relevance_propagation import RelevancePropagator -from .core.visualization import visualize_graph, visualize_relevance, visualize_relevance_auto +from .core.visualization import ( + visualize_graph, + visualize_relevance, + visualize_relevance_auto, + SEMANTIC_LAYER_TYPES, +) from .core.token_relevance_visuals import ( plot_tokenwise_relevance_map_swapped, plot_input_heatmap_for_token, @@ -1527,16 +1532,61 @@ def print_all_relevance_info(self): def visualize(self, save_path="graph.png"): visualize_graph(self.graph, save_path) - def visualize_dlbacktrace(self, output_path="backtrace_graph", top_k=None, relevance_threshold=None, engine_auto_threshold=1500, show=True, inline_format="svg"): + def visualize_dlbacktrace( + self, + output_path="backtrace_graph", + top_k=None, + relevance_threshold=None, + engine_auto_threshold=1500, + layer_types=None, + compact=False, + show=True, + inline_format="svg" + ): + """Visualize DL-Backtrace relevance graph. + + Parameters + ---------- + output_path : str + Output file path (without extension) + top_k : int, optional + Show only top-k nodes by relevance + relevance_threshold : float, optional + Show nodes with |relevance| >= threshold + engine_auto_threshold : int + Node count threshold for switching rendering engines + layer_types : list[str], optional + List of layer types to include. Options: + - "MLP_Layer" (Linear/FC) + - "DL_Layer" (Conv) + - "Activation" (ReLU, GELU, etc.) + - "Normalization" (BatchNorm, LayerNorm) + - "Attention" + - "Output" + - "Placeholder" / "Model_Input" + - "NLP_Embedding" + compact : bool + If True, uses SEMANTIC_LAYER_TYPES for a paper-ready compact graph. + Equivalent to layer_types=SEMANTIC_LAYER_TYPES. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + """ + # compact=True is a shortcut for semantic layer types + if compact and layer_types is None: + layer_types = list(SEMANTIC_LAYER_TYPES) + visualize_relevance_auto( self.graph, self.all_wt, - output_path=output_path, # pretty path for small graphs + output_path=output_path, node_threshold=500, engine_auto_threshold=engine_auto_threshold, - fast_output_path=output_path, # path for large graphs - show=show, # ā¬…ļø show in Colab - inline_format=inline_format, # or "png" if SVG too heavy + fast_output_path=output_path, + layer_types=layer_types, + show=show, + inline_format=inline_format, ) def visualize_dlbacktrace_with_modules( From f8f9cf66aa56d2031cb560cf79c31ac96b83c76c Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Wed, 25 Mar 2026 16:54:28 +0530 Subject: [PATCH 06/16] Add layer_types filtering for compact graph visualization - Add SEMANTIC_LAYER_TYPES constant for paper-ready compact graphs - Add layer_types parameter to visualize_relevance() and visualize_relevance_fast() - Add compact=True shortcut to visualize_dlbacktrace() API - Implement transitive edge computation to maintain DAG connectivity when intermediate nodes are filtered out - Export SEMANTIC_LAYER_TYPES from package __init__.py - Add _get_node_category() helper to check both layer_name and layer_type (graph stores semantic categories in layer_name, not layer_type) - Fix color map lookup to use semantic categories - Allows proper filtering of DL_Layer, MLP_Layer, Activation, etc. nodes --- .../dlbacktrace/core/visualization.py | 181 ++++++++++++++---- 1 file changed, 139 insertions(+), 42 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 9df787c..adb3cd5 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -9,16 +9,17 @@ from IPython.display import display, SVG, Image as IPyImage from typing import Optional, Sequence -# Semantically meaningful layer types for compact visualization +# Semantically meaningful layer categories for compact visualization +# These match the ATEN_LAYER_MAP categories in graph_builder.py SEMANTIC_LAYER_TYPES: tuple[str, ...] = ( - "MLP_Layer", # Linear/FC layers - "DL_Layer", # Conv layers (Conv1d, Conv2d, Conv3d) + "MLP_Layer", # Linear/FC layers (linear, addmm) + "DL_Layer", # Conv layers (conv2d, max_pool2d, etc.) "Activation", # ReLU, GELU, SiLU, etc. "Normalization", # BatchNorm, LayerNorm, GroupNorm "Attention", # Self/Cross attention (scaled_dot_product_attention) - "Output", # Final output node - "Placeholder", # Input nodes (x, input_ids, etc.) - CRITICAL for graph connectivity - "Model_Input", # Legacy input type (kept for compatibility) + "Output", # Final output node (layer_type) + "Placeholder", # Input nodes (layer_type) - CRITICAL for graph connectivity + "Model_Input", # Legacy input type (layer_type) "NLP_Embedding", # Embedding layers (embedding, embedding_bag) ) @@ -26,6 +27,30 @@ DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Placeholder", "Model_Input", "Output") +def _get_node_category(node_attrs: dict) -> str: + """Get the semantic category for a node by checking both layer_type and layer_name. + + The graph stores: + - layer_type: 'ATen_Operation', 'Placeholder', 'Output', 'Operation', etc. + - layer_name: 'DL_Layer', 'MLP_Layer', 'Activation', 'Normalization', etc. + + For filtering, we need to check layer_name first (for ATen ops), then layer_type. + """ + layer_name = node_attrs.get("layer_name", "") + layer_type = node_attrs.get("layer_type", "Unknown") + + # For ATen operations, layer_name contains the semantic category + if layer_name in SEMANTIC_LAYER_TYPES: + return layer_name + + # For Placeholder, Output, etc., layer_type is the category + if layer_type in ("Placeholder", "Output", "Model_Input"): + return layer_type + + # Return layer_type as fallback + return layer_type + + def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): """šŸ“Š Visualize forward execution graph with dynamic scaling (shows inline + saves)""" num_nodes = len(graph.nodes) @@ -112,16 +137,16 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", force_include = { node.replace("/", " ").replace(":", " ") for node in graph.nodes - if graph.nodes[node].get("layer_type") in DEFAULT_FORCE_INCLUDE_TYPES + if _get_node_category(graph.nodes[node]) in DEFAULT_FORCE_INCLUDE_TYPES } if layer_types is not None: - # Layer-type filtering mode + # Layer-type filtering mode - use _get_node_category for proper semantic matching layer_types_set = set(layer_types) top_node_names = { node.replace("/", " ").replace(":", " ") for node in graph.nodes - if graph.nodes[node].get("layer_type") in layer_types_set + if _get_node_category(graph.nodes[node]) in layer_types_set } | force_include elif top_k: top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] @@ -131,7 +156,30 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", else: top_node_names = set(relevance_data.keys()) | force_include - # --- Color map for node types --- + # --- Build raw->normalized name mapping for ancestor lookup --- + raw_to_norm = {node: node.replace("/", " ").replace(":", " ") for node in graph.nodes} + norm_to_raw = {v: k for k, v in raw_to_norm.items()} + + # --- Helper to find transitive ancestors in filtered set --- + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find all ancestors that are in the filtered set.""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = raw_to_norm.get(parent_raw, parent_raw.replace("/", " ").replace(":", " ")) + if parent_norm in top_node_names: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + # Recursively search this parent's ancestors + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors + + # --- Color map for node types (uses layer_name for ATen ops, layer_type for others) --- color_map = { "MLP_Layer": "lightblue", "DL_Layer": "lightgreen", @@ -161,7 +209,9 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", if name not in top_node_names: continue rel = relevance_data.get(name, (0.0, 0.0, 0.0)) - fill = color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white") + # Use layer_name first (for semantic category), then layer_type as fallback + node_category = _get_node_category(graph.nodes[node]) + fill = color_map.get(node_category, color_map.get(graph.nodes[node].get("layer_type", "Unknown"), "white")) g.node( name, label=f"{name}\nMean: {rel[0]:.3f}\nMax: {rel[1]:.3f}\nMin: {rel[2]:.3f}", @@ -169,15 +219,30 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", fillcolor=fill, ) - # --- Add edges --- + # --- Add edges (with transitive connections when layer_types filtering) --- + added_edges = set() for node in graph.nodes: name = node.replace("/", " ").replace(":", " ") if name not in top_node_names: continue - for parent in graph.nodes[node].get("parents", []): - parent_fmt = parent.replace("/", " ").replace(":", " ") - if parent_fmt in top_node_names: - g.edge(parent_fmt, name) + + if layer_types is not None: + # Use transitive ancestor search for filtered graphs + ancestors = find_filtered_ancestors(node) + for ancestor in ancestors: + edge = (ancestor, name) + if edge not in added_edges: + added_edges.add(edge) + g.edge(ancestor, name) + else: + # Original direct parent logic + for parent in graph.nodes[node].get("parents", []): + parent_fmt = parent.replace("/", " ").replace(":", " ") + if parent_fmt in top_node_names: + edge = (parent_fmt, name) + if edge not in added_edges: + added_edges.add(edge) + g.edge(parent_fmt, name) out = g.render(output_path, format="svg", cleanup=True) @@ -315,19 +380,40 @@ def visualize_relevance_fast( def _norm(s): return s.replace("/", " ").replace(":", " ") - # present nodes - present_raw = list(graph.nodes.keys()) + # present nodes - keep all for transitive edge computation + all_raw = list(graph.nodes.keys()) - # Apply layer_types filter if specified + # Determine filtered set using _get_node_category for proper semantic matching if layer_types is not None: layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) present_raw = [ - raw for raw in present_raw - if graph.nodes[raw].get("layer_type") in layer_types_set + raw for raw in all_raw + if _get_node_category(graph.nodes[raw]) in layer_types_set ] + else: + present_raw = all_raw - norm_by_raw = {raw: _norm(raw) for raw in present_raw} - present_norm = set(norm_by_raw.values()) + norm_by_raw = {raw: _norm(raw) for raw in all_raw} # All nodes for lookup + present_norm = {_norm(raw) for raw in present_raw} # Filtered set + + # Helper to find transitive ancestors in filtered set + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find all ancestors that are in the filtered set.""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) or [] + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = norm_by_raw.get(parent_raw, _norm(parent_raw)) + if parent_norm in present_norm: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + # Recursively search this parent's ancestors + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors # relevance only for present rel_map = {} @@ -420,8 +506,9 @@ def _short(s, n=48): for raw in present_raw: nk = norm_by_raw[raw] mean, mx, mn = rel_map.get(nk, (0.0, 0.0, 0.0)) - lt = graph.nodes[raw].get("layer_type", "Unknown") - fill = color_map.get(lt, "white") + # Use _get_node_category for proper semantic coloring + node_category = _get_node_category(graph.nodes[raw]) + fill = color_map.get(node_category, color_map.get(graph.nodes[raw].get("layer_type", "Unknown"), "white")) collapsed = graph.nodes[raw].get("collapsed_count", 0) collapsed_line = f"\n[collapsed {collapsed}]" if collapsed else "" @@ -438,25 +525,35 @@ def _short(s, n=48): fillcolor=fill, ) - # edges + # edges - use transitive ancestors when layer_types filtering added = set() for raw in present_raw: child = norm_by_raw[raw] - parents = graph.nodes[raw].get("parents", []) or [] - if max_parents_per_node is not None and len(parents) > max_parents_per_node: - parents = sorted( - parents, - key=lambda p: abs(rel_map.get(_norm(p), (0.0, 0.0, 0.0))[0]), - reverse=True, - )[:max_parents_per_node] - - for p_raw in parents: - pn = norm_by_raw.get(p_raw, _norm(p_raw)) - e = (pn, child) - if e in added: - continue - added.add(e) - g.edge(pn, child) + + if layer_types is not None: + # Use transitive ancestor search for filtered graphs + ancestors = find_filtered_ancestors(raw) + for ancestor in ancestors: + e = (ancestor, child) + if e not in added: + added.add(e) + g.edge(ancestor, child) + else: + # Original direct parent logic + parents = graph.nodes[raw].get("parents", []) or [] + if max_parents_per_node is not None and len(parents) > max_parents_per_node: + parents = sorted( + parents, + key=lambda p: abs(rel_map.get(_norm(p), (0.0, 0.0, 0.0))[0]), + reverse=True, + )[:max_parents_per_node] + + for p_raw in parents: + pn = norm_by_raw.get(p_raw, _norm(p_raw)) + e = (pn, child) + if e not in added: + added.add(e) + g.edge(pn, child) out = g.render(output_path, cleanup=True) @@ -499,7 +596,7 @@ def visualize_relevance_auto( layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) filtered_count = sum( 1 for n in graph.nodes - if graph.nodes[n].get("layer_type") in layer_types_set + if _get_node_category(graph.nodes[n]) in layer_types_set ) num_nodes = filtered_count print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)") From 75b8a6596ae01c3304b65188cd8cd6d63cdb3031 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Thu, 26 Mar 2026 11:12:18 +0530 Subject: [PATCH 07/16] add print statements for before/after node counts in filtering MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Shows total nodes → filtered nodes count for each filtering mode: - Layer-type filtering (visualize_relevance and visualize_relevance_fast) - Top-k filtering - Threshold filtering - No filtering --- .../pytorch_backtrace/dlbacktrace/core/visualization.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index adb3cd5..75912f9 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -133,6 +133,7 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", # --- Filter based on top_k, threshold, or layer_types --- flat_scores = {k: v[0] for k, v in relevance_data.items()} + total_nodes = len(graph.nodes) force_include = { node.replace("/", " ").replace(":", " ") @@ -148,13 +149,17 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", for node in graph.nodes if _get_node_category(graph.nodes[node]) in layer_types_set } | force_include + print(f"šŸ“Š Layer-type filtering: {total_nodes} nodes → {len(top_node_names)} nodes (filter: {list(layer_types_set)[:5]}{'...' if len(layer_types_set) > 5 else ''})") elif top_k: top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] top_node_names = {k for k, _ in top_keys} | force_include + print(f"šŸ“Š Top-k filtering: {total_nodes} nodes → {len(top_node_names)} nodes (top_k={top_k})") elif relevance_threshold is not None: top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold} | force_include + print(f"šŸ“Š Threshold filtering: {total_nodes} nodes → {len(top_node_names)} nodes (threshold={relevance_threshold})") else: top_node_names = set(relevance_data.keys()) | force_include + print(f"šŸ“Š No filtering: {total_nodes} nodes") # --- Build raw->normalized name mapping for ancestor lookup --- raw_to_norm = {node: node.replace("/", " ").replace(":", " ") for node in graph.nodes} @@ -382,6 +387,7 @@ def _norm(s): # present nodes - keep all for transitive edge computation all_raw = list(graph.nodes.keys()) + total_nodes = len(all_raw) # Determine filtered set using _get_node_category for proper semantic matching if layer_types is not None: @@ -390,8 +396,10 @@ def _norm(s): raw for raw in all_raw if _get_node_category(graph.nodes[raw]) in layer_types_set ] + print(f"šŸ“Š Layer-type filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (filter: {list(layer_types)[:5]}{'...' if len(layer_types) > 5 else ''})") else: present_raw = all_raw + print(f"šŸ“Š No filtering (fast): {total_nodes} nodes") norm_by_raw = {raw: _norm(raw) for raw in all_raw} # All nodes for lookup present_norm = {_norm(raw) for raw in present_raw} # Filtered set From 9aa59fdadbae9a97a84a6ce6e1c4f68f3b70c69b Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Thu, 26 Mar 2026 11:33:34 +0530 Subject: [PATCH 08/16] exclude Placeholder from force-include types for cleaner compact graphs --- .../pytorch_backtrace/dlbacktrace/core/visualization.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 75912f9..36b2f72 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -24,7 +24,8 @@ ) # Default types to always force-include (for graph connectivity) -DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Placeholder", "Model_Input", "Output") +# Note: Placeholder excluded to keep compact graphs clean +DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Model_Input", "Output") def _get_node_category(node_attrs: dict) -> str: @@ -43,8 +44,9 @@ def _get_node_category(node_attrs: dict) -> str: if layer_name in SEMANTIC_LAYER_TYPES: return layer_name - # For Placeholder, Output, etc., layer_type is the category - if layer_type in ("Placeholder", "Output", "Model_Input"): + # For Output, Model_Input, etc., layer_type is the category + # Note: Placeholder excluded to keep compact graphs clean + if layer_type in ("Output", "Model_Input"): return layer_type # Return layer_type as fallback From ae705531a7a1226919b884d35c1f7aeffbd714ea Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Thu, 26 Mar 2026 12:11:10 +0530 Subject: [PATCH 09/16] exclude p_model_*/b_model_* parameter nodes from compact graphs - Add EXCLUDED_NODE_PREFIXES constant and _is_excluded_node() helper - Filter out parameter/bias weight nodes in all filtering modes - Cleaner visualization by hiding internal weight matrices --- .../dlbacktrace/core/visualization.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 36b2f72..0abf58c 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -27,6 +27,14 @@ # Note: Placeholder excluded to keep compact graphs clean DEFAULT_FORCE_INCLUDE_TYPES: tuple[str, ...] = ("Model_Input", "Output") +# Prefixes for parameter/bias nodes to exclude from compact graphs +EXCLUDED_NODE_PREFIXES: tuple[str, ...] = ("p_model_", "b_model_") + + +def _is_excluded_node(node_name: str) -> bool: + """Check if node should be excluded (parameter/bias weights).""" + return any(node_name.startswith(prefix) for prefix in EXCLUDED_NODE_PREFIXES) + def _get_node_category(node_attrs: dict) -> str: """Get the semantic category for a node by checking both layer_type and layer_name. @@ -141,6 +149,7 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", node.replace("/", " ").replace(":", " ") for node in graph.nodes if _get_node_category(graph.nodes[node]) in DEFAULT_FORCE_INCLUDE_TYPES + and not _is_excluded_node(node) } if layer_types is not None: @@ -150,17 +159,18 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", node.replace("/", " ").replace(":", " ") for node in graph.nodes if _get_node_category(graph.nodes[node]) in layer_types_set + and not _is_excluded_node(node) } | force_include print(f"šŸ“Š Layer-type filtering: {total_nodes} nodes → {len(top_node_names)} nodes (filter: {list(layer_types_set)[:5]}{'...' if len(layer_types_set) > 5 else ''})") elif top_k: top_keys = sorted(flat_scores.items(), key=lambda x: abs(x[1]), reverse=True)[:top_k] - top_node_names = {k for k, _ in top_keys} | force_include + top_node_names = {k for k, _ in top_keys if not _is_excluded_node(k)} | force_include print(f"šŸ“Š Top-k filtering: {total_nodes} nodes → {len(top_node_names)} nodes (top_k={top_k})") elif relevance_threshold is not None: - top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold} | force_include + top_node_names = {k for k, v in flat_scores.items() if abs(v) >= relevance_threshold and not _is_excluded_node(k)} | force_include print(f"šŸ“Š Threshold filtering: {total_nodes} nodes → {len(top_node_names)} nodes (threshold={relevance_threshold})") else: - top_node_names = set(relevance_data.keys()) | force_include + top_node_names = {k for k in relevance_data.keys() if not _is_excluded_node(k)} | force_include print(f"šŸ“Š No filtering: {total_nodes} nodes") # --- Build raw->normalized name mapping for ancestor lookup --- @@ -397,11 +407,12 @@ def _norm(s): present_raw = [ raw for raw in all_raw if _get_node_category(graph.nodes[raw]) in layer_types_set + and not _is_excluded_node(raw) ] print(f"šŸ“Š Layer-type filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (filter: {list(layer_types)[:5]}{'...' if len(layer_types) > 5 else ''})") else: - present_raw = all_raw - print(f"šŸ“Š No filtering (fast): {total_nodes} nodes") + present_raw = [raw for raw in all_raw if not _is_excluded_node(raw)] + print(f"šŸ“Š No filtering (fast): {total_nodes} nodes → {len(present_raw)} nodes (excluded p_model_*/b_model_*)") norm_by_raw = {raw: _norm(raw) for raw in all_raw} # All nodes for lookup present_norm = {_norm(raw) for raw in present_raw} # Filtered set @@ -607,12 +618,13 @@ def visualize_relevance_auto( filtered_count = sum( 1 for n in graph.nodes if _get_node_category(graph.nodes[n]) in layer_types_set + and not _is_excluded_node(n) ) num_nodes = filtered_count print(f"num_nodes after layer_types filter: {num_nodes} (from {len(graph.nodes)} total)") else: - num_nodes = len(graph.nodes) - print(f"num_nodes: {num_nodes}") + num_nodes = sum(1 for n in graph.nodes if not _is_excluded_node(n)) + print(f"num_nodes: {num_nodes} (from {len(graph.nodes)} total, excluded p_model_*/b_model_*)") if num_nodes < node_threshold: # small graph → original pretty version From 8ae029583536d7bdfb357bbd7349ac016cfac127 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 10:29:57 +0530 Subject: [PATCH 10/16] fix relevance mean calculation to use sum() for single tensors - use sum() instead of mean() for single tensor relevance (batch=1) - add clarifying comments for mean/max/min computation - consistent handling between visualize_relevance and visualize_relevance_fast --- .../dlbacktrace/core/visualization.py | 30 ++++++++++++++----- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 0abf58c..802dd58 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -126,13 +126,22 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", relevance_data = {} # --- Extract relevance stats from all_wt --- + # Mean: sum of all entries (for batch=1) or average of sums across batches + # Max/Min: max/min of batch sums (for batched) or max/min element (for single) for node_name, rel in all_wt.items(): node_key = node_name.replace("/", " ").replace(":", " ") if isinstance(rel, (list, tuple)): - flat = [float(r.sum()) for r in rel if hasattr(r, "sum")] - stats = (float(sum(flat) / len(flat)), max(flat), min(flat)) if flat else (0.0, 0.0, 0.0) + # Batched data: list of tensors + batch_sums = [float(r.sum()) for r in rel if hasattr(r, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + stats = (mean_val, max(batch_sums), min(batch_sums)) + else: + stats = (0.0, 0.0, 0.0) elif hasattr(rel, "sum"): - stats = (float(rel.mean()), float(rel.max()), float(rel.min())) + # Single tensor (batch size = 1) + # Mean = sum of all entries in relevance vector + stats = (float(rel.sum()), float(rel.max()), float(rel.min())) else: try: val = float(rel) @@ -437,20 +446,25 @@ def find_filtered_ancestors(node_raw, visited=None): return ancestors # relevance only for present + # Mean: sum of all entries (for batch=1) or average of sums across batches + # Max: maximum individual value in the relevance tensor + # Min: minimum individual value in the relevance tensor rel_map = {} for k, v in all_wt.items(): nk = _norm(k) if nk not in present_norm: continue if isinstance(v, (list, tuple)): - flat = [float(t.sum()) for t in v if hasattr(t, "sum")] - if flat: - mean = float(sum(flat) / len(flat)) - rel_map[nk] = (mean, max(flat), min(flat)) + # Batched data + batch_sums = [float(t.sum()) for t in v if hasattr(t, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + rel_map[nk] = (mean_val, max(batch_sums), min(batch_sums)) else: rel_map[nk] = (0.0, 0.0, 0.0) elif hasattr(v, "sum"): - rel_map[nk] = (float(v.mean()), float(v.max()), float(v.min())) + # Single tensor (batch size = 1): Mean = sum of all entries + rel_map[nk] = (float(v.sum()), float(v.max()), float(v.min())) else: try: x = float(v) From 1cc938efe5c1049836c84cada41424f669b46c0c Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 11:57:28 +0530 Subject: [PATCH 11/16] add paginated visualization and rankdir support - add visualize_relevance_paginated() for splitting large graphs into pages - add rankdir parameter to visualize_relevance(), visualize_relevance_fast(), visualize_relevance_auto() for TB/LR graph direction - add paginated and max_nodes_per_page params to visualize_dlbacktrace() - pass rankdir through visualize_dlbacktrace() to underlying functions --- .../dlbacktrace/core/visualization.py | 263 +++++++++++++++++- .../dlbacktrace/dlbacktrace.py | 48 +++- 2 files changed, 298 insertions(+), 13 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 802dd58..ac4752e 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -100,6 +100,7 @@ def visualize_graph(graph, save_path="graph.png", *, show=True, dpi=600): def visualize_relevance(graph, all_wt, output_path="backtrace_graph", *, top_k=None, relevance_threshold=None, layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg"): """šŸŽÆ Visualize relevance backtrace using Graphviz (shows inline + saves) @@ -118,6 +119,10 @@ def visualize_relevance(graph, all_wt, output_path="backtrace_graph", layer_types : list[str], optional Filter to only these layer types. If None, shows all nodes. Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall), + "RL" (right-to-left), "BT" (bottom-to-top). Default "LR". + Use "TB" for LaTeX/paper-friendly vertical layout. show : bool Whether to display inline in Jupyter/Colab inline_format : str @@ -225,7 +230,7 @@ def find_filtered_ancestors(node_raw, visited=None): g = graphviz.Digraph( "DLBacktrace", format="svg", - graph_attr={"rankdir": "LR", "splines": "spline"}, + graph_attr={"rankdir": rankdir, "splines": "spline"}, node_attr={"fontname": "Helvetica", "fontsize": "10"} ) @@ -393,6 +398,7 @@ def visualize_relevance_fast( engine_auto_threshold=1200, disable_concentrate_for_sfdp=True, layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg", ): @@ -402,6 +408,9 @@ def visualize_relevance_fast( ---------- layer_types : list[str], optional Filter to only these layer types. If None, shows all nodes. + rankdir : str + Graph direction: "LR" (left-to-right), "TB" (top-to-bottom). + Default "LR". Use "TB" for LaTeX-friendly vertical layout. """ def _norm(s): return s.replace("/", " ").replace(":", " ") @@ -502,7 +511,7 @@ def find_filtered_ancestors(node_raw, visited=None): "outputorder": "edgesfirst", } if engine == "dot": - graph_attr["rankdir"] = "LR" + graph_attr["rankdir"] = rankdir graph_attr["splines"] = "spline" graph_attr["concentrate"] = "true" else: @@ -614,6 +623,7 @@ def visualize_relevance_auto( engine_auto_threshold=1500, fast_output_path="backtrace_collapsed_fast", layer_types: Optional[Sequence[str]] = None, + rankdir: str = "LR", show=True, inline_format="svg", ): @@ -625,6 +635,9 @@ def visualize_relevance_auto( Filter to only these layer types. If specified, layer_types filtering takes precedence over automatic collapsing for large graphs. Use SEMANTIC_LAYER_TYPES for a compact paper-ready graph. + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall). + Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout. """ # If layer_types specified, count only matching nodes for threshold decision if layer_types is not None: @@ -647,6 +660,7 @@ def visualize_relevance_auto( all_wt, output_path=output_path, layer_types=layer_types, + rankdir=rankdir, show=show, inline_format=inline_format, ) @@ -666,6 +680,251 @@ def visualize_relevance_auto( max_parents_per_node=2, engine_auto_threshold=engine_auto_threshold, layer_types=layer_types, + rankdir=rankdir, show=show, inline_format=inline_format, ) + + +def visualize_relevance_paginated( + graph, + all_wt, + output_path="backtrace_graph", + *, + max_nodes_per_page: int = 30, + layer_types: Optional[Sequence[str]] = None, + show=True, + inline_format="svg", +): + """šŸŽÆ Visualize relevance backtrace as multiple sub-graphs (pages) for long DAGs. + + Splits the graph into topologically-ordered pages, each containing at most + `max_nodes_per_page` nodes. Each page is saved as a separate file and + displayed inline sequentially. + + Parameters + ---------- + graph : networkx.DiGraph + The computation graph with layer_type attributes on nodes + all_wt : dict + Relevance weights for each node + output_path : str + Base output file path (without extension). Pages will be named + {output_path}_page1.svg, {output_path}_page2.svg, etc. + max_nodes_per_page : int + Maximum number of nodes per page/sub-graph (default: 30) + layer_types : list[str], optional + Filter to only these layer types. Use SEMANTIC_LAYER_TYPES for compact graphs. + show : bool + Whether to display inline in Jupyter/Colab + inline_format : str + Format for inline display ("svg" or "png") + + Returns + ------- + list[tuple[graphviz.Digraph, str]] + List of (graph, output_path) tuples for each page + """ + # --- Extract relevance stats --- + def _norm(s): + return s.replace("/", " ").replace(":", " ") + + relevance_data = {} + for node_name, rel in all_wt.items(): + node_key = _norm(node_name) + if isinstance(rel, (list, tuple)): + batch_sums = [float(r.sum()) for r in rel if hasattr(r, "sum")] + if batch_sums: + mean_val = sum(batch_sums) / len(batch_sums) + relevance_data[node_key] = (mean_val, max(batch_sums), min(batch_sums)) + else: + relevance_data[node_key] = (0.0, 0.0, 0.0) + elif hasattr(rel, "sum"): + relevance_data[node_key] = (float(rel.sum()), float(rel.max()), float(rel.min())) + else: + try: + val = float(rel) + relevance_data[node_key] = (val, val, val) + except Exception: + relevance_data[node_key] = (0.0, 0.0, 0.0) + + # --- Filter nodes --- + total_nodes = len(graph.nodes) + + if layer_types is not None: + layer_types_set = set(layer_types) | set(DEFAULT_FORCE_INCLUDE_TYPES) + filtered_nodes = [ + node for node in graph.nodes + if _get_node_category(graph.nodes[node]) in layer_types_set + and not _is_excluded_node(node) + ] + else: + filtered_nodes = [node for node in graph.nodes if not _is_excluded_node(node)] + + filtered_norm = {_norm(n) for n in filtered_nodes} + print(f"šŸ“Š Paginated: {total_nodes} nodes → {len(filtered_nodes)} nodes") + + # --- Build parent mapping for filtered nodes --- + raw_to_norm = {node: _norm(node) for node in graph.nodes} + + def find_filtered_ancestors(node_raw, visited=None): + """BFS to find ancestors in filtered set (for transitive edges).""" + if visited is None: + visited = set() + ancestors = set() + parents = graph.nodes[node_raw].get("parents", []) or [] + for parent_raw in parents: + if parent_raw in visited: + continue + visited.add(parent_raw) + parent_norm = raw_to_norm.get(parent_raw, _norm(parent_raw)) + if parent_norm in filtered_norm: + ancestors.add(parent_norm) + elif parent_raw in graph.nodes: + ancestors.update(find_filtered_ancestors(parent_raw, visited)) + return ancestors + + # --- Topological sort of filtered nodes --- + # Build edges for filtered subgraph + filtered_edges = {} # node_norm -> set of parent_norms (in filtered set) + for node in filtered_nodes: + node_norm = _norm(node) + if layer_types is not None: + filtered_edges[node_norm] = find_filtered_ancestors(node) + else: + parents = graph.nodes[node].get("parents", []) or [] + filtered_edges[node_norm] = {_norm(p) for p in parents if _norm(p) in filtered_norm} + + # Kahn's algorithm for topological sort + in_degree = {n: 0 for n in filtered_norm} + for node, parents in filtered_edges.items(): + for p in parents: + if p in in_degree: + in_degree[node] = in_degree.get(node, 0) + 1 + + # Start with nodes that have no filtered parents (in_degree == 0) + queue = [n for n, d in in_degree.items() if d == 0] + topo_order = [] + + while queue: + node = queue.pop(0) + topo_order.append(node) + # Find children of this node + for child, parents in filtered_edges.items(): + if node in parents: + in_degree[child] -= 1 + if in_degree[child] == 0 and child not in topo_order: + queue.append(child) + + # Add any remaining nodes (handles cycles gracefully) + for n in filtered_norm: + if n not in topo_order: + topo_order.append(n) + + print(f"šŸ“Š Topological order: {len(topo_order)} nodes") + + # --- Split into pages --- + pages = [] + for i in range(0, len(topo_order), max_nodes_per_page): + pages.append(topo_order[i:i + max_nodes_per_page]) + + print(f"šŸ“Š Split into {len(pages)} pages (max {max_nodes_per_page} nodes/page)") + + # --- Color scale setup --- + all_means = [relevance_data.get(n, (0.0, 0.0, 0.0))[0] for n in topo_order] + max_abs = max(abs(v) for v in all_means) if all_means else 1.0 + if max_abs == 0: + max_abs = 1.0 + + def get_color(mean_val): + norm = mean_val / max_abs + if norm >= 0: + r = int(255 * (1 - norm)) + return f"#{r:02x}ff{r:02x}" # green + else: + g = int(255 * (1 + norm)) + return f"#ff{g:02x}{g:02x}" # red + + # --- Render each page --- + results = [] + for page_idx, page_nodes in enumerate(pages): + page_num = page_idx + 1 + page_node_set = set(page_nodes) + + # Include connector nodes from previous page for context + connector_nodes = set() + if page_idx > 0: + prev_page_nodes = set(pages[page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) + + g = graphviz.Digraph( + name=f"DLBacktrace_Page{page_num}", + format="svg", + graph_attr={ + "rankdir": "TB", # Top-to-bottom for vertical flow + "label": f"DLBacktrace Graph - Page {page_num}/{len(pages)}", + "labelloc": "t", + "fontsize": "14", + "nodesep": "0.3", + "ranksep": "0.5", + }, + node_attr={ + "shape": "box", + "style": "filled,rounded", + "fontsize": "10", + }, + edge_attr={ + "fontsize": "8", + }, + ) + + # Add connector nodes (grayed out, from previous page) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + label = f"{node}\n(from prev page)" + g.node(node, label=label, fillcolor="lightgray", style="filled,rounded,dashed") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + label = f"{node}\nMean={mean:.4f}\nMax={mx:.4f} Min={mn:.4f}" + color = get_color(mean) + g.node(node, label=label, fillcolor=color) + + # Add edges within page and from connectors + added_edges = set() + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + edge = (parent, node) + if edge not in added_edges: + added_edges.add(edge) + # Get edge weight (child's mean relevance) + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + g.edge(parent, node, label=f"{child_mean:.3f}") + + # Render + page_output = f"{output_path}_page{page_num}" + out = g.render(page_output, cleanup=True) + results.append((g, out)) + + # Show inline + if show: + print(f"\n{'='*50}") + print(f"šŸ“„ Page {page_num}/{len(pages)} ({len(page_nodes)} nodes)") + print(f"{'='*50}") + if inline_format.lower() == "svg": + svg_bytes = g.pipe(format="svg") + display(SVG(svg_bytes)) + else: + png_bytes = g.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Page {page_num} saved → {out}") + + print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") + return results diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 3a9dd5a..b90bc7c 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -14,6 +14,7 @@ visualize_graph, visualize_relevance, visualize_relevance_auto, + visualize_relevance_paginated, SEMANTIC_LAYER_TYPES, ) from .core.token_relevance_visuals import ( @@ -1540,6 +1541,9 @@ def visualize_dlbacktrace( engine_auto_threshold=1500, layer_types=None, compact=False, + paginated=False, + max_nodes_per_page=30, + rankdir="LR", show=True, inline_format="svg" ): @@ -1568,6 +1572,15 @@ def visualize_dlbacktrace( compact : bool If True, uses SEMANTIC_LAYER_TYPES for a paper-ready compact graph. Equivalent to layer_types=SEMANTIC_LAYER_TYPES. + paginated : bool + If True, splits the graph into multiple pages for long DAGs. + Each page contains max_nodes_per_page nodes in topological order. + max_nodes_per_page : int + Maximum nodes per page when paginated=True (default: 30) + rankdir : str + Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall). + Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout that + fits on a single page. show : bool Whether to display inline in Jupyter/Colab inline_format : str @@ -1577,17 +1590,30 @@ def visualize_dlbacktrace( if compact and layer_types is None: layer_types = list(SEMANTIC_LAYER_TYPES) - visualize_relevance_auto( - self.graph, - self.all_wt, - output_path=output_path, - node_threshold=500, - engine_auto_threshold=engine_auto_threshold, - fast_output_path=output_path, - layer_types=layer_types, - show=show, - inline_format=inline_format, - ) + if paginated: + # Use paginated visualization for long graphs + visualize_relevance_paginated( + self.graph, + self.all_wt, + output_path=output_path, + max_nodes_per_page=max_nodes_per_page, + layer_types=layer_types, + show=show, + inline_format=inline_format, + ) + else: + visualize_relevance_auto( + self.graph, + self.all_wt, + output_path=output_path, + node_threshold=500, + engine_auto_threshold=engine_auto_threshold, + fast_output_path=output_path, + layer_types=layer_types, + rankdir=rankdir, + show=show, + inline_format=inline_format, + ) def visualize_dlbacktrace_with_modules( self, From 0a9d65b06ab67cb13d95043f084e8140bd2d861d Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 12:48:55 +0530 Subject: [PATCH 12/16] add pages_per_row for side-by-side paginated visualization - add pages_per_row param to visualize_relevance_paginated() and visualize_dlbacktrace() - when pages_per_row > 1, creates combined SVG with pages arranged horizontally - add rankdir param to visualize_relevance_paginated() (default TB for vertical) - pass rankdir through to paginated visualization from visualize_dlbacktrace() --- .../dlbacktrace/core/visualization.py | 109 +++++++++++++++++- .../dlbacktrace/dlbacktrace.py | 11 +- 2 files changed, 116 insertions(+), 4 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index ac4752e..9a94d63 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -693,6 +693,8 @@ def visualize_relevance_paginated( *, max_nodes_per_page: int = 30, layer_types: Optional[Sequence[str]] = None, + rankdir: str = "TB", + pages_per_row: int = 1, show=True, inline_format="svg", ): @@ -711,10 +713,18 @@ def visualize_relevance_paginated( output_path : str Base output file path (without extension). Pages will be named {output_path}_page1.svg, {output_path}_page2.svg, etc. + When pages_per_row > 1, also creates {output_path}_combined.svg max_nodes_per_page : int Maximum number of nodes per page/sub-graph (default: 30) layer_types : list[str], optional Filter to only these layer types. Use SEMANTIC_LAYER_TYPES for compact graphs. + rankdir : str + Graph direction: "TB" (top-to-bottom, default for paginated), + "LR" (left-to-right). Default "TB" for vertical flow. + pages_per_row : int + Number of pages to display side-by-side in a combined view (default: 1). + When > 1, creates an additional combined SVG with pages arranged + left-to-right. Useful for fitting multiple pages on one LaTeX page. show : bool Whether to display inline in Jupyter/Colab inline_format : str @@ -864,7 +874,7 @@ def get_color(mean_val): name=f"DLBacktrace_Page{page_num}", format="svg", graph_attr={ - "rankdir": "TB", # Top-to-bottom for vertical flow + "rankdir": rankdir, "label": f"DLBacktrace Graph - Page {page_num}/{len(pages)}", "labelloc": "t", "fontsize": "14", @@ -913,7 +923,7 @@ def get_color(mean_val): results.append((g, out)) # Show inline - if show: + if show and pages_per_row == 1: print(f"\n{'='*50}") print(f"šŸ“„ Page {page_num}/{len(pages)} ({len(page_nodes)} nodes)") print(f"{'='*50}") @@ -926,5 +936,100 @@ def get_color(mean_val): print(f"āœ… Page {page_num} saved → {out}") + # --- Create combined side-by-side view when pages_per_row > 1 --- + if pages_per_row > 1 and len(pages) > 1: + print(f"\nšŸ“Š Creating combined view with {pages_per_row} pages per row...") + + # Create a master graph that uses subgraphs for side-by-side layout + combined = graphviz.Digraph( + name="DLBacktrace_Combined", + format="svg", + graph_attr={ + "rankdir": "LR", # Left-to-right for side-by-side pages + "label": f"DLBacktrace Graph - Combined View ({len(pages)} pages)", + "labelloc": "t", + "fontsize": "16", + "compound": "true", + "newrank": "true", + }, + ) + + # Add each page as a cluster subgraph + for page_idx, page_nodes in enumerate(pages): + page_num = page_idx + 1 + page_node_set = set(page_nodes) + + # Determine connector nodes from previous page + connector_nodes = set() + if page_idx > 0: + prev_page_nodes = set(pages[page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) + + with combined.subgraph(name=f"cluster_page{page_num}") as subg: + subg.attr( + label=f"Page {page_num}", + style="rounded,filled", + color="lightgray", + fillcolor="white", + fontsize="12", + ) + subg.attr("graph", rankdir=rankdir) # Each page flows vertically + + # Prefix node names with page number to avoid conflicts + prefix = f"p{page_num}_" + + # Add connector nodes (grayed out) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + label = f"{node}\n(prev)" + subg.node(prefix + node, label=label, fillcolor="lightgray", + style="filled,rounded,dashed", shape="box", fontsize="9") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + label = f"{node}\nM={mean:.3f}\n↑{mx:.3f} ↓{mn:.3f}" + color = get_color(mean) + subg.node(prefix + node, label=label, fillcolor=color, + style="filled,rounded", shape="box", fontsize="9") + + # Add edges within page + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + subg.edge(prefix + parent, prefix + node, + label=f"{child_mean:.2f}", fontsize="7") + + # Add invisible edges between clusters to maintain left-to-right order + for i in range(len(pages) - 1): + if pages[i] and pages[i + 1]: + # Connect last node of page i to first node of page i+1 (invisible) + src = f"p{i+1}_{pages[i][-1]}" + dst = f"p{i+2}_{pages[i+1][0]}" + combined.edge(src, dst, style="invis", constraint="true") + + # Render combined + combined_output = f"{output_path}_combined" + combined_out = combined.render(combined_output, cleanup=True) + results.append((combined, combined_out)) + + if show: + print(f"\n{'='*60}") + print(f"šŸ“„ Combined View ({len(pages)} pages side-by-side)") + print(f"{'='*60}") + if inline_format.lower() == "svg": + svg_bytes = combined.pipe(format="svg") + display(SVG(svg_bytes)) + else: + png_bytes = combined.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Combined graph saved → {combined_out}") + print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") return results diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index b90bc7c..36347e0 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -1543,6 +1543,7 @@ def visualize_dlbacktrace( compact=False, paginated=False, max_nodes_per_page=30, + pages_per_row=1, rankdir="LR", show=True, inline_format="svg" @@ -1577,10 +1578,14 @@ def visualize_dlbacktrace( Each page contains max_nodes_per_page nodes in topological order. max_nodes_per_page : int Maximum nodes per page when paginated=True (default: 30) + pages_per_row : int + Number of pages to display side-by-side when paginated=True (default: 1). + When > 1, creates an additional combined SVG with pages arranged + left-to-right. E.g., pages_per_row=3 puts 3 pages side-by-side. rankdir : str Graph direction: "LR" (left-to-right, wide), "TB" (top-to-bottom, tall). - Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout that - fits on a single page. + Default "LR". Use "TB" for LaTeX/paper-friendly vertical layout. + For paginated mode, "TB" is recommended for each page. show : bool Whether to display inline in Jupyter/Colab inline_format : str @@ -1598,6 +1603,8 @@ def visualize_dlbacktrace( output_path=output_path, max_nodes_per_page=max_nodes_per_page, layer_types=layer_types, + rankdir=rankdir, + pages_per_row=pages_per_row, show=show, inline_format=inline_format, ) From b99de45632af711d31bfe66c558b0d25fd734326 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 13:21:54 +0530 Subject: [PATCH 13/16] improve paginated combined view with row-based SVG files - split pages into rows (pages_per_row pages each) - create separate combined SVG for each row - proper cluster subgraph layout with inter-page edges - show page connections across consecutive pages in same row --- .../dlbacktrace/core/visualization.py | 199 ++++++++++-------- 1 file changed, 114 insertions(+), 85 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 9a94d63..9717bf0 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -936,100 +936,129 @@ def get_color(mean_val): print(f"āœ… Page {page_num} saved → {out}") - # --- Create combined side-by-side view when pages_per_row > 1 --- + # --- Create combined side-by-side views when pages_per_row > 1 --- if pages_per_row > 1 and len(pages) > 1: - print(f"\nšŸ“Š Creating combined view with {pages_per_row} pages per row...") + # Split pages into rows (each row has at most pages_per_row pages) + page_rows = [] + for i in range(0, len(pages), pages_per_row): + page_rows.append(list(enumerate(pages[i:i + pages_per_row], start=i))) - # Create a master graph that uses subgraphs for side-by-side layout - combined = graphviz.Digraph( - name="DLBacktrace_Combined", - format="svg", - graph_attr={ - "rankdir": "LR", # Left-to-right for side-by-side pages - "label": f"DLBacktrace Graph - Combined View ({len(pages)} pages)", - "labelloc": "t", - "fontsize": "16", - "compound": "true", - "newrank": "true", - }, - ) + print(f"\nšŸ“Š Creating {len(page_rows)} combined view(s) with {pages_per_row} pages per row...") - # Add each page as a cluster subgraph - for page_idx, page_nodes in enumerate(pages): - page_num = page_idx + 1 - page_node_set = set(page_nodes) + for row_idx, row_pages in enumerate(page_rows): + row_num = row_idx + 1 - # Determine connector nodes from previous page - connector_nodes = set() - if page_idx > 0: - prev_page_nodes = set(pages[page_idx - 1]) - for node in page_nodes: - for parent in filtered_edges.get(node, set()): - if parent in prev_page_nodes: - connector_nodes.add(parent) + # Create a master graph using HTML-like table for true side-by-side + combined = graphviz.Digraph( + name=f"DLBacktrace_Combined_Row{row_num}", + format="svg", + engine="dot", + graph_attr={ + "rankdir": "LR", # Main graph is LR for side-by-side clusters + "label": f"DLBacktrace Graph - Row {row_num}/{len(page_rows)} (Pages {row_pages[0][0]+1}-{row_pages[-1][0]+1})", + "labelloc": "t", + "fontsize": "14", + "splines": "spline", + "compound": "true", + "nodesep": "0.2", + "ranksep": "0.3", + }, + ) - with combined.subgraph(name=f"cluster_page{page_num}") as subg: - subg.attr( - label=f"Page {page_num}", - style="rounded,filled", - color="lightgray", - fillcolor="white", - fontsize="12", - ) - subg.attr("graph", rankdir=rankdir) # Each page flows vertically + # Add each page in this row as a cluster subgraph + for local_idx, (global_page_idx, page_nodes) in enumerate(row_pages): + page_num = global_page_idx + 1 + page_node_set = set(page_nodes) - # Prefix node names with page number to avoid conflicts - prefix = f"p{page_num}_" + # Determine connector nodes from previous page (if in this row) + connector_nodes = set() + if global_page_idx > 0 and local_idx > 0: + prev_page_nodes = set(pages[global_page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) - # Add connector nodes (grayed out) - for node in connector_nodes: - stats = relevance_data.get(node, (0.0, 0.0, 0.0)) - label = f"{node}\n(prev)" - subg.node(prefix + node, label=label, fillcolor="lightgray", - style="filled,rounded,dashed", shape="box", fontsize="9") - - # Add page nodes - for node in page_nodes: - stats = relevance_data.get(node, (0.0, 0.0, 0.0)) - mean, mx, mn = stats - label = f"{node}\nM={mean:.3f}\n↑{mx:.3f} ↓{mn:.3f}" - color = get_color(mean) - subg.node(prefix + node, label=label, fillcolor=color, - style="filled,rounded", shape="box", fontsize="9") + with combined.subgraph(name=f"cluster_page{page_num}") as subg: + subg.attr( + label=f"Page {page_num}", + style="rounded", + color="black", + fontsize="11", + margin="10", + ) + # Force TB direction inside each cluster + subg.attr("graph", rankdir="TB", nodesep="0.15", ranksep="0.25") + + # Prefix node names with page number to avoid conflicts + prefix = f"p{page_num}_" + + # Add connector nodes (grayed out) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\n(prev)" + subg.node(prefix + node, label=label, fillcolor="lightgray", + style="filled,rounded,dashed", shape="box", fontsize="8") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\nM={mean:.3f}\n↑{mx:.3f} ↓{mn:.3f}" + color = get_color(mean) + subg.node(prefix + node, label=label, fillcolor=color, + style="filled,rounded", shape="box", fontsize="8") + + # Add edges within page + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + subg.edge(prefix + parent, prefix + node, + label=f"{child_mean:.2f}", fontsize="7") + + # Add invisible edges between clusters to force left-to-right ordering + # Use rank=same to align first nodes of each cluster + if len(row_pages) > 1: + with combined.subgraph() as s: + s.attr(rank="same") + for local_idx, (global_page_idx, page_nodes) in enumerate(row_pages): + if page_nodes: + prefix = f"p{global_page_idx + 1}_" + s.node(prefix + page_nodes[0]) - # Add edges within page - for node in page_nodes: - for parent in filtered_edges.get(node, set()): - if parent in page_node_set or parent in connector_nodes: - child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] - subg.edge(prefix + parent, prefix + node, - label=f"{child_mean:.2f}", fontsize="7") - - # Add invisible edges between clusters to maintain left-to-right order - for i in range(len(pages) - 1): - if pages[i] and pages[i + 1]: - # Connect last node of page i to first node of page i+1 (invisible) - src = f"p{i+1}_{pages[i][-1]}" - dst = f"p{i+2}_{pages[i+1][0]}" - combined.edge(src, dst, style="invis", constraint="true") - - # Render combined - combined_output = f"{output_path}_combined" - combined_out = combined.render(combined_output, cleanup=True) - results.append((combined, combined_out)) - - if show: - print(f"\n{'='*60}") - print(f"šŸ“„ Combined View ({len(pages)} pages side-by-side)") - print(f"{'='*60}") - if inline_format.lower() == "svg": - svg_bytes = combined.pipe(format="svg") - display(SVG(svg_bytes)) + # Invisible edges to maintain order + for i in range(len(row_pages) - 1): + _, curr_nodes = row_pages[i] + _, next_nodes = row_pages[i + 1] + if curr_nodes and next_nodes: + src_prefix = f"p{row_pages[i][0] + 1}_" + dst_prefix = f"p{row_pages[i + 1][0] + 1}_" + combined.edge(src_prefix + curr_nodes[0], dst_prefix + next_nodes[0], + style="invis", constraint="true") + + # Render combined row + if len(page_rows) == 1: + combined_output = f"{output_path}_combined" else: - png_bytes = combined.pipe(format="png") - display(IPyImage(data=png_bytes)) - - print(f"āœ… Combined graph saved → {combined_out}") + combined_output = f"{output_path}_combined_row{row_num}" + combined_out = combined.render(combined_output, cleanup=True) + results.append((combined, combined_out)) + + if show: + print(f"\n{'='*60}") + print(f"šŸ“„ Combined View Row {row_num}/{len(page_rows)} ({len(row_pages)} pages side-by-side)") + print(f"{'='*60}") + if inline_format.lower() == "svg": + svg_bytes = combined.pipe(format="svg") + display(SVG(svg_bytes)) + else: + png_bytes = combined.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Combined row {row_num} saved → {combined_out}") print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") return results From 1076ab2159b47a62b74888b647232e3ecc812a3a Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 15:22:03 +0530 Subject: [PATCH 14/16] change paginated combined view to vertical stacking (TB) - rename row-based to column-based layout (pages_per_row stacks vertically) - use TB rankdir for main graph (vertical cluster stacking) - use LR rankdir inside each page cluster (wider node layout) - better visual flow for sequential page reading --- .../dlbacktrace/core/visualization.py | 80 +++++++++---------- 1 file changed, 37 insertions(+), 43 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 9717bf0..107c931 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -936,43 +936,45 @@ def get_color(mean_val): print(f"āœ… Page {page_num} saved → {out}") - # --- Create combined side-by-side views when pages_per_row > 1 --- + # --- Create combined column-wise views when pages_per_row > 1 --- + # pages_per_row = number of pages stacked vertically in each combined view if pages_per_row > 1 and len(pages) > 1: - # Split pages into rows (each row has at most pages_per_row pages) - page_rows = [] + # Split pages into columns (each combined view has at most pages_per_row pages stacked vertically) + page_columns = [] for i in range(0, len(pages), pages_per_row): - page_rows.append(list(enumerate(pages[i:i + pages_per_row], start=i))) + page_columns.append(list(enumerate(pages[i:i + pages_per_row], start=i))) - print(f"\nšŸ“Š Creating {len(page_rows)} combined view(s) with {pages_per_row} pages per row...") + print(f"\nšŸ“Š Creating {len(page_columns)} combined view(s) with {pages_per_row} pages per column (stacked vertically)...") - for row_idx, row_pages in enumerate(page_rows): - row_num = row_idx + 1 + for col_idx, col_pages in enumerate(page_columns): + col_num = col_idx + 1 - # Create a master graph using HTML-like table for true side-by-side + # Create a master graph with TB rankdir so clusters stack vertically combined = graphviz.Digraph( - name=f"DLBacktrace_Combined_Row{row_num}", + name=f"DLBacktrace_Combined_Col{col_num}", format="svg", engine="dot", graph_attr={ - "rankdir": "LR", # Main graph is LR for side-by-side clusters - "label": f"DLBacktrace Graph - Row {row_num}/{len(page_rows)} (Pages {row_pages[0][0]+1}-{row_pages[-1][0]+1})", + "rankdir": "TB", # Main graph is TB for vertical stacking of clusters + "label": f"DLBacktrace Graph - Column {col_num}/{len(page_columns)} (Pages {col_pages[0][0]+1}-{col_pages[-1][0]+1})", "labelloc": "t", "fontsize": "14", "splines": "spline", "compound": "true", - "nodesep": "0.2", - "ranksep": "0.3", + "nodesep": "0.3", + "ranksep": "0.5", + "newrank": "true", }, ) - # Add each page in this row as a cluster subgraph - for local_idx, (global_page_idx, page_nodes) in enumerate(row_pages): + # Add each page in this column as a cluster subgraph + for local_idx, (global_page_idx, page_nodes) in enumerate(col_pages): page_num = global_page_idx + 1 page_node_set = set(page_nodes) - # Determine connector nodes from previous page (if in this row) + # Determine connector nodes from previous page connector_nodes = set() - if global_page_idx > 0 and local_idx > 0: + if global_page_idx > 0: prev_page_nodes = set(pages[global_page_idx - 1]) for node in page_nodes: for parent in filtered_edges.get(node, set()): @@ -987,8 +989,8 @@ def get_color(mean_val): fontsize="11", margin="10", ) - # Force TB direction inside each cluster - subg.attr("graph", rankdir="TB", nodesep="0.15", ranksep="0.25") + # Each page cluster flows left-to-right (LR) for wider layout + subg.attr("graph", rankdir="LR", nodesep="0.2", ranksep="0.3") # Prefix node names with page number to avoid conflicts prefix = f"p{page_num}_" @@ -996,7 +998,7 @@ def get_color(mean_val): # Add connector nodes (grayed out) for node in connector_nodes: stats = relevance_data.get(node, (0.0, 0.0, 0.0)) - short_node = node[:20] + "..." if len(node) > 20 else node + short_node = node[:25] + "..." if len(node) > 25 else node label = f"{short_node}\n(prev)" subg.node(prefix + node, label=label, fillcolor="lightgray", style="filled,rounded,dashed", shape="box", fontsize="8") @@ -1005,7 +1007,7 @@ def get_color(mean_val): for node in page_nodes: stats = relevance_data.get(node, (0.0, 0.0, 0.0)) mean, mx, mn = stats - short_node = node[:20] + "..." if len(node) > 20 else node + short_node = node[:25] + "..." if len(node) > 25 else node label = f"{short_node}\nM={mean:.3f}\n↑{mx:.3f} ↓{mn:.3f}" color = get_color(mean) subg.node(prefix + node, label=label, fillcolor=color, @@ -1019,37 +1021,29 @@ def get_color(mean_val): subg.edge(prefix + parent, prefix + node, label=f"{child_mean:.2f}", fontsize="7") - # Add invisible edges between clusters to force left-to-right ordering - # Use rank=same to align first nodes of each cluster - if len(row_pages) > 1: - with combined.subgraph() as s: - s.attr(rank="same") - for local_idx, (global_page_idx, page_nodes) in enumerate(row_pages): - if page_nodes: - prefix = f"p{global_page_idx + 1}_" - s.node(prefix + page_nodes[0]) - - # Invisible edges to maintain order - for i in range(len(row_pages) - 1): - _, curr_nodes = row_pages[i] - _, next_nodes = row_pages[i + 1] + # Add invisible edges between clusters to force top-to-bottom ordering + if len(col_pages) > 1: + for i in range(len(col_pages) - 1): + _, curr_nodes = col_pages[i] + _, next_nodes = col_pages[i + 1] if curr_nodes and next_nodes: - src_prefix = f"p{row_pages[i][0] + 1}_" - dst_prefix = f"p{row_pages[i + 1][0] + 1}_" - combined.edge(src_prefix + curr_nodes[0], dst_prefix + next_nodes[0], + # Connect last node of current page to first node of next page + src_prefix = f"p{col_pages[i][0] + 1}_" + dst_prefix = f"p{col_pages[i + 1][0] + 1}_" + combined.edge(src_prefix + curr_nodes[-1], dst_prefix + next_nodes[0], style="invis", constraint="true") - # Render combined row - if len(page_rows) == 1: + # Render combined column + if len(page_columns) == 1: combined_output = f"{output_path}_combined" else: - combined_output = f"{output_path}_combined_row{row_num}" + combined_output = f"{output_path}_combined_col{col_num}" combined_out = combined.render(combined_output, cleanup=True) results.append((combined, combined_out)) if show: print(f"\n{'='*60}") - print(f"šŸ“„ Combined View Row {row_num}/{len(page_rows)} ({len(row_pages)} pages side-by-side)") + print(f"šŸ“„ Combined View Column {col_num}/{len(page_columns)} ({len(col_pages)} pages stacked)") print(f"{'='*60}") if inline_format.lower() == "svg": svg_bytes = combined.pipe(format="svg") @@ -1058,7 +1052,7 @@ def get_color(mean_val): png_bytes = combined.pipe(format="png") display(IPyImage(data=png_bytes)) - print(f"āœ… Combined row {row_num} saved → {combined_out}") + print(f"āœ… Combined column {col_num} saved → {combined_out}") print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") return results From 45bf16592760f3d21c096df434540d917f70f466 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Fri, 27 Mar 2026 16:44:25 +0530 Subject: [PATCH 15/16] fix paginated combined view to use Nx3 matrix grid layout - organize pages into columns (page % num_columns) for proper grid - use orthogonal splines for cleaner inter-page edges - add invisible edges to enforce row ordering across columns - create nested clusters: column clusters containing page clusters - pages flow: col0[p1,p4,p7], col1[p2,p5,p8], col2[p3,p6,p9] for Nx3 --- .../dlbacktrace/core/visualization.py | 242 ++++++++++-------- 1 file changed, 133 insertions(+), 109 deletions(-) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py index 107c931..cb6b953 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/core/visualization.py @@ -936,123 +936,147 @@ def get_color(mean_val): print(f"āœ… Page {page_num} saved → {out}") - # --- Create combined column-wise views when pages_per_row > 1 --- - # pages_per_row = number of pages stacked vertically in each combined view + # --- Create combined grid/matrix views when pages_per_row > 1 --- + # Layout: pages arranged in columns, flowing top-to-bottom within each column + # Example: 3 pages with pages_per_row=3 creates a Nx3 matrix if pages_per_row > 1 and len(pages) > 1: - # Split pages into columns (each combined view has at most pages_per_row pages stacked vertically) - page_columns = [] - for i in range(0, len(pages), pages_per_row): - page_columns.append(list(enumerate(pages[i:i + pages_per_row], start=i))) + num_columns = min(pages_per_row, len(pages)) - print(f"\nšŸ“Š Creating {len(page_columns)} combined view(s) with {pages_per_row} pages per column (stacked vertically)...") + # Organize pages into columns + # Column 0: pages[0], pages[num_columns], pages[2*num_columns], ... + # Column 1: pages[1], pages[num_columns+1], ... + columns = [[] for _ in range(num_columns)] + for page_idx, page_nodes in enumerate(pages): + col_idx = page_idx % num_columns + columns[col_idx].append((page_idx, page_nodes)) - for col_idx, col_pages in enumerate(page_columns): - col_num = col_idx + 1 - - # Create a master graph with TB rankdir so clusters stack vertically - combined = graphviz.Digraph( - name=f"DLBacktrace_Combined_Col{col_num}", - format="svg", - engine="dot", - graph_attr={ - "rankdir": "TB", # Main graph is TB for vertical stacking of clusters - "label": f"DLBacktrace Graph - Column {col_num}/{len(page_columns)} (Pages {col_pages[0][0]+1}-{col_pages[-1][0]+1})", - "labelloc": "t", - "fontsize": "14", - "splines": "spline", - "compound": "true", - "nodesep": "0.3", - "ranksep": "0.5", - "newrank": "true", - }, - ) - - # Add each page in this column as a cluster subgraph - for local_idx, (global_page_idx, page_nodes) in enumerate(col_pages): - page_num = global_page_idx + 1 - page_node_set = set(page_nodes) - - # Determine connector nodes from previous page - connector_nodes = set() - if global_page_idx > 0: - prev_page_nodes = set(pages[global_page_idx - 1]) - for node in page_nodes: - for parent in filtered_edges.get(node, set()): - if parent in prev_page_nodes: - connector_nodes.add(parent) + num_rows = max(len(col) for col in columns) + + print(f"\nšŸ“Š Creating grid layout: {num_rows} rows Ɨ {num_columns} columns...") + + # Create a master graph + combined = graphviz.Digraph( + name="DLBacktrace_Combined_Grid", + format="svg", + engine="dot", + graph_attr={ + "rankdir": "TB", # Top-to-bottom for rows + "label": f"DLBacktrace Graph - Grid View ({len(pages)} pages in {num_rows}Ɨ{num_columns} layout)", + "labelloc": "t", + "fontsize": "14", + "splines": "ortho", # Orthogonal edges for cleaner look + "nodesep": "0.4", + "ranksep": "0.6", + "newrank": "true", + }, + ) + + # Create nodes for each page in its column cluster + for col_idx, col_pages in enumerate(columns): + with combined.subgraph(name=f"cluster_col{col_idx}") as col_subg: + col_subg.attr( + label=f"Column {col_idx + 1}", + style="rounded,dashed", + color="gray", + fontsize="10", + margin="15", + ) - with combined.subgraph(name=f"cluster_page{page_num}") as subg: - subg.attr( - label=f"Page {page_num}", - style="rounded", - color="black", - fontsize="11", - margin="10", - ) - # Each page cluster flows left-to-right (LR) for wider layout - subg.attr("graph", rankdir="LR", nodesep="0.2", ranksep="0.3") - - # Prefix node names with page number to avoid conflicts - prefix = f"p{page_num}_" + for row_idx, (global_page_idx, page_nodes) in enumerate(col_pages): + page_num = global_page_idx + 1 + page_node_set = set(page_nodes) - # Add connector nodes (grayed out) - for node in connector_nodes: - stats = relevance_data.get(node, (0.0, 0.0, 0.0)) - short_node = node[:25] + "..." if len(node) > 25 else node - label = f"{short_node}\n(prev)" - subg.node(prefix + node, label=label, fillcolor="lightgray", - style="filled,rounded,dashed", shape="box", fontsize="8") + # Determine connector nodes from previous page + connector_nodes = set() + if global_page_idx > 0: + prev_page_nodes = set(pages[global_page_idx - 1]) + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in prev_page_nodes: + connector_nodes.add(parent) - # Add page nodes - for node in page_nodes: - stats = relevance_data.get(node, (0.0, 0.0, 0.0)) - mean, mx, mn = stats - short_node = node[:25] + "..." if len(node) > 25 else node - label = f"{short_node}\nM={mean:.3f}\n↑{mx:.3f} ↓{mn:.3f}" - color = get_color(mean) - subg.node(prefix + node, label=label, fillcolor=color, - style="filled,rounded", shape="box", fontsize="8") + # Create a sub-cluster for this page + with col_subg.subgraph(name=f"cluster_page{page_num}") as page_subg: + page_subg.attr( + label=f"Page {page_num}", + style="rounded,filled", + color="black", + fillcolor="white", + fontsize="9", + margin="8", + ) + + prefix = f"p{page_num}_" + + # Add connector nodes (grayed out) + for node in connector_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\n(prev)" + page_subg.node(prefix + node, label=label, fillcolor="lightgray", + style="filled,rounded,dashed", shape="box", fontsize="7") + + # Add page nodes + for node in page_nodes: + stats = relevance_data.get(node, (0.0, 0.0, 0.0)) + mean, mx, mn = stats + short_node = node[:20] + "..." if len(node) > 20 else node + label = f"{short_node}\nM={mean:.3f}\n↑{mx:.3f}↓{mn:.3f}" + color = get_color(mean) + page_subg.node(prefix + node, label=label, fillcolor=color, + style="filled,rounded", shape="box", fontsize="7") + + # Add edges within page + for node in page_nodes: + for parent in filtered_edges.get(node, set()): + if parent in page_node_set or parent in connector_nodes: + child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] + page_subg.edge(prefix + parent, prefix + node, + label=f"{child_mean:.2f}", fontsize="6") - # Add edges within page - for node in page_nodes: - for parent in filtered_edges.get(node, set()): - if parent in page_node_set or parent in connector_nodes: - child_mean = relevance_data.get(node, (0.0, 0.0, 0.0))[0] - subg.edge(prefix + parent, prefix + node, - label=f"{child_mean:.2f}", fontsize="7") - - # Add invisible edges between clusters to force top-to-bottom ordering - if len(col_pages) > 1: - for i in range(len(col_pages) - 1): - _, curr_nodes = col_pages[i] - _, next_nodes = col_pages[i + 1] - if curr_nodes and next_nodes: - # Connect last node of current page to first node of next page - src_prefix = f"p{col_pages[i][0] + 1}_" - dst_prefix = f"p{col_pages[i + 1][0] + 1}_" - combined.edge(src_prefix + curr_nodes[-1], dst_prefix + next_nodes[0], - style="invis", constraint="true") - - # Render combined column - if len(page_columns) == 1: - combined_output = f"{output_path}_combined" + # Add invisible edge to next page in this column + if row_idx < len(col_pages) - 1: + next_page_idx, next_page_nodes = col_pages[row_idx + 1] + if page_nodes and next_page_nodes: + src_prefix = f"p{page_num}_" + dst_prefix = f"p{next_page_idx + 1}_" + combined.edge(src_prefix + page_nodes[-1], dst_prefix + next_page_nodes[0], + style="invis", constraint="true") + + # Add invisible edges between columns to align them horizontally + # Connect first node of each column's first page + for col_idx in range(num_columns - 1): + if columns[col_idx] and columns[col_idx + 1]: + _, left_nodes = columns[col_idx][0] + _, right_nodes = columns[col_idx + 1][0] + if left_nodes and right_nodes: + left_page_idx = columns[col_idx][0][0] + right_page_idx = columns[col_idx + 1][0][0] + src = f"p{left_page_idx + 1}_{left_nodes[0]}" + dst = f"p{right_page_idx + 1}_{right_nodes[0]}" + # Use rank=same to align horizontally + with combined.subgraph() as s: + s.attr(rank="same") + s.node(src) + s.node(dst) + + # Render combined grid + combined_output = f"{output_path}_combined" + combined_out = combined.render(combined_output, cleanup=True) + results.append((combined, combined_out)) + + if show: + print(f"\n{'='*60}") + print(f"šŸ“„ Combined Grid View ({num_rows} rows Ɨ {num_columns} columns)") + print(f"{'='*60}") + if inline_format.lower() == "svg": + svg_bytes = combined.pipe(format="svg") + display(SVG(svg_bytes)) else: - combined_output = f"{output_path}_combined_col{col_num}" - combined_out = combined.render(combined_output, cleanup=True) - results.append((combined, combined_out)) - - if show: - print(f"\n{'='*60}") - print(f"šŸ“„ Combined View Column {col_num}/{len(page_columns)} ({len(col_pages)} pages stacked)") - print(f"{'='*60}") - if inline_format.lower() == "svg": - svg_bytes = combined.pipe(format="svg") - display(SVG(svg_bytes)) - else: - png_bytes = combined.pipe(format="png") - display(IPyImage(data=png_bytes)) - - print(f"āœ… Combined column {col_num} saved → {combined_out}") + png_bytes = combined.pipe(format="png") + display(IPyImage(data=png_bytes)) + + print(f"āœ… Combined grid saved → {combined_out}") print(f"\nšŸ“Š Total: {len(pages)} pages saved with base path '{output_path}_pageN.svg'") return results From 8cf7e7b9584ebb9b60e67f930529305ab5fd4557 Mon Sep 17 00:00:00 2001 From: neerajaryaai Date: Mon, 30 Mar 2026 19:03:55 +0530 Subject: [PATCH 16/16] fix: Add missing PropagationSchedule import --- dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py index 7c8b61e..8c4a14a 100755 --- a/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py +++ b/dl_backtrace/pytorch_backtrace/dlbacktrace/dlbacktrace.py @@ -10,6 +10,7 @@ from .core.config import activation_master from .core.dlb_auto_sampler import DLBAutoSampler from .core.relevance_propagation import RelevancePropagator +from .core.compiled_propagation import PropagationSchedule from .core.visualization import ( visualize_graph, visualize_relevance,