Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions agent/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"langchain-openai",
"greenlet>=3.5.1",
"mcp>=1.12.4",
"networkx>=3.3",
]

[tool.uv.sources]
Expand All @@ -32,6 +33,9 @@ build-backend = "uv_build"
[dependency-groups]
dev = [
"pytest>=9.0.3",
"pytest-asyncio>=1.4.0",
"pytest-cov>=7.1.0",
"pytest-mock>=3.15.1",
"ruff>=0.3.0",
]

Expand Down
25 changes: 25 additions & 0 deletions agent/src/agent/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Literal


class AgentSettings(BaseSettings):
Expand Down Expand Up @@ -32,5 +33,29 @@ class AgentSettings(BaseSettings):
)
LANGFUSE_PROMPT_REJECTION_ROUTER: str = "text2sql/rejection_router"

# ── G2-01: Table Scoping ──────────────────────────────────────────────────
TABLE_SCOPING_MODE: Literal["strict", "hybrid"] = "hybrid"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be on the request level and not on global level. so you can pass the configuration as argument when invoking the agent

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

realized it acts as default - should be called default table scoping mode then


# ── G2-03: Advanced Schema Explorer phases ────────────────────────────────
ENABLE_SEMANTIC_TYPING: bool = False
ENABLE_JOIN_GRAPH: bool = False
ENABLE_SCHEMA_SUMMARIZATION: bool = False
ENABLE_AMBIGUITY_DETECT: bool = True

# ── G2-04: Satisfaction Check ─────────────────────────────────────────────
SATISFACTION_CHECK_ENABLED: bool = True
SATISFACTION_CHECK_EXECUTION: bool = True
SATISFACTION_CHECK_PLAUSIBILITY: bool = True
SATISFACTION_CHECK_COLUMNS: bool = True
SATISFACTION_CHECK_SEMANTIC: bool = False # LLM-heavy, off by default
SATISFACTION_MIN_ROWS: int = 1
SATISFACTION_MAX_ROWS: int = 50_000
SATISFACTION_SEMANTIC_THRESHOLD: float = 0.75
SATISFACTION_MAX_FAILURES: int = 2 # escalate to HITL after this many check failures

# ── G2-05: Redis Schema Cache ─────────────────────────────────────────────
SCHEMA_CACHE_TTL: int = 600 # seconds — DDL content
PROFILE_CACHE_TTL: int = 1800 # seconds — table profile statistics
Comment on lines +51 to +58

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

cat -n agent/src/agent/config.py | head -100

Repository: StavPonte11/text2sql-onboarding

Length of output: 3067


🏁 Script executed:

rg "SATISFACTION_MIN_ROWS|SATISFACTION_MAX_ROWS|SATISFACTION_SEMANTIC_THRESHOLD|SATISFACTION_MAX_FAILURES|SCHEMA_CACHE_TTL|PROFILE_CACHE_TTL" --type py -B 2 -A 2

Repository: StavPonte11/text2sql-onboarding

Length of output: 5927


🏁 Script executed:

rg "validator|field_validator|model_validator" agent/src/agent/config.py -A 3

Repository: StavPonte11/text2sql-onboarding

Length of output: 57


🏁 Script executed:

cat -n agent/src/agent/nodes/satisfaction_check.py | head -150

Repository: StavPonte11/text2sql-onboarding

Length of output: 6850


Add bounds validation for satisfaction thresholds and cache TTLs.

These env-driven values currently accept invalid ranges (e.g., negative or zero TTLs, semantic threshold outside [0,1], or inverted row bounds), which silently degrade gate decisions and cache behavior:

  • SATISFACTION_MIN_ROWS ≤ 0 causes the plausibility check to never fail
  • SATISFACTION_MAX_ROWS ≤ 0 or < MIN_ROWS inverts the check logic
  • SATISFACTION_SEMANTIC_THRESHOLD outside [0,1] makes the check always pass or always fail
  • SATISFACTION_MAX_FAILURES ≤ 0 breaks escalation logic
  • SCHEMA_CACHE_TTL and PROFILE_CACHE_TTL ≤ 0 fail on Redis setex operations

Apply Pydantic Field constraints (already used elsewhere in this config class):

Proposed fix
-    SATISFACTION_MIN_ROWS: int = 1
-    SATISFACTION_MAX_ROWS: int = 50_000
-    SATISFACTION_SEMANTIC_THRESHOLD: float = 0.75
-    SATISFACTION_MAX_FAILURES: int = 2  # escalate to HITL after this many check failures
+    SATISFACTION_MIN_ROWS: int = Field(default=1, ge=0)
+    SATISFACTION_MAX_ROWS: int = Field(default=50_000, ge=1)
+    SATISFACTION_SEMANTIC_THRESHOLD: float = Field(default=0.75, ge=0.0, le=1.0)
+    SATISFACTION_MAX_FAILURES: int = Field(default=2, ge=1)  # escalate to HITL after this many check failures
@@
-    SCHEMA_CACHE_TTL: int = 600    # seconds — DDL content
-    PROFILE_CACHE_TTL: int = 1800  # seconds — table profile statistics
+    SCHEMA_CACHE_TTL: int = Field(default=600, gt=0)    # seconds — DDL content
+    PROFILE_CACHE_TTL: int = Field(default=1800, gt=0)  # seconds — table profile statistics
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@agent/src/agent/config.py` around lines 51 - 58, Add Pydantic Field
constraints to validate the satisfaction and cache configuration values in the
config class. For SATISFACTION_MIN_ROWS and SATISFACTION_MAX_ROWS, ensure they
are positive integers with MIN_ROWS less than MAX_ROWS. For
SATISFACTION_SEMANTIC_THRESHOLD, constrain it to the range [0, 1]. For
SATISFACTION_MAX_FAILURES, ensure it is a positive integer. For SCHEMA_CACHE_TTL
and PROFILE_CACHE_TTL, ensure both are positive integers. Use Pydantic
validators like gt (greater than), ge (greater than or equal), le (less than or
equal), and Field constraints that match the patterns already established
elsewhere in this config class.



settings = AgentSettings()
181 changes: 158 additions & 23 deletions agent/src/agent/graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
"""
LangGraph agent graph — Group 2 hardened topology.

Node order:
START → validate_config → extractor → schema_explorer → ...
(G2-01 fail-fast)

HITL escalation compiles with interrupt_before=["hitl_escalation"] so
LangGraph pauses before executing that node. After a human injects
corrected state the graph resumes from hitl_escalation which immediately
routes to extractor (full state reset path).

Satisfaction check sits between refiner success path and finalizer (G2-04).
"""

from agent.nodes.refiner import MAX_REFINER_ITERATIONS
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver

from langchain_core.prompts import ChatPromptTemplate
from agent.state import AgentState
from agent.nodes.extractor import extractor_node
from agent.nodes.schema_explorer import schema_explorer_node
from agent.nodes.schema_explorer import schema_explorer_node, MAX_SCHEMA_RETRIES
from agent.nodes.query_builder import query_builder_node
from agent.nodes.refiner import refiner_node
from agent.nodes.finalizer import finalizer_node
from agent.nodes.satisfaction_check import satisfaction_check_node
from agent.config import settings
from agent.langfuse_client import langfuse_client
from pydantic import BaseModel, Field
Expand All @@ -20,6 +36,71 @@
llm = get_llm("routing")


# ── G2-01: Custom exception ───────────────────────────────────────────────────


class InvalidConfigurationException(ValueError):
"""Raised when agent state contains an invalid or unsafe configuration."""


# ── G2-01: Config validator node ──────────────────────────────────────────────


def validate_config_node(state: AgentState) -> dict:
"""
First node after START. Resolves scoping_mode from state (or falls back
to the env default) and enforces strict-mode preconditions.

Raises:
InvalidConfigurationException: if scoping_mode='strict' and
allowed_tables is null or empty.
"""
mode: str = state.get("scoping_mode") or settings.TABLE_SCOPING_MODE

if mode == "strict":
allowed = state.get("allowed_tables")
if not allowed:
raise InvalidConfigurationException(
"scoping_mode='strict' requires allowed_tables to be a non-empty list. "
"Execution aborted to prevent unrestricted table access."
)

return {"scoping_mode": mode}


# ── G2-02: HITL escalation node ───────────────────────────────────────────────


def hitl_escalation_node(state: AgentState) -> dict:
"""
Execution pauses HERE via LangGraph interrupt_before before this node runs.
The human then calls graph.update_state() to inject corrected state and
clears sql_query / last_error / trino_error / escalated / escalation_reason.
Comment on lines +77 to +78

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure that what happens?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also im not sure clearing the state is the right call

After update_state the graph resumes from this node, which immediately
routes to extractor via its direct edge.

This node body only performs observability work — it does NOT call interrupt()
itself (interrupt_before handles the pause at compile time).
"""
reason = state.get("escalation_reason", "Maximum retries exhausted.")

try:
trace_id = langfuse_client.get_current_trace_id()
if trace_id:
langfuse_client.trace(
id=trace_id,
tags=["escalated=True"],
metadata={"escalation_reason": reason},
)
except Exception:
pass

return {"escalated": True}


# ── Rejection router ──────────────────────────────────────────────────────────


class RejectionRoute(BaseModel):
route: Literal["extractor", "schema_explorer", "query_builder"] = Field(
description="The phase to route the execution back to based on the user feedback."
Expand Down Expand Up @@ -55,60 +136,92 @@ def rejection_router_node(state: AgentState):
}


def route_refiner(state: AgentState):
if (
state.get("trino_error")
and state.get("refinement_count", 0) < MAX_REFINER_ITERATIONS
):
return "refiner"
return "finalizer"
# ── Conditional edge functions ────────────────────────────────────────────────


def route_schema_explorer(state: AgentState) -> str:
"""G2-02: route to hitl_escalation after MAX_SCHEMA_RETRIES."""
if state.get("hallucinated_tables"):
if (state.get("schema_explorer_retry_count") or 0) >= MAX_SCHEMA_RETRIES:
return "hitl_escalation"
return "schema_explorer"
return "query_builder"

def route_query_builder(state: AgentState):

def route_refiner(state: AgentState) -> str:
"""G2-02: route to hitl_escalation when refiner limit is hit."""
if state.get("trino_error"):
if state.get("refinement_count", 0) < MAX_REFINER_ITERATIONS:
return "satisfaction_check" # run check even on error path so Check A can flag it
return "hitl_escalation"
return "satisfaction_check"


def route_satisfaction(state: AgentState) -> str:
"""
G2-04: route based on satisfaction check outcome.
- no module / no failures → finalizer
- failures, under MAX → refiner
- failures, over MAX → hitl_escalation
"""
failures = state.get("satisfaction_failures")
if not failures:
return "finalizer"

fail_count = state.get("satisfaction_fail_count") or 0
if fail_count >= settings.SATISFACTION_MAX_FAILURES:
return "hitl_escalation"
return "refiner"


def route_query_builder(state: AgentState) -> str:
if state.get("feedback"):
return "rejection_router"
return "refiner"


def route_rejection(state: AgentState):
def route_rejection(state: AgentState) -> str:
route = state.get("feedback_route")
if route in ["extractor", "schema_explorer", "query_builder"]:
return route
return "extractor"


# ── Build graph ───────────────────────────────────────────────────────────────

workflow = StateGraph(AgentState)

workflow.add_node("validate_config", validate_config_node)
workflow.add_node("extractor", extractor_node)
workflow.add_node("schema_explorer", schema_explorer_node)
workflow.add_node("query_builder", query_builder_node)
workflow.add_node("rejection_router", rejection_router_node)
workflow.add_node("refiner", refiner_node)
workflow.add_node("satisfaction_check", satisfaction_check_node)
workflow.add_node("hitl_escalation", hitl_escalation_node)
workflow.add_node("finalizer", finalizer_node)

workflow.add_edge(START, "extractor")
# Entry: validate config before anything else (G2-01 fail-fast)
workflow.add_edge(START, "validate_config")
workflow.add_edge("validate_config", "extractor")
workflow.add_edge("extractor", "schema_explorer")


def route_schema_explorer(state: AgentState):
if state.get("hallucinated_tables"):
if state.get("schema_explorer_retry_count", 0) >= 3:
return "query_builder"
return "schema_explorer"
return "query_builder"


workflow.add_conditional_edges(
"schema_explorer",
route_schema_explorer,
{"schema_explorer": "schema_explorer", "query_builder": "query_builder"},
{
"schema_explorer": "schema_explorer",
"query_builder": "query_builder",
"hitl_escalation": "hitl_escalation", # G2-02
},
)

workflow.add_conditional_edges(
"query_builder",
route_query_builder,
{"rejection_router": "rejection_router", "refiner": "refiner"},
)

workflow.add_conditional_edges(
"rejection_router",
route_rejection,
Expand All @@ -120,9 +233,31 @@ def route_schema_explorer(state: AgentState):
)

workflow.add_conditional_edges(
"refiner", route_refiner, {"refiner": "refiner", "finalizer": "finalizer"}
"refiner",
route_refiner,
{
"satisfaction_check": "satisfaction_check", # G2-04 replaces direct → finalizer
"hitl_escalation": "hitl_escalation", # G2-02
},
)

# G2-04: satisfaction gate
workflow.add_conditional_edges(
"satisfaction_check",
route_satisfaction,
{
"finalizer": "finalizer",
"refiner": "refiner",
"hitl_escalation": "hitl_escalation",
},
)

# G2-02: HITL resume path → restart from extractor (full state reset by human)
workflow.add_edge("hitl_escalation", "extractor")
workflow.add_edge("finalizer", END)

memory = MemorySaver()
agent_graph = workflow.compile(checkpointer=memory)
agent_graph = workflow.compile(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be really nice to add a graph of the graph to the readme (so I can also see visually what it looks like)

checkpointer=memory,
interrupt_before=["hitl_escalation"], # G2-02: pause before HITL node
)
4 changes: 4 additions & 0 deletions agent/src/agent/nodes/refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ async def refiner_node(state: AgentState):
"last_error": trino_error,
"refinement_count": count + 1,
"error_history": error_history,
"escalation_reason": (
f"Refiner exhausted {MAX_REFINER_ITERATIONS} iterations. "
f"Last Trino error: {trino_error}"
),
}

langfuse_prompt = langfuse_client.get_prompt(settings.LANGFUSE_PROMPT_REFINER)
Expand Down
Loading