Skip to content

feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task#13

Open
neerajaryaai wants to merge 18 commits intodlb_v2from
forward_only_generation
Open

feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task#13
neerajaryaai wants to merge 18 commits intodlb_v2from
forward_only_generation

Conversation

@neerajaryaai
Copy link
Copy Markdown
Collaborator

@neerajaryaai neerajaryaai commented Feb 4, 2026

Summary

This PR introduces a new forward-only-generation task type to run_task() that enables fast token generation without computing relevance scores. This is useful when users only need generated tokens and want to minimize memory usage and latency.

It also adds a two-step API (forward_pass() + relevance_pass()) that decouples multi-token generation from relevance computation, giving users full control over when and which tokens to explain.

Changes

New Feature: forward-only-generation Task

  • Added new task type "forward-only-generation" to run_task() method
  • Implements autoregressive token generation loop with:
    • Greedy decoding (default)
    • Temperature scaling for controlling output distribution
    • Top-k sampling for limiting token candidates
    • Top-p (nucleus) sampling for dynamic vocabulary truncation
  • Supports early stopping via eos_token_id parameter
  • Clears node_io between generation steps to reduce memory footprint

Improved Visualization API

  • Added show and inline_format parameters to visualize_dlbacktrace() method
  • Made visualization output path consistent between small and large graphs

New Feature: Two-Step forward_pass() + relevance_pass() API

  • forward_pass() — Runs autoregressive generation for N tokens, storing per-step node_io snapshots efficiently:
    • Clones only tensor data (input_values, output_values, layer_hyperparams); shallow-copies immutable graph metadata
    • Moves snapshot tensors to CPU by default to free GPU VRAM
    • Clears GPU memory + runs gc.collect() between steps
    • Supports greedy, temperature, top-k, and top-p sampling
    • Returns generated_token_ids, complete_sequence, and num_steps
  • relevance_pass() — Computes relevance for selected tokens from a prior forward_pass():
    • Reads from self.node_io_trace (populated by forward_pass())
    • Accepts token_indices to explain specific steps (e.g., [0, 4, 9]) or all steps (None)
    • Returns per-step {'step_index', 'token_id', 'relevance'} dicts
  • clear_traces() — Frees all stored snapshots, relevance data, and GPU cache

Usage Examples

# Fast forward-only generation (no relevance, no tokenizer needed)
results = dlb.run_task(
    task="forward-only-generation",
    inputs={'input_ids': input_ids, 'attention_mask': attention_mask},
    max_new_tokens=50,
    temperature=0.8,
    top_k=50,
    top_p=0.9,
    eos_token_id=tokenizer.eos_token_id,
)
generated_tokens = results['generated_token_ids']  # List[int]
# Two-step API: generate first, explain later
# Step 1 — Forward pass (generate 10 tokens, store node I/O)
result = dlb.forward_pass(
    inputs={'input_ids': input_ids, 'attention_mask': attention_mask},
    max_new_tokens=10,
    temperature=0.8,
    top_k=50,
    debug=True,
)
print(result['generated_token_ids'])  # [tok1, tok2, ..., tok10]

# Step 2 — Relevance pass (explain selected tokens)
relevance_results = dlb.relevance_pass(
    token_indices=[0, 4, 9],  # explain 1st, 5th, and 10th tokens
    multiplier=100.0,
    debug=True,
)
for r in relevance_results:
    print(f"Token {r['token_id']} (step {r['step_index']})")

# Cleanup
dlb.clear_traces()

@neerajaryaai neerajaryaai changed the title feat: Add 'forward-only-generation' task for fast token generation without relevance tracing feat: Add 'forward-only-generation' task for fast token generation without relevance tracing and two-step forward_pass() + relevance_pass() API for decoupled multi-token generation and explanation Feb 9, 2026
@neerajaryaai neerajaryaai changed the title feat: Add 'forward-only-generation' task for fast token generation without relevance tracing and two-step forward_pass() + relevance_pass() API for decoupled multi-token generation and explanation feat: Add two-step generation API (forward_pass + relevance_pass) and forward-only-generation task Feb 9, 2026
neerajaryaai and others added 4 commits February 9, 2026 10:33
- Add SEMANTIC_LAYER_TYPES constant for paper-ready compact visualizations
- Add layer_types parameter to visualize_relevance, visualize_relevance_fast, and visualize_relevance_auto
- Filter nodes by semantic layer types (MLP, Attention, Normalization, etc.)
- Always include Placeholder, Model_Input, and Output nodes for graph connectivity
- Export SEMANTIC_LAYER_TYPES from package __init__.py
- Add SEMANTIC_LAYER_TYPES constant for paper-ready compact graphs
- Add layer_types parameter to visualize_relevance() and visualize_relevance_fast()
- Add compact=True shortcut to visualize_dlbacktrace() API
- Implement transitive edge computation to maintain DAG connectivity
  when intermediate nodes are filtered out
- Export SEMANTIC_LAYER_TYPES from package __init__.py
- Add _get_node_category() helper to check both layer_name and layer_type
  (graph stores semantic categories in layer_name, not layer_type)
- Fix color map lookup to use semantic categories
- Allows proper filtering of DL_Layer, MLP_Layer, Activation, etc. nodes
@neerajaryaai neerajaryaai force-pushed the forward_only_generation branch from 26d4f70 to f8f9cf6 Compare March 25, 2026 11:24
Shows total nodes → filtered nodes count for each filtering mode:
- Layer-type filtering (visualize_relevance and visualize_relevance_fast)
- Top-k filtering
- Threshold filtering
- No filtering
@neerajaryaai neerajaryaai force-pushed the forward_only_generation branch from d042875 to 75b8a65 Compare March 26, 2026 05:42
- Add EXCLUDED_NODE_PREFIXES constant and _is_excluded_node() helper
- Filter out parameter/bias weight nodes in all filtering modes
- Cleaner visualization by hiding internal weight matrices
- use sum() instead of mean() for single tensor relevance (batch=1)
- add clarifying comments for mean/max/min computation
- consistent handling between visualize_relevance and visualize_relevance_fast
- add visualize_relevance_paginated() for splitting large graphs into pages
- add rankdir parameter to visualize_relevance(), visualize_relevance_fast(),
  visualize_relevance_auto() for TB/LR graph direction
- add paginated and max_nodes_per_page params to visualize_dlbacktrace()
- pass rankdir through visualize_dlbacktrace() to underlying functions
- add pages_per_row param to visualize_relevance_paginated() and visualize_dlbacktrace()
- when pages_per_row > 1, creates combined SVG with pages arranged horizontally
- add rankdir param to visualize_relevance_paginated() (default TB for vertical)
- pass rankdir through to paginated visualization from visualize_dlbacktrace()
- split pages into rows (pages_per_row pages each)
- create separate combined SVG for each row
- proper cluster subgraph layout with inter-page edges
- show page connections across consecutive pages in same row
- rename row-based to column-based layout (pages_per_row stacks vertically)
- use TB rankdir for main graph (vertical cluster stacking)
- use LR rankdir inside each page cluster (wider node layout)
- better visual flow for sequential page reading
- organize pages into columns (page % num_columns) for proper grid
- use orthogonal splines for cleaner inter-page edges
- add invisible edges to enforce row ordering across columns
- create nested clusters: column clusters containing page clusters
- pages flow: col0[p1,p4,p7], col1[p2,p5,p8], col2[p3,p6,p9] for Nx3
@neerajaryaai neerajaryaai force-pushed the forward_only_generation branch from d70f67c to 45bf165 Compare March 30, 2026 13:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant