diff --git a/backend/crates/kalamdb-api/src/ws/events/subscription.rs b/backend/crates/kalamdb-api/src/ws/events/subscription.rs index 6117ffdf..1b0c4af1 100644 --- a/backend/crates/kalamdb-api/src/ws/events/subscription.rs +++ b/backend/crates/kalamdb-api/src/ws/events/subscription.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use actix_ws::Session; use kalamdb_commons::{ - websocket::{BatchControl, SubscriptionRequest, MAX_ROWS_PER_BATCH}, + websocket::{BatchControl, SubscriptionOptions, SubscriptionRequest, MAX_ROWS_PER_BATCH}, WebSocketMessage, }; use kalamdb_core::providers::arrow_json_conversion::row_into_json_map; @@ -61,11 +61,19 @@ pub async fn handle_subscribe( let subscription_id = subscription.id.clone(); let subscription_options = subscription.options.clone(); - // Determine batch size for initial data options - let batch_size = subscription_options - .as_ref() - .and_then(|options| options.batch_size) - .unwrap_or(MAX_ROWS_PER_BATCH); + let batch_size = subscription_batch_size(subscription_options.as_ref()); + + if let Err(message) = validate_subscription_options(subscription_options.as_ref(), batch_size) { + let _ = send_error( + session, + &subscription_id, + WsErrorCode::Unsupported, + &message, + compression_enabled, + ) + .await; + return Ok(()); + } // Create initial data options respecting all three options: // - from: Resume from a specific sequence ID @@ -214,3 +222,79 @@ pub async fn handle_subscribe( }, } } + +fn subscription_batch_size(options: Option<&SubscriptionOptions>) -> usize { + options + .and_then(|options| options.batch_size) + .unwrap_or(MAX_ROWS_PER_BATCH) +} + +fn validate_subscription_options( + options: Option<&SubscriptionOptions>, + batch_size: usize, +) -> Result<(), String> { + let Some(options) = options else { + return Ok(()); + }; + + if let Some(last_rows) = options.last_rows { + let last_rows = last_rows as usize; + if last_rows > batch_size { + return Err(format!( + "last_rows ({last_rows}) cannot exceed batch_size ({batch_size}); paginated last_rows replay is not supported" + )); + } + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use kalamdb_commons::websocket::SubscriptionOptions; + + use super::{subscription_batch_size, validate_subscription_options, MAX_ROWS_PER_BATCH}; + + #[test] + fn validate_subscription_options_allows_last_rows_within_batch_size() { + let options = SubscriptionOptions { + batch_size: Some(50), + last_rows: Some(50), + from: None, + }; + + let batch_size = subscription_batch_size(Some(&options)); + + assert_eq!(batch_size, 50); + assert!(validate_subscription_options(Some(&options), batch_size).is_ok()); + } + + #[test] + fn validate_subscription_options_rejects_last_rows_above_batch_size() { + let options = SubscriptionOptions { + batch_size: Some(50), + last_rows: Some(51), + from: None, + }; + + let batch_size = subscription_batch_size(Some(&options)); + let error = validate_subscription_options(Some(&options), batch_size) + .expect_err("last_rows above batch_size should be rejected"); + + assert!(error.contains("last_rows (51) cannot exceed batch_size (50)")); + } + + #[test] + fn validate_subscription_options_uses_default_batch_size_when_unspecified() { + let options = SubscriptionOptions { + batch_size: None, + last_rows: Some(MAX_ROWS_PER_BATCH as u32 + 1), + from: None, + }; + + let batch_size = subscription_batch_size(Some(&options)); + + assert_eq!(batch_size, MAX_ROWS_PER_BATCH); + assert!(validate_subscription_options(Some(&options), batch_size).is_err()); + } +} diff --git a/backend/crates/kalamdb-dialect/src/ddl/subscribe_commands.rs b/backend/crates/kalamdb-dialect/src/ddl/subscribe_commands.rs index c5c70afe..3909a831 100644 --- a/backend/crates/kalamdb-dialect/src/ddl/subscribe_commands.rs +++ b/backend/crates/kalamdb-dialect/src/ddl/subscribe_commands.rs @@ -24,7 +24,7 @@ //! SUBSCRIBE TO app.messages WHERE user_id = CURRENT_USER() OPTIONS (last_rows=10); //! //! -- With multiple options -//! SUBSCRIBE TO app.messages OPTIONS (last_rows=100, batch_size=50); +//! SUBSCRIBE TO app.messages OPTIONS (last_rows=50, batch_size=50); //! //! -- Resume from specific sequence ID //! SUBSCRIBE TO app.messages OPTIONS (from=12345); @@ -603,12 +603,12 @@ mod tests { use kalamdb_commons::ids::SeqId; let stmt = SubscribeStatement::parse( - "SUBSCRIBE TO app.messages OPTIONS (last_rows=100, batch_size=50, from=999)", + "SUBSCRIBE TO app.messages OPTIONS (last_rows=50, batch_size=50, from=999)", ) .unwrap(); assert_eq!(stmt.namespace, NamespaceId::from("app")); assert_eq!(stmt.table_name, TableName::from("messages")); - assert_eq!(stmt.options.last_rows, Some(100)); + assert_eq!(stmt.options.last_rows, Some(50)); assert_eq!(stmt.options.batch_size, Some(50)); assert_eq!(stmt.options.from, Some(SeqId::new(999))); } diff --git a/backend/crates/kalamdb-live/src/helpers/initial_data.rs b/backend/crates/kalamdb-live/src/helpers/initial_data.rs index a7efe8b3..b7c92be2 100644 --- a/backend/crates/kalamdb-live/src/helpers/initial_data.rs +++ b/backend/crates/kalamdb-live/src/helpers/initial_data.rs @@ -6,7 +6,10 @@ use std::{collections::BTreeMap, fmt::Write, sync::Arc}; -use datafusion::arrow::array::{Array, Int64Array}; +use datafusion::arrow::{ + array::{Array, Int64Array}, + record_batch::RecordBatch, +}; use datafusion_common::ScalarValue; use kalamdb_commons::{ constants::SystemColumnNames, @@ -159,6 +162,14 @@ pub struct InitialDataFetcher { sql_executor: Arc>>, } +#[derive(Debug, Clone, Copy)] +struct TableCapabilities { + has_commit_seq: bool, + has_deleted: bool, +} + +const BLOCKING_MATERIALIZATION_ROW_THRESHOLD: usize = 4_096; + impl InitialDataFetcher { /// Create a new initial data fetcher. /// @@ -225,7 +236,8 @@ impl InitialDataFetcher { // Build SELECT clause: either specific columns or * // Always include _seq column for pagination, even if not in projections - let has_commit_seq = self.table_has_column(table_id, SystemColumnNames::COMMIT_SEQ)?; + let table_capabilities = self.table_capabilities(table_id)?; + let has_commit_seq = table_capabilities.has_commit_seq; let select_clause = if let Some(cols) = projections { // Ensure system resume columns are always included for pagination tracking. let mut columns = cols.to_vec(); @@ -243,7 +255,7 @@ impl InitialDataFetcher { let mut sql = format!("SELECT {} FROM {}", select_clause, table_name); let where_clauses = - self.build_where_clauses(table_id, table_type, &options, where_clause)?; + self.build_where_clauses(table_type, &options, where_clause, table_capabilities); if !where_clauses.is_empty() { sql.push_str(" WHERE "); @@ -275,83 +287,15 @@ impl InitialDataFetcher { .execute_for_batches(&sql, user_id, role, ReadContext::Internal) .await?; - // Convert batches to Rows - // Pre-allocate with limit+1 since that's the max we'll fetch - let mut rows_with_seq: Vec<(SeqId, Option, Row)> = Vec::with_capacity(limit + 1); - - for batch in batches { - let schema = batch.schema(); - let seq_col_idx = schema.index_of(SystemColumnNames::SEQ).map_err(|_| { - LiveError::Other(format!("Result missing {} column", SystemColumnNames::SEQ)) - })?; - - let seq_col = batch.column(seq_col_idx); - let seq_array = seq_col - .as_any() - .downcast_ref::() - .ok_or_else(|| { - LiveError::Other(format!("{} column is not Int64", SystemColumnNames::SEQ)) - })?; - let commit_seq_array = if has_commit_seq { - let commit_idx = schema.index_of(SystemColumnNames::COMMIT_SEQ).map_err(|_| { - LiveError::Other(format!( - "Result missing {} column", - SystemColumnNames::COMMIT_SEQ - )) - })?; - Some(batch.column(commit_idx)) - } else { - None - }; - - let num_rows = batch.num_rows(); - let num_cols = batch.num_columns(); - - for row_idx in 0..num_rows { - let mut row_map = BTreeMap::new(); - for col_idx in 0..num_cols { - let col_name = schema.field(col_idx).name(); - if col_name == SystemColumnNames::COMMIT_SEQ { - continue; - } - let col_array = batch.column(col_idx); - let value = ScalarValue::try_from_array(col_array, row_idx) - .into_serialization_error("Failed to convert to ScalarValue")?; - row_map.insert(col_name.clone(), value); - } - - let seq_val = seq_array.value(row_idx); - let seq_id = SeqId::from(seq_val); - let commit_seq = commit_seq_array - .as_ref() - .and_then(|array| ScalarValue::try_from_array(array, row_idx).ok()) - .and_then(|value| match value { - ScalarValue::UInt64(Some(commit_seq)) => Some(commit_seq), - ScalarValue::Int64(Some(commit_seq)) if commit_seq >= 0 => { - Some(commit_seq as u64) - }, - _ => None, - }); - rows_with_seq.push((seq_id, commit_seq, Row::new(row_map))); - } - } - - if has_commit_seq && options.since_commit_seq.is_some() { - rows_with_seq - .sort_unstable_by_key(|(seq_id, commit_seq, _)| (commit_seq.unwrap_or(0), *seq_id)); - } else { - rows_with_seq.sort_unstable_by_key(|(seq_id, _, _)| *seq_id); - } - if options.fetch_last { - rows_with_seq.reverse(); - } + let mut rows_with_seq = materialize_initial_rows(batches, has_commit_seq, limit + 1).await?; // Determine has_more and slice to limit let total_fetched = rows_with_seq.len(); - let has_more = total_fetched > limit; + let over_limit = total_fetched > limit; + let has_more = !options.fetch_last && over_limit; // Truncate in-place instead of collecting into a new Vec - if has_more { + if over_limit { rows_with_seq.truncate(limit); } @@ -396,6 +340,11 @@ impl InitialDataFetcher { options: &InitialDataOptions, where_clause: Option<&str>, ) -> Result, LiveError> { + let table_capabilities = self.table_capabilities(table_id)?; + if table_capabilities.has_commit_seq { + return Ok(None); + } + self.compute_snapshot_end_seq_sql_fallback( live_id, role, @@ -403,6 +352,7 @@ impl InitialDataFetcher { table_type, options, where_clause, + table_capabilities, ) .await } @@ -418,7 +368,8 @@ impl InitialDataFetcher { options: &InitialDataOptions, where_clause: Option<&str>, ) -> Result, LiveError> { - if !self.table_has_column(table_id, SystemColumnNames::COMMIT_SEQ)? { + let table_capabilities = self.table_capabilities(table_id)?; + if !table_capabilities.has_commit_seq { return Ok(None); } @@ -431,7 +382,7 @@ impl InitialDataFetcher { ); let where_clauses = - self.build_where_clauses(table_id, table_type, options, where_clause)?; + self.build_where_clauses(table_type, options, where_clause, table_capabilities); if !where_clauses.is_empty() { sql.push_str(" WHERE "); sql.push_str(&where_clauses.join(" AND ")); @@ -466,6 +417,7 @@ impl InitialDataFetcher { table_type: TableType, options: &InitialDataOptions, where_clause: Option<&str>, + table_capabilities: TableCapabilities, ) -> Result, LiveError> { let user_id = live_id.user_id().clone(); @@ -474,7 +426,7 @@ impl InitialDataFetcher { format!("SELECT MAX({}) AS max_seq FROM {}", SystemColumnNames::SEQ, table_name); let where_clauses = - self.build_where_clauses(table_id, table_type, options, where_clause)?; + self.build_where_clauses(table_type, options, where_clause, table_capabilities); if !where_clauses.is_empty() { sql.push_str(" WHERE "); sql.push_str(&where_clauses.join(" AND ")); @@ -505,16 +457,14 @@ impl InitialDataFetcher { fn build_where_clauses( &self, - table_id: &TableId, table_type: TableType, options: &InitialDataOptions, where_clause: Option<&str>, - ) -> Result, LiveError> { + table_capabilities: TableCapabilities, + ) -> Vec { let mut where_clauses = Vec::new(); - let has_commit_seq = self.table_has_column(table_id, SystemColumnNames::COMMIT_SEQ)?; - - if has_commit_seq { + if table_capabilities.has_commit_seq { match (options.since_commit_seq, options.since_seq) { (Some(since_commit), Some(since_seq)) => where_clauses.push(format!( "({commit_col} > {since_commit} OR ({commit_col} = {since_commit} AND \ @@ -556,7 +506,7 @@ impl InitialDataFetcher { if !options.include_deleted && matches!(table_type, TableType::User | TableType::Shared) - && self.table_has_column(table_id, SystemColumnNames::DELETED)? + && table_capabilities.has_deleted { where_clauses.push(format!("{} = false", SystemColumnNames::DELETED)); } @@ -565,19 +515,107 @@ impl InitialDataFetcher { where_clauses.push(where_sql.to_string()); } - Ok(where_clauses) + where_clauses } - fn table_has_column(&self, table_id: &TableId, column_name: &str) -> Result { + fn table_capabilities(&self, table_id: &TableId) -> Result { let schema = self.schema_lookup.get_arrow_schema(table_id)?; - Ok(schema.field_with_name(column_name).is_ok()) + Ok(TableCapabilities { + has_commit_seq: schema.field_with_name(SystemColumnNames::COMMIT_SEQ).is_ok(), + has_deleted: schema.field_with_name(SystemColumnNames::DELETED).is_ok(), + }) + } +} + +async fn materialize_initial_rows( + batches: Vec, + has_commit_seq: bool, + capacity_hint: usize, +) -> Result, Row)>, LiveError> { + let row_count = batches.iter().map(|batch| batch.num_rows()).sum::(); + if row_count <= BLOCKING_MATERIALIZATION_ROW_THRESHOLD { + return materialize_initial_rows_sync(batches, has_commit_seq, capacity_hint); + } + + tokio::task::spawn_blocking(move || { + materialize_initial_rows_sync(batches, has_commit_seq, capacity_hint) + }) + .await + .map_err(|err| LiveError::Other(format!("initial data materialization task failed: {err}")))? +} + +fn materialize_initial_rows_sync( + batches: Vec, + has_commit_seq: bool, + capacity_hint: usize, +) -> Result, Row)>, LiveError> { + let mut rows_with_seq = Vec::with_capacity(capacity_hint); + + for batch in batches { + let schema = batch.schema(); + let seq_col_idx = schema.index_of(SystemColumnNames::SEQ).map_err(|_| { + LiveError::Other(format!("Result missing {} column", SystemColumnNames::SEQ)) + })?; + + let seq_col = batch.column(seq_col_idx); + let seq_array = seq_col.as_any().downcast_ref::().ok_or_else(|| { + LiveError::Other(format!("{} column is not Int64", SystemColumnNames::SEQ)) + })?; + let commit_seq_array = if has_commit_seq { + let commit_idx = schema.index_of(SystemColumnNames::COMMIT_SEQ).map_err(|_| { + LiveError::Other(format!( + "Result missing {} column", + SystemColumnNames::COMMIT_SEQ + )) + })?; + Some(batch.column(commit_idx)) + } else { + None + }; + + let num_rows = batch.num_rows(); + let num_cols = batch.num_columns(); + + for row_idx in 0..num_rows { + let mut row_map = BTreeMap::new(); + for col_idx in 0..num_cols { + let col_name = schema.field(col_idx).name(); + if col_name == SystemColumnNames::COMMIT_SEQ { + continue; + } + let col_array = batch.column(col_idx); + let value = ScalarValue::try_from_array(col_array, row_idx) + .into_serialization_error("Failed to convert to ScalarValue")?; + row_map.insert(col_name.clone(), value); + } + + let seq_id = SeqId::from(seq_array.value(row_idx)); + let commit_seq = commit_seq_array + .as_ref() + .and_then(|array| ScalarValue::try_from_array(array, row_idx).ok()) + .and_then(|value| match value { + ScalarValue::UInt64(Some(commit_seq)) => Some(commit_seq), + ScalarValue::Int64(Some(commit_seq)) if commit_seq >= 0 => { + Some(commit_seq as u64) + }, + _ => None, + }); + rows_with_seq.push((seq_id, commit_seq, Row::new(row_map))); + } } + + Ok(rows_with_seq) } #[cfg(test)] mod tests { + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc as StdArc, + }; + use arrow::{ - array::Int64Array, + array::{Int64Array, UInt64Array}, datatypes::{DataType, Field, Schema}, }; use async_trait::async_trait; @@ -602,6 +640,21 @@ mod tests { } } + struct CommitSeqSchemaLookup; + + impl LiveSchemaLookup for CommitSeqSchemaLookup { + fn get_table_definition(&self, _table_id: &TableId) -> Option> { + None + } + + fn get_arrow_schema(&self, _table_id: &TableId) -> Result, LiveError> { + Ok(Arc::new(Schema::new(vec![ + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + Field::new(SystemColumnNames::COMMIT_SEQ, DataType::UInt64, false), + ]))) + } + } + struct MaxSeqExecutor { seen_sql: Mutex>, } @@ -623,6 +676,76 @@ mod tests { } } + struct CountingExecutor { + calls: AtomicUsize, + } + + #[async_trait] + impl LiveSqlExecutor for CountingExecutor { + async fn execute_for_batches( + &self, + _sql: &str, + _user_id: kalamdb_commons::models::UserId, + _role: Role, + _read_context: ReadContext, + ) -> Result, LiveError> { + self.calls.fetch_add(1, Ordering::Relaxed); + Ok(Vec::new()) + } + } + + struct CountingSchemaLookup { + schema: Arc, + calls: StdArc, + } + + impl LiveSchemaLookup for CountingSchemaLookup { + fn get_table_definition(&self, _table_id: &TableId) -> Option> { + None + } + + fn get_arrow_schema(&self, _table_id: &TableId) -> Result, LiveError> { + self.calls.fetch_add(1, Ordering::Relaxed); + Ok(Arc::clone(&self.schema)) + } + } + + struct StaticBatchExecutor { + batches: Vec, + } + + #[async_trait] + impl LiveSqlExecutor for StaticBatchExecutor { + async fn execute_for_batches( + &self, + _sql: &str, + _user_id: kalamdb_commons::models::UserId, + _role: Role, + _read_context: ReadContext, + ) -> Result, LiveError> { + Ok(self.batches.clone()) + } + } + + struct CaptureFetchExecutor { + seen_sql: Mutex>, + batches: Vec, + } + + #[async_trait] + impl LiveSqlExecutor for CaptureFetchExecutor { + async fn execute_for_batches( + &self, + sql: &str, + _user_id: kalamdb_commons::models::UserId, + _role: Role, + _read_context: ReadContext, + ) -> Result, LiveError> { + self.seen_sql.lock().push(sql.to_string()); + Ok(self.batches.clone()) + } + } + #[test] fn test_initial_data_options_default() { let options = InitialDataOptions::default(); @@ -695,4 +818,481 @@ mod tests { Some("SELECT MAX(_seq) AS max_seq FROM app.items") ); } + + #[tokio::test] + async fn snapshot_seq_boundary_skips_sql_when_commit_seq_is_available() { + let fetcher = InitialDataFetcher::new(Arc::new(CommitSeqSchemaLookup)); + let executor = Arc::new(CountingExecutor { + calls: AtomicUsize::new(0), + }); + fetcher.set_sql_executor(executor.clone()); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let boundary = fetcher + .compute_snapshot_end_seq( + &live_id, + Role::User, + &table_id, + TableType::User, + &InitialDataOptions::default(), + None, + ) + .await + .expect("snapshot boundary check"); + + assert_eq!(boundary, None); + assert_eq!(executor.calls.load(Ordering::Relaxed), 0); + } + + #[tokio::test] + async fn fetch_initial_data_reads_schema_once_for_system_column_capabilities() { + let schema_calls = StdArc::new(AtomicUsize::new(0)); + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + Field::new(SystemColumnNames::COMMIT_SEQ, DataType::UInt64, false), + Field::new(SystemColumnNames::DELETED, DataType::Boolean, false), + ])); + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: Arc::clone(&schema_calls), + })); + fetcher.set_sql_executor(Arc::new(CountingExecutor { + calls: AtomicUsize::new(0), + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(None, Some(SeqId::from(10)), 100), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + assert!(result.rows.is_empty()); + assert_eq!(schema_calls.load(Ordering::Relaxed), 1); + } + + #[tokio::test] + async fn fetch_initial_data_builds_seq_window_sql() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let executor = Arc::new(CaptureFetchExecutor { + seen_sql: Mutex::new(Vec::new()), + batches: Vec::new(), + }); + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(executor.clone()); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(Some(SeqId::from(10)), Some(SeqId::from(40)), 2), + Some("id > 0"), + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + assert!(result.rows.is_empty()); + assert_eq!( + executor.seen_sql.lock().as_slice(), + ["SELECT id, _seq FROM app.items WHERE _seq > 10 AND _seq <= 40 AND id > 0 ORDER BY _seq ASC LIMIT 3"], + ); + } + + #[tokio::test] + async fn fetch_initial_data_builds_commit_seq_resume_sql() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + Field::new(SystemColumnNames::COMMIT_SEQ, DataType::UInt64, false), + ])); + let executor = Arc::new(CaptureFetchExecutor { + seen_sql: Mutex::new(Vec::new()), + batches: Vec::new(), + }); + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(executor.clone()); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(Some(SeqId::from(10)), None, 2) + .with_commit_range(Some(7), Some(9)), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + assert!(result.rows.is_empty()); + assert_eq!( + executor.seen_sql.lock().as_slice(), + ["SELECT id, _seq, _commit_seq FROM app.items WHERE (_commit_seq > 7 OR (_commit_seq = 7 AND _seq > 10)) AND _commit_seq <= 9 ORDER BY _commit_seq ASC, _seq ASC LIMIT 3"], + ); + } + + #[tokio::test] + async fn materialize_initial_rows_extracts_resume_columns_and_rows() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + Field::new(SystemColumnNames::COMMIT_SEQ, DataType::UInt64, false), + ])); + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![10, 20])), + Arc::new(Int64Array::from(vec![100, 200])), + Arc::new(UInt64Array::from(vec![7, 8])), + ], + ) + .expect("record batch"); + + let rows = materialize_initial_rows(vec![batch], true, 3) + .await + .expect("materialized rows"); + + assert_eq!(rows.len(), 2); + assert_eq!(rows[0].0, SeqId::from(100)); + assert_eq!(rows[0].1, Some(7)); + assert_eq!(rows[1].0, SeqId::from(200)); + assert_eq!(rows[1].1, Some(8)); + assert!(rows[0].2.values.contains_key("id")); + assert!(!rows[0].2.values.contains_key(SystemColumnNames::COMMIT_SEQ)); + } + + #[tokio::test] + async fn fetch_initial_data_preserves_executor_row_order() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2, 1])), + Arc::new(Int64Array::from(vec![20, 10])), + ], + ) + .expect("record batch"); + + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(Arc::new(StaticBatchExecutor { + batches: vec![batch], + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(None, Some(SeqId::from(20)), 10), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + let ids: Vec = result + .rows + .iter() + .map(|row| match row.values.get("id") { + Some(ScalarValue::Int64(Some(value))) => *value, + other => panic!("unexpected id value: {other:?}"), + }) + .collect(); + + assert_eq!(ids, vec![2, 1]); + } + + #[tokio::test] + async fn fetch_initial_data_exact_batch_boundary_is_ready() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![1, 2])), + Arc::new(Int64Array::from(vec![10, 20])), + ], + ) + .expect("record batch"); + + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(Arc::new(StaticBatchExecutor { + batches: vec![batch], + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(None, Some(SeqId::from(20)), 2), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + let ids: Vec = result + .rows + .iter() + .map(|row| match row.values.get("id") { + Some(ScalarValue::Int64(Some(value))) => *value, + other => panic!("unexpected id value: {other:?}"), + }) + .collect(); + + assert_eq!(ids, vec![1, 2]); + assert!(!result.has_more); + assert_eq!(result.last_seq, Some(SeqId::from(20))); + assert_eq!(result.snapshot_end_seq, Some(SeqId::from(20))); + } + + #[tokio::test] + async fn fetch_initial_data_resume_batch_preserves_cursor_state() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2, 3, 4])), + Arc::new(Int64Array::from(vec![20, 30, 40])), + ], + ) + .expect("record batch"); + + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(Arc::new(StaticBatchExecutor { + batches: vec![batch], + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::batch(Some(SeqId::from(10)), Some(SeqId::from(40)), 2), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + let ids: Vec = result + .rows + .iter() + .map(|row| match row.values.get("id") { + Some(ScalarValue::Int64(Some(value))) => *value, + other => panic!("unexpected id value: {other:?}"), + }) + .collect(); + + assert_eq!(ids, vec![2, 3]); + assert!(result.has_more); + assert_eq!(result.last_seq, Some(SeqId::from(30))); + assert_eq!(result.snapshot_end_seq, Some(SeqId::from(40))); + } + + #[tokio::test] + async fn fetch_last_rows_does_not_paginate_older_history() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![6, 5, 4, 3, 2, 1])), + Arc::new(Int64Array::from(vec![60, 50, 40, 30, 20, 10])), + ], + ) + .expect("record batch"); + + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(Arc::new(StaticBatchExecutor { + batches: vec![batch], + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::last(5), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + let ids: Vec = result + .rows + .iter() + .map(|row| match row.values.get("id") { + Some(ScalarValue::Int64(Some(value))) => *value, + other => panic!("unexpected id value: {other:?}"), + }) + .collect(); + + assert_eq!(ids, vec![2, 3, 4, 5, 6]); + assert!(!result.has_more); + assert_eq!(result.last_seq, Some(SeqId::from(60))); + assert_eq!(result.snapshot_end_seq, Some(SeqId::from(60))); + } + + #[tokio::test] + async fn fetch_last_rows_returns_all_when_row_count_is_below_limit() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new(SystemColumnNames::SEQ, DataType::Int64, false), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int64Array::from(vec![2, 1])), + Arc::new(Int64Array::from(vec![20, 10])), + ], + ) + .expect("record batch"); + + let fetcher = InitialDataFetcher::new(Arc::new(CountingSchemaLookup { + schema, + calls: StdArc::new(AtomicUsize::new(0)), + })); + fetcher.set_sql_executor(Arc::new(StaticBatchExecutor { + batches: vec![batch], + })); + + let table_id = TableId::new(NamespaceId::from("app"), TableName::from("items")); + let live_id = LiveQueryId::new( + UserId::new("u1"), + kalamdb_commons::models::ConnectionId::new("c1"), + "sub1".to_string(), + ); + + let result = fetcher + .fetch_initial_data( + &live_id, + Role::User, + &table_id, + TableType::User, + InitialDataOptions::last(5), + None, + Some(&["id".to_string()]), + ) + .await + .expect("initial data fetch"); + + let ids: Vec = result + .rows + .iter() + .map(|row| match row.values.get("id") { + Some(ScalarValue::Int64(Some(value))) => *value, + other => panic!("unexpected id value: {other:?}"), + }) + .collect(); + + assert_eq!(ids, vec![1, 2]); + assert!(!result.has_more); + assert_eq!(result.last_seq, Some(SeqId::from(20))); + assert_eq!(result.snapshot_end_seq, Some(SeqId::from(20))); + } } diff --git a/backend/crates/kalamdb-live/src/models/connection.rs b/backend/crates/kalamdb-live/src/models/connection.rs index 48a0b4fb..0e8cd802 100644 --- a/backend/crates/kalamdb-live/src/models/connection.rs +++ b/backend/crates/kalamdb-live/src/models/connection.rs @@ -4,10 +4,9 @@ use std::{ collections::{HashMap, VecDeque}, - hash::{DefaultHasher, Hash, Hasher}, sync::{ atomic::{AtomicBool, AtomicI64, AtomicU64, Ordering}, - Arc, OnceLock, Weak, + Arc, OnceLock, }, time::{Instant, SystemTime, UNIX_EPOCH}, }; @@ -28,54 +27,6 @@ pub(crate) fn epoch_millis() -> u64 { SystemTime::now().duration_since(UNIX_EPOCH).unwrap_or_default().as_millis() as u64 } -#[derive(Default)] -struct SubscriptionStringPool { - buckets: HashMap>>, -} - -impl SubscriptionStringPool { - fn intern(&mut self, value: &str) -> Arc { - let hash = hash_subscription_str(value); - let bucket = self.buckets.entry(hash).or_default(); - let mut dead_entries = 0usize; - - for weak in bucket.iter() { - match weak.upgrade() { - Some(existing) if existing.as_ref() == value => return existing, - Some(_) => {}, - None => dead_entries += 1, - } - } - - if dead_entries > 0 { - bucket.retain(|weak| weak.strong_count() > 0); - } - - let interned: Arc = Arc::from(value); - bucket.push(Arc::downgrade(&interned)); - interned - } -} - -fn hash_subscription_str(value: &str) -> u64 { - let mut hasher = DefaultHasher::new(); - value.hash(&mut hasher); - hasher.finish() -} - -fn subscription_string_pool() -> &'static Mutex { - static STRING_POOL: OnceLock> = OnceLock::new(); - STRING_POOL.get_or_init(|| Mutex::new(SubscriptionStringPool::default())) -} - -fn intern_subscription_str(value: &str) -> Arc { - if value.is_empty() { - return Arc::from(""); - } - - subscription_string_pool().lock().intern(value) -} - /// Maximum live-query subscriptions allowed on a single WebSocket connection. pub const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 100; @@ -159,8 +110,8 @@ pub struct SubscriptionRuntimeMetadata { impl SubscriptionRuntimeMetadata { pub fn new(query: &str, options_json: Option<&str>, created_at_ms: i64) -> Self { Self { - query: intern_subscription_str(query), - options_json: options_json.map(intern_subscription_str), + query: Arc::from(query), + options_json: options_json.map(Arc::from), created_at_ms, last_update_ms: AtomicI64::new(created_at_ms), changes: AtomicI64::new(0), @@ -766,7 +717,7 @@ mod tests { } #[test] - fn test_subscription_runtime_metadata_interns_query_and_options() { + fn test_subscription_runtime_metadata_owns_query_and_options() { let first = SubscriptionRuntimeMetadata::new( "SELECT * FROM shared.events", Some(r#"{"batch_size":100}"#), @@ -778,8 +729,10 @@ mod tests { 2, ); - assert!(Arc::ptr_eq(&first.query, &second.query)); - assert!(Arc::ptr_eq( + assert_eq!(first.query(), second.query()); + assert_eq!(first.options_json(), second.options_json()); + assert!(!Arc::ptr_eq(&first.query, &second.query)); + assert!(!Arc::ptr_eq( first.options_json.as_ref().expect("options should exist"), second.options_json.as_ref().expect("options should exist"), )); diff --git a/backend/crates/kalamdb-live/src/notification.rs b/backend/crates/kalamdb-live/src/notification.rs index 4806e87e..d854de76 100644 --- a/backend/crates/kalamdb-live/src/notification.rs +++ b/backend/crates/kalamdb-live/src/notification.rs @@ -468,6 +468,11 @@ impl NotificationService { change_notification: ChangeNotification, all_handles: Arc>, ) -> Result { + let handle_count = all_handles.len(); + if handle_count == 0 { + return Ok(0); + } + let seq_value = extract_seq(&change_notification); let commit_seq = extract_commit_seq(&change_notification); let delivery_timestamp_ms = epoch_millis(); @@ -476,14 +481,29 @@ impl NotificationService { let new_row = Arc::new(change_notification.row_data); let old_row = change_notification.old_data.map(Arc::new); - let handle_count = all_handles.len(); - if handle_count == 0 { - return Ok(0); + if handle_count == 1 { + let Some(handle) = all_handles.iter().next().map(|entry| entry.value().clone()) else { + return Ok(0); + }; + + return dispatch_one( + handle, + &new_row, + old_row.as_deref(), + &change_type, + &pk_columns, + seq_value, + commit_seq, + delivery_timestamp_ms, + ); } // Small fan-out: inline dispatch directly from DashMap refs (no clone/spawn overhead) if handle_count <= SHARED_NOTIFY_CHUNK_SIZE { - let chunk_handles = all_handles.iter().map(|entry| entry.value().clone()).collect(); + let chunk_handles = all_handles + .iter() + .map(|entry| entry.value().clone()) + .collect::>(); return dispatch_chunk( chunk_handles, &new_row, @@ -666,6 +686,52 @@ fn dispatch_chunk( Ok(count) } +fn dispatch_one( + handle: SubscriptionHandle, + new_row: &Row, + old_row: Option<&Row>, + change_type: &ChangeType, + pk_columns: &[String], + seq_value: Option, + commit_seq: Option, + delivery_timestamp_ms: u64, +) -> Result { + if let Some(ref filter_expr) = handle.filter_expr { + match filter_matches(filter_expr, new_row) { + Ok(true) => {}, + Ok(false) => return Ok(0), + Err(e) => { + log::error!( + "Filter error for subscription_id={}: {}", + handle.subscription_id, + e + ); + return Ok(0); + }, + } + } + + let payload = Arc::new(build_shared_payload( + change_type, + new_row, + old_row, + pk_columns, + &handle.projections, + )?); + let notification = Arc::new(WireNotification { + subscription_id: Arc::clone(&handle.subscription_id), + payload, + }); + + Ok(usize::from(try_deliver( + &handle, + notification, + seq_value, + commit_seq, + delivery_timestamp_ms, + ))) +} + impl NotificationServiceTrait for NotificationService { type Notification = ChangeNotification; diff --git a/benchv2/Cargo.lock b/benchv2/Cargo.lock index 9595facf..5ea937f3 100644 --- a/benchv2/Cargo.lock +++ b/benchv2/Cargo.lock @@ -1363,9 +1363,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.2" +version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab3f43e3283ab1488b624b44b0e988d0acea0b3214e694730a055cb6b2efa801" +checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" dependencies = [ "base64", "bytes", @@ -1445,9 +1445,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustls" -version = "0.23.37" +version = "0.23.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" dependencies = [ "aws-lc-rs", "once_cell", diff --git a/benchv2/run-benchmarks.sh b/benchv2/run-benchmarks.sh index f212a4ec..b8e2986c 100755 --- a/benchv2/run-benchmarks.sh +++ b/benchv2/run-benchmarks.sh @@ -20,6 +20,7 @@ URLS="${KALAMDB_URLS:-${KALAMDB_URL:-http://127.0.0.1:8080}}" USER="${KALAMDB_USER:-admin}" PASSWORD="${KALAMDB_PASSWORD:-kalamdb123}" MAX_SUBSCRIBERS="${KALAMDB_MAX_SUBSCRIBERS:-}" +OUTPUT_DIR="results" EXTRA_ARGS=() BENCH_SERVER_PID="" @@ -190,6 +191,7 @@ while [[ $# -gt 0 ]]; do --user) USER="$2"; shift 2;; --password) PASSWORD="$2"; shift 2;; --max-subscribers) MAX_SUBSCRIBERS="$2"; shift 2;; + --output-dir) OUTPUT_DIR="$2"; EXTRA_ARGS+=("$1" "$2"); shift 2;; *) EXTRA_ARGS+=("$1"); shift;; esac done @@ -240,4 +242,4 @@ if [[ "$MANAGED_SERVER" == "yes" ]]; then fi echo "" -echo "📊 Results directory: $SCRIPT_DIR/results" +echo "📊 Results directory: $SCRIPT_DIR/$OUTPUT_DIR" diff --git a/benchv2/run-chat-realtime.sh b/benchv2/run-chat-realtime.sh index 2388f221..b7ccbbe2 100755 --- a/benchv2/run-chat-realtime.sh +++ b/benchv2/run-chat-realtime.sh @@ -74,7 +74,7 @@ else echo "▸ Conversation message rate: $MESSAGES_PER_MINUTE messages/min" fi -CMD=(./run-benchmarks.sh --urls "$URL" --bench chat_realtime --iterations 1 --warmup 0) +CMD=(./run-benchmarks.sh --urls "$URL" --suite chat-runtime --output-dir results/chat-runtime --bench chat_realtime --iterations 1 --warmup 0) if [[ -n "$BENCH_USER" ]]; then CMD+=(--user "$BENCH_USER") @@ -88,4 +88,4 @@ if (( ${#EXTRA_ARGS[@]} > 0 )); then CMD+=("${EXTRA_ARGS[@]}") fi -exec "${CMD[@]}" \ No newline at end of file +exec "${CMD[@]}" diff --git a/benchv2/src/benchmarks/mod.rs b/benchv2/src/benchmarks/mod.rs index c059649a..73e2a109 100644 --- a/benchv2/src/benchmarks/mod.rs +++ b/benchv2/src/benchmarks/mod.rs @@ -1,5 +1,4 @@ pub mod bulk_insert_bench; -pub mod chat_realtime_bench; pub mod concurrent_bench; pub mod connection_scale_bench; pub mod ddl_bench; @@ -98,16 +97,8 @@ pub trait Benchmark: Send + Sync { } } -/// Benchmarks included when no explicit `--bench` or `--filter` selection is provided. -/// -/// `connection_scale` remains opt-in because single-host runs often need extra -/// loopback aliases or explicit override flags to avoid macOS ephemeral-port limits. -pub fn enabled_in_default_suite(name: &str) -> bool { - !matches!(name, "connection_scale" | "chat_realtime") -} - -/// Returns all registered benchmarks. Add new benchmarks here. -pub fn all_benchmarks() -> Vec> { +/// Returns all standard benchmarks. Add default benchmark-suite cases here. +pub fn standard_benchmarks() -> Vec> { vec![ // --- Core operation benchmarks --- Box::new(ddl_bench::CreateTableBench { @@ -161,9 +152,16 @@ pub fn all_benchmarks() -> Vec> { Box::new(load_connection_storm_bench::ConnectionStormBench), Box::new(load_mixed_rw_bench::MixedReadWriteBench), Box::new(load_wide_fanout_bench::WideFanoutQueryBench), - Box::new(chat_realtime_bench::ChatRealtimeBench), // --- Scale tests (run with --iterations 1 --warmup 0 --filter subscriber_scale) --- Box::new(connection_scale_bench::ConnectionScaleBench), Box::new(subscriber_scale_bench::SubscriberScaleBench::default()), ] } + +/// Benchmarks included when no explicit `--bench` or `--filter` selection is provided. +/// +/// `connection_scale` remains opt-in because single-host runs often need extra +/// loopback aliases or explicit override flags to avoid macOS ephemeral-port limits. +pub fn enabled_in_default_suite(name: &str) -> bool { + !matches!(name, "connection_scale") +} diff --git a/benchv2/src/benchmarks/chat_realtime_bench.rs b/benchv2/src/chat_runtime/chat_realtime_bench.rs similarity index 94% rename from benchv2/src/benchmarks/chat_realtime_bench.rs rename to benchv2/src/chat_runtime/chat_realtime_bench.rs index fff27830..ef101537 100644 --- a/benchv2/src/benchmarks/chat_realtime_bench.rs +++ b/benchv2/src/chat_runtime/chat_realtime_bench.rs @@ -11,9 +11,9 @@ use kalam_client::{ }; use serde_json::Value as JsonValue; use sysinfo::{MemoryRefreshKind, Pid, ProcessRefreshKind, ProcessesToUpdate, RefreshKind, System}; -use tokio::sync::{Mutex as AsyncMutex, Semaphore, watch}; -use tokio::task::JoinSet; +use tokio::sync::{watch, Mutex as AsyncMutex, Semaphore}; use tokio::task::JoinHandle; +use tokio::task::JoinSet; use tokio::time::{sleep, timeout, Instant}; use crate::benchmarks::Benchmark; @@ -65,10 +65,8 @@ impl ChatWorkloadSettings { fn from_env() -> Result { let minutes = parse_u64_env("KALAMDB_BENCH_CHAT_MINUTES", DEFAULT_CHAT_MINUTES)?; let user_count = parse_u32_env("KALAMDB_BENCH_CHAT_USERS", DEFAULT_CHAT_USERS)?; - let realtime_conversations = parse_u32_env( - "KALAMDB_BENCH_CHAT_REALTIME_CONVS", - DEFAULT_CHAT_REALTIME_CONVS, - )?; + let realtime_conversations = + parse_u32_env("KALAMDB_BENCH_CHAT_REALTIME_CONVS", DEFAULT_CHAT_REALTIME_CONVS)?; let messages_per_minute = parse_u32_env( "KALAMDB_BENCH_CHAT_MESSAGES_PER_MINUTE", DEFAULT_CHAT_MESSAGES_PER_MINUTE, @@ -81,9 +79,7 @@ impl ChatWorkloadSettings { return Err("KALAMDB_BENCH_CHAT_USERS must be at least 2".to_string()); } if realtime_conversations == 0 { - return Err( - "KALAMDB_BENCH_CHAT_REALTIME_CONVS must be greater than zero".to_string(), - ); + return Err("KALAMDB_BENCH_CHAT_REALTIME_CONVS must be greater than zero".to_string()); } Ok(Self { @@ -248,12 +244,8 @@ impl Benchmark for ChatRealtimeBench { let conversation_topic = conversation_topic_name(&config.namespace); let message_topic = message_topic_name(&config.namespace); - let _ = client - .sql(&format!("DROP TOPIC IF EXISTS {}", message_topic)) - .await; - let _ = client - .sql(&format!("DROP TOPIC IF EXISTS {}", conversation_topic)) - .await; + let _ = client.sql(&format!("DROP TOPIC IF EXISTS {}", message_topic)).await; + let _ = client.sql(&format!("DROP TOPIC IF EXISTS {}", conversation_topic)).await; let _ = client .sql(&format!("DROP TOPIC IF EXISTS {}", typing_topic_name(&config.namespace))) .await; @@ -268,9 +260,7 @@ impl Benchmark for ChatRealtimeBench { .await; for username in &usernames { - let _ = client - .sql(&format!("DROP USER IF EXISTS {}", sql_literal(username))) - .await; + let _ = client.sql(&format!("DROP USER IF EXISTS {}", sql_literal(username))).await; } run_sql_with_retry( @@ -379,15 +369,12 @@ impl Benchmark for ChatRealtimeBench { global_stop.clone(), iteration, ); - let conversation_ids = Arc::new(AtomicU64::new( - 40_000_000_000 + u64::from(iteration) * 1_000_000, - )); - let message_ids = Arc::new(AtomicU64::new( - 50_000_000_000 + u64::from(iteration) * 10_000_000, - )); - let typing_ids = Arc::new(AtomicU64::new( - 60_000_000_000 + u64::from(iteration) * 10_000_000, - )); + let conversation_ids = + Arc::new(AtomicU64::new(40_000_000_000 + u64::from(iteration) * 1_000_000)); + let message_ids = + Arc::new(AtomicU64::new(50_000_000_000 + u64::from(iteration) * 10_000_000)); + let typing_ids = + Arc::new(AtomicU64::new(60_000_000_000 + u64::from(iteration) * 10_000_000)); let prewarmed_active_users = prewarm_user_clients( user_pool.clone(), @@ -424,17 +411,6 @@ impl Benchmark for ChatRealtimeBench { CHAT_TYPING_INTERVAL.as_secs(), ); - println!( - " Chat workload settings: duration={}m, regular_users={}, target_active_chat_users={}, active_conversations={}, message_rate={}, typing_burst={}x{}s", - settings.minutes, - settings.user_count, - target_active_user_count, - settings.realtime_conversations, - settings.message_rate_label(), - CHAT_TYPING_BURSTS, - CHAT_TYPING_INTERVAL.as_secs(), - ); - for worker_id in 0..settings.realtime_conversations { let namespace = config.namespace.clone(); let worker_stats = stats.clone(); @@ -444,9 +420,7 @@ impl Benchmark for ChatRealtimeBench { let worker_conversations = conversation_ids.clone(); let worker_messages = message_ids.clone(); let worker_typing = typing_ids.clone(); - let worker_start_delay = chat_worker_start_delay( - worker_id, - ); + let worker_start_delay = chat_worker_start_delay(worker_id); let worker_deadline = run_deadline + worker_start_delay; handles.push(tokio::spawn(async move { @@ -475,7 +449,7 @@ impl Benchmark for ChatRealtimeBench { let mut errors = Vec::new(); for handle in handles { match handle.await { - Ok(Ok(())) => {} + Ok(Ok(())) => {}, Ok(Err(error)) => errors.push(error), Err(error) => errors.push(format!("worker join error: {}", error)), } @@ -517,10 +491,7 @@ impl Benchmark for ChatRealtimeBench { )) .await; let _ = client - .sql(&format!( - "DROP TOPIC IF EXISTS {}", - typing_topic_name(&config.namespace) - )) + .sql(&format!("DROP TOPIC IF EXISTS {}", typing_topic_name(&config.namespace))) .await; let _ = client .sql(&format!("DROP STREAM TABLE IF EXISTS {}.typing_events", config.namespace)) @@ -533,9 +504,8 @@ impl Benchmark for ChatRealtimeBench { .await; for username in users { - let _ = client - .sql(&format!("DROP USER IF EXISTS {}", sql_literal(&username))) - .await; + let _ = + client.sql(&format!("DROP USER IF EXISTS {}", sql_literal(&username))).await; } Ok(()) @@ -761,9 +731,7 @@ async fn emit_typing_event( ) .await?; - stats - .insert_typing_event_timings - .record(insert_typing_started.elapsed()); + stats.insert_typing_event_timings.record(insert_typing_started.elapsed()); stats.typing_events_sent.fetch_add(1, Ordering::Relaxed); Ok(()) } @@ -799,9 +767,7 @@ async fn emit_message( ) .await?; - stats - .insert_message_timings - .record(insert_message_started.elapsed()); + stats.insert_message_timings.record(insert_message_started.elapsed()); stats.messages_sent.fetch_add(1, Ordering::Relaxed); Ok(message_id) } @@ -880,10 +846,7 @@ async fn create_subscription( let mut attempts = 0_u32; loop { - let config = SubscriptionConfig::without_initial_data( - subscription_id.clone(), - sql.clone(), - ); + let config = SubscriptionConfig::without_initial_data(subscription_id.clone(), sql.clone()); match timeout(SUBSCRIBE_TIMEOUT, link.subscribe_with_config(config)).await { Ok(Ok(subscription)) => return Ok(subscription), @@ -894,13 +857,13 @@ async fn create_subscription( attempts += 1; sleep(delay).await; delay = (delay * 2).min(Duration::from_secs(5)); - } + }, Ok(Err(error)) => return Err(format!("subscribe error: {}", error)), Err(_) if attempts < CHAT_SUBSCRIBE_RETRY_ATTEMPTS => { attempts += 1; sleep(delay).await; delay = (delay * 2).min(Duration::from_secs(5)); - } + }, Err(_) => return Err("subscription timed out before becoming ready".to_string()), } } @@ -978,13 +941,13 @@ async fn shutdown_subscription_tasks( for mut handle in handles { match timeout(SUBSCRIPTION_SHUTDOWN_TIMEOUT, &mut handle).await { - Ok(Ok(Ok(()))) => {} + Ok(Ok(Ok(()))) => {}, Ok(Ok(Err(error))) => errors.push(error), Ok(Err(error)) => errors.push(format!("subscription join error: {}", error)), Err(_) => { handle.abort(); errors.push("timed out waiting for subscription shutdown".to_string()); - } + }, } } @@ -1039,13 +1002,13 @@ impl ChatTopicForwarder { for mut handle in self.handles { match timeout(SUBSCRIPTION_SHUTDOWN_TIMEOUT, &mut handle).await { - Ok(Ok(Ok(()))) => {} + Ok(Ok(Ok(()))) => {}, Ok(Ok(Err(error))) => errors.push(error), Ok(Err(error)) => errors.push(format!("forwarder join error: {}", error)), Err(_) => { handle.abort(); errors.push("timed out waiting for topic forwarder shutdown".to_string()); - } + }, } } @@ -1099,15 +1062,15 @@ async fn run_conversation_forwarder( let forwarded = match parse_conversation_forward_record(&record.payload) { Ok(row) if row.needs_forward => { forward_conversation(&admin_client, &namespace, &row, &stats).await? - } + }, Ok(_) => { stats.skipped_topic_records.fetch_add(1, Ordering::Relaxed); false - } + }, Err(error) => { let _ = consumer.close().await; return Err(format!("conversation topic payload decode: {}", error)); - } + }, }; if forwarded { @@ -1175,29 +1138,22 @@ async fn run_message_forwarder( Ok(row) if row.needs_forward => { rows_to_forward.push(row); records_to_commit.push(record); - } + }, Ok(_) => { stats.skipped_topic_records.fetch_add(1, Ordering::Relaxed); consumer.mark_processed(&record); - } + }, Err(error) => { let _ = consumer.close().await; return Err(format!("message topic payload decode: {}", error)); - } + }, } } if !rows_to_forward.is_empty() { - let forwarded_count = forward_message_rows( - &admin_client, - &namespace, - rows_to_forward, - &stats, - ) - .await?; - stats - .messages_forwarded - .fetch_add(forwarded_count, Ordering::Relaxed); + let forwarded_count = + forward_message_rows(&admin_client, &namespace, rows_to_forward, &stats).await?; + stats.messages_forwarded.fetch_add(forwarded_count, Ordering::Relaxed); for record in &records_to_commit { consumer.mark_processed(record); @@ -1262,29 +1218,22 @@ async fn run_typing_forwarder( Ok(row) if row.needs_forward => { rows_to_forward.push(row); records_to_commit.push(record); - } + }, Ok(_) => { stats.skipped_topic_records.fetch_add(1, Ordering::Relaxed); consumer.mark_processed(&record); - } + }, Err(error) => { let _ = consumer.close().await; return Err(format!("typing topic payload decode: {}", error)); - } + }, } } if !rows_to_forward.is_empty() { - let forwarded_count = forward_typing_rows( - &admin_client, - &namespace, - rows_to_forward, - &stats, - ) - .await?; - stats - .typing_events_forwarded - .fetch_add(forwarded_count, Ordering::Relaxed); + let forwarded_count = + forward_typing_rows(&admin_client, &namespace, rows_to_forward, &stats).await?; + stats.typing_events_forwarded.fetch_add(forwarded_count, Ordering::Relaxed); for record in &records_to_commit { consumer.mark_processed(record); @@ -1431,17 +1380,17 @@ async fn forward_message_rows( if forwarded { forwarded_count += 1; } - } + }, Some(Ok(Err(error))) => { join_set.abort_all(); while join_set.join_next().await.is_some() {} return Err(error); - } + }, Some(Err(error)) => { join_set.abort_all(); while join_set.join_next().await.is_some() {} return Err(format!("message forward task join error: {}", error)); - } + }, None => break, } @@ -1514,17 +1463,17 @@ async fn forward_typing_rows( if forwarded { forwarded_count += 1; } - } + }, Some(Ok(Err(error))) => { join_set.abort_all(); while join_set.join_next().await.is_some() {} return Err(error); - } + }, Some(Err(error)) => { join_set.abort_all(); while join_set.join_next().await.is_some() {} return Err(format!("typing forward task join error: {}", error)); - } + }, None => break, } @@ -1534,7 +1483,9 @@ async fn forward_typing_rows( Ok(forwarded_count) } -fn parse_conversation_forward_record(payload_bytes: &[u8]) -> Result { +fn parse_conversation_forward_record( + payload_bytes: &[u8], +) -> Result { let payload: JsonValue = serde_json::from_slice(payload_bytes) .map_err(|error| format!("invalid topic payload json: {}", error))?; let row = payload.get("row").unwrap_or(&payload); @@ -1622,13 +1573,10 @@ impl UserClientPool { } } - let fresh_client = login_user_with_retry( - self.urls.as_ref(), - username, - self.password.as_ref(), - ) - .await - .map_err(|error| format!("login failed for {}: {}", username, error))?; + let fresh_client = + login_user_with_retry(self.urls.as_ref(), username, self.password.as_ref()) + .await + .map_err(|error| format!("login failed for {}: {}", username, error))?; let mut clients = self.clients.lock().await; if let Some(client) = clients.get(username) { @@ -1653,13 +1601,12 @@ async fn login_user_with_retry( match KalamClient::login_steady_state(urls, username, password).await { Ok(client) => return Ok(client), Err(error) - if attempts < CHAT_LOGIN_RETRY_ATTEMPTS - && is_transient_chat_error(&error) => + if attempts < CHAT_LOGIN_RETRY_ATTEMPTS && is_transient_chat_error(&error) => { attempts += 1; sleep(delay).await; delay = (delay * 2).min(Duration::from_secs(5)); - } + }, Err(error) => return Err(error), } } @@ -1693,11 +1640,10 @@ impl SessionDeliveryTracker { self.peer_messages.fetch_add(1, Ordering::Relaxed); } } - } + }, SubscriptionKind::Typing => { - self.typing_events - .fetch_add(rows.len() as u64, Ordering::Relaxed); - } + self.typing_events.fetch_add(rows.len() as u64, Ordering::Relaxed); + }, } } } @@ -1769,9 +1715,7 @@ async fn prewarm_user_clients( break; }; let pool = user_pool.clone(); - join_set.spawn(async move { - pool.client_for(&username).await.map(|_| username) - }); + join_set.spawn(async move { pool.client_for(&username).await.map(|_| username) }); in_flight += 1; } @@ -1788,13 +1732,13 @@ async fn prewarm_user_clients( if failure_samples.len() < 5 { failure_samples.push(error); } - } + }, Err(error) => { failed_attempts += 1; if failure_samples.len() < 5 { failure_samples.push(format!("active-user prewarm join error: {}", error)); } - } + }, } while in_flight < CHAT_LOGIN_MAX_IN_FLIGHT @@ -1804,9 +1748,7 @@ async fn prewarm_user_clients( break; }; let pool = user_pool.clone(); - join_set.spawn(async move { - pool.client_for(&username).await.map(|_| username) - }); + join_set.spawn(async move { pool.client_for(&username).await.map(|_| username) }); in_flight += 1; } } @@ -1905,9 +1847,9 @@ fn json_string_field(row: &JsonValue, key: &str) -> Result { fn json_u64_field(row: &JsonValue, key: &str) -> Result { match row.get(key) { - Some(JsonValue::Number(value)) => value - .as_u64() - .ok_or_else(|| format!("field {} is not a u64", key)), + Some(JsonValue::Number(value)) => { + value.as_u64().ok_or_else(|| format!("field {} is not a u64", key)) + }, Some(JsonValue::String(value)) => value .parse::() .map_err(|error| format!("field {} is not a valid u64: {}", key, error)), @@ -1929,11 +1871,7 @@ fn json_bool_field(row: &JsonValue, key: &str) -> Result { fn sql_response_has_rows(response: &SqlResponse) -> bool { response.results.iter().any(|result| { - result - .rows - .as_ref() - .map(|rows| !rows.is_empty()) - .unwrap_or(false) + result.rows.as_ref().map(|rows| !rows.is_empty()).unwrap_or(false) || result.row_count.unwrap_or(0) > 0 }) } @@ -1993,15 +1931,13 @@ fn record_sql_metric(stats: &ChatWorkloadStats, kind: ChatSqlKind, elapsed: Dura stats.inserts.fetch_add(1, Ordering::Relaxed); stats.insert_latency_us.fetch_add(elapsed_us, Ordering::Relaxed); update_peak(&stats.insert_max_us, elapsed_us); - } + }, } } fn record_subscription_open(stats: &ChatWorkloadStats, elapsed: Duration) { let elapsed_us = duration_to_us(elapsed); - stats - .subscription_open_latency_us - .fetch_add(elapsed_us, Ordering::Relaxed); + stats.subscription_open_latency_us.fetch_add(elapsed_us, Ordering::Relaxed); update_peak(&stats.subscription_open_max_us, elapsed_us); stats.subscription_connect_timings.record(elapsed); } @@ -2306,7 +2242,11 @@ fn print_chat_summary( println!("{}", insert_typing_stats.display_ms("insert_typing_event")); println!( " Managed server RSS: {}", - format_memory_summary(memory, logged_in_users.max(u64::from(settings.realtime_conversations)), peak_active_subscriptions), + format_memory_summary( + memory, + logged_in_users.max(u64::from(settings.realtime_conversations)), + peak_active_subscriptions + ), ); println!( " Backend note: conversations and messages are USER tables, typing_events is a STREAM table, and Rust topic consumers mirror conversation, message, and typing INSERTs back into recipient user scopes with EXECUTE AS USER.", @@ -2317,9 +2257,7 @@ fn print_chat_summary( ); println!( " Coverage: configured_users={} | sessions_started={} | sessions_completed={}", - settings.user_count, - sessions_started, - sessions_completed, + settings.user_count, sessions_started, sessions_completed, ); } @@ -2332,7 +2270,7 @@ async fn run_sql_with_retry(client: &KalamClient, sql: &str) -> Result { sleep(delay).await; delay = (delay * 2).min(CHAT_SQL_RETRY_MAX_DELAY); - } + }, Err(error) => return Err(error), } } @@ -2344,12 +2282,12 @@ fn is_transient_chat_error(error: &str) -> bool { let lower = error.to_ascii_lowercase(); lower.contains("timeout") || lower.contains("temporarily unavailable") - || lower.contains("connection failed") + || lower.contains("connection failed") || lower.contains("connection reset") || lower.contains("connection refused") || lower.contains("broken pipe") - || lower.contains("error sending request") - || lower.contains("transport") + || lower.contains("error sending request") + || lower.contains("transport") || lower.contains("too many open files") || lower.contains("503") } @@ -2408,8 +2346,7 @@ fn percentile_us(sorted_samples: &[u64], percentile: f64) -> f64 { let upper = idx.ceil() as usize; let fraction = idx - lower as f64; - sorted_samples[lower] as f64 * (1.0 - fraction) - + sorted_samples[upper] as f64 * fraction + sorted_samples[lower] as f64 * (1.0 - fraction) + sorted_samples[upper] as f64 * fraction } fn lock_unpoisoned(mutex: &Mutex) -> std::sync::MutexGuard<'_, T> { @@ -2422,12 +2359,7 @@ fn lock_unpoisoned(mutex: &Mutex) -> std::sync::MutexGuard<'_, T> { fn update_peak(peak: &AtomicU64, candidate: u64) { let mut current = peak.load(Ordering::Relaxed); while candidate > current { - match peak.compare_exchange_weak( - current, - candidate, - Ordering::Relaxed, - Ordering::Relaxed, - ) { + match peak.compare_exchange_weak(current, candidate, Ordering::Relaxed, Ordering::Relaxed) { Ok(_) => break, Err(observed) => current = observed, } @@ -2486,10 +2418,10 @@ fn format_memory(memory_bytes: Option) -> String { match memory_bytes { Some(bytes) if bytes >= 1024 * 1024 * 1024 => { format!("{:.2} GiB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) - } + }, Some(bytes) if bytes >= 1024 * 1024 => { format!("{:.1} MiB", bytes as f64 / (1024.0 * 1024.0)) - } + }, Some(bytes) if bytes >= 1024 => format!("{:.1} KiB", bytes as f64 / 1024.0), Some(bytes) => format!("{} B", bytes), None => "n/a".to_string(), diff --git a/benchv2/src/chat_runtime/mod.rs b/benchv2/src/chat_runtime/mod.rs new file mode 100644 index 00000000..ae3eca2d --- /dev/null +++ b/benchv2/src/chat_runtime/mod.rs @@ -0,0 +1,7 @@ +pub mod chat_realtime_bench; + +use crate::benchmarks::Benchmark; + +pub fn benchmarks() -> Vec> { + vec![Box::new(chat_realtime_bench::ChatRealtimeBench)] +} diff --git a/benchv2/src/config.rs b/benchv2/src/config.rs index eff7b654..bb0e6981 100644 --- a/benchv2/src/config.rs +++ b/benchv2/src/config.rs @@ -1,4 +1,19 @@ -use clap::Parser; +use clap::{Parser, ValueEnum}; + +#[derive(ValueEnum, Debug, Clone, Copy, PartialEq, Eq)] +pub enum BenchmarkSuite { + Standard, + ChatRuntime, +} + +impl BenchmarkSuite { + pub fn as_str(self) -> &'static str { + match self { + Self::Standard => "standard", + Self::ChatRuntime => "chat-runtime", + } + } +} /// CLI configuration for the benchmark tool. #[derive(Parser, Debug, Clone)] @@ -38,6 +53,11 @@ pub struct Config { #[arg(long, default_value = "results")] pub output_dir: String, + /// Benchmark suite to run. The standard suite writes to results/ by default; + /// named suites should pass their own output directory from their runner script. + #[arg(long, value_enum, default_value_t = BenchmarkSuite::Standard)] + pub suite: BenchmarkSuite, + /// Run only benchmarks matching this filter (substring match) #[arg(long)] pub filter: Option, diff --git a/benchv2/src/lib.rs b/benchv2/src/lib.rs index 0a080104..e98a32b6 100644 --- a/benchv2/src/lib.rs +++ b/benchv2/src/lib.rs @@ -1,4 +1,5 @@ pub mod benchmarks; +pub mod chat_runtime; pub mod client; pub mod comparison; pub mod config; @@ -8,3 +9,12 @@ pub mod reporter; pub mod runner; pub mod system_info; pub mod verdict; + +use config::BenchmarkSuite; + +pub fn selected_benchmarks(suite: BenchmarkSuite) -> Vec> { + match suite { + BenchmarkSuite::Standard => benchmarks::standard_benchmarks(), + BenchmarkSuite::ChatRuntime => chat_runtime::benchmarks(), + } +} diff --git a/benchv2/src/main.rs b/benchv2/src/main.rs index 802f5b8d..dbc9d2e6 100644 --- a/benchv2/src/main.rs +++ b/benchv2/src/main.rs @@ -5,6 +5,7 @@ use std::time::Instant; use clap::Parser; mod benchmarks; +mod chat_runtime; mod client; mod comparison; mod config; @@ -16,7 +17,7 @@ mod system_info; mod verdict; use client::KalamClient; -use config::Config; +use config::{BenchmarkSuite, Config}; use reporter::html_reporter; use reporter::json_reporter; use system_info::collect_system_info; @@ -29,7 +30,7 @@ async fn main() { if config.list_benches { println!("Available benchmarks:"); - for bench in benchmarks::all_benchmarks() { + for bench in selected_benchmarks(config.suite) { println!(" {:<28} {}", bench.name(), bench.description()); } return; @@ -50,6 +51,8 @@ async fn main() { println!(" Warmup: {}", config.warmup); println!(" Concurrency: {}", config.concurrency); println!(" Max Subs: {}", config.max_subscribers); + println!(" Suite: {}", config.suite.as_str()); + println!(" Output Dir: {}", config.output_dir); if let Some(ref f) = config.filter { println!(" Filter: {}", f); } @@ -191,6 +194,13 @@ async fn main() { } } +pub(crate) fn selected_benchmarks(suite: BenchmarkSuite) -> Vec> { + match suite { + BenchmarkSuite::Standard => benchmarks::standard_benchmarks(), + BenchmarkSuite::ChatRuntime => chat_runtime::benchmarks(), + } +} + fn load_kalamdb_version() -> String { let fallback = env!("CARGO_PKG_VERSION").to_string(); let root_manifest = Path::new(env!("CARGO_MANIFEST_DIR")).join("../Cargo.toml"); diff --git a/benchv2/src/runner.rs b/benchv2/src/runner.rs index 6be2dfe5..0d54064e 100644 --- a/benchv2/src/runner.rs +++ b/benchv2/src/runner.rs @@ -1,10 +1,11 @@ use std::time::Instant; -use crate::benchmarks::{all_benchmarks, enabled_in_default_suite, Benchmark}; +use crate::benchmarks::{enabled_in_default_suite, Benchmark}; use crate::client::KalamClient; use crate::comparison::{self, PreviousRun}; use crate::config::Config; use crate::metrics::BenchmarkResult; +use crate::selected_benchmarks; use crate::verdict; /// Runs all registered benchmarks according to config, returning results. @@ -13,7 +14,7 @@ pub async fn run_all( config: &Config, previous: Option<&PreviousRun>, ) -> Vec { - let benchmarks = all_benchmarks(); + let benchmarks = selected_benchmarks(config.suite); let mut results = Vec::new(); let mut selected = Vec::new(); diff --git a/link/link-common/src/subscription/models/subscription_options.rs b/link/link-common/src/subscription/models/subscription_options.rs index 5aa7e1c5..9a141742 100644 --- a/link/link-common/src/subscription/models/subscription_options.rs +++ b/link/link-common/src/subscription/models/subscription_options.rs @@ -15,8 +15,8 @@ use crate::seq_id::SeqId; /// ```rust /// use kalam_client::{SeqId, SubscriptionOptions}; /// -/// // Fetch last 100 rows with batch size of 50 -/// let options = SubscriptionOptions::default().with_batch_size(50).with_last_rows(100); +/// // Fetch last 50 rows in a single initial batch +/// let options = SubscriptionOptions::default().with_batch_size(50).with_last_rows(50); /// /// // Resume from a specific sequence ID after reconnection /// let some_seq_id = SeqId::new(123);