Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,18 @@ def main(args):
parser.add_argument(
"--redis_batch_size", type=int, default=256, help="Batch size for Redis vector operations (default: 256)"
)
parser.add_argument(
"--cross_encoder_model",
type=str,
default=None,
help="Name of the cross-encoder model to use for reranking (default: None)",
)
parser.add_argument(
"--rerank_k",
type=int,
default=10,
help="Number of candidates to rerank (default: 10)",
)
args = parser.parse_args()

main(args)
160 changes: 94 additions & 66 deletions run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def main():
parser.add_argument("--redis_index_name", type=str, default="idx_cache_match")
parser.add_argument("--redis_doc_prefix", type=str, default="cache:")
parser.add_argument("--redis_batch_size", type=int, default=256)
parser.add_argument(
"--cross_encoder_models",
type=str,
nargs="*",
default=None,
help="List of cross-encoder models (optional). If not provided, only bi-encoder is used.",
)
parser.add_argument("--rerank_k", type=int, default=10, help="Number of candidates to rerank.")

args = parser.parse_args()

Expand Down Expand Up @@ -101,76 +109,96 @@ def main():
for model_name in args.models:
print(f"\n Model: {model_name}")

# Sanitize model name for directory structure
safe_model_name = model_name.replace("/", "_")
# Prepare list of cross-encoders to iterate over (None = no reranking)
ce_models = args.cross_encoder_models if args.cross_encoder_models else [None]

for ce_model_name in ce_models:
if ce_model_name:
print(f" Cross-Encoder: {ce_model_name}")
else:
print(f" Cross-Encoder: None (Bi-Encoder only)")

# Sanitize model name for directory structure
safe_model_name = model_name.replace("/", "_")

for run_i in range(1, args.n_runs + 1):
print(f" Run {run_i}/{args.n_runs}...")

# 1. Bootstrapping Logic
# Sample 80% of the universe
run_universe = full_df.sample(
frac=args.sample_ratio, random_state=run_i
) # Use run_i as seed for reproducibility per run

for run_i in range(1, args.n_runs + 1):
print(f" Run {run_i}/{args.n_runs}...")
# Split into Queries (n_samples) and Cache (remainder)
if len(run_universe) <= args.n_samples:
print(
f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping."
)
continue

# 1. Bootstrapping Logic
# Sample 80% of the universe
run_universe = full_df.sample(
frac=args.sample_ratio, random_state=run_i
) # Use run_i as seed for reproducibility per run
queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000)
cache = run_universe.drop(queries.index)

# Split into Queries (n_samples) and Cache (remainder)
if len(run_universe) <= args.n_samples:
print(
f" Warning: Dataset size ({len(run_universe)}) <= n_samples ({args.n_samples}). Skipping."
# Shuffle cache
cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True)
queries = queries.reset_index(drop=True)

# 2. Construct Output Path
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

# Include cross-encoder in output path if used
model_dir_name = safe_model_name
if ce_model_name:
safe_cross_encoder_name = ce_model_name.replace("/", "_")
model_dir_name = f"{safe_model_name}_rerank_{safe_cross_encoder_name}"

run_output_dir = os.path.join(
args.output_dir, dataset_name, model_dir_name, f"run_{run_i}", timestamp
)
continue

queries = run_universe.sample(n=args.n_samples, random_state=run_i + 1000)
cache = run_universe.drop(queries.index)

# Shuffle cache
cache = cache.sample(frac=1, random_state=run_i + 2000).reset_index(drop=True)
queries = queries.reset_index(drop=True)

# 2. Construct Output Path
timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
run_output_dir = os.path.join(args.output_dir, dataset_name, safe_model_name, f"run_{run_i}", timestamp)
os.makedirs(run_output_dir, exist_ok=True)

# 3. Prepare Args for Evaluation
eval_args = BenchmarkArgs(
query_log_path=dataset_path, # Not strictly used by logic below but good for reference
sentence_column=args.sentence_column,
output_dir=run_output_dir,
n_samples=args.n_samples,
model_name=model_name,
cache_path=None,
full=args.full,
llm_name=args.llm_name,
llm_model=llm_classifier,
sweep_steps=200, # Default
use_redis=args.use_redis,
redis_url=args.redis_url,
redis_index_name=args.redis_index_name,
redis_doc_prefix=args.redis_doc_prefix,
redis_batch_size=args.redis_batch_size,
# device defaults to code logic
)

# 4. Run Evaluation
try:
print(" Matching...")
if args.use_redis:
queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args)
else:
queries_matched = run_matching(queries.copy(), cache.copy(), eval_args)

print(" Evaluating...")
if args.full:
run_full_evaluation(queries_matched, eval_args)
else:
run_chr_analysis(queries_matched, eval_args)

except Exception as e:
print(f" Error in run {run_i}: {e}")
import traceback

traceback.print_exc()
os.makedirs(run_output_dir, exist_ok=True)

# 3. Prepare Args for Evaluation
eval_args = BenchmarkArgs(
query_log_path=dataset_path, # Not strictly used by logic below but good for reference
sentence_column=args.sentence_column,
output_dir=run_output_dir,
n_samples=args.n_samples,
model_name=model_name,
cache_path=None,
full=args.full,
llm_name=args.llm_name,
llm_model=llm_classifier,
sweep_steps=200, # Default
use_redis=args.use_redis,
redis_url=args.redis_url,
redis_index_name=args.redis_index_name,
redis_doc_prefix=args.redis_doc_prefix,
redis_batch_size=args.redis_batch_size,
cross_encoder_model=ce_model_name,
rerank_k=args.rerank_k,
# device defaults to code logic
)

# 4. Run Evaluation
try:
print(" Matching...")
if args.use_redis:
queries_matched = run_matching_redis(queries.copy(), cache.copy(), eval_args)
else:
queries_matched = run_matching(queries.copy(), cache.copy(), eval_args)

print(" Evaluating...")
if args.full:
run_full_evaluation(queries_matched, eval_args)
else:
run_chr_analysis(queries_matched, eval_args)

except Exception as e:
print(f" Error in run {run_i}: {e}")
import traceback

traceback.print_exc()

print("\nBenchmark completed.")

Expand Down
15 changes: 15 additions & 0 deletions run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Example usage:
uv run run_benchmark.py \
--dataset_dir "dataset" \
--output_dir "cross_encoder_results" \
--models "Alibaba-NLP/gte-modernbert-base" "redis/langcache-embed-v1" "redis/langcache-embed-v3-small" \
--dataset_names "vizio_unique_medium.csv" "axis_bank_unique_sentences.csv"\
--sentence_column "sentence" \
--n_runs 10 \
--n_samples 10000 \
--sample_ratio 0.8 \
--llm_name "tensoropera/Fox-1-1.6B" \
--full \
--use_redis \
# --cross_encoder_models "redis/langcache-reranker-v1-softmnrl-triplet" "Alibaba-NLP/gte-reranker-modernbert-base" \
# --rerank_k 5
Loading