Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,36 @@
logger = logging.getLogger(__name__)

# Keys that are internal to AG-UI orchestration and should not be passed to chat clients
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"}
AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state", "forwarded_props"}


def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]:
Comment thread
moonbox3 marked this conversation as resolved.
"""Build metadata dict with truncated string values for Azure compatibility.
"""Build metadata dict with string values for Azure compatibility.

Azure has a 512 character limit per metadata value.
Azure has a 512 character limit per metadata value. String values that
already fit are kept as-is. Non-string values are JSON-serialized. If the
resulting string exceeds 512 characters the key is **dropped** (with a
warning) instead of truncated, because truncation can produce invalid JSON
that downstream consumers cannot decode.

Args:
thread_metadata: Raw metadata dict

Returns:
Metadata with string values truncated to 512 chars
Metadata with safe string values (each <= 512 chars)
"""
if not thread_metadata:
return {}
safe_metadata: dict[str, Any] = {}
for key, value in thread_metadata.items():
value_str = value if isinstance(value, str) else json.dumps(value)
if len(value_str) > 512:
value_str = value_str[:512]
logger.warning(
"Dropping metadata key %r: serialized value is %d chars (limit 512)",
key,
len(value_str),
)
continue
safe_metadata[key] = value_str
return safe_metadata

Expand Down Expand Up @@ -790,6 +799,10 @@ async def run_agent_stream(
"ag_ui_thread_id": thread_id,
"ag_ui_run_id": run_id,
}
if "forwarded_props" in input_data:
base_metadata["forwarded_props"] = input_data["forwarded_props"]
elif "forwardedProps" in input_data:
base_metadata["forwarded_props"] = input_data["forwardedProps"]
if flow.current_state:
base_metadata["current_state"] = flow.current_state
session.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined]
Expand Down
27 changes: 25 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import inspect
import json
import logging
import uuid
Expand Down Expand Up @@ -581,11 +582,33 @@ def _drain_open_message() -> list[TextMessageEndEvent]:
flow.accumulated_text = ""
return [TextMessageEndEvent(message_id=current_message_id)]

fwd_kwargs: dict[str, Any] = {}
if "forwarded_props" in input_data:
forwarded_props = input_data["forwarded_props"]
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}
elif "forwardedProps" in input_data:
forwarded_props = input_data["forwardedProps"]
fwd_kwargs["function_invocation_kwargs"] = {"forwarded_props": forwarded_props}

# Only pass function_invocation_kwargs if the workflow.run signature accepts it
if fwd_kwargs:
try:
sig = inspect.signature(workflow.run)
params = sig.parameters
accepts_fwd = "function_invocation_kwargs" in params or any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
)
except (ValueError, TypeError):
accepts_fwd = False
if not accepts_fwd:
logger.debug("workflow.run() does not accept function_invocation_kwargs; dropping forwarded_props")
fwd_kwargs = {}

try:
if responses:
event_stream = workflow.run(responses=responses, stream=True)
event_stream = workflow.run(responses=responses, stream=True, **fwd_kwargs)
else:
event_stream = workflow.run(message=messages, stream=True)
event_stream = workflow.run(message=messages, stream=True, **fwd_kwargs)
Comment thread
moonbox3 marked this conversation as resolved.

async for event in event_stream:
event_type = getattr(event, "type", None)
Expand Down
Comment thread
moonbox3 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) Microsoft. All rights reserved.

"""Tests for forwarded_props inclusion in AG-UI session metadata."""

import json
from typing import Any

from agent_framework_ag_ui._agent_run import AG_UI_INTERNAL_METADATA_KEYS, _build_safe_metadata


class TestForwardedPropsInSessionMetadata:
"""Verify that forwarded_props is surfaced in session metadata and filtered from LLM metadata."""

def test_forwarded_props_in_internal_metadata_keys(self):
"""forwarded_props is listed in AG_UI_INTERNAL_METADATA_KEYS to prevent LLM leakage."""
assert "forwarded_props" in AG_UI_INTERNAL_METADATA_KEYS

def test_forwarded_props_filtered_from_client_metadata(self):
"""forwarded_props is filtered out when building LLM-bound client metadata."""
session_metadata: dict[str, Any] = {
"ag_ui_thread_id": "t1",
"ag_ui_run_id": "r1",
"forwarded_props": '{"custom_flag": true}',
}

client_metadata = {k: v for k, v in session_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS}

assert "forwarded_props" not in client_metadata
assert "ag_ui_thread_id" not in client_metadata


class TestBuildSafeMetadata:
"""Verify _build_safe_metadata handles various value types correctly."""

def test_string_value_unchanged(self):
result = _build_safe_metadata({"key": "hello"})
assert result == {"key": "hello"}

def test_dict_value_serialized_to_json(self):
result = _build_safe_metadata({"fp": {"flag": True, "source": "frontend"}})
assert "fp" in result
assert isinstance(result["fp"], str)
# Must be valid, decodable JSON
decoded = json.loads(result["fp"])
assert decoded == {"flag": True, "source": "frontend"}

def test_empty_dict_serialized_to_json(self):
result = _build_safe_metadata({"fp": {}})
assert result["fp"] == "{}"
assert json.loads(result["fp"]) == {}

def test_value_within_limit_kept(self):
value = "x" * 512
result = _build_safe_metadata({"key": value})
assert result["key"] == value

def test_value_exceeding_limit_dropped(self):
"""Values exceeding 512 chars are dropped entirely (not truncated)."""
value = "x" * 513
result = _build_safe_metadata({"key": value})
assert "key" not in result

def test_json_value_exceeding_limit_dropped(self):
"""JSON-serialized dict exceeding 512 chars is dropped, not truncated into invalid JSON."""
big_dict = {f"key_{i}": "v" * 100 for i in range(50)}
result = _build_safe_metadata({"forwarded_props": big_dict})
assert "forwarded_props" not in result

def test_other_keys_preserved_when_one_dropped(self):
"""Dropping one oversized key does not affect other keys."""
result = _build_safe_metadata(
{
"small": "ok",
"big": "x" * 600,
}
)
assert result == {"small": "ok"}

def test_none_input_returns_empty(self):
assert _build_safe_metadata(None) == {}

def test_empty_input_returns_empty(self):
assert _build_safe_metadata({}) == {}
12 changes: 6 additions & 6 deletions python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def test_short_string_values(self):
result = _build_safe_metadata(metadata)
assert result == metadata

def test_truncates_long_strings(self):
"""Truncates strings over 512 chars."""
def test_drops_long_strings(self):
"""Drops strings over 512 chars instead of truncating."""
long_value = "x" * 1000
metadata = {"key": long_value}
result = _build_safe_metadata(metadata)
assert len(result["key"]) == 512
assert "key" not in result

def test_serializes_non_strings(self):
"""Serializes non-string values to JSON."""
Expand All @@ -77,12 +77,12 @@ def test_serializes_non_strings(self):
assert result["count"] == "42"
assert result["items"] == "[1, 2, 3]"

def test_truncates_serialized_values(self):
"""Truncates serialized values over 512 chars."""
def test_drops_oversized_serialized_values(self):
"""Drops serialized values over 512 chars instead of truncating."""
long_list = list(range(200))
metadata = {"data": long_list}
result = _build_safe_metadata(metadata)
assert len(result["data"]) == 512
assert "data" not in result


class TestHasOnlyToolCalls:
Expand Down
Loading
Loading