From 1b70f5943c849bf7a0e7e28734d6ce048ec6fc0c Mon Sep 17 00:00:00 2001 From: Nick Mitchell Date: Wed, 25 Feb 2026 18:35:57 -0500 Subject: [PATCH] fix(bench): resolve runtime panic and clap CSV arg parsing; add tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace custom parse_csv_usize value_parser with clap built-in value_delimiter to fix type mismatch (parser returned Vec but clap expected usize for Vec fields). Wrap sync benchmark runners in spawn_blocking to avoid nested tokio runtime panic from criterion and reqwest::blocking running inside an async context. Add 49 unit tests across the bench modules covering pure scoring functions, arg parsing, and string metrics — no model or network calls needed at test time. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Nick Mitchell --- cli/src/bench/haystack.rs | 93 +++++++++++++++++++++- cli/src/bench/mod.rs | 137 ++++++++++++++++++++++++++++++--- cli/src/bench/niah.rs | 57 +++++++++++++- cli/src/bench/ragcsv.rs | 157 ++++++++++++++++++++++++++++++++++++++ cli/src/bench/ruler.rs | 43 ++++++++++- 5 files changed, 467 insertions(+), 20 deletions(-) diff --git a/cli/src/bench/haystack.rs b/cli/src/bench/haystack.rs index 90e1f148..091191e2 100644 --- a/cli/src/bench/haystack.rs +++ b/cli/src/bench/haystack.rs @@ -28,15 +28,30 @@ pub struct HaystackArgs { pub measurement_time: u64, /// Comma-separated document counts for basic benchmark - #[arg(long, default_value = "2,4,8", value_parser = super::parse_csv_usize, env = "BENCH_NUM_DOCS")] + #[arg( + long, + default_value = "2,4,8", + value_delimiter = ',', + env = "BENCH_NUM_DOCS" + )] pub num_docs: Vec, /// Comma-separated document lengths (words of filler per doc) - #[arg(long, default_value = "0,10,100,200,400,600,800,1000", value_parser = super::parse_csv_usize, env = "BENCH_DOC_LENGTH")] + #[arg( + long, + default_value = "0,10,100,200,400,600,800,1000", + value_delimiter = ',', + env = "BENCH_DOC_LENGTH" + )] pub doc_lengths: Vec, /// Comma-separated chunk sizes for map-reduce benchmark - #[arg(long, default_value = "2,4", value_parser = super::parse_csv_usize, env = "BENCH_CHUNK_SIZES")] + #[arg( + long, + default_value = "2,4", + value_delimiter = ',', + env = "BENCH_CHUNK_SIZES" + )] pub chunk_sizes: Vec, /// Number of documents for the map-reduce benchmark @@ -479,3 +494,75 @@ pub fn run(args: HaystackArgs) -> Result<(), SpnlError> { criterion.final_summary(); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + // ---- ratio ---- + + #[test] + fn ratio_basic() { + assert!((ratio(3, 4) - 0.75).abs() < f64::EPSILON); + } + + // ---- score ---- + + #[test] + fn score_all_correct() { + let expected = vec!["a".into(), "b".into(), "c".into()]; + let actual = vec!["a".into(), "b".into(), "c".into()]; + let (precision, recall) = score(&expected, &actual); + assert!((precision - 1.0).abs() < f64::EPSILON); + assert!((recall - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn score_with_false_positives() { + let expected: Vec = vec!["a".into(), "b".into()]; + let actual: Vec = vec!["a".into(), "b".into(), "x".into()]; + let (precision, recall) = score(&expected, &actual); + // precision = 2/3 + assert!((precision - 2.0 / 3.0).abs() < 1e-10); + // recall = 2/2 = 1.0 + assert!((recall - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn score_with_false_negatives() { + let expected: Vec = vec!["a".into(), "b".into(), "c".into()]; + let actual: Vec = vec!["a".into()]; + let (precision, recall) = score(&expected, &actual); + assert!((precision - 1.0).abs() < f64::EPSILON); + assert!((recall - 1.0 / 3.0).abs() < 1e-10); + } + + #[test] + fn score_empty_actual_precision_zero() { + let expected: Vec = vec!["a".into()]; + let actual: Vec = vec![]; + let (precision, _recall) = score(&expected, &actual); + assert!((precision - 0.0).abs() < f64::EPSILON); + } + + // ---- score_chain ---- + + #[test] + fn score_chain_filters_first_name() { + // expected[0] is the name to filter; subsequent actuals should NOT contain it + let expected: Vec = vec!["bad-name".into(), "".into(), "".into()]; + // actual[0] is ignored; actual[1..] are checked + let actual: Vec = vec!["ignored".into(), "good".into(), "also-good".into()]; + let (score, _) = score_chain(&expected, &actual); + assert!((score - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn score_chain_penalizes_leaked_name() { + let expected: Vec = vec!["bad-name".into(), "".into(), "".into()]; + let actual: Vec = vec!["ignored".into(), "bad-name".into(), "ok".into()]; + let (score, _) = score_chain(&expected, &actual); + // 1 out of 2 subsequent actuals matches do_not_want → 1 - 1/2 = 0.5 + assert!((score - 0.5).abs() < f64::EPSILON); + } +} diff --git a/cli/src/bench/mod.rs b/cli/src/bench/mod.rs index 76a55606..c4eddeed 100644 --- a/cli/src/bench/mod.rs +++ b/cli/src/bench/mod.rs @@ -20,9 +20,15 @@ pub enum BenchCommands { pub async fn run(command: BenchCommands) -> Result<(), SpnlError> { match command { - BenchCommands::Haystack(args) => haystack::run(args), - BenchCommands::Niah(args) => niah::run(args), - BenchCommands::Ruler(args) => ruler::run(args), + BenchCommands::Haystack(args) => { + Ok(tokio::task::spawn_blocking(move || haystack::run(args)).await??) + } + BenchCommands::Niah(args) => { + Ok(tokio::task::spawn_blocking(move || niah::run(args)).await??) + } + BenchCommands::Ruler(args) => { + Ok(tokio::task::spawn_blocking(move || ruler::run(args)).await??) + } BenchCommands::Ragcsv(args) => ragcsv::run(args).await, } } @@ -179,13 +185,120 @@ pub fn encode_and_trim( } } -/// Parse a comma-separated list of integers from a string -pub fn parse_csv_usize(s: &str) -> Result, String> { - s.split(',') - .map(|n| { - n.trim() - .parse() - .map_err(|e| format!("invalid number '{}': {}", n.trim(), e)) - }) - .collect() +#[cfg(test)] +mod tests { + use super::*; + + // ---- compute_quantiles ---- + + #[test] + fn quantiles_single_value() { + let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&[42.0]); + assert_eq!(min, 42.0); + assert_eq!(p25, 42.0); + assert_eq!(p50, 42.0); + assert_eq!(p75, 42.0); + assert_eq!(p90, 42.0); + assert_eq!(p99, 42.0); + assert_eq!(max, 42.0); + } + + #[test] + fn quantiles_multiple_values() { + let values: Vec = (1..=100).map(|i| i as f64).collect(); + let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&values); + assert_eq!(min, 1.0); + assert_eq!(p25, 26.0); + assert_eq!(p50, 51.0); + assert_eq!(p75, 76.0); + assert_eq!(p90, 91.0); + assert_eq!(p99, 100.0); + assert_eq!(max, 100.0); + } + + #[test] + fn quantiles_unsorted_input_produces_sorted_output() { + let values = vec![5.0, 1.0, 3.0, 4.0, 2.0]; + let (min, _, _, _, _, _, max) = compute_quantiles(&values); + assert_eq!(min, 1.0); + assert_eq!(max, 5.0); + } + + #[test] + fn quantiles_empty_returns_zeros() { + let (min, p25, p50, p75, p90, p99, max) = compute_quantiles(&[]); + assert_eq!( + (min, p25, p50, p75, p90, p99, max), + (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + ); + } + + // ---- compute_quantiles_with_avg ---- + + #[test] + fn quantiles_with_avg_computes_average() { + let values = vec![10.0, 20.0, 30.0]; + let (_, _, _, _, _, _, _, avg) = compute_quantiles_with_avg(&values); + assert!((avg - 20.0).abs() < f64::EPSILON); + } + + #[test] + fn quantiles_with_avg_empty_returns_zeros() { + let (min, p25, p50, p75, p90, p99, max, avg) = compute_quantiles_with_avg(&[]); + assert_eq!( + (min, p25, p50, p75, p90, p99, max, avg), + (0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + ); + } + + // ---- clap arg-parsing smoke tests ---- + + use crate::args::FullArgs; + use clap::Parser; + + #[test] + fn parse_bench_ruler_defaults() { + FullArgs::try_parse_from(["spnl", "bench", "ruler"]).unwrap(); + } + + #[test] + fn parse_bench_niah_defaults() { + FullArgs::try_parse_from(["spnl", "bench", "niah"]).unwrap(); + } + + #[test] + fn parse_bench_haystack_defaults() { + FullArgs::try_parse_from(["spnl", "bench", "haystack"]).unwrap(); + } + + #[test] + fn parse_bench_ruler_csv_context_lengths() { + let args = FullArgs::try_parse_from([ + "spnl", + "bench", + "ruler", + "--context-lengths", + "1000,2000,4000", + ]) + .unwrap(); + match args.command { + crate::args::Commands::Bench { + command: BenchCommands::Ruler(r), + } => assert_eq!(r.context_lengths, vec![1000, 2000, 4000]), + other => panic!("unexpected command: {:?}", other), + } + } + + #[test] + fn parse_bench_niah_csv_depth_percentages() { + let args = + FullArgs::try_parse_from(["spnl", "bench", "niah", "--depth-percentages", "0,50,100"]) + .unwrap(); + match args.command { + crate::args::Commands::Bench { + command: BenchCommands::Niah(n), + } => assert_eq!(n.depth_percentages, vec![0, 50, 100]), + other => panic!("unexpected command: {:?}", other), + } + } } diff --git a/cli/src/bench/niah.rs b/cli/src/bench/niah.rs index f1bfd91a..162dd9fd 100644 --- a/cli/src/bench/niah.rs +++ b/cli/src/bench/niah.rs @@ -39,15 +39,30 @@ pub struct NiahArgs { pub measurement_time: u64, /// Comma-separated context lengths in tokens - #[arg(long, default_value = "1000,2000,4000,8000", value_parser = super::parse_csv_usize, env = "BENCH_CONTEXT_LENGTHS")] + #[arg( + long, + default_value = "1000,2000,4000,8000", + value_delimiter = ',', + env = "BENCH_CONTEXT_LENGTHS" + )] pub context_lengths: Vec, /// Comma-separated depth percentages (0-100) - #[arg(long, default_value = "0,25,50,75,100", value_parser = super::parse_csv_usize, env = "BENCH_DEPTH_PERCENTAGES")] + #[arg( + long, + default_value = "0,25,50,75,100", + value_delimiter = ',', + env = "BENCH_DEPTH_PERCENTAGES" + )] pub depth_percentages: Vec, /// Comma-separated chunk sizes (0 = no chunking) - #[arg(long, default_value = "0,2,4", value_parser = super::parse_csv_usize, env = "BENCH_CHUNK_SIZES")] + #[arg( + long, + default_value = "0,2,4", + value_delimiter = ',', + env = "BENCH_CHUNK_SIZES" + )] pub chunk_sizes: Vec, /// Token buffer for system/question/response @@ -518,3 +533,39 @@ pub fn run(args: NiahArgs) -> Result<(), SpnlError> { criterion.final_summary(); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn evaluate_exact_substring_match() { + assert!( + (evaluate_needle_retrieval("The answer is 73.", "73", false) - 1.0).abs() + < f64::EPSILON + ); + } + + #[test] + fn evaluate_case_insensitive_match() { + assert!( + (evaluate_needle_retrieval("HELLO world", "hello", false) - 1.0).abs() < f64::EPSILON + ); + } + + #[test] + fn evaluate_numeric_match_in_tokens() { + // Number not as substring but parsed from a token + assert!( + (evaluate_needle_retrieval("The number is (73).", "73", false) - 1.0).abs() + < f64::EPSILON + ); + } + + #[test] + fn evaluate_no_match_returns_zero() { + assert!( + (evaluate_needle_retrieval("Nothing relevant here", "73", false)).abs() < f64::EPSILON + ); + } +} diff --git a/cli/src/bench/ragcsv.rs b/cli/src/bench/ragcsv.rs index 7020fd2d..e80f7c75 100644 --- a/cli/src/bench/ragcsv.rs +++ b/cli/src/bench/ragcsv.rs @@ -877,3 +877,160 @@ pub async fn run(args: RagcsvArgs) -> Result<(), SpnlError> { Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + // ---- python_repr_to_json ---- + + #[test] + fn python_repr_single_quoted_strings() { + assert_eq!(python_repr_to_json("{'a': 'b'}"), r#"{"a": "b"}"#); + } + + #[test] + fn python_repr_none_true_false() { + assert_eq!( + python_repr_to_json("{'x': None, 'y': True, 'z': False}"), + r#"{"x": null, "y": true, "z": false}"# + ); + } + + #[test] + fn python_repr_embedded_double_quotes() { + // A single-quoted Python string containing a double quote should escape it + assert_eq!( + python_repr_to_json(r#"{'a': 'he said "hi"'}"#), + r#"{"a": "he said \"hi\""}"# + ); + } + + #[test] + fn python_repr_escaped_single_quotes() { + // Python: 'it\'s' -> JSON: "it's" + assert_eq!(python_repr_to_json(r"{'a': 'it\'s'}"), r#"{"a": "it's"}"#); + } + + // ---- parse_accuracy ---- + + #[test] + fn parse_accuracy_numeric_string() { + assert!((parse_accuracy("85") - 85.0).abs() < f64::EPSILON); + } + + #[test] + fn parse_accuracy_text_with_number() { + assert!((parse_accuracy("The accuracy is 72 percent") - 72.0).abs() < f64::EPSILON); + } + + #[test] + fn parse_accuracy_clamps_above_100() { + assert!((parse_accuracy("150") - 100.0).abs() < f64::EPSILON); + } + + #[test] + fn parse_accuracy_no_number_returns_zero() { + assert!((parse_accuracy("no numbers here")).abs() < f64::EPSILON); + } + + // ---- normalize_tokens ---- + + #[test] + fn normalize_tokens_splits_punctuation() { + let tokens = normalize_tokens("Hello, world!"); + assert_eq!(tokens, vec!["hello", "world"]); + } + + #[test] + fn normalize_tokens_case_folding() { + let tokens = normalize_tokens("ABC def GHI"); + assert_eq!(tokens, vec!["abc", "def", "ghi"]); + } + + #[test] + fn normalize_tokens_empty_filtering() { + let tokens = normalize_tokens(" ... "); + assert!(tokens.is_empty()); + } + + // ---- token_f1 ---- + + #[test] + fn token_f1_identical() { + assert!((token_f1("the cat sat", "the cat sat") - 100.0).abs() < f64::EPSILON); + } + + #[test] + fn token_f1_no_overlap() { + assert!(token_f1("alpha beta", "gamma delta").abs() < f64::EPSILON); + } + + #[test] + fn token_f1_partial_overlap() { + let score = token_f1("the cat sat on the mat", "the cat"); + assert!(score > 0.0 && score < 100.0); + } + + #[test] + fn token_f1_both_empty() { + assert!((token_f1("", "") - 100.0).abs() < f64::EPSILON); + } + + // ---- exact_match ---- + + #[test] + fn exact_match_identical_after_normalization() { + assert!((exact_match("Hello, World!", "hello world") - 100.0).abs() < f64::EPSILON); + } + + #[test] + fn exact_match_different() { + assert!(exact_match("foo", "bar").abs() < f64::EPSILON); + } + + // ---- bleu_1 ---- + + #[test] + fn bleu_1_identical() { + assert!((bleu_1("the cat sat", "the cat sat") - 100.0).abs() < f64::EPSILON); + } + + #[test] + fn bleu_1_no_overlap() { + assert!(bleu_1("alpha beta", "gamma delta").abs() < f64::EPSILON); + } + + #[test] + fn bleu_1_brevity_penalty() { + // hyp is shorter than ref, so brevity penalty applies + let long_ref = "the quick brown fox jumps over the lazy dog"; + let short_hyp = "the fox"; + let score = bleu_1(long_ref, short_hyp); + assert!(score > 0.0 && score < 100.0); + } + + // ---- MetricFlags::from_arg ---- + + #[test] + fn metric_flags_all() { + let f = MetricFlags::from_arg("all"); + assert!(f.accuracy && f.faithfulness && f.relevancy); + } + + #[test] + fn metric_flags_single() { + let f = MetricFlags::from_arg("accuracy"); + assert!(f.accuracy); + assert!(!f.faithfulness); + assert!(!f.relevancy); + } + + #[test] + fn metric_flags_comma_separated() { + let f = MetricFlags::from_arg("accuracy,relevancy"); + assert!(f.accuracy); + assert!(!f.faithfulness); + assert!(f.relevancy); + } +} diff --git a/cli/src/bench/ruler.rs b/cli/src/bench/ruler.rs index 120751b4..cc8e4a74 100644 --- a/cli/src/bench/ruler.rs +++ b/cli/src/bench/ruler.rs @@ -40,7 +40,12 @@ pub struct RulerArgs { pub measurement_time: u64, /// Comma-separated context lengths in tokens - #[arg(long, default_value = "4000,8000", value_parser = super::parse_csv_usize, env = "BENCH_CONTEXT_LENGTHS")] + #[arg( + long, + default_value = "4000,8000", + value_delimiter = ',', + env = "BENCH_CONTEXT_LENGTHS" + )] pub context_lengths: Vec, /// Comma-separated tasks to run (niah, variable_tracking) @@ -69,7 +74,12 @@ pub struct RulerArgs { pub niah_num_needle_q: usize, /// Comma-separated depth percentages for NIAH - #[arg(long, default_value = "50", value_parser = super::parse_csv_usize, env = "BENCH_NIAH_DEPTH_PERCENTAGES")] + #[arg( + long, + default_value = "50", + value_delimiter = ',', + env = "BENCH_NIAH_DEPTH_PERCENTAGES" + )] pub niah_depth_percentages: Vec, // -- Variable Tracking-specific -- @@ -635,3 +645,32 @@ pub fn run(args: RulerArgs) -> Result<(), SpnlError> { criterion.final_summary(); Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn string_match_all_all_found() { + let refs = vec!["alpha".into(), "beta".into()]; + assert!((string_match_all("alpha and beta are here", &refs) - 1.0).abs() < f64::EPSILON); + } + + #[test] + fn string_match_all_partial() { + let refs = vec!["alpha".into(), "beta".into(), "gamma".into()]; + assert!((string_match_all("only alpha and gamma", &refs) - 2.0 / 3.0).abs() < 1e-10); + } + + #[test] + fn string_match_all_none_found() { + let refs = vec!["alpha".into(), "beta".into()]; + assert!(string_match_all("nothing here", &refs).abs() < f64::EPSILON); + } + + #[test] + fn string_match_all_case_insensitive() { + let refs = vec!["HELLO".into()]; + assert!((string_match_all("hello world", &refs) - 1.0).abs() < f64::EPSILON); + } +}