Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,10 @@
.env
__pycache__/
*/__pycache__/
results/
results/

# OS files
.DS_Store

# data files
eval/stream-bench/data
250 changes: 168 additions & 82 deletions ace/ace.py

Large diffs are not rendered by default.

70 changes: 64 additions & 6 deletions ace/core/reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ def reflect(
use_ground_truth: bool = True,
use_json_mode: bool = False,
call_id: str = "reflect",
log_dir: Optional[str] = None
log_dir: Optional[str] = None,
sql_exec_results: Optional[Dict[str, Any]] = None
) -> Tuple[str, List[Dict[str, str]], Dict[str, Any]]:
"""
Analyze the generator's output and tag bullets.

Args:
question: The original question
reasoning_trace: The generator's reasoning
Expand All @@ -57,10 +58,18 @@ def reflect(
use_json_mode: Whether to use JSON mode
call_id: Unique identifier for this call
log_dir: Directory for logging

sql_exec_results: Optional dict containing SQL execution results with keys:
- 'predicted_result': List of tuples from predicted SQL execution
- 'ground_truth_result': List of tuples from ground truth SQL execution
- 'db_name': Database name used for evaluation
- 'error': Error message if execution failed

Returns:
Tuple of (reflection_content, bullet_tags, call_info)
"""
# Format SQL execution results for the prompt
sql_exec_text = self._format_sql_exec_results(sql_exec_results)

# Select the appropriate prompt
if use_ground_truth and ground_truth:
prompt = REFLECTOR_PROMPT.format(
Expand All @@ -69,15 +78,17 @@ def reflect(
predicted_answer,
ground_truth,
environment_feedback,
bullets_used
bullets_used,
sql_exec_text
)
else:
prompt = REFLECTOR_PROMPT_NO_GT.format(
question,
reasoning_trace,
predicted_answer,
environment_feedback,
bullets_used
bullets_used,
sql_exec_text
)

response, call_info = timed_llm_call(
Expand All @@ -96,7 +107,54 @@ def reflect(
bullet_tags = self._extract_bullet_tags(response, use_json_mode)

return response, bullet_tags, call_info


def _format_sql_exec_results(self, sql_exec_results: Optional[Dict[str, Any]]) -> str:
"""
Format SQL execution results for display in the prompt.

Args:
sql_exec_results: Dict containing execution results or None

Returns:
Formatted string describing the execution results
"""
if not sql_exec_results:
return "No SQL execution results available."

if "error" in sql_exec_results:
return f"Error during SQL execution: {sql_exec_results['error']}"

db_name = sql_exec_results.get("db_name", "unknown")
pred_result = sql_exec_results.get("predicted_result", [])
gt_result = sql_exec_results.get("ground_truth_result", [])

# Format the results
lines = [f"Database: {db_name}\n"]

# Predicted SQL results
lines.append(f"Predicted SQL Execution Result ({len(pred_result)} rows):")
if pred_result:
for i, row in enumerate(pred_result[:20]): # Show first 20 rows
lines.append(f" Row {i+1}: {row}")
if len(pred_result) > 20:
lines.append(f" ... ({len(pred_result) - 20} more rows)")
else:
lines.append(" (Empty result set)")

lines.append("")

# Ground truth SQL results
lines.append(f"Ground Truth SQL Execution Result ({len(gt_result)} rows):")
if gt_result:
for i, row in enumerate(gt_result[:20]): # Show first 20 rows
lines.append(f" Row {i+1}: {row}")
if len(gt_result) > 20:
lines.append(f" ... ({len(gt_result) - 20} more rows)")
else:
lines.append(" (Empty result set)")

return "\n".join(lines)

def _extract_bullet_tags(
self,
response: str,
Expand Down
6 changes: 6 additions & 0 deletions ace/prompts/reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
**Part of Playbook that's used by the generator to answer the question:**
{}

**SQL Execution Results (if available):**
{}

**Answer in this exact JSON format:**
{{
"reasoning": "[Your chain of thought / reasoning / thinking process, detailed analysis and calculations]",
Expand Down Expand Up @@ -98,6 +101,9 @@
**Part of Playbook that's used by the generator to answer the question:**
{}

**SQL Execution Results (if available):**
{}

**Answer in this exact JSON format:**
{{
"reasoning": "[Your chain of thought / reasoning / thinking process, detailed analysis and calculations]",
Expand Down
Loading