Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added agent/.coverage
Binary file not shown.
4 changes: 4 additions & 0 deletions agent/src/agent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class AgentSettings(BaseSettings):
JEEN_API_KEY: str = "" # If empty, agent gracefully skips fetching
SKILLS_HOT_RELOAD: bool = False # If true, bypass Redis cache for skills

# ── G4: Feature Flags & Execution Modes ──────────────────────────────────
BACKEND_URL: str = "" # Studio backend URL for flag reads (e.g. http://backend:8000)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add a default empty string? this will disable verification. please make this non optional

# If empty, FlagBridge falls back to env-var defaults


# Langfuse prompt names
LANGFUSE_PROMPT_EXTRACTOR: str = "text2sql/extractor"
Expand Down
7 changes: 5 additions & 2 deletions agent/src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from langchain_core.prompts import ChatPromptTemplate
from agent.state import AgentState
from agent.nodes.extractor import extractor_node
from agent.nodes.init_flags import init_flags_node
from agent.nodes.init_skills import init_skills_node
from agent.nodes.schema_explorer import schema_explorer_node, MAX_SCHEMA_RETRIES
from agent.nodes.query_builder import query_builder_node
Expand Down Expand Up @@ -193,6 +194,7 @@ def route_rejection(state: AgentState) -> str:
workflow = StateGraph(AgentState)

workflow.add_node("validate_config", validate_config_node)
workflow.add_node("init_flags", init_flags_node)
workflow.add_node("init_skills", init_skills_node)
workflow.add_node("extractor", extractor_node)
workflow.add_node("schema_explorer", schema_explorer_node)
Expand All @@ -203,9 +205,10 @@ def route_rejection(state: AgentState) -> str:
workflow.add_node("hitl_escalation", hitl_escalation_node)
workflow.add_node("finalizer", finalizer_node)

# Entry: validate config before anything else (G2-01 fail-fast)
# Entry: validate config → resolve flags → load skills → start reasoning
workflow.add_edge(START, "validate_config")
workflow.add_edge("validate_config", "init_skills")
workflow.add_edge("validate_config", "init_flags")
workflow.add_edge("init_flags", "init_skills")
workflow.add_edge("init_skills", "extractor")
workflow.add_edge("extractor", "schema_explorer")

Expand Down
65 changes: 58 additions & 7 deletions agent/src/agent/llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,67 @@
import logging
from typing import Optional

from langchain_openai import ChatOpenAI

from agent.config import settings
from typing import Optional

logger = logging.getLogger(__name__)

# Flag name → LLM_MODEL env-var fallback for each node
_NODE_MODEL_FLAGS: dict[str, str] = {
"extractor": "EXTRACTOR_MODEL",
"schema_explorer": "SCHEMA_SUMMARY_MODEL",
"query_builder": "QUERY_BUILDER_MODEL",
"refiner": "REFINER_MODEL",
"satisfaction_check": "SATISFACTION_JUDGE_MODEL",
"routing": "QUERY_BUILDER_MODEL", # rejection router reuses QB model
"default": "QUERY_BUILDER_MODEL",
}

_NODE_TEMP_FLAGS: dict[str, str] = {
"extractor": "EXTRACTOR_TEMPERATURE",
"query_builder": "QUERY_BUILDER_TEMPERATURE",
}
Comment on lines +11 to +24

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we not using the same name? this might be confusing



def get_llm(
node: str = "default",
temperature: Optional[float] = None,
runtime_flags: Optional[dict] = None,
) -> ChatOpenAI:
"""
Factory for per-node LLM instances.

Priority for model/temperature selection:
1. runtime_flags (resolved by init_flags_node from DB + execution mode)
2. AgentSettings env-var defaults

Args:
node: Name of the calling graph node (used to pick the right flag).
temperature: Optional hard override — bypasses flag resolution.
runtime_flags: The state["runtime_flags"] dict from the current invocation.
Pass None when initialising at module level (will use env defaults).
"""
flags = runtime_flags or {}

# Resolve model
model_flag = _NODE_MODEL_FLAGS.get(node, "QUERY_BUILDER_MODEL")
model = flags.get(model_flag) or settings.LLM_MODEL

# Resolve temperature
if temperature is None:
temp_flag = _NODE_TEMP_FLAGS.get(node)
temperature = float(flags.get(temp_flag, 0.0)) if temp_flag else 0.0

Comment on lines +52 to +55

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Handle malformed temperature flag values defensively.

Line 54 casts a dynamic flag directly with float(...). A bad DB/admin override (e.g., empty string or non-numeric text) will raise and break request execution.

Suggested fix
     if temperature is None:
         temp_flag = _NODE_TEMP_FLAGS.get(node)
-        temperature = float(flags.get(temp_flag, 0.0)) if temp_flag else 0.0
+        if temp_flag:
+            raw_temp = flags.get(temp_flag, 0.0)
+            try:
+                temperature = float(raw_temp)
+            except (TypeError, ValueError):
+                logger.warning(
+                    "Invalid temperature override for %s: %r; falling back to 0.0",
+                    temp_flag,
+                    raw_temp,
+                )
+                temperature = 0.0
+        else:
+            temperature = 0.0
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if temperature is None:
temp_flag = _NODE_TEMP_FLAGS.get(node)
temperature = float(flags.get(temp_flag, 0.0)) if temp_flag else 0.0
if temperature is None:
temp_flag = _NODE_TEMP_FLAGS.get(node)
if temp_flag:
raw_temp = flags.get(temp_flag, 0.0)
try:
temperature = float(raw_temp)
except (TypeError, ValueError):
logger.warning(
"Invalid temperature override for %s: %r; falling back to 0.0",
temp_flag,
raw_temp,
)
temperature = 0.0
else:
temperature = 0.0
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@agent/src/agent/llm.py` around lines 52 - 55, The temperature flag value is
being directly cast to float without error handling, which will raise a
ValueError if the flag contains a malformed value like an empty string or
non-numeric text. Wrap the float conversion in a try-except block to defensively
handle ValueError exceptions, defaulting to 0.0 when conversion fails. Apply
this error handling to the temperature assignment where
float(flags.get(temp_flag, 0.0)) is called, optionally logging a warning when an
invalid value is encountered.

logger.debug(
"Instantiating LLM for node='%s': model='%s' temperature=%.2f",
node,
model,
temperature,
)

# We could dynamically configure settings based on the node name.
# For now, it delegates to agent.config settings.
def get_llm(node: str = "default", temperature: Optional[float] = 0.0) -> ChatOpenAI:
"""Factory function for instantiating the unified LLM."""
logging.debug(f"Instantiating LLM for node: {node}")
return ChatOpenAI(
model=settings.LLM_MODEL,
model=model,
base_url=settings.LLM_BASE_URL,
api_key=settings.LLM_API_KEY,
temperature=temperature,
Expand Down
18 changes: 17 additions & 1 deletion agent/src/agent/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,24 @@ async def chat_with_agent(
allowed_statuses: list[str] | None = None,
extractors: list[str] | None = None,
active_skills: list[str] | None = None,
execution_mode: str | None = None,
hitl_enabled: bool = True,
) -> str:
"""Run the Text2SQL agent to answer database queries."""
"""Run the Text2SQL agent to answer database queries.

Args:
query: The natural language question to answer.
thread_id: Optional thread ID for session continuity.
resume_value: HITL resume payload (pass after receiving an interrupt).
allowed_tables: Restrict the agent to specific tables.
allowed_statuses: Filter tables by status.
extractors: List of extractor names/IDs to use.
active_skills: List of Jeen skill UUIDs to inject.
execution_mode: Named configuration preset (e.g. 'cost_saving',
'high_quality', 'benchmark'). Overrides flag defaults
for this invocation only.
hitl_enabled: If False, skip all human-in-the-loop interrupts.
"""
thread_id = thread_id or str(uuid.uuid4())
config = {
"configurable": {"thread_id": thread_id},
Expand Down Expand Up @@ -91,6 +106,7 @@ async def chat_with_agent(
"allowed_statuses": allowed_statuses,
"active_extractors": active_extractors,
"active_skills": active_skills,
"execution_mode": execution_mode,
"non_interactive": not hitl_enabled,
},
config=config,
Expand Down
7 changes: 4 additions & 3 deletions agent/src/agent/nodes/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def extract(self, query: str) -> List[ContextEntry]:


class LLMExtractor(BaseExtractor):
def __init__(self):
self.llm = get_llm("extractor")
def __init__(self, runtime_flags: dict | None = None):
self.llm = get_llm("extractor", runtime_flags=runtime_flags)

langfuse_prompt = langfuse_client.get_prompt(settings.LANGFUSE_PROMPT_EXTRACTOR)
self.prompt = ChatPromptTemplate.from_messages(
Expand Down Expand Up @@ -102,10 +102,11 @@ def extractor_node(state: AgentState):
"""Enrich the user query with additional context to help downstream phases."""
user_query = state["user_query"]
active_extractors = state.get("active_extractors") or []
runtime_flags = state.get("runtime_flags") or {}

import concurrent.futures

extractors: List[BaseExtractor] = [TimeExtractor(), LLMExtractor()]
extractors: List[BaseExtractor] = [TimeExtractor(), LLMExtractor(runtime_flags=runtime_flags)]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why pass the runtime flags only to one extractor? the time extractor and the http extractors don't need it?


for ext_info in active_extractors:
extractors.append(HTTPExtractor(ext_info["url"], ext_info["name"]))
Expand Down
66 changes: 66 additions & 0 deletions agent/src/agent/nodes/init_flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
init_flags_node (G4)
====================
Runs immediately after validate_config, before init_skills.

Responsibilities:
1. Call FlagBridge.resolve_flags(execution_mode) to merge:
mode overrides → DB flags → env-var defaults
2. Write the resolved dict to state["runtime_flags"]
3. Log runtime_flags to Langfuse trace metadata for full observability
"""

import logging

from agent.langfuse_client import langfuse_client
from agent.state import AgentState
from agent.utils.flag_bridge import FlagBridge

logger = logging.getLogger(__name__)

_flag_bridge = FlagBridge()


async def init_flags_node(state: AgentState) -> dict:
"""
Resolve all runtime configuration flags for this invocation.

The resolved dict is stored in state["runtime_flags"] and read by every
downstream node instead of directly accessing AgentSettings env vars.
This guarantees that:
- DS team changes in the Studio UI take effect within the cache TTL (30s).
- Execution mode overrides are applied consistently to all nodes.
- Every Langfuse trace carries the exact config used for that query.
"""
execution_mode: str | None = state.get("execution_mode")

try:
runtime_flags = await _flag_bridge.resolve_flags(execution_mode)
except Exception as exc:
logger.warning("init_flags_node: FlagBridge failed (%s), using env defaults", exc)
# FlagBridge already handles its own fallback internally, so this is a
# safety net for any unexpected error in the bridge itself.
from agent.utils.flag_bridge import _ENV_DEFAULTS
runtime_flags = dict(_ENV_DEFAULTS)

# Emit to Langfuse for observability
try:
trace_id = langfuse_client.get_current_trace_id()
if trace_id:
langfuse_client.trace(
id=trace_id,
metadata={
"runtime_flags": runtime_flags,
"execution_mode": execution_mode or "default",
},
)
except Exception as exc:
logger.warning("init_flags_node: Langfuse trace failed: %s", exc)

logger.info(
"init_flags_node: resolved %d flags (mode=%s)",
len(runtime_flags),
execution_mode or "default",
)

return {"runtime_flags": runtime_flags}
14 changes: 12 additions & 2 deletions agent/src/agent/nodes/init_skills.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@ async def init_skills_node(state: AgentState) -> dict:
keeping reasoning nodes pure and state reproducible.
"""
active_skills = state.get("active_skills")
runtime_flags = state.get("runtime_flags") or {}

if not active_skills:
from agent.config import settings

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import to top of file

skills_enabled = bool(runtime_flags.get("SKILLS_ENABLED", settings.SKILLS_ENABLED))
hot_reload = bool(runtime_flags.get("SKILLS_HOT_RELOAD", settings.SKILLS_HOT_RELOAD))
cache_ttl = int(runtime_flags.get("SKILLS_CACHE_TTL", settings.SKILLS_CACHE_TTL))

Comment on lines +20 to +23

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Harden runtime-flag parsing for bool/int values.

Direct bool(...)/int(...) on dynamic flag payloads can misinterpret values ("false"True) or raise (None, bad strings) before the node’s error handling.

Suggested fix
+    def _to_bool(v, default: bool) -> bool:
+        if isinstance(v, bool):
+            return v
+        if isinstance(v, str):
+            s = v.strip().lower()
+            if s in {"1", "true", "yes", "on"}:
+                return True
+            if s in {"0", "false", "no", "off"}:
+                return False
+        return default
+
+    def _to_int(v, default: int) -> int:
+        try:
+            return int(v)
+        except (TypeError, ValueError):
+            return default
+
-    skills_enabled = bool(runtime_flags.get("SKILLS_ENABLED", settings.SKILLS_ENABLED))
-    hot_reload = bool(runtime_flags.get("SKILLS_HOT_RELOAD", settings.SKILLS_HOT_RELOAD))
-    cache_ttl = int(runtime_flags.get("SKILLS_CACHE_TTL", settings.SKILLS_CACHE_TTL))
+    skills_enabled = _to_bool(runtime_flags.get("SKILLS_ENABLED"), settings.SKILLS_ENABLED)
+    hot_reload = _to_bool(runtime_flags.get("SKILLS_HOT_RELOAD"), settings.SKILLS_HOT_RELOAD)
+    cache_ttl = _to_int(runtime_flags.get("SKILLS_CACHE_TTL"), settings.SKILLS_CACHE_TTL)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@agent/src/agent/nodes/init_skills.py` around lines 20 - 23, The issue is in
how skills_enabled, hot_reload, and cache_ttl values are being parsed from
runtime_flags using direct bool() and int() conversions. Direct bool() on string
values like "false" incorrectly converts to True, and int() can raise exceptions
on None or invalid strings before proper error handling occurs. Fix this by
implementing safe parsing: for boolean flags, explicitly check if the string
value equals "true" or "false" (case-insensitive) rather than using bool()
directly, and for the cache_ttl integer, wrap the int() conversion in a
try-except block to gracefully fall back to the settings default when parsing
fails, ensuring all three flag assignments handle invalid or missing values
safely before they reach the rest of the node logic.

if not skills_enabled or not active_skills:
return {"loaded_skills": None}

try:
_skill_registry.redis = get_redis_client()
loaded_skills = await _skill_registry.get_skills(active_skills)
loaded_skills = await _skill_registry.get_skills(
active_skills,
hot_reload=hot_reload,
cache_ttl=cache_ttl,
)

if loaded_skills:
try:
Expand Down
8 changes: 3 additions & 5 deletions agent/src/agent/nodes/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
from agent.config import settings
from agent.langfuse_client import langfuse_client
from langgraph.types import interrupt

llm = get_llm("query_builder")


def query_builder_node(state: AgentState):
"""Build SQL from plan and pause for user approval."""
runtime_flags = state.get("runtime_flags") or {}
feedback = state.get("feedback")
feedback_str = f"\nUser Feedback to apply: {feedback}" if feedback else ""

Expand All @@ -23,7 +20,8 @@ def query_builder_node(state: AgentState):

langfuse_prompt = langfuse_client.get_prompt(settings.LANGFUSE_PROMPT_QUERY_BUILDER)
prompt = ChatPromptTemplate.from_messages(langfuse_prompt.get_langchain_prompt())
chain = prompt | llm
_llm = get_llm("query_builder", runtime_flags=runtime_flags)
chain = prompt | _llm
response = chain.invoke(
{
"schema_plan": state.get("schema_plan"),
Expand Down
12 changes: 8 additions & 4 deletions agent/src/agent/nodes/refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from core import execute_query_sync
from agent.config import settings
from agent.langfuse_client import langfuse_client
from agent.langfuse_client import langfuse_client
from langchain_core.prompts import ChatPromptTemplate
from agent.llm import get_llm
from agent.utils.sql import clean_sql
Expand All @@ -33,6 +32,10 @@ async def refiner_node(state: AgentState):
sql = state.get("sql_query")
count = state.get("refinement_count", 0)
error_history = state.get("error_history") or []
runtime_flags = state.get("runtime_flags") or {}

# Resolve per-invocation limit (DS-tunable via flags)
max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate MAX_REFINER_ITERATIONS before casting.

A malformed/nullable runtime flag will raise here and fail the node before refinement logic runs.

Suggested fix
-    max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
+    try:
+        max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
+    except (TypeError, ValueError):
+        max_iterations = MAX_REFINER_ITERATIONS
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
try:
max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
except (TypeError, ValueError):
max_iterations = MAX_REFINER_ITERATIONS
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@agent/src/agent/nodes/refiner.py` at line 38, The code directly casts the
result of runtime_flags.get() to an integer without validating the retrieved
value, which will raise an exception if the runtime flag is malformed, nullable,
or not a valid integer. Before the int() cast in the max_iterations assignment,
validate that the retrieved value from
runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS) is not None
and is a valid integer string or numeric value, and provide appropriate fallback
handling (such as using the default MAX_REFINER_ITERATIONS constant) if
validation fails.


# Execute against Trino
try:
Expand All @@ -49,14 +52,14 @@ async def refiner_node(state: AgentState):

if not success:
# If we reached the refinement limit, just stop and don't prompt LLM
if count >= MAX_REFINER_ITERATIONS:
if count >= max_iterations:
return {
"trino_error": trino_error,
"last_error": trino_error,
"refinement_count": count + 1,
"error_history": error_history,
Comment on lines +55 to +60

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Refinement stop condition is off by one.

With count >= max_iterations and returning count + 1, the node reports/executed attempts beyond the configured maximum.

Suggested fix
-        if count >= max_iterations:
+        if count + 1 >= max_iterations:
             return {
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@agent/src/agent/nodes/refiner.py` around lines 55 - 60, The stop condition
for the refiner node has an off-by-one error that causes it to exceed the
configured maximum iterations. In the refiner.py file where the refinement_count
condition is checked, change the comparison operator from `if count >=
max_iterations:` to `if count > max_iterations:` to ensure the node stops at
exactly the maximum number of iterations rather than one iteration beyond it.
This will prevent the refinement_count from exceeding the intended limit.

"escalation_reason": (
f"Refiner exhausted {MAX_REFINER_ITERATIONS} iterations. "
f"Refiner exhausted {max_iterations} iterations. "
f"Last Trino error: {trino_error}"
),
}
Expand All @@ -65,7 +68,8 @@ async def refiner_node(state: AgentState):
prompt = ChatPromptTemplate.from_messages(
langfuse_prompt.get_langchain_prompt()
)
chain = prompt | llm
_llm = get_llm("refiner", runtime_flags=runtime_flags)
chain = prompt | _llm

schema_context = build_refiner_schema_context(state)

Expand Down
Loading