Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions hindsight-api-slim/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,25 @@ class ReflectRequest(BaseModel):
description="Compound tag filter using boolean groups. Groups in the list are AND-ed. "
"Each group is a leaf {tags, match} or compound {and: [...]}, {or: [...]}, {not: ...}.",
)
fact_types: list[Literal["world", "experience", "observation"]] | None = Field(
default=None,
description="Filter which fact types are retrieved during reflect. None means all types (world, experience, observation).",
)
exclude_mental_models: bool = Field(
default=False,
description="If true, exclude all mental models from the reflect loop (skip search_mental_models tool).",
)
exclude_mental_model_ids: list[str] | None = Field(
default=None,
description="Exclude specific mental models by ID from the reflect loop.",
)

@field_validator("fact_types")
@classmethod
def validate_reflect_fact_types(cls, v: list[str] | None) -> list[str] | None:
if v is not None and len(v) == 0:
raise ValueError("fact_types must not be empty. Use null to include all fact types.")
return v

@model_validator(mode="after")
def validate_tags_exclusive(self) -> "ReflectRequest":
Expand Down Expand Up @@ -1435,6 +1454,25 @@ class MentalModelTrigger(BaseModel):
default=False,
description="If true, refresh this mental model after observations consolidation (real-time mode)",
)
fact_types: list[Literal["world", "experience", "observation"]] | None = Field(
default=None,
description="Filter which fact types are retrieved during reflect. None means all types (world, experience, observation).",
)
exclude_mental_models: bool = Field(
default=False,
description="If true, exclude all mental models from the reflect loop (skip search_mental_models tool).",
)
exclude_mental_model_ids: list[str] | None = Field(
default=None,
description="Exclude specific mental models by ID from the reflect loop.",
)

@field_validator("fact_types")
@classmethod
def validate_fact_types(cls, v: list[str] | None) -> list[str] | None:
if v is not None and len(v) == 0:
raise ValueError("fact_types must not be empty. Use null to include all fact types.")
return v


class MentalModelResponse(BaseModel):
Expand Down Expand Up @@ -2505,6 +2543,9 @@ async def api_reflect(
tags=request.tags,
tags_match=request.tags_match,
tag_groups=request.tag_groups,
fact_types=request.fact_types,
exclude_mental_models=request.exclude_mental_models,
exclude_mental_model_ids=request.exclude_mental_model_ids,
)

# Build based_on (memories + mental_models + directives) if facts are requested
Expand Down
51 changes: 40 additions & 11 deletions hindsight-api-slim/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,14 +868,23 @@ async def _handle_refresh_mental_model(self, task_dict: dict[str, Any]):
tags = mental_model.get("tags")
tags_match = "all_strict" if tags else "any"

# Read reflect options from trigger (if stored)
trigger_data = mental_model.get("trigger") or {}
fact_types = trigger_data.get("fact_types")
exclude_mental_models = trigger_data.get("exclude_mental_models", False)
stored_exclude_ids: list[str] = trigger_data.get("exclude_mental_model_ids") or []

# Run reflect to generate new content, excluding the mental model being refreshed
# Always add self to excluded IDs to prevent circular reference
reflect_result = await self.reflect_async(
bank_id=bank_id,
query=source_query,
request_context=internal_context,
tags=tags,
tags_match=tags_match,
exclude_mental_model_ids=[mental_model_id],
fact_types=fact_types,
exclude_mental_models=exclude_mental_models,
exclude_mental_model_ids=list({*stored_exclude_ids, mental_model_id}),
)

generated_content = reflect_result.text or "No content generated"
Expand Down Expand Up @@ -5113,6 +5122,8 @@ async def reflect_async(
tags_match: TagsMatch = "any",
tag_groups: list[TagGroup] | None = None,
exclude_mental_model_ids: list[str] | None = None,
fact_types: list[str] | None = None,
exclude_mental_models: bool = False,
_skip_span: bool = False,
) -> ReflectResult:
"""
Expand Down Expand Up @@ -5233,6 +5244,11 @@ async def search_observations_fn(q: str, max_tokens: int = 5000) -> dict[str, An
pending_consolidation=pending_consolidation,
)

# Determine which tools to enable based on fact_types and exclude_mental_models
include_observations = fact_types is None or "observation" in fact_types
recall_fact_types = [ft for ft in (fact_types or ["world", "experience"]) if ft in ("world", "experience")]
include_recall = bool(recall_fact_types)

async def recall_fn(q: str, max_tokens: int = 4096, max_chunk_tokens: int = 1000) -> dict[str, Any]:
return await tool_recall(
self,
Expand All @@ -5244,6 +5260,7 @@ async def recall_fn(q: str, max_tokens: int = 4096, max_chunk_tokens: int = 1000
tags_match=tags_match,
tag_groups=tag_groups,
max_chunk_tokens=max_chunk_tokens,
fact_types=recall_fact_types if fact_types is not None else None,
)

async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
Expand All @@ -5266,15 +5283,17 @@ async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
if directives:
logger.info(f"[REFLECT {reflect_id}] Loaded {len(directives)} directives")

# Check if the bank has any mental models
async with pool.acquire() as conn:
mental_model_count = await conn.fetchval(
f"SELECT COUNT(*) FROM {fq_table('mental_models')} WHERE bank_id = $1",
bank_id,
)
has_mental_models = mental_model_count > 0
if has_mental_models:
logger.info(f"[REFLECT {reflect_id}] Bank has {mental_model_count} mental models")
# Check if the bank has any mental models (skip check if all mental models are excluded)
has_mental_models = False
if not exclude_mental_models:
async with pool.acquire() as conn:
mental_model_count = await conn.fetchval(
f"SELECT COUNT(*) FROM {fq_table('mental_models')} WHERE bank_id = $1",
bank_id,
)
has_mental_models = mental_model_count > 0
if has_mental_models:
logger.info(f"[REFLECT {reflect_id}] Bank has {mental_model_count} mental models")

# Run the agent with parent span for reflect operation (skip if called from another operation)
if not _skip_span:
Expand All @@ -5299,6 +5318,8 @@ async def expand_fn(memory_ids: list[str], depth: str) -> dict[str, Any]:
response_schema=response_schema,
directives=directives,
has_mental_models=has_mental_models,
include_observations=include_observations,
include_recall=include_recall,
budget=effective_budget,
max_context_tokens=max_context_tokens,
)
Expand Down Expand Up @@ -6430,6 +6451,12 @@ async def refresh_mental_model(
tags = mental_model.get("tags")
tags_match = "all_strict" if tags else "any"

# Read reflect options from trigger (if stored)
trigger_data = mental_model.get("trigger") or {}
fact_types = trigger_data.get("fact_types")
exclude_mental_models = trigger_data.get("exclude_mental_models", False)
stored_exclude_ids: list[str] = trigger_data.get("exclude_mental_model_ids") or []

# Run reflect with the source query, excluding the mental model being refreshed
# Skip creating a nested "hindsight.reflect" span since we already have "hindsight.mental_model_refresh"
reflect_result = await self.reflect_async(
Expand All @@ -6438,7 +6465,9 @@ async def refresh_mental_model(
request_context=request_context,
tags=tags,
tags_match=tags_match,
exclude_mental_model_ids=[mental_model_id],
fact_types=fact_types,
exclude_mental_models=exclude_mental_models,
exclude_mental_model_ids=list({*stored_exclude_ids, mental_model_id}),
_skip_span=True,
)

Expand Down
73 changes: 58 additions & 15 deletions hindsight-api-slim/hindsight_api/engine/reflect/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,8 @@ async def run_reflect_agent(
response_schema: dict | None = None,
directives: list[dict[str, Any]] | None = None,
has_mental_models: bool = False,
include_observations: bool = True,
include_recall: bool = True,
budget: str | None = None,
max_context_tokens: int = 100_000,
) -> ReflectAgentResult:
Expand Down Expand Up @@ -355,7 +357,14 @@ async def run_reflect_agent(
directive_rules = _extract_directive_rules(directives) if directives else None

# Get tools for this agent (with directive compliance field if directives exist)
tools = get_reflect_tools(directive_rules=directive_rules)
tools = get_reflect_tools(
directive_rules=directive_rules,
include_mental_models=has_mental_models,
include_observations=include_observations,
include_recall=include_recall,
)
# Build set of enabled tool names to guard against LLM hallucinating disabled tool calls
enabled_tools: frozenset[str] = frozenset(t["function"]["name"] for t in tools if t.get("type") == "function")

# Build initial messages (directives are injected into system prompt at START and END)
system_prompt = build_system_prompt_for_tools(
Expand Down Expand Up @@ -538,19 +547,18 @@ def _log_completion(answer: str, iterations: int, forced: bool = False):
llm_start = time.time()

# Determine tool_choice for this iteration.
# Force the full hierarchical retrieval path before allowing auto:
# With mental models:
# 0 → search_mental_models, 1 → search_observations, 2 → recall, 3+ → auto
# Without mental models:
# 0 → search_observations, 1 → recall, 2+ → auto
if iteration == 0 and has_mental_models:
iter_tool_choice: str | dict = {"type": "function", "function": {"name": "search_mental_models"}}
elif iteration == 0:
iter_tool_choice = {"type": "function", "function": {"name": "search_observations"}}
elif iteration == 1 and has_mental_models:
iter_tool_choice = {"type": "function", "function": {"name": "search_observations"}}
elif iteration == 1 or (iteration == 2 and has_mental_models):
iter_tool_choice = {"type": "function", "function": {"name": "recall"}}
# Force the full hierarchical retrieval path (only for enabled tools) before allowing auto.
# Build the forced sequence from the tools that are actually enabled.
forced_sequence = []
if has_mental_models:
forced_sequence.append("search_mental_models")
if include_observations:
forced_sequence.append("search_observations")
if include_recall:
forced_sequence.append("recall")

if iteration < len(forced_sequence):
iter_tool_choice: str | dict = {"type": "function", "function": {"name": forced_sequence[iteration]}}
else:
iter_tool_choice = "auto"

Expand Down Expand Up @@ -769,14 +777,41 @@ def _log_completion(answer: str, iterations: int, forced: bool = False):
# Execute other tools in parallel (exclude done tool in all its format variants)
other_tools = [tc for tc in result.tool_calls if not _is_done_tool(tc.name)]
if other_tools:
# Add assistant message with tool calls
# Partition into enabled vs hallucinated (not in enabled_tools set)
allowed_tools = []
hallucinated_tools = []
for tc in other_tools:
norm = _normalize_tool_name(tc.name)
if enabled_tools is not None and norm not in enabled_tools and norm not in ("done", "expand"):
hallucinated_tools.append(tc)
else:
allowed_tools.append(tc)

# Build assistant message with all tool calls (LLM requires them for history)
messages.append(
{
"role": "assistant",
"tool_calls": [_tool_call_to_dict(tc) for tc in other_tools],
}
)

# Immediately reject hallucinated tool calls without adding to trace
for tc in hallucinated_tools:
messages.append(
{
"role": "tool",
"tool_call_id": tc.id,
"name": tc.name,
"content": json.dumps(
{
"error": f"Tool '{_normalize_tool_name(tc.name)}' is not available. Use only the tools provided to you."
}
),
}
)

other_tools = allowed_tools

# Execute tools in parallel
tool_tasks = [
_execute_tool_with_timing(
Expand All @@ -785,6 +820,7 @@ def _log_completion(answer: str, iterations: int, forced: bool = False):
search_observations_fn,
recall_fn,
expand_fn,
enabled_tools=enabled_tools,
)
for tc in other_tools
]
Expand Down Expand Up @@ -974,6 +1010,7 @@ async def _execute_tool_with_timing(
search_observations_fn: Callable[[str, int], Awaitable[dict[str, Any]]],
recall_fn: Callable[[str, int, int], Awaitable[dict[str, Any]]],
expand_fn: Callable[[list[str], str], Awaitable[dict[str, Any]]],
enabled_tools: frozenset[str] | None = None,
) -> tuple[dict[str, Any], int]:
"""Execute a tool call and return result with timing."""
from hindsight_api.tracing import get_tracer
Expand Down Expand Up @@ -1007,6 +1044,7 @@ async def _execute_tool_with_timing(
search_observations_fn,
recall_fn,
expand_fn,
enabled_tools=enabled_tools,
)

# Set success attributes
Expand Down Expand Up @@ -1046,11 +1084,16 @@ async def _execute_tool(
search_observations_fn: Callable[[str, int], Awaitable[dict[str, Any]]],
recall_fn: Callable[[str, int, int], Awaitable[dict[str, Any]]],
expand_fn: Callable[[list[str], str], Awaitable[dict[str, Any]]],
enabled_tools: frozenset[str] | None = None,
) -> dict[str, Any]:
"""Execute a single tool by name."""
# Normalize tool name for various LLM output formats
tool_name = _normalize_tool_name(tool_name)

# Guard against LLMs hallucinating calls to tools that were not provided
if enabled_tools is not None and tool_name not in enabled_tools and tool_name not in ("done", "expand"):
return {"error": f"Tool '{tool_name}' is not available. Use only the tools provided to you."}

if tool_name == "search_mental_models":
query = args.get("query")
if not query:
Expand Down
6 changes: 5 additions & 1 deletion hindsight-api-slim/hindsight_api/engine/reflect/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def tool_recall(
tag_groups: "list | None" = None,
connection_budget: int = 1,
max_chunk_tokens: int = 1000,
fact_types: list[str] | None = None,
) -> dict[str, Any]:
"""
Search memories using TEMPR retrieval.
Expand All @@ -217,15 +218,18 @@ async def tool_recall(
tags_match: How to match tags - "any" (OR), "all" (AND), or "exact"
connection_budget: Max DB connections for this recall (default 1 for internal ops)
max_chunk_tokens: Maximum tokens for raw source chunk text (default 1000, always included)
fact_types: Optional filter for fact types to retrieve. Defaults to ["experience", "world"].

Returns:
Dict with list of matching memories including raw chunk text
"""
# Only world/experience are valid for raw recall (observation is handled by search_observations)
recall_fact_type = [ft for ft in (fact_types or ["experience", "world"]) if ft in ("world", "experience")]
include_chunks = True
result = await memory_engine.recall_async(
bank_id=bank_id,
query=query,
fact_type=["experience", "world"],
fact_type=recall_fact_type,
max_tokens=max_tokens,
enable_trace=False,
request_context=request_context,
Expand Down
26 changes: 19 additions & 7 deletions hindsight-api-slim/hindsight_api/engine/reflect/tools_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,12 @@ def _build_done_tool_with_directives(directive_rules: list[str]) -> dict:
}


def get_reflect_tools(directive_rules: list[str] | None = None) -> list[dict]:
def get_reflect_tools(
directive_rules: list[str] | None = None,
include_mental_models: bool = True,
include_observations: bool = True,
include_recall: bool = True,
) -> list[dict]:
"""
Get the list of tools for the reflect agent.

Expand All @@ -239,16 +244,23 @@ def get_reflect_tools(directive_rules: list[str] | None = None) -> list[dict]:
Args:
directive_rules: Optional list of directive rule strings. If provided,
the done() tool will require directive compliance confirmation.
include_mental_models: Whether to include the search_mental_models tool.
include_observations: Whether to include the search_observations tool.
include_recall: Whether to include the recall tool.

Returns:
List of tool definitions in OpenAI format
"""
tools = [
TOOL_SEARCH_MENTAL_MODELS,
TOOL_SEARCH_OBSERVATIONS,
TOOL_RECALL,
TOOL_EXPAND,
]
tools = []

if include_mental_models:
tools.append(TOOL_SEARCH_MENTAL_MODELS)
if include_observations:
tools.append(TOOL_SEARCH_OBSERVATIONS)
if include_recall:
tools.append(TOOL_RECALL)

tools.append(TOOL_EXPAND)

# Use directive-aware done tool if directives are present
if directive_rules:
Expand Down
Loading
Loading