Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions src/google/adk/a2a/converters/to_adk_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,23 @@ def _create_event(
actions: Optional[EventActions] = None,
long_running_function_ids: Optional[set[str]] = None,
partial: bool = False,
custom_metadata: Optional[dict[str, Any]] = None,
error_code: Optional[str] = None,
usage_metadata: Optional[
genai_types.GenerateContentResponseUsageMetadata
] = None,
) -> Optional[Event]:
"""Creates an ADK event from parts and metadata."""
event_actions = actions or EventActions()
if not output_parts and not event_actions.model_dump(
exclude_none=True, exclude_defaults=True
):
has_actions = bool(
event_actions.model_dump(exclude_none=True, exclude_defaults=True)
)
has_event_metadata = (
custom_metadata is not None
or error_code is not None
or usage_metadata is not None
)
if not output_parts and not has_actions and not has_event_metadata:
return None

event = Event(
Expand All @@ -206,6 +217,9 @@ def _create_event(
else None
),
partial=partial,
custom_metadata=custom_metadata,
error_code=error_code,
usage_metadata=usage_metadata,
)

return event
Expand Down Expand Up @@ -248,6 +262,59 @@ def _extract_event_actions(
return EventActions()


def _extract_event_metadata(
metadata: Optional[dict[str, Any]],
) -> dict[str, Any]:
"""Extracts ADK event metadata fields from A2A metadata.

Restores custom_metadata, error_code, and usage_metadata that were
serialized by the outbound converter.

Args:
metadata: The A2A metadata dictionary.

Returns:
A dict of keyword arguments suitable for passing to _create_event().
"""
if not metadata:
return {}

result: dict[str, Any] = {}

raw_custom = metadata.get(_get_adk_metadata_key("custom_metadata"))
if raw_custom is not None:
parsed = _parse_adk_metadata_value(raw_custom)
if isinstance(parsed, dict):
result["custom_metadata"] = parsed
else:
logger.warning(
"Ignoring invalid ADK custom_metadata of type %s",
type(parsed).__name__,
)

raw_error_code = metadata.get(_get_adk_metadata_key("error_code"))
if raw_error_code is not None:
result["error_code"] = str(raw_error_code)

raw_usage = metadata.get(_get_adk_metadata_key("usage_metadata"))
if raw_usage is not None:
parsed = _parse_adk_metadata_value(raw_usage)
if isinstance(parsed, dict):
try:
result["usage_metadata"] = (
genai_types.GenerateContentResponseUsageMetadata(**parsed)
)
except Exception as e:
logger.warning("Ignoring invalid ADK usage_metadata: %s", e)
else:
logger.warning(
"Ignoring invalid ADK usage_metadata of type %s",
type(parsed).__name__,
)

return result


def _merge_top_level_dicts(
base: dict[str, Any], new_values: dict[str, Any]
) -> dict[str, Any]:
Expand Down Expand Up @@ -304,6 +371,7 @@ def convert_a2a_task_to_event(

try:
event_actions = EventActions()
event_metadata: dict[str, Any] = {}
output_parts = []
long_running_function_ids = set()
if a2a_task.artifacts:
Expand All @@ -314,6 +382,7 @@ def convert_a2a_task_to_event(
event_actions = _merge_event_actions(
event_actions, _extract_event_actions(artifact.metadata)
)
event_metadata.update(_extract_event_metadata(artifact.metadata))
output_parts, _ = _convert_a2a_parts_to_adk_parts(
artifact_parts, part_converter
)
Expand All @@ -325,6 +394,9 @@ def convert_a2a_task_to_event(
event_actions,
_extract_event_actions(a2a_task.status.message.metadata),
)
event_metadata.update(
_extract_event_metadata(a2a_task.status.message.metadata)
)
parts, ids = _convert_a2a_parts_to_adk_parts(
a2a_task.status.message.parts, part_converter
)
Expand All @@ -337,6 +409,7 @@ def convert_a2a_task_to_event(
author,
event_actions,
long_running_function_ids,
**event_metadata,
)

except Exception as e:
Expand Down Expand Up @@ -375,11 +448,13 @@ def convert_a2a_message_to_event(
output_parts, _ = _convert_a2a_parts_to_adk_parts(
a2a_message.parts, part_converter
)
event_metadata = _extract_event_metadata(a2a_message.metadata)
return _create_event(
output_parts,
invocation_context,
author,
_extract_event_actions(a2a_message.metadata),
**event_metadata,
)

except Exception as e:
Expand Down Expand Up @@ -412,10 +487,14 @@ def convert_a2a_status_update_to_event(
output_parts = []
long_running_function_ids = set()
event_actions = EventActions()
event_metadata: dict[str, Any] = {}
if a2a_status_update.status.message:
event_actions = _extract_event_actions(
a2a_status_update.status.message.metadata
)
event_metadata = _extract_event_metadata(
a2a_status_update.status.message.metadata
)
parts, ids = _convert_a2a_parts_to_adk_parts(
a2a_status_update.status.message.parts, part_converter
)
Expand All @@ -428,6 +507,7 @@ def convert_a2a_status_update_to_event(
author,
event_actions,
long_running_function_ids,
**event_metadata,
)
except Exception as e:
logger.error("Failed to convert A2A status update to event: %s", e)
Expand Down Expand Up @@ -460,12 +540,16 @@ def convert_a2a_artifact_update_to_event(
output_parts, _ = _convert_a2a_parts_to_adk_parts(
a2a_artifact_update.artifact.parts, part_converter
)
event_metadata = _extract_event_metadata(
a2a_artifact_update.artifact.metadata
)
return _create_event(
output_parts,
invocation_context,
author,
_extract_event_actions(a2a_artifact_update.artifact.metadata),
partial=not a2a_artifact_update.last_chunk,
**event_metadata,
)
except Exception as e:
logger.error("Failed to convert A2A artifact update to event: %s", e)
Expand Down
Loading