From 2dd392fd23ba018bfee7a05ab089bdc2a5a70391 Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 18 Feb 2026 18:04:34 -0800 Subject: [PATCH 1/3] Add Stream-Bench support --- .gitignore | 8 +- ace/ace.py | 244 +++--- ace/core/reflector.py | 70 +- ace/prompts/reflector.py | 6 + eval/finance/data_processor.py | 6 +- eval/finance/run.py | 79 +- eval/stream-bench/README.md | 499 +++++++++++++ eval/stream-bench/analyze_logs.py | 610 +++++++++++++++ eval/stream-bench/data/bird_config.json | 112 +++ eval/stream-bench/data/cosql_config.json | 64 ++ eval/stream-bench/data/spider_config.json | 66 ++ eval/stream-bench/data_processor.py | 697 ++++++++++++++++++ eval/stream-bench/dataset_stats.py | 135 ++++ eval/stream-bench/download_text2sql_data.py | 160 ++++ eval/stream-bench/plot.py | 575 +++++++++++++++ .../preprocess_streambench_bird.py | 274 +++++++ .../preprocess_streambench_cosql.py | 251 +++++++ .../preprocess_streambench_spider.py | 246 +++++++ eval/stream-bench/run.py | 475 ++++++++++++ eval/stream-bench/run_playbook.py | 525 +++++++++++++ llm.py | 64 +- utils.py | 86 ++- 22 files changed, 5102 insertions(+), 150 deletions(-) create mode 100644 eval/stream-bench/README.md create mode 100644 eval/stream-bench/analyze_logs.py create mode 100644 eval/stream-bench/data/bird_config.json create mode 100644 eval/stream-bench/data/cosql_config.json create mode 100644 eval/stream-bench/data/spider_config.json create mode 100644 eval/stream-bench/data_processor.py create mode 100755 eval/stream-bench/dataset_stats.py create mode 100644 eval/stream-bench/download_text2sql_data.py create mode 100644 eval/stream-bench/plot.py create mode 100644 eval/stream-bench/preprocess_streambench_bird.py create mode 100644 eval/stream-bench/preprocess_streambench_cosql.py create mode 100644 eval/stream-bench/preprocess_streambench_spider.py create mode 100644 eval/stream-bench/run.py create mode 100755 eval/stream-bench/run_playbook.py diff --git a/.gitignore b/.gitignore index cdb39abd..45799b71 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,10 @@ .env __pycache__/ */__pycache__/ -results/ \ No newline at end of file +results/ + +# OS files +.DS_Store + +# data files +eval/stream-bench/data diff --git a/ace/ace.py b/ace/ace.py index 2d662adc..9316b23e 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -10,6 +10,7 @@ import os import json +import time from datetime import datetime from typing import Dict, List, Tuple, Optional, Any @@ -111,10 +112,10 @@ def _initialize_empty_playbook(self) -> str: def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: """ Extract common configuration parameters. - + Args: config: Configuration dictionary - + Returns: Dictionary with extracted parameters """ @@ -131,24 +132,36 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: 'save_dir': config.get('save_dir', './results'), 'test_workers': config.get('test_workers', 20), 'use_bulletpoint_analyzer': config.get('use_bulletpoint_analyzer', False), - 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90) + 'bulletpoint_analyzer_threshold': config.get('bulletpoint_analyzer_threshold', 0.90), + 'pass_sql_eval_results': config.get('pass_sql_eval_results', False) } - def _setup_paths(self, save_dir: str, task_name: str, mode: str) -> Tuple[str, str]: + def _setup_paths(self, save_dir: str, task_name: str, mode: str, db_name: str = None, curriculum: str = None) -> Tuple[str, str]: """ Setup logging paths and directories. - + Args: save_dir: Base path for saving results task_name: task name mode: 'offline', 'online', or 'eval_only' - + db_name: Optional database name to include in folder name + curriculum: Optional curriculum level to include in folder name + Returns: Tuple of (usage_log_path, playbook_dir) """ # Create timestamped run folder timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - run_folder = f"ace_run_{timestamp}_{task_name}_{mode}" + + # Build run folder name with optional db_name and curriculum + run_folder_parts = ["ace_run", timestamp, task_name] + if db_name: + run_folder_parts.append(db_name) + if curriculum: + run_folder_parts.append(curriculum) + run_folder_parts.append(mode) + run_folder = "_".join(run_folder_parts) + save_path = os.path.join(save_dir, run_folder) os.makedirs(save_path, exist_ok=True) log_dir = os.path.join(save_path, "detailed_llm_logs") @@ -160,7 +173,7 @@ def _setup_paths(self, save_dir: str, task_name: str, mode: str) -> Tuple[str, s usage_log_path = os.path.join(save_path, "bullet_usage_log.jsonl") playbook_dir = os.path.join(save_path, "intermediate_playbooks") os.makedirs(playbook_dir, exist_ok=True) - + return save_path, usage_log_path, playbook_dir, log_dir def run( @@ -169,48 +182,59 @@ def run( train_samples: Optional[List[Dict[str, Any]]] = None, val_samples: Optional[List[Dict[str, Any]]] = None, test_samples: Optional[List[Dict[str, Any]]] = None, - data_processor = None, + train_processor = None, + val_processor = None, + test_processor = None, + data_processor = None, # Kept for backward compatibility config: Dict[str, Any] = None ) -> Dict[str, Any]: """ Main entrypoint for running ACE system in different modes. - + Args: mode: Run mode - 'offline', 'online', or 'eval_only' train_samples: Training samples (required for offline mode) val_samples: Validation samples (required for offline mode) test_samples: Test samples (required for online and eval_only modes) - data_processor: Data processor instance for the task + train_processor: Data processor for training samples + val_processor: Data processor for validation samples + test_processor: Data processor for test samples + data_processor: (Deprecated) Single processor for backward compatibility config: Configuration dictionary - + Returns: Dictionary with results depending on the mode """ + # Handle backward compatibility: if data_processor provided but no split processors, use it for all + if data_processor is not None and train_processor is None and val_processor is None and test_processor is None: + train_processor = val_processor = test_processor = data_processor # Validate inputs if mode not in ['offline', 'online', 'eval_only']: raise ValueError(f"Invalid mode: {mode}. Must be 'offline', 'online', or 'eval_only'") - + if mode == 'offline' and (train_samples is None or val_samples is None): raise ValueError("Offline mode requires train_samples and val_samples") - + if mode == 'online' and test_samples is None: raise ValueError("Online mode requires test_samples") - + if mode == 'eval_only' and test_samples is None: raise ValueError("eval_only mode requires test_samples") - + # Extract configuration config_params = self._extract_config_params(config) task_name = config_params['task_name'] save_dir = config_params['save_dir'] - + db_name = config.get('db_name', None) if config else None + curriculum = config.get('curriculum', None) if config else None + # Setup paths based on mode if mode == 'eval_only': - save_path, log_dir = self._setup_paths(save_dir, task_name, mode) + save_path, log_dir = self._setup_paths(save_dir, task_name, mode, db_name, curriculum) usage_log_path = None playbook_dir = None else: - save_path, usage_log_path, playbook_dir, log_dir = self._setup_paths(save_dir, task_name, mode) + save_path, usage_log_path, playbook_dir, log_dir = self._setup_paths(save_dir, task_name, mode, db_name, curriculum) # Save configuration config_path = os.path.join(save_path, "run_config.json") @@ -252,7 +276,7 @@ def run( print(f"{'='*60}\n") initial_test_results = self._run_test( test_samples=test_samples, - data_processor=data_processor, + data_processor=test_processor, playbook=self.playbook, config=config, log_dir=log_dir, @@ -260,8 +284,8 @@ def run( prefix="initial" ) results['initial_test_results'] = initial_test_results - print(f"Initial Test Accuracy: {initial_test_results['accuracy']:.3f}\n") - + print(f"Initial Test Accuracy: {initial_test_results['accuracy']:.3f} ({initial_test_results['correct']}/{initial_test_results['total']})\n") + # 2. Run offline training print(f"\n{'='*60}") print(f"STARTING OFFLINE TRAINING") @@ -269,7 +293,8 @@ def run( training_results = self._offline_train( train_samples=train_samples, val_samples=val_samples, - data_processor=data_processor, + train_processor=train_processor, + val_processor=val_processor, config=config, save_path=save_path, usage_log_path=usage_log_path, @@ -285,7 +310,7 @@ def run( print(f"{'='*60}\n") final_test_results = self._run_test( test_samples=test_samples, - data_processor=data_processor, + data_processor=test_processor, playbook=self.best_playbook, config=config, log_dir=log_dir, @@ -293,7 +318,7 @@ def run( prefix="final" ) results['final_test_results'] = final_test_results - print(f"Final Test Accuracy: {final_test_results['accuracy']:.3f}\n") + print(f"Final Test Accuracy: {final_test_results['accuracy']:.3f} ({final_test_results['correct']}/{final_test_results['total']})\n") elif mode == 'online': # ONLINE MODE WORKFLOW @@ -303,7 +328,7 @@ def run( print(f"{'='*60}\n") initial_test_results = self._run_test( test_samples=test_samples, - data_processor=data_processor, + data_processor=test_processor, playbook=self.playbook, config=config, log_dir=log_dir, @@ -311,15 +336,15 @@ def run( prefix="initial" ) results['initial_test_results'] = initial_test_results - print(f"Initial Test Accuracy: {initial_test_results['accuracy']:.3f}\n") - + print(f"Initial Test Accuracy: {initial_test_results['accuracy']:.3f} ({initial_test_results['correct']}/{initial_test_results['total']})\n") + # 2. Run online training and testing print(f"\n{'='*60}") print(f"STARTING ONLINE TRAIN AND TEST") print(f"{'='*60}\n") online_results = self._online_train_and_test( test_samples=test_samples, - data_processor=data_processor, + data_processor=test_processor, config=config, save_path=save_path, usage_log_path=usage_log_path, @@ -327,7 +352,7 @@ def run( log_dir=log_dir ) results['online_test_results'] = online_results - + else: # eval_only # EVAL ONLY MODE WORKFLOW print(f"\n{'='*60}") @@ -335,7 +360,7 @@ def run( print(f"{'='*60}\n") test_results = self._run_test( test_samples=test_samples, - data_processor=data_processor, + data_processor=test_processor, playbook=self.playbook, config=config, log_dir=log_dir, @@ -357,16 +382,18 @@ def run( if mode == 'offline': print(f"Best Validation Accuracy: {results['training_results']['best_validation_accuracy']:.3f}") if test_samples: - print(f"Initial Test Accuracy: {results['initial_test_results']['accuracy']:.3f}") - print(f"Final Test Accuracy: {results['final_test_results']['accuracy']:.3f}") + print(f"Initial Test Accuracy: {results['initial_test_results']['accuracy']:.3f} ({results['initial_test_results']['correct']}/{results['initial_test_results']['total']})") + print(f"Final Test Accuracy: {results['final_test_results']['accuracy']:.3f} ({results['final_test_results']['correct']}/{results['final_test_results']['total']})") elif mode == 'online': - print(f"Initial Test Accuracy: {results['initial_test_results']['accuracy']:.3f}") - print(f"Final Test Accuracy: {results['online_test_results']['accuracy']:.3f}") + print(f"Initial Test Accuracy: {results['initial_test_results']['accuracy']:.3f} ({results['initial_test_results']['correct']}/{results['initial_test_results']['total']})") + print(f"Final Test Accuracy: {results['online_test_results']['accuracy']:.3f} ({results['online_test_results']['correct']}/{results['online_test_results']['total']})") else: # eval_only - print(f"Test Accuracy: {results['test_results']['accuracy']:.3f}") + print(f"Test Accuracy: {results['test_results']['accuracy']:.3f} ({results['test_results']['correct']}/{results['test_results']['total']})") print(f"Results saved to: {save_path}") print(f"{'='*60}\n") - + + # Add save_path to results for external use + results['save_path'] = save_path return results def _run_test( @@ -474,15 +501,26 @@ def _train_single_sample( # Extract answer and check correctness final_answer = extract_answer(gen_response) - is_correct = data_processor.answer_is_correct(final_answer, target) + sample_metadata = task_dict.get("others", None) + + # Get SQL evaluation results if flag is enabled + pass_sql_eval_results = config_params.get('pass_sql_eval_results', False) + if pass_sql_eval_results: + is_correct, sql_exec_results = data_processor.answer_is_correct( + final_answer, target, sample_metadata, return_exec_results=True + ) + else: + is_correct = data_processor.answer_is_correct(final_answer, target, sample_metadata) + sql_exec_results = None + pre_train_answer = final_answer - + print(f"Correct: {is_correct}") - + # Log bullet usage log_bullet_usage(usage_log_path, epoch, step, task_dict, bullet_ids, playbook=self.playbook, is_correct=is_correct) - + # Track pre-train result tracking_dict = { "pre_train_result": { @@ -492,20 +530,20 @@ def _train_single_sample( "playbook_length": len(self.playbook) } } - + reflection_content = "(empty)" - + # STEP 2: Reflection and regeneration if not is_correct: # For incorrect answers - iterate reflection rounds for round_num in range(max_num_rounds): print(f"Reflection round {round_num + 1}/{max_num_rounds}") - + # Get bullets for reflector playbook_bullets = extract_playbook_bullets( self.playbook, bullet_ids ) - + # Reflect on error reflection_content, bullet_tags, _ = self.reflector.reflect( question=question, @@ -517,7 +555,8 @@ def _train_single_sample( use_ground_truth=not no_ground_truth, use_json_mode=use_json_mode, call_id=f"{step_id}_round_{round_num}", - log_dir=log_dir + log_dir=log_dir, + sql_exec_results=sql_exec_results if pass_sql_eval_results else None ) # Update bullet counts @@ -538,8 +577,8 @@ def _train_single_sample( ) final_answer = extract_answer(gen_response) - - if data_processor.answer_is_correct(final_answer, target): + + if data_processor.answer_is_correct(final_answer, target, sample_metadata): print(f"Corrected after reflection round {round_num + 1}!") is_correct = True break @@ -549,7 +588,7 @@ def _train_single_sample( playbook_bullets = extract_playbook_bullets( self.playbook, bullet_ids ) - + reflection_content, bullet_tags, _ = self.reflector.reflect( question=question, reasoning_trace=gen_response, @@ -560,7 +599,8 @@ def _train_single_sample( use_ground_truth=not no_ground_truth, use_json_mode=use_json_mode, call_id=f"{step_id}_reflect_on_correct", - log_dir=log_dir + log_dir=log_dir, + sql_exec_results=sql_exec_results if pass_sql_eval_results else None ) # Update bullet counts @@ -618,8 +658,8 @@ def _train_single_sample( final_answer = extract_answer(gen_response) post_train_answer = final_answer - - post_train_is_correct = data_processor.answer_is_correct(final_answer, target) + + post_train_is_correct = data_processor.answer_is_correct(final_answer, target, sample_metadata) tracking_dict["post_train_result"] = { "final_answer": final_answer, "is_correct": post_train_is_correct, @@ -633,7 +673,8 @@ def _offline_train( self, train_samples: List[Dict[str, Any]], val_samples: List[Dict[str, Any]], - data_processor, + train_processor, + val_processor, config: Dict[str, Any], save_path: str, usage_log_path: str, @@ -642,11 +683,12 @@ def _offline_train( ) -> Dict[str, Any]: """ Run offline training - + Args: train_samples: List of training samples val_samples: List of validation samples - data_processor: Data processor instance for the task + train_processor: Data processor for training samples + val_processor: Data processor for validation samples config: Configuration dictionary save_path: Path to save results usage_log_path: Path for bullet usage logging @@ -670,6 +712,7 @@ def _offline_train( results = [] pre_train_post_train_results = [] error_logs = [] + step_timings = [] best_accuracy = 0.0 self.best_playbook = self.playbook @@ -678,28 +721,29 @@ def _offline_train( print(f"Val samples: {len(val_samples)}") print(f"Curator frequency: every {curator_frequency} steps") print(f"Evaluation frequency: every {eval_steps} steps\n") - + # Training loop for epoch in range(1, num_epochs + 1): print(f"\n{'='*60}") print(f"EPOCH {epoch}/{num_epochs}") print(f"{'='*60}") - + epoch_answers_pre_train = [] epoch_targets_pre_train = [] epoch_answers_post_train = [] epoch_targets_post_train = [] - + for step, task_dict in enumerate(train_samples): step += 1 + step_start_time = time.time() print(f"\n--- Step {step}/{len(train_samples)} ---") - + target = task_dict.get("target", "") - + # Use helper method for training single sample pre_train_answer, post_train_answer, tracking_dict = self._train_single_sample( task_dict=task_dict, - data_processor=data_processor, + data_processor=train_processor, step_id=f"train_e_{epoch}_s_{step}", epoch=epoch, step=step, @@ -708,18 +752,29 @@ def _offline_train( config_params=config_params, total_samples=len(train_samples) ) - + + step_elapsed_time = time.time() - step_start_time + print(f"Step {step} completed in {step_elapsed_time:.2f} seconds") + + # Track step timing + step_timings.append({ + "epoch": epoch, + "step": step, + "time_seconds": step_elapsed_time + }) + # Collect answers for accuracy calculation epoch_answers_pre_train.append(pre_train_answer) epoch_targets_pre_train.append(target) epoch_answers_post_train.append(post_train_answer) epoch_targets_post_train.append(target) - + # Track pre-train and post-train results pre_train_post_train_result = { "epoch": epoch, "step": step, "target": target, + "step_time_seconds": step_elapsed_time, **tracking_dict } pre_train_post_train_results.append(pre_train_post_train_result) @@ -739,18 +794,18 @@ def _offline_train( print(f"{'='*40}") # Compute training accuracies - pre_train_accuracy = data_processor.evaluate_accuracy( + pre_train_accuracy = train_processor.evaluate_accuracy( epoch_answers_pre_train, epoch_targets_pre_train ) - post_train_accuracy = data_processor.evaluate_accuracy( + post_train_accuracy = train_processor.evaluate_accuracy( epoch_answers_post_train, epoch_targets_post_train ) - + # Validation evaluation val_results = {} if val_samples: val_results, val_error_log = evaluate_test_set( - data_processor, self.generator, self.playbook, + val_processor, self.generator, self.playbook, val_samples, self.max_tokens, log_dir, max_workers=test_workers, use_json_mode=use_json_mode ) @@ -809,25 +864,31 @@ def _offline_train( "best_accuracy": best_accuracy, "results": results, }, f, indent=2) - + pre_train_post_train_results_path = os.path.join(save_path, "pre_train_post_train_results.json") with open(pre_train_post_train_results_path, "w") as f: json.dump(pre_train_post_train_results, f, indent=2) - + + # Calculate timing statistics + total_training_time = sum(t["time_seconds"] for t in step_timings) + avg_step_time = total_training_time / len(step_timings) if step_timings else 0 + # Save final playbook final_playbook_path = os.path.join(save_path, f"final_playbook.txt") with open(final_playbook_path, "w") as f: f.write(self.playbook) - + # Save best playbook best_playbook_path = os.path.join(save_path, f"best_playbook.txt") with open(best_playbook_path, "w") as f: f.write(self.best_playbook) - + print(f"\n{'='*60}") print(f"OFFLINE TRAINING COMPLETE") print(f"{'='*60}") print(f"Best Validation Accuracy: {best_accuracy:.3f}") + print(f"Total Training Time: {total_training_time/60:.2f} minutes ({total_training_time:.2f} seconds)") + print(f"Average Step Time: {avg_step_time:.2f} seconds") print(f"{'='*60}\n") return {"best_validation_accuracy": best_accuracy} @@ -918,7 +979,8 @@ def _online_train_and_test( # Initialize tracking train_results = [] pre_train_post_train_results = [] - + step_timings = [] + # Test tracking - accumulate across all windows correct_count_sample_based = 0 correct_count = 0 @@ -1009,12 +1071,13 @@ def _online_train_and_test( for local_step, task_dict in enumerate(window_samples): global_step += 1 local_step += 1 - + step_start_time = time.time() + print(f"\n--- Window {window_idx + 1}, Step {local_step}/{len(window_samples)} " f"(Global step {global_step}) ---") - + target = task_dict.get("target", "") - + # Use helper method for training single sample pre_train_answer, post_train_answer, tracking_dict = self._train_single_sample( task_dict=task_dict, @@ -1027,18 +1090,29 @@ def _online_train_and_test( config_params=config_params, total_samples=len(test_samples) ) - + + step_elapsed_time = time.time() - step_start_time + print(f"Step {global_step} completed in {step_elapsed_time:.2f} seconds") + + # Track step timing + step_timings.append({ + "window": window_idx + 1, + "global_step": global_step, + "time_seconds": step_elapsed_time + }) + # Collect answers for accuracy calculation epoch_answers_pre_train.append(pre_train_answer) epoch_targets_pre_train.append(target) epoch_answers_post_train.append(post_train_answer) epoch_targets_post_train.append(target) - + # Track pre-train and post-train results pre_train_post_train_result = { "window": window_idx + 1, "global_step": global_step, "target": target, + "step_time_seconds": step_elapsed_time, **tracking_dict } pre_train_post_train_results.append(pre_train_post_train_result) @@ -1053,10 +1127,10 @@ def _online_train_and_test( # End of window - compute training accuracies for this window pre_train_accuracy = data_processor.evaluate_accuracy( - epoch_answers_pre_train, epoch_targets_pre_train + epoch_answers_pre_train, epoch_targets_pre_train, window_samples ) post_train_accuracy = data_processor.evaluate_accuracy( - epoch_answers_post_train, epoch_targets_post_train + epoch_answers_post_train, epoch_targets_post_train, window_samples ) window_train_result = { @@ -1123,16 +1197,22 @@ def _online_train_and_test( pre_train_post_train_results_path = os.path.join(save_path, "pre_train_post_train_results.json") with open(pre_train_post_train_results_path, "w") as f: json.dump(pre_train_post_train_results, f, indent=2) - + + # Calculate timing statistics + total_training_time = sum(t["time_seconds"] for t in step_timings) + avg_step_time = total_training_time / len(step_timings) if step_timings else 0 + # Save final playbook final_playbook_path = os.path.join(save_path, f"final_playbook.txt") with open(final_playbook_path, "w") as f: f.write(self.playbook) - + print(f"\n{'='*60}") print(f"ONLINE TRAINING AND TESTING COMPLETE") print(f"{'='*60}") - print(f"Final Test Accuracy: {final_test_accuracy:.3f}") + print(f"Final Test Accuracy: {final_test_accuracy:.3f} ({correct_count}/{total_count})") + print(f"Total Training Time: {total_training_time/60:.2f} minutes ({total_training_time:.2f} seconds)") + print(f"Average Step Time: {avg_step_time:.2f} seconds") print(f"{'='*60}\n") return { diff --git a/ace/core/reflector.py b/ace/core/reflector.py index 134ea4c3..8c9be73a 100644 --- a/ace/core/reflector.py +++ b/ace/core/reflector.py @@ -41,11 +41,12 @@ def reflect( use_ground_truth: bool = True, use_json_mode: bool = False, call_id: str = "reflect", - log_dir: Optional[str] = None + log_dir: Optional[str] = None, + sql_exec_results: Optional[Dict[str, Any]] = None ) -> Tuple[str, List[Dict[str, str]], Dict[str, Any]]: """ Analyze the generator's output and tag bullets. - + Args: question: The original question reasoning_trace: The generator's reasoning @@ -57,10 +58,18 @@ def reflect( use_json_mode: Whether to use JSON mode call_id: Unique identifier for this call log_dir: Directory for logging - + sql_exec_results: Optional dict containing SQL execution results with keys: + - 'predicted_result': List of tuples from predicted SQL execution + - 'ground_truth_result': List of tuples from ground truth SQL execution + - 'db_name': Database name used for evaluation + - 'error': Error message if execution failed + Returns: Tuple of (reflection_content, bullet_tags, call_info) """ + # Format SQL execution results for the prompt + sql_exec_text = self._format_sql_exec_results(sql_exec_results) + # Select the appropriate prompt if use_ground_truth and ground_truth: prompt = REFLECTOR_PROMPT.format( @@ -69,7 +78,8 @@ def reflect( predicted_answer, ground_truth, environment_feedback, - bullets_used + bullets_used, + sql_exec_text ) else: prompt = REFLECTOR_PROMPT_NO_GT.format( @@ -77,7 +87,8 @@ def reflect( reasoning_trace, predicted_answer, environment_feedback, - bullets_used + bullets_used, + sql_exec_text ) response, call_info = timed_llm_call( @@ -96,7 +107,54 @@ def reflect( bullet_tags = self._extract_bullet_tags(response, use_json_mode) return response, bullet_tags, call_info - + + def _format_sql_exec_results(self, sql_exec_results: Optional[Dict[str, Any]]) -> str: + """ + Format SQL execution results for display in the prompt. + + Args: + sql_exec_results: Dict containing execution results or None + + Returns: + Formatted string describing the execution results + """ + if not sql_exec_results: + return "No SQL execution results available." + + if "error" in sql_exec_results: + return f"Error during SQL execution: {sql_exec_results['error']}" + + db_name = sql_exec_results.get("db_name", "unknown") + pred_result = sql_exec_results.get("predicted_result", []) + gt_result = sql_exec_results.get("ground_truth_result", []) + + # Format the results + lines = [f"Database: {db_name}\n"] + + # Predicted SQL results + lines.append(f"Predicted SQL Execution Result ({len(pred_result)} rows):") + if pred_result: + for i, row in enumerate(pred_result[:20]): # Show first 20 rows + lines.append(f" Row {i+1}: {row}") + if len(pred_result) > 20: + lines.append(f" ... ({len(pred_result) - 20} more rows)") + else: + lines.append(" (Empty result set)") + + lines.append("") + + # Ground truth SQL results + lines.append(f"Ground Truth SQL Execution Result ({len(gt_result)} rows):") + if gt_result: + for i, row in enumerate(gt_result[:20]): # Show first 20 rows + lines.append(f" Row {i+1}: {row}") + if len(gt_result) > 20: + lines.append(f" ... ({len(gt_result) - 20} more rows)") + else: + lines.append(" (Empty result set)") + + return "\n".join(lines) + def _extract_bullet_tags( self, response: str, diff --git a/ace/prompts/reflector.py b/ace/prompts/reflector.py index 7b9f841b..b459e493 100644 --- a/ace/prompts/reflector.py +++ b/ace/prompts/reflector.py @@ -44,6 +44,9 @@ **Part of Playbook that's used by the generator to answer the question:** {} +**SQL Execution Results (if available):** +{} + **Answer in this exact JSON format:** {{ "reasoning": "[Your chain of thought / reasoning / thinking process, detailed analysis and calculations]", @@ -98,6 +101,9 @@ **Part of Playbook that's used by the generator to answer the question:** {} +**SQL Execution Results (if available):** +{} + **Answer in this exact JSON format:** {{ "reasoning": "[Your chain of thought / reasoning / thinking process, detailed analysis and calculations]", diff --git a/eval/finance/data_processor.py b/eval/finance/data_processor.py index aa110bb0..2f63c4e0 100644 --- a/eval/finance/data_processor.py +++ b/eval/finance/data_processor.py @@ -162,13 +162,14 @@ def _formula_answer_is_correct(self, predicted: str, ground_truth: str) -> bool: return predicted == ground_truth - def answer_is_correct(self, predicted: str, ground_truth: str) -> bool: + def answer_is_correct(self, predicted: str, ground_truth: str, sample_metadata=None) -> bool: """ Dataset-specific answer correctness check. Args: predicted: Model's answer ground_truth: Ground truth answer + sample_metadata: Optional dict containing sample metadata (unused for finance tasks) Returns: bool: True if answer is correct, False otherwise @@ -220,13 +221,14 @@ def _evaluate_formula_accuracy(self, out: List[str], target: List[str]) -> tuple return accuracy - def evaluate_accuracy(self, out: List[str], target: List[str]) -> tuple: + def evaluate_accuracy(self, out: List[str], target: List[str], samples=None) -> tuple: """ Dataset-specific accuracy evaluation. Args: out: List of model predictions target: List of ground truth targets + samples: Optional list of sample dicts (unused for finance tasks) Returns: tuple: (accuracy, response_list) diff --git a/eval/finance/run.py b/eval/finance/run.py index 86535536..543894e1 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -13,10 +13,10 @@ from ace import ACE from utils import initialize_clients -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser(description='ACE System - Refactored') - +def get_base_parser(description='ACE System'): + """Get base argument parser with common arguments.""" + parser = argparse.ArgumentParser(description=description) + # Task configuration parser.add_argument("--task_name", type=str, required=True, help="Name of the task (e.g., 'finer', 'formula')") @@ -27,11 +27,11 @@ def parse_args(): help="Run mode: 'offline' for offline training with validation, " "'online' for online training and testing on test split, " "'eval_only' for testing only with provided playbook") - + # Model configuration parser.add_argument("--api_provider", type=str, default="sambanova", choices=["sambanova", "together", "openai"], help="API provider") - parser.add_argument("--generator_model", type=str, + parser.add_argument("--generator_model", type=str, default="DeepSeek-V3.1", help="Model for generator") parser.add_argument("--reflector_model", type=str, @@ -40,7 +40,7 @@ def parse_args(): parser.add_argument("--curator_model", type=str, default="DeepSeek-V3.1", help="Model for curator") - + # Training configuration parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") @@ -54,7 +54,7 @@ def parse_args(): help="Update playbook every N samples for evaluation in online mode") parser.add_argument("--save_steps", type=int, default=50, help="Save intermediate playbooks every N steps") - + # System configuration parser.add_argument("--max_tokens", type=int, default=4096, help="Max tokens for LLM responses") @@ -62,57 +62,67 @@ def parse_args(): help="Total token budget for playbook") parser.add_argument("--test_workers", type=int, default=20, help="Number of parallel workers for testing") - + # Prompt configuration parser.add_argument("--json_mode", action="store_true", help="Enable JSON mode for LLM calls") parser.add_argument("--no_ground_truth", action="store_true", help="Don't use ground truth in reflection") - + # Bulletpoint analyzer configuration parser.add_argument("--use_bulletpoint_analyzer", action="store_true", help="Enable bulletpoint analyzer for deduplication and merging") parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") - + + # SQL evaluation configuration + parser.add_argument("--pass_sql_eval_results", action="store_true", + help="Pass SQL execution results to reflector for better error analysis") + # Output configuration parser.add_argument("--save_path", type=str, required=True, help="Directory to save results") - + + return parser + + +def parse_args(): + """Parse command line arguments for finance tasks.""" + parser = get_base_parser(description='ACE System - Finance') return parser.parse_args() def load_data(data_path: str): """ Load and process data from a JSONL file. - + Args: data_path: Path to the JSONL file - + Returns: List of dictionaries containing the data """ if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found: {data_path}") - + data = [] with open(data_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: # Skip empty lines data.append(json.loads(line)) - + print(f"Loaded {len(data)} samples from {data_path}") return data def preprocess_data(task_name, config, mode): """ Load training and test data for the specified task. - + Args: task_name: Name of the task config: Configuration dictionary with data paths mode: Run mode ('offline', 'online', or 'eval_only') - + Returns: Tuple of (train_samples, val_samples, test_samples, data_processor) - For offline mode: all three are loaded @@ -120,39 +130,39 @@ def preprocess_data(task_name, config, mode): - For eval_only mode: only test_samples """ processor = DataProcessor(task_name=task_name) - + # For online and eval_only modes, only load test data if mode in ["online", "eval_only"]: train_samples = None val_samples = None - + if "test_data" in config: test_samples = load_data(config["test_data"]) test_samples = processor.process_task_data(test_samples) else: raise ValueError(f"{mode} mode requires test data in config.") - + if mode == "online": print(f"Online mode: Training and testing on {len(test_samples)} examples") else: print(f"Eval only mode: Testing on {len(test_samples)} examples") - + # For offline mode, load train, val, and optionally test data else: train_samples = load_data(config["train_data"]) val_samples = load_data(config["val_data"]) train_samples = processor.process_task_data(train_samples) val_samples = processor.process_task_data(val_samples) - + if "test_data" in config: test_samples = load_data(config["test_data"]) test_samples = processor.process_task_data(test_samples) else: test_samples = [] - + print(f"Offline mode: Training on {len(train_samples)} examples, " f"validating on {len(val_samples)}, testing on {len(test_samples)}") - + return train_samples, val_samples, test_samples, processor @@ -167,7 +177,7 @@ def load_initial_playbook(path): def main(): """Main execution function.""" args = parse_args() - + print(f"\n{'='*60}") print(f"ACE SYSTEM") print(f"{'='*60}") @@ -175,24 +185,24 @@ def main(): print(f"Mode: {args.mode.upper().replace('_', ' ')}") print(f"Generator Model: {args.generator_model}") print(f"{'='*60}\n") - + # Load data with open("./eval/finance/data/sample_config.json", 'r') as f: task_config = json.load(f) train_samples, val_samples, test_samples, data_processor = preprocess_data( - args.task_name, + args.task_name, task_config[args.task_name], args.mode ) - + # Load initial playbook (or use empty if None provided) initial_playbook = load_initial_playbook(args.initial_playbook_path) if initial_playbook: print(f"Loaded initial playbook from {args.initial_playbook_path}\n") else: print("Using empty playbook as initial playbook\n") - + # Create ACE system ace_system = ACE( api_provider=args.api_provider, @@ -204,7 +214,7 @@ def main(): use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold ) - + # Prepare configuration config = { 'num_epochs': args.num_epochs, @@ -223,9 +233,10 @@ def main(): 'initial_playbook_path': args.initial_playbook_path, 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, + 'pass_sql_eval_results': args.pass_sql_eval_results, 'api_provider': args.api_provider } - + # Execute using the unified run method results = ace_system.run( mode=args.mode, @@ -235,7 +246,7 @@ def main(): data_processor=data_processor, config=config ) - + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/eval/stream-bench/README.md b/eval/stream-bench/README.md new file mode 100644 index 00000000..6ea9ece2 --- /dev/null +++ b/eval/stream-bench/README.md @@ -0,0 +1,499 @@ +# stream-bench + +Benchmarking framework for text-to-SQL evaluation using the ACE (Agentic Context Engineering) system. Supports three datasets: **BIRD**, **CoSQL**, and **Spider**. + + +## Overview + +Stream-bench evaluates the ACE system on text-to-SQL tasks using execution-based accuracy (result-set matching, not string matching). The typical workflow is: + +``` +Download DBs → Preprocess data → Configure task → Run ACE → Evaluate playbook → Plot results +``` + +ACE trains a playbook (a set of instructions) over a stream of training examples, then evaluates it on a held-out test set. + + +## Directory Structure + +``` +eval/stream-bench/ +├── run.py # Main ACE training/evaluation runner +├── run_playbook.py # Evaluate a saved playbook on test data +├── plot.py # Generate performance plots +├── analyze_logs.py # Analyze terminal output logs +├── data_processor.py # Core data loading and SQL evaluation +├── download_text2sql_data.py # Download raw databases (BIRD/CoSQL/Spider) +├── preprocess_streambench_bird.py # Create .jsonl from HuggingFace BIRD data +├── preprocess_streambench_cosql.py # Create .jsonl from HuggingFace CoSQL data +├── preprocess_streambench_spider.py # Create .jsonl from HuggingFace Spider data +├── dataset_stats.py # Print dataset statistics +└── data/ + ├── bird_config.json # Pre-defined BIRD task configurations + ├── cosql_config.json # Pre-defined CoSQL task configurations + ├── spider_config.json # Pre-defined Spider task configurations + ├── bird/dev_databases/ # BIRD test databases (after download) + ├── bird_train/train_databases/ # BIRD training databases (after download) + ├── cosql/ # CoSQL databases (after download) + ├── spider/database/ # Spider databases (after download) + ├── streambench_bird_train.jsonl # Preprocessed BIRD train split + ├── streambench_bird_val.jsonl # Preprocessed BIRD val split + ├── streambench_bird_test.jsonl # Preprocessed BIRD test split + ├── streambench_cosql_*.jsonl # Preprocessed CoSQL splits + └── streambench_spider_*.jsonl # Preprocessed Spider splits +``` + +Results are written to a top-level `results/` directory at the repo root. + + +## Setup + +Install dependencies from the repo root: + +```bash +uv sync +# or: pip install -e . +``` + +Set up your API key in `.env` (copy from `.env.example`): + +```bash +cp .env.example .env +# Edit .env and add your API key +``` + +All commands below should be run from the **repo root** (`ace/`), not from inside `eval/stream-bench/`. + + +## Step 1: Download Databases + +The raw SQLite databases are required for execution-based evaluation. The preprocessed `.jsonl` data files (Step 2) can be downloaded from HuggingFace independently, but you still need the databases to run SQL evaluation. + +### BIRD + +```bash +# Test/dev databases (~350 MB) +python eval/stream-bench/download_text2sql_data.py \ + --dataset bird --split dev \ + --save_dir eval/stream-bench/data + +# Training databases (~33 GB, only needed for offline mode training) +python eval/stream-bench/download_text2sql_data.py \ + --dataset bird --split train \ + --save_dir eval/stream-bench/data +``` + +After extraction: +``` +eval/stream-bench/data/bird/dev_databases/{db_name}/{db_name}.sqlite +eval/stream-bench/data/bird_train/train_databases/{db_name}/{db_name}.sqlite +``` + +### CoSQL + +```bash +python eval/stream-bench/download_text2sql_data.py \ + --dataset cosql \ + --save_dir eval/stream-bench/data +``` + +After extraction: +``` +eval/stream-bench/data/cosql/{db_name}/{db_name}.sqlite +eval/stream-bench/data/cosql/tables.json +``` + +### Spider + +```bash +python eval/stream-bench/download_text2sql_data.py \ + --dataset spider \ + --save_dir eval/stream-bench/data +``` + +After extraction: +``` +eval/stream-bench/data/spider/database/{db_name}/{db_name}.sqlite +eval/stream-bench/data/spider/tables.json +``` + +> **Note:** CoSQL and Spider downloads use `gdown` and require Google Drive access. Install with `pip install gdown` if missing. + + +## Step 2: Preprocess Data + +This step downloads the question/SQL pairs from HuggingFace (`appier-ai-research/StreamBench`) and combines them with database schema information to create `.jsonl` files used during training. + +**Skip this step** if the `.jsonl` files already exist in `eval/stream-bench/data/`. + +### BIRD + +```bash +# Test split (uses dev databases) +python eval/stream-bench/preprocess_streambench_bird.py \ + --split test \ + --bird_root eval/stream-bench/data/bird \ + --tables_json eval/stream-bench/data/dev_20240627/dev_tables.json \ + --out eval/stream-bench/data/streambench_bird_test.jsonl + +# Train split (uses train databases) +python eval/stream-bench/preprocess_streambench_bird.py \ + --split train \ + --bird_root eval/stream-bench/data/bird_train/train_databases \ + --out eval/stream-bench/data/streambench_bird_train.jsonl + +# Val split +python eval/stream-bench/preprocess_streambench_bird.py \ + --split validation \ + --bird_root eval/stream-bench/data/bird_train/train_databases \ + --out eval/stream-bench/data/streambench_bird_val.jsonl +``` + +### CoSQL + +```bash +python eval/stream-bench/preprocess_streambench_cosql.py \ + --split test \ + --cosql_root eval/stream-bench/data/cosql \ + --tables_json eval/stream-bench/data/cosql/tables.json \ + --out eval/stream-bench/data/streambench_cosql_test.jsonl + +python eval/stream-bench/preprocess_streambench_cosql.py \ + --split train \ + --cosql_root eval/stream-bench/data/cosql \ + --tables_json eval/stream-bench/data/cosql/tables.json \ + --out eval/stream-bench/data/streambench_cosql_train.jsonl + +python eval/stream-bench/preprocess_streambench_cosql.py \ + --split validation \ + --cosql_root eval/stream-bench/data/cosql \ + --tables_json eval/stream-bench/data/cosql/tables.json \ + --out eval/stream-bench/data/streambench_cosql_val.jsonl +``` + +### Spider + +```bash +python eval/stream-bench/preprocess_streambench_spider.py \ + --split test \ + --spider_root eval/stream-bench/data/spider \ + --tables_json eval/stream-bench/data/spider/tables.json \ + --out eval/stream-bench/data/streambench_spider_test.jsonl + +python eval/stream-bench/preprocess_streambench_spider.py \ + --split train \ + --spider_root eval/stream-bench/data/spider \ + --tables_json eval/stream-bench/data/spider/tables.json \ + --out eval/stream-bench/data/streambench_spider_train.jsonl + +python eval/stream-bench/preprocess_streambench_spider.py \ + --split validation \ + --spider_root eval/stream-bench/data/spider \ + --tables_json eval/stream-bench/data/spider/tables.json \ + --out eval/stream-bench/data/streambench_spider_val.jsonl +``` + +Each `.jsonl` record has the format: +```json +{ + "question_id": "0", + "question": "What is the highest eligible free rate for K-12 students in Alameda County?", + "sql": "SELECT ...", + "difficulty": "simple", + "db_name": "california_schools", + "db_schema": { "db_id": "...", "tables": [...], "primary_keys": [...], "foreign_keys": [...] } +} +``` + + +## Step 3: Configure a Run + +Runs are configured via a JSON file that specifies data paths and task parameters. Pre-built configs are in `eval/stream-bench/data/`. + +### Pre-built Configurations + +**BIRD** (`eval/stream-bench/data/bird_config.json`): + +| Task name | Samples | Notes | +|---|---|---| +| `bird_all` | all | Full dataset | +| `bird_150` | 150 | Random subset | +| `bird_150_balanced` | 150 | Equal per difficulty | +| `bird_300_balanced` | 300 | Equal per difficulty | +| `bird_432_balanced` | 432 | Equal per difficulty | +| `bird_1000_quasi_balanced` | 1000 | Balanced with fallback | + +**CoSQL** (`eval/stream-bench/data/cosql_config.json`): + +| Task name | Samples | Notes | +|---|---|---| +| `cosql_all` | all | Full dataset | +| `cosql_150` | 150 | Random subset | +| `cosql_150_balanced` | 150 | Equal per difficulty | +| `cosql_36_balanced` | 36 | Small balanced subset | + +**Spider** (`eval/stream-bench/data/spider_config.json`): + +| Task name | Samples | Notes | +|---|---|---| +| `spider_all` | all | Full dataset | +| `spider_150` | 150 | Random subset | +| `spider_150_balanced` | 150 | Equal per difficulty | +| `spider_150_quasi_balanced` | 150 | Balanced with fallback | + +### Custom Configuration + +Create a JSON file with one or more task entries: + +```json +{ + "my_task": { + "train_data": "eval/stream-bench/data/streambench_bird_train.jsonl", + "val_data": "eval/stream-bench/data/streambench_bird_val.jsonl", + "test_data": "eval/stream-bench/data/streambench_bird_test.jsonl", + + "bird_train_db_root": "eval/stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "eval/stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "eval/stream-bench/data/bird/dev_databases", + + "max_samples": 150, + "difficulty_filter": "quasi_balanced" + } +} +``` + +For CoSQL, replace the `bird_*` keys with `cosql_db_root`: +```json +{ + "my_cosql_task": { + "train_data": "eval/stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "eval/stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "eval/stream-bench/data/streambench_cosql_test.jsonl", + "cosql_db_root": "eval/stream-bench/data/cosql", + "max_samples": 100 + } +} +``` + +**Config field reference:** + +| Field | Type | Description | +|---|---|---| +| `train_data` | string | Path to training `.jsonl` | +| `val_data` | string | Path to validation `.jsonl` | +| `test_data` | string | Path to test `.jsonl` | +| `bird_db_root` | string | Default database root (BIRD/Spider) | +| `bird_train_db_root` | string | Override DB root for train split | +| `bird_val_db_root` | string | Override DB root for val split | +| `bird_test_db_root` | string | Override DB root for test split | +| `cosql_db_root` | string | Database root for CoSQL tasks | +| `max_samples` | int | Cap on samples (applies to all splits) | +| `max_train_samples` | int | Cap for train split only | +| `max_val_samples` | int | Cap for val split only | +| `max_test_samples` | int | Cap for test split only | +| `difficulty_filter` | string | `simple-only`, `moderate-only`, `challenging-only`, `balanced`, `quasi_balanced` | + + +## Step 4: Run ACE Training + +```bash +python eval/stream-bench/run.py \ + --data_config eval/stream-bench/data/bird_config.json \ + --task_name bird_150_balanced \ + --mode online \ + --api_provider sambanova \ + --generator_model DeepSeek-V3-0324 \ + --curator_model DeepSeek-V3-0324 \ + --reflector_model DeepSeek-V3-0324 +``` + +### Run modes + +| Mode | Description | +|---|---| +| `online` | Stream training data one-by-one, updating the playbook after each window | +| `offline` | Train on the full training set, then evaluate on test | +| `eval_only` | Skip training, evaluate test set with an existing or empty playbook | + +### Commonly used options + +| Flag | Default | Description | +|---|---|---| +| `--mode` | required | `online`, `offline`, or `eval_only` | +| `--task_name` | required | Key from the config JSON | +| `--data_config` | required | Path to config JSON | +| `--api_provider` | required | `sambanova`, `together`, or `openai` | +| `--generator_model` | required | Model name for SQL generation | +| `--curriculum` | none | `easy_to_hard`, `hard_to_easy`, or `random` | +| `--num_epochs` | 1 | Number of passes over training data | +| `--max_num_rounds` | — | Maximum curator/reflector rounds | +| `--eval_steps` | — | Evaluate every N training steps | +| `--playbook_token_budget` | — | Token budget for playbook size | +| `--test_workers` | 1 | Parallel workers for test evaluation | +| `--initial_playbook_path` | none | Warm-start from an existing playbook | +| `--plot` | false | Auto-generate plots after run completes | + + +## Step 5: Evaluate a Playbook + +After a run completes, evaluate a specific saved playbook on any data split: + +```bash +python eval/stream-bench/run_playbook.py \ + --results_dir results/ace_run_20260119_234301_bird_150_balanced_online \ + --playbook_file intermediate_playbooks/window_4_final_playbook.txt \ + --dataset test +``` + +To run a baseline evaluation with an empty playbook: + +```bash +python eval/stream-bench/run_playbook.py \ + --results_dir results/ace_run_20260119_234301_bird_150_balanced_online \ + --dataset test +``` + +Save detailed per-sample results to a JSON file: + +```bash +python eval/stream-bench/run_playbook.py \ + --results_dir results/ace_run_20260119_234301_bird_150_balanced_online \ + --playbook_file intermediate_playbooks/window_4_final_playbook.txt \ + --output_file playbook_eval_results.json \ + --dataset test +``` + +> **Note:** `run_playbook.py` reads samples from the `processed_data/` subfolder that `run.py` writes during its setup phase. Run `run.py` at least once first so that folder exists. + + +## Results Structure + +Each run creates a timestamped directory under `results/`: + +``` +results/ +└── ace_run_YYYYMMDD_HHMMSS_{task_name}_{mode}/ + ├── run_config.json # Full run configuration snapshot + ├── final_playbook.txt # Playbook after all training + ├── intermediate_playbooks/ + │ ├── window_1_final_playbook.txt + │ ├── window_2_final_playbook.txt + │ └── ... + ├── processed_data/ + │ ├── train_samples.json # Preprocessed training data + │ ├── val_samples.json # Preprocessed validation data + │ └── test_samples.json # Preprocessed test data + └── terminal_output.log # Full terminal log +``` + +Each `*_samples.json` is a list of objects with: +```json +{ + "context": "", + "question": "natural language question", + "target": "ground truth SQL", + "others": { + "question_id": "0", + "difficulty": "moderate", + "db_name": "california_schools", + "task": "bird", + "data_source": "streambench_bird" + } +} +``` + + +## Plotting + +Generate accuracy-over-time plots for a completed run: + +```bash +# Online run +python eval/stream-bench/plot.py \ + --run_dir results/ace_run_20260116_103642_bird_150_balanced_online \ + --mode online + +# Offline run +python eval/stream-bench/plot.py \ + --run_dir results/ace_run_20260122_203526_bird_150_balanced_offline \ + --mode offline +``` + +Or pass `--plot` to `run.py` to generate plots automatically at the end of a run. + + +## Reference: All CLI Arguments + +### run.py + +``` +Required: + --data_config PATH Config JSON file + --task_name TEXT Task key within the config JSON + --mode {online|offline|eval_only} + --api_provider {sambanova|together|openai} + --generator_model TEXT + +Optional (model): + --reflector_model TEXT + --curator_model TEXT + --max_tokens INT + +Optional (training): + --curriculum {easy_to_hard|hard_to_easy|random} + --num_epochs INT + --max_num_rounds INT + --curator_frequency INT + --eval_steps INT + --online_eval_frequency INT + --save_steps INT + --playbook_token_budget INT + --initial_playbook_path PATH + --pass_sql_eval_results / --no_pass_sql_eval_results + --json_mode / --no_json_mode + --no_ground_truth + --use_bulletpoint_analyzer / --no_bulletpoint_analyzer + --bulletpoint_analyzer_threshold INT + +Optional (output): + --save_path PATH + --test_workers INT + --plot +``` + +### run_playbook.py + +``` +Required: + --results_dir PATH ACE results directory + +Optional: + --playbook_file PATH Playbook file relative to results_dir + (omit for empty-playbook baseline) + --dataset {train|val|test} Split to evaluate (default: test) + --bird_db_root PATH BIRD database root (auto-detected by dataset) + --api_provider {sambanova|together|openai} + --generator_model TEXT + --num_workers INT Parallel generation workers (default: 4) + --output_file PATH Save detailed per-sample results as JSON +``` + +### download_text2sql_data.py + +``` + --dataset {bird|cosql|spider} + --split {dev|train} BIRD only (default: dev) + --save_dir PATH Where to save files (default: ./data) +``` + +### preprocess_streambench_{bird,cosql,spider}.py + +``` + --split {train|validation|test} + --out PATH Output .jsonl file + --tables_json PATH Path to tables.json (recommended for schema) + --bird_root / --cosql_root / --spider_root PATH DB root (fallback schema source) + --schema_format {json|string} (default: json) + --train_ratio FLOAT Train/val split ratio (default: 0.8, BIRD/Spider only) + --seed INT Random seed (default: 42) +``` diff --git a/eval/stream-bench/analyze_logs.py b/eval/stream-bench/analyze_logs.py new file mode 100644 index 00000000..2a463a07 --- /dev/null +++ b/eval/stream-bench/analyze_logs.py @@ -0,0 +1,610 @@ +#!/usr/bin/env python3 +""" +Script to analyze errors from ACE terminal log files. +Counts errors during initial and final test accuracy calculations and classifies error types. +""" + +import re +import sys +from pathlib import Path +from collections import defaultdict +from typing import Dict, List, Tuple + + +class ErrorAnalyzer: + def __init__(self, log_file_path: str): + self.log_file_path = Path(log_file_path) + self.initial_errors = defaultdict(list) + self.final_errors = defaultdict(list) + self.between_errors = defaultdict(list) # Errors between initial and final + self.current_phase = None # 'initial', 'final', or 'between' + self.current_window = None # Track current window number + self.window_accuracies = {} # Track window accuracies {window_num: accuracy} + self.cumulative_accuracies = {} # Track cumulative test accuracies {window_num: cumulative_acc} + self.cumulative_samples = {} # Track cumulative sample counts {window_num: total_samples} + self.window_sizes = {} # Track window sizes {window_num: num_samples} + self.default_window_size = None # Default window size if not specified per window + + # Track component-specific errors (generator, reflector, curator) + self.component_errors = defaultdict(lambda: defaultdict(list)) # {phase: {component: [errors]}} + + # Track which problems/samples have errors + self.problem_errors = defaultdict(set) # {phase: set of (window, sample_idx/call_id)} + + # Track current sample being processed + self.current_sample_idx = None + self.current_call_id = None + + def classify_error(self, error_line: str, details: str = "") -> str: + """Classify the type of error based on the error message.""" + combined_text = error_line + " " + details + + # Context length errors + if "context_length_exceeded" in combined_text or "tokens exceed" in combined_text: + return "context_length_exceeded" + + # Rate limit errors + if "rate_limit" in combined_text.lower() or "429" in combined_text: + return "rate_limit" + + # Invalid request errors (excluding context length) + if "invalid_request_error" in combined_text or "invalid prompt" in combined_text: + if "context_length" not in combined_text: + return "invalid_request" + + # Authentication errors + if "authentication" in combined_text.lower() or "401" in combined_text: + return "authentication_error" + + # Server errors + if "500" in combined_text or "502" in combined_text or "503" in combined_text: + return "server_error" + + # Timeout errors + if "timeout" in combined_text.lower(): + return "timeout_error" + + # Connection errors + if "connection" in combined_text.lower(): + return "connection_error" + + # Client errors (general) + if "client error" in combined_text.lower(): + return "client_error_other" + + # Unknown errors + return "unknown_error" + + def extract_component(self, line: str) -> str: + """Extract the component name (GENERATOR, REFLECTOR, CURATOR) from the log line.""" + if '[GENERATOR]' in line: + return 'GENERATOR' + elif '[REFLECTOR]' in line: + return 'REFLECTOR' + elif '[CURATOR]' in line: + return 'CURATOR' + return 'UNKNOWN' + + def extract_base_sample_id(self, call_id: str) -> str: + """Extract base sample ID from call ID. + + Examples: + 'online_train_s_1455_round_2' -> '1455' + 'online_train_s_1320_post_curate' -> '1320' + 'gen_call_1' -> 'gen_call_1' + '123' -> '123' + """ + # Try to match patterns like "s_1455_round_2" or "s_1320_post_curate" + match = re.search(r's_(\d+)(?:_round_\d+|_post_curate|_pre_curate)?', call_id) + if match: + return match.group(1) + + # If no pattern match, return the original (might be sample_idx) + return call_id + + def detect_phase(self, line: str) -> None: + """Detect whether we're in the initial or final test accuracy calculation phase.""" + # Track window size (appears early in logs as "Window size: N") + window_size_match = re.search(r'Window size:\s*(\d+)', line, re.IGNORECASE) + if window_size_match: + self.default_window_size = int(window_size_match.group(1)) + + # Track window numbers - these set the context for what window we're in + window_match = re.search(r'WINDOW (\d+)', line, re.IGNORECASE) + if window_match: + self.current_window = int(window_match.group(1)) + + # Also check for "Testing window X" pattern + testing_window_match = re.search(r'Testing window (\d+)', line, re.IGNORECASE) + if testing_window_match: + self.current_window = int(testing_window_match.group(1)) + + # Track window step counts like "Window 1, Step 5/15" + window_step_match = re.search(r'Window\s+(\d+),\s+Step\s+\d+/(\d+)', line, re.IGNORECASE) + if window_step_match: + window_num = int(window_step_match.group(1)) + window_size = int(window_step_match.group(2)) + self.window_sizes[window_num] = window_size + + # Track sample indices and call IDs + sample_match = re.search(r'sample[_ ]?(?:idx|index|#)?[:\s]+(\d+)', line, re.IGNORECASE) + if sample_match: + self.current_sample_idx = int(sample_match.group(1)) + + # Track call IDs from component logs like "[GENERATOR] Call XXX" + call_match = re.search(r'\[(GENERATOR|REFLECTOR|CURATOR)\]\s+Call\s+([A-Za-z0-9_-]+)', line) + if call_match: + self.current_call_id = call_match.group(2) + + # Track window accuracy: "Window X test accuracy: Y.YYY" + # This pattern appears after testing each window and tells us which window completed + accuracy_match = re.search(r'Window\s+(\d+)\s+test\s+accuracy:\s+([\d.]+)', line, re.IGNORECASE) + if accuracy_match: + window_num = int(accuracy_match.group(1)) + accuracy = float(accuracy_match.group(2)) + self.window_accuracies[window_num] = accuracy + # Update current window to match the window that just reported accuracy + self.current_window = window_num + + # Track cumulative test accuracy: "Cumulative test accuracy so far: Y.YYY (N samples)" + # This appears right after the window accuracy line + # IMPORTANT: This cumulative accuracy should be associated with the window that just completed + cumulative_match = re.search(r'Cumulative\s+test\s+accuracy\s+so\s+far:\s+([\d.]+)\s+\((\d+)\s+samples\)', line, re.IGNORECASE) + if cumulative_match: + cumulative_acc = float(cumulative_match.group(1)) + total_samples = int(cumulative_match.group(2)) + # Associate with current window (which should have been set by the previous "Window X test accuracy" line) + if self.current_window: + self.cumulative_accuracies[self.current_window] = cumulative_acc + self.cumulative_samples[self.current_window] = total_samples + + # Look for phase indicators in the log + if "initial test acc" in line.lower() or "calculating initial" in line.lower(): + self.current_phase = "initial" + elif "final test acc" in line.lower() or "calculating final" in line.lower(): + self.current_phase = "final" + elif "test acc" in line.lower() and self.current_phase is None: + # If we haven't seen initial yet, assume we're in initial phase + self.current_phase = "initial" + elif self.current_phase == "initial" and "done" in line.lower(): + # Initial phase completed, now in between phase + self.current_phase = "between" + + def analyze_log(self) -> None: + """Parse the log file and categorize errors.""" + if not self.log_file_path.exists(): + print(f"Error: Log file not found at {self.log_file_path}") + sys.exit(1) + + with open(self.log_file_path, 'r', encoding='utf-8', errors='ignore') as f: + lines = f.readlines() + + i = 0 + while i < len(lines): + line = lines[i].strip() + + # Detect phase changes + self.detect_phase(line) + + # Look for error indicators + # Match both ⚠️ errors and "failed after" errors + is_warning_error = "⚠️" in line and ("error" in line.lower() or "Error" in line) + is_failed_error = "failed after" in line and ("Error" in line or "error" in line) + + if is_warning_error or is_failed_error: + # Extract error details from the current line + error_msg = line + + # Extract component that caused the error + component = self.extract_component(error_msg) + + # Look ahead for [GENERATOR] Error details on the next line(s) + details = "" + j = i + 1 + while j < len(lines) and j < i + 5: # Look ahead up to 5 lines + next_line = lines[j].strip() + if "[GENERATOR] Error details:" in next_line or "[REFLECTOR] Error details:" in next_line or "[CURATOR] Error details:" in next_line or "Error code:" in next_line: + details += " " + next_line + j += 1 + # Also extract component from error details if not found yet + if component == 'UNKNOWN': + component = self.extract_component(next_line) + elif next_line and not next_line.startswith("["): + # Continuation of error details + details += " " + next_line + j += 1 + else: + break + + # Classify the error + error_type = self.classify_error(error_msg, details) + + # Determine phase for tracking + phase = self.current_phase if self.current_phase else "initial" + + # Store in appropriate phase + error_entry = { + 'line_num': i + 1, + 'message': error_msg, + 'details': details.strip(), + 'window': self.current_window, + 'component': component, + 'sample_idx': self.current_sample_idx, + 'call_id': self.current_call_id + } + + # Track which problems have errors + problem_key = (self.current_window, self.current_sample_idx or self.current_call_id) + if problem_key[1] is not None: # Only track if we have a sample/call identifier + self.problem_errors[phase].add(problem_key) + + # Store in component-specific tracking + self.component_errors[phase][component].append(error_entry) + + if self.current_phase == "final": + self.final_errors[error_type].append(error_entry) + elif self.current_phase == "between": + self.between_errors[error_type].append(error_entry) + else: + # Default to initial if phase is unknown + self.initial_errors[error_type].append(error_entry) + + i += 1 + + def print_component_analysis(self, phase: str, phase_name: str) -> None: + """Print component-specific error analysis for a given phase.""" + component_stats = self.component_errors.get(phase, {}) + if not component_stats: + return + + print(f"\nComponent breakdown:") + total_component_errors = sum(len(errors) for errors in component_stats.values()) + + for component in ['GENERATOR', 'REFLECTOR', 'CURATOR', 'UNKNOWN']: + errors = component_stats.get(component, []) + if errors: + percentage = (len(errors) / total_component_errors * 100) if total_component_errors > 0 else 0 + print(f" • {component}: {len(errors)} errors ({percentage:.1f}%)") + + # Count unique problems for this component + unique_problems = set() + for error in errors: + problem_key = (error['window'], error['sample_idx'] or error['call_id']) + if problem_key[1] is not None: + unique_problems.add(problem_key) + + if unique_problems: + avg_errors_per_problem = len(errors) / len(unique_problems) + print(f" - Unique problems affected: {len(unique_problems)}") + print(f" - Average errors per problem: {avg_errors_per_problem:.1f}") + + def print_problem_analysis(self, phase: str, phase_name: str) -> None: + """Print analysis of which problems were affected by errors.""" + problems = self.problem_errors.get(phase, set()) + if not problems: + return + + print(f"\nProblems with errors:") + print(f" • Total unique problems affected: {len(problems)}") + + # Count errors per problem + problem_error_counts = defaultdict(int) + all_errors = [] + + # Gather all errors for this phase + if phase == "initial": + for errors in self.initial_errors.values(): + all_errors.extend(errors) + elif phase == "final": + for errors in self.final_errors.values(): + all_errors.extend(errors) + elif phase == "between": + for errors in self.between_errors.values(): + all_errors.extend(errors) + + for error in all_errors: + problem_key = (error['window'], error['sample_idx'] or error['call_id']) + if problem_key[1] is not None: + problem_error_counts[problem_key] += 1 + + if problem_error_counts: + # Find problems with multiple errors + multiple_errors = {k: v for k, v in problem_error_counts.items() if v > 1} + if multiple_errors: + print(f" • Problems with multiple errors: {len(multiple_errors)}") + print(f" - Max errors for single problem: {max(multiple_errors.values())}") + print(f" - Average errors per problem (for problems with >1 error): {sum(multiple_errors.values()) / len(multiple_errors):.1f}") + + # Show top 5 problems with most errors + top_problems = sorted(multiple_errors.items(), key=lambda x: x[1], reverse=True)[:5] + print(f" - Top problems by error count:") + for (window, identifier), count in top_problems: + print(f" * Window {window}, Sample/Call {identifier}: {count} errors") + + def print_report(self, debug=False) -> None: + """Print a formatted report of the error analysis.""" + print("=" * 80) + print("ERROR ANALYSIS REPORT") + print("=" * 80) + print(f"Log file: {self.log_file_path}") + + # Debug: Show what accuracies were captured + if debug: + print("\n[DEBUG] Captured accuracies:") + for window in sorted(self.window_accuracies.keys()): + win_acc = self.window_accuracies.get(window, "N/A") + cum_acc = self.cumulative_accuracies.get(window, "N/A") + samples = self.cumulative_samples.get(window, "N/A") + print(f" Window {window}: window_acc={win_acc}, cumulative_acc={cum_acc}, samples={samples}") + print() + + # Initial test acc errors + print("-" * 80) + print("INITIAL TEST ACC CALCULATION") + print("-" * 80) + total_initial = sum(len(errors) for errors in self.initial_errors.values()) + print(f"Total errors: {total_initial}") + + if total_initial > 0: + print("\nError breakdown by type:") + for error_type, errors in sorted(self.initial_errors.items()): + print(f" • {error_type}: {len(errors)}") + + # Show component analysis + self.print_component_analysis("initial", "INITIAL TEST ACC") + + # Show problem analysis + self.print_problem_analysis("initial", "INITIAL TEST ACC") + + # Show window distribution + window_counts = defaultdict(int) + window_problems = defaultdict(set) # Track unique API calls per window + window_base_samples = defaultdict(set) # Track unique base sample IDs per window + for errors in self.initial_errors.values(): + for error in errors: + if error['window']: + window_counts[error['window']] += 1 + # Track unique API calls + problem_id = error['sample_idx'] or error['call_id'] + if problem_id is not None: + window_problems[error['window']].add(problem_id) + # Extract base sample ID (e.g., "1455" from "online_train_s_1455_round_2") + base_sample = self.extract_base_sample_id(str(problem_id)) + if base_sample: + window_base_samples[error['window']].add(base_sample) + if window_counts: + print("\nErrors by window:") + for window in sorted(window_counts.keys()): + acc_info = "" + if window in self.cumulative_accuracies: + samples_info = f", {self.cumulative_samples[window]} samples" if window in self.cumulative_samples else "" + acc_info = f" (cumulative acc: {self.cumulative_accuracies[window]:.3f}{samples_info})" + elif window in self.window_accuracies: + acc_info = f" (window acc: {self.window_accuracies[window]:.3f})" + + unique_calls = len(window_problems[window]) + unique_samples = len(window_base_samples[window]) + # Get window size (total problems in window) + window_size = self.window_sizes.get(window, self.default_window_size) + if window_size and unique_samples > 0: + problems_info = f", {unique_samples}/{window_size} problems ({unique_calls} unique API calls)" + elif unique_samples > 0: + problems_info = f", {unique_samples} problems ({unique_calls} unique API calls)" + else: + problems_info = "" + print(f" • Window {window}: {window_counts[window]} errors{problems_info}{acc_info}") + print() + + # Final test acc errors + print("-" * 80) + print("FINAL TEST ACC CALCULATION") + print("-" * 80) + total_final = sum(len(errors) for errors in self.final_errors.values()) + print(f"Total errors: {total_final}") + + if total_final > 0: + print("\nError breakdown by type:") + for error_type, errors in sorted(self.final_errors.items()): + print(f" • {error_type}: {len(errors)}") + + # Show component analysis + self.print_component_analysis("final", "FINAL TEST ACC") + + # Show problem analysis + self.print_problem_analysis("final", "FINAL TEST ACC") + + # Show window distribution + window_counts = defaultdict(int) + window_problems = defaultdict(set) # Track unique API calls per window + window_base_samples = defaultdict(set) # Track unique base sample IDs per window + for errors in self.final_errors.values(): + for error in errors: + if error['window']: + window_counts[error['window']] += 1 + # Track unique API calls + problem_id = error['sample_idx'] or error['call_id'] + if problem_id is not None: + window_problems[error['window']].add(problem_id) + # Extract base sample ID (e.g., "1455" from "online_train_s_1455_round_2") + base_sample = self.extract_base_sample_id(str(problem_id)) + if base_sample: + window_base_samples[error['window']].add(base_sample) + if window_counts: + print("\nErrors by window:") + for window in sorted(window_counts.keys()): + acc_info = "" + if window in self.cumulative_accuracies: + samples_info = f", {self.cumulative_samples[window]} samples" if window in self.cumulative_samples else "" + acc_info = f" (cumulative acc: {self.cumulative_accuracies[window]:.3f}{samples_info})" + elif window in self.window_accuracies: + acc_info = f" (window acc: {self.window_accuracies[window]:.3f})" + + unique_calls = len(window_problems[window]) + unique_samples = len(window_base_samples[window]) + # Get window size (total problems in window) + window_size = self.window_sizes.get(window, self.default_window_size) + if window_size and unique_samples > 0: + problems_info = f", {unique_samples}/{window_size} problems ({unique_calls} unique API calls)" + elif unique_samples > 0: + problems_info = f", {unique_samples} problems ({unique_calls} unique API calls)" + else: + problems_info = "" + print(f" • Window {window}: {window_counts[window]} errors{problems_info}{acc_info}") + print() + + # Between phase errors + total_between = sum(len(errors) for errors in self.between_errors.values()) + if total_between > 0: + print("-" * 80) + print("ERRORS BETWEEN INITIAL AND FINAL TEST ACC") + print("-" * 80) + print(f"Total errors: {total_between}") + print("\nError breakdown by type:") + for error_type, errors in sorted(self.between_errors.items()): + print(f" • {error_type}: {len(errors)}") + + # Show component analysis + self.print_component_analysis("between", "BETWEEN PHASES") + + # Show problem analysis + self.print_problem_analysis("between", "BETWEEN PHASES") + + # Show window distribution + window_counts = defaultdict(int) + window_problems = defaultdict(set) # Track unique API calls per window + window_base_samples = defaultdict(set) # Track unique base sample IDs per window + for errors in self.between_errors.values(): + for error in errors: + if error['window']: + window_counts[error['window']] += 1 + # Track unique API calls + problem_id = error['sample_idx'] or error['call_id'] + if problem_id is not None: + window_problems[error['window']].add(problem_id) + # Extract base sample ID (e.g., "1455" from "online_train_s_1455_round_2") + base_sample = self.extract_base_sample_id(str(problem_id)) + if base_sample: + window_base_samples[error['window']].add(base_sample) + if window_counts: + print("\nErrors by window:") + for window in sorted(window_counts.keys()): + acc_info = "" + if window in self.cumulative_accuracies: + samples_info = f", {self.cumulative_samples[window]} samples" if window in self.cumulative_samples else "" + acc_info = f" (cumulative acc: {self.cumulative_accuracies[window]:.3f}{samples_info})" + elif window in self.window_accuracies: + acc_info = f" (window acc: {self.window_accuracies[window]:.3f})" + + unique_calls = len(window_problems[window]) + unique_samples = len(window_base_samples[window]) + # Get window size (total problems in window) + window_size = self.window_sizes.get(window, self.default_window_size) + if window_size and unique_samples > 0: + problems_info = f", {unique_samples}/{window_size} problems ({unique_calls} unique API calls)" + elif unique_samples > 0: + problems_info = f", {unique_samples} problems ({unique_calls} unique API calls)" + else: + problems_info = "" + print(f" • Window {window}: {window_counts[window]} errors{problems_info}{acc_info}") + print() + + # Detailed error samples + if total_initial > 0 or total_final > 0 or total_between > 0: + print("=" * 80) + print("DETAILED ERROR SAMPLES") + print("=" * 80) + + if total_initial > 0: + print("\n[INITIAL TEST ACC - Sample errors]") + for error_type, errors in sorted(self.initial_errors.items()): + window_info = f"Window {errors[0]['window']}" if errors[0]['window'] else "Window unknown" + print(f"\n {error_type} (showing first of {len(errors)}):") + print(f" {window_info}, Line {errors[0]['line_num']}: {errors[0]['message'][:100]}...") + if errors[0]['details']: + print(f" Details: {errors[0]['details'][:200]}...") + + if total_final > 0: + print("\n[FINAL TEST ACC - Sample errors]") + for error_type, errors in sorted(self.final_errors.items()): + window_info = f"Window {errors[0]['window']}" if errors[0]['window'] else "Window unknown" + print(f"\n {error_type} (showing first of {len(errors)}):") + print(f" {window_info}, Line {errors[0]['line_num']}: {errors[0]['message'][:100]}...") + if errors[0]['details']: + print(f" Details: {errors[0]['details'][:200]}...") + + if total_between > 0: + print("\n[BETWEEN PHASES - Sample errors]") + for error_type, errors in sorted(self.between_errors.items()): + window_info = f"Window {errors[0]['window']}" if errors[0]['window'] else "Window unknown" + print(f"\n {error_type} (showing first of {len(errors)}):") + print(f" {window_info}, Line {errors[0]['line_num']}: {errors[0]['message'][:100]}...") + if errors[0]['details']: + print(f" Details: {errors[0]['details'][:200]}...") + + print("\n" + "=" * 80) + print("SUMMARY") + print("=" * 80) + print(f"Initial test acc errors: {total_initial}") + print(f"Final test acc errors: {total_final}") + print(f"Between phases errors: {total_between}") + print(f"Total errors: {total_initial + total_final + total_between}") + + # Special callout for specific windows + special_windows = [30, 50, 70, 80, 90, 100] + windows_found = [w for w in special_windows if w in self.window_accuracies] + + if windows_found: + print("\nKey Window Statistics:") + for window_num in windows_found: + # Count errors in this window + window_errors = 0 + for errors in self.initial_errors.values(): + window_errors += sum(1 for e in errors if e['window'] == window_num) + for errors in self.final_errors.values(): + window_errors += sum(1 for e in errors if e['window'] == window_num) + for errors in self.between_errors.values(): + window_errors += sum(1 for e in errors if e['window'] == window_num) + + # Prefer cumulative accuracy over window accuracy + if window_num in self.cumulative_accuracies: + acc_value = self.cumulative_accuracies[window_num] + acc_label = "Cumulative Acc" + samples = self.cumulative_samples.get(window_num) + samples_str = f" ({samples} samples)" if samples else "" + elif window_num in self.window_accuracies: + acc_value = self.window_accuracies[window_num] + acc_label = "Window Acc" + samples_str = "" + else: + acc_value = None + acc_label = "Accuracy" + samples_str = "" + + if acc_value is not None: + print(f" Window {window_num:3d} - {acc_label}: {acc_value:.3f}{samples_str}, Errors: {window_errors}") + else: + print(f" Window {window_num:3d} - Errors: {window_errors}") + + print("=" * 80) + + +def main(): + """Main entry point for the script.""" + if len(sys.argv) < 2: + print("Usage: python analyze_errors.py [--debug]") + print("\nExample:") + print(" python analyze_errors.py results/ace_run/terminal_log.txt") + print(" python analyze_errors.py results/ace_run_20240115/terminal_log.txt") + print(" python analyze_errors.py results/ace_run/terminal_log.txt --debug") + sys.exit(1) + + log_file = sys.argv[1] + debug = "--debug" in sys.argv + + analyzer = ErrorAnalyzer(log_file) + analyzer.analyze_log() + analyzer.print_report(debug=debug) + + +if __name__ == "__main__": + main() diff --git a/eval/stream-bench/data/bird_config.json b/eval/stream-bench/data/bird_config.json new file mode 100644 index 00000000..67e533bb --- /dev/null +++ b/eval/stream-bench/data/bird_config.json @@ -0,0 +1,112 @@ +{ + "bird_all": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_10": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 10, + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases", + "_comment": "Old format: max_samples applies to all splits (train, val, test)" + }, + "bird_custom_separate": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_train_samples": 100, + "max_val_samples": 50, + "max_test_samples": 10, + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases", + "_comment": "Separate limits: train=100, val=50, test=10 (no default)" + }, + "bird_custom_with_default": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 50, + "max_train_samples": 200, + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases", + "_comment": "Default=50, but train overridden to 200. Result: train=200, val=50, test=50" + }, + "bird_30": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 30, + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_150": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 150, + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_150_balanced": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 150, + "max_val_samples": 30, + "difficulty_filter": "balanced", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_300_balanced": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 300, + "difficulty_filter": "balanced", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_60_balanced": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 60, + "difficulty_filter": "balanced", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_432_balanced": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 432, + "difficulty_filter": "balanced", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + }, + "bird_1000_quasi_balanced": { + "train_data": "stream-bench/data/streambench_bird_train.jsonl", + "val_data": "stream-bench/data/streambench_bird_val.jsonl", + "test_data": "stream-bench/data/streambench_bird_test.jsonl", + "max_samples": 1000, + "difficulty_filter": "quasi_balanced", + "bird_train_db_root": "stream-bench/data/bird_train/train_databases", + "bird_val_db_root": "stream-bench/data/bird_train/train_databases", + "bird_test_db_root": "stream-bench/data/bird/dev_databases" + } +} diff --git a/eval/stream-bench/data/cosql_config.json b/eval/stream-bench/data/cosql_config.json new file mode 100644 index 00000000..688b2454 --- /dev/null +++ b/eval/stream-bench/data/cosql_config.json @@ -0,0 +1,64 @@ +{ + "cosql_all": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_10": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 10, + "cosql_db_root": "stream-bench/data/cosql", + "_comment": "Old format: max_samples applies to all splits (train, val, test)" + }, + "cosql_36_balanced": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 36, + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_50": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 50, + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_150": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 150, + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_150_balanced": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 150, + "max_val_samples": 90, + "difficulty_filter": "balanced", + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_300_balanced": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_samples": 300, + "difficulty_filter": "balanced", + "cosql_db_root": "stream-bench/data/cosql" + }, + "cosql_custom": { + "train_data": "stream-bench/data/streambench_cosql_train.jsonl", + "val_data": "stream-bench/data/streambench_cosql_val.jsonl", + "test_data": "stream-bench/data/streambench_cosql_test.jsonl", + "max_train_samples": 100, + "max_val_samples": 50, + "max_test_samples": 10, + "cosql_db_root": "stream-bench/data/cosql", + "_comment": "Separate limits: train=100, val=50, test=10" + } +} diff --git a/eval/stream-bench/data/spider_config.json b/eval/stream-bench/data/spider_config.json new file mode 100644 index 00000000..39ddd986 --- /dev/null +++ b/eval/stream-bench/data/spider_config.json @@ -0,0 +1,66 @@ +{ + "spider_all": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "bird_db_root": "stream-bench/data/spider" + }, + "spider_10": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 10, + "bird_db_root": "stream-bench/data/spider", + "_comment": "Old format: max_samples applies to all splits (train, val, test)" + }, + "spider_30": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 30, + "bird_db_root": "stream-bench/data/spider" + }, + "spider_150": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 150, + "bird_db_root": "stream-bench/data/spider" + }, + "spider_150_balanced": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 150, + "max_val_samples": 30, + "difficulty_filter": "balanced", + "bird_db_root": "stream-bench/data/spider" + }, + "spider_15_balanced": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 15, + "difficulty_filter": "balanced", + "bird_db_root": "stream-bench/data/spider" + }, + "spider_15_quasi_balanced": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 15, + "difficulty_filter": "quasi_balanced", + "bird_db_root": "stream-bench/data/spider", + "_comment": "Uses fallback: simple->moderate, challenging->moderate, moderate->half simple + half challenging" + }, + "spider_150_quasi_balanced": { + "train_data": "stream-bench/data/streambench_spider_train.jsonl", + "val_data": "stream-bench/data/streambench_spider_val.jsonl", + "test_data": "stream-bench/data/streambench_spider_test.jsonl", + "max_samples": 150, + "max_val_samples": 30, + "difficulty_filter": "quasi_balanced", + "bird_db_root": "stream-bench/data/spider", + "_comment": "Quasi-balanced with separate val sample limit" + } +} diff --git a/eval/stream-bench/data_processor.py b/eval/stream-bench/data_processor.py new file mode 100644 index 00000000..92aa4588 --- /dev/null +++ b/eval/stream-bench/data_processor.py @@ -0,0 +1,697 @@ +import os +import re +import sqlite3 +import time +from typing import List, Dict, Any, Optional, Tuple + +class DataProcessor: + """ + DataProcessor for BIRD and CoSQL with thread-safe evaluation. + + Evaluation mode: execute predicted & gold on sqlite DB and compare result sets + + Thread-safety: db_name metadata flows through sample['others'] dict to enable + safe parallel evaluation across multiple workers. + """ + + def __init__( + self, + bird_db_root: Optional[str] = None, + cosql_db_root: Optional[str] = None, + exec_timeout_ms: int = 20000, + exec_max_rows: int = 20000, + max_samples: Optional[int] = None, # None = use all samples + db_name: Optional[str] = None, # None = use mixed databases (no filter) + difficulty_filter: Optional[str] = None, # None = no difficulty filtering + curriculum: Optional[str] = None, # None = no curriculum ordering + task: str = "bird", # "bird" or "cosql" + ): + """ + Initialize DataProcessor. + + Args: + bird_db_root: Root directory for BIRD databases + cosql_db_root: Root directory for CoSQL databases + task: Task type ("bird" or "cosql") + difficulty_filter: Strategy for selecting samples by difficulty (dataset-level). Options: + - None: No filtering, use all samples + - "simple-only": Only simple difficulty samples + - "moderate-only": Only moderate difficulty samples + - "challenging-only": Only challenging difficulty samples + - "balanced": Equal distribution (1/3 from each difficulty) + - "quasi_balanced": Equal distribution with fallback to closest difficulty. + If not enough samples: simple uses moderate, challenging uses moderate, + moderate uses half simple + half challenging + curriculum: Strategy for ordering samples (run-level). Options: + - None: No reordering (original order from dataset) + - "easy_to_hard": Easy -> Medium -> Challenging order + - "hard_to_easy": Challenging -> Medium -> Easy order + - "random": Random order (fixed seed) + """ + self.bird_db_root = bird_db_root + self.cosql_db_root = cosql_db_root + self.task = task + self.exec_timeout_ms = exec_timeout_ms + self.exec_max_rows = exec_max_rows + self.max_samples = max_samples + self.db_name = db_name + self.difficulty_filter = difficulty_filter + self.curriculum = curriculum + + # Validate difficulty_filter option + valid_filters = [ + None, "simple-only", "moderate-only", "challenging-only", "balanced", "quasi_balanced" + ] + if self.difficulty_filter not in valid_filters: + raise ValueError( + f"Invalid difficulty_filter '{self.difficulty_filter}'. " + f"Valid options: {[f for f in valid_filters if f is not None]}" + ) + + # Validate curriculum option + valid_curricula = [None, "easy_to_hard", "hard_to_easy", "random"] + if self.curriculum not in valid_curricula: + raise ValueError( + f"Invalid curriculum '{self.curriculum}'. " + f"Valid options: {[c for c in valid_curricula if c is not None]}" + ) + + # ------------------------- + # REQUIRED SIGNATURES + # ------------------------- + + def process_task_data(self, raw_data): + """ + Convert your BIRD JSONL rows into standardized format: + [{"context": ..., "question": ..., "target": ..., "others": {...}}] + + db_name is stored in each sample's 'others' dict for thread-safe evaluation. + """ + processed = [] + + # Step 1: Filter by db_name if specified + if self.db_name is not None: + raw_data = [ + item for item in raw_data + if (item.get("db_name") or item.get("db_id") or "") == self.db_name + ] + print(f"After db_name filter ('{self.db_name}'): {len(raw_data)} samples") + + # Step 2: Apply difficulty_filter for dataset-level selection + if self.difficulty_filter is not None: + raw_data = self._apply_difficulty_filter(raw_data) + + # Step 3: Apply curriculum for run-level ordering + if self.curriculum is not None: + raw_data = self._apply_curriculum_ordering(raw_data) + + # Step 4: Apply max_samples limit (if not already applied by balanced difficulty filter) + if self.max_samples is not None and self.difficulty_filter != "balanced": + if len(raw_data) > self.max_samples: + raw_data = raw_data[:self.max_samples] + print(f"Applied max_samples limit: {self.max_samples} samples") + + # Print summary of processed data + self._print_data_summary(raw_data) + + for item in raw_data: + db_name = item.get("db_name") or item.get("db_id") or "" + question = item.get("question", "") + target_sql = item.get("sql", "") or item.get("SQL", "") + db_schema = item.get("db_schema", "") + + context = f"""You are given a database schema and a question. + + INSTRUCTIONS: + - Output ONLY a valid SQL query. + - Do NOT include explanations, comments, markdown, or any extra text. + - You may ONLY reference tables and columns that appear in the schema below. + - Do NOT hallucinate tables, columns, or relationships. + + DATABASE SCHEMA: + {db_schema} + """ + + processed.append({ + "context": context, + "question": question, + "target": target_sql, + "others": { + "question_id": item.get("question_id"), + "difficulty": item.get("difficulty"), + "db_name": db_name, + "task": self.task, + "data_source": f"streambench_{self.task}", + "turn_id": item.get("turn_id"), # For CoSQL conversational turns + } + }) + + return processed + + def answer_is_correct(self, predicted, ground_truth, sample_metadata=None, return_exec_results=False): + """ + Compare predicted vs ground_truth using exec mode. + + Args: + predicted: Predicted SQL query + ground_truth: Ground truth SQL query + sample_metadata: Optional dict containing 'db_name' and other metadata + return_exec_results: If True, return tuple of (is_correct, exec_results_dict) + + Returns: + bool: True if execution results match, False otherwise + OR + tuple: (bool, dict) if return_exec_results=True, where dict contains: + - 'predicted_result': List of tuples from predicted SQL execution + - 'ground_truth_result': List of tuples from ground truth SQL execution + - 'db_name': Database name used for evaluation + - 'error': Error message if execution failed + """ + # Extract db_name from metadata + db_name = "" + if sample_metadata: + db_name = sample_metadata.get("db_name", "") + + # If db_name not available, return False + if not db_name: + print(f"Warning: No db_name available in sample metadata") + if return_exec_results: + return False, {"error": "No db_name available in sample metadata", "db_name": ""} + return False + + print(f"\n[EVAL START] Evaluating on DB: {db_name}") + print(f"[EVAL START] Predicted SQL: {predicted[:200]}...") # First 200 chars + print(f"[EVAL START] Ground truth SQL: {ground_truth[:200]}...") + + try: + result, exec_results = self._exec_match(predicted, ground_truth, db_name, return_exec_results=return_exec_results) + print(f"[EVAL DONE] Result: {result}") + if return_exec_results: + return result, exec_results + return result + except FileNotFoundError: + # Re-raise database not found errors to stop execution + raise + except Exception as e: + # Other execution errors: print and return False + print(f"[EVAL ERROR] Exception during evaluation: {e}") + if return_exec_results: + return False, {"error": str(e), "db_name": db_name} + return False + + def evaluate_accuracy(self, predictions, ground_truths, samples=None): + """ + Calculate accuracy using execution-based evaluation. + + For parallel test evaluation: The actual correctness is determined by + answer_is_correct() in worker threads. This method re-evaluates using + sample metadata if available, or falls back to string comparison. + + Args: + predictions: List of predicted SQL queries + ground_truths: List of ground truth SQL queries + samples: Optional list of sample dicts with 'others' metadata + + Returns: + float: Accuracy score (0.0 to 1.0) + """ + if len(predictions) != len(ground_truths): + raise ValueError("predictions and ground_truths must have the same length") + if len(predictions) == 0: + return 0.0 + + correct = 0 + for i, (p, g) in enumerate(zip(predictions, ground_truths)): + # If we have sample metadata, use execution-based evaluation + if samples and i < len(samples): + sample_metadata = samples[i].get("others", None) if isinstance(samples[i], dict) else None + if self.answer_is_correct(p, g, sample_metadata): + correct += 1 + else: + # Fallback to string comparison for training (no metadata available) + if p.strip().lower() == g.strip().lower(): + correct += 1 + + return correct / len(predictions) + + # ------------------------- + # DIFFICULTY FILTER & CURRICULUM LOGIC + # ------------------------- + + def _apply_difficulty_filter(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Apply difficulty-based filtering to select samples (dataset-level). + + Args: + raw_data: List of raw data items with 'difficulty' field + + Returns: + Filtered list based on difficulty_filter strategy + """ + if not self.difficulty_filter: + return raw_data + + # Categorize samples by difficulty + simple_samples = [] + moderate_samples = [] + challenging_samples = [] + unknown_samples = [] + + for item in raw_data: + difficulty = (item.get("difficulty") or "").lower() + if difficulty == "simple": + simple_samples.append(item) + elif difficulty == "moderate": + moderate_samples.append(item) + elif difficulty == "challenging": + challenging_samples.append(item) + else: + unknown_samples.append(item) + + # Apply difficulty filter strategy + if self.difficulty_filter == "simple-only": + result = simple_samples + print(f"Difficulty filter 'simple-only': Selected {len(result)} simple samples") + + elif self.difficulty_filter == "moderate-only": + result = moderate_samples + print(f"Difficulty filter 'moderate-only': Selected {len(result)} moderate samples") + + elif self.difficulty_filter == "challenging-only": + result = challenging_samples + print(f"Difficulty filter 'challenging-only': Selected {len(result)} challenging samples") + + elif self.difficulty_filter == "balanced": + # Equal distribution from each difficulty (1/3 each) + # If max_samples is set, distribute it equally across difficulties + if self.max_samples is not None: + target_per_difficulty = self.max_samples // 3 + min_count = min( + len(simple_samples), + len(moderate_samples), + len(challenging_samples), + target_per_difficulty + ) + else: + min_count = min(len(simple_samples), len(moderate_samples), len(challenging_samples)) + + # Check if we have samples from all difficulty levels + if min_count == 0: + missing = [] + if len(simple_samples) == 0: + missing.append("simple") + if len(moderate_samples) == 0: + missing.append("moderate") + if len(challenging_samples) == 0: + missing.append("challenging") + raise ValueError( + f"Difficulty filter 'balanced' requires samples from all difficulty levels. " + f"Missing difficulty levels: {', '.join(missing)}. " + f"Available: simple={len(simple_samples)}, moderate={len(moderate_samples)}, " + f"challenging={len(challenging_samples)}" + ) + + # Combine samples without reordering (order depends on curriculum) + result = ( + simple_samples[:min_count] + + moderate_samples[:min_count] + + challenging_samples[:min_count] + ) + + # Check if we can meet max_samples requirement + if self.max_samples is not None and len(result) < self.max_samples: + raise ValueError( + f"Cannot meet max_samples={self.max_samples} with difficulty filter 'balanced'. " + f"Need {self.max_samples // 3} samples per difficulty, but only have: " + f"simple={len(simple_samples)}, moderate={len(moderate_samples)}, " + f"challenging={len(challenging_samples)}. " + f"Can only provide {len(result)} samples ({min_count} of each difficulty)." + ) + + print(f"Difficulty filter 'balanced': Selected {min_count} from each difficulty " + f"(total: {len(result)} samples)") + + elif self.difficulty_filter == "quasi_balanced": + # Quasi-balanced: Try to get equal distribution (1/3 each), + # but if a difficulty doesn't have enough samples, use closest difficulty: + # - For challenging: use moderate if not enough challenging + # - For simple: use moderate if not enough simple + # - For moderate: use half simple and half challenging if not enough moderate + + if self.max_samples is not None: + target_per_difficulty = self.max_samples // 3 + else: + # Try to get the maximum possible balanced distribution + target_per_difficulty = max( + len(simple_samples), + len(moderate_samples), + len(challenging_samples) + ) + + # Collect samples for each difficulty category with fallback + selected_simple = [] + selected_moderate = [] + selected_challenging = [] + + # Simple samples: use simple first, then moderate + if len(simple_samples) >= target_per_difficulty: + selected_simple = simple_samples[:target_per_difficulty] + else: + selected_simple = simple_samples[:] + needed = target_per_difficulty - len(selected_simple) + # Use moderate as fallback + selected_simple.extend(moderate_samples[:needed]) + + # Challenging samples: use challenging first, then moderate + if len(challenging_samples) >= target_per_difficulty: + selected_challenging = challenging_samples[:target_per_difficulty] + else: + selected_challenging = challenging_samples[:] + needed = target_per_difficulty - len(selected_challenging) + # Use moderate as fallback + selected_challenging.extend(moderate_samples[:needed]) + + # Moderate samples: use moderate first, then half simple and half challenging + if len(moderate_samples) >= target_per_difficulty: + selected_moderate = moderate_samples[:target_per_difficulty] + else: + selected_moderate = moderate_samples[:] + needed = target_per_difficulty - len(selected_moderate) + # Split needed samples between simple and challenging + half_needed = needed // 2 + remainder = needed % 2 + + # Use simple and challenging (not already used in other categories) + # To avoid reusing samples, we need to track what we've already taken + simple_used = len(selected_simple) if len(simple_samples) < target_per_difficulty else 0 + challenging_used = len(selected_challenging) if len(challenging_samples) < target_per_difficulty else 0 + + from_simple = simple_samples[simple_used:simple_used + half_needed + remainder] + from_challenging = challenging_samples[challenging_used:challenging_used + half_needed] + + selected_moderate.extend(from_simple) + selected_moderate.extend(from_challenging) + + # Combine samples + result = selected_simple + selected_moderate + selected_challenging + + # Print detailed selection info + simple_from_moderate = max(0, target_per_difficulty - len(simple_samples)) + challenging_from_moderate = max(0, target_per_difficulty - len(challenging_samples)) + moderate_from_others = max(0, target_per_difficulty - len(moderate_samples)) + + print(f"Difficulty filter 'quasi_balanced': Target {target_per_difficulty} per difficulty") + print(f" Simple: {len(selected_simple)} samples " + f"({len(simple_samples)} native, {simple_from_moderate} from moderate)") + print(f" Moderate: {len(selected_moderate)} samples " + f"({len(moderate_samples)} native, {moderate_from_others} from simple/challenging)") + print(f" Challenging: {len(selected_challenging)} samples " + f"({len(challenging_samples)} native, {challenging_from_moderate} from moderate)") + print(f" Total: {len(result)} samples") + + else: + # Should not reach here due to validation in __init__ + result = raw_data + + if unknown_samples: + print(f"Warning: {len(unknown_samples)} samples with unknown difficulty were excluded") + + return result + + def _apply_curriculum_ordering(self, raw_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Apply curriculum-based ordering to samples (run-level). + + Args: + raw_data: List of raw data items with 'difficulty' field + + Returns: + Reordered list based on curriculum strategy + """ + if not self.curriculum: + return raw_data + + # Categorize samples by difficulty + simple_samples = [] + moderate_samples = [] + challenging_samples = [] + unknown_samples = [] + + for item in raw_data: + difficulty = (item.get("difficulty") or "").lower() + if difficulty == "simple": + simple_samples.append(item) + elif difficulty == "moderate": + moderate_samples.append(item) + elif difficulty == "challenging": + challenging_samples.append(item) + else: + unknown_samples.append(item) + + # Apply curriculum ordering + if self.curriculum == "easy_to_hard": + # Easy -> Medium -> Challenging + result = simple_samples + moderate_samples + challenging_samples + unknown_samples + print(f"Curriculum 'easy_to_hard': Ordered {len(simple_samples)} simple -> " + f"{len(moderate_samples)} moderate -> {len(challenging_samples)} challenging") + + elif self.curriculum == "hard_to_easy": + # Challenging -> Medium -> Easy + result = challenging_samples + moderate_samples + simple_samples + unknown_samples + print(f"Curriculum 'hard_to_easy': Ordered {len(challenging_samples)} challenging -> " + f"{len(moderate_samples)} moderate -> {len(simple_samples)} simple") + + elif self.curriculum == "random": + # Random order with fixed seed + import random + result = raw_data.copy() + random.seed(42) # Fixed seed for reproducibility + random.shuffle(result) + print(f"Curriculum 'random': Randomly shuffled {len(result)} samples (seed=42)") + + else: + # Should not reach here due to validation in __init__ + result = raw_data + + return result + + def _print_data_summary(self, raw_data: List[Dict[str, Any]]) -> None: + """ + Print a detailed summary of the processed data. + + Args: + raw_data: List of processed data items + """ + print("\n" + "="*70) + print("PROCESSED DATA SUMMARY") + print("="*70) + + # Total size + print(f"Total samples: {len(raw_data)}") + + # Database selection + if self.db_name: + print(f"Database filter: '{self.db_name}'") + else: + # Count unique databases + db_names = set() + for item in raw_data: + db = item.get("db_name") or item.get("db_id") or "unknown" + db_names.add(db) + print(f"Database filter: None (using {len(db_names)} databases: {', '.join(sorted(db_names))})") + + # Difficulty filter (dataset-level) + if self.difficulty_filter: + print(f"Difficulty filter (dataset-level): {self.difficulty_filter}") + else: + print("Difficulty filter: None") + + # Curriculum (run-level ordering) + if self.curriculum: + print(f"Curriculum ordering (run-level): {self.curriculum}") + else: + print("Curriculum ordering: None (original order)") + + # Difficulty distribution + from collections import Counter + difficulty_counts = Counter() + for item in raw_data: + difficulty = (item.get("difficulty") or "unknown").lower() + difficulty_counts[difficulty] += 1 + + print(f"\nDifficulty distribution:") + print(f" Simple: {difficulty_counts.get('simple', 0):4d} samples") + print(f" Moderate: {difficulty_counts.get('moderate', 0):4d} samples") + print(f" Challenging: {difficulty_counts.get('challenging', 0):4d} samples") + if difficulty_counts.get('unknown', 0) > 0: + print(f" Unknown: {difficulty_counts.get('unknown', 0):4d} samples") + + # Order of difficulties (first 20 and last 20 samples) + if len(raw_data) > 0: + difficulties = [(item.get("difficulty") or "unknown").lower() for item in raw_data] + + print(f"\nDifficulty order:") + if len(difficulties) <= 40: + # Show all if 40 or fewer + order_str = " -> ".join(difficulties) + print(f" {order_str}") + else: + # Show first 20 and last 20 + first_20 = " -> ".join(difficulties[:20]) + last_20 = " -> ".join(difficulties[-20:]) + print(f" First 20: {first_20}") + print(f" ... ({len(difficulties) - 40} more samples) ...") + print(f" Last 20: {last_20}") + + print("="*70 + "\n") + + # ------------------------- + # EXECUTION EVAL INTERNALS + # ------------------------- + + def _exec_match(self, predicted_sql: str, gold_sql: str, db_name: str, return_exec_results: bool = False): + sqlite_path = self._find_sqlite_path(db_name) + if not sqlite_path: + # DB not found -> raise error and stop execution + db_root = self.cosql_db_root if self.task == "cosql" else self.bird_db_root + error_msg = f"SQLite DB for {db_name} not found under {db_root}. Please check database configuration." + print(f"\n--- FATAL ERROR: Database Not Found ---") + print(f"DB: {db_name}") + print(f"Task: {self.task}") + print(f"Expected location: {db_root}/{db_name}/{db_name}.sqlite") + print(f"Error: {error_msg}") + print("-" * 50) + raise FileNotFoundError(error_msg) + + try: + print(f"[EXEC] Running PREDICTED SQL on {db_name}") + pred_res = self._run_sql(sqlite_path, predicted_sql) + + print(f"[EXEC] Running GROUND TRUTH SQL on {db_name}") + gold_res = self._run_sql(sqlite_path, gold_sql) + + # Print execution results + print(f"\n--- Execution Results ---") + print(f"DB: {db_name}") + print(f"\nPredicted SQL:\n{predicted_sql}") + print(f"\nPredicted Result ({len(pred_res)} rows):") + for row in pred_res[:10]: # Print first 10 rows + print(f" {row}") + if len(pred_res) > 10: + print(f" ... ({len(pred_res) - 10} more rows)") + + print(f"\nGround Truth SQL:\n{gold_sql}") + print(f"\nGround Truth Result ({len(gold_res)} rows):") + for row in gold_res[:10]: # Print first 10 rows + print(f" {row}") + if len(gold_res) > 10: + print(f" ... ({len(gold_res) - 10} more rows)") + print("-" * 50) + + # Compare results + print(f"[EXEC] Normalizing and comparing results...") + match = self._normalize_result(pred_res) == self._normalize_result(gold_res) + print(f"[EXEC] Match result: {match}") + + if return_exec_results: + exec_results = { + "predicted_result": pred_res, + "ground_truth_result": gold_res, + "db_name": db_name + } + return match, exec_results + return match, {} + + except Exception as e: + print(f"\n--- Execution Error ---") + print(f"DB: {db_name}") + print(f"Error: {e}") + print("-" * 50) + if return_exec_results: + return False, {"error": str(e), "db_name": db_name} + return False, {} + + def _find_sqlite_path(self, db_name: str) -> Optional[str]: + """ + Find SQLite database path for either BIRD or CoSQL. + + Typical layouts: + BIRD: //.sqlite + CoSQL: //.sqlite + """ + # Determine which db_root to use based on task + if self.task == "cosql": + db_root = self.cosql_db_root + else: + db_root = self.bird_db_root + + if not db_root: + return None + + # Try standard layout: //.sqlite + p = os.path.join(db_root, db_name, f"{db_name}.sqlite") + if os.path.exists(p): + return p + + return None + + def _run_sql(self, sqlite_path: str, sql: str) -> List[Tuple[Any, ...]]: + sql = (sql or "").strip().rstrip(";") + if not sql: + raise ValueError("Empty SQL") + + # guard against writes/DDL in evaluation + lowered = sql.lower() + if any(k in lowered for k in ["insert ", "update ", "delete ", "drop ", "alter ", "create ", "pragma ", "attach "]): + raise ValueError("Unsafe SQL in evaluation") + + print(f"[DEBUG] Connecting to database: {sqlite_path}") + conn = sqlite3.connect(sqlite_path, timeout=self.exec_timeout_ms / 1000.0) + + # Set up timeout mechanism using progress handler + start_time = time.time() + timeout_seconds = self.exec_timeout_ms / 1000.0 + + def progress_handler(): + if time.time() - start_time > timeout_seconds: + print(f"[TIMEOUT] Query exceeded {timeout_seconds}s timeout") + return 1 # Non-zero return aborts the operation + return 0 + + try: + # Call progress handler every 1000 VM instructions + conn.set_progress_handler(progress_handler, 1000) + + print(f"[DEBUG] Executing SQL query...") + cur = conn.cursor() + cur.execute(sql) + + print(f"[DEBUG] Fetching results (max {self.exec_max_rows} rows)...") + rows = cur.fetchmany(self.exec_max_rows) + + elapsed = time.time() - start_time + print(f"[DEBUG] Query completed in {elapsed:.2f}s, returned {len(rows)} rows") + + return [tuple(r) for r in rows] + except sqlite3.OperationalError as e: + elapsed = time.time() - start_time + if "interrupted" in str(e).lower(): + raise TimeoutError(f"SQL query timed out after {elapsed:.2f}s (limit: {timeout_seconds}s)") + raise + finally: + conn.close() + + @staticmethod + def _normalize_result(rows: List[Tuple[Any, ...]]) -> List[Tuple[Any, ...]]: + """ + Normalize query results for comparison. + - Rounds floats to 6 decimal places + - Sorts rows for order-independent comparison + """ + def norm(v): + if isinstance(v, float): + return round(v, 6) + return v + + normed = [tuple(norm(v) for v in row) for row in rows] + return sorted(normed) diff --git a/eval/stream-bench/dataset_stats.py b/eval/stream-bench/dataset_stats.py new file mode 100755 index 00000000..95ea7bf3 --- /dev/null +++ b/eval/stream-bench/dataset_stats.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +""" +Dataset Statistics Script + +This script analyzes and displays statistics for BIRD, CoSQL, and Spider datasets. +It shows total samples and difficulty distribution for train, val, and test splits. + +Usage: + python dataset_stats.py + +where is one of: bird, cosql, spider +""" + +import json +import sys +from pathlib import Path +from collections import defaultdict + + +def load_jsonl(file_path): + """Load data from a JSONL file.""" + data = [] + with open(file_path, 'r') as f: + for line in f: + data.append(json.loads(line.strip())) + return data + + +def get_difficulty_stats(data): + """Calculate difficulty distribution from dataset.""" + difficulty_counts = defaultdict(int) + for item in data: + difficulty = item.get('difficulty', 'unknown') + difficulty_counts[difficulty] += 1 + return dict(difficulty_counts) + + +def print_split_stats(split_name, data): + """Print statistics for a single data split.""" + total = len(data) + difficulty_stats = get_difficulty_stats(data) + + print(f"\n{split_name.upper()}:") + print(f" Total samples: {total}") + + if difficulty_stats: + print(f" By difficulty:") + # Sort difficulties for consistent output + for difficulty in sorted(difficulty_stats.keys()): + count = difficulty_stats[difficulty] + percentage = (count / total * 100) if total > 0 else 0 + print(f" {difficulty}: {count} ({percentage:.1f}%)") + + +def get_dataset_paths(dataset_name): + """Get file paths for the specified dataset.""" + base_path = Path(__file__).parent / "data" + + paths = { + 'train': base_path / f"streambench_{dataset_name}_train.jsonl", + 'val': base_path / f"streambench_{dataset_name}_val.jsonl", + 'test': base_path / f"streambench_{dataset_name}_test.jsonl" + } + + return paths + + +def print_dataset_stats(dataset_name): + """Print comprehensive statistics for a dataset.""" + print(f"=" * 60) + print(f"Dataset Statistics: {dataset_name.upper()}") + print(f"=" * 60) + + paths = get_dataset_paths(dataset_name) + + # Track totals across all splits + total_samples = 0 + total_difficulty = defaultdict(int) + + for split_name in ['train', 'val', 'test']: + file_path = paths[split_name] + + if not file_path.exists(): + print(f"\n{split_name.upper()}: File not found at {file_path}") + continue + + try: + data = load_jsonl(file_path) + print_split_stats(split_name, data) + + # Update totals + total_samples += len(data) + difficulty_stats = get_difficulty_stats(data) + for difficulty, count in difficulty_stats.items(): + total_difficulty[difficulty] += count + + except Exception as e: + print(f"\n{split_name.upper()}: Error loading file - {e}") + + # Print overall statistics + print(f"\n{'-' * 60}") + print(f"OVERALL STATISTICS:") + print(f" Total samples across all splits: {total_samples}") + + if total_difficulty: + print(f" Overall difficulty distribution:") + for difficulty in sorted(total_difficulty.keys()): + count = total_difficulty[difficulty] + percentage = (count / total_samples * 100) if total_samples > 0 else 0 + print(f" {difficulty}: {count} ({percentage:.1f}%)") + + print(f"=" * 60) + + +def main(): + """Main function to parse arguments and display statistics.""" + if len(sys.argv) != 2: + print("Error: Dataset name required") + print("\nUsage: python dataset_stats.py ") + print("where is one of: bird, cosql, spider") + sys.exit(1) + + dataset_name = sys.argv[1].lower() + + valid_datasets = ['bird', 'cosql', 'spider'] + if dataset_name not in valid_datasets: + print(f"Error: Invalid dataset name '{dataset_name}'") + print(f"Valid options: {', '.join(valid_datasets)}") + sys.exit(1) + + print_dataset_stats(dataset_name) + + +if __name__ == "__main__": + main() diff --git a/eval/stream-bench/download_text2sql_data.py b/eval/stream-bench/download_text2sql_data.py new file mode 100644 index 00000000..556f4142 --- /dev/null +++ b/eval/stream-bench/download_text2sql_data.py @@ -0,0 +1,160 @@ +""" +Download text-to-SQL datasets (Spider, CoSQL, and BIRD) and unzip them. +https://github.com/stream-bench/stream-bench/blob/main/download_text2sql_data.py + +Usage: + # Download BIRD dev split (test set, ~350MB) + python download_text2sql_data.py --dataset bird --split dev + + # Download BIRD train split (full training set with databases, ~33GB) + python download_text2sql_data.py --dataset bird --split train + + # Download other datasets + python download_text2sql_data.py --dataset spider + python download_text2sql_data.py --dataset cosql +""" + +import os +import gdown +import zipfile +import requests +from tqdm import tqdm +from pathlib import Path +from colorama import Fore, Style + +def download_file(url: str, save_path: str) -> None: + # Stream the download in chunks + r = requests.get(url, stream=True) + total_size = int(r.headers.get('content-length', 0)) + chunk_size = 1024 # 1 KB + + with open(save_path, "wb") as f, tqdm( + desc=save_path, + total=total_size, + unit='B', + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + bar.update(len(chunk)) + +def download_bird(save_dir: str, split: str = "dev") -> None: + """ + Download BIRD dataset. + + Args: + save_dir: Directory to save the downloaded files + split: Either "dev" (test split, ~350MB) or "train" (full training split, ~33GB) + """ + split_name = "dev" if split == "dev" else "train" + print(Fore.CYAN + f"Downloading and unzipping BIRD {split_name}..." + Style.RESET_ALL) + Path(save_dir).mkdir(parents=True, exist_ok=True) + + # Download BIRD + if split == "train": + bird_link = "https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip" + bird_zip = "bird_train.zip" + db_zip_name = "train_databases.zip" + else: + bird_link = "https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip" + bird_zip = "bird.zip" + db_zip_name = "dev_databases.zip" + + bird_save_path = f"{save_dir}/{bird_zip}" + if os.path.exists(bird_save_path): + print(f"BIRD {split_name} already exists at {bird_save_path}") + else: + print(f"Downloading BIRD {split_name} to {bird_save_path}") + print(f"Note: {'Train split is ~33GB and may take a while!' if split == 'train' else 'Dev split is ~350MB.'}") + download_file(url=bird_link, save_path=bird_save_path) + + # Unzip BIRD + extract_to_dir = os.path.join(save_dir, "bird" if split == "dev" else "bird_train") + if os.path.exists(extract_to_dir): + print(f"BIRD {split_name} already unzipped to {extract_to_dir}") + else: + print(f"Extracting {bird_zip}...") + with zipfile.ZipFile(bird_save_path, 'r') as zip_ref: + zip_ref.extractall(save_dir) + + # Search for the database zip file + db_zip_path = None + for root, dirs, files in os.walk(save_dir): + for file in files: + if file == db_zip_name: + db_zip_path = os.path.join(root, file) + break + if db_zip_path: + break + + if db_zip_path: + print(f"Extracting {db_zip_name}...") + with zipfile.ZipFile(db_zip_path, 'r') as zip_ref: + zip_ref.extractall(extract_to_dir) + print(f"Unzipped to {extract_to_dir}") + else: + print(f"Warning: Could not find {db_zip_name} in extracted files") + print(f"Unzipped to {extract_to_dir}") + +def download_cosql(save_dir: str) -> None: + print(Fore.CYAN + "Downloading and unzipping CoSQL..." + Style.RESET_ALL) + # Download CoSQL + cosql_zip = "cosql.zip" + cosql_save_path = f"{save_dir}/{cosql_zip}" + if os.path.exists(cosql_save_path): + print(f"CoSQL already exists at {cosql_save_path}") + else: + print(f"Downloading CoSQL to {cosql_save_path}") + file_id = "1QQPkUVUN2Leu_ykVchae0FzURZGHPwdJ" + url = f"https://drive.google.com/uc?export=download&id={file_id}" + gdown.download(url, cosql_save_path, quiet=False) + + # Unzip CoSQL + extract_to_dir = os.path.join(save_dir, "cosql") + if os.path.exists(extract_to_dir): + print(f"CoSQL already unzipped to {extract_to_dir}") + else: + with zipfile.ZipFile(cosql_save_path, 'r') as zip_ref: + zip_ref.extractall(save_dir) + print(f"Unzipped to {extract_to_dir}") + +def download_spider(save_dir: str) -> None: + print(Fore.CYAN + "Downloading and unzipping Spider..." + Style.RESET_ALL) + # Download Spider + spider_zip = "spider.zip" + spider_save_path = f"{save_dir}/{spider_zip}" + if os.path.exists(spider_save_path): + print(f"Spider already exists at {spider_save_path}") + else: + print(f"Downloading Spider to {spider_save_path}") + file_id = "1nkYLYDSGKICePTQnnPl9TgLfdmHI9ctM" + url = f"https://drive.google.com/uc?export=download&id={file_id}" + gdown.download(url, spider_save_path, quiet=False) + + # Unzip Spider + extract_to_dir = os.path.join(save_dir, "spider") + if os.path.exists(extract_to_dir): + print(f"Spider already unzipped to {extract_to_dir}") + else: + with zipfile.ZipFile(spider_save_path, 'r') as zip_ref: + zip_ref.extractall(save_dir) + print(f"Unzipped to {extract_to_dir}") + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Download text-to-SQL datasets") + parser.add_argument("--save_dir", default="./data", help="Directory to save datasets") + parser.add_argument("--dataset", default="bird", choices=["bird", "cosql", "spider"], + help="Which dataset to download") + parser.add_argument("--split", default="dev", choices=["dev", "train"], + help="For BIRD: 'dev' (test split, ~350MB) or 'train' (full training split, ~33GB)") + args = parser.parse_args() + + if args.dataset == "bird": + download_bird(args.save_dir, split=args.split) + elif args.dataset == "cosql": + download_cosql(args.save_dir) + elif args.dataset == "spider": + download_spider(args.save_dir) \ No newline at end of file diff --git a/eval/stream-bench/plot.py b/eval/stream-bench/plot.py new file mode 100644 index 00000000..48383f6a --- /dev/null +++ b/eval/stream-bench/plot.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +""" +Plotting utilities for stream-bench online training results. + +Generates two types of plots: +1. Test Performance: Shows accuracy progression across test windows +2. Training Progress: Shows pre/post train accuracy, improvement, and playbook growth + +Can be run as a standalone script: + python stream-bench/plot.py --run_dir results/ace_run_20260113_170913_finer_online + +Or imported and used programmatically: + from plot import plot_online_performance, plot_training_progress + plot_online_performance(save_path, mode) + plot_training_progress(save_path, mode) +""" +import os +import sys +import json +import argparse +import matplotlib +matplotlib.use('Agg') # Use non-GUI backend for saving plots without display +import matplotlib.pyplot as plt + + +def plot_offline_training_progress(save_path): + """ + Generate training progress plots for offline mode showing pre/post train accuracy per step. + + Args: + save_path: Path where results are saved and where plot will be saved + """ + # Load pre_train_post_train_results + pre_post_path = os.path.join(save_path, 'pre_train_post_train_results.json') + if not os.path.exists(pre_post_path): + print(f"Warning: pre_train_post_train_results.json not found at {pre_post_path}. Skipping offline plot generation.") + return + + with open(pre_post_path, 'r') as f: + step_results = json.load(f) + + if not step_results: + print("Warning: Empty pre_train_post_train_results. Skipping offline plot generation.") + return + + # Load validation results from train_results.json + train_results_path = os.path.join(save_path, 'train_results.json') + val_steps = [] + val_accuracies = [] + val_by_difficulty = {} # Dictionary to track accuracy by difficulty over steps + + if os.path.exists(train_results_path): + with open(train_results_path, 'r') as f: + train_data = json.load(f) + if 'results' in train_data and train_data['results']: + for result in train_data['results']: + if 'val_result' in result and result['val_result']: + val_steps.append(result['step']) + val_accuracies.append(result['val_result']['accuracy']) + + # Extract difficulty-level accuracies + if 'by_difficulty' in result['val_result']: + for difficulty, diff_data in result['val_result']['by_difficulty'].items(): + if difficulty not in val_by_difficulty: + val_by_difficulty[difficulty] = {'steps': [], 'accuracies': []} + val_by_difficulty[difficulty]['steps'].append(result['step']) + val_by_difficulty[difficulty]['accuracies'].append(diff_data['accuracy']) + + # Load final results for initial and final test accuracy + final_results_path = os.path.join(save_path, 'final_results.json') + initial_test_acc = None + final_test_acc = None + + if os.path.exists(final_results_path): + with open(final_results_path, 'r') as f: + final_data = json.load(f) + if 'initial_test_results' in final_data: + initial_test_acc = final_data['initial_test_results']['accuracy'] + if 'final_test_results' in final_data: + final_test_acc = final_data['final_test_results']['accuracy'] + + # Extract data + steps = [r['step'] for r in step_results] + epochs = [r['epoch'] for r in step_results] + pre_train_correct = [r['pre_train_result']['is_correct'] for r in step_results] + post_train_correct = [r['post_train_result']['is_correct'] for r in step_results] + playbook_tokens = [r['post_train_result']['playbook_num_tokens'] for r in step_results] + playbook_length = [r['post_train_result']['playbook_length'] for r in step_results] + step_times = [r.get('step_time_seconds', 0) for r in step_results] + + # Calculate cumulative accuracies + cumulative_pre = [] + cumulative_post = [] + for i in range(len(steps)): + cumulative_pre.append(sum(pre_train_correct[:i+1]) / (i+1)) + cumulative_post.append(sum(post_train_correct[:i+1]) / (i+1)) + + # Calculate per-step improvement (1 if improved, 0 if same, -1 if worse) + improvement = [int(post) - int(pre) for pre, post in zip(pre_train_correct, post_train_correct)] + + # Create figure with multiple subplots + _, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12)) + + # Plot 1: Cumulative Accuracy - Pre-train vs Post-train + ax1.plot(steps, cumulative_pre, 'r-o', linewidth=2, markersize=6, label='Pre-train (cumulative)', alpha=0.7) + ax1.plot(steps, cumulative_post, 'g-s', linewidth=2, markersize=6, label='Post-train (cumulative)', alpha=0.7) + + # Add validation accuracy line if available + if val_steps and val_accuracies: + ax1.plot(val_steps, val_accuracies, 'b-^', linewidth=2, markersize=8, label='Validation Accuracy', alpha=0.8, zorder=5) + + # Add difficulty-level validation accuracy lines + difficulty_colors = {'simple': 'lightgreen', 'moderate': 'orange', 'challenging': 'darkred'} + difficulty_markers = {'simple': 'v', 'moderate': 'D', 'challenging': 'X'} + for difficulty in sorted(val_by_difficulty.keys()): + diff_data = val_by_difficulty[difficulty] + color = difficulty_colors.get(difficulty, 'gray') + marker = difficulty_markers.get(difficulty, 'o') + ax1.plot(diff_data['steps'], diff_data['accuracies'], + linestyle='--', linewidth=1.5, marker=marker, markersize=6, + color=color, label=f'Val: {difficulty}', alpha=0.7, zorder=4) + + # Add initial and final test accuracy if available + if initial_test_acc is not None: + ax1.axhline(y=initial_test_acc, color='cyan', linestyle='--', linewidth=1.5, label=f'Initial Test Acc: {initial_test_acc:.3f}', alpha=0.6) + if final_test_acc is not None: + ax1.axhline(y=final_test_acc, color='purple', linestyle='--', linewidth=1.5, label=f'Final Test Acc: {final_test_acc:.3f}', alpha=0.6) + + ax1.set_xlabel('Training Step', fontsize=12) + ax1.set_ylabel('Cumulative Accuracy', fontsize=12) + ax1.set_title('Offline Mode: Training Progress (Cumulative)', fontsize=14, fontweight='bold') + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=8, loc='best') + ax1.set_ylim([0, 1.0]) + + # Plot 2: Per-Step Correctness (1 = improved, 0 = same, -1 = worse) + colors = ['green' if x > 0 else 'gray' if x == 0 else 'red' for x in improvement] + ax2.bar(steps, improvement, color=colors, alpha=0.6, edgecolor='black', width=0.8) + ax2.axhline(y=0, color='black', linestyle='-', linewidth=1) + ax2.set_xlabel('Training Step', fontsize=12) + ax2.set_ylabel('Improvement (Post - Pre)', fontsize=12) + ax2.set_title('Offline Mode: Per-Step Improvement', fontsize=14, fontweight='bold') + ax2.set_yticks([-1, 0, 1]) + ax2.set_yticklabels(['Worse', 'Same', 'Better']) + ax2.grid(True, alpha=0.3, axis='y') + + # Add summary statistics + improved = sum(1 for x in improvement if x > 0) + same = sum(1 for x in improvement if x == 0) + worse = sum(1 for x in improvement if x < 0) + ax2.text(0.02, 0.98, f'Improved: {improved}\nSame: {same}\nWorse: {worse}', + transform=ax2.transAxes, fontsize=10, verticalalignment='top', + bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) + + # Plot 3: Playbook Token Growth + ax3.plot(steps, playbook_tokens, 'purple', marker='D', linewidth=2, markersize=6, label='Playbook Tokens') + ax3.set_xlabel('Training Step', fontsize=12) + ax3.set_ylabel('Number of Tokens', fontsize=12) + ax3.set_title('Offline Mode: Playbook Growth (Tokens)', fontsize=14, fontweight='bold') + ax3.grid(True, alpha=0.3) + ax3.legend(fontsize=10) + + # Plot 4: Step Time Distribution + ax4.plot(steps, step_times, 'orange', marker='o', linewidth=2, markersize=6, label='Step Time') + ax4.axhline(y=sum(step_times)/len(step_times), color='red', linestyle='--', linewidth=2, + label=f'Avg: {sum(step_times)/len(step_times):.1f}s') + ax4.set_xlabel('Training Step', fontsize=12) + ax4.set_ylabel('Time (seconds)', fontsize=12) + ax4.set_title('Offline Mode: Training Time per Step', fontsize=14, fontweight='bold') + ax4.grid(True, alpha=0.3) + ax4.legend(fontsize=10) + + plt.tight_layout() + + # Create plots subfolder + plots_dir = os.path.join(save_path, 'plots') + os.makedirs(plots_dir, exist_ok=True) + + # Save plot + plot_path = os.path.join(plots_dir, 'offline_training_progress.png') + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + print(f"\nOffline training progress plot saved to: {plot_path}") + plt.close() + + # Save data as CSV + csv_path = os.path.join(plots_dir, 'offline_training_data.csv') + + # Create a dictionary mapping steps to validation accuracies for easy lookup + val_acc_by_step = {step: acc for step, acc in zip(val_steps, val_accuracies)} + + # Create dictionaries mapping steps to difficulty-level accuracies + difficulty_acc_by_step = {} + all_difficulties = sorted(val_by_difficulty.keys()) + for difficulty in all_difficulties: + difficulty_acc_by_step[difficulty] = { + step: acc for step, acc in zip( + val_by_difficulty[difficulty]['steps'], + val_by_difficulty[difficulty]['accuracies'] + ) + } + + # Build CSV header with difficulty columns + header_parts = ["step", "epoch", "pre_train_correct", "post_train_correct", + "cumulative_pre_acc", "cumulative_post_acc", "improvement", + "playbook_tokens", "playbook_length", "step_time_seconds", "val_accuracy"] + for difficulty in all_difficulties: + header_parts.append(f"val_acc_{difficulty}") + + with open(csv_path, 'w') as f: + f.write(",".join(header_parts) + "\n") + for i in range(len(steps)): + val_acc_str = f"{val_acc_by_step[steps[i]]:.4f}" if steps[i] in val_acc_by_step else "" + + # Build row + row_parts = [ + str(steps[i]), str(epochs[i]), + str(int(pre_train_correct[i])), str(int(post_train_correct[i])), + f"{cumulative_pre[i]:.4f}", f"{cumulative_post[i]:.4f}", + str(improvement[i]), + str(playbook_tokens[i]), str(playbook_length[i]), + f"{step_times[i]:.2f}", val_acc_str + ] + + # Add difficulty-level accuracies + for difficulty in all_difficulties: + if steps[i] in difficulty_acc_by_step[difficulty]: + row_parts.append(f"{difficulty_acc_by_step[difficulty][steps[i]]:.4f}") + else: + row_parts.append("") + + f.write(",".join(row_parts) + "\n") + print(f"Offline training data saved to: {csv_path}") + + +def plot_online_performance(save_path, mode='online'): + """ + Generate performance plots for online mode showing how accuracy changes over steps. + + Args: + save_path: Path where results are saved and where plot will be saved + mode: Run mode (should be 'online', default: 'online') + """ + if mode != 'online': + print(f"Skipping plot generation - only available for online mode (current mode: {mode})") + return + + # Load test results from the saved JSON file (which contains window_results) + test_results_path = os.path.join(save_path, 'test_results.json') + if not os.path.exists(test_results_path): + print(f"Warning: Test results file not found at {test_results_path}. Skipping plot generation.") + return + + with open(test_results_path, 'r') as f: + test_data = json.load(f) + + # Extract test_results from the loaded JSON + if 'test_results' not in test_data: + print("Warning: No test results found in test_results.json. Skipping plot generation.") + return + + test_results = test_data['test_results'] + + # Check if we have window results + if 'window_results' not in test_results: + print("Warning: No window results found. Skipping plot generation.") + return + + window_results = test_results['window_results'] + + # Extract window data + window_numbers = [w['window'] for w in window_results] + window_accuracies = [w['window_accuracy'] for w in window_results] + window_end_indices = [w['end_idx'] for w in window_results] + + # Get initial and final test accuracy from final_results.json if available + final_results_path = os.path.join(save_path, 'final_results.json') + initial_test_accuracy = None + final_test_accuracy = None + + if os.path.exists(final_results_path): + with open(final_results_path, 'r') as f: + final_results = json.load(f) + if 'initial_test_results' in final_results: + initial_test_accuracy = final_results['initial_test_results']['accuracy'] + if 'online_test_results' in final_results: + final_test_accuracy = final_results['online_test_results']['accuracy'] + + # Fallback to window data if final_results.json not available + if initial_test_accuracy is None: + initial_test_accuracy = window_accuracies[0] if window_accuracies else 0.0 + if final_test_accuracy is None: + final_test_accuracy = test_results['accuracy'] + + # Create figure with multiple subplots + _, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) + + # Plot 1: Accuracy by Window + ax1.plot(window_numbers, window_accuracies, 'b-o', linewidth=2, markersize=8, label='Window Accuracy') + ax1.axhline(y=test_results['accuracy'], color='r', linestyle='--', linewidth=2, label=f'Overall Accuracy: {test_results["accuracy"]:.3f}') + + # Add initial test accuracy dot (before training, shown at window 0.5 to be before first window) + ax1.plot(0.5, initial_test_accuracy, 'go', markersize=14, label=f'Initial Test Accuracy: {initial_test_accuracy:.3f}', zorder=5) + # Add final test accuracy dot (after all training, shown at the last window) + ax1.plot(window_numbers[-1] + 0.5, final_test_accuracy, 'ro', markersize=14, label=f'Final Test Accuracy: {final_test_accuracy:.3f}', zorder=5) + + ax1.set_xlabel('Window Number', fontsize=12) + ax1.set_ylabel('Accuracy', fontsize=12) + ax1.set_title('Online Mode: Accuracy by Training Window', fontsize=14, fontweight='bold') + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=10) + ax1.set_ylim([0, 1.0]) + + # Add value labels on points + for x, y in zip(window_numbers, window_accuracies): + ax1.annotate(f'{y:.3f}', (x, y), textcoords="offset points", xytext=(0,10), ha='center', fontsize=8) + + # Plot 2: Accuracy by Sample Index (cumulative) + ax2.plot(window_end_indices, window_accuracies, 'g-s', linewidth=2, markersize=8, label='Accuracy') + ax2.axhline(y=test_results['accuracy'], color='r', linestyle='--', linewidth=2, label=f'Overall Accuracy: {test_results["accuracy"]:.3f}') + + # Add initial test accuracy dot (at index 0, before training starts) + ax2.plot(0, initial_test_accuracy, 'go', markersize=14, label=f'Initial Test Accuracy: {initial_test_accuracy:.3f}', zorder=5) + # Add final test accuracy dot (at the end of all samples) + ax2.plot(window_end_indices[-1], final_test_accuracy, 'ro', markersize=14, label=f'Final Test Accuracy: {final_test_accuracy:.3f}', zorder=5) + + ax2.set_xlabel('Sample Index (End of Window)', fontsize=12) + ax2.set_ylabel('Accuracy', fontsize=12) + ax2.set_title('Online Mode: Accuracy by Sample Progress', fontsize=14, fontweight='bold') + ax2.grid(True, alpha=0.3) + ax2.legend(fontsize=10) + ax2.set_ylim([0, 1.0]) + + plt.tight_layout() + + # Create plots subfolder + plots_dir = os.path.join(save_path, 'plots') + os.makedirs(plots_dir, exist_ok=True) + + # Save plot + plot_path = os.path.join(plots_dir, 'online_performance_plot.png') + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + print(f"\nPerformance plot saved to: {plot_path}") + plt.close() + + # Also save the data as CSV for external plotting + csv_path = os.path.join(plots_dir, 'online_performance_data.csv') + with open(csv_path, 'w') as f: + f.write("window,window_accuracy,start_idx,end_idx,samples_in_window\n") + for w in window_results: + f.write(f"{w['window']},{w['window_accuracy']},{w['start_idx']},{w['end_idx']},{w['window_total']}\n") + print(f"Performance data saved to: {csv_path}") + + +def plot_training_progress(save_path, mode='online'): + """ + Generate training progress plots showing pre/post train accuracy and playbook growth. + + Args: + save_path: Path where results are saved and where plot will be saved + mode: Run mode (should be 'online', default: 'online') + """ + if mode != 'online': + print(f"Skipping training plot generation - only available for online mode (current mode: {mode})") + return + + # Load training results from the saved JSON file + train_results_path = os.path.join(save_path, 'train_results.json') + if not os.path.exists(train_results_path): + print(f"Warning: Training results file not found at {train_results_path}. Skipping training plot generation.") + return + + with open(train_results_path, 'r') as f: + train_data = json.load(f) + + # Extract train_results from the loaded JSON + if 'train_results' not in train_data: + print("Warning: No train results found in train_results.json. Skipping training plot generation.") + return + + train_results = train_data['train_results'] + + if not train_results: + print("Warning: Empty train results. Skipping training plot generation.") + return + + # Load initial and final test accuracy from final_results.json if available + final_results_path = os.path.join(save_path, 'final_results.json') + initial_test_accuracy = None + final_test_accuracy = None + + if os.path.exists(final_results_path): + with open(final_results_path, 'r') as f: + final_results = json.load(f) + if 'initial_test_results' in final_results: + initial_test_accuracy = final_results['initial_test_results']['accuracy'] + if 'online_test_results' in final_results: + final_test_accuracy = final_results['online_test_results']['accuracy'] + + # Fallback to test_results.json if final_results.json not available + if initial_test_accuracy is None or final_test_accuracy is None: + test_results_path = os.path.join(save_path, 'test_results.json') + if os.path.exists(test_results_path): + with open(test_results_path, 'r') as f: + test_data = json.load(f) + if 'test_results' in test_data and 'window_results' in test_data['test_results']: + window_results = test_data['test_results']['window_results'] + if window_results: + if initial_test_accuracy is None: + initial_test_accuracy = window_results[0]['window_accuracy'] + if final_test_accuracy is None: + final_test_accuracy = test_data['test_results']['accuracy'] + + # Extract data from train_results + windows = [r['window'] for r in train_results] + pre_train_acc = [r['train_result']['pre_train_accuracy'] for r in train_results] + post_train_acc = [r['train_result']['post_train_accuracy'] for r in train_results] + cumulative_test_acc = [r['cumulative_test_accuracy'] for r in train_results] + playbook_tokens = [r['playbook_num_tokens'] for r in train_results] + playbook_length = [r['playbook_length'] for r in train_results] + + # Calculate improvement per window + improvement = [post - pre for pre, post in zip(pre_train_acc, post_train_acc)] + + # Create figure with multiple subplots + _, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12)) + + # Plot 1: Pre-train vs Post-train Accuracy by Window + ax1.plot(windows, pre_train_acc, 'r-o', linewidth=2, markersize=8, label='Pre-train Accuracy') + ax1.plot(windows, post_train_acc, 'g-s', linewidth=2, markersize=8, label='Post-train Accuracy') + ax1.plot(windows, cumulative_test_acc, 'b--^', linewidth=2, markersize=8, label='Cumulative Test Accuracy') + + # Add initial and final test accuracy dots if available + if initial_test_accuracy is not None and final_test_accuracy is not None: + # Initial test accuracy shown before first window (at 0.5) + ax1.plot(0.5, initial_test_accuracy, 'go', markersize=14, label=f'Initial Test Accuracy: {initial_test_accuracy:.3f}', zorder=5) + # Final test accuracy shown after last window (at last window + 0.5) + ax1.plot(windows[-1] + 0.5, final_test_accuracy, 'ro', markersize=14, label=f'Final Test Accuracy: {final_test_accuracy:.3f}', zorder=5) + + ax1.set_xlabel('Window Number', fontsize=12) + ax1.set_ylabel('Accuracy', fontsize=12) + ax1.set_title('Training Progress: Pre-train vs Post-train Accuracy', fontsize=14, fontweight='bold') + ax1.grid(True, alpha=0.3) + ax1.legend(fontsize=10) + ax1.set_ylim([0, 1.0]) + + # Plot 2: Training Improvement per Window + colors = ['green' if x >= 0 else 'red' for x in improvement] + ax2.bar(windows, improvement, color=colors, alpha=0.6, edgecolor='black') + ax2.axhline(y=0, color='black', linestyle='-', linewidth=0.5) + ax2.set_xlabel('Window Number', fontsize=12) + ax2.set_ylabel('Accuracy Improvement', fontsize=12) + ax2.set_title('Training Improvement per Window (Post - Pre)', fontsize=14, fontweight='bold') + ax2.grid(True, alpha=0.3, axis='y') + + # Add value labels on bars + for i, (w, imp) in enumerate(zip(windows, improvement)): + ax2.text(w, imp, f'{imp:+.3f}', ha='center', va='bottom' if imp >= 0 else 'top', fontsize=9) + + # Plot 3: Playbook Token Growth + ax3.plot(windows, playbook_tokens, 'purple', marker='D', linewidth=2, markersize=8, label='Playbook Tokens') + ax3.set_xlabel('Window Number', fontsize=12) + ax3.set_ylabel('Number of Tokens', fontsize=12) + ax3.set_title('Playbook Growth: Token Count', fontsize=14, fontweight='bold') + ax3.grid(True, alpha=0.3) + ax3.legend(fontsize=10) + + # Add value labels + for w, tokens in zip(windows, playbook_tokens): + ax3.annotate(f'{tokens}', (w, tokens), textcoords="offset points", xytext=(0,10), ha='center', fontsize=8) + + # Plot 4: Playbook Character Length Growth + ax4.plot(windows, playbook_length, 'orange', marker='D', linewidth=2, markersize=8, label='Playbook Length (chars)') + ax4.set_xlabel('Window Number', fontsize=12) + ax4.set_ylabel('Character Count', fontsize=12) + ax4.set_title('Playbook Growth: Character Length', fontsize=14, fontweight='bold') + ax4.grid(True, alpha=0.3) + ax4.legend(fontsize=10) + + # Add value labels + for w, length in zip(windows, playbook_length): + ax4.annotate(f'{length}', (w, length), textcoords="offset points", xytext=(0,10), ha='center', fontsize=8) + + plt.tight_layout() + + # Create plots subfolder + plots_dir = os.path.join(save_path, 'plots') + os.makedirs(plots_dir, exist_ok=True) + + # Save plot + plot_path = os.path.join(plots_dir, 'training_progress_plot.png') + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + print(f"\nTraining progress plot saved to: {plot_path}") + plt.close() + + # Also save the data as CSV for external plotting + csv_path = os.path.join(plots_dir, 'training_progress_data.csv') + with open(csv_path, 'w') as f: + f.write("window,pre_train_accuracy,post_train_accuracy,cumulative_test_accuracy,improvement,playbook_tokens,playbook_length\n") + for i, w in enumerate(windows): + f.write(f"{w},{pre_train_acc[i]},{post_train_acc[i]},{cumulative_test_acc[i]},{improvement[i]},{playbook_tokens[i]},{playbook_length[i]}\n") + print(f"Training progress data saved to: {csv_path}") + + +def main(): + """Main function for command-line usage.""" + parser = argparse.ArgumentParser( + description='Generate performance plots for ACE online training runs', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Plot results from a specific run directory + python stream-bench/plot.py --run_dir results/ace_run_20260113_170913_finer_online + + # Specify mode explicitly (default is 'online') + python stream-bench/plot.py --run_dir results/ace_run_20260113_170913_finer_online --mode online + """ + ) + + parser.add_argument( + '--run_dir', + type=str, + required=True, + help='Path to the ACE run directory containing test_results.json' + ) + + parser.add_argument( + '--mode', + type=str, + default='online', + choices=['online', 'offline', 'eval_only'], + help='Run mode (default: online). Only online mode supports plotting.' + ) + + args = parser.parse_args() + + # Validate run directory exists + if not os.path.exists(args.run_dir): + print(f"Error: Run directory not found: {args.run_dir}") + sys.exit(1) + + # Check if test_results.json exists (only required for online mode) + if args.mode == 'online': + test_results_path = os.path.join(args.run_dir, 'test_results.json') + if not os.path.exists(test_results_path): + print(f"Error: test_results.json not found in {args.run_dir}") + print(f"Expected path: {test_results_path}") + sys.exit(1) + + print(f"{'='*60}") + print(f"GENERATING PLOTS FOR ACE RUN") + print(f"{'='*60}") + print(f"Run directory: {args.run_dir}") + print(f"Mode: {args.mode}") + print(f"{'='*60}\n") + + # Generate plots based on mode + if args.mode == 'online': + print("Generating test performance plots...") + plot_online_performance(args.run_dir, args.mode) + + print("\nGenerating training progress plots...") + plot_training_progress(args.run_dir, args.mode) + elif args.mode == 'offline': + print("Generating offline training progress plots...") + plot_offline_training_progress(args.run_dir) + else: + print(f"Plot generation not supported for mode: {args.mode}") + + print(f"\n{'='*60}") + print(f"PLOTTING COMPLETE") + print(f"{'='*60}\n") + + +if __name__ == '__main__': + main() diff --git a/eval/stream-bench/preprocess_streambench_bird.py b/eval/stream-bench/preprocess_streambench_bird.py new file mode 100644 index 00000000..a346bd4f --- /dev/null +++ b/eval/stream-bench/preprocess_streambench_bird.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +""" +Download StreamBench (BIRD subset) from Hugging Face and preprocess to: +(1) supplement database schema (db_schema) using BIRD tables.json (preferred) + with a fallback to SQLite introspection if tables.json is missing/doesn't match. +(2) output only: question_id, question, sql, difficulty, db_name, db_schema +(3) Create proper train/val split since HuggingFace has identical splits + +Examples: + # Generate test split using dev databases + python preprocess_streambench_bird.py \ + --bird_root ./data/bird \ + --tables_json ./data/dev_20240627/dev_tables.json \ + --split test \ + --out ./data/streambench_bird_test.jsonl + + # Generate train split (80% of HF train data) using train databases + python preprocess_streambench_bird.py \ + --bird_root ./data/bird_train/train_databases \ + --split train \ + --out ./data/streambench_bird_train.jsonl + + # Generate validation split (20% of HF train data) using train databases + python preprocess_streambench_bird.py \ + --bird_root ./data/bird_train/train_databases \ + --split validation \ + --out ./data/streambench_bird_val.jsonl + + Note: For train/validation splits, you need to download the full BIRD train databases (~33GB): + python download_text2sql_data.py --dataset bird --split train +""" + +from __future__ import annotations + +import argparse +import glob +import json +import os +import sqlite3 +from typing import Any, Dict, List, Optional, Tuple + +def load_streambench_bird(split: str): + # Hugging Face datasets + from datasets import load_dataset + # Subset name is "bird", splits include train/validation/test + ds = load_dataset("appier-ai-research/StreamBench", "bird", split=split) + return ds + +def read_tables_json(tables_json_path: str) -> Dict[str, Dict[str, Any]]: + """ + Read Spider/BIRD-style tables.json: + Each entry has db_id, table_names_original, column_names_original, column_types, foreign_keys, primary_keys, etc. + Format described in Spider docs and used widely by text-to-SQL datasets. [oai_citation:1‡GitHub](https://github.com/taoyds/spider/blob/master/README.md?utm_source=chatgpt.com) + """ + with open(tables_json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + dbid_to_schema: Dict[str, Dict[str, Any]] = {} + for db in data: + db_id = db.get("db_id") + if not db_id: + continue + + table_names = db.get("table_names_original", []) + column_names = db.get("column_names_original", []) + column_types = db.get("column_types", []) + + # column_names: list like [[table_idx, "colname"], ...] + # table_idx = -1 for "*" + tables: List[Dict[str, Any]] = [{"name": t, "columns": []} for t in table_names] + + for i, (tbl_idx, col_name) in enumerate(column_names): + if tbl_idx == -1: + continue + col_type = column_types[i] if i < len(column_types) else None + tables[tbl_idx]["columns"].append({"name": col_name, "type": col_type}) + + schema_obj = { + "db_id": db_id, + "tables": tables, + "primary_keys": db.get("primary_keys", []), + "foreign_keys": db.get("foreign_keys", []), + } + dbid_to_schema[db_id] = schema_obj + + return dbid_to_schema + +def find_sqlite_path(bird_root: str, db_id: str) -> Optional[str]: + """ + Try common BIRD/Spider layouts. We search for a .sqlite file under bird_root containing db_id. + """ + patterns = [ + os.path.join(bird_root, "**", db_id, "*.sqlite"), + os.path.join(bird_root, "**", f"{db_id}.sqlite"), + os.path.join(bird_root, "**", db_id, f"{db_id}.sqlite"), + ] + for pat in patterns: + hits = glob.glob(pat, recursive=True) + if hits: + # prefer shortest path / deterministic choice + hits.sort(key=lambda p: (len(p), p)) + return hits[0] + return None + +def sqlite_introspect_schema(sqlite_path: str) -> Dict[str, Any]: + """ + Build a minimal schema object from SQLite: + - table names + - columns + types + - foreign keys + """ + conn = sqlite3.connect(sqlite_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + # Tables + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") + table_names = [r["name"] for r in cur.fetchall()] + + tables: List[Dict[str, Any]] = [] + foreign_keys: List[List[int]] = [] # keep same shape as tables.json-ish: [ [src_col_global, dst_col_global], ... ] (best-effort) + primary_keys: List[int] = [] # best-effort global column indices + + # We'll store columns globally as we go, to emulate tables.json indices roughly. + # This is a fallback only; for training it's usually fine to use structured per-table schema. + global_col_index: Dict[Tuple[str, str], int] = {} + global_cols: List[Tuple[str, str]] = [] + + for t in table_names: + cur.execute(f"PRAGMA table_info('{t}');") + cols = [] + for row in cur.fetchall(): + col_name = row["name"] + col_type = row["type"] + is_pk = int(row["pk"]) == 1 + cols.append({"name": col_name, "type": col_type}) + + idx = len(global_cols) + global_cols.append((t, col_name)) + global_col_index[(t, col_name)] = idx + if is_pk: + primary_keys.append(idx) + + tables.append({"name": t, "columns": cols}) + + # Foreign keys (best-effort mapping to global column indices) + for t in table_names: + cur.execute(f"PRAGMA foreign_key_list('{t}');") + for fk in cur.fetchall(): + src_col = fk["from"] + dst_table = fk["table"] + dst_col = fk["to"] + if (t, src_col) in global_col_index and (dst_table, dst_col) in global_col_index: + foreign_keys.append([global_col_index[(t, src_col)], global_col_index[(dst_table, dst_col)]]) + + conn.close() + return { + "sqlite_path": sqlite_path, + "tables": tables, + "primary_keys": primary_keys, + "foreign_keys": foreign_keys, + } + +def coerce_row(example: Dict[str, Any]) -> Dict[str, Any]: + """ + StreamBench BIRD fields (as shown in viewer): db_id, question, SQL, question_id, difficulty, evidence. [oai_citation:2‡Hugging Face](https://huggingface.co/datasets/appier-ai-research/StreamBench) + """ + db_id = example.get("db_id") + qid = example.get("question_id") + question = example.get("question") + sql = example.get("SQL") # note: uppercase in StreamBench viewer [oai_citation:3‡Hugging Face](https://huggingface.co/datasets/appier-ai-research/StreamBench) + difficulty = example.get("difficulty") + return { + "question_id": qid, + "question": question, + "sql": sql, + "difficulty": difficulty, + "db_name": db_id, + } + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--split", default="test", choices=["train", "validation", "test"]) + ap.add_argument("--out", default="streambench_bird_test.jsonl") + ap.add_argument("--tables_json", default="", help="Path to BIRD/Spider-style tables.json (recommended).") + ap.add_argument("--bird_root", default="", help="Root dir where BIRD sqlite databases live (fallback if tables.json missing).") + ap.add_argument("--schema_format", default="json", choices=["json", "string"], + help="Store db_schema as structured JSON (json) or as a compact string (string).") + ap.add_argument("--train_ratio", type=float, default=0.8, + help="Ratio of data to use for training when splitting train/val (default: 0.8)") + ap.add_argument("--seed", type=int, default=42, + help="Random seed for train/val split (default: 42)") + args = ap.parse_args() + + # Load dataset + # Note: HuggingFace train and validation splits are identical, so we create our own split + if args.split in ["train", "validation"]: + # Load the HF train split and create our own train/val split + print(f"Loading HuggingFace 'train' split for custom {args.split} split...") + ds_full = load_streambench_bird("train") + + # Convert to list for splitting + ds_list = list(ds_full) + total_samples = len(ds_list) + + # Shuffle with seed for reproducibility + import random + random.seed(args.seed) + random.shuffle(ds_list) + + # Split based on train_ratio + split_idx = int(total_samples * args.train_ratio) + + if args.split == "train": + ds = ds_list[:split_idx] + print(f"Created train split: {len(ds)} samples ({args.train_ratio*100:.0f}% of {total_samples})") + else: # validation + ds = ds_list[split_idx:] + print(f"Created validation split: {len(ds)} samples ({(1-args.train_ratio)*100:.0f}% of {total_samples})") + else: + # For test split, use the original HF split + ds = load_streambench_bird(args.split) + print(f"Loaded test split: {len(ds)} samples") + + # Load schema map from tables.json if provided + dbid_to_schema = {} + if args.tables_json: + if not os.path.exists(args.tables_json): + raise FileNotFoundError(f"--tables_json not found: {args.tables_json}") + dbid_to_schema = read_tables_json(args.tables_json) + + # Process and write + n_missing_schema = 0 + with open(args.out, "w", encoding="utf-8") as f: + for ex in ds: + row = coerce_row(ex) + db_id = row["db_name"] + + schema_obj = dbid_to_schema.get(db_id) + + # Fallback to sqlite introspection if needed + if schema_obj is None: + if args.bird_root: + sqlite_path = find_sqlite_path(args.bird_root, db_id) + if sqlite_path: + schema_obj = sqlite_introspect_schema(sqlite_path) + + if schema_obj is None: + n_missing_schema += 1 + schema_obj = {"error": "schema_not_found", "db_id": db_id} + + if args.schema_format == "string": + # Compact, readable schema string + if "tables" in schema_obj and isinstance(schema_obj["tables"], list): + parts = [] + for t in schema_obj["tables"]: + tname = t.get("name", "") + cols = t.get("columns", []) + col_str = ", ".join([c.get("name", "") for c in cols if isinstance(c, dict)]) + parts.append(f"{tname}({col_str})") + row["db_schema"] = "\n".join(parts) + else: + row["db_schema"] = json.dumps(schema_obj, ensure_ascii=False) + else: + row["db_schema"] = schema_obj + + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + print(f"Wrote {len(ds)} samples to: {args.out}") + if n_missing_schema: + print(f"WARNING: {n_missing_schema} rows had missing schema (schema_not_found). Provide --tables_json and/or --bird_root.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/eval/stream-bench/preprocess_streambench_cosql.py b/eval/stream-bench/preprocess_streambench_cosql.py new file mode 100644 index 00000000..36186e5a --- /dev/null +++ b/eval/stream-bench/preprocess_streambench_cosql.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 +""" +Download StreamBench (CoSQL subset) from Hugging Face and preprocess to: +(1) supplement database schema (db_schema) using CoSQL tables.json (preferred) + with a fallback to SQLite introspection if tables.json is missing/doesn't match. +(2) output only: question_id, question, sql, difficulty, db_name, db_schema + +CoSQL is a conversational text-to-SQL dataset. Each dialogue contains multiple turns. +We extract individual question-SQL pairs from the conversations. + +Examples: + # Generate test split using CoSQL databases + python preprocess_streambench_cosql.py \ + --cosql_root ./data/cosql \ + --tables_json ./data/cosql/tables.json \ + --split test \ + --out ./data/streambench_cosql_test.jsonl + + # Generate train split using CoSQL databases + python preprocess_streambench_cosql.py \ + --cosql_root ./data/cosql \ + --tables_json ./data/cosql/tables.json \ + --split train \ + --out ./data/streambench_cosql_train.jsonl + + Note: First download CoSQL data: + python download_text2sql_data.py --dataset cosql +""" + +from __future__ import annotations + +import argparse +import glob +import json +import os +import sqlite3 +from typing import Any, Dict, List, Optional, Tuple + +def load_streambench_cosql(split: str): + # Hugging Face datasets + from datasets import load_dataset + # Subset name is "cosql", splits include train/validation/test + ds = load_dataset("appier-ai-research/StreamBench", "cosql", split=split) + return ds + +def read_tables_json(tables_json_path: str) -> Dict[str, Dict[str, Any]]: + """ + Read Spider/CoSQL-style tables.json: + Each entry has db_id, table_names_original, column_names_original, column_types, foreign_keys, primary_keys, etc. + """ + with open(tables_json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + dbid_to_schema: Dict[str, Dict[str, Any]] = {} + for db in data: + db_id = db.get("db_id") + if not db_id: + continue + + table_names = db.get("table_names_original", []) + column_names = db.get("column_names_original", []) + column_types = db.get("column_types", []) + + # column_names: list like [[table_idx, "colname"], ...] + # table_idx = -1 for "*" + tables: List[Dict[str, Any]] = [{"name": t, "columns": []} for t in table_names] + + for i, (tbl_idx, col_name) in enumerate(column_names): + if tbl_idx == -1: + continue + col_type = column_types[i] if i < len(column_types) else None + tables[tbl_idx]["columns"].append({"name": col_name, "type": col_type}) + + schema_obj = { + "db_id": db_id, + "tables": tables, + "primary_keys": db.get("primary_keys", []), + "foreign_keys": db.get("foreign_keys", []), + } + dbid_to_schema[db_id] = schema_obj + + return dbid_to_schema + +def find_sqlite_path(cosql_root: str, db_id: str) -> Optional[str]: + """ + Try common CoSQL/Spider layouts. CoSQL typically has databases in a "database" subdirectory. + """ + patterns = [ + os.path.join(cosql_root, "database", db_id, "*.sqlite"), + os.path.join(cosql_root, "database", f"{db_id}.sqlite"), + os.path.join(cosql_root, "database", db_id, f"{db_id}.sqlite"), + os.path.join(cosql_root, "**", db_id, "*.sqlite"), + os.path.join(cosql_root, "**", f"{db_id}.sqlite"), + ] + for pat in patterns: + hits = glob.glob(pat, recursive=True) + if hits: + # prefer shortest path / deterministic choice + hits.sort(key=lambda p: (len(p), p)) + return hits[0] + return None + +def sqlite_introspect_schema(sqlite_path: str) -> Dict[str, Any]: + """ + Build a minimal schema object from SQLite: + - table names + - columns + types + - foreign keys + """ + conn = sqlite3.connect(sqlite_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + # Tables + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") + table_names = [r["name"] for r in cur.fetchall()] + + tables: List[Dict[str, Any]] = [] + foreign_keys: List[List[int]] = [] + primary_keys: List[int] = [] + + # Store columns globally to emulate tables.json indices + global_col_index: Dict[Tuple[str, str], int] = {} + global_cols: List[Tuple[str, str]] = [] + + for t in table_names: + cur.execute(f"PRAGMA table_info('{t}');") + cols = [] + for row in cur.fetchall(): + col_name = row["name"] + col_type = row["type"] + is_pk = int(row["pk"]) == 1 + cols.append({"name": col_name, "type": col_type}) + + idx = len(global_cols) + global_cols.append((t, col_name)) + global_col_index[(t, col_name)] = idx + if is_pk: + primary_keys.append(idx) + + tables.append({"name": t, "columns": cols}) + + # Foreign keys (best-effort mapping to global column indices) + for t in table_names: + cur.execute(f"PRAGMA foreign_key_list('{t}');") + for fk in cur.fetchall(): + src_col = fk["from"] + dst_table = fk["table"] + dst_col = fk["to"] + if (t, src_col) in global_col_index and (dst_table, dst_col) in global_col_index: + foreign_keys.append([global_col_index[(t, src_col)], global_col_index[(dst_table, dst_col)]]) + + conn.close() + return { + "sqlite_path": sqlite_path, + "tables": tables, + "primary_keys": primary_keys, + "foreign_keys": foreign_keys, + } + +def coerce_row(example: Dict[str, Any]) -> Dict[str, Any]: + """ + StreamBench CoSQL fields: db_id, question, SQL, question_id, turn_id, difficulty + CoSQL is conversational, so we have turn_id to track position in dialogue. + """ + db_id = example.get("db_id") + qid = example.get("question_id") + question = example.get("question") + sql = example.get("SQL") + turn_id = example.get("turn_id") + difficulty = example.get("difficulty", "") + + # Create unique identifier combining question_id and turn_id for multi-turn conversations + unique_id = f"{qid}_turn{turn_id}" if turn_id is not None else str(qid) + + return { + "question_id": unique_id, + "question": question, + "sql": sql, + "difficulty": difficulty, + "db_name": db_id, + "turn_id": turn_id, + } + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--split", default="test", choices=["train", "validation", "test"]) + ap.add_argument("--out", default="streambench_cosql_test.jsonl") + ap.add_argument("--tables_json", default="", help="Path to CoSQL-style tables.json (recommended).") + ap.add_argument("--cosql_root", default="", help="Root dir where CoSQL sqlite databases live (fallback if tables.json missing).") + ap.add_argument("--schema_format", default="json", choices=["json", "string"], + help="Store db_schema as structured JSON (json) or as a compact string (string).") + args = ap.parse_args() + + # Load dataset + ds = load_streambench_cosql(args.split) + + # Load schema map from tables.json if provided + dbid_to_schema = {} + if args.tables_json: + if os.path.exists(args.tables_json): + dbid_to_schema = read_tables_json(args.tables_json) + print(f"Loaded schema metadata from {args.tables_json}") + else: + print(f"Warning: --tables_json not found at {args.tables_json}, will use SQLite introspection") + else: + print("No tables.json provided, will use SQLite introspection for schema") + + # Process and write + n_missing_schema = 0 + with open(args.out, "w", encoding="utf-8") as f: + for ex in ds: + row = coerce_row(ex) + db_id = row["db_name"] + + schema_obj = dbid_to_schema.get(db_id) + + # Fallback to sqlite introspection if needed + if schema_obj is None: + if args.cosql_root: + sqlite_path = find_sqlite_path(args.cosql_root, db_id) + if sqlite_path: + schema_obj = sqlite_introspect_schema(sqlite_path) + + if schema_obj is None: + n_missing_schema += 1 + schema_obj = {"error": "schema_not_found", "db_id": db_id} + + if args.schema_format == "string": + # Compact, readable schema string + if "tables" in schema_obj and isinstance(schema_obj["tables"], list): + parts = [] + for t in schema_obj["tables"]: + tname = t.get("name", "") + cols = t.get("columns", []) + col_str = ", ".join([c.get("name", "") for c in cols if isinstance(c, dict)]) + parts.append(f"{tname}({col_str})") + row["db_schema"] = "\n".join(parts) + else: + row["db_schema"] = json.dumps(schema_obj, ensure_ascii=False) + else: + row["db_schema"] = schema_obj + + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + print(f"Wrote: {args.out}") + if n_missing_schema: + print(f"WARNING: {n_missing_schema} rows had missing schema (schema_not_found). Provide --tables_json and/or --cosql_root.") + +if __name__ == "__main__": + main() diff --git a/eval/stream-bench/preprocess_streambench_spider.py b/eval/stream-bench/preprocess_streambench_spider.py new file mode 100644 index 00000000..06ad4bec --- /dev/null +++ b/eval/stream-bench/preprocess_streambench_spider.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Download StreamBench (Spider subset) from Hugging Face and preprocess to: +(1) supplement database schema (db_schema) using Spider tables.json (preferred) + with a fallback to SQLite introspection if tables.json is missing/doesn't match. +(2) output only: question_id, question, sql, difficulty, db_name, db_schema + +Examples: + # Generate test split using Spider databases + python preprocess_streambench_spider.py \ + --spider_root ./data/spider \ + --tables_json ./data/spider/tables.json \ + --split test \ + --out ./data/streambench_spider_test.jsonl + + # Generate train split using Spider databases + python preprocess_streambench_spider.py \ + --spider_root ./data/spider \ + --tables_json ./data/spider/tables.json \ + --split train \ + --out ./data/streambench_spider_train.jsonl + + # Generate validation split + python preprocess_streambench_spider.py \ + --spider_root ./data/spider \ + --tables_json ./data/spider/tables.json \ + --split validation \ + --out ./data/streambench_spider_val.jsonl +""" + +from __future__ import annotations + +import argparse +import glob +import json +import os +import sqlite3 +from typing import Any, Dict, List, Optional, Tuple + +def load_streambench_spider(split: str): + # Hugging Face datasets + from datasets import load_dataset + # Subset name is "spider", splits include train/validation/test + ds = load_dataset("appier-ai-research/StreamBench", "spider", split=split) + return ds + +def read_tables_json(tables_json_path: str) -> Dict[str, Dict[str, Any]]: + """ + Read Spider-style tables.json: + Each entry has db_id, table_names_original, column_names_original, column_types, foreign_keys, primary_keys, etc. + Format described in Spider docs and used widely by text-to-SQL datasets. + """ + with open(tables_json_path, "r", encoding="utf-8") as f: + data = json.load(f) + + dbid_to_schema: Dict[str, Dict[str, Any]] = {} + for db in data: + db_id = db.get("db_id") + if not db_id: + continue + + table_names = db.get("table_names_original", []) + column_names = db.get("column_names_original", []) + column_types = db.get("column_types", []) + + # column_names: list like [[table_idx, "colname"], ...] + # table_idx = -1 for "*" + tables: List[Dict[str, Any]] = [{"name": t, "columns": []} for t in table_names] + + for i, (tbl_idx, col_name) in enumerate(column_names): + if tbl_idx == -1: + continue + col_type = column_types[i] if i < len(column_types) else None + tables[tbl_idx]["columns"].append({"name": col_name, "type": col_type}) + + schema_obj = { + "db_id": db_id, + "tables": tables, + "primary_keys": db.get("primary_keys", []), + "foreign_keys": db.get("foreign_keys", []), + } + dbid_to_schema[db_id] = schema_obj + + return dbid_to_schema + +def find_sqlite_path(spider_root: str, db_id: str) -> Optional[str]: + """ + Try common Spider layouts. We search for a .sqlite file under spider_root containing db_id. + Typical Spider structure: spider/database/{db_id}/{db_id}.sqlite + """ + patterns = [ + os.path.join(spider_root, "database", db_id, "*.sqlite"), + os.path.join(spider_root, "database", f"{db_id}.sqlite"), + os.path.join(spider_root, "database", db_id, f"{db_id}.sqlite"), + os.path.join(spider_root, "**", db_id, "*.sqlite"), + os.path.join(spider_root, "**", f"{db_id}.sqlite"), + os.path.join(spider_root, "**", db_id, f"{db_id}.sqlite"), + ] + for pat in patterns: + hits = glob.glob(pat, recursive=True) + if hits: + # prefer shortest path / deterministic choice + hits.sort(key=lambda p: (len(p), p)) + return hits[0] + return None + +def sqlite_introspect_schema(sqlite_path: str) -> Dict[str, Any]: + """ + Build a minimal schema object from SQLite: + - table names + - columns + types + - foreign keys + """ + conn = sqlite3.connect(sqlite_path) + conn.row_factory = sqlite3.Row + cur = conn.cursor() + + # Tables + cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%';") + table_names = [r["name"] for r in cur.fetchall()] + + tables: List[Dict[str, Any]] = [] + foreign_keys: List[List[int]] = [] # keep same shape as tables.json-ish: [ [src_col_global, dst_col_global], ... ] (best-effort) + primary_keys: List[int] = [] # best-effort global column indices + + # We'll store columns globally as we go, to emulate tables.json indices roughly. + # This is a fallback only; for training it's usually fine to use structured per-table schema. + global_col_index: Dict[Tuple[str, str], int] = {} + global_cols: List[Tuple[str, str]] = [] + + for t in table_names: + cur.execute(f"PRAGMA table_info('{t}');") + cols = [] + for row in cur.fetchall(): + col_name = row["name"] + col_type = row["type"] + is_pk = int(row["pk"]) == 1 + cols.append({"name": col_name, "type": col_type}) + + idx = len(global_cols) + global_cols.append((t, col_name)) + global_col_index[(t, col_name)] = idx + if is_pk: + primary_keys.append(idx) + + tables.append({"name": t, "columns": cols}) + + # Foreign keys (best-effort mapping to global column indices) + for t in table_names: + cur.execute(f"PRAGMA foreign_key_list('{t}');") + for fk in cur.fetchall(): + src_col = fk["from"] + dst_table = fk["table"] + dst_col = fk["to"] + if (t, src_col) in global_col_index and (dst_table, dst_col) in global_col_index: + foreign_keys.append([global_col_index[(t, src_col)], global_col_index[(dst_table, dst_col)]]) + + conn.close() + return { + "sqlite_path": sqlite_path, + "tables": tables, + "primary_keys": primary_keys, + "foreign_keys": foreign_keys, + } + +def coerce_row(example: Dict[str, Any]) -> Dict[str, Any]: + """ + StreamBench Spider fields: db_id, question, SQL, question_id, difficulty (if available), evidence. + Map to standardized format compatible with the data processor. + """ + db_id = example.get("db_id") + qid = example.get("question_id") + question = example.get("question") + sql = example.get("SQL") # note: uppercase in StreamBench + difficulty = example.get("difficulty", "unknown") # Spider might not have difficulty in all versions + return { + "question_id": qid, + "question": question, + "sql": sql, + "difficulty": difficulty, + "db_name": db_id, + } + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--split", default="test", choices=["train", "validation", "test"]) + ap.add_argument("--out", default="streambench_spider_test.jsonl") + ap.add_argument("--tables_json", default="", help="Path to Spider-style tables.json (recommended).") + ap.add_argument("--spider_root", default="", help="Root dir where Spider sqlite databases live (fallback if tables.json missing).") + ap.add_argument("--schema_format", default="json", choices=["json", "string"], + help="Store db_schema as structured JSON (json) or as a compact string (string).") + args = ap.parse_args() + + # Load dataset + ds = load_streambench_spider(args.split) + + # Load schema map from tables.json if provided + dbid_to_schema = {} + if args.tables_json: + if not os.path.exists(args.tables_json): + raise FileNotFoundError(f"--tables_json not found: {args.tables_json}") + dbid_to_schema = read_tables_json(args.tables_json) + + # Process and write + n_missing_schema = 0 + with open(args.out, "w", encoding="utf-8") as f: + for ex in ds: + row = coerce_row(ex) + db_id = row["db_name"] + + schema_obj = dbid_to_schema.get(db_id) + + # Fallback to sqlite introspection if needed + if schema_obj is None: + if args.spider_root: + sqlite_path = find_sqlite_path(args.spider_root, db_id) + if sqlite_path: + schema_obj = sqlite_introspect_schema(sqlite_path) + + if schema_obj is None: + n_missing_schema += 1 + schema_obj = {"error": "schema_not_found", "db_id": db_id} + + if args.schema_format == "string": + # Compact, readable schema string + if "tables" in schema_obj and isinstance(schema_obj["tables"], list): + parts = [] + for t in schema_obj["tables"]: + tname = t.get("name", "") + cols = t.get("columns", []) + col_str = ", ".join([c.get("name", "") for c in cols if isinstance(c, dict)]) + parts.append(f"{tname}({col_str})") + row["db_schema"] = "\n".join(parts) + else: + row["db_schema"] = json.dumps(schema_obj, ensure_ascii=False) + else: + row["db_schema"] = schema_obj + + f.write(json.dumps(row, ensure_ascii=False) + "\n") + + print(f"Wrote: {args.out}") + if n_missing_schema: + print(f"WARNING: {n_missing_schema} rows had missing schema (schema_not_found). Provide --tables_json and/or --spider_root.") + +if __name__ == "__main__": + main() diff --git a/eval/stream-bench/run.py b/eval/stream-bench/run.py new file mode 100644 index 00000000..295cabac --- /dev/null +++ b/eval/stream-bench/run.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +""" +Stream Bench task runner using ACE system. +""" +import os +import sys +import json +import re +import time +import traceback + +from ace import ACE +from .data_processor import DataProcessor +from .plot import plot_online_performance, plot_training_progress, plot_offline_training_progress +from finance.run import get_base_parser, load_initial_playbook, load_data + + +def parse_args(): + """Parse command line arguments for stream-bench.""" + # Get base parser with all common arguments + parser = get_base_parser(description='ACE System - Stream Bench') + + # Add stream-bench specific arguments + parser.add_argument("--data_config", type=str, required=True, + help="Path to data configuration JSON file") + parser.add_argument("--plot", action="store_true", + help="Generate performance plot for online mode (shows accuracy vs steps)") + parser.add_argument("--db_name", type=str, default=None, + help="Database name to filter data (optional, overrides config)") + parser.add_argument("--curriculum", type=str, default=None, choices=["easy_to_hard", "hard_to_easy", "random"], + help="Curriculum ordering strategy: easy_to_hard, hard_to_easy, random") + + return parser.parse_args() + + +class TeeLogger: + """Logger that writes to both terminal and file simultaneously with auto-flush.""" + + def __init__(self, log_file_path, mode='w'): + self.terminal = sys.stdout + self.log_file = open(log_file_path, mode, buffering=1) # Line buffering + self.log_file_path = log_file_path + + def write(self, message): + self.terminal.write(message) + self.log_file.write(message) + # Force flush to ensure immediate write + self.terminal.flush() + self.log_file.flush() + + def flush(self): + self.terminal.flush() + self.log_file.flush() + + def close(self): + sys.stdout = self.terminal + self.log_file.close() + + +def preprocess_data(task_name, config, mode, db_name=None, curriculum=None): + """ + Load training and test data for the specified task. + + Args: + task_name: Name of the task + config: Configuration dictionary with data paths and settings + mode: Run mode ('offline', 'online', 'online', or 'eval_only') + db_name: Database name from command line args + curriculum: Curriculum ordering from command line args + + Returns: + Tuple of (train_samples, val_samples, test_samples, train_processor, val_processor, test_processor) + - For offline mode: all three sample sets and all three processors are returned + - For online/eval_only mode: only test_samples and test_processor (train/val processors are None) + """ + # Get max_samples from config + # max_samples serves as default for all splits + # Individual limits (max_train_samples, max_val_samples, max_test_samples) override the default + max_samples = config.get("max_samples", None) + max_train_samples = config.get("max_train_samples") if "max_train_samples" in config else max_samples + max_val_samples = config.get("max_val_samples") if "max_val_samples" in config else max_samples + max_test_samples = config.get("max_test_samples") if "max_test_samples" in config else max_samples + + # Detect task type based on config keys + # If cosql_db_root is present, it's a CoSQL task, otherwise BIRD + if "cosql_db_root" in config: + task = "cosql" + # Get cosql_db_root from config + cosql_db_root = config.get("cosql_db_root", "eval/stream-bench/data/cosql") + bird_db_root = None + bird_train_db_root = None + bird_val_db_root = None + bird_test_db_root = None + cosql_train_db_root = cosql_db_root + cosql_val_db_root = cosql_db_root + cosql_test_db_root = cosql_db_root + + print(f"[CONFIG] Task: CoSQL") + print(f"[CONFIG] Database path: {cosql_db_root}") + else: + task = "bird" + # Get bird_db_root from config, with support for separate train/val/test database paths + # bird_db_root serves as default for all splits + # Individual paths (bird_train_db_root, bird_val_db_root, bird_test_db_root) override the default + bird_db_root = config.get("bird_db_root", "eval/stream-bench/data/bird/dev_databases") + bird_train_db_root = config.get("bird_train_db_root") if "bird_train_db_root" in config else bird_db_root + bird_val_db_root = config.get("bird_val_db_root") if "bird_val_db_root" in config else bird_db_root + bird_test_db_root = config.get("bird_test_db_root") if "bird_test_db_root" in config else bird_db_root + cosql_db_root = None + cosql_train_db_root = None + cosql_val_db_root = None + cosql_test_db_root = None + + print(f"[CONFIG] Task: BIRD") + print(f"[CONFIG] Database paths:") + print(f" bird_train_db_root: {bird_train_db_root}") + print(f" bird_val_db_root: {bird_val_db_root}") + print(f" bird_test_db_root: {bird_test_db_root}") + + # Get difficulty_filter from config (dataset-level selection) + difficulty_filter = config.get("difficulty_filter", None) + + # For online and eval_only modes, only load test data + if mode in ["online", "eval_only"]: + train_samples = None + val_samples = None + + # Create processor for test data + test_processor = DataProcessor( + bird_db_root=bird_test_db_root, + cosql_db_root=cosql_test_db_root, + task=task, + max_samples=max_test_samples, + db_name=db_name, + difficulty_filter=difficulty_filter, + curriculum=curriculum + ) + + if "test_data" in config: + test_samples = load_data(config["test_data"]) + test_samples = test_processor.process_task_data(test_samples) + else: + raise ValueError(f"{mode} mode requires test data in config.") + + if mode == "online": + print(f"Online mode: Training and testing on {len(test_samples)} examples") + else: + print(f"Eval only mode: Testing on {len(test_samples)} examples") + + return train_samples, val_samples, test_samples, None, None, test_processor + + # For offline mode, load train, val, and optionally test data + else: + # Create separate processors for train, val, and test to apply different max_samples + train_processor = DataProcessor( + bird_db_root=bird_train_db_root, + cosql_db_root=cosql_train_db_root, + task=task, + max_samples=max_train_samples, + db_name=db_name, + difficulty_filter=difficulty_filter, + curriculum=curriculum + ) + + val_processor = DataProcessor( + bird_db_root=bird_val_db_root, + cosql_db_root=cosql_val_db_root, + task=task, + max_samples=max_val_samples, + db_name=db_name, + difficulty_filter=difficulty_filter, + curriculum=curriculum + ) + + test_processor = DataProcessor( + bird_db_root=bird_test_db_root, + cosql_db_root=cosql_test_db_root, + task=task, + max_samples=max_test_samples, + db_name=db_name, + difficulty_filter=difficulty_filter, + curriculum=curriculum + ) + + train_samples = load_data(config["train_data"]) + val_samples = load_data(config["val_data"]) + train_samples = train_processor.process_task_data(train_samples) + val_samples = val_processor.process_task_data(val_samples) + + if "test_data" in config: + test_samples = load_data(config["test_data"]) + test_samples = test_processor.process_task_data(test_samples) + else: + test_samples = [] + + print(f"Offline mode: Training on {len(train_samples)} examples, " + f"validating on {len(val_samples)}, testing on {len(test_samples)}") + + # Return all three processors for proper evaluation of each split + return train_samples, val_samples, test_samples, train_processor, val_processor, test_processor + + +def main(): + """Main execution function.""" + # Start total timing + total_start_time = time.time() + + args = parse_args() + + # Create temporary log directory to capture all output from the start + temp_log_dir = os.path.join(args.save_path, "temp_logs") + os.makedirs(temp_log_dir, exist_ok=True) + log_timestamp = time.strftime("%Y%m%d_%H%M%S") + temp_log_path = os.path.join(temp_log_dir, f"terminal_output_{log_timestamp}.txt") + + # Set up logger immediately to capture ALL output + logger = TeeLogger(temp_log_path) + sys.stdout = logger + + # Print initial banner (now captured by logger) + print(f"\n{'='*60}") + print(f"ACE SYSTEM - Stream Bench") + print(f"{'='*60}") + print(f"Task: {args.task_name}") + print(f"Mode: {args.mode.upper().replace('_', ' ')}") + print(f"Generator Model: {args.generator_model}") + print(f"Data Config: {args.data_config}") + print(f"Logging all terminal output to: {temp_log_path}") + print(f"{'='*60}\n") + + try: + + # Load data configuration + with open(args.data_config, 'r') as f: + data_config = json.load(f) + + # Get task-specific config + if args.task_name not in data_config: + raise ValueError(f"Task '{args.task_name}' not found in config file: {args.data_config}") + + task_config = data_config[args.task_name] + + # Print config settings for max_samples + max_samples_default = task_config.get("max_samples", None) + has_overrides = "max_train_samples" in task_config or "max_val_samples" in task_config or "max_test_samples" in task_config + + if has_overrides: + # Show overrides with default fallback + print(f"Max samples (from config):") + if max_samples_default is not None: + print(f" - Default: {max_samples_default}") + train_val = task_config.get('max_train_samples', max_samples_default or 'No limit') + val_val = task_config.get('max_val_samples', max_samples_default or 'No limit') + test_val = task_config.get('max_test_samples', max_samples_default or 'No limit') + print(f" - Train: {train_val}") + print(f" - Validation: {val_val}") + print(f" - Test: {test_val}") + elif max_samples_default is not None: + # Only default specified + print(f"Max samples (from config): {max_samples_default} (applies to all splits)") + else: + print(f"Max samples: No limit") + + if args.db_name: + print(f"Database filter: {args.db_name}") + else: + print(f"Database filter: None (using mixed databases)") + + if "difficulty_filter" in task_config: + print(f"Difficulty filter (from config): {task_config['difficulty_filter']}") + else: + print(f"Difficulty filter: None (no filtering)") + + if args.curriculum: + print(f"Curriculum ordering: {args.curriculum}") + else: + print(f"Curriculum ordering: None (original order)") + + print() # blank line + + train_samples, val_samples, test_samples, train_processor, val_processor, test_processor = preprocess_data( + args.task_name, + task_config, + args.mode, + db_name=args.db_name, + curriculum=args.curriculum + ) + + # Load initial playbook (or use empty if None provided) + initial_playbook = load_initial_playbook(args.initial_playbook_path) + if initial_playbook: + print(f"Loaded initial playbook from {args.initial_playbook_path}\n") + else: + print("Using empty playbook as initial playbook\n") + + # Create ACE system with a custom wrapper to intercept path creation + ace_system = ACE( + api_provider=args.api_provider, + generator_model=args.generator_model, + reflector_model=args.reflector_model, + curator_model=args.curator_model, + max_tokens=args.max_tokens, + initial_playbook=initial_playbook, + use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, + bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold + ) + + # Extract config filename (without extension) from config path + config_filename = os.path.splitext(os.path.basename(args.data_config))[0] + + # Prepare configuration + config = { + 'num_epochs': args.num_epochs, + 'max_num_rounds': args.max_num_rounds, + 'curator_frequency': args.curator_frequency, + 'eval_steps': args.eval_steps, + 'online_eval_frequency': args.online_eval_frequency, + 'save_steps': args.save_steps, + 'playbook_token_budget': args.playbook_token_budget, + 'task_name': args.task_name, + 'mode': args.mode, + 'json_mode': args.json_mode, + 'no_ground_truth': args.no_ground_truth, + 'save_dir': args.save_path, # Pass parent directory + 'test_workers': args.test_workers, + 'initial_playbook_path': args.initial_playbook_path, + 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, + 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, + 'pass_sql_eval_results': args.pass_sql_eval_results, + 'api_provider': args.api_provider, + 'config_name': config_filename, + 'db_name': args.db_name, + 'curriculum': args.curriculum + } + + # Create a save hook to intercept when ACE creates the save path + original_setup_paths = ace_system._setup_paths + run_save_path_container = {'path': None} + + def setup_paths_with_data_save(*args, **kwargs): + """Wrapper that saves processed data right after path creation.""" + result = original_setup_paths(*args, **kwargs) + # Extract save_path from result (first element of tuple) + save_path = result[0] if isinstance(result, tuple) else result + run_save_path_container['path'] = save_path + + # Save processed data immediately after folder creation + print(f"\nSaving preprocessed data to: {save_path}") + processed_data_dir = os.path.join(save_path, "processed_data") + os.makedirs(processed_data_dir, exist_ok=True) + + if train_samples is not None: + train_path = os.path.join(processed_data_dir, "train_samples.json") + with open(train_path, 'w') as f: + json.dump(train_samples, f, indent=2) + print(f" - Saved train samples ({len(train_samples)} samples)") + + if val_samples is not None: + val_path = os.path.join(processed_data_dir, "val_samples.json") + with open(val_path, 'w') as f: + json.dump(val_samples, f, indent=2) + print(f" - Saved val samples ({len(val_samples)} samples)") + + if test_samples is not None: + test_path = os.path.join(processed_data_dir, "test_samples.json") + with open(test_path, 'w') as f: + json.dump(test_samples, f, indent=2) + print(f" - Saved test samples ({len(test_samples)} samples)") + + print() # blank line + return result + + # Replace the method temporarily + ace_system._setup_paths = setup_paths_with_data_save + + # Execute using the unified run method + print(f"Starting ACE run at {time.strftime('%Y-%m-%d %H:%M:%S')}\n") + run_start_time = time.time() + + results = ace_system.run( + mode=args.mode, + train_samples=train_samples, + val_samples=val_samples, + test_samples=test_samples, + train_processor=train_processor, + val_processor=val_processor, + test_processor=test_processor, + config=config + ) + + run_elapsed_time = time.time() - run_start_time + print(f"\nACE run completed in {run_elapsed_time/60:.2f} minutes ({run_elapsed_time:.2f} seconds)") + + # Get the actual save path that was created + run_save_path = run_save_path_container['path'] + if not run_save_path: + # Fallback to results if something went wrong with the hook + run_save_path = results.get('save_path', args.save_path) + + # Extract timestamp from the ace_run folder name to match it exactly + # Folder format: ace_run_YYYYMMDD_HHMMSS_task_name_... + folder_name = os.path.basename(run_save_path) + timestamp_match = re.search(r'ace_run_(\d{8}_\d{6})', folder_name) + if timestamp_match: + ace_run_timestamp = timestamp_match.group(1) + else: + # Fallback to original timestamp if extraction fails + ace_run_timestamp = log_timestamp + + # Move the log file from temp location to final location + final_log_path = os.path.join(run_save_path, f"terminal_output_{ace_run_timestamp}.txt") + + # Close current logger before moving file + logger.close() + + # Move the log file to final location + import shutil + shutil.move(temp_log_path, final_log_path) + + # Reopen logger with final path in APPEND mode to continue logging + logger = TeeLogger(final_log_path, mode='a') + sys.stdout = logger + + print(f"\nMoved terminal output log to: {final_log_path}") + + # Calculate and display total timing + total_elapsed_time = time.time() - total_start_time + + print(f"\n{'='*60}") + print(f"TOTAL EXECUTION TIME") + print(f"{'='*60}") + print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(total_start_time))}") + print(f"End time: {time.strftime('%Y-%m-%d %H:%M:%S')}") + print(f"Total time: {total_elapsed_time/60:.2f} minutes ({total_elapsed_time:.2f} seconds)") + print(f"ACE run time: {run_elapsed_time/60:.2f} minutes ({run_elapsed_time:.2f} seconds)") + print(f"{'='*60}\n") + + # Generate performance plots if requested + if args.plot: + print(f"\n{'='*60}") + print(f"GENERATING PERFORMANCE PLOTS") + print(f"{'='*60}\n") + + if args.mode == 'online': + plot_online_performance(run_save_path, args.mode) + plot_training_progress(run_save_path, args.mode) + elif args.mode == 'offline': + plot_offline_training_progress(run_save_path) + else: + print(f"Skipping plot generation - not available for {args.mode} mode") + + # Close the logger + if logger: + # Print before closing since we're using the logger + final_message = f"Terminal output saved to {final_log_path}" + logger.close() + # Print to terminal after logger is closed + print(final_message) + + except Exception as e: + print(f"\n{'='*60}") + print(f"ERROR: An exception occurred") + print(f"{'='*60}") + print(f"{type(e).__name__}: {e}") + traceback.print_exc() + print(f"{'='*60}\n") + raise + finally: + # Ensure logger is closed even if there's an error + if 'logger' in locals() and logger is not None: + logger.close() + + +if __name__ == "__main__": + main() diff --git a/eval/stream-bench/run_playbook.py b/eval/stream-bench/run_playbook.py new file mode 100755 index 00000000..6957795b --- /dev/null +++ b/eval/stream-bench/run_playbook.py @@ -0,0 +1,525 @@ +#!/usr/bin/env python3 +""" +Script to run a playbook from intermediate_playbooks folder on processed data +and evaluate accuracy. + +Usage: + python run_playbook.py --results_dir --playbook_file + +Example: + python run_playbook.py --results_dir results/ace_run_20260119_234301_bird_all_hard_to_easy_online --playbook_file window_4_final_playbook.txt +""" + +import os +import sys +import json +import argparse +from pathlib import Path +from typing import List, Dict, Any +from concurrent.futures import ThreadPoolExecutor, as_completed + +# Add parent directory to path to import modules +parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.insert(0, parent_dir) + +# Add stream-bench to path +stream_bench_dir = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, stream_bench_dir) + +from ace.core import Generator +from utils import initialize_clients +from data_processor import DataProcessor + + +def load_playbook(playbook_path: str) -> str: + """Load playbook content from file.""" + with open(playbook_path, 'r') as f: + return f.read() + + + + +def extract_sql_from_response(response: str) -> str: + """ + Extract SQL query from generator response. + + The response may be: + 1. JSON format with reasoning and final_answer + 2. Plain SQL query + + Args: + response: Generator response + + Returns: + Extracted SQL query + """ + # Try to parse as JSON first + try: + response_json = json.loads(response) + # Try to get final_answer field + if 'final_answer' in response_json: + sql = response_json['final_answer'] + # Remove any markdown code blocks + sql = sql.replace('```sql', '').replace('```', '').strip() + return sql + # If no final_answer, try to get the whole response + return response.strip() + except (json.JSONDecodeError, TypeError): + # If not JSON, assume it's plain SQL + # Remove any markdown code blocks + sql = response.replace('```sql', '').replace('```', '').strip() + return sql + + +def generate_predictions_parallel( + generator: Generator, + samples: List[Dict[str, Any]], + playbook: str, + num_workers: int = 4 +) -> tuple[List[str], Dict[str, Any]]: + """ + Generate predictions for all samples using the playbook in parallel. + + Args: + generator: Generator instance + samples: List of samples with context and question + playbook: Playbook content to use + num_workers: Number of parallel workers + + Returns: + Tuple of (predictions, error_stats) + - predictions: List of predicted SQL queries + - error_stats: Dictionary with error statistics + """ + predictions = [None] * len(samples) + error_info = [] # Track detailed error information + + def process_sample(idx: int, sample: Dict[str, Any]) -> tuple: + """Process a single sample and return (idx, prediction, error_type).""" + try: + context = sample['context'] + question = sample['question'] + + # Generate prediction using playbook + # Generator.generate returns (response, bullet_ids, call_info) + response, bullet_ids, call_info = generator.generate( + question=question, + playbook=playbook, + context=context, + reflection="(empty)", # Explicitly pass empty reflection to minimize tokens + use_json_mode=True # Use JSON mode to get structured response + ) + + # Extract SQL from response + predicted_sql = extract_sql_from_response(response) + + # Check if this was an error response from timed_llm_call + error_type = None + if "INCORRECT_DUE_TO_EMPTY_RESPONSE" in predicted_sql: + error_type = "empty_response" + elif "INCORRECT_DUE_TO_INVALID_PROMPT" in predicted_sql: + error_type = "invalid_prompt" + elif call_info.get('error'): + # Check for context length exceeded + error_msg = call_info.get('error', '') + if 'context_length_exceeded' in error_msg or 'tokens exceed' in error_msg: + error_type = "context_length_exceeded" + else: + error_type = "api_error" + + return idx, predicted_sql, error_type + except Exception as e: + error_str = str(e) + print(f"Error processing sample {idx}: {error_str}") + + # Classify error type + error_type = "unknown_error" + if 'context_length_exceeded' in error_str or 'tokens exceed' in error_str: + error_type = "context_length_exceeded" + elif 'timeout' in error_str.lower() or 'timed out' in error_str.lower(): + error_type = "timeout" + elif 'rate limit' in error_str.lower() or '429' in error_str: + error_type = "rate_limit" + elif '400' in error_str or 'invalid_prompt' in error_str.lower(): + error_type = "client_error" + elif '500' in error_str or 'server error' in error_str.lower(): + error_type = "server_error" + + # Return a placeholder SQL that will fail evaluation + return idx, "SELECT 1", error_type + + print(f"\nGenerating predictions with {num_workers} workers...") + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + # Submit all tasks + futures = { + executor.submit(process_sample, i, sample): i + for i, sample in enumerate(samples) + } + + # Collect results as they complete + completed = 0 + for future in as_completed(futures): + idx, prediction, error_type = future.result() + predictions[idx] = prediction + + if error_type: + error_info.append({ + 'sample_idx': idx, + 'error_type': error_type, + 'question': samples[idx].get('question', '')[:100] # First 100 chars + }) + + completed += 1 + + if completed % 10 == 0 or completed == len(samples): + print(f" Progress: {completed}/{len(samples)} samples completed") + + # Generate error statistics + error_stats = { + 'total_errors': len(error_info), + 'error_breakdown': {}, + 'error_details': error_info + } + + # Count errors by type + for error in error_info: + error_type = error['error_type'] + error_stats['error_breakdown'][error_type] = error_stats['error_breakdown'].get(error_type, 0) + 1 + + if error_info: + print(f"\n⚠️ Warning: {len(error_info)} samples failed during generation") + print("Error breakdown:") + for error_type, count in sorted(error_stats['error_breakdown'].items(), key=lambda x: x[1], reverse=True): + print(f" - {error_type}: {count}") + + return predictions, error_stats + + +def evaluate_test_samples( + predictions: List[str], + test_samples: List[Dict[str, Any]], + bird_db_root: str +) -> Dict[str, Any]: + """ + Evaluate predictions using test_samples.json (which already has ground truth). + + Args: + predictions: List of predicted SQL queries + test_samples: List of test samples with ground truth SQL + bird_db_root: Path to BIRD database root + + Returns: + Dictionary with evaluation results + """ + print(f"\n" + "="*70) + print("EVALUATING ON TEST SAMPLES") + print("="*70) + print(f"Total test samples: {len(test_samples)}") + print("="*70) + + if len(predictions) != len(test_samples): + print(f"\nError: Mismatch between predictions ({len(predictions)}) and samples ({len(test_samples)})") + return { + 'accuracy': 0.0, + 'total_samples': 0, + 'correct': 0, + 'error': 'Prediction count mismatch' + } + + # Initialize DataProcessor for evaluation + data_processor = DataProcessor(bird_db_root=bird_db_root) + + # Evaluate each sample + correct = 0 + results = [] + + # Track results by difficulty + difficulty_stats = { + 'simple': {'correct': 0, 'total': 0}, + 'moderate': {'correct': 0, 'total': 0}, + 'challenging': {'correct': 0, 'total': 0} + } + + print("\nEvaluating predictions...") + for i, (pred, sample) in enumerate(zip(predictions, test_samples)): + try: + gt = sample['target'] + meta = sample.get('others', {}) + difficulty = meta.get('difficulty', 'unknown') + + is_correct = data_processor.answer_is_correct(pred, gt, meta) + + if is_correct: + correct += 1 + + # Track by difficulty + if difficulty in difficulty_stats: + difficulty_stats[difficulty]['total'] += 1 + if is_correct: + difficulty_stats[difficulty]['correct'] += 1 + + results.append({ + 'question': sample['question'], + 'db_name': meta.get('db_name', ''), + 'difficulty': difficulty, + 'predicted_sql': pred, + 'ground_truth_sql': gt, + 'is_correct': is_correct + }) + except Exception as e: + print(f" Error evaluating sample {i}: {e}") + meta = sample.get('others', {}) + difficulty = meta.get('difficulty', 'unknown') + + # Track failed sample by difficulty + if difficulty in difficulty_stats: + difficulty_stats[difficulty]['total'] += 1 + + results.append({ + 'question': sample.get('question', ''), + 'db_name': meta.get('db_name', ''), + 'difficulty': difficulty, + 'predicted_sql': pred, + 'ground_truth_sql': sample.get('target', ''), + 'is_correct': False, + 'error': str(e) + }) + + if (i + 1) % 10 == 0 or (i + 1) == len(test_samples): + print(f" Progress: {i + 1}/{len(test_samples)} samples evaluated (correct: {correct})") + + accuracy = correct / len(test_samples) if len(test_samples) > 0 else 0.0 + + # Calculate difficulty-specific accuracies + difficulty_accuracies = {} + for diff, stats in difficulty_stats.items(): + if stats['total'] > 0: + difficulty_accuracies[diff] = { + 'accuracy': stats['correct'] / stats['total'], + 'correct': stats['correct'], + 'total': stats['total'] + } + + return { + 'accuracy': accuracy, + 'total_samples': len(test_samples), + 'correct': correct, + 'difficulty_breakdown': difficulty_accuracies, + 'results': results + } + + + + +def load_run_config(results_dir: str) -> Dict[str, Any]: + """Load run_config.json from results directory.""" + config_path = os.path.join(results_dir, 'run_config.json') + if os.path.exists(config_path): + with open(config_path, 'r') as f: + return json.load(f) + return {} + + +def main(): + parser = argparse.ArgumentParser( + description='Run a playbook on processed data and evaluate accuracy' + ) + parser.add_argument( + '--results_dir', + type=str, + required=True, + help='Path to results directory (e.g., results/ace_run_20260119_234301_bird_all_hard_to_easy_online)' + ) + parser.add_argument( + '--playbook_file', + type=str, + default=None, + help='Playbook file path relative to results_dir (e.g., intermediate_playbooks/window_4_final_playbook.txt). If not provided, runs initial evaluation with empty playbook.' + ) + parser.add_argument( + '--bird_db_root', + type=str, + default='eval/stream-bench/data/bird/dev_databases', + help='Path to BIRD database root directory (for SQL execution during evaluation)' + ) + parser.add_argument( + '--api_provider', + type=str, + default=None, + choices=['sambanova', 'together', 'openai'], + help='API provider for LLM calls (defaults to value from run_config.json)' + ) + parser.add_argument( + '--generator_model', + type=str, + default=None, + help='Model name for generator (defaults to generator_model from run_config.json)' + ) + parser.add_argument( + '--num_workers', + type=int, + default=4, + help='Number of parallel workers for generation' + ) + parser.add_argument( + '--output_file', + type=str, + default=None, + help='Optional output file to save detailed results (JSON)' + ) + parser.add_argument( + '--dataset', + type=str, + default='test', + choices=['train', 'val', 'test'], + help='Which dataset to evaluate on: train_samples.json, val_samples.json, or test_samples.json (default: test)' + ) + + args = parser.parse_args() + + # Validate paths + if not os.path.exists(args.results_dir): + print(f"Error: Results directory not found: {args.results_dir}") + return 1 + + # Handle playbook file (optional) + playbook_path = None + if args.playbook_file: + # Join playbook_file with results_dir + playbook_path = os.path.join(args.results_dir, args.playbook_file) + if not os.path.exists(playbook_path): + print(f"Error: Playbook file not found: {playbook_path}") + return 1 + + # Load run config to get default model and API provider + run_config = load_run_config(args.results_dir) + + # Use config values if args not provided + if args.api_provider is None: + args.api_provider = run_config.get('config', {}).get('api_provider') or run_config.get('api_provider', 'sambanova') + print(f"Using API provider from run_config.json: {args.api_provider}") + + if args.generator_model is None: + args.generator_model = run_config.get('generator_model', 'DeepSeek-V3.1') + print(f"Using generator_model from run_config.json: {args.generator_model}") + + # Get bird_db_root from config if not provided via CLI + # Hardcode database paths: train/val use train_databases, test uses dev_databases + if args.bird_db_root == 'eval/stream-bench/data/bird/dev_databases': # Using default + if args.dataset in ['train', 'val']: + args.bird_db_root = 'eval/stream-bench/data/bird_train/train_databases' + else: # test + args.bird_db_root = 'eval/stream-bench/data/bird/dev_databases' + print(f"Using bird_db_root for {args.dataset} dataset: {args.bird_db_root}") + + # Load playbook (or use empty for initial evaluation) + if playbook_path: + print(f"\nLoading playbook from: {playbook_path}") + playbook = load_playbook(playbook_path) + print(f"Playbook loaded ({len(playbook)} characters)") + else: + print(f"\nNo playbook provided - running INITIAL EVALUATION with empty playbook") + playbook = "" + + # Load samples from processed_data (has everything we need) + samples_filename = f'{args.dataset}_samples.json' + samples_path = os.path.join(args.results_dir, 'processed_data', samples_filename) + + if not os.path.exists(samples_path): + print(f"\nError: {args.dataset.capitalize()} samples file not found: {samples_path}") + print("This file should be created during the ACE training run.") + return 1 + + print(f"\nLoading {args.dataset} samples from: {samples_path}") + print(f" (This file contains the {args.dataset} data with ground truth SQL)") + with open(samples_path, 'r') as f: + samples = json.load(f) + print(f" Loaded {len(samples)} {args.dataset} samples") + + # Initialize generator + print(f"\nInitializing generator with {args.api_provider} API...") + generator_client, _, _ = initialize_clients(args.api_provider) + generator = Generator(generator_client, args.api_provider, args.generator_model, max_tokens=4096) + + # Generate predictions + predictions, error_stats = generate_predictions_parallel( + generator, samples, playbook, num_workers=args.num_workers + ) + + # Evaluate + print(f"\nEvaluating predictions using execution-based evaluation...") + eval_results = evaluate_test_samples(predictions, samples, args.bird_db_root) + + # Print results + print("\n" + "="*70) + if args.playbook_file: + print(f"EVALUATION RESULTS - {args.dataset.upper()} DATASET") + print("="*70) + print(f"Playbook: {args.playbook_file}") + else: + print(f"INITIAL EVALUATION RESULTS - {args.dataset.upper()} DATASET (empty playbook)") + print("="*70) + print(f"Playbook: ") + print(f"Dataset: {args.dataset}_samples.json") + print(f"\nOverall Performance:") + print(f" Total samples evaluated: {eval_results['total_samples']}") + print(f" Correct: {eval_results['correct']}") + print(f" Accuracy: {eval_results['accuracy']:.2%}") + + # Print difficulty breakdown if available + if 'difficulty_breakdown' in eval_results and eval_results['difficulty_breakdown']: + print(f"\nPerformance by Difficulty:") + for difficulty in ['simple', 'moderate', 'challenging']: + if difficulty in eval_results['difficulty_breakdown']: + stats = eval_results['difficulty_breakdown'][difficulty] + print(f" {difficulty.capitalize():12s}: {stats['accuracy']:.2%} ({stats['correct']}/{stats['total']})") + + if error_stats['total_errors'] > 0: + print(f"\nAPI Errors:") + print(f" Total errors during generation: {error_stats['total_errors']}") + print(f" (These samples were marked as incorrect)") + print("="*70) + + # Save detailed results if requested + if args.output_file: + # Save output file under the results_dir directory + output_path = os.path.join(args.results_dir, args.output_file) + + # Warn if file already exists + if os.path.exists(output_path): + print(f"\nWarning: Output file already exists and will be overwritten: {output_path}") + + try: + # Create directory if it doesn't exist + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + with open(output_path, 'w') as f: + json.dump({ + 'dataset': args.dataset, + 'dataset_file': f'{args.dataset}_samples.json', + 'playbook_file': args.playbook_file if args.playbook_file else '', + 'playbook_path': playbook_path if playbook_path else None, + 'is_initial_evaluation': args.playbook_file is None, + 'bird_db_root': args.bird_db_root, + 'accuracy': eval_results['accuracy'], + 'total_samples': eval_results['total_samples'], + 'correct': eval_results['correct'], + 'difficulty_breakdown': eval_results.get('difficulty_breakdown', {}), + 'api_errors': error_stats['total_errors'], + 'error_breakdown': error_stats['error_breakdown'], + 'error_details': error_stats['error_details'], + 'results': eval_results['results'] + }, f, indent=2) + print(f"\nDetailed results saved to: {output_path}") + except Exception as e: + print(f"\nError saving results to {output_path}: {e}") + return 1 + + return 0 + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/llm.py b/llm.py index b6483f6a..649337d9 100644 --- a/llm.py +++ b/llm.py @@ -117,7 +117,31 @@ def timed_llm_call(client, api_provider, model, prompt, role, call_id, max_token is_timeout = any(k in str(e).lower() for k in ["timeout", "timed out", "connection"]) is_rate_limit = any(k in str(e).lower() for k in ["rate limit", "429", "rate_limit_exceeded"]) is_empty_response = "empty response" in str(e).lower() or "api returned none content" in str(e).lower() - + + # Check for client errors (400) including invalid prompts, policy violations + is_client_error = False + is_invalid_prompt = False + if hasattr(e, 'response'): + try: + status_code = getattr(e.response, 'status_code', None) + if status_code and status_code == 400: + is_client_error = True + print(f"[{role.upper()}] Client error detected: HTTP {status_code}") + except: + pass + + # Check for invalid prompt / policy violations + if any(k in str(e).lower() for k in ["invalid_prompt", "usage policy", "invalid prompt", "error code: 400"]): + is_client_error = True + is_invalid_prompt = True + print(f"[{role.upper()}] Invalid prompt error detected: {str(e)[:200]}...") + + # Also check for specific OpenAI BadRequestError + if hasattr(openai, 'BadRequestError') and isinstance(e, openai.BadRequestError): + is_client_error = True + if "invalid_prompt" in str(e).lower() or "usage policy" in str(e).lower(): + is_invalid_prompt = True + # Check for server errors (500, 502, 503, etc.) that should be retried is_server_error = False if hasattr(e, 'response'): @@ -128,16 +152,16 @@ def timed_llm_call(client, api_provider, model, prompt, role, call_id, max_token print(f"[{role.upper()}] Server error detected: HTTP {status_code}") except: pass - + # Also check for 500 errors in the error message itself if any(k in str(e).lower() for k in ["500 internal server error", "internal server error", "502 bad gateway", "503 service unavailable"]): is_server_error = True print(f"[{role.upper()}] Server error detected in message: {str(e)[:100]}...") - + # Also check for specific OpenAI exceptions if hasattr(openai, 'RateLimitError') and isinstance(e, openai.RateLimitError): is_rate_limit = True - + # Check for OpenAI InternalServerError if hasattr(openai, 'InternalServerError') and isinstance(e, openai.InternalServerError): is_server_error = True @@ -222,7 +246,37 @@ def timed_llm_call(client, api_provider, model, prompt, role, call_id, max_token # For the 4-question format, we return 4 wrong answers incorrect_response = "INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE, INCORRECT_DUE_TO_EMPTY_RESPONSE" return incorrect_response, call_info - + + # Handle client errors (400) - these should NOT be retried + # Instead, we skip the sample and continue execution + if is_client_error or is_invalid_prompt: + print(f"[{role.upper()}] ⚠️ Client error (400 / invalid prompt) - skipping sample and continuing execution") + print(f"[{role.upper()}] Error details: {str(e)[:300]}...") + + error_time = time.time() + call_info = { + "role": role, + "call_id": call_id, + "model": model, + "prompt": prompt, + "error": "CLIENT_ERROR_SKIPPED: " + str(e), + "total_time": error_time - start_time, + "prompt_length": len(prompt), + "response_length": 0, + "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S_%f")[:-3], + "datetime": datetime.now().isoformat(), + "skipped_due_to_invalid_prompt": is_invalid_prompt, + "skipped_due_to_client_error": True + } + + if log_dir: + log_llm_call(log_dir, call_info) + + # Return a response that will be marked as incorrect so execution continues + # This allows the training/testing loop to proceed with remaining samples + incorrect_response = "INCORRECT_DUE_TO_INVALID_PROMPT, INCORRECT_DUE_TO_INVALID_PROMPT, INCORRECT_DUE_TO_INVALID_PROMPT, INCORRECT_DUE_TO_INVALID_PROMPT" + return incorrect_response, call_info + # Retry logic for timeouts, rate limits, and server errors if (is_timeout or is_rate_limit or is_server_error) and attempt < retries_on_timeout: attempt += 1 diff --git a/utils.py b/utils.py index 01ef7bdc..5c127fcf 100644 --- a/utils.py +++ b/utils.py @@ -37,8 +37,8 @@ def initialize_clients(api_provider): generator_client = openai.OpenAI(api_key=api_key, base_url=base_url) reflector_client = openai.OpenAI(api_key=api_key, base_url=base_url) curator_client = openai.OpenAI(api_key=api_key, base_url=base_url) - - print("Using Together API for all models") + + print(f"Using {api_provider.capitalize()} API for all models") return generator_client, reflector_client, curator_client def get_section_slug(section_name): @@ -177,7 +177,11 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: ) final_answer = extract_answer(gen_response) - is_correct = data_processor.answer_is_correct(final_answer, target) + # print("============= calling data_processor answer_is_correct ===========") + + # Pass sample metadata for thread-safe evaluation (e.g., db_name for SQL tasks) + sample_metadata = task_dict.get("others", None) + is_correct = data_processor.answer_is_correct(final_answer, target, sample_metadata) return { "index": i, @@ -192,11 +196,11 @@ def evaluate_single_test_sample(args_tuple, data_processor) -> Tuple[Dict, str]: def evaluate_test_set(data_processor, generator, playbook, test_samples, - max_tokens=4096, log_dir=None, max_workers=20, + max_tokens=4096, log_dir=None, max_workers=20, use_json_mode=False) -> Tuple[Dict, Dict]: """ Parallel evaluation of test set - task-agnostic implementation. - + Args: data_processor: DataProcessor instance with answer_is_correct and evaluate_accuracy methods generator: Generator instance @@ -206,7 +210,7 @@ def evaluate_test_set(data_processor, generator, playbook, test_samples, log_dir: Directory for logs max_workers: Number of parallel workers use_json_mode: Whether to use JSON mode - + Returns: Tuple of (results_dict, error_logs_dict) """ @@ -224,62 +228,104 @@ def evaluate_test_set(data_processor, generator, playbook, test_samples, "answers": [], "targets": [], "errors": [] } + # NEW: Track results by difficulty level + difficulty_results = {} + + # Store results indexed by original sample position to preserve order + indexed_results = {} + # Use a wrapper to pass data_processor to the evaluation function def eval_wrapper(args_tuple): return evaluate_single_test_sample(args_tuple, data_processor) with ThreadPoolExecutor(max_workers=max_workers) as executor: future_to_args = { - executor.submit(eval_wrapper, args): args + executor.submit(eval_wrapper, args): args for args in args_list } for i, future in enumerate(as_completed(future_to_args), 1): result, error = future.result() - + if error: print(error) continue if result and result["success"]: + # Store result by its original index to preserve order + indexed_results[result["index"]] = result + results["correct"] += (1 if result["is_correct"] else 0) results["total"] += 1 - results["answers"].append(result["final_answer"]) - results["targets"].append(result["target"]) - + + # NEW: Track by difficulty level + sample = test_samples[result["index"]] + difficulty = sample.get("others", {}).get("difficulty", "unknown") + if difficulty not in difficulty_results: + difficulty_results[difficulty] = {"correct": 0, "total": 0} + difficulty_results[difficulty]["total"] += 1 + if result["is_correct"]: + difficulty_results[difficulty]["correct"] += 1 + if not result["is_correct"]: results["errors"].append({ "index": result["index"], "prediction": result["final_answer"], "ground_truth": result["target"] }) - + if result["final_answer"] == "No final answer found": results["no_answer"] += 1 if i % 50 == 0: curr_acc = results["correct"] / results["total"] if results["total"] > 0 else 0 print(f"Progress: {i}/{len(args_list)}, Accuracy: {curr_acc:.3f}") - + + # Reconstruct answers and targets in original order + for idx in sorted(indexed_results.keys()): + result = indexed_results[idx] + results["answers"].append(result["final_answer"]) + results["targets"].append(result["target"]) + # NEW if results["answers"] and results["targets"]: - accuracy = data_processor.evaluate_accuracy(results["answers"], results["targets"]) - + # Calculate overall accuracy from worker thread results (no re-evaluation) + accuracy = results["correct"] / results["total"] if results["total"] > 0 else 0.0 + + # Calculate accuracy by difficulty + accuracy_by_difficulty = {} + for difficulty, diff_results in difficulty_results.items(): + if diff_results["total"] > 0: + diff_accuracy = diff_results["correct"] / diff_results["total"] + accuracy_by_difficulty[difficulty] = { + "accuracy": diff_accuracy, + "correct": diff_results["correct"], + "total": diff_results["total"] + } + final_results = { "accuracy": accuracy, "correct": results["correct"], "total": results["total"], - "no_answer": results["no_answer"] + "no_answer": results["no_answer"], + "by_difficulty": accuracy_by_difficulty } - + error_logs = { "accuracy": accuracy, "errors": results["errors"] } - + print(f"\n📊 Final Accuracy: {accuracy:.3f} ({results['correct']}/{results['total']})") + + # Print accuracy by difficulty level + if accuracy_by_difficulty: + print(f"\n📈 Accuracy by Difficulty Level:") + for difficulty in sorted(accuracy_by_difficulty.keys()): + diff_data = accuracy_by_difficulty[difficulty] + print(f" {difficulty}: {diff_data['accuracy']:.3f} ({diff_data['correct']}/{diff_data['total']})") else: - results = {"accuracy": 0.0, "correct": 0, "total": 0} + final_results = {"accuracy": 0.0, "correct": 0, "total": 0, "by_difficulty": {}} error_logs = {} print(f"\n📊 No valid results!") - + return final_results, error_logs \ No newline at end of file From e6e447b318a4af682e6f79bafe25c1395ab96fba Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 18 Feb 2026 18:14:21 -0800 Subject: [PATCH 2/3] remove stream-bench dependency on finance --- eval/finance/data_processor.py | 6 +- eval/finance/run.py | 79 +++++++++++------------- eval/stream-bench/run.py | 107 +++++++++++++++++++++++++++++++-- 3 files changed, 138 insertions(+), 54 deletions(-) diff --git a/eval/finance/data_processor.py b/eval/finance/data_processor.py index 2f63c4e0..aa110bb0 100644 --- a/eval/finance/data_processor.py +++ b/eval/finance/data_processor.py @@ -162,14 +162,13 @@ def _formula_answer_is_correct(self, predicted: str, ground_truth: str) -> bool: return predicted == ground_truth - def answer_is_correct(self, predicted: str, ground_truth: str, sample_metadata=None) -> bool: + def answer_is_correct(self, predicted: str, ground_truth: str) -> bool: """ Dataset-specific answer correctness check. Args: predicted: Model's answer ground_truth: Ground truth answer - sample_metadata: Optional dict containing sample metadata (unused for finance tasks) Returns: bool: True if answer is correct, False otherwise @@ -221,14 +220,13 @@ def _evaluate_formula_accuracy(self, out: List[str], target: List[str]) -> tuple return accuracy - def evaluate_accuracy(self, out: List[str], target: List[str], samples=None) -> tuple: + def evaluate_accuracy(self, out: List[str], target: List[str]) -> tuple: """ Dataset-specific accuracy evaluation. Args: out: List of model predictions target: List of ground truth targets - samples: Optional list of sample dicts (unused for finance tasks) Returns: tuple: (accuracy, response_list) diff --git a/eval/finance/run.py b/eval/finance/run.py index 543894e1..86535536 100644 --- a/eval/finance/run.py +++ b/eval/finance/run.py @@ -13,10 +13,10 @@ from ace import ACE from utils import initialize_clients -def get_base_parser(description='ACE System'): - """Get base argument parser with common arguments.""" - parser = argparse.ArgumentParser(description=description) - +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description='ACE System - Refactored') + # Task configuration parser.add_argument("--task_name", type=str, required=True, help="Name of the task (e.g., 'finer', 'formula')") @@ -27,11 +27,11 @@ def get_base_parser(description='ACE System'): help="Run mode: 'offline' for offline training with validation, " "'online' for online training and testing on test split, " "'eval_only' for testing only with provided playbook") - + # Model configuration parser.add_argument("--api_provider", type=str, default="sambanova", choices=["sambanova", "together", "openai"], help="API provider") - parser.add_argument("--generator_model", type=str, + parser.add_argument("--generator_model", type=str, default="DeepSeek-V3.1", help="Model for generator") parser.add_argument("--reflector_model", type=str, @@ -40,7 +40,7 @@ def get_base_parser(description='ACE System'): parser.add_argument("--curator_model", type=str, default="DeepSeek-V3.1", help="Model for curator") - + # Training configuration parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") @@ -54,7 +54,7 @@ def get_base_parser(description='ACE System'): help="Update playbook every N samples for evaluation in online mode") parser.add_argument("--save_steps", type=int, default=50, help="Save intermediate playbooks every N steps") - + # System configuration parser.add_argument("--max_tokens", type=int, default=4096, help="Max tokens for LLM responses") @@ -62,67 +62,57 @@ def get_base_parser(description='ACE System'): help="Total token budget for playbook") parser.add_argument("--test_workers", type=int, default=20, help="Number of parallel workers for testing") - + # Prompt configuration parser.add_argument("--json_mode", action="store_true", help="Enable JSON mode for LLM calls") parser.add_argument("--no_ground_truth", action="store_true", help="Don't use ground truth in reflection") - + # Bulletpoint analyzer configuration parser.add_argument("--use_bulletpoint_analyzer", action="store_true", help="Enable bulletpoint analyzer for deduplication and merging") parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") - - # SQL evaluation configuration - parser.add_argument("--pass_sql_eval_results", action="store_true", - help="Pass SQL execution results to reflector for better error analysis") - + # Output configuration parser.add_argument("--save_path", type=str, required=True, help="Directory to save results") - - return parser - - -def parse_args(): - """Parse command line arguments for finance tasks.""" - parser = get_base_parser(description='ACE System - Finance') + return parser.parse_args() def load_data(data_path: str): """ Load and process data from a JSONL file. - + Args: data_path: Path to the JSONL file - + Returns: List of dictionaries containing the data """ if not os.path.exists(data_path): raise FileNotFoundError(f"Data file not found: {data_path}") - + data = [] with open(data_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: # Skip empty lines data.append(json.loads(line)) - + print(f"Loaded {len(data)} samples from {data_path}") return data def preprocess_data(task_name, config, mode): """ Load training and test data for the specified task. - + Args: task_name: Name of the task config: Configuration dictionary with data paths mode: Run mode ('offline', 'online', or 'eval_only') - + Returns: Tuple of (train_samples, val_samples, test_samples, data_processor) - For offline mode: all three are loaded @@ -130,39 +120,39 @@ def preprocess_data(task_name, config, mode): - For eval_only mode: only test_samples """ processor = DataProcessor(task_name=task_name) - + # For online and eval_only modes, only load test data if mode in ["online", "eval_only"]: train_samples = None val_samples = None - + if "test_data" in config: test_samples = load_data(config["test_data"]) test_samples = processor.process_task_data(test_samples) else: raise ValueError(f"{mode} mode requires test data in config.") - + if mode == "online": print(f"Online mode: Training and testing on {len(test_samples)} examples") else: print(f"Eval only mode: Testing on {len(test_samples)} examples") - + # For offline mode, load train, val, and optionally test data else: train_samples = load_data(config["train_data"]) val_samples = load_data(config["val_data"]) train_samples = processor.process_task_data(train_samples) val_samples = processor.process_task_data(val_samples) - + if "test_data" in config: test_samples = load_data(config["test_data"]) test_samples = processor.process_task_data(test_samples) else: test_samples = [] - + print(f"Offline mode: Training on {len(train_samples)} examples, " f"validating on {len(val_samples)}, testing on {len(test_samples)}") - + return train_samples, val_samples, test_samples, processor @@ -177,7 +167,7 @@ def load_initial_playbook(path): def main(): """Main execution function.""" args = parse_args() - + print(f"\n{'='*60}") print(f"ACE SYSTEM") print(f"{'='*60}") @@ -185,24 +175,24 @@ def main(): print(f"Mode: {args.mode.upper().replace('_', ' ')}") print(f"Generator Model: {args.generator_model}") print(f"{'='*60}\n") - + # Load data with open("./eval/finance/data/sample_config.json", 'r') as f: task_config = json.load(f) train_samples, val_samples, test_samples, data_processor = preprocess_data( - args.task_name, + args.task_name, task_config[args.task_name], args.mode ) - + # Load initial playbook (or use empty if None provided) initial_playbook = load_initial_playbook(args.initial_playbook_path) if initial_playbook: print(f"Loaded initial playbook from {args.initial_playbook_path}\n") else: print("Using empty playbook as initial playbook\n") - + # Create ACE system ace_system = ACE( api_provider=args.api_provider, @@ -214,7 +204,7 @@ def main(): use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold ) - + # Prepare configuration config = { 'num_epochs': args.num_epochs, @@ -233,10 +223,9 @@ def main(): 'initial_playbook_path': args.initial_playbook_path, 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, - 'pass_sql_eval_results': args.pass_sql_eval_results, 'api_provider': args.api_provider } - + # Execute using the unified run method results = ace_system.run( mode=args.mode, @@ -246,7 +235,7 @@ def main(): data_processor=data_processor, config=config ) - + if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/eval/stream-bench/run.py b/eval/stream-bench/run.py index 295cabac..f126fa25 100644 --- a/eval/stream-bench/run.py +++ b/eval/stream-bench/run.py @@ -7,20 +7,117 @@ import json import re import time +import argparse import traceback from ace import ACE from .data_processor import DataProcessor from .plot import plot_online_performance, plot_training_progress, plot_offline_training_progress -from finance.run import get_base_parser, load_initial_playbook, load_data + + +def load_data(data_path: str): + """ + Load and process data from a JSONL file. + + Args: + data_path: Path to the JSONL file + + Returns: + List of dictionaries containing the data + """ + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found: {data_path}") + + data = [] + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: # Skip empty lines + data.append(json.loads(line)) + + print(f"Loaded {len(data)} samples from {data_path}") + return data + + +def load_initial_playbook(path): + """Load initial playbook if provided.""" + if path and os.path.exists(path): + with open(path, 'r') as f: + return f.read() + return None def parse_args(): """Parse command line arguments for stream-bench.""" - # Get base parser with all common arguments - parser = get_base_parser(description='ACE System - Stream Bench') - - # Add stream-bench specific arguments + parser = argparse.ArgumentParser(description='ACE System - Stream Bench') + + # Task configuration + parser.add_argument("--task_name", type=str, required=True, + help="Name of the task (e.g., 'finer', 'formula')") + parser.add_argument("--initial_playbook_path", type=str, default=None, + help="Path to initial playbook (optional)") + parser.add_argument("--mode", type=str, default="offline", + choices=["offline", "online", "eval_only"], + help="Run mode: 'offline' for offline training with validation, " + "'online' for online training and testing on test split, " + "'eval_only' for testing only with provided playbook") + + # Model configuration + parser.add_argument("--api_provider", type=str, default="sambanova", + choices=["sambanova", "together", "openai"], help="API provider") + parser.add_argument("--generator_model", type=str, + default="DeepSeek-V3.1", + help="Model for generator") + parser.add_argument("--reflector_model", type=str, + default="DeepSeek-V3.1", + help="Model for reflector") + parser.add_argument("--curator_model", type=str, + default="DeepSeek-V3.1", + help="Model for curator") + + # Training configuration + parser.add_argument("--num_epochs", type=int, default=1, + help="Number of training epochs") + parser.add_argument("--max_num_rounds", type=int, default=3, + help="Max reflection rounds for incorrect answers") + parser.add_argument("--curator_frequency", type=int, default=1, + help="Run curator every N steps") + parser.add_argument("--eval_steps", type=int, default=100, + help="Evaluate every N steps") + parser.add_argument("--online_eval_frequency", type=int, default=15, + help="Update playbook every N samples for evaluation in online mode") + parser.add_argument("--save_steps", type=int, default=50, + help="Save intermediate playbooks every N steps") + + # System configuration + parser.add_argument("--max_tokens", type=int, default=4096, + help="Max tokens for LLM responses") + parser.add_argument("--playbook_token_budget", type=int, default=80000, + help="Total token budget for playbook") + parser.add_argument("--test_workers", type=int, default=20, + help="Number of parallel workers for testing") + + # Prompt configuration + parser.add_argument("--json_mode", action="store_true", + help="Enable JSON mode for LLM calls") + parser.add_argument("--no_ground_truth", action="store_true", + help="Don't use ground truth in reflection") + + # Bulletpoint analyzer configuration + parser.add_argument("--use_bulletpoint_analyzer", action="store_true", + help="Enable bulletpoint analyzer for deduplication and merging") + parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, + help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") + + # SQL evaluation configuration + parser.add_argument("--pass_sql_eval_results", action="store_true", + help="Pass SQL execution results to reflector for better error analysis") + + # Output configuration + parser.add_argument("--save_path", type=str, required=True, + help="Directory to save results") + + # Stream-bench specific arguments parser.add_argument("--data_config", type=str, required=True, help="Path to data configuration JSON file") parser.add_argument("--plot", action="store_true", From 7ffedabd24a092d01742665b4d5ee815eaa9511d Mon Sep 17 00:00:00 2001 From: Sherry Date: Wed, 18 Feb 2026 18:21:18 -0800 Subject: [PATCH 3/3] move extra args in _setup_paths to metadata --- ace/ace.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/ace/ace.py b/ace/ace.py index 9316b23e..bc89c1dc 100644 --- a/ace/ace.py +++ b/ace/ace.py @@ -136,7 +136,7 @@ def _extract_config_params(self, config: Dict[str, Any]) -> Dict[str, Any]: 'pass_sql_eval_results': config.get('pass_sql_eval_results', False) } - def _setup_paths(self, save_dir: str, task_name: str, mode: str, db_name: str = None, curriculum: str = None) -> Tuple[str, str]: + def _setup_paths(self, save_dir: str, task_name: str, mode: str, metadata: Dict[str, Any] = None) -> Tuple[str, str]: """ Setup logging paths and directories. @@ -144,12 +144,16 @@ def _setup_paths(self, save_dir: str, task_name: str, mode: str, db_name: str = save_dir: Base path for saving results task_name: task name mode: 'offline', 'online', or 'eval_only' - db_name: Optional database name to include in folder name - curriculum: Optional curriculum level to include in folder name + metadata: Optional dict with extra fields to include in the folder name + (e.g. {'db_name': 'my_db', 'curriculum': 'easy_to_hard'}) Returns: Tuple of (usage_log_path, playbook_dir) """ + metadata = metadata or {} + db_name = metadata.get('db_name') + curriculum = metadata.get('curriculum') + # Create timestamped run folder timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") @@ -225,16 +229,18 @@ def run( config_params = self._extract_config_params(config) task_name = config_params['task_name'] save_dir = config_params['save_dir'] - db_name = config.get('db_name', None) if config else None - curriculum = config.get('curriculum', None) if config else None + metadata = { + 'db_name': config.get('db_name') if config else None, + 'curriculum': config.get('curriculum') if config else None, + } # Setup paths based on mode if mode == 'eval_only': - save_path, log_dir = self._setup_paths(save_dir, task_name, mode, db_name, curriculum) + save_path, log_dir = self._setup_paths(save_dir, task_name, mode, metadata) usage_log_path = None playbook_dir = None else: - save_path, usage_log_path, playbook_dir, log_dir = self._setup_paths(save_dir, task_name, mode, db_name, curriculum) + save_path, usage_log_path, playbook_dir, log_dir = self._setup_paths(save_dir, task_name, mode, metadata) # Save configuration config_path = os.path.join(save_path, "run_config.json")