diff --git a/backend/main.py b/backend/main.py index 819b0e1e..65b22fb1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -27,6 +27,7 @@ # Database from rag_solution.file_management.database import Base, engine, get_db +from rag_solution.router.agent_config_router import collection_agent_router, config_router as agent_config_router from rag_solution.router.agent_router import router as agent_router # Models @@ -260,6 +261,8 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: app.include_router(voice_router) app.include_router(websocket_router) app.include_router(agent_router) +app.include_router(agent_config_router) +app.include_router(collection_agent_router) # Root endpoint @@ -324,6 +327,22 @@ async def root() -> dict[str, str]: "url": "https://spiffe.io/docs/latest/spire-about/spire-concepts/", }, }, + { + "name": "agent-configs", + "description": ( + "Agent configuration management for the 3-stage search pipeline. " + "Create and manage agent configurations for pre-search, post-search, " + "and response stages. Reference: GitHub Issue #697." + ), + }, + { + "name": "collection-agents", + "description": ( + "Collection-agent associations for the search pipeline. " + "Associate agent configurations with collections and manage " + "execution priorities. Reference: GitHub Issue #697." + ), + }, { "name": "podcast", "description": "AI-powered podcast generation from document collections", diff --git a/backend/rag_solution/models/__init__.py b/backend/rag_solution/models/__init__.py index c3837195..0b63f809 100644 --- a/backend/rag_solution/models/__init__.py +++ b/backend/rag_solution/models/__init__.py @@ -5,6 +5,9 @@ # Agent model for SPIFFE-based workload identity from rag_solution.models.agent import Agent +# Agent configuration models for search pipeline hooks +from rag_solution.models.agent_config import AgentConfig, CollectionAgent + # Then Collection since it's referenced by UserCollection from rag_solution.models.collection import Collection @@ -32,8 +35,10 @@ # Register all models with Base.metadata __all__ = [ "Agent", + "AgentConfig", "Base", "Collection", + "CollectionAgent", "ConversationMessage", "ConversationSession", "ConversationSummary", diff --git a/backend/rag_solution/models/agent_config.py b/backend/rag_solution/models/agent_config.py new file mode 100644 index 00000000..9d96a934 --- /dev/null +++ b/backend/rag_solution/models/agent_config.py @@ -0,0 +1,309 @@ +"""Agent configuration model for search pipeline execution hooks. + +This module defines the AgentConfig SQLAlchemy model for storing agent +configurations that can be attached to collections at different pipeline stages. + +The 3-stage pipeline supports: +- Stage 1: Pre-Search Agents (query expansion, language detection, etc.) +- Stage 2: Post-Search Agents (re-ranking, deduplication, PII redaction) +- Stage 3: Response Agents (PowerPoint, PDF, Chart generation) + +Reference: GitHub Issue #697 +""" + +from __future__ import annotations + +import uuid +from datetime import UTC, datetime +from enum import Enum +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime, ForeignKey, Index, Integer, String, Text +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from core.identity_service import IdentityService +from rag_solution.file_management.database import Base + +if TYPE_CHECKING: + from rag_solution.models.user import User + + +class AgentStage(str, Enum): + """Pipeline stages where agents can execute. + + Attributes: + PRE_SEARCH: Before vector search (query enhancement) + POST_SEARCH: After vector search (result enhancement) + RESPONSE: Response generation (artifacts) + """ + + PRE_SEARCH = "pre_search" + POST_SEARCH = "post_search" + RESPONSE = "response" + + +class AgentConfigStatus(str, Enum): + """Agent configuration status. + + Attributes: + ACTIVE: Agent config is active and available + DISABLED: Agent config is temporarily disabled + DEPRECATED: Agent config is deprecated (still works but not recommended) + """ + + ACTIVE = "active" + DISABLED = "disabled" + DEPRECATED = "deprecated" + + +class AgentConfig(Base): + """SQLAlchemy model for agent configurations in the search pipeline. + + AgentConfig defines a specific agent that can be attached to collections + and executed at a specific stage in the search pipeline. + + Attributes: + id: Unique identifier for the agent config (UUID) + name: Human-readable name for the agent config + description: Description of what the agent does + agent_type: Type identifier (e.g., "query_expander", "reranker", "pdf_generator") + stage: Pipeline stage where this agent executes + handler_module: Python module path for the handler + handler_class: Class name within the handler module + default_config: Default configuration JSONB for the agent + timeout_seconds: Maximum execution time before circuit breaker trips + max_retries: Maximum retry attempts on failure + priority: Default execution priority (lower = earlier execution) + is_system: Whether this is a built-in system agent + owner_user_id: User who created this agent config (null for system agents) + status: Current status (active, disabled, deprecated) + created_at: Timestamp of creation + updated_at: Timestamp of last update + """ + + __tablename__ = "agent_configs" + + __table_args__ = ( + Index("ix_agent_configs_stage", "stage"), + Index("ix_agent_configs_agent_type", "agent_type"), + Index("ix_agent_configs_status", "status"), + Index("ix_agent_configs_owner_status", "owner_user_id", "status"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=IdentityService.generate_id, + ) + name: Mapped[str] = mapped_column( + String(255), + nullable=False, + comment="Human-readable agent config name", + ) + description: Mapped[str | None] = mapped_column( + Text, + nullable=True, + comment="Description of agent functionality", + ) + agent_type: Mapped[str] = mapped_column( + String(100), + nullable=False, + index=True, + comment="Agent type identifier (e.g., query_expander, reranker)", + ) + stage: Mapped[str] = mapped_column( + String(50), + nullable=False, + default=AgentStage.PRE_SEARCH.value, + comment="Pipeline stage (pre_search, post_search, response)", + ) + handler_module: Mapped[str] = mapped_column( + String(500), + nullable=False, + comment="Python module path for the handler", + ) + handler_class: Mapped[str] = mapped_column( + String(255), + nullable=False, + comment="Class name within the handler module", + ) + default_config: Mapped[dict] = mapped_column( + JSONB, + nullable=False, + default=dict, + comment="Default configuration for the agent", + ) + timeout_seconds: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=30, + comment="Maximum execution time in seconds", + ) + max_retries: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=2, + comment="Maximum retry attempts on failure", + ) + priority: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=100, + comment="Default execution priority (lower = earlier)", + ) + is_system: Mapped[bool] = mapped_column( + nullable=False, + default=False, + comment="Whether this is a built-in system agent", + ) + owner_user_id: Mapped[uuid.UUID | None] = mapped_column( + UUID(as_uuid=True), + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + index=True, + comment="User who created this config (null for system)", + ) + status: Mapped[str] = mapped_column( + String(50), + nullable=False, + default=AgentConfigStatus.ACTIVE.value, + comment="Status (active, disabled, deprecated)", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + nullable=False, + ) + + # Relationships + owner: Mapped[User | None] = relationship( + "User", + back_populates="agent_configs", + foreign_keys=[owner_user_id], + ) + collection_associations: Mapped[list["CollectionAgent"]] = relationship( + "CollectionAgent", + back_populates="agent_config", + cascade="all, delete-orphan", + ) + + def __repr__(self) -> str: + """String representation of the agent config.""" + return ( + f"AgentConfig(id='{self.id}', name='{self.name}', " + f"agent_type='{self.agent_type}', stage='{self.stage}', status='{self.status}')" + ) + + def is_active(self) -> bool: + """Check if the agent config is active.""" + return self.status == AgentConfigStatus.ACTIVE.value + + +class CollectionAgent(Base): + """Junction table for collection-agent associations. + + This table links collections to agent configs with collection-specific + configuration overrides and priority settings. + + Attributes: + id: Unique identifier for the association + collection_id: UUID of the collection + agent_config_id: UUID of the agent configuration + enabled: Whether this agent is enabled for the collection + priority: Execution priority override (lower = earlier) + config_override: Collection-specific configuration overrides + created_at: Timestamp of association creation + updated_at: Timestamp of last update + """ + + __tablename__ = "collection_agents" + + __table_args__ = ( + Index("ix_collection_agents_collection", "collection_id"), + Index("ix_collection_agents_agent", "agent_config_id"), + Index("ix_collection_agents_enabled", "collection_id", "enabled"), + Index("ix_collection_agents_priority", "collection_id", "priority"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + primary_key=True, + default=IdentityService.generate_id, + ) + collection_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("collections.id", ondelete="CASCADE"), + nullable=False, + comment="Collection this agent is attached to", + ) + agent_config_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), + ForeignKey("agent_configs.id", ondelete="CASCADE"), + nullable=False, + comment="Agent configuration to use", + ) + enabled: Mapped[bool] = mapped_column( + nullable=False, + default=True, + comment="Whether agent is enabled for this collection", + ) + priority: Mapped[int] = mapped_column( + Integer, + nullable=False, + default=100, + comment="Execution priority override (lower = earlier)", + ) + config_override: Mapped[dict] = mapped_column( + JSONB, + nullable=False, + default=dict, + comment="Collection-specific configuration overrides", + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + nullable=False, + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now(UTC), + onupdate=lambda: datetime.now(UTC), + nullable=False, + ) + + # Relationships + agent_config: Mapped[AgentConfig] = relationship( + "AgentConfig", + back_populates="collection_associations", + ) + collection: Mapped["Collection"] = relationship( # noqa: F821 + "Collection", + back_populates="agent_associations", + ) + + def __repr__(self) -> str: + """String representation of the collection-agent association.""" + return ( + f"CollectionAgent(id='{self.id}', collection_id='{self.collection_id}', " + f"agent_config_id='{self.agent_config_id}', enabled={self.enabled}, priority={self.priority})" + ) + + def get_merged_config(self) -> dict: + """Get merged configuration (default + overrides). + + Returns: + Merged configuration dictionary + """ + if not self.agent_config: + return self.config_override + + merged = dict(self.agent_config.default_config) + merged.update(self.config_override) + return merged diff --git a/backend/rag_solution/models/collection.py b/backend/rag_solution/models/collection.py index 6fd1e6f0..b9db53a0 100644 --- a/backend/rag_solution/models/collection.py +++ b/backend/rag_solution/models/collection.py @@ -16,6 +16,7 @@ from rag_solution.schemas.collection_schema import CollectionStatus if TYPE_CHECKING: + from rag_solution.models.agent_config import CollectionAgent from rag_solution.models.conversation import ConversationSession from rag_solution.models.file import File from rag_solution.models.podcast import Podcast @@ -63,6 +64,9 @@ class Collection(Base): # pylint: disable=too-few-public-methods podcasts: Mapped[list["Podcast"]] = relationship( "Podcast", back_populates="collection", cascade="all, delete-orphan" ) + agent_associations: Mapped[list["CollectionAgent"]] = relationship( + "CollectionAgent", back_populates="collection", cascade="all, delete-orphan" + ) def __repr__(self) -> str: return f"Collection(id='{self.id}', name='{self.name}', is_private={self.is_private})" diff --git a/backend/rag_solution/models/user.py b/backend/rag_solution/models/user.py index c01d88b6..30ebe33d 100644 --- a/backend/rag_solution/models/user.py +++ b/backend/rag_solution/models/user.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from rag_solution.models.agent import Agent + from rag_solution.models.agent_config import AgentConfig from rag_solution.models.conversation import ConversationSession from rag_solution.models.file import File from rag_solution.models.llm_parameters import LLMParameters @@ -55,6 +56,9 @@ class User(Base): podcasts: Mapped[list[Podcast]] = relationship("Podcast", back_populates="user", cascade="all, delete-orphan") voices: Mapped[list[Voice]] = relationship("Voice", back_populates="user", cascade="all, delete-orphan") agents: Mapped[list[Agent]] = relationship("Agent", back_populates="owner", cascade="all, delete-orphan") + agent_configs: Mapped[list[AgentConfig]] = relationship( + "AgentConfig", back_populates="owner", cascade="all, delete-orphan" + ) def __repr__(self) -> str: return ( diff --git a/backend/rag_solution/repository/agent_config_repository.py b/backend/rag_solution/repository/agent_config_repository.py new file mode 100644 index 00000000..3060469c --- /dev/null +++ b/backend/rag_solution/repository/agent_config_repository.py @@ -0,0 +1,556 @@ +"""Repository for AgentConfig and CollectionAgent database operations. + +This module provides data access for agent configurations and collection-agent +associations used in the 3-stage search pipeline. + +Reference: GitHub Issue #697 +""" + +from typing import Any + +from pydantic import UUID4 +from sqlalchemy import func +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, joinedload + +from core.custom_exceptions import RepositoryError +from core.logging_utils import get_logger +from rag_solution.core.exceptions import AlreadyExistsError, NotFoundError, ValidationError +from rag_solution.models.agent_config import AgentConfig, AgentConfigStatus, AgentStage, CollectionAgent +from rag_solution.schemas.agent_config_schema import ( + AgentConfigInput, + AgentConfigOutput, + AgentConfigUpdate, + CollectionAgentInput, + CollectionAgentOutput, + CollectionAgentUpdate, +) + +logger = get_logger(__name__) + + +class AgentConfigRepository: + """Repository for handling AgentConfig database operations.""" + + def __init__(self: Any, db: Session) -> None: + """Initialize with database session. + + Args: + db: SQLAlchemy database session + """ + self.db = db + + def create( + self, + config_input: AgentConfigInput, + owner_user_id: UUID4 | None = None, + is_system: bool = False, + ) -> AgentConfigOutput: + """Create a new agent configuration. + + Args: + config_input: Agent config creation data + owner_user_id: UUID of the owning user (null for system configs) + is_system: Whether this is a system-level config + + Returns: + Created agent config data + + Raises: + ValidationError: For validation errors + RepositoryError: For other database errors + """ + try: + agent_config = AgentConfig( + name=config_input.name, + description=config_input.description, + agent_type=config_input.agent_type, + stage=config_input.stage.value, + handler_module=config_input.handler_module, + handler_class=config_input.handler_class, + default_config=config_input.default_config, + timeout_seconds=config_input.timeout_seconds, + max_retries=config_input.max_retries, + priority=config_input.priority, + is_system=is_system, + owner_user_id=owner_user_id, + status=AgentConfigStatus.ACTIVE.value, + ) + self.db.add(agent_config) + self.db.commit() + self.db.refresh(agent_config) + return AgentConfigOutput.model_validate(agent_config) + except IntegrityError as e: + self.db.rollback() + raise ValidationError("An error occurred while creating the agent config") from e + except Exception as e: + self.db.rollback() + logger.error("Error creating agent config: %s", e) + raise RepositoryError(f"Failed to create agent config: {e!s}") from e + + def get_by_id(self, config_id: UUID4) -> AgentConfigOutput: + """Fetch agent config by ID. + + Args: + config_id: UUID of the agent config + + Returns: + Agent config data + + Raises: + NotFoundError: If config not found + RepositoryError: For database errors + """ + try: + config = ( + self.db.query(AgentConfig) + .filter(AgentConfig.id == config_id) + .options(joinedload(AgentConfig.owner)) + .first() + ) + if not config: + raise NotFoundError("AgentConfig", resource_id=str(config_id)) + return AgentConfigOutput.model_validate(config) + except NotFoundError: + raise + except Exception as e: + logger.error("Error getting agent config %s: %s", config_id, e) + raise RepositoryError(f"Failed to get agent config by ID: {e!s}") from e + + def update(self, config_id: UUID4, config_update: AgentConfigUpdate) -> AgentConfigOutput: + """Update agent config data. + + Args: + config_id: UUID of the agent config + config_update: Update data + + Returns: + Updated agent config data + + Raises: + NotFoundError: If config not found + RepositoryError: For database errors + """ + try: + config = self.db.query(AgentConfig).filter(AgentConfig.id == config_id).first() + if not config: + raise NotFoundError("AgentConfig", resource_id=str(config_id)) + + # Update only provided fields + update_data = config_update.model_dump(exclude_unset=True) + for key, value in update_data.items(): + if value is not None: + if key == "status" and hasattr(value, "value"): + value = value.value + setattr(config, key, value) + + self.db.commit() + self.db.refresh(config) + return AgentConfigOutput.model_validate(config) + except NotFoundError: + raise + except Exception as e: + logger.error("Error updating agent config %s: %s", config_id, e) + self.db.rollback() + raise RepositoryError(f"Failed to update agent config: {e!s}") from e + + def delete(self, config_id: UUID4) -> bool: + """Delete an agent config. + + Args: + config_id: UUID of the agent config + + Returns: + True if deleted, False if not found + """ + try: + result = self.db.query(AgentConfig).filter(AgentConfig.id == config_id).delete() + self.db.commit() + return result > 0 + except Exception as e: + logger.error("Error deleting agent config %s: %s", config_id, e) + self.db.rollback() + raise RepositoryError(f"Failed to delete agent config: {e!s}") from e + + def list_configs( + self, + skip: int = 0, + limit: int = 100, + owner_user_id: UUID4 | None = None, + stage: str | None = None, + agent_type: str | None = None, + status: str | None = None, + include_system: bool = True, + ) -> tuple[list[AgentConfigOutput], int]: + """List agent configs with optional filters and pagination. + + Args: + skip: Number of records to skip + limit: Maximum number of records to return + owner_user_id: Filter by owner user ID + stage: Filter by stage + agent_type: Filter by agent type + status: Filter by status + include_system: Whether to include system configs + + Returns: + Tuple of (list of configs, total count) + """ + try: + query = self.db.query(AgentConfig) + + # Apply filters + if owner_user_id: + query = query.filter(AgentConfig.owner_user_id == owner_user_id) + if stage: + query = query.filter(AgentConfig.stage == stage) + if agent_type: + query = query.filter(AgentConfig.agent_type == agent_type) + if status: + query = query.filter(AgentConfig.status == status) + if not include_system: + query = query.filter(AgentConfig.is_system.is_(False)) + + # Get total count + total = query.count() + + # Apply pagination and fetch + configs = ( + query.options(joinedload(AgentConfig.owner)) + .order_by(AgentConfig.stage, AgentConfig.priority, AgentConfig.name) + .offset(skip) + .limit(limit) + .all() + ) + + return ([AgentConfigOutput.model_validate(c) for c in configs], total) + except Exception as e: + logger.error("Error listing agent configs: %s", e) + raise RepositoryError(f"Failed to list agent configs: {e!s}") from e + + def list_by_stage( + self, + stage: AgentStage, + include_system: bool = True, + ) -> list[AgentConfigOutput]: + """List active agent configs for a specific stage. + + Args: + stage: Pipeline stage + include_system: Whether to include system configs + + Returns: + List of active agent configs for the stage + """ + try: + query = self.db.query(AgentConfig).filter( + AgentConfig.stage == stage.value, + AgentConfig.status == AgentConfigStatus.ACTIVE.value, + ) + + if not include_system: + query = query.filter(AgentConfig.is_system.is_(False)) + + configs = query.order_by(AgentConfig.priority, AgentConfig.name).all() + return [AgentConfigOutput.model_validate(c) for c in configs] + except Exception as e: + logger.error("Error listing configs by stage %s: %s", stage, e) + raise RepositoryError(f"Failed to list configs by stage: {e!s}") from e + + def count_by_owner(self, owner_user_id: UUID4) -> int: + """Count agent configs owned by a user. + + Args: + owner_user_id: UUID of the owner + + Returns: + Number of configs + """ + try: + return ( + self.db.query(func.count(AgentConfig.id)).filter(AgentConfig.owner_user_id == owner_user_id).scalar() + or 0 + ) + except Exception as e: + logger.error("Error counting configs for owner %s: %s", owner_user_id, e) + return 0 + + +class CollectionAgentRepository: + """Repository for handling CollectionAgent database operations.""" + + def __init__(self: Any, db: Session) -> None: + """Initialize with database session. + + Args: + db: SQLAlchemy database session + """ + self.db = db + + def create( + self, + collection_id: UUID4, + association_input: CollectionAgentInput, + ) -> CollectionAgentOutput: + """Create a new collection-agent association. + + Args: + collection_id: UUID of the collection + association_input: Association creation data + + Returns: + Created association data + + Raises: + AlreadyExistsError: If association already exists + NotFoundError: If collection or agent config not found + RepositoryError: For database errors + """ + try: + # Check if association already exists + existing = ( + self.db.query(CollectionAgent) + .filter( + CollectionAgent.collection_id == collection_id, + CollectionAgent.agent_config_id == association_input.agent_config_id, + ) + .first() + ) + if existing: + raise AlreadyExistsError( + "CollectionAgent", + "collection_id+agent_config_id", + f"{collection_id}+{association_input.agent_config_id}", + ) + + association = CollectionAgent( + collection_id=collection_id, + agent_config_id=association_input.agent_config_id, + enabled=association_input.enabled, + priority=association_input.priority, + config_override=association_input.config_override, + ) + self.db.add(association) + self.db.commit() + self.db.refresh(association) + + # Load the agent_config relationship + self.db.refresh(association, attribute_names=["agent_config"]) + + return CollectionAgentOutput.model_validate(association) + except AlreadyExistsError: + raise + except IntegrityError as e: + self.db.rollback() + if "collection_id" in str(e): + raise NotFoundError("Collection", resource_id=str(collection_id)) from e + if "agent_config_id" in str(e): + raise NotFoundError("AgentConfig", resource_id=str(association_input.agent_config_id)) from e + raise ValidationError("An error occurred while creating the association") from e + except Exception as e: + self.db.rollback() + logger.error("Error creating collection-agent association: %s", e) + raise RepositoryError(f"Failed to create association: {e!s}") from e + + def get_by_id(self, association_id: UUID4) -> CollectionAgentOutput: + """Fetch association by ID. + + Args: + association_id: UUID of the association + + Returns: + Association data + + Raises: + NotFoundError: If association not found + RepositoryError: For database errors + """ + try: + association = ( + self.db.query(CollectionAgent) + .filter(CollectionAgent.id == association_id) + .options(joinedload(CollectionAgent.agent_config)) + .first() + ) + if not association: + raise NotFoundError("CollectionAgent", resource_id=str(association_id)) + return CollectionAgentOutput.model_validate(association) + except NotFoundError: + raise + except Exception as e: + logger.error("Error getting association %s: %s", association_id, e) + raise RepositoryError(f"Failed to get association by ID: {e!s}") from e + + def update( + self, + association_id: UUID4, + update_data: CollectionAgentUpdate, + ) -> CollectionAgentOutput: + """Update a collection-agent association. + + Args: + association_id: UUID of the association + update_data: Update data + + Returns: + Updated association data + + Raises: + NotFoundError: If association not found + RepositoryError: For database errors + """ + try: + association = self.db.query(CollectionAgent).filter(CollectionAgent.id == association_id).first() + if not association: + raise NotFoundError("CollectionAgent", resource_id=str(association_id)) + + # Update only provided fields + data = update_data.model_dump(exclude_unset=True) + for key, value in data.items(): + if value is not None: + setattr(association, key, value) + + self.db.commit() + self.db.refresh(association) + return CollectionAgentOutput.model_validate(association) + except NotFoundError: + raise + except Exception as e: + logger.error("Error updating association %s: %s", association_id, e) + self.db.rollback() + raise RepositoryError(f"Failed to update association: {e!s}") from e + + def delete(self, association_id: UUID4) -> bool: + """Delete an association. + + Args: + association_id: UUID of the association + + Returns: + True if deleted, False if not found + """ + try: + result = self.db.query(CollectionAgent).filter(CollectionAgent.id == association_id).delete() + self.db.commit() + return result > 0 + except Exception as e: + logger.error("Error deleting association %s: %s", association_id, e) + self.db.rollback() + raise RepositoryError(f"Failed to delete association: {e!s}") from e + + def list_by_collection( + self, + collection_id: UUID4, + stage: str | None = None, + enabled_only: bool = False, + ) -> list[CollectionAgentOutput]: + """List associations for a collection. + + Args: + collection_id: UUID of the collection + stage: Optional stage filter + enabled_only: Only return enabled associations + + Returns: + List of associations ordered by priority + """ + try: + query = ( + self.db.query(CollectionAgent) + .join(AgentConfig) + .filter(CollectionAgent.collection_id == collection_id) + .options(joinedload(CollectionAgent.agent_config)) + ) + + if stage: + query = query.filter(AgentConfig.stage == stage) + if enabled_only: + query = query.filter(CollectionAgent.enabled.is_(True)) + + associations = query.order_by(CollectionAgent.priority, AgentConfig.name).all() + return [CollectionAgentOutput.model_validate(a) for a in associations] + except Exception as e: + logger.error("Error listing associations for collection %s: %s", collection_id, e) + raise RepositoryError(f"Failed to list associations: {e!s}") from e + + def batch_update_priorities( + self, + collection_id: UUID4, + priorities: dict[UUID4, int], + ) -> list[CollectionAgentOutput]: + """Batch update priorities for multiple associations. + + Args: + collection_id: UUID of the collection + priorities: Mapping of association ID to new priority + + Returns: + List of updated associations + """ + try: + updated = [] + for assoc_id, priority in priorities.items(): + association = ( + self.db.query(CollectionAgent) + .filter( + CollectionAgent.id == assoc_id, + CollectionAgent.collection_id == collection_id, + ) + .first() + ) + if association: + association.priority = priority + updated.append(association) + + self.db.commit() + + # Refresh and return + for assoc in updated: + self.db.refresh(assoc) + + return [CollectionAgentOutput.model_validate(a) for a in updated] + except Exception as e: + logger.error("Error batch updating priorities: %s", e) + self.db.rollback() + raise RepositoryError(f"Failed to batch update priorities: {e!s}") from e + + def count_by_collection(self, collection_id: UUID4, enabled_only: bool = False) -> int: + """Count associations for a collection. + + Args: + collection_id: UUID of the collection + enabled_only: Only count enabled associations + + Returns: + Number of associations + """ + try: + query = self.db.query(func.count(CollectionAgent.id)).filter( + CollectionAgent.collection_id == collection_id + ) + if enabled_only: + query = query.filter(CollectionAgent.enabled.is_(True)) + return query.scalar() or 0 + except Exception as e: + logger.error("Error counting associations for collection %s: %s", collection_id, e) + return 0 + + def delete_by_collection(self, collection_id: UUID4) -> int: + """Delete all associations for a collection. + + Args: + collection_id: UUID of the collection + + Returns: + Number of deleted associations + """ + try: + result = ( + self.db.query(CollectionAgent).filter(CollectionAgent.collection_id == collection_id).delete() + ) + self.db.commit() + return result + except Exception as e: + logger.error("Error deleting associations for collection %s: %s", collection_id, e) + self.db.rollback() + raise RepositoryError(f"Failed to delete associations: {e!s}") from e diff --git a/backend/rag_solution/router/agent_config_router.py b/backend/rag_solution/router/agent_config_router.py new file mode 100644 index 00000000..37eedf3f --- /dev/null +++ b/backend/rag_solution/router/agent_config_router.py @@ -0,0 +1,590 @@ +"""Router for Agent Configuration API endpoints. + +This module provides REST API endpoints for managing agent configurations +and collection-agent associations for the 3-stage search pipeline. + +Endpoints: + Agent Configurations: + - POST /api/agent-configs - Create a new agent config + - GET /api/agent-configs - List agent configs + - GET /api/agent-configs/{config_id} - Get agent config by ID + - PUT /api/agent-configs/{config_id} - Update agent config + - DELETE /api/agent-configs/{config_id} - Delete agent config + - GET /api/agent-configs/stages/{stage} - List configs by stage + + Collection-Agent Associations: + - POST /api/collections/{collection_id}/agents - Add agent to collection + - GET /api/collections/{collection_id}/agents - List collection agents + - GET /api/collections/{collection_id}/agents/summary - Get agent summary + - PUT /api/collections/{collection_id}/agents/{assoc_id} - Update association + - DELETE /api/collections/{collection_id}/agents/{assoc_id} - Remove association + - POST /api/collections/{collection_id}/agents/priorities - Batch update priorities + +Reference: GitHub Issue #697 +""" + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from pydantic import UUID4 +from sqlalchemy.orm import Session + +from core.logging_utils import get_logger +from rag_solution.core.exceptions import AlreadyExistsError, NotFoundError, ValidationError +from rag_solution.file_management.database import get_db +from rag_solution.schemas.agent_config_schema import ( + AgentConfigInput, + AgentConfigListResponse, + AgentConfigOutput, + AgentConfigUpdate, + BatchPriorityUpdate, + CollectionAgentInput, + CollectionAgentListResponse, + CollectionAgentOutput, + CollectionAgentUpdate, +) +from rag_solution.services.agent_config_service import AgentConfigService + +logger = get_logger(__name__) + +# ============================================================================ +# Agent Configuration Router +# ============================================================================ + +config_router = APIRouter( + prefix="/api/agent-configs", + tags=["agent-configs"], + responses={ + 401: {"description": "Unauthorized"}, + 403: {"description": "Forbidden"}, + 404: {"description": "Agent config not found"}, + }, +) + + +def get_current_user_id(request: Request) -> UUID4: + """Extract current user ID from request state. + + Args: + request: FastAPI request object + + Returns: + User UUID + + Raises: + HTTPException: If user not authenticated + """ + user = getattr(request.state, "user", None) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required", + ) + + user_id = user.get("uuid") or user.get("id") + if not user_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User ID not found in token", + ) + + return UUID4(user_id) + + +@config_router.post( + "", + response_model=AgentConfigOutput, + status_code=status.HTTP_201_CREATED, + summary="Create agent config", + description="Create a new agent configuration for the search pipeline.", +) +async def create_agent_config( + request: Request, + config_input: AgentConfigInput, + db: Session = Depends(get_db), +) -> AgentConfigOutput: + """Create a new agent configuration. + + Args: + request: FastAPI request object + config_input: Agent config creation data + db: Database session + + Returns: + Created agent config + """ + try: + owner_user_id = get_current_user_id(request) + service = AgentConfigService(db) + return service.create_config(config_input, owner_user_id) + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error("Error creating agent config: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create agent config", + ) from e + + +@config_router.get( + "", + response_model=AgentConfigListResponse, + summary="List agent configs", + description="List agent configurations with optional filtering.", +) +async def list_agent_configs( + request: Request, + skip: int = Query(0, ge=0, description="Number of records to skip"), + limit: int = Query(100, ge=1, le=1000, description="Maximum records to return"), + stage: str | None = Query(None, description="Filter by pipeline stage"), + agent_type: str | None = Query(None, description="Filter by agent type"), + config_status: str | None = Query(None, alias="status", description="Filter by status"), + mine_only: bool = Query(False, description="Only show configs owned by current user"), + include_system: bool = Query(True, description="Include system configs"), + db: Session = Depends(get_db), +) -> AgentConfigListResponse: + """List agent configurations with filtering. + + Args: + request: FastAPI request object + skip: Pagination offset + limit: Maximum records + stage: Filter by stage + agent_type: Filter by type + config_status: Filter by status + mine_only: Only show owned configs + include_system: Include system configs + db: Database session + + Returns: + Paginated list of configs + """ + try: + owner_user_id = None + if mine_only: + owner_user_id = get_current_user_id(request) + + service = AgentConfigService(db) + return service.list_configs( + skip=skip, + limit=limit, + owner_user_id=owner_user_id, + stage=stage, + agent_type=agent_type, + status=config_status, + include_system=include_system, + ) + except Exception as e: + logger.error("Error listing agent configs: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to list agent configs", + ) from e + + +@config_router.get( + "/stages/{stage}", + response_model=list[AgentConfigOutput], + summary="List configs by stage", + description="List active agent configurations for a specific pipeline stage.", +) +async def list_configs_by_stage( + stage: str, + include_system: bool = Query(True, description="Include system configs"), + db: Session = Depends(get_db), +) -> list[AgentConfigOutput]: + """List configs for a specific stage. + + Args: + stage: Pipeline stage (pre_search, post_search, response) + include_system: Include system configs + db: Database session + + Returns: + List of configs for the stage + """ + try: + service = AgentConfigService(db) + return service.list_by_stage(stage, include_system) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error("Error listing configs by stage: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to list configs by stage", + ) from e + + +@config_router.get( + "/{config_id}", + response_model=AgentConfigOutput, + summary="Get agent config", + description="Get a specific agent configuration by ID.", +) +async def get_agent_config( + config_id: UUID4, + db: Session = Depends(get_db), +) -> AgentConfigOutput: + """Get agent config by ID. + + Args: + config_id: UUID of the config + db: Database session + + Returns: + Agent config + """ + try: + service = AgentConfigService(db) + return service.get_config(config_id) + except NotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except Exception as e: + logger.error("Error getting agent config %s: %s", config_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to get agent config", + ) from e + + +@config_router.put( + "/{config_id}", + response_model=AgentConfigOutput, + summary="Update agent config", + description="Update an existing agent configuration.", +) +async def update_agent_config( + config_id: UUID4, + config_update: AgentConfigUpdate, + db: Session = Depends(get_db), +) -> AgentConfigOutput: + """Update agent config. + + Args: + config_id: UUID of the config + config_update: Update data + db: Database session + + Returns: + Updated agent config + """ + try: + service = AgentConfigService(db) + return service.update_config(config_id, config_update) + except NotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error("Error updating agent config %s: %s", config_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update agent config", + ) from e + + +@config_router.delete( + "/{config_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete agent config", + description="Delete an agent configuration.", +) +async def delete_agent_config( + config_id: UUID4, + db: Session = Depends(get_db), +) -> None: + """Delete agent config. + + Args: + config_id: UUID of the config + db: Database session + """ + try: + service = AgentConfigService(db) + deleted = service.delete_config(config_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Agent config {config_id} not found", + ) + except HTTPException: + raise + except Exception as e: + logger.error("Error deleting agent config %s: %s", config_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to delete agent config", + ) from e + + +# ============================================================================ +# Collection-Agent Association Router +# ============================================================================ + +collection_agent_router = APIRouter( + prefix="/api/collections/{collection_id}/agents", + tags=["collection-agents"], + responses={ + 401: {"description": "Unauthorized"}, + 403: {"description": "Forbidden"}, + 404: {"description": "Collection or association not found"}, + }, +) + + +@collection_agent_router.post( + "", + response_model=CollectionAgentOutput, + status_code=status.HTTP_201_CREATED, + summary="Add agent to collection", + description="Associate an agent configuration with a collection.", +) +async def add_agent_to_collection( + collection_id: UUID4, + association_input: CollectionAgentInput, + db: Session = Depends(get_db), +) -> CollectionAgentOutput: + """Add agent to collection. + + Args: + collection_id: UUID of the collection + association_input: Association data + db: Database session + + Returns: + Created association + """ + try: + service = AgentConfigService(db) + return service.add_agent_to_collection(collection_id, association_input) + except AlreadyExistsError as e: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(e), + ) from e + except NotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except ValidationError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e), + ) from e + except Exception as e: + logger.error("Error adding agent to collection %s: %s", collection_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to add agent to collection", + ) from e + + +@collection_agent_router.get( + "", + response_model=CollectionAgentListResponse, + summary="List collection agents", + description="List all agents associated with a collection.", +) +async def list_collection_agents( + collection_id: UUID4, + stage: str | None = Query(None, description="Filter by pipeline stage"), + enabled_only: bool = Query(False, description="Only show enabled agents"), + db: Session = Depends(get_db), +) -> CollectionAgentListResponse: + """List agents for a collection. + + Args: + collection_id: UUID of the collection + stage: Filter by stage + enabled_only: Only enabled agents + db: Database session + + Returns: + List of associations + """ + try: + service = AgentConfigService(db) + return service.list_collection_agents( + collection_id=collection_id, + stage=stage, + enabled_only=enabled_only, + ) + except Exception as e: + logger.error("Error listing agents for collection %s: %s", collection_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to list collection agents", + ) from e + + +@collection_agent_router.get( + "/summary", + summary="Get agent summary", + description="Get a summary of agents for a collection by stage.", +) +async def get_collection_agent_summary( + collection_id: UUID4, + db: Session = Depends(get_db), +) -> dict: + """Get agent summary for a collection. + + Args: + collection_id: UUID of the collection + db: Database session + + Returns: + Summary with counts per stage + """ + try: + service = AgentConfigService(db) + return service.get_collection_agent_summary(collection_id) + except Exception as e: + logger.error("Error getting agent summary for collection %s: %s", collection_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to get agent summary", + ) from e + + +@collection_agent_router.put( + "/{association_id}", + response_model=CollectionAgentOutput, + summary="Update association", + description="Update a collection-agent association.", +) +async def update_collection_agent( + collection_id: UUID4, + association_id: UUID4, + update_data: CollectionAgentUpdate, + db: Session = Depends(get_db), +) -> CollectionAgentOutput: + """Update collection-agent association. + + Args: + collection_id: UUID of the collection (for validation) + association_id: UUID of the association + update_data: Update data + db: Database session + + Returns: + Updated association + """ + try: + service = AgentConfigService(db) + # Verify association belongs to collection + assoc = service.get_association(association_id) + if assoc.collection_id != collection_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Association {association_id} not found in collection {collection_id}", + ) + return service.update_association(association_id, update_data) + except NotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except HTTPException: + raise + except Exception as e: + logger.error("Error updating association %s: %s", association_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to update association", + ) from e + + +@collection_agent_router.delete( + "/{association_id}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Remove agent from collection", + description="Remove an agent from a collection.", +) +async def remove_agent_from_collection( + collection_id: UUID4, + association_id: UUID4, + db: Session = Depends(get_db), +) -> None: + """Remove agent from collection. + + Args: + collection_id: UUID of the collection + association_id: UUID of the association + db: Database session + """ + try: + service = AgentConfigService(db) + # Verify association belongs to collection + assoc = service.get_association(association_id) + if assoc.collection_id != collection_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Association {association_id} not found in collection {collection_id}", + ) + deleted = service.remove_agent_from_collection(association_id) + if not deleted: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Association {association_id} not found", + ) + except NotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) from e + except HTTPException: + raise + except Exception as e: + logger.error("Error removing association %s: %s", association_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to remove agent from collection", + ) from e + + +@collection_agent_router.post( + "/priorities", + response_model=list[CollectionAgentOutput], + summary="Batch update priorities", + description="Batch update priorities for multiple collection-agent associations.", +) +async def batch_update_priorities( + collection_id: UUID4, + priority_update: BatchPriorityUpdate, + db: Session = Depends(get_db), +) -> list[CollectionAgentOutput]: + """Batch update priorities. + + Args: + collection_id: UUID of the collection + priority_update: Priority updates + db: Database session + + Returns: + List of updated associations + """ + try: + service = AgentConfigService(db) + return service.batch_update_priorities(collection_id, priority_update.priorities) + except Exception as e: + logger.error("Error batch updating priorities for collection %s: %s", collection_id, e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to batch update priorities", + ) from e diff --git a/backend/rag_solution/schemas/agent_config_schema.py b/backend/rag_solution/schemas/agent_config_schema.py new file mode 100644 index 00000000..16a53abd --- /dev/null +++ b/backend/rag_solution/schemas/agent_config_schema.py @@ -0,0 +1,342 @@ +"""Pydantic schemas for Agent Configuration and Execution Hooks. + +This module defines the request/response schemas for the Agent Configuration API +and execution hooks for the 3-stage search pipeline. + +The 3-stage pipeline supports: +- Stage 1: Pre-Search Agents (query expansion, language detection, etc.) +- Stage 2: Post-Search Agents (re-ranking, deduplication, PII redaction) +- Stage 3: Response Agents (PowerPoint, PDF, Chart generation) + +Reference: GitHub Issue #697 +""" + +from datetime import datetime +from enum import Enum +from typing import Any + +from pydantic import UUID4, BaseModel, ConfigDict, Field + + +class AgentStage(str, Enum): + """Pipeline stages where agents can execute.""" + + PRE_SEARCH = "pre_search" + POST_SEARCH = "post_search" + RESPONSE = "response" + + +class AgentConfigStatus(str, Enum): + """Agent configuration status.""" + + ACTIVE = "active" + DISABLED = "disabled" + DEPRECATED = "deprecated" + + +class AgentExecutionStatus(str, Enum): + """Status of an individual agent execution.""" + + SUCCESS = "success" + FAILED = "failed" + TIMEOUT = "timeout" + SKIPPED = "skipped" + CIRCUIT_OPEN = "circuit_open" + + +# ============================================================================ +# Agent Configuration Schemas +# ============================================================================ + + +class AgentConfigInput(BaseModel): + """Schema for creating a new agent configuration. + + Attributes: + name: Human-readable name for the agent config + description: Description of what the agent does + agent_type: Type identifier (e.g., "query_expander", "reranker") + stage: Pipeline stage where this agent executes + handler_module: Python module path for the handler + handler_class: Class name within the handler module + default_config: Default configuration for the agent + timeout_seconds: Maximum execution time before circuit breaker trips + max_retries: Maximum retry attempts on failure + priority: Default execution priority (lower = earlier execution) + """ + + name: str = Field(..., min_length=1, max_length=255, description="Human-readable agent config name") + description: str | None = Field(default=None, max_length=2000, description="Agent description") + agent_type: str = Field(..., min_length=1, max_length=100, description="Agent type identifier") + stage: AgentStage = Field(default=AgentStage.PRE_SEARCH, description="Pipeline stage") + handler_module: str = Field(..., min_length=1, max_length=500, description="Python module path") + handler_class: str = Field(..., min_length=1, max_length=255, description="Handler class name") + default_config: dict[str, Any] = Field(default_factory=dict, description="Default configuration") + timeout_seconds: int = Field(default=30, ge=1, le=300, description="Max execution time") + max_retries: int = Field(default=2, ge=0, le=5, description="Max retry attempts") + priority: int = Field(default=100, ge=0, le=1000, description="Execution priority") + + +class AgentConfigUpdate(BaseModel): + """Schema for updating an existing agent configuration. + + All fields are optional - only provided fields will be updated. + """ + + name: str | None = Field(default=None, min_length=1, max_length=255) + description: str | None = Field(default=None, max_length=2000) + default_config: dict[str, Any] | None = Field(default=None) + timeout_seconds: int | None = Field(default=None, ge=1, le=300) + max_retries: int | None = Field(default=None, ge=0, le=5) + priority: int | None = Field(default=None, ge=0, le=1000) + status: AgentConfigStatus | None = Field(default=None) + + +class AgentConfigOutput(BaseModel): + """Schema for agent configuration response data. + + Attributes: + id: Unique identifier for the agent config + name: Human-readable name + description: Agent description + agent_type: Type identifier + stage: Pipeline stage + handler_module: Python module path + handler_class: Handler class name + default_config: Default configuration + timeout_seconds: Max execution time + max_retries: Max retry attempts + priority: Execution priority + is_system: Whether this is a built-in system agent + owner_user_id: User who created this config + status: Current status + created_at: Creation timestamp + updated_at: Last update timestamp + """ + + id: UUID4 + name: str + description: str | None + agent_type: str + stage: str + handler_module: str + handler_class: str + default_config: dict[str, Any] + timeout_seconds: int + max_retries: int + priority: int + is_system: bool + owner_user_id: UUID4 | None + status: str + created_at: datetime + updated_at: datetime + + model_config = ConfigDict(from_attributes=True) + + +class AgentConfigListResponse(BaseModel): + """Schema for paginated agent config list response.""" + + configs: list[AgentConfigOutput] + total: int + skip: int + limit: int + + +# ============================================================================ +# Collection-Agent Association Schemas +# ============================================================================ + + +class CollectionAgentInput(BaseModel): + """Schema for associating an agent with a collection. + + Attributes: + agent_config_id: UUID of the agent configuration + enabled: Whether agent is enabled for this collection + priority: Execution priority override + config_override: Collection-specific configuration overrides + """ + + agent_config_id: UUID4 = Field(..., description="Agent configuration ID") + enabled: bool = Field(default=True, description="Whether agent is enabled") + priority: int = Field(default=100, ge=0, le=1000, description="Priority override") + config_override: dict[str, Any] = Field(default_factory=dict, description="Config overrides") + + +class CollectionAgentUpdate(BaseModel): + """Schema for updating a collection-agent association.""" + + enabled: bool | None = Field(default=None) + priority: int | None = Field(default=None, ge=0, le=1000) + config_override: dict[str, Any] | None = Field(default=None) + + +class CollectionAgentOutput(BaseModel): + """Schema for collection-agent association response data.""" + + id: UUID4 + collection_id: UUID4 + agent_config_id: UUID4 + enabled: bool + priority: int + config_override: dict[str, Any] + created_at: datetime + updated_at: datetime + agent_config: AgentConfigOutput | None = None + + model_config = ConfigDict(from_attributes=True) + + +class CollectionAgentListResponse(BaseModel): + """Schema for collection agent associations list.""" + + associations: list[CollectionAgentOutput] + total: int + + +class BatchPriorityUpdate(BaseModel): + """Schema for batch updating agent priorities. + + Attributes: + priorities: Mapping of association ID to new priority + """ + + priorities: dict[UUID4, int] = Field(..., description="Association ID to priority mapping") + + +# ============================================================================ +# Agent Execution Schemas +# ============================================================================ + + +class AgentContext(BaseModel): + """Context passed to agents during pipeline execution. + + This provides all the information an agent needs to execute, + including the search input, retrieved documents, and stage-specific data. + + Attributes: + search_input: Original search request data + collection_id: Collection being searched + user_id: User making the request + stage: Current pipeline stage + query: Current query (may be modified by previous agents) + query_results: Retrieved documents (populated after retrieval stage) + previous_results: Results from previously executed agents in this stage + config: Merged configuration for this agent + metadata: Additional context metadata + """ + + search_input: dict[str, Any] = Field(..., description="Original search request") + collection_id: UUID4 = Field(..., description="Collection ID") + user_id: UUID4 = Field(..., description="User ID") + stage: AgentStage = Field(..., description="Current pipeline stage") + query: str = Field(..., description="Current query") + query_results: list[dict[str, Any]] = Field(default_factory=list, description="Retrieved documents") + previous_results: list["AgentResult"] = Field(default_factory=list, description="Previous agent results") + config: dict[str, Any] = Field(default_factory=dict, description="Merged agent config") + metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata") + + +class AgentResult(BaseModel): + """Result from an agent execution. + + Attributes: + agent_config_id: ID of the agent configuration + agent_name: Name of the agent + agent_type: Type of the agent + stage: Pipeline stage where executed + status: Execution status + execution_time_ms: Time taken in milliseconds + modified_query: Modified query (for pre-search agents) + modified_results: Modified results (for post-search agents) + artifacts: Generated artifacts (for response agents) + metadata: Additional result metadata + error_message: Error message if failed + """ + + agent_config_id: UUID4 = Field(..., description="Agent config ID") + agent_name: str = Field(..., description="Agent name") + agent_type: str = Field(..., description="Agent type") + stage: str = Field(..., description="Pipeline stage") + status: AgentExecutionStatus = Field(..., description="Execution status") + execution_time_ms: float = Field(..., description="Execution time in ms") + modified_query: str | None = Field(default=None, description="Modified query") + modified_results: list[dict[str, Any]] | None = Field(default=None, description="Modified results") + artifacts: list["AgentArtifact"] | None = Field(default=None, description="Generated artifacts") + metadata: dict[str, Any] = Field(default_factory=dict, description="Result metadata") + error_message: str | None = Field(default=None, description="Error message") + + +class AgentArtifact(BaseModel): + """Artifact generated by a response agent. + + Attributes: + artifact_type: Type of artifact (e.g., "pdf", "pptx", "chart") + content_type: MIME type of the content + filename: Suggested filename + data_url: Base64 data URL or download URL + metadata: Additional artifact metadata + """ + + artifact_type: str = Field(..., description="Artifact type") + content_type: str = Field(..., description="MIME type") + filename: str = Field(..., description="Suggested filename") + data_url: str | None = Field(default=None, description="Data URL or download URL") + size_bytes: int | None = Field(default=None, description="Size in bytes") + metadata: dict[str, Any] = Field(default_factory=dict, description="Artifact metadata") + + +class AgentExecutionSummary(BaseModel): + """Summary of all agent executions for a search request. + + Attributes: + total_agents: Total number of agents executed + successful: Number of successful executions + failed: Number of failed executions + skipped: Number of skipped executions + total_execution_time_ms: Total execution time in milliseconds + pre_search_results: Results from pre-search stage + post_search_results: Results from post-search stage + response_results: Results from response stage + artifacts: All generated artifacts + """ + + total_agents: int = Field(default=0, description="Total agents executed") + successful: int = Field(default=0, description="Successful executions") + failed: int = Field(default=0, description="Failed executions") + skipped: int = Field(default=0, description="Skipped executions") + total_execution_time_ms: float = Field(default=0.0, description="Total execution time") + pre_search_results: list[AgentResult] = Field(default_factory=list) + post_search_results: list[AgentResult] = Field(default_factory=list) + response_results: list[AgentResult] = Field(default_factory=list) + artifacts: list[AgentArtifact] = Field(default_factory=list) + + +# ============================================================================ +# Pipeline Metadata Schema +# ============================================================================ + + +class PipelineMetadata(BaseModel): + """Metadata about pipeline execution including agent timing. + + Attributes: + pipeline_version: Version of the pipeline architecture + stages_executed: List of stages that were executed + total_execution_time_ms: Total pipeline execution time + agent_execution_summary: Summary of agent executions + timings: Detailed timing breakdown by stage + """ + + pipeline_version: str = Field(default="v2_with_agents", description="Pipeline version") + stages_executed: list[str] = Field(default_factory=list, description="Executed stages") + total_execution_time_ms: float = Field(default=0.0, description="Total execution time") + agent_execution_summary: AgentExecutionSummary | None = Field(default=None) + timings: dict[str, float] = Field(default_factory=dict, description="Stage timings") + + +# Update forward references +AgentContext.model_rebuild() +AgentResult.model_rebuild() diff --git a/backend/rag_solution/schemas/search_schema.py b/backend/rag_solution/schemas/search_schema.py index 9761f821..e2c8310f 100644 --- a/backend/rag_solution/schemas/search_schema.py +++ b/backend/rag_solution/schemas/search_schema.py @@ -4,6 +4,7 @@ from pydantic import UUID4, BaseModel, ConfigDict +from rag_solution.schemas.agent_config_schema import AgentArtifact, AgentExecutionSummary from rag_solution.schemas.llm_usage_schema import TokenWarning from rag_solution.schemas.structured_output_schema import StructuredAnswer from vectordbs.data_types import DocumentMetadata, QueryResult @@ -46,6 +47,8 @@ class SearchOutput(BaseModel): rewritten_query: Optional rewritten version of the original query evaluation: Optional evaluation metrics and results structured_answer: Optional structured answer with citations (when requested) + agent_artifacts: Artifacts generated by response agents (PDFs, charts, etc.) + agent_executions: Summary of agent executions at each pipeline stage """ answer: str @@ -58,5 +61,7 @@ class SearchOutput(BaseModel): metadata: dict[str, Any] | None = None # Additional metadata including conversation context token_warning: TokenWarning | None = None # Token usage warning if approaching limits structured_answer: StructuredAnswer | None = None # Structured output with citations when requested + agent_artifacts: list[AgentArtifact] | None = None # Generated artifacts from response agents + agent_executions: AgentExecutionSummary | None = None # Summary of agent executions model_config = ConfigDict(from_attributes=True) diff --git a/backend/rag_solution/services/agent_config_service.py b/backend/rag_solution/services/agent_config_service.py new file mode 100644 index 00000000..ab4f5267 --- /dev/null +++ b/backend/rag_solution/services/agent_config_service.py @@ -0,0 +1,333 @@ +"""Service layer for Agent Configuration management. + +This module provides the AgentConfigService that handles business logic +for managing agent configurations and collection-agent associations. + +Reference: GitHub Issue #697 +""" + +from typing import Any + +from pydantic import UUID4 +from sqlalchemy.orm import Session + +from core.config import Settings +from core.logging_utils import get_logger +from rag_solution.models.agent_config import AgentStage +from rag_solution.repository.agent_config_repository import AgentConfigRepository, CollectionAgentRepository +from rag_solution.schemas.agent_config_schema import ( + AgentConfigInput, + AgentConfigListResponse, + AgentConfigOutput, + AgentConfigUpdate, + CollectionAgentInput, + CollectionAgentListResponse, + CollectionAgentOutput, + CollectionAgentUpdate, +) + +logger = get_logger("services.agent_config") + + +class AgentConfigService: + """Service for managing agent configurations. + + This service provides business logic for creating, updating, and + managing agent configurations that can be attached to collections. + """ + + def __init__(self, db: Session, settings: Settings | None = None) -> None: + """Initialize the service. + + Args: + db: Database session + settings: Application settings + """ + self.db = db + self.settings = settings or Settings() + self._config_repo: AgentConfigRepository | None = None + self._assoc_repo: CollectionAgentRepository | None = None + + @property + def config_repo(self) -> AgentConfigRepository: + """Lazy initialization of config repository.""" + if self._config_repo is None: + self._config_repo = AgentConfigRepository(self.db) + return self._config_repo + + @property + def assoc_repo(self) -> CollectionAgentRepository: + """Lazy initialization of association repository.""" + if self._assoc_repo is None: + self._assoc_repo = CollectionAgentRepository(self.db) + return self._assoc_repo + + # ======================================================================== + # Agent Configuration Methods + # ======================================================================== + + def create_config( + self, + config_input: AgentConfigInput, + owner_user_id: UUID4 | None = None, + is_system: bool = False, + ) -> AgentConfigOutput: + """Create a new agent configuration. + + Args: + config_input: Agent config creation data + owner_user_id: UUID of the owning user + is_system: Whether this is a system-level config + + Returns: + Created agent config + """ + logger.info("Creating agent config: %s", config_input.name) + config = self.config_repo.create(config_input, owner_user_id, is_system) + logger.info("Created agent config %s with ID %s", config.name, config.id) + return config + + def get_config(self, config_id: UUID4) -> AgentConfigOutput: + """Get an agent configuration by ID. + + Args: + config_id: UUID of the config + + Returns: + Agent config + """ + return self.config_repo.get_by_id(config_id) + + def update_config( + self, + config_id: UUID4, + config_update: AgentConfigUpdate, + ) -> AgentConfigOutput: + """Update an agent configuration. + + Args: + config_id: UUID of the config + config_update: Update data + + Returns: + Updated agent config + """ + logger.info("Updating agent config %s", config_id) + config = self.config_repo.update(config_id, config_update) + logger.info("Updated agent config %s", config_id) + return config + + def delete_config(self, config_id: UUID4) -> bool: + """Delete an agent configuration. + + Args: + config_id: UUID of the config + + Returns: + True if deleted + """ + logger.info("Deleting agent config %s", config_id) + result = self.config_repo.delete(config_id) + if result: + logger.info("Deleted agent config %s", config_id) + return result + + def list_configs( + self, + skip: int = 0, + limit: int = 100, + owner_user_id: UUID4 | None = None, + stage: str | None = None, + agent_type: str | None = None, + status: str | None = None, + include_system: bool = True, + ) -> AgentConfigListResponse: + """List agent configurations with optional filters. + + Args: + skip: Pagination offset + limit: Maximum results + owner_user_id: Filter by owner + stage: Filter by stage + agent_type: Filter by type + status: Filter by status + include_system: Include system configs + + Returns: + Paginated list of configs + """ + configs, total = self.config_repo.list_configs( + skip=skip, + limit=limit, + owner_user_id=owner_user_id, + stage=stage, + agent_type=agent_type, + status=status, + include_system=include_system, + ) + return AgentConfigListResponse(configs=configs, total=total, skip=skip, limit=limit) + + def list_by_stage( + self, + stage: str, + include_system: bool = True, + ) -> list[AgentConfigOutput]: + """List active configs for a specific stage. + + Args: + stage: Pipeline stage + include_system: Include system configs + + Returns: + List of configs for the stage + """ + try: + agent_stage = AgentStage(stage) + except ValueError as e: + raise ValueError(f"Invalid stage: {stage}") from e + return self.config_repo.list_by_stage(agent_stage, include_system) + + # ======================================================================== + # Collection-Agent Association Methods + # ======================================================================== + + def add_agent_to_collection( + self, + collection_id: UUID4, + association_input: CollectionAgentInput, + ) -> CollectionAgentOutput: + """Add an agent to a collection. + + Args: + collection_id: UUID of the collection + association_input: Association data + + Returns: + Created association + """ + logger.info( + "Adding agent %s to collection %s", + association_input.agent_config_id, + collection_id, + ) + assoc = self.assoc_repo.create(collection_id, association_input) + logger.info("Created association %s", assoc.id) + return assoc + + def get_association(self, association_id: UUID4) -> CollectionAgentOutput: + """Get a collection-agent association. + + Args: + association_id: UUID of the association + + Returns: + Association data + """ + return self.assoc_repo.get_by_id(association_id) + + def update_association( + self, + association_id: UUID4, + update_data: CollectionAgentUpdate, + ) -> CollectionAgentOutput: + """Update a collection-agent association. + + Args: + association_id: UUID of the association + update_data: Update data + + Returns: + Updated association + """ + logger.info("Updating association %s", association_id) + assoc = self.assoc_repo.update(association_id, update_data) + logger.info("Updated association %s", association_id) + return assoc + + def remove_agent_from_collection(self, association_id: UUID4) -> bool: + """Remove an agent from a collection. + + Args: + association_id: UUID of the association + + Returns: + True if removed + """ + logger.info("Removing association %s", association_id) + result = self.assoc_repo.delete(association_id) + if result: + logger.info("Removed association %s", association_id) + return result + + def list_collection_agents( + self, + collection_id: UUID4, + stage: str | None = None, + enabled_only: bool = False, + ) -> CollectionAgentListResponse: + """List agents associated with a collection. + + Args: + collection_id: UUID of the collection + stage: Filter by stage + enabled_only: Only enabled associations + + Returns: + List of associations + """ + associations = self.assoc_repo.list_by_collection( + collection_id=collection_id, + stage=stage, + enabled_only=enabled_only, + ) + return CollectionAgentListResponse( + associations=associations, + total=len(associations), + ) + + def batch_update_priorities( + self, + collection_id: UUID4, + priorities: dict[UUID4, int], + ) -> list[CollectionAgentOutput]: + """Batch update priorities for collection agents. + + Args: + collection_id: UUID of the collection + priorities: Mapping of association ID to priority + + Returns: + List of updated associations + """ + logger.info("Batch updating %d priorities for collection %s", len(priorities), collection_id) + return self.assoc_repo.batch_update_priorities(collection_id, priorities) + + def get_collection_agent_summary(self, collection_id: UUID4) -> dict[str, Any]: + """Get a summary of agents for a collection. + + Args: + collection_id: UUID of the collection + + Returns: + Summary with counts per stage + """ + associations = self.assoc_repo.list_by_collection(collection_id, enabled_only=True) + + summary = { + "pre_search": 0, + "post_search": 0, + "response": 0, + "total": 0, + "enabled": 0, + } + + for assoc in associations: + if assoc.agent_config: + stage = assoc.agent_config.stage + if stage in summary: + summary[stage] += 1 + summary["total"] += 1 + if assoc.enabled: + summary["enabled"] += 1 + + return summary diff --git a/backend/rag_solution/services/agent_executor_service.py b/backend/rag_solution/services/agent_executor_service.py new file mode 100644 index 00000000..609ce838 --- /dev/null +++ b/backend/rag_solution/services/agent_executor_service.py @@ -0,0 +1,726 @@ +"""Agent executor service for search pipeline hooks. + +This module provides the AgentExecutorService that orchestrates agent execution +at the three stages of the search pipeline with circuit breaker pattern for +failure isolation. + +The 3-stage pipeline: +- Stage 1: Pre-Search Agents (sequential by priority) - Query enhancement +- Stage 2: Post-Search Agents (sequential by priority) - Result enhancement +- Stage 3: Response Agents (parallel execution) - Artifact generation + +Reference: GitHub Issue #697 +""" + +from __future__ import annotations + +import asyncio +import importlib +import time +from abc import ABC, abstractmethod +from collections import defaultdict +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from pydantic import UUID4 +from sqlalchemy.orm import Session + +from core.config import Settings +from core.logging_utils import get_logger +from rag_solution.models.agent_config import AgentConfig, AgentStage, CollectionAgent +from rag_solution.schemas.agent_config_schema import ( + AgentArtifact, + AgentContext, + AgentExecutionStatus, + AgentExecutionSummary, + AgentResult, + AgentStage as SchemaAgentStage, +) + +if TYPE_CHECKING: + from vectordbs.data_types import QueryResult + +logger = get_logger("services.agent_executor") + + +# ============================================================================ +# Circuit Breaker Implementation +# ============================================================================ + + +@dataclass +class CircuitBreakerState: + """State tracking for a circuit breaker. + + Attributes: + failure_count: Number of consecutive failures + last_failure_time: Timestamp of last failure + state: Current state (closed, open, half_open) + success_count_in_half_open: Successful calls in half-open state + """ + + failure_count: int = 0 + last_failure_time: float = 0.0 + state: str = "closed" # closed, open, half_open + success_count_in_half_open: int = 0 + + +class CircuitBreaker: + """Circuit breaker for isolating agent failures. + + Implements the circuit breaker pattern: + - CLOSED: Normal operation, failures increment counter + - OPEN: Requests fail fast, no execution + - HALF_OPEN: Limited requests allowed to test recovery + + Attributes: + failure_threshold: Number of failures before opening + recovery_timeout: Seconds before attempting recovery + half_open_max_calls: Max calls allowed in half-open state + """ + + def __init__( + self, + failure_threshold: int = 3, + recovery_timeout: float = 60.0, + half_open_max_calls: int = 3, + ) -> None: + """Initialize the circuit breaker. + + Args: + failure_threshold: Failures before circuit opens + recovery_timeout: Seconds before half-open transition + half_open_max_calls: Calls to allow in half-open state + """ + self.failure_threshold = failure_threshold + self.recovery_timeout = recovery_timeout + self.half_open_max_calls = half_open_max_calls + self._states: dict[str, CircuitBreakerState] = defaultdict(CircuitBreakerState) + + def _get_state(self, circuit_id: str) -> CircuitBreakerState: + """Get or create state for a circuit.""" + return self._states[circuit_id] + + def is_open(self, circuit_id: str) -> bool: + """Check if circuit is open (blocking requests). + + Args: + circuit_id: Identifier for the circuit (usually agent_config_id) + + Returns: + True if circuit is open and blocking requests + """ + state = self._get_state(circuit_id) + + if state.state == "closed": + return False + + if state.state == "open": + # Check if recovery timeout has passed + if time.time() - state.last_failure_time >= self.recovery_timeout: + state.state = "half_open" + state.success_count_in_half_open = 0 + logger.info("Circuit %s transitioning to half-open", circuit_id) + return False + return True + + # half_open state - allow limited requests + return False + + def record_success(self, circuit_id: str) -> None: + """Record a successful execution. + + Args: + circuit_id: Identifier for the circuit + """ + state = self._get_state(circuit_id) + + if state.state == "half_open": + state.success_count_in_half_open += 1 + if state.success_count_in_half_open >= self.half_open_max_calls: + # Reset to closed state + state.state = "closed" + state.failure_count = 0 + state.success_count_in_half_open = 0 + logger.info("Circuit %s closed after successful recovery", circuit_id) + elif state.state == "closed": + # Reset failure count on success + state.failure_count = 0 + + def record_failure(self, circuit_id: str) -> None: + """Record a failed execution. + + Args: + circuit_id: Identifier for the circuit + """ + state = self._get_state(circuit_id) + state.failure_count += 1 + state.last_failure_time = time.time() + + if state.state == "half_open": + # Any failure in half-open reopens the circuit + state.state = "open" + logger.warning("Circuit %s reopened after half-open failure", circuit_id) + elif state.failure_count >= self.failure_threshold: + state.state = "open" + logger.warning("Circuit %s opened after %d failures", circuit_id, state.failure_count) + + def get_state(self, circuit_id: str) -> str: + """Get the current state of a circuit. + + Args: + circuit_id: Identifier for the circuit + + Returns: + Current state: "closed", "open", or "half_open" + """ + return self._get_state(circuit_id).state + + +# ============================================================================ +# Base Agent Handler +# ============================================================================ + + +class BaseAgentHandler(ABC): + """Abstract base class for agent handlers. + + All agent implementations must inherit from this class and implement + the execute method. + """ + + @abstractmethod + async def execute(self, context: AgentContext) -> AgentResult: + """Execute the agent with the given context. + + Args: + context: Execution context with search data + + Returns: + Result of the agent execution + """ + + @property + @abstractmethod + def agent_type(self) -> str: + """Get the agent type identifier.""" + + +# ============================================================================ +# Agent Executor Service +# ============================================================================ + + +class AgentExecutorService: + """Service for executing agents at pipeline stages. + + This service manages the execution of agents configured for a collection + at the three pipeline stages: pre-search, post-search, and response. + + Key features: + - Sequential execution for pre-search and post-search (by priority) + - Parallel execution for response agents + - Circuit breaker for failure isolation + - Timeout handling + - Retry logic + + Attributes: + db: Database session + settings: Application settings + circuit_breaker: Circuit breaker instance + """ + + def __init__(self, db: Session, settings: Settings | None = None) -> None: + """Initialize the agent executor service. + + Args: + db: Database session + settings: Optional application settings + """ + self.db = db + self.settings = settings or Settings() + self.circuit_breaker = CircuitBreaker( + failure_threshold=3, + recovery_timeout=60.0, + half_open_max_calls=3, + ) + self._handler_cache: dict[str, type[BaseAgentHandler]] = {} + + def _get_collection_agents( + self, + collection_id: UUID4, + stage: AgentStage, + ) -> list[CollectionAgent]: + """Get enabled agents for a collection at a specific stage. + + Args: + collection_id: Collection UUID + stage: Pipeline stage to filter by + + Returns: + List of CollectionAgent associations ordered by priority + """ + return ( + self.db.query(CollectionAgent) + .join(AgentConfig) + .filter( + CollectionAgent.collection_id == collection_id, + CollectionAgent.enabled.is_(True), + AgentConfig.stage == stage.value, + AgentConfig.status == "active", + ) + .order_by(CollectionAgent.priority) + .all() + ) + + def _load_handler_class(self, agent_config: AgentConfig) -> type[BaseAgentHandler] | None: + """Dynamically load a handler class. + + Args: + agent_config: Agent configuration with handler info + + Returns: + Handler class or None if loading fails + """ + cache_key = f"{agent_config.handler_module}:{agent_config.handler_class}" + + if cache_key in self._handler_cache: + return self._handler_cache[cache_key] + + try: + module = importlib.import_module(agent_config.handler_module) + handler_class = getattr(module, agent_config.handler_class) + + if not issubclass(handler_class, BaseAgentHandler): + logger.error( + "Handler %s is not a subclass of BaseAgentHandler", + cache_key, + ) + return None + + self._handler_cache[cache_key] = handler_class + return handler_class + + except (ImportError, AttributeError) as e: + logger.error("Failed to load handler %s: %s", cache_key, e) + return None + + async def _execute_single_agent( + self, + collection_agent: CollectionAgent, + context: AgentContext, + ) -> AgentResult: + """Execute a single agent with timeout and error handling. + + Args: + collection_agent: Collection-agent association + context: Execution context + + Returns: + Agent result + """ + agent_config = collection_agent.agent_config + circuit_id = str(agent_config.id) + start_time = time.time() + + # Check circuit breaker + if self.circuit_breaker.is_open(circuit_id): + logger.warning("Circuit open for agent %s, skipping", agent_config.name) + return AgentResult( + agent_config_id=agent_config.id, + agent_name=agent_config.name, + agent_type=agent_config.agent_type, + stage=agent_config.stage, + status=AgentExecutionStatus.CIRCUIT_OPEN, + execution_time_ms=0.0, + error_message="Circuit breaker is open", + ) + + # Merge configuration + merged_config = collection_agent.get_merged_config() + context.config = merged_config + + # Load handler + handler_class = self._load_handler_class(agent_config) + if handler_class is None: + return AgentResult( + agent_config_id=agent_config.id, + agent_name=agent_config.name, + agent_type=agent_config.agent_type, + stage=agent_config.stage, + status=AgentExecutionStatus.FAILED, + execution_time_ms=(time.time() - start_time) * 1000, + error_message="Failed to load handler", + ) + + # Execute with retry and timeout + retries = 0 + max_retries = agent_config.max_retries + last_error = None + + while retries <= max_retries: + try: + handler = handler_class() + result = await asyncio.wait_for( + handler.execute(context), + timeout=agent_config.timeout_seconds, + ) + + # Record success + self.circuit_breaker.record_success(circuit_id) + + execution_time_ms = (time.time() - start_time) * 1000 + result.execution_time_ms = execution_time_ms + + logger.info( + "Agent %s executed successfully in %.2fms", + agent_config.name, + execution_time_ms, + ) + return result + + except asyncio.TimeoutError: + last_error = f"Timeout after {agent_config.timeout_seconds}s" + retries += 1 + logger.warning( + "Agent %s timed out (attempt %d/%d)", + agent_config.name, + retries, + max_retries + 1, + ) + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: We need to catch all exceptions to prevent pipeline failure + last_error = str(e) + retries += 1 + logger.warning( + "Agent %s failed (attempt %d/%d): %s", + agent_config.name, + retries, + max_retries + 1, + e, + ) + + # All retries exhausted + self.circuit_breaker.record_failure(circuit_id) + execution_time_ms = (time.time() - start_time) * 1000 + + return AgentResult( + agent_config_id=agent_config.id, + agent_name=agent_config.name, + agent_type=agent_config.agent_type, + stage=agent_config.stage, + status=AgentExecutionStatus.TIMEOUT if "Timeout" in str(last_error) else AgentExecutionStatus.FAILED, + execution_time_ms=execution_time_ms, + error_message=last_error, + ) + + async def execute_pre_search_agents( + self, + collection_id: UUID4, + context: AgentContext, + ) -> tuple[str, list[AgentResult]]: + """Execute pre-search agents sequentially by priority. + + Pre-search agents can modify the query before retrieval. + Each agent receives the modified query from the previous agent. + + Args: + collection_id: Collection UUID + context: Execution context + + Returns: + Tuple of (modified_query, list of results) + """ + logger.info("Executing pre-search agents for collection %s", collection_id) + + agents = self._get_collection_agents(collection_id, AgentStage.PRE_SEARCH) + results: list[AgentResult] = [] + current_query = context.query + + for collection_agent in agents: + # Update context with current query + context.query = current_query + context.previous_results = results.copy() + + result = await self._execute_single_agent(collection_agent, context) + results.append(result) + + # If successful and query was modified, use it + if result.status == AgentExecutionStatus.SUCCESS and result.modified_query: + current_query = result.modified_query + logger.info( + "Query modified by %s: '%s' -> '%s'", + result.agent_name, + context.search_input.get("question", ""), + current_query, + ) + + logger.info("Pre-search stage completed: %d agents executed", len(results)) + return current_query, results + + async def execute_post_search_agents( + self, + collection_id: UUID4, + context: AgentContext, + query_results: list[dict[str, Any]], + ) -> tuple[list[dict[str, Any]], list[AgentResult]]: + """Execute post-search agents sequentially by priority. + + Post-search agents can modify, rerank, or filter the retrieved results. + Each agent receives the modified results from the previous agent. + + Args: + collection_id: Collection UUID + context: Execution context + query_results: Retrieved documents + + Returns: + Tuple of (modified_results, list of results) + """ + logger.info("Executing post-search agents for collection %s", collection_id) + + agents = self._get_collection_agents(collection_id, AgentStage.POST_SEARCH) + results: list[AgentResult] = [] + current_results = query_results + + for collection_agent in agents: + # Update context with current results + context.query_results = current_results + context.previous_results = results.copy() + + result = await self._execute_single_agent(collection_agent, context) + results.append(result) + + # If successful and results were modified, use them + if result.status == AgentExecutionStatus.SUCCESS and result.modified_results: + current_results = result.modified_results + logger.info( + "Results modified by %s: %d -> %d items", + result.agent_name, + len(query_results), + len(current_results), + ) + + logger.info("Post-search stage completed: %d agents executed", len(results)) + return current_results, results + + async def execute_response_agents( + self, + collection_id: UUID4, + context: AgentContext, + ) -> tuple[list[AgentArtifact], list[AgentResult]]: + """Execute response agents in parallel. + + Response agents generate artifacts (PDFs, PowerPoints, charts, etc.) + and run in parallel since they don't depend on each other. + + Args: + collection_id: Collection UUID + context: Execution context + + Returns: + Tuple of (artifacts, list of results) + """ + logger.info("Executing response agents for collection %s", collection_id) + + agents = self._get_collection_agents(collection_id, AgentStage.RESPONSE) + + if not agents: + return [], [] + + # Execute all response agents in parallel + tasks = [self._execute_single_agent(agent, context) for agent in agents] + results = await asyncio.gather(*tasks, return_exceptions=True) + + agent_results: list[AgentResult] = [] + all_artifacts: list[AgentArtifact] = [] + + for i, result in enumerate(results): + if isinstance(result, Exception): + # Handle exception from gather + agent_config = agents[i].agent_config + agent_results.append( + AgentResult( + agent_config_id=agent_config.id, + agent_name=agent_config.name, + agent_type=agent_config.agent_type, + stage=agent_config.stage, + status=AgentExecutionStatus.FAILED, + execution_time_ms=0.0, + error_message=str(result), + ) + ) + else: + agent_results.append(result) + if result.artifacts: + all_artifacts.extend(result.artifacts) + + logger.info( + "Response stage completed: %d agents executed, %d artifacts generated", + len(results), + len(all_artifacts), + ) + return all_artifacts, agent_results + + async def execute_all_stages( + self, + collection_id: UUID4, + search_input: dict[str, Any], + user_id: UUID4, + initial_query: str, + query_results: list[dict[str, Any]], + ) -> AgentExecutionSummary: + """Execute all agent stages for a search request. + + This is the main entry point for agent execution during search. + It orchestrates the three stages and collects results. + + Args: + collection_id: Collection UUID + search_input: Original search request + user_id: User UUID + initial_query: Initial search query + query_results: Retrieved documents (as dicts) + + Returns: + Summary of all agent executions + """ + logger.info("Starting agent execution for collection %s", collection_id) + start_time = time.time() + + summary = AgentExecutionSummary() + + # Create initial context + context = AgentContext( + search_input=search_input, + collection_id=collection_id, + user_id=user_id, + stage=SchemaAgentStage.PRE_SEARCH, + query=initial_query, + ) + + # Stage 1: Pre-search agents + try: + modified_query, pre_results = await self.execute_pre_search_agents( + collection_id, context.model_copy(deep=True) + ) + summary.pre_search_results = pre_results + context.query = modified_query + + # Update counts + for result in pre_results: + summary.total_agents += 1 + if result.status == AgentExecutionStatus.SUCCESS: + summary.successful += 1 + elif result.status in (AgentExecutionStatus.FAILED, AgentExecutionStatus.TIMEOUT): + summary.failed += 1 + else: + summary.skipped += 1 + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail search + logger.exception("Pre-search stage failed: %s", e) + + # Stage 2: Post-search agents + try: + context.stage = SchemaAgentStage.POST_SEARCH + context.query_results = query_results + + modified_results, post_results = await self.execute_post_search_agents( + collection_id, context.model_copy(deep=True), query_results + ) + summary.post_search_results = post_results + + # Update counts + for result in post_results: + summary.total_agents += 1 + if result.status == AgentExecutionStatus.SUCCESS: + summary.successful += 1 + elif result.status in (AgentExecutionStatus.FAILED, AgentExecutionStatus.TIMEOUT): + summary.failed += 1 + else: + summary.skipped += 1 + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail search + logger.exception("Post-search stage failed: %s", e) + + # Stage 3: Response agents + try: + context.stage = SchemaAgentStage.RESPONSE + + artifacts, response_results = await self.execute_response_agents( + collection_id, context.model_copy(deep=True) + ) + summary.response_results = response_results + summary.artifacts = artifacts + + # Update counts + for result in response_results: + summary.total_agents += 1 + if result.status == AgentExecutionStatus.SUCCESS: + summary.successful += 1 + elif result.status in (AgentExecutionStatus.FAILED, AgentExecutionStatus.TIMEOUT): + summary.failed += 1 + else: + summary.skipped += 1 + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail search + logger.exception("Response stage failed: %s", e) + + summary.total_execution_time_ms = (time.time() - start_time) * 1000 + + logger.info( + "Agent execution completed: %d total, %d successful, %d failed, %d skipped in %.2fms", + summary.total_agents, + summary.successful, + summary.failed, + summary.skipped, + summary.total_execution_time_ms, + ) + + return summary + + def has_agents_for_collection(self, collection_id: UUID4) -> bool: + """Check if a collection has any enabled agents. + + Args: + collection_id: Collection UUID + + Returns: + True if collection has enabled agents + """ + count = ( + self.db.query(CollectionAgent) + .filter( + CollectionAgent.collection_id == collection_id, + CollectionAgent.enabled.is_(True), + ) + .count() + ) + return count > 0 + + def get_collection_agent_summary(self, collection_id: UUID4) -> dict[str, int]: + """Get a summary of agents for a collection by stage. + + Args: + collection_id: Collection UUID + + Returns: + Dict with counts per stage + """ + summary = { + "pre_search": 0, + "post_search": 0, + "response": 0, + "total": 0, + } + + for stage in AgentStage: + count = len(self._get_collection_agents(collection_id, stage)) + summary[stage.value] = count + summary["total"] += count + + return summary diff --git a/backend/rag_solution/services/pipeline/search_context.py b/backend/rag_solution/services/pipeline/search_context.py index a3d7f71a..78951671 100644 --- a/backend/rag_solution/services/pipeline/search_context.py +++ b/backend/rag_solution/services/pipeline/search_context.py @@ -11,6 +11,7 @@ from pydantic import UUID4 +from rag_solution.schemas.agent_config_schema import AgentArtifact, AgentExecutionSummary from rag_solution.schemas.chain_of_thought_schema import ChainOfThoughtOutput from rag_solution.schemas.llm_usage_schema import TokenWarning from rag_solution.schemas.search_schema import SearchInput @@ -49,6 +50,10 @@ class SearchContext: # pylint: disable=too-many-instance-attributes token_warning: Token usage warnings structured_answer: Structured answer with citations (when requested) + # Agent Execution (Issue #697) + agent_artifacts: Artifacts generated by response agents + agent_execution_summary: Summary of agent executions at each stage + # Execution Metadata start_time: When search started execution_time: Total search execution time @@ -77,6 +82,10 @@ class SearchContext: # pylint: disable=too-many-instance-attributes token_warning: TokenWarning | None = None structured_answer: StructuredAnswer | None = None + # Agent Execution (Issue #697) + agent_artifacts: list[AgentArtifact] = field(default_factory=list) + agent_execution_summary: AgentExecutionSummary | None = None + # Execution Metadata start_time: float = field(default_factory=time.time) execution_time: float = 0.0 diff --git a/backend/rag_solution/services/pipeline/stages/__init__.py b/backend/rag_solution/services/pipeline/stages/__init__.py index b88da082..4a10ee11 100644 --- a/backend/rag_solution/services/pipeline/stages/__init__.py +++ b/backend/rag_solution/services/pipeline/stages/__init__.py @@ -4,6 +4,7 @@ This module contains concrete implementations of pipeline stages. """ +from .agent_execution_stage import PostSearchAgentStage, PreSearchAgentStage, ResponseAgentStage from .generation_stage import GenerationStage from .pipeline_resolution_stage import PipelineResolutionStage from .query_enhancement_stage import QueryEnhancementStage @@ -14,8 +15,11 @@ __all__ = [ "GenerationStage", "PipelineResolutionStage", + "PostSearchAgentStage", + "PreSearchAgentStage", "QueryEnhancementStage", "ReasoningStage", "RerankingStage", + "ResponseAgentStage", "RetrievalStage", ] diff --git a/backend/rag_solution/services/pipeline/stages/agent_execution_stage.py b/backend/rag_solution/services/pipeline/stages/agent_execution_stage.py new file mode 100644 index 00000000..e2e50a93 --- /dev/null +++ b/backend/rag_solution/services/pipeline/stages/agent_execution_stage.py @@ -0,0 +1,416 @@ +""" +Agent execution stage for search pipeline. + +This stage orchestrates the execution of configured agents at the appropriate +pipeline points. It integrates with the AgentExecutorService to run agents +at three stages: +- Pre-Search: Query enhancement agents (run before retrieval) +- Post-Search: Result enhancement agents (run after retrieval, before generation) +- Response: Artifact generation agents (run after generation) + +Reference: GitHub Issue #697 +""" + +from typing import Any + +from sqlalchemy.orm import Session + +from core.config import Settings +from core.logging_utils import get_logger +from rag_solution.schemas.agent_config_schema import AgentStage +from rag_solution.services.agent_executor_service import AgentExecutorService +from rag_solution.services.pipeline.base_stage import BaseStage, StageResult +from rag_solution.services.pipeline.search_context import SearchContext + +logger = get_logger("services.pipeline.stages.agent_execution") + + +class PreSearchAgentStage(BaseStage): # pylint: disable=too-few-public-methods + """ + Execute pre-search agents for query enhancement. + + This stage runs agents that can modify the query before retrieval: + - Query expanders + - Language detectors/translators + - Acronym resolvers + - Intent classifiers + + Agents are executed sequentially by priority. + """ + + def __init__(self, db: Session, settings: Settings | None = None) -> None: + """ + Initialize the pre-search agent stage. + + Args: + db: Database session + settings: Application settings + """ + super().__init__("PreSearchAgents") + self.db = db + self.settings = settings + self._executor: AgentExecutorService | None = None + + @property + def executor(self) -> AgentExecutorService: + """Lazy initialization of agent executor service.""" + if self._executor is None: + self._executor = AgentExecutorService(self.db, self.settings) + return self._executor + + async def execute(self, context: SearchContext) -> StageResult: + """ + Execute pre-search agents. + + Args: + context: Current search context + + Returns: + StageResult with potentially modified query + """ + self._log_stage_start(context) + + try: + # Check if collection has any pre-search agents + if not self.executor.has_agents_for_collection(context.collection_id): + logger.debug("No agents configured for collection %s", context.collection_id) + return StageResult(success=True, context=context) + + # Build agent context + agent_context_dict = { + "search_input": context.search_input.model_dump(), + "collection_id": context.collection_id, + "user_id": context.user_id, + "stage": AgentStage.PRE_SEARCH, + "query": context.search_input.question, + "query_results": [], + "previous_results": [], + "config": {}, + "metadata": context.metadata, + } + + from rag_solution.schemas.agent_config_schema import AgentContext as SchemaAgentContext + + agent_context = SchemaAgentContext(**agent_context_dict) + + # Execute pre-search agents + modified_query, results = await self.executor.execute_pre_search_agents( + context.collection_id, + agent_context, + ) + + # Update context with modified query if changed + if modified_query and modified_query != context.search_input.question: + context.rewritten_query = modified_query + logger.info("Query modified by pre-search agents: '%s'", modified_query[:100]) + + # Initialize or update execution summary + if context.agent_execution_summary is None: + from rag_solution.schemas.agent_config_schema import AgentExecutionSummary + + context.agent_execution_summary = AgentExecutionSummary() + + context.agent_execution_summary.pre_search_results = results + + # Update counts + for result in results: + context.agent_execution_summary.total_agents += 1 + if result.status.value == "success": + context.agent_execution_summary.successful += 1 + elif result.status.value in ("failed", "timeout"): + context.agent_execution_summary.failed += 1 + else: + context.agent_execution_summary.skipped += 1 + + context.add_metadata( + "pre_search_agents", + { + "agents_executed": len(results), + "query_modified": modified_query != context.search_input.question, + }, + ) + + result = StageResult(success=True, context=context) + self._log_stage_complete(result) + return result + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail pipeline for agent errors + logger.exception("Pre-search agent stage failed: %s", e) + # Return success to allow pipeline to continue + return StageResult(success=True, context=context, error=str(e)) + + +class PostSearchAgentStage(BaseStage): # pylint: disable=too-few-public-methods + """ + Execute post-search agents for result enhancement. + + This stage runs agents that can modify the retrieved results: + - Re-rankers + - Deduplicators + - External enrichers + - PII redactors + + Agents are executed sequentially by priority. + """ + + def __init__(self, db: Session, settings: Settings | None = None) -> None: + """ + Initialize the post-search agent stage. + + Args: + db: Database session + settings: Application settings + """ + super().__init__("PostSearchAgents") + self.db = db + self.settings = settings + self._executor: AgentExecutorService | None = None + + @property + def executor(self) -> AgentExecutorService: + """Lazy initialization of agent executor service.""" + if self._executor is None: + self._executor = AgentExecutorService(self.db, self.settings) + return self._executor + + async def execute(self, context: SearchContext) -> StageResult: + """ + Execute post-search agents. + + Args: + context: Current search context with query_results + + Returns: + StageResult with potentially modified results + """ + self._log_stage_start(context) + + try: + # Check if collection has any post-search agents + if not self.executor.has_agents_for_collection(context.collection_id): + logger.debug("No agents configured for collection %s", context.collection_id) + return StageResult(success=True, context=context) + + # Convert QueryResults to dicts for agent processing + query_results_dicts = [ + { + "chunk_id": r.chunk.id if r.chunk else None, + "document_id": r.document_id, + "score": r.score, + "text": r.chunk.text if r.chunk else "", + "metadata": r.chunk.metadata if r.chunk else {}, + } + for r in context.query_results + ] + + # Build agent context + agent_context_dict = { + "search_input": context.search_input.model_dump(), + "collection_id": context.collection_id, + "user_id": context.user_id, + "stage": AgentStage.POST_SEARCH, + "query": context.rewritten_query or context.search_input.question, + "query_results": query_results_dicts, + "previous_results": [], + "config": {}, + "metadata": context.metadata, + } + + from rag_solution.schemas.agent_config_schema import AgentContext as SchemaAgentContext + + agent_context = SchemaAgentContext(**agent_context_dict) + + # Execute post-search agents + modified_results, results = await self.executor.execute_post_search_agents( + context.collection_id, + agent_context, + query_results_dicts, + ) + + # Note: Modified results would need conversion back to QueryResult objects + # For now, agents can modify scores/order but not the structure + # A future enhancement could allow full result modification + + # Initialize or update execution summary + if context.agent_execution_summary is None: + from rag_solution.schemas.agent_config_schema import AgentExecutionSummary + + context.agent_execution_summary = AgentExecutionSummary() + + context.agent_execution_summary.post_search_results = results + + # Update counts + for result in results: + context.agent_execution_summary.total_agents += 1 + if result.status.value == "success": + context.agent_execution_summary.successful += 1 + elif result.status.value in ("failed", "timeout"): + context.agent_execution_summary.failed += 1 + else: + context.agent_execution_summary.skipped += 1 + + context.add_metadata( + "post_search_agents", + { + "agents_executed": len(results), + "results_modified": len(modified_results) != len(query_results_dicts), + }, + ) + + result = StageResult(success=True, context=context) + self._log_stage_complete(result) + return result + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail pipeline for agent errors + logger.exception("Post-search agent stage failed: %s", e) + return StageResult(success=True, context=context, error=str(e)) + + +class ResponseAgentStage(BaseStage): # pylint: disable=too-few-public-methods + """ + Execute response agents for artifact generation. + + This stage runs agents that generate artifacts from the search results: + - PowerPoint generators + - PDF report generators + - Chart generators + - Audio summary generators + + Agents are executed in parallel since they don't depend on each other. + """ + + def __init__(self, db: Session, settings: Settings | None = None) -> None: + """ + Initialize the response agent stage. + + Args: + db: Database session + settings: Application settings + """ + super().__init__("ResponseAgents") + self.db = db + self.settings = settings + self._executor: AgentExecutorService | None = None + + @property + def executor(self) -> AgentExecutorService: + """Lazy initialization of agent executor service.""" + if self._executor is None: + self._executor = AgentExecutorService(self.db, self.settings) + return self._executor + + async def execute(self, context: SearchContext) -> StageResult: + """ + Execute response agents for artifact generation. + + Args: + context: Current search context with generated answer + + Returns: + StageResult with generated artifacts + """ + self._log_stage_start(context) + + try: + # Check if collection has any response agents + if not self.executor.has_agents_for_collection(context.collection_id): + logger.debug("No agents configured for collection %s", context.collection_id) + return StageResult(success=True, context=context) + + # Convert QueryResults to dicts for agent processing + query_results_dicts = [ + { + "chunk_id": r.chunk.id if r.chunk else None, + "document_id": r.document_id, + "score": r.score, + "text": r.chunk.text if r.chunk else "", + "metadata": r.chunk.metadata if r.chunk else {}, + } + for r in context.query_results + ] + + # Build agent context with full search data + agent_context_dict = { + "search_input": context.search_input.model_dump(), + "collection_id": context.collection_id, + "user_id": context.user_id, + "stage": AgentStage.RESPONSE, + "query": context.rewritten_query or context.search_input.question, + "query_results": query_results_dicts, + "previous_results": [], + "config": {}, + "metadata": { + **context.metadata, + "generated_answer": context.generated_answer, + "document_count": len(context.document_metadata), + }, + } + + from rag_solution.schemas.agent_config_schema import AgentContext as SchemaAgentContext + + agent_context = SchemaAgentContext(**agent_context_dict) + + # Execute response agents (in parallel) + artifacts, results = await self.executor.execute_response_agents( + context.collection_id, + agent_context, + ) + + # Update context with artifacts + context.agent_artifacts = artifacts + + # Initialize or update execution summary + if context.agent_execution_summary is None: + from rag_solution.schemas.agent_config_schema import AgentExecutionSummary + + context.agent_execution_summary = AgentExecutionSummary() + + context.agent_execution_summary.response_results = results + context.agent_execution_summary.artifacts = artifacts + + # Update counts + for result in results: + context.agent_execution_summary.total_agents += 1 + if result.status.value == "success": + context.agent_execution_summary.successful += 1 + elif result.status.value in ("failed", "timeout"): + context.agent_execution_summary.failed += 1 + else: + context.agent_execution_summary.skipped += 1 + + # Calculate total execution time + if context.agent_execution_summary: + total_time = sum( + r.execution_time_ms + for r in ( + context.agent_execution_summary.pre_search_results + + context.agent_execution_summary.post_search_results + + context.agent_execution_summary.response_results + ) + ) + context.agent_execution_summary.total_execution_time_ms = total_time + + context.add_metadata( + "response_agents", + { + "agents_executed": len(results), + "artifacts_generated": len(artifacts), + }, + ) + + logger.info( + "Response agents generated %d artifacts from %d agents", + len(artifacts), + len(results), + ) + + result = StageResult(success=True, context=context) + self._log_stage_complete(result) + return result + + except Exception as e: # pylint: disable=broad-exception-caught + # Justification: Log but don't fail pipeline for agent errors + logger.exception("Response agent stage failed: %s", e) + return StageResult(success=True, context=context, error=str(e)) diff --git a/backend/rag_solution/services/search_service.py b/backend/rag_solution/services/search_service.py index 26cb88ca..4828b878 100644 --- a/backend/rag_solution/services/search_service.py +++ b/backend/rag_solution/services/search_service.py @@ -27,9 +27,12 @@ from rag_solution.services.pipeline.stages import ( GenerationStage, PipelineResolutionStage, + PostSearchAgentStage, + PreSearchAgentStage, QueryEnhancementStage, ReasoningStage, RerankingStage, + ResponseAgentStage, RetrievalStage, ) from rag_solution.services.pipeline_service import PipelineService @@ -573,15 +576,18 @@ async def search(self, search_input: SearchInput) -> SearchOutput: return await self._search_with_pipeline(search_input) async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput: - """New stage-based pipeline architecture (Week 4). + """Stage-based pipeline architecture with agent execution hooks. This method uses the modern pipeline architecture with explicit stages: 1. PipelineResolutionStage - Resolve user's default pipeline 2. QueryEnhancementStage - Enhance/rewrite query - 3. RetrievalStage - Retrieve documents from vector DB - 4. RerankingStage - Rerank results for relevance - 5. ReasoningStage - Apply Chain of Thought if needed - 6. GenerationStage - Generate final answer + 3. PreSearchAgentStage - Execute pre-search agents (Issue #697) + 4. RetrievalStage - Retrieve documents from vector DB + 5. RerankingStage - Rerank results for relevance + 6. PostSearchAgentStage - Execute post-search agents (Issue #697) + 7. ReasoningStage - Apply Chain of Thought if needed + 8. GenerationStage - Generate final answer + 9. ResponseAgentStage - Execute response agents for artifacts (Issue #697) Each stage is independent, testable, and modifiable without affecting others. This enables easier maintenance, testing, and feature addition. @@ -590,9 +596,9 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput search_input: The search request Returns: - SearchOutput with answer, documents, and metadata + SearchOutput with answer, documents, metadata, and agent artifacts """ - logger.info("✨ Starting NEW pipeline architecture execution") + logger.info("✨ Starting pipeline architecture execution with agent hooks") logger.info("Question: %s", search_input.question) # Create initial search context @@ -603,8 +609,14 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput # Create pipeline executor (pass empty list, stages will be added below) executor = PipelineExecutor(stages=[]) - # Add stages in execution order (Week 4 implementation uses all stages) - logger.debug("Configuring pipeline with all 6 stages") + # Check if agents are enabled (can be disabled via config_metadata) + agents_enabled = True + if search_input.config_metadata and search_input.config_metadata.get("agents_disabled"): + agents_enabled = False + logger.info("Agent execution disabled by config_metadata") + + # Add stages in execution order + logger.debug("Configuring pipeline with stages (agents_enabled=%s)", agents_enabled) # Stage 1: Pipeline Resolution - Get user's default pipeline configuration executor.add_stage(PipelineResolutionStage(self.pipeline_service)) @@ -612,18 +624,30 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput # Stage 2: Query Enhancement - Rewrite/enhance query for better retrieval executor.add_stage(QueryEnhancementStage(self.pipeline_service)) - # Stage 3: Retrieval - Get documents from vector DB + # Stage 3: Pre-Search Agents - Query enhancement agents (Issue #697) + if agents_enabled: + executor.add_stage(PreSearchAgentStage(self.db, self.settings)) + + # Stage 4: Retrieval - Get documents from vector DB executor.add_stage(RetrievalStage(self.pipeline_service)) - # Stage 4: Reranking - Rerank results for better relevance + # Stage 5: Reranking - Rerank results for better relevance executor.add_stage(RerankingStage(self.pipeline_service)) - # Stage 5: Reasoning - Apply Chain of Thought if needed + # Stage 6: Post-Search Agents - Result enhancement agents (Issue #697) + if agents_enabled: + executor.add_stage(PostSearchAgentStage(self.db, self.settings)) + + # Stage 7: Reasoning - Apply Chain of Thought if needed executor.add_stage(ReasoningStage(self.chain_of_thought_service)) - # Stage 6: Generation - Generate final answer from context + # Stage 8: Generation - Generate final answer from context executor.add_stage(GenerationStage(self.pipeline_service)) + # Stage 9: Response Agents - Artifact generation agents (Issue #697) + if agents_enabled: + executor.add_stage(ResponseAgentStage(self.db, self.settings)) + # Execute pipeline logger.info("Executing pipeline with %d stages", len(executor.get_stage_names())) result_context = await executor.execute(context) @@ -654,6 +678,17 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput else "NO DOCUMENT_NAME", ) + # Log agent execution summary if present + if result_context.agent_execution_summary: + summary = result_context.agent_execution_summary + logger.info( + "🤖 Agent execution: %d total, %d successful, %d failed in %.2fms", + summary.total_agents, + summary.successful, + summary.failed, + summary.total_execution_time_ms, + ) + search_output = SearchOutput( answer=cleaned_answer, documents=result_context.document_metadata, @@ -664,9 +699,12 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput cot_output=cot_output_dict, token_warning=result_context.token_warning, structured_answer=result_context.structured_answer, + agent_artifacts=result_context.agent_artifacts if result_context.agent_artifacts else None, + agent_executions=result_context.agent_execution_summary, metadata={ - "pipeline_architecture": "v2_stage_based", + "pipeline_architecture": "v2_with_agents", "stages_executed": executor.get_stage_names(), + "agents_enabled": agents_enabled, **result_context.metadata, }, ) @@ -674,6 +712,8 @@ async def _search_with_pipeline(self, search_input: SearchInput) -> SearchOutput logger.info("✨ Pipeline execution completed successfully in %.2f seconds", result_context.execution_time) logger.info("Generated answer length: %d chars", len(cleaned_answer)) logger.info("Retrieved documents: %d", len(result_context.query_results)) + if result_context.agent_artifacts: + logger.info("Generated artifacts: %d", len(result_context.agent_artifacts)) return search_output diff --git a/tests/unit/schemas/test_agent_config_schema.py b/tests/unit/schemas/test_agent_config_schema.py new file mode 100644 index 00000000..230c1ab4 --- /dev/null +++ b/tests/unit/schemas/test_agent_config_schema.py @@ -0,0 +1,426 @@ +"""Unit tests for agent configuration schemas. + +This module tests the Pydantic schemas for agent configuration +including validation and serialization. + +Reference: GitHub Issue #697 +""" + +import pytest +from pydantic import UUID4, ValidationError + +from rag_solution.schemas.agent_config_schema import ( + AgentArtifact, + AgentConfigInput, + AgentConfigOutput, + AgentConfigStatus, + AgentConfigUpdate, + AgentContext, + AgentExecutionStatus, + AgentExecutionSummary, + AgentResult, + AgentStage, + BatchPriorityUpdate, + CollectionAgentInput, + CollectionAgentOutput, + CollectionAgentUpdate, + PipelineMetadata, +) + + +class TestAgentStageEnum: + """Test AgentStage enum.""" + + def test_pre_search_value(self) -> None: + """Test PRE_SEARCH stage value.""" + assert AgentStage.PRE_SEARCH.value == "pre_search" + + def test_post_search_value(self) -> None: + """Test POST_SEARCH stage value.""" + assert AgentStage.POST_SEARCH.value == "post_search" + + def test_response_value(self) -> None: + """Test RESPONSE stage value.""" + assert AgentStage.RESPONSE.value == "response" + + +class TestAgentConfigStatusEnum: + """Test AgentConfigStatus enum.""" + + def test_active_value(self) -> None: + """Test ACTIVE status value.""" + assert AgentConfigStatus.ACTIVE.value == "active" + + def test_disabled_value(self) -> None: + """Test DISABLED status value.""" + assert AgentConfigStatus.DISABLED.value == "disabled" + + def test_deprecated_value(self) -> None: + """Test DEPRECATED status value.""" + assert AgentConfigStatus.DEPRECATED.value == "deprecated" + + +class TestAgentExecutionStatusEnum: + """Test AgentExecutionStatus enum.""" + + def test_all_statuses(self) -> None: + """Test all execution status values.""" + assert AgentExecutionStatus.SUCCESS.value == "success" + assert AgentExecutionStatus.FAILED.value == "failed" + assert AgentExecutionStatus.TIMEOUT.value == "timeout" + assert AgentExecutionStatus.SKIPPED.value == "skipped" + assert AgentExecutionStatus.CIRCUIT_OPEN.value == "circuit_open" + + +class TestAgentConfigInput: + """Test AgentConfigInput schema.""" + + def test_valid_input(self) -> None: + """Test creating valid agent config input.""" + config = AgentConfigInput( + name="Query Expander", + description="Expands queries for better retrieval", + agent_type="query_expander", + stage=AgentStage.PRE_SEARCH, + handler_module="rag_solution.agents.query_expander", + handler_class="QueryExpanderAgent", + ) + assert config.name == "Query Expander" + assert config.stage == AgentStage.PRE_SEARCH + assert config.timeout_seconds == 30 # default + assert config.max_retries == 2 # default + + def test_with_custom_settings(self) -> None: + """Test config with custom timeout and retries.""" + config = AgentConfigInput( + name="Slow Agent", + agent_type="slow_agent", + stage=AgentStage.RESPONSE, + handler_module="test.module", + handler_class="SlowAgent", + timeout_seconds=120, + max_retries=5, + priority=50, + ) + assert config.timeout_seconds == 120 + assert config.max_retries == 5 + assert config.priority == 50 + + def test_with_default_config(self) -> None: + """Test config with default_config parameter.""" + config = AgentConfigInput( + name="Configurable Agent", + agent_type="configurable", + stage=AgentStage.POST_SEARCH, + handler_module="test.module", + handler_class="ConfigurableAgent", + default_config={"threshold": 0.5, "max_results": 10}, + ) + assert config.default_config["threshold"] == 0.5 + assert config.default_config["max_results"] == 10 + + def test_name_validation(self) -> None: + """Test name field validation.""" + with pytest.raises(ValidationError): + AgentConfigInput( + name="", # Empty name should fail + agent_type="test", + stage=AgentStage.PRE_SEARCH, + handler_module="test.module", + handler_class="TestAgent", + ) + + def test_timeout_validation(self) -> None: + """Test timeout_seconds validation bounds.""" + # Too low + with pytest.raises(ValidationError): + AgentConfigInput( + name="Test", + agent_type="test", + stage=AgentStage.PRE_SEARCH, + handler_module="test.module", + handler_class="TestAgent", + timeout_seconds=0, + ) + # Too high + with pytest.raises(ValidationError): + AgentConfigInput( + name="Test", + agent_type="test", + stage=AgentStage.PRE_SEARCH, + handler_module="test.module", + handler_class="TestAgent", + timeout_seconds=500, + ) + + def test_max_retries_validation(self) -> None: + """Test max_retries validation bounds.""" + with pytest.raises(ValidationError): + AgentConfigInput( + name="Test", + agent_type="test", + stage=AgentStage.PRE_SEARCH, + handler_module="test.module", + handler_class="TestAgent", + max_retries=10, # Max is 5 + ) + + +class TestAgentConfigUpdate: + """Test AgentConfigUpdate schema.""" + + def test_partial_update(self) -> None: + """Test partial update with only some fields.""" + update = AgentConfigUpdate(name="Updated Name") + assert update.name == "Updated Name" + assert update.description is None + assert update.status is None + + def test_status_update(self) -> None: + """Test status update.""" + update = AgentConfigUpdate(status=AgentConfigStatus.DISABLED) + assert update.status == AgentConfigStatus.DISABLED + + def test_all_fields_update(self) -> None: + """Test update with all fields.""" + update = AgentConfigUpdate( + name="New Name", + description="New description", + default_config={"new": "config"}, + timeout_seconds=60, + max_retries=3, + priority=200, + status=AgentConfigStatus.ACTIVE, + ) + assert update.name == "New Name" + assert update.timeout_seconds == 60 + + +class TestCollectionAgentInput: + """Test CollectionAgentInput schema.""" + + def test_minimal_input(self) -> None: + """Test minimal collection-agent input.""" + assoc = CollectionAgentInput( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + ) + assert assoc.enabled is True # default + assert assoc.priority == 100 # default + assert assoc.config_override == {} # default + + def test_full_input(self) -> None: + """Test full collection-agent input.""" + assoc = CollectionAgentInput( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + enabled=False, + priority=50, + config_override={"custom_setting": "value"}, + ) + assert assoc.enabled is False + assert assoc.priority == 50 + assert assoc.config_override["custom_setting"] == "value" + + def test_priority_validation(self) -> None: + """Test priority field validation.""" + with pytest.raises(ValidationError): + CollectionAgentInput( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + priority=2000, # Max is 1000 + ) + + +class TestCollectionAgentUpdate: + """Test CollectionAgentUpdate schema.""" + + def test_enable_only(self) -> None: + """Test updating only enabled field.""" + update = CollectionAgentUpdate(enabled=False) + assert update.enabled is False + assert update.priority is None + + def test_priority_only(self) -> None: + """Test updating only priority field.""" + update = CollectionAgentUpdate(priority=25) + assert update.priority == 25 + assert update.enabled is None + + +class TestBatchPriorityUpdate: + """Test BatchPriorityUpdate schema.""" + + def test_batch_update(self) -> None: + """Test batch priority update.""" + update = BatchPriorityUpdate( + priorities={ + UUID4("12345678-1234-5678-1234-567812345678"): 10, + UUID4("87654321-4321-8765-4321-876543218765"): 20, + } + ) + assert len(update.priorities) == 2 + + +class TestAgentContext: + """Test AgentContext schema.""" + + def test_minimal_context(self) -> None: + """Test minimal context creation.""" + ctx = AgentContext( + search_input={"question": "test"}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.PRE_SEARCH, + query="test query", + ) + assert ctx.query == "test query" + assert ctx.query_results == [] + assert ctx.previous_results == [] + assert ctx.config == {} + + def test_full_context(self) -> None: + """Test full context with all fields.""" + ctx = AgentContext( + search_input={"question": "test", "config": {"top_k": 10}}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.POST_SEARCH, + query="test query", + query_results=[{"id": "doc1", "score": 0.9}], + config={"threshold": 0.5}, + metadata={"source": "test"}, + ) + assert len(ctx.query_results) == 1 + assert ctx.config["threshold"] == 0.5 + + +class TestAgentResult: + """Test AgentResult schema.""" + + def test_success_result(self) -> None: + """Test successful result.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="Test Agent", + agent_type="test", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=50.5, + ) + assert result.status == AgentExecutionStatus.SUCCESS + assert result.execution_time_ms == 50.5 + assert result.error_message is None + + def test_failed_result(self) -> None: + """Test failed result.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="Test Agent", + agent_type="test", + stage="pre_search", + status=AgentExecutionStatus.FAILED, + execution_time_ms=100.0, + error_message="Connection failed", + ) + assert result.status == AgentExecutionStatus.FAILED + assert "Connection" in result.error_message + + def test_result_with_modified_query(self) -> None: + """Test result with modified query.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="Query Expander", + agent_type="query_expander", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=30.0, + modified_query="expanded query with more terms", + ) + assert result.modified_query is not None + assert "expanded" in result.modified_query + + +class TestAgentArtifact: + """Test AgentArtifact schema.""" + + def test_pdf_artifact(self) -> None: + """Test PDF artifact creation.""" + artifact = AgentArtifact( + artifact_type="pdf", + content_type="application/pdf", + filename="report.pdf", + size_bytes=10240, + ) + assert artifact.artifact_type == "pdf" + assert artifact.content_type == "application/pdf" + assert artifact.filename == "report.pdf" + + def test_artifact_with_data_url(self) -> None: + """Test artifact with data URL.""" + artifact = AgentArtifact( + artifact_type="chart", + content_type="image/png", + filename="chart.png", + data_url="", + metadata={"chart_type": "bar", "title": "Sales"}, + ) + assert artifact.data_url.startswith("data:") + assert artifact.metadata["chart_type"] == "bar" + + +class TestAgentExecutionSummary: + """Test AgentExecutionSummary schema.""" + + def test_empty_summary(self) -> None: + """Test empty summary defaults.""" + summary = AgentExecutionSummary() + assert summary.total_agents == 0 + assert summary.successful == 0 + assert summary.failed == 0 + assert summary.skipped == 0 + assert summary.total_execution_time_ms == 0.0 + assert summary.artifacts == [] + + def test_populated_summary(self) -> None: + """Test populated summary.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="Test", + agent_type="test", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=50.0, + ) + summary = AgentExecutionSummary( + total_agents=3, + successful=2, + failed=1, + total_execution_time_ms=150.0, + pre_search_results=[result], + ) + assert summary.total_agents == 3 + assert summary.successful == 2 + assert len(summary.pre_search_results) == 1 + + +class TestPipelineMetadata: + """Test PipelineMetadata schema.""" + + def test_default_metadata(self) -> None: + """Test default pipeline metadata.""" + meta = PipelineMetadata() + assert meta.pipeline_version == "v2_with_agents" + assert meta.stages_executed == [] + assert meta.total_execution_time_ms == 0.0 + + def test_full_metadata(self) -> None: + """Test full pipeline metadata.""" + summary = AgentExecutionSummary(total_agents=5) + meta = PipelineMetadata( + pipeline_version="v2_with_agents", + stages_executed=["PipelineResolution", "QueryEnhancement", "PreSearchAgents"], + total_execution_time_ms=500.0, + agent_execution_summary=summary, + timings={"PipelineResolution": 10.0, "QueryEnhancement": 50.0}, + ) + assert len(meta.stages_executed) == 3 + assert meta.agent_execution_summary.total_agents == 5 + assert "QueryEnhancement" in meta.timings diff --git a/tests/unit/services/test_agent_executor_service.py b/tests/unit/services/test_agent_executor_service.py new file mode 100644 index 00000000..0771aa2d --- /dev/null +++ b/tests/unit/services/test_agent_executor_service.py @@ -0,0 +1,454 @@ +"""Unit tests for AgentExecutorService and CircuitBreaker. + +This module tests the agent execution hooks functionality including: +- Circuit breaker pattern +- Agent execution at each pipeline stage +- Error handling and failure isolation + +Reference: GitHub Issue #697 +""" + +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import UUID4 + +from rag_solution.schemas.agent_config_schema import ( + AgentArtifact, + AgentContext, + AgentExecutionStatus, + AgentResult, + AgentStage, +) +from rag_solution.services.agent_executor_service import ( + BaseAgentHandler, + CircuitBreaker, + CircuitBreakerState, +) + + +# ============================================================================ +# Circuit Breaker Tests +# ============================================================================ + + +class TestCircuitBreaker: + """Test suite for CircuitBreaker class.""" + + def test_initial_state_is_closed(self) -> None: + """Test that circuit breaker starts in closed state.""" + cb = CircuitBreaker() + assert cb.get_state("test") == "closed" + assert not cb.is_open("test") + + def test_records_failure(self) -> None: + """Test that failures are recorded correctly.""" + cb = CircuitBreaker(failure_threshold=3) + cb.record_failure("test") + assert cb._states["test"].failure_count == 1 + assert cb.get_state("test") == "closed" + + def test_opens_after_threshold(self) -> None: + """Test that circuit opens after failure threshold.""" + cb = CircuitBreaker(failure_threshold=3) + for _ in range(3): + cb.record_failure("test") + assert cb.get_state("test") == "open" + assert cb.is_open("test") + + def test_success_resets_failure_count(self) -> None: + """Test that success resets failure count in closed state.""" + cb = CircuitBreaker(failure_threshold=3) + cb.record_failure("test") + cb.record_failure("test") + assert cb._states["test"].failure_count == 2 + cb.record_success("test") + assert cb._states["test"].failure_count == 0 + + def test_transitions_to_half_open_after_timeout(self) -> None: + """Test circuit transitions to half-open after recovery timeout.""" + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1) + cb.record_failure("test") + assert cb.get_state("test") == "open" + time.sleep(0.15) + # is_open should trigger transition to half-open + assert not cb.is_open("test") + assert cb.get_state("test") == "half_open" + + def test_closes_after_successful_half_open(self) -> None: + """Test circuit closes after successful calls in half-open state.""" + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1, half_open_max_calls=2) + cb.record_failure("test") + time.sleep(0.15) + cb.is_open("test") # Trigger transition + cb.record_success("test") + cb.record_success("test") + assert cb.get_state("test") == "closed" + + def test_reopens_on_failure_in_half_open(self) -> None: + """Test circuit reopens on failure in half-open state.""" + cb = CircuitBreaker(failure_threshold=1, recovery_timeout=0.1) + cb.record_failure("test") + time.sleep(0.15) + cb.is_open("test") # Trigger transition to half-open + cb.record_failure("test") + assert cb.get_state("test") == "open" + + def test_independent_circuits(self) -> None: + """Test that different circuit IDs are independent.""" + cb = CircuitBreaker(failure_threshold=2) + cb.record_failure("circuit1") + cb.record_failure("circuit1") + cb.record_failure("circuit2") + assert cb.get_state("circuit1") == "open" + assert cb.get_state("circuit2") == "closed" + + +# ============================================================================ +# Agent Context and Result Schema Tests +# ============================================================================ + + +class TestAgentSchemas: + """Test suite for agent-related Pydantic schemas.""" + + def test_agent_context_creation(self) -> None: + """Test AgentContext creation with required fields.""" + context = AgentContext( + search_input={"question": "test question"}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.PRE_SEARCH, + query="test query", + ) + assert context.query == "test query" + assert context.stage == AgentStage.PRE_SEARCH + assert context.query_results == [] + assert context.previous_results == [] + + def test_agent_context_with_all_fields(self) -> None: + """Test AgentContext with all optional fields.""" + context = AgentContext( + search_input={"question": "test"}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.POST_SEARCH, + query="test", + query_results=[{"id": "doc1", "score": 0.9}], + config={"threshold": 0.5}, + metadata={"source": "test"}, + ) + assert len(context.query_results) == 1 + assert context.config["threshold"] == 0.5 + + def test_agent_result_success(self) -> None: + """Test AgentResult for successful execution.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="test_agent", + agent_type="query_expander", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=100.5, + modified_query="expanded query", + ) + assert result.status == AgentExecutionStatus.SUCCESS + assert result.modified_query == "expanded query" + assert result.error_message is None + + def test_agent_result_failure(self) -> None: + """Test AgentResult for failed execution.""" + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="test_agent", + agent_type="reranker", + stage="post_search", + status=AgentExecutionStatus.FAILED, + execution_time_ms=50.0, + error_message="Connection timeout", + ) + assert result.status == AgentExecutionStatus.FAILED + assert "timeout" in result.error_message.lower() + + def test_agent_result_with_artifacts(self) -> None: + """Test AgentResult with generated artifacts.""" + artifact = AgentArtifact( + artifact_type="pdf", + content_type="application/pdf", + filename="report.pdf", + data_url="data:application/pdf;base64,ABC123", + size_bytes=1024, + ) + result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="pdf_generator", + agent_type="pdf_generator", + stage="response", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=500.0, + artifacts=[artifact], + ) + assert len(result.artifacts) == 1 + assert result.artifacts[0].filename == "report.pdf" + + def test_agent_artifact_creation(self) -> None: + """Test AgentArtifact creation.""" + artifact = AgentArtifact( + artifact_type="chart", + content_type="image/png", + filename="chart.png", + metadata={"chart_type": "bar"}, + ) + assert artifact.artifact_type == "chart" + assert artifact.metadata["chart_type"] == "bar" + + +# ============================================================================ +# Base Agent Handler Tests +# ============================================================================ + + +class MockAgentHandler(BaseAgentHandler): + """Mock agent handler for testing.""" + + def __init__(self, return_value: AgentResult | None = None, raise_exception: bool = False) -> None: + self._return_value = return_value + self._raise_exception = raise_exception + + async def execute(self, context: AgentContext) -> AgentResult: + if self._raise_exception: + raise RuntimeError("Test exception") + if self._return_value: + return self._return_value + return AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="mock_agent", + agent_type="mock", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=10.0, + ) + + @property + def agent_type(self) -> str: + return "mock" + + +class TestBaseAgentHandler: + """Test suite for BaseAgentHandler abstract class.""" + + @pytest.mark.asyncio + async def test_mock_handler_success(self) -> None: + """Test mock handler returns expected result.""" + handler = MockAgentHandler() + context = AgentContext( + search_input={}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.PRE_SEARCH, + query="test", + ) + result = await handler.execute(context) + assert result.status == AgentExecutionStatus.SUCCESS + + @pytest.mark.asyncio + async def test_mock_handler_with_custom_result(self) -> None: + """Test mock handler with custom return value.""" + custom_result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="custom", + agent_type="custom", + stage="post_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=100.0, + modified_results=[{"id": "doc1"}], + ) + handler = MockAgentHandler(return_value=custom_result) + context = AgentContext( + search_input={}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.POST_SEARCH, + query="test", + ) + result = await handler.execute(context) + assert result.modified_results is not None + assert len(result.modified_results) == 1 + + @pytest.mark.asyncio + async def test_mock_handler_raises_exception(self) -> None: + """Test mock handler exception handling.""" + handler = MockAgentHandler(raise_exception=True) + context = AgentContext( + search_input={}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.PRE_SEARCH, + query="test", + ) + with pytest.raises(RuntimeError): + await handler.execute(context) + + +# ============================================================================ +# Agent Stage Tests +# ============================================================================ + + +class TestAgentStages: + """Test agent execution at different pipeline stages.""" + + def test_stage_enum_values(self) -> None: + """Test AgentStage enum values.""" + assert AgentStage.PRE_SEARCH.value == "pre_search" + assert AgentStage.POST_SEARCH.value == "post_search" + assert AgentStage.RESPONSE.value == "response" + + def test_execution_status_enum_values(self) -> None: + """Test AgentExecutionStatus enum values.""" + assert AgentExecutionStatus.SUCCESS.value == "success" + assert AgentExecutionStatus.FAILED.value == "failed" + assert AgentExecutionStatus.TIMEOUT.value == "timeout" + assert AgentExecutionStatus.SKIPPED.value == "skipped" + assert AgentExecutionStatus.CIRCUIT_OPEN.value == "circuit_open" + + +# ============================================================================ +# Integration-like Unit Tests +# ============================================================================ + + +class TestAgentExecutionFlow: + """Test agent execution flow scenarios.""" + + @pytest.mark.asyncio + async def test_pre_search_agent_modifies_query(self) -> None: + """Test pre-search agent that modifies query.""" + modified_result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="query_expander", + agent_type="query_expander", + stage="pre_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=50.0, + modified_query="expanded: what is machine learning and AI", + ) + handler = MockAgentHandler(return_value=modified_result) + context = AgentContext( + search_input={"question": "what is ML"}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.PRE_SEARCH, + query="what is ML", + ) + result = await handler.execute(context) + assert result.modified_query is not None + assert "expanded" in result.modified_query + + @pytest.mark.asyncio + async def test_post_search_agent_modifies_results(self) -> None: + """Test post-search agent that modifies results.""" + modified_result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="reranker", + agent_type="reranker", + stage="post_search", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=100.0, + modified_results=[ + {"id": "doc2", "score": 0.95}, + {"id": "doc1", "score": 0.85}, + ], + ) + handler = MockAgentHandler(return_value=modified_result) + context = AgentContext( + search_input={}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.POST_SEARCH, + query="test", + query_results=[ + {"id": "doc1", "score": 0.9}, + {"id": "doc2", "score": 0.8}, + ], + ) + result = await handler.execute(context) + assert result.modified_results is not None + # Check reranking - doc2 now first + assert result.modified_results[0]["id"] == "doc2" + + @pytest.mark.asyncio + async def test_response_agent_generates_artifact(self) -> None: + """Test response agent that generates artifact.""" + artifact = AgentArtifact( + artifact_type="pptx", + content_type="application/vnd.openxmlformats-officedocument.presentationml.presentation", + filename="summary.pptx", + size_bytes=5000, + ) + modified_result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="pptx_generator", + agent_type="pptx_generator", + stage="response", + status=AgentExecutionStatus.SUCCESS, + execution_time_ms=500.0, + artifacts=[artifact], + ) + handler = MockAgentHandler(return_value=modified_result) + context = AgentContext( + search_input={}, + collection_id=UUID4("12345678-1234-5678-1234-567812345678"), + user_id=UUID4("87654321-4321-8765-4321-876543218765"), + stage=AgentStage.RESPONSE, + query="test", + metadata={"generated_answer": "Test answer"}, + ) + result = await handler.execute(context) + assert result.artifacts is not None + assert len(result.artifacts) == 1 + assert result.artifacts[0].artifact_type == "pptx" + + +# ============================================================================ +# Error Handling Tests +# ============================================================================ + + +class TestAgentErrorHandling: + """Test error handling in agent execution.""" + + def test_circuit_breaker_blocks_requests_when_open(self) -> None: + """Test that circuit breaker blocks requests when open.""" + cb = CircuitBreaker(failure_threshold=1) + cb.record_failure("test") + assert cb.is_open("test") + + @pytest.mark.asyncio + async def test_agent_timeout_handling(self) -> None: + """Test handling of agent timeout.""" + # Simulate timeout result + timeout_result = AgentResult( + agent_config_id=UUID4("12345678-1234-5678-1234-567812345678"), + agent_name="slow_agent", + agent_type="slow", + stage="pre_search", + status=AgentExecutionStatus.TIMEOUT, + execution_time_ms=30000.0, + error_message="Timeout after 30s", + ) + assert timeout_result.status == AgentExecutionStatus.TIMEOUT + assert "Timeout" in timeout_result.error_message + + def test_circuit_breaker_state_isolation(self) -> None: + """Test that circuit breaker state is isolated per circuit ID.""" + cb = CircuitBreaker(failure_threshold=2) + # Fail one circuit + cb.record_failure("agent1") + cb.record_failure("agent1") + # Other circuit should be unaffected + assert cb.is_open("agent1") + assert not cb.is_open("agent2") + assert not cb.is_open("agent3")