diff --git a/ci/vale/styles/config/vocabularies/nat/accept.txt b/ci/vale/styles/config/vocabularies/nat/accept.txt index 1de8c93b2d..05c9bb5c95 100644 --- a/ci/vale/styles/config/vocabularies/nat/accept.txt +++ b/ci/vale/styles/config/vocabularies/nat/accept.txt @@ -179,6 +179,7 @@ Tavily [Tt]imestamp(s?) [Tt]okenization [Tt]okenizer(s?) +[Tt]rie(s?) triages [Uu]ncomment(ed)? [Uu]nencrypted diff --git a/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md new file mode 100644 index 0000000000..767f0166eb --- /dev/null +++ b/examples/dynamo_integration/react_benchmark_agent/README_PREDICTION_TRIE.md @@ -0,0 +1,172 @@ + + + +# Prediction Trie Optimization for Dynamo + +Use profiled execution data to inject accurate per-call prediction headers instead of static guesses. + +## Overview + +The prediction trie enables **dynamic header injection** for Dynamo's KV-aware routing. Instead of using static values like `prefix_total_requests=10` for every call, the trie provides accurate predictions based on: +- **Function path**: Where in the agent hierarchy the call originates (e.g., `["react_workflow", "react_agent"]`) +- **Call index**: Which LLM call this is within the current function (1st, 2nd, 3rd, etc.) + +This allows Dynamo's Thompson Sampling router to make better worker assignment decisions. + +## Quick Start + +### Phase 1: Build the Prediction Trie + +Run profiling to collect execution data and build the trie: + +```bash +nat eval --config_file configs/profile_rethinking_full_test.yml +``` + +**Output location:** +``` +outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json +``` + +### Phase 2: Run with Predictions + +1. **Update the trie path** in `configs/run_with_prediction_trie.yml`: + ```yaml + prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json + ``` + +2. **Run with dynamic predictions:** + ```bash + nat eval --config_file configs/run_with_prediction_trie.yml + ``` + +## How It Works + +### During Profiling (Phase 1) + +The profiler collects data for each LLM call: +- Function path at time of call +- Call index within the parent function +- Output tokens generated +- Time until the next LLM call +- Remaining LLM calls in the workflow + +This data is aggregated into a trie structure with statistical summaries (mean, p50, p90, etc.) at each node. + +### During Execution (Phase 2) + +For each LLM request: +1. Read the current function path from context +2. Read the call index from the LLM call tracker +3. Look up the prediction in the trie +4. Inject headers into the HTTP request + +### Fallback Chain + +If an exact match isn't found, the trie lookup falls back: +1. Exact path + exact call index (most specific) +2. Exact path + any call index +3. Partial path + exact call index +4. Root aggregated stats (most general) + +This ensures predictions are always available, even for novel execution paths. + +## Headers Injected + +| Header | Source | Description | +|--------|--------|-------------| +| `x-nat-remaining-llm-calls` | `prediction.remaining_calls.mean` | Expected remaining LLM calls in workflow | +| `x-nat-interarrival-ms` | `prediction.interarrival_ms.mean` | Expected milliseconds until next call | +| `x-nat-expected-output-tokens` | `prediction.output_tokens.p90` | Expected output tokens (90th percentile) | + +## Comparing Results + +To measure the impact of prediction trie vs static headers: + +1. **Run with static headers** (baseline): + ```bash + nat eval --config_file configs/eval_config_rethinking_full_test.yml + ``` + +2. **Run with prediction trie**: + ```bash + nat eval --config_file configs/run_with_prediction_trie.yml + ``` + +3. **Compare metrics**: + - `avg_llm_latency`: Lower is better + - `avg_workflow_runtime`: Lower is better + - Look for improvements in KV cache hit rates in Dynamo logs + +## Configuration Reference + +### Profiler Configuration (Phase 1) + +Enable trie building in the profiler section: + +```yaml +profiler: + prediction_trie: + enable: true + output_filename: prediction_trie.json # default +``` + +### LLM Configuration (Phase 2) + +Add the trie path to your Dynamo LLM config: + +```yaml +llms: + dynamo_llm: + _type: dynamo + prefix_template: "react-benchmark-{uuid}" + + # Static fallbacks (used if trie lookup fails) + prefix_total_requests: 10 + prefix_osl: MEDIUM + prefix_iat: MEDIUM + + # Dynamic predictions from profiled data + prediction_trie_path: /path/to/prediction_trie.json +``` + +## Troubleshooting + +### "Prediction trie file not found" + +The trie file doesn't exist at the configured path. Check: +- Did Phase 1 profiling complete successfully? +- Is the `job_id` in the path correct? +- Is the path relative to where you're running the command? + +### "No prediction found for path" + +This is normal - it means the trie is using fallback predictions. The trie will fall back to more general predictions when exact matches aren't found. + +### Headers not being injected + +Ensure: +- `prefix_template` is set (required for Dynamo hooks) +- `prediction_trie_path` points to a valid trie file +- You're using the `dynamo` LLM type + +## Files + +| File | Purpose | +|------|---------| +| `configs/profile_rethinking_full_test.yml` | Phase 1: Profile and build trie | +| `configs/run_with_prediction_trie.yml` | Phase 2: Run with dynamic predictions | diff --git a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml index 47b7e243fb..5dc1fb40a3 100644 --- a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml +++ b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/profile_rethinking_full_test.yml @@ -223,6 +223,11 @@ eval: concurrency_spike_analysis: enable: true spike_threshold: 24 # Alert when concurrent functions >= 24 + # Build prediction trie for dynamic Dynamo header injection + # Output: prediction_trie.json in the output directory + # Use with run_with_prediction_trie.yml for optimized routing + prediction_trie: + enable: true evaluators: tool_selection_quality: diff --git a/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml new file mode 100644 index 0000000000..b22f114692 --- /dev/null +++ b/examples/dynamo_integration/react_benchmark_agent/src/react_benchmark_agent/configs/run_with_prediction_trie.yml @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# ============================================================================= +# RUN WITH PREDICTION TRIE - DYNAMIC HEADER INJECTION +# ============================================================================= +# Purpose: Use profiled prediction trie for dynamic Dynamo header injection +# +# Prerequisites: +# 1. Run profiling first to build the prediction trie: +# nat eval --config_file configs/profile_rethinking_full_test.yml +# +# 2. Update prediction_trie_path below to point to the generated trie: +# outputs/dynamo_evals/rethinking_full_test_for_profiling//prediction_trie.json +# +# What this does: +# - Loads the prediction trie built from profiled execution data +# - For each LLM call, looks up predictions based on: +# * Current function path (e.g., ["react_workflow", "react_agent"]) +# * Call index within the current function +# - Injects dynamic headers per request: +# * x-nat-remaining-llm-calls: Expected remaining calls +# * x-nat-interarrival-ms: Expected time until next call +# * x-nat-expected-output-tokens: Expected output tokens (p90) +# +# Benefits over static headers: +# - Accurate per-call predictions instead of guessing prefix_total_requests=10 +# - Different predictions for different parts of the agent workflow +# - Dynamo router can make better worker assignment decisions +# +# Usage: +# nat eval --config_file configs/run_with_prediction_trie.yml +# ============================================================================= + +functions: + react_benchmark_agent: + _type: react_benchmark_agent + prefix: "Agent:" + decision_only: true + canned_response_template: "Successfully executed {tool_name}. Operation completed." + + # Define the ReAct workflow + react_workflow: + _type: react_agent + llm_name: dynamo_llm + tool_names: [ + banking_tools + ] + verbose: false # Disable verbose for benchmarking + parse_agent_response_max_retries: 3 + max_tool_calls: 25 + max_history: 1000 + pass_tool_call_errors_to_agent: true + recursion_limit: 50 + system_prompt: | + You are a tool-calling agent evaluated on TOOL SELECTION capability. Your goal is to select ALL the correct tools, in the correct order, to COMPLETELY handle real-world use-cases. + + IMPORTANT: This is a tool selection exercise, NOT real execution. + - Focus on selecting the RIGHT TOOL for each step + - Use placeholder or dummy values for required parameters (e.g., "12345", "user@example.com", "2024-01-01") + - Tool responses are simulated - ignore them and focus on selecting the next appropriate tool + - What matters is YOUR INTENT and TOOL CHOICE, not the data quality + + Available tools: + + {tools} + + Use this exact format for EACH response: + + Thought: I need to analyze what the user needs and select the SINGLE NEXT tool to call. + Action: the ONE tool to call right now, must be one of [{tool_names}] + Action Input: valid JSON with required parameters (use placeholder values) + + CRITICAL RULES: + 1. Output ONLY ONE Thought, Action, and Action Input per response + 2. STOP IMMEDIATELY after writing Action Input + 3. DO NOT write the Observation - the system will provide it + 4. DO NOT write multiple Thought/Action/Action Input cycles in one response + + When you have called ALL necessary tools: + Thought: I have now completed all necessary steps + Final Answer: [summary of what was accomplished] + +function_groups: + banking_tools: + _type: banking_tools_group + # tools.json available after running: /examples/dynamo_integration/scripts/download_agent_leaderboard_v2.py + tools_json_path: ./examples/dynamo_integration/data/raw/banking/tools.json + decision_only: true + include: [ + get_account_balance, + get_transaction_history, + transfer_funds, + get_loan_information, + get_credit_card_information, + get_mortgage_details, + get_savings_account_products, + schedule_appointment, + check_loan_application_status, + find_nearby_locations, + get_investment_products, + report_lost_stolen_card, + update_contact_information, + setup_automatic_bill_pay, + initiate_transaction_dispute, + get_exchange_rates, + calculate_loan_payment, + manage_account_alerts, + check_wire_transfer_status, + get_cd_products + ] + +llms: + # ========================================================================= + # DYNAMO LLM WITH PREDICTION TRIE + # ========================================================================= + # Uses prediction_trie_path to load profiled predictions and inject + # dynamic headers per LLM call based on current execution context. + dynamo_llm: + _type: dynamo + model_name: llama-3.3-70b + base_url: http://localhost:8099/v1 + api_key: dummy + temperature: 0.0 + max_tokens: 8192 + stop: ["Observation:", "\nThought:"] + + # Dynamo prefix configuration (required for prefix routing) + prefix_template: "react-benchmark-{uuid}" + + # Static fallback values (used if trie lookup fails) + prefix_total_requests: 10 + prefix_osl: MEDIUM + prefix_iat: MEDIUM + + # ========================================================================= + # PREDICTION TRIE - Dynamic per-call header injection + # ========================================================================= + # UPDATE THIS PATH to point to your profiled prediction trie: + # 1. Run: nat eval --config_file configs/profile_rethinking_full_test.yml + # 2. Find the job output directory (includes job_id) + # 3. Set path to: //prediction_trie.json + prediction_trie_path: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/rethinking_full_test_for_profiling/REPLACE_WITH_JOB_ID/prediction_trie.json + + # Secondary LLM for self-evaluation (no prediction trie needed) + eval_llm: + _type: dynamo + model_name: llama-3.3-70b + base_url: http://localhost:8099/v1 + api_key: dummy + temperature: 0.0 + max_tokens: 1024 + +# Advanced self-evaluating wrapper with feedback +workflow: + _type: self_evaluating_agent_with_feedback + wrapped_agent: react_workflow + evaluator_llm: eval_llm + max_retries: 3 + min_confidence_threshold: 0.80 + pass_feedback_to_agent: true + verbose: false + +eval: + general: + max_concurrency: 36 + + output: + dir: ./examples/dynamo_integration/react_benchmark_agent/outputs/dynamo_evals/prediction_trie_eval/ + cleanup: false + job_management: + append_job_id_to_output_dir: true + + dataset: + _type: json + file_path: ./examples/dynamo_integration/data/agent_leaderboard_v2_banking.json + structure: + disable: true + + # Lighter profiler config - we're consuming predictions, not building them + profiler: + compute_llm_metrics: true + csv_exclude_io_text: true + # No prediction_trie section - we're using the trie, not building it + + # ========================================================================= + # RUNTIME EVALUATORS - Compare against static header baseline + # ========================================================================= + evaluators: + # Primary metric: Average LLM latency per call (seconds) + avg_llm_latency: + _type: avg_llm_latency + max_concurrency: 36 + + # Secondary metric: Average workflow runtime (seconds) + avg_workflow_runtime: + _type: avg_workflow_runtime + max_concurrency: 36 + + # Tertiary metric: Average number of LLM calls + avg_num_llm_calls: + _type: avg_num_llm_calls + max_concurrency: 36 diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py index 3cca683aaa..daf4221675 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/llm.py @@ -16,6 +16,7 @@ import logging from collections.abc import Sequence +from pathlib import Path from typing import TYPE_CHECKING from typing import Any from typing import TypeVar @@ -42,6 +43,8 @@ from nat.llm.utils.thinking import BaseThinkingInjector from nat.llm.utils.thinking import FunctionArgumentWrapper from nat.llm.utils.thinking import patch_with_thinking +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup from nat.utils.exception_handlers.automatic_retries import patch_with_retry from nat.utils.responses_api import validate_no_responses_api from nat.utils.type_utils import override @@ -243,6 +246,19 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): # Initialize http_async_client to None for proper cleanup http_async_client = None + # Load prediction trie if configured + prediction_lookup: PredictionTrieLookup | None = None + if llm_config.prediction_trie_path: + try: + trie_path = Path(llm_config.prediction_trie_path) + trie = load_prediction_trie(trie_path) + prediction_lookup = PredictionTrieLookup(trie) + logger.info("Loaded prediction trie from %s", llm_config.prediction_trie_path) + except FileNotFoundError: + logger.warning("Prediction trie file not found: %s", llm_config.prediction_trie_path) + except Exception as e: + logger.warning("Failed to load prediction trie: %s", e) + try: # If prefix_template is set, create a custom httpx client with Dynamo hooks if llm_config.prefix_template is not None: @@ -252,14 +268,16 @@ async def dynamo_langchain(llm_config: DynamoModelConfig, _builder: Builder): osl=llm_config.prefix_osl, iat=llm_config.prefix_iat, timeout=llm_config.request_timeout, + prediction_lookup=prediction_lookup, ) config_dict["http_async_client"] = http_async_client logger.info( - "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s", + "Dynamo prefix headers enabled: template=%s, total_requests=%d, osl=%s, iat=%s, prediction_trie=%s", llm_config.prefix_template, llm_config.prefix_total_requests, llm_config.prefix_osl, llm_config.prefix_iat, + "loaded" if prediction_lookup else "disabled", ) # Create the ChatOpenAI client diff --git a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py index 98e60932a0..e5a4f18741 100644 --- a/packages/nvidia_nat_langchain/tests/test_llm_langchain.py +++ b/packages/nvidia_nat_langchain/tests/test_llm_langchain.py @@ -241,6 +241,7 @@ async def test_creation_with_prefix_template(self, osl="HIGH", iat="LOW", timeout=300.0, + prediction_lookup=None, ) # Verify ChatOpenAI was called with the custom httpx client diff --git a/src/nat/builder/context.py b/src/nat/builder/context.py index d74551878a..e22a1618c9 100644 --- a/src/nat/builder/context.py +++ b/src/nat/builder/context.py @@ -81,6 +81,7 @@ def __init__(self): self._event_stream: ContextVar[Subject[IntermediateStep] | None] = ContextVar("event_stream", default=None) self._active_function: ContextVar[InvocationNode | None] = ContextVar("active_function", default=None) self._active_span_id_stack: ContextVar[list[str] | None] = ContextVar("active_span_id_stack", default=None) + self._function_path_stack: ContextVar[list[str] | None] = ContextVar("function_path_stack", default=None) # Default is a lambda no-op which returns NoneType self.user_input_callback: ContextVar[Callable[[InteractionPrompt], Awaitable[HumanResponse | None]] @@ -115,6 +116,12 @@ def active_span_id_stack(self) -> ContextVar[list[str]]: self._active_span_id_stack.set(["root"]) return typing.cast(ContextVar[list[str]], self._active_span_id_stack) + @property + def function_path_stack(self) -> ContextVar[list[str]]: + if self._function_path_stack.get() is None: + self._function_path_stack.set([]) + return typing.cast(ContextVar[list[str]], self._function_path_stack) + @staticmethod def get() -> "ContextState": return ContextState() @@ -251,6 +258,11 @@ def push_active_function(self, # 1) Set the active function in the contextvar fn_token = self._context_state.active_function.set(current_function_node) + # 1b) Push function name onto path stack + current_path = self._context_state.function_path_stack.get() + new_path = current_path + [function_name] + path_token = self._context_state.function_path_stack.set(new_path) + # 2) Optionally record function start as an intermediate step step_manager = self.intermediate_step_manager step_manager.push_intermediate_step( @@ -275,7 +287,10 @@ def push_active_function(self, name=function_name, data=data)) - # 4) Unset the function contextvar + # 4a) Pop function name from path stack + self._context_state.function_path_stack.reset(path_token) + + # 4b) Unset the function contextvar self._context_state.active_function.reset(fn_token) @property @@ -288,6 +303,19 @@ def active_function(self) -> InvocationNode: """ return self._context_state.active_function.get() + @property + def function_path(self) -> list[str]: + """ + Returns a copy of the current function path stack. + + The function path represents the ancestry of the currently executing + function, from root to the current function. + + Returns: + list[str]: Copy of the function path stack. + """ + return list(self._context_state.function_path_stack.get()) + @property def active_span_id(self) -> str: """ diff --git a/src/nat/builder/intermediate_step_manager.py b/src/nat/builder/intermediate_step_manager.py index 2f9af79fa5..015a169451 100644 --- a/src/nat/builder/intermediate_step_manager.py +++ b/src/nat/builder/intermediate_step_manager.py @@ -22,6 +22,8 @@ from nat.data_models.intermediate_step import IntermediateStep from nat.data_models.intermediate_step import IntermediateStepPayload from nat.data_models.intermediate_step import IntermediateStepState +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker from nat.utils.reactive.observable import OnComplete from nat.utils.reactive.observable import OnError from nat.utils.reactive.observable import OnNext @@ -95,6 +97,16 @@ def push_intermediate_step(self, payload: IntermediateStepPayload) -> None: parent_step_id, id(active_span_id_stack)) + # Track LLM call index for prediction trie lookups + if payload.event_type == IntermediateStepType.LLM_START: + active_function = self._context_state.active_function.get() + if active_function and active_function.function_id != "root": + tracker = get_call_tracker() + tracker.increment(active_function.function_id) + logger.debug("Incremented LLM call tracker for %s to %d", + active_function.function_id, + tracker.counts.get(active_function.function_id, 0)) + elif (payload.event_state == IntermediateStepState.END): # Remove the current step from the outstanding steps diff --git a/src/nat/data_models/profiler.py b/src/nat/data_models/profiler.py index cb0ed64544..ed64ed2c18 100644 --- a/src/nat/data_models/profiler.py +++ b/src/nat/data_models/profiler.py @@ -40,6 +40,11 @@ class PrefixSpanConfig(BaseModel): chain_with_common_prefixes: bool = False +class PredictionTrieConfig(BaseModel): + enable: bool = False + output_filename: str = "prediction_trie.json" + + class ProfilerConfig(BaseModel): base_metrics: bool = False @@ -52,3 +57,4 @@ class ProfilerConfig(BaseModel): bottleneck_analysis: BottleneckConfig = BottleneckConfig() concurrency_spike_analysis: ConcurrencySpikeConfig = ConcurrencySpikeConfig() prefix_span_analysis: PrefixSpanConfig = PrefixSpanConfig() + prediction_trie: PredictionTrieConfig = PredictionTrieConfig() diff --git a/src/nat/llm/dynamo_llm.py b/src/nat/llm/dynamo_llm.py index 003ef9be88..95457f20c6 100644 --- a/src/nat/llm/dynamo_llm.py +++ b/src/nat/llm/dynamo_llm.py @@ -60,88 +60,190 @@ if TYPE_CHECKING: import httpx + from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + from pydantic import Field from nat.builder.builder import Builder +from nat.builder.context import Context +from nat.builder.context import Singleton from nat.builder.llm import LLMProviderInfo from nat.cli.register_workflow import register_llm_provider from nat.data_models.optimizable import OptimizableField from nat.data_models.optimizable import SearchSpace from nat.llm.openai_llm import OpenAIModelConfig from nat.llm.utils.constants import LLMHeaderPrefix +from nat.profiler.prediction_trie.data_models import LLMCallPrediction logger = logging.getLogger(__name__) # Define valid prefix hint values PrefixLevel = Literal["LOW", "MEDIUM", "HIGH"] +# ============================================================================= +# CATEGORY CONVERSION HELPERS +# ============================================================================= + + +def _output_tokens_to_osl(output_tokens: float) -> PrefixLevel: + """ + Convert predicted output tokens to OSL category. + + Thresholds: + - < 256 tokens: LOW (short responses) + - < 1024 tokens: MEDIUM (typical responses) + - >= 1024 tokens: HIGH (long responses) + """ + if output_tokens < 256: + return "LOW" + if output_tokens < 1024: + return "MEDIUM" + return "HIGH" + + +def _interarrival_ms_to_iat(interarrival_ms: float) -> PrefixLevel: + """ + Convert predicted interarrival time to IAT category. + + Thresholds: + - < 100ms: LOW (rapid bursts, high worker stickiness) + - < 500ms: MEDIUM (normal pacing) + - >= 500ms: HIGH (slow requests, more exploration) + """ + if interarrival_ms < 100: + return "LOW" + if interarrival_ms < 500: + return "MEDIUM" + return "HIGH" + + # ============================================================================= # CONTEXT MANAGEMENT FOR DYNAMO PREFIX ID # ============================================================================= -class DynamoPrefixContext: +class DynamoPrefixContext(metaclass=Singleton): """ Singleton class for managing Dynamo prefix IDs across LLM calls. - This allows evaluation code to set a prefix ID that persists across all LLM - calls for a single evaluation question (multi-turn conversation). + Prefix IDs are unique per depth level in the function call stack, allowing + different caching behavior at different levels of nested function calls. + Each depth level gets its own prefix ID that remains constant within a + single workflow run but changes between runs. + + The prefix ID format is: ``{workflow_run_id}-d{depth}`` Usage:: from nat.llm.dynamo_llm import DynamoPrefixContext - # Set prefix ID at the start of each evaluation question - DynamoPrefixContext.set("eval-q001-abc123") - - # ... perform LLM calls ... - - # Clear when done - DynamoPrefixContext.clear() + # Automatically gets prefix ID based on current call stack depth + prefix_id = DynamoPrefixContext.get() - # Or use as a context manager + # Or use as a context manager for explicit control with DynamoPrefixContext.scope("eval-q001-abc123"): - # ... perform LLM calls ... + # All LLM calls here will use "eval-q001-abc123" prefix + ... """ - _current_prefix_id: ContextVar[str | None] = ContextVar('dynamo_prefix_id', default=None) + # Maps depth -> prefix_id for the current workflow run + _prefix_ids_by_depth: ContextVar[dict[int, str] | None] = ContextVar('dynamo_prefix_ids_by_depth', default=None) + # Optional override that takes precedence over depth-based IDs + _override_prefix_id: ContextVar[str | None] = ContextVar('dynamo_override_prefix_id', default=None) + + @classmethod + def _get_current_depth(cls) -> int: + """Get the current function call stack depth from Context.""" + try: + ctx = Context.get() + return len(ctx.function_path) + except Exception: + return 0 + + @classmethod + def _get_or_create_depth_map(cls) -> dict[int, str]: + """Get or create the depth -> prefix_id mapping for this context.""" + depth_map = cls._prefix_ids_by_depth.get() + if depth_map is None: + depth_map = {} + cls._prefix_ids_by_depth.set(depth_map) + return depth_map @classmethod def set(cls, prefix_id: str) -> None: """ - Set the Dynamo prefix ID for the current context. + Set an override prefix ID that takes precedence over depth-based IDs. - Call this at the start of each evaluation question to ensure all LLM calls - for that question share the same prefix ID (enabling KV cache reuse). + Use this when you need explicit control over the prefix ID, such as + during batch evaluation where each question should have a specific ID. Args: - prefix_id: The unique prefix ID (e.g., "eval-q001-abc123") + prefix_id: The prefix ID to use (overrides depth-based generation) """ - cls._current_prefix_id.set(prefix_id) - logger.debug("Set Dynamo prefix ID: %s", prefix_id) + cls._override_prefix_id.set(prefix_id) + logger.debug("Set override Dynamo prefix ID: %s", prefix_id) @classmethod def clear(cls) -> None: - """Clear the current Dynamo prefix ID context.""" - cls._current_prefix_id.set(None) - logger.debug("Cleared Dynamo prefix ID") + """Clear all prefix ID state (both override and depth-based).""" + cls._override_prefix_id.set(None) + cls._prefix_ids_by_depth.set(None) + logger.debug("Cleared Dynamo prefix ID context") + + @classmethod + def get(cls) -> str: + """ + Get the Dynamo prefix ID for the current context. + + Returns the override prefix ID if set, otherwise returns a depth-based + prefix ID that is unique per workflow run and call stack depth. + + Returns: + The prefix ID string, never None. + """ + # Check for override first + override = cls._override_prefix_id.get() + if override: + return override + + # Get depth-based prefix ID + depth = cls._get_current_depth() + depth_map = cls._get_or_create_depth_map() + + if depth not in depth_map: + # Generate new prefix ID for this depth + try: + ctx = Context.get() + workflow_id = ctx.workflow_run_id + except Exception: + workflow_id = None + + if not workflow_id: + logger.warning("No workflow_run_id in context; using unique prefix ID.") + workflow_id = uuid.uuid4().hex[:16] + + prefix_id = f"{workflow_id}-d{depth}" + depth_map[depth] = prefix_id + logger.debug("Generated Dynamo prefix ID for depth %d: %s", depth, prefix_id) + + return depth_map[depth] @classmethod - def get(cls) -> str | None: - """Get the current Dynamo prefix ID from context, if any.""" - return cls._current_prefix_id.get() + def is_set(cls) -> bool: + """Check if a Dynamo prefix ID is available (always True, IDs are auto-generated).""" + return True @classmethod @contextmanager def scope(cls, prefix_id: str) -> Iterator[None]: """ - Context manager for scoped prefix ID usage. + Context manager for scoped override prefix ID usage. - Automatically sets the prefix ID on entry and clears it on exit, - ensuring proper cleanup even if exceptions occur. + Sets an override prefix ID on entry and restores the previous state on exit, + ensuring proper cleanup even if exceptions occur. Supports nesting. Args: - prefix_id: The unique prefix ID for this scope + prefix_id: The override prefix ID for this scope Yields: None @@ -151,11 +253,12 @@ def scope(cls, prefix_id: str) -> Iterator[None]: # All LLM calls here will use "eval-q001" prefix await llm.ainvoke(...) """ + previous_override = cls._override_prefix_id.get() cls.set(prefix_id) try: yield finally: - cls.clear() + cls._override_prefix_id.set(previous_override) # ============================================================================= @@ -195,19 +298,23 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"): "Lower values allow more load balancing across workers."), space=SearchSpace(low=1, high=20, step=5)) - prefix_osl: PrefixLevel = OptimizableField(default="MEDIUM", - description=("Output Sequence Length hint for the Dynamo router. " - "LOW=short responses (decode_cost=1.0), " - "MEDIUM=typical (decode_cost=2.0), " - "HIGH=long responses (decode_cost=3.0)."), - space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"])) + prefix_osl: PrefixLevel = OptimizableField( + default="MEDIUM", + description="Output Sequence Length hint for the Dynamo router. " + "LOW means short responses (decode_cost=1.0), " + "MEDIUM means typical (decode_cost=2.0), " + "HIGH means long responses (decode_cost=3.0).", + space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]), + ) - prefix_iat: PrefixLevel = OptimizableField(default="MEDIUM", - description=("Inter-Arrival Time hint for the Dynamo router. " - "LOW=rapid bursts (iat_factor=1.5, high stickiness), " - "MEDIUM=normal (iat_factor=1.0), " - "HIGH=slow requests (iat_factor=0.6, more exploration)."), - space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"])) + prefix_iat: PrefixLevel = OptimizableField( + default="MEDIUM", + description="Inter-Arrival Time hint for the Dynamo router. " + "LOW means rapid bursts (iat_factor=1.5, high stickiness), " + "MEDIUM means normal (iat_factor=1.0), " + "HIGH means slow requests (iat_factor=0.6, more exploration).", + space=SearchSpace(values=["LOW", "MEDIUM", "HIGH"]), + ) request_timeout: float = Field( default=600.0, @@ -215,6 +322,12 @@ class DynamoModelConfig(OpenAIModelConfig, name="dynamo"): description="HTTP request timeout in seconds for LLM requests.", ) + prediction_trie_path: str | None = Field( + default=None, + description="Path to prediction_trie.json file. When set, predictions are " + "looked up and injected as headers for each LLM call.", + ) + # ========================================================================= # UTILITY METHODS # ========================================================================= @@ -243,6 +356,7 @@ def get_dynamo_field_names() -> frozenset[str]: "prefix_osl", "prefix_iat", "request_timeout", + "prediction_trie_path", }) @@ -261,44 +375,28 @@ def _create_dynamo_request_hook( Create an httpx event hook that injects Dynamo prefix headers into requests. This hook is called before each HTTP request is sent, allowing us to inject - headers dynamically. The prefix ID is generated ONCE when the hook is created, - ensuring all requests from the same client share the same prefix ID. This enables - Dynamo's KV cache optimization across multi-turn conversations. - - The context variable can override this for scenarios where you need different - prefix IDs (e.g., per-question in batch evaluation). + headers dynamically. The prefix ID is obtained from DynamoPrefixContext which + provides depth-aware prefix IDs - each level in the function call stack gets + its own unique prefix ID that remains constant within a workflow run. Args: - prefix_template: Template string with {uuid} placeholder - total_requests: Expected number of requests for this prefix - osl: Output sequence length hint (LOW/MEDIUM/HIGH) - iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + prefix_template: Template string with {uuid} placeholder (unused, for API compat). + total_requests: Expected number of requests for this prefix. + osl: Output sequence length hint (LOW/MEDIUM/HIGH). + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH). Returns: An async function suitable for use as an httpx event hook. """ - # Generate the default prefix ID ONCE when the hook is created - # This ensures all requests from this client share the same prefix ID - unique_id = uuid.uuid4().hex[:16] - if prefix_template: - default_prefix_id = prefix_template.format(uuid=unique_id) - else: - default_prefix_id = f"nat-dynamo-{unique_id}" - - logger.debug("Created Dynamo request hook with default prefix ID: %s", default_prefix_id) + # Note: prefix_template is kept for API compatibility but no longer used. + # Prefix IDs are now managed by DynamoPrefixContext with depth-awareness. + _ = prefix_template # Suppress unused parameter warning async def on_request(request): """Inject Dynamo prefix headers before each request.""" - # Check context variable first (allows per-question override in batch evaluation) - context_prefix_id = DynamoPrefixContext.get() - - if context_prefix_id: - prefix_id = context_prefix_id - logger.debug("Using context prefix ID: %s", prefix_id) - else: - # Use the pre-generated prefix ID (same for all requests from this client) - prefix_id = default_prefix_id - logger.debug("Using default prefix ID: %s", prefix_id) + # Get depth-aware prefix ID from context + prefix_id = DynamoPrefixContext.get() + logger.debug("Using depth-aware prefix ID: %s", prefix_id) # Inject Dynamo headers request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] = prefix_id @@ -321,6 +419,7 @@ def create_httpx_client_with_dynamo_hooks( osl: str, iat: str, timeout: float = 600.0, + prediction_lookup: "PredictionTrieLookup | None" = None, ) -> "httpx.AsyncClient": """ Create an httpx.AsyncClient with Dynamo prefix header injection. @@ -334,16 +433,173 @@ def create_httpx_client_with_dynamo_hooks( osl: Output sequence length hint (LOW/MEDIUM/HIGH) iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) timeout: HTTP request timeout in seconds + prediction_lookup: Optional PredictionTrieLookup for dynamic header injection Returns: An httpx.AsyncClient configured with Dynamo header injection. """ import httpx - request_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add dynamic prediction hook if lookup provided + if prediction_lookup is not None: + prediction_hook = _create_dynamic_prediction_hook(prediction_lookup) + hooks.append(prediction_hook) + + return httpx.AsyncClient( + event_hooks={"request": hooks}, + timeout=httpx.Timeout(timeout), + ) + + +def _create_prediction_request_hook( + prediction: LLMCallPrediction, ) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that overrides x-prefix-* headers from static prediction data. + + This hook converts numeric prediction values to categorical values (LOW/MEDIUM/HIGH) + and overrides the x-prefix-* headers set by the Dynamo prefix hook. + + Args: + prediction: The prediction data to inject + + Returns: + An async function suitable for use as an httpx event hook. + """ + # Pre-compute categorical values from prediction + total_requests = int(prediction.remaining_calls.mean) + osl = _output_tokens_to_osl(prediction.output_tokens.p90) + iat = _interarrival_ms_to_iat(prediction.interarrival_ms.mean) + + async def on_request(request): + """Override x-prefix-* headers with prediction-derived values.""" + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-total-requests"] = str(total_requests) + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-osl"] = osl + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-iat"] = iat + + logger.debug( + "Overrode prefix headers from static prediction: total_requests=%d, osl=%s, iat=%s", + total_requests, + osl, + iat, + ) + + return on_request + + +def _create_dynamic_prediction_hook( + trie_lookup: "PredictionTrieLookup", ) -> Callable[["httpx.Request"], Coroutine[Any, Any, None]]: + """ + Create an httpx event hook that dynamically looks up predictions per request. + + This hook reads the current function path and call index from context, + looks up the prediction in the trie, and overrides the x-prefix-* headers + with values derived from the prediction. The numeric prediction values + are converted to categorical values (LOW/MEDIUM/HIGH) for consistency + with static configuration. + + When a prediction is found, this hook overrides: + - x-prefix-total-requests: from remaining_calls.mean + - x-prefix-osl: converted from output_tokens.p90 + - x-prefix-iat: converted from interarrival_ms.mean + + Args: + trie_lookup: The PredictionTrieLookup instance to query + + Returns: + An async function suitable for use as an httpx event hook. + """ + + async def on_request(request: "httpx.Request") -> None: + """Look up prediction from context and override x-prefix-* headers.""" + from nat.llm.prediction_context import get_call_tracker + + try: + ctx = Context.get() + path = ctx.function_path + + # Get call index for current parent function + call_index = 1 # default + active_fn = ctx.active_function + if active_fn and active_fn.function_id != "root": + tracker = get_call_tracker() + call_index = tracker.counts.get(active_fn.function_id, 1) + + # Look up prediction + prediction = trie_lookup.find(path, call_index) + + if prediction: + # Convert numeric predictions to categorical values and override headers + total_requests = int(prediction.remaining_calls.mean) + osl = _output_tokens_to_osl(prediction.output_tokens.p90) + iat = _interarrival_ms_to_iat(prediction.interarrival_ms.mean) + + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-total-requests"] = str(total_requests) + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-osl"] = osl + request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-iat"] = iat + + logger.debug( + "Overrode prefix headers from prediction: path=%s, call_index=%d, " + "total_requests=%d, osl=%s (tokens=%d), iat=%s (ms=%d)", + path, + call_index, + total_requests, + osl, + int(prediction.output_tokens.p90), + iat, + int(prediction.interarrival_ms.mean), + ) + else: + logger.debug("No prediction found for path=%s, call_index=%d; using static values", path, call_index) + + except Exception as e: + # Don't fail the request if prediction lookup fails + logger.warning("Failed to override prefix headers from prediction: %s", e) + + return on_request + + +def create_httpx_client_with_prediction_headers( + prediction: LLMCallPrediction, + prefix_template: str | None, + total_requests: int, + osl: str, + iat: str, + timeout: float = 600.0, +) -> "httpx.AsyncClient": + """ + Create an httpx.AsyncClient with both Dynamo prefix and prediction headers. + + Args: + prediction: Prediction data for this LLM call + prefix_template: Template string with {uuid} placeholder + total_requests: Expected number of requests for this prefix + osl: Output sequence length hint (LOW/MEDIUM/HIGH) + iat: Inter-arrival time hint (LOW/MEDIUM/HIGH) + timeout: HTTP request timeout in seconds + + Returns: + An httpx.AsyncClient configured with header injection. + """ + import httpx + + hooks: list[Callable] = [] + + # Add Dynamo prefix hook + prefix_hook = _create_dynamo_request_hook(prefix_template, total_requests, osl, iat) + hooks.append(prefix_hook) + + # Add prediction hook + prediction_hook = _create_prediction_request_hook(prediction) + hooks.append(prediction_hook) return httpx.AsyncClient( - event_hooks={"request": [request_hook]}, + event_hooks={"request": hooks}, timeout=httpx.Timeout(timeout), ) diff --git a/src/nat/llm/prediction_context.py b/src/nat/llm/prediction_context.py new file mode 100644 index 0000000000..91184baeda --- /dev/null +++ b/src/nat/llm/prediction_context.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Runtime context management for prediction trie lookups. + +Provides tracking of LLM call indices per function invocation, +enabling accurate lookups in the prediction trie at runtime. +""" + +from contextvars import ContextVar +from dataclasses import dataclass +from dataclasses import field + + +@dataclass +class LLMCallTracker: + """Tracks LLM call counts per function invocation.""" + + counts: dict[str, int] = field(default_factory=dict) + + def increment(self, parent_function_id: str) -> int: + """ + Increment and return the call index for this parent. + + Args: + parent_function_id: Unique ID of the parent function invocation + + Returns: + The call index (1-indexed) for this LLM call within the parent + """ + self.counts[parent_function_id] = self.counts.get(parent_function_id, 0) + 1 + return self.counts[parent_function_id] + + def reset(self, parent_function_id: str) -> None: + """ + Reset call count when a function invocation completes. + + Args: + parent_function_id: Unique ID of the parent function invocation + """ + self.counts.pop(parent_function_id, None) + + +# Thread/async-safe context variable for the call tracker +_llm_call_tracker: ContextVar[LLMCallTracker] = ContextVar("llm_call_tracker") + + +def get_call_tracker() -> LLMCallTracker: + """ + Get the LLMCallTracker for the current context. + + Creates a new tracker if one doesn't exist in the current context. + + Returns: + The LLMCallTracker for this context + """ + try: + return _llm_call_tracker.get() + except LookupError: + tracker = LLMCallTracker() + _llm_call_tracker.set(tracker) + return tracker diff --git a/src/nat/profiler/prediction_trie/__init__.py b/src/nat/profiler/prediction_trie/__init__.py new file mode 100644 index 0000000000..ca71b36a68 --- /dev/null +++ b/src/nat/profiler/prediction_trie/__init__.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder + +# Note: PredictionTrieLookup is intentionally not re-exported here to avoid +# Sphinx cross-reference warnings. Import from trie_lookup submodule directly: +# from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + +__all__ = [ + "LLMCallPrediction", + "PredictionMetrics", + "PredictionTrieBuilder", + "PredictionTrieNode", + "load_prediction_trie", + "save_prediction_trie", +] diff --git a/src/nat/profiler/prediction_trie/data_models.py b/src/nat/profiler/prediction_trie/data_models.py new file mode 100644 index 0000000000..bf20c0ac1b --- /dev/null +++ b/src/nat/profiler/prediction_trie/data_models.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from pydantic import BaseModel +from pydantic import Field + + +class PredictionMetrics(BaseModel): + """Aggregated statistics for a single metric from profiler data.""" + + sample_count: int = Field(default=0, description="Number of samples") + mean: float = Field(default=0.0, description="Mean value") + p50: float = Field(default=0.0, description="50th percentile (median)") + p90: float = Field(default=0.0, description="90th percentile") + p95: float = Field(default=0.0, description="95th percentile") + + +class LLMCallPrediction(BaseModel): + """Predictions for an LLM call at a given position in the call hierarchy.""" + + remaining_calls: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="How many more LLM calls are expected after this one", + ) + interarrival_ms: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected time in milliseconds until the next LLM call", + ) + output_tokens: PredictionMetrics = Field( + default_factory=PredictionMetrics, + description="Expected output token count for this call", + ) + + +class PredictionTrieNode(BaseModel): + """A node in the prediction trie representing a function in the call hierarchy.""" + + name: str = Field(description="Function name at this level in the hierarchy") + children: dict[str, PredictionTrieNode] = Field( + default_factory=dict, + description="Child nodes keyed by function name", + ) + predictions_by_call_index: dict[int, LLMCallPrediction] = Field( + default_factory=dict, + description="Predictions keyed by call index (1-indexed)", + ) + predictions_any_index: LLMCallPrediction | None = Field( + default=None, + description="Fallback predictions aggregated across all call indices", + ) + + +# Rebuild model to handle forward references +PredictionTrieNode.model_rebuild() diff --git a/src/nat/profiler/prediction_trie/metrics_accumulator.py b/src/nat/profiler/prediction_trie/metrics_accumulator.py new file mode 100644 index 0000000000..e2313b17c3 --- /dev/null +++ b/src/nat/profiler/prediction_trie/metrics_accumulator.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +class MetricsAccumulator: + """Accumulates samples and computes aggregated statistics.""" + + def __init__(self) -> None: + self._samples: list[float] = [] + + def add_sample(self, value: float) -> None: + """Add a sample value to the accumulator.""" + self._samples.append(value) + + def has_samples(self) -> bool: + """Return True if any samples have been added.""" + return len(self._samples) > 0 + + def compute_metrics(self) -> PredictionMetrics: + """Compute aggregated metrics from accumulated samples.""" + if not self._samples: + return PredictionMetrics() + + n = len(self._samples) + mean_val = sum(self._samples) / n + sorted_samples = sorted(self._samples) + + return PredictionMetrics( + sample_count=n, + mean=mean_val, + p50=self._percentile(sorted_samples, 50), + p90=self._percentile(sorted_samples, 90), + p95=self._percentile(sorted_samples, 95), + ) + + @staticmethod + def _percentile(sorted_data: list[float], pct: float) -> float: + """Compute percentile using linear interpolation.""" + if not sorted_data: + return 0.0 + if len(sorted_data) == 1: + return sorted_data[0] + k = (len(sorted_data) - 1) * (pct / 100.0) + f = math.floor(k) + c = math.ceil(k) + if f == c: + return sorted_data[int(k)] + return sorted_data[f] + (sorted_data[c] - sorted_data[f]) * (k - f) diff --git a/src/nat/profiler/prediction_trie/serialization.py b/src/nat/profiler/prediction_trie/serialization.py new file mode 100644 index 0000000000..d0a6a62350 --- /dev/null +++ b/src/nat/profiler/prediction_trie/serialization.py @@ -0,0 +1,114 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from datetime import UTC +from datetime import datetime +from pathlib import Path +from typing import Any + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + +CURRENT_VERSION = "1.0" + + +def save_prediction_trie( + trie: PredictionTrieNode, + path: Path, + workflow_name: str = "unknown", +) -> None: + """ + Save a prediction trie to a JSON file. + + Args: + trie: The prediction trie root node + path: Path to save the JSON file + workflow_name: Name of the workflow this trie was built from + """ + data = { + "version": CURRENT_VERSION, + "generated_at": datetime.now(UTC).isoformat(), + "workflow_name": workflow_name, + "root": _serialize_node(trie), + } + + with open(path, "w", encoding="utf-8") as f: + json.dump(data, f, indent=2) + + +def load_prediction_trie(path: Path) -> PredictionTrieNode: + """ + Load a prediction trie from a JSON file. + + Args: + path: Path to the JSON file + + Returns: + The deserialized prediction trie root node + """ + with open(path, encoding="utf-8") as f: + data = json.load(f) + + return _deserialize_node(data["root"]) + + +def _serialize_node(node: PredictionTrieNode) -> dict[str, Any]: + """Serialize a trie node to a dictionary.""" + result: dict[str, Any] = { + "name": node.name, + "predictions_by_call_index": { + str(k): v.model_dump() + for k, v in node.predictions_by_call_index.items() + }, + "predictions_any_index": node.predictions_any_index.model_dump() if node.predictions_any_index else None, + "children": { + k: _serialize_node(v) + for k, v in node.children.items() + }, + } + return result + + +def _deserialize_node(data: dict[str, Any]) -> PredictionTrieNode: + """Deserialize a dictionary to a trie node.""" + predictions_by_call_index: dict[int, LLMCallPrediction] = {} + for k, v in data.get("predictions_by_call_index", {}).items(): + predictions_by_call_index[int(k)] = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + predictions_any_index = None + if data.get("predictions_any_index"): + v = data["predictions_any_index"] + predictions_any_index = LLMCallPrediction( + remaining_calls=PredictionMetrics(**v["remaining_calls"]), + interarrival_ms=PredictionMetrics(**v["interarrival_ms"]), + output_tokens=PredictionMetrics(**v["output_tokens"]), + ) + + children: dict[str, PredictionTrieNode] = {} + for k, v in data.get("children", {}).items(): + children[k] = _deserialize_node(v) + + return PredictionTrieNode( + name=data["name"], + predictions_by_call_index=predictions_by_call_index, + predictions_any_index=predictions_any_index, + children=children, + ) diff --git a/src/nat/profiler/prediction_trie/trie_builder.py b/src/nat/profiler/prediction_trie/trie_builder.py new file mode 100644 index 0000000000..3b2f285a53 --- /dev/null +++ b/src/nat/profiler/prediction_trie/trie_builder.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from dataclasses import dataclass +from dataclasses import field + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepType +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +@dataclass +class LLMCallContext: + """Context for a single LLM call extracted from a trace.""" + + path: list[str] + call_index: int + remaining_calls: int + time_to_next_ms: float | None + output_tokens: int + + +@dataclass +class _NodeAccumulators: + """Accumulators for a single trie node.""" + + remaining_calls: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + interarrival_ms: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + output_tokens: dict[int, MetricsAccumulator] = field(default_factory=lambda: defaultdict(MetricsAccumulator)) + # For aggregated stats across all call indices + all_remaining_calls: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_interarrival_ms: MetricsAccumulator = field(default_factory=MetricsAccumulator) + all_output_tokens: MetricsAccumulator = field(default_factory=MetricsAccumulator) + + +class PredictionTrieBuilder: + """Builds a prediction trie from profiler execution traces.""" + + def __init__(self) -> None: + # Map from path tuple to accumulators + self._node_accumulators: dict[tuple[str, ...], _NodeAccumulators] = defaultdict(_NodeAccumulators) + + def add_trace(self, steps: list[IntermediateStep]) -> None: + """Process a single execution trace and update accumulators.""" + contexts = self._extract_llm_contexts(steps) + for ctx in contexts: + self._update_accumulators(ctx) + + def _extract_llm_contexts(self, steps: list[IntermediateStep]) -> list[LLMCallContext]: + """Extract LLM call contexts from a trace.""" + # Sort steps by timestamp + sorted_steps = sorted(steps, key=lambda s: s.event_timestamp) + + # Find all LLM_END events + llm_ends = [s for s in sorted_steps if s.event_type == IntermediateStepType.LLM_END] + + # Find all LLM_START events for interarrival time calculation + llm_starts = [s for s in sorted_steps if s.event_type == IntermediateStepType.LLM_START] + + # Track call index per parent function + call_counts: dict[str, int] = defaultdict(int) + contexts: list[LLMCallContext] = [] + + for i, end_step in enumerate(llm_ends): + # Build path from function ancestry + path = self._build_path(end_step) + + # Determine call index within parent + parent_key = end_step.function_ancestry.function_id + call_counts[parent_key] += 1 + call_index = call_counts[parent_key] + + # Remaining calls in this trace + remaining = len(llm_ends) - i - 1 + + # Time to next LLM start (if any) + time_to_next_ms: float | None = None + current_end_time = end_step.event_timestamp + # Find next LLM_START after this LLM_END + for start_step in llm_starts: + if start_step.event_timestamp > current_end_time: + time_to_next_ms = (start_step.event_timestamp - current_end_time) * 1000.0 + break + + # Output tokens + output_tokens = 0 + if end_step.usage_info and end_step.usage_info.token_usage: + output_tokens = end_step.usage_info.token_usage.completion_tokens or 0 + + contexts.append( + LLMCallContext( + path=path, + call_index=call_index, + remaining_calls=remaining, + time_to_next_ms=time_to_next_ms, + output_tokens=output_tokens, + )) + + return contexts + + def _build_path(self, step: IntermediateStep) -> list[str]: + """Build the function path from ancestry.""" + path: list[str] = [] + ancestry = step.function_ancestry + + # Walk up the ancestry chain + if ancestry.parent_name: + path.append(ancestry.parent_name) + path.append(ancestry.function_name) + + return path + + def _update_accumulators(self, ctx: LLMCallContext) -> None: + """Update accumulators at every node along the path.""" + # Update root node + root_key: tuple[str, ...] = () + self._add_to_accumulators(root_key, ctx) + + # Update each node along the path + for i in range(len(ctx.path)): + path_key = tuple(ctx.path[:i + 1]) + self._add_to_accumulators(path_key, ctx) + + def _add_to_accumulators(self, path_key: tuple[str, ...], ctx: LLMCallContext) -> None: + """Add context data to accumulators for a specific path.""" + accs = self._node_accumulators[path_key] + + # By call index + accs.remaining_calls[ctx.call_index].add_sample(float(ctx.remaining_calls)) + accs.output_tokens[ctx.call_index].add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.interarrival_ms[ctx.call_index].add_sample(ctx.time_to_next_ms) + + # Aggregated across all indices + accs.all_remaining_calls.add_sample(float(ctx.remaining_calls)) + accs.all_output_tokens.add_sample(float(ctx.output_tokens)) + if ctx.time_to_next_ms is not None: + accs.all_interarrival_ms.add_sample(ctx.time_to_next_ms) + + def build(self) -> PredictionTrieNode: + """Build the final prediction trie from accumulated data.""" + root = PredictionTrieNode(name="root") + + for path_key, accs in self._node_accumulators.items(): + node = self._get_or_create_node(root, path_key) + self._populate_node_predictions(node, accs) + + return root + + def _get_or_create_node(self, root: PredictionTrieNode, path_key: tuple[str, ...]) -> PredictionTrieNode: + """Navigate to or create a node at the given path.""" + if not path_key: + return root + + current = root + for name in path_key: + if name not in current.children: + current.children[name] = PredictionTrieNode(name=name) + current = current.children[name] + return current + + def _populate_node_predictions(self, node: PredictionTrieNode, accs: _NodeAccumulators) -> None: + """Populate a node with computed predictions from accumulators.""" + # Predictions by call index + all_indices = set(accs.remaining_calls.keys()) | set(accs.interarrival_ms.keys()) | set( + accs.output_tokens.keys()) + + for idx in all_indices: + prediction = LLMCallPrediction( + remaining_calls=accs.remaining_calls[idx].compute_metrics(), + interarrival_ms=accs.interarrival_ms[idx].compute_metrics(), + output_tokens=accs.output_tokens[idx].compute_metrics(), + ) + node.predictions_by_call_index[idx] = prediction + + # Aggregated predictions + if accs.all_remaining_calls.has_samples(): + node.predictions_any_index = LLMCallPrediction( + remaining_calls=accs.all_remaining_calls.compute_metrics(), + interarrival_ms=accs.all_interarrival_ms.compute_metrics(), + output_tokens=accs.all_output_tokens.compute_metrics(), + ) diff --git a/src/nat/profiler/prediction_trie/trie_lookup.py b/src/nat/profiler/prediction_trie/trie_lookup.py new file mode 100644 index 0000000000..455d645446 --- /dev/null +++ b/src/nat/profiler/prediction_trie/trie_lookup.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +class PredictionTrieLookup: + """Looks up predictions in a prediction trie with graceful fallback.""" + + def __init__(self, root: PredictionTrieNode) -> None: + self._root = root + + def find(self, path: list[str], call_index: int) -> LLMCallPrediction | None: + """ + Find the best matching prediction for the given path and call index. + + Walks the trie as far as possible along the path, then returns the deepest + match. Falls back to aggregated predictions when exact call_index isn't found. + + Args: + path: Function ancestry path (e.g., ["my_workflow", "react_agent"]) + call_index: The Nth LLM call within the current parent function + + Returns: + Best matching prediction, or None if trie is empty + """ + node = self._root + deepest_match: LLMCallPrediction | None = None + + # Check root node first + deepest_match = self._get_prediction(node, call_index) or deepest_match + + # Walk the trie as far as we can match + for func_name in path: + if func_name not in node.children: + break + node = node.children[func_name] + # Update deepest match at each level + match = self._get_prediction(node, call_index) + if match is not None: + deepest_match = match + + return deepest_match + + def _get_prediction(self, node: PredictionTrieNode, call_index: int) -> LLMCallPrediction | None: + """Get prediction from node, preferring exact call_index, falling back to aggregated.""" + if call_index in node.predictions_by_call_index: + return node.predictions_by_call_index[call_index] + return node.predictions_any_index diff --git a/src/nat/profiler/profile_runner.py b/src/nat/profiler/profile_runner.py index 0ac72f5deb..d18e7ff99b 100644 --- a/src/nat/profiler/profile_runner.py +++ b/src/nat/profiler/profile_runner.py @@ -270,6 +270,25 @@ async def run(self, all_steps: list[list[IntermediateStep]]) -> ProfilerResults: json.dump(workflow_profiling_metrics, f, indent=2) logger.info("Wrote workflow profiling metrics to: %s", profiling_metrics_path) + if self.profile_config.prediction_trie.enable: + # ------------------------------------------------------------ + # Build and save prediction trie + # ------------------------------------------------------------ + from nat.profiler.prediction_trie import PredictionTrieBuilder + from nat.profiler.prediction_trie import save_prediction_trie + + logger.info("Building prediction trie from traces...") + trie_builder = PredictionTrieBuilder() + for trace in all_steps: + trie_builder.add_trace(trace) + + prediction_trie = trie_builder.build() + + if self.write_output: + trie_path = os.path.join(self.output_dir, self.profile_config.prediction_trie.output_filename) + save_prediction_trie(prediction_trie, Path(trie_path), workflow_name="profiled_workflow") + logger.info("Wrote prediction trie to: %s", trie_path) + if self.profile_config.token_usage_forecast: # ------------------------------------------------------------ # Fit forecasting model and save diff --git a/tests/nat/builder/test_call_tracker_integration.py b/tests/nat/builder/test_call_tracker_integration.py new file mode 100644 index 0000000000..8b3d74c27d --- /dev/null +++ b/tests/nat/builder/test_call_tracker_integration.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.prediction_context import get_call_tracker + + +def test_llm_start_increments_call_tracker(): + """Test that pushing an LLM_START step increments the call tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + # Initially no count for this function + assert tracker.counts.get(active_fn.function_id, 0) == 0 + + # Push LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + # Call tracker should be incremented + assert tracker.counts.get(active_fn.function_id) == 1 + + # Push another LLM_START + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-call-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + # Should be 2 now + assert tracker.counts.get(active_fn.function_id) == 2 + + +def test_non_llm_start_does_not_increment_tracker(): + """Test that non-LLM_START events don't increment the tracker.""" + ctx = Context.get() + step_manager = ctx.intermediate_step_manager + + with ctx.push_active_function("test_agent_2", input_data=None): + active_fn = ctx.active_function + tracker = get_call_tracker() + + initial_count = tracker.counts.get(active_fn.function_id, 0) + + # Push TOOL_START (should not increment) + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="tool-call-1", + event_type=IntermediateStepType.TOOL_START, + name="test-tool", + )) + + # Count should be unchanged + assert tracker.counts.get(active_fn.function_id, 0) == initial_count diff --git a/tests/nat/builder/test_function_path_stack.py b/tests/nat/builder/test_function_path_stack.py new file mode 100644 index 0000000000..508cf27f15 --- /dev/null +++ b/tests/nat/builder/test_function_path_stack.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.builder.context import Context +from nat.builder.context import ContextState + + +def test_function_path_stack_default_empty(): + """Test that function_path_stack starts empty.""" + state = ContextState.get() + # Reset to test fresh state + state._function_path_stack.set(None) + + path = state.function_path_stack.get() + assert path == [] + + +def test_function_path_stack_can_be_set(): + """Test that function_path_stack can be set and retrieved.""" + state = ContextState.get() + state.function_path_stack.set(["workflow", "agent"]) + + path = state.function_path_stack.get() + assert path == ["workflow", "agent"] + + +def test_push_active_function_updates_path_stack(): + """Test that push_active_function pushes/pops from path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + # Initially empty + assert state.function_path_stack.get() == [] + + with ctx.push_active_function("my_workflow", input_data=None): + assert state.function_path_stack.get() == ["my_workflow"] + + with ctx.push_active_function("react_agent", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + with ctx.push_active_function("tool_call", input_data=None): + assert state.function_path_stack.get() == ["my_workflow", "react_agent", "tool_call"] + + # After tool_call exits + assert state.function_path_stack.get() == ["my_workflow", "react_agent"] + + # After react_agent exits + assert state.function_path_stack.get() == ["my_workflow"] + + # After workflow exits + assert state.function_path_stack.get() == [] + + +def test_context_function_path_property(): + """Test that Context.function_path returns a copy of the path stack.""" + ctx = Context.get() + state = ctx._context_state + + # Reset path stack + state._function_path_stack.set(None) + + with ctx.push_active_function("workflow", input_data=None): + with ctx.push_active_function("agent", input_data=None): + path = ctx.function_path + assert path == ["workflow", "agent"] + + # Verify it's a copy (modifications don't affect original) + path.append("modified") + assert ctx.function_path == ["workflow", "agent"] diff --git a/tests/nat/llm/test_dynamic_prediction_hook.py b/tests/nat/llm/test_dynamic_prediction_hook.py new file mode 100644 index 0000000000..08702b1a04 --- /dev/null +++ b/tests/nat/llm/test_dynamic_prediction_hook.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nat.builder.context import Context +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.llm.dynamo_llm import create_httpx_client_with_dynamo_hooks +from nat.llm.prediction_context import get_call_tracker +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +@pytest.fixture(name="sample_trie_lookup") +def fixture_sample_trie_lookup() -> PredictionTrieLookup: + """Create a sample trie lookup for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: prediction, 2: prediction + }, + predictions_any_index=prediction, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=prediction, + ) + + return PredictionTrieLookup(root) + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +async def test_dynamic_hook_injects_headers(sample_trie_lookup): + """Test that dynamic hook overrides x-prefix-* headers based on context predictions.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate LLM call tracker increment (normally done by step manager) + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + # Prediction values are converted to x-prefix-* headers: + # - remaining_calls.mean=3.0 -> x-prefix-total-requests="3" + # - output_tokens.p90=200.0 -> x-prefix-osl="LOW" (< 256) + # - interarrival_ms.mean=500.0 -> x-prefix-iat="HIGH" (>= 500) + assert "x-prefix-total-requests" in request.headers + assert request.headers["x-prefix-total-requests"] == "3" + assert request.headers["x-prefix-osl"] == "LOW" + assert request.headers["x-prefix-iat"] == "HIGH" + + +async def test_dynamic_hook_uses_root_fallback(sample_trie_lookup): + """Test that hook falls back to root prediction for unknown paths.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + with ctx.push_active_function("unknown_workflow", input_data=None): + tracker = get_call_tracker() + tracker.increment(ctx.active_function.function_id) + + request = MockRequest() + await hook(request) + + # Should fall back to root aggregated predictions + assert "x-prefix-total-requests" in request.headers + + +async def test_dynamic_hook_handles_empty_context(sample_trie_lookup): + """Test that hook handles missing context gracefully.""" + ctx = Context.get() + state = ctx._context_state + + # Reset state to empty + state._function_path_stack.set(None) + state._active_function.set(None) + + hook = _create_dynamic_prediction_hook(sample_trie_lookup) + + request = MockRequest() + # Should not raise an exception + await hook(request) + + # Should still inject headers from root fallback + assert "x-prefix-total-requests" in request.headers + + +async def test_dynamic_hook_no_prediction_found(): + """Test that hook handles case where no prediction is found.""" + # Create empty trie with no predictions + empty_root = PredictionTrieNode(name="root") + empty_trie = PredictionTrieLookup(empty_root) + + ctx = Context.get() + state = ctx._context_state + + # Reset state + state._function_path_stack.set(None) + + hook = _create_dynamic_prediction_hook(empty_trie) + + with ctx.push_active_function("some_function", input_data=None): + request = MockRequest() + await hook(request) + + # Headers should not be overridden when no prediction found + # (the static Dynamo hook would set them, but this hook runs after) + assert "x-prefix-total-requests" not in request.headers + + +async def test_client_includes_prediction_hook_when_lookup_provided(sample_trie_lookup): + """Test that client includes prediction hook when trie_lookup is provided.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=sample_trie_lookup, + ) + + # Should have 2 hooks: dynamo prefix + prediction + assert len(client.event_hooks["request"]) == 2 + + await client.aclose() + + +async def test_client_works_without_prediction_lookup(): + """Test that client works when prediction_lookup is None.""" + client = create_httpx_client_with_dynamo_hooks( + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + prediction_lookup=None, + ) + + # Should have 1 hook: dynamo prefix only + assert len(client.event_hooks["request"]) == 1 + + await client.aclose() diff --git a/tests/nat/llm/test_dynamo_llm.py b/tests/nat/llm/test_dynamo_llm.py index 6251c234ea..3a141383c6 100644 --- a/tests/nat/llm/test_dynamo_llm.py +++ b/tests/nat/llm/test_dynamo_llm.py @@ -135,6 +135,7 @@ def test_get_dynamo_field_names(self): "prefix_osl", "prefix_iat", "request_timeout", + "prediction_trie_path", }) assert field_names == expected @@ -149,26 +150,37 @@ def test_get_dynamo_field_names(self): class TestDynamoPrefixContext: """Tests for DynamoPrefixContext singleton class.""" - def test_set_and_get_prefix_id(self): - """Test setting and getting prefix ID.""" - # Ensure clean state + def test_auto_generates_depth_based_prefix(self): + """Test that get() auto-generates a depth-based prefix when no override is set.""" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None - # Set and get + # get() always returns a value - auto-generated if no override + prefix = DynamoPrefixContext.get() + assert prefix is not None + assert "-d0" in prefix # Depth 0 at root level + + def test_set_and_get_override_prefix_id(self): + """Test setting and getting an override prefix ID.""" + DynamoPrefixContext.clear() + + # Set override DynamoPrefixContext.set("test-prefix-123") assert DynamoPrefixContext.get() == "test-prefix-123" # Clean up DynamoPrefixContext.clear() - def test_clear_prefix_id(self): - """Test clearing prefix ID.""" + def test_clear_removes_override_but_auto_generates(self): + """Test that clear() removes override but get() still returns auto-generated value.""" DynamoPrefixContext.set("test-prefix-456") assert DynamoPrefixContext.get() == "test-prefix-456" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None + # After clear, get() returns auto-generated depth-based prefix + prefix = DynamoPrefixContext.get() + assert prefix is not None + assert prefix != "test-prefix-456" + assert "-d0" in prefix def test_overwrite_prefix_id(self): """Test that setting a new prefix ID overwrites the old one.""" @@ -183,18 +195,19 @@ def test_overwrite_prefix_id(self): DynamoPrefixContext.clear() def test_scope_context_manager(self): - """Test the scope context manager for automatic cleanup.""" + """Test the scope context manager with override prefix.""" DynamoPrefixContext.clear() - assert DynamoPrefixContext.get() is None with DynamoPrefixContext.scope("scoped-prefix-789"): assert DynamoPrefixContext.get() == "scoped-prefix-789" - # Should be cleared after exiting context - assert DynamoPrefixContext.get() is None + # After exiting scope, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "scoped-prefix-789" + assert "-d0" in prefix def test_scope_context_manager_cleanup_on_exception(self): - """Test that scope context manager clears prefix ID even on exception.""" + """Test that scope context manager restores state even on exception.""" DynamoPrefixContext.clear() with pytest.raises(ValueError): @@ -202,22 +215,31 @@ def test_scope_context_manager_cleanup_on_exception(self): assert DynamoPrefixContext.get() == "error-prefix" raise ValueError("Test exception") - # Should still be cleared after exception - assert DynamoPrefixContext.get() is None + # After exception, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "error-prefix" + assert "-d0" in prefix - def test_scope_nested_replaces_then_clears(self): - """Test that nested scopes work but outer scope is lost after inner exits.""" + def test_scope_nested_restores_outer(self): + """Test that nested scopes properly restore outer scope value.""" DynamoPrefixContext.clear() with DynamoPrefixContext.scope("outer"): assert DynamoPrefixContext.get() == "outer" with DynamoPrefixContext.scope("inner"): assert DynamoPrefixContext.get() == "inner" - # After inner scope exits, it clears - outer value is lost - assert DynamoPrefixContext.get() is None + # After inner scope exits, outer value is restored + assert DynamoPrefixContext.get() == "outer" - # Still None after outer exits - assert DynamoPrefixContext.get() is None + # After outer scope exits, returns to auto-generated + prefix = DynamoPrefixContext.get() + assert prefix != "outer" + assert "-d0" in prefix + + def test_is_set_always_true(self): + """Test that is_set() always returns True since IDs are auto-generated.""" + DynamoPrefixContext.clear() + assert DynamoPrefixContext.is_set() is True # --------------------------------------------------------------------------- @@ -252,14 +274,14 @@ async def test_hook_injects_headers(self): await hook(mock_request) assert f"{LLMHeaderPrefix.DYNAMO.value}-id" in mock_request.headers - assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"].startswith("test-") + assert "-d0" in mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-total-requests"] == "15" assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-osl"] == "HIGH" assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-iat"] == "LOW" @pytest.mark.asyncio async def test_hook_uses_context_prefix_id(self): - """Test that the hook uses context variable prefix ID when set.""" + """Test that the hook uses context override prefix ID when set.""" hook = _create_dynamo_request_hook( prefix_template="template-{uuid}", total_requests=10, @@ -267,7 +289,7 @@ async def test_hook_uses_context_prefix_id(self): iat="MEDIUM", ) - # Set context prefix ID + # Set context override prefix ID DynamoPrefixContext.set("context-prefix-abc") mock_request = MagicMock() @@ -279,12 +301,12 @@ async def test_hook_uses_context_prefix_id(self): assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] == "context-prefix-abc" @pytest.mark.asyncio - async def test_hook_uses_same_id_for_all_requests(self): - """Test that the hook uses the same prefix ID for all requests from the same client. + async def test_hook_uses_same_id_for_same_depth(self): + """Test that the hook uses the same prefix ID for all requests at the same depth. This ensures Dynamo's KV cache optimization works across multi-turn conversations. - All requests from the same client (created with the same hook) should share - the same prefix ID to enable KV cache reuse. + All requests at the same depth within a workflow run should share the same + prefix ID to enable KV cache reuse. """ hook = _create_dynamo_request_hook( prefix_template="session-{uuid}", @@ -300,14 +322,14 @@ async def test_hook_uses_same_id_for_all_requests(self): await hook(mock_request) prefix_ids.add(mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"]) - # All requests should share the SAME prefix ID (for KV cache optimization) + # All requests at the same depth should share the SAME prefix ID assert len(prefix_ids) == 1 - # And it should start with our template - assert list(prefix_ids)[0].startswith("session-") + # Should contain depth marker + assert "-d0" in list(prefix_ids)[0] @pytest.mark.asyncio - async def test_different_hooks_have_different_ids(self): - """Test that different hooks (different clients) get different prefix IDs.""" + async def test_hooks_share_id_at_same_depth(self): + """Test that multiple hooks share the same prefix ID at the same depth.""" prefix_ids = set() for _ in range(5): hook = _create_dynamo_request_hook( @@ -321,12 +343,12 @@ async def test_different_hooks_have_different_ids(self): await hook(mock_request) prefix_ids.add(mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"]) - # Different hooks should have different prefix IDs - assert len(prefix_ids) == 5 + # All hooks at the same depth share the same prefix ID within a context + assert len(prefix_ids) == 1 @pytest.mark.asyncio - async def test_hook_default_prefix_template(self): - """Test that the hook uses default prefix format when template is None.""" + async def test_hook_uses_depth_based_prefix(self): + """Test that the hook uses depth-based prefix format.""" hook = _create_dynamo_request_hook( prefix_template=None, total_requests=10, @@ -339,8 +361,9 @@ async def test_hook_default_prefix_template(self): await hook(mock_request) - # Should use default "nat-dynamo-{id}" format - assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"].startswith("nat-dynamo-") + # Should use depth-based format "{workflow_id}-d{depth}" + prefix_id = mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] + assert "-d0" in prefix_id # Depth 0 at root level @pytest.mark.asyncio async def test_hook_normalizes_case(self): @@ -361,8 +384,8 @@ async def test_hook_normalizes_case(self): assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-iat"] == "HIGH" @pytest.mark.asyncio - async def test_hook_template_without_uuid_placeholder(self): - """Test that a template without {uuid} placeholder uses template as-is.""" + async def test_hook_with_override_ignores_depth(self): + """Test that setting an override prefix uses it instead of depth-based ID.""" hook = _create_dynamo_request_hook( prefix_template="static-prefix-no-uuid", total_requests=10, @@ -370,12 +393,15 @@ async def test_hook_template_without_uuid_placeholder(self): iat="MEDIUM", ) + # Set override + DynamoPrefixContext.set("my-override-prefix") + mock_request = MagicMock() mock_request.headers = {} await hook(mock_request) - # Template used as-is when no {uuid} placeholder - assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] == "static-prefix-no-uuid" + # Override prefix is used + assert mock_request.headers[f"{LLMHeaderPrefix.DYNAMO.value}-id"] == "my-override-prefix" # --------------------------------------------------------------------------- diff --git a/tests/nat/llm/test_dynamo_prediction_headers.py b/tests/nat/llm/test_dynamo_prediction_headers.py new file mode 100644 index 0000000000..9f581a3181 --- /dev/null +++ b/tests/nat/llm/test_dynamo_prediction_headers.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.llm.dynamo_llm import create_httpx_client_with_prediction_headers +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +async def test_prediction_headers_injected(): + """Test that prediction headers are injected into requests.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Create a mock request to capture headers + captured_headers = {} + + async def capture_hook(request): + captured_headers.update(dict(request.headers)) + + client = create_httpx_client_with_prediction_headers( + prediction=prediction, + prefix_template="test-{uuid}", + total_requests=10, + osl="MEDIUM", + iat="LOW", + ) + + # Add our capture hook + client.event_hooks["request"].append(capture_hook) + + # Make a test request (will fail, but headers will be captured) + try: + await client.post("http://localhost:1/test", json={}) + except Exception: + pass + + # Prediction hook overrides x-prefix-* headers with prediction-derived values + # remaining_calls.mean=3.0 → x-prefix-total-requests="3" + assert "x-prefix-total-requests" in captured_headers + assert captured_headers["x-prefix-total-requests"] == "3" + # output_tokens.p90=200.0 (< 256) → x-prefix-osl="LOW" + assert "x-prefix-osl" in captured_headers + assert captured_headers["x-prefix-osl"] == "LOW" + # interarrival_ms.mean=500.0 (>= 500) → x-prefix-iat="HIGH" + assert "x-prefix-iat" in captured_headers + assert captured_headers["x-prefix-iat"] == "HIGH" + + await client.aclose() diff --git a/tests/nat/llm/test_dynamo_prediction_trie.py b/tests/nat/llm/test_dynamo_prediction_trie.py new file mode 100644 index 0000000000..043955192f --- /dev/null +++ b/tests/nat/llm/test_dynamo_prediction_trie.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import pytest + +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.profiler.prediction_trie import PredictionTrieNode +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics + + +@pytest.fixture(name="trie_file") +def fixture_trie_file() -> Path: + """Create a temporary trie file for testing.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + path = Path(f.name) + + save_prediction_trie(root, path) + yield path + path.unlink(missing_ok=True) + + +def test_dynamo_config_with_trie_path(trie_file): + """Test that DynamoModelConfig accepts prediction_trie_path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + prediction_trie_path=str(trie_file), + ) + + assert config.prediction_trie_path == str(trie_file) + assert "prediction_trie_path" in DynamoModelConfig.get_dynamo_field_names() + + +def test_dynamo_config_without_trie_path(): + """Test that DynamoModelConfig works without prediction_trie_path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + ) + + assert config.prediction_trie_path is None + + +def test_dynamo_field_names_excludes_trie_path(): + """Test that prediction_trie_path is excluded from OpenAI client kwargs.""" + config = DynamoModelConfig( + base_url="http://localhost:8000", + model_name="test-model", + api_key="test-key", + prediction_trie_path="/path/to/trie.json", + ) + + # Simulate what would be passed to an OpenAI client + exclude_fields = {"type", "thinking", *DynamoModelConfig.get_dynamo_field_names()} + config_dict = config.model_dump(exclude=exclude_fields, exclude_none=True) + + # prediction_trie_path should not be in the config dict + assert "prediction_trie_path" not in config_dict diff --git a/tests/nat/llm/test_prediction_context.py b/tests/nat/llm/test_prediction_context.py new file mode 100644 index 0000000000..149bbca26d --- /dev/null +++ b/tests/nat/llm/test_prediction_context.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.llm.prediction_context import LLMCallTracker +from nat.llm.prediction_context import get_call_tracker + + +def test_tracker_increment(): + tracker = LLMCallTracker() + assert tracker.increment("func-1") == 1 + assert tracker.increment("func-1") == 2 + assert tracker.increment("func-2") == 1 + assert tracker.increment("func-1") == 3 + + +def test_tracker_reset(): + tracker = LLMCallTracker() + tracker.increment("func-1") + tracker.increment("func-1") + tracker.reset("func-1") + assert tracker.increment("func-1") == 1 + + +def test_tracker_context_variable(): + tracker1 = get_call_tracker() + tracker1.increment("func-a") + + tracker2 = get_call_tracker() + # Should be the same tracker in the same context + assert tracker2.increment("func-a") == 2 diff --git a/tests/nat/llm/test_runtime_prediction_e2e.py b/tests/nat/llm/test_runtime_prediction_e2e.py new file mode 100644 index 0000000000..d33d3d57bb --- /dev/null +++ b/tests/nat/llm/test_runtime_prediction_e2e.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end test for runtime prediction trie integration. + +This test validates that all pieces work together: +1. function_path_stack gets updated when push_active_function is called +2. IntermediateStepManager increments call tracker on LLM_START +3. Dynamic hook reads context and looks up predictions +4. Correct headers are injected based on call index +""" + +import tempfile +from pathlib import Path + +from nat.builder.context import Context +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.llm.dynamo_llm import _create_dynamic_prediction_hook +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +class MockRequest: + """Mock httpx.Request for testing.""" + + def __init__(self): + self.headers = {} + + +def create_test_trie() -> PredictionTrieNode: + """Create a test trie with known predictions.""" + # Agent at call 1: 2 remaining, 500ms interarrival, 150 tokens + call_1_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + # Agent at call 2: 1 remaining, 300ms interarrival, 100 tokens + call_2_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=1.0, p50=1.0, p90=2.0, p95=2.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=300.0, p50=280.0, p90=400.0, p95=450.0), + output_tokens=PredictionMetrics(sample_count=10, mean=100.0, p50=90.0, p90=150.0, p95=180.0), + ) + + # Agent at call 3: 0 remaining + call_3_prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=0.0, p50=0.0, p90=0.0, p95=0.0), + output_tokens=PredictionMetrics(sample_count=10, mean=80.0, p50=75.0, p90=120.0, p95=140.0), + ) + + # Aggregated for fallback + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=30, mean=1.0, p50=1.0, p90=2.0, p95=3.0), + interarrival_ms=PredictionMetrics(sample_count=30, mean=400.0, p50=380.0, p90=550.0, p95=600.0), + output_tokens=PredictionMetrics(sample_count=30, mean=110.0, p50=100.0, p90=160.0, p95=190.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: call_1_prediction, 2: call_2_prediction, 3: call_3_prediction + }, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + return PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + +async def test_e2e_prediction_headers_injected_correctly(): + """Test complete flow: context tracking -> step manager -> hook -> headers.""" + # Create and save trie + trie = create_test_trie() + + with tempfile.TemporaryDirectory() as tmpdir: + trie_path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(trie, trie_path, workflow_name="test") + + # Load trie + loaded_trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(loaded_trie) + + # Create hook + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("my_workflow", input_data=None): + with ctx.push_active_function("react_agent", input_data=None): + # Simulate first LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-1", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request1 = MockRequest() + await hook(request1) + + # Should have call 1 predictions: remaining_calls.mean=2.0, output_tokens.p90=200 + assert request1.headers["x-prefix-total-requests"] == "2" + assert request1.headers["x-prefix-osl"] == "LOW" # 200 tokens < 256 + assert request1.headers["x-prefix-iat"] == "HIGH" # 500ms >= 500 + + # Simulate second LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-2", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request2 = MockRequest() + await hook(request2) + + # Should have call 2 predictions: remaining_calls.mean=1.0, output_tokens.p90=150 + assert request2.headers["x-prefix-total-requests"] == "1" + assert request2.headers["x-prefix-osl"] == "LOW" # 150 tokens < 256 + assert request2.headers["x-prefix-iat"] == "MEDIUM" # 300ms is 100-500 + + # Simulate third LLM call + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-3", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request3 = MockRequest() + await hook(request3) + + # Should have call 3 predictions: remaining_calls.mean=0.0, output_tokens.p90=120 + assert request3.headers["x-prefix-total-requests"] == "0" + assert request3.headers["x-prefix-osl"] == "LOW" # 120 tokens < 256 + + +async def test_e2e_fallback_to_root(): + """Test that unknown paths fall back to root predictions.""" + trie = create_test_trie() + lookup = PredictionTrieLookup(trie) + hook = _create_dynamic_prediction_hook(lookup) + + ctx = Context.get() + state = ctx._context_state + step_manager = ctx.intermediate_step_manager + + # Reset state + state._function_path_stack.set(None) + + with ctx.push_active_function("unknown_workflow", input_data=None): + step_manager.push_intermediate_step( + IntermediateStepPayload( + UUID="llm-unknown", + event_type=IntermediateStepType.LLM_START, + name="test-model", + )) + + request = MockRequest() + await hook(request) + + # Should fall back to root aggregated predictions (remaining_calls.mean=1.0, output_tokens.p90=160 + assert "x-prefix-total-requests" in request.headers + assert request.headers["x-prefix-total-requests"] == "1" # aggregated mean diff --git a/tests/nat/plugins/langchain/test_dynamo_trie_loading.py b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py new file mode 100644 index 0000000000..f06a19d116 --- /dev/null +++ b/tests/nat/plugins/langchain/test_dynamo_trie_loading.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from nat.builder.builder import Builder +from nat.llm.dynamo_llm import DynamoModelConfig +from nat.plugins.langchain.llm import dynamo_langchain +from nat.profiler.prediction_trie import save_prediction_trie +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +@pytest.fixture(name="trie_file") +def fixture_trie_file(): + """Create a temporary trie file.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + root = PredictionTrieNode( + name="root", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + save_prediction_trie(root, path, workflow_name="test") + yield str(path) + + +@pytest.fixture(name="mock_builder") +def fixture_mock_builder(): + """Create a mock builder.""" + return MagicMock(spec=Builder) + + +def test_dynamo_config_with_valid_trie_path(trie_file): + """Test that DynamoModelConfig can be created with valid trie path.""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path=trie_file, + ) + + assert config.prediction_trie_path == trie_file + + +def test_dynamo_config_with_nonexistent_trie_path(): + """Test that DynamoModelConfig accepts nonexistent path (validated at load time).""" + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prediction_trie_path="/nonexistent/path/trie.json", + ) + + # Config creation should succeed; error happens at runtime + assert config.prediction_trie_path == "/nonexistent/path/trie.json" + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_loads_trie_and_passes_to_client(mock_chat, mock_create_client, trie_file, mock_builder): + """Test that dynamo_langchain loads trie from path and passes PredictionTrieLookup to httpx client.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path=trie_file, + ) + + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert "prediction_lookup" in call_kwargs + assert isinstance(call_kwargs["prediction_lookup"], PredictionTrieLookup) + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_handles_nonexistent_trie_gracefully(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain logs warning and continues when trie file doesn't exist.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path="/nonexistent/path/trie.json", + ) + + # Should not raise an exception + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup=None + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_no_trie_path_means_no_lookup(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain passes None when no trie path is configured.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", # prediction_trie_path is None by default + ) + + async with dynamo_langchain(config, mock_builder): + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + + +@patch("nat.plugins.langchain.llm.create_httpx_client_with_dynamo_hooks") +@patch("langchain_openai.ChatOpenAI") +async def test_dynamo_langchain_handles_invalid_trie_file_gracefully(mock_chat, mock_create_client, mock_builder): + """Test that dynamo_langchain logs warning and continues when trie file is invalid JSON.""" + mock_httpx_client = MagicMock() + mock_httpx_client.aclose = AsyncMock() + mock_create_client.return_value = mock_httpx_client + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + f.write("not valid json {{{") + invalid_trie_path = f.name + + try: + config = DynamoModelConfig( + base_url="http://localhost:8000/v1", + model_name="test-model", + api_key="test-key", + prefix_template="test-{uuid}", + prediction_trie_path=invalid_trie_path, + ) + + # Should not raise an exception + async with dynamo_langchain(config, mock_builder): + # Verify httpx client was created with prediction_lookup=None + mock_create_client.assert_called_once() + call_kwargs = mock_create_client.call_args.kwargs + assert call_kwargs["prediction_lookup"] is None + + mock_httpx_client.aclose.assert_awaited_once() + finally: + Path(invalid_trie_path).unlink(missing_ok=True) diff --git a/tests/nat/profiler/prediction_trie/__init__.py b/tests/nat/profiler/prediction_trie/__init__.py new file mode 100644 index 0000000000..3bcc1c39bb --- /dev/null +++ b/tests/nat/profiler/prediction_trie/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/nat/profiler/prediction_trie/test_data_models.py b/tests/nat/profiler/prediction_trie/test_data_models.py new file mode 100644 index 0000000000..f26157a1cd --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_data_models.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode + + +def test_prediction_metrics_creation(): + metrics = PredictionMetrics(sample_count=10, mean=5.0, p50=4.5, p90=8.0, p95=9.0) + assert metrics.sample_count == 10 + assert metrics.mean == 5.0 + assert metrics.p50 == 4.5 + assert metrics.p90 == 8.0 + assert metrics.p95 == 9.0 + + +def test_prediction_metrics_defaults(): + metrics = PredictionMetrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 + + +def test_llm_call_prediction_creation(): + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=5, mean=3.0, p50=3.0, p90=5.0, p95=6.0), + interarrival_ms=PredictionMetrics(sample_count=5, mean=500.0, p50=450.0, p90=800.0, p95=900.0), + output_tokens=PredictionMetrics(sample_count=5, mean=150.0, p50=140.0, p90=250.0, p95=300.0), + ) + assert prediction.remaining_calls.mean == 3.0 + assert prediction.interarrival_ms.mean == 500.0 + assert prediction.output_tokens.mean == 150.0 + + +def test_llm_call_prediction_defaults(): + prediction = LLMCallPrediction() + assert prediction.remaining_calls.sample_count == 0 + assert prediction.interarrival_ms.sample_count == 0 + assert prediction.output_tokens.sample_count == 0 + + +def test_prediction_trie_node_creation(): + node = PredictionTrieNode(name="root") + assert node.name == "root" + assert node.children == {} + assert node.predictions_by_call_index == {} + assert node.predictions_any_index is None + + +def test_prediction_trie_node_with_children(): + child = PredictionTrieNode(name="react_agent") + root = PredictionTrieNode(name="root", children={"react_agent": child}) + assert "react_agent" in root.children + assert root.children["react_agent"].name == "react_agent" + + +def test_prediction_trie_node_with_predictions(): + prediction = LLMCallPrediction() + node = PredictionTrieNode( + name="agent", + predictions_by_call_index={ + 1: prediction, 2: prediction + }, + predictions_any_index=prediction, + ) + assert 1 in node.predictions_by_call_index + assert 2 in node.predictions_by_call_index + assert node.predictions_any_index is not None + + +def test_prediction_trie_node_nested_hierarchy(): + """Test a multi-level trie structure.""" + leaf = PredictionTrieNode(name="tool_call") + middle = PredictionTrieNode(name="react_agent", children={"tool_call": leaf}) + root = PredictionTrieNode(name="workflow", children={"react_agent": middle}) + + assert root.children["react_agent"].children["tool_call"].name == "tool_call" diff --git a/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py new file mode 100644 index 0000000000..46eb4ddd5c --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_metrics_accumulator.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nat.profiler.prediction_trie.metrics_accumulator import MetricsAccumulator + + +def test_accumulator_add_single_sample(): + acc = MetricsAccumulator() + acc.add_sample(10.0) + metrics = acc.compute_metrics() + assert metrics.sample_count == 1 + assert metrics.mean == 10.0 + assert metrics.p50 == 10.0 + assert metrics.p90 == 10.0 + assert metrics.p95 == 10.0 + + +def test_accumulator_add_multiple_samples(): + acc = MetricsAccumulator() + for v in [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]: + acc.add_sample(v) + metrics = acc.compute_metrics() + assert metrics.sample_count == 10 + assert metrics.mean == 5.5 + assert metrics.p50 == 5.5 # median of 1-10 + assert metrics.p90 == 9.1 # 90th percentile + assert metrics.p95 == pytest.approx(9.55) # 95th percentile + + +def test_accumulator_empty(): + acc = MetricsAccumulator() + metrics = acc.compute_metrics() + assert metrics.sample_count == 0 + assert metrics.mean == 0.0 diff --git a/tests/nat/profiler/prediction_trie/test_serialization.py b/tests/nat/profiler/prediction_trie/test_serialization.py new file mode 100644 index 0000000000..289f15d2d2 --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_serialization.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from pathlib import Path + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.serialization import load_prediction_trie +from nat.profiler.prediction_trie.serialization import save_prediction_trie + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing serialization.""" + prediction = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + + child = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={1: prediction}, + predictions_any_index=prediction, + ) + + root = PredictionTrieNode( + name="root", + children={"react_agent": child}, + predictions_any_index=prediction, + ) + + return root + + +def test_save_and_load_trie(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + loaded = load_prediction_trie(path) + + assert loaded.name == "root" + assert "react_agent" in loaded.children + assert loaded.children["react_agent"].predictions_by_call_index[1].remaining_calls.mean == 3.0 + + +def test_saved_file_has_metadata(sample_trie): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "prediction_trie.json" + + save_prediction_trie(sample_trie, path, workflow_name="test_workflow") + + with open(path) as f: + data = json.load(f) + + assert data["version"] == "1.0" + assert data["workflow_name"] == "test_workflow" + assert "generated_at" in data + assert "root" in data diff --git a/tests/nat/profiler/prediction_trie/test_trie_builder.py b/tests/nat/profiler/prediction_trie/test_trie_builder.py new file mode 100644 index 0000000000..e68cf2eb78 --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_trie_builder.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie.trie_builder import PredictionTrieBuilder + + +@pytest.fixture(name="simple_trace") +def fixture_simple_trace() -> list[IntermediateStep]: + """Create a simple trace with two LLM calls.""" + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100), ), + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1002.0, + UUID="llm-2", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1003.0, + span_event_timestamp=1002.0, + UUID="llm-2", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=150), ), + ), + ), + ] + + +def test_trie_builder_builds_from_single_trace(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + assert trie.name == "root" + assert "my_workflow" in trie.children + + workflow_node = trie.children["my_workflow"] + # First LLM call: call_index=1, remaining=1 + assert 1 in workflow_node.predictions_by_call_index + # Second LLM call: call_index=2, remaining=0 + assert 2 in workflow_node.predictions_by_call_index + + +def test_trie_builder_computes_remaining_calls(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call should predict 1 remaining call + assert workflow_node.predictions_by_call_index[1].remaining_calls.mean == 1.0 + # Second call should predict 0 remaining calls + assert workflow_node.predictions_by_call_index[2].remaining_calls.mean == 0.0 + + +def test_trie_builder_computes_output_tokens(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call had 100 completion tokens + assert workflow_node.predictions_by_call_index[1].output_tokens.mean == 100.0 + # Second call had 150 completion tokens + assert workflow_node.predictions_by_call_index[2].output_tokens.mean == 150.0 + + +def test_trie_builder_computes_interarrival_time(simple_trace): + builder = PredictionTrieBuilder() + builder.add_trace(simple_trace) + trie = builder.build() + + workflow_node = trie.children["my_workflow"] + # First call: next LLM starts at 1002.0, this call ends at 1001.0 -> 1000ms + assert workflow_node.predictions_by_call_index[1].interarrival_ms.mean == 1000.0 diff --git a/tests/nat/profiler/prediction_trie/test_trie_lookup.py b/tests/nat/profiler/prediction_trie/test_trie_lookup.py new file mode 100644 index 0000000000..58e07aae89 --- /dev/null +++ b/tests/nat/profiler/prediction_trie/test_trie_lookup.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nat.profiler.prediction_trie.data_models import LLMCallPrediction +from nat.profiler.prediction_trie.data_models import PredictionMetrics +from nat.profiler.prediction_trie.data_models import PredictionTrieNode +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup + + +@pytest.fixture(name="sample_trie") +def fixture_sample_trie() -> PredictionTrieNode: + """Create a sample trie for testing lookups.""" + prediction_1 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=3.0, p50=3.0, p90=4.0, p95=5.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=500.0, p50=450.0, p90=700.0, p95=800.0), + output_tokens=PredictionMetrics(sample_count=10, mean=150.0, p50=140.0, p90=200.0, p95=250.0), + ) + prediction_2 = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=10, mean=2.0, p50=2.0, p90=3.0, p95=4.0), + interarrival_ms=PredictionMetrics(sample_count=10, mean=400.0, p50=380.0, p90=600.0, p95=700.0), + output_tokens=PredictionMetrics(sample_count=10, mean=200.0, p50=190.0, p90=280.0, p95=320.0), + ) + aggregated = LLMCallPrediction( + remaining_calls=PredictionMetrics(sample_count=20, mean=2.5, p50=2.5, p90=3.5, p95=4.5), + interarrival_ms=PredictionMetrics(sample_count=20, mean=450.0, p50=415.0, p90=650.0, p95=750.0), + output_tokens=PredictionMetrics(sample_count=20, mean=175.0, p50=165.0, p90=240.0, p95=285.0), + ) + + agent_node = PredictionTrieNode( + name="react_agent", + predictions_by_call_index={ + 1: prediction_1, 2: prediction_2 + }, + predictions_any_index=aggregated, + ) + + workflow_node = PredictionTrieNode( + name="my_workflow", + children={"react_agent": agent_node}, + predictions_any_index=aggregated, + ) + + root = PredictionTrieNode( + name="root", + children={"my_workflow": workflow_node}, + predictions_any_index=aggregated, + ) + + return root + + +def test_lookup_exact_match(sample_trie): + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + + assert result is not None + assert result.remaining_calls.mean == 3.0 + assert result.output_tokens.mean == 150.0 + + +def test_lookup_partial_path_match(sample_trie): + """When exact path doesn't exist, fall back to closest ancestor.""" + lookup = PredictionTrieLookup(sample_trie) + # "unknown_tool" doesn't exist, should fall back to react_agent's aggregated + result = lookup.find(path=["my_workflow", "react_agent", "unknown_tool"], call_index=1) + + assert result is not None + # Should get react_agent's call_index=1 prediction + assert result.remaining_calls.mean == 3.0 + + +def test_lookup_unknown_call_index_fallback(sample_trie): + """When call_index doesn't exist, fall back to aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["my_workflow", "react_agent"], call_index=99) + + assert result is not None + # Should fall back to predictions_any_index + assert result.remaining_calls.mean == 2.5 + + +def test_lookup_no_match_returns_root_aggregated(sample_trie): + """When nothing matches, return root's aggregated.""" + lookup = PredictionTrieLookup(sample_trie) + result = lookup.find(path=["completely_unknown"], call_index=1) + + assert result is not None + # Should return root's aggregated prediction + assert result.remaining_calls.mean == 2.5 diff --git a/tests/nat/profiler/test_prediction_trie_e2e.py b/tests/nat/profiler/test_prediction_trie_e2e.py new file mode 100644 index 0000000000..6b2ab34853 --- /dev/null +++ b/tests/nat/profiler/test_prediction_trie_e2e.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end test for prediction trie workflow.""" + +import tempfile +from pathlib import Path + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.prediction_trie.trie_lookup import PredictionTrieLookup +from nat.profiler.profile_runner import ProfilerRunner + + +def make_agent_trace(agent_name: str, num_llm_calls: int, base_timestamp: float) -> list[IntermediateStep]: + """Create a trace with multiple LLM calls in an agent.""" + steps = [] + ts = base_timestamp + + for i in range(num_llm_calls): + llm_uuid = f"llm-{agent_name}-{i}" + + # LLM_START + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=ts, + UUID=llm_uuid, + ), + )) + ts += 0.5 + + # LLM_END + completion_tokens = 100 + (i * 50) # Vary tokens by position + steps.append( + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id=f"{agent_name}-1", + function_name=agent_name, + parent_id="workflow-1", + parent_name="my_workflow", + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=ts, + span_event_timestamp=ts - 0.5, + UUID=llm_uuid, + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=completion_tokens)), + ), + )) + ts += 0.5 + + return steps + + +async def test_e2e_prediction_trie_workflow(): + """Test the complete flow: profiler -> trie -> lookup.""" + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + # Create multiple traces with different agents + traces = [ + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=1000.0), + make_agent_trace("react_agent", num_llm_calls=3, base_timestamp=2000.0), + make_agent_trace("tool_agent", num_llm_calls=2, base_timestamp=3000.0), + ] + + # Run profiler + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + runner = ProfilerRunner(config, output_dir) + await runner.run(traces) + + # Load trie + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists(), "Trie file should exist" + + trie = load_prediction_trie(trie_path) + lookup = PredictionTrieLookup(trie) + + # Test lookups + # react_agent has 3 LLM calls, so at call 1 there are 2 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 2.0 # 2 remaining after first call + + # At call 3 there are 0 remaining + result = lookup.find(path=["my_workflow", "react_agent"], call_index=3) + assert result is not None + assert result.remaining_calls.mean == 0.0 + + # tool_agent should have different stats + result = lookup.find(path=["my_workflow", "tool_agent"], call_index=1) + assert result is not None + assert result.remaining_calls.mean == 1.0 # 1 remaining after first call + + # Unknown agent should fall back to aggregated + result = lookup.find(path=["my_workflow", "unknown_agent"], call_index=1) + assert result is not None # Should still get a result from fallback diff --git a/tests/nat/profiler/test_prediction_trie_integration.py b/tests/nat/profiler/test_prediction_trie_integration.py new file mode 100644 index 0000000000..2d54b7d860 --- /dev/null +++ b/tests/nat/profiler/test_prediction_trie_integration.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +import pytest + +from nat.data_models.intermediate_step import IntermediateStep +from nat.data_models.intermediate_step import IntermediateStepPayload +from nat.data_models.intermediate_step import IntermediateStepType +from nat.data_models.intermediate_step import UsageInfo +from nat.data_models.invocation_node import InvocationNode +from nat.data_models.profiler import PredictionTrieConfig +from nat.data_models.profiler import ProfilerConfig +from nat.profiler.callbacks.token_usage_base_model import TokenUsageBaseModel +from nat.profiler.prediction_trie import load_prediction_trie +from nat.profiler.profile_runner import ProfilerRunner + + +@pytest.fixture(name="sample_traces") +def fixture_sample_traces() -> list[list[IntermediateStep]]: + """Create sample traces for testing profiler integration.""" + + def make_trace() -> list[IntermediateStep]: + return [ + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_START, + event_timestamp=1000.0, + UUID="llm-1", + ), + ), + IntermediateStep( + parent_id="root", + function_ancestry=InvocationNode( + function_id="workflow-1", + function_name="my_workflow", + parent_id=None, + parent_name=None, + ), + payload=IntermediateStepPayload( + event_type=IntermediateStepType.LLM_END, + event_timestamp=1001.0, + span_event_timestamp=1000.0, + UUID="llm-1", + usage_info=UsageInfo(token_usage=TokenUsageBaseModel(completion_tokens=100)), + ), + ), + ] + + return [make_trace(), make_trace()] + + +async def test_profiler_generates_prediction_trie(sample_traces): + with tempfile.TemporaryDirectory() as tmpdir: + output_dir = Path(tmpdir) + + config = ProfilerConfig( + base_metrics=True, + prediction_trie=PredictionTrieConfig(enable=True), + ) + + runner = ProfilerRunner(config, output_dir) + await runner.run(sample_traces) + + trie_path = output_dir / "prediction_trie.json" + assert trie_path.exists() + + trie = load_prediction_trie(trie_path) + assert trie.name == "root" + assert "my_workflow" in trie.children