From 832635c0b00546d8e57d17ae6a6be9c31d1a10dc Mon Sep 17 00:00:00 2001 From: Shyamsaibethina Date: Tue, 3 Mar 2026 20:55:41 -0800 Subject: [PATCH 1/5] first draft of retriever --- ace/ace.py | 33 +++++++++------------ ace/core/__init__.py | 3 +- ace/core/retriever.py | 69 +++++++++++++++++++++++++++++++++++++++++++ eval/finance/run.py | 7 ++++- playbook_utils.py | 14 +++++++++ run_topk_sweep.sh | 53 +++++++++++++++++++++++++++++++++ 6 files changed, 158 insertions(+), 21 deletions(-) create mode 100644 ace/core/retriever.py create mode 100755 run_topk_sweep.sh diff --git a/ace/ace.py b/ace/ace.py index 2d662adc..77ec5a42 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -13,7 +13,7 @@ from datetime import datetime from typing import Dict, List, Tuple, Optional, Any -from .core import Generator, Reflector, Curator, BulletpointAnalyzer +from .core import Generator, Reflector, Curator, BulletpointAnalyzer, Retriever from playbook_utils import * from logger import * from utils import * @@ -39,21 +39,9 @@ def __init__( max_tokens: int = 4096, initial_playbook: Optional[str] = None, use_bulletpoint_analyzer: bool = False, - bulletpoint_analyzer_threshold: float = 0.90 + bulletpoint_analyzer_threshold: float = 0.90, + retriever_top_k: int = 5 ): - """ - Initialize the ACE system. - - Args: - api_provider: API provider for LLM calls - generator_model: Model name for generator - reflector_model: Model name for reflector - curator_model: Model name for curator - max_tokens: Maximum tokens for LLM calls - initial_playbook: Initial playbook content (optional) - use_bulletpoint_analyzer: Whether to use bulletpoint analyzer for deduplication - bulletpoint_analyzer_threshold: Similarity threshold for bulletpoint analyzer (0-1) - """ # Initialize API clients generator_client, reflector_client, curator_client = initialize_clients(api_provider) @@ -61,6 +49,7 @@ def __init__( self.generator = Generator(generator_client, api_provider, generator_model, max_tokens) self.reflector = Reflector(reflector_client, api_provider, reflector_model, max_tokens) self.curator = Curator(curator_client, api_provider, curator_model, max_tokens) + self.retriever = Retriever(top_k=retriever_top_k) # Initialize bulletpoint analyzer if requested and available self.use_bulletpoint_analyzer = use_bulletpoint_analyzer @@ -459,12 +448,14 @@ def _train_single_sample( question = task_dict.get("question", "") context = task_dict.get("context", "") target = task_dict.get("target", "") - + + retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) + # STEP 1: Initial generation (pre-train) print("Generating initial answer...") gen_response, bullet_ids, call_info = self.generator.generate( question=question, - playbook=self.playbook, + playbook=retrieved_playbook, context=context, reflection="(empty)", use_json_mode=use_json_mode, @@ -525,11 +516,13 @@ def _train_single_sample( self.playbook = update_bullet_counts( self.playbook, bullet_tags ) + + retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) # Regenerate with reflection gen_response, bullet_ids, _ = self.generator.generate( question=question, - playbook=self.playbook, + playbook=retrieved_playbook, context=context, reflection=reflection_content, use_json_mode=use_json_mode, @@ -604,11 +597,13 @@ def _train_single_sample( threshold=self.bulletpoint_analyzer_threshold, merge=True ) + + retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) # STEP 4: Post-curator generation gen_response, _, _ = self.generator.generate( question=question, - playbook=self.playbook, + playbook=retrieved_playbook, context=context, reflection="(empty)", use_json_mode=use_json_mode, diff --git a/ace/core/__init__.py b/ace/core/__init__.py index 6d4d0926..7f2ece92 100644 --- a/ace/core/__init__.py +++ b/ace/core/__init__.py @@ -7,5 +7,6 @@ from .reflector import Reflector from .curator import Curator from .bulletpoint_analyzer import BulletpointAnalyzer, DEDUP_AVAILABLE +from .retriever import Retriever -__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE'] \ No newline at end of file +__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE', 'Retriever'] \ No newline at end of file diff --git a/ace/core/retriever.py b/ace/core/retriever.py new file mode 100644 index 00000000..c7e28964 --- /dev/null +++ b/ace/core/retriever.py @@ -0,0 +1,69 @@ +from playbook_utils import parse_playbook_line, format_playbook_line +from sentence_transformers import SentenceTransformer +import numpy as np + + +class Retriever: + def __init__(self, model_name: str = "intfloat/multilingual-e5-large", top_k: int = 5): + self.model = SentenceTransformer(model_name) + self.top_k = top_k + + def retrieve(self, playbook: str, question: str, context: str = "", top_k: int = None) -> str: + top_k = top_k if top_k is not None else self.top_k + bullets_with_sections = self._extract_bullets_with_sections(playbook) + if not bullets_with_sections: + return "" + + top_k = min(top_k, len(bullets_with_sections)) + + passage_texts = [ + "passage: " + b["content"] for b in bullets_with_sections + ] + passage_embeddings = self.model.encode(passage_texts, normalize_embeddings=True) + + query_text = "query: " + question + " " + context + query_embedding = self.model.encode(query_text, normalize_embeddings=True) + + similarities = np.dot(passage_embeddings, query_embedding) + top_k_indices = set(np.argsort(similarities)[-top_k:]) + + section_order = list(dict.fromkeys( + b["section"] for b in bullets_with_sections + )) + + section_bullets: dict[str, list[str]] = {s: [] for s in section_order} + for i, b in enumerate(bullets_with_sections): + if i in top_k_indices: + line = format_playbook_line(b["id"], b["helpful"], b["harmful"], b["content"]) + section_bullets[b["section"]].append(line) + + lines = [] + for section in section_order: + if section_bullets[section]: + lines.append(section) + lines.extend(section_bullets[section]) + lines.append("") + + print("Used bullets:") + for i, b in enumerate(bullets_with_sections): + if i in top_k_indices: + print(f" {b['id']}: {b['content']}") + + print("Out of total bullets:") + print(f" {len(bullets_with_sections)}") + + return "\n".join(lines).rstrip() + + # To make sure we go by sections and we keep the section order + def _extract_bullets_with_sections(self, playbook: str) -> list: + results = [] + current_section = "" + for line in playbook.strip().split("\n"): + if line.strip().startswith("##"): + current_section = line.strip() + else: + parsed = parse_playbook_line(line) + if parsed: + parsed["section"] = current_section + results.append(parsed) + return results diff --git a/eval/finance/run.py b/eval/finance/run.py index 86535536..85852771 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -75,6 +75,10 @@ def parse_args(): parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") + # Retriever configuration + parser.add_argument("--retriever_top_k", type=int, default=5, + help="Number of top bullets to retrieve per question (default: 5)") + # Output configuration parser.add_argument("--save_path", type=str, required=True, help="Directory to save results") @@ -202,7 +206,8 @@ def main(): max_tokens=args.max_tokens, initial_playbook=initial_playbook, use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, - bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold + bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold, + retriever_top_k=args.retriever_top_k ) # Prepare configuration diff --git a/playbook_utils.py b/playbook_utils.py index a457ba96..b445b0e9 100644 --- a/playbook_utils.py +++ b/playbook_utils.py @@ -369,3 +369,17 @@ def extract_playbook_bullets(playbook_text, bullet_ids): formatted_bullets.append(f"[{bullet['id']}] helpful={bullet['helpful']} harmful={bullet['harmful']} :: {bullet['content']}") return '\n'.join(formatted_bullets) + + +def extract_all_bullets(playbook_text): + """ + Extract all bullet points from playbook. + """ + lines = playbook_text.strip().split('\n') + all_bullets = [] + for line in lines: + if line.strip(): + parsed = parse_playbook_line(line) + if parsed: + all_bullets.append(parsed) + return all_bullets \ No newline at end of file diff --git a/run_topk_sweep.sh b/run_topk_sweep.sh new file mode 100755 index 00000000..3d001f3a --- /dev/null +++ b/run_topk_sweep.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -e + +TOP_K_VALUES=(5 10 20) +PIDS=() + +for top_k in "${TOP_K_VALUES[@]}"; do + echo "Launching run with retriever_top_k=${top_k}..." + + python -m eval.finance.run \ + --task_name finer \ + --mode offline \ + --save_path "results/topk_${top_k}" \ + --api_provider openai \ + --generator_model gpt-4o \ + --reflector_model gpt-4o \ + --curator_model gpt-4o \ + --retriever_top_k "${top_k}" \ + > "results/topk_${top_k}.log" 2>&1 & + + PIDS+=($!) + echo " PID=$! -> results/topk_${top_k}/ (log: results/topk_${top_k}.log)" +done + +echo "" +echo "All ${#TOP_K_VALUES[@]} runs launched in parallel." +echo "Waiting for completion..." +echo "" + +FAILED=0 +for i in "${!PIDS[@]}"; do + pid=${PIDS[$i]} + top_k=${TOP_K_VALUES[$i]} + if wait "$pid"; then + echo "top_k=${top_k} (PID ${pid}) finished successfully" + else + echo "top_k=${top_k} (PID ${pid}) FAILED (exit code $?)" + FAILED=$((FAILED + 1)) + fi +done + +echo "" +echo "============================================" +if [ "$FAILED" -eq 0 ]; then + echo "All runs complete!" +else + echo "${FAILED} run(s) failed. Check logs." +fi +echo "Results:" +for top_k in "${TOP_K_VALUES[@]}"; do + echo " - results/topk_${top_k}/ (log: results/topk_${top_k}.log)" +done +echo "============================================" From 5f5690fb2b915cb05dd015516cc17df3b77b8e36 Mon Sep 17 00:00:00 2001 From: Shyam Sai Bethina Date: Fri, 6 Mar 2026 05:08:03 +0000 Subject: [PATCH 2/5] first draft --- ace/ace.py | 30 +++++++++++++++------------ ace/core/retriever.py | 47 +++++++++++++++++++++--------------------- eval/finance/run.py | 11 +++++----- utils.py | 48 ++++++++++++++++++++++++++++++++----------- 4 files changed, 82 insertions(+), 54 deletions(-) diff --git a/ace/ace.py b/ace/ace.py index 77ec5a42..e5bd8a9c 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -279,7 +279,8 @@ def run( config=config, log_dir=log_dir, save_path=save_path, - prefix="final" + prefix="final", + use_retriever=True ) results['final_test_results'] = final_test_results print(f"Final Test Accuracy: {final_test_results['accuracy']:.3f}\n") @@ -329,7 +330,8 @@ def run( config=config, log_dir=log_dir, save_path=save_path, - prefix="test" + prefix="test", + use_retriever=config.get('use_retriever', False) ) results['test_results'] = test_results @@ -366,7 +368,8 @@ def _run_test( config: Dict[str, Any], log_dir: str, save_path: str, - prefix: str = "test" + prefix: str = "test", + use_retriever: bool = False ) -> Dict[str, Any]: """ Run testing @@ -379,6 +382,7 @@ def _run_test( log_dir: Directory for detailed logs save_path: Path to save results prefix: Prefix for saved files (e.g., 'initial', 'final', 'test') + use_retriever: If True, use retriever to build per-sample mini playbooks Returns: Dictionary with test results @@ -386,6 +390,11 @@ def _run_test( config_params = self._extract_config_params(config) use_json_mode = config_params['use_json_mode'] test_workers = config_params['test_workers'] + + retriever = None + if use_retriever: + self.retriever.index_playbook(playbook) + retriever = self.retriever test_results, test_error_log = evaluate_test_set( data_processor, @@ -395,7 +404,8 @@ def _run_test( self.max_tokens, log_dir, max_workers=test_workers, - use_json_mode=use_json_mode + use_json_mode=use_json_mode, + retriever=retriever ) # Save test results @@ -449,13 +459,11 @@ def _train_single_sample( context = task_dict.get("context", "") target = task_dict.get("target", "") - retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) - # STEP 1: Initial generation (pre-train) print("Generating initial answer...") gen_response, bullet_ids, call_info = self.generator.generate( question=question, - playbook=retrieved_playbook, + playbook=self.playbook, context=context, reflection="(empty)", use_json_mode=use_json_mode, @@ -517,12 +525,10 @@ def _train_single_sample( self.playbook, bullet_tags ) - retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) - # Regenerate with reflection gen_response, bullet_ids, _ = self.generator.generate( question=question, - playbook=retrieved_playbook, + playbook=self.playbook, context=context, reflection=reflection_content, use_json_mode=use_json_mode, @@ -598,12 +604,10 @@ def _train_single_sample( merge=True ) - retrieved_playbook = self.retriever.retrieve(self.playbook, question, context) - # STEP 4: Post-curator generation gen_response, _, _ = self.generator.generate( question=question, - playbook=retrieved_playbook, + playbook=self.playbook, context=context, reflection="(empty)", use_json_mode=use_json_mode, diff --git a/ace/core/retriever.py b/ace/core/retriever.py index c7e28964..8fdab759 100644 --- a/ace/core/retriever.py +++ b/ace/core/retriever.py @@ -5,34 +5,42 @@ class Retriever: def __init__(self, model_name: str = "intfloat/multilingual-e5-large", top_k: int = 5): - self.model = SentenceTransformer(model_name) + self.model = SentenceTransformer(model_name, device="cpu") self.top_k = top_k + self.bullets_with_sections: list = [] + self.passage_embeddings: np.ndarray | None = None - def retrieve(self, playbook: str, question: str, context: str = "", top_k: int = None) -> str: - top_k = top_k if top_k is not None else self.top_k - bullets_with_sections = self._extract_bullets_with_sections(playbook) - if not bullets_with_sections: - return "" - - top_k = min(top_k, len(bullets_with_sections)) - + def index_playbook(self, playbook: str) -> None: + """Parse playbook bullets and pre-compute their embeddings.""" + self.bullets_with_sections = self._extract_bullets_with_sections(playbook) + if not self.bullets_with_sections: + self.passage_embeddings = None + return passage_texts = [ - "passage: " + b["content"] for b in bullets_with_sections + "passage: " + b["content"] for b in self.bullets_with_sections ] - passage_embeddings = self.model.encode(passage_texts, normalize_embeddings=True) + self.passage_embeddings = self.model.encode(passage_texts, normalize_embeddings=True) + + def retrieve(self, question: str, context: str = "", top_k: int | None = None) -> str: + """Return a mini-playbook string containing the top-k most relevant bullets.""" + if not self.bullets_with_sections or self.passage_embeddings is None: + return "" + + top_k = top_k if top_k is not None else self.top_k + top_k = min(top_k, len(self.bullets_with_sections)) query_text = "query: " + question + " " + context query_embedding = self.model.encode(query_text, normalize_embeddings=True) - similarities = np.dot(passage_embeddings, query_embedding) + similarities = np.dot(self.passage_embeddings, query_embedding) top_k_indices = set(np.argsort(similarities)[-top_k:]) section_order = list(dict.fromkeys( - b["section"] for b in bullets_with_sections + b["section"] for b in self.bullets_with_sections )) section_bullets: dict[str, list[str]] = {s: [] for s in section_order} - for i, b in enumerate(bullets_with_sections): + for i, b in enumerate(self.bullets_with_sections): if i in top_k_indices: line = format_playbook_line(b["id"], b["helpful"], b["harmful"], b["content"]) section_bullets[b["section"]].append(line) @@ -44,17 +52,8 @@ def retrieve(self, playbook: str, question: str, context: str = "", top_k: int = lines.extend(section_bullets[section]) lines.append("") - print("Used bullets:") - for i, b in enumerate(bullets_with_sections): - if i in top_k_indices: - print(f" {b['id']}: {b['content']}") - - print("Out of total bullets:") - print(f" {len(bullets_with_sections)}") - return "\n".join(lines).rstrip() - - # To make sure we go by sections and we keep the section order + def _extract_bullets_with_sections(self, playbook: str) -> list: results = [] current_section = "" diff --git a/eval/finance/run.py b/eval/finance/run.py index 85852771..aa1c5f45 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -75,9 +75,9 @@ def parse_args(): parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") - # Retriever configuration - parser.add_argument("--retriever_top_k", type=int, default=5, - help="Number of top bullets to retrieve per question (default: 5)") + # Retriever configuration — passing --retriever_top_k enables retrieval + parser.add_argument("--retriever_top_k", type=int, default=None, + help="Number of top bullets to retrieve per question. Enables retrieval when set.") # Output configuration parser.add_argument("--save_path", type=str, required=True, @@ -207,7 +207,7 @@ def main(): initial_playbook=initial_playbook, use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold, - retriever_top_k=args.retriever_top_k + retriever_top_k=args.retriever_top_k or 5 ) # Prepare configuration @@ -228,7 +228,8 @@ def main(): 'initial_playbook_path': args.initial_playbook_path, 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, - 'api_provider': args.api_provider + 'api_provider': args.api_provider, + 'use_retriever': args.retriever_top_k is not None } # Execute using the unified run method diff --git a/utils.py b/utils.py index 01ef7bdc..f8a3b60c 100644 --- a/utils.py +++ b/utils.py @@ -38,7 +38,7 @@ def initialize_clients(api_provider): reflector_client = openai.OpenAI(api_key=api_key, base_url=base_url) curator_client = openai.OpenAI(api_key=api_key, base_url=base_url) - print("Using Together API for all models") + print(f"Using {api_provider} API for all models") return generator_client, reflector_client, curator_client def get_section_slug(section_name): @@ -152,13 +152,15 @@ def count_tokens(prompt: str) -> int: return len(enc.encode(prompt)) -def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: +def evaluate_single_test_sample(args_tuple, data_processor, retriever=None) -> Tuple[Dict, str]: """ Evaluate a single test sample - task-agnostic implementation. Args: args_tuple: Tuple of (index, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode) data_processor: DataProcessor instance with answer_is_correct method + retriever: Optional Retriever instance. When provided, retrieves a per-sample + mini playbook instead of using the full playbook. """ (i, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode) = args_tuple try: @@ -166,6 +168,11 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: question = task_dict["question"] target = task_dict["target"] + sub_playbook_tokens = None + if retriever is not None: + playbook = retriever.retrieve(question, context) + sub_playbook_tokens = count_tokens(playbook) + gen_response, bullet_ids, call_info = generator.generate( question=question, playbook=playbook, @@ -179,13 +186,17 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: final_answer = extract_answer(gen_response) is_correct = data_processor.answer_is_correct(final_answer, target) - return { + result = { "index": i, "final_answer": final_answer, "target": target, "is_correct": is_correct, "success": True - }, None + } + if sub_playbook_tokens is not None: + result["sub_playbook_tokens"] = sub_playbook_tokens + + return result, None except Exception as e: return None, f"Error evaluating sample {i}: {type(e).__name__}: {str(e)}" @@ -193,25 +204,29 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: def evaluate_test_set(data_processor, generator, playbook, test_samples, max_tokens=4096, log_dir=None, max_workers=20, - use_json_mode=False) -> Tuple[Dict, Dict]: + use_json_mode=False, retriever=None) -> Tuple[Dict, Dict]: """ Parallel evaluation of test set - task-agnostic implementation. Args: data_processor: DataProcessor instance with answer_is_correct and evaluate_accuracy methods generator: Generator instance - playbook: Current playbook string + playbook: Current playbook string (used as fallback when retriever is None) test_samples: List of test samples max_tokens: Max tokens for generation log_dir: Directory for logs max_workers: Number of parallel workers use_json_mode: Whether to use JSON mode + retriever: Optional Retriever instance (already indexed). When provided, + each sample gets a per-sample mini playbook via retrieval. Returns: Tuple of (results_dict, error_logs_dict) """ print(f"\n{'='*40}") print(f"EVALUATING TEST SET - {len(test_samples)} samples, {max_workers} workers") + if retriever is not None: + print(f" Using retriever (top_k={retriever.top_k})") print(f"{'='*40}") args_list = [ @@ -221,12 +236,12 @@ def evaluate_test_set(data_processor, generator, playbook, test_samples, results = { "correct": 0, "total": 0, "no_answer": 0, - "answers": [], "targets": [], "errors": [] + "answers": [], "targets": [], "errors": [], + "sub_playbook_token_counts": [] } - # Use a wrapper to pass data_processor to the evaluation function def eval_wrapper(args_tuple): - return evaluate_single_test_sample(args_tuple, data_processor) + return evaluate_single_test_sample(args_tuple, data_processor, retriever=retriever) with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_args = { @@ -246,6 +261,9 @@ def eval_wrapper(args_tuple): results["total"] += 1 results["answers"].append(result["final_answer"]) results["targets"].append(result["target"]) + + if "sub_playbook_tokens" in result: + results["sub_playbook_token_counts"].append(result["sub_playbook_tokens"]) if not result["is_correct"]: results["errors"].append({ @@ -270,16 +288,22 @@ def eval_wrapper(args_tuple): "total": results["total"], "no_answer": results["no_answer"] } + + token_counts = results["sub_playbook_token_counts"] + if token_counts: + avg_tokens = round(sum(token_counts) / len(token_counts), 1) + final_results["avg_sub_playbook_tokens"] = avg_tokens + print(f"\n Avg sub-playbook tokens: {avg_tokens}") error_logs = { "accuracy": accuracy, "errors": results["errors"] } - print(f"\nšŸ“Š Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})") + print(f"\n Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})") else: - results = {"accuracy": 0.0, "correct": 0, "total": 0} + final_results = {"accuracy": 0.0, "correct": 0, "total": 0} error_logs = {} - print(f"\nšŸ“Š No valid results!") + print(f"\n No valid results!") return final_results, error_logs \ No newline at end of file From 800e73fd933c7c5875a07698591d6290bb041bcf Mon Sep 17 00:00:00 2001 From: Shyam Sai Bethina Date: Thu, 12 Mar 2026 01:06:27 +0000 Subject: [PATCH 3/5] added more arguments to config --- README.md | 16 ++++++++++++++-- ace/ace.py | 10 +++++++--- eval/finance/run.py | 9 +++++++-- 3 files changed, 28 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index c9697711..d5ffdea0 100644 --- a/README.md +++ b/README.md @@ -120,8 +120,10 @@ config = { 'save_dir': './results', 'test_workers': 20, 'use_bulletpoint_analyzer': false, - 'api_provider': api_provider - + 'api_provider': api_provider, + 'use_retriever': False, + 'retriever_top_k': 10, + 'retriever_model_name': 'intfloat/multilingual-e5-large' } # Offline adaptation @@ -177,6 +179,14 @@ python -m eval.finance.run \ --initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \ --save_path test_results +# Evaluation with retrieval-based sub-playbooks (top-k bullets per sample) +python -m eval.finance.run \ + --task_name finer \ + --mode eval_only \ + --initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \ + --retriever_top_k 10 \ + --save_path test_results_topk10 + # Training with custom configuration python -m eval.finance.run \ --task_name finer \ @@ -214,6 +224,8 @@ python -m eval.finance.run \ | `--no_ground_truth` | Don't use ground truth in reflection | False | | `--use_bulletpoint_analyzer` | Enable bulletpoint analyzer for playbook deduplication and merging | False | | `--bulletpoint_analyzer_threshold` | Similarity threshold for bulletpoint analyzer (0-1) | 0.9 | +| `--retriever_top_k` | Number of top bullets to retrieve per sample. Enables retrieval when set. | None (disabled) | +| `--retriever_model_name` | Sentence-transformers model for retrieval embeddings | `intfloat/multilingual-e5-large` | diff --git a/ace/ace.py b/ace/ace.py index e5bd8a9c..4f64bf94 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -40,7 +40,8 @@ def __init__( initial_playbook: Optional[str] = None, use_bulletpoint_analyzer: bool = False, bulletpoint_analyzer_threshold: float = 0.90, - retriever_top_k: int = 5 + retriever_top_k: int = 5, + retriever_model_name: str = "intfloat/multilingual-e5-large" ): # Initialize API clients generator_client, reflector_client, curator_client = initialize_clients(api_provider) @@ -49,7 +50,7 @@ def __init__( self.generator = Generator(generator_client, api_provider, generator_model, max_tokens) self.reflector = Reflector(reflector_client, api_provider, reflector_model, max_tokens) self.curator = Curator(curator_client, api_provider, curator_model, max_tokens) - self.retriever = Retriever(top_k=retriever_top_k) + self.retriever = Retriever(model_name=retriever_model_name, top_k=retriever_top_k) # Initialize bulletpoint analyzer if requested and available self.use_bulletpoint_analyzer = use_bulletpoint_analyzer @@ -120,7 +121,10 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: 'save_dir': config.get('save_dir', './results'), 'test_workers': config.get('test_workers', 20), 'use_bulletpoint_analyzer': config.get('use_bulletpoint_analyzer', False), - 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90) + 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90), + 'use_retriever': config.get('use_retriever', False), + 'retriever_top_k': config.get('retriever_top_k', 5), + 'retriever_model_name': config.get('retriever_model_name', 'intfloat/multilingual-e5-large') } def _setup_paths(self, save_dir: str, task_name: str, mode: str) -> Tuple[str, str]: diff --git a/eval/finance/run.py b/eval/finance/run.py index aa1c5f45..65af4d6d 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -78,6 +78,8 @@ def parse_args(): # Retriever configuration — passing --retriever_top_k enables retrieval parser.add_argument("--retriever_top_k", type=int, default=None, help="Number of top bullets to retrieve per question. Enables retrieval when set.") + parser.add_argument("--retriever_model_name", type=str, default="intfloat/multilingual-e5-large", + help="Sentence-transformers model for retrieval embeddings") # Output configuration parser.add_argument("--save_path", type=str, required=True, @@ -207,7 +209,8 @@ def main(): initial_playbook=initial_playbook, use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold, - retriever_top_k=args.retriever_top_k or 5 + retriever_top_k=args.retriever_top_k or 5, + retriever_model_name=args.retriever_model_name ) # Prepare configuration @@ -229,7 +232,9 @@ def main(): 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, 'api_provider': args.api_provider, - 'use_retriever': args.retriever_top_k is not None + 'use_retriever': args.retriever_top_k is not None, + 'retriever_top_k': args.retriever_top_k, + 'retriever_model_name': args.retriever_model_name } # Execute using the unified run method From 9a6a5074e570402bf5f8458be1bdbc188e3ff30d Mon Sep 17 00:00:00 2001 From: Shyam Sai Bethina Date: Thu, 12 Mar 2026 01:14:36 +0000 Subject: [PATCH 4/5] removing the script file I made --- run_topk_sweep.sh | 53 ----------------------------------------------- 1 file changed, 53 deletions(-) delete mode 100755 run_topk_sweep.sh diff --git a/run_topk_sweep.sh b/run_topk_sweep.sh deleted file mode 100755 index 3d001f3a..00000000 --- a/run_topk_sweep.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -set -e - -TOP_K_VALUES=(5 10 20) -PIDS=() - -for top_k in "${TOP_K_VALUES[@]}"; do - echo "Launching run with retriever_top_k=${top_k}..." - - python -m eval.finance.run \ - --task_name finer \ - --mode offline \ - --save_path "results/topk_${top_k}" \ - --api_provider openai \ - --generator_model gpt-4o \ - --reflector_model gpt-4o \ - --curator_model gpt-4o \ - --retriever_top_k "${top_k}" \ - > "results/topk_${top_k}.log" 2>&1 & - - PIDS+=($!) - echo " PID=$! -> results/topk_${top_k}/ (log: results/topk_${top_k}.log)" -done - -echo "" -echo "All ${#TOP_K_VALUES[@]} runs launched in parallel." -echo "Waiting for completion..." -echo "" - -FAILED=0 -for i in "${!PIDS[@]}"; do - pid=${PIDS[$i]} - top_k=${TOP_K_VALUES[$i]} - if wait "$pid"; then - echo "top_k=${top_k} (PID ${pid}) finished successfully" - else - echo "top_k=${top_k} (PID ${pid}) FAILED (exit code $?)" - FAILED=$((FAILED + 1)) - fi -done - -echo "" -echo "============================================" -if [ "$FAILED" -eq 0 ]; then - echo "All runs complete!" -else - echo "${FAILED} run(s) failed. Check logs." -fi -echo "Results:" -for top_k in "${TOP_K_VALUES[@]}"; do - echo " - results/topk_${top_k}/ (log: results/topk_${top_k}.log)" -done -echo "============================================" From 8ff1290d5cccba6b957f6138d037d2103cf97a28 Mon Sep 17 00:00:00 2001 From: Shyam Sai Bethina Date: Thu, 12 Mar 2026 01:17:30 +0000 Subject: [PATCH 5/5] Added retriever args to the comment --- ace/ace.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ace/ace.py b/ace/ace.py index 4f64bf94..640e8861 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -43,6 +43,21 @@ def __init__( retriever_top_k: int = 5, retriever_model_name: str = "intfloat/multilingual-e5-large" ): + """ + Initialize the ACE system. + + Args: + api_provider: API provider for LLM calls + generator_model: Model name for generator + reflector_model: Model name for reflector + curator_model: Model name for curator + max_tokens: Maximum tokens for LLM calls + initial_playbook: Initial playbook content (optional) + use_bulletpoint_analyzer: Whether to use bulletpoint analyzer for deduplication + bulletpoint_analyzer_threshold: Similarity threshold for bulletpoint analyzer (0-1) + retriever_top_k: Number of top bullets to retrieve per sample + retriever_model_name: Sentence-transformers model for retrieval embeddings + """ # Initialize API clients generator_client, reflector_client, curator_client = initialize_clients(api_provider)