-
Notifications
You must be signed in to change notification settings - Fork 0
feat: implement feature flags & execution modes builder #18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,16 +1,67 @@ | ||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from langchain_openai import ChatOpenAI | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from agent.config import settings | ||||||||||||||||||||||||||||||||||||||
| from typing import Optional | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Flag name → LLM_MODEL env-var fallback for each node | ||||||||||||||||||||||||||||||||||||||
| _NODE_MODEL_FLAGS: dict[str, str] = { | ||||||||||||||||||||||||||||||||||||||
| "extractor": "EXTRACTOR_MODEL", | ||||||||||||||||||||||||||||||||||||||
| "schema_explorer": "SCHEMA_SUMMARY_MODEL", | ||||||||||||||||||||||||||||||||||||||
| "query_builder": "QUERY_BUILDER_MODEL", | ||||||||||||||||||||||||||||||||||||||
| "refiner": "REFINER_MODEL", | ||||||||||||||||||||||||||||||||||||||
| "satisfaction_check": "SATISFACTION_JUDGE_MODEL", | ||||||||||||||||||||||||||||||||||||||
| "routing": "QUERY_BUILDER_MODEL", # rejection router reuses QB model | ||||||||||||||||||||||||||||||||||||||
| "default": "QUERY_BUILDER_MODEL", | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| _NODE_TEMP_FLAGS: dict[str, str] = { | ||||||||||||||||||||||||||||||||||||||
| "extractor": "EXTRACTOR_TEMPERATURE", | ||||||||||||||||||||||||||||||||||||||
| "query_builder": "QUERY_BUILDER_TEMPERATURE", | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+11
to
+24
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are we not using the same name? this might be confusing |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def get_llm( | ||||||||||||||||||||||||||||||||||||||
| node: str = "default", | ||||||||||||||||||||||||||||||||||||||
| temperature: Optional[float] = None, | ||||||||||||||||||||||||||||||||||||||
| runtime_flags: Optional[dict] = None, | ||||||||||||||||||||||||||||||||||||||
| ) -> ChatOpenAI: | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Factory for per-node LLM instances. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Priority for model/temperature selection: | ||||||||||||||||||||||||||||||||||||||
| 1. runtime_flags (resolved by init_flags_node from DB + execution mode) | ||||||||||||||||||||||||||||||||||||||
| 2. AgentSettings env-var defaults | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||
| node: Name of the calling graph node (used to pick the right flag). | ||||||||||||||||||||||||||||||||||||||
| temperature: Optional hard override — bypasses flag resolution. | ||||||||||||||||||||||||||||||||||||||
| runtime_flags: The state["runtime_flags"] dict from the current invocation. | ||||||||||||||||||||||||||||||||||||||
| Pass None when initialising at module level (will use env defaults). | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| flags = runtime_flags or {} | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Resolve model | ||||||||||||||||||||||||||||||||||||||
| model_flag = _NODE_MODEL_FLAGS.get(node, "QUERY_BUILDER_MODEL") | ||||||||||||||||||||||||||||||||||||||
| model = flags.get(model_flag) or settings.LLM_MODEL | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # Resolve temperature | ||||||||||||||||||||||||||||||||||||||
| if temperature is None: | ||||||||||||||||||||||||||||||||||||||
| temp_flag = _NODE_TEMP_FLAGS.get(node) | ||||||||||||||||||||||||||||||||||||||
| temperature = float(flags.get(temp_flag, 0.0)) if temp_flag else 0.0 | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+52
to
+55
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle malformed temperature flag values defensively. Line 54 casts a dynamic flag directly with Suggested fix if temperature is None:
temp_flag = _NODE_TEMP_FLAGS.get(node)
- temperature = float(flags.get(temp_flag, 0.0)) if temp_flag else 0.0
+ if temp_flag:
+ raw_temp = flags.get(temp_flag, 0.0)
+ try:
+ temperature = float(raw_temp)
+ except (TypeError, ValueError):
+ logger.warning(
+ "Invalid temperature override for %s: %r; falling back to 0.0",
+ temp_flag,
+ raw_temp,
+ )
+ temperature = 0.0
+ else:
+ temperature = 0.0📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| logger.debug( | ||||||||||||||||||||||||||||||||||||||
| "Instantiating LLM for node='%s': model='%s' temperature=%.2f", | ||||||||||||||||||||||||||||||||||||||
| node, | ||||||||||||||||||||||||||||||||||||||
| model, | ||||||||||||||||||||||||||||||||||||||
| temperature, | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| # We could dynamically configure settings based on the node name. | ||||||||||||||||||||||||||||||||||||||
| # For now, it delegates to agent.config settings. | ||||||||||||||||||||||||||||||||||||||
| def get_llm(node: str = "default", temperature: Optional[float] = 0.0) -> ChatOpenAI: | ||||||||||||||||||||||||||||||||||||||
| """Factory function for instantiating the unified LLM.""" | ||||||||||||||||||||||||||||||||||||||
| logging.debug(f"Instantiating LLM for node: {node}") | ||||||||||||||||||||||||||||||||||||||
| return ChatOpenAI( | ||||||||||||||||||||||||||||||||||||||
| model=settings.LLM_MODEL, | ||||||||||||||||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||||||||||||||||
| base_url=settings.LLM_BASE_URL, | ||||||||||||||||||||||||||||||||||||||
| api_key=settings.LLM_API_KEY, | ||||||||||||||||||||||||||||||||||||||
| temperature=temperature, | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,8 +45,8 @@ def extract(self, query: str) -> List[ContextEntry]: | |
|
|
||
|
|
||
| class LLMExtractor(BaseExtractor): | ||
| def __init__(self): | ||
| self.llm = get_llm("extractor") | ||
| def __init__(self, runtime_flags: dict | None = None): | ||
| self.llm = get_llm("extractor", runtime_flags=runtime_flags) | ||
|
|
||
| langfuse_prompt = langfuse_client.get_prompt(settings.LANGFUSE_PROMPT_EXTRACTOR) | ||
| self.prompt = ChatPromptTemplate.from_messages( | ||
|
|
@@ -102,10 +102,11 @@ def extractor_node(state: AgentState): | |
| """Enrich the user query with additional context to help downstream phases.""" | ||
| user_query = state["user_query"] | ||
| active_extractors = state.get("active_extractors") or [] | ||
| runtime_flags = state.get("runtime_flags") or {} | ||
|
|
||
| import concurrent.futures | ||
|
|
||
| extractors: List[BaseExtractor] = [TimeExtractor(), LLMExtractor()] | ||
| extractors: List[BaseExtractor] = [TimeExtractor(), LLMExtractor(runtime_flags=runtime_flags)] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why pass the runtime flags only to one extractor? the time extractor and the http extractors don't need it? |
||
|
|
||
| for ext_info in active_extractors: | ||
| extractors.append(HTTPExtractor(ext_info["url"], ext_info["name"])) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| """ | ||
| init_flags_node (G4) | ||
| ==================== | ||
| Runs immediately after validate_config, before init_skills. | ||
|
|
||
| Responsibilities: | ||
| 1. Call FlagBridge.resolve_flags(execution_mode) to merge: | ||
| mode overrides → DB flags → env-var defaults | ||
| 2. Write the resolved dict to state["runtime_flags"] | ||
| 3. Log runtime_flags to Langfuse trace metadata for full observability | ||
| """ | ||
|
|
||
| import logging | ||
|
|
||
| from agent.langfuse_client import langfuse_client | ||
| from agent.state import AgentState | ||
| from agent.utils.flag_bridge import FlagBridge | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| _flag_bridge = FlagBridge() | ||
|
|
||
|
|
||
| async def init_flags_node(state: AgentState) -> dict: | ||
| """ | ||
| Resolve all runtime configuration flags for this invocation. | ||
|
|
||
| The resolved dict is stored in state["runtime_flags"] and read by every | ||
| downstream node instead of directly accessing AgentSettings env vars. | ||
| This guarantees that: | ||
| - DS team changes in the Studio UI take effect within the cache TTL (30s). | ||
| - Execution mode overrides are applied consistently to all nodes. | ||
| - Every Langfuse trace carries the exact config used for that query. | ||
| """ | ||
| execution_mode: str | None = state.get("execution_mode") | ||
|
|
||
| try: | ||
| runtime_flags = await _flag_bridge.resolve_flags(execution_mode) | ||
| except Exception as exc: | ||
| logger.warning("init_flags_node: FlagBridge failed (%s), using env defaults", exc) | ||
| # FlagBridge already handles its own fallback internally, so this is a | ||
| # safety net for any unexpected error in the bridge itself. | ||
| from agent.utils.flag_bridge import _ENV_DEFAULTS | ||
| runtime_flags = dict(_ENV_DEFAULTS) | ||
|
|
||
| # Emit to Langfuse for observability | ||
| try: | ||
| trace_id = langfuse_client.get_current_trace_id() | ||
| if trace_id: | ||
| langfuse_client.trace( | ||
| id=trace_id, | ||
| metadata={ | ||
| "runtime_flags": runtime_flags, | ||
| "execution_mode": execution_mode or "default", | ||
| }, | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning("init_flags_node: Langfuse trace failed: %s", exc) | ||
|
|
||
| logger.info( | ||
| "init_flags_node: resolved %d flags (mode=%s)", | ||
| len(runtime_flags), | ||
| execution_mode or "default", | ||
| ) | ||
|
|
||
| return {"runtime_flags": runtime_flags} |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,23 @@ async def init_skills_node(state: AgentState) -> dict: | |
| keeping reasoning nodes pure and state reproducible. | ||
| """ | ||
| active_skills = state.get("active_skills") | ||
| runtime_flags = state.get("runtime_flags") or {} | ||
|
|
||
| if not active_skills: | ||
| from agent.config import settings | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. import to top of file |
||
| skills_enabled = bool(runtime_flags.get("SKILLS_ENABLED", settings.SKILLS_ENABLED)) | ||
| hot_reload = bool(runtime_flags.get("SKILLS_HOT_RELOAD", settings.SKILLS_HOT_RELOAD)) | ||
| cache_ttl = int(runtime_flags.get("SKILLS_CACHE_TTL", settings.SKILLS_CACHE_TTL)) | ||
|
|
||
|
Comment on lines
+20
to
+23
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Harden runtime-flag parsing for bool/int values. Direct Suggested fix+ def _to_bool(v, default: bool) -> bool:
+ if isinstance(v, bool):
+ return v
+ if isinstance(v, str):
+ s = v.strip().lower()
+ if s in {"1", "true", "yes", "on"}:
+ return True
+ if s in {"0", "false", "no", "off"}:
+ return False
+ return default
+
+ def _to_int(v, default: int) -> int:
+ try:
+ return int(v)
+ except (TypeError, ValueError):
+ return default
+
- skills_enabled = bool(runtime_flags.get("SKILLS_ENABLED", settings.SKILLS_ENABLED))
- hot_reload = bool(runtime_flags.get("SKILLS_HOT_RELOAD", settings.SKILLS_HOT_RELOAD))
- cache_ttl = int(runtime_flags.get("SKILLS_CACHE_TTL", settings.SKILLS_CACHE_TTL))
+ skills_enabled = _to_bool(runtime_flags.get("SKILLS_ENABLED"), settings.SKILLS_ENABLED)
+ hot_reload = _to_bool(runtime_flags.get("SKILLS_HOT_RELOAD"), settings.SKILLS_HOT_RELOAD)
+ cache_ttl = _to_int(runtime_flags.get("SKILLS_CACHE_TTL"), settings.SKILLS_CACHE_TTL)🤖 Prompt for AI Agents |
||
| if not skills_enabled or not active_skills: | ||
| return {"loaded_skills": None} | ||
|
|
||
| try: | ||
| _skill_registry.redis = get_redis_client() | ||
| loaded_skills = await _skill_registry.get_skills(active_skills) | ||
| loaded_skills = await _skill_registry.get_skills( | ||
| active_skills, | ||
| hot_reload=hot_reload, | ||
| cache_ttl=cache_ttl, | ||
| ) | ||
|
|
||
| if loaded_skills: | ||
| try: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,7 +6,6 @@ | |||||||||||
| from core import execute_query_sync | ||||||||||||
| from agent.config import settings | ||||||||||||
| from agent.langfuse_client import langfuse_client | ||||||||||||
| from agent.langfuse_client import langfuse_client | ||||||||||||
| from langchain_core.prompts import ChatPromptTemplate | ||||||||||||
| from agent.llm import get_llm | ||||||||||||
| from agent.utils.sql import clean_sql | ||||||||||||
|
|
@@ -33,6 +32,10 @@ async def refiner_node(state: AgentState): | |||||||||||
| sql = state.get("sql_query") | ||||||||||||
| count = state.get("refinement_count", 0) | ||||||||||||
| error_history = state.get("error_history") or [] | ||||||||||||
| runtime_flags = state.get("runtime_flags") or {} | ||||||||||||
|
|
||||||||||||
| # Resolve per-invocation limit (DS-tunable via flags) | ||||||||||||
| max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS)) | ||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate A malformed/nullable runtime flag will raise here and fail the node before refinement logic runs. Suggested fix- max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
+ try:
+ max_iterations = int(runtime_flags.get("MAX_REFINER_ITERATIONS", MAX_REFINER_ITERATIONS))
+ except (TypeError, ValueError):
+ max_iterations = MAX_REFINER_ITERATIONS📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||
|
|
||||||||||||
| # Execute against Trino | ||||||||||||
| try: | ||||||||||||
|
|
@@ -49,14 +52,14 @@ async def refiner_node(state: AgentState): | |||||||||||
|
|
||||||||||||
| if not success: | ||||||||||||
| # If we reached the refinement limit, just stop and don't prompt LLM | ||||||||||||
| if count >= MAX_REFINER_ITERATIONS: | ||||||||||||
| if count >= max_iterations: | ||||||||||||
| return { | ||||||||||||
| "trino_error": trino_error, | ||||||||||||
| "last_error": trino_error, | ||||||||||||
| "refinement_count": count + 1, | ||||||||||||
| "error_history": error_history, | ||||||||||||
|
Comment on lines
+55
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Refinement stop condition is off by one. With Suggested fix- if count >= max_iterations:
+ if count + 1 >= max_iterations:
return {🤖 Prompt for AI Agents |
||||||||||||
| "escalation_reason": ( | ||||||||||||
| f"Refiner exhausted {MAX_REFINER_ITERATIONS} iterations. " | ||||||||||||
| f"Refiner exhausted {max_iterations} iterations. " | ||||||||||||
| f"Last Trino error: {trino_error}" | ||||||||||||
| ), | ||||||||||||
| } | ||||||||||||
|
|
@@ -65,7 +68,8 @@ async def refiner_node(state: AgentState): | |||||||||||
| prompt = ChatPromptTemplate.from_messages( | ||||||||||||
| langfuse_prompt.get_langchain_prompt() | ||||||||||||
| ) | ||||||||||||
| chain = prompt | llm | ||||||||||||
| _llm = get_llm("refiner", runtime_flags=runtime_flags) | ||||||||||||
| chain = prompt | _llm | ||||||||||||
|
|
||||||||||||
| schema_context = build_refiner_schema_context(state) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add a default empty string? this will disable verification. please make this non optional