Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions ace/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,11 @@ def _train_single_sample(
token_budget = config_params['token_budget']
use_json_mode = config_params['use_json_mode']
no_ground_truth = config_params['no_ground_truth']
# if algorithmic task, use code prompt style
if hasattr(data_processor, "get_generator_prompt_style"):
prompt_style = data_processor.get_generator_prompt_style()
else:
prompt_style = "json"

# Extract sample data
question = task_dict.get("question", "")
Expand All @@ -468,6 +473,7 @@ def _train_single_sample(
context=context,
reflection="(empty)",
use_json_mode=use_json_mode,
prompt_style=prompt_style,
call_id=f"{step_id}_gen_initial",
log_dir=log_dir
)
Expand Down Expand Up @@ -533,6 +539,7 @@ def _train_single_sample(
context=context,
reflection=reflection_content,
use_json_mode=use_json_mode,
prompt_style=prompt_style,
call_id=f"{step_id}_post_reflect_round_{round_num}",
log_dir=log_dir
)
Expand Down Expand Up @@ -612,6 +619,7 @@ def _train_single_sample(
context=context,
reflection="(empty)",
use_json_mode=use_json_mode,
prompt_style=prompt_style,
call_id=f"{step_id}_post_curate",
log_dir=log_dir
)
Expand Down Expand Up @@ -672,6 +680,17 @@ def _offline_train(
error_logs = []
best_accuracy = 0.0
self.best_playbook = self.playbook
metrics_eval_path = os.path.join(save_path, "metrics_eval.jsonl")
checkpoints_index_path = os.path.join(save_path, "checkpoints_index.jsonl")

# Start each offline run with fresh metric/checkpoint index logs.
for p in (metrics_eval_path, checkpoints_index_path):
with open(p, "w", encoding="utf-8") as _f:
_f.write("")

def append_jsonl(path: str, payload: Dict[str, Any]) -> None:
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(payload, ensure_ascii=False) + "\n")

print(f"Total epochs: {num_epochs}")
print(f"Train samples per epoch: {len(train_samples)}")
Expand Down Expand Up @@ -731,6 +750,16 @@ def _offline_train(
)
with open(intermediate_path, "w") as f:
f.write(self.playbook)
append_jsonl(checkpoints_index_path, {
"timestamp": datetime.now().isoformat(),
"checkpoint_type": "intermediate",
"epoch": epoch,
"step": step,
"global_step": (epoch - 1) * len(train_samples) + step,
"path": intermediate_path,
"playbook_num_tokens": count_tokens(self.playbook),
"playbook_length": len(self.playbook),
})

# Periodic evaluation
if step % eval_steps == 0:
Expand All @@ -748,6 +777,7 @@ def _offline_train(

# Validation evaluation
val_results = {}
val_error_log = {}
if val_samples:
val_results, val_error_log = evaluate_test_set(
data_processor, self.generator, self.playbook,
Expand All @@ -774,6 +804,21 @@ def _offline_train(
"val_results": val_results,
"error_log": val_error_log
})
append_jsonl(metrics_eval_path, {
"timestamp": datetime.now().isoformat(),
"epoch": epoch,
"step": step,
"global_step": (epoch - 1) * len(train_samples) + step,
"train_pre_accuracy": pre_train_accuracy,
"train_post_accuracy": post_train_accuracy,
"val_mean_score": val_results.get("mean_score"),
"val_accuracy": val_results.get("accuracy"),
"val_format_valid_count": val_results.get("format_valid_count"),
"val_evaluated_count": val_results.get("evaluated_count"),
"val_failed_count": val_error_log.get("failed_count", 0),
"playbook_num_tokens": count_tokens(self.playbook),
"playbook_length": len(self.playbook),
})

# Track best playbook
if val_results:
Expand Down Expand Up @@ -801,6 +846,16 @@ def _offline_train(
)
with open(epoch_playbook_path, "w") as f:
f.write(self.playbook)
append_jsonl(checkpoints_index_path, {
"timestamp": datetime.now().isoformat(),
"checkpoint_type": "epoch_final",
"epoch": epoch,
"step": len(train_samples),
"global_step": epoch * len(train_samples),
"path": epoch_playbook_path,
"playbook_num_tokens": count_tokens(self.playbook),
"playbook_length": len(self.playbook),
})

# Save training results
results_path = os.path.join(save_path, "train_results.json")
Expand All @@ -818,11 +873,32 @@ def _offline_train(
final_playbook_path = os.path.join(save_path, f"final_playbook.txt")
with open(final_playbook_path, "w") as f:
f.write(self.playbook)
append_jsonl(checkpoints_index_path, {
"timestamp": datetime.now().isoformat(),
"checkpoint_type": "final_playbook",
"epoch": num_epochs,
"step": len(train_samples),
"global_step": num_epochs * len(train_samples),
"path": final_playbook_path,
"playbook_num_tokens": count_tokens(self.playbook),
"playbook_length": len(self.playbook),
})

# Save best playbook
best_playbook_path = os.path.join(save_path, f"best_playbook.txt")
with open(best_playbook_path, "w") as f:
f.write(self.best_playbook)
append_jsonl(checkpoints_index_path, {
"timestamp": datetime.now().isoformat(),
"checkpoint_type": "best_playbook",
"epoch": num_epochs,
"step": len(train_samples),
"global_step": num_epochs * len(train_samples),
"path": best_playbook_path,
"best_validation_accuracy": best_accuracy,
"playbook_num_tokens": count_tokens(self.best_playbook),
"playbook_length": len(self.best_playbook),
})

print(f"\n{'='*60}")
print(f"OFFLINE TRAINING COMPLETE")
Expand Down
12 changes: 9 additions & 3 deletions ace/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import json
import re
from typing import Dict, List, Tuple, Optional, Any
from ..prompts.generator import GENERATOR_PROMPT
from ..prompts.generator import GENERATOR_PROMPT_JSON, GENERATOR_PROMPT_CODE
from llm import timed_llm_call

class Generator:
Expand Down Expand Up @@ -37,6 +37,7 @@ def generate(
context: str = "",
reflection: str = "(empty)",
use_json_mode: bool = False,
prompt_style: str = "json",
call_id: str = "gen",
log_dir: Optional[str] = None
) -> Tuple[str, List[str], Dict[str, Any]]:
Expand All @@ -56,8 +57,13 @@ def generate(
Tuple of (full_response, bullet_ids_used, call_info)
"""
# Format the prompt
prompt = GENERATOR_PROMPT.format(playbook, reflection, question, context)
if prompt_style == "code":
prompt = GENERATOR_PROMPT_CODE.format(playbook, reflection, question, context)
else:
prompt = GENERATOR_PROMPT_JSON.format(playbook, reflection, question, context)

use_json_mode_call = use_json_mode and prompt_style != "code"

response, call_info = timed_llm_call(
self.api_client,
self.api_provider,
Expand All @@ -67,7 +73,7 @@ def generate(
call_id=call_id,
max_tokens=self.max_tokens,
log_dir=log_dir,
use_json_mode=use_json_mode
use_json_mode=use_json_mode_call
)

# Extract bullet IDs if using retrieval and reason mode
Expand Down
34 changes: 32 additions & 2 deletions ace/prompts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

# Retrieval and Reason Generator prompt that outputs bullet IDs
GENERATOR_PROMPT = """You are an analysis expert tasked with answering questions using your knowledge, a curated playbook of strategies and insights and a reflection that goes over the diagnosis of all previous mistakes made while answering the question.
GENERATOR_PROMPT_JSON = """You are an analysis expert tasked with answering questions using your knowledge, a curated playbook of strategies and insights and a reflection that goes over the diagnosis of all previous mistakes made while answering the question.

**Instructions:**
- Read the playbook carefully and apply relevant strategies, formulas, and insights
Expand Down Expand Up @@ -39,4 +39,34 @@
}}

---
"""
"""

# Code-only generator prompt for programming tasks
GENERATOR_PROMPT_CODE = """You are a coding expert tasked with solving programming problems using your knowledge, a curated playbook of strategies and insights, and a reflection that summarizes previous mistakes.

**Instructions:**
- Use the playbook and reflection when helpful
- Write a complete, runnable solution that follows the problem's input/output format
- Return only the final code (no explanations, no markdown, no JSON)
- Prefer C++17 when the question asks for it

**Playbook:**
{}

**Reflection:**
{}

**Question:**
{}

**Context:**
{}

**Output:**
Return only the code.

---
"""

# Backward-compatible alias
GENERATOR_PROMPT = GENERATOR_PROMPT_JSON
172 changes: 172 additions & 0 deletions eval/frontier-cs/data/algorithmic_all.jsonl

Large diffs are not rendered by default.

Loading