diff --git a/evaluation.py b/evaluation.py index 58acec6..e3d124c 100644 --- a/evaluation.py +++ b/evaluation.py @@ -253,6 +253,18 @@ def main(args): parser.add_argument( "--redis_batch_size", type=int, default=256, help="Batch size for Redis vector operations (default: 256)" ) + parser.add_argument( + "--cross_encoder_model", + type=str, + default=None, + help="Name of the cross-encoder model to use for reranking (default: None)", + ) + parser.add_argument( + "--rerank_k", + type=int, + default=10, + help="Number of candidates to rerank (default: 10)", + ) args = parser.parse_args() main(args) diff --git a/run_benchmark.py b/run_benchmark.py index e693231..577d61f 100644 --- a/run_benchmark.py +++ b/run_benchmark.py @@ -56,6 +56,14 @@ def main(): parser.add_argument("--redis_index_name", type=str, default="idx_cache_match") parser.add_argument("--redis_doc_prefix", type=str, default="cache:") parser.add_argument("--redis_batch_size", type=int, default=256) + parser.add_argument( + "--cross_encoder_models", + type=str, + nargs="*", + default=None, + help="List of cross-encoder models (optional). If not provided, only bi-encoder is used.", + ) + parser.add_argument("--rerank_k", type=int, default=10, help="Number of candidates to rerank.") args = parser.parse_args() @@ -101,76 +109,96 @@ def main(): for model_name in args.models: print(f"\n Model: {model_name}") - # Sanitize model name for directory structure - safe_model_name = model_name.replace("/", "_") + # Prepare list of cross-encoders to iterate over (None = no reranking) + ce_models = args.cross_encoder_models if args.cross_encoder_models else [None] + + for ce_model_name in ce_models: + if ce_model_name: + print(f" Cross-Encoder: {ce_model_name}") + else: + print(f" Cross-Encoder: None (Bi-Encoder only)") + + # Sanitize model name for directory structure + safe_model_name = model_name.replace("/", "_") + + for run_i in range(1, args.n_runs + 1): + print(f" Run {run_i}/{args.n_runs}...") + + # 1. Bootstrapping Logic + # Sample 80% of the universe + run_universe = full_df.sample( + frac=args.sample_ratio, random_state=run_i + ) # Use run_i as seed for reproducibility per run - for run_i in range(1, args.n_runs + 1): - print(f" Run {run_i}/{args.n_runs}...") + # Split into Queries (n_samples) and Cache (remainder) + if len(run_universe) <= args.n_samples: + print( + f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping." + ) + continue - # 1. Bootstrapping Logic - # Sample 80% of the universe - run_universe = full_df.sample( - frac=args.sample_ratio, random_state=run_i - ) # Use run_i as seed for reproducibility per run + queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000) + cache = run_universe.drop(queries.index) - # Split into Queries (n_samples) and Cache (remainder) - if len(run_universe) <= args.n_samples: - print( - f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping." + # Shuffle cache + cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True) + queries = queries.reset_index(drop=True) + + # 2. Construct Output Path + timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") + + # Include cross-encoder in output path if used + model_dir_name = safe_model_name + if ce_model_name: + safe_cross_encoder_name = ce_model_name.replace("/", "_") + model_dir_name = f"{safe_model_name}_rerank_{safe_cross_encoder_name}" + + run_output_dir = os.path.join( + args.output_dir, dataset_name, model_dir_name, f"run_{run_i}", timestamp ) - continue - - queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000) - cache = run_universe.drop(queries.index) - - # Shuffle cache - cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True) - queries = queries.reset_index(drop=True) - - # 2. Construct Output Path - timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") - run_output_dir = os.path.join(args.output_dir, dataset_name, safe_model_name, f"run_{run_i}", timestamp) - os.makedirs(run_output_dir, exist_ok=True) - - # 3. Prepare Args for Evaluation - eval_args = BenchmarkArgs( - query_log_path=dataset_path, # Not strictly used by logic below but good for reference - sentence_column=args.sentence_column, - output_dir=run_output_dir, - n_samples=args.n_samples, - model_name=model_name, - cache_path=None, - full=args.full, - llm_name=args.llm_name, - llm_model=llm_classifier, - sweep_steps=200, # Default - use_redis=args.use_redis, - redis_url=args.redis_url, - redis_index_name=args.redis_index_name, - redis_doc_prefix=args.redis_doc_prefix, - redis_batch_size=args.redis_batch_size, - # device defaults to code logic - ) - - # 4. Run Evaluation - try: - print(" Matching...") - if args.use_redis: - queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args) - else: - queries_matched = run_matching(queries.copy(), cache.copy(), eval_args) - - print(" Evaluating...") - if args.full: - run_full_evaluation(queries_matched, eval_args) - else: - run_chr_analysis(queries_matched, eval_args) - - except Exception as e: - print(f" Error in run {run_i}: {e}") - import traceback - - traceback.print_exc() + os.makedirs(run_output_dir, exist_ok=True) + + # 3. Prepare Args for Evaluation + eval_args = BenchmarkArgs( + query_log_path=dataset_path, # Not strictly used by logic below but good for reference + sentence_column=args.sentence_column, + output_dir=run_output_dir, + n_samples=args.n_samples, + model_name=model_name, + cache_path=None, + full=args.full, + llm_name=args.llm_name, + llm_model=llm_classifier, + sweep_steps=200, # Default + use_redis=args.use_redis, + redis_url=args.redis_url, + redis_index_name=args.redis_index_name, + redis_doc_prefix=args.redis_doc_prefix, + redis_batch_size=args.redis_batch_size, + cross_encoder_model=ce_model_name, + rerank_k=args.rerank_k, + # device defaults to code logic + ) + + # 4. Run Evaluation + try: + print(" Matching...") + if args.use_redis: + queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args) + else: + queries_matched = run_matching(queries.copy(), cache.copy(), eval_args) + + print(" Evaluating...") + if args.full: + run_full_evaluation(queries_matched, eval_args) + else: + run_chr_analysis(queries_matched, eval_args) + + except Exception as e: + print(f" Error in run {run_i}: {e}") + import traceback + + traceback.print_exc() print("\nBenchmark completed.") diff --git a/run_benchmark.sh b/run_benchmark.sh new file mode 100644 index 0000000..085fd80 --- /dev/null +++ b/run_benchmark.sh @@ -0,0 +1,15 @@ +# Example usage: +uv run run_benchmark.py \ + --dataset_dir "dataset" \ + --output_dir "cross_encoder_results" \ + --models "Alibaba-NLP/gte-modernbert-base" "redis/langcache-embed-v1" "redis/langcache-embed-v3-small" \ + --dataset_names "vizio_unique_medium.csv" "axis_bank_unique_sentences.csv"\ + --sentence_column "sentence" \ + --n_runs 10 \ + --n_samples 10000 \ + --sample_ratio 0.8 \ + --llm_name "tensoropera/Fox-1-1.6B" \ + --full \ + --use_redis \ + # --cross_encoder_models "redis/langcache-reranker-v1-softmnrl-triplet" "Alibaba-NLP/gte-reranker-modernbert-base" \ + # --rerank_k 5 \ No newline at end of file diff --git a/scripts/plot_multiple_precision_vs_cache_hit_ratio.py b/scripts/plot_multiple_precision_vs_cache_hit_ratio.py index 0a693db..ecfc1e8 100644 --- a/scripts/plot_multiple_precision_vs_cache_hit_ratio.py +++ b/scripts/plot_multiple_precision_vs_cache_hit_ratio.py @@ -1,12 +1,101 @@ import argparse import os +from collections import defaultdict import matplotlib.pyplot as plt +import matplotlib.colors as mcolors import numpy as np import pandas as pd from utils import crawl_results +def extract_retriever_name(model_name): + """Extract retriever name from model name (part before '_rerank_').""" + if '_rerank_' in model_name: + return model_name.split('_rerank_')[0] + return model_name + + +def extract_reranker_name(model_name): + """Extract reranker name from model name (part after '_rerank_'), or None if no reranker.""" + if '_rerank_' in model_name: + return model_name.split('_rerank_')[1] + return None + + +def darken_color(color, factor): + """ + Darken a color by a given factor (0 = original, 1 = black). + factor should be between 0 and 1. + """ + rgb = mcolors.to_rgb(color) + darkened = tuple(c * (1 - factor) for c in rgb) + return darkened + + +def get_retriever_color_map(model_names): + """ + Create a color mapping for retrievers and their cross-encoder variants. + Returns: dict mapping model_name -> color + """ + # Group models by retriever + retriever_groups = defaultdict(list) + for model_name in model_names: + retriever = extract_retriever_name(model_name) + retriever_groups[retriever].append(model_name) + + # Sort retrievers for consistent ordering + sorted_retrievers = sorted(retriever_groups.keys()) + + # Use a colorful palette with good distinction + base_colors = [ + '#e6194B', # Red + '#3cb44b', # Green + '#4363d8', # Blue + '#f58231', # Orange + '#911eb4', # Purple + '#42d4f4', # Cyan + '#f032e6', # Magenta + '#bfef45', # Lime + '#fabed4', # Pink + '#469990', # Teal + '#dcbeff', # Lavender + '#9A6324', # Brown + '#fffac8', # Beige + '#800000', # Maroon + '#aaffc3', # Mint + ] + + color_map = {} + + for i, retriever in enumerate(sorted_retrievers): + base_color = base_colors[i % len(base_colors)] + models_in_group = retriever_groups[retriever] + + # Sort models within group: base retriever first, then rerankers alphabetically + def sort_key(m): + reranker = extract_reranker_name(m) + if reranker is None: + return (0, '') # Base retriever comes first + return (1, reranker) + + models_in_group.sort(key=sort_key) + + # Assign colors with increasing darkness + n_models = len(models_in_group) + for j, model_name in enumerate(models_in_group): + if n_models == 1: + # Only base retriever, use base color + color_map[model_name] = base_color + else: + # Darken progressively: base is brightest, last reranker is darkest + # factor ranges from 0 (base) to ~0.6 (darkest reranker) + darken_factor = j * 0.5 / (n_models - 1) if n_models > 1 else 0 + color_map[model_name] = darken_color(base_color, darken_factor) + + return color_map, sorted_retrievers + + def main(): parser = argparse.ArgumentParser( "Usage: python plot_multiple_precision_vs_cache_hit_ratio.py --base_dir " @@ -26,8 +115,13 @@ def main(): dataset_full_path = os.path.join(base_dir, dataset_name) if not os.path.exists(dataset_full_path): continue - fig, ax = plt.subplots(figsize=(10, 7)) - colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] + + # CHANGED: Create two subplots: one for curves, one for the AUC bar chart + fig, (ax_main, ax_bar) = plt.subplots(1, 2, figsize=(30, 12), gridspec_kw={'width_ratios': [2, 1]}) + + # Build color map based on retriever grouping + all_model_names = list(model_data.keys()) + color_map, sorted_retrievers = get_retriever_color_map(all_model_names) # Get base rate from first valid run to compute theoretical curves base_rate = None @@ -48,6 +142,9 @@ def main(): if base_rate is not None: break + # Theoretical AUCs storage + theory_aucs = {} + # Plot theoretical curves if base_rate is not None: # Theoretical Perfect (Uniform Negatives) @@ -56,16 +153,30 @@ def main(): x_uniform = np.concatenate(([0], x_uniform)) y_uniform = np.concatenate(([1], y_uniform)) auc_uniform = base_rate * (1 - np.log(base_rate)) - ax.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}") + ax_main.plot(x_uniform, y_uniform, '--', color='black', label=f"Perfect (Uniform Negs), AUC: {auc_uniform:.3f}") + theory_aucs['Uniform'] = auc_uniform # Theoretical Perfect (Zero Negatives) x_zeros = [0, base_rate, 1.0] y_zeros = [1.0, 1.0, base_rate] auc_zeros = base_rate + 0.5 * (1 - base_rate**2) - ax.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}") + ax_main.plot(x_zeros, y_zeros, ':', color='black', label=f"Perfect (Zero Negs), AUC: {auc_zeros:.3f}") + theory_aucs['ZeroNegs'] = auc_zeros + + # Sort models by retriever group, then by reranker + def model_sort_key(m): + retriever = extract_retriever_name(m) + reranker = extract_reranker_name(m) + if reranker is None: + return (retriever, 0, '') + return (retriever, 1, reranker) + + sorted_models = sorted(model_data.keys(), key=model_sort_key) + + # Store data for the bar plot + auc_records = [] - sorted_models = sorted(model_data.keys()) - for i, model_name in enumerate(sorted_models): + for model_name in sorted_models: run_paths = model_data[model_name] precisions_interp = [] aucs_pchr = [] @@ -81,7 +192,7 @@ def main(): try: df = pd.read_csv(csv_path) - # Remove the last row because it's it's always precision = 1.0 + # Remove the last row because it's always precision = 1.0 df = df.iloc[:-1] x_chr = df["cache_hit_ratio"].values @@ -95,7 +206,13 @@ def main(): p_interp = np.interp(common_chr, x_chr, y_prec) precisions_interp.append(p_interp) - aucs_pchr.append(np.trapezoid(p_interp, common_chr)) + # Use numpy.trapezoid (NumPy 2.0) or numpy.trapz (older) + try: + auc_val = np.trapezoid(p_interp, common_chr) + except AttributeError: + auc_val = np.trapz(p_interp, common_chr) + + aucs_pchr.append(auc_val) valid_runs += 1 except Exception as e: @@ -109,17 +226,57 @@ def main(): mean_auc_pchr = np.mean(aucs_pchr) std_auc_pchr = np.std(aucs_pchr) if valid_runs > 1 else 0.0 - color = colors[i % len(colors)] + # Get color from the retriever-based color map + color = color_map[model_name] label_chr = f"{model_name}, AUC: {mean_auc_pchr:.3f} ± {std_auc_pchr:.3f}" - ax.plot(common_chr, mean_p_chr, label=label_chr, color=color) + ax_main.plot(common_chr, mean_p_chr, label=label_chr, color=color) if valid_runs > 1: - ax.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2) - ax.set_xlabel("Cache Hit Ratio") - ax.set_ylabel("Precision") - ax.set_title("Precision vs Cache Hit Ratio") - ax.grid(True) - ax.legend() + ax_main.fill_between(common_chr, mean_p_chr - std_p_chr, mean_p_chr + std_p_chr, color=color, alpha=0.2) + + # Save data for bar chart + auc_records.append({ + 'name': model_name, + 'mean': mean_auc_pchr, + 'std': std_auc_pchr, + 'color': color, + 'retriever': extract_retriever_name(model_name) + }) + + # --- Configure Main Curve Plot --- + ax_main.set_xlabel("Cache Hit Ratio") + ax_main.set_ylabel("Precision") + ax_main.set_title("Precision vs Cache Hit Ratio") + ax_main.grid(True) + ax_main.legend() + + # --- Configure Bar Chart --- + if auc_records: + # Sort by mean AUC (ascending so best is at top) + auc_records.sort(key=lambda x: x['mean'], reverse=False) + + names = [r['name'] for r in auc_records] + means = [r['mean'] for r in auc_records] + stds = [r['std'] for r in auc_records] + bar_colors = [r['color'] for r in auc_records] + y_pos = np.arange(len(names)) + + ax_bar.barh(y_pos, means, xerr=stds, color=bar_colors, align='center', capsize=5, alpha=0.8) + ax_bar.set_yticks(y_pos) + ax_bar.set_yticklabels(names) + ax_bar.set_xlabel("AUC") + ax_bar.set_title("AUC Comparison") + ax_bar.grid(axis='x', linestyle='--', alpha=0.7) + + # Add theoretical lines to bar chart + if 'Uniform' in theory_aucs: + ax_bar.axvline(theory_aucs['Uniform'], color='black', linestyle='--', alpha=0.7) + if 'ZeroNegs' in theory_aucs: + ax_bar.axvline(theory_aucs['ZeroNegs'], color='black', linestyle=':', alpha=0.7) + + # Set x-limits to focus on relevant area if needed, or 0-1 + # ax_bar.set_xlim(0, 1.05) + fig.suptitle(f"Performance on {dataset_name.split('_')[0]}") plt.tight_layout() output_path = os.path.join(dataset_full_path, "precision_vs_cache_hit_ratio.png") @@ -129,4 +286,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/customer_analysis/data_processing.py b/src/customer_analysis/data_processing.py index 450b01d..a1764ab 100644 --- a/src/customer_analysis/data_processing.py +++ b/src/customer_analysis/data_processing.py @@ -1,4 +1,5 @@ import pandas as pd +import numpy as np import torch from src.customer_analysis.embedding_interface import NeuralEmbedding @@ -23,6 +24,21 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): """ text_col = args.sentence_column + # Determine k for retrieval + k = 1 + cross_encoder = None + if getattr(args, "cross_encoder_model", None): + try: + from sentence_transformers import CrossEncoder + cross_encoder = CrossEncoder( + args.cross_encoder_model, + device=getattr(args, "device", None) or ("cuda" if torch.cuda.is_available() else "cpu") + ) + k = getattr(args, "rerank_k", 10) + print(f"Using Cross-Encoder reranking: {args.cross_encoder_model} (top-{k})") + except ImportError: + print("Warning: sentence_transformers not found or CrossEncoder import failed. Skipping reranking.") + rindex = RedisVectorIndex( col_query=text_col, index_name=getattr(args, "redis_index_name", "idx_cache_match"), @@ -38,27 +54,83 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): # 2) embed + load cache cache_texts = cache[text_col].tolist() cache_vecs = rindex._embed_batch(cache_texts) # (M, D) + + # Fix: Ensure vectors are normalized and float32 + norms = np.linalg.norm(cache_vecs, axis=1, keepdims=True) + norms[norms == 0] = 1e-9 + cache_vecs = (cache_vecs / norms).astype(np.float32) + rindex.load_texts_and_vecs(cache_texts, cache_vecs) - # 3) embed queries and search top-1 + # 3) embed queries and search top-k query_texts = queries[text_col].tolist() query_vecs = rindex._embed_batch(query_texts) + + # Normalize queries too and ensure float32 + norms = np.linalg.norm(query_vecs, axis=1, keepdims=True) + norms[norms == 0] = 1e-9 + query_vecs = (query_vecs / norms).astype(np.float32) best_scores: list[float] = [] matches: list[str] = [] - for qv in query_vecs: - resp = rindex.query_vector_topk(qv, k=1) - if not resp: - best_scores.append(0.0) - matches.append("") - continue - - hit = resp[0] - cosine_sim = 1.0 - float(hit["vector_distance"]) # convert to similarity - - best_scores.append(cosine_sim) - matches.append(hit[text_col]) + if cross_encoder and k > 1: + all_pairs = [] + query_candidate_counts = [] + candidates_list = [] # store candidates for each query to retrieve text later + + print("Retrieving candidates from Redis...") + for i, qv in enumerate(query_vecs): + resp = rindex.query_vector_topk(qv, k=k) + if not resp: + query_candidate_counts.append(0) + candidates_list.append([]) + continue + + q_text = query_texts[i] + cands = [r[text_col] for r in resp] + candidates_list.append(cands) + + for c_text in cands: + all_pairs.append([q_text, c_text]) + + query_candidate_counts.append(len(cands)) + + if all_pairs: + print(f"Reranking {len(all_pairs)} pairs with Cross-Encoder...") + all_scores = cross_encoder.predict(all_pairs, batch_size=32, show_progress_bar=True) + + # Reassemble + score_idx = 0 + for i, count in enumerate(query_candidate_counts): + if count == 0: + best_scores.append(0.0) + matches.append("") + continue + + # Get scores for this query + q_scores = all_scores[score_idx : score_idx + count] + score_idx += count + + best_idx = np.argmax(q_scores) + best_scores.append(float(q_scores[best_idx])) + matches.append(candidates_list[i][best_idx]) + else: + best_scores = [0.0] * len(queries) + matches = [""] * len(queries) + + else: + for qv in query_vecs: + resp = rindex.query_vector_topk(qv, k=1) + if not resp: + best_scores.append(0.0) + matches.append("") + continue + + hit = resp[0] + cosine_sim = 1.0 - float(hit["vector_distance"]) # convert to similarity + best_scores.append(cosine_sim) + matches.append(hit[text_col]) # 4) attach outputs out = queries.copy() @@ -74,17 +146,62 @@ def run_matching_redis(queries: pd.DataFrame, cache: pd.DataFrame, args): def run_matching(queries, cache, args): embedding_model = NeuralEmbedding(args.model_name, device="cuda" if torch.cuda.is_available() else "cpu") + + # Determine k for retrieval + k = 1 + cross_encoder = None + if getattr(args, "cross_encoder_model", None): + try: + from sentence_transformers import CrossEncoder + cross_encoder = CrossEncoder( + args.cross_encoder_model, + device="cuda" if torch.cuda.is_available() else "cpu" + ) + k = getattr(args, "rerank_k", 10) + print(f"Using Cross-Encoder reranking: {args.cross_encoder_model} (top-{k})") + except ImportError: + print("Warning: sentence_transformers not found or CrossEncoder import failed. Skipping reranking.") queries["best_scores"] = 0 + query_list = queries[args.sentence_column].to_list() + cache_list = cache[args.sentence_column].to_list() + best_indices, best_scores, decision_methods = embedding_model.calculate_best_matches_with_cache_large_dataset( - queries=queries[args.sentence_column].to_list(), - cache=cache[args.sentence_column].to_list(), + queries=query_list, + cache=cache_list, batch_size=512, + k=k ) - queries["best_scores"] = best_scores - queries["matches"] = cache.iloc[best_indices][args.sentence_column].to_list() + if cross_encoder and k > 1: + print("Reranking results with Cross-Encoder...") + # best_indices is (N, k) + all_pairs = [] + N = len(query_list) + + for i in range(N): + q_text = query_list[i] + for idx in best_indices[i]: + all_pairs.append([q_text, cache_list[idx]]) + + if all_pairs: + all_scores = cross_encoder.predict(all_pairs, batch_size=128, show_progress_bar=True) + all_scores = all_scores.reshape(N, k) + + best_idx_in_k = np.argmax(all_scores, axis=1) # (N,) + + final_scores = all_scores[np.arange(N), best_idx_in_k] + final_cache_indices = best_indices[np.arange(N), best_idx_in_k] + + queries["best_scores"] = final_scores + queries["matches"] = [cache_list[i] for i in final_cache_indices] + else: + queries["best_scores"] = 0.0 + queries["matches"] = "" + else: + queries["best_scores"] = best_scores + queries["matches"] = cache.iloc[best_indices][args.sentence_column].to_list() del embedding_model torch.cuda.empty_cache() diff --git a/src/customer_analysis/embedding_interface.py b/src/customer_analysis/embedding_interface.py index 68e9d9d..052d698 100644 --- a/src/customer_analysis/embedding_interface.py +++ b/src/customer_analysis/embedding_interface.py @@ -114,11 +114,13 @@ def calculate_best_matches_from_embeddings( # ------------------------------ def _infer_embedding_dim(self, sentences: list[str]) -> int: """Return the embedding dimension for the current model.""" + # Use probing to get the actual dimension as some models report incorrect config dimensions try: - return int(self.model.get_sentence_embedding_dimension()) + probe = self.model.encode([sentences[0] if sentences else "test"]) + return int(probe.shape[1]) except Exception: - probe = self.model.encode([sentences[0]]) - return int(probe.shape[1]) + # Fallback to config if probing fails + return int(self.model.get_sentence_embedding_dimension()) def _prepare_memmap_dir(self, memmap_dir: Optional[str]) -> tuple[bool, str, str]: """Ensure a directory exists for memmap files and return path components. @@ -236,6 +238,7 @@ def _compute_blockwise_best_matches_two_sets( mask_self_similarity: bool = False, sentence_offset: int = 0, early_stop: int = 0, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray]: """Blockwise nearest-neighbour where rows and columns come from two sets. @@ -244,8 +247,13 @@ def _compute_blockwise_best_matches_two_sets( diagonal entries for that alignment will be masked to -inf. """ n_rows = early_stop if early_stop > 0 else num_rows - best_scores = np.full(n_rows, -np.inf, dtype=np.float32) - best_indices = np.zeros(n_rows, dtype=np.int32) + + if k == 1: + best_scores = np.full(n_rows, -np.inf, dtype=np.float32) + best_indices = np.zeros(n_rows, dtype=np.int32) + else: + best_scores = np.full((n_rows, k), -np.inf, dtype=np.float32) + best_indices = np.zeros((n_rows, k), dtype=np.int32) rows_mm = np.memmap(row_emb_path, mode="r", dtype=dtype, shape=(n_rows, embedding_dim)) cols_mm = np.memmap(col_emb_path, mode="r", dtype=dtype, shape=(num_cols, embedding_dim)) @@ -254,8 +262,12 @@ def _compute_blockwise_best_matches_two_sets( row_end = min(row_start + row_block, n_rows) row_emb = np.asarray(rows_mm[row_start:row_end]) - chunk_best_scores = np.full(row_end - row_start, -np.inf, dtype=np.float32) - chunk_best_indices = np.zeros(row_end - row_start, dtype=np.int32) + if k == 1: + chunk_best_scores = np.full(row_end - row_start, -np.inf, dtype=np.float32) + chunk_best_indices = np.zeros(row_end - row_start, dtype=np.int32) + else: + chunk_best_scores = np.full((row_end - row_start, k), -np.inf, dtype=np.float32) + chunk_best_indices = np.zeros((row_end - row_start, k), dtype=np.int32) for col_start in range(0, num_cols, col_block): col_end = min(col_start + col_block, num_cols) @@ -277,13 +289,49 @@ def _compute_blockwise_best_matches_two_sets( col_local_indices = np.arange(overlap_start - col_start, overlap_end - col_start) sim[row_local_indices, col_local_indices] = -np.inf - block_idx = np.argmax(sim, axis=1) - block_val = sim[np.arange(sim.shape[0]), block_idx].astype(np.float32, copy=False) + if k == 1: + block_idx = np.argmax(sim, axis=1) + block_val = sim[np.arange(sim.shape[0]), block_idx].astype(np.float32, copy=False) + + for i in range(len(block_val)): + if block_val[i] > chunk_best_scores[i]: + chunk_best_scores[i] = block_val[i] + chunk_best_indices[i] = col_start + block_idx[i] + else: + # Top-k logic + # If columns in this block < k, take all valid + curr_block_size = col_end - col_start + if curr_block_size <= k: + top_k_in_block_idx = np.argsort(-sim, axis=1) # Sort all + top_k_in_block_val = np.take_along_axis(sim, top_k_in_block_idx, axis=1) + # Might have fewer than k if block is small + else: + # Use argpartition for top k + # We want largest k + part_idx = np.argpartition(-sim, k, axis=1)[:, :k] + top_k_in_block_val = np.take_along_axis(sim, part_idx, axis=1) + + # Sort them to have ordered top-k (optional but good for merging) + sorted_sub_idx = np.argsort(-top_k_in_block_val, axis=1) + top_k_in_block_val = np.take_along_axis(top_k_in_block_val, sorted_sub_idx, axis=1) + top_k_in_block_idx = np.take_along_axis(part_idx, sorted_sub_idx, axis=1) + + # Merge with accumulated bests + # chunk_best_scores: (batch, k) + # top_k_in_block_val: (batch, min(block, k)) + + # Adjust indices to global column indices + top_k_in_block_idx_global = top_k_in_block_idx + col_start + + combined_vals = np.concatenate([chunk_best_scores, top_k_in_block_val], axis=1) + combined_idxs = np.concatenate([chunk_best_indices, top_k_in_block_idx_global], axis=1) + + # Find top k in combined + best_combined_args = np.argsort(-combined_vals, axis=1)[:, :k] + + chunk_best_scores = np.take_along_axis(combined_vals, best_combined_args, axis=1) + chunk_best_indices = np.take_along_axis(combined_idxs, best_combined_args, axis=1) - for i in range(len(block_val)): - if block_val[i] > chunk_best_scores[i]: - chunk_best_scores[i] = block_val[i] - chunk_best_indices[i] = col_start + block_idx[i] best_scores[row_start:row_end] = chunk_best_scores best_indices[row_start:row_end] = chunk_best_indices @@ -356,6 +404,7 @@ def calculate_best_matches_with_cache_large_dataset( dtype: np.dtype = np.float32, sentence_offset: int = 0, early_stop: int = 0, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """Large-dataset variant: find best cache match for each sentence using memmaps. @@ -415,6 +464,7 @@ def calculate_best_matches_with_cache_large_dataset( mask_self_similarity=(queries is cache or queries == cache), sentence_offset=sentence_offset, early_stop=early_stop, + k=k, ) decision_methods = np.full(num_sentences, "neural", dtype=object) @@ -432,7 +482,7 @@ def calculate_best_matches_with_cache_large_dataset( return best_indices, best_scores, decision_methods def calculate_best_matches_with_cache( - self, sentences: list[str], cache: list[str], batch_size: int = 1024 + self, sentences: list[str], cache: list[str], batch_size: int = 1024, k: int = 1 ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Calculate the best similarity match for each sentence against all other @@ -448,6 +498,7 @@ def calculate_best_matches_with_cache( cache=cache, batch_size=batch_size, sentence_offset=0, + k=k, ) best_indices, best_scores, decision_methods = out @@ -463,6 +514,7 @@ def calculate_best_matches_from_embeddings_with_cache( batch_size: int = 1024, sentence_offset: int = 0, mask_self_similarity: bool = False, + k: int = 1, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Calculate the best similarity match for each sentence against all other @@ -479,8 +531,13 @@ def calculate_best_matches_from_embeddings_with_cache( norms[norms == 0] = 1e-9 cache_embeddings_matrix /= norms - best_indices = np.zeros(len(sentences), dtype=np.int32) - best_scores = np.zeros(len(sentences), dtype=np.float32) + if k == 1: + best_indices = np.zeros(len(sentences), dtype=np.int32) + best_scores = np.zeros(len(sentences), dtype=np.float32) + else: + best_indices = np.zeros((len(sentences), k), dtype=np.int32) + best_scores = np.zeros((len(sentences), k), dtype=np.float32) + decision_methods = np.full(len(sentences), "neural", dtype=object) for start in tqdm( @@ -503,12 +560,25 @@ def calculate_best_matches_from_embeddings_with_cache( if mask_self_similarity: batch_sims[row_indices, col_indices] = -np.inf - best_indices_batch = np.argmax( - batch_sims, axis=1 - ) # we want to find the best match for each sentence in the batch (batch_size) - best_scores_batch = batch_sims[ - row_indices, best_indices_batch - ] # we want to find the best score for each sentence in the batch (batch_size) + if k == 1: + best_indices_batch = np.argmax( + batch_sims, axis=1 + ) # we want to find the best match for each sentence in the batch (batch_size) + best_scores_batch = batch_sims[ + row_indices, best_indices_batch + ] # we want to find the best score for each sentence in the batch (batch_size) + else: + # Top k + if batch_sims.shape[1] <= k: + # Less candidates than k + best_indices_batch = np.argsort(-batch_sims, axis=1) + best_scores_batch = np.take_along_axis(batch_sims, best_indices_batch, axis=1) + else: + part_idx = np.argpartition(-batch_sims, k, axis=1)[:, :k] + top_k_val = np.take_along_axis(batch_sims, part_idx, axis=1) + sorted_sub_idx = np.argsort(-top_k_val, axis=1) + best_scores_batch = np.take_along_axis(top_k_val, sorted_sub_idx, axis=1) + best_indices_batch = np.take_along_axis(part_idx, sorted_sub_idx, axis=1) best_indices[start:end] = best_indices_batch best_scores[start:end] = best_scores_batch diff --git a/src/customer_analysis/metrics_util.py b/src/customer_analysis/metrics_util.py index 64ea008..c603007 100644 --- a/src/customer_analysis/metrics_util.py +++ b/src/customer_analysis/metrics_util.py @@ -64,8 +64,9 @@ def sweep_thresholds_on_results(results_df: pd.DataFrame) -> pd.DataFrame: """Perform threshold sweep and return results.""" print("\nPerforming threshold sweep") min_score = results_df["similarity_score"].min() - steps = max(min(200, len(results_df)), 1) # At least 1 step, cannot be 0, and max length of 200 - thresholds = np.linspace(min_score, 1.0, steps) + max_score = results_df["similarity_score"].max() + steps = 200 + thresholds = np.linspace(min_score, 1.0 if max_score < 1.0 else max_score, steps) results = [] for i, threshold in enumerate(thresholds): diff --git a/src/customer_analysis/query_engine.py b/src/customer_analysis/query_engine.py index c44b2f6..1d7cbe7 100644 --- a/src/customer_analysis/query_engine.py +++ b/src/customer_analysis/query_engine.py @@ -56,8 +56,12 @@ class RedisVectorIndex: def __post_init__(self): # 0) init local embedding model device = self.device or ("cuda" if _HAS_TORCH and torch.cuda.is_available() else "cpu") - self.model = SentenceTransformer(self.model_name, device=device) - self.embed_dim = int(self.model.get_sentence_embedding_dimension()) + self.model = SentenceTransformer(self.model_name, device=device, local_files_only=False, trust_remote_code=True) + + # Probe the model to get the actual output dimension + # (Some models report incorrect dimension in config) + probe = self.model.encode(["test"], convert_to_numpy=True) + self.embed_dim = int(probe.shape[1]) # 1) ensure Redis index exists (schema dims come from the model) schema_dict = { @@ -79,8 +83,8 @@ def __post_init__(self): } schema = IndexSchema.from_dict(schema_dict) self.index: SearchIndex = SearchIndex(schema, redis_url=self.redis_url) - if not self.index.exists(): - self.index.create(overwrite=False) + # Always overwrite to ensure schema matches the current model dimensions + self.index.create(overwrite=True) def _embed_batch(self, texts: List[str]) -> np.ndarray: vecs = self.model.encode(