diff --git a/src/bedrock_agentcore/evaluation/__init__.py b/src/bedrock_agentcore/evaluation/__init__.py index 89e163bd..7f49fcfc 100644 --- a/src/bedrock_agentcore/evaluation/__init__.py +++ b/src/bedrock_agentcore/evaluation/__init__.py @@ -1,5 +1,6 @@ -"""AgentCore Evaluation integration for Strands.""" +"""AgentCore Evaluation: EvaluationClient and Strands integration.""" +from bedrock_agentcore.evaluation.client import EvaluationClient from bedrock_agentcore.evaluation.integrations.strands_agents_evals.evaluator import ( StrandsEvalsAgentCoreEvaluator, create_strands_evaluator, @@ -12,8 +13,9 @@ ) __all__ = [ - "create_strands_evaluator", + "EvaluationClient", "StrandsEvalsAgentCoreEvaluator", "convert_strands_to_adot", + "create_strands_evaluator", "fetch_spans_from_cloudwatch", ] diff --git a/src/bedrock_agentcore/evaluation/_agent_span_collector/__init__.py b/src/bedrock_agentcore/evaluation/_agent_span_collector/__init__.py new file mode 100644 index 00000000..6854f2e8 --- /dev/null +++ b/src/bedrock_agentcore/evaluation/_agent_span_collector/__init__.py @@ -0,0 +1,8 @@ +"""Agent span collector: collects OpenTelemetry spans for evaluation.""" + +from .agent_span_collector import AgentSpanCollector, CloudWatchAgentSpanCollector + +__all__ = [ + "AgentSpanCollector", + "CloudWatchAgentSpanCollector", +] diff --git a/src/bedrock_agentcore/evaluation/_agent_span_collector/agent_span_collector.py b/src/bedrock_agentcore/evaluation/_agent_span_collector/agent_span_collector.py new file mode 100644 index 00000000..ac425415 --- /dev/null +++ b/src/bedrock_agentcore/evaluation/_agent_span_collector/agent_span_collector.py @@ -0,0 +1,125 @@ +"""Span collector abstraction for the evaluation runner.""" + +import logging +import time +from abc import ABC, abstractmethod +from datetime import datetime, timedelta +from typing import List + +from bedrock_agentcore._utils.endpoints import DEFAULT_REGION +from bedrock_agentcore.evaluation.utils.cloudwatch_span_helper import CloudWatchSpanHelper + +logger = logging.getLogger(__name__) + +AWS_SPANS_LOG_GROUP = "aws/spans" + + +class AgentSpanCollector(ABC): + """Abstract base class for collecting spans after agent invocation.""" + + @abstractmethod + def collect(self, session_id: str, start_time: datetime, end_time: datetime) -> List[dict]: + """Collect spans for a given session. + + Args: + session_id: The session ID to collect spans for. + start_time: The start time of the session invocation. + end_time: The end time of the session invocation. + + Returns: + List of span dictionaries. + """ + + +class CloudWatchAgentSpanCollector(AgentSpanCollector): + """Collects spans from CloudWatch using precise attributes.session.id filtering.""" + + def __init__( + self, + log_group_name: str, + region: str = DEFAULT_REGION, + max_wait_seconds: int = 300, + poll_interval_seconds: int = 30, + ): + """Initialize the CloudWatch span collector. + + Args: + log_group_name: CloudWatch log group name for event logs. + region: AWS region for CloudWatch client. + max_wait_seconds: Maximum time to poll for spans before giving up (default 300s). + poll_interval_seconds: Time between poll attempts (default 30s). + """ + self.log_group_name = log_group_name + self.region = region + self.max_wait_seconds = max_wait_seconds + self.poll_interval_seconds = poll_interval_seconds + self._helper = CloudWatchSpanHelper(region=region) + + def collect(self, session_id: str, start_time: datetime, end_time: datetime) -> List[dict]: + """Collect spans from CloudWatch, polling until spans appear or timeout. + + Args: + session_id: The session ID to collect spans for. + start_time: The start time of the session invocation. + end_time: The end time of the session invocation. + + Returns: + List of ADOT span dictionaries. + """ + # Widen the query window so spans ingested shortly after the + # invocation ended are not excluded. CloudWatch Logs Insights + # treats endTime as exclusive and ingestion can lag by seconds, + # so a 60-second buffer avoids missing spans on every retry. + query_end_time = end_time + timedelta(seconds=60) + logger.debug( + "Collecting spans for session_id=%s, log_group=%s, time_range=[%s, %s]", + session_id, + self.log_group_name, + start_time, + query_end_time, + ) + deadline = time.monotonic() + self.max_wait_seconds + + while True: + spans = self._fetch_spans(session_id, start_time, query_end_time) + logger.debug("fetch_spans returned %d span(s)", len(spans)) + + if spans: + logger.info("Collected %d span(s) for session %s", len(spans), session_id) + return spans + + if time.monotonic() + self.poll_interval_seconds > deadline: + logger.warning( + "Span collection timed out after %ds for session %s (0 spans found)", + self.max_wait_seconds, + session_id, + ) + return spans + + logger.info("No spans found yet, retrying in %ds...", self.poll_interval_seconds) + time.sleep(self.poll_interval_seconds) + + def _fetch_spans(self, session_id: str, start_time: datetime, end_time: datetime) -> List[dict]: + """Fetch spans from both aws/spans and the configured log group. + + Queries both log groups with a precise attributes.session.id filter, + combines results, and returns only valid ADOT span documents. + """ + query_string = ( + f"fields @timestamp, @message" + f'\n| filter attributes.session.id = "{session_id}"' + f"\n| filter ispresent(scope.name)" + f"\n| sort @timestamp asc" + ) + + aws_spans = self._helper.query_log_group( + AWS_SPANS_LOG_GROUP, session_id, start_time, end_time, query_string=query_string + ) + event_spans = self._helper.query_log_group( + self.log_group_name, session_id, start_time, end_time, query_string=query_string + ) + + all_data = aws_spans + event_spans + + logger.info("Fetched %d span items from CloudWatch", len(all_data)) + return all_data diff --git a/src/bedrock_agentcore/evaluation/client.py b/src/bedrock_agentcore/evaluation/client.py new file mode 100644 index 00000000..3aacc919 --- /dev/null +++ b/src/bedrock_agentcore/evaluation/client.py @@ -0,0 +1,254 @@ +"""EvaluationClient for collecting spans and running evaluations.""" + +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Dict, Iterator, List, Optional + +import boto3 +from botocore.config import Config + +from bedrock_agentcore._utils.user_agent import build_user_agent_suffix +from bedrock_agentcore.evaluation._agent_span_collector import CloudWatchAgentSpanCollector + +logger = logging.getLogger(__name__) + +MAX_TARGET_IDS_PER_REQUEST = 10 +QUERY_TIMEOUT_SECONDS = 60 +POLL_INTERVAL_SECONDS = 2 + + +class EvaluationClient: + """Client for evaluating agent sessions. + + Collects spans from CloudWatch and calls the evaluation API with + level-aware batching. + + Example:: + + client = EvaluationClient(region_name="us-west-2") + + # Using agent_id (log group derived automatically) + results = client.run( + evaluator_ids=["accuracy", "toxicity"], + session_id="sess-123", + agent_id="my-agent", + ) + + # Using log_group_name directly + results = client.run( + evaluator_ids=["accuracy", "toxicity"], + session_id="sess-123", + log_group_name="/custom/my-log-group", + ) + + for r in results: + print(f"{r['evaluatorId']}: {r.get('value')} - {r.get('explanation')}") + """ + + def __init__( + self, + region_name: Optional[str] = None, + integration_source: Optional[str] = None, + ): + """Initialize the EvaluationClient. + + Args: + region_name: AWS region. Falls back to boto3 session region or us-west-2. + integration_source: Optional integration framework identifier for telemetry. + """ + self.region_name = region_name or boto3.Session().region_name or "us-west-2" + self.integration_source = integration_source + + user_agent_extra = build_user_agent_suffix(integration_source) + client_config = Config(user_agent_extra=user_agent_extra) + + self._dp_client = boto3.client( + "bedrock-agentcore", + region_name=self.region_name, + config=client_config, + ) + self._cp_client = boto3.client( + "bedrock-agentcore-control", + region_name=self.region_name, + config=client_config, + ) + self._evaluator_level_cache: Dict[str, str] = {} + + logger.info("Initialized EvaluationClient in region %s", self.region_name) + + def run( + self, + evaluator_ids: List[str], + session_id: str, + agent_id: Optional[str] = None, + look_back_time: timedelta = timedelta(days=7), + log_group_name: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """Evaluate an agent session end-to-end. + + 1. Collects spans from CloudWatch. + 2. For each evaluator, looks up its level (SESSION/TRACE/TOOL_CALL). + 3. Builds the appropriate evaluationTarget based on level. + 4. Calls evaluate() with auto-batching (max 10 target IDs per request). + 5. Returns combined evaluationResults from all evaluators. + + Either ``agent_id`` or ``log_group_name`` must be provided. + When only ``agent_id`` is given, the log group name is derived as + ``/aws/bedrock-agentcore/runtimes/{agent_id}-DEFAULT``. + + Args: + evaluator_ids: List of evaluator IDs (built-in or custom ARNs). + session_id: The session ID to evaluate. + agent_id: The agent ID. Used to derive the log group when + ``log_group_name`` is not provided. + look_back_time: How far back to search for spans (default: 7 days). + log_group_name: CloudWatch log group name. If provided, ``agent_id`` + is not required. + + Returns: + List of evaluation result dicts from all evaluators. + + Raises: + ValueError: If neither ``agent_id`` nor ``log_group_name`` is provided. + """ + if not agent_id and not log_group_name: + raise ValueError("Provide either agent_id or log_group_name.") + + if not log_group_name: + log_group_name = f"/aws/bedrock-agentcore/runtimes/{agent_id}-DEFAULT" + logger.debug("Derived log_group_name=%s from agent_id=%s", log_group_name, agent_id) + + end_time = datetime.now(timezone.utc) + start_time = end_time - look_back_time + + logger.info( + "Running evaluation for session=%s, log_group=%s, time_range=[%s, %s]", + session_id, + log_group_name, + start_time, + end_time, + ) + + # Step 1: Collect spans + collector = CloudWatchAgentSpanCollector( + log_group_name=log_group_name, + region=self.region_name, + max_wait_seconds=QUERY_TIMEOUT_SECONDS, + poll_interval_seconds=POLL_INTERVAL_SECONDS, + ) + spans = collector.collect( + session_id=session_id, + start_time=start_time, + end_time=end_time, + ) + + if not spans: + logger.warning("No spans found for session %s", session_id) + return [] + + base_input = {"evaluationInput": {"sessionSpans": spans}} + + # Steps 2-4: For each evaluator, look up level, build targets, call API + all_results = [] + for evaluator_id in evaluator_ids: + level = self._get_evaluator_level(evaluator_id) + logger.info("Evaluating with %s (level=%s)", evaluator_id, level) + requests = self._build_requests_for_level(evaluator_id, level, base_input, spans) + if len(requests) > 1: + logger.debug("Split into %d batched request(s) for evaluator %s", len(requests), evaluator_id) + evaluator_result_count = 0 + for request in requests: + try: + response = self._dp_client.evaluate(evaluatorId=evaluator_id, **request) + results = response.get("evaluationResults", []) + evaluator_result_count += len(results) + all_results.extend(results) + except Exception as e: + logger.warning("Evaluator %s failed: %s", evaluator_id, e) + logger.debug("Evaluator %s returned %d result(s)", evaluator_id, evaluator_result_count) + + logger.info( + "Evaluation complete: %d result(s) from %d evaluator(s)", + len(all_results), + len(evaluator_ids), + ) + return all_results + + def _get_evaluator_level(self, evaluator_id: str) -> str: + """Look up evaluator level with caching. Falls back to SESSION.""" + if evaluator_id not in self._evaluator_level_cache: + try: + response = self._cp_client.get_evaluator(evaluatorId=evaluator_id) + self._evaluator_level_cache[evaluator_id] = response["level"] + except Exception as e: + logger.warning( + "Failed to get level for %s, defaulting to SESSION: %s", + evaluator_id, + e, + ) + self._evaluator_level_cache[evaluator_id] = "SESSION" + return self._evaluator_level_cache[evaluator_id] + + def _build_requests_for_level( + self, + evaluator_id: str, + level: str, + base_input: dict, + spans: list, + ) -> List[dict]: + """Build one or more evaluate request payloads based on evaluator level.""" + if level == "SESSION": + return [base_input] + + if level == "TRACE": + trace_ids = self._extract_trace_ids(spans) + logger.debug("Extracted %d unique trace ID(s) for evaluator %s", len(trace_ids), evaluator_id) + if not trace_ids: + raise ValueError(f"No trace IDs found for trace-level evaluator {evaluator_id}") + return [ + {**base_input, "evaluationTarget": {"traceIds": batch}} + for batch in self._batch(trace_ids, MAX_TARGET_IDS_PER_REQUEST) + ] + + if level == "TOOL_CALL": + tool_span_ids = self._extract_tool_span_ids(spans) + logger.debug("Extracted %d tool span ID(s) for evaluator %s", len(tool_span_ids), evaluator_id) + if not tool_span_ids: + raise ValueError(f"No tool span IDs found for tool-level evaluator {evaluator_id}") + return [ + {**base_input, "evaluationTarget": {"spanIds": batch}} + for batch in self._batch(tool_span_ids, MAX_TARGET_IDS_PER_REQUEST) + ] + + raise ValueError(f"Unknown evaluator level: {level}") + + @staticmethod + def _extract_trace_ids(spans: list) -> List[str]: + """Extract unique trace IDs from spans, ordered by appearance.""" + seen: set = set() + trace_ids: List[str] = [] + for span in spans: + trace_id = span.get("traceId") + if trace_id and trace_id not in seen: + trace_ids.append(trace_id) + seen.add(trace_id) + return trace_ids + + @staticmethod + def _extract_tool_span_ids(spans: list) -> List[str]: + """Extract span IDs for tool execution spans.""" + tool_span_ids: List[str] = [] + for span in spans: + name = span.get("name", "") + kind = span.get("kind") + if kind == "SPAN_KIND_INTERNAL" and name.startswith("Tool:"): + span_id = span.get("spanId") + if span_id: + tool_span_ids.append(span_id) + return tool_span_ids + + @staticmethod + def _batch(items: list, size: int) -> Iterator[list]: + """Yield successive chunks of the given size.""" + for i in range(0, len(items), size): + yield items[i : i + size] diff --git a/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py b/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py index 0732fc89..dcf9225e 100644 --- a/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py +++ b/src/bedrock_agentcore/evaluation/utils/cloudwatch_span_helper.py @@ -4,7 +4,7 @@ import logging import time from datetime import datetime -from typing import Any, List +from typing import Any, List, Optional import boto3 @@ -43,6 +43,7 @@ def query_log_group( session_id: str, start_time: datetime, end_time: datetime, + query_string: Optional[str] = None, ) -> List[dict]: """Query a single CloudWatch log group for session data. @@ -51,11 +52,14 @@ def query_log_group( session_id: Session ID to filter by start_time: Query start time end_time: Query end time + query_string: Optional custom query string. When provided, used instead + of the default substring match query. Returns: List of parsed JSON log messages """ - query_string = f"""fields @timestamp, @message + if query_string is None: + query_string = f"""fields @timestamp, @message | filter @message like "{session_id}" | sort @timestamp asc""" @@ -63,6 +67,14 @@ def query_log_group( initial_backoff = 0.5 max_backoff = 5.0 + logger.debug( + "Querying log group %s: start_time=%s, end_time=%s, query=%s", + log_group_name, + start_time, + end_time, + query_string, + ) + try: response = self.logs_client.start_query( logGroupName=log_group_name, @@ -133,6 +145,7 @@ def fetch_spans( session_id: str, event_log_group: str, start_time: datetime, + end_time: datetime, ) -> List[dict]: """Fetch ADOT spans from CloudWatch with configurable event log group. @@ -145,6 +158,7 @@ def fetch_spans( - For Runtime agents: "/aws/bedrock-agentcore/runtimes/{agent_id}-{endpoint}" - For custom agents: Any log group you configured (e.g., "/my-app/agent-events") start_time: Start time for log query + end_time: End time for log query Returns: List of ADOT span and log record dictionaries @@ -153,21 +167,22 @@ def fetch_spans( >>> from datetime import datetime, timedelta, timezone >>> helper = CloudWatchSpanHelper(region="us-west-2") >>> start_time = datetime.now(timezone.utc) - timedelta(minutes=10) - >>> spans = fetcher.fetch_spans( + >>> end_time = datetime.now(timezone.utc) + >>> spans = helper.fetch_spans( ... session_id="abc-123", ... event_log_group="/aws/bedrock-agentcore/runtimes/my-agent-ABC-DEFAULT", - ... start_time=start_time + ... start_time=start_time, + ... end_time=end_time, ... ) Example (Custom agent): - >>> spans = fetcher.fetch_spans( + >>> spans = helper.fetch_spans( ... session_id="abc-123", ... event_log_group="/my-app/agent-events", - ... start_time=start_time + ... start_time=start_time, + ... end_time=end_time, ... ) """ - end_time = datetime.now() - # Query both log groups aws_spans = self.query_log_group("aws/spans", session_id, start_time, end_time) event_logs = self.query_log_group(event_log_group, session_id, start_time, end_time) @@ -184,6 +199,7 @@ def fetch_spans_from_cloudwatch( session_id: str, event_log_group: str, start_time: datetime, + end_time: datetime, region: str = DEFAULT_REGION, ) -> List[dict]: """Fetch ADOT spans from CloudWatch with configurable event log group. @@ -199,6 +215,7 @@ def fetch_spans_from_cloudwatch( - For Runtime agents: "/aws/bedrock-agentcore/runtimes/{agent_id}-{endpoint}" - For custom agents: Any log group you configured (e.g., "/my-app/agent-events") start_time: Start time for log query + end_time: End time for log query region: AWS region (default: from DEFAULT_REGION constant) Returns: @@ -207,18 +224,21 @@ def fetch_spans_from_cloudwatch( Example (Runtime agent): >>> from datetime import datetime, timedelta, timezone >>> start_time = datetime.now(timezone.utc) - timedelta(minutes=10) + >>> end_time = datetime.now(timezone.utc) >>> spans = fetch_spans_from_cloudwatch( ... session_id="abc-123", ... event_log_group="/aws/bedrock-agentcore/runtimes/my-agent-ABC-DEFAULT", - ... start_time=start_time + ... start_time=start_time, + ... end_time=end_time, ... ) Example (Custom agent): >>> spans = fetch_spans_from_cloudwatch( ... session_id="abc-123", ... event_log_group="/my-app/agent-events", - ... start_time=start_time + ... start_time=start_time, + ... end_time=end_time, ... ) """ helper = CloudWatchSpanHelper(region=region) - return helper.fetch_spans(session_id, event_log_group, start_time) + return helper.fetch_spans(session_id, event_log_group, start_time, end_time) diff --git a/tests/bedrock_agentcore/evaluation/test_client.py b/tests/bedrock_agentcore/evaluation/test_client.py new file mode 100644 index 00000000..d8b8e179 --- /dev/null +++ b/tests/bedrock_agentcore/evaluation/test_client.py @@ -0,0 +1,345 @@ +"""Unit tests for EvaluationClient.""" + +from datetime import timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from bedrock_agentcore.evaluation.client import EvaluationClient + +# --- Fixtures --- + +SAMPLE_SPANS = [ + { + "scope": {"name": "agent"}, + "traceId": "trace-1", + "spanId": "span-1", + "name": "Agent.invoke", + "kind": "SPAN_KIND_SERVER", + }, + { + "scope": {"name": "agent"}, + "traceId": "trace-1", + "spanId": "span-2", + "name": "Tool:search", + "kind": "SPAN_KIND_INTERNAL", + }, + { + "scope": {"name": "agent"}, + "traceId": "trace-1", + "spanId": "span-3", + "name": "Tool:calculator", + "kind": "SPAN_KIND_INTERNAL", + }, + { + "scope": {"name": "agent"}, + "traceId": "trace-2", + "spanId": "span-4", + "name": "Agent.invoke", + "kind": "SPAN_KIND_SERVER", + }, + { + "scope": {"name": "agent"}, + "traceId": "trace-2", + "spanId": "span-5", + "name": "Tool:search", + "kind": "SPAN_KIND_INTERNAL", + }, +] + + +@pytest.fixture +def client(): + """Create an EvaluationClient with mocked boto3 clients.""" + with patch("bedrock_agentcore.evaluation.client.boto3") as mock_boto3: + mock_boto3.Session.return_value.region_name = "us-west-2" + mock_dp = MagicMock() + mock_cp = MagicMock() + mock_boto3.client.side_effect = lambda service, **kwargs: ( + mock_dp if service == "bedrock-agentcore" else mock_cp + ) + c = EvaluationClient(region_name="us-west-2") + c._dp_client = mock_dp + c._cp_client = mock_cp + return c + + +# --- __init__ tests --- + + +class TestInit: + def test_creates_both_clients(self): + with patch("bedrock_agentcore.evaluation.client.boto3") as mock_boto3: + mock_boto3.Session.return_value.region_name = "us-east-1" + EvaluationClient(region_name="us-east-1") + calls = mock_boto3.client.call_args_list + service_names = [call[0][0] for call in calls] + assert "bedrock-agentcore" in service_names + assert "bedrock-agentcore-control" in service_names + + def test_region_fallback(self): + with patch("bedrock_agentcore.evaluation.client.boto3") as mock_boto3: + mock_boto3.Session.return_value.region_name = "eu-west-1" + c = EvaluationClient() + assert c.region_name == "eu-west-1" + + def test_region_fallback_to_default(self): + with patch("bedrock_agentcore.evaluation.client.boto3") as mock_boto3: + mock_boto3.Session.return_value.region_name = None + c = EvaluationClient() + assert c.region_name == "us-west-2" + + def test_empty_evaluator_level_cache(self): + with patch("bedrock_agentcore.evaluation.client.boto3") as mock_boto3: + mock_boto3.Session.return_value.region_name = "us-west-2" + c = EvaluationClient() + assert c._evaluator_level_cache == {} + + +# --- run() validation tests --- + + +class TestRunValidation: + def test_raises_without_agent_id_or_log_group(self, client): + with pytest.raises(ValueError, match="Provide either agent_id or log_group_name"): + client.run(evaluator_ids=["accuracy"], session_id="sess-1") + + def test_derives_log_group_from_agent_id(self, client): + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = [] + client.run(evaluator_ids=["accuracy"], session_id="sess-1", agent_id="my-agent") + mock_collector_cls.assert_called_once_with( + log_group_name="/aws/bedrock-agentcore/runtimes/my-agent-DEFAULT", + region="us-west-2", + max_wait_seconds=60, + poll_interval_seconds=2, + ) + + def test_log_group_name_takes_precedence(self, client): + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = [] + client.run( + evaluator_ids=["accuracy"], + session_id="sess-1", + agent_id="my-agent", + log_group_name="/custom/group", + ) + mock_collector_cls.assert_called_once_with( + log_group_name="/custom/group", + region="us-west-2", + max_wait_seconds=60, + poll_interval_seconds=2, + ) + + def test_returns_empty_when_no_spans(self, client): + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = [] + results = client.run(evaluator_ids=["accuracy"], session_id="sess-1", agent_id="my-agent") + assert results == [] + client._dp_client.evaluate.assert_not_called() + + +# --- run() end-to-end tests --- + + +class TestRunEndToEnd: + def test_session_level_evaluator(self, client): + client._cp_client.get_evaluator.return_value = {"level": "SESSION"} + client._dp_client.evaluate.return_value = {"evaluationResults": [{"evaluatorId": "accuracy", "value": 0.9}]} + + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = SAMPLE_SPANS + results = client.run(evaluator_ids=["accuracy"], session_id="sess-1", agent_id="my-agent") + + assert len(results) == 1 + assert results[0]["value"] == 0.9 + # SESSION level: no evaluationTarget + call_kwargs = client._dp_client.evaluate.call_args[1] + assert "evaluationTarget" not in call_kwargs + + def test_trace_level_evaluator(self, client): + client._cp_client.get_evaluator.return_value = {"level": "TRACE"} + client._dp_client.evaluate.return_value = {"evaluationResults": [{"evaluatorId": "trace-eval", "value": 0.8}]} + + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = SAMPLE_SPANS + results = client.run(evaluator_ids=["trace-eval"], session_id="sess-1", agent_id="my-agent") + + assert len(results) == 1 + call_kwargs = client._dp_client.evaluate.call_args[1] + assert call_kwargs["evaluationTarget"]["traceIds"] == ["trace-1", "trace-2"] + + def test_tool_call_level_evaluator(self, client): + client._cp_client.get_evaluator.return_value = {"level": "TOOL_CALL"} + client._dp_client.evaluate.return_value = {"evaluationResults": [{"evaluatorId": "tool-eval", "value": 0.7}]} + + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = SAMPLE_SPANS + results = client.run(evaluator_ids=["tool-eval"], session_id="sess-1", agent_id="my-agent") + + assert len(results) == 1 + call_kwargs = client._dp_client.evaluate.call_args[1] + assert set(call_kwargs["evaluationTarget"]["spanIds"]) == {"span-2", "span-3", "span-5"} + + def test_multiple_evaluators(self, client): + client._cp_client.get_evaluator.side_effect = [ + {"level": "SESSION"}, + {"level": "TRACE"}, + ] + client._dp_client.evaluate.return_value = {"evaluationResults": [{"evaluatorId": "eval", "value": 1.0}]} + + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = SAMPLE_SPANS + results = client.run( + evaluator_ids=["session-eval", "trace-eval"], + session_id="sess-1", + agent_id="my-agent", + ) + + assert len(results) == 2 + assert client._dp_client.evaluate.call_count == 2 + + def test_evaluator_api_error_is_caught(self, client): + client._cp_client.get_evaluator.return_value = {"level": "SESSION"} + client._dp_client.evaluate.side_effect = RuntimeError("API error") + + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = SAMPLE_SPANS + results = client.run(evaluator_ids=["accuracy"], session_id="sess-1", agent_id="my-agent") + + assert results == [] + + def test_custom_look_back_time(self, client): + with patch("bedrock_agentcore.evaluation.client.CloudWatchAgentSpanCollector") as mock_collector_cls: + mock_collector_cls.return_value.collect.return_value = [] + client.run( + evaluator_ids=["accuracy"], + session_id="sess-1", + agent_id="my-agent", + look_back_time=timedelta(hours=2), + ) + call_kwargs = mock_collector_cls.return_value.collect.call_args[1] + duration = call_kwargs["end_time"] - call_kwargs["start_time"] + assert duration == timedelta(hours=2) + + +# --- _get_evaluator_level tests --- + + +class TestGetEvaluatorLevel: + def test_returns_level_from_api(self, client): + client._cp_client.get_evaluator.return_value = {"level": "TRACE"} + assert client._get_evaluator_level("eval-1") == "TRACE" + + def test_caches_level(self, client): + client._cp_client.get_evaluator.return_value = {"level": "TRACE"} + client._get_evaluator_level("eval-1") + client._get_evaluator_level("eval-1") + client._cp_client.get_evaluator.assert_called_once() + + def test_falls_back_to_session_on_error(self, client): + client._cp_client.get_evaluator.side_effect = RuntimeError("not found") + assert client._get_evaluator_level("eval-1") == "SESSION" + + def test_caches_fallback(self, client): + client._cp_client.get_evaluator.side_effect = RuntimeError("not found") + client._get_evaluator_level("eval-1") + client._get_evaluator_level("eval-1") + client._cp_client.get_evaluator.assert_called_once() + + +# --- _build_requests_for_level tests --- + + +class TestBuildRequestsForLevel: + def test_session_level(self, client): + base = {"evaluationInput": {"sessionSpans": SAMPLE_SPANS}} + requests = client._build_requests_for_level("eval", "SESSION", base, SAMPLE_SPANS) + assert len(requests) == 1 + assert requests[0] is base + + def test_trace_level(self, client): + base = {"evaluationInput": {"sessionSpans": SAMPLE_SPANS}} + requests = client._build_requests_for_level("eval", "TRACE", base, SAMPLE_SPANS) + assert len(requests) == 1 + assert requests[0]["evaluationTarget"]["traceIds"] == ["trace-1", "trace-2"] + + def test_tool_call_level(self, client): + base = {"evaluationInput": {"sessionSpans": SAMPLE_SPANS}} + requests = client._build_requests_for_level("eval", "TOOL_CALL", base, SAMPLE_SPANS) + assert len(requests) == 1 + assert set(requests[0]["evaluationTarget"]["spanIds"]) == {"span-2", "span-3", "span-5"} + + def test_trace_level_no_traces_raises(self, client): + base = {"evaluationInput": {"sessionSpans": []}} + with pytest.raises(ValueError, match="No trace IDs found"): + client._build_requests_for_level("eval", "TRACE", base, []) + + def test_tool_call_level_no_tools_raises(self, client): + spans = [{"name": "Agent.invoke", "kind": "SPAN_KIND_SERVER", "spanId": "s1"}] + base = {"evaluationInput": {"sessionSpans": spans}} + with pytest.raises(ValueError, match="No tool span IDs found"): + client._build_requests_for_level("eval", "TOOL_CALL", base, spans) + + def test_unknown_level_raises(self, client): + with pytest.raises(ValueError, match="Unknown evaluator level"): + client._build_requests_for_level("eval", "UNKNOWN", {}, []) + + def test_trace_level_batching(self, client): + # Create spans with 12 unique trace IDs to trigger batching + spans = [{"traceId": f"trace-{i}", "spanId": f"span-{i}"} for i in range(12)] + base = {"evaluationInput": {"sessionSpans": spans}} + requests = client._build_requests_for_level("eval", "TRACE", base, spans) + assert len(requests) == 2 + assert len(requests[0]["evaluationTarget"]["traceIds"]) == 10 + assert len(requests[1]["evaluationTarget"]["traceIds"]) == 2 + + +# --- Static helper tests --- + + +class TestExtractTraceIds: + def test_extracts_unique_ordered(self): + ids = EvaluationClient._extract_trace_ids(SAMPLE_SPANS) + assert ids == ["trace-1", "trace-2"] + + def test_empty_spans(self): + assert EvaluationClient._extract_trace_ids([]) == [] + + def test_skips_missing_trace_id(self): + spans = [{"spanId": "s1"}, {"traceId": "t1", "spanId": "s2"}] + assert EvaluationClient._extract_trace_ids(spans) == ["t1"] + + +class TestExtractToolSpanIds: + def test_extracts_tool_spans(self): + ids = EvaluationClient._extract_tool_span_ids(SAMPLE_SPANS) + assert ids == ["span-2", "span-3", "span-5"] + + def test_ignores_non_tool_spans(self): + spans = [ + {"name": "Agent.invoke", "kind": "SPAN_KIND_SERVER", "spanId": "s1"}, + {"name": "LLM.call", "kind": "SPAN_KIND_INTERNAL", "spanId": "s2"}, + ] + assert EvaluationClient._extract_tool_span_ids(spans) == [] + + def test_empty_spans(self): + assert EvaluationClient._extract_tool_span_ids([]) == [] + + +class TestBatch: + def test_exact_batches(self): + batches = list(EvaluationClient._batch([1, 2, 3, 4], 2)) + assert batches == [[1, 2], [3, 4]] + + def test_remainder(self): + batches = list(EvaluationClient._batch([1, 2, 3], 2)) + assert batches == [[1, 2], [3]] + + def test_single_batch(self): + batches = list(EvaluationClient._batch([1, 2], 10)) + assert batches == [[1, 2]] + + def test_empty(self): + batches = list(EvaluationClient._batch([], 10)) + assert batches == [] diff --git a/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py b/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py index b78238a9..fc886a5f 100644 --- a/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py +++ b/tests/bedrock_agentcore/evaluation/utils/test_cloudwatch_span_helper.py @@ -126,12 +126,14 @@ def test_fetch_spans_from_cloudwatch(self): } start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) with patch("boto3.client", return_value=mock_client): spans = fetch_spans_from_cloudwatch( session_id="session-123", event_log_group="/aws/bedrock-agentcore/runtimes/my-agent-ABC-DEFAULT", start_time=start_time, + end_time=end_time, ) assert len(spans) == 2 # Called twice (aws/spans + event logs) @@ -159,12 +161,14 @@ def test_fetch_spans_from_cloudwatch_filters_invalid(self): ] start_time = datetime(2024, 1, 1, tzinfo=timezone.utc) + end_time = datetime(2024, 1, 2, tzinfo=timezone.utc) with patch("boto3.client", return_value=mock_client): spans = fetch_spans_from_cloudwatch( session_id="session-123", event_log_group="/aws/bedrock-agentcore/runtimes/my-agent-ABC-DEFAULT", start_time=start_time, + end_time=end_time, ) assert len(spans) == 1 # Only valid document