Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 90 additions & 6 deletions backend/crates/kalamdb-api/src/ws/events/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::sync::Arc;

use actix_ws::Session;
use kalamdb_commons::{
websocket::{BatchControl, SubscriptionRequest, MAX_ROWS_PER_BATCH},
websocket::{BatchControl, SubscriptionOptions, SubscriptionRequest, MAX_ROWS_PER_BATCH},
WebSocketMessage,
};
use kalamdb_core::providers::arrow_json_conversion::row_into_json_map;
Expand Down Expand Up @@ -61,11 +61,19 @@ pub async fn handle_subscribe(
let subscription_id = subscription.id.clone();
let subscription_options = subscription.options.clone();

// Determine batch size for initial data options
let batch_size = subscription_options
.as_ref()
.and_then(|options| options.batch_size)
.unwrap_or(MAX_ROWS_PER_BATCH);
let batch_size = subscription_batch_size(subscription_options.as_ref());

if let Err(message) = validate_subscription_options(subscription_options.as_ref(), batch_size) {
let _ = send_error(
session,
&subscription_id,
WsErrorCode::Unsupported,
&message,
compression_enabled,
)
.await;
return Ok(());
}

// Create initial data options respecting all three options:
// - from: Resume from a specific sequence ID
Expand Down Expand Up @@ -214,3 +222,79 @@ pub async fn handle_subscribe(
},
}
}

fn subscription_batch_size(options: Option<&SubscriptionOptions>) -> usize {
options
.and_then(|options| options.batch_size)
.unwrap_or(MAX_ROWS_PER_BATCH)
}

fn validate_subscription_options(
options: Option<&SubscriptionOptions>,
batch_size: usize,
) -> Result<(), String> {
let Some(options) = options else {
return Ok(());
};

if let Some(last_rows) = options.last_rows {
let last_rows = last_rows as usize;
if last_rows > batch_size {
return Err(format!(
"last_rows ({last_rows}) cannot exceed batch_size ({batch_size}); paginated last_rows replay is not supported"
));
}
}

Ok(())
}

#[cfg(test)]
mod tests {
use kalamdb_commons::websocket::SubscriptionOptions;

use super::{subscription_batch_size, validate_subscription_options, MAX_ROWS_PER_BATCH};

#[test]
fn validate_subscription_options_allows_last_rows_within_batch_size() {
let options = SubscriptionOptions {
batch_size: Some(50),
last_rows: Some(50),
from: None,
};

let batch_size = subscription_batch_size(Some(&options));

assert_eq!(batch_size, 50);
assert!(validate_subscription_options(Some(&options), batch_size).is_ok());
}

#[test]
fn validate_subscription_options_rejects_last_rows_above_batch_size() {
let options = SubscriptionOptions {
batch_size: Some(50),
last_rows: Some(51),
from: None,
};

let batch_size = subscription_batch_size(Some(&options));
let error = validate_subscription_options(Some(&options), batch_size)
.expect_err("last_rows above batch_size should be rejected");

assert!(error.contains("last_rows (51) cannot exceed batch_size (50)"));
}

#[test]
fn validate_subscription_options_uses_default_batch_size_when_unspecified() {
let options = SubscriptionOptions {
batch_size: None,
last_rows: Some(MAX_ROWS_PER_BATCH as u32 + 1),
from: None,
};

let batch_size = subscription_batch_size(Some(&options));

assert_eq!(batch_size, MAX_ROWS_PER_BATCH);
assert!(validate_subscription_options(Some(&options), batch_size).is_err());
}
}
6 changes: 3 additions & 3 deletions backend/crates/kalamdb-dialect/src/ddl/subscribe_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
//! SUBSCRIBE TO app.messages WHERE user_id = CURRENT_USER() OPTIONS (last_rows=10);
//!
//! -- With multiple options
//! SUBSCRIBE TO app.messages OPTIONS (last_rows=100, batch_size=50);
//! SUBSCRIBE TO app.messages OPTIONS (last_rows=50, batch_size=50);
//!
//! -- Resume from specific sequence ID
//! SUBSCRIBE TO app.messages OPTIONS (from=12345);
Expand Down Expand Up @@ -603,12 +603,12 @@ mod tests {
use kalamdb_commons::ids::SeqId;

let stmt = SubscribeStatement::parse(
"SUBSCRIBE TO app.messages OPTIONS (last_rows=100, batch_size=50, from=999)",
"SUBSCRIBE TO app.messages OPTIONS (last_rows=50, batch_size=50, from=999)",
)
.unwrap();
assert_eq!(stmt.namespace, NamespaceId::from("app"));
assert_eq!(stmt.table_name, TableName::from("messages"));
assert_eq!(stmt.options.last_rows, Some(100));
assert_eq!(stmt.options.last_rows, Some(50));
assert_eq!(stmt.options.batch_size, Some(50));
assert_eq!(stmt.options.from, Some(SeqId::new(999)));
}
Expand Down
Loading
Loading