-
Notifications
You must be signed in to change notification settings - Fork 0
feat: agent architecture and performance optimizations #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| from pydantic import Field | ||
| from pydantic_settings import BaseSettings, SettingsConfigDict | ||
| from typing import Literal | ||
|
|
||
|
|
||
| class AgentSettings(BaseSettings): | ||
|
|
@@ -32,5 +33,29 @@ class AgentSettings(BaseSettings): | |
| ) | ||
| LANGFUSE_PROMPT_REJECTION_ROUTER: str = "text2sql/rejection_router" | ||
|
|
||
| # ── G2-01: Table Scoping ────────────────────────────────────────────────── | ||
| TABLE_SCOPING_MODE: Literal["strict", "hybrid"] = "hybrid" | ||
|
|
||
| # ── G2-03: Advanced Schema Explorer phases ──────────────────────────────── | ||
| ENABLE_SEMANTIC_TYPING: bool = False | ||
| ENABLE_JOIN_GRAPH: bool = False | ||
| ENABLE_SCHEMA_SUMMARIZATION: bool = False | ||
| ENABLE_AMBIGUITY_DETECT: bool = True | ||
|
|
||
| # ── G2-04: Satisfaction Check ───────────────────────────────────────────── | ||
| SATISFACTION_CHECK_ENABLED: bool = True | ||
| SATISFACTION_CHECK_EXECUTION: bool = True | ||
| SATISFACTION_CHECK_PLAUSIBILITY: bool = True | ||
| SATISFACTION_CHECK_COLUMNS: bool = True | ||
| SATISFACTION_CHECK_SEMANTIC: bool = False # LLM-heavy, off by default | ||
| SATISFACTION_MIN_ROWS: int = 1 | ||
| SATISFACTION_MAX_ROWS: int = 50_000 | ||
| SATISFACTION_SEMANTIC_THRESHOLD: float = 0.75 | ||
| SATISFACTION_MAX_FAILURES: int = 2 # escalate to HITL after this many check failures | ||
|
|
||
| # ── G2-05: Redis Schema Cache ───────────────────────────────────────────── | ||
| SCHEMA_CACHE_TTL: int = 600 # seconds — DDL content | ||
| PROFILE_CACHE_TTL: int = 1800 # seconds — table profile statistics | ||
|
Comment on lines
+51
to
+58
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: cat -n agent/src/agent/config.py | head -100Repository: StavPonte11/text2sql-onboarding Length of output: 3067 🏁 Script executed: rg "SATISFACTION_MIN_ROWS|SATISFACTION_MAX_ROWS|SATISFACTION_SEMANTIC_THRESHOLD|SATISFACTION_MAX_FAILURES|SCHEMA_CACHE_TTL|PROFILE_CACHE_TTL" --type py -B 2 -A 2Repository: StavPonte11/text2sql-onboarding Length of output: 5927 🏁 Script executed: rg "validator|field_validator|model_validator" agent/src/agent/config.py -A 3Repository: StavPonte11/text2sql-onboarding Length of output: 57 🏁 Script executed: cat -n agent/src/agent/nodes/satisfaction_check.py | head -150Repository: StavPonte11/text2sql-onboarding Length of output: 6850 Add bounds validation for satisfaction thresholds and cache TTLs. These env-driven values currently accept invalid ranges (e.g., negative or zero TTLs, semantic threshold outside
Apply Pydantic Field constraints (already used elsewhere in this config class): Proposed fix- SATISFACTION_MIN_ROWS: int = 1
- SATISFACTION_MAX_ROWS: int = 50_000
- SATISFACTION_SEMANTIC_THRESHOLD: float = 0.75
- SATISFACTION_MAX_FAILURES: int = 2 # escalate to HITL after this many check failures
+ SATISFACTION_MIN_ROWS: int = Field(default=1, ge=0)
+ SATISFACTION_MAX_ROWS: int = Field(default=50_000, ge=1)
+ SATISFACTION_SEMANTIC_THRESHOLD: float = Field(default=0.75, ge=0.0, le=1.0)
+ SATISFACTION_MAX_FAILURES: int = Field(default=2, ge=1) # escalate to HITL after this many check failures
@@
- SCHEMA_CACHE_TTL: int = 600 # seconds — DDL content
- PROFILE_CACHE_TTL: int = 1800 # seconds — table profile statistics
+ SCHEMA_CACHE_TTL: int = Field(default=600, gt=0) # seconds — DDL content
+ PROFILE_CACHE_TTL: int = Field(default=1800, gt=0) # seconds — table profile statistics🤖 Prompt for AI Agents |
||
|
|
||
|
|
||
| settings = AgentSettings() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,30 @@ | ||
| """ | ||
| LangGraph agent graph — Group 2 hardened topology. | ||
|
|
||
| Node order: | ||
| START → validate_config → extractor → schema_explorer → ... | ||
| (G2-01 fail-fast) | ||
|
|
||
| HITL escalation compiles with interrupt_before=["hitl_escalation"] so | ||
| LangGraph pauses before executing that node. After a human injects | ||
| corrected state the graph resumes from hitl_escalation which immediately | ||
| routes to extractor (full state reset path). | ||
|
|
||
| Satisfaction check sits between refiner success path and finalizer (G2-04). | ||
| """ | ||
|
|
||
| from agent.nodes.refiner import MAX_REFINER_ITERATIONS | ||
| from langgraph.graph import StateGraph, START, END | ||
| from langgraph.checkpoint.memory import MemorySaver | ||
|
|
||
| from langchain_core.prompts import ChatPromptTemplate | ||
| from agent.state import AgentState | ||
| from agent.nodes.extractor import extractor_node | ||
| from agent.nodes.schema_explorer import schema_explorer_node | ||
| from agent.nodes.schema_explorer import schema_explorer_node, MAX_SCHEMA_RETRIES | ||
| from agent.nodes.query_builder import query_builder_node | ||
| from agent.nodes.refiner import refiner_node | ||
| from agent.nodes.finalizer import finalizer_node | ||
| from agent.nodes.satisfaction_check import satisfaction_check_node | ||
| from agent.config import settings | ||
| from agent.langfuse_client import langfuse_client | ||
| from pydantic import BaseModel, Field | ||
|
|
@@ -20,6 +36,71 @@ | |
| llm = get_llm("routing") | ||
|
|
||
|
|
||
| # ── G2-01: Custom exception ─────────────────────────────────────────────────── | ||
|
|
||
|
|
||
| class InvalidConfigurationException(ValueError): | ||
| """Raised when agent state contains an invalid or unsafe configuration.""" | ||
|
|
||
|
|
||
| # ── G2-01: Config validator node ────────────────────────────────────────────── | ||
|
|
||
|
|
||
| def validate_config_node(state: AgentState) -> dict: | ||
| """ | ||
| First node after START. Resolves scoping_mode from state (or falls back | ||
| to the env default) and enforces strict-mode preconditions. | ||
|
|
||
| Raises: | ||
| InvalidConfigurationException: if scoping_mode='strict' and | ||
| allowed_tables is null or empty. | ||
| """ | ||
| mode: str = state.get("scoping_mode") or settings.TABLE_SCOPING_MODE | ||
|
|
||
| if mode == "strict": | ||
| allowed = state.get("allowed_tables") | ||
| if not allowed: | ||
| raise InvalidConfigurationException( | ||
| "scoping_mode='strict' requires allowed_tables to be a non-empty list. " | ||
| "Execution aborted to prevent unrestricted table access." | ||
| ) | ||
|
|
||
| return {"scoping_mode": mode} | ||
|
|
||
|
|
||
| # ── G2-02: HITL escalation node ─────────────────────────────────────────────── | ||
|
|
||
|
|
||
| def hitl_escalation_node(state: AgentState) -> dict: | ||
| """ | ||
| Execution pauses HERE via LangGraph interrupt_before before this node runs. | ||
| The human then calls graph.update_state() to inject corrected state and | ||
| clears sql_query / last_error / trino_error / escalated / escalation_reason. | ||
|
Comment on lines
+77
to
+78
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im not sure that what happens?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also im not sure clearing the state is the right call |
||
| After update_state the graph resumes from this node, which immediately | ||
| routes to extractor via its direct edge. | ||
|
|
||
| This node body only performs observability work — it does NOT call interrupt() | ||
| itself (interrupt_before handles the pause at compile time). | ||
| """ | ||
| reason = state.get("escalation_reason", "Maximum retries exhausted.") | ||
|
|
||
| try: | ||
| trace_id = langfuse_client.get_current_trace_id() | ||
| if trace_id: | ||
| langfuse_client.trace( | ||
| id=trace_id, | ||
| tags=["escalated=True"], | ||
| metadata={"escalation_reason": reason}, | ||
| ) | ||
| except Exception: | ||
| pass | ||
|
|
||
| return {"escalated": True} | ||
|
|
||
|
|
||
| # ── Rejection router ────────────────────────────────────────────────────────── | ||
|
|
||
|
|
||
| class RejectionRoute(BaseModel): | ||
| route: Literal["extractor", "schema_explorer", "query_builder"] = Field( | ||
| description="The phase to route the execution back to based on the user feedback." | ||
|
|
@@ -55,60 +136,92 @@ def rejection_router_node(state: AgentState): | |
| } | ||
|
|
||
|
|
||
| def route_refiner(state: AgentState): | ||
| if ( | ||
| state.get("trino_error") | ||
| and state.get("refinement_count", 0) < MAX_REFINER_ITERATIONS | ||
| ): | ||
| return "refiner" | ||
| return "finalizer" | ||
| # ── Conditional edge functions ──────────────────────────────────────────────── | ||
|
|
||
|
|
||
| def route_schema_explorer(state: AgentState) -> str: | ||
| """G2-02: route to hitl_escalation after MAX_SCHEMA_RETRIES.""" | ||
| if state.get("hallucinated_tables"): | ||
| if (state.get("schema_explorer_retry_count") or 0) >= MAX_SCHEMA_RETRIES: | ||
| return "hitl_escalation" | ||
| return "schema_explorer" | ||
| return "query_builder" | ||
|
|
||
| def route_query_builder(state: AgentState): | ||
|
|
||
| def route_refiner(state: AgentState) -> str: | ||
| """G2-02: route to hitl_escalation when refiner limit is hit.""" | ||
| if state.get("trino_error"): | ||
| if state.get("refinement_count", 0) < MAX_REFINER_ITERATIONS: | ||
| return "satisfaction_check" # run check even on error path so Check A can flag it | ||
| return "hitl_escalation" | ||
| return "satisfaction_check" | ||
|
|
||
|
|
||
| def route_satisfaction(state: AgentState) -> str: | ||
| """ | ||
| G2-04: route based on satisfaction check outcome. | ||
| - no module / no failures → finalizer | ||
| - failures, under MAX → refiner | ||
| - failures, over MAX → hitl_escalation | ||
| """ | ||
| failures = state.get("satisfaction_failures") | ||
| if not failures: | ||
| return "finalizer" | ||
|
|
||
| fail_count = state.get("satisfaction_fail_count") or 0 | ||
| if fail_count >= settings.SATISFACTION_MAX_FAILURES: | ||
| return "hitl_escalation" | ||
| return "refiner" | ||
|
|
||
|
|
||
| def route_query_builder(state: AgentState) -> str: | ||
| if state.get("feedback"): | ||
| return "rejection_router" | ||
| return "refiner" | ||
|
|
||
|
|
||
| def route_rejection(state: AgentState): | ||
| def route_rejection(state: AgentState) -> str: | ||
| route = state.get("feedback_route") | ||
| if route in ["extractor", "schema_explorer", "query_builder"]: | ||
| return route | ||
| return "extractor" | ||
|
|
||
|
|
||
| # ── Build graph ─────────────────────────────────────────────────────────────── | ||
|
|
||
| workflow = StateGraph(AgentState) | ||
|
|
||
| workflow.add_node("validate_config", validate_config_node) | ||
| workflow.add_node("extractor", extractor_node) | ||
| workflow.add_node("schema_explorer", schema_explorer_node) | ||
| workflow.add_node("query_builder", query_builder_node) | ||
| workflow.add_node("rejection_router", rejection_router_node) | ||
| workflow.add_node("refiner", refiner_node) | ||
| workflow.add_node("satisfaction_check", satisfaction_check_node) | ||
| workflow.add_node("hitl_escalation", hitl_escalation_node) | ||
| workflow.add_node("finalizer", finalizer_node) | ||
|
|
||
| workflow.add_edge(START, "extractor") | ||
| # Entry: validate config before anything else (G2-01 fail-fast) | ||
| workflow.add_edge(START, "validate_config") | ||
| workflow.add_edge("validate_config", "extractor") | ||
| workflow.add_edge("extractor", "schema_explorer") | ||
|
|
||
|
|
||
| def route_schema_explorer(state: AgentState): | ||
| if state.get("hallucinated_tables"): | ||
| if state.get("schema_explorer_retry_count", 0) >= 3: | ||
| return "query_builder" | ||
| return "schema_explorer" | ||
| return "query_builder" | ||
|
|
||
|
|
||
| workflow.add_conditional_edges( | ||
| "schema_explorer", | ||
| route_schema_explorer, | ||
| {"schema_explorer": "schema_explorer", "query_builder": "query_builder"}, | ||
| { | ||
| "schema_explorer": "schema_explorer", | ||
| "query_builder": "query_builder", | ||
| "hitl_escalation": "hitl_escalation", # G2-02 | ||
| }, | ||
| ) | ||
|
|
||
| workflow.add_conditional_edges( | ||
| "query_builder", | ||
| route_query_builder, | ||
| {"rejection_router": "rejection_router", "refiner": "refiner"}, | ||
| ) | ||
|
|
||
| workflow.add_conditional_edges( | ||
| "rejection_router", | ||
| route_rejection, | ||
|
|
@@ -120,9 +233,31 @@ def route_schema_explorer(state: AgentState): | |
| ) | ||
|
|
||
| workflow.add_conditional_edges( | ||
| "refiner", route_refiner, {"refiner": "refiner", "finalizer": "finalizer"} | ||
| "refiner", | ||
| route_refiner, | ||
| { | ||
| "satisfaction_check": "satisfaction_check", # G2-04 replaces direct → finalizer | ||
| "hitl_escalation": "hitl_escalation", # G2-02 | ||
| }, | ||
| ) | ||
|
|
||
| # G2-04: satisfaction gate | ||
| workflow.add_conditional_edges( | ||
| "satisfaction_check", | ||
| route_satisfaction, | ||
| { | ||
| "finalizer": "finalizer", | ||
| "refiner": "refiner", | ||
| "hitl_escalation": "hitl_escalation", | ||
| }, | ||
| ) | ||
|
|
||
| # G2-02: HITL resume path → restart from extractor (full state reset by human) | ||
| workflow.add_edge("hitl_escalation", "extractor") | ||
| workflow.add_edge("finalizer", END) | ||
|
|
||
| memory = MemorySaver() | ||
| agent_graph = workflow.compile(checkpointer=memory) | ||
| agent_graph = workflow.compile( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it would be really nice to add a graph of the graph to the readme (so I can also see visually what it looks like) |
||
| checkpointer=memory, | ||
| interrupt_before=["hitl_escalation"], # G2-02: pause before HITL node | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be on the request level and not on global level. so you can pass the configuration as argument when invoking the agent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
realized it acts as default - should be called default table scoping mode then