diff --git a/README.md b/README.md index 95b1c04b..f968d5db 100644 --- a/README.md +++ b/README.md @@ -124,8 +124,10 @@ config = { 'save_dir': './results', 'test_workers': 20, 'use_bulletpoint_analyzer': false, - 'api_provider': api_provider - + 'api_provider': api_provider, + 'use_retriever': False, + 'retriever_top_k': 10, + 'retriever_model_name': 'intfloat/multilingual-e5-large' } # Offline adaptation @@ -181,6 +183,14 @@ uv run python -m eval.finance.run \ --initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \ --save_path test_results +# Evaluation with retrieval-based sub-playbooks (top-k bullets per sample) +python -m eval.finance.run \ + --task_name finer \ + --mode eval_only \ + --initial_playbook_path results/ace_run_TIMESTAMP_finer_offline/best_playbook.txt \ + --retriever_top_k 10 \ + --save_path test_results_topk10 + # Training with custom configuration uv run python -m eval.finance.run \ --task_name finer \ @@ -218,6 +228,8 @@ uv run python -m eval.finance.run \ | `--no_ground_truth` | Don't use ground truth in reflection | False | | `--use_bulletpoint_analyzer` | Enable bulletpoint analyzer for playbook deduplication and merging | False | | `--bulletpoint_analyzer_threshold` | Similarity threshold for bulletpoint analyzer (0-1) | 0.9 | +| `--retriever_top_k` | Number of top bullets to retrieve per sample. Enables retrieval when set. | None (disabled) | +| `--retriever_model_name` | Sentence-transformers model for retrieval embeddings | `intfloat/multilingual-e5-large` | diff --git a/ace/ace.py b/ace/ace.py index 2d662adc..640e8861 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -13,7 +13,7 @@ from datetime import datetime from typing import Dict, List, Tuple, Optional, Any -from .core import Generator, Reflector, Curator, BulletpointAnalyzer +from .core import Generator, Reflector, Curator, BulletpointAnalyzer, Retriever from playbook_utils import * from logger import * from utils import * @@ -39,7 +39,9 @@ def __init__( max_tokens: int = 4096, initial_playbook: Optional[str] = None, use_bulletpoint_analyzer: bool = False, - bulletpoint_analyzer_threshold: float = 0.90 + bulletpoint_analyzer_threshold: float = 0.90, + retriever_top_k: int = 5, + retriever_model_name: str = "intfloat/multilingual-e5-large" ): """ Initialize the ACE system. @@ -53,6 +55,8 @@ def __init__( initial_playbook: Initial playbook content (optional) use_bulletpoint_analyzer: Whether to use bulletpoint analyzer for deduplication bulletpoint_analyzer_threshold: Similarity threshold for bulletpoint analyzer (0-1) + retriever_top_k: Number of top bullets to retrieve per sample + retriever_model_name: Sentence-transformers model for retrieval embeddings """ # Initialize API clients generator_client, reflector_client, curator_client = initialize_clients(api_provider) @@ -61,6 +65,7 @@ def __init__( self.generator = Generator(generator_client, api_provider, generator_model, max_tokens) self.reflector = Reflector(reflector_client, api_provider, reflector_model, max_tokens) self.curator = Curator(curator_client, api_provider, curator_model, max_tokens) + self.retriever = Retriever(model_name=retriever_model_name, top_k=retriever_top_k) # Initialize bulletpoint analyzer if requested and available self.use_bulletpoint_analyzer = use_bulletpoint_analyzer @@ -131,7 +136,10 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: 'save_dir': config.get('save_dir', './results'), 'test_workers': config.get('test_workers', 20), 'use_bulletpoint_analyzer': config.get('use_bulletpoint_analyzer', False), - 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90) + 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90), + 'use_retriever': config.get('use_retriever', False), + 'retriever_top_k': config.get('retriever_top_k', 5), + 'retriever_model_name': config.get('retriever_model_name', 'intfloat/multilingual-e5-large') } def _setup_paths(self, save_dir: str, task_name: str, mode: str) -> Tuple[str, str]: @@ -290,7 +298,8 @@ def run( config=config, log_dir=log_dir, save_path=save_path, - prefix="final" + prefix="final", + use_retriever=True ) results['final_test_results'] = final_test_results print(f"Final Test Accuracy: {final_test_results['accuracy']:.3f}\n") @@ -340,7 +349,8 @@ def run( config=config, log_dir=log_dir, save_path=save_path, - prefix="test" + prefix="test", + use_retriever=config.get('use_retriever', False) ) results['test_results'] = test_results @@ -377,7 +387,8 @@ def _run_test( config: Dict[str, Any], log_dir: str, save_path: str, - prefix: str = "test" + prefix: str = "test", + use_retriever: bool = False ) -> Dict[str, Any]: """ Run testing @@ -390,6 +401,7 @@ def _run_test( log_dir: Directory for detailed logs save_path: Path to save results prefix: Prefix for saved files (e.g., 'initial', 'final', 'test') + use_retriever: If True, use retriever to build per-sample mini playbooks Returns: Dictionary with test results @@ -397,6 +409,11 @@ def _run_test( config_params = self._extract_config_params(config) use_json_mode = config_params['use_json_mode'] test_workers = config_params['test_workers'] + + retriever = None + if use_retriever: + self.retriever.index_playbook(playbook) + retriever = self.retriever test_results, test_error_log = evaluate_test_set( data_processor, @@ -406,7 +423,8 @@ def _run_test( self.max_tokens, log_dir, max_workers=test_workers, - use_json_mode=use_json_mode + use_json_mode=use_json_mode, + retriever=retriever ) # Save test results @@ -459,7 +477,7 @@ def _train_single_sample( question = task_dict.get("question", "") context = task_dict.get("context", "") target = task_dict.get("target", "") - + # STEP 1: Initial generation (pre-train) print("Generating initial answer...") gen_response, bullet_ids, call_info = self.generator.generate( @@ -525,7 +543,7 @@ def _train_single_sample( self.playbook = update_bullet_counts( self.playbook, bullet_tags ) - + # Regenerate with reflection gen_response, bullet_ids, _ = self.generator.generate( question=question, @@ -604,7 +622,7 @@ def _train_single_sample( threshold=self.bulletpoint_analyzer_threshold, merge=True ) - + # STEP 4: Post-curator generation gen_response, _, _ = self.generator.generate( question=question, diff --git a/ace/core/__init__.py b/ace/core/__init__.py index 6d4d0926..7f2ece92 100644 --- a/ace/core/__init__.py +++ b/ace/core/__init__.py @@ -7,5 +7,6 @@ from .reflector import Reflector from .curator import Curator from .bulletpoint_analyzer import BulletpointAnalyzer, DEDUP_AVAILABLE +from .retriever import Retriever -__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE'] \ No newline at end of file +__all__ = ['Generator', 'Reflector', 'Curator', 'BulletpointAnalyzer', 'DEDUP_AVAILABLE', 'Retriever'] \ No newline at end of file diff --git a/ace/core/retriever.py b/ace/core/retriever.py new file mode 100644 index 00000000..8fdab759 --- /dev/null +++ b/ace/core/retriever.py @@ -0,0 +1,68 @@ +from playbook_utils import parse_playbook_line, format_playbook_line +from sentence_transformers import SentenceTransformer +import numpy as np + + +class Retriever: + def __init__(self, model_name: str = "intfloat/multilingual-e5-large", top_k: int = 5): + self.model = SentenceTransformer(model_name, device="cpu") + self.top_k = top_k + self.bullets_with_sections: list = [] + self.passage_embeddings: np.ndarray | None = None + + def index_playbook(self, playbook: str) -> None: + """Parse playbook bullets and pre-compute their embeddings.""" + self.bullets_with_sections = self._extract_bullets_with_sections(playbook) + if not self.bullets_with_sections: + self.passage_embeddings = None + return + passage_texts = [ + "passage: " + b["content"] for b in self.bullets_with_sections + ] + self.passage_embeddings = self.model.encode(passage_texts, normalize_embeddings=True) + + def retrieve(self, question: str, context: str = "", top_k: int | None = None) -> str: + """Return a mini-playbook string containing the top-k most relevant bullets.""" + if not self.bullets_with_sections or self.passage_embeddings is None: + return "" + + top_k = top_k if top_k is not None else self.top_k + top_k = min(top_k, len(self.bullets_with_sections)) + + query_text = "query: " + question + " " + context + query_embedding = self.model.encode(query_text, normalize_embeddings=True) + + similarities = np.dot(self.passage_embeddings, query_embedding) + top_k_indices = set(np.argsort(similarities)[-top_k:]) + + section_order = list(dict.fromkeys( + b["section"] for b in self.bullets_with_sections + )) + + section_bullets: dict[str, list[str]] = {s: [] for s in section_order} + for i, b in enumerate(self.bullets_with_sections): + if i in top_k_indices: + line = format_playbook_line(b["id"], b["helpful"], b["harmful"], b["content"]) + section_bullets[b["section"]].append(line) + + lines = [] + for section in section_order: + if section_bullets[section]: + lines.append(section) + lines.extend(section_bullets[section]) + lines.append("") + + return "\n".join(lines).rstrip() + + def _extract_bullets_with_sections(self, playbook: str) -> list: + results = [] + current_section = "" + for line in playbook.strip().split("\n"): + if line.strip().startswith("##"): + current_section = line.strip() + else: + parsed = parse_playbook_line(line) + if parsed: + parsed["section"] = current_section + results.append(parsed) + return results diff --git a/eval/finance/run.py b/eval/finance/run.py index 86535536..65af4d6d 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -75,6 +75,12 @@ def parse_args(): parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") + # Retriever configuration — passing --retriever_top_k enables retrieval + parser.add_argument("--retriever_top_k", type=int, default=None, + help="Number of top bullets to retrieve per question. Enables retrieval when set.") + parser.add_argument("--retriever_model_name", type=str, default="intfloat/multilingual-e5-large", + help="Sentence-transformers model for retrieval embeddings") + # Output configuration parser.add_argument("--save_path", type=str, required=True, help="Directory to save results") @@ -202,7 +208,9 @@ def main(): max_tokens=args.max_tokens, initial_playbook=initial_playbook, use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, - bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold + bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold, + retriever_top_k=args.retriever_top_k or 5, + retriever_model_name=args.retriever_model_name ) # Prepare configuration @@ -223,7 +231,10 @@ def main(): 'initial_playbook_path': args.initial_playbook_path, 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, - 'api_provider': args.api_provider + 'api_provider': args.api_provider, + 'use_retriever': args.retriever_top_k is not None, + 'retriever_top_k': args.retriever_top_k, + 'retriever_model_name': args.retriever_model_name } # Execute using the unified run method diff --git a/playbook_utils.py b/playbook_utils.py index a457ba96..b445b0e9 100644 --- a/playbook_utils.py +++ b/playbook_utils.py @@ -369,3 +369,17 @@ def extract_playbook_bullets(playbook_text, bullet_ids): formatted_bullets.append(f"[{bullet['id']}] helpful={bullet['helpful']} harmful={bullet['harmful']} :: {bullet['content']}") return '\n'.join(formatted_bullets) + + +def extract_all_bullets(playbook_text): + """ + Extract all bullet points from playbook. + """ + lines = playbook_text.strip().split('\n') + all_bullets = [] + for line in lines: + if line.strip(): + parsed = parse_playbook_line(line) + if parsed: + all_bullets.append(parsed) + return all_bullets \ No newline at end of file diff --git a/utils.py b/utils.py index 01ef7bdc..f8a3b60c 100644 --- a/utils.py +++ b/utils.py @@ -38,7 +38,7 @@ def initialize_clients(api_provider): reflector_client = openai.OpenAI(api_key=api_key, base_url=base_url) curator_client = openai.OpenAI(api_key=api_key, base_url=base_url) - print("Using Together API for all models") + print(f"Using {api_provider} API for all models") return generator_client, reflector_client, curator_client def get_section_slug(section_name): @@ -152,13 +152,15 @@ def count_tokens(prompt: str) -> int: return len(enc.encode(prompt)) -def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: +def evaluate_single_test_sample(args_tuple, data_processor, retriever=None) -> Tuple[Dict, str]: """ Evaluate a single test sample - task-agnostic implementation. Args: args_tuple: Tuple of (index, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode) data_processor: DataProcessor instance with answer_is_correct method + retriever: Optional Retriever instance. When provided, retrieves a per-sample + mini playbook instead of using the full playbook. """ (i, task_dict, generator, playbook, max_tokens, log_dir, use_json_mode) = args_tuple try: @@ -166,6 +168,11 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: question = task_dict["question"] target = task_dict["target"] + sub_playbook_tokens = None + if retriever is not None: + playbook = retriever.retrieve(question, context) + sub_playbook_tokens = count_tokens(playbook) + gen_response, bullet_ids, call_info = generator.generate( question=question, playbook=playbook, @@ -179,13 +186,17 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: final_answer = extract_answer(gen_response) is_correct = data_processor.answer_is_correct(final_answer, target) - return { + result = { "index": i, "final_answer": final_answer, "target": target, "is_correct": is_correct, "success": True - }, None + } + if sub_playbook_tokens is not None: + result["sub_playbook_tokens"] = sub_playbook_tokens + + return result, None except Exception as e: return None, f"Error evaluating sample {i}: {type(e).__name__}: {str(e)}" @@ -193,25 +204,29 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: def evaluate_test_set(data_processor, generator, playbook, test_samples, max_tokens=4096, log_dir=None, max_workers=20, - use_json_mode=False) -> Tuple[Dict, Dict]: + use_json_mode=False, retriever=None) -> Tuple[Dict, Dict]: """ Parallel evaluation of test set - task-agnostic implementation. Args: data_processor: DataProcessor instance with answer_is_correct and evaluate_accuracy methods generator: Generator instance - playbook: Current playbook string + playbook: Current playbook string (used as fallback when retriever is None) test_samples: List of test samples max_tokens: Max tokens for generation log_dir: Directory for logs max_workers: Number of parallel workers use_json_mode: Whether to use JSON mode + retriever: Optional Retriever instance (already indexed). When provided, + each sample gets a per-sample mini playbook via retrieval. Returns: Tuple of (results_dict, error_logs_dict) """ print(f"\n{'='*40}") print(f"EVALUATING TEST SET - {len(test_samples)} samples, {max_workers} workers") + if retriever is not None: + print(f" Using retriever (top_k={retriever.top_k})") print(f"{'='*40}") args_list = [ @@ -221,12 +236,12 @@ def evaluate_test_set(data_processor, generator, playbook, test_samples, results = { "correct": 0, "total": 0, "no_answer": 0, - "answers": [], "targets": [], "errors": [] + "answers": [], "targets": [], "errors": [], + "sub_playbook_token_counts": [] } - # Use a wrapper to pass data_processor to the evaluation function def eval_wrapper(args_tuple): - return evaluate_single_test_sample(args_tuple, data_processor) + return evaluate_single_test_sample(args_tuple, data_processor, retriever=retriever) with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_args = { @@ -246,6 +261,9 @@ def eval_wrapper(args_tuple): results["total"] += 1 results["answers"].append(result["final_answer"]) results["targets"].append(result["target"]) + + if "sub_playbook_tokens" in result: + results["sub_playbook_token_counts"].append(result["sub_playbook_tokens"]) if not result["is_correct"]: results["errors"].append({ @@ -270,16 +288,22 @@ def eval_wrapper(args_tuple): "total": results["total"], "no_answer": results["no_answer"] } + + token_counts = results["sub_playbook_token_counts"] + if token_counts: + avg_tokens = round(sum(token_counts) / len(token_counts), 1) + final_results["avg_sub_playbook_tokens"] = avg_tokens + print(f"\n Avg sub-playbook tokens: {avg_tokens}") error_logs = { "accuracy": accuracy, "errors": results["errors"] } - print(f"\nšŸ“Š Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})") + print(f"\n Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})") else: - results = {"accuracy": 0.0, "correct": 0, "total": 0} + final_results = {"accuracy": 0.0, "correct": 0, "total": 0} error_logs = {} - print(f"\nšŸ“Š No valid results!") + print(f"\n No valid results!") return final_results, error_logs \ No newline at end of file