From adc3f611e823c8aeb9dae21ee9b1c812983f9e6a Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 14:53:28 -0800 Subject: [PATCH 01/37] Set and clear DynamoPrefixContext for workflow KV optimization Introduced context management using DynamoPrefixContext to optimize KV cache by setting unique prefix IDs per workflow run. This includes adding lazy imports to avoid circular dependencies and ensuring the context is cleared in `finally` blocks to prevent leaks. Signed-off-by: dnandakumar-nv --- src/nat/runtime/runner.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/nat/runtime/runner.py b/src/nat/runtime/runner.py index 85ea726dcd..fa6f4ac8a5 100644 --- a/src/nat/runtime/runner.py +++ b/src/nat/runtime/runner.py @@ -161,6 +161,12 @@ async def result(self, to_type: type | None = None): token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + # Set Dynamo prefix context for KV cache optimization + # Each workflow invocation gets a unique prefix ID based on the run ID + # Lazy import to avoid circular dependency + from nat.llm.dynamo_llm import DynamoPrefixContext + DynamoPrefixContext.set(f"nat-workflow-{workflow_run_id}") + # Prepare workflow-level intermediate step identifiers workflow_step_uuid = str(uuid.uuid4()) workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" @@ -211,6 +217,9 @@ async def result(self, to_type: type | None = None): self._state = RunnerState.FAILED raise finally: + # Lazy import to avoid circular dependency + from nat.llm.dynamo_llm import DynamoPrefixContext + DynamoPrefixContext.clear() if token_run_id is not None: self._context_state.workflow_run_id.reset(token_run_id) if token_trace_id is not None: @@ -240,6 +249,12 @@ async def result_stream(self, to_type: type | None = None): token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) + # Set Dynamo prefix context for KV cache optimization + # Each workflow invocation gets a unique prefix ID based on the run ID + # Lazy import to avoid circular dependency + from nat.llm.dynamo_llm import DynamoPrefixContext + DynamoPrefixContext.set(f"nat-workflow-{workflow_run_id}") + # Prepare workflow-level intermediate step identifiers workflow_step_uuid = str(uuid.uuid4()) workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" @@ -296,6 +311,9 @@ async def result_stream(self, to_type: type | None = None): self._state = RunnerState.FAILED raise finally: + # Lazy import to avoid circular dependency + from nat.llm.dynamo_llm import DynamoPrefixContext + DynamoPrefixContext.clear() if token_run_id is not None: self._context_state.workflow_run_id.reset(token_run_id) if token_trace_id is not None: From ffd3d2893c67f6f6c8709564523349361d74a6b6 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 14:54:33 -0800 Subject: [PATCH 02/37] Add tests for DynamoPrefixContext integration in Runner Introduces comprehensive test cases to ensure `DynamoPrefixContext` is properly set, cleared, and associated with unique workflow run IDs during execution. This includes tests for normal operations, streaming results, error handling, and pre-set workflow run IDs, enhancing reliability and coverage for the Runner class. Signed-off-by: dnandakumar-nv --- .../nat/runtime/test_runner_dynamo_prefix.py | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 tests/nat/runtime/test_runner_dynamo_prefix.py diff --git a/tests/nat/runtime/test_runner_dynamo_prefix.py b/tests/nat/runtime/test_runner_dynamo_prefix.py new file mode 100644 index 0000000000..a0c61ca147 --- /dev/null +++ b/tests/nat/runtime/test_runner_dynamo_prefix.py @@ -0,0 +1,358 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for DynamoPrefixContext integration in the Runner class. + +These tests verify that the Runner properly sets and clears the DynamoPrefixContext +for KV cache optimization when using Dynamo LLM backends. +""" + +from collections.abc import AsyncGenerator + +import pytest + +from nat.builder.builder import Builder +from nat.builder.context import ContextState +from nat.builder.workflow_builder import WorkflowBuilder +from nat.cli.register_workflow import register_function +from nat.data_models.function import FunctionBaseConfig +from nat.llm.dynamo_llm import DynamoPrefixContext +from nat.observability.exporter_manager import ExporterManager +from nat.runtime.runner import Runner + + +class SingleOutputConfig(FunctionBaseConfig, name="single_output_dynamo_test"): + pass + + +class StreamOutputConfig(FunctionBaseConfig, name="stream_output_dynamo_test"): + pass + + +class CaptureConfig(FunctionBaseConfig, name="capture_dynamo_prefix"): + pass + + +@pytest.fixture(scope="module", autouse=True) +async def _register_single_output_fn(): + + @register_function(config_type=SingleOutputConfig) + async def register(config: SingleOutputConfig, b: Builder): + + async def _inner(message: str) -> str: + return message + "!" + + yield _inner + + +@pytest.fixture(scope="module", autouse=True) +async def _register_stream_output_fn(): + + @register_function(config_type=StreamOutputConfig) + async def register(config: StreamOutputConfig, b: Builder): + + async def _inner_stream(message: str) -> AsyncGenerator[str]: + yield message + "!" + + yield _inner_stream + + +@pytest.fixture(autouse=True) +def clean_dynamo_context(): + """Ensure DynamoPrefixContext is clean before and after each test.""" + DynamoPrefixContext.clear() + yield + DynamoPrefixContext.clear() + + +async def test_runner_result_sets_dynamo_prefix_context(): + """Test that Runner.result() sets DynamoPrefixContext with unique prefix ID.""" + captured_prefix_ids = [] + + @register_function(config_type=CaptureConfig) + async def _register(config: CaptureConfig, b: Builder): + + async def _capture(message: str) -> str: + # Capture the prefix ID during execution + prefix_id = DynamoPrefixContext.get() + captured_prefix_ids.append(prefix_id) + return message + + yield _capture + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="capture_fn", config=CaptureConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + await runner.result() + + # Verify prefix ID was set during execution + assert len(captured_prefix_ids) == 1 + assert captured_prefix_ids[0] is not None + assert captured_prefix_ids[0].startswith("nat-workflow-") + + +async def test_runner_result_clears_dynamo_prefix_context_after_completion(): + """Test that Runner.result() clears DynamoPrefixContext after workflow completes.""" + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="test_fn", config=SingleOutputConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + await runner.result() + + # Verify prefix ID is cleared after execution + assert DynamoPrefixContext.get() is None + + +async def test_runner_result_clears_dynamo_prefix_context_on_error(): + """Test that Runner.result() clears DynamoPrefixContext even when workflow fails.""" + + class ErrorConfig(FunctionBaseConfig, name="error_dynamo_test"): + pass + + @register_function(config_type=ErrorConfig) + async def _register(config: ErrorConfig, b: Builder): + + async def _error(message: str) -> str: + raise ValueError("Simulated error") + + yield _error + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="error_fn", config=ErrorConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + with pytest.raises(ValueError, match="Simulated error"): + await runner.result() + + # Verify prefix ID is cleared even after error + assert DynamoPrefixContext.get() is None + + +async def test_runner_result_different_invocations_get_unique_prefix_ids(): + """Test that different workflow invocations get unique prefix IDs.""" + captured_prefix_ids = [] + + class CaptureConfig2(FunctionBaseConfig, name="capture_dynamo_prefix2"): + pass + + @register_function(config_type=CaptureConfig2) + async def _register(config: CaptureConfig2, b: Builder): + + async def _capture(message: str) -> str: + prefix_id = DynamoPrefixContext.get() + captured_prefix_ids.append(prefix_id) + return message + + yield _capture + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="capture_fn", config=CaptureConfig2()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + # Run workflow multiple times + for i in range(3): + async with Runner(input_message=f"test{i}", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + await runner.result() + + # Each invocation should have a unique prefix ID + assert len(captured_prefix_ids) == 3 + assert len(set(captured_prefix_ids)) == 3 # All unique + + +async def test_runner_result_stream_sets_dynamo_prefix_context(): + """Test that Runner.result_stream() sets DynamoPrefixContext with unique prefix ID.""" + captured_prefix_ids = [] + + class StreamCaptureConfig(FunctionBaseConfig, name="stream_capture_dynamo"): + pass + + @register_function(config_type=StreamCaptureConfig) + async def _register(config: StreamCaptureConfig, b: Builder): + + async def _capture_stream(message: str) -> AsyncGenerator[str]: + prefix_id = DynamoPrefixContext.get() + captured_prefix_ids.append(prefix_id) + yield message + + yield _capture_stream + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="stream_capture_fn", config=StreamCaptureConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + async for _ in runner.result_stream(): + pass + + # Verify prefix ID was set during execution + assert len(captured_prefix_ids) == 1 + assert captured_prefix_ids[0] is not None + assert captured_prefix_ids[0].startswith("nat-workflow-") + + +async def test_runner_result_stream_clears_dynamo_prefix_context_after_completion(): + """Test that Runner.result_stream() clears DynamoPrefixContext after workflow completes.""" + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="stream_fn", config=StreamOutputConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + async for _ in runner.result_stream(): + pass + + # Verify prefix ID is cleared after execution + assert DynamoPrefixContext.get() is None + + +async def test_runner_result_stream_clears_dynamo_prefix_context_on_error(): + """Test that Runner.result_stream() clears DynamoPrefixContext even when workflow fails.""" + + class StreamErrorConfig(FunctionBaseConfig, name="stream_error_dynamo_test"): + pass + + @register_function(config_type=StreamErrorConfig) + async def _register(config: StreamErrorConfig, b: Builder): + + async def _error_stream(message: str) -> AsyncGenerator[str]: + raise ValueError("Simulated stream error") + yield message # Make it a generator + + yield _error_stream + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="stream_error_fn", config=StreamErrorConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + with pytest.raises(ValueError, match="Simulated stream error"): + async for _ in runner.result_stream(): + pass + + # Verify prefix ID is cleared even after error + assert DynamoPrefixContext.get() is None + + +async def test_runner_prefix_id_based_on_workflow_run_id(): + """Test that the prefix ID is based on the workflow_run_id.""" + captured_prefix_id = None + + class PrefixCheckConfig(FunctionBaseConfig, name="prefix_check_dynamo"): + pass + + @register_function(config_type=PrefixCheckConfig) + async def _register(config: PrefixCheckConfig, b: Builder): + + async def _check(message: str) -> str: + nonlocal captured_prefix_id + captured_prefix_id = DynamoPrefixContext.get() + return message + + yield _check + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="prefix_check_fn", config=PrefixCheckConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + await runner.result() + + # The prefix ID should be in the expected format + assert captured_prefix_id is not None + assert captured_prefix_id.startswith("nat-workflow-") + # Verify the UUID portion is valid (36 chars with hyphens) + uuid_part = captured_prefix_id[len("nat-workflow-"):] + assert len(uuid_part) == 36 + + +async def test_runner_uses_existing_workflow_run_id_for_prefix(): + """Test that Runner uses existing workflow_run_id when set externally.""" + captured_prefix_id = None + preset_run_id = "preset-external-run-id-12345" + + class ExternalIdConfig(FunctionBaseConfig, name="external_id_dynamo"): + pass + + @register_function(config_type=ExternalIdConfig) + async def _register(config: ExternalIdConfig, b: Builder): + + async def _check(message: str) -> str: + nonlocal captured_prefix_id + captured_prefix_id = DynamoPrefixContext.get() + return message + + yield _check + + async with WorkflowBuilder() as builder: + entry_fn = await builder.add_function(name="external_id_fn", config=ExternalIdConfig()) + + context_state = ContextState() + exporter_manager = ExporterManager() + + # Pre-set the workflow_run_id + token = context_state.workflow_run_id.set(preset_run_id) + try: + async with Runner(input_message="test", + entry_fn=entry_fn, + context_state=context_state, + exporter_manager=exporter_manager) as runner: + await runner.result() + finally: + context_state.workflow_run_id.reset(token) + + # The prefix ID should use the pre-set workflow_run_id + assert captured_prefix_id == f"nat-workflow-{preset_run_id}" From 360288dd4c55f9bc8a892a861d1ee6d7c4204da2 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 15:32:25 -0800 Subject: [PATCH 03/37] Add design doc for prediction trie inference routing Design for a prediction system that provides Dynamo inference server with expected workload characteristics (remaining calls, inter-arrival time, output length) for each LLM call, enabling smarter routing. Key components: - PredictionTrie: hierarchical structure storing metrics at every path granularity - TrieBuilder: processes profiler traces into trie - Runtime lookup with graceful fallback to less specific matches - Header injection in dynamo_langchain LLM client Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- .../2026-01-23-prediction-trie-design.md | 342 ++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 docs/plans/2026-01-23-prediction-trie-design.md diff --git a/docs/plans/2026-01-23-prediction-trie-design.md b/docs/plans/2026-01-23-prediction-trie-design.md new file mode 100644 index 0000000000..ad14bbdcaf --- /dev/null +++ b/docs/plans/2026-01-23-prediction-trie-design.md @@ -0,0 +1,342 @@ +# Prediction Trie for Dynamo Inference Routing + +**Date:** 2026-01-23 +**Status:** Approved +**Author:** Design session with Claude + +## Overview + +A prediction system that provides the Dynamo inference server with expected workload characteristics for each LLM call—remaining calls, inter-arrival time, and expected output length—enabling smarter routing decisions. + +## Problem + +The Dynamo inference server can make better routing decisions if it knows: +- How many more LLM calls are expected in this workflow +- When the next LLM call will arrive +- How long the response will be + +Currently, each LLM request arrives without this context. The server treats each call independently, missing optimization opportunities. + +## Solution + +Build a prediction trie from profiler data that captures LLM call patterns at multiple granularities. At runtime, inject predictions as HTTP headers on inference requests. + +### End-to-End Flow + +``` +┌─────────────────────────────────────────────────────────────────────────────┐ +│ PROFILING PHASE │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ 1. Run profiler on workflow with representative inputs │ +│ 2. Collect IntermediateStep traces with full ancestry │ +│ 3. Build PredictionTrie from LLM_END events │ +│ 4. Serialize to prediction_trie.json │ +└─────────────────────────────────────────────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────────────────────┐ +│ RUNTIME PHASE │ +├─────────────────────────────────────────────────────────────────────────────┤ +│ 1. LLM client loads prediction_trie.json at startup │ +│ 2. On each LLM call: │ +│ a. Get current function path from context │ +│ b. Increment and get call_index from tracker │ +│ c. Lookup prediction in trie │ +│ d. Inject headers into request │ +│ 3. Dynamo server uses headers for routing decisions │ +└─────────────────────────────────────────────────────────────────────────────┘ +``` + +## Data Structures + +### Prediction Metrics + +```python +@dataclass +class PredictionMetrics: + """Stats for a single metric, pre-computed from profiler data.""" + sample_count: int + mean: float + p50: float + p90: float + p95: float +``` + +### LLM Call Prediction + +```python +@dataclass +class LLMCallPrediction: + """What we predict for an LLM call at a given position.""" + remaining_calls: PredictionMetrics # How many more LLM calls expected + interarrival_ms: PredictionMetrics # Time until next LLM call + output_tokens: PredictionMetrics # Expected output length +``` + +### Prediction Trie Node + +```python +@dataclass +class PredictionTrieNode: + """A node in the prediction trie.""" + name: str # Function name at this level + children: dict[str, PredictionTrieNode] # Child nodes by function name + predictions_by_call_index: dict[int, LLMCallPrediction] # Metrics keyed by call index + predictions_any_index: LLMCallPrediction | None # Fallback: aggregated across all indices +``` + +### Trie Structure Example + +``` +root +├── workflow (stats: all LLM calls in any workflow) +│ └── react_agent (stats: all LLM calls under react_agent) +│ ├── search_tool (stats: LLM calls under search_tool) +│ │ └── llm:1 (stats: first LLM call in search_tool) +│ │ └── llm:2 (stats: second LLM call) +│ └── calculator_tool +│ └── llm:1 (stats: first LLM call in calculator_tool) +``` + +## Building the Trie + +### LLM Call Context Extraction + +For each `LLM_END` event in a profiler trace: + +```python +@dataclass +class LLMCallContext: + path: list[str] # ["workflow", "react_agent", "search_tool"] + call_index: int # Nth LLM call within the immediate parent + remaining_calls: int # How many LLM calls left in this workflow run + time_to_next_ms: float # Milliseconds until next LLM_START (or None if last) + output_tokens: int # Actual completion tokens +``` + +### Call Index Scoping + +Call index is scoped to the immediate parent function: + +``` +workflow (run_id=1) + └── react_agent (invocation_id=a1) + ├── LLM call (call_index=1 within react_agent) + ├── search_tool + │ └── LLM call (call_index=1 within search_tool) + └── LLM call (call_index=2 within react_agent) +``` + +### Trie Update Algorithm + +For each LLM call, walk its ancestry path and update every node: + +```python +def update_trie(root: PredictionTrieNode, ctx: LLMCallContext): + node = root + # Walk path, updating aggregates at each level + for func_name in ctx.path: + node.add_sample(ctx.call_index, ctx.remaining_calls, ctx.time_to_next_ms, ctx.output_tokens) + node = node.children.setdefault(func_name, PredictionTrieNode(func_name)) + # Update leaf node too + node.add_sample(ctx.call_index, ctx.remaining_calls, ctx.time_to_next_ms, ctx.output_tokens) +``` + +This means a single LLM call contributes samples to every ancestor node—giving us aggregated stats at every granularity automatically. + +## Runtime Lookup + +### Current Context + +```python +@dataclass +class CurrentContext: + path: list[str] # Current function ancestry + call_index: int # Which LLM call this is within the immediate parent +``` + +### Lookup Algorithm + +```python +def lookup(root: PredictionTrieNode, ctx: CurrentContext) -> LLMCallPrediction | None: + node = root + deepest_match = None + + # Walk the trie as far as we can match + for func_name in ctx.path: + # Capture this node as a potential match before descending + prediction = node.predictions_by_call_index.get(ctx.call_index) + if prediction is None: + prediction = node.predictions_any_index + if prediction is not None: + deepest_match = prediction + + # Try to descend + if func_name not in node.children: + break + node = node.children[func_name] + + # Check the final node we reached + prediction = node.predictions_by_call_index.get(ctx.call_index) + if prediction is None: + prediction = node.predictions_any_index + if prediction is not None: + deepest_match = prediction + + return deepest_match +``` + +### Fallback Behavior + +1. Try exact path + exact call index (most specific) +2. Try exact path + any call index +3. Try partial path + exact call index +4. Try partial path + any call index (most general) + +Novel tool calls automatically get predictions based on agent-level stats. + +## Runtime Call Index Tracking + +```python +from contextvars import ContextVar + +@dataclass +class LLMCallTracker: + """Tracks LLM call counts per function invocation.""" + counts: dict[str, int] = field(default_factory=dict) + + def increment(self, parent_function_id: str) -> int: + """Increment and return the call index for this parent.""" + self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 + return self.counts[parent_function_id] + + def reset(self, parent_function_id: str): + """Reset when a function invocation completes.""" + self.counts.pop(parent_function_id, None) + +_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar('llm_call_tracker') +``` + +## Header Injection + +### Headers + +``` +X-NAT-Remaining-LLM-Calls: 3 +X-NAT-Interarrival-Ms: 450 +X-NAT-Expected-Output-Tokens: 256 +X-NAT-Prediction-Confidence: 0.85 +``` + +### Integration Point + +```python +class DynamoLangChainLLM(BaseLLM): + prediction_trie: PredictionTrie | None = None + + def _call(self, prompt: str, **kwargs) -> str: + headers = self._get_base_headers() + + if self.prediction_trie is not None: + ctx = self._get_current_context() + prediction = self.prediction_trie.lookup(ctx) + if prediction: + headers["X-NAT-Remaining-LLM-Calls"] = str(prediction.remaining_calls.mean) + headers["X-NAT-Interarrival-Ms"] = str(prediction.interarrival_ms.mean) + headers["X-NAT-Expected-Output-Tokens"] = str(prediction.output_tokens.p90) + + return self._make_request(prompt, headers=headers, **kwargs) +``` + +### Configuration + +```yaml +llms: + my_llm: + _type: nim + model_name: meta/llama-3.1-70b-instruct + prediction_trie_path: ./profiler_output/prediction_trie.json +``` + +## Serialization + +### JSON Format + +```json +{ + "version": "1.0", + "generated_at": "2026-01-23T10:30:00Z", + "workflow_name": "my_workflow", + "sample_count": 150, + "root": { + "name": "root", + "predictions_by_call_index": { + "1": { + "remaining_calls": {"sample_count": 150, "mean": 4.2, "p50": 4, "p90": 6, "p95": 7}, + "interarrival_ms": {"sample_count": 150, "mean": 520, "p50": 480, "p90": 890, "p95": 1100}, + "output_tokens": {"sample_count": 150, "mean": 185, "p50": 160, "p90": 320, "p95": 410} + } + }, + "predictions_any_index": { ... }, + "children": { + "react_agent": { ... } + } + } +} +``` + +### Output Files + +``` +profiler_output/ +├── all_requests_profiler_traces.json +├── standardized_data_all.csv +├── inference_optimization.json +├── prediction_trie.json # NEW +└── prediction_trie_summary.txt # NEW: human-readable summary +``` + +## File Organization + +``` +src/nat/profiler/ +├── prediction_trie/ +│ ├── __init__.py +│ ├── data_models.py # PredictionTrieNode, LLMCallPrediction, PredictionMetrics +│ ├── trie_builder.py # Build trie from profiler traces +│ ├── trie_lookup.py # Lookup algorithm +│ └── serialization.py # JSON load/save + +src/nat/llm/ +├── prediction_context.py # LLMCallTracker, context variable, path extraction + +packages/nvidia_nat_langchain/src/nat/plugins/langchain/ +├── llm.py # Modify to inject headers +``` + +## Profiler Configuration + +```yaml +profiler: + base_metrics: true + prediction_trie: true + prediction_trie_output: ./prediction_trie.json +``` + +## Implementation Sequence + +1. **Data models** - `PredictionTrieNode`, `LLMCallPrediction`, `PredictionMetrics` +2. **Trie builder** - Parse profiler traces, extract LLM call contexts, build trie +3. **Serialization** - JSON save/load for the trie +4. **Trie lookup** - Walk trie, return deepest match with fallback +5. **Runtime tracking** - `LLMCallTracker` context variable, integrate with existing ancestry tracking +6. **Header injection** - Modify `dynamo_langchain` LLM client to inject headers +7. **Profiler integration** - Add config option, wire trie builder into profiler output +8. **Tests** - Unit tests for trie operations, integration test with sample traces + +## Out of Scope + +- Concurrency/parallelism tracking +- Input token bucketing for lookup +- Real-time trie updates during runtime +- Multiple trie versions/A-B testing From 80df59fb844889544dbf7f31bc764b8294b1ac27 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 15:41:15 -0800 Subject: [PATCH 04/37] docs: add prediction trie implementation plan Detailed 10-task TDD implementation plan: 1. Data models (PredictionMetrics, LLMCallPrediction, PredictionTrieNode) 2. Metrics accumulator for computing statistics 3. Trie builder from profiler traces 4. Trie lookup with fallback 5. JSON serialization 6. Runtime call tracker (contextvars) 7. Profiler integration (config + generation) 8. Dynamo header injection 9. LangChain integration 10. End-to-end test Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- ...26-01-23-prediction-trie-implementation.md | 1840 +++++++++++++++++ 1 file changed, 1840 insertions(+) create mode 100644 docs/plans/2026-01-23-prediction-trie-implementation.md diff --git a/docs/plans/2026-01-23-prediction-trie-implementation.md b/docs/plans/2026-01-23-prediction-trie-implementation.md new file mode 100644 index 0000000000..d8e11cc5de --- /dev/null +++ b/docs/plans/2026-01-23-prediction-trie-implementation.md @@ -0,0 +1,1840 @@ +# Prediction Trie Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Build a prediction trie that aggregates LLM call patterns from profiler data and injects routing hints as Dynamo headers at runtime. + +**Architecture:** The profiler builds a trie from execution traces where each node stores aggregated metrics (remaining calls, interarrival time, output tokens) by call index. At runtime, the Dynamo LLM client walks the trie to find the best match for the current execution context and injects predictions as HTTP headers. + +**Tech Stack:** Python 3.11+, Pydantic v2, httpx event hooks, contextvars + +--- + +## Task 1: Data Models + +**Files:** +- Create: `src/nat/profiler/prediction_trie/data_models.py` +- Test: `tests/nat/profiler/prediction_trie/test_data_models.py` + +### Step 1: Write the failing test for PredictionMetrics + +```python +# tests/nat/profiler/prediction_trie/test_data_models.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +def test_prediction_metrics_creation(): + metrics = PredictionMetrics(sample_count=10, mean=5.0, p50=4.5, p90=8.0, p95=9.0) + assert metrics.sample_count == 10 + assert metrics.mean == 5.0 + assert metrics.p50 == 4.5 + assert metrics.p90 == 8.0 + assert metrics.p95 == 9.0 + + +def test_prediction_metrics_defaults(): + metrics = PredictionMetrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py::test_prediction_metrics_creation -v` +Expected: FAIL with "ModuleNotFoundError: No module named 'nat.profiler.prediction_trie'" + +### Step 3: Create the prediction_trie package and data models + +```python +# src/nat/profiler/prediction_trie/__init__.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + +__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode"] +``` + +```python +# src/nat/profiler/prediction_trie/data_models.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pydantic import BaseModel +from pydantic import Field + + +class PredictionMetrics(BaseModel): + """Aggregated statistics for a single metric from profiler data.""" + + sample_count: int = Field(default=0, description="Number of samples") + mean: float = Field(default=0.0, description="Mean value") + p50: float = Field(default=0.0, description="50th percentile (median)") + p90: float = Field(default=0.0, description="90th percentile") + p95: float = Field(default=0.0, description="95th percentile") + + +class LLMCallPrediction(BaseModel): + """Predictions for an LLM call at a given position in the call hierarchy.""" + + remaining_calls: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="How many more LLM calls are expected after this one", + ) + interarrival_ms: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected time in milliseconds until the next LLM call", + ) + output_tokens: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected output token count for this call", + ) + + +class PredictionTrieNode(BaseModel): + """A node in the prediction trie representing a function in the call hierarchy.""" + + name: str = Field(description="Function name at this level in the hierarchy") + children: dict[str, PredictionTrieNode] = Field( + default_factory=dict, + description="Child nodes keyed by function name", + ) + predictions_by_call_index: dict[int, LLMCallPrediction] = Field( + default_factory=dict, + description="Predictions keyed by call index (1-indexed)", + ) + predictions_any_index: LLMCallPrediction | None = Field( + default=None, + description="Fallback predictions aggregated across all call indices", + ) + + +# Rebuild model to handle forward references +PredictionTrieNode.model_rebuild() +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py -v` +Expected: PASS + +### Step 5: Add tests for LLMCallPrediction and PredictionTrieNode + +Add to `tests/nat/profiler/prediction_trie/test_data_models.py`: + +```python +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +def test_llm_call_prediction_creation(): + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=5, mean=3.0, p50=3.0, p90=5.0, p95=6.0), + interarrival_ms=PredictionMetrics(sample_count=5, mean=500.0, p50=450.0, p90=800.0, p95=900.0), + output_tokens=PredictionMetrics(sample_count=5, mean=150.0, p50=140.0, p90=250.0, p95=300.0), + ) + assert prediction.remaining_calls.mean == 3.0 + assert prediction.interarrival_ms.mean == 500.0 + assert prediction.output_tokens.mean == 150.0 + + +def test_prediction_trie_node_creation(): + node = PredictionTrieNode(name="root") + assert node.name == "root" + assert node.children == {} + assert node.predictions_by_call_index == {} + assert node.predictions_any_index is None + + +def test_prediction_trie_node_with_children(): + child = PredictionTrieNode(name="react_agent") + root = PredictionTrieNode(name="root", children={"react_agent": child}) + assert "react_agent" in root.children + assert root.children["react_agent"].name == "react_agent" + + +def test_prediction_trie_node_with_predictions(): + prediction = LLMCallPrediction() + node = PredictionTrieNode( + name="agent", + predictions_by_call_index={1: prediction, 2: prediction}, + predictions_any_index=prediction, + ) + assert 1 in node.predictions_by_call_index + assert 2 in node.predictions_by_call_index + assert node.predictions_any_index is not None +``` + +### Step 6: Run all data model tests + +Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py -v` +Expected: PASS (all tests) + +### Step 7: Commit + +```bash +git add src/nat/profiler/prediction_trie/ tests/nat/profiler/prediction_trie/ +git commit --signoff -m "feat(profiler): add prediction trie data models + +Add Pydantic models for the prediction trie: +- PredictionMetrics: aggregated stats (mean, p50, p90, p95) +- LLMCallPrediction: predictions for remaining calls, interarrival time, output tokens +- PredictionTrieNode: trie node with children and predictions by call index" +``` + +--- + +## Task 2: Metrics Accumulator + +**Files:** +- Create: `src/nat/profiler/prediction_trie/metrics_accumulator.py` +- Test: `tests/nat/profiler/prediction_trie/test_metrics_accumulator.py` + +### Step 1: Write the failing test for MetricsAccumulator + +```python +# tests/nat/profiler/prediction_trie/test_metrics_accumulator.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +def test_accumulator_add_single_sample(): + acc = MetricsAccumulator() + acc.add_sample(10.0) + metrics = acc.compute_metrics() + assert metrics.sample_count == 1 + assert metrics.mean == 10.0 + assert metrics.p50 == 10.0 + assert metrics.p90 == 10.0 + assert metrics.p95 == 10.0 + + +def test_accumulator_add_multiple_samples(): + acc = MetricsAccumulator() + for v in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]: + acc.add_sample(v) + metrics = acc.compute_metrics() + assert metrics.sample_count == 10 + assert metrics.mean == 5.5 + assert metrics.p50 == 5.5 # median of 1-10 + assert metrics.p90 == 9.1 # 90th percentile + assert metrics.p95 == 9.55 # 95th percentile + + +def test_accumulator_empty(): + acc = MetricsAccumulator() + metrics = acc.compute_metrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/profiler/prediction_trie/test_metrics_accumulator.py::test_accumulator_add_single_sample -v` +Expected: FAIL with "ModuleNotFoundError" + +### Step 3: Implement MetricsAccumulator + +```python +# src/nat/profiler/prediction_trie/metrics_accumulator.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +class MetricsAccumulator: + """Accumulates samples and computes aggregated statistics.""" + + def __init__(self) -> None: + self._samples: list[float] = [] + + def add_sample(self, value: float) -> None: + """Add a sample value to the accumulator.""" + self._samples.append(value) + + def compute_metrics(self) -> PredictionMetrics: + """Compute aggregated metrics from accumulated samples.""" + if not self._samples: + return PredictionMetrics() + + n = len(self._samples) + mean_val = sum(self._samples) / n + sorted_samples = sorted(self._samples) + + return PredictionMetrics( + sample_count=n, + mean=mean_val, + p50=self._percentile(sorted_samples, 50), + p90=self._percentile(sorted_samples, 90), + p95=self._percentile(sorted_samples, 95), + ) + + @staticmethod + def _percentile(sorted_data: list[float], pct: float) -> float: + """Compute percentile using linear interpolation.""" + if not sorted_data: + return 0.0 + if len(sorted_data) == 1: + return sorted_data[0] + k = (len(sorted_data) - 1) * (pct / 100.0) + f = math.floor(k) + c = math.ceil(k) + if f == c: + return sorted_data[int(k)] + return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f) +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/profiler/prediction_trie/test_metrics_accumulator.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/profiler/prediction_trie/metrics_accumulator.py tests/nat/profiler/prediction_trie/test_metrics_accumulator.py +git commit --signoff -m "feat(profiler): add MetricsAccumulator for prediction trie + +Accumulates sample values and computes aggregated statistics +(mean, p50, p90, p95) using linear interpolation for percentiles." +``` + +--- + +## Task 3: Trie Builder + +**Files:** +- Create: `src/nat/profiler/prediction_trie/trie_builder.py` +- Test: `tests/nat/profiler/prediction_trie/test_trie_builder.py` + +### Step 1: Write the failing test for TrieBuilder + +```python +# tests/nat/profiler/prediction_trie/test_trie_builder.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder + + +@pytest.fixture(name="simple_trace") +def fixture_simple_trace() -> list[IntermediateStep]: + """Create a simple trace with two LLM calls.""" + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo( + token_usage=TokenUsageBaseModel(completion_tokens=100), + ), + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1002.0, + UUID="llm-2", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1003.0, + span_event_timestamp=1002.0, + UUID="llm-2", + usage_info=UsageInfo( + token_usage=TokenUsageBaseModel(completion_tokens=150), + ), + ), + ), + ] + + +def test_trie_builder_builds_from_single_trace(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + assert trie.name == "root" + assert "my_workflow" in trie.children + + workflow_node = trie.children["my_workflow"] + # First LLM call: call_index=1, remaining=1 + assert 1 in workflow_node.predictions_by_call_index + # Second LLM call: call_index=2, remaining=0 + assert 2 in workflow_node.predictions_by_call_index + + +def test_trie_builder_computes_remaining_calls(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call should predict 1 remaining call + assert workflow_node.predictions_by_call_index[1].remaining_calls.mean == 1.0 + # Second call should predict 0 remaining calls + assert workflow_node.predictions_by_call_index[2].remaining_calls.mean == 0.0 + + +def test_trie_builder_computes_output_tokens(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call had 100 completion tokens + assert workflow_node.predictions_by_call_index[1].output_tokens.mean == 100.0 + # Second call had 150 completion tokens + assert workflow_node.predictions_by_call_index[2].output_tokens.mean == 150.0 +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py::test_trie_builder_builds_from_single_trace -v` +Expected: FAIL with "ModuleNotFoundError" + +### Step 3: Implement PredictionTrieBuilder + +```python +# src/nat/profiler/prediction_trie/trie_builder.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepType +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +@dataclass +class LLMCallContext: + """Context for a single LLM call extracted from a trace.""" + + path: list[str] + call_index: int + remaining_calls: int + time_to_next_ms: float | None + output_tokens: int + + +@dataclass +class _NodeAccumulators: + """Accumulators for a single trie node.""" + + remaining_calls: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + interarrival_ms: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + output_tokens: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + # For aggregated stats across all call indices + all_remaining_calls: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_interarrival_ms: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_output_tokens: MetricsAccumulator = field(default_factory=MetricsAccumulator) + + +class PredictionTrieBuilder: + """Builds a prediction trie from profiler execution traces.""" + + def __init__(self) -> None: + # Map from path tuple to accumulators + self._node_accumulators: dict[tuple[str, ...], _NodeAccumulators] = defaultdict(_NodeAccumulators) + + def add_trace(self, steps: list[IntermediateStep]) -> None: + """Process a single execution trace and update accumulators.""" + contexts = self._extract_llm_contexts(steps) + for ctx in contexts: + self._update_accumulators(ctx) + + def _extract_llm_contexts(self, steps: list[IntermediateStep]) -> list[LLMCallContext]: + """Extract LLM call contexts from a trace.""" + # Sort steps by timestamp + sorted_steps = sorted(steps, key=lambda s: s.event_timestamp) + + # Find all LLM_END events + llm_ends: list[IntermediateStep] = [] + for step in sorted_steps: + if step.event_type == IntermediateStepType.LLM_END: + llm_ends.append(step) + + # Find all LLM_START events for interarrival time calculation + llm_starts: list[IntermediateStep] = [] + for step in sorted_steps: + if step.event_type == IntermediateStepType.LLM_START: + llm_starts.append(step) + + # Track call index per parent function + call_counts: dict[str, int] = defaultdict(int) + contexts: list[LLMCallContext] = [] + + for i, end_step in enumerate(llm_ends): + # Build path from function ancestry + path = self._build_path(end_step) + + # Determine call index within parent + parent_key = end_step.function_ancestry.function_id + call_counts[parent_key] += 1 + call_index = call_counts[parent_key] + + # Remaining calls in this trace + remaining = len(llm_ends) - i - 1 + + # Time to next LLM start (if any) + time_to_next_ms: float | None = None + if i + 1 < len(llm_starts): + next_start_time = llm_starts[i + 1].event_timestamp if i + 1 < len(llm_starts) else None + if next_start_time is not None: + time_to_next_ms = (next_start_time - end_step.event_timestamp) * 1000.0 + + # Output tokens + output_tokens = 0 + if end_step.usage_info and end_step.usage_info.token_usage: + output_tokens = end_step.usage_info.token_usage.completion_tokens or 0 + + contexts.append( + LLMCallContext( + path=path, + call_index=call_index, + remaining_calls=remaining, + time_to_next_ms=time_to_next_ms, + output_tokens=output_tokens, + ) + ) + + return contexts + + def _build_path(self, step: IntermediateStep) -> list[str]: + """Build the function path from ancestry.""" + path: list[str] = [] + ancestry = step.function_ancestry + + # Walk up the ancestry chain + if ancestry.parent_name: + path.append(ancestry.parent_name) + path.append(ancestry.function_name) + + return path + + def _update_accumulators(self, ctx: LLMCallContext) -> None: + """Update accumulators at every node along the path.""" + # Update root node + root_key: tuple[str, ...] = () + self._add_to_accumulators(root_key, ctx) + + # Update each node along the path + for i in range(len(ctx.path)): + path_key = tuple(ctx.path[: i + 1]) + self._add_to_accumulators(path_key, ctx) + + def _add_to_accumulators(self, path_key: tuple[str, ...], ctx: LLMCallContext) -> None: + """Add context data to accumulators for a specific path.""" + accs = self._node_accumulators[path_key] + + # By call index + accs.remaining_calls[ctx.call_index].add_sample(float(ctx.remaining_calls)) + accs.output_tokens[ctx.call_index].add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.interarrival_ms[ctx.call_index].add_sample(ctx.time_to_next_ms) + + # Aggregated across all indices + accs.all_remaining_calls.add_sample(float(ctx.remaining_calls)) + accs.all_output_tokens.add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.all_interarrival_ms.add_sample(ctx.time_to_next_ms) + + def build(self) -> PredictionTrieNode: + """Build the final prediction trie from accumulated data.""" + root = PredictionTrieNode(name="root") + + for path_key, accs in self._node_accumulators.items(): + node = self._get_or_create_node(root, path_key) + self._populate_node_predictions(node, accs) + + return root + + def _get_or_create_node(self, root: PredictionTrieNode, path_key: tuple[str, ...]) -> PredictionTrieNode: + """Navigate to or create a node at the given path.""" + if not path_key: + return root + + current = root + for name in path_key: + if name not in current.children: + current.children[name] = PredictionTrieNode(name=name) + current = current.children[name] + return current + + def _populate_node_predictions(self, node: PredictionTrieNode, accs: _NodeAccumulators) -> None: + """Populate a node with computed predictions from accumulators.""" + # Predictions by call index + all_indices = set(accs.remaining_calls.keys()) | set(accs.interarrival_ms.keys()) | set(accs.output_tokens.keys()) + + for idx in all_indices: + prediction = LLMCallPrediction( + remaining_calls=accs.remaining_calls[idx].compute_metrics(), + interarrival_ms=accs.interarrival_ms[idx].compute_metrics(), + output_tokens=accs.output_tokens[idx].compute_metrics(), + ) + node.predictions_by_call_index[idx] = prediction + + # Aggregated predictions + if accs.all_remaining_calls._samples: + node.predictions_any_index = LLMCallPrediction( + remaining_calls=accs.all_remaining_calls.compute_metrics(), + interarrival_ms=accs.all_interarrival_ms.compute_metrics(), + output_tokens=accs.all_output_tokens.compute_metrics(), + ) +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py -v` +Expected: PASS + +### Step 5: Add test for interarrival time + +Add to `tests/nat/profiler/prediction_trie/test_trie_builder.py`: + +```python +def test_trie_builder_computes_interarrival_time(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call: next LLM starts at 1002.0, this call ends at 1001.0 -> 1000ms + assert workflow_node.predictions_by_call_index[1].interarrival_ms.mean == 1000.0 +``` + +### Step 6: Run all builder tests + +Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py -v` +Expected: PASS + +### Step 7: Commit + +```bash +git add src/nat/profiler/prediction_trie/trie_builder.py tests/nat/profiler/prediction_trie/test_trie_builder.py +git commit --signoff -m "feat(profiler): add PredictionTrieBuilder + +Builds prediction trie from profiler execution traces: +- Extracts LLM call contexts (path, call index, remaining, interarrival, output tokens) +- Aggregates metrics at every node along the path +- Computes stats by call index and aggregated fallback" +``` + +--- + +## Task 4: Trie Lookup + +**Files:** +- Create: `src/nat/profiler/prediction_trie/trie_lookup.py` +- Test: `tests/nat/profiler/prediction_trie/test_trie_lookup.py` + +### Step 1: Write the failing test for lookup + +```python +# tests/nat/profiler/prediction_trie/test_trie_lookup.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing lookups.""" + prediction_1 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + prediction_2 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=400.0, p50=380.0, p90=600.0, p95=700.0), + output_tokens=PredictionMetrics(sample_count=10, mean=200.0, p50=190.0, p90=280.0, p95=320.0), + ) + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=20, mean=2.5, p50=2.5, p90=3.5, p95=4.5), + interarrival_ms=PredictionMetrics(sample_count=20, mean=450.0, p50=415.0, p90=650.0, p95=750.0), + output_tokens=PredictionMetrics(sample_count=20, mean=175.0, p50=165.0, p90=240.0, p95=285.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: prediction_1, 2: prediction_2}, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + return root + + +def test_lookup_exact_match(sample_trie): + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + + assert result is not None + assert result.remaining_calls.mean == 3.0 + assert result.output_tokens.mean == 150.0 + + +def test_lookup_partial_path_match(sample_trie): + """When exact path doesn't exist, fall back to closest ancestor.""" + lookup = PredictionTrieLookup(sample_trie) + # "unknown_tool" doesn't exist, should fall back to react_agent's aggregated + result = lookup.find(path=["my_workflow", "react_agent", "unknown_tool"], call_index=1) + + assert result is not None + # Should get react_agent's call_index=1 prediction + assert result.remaining_calls.mean == 3.0 + + +def test_lookup_unknown_call_index_fallback(sample_trie): + """When call_index doesn't exist, fall back to aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=99) + + assert result is not None + # Should fall back to predictions_any_index + assert result.remaining_calls.mean == 2.5 + + +def test_lookup_no_match_returns_root_aggregated(sample_trie): + """When nothing matches, return root's aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["completely_unknown"], call_index=1) + + assert result is not None + # Should return root's aggregated prediction + assert result.remaining_calls.mean == 2.5 +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/profiler/prediction_trie/test_trie_lookup.py::test_lookup_exact_match -v` +Expected: FAIL with "ModuleNotFoundError" + +### Step 3: Implement PredictionTrieLookup + +```python +# src/nat/profiler/prediction_trie/trie_lookup.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +class PredictionTrieLookup: + """Looks up predictions in a prediction trie with graceful fallback.""" + + def __init__(self, root: PredictionTrieNode) -> None: + self._root = root + + def find(self, path: list[str], call_index: int) -> LLMCallPrediction | None: + """ + Find the best matching prediction for the given path and call index. + + Walks the trie as far as possible along the path, then returns the deepest + match. Falls back to aggregated predictions when exact call_index isn't found. + + Args: + path: Function ancestry path (e.g., ["my_workflow", "react_agent"]) + call_index: The Nth LLM call within the current parent function + + Returns: + Best matching prediction, or None if trie is empty + """ + node = self._root + deepest_match: LLMCallPrediction | None = None + + # Check root node first + deepest_match = self._get_prediction(node, call_index) or deepest_match + + # Walk the trie as far as we can match + for func_name in path: + if func_name not in node.children: + break + node = node.children[func_name] + # Update deepest match at each level + match = self._get_prediction(node, call_index) + if match is not None: + deepest_match = match + + return deepest_match + + def _get_prediction(self, node: PredictionTrieNode, call_index: int) -> LLMCallPrediction | None: + """Get prediction from node, preferring exact call_index, falling back to aggregated.""" + if call_index in node.predictions_by_call_index: + return node.predictions_by_call_index[call_index] + return node.predictions_any_index +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/profiler/prediction_trie/test_trie_lookup.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/profiler/prediction_trie/trie_lookup.py tests/nat/profiler/prediction_trie/test_trie_lookup.py +git commit --signoff -m "feat(profiler): add PredictionTrieLookup + +Walks the trie to find best matching prediction: +- Exact path + exact call_index (most specific) +- Partial path + exact call_index +- Falls back to aggregated predictions when call_index not found" +``` + +--- + +## Task 5: Serialization + +**Files:** +- Create: `src/nat/profiler/prediction_trie/serialization.py` +- Test: `tests/nat/profiler/prediction_trie/test_serialization.py` + +### Step 1: Write the failing test for serialization + +```python +# tests/nat/profiler/prediction_trie/test_serialization.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import tempfile +from pathlib import Path + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing serialization.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + child = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"react_agent": child}, + predictions_any_index=prediction, + ) + + return root + + +def test_save_and_load_trie(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + loaded = load_prediction_trie(path) + + assert loaded.name == "root" + assert "react_agent" in loaded.children + assert loaded.children["react_agent"].predictions_by_call_index[1].remaining_calls.mean == 3.0 + + +def test_saved_file_has_metadata(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + with open(path) as f: + data = json.load(f) + + assert data["version"] == "1.0" + assert data["workflow_name"] == "test_workflow" + assert "generated_at" in data + assert "root" in data +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/profiler/prediction_trie/test_serialization.py::test_save_and_load_trie -v` +Expected: FAIL with "ModuleNotFoundError" + +### Step 3: Implement serialization functions + +```python +# src/nat/profiler/prediction_trie/serialization.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +from datetime import datetime +from datetime import timezone +from pathlib import Path +from typing import Any + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + +CURRENT_VERSION = "1.0" + + +def save_prediction_trie( + trie: PredictionTrieNode, + path: Path, + workflow_name: str = "unknown", +) -> None: + """ + Save a prediction trie to a JSON file. + + Args: + trie: The prediction trie root node + path: Path to save the JSON file + workflow_name: Name of the workflow this trie was built from + """ + data = { + "version": CURRENT_VERSION, + "generated_at": datetime.now(timezone.utc).isoformat(), + "workflow_name": workflow_name, + "root": _serialize_node(trie), + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +def load_prediction_trie(path: Path) -> PredictionTrieNode: + """ + Load a prediction trie from a JSON file. + + Args: + path: Path to the JSON file + + Returns: + The deserialized prediction trie root node + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + + return _deserialize_node(data["root"]) + + +def _serialize_node(node: PredictionTrieNode) -> dict[str, Any]: + """Serialize a trie node to a dictionary.""" + result: dict[str, Any] = { + "name": node.name, + "predictions_by_call_index": { + str(k): v.model_dump() for k, v in node.predictions_by_call_index.items() + }, + "predictions_any_index": node.predictions_any_index.model_dump() if node.predictions_any_index else None, + "children": {k: _serialize_node(v) for k, v in node.children.items()}, + } + return result + + +def _deserialize_node(data: dict[str, Any]) -> PredictionTrieNode: + """Deserialize a dictionary to a trie node.""" + predictions_by_call_index: dict[int, LLMCallPrediction] = {} + for k, v in data.get("predictions_by_call_index", {}).items(): + predictions_by_call_index[int(k)] = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + predictions_any_index = None + if data.get("predictions_any_index"): + v = data["predictions_any_index"] + predictions_any_index = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + children: dict[str, PredictionTrieNode] = {} + for k, v in data.get("children", {}).items(): + children[k] = _deserialize_node(v) + + return PredictionTrieNode( + name=data["name"], + predictions_by_call_index=predictions_by_call_index, + predictions_any_index=predictions_any_index, + children=children, + ) +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/profiler/prediction_trie/test_serialization.py -v` +Expected: PASS + +### Step 5: Update __init__.py exports + +```python +# src/nat/profiler/prediction_trie/__init__.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + +__all__ = [ + "LLMCallPrediction", + "PredictionMetrics", + "PredictionTrieBuilder", + "PredictionTrieLookup", + "PredictionTrieNode", + "load_prediction_trie", + "save_prediction_trie", +] +``` + +### Step 6: Commit + +```bash +git add src/nat/profiler/prediction_trie/serialization.py src/nat/profiler/prediction_trie/__init__.py tests/nat/profiler/prediction_trie/test_serialization.py +git commit --signoff -m "feat(profiler): add prediction trie serialization + +JSON serialization with metadata: +- version, generated_at, workflow_name +- Recursive node serialization/deserialization +- Handles predictions_by_call_index int keys" +``` + +--- + +## Task 6: Runtime Call Tracker + +**Files:** +- Create: `src/nat/llm/prediction_context.py` +- Test: `tests/nat/llm/test_prediction_context.py` + +### Step 1: Write the failing test for LLMCallTracker + +```python +# tests/nat/llm/test_prediction_context.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.llm.prediction_context import LLMCallTracker +from nat.llm.prediction_context import get_call_tracker + + +def test_tracker_increment(): + tracker = LLMCallTracker() + assert tracker.increment("func-1") == 1 + assert tracker.increment("func-1") == 2 + assert tracker.increment("func-2") == 1 + assert tracker.increment("func-1") == 3 + + +def test_tracker_reset(): + tracker = LLMCallTracker() + tracker.increment("func-1") + tracker.increment("func-1") + tracker.reset("func-1") + assert tracker.increment("func-1") == 1 + + +def test_tracker_context_variable(): + tracker1 = get_call_tracker() + tracker1.increment("func-a") + + tracker2 = get_call_tracker() + # Should be the same tracker in the same context + assert tracker2.increment("func-a") == 2 +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/llm/test_prediction_context.py::test_tracker_increment -v` +Expected: FAIL with "ModuleNotFoundError" + +### Step 3: Implement LLMCallTracker + +```python +# src/nat/llm/prediction_context.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Runtime context management for prediction trie lookups. + +Provides tracking of LLM call indices per function invocation, +enabling accurate lookups in the prediction trie at runtime. +""" + +from contextvars import ContextVar +from dataclasses import dataclass +from dataclasses import field + + +@dataclass +class LLMCallTracker: + """Tracks LLM call counts per function invocation.""" + + counts: dict[str, int] = field(default_factory=dict) + + def increment(self, parent_function_id: str) -> int: + """ + Increment and return the call index for this parent. + + Args: + parent_function_id: Unique ID of the parent function invocation + + Returns: + The call index (1-indexed) for this LLM call within the parent + """ + self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 + return self.counts[parent_function_id] + + def reset(self, parent_function_id: str) -> None: + """ + Reset call count when a function invocation completes. + + Args: + parent_function_id: Unique ID of the parent function invocation + """ + self.counts.pop(parent_function_id, None) + + +# Thread/async-safe context variable for the call tracker +_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar("llm_call_tracker") + + +def get_call_tracker() -> LLMCallTracker: + """ + Get the LLMCallTracker for the current context. + + Creates a new tracker if one doesn't exist in the current context. + + Returns: + The LLMCallTracker for this context + """ + try: + return _llm_call_tracker.get() + except LookupError: + tracker = LLMCallTracker() + _llm_call_tracker.set(tracker) + return tracker +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/llm/test_prediction_context.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/llm/prediction_context.py tests/nat/llm/test_prediction_context.py +git commit --signoff -m "feat(llm): add LLMCallTracker for runtime prediction lookups + +Context variable-based tracking of LLM call indices per function +invocation. Thread/async-safe using contextvars." +``` + +--- + +## Task 7: Profiler Integration + +**Files:** +- Modify: `src/nat/data_models/profiler.py` +- Modify: `src/nat/profiler/profile_runner.py` +- Test: `tests/nat/profiler/test_prediction_trie_integration.py` + +### Step 1: Add prediction_trie config option + +Update `src/nat/data_models/profiler.py`: + +```python +# Add to ProfilerConfig class: +class PredictionTrieConfig(BaseModel): + enable: bool = False + output_filename: str = "prediction_trie.json" + + +class ProfilerConfig(BaseModel): + + base_metrics: bool = False + token_usage_forecast: bool = False + token_uniqueness_forecast: bool = False + workflow_runtime_forecast: bool = False + compute_llm_metrics: bool = False + csv_exclude_io_text: bool = False + prompt_caching_prefixes: PromptCachingConfig = PromptCachingConfig() + bottleneck_analysis: BottleneckConfig = BottleneckConfig() + concurrency_spike_analysis: ConcurrencySpikeConfig = ConcurrencySpikeConfig() + prefix_span_analysis: PrefixSpanConfig = PrefixSpanConfig() + prediction_trie: PredictionTrieConfig = PredictionTrieConfig() # ADD THIS +``` + +### Step 2: Write failing test for profiler integration + +```python +# tests/nat/profiler/test_prediction_trie_integration.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.profile_runner import ProfilerRunner + + +@pytest.fixture(name="sample_traces") +def fixture_sample_traces() -> list[list[IntermediateStep]]: + """Create sample traces for testing profiler integration.""" + + def make_trace() -> list[IntermediateStep]: + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100)), + ), + ), + ] + + return [make_trace(), make_trace()] + + +async def test_profiler_generates_prediction_trie(sample_traces): + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + + runner = ProfilerRunner(config, output_dir) + await runner.run(sample_traces) + + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists() + + trie = load_prediction_trie(trie_path) + assert trie.name == "root" + assert "my_workflow" in trie.children +``` + +### Step 3: Run test to verify it fails + +Run: `pytest tests/nat/profiler/test_prediction_trie_integration.py -v` +Expected: FAIL (prediction_trie.json not generated) + +### Step 4: Update ProfilerRunner to generate prediction trie + +Add to `src/nat/profiler/profile_runner.py` in the `run` method, after the existing analysis sections (around line 257): + +```python + # After prefix_span_analysis section, add: + + if self.profile_config.prediction_trie.enable: + # ------------------------------------------------------------ + # Build and save prediction trie + # ------------------------------------------------------------ + from nat.profiler.prediction_trie import PredictionTrieBuilder + from nat.profiler.prediction_trie import save_prediction_trie + + logger.info("Building prediction trie from traces...") + trie_builder = PredictionTrieBuilder() + for trace in all_steps: + trie_builder.add_trace(trace) + + prediction_trie = trie_builder.build() + + if self.write_output: + trie_path = os.path.join(self.output_dir, self.profile_config.prediction_trie.output_filename) + save_prediction_trie(prediction_trie, Path(trie_path), workflow_name="profiled_workflow") + logger.info("Wrote prediction trie to: %s", trie_path) +``` + +### Step 5: Run test to verify it passes + +Run: `pytest tests/nat/profiler/test_prediction_trie_integration.py -v` +Expected: PASS + +### Step 6: Commit + +```bash +git add src/nat/data_models/profiler.py src/nat/profiler/profile_runner.py tests/nat/profiler/test_prediction_trie_integration.py +git commit --signoff -m "feat(profiler): integrate prediction trie generation + +Add PredictionTrieConfig to ProfilerConfig with enable flag. +ProfilerRunner now builds and saves prediction_trie.json when enabled." +``` + +--- + +## Task 8: Dynamo Header Injection + +**Files:** +- Modify: `src/nat/llm/dynamo_llm.py` +- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` +- Test: `tests/nat/llm/test_dynamo_prediction_headers.py` + +### Step 1: Write failing test for header injection + +```python +# tests/nat/llm/test_dynamo_prediction_headers.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.llm.dynamo_llm import create_httpx_client_with_prediction_headers +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +async def test_prediction_headers_injected(): + """Test that prediction headers are injected into requests.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Create a mock request to capture headers + captured_headers = {} + + async def capture_hook(request): + captured_headers.update(dict(request.headers)) + + client = create_httpx_client_with_prediction_headers( + prediction=prediction, + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + ) + + # Add our capture hook + client.event_hooks["request"].append(capture_hook) + + # Make a test request (will fail, but headers will be captured) + try: + await client.post("http://localhost:1/test", json={}) + except Exception: + pass + + assert "x-nat-remaining-llm-calls" in captured_headers + assert captured_headers["x-nat-remaining-llm-calls"] == "3" + assert "x-nat-interarrival-ms" in captured_headers + assert captured_headers["x-nat-interarrival-ms"] == "500" + assert "x-nat-expected-output-tokens" in captured_headers + assert captured_headers["x-nat-expected-output-tokens"] == "200" # p90 value + + await client.aclose() +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/llm/test_dynamo_prediction_headers.py -v` +Expected: FAIL with "cannot import name 'create_httpx_client_with_prediction_headers'" + +### Step 3: Add prediction header injection to dynamo_llm.py + +Add to `src/nat/llm/dynamo_llm.py`: + +```python +# Add import at top: +from nat.profiler.prediction_trie.data_models import LLMCallPrediction + + +def _create_prediction_request_hook( + prediction: LLMCallPrediction, +) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that injects prediction headers. + + Args: + prediction: The prediction data to inject + + Returns: + An async function suitable for use as an httpx event hook. + """ + + async def on_request(request): + """Inject prediction headers before each request.""" + request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) + request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) + request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + + logger.debug( + "Injected prediction headers: remaining=%d, interarrival=%d, output_tokens=%d", + int(prediction.remaining_calls.mean), + int(prediction.interarrival_ms.mean), + int(prediction.output_tokens.p90), + ) + + return on_request + + +def create_httpx_client_with_prediction_headers( + prediction: LLMCallPrediction, + prefix_template: str | None, + total_requests: int, + osl: str, + iat: str, + timeout: float = 600.0, +) -> "httpx.AsyncClient": + """ + Create an httpx.AsyncClient with both Dynamo prefix and prediction headers. + + Args: + prediction: Prediction data for this LLM call + prefix_template: Template string with {uuid} placeholder + total_requests: Expected number of requests for this prefix + osl: Output sequence length hint (LOW/MEDIUM/HIGH) + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + timeout: HTTP request timeout in seconds + + Returns: + An httpx.AsyncClient configured with header injection. + """ + import httpx + + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add prediction hook + prediction_hook = _create_prediction_request_hook(prediction) + hooks.append(prediction_hook) + + return httpx.AsyncClient( + event_hooks={"request": hooks}, + timeout=httpx.Timeout(timeout), + ) +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/llm/test_dynamo_prediction_headers.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamo_prediction_headers.py +git commit --signoff -m "feat(llm): add prediction header injection to Dynamo client + +Injects x-nat-remaining-llm-calls, x-nat-interarrival-ms, and +x-nat-expected-output-tokens headers for server routing optimization." +``` + +--- + +## Task 9: LangChain Integration with Trie Loading + +**Files:** +- Modify: `src/nat/llm/dynamo_llm.py` (add config field) +- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` +- Test: `tests/nat/plugins/langchain/test_dynamo_prediction_trie.py` + +### Step 1: Add prediction_trie_path to DynamoModelConfig + +Update `src/nat/llm/dynamo_llm.py`: + +```python +# Add to DynamoModelConfig class: + prediction_trie_path: str | None = Field( + default=None, + description="Path to prediction_trie.json file. When set, predictions are " + "looked up and injected as headers for each LLM call.", + ) + + # Update get_dynamo_field_names(): + @staticmethod + def get_dynamo_field_names() -> frozenset[str]: + return frozenset({ + "prefix_template", + "prefix_total_requests", + "prefix_osl", + "prefix_iat", + "request_timeout", + "prediction_trie_path", # ADD THIS + }) +``` + +### Step 2: Write test for trie-based header injection + +```python +# tests/nat/plugins/langchain/test_dynamo_prediction_trie.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.profiler.prediction_trie import PredictionTrieNode +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +@pytest.fixture(name="trie_file") +def fixture_trie_file() -> Path: + """Create a temporary trie file for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + path = Path(f.name) + + save_prediction_trie(root, path) + yield path + path.unlink(missing_ok=True) + + +def test_dynamo_config_with_trie_path(trie_file): + """Test that DynamoModelConfig accepts prediction_trie_path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + prediction_trie_path=str(trie_file), + ) + + assert config.prediction_trie_path == str(trie_file) + assert "prediction_trie_path" in DynamoModelConfig.get_dynamo_field_names() +``` + +### Step 3: Run test to verify config field works + +Run: `pytest tests/nat/plugins/langchain/test_dynamo_prediction_trie.py -v` +Expected: PASS + +### Step 4: Commit + +```bash +git add src/nat/llm/dynamo_llm.py tests/nat/plugins/langchain/test_dynamo_prediction_trie.py +git commit --signoff -m "feat(llm): add prediction_trie_path config to DynamoModelConfig + +Allows specifying a prediction_trie.json file path in workflow config. +When set, predictions are looked up and injected as headers." +``` + +--- + +## Task 10: End-to-End Integration Test + +**Files:** +- Test: `tests/nat/profiler/test_prediction_trie_e2e.py` + +### Step 1: Write end-to-end test + +```python +# tests/nat/profiler/test_prediction_trie_e2e.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end test for prediction trie workflow.""" + +import tempfile +from pathlib import Path + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.profile_runner import ProfilerRunner + + +def make_agent_trace(agent_name: str, num_llm_calls: int, base_timestamp: float) -> list[IntermediateStep]: + """Create a trace with multiple LLM calls in an agent.""" + steps = [] + ts = base_timestamp + + for i in range(num_llm_calls): + llm_uuid = f"llm-{agent_name}-{i}" + + # LLM_START + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=ts, + UUID=llm_uuid, + ), + ) + ) + ts += 0.5 + + # LLM_END + completion_tokens = 100 + (i * 50) # Vary tokens by position + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=ts, + span_event_timestamp=ts - 0.5, + UUID=llm_uuid, + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=completion_tokens)), + ), + ) + ) + ts += 0.5 + + return steps + + +async def test_e2e_prediction_trie_workflow(): + """Test the complete flow: profiler -> trie -> lookup.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + # Create multiple traces with different agents + traces = [ + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=1000.0), + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=2000.0), + make_agent_trace("tool_agent", num_llm_calls=2, base_timestamp=3000.0), + ] + + # Run profiler + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + runner = ProfilerRunner(config, output_dir) + await runner.run(traces) + + # Load trie + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists(), "Trie file should exist" + + trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(trie) + + # Test lookups + # react_agent has 3 LLM calls, so at call 1 there are 2 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 2.0 # 2 remaining after first call + + # At call 3 there are 0 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=3) + assert result is not None + assert result.remaining_calls.mean == 0.0 + + # tool_agent should have different stats + result = lookup.find(path=["my_workflow", "tool_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 1.0 # 1 remaining after first call + + # Unknown agent should fall back to aggregated + result = lookup.find(path=["my_workflow", "unknown_agent"], call_index=1) + assert result is not None # Should still get a result from fallback +``` + +### Step 2: Run e2e test + +Run: `pytest tests/nat/profiler/test_prediction_trie_e2e.py -v` +Expected: PASS + +### Step 3: Commit + +```bash +git add tests/nat/profiler/test_prediction_trie_e2e.py +git commit --signoff -m "test(profiler): add end-to-end prediction trie test + +Validates complete flow: profiler traces -> trie generation -> lookup +with different agents and call indices." +``` + +--- + +## Summary + +This plan implements the prediction trie feature in 10 tasks: + +1. **Data Models** - Pydantic models for trie nodes and predictions +2. **Metrics Accumulator** - Helper for computing statistics +3. **Trie Builder** - Builds trie from profiler traces +4. **Trie Lookup** - Finds best matching prediction with fallback +5. **Serialization** - JSON save/load +6. **Runtime Call Tracker** - Context variable for tracking call indices +7. **Profiler Integration** - Config option and trie generation +8. **Dynamo Header Injection** - httpx hooks for prediction headers +9. **LangChain Integration** - Config field for trie path +10. **End-to-End Test** - Validates complete flow + +Each task follows TDD: write failing test, implement, verify, commit. From 64bfc21b7c7683c9db29eeeaea113bb3c6f2cbd0 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 16:01:40 -0800 Subject: [PATCH 05/37] feat(profiler): add prediction trie data models Add Pydantic models for the prediction trie: - PredictionMetrics: aggregated stats (mean, p50, p90, p95) - LLMCallPrediction: predictions for remaining calls, interarrival time, output tokens - PredictionTrieNode: trie node with children and predictions by call index Signed-off-by: dnandakumar-nv --- src/nat/profiler/prediction_trie/__init__.py | 20 +++++ .../profiler/prediction_trie/data_models.py | 68 ++++++++++++++ .../nat/profiler/prediction_trie/__init__.py | 14 +++ .../prediction_trie/test_data_models.py | 89 +++++++++++++++++++ 4 files changed, 191 insertions(+) create mode 100644 src/nat/profiler/prediction_trie/__init__.py create mode 100644 src/nat/profiler/prediction_trie/data_models.py create mode 100644 tests/nat/profiler/prediction_trie/__init__.py create mode 100644 tests/nat/profiler/prediction_trie/test_data_models.py diff --git a/src/nat/profiler/prediction_trie/__init__.py b/src/nat/profiler/prediction_trie/__init__.py new file mode 100644 index 0000000000..8210bb0452 --- /dev/null +++ b/src/nat/profiler/prediction_trie/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + +__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode"] diff --git a/src/nat/profiler/prediction_trie/data_models.py b/src/nat/profiler/prediction_trie/data_models.py new file mode 100644 index 0000000000..bf20c0ac1b --- /dev/null +++ b/src/nat/profiler/prediction_trie/data_models.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import BaseModel +from pydantic import Field + + +class PredictionMetrics(BaseModel): + """Aggregated statistics for a single metric from profiler data.""" + + sample_count: int = Field(default=0, description="Number of samples") + mean: float = Field(default=0.0, description="Mean value") + p50: float = Field(default=0.0, description="50th percentile (median)") + p90: float = Field(default=0.0, description="90th percentile") + p95: float = Field(default=0.0, description="95th percentile") + + +class LLMCallPrediction(BaseModel): + """Predictions for an LLM call at a given position in the call hierarchy.""" + + remaining_calls: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="How many more LLM calls are expected after this one", + ) + interarrival_ms: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected time in milliseconds until the next LLM call", + ) + output_tokens: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected output token count for this call", + ) + + +class PredictionTrieNode(BaseModel): + """A node in the prediction trie representing a function in the call hierarchy.""" + + name: str = Field(description="Function name at this level in the hierarchy") + children: dict[str, PredictionTrieNode] = Field( + default_factory=dict, + description="Child nodes keyed by function name", + ) + predictions_by_call_index: dict[int, LLMCallPrediction] = Field( + default_factory=dict, + description="Predictions keyed by call index (1-indexed)", + ) + predictions_any_index: LLMCallPrediction | None = Field( + default=None, + description="Fallback predictions aggregated across all call indices", + ) + + +# Rebuild model to handle forward references +PredictionTrieNode.model_rebuild() diff --git a/tests/nat/profiler/prediction_trie/__init__.py b/tests/nat/profiler/prediction_trie/__init__.py new file mode 100644 index 0000000000..3bcc1c39bb --- /dev/null +++ b/tests/nat/profiler/prediction_trie/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/nat/profiler/prediction_trie/test_data_models.py b/tests/nat/profiler/prediction_trie/test_data_models.py new file mode 100644 index 0000000000..f26157a1cd --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_data_models.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +def test_prediction_metrics_creation(): + metrics = PredictionMetrics(sample_count=10, mean=5.0, p50=4.5, p90=8.0, p95=9.0) + assert metrics.sample_count == 10 + assert metrics.mean == 5.0 + assert metrics.p50 == 4.5 + assert metrics.p90 == 8.0 + assert metrics.p95 == 9.0 + + +def test_prediction_metrics_defaults(): + metrics = PredictionMetrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 + + +def test_llm_call_prediction_creation(): + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=5, mean=3.0, p50=3.0, p90=5.0, p95=6.0), + interarrival_ms=PredictionMetrics(sample_count=5, mean=500.0, p50=450.0, p90=800.0, p95=900.0), + output_tokens=PredictionMetrics(sample_count=5, mean=150.0, p50=140.0, p90=250.0, p95=300.0), + ) + assert prediction.remaining_calls.mean == 3.0 + assert prediction.interarrival_ms.mean == 500.0 + assert prediction.output_tokens.mean == 150.0 + + +def test_llm_call_prediction_defaults(): + prediction = LLMCallPrediction() + assert prediction.remaining_calls.sample_count == 0 + assert prediction.interarrival_ms.sample_count == 0 + assert prediction.output_tokens.sample_count == 0 + + +def test_prediction_trie_node_creation(): + node = PredictionTrieNode(name="root") + assert node.name == "root" + assert node.children == {} + assert node.predictions_by_call_index == {} + assert node.predictions_any_index is None + + +def test_prediction_trie_node_with_children(): + child = PredictionTrieNode(name="react_agent") + root = PredictionTrieNode(name="root", children={"react_agent": child}) + assert "react_agent" in root.children + assert root.children["react_agent"].name == "react_agent" + + +def test_prediction_trie_node_with_predictions(): + prediction = LLMCallPrediction() + node = PredictionTrieNode( + name="agent", + predictions_by_call_index={ + 1: prediction, 2: prediction + }, + predictions_any_index=prediction, + ) + assert 1 in node.predictions_by_call_index + assert 2 in node.predictions_by_call_index + assert node.predictions_any_index is not None + + +def test_prediction_trie_node_nested_hierarchy(): + """Test a multi-level trie structure.""" + leaf = PredictionTrieNode(name="tool_call") + middle = PredictionTrieNode(name="react_agent", children={"tool_call": leaf}) + root = PredictionTrieNode(name="workflow", children={"react_agent": middle}) + + assert root.children["react_agent"].children["tool_call"].name == "tool_call" From ee00b6614bd159c5e108339fc40f8a582ed5baf7 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 16:07:58 -0800 Subject: [PATCH 06/37] feat(profiler): add MetricsAccumulator for prediction trie Accumulates sample values and computes aggregated statistics (mean, p50, p90, p95) using linear interpolation for percentiles. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- .../prediction_trie/metrics_accumulator.py | 48 +++++++++++++++++++ .../test_metrics_accumulator.py | 36 ++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 src/nat/profiler/prediction_trie/metrics_accumulator.py create mode 100644 tests/nat/profiler/prediction_trie/test_metrics_accumulator.py diff --git a/src/nat/profiler/prediction_trie/metrics_accumulator.py b/src/nat/profiler/prediction_trie/metrics_accumulator.py new file mode 100644 index 0000000000..2659293c89 --- /dev/null +++ b/src/nat/profiler/prediction_trie/metrics_accumulator.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import math + +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +class MetricsAccumulator: + """Accumulates samples and computes aggregated statistics.""" + + def __init__(self) -> None: + self._samples: list[float] = [] + + def add_sample(self, value: float) -> None: + """Add a sample value to the accumulator.""" + self._samples.append(value) + + def compute_metrics(self) -> PredictionMetrics: + """Compute aggregated metrics from accumulated samples.""" + if not self._samples: + return PredictionMetrics() + + n = len(self._samples) + mean_val = sum(self._samples) / n + sorted_samples = sorted(self._samples) + + return PredictionMetrics( + sample_count=n, + mean=mean_val, + p50=self._percentile(sorted_samples, 50), + p90=self._percentile(sorted_samples, 90), + p95=self._percentile(sorted_samples, 95), + ) + + @staticmethod + def _percentile(sorted_data: list[float], pct: float) -> float: + """Compute percentile using linear interpolation.""" + if not sorted_data: + return 0.0 + if len(sorted_data) == 1: + return sorted_data[0] + k = (len(sorted_data) - 1) * (pct / 100.0) + f = math.floor(k) + c = math.ceil(k) + if f == c: + return sorted_data[int(k)] + return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f) diff --git a/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py new file mode 100644 index 0000000000..5d329db61d --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +def test_accumulator_add_single_sample(): + acc = MetricsAccumulator() + acc.add_sample(10.0) + metrics = acc.compute_metrics() + assert metrics.sample_count == 1 + assert metrics.mean == 10.0 + assert metrics.p50 == 10.0 + assert metrics.p90 == 10.0 + assert metrics.p95 == 10.0 + + +def test_accumulator_add_multiple_samples(): + acc = MetricsAccumulator() + for v in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]: + acc.add_sample(v) + metrics = acc.compute_metrics() + assert metrics.sample_count == 10 + assert metrics.mean == 5.5 + assert metrics.p50 == 5.5 # median of 1-10 + assert metrics.p90 == 9.1 # 90th percentile + assert metrics.p95 == pytest.approx(9.55) # 95th percentile + + +def test_accumulator_empty(): + acc = MetricsAccumulator() + metrics = acc.compute_metrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 From 325afac6d8cf234df07c11071911ce0f6a3638ab Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 16:14:44 -0800 Subject: [PATCH 07/37] feat(profiler): add PredictionTrieBuilder Builds prediction trie from profiler execution traces: - Extracts LLM call contexts (path, call index, remaining, interarrival, output tokens) - Aggregates metrics at every node along the path - Computes stats by call index and aggregated fallback Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- .../prediction_trie/metrics_accumulator.py | 4 + .../profiler/prediction_trie/trie_builder.py | 187 ++++++++++++++++++ .../prediction_trie/test_trie_builder.py | 128 ++++++++++++ 3 files changed, 319 insertions(+) create mode 100644 src/nat/profiler/prediction_trie/trie_builder.py create mode 100644 tests/nat/profiler/prediction_trie/test_trie_builder.py diff --git a/src/nat/profiler/prediction_trie/metrics_accumulator.py b/src/nat/profiler/prediction_trie/metrics_accumulator.py index 2659293c89..19ba67ead0 100644 --- a/src/nat/profiler/prediction_trie/metrics_accumulator.py +++ b/src/nat/profiler/prediction_trie/metrics_accumulator.py @@ -16,6 +16,10 @@ def add_sample(self, value: float) -> None: """Add a sample value to the accumulator.""" self._samples.append(value) + def has_samples(self) -> bool: + """Return True if any samples have been added.""" + return len(self._samples) > 0 + def compute_metrics(self) -> PredictionMetrics: """Compute aggregated metrics from accumulated samples.""" if not self._samples: diff --git a/src/nat/profiler/prediction_trie/trie_builder.py b/src/nat/profiler/prediction_trie/trie_builder.py new file mode 100644 index 0000000000..836cb3cbc3 --- /dev/null +++ b/src/nat/profiler/prediction_trie/trie_builder.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepType +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +@dataclass +class LLMCallContext: + """Context for a single LLM call extracted from a trace.""" + + path: list[str] + call_index: int + remaining_calls: int + time_to_next_ms: float | None + output_tokens: int + + +@dataclass +class _NodeAccumulators: + """Accumulators for a single trie node.""" + + remaining_calls: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + interarrival_ms: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + output_tokens: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + # For aggregated stats across all call indices + all_remaining_calls: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_interarrival_ms: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_output_tokens: MetricsAccumulator = field(default_factory=MetricsAccumulator) + + +class PredictionTrieBuilder: + """Builds a prediction trie from profiler execution traces.""" + + def __init__(self) -> None: + # Map from path tuple to accumulators + self._node_accumulators: dict[tuple[str, ...], _NodeAccumulators] = defaultdict(_NodeAccumulators) + + def add_trace(self, steps: list[IntermediateStep]) -> None: + """Process a single execution trace and update accumulators.""" + contexts = self._extract_llm_contexts(steps) + for ctx in contexts: + self._update_accumulators(ctx) + + def _extract_llm_contexts(self, steps: list[IntermediateStep]) -> list[LLMCallContext]: + """Extract LLM call contexts from a trace.""" + # Sort steps by timestamp + sorted_steps = sorted(steps, key=lambda s: s.event_timestamp) + + # Find all LLM_END events + llm_ends = [s for s in sorted_steps if s.event_type == IntermediateStepType.LLM_END] + + # Find all LLM_START events for interarrival time calculation + llm_starts = [s for s in sorted_steps if s.event_type == IntermediateStepType.LLM_START] + + # Track call index per parent function + call_counts: dict[str, int] = defaultdict(int) + contexts: list[LLMCallContext] = [] + + for i, end_step in enumerate(llm_ends): + # Build path from function ancestry + path = self._build_path(end_step) + + # Determine call index within parent + parent_key = end_step.function_ancestry.function_id + call_counts[parent_key] += 1 + call_index = call_counts[parent_key] + + # Remaining calls in this trace + remaining = len(llm_ends) - i - 1 + + # Time to next LLM start (if any) + time_to_next_ms: float | None = None + current_end_time = end_step.event_timestamp + # Find next LLM_START after this LLM_END + for start_step in llm_starts: + if start_step.event_timestamp > current_end_time: + time_to_next_ms = (start_step.event_timestamp - current_end_time) * 1000.0 + break + + # Output tokens + output_tokens = 0 + if end_step.usage_info and end_step.usage_info.token_usage: + output_tokens = end_step.usage_info.token_usage.completion_tokens or 0 + + contexts.append( + LLMCallContext( + path=path, + call_index=call_index, + remaining_calls=remaining, + time_to_next_ms=time_to_next_ms, + output_tokens=output_tokens, + )) + + return contexts + + def _build_path(self, step: IntermediateStep) -> list[str]: + """Build the function path from ancestry.""" + path: list[str] = [] + ancestry = step.function_ancestry + + # Walk up the ancestry chain + if ancestry.parent_name: + path.append(ancestry.parent_name) + path.append(ancestry.function_name) + + return path + + def _update_accumulators(self, ctx: LLMCallContext) -> None: + """Update accumulators at every node along the path.""" + # Update root node + root_key: tuple[str, ...] = () + self._add_to_accumulators(root_key, ctx) + + # Update each node along the path + for i in range(len(ctx.path)): + path_key = tuple(ctx.path[:i + 1]) + self._add_to_accumulators(path_key, ctx) + + def _add_to_accumulators(self, path_key: tuple[str, ...], ctx: LLMCallContext) -> None: + """Add context data to accumulators for a specific path.""" + accs = self._node_accumulators[path_key] + + # By call index + accs.remaining_calls[ctx.call_index].add_sample(float(ctx.remaining_calls)) + accs.output_tokens[ctx.call_index].add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.interarrival_ms[ctx.call_index].add_sample(ctx.time_to_next_ms) + + # Aggregated across all indices + accs.all_remaining_calls.add_sample(float(ctx.remaining_calls)) + accs.all_output_tokens.add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.all_interarrival_ms.add_sample(ctx.time_to_next_ms) + + def build(self) -> PredictionTrieNode: + """Build the final prediction trie from accumulated data.""" + root = PredictionTrieNode(name="root") + + for path_key, accs in self._node_accumulators.items(): + node = self._get_or_create_node(root, path_key) + self._populate_node_predictions(node, accs) + + return root + + def _get_or_create_node(self, root: PredictionTrieNode, path_key: tuple[str, ...]) -> PredictionTrieNode: + """Navigate to or create a node at the given path.""" + if not path_key: + return root + + current = root + for name in path_key: + if name not in current.children: + current.children[name] = PredictionTrieNode(name=name) + current = current.children[name] + return current + + def _populate_node_predictions(self, node: PredictionTrieNode, accs: _NodeAccumulators) -> None: + """Populate a node with computed predictions from accumulators.""" + # Predictions by call index + all_indices = set(accs.remaining_calls.keys()) | set(accs.interarrival_ms.keys()) | set( + accs.output_tokens.keys()) + + for idx in all_indices: + prediction = LLMCallPrediction( + remaining_calls=accs.remaining_calls[idx].compute_metrics(), + interarrival_ms=accs.interarrival_ms[idx].compute_metrics(), + output_tokens=accs.output_tokens[idx].compute_metrics(), + ) + node.predictions_by_call_index[idx] = prediction + + # Aggregated predictions + if accs.all_remaining_calls.has_samples(): + node.predictions_any_index = LLMCallPrediction( + remaining_calls=accs.all_remaining_calls.compute_metrics(), + interarrival_ms=accs.all_interarrival_ms.compute_metrics(), + output_tokens=accs.all_output_tokens.compute_metrics(), + ) diff --git a/tests/nat/profiler/prediction_trie/test_trie_builder.py b/tests/nat/profiler/prediction_trie/test_trie_builder.py new file mode 100644 index 0000000000..6b964b835f --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_trie_builder.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder + + +@pytest.fixture(name="simple_trace") +def fixture_simple_trace() -> list[IntermediateStep]: + """Create a simple trace with two LLM calls.""" + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100), ), + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1002.0, + UUID="llm-2", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1003.0, + span_event_timestamp=1002.0, + UUID="llm-2", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=150), ), + ), + ), + ] + + +def test_trie_builder_builds_from_single_trace(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + assert trie.name == "root" + assert "my_workflow" in trie.children + + workflow_node = trie.children["my_workflow"] + # First LLM call: call_index=1, remaining=1 + assert 1 in workflow_node.predictions_by_call_index + # Second LLM call: call_index=2, remaining=0 + assert 2 in workflow_node.predictions_by_call_index + + +def test_trie_builder_computes_remaining_calls(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call should predict 1 remaining call + assert workflow_node.predictions_by_call_index[1].remaining_calls.mean == 1.0 + # Second call should predict 0 remaining calls + assert workflow_node.predictions_by_call_index[2].remaining_calls.mean == 0.0 + + +def test_trie_builder_computes_output_tokens(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call had 100 completion tokens + assert workflow_node.predictions_by_call_index[1].output_tokens.mean == 100.0 + # Second call had 150 completion tokens + assert workflow_node.predictions_by_call_index[2].output_tokens.mean == 150.0 + + +def test_trie_builder_computes_interarrival_time(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call: next LLM starts at 1002.0, this call ends at 1001.0 -> 1000ms + assert workflow_node.predictions_by_call_index[1].interarrival_ms.mean == 1000.0 From 932f6741e56bd09a5154bc82abcbbbec520cb8b1 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Fri, 23 Jan 2026 16:24:25 -0800 Subject: [PATCH 08/37] feat(profiler): add PredictionTrieLookup Walks the trie to find best matching prediction: - Exact path + exact call_index (most specific) - Partial path + exact call_index - Falls back to aggregated predictions when call_index not found Signed-off-by: Claude Signed-off-by: dnandakumar-nv --- src/nat/profiler/prediction_trie/__init__.py | 3 +- .../profiler/prediction_trie/trie_lookup.py | 62 +++++++++++++ .../prediction_trie/test_trie_lookup.py | 91 +++++++++++++++++++ 3 files changed, 155 insertions(+), 1 deletion(-) create mode 100644 src/nat/profiler/prediction_trie/trie_lookup.py create mode 100644 tests/nat/profiler/prediction_trie/test_trie_lookup.py diff --git a/src/nat/profiler/prediction_trie/__init__.py b/src/nat/profiler/prediction_trie/__init__.py index 8210bb0452..d3b23f1075 100644 --- a/src/nat/profiler/prediction_trie/__init__.py +++ b/src/nat/profiler/prediction_trie/__init__.py @@ -16,5 +16,6 @@ from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup -__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode"] +__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode", "PredictionTrieLookup"] diff --git a/src/nat/profiler/prediction_trie/trie_lookup.py b/src/nat/profiler/prediction_trie/trie_lookup.py new file mode 100644 index 0000000000..455d645446 --- /dev/null +++ b/src/nat/profiler/prediction_trie/trie_lookup.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +class PredictionTrieLookup: + """Looks up predictions in a prediction trie with graceful fallback.""" + + def __init__(self, root: PredictionTrieNode) -> None: + self._root = root + + def find(self, path: list[str], call_index: int) -> LLMCallPrediction | None: + """ + Find the best matching prediction for the given path and call index. + + Walks the trie as far as possible along the path, then returns the deepest + match. Falls back to aggregated predictions when exact call_index isn't found. + + Args: + path: Function ancestry path (e.g., ["my_workflow", "react_agent"]) + call_index: The Nth LLM call within the current parent function + + Returns: + Best matching prediction, or None if trie is empty + """ + node = self._root + deepest_match: LLMCallPrediction | None = None + + # Check root node first + deepest_match = self._get_prediction(node, call_index) or deepest_match + + # Walk the trie as far as we can match + for func_name in path: + if func_name not in node.children: + break + node = node.children[func_name] + # Update deepest match at each level + match = self._get_prediction(node, call_index) + if match is not None: + deepest_match = match + + return deepest_match + + def _get_prediction(self, node: PredictionTrieNode, call_index: int) -> LLMCallPrediction | None: + """Get prediction from node, preferring exact call_index, falling back to aggregated.""" + if call_index in node.predictions_by_call_index: + return node.predictions_by_call_index[call_index] + return node.predictions_any_index diff --git a/tests/nat/profiler/prediction_trie/test_trie_lookup.py b/tests/nat/profiler/prediction_trie/test_trie_lookup.py new file mode 100644 index 0000000000..2fde204e5d --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_trie_lookup.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing lookups.""" + prediction_1 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + prediction_2 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=400.0, p50=380.0, p90=600.0, p95=700.0), + output_tokens=PredictionMetrics(sample_count=10, mean=200.0, p50=190.0, p90=280.0, p95=320.0), + ) + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=20, mean=2.5, p50=2.5, p90=3.5, p95=4.5), + interarrival_ms=PredictionMetrics(sample_count=20, mean=450.0, p50=415.0, p90=650.0, p95=750.0), + output_tokens=PredictionMetrics(sample_count=20, mean=175.0, p50=165.0, p90=240.0, p95=285.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: prediction_1, 2: prediction_2 + }, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + return root + + +def test_lookup_exact_match(sample_trie): + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + + assert result is not None + assert result.remaining_calls.mean == 3.0 + assert result.output_tokens.mean == 150.0 + + +def test_lookup_partial_path_match(sample_trie): + """When exact path doesn't exist, fall back to closest ancestor.""" + lookup = PredictionTrieLookup(sample_trie) + # "unknown_tool" doesn't exist, should fall back to react_agent's aggregated + result = lookup.find(path=["my_workflow", "react_agent", "unknown_tool"], call_index=1) + + assert result is not None + # Should get react_agent's call_index=1 prediction + assert result.remaining_calls.mean == 3.0 + + +def test_lookup_unknown_call_index_fallback(sample_trie): + """When call_index doesn't exist, fall back to aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=99) + + assert result is not None + # Should fall back to predictions_any_index + assert result.remaining_calls.mean == 2.5 + + +def test_lookup_no_match_returns_root_aggregated(sample_trie): + """When nothing matches, return root's aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["completely_unknown"], call_index=1) + + assert result is not None + # Should return root's aggregated prediction + assert result.remaining_calls.mean == 2.5 From d62fc1eb595a14ed2e11ae81fbc0d09c7e30c448 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:24:03 -0800 Subject: [PATCH 09/37] feat(profiler): add prediction trie serialization JSON serialization with metadata: - version, generated_at, workflow_name - Recursive node serialization/deserialization - Handles predictions_by_call_index int keys Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/profiler/prediction_trie/__init__.py | 13 +- .../profiler/prediction_trie/serialization.py | 114 ++++++++++++++++++ .../prediction_trie/test_serialization.py | 66 ++++++++++ 3 files changed, 192 insertions(+), 1 deletion(-) create mode 100644 src/nat/profiler/prediction_trie/serialization.py create mode 100644 tests/nat/profiler/prediction_trie/test_serialization.py diff --git a/src/nat/profiler/prediction_trie/__init__.py b/src/nat/profiler/prediction_trie/__init__.py index d3b23f1075..35d302c854 100644 --- a/src/nat/profiler/prediction_trie/__init__.py +++ b/src/nat/profiler/prediction_trie/__init__.py @@ -16,6 +16,17 @@ from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup -__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode", "PredictionTrieLookup"] +__all__ = [ + "LLMCallPrediction", + "PredictionMetrics", + "PredictionTrieBuilder", + "PredictionTrieLookup", + "PredictionTrieNode", + "load_prediction_trie", + "save_prediction_trie", +] diff --git a/src/nat/profiler/prediction_trie/serialization.py b/src/nat/profiler/prediction_trie/serialization.py new file mode 100644 index 0000000000..d0a6a62350 --- /dev/null +++ b/src/nat/profiler/prediction_trie/serialization.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from datetime import UTC +from datetime import datetime +from pathlib import Path +from typing import Any + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + +CURRENT_VERSION = "1.0" + + +def save_prediction_trie( + trie: PredictionTrieNode, + path: Path, + workflow_name: str = "unknown", +) -> None: + """ + Save a prediction trie to a JSON file. + + Args: + trie: The prediction trie root node + path: Path to save the JSON file + workflow_name: Name of the workflow this trie was built from + """ + data = { + "version": CURRENT_VERSION, + "generated_at": datetime.now(UTC).isoformat(), + "workflow_name": workflow_name, + "root": _serialize_node(trie), + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +def load_prediction_trie(path: Path) -> PredictionTrieNode: + """ + Load a prediction trie from a JSON file. + + Args: + path: Path to the JSON file + + Returns: + The deserialized prediction trie root node + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + + return _deserialize_node(data["root"]) + + +def _serialize_node(node: PredictionTrieNode) -> dict[str, Any]: + """Serialize a trie node to a dictionary.""" + result: dict[str, Any] = { + "name": node.name, + "predictions_by_call_index": { + str(k): v.model_dump() + for k, v in node.predictions_by_call_index.items() + }, + "predictions_any_index": node.predictions_any_index.model_dump() if node.predictions_any_index else None, + "children": { + k: _serialize_node(v) + for k, v in node.children.items() + }, + } + return result + + +def _deserialize_node(data: dict[str, Any]) -> PredictionTrieNode: + """Deserialize a dictionary to a trie node.""" + predictions_by_call_index: dict[int, LLMCallPrediction] = {} + for k, v in data.get("predictions_by_call_index", {}).items(): + predictions_by_call_index[int(k)] = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + predictions_any_index = None + if data.get("predictions_any_index"): + v = data["predictions_any_index"] + predictions_any_index = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + children: dict[str, PredictionTrieNode] = {} + for k, v in data.get("children", {}).items(): + children[k] = _deserialize_node(v) + + return PredictionTrieNode( + name=data["name"], + predictions_by_call_index=predictions_by_call_index, + predictions_any_index=predictions_any_index, + children=children, + ) diff --git a/tests/nat/profiler/prediction_trie/test_serialization.py b/tests/nat/profiler/prediction_trie/test_serialization.py new file mode 100644 index 0000000000..a617d5d48d --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_serialization.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +import tempfile +from pathlib import Path + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing serialization.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + child = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"react_agent": child}, + predictions_any_index=prediction, + ) + + return root + + +def test_save_and_load_trie(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + loaded = load_prediction_trie(path) + + assert loaded.name == "root" + assert "react_agent" in loaded.children + assert loaded.children["react_agent"].predictions_by_call_index[1].remaining_calls.mean == 3.0 + + +def test_saved_file_has_metadata(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + with open(path) as f: + data = json.load(f) + + assert data["version"] == "1.0" + assert data["workflow_name"] == "test_workflow" + assert "generated_at" in data + assert "root" in data From f916957320e2f5a201b801ef4d6340f73a5e04cb Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:30:32 -0800 Subject: [PATCH 10/37] feat(llm): add LLMCallTracker for runtime prediction lookups Context variable-based tracking of LLM call indices per function invocation. Thread/async-safe using contextvars. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/llm/prediction_context.py | 62 ++++++++++++++++++++++++ tests/nat/llm/test_prediction_context.py | 31 ++++++++++++ 2 files changed, 93 insertions(+) create mode 100644 src/nat/llm/prediction_context.py create mode 100644 tests/nat/llm/test_prediction_context.py diff --git a/src/nat/llm/prediction_context.py b/src/nat/llm/prediction_context.py new file mode 100644 index 0000000000..dd4757dba0 --- /dev/null +++ b/src/nat/llm/prediction_context.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Runtime context management for prediction trie lookups. + +Provides tracking of LLM call indices per function invocation, +enabling accurate lookups in the prediction trie at runtime. +""" + +from contextvars import ContextVar +from dataclasses import dataclass +from dataclasses import field + + +@dataclass +class LLMCallTracker: + """Tracks LLM call counts per function invocation.""" + + counts: dict[str, int] = field(default_factory=dict) + + def increment(self, parent_function_id: str) -> int: + """ + Increment and return the call index for this parent. + + Args: + parent_function_id: Unique ID of the parent function invocation + + Returns: + The call index (1-indexed) for this LLM call within the parent + """ + self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 + return self.counts[parent_function_id] + + def reset(self, parent_function_id: str) -> None: + """ + Reset call count when a function invocation completes. + + Args: + parent_function_id: Unique ID of the parent function invocation + """ + self.counts.pop(parent_function_id, None) + + +# Thread/async-safe context variable for the call tracker +_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar("llm_call_tracker") + + +def get_call_tracker() -> LLMCallTracker: + """ + Get the LLMCallTracker for the current context. + + Creates a new tracker if one doesn't exist in the current context. + + Returns: + The LLMCallTracker for this context + """ + try: + return _llm_call_tracker.get() + except LookupError: + tracker = LLMCallTracker() + _llm_call_tracker.set(tracker) + return tracker diff --git a/tests/nat/llm/test_prediction_context.py b/tests/nat/llm/test_prediction_context.py new file mode 100644 index 0000000000..d46166d7d0 --- /dev/null +++ b/tests/nat/llm/test_prediction_context.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from nat.llm.prediction_context import LLMCallTracker +from nat.llm.prediction_context import get_call_tracker + + +def test_tracker_increment(): + tracker = LLMCallTracker() + assert tracker.increment("func-1") == 1 + assert tracker.increment("func-1") == 2 + assert tracker.increment("func-2") == 1 + assert tracker.increment("func-1") == 3 + + +def test_tracker_reset(): + tracker = LLMCallTracker() + tracker.increment("func-1") + tracker.increment("func-1") + tracker.reset("func-1") + assert tracker.increment("func-1") == 1 + + +def test_tracker_context_variable(): + tracker1 = get_call_tracker() + tracker1.increment("func-a") + + tracker2 = get_call_tracker() + # Should be the same tracker in the same context + assert tracker2.increment("func-a") == 2 From 381f0afde5248d7bbf9a38916483ceabc7e5f574 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:37:07 -0800 Subject: [PATCH 11/37] feat(profiler): integrate prediction trie generation Add PredictionTrieConfig to ProfilerConfig with enable flag. ProfilerRunner now builds and saves prediction_trie.json when enabled. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/data_models/profiler.py | 6 ++ src/nat/profiler/profile_runner.py | 19 +++++ .../test_prediction_trie_integration.py | 79 +++++++++++++++++++ 3 files changed, 104 insertions(+) create mode 100644 tests/nat/profiler/test_prediction_trie_integration.py diff --git a/src/nat/data_models/profiler.py b/src/nat/data_models/profiler.py index cb0ed64544..ed64ed2c18 100644 --- a/src/nat/data_models/profiler.py +++ b/src/nat/data_models/profiler.py @@ -40,6 +40,11 @@ class PrefixSpanConfig(BaseModel): chain_with_common_prefixes: bool = False +class PredictionTrieConfig(BaseModel): + enable: bool = False + output_filename: str = "prediction_trie.json" + + class ProfilerConfig(BaseModel): base_metrics: bool = False @@ -52,3 +57,4 @@ class ProfilerConfig(BaseModel): bottleneck_analysis: BottleneckConfig = BottleneckConfig() concurrency_spike_analysis: ConcurrencySpikeConfig = ConcurrencySpikeConfig() prefix_span_analysis: PrefixSpanConfig = PrefixSpanConfig() + prediction_trie: PredictionTrieConfig = PredictionTrieConfig() diff --git a/src/nat/profiler/profile_runner.py b/src/nat/profiler/profile_runner.py index 0ac72f5deb..d18e7ff99b 100644 --- a/src/nat/profiler/profile_runner.py +++ b/src/nat/profiler/profile_runner.py @@ -270,6 +270,25 @@ async def run(self, all_steps: list[list[IntermediateStep]]) -> ProfilerResults: json.dump(workflow_profiling_metrics, f, indent=2) logger.info("Wrote workflow profiling metrics to: %s", profiling_metrics_path) + if self.profile_config.prediction_trie.enable: + # ------------------------------------------------------------ + # Build and save prediction trie + # ------------------------------------------------------------ + from nat.profiler.prediction_trie import PredictionTrieBuilder + from nat.profiler.prediction_trie import save_prediction_trie + + logger.info("Building prediction trie from traces...") + trie_builder = PredictionTrieBuilder() + for trace in all_steps: + trie_builder.add_trace(trace) + + prediction_trie = trie_builder.build() + + if self.write_output: + trie_path = os.path.join(self.output_dir, self.profile_config.prediction_trie.output_filename) + save_prediction_trie(prediction_trie, Path(trie_path), workflow_name="profiled_workflow") + logger.info("Wrote prediction trie to: %s", trie_path) + if self.profile_config.token_usage_forecast: # ------------------------------------------------------------ # Fit forecasting model and save diff --git a/tests/nat/profiler/test_prediction_trie_integration.py b/tests/nat/profiler/test_prediction_trie_integration.py new file mode 100644 index 0000000000..8484a8e078 --- /dev/null +++ b/tests/nat/profiler/test_prediction_trie_integration.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.profile_runner import ProfilerRunner + + +@pytest.fixture(name="sample_traces") +def fixture_sample_traces() -> list[list[IntermediateStep]]: + """Create sample traces for testing profiler integration.""" + + def make_trace() -> list[IntermediateStep]: + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100)), + ), + ), + ] + + return [make_trace(), make_trace()] + + +async def test_profiler_generates_prediction_trie(sample_traces): + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + + runner = ProfilerRunner(config, output_dir) + await runner.run(sample_traces) + + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists() + + trie = load_prediction_trie(trie_path) + assert trie.name == "root" + assert "my_workflow" in trie.children From ee166c8d14613a8a0796aac5bc76c7d3c18f6cb6 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:47:15 -0800 Subject: [PATCH 12/37] feat(llm): add prediction header injection to Dynamo client Injects x-nat-remaining-llm-calls, x-nat-interarrival-ms, and x-nat-expected-output-tokens headers for server routing optimization. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 69 +++++++++++++++++++ .../nat/llm/test_dynamo_prediction_headers.py | 47 +++++++++++++ 2 files changed, 116 insertions(+) create mode 100644 tests/nat/llm/test_dynamo_prediction_headers.py diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 79667e106b..6a82c9bdcf 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -68,6 +68,7 @@ from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import SearchSpace from nat.llm.openai_llm import OpenAIModelConfig +from nat.profiler.prediction_trie.data_models import LLMCallPrediction logger = logging.getLogger(__name__) @@ -347,6 +348,74 @@ def create_httpx_client_with_dynamo_hooks( ) +def _create_prediction_request_hook( + prediction: LLMCallPrediction, ) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that injects prediction headers. + + Args: + prediction: The prediction data to inject + + Returns: + An async function suitable for use as an httpx event hook. + """ + + async def on_request(request): + """Inject prediction headers before each request.""" + request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) + request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) + request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + + logger.debug( + "Injected prediction headers: remaining=%d, interarrival=%d, output_tokens=%d", + int(prediction.remaining_calls.mean), + int(prediction.interarrival_ms.mean), + int(prediction.output_tokens.p90), + ) + + return on_request + + +def create_httpx_client_with_prediction_headers( + prediction: LLMCallPrediction, + prefix_template: str | None, + total_requests: int, + osl: str, + iat: str, + timeout: float = 600.0, +) -> "httpx.AsyncClient": + """ + Create an httpx.AsyncClient with both Dynamo prefix and prediction headers. + + Args: + prediction: Prediction data for this LLM call + prefix_template: Template string with {uuid} placeholder + total_requests: Expected number of requests for this prefix + osl: Output sequence length hint (LOW/MEDIUM/HIGH) + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + timeout: HTTP request timeout in seconds + + Returns: + An httpx.AsyncClient configured with header injection. + """ + import httpx + + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add prediction hook + prediction_hook = _create_prediction_request_hook(prediction) + hooks.append(prediction_hook) + + return httpx.AsyncClient( + event_hooks={"request": hooks}, + timeout=httpx.Timeout(timeout), + ) + + # ============================================================================= # PROVIDER REGISTRATION # ============================================================================= diff --git a/tests/nat/llm/test_dynamo_prediction_headers.py b/tests/nat/llm/test_dynamo_prediction_headers.py new file mode 100644 index 0000000000..97b413424b --- /dev/null +++ b/tests/nat/llm/test_dynamo_prediction_headers.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from nat.llm.dynamo_llm import create_httpx_client_with_prediction_headers +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +async def test_prediction_headers_injected(): + """Test that prediction headers are injected into requests.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Create a mock request to capture headers + captured_headers = {} + + async def capture_hook(request): + captured_headers.update(dict(request.headers)) + + client = create_httpx_client_with_prediction_headers( + prediction=prediction, + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + ) + + # Add our capture hook + client.event_hooks["request"].append(capture_hook) + + # Make a test request (will fail, but headers will be captured) + try: + await client.post("http://localhost:1/test", json={}) + except Exception: + pass + + assert "x-nat-remaining-llm-calls" in captured_headers + assert captured_headers["x-nat-remaining-llm-calls"] == "3" + assert "x-nat-interarrival-ms" in captured_headers + assert captured_headers["x-nat-interarrival-ms"] == "500" + assert "x-nat-expected-output-tokens" in captured_headers + assert captured_headers["x-nat-expected-output-tokens"] == "200" # p90 value + + await client.aclose() From 7b8931c8cfba4841c652ac1667f8bd21e5b28447 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:51:59 -0800 Subject: [PATCH 13/37] feat(llm): add prediction_trie_path config to DynamoModelConfig Allows specifying a prediction_trie.json file path in workflow config. When set, predictions are looked up and injected as headers. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 7 ++ tests/nat/llm/test_dynamo_prediction_trie.py | 77 ++++++++++++++++++++ 2 files changed, 84 insertions(+) create mode 100644 tests/nat/llm/test_dynamo_prediction_trie.py diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 6a82c9bdcf..10ddd98011 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -215,6 +215,12 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"): description="HTTP request timeout in seconds for LLM requests.", ) + prediction_trie_path: str | None = Field( + default=None, + description="Path to prediction_trie.json file. When set, predictions are " + "looked up and injected as headers for each LLM call.", + ) + # ========================================================================= # UTILITY METHODS # ========================================================================= @@ -243,6 +249,7 @@ def get_dynamo_field_names() -> frozenset[str]: "prefix_osl", "prefix_iat", "request_timeout", + "prediction_trie_path", }) diff --git a/tests/nat/llm/test_dynamo_prediction_trie.py b/tests/nat/llm/test_dynamo_prediction_trie.py new file mode 100644 index 0000000000..517013bedc --- /dev/null +++ b/tests/nat/llm/test_dynamo_prediction_trie.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.profiler.prediction_trie import PredictionTrieNode +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +@pytest.fixture(name="trie_file") +def fixture_trie_file() -> Path: + """Create a temporary trie file for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + path = Path(f.name) + + save_prediction_trie(root, path) + yield path + path.unlink(missing_ok=True) + + +def test_dynamo_config_with_trie_path(trie_file): + """Test that DynamoModelConfig accepts prediction_trie_path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + prediction_trie_path=str(trie_file), + ) + + assert config.prediction_trie_path == str(trie_file) + assert "prediction_trie_path" in DynamoModelConfig.get_dynamo_field_names() + + +def test_dynamo_config_without_trie_path(): + """Test that DynamoModelConfig works without prediction_trie_path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + ) + + assert config.prediction_trie_path is None + + +def test_dynamo_field_names_excludes_trie_path(): + """Test that prediction_trie_path is excluded from OpenAI client kwargs.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + prediction_trie_path="/path/to/trie.json", + ) + + # Simulate what would be passed to an OpenAI client + exclude_fields = {"type", "thinking", *DynamoModelConfig.get_dynamo_field_names()} + config_dict = config.model_dump(exclude=exclude_fields, exclude_none=True) + + # prediction_trie_path should not be in the config dict + assert "prediction_trie_path" not in config_dict From 52fb243d6ccd52bdfef711baca3affc3885a1a97 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 08:56:23 -0800 Subject: [PATCH 14/37] test(profiler): add end-to-end prediction trie test Validates complete flow: profiler traces -> trie generation -> lookup with different agents and call indices. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- .../nat/profiler/test_prediction_trie_e2e.py | 116 ++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 tests/nat/profiler/test_prediction_trie_e2e.py diff --git a/tests/nat/profiler/test_prediction_trie_e2e.py b/tests/nat/profiler/test_prediction_trie_e2e.py new file mode 100644 index 0000000000..70ea27b5d1 --- /dev/null +++ b/tests/nat/profiler/test_prediction_trie_e2e.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end test for prediction trie workflow.""" + +import tempfile +from pathlib import Path + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.profile_runner import ProfilerRunner + + +def make_agent_trace(agent_name: str, num_llm_calls: int, base_timestamp: float) -> list[IntermediateStep]: + """Create a trace with multiple LLM calls in an agent.""" + steps = [] + ts = base_timestamp + + for i in range(num_llm_calls): + llm_uuid = f"llm-{agent_name}-{i}" + + # LLM_START + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=ts, + UUID=llm_uuid, + ), + )) + ts += 0.5 + + # LLM_END + completion_tokens = 100 + (i * 50) # Vary tokens by position + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=ts, + span_event_timestamp=ts - 0.5, + UUID=llm_uuid, + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=completion_tokens)), + ), + )) + ts += 0.5 + + return steps + + +async def test_e2e_prediction_trie_workflow(): + """Test the complete flow: profiler -> trie -> lookup.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + # Create multiple traces with different agents + traces = [ + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=1000.0), + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=2000.0), + make_agent_trace("tool_agent", num_llm_calls=2, base_timestamp=3000.0), + ] + + # Run profiler + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + runner = ProfilerRunner(config, output_dir) + await runner.run(traces) + + # Load trie + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists(), "Trie file should exist" + + trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(trie) + + # Test lookups + # react_agent has 3 LLM calls, so at call 1 there are 2 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 2.0 # 2 remaining after first call + + # At call 3 there are 0 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=3) + assert result is not None + assert result.remaining_calls.mean == 0.0 + + # tool_agent should have different stats + result = lookup.find(path=["my_workflow", "tool_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 1.0 # 1 remaining after first call + + # Unknown agent should fall back to aggregated + result = lookup.find(path=["my_workflow", "unknown_agent"], call_index=1) + assert result is not None # Should still get a result from fallback From 6d36b20f53d8d44ca92188b604c45b684c22da29 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:20:59 -0800 Subject: [PATCH 15/37] docs: add runtime prediction trie integration design Design document for integrating the prediction trie with runtime workflow execution: - Add function_path_stack ContextVar for full ancestry tracking - Increment call tracker in IntermediateStepManager on LLM_START - Dynamic httpx hook for per-request prediction lookup - Fallback chain to root aggregates when no match found Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- ...26-01-24-runtime-prediction-trie-design.md | 248 ++++++++++++++++++ 1 file changed, 248 insertions(+) create mode 100644 docs/plans/2026-01-24-runtime-prediction-trie-design.md diff --git a/docs/plans/2026-01-24-runtime-prediction-trie-design.md b/docs/plans/2026-01-24-runtime-prediction-trie-design.md new file mode 100644 index 0000000000..b7d9bd3109 --- /dev/null +++ b/docs/plans/2026-01-24-runtime-prediction-trie-design.md @@ -0,0 +1,248 @@ +# Runtime Prediction Trie Integration Design + +## Overview + +This design addresses the gap between the prediction trie (built by the profiler) and runtime execution. Currently, the trie is built and saved, but never loaded or used during actual workflow execution to inject prediction headers. + +## Problem Statement + +The prediction trie implementation has the following gaps: + +1. **Trie never loaded at runtime** - `prediction_trie_path` config exists but is never used +2. **Function path not tracked for lookups** - `Context.active_function` only stores immediate parent, not full ancestry +3. **Call index never tracked at runtime** - `LLMCallTracker` exists but is never incremented during LLM calls +4. **Headers are static** - httpx client created once with static hooks; predictions need dynamic per-call lookup + +## Design Goals + +- Track full function path ancestry during workflow execution +- Track LLM call indices per parent function +- Look up predictions dynamically on each LLM call +- Inject prediction headers for Dynamo routing optimization +- Work across all LLM frameworks (LangChain, LlamaIndex, etc.) + +## Architecture + +### Separation of Concerns + +| Concern | Component | Scope | +|---------|-----------|-------| +| State tracking | Callback handlers + IntermediateStepManager | All LLM providers | +| Header injection | Dynamo httpx hook | Dynamo LLM only | + +This separation ensures state is tracked universally (even if multiple LLM providers are used in one workflow), while header injection is specific to Dynamo. + +### Data Flow + +``` +1. Workflow starts + └─► function_path_stack = ["my_workflow"] + +2. Agent function called via push_active_function("react_agent") + └─► function_path_stack = ["my_workflow", "react_agent"] + +3. LLM call initiated + └─► Callback fires on_chat_model_start + └─► IntermediateStepManager.push_intermediate_step(LLM_START) + └─► call_tracker.increment(parent_function_id) → 1 + +4. httpx sends request (Dynamo) + └─► Dynamic hook executes: + ├─► Read function_path_stack → ["my_workflow", "react_agent"] + ├─► Read call_tracker count → 1 + ├─► trie_lookup.find(path, call_index) → prediction + │ └─► (fallback to root.predictions_any_index if no match) + └─► Inject headers + +5. Next LLM call → call_index becomes 2, repeat +``` + +## Components to Modify + +### 1. ContextState (src/nat/builder/context.py) + +Add new ContextVar to track full function path: + +```python +class ContextState: + def __init__(self): + # ... existing fields ... + self._function_path_stack: ContextVar[list[str] | None] = ContextVar( + "function_path_stack", default=None + ) + + @property + def function_path_stack(self) -> ContextVar[list[str]]: + if self._function_path_stack.get() is None: + self._function_path_stack.set([]) + return typing.cast(ContextVar[list[str]], self._function_path_stack) +``` + +### 2. Context.push_active_function() (src/nat/builder/context.py) + +Update to push/pop function names on path stack: + +```python +@contextmanager +def push_active_function(self, function_name: str, ...): + # ... existing code ... + + # Push function name onto path stack + current_path = self._context_state.function_path_stack.get() + new_path = current_path + [function_name] + path_token = self._context_state.function_path_stack.set(new_path) + + try: + yield manager + finally: + # ... existing cleanup ... + self._context_state.function_path_stack.reset(path_token) +``` + +### 3. IntermediateStepManager.push_intermediate_step() (src/nat/builder/intermediate_step_manager.py) + +Increment call tracker on LLM_START events: + +```python +from nat.llm.prediction_context import get_call_tracker + +def push_intermediate_step(self, payload: IntermediateStepPayload) -> None: + # ... existing code ... + + # Track LLM call index for prediction lookups + if payload.event_type == IntermediateStepType.LLM_START: + active_function = self._context_state.active_function.get() + if active_function: + tracker = get_call_tracker() + tracker.increment(active_function.function_id) + + # ... rest of existing code ... +``` + +### 4. Context.function_path property (src/nat/builder/context.py) + +Add property to read current function path: + +```python +@property +def function_path(self) -> list[str]: + """Returns the current function path stack (copy).""" + return list(self._context_state.function_path_stack.get()) +``` + +### 5. dynamo_langchain() (packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py) + +Load trie and create dynamic hook: + +```python +from nat.profiler.prediction_trie import load_prediction_trie, PredictionTrieLookup + +@register_llm_client(config_type=DynamoModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) +async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): + # Load prediction trie if configured + trie_lookup: PredictionTrieLookup | None = None + if llm_config.prediction_trie_path: + trie = load_prediction_trie(Path(llm_config.prediction_trie_path)) + trie_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) + + # Create httpx client with dynamic prediction hook + if llm_config.prefix_template is not None: + http_async_client = create_httpx_client_with_dynamo_hooks( + # ... existing params ... + prediction_lookup=trie_lookup, # Pass lookup to hook + ) +``` + +### 6. Dynamic Prediction Hook (src/nat/llm/dynamo_llm.py) + +Create hook that reads context and looks up predictions: + +```python +def _create_dynamic_prediction_hook( + trie_lookup: PredictionTrieLookup, +) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """Create hook that dynamically looks up predictions per request.""" + + async def on_request(request: "httpx.Request") -> None: + from nat.builder.context import Context + from nat.llm.prediction_context import get_call_tracker + + ctx = Context.get() + path = ctx.function_path + + # Get call index for current parent function + call_index = 1 # default + active_fn = ctx.active_function + if active_fn: + tracker = get_call_tracker() + call_index = tracker.counts.get(active_fn.function_id, 1) + + # Look up prediction + prediction = trie_lookup.find(path, call_index) + + if prediction: + request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) + request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) + request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + + logger.debug( + "Injected prediction headers: path=%s, call_index=%d, remaining=%d", + path, call_index, int(prediction.remaining_calls.mean) + ) + + return on_request +``` + +## Fallback Chain + +When looking up predictions, the following fallback chain applies: + +1. **Exact match**: path + call_index found in trie +2. **Partial path**: walk trie as far as possible, use deepest match +3. **Any index**: use node's `predictions_any_index` if exact call_index not found +4. **Root fallback**: use root's `predictions_any_index` as final fallback + +This ensures we always have some prediction to inject (root aggregates across all profiled traces). + +## Call Index Tracking + +- Each function invocation has a unique UUID (`function_id`) +- `LLMCallTracker.increment(function_id)` returns 1, 2, 3... for successive LLM calls +- No explicit reset needed - new function invocations get new UUIDs automatically +- Memory is minimal (dict of int counters) and garbage collected with context + +## Headers Injected + +| Header | Value | Description | +|--------|-------|-------------| +| `x-nat-remaining-llm-calls` | `int(prediction.remaining_calls.mean)` | Expected remaining LLM calls | +| `x-nat-interarrival-ms` | `int(prediction.interarrival_ms.mean)` | Expected ms until next call | +| `x-nat-expected-output-tokens` | `int(prediction.output_tokens.p90)` | Expected output tokens (p90) | + +## Testing Strategy + +1. **Unit tests**: Test each component in isolation + - `function_path_stack` push/pop behavior + - Call tracker increment in IntermediateStepManager + - Dynamic hook reads context correctly + +2. **Integration test**: End-to-end flow + - Create trie from sample traces + - Run workflow with Dynamo LLM + - Verify headers injected with correct values + +3. **Fallback test**: Verify fallback chain + - Unknown path falls back to root + - Unknown call_index falls back to any_index + +## Files Changed + +| File | Type | Description | +|------|------|-------------| +| `src/nat/builder/context.py` | Modify | Add function_path_stack ContextVar and property | +| `src/nat/builder/intermediate_step_manager.py` | Modify | Increment call tracker on LLM_START | +| `src/nat/llm/dynamo_llm.py` | Modify | Add dynamic prediction hook | +| `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` | Modify | Load trie, wire up hook | +| `tests/nat/builder/test_function_path_stack.py` | New | Test path stack tracking | +| `tests/nat/llm/test_dynamic_prediction_hook.py` | New | Test dynamic lookup and injection | From 1e5d370a6d0f04d0dc6f25184479bee14e409a35 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:23:50 -0800 Subject: [PATCH 16/37] docs: add runtime prediction trie implementation plan Detailed TDD implementation plan for integrating prediction trie lookups at runtime: - Task 1: Add function_path_stack ContextVar - Task 2: Track path in push_active_function - Task 3: Increment call tracker in IntermediateStepManager - Task 4: Create dynamic prediction hook - Task 5: Update httpx client creation - Task 6: Load trie in LangChain Dynamo client - Task 7: End-to-end integration test Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- ...-runtime-prediction-trie-implementation.md | 1066 +++++++++++++++++ 1 file changed, 1066 insertions(+) create mode 100644 docs/plans/2026-01-24-runtime-prediction-trie-implementation.md diff --git a/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md b/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md new file mode 100644 index 0000000000..e87f4ee540 --- /dev/null +++ b/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md @@ -0,0 +1,1066 @@ +# Runtime Prediction Trie Integration Implementation Plan + +> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. + +**Goal:** Enable runtime prediction trie lookups to inject Dynamo headers based on current function path and LLM call index. + +**Architecture:** Add a function path stack ContextVar for tracking ancestry, increment call tracker in IntermediateStepManager on LLM_START events, and create a dynamic httpx hook that reads context and looks up predictions from a pre-loaded trie. + +**Tech Stack:** Python 3.11+, contextvars, Pydantic v2, httpx event hooks + +--- + +## Task 1: Add Function Path Stack to ContextState + +**Files:** +- Modify: `src/nat/builder/context.py:67-120` +- Test: `tests/nat/builder/test_function_path_stack.py` + +### Step 1: Write the failing test for function_path_stack + +```python +# tests/nat/builder/test_function_path_stack.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.builder.context import ContextState + + +def test_function_path_stack_default_empty(): + """Test that function_path_stack starts empty.""" + state = ContextState.get() + # Reset to test fresh state + state._function_path_stack.set(None) + + path = state.function_path_stack.get() + assert path == [] + + +def test_function_path_stack_can_be_set(): + """Test that function_path_stack can be set and retrieved.""" + state = ContextState.get() + state.function_path_stack.set(["workflow", "agent"]) + + path = state.function_path_stack.get() + assert path == ["workflow", "agent"] +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/builder/test_function_path_stack.py::test_function_path_stack_default_empty -v` +Expected: FAIL with "AttributeError: 'ContextState' object has no attribute '_function_path_stack'" + +### Step 3: Add function_path_stack ContextVar to ContextState + +In `src/nat/builder/context.py`, add to `ContextState.__init__` after line 83: + +```python + self._function_path_stack: ContextVar[list[str] | None] = ContextVar("function_path_stack", default=None) +``` + +And add the property after `active_span_id_stack` property (after line 116): + +```python + @property + def function_path_stack(self) -> ContextVar[list[str]]: + if self._function_path_stack.get() is None: + self._function_path_stack.set([]) + return typing.cast(ContextVar[list[str]], self._function_path_stack) +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/builder/test_function_path_stack.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/builder/context.py tests/nat/builder/test_function_path_stack.py +git commit --signoff -m "feat(context): add function_path_stack ContextVar + +Tracks the full function ancestry path as a list of function names, +enabling prediction trie lookups at runtime." +``` + +--- + +## Task 2: Update push_active_function to Track Path Stack + +**Files:** +- Modify: `src/nat/builder/context.py:235-279` +- Test: `tests/nat/builder/test_function_path_stack.py` + +### Step 1: Write the failing test for push_active_function path tracking + +Add to `tests/nat/builder/test_function_path_stack.py`: + +```python +from nat.builder.context import Context + + +def test_push_active_function_updates_path_stack(): + """Test that push_active_function pushes/pops from path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + # Initially empty + assert state.function_path_stack.get() == [] + + with ctx.push_active_function("my_workflow", input_data=None): + assert state.function_path_stack.get() == ["my_workflow"] + + with ctx.push_active_function("react_agent", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + with ctx.push_active_function("tool_call", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent", "tool_call"] + + # After tool_call exits + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + # After react_agent exits + assert state.function_path_stack.get() == ["my_workflow"] + + # After workflow exits + assert state.function_path_stack.get() == [] +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/builder/test_function_path_stack.py::test_push_active_function_updates_path_stack -v` +Expected: FAIL with assertion error (path stack not being updated) + +### Step 3: Update push_active_function to track path stack + +In `src/nat/builder/context.py`, modify `push_active_function` method. After line 252 (after setting fn_token), add: + +```python + # 1b) Push function name onto path stack + current_path = self._context_state.function_path_stack.get() + new_path = current_path + [function_name] + path_token = self._context_state.function_path_stack.set(new_path) +``` + +And in the finally block, before line 279 (before resetting fn_token), add: + +```python + # 4a) Pop function name from path stack + self._context_state.function_path_stack.reset(path_token) +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/builder/test_function_path_stack.py -v` +Expected: PASS + +### Step 5: Add function_path property to Context class + +Add after `active_function` property (around line 289): + +```python + @property + def function_path(self) -> list[str]: + """ + Returns a copy of the current function path stack. + + The function path represents the ancestry of the currently executing + function, from root to the current function. + + Returns: + list[str]: Copy of the function path stack. + """ + return list(self._context_state.function_path_stack.get()) +``` + +### Step 6: Write test for function_path property + +Add to `tests/nat/builder/test_function_path_stack.py`: + +```python +def test_context_function_path_property(): + """Test that Context.function_path returns a copy of the path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + with ctx.push_active_function("workflow", input_data=None): + with ctx.push_active_function("agent", input_data=None): + path = ctx.function_path + assert path == ["workflow", "agent"] + + # Verify it's a copy (modifications don't affect original) + path.append("modified") + assert ctx.function_path == ["workflow", "agent"] +``` + +### Step 7: Run all tests + +Run: `pytest tests/nat/builder/test_function_path_stack.py -v` +Expected: PASS + +### Step 8: Commit + +```bash +git add src/nat/builder/context.py tests/nat/builder/test_function_path_stack.py +git commit --signoff -m "feat(context): track function path in push_active_function + +Push/pop function names onto function_path_stack in push_active_function. +Add Context.function_path property to retrieve the current path." +``` + +--- + +## Task 3: Increment Call Tracker in IntermediateStepManager + +**Files:** +- Modify: `src/nat/builder/intermediate_step_manager.py:64-96` +- Test: `tests/nat/builder/test_call_tracker_integration.py` + +### Step 1: Write the failing test for call tracker integration + +```python +# tests/nat/builder/test_call_tracker_integration.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker + + +def test_llm_start_increments_call_tracker(): + """Test that pushing an LLM_START step increments the call tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + # Initially no count for this function + assert tracker.counts.get(active_fn.function_id, 0) == 0 + + # Push LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + # Call tracker should be incremented + assert tracker.counts.get(active_fn.function_id) == 1 + + # Push another LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + # Should be 2 now + assert tracker.counts.get(active_fn.function_id) == 2 + + +def test_non_llm_start_does_not_increment_tracker(): + """Test that non-LLM_START events don't increment the tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent_2", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + initial_count = tracker.counts.get(active_fn.function_id, 0) + + # Push TOOL_START (should not increment) + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="tool-call-1", + event_type=IntermediateStepType.TOOL_START, + name="test-tool", + ) + ) + + # Count should be unchanged + assert tracker.counts.get(active_fn.function_id, 0) == initial_count +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/builder/test_call_tracker_integration.py::test_llm_start_increments_call_tracker -v` +Expected: FAIL with assertion error (count is 0, not 1) + +### Step 3: Add call tracker increment to IntermediateStepManager + +In `src/nat/builder/intermediate_step_manager.py`, add import at top: + +```python +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker +``` + +Then in `push_intermediate_step` method, after line 96 (after the debug log for START), add: + +```python + # Track LLM call index for prediction trie lookups + if payload.event_type == IntermediateStepType.LLM_START: + active_function = self._context_state.active_function.get() + if active_function and active_function.function_id != "root": + tracker = get_call_tracker() + tracker.increment(active_function.function_id) + logger.debug("Incremented LLM call tracker for %s to %d", + active_function.function_id, + tracker.counts.get(active_function.function_id, 0)) +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/builder/test_call_tracker_integration.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/builder/intermediate_step_manager.py tests/nat/builder/test_call_tracker_integration.py +git commit --signoff -m "feat(step-manager): increment call tracker on LLM_START + +IntermediateStepManager now increments LLMCallTracker when an LLM_START +event is pushed. This enables accurate call index tracking for prediction +trie lookups across all LLM frameworks." +``` + +--- + +## Task 4: Create Dynamic Prediction Hook + +**Files:** +- Modify: `src/nat/llm/dynamo_llm.py` +- Test: `tests/nat/llm/test_dynamic_prediction_hook.py` + +### Step 1: Write the failing test for dynamic prediction hook + +```python +# tests/nat/llm/test_dynamic_prediction_hook.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.builder.context import Context +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.llm.prediction_context import get_call_tracker +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +@pytest.fixture(name="sample_trie_lookup") +def fixture_sample_trie_lookup() -> PredictionTrieLookup: + """Create a sample trie lookup for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: prediction, 2: prediction}, + predictions_any_index=prediction, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=prediction, + ) + + return PredictionTrieLookup(root) + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +async def test_dynamic_hook_injects_headers(sample_trie_lookup): + """Test that dynamic hook injects prediction headers based on context.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate LLM call tracker increment (normally done by step manager) + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + assert "x-nat-remaining-llm-calls" in request.headers + assert request.headers["x-nat-remaining-llm-calls"] == "3" + assert request.headers["x-nat-interarrival-ms"] == "500" + assert request.headers["x-nat-expected-output-tokens"] == "200" + + +async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): + """Test that hook falls back to root prediction for unknown paths.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("unknown_workflow", input_data=None): + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + # Should still inject headers from root fallback + assert "x-nat-remaining-llm-calls" in request.headers +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py::test_dynamic_hook_injects_headers -v` +Expected: FAIL with "cannot import name '_create_dynamic_prediction_hook'" + +### Step 3: Implement dynamic prediction hook + +Add to `src/nat/llm/dynamo_llm.py` after the existing `_create_prediction_request_hook` function (around line 383): + +```python +def _create_dynamic_prediction_hook( + trie_lookup: "PredictionTrieLookup", +) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that dynamically looks up predictions per request. + + This hook reads the current function path and call index from context, + looks up the prediction in the trie, and injects headers. + + Args: + trie_lookup: The PredictionTrieLookup instance to query + + Returns: + An async function suitable for use as an httpx event hook. + """ + # Import here to avoid circular imports + from nat.profiler.prediction_trie import PredictionTrieLookup + + async def on_request(request: "httpx.Request") -> None: + """Look up prediction from context and inject headers.""" + from nat.builder.context import Context + from nat.llm.prediction_context import get_call_tracker + + try: + ctx = Context.get() + path = ctx.function_path + + # Get call index for current parent function + call_index = 1 # default + active_fn = ctx.active_function + if active_fn and active_fn.function_id != "root": + tracker = get_call_tracker() + call_index = tracker.counts.get(active_fn.function_id, 1) + + # Look up prediction + prediction = trie_lookup.find(path, call_index) + + if prediction: + request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) + request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) + request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + + logger.debug( + "Injected prediction headers: path=%s, call_index=%d, remaining=%d, interarrival=%d, output=%d", + path, + call_index, + int(prediction.remaining_calls.mean), + int(prediction.interarrival_ms.mean), + int(prediction.output_tokens.p90), + ) + else: + logger.debug("No prediction found for path=%s, call_index=%d", path, call_index) + + except Exception as e: + # Don't fail the request if prediction lookup fails + logger.warning("Failed to inject prediction headers: %s", e) + + return on_request +``` + +Also add the import at top of file (after existing TYPE_CHECKING imports): + +```python +if TYPE_CHECKING: + import httpx + from nat.profiler.prediction_trie import PredictionTrieLookup +``` + +### Step 4: Run test to verify it passes + +Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamic_prediction_hook.py +git commit --signoff -m "feat(dynamo): add dynamic prediction hook + +Creates httpx hook that reads function path and call index from context, +looks up prediction in trie, and injects headers per-request." +``` + +--- + +## Task 5: Update create_httpx_client_with_dynamo_hooks + +**Files:** +- Modify: `src/nat/llm/dynamo_llm.py:325-355` +- Test: `tests/nat/llm/test_dynamo_prediction_hook.py` + +### Step 1: Write test for updated client creation + +Add to `tests/nat/llm/test_dynamic_prediction_hook.py`: + +```python +from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks + + +async def test_client_includes_prediction_hook_when_lookup_provided(sample_trie_lookup): + """Test that client includes prediction hook when trie_lookup is provided.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=sample_trie_lookup, + ) + + # Should have 2 hooks: dynamo prefix + prediction + assert len(client.event_hooks["request"]) == 2 + + await client.aclose() + + +async def test_client_works_without_prediction_lookup(): + """Test that client works when prediction_lookup is None.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=None, + ) + + # Should have 1 hook: dynamo prefix only + assert len(client.event_hooks["request"]) == 1 + + await client.aclose() +``` + +### Step 2: Run test to verify it fails + +Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py::test_client_includes_prediction_hook_when_lookup_provided -v` +Expected: FAIL with "unexpected keyword argument 'prediction_lookup'" + +### Step 3: Update create_httpx_client_with_dynamo_hooks + +Modify `create_httpx_client_with_dynamo_hooks` in `src/nat/llm/dynamo_llm.py`: + +```python +def create_httpx_client_with_dynamo_hooks( + prefix_template: str | None, + total_requests: int, + osl: str, + iat: str, + timeout: float = 600.0, + prediction_lookup: "PredictionTrieLookup | None" = None, +) -> "httpx.AsyncClient": + """ + Create an httpx.AsyncClient with Dynamo prefix header injection. + + This client can be passed to the OpenAI SDK to inject headers at the HTTP level, + making it framework-agnostic. + + Args: + prefix_template: Template string with {uuid} placeholder + total_requests: Expected number of requests for this prefix + osl: Output sequence length hint (LOW/MEDIUM/HIGH) + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + timeout: HTTP request timeout in seconds + prediction_lookup: Optional PredictionTrieLookup for dynamic header injection + + Returns: + An httpx.AsyncClient configured with Dynamo header injection. + """ + import httpx + + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add dynamic prediction hook if lookup provided + if prediction_lookup is not None: + prediction_hook = _create_dynamic_prediction_hook(prediction_lookup) + hooks.append(prediction_hook) + + return httpx.AsyncClient( + event_hooks={"request": hooks}, + timeout=httpx.Timeout(timeout), + ) +``` + +### Step 4: Run tests to verify they pass + +Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py -v` +Expected: PASS + +### Step 5: Commit + +```bash +git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamic_prediction_hook.py +git commit --signoff -m "feat(dynamo): add prediction_lookup param to client creation + +create_httpx_client_with_dynamo_hooks now accepts optional prediction_lookup +parameter. When provided, adds dynamic prediction hook to inject headers." +``` + +--- + +## Task 6: Load Trie in LangChain Dynamo Client + +**Files:** +- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py:202-252` +- Test: `tests/nat/plugins/langchain/test_dynamo_trie_loading.py` + +### Step 1: Write test for trie loading + +```python +# tests/nat/plugins/langchain/test_dynamo_trie_loading.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path + +import pytest + +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +@pytest.fixture(name="trie_file") +def fixture_trie_file(): + """Create a temporary trie file.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(root, path, workflow_name="test") + yield str(path) + + +def test_dynamo_config_with_valid_trie_path(trie_file): + """Test that DynamoModelConfig can be created with valid trie path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path=trie_file, + ) + + assert config.prediction_trie_path == trie_file + + +def test_dynamo_config_with_nonexistent_trie_path(): + """Test that DynamoModelConfig accepts nonexistent path (validated at load time).""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path="/nonexistent/path/trie.json", + ) + + # Config creation should succeed; error happens at runtime + assert config.prediction_trie_path == "/nonexistent/path/trie.json" +``` + +### Step 2: Run tests + +Run: `pytest tests/nat/plugins/langchain/test_dynamo_trie_loading.py -v` +Expected: PASS (config validation already exists) + +### Step 3: Update dynamo_langchain to load trie + +Modify `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py`. Add import at top: + +```python +from pathlib import Path + +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie import PredictionTrieLookup +``` + +Then modify the `dynamo_langchain` function (around line 202-252): + +```python +@register_llm_client(config_type=DynamoModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) +async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): + """ + Create a LangChain ChatOpenAI client for Dynamo with automatic prefix header injection. + + This client injects Dynamo prefix headers at the HTTP transport level using httpx event hooks, + enabling KV cache optimization and request routing. + """ + from langchain_openai import ChatOpenAI + + # Build config dict excluding Dynamo-specific and NAT-specific fields + config_dict = llm_config.model_dump( + exclude={"type", "thinking", "api_type", *DynamoModelConfig.get_dynamo_field_names()}, + by_alias=True, + exclude_none=True, + exclude_unset=True, + ) + + # Initialize http_async_client to None for proper cleanup + http_async_client = None + + # Load prediction trie if configured + prediction_lookup: PredictionTrieLookup | None = None + if llm_config.prediction_trie_path: + try: + trie_path = Path(llm_config.prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", llm_config.prediction_trie_path) + except Exception as e: + logger.warning("Failed to load prediction trie: %s", e) + + try: + # If prefix_template is set, create a custom httpx client with Dynamo hooks + if llm_config.prefix_template is not None: + http_async_client = create_httpx_client_with_dynamo_hooks( + prefix_template=llm_config.prefix_template, + total_requests=llm_config.prefix_total_requests, + osl=llm_config.prefix_osl, + iat=llm_config.prefix_iat, + timeout=llm_config.request_timeout, + prediction_lookup=prediction_lookup, + ) + config_dict["http_async_client"] = http_async_client + logger.info( + "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", + llm_config.prefix_template, + llm_config.prefix_total_requests, + llm_config.prefix_osl, + llm_config.prefix_iat, + "loaded" if prediction_lookup else "disabled", + ) + + # Create the ChatOpenAI client + if llm_config.api_type == APITypeEnum.RESPONSES: + client = ChatOpenAI(stream_usage=True, use_responses_api=True, use_previous_response_id=True, **config_dict) + else: + client = ChatOpenAI(stream_usage=True, **config_dict) + + yield _patch_llm_based_on_config(client, llm_config) + finally: + # Ensure the httpx client is properly closed to avoid resource leaks + if http_async_client is not None: + await http_async_client.aclose() +``` + +### Step 4: Run existing tests to ensure no regressions + +Run: `pytest tests/nat/plugins/langchain/ -v -k dynamo` +Expected: PASS + +### Step 5: Commit + +```bash +git add packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py tests/nat/plugins/langchain/test_dynamo_trie_loading.py +git commit --signoff -m "feat(langchain): load prediction trie in dynamo_langchain + +Loads prediction trie from prediction_trie_path config and passes +PredictionTrieLookup to httpx client for dynamic header injection." +``` + +--- + +## Task 7: End-to-End Integration Test + +**Files:** +- Test: `tests/nat/llm/test_runtime_prediction_e2e.py` + +### Step 1: Write end-to-end integration test + +```python +# tests/nat/llm/test_runtime_prediction_e2e.py +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end test for runtime prediction trie integration.""" + +import tempfile +from pathlib import Path + +import pytest + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.llm.prediction_context import get_call_tracker +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +def create_test_trie() -> PredictionTrieNode: + """Create a test trie with known predictions.""" + # Agent at call 1: 2 remaining, 500ms interarrival, 150 tokens + call_1_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Agent at call 2: 1 remaining, 300ms interarrival, 100 tokens + call_2_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=1.0, p50=1.0, p90=2.0, p95=2.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=300.0, p50=280.0, p90=400.0, p95=450.0), + output_tokens=PredictionMetrics(sample_count=10, mean=100.0, p50=90.0, p90=150.0, p95=180.0), + ) + + # Agent at call 3: 0 remaining + call_3_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + output_tokens=PredictionMetrics(sample_count=10, mean=80.0, p50=75.0, p90=120.0, p95=140.0), + ) + + # Aggregated for fallback + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=30, mean=1.0, p50=1.0, p90=2.0, p95=3.0), + interarrival_ms=PredictionMetrics(sample_count=30, mean=400.0, p50=380.0, p90=550.0, p95=600.0), + output_tokens=PredictionMetrics(sample_count=30, mean=110.0, p50=100.0, p90=160.0, p95=190.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: call_1_prediction, 2: call_2_prediction, 3: call_3_prediction}, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + return PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + +async def test_e2e_prediction_headers_injected_correctly(): + """Test complete flow: context tracking -> step manager -> hook -> headers.""" + # Create and save trie + trie = create_test_trie() + + with tempfile.TemporaryDirectory() as tmpdir: + trie_path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(trie, trie_path, workflow_name="test") + + # Load trie + loaded_trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(loaded_trie) + + # Create hook + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate first LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + request1 = MockRequest() + await hook(request1) + + # Should have call 1 predictions: 2 remaining + assert request1.headers["x-nat-remaining-llm-calls"] == "2" + assert request1.headers["x-nat-interarrival-ms"] == "500" + assert request1.headers["x-nat-expected-output-tokens"] == "200" + + # Simulate second LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + request2 = MockRequest() + await hook(request2) + + # Should have call 2 predictions: 1 remaining + assert request2.headers["x-nat-remaining-llm-calls"] == "1" + assert request2.headers["x-nat-interarrival-ms"] == "300" + assert request2.headers["x-nat-expected-output-tokens"] == "150" + + # Simulate third LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-3", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + request3 = MockRequest() + await hook(request3) + + # Should have call 3 predictions: 0 remaining + assert request3.headers["x-nat-remaining-llm-calls"] == "0" + assert request3.headers["x-nat-expected-output-tokens"] == "120" + + +async def test_e2e_fallback_to_root(): + """Test that unknown paths fall back to root predictions.""" + trie = create_test_trie() + lookup = PredictionTrieLookup(trie) + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("unknown_workflow", input_data=None): + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-unknown", + event_type=IntermediateStepType.LLM_START, + name="test-model", + ) + ) + + request = MockRequest() + await hook(request) + + # Should fall back to root aggregated predictions + assert "x-nat-remaining-llm-calls" in request.headers + assert request.headers["x-nat-remaining-llm-calls"] == "1" # aggregated mean +``` + +### Step 2: Run e2e test + +Run: `pytest tests/nat/llm/test_runtime_prediction_e2e.py -v` +Expected: PASS + +### Step 3: Commit + +```bash +git add tests/nat/llm/test_runtime_prediction_e2e.py +git commit --signoff -m "test: add end-to-end runtime prediction trie test + +Validates complete flow: function path tracking -> call tracker increment +-> dynamic hook lookup -> correct headers injected for each call index." +``` + +--- + +## Summary + +This plan implements runtime prediction trie integration in 7 tasks: + +1. **Function Path Stack** - Add ContextVar to ContextState +2. **Path Tracking** - Update push_active_function to track path +3. **Call Tracker Integration** - Increment tracker in IntermediateStepManager on LLM_START +4. **Dynamic Hook** - Create hook that reads context and looks up predictions +5. **Client Update** - Add prediction_lookup param to client creation +6. **LangChain Integration** - Load trie in dynamo_langchain +7. **E2E Test** - Validate complete flow + +Each task follows TDD: write failing test, implement, verify, commit. From 6137fb7e7c8448210fdbd3948809062c39177b85 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:27:06 -0800 Subject: [PATCH 17/37] feat(context): add function_path_stack ContextVar to ContextState Add a new ContextVar to track the full function ancestry path as a list of function names. This will be used by the runtime prediction trie integration to perform prediction lookups using the full path (e.g., ["my_workflow", "react_agent", "tool"]). The implementation follows the existing pattern of active_span_id_stack, using a private ContextVar with None default and a property that lazily initializes to an empty list. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/builder/context.py | 7 ++++++ tests/nat/builder/test_function_path_stack.py | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+) create mode 100644 tests/nat/builder/test_function_path_stack.py diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index d74551878a..54387c9bea 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -81,6 +81,7 @@ def __init__(self): self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None) self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None) self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None) + self._function_path_stack: ContextVar[list[str] | None] = ContextVar("function_path_stack", default=None) # Default is a lambda no-op which returns NoneType self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]] @@ -115,6 +116,12 @@ def active_span_id_stack(self) -> ContextVar[list[str]]: self._active_span_id_stack.set(["root"]) return typing.cast(ContextVar[list[str]], self._active_span_id_stack) + @property + def function_path_stack(self) -> ContextVar[list[str]]: + if self._function_path_stack.get() is None: + self._function_path_stack.set([]) + return typing.cast(ContextVar[list[str]], self._function_path_stack) + @staticmethod def get() -> "ContextState": return ContextState() diff --git a/tests/nat/builder/test_function_path_stack.py b/tests/nat/builder/test_function_path_stack.py new file mode 100644 index 0000000000..54bfe7cfd3 --- /dev/null +++ b/tests/nat/builder/test_function_path_stack.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from nat.builder.context import ContextState + + +def test_function_path_stack_default_empty(): + """Test that function_path_stack starts empty.""" + state = ContextState.get() + # Reset to test fresh state + state._function_path_stack.set(None) + + path = state.function_path_stack.get() + assert path == [] + + +def test_function_path_stack_can_be_set(): + """Test that function_path_stack can be set and retrieved.""" + state = ContextState.get() + state.function_path_stack.set(["workflow", "agent"]) + + path = state.function_path_stack.get() + assert path == ["workflow", "agent"] From b91d1f2d260653fb6e750aaa09daf6879e10ff36 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:31:29 -0800 Subject: [PATCH 18/37] feat(context): update push_active_function to track function path stack Update the push_active_function context manager to push/pop function names on the function_path_stack ContextVar. This enables tracking the complete ancestry of the currently executing function from root to leaf. Changes: - Push function name onto path stack when entering push_active_function - Pop function name using ContextVar.reset(token) when exiting - Add Context.function_path property that returns a copy of the path stack Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/builder/context.py | 23 ++++++++- tests/nat/builder/test_function_path_stack.py | 50 ++++++++++++++++++- 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index 54387c9bea..e22a1618c9 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -258,6 +258,11 @@ def push_active_function(self, # 1) Set the active function in the contextvar fn_token = self._context_state.active_function.set(current_function_node) + # 1b) Push function name onto path stack + current_path = self._context_state.function_path_stack.get() + new_path = current_path + [function_name] + path_token = self._context_state.function_path_stack.set(new_path) + # 2) Optionally record function start as an intermediate step step_manager = self.intermediate_step_manager step_manager.push_intermediate_step( @@ -282,7 +287,10 @@ def push_active_function(self, name=function_name, data=data)) - # 4) Unset the function contextvar + # 4a) Pop function name from path stack + self._context_state.function_path_stack.reset(path_token) + + # 4b) Unset the function contextvar self._context_state.active_function.reset(fn_token) @property @@ -295,6 +303,19 @@ def active_function(self) -> InvocationNode: """ return self._context_state.active_function.get() + @property + def function_path(self) -> list[str]: + """ + Returns a copy of the current function path stack. + + The function path represents the ancestry of the currently executing + function, from root to the current function. + + Returns: + list[str]: Copy of the function path stack. + """ + return list(self._context_state.function_path_stack.get()) + @property def active_span_id(self) -> str: """ diff --git a/tests/nat/builder/test_function_path_stack.py b/tests/nat/builder/test_function_path_stack.py index 54bfe7cfd3..da65a7fa26 100644 --- a/tests/nat/builder/test_function_path_stack.py +++ b/tests/nat/builder/test_function_path_stack.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +from nat.builder.context import Context from nat.builder.context import ContextState @@ -22,3 +22,51 @@ def test_function_path_stack_can_be_set(): path = state.function_path_stack.get() assert path == ["workflow", "agent"] + + +def test_push_active_function_updates_path_stack(): + """Test that push_active_function pushes/pops from path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + # Initially empty + assert state.function_path_stack.get() == [] + + with ctx.push_active_function("my_workflow", input_data=None): + assert state.function_path_stack.get() == ["my_workflow"] + + with ctx.push_active_function("react_agent", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + with ctx.push_active_function("tool_call", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent", "tool_call"] + + # After tool_call exits + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + # After react_agent exits + assert state.function_path_stack.get() == ["my_workflow"] + + # After workflow exits + assert state.function_path_stack.get() == [] + + +def test_context_function_path_property(): + """Test that Context.function_path returns a copy of the path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + with ctx.push_active_function("workflow", input_data=None): + with ctx.push_active_function("agent", input_data=None): + path = ctx.function_path + assert path == ["workflow", "agent"] + + # Verify it's a copy (modifications don't affect original) + path.append("modified") + assert ctx.function_path == ["workflow", "agent"] From d4f02f22431bc66ed9b70a689131dbdce80a3e04 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:35:48 -0800 Subject: [PATCH 19/37] feat(step_manager): increment LLM call tracker on LLM_START events Update IntermediateStepManager.push_intermediate_step() to increment the LLMCallTracker whenever an LLM_START event is pushed. This ensures call indices are tracked for all LLM frameworks (LangChain, LlamaIndex, etc.) since they all push events through this manager. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/builder/intermediate_step_manager.py | 12 ++++ .../builder/test_call_tracker_integration.py | 66 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 tests/nat/builder/test_call_tracker_integration.py diff --git a/src/nat/builder/intermediate_step_manager.py b/src/nat/builder/intermediate_step_manager.py index 2f9af79fa5..015a169451 100644 --- a/src/nat/builder/intermediate_step_manager.py +++ b/src/nat/builder/intermediate_step_manager.py @@ -22,6 +22,8 @@ from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepState +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker from nat.utils.reactive.observable import OnComplete from nat.utils.reactive.observable import OnError from nat.utils.reactive.observable import OnNext @@ -95,6 +97,16 @@ def push_intermediate_step(self, payload: IntermediateStepPayload) -> None: parent_step_id, id(active_span_id_stack)) + # Track LLM call index for prediction trie lookups + if payload.event_type == IntermediateStepType.LLM_START: + active_function = self._context_state.active_function.get() + if active_function and active_function.function_id != "root": + tracker = get_call_tracker() + tracker.increment(active_function.function_id) + logger.debug("Incremented LLM call tracker for %s to %d", + active_function.function_id, + tracker.counts.get(active_function.function_id, 0)) + elif (payload.event_state == IntermediateStepState.END): # Remove the current step from the outstanding steps diff --git a/tests/nat/builder/test_call_tracker_integration.py b/tests/nat/builder/test_call_tracker_integration.py new file mode 100644 index 0000000000..ec06ae5521 --- /dev/null +++ b/tests/nat/builder/test_call_tracker_integration.py @@ -0,0 +1,66 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker + + +def test_llm_start_increments_call_tracker(): + """Test that pushing an LLM_START step increments the call tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + # Initially no count for this function + assert tracker.counts.get(active_fn.function_id, 0) == 0 + + # Push LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + # Call tracker should be incremented + assert tracker.counts.get(active_fn.function_id) == 1 + + # Push another LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + # Should be 2 now + assert tracker.counts.get(active_fn.function_id) == 2 + + +def test_non_llm_start_does_not_increment_tracker(): + """Test that non-LLM_START events don't increment the tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent_2", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + initial_count = tracker.counts.get(active_fn.function_id, 0) + + # Push TOOL_START (should not increment) + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="tool-call-1", + event_type=IntermediateStepType.TOOL_START, + name="test-tool", + )) + + # Count should be unchanged + assert tracker.counts.get(active_fn.function_id, 0) == initial_count From fa7830d74ee408f1fa7134d4a5d9d016cb624860 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:40:01 -0800 Subject: [PATCH 20/37] Add dynamic prediction hook for runtime trie lookups Create _create_dynamic_prediction_hook function that dynamically looks up predictions from the trie based on current context (function path + call index) and injects headers for Dynamo optimization. The hook: - Reads Context.function_path to get current ancestry - Reads LLMCallTracker.counts to get current call index - Looks up prediction in trie using trie_lookup.find(path, call_index) - Injects headers: x-nat-remaining-llm-calls, x-nat-interarrival-ms, x-nat-expected-output-tokens This is part of the dynamic inference headers feature for KV cache optimization with NVIDIA Dynamo. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 59 ++++++++ tests/nat/llm/test_dynamic_prediction_hook.py | 138 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 tests/nat/llm/test_dynamic_prediction_hook.py diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 10ddd98011..89bdf21dc6 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -60,6 +60,8 @@ if TYPE_CHECKING: import httpx + from nat.profiler.prediction_trie import PredictionTrieLookup + from pydantic import Field from nat.builder.builder import Builder @@ -383,6 +385,63 @@ async def on_request(request): return on_request +def _create_dynamic_prediction_hook( + trie_lookup: "PredictionTrieLookup", ) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that dynamically looks up predictions per request. + + This hook reads the current function path and call index from context, + looks up the prediction in the trie, and injects headers. + + Args: + trie_lookup: The PredictionTrieLookup instance to query + + Returns: + An async function suitable for use as an httpx event hook. + """ + + async def on_request(request: "httpx.Request") -> None: + """Look up prediction from context and inject headers.""" + from nat.builder.context import Context + from nat.llm.prediction_context import get_call_tracker + + try: + ctx = Context.get() + path = ctx.function_path + + # Get call index for current parent function + call_index = 1 # default + active_fn = ctx.active_function + if active_fn and active_fn.function_id != "root": + tracker = get_call_tracker() + call_index = tracker.counts.get(active_fn.function_id, 1) + + # Look up prediction + prediction = trie_lookup.find(path, call_index) + + if prediction: + request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) + request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) + request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + + logger.debug( + "Injected prediction headers: path=%s, call_index=%d, remaining=%d, interarrival=%d, output=%d", + path, + call_index, + int(prediction.remaining_calls.mean), + int(prediction.interarrival_ms.mean), + int(prediction.output_tokens.p90), + ) + else: + logger.debug("No prediction found for path=%s, call_index=%d", path, call_index) + + except Exception as e: + # Don't fail the request if prediction lookup fails + logger.warning("Failed to inject prediction headers: %s", e) + + return on_request + + def create_httpx_client_with_prediction_headers( prediction: LLMCallPrediction, prefix_template: str | None, diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py new file mode 100644 index 0000000000..6abf491c86 --- /dev/null +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from nat.builder.context import Context +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.llm.prediction_context import get_call_tracker +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +@pytest.fixture(name="sample_trie_lookup") +def fixture_sample_trie_lookup() -> PredictionTrieLookup: + """Create a sample trie lookup for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: prediction, 2: prediction + }, + predictions_any_index=prediction, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=prediction, + ) + + return PredictionTrieLookup(root) + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +async def test_dynamic_hook_injects_headers(sample_trie_lookup): + """Test that dynamic hook injects prediction headers based on context.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate LLM call tracker increment (normally done by step manager) + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + assert "x-nat-remaining-llm-calls" in request.headers + assert request.headers["x-nat-remaining-llm-calls"] == "3" + assert request.headers["x-nat-interarrival-ms"] == "500" + assert request.headers["x-nat-expected-output-tokens"] == "200" + + +async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): + """Test that hook falls back to root prediction for unknown paths.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("unknown_workflow", input_data=None): + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + # Should fall back to root aggregated predictions + assert "x-nat-remaining-llm-calls" in request.headers + + +async def test_dynamic_hook_handles_empty_context(sample_trie_lookup): + """Test that hook handles missing context gracefully.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state to empty + state._function_path_stack.set(None) + state._active_function.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + request = MockRequest() + # Should not raise an exception + await hook(request) + + # Should still inject headers from root fallback + assert "x-nat-remaining-llm-calls" in request.headers + + +async def test_dynamic_hook_no_prediction_found(): + """Test that hook handles case where no prediction is found.""" + # Create empty trie with no predictions + empty_root = PredictionTrieNode(name="root") + empty_trie = PredictionTrieLookup(empty_root) + + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(empty_trie) + + with ctx.push_active_function("some_function", input_data=None): + request = MockRequest() + await hook(request) + + # Headers should not be injected when no prediction found + assert "x-nat-remaining-llm-calls" not in request.headers From bd40ae09d6f5a9d02c909b9d78a9b667e370b54d Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:41:40 -0800 Subject: [PATCH 21/37] fix(test): include prediction_trie_path in dynamo field names test Update test expectation to include prediction_trie_path which was added to DynamoModelConfig.get_dynamo_field_names() in a previous task. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- tests/nat/llm/test_dynamo_llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/nat/llm/test_dynamo_llm.py b/tests/nat/llm/test_dynamo_llm.py index 2fce9b313b..d73203fdcf 100644 --- a/tests/nat/llm/test_dynamo_llm.py +++ b/tests/nat/llm/test_dynamo_llm.py @@ -134,6 +134,7 @@ def test_get_dynamo_field_names(self): "prefix_osl", "prefix_iat", "request_timeout", + "prediction_trie_path", }) assert field_names == expected From 2ec7d7c7473ff5e8037fa5a2671478879f48d977 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:44:54 -0800 Subject: [PATCH 22/37] Add prediction_lookup parameter to create_httpx_client_with_dynamo_hooks Update the function to accept an optional PredictionTrieLookup parameter. When provided, adds the dynamic prediction hook to the list of hooks, enabling runtime header injection based on trie predictions. Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 15 +++++++-- tests/nat/llm/test_dynamic_prediction_hook.py | 33 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 89bdf21dc6..6931a958ca 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -330,6 +330,7 @@ def create_httpx_client_with_dynamo_hooks( osl: str, iat: str, timeout: float = 600.0, + prediction_lookup: "PredictionTrieLookup | None" = None, ) -> "httpx.AsyncClient": """ Create an httpx.AsyncClient with Dynamo prefix header injection. @@ -343,16 +344,26 @@ def create_httpx_client_with_dynamo_hooks( osl: Output sequence length hint (LOW/MEDIUM/HIGH) iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) timeout: HTTP request timeout in seconds + prediction_lookup: Optional PredictionTrieLookup for dynamic header injection Returns: An httpx.AsyncClient configured with Dynamo header injection. """ import httpx - request_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add dynamic prediction hook if lookup provided + if prediction_lookup is not None: + prediction_hook = _create_dynamic_prediction_hook(prediction_lookup) + hooks.append(prediction_hook) return httpx.AsyncClient( - event_hooks={"request": [request_hook]}, + event_hooks={"request": hooks}, timeout=httpx.Timeout(timeout), ) diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py index 6abf491c86..cbb2196b9b 100644 --- a/tests/nat/llm/test_dynamic_prediction_hook.py +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -5,6 +5,7 @@ from nat.builder.context import Context from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks from nat.llm.prediction_context import get_call_tracker from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie.data_models import LLMCallPrediction @@ -136,3 +137,35 @@ async def test_dynamic_hook_no_prediction_found(): # Headers should not be injected when no prediction found assert "x-nat-remaining-llm-calls" not in request.headers + + +async def test_client_includes_prediction_hook_when_lookup_provided(sample_trie_lookup): + """Test that client includes prediction hook when trie_lookup is provided.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=sample_trie_lookup, + ) + + # Should have 2 hooks: dynamo prefix + prediction + assert len(client.event_hooks["request"]) == 2 + + await client.aclose() + + +async def test_client_works_without_prediction_lookup(): + """Test that client works when prediction_lookup is None.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=None, + ) + + # Should have 1 hook: dynamo prefix only + assert len(client.event_hooks["request"]) == 1 + + await client.aclose() From 3a7e42e19cb313763d4122b865a27befc22eb957 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 09:54:54 -0800 Subject: [PATCH 23/37] Add end-to-end integration test for runtime prediction trie Validates the complete flow from function path tracking through header injection: - function_path_stack updates on push_active_function - IntermediateStepManager increments call tracker on LLM_START - Dynamic hook reads context and looks up predictions - Correct headers injected based on call index Co-Authored-By: Claude Opus 4.5 Signed-off-by: dnandakumar-nv --- tests/nat/llm/test_runtime_prediction_e2e.py | 184 +++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 tests/nat/llm/test_runtime_prediction_e2e.py diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py new file mode 100644 index 0000000000..7e5fae54f3 --- /dev/null +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end test for runtime prediction trie integration. + +This test validates that all pieces work together: +1. function_path_stack gets updated when push_active_function is called +2. IntermediateStepManager increments call tracker on LLM_START +3. Dynamic hook reads context and looks up predictions +4. Correct headers are injected based on call index +""" + +import tempfile +from pathlib import Path + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +def create_test_trie() -> PredictionTrieNode: + """Create a test trie with known predictions.""" + # Agent at call 1: 2 remaining, 500ms interarrival, 150 tokens + call_1_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Agent at call 2: 1 remaining, 300ms interarrival, 100 tokens + call_2_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=1.0, p50=1.0, p90=2.0, p95=2.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=300.0, p50=280.0, p90=400.0, p95=450.0), + output_tokens=PredictionMetrics(sample_count=10, mean=100.0, p50=90.0, p90=150.0, p95=180.0), + ) + + # Agent at call 3: 0 remaining + call_3_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + output_tokens=PredictionMetrics(sample_count=10, mean=80.0, p50=75.0, p90=120.0, p95=140.0), + ) + + # Aggregated for fallback + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=30, mean=1.0, p50=1.0, p90=2.0, p95=3.0), + interarrival_ms=PredictionMetrics(sample_count=30, mean=400.0, p50=380.0, p90=550.0, p95=600.0), + output_tokens=PredictionMetrics(sample_count=30, mean=110.0, p50=100.0, p90=160.0, p95=190.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: call_1_prediction, 2: call_2_prediction, 3: call_3_prediction + }, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + return PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + +async def test_e2e_prediction_headers_injected_correctly(): + """Test complete flow: context tracking -> step manager -> hook -> headers.""" + # Create and save trie + trie = create_test_trie() + + with tempfile.TemporaryDirectory() as tmpdir: + trie_path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(trie, trie_path, workflow_name="test") + + # Load trie + loaded_trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(loaded_trie) + + # Create hook + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate first LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request1 = MockRequest() + await hook(request1) + + # Should have call 1 predictions: 2 remaining + assert request1.headers["x-nat-remaining-llm-calls"] == "2" + assert request1.headers["x-nat-interarrival-ms"] == "500" + assert request1.headers["x-nat-expected-output-tokens"] == "200" + + # Simulate second LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request2 = MockRequest() + await hook(request2) + + # Should have call 2 predictions: 1 remaining + assert request2.headers["x-nat-remaining-llm-calls"] == "1" + assert request2.headers["x-nat-interarrival-ms"] == "300" + assert request2.headers["x-nat-expected-output-tokens"] == "150" + + # Simulate third LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-3", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request3 = MockRequest() + await hook(request3) + + # Should have call 3 predictions: 0 remaining + assert request3.headers["x-nat-remaining-llm-calls"] == "0" + assert request3.headers["x-nat-expected-output-tokens"] == "120" + + +async def test_e2e_fallback_to_root(): + """Test that unknown paths fall back to root predictions.""" + trie = create_test_trie() + lookup = PredictionTrieLookup(trie) + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("unknown_workflow", input_data=None): + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-unknown", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request = MockRequest() + await hook(request) + + # Should fall back to root aggregated predictions + assert "x-nat-remaining-llm-calls" in request.headers + assert request.headers["x-nat-remaining-llm-calls"] == "1" # aggregated mean From 912245a1fc5f14fcd3468687b16ae3a7a1798eaf Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 10:21:26 -0800 Subject: [PATCH 24/37] docs: add prediction trie example config design Design for two-phase Dynamo optimization workflow: - Phase 1: Profile with prediction_trie.enable to build trie - Phase 2: Run with prediction_trie_path for dynamic headers Co-Authored-By: Claude Opus 4.5 --- ...4-prediction-trie-example-config-design.md | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 docs/plans/2026-01-24-prediction-trie-example-config-design.md diff --git a/docs/plans/2026-01-24-prediction-trie-example-config-design.md b/docs/plans/2026-01-24-prediction-trie-example-config-design.md new file mode 100644 index 0000000000..74922d487b --- /dev/null +++ b/docs/plans/2026-01-24-prediction-trie-example-config-design.md @@ -0,0 +1,127 @@ +# Prediction Trie Example Config Design + +## Overview + +Create example configs and documentation demonstrating the two-phase Dynamo optimization workflow using prediction trie for dynamic header injection. + +## Two-Phase Workflow + +``` +Phase 1: Profiling +┌─────────────────────────────────────────────────────────────┐ +│ nat eval --config_file profile_rethinking_full_test.yml │ +│ │ │ +│ ▼ │ +│ outputs/rethinking_full_test_for_profiling/ │ +│ └── prediction_trie.json │ +└─────────────────────────────────────────────────────────────┘ + +Phase 2: Run with Predictions +┌─────────────────────────────────────────────────────────────┐ +│ nat eval --config_file run_with_prediction_trie.yml │ +│ │ │ +│ Loads prediction_trie.json │ +│ │ │ +│ Injects dynamic headers per LLM call: │ +│ - x-nat-remaining-llm-calls │ +│ - x-nat-interarrival-ms │ +│ - x-nat-expected-output-tokens │ +└─────────────────────────────────────────────────────────────┘ +``` + +**Key difference from static headers:** Instead of guessing `prefix_total_requests=10`, the trie provides accurate per-call predictions based on function path and call index from profiled data. + +## Deliverables + +### 1. Update: profile_rethinking_full_test.yml + +Add `prediction_trie` section to enable trie building: + +```yaml +profiler: + # ... existing config ... + + # NEW: Build prediction trie from profiled traces + prediction_trie: + enable: true + output_filename: prediction_trie.json +``` + +### 2. New: run_with_prediction_trie.yml + +Config that loads the trie and uses dynamic predictions: + +```yaml +llms: + dynamo_llm: + _type: dynamo + model_name: llama-3.3-70b + base_url: http://localhost:8099/v1 + api_key: dummy + temperature: 0.0 + max_tokens: 8192 + stop: ["Observation:", "\nThought:"] + prefix_template: "react-benchmark-{uuid}" + + # Static headers as fallback + prefix_total_requests: 10 + prefix_osl: MEDIUM + prefix_iat: MEDIUM + + # NEW: Load prediction trie for dynamic per-call headers + prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json + +eval: + general: + output: + dir: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/prediction_trie_eval/ + + profiler: + compute_llm_metrics: true + csv_exclude_io_text: true +``` + +### 3. New: README_PREDICTION_TRIE.md + +Documentation for the two-phase workflow: + +```markdown +# Prediction Trie Optimization for Dynamo + +## Overview +Use profiled execution data to inject accurate per-call prediction headers +instead of static guesses. + +## Quick Start + +### Phase 1: Build the Prediction Trie +nat eval --config_file configs/profile_rethinking_full_test.yml + +Output: outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json + +### Phase 2: Run with Predictions +1. Update prediction_trie_path in run_with_prediction_trie.yml +2. Run: nat eval --config_file configs/run_with_prediction_trie.yml + +## How It Works +- Phase 1 profiles the agent and builds a trie mapping (function_path, call_index) → predictions +- Phase 2 loads the trie and injects headers dynamically based on current execution context + +## Headers Injected +| Header | Source | Description | +|--------|--------|-------------| +| x-nat-remaining-llm-calls | prediction.remaining_calls.mean | Expected remaining calls | +| x-nat-interarrival-ms | prediction.interarrival_ms.mean | Expected time to next call | +| x-nat-expected-output-tokens | prediction.output_tokens.p90 | Expected output tokens | + +## Comparing Results +Run both static and prediction-based configs and compare avg_llm_latency metrics. +``` + +## Files Changed + +| File | Type | Description | +|------|------|-------------| +| `examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml` | Modify | Add prediction_trie.enable: true | +| `examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml` | New | Config using prediction_trie_path | +| `examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md` | New | Documentation for two-phase workflow | From 3c6f65dfb7169a46e3da4f75a54e9fbaee9fe6dd Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 10:27:17 -0800 Subject: [PATCH 25/37] feat(examples): add prediction trie example configs and docs Add two-phase workflow for Dynamo optimization using prediction trie: Phase 1 (profiling): - Enable prediction_trie.enable in profile_rethinking_full_test.yml - Builds trie from profiled execution data Phase 2 (runtime): - New run_with_prediction_trie.yml config - Loads trie and injects dynamic headers per LLM call - Headers: x-nat-remaining-llm-calls, x-nat-interarrival-ms, x-nat-expected-output-tokens Documentation: - README_PREDICTION_TRIE.md with quick start, how it works, configuration reference, and troubleshooting Co-Authored-By: Claude Opus 4.5 --- .../README_PREDICTION_TRIE.md | 155 +++++++++++++ .../configs/profile_rethinking_full_test.yml | 5 + .../configs/run_with_prediction_trie.yml | 216 ++++++++++++++++++ 3 files changed, 376 insertions(+) create mode 100644 examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md create mode 100644 examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml diff --git a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md new file mode 100644 index 0000000000..332b3cf8c5 --- /dev/null +++ b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md @@ -0,0 +1,155 @@ +# Prediction Trie Optimization for Dynamo + +Use profiled execution data to inject accurate per-call prediction headers instead of static guesses. + +## Overview + +The prediction trie enables **dynamic header injection** for Dynamo's KV-aware routing. Instead of using static values like `prefix_total_requests=10` for every call, the trie provides accurate predictions based on: +- **Function path**: Where in the agent hierarchy the call originates (e.g., `["react_workflow", "react_agent"]`) +- **Call index**: Which LLM call this is within the current function (1st, 2nd, 3rd, etc.) + +This allows Dynamo's Thompson Sampling router to make better worker assignment decisions. + +## Quick Start + +### Phase 1: Build the Prediction Trie + +Run profiling to collect execution data and build the trie: + +```bash +nat eval --config_file configs/profile_rethinking_full_test.yml +``` + +**Output location:** +``` +outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json +``` + +### Phase 2: Run with Predictions + +1. **Update the trie path** in `configs/run_with_prediction_trie.yml`: + ```yaml + prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json + ``` + +2. **Run with dynamic predictions:** + ```bash + nat eval --config_file configs/run_with_prediction_trie.yml + ``` + +## How It Works + +### During Profiling (Phase 1) + +The profiler collects data for each LLM call: +- Function path at time of call +- Call index within the parent function +- Output tokens generated +- Time until the next LLM call (interarrival) +- Remaining LLM calls in the workflow + +This data is aggregated into a trie structure with statistical summaries (mean, p50, p90, etc.) at each node. + +### During Execution (Phase 2) + +For each LLM request: +1. Read the current function path from context +2. Read the call index from the LLM call tracker +3. Look up the prediction in the trie +4. Inject headers into the HTTP request + +### Fallback Chain + +If an exact match isn't found, the trie lookup falls back: +1. Exact path + exact call index (most specific) +2. Exact path + any call index +3. Partial path + exact call index +4. Root aggregated stats (most general) + +This ensures predictions are always available, even for novel execution paths. + +## Headers Injected + +| Header | Source | Description | +|--------|--------|-------------| +| `x-nat-remaining-llm-calls` | `prediction.remaining_calls.mean` | Expected remaining LLM calls in workflow | +| `x-nat-interarrival-ms` | `prediction.interarrival_ms.mean` | Expected milliseconds until next call | +| `x-nat-expected-output-tokens` | `prediction.output_tokens.p90` | Expected output tokens (90th percentile) | + +## Comparing Results + +To measure the impact of prediction trie vs static headers: + +1. **Run with static headers** (baseline): + ```bash + nat eval --config_file configs/eval_config_rethinking_full_test.yml + ``` + +2. **Run with prediction trie**: + ```bash + nat eval --config_file configs/run_with_prediction_trie.yml + ``` + +3. **Compare metrics**: + - `avg_llm_latency`: Lower is better + - `avg_workflow_runtime`: Lower is better + - Look for improvements in KV cache hit rates in Dynamo logs + +## Configuration Reference + +### Profiler Config (Phase 1) + +Enable trie building in the profiler section: + +```yaml +profiler: + prediction_trie: + enable: true + output_filename: prediction_trie.json # default +``` + +### LLM Config (Phase 2) + +Add the trie path to your Dynamo LLM config: + +```yaml +llms: + dynamo_llm: + _type: dynamo + prefix_template: "react-benchmark-{uuid}" + + # Static fallbacks (used if trie lookup fails) + prefix_total_requests: 10 + prefix_osl: MEDIUM + prefix_iat: MEDIUM + + # Dynamic predictions from profiled data + prediction_trie_path: /path/to/prediction_trie.json +``` + +## Troubleshooting + +### "Prediction trie file not found" + +The trie file doesn't exist at the configured path. Check: +- Did Phase 1 profiling complete successfully? +- Is the job_id in the path correct? +- Is the path relative to where you're running the command? + +### "No prediction found for path" + +This is normal - it means the trie is using fallback predictions. The trie will fall back to more general predictions when exact matches aren't found. + +### Headers not being injected + +Ensure: +- `prefix_template` is set (required for Dynamo hooks) +- `prediction_trie_path` points to a valid trie file +- You're using the `dynamo` LLM type + +## Files + +| File | Purpose | +|------|---------| +| `configs/profile_rethinking_full_test.yml` | Phase 1: Profile and build trie | +| `configs/run_with_prediction_trie.yml` | Phase 2: Run with dynamic predictions | diff --git a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml index 47b7e243fb..5dc1fb40a3 100644 --- a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml +++ b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml @@ -223,6 +223,11 @@ eval: concurrency_spike_analysis: enable: true spike_threshold: 24 # Alert when concurrent functions >= 24 + # Build prediction trie for dynamic Dynamo header injection + # Output: prediction_trie.json in the output directory + # Use with run_with_prediction_trie.yml for optimized routing + prediction_trie: + enable: true evaluators: tool_selection_quality: diff --git a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml new file mode 100644 index 0000000000..b22f114692 --- /dev/null +++ b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# ============================================================================= +# RUN WITH PREDICTION TRIE - DYNAMIC HEADER INJECTION +# ============================================================================= +# Purpose: Use profiled prediction trie for dynamic Dynamo header injection +# +# Prerequisites: +# 1. Run profiling first to build the prediction trie: +# nat eval --config_file configs/profile_rethinking_full_test.yml +# +# 2. Update prediction_trie_path below to point to the generated trie: +# outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json +# +# What this does: +# - Loads the prediction trie built from profiled execution data +# - For each LLM call, looks up predictions based on: +# * Current function path (e.g., ["react_workflow", "react_agent"]) +# * Call index within the current function +# - Injects dynamic headers per request: +# * x-nat-remaining-llm-calls: Expected remaining calls +# * x-nat-interarrival-ms: Expected time until next call +# * x-nat-expected-output-tokens: Expected output tokens (p90) +# +# Benefits over static headers: +# - Accurate per-call predictions instead of guessing prefix_total_requests=10 +# - Different predictions for different parts of the agent workflow +# - Dynamo router can make better worker assignment decisions +# +# Usage: +# nat eval --config_file configs/run_with_prediction_trie.yml +# ============================================================================= + +functions: + react_benchmark_agent: + _type: react_benchmark_agent + prefix: "Agent:" + decision_only: true + canned_response_template: "Successfully executed {tool_name}. Operation completed." + + # Define the ReAct workflow + react_workflow: + _type: react_agent + llm_name: dynamo_llm + tool_names: [ + banking_tools + ] + verbose: false # Disable verbose for benchmarking + parse_agent_response_max_retries: 3 + max_tool_calls: 25 + max_history: 1000 + pass_tool_call_errors_to_agent: true + recursion_limit: 50 + system_prompt: | + You are a tool-calling agent evaluated on TOOL SELECTION capability. Your goal is to select ALL the correct tools, in the correct order, to COMPLETELY handle real-world use-cases. + + IMPORTANT: This is a tool selection exercise, NOT real execution. + - Focus on selecting the RIGHT TOOL for each step + - Use placeholder or dummy values for required parameters (e.g., "12345", "user@example.com", "2024-01-01") + - Tool responses are simulated - ignore them and focus on selecting the next appropriate tool + - What matters is YOUR INTENT and TOOL CHOICE, not the data quality + + Available tools: + + {tools} + + Use this exact format for EACH response: + + Thought: I need to analyze what the user needs and select the SINGLE NEXT tool to call. + Action: the ONE tool to call right now, must be one of [{tool_names}] + Action Input: valid JSON with required parameters (use placeholder values) + + CRITICAL RULES: + 1. Output ONLY ONE Thought, Action, and Action Input per response + 2. STOP IMMEDIATELY after writing Action Input + 3. DO NOT write the Observation - the system will provide it + 4. DO NOT write multiple Thought/Action/Action Input cycles in one response + + When you have called ALL necessary tools: + Thought: I have now completed all necessary steps + Final Answer: [summary of what was accomplished] + +function_groups: + banking_tools: + _type: banking_tools_group + # tools.json available after running: /examples/dynamo_integration/scripts/download_agent_leaderboard_v2.py + tools_json_path: ./examples/dynamo_integration/data/raw/banking/tools.json + decision_only: true + include: [ + get_account_balance, + get_transaction_history, + transfer_funds, + get_loan_information, + get_credit_card_information, + get_mortgage_details, + get_savings_account_products, + schedule_appointment, + check_loan_application_status, + find_nearby_locations, + get_investment_products, + report_lost_stolen_card, + update_contact_information, + setup_automatic_bill_pay, + initiate_transaction_dispute, + get_exchange_rates, + calculate_loan_payment, + manage_account_alerts, + check_wire_transfer_status, + get_cd_products + ] + +llms: + # ========================================================================= + # DYNAMO LLM WITH PREDICTION TRIE + # ========================================================================= + # Uses prediction_trie_path to load profiled predictions and inject + # dynamic headers per LLM call based on current execution context. + dynamo_llm: + _type: dynamo + model_name: llama-3.3-70b + base_url: http://localhost:8099/v1 + api_key: dummy + temperature: 0.0 + max_tokens: 8192 + stop: ["Observation:", "\nThought:"] + + # Dynamo prefix configuration (required for prefix routing) + prefix_template: "react-benchmark-{uuid}" + + # Static fallback values (used if trie lookup fails) + prefix_total_requests: 10 + prefix_osl: MEDIUM + prefix_iat: MEDIUM + + # ========================================================================= + # PREDICTION TRIE - Dynamic per-call header injection + # ========================================================================= + # UPDATE THIS PATH to point to your profiled prediction trie: + # 1. Run: nat eval --config_file configs/profile_rethinking_full_test.yml + # 2. Find the job output directory (includes job_id) + # 3. Set path to: //prediction_trie.json + prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling/REPLACE_WITH_JOB_ID/prediction_trie.json + + # Secondary LLM for self-evaluation (no prediction trie needed) + eval_llm: + _type: dynamo + model_name: llama-3.3-70b + base_url: http://localhost:8099/v1 + api_key: dummy + temperature: 0.0 + max_tokens: 1024 + +# Advanced self-evaluating wrapper with feedback +workflow: + _type: self_evaluating_agent_with_feedback + wrapped_agent: react_workflow + evaluator_llm: eval_llm + max_retries: 3 + min_confidence_threshold: 0.80 + pass_feedback_to_agent: true + verbose: false + +eval: + general: + max_concurrency: 36 + + output: + dir: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/prediction_trie_eval/ + cleanup: false + job_management: + append_job_id_to_output_dir: true + + dataset: + _type: json + file_path: ./examples/dynamo_integration/data/agent_leaderboard_v2_banking.json + structure: + disable: true + + # Lighter profiler config - we're consuming predictions, not building them + profiler: + compute_llm_metrics: true + csv_exclude_io_text: true + # No prediction_trie section - we're using the trie, not building it + + # ========================================================================= + # RUNTIME EVALUATORS - Compare against static header baseline + # ========================================================================= + evaluators: + # Primary metric: Average LLM latency per call (seconds) + avg_llm_latency: + _type: avg_llm_latency + max_concurrency: 36 + + # Secondary metric: Average workflow runtime (seconds) + avg_workflow_runtime: + _type: avg_workflow_runtime + max_concurrency: 36 + + # Tertiary metric: Average number of LLM calls + avg_num_llm_calls: + _type: avg_num_llm_calls + max_concurrency: 36 From 4f09ca360cee5a1d881ab820774f55310718cf69 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 11:45:01 -0800 Subject: [PATCH 26/37] Refactor header injection with dynamic prediction logic Updated x-prefix-* headers to use categorical values (LOW/MEDIUM/HIGH) derived from prediction metrics. Introduced support for loading and handling prediction trie files for dynamic header overrides, ensuring consistent and contextually accurate LLM request annotations. Signed-off-by: dnandakumar-nv --- .../src/nat/plugins/langchain/llm.py | 20 +- .../tests/test_llm_langchain.py | 1 + src/nat/llm/dynamo_llm.py | 118 ++++++++++-- .../builder/test_call_tracker_integration.py | 1 - tests/nat/llm/test_dynamic_prediction_hook.py | 23 ++- tests/nat/llm/test_prediction_context.py | 1 - .../langchain/test_dynamo_trie_loading.py | 179 ++++++++++++++++++ 7 files changed, 311 insertions(+), 32 deletions(-) create mode 100644 tests/nat/plugins/langchain/test_dynamo_trie_loading.py diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index d3cb632f14..88a8b5f2ed 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -16,6 +16,7 @@ import logging from collections.abc import Sequence +from pathlib import Path from typing import Any from typing import TypeVar @@ -37,6 +38,8 @@ from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import load_prediction_trie from nat.utils.exception_handlers.automatic_retries import patch_with_retry from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override @@ -220,6 +223,19 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): # Initialize http_async_client to None for proper cleanup http_async_client = None + # Load prediction trie if configured + prediction_lookup: PredictionTrieLookup | None = None + if llm_config.prediction_trie_path: + try: + trie_path = Path(llm_config.prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", llm_config.prediction_trie_path) + except Exception as e: + logger.warning("Failed to load prediction trie: %s", e) + try: # If prefix_template is set, create a custom httpx client with Dynamo hooks if llm_config.prefix_template is not None: @@ -229,14 +245,16 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): osl=llm_config.prefix_osl, iat=llm_config.prefix_iat, timeout=llm_config.request_timeout, + prediction_lookup=prediction_lookup, ) config_dict["http_async_client"] = http_async_client logger.info( - "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s", + "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", llm_config.prefix_template, llm_config.prefix_total_requests, llm_config.prefix_osl, llm_config.prefix_iat, + "loaded" if prediction_lookup else "disabled", ) # Create the ChatOpenAI client diff --git a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py index 98e60932a0..e5a4f18741 100644 --- a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py +++ b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py @@ -241,6 +241,7 @@ async def test_creation_with_prefix_template(self, osl="HIGH", iat="LOW", timeout=300.0, + prediction_lookup=None, ) # Verify ChatOpenAI was called with the custom httpx client diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 6931a958ca..50916fcf1c 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -65,6 +65,8 @@ from pydantic import Field from nat.builder.builder import Builder +from nat.builder.context import Context +from nat.builder.context import Singleton from nat.builder.llm import LLMProviderInfo from nat.cli.register_workflow import register_llm_provider from nat.data_models.optimizable import OptimizableField @@ -77,12 +79,49 @@ # Define valid prefix hint values PrefixLevel = Literal["LOW", "MEDIUM", "HIGH"] +# ============================================================================= +# CATEGORY CONVERSION HELPERS +# ============================================================================= + + +def _output_tokens_to_osl(output_tokens: float) -> PrefixLevel: + """ + Convert predicted output tokens to OSL category. + + Thresholds: + - < 256 tokens: LOW (short responses) + - < 1024 tokens: MEDIUM (typical responses) + - >= 1024 tokens: HIGH (long responses) + """ + if output_tokens < 256: + return "LOW" + if output_tokens < 1024: + return "MEDIUM" + return "HIGH" + + +def _interarrival_ms_to_iat(interarrival_ms: float) -> PrefixLevel: + """ + Convert predicted interarrival time to IAT category. + + Thresholds: + - < 100ms: LOW (rapid bursts, high worker stickiness) + - < 500ms: MEDIUM (normal pacing) + - >= 500ms: HIGH (slow requests, more exploration) + """ + if interarrival_ms < 100: + return "LOW" + if interarrival_ms < 500: + return "MEDIUM" + return "HIGH" + + # ============================================================================= # CONTEXT MANAGEMENT FOR DYNAMO PREFIX ID # ============================================================================= -class DynamoPrefixContext: +class DynamoPrefixContext(metaclass=Singleton): """ Singleton class for managing Dynamo prefix IDs across LLM calls. @@ -133,6 +172,11 @@ def get(cls) -> str | None: """Get the current Dynamo prefix ID from context, if any.""" return cls._current_prefix_id.get() + @classmethod + def is_set(cls) -> bool: + """Check if a Dynamo prefix ID is currently set in context.""" + return cls.get() is not None + @classmethod @contextmanager def scope(cls, prefix_id: str) -> Iterator[None]: @@ -299,6 +343,17 @@ def _create_dynamo_request_hook( async def on_request(request): """Inject Dynamo prefix headers before each request.""" # Check context variable first (allows per-question override in batch evaluation) + + if not DynamoPrefixContext.is_set(): + if not Context.workflow_run_id: + logger.warning("No workflow_run_id in context; using unique prefix ID.") + import uuid + prefix = str(uuid.uuid4().hex[:16]) + else: + prefix = Context.workflow_run_id + + DynamoPrefixContext.set(prefix) + context_prefix_id = DynamoPrefixContext.get() if context_prefix_id: @@ -371,7 +426,10 @@ def create_httpx_client_with_dynamo_hooks( def _create_prediction_request_hook( prediction: LLMCallPrediction, ) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: """ - Create an httpx event hook that injects prediction headers. + Create an httpx event hook that overrides x-prefix-* headers from static prediction data. + + This hook converts numeric prediction values to categorical values (LOW/MEDIUM/HIGH) + and overrides the x-prefix-* headers set by the Dynamo prefix hook. Args: prediction: The prediction data to inject @@ -379,18 +437,22 @@ def _create_prediction_request_hook( Returns: An async function suitable for use as an httpx event hook. """ + # Pre-compute categorical values from prediction + total_requests = int(prediction.remaining_calls.mean) + osl = _output_tokens_to_osl(prediction.output_tokens.p90) + iat = _interarrival_ms_to_iat(prediction.interarrival_ms.mean) async def on_request(request): - """Inject prediction headers before each request.""" - request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) - request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) - request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + """Override x-prefix-* headers with prediction-derived values.""" + request.headers["x-prefix-total-requests"] = str(total_requests) + request.headers["x-prefix-osl"] = osl + request.headers["x-prefix-iat"] = iat logger.debug( - "Injected prediction headers: remaining=%d, interarrival=%d, output_tokens=%d", - int(prediction.remaining_calls.mean), - int(prediction.interarrival_ms.mean), - int(prediction.output_tokens.p90), + "Overrode prefix headers from static prediction: total_requests=%d, osl=%s, iat=%s", + total_requests, + osl, + iat, ) return on_request @@ -402,7 +464,15 @@ def _create_dynamic_prediction_hook( Create an httpx event hook that dynamically looks up predictions per request. This hook reads the current function path and call index from context, - looks up the prediction in the trie, and injects headers. + looks up the prediction in the trie, and overrides the x-prefix-* headers + with values derived from the prediction. The numeric prediction values + are converted to categorical values (LOW/MEDIUM/HIGH) for consistency + with static configuration. + + When a prediction is found, this hook overrides: + - x-prefix-total-requests: from remaining_calls.mean + - x-prefix-osl: converted from output_tokens.p90 + - x-prefix-iat: converted from interarrival_ms.mean Args: trie_lookup: The PredictionTrieLookup instance to query @@ -412,7 +482,7 @@ def _create_dynamic_prediction_hook( """ async def on_request(request: "httpx.Request") -> None: - """Look up prediction from context and inject headers.""" + """Look up prediction from context and override x-prefix-* headers.""" from nat.builder.context import Context from nat.llm.prediction_context import get_call_tracker @@ -431,24 +501,32 @@ async def on_request(request: "httpx.Request") -> None: prediction = trie_lookup.find(path, call_index) if prediction: - request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) - request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) - request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) + # Convert numeric predictions to categorical values and override headers + total_requests = int(prediction.remaining_calls.mean) + osl = _output_tokens_to_osl(prediction.output_tokens.p90) + iat = _interarrival_ms_to_iat(prediction.interarrival_ms.mean) + + request.headers["x-prefix-total-requests"] = str(total_requests) + request.headers["x-prefix-osl"] = osl + request.headers["x-prefix-iat"] = iat logger.debug( - "Injected prediction headers: path=%s, call_index=%d, remaining=%d, interarrival=%d, output=%d", + "Overrode prefix headers from prediction: path=%s, call_index=%d, " + "total_requests=%d, osl=%s (tokens=%d), iat=%s (ms=%d)", path, call_index, - int(prediction.remaining_calls.mean), - int(prediction.interarrival_ms.mean), + total_requests, + osl, int(prediction.output_tokens.p90), + iat, + int(prediction.interarrival_ms.mean), ) else: - logger.debug("No prediction found for path=%s, call_index=%d", path, call_index) + logger.debug("No prediction found for path=%s, call_index=%d; using static values", path, call_index) except Exception as e: # Don't fail the request if prediction lookup fails - logger.warning("Failed to inject prediction headers: %s", e) + logger.warning("Failed to override prefix headers from prediction: %s", e) return on_request diff --git a/tests/nat/builder/test_call_tracker_integration.py b/tests/nat/builder/test_call_tracker_integration.py index ec06ae5521..dc1695f91b 100644 --- a/tests/nat/builder/test_call_tracker_integration.py +++ b/tests/nat/builder/test_call_tracker_integration.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - from nat.builder.context import Context from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py index cbb2196b9b..1bab5ab641 100644 --- a/tests/nat/llm/test_dynamic_prediction_hook.py +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -53,7 +53,7 @@ def __init__(self): async def test_dynamic_hook_injects_headers(sample_trie_lookup): - """Test that dynamic hook injects prediction headers based on context.""" + """Test that dynamic hook overrides x-prefix-* headers based on context predictions.""" ctx = Context.get() state = ctx._context_state @@ -71,10 +71,14 @@ async def test_dynamic_hook_injects_headers(sample_trie_lookup): request = MockRequest() await hook(request) - assert "x-nat-remaining-llm-calls" in request.headers - assert request.headers["x-nat-remaining-llm-calls"] == "3" - assert request.headers["x-nat-interarrival-ms"] == "500" - assert request.headers["x-nat-expected-output-tokens"] == "200" + # Prediction values are converted to x-prefix-* headers: + # - remaining_calls.mean=3.0 -> x-prefix-total-requests="3" + # - output_tokens.p90=200.0 -> x-prefix-osl="LOW" (< 256) + # - interarrival_ms.mean=500.0 -> x-prefix-iat="HIGH" (>= 500) + assert "x-prefix-total-requests" in request.headers + assert request.headers["x-prefix-total-requests"] == "3" + assert request.headers["x-prefix-osl"] == "LOW" + assert request.headers["x-prefix-iat"] == "HIGH" async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): @@ -95,7 +99,7 @@ async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): await hook(request) # Should fall back to root aggregated predictions - assert "x-nat-remaining-llm-calls" in request.headers + assert "x-prefix-total-requests" in request.headers async def test_dynamic_hook_handles_empty_context(sample_trie_lookup): @@ -114,7 +118,7 @@ async def test_dynamic_hook_handles_empty_context(sample_trie_lookup): await hook(request) # Should still inject headers from root fallback - assert "x-nat-remaining-llm-calls" in request.headers + assert "x-prefix-total-requests" in request.headers async def test_dynamic_hook_no_prediction_found(): @@ -135,8 +139,9 @@ async def test_dynamic_hook_no_prediction_found(): request = MockRequest() await hook(request) - # Headers should not be injected when no prediction found - assert "x-nat-remaining-llm-calls" not in request.headers + # Headers should not be overridden when no prediction found + # (the static Dynamo hook would set them, but this hook runs after) + assert "x-prefix-total-requests" not in request.headers async def test_client_includes_prediction_hook_when_lookup_provided(sample_trie_lookup): diff --git a/tests/nat/llm/test_prediction_context.py b/tests/nat/llm/test_prediction_context.py index d46166d7d0..b0132afc35 100644 --- a/tests/nat/llm/test_prediction_context.py +++ b/tests/nat/llm/test_prediction_context.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 - from nat.llm.prediction_context import LLMCallTracker from nat.llm.prediction_context import get_call_tracker diff --git a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py new file mode 100644 index 0000000000..2156f68d51 --- /dev/null +++ b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from nat.builder.builder import Builder +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.plugins.langchain.llm import dynamo_langchain +from nat.profiler.prediction_trie import PredictionTrieLookup +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +@pytest.fixture(name="trie_file") +def fixture_trie_file(): + """Create a temporary trie file.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(root, path, workflow_name="test") + yield str(path) + + +@pytest.fixture(name="mock_builder") +def fixture_mock_builder(): + """Create a mock builder.""" + return MagicMock(spec=Builder) + + +def test_dynamo_config_with_valid_trie_path(trie_file): + """Test that DynamoModelConfig can be created with valid trie path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path=trie_file, + ) + + assert config.prediction_trie_path == trie_file + + +def test_dynamo_config_with_nonexistent_trie_path(): + """Test that DynamoModelConfig accepts nonexistent path (validated at load time).""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path="/nonexistent/path/trie.json", + ) + + # Config creation should succeed; error happens at runtime + assert config.prediction_trie_path == "/nonexistent/path/trie.json" + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): + """Test that dynamo_langchain loads trie from path and passes PredictionTrieLookup to httpx client.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path=trie_file, + ) + + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert "prediction_lookup" in call_kwargs + assert isinstance(call_kwargs["prediction_lookup"], PredictionTrieLookup) + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain logs warning and continues when trie file doesn't exist.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path="/nonexistent/path/trie.json", + ) + + # Should not raise an exception + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup=None + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain passes None when no trie path is configured.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", # prediction_trie_path is None by default + ) + + async with dynamo_langchain(config, mock_builder): + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain logs warning and continues when trie file is invalid JSON.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not valid json {{{") + invalid_trie_path = f.name + + try: + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path=invalid_trie_path, + ) + + # Should not raise an exception + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup=None + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + finally: + Path(invalid_trie_path).unlink(missing_ok=True) From a1ab80c5b1fde7bb010dd8220e1a1a37abe19892 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 12:14:26 -0800 Subject: [PATCH 27/37] Refactor Dynamo prefix handling to centralize logic. Moved logic for setting Dynamo prefix ID into `DynamoPrefixContext.get` for better reusability and clarity. Removed redundant code from request header injection, ensuring consistent prefix generation and logging behavior. Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 50916fcf1c..3ecacdfb2a 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -170,7 +170,23 @@ def clear(cls) -> None: @classmethod def get(cls) -> str | None: """Get the current Dynamo prefix ID from context, if any.""" - return cls._current_prefix_id.get() + cur_prefix = cls._current_prefix_id.get() + + if not cur_prefix: + + import uuid + + from nat.builder.context import Context + logger.debug("No Dynamo prefix ID set in context") + if not Context.workflow_run_id: + logger.warning("No workflow_run_id in context; using unique prefix ID.") + prefix = str(uuid.uuid4().hex[:16]) + else: + prefix = Context.workflow_run_id + cls.set(prefix) + return prefix + + return cur_prefix @classmethod def is_set(cls) -> bool: @@ -344,16 +360,6 @@ async def on_request(request): """Inject Dynamo prefix headers before each request.""" # Check context variable first (allows per-question override in batch evaluation) - if not DynamoPrefixContext.is_set(): - if not Context.workflow_run_id: - logger.warning("No workflow_run_id in context; using unique prefix ID.") - import uuid - prefix = str(uuid.uuid4().hex[:16]) - else: - prefix = Context.workflow_run_id - - DynamoPrefixContext.set(prefix) - context_prefix_id = DynamoPrefixContext.get() if context_prefix_id: @@ -483,7 +489,6 @@ def _create_dynamic_prediction_hook( async def on_request(request: "httpx.Request") -> None: """Look up prediction from context and override x-prefix-* headers.""" - from nat.builder.context import Context from nat.llm.prediction_context import get_call_tracker try: From 2fc6a09b7bb6a977c93023746c6285ca106dc706 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 13:04:13 -0800 Subject: [PATCH 28/37] Refactor DynamoPrefixContext for depth-aware prefix handling Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv --- src/nat/llm/dynamo_llm.py | 164 ++++++++++++++++++------------- tests/nat/llm/test_dynamo_llm.py | 114 ++++++++++++--------- 2 files changed, 164 insertions(+), 114 deletions(-) diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 3ecacdfb2a..878d03623c 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -125,85 +125,124 @@ class DynamoPrefixContext(metaclass=Singleton): """ Singleton class for managing Dynamo prefix IDs across LLM calls. - This allows evaluation code to set a prefix ID that persists across all LLM - calls for a single evaluation question (multi-turn conversation). + Prefix IDs are unique per depth level in the function call stack, allowing + different caching behavior at different levels of nested function calls. + Each depth level gets its own prefix ID that remains constant within a + single workflow run but changes between runs. + + The prefix ID format is: ``{workflow_run_id}-d{depth}`` Usage:: from nat.llm.dynamo_llm import DynamoPrefixContext - # Set prefix ID at the start of each evaluation question - DynamoPrefixContext.set("eval-q001-abc123") - - # ... perform LLM calls ... - - # Clear when done - DynamoPrefixContext.clear() + # Automatically gets prefix ID based on current call stack depth + prefix_id = DynamoPrefixContext.get() - # Or use as a context manager + # Or use as a context manager for explicit control with DynamoPrefixContext.scope("eval-q001-abc123"): - # ... perform LLM calls ... + # All LLM calls here will use "eval-q001-abc123" prefix + ... """ - _current_prefix_id: ContextVar[str | None] = ContextVar('dynamo_prefix_id', default=None) + # Maps depth -> prefix_id for the current workflow run + _prefix_ids_by_depth: ContextVar[dict[int, str] | None] = ContextVar('dynamo_prefix_ids_by_depth', default=None) + # Optional override that takes precedence over depth-based IDs + _override_prefix_id: ContextVar[str | None] = ContextVar('dynamo_override_prefix_id', default=None) + + @classmethod + def _get_current_depth(cls) -> int: + """Get the current function call stack depth from Context.""" + try: + ctx = Context.get() + return len(ctx.function_path) + except Exception: + return 0 + + @classmethod + def _get_or_create_depth_map(cls) -> dict[int, str]: + """Get or create the depth -> prefix_id mapping for this context.""" + depth_map = cls._prefix_ids_by_depth.get() + if depth_map is None: + depth_map = {} + cls._prefix_ids_by_depth.set(depth_map) + return depth_map @classmethod def set(cls, prefix_id: str) -> None: """ - Set the Dynamo prefix ID for the current context. + Set an override prefix ID that takes precedence over depth-based IDs. - Call this at the start of each evaluation question to ensure all LLM calls - for that question share the same prefix ID (enabling KV cache reuse). + Use this when you need explicit control over the prefix ID, such as + during batch evaluation where each question should have a specific ID. Args: - prefix_id: The unique prefix ID (e.g., "eval-q001-abc123") + prefix_id: The prefix ID to use (overrides depth-based generation) """ - cls._current_prefix_id.set(prefix_id) - logger.debug("Set Dynamo prefix ID: %s", prefix_id) + cls._override_prefix_id.set(prefix_id) + logger.debug("Set override Dynamo prefix ID: %s", prefix_id) @classmethod def clear(cls) -> None: - """Clear the current Dynamo prefix ID context.""" - cls._current_prefix_id.set(None) - logger.debug("Cleared Dynamo prefix ID") + """Clear all prefix ID state (both override and depth-based).""" + cls._override_prefix_id.set(None) + cls._prefix_ids_by_depth.set(None) + logger.debug("Cleared Dynamo prefix ID context") @classmethod - def get(cls) -> str | None: - """Get the current Dynamo prefix ID from context, if any.""" - cur_prefix = cls._current_prefix_id.get() - - if not cur_prefix: + def get(cls) -> str: + """ + Get the Dynamo prefix ID for the current context. - import uuid + Returns the override prefix ID if set, otherwise returns a depth-based + prefix ID that is unique per workflow run and call stack depth. - from nat.builder.context import Context - logger.debug("No Dynamo prefix ID set in context") - if not Context.workflow_run_id: + Returns: + The prefix ID string, never None. + """ + # Check for override first + override = cls._override_prefix_id.get() + if override: + return override + + # Get depth-based prefix ID + depth = cls._get_current_depth() + depth_map = cls._get_or_create_depth_map() + + if depth not in depth_map: + # Generate new prefix ID for this depth + try: + ctx = Context.get() + workflow_id = ctx.workflow_run_id + except Exception: + workflow_id = None + + if not workflow_id: logger.warning("No workflow_run_id in context; using unique prefix ID.") - prefix = str(uuid.uuid4().hex[:16]) - else: - prefix = Context.workflow_run_id - cls.set(prefix) - return prefix + workflow_id = uuid.uuid4().hex[:16] + + prefix_id = f"{workflow_id}-d{depth}" + depth_map[depth] = prefix_id + logger.debug("Generated Dynamo prefix ID for depth %d: %s", depth, prefix_id) - return cur_prefix + return depth_map[depth] @classmethod def is_set(cls) -> bool: - """Check if a Dynamo prefix ID is currently set in context.""" - return cls.get() is not None + """Check if a Dynamo prefix ID is available (always True, IDs are auto-generated).""" + return True @classmethod @contextmanager def scope(cls, prefix_id: str) -> Iterator[None]: """ - Context manager for scoped prefix ID usage. + Context manager for scoped override prefix ID usage. - Automatically sets the prefix ID on entry and clears it on exit, - ensuring proper cleanup even if exceptions occur. + Sets an override prefix ID on entry and restores the previous state on exit, + ensuring proper cleanup even if exceptions occur. Supports nesting. Args: - prefix_id: The unique prefix ID for this scope + prefix_id: The override prefix ID for this scope Yields: None @@ -213,11 +252,12 @@ def scope(cls, prefix_id: str) -> Iterator[None]: # All LLM calls here will use "eval-q001" prefix await llm.ainvoke(...) """ + previous_override = cls._override_prefix_id.get() cls.set(prefix_id) try: yield finally: - cls.clear() + cls._override_prefix_id.set(previous_override) # ============================================================================= @@ -330,15 +370,13 @@ def _create_dynamo_request_hook( Create an httpx event hook that injects Dynamo prefix headers into requests. This hook is called before each HTTP request is sent, allowing us to inject - headers dynamically. The prefix ID is generated ONCE when the hook is created, - ensuring all requests from the same client share the same prefix ID. This enables - Dynamo's KV cache optimization across multi-turn conversations. - - The context variable can override this for scenarios where you need different - prefix IDs (e.g., per-question in batch evaluation). + headers dynamically. The prefix ID is obtained from DynamoPrefixContext which + provides depth-aware prefix IDs - each level in the function call stack gets + its own unique prefix ID that remains constant within a workflow run. Args: - prefix_template: Template string with {uuid} placeholder + prefix_template: Template string with {uuid} placeholder (currently unused, + kept for API compatibility) total_requests: Expected number of requests for this prefix osl: Output sequence length hint (LOW/MEDIUM/HIGH) iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) @@ -346,29 +384,15 @@ def _create_dynamo_request_hook( Returns: An async function suitable for use as an httpx event hook. """ - # Generate the default prefix ID ONCE when the hook is created - # This ensures all requests from this client share the same prefix ID - unique_id = uuid.uuid4().hex[:16] - if prefix_template: - default_prefix_id = prefix_template.format(uuid=unique_id) - else: - default_prefix_id = f"nat-dynamo-{unique_id}" - - logger.debug("Created Dynamo request hook with default prefix ID: %s", default_prefix_id) + # Note: prefix_template is kept for API compatibility but no longer used. + # Prefix IDs are now managed by DynamoPrefixContext with depth-awareness. + _ = prefix_template # Suppress unused parameter warning async def on_request(request): """Inject Dynamo prefix headers before each request.""" - # Check context variable first (allows per-question override in batch evaluation) - - context_prefix_id = DynamoPrefixContext.get() - - if context_prefix_id: - prefix_id = context_prefix_id - logger.debug("Using context prefix ID: %s", prefix_id) - else: - # Use the pre-generated prefix ID (same for all requests from this client) - prefix_id = default_prefix_id - logger.debug("Using default prefix ID: %s", prefix_id) + # Get depth-aware prefix ID from context + prefix_id = DynamoPrefixContext.get() + logger.debug("Using depth-aware prefix ID: %s", prefix_id) # Inject Dynamo headers request.headers["x-prefix-id"] = prefix_id diff --git a/tests/nat/llm/test_dynamo_llm.py b/tests/nat/llm/test_dynamo_llm.py index d73203fdcf..2b6ba50650 100644 --- a/tests/nat/llm/test_dynamo_llm.py +++ b/tests/nat/llm/test_dynamo_llm.py @@ -149,26 +149,37 @@ def test_get_dynamo_field_names(self): class TestDynamoPrefixContext: """Tests for DynamoPrefixContext singleton class.""" - def test_set_and_get_prefix_id(self): - """Test setting and getting prefix ID.""" - # Ensure clean state + def test_auto_generates_depth_based_prefix(self): + """Test that get() auto-generates a depth-based prefix when no override is set.""" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None - # Set and get + # get() always returns a value - auto-generated if no override + prefix = DynamoPrefixContext.get() + assert prefix is not None + assert "-d0" in prefix # Depth 0 at root level + + def test_set_and_get_override_prefix_id(self): + """Test setting and getting an override prefix ID.""" + DynamoPrefixContext.clear() + + # Set override DynamoPrefixContext.set("test-prefix-123") assert DynamoPrefixContext.get() == "test-prefix-123" # Clean up DynamoPrefixContext.clear() - def test_clear_prefix_id(self): - """Test clearing prefix ID.""" + def test_clear_removes_override_but_auto_generates(self): + """Test that clear() removes override but get() still returns auto-generated value.""" DynamoPrefixContext.set("test-prefix-456") assert DynamoPrefixContext.get() == "test-prefix-456" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None + # After clear, get() returns auto-generated depth-based prefix + prefix = DynamoPrefixContext.get() + assert prefix is not None + assert prefix != "test-prefix-456" + assert "-d0" in prefix def test_overwrite_prefix_id(self): """Test that setting a new prefix ID overwrites the old one.""" @@ -183,18 +194,19 @@ def test_overwrite_prefix_id(self): DynamoPrefixContext.clear() def test_scope_context_manager(self): - """Test the scope context manager for automatic cleanup.""" + """Test the scope context manager with override prefix.""" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None with DynamoPrefixContext.scope("scoped-prefix-789"): assert DynamoPrefixContext.get() == "scoped-prefix-789" - # Should be cleared after exiting context - assert DynamoPrefixContext.get() is None + # After exiting scope, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "scoped-prefix-789" + assert "-d0" in prefix def test_scope_context_manager_cleanup_on_exception(self): - """Test that scope context manager clears prefix ID even on exception.""" + """Test that scope context manager restores state even on exception.""" DynamoPrefixContext.clear() with pytest.raises(ValueError): @@ -202,22 +214,31 @@ def test_scope_context_manager_cleanup_on_exception(self): assert DynamoPrefixContext.get() == "error-prefix" raise ValueError("Test exception") - # Should still be cleared after exception - assert DynamoPrefixContext.get() is None + # After exception, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "error-prefix" + assert "-d0" in prefix - def test_scope_nested_replaces_then_clears(self): - """Test that nested scopes work but outer scope is lost after inner exits.""" + def test_scope_nested_restores_outer(self): + """Test that nested scopes properly restore outer scope value.""" DynamoPrefixContext.clear() with DynamoPrefixContext.scope("outer"): assert DynamoPrefixContext.get() == "outer" with DynamoPrefixContext.scope("inner"): assert DynamoPrefixContext.get() == "inner" - # After inner scope exits, it clears - outer value is lost - assert DynamoPrefixContext.get() is None + # After inner scope exits, outer value is restored + assert DynamoPrefixContext.get() == "outer" - # Still None after outer exits - assert DynamoPrefixContext.get() is None + # After outer scope exits, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "outer" + assert "-d0" in prefix + + def test_is_set_always_true(self): + """Test that is_set() always returns True since IDs are auto-generated.""" + DynamoPrefixContext.clear() + assert DynamoPrefixContext.is_set() is True # --------------------------------------------------------------------------- @@ -251,15 +272,16 @@ async def test_hook_injects_headers(self): await hook(mock_request) + # Prefix ID comes from DynamoPrefixContext (depth-based) assert "x-prefix-id" in mock_request.headers - assert mock_request.headers["x-prefix-id"].startswith("test-") + assert "-d0" in mock_request.headers["x-prefix-id"] # Depth 0 assert mock_request.headers["x-prefix-total-requests"] == "15" assert mock_request.headers["x-prefix-osl"] == "HIGH" assert mock_request.headers["x-prefix-iat"] == "LOW" @pytest.mark.asyncio async def test_hook_uses_context_prefix_id(self): - """Test that the hook uses context variable prefix ID when set.""" + """Test that the hook uses context override prefix ID when set.""" hook = _create_dynamo_request_hook( prefix_template="template-{uuid}", total_requests=10, @@ -267,7 +289,7 @@ async def test_hook_uses_context_prefix_id(self): iat="MEDIUM", ) - # Set context prefix ID + # Set context override prefix ID DynamoPrefixContext.set("context-prefix-abc") mock_request = MagicMock() @@ -275,16 +297,16 @@ async def test_hook_uses_context_prefix_id(self): await hook(mock_request) - # Should use context prefix ID, not generate from template + # Should use context override prefix ID assert mock_request.headers["x-prefix-id"] == "context-prefix-abc" @pytest.mark.asyncio - async def test_hook_uses_same_id_for_all_requests(self): - """Test that the hook uses the same prefix ID for all requests from the same client. + async def test_hook_uses_same_id_for_same_depth(self): + """Test that the hook uses the same prefix ID for all requests at the same depth. This ensures Dynamo's KV cache optimization works across multi-turn conversations. - All requests from the same client (created with the same hook) should share - the same prefix ID to enable KV cache reuse. + All requests at the same depth within a workflow run should share the same + prefix ID to enable KV cache reuse. """ hook = _create_dynamo_request_hook( prefix_template="session-{uuid}", @@ -300,14 +322,14 @@ async def test_hook_uses_same_id_for_all_requests(self): await hook(mock_request) prefix_ids.add(mock_request.headers["x-prefix-id"]) - # All requests should share the SAME prefix ID (for KV cache optimization) + # All requests at the same depth should share the SAME prefix ID assert len(prefix_ids) == 1 - # And it should start with our template - assert list(prefix_ids)[0].startswith("session-") + # Should contain depth marker + assert "-d0" in list(prefix_ids)[0] @pytest.mark.asyncio - async def test_different_hooks_have_different_ids(self): - """Test that different hooks (different clients) get different prefix IDs.""" + async def test_hooks_share_id_at_same_depth(self): + """Test that multiple hooks share the same prefix ID at the same depth.""" prefix_ids = set() for _ in range(5): hook = _create_dynamo_request_hook( @@ -321,12 +343,12 @@ async def test_different_hooks_have_different_ids(self): await hook(mock_request) prefix_ids.add(mock_request.headers["x-prefix-id"]) - # Different hooks should have different prefix IDs - assert len(prefix_ids) == 5 + # All hooks at the same depth share the same prefix ID within a context + assert len(prefix_ids) == 1 @pytest.mark.asyncio - async def test_hook_default_prefix_template(self): - """Test that the hook uses default prefix format when template is None.""" + async def test_hook_uses_depth_based_prefix(self): + """Test that the hook uses depth-based prefix format.""" hook = _create_dynamo_request_hook( prefix_template=None, total_requests=10, @@ -339,8 +361,9 @@ async def test_hook_default_prefix_template(self): await hook(mock_request) - # Should use default "nat-dynamo-{id}" format - assert mock_request.headers["x-prefix-id"].startswith("nat-dynamo-") + # Should use depth-based format "{workflow_id}-d{depth}" + prefix_id = mock_request.headers["x-prefix-id"] + assert "-d0" in prefix_id # Depth 0 at root level @pytest.mark.asyncio async def test_hook_normalizes_case(self): @@ -361,8 +384,8 @@ async def test_hook_normalizes_case(self): assert mock_request.headers["x-prefix-iat"] == "HIGH" @pytest.mark.asyncio - async def test_hook_template_without_uuid_placeholder(self): - """Test that a template without {uuid} placeholder uses template as-is.""" + async def test_hook_with_override_ignores_depth(self): + """Test that setting an override prefix uses it instead of depth-based ID.""" hook = _create_dynamo_request_hook( prefix_template="static-prefix-no-uuid", total_requests=10, @@ -370,12 +393,15 @@ async def test_hook_template_without_uuid_placeholder(self): iat="MEDIUM", ) + # Set override + DynamoPrefixContext.set("my-override-prefix") + mock_request = MagicMock() mock_request.headers = {} await hook(mock_request) - # Template used as-is when no {uuid} placeholder - assert mock_request.headers["x-prefix-id"] == "static-prefix-no-uuid" + # Override prefix is used + assert mock_request.headers["x-prefix-id"] == "my-override-prefix" # --------------------------------------------------------------------------- From 37ad4a49620990f3a500bf29c039f1f71338b3af Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 13:04:40 -0800 Subject: [PATCH 29/37] Refactor DynamoPrefixContext for depth-aware prefix handling Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv --- .../2026-01-23-prediction-trie-design.md | 342 --- ...26-01-23-prediction-trie-implementation.md | 1840 ----------------- ...4-prediction-trie-example-config-design.md | 127 -- ...26-01-24-runtime-prediction-trie-design.md | 248 --- ...-runtime-prediction-trie-implementation.md | 1066 ---------- 5 files changed, 3623 deletions(-) delete mode 100644 docs/plans/2026-01-23-prediction-trie-design.md delete mode 100644 docs/plans/2026-01-23-prediction-trie-implementation.md delete mode 100644 docs/plans/2026-01-24-prediction-trie-example-config-design.md delete mode 100644 docs/plans/2026-01-24-runtime-prediction-trie-design.md delete mode 100644 docs/plans/2026-01-24-runtime-prediction-trie-implementation.md diff --git a/docs/plans/2026-01-23-prediction-trie-design.md b/docs/plans/2026-01-23-prediction-trie-design.md deleted file mode 100644 index ad14bbdcaf..0000000000 --- a/docs/plans/2026-01-23-prediction-trie-design.md +++ /dev/null @@ -1,342 +0,0 @@ -# Prediction Trie for Dynamo Inference Routing - -**Date:** 2026-01-23 -**Status:** Approved -**Author:** Design session with Claude - -## Overview - -A prediction system that provides the Dynamo inference server with expected workload characteristics for each LLM call—remaining calls, inter-arrival time, and expected output length—enabling smarter routing decisions. - -## Problem - -The Dynamo inference server can make better routing decisions if it knows: -- How many more LLM calls are expected in this workflow -- When the next LLM call will arrive -- How long the response will be - -Currently, each LLM request arrives without this context. The server treats each call independently, missing optimization opportunities. - -## Solution - -Build a prediction trie from profiler data that captures LLM call patterns at multiple granularities. At runtime, inject predictions as HTTP headers on inference requests. - -### End-to-End Flow - -``` -┌─────────────────────────────────────────────────────────────────────────────┐ -│ PROFILING PHASE │ -├─────────────────────────────────────────────────────────────────────────────┤ -│ 1. Run profiler on workflow with representative inputs │ -│ 2. Collect IntermediateStep traces with full ancestry │ -│ 3. Build PredictionTrie from LLM_END events │ -│ 4. Serialize to prediction_trie.json │ -└─────────────────────────────────────────────────────────────────────────────┘ - │ - ▼ -┌─────────────────────────────────────────────────────────────────────────────┐ -│ RUNTIME PHASE │ -├─────────────────────────────────────────────────────────────────────────────┤ -│ 1. LLM client loads prediction_trie.json at startup │ -│ 2. On each LLM call: │ -│ a. Get current function path from context │ -│ b. Increment and get call_index from tracker │ -│ c. Lookup prediction in trie │ -│ d. Inject headers into request │ -│ 3. Dynamo server uses headers for routing decisions │ -└─────────────────────────────────────────────────────────────────────────────┘ -``` - -## Data Structures - -### Prediction Metrics - -```python -@dataclass -class PredictionMetrics: - """Stats for a single metric, pre-computed from profiler data.""" - sample_count: int - mean: float - p50: float - p90: float - p95: float -``` - -### LLM Call Prediction - -```python -@dataclass -class LLMCallPrediction: - """What we predict for an LLM call at a given position.""" - remaining_calls: PredictionMetrics # How many more LLM calls expected - interarrival_ms: PredictionMetrics # Time until next LLM call - output_tokens: PredictionMetrics # Expected output length -``` - -### Prediction Trie Node - -```python -@dataclass -class PredictionTrieNode: - """A node in the prediction trie.""" - name: str # Function name at this level - children: dict[str, PredictionTrieNode] # Child nodes by function name - predictions_by_call_index: dict[int, LLMCallPrediction] # Metrics keyed by call index - predictions_any_index: LLMCallPrediction | None # Fallback: aggregated across all indices -``` - -### Trie Structure Example - -``` -root -├── workflow (stats: all LLM calls in any workflow) -│ └── react_agent (stats: all LLM calls under react_agent) -│ ├── search_tool (stats: LLM calls under search_tool) -│ │ └── llm:1 (stats: first LLM call in search_tool) -│ │ └── llm:2 (stats: second LLM call) -│ └── calculator_tool -│ └── llm:1 (stats: first LLM call in calculator_tool) -``` - -## Building the Trie - -### LLM Call Context Extraction - -For each `LLM_END` event in a profiler trace: - -```python -@dataclass -class LLMCallContext: - path: list[str] # ["workflow", "react_agent", "search_tool"] - call_index: int # Nth LLM call within the immediate parent - remaining_calls: int # How many LLM calls left in this workflow run - time_to_next_ms: float # Milliseconds until next LLM_START (or None if last) - output_tokens: int # Actual completion tokens -``` - -### Call Index Scoping - -Call index is scoped to the immediate parent function: - -``` -workflow (run_id=1) - └── react_agent (invocation_id=a1) - ├── LLM call (call_index=1 within react_agent) - ├── search_tool - │ └── LLM call (call_index=1 within search_tool) - └── LLM call (call_index=2 within react_agent) -``` - -### Trie Update Algorithm - -For each LLM call, walk its ancestry path and update every node: - -```python -def update_trie(root: PredictionTrieNode, ctx: LLMCallContext): - node = root - # Walk path, updating aggregates at each level - for func_name in ctx.path: - node.add_sample(ctx.call_index, ctx.remaining_calls, ctx.time_to_next_ms, ctx.output_tokens) - node = node.children.setdefault(func_name, PredictionTrieNode(func_name)) - # Update leaf node too - node.add_sample(ctx.call_index, ctx.remaining_calls, ctx.time_to_next_ms, ctx.output_tokens) -``` - -This means a single LLM call contributes samples to every ancestor node—giving us aggregated stats at every granularity automatically. - -## Runtime Lookup - -### Current Context - -```python -@dataclass -class CurrentContext: - path: list[str] # Current function ancestry - call_index: int # Which LLM call this is within the immediate parent -``` - -### Lookup Algorithm - -```python -def lookup(root: PredictionTrieNode, ctx: CurrentContext) -> LLMCallPrediction | None: - node = root - deepest_match = None - - # Walk the trie as far as we can match - for func_name in ctx.path: - # Capture this node as a potential match before descending - prediction = node.predictions_by_call_index.get(ctx.call_index) - if prediction is None: - prediction = node.predictions_any_index - if prediction is not None: - deepest_match = prediction - - # Try to descend - if func_name not in node.children: - break - node = node.children[func_name] - - # Check the final node we reached - prediction = node.predictions_by_call_index.get(ctx.call_index) - if prediction is None: - prediction = node.predictions_any_index - if prediction is not None: - deepest_match = prediction - - return deepest_match -``` - -### Fallback Behavior - -1. Try exact path + exact call index (most specific) -2. Try exact path + any call index -3. Try partial path + exact call index -4. Try partial path + any call index (most general) - -Novel tool calls automatically get predictions based on agent-level stats. - -## Runtime Call Index Tracking - -```python -from contextvars import ContextVar - -@dataclass -class LLMCallTracker: - """Tracks LLM call counts per function invocation.""" - counts: dict[str, int] = field(default_factory=dict) - - def increment(self, parent_function_id: str) -> int: - """Increment and return the call index for this parent.""" - self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 - return self.counts[parent_function_id] - - def reset(self, parent_function_id: str): - """Reset when a function invocation completes.""" - self.counts.pop(parent_function_id, None) - -_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar('llm_call_tracker') -``` - -## Header Injection - -### Headers - -``` -X-NAT-Remaining-LLM-Calls: 3 -X-NAT-Interarrival-Ms: 450 -X-NAT-Expected-Output-Tokens: 256 -X-NAT-Prediction-Confidence: 0.85 -``` - -### Integration Point - -```python -class DynamoLangChainLLM(BaseLLM): - prediction_trie: PredictionTrie | None = None - - def _call(self, prompt: str, **kwargs) -> str: - headers = self._get_base_headers() - - if self.prediction_trie is not None: - ctx = self._get_current_context() - prediction = self.prediction_trie.lookup(ctx) - if prediction: - headers["X-NAT-Remaining-LLM-Calls"] = str(prediction.remaining_calls.mean) - headers["X-NAT-Interarrival-Ms"] = str(prediction.interarrival_ms.mean) - headers["X-NAT-Expected-Output-Tokens"] = str(prediction.output_tokens.p90) - - return self._make_request(prompt, headers=headers, **kwargs) -``` - -### Configuration - -```yaml -llms: - my_llm: - _type: nim - model_name: meta/llama-3.1-70b-instruct - prediction_trie_path: ./profiler_output/prediction_trie.json -``` - -## Serialization - -### JSON Format - -```json -{ - "version": "1.0", - "generated_at": "2026-01-23T10:30:00Z", - "workflow_name": "my_workflow", - "sample_count": 150, - "root": { - "name": "root", - "predictions_by_call_index": { - "1": { - "remaining_calls": {"sample_count": 150, "mean": 4.2, "p50": 4, "p90": 6, "p95": 7}, - "interarrival_ms": {"sample_count": 150, "mean": 520, "p50": 480, "p90": 890, "p95": 1100}, - "output_tokens": {"sample_count": 150, "mean": 185, "p50": 160, "p90": 320, "p95": 410} - } - }, - "predictions_any_index": { ... }, - "children": { - "react_agent": { ... } - } - } -} -``` - -### Output Files - -``` -profiler_output/ -├── all_requests_profiler_traces.json -├── standardized_data_all.csv -├── inference_optimization.json -├── prediction_trie.json # NEW -└── prediction_trie_summary.txt # NEW: human-readable summary -``` - -## File Organization - -``` -src/nat/profiler/ -├── prediction_trie/ -│ ├── __init__.py -│ ├── data_models.py # PredictionTrieNode, LLMCallPrediction, PredictionMetrics -│ ├── trie_builder.py # Build trie from profiler traces -│ ├── trie_lookup.py # Lookup algorithm -│ └── serialization.py # JSON load/save - -src/nat/llm/ -├── prediction_context.py # LLMCallTracker, context variable, path extraction - -packages/nvidia_nat_langchain/src/nat/plugins/langchain/ -├── llm.py # Modify to inject headers -``` - -## Profiler Configuration - -```yaml -profiler: - base_metrics: true - prediction_trie: true - prediction_trie_output: ./prediction_trie.json -``` - -## Implementation Sequence - -1. **Data models** - `PredictionTrieNode`, `LLMCallPrediction`, `PredictionMetrics` -2. **Trie builder** - Parse profiler traces, extract LLM call contexts, build trie -3. **Serialization** - JSON save/load for the trie -4. **Trie lookup** - Walk trie, return deepest match with fallback -5. **Runtime tracking** - `LLMCallTracker` context variable, integrate with existing ancestry tracking -6. **Header injection** - Modify `dynamo_langchain` LLM client to inject headers -7. **Profiler integration** - Add config option, wire trie builder into profiler output -8. **Tests** - Unit tests for trie operations, integration test with sample traces - -## Out of Scope - -- Concurrency/parallelism tracking -- Input token bucketing for lookup -- Real-time trie updates during runtime -- Multiple trie versions/A-B testing diff --git a/docs/plans/2026-01-23-prediction-trie-implementation.md b/docs/plans/2026-01-23-prediction-trie-implementation.md deleted file mode 100644 index d8e11cc5de..0000000000 --- a/docs/plans/2026-01-23-prediction-trie-implementation.md +++ /dev/null @@ -1,1840 +0,0 @@ -# Prediction Trie Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Build a prediction trie that aggregates LLM call patterns from profiler data and injects routing hints as Dynamo headers at runtime. - -**Architecture:** The profiler builds a trie from execution traces where each node stores aggregated metrics (remaining calls, interarrival time, output tokens) by call index. At runtime, the Dynamo LLM client walks the trie to find the best match for the current execution context and injects predictions as HTTP headers. - -**Tech Stack:** Python 3.11+, Pydantic v2, httpx event hooks, contextvars - ---- - -## Task 1: Data Models - -**Files:** -- Create: `src/nat/profiler/prediction_trie/data_models.py` -- Test: `tests/nat/profiler/prediction_trie/test_data_models.py` - -### Step 1: Write the failing test for PredictionMetrics - -```python -# tests/nat/profiler/prediction_trie/test_data_models.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.profiler.prediction_trie.data_models import PredictionMetrics - - -def test_prediction_metrics_creation(): - metrics = PredictionMetrics(sample_count=10, mean=5.0, p50=4.5, p90=8.0, p95=9.0) - assert metrics.sample_count == 10 - assert metrics.mean == 5.0 - assert metrics.p50 == 4.5 - assert metrics.p90 == 8.0 - assert metrics.p95 == 9.0 - - -def test_prediction_metrics_defaults(): - metrics = PredictionMetrics() - assert metrics.sample_count == 0 - assert metrics.mean == 0.0 -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py::test_prediction_metrics_creation -v` -Expected: FAIL with "ModuleNotFoundError: No module named 'nat.profiler.prediction_trie'" - -### Step 3: Create the prediction_trie package and data models - -```python -# src/nat/profiler/prediction_trie/__init__.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - -__all__ = ["PredictionMetrics", "LLMCallPrediction", "PredictionTrieNode"] -``` - -```python -# src/nat/profiler/prediction_trie/data_models.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from pydantic import BaseModel -from pydantic import Field - - -class PredictionMetrics(BaseModel): - """Aggregated statistics for a single metric from profiler data.""" - - sample_count: int = Field(default=0, description="Number of samples") - mean: float = Field(default=0.0, description="Mean value") - p50: float = Field(default=0.0, description="50th percentile (median)") - p90: float = Field(default=0.0, description="90th percentile") - p95: float = Field(default=0.0, description="95th percentile") - - -class LLMCallPrediction(BaseModel): - """Predictions for an LLM call at a given position in the call hierarchy.""" - - remaining_calls: PredictionMetrics = Field( - default_factory=PredictionMetrics, - description="How many more LLM calls are expected after this one", - ) - interarrival_ms: PredictionMetrics = Field( - default_factory=PredictionMetrics, - description="Expected time in milliseconds until the next LLM call", - ) - output_tokens: PredictionMetrics = Field( - default_factory=PredictionMetrics, - description="Expected output token count for this call", - ) - - -class PredictionTrieNode(BaseModel): - """A node in the prediction trie representing a function in the call hierarchy.""" - - name: str = Field(description="Function name at this level in the hierarchy") - children: dict[str, PredictionTrieNode] = Field( - default_factory=dict, - description="Child nodes keyed by function name", - ) - predictions_by_call_index: dict[int, LLMCallPrediction] = Field( - default_factory=dict, - description="Predictions keyed by call index (1-indexed)", - ) - predictions_any_index: LLMCallPrediction | None = Field( - default=None, - description="Fallback predictions aggregated across all call indices", - ) - - -# Rebuild model to handle forward references -PredictionTrieNode.model_rebuild() -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py -v` -Expected: PASS - -### Step 5: Add tests for LLMCallPrediction and PredictionTrieNode - -Add to `tests/nat/profiler/prediction_trie/test_data_models.py`: - -```python -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - - -def test_llm_call_prediction_creation(): - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=5, mean=3.0, p50=3.0, p90=5.0, p95=6.0), - interarrival_ms=PredictionMetrics(sample_count=5, mean=500.0, p50=450.0, p90=800.0, p95=900.0), - output_tokens=PredictionMetrics(sample_count=5, mean=150.0, p50=140.0, p90=250.0, p95=300.0), - ) - assert prediction.remaining_calls.mean == 3.0 - assert prediction.interarrival_ms.mean == 500.0 - assert prediction.output_tokens.mean == 150.0 - - -def test_prediction_trie_node_creation(): - node = PredictionTrieNode(name="root") - assert node.name == "root" - assert node.children == {} - assert node.predictions_by_call_index == {} - assert node.predictions_any_index is None - - -def test_prediction_trie_node_with_children(): - child = PredictionTrieNode(name="react_agent") - root = PredictionTrieNode(name="root", children={"react_agent": child}) - assert "react_agent" in root.children - assert root.children["react_agent"].name == "react_agent" - - -def test_prediction_trie_node_with_predictions(): - prediction = LLMCallPrediction() - node = PredictionTrieNode( - name="agent", - predictions_by_call_index={1: prediction, 2: prediction}, - predictions_any_index=prediction, - ) - assert 1 in node.predictions_by_call_index - assert 2 in node.predictions_by_call_index - assert node.predictions_any_index is not None -``` - -### Step 6: Run all data model tests - -Run: `pytest tests/nat/profiler/prediction_trie/test_data_models.py -v` -Expected: PASS (all tests) - -### Step 7: Commit - -```bash -git add src/nat/profiler/prediction_trie/ tests/nat/profiler/prediction_trie/ -git commit --signoff -m "feat(profiler): add prediction trie data models - -Add Pydantic models for the prediction trie: -- PredictionMetrics: aggregated stats (mean, p50, p90, p95) -- LLMCallPrediction: predictions for remaining calls, interarrival time, output tokens -- PredictionTrieNode: trie node with children and predictions by call index" -``` - ---- - -## Task 2: Metrics Accumulator - -**Files:** -- Create: `src/nat/profiler/prediction_trie/metrics_accumulator.py` -- Test: `tests/nat/profiler/prediction_trie/test_metrics_accumulator.py` - -### Step 1: Write the failing test for MetricsAccumulator - -```python -# tests/nat/profiler/prediction_trie/test_metrics_accumulator.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator - - -def test_accumulator_add_single_sample(): - acc = MetricsAccumulator() - acc.add_sample(10.0) - metrics = acc.compute_metrics() - assert metrics.sample_count == 1 - assert metrics.mean == 10.0 - assert metrics.p50 == 10.0 - assert metrics.p90 == 10.0 - assert metrics.p95 == 10.0 - - -def test_accumulator_add_multiple_samples(): - acc = MetricsAccumulator() - for v in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]: - acc.add_sample(v) - metrics = acc.compute_metrics() - assert metrics.sample_count == 10 - assert metrics.mean == 5.5 - assert metrics.p50 == 5.5 # median of 1-10 - assert metrics.p90 == 9.1 # 90th percentile - assert metrics.p95 == 9.55 # 95th percentile - - -def test_accumulator_empty(): - acc = MetricsAccumulator() - metrics = acc.compute_metrics() - assert metrics.sample_count == 0 - assert metrics.mean == 0.0 -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/profiler/prediction_trie/test_metrics_accumulator.py::test_accumulator_add_single_sample -v` -Expected: FAIL with "ModuleNotFoundError" - -### Step 3: Implement MetricsAccumulator - -```python -# src/nat/profiler/prediction_trie/metrics_accumulator.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import math - -from nat.profiler.prediction_trie.data_models import PredictionMetrics - - -class MetricsAccumulator: - """Accumulates samples and computes aggregated statistics.""" - - def __init__(self) -> None: - self._samples: list[float] = [] - - def add_sample(self, value: float) -> None: - """Add a sample value to the accumulator.""" - self._samples.append(value) - - def compute_metrics(self) -> PredictionMetrics: - """Compute aggregated metrics from accumulated samples.""" - if not self._samples: - return PredictionMetrics() - - n = len(self._samples) - mean_val = sum(self._samples) / n - sorted_samples = sorted(self._samples) - - return PredictionMetrics( - sample_count=n, - mean=mean_val, - p50=self._percentile(sorted_samples, 50), - p90=self._percentile(sorted_samples, 90), - p95=self._percentile(sorted_samples, 95), - ) - - @staticmethod - def _percentile(sorted_data: list[float], pct: float) -> float: - """Compute percentile using linear interpolation.""" - if not sorted_data: - return 0.0 - if len(sorted_data) == 1: - return sorted_data[0] - k = (len(sorted_data) - 1) * (pct / 100.0) - f = math.floor(k) - c = math.ceil(k) - if f == c: - return sorted_data[int(k)] - return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f) -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/profiler/prediction_trie/test_metrics_accumulator.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/profiler/prediction_trie/metrics_accumulator.py tests/nat/profiler/prediction_trie/test_metrics_accumulator.py -git commit --signoff -m "feat(profiler): add MetricsAccumulator for prediction trie - -Accumulates sample values and computes aggregated statistics -(mean, p50, p90, p95) using linear interpolation for percentiles." -``` - ---- - -## Task 3: Trie Builder - -**Files:** -- Create: `src/nat/profiler/prediction_trie/trie_builder.py` -- Test: `tests/nat/profiler/prediction_trie/test_trie_builder.py` - -### Step 1: Write the failing test for TrieBuilder - -```python -# tests/nat/profiler/prediction_trie/test_trie_builder.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.data_models.intermediate_step import IntermediateStep -from nat.data_models.intermediate_step import IntermediateStepPayload -from nat.data_models.intermediate_step import IntermediateStepType -from nat.data_models.intermediate_step import UsageInfo -from nat.data_models.invocation_node import InvocationNode -from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel -from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder - - -@pytest.fixture(name="simple_trace") -def fixture_simple_trace() -> list[IntermediateStep]: - """Create a simple trace with two LLM calls.""" - return [ - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_START, - event_timestamp=1000.0, - UUID="llm-1", - ), - ), - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_END, - event_timestamp=1001.0, - span_event_timestamp=1000.0, - UUID="llm-1", - usage_info=UsageInfo( - token_usage=TokenUsageBaseModel(completion_tokens=100), - ), - ), - ), - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_START, - event_timestamp=1002.0, - UUID="llm-2", - ), - ), - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_END, - event_timestamp=1003.0, - span_event_timestamp=1002.0, - UUID="llm-2", - usage_info=UsageInfo( - token_usage=TokenUsageBaseModel(completion_tokens=150), - ), - ), - ), - ] - - -def test_trie_builder_builds_from_single_trace(simple_trace): - builder = PredictionTrieBuilder() - builder.add_trace(simple_trace) - trie = builder.build() - - assert trie.name == "root" - assert "my_workflow" in trie.children - - workflow_node = trie.children["my_workflow"] - # First LLM call: call_index=1, remaining=1 - assert 1 in workflow_node.predictions_by_call_index - # Second LLM call: call_index=2, remaining=0 - assert 2 in workflow_node.predictions_by_call_index - - -def test_trie_builder_computes_remaining_calls(simple_trace): - builder = PredictionTrieBuilder() - builder.add_trace(simple_trace) - trie = builder.build() - - workflow_node = trie.children["my_workflow"] - # First call should predict 1 remaining call - assert workflow_node.predictions_by_call_index[1].remaining_calls.mean == 1.0 - # Second call should predict 0 remaining calls - assert workflow_node.predictions_by_call_index[2].remaining_calls.mean == 0.0 - - -def test_trie_builder_computes_output_tokens(simple_trace): - builder = PredictionTrieBuilder() - builder.add_trace(simple_trace) - trie = builder.build() - - workflow_node = trie.children["my_workflow"] - # First call had 100 completion tokens - assert workflow_node.predictions_by_call_index[1].output_tokens.mean == 100.0 - # Second call had 150 completion tokens - assert workflow_node.predictions_by_call_index[2].output_tokens.mean == 150.0 -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py::test_trie_builder_builds_from_single_trace -v` -Expected: FAIL with "ModuleNotFoundError" - -### Step 3: Implement PredictionTrieBuilder - -```python -# src/nat/profiler/prediction_trie/trie_builder.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections import defaultdict -from dataclasses import dataclass -from dataclasses import field - -from nat.data_models.intermediate_step import IntermediateStep -from nat.data_models.intermediate_step import IntermediateStepType -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator - - -@dataclass -class LLMCallContext: - """Context for a single LLM call extracted from a trace.""" - - path: list[str] - call_index: int - remaining_calls: int - time_to_next_ms: float | None - output_tokens: int - - -@dataclass -class _NodeAccumulators: - """Accumulators for a single trie node.""" - - remaining_calls: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) - interarrival_ms: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) - output_tokens: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) - # For aggregated stats across all call indices - all_remaining_calls: MetricsAccumulator = field(default_factory=MetricsAccumulator) - all_interarrival_ms: MetricsAccumulator = field(default_factory=MetricsAccumulator) - all_output_tokens: MetricsAccumulator = field(default_factory=MetricsAccumulator) - - -class PredictionTrieBuilder: - """Builds a prediction trie from profiler execution traces.""" - - def __init__(self) -> None: - # Map from path tuple to accumulators - self._node_accumulators: dict[tuple[str, ...], _NodeAccumulators] = defaultdict(_NodeAccumulators) - - def add_trace(self, steps: list[IntermediateStep]) -> None: - """Process a single execution trace and update accumulators.""" - contexts = self._extract_llm_contexts(steps) - for ctx in contexts: - self._update_accumulators(ctx) - - def _extract_llm_contexts(self, steps: list[IntermediateStep]) -> list[LLMCallContext]: - """Extract LLM call contexts from a trace.""" - # Sort steps by timestamp - sorted_steps = sorted(steps, key=lambda s: s.event_timestamp) - - # Find all LLM_END events - llm_ends: list[IntermediateStep] = [] - for step in sorted_steps: - if step.event_type == IntermediateStepType.LLM_END: - llm_ends.append(step) - - # Find all LLM_START events for interarrival time calculation - llm_starts: list[IntermediateStep] = [] - for step in sorted_steps: - if step.event_type == IntermediateStepType.LLM_START: - llm_starts.append(step) - - # Track call index per parent function - call_counts: dict[str, int] = defaultdict(int) - contexts: list[LLMCallContext] = [] - - for i, end_step in enumerate(llm_ends): - # Build path from function ancestry - path = self._build_path(end_step) - - # Determine call index within parent - parent_key = end_step.function_ancestry.function_id - call_counts[parent_key] += 1 - call_index = call_counts[parent_key] - - # Remaining calls in this trace - remaining = len(llm_ends) - i - 1 - - # Time to next LLM start (if any) - time_to_next_ms: float | None = None - if i + 1 < len(llm_starts): - next_start_time = llm_starts[i + 1].event_timestamp if i + 1 < len(llm_starts) else None - if next_start_time is not None: - time_to_next_ms = (next_start_time - end_step.event_timestamp) * 1000.0 - - # Output tokens - output_tokens = 0 - if end_step.usage_info and end_step.usage_info.token_usage: - output_tokens = end_step.usage_info.token_usage.completion_tokens or 0 - - contexts.append( - LLMCallContext( - path=path, - call_index=call_index, - remaining_calls=remaining, - time_to_next_ms=time_to_next_ms, - output_tokens=output_tokens, - ) - ) - - return contexts - - def _build_path(self, step: IntermediateStep) -> list[str]: - """Build the function path from ancestry.""" - path: list[str] = [] - ancestry = step.function_ancestry - - # Walk up the ancestry chain - if ancestry.parent_name: - path.append(ancestry.parent_name) - path.append(ancestry.function_name) - - return path - - def _update_accumulators(self, ctx: LLMCallContext) -> None: - """Update accumulators at every node along the path.""" - # Update root node - root_key: tuple[str, ...] = () - self._add_to_accumulators(root_key, ctx) - - # Update each node along the path - for i in range(len(ctx.path)): - path_key = tuple(ctx.path[: i + 1]) - self._add_to_accumulators(path_key, ctx) - - def _add_to_accumulators(self, path_key: tuple[str, ...], ctx: LLMCallContext) -> None: - """Add context data to accumulators for a specific path.""" - accs = self._node_accumulators[path_key] - - # By call index - accs.remaining_calls[ctx.call_index].add_sample(float(ctx.remaining_calls)) - accs.output_tokens[ctx.call_index].add_sample(float(ctx.output_tokens)) - if ctx.time_to_next_ms is not None: - accs.interarrival_ms[ctx.call_index].add_sample(ctx.time_to_next_ms) - - # Aggregated across all indices - accs.all_remaining_calls.add_sample(float(ctx.remaining_calls)) - accs.all_output_tokens.add_sample(float(ctx.output_tokens)) - if ctx.time_to_next_ms is not None: - accs.all_interarrival_ms.add_sample(ctx.time_to_next_ms) - - def build(self) -> PredictionTrieNode: - """Build the final prediction trie from accumulated data.""" - root = PredictionTrieNode(name="root") - - for path_key, accs in self._node_accumulators.items(): - node = self._get_or_create_node(root, path_key) - self._populate_node_predictions(node, accs) - - return root - - def _get_or_create_node(self, root: PredictionTrieNode, path_key: tuple[str, ...]) -> PredictionTrieNode: - """Navigate to or create a node at the given path.""" - if not path_key: - return root - - current = root - for name in path_key: - if name not in current.children: - current.children[name] = PredictionTrieNode(name=name) - current = current.children[name] - return current - - def _populate_node_predictions(self, node: PredictionTrieNode, accs: _NodeAccumulators) -> None: - """Populate a node with computed predictions from accumulators.""" - # Predictions by call index - all_indices = set(accs.remaining_calls.keys()) | set(accs.interarrival_ms.keys()) | set(accs.output_tokens.keys()) - - for idx in all_indices: - prediction = LLMCallPrediction( - remaining_calls=accs.remaining_calls[idx].compute_metrics(), - interarrival_ms=accs.interarrival_ms[idx].compute_metrics(), - output_tokens=accs.output_tokens[idx].compute_metrics(), - ) - node.predictions_by_call_index[idx] = prediction - - # Aggregated predictions - if accs.all_remaining_calls._samples: - node.predictions_any_index = LLMCallPrediction( - remaining_calls=accs.all_remaining_calls.compute_metrics(), - interarrival_ms=accs.all_interarrival_ms.compute_metrics(), - output_tokens=accs.all_output_tokens.compute_metrics(), - ) -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py -v` -Expected: PASS - -### Step 5: Add test for interarrival time - -Add to `tests/nat/profiler/prediction_trie/test_trie_builder.py`: - -```python -def test_trie_builder_computes_interarrival_time(simple_trace): - builder = PredictionTrieBuilder() - builder.add_trace(simple_trace) - trie = builder.build() - - workflow_node = trie.children["my_workflow"] - # First call: next LLM starts at 1002.0, this call ends at 1001.0 -> 1000ms - assert workflow_node.predictions_by_call_index[1].interarrival_ms.mean == 1000.0 -``` - -### Step 6: Run all builder tests - -Run: `pytest tests/nat/profiler/prediction_trie/test_trie_builder.py -v` -Expected: PASS - -### Step 7: Commit - -```bash -git add src/nat/profiler/prediction_trie/trie_builder.py tests/nat/profiler/prediction_trie/test_trie_builder.py -git commit --signoff -m "feat(profiler): add PredictionTrieBuilder - -Builds prediction trie from profiler execution traces: -- Extracts LLM call contexts (path, call index, remaining, interarrival, output tokens) -- Aggregates metrics at every node along the path -- Computes stats by call index and aggregated fallback" -``` - ---- - -## Task 4: Trie Lookup - -**Files:** -- Create: `src/nat/profiler/prediction_trie/trie_lookup.py` -- Test: `tests/nat/profiler/prediction_trie/test_trie_lookup.py` - -### Step 1: Write the failing test for lookup - -```python -# tests/nat/profiler/prediction_trie/test_trie_lookup.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - - -@pytest.fixture(name="sample_trie") -def fixture_sample_trie() -> PredictionTrieNode: - """Create a sample trie for testing lookups.""" - prediction_1 = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - prediction_2 = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=400.0, p50=380.0, p90=600.0, p95=700.0), - output_tokens=PredictionMetrics(sample_count=10, mean=200.0, p50=190.0, p90=280.0, p95=320.0), - ) - aggregated = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=20, mean=2.5, p50=2.5, p90=3.5, p95=4.5), - interarrival_ms=PredictionMetrics(sample_count=20, mean=450.0, p50=415.0, p90=650.0, p95=750.0), - output_tokens=PredictionMetrics(sample_count=20, mean=175.0, p50=165.0, p90=240.0, p95=285.0), - ) - - agent_node = PredictionTrieNode( - name="react_agent", - predictions_by_call_index={1: prediction_1, 2: prediction_2}, - predictions_any_index=aggregated, - ) - - workflow_node = PredictionTrieNode( - name="my_workflow", - children={"react_agent": agent_node}, - predictions_any_index=aggregated, - ) - - root = PredictionTrieNode( - name="root", - children={"my_workflow": workflow_node}, - predictions_any_index=aggregated, - ) - - return root - - -def test_lookup_exact_match(sample_trie): - lookup = PredictionTrieLookup(sample_trie) - result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) - - assert result is not None - assert result.remaining_calls.mean == 3.0 - assert result.output_tokens.mean == 150.0 - - -def test_lookup_partial_path_match(sample_trie): - """When exact path doesn't exist, fall back to closest ancestor.""" - lookup = PredictionTrieLookup(sample_trie) - # "unknown_tool" doesn't exist, should fall back to react_agent's aggregated - result = lookup.find(path=["my_workflow", "react_agent", "unknown_tool"], call_index=1) - - assert result is not None - # Should get react_agent's call_index=1 prediction - assert result.remaining_calls.mean == 3.0 - - -def test_lookup_unknown_call_index_fallback(sample_trie): - """When call_index doesn't exist, fall back to aggregated.""" - lookup = PredictionTrieLookup(sample_trie) - result = lookup.find(path=["my_workflow", "react_agent"], call_index=99) - - assert result is not None - # Should fall back to predictions_any_index - assert result.remaining_calls.mean == 2.5 - - -def test_lookup_no_match_returns_root_aggregated(sample_trie): - """When nothing matches, return root's aggregated.""" - lookup = PredictionTrieLookup(sample_trie) - result = lookup.find(path=["completely_unknown"], call_index=1) - - assert result is not None - # Should return root's aggregated prediction - assert result.remaining_calls.mean == 2.5 -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/profiler/prediction_trie/test_trie_lookup.py::test_lookup_exact_match -v` -Expected: FAIL with "ModuleNotFoundError" - -### Step 3: Implement PredictionTrieLookup - -```python -# src/nat/profiler/prediction_trie/trie_lookup.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - - -class PredictionTrieLookup: - """Looks up predictions in a prediction trie with graceful fallback.""" - - def __init__(self, root: PredictionTrieNode) -> None: - self._root = root - - def find(self, path: list[str], call_index: int) -> LLMCallPrediction | None: - """ - Find the best matching prediction for the given path and call index. - - Walks the trie as far as possible along the path, then returns the deepest - match. Falls back to aggregated predictions when exact call_index isn't found. - - Args: - path: Function ancestry path (e.g., ["my_workflow", "react_agent"]) - call_index: The Nth LLM call within the current parent function - - Returns: - Best matching prediction, or None if trie is empty - """ - node = self._root - deepest_match: LLMCallPrediction | None = None - - # Check root node first - deepest_match = self._get_prediction(node, call_index) or deepest_match - - # Walk the trie as far as we can match - for func_name in path: - if func_name not in node.children: - break - node = node.children[func_name] - # Update deepest match at each level - match = self._get_prediction(node, call_index) - if match is not None: - deepest_match = match - - return deepest_match - - def _get_prediction(self, node: PredictionTrieNode, call_index: int) -> LLMCallPrediction | None: - """Get prediction from node, preferring exact call_index, falling back to aggregated.""" - if call_index in node.predictions_by_call_index: - return node.predictions_by_call_index[call_index] - return node.predictions_any_index -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/profiler/prediction_trie/test_trie_lookup.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/profiler/prediction_trie/trie_lookup.py tests/nat/profiler/prediction_trie/test_trie_lookup.py -git commit --signoff -m "feat(profiler): add PredictionTrieLookup - -Walks the trie to find best matching prediction: -- Exact path + exact call_index (most specific) -- Partial path + exact call_index -- Falls back to aggregated predictions when call_index not found" -``` - ---- - -## Task 5: Serialization - -**Files:** -- Create: `src/nat/profiler/prediction_trie/serialization.py` -- Test: `tests/nat/profiler/prediction_trie/test_serialization.py` - -### Step 1: Write the failing test for serialization - -```python -# tests/nat/profiler/prediction_trie/test_serialization.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import json -import tempfile -from pathlib import Path - -import pytest - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.serialization import load_prediction_trie -from nat.profiler.prediction_trie.serialization import save_prediction_trie - - -@pytest.fixture(name="sample_trie") -def fixture_sample_trie() -> PredictionTrieNode: - """Create a sample trie for testing serialization.""" - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - child = PredictionTrieNode( - name="react_agent", - predictions_by_call_index={1: prediction}, - predictions_any_index=prediction, - ) - - root = PredictionTrieNode( - name="root", - children={"react_agent": child}, - predictions_any_index=prediction, - ) - - return root - - -def test_save_and_load_trie(sample_trie): - with tempfile.TemporaryDirectory() as tmpdir: - path = Path(tmpdir) / "prediction_trie.json" - - save_prediction_trie(sample_trie, path, workflow_name="test_workflow") - - loaded = load_prediction_trie(path) - - assert loaded.name == "root" - assert "react_agent" in loaded.children - assert loaded.children["react_agent"].predictions_by_call_index[1].remaining_calls.mean == 3.0 - - -def test_saved_file_has_metadata(sample_trie): - with tempfile.TemporaryDirectory() as tmpdir: - path = Path(tmpdir) / "prediction_trie.json" - - save_prediction_trie(sample_trie, path, workflow_name="test_workflow") - - with open(path) as f: - data = json.load(f) - - assert data["version"] == "1.0" - assert data["workflow_name"] == "test_workflow" - assert "generated_at" in data - assert "root" in data -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/profiler/prediction_trie/test_serialization.py::test_save_and_load_trie -v` -Expected: FAIL with "ModuleNotFoundError" - -### Step 3: Implement serialization functions - -```python -# src/nat/profiler/prediction_trie/serialization.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import json -from datetime import datetime -from datetime import timezone -from pathlib import Path -from typing import Any - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - -CURRENT_VERSION = "1.0" - - -def save_prediction_trie( - trie: PredictionTrieNode, - path: Path, - workflow_name: str = "unknown", -) -> None: - """ - Save a prediction trie to a JSON file. - - Args: - trie: The prediction trie root node - path: Path to save the JSON file - workflow_name: Name of the workflow this trie was built from - """ - data = { - "version": CURRENT_VERSION, - "generated_at": datetime.now(timezone.utc).isoformat(), - "workflow_name": workflow_name, - "root": _serialize_node(trie), - } - - with open(path, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2) - - -def load_prediction_trie(path: Path) -> PredictionTrieNode: - """ - Load a prediction trie from a JSON file. - - Args: - path: Path to the JSON file - - Returns: - The deserialized prediction trie root node - """ - with open(path, encoding="utf-8") as f: - data = json.load(f) - - return _deserialize_node(data["root"]) - - -def _serialize_node(node: PredictionTrieNode) -> dict[str, Any]: - """Serialize a trie node to a dictionary.""" - result: dict[str, Any] = { - "name": node.name, - "predictions_by_call_index": { - str(k): v.model_dump() for k, v in node.predictions_by_call_index.items() - }, - "predictions_any_index": node.predictions_any_index.model_dump() if node.predictions_any_index else None, - "children": {k: _serialize_node(v) for k, v in node.children.items()}, - } - return result - - -def _deserialize_node(data: dict[str, Any]) -> PredictionTrieNode: - """Deserialize a dictionary to a trie node.""" - predictions_by_call_index: dict[int, LLMCallPrediction] = {} - for k, v in data.get("predictions_by_call_index", {}).items(): - predictions_by_call_index[int(k)] = LLMCallPrediction( - remaining_calls=PredictionMetrics(**v["remaining_calls"]), - interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), - output_tokens=PredictionMetrics(**v["output_tokens"]), - ) - - predictions_any_index = None - if data.get("predictions_any_index"): - v = data["predictions_any_index"] - predictions_any_index = LLMCallPrediction( - remaining_calls=PredictionMetrics(**v["remaining_calls"]), - interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), - output_tokens=PredictionMetrics(**v["output_tokens"]), - ) - - children: dict[str, PredictionTrieNode] = {} - for k, v in data.get("children", {}).items(): - children[k] = _deserialize_node(v) - - return PredictionTrieNode( - name=data["name"], - predictions_by_call_index=predictions_by_call_index, - predictions_any_index=predictions_any_index, - children=children, - ) -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/profiler/prediction_trie/test_serialization.py -v` -Expected: PASS - -### Step 5: Update __init__.py exports - -```python -# src/nat/profiler/prediction_trie/__init__.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode -from nat.profiler.prediction_trie.serialization import load_prediction_trie -from nat.profiler.prediction_trie.serialization import save_prediction_trie -from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder -from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup - -__all__ = [ - "LLMCallPrediction", - "PredictionMetrics", - "PredictionTrieBuilder", - "PredictionTrieLookup", - "PredictionTrieNode", - "load_prediction_trie", - "save_prediction_trie", -] -``` - -### Step 6: Commit - -```bash -git add src/nat/profiler/prediction_trie/serialization.py src/nat/profiler/prediction_trie/__init__.py tests/nat/profiler/prediction_trie/test_serialization.py -git commit --signoff -m "feat(profiler): add prediction trie serialization - -JSON serialization with metadata: -- version, generated_at, workflow_name -- Recursive node serialization/deserialization -- Handles predictions_by_call_index int keys" -``` - ---- - -## Task 6: Runtime Call Tracker - -**Files:** -- Create: `src/nat/llm/prediction_context.py` -- Test: `tests/nat/llm/test_prediction_context.py` - -### Step 1: Write the failing test for LLMCallTracker - -```python -# tests/nat/llm/test_prediction_context.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.llm.prediction_context import LLMCallTracker -from nat.llm.prediction_context import get_call_tracker - - -def test_tracker_increment(): - tracker = LLMCallTracker() - assert tracker.increment("func-1") == 1 - assert tracker.increment("func-1") == 2 - assert tracker.increment("func-2") == 1 - assert tracker.increment("func-1") == 3 - - -def test_tracker_reset(): - tracker = LLMCallTracker() - tracker.increment("func-1") - tracker.increment("func-1") - tracker.reset("func-1") - assert tracker.increment("func-1") == 1 - - -def test_tracker_context_variable(): - tracker1 = get_call_tracker() - tracker1.increment("func-a") - - tracker2 = get_call_tracker() - # Should be the same tracker in the same context - assert tracker2.increment("func-a") == 2 -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/llm/test_prediction_context.py::test_tracker_increment -v` -Expected: FAIL with "ModuleNotFoundError" - -### Step 3: Implement LLMCallTracker - -```python -# src/nat/llm/prediction_context.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Runtime context management for prediction trie lookups. - -Provides tracking of LLM call indices per function invocation, -enabling accurate lookups in the prediction trie at runtime. -""" - -from contextvars import ContextVar -from dataclasses import dataclass -from dataclasses import field - - -@dataclass -class LLMCallTracker: - """Tracks LLM call counts per function invocation.""" - - counts: dict[str, int] = field(default_factory=dict) - - def increment(self, parent_function_id: str) -> int: - """ - Increment and return the call index for this parent. - - Args: - parent_function_id: Unique ID of the parent function invocation - - Returns: - The call index (1-indexed) for this LLM call within the parent - """ - self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 - return self.counts[parent_function_id] - - def reset(self, parent_function_id: str) -> None: - """ - Reset call count when a function invocation completes. - - Args: - parent_function_id: Unique ID of the parent function invocation - """ - self.counts.pop(parent_function_id, None) - - -# Thread/async-safe context variable for the call tracker -_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar("llm_call_tracker") - - -def get_call_tracker() -> LLMCallTracker: - """ - Get the LLMCallTracker for the current context. - - Creates a new tracker if one doesn't exist in the current context. - - Returns: - The LLMCallTracker for this context - """ - try: - return _llm_call_tracker.get() - except LookupError: - tracker = LLMCallTracker() - _llm_call_tracker.set(tracker) - return tracker -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/llm/test_prediction_context.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/llm/prediction_context.py tests/nat/llm/test_prediction_context.py -git commit --signoff -m "feat(llm): add LLMCallTracker for runtime prediction lookups - -Context variable-based tracking of LLM call indices per function -invocation. Thread/async-safe using contextvars." -``` - ---- - -## Task 7: Profiler Integration - -**Files:** -- Modify: `src/nat/data_models/profiler.py` -- Modify: `src/nat/profiler/profile_runner.py` -- Test: `tests/nat/profiler/test_prediction_trie_integration.py` - -### Step 1: Add prediction_trie config option - -Update `src/nat/data_models/profiler.py`: - -```python -# Add to ProfilerConfig class: -class PredictionTrieConfig(BaseModel): - enable: bool = False - output_filename: str = "prediction_trie.json" - - -class ProfilerConfig(BaseModel): - - base_metrics: bool = False - token_usage_forecast: bool = False - token_uniqueness_forecast: bool = False - workflow_runtime_forecast: bool = False - compute_llm_metrics: bool = False - csv_exclude_io_text: bool = False - prompt_caching_prefixes: PromptCachingConfig = PromptCachingConfig() - bottleneck_analysis: BottleneckConfig = BottleneckConfig() - concurrency_spike_analysis: ConcurrencySpikeConfig = ConcurrencySpikeConfig() - prefix_span_analysis: PrefixSpanConfig = PrefixSpanConfig() - prediction_trie: PredictionTrieConfig = PredictionTrieConfig() # ADD THIS -``` - -### Step 2: Write failing test for profiler integration - -```python -# tests/nat/profiler/test_prediction_trie_integration.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import tempfile -from pathlib import Path - -import pytest - -from nat.data_models.intermediate_step import IntermediateStep -from nat.data_models.intermediate_step import IntermediateStepPayload -from nat.data_models.intermediate_step import IntermediateStepType -from nat.data_models.intermediate_step import UsageInfo -from nat.data_models.invocation_node import InvocationNode -from nat.data_models.profiler import PredictionTrieConfig -from nat.data_models.profiler import ProfilerConfig -from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel -from nat.profiler.prediction_trie import load_prediction_trie -from nat.profiler.profile_runner import ProfilerRunner - - -@pytest.fixture(name="sample_traces") -def fixture_sample_traces() -> list[list[IntermediateStep]]: - """Create sample traces for testing profiler integration.""" - - def make_trace() -> list[IntermediateStep]: - return [ - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_START, - event_timestamp=1000.0, - UUID="llm-1", - ), - ), - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id="workflow-1", - function_name="my_workflow", - parent_id=None, - parent_name=None, - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_END, - event_timestamp=1001.0, - span_event_timestamp=1000.0, - UUID="llm-1", - usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100)), - ), - ), - ] - - return [make_trace(), make_trace()] - - -async def test_profiler_generates_prediction_trie(sample_traces): - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - config = ProfilerConfig( - base_metrics=True, - prediction_trie=PredictionTrieConfig(enable=True), - ) - - runner = ProfilerRunner(config, output_dir) - await runner.run(sample_traces) - - trie_path = output_dir / "prediction_trie.json" - assert trie_path.exists() - - trie = load_prediction_trie(trie_path) - assert trie.name == "root" - assert "my_workflow" in trie.children -``` - -### Step 3: Run test to verify it fails - -Run: `pytest tests/nat/profiler/test_prediction_trie_integration.py -v` -Expected: FAIL (prediction_trie.json not generated) - -### Step 4: Update ProfilerRunner to generate prediction trie - -Add to `src/nat/profiler/profile_runner.py` in the `run` method, after the existing analysis sections (around line 257): - -```python - # After prefix_span_analysis section, add: - - if self.profile_config.prediction_trie.enable: - # ------------------------------------------------------------ - # Build and save prediction trie - # ------------------------------------------------------------ - from nat.profiler.prediction_trie import PredictionTrieBuilder - from nat.profiler.prediction_trie import save_prediction_trie - - logger.info("Building prediction trie from traces...") - trie_builder = PredictionTrieBuilder() - for trace in all_steps: - trie_builder.add_trace(trace) - - prediction_trie = trie_builder.build() - - if self.write_output: - trie_path = os.path.join(self.output_dir, self.profile_config.prediction_trie.output_filename) - save_prediction_trie(prediction_trie, Path(trie_path), workflow_name="profiled_workflow") - logger.info("Wrote prediction trie to: %s", trie_path) -``` - -### Step 5: Run test to verify it passes - -Run: `pytest tests/nat/profiler/test_prediction_trie_integration.py -v` -Expected: PASS - -### Step 6: Commit - -```bash -git add src/nat/data_models/profiler.py src/nat/profiler/profile_runner.py tests/nat/profiler/test_prediction_trie_integration.py -git commit --signoff -m "feat(profiler): integrate prediction trie generation - -Add PredictionTrieConfig to ProfilerConfig with enable flag. -ProfilerRunner now builds and saves prediction_trie.json when enabled." -``` - ---- - -## Task 8: Dynamo Header Injection - -**Files:** -- Modify: `src/nat/llm/dynamo_llm.py` -- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` -- Test: `tests/nat/llm/test_dynamo_prediction_headers.py` - -### Step 1: Write failing test for header injection - -```python -# tests/nat/llm/test_dynamo_prediction_headers.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.llm.dynamo_llm import create_httpx_client_with_prediction_headers -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics - - -async def test_prediction_headers_injected(): - """Test that prediction headers are injected into requests.""" - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - # Create a mock request to capture headers - captured_headers = {} - - async def capture_hook(request): - captured_headers.update(dict(request.headers)) - - client = create_httpx_client_with_prediction_headers( - prediction=prediction, - prefix_template="test-{uuid}", - total_requests=10, - osl="MEDIUM", - iat="LOW", - ) - - # Add our capture hook - client.event_hooks["request"].append(capture_hook) - - # Make a test request (will fail, but headers will be captured) - try: - await client.post("http://localhost:1/test", json={}) - except Exception: - pass - - assert "x-nat-remaining-llm-calls" in captured_headers - assert captured_headers["x-nat-remaining-llm-calls"] == "3" - assert "x-nat-interarrival-ms" in captured_headers - assert captured_headers["x-nat-interarrival-ms"] == "500" - assert "x-nat-expected-output-tokens" in captured_headers - assert captured_headers["x-nat-expected-output-tokens"] == "200" # p90 value - - await client.aclose() -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/llm/test_dynamo_prediction_headers.py -v` -Expected: FAIL with "cannot import name 'create_httpx_client_with_prediction_headers'" - -### Step 3: Add prediction header injection to dynamo_llm.py - -Add to `src/nat/llm/dynamo_llm.py`: - -```python -# Add import at top: -from nat.profiler.prediction_trie.data_models import LLMCallPrediction - - -def _create_prediction_request_hook( - prediction: LLMCallPrediction, -) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: - """ - Create an httpx event hook that injects prediction headers. - - Args: - prediction: The prediction data to inject - - Returns: - An async function suitable for use as an httpx event hook. - """ - - async def on_request(request): - """Inject prediction headers before each request.""" - request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) - request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) - request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) - - logger.debug( - "Injected prediction headers: remaining=%d, interarrival=%d, output_tokens=%d", - int(prediction.remaining_calls.mean), - int(prediction.interarrival_ms.mean), - int(prediction.output_tokens.p90), - ) - - return on_request - - -def create_httpx_client_with_prediction_headers( - prediction: LLMCallPrediction, - prefix_template: str | None, - total_requests: int, - osl: str, - iat: str, - timeout: float = 600.0, -) -> "httpx.AsyncClient": - """ - Create an httpx.AsyncClient with both Dynamo prefix and prediction headers. - - Args: - prediction: Prediction data for this LLM call - prefix_template: Template string with {uuid} placeholder - total_requests: Expected number of requests for this prefix - osl: Output sequence length hint (LOW/MEDIUM/HIGH) - iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) - timeout: HTTP request timeout in seconds - - Returns: - An httpx.AsyncClient configured with header injection. - """ - import httpx - - hooks: list[Callable] = [] - - # Add Dynamo prefix hook - prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) - hooks.append(prefix_hook) - - # Add prediction hook - prediction_hook = _create_prediction_request_hook(prediction) - hooks.append(prediction_hook) - - return httpx.AsyncClient( - event_hooks={"request": hooks}, - timeout=httpx.Timeout(timeout), - ) -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/llm/test_dynamo_prediction_headers.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamo_prediction_headers.py -git commit --signoff -m "feat(llm): add prediction header injection to Dynamo client - -Injects x-nat-remaining-llm-calls, x-nat-interarrival-ms, and -x-nat-expected-output-tokens headers for server routing optimization." -``` - ---- - -## Task 9: LangChain Integration with Trie Loading - -**Files:** -- Modify: `src/nat/llm/dynamo_llm.py` (add config field) -- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` -- Test: `tests/nat/plugins/langchain/test_dynamo_prediction_trie.py` - -### Step 1: Add prediction_trie_path to DynamoModelConfig - -Update `src/nat/llm/dynamo_llm.py`: - -```python -# Add to DynamoModelConfig class: - prediction_trie_path: str | None = Field( - default=None, - description="Path to prediction_trie.json file. When set, predictions are " - "looked up and injected as headers for each LLM call.", - ) - - # Update get_dynamo_field_names(): - @staticmethod - def get_dynamo_field_names() -> frozenset[str]: - return frozenset({ - "prefix_template", - "prefix_total_requests", - "prefix_osl", - "prefix_iat", - "request_timeout", - "prediction_trie_path", # ADD THIS - }) -``` - -### Step 2: Write test for trie-based header injection - -```python -# tests/nat/plugins/langchain/test_dynamo_prediction_trie.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import tempfile -from pathlib import Path - -import pytest - -from nat.llm.dynamo_llm import DynamoModelConfig -from nat.profiler.prediction_trie import PredictionTrieNode -from nat.profiler.prediction_trie import save_prediction_trie -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics - - -@pytest.fixture(name="trie_file") -def fixture_trie_file() -> Path: - """Create a temporary trie file for testing.""" - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - root = PredictionTrieNode( - name="root", - predictions_by_call_index={1: prediction}, - predictions_any_index=prediction, - ) - - with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: - path = Path(f.name) - - save_prediction_trie(root, path) - yield path - path.unlink(missing_ok=True) - - -def test_dynamo_config_with_trie_path(trie_file): - """Test that DynamoModelConfig accepts prediction_trie_path.""" - config = DynamoModelConfig( - base_url="http://localhost:8000", - model_name="test-model", - api_key="test-key", - prediction_trie_path=str(trie_file), - ) - - assert config.prediction_trie_path == str(trie_file) - assert "prediction_trie_path" in DynamoModelConfig.get_dynamo_field_names() -``` - -### Step 3: Run test to verify config field works - -Run: `pytest tests/nat/plugins/langchain/test_dynamo_prediction_trie.py -v` -Expected: PASS - -### Step 4: Commit - -```bash -git add src/nat/llm/dynamo_llm.py tests/nat/plugins/langchain/test_dynamo_prediction_trie.py -git commit --signoff -m "feat(llm): add prediction_trie_path config to DynamoModelConfig - -Allows specifying a prediction_trie.json file path in workflow config. -When set, predictions are looked up and injected as headers." -``` - ---- - -## Task 10: End-to-End Integration Test - -**Files:** -- Test: `tests/nat/profiler/test_prediction_trie_e2e.py` - -### Step 1: Write end-to-end test - -```python -# tests/nat/profiler/test_prediction_trie_e2e.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""End-to-end test for prediction trie workflow.""" - -import tempfile -from pathlib import Path - -import pytest - -from nat.data_models.intermediate_step import IntermediateStep -from nat.data_models.intermediate_step import IntermediateStepPayload -from nat.data_models.intermediate_step import IntermediateStepType -from nat.data_models.intermediate_step import UsageInfo -from nat.data_models.invocation_node import InvocationNode -from nat.data_models.profiler import PredictionTrieConfig -from nat.data_models.profiler import ProfilerConfig -from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel -from nat.profiler.prediction_trie import PredictionTrieLookup -from nat.profiler.prediction_trie import load_prediction_trie -from nat.profiler.profile_runner import ProfilerRunner - - -def make_agent_trace(agent_name: str, num_llm_calls: int, base_timestamp: float) -> list[IntermediateStep]: - """Create a trace with multiple LLM calls in an agent.""" - steps = [] - ts = base_timestamp - - for i in range(num_llm_calls): - llm_uuid = f"llm-{agent_name}-{i}" - - # LLM_START - steps.append( - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id=f"{agent_name}-1", - function_name=agent_name, - parent_id="workflow-1", - parent_name="my_workflow", - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_START, - event_timestamp=ts, - UUID=llm_uuid, - ), - ) - ) - ts += 0.5 - - # LLM_END - completion_tokens = 100 + (i * 50) # Vary tokens by position - steps.append( - IntermediateStep( - parent_id="root", - function_ancestry=InvocationNode( - function_id=f"{agent_name}-1", - function_name=agent_name, - parent_id="workflow-1", - parent_name="my_workflow", - ), - payload=IntermediateStepPayload( - event_type=IntermediateStepType.LLM_END, - event_timestamp=ts, - span_event_timestamp=ts - 0.5, - UUID=llm_uuid, - usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=completion_tokens)), - ), - ) - ) - ts += 0.5 - - return steps - - -async def test_e2e_prediction_trie_workflow(): - """Test the complete flow: profiler -> trie -> lookup.""" - with tempfile.TemporaryDirectory() as tmpdir: - output_dir = Path(tmpdir) - - # Create multiple traces with different agents - traces = [ - make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=1000.0), - make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=2000.0), - make_agent_trace("tool_agent", num_llm_calls=2, base_timestamp=3000.0), - ] - - # Run profiler - config = ProfilerConfig( - base_metrics=True, - prediction_trie=PredictionTrieConfig(enable=True), - ) - runner = ProfilerRunner(config, output_dir) - await runner.run(traces) - - # Load trie - trie_path = output_dir / "prediction_trie.json" - assert trie_path.exists(), "Trie file should exist" - - trie = load_prediction_trie(trie_path) - lookup = PredictionTrieLookup(trie) - - # Test lookups - # react_agent has 3 LLM calls, so at call 1 there are 2 remaining - result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) - assert result is not None - assert result.remaining_calls.mean == 2.0 # 2 remaining after first call - - # At call 3 there are 0 remaining - result = lookup.find(path=["my_workflow", "react_agent"], call_index=3) - assert result is not None - assert result.remaining_calls.mean == 0.0 - - # tool_agent should have different stats - result = lookup.find(path=["my_workflow", "tool_agent"], call_index=1) - assert result is not None - assert result.remaining_calls.mean == 1.0 # 1 remaining after first call - - # Unknown agent should fall back to aggregated - result = lookup.find(path=["my_workflow", "unknown_agent"], call_index=1) - assert result is not None # Should still get a result from fallback -``` - -### Step 2: Run e2e test - -Run: `pytest tests/nat/profiler/test_prediction_trie_e2e.py -v` -Expected: PASS - -### Step 3: Commit - -```bash -git add tests/nat/profiler/test_prediction_trie_e2e.py -git commit --signoff -m "test(profiler): add end-to-end prediction trie test - -Validates complete flow: profiler traces -> trie generation -> lookup -with different agents and call indices." -``` - ---- - -## Summary - -This plan implements the prediction trie feature in 10 tasks: - -1. **Data Models** - Pydantic models for trie nodes and predictions -2. **Metrics Accumulator** - Helper for computing statistics -3. **Trie Builder** - Builds trie from profiler traces -4. **Trie Lookup** - Finds best matching prediction with fallback -5. **Serialization** - JSON save/load -6. **Runtime Call Tracker** - Context variable for tracking call indices -7. **Profiler Integration** - Config option and trie generation -8. **Dynamo Header Injection** - httpx hooks for prediction headers -9. **LangChain Integration** - Config field for trie path -10. **End-to-End Test** - Validates complete flow - -Each task follows TDD: write failing test, implement, verify, commit. diff --git a/docs/plans/2026-01-24-prediction-trie-example-config-design.md b/docs/plans/2026-01-24-prediction-trie-example-config-design.md deleted file mode 100644 index 74922d487b..0000000000 --- a/docs/plans/2026-01-24-prediction-trie-example-config-design.md +++ /dev/null @@ -1,127 +0,0 @@ -# Prediction Trie Example Config Design - -## Overview - -Create example configs and documentation demonstrating the two-phase Dynamo optimization workflow using prediction trie for dynamic header injection. - -## Two-Phase Workflow - -``` -Phase 1: Profiling -┌─────────────────────────────────────────────────────────────┐ -│ nat eval --config_file profile_rethinking_full_test.yml │ -│ │ │ -│ ▼ │ -│ outputs/rethinking_full_test_for_profiling/ │ -│ └── prediction_trie.json │ -└─────────────────────────────────────────────────────────────┘ - -Phase 2: Run with Predictions -┌─────────────────────────────────────────────────────────────┐ -│ nat eval --config_file run_with_prediction_trie.yml │ -│ │ │ -│ Loads prediction_trie.json │ -│ │ │ -│ Injects dynamic headers per LLM call: │ -│ - x-nat-remaining-llm-calls │ -│ - x-nat-interarrival-ms │ -│ - x-nat-expected-output-tokens │ -└─────────────────────────────────────────────────────────────┘ -``` - -**Key difference from static headers:** Instead of guessing `prefix_total_requests=10`, the trie provides accurate per-call predictions based on function path and call index from profiled data. - -## Deliverables - -### 1. Update: profile_rethinking_full_test.yml - -Add `prediction_trie` section to enable trie building: - -```yaml -profiler: - # ... existing config ... - - # NEW: Build prediction trie from profiled traces - prediction_trie: - enable: true - output_filename: prediction_trie.json -``` - -### 2. New: run_with_prediction_trie.yml - -Config that loads the trie and uses dynamic predictions: - -```yaml -llms: - dynamo_llm: - _type: dynamo - model_name: llama-3.3-70b - base_url: http://localhost:8099/v1 - api_key: dummy - temperature: 0.0 - max_tokens: 8192 - stop: ["Observation:", "\nThought:"] - prefix_template: "react-benchmark-{uuid}" - - # Static headers as fallback - prefix_total_requests: 10 - prefix_osl: MEDIUM - prefix_iat: MEDIUM - - # NEW: Load prediction trie for dynamic per-call headers - prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json - -eval: - general: - output: - dir: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/prediction_trie_eval/ - - profiler: - compute_llm_metrics: true - csv_exclude_io_text: true -``` - -### 3. New: README_PREDICTION_TRIE.md - -Documentation for the two-phase workflow: - -```markdown -# Prediction Trie Optimization for Dynamo - -## Overview -Use profiled execution data to inject accurate per-call prediction headers -instead of static guesses. - -## Quick Start - -### Phase 1: Build the Prediction Trie -nat eval --config_file configs/profile_rethinking_full_test.yml - -Output: outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json - -### Phase 2: Run with Predictions -1. Update prediction_trie_path in run_with_prediction_trie.yml -2. Run: nat eval --config_file configs/run_with_prediction_trie.yml - -## How It Works -- Phase 1 profiles the agent and builds a trie mapping (function_path, call_index) → predictions -- Phase 2 loads the trie and injects headers dynamically based on current execution context - -## Headers Injected -| Header | Source | Description | -|--------|--------|-------------| -| x-nat-remaining-llm-calls | prediction.remaining_calls.mean | Expected remaining calls | -| x-nat-interarrival-ms | prediction.interarrival_ms.mean | Expected time to next call | -| x-nat-expected-output-tokens | prediction.output_tokens.p90 | Expected output tokens | - -## Comparing Results -Run both static and prediction-based configs and compare avg_llm_latency metrics. -``` - -## Files Changed - -| File | Type | Description | -|------|------|-------------| -| `examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml` | Modify | Add prediction_trie.enable: true | -| `examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml` | New | Config using prediction_trie_path | -| `examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md` | New | Documentation for two-phase workflow | diff --git a/docs/plans/2026-01-24-runtime-prediction-trie-design.md b/docs/plans/2026-01-24-runtime-prediction-trie-design.md deleted file mode 100644 index b7d9bd3109..0000000000 --- a/docs/plans/2026-01-24-runtime-prediction-trie-design.md +++ /dev/null @@ -1,248 +0,0 @@ -# Runtime Prediction Trie Integration Design - -## Overview - -This design addresses the gap between the prediction trie (built by the profiler) and runtime execution. Currently, the trie is built and saved, but never loaded or used during actual workflow execution to inject prediction headers. - -## Problem Statement - -The prediction trie implementation has the following gaps: - -1. **Trie never loaded at runtime** - `prediction_trie_path` config exists but is never used -2. **Function path not tracked for lookups** - `Context.active_function` only stores immediate parent, not full ancestry -3. **Call index never tracked at runtime** - `LLMCallTracker` exists but is never incremented during LLM calls -4. **Headers are static** - httpx client created once with static hooks; predictions need dynamic per-call lookup - -## Design Goals - -- Track full function path ancestry during workflow execution -- Track LLM call indices per parent function -- Look up predictions dynamically on each LLM call -- Inject prediction headers for Dynamo routing optimization -- Work across all LLM frameworks (LangChain, LlamaIndex, etc.) - -## Architecture - -### Separation of Concerns - -| Concern | Component | Scope | -|---------|-----------|-------| -| State tracking | Callback handlers + IntermediateStepManager | All LLM providers | -| Header injection | Dynamo httpx hook | Dynamo LLM only | - -This separation ensures state is tracked universally (even if multiple LLM providers are used in one workflow), while header injection is specific to Dynamo. - -### Data Flow - -``` -1. Workflow starts - └─► function_path_stack = ["my_workflow"] - -2. Agent function called via push_active_function("react_agent") - └─► function_path_stack = ["my_workflow", "react_agent"] - -3. LLM call initiated - └─► Callback fires on_chat_model_start - └─► IntermediateStepManager.push_intermediate_step(LLM_START) - └─► call_tracker.increment(parent_function_id) → 1 - -4. httpx sends request (Dynamo) - └─► Dynamic hook executes: - ├─► Read function_path_stack → ["my_workflow", "react_agent"] - ├─► Read call_tracker count → 1 - ├─► trie_lookup.find(path, call_index) → prediction - │ └─► (fallback to root.predictions_any_index if no match) - └─► Inject headers - -5. Next LLM call → call_index becomes 2, repeat -``` - -## Components to Modify - -### 1. ContextState (src/nat/builder/context.py) - -Add new ContextVar to track full function path: - -```python -class ContextState: - def __init__(self): - # ... existing fields ... - self._function_path_stack: ContextVar[list[str] | None] = ContextVar( - "function_path_stack", default=None - ) - - @property - def function_path_stack(self) -> ContextVar[list[str]]: - if self._function_path_stack.get() is None: - self._function_path_stack.set([]) - return typing.cast(ContextVar[list[str]], self._function_path_stack) -``` - -### 2. Context.push_active_function() (src/nat/builder/context.py) - -Update to push/pop function names on path stack: - -```python -@contextmanager -def push_active_function(self, function_name: str, ...): - # ... existing code ... - - # Push function name onto path stack - current_path = self._context_state.function_path_stack.get() - new_path = current_path + [function_name] - path_token = self._context_state.function_path_stack.set(new_path) - - try: - yield manager - finally: - # ... existing cleanup ... - self._context_state.function_path_stack.reset(path_token) -``` - -### 3. IntermediateStepManager.push_intermediate_step() (src/nat/builder/intermediate_step_manager.py) - -Increment call tracker on LLM_START events: - -```python -from nat.llm.prediction_context import get_call_tracker - -def push_intermediate_step(self, payload: IntermediateStepPayload) -> None: - # ... existing code ... - - # Track LLM call index for prediction lookups - if payload.event_type == IntermediateStepType.LLM_START: - active_function = self._context_state.active_function.get() - if active_function: - tracker = get_call_tracker() - tracker.increment(active_function.function_id) - - # ... rest of existing code ... -``` - -### 4. Context.function_path property (src/nat/builder/context.py) - -Add property to read current function path: - -```python -@property -def function_path(self) -> list[str]: - """Returns the current function path stack (copy).""" - return list(self._context_state.function_path_stack.get()) -``` - -### 5. dynamo_langchain() (packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py) - -Load trie and create dynamic hook: - -```python -from nat.profiler.prediction_trie import load_prediction_trie, PredictionTrieLookup - -@register_llm_client(config_type=DynamoModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) -async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): - # Load prediction trie if configured - trie_lookup: PredictionTrieLookup | None = None - if llm_config.prediction_trie_path: - trie = load_prediction_trie(Path(llm_config.prediction_trie_path)) - trie_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) - - # Create httpx client with dynamic prediction hook - if llm_config.prefix_template is not None: - http_async_client = create_httpx_client_with_dynamo_hooks( - # ... existing params ... - prediction_lookup=trie_lookup, # Pass lookup to hook - ) -``` - -### 6. Dynamic Prediction Hook (src/nat/llm/dynamo_llm.py) - -Create hook that reads context and looks up predictions: - -```python -def _create_dynamic_prediction_hook( - trie_lookup: PredictionTrieLookup, -) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: - """Create hook that dynamically looks up predictions per request.""" - - async def on_request(request: "httpx.Request") -> None: - from nat.builder.context import Context - from nat.llm.prediction_context import get_call_tracker - - ctx = Context.get() - path = ctx.function_path - - # Get call index for current parent function - call_index = 1 # default - active_fn = ctx.active_function - if active_fn: - tracker = get_call_tracker() - call_index = tracker.counts.get(active_fn.function_id, 1) - - # Look up prediction - prediction = trie_lookup.find(path, call_index) - - if prediction: - request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) - request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) - request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) - - logger.debug( - "Injected prediction headers: path=%s, call_index=%d, remaining=%d", - path, call_index, int(prediction.remaining_calls.mean) - ) - - return on_request -``` - -## Fallback Chain - -When looking up predictions, the following fallback chain applies: - -1. **Exact match**: path + call_index found in trie -2. **Partial path**: walk trie as far as possible, use deepest match -3. **Any index**: use node's `predictions_any_index` if exact call_index not found -4. **Root fallback**: use root's `predictions_any_index` as final fallback - -This ensures we always have some prediction to inject (root aggregates across all profiled traces). - -## Call Index Tracking - -- Each function invocation has a unique UUID (`function_id`) -- `LLMCallTracker.increment(function_id)` returns 1, 2, 3... for successive LLM calls -- No explicit reset needed - new function invocations get new UUIDs automatically -- Memory is minimal (dict of int counters) and garbage collected with context - -## Headers Injected - -| Header | Value | Description | -|--------|-------|-------------| -| `x-nat-remaining-llm-calls` | `int(prediction.remaining_calls.mean)` | Expected remaining LLM calls | -| `x-nat-interarrival-ms` | `int(prediction.interarrival_ms.mean)` | Expected ms until next call | -| `x-nat-expected-output-tokens` | `int(prediction.output_tokens.p90)` | Expected output tokens (p90) | - -## Testing Strategy - -1. **Unit tests**: Test each component in isolation - - `function_path_stack` push/pop behavior - - Call tracker increment in IntermediateStepManager - - Dynamic hook reads context correctly - -2. **Integration test**: End-to-end flow - - Create trie from sample traces - - Run workflow with Dynamo LLM - - Verify headers injected with correct values - -3. **Fallback test**: Verify fallback chain - - Unknown path falls back to root - - Unknown call_index falls back to any_index - -## Files Changed - -| File | Type | Description | -|------|------|-------------| -| `src/nat/builder/context.py` | Modify | Add function_path_stack ContextVar and property | -| `src/nat/builder/intermediate_step_manager.py` | Modify | Increment call tracker on LLM_START | -| `src/nat/llm/dynamo_llm.py` | Modify | Add dynamic prediction hook | -| `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py` | Modify | Load trie, wire up hook | -| `tests/nat/builder/test_function_path_stack.py` | New | Test path stack tracking | -| `tests/nat/llm/test_dynamic_prediction_hook.py` | New | Test dynamic lookup and injection | diff --git a/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md b/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md deleted file mode 100644 index e87f4ee540..0000000000 --- a/docs/plans/2026-01-24-runtime-prediction-trie-implementation.md +++ /dev/null @@ -1,1066 +0,0 @@ -# Runtime Prediction Trie Integration Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Enable runtime prediction trie lookups to inject Dynamo headers based on current function path and LLM call index. - -**Architecture:** Add a function path stack ContextVar for tracking ancestry, increment call tracker in IntermediateStepManager on LLM_START events, and create a dynamic httpx hook that reads context and looks up predictions from a pre-loaded trie. - -**Tech Stack:** Python 3.11+, contextvars, Pydantic v2, httpx event hooks - ---- - -## Task 1: Add Function Path Stack to ContextState - -**Files:** -- Modify: `src/nat/builder/context.py:67-120` -- Test: `tests/nat/builder/test_function_path_stack.py` - -### Step 1: Write the failing test for function_path_stack - -```python -# tests/nat/builder/test_function_path_stack.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.builder.context import ContextState - - -def test_function_path_stack_default_empty(): - """Test that function_path_stack starts empty.""" - state = ContextState.get() - # Reset to test fresh state - state._function_path_stack.set(None) - - path = state.function_path_stack.get() - assert path == [] - - -def test_function_path_stack_can_be_set(): - """Test that function_path_stack can be set and retrieved.""" - state = ContextState.get() - state.function_path_stack.set(["workflow", "agent"]) - - path = state.function_path_stack.get() - assert path == ["workflow", "agent"] -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/builder/test_function_path_stack.py::test_function_path_stack_default_empty -v` -Expected: FAIL with "AttributeError: 'ContextState' object has no attribute '_function_path_stack'" - -### Step 3: Add function_path_stack ContextVar to ContextState - -In `src/nat/builder/context.py`, add to `ContextState.__init__` after line 83: - -```python - self._function_path_stack: ContextVar[list[str] | None] = ContextVar("function_path_stack", default=None) -``` - -And add the property after `active_span_id_stack` property (after line 116): - -```python - @property - def function_path_stack(self) -> ContextVar[list[str]]: - if self._function_path_stack.get() is None: - self._function_path_stack.set([]) - return typing.cast(ContextVar[list[str]], self._function_path_stack) -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/builder/test_function_path_stack.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/builder/context.py tests/nat/builder/test_function_path_stack.py -git commit --signoff -m "feat(context): add function_path_stack ContextVar - -Tracks the full function ancestry path as a list of function names, -enabling prediction trie lookups at runtime." -``` - ---- - -## Task 2: Update push_active_function to Track Path Stack - -**Files:** -- Modify: `src/nat/builder/context.py:235-279` -- Test: `tests/nat/builder/test_function_path_stack.py` - -### Step 1: Write the failing test for push_active_function path tracking - -Add to `tests/nat/builder/test_function_path_stack.py`: - -```python -from nat.builder.context import Context - - -def test_push_active_function_updates_path_stack(): - """Test that push_active_function pushes/pops from path stack.""" - ctx = Context.get() - state = ctx._context_state - - # Reset path stack - state._function_path_stack.set(None) - - # Initially empty - assert state.function_path_stack.get() == [] - - with ctx.push_active_function("my_workflow", input_data=None): - assert state.function_path_stack.get() == ["my_workflow"] - - with ctx.push_active_function("react_agent", input_data=None): - assert state.function_path_stack.get() == ["my_workflow", "react_agent"] - - with ctx.push_active_function("tool_call", input_data=None): - assert state.function_path_stack.get() == ["my_workflow", "react_agent", "tool_call"] - - # After tool_call exits - assert state.function_path_stack.get() == ["my_workflow", "react_agent"] - - # After react_agent exits - assert state.function_path_stack.get() == ["my_workflow"] - - # After workflow exits - assert state.function_path_stack.get() == [] -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/builder/test_function_path_stack.py::test_push_active_function_updates_path_stack -v` -Expected: FAIL with assertion error (path stack not being updated) - -### Step 3: Update push_active_function to track path stack - -In `src/nat/builder/context.py`, modify `push_active_function` method. After line 252 (after setting fn_token), add: - -```python - # 1b) Push function name onto path stack - current_path = self._context_state.function_path_stack.get() - new_path = current_path + [function_name] - path_token = self._context_state.function_path_stack.set(new_path) -``` - -And in the finally block, before line 279 (before resetting fn_token), add: - -```python - # 4a) Pop function name from path stack - self._context_state.function_path_stack.reset(path_token) -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/builder/test_function_path_stack.py -v` -Expected: PASS - -### Step 5: Add function_path property to Context class - -Add after `active_function` property (around line 289): - -```python - @property - def function_path(self) -> list[str]: - """ - Returns a copy of the current function path stack. - - The function path represents the ancestry of the currently executing - function, from root to the current function. - - Returns: - list[str]: Copy of the function path stack. - """ - return list(self._context_state.function_path_stack.get()) -``` - -### Step 6: Write test for function_path property - -Add to `tests/nat/builder/test_function_path_stack.py`: - -```python -def test_context_function_path_property(): - """Test that Context.function_path returns a copy of the path stack.""" - ctx = Context.get() - state = ctx._context_state - - # Reset path stack - state._function_path_stack.set(None) - - with ctx.push_active_function("workflow", input_data=None): - with ctx.push_active_function("agent", input_data=None): - path = ctx.function_path - assert path == ["workflow", "agent"] - - # Verify it's a copy (modifications don't affect original) - path.append("modified") - assert ctx.function_path == ["workflow", "agent"] -``` - -### Step 7: Run all tests - -Run: `pytest tests/nat/builder/test_function_path_stack.py -v` -Expected: PASS - -### Step 8: Commit - -```bash -git add src/nat/builder/context.py tests/nat/builder/test_function_path_stack.py -git commit --signoff -m "feat(context): track function path in push_active_function - -Push/pop function names onto function_path_stack in push_active_function. -Add Context.function_path property to retrieve the current path." -``` - ---- - -## Task 3: Increment Call Tracker in IntermediateStepManager - -**Files:** -- Modify: `src/nat/builder/intermediate_step_manager.py:64-96` -- Test: `tests/nat/builder/test_call_tracker_integration.py` - -### Step 1: Write the failing test for call tracker integration - -```python -# tests/nat/builder/test_call_tracker_integration.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.builder.context import Context -from nat.data_models.intermediate_step import IntermediateStepPayload -from nat.data_models.intermediate_step import IntermediateStepType -from nat.llm.prediction_context import get_call_tracker - - -def test_llm_start_increments_call_tracker(): - """Test that pushing an LLM_START step increments the call tracker.""" - ctx = Context.get() - step_manager = ctx.intermediate_step_manager - - with ctx.push_active_function("test_agent", input_data=None): - active_fn = ctx.active_function - tracker = get_call_tracker() - - # Initially no count for this function - assert tracker.counts.get(active_fn.function_id, 0) == 0 - - # Push LLM_START - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-call-1", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - # Call tracker should be incremented - assert tracker.counts.get(active_fn.function_id) == 1 - - # Push another LLM_START - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-call-2", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - # Should be 2 now - assert tracker.counts.get(active_fn.function_id) == 2 - - -def test_non_llm_start_does_not_increment_tracker(): - """Test that non-LLM_START events don't increment the tracker.""" - ctx = Context.get() - step_manager = ctx.intermediate_step_manager - - with ctx.push_active_function("test_agent_2", input_data=None): - active_fn = ctx.active_function - tracker = get_call_tracker() - - initial_count = tracker.counts.get(active_fn.function_id, 0) - - # Push TOOL_START (should not increment) - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="tool-call-1", - event_type=IntermediateStepType.TOOL_START, - name="test-tool", - ) - ) - - # Count should be unchanged - assert tracker.counts.get(active_fn.function_id, 0) == initial_count -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/builder/test_call_tracker_integration.py::test_llm_start_increments_call_tracker -v` -Expected: FAIL with assertion error (count is 0, not 1) - -### Step 3: Add call tracker increment to IntermediateStepManager - -In `src/nat/builder/intermediate_step_manager.py`, add import at top: - -```python -from nat.data_models.intermediate_step import IntermediateStepType -from nat.llm.prediction_context import get_call_tracker -``` - -Then in `push_intermediate_step` method, after line 96 (after the debug log for START), add: - -```python - # Track LLM call index for prediction trie lookups - if payload.event_type == IntermediateStepType.LLM_START: - active_function = self._context_state.active_function.get() - if active_function and active_function.function_id != "root": - tracker = get_call_tracker() - tracker.increment(active_function.function_id) - logger.debug("Incremented LLM call tracker for %s to %d", - active_function.function_id, - tracker.counts.get(active_function.function_id, 0)) -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/builder/test_call_tracker_integration.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/builder/intermediate_step_manager.py tests/nat/builder/test_call_tracker_integration.py -git commit --signoff -m "feat(step-manager): increment call tracker on LLM_START - -IntermediateStepManager now increments LLMCallTracker when an LLM_START -event is pushed. This enables accurate call index tracking for prediction -trie lookups across all LLM frameworks." -``` - ---- - -## Task 4: Create Dynamic Prediction Hook - -**Files:** -- Modify: `src/nat/llm/dynamo_llm.py` -- Test: `tests/nat/llm/test_dynamic_prediction_hook.py` - -### Step 1: Write the failing test for dynamic prediction hook - -```python -# tests/nat/llm/test_dynamic_prediction_hook.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from nat.builder.context import Context -from nat.llm.dynamo_llm import _create_dynamic_prediction_hook -from nat.llm.prediction_context import get_call_tracker -from nat.profiler.prediction_trie import PredictionTrieLookup -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - - -@pytest.fixture(name="sample_trie_lookup") -def fixture_sample_trie_lookup() -> PredictionTrieLookup: - """Create a sample trie lookup for testing.""" - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - agent_node = PredictionTrieNode( - name="react_agent", - predictions_by_call_index={1: prediction, 2: prediction}, - predictions_any_index=prediction, - ) - - workflow_node = PredictionTrieNode( - name="my_workflow", - children={"react_agent": agent_node}, - predictions_any_index=prediction, - ) - - root = PredictionTrieNode( - name="root", - children={"my_workflow": workflow_node}, - predictions_any_index=prediction, - ) - - return PredictionTrieLookup(root) - - -class MockRequest: - """Mock httpx.Request for testing.""" - - def __init__(self): - self.headers = {} - - -async def test_dynamic_hook_injects_headers(sample_trie_lookup): - """Test that dynamic hook injects prediction headers based on context.""" - ctx = Context.get() - state = ctx._context_state - - # Reset state - state._function_path_stack.set(None) - - hook = _create_dynamic_prediction_hook(sample_trie_lookup) - - with ctx.push_active_function("my_workflow", input_data=None): - with ctx.push_active_function("react_agent", input_data=None): - # Simulate LLM call tracker increment (normally done by step manager) - tracker = get_call_tracker() - tracker.increment(ctx.active_function.function_id) - - request = MockRequest() - await hook(request) - - assert "x-nat-remaining-llm-calls" in request.headers - assert request.headers["x-nat-remaining-llm-calls"] == "3" - assert request.headers["x-nat-interarrival-ms"] == "500" - assert request.headers["x-nat-expected-output-tokens"] == "200" - - -async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): - """Test that hook falls back to root prediction for unknown paths.""" - ctx = Context.get() - state = ctx._context_state - - # Reset state - state._function_path_stack.set(None) - - hook = _create_dynamic_prediction_hook(sample_trie_lookup) - - with ctx.push_active_function("unknown_workflow", input_data=None): - tracker = get_call_tracker() - tracker.increment(ctx.active_function.function_id) - - request = MockRequest() - await hook(request) - - # Should still inject headers from root fallback - assert "x-nat-remaining-llm-calls" in request.headers -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py::test_dynamic_hook_injects_headers -v` -Expected: FAIL with "cannot import name '_create_dynamic_prediction_hook'" - -### Step 3: Implement dynamic prediction hook - -Add to `src/nat/llm/dynamo_llm.py` after the existing `_create_prediction_request_hook` function (around line 383): - -```python -def _create_dynamic_prediction_hook( - trie_lookup: "PredictionTrieLookup", -) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: - """ - Create an httpx event hook that dynamically looks up predictions per request. - - This hook reads the current function path and call index from context, - looks up the prediction in the trie, and injects headers. - - Args: - trie_lookup: The PredictionTrieLookup instance to query - - Returns: - An async function suitable for use as an httpx event hook. - """ - # Import here to avoid circular imports - from nat.profiler.prediction_trie import PredictionTrieLookup - - async def on_request(request: "httpx.Request") -> None: - """Look up prediction from context and inject headers.""" - from nat.builder.context import Context - from nat.llm.prediction_context import get_call_tracker - - try: - ctx = Context.get() - path = ctx.function_path - - # Get call index for current parent function - call_index = 1 # default - active_fn = ctx.active_function - if active_fn and active_fn.function_id != "root": - tracker = get_call_tracker() - call_index = tracker.counts.get(active_fn.function_id, 1) - - # Look up prediction - prediction = trie_lookup.find(path, call_index) - - if prediction: - request.headers["x-nat-remaining-llm-calls"] = str(int(prediction.remaining_calls.mean)) - request.headers["x-nat-interarrival-ms"] = str(int(prediction.interarrival_ms.mean)) - request.headers["x-nat-expected-output-tokens"] = str(int(prediction.output_tokens.p90)) - - logger.debug( - "Injected prediction headers: path=%s, call_index=%d, remaining=%d, interarrival=%d, output=%d", - path, - call_index, - int(prediction.remaining_calls.mean), - int(prediction.interarrival_ms.mean), - int(prediction.output_tokens.p90), - ) - else: - logger.debug("No prediction found for path=%s, call_index=%d", path, call_index) - - except Exception as e: - # Don't fail the request if prediction lookup fails - logger.warning("Failed to inject prediction headers: %s", e) - - return on_request -``` - -Also add the import at top of file (after existing TYPE_CHECKING imports): - -```python -if TYPE_CHECKING: - import httpx - from nat.profiler.prediction_trie import PredictionTrieLookup -``` - -### Step 4: Run test to verify it passes - -Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamic_prediction_hook.py -git commit --signoff -m "feat(dynamo): add dynamic prediction hook - -Creates httpx hook that reads function path and call index from context, -looks up prediction in trie, and injects headers per-request." -``` - ---- - -## Task 5: Update create_httpx_client_with_dynamo_hooks - -**Files:** -- Modify: `src/nat/llm/dynamo_llm.py:325-355` -- Test: `tests/nat/llm/test_dynamo_prediction_hook.py` - -### Step 1: Write test for updated client creation - -Add to `tests/nat/llm/test_dynamic_prediction_hook.py`: - -```python -from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks - - -async def test_client_includes_prediction_hook_when_lookup_provided(sample_trie_lookup): - """Test that client includes prediction hook when trie_lookup is provided.""" - client = create_httpx_client_with_dynamo_hooks( - prefix_template="test-{uuid}", - total_requests=10, - osl="MEDIUM", - iat="LOW", - prediction_lookup=sample_trie_lookup, - ) - - # Should have 2 hooks: dynamo prefix + prediction - assert len(client.event_hooks["request"]) == 2 - - await client.aclose() - - -async def test_client_works_without_prediction_lookup(): - """Test that client works when prediction_lookup is None.""" - client = create_httpx_client_with_dynamo_hooks( - prefix_template="test-{uuid}", - total_requests=10, - osl="MEDIUM", - iat="LOW", - prediction_lookup=None, - ) - - # Should have 1 hook: dynamo prefix only - assert len(client.event_hooks["request"]) == 1 - - await client.aclose() -``` - -### Step 2: Run test to verify it fails - -Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py::test_client_includes_prediction_hook_when_lookup_provided -v` -Expected: FAIL with "unexpected keyword argument 'prediction_lookup'" - -### Step 3: Update create_httpx_client_with_dynamo_hooks - -Modify `create_httpx_client_with_dynamo_hooks` in `src/nat/llm/dynamo_llm.py`: - -```python -def create_httpx_client_with_dynamo_hooks( - prefix_template: str | None, - total_requests: int, - osl: str, - iat: str, - timeout: float = 600.0, - prediction_lookup: "PredictionTrieLookup | None" = None, -) -> "httpx.AsyncClient": - """ - Create an httpx.AsyncClient with Dynamo prefix header injection. - - This client can be passed to the OpenAI SDK to inject headers at the HTTP level, - making it framework-agnostic. - - Args: - prefix_template: Template string with {uuid} placeholder - total_requests: Expected number of requests for this prefix - osl: Output sequence length hint (LOW/MEDIUM/HIGH) - iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) - timeout: HTTP request timeout in seconds - prediction_lookup: Optional PredictionTrieLookup for dynamic header injection - - Returns: - An httpx.AsyncClient configured with Dynamo header injection. - """ - import httpx - - hooks: list[Callable] = [] - - # Add Dynamo prefix hook - prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) - hooks.append(prefix_hook) - - # Add dynamic prediction hook if lookup provided - if prediction_lookup is not None: - prediction_hook = _create_dynamic_prediction_hook(prediction_lookup) - hooks.append(prediction_hook) - - return httpx.AsyncClient( - event_hooks={"request": hooks}, - timeout=httpx.Timeout(timeout), - ) -``` - -### Step 4: Run tests to verify they pass - -Run: `pytest tests/nat/llm/test_dynamic_prediction_hook.py -v` -Expected: PASS - -### Step 5: Commit - -```bash -git add src/nat/llm/dynamo_llm.py tests/nat/llm/test_dynamic_prediction_hook.py -git commit --signoff -m "feat(dynamo): add prediction_lookup param to client creation - -create_httpx_client_with_dynamo_hooks now accepts optional prediction_lookup -parameter. When provided, adds dynamic prediction hook to inject headers." -``` - ---- - -## Task 6: Load Trie in LangChain Dynamo Client - -**Files:** -- Modify: `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py:202-252` -- Test: `tests/nat/plugins/langchain/test_dynamo_trie_loading.py` - -### Step 1: Write test for trie loading - -```python -# tests/nat/plugins/langchain/test_dynamo_trie_loading.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -import tempfile -from pathlib import Path - -import pytest - -from nat.llm.dynamo_llm import DynamoModelConfig -from nat.profiler.prediction_trie import save_prediction_trie -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - - -@pytest.fixture(name="trie_file") -def fixture_trie_file(): - """Create a temporary trie file.""" - prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - root = PredictionTrieNode( - name="root", - predictions_by_call_index={1: prediction}, - predictions_any_index=prediction, - ) - - with tempfile.TemporaryDirectory() as tmpdir: - path = Path(tmpdir) / "prediction_trie.json" - save_prediction_trie(root, path, workflow_name="test") - yield str(path) - - -def test_dynamo_config_with_valid_trie_path(trie_file): - """Test that DynamoModelConfig can be created with valid trie path.""" - config = DynamoModelConfig( - base_url="http://localhost:8000/v1", - model_name="test-model", - api_key="test-key", - prediction_trie_path=trie_file, - ) - - assert config.prediction_trie_path == trie_file - - -def test_dynamo_config_with_nonexistent_trie_path(): - """Test that DynamoModelConfig accepts nonexistent path (validated at load time).""" - config = DynamoModelConfig( - base_url="http://localhost:8000/v1", - model_name="test-model", - api_key="test-key", - prediction_trie_path="/nonexistent/path/trie.json", - ) - - # Config creation should succeed; error happens at runtime - assert config.prediction_trie_path == "/nonexistent/path/trie.json" -``` - -### Step 2: Run tests - -Run: `pytest tests/nat/plugins/langchain/test_dynamo_trie_loading.py -v` -Expected: PASS (config validation already exists) - -### Step 3: Update dynamo_langchain to load trie - -Modify `packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py`. Add import at top: - -```python -from pathlib import Path - -from nat.profiler.prediction_trie import load_prediction_trie -from nat.profiler.prediction_trie import PredictionTrieLookup -``` - -Then modify the `dynamo_langchain` function (around line 202-252): - -```python -@register_llm_client(config_type=DynamoModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) -async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): - """ - Create a LangChain ChatOpenAI client for Dynamo with automatic prefix header injection. - - This client injects Dynamo prefix headers at the HTTP transport level using httpx event hooks, - enabling KV cache optimization and request routing. - """ - from langchain_openai import ChatOpenAI - - # Build config dict excluding Dynamo-specific and NAT-specific fields - config_dict = llm_config.model_dump( - exclude={"type", "thinking", "api_type", *DynamoModelConfig.get_dynamo_field_names()}, - by_alias=True, - exclude_none=True, - exclude_unset=True, - ) - - # Initialize http_async_client to None for proper cleanup - http_async_client = None - - # Load prediction trie if configured - prediction_lookup: PredictionTrieLookup | None = None - if llm_config.prediction_trie_path: - try: - trie_path = Path(llm_config.prediction_trie_path) - trie = load_prediction_trie(trie_path) - prediction_lookup = PredictionTrieLookup(trie) - logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) - except FileNotFoundError: - logger.warning("Prediction trie file not found: %s", llm_config.prediction_trie_path) - except Exception as e: - logger.warning("Failed to load prediction trie: %s", e) - - try: - # If prefix_template is set, create a custom httpx client with Dynamo hooks - if llm_config.prefix_template is not None: - http_async_client = create_httpx_client_with_dynamo_hooks( - prefix_template=llm_config.prefix_template, - total_requests=llm_config.prefix_total_requests, - osl=llm_config.prefix_osl, - iat=llm_config.prefix_iat, - timeout=llm_config.request_timeout, - prediction_lookup=prediction_lookup, - ) - config_dict["http_async_client"] = http_async_client - logger.info( - "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", - llm_config.prefix_template, - llm_config.prefix_total_requests, - llm_config.prefix_osl, - llm_config.prefix_iat, - "loaded" if prediction_lookup else "disabled", - ) - - # Create the ChatOpenAI client - if llm_config.api_type == APITypeEnum.RESPONSES: - client = ChatOpenAI(stream_usage=True, use_responses_api=True, use_previous_response_id=True, **config_dict) - else: - client = ChatOpenAI(stream_usage=True, **config_dict) - - yield _patch_llm_based_on_config(client, llm_config) - finally: - # Ensure the httpx client is properly closed to avoid resource leaks - if http_async_client is not None: - await http_async_client.aclose() -``` - -### Step 4: Run existing tests to ensure no regressions - -Run: `pytest tests/nat/plugins/langchain/ -v -k dynamo` -Expected: PASS - -### Step 5: Commit - -```bash -git add packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py tests/nat/plugins/langchain/test_dynamo_trie_loading.py -git commit --signoff -m "feat(langchain): load prediction trie in dynamo_langchain - -Loads prediction trie from prediction_trie_path config and passes -PredictionTrieLookup to httpx client for dynamic header injection." -``` - ---- - -## Task 7: End-to-End Integration Test - -**Files:** -- Test: `tests/nat/llm/test_runtime_prediction_e2e.py` - -### Step 1: Write end-to-end integration test - -```python -# tests/nat/llm/test_runtime_prediction_e2e.py -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""End-to-end test for runtime prediction trie integration.""" - -import tempfile -from pathlib import Path - -import pytest - -from nat.builder.context import Context -from nat.data_models.intermediate_step import IntermediateStepPayload -from nat.data_models.intermediate_step import IntermediateStepType -from nat.llm.dynamo_llm import _create_dynamic_prediction_hook -from nat.llm.prediction_context import get_call_tracker -from nat.profiler.prediction_trie import load_prediction_trie -from nat.profiler.prediction_trie import PredictionTrieLookup -from nat.profiler.prediction_trie import save_prediction_trie -from nat.profiler.prediction_trie.data_models import LLMCallPrediction -from nat.profiler.prediction_trie.data_models import PredictionMetrics -from nat.profiler.prediction_trie.data_models import PredictionTrieNode - - -class MockRequest: - """Mock httpx.Request for testing.""" - - def __init__(self): - self.headers = {} - - -def create_test_trie() -> PredictionTrieNode: - """Create a test trie with known predictions.""" - # Agent at call 1: 2 remaining, 500ms interarrival, 150 tokens - call_1_prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), - output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), - ) - - # Agent at call 2: 1 remaining, 300ms interarrival, 100 tokens - call_2_prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=1.0, p50=1.0, p90=2.0, p95=2.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=300.0, p50=280.0, p90=400.0, p95=450.0), - output_tokens=PredictionMetrics(sample_count=10, mean=100.0, p50=90.0, p90=150.0, p95=180.0), - ) - - # Agent at call 3: 0 remaining - call_3_prediction = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), - interarrival_ms=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), - output_tokens=PredictionMetrics(sample_count=10, mean=80.0, p50=75.0, p90=120.0, p95=140.0), - ) - - # Aggregated for fallback - aggregated = LLMCallPrediction( - remaining_calls=PredictionMetrics(sample_count=30, mean=1.0, p50=1.0, p90=2.0, p95=3.0), - interarrival_ms=PredictionMetrics(sample_count=30, mean=400.0, p50=380.0, p90=550.0, p95=600.0), - output_tokens=PredictionMetrics(sample_count=30, mean=110.0, p50=100.0, p90=160.0, p95=190.0), - ) - - agent_node = PredictionTrieNode( - name="react_agent", - predictions_by_call_index={1: call_1_prediction, 2: call_2_prediction, 3: call_3_prediction}, - predictions_any_index=aggregated, - ) - - workflow_node = PredictionTrieNode( - name="my_workflow", - children={"react_agent": agent_node}, - predictions_any_index=aggregated, - ) - - return PredictionTrieNode( - name="root", - children={"my_workflow": workflow_node}, - predictions_any_index=aggregated, - ) - - -async def test_e2e_prediction_headers_injected_correctly(): - """Test complete flow: context tracking -> step manager -> hook -> headers.""" - # Create and save trie - trie = create_test_trie() - - with tempfile.TemporaryDirectory() as tmpdir: - trie_path = Path(tmpdir) / "prediction_trie.json" - save_prediction_trie(trie, trie_path, workflow_name="test") - - # Load trie - loaded_trie = load_prediction_trie(trie_path) - lookup = PredictionTrieLookup(loaded_trie) - - # Create hook - hook = _create_dynamic_prediction_hook(lookup) - - ctx = Context.get() - state = ctx._context_state - step_manager = ctx.intermediate_step_manager - - # Reset state - state._function_path_stack.set(None) - - with ctx.push_active_function("my_workflow", input_data=None): - with ctx.push_active_function("react_agent", input_data=None): - # Simulate first LLM call - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-1", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - request1 = MockRequest() - await hook(request1) - - # Should have call 1 predictions: 2 remaining - assert request1.headers["x-nat-remaining-llm-calls"] == "2" - assert request1.headers["x-nat-interarrival-ms"] == "500" - assert request1.headers["x-nat-expected-output-tokens"] == "200" - - # Simulate second LLM call - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-2", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - request2 = MockRequest() - await hook(request2) - - # Should have call 2 predictions: 1 remaining - assert request2.headers["x-nat-remaining-llm-calls"] == "1" - assert request2.headers["x-nat-interarrival-ms"] == "300" - assert request2.headers["x-nat-expected-output-tokens"] == "150" - - # Simulate third LLM call - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-3", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - request3 = MockRequest() - await hook(request3) - - # Should have call 3 predictions: 0 remaining - assert request3.headers["x-nat-remaining-llm-calls"] == "0" - assert request3.headers["x-nat-expected-output-tokens"] == "120" - - -async def test_e2e_fallback_to_root(): - """Test that unknown paths fall back to root predictions.""" - trie = create_test_trie() - lookup = PredictionTrieLookup(trie) - hook = _create_dynamic_prediction_hook(lookup) - - ctx = Context.get() - state = ctx._context_state - step_manager = ctx.intermediate_step_manager - - # Reset state - state._function_path_stack.set(None) - - with ctx.push_active_function("unknown_workflow", input_data=None): - step_manager.push_intermediate_step( - IntermediateStepPayload( - UUID="llm-unknown", - event_type=IntermediateStepType.LLM_START, - name="test-model", - ) - ) - - request = MockRequest() - await hook(request) - - # Should fall back to root aggregated predictions - assert "x-nat-remaining-llm-calls" in request.headers - assert request.headers["x-nat-remaining-llm-calls"] == "1" # aggregated mean -``` - -### Step 2: Run e2e test - -Run: `pytest tests/nat/llm/test_runtime_prediction_e2e.py -v` -Expected: PASS - -### Step 3: Commit - -```bash -git add tests/nat/llm/test_runtime_prediction_e2e.py -git commit --signoff -m "test: add end-to-end runtime prediction trie test - -Validates complete flow: function path tracking -> call tracker increment --> dynamic hook lookup -> correct headers injected for each call index." -``` - ---- - -## Summary - -This plan implements runtime prediction trie integration in 7 tasks: - -1. **Function Path Stack** - Add ContextVar to ContextState -2. **Path Tracking** - Update push_active_function to track path -3. **Call Tracker Integration** - Increment tracker in IntermediateStepManager on LLM_START -4. **Dynamic Hook** - Create hook that reads context and looks up predictions -5. **Client Update** - Add prediction_lookup param to client creation -6. **LangChain Integration** - Load trie in dynamo_langchain -7. **E2E Test** - Validate complete flow - -Each task follows TDD: write failing test, implement, verify, commit. From 4cbd4e6a16c6dc5282c9fa8da414cc4403607ac0 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 13:30:32 -0800 Subject: [PATCH 30/37] Refactor DynamoPrefixContext for depth-aware prefix handling Introduced depth-aware prefix ID generation for more granular control of prefix IDs across nested function calls. Replaced the previous context variable approach with a depth mapping mechanism and added support for override prefixes. Updated relevant tests for clarity and alignment with the new depth-based behavior. Signed-off-by: dnandakumar-nv --- .../nat/llm/test_dynamo_prediction_headers.py | 16 +++++++---- tests/nat/llm/test_runtime_prediction_e2e.py | 28 +++++++++---------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/nat/llm/test_dynamo_prediction_headers.py b/tests/nat/llm/test_dynamo_prediction_headers.py index 97b413424b..e4b805f4ea 100644 --- a/tests/nat/llm/test_dynamo_prediction_headers.py +++ b/tests/nat/llm/test_dynamo_prediction_headers.py @@ -37,11 +37,15 @@ async def capture_hook(request): except Exception: pass - assert "x-nat-remaining-llm-calls" in captured_headers - assert captured_headers["x-nat-remaining-llm-calls"] == "3" - assert "x-nat-interarrival-ms" in captured_headers - assert captured_headers["x-nat-interarrival-ms"] == "500" - assert "x-nat-expected-output-tokens" in captured_headers - assert captured_headers["x-nat-expected-output-tokens"] == "200" # p90 value + # Prediction hook overrides x-prefix-* headers with prediction-derived values + # remaining_calls.mean=3.0 → x-prefix-total-requests="3" + assert "x-prefix-total-requests" in captured_headers + assert captured_headers["x-prefix-total-requests"] == "3" + # output_tokens.p90=200.0 (< 256) → x-prefix-osl="LOW" + assert "x-prefix-osl" in captured_headers + assert captured_headers["x-prefix-osl"] == "LOW" + # interarrival_ms.mean=500.0 (>= 500) → x-prefix-iat="HIGH" + assert "x-prefix-iat" in captured_headers + assert captured_headers["x-prefix-iat"] == "HIGH" await client.aclose() diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py index 7e5fae54f3..66c89cd7e4 100644 --- a/tests/nat/llm/test_runtime_prediction_e2e.py +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -118,10 +118,10 @@ async def test_e2e_prediction_headers_injected_correctly(): request1 = MockRequest() await hook(request1) - # Should have call 1 predictions: 2 remaining - assert request1.headers["x-nat-remaining-llm-calls"] == "2" - assert request1.headers["x-nat-interarrival-ms"] == "500" - assert request1.headers["x-nat-expected-output-tokens"] == "200" + # Should have call 1 predictions: remaining_calls.mean=2.0, output_tokens.p90=200 + assert request1.headers["x-prefix-total-requests"] == "2" + assert request1.headers["x-prefix-osl"] == "LOW" # 200 tokens < 256 + assert request1.headers["x-prefix-iat"] == "HIGH" # 500ms >= 500 # Simulate second LLM call step_manager.push_intermediate_step( @@ -134,10 +134,10 @@ async def test_e2e_prediction_headers_injected_correctly(): request2 = MockRequest() await hook(request2) - # Should have call 2 predictions: 1 remaining - assert request2.headers["x-nat-remaining-llm-calls"] == "1" - assert request2.headers["x-nat-interarrival-ms"] == "300" - assert request2.headers["x-nat-expected-output-tokens"] == "150" + # Should have call 2 predictions: remaining_calls.mean=1.0, output_tokens.p90=150 + assert request2.headers["x-prefix-total-requests"] == "1" + assert request2.headers["x-prefix-osl"] == "LOW" # 150 tokens < 256 + assert request2.headers["x-prefix-iat"] == "MEDIUM" # 300ms is 100-500 # Simulate third LLM call step_manager.push_intermediate_step( @@ -150,9 +150,9 @@ async def test_e2e_prediction_headers_injected_correctly(): request3 = MockRequest() await hook(request3) - # Should have call 3 predictions: 0 remaining - assert request3.headers["x-nat-remaining-llm-calls"] == "0" - assert request3.headers["x-nat-expected-output-tokens"] == "120" + # Should have call 3 predictions: remaining_calls.mean=0.0, output_tokens.p90=120 + assert request3.headers["x-prefix-total-requests"] == "0" + assert request3.headers["x-prefix-osl"] == "LOW" # 120 tokens < 256 async def test_e2e_fallback_to_root(): @@ -179,6 +179,6 @@ async def test_e2e_fallback_to_root(): request = MockRequest() await hook(request) - # Should fall back to root aggregated predictions - assert "x-nat-remaining-llm-calls" in request.headers - assert request.headers["x-nat-remaining-llm-calls"] == "1" # aggregated mean + # Should fall back to root aggregated predictions (remaining_calls.mean=1.0, output_tokens.p90=160 + assert "x-prefix-total-requests" in request.headers + assert request.headers["x-prefix-total-requests"] == "1" # aggregated mean From fa79197d7543ea0281bf52818cfeaeb86afaad7f Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Sat, 24 Jan 2026 13:34:36 -0800 Subject: [PATCH 31/37] Remove DynamoPrefixContext handling in Runner class Eliminated the setup and cleanup of DynamoPrefixContext from the Runner class as it is no longer required for KV cache optimization. This simplifies the workflow logic and reduces dependencies, ensuring cleaner and more maintainable code. Signed-off-by: dnandakumar-nv --- src/nat/runtime/runner.py | 18 - .../nat/runtime/test_runner_dynamo_prefix.py | 358 ------------------ 2 files changed, 376 deletions(-) delete mode 100644 tests/nat/runtime/test_runner_dynamo_prefix.py diff --git a/src/nat/runtime/runner.py b/src/nat/runtime/runner.py index fa6f4ac8a5..85ea726dcd 100644 --- a/src/nat/runtime/runner.py +++ b/src/nat/runtime/runner.py @@ -161,12 +161,6 @@ async def result(self, to_type: type | None = None): token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) - # Set Dynamo prefix context for KV cache optimization - # Each workflow invocation gets a unique prefix ID based on the run ID - # Lazy import to avoid circular dependency - from nat.llm.dynamo_llm import DynamoPrefixContext - DynamoPrefixContext.set(f"nat-workflow-{workflow_run_id}") - # Prepare workflow-level intermediate step identifiers workflow_step_uuid = str(uuid.uuid4()) workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" @@ -217,9 +211,6 @@ async def result(self, to_type: type | None = None): self._state = RunnerState.FAILED raise finally: - # Lazy import to avoid circular dependency - from nat.llm.dynamo_llm import DynamoPrefixContext - DynamoPrefixContext.clear() if token_run_id is not None: self._context_state.workflow_run_id.reset(token_run_id) if token_trace_id is not None: @@ -249,12 +240,6 @@ async def result_stream(self, to_type: type | None = None): token_run_id = self._context_state.workflow_run_id.set(workflow_run_id) token_trace_id = self._context_state.workflow_trace_id.set(workflow_trace_id) - # Set Dynamo prefix context for KV cache optimization - # Each workflow invocation gets a unique prefix ID based on the run ID - # Lazy import to avoid circular dependency - from nat.llm.dynamo_llm import DynamoPrefixContext - DynamoPrefixContext.set(f"nat-workflow-{workflow_run_id}") - # Prepare workflow-level intermediate step identifiers workflow_step_uuid = str(uuid.uuid4()) workflow_name = getattr(self._entry_fn, 'instance_name', None) or "workflow" @@ -311,9 +296,6 @@ async def result_stream(self, to_type: type | None = None): self._state = RunnerState.FAILED raise finally: - # Lazy import to avoid circular dependency - from nat.llm.dynamo_llm import DynamoPrefixContext - DynamoPrefixContext.clear() if token_run_id is not None: self._context_state.workflow_run_id.reset(token_run_id) if token_trace_id is not None: diff --git a/tests/nat/runtime/test_runner_dynamo_prefix.py b/tests/nat/runtime/test_runner_dynamo_prefix.py deleted file mode 100644 index a0c61ca147..0000000000 --- a/tests/nat/runtime/test_runner_dynamo_prefix.py +++ /dev/null @@ -1,358 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for DynamoPrefixContext integration in the Runner class. - -These tests verify that the Runner properly sets and clears the DynamoPrefixContext -for KV cache optimization when using Dynamo LLM backends. -""" - -from collections.abc import AsyncGenerator - -import pytest - -from nat.builder.builder import Builder -from nat.builder.context import ContextState -from nat.builder.workflow_builder import WorkflowBuilder -from nat.cli.register_workflow import register_function -from nat.data_models.function import FunctionBaseConfig -from nat.llm.dynamo_llm import DynamoPrefixContext -from nat.observability.exporter_manager import ExporterManager -from nat.runtime.runner import Runner - - -class SingleOutputConfig(FunctionBaseConfig, name="single_output_dynamo_test"): - pass - - -class StreamOutputConfig(FunctionBaseConfig, name="stream_output_dynamo_test"): - pass - - -class CaptureConfig(FunctionBaseConfig, name="capture_dynamo_prefix"): - pass - - -@pytest.fixture(scope="module", autouse=True) -async def _register_single_output_fn(): - - @register_function(config_type=SingleOutputConfig) - async def register(config: SingleOutputConfig, b: Builder): - - async def _inner(message: str) -> str: - return message + "!" - - yield _inner - - -@pytest.fixture(scope="module", autouse=True) -async def _register_stream_output_fn(): - - @register_function(config_type=StreamOutputConfig) - async def register(config: StreamOutputConfig, b: Builder): - - async def _inner_stream(message: str) -> AsyncGenerator[str]: - yield message + "!" - - yield _inner_stream - - -@pytest.fixture(autouse=True) -def clean_dynamo_context(): - """Ensure DynamoPrefixContext is clean before and after each test.""" - DynamoPrefixContext.clear() - yield - DynamoPrefixContext.clear() - - -async def test_runner_result_sets_dynamo_prefix_context(): - """Test that Runner.result() sets DynamoPrefixContext with unique prefix ID.""" - captured_prefix_ids = [] - - @register_function(config_type=CaptureConfig) - async def _register(config: CaptureConfig, b: Builder): - - async def _capture(message: str) -> str: - # Capture the prefix ID during execution - prefix_id = DynamoPrefixContext.get() - captured_prefix_ids.append(prefix_id) - return message - - yield _capture - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="capture_fn", config=CaptureConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - await runner.result() - - # Verify prefix ID was set during execution - assert len(captured_prefix_ids) == 1 - assert captured_prefix_ids[0] is not None - assert captured_prefix_ids[0].startswith("nat-workflow-") - - -async def test_runner_result_clears_dynamo_prefix_context_after_completion(): - """Test that Runner.result() clears DynamoPrefixContext after workflow completes.""" - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="test_fn", config=SingleOutputConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - await runner.result() - - # Verify prefix ID is cleared after execution - assert DynamoPrefixContext.get() is None - - -async def test_runner_result_clears_dynamo_prefix_context_on_error(): - """Test that Runner.result() clears DynamoPrefixContext even when workflow fails.""" - - class ErrorConfig(FunctionBaseConfig, name="error_dynamo_test"): - pass - - @register_function(config_type=ErrorConfig) - async def _register(config: ErrorConfig, b: Builder): - - async def _error(message: str) -> str: - raise ValueError("Simulated error") - - yield _error - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="error_fn", config=ErrorConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - with pytest.raises(ValueError, match="Simulated error"): - await runner.result() - - # Verify prefix ID is cleared even after error - assert DynamoPrefixContext.get() is None - - -async def test_runner_result_different_invocations_get_unique_prefix_ids(): - """Test that different workflow invocations get unique prefix IDs.""" - captured_prefix_ids = [] - - class CaptureConfig2(FunctionBaseConfig, name="capture_dynamo_prefix2"): - pass - - @register_function(config_type=CaptureConfig2) - async def _register(config: CaptureConfig2, b: Builder): - - async def _capture(message: str) -> str: - prefix_id = DynamoPrefixContext.get() - captured_prefix_ids.append(prefix_id) - return message - - yield _capture - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="capture_fn", config=CaptureConfig2()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - # Run workflow multiple times - for i in range(3): - async with Runner(input_message=f"test{i}", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - await runner.result() - - # Each invocation should have a unique prefix ID - assert len(captured_prefix_ids) == 3 - assert len(set(captured_prefix_ids)) == 3 # All unique - - -async def test_runner_result_stream_sets_dynamo_prefix_context(): - """Test that Runner.result_stream() sets DynamoPrefixContext with unique prefix ID.""" - captured_prefix_ids = [] - - class StreamCaptureConfig(FunctionBaseConfig, name="stream_capture_dynamo"): - pass - - @register_function(config_type=StreamCaptureConfig) - async def _register(config: StreamCaptureConfig, b: Builder): - - async def _capture_stream(message: str) -> AsyncGenerator[str]: - prefix_id = DynamoPrefixContext.get() - captured_prefix_ids.append(prefix_id) - yield message - - yield _capture_stream - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="stream_capture_fn", config=StreamCaptureConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - async for _ in runner.result_stream(): - pass - - # Verify prefix ID was set during execution - assert len(captured_prefix_ids) == 1 - assert captured_prefix_ids[0] is not None - assert captured_prefix_ids[0].startswith("nat-workflow-") - - -async def test_runner_result_stream_clears_dynamo_prefix_context_after_completion(): - """Test that Runner.result_stream() clears DynamoPrefixContext after workflow completes.""" - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="stream_fn", config=StreamOutputConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - async for _ in runner.result_stream(): - pass - - # Verify prefix ID is cleared after execution - assert DynamoPrefixContext.get() is None - - -async def test_runner_result_stream_clears_dynamo_prefix_context_on_error(): - """Test that Runner.result_stream() clears DynamoPrefixContext even when workflow fails.""" - - class StreamErrorConfig(FunctionBaseConfig, name="stream_error_dynamo_test"): - pass - - @register_function(config_type=StreamErrorConfig) - async def _register(config: StreamErrorConfig, b: Builder): - - async def _error_stream(message: str) -> AsyncGenerator[str]: - raise ValueError("Simulated stream error") - yield message # Make it a generator - - yield _error_stream - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="stream_error_fn", config=StreamErrorConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - with pytest.raises(ValueError, match="Simulated stream error"): - async for _ in runner.result_stream(): - pass - - # Verify prefix ID is cleared even after error - assert DynamoPrefixContext.get() is None - - -async def test_runner_prefix_id_based_on_workflow_run_id(): - """Test that the prefix ID is based on the workflow_run_id.""" - captured_prefix_id = None - - class PrefixCheckConfig(FunctionBaseConfig, name="prefix_check_dynamo"): - pass - - @register_function(config_type=PrefixCheckConfig) - async def _register(config: PrefixCheckConfig, b: Builder): - - async def _check(message: str) -> str: - nonlocal captured_prefix_id - captured_prefix_id = DynamoPrefixContext.get() - return message - - yield _check - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="prefix_check_fn", config=PrefixCheckConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - await runner.result() - - # The prefix ID should be in the expected format - assert captured_prefix_id is not None - assert captured_prefix_id.startswith("nat-workflow-") - # Verify the UUID portion is valid (36 chars with hyphens) - uuid_part = captured_prefix_id[len("nat-workflow-"):] - assert len(uuid_part) == 36 - - -async def test_runner_uses_existing_workflow_run_id_for_prefix(): - """Test that Runner uses existing workflow_run_id when set externally.""" - captured_prefix_id = None - preset_run_id = "preset-external-run-id-12345" - - class ExternalIdConfig(FunctionBaseConfig, name="external_id_dynamo"): - pass - - @register_function(config_type=ExternalIdConfig) - async def _register(config: ExternalIdConfig, b: Builder): - - async def _check(message: str) -> str: - nonlocal captured_prefix_id - captured_prefix_id = DynamoPrefixContext.get() - return message - - yield _check - - async with WorkflowBuilder() as builder: - entry_fn = await builder.add_function(name="external_id_fn", config=ExternalIdConfig()) - - context_state = ContextState() - exporter_manager = ExporterManager() - - # Pre-set the workflow_run_id - token = context_state.workflow_run_id.set(preset_run_id) - try: - async with Runner(input_message="test", - entry_fn=entry_fn, - context_state=context_state, - exporter_manager=exporter_manager) as runner: - await runner.result() - finally: - context_state.workflow_run_id.reset(token) - - # The prefix ID should use the pre-set workflow_run_id - assert captured_prefix_id == f"nat-workflow-{preset_run_id}" From 2cba23cfbc12c1d5f4d2763067388ad709c6eba1 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 11:33:25 -0800 Subject: [PATCH 32/37] Add Apache 2.0 license headers to source and test files Updated files to include full Apache 2.0 license text, ensuring clarity on usage and distribution under the license. This change ensures compliance with legal requirements and improves consistency across the repository. Signed-off-by: dnandakumar-nv --- .../README_PREDICTION_TRIE.md | 17 +++++++++++++++++ src/nat/llm/dynamo_llm.py | 2 +- src/nat/llm/prediction_context.py | 14 +++++++++++++- .../prediction_trie/metrics_accumulator.py | 12 ++++++++++++ .../profiler/prediction_trie/trie_builder.py | 12 ++++++++++++ .../builder/test_call_tracker_integration.py | 12 ++++++++++++ tests/nat/builder/test_function_path_stack.py | 12 ++++++++++++ tests/nat/llm/test_dynamic_prediction_hook.py | 12 ++++++++++++ tests/nat/llm/test_dynamo_prediction_headers.py | 12 ++++++++++++ tests/nat/llm/test_dynamo_prediction_trie.py | 12 ++++++++++++ tests/nat/llm/test_prediction_context.py | 12 ++++++++++++ tests/nat/llm/test_runtime_prediction_e2e.py | 13 +++++++++++++ .../langchain/test_dynamo_trie_loading.py | 12 ++++++++++++ .../prediction_trie/test_metrics_accumulator.py | 12 ++++++++++++ .../prediction_trie/test_serialization.py | 12 ++++++++++++ .../prediction_trie/test_trie_builder.py | 12 ++++++++++++ .../prediction_trie/test_trie_lookup.py | 12 ++++++++++++ 17 files changed, 200 insertions(+), 2 deletions(-) diff --git a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md index 332b3cf8c5..bd8174c01b 100644 --- a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md +++ b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md @@ -1,3 +1,20 @@ + + + # Prediction Trie Optimization for Dynamo Use profiled execution data to inject accurate per-call prediction headers instead of static guesses. diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index f96eb6b314..a55b3e87be 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -72,8 +72,8 @@ from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import SearchSpace from nat.llm.openai_llm import OpenAIModelConfig -from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.llm.utils.constants import LLMHeaderPrefix +from nat.profiler.prediction_trie.data_models import LLMCallPrediction logger = logging.getLogger(__name__) diff --git a/src/nat/llm/prediction_context.py b/src/nat/llm/prediction_context.py index dd4757dba0..91184baeda 100644 --- a/src/nat/llm/prediction_context.py +++ b/src/nat/llm/prediction_context.py @@ -1,5 +1,17 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Runtime context management for prediction trie lookups. diff --git a/src/nat/profiler/prediction_trie/metrics_accumulator.py b/src/nat/profiler/prediction_trie/metrics_accumulator.py index 19ba67ead0..e2313b17c3 100644 --- a/src/nat/profiler/prediction_trie/metrics_accumulator.py +++ b/src/nat/profiler/prediction_trie/metrics_accumulator.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math diff --git a/src/nat/profiler/prediction_trie/trie_builder.py b/src/nat/profiler/prediction_trie/trie_builder.py index 836cb3cbc3..3b2f285a53 100644 --- a/src/nat/profiler/prediction_trie/trie_builder.py +++ b/src/nat/profiler/prediction_trie/trie_builder.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import annotations diff --git a/tests/nat/builder/test_call_tracker_integration.py b/tests/nat/builder/test_call_tracker_integration.py index dc1695f91b..8b3d74c27d 100644 --- a/tests/nat/builder/test_call_tracker_integration.py +++ b/tests/nat/builder/test_call_tracker_integration.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from nat.builder.context import Context from nat.data_models.intermediate_step import IntermediateStepPayload diff --git a/tests/nat/builder/test_function_path_stack.py b/tests/nat/builder/test_function_path_stack.py index da65a7fa26..508cf27f15 100644 --- a/tests/nat/builder/test_function_path_stack.py +++ b/tests/nat/builder/test_function_path_stack.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from nat.builder.context import Context from nat.builder.context import ContextState diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py index 1bab5ab641..a07f8bb110 100644 --- a/tests/nat/llm/test_dynamic_prediction_hook.py +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest diff --git a/tests/nat/llm/test_dynamo_prediction_headers.py b/tests/nat/llm/test_dynamo_prediction_headers.py index e4b805f4ea..9f581a3181 100644 --- a/tests/nat/llm/test_dynamo_prediction_headers.py +++ b/tests/nat/llm/test_dynamo_prediction_headers.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from nat.llm.dynamo_llm import create_httpx_client_with_prediction_headers from nat.profiler.prediction_trie.data_models import LLMCallPrediction diff --git a/tests/nat/llm/test_dynamo_prediction_trie.py b/tests/nat/llm/test_dynamo_prediction_trie.py index 517013bedc..043955192f 100644 --- a/tests/nat/llm/test_dynamo_prediction_trie.py +++ b/tests/nat/llm/test_dynamo_prediction_trie.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import tempfile from pathlib import Path diff --git a/tests/nat/llm/test_prediction_context.py b/tests/nat/llm/test_prediction_context.py index b0132afc35..149bbca26d 100644 --- a/tests/nat/llm/test_prediction_context.py +++ b/tests/nat/llm/test_prediction_context.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from nat.llm.prediction_context import LLMCallTracker from nat.llm.prediction_context import get_call_tracker diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py index 66c89cd7e4..7928ba54e2 100644 --- a/tests/nat/llm/test_runtime_prediction_e2e.py +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -1,5 +1,18 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """End-to-end test for runtime prediction trie integration. This test validates that all pieces work together: diff --git a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py index 2156f68d51..95ca687671 100644 --- a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py +++ b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import tempfile from pathlib import Path diff --git a/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py index 5d329db61d..46eb4ddd5c 100644 --- a/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py +++ b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest diff --git a/tests/nat/profiler/prediction_trie/test_serialization.py b/tests/nat/profiler/prediction_trie/test_serialization.py index a617d5d48d..289f15d2d2 100644 --- a/tests/nat/profiler/prediction_trie/test_serialization.py +++ b/tests/nat/profiler/prediction_trie/test_serialization.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import json import tempfile diff --git a/tests/nat/profiler/prediction_trie/test_trie_builder.py b/tests/nat/profiler/prediction_trie/test_trie_builder.py index 6b964b835f..e68cf2eb78 100644 --- a/tests/nat/profiler/prediction_trie/test_trie_builder.py +++ b/tests/nat/profiler/prediction_trie/test_trie_builder.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest diff --git a/tests/nat/profiler/prediction_trie/test_trie_lookup.py b/tests/nat/profiler/prediction_trie/test_trie_lookup.py index 2fde204e5d..58e07aae89 100644 --- a/tests/nat/profiler/prediction_trie/test_trie_lookup.py +++ b/tests/nat/profiler/prediction_trie/test_trie_lookup.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest From 7d2c087b579a842f0b58b573163d6d7d81104259 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 12:15:38 -0800 Subject: [PATCH 33/37] Add "Trie(s)" to accepted vocabulary list This update includes "Trie(s)" in the NAT vocabulary file for Vale. It ensures that the term is recognized as valid during linting. Signed-off-by: dnandakumar-nv --- ci/vale/styles/config/vocabularies/nat/accept.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/vale/styles/config/vocabularies/nat/accept.txt b/ci/vale/styles/config/vocabularies/nat/accept.txt index 1de8c93b2d..05c9bb5c95 100644 --- a/ci/vale/styles/config/vocabularies/nat/accept.txt +++ b/ci/vale/styles/config/vocabularies/nat/accept.txt @@ -179,6 +179,7 @@ Tavily [Tt]imestamp(s?) [Tt]okenization [Tt]okenizer(s?) +[Tt]rie(s?) triages [Uu]ncomment(ed)? [Uu]nencrypted From 727f5645f1f293f9c791445122d35810d5204185 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 12:26:26 -0800 Subject: [PATCH 34/37] Update README and test files for clarity and consistency Updated headers and descriptions in the README to improve terminology clarity. Removed an unnecessary blank line in the test file for better formatting. Signed-off-by: dnandakumar-nv --- .../react_benchmark_agent/README_PREDICTION_TRIE.md | 6 +++--- tests/nat/llm/test_runtime_prediction_e2e.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md index bd8174c01b..8469efcad9 100644 --- a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md +++ b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md @@ -62,7 +62,7 @@ The profiler collects data for each LLM call: - Function path at time of call - Call index within the parent function - Output tokens generated -- Time until the next LLM call (interarrival) +- Time until the next LLM call - Remaining LLM calls in the workflow This data is aggregated into a trie structure with statistical summaries (mean, p50, p90, etc.) at each node. @@ -114,7 +114,7 @@ To measure the impact of prediction trie vs static headers: ## Configuration Reference -### Profiler Config (Phase 1) +### Profiler Configuration (Phase 1) Enable trie building in the profiler section: @@ -125,7 +125,7 @@ profiler: output_filename: prediction_trie.json # default ``` -### LLM Config (Phase 2) +### LLM Configuration (Phase 2) Add the trie path to your Dynamo LLM config: diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py index 7928ba54e2..c88d20659d 100644 --- a/tests/nat/llm/test_runtime_prediction_e2e.py +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """End-to-end test for runtime prediction trie integration. This test validates that all pieces work together: From 74c191d74e4066ec83680a0c78d531666a624a52 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 12:39:51 -0800 Subject: [PATCH 35/37] Fix formatting of `job_id` in README_PREDICTION_TRIE.md Corrected the formatting of `job_id` to use code style for consistency and clarity. This improves readability and aligns with standard documentation practices. Signed-off-by: dnandakumar-nv --- .../react_benchmark_agent/README_PREDICTION_TRIE.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md index 8469efcad9..767f0166eb 100644 --- a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md +++ b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md @@ -150,7 +150,7 @@ llms: The trie file doesn't exist at the configured path. Check: - Did Phase 1 profiling complete successfully? -- Is the job_id in the path correct? +- Is the `job_id` in the path correct? - Is the path relative to where you're running the command? ### "No prediction found for path" From c412dc88d251cb97e775228295b24059ac48cc17 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 13:03:00 -0800 Subject: [PATCH 36/37] Add Apache 2.0 license headers to test files Included the full Apache 2.0 license header in two test files for compliance. This ensures proper licensing alignment and clarifies usage terms for these files. Signed-off-by: dnandakumar-nv --- tests/nat/profiler/test_prediction_trie_e2e.py | 12 ++++++++++++ .../nat/profiler/test_prediction_trie_integration.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/tests/nat/profiler/test_prediction_trie_e2e.py b/tests/nat/profiler/test_prediction_trie_e2e.py index 70ea27b5d1..492add53cf 100644 --- a/tests/nat/profiler/test_prediction_trie_e2e.py +++ b/tests/nat/profiler/test_prediction_trie_e2e.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """End-to-end test for prediction trie workflow.""" import tempfile diff --git a/tests/nat/profiler/test_prediction_trie_integration.py b/tests/nat/profiler/test_prediction_trie_integration.py index 8484a8e078..2d54b7d860 100644 --- a/tests/nat/profiler/test_prediction_trie_integration.py +++ b/tests/nat/profiler/test_prediction_trie_integration.py @@ -1,5 +1,17 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import tempfile from pathlib import Path From 22327b96fe4b898008a4142b99cba61f4b719c07 Mon Sep 17 00:00:00 2001 From: dnandakumar-nv Date: Mon, 26 Jan 2026 13:47:00 -0800 Subject: [PATCH 37/37] Refactor imports for PredictionTrieLookup across modules Updated imports to fetch `PredictionTrieLookup` directly from the `trie_lookup` submodule for better clarity and modularity. Adjusted `__init__.py` to avoid re-exporting `PredictionTrieLookup` to prevent Sphinx cross-reference warnings. Additionally, reformatted and clarified docstrings and field descriptions for improved readability. Signed-off-by: dnandakumar-nv --- .../src/nat/plugins/langchain/llm.py | 2 +- src/nat/llm/dynamo_llm.py | 39 ++++++++++--------- src/nat/profiler/prediction_trie/__init__.py | 6 ++- tests/nat/llm/test_dynamic_prediction_hook.py | 2 +- tests/nat/llm/test_runtime_prediction_e2e.py | 2 +- .../langchain/test_dynamo_trie_loading.py | 2 +- .../nat/profiler/test_prediction_trie_e2e.py | 2 +- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index ea864cbf8f..daf4221675 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -43,8 +43,8 @@ from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking -from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup from nat.utils.exception_handlers.automatic_retries import patch_with_retry from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index a55b3e87be..95457f20c6 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -60,7 +60,7 @@ if TYPE_CHECKING: import httpx - from nat.profiler.prediction_trie import PredictionTrieLookup + from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup from pydantic import Field @@ -298,19 +298,23 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"): "Lower values allow more load balancing across workers."), space=SearchSpace(low=1, high=20, step=5)) - prefix_osl: PrefixLevel = OptimizableField(default="MEDIUM", - description=("Output Sequence Length hint for the Dynamo router. " - "LOW=short responses (decode_cost=1.0), " - "MEDIUM=typical (decode_cost=2.0), " - "HIGH=long responses (decode_cost=3.0)."), - space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"])) + prefix_osl: PrefixLevel = OptimizableField( + default="MEDIUM", + description="Output Sequence Length hint for the Dynamo router. " + "LOW means short responses (decode_cost=1.0), " + "MEDIUM means typical (decode_cost=2.0), " + "HIGH means long responses (decode_cost=3.0).", + space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]), + ) - prefix_iat: PrefixLevel = OptimizableField(default="MEDIUM", - description=("Inter-Arrival Time hint for the Dynamo router. " - "LOW=rapid bursts (iat_factor=1.5, high stickiness), " - "MEDIUM=normal (iat_factor=1.0), " - "HIGH=slow requests (iat_factor=0.6, more exploration)."), - space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"])) + prefix_iat: PrefixLevel = OptimizableField( + default="MEDIUM", + description="Inter-Arrival Time hint for the Dynamo router. " + "LOW means rapid bursts (iat_factor=1.5, high stickiness), " + "MEDIUM means normal (iat_factor=1.0), " + "HIGH means slow requests (iat_factor=0.6, more exploration).", + space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]), + ) request_timeout: float = Field( default=600.0, @@ -376,11 +380,10 @@ def _create_dynamo_request_hook( its own unique prefix ID that remains constant within a workflow run. Args: - prefix_template: Template string with {uuid} placeholder (currently unused, - kept for API compatibility) - total_requests: Expected number of requests for this prefix - osl: Output sequence length hint (LOW/MEDIUM/HIGH) - iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + prefix_template: Template string with {uuid} placeholder (unused, for API compat). + total_requests: Expected number of requests for this prefix. + osl: Output sequence length hint (LOW/MEDIUM/HIGH). + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH). Returns: An async function suitable for use as an httpx event hook. diff --git a/src/nat/profiler/prediction_trie/__init__.py b/src/nat/profiler/prediction_trie/__init__.py index 35d302c854..ca71b36a68 100644 --- a/src/nat/profiler/prediction_trie/__init__.py +++ b/src/nat/profiler/prediction_trie/__init__.py @@ -19,13 +19,15 @@ from nat.profiler.prediction_trie.serialization import load_prediction_trie from nat.profiler.prediction_trie.serialization import save_prediction_trie from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder -from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + +# Note: PredictionTrieLookup is intentionally not re-exported here to avoid +# Sphinx cross-reference warnings. Import from trie_lookup submodule directly: +# from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup __all__ = [ "LLMCallPrediction", "PredictionMetrics", "PredictionTrieBuilder", - "PredictionTrieLookup", "PredictionTrieNode", "load_prediction_trie", "save_prediction_trie", diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py index a07f8bb110..08702b1a04 100644 --- a/tests/nat/llm/test_dynamic_prediction_hook.py +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -19,10 +19,10 @@ from nat.llm.dynamo_llm import _create_dynamic_prediction_hook from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks from nat.llm.prediction_context import get_call_tracker -from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup @pytest.fixture(name="sample_trie_lookup") diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py index c88d20659d..d33d3d57bb 100644 --- a/tests/nat/llm/test_runtime_prediction_e2e.py +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -28,12 +28,12 @@ from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepType from nat.llm.dynamo_llm import _create_dynamic_prediction_hook -from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie import load_prediction_trie from nat.profiler.prediction_trie import save_prediction_trie from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup class MockRequest: diff --git a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py index 95ca687671..f06a19d116 100644 --- a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py +++ b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py @@ -24,11 +24,11 @@ from nat.builder.builder import Builder from nat.llm.dynamo_llm import DynamoModelConfig from nat.plugins.langchain.llm import dynamo_langchain -from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie import save_prediction_trie from nat.profiler.prediction_trie.data_models import LLMCallPrediction from nat.profiler.prediction_trie.data_models import PredictionMetrics from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup @pytest.fixture(name="trie_file") diff --git a/tests/nat/profiler/test_prediction_trie_e2e.py b/tests/nat/profiler/test_prediction_trie_e2e.py index 492add53cf..6b2ab34853 100644 --- a/tests/nat/profiler/test_prediction_trie_e2e.py +++ b/tests/nat/profiler/test_prediction_trie_e2e.py @@ -25,8 +25,8 @@ from nat.data_models.profiler import PredictionTrieConfig from nat.data_models.profiler import ProfilerConfig from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel -from nat.profiler.prediction_trie import PredictionTrieLookup from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup from nat.profiler.profile_runner import ProfilerRunner