Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions cli/src/bench/haystack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,30 @@ pub struct HaystackArgs {
pub measurement_time: u64,

/// Comma-separated document counts for basic benchmark
#[arg(long, default_value = "2,4,8", value_parser = super::parse_csv_usize, env = "BENCH_NUM_DOCS")]
#[arg(
long,
default_value = "2,4,8",
value_delimiter = ',',
env = "BENCH_NUM_DOCS"
)]
pub num_docs: Vec<usize>,

/// Comma-separated document lengths (words of filler per doc)
#[arg(long, default_value = "0,10,100,200,400,600,800,1000", value_parser = super::parse_csv_usize, env = "BENCH_DOC_LENGTH")]
#[arg(
long,
default_value = "0,10,100,200,400,600,800,1000",
value_delimiter = ',',
env = "BENCH_DOC_LENGTH"
)]
pub doc_lengths: Vec<usize>,

/// Comma-separated chunk sizes for map-reduce benchmark
#[arg(long, default_value = "2,4", value_parser = super::parse_csv_usize, env = "BENCH_CHUNK_SIZES")]
#[arg(
long,
default_value = "2,4",
value_delimiter = ',',
env = "BENCH_CHUNK_SIZES"
)]
pub chunk_sizes: Vec<usize>,

/// Number of documents for the map-reduce benchmark
Expand Down Expand Up @@ -479,3 +494,75 @@ pub fn run(args: HaystackArgs) -> Result<(), SpnlError> {
criterion.final_summary();
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

// ---- ratio ----

#[test]
fn ratio_basic() {
assert!((ratio(3, 4) - 0.75).abs() < f64::EPSILON);
}

// ---- score ----

#[test]
fn score_all_correct() {
let expected = vec!["a".into(), "b".into(), "c".into()];
let actual = vec!["a".into(), "b".into(), "c".into()];
let (precision, recall) = score(&expected, &actual);
assert!((precision - 1.0).abs() < f64::EPSILON);
assert!((recall - 1.0).abs() < f64::EPSILON);
}

#[test]
fn score_with_false_positives() {
let expected: Vec<String> = vec!["a".into(), "b".into()];
let actual: Vec<String> = vec!["a".into(), "b".into(), "x".into()];
let (precision, recall) = score(&expected, &actual);
// precision = 2/3
assert!((precision - 2.0 / 3.0).abs() < 1e-10);
// recall = 2/2 = 1.0
assert!((recall - 1.0).abs() < f64::EPSILON);
}

#[test]
fn score_with_false_negatives() {
let expected: Vec<String> = vec!["a".into(), "b".into(), "c".into()];
let actual: Vec<String> = vec!["a".into()];
let (precision, recall) = score(&expected, &actual);
assert!((precision - 1.0).abs() < f64::EPSILON);
assert!((recall - 1.0 / 3.0).abs() < 1e-10);
}

#[test]
fn score_empty_actual_precision_zero() {
let expected: Vec<String> = vec!["a".into()];
let actual: Vec<String> = vec![];
let (precision, _recall) = score(&expected, &actual);
assert!((precision - 0.0).abs() < f64::EPSILON);
}

// ---- score_chain ----

#[test]
fn score_chain_filters_first_name() {
// expected[0] is the name to filter; subsequent actuals should NOT contain it
let expected: Vec<String> = vec!["bad-name".into(), "".into(), "".into()];
// actual[0] is ignored; actual[1..] are checked
let actual: Vec<String> = vec!["ignored".into(), "good".into(), "also-good".into()];
let (score, _) = score_chain(&expected, &actual);
assert!((score - 1.0).abs() < f64::EPSILON);
}

#[test]
fn score_chain_penalizes_leaked_name() {
let expected: Vec<String> = vec!["bad-name".into(), "".into(), "".into()];
let actual: Vec<String> = vec!["ignored".into(), "bad-name".into(), "ok".into()];
let (score, _) = score_chain(&expected, &actual);
// 1 out of 2 subsequent actuals matches do_not_want → 1 - 1/2 = 0.5
assert!((score - 0.5).abs() < f64::EPSILON);
}
}
137 changes: 125 additions & 12 deletions cli/src/bench/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,15 @@ pub enum BenchCommands {

pub async fn run(command: BenchCommands) -> Result<(), SpnlError> {
match command {
BenchCommands::Haystack(args) => haystack::run(args),
BenchCommands::Niah(args) => niah::run(args),
BenchCommands::Ruler(args) => ruler::run(args),
BenchCommands::Haystack(args) => {
Ok(tokio::task::spawn_blocking(move || haystack::run(args)).await??)
}
BenchCommands::Niah(args) => {
Ok(tokio::task::spawn_blocking(move || niah::run(args)).await??)
}
BenchCommands::Ruler(args) => {
Ok(tokio::task::spawn_blocking(move || ruler::run(args)).await??)
}
BenchCommands::Ragcsv(args) => ragcsv::run(args).await,
}
}
Expand Down Expand Up @@ -179,13 +185,120 @@ pub fn encode_and_trim(
}
}

/// Parse a comma-separated list of integers from a string
pub fn parse_csv_usize(s: &str) -> Result<Vec<usize>, String> {
s.split(',')
.map(|n| {
n.trim()
.parse()
.map_err(|e| format!("invalid number '{}': {}", n.trim(), e))
})
.collect()
#[cfg(test)]
mod tests {
use super::*;

// ---- compute_quantiles ----

#[test]
fn quantiles_single_value() {
let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&[42.0]);
assert_eq!(min, 42.0);
assert_eq!(p25, 42.0);
assert_eq!(p50, 42.0);
assert_eq!(p75, 42.0);
assert_eq!(p90, 42.0);
assert_eq!(p99, 42.0);
assert_eq!(max, 42.0);
}

#[test]
fn quantiles_multiple_values() {
let values: Vec<f64> = (1..=100).map(|i| i as f64).collect();
let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&values);
assert_eq!(min, 1.0);
assert_eq!(p25, 26.0);
assert_eq!(p50, 51.0);
assert_eq!(p75, 76.0);
assert_eq!(p90, 91.0);
assert_eq!(p99, 100.0);
assert_eq!(max, 100.0);
}

#[test]
fn quantiles_unsorted_input_produces_sorted_output() {
let values = vec![5.0, 1.0, 3.0, 4.0, 2.0];
let (min, _, _, _, _, _, max) = compute_quantiles(&values);
assert_eq!(min, 1.0);
assert_eq!(max, 5.0);
}

#[test]
fn quantiles_empty_returns_zeros() {
let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&[]);
assert_eq!(
(min, p25, p50, p75, p90, p99, max),
(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
);
}

// ---- compute_quantiles_with_avg ----

#[test]
fn quantiles_with_avg_computes_average() {
let values = vec![10.0, 20.0, 30.0];
let (_, _, _, _, _, _, _, avg) = compute_quantiles_with_avg(&values);
assert!((avg - 20.0).abs() < f64::EPSILON);
}

#[test]
fn quantiles_with_avg_empty_returns_zeros() {
let (min, p25, p50, p75, p90, p99, max, avg) = compute_quantiles_with_avg(&[]);
assert_eq!(
(min, p25, p50, p75, p90, p99, max, avg),
(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)
);
}

// ---- clap arg-parsing smoke tests ----

use crate::args::FullArgs;
use clap::Parser;

#[test]
fn parse_bench_ruler_defaults() {
FullArgs::try_parse_from(["spnl", "bench", "ruler"]).unwrap();
}

#[test]
fn parse_bench_niah_defaults() {
FullArgs::try_parse_from(["spnl", "bench", "niah"]).unwrap();
}

#[test]
fn parse_bench_haystack_defaults() {
FullArgs::try_parse_from(["spnl", "bench", "haystack"]).unwrap();
}

#[test]
fn parse_bench_ruler_csv_context_lengths() {
let args = FullArgs::try_parse_from([
"spnl",
"bench",
"ruler",
"--context-lengths",
"1000,2000,4000",
])
.unwrap();
match args.command {
crate::args::Commands::Bench {
command: BenchCommands::Ruler(r),
} => assert_eq!(r.context_lengths, vec![1000, 2000, 4000]),
other => panic!("unexpected command: {:?}", other),
}
}

#[test]
fn parse_bench_niah_csv_depth_percentages() {
let args =
FullArgs::try_parse_from(["spnl", "bench", "niah", "--depth-percentages", "0,50,100"])
.unwrap();
match args.command {
crate::args::Commands::Bench {
command: BenchCommands::Niah(n),
} => assert_eq!(n.depth_percentages, vec![0, 50, 100]),
other => panic!("unexpected command: {:?}", other),
}
}
}
57 changes: 54 additions & 3 deletions cli/src/bench/niah.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,30 @@ pub struct NiahArgs {
pub measurement_time: u64,

/// Comma-separated context lengths in tokens
#[arg(long, default_value = "1000,2000,4000,8000", value_parser = super::parse_csv_usize, env = "BENCH_CONTEXT_LENGTHS")]
#[arg(
long,
default_value = "1000,2000,4000,8000",
value_delimiter = ',',
env = "BENCH_CONTEXT_LENGTHS"
)]
pub context_lengths: Vec<usize>,

/// Comma-separated depth percentages (0-100)
#[arg(long, default_value = "0,25,50,75,100", value_parser = super::parse_csv_usize, env = "BENCH_DEPTH_PERCENTAGES")]
#[arg(
long,
default_value = "0,25,50,75,100",
value_delimiter = ',',
env = "BENCH_DEPTH_PERCENTAGES"
)]
pub depth_percentages: Vec<usize>,

/// Comma-separated chunk sizes (0 = no chunking)
#[arg(long, default_value = "0,2,4", value_parser = super::parse_csv_usize, env = "BENCH_CHUNK_SIZES")]
#[arg(
long,
default_value = "0,2,4",
value_delimiter = ',',
env = "BENCH_CHUNK_SIZES"
)]
pub chunk_sizes: Vec<usize>,

/// Token buffer for system/question/response
Expand Down Expand Up @@ -518,3 +533,39 @@ pub fn run(args: NiahArgs) -> Result<(), SpnlError> {
criterion.final_summary();
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn evaluate_exact_substring_match() {
assert!(
(evaluate_needle_retrieval("The answer is 73.", "73", false) - 1.0).abs()
< f64::EPSILON
);
}

#[test]
fn evaluate_case_insensitive_match() {
assert!(
(evaluate_needle_retrieval("HELLO world", "hello", false) - 1.0).abs() < f64::EPSILON
);
}

#[test]
fn evaluate_numeric_match_in_tokens() {
// Number not as substring but parsed from a token
assert!(
(evaluate_needle_retrieval("The number is (73).", "73", false) - 1.0).abs()
< f64::EPSILON
);
}

#[test]
fn evaluate_no_match_returns_zero() {
assert!(
(evaluate_needle_retrieval("Nothing relevant here", "73", false)).abs() < f64::EPSILON
);
}
}
Loading
Loading