Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,10 @@ config = {
'save_dir': './results',
'test_workers': 20,
'use_bulletpoint_analyzer': false,
'api_provider': api_provider

'api_provider': api_provider,
'use_retriever': False,
'retriever_top_k': 10,
'retriever_model_name': 'intfloat/multilingual-e5-large'
}

# Offline adaptation
Expand Down Expand Up @@ -181,6 +183,14 @@ uv run python -m eval.finance.run \
--initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \
--save_path test_results

# Evaluation with retrieval-based sub-playbooks (top-k bullets per sample)
python -m eval.finance.run \
--task_name finer \
--mode eval_only \
--initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \
--retriever_top_k 10 \
--save_path test_results_topk10

# Training with custom configuration
uv run python -m eval.finance.run \
--task_name finer \
Expand Down Expand Up @@ -218,6 +228,8 @@ uv run python -m eval.finance.run \
| `--no_ground_truth` | Don't use ground truth in reflection | False |
| `--use_bulletpoint_analyzer` | Enable bulletpoint analyzer for playbook deduplication and merging | False |
| `--bulletpoint_analyzer_threshold` | Similarity threshold for bulletpoint analyzer (0-1) | 0.9 |
| `--retriever_top_k` | Number of top bullets to retrieve per sample. Enables retrieval when set. | None (disabled) |
| `--retriever_model_name` | Sentence-transformers model for retrieval embeddings | `intfloat/multilingual-e5-large` |

</details>

Expand Down
38 changes: 28 additions & 10 deletions ace/ace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datetime import datetime
from typing import Dict, List, Tuple, Optional, Any

from .core import Generator, Reflector, Curator, BulletpointAnalyzer
from .core import Generator, Reflector, Curator, BulletpointAnalyzer, Retriever
from playbook_utils import *
from logger import *
from utils import *
Expand All @@ -39,7 +39,9 @@ def __init__(
max_tokens: int = 4096,
initial_playbook: Optional[str] = None,
use_bulletpoint_analyzer: bool = False,
bulletpoint_analyzer_threshold: float = 0.90
bulletpoint_analyzer_threshold: float = 0.90,
retriever_top_k: int = 5,
retriever_model_name: str = "intfloat/multilingual-e5-large"
):
"""
Initialize the ACE system.
Expand All @@ -53,6 +55,8 @@ def __init__(
initial_playbook: Initial playbook content (optional)
use_bulletpoint_analyzer: Whether to use bulletpoint analyzer for deduplication
bulletpoint_analyzer_threshold: Similarity threshold for bulletpoint analyzer (0-1)
retriever_top_k: Number of top bullets to retrieve per sample
retriever_model_name: Sentence-transformers model for retrieval embeddings
"""
# Initialize API clients
generator_client, reflector_client, curator_client = initialize_clients(api_provider)
Expand All @@ -61,6 +65,7 @@ def __init__(
self.generator = Generator(generator_client, api_provider, generator_model, max_tokens)
self.reflector = Reflector(reflector_client, api_provider, reflector_model, max_tokens)
self.curator = Curator(curator_client, api_provider, curator_model, max_tokens)
self.retriever = Retriever(model_name=retriever_model_name, top_k=retriever_top_k)

# Initialize bulletpoint analyzer if requested and available
self.use_bulletpoint_analyzer = use_bulletpoint_analyzer
Expand Down Expand Up @@ -131,7 +136,10 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]:
'save_dir': config.get('save_dir', './results'),
'test_workers': config.get('test_workers', 20),
'use_bulletpoint_analyzer': config.get('use_bulletpoint_analyzer', False),
'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90)
'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90),
'use_retriever': config.get('use_retriever', False),
'retriever_top_k': config.get('retriever_top_k', 5),
'retriever_model_name': config.get('retriever_model_name', 'intfloat/multilingual-e5-large')
}

def _setup_paths(self, save_dir: str, task_name: str, mode: str) -> Tuple[str, str]:
Expand Down Expand Up @@ -290,7 +298,8 @@ def run(
config=config,
log_dir=log_dir,
save_path=save_path,
prefix="final"
prefix="final",
use_retriever=True
)
results['final_test_results'] = final_test_results
print(f"Final Test Accuracy: {final_test_results['accuracy']:.3f}\n")
Expand Down Expand Up @@ -340,7 +349,8 @@ def run(
config=config,
log_dir=log_dir,
save_path=save_path,
prefix="test"
prefix="test",
use_retriever=config.get('use_retriever', False)
)
results['test_results'] = test_results

Expand Down Expand Up @@ -377,7 +387,8 @@ def _run_test(
config: Dict[str, Any],
log_dir: str,
save_path: str,
prefix: str = "test"
prefix: str = "test",
use_retriever: bool = False
) -> Dict[str, Any]:
"""
Run testing
Expand All @@ -390,13 +401,19 @@ def _run_test(
log_dir: Directory for detailed logs
save_path: Path to save results
prefix: Prefix for saved files (e.g., 'initial', 'final', 'test')
use_retriever: If True, use retriever to build per-sample mini playbooks

Returns:
Dictionary with test results
"""
config_params = self._extract_config_params(config)
use_json_mode = config_params['use_json_mode']
test_workers = config_params['test_workers']

retriever = None
if use_retriever:
self.retriever.index_playbook(playbook)
retriever = self.retriever

test_results, test_error_log = evaluate_test_set(
data_processor,
Expand All @@ -406,7 +423,8 @@ def _run_test(
self.max_tokens,
log_dir,
max_workers=test_workers,
use_json_mode=use_json_mode
use_json_mode=use_json_mode,
retriever=retriever
)

# Save test results
Expand Down Expand Up @@ -459,7 +477,7 @@ def _train_single_sample(
question = task_dict.get("question", "")
context = task_dict.get("context", "")
target = task_dict.get("target", "")

# STEP 1: Initial generation (pre-train)
print("Generating initial answer...")
gen_response, bullet_ids, call_info = self.generator.generate(
Expand Down Expand Up @@ -525,7 +543,7 @@ def _train_single_sample(
self.playbook = update_bullet_counts(
self.playbook, bullet_tags
)

# Regenerate with reflection
gen_response, bullet_ids, _ = self.generator.generate(
question=question,
Expand Down Expand Up @@ -604,7 +622,7 @@ def _train_single_sample(
threshold=self.bulletpoint_analyzer_threshold,
merge=True
)

# STEP 4: Post-curator generation
gen_response, _, _ = self.generator.generate(
question=question,
Expand Down
3 changes: 2 additions & 1 deletion ace/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .reflector import Reflector
from .curator import Curator
from .bulletpoint_analyzer import BulletpointAnalyzer, DEDUP_AVAILABLE
from .retriever import Retriever

__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE']
__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE', 'Retriever']
68 changes: 68 additions & 0 deletions ace/core/retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from playbook_utils import parse_playbook_line, format_playbook_line
from sentence_transformers import SentenceTransformer
import numpy as np


class Retriever:
def __init__(self, model_name: str = "intfloat/multilingual-e5-large", top_k: int = 5):
self.model = SentenceTransformer(model_name, device="cpu")
self.top_k = top_k
self.bullets_with_sections: list = []
self.passage_embeddings: np.ndarray | None = None

def index_playbook(self, playbook: str) -> None:
"""Parse playbook bullets and pre-compute their embeddings."""
self.bullets_with_sections = self._extract_bullets_with_sections(playbook)
if not self.bullets_with_sections:
self.passage_embeddings = None
return
passage_texts = [
"passage: " + b["content"] for b in self.bullets_with_sections
]
self.passage_embeddings = self.model.encode(passage_texts, normalize_embeddings=True)

def retrieve(self, question: str, context: str = "", top_k: int | None = None) -> str:
"""Return a mini-playbook string containing the top-k most relevant bullets."""
if not self.bullets_with_sections or self.passage_embeddings is None:
return ""

top_k = top_k if top_k is not None else self.top_k
top_k = min(top_k, len(self.bullets_with_sections))

query_text = "query: " + question + " " + context
query_embedding = self.model.encode(query_text, normalize_embeddings=True)

similarities = np.dot(self.passage_embeddings, query_embedding)
top_k_indices = set(np.argsort(similarities)[-top_k:])

section_order = list(dict.fromkeys(
b["section"] for b in self.bullets_with_sections
))

section_bullets: dict[str, list[str]] = {s: [] for s in section_order}
for i, b in enumerate(self.bullets_with_sections):
if i in top_k_indices:
line = format_playbook_line(b["id"], b["helpful"], b["harmful"], b["content"])
section_bullets[b["section"]].append(line)

lines = []
for section in section_order:
if section_bullets[section]:
lines.append(section)
lines.extend(section_bullets[section])
lines.append("")

return "\n".join(lines).rstrip()

def _extract_bullets_with_sections(self, playbook: str) -> list:
results = []
current_section = ""
for line in playbook.strip().split("\n"):
if line.strip().startswith("##"):
current_section = line.strip()
else:
parsed = parse_playbook_line(line)
if parsed:
parsed["section"] = current_section
results.append(parsed)
return results
15 changes: 13 additions & 2 deletions eval/finance/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def parse_args():
parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90,
help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)")

# Retriever configuration — passing --retriever_top_k enables retrieval
parser.add_argument("--retriever_top_k", type=int, default=None,
help="Number of top bullets to retrieve per question. Enables retrieval when set.")
parser.add_argument("--retriever_model_name", type=str, default="intfloat/multilingual-e5-large",
help="Sentence-transformers model for retrieval embeddings")

# Output configuration
parser.add_argument("--save_path", type=str, required=True,
help="Directory to save results")
Expand Down Expand Up @@ -202,7 +208,9 @@ def main():
max_tokens=args.max_tokens,
initial_playbook=initial_playbook,
use_bulletpoint_analyzer=args.use_bulletpoint_analyzer,
bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold
bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold,
retriever_top_k=args.retriever_top_k or 5,
retriever_model_name=args.retriever_model_name
)

# Prepare configuration
Expand All @@ -223,7 +231,10 @@ def main():
'initial_playbook_path': args.initial_playbook_path,
'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer,
'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold,
'api_provider': args.api_provider
'api_provider': args.api_provider,
'use_retriever': args.retriever_top_k is not None,
'retriever_top_k': args.retriever_top_k,
'retriever_model_name': args.retriever_model_name
}

# Execute using the unified run method
Expand Down
14 changes: 14 additions & 0 deletions playbook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,17 @@ def extract_playbook_bullets(playbook_text, bullet_ids):
formatted_bullets.append(f"[{bullet['id']}] helpful={bullet['helpful']} harmful={bullet['harmful']} :: {bullet['content']}")

return '\n'.join(formatted_bullets)


def extract_all_bullets(playbook_text):
"""
Extract all bullet points from playbook.
"""
lines = playbook_text.strip().split('\n')
all_bullets = []
for line in lines:
if line.strip():
parsed = parse_playbook_line(line)
if parsed:
all_bullets.append(parsed)
return all_bullets
Loading