diff --git a/bindings/cpp/src/types.rs b/bindings/cpp/src/types.rs index fef73cea..a25a5442 100644 --- a/bindings/cpp/src/types.rs +++ b/bindings/cpp/src/types.rs @@ -478,3 +478,220 @@ pub fn core_lake_snapshot_to_ffi(snapshot: &fcore::metadata::LakeSnapshot) -> ff bucket_offsets, } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + BinaryArray, BooleanArray, Date32Array, FixedSizeBinaryArray, Float32Array, Float64Array, + Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, RecordBatch, + Time32MillisecondArray, Time64MicrosecondArray, TimestampMillisecondArray, + }; + use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; + use fcore::record::{ChangeType, ScanRecord, ScanRecords}; + use fcore::row::{ColumnarRow, InternalRow}; + use std::collections::HashMap; + use std::sync::Arc; + + fn make_ffi_datum(datum_type: i32) -> ffi::FfiDatum { + ffi::FfiDatum { + datum_type, + bool_val: false, + i32_val: 0, + i64_val: 0, + f32_val: 0.0, + f64_val: 0.0, + string_val: String::new(), + bytes_val: vec![], + } + } + + #[test] + fn ffi_descriptor_to_core_rejects_invalid_type() { + let descriptor = ffi::FfiTableDescriptor { + schema: ffi::FfiSchema { + columns: vec![ffi::FfiColumn { + name: "bad".to_string(), + data_type: 999, + comment: String::new(), + }], + primary_keys: vec![], + }, + partition_keys: vec![], + bucket_count: 0, + bucket_keys: vec![], + properties: vec![], + comment: String::new(), + }; + + let result = ffi_descriptor_to_core(&descriptor); + assert!(result.is_err()); + } + + #[test] + fn ffi_row_to_core_maps_datum_types() { + let mut fields = Vec::new(); + + fields.push(make_ffi_datum(DATUM_TYPE_NULL)); + + let mut bool_datum = make_ffi_datum(DATUM_TYPE_BOOL); + bool_datum.bool_val = true; + fields.push(bool_datum); + + let mut int32_datum = make_ffi_datum(DATUM_TYPE_INT32); + int32_datum.i32_val = 11; + fields.push(int32_datum); + + let mut int64_datum = make_ffi_datum(DATUM_TYPE_INT64); + int64_datum.i64_val = 22; + fields.push(int64_datum); + + let mut f32_datum = make_ffi_datum(DATUM_TYPE_FLOAT32); + f32_datum.f32_val = 1.25; + fields.push(f32_datum); + + let mut f64_datum = make_ffi_datum(DATUM_TYPE_FLOAT64); + f64_datum.f64_val = 2.5; + fields.push(f64_datum); + + let mut str_datum = make_ffi_datum(DATUM_TYPE_STRING); + str_datum.string_val = "hello".to_string(); + fields.push(str_datum); + + let mut bytes_datum = make_ffi_datum(DATUM_TYPE_BYTES); + bytes_datum.bytes_val = vec![1, 2, 3]; + fields.push(bytes_datum); + + let row = ffi::FfiGenericRow { fields }; + let core_row = ffi_row_to_core(&row); + + assert_eq!(core_row.get_field_count(), 8); + assert!(core_row.is_null_at(0)); + assert!(core_row.get_boolean(1)); + assert_eq!(core_row.get_int(2), 11); + assert_eq!(core_row.get_long(3), 22); + assert!((core_row.get_float(4) - 1.25).abs() < f32::EPSILON); + assert!((core_row.get_double(5) - 2.5).abs() < f64::EPSILON); + assert_eq!(core_row.get_string(6), "hello"); + assert_eq!(core_row.get_bytes(7), &[1, 2, 3]); + } + + #[test] + fn core_scan_records_to_ffi_maps_records() { + let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![7]))]) + .expect("record batch"); + let row = ColumnarRow::new(Arc::new(batch)); + let record = ScanRecord::new(row, 10, 20, ChangeType::Insert); + + let bucket = fcore::metadata::TableBucket::new(1, 2); + let records = ScanRecords::new(HashMap::from([(bucket.clone(), vec![record])])); + + let ffi_records = core_scan_records_to_ffi(&records); + assert_eq!(ffi_records.records.len(), 1); + let ffi_record = &ffi_records.records[0]; + assert_eq!(ffi_record.bucket_id, 2); + assert_eq!(ffi_record.offset, 10); + assert_eq!(ffi_record.timestamp, 20); + assert_eq!(ffi_record.row.fields.len(), 1); + assert_eq!(ffi_record.row.fields[0].datum_type, DATUM_TYPE_INT32); + assert_eq!(ffi_record.row.fields[0].i32_val, 7); + } + + #[test] + fn core_row_to_ffi_fields_maps_arrow_types() { + let schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Boolean, false), + Field::new("i32", DataType::Int32, false), + Field::new("i64", DataType::Int64, false), + Field::new("f32", DataType::Float32, false), + Field::new("f64", DataType::Float64, false), + Field::new("s", DataType::Utf8, false), + Field::new("bin", DataType::Binary, false), + Field::new("fix", DataType::FixedSizeBinary(2), false), + Field::new("date", DataType::Date32, false), + Field::new( + "ts", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new("t32", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("t64", DataType::Time64(TimeUnit::Microsecond), false), + Field::new("ls", DataType::LargeUtf8, false), + Field::new("lb", DataType::LargeBinary, false), + ])); + + let batch = RecordBatch::try_new( + schema, + vec![ + Arc::new(BooleanArray::from(vec![true])), + Arc::new(Int32Array::from(vec![1])), + Arc::new(Int64Array::from(vec![2])), + Arc::new(Float32Array::from(vec![1.5])), + Arc::new(Float64Array::from(vec![2.5])), + Arc::new(arrow::array::StringArray::from(vec!["text"])), + Arc::new(BinaryArray::from(vec![b"bin".as_slice()])), + Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ab".as_slice())].into_iter(), + 2, + ) + .expect("fixed array"), + ), + Arc::new(Date32Array::from(vec![3])), + Arc::new(TimestampMillisecondArray::from(vec![30])), + Arc::new(Time32MillisecondArray::from(vec![10])), + Arc::new(Time64MicrosecondArray::from(vec![20])), + Arc::new(LargeStringArray::from(vec!["large"])), + Arc::new(LargeBinaryArray::from(vec![b"big".as_slice()])), + ], + ) + .expect("record batch"); + + let row = ColumnarRow::new(Arc::new(batch)); + let fields = core_row_to_ffi_fields(&row); + assert_eq!(fields.len(), 14); + + assert_eq!(fields[0].datum_type, DATUM_TYPE_BOOL); + assert!(fields[0].bool_val); + + assert_eq!(fields[1].datum_type, DATUM_TYPE_INT32); + assert_eq!(fields[1].i32_val, 1); + + assert_eq!(fields[2].datum_type, DATUM_TYPE_INT64); + assert_eq!(fields[2].i64_val, 2); + + assert_eq!(fields[3].datum_type, DATUM_TYPE_FLOAT32); + assert!((fields[3].f32_val - 1.5).abs() < f32::EPSILON); + + assert_eq!(fields[4].datum_type, DATUM_TYPE_FLOAT64); + assert!((fields[4].f64_val - 2.5).abs() < f64::EPSILON); + + assert_eq!(fields[5].datum_type, DATUM_TYPE_STRING); + assert_eq!(fields[5].string_val, "text"); + + assert_eq!(fields[6].datum_type, DATUM_TYPE_BYTES); + assert_eq!(fields[6].bytes_val, b"bin"); + + assert_eq!(fields[7].datum_type, DATUM_TYPE_BYTES); + assert_eq!(fields[7].bytes_val, b"ab"); + + assert_eq!(fields[8].datum_type, DATUM_TYPE_INT32); + assert_eq!(fields[8].i32_val, 3); + + assert_eq!(fields[9].datum_type, DATUM_TYPE_INT64); + assert_eq!(fields[9].i64_val, 30); + + assert_eq!(fields[10].datum_type, DATUM_TYPE_INT32); + assert_eq!(fields[10].i32_val, 10); + + assert_eq!(fields[11].datum_type, DATUM_TYPE_INT64); + assert_eq!(fields[11].i64_val, 20); + + assert_eq!(fields[12].datum_type, DATUM_TYPE_STRING); + assert_eq!(fields[12].string_val, "large"); + + assert_eq!(fields[13].datum_type, DATUM_TYPE_BYTES); + assert_eq!(fields[13].bytes_val, b"big"); + } +} diff --git a/crates/fluss/src/client/admin.rs b/crates/fluss/src/client/admin.rs index 6646f97c..9e999c91 100644 --- a/crates/fluss/src/client/admin.rs +++ b/crates/fluss/src/client/admin.rs @@ -324,3 +324,225 @@ impl FlussAdmin { Ok(tasks) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster::{ServerNode, ServerType}; + use crate::metadata::{ + DataField, DataTypes, DatabaseDescriptor, JsonSerde, Schema, TableDescriptor, TablePath, + }; + use crate::proto::{ + CreateDatabaseResponse, CreateTableResponse, DatabaseExistsResponse, DropDatabaseResponse, + DropTableResponse, GetDatabaseInfoResponse, GetLatestLakeSnapshotResponse, + GetTableInfoResponse, ListDatabasesResponse, ListOffsetsResponse, ListTablesResponse, + PbLakeSnapshotForBucket, PbListOffsetsRespForBucket, TableExistsResponse, + }; + use crate::test_utils::{build_cluster_with_coordinator_arc, build_mock_connection}; + use prost::Message; + use std::sync::Arc; + + const API_CREATE_DATABASE: i16 = 1001; + const API_DROP_DATABASE: i16 = 1002; + const API_LIST_DATABASES: i16 = 1003; + const API_DATABASE_EXISTS: i16 = 1004; + const API_CREATE_TABLE: i16 = 1005; + const API_DROP_TABLE: i16 = 1006; + const API_GET_TABLE: i16 = 1007; + const API_LIST_TABLES: i16 = 1008; + const API_TABLE_EXISTS: i16 = 1010; + const API_LIST_OFFSETS: i16 = 1021; + const API_GET_LAKE_SNAPSHOT: i16 = 1032; + const API_GET_DATABASE_INFO: i16 = 1035; + + fn build_table_descriptor() -> TableDescriptor { + let row_type = DataTypes::row(vec![ + DataField::new("id".to_string(), DataTypes::int(), None), + DataField::new("name".to_string(), DataTypes::string(), None), + ]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema"); + TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor") + } + + #[tokio::test] + async fn admin_requests_round_trip() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_id = 42; + let table_descriptor = build_table_descriptor(); + let table_json = + serde_json::to_vec(&table_descriptor.serialize_json().expect("table json")).unwrap(); + + let db_descriptor = DatabaseDescriptor::builder() + .comment("test") + .custom_property("k", "v") + .build(); + let db_json = + serde_json::to_vec(&db_descriptor.serialize_json().expect("db json")).unwrap(); + + let (admin_connection, admin_handle) = + build_mock_connection(move |api_key: crate::rpc::ApiKey, _, _| { + match i16::from(api_key) { + API_CREATE_DATABASE => CreateDatabaseResponse::default().encode_to_vec(), + API_CREATE_TABLE => CreateTableResponse::default().encode_to_vec(), + API_DROP_TABLE => DropTableResponse::default().encode_to_vec(), + API_DROP_DATABASE => DropDatabaseResponse::default().encode_to_vec(), + API_LIST_TABLES => ListTablesResponse { + table_name: vec!["tbl".to_string()], + } + .encode_to_vec(), + API_TABLE_EXISTS => TableExistsResponse { exists: true }.encode_to_vec(), + API_LIST_DATABASES => ListDatabasesResponse { + database_name: vec!["db".to_string(), "db2".to_string()], + } + .encode_to_vec(), + API_DATABASE_EXISTS => DatabaseExistsResponse { exists: false }.encode_to_vec(), + API_GET_TABLE => GetTableInfoResponse { + table_id, + schema_id: 1, + table_json: table_json.clone(), + created_time: 10, + modified_time: 20, + } + .encode_to_vec(), + API_GET_DATABASE_INFO => GetDatabaseInfoResponse { + database_json: db_json.clone(), + created_time: 5, + modified_time: 6, + } + .encode_to_vec(), + API_GET_LAKE_SNAPSHOT => GetLatestLakeSnapshotResponse { + table_id, + snapshot_id: 99, + bucket_snapshots: vec![PbLakeSnapshotForBucket { + partition_id: None, + bucket_id: 0, + log_offset: Some(123), + }], + } + .encode_to_vec(), + _ => vec![], + } + }) + .await; + + let (tablet_connection, tablet_handle) = build_mock_connection( + |api_key: crate::rpc::ApiKey, _, _| match i16::from(api_key) { + API_LIST_OFFSETS => ListOffsetsResponse { + buckets_resp: vec![PbListOffsetsRespForBucket { + bucket_id: 0, + error_code: None, + error_message: None, + offset: Some(7), + }], + } + .encode_to_vec(), + _ => vec![], + }, + ) + .await; + + let coordinator = ServerNode::new( + 100, + "127.0.0.1".to_string(), + 9999, + ServerType::CoordinatorServer, + ); + let tablet = ServerNode::new(1, "127.0.0.1".to_string(), 9998, ServerType::TabletServer); + let cluster = build_cluster_with_coordinator_arc( + &table_path, + table_id, + coordinator.clone(), + tablet.clone(), + ); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let rpc_client = Arc::new(RpcClient::new()); + rpc_client.insert_connection_for_test(&coordinator, admin_connection); + rpc_client.insert_connection_for_test(&tablet, tablet_connection); + + let admin = FlussAdmin::new(rpc_client.clone(), metadata.clone()).await?; + + admin + .create_database("db", true, Some(&db_descriptor)) + .await?; + admin + .create_table(&table_path, &table_descriptor, true) + .await?; + admin.drop_table(&table_path, true).await?; + + let tables = admin.list_tables("db").await?; + assert_eq!(tables, vec!["tbl".to_string()]); + + let exists = admin.table_exists(&table_path).await?; + assert!(exists); + + let dbs = admin.list_databases().await?; + assert_eq!(dbs.len(), 2); + + let db_exists = admin.database_exists("db").await?; + assert!(!db_exists); + + let table_info = admin.get_table(&table_path).await?; + assert_eq!(table_info.table_id, table_id); + + let db_info = admin.get_database_info("db").await?; + assert_eq!(db_info.database_name(), "db"); + + let snapshot = admin.get_latest_lake_snapshot(&table_path).await?; + assert_eq!(snapshot.snapshot_id(), 99); + assert_eq!( + snapshot + .table_buckets_offset() + .get(&TableBucket::new(table_id, 0)), + Some(&123) + ); + + let offsets = admin + .list_offsets(&table_path, &[0], OffsetSpec::Earliest) + .await?; + assert_eq!(offsets.get(&0), Some(&7)); + + admin_handle.abort(); + tablet_handle.abort(); + Ok(()) + } + + #[tokio::test] + async fn list_offsets_empty_buckets_error() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let (admin_connection, admin_handle) = build_mock_connection( + |api_key: crate::rpc::ApiKey, _, _| match i16::from(api_key) { + API_CREATE_DATABASE => CreateDatabaseResponse::default().encode_to_vec(), + _ => vec![], + }, + ) + .await; + let coordinator = ServerNode::new( + 10, + "127.0.0.1".to_string(), + 9999, + ServerType::CoordinatorServer, + ); + let tablet = ServerNode::new(11, "127.0.0.1".to_string(), 8081, ServerType::TabletServer); + let cluster = build_cluster_with_coordinator_arc(&table_path, 1, coordinator, tablet); + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let rpc_client = Arc::new(RpcClient::new()); + rpc_client.insert_connection_for_test( + metadata.get_cluster().get_coordinator_server().unwrap(), + admin_connection, + ); + + let admin = FlussAdmin::new(rpc_client, metadata).await?; + + let result = admin + .list_offsets(&table_path, &[], OffsetSpec::Earliest) + .await; + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + admin_handle.abort(); + Ok(()) + } +} diff --git a/crates/fluss/src/client/connection.rs b/crates/fluss/src/client/connection.rs index 595daf55..26480ddf 100644 --- a/crates/fluss/src/client/connection.rs +++ b/crates/fluss/src/client/connection.rs @@ -59,6 +59,24 @@ impl FlussConnection { self.network_connects.clone() } + pub fn config(&self) -> &Config { + &self.args + } + + #[cfg(test)] + pub(crate) fn new_for_test( + metadata: Arc, + network_connects: Arc, + args: Config, + ) -> Self { + FlussConnection { + metadata, + network_connects, + args, + writer_client: Default::default(), + } + } + pub async fn get_admin(&self) -> Result { FlussAdmin::new(self.network_connects.clone(), self.metadata.clone()).await } @@ -85,3 +103,129 @@ impl FlussConnection { Ok(FlussTable::new(self, self.metadata.clone(), table_info)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster::{Cluster, ServerNode, ServerType}; + use crate::metadata::{DataField, DataTypes, Schema, TableDescriptor, TableInfo, TablePath}; + use crate::test_utils::build_mock_connection; + use std::collections::HashMap; + + fn build_cluster() -> Arc { + let coordinator = ServerNode::new( + 1, + "127.0.0.1".to_string(), + 9092, + ServerType::CoordinatorServer, + ); + + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema"); + let descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor"); + let table_info = TableInfo::of(table_path.clone(), 1, 1, descriptor, 0, 0); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), 1); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert(table_path, table_info); + + Arc::new(Cluster::new( + Some(coordinator), + HashMap::new(), + HashMap::new(), + HashMap::new(), + table_id_by_path, + table_info_by_path, + )) + } + + #[tokio::test] + async fn get_or_create_writer_client_is_cached() -> Result<()> { + let metadata = Arc::new(Metadata::new_for_test(build_cluster())); + let rpc_client = Arc::new(RpcClient::new()); + let config = Config::default(); + + let connection = FlussConnection { + metadata, + network_connects: rpc_client, + args: config, + writer_client: Default::default(), + }; + + let first = connection.get_or_create_writer_client()?; + let second = connection.get_or_create_writer_client()?; + assert!(Arc::ptr_eq(&first, &second)); + + drop(first); + drop(second); + + let stored = connection + .writer_client + .write() + .take() + .expect("writer client"); + let client = Arc::try_unwrap(stored).unwrap_or_else(|_| { + panic!("writer client still shared"); + }); + client.close().await?; + Ok(()) + } + + #[test] + fn exposes_config_and_clients() { + let metadata = Arc::new(Metadata::new_for_test(build_cluster())); + let rpc_client = Arc::new(RpcClient::new()); + let config = Config::default(); + + let connection = FlussConnection { + metadata: metadata.clone(), + network_connects: rpc_client.clone(), + args: config.clone(), + writer_client: Default::default(), + }; + + assert_eq!( + connection.config().request_max_size, + config.request_max_size + ); + assert!(Arc::ptr_eq(&connection.get_metadata(), &metadata)); + assert!(Arc::ptr_eq(&connection.get_connections(), &rpc_client)); + } + + #[tokio::test] + async fn get_admin_uses_cached_connection() -> Result<()> { + let metadata = Arc::new(Metadata::new_for_test(build_cluster())); + let rpc_client = Arc::new(RpcClient::new()); + let (connection, handle) = + build_mock_connection(|_api_key: crate::rpc::ApiKey, _, _| Vec::new()).await; + let coordinator = metadata + .get_cluster() + .get_coordinator_server() + .expect("coordinator") + .clone(); + rpc_client.insert_connection_for_test(&coordinator, connection); + + let connection = FlussConnection { + metadata, + network_connects: rpc_client, + args: Config::default(), + writer_client: Default::default(), + }; + + let _admin = connection.get_admin().await?; + handle.abort(); + Ok(()) + } +} diff --git a/crates/fluss/src/client/credentials.rs b/crates/fluss/src/client/credentials.rs index c520b441..ef432425 100644 --- a/crates/fluss/src/client/credentials.rs +++ b/crates/fluss/src/client/credentials.rs @@ -161,7 +161,23 @@ impl CredentialsCache { mod tests { use super::*; use crate::client::metadata::Metadata; - use crate::cluster::Cluster; + use crate::cluster::{Cluster, ServerNode, ServerType}; + use crate::proto::{GetFileSystemSecurityTokenResponse, PbKeyValue}; + use crate::test_utils::build_mock_connection; + use prost::Message; + + const API_GET_SECURITY_TOKEN: i16 = 1025; + + fn build_cluster(server: ServerNode) -> Arc { + Arc::new(Cluster::new( + None, + HashMap::from([(server.id(), server)]), + HashMap::new(), + HashMap::new(), + HashMap::new(), + HashMap::new(), + )) + } #[test] fn convert_hadoop_key_to_opendal_maps_known_keys() { @@ -206,4 +222,78 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn refresh_from_server_returns_empty_when_token_missing() -> Result<()> { + let (connection, handle) = + build_mock_connection( + |api_key: crate::rpc::ApiKey, _, _| match i16::from(api_key) { + API_GET_SECURITY_TOKEN => GetFileSystemSecurityTokenResponse { + schema: "s3".to_string(), + token: Vec::new(), + expiration_time: None, + addition_info: vec![], + } + .encode_to_vec(), + _ => vec![], + }, + ) + .await; + + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9999, ServerType::TabletServer); + let rpc_client = Arc::new(RpcClient::new()); + rpc_client.insert_connection_for_test(&server, connection); + let cache = CredentialsCache::new( + rpc_client, + Arc::new(Metadata::new_for_test(build_cluster(server))), + ); + + let props = cache.get_or_refresh().await?; + assert!(props.is_empty()); + handle.abort(); + Ok(()) + } + + #[tokio::test] + async fn refresh_from_server_parses_token() -> Result<()> { + let token_json = serde_json::json!({ + "access_key_id": "ak", + "access_key_secret": "sk", + "security_token": "st" + }); + let token_bytes = serde_json::to_vec(&token_json).unwrap(); + let (connection, handle) = + build_mock_connection(move |api_key: crate::rpc::ApiKey, _, _| { + match i16::from(api_key) { + API_GET_SECURITY_TOKEN => GetFileSystemSecurityTokenResponse { + schema: "s3".to_string(), + token: token_bytes.clone(), + expiration_time: Some(100), + addition_info: vec![PbKeyValue { + key: "fs.s3a.endpoint".to_string(), + value: "localhost".to_string(), + }], + } + .encode_to_vec(), + _ => vec![], + } + }) + .await; + + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9999, ServerType::TabletServer); + let rpc_client = Arc::new(RpcClient::new()); + rpc_client.insert_connection_for_test(&server, connection); + let cache = CredentialsCache::new( + rpc_client, + Arc::new(Metadata::new_for_test(build_cluster(server))), + ); + + let props = cache.get_or_refresh().await?; + assert_eq!(props.get("access_key_id"), Some(&"ak".to_string())); + assert_eq!(props.get("secret_access_key"), Some(&"sk".to_string())); + assert_eq!(props.get("security_token"), Some(&"st".to_string())); + assert_eq!(props.get("endpoint"), Some(&"localhost".to_string())); + handle.abort(); + Ok(()) + } } diff --git a/crates/fluss/src/client/metadata.rs b/crates/fluss/src/client/metadata.rs index 3c6730b5..f7566753 100644 --- a/crates/fluss/src/client/metadata.rs +++ b/crates/fluss/src/client/metadata.rs @@ -165,13 +165,137 @@ impl Metadata { #[cfg(test)] mod tests { use super::*; - use crate::metadata::{TableBucket, TablePath}; - use crate::test_utils::build_cluster_arc; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; + use crate::metadata::{ + DataField, DataTypes, JsonSerde, Schema, TableBucket, TableDescriptor, TableInfo, TablePath, + }; + use crate::proto::{ + MetadataResponse, PbBucketMetadata, PbServerNode, PbTableMetadata, PbTablePath, + }; + use crate::test_utils::build_mock_connection; + use prost::Message; + use std::collections::{HashMap, HashSet}; + use std::sync::Arc; + + const API_UPDATE_METADATA: i16 = 1012; + + fn build_table_info(table_path: TablePath, table_id: i64) -> TableInfo { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + TableInfo::of(table_path, table_id, 1, table_descriptor, 0, 0) + } + + fn build_cluster(table_path: &TablePath, table_id: i64) -> Arc { + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = BucketLocation::new( + table_bucket.clone(), + Some(server.clone()), + table_path.clone(), + ); + + let mut servers = HashMap::new(); + servers.insert(server.id(), server); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert( + table_path.clone(), + build_table_info(table_path.clone(), table_id), + ); + + Arc::new(Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + )) + } + + fn build_cluster_with_server(server: ServerNode) -> Arc { + Arc::new(Cluster::new( + None, + HashMap::from([(server.id(), server)]), + HashMap::new(), + HashMap::new(), + HashMap::new(), + HashMap::new(), + )) + } + + fn build_metadata_response(table_path: &TablePath, table_id: i64) -> MetadataResponse { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + let table_descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor build"); + let table_json = + serde_json::to_vec(&table_descriptor.serialize_json().expect("table json")).unwrap(); + + MetadataResponse { + coordinator_server: Some(PbServerNode { + node_id: 10, + host: "127.0.0.1".to_string(), + port: 9999, + listeners: None, + }), + tablet_servers: vec![PbServerNode { + node_id: 1, + host: "127.0.0.1".to_string(), + port: 9092, + listeners: None, + }], + table_metadata: vec![PbTableMetadata { + table_path: PbTablePath { + database_name: table_path.database().to_string(), + table_name: table_path.table().to_string(), + }, + table_id, + schema_id: 1, + table_json, + bucket_metadata: vec![PbBucketMetadata { + bucket_id: 0, + leader_id: Some(1), + replica_id: vec![1], + }], + created_time: 0, + modified_time: 0, + }], + partition_metadata: vec![], + } + } #[test] fn leader_for_returns_server() { let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let cluster = build_cluster_arc(&table_path, 1, 1); + let cluster = build_cluster(&table_path, 1); let metadata = Metadata::new_for_test(cluster); let leader = metadata .leader_for(&TableBucket::new(1, 0)) @@ -182,10 +306,87 @@ mod tests { #[test] fn invalidate_server_removes_leader() { let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let cluster = build_cluster_arc(&table_path, 1, 1); + let cluster = build_cluster(&table_path, 1); let metadata = Metadata::new_for_test(cluster); metadata.invalidate_server(&1, vec![1]); let cluster = metadata.get_cluster(); assert!(cluster.get_tablet_server(1).is_none()); } + + #[tokio::test] + async fn update_replaces_cluster_state() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let metadata = Metadata::new_for_test(Arc::new(Cluster::default())); + + let response = build_metadata_response(&table_path, 1); + metadata.update(response).await?; + + let cluster = metadata.get_cluster(); + assert!(cluster.get_tablet_server(1).is_some()); + assert!(cluster.opt_get_table(&table_path).is_some()); + Ok(()) + } + + #[tokio::test] + async fn check_and_update_table_metadata_noop_when_present() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let cluster = build_cluster(&table_path, 1); + let metadata = Metadata::new_for_test(cluster); + metadata + .check_and_update_table_metadata(std::slice::from_ref(&table_path)) + .await?; + let cluster = metadata.get_cluster(); + assert!(cluster.opt_get_table(&table_path).is_some()); + Ok(()) + } + + #[tokio::test] + async fn update_tables_metadata_refreshes_cluster() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let metadata = Metadata::new_for_test(build_cluster_with_server(server.clone())); + let response = build_metadata_response(&table_path, 1); + let response_bytes = response.encode_to_vec(); + let (connection, handle) = + build_mock_connection(move |api_key, _, _| match i16::from(api_key) { + API_UPDATE_METADATA => response_bytes.clone(), + _ => vec![], + }) + .await; + metadata + .connections + .insert_connection_for_test(&server, connection); + + metadata + .update_tables_metadata(&HashSet::from([&table_path])) + .await?; + assert!(metadata.get_cluster().opt_get_table(&table_path).is_some()); + handle.abort(); + Ok(()) + } + + #[tokio::test] + async fn check_and_update_table_metadata_triggers_update() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let server = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let metadata = Metadata::new_for_test(build_cluster_with_server(server.clone())); + let response = build_metadata_response(&table_path, 1); + let response_bytes = response.encode_to_vec(); + let (connection, handle) = + build_mock_connection(move |api_key, _, _| match i16::from(api_key) { + API_UPDATE_METADATA => response_bytes.clone(), + _ => vec![], + }) + .await; + metadata + .connections + .insert_connection_for_test(&server, connection); + + metadata + .check_and_update_table_metadata(std::slice::from_ref(&table_path)) + .await?; + assert!(metadata.get_cluster().opt_get_table(&table_path).is_some()); + handle.abort(); + Ok(()) + } } diff --git a/crates/fluss/src/client/mod.rs b/crates/fluss/src/client/mod.rs index cff218b3..27798a30 100644 --- a/crates/fluss/src/client/mod.rs +++ b/crates/fluss/src/client/mod.rs @@ -28,3 +28,6 @@ pub use credentials::*; pub use metadata::*; pub use table::*; pub use write::*; + +#[cfg(test)] +pub(crate) use table::log_fetch_buffer::{CompletedFetch, FetchErrorContext}; diff --git a/crates/fluss/src/client/table/log_fetch_buffer.rs b/crates/fluss/src/client/table/log_fetch_buffer.rs index 214a79cd..17be8c5f 100644 --- a/crates/fluss/src/client/table/log_fetch_buffer.rs +++ b/crates/fluss/src/client/table/log_fetch_buffer.rs @@ -652,9 +652,17 @@ mod tests { ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, }; use crate::metadata::{DataField, DataTypes, RowType, TablePath}; - use crate::record::{MemoryLogRecordsArrowBuilder, ReadContext, to_arrow_schema}; + use crate::record::{ + LENGTH_LENGTH, LENGTH_OFFSET, LOG_OVERHEAD, MemoryLogRecordsArrowBuilder, + RECORDS_COUNT_LENGTH, RECORDS_COUNT_OFFSET, RECORDS_OFFSET, ReadContext, to_arrow_schema, + }; use crate::row::GenericRow; + use crate::test_utils::{ + TestCompletedFetch, build_read_context_for_int32, build_single_int_scan_record, + }; + use std::collections::HashSet; use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; use std::time::Duration; fn test_read_context() -> Result { @@ -712,6 +720,83 @@ mod tests { assert!(completed.take_error().is_some()); } + #[test] + fn buffered_buckets_include_pending_and_next_in_line() { + let buffer = LogFetchBuffer::new(build_read_context_for_int32()); + let bucket_pending = TableBucket::new(1, 0); + let bucket_next = TableBucket::new(1, 1); + let bucket_completed = TableBucket::new(1, 2); + + buffer.pend(Box::new(ErrorPendingFetch { + table_bucket: bucket_pending.clone(), + })); + buffer.set_next_in_line_fetch(Some(Box::new(TestCompletedFetch::new(bucket_next.clone())))); + buffer.add(Box::new(TestCompletedFetch::new(bucket_completed.clone()))); + + let buckets: HashSet = buffer.buffered_buckets().into_iter().collect(); + assert!(buckets.contains(&bucket_pending)); + assert!(buckets.contains(&bucket_next)); + assert!(buckets.contains(&bucket_completed)); + } + + #[test] + fn pended_buckets_only_returns_pending() { + let buffer = LogFetchBuffer::new(build_read_context_for_int32()); + let bucket_pending = TableBucket::new(1, 0); + buffer.pend(Box::new(ErrorPendingFetch { + table_bucket: bucket_pending.clone(), + })); + buffer.add(Box::new(TestCompletedFetch::new(TableBucket::new(1, 1)))); + + let pending: HashSet = buffer.pending_fetches.lock().keys().cloned().collect(); + assert_eq!(pending, HashSet::from([bucket_pending])); + } + + #[test] + fn add_with_pending_keeps_buffer_empty_until_completed() { + struct PendingGate { + table_bucket: TableBucket, + completed: Arc, + } + + impl PendingFetch for PendingGate { + fn table_bucket(&self) -> &TableBucket { + &self.table_bucket + } + + fn is_completed(&self) -> bool { + self.completed.load(Ordering::Acquire) + } + + fn to_completed_fetch(self: Box) -> Result> { + Ok(Box::new(TestCompletedFetch::new(self.table_bucket.clone()))) + } + } + + let buffer = LogFetchBuffer::new(build_read_context_for_int32()); + let bucket = TableBucket::new(1, 0); + let completed = Arc::new(AtomicBool::new(false)); + let pending = PendingGate { + table_bucket: bucket.clone(), + completed: completed.clone(), + }; + buffer.pend(Box::new(pending)); + + buffer.add(Box::new(TestCompletedFetch::new(bucket.clone()))); + assert!(buffer.is_empty()); + + { + let pending = buffer.pending_fetches.lock(); + let entry = pending.get(&bucket).expect("pending"); + assert_eq!(entry.len(), 2); + } + + completed.store(true, Ordering::Release); + + buffer.try_complete(&bucket); + assert!(!buffer.is_empty()); + } + #[test] fn default_completed_fetch_reads_records() -> Result<()> { let row_type = RowType::new(vec![ @@ -757,4 +842,188 @@ mod tests { Ok(()) } + + #[test] + fn default_completed_fetch_propagates_error_for_records() { + let read_context = build_read_context_for_int32(); + let mut fetch = DefaultCompletedFetch::from_error( + TableBucket::new(1, 0), + Error::UnexpectedError { + message: "fetch failed".to_string(), + source: None, + }, + 0, + read_context, + ); + + let err = match fetch.fetch_records(1) { + Ok(_) => panic!("expected error"), + Err(err) => err, + }; + assert!(matches!(err, Error::UnexpectedError { .. })); + } + + #[test] + fn default_completed_fetch_propagates_error_for_batches() { + let read_context = build_read_context_for_int32(); + let mut fetch = DefaultCompletedFetch::from_error( + TableBucket::new(1, 0), + Error::UnexpectedError { + message: "fetch failed".to_string(), + source: None, + }, + 0, + read_context, + ); + + let err = fetch.fetch_batches(1).expect_err("expected error"); + assert!(matches!(err, Error::UnexpectedError { .. })); + } + + #[test] + fn default_completed_fetch_propagates_api_error_for_records() { + let read_context = build_read_context_for_int32(); + let fetch_error_context = FetchErrorContext { + action: FetchErrorAction::Authorization, + log_level: FetchErrorLogLevel::Warn, + log_message: "authorization failed".to_string(), + }; + let mut fetch = DefaultCompletedFetch::from_api_error( + TableBucket::new(1, 0), + ApiError { + code: 7, + message: "auth failed".to_string(), + }, + fetch_error_context, + 0, + read_context, + ); + + let err = match fetch.fetch_records(1) { + Ok(_) => panic!("expected api error"), + Err(err) => err, + }; + match err { + Error::FlussAPIError { api_error } => { + assert_eq!(api_error.code, 7); + assert_eq!(api_error.message, "auth failed"); + } + _ => panic!("unexpected error type"), + } + } + + #[test] + fn default_completed_fetch_propagates_api_error_for_batches() { + let read_context = build_read_context_for_int32(); + let fetch_error_context = FetchErrorContext { + action: FetchErrorAction::Authorization, + log_level: FetchErrorLogLevel::Warn, + log_message: "authorization failed".to_string(), + }; + let mut fetch = DefaultCompletedFetch::from_api_error( + TableBucket::new(1, 0), + ApiError { + code: 7, + message: "auth failed".to_string(), + }, + fetch_error_context, + 0, + read_context, + ); + + let err = fetch.fetch_batches(1).expect_err("expected api error"); + match err { + Error::FlussAPIError { api_error } => { + assert_eq!(api_error.code, 7); + assert_eq!(api_error.message, "auth failed"); + } + _ => panic!("unexpected error type"), + } + } + + #[test] + fn default_completed_fetch_returns_error_on_corrupt_last_record() { + let read_context = build_read_context_for_int32(); + let mut fetch = DefaultCompletedFetch::new( + TableBucket::new(1, 0), + LogRecordsBatches::new(Vec::new()), + 0, + read_context, + 0, + 0, + ); + fetch.corrupt_last_record = true; + + let err = match fetch.fetch_records(1) { + Ok(_) => panic!("expected error"), + Err(err) => err, + }; + assert!(matches!(err, Error::UnexpectedError { .. })); + } + + #[test] + fn default_completed_fetch_returns_error_when_cached_error_without_records() { + let read_context = build_read_context_for_int32(); + let mut fetch = DefaultCompletedFetch::new( + TableBucket::new(1, 0), + LogRecordsBatches::new(Vec::new()), + 0, + read_context, + 0, + 0, + ); + fetch.cached_record_error = Some("decode failure".to_string()); + + let err = match fetch.fetch_records(1) { + Ok(_) => panic!("expected error"), + Err(err) => err, + }; + match err { + Error::UnexpectedError { message, .. } => { + assert!(message.contains("decode failure")); + } + _ => panic!("unexpected error type"), + } + } + + #[test] + fn default_completed_fetch_returns_partial_records_when_cached_error_after_record() { + let read_context = build_read_context_for_int32(); + let mut fetch = DefaultCompletedFetch::new( + TableBucket::new(1, 0), + LogRecordsBatches::new(Vec::new()), + 0, + read_context, + 0, + 0, + ); + fetch.cached_record_error = Some("decode failure".to_string()); + fetch.last_record = Some(build_single_int_scan_record()); + + let records = fetch.fetch_records(1).expect("records"); + assert_eq!(records.len(), 1); + } + + #[test] + fn default_completed_fetch_returns_error_on_invalid_batch_payload() { + let read_context = build_read_context_for_int32(); + let total_len = RECORDS_OFFSET; + let batch_len = total_len - LOG_OVERHEAD; + let mut data = vec![0_u8; total_len]; + data[LENGTH_OFFSET..LENGTH_OFFSET + LENGTH_LENGTH] + .copy_from_slice(&(batch_len as i32).to_le_bytes()); + data[RECORDS_COUNT_OFFSET..RECORDS_COUNT_OFFSET + RECORDS_COUNT_LENGTH] + .copy_from_slice(&1_i32.to_le_bytes()); + let mut fetch = DefaultCompletedFetch::new( + TableBucket::new(1, 0), + LogRecordsBatches::new(data), + total_len, + read_context, + 0, + 0, + ); + + let err = fetch.fetch_batches(1).expect_err("expected error"); + assert!(matches!(err, Error::ArrowError { .. })); + } } diff --git a/crates/fluss/src/client/table/mod.rs b/crates/fluss/src/client/table/mod.rs index 2bfa0541..2940cf2b 100644 --- a/crates/fluss/src/client/table/mod.rs +++ b/crates/fluss/src/client/table/mod.rs @@ -26,7 +26,7 @@ pub const EARLIEST_OFFSET: i64 = -2; mod append; mod lookup; -mod log_fetch_buffer; +pub(crate) mod log_fetch_buffer; mod partition_getter; mod remote_log; mod scanner; diff --git a/crates/fluss/src/client/table/remote_log.rs b/crates/fluss/src/client/table/remote_log.rs index 01425157..c977ffbe 100644 --- a/crates/fluss/src/client/table/remote_log.rs +++ b/crates/fluss/src/client/table/remote_log.rs @@ -414,3 +414,252 @@ impl PendingFetch for RemotePendingFetch { Ok(Box::new(completed_fetch)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::client::WriteRecord; + use crate::compression::{ + ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }; + use crate::metadata::{DataField, DataTypes, RowType, TablePath}; + use crate::record::{MemoryLogRecordsArrowBuilder, to_arrow_schema}; + use crate::row::{Datum, GenericRow}; + use std::collections::HashMap; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + use tokio::time::{Duration, timeout}; + + fn build_log_bytes() -> Result> { + let row_type = RowType::new(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let mut builder = MemoryLogRecordsArrowBuilder::new( + 1, + &row_type, + false, + ArrowCompressionInfo { + compression_type: ArrowCompressionType::None, + compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }, + )?; + let record = WriteRecord::for_append( + table_path, + 1, + GenericRow { + values: vec![Datum::Int32(1)], + }, + ); + builder.append(&record)?; + builder.build() + } + + #[test] + fn remote_log_segment_local_file_name() { + let bucket = TableBucket::new(1, 0); + let segment = RemoteLogSegment { + segment_id: "seg".to_string(), + start_offset: 12, + end_offset: 20, + size_in_bytes: 0, + table_bucket: bucket, + }; + assert_eq!(segment.local_file_name(), "seg_00000000000000000012.log"); + } + + #[test] + fn remote_log_fetch_info_from_proto_defaults_start_pos() { + let bucket = TableBucket::new(1, 0); + let proto_segment = PbRemoteLogSegment { + remote_log_segment_id: "seg".to_string(), + remote_log_start_offset: 0, + remote_log_end_offset: 10, + segment_size_in_bytes: 1, + }; + let proto_info = PbRemoteLogFetchInfo { + remote_log_tablet_dir: "/tmp/remote".to_string(), + partition_name: None, + remote_log_segments: vec![proto_segment], + first_start_pos: None, + }; + let info = RemoteLogFetchInfo::from_proto(&proto_info, bucket.clone()); + assert_eq!(info.first_start_pos, 0); + assert_eq!(info.remote_log_segments.len(), 1); + assert_eq!(info.remote_log_segments[0].table_bucket, bucket); + } + + #[tokio::test] + async fn download_future_callbacks_fire() -> Result<()> { + let (tx, rx) = oneshot::channel(); + let future = RemoteLogDownloadFuture::new(rx); + let fired = Arc::new(AtomicBool::new(false)); + let fired_clone = fired.clone(); + future.on_complete(move || { + fired_clone.store(true, Ordering::Release); + }); + tx.send(Ok(vec![1, 2, 3])).unwrap(); + + let _ = timeout(Duration::from_millis(50), async { + while !future.is_done() { + tokio::task::yield_now().await; + } + }) + .await; + + assert!(fired.load(Ordering::Acquire)); + Ok(()) + } + + #[tokio::test] + async fn download_future_get_bytes_errors() -> Result<()> { + let (tx, rx) = oneshot::channel(); + let future = RemoteLogDownloadFuture::new(rx); + assert!(matches!( + future.get_remote_log_bytes(), + Err(Error::IoUnexpectedError { .. }) + )); + tx.send(Err(Error::UnexpectedError { + message: "boom".to_string(), + source: None, + })) + .unwrap(); + + let _ = timeout(Duration::from_millis(50), async { + while !future.is_done() { + tokio::task::yield_now().await; + } + }) + .await; + + let err = future.get_remote_log_bytes().unwrap_err(); + assert!(matches!(err, Error::IoUnexpectedError { .. })); + Ok(()) + } + + #[tokio::test] + async fn download_future_canceled_returns_error() -> Result<()> { + let (tx, rx) = oneshot::channel::>>(); + drop(tx); + let future = RemoteLogDownloadFuture::new(rx); + let _ = timeout(Duration::from_millis(50), async { + while !future.is_done() { + tokio::task::yield_now().await; + } + }) + .await; + + let err = future.get_remote_log_bytes().unwrap_err(); + assert!(matches!(err, Error::IoUnexpectedError { .. })); + Ok(()) + } + + #[tokio::test] + async fn download_file_rejects_invalid_url() -> Result<()> { + let temp_dir = TempDir::new()?; + let local_path = temp_dir.path().join("local.log"); + let result = RemoteLogDownloader::download_file( + "://", + "/tmp/missing.log", + &local_path, + &HashMap::new(), + ) + .await; + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + #[tokio::test] + async fn download_file_missing_remote_path_returns_error() -> Result<()> { + let temp_dir = TempDir::new()?; + let remote_dir = temp_dir.path(); + let local_path = temp_dir.path().join("local.log"); + let remote_path = remote_dir.join("missing.log"); + let result = RemoteLogDownloader::download_file( + remote_dir.to_str().unwrap(), + remote_path.to_str().unwrap(), + &local_path, + &HashMap::new(), + ) + .await; + assert!(matches!( + result, + Err(Error::RemoteStorageUnexpectedError { .. }) + )); + Ok(()) + } + + #[tokio::test] + async fn request_remote_log_reads_local_file() -> Result<()> { + let temp_dir = TempDir::new()?; + let remote_dir = temp_dir.path().join("remote"); + tokio::fs::create_dir_all(&remote_dir).await?; + + let bucket = TableBucket::new(1, 0); + let segment = RemoteLogSegment { + segment_id: "seg".to_string(), + start_offset: 0, + end_offset: 0, + size_in_bytes: 0, + table_bucket: bucket, + }; + + let downloader = RemoteLogDownloader::new(TempDir::new()?)?; + let remote_path = downloader.build_remote_path(remote_dir.to_str().unwrap(), &segment); + let remote_file = PathBuf::from(&remote_path); + if let Some(parent) = remote_file.parent() { + tokio::fs::create_dir_all(parent).await?; + } + tokio::fs::write(&remote_file, b"data").await?; + + let future = downloader.request_remote_log(remote_dir.to_str().unwrap(), &segment); + let _ = timeout(Duration::from_millis(200), async { + while !future.is_done() { + tokio::task::yield_now().await; + } + }) + .await; + + let bytes = future.get_remote_log_bytes()?; + assert_eq!(bytes, b"data"); + Ok(()) + } + + #[tokio::test] + async fn remote_pending_fetch_to_completed_fetch() -> Result<()> { + let bytes = build_log_bytes()?; + let (tx, rx) = oneshot::channel(); + tx.send(Ok(bytes)).unwrap(); + let future = RemoteLogDownloadFuture::new(rx); + let _ = timeout(Duration::from_millis(50), async { + while !future.is_done() { + tokio::task::yield_now().await; + } + }) + .await; + + let row_type = RowType::new(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let read_context = ReadContext::new(to_arrow_schema(&row_type)?, false); + + let bucket = TableBucket::new(1, 0); + let segment = RemoteLogSegment { + segment_id: "seg".to_string(), + start_offset: 0, + end_offset: 0, + size_in_bytes: 0, + table_bucket: bucket.clone(), + }; + + let pending = RemotePendingFetch::new(segment, future, 0, 0, 0, read_context); + let mut completed = Box::new(pending).to_completed_fetch()?; + let records = completed.fetch_records(10)?; + assert_eq!(records.len(), 1); + Ok(()) + } +} diff --git a/crates/fluss/src/client/table/scanner.rs b/crates/fluss/src/client/table/scanner.rs index afa44f35..c21ce86f 100644 --- a/crates/fluss/src/client/table/scanner.rs +++ b/crates/fluss/src/client/table/scanner.rs @@ -1466,18 +1466,104 @@ impl BucketScanStatus { #[cfg(test)] mod tests { use super::*; + use crate::client::FlussConnection; use crate::client::WriteRecord; use crate::client::metadata::Metadata; + use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; use crate::compression::{ ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, }; + use crate::config::Config; use crate::metadata::{TableInfo, TablePath}; + use crate::proto::{FetchLogResponse, PbFetchLogRespForBucket, PbFetchLogRespForTable}; use crate::record::MemoryLogRecordsArrowBuilder; use crate::row::{Datum, GenericRow}; use crate::rpc::FlussError; - use crate::test_utils::{build_cluster_arc, build_table_info}; + use crate::test_utils::TestCompletedFetch; + use crate::test_utils::{build_cluster_arc, build_mock_connection, build_table_info}; + use prost::Message; + use std::collections::HashMap; + use std::time::Duration; + use tokio::io::BufStream; + use tokio::time::timeout; + + const DEFAULT_TABLE_ID: i64 = 1; + const DEFAULT_BUCKETS: i32 = 1; + + fn default_table_path() -> TablePath { + TablePath::new("db".to_string(), "tbl".to_string()) + } + + struct ScannerTestEnv { + table_path: TablePath, + table_info: TableInfo, + metadata: Arc, + status: Arc, + rpc_client: Arc, + } + + impl ScannerTestEnv { + fn new() -> Self { + let table_path = default_table_path(); + let table_info = + build_table_info(table_path.clone(), DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + let cluster = build_cluster_arc(&table_path, DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + Self::with_table_info_and_cluster(table_info, cluster) + } + + fn with_table_info_and_cluster(table_info: TableInfo, cluster: Arc) -> Self { + let metadata = Arc::new(Metadata::new_for_test(cluster)); + let status = Arc::new(LogScannerStatus::new()); + let rpc_client = Arc::new(RpcClient::new()); + let table_path = table_info.table_path.clone(); + Self { + table_path, + table_info, + metadata, + status, + rpc_client, + } + } + + fn fetcher(&self, projected_fields: Option>) -> Result { + LogFetcher::new( + self.table_info.clone(), + self.rpc_client.clone(), + self.metadata.clone(), + self.status.clone(), + projected_fields, + ) + } - fn build_records(table_info: &TableInfo, table_path: Arc) -> Result> { + fn inner(&self, projected_fields: Option>) -> Result { + LogScannerInner::new( + &self.table_info, + self.metadata.clone(), + self.rpc_client.clone(), + projected_fields, + ) + } + + fn connection(&self) -> FlussConnection { + FlussConnection::new_for_test( + self.metadata.clone(), + self.rpc_client.clone(), + Config::default(), + ) + } + + fn assign_bucket(&self, bucket_id: i32, offset: i64) -> TableBucket { + let bucket = TableBucket::new(self.table_info.table_id, bucket_id); + self.status.assign_scan_bucket(bucket.clone(), offset); + bucket + } + + fn build_records(&self) -> Result> { + build_records(&self.table_info, &self.table_path) + } + } + + fn build_records(table_info: &TableInfo, table_path: &TablePath) -> Result> { let mut builder = MemoryLogRecordsArrowBuilder::new( 1, table_info.get_row_type(), @@ -1488,8 +1574,8 @@ mod tests { }, )?; let record = WriteRecord::for_append( - table_path, - 1, + Arc::new(table_path.clone()), + table_info.schema_id, GenericRow { values: vec![Datum::Int32(1)], }, @@ -1498,56 +1584,177 @@ mod tests { builder.build() } + fn build_cluster_with_leader( + table_info: &TableInfo, + leader: Option, + include_server: bool, + ) -> Arc { + let table_bucket = TableBucket::new(table_info.table_id, 0); + let mut servers = HashMap::new(); + if include_server { + if let Some(server) = leader.clone() { + servers.insert(server.id(), server); + } + } + let location = + BucketLocation::new(table_bucket.clone(), leader, table_info.table_path.clone()); + let locations_by_path = + HashMap::from([(table_info.table_path.clone(), vec![location.clone()])]); + let locations_by_bucket = HashMap::from([(table_bucket, location)]); + let table_id_by_path = + HashMap::from([(table_info.table_path.clone(), table_info.table_id)]); + let table_info_by_path = + HashMap::from([(table_info.table_path.clone(), table_info.clone())]); + Arc::new(Cluster::new( + None, + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + )) + } + + async fn collect_result_for_error( + error: FlussError, + ) -> Result>, Error>> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + let response = FetchLogResponse { + tables_resp: vec![PbFetchLogRespForTable { + table_id: env.table_info.table_id, + buckets_resp: vec![PbFetchLogRespForBucket { + partition_id: None, + bucket_id: 0, + error_code: Some(error.code()), + error_message: Some("err".to_string()), + high_watermark: None, + log_start_offset: None, + remote_log_fetch_info: None, + records: None, + }], + }], + }; + + let response_context = FetchResponseContext { + metadata: env.metadata.clone(), + log_fetch_buffer: fetcher.log_fetch_buffer.clone(), + log_scanner_status: fetcher.log_scanner_status.clone(), + read_context: fetcher.read_context.clone(), + remote_read_context: fetcher.remote_read_context.clone(), + remote_log_downloader: fetcher.remote_log_downloader.clone(), + credentials_cache: fetcher.credentials_cache.clone(), + }; + + LogFetcher::handle_fetch_response(response, response_context).await; + + Ok(fetcher.collect_fetches()) + } + + #[test] + fn project_rejects_empty_indices() -> Result<()> { + let env = ScannerTestEnv::new(); + let conn = env.connection(); + + let builder = TableScan::new(&conn, env.table_info.clone(), env.metadata.clone()); + let result = builder.project(&[]); + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + #[test] + fn project_rejects_out_of_range_index() -> Result<()> { + let env = ScannerTestEnv::new(); + let conn = env.connection(); + + let builder = TableScan::new(&conn, env.table_info.clone(), env.metadata.clone()); + let result = builder.project(&[1]); + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + #[test] + fn project_by_name_rejects_empty() -> Result<()> { + let env = ScannerTestEnv::new(); + let conn = env.connection(); + + let builder = TableScan::new(&conn, env.table_info.clone(), env.metadata.clone()); + let result = builder.project_by_name(&[]); + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + #[test] + fn project_by_name_rejects_missing_column() -> Result<()> { + let env = ScannerTestEnv::new(); + let conn = env.connection(); + + let builder = TableScan::new(&conn, env.table_info.clone(), env.metadata.clone()); + let result = builder.project_by_name(&["missing"]); + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + #[test] + fn log_fetcher_rejects_invalid_projection() -> Result<()> { + let env = ScannerTestEnv::new(); + + let result = env.fetcher(Some(vec![1])); + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + Ok(()) + } + + async fn wait_for_leader_removal( + metadata: &Metadata, + table_bucket: &TableBucket, + ) -> Result<()> { + timeout(Duration::from_millis(500), async { + loop { + if metadata.get_cluster().leader_for(table_bucket).is_none() { + break; + } + tokio::task::yield_now().await; + } + }) + .await + .map_err(|_| Error::UnexpectedError { + message: "Timeout waiting for leader removal".to_string(), + source: None, + })?; + Ok(()) + } + #[tokio::test] async fn collect_fetches_updates_offset() -> Result<()> { - let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let table_info = build_table_info(table_path.clone(), 1, 1); - let cluster = build_cluster_arc(&table_path, 1, 1); - let metadata = Arc::new(Metadata::new_for_test(cluster)); - let status = Arc::new(LogScannerStatus::new()); - let fetcher = LogFetcher::new( - table_info.clone(), - Arc::new(RpcClient::new()), - metadata, - status.clone(), - None, - )?; + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; - let bucket = TableBucket::new(1, 0); - status.assign_scan_bucket(bucket.clone(), 0); + let bucket = env.assign_bucket(0, 0); - let data = build_records(&table_info, Arc::new(table_path))?; + let data = env.build_records()?; let log_records = LogRecordsBatches::new(data.clone()); - let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type())?, false); + let read_context = ReadContext::new(to_arrow_schema(env.table_info.get_row_type())?, false); let completed = DefaultCompletedFetch::new(bucket.clone(), log_records, data.len(), read_context, 0, 0); fetcher.log_fetch_buffer.add(Box::new(completed)); let fetched = fetcher.collect_fetches()?; assert_eq!(fetched.get(&bucket).unwrap().len(), 1); - assert_eq!(status.get_bucket_offset(&bucket), Some(1)); + assert_eq!(env.status.get_bucket_offset(&bucket), Some(1)); Ok(()) } #[test] fn fetch_records_from_fetch_drains_unassigned_bucket() -> Result<()> { - let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let table_info = build_table_info(table_path.clone(), 1, 1); - let cluster = build_cluster_arc(&table_path, 1, 1); - let metadata = Arc::new(Metadata::new_for_test(cluster)); - let status = Arc::new(LogScannerStatus::new()); - let fetcher = LogFetcher::new( - table_info.clone(), - Arc::new(RpcClient::new()), - metadata, - status, - None, - )?; + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; - let bucket = TableBucket::new(1, 0); - let data = build_records(&table_info, Arc::new(table_path))?; + let bucket = TableBucket::new(env.table_info.table_id, 0); + let data = env.build_records()?; let log_records = LogRecordsBatches::new(data.clone()); - let read_context = ReadContext::new(to_arrow_schema(table_info.get_row_type())?, false); + let read_context = ReadContext::new(to_arrow_schema(env.table_info.get_row_type())?, false); let mut completed: Box = Box::new(DefaultCompletedFetch::new( bucket, log_records, @@ -1565,19 +1772,9 @@ mod tests { #[tokio::test] async fn prepare_fetch_log_requests_skips_pending() -> Result<()> { - let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let table_info = build_table_info(table_path.clone(), 1, 1); - let cluster = build_cluster_arc(&table_path, 1, 1); - let metadata = Arc::new(Metadata::new_for_test(cluster)); - let status = Arc::new(LogScannerStatus::new()); - status.assign_scan_bucket(TableBucket::new(1, 0), 0); - let fetcher = LogFetcher::new( - table_info, - Arc::new(RpcClient::new()), - metadata, - status, - None, - )?; + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; fetcher.nodes_with_pending_fetch_requests.lock().insert(1); @@ -1586,25 +1783,43 @@ mod tests { Ok(()) } + #[tokio::test] + async fn prepare_fetch_log_requests_skips_without_leader() -> Result<()> { + let table_path = default_table_path(); + let table_info = build_table_info(table_path.clone(), DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + let cluster = build_cluster_with_leader(&table_info, None, false); + let env = ScannerTestEnv::with_table_info_and_cluster(table_info, cluster); + env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + let requests = fetcher.prepare_fetch_log_requests().await; + assert!(requests.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn prepare_fetch_log_requests_sets_projection() -> Result<()> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 0); + let fetcher = env.fetcher(Some(vec![0]))?; + + let requests = fetcher.prepare_fetch_log_requests().await; + let request = requests.get(&1).expect("fetch request"); + let table_req = request.tables_req.first().expect("table request"); + assert!(table_req.projection_pushdown_enabled); + assert_eq!(table_req.projected_fields, vec![0]); + Ok(()) + } + #[tokio::test] async fn handle_fetch_response_sets_error() -> Result<()> { - let table_path = TablePath::new("db".to_string(), "tbl".to_string()); - let table_info = build_table_info(table_path.clone(), 1, 1); - let cluster = build_cluster_arc(&table_path, 1, 1); - let metadata = Arc::new(Metadata::new_for_test(cluster)); - let status = Arc::new(LogScannerStatus::new()); - status.assign_scan_bucket(TableBucket::new(1, 0), 5); - let fetcher = LogFetcher::new( - table_info.clone(), - Arc::new(RpcClient::new()), - metadata.clone(), - status.clone(), - None, - )?; + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 5); + let fetcher = env.fetcher(None)?; let response = crate::proto::FetchLogResponse { tables_resp: vec![crate::proto::PbFetchLogRespForTable { - table_id: 1, + table_id: env.table_info.table_id, buckets_resp: vec![crate::proto::PbFetchLogRespForBucket { partition_id: None, bucket_id: 0, @@ -1619,7 +1834,7 @@ mod tests { }; let response_context = FetchResponseContext { - metadata: metadata.clone(), + metadata: env.metadata.clone(), log_fetch_buffer: fetcher.log_fetch_buffer.clone(), log_scanner_status: fetcher.log_scanner_status.clone(), read_context: fetcher.read_context.clone(), @@ -1636,6 +1851,406 @@ mod tests { Ok(()) } + #[tokio::test] + async fn collect_fetches_ignores_retriable_errors() -> Result<()> { + let ignore_errors = [ + FlussError::NotLeaderOrFollower, + FlussError::LogStorageException, + FlussError::KvStorageException, + FlussError::StorageException, + FlussError::FencedLeaderEpochException, + FlussError::UnknownTableOrBucketException, + FlussError::UnknownServerError, + ]; + + for error in ignore_errors { + let result = collect_result_for_error(error).await?; + assert!( + matches!(result, Ok(records) if records.is_empty()), + "unexpected result for {error:?}" + ); + } + Ok(()) + } + + #[tokio::test] + async fn collect_fetches_returns_error_for_corrupt_or_unexpected() -> Result<()> { + let error_cases = [ + FlussError::CorruptMessage, + FlussError::InvalidTableException, + ]; + + for error in error_cases { + let result = collect_result_for_error(error).await?; + assert!( + matches!(result, Err(Error::UnexpectedError { .. })), + "unexpected result for {error:?}" + ); + } + Ok(()) + } + + #[tokio::test] + async fn send_fetches_invalidates_missing_server() -> Result<()> { + let table_path = default_table_path(); + let table_info = build_table_info(table_path.clone(), DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + let leader = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let cluster = build_cluster_with_leader(&table_info, Some(leader), false); + let env = ScannerTestEnv::with_table_info_and_cluster(table_info, cluster); + let bucket = env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + fetcher.send_fetches().await?; + wait_for_leader_removal(&env.metadata, &bucket).await?; + Ok(()) + } + + #[tokio::test] + async fn send_fetches_invalidates_on_request_error() -> Result<()> { + let table_path = default_table_path(); + let table_info = build_table_info(table_path.clone(), DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + let leader = ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer); + let cluster = build_cluster_with_leader(&table_info, Some(leader.clone()), true); + let env = ScannerTestEnv::with_table_info_and_cluster(table_info, cluster); + let bucket = env.assign_bucket(0, 0); + let rpc_client = env.rpc_client.clone(); + + let (client, server) = tokio::io::duplex(1024); + drop(server); + let transport = crate::rpc::Transport::Test { inner: client }; + let connection = Arc::new(crate::rpc::ServerConnectionInner::new( + BufStream::new(transport), + usize::MAX, + Arc::from(""), + )); + rpc_client.insert_connection_for_test(&leader, connection); + + let fetcher = env.fetcher(None)?; + fetcher.send_fetches().await?; + wait_for_leader_removal(&env.metadata, &bucket).await?; + Ok(()) + } + + #[tokio::test] + async fn send_fetches_invalidates_on_connection_error() -> Result<()> { + let table_path = default_table_path(); + let table_info = build_table_info(table_path.clone(), DEFAULT_TABLE_ID, DEFAULT_BUCKETS); + let leader = ServerNode::new(1, "127.0.0.1".to_string(), 1, ServerType::TabletServer); + let cluster = build_cluster_with_leader(&table_info, Some(leader.clone()), true); + let env = ScannerTestEnv::with_table_info_and_cluster(table_info, cluster); + let bucket = env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + fetcher.send_fetches().await?; + wait_for_leader_removal(&env.metadata, &bucket).await?; + Ok(()) + } + + #[tokio::test] + async fn handle_fetch_response_records_are_collected() -> Result<()> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + let records = env.build_records()?; + let response = FetchLogResponse { + tables_resp: vec![PbFetchLogRespForTable { + table_id: env.table_info.table_id, + buckets_resp: vec![PbFetchLogRespForBucket { + partition_id: None, + bucket_id: 0, + error_code: None, + error_message: None, + high_watermark: Some(5), + log_start_offset: Some(0), + remote_log_fetch_info: None, + records: Some(records), + }], + }], + }; + + let response_context = FetchResponseContext { + metadata: env.metadata.clone(), + log_fetch_buffer: fetcher.log_fetch_buffer.clone(), + log_scanner_status: fetcher.log_scanner_status.clone(), + read_context: fetcher.read_context.clone(), + remote_read_context: fetcher.remote_read_context.clone(), + remote_log_downloader: fetcher.remote_log_downloader.clone(), + credentials_cache: fetcher.credentials_cache.clone(), + }; + + LogFetcher::handle_fetch_response(response, response_context).await; + + let fetched = fetcher.collect_fetches()?; + assert_eq!(fetched.get(&TableBucket::new(1, 0)).unwrap().len(), 1); + assert_eq!( + env.status.get_bucket_offset(&TableBucket::new(1, 0)), + Some(1) + ); + Ok(()) + } + + #[tokio::test] + async fn send_fetches_enqueues_completed_fetch() -> Result<()> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 0); + let rpc_client = env.rpc_client.clone(); + + let records = env.build_records()?; + let response = FetchLogResponse { + tables_resp: vec![PbFetchLogRespForTable { + table_id: env.table_info.table_id, + buckets_resp: vec![PbFetchLogRespForBucket { + partition_id: None, + bucket_id: 0, + error_code: None, + error_message: None, + high_watermark: Some(1), + log_start_offset: Some(0), + remote_log_fetch_info: None, + records: Some(records.clone()), + }], + }], + }; + + let (connection, handle) = + build_mock_connection(move |_api_key: crate::rpc::ApiKey, _, _| { + response.encode_to_vec() + }) + .await; + let server_node = env + .metadata + .get_cluster() + .get_tablet_server(1) + .expect("server") + .clone(); + rpc_client.insert_connection_for_test(&server_node, connection); + + let fetcher = env.fetcher(None)?; + + fetcher.send_fetches().await?; + let has_data = fetcher + .log_fetch_buffer + .await_not_empty(Duration::from_millis(200)) + .await?; + assert!(has_data); + + let fetched = fetcher.collect_fetches()?; + assert_eq!(fetched.get(&TableBucket::new(1, 0)).unwrap().len(), 1); + handle.abort(); + Ok(()) + } + + #[test] + fn collect_batches_returns_batches_and_updates_offset() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 0); + + let data = env.build_records()?; + let log_records = LogRecordsBatches::new(data.clone()); + let mut completed = DefaultCompletedFetch::new( + bucket.clone(), + log_records, + data.len(), + fetcher.read_context.clone(), + 0, + 0, + ); + completed.set_initialized(); + fetcher + .log_fetch_buffer + .set_next_in_line_fetch(Some(Box::new(completed))); + + let batches = fetcher.collect_batches()?; + assert_eq!(batches.len(), 1); + assert_eq!(env.status.get_bucket_offset(&bucket), Some(1)); + Ok(()) + } + + #[test] + fn collect_batches_returns_partial_on_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 0); + + let completed = TestCompletedFetch::batch_ok(bucket.clone()); + fetcher + .log_fetch_buffer + .set_next_in_line_fetch(Some(Box::new(completed))); + + let error_fetch = TestCompletedFetch::batch_err(bucket.clone()); + fetcher.log_fetch_buffer.add(Box::new(error_fetch)); + + let batches = fetcher.collect_batches()?; + assert_eq!(batches.len(), 1); + Ok(()) + } + + #[test] + fn fetch_batches_from_fetch_drains_unassigned_bucket() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = TableBucket::new(env.table_info.table_id, 0); + let data = env.build_records()?; + let log_records = LogRecordsBatches::new(data.clone()); + let mut completed: Box = Box::new(DefaultCompletedFetch::new( + bucket, + log_records, + data.len(), + fetcher.read_context.clone(), + 0, + 0, + )); + + let batches = fetcher.fetch_batches_from_fetch(&mut completed, 10)?; + assert!(batches.is_empty()); + assert!(completed.is_consumed()); + Ok(()) + } + + #[test] + fn fetch_batches_from_fetch_returns_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 0); + let mut completed: Box = Box::new(DefaultCompletedFetch::from_error( + bucket, + Error::UnexpectedError { + message: "fetch error".to_string(), + source: None, + }, + 0, + fetcher.read_context.clone(), + )); + + let result = fetcher.fetch_batches_from_fetch(&mut completed, 10); + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + Ok(()) + } + + #[test] + fn fetch_batches_from_fetch_ignores_out_of_order_offset() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 5); + let data = env.build_records()?; + let log_records = LogRecordsBatches::new(data.clone()); + let mut completed: Box = Box::new(DefaultCompletedFetch::new( + bucket, + log_records, + data.len(), + fetcher.read_context.clone(), + 0, + 0, + )); + + let batches = fetcher.fetch_batches_from_fetch(&mut completed, 10)?; + assert!(batches.is_empty()); + assert!(completed.is_consumed()); + Ok(()) + } + + #[test] + fn collect_batches_skips_error_when_empty_and_size_zero() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let error_fetch = DefaultCompletedFetch::from_error( + TableBucket::new(1, 0), + Error::UnexpectedError { + message: "fetch error".to_string(), + source: None, + }, + 0, + fetcher.read_context.clone(), + ); + fetcher.log_fetch_buffer.add(Box::new(error_fetch)); + + let batches = fetcher.collect_batches()?; + assert!(batches.is_empty()); + Ok(()) + } + + #[test] + fn fetch_records_from_fetch_ignores_out_of_order_offset() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 5); + let data = env.build_records()?; + let log_records = LogRecordsBatches::new(data.clone()); + let mut completed: Box = Box::new(DefaultCompletedFetch::new( + bucket, + log_records, + data.len(), + fetcher.read_context.clone(), + 0, + 0, + )); + + let records = fetcher.fetch_records_from_fetch(&mut completed, 10)?; + assert!(records.is_empty()); + assert!(completed.is_consumed()); + Ok(()) + } + + #[test] + fn fetch_records_from_fetch_returns_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 0); + let mut completed: Box = Box::new(DefaultCompletedFetch::from_error( + bucket, + Error::UnexpectedError { + message: "fetch error".to_string(), + source: None, + }, + 0, + fetcher.read_context.clone(), + )); + + let result = fetcher.fetch_records_from_fetch(&mut completed, 10); + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + Ok(()) + } + + #[test] + fn collect_fetches_returns_partial_on_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let fetcher = env.fetcher(None)?; + + let bucket = env.assign_bucket(0, 0); + let completed = TestCompletedFetch::record_ok(bucket.clone()); + fetcher + .log_fetch_buffer + .set_next_in_line_fetch(Some(Box::new(completed))); + + let error_fetch = TestCompletedFetch::record_err(bucket.clone()); + fetcher.log_fetch_buffer.add(Box::new(error_fetch)); + + let result = fetcher.collect_fetches()?; + let records = result.get(&bucket).expect("records"); + assert_eq!(records.len(), 1); + Ok(()) + } + + #[tokio::test] + async fn subscribe_batch_rejects_empty() -> Result<()> { + let env = ScannerTestEnv::new(); + let inner = env.inner(None)?; + + let result = inner.subscribe_batch(&HashMap::new()).await; + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + Ok(()) + } + #[tokio::test] async fn handle_fetch_response_invalidates_table_meta() -> Result<()> { let table_path = TablePath::new("db".to_string(), "tbl".to_string()); @@ -1686,4 +2301,130 @@ mod tests { assert!(metadata.leader_for(&bucket).is_none()); Ok(()) } + + #[tokio::test] + async fn handle_fetch_response_out_of_range_sets_error() -> Result<()> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 5); + let fetcher = env.fetcher(None)?; + + let response = crate::proto::FetchLogResponse { + tables_resp: vec![crate::proto::PbFetchLogRespForTable { + table_id: env.table_info.table_id, + buckets_resp: vec![crate::proto::PbFetchLogRespForBucket { + partition_id: None, + bucket_id: 0, + error_code: Some(FlussError::LogOffsetOutOfRangeException.code()), + error_message: Some("out of range".to_string()), + high_watermark: None, + log_start_offset: None, + remote_log_fetch_info: None, + records: None, + }], + }], + }; + + let response_context = FetchResponseContext { + metadata: env.metadata.clone(), + log_fetch_buffer: fetcher.log_fetch_buffer.clone(), + log_scanner_status: fetcher.log_scanner_status.clone(), + read_context: fetcher.read_context.clone(), + remote_read_context: fetcher.remote_read_context.clone(), + remote_log_downloader: fetcher.remote_log_downloader.clone(), + credentials_cache: fetcher.credentials_cache.clone(), + }; + + LogFetcher::handle_fetch_response(response, response_context).await; + + let completed = fetcher.log_fetch_buffer.poll().expect("completed fetch"); + let result = fetcher.initialize_fetch(completed); + assert!(matches!(result, Err(Error::UnexpectedError { .. }))); + Ok(()) + } + + #[test] + fn initialize_fetch_returns_authorization_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let bucket = env.assign_bucket(0, 0); + let fetcher = env.fetcher(None)?; + + let error_context = LogFetcher::describe_fetch_error( + FlussError::AuthorizationException, + &bucket, + 0, + "denied", + ); + let completed = DefaultCompletedFetch::from_api_error( + bucket, + ApiError { + code: FlussError::AuthorizationException.code(), + message: "denied".to_string(), + }, + error_context, + 0, + fetcher.read_context.clone(), + ); + + let result = fetcher.initialize_fetch(Box::new(completed)); + assert!(matches!(result, Err(Error::FlussAPIError { .. }))); + Ok(()) + } + + #[test] + fn initialize_fetch_discards_stale_offset() -> Result<()> { + let env = ScannerTestEnv::new(); + env.assign_bucket(0, 5); + let fetcher = env.fetcher(None)?; + + let data = env.build_records()?; + let log_records = LogRecordsBatches::new(data.clone()); + let read_context = ReadContext::new(to_arrow_schema(env.table_info.get_row_type())?, false); + let completed: Box = Box::new(DefaultCompletedFetch::new( + TableBucket::new(1, 0), + log_records, + data.len(), + read_context, + 0, + 0, + )); + + let result = fetcher.initialize_fetch(completed)?; + assert!(result.is_none()); + Ok(()) + } + + #[tokio::test] + async fn poll_without_subscription_returns_empty() -> Result<()> { + let env = ScannerTestEnv::new(); + let inner = env.inner(None)?; + let scanner = LogScanner { + inner: Arc::new(inner), + }; + + let result = scanner.poll(Duration::from_millis(1)).await?; + assert!(result.is_empty()); + Ok(()) + } + + #[tokio::test] + async fn poll_records_propagates_wakeup_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let inner = env.inner(None)?; + + inner.log_fetcher.log_fetch_buffer.wakeup(); + let result = inner.poll_records(Duration::from_millis(10)).await; + assert!(matches!(result, Err(Error::WakeupError { .. }))); + Ok(()) + } + + #[tokio::test] + async fn poll_batches_propagates_wakeup_error() -> Result<()> { + let env = ScannerTestEnv::new(); + let inner = env.inner(None)?; + + inner.log_fetcher.log_fetch_buffer.wakeup(); + let result = inner.poll_batches(Duration::from_millis(10)).await; + assert!(matches!(result, Err(Error::WakeupError { .. }))); + Ok(()) + } } diff --git a/crates/fluss/src/cluster/cluster.rs b/crates/fluss/src/cluster/cluster.rs index 2484026a..5a1177f9 100644 --- a/crates/fluss/src/cluster/cluster.rs +++ b/crates/fluss/src/cluster/cluster.rs @@ -316,3 +316,150 @@ impl Cluster { self.table_info_by_path.get(table_path) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::{ + DataField, DataTypes, JsonSerde, Schema, TableBucket, TableDescriptor, TablePath, + }; + use crate::proto::{ + MetadataResponse, PbBucketMetadata, PbServerNode, PbTableMetadata, PbTablePath, + }; + use std::collections::HashMap; + + fn build_table_descriptor() -> TableDescriptor { + let row_type = DataTypes::row(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let mut schema_builder = Schema::builder().with_row_type(&row_type); + let schema = schema_builder.build().expect("schema build"); + TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(2), vec![]) + .build() + .expect("descriptor") + } + + fn build_metadata_response() -> MetadataResponse { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let table_descriptor = build_table_descriptor(); + let table_json = + serde_json::to_vec(&table_descriptor.serialize_json().expect("table json")).unwrap(); + + MetadataResponse { + coordinator_server: Some(PbServerNode { + node_id: 10, + host: "127.0.0.1".to_string(), + port: 9999, + listeners: None, + }), + tablet_servers: vec![ + PbServerNode { + node_id: 1, + host: "127.0.0.1".to_string(), + port: 9092, + listeners: None, + }, + PbServerNode { + node_id: 2, + host: "127.0.0.1".to_string(), + port: 9093, + listeners: None, + }, + ], + table_metadata: vec![PbTableMetadata { + table_path: PbTablePath { + database_name: table_path.database().to_string(), + table_name: table_path.table().to_string(), + }, + table_id: 5, + schema_id: 1, + table_json, + bucket_metadata: vec![ + PbBucketMetadata { + bucket_id: 0, + leader_id: Some(1), + replica_id: vec![1], + }, + PbBucketMetadata { + bucket_id: 1, + leader_id: Some(2), + replica_id: vec![2], + }, + ], + created_time: 10, + modified_time: 20, + }], + partition_metadata: vec![], + } + } + + #[test] + fn cluster_from_metadata_response_populates_servers_and_locations() -> Result<()> { + let response = build_metadata_response(); + let cluster = Cluster::from_metadata_response(response, None)?; + + assert!(cluster.get_coordinator_server().is_some()); + assert!(cluster.get_tablet_server(1).is_some()); + assert!(cluster.get_tablet_server(2).is_some()); + assert_eq!(cluster.table_id_by_path.len(), 1); + + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + assert!(cluster.opt_get_table(&table_path).is_some()); + assert_eq!(cluster.get_bucket_count(&table_path), 2); + + let bucket0 = TableBucket::new(5, 0); + let bucket1 = TableBucket::new(5, 1); + assert_eq!(cluster.leader_for(&bucket0).unwrap().id(), 1); + assert_eq!(cluster.leader_for(&bucket1).unwrap().id(), 2); + Ok(()) + } + + #[test] + fn invalidate_server_removes_locations_for_tables() { + let response = build_metadata_response(); + let cluster = Cluster::from_metadata_response(response, None).expect("cluster"); + + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let updated = cluster.invalidate_server(&1, vec![5]); + assert!(updated.get_tablet_server(1).is_none()); + assert!(updated.opt_get_table(&table_path).is_some()); + assert!(updated.leader_for(&TableBucket::new(5, 0)).is_none()); + } + + #[test] + fn update_replaces_state() -> Result<()> { + let response = build_metadata_response(); + let mut cluster = Cluster::default(); + let next = Cluster::from_metadata_response(response, None)?; + cluster.update(next); + assert!(cluster.get_tablet_server(1).is_some()); + Ok(()) + } + + #[test] + fn leader_for_respects_available_locations() -> Result<()> { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + let mut servers = HashMap::new(); + servers.insert( + 1, + ServerNode::new(1, "127.0.0.1".to_string(), 9092, ServerType::TabletServer), + ); + let bucket = TableBucket::new(1, 0); + let location = BucketLocation::new(bucket.clone(), None, table_path.clone()); + + let cluster = Cluster::new( + None, + servers, + HashMap::from([(table_path.clone(), vec![location])]), + HashMap::new(), + HashMap::from([(table_path, 1)]), + HashMap::new(), + ); + assert!(cluster.leader_for(&bucket).is_none()); + Ok(()) + } +} diff --git a/crates/fluss/src/metadata/table.rs b/crates/fluss/src/metadata/table.rs index f4cf972d..c3ac0ad0 100644 --- a/crates/fluss/src/metadata/table.rs +++ b/crates/fluss/src/metadata/table.rs @@ -1067,3 +1067,208 @@ impl LakeSnapshot { &self.table_buckets_offset } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::metadata::DataTypes; + + #[test] + fn schema_builder_rejects_duplicate_columns() { + let result = Schema::builder() + .column("id", DataTypes::int()) + .column("id", DataTypes::string()) + .build(); + assert!(matches!(result, Err(Error::InvalidTableError { .. }))); + } + + #[test] + fn primary_key_columns_become_non_nullable() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .column("name", DataTypes::string()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + + let id_col = schema.columns().iter().find(|c| c.name() == "id").unwrap(); + assert!(!id_col.data_type().is_nullable()); + } + + #[test] + fn table_descriptor_defaults_bucket_keys_for_primary_key() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .column("p", DataTypes::string()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + + let descriptor = TableDescriptor::builder() + .schema(schema) + .partitioned_by(vec!["p".to_string()]) + .build() + .expect("descriptor"); + + assert_eq!(descriptor.bucket_keys(), vec!["id"]); + assert!(descriptor.has_primary_key()); + } + + #[test] + fn table_descriptor_rejects_bucket_keys_with_partition() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .column("p", DataTypes::string()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + + let result = TableDescriptor::builder() + .schema(schema) + .partitioned_by(vec!["p".to_string()]) + .distributed_by(Some(1), vec!["p".to_string()]) + .build(); + assert!(matches!(result, Err(Error::InvalidTableError { .. }))); + } + + #[test] + fn replication_factor_errors_on_missing_or_invalid() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .build() + .expect("schema"); + let descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(1), vec![]) + .build() + .expect("descriptor"); + + assert!(descriptor.replication_factor().is_err()); + + let mut props = HashMap::new(); + props.insert("table.replication.factor".to_string(), "oops".to_string()); + let descriptor = descriptor.with_properties(props); + assert!(descriptor.replication_factor().is_err()); + } + + #[test] + fn table_info_round_trip_descriptor() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + let descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(3), vec![]) + .comment("tbl") + .build() + .expect("descriptor"); + + let info = TableInfo::of( + TablePath::new("db".to_string(), "tbl".to_string()), + 10, + 1, + descriptor.clone(), + 0, + 0, + ); + let round_trip = info.to_table_descriptor().expect("descriptor"); + assert_eq!( + round_trip.bucket_keys(), + info.bucket_keys + .iter() + .map(|s| s.as_str()) + .collect::>() + ); + assert_eq!(round_trip.comment(), Some("tbl")); + } + + #[test] + fn formats_table_path_and_table_bucket() { + let table_path = TablePath::new("db".to_string(), "tbl".to_string()); + assert_eq!(table_path.database(), "db"); + assert_eq!(table_path.table(), "tbl"); + assert_eq!(format!("{table_path}"), "db.tbl"); + + let bucket = TableBucket::new(10, 2); + assert_eq!(bucket.table_id(), 10); + assert_eq!(bucket.bucket_id(), 2); + assert_eq!(format!("{bucket}"), "TableBucket(table_id=10, bucket=2)"); + } + + #[test] + fn default_bucket_key_detection() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + let descriptor = TableDescriptor::builder() + .schema(schema) + .distributed_by(Some(2), vec![]) + .build() + .expect("descriptor"); + + assert!(descriptor.is_default_bucket_key().expect("default")); + } + + #[test] + fn log_format_and_kv_format_parsing() { + assert_eq!(LogFormat::parse("arrow").unwrap(), LogFormat::ARROW); + assert!(LogFormat::parse("unknown").is_err()); + + assert_eq!(KvFormat::parse("indexed").unwrap(), KvFormat::INDEXED); + assert!(KvFormat::parse("bad").is_err()); + } + + #[test] + fn table_info_flags_and_replication_factor() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .column("p", DataTypes::string()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + let descriptor = TableDescriptor::builder() + .schema(schema) + .partitioned_by(vec!["p".to_string()]) + .distributed_by(Some(2), vec![]) + .build() + .expect("descriptor") + .with_replication_factor(3); + + assert!(descriptor.is_partitioned()); + assert!(descriptor.has_primary_key()); + assert_eq!(descriptor.replication_factor().unwrap(), 3); + + let info = TableInfo::of( + TablePath::new("db".to_string(), "tbl".to_string()), + 10, + 1, + descriptor, + 0, + 0, + ); + + assert!(info.has_primary_key()); + assert!(info.has_bucket_key()); + assert!(info.is_partitioned()); + assert!(info.is_default_bucket_key()); + assert_eq!(info.get_physical_primary_keys(), &["id".to_string()]); + } + + #[test] + fn schema_primary_key_indexes_and_column_names() { + let schema = Schema::builder() + .column("id", DataTypes::int()) + .column("name", DataTypes::string()) + .primary_key(vec!["id".to_string()]) + .build() + .expect("schema"); + + assert_eq!(schema.primary_key_indexes(), vec![0]); + assert_eq!(schema.primary_key_column_names(), vec!["id"]); + assert_eq!(schema.column_names(), vec!["id", "name"]); + } +} diff --git a/crates/fluss/src/record/arrow.rs b/crates/fluss/src/record/arrow.rs index 39114d32..cd24505b 100644 --- a/crates/fluss/src/record/arrow.rs +++ b/crates/fluss/src/record/arrow.rs @@ -522,11 +522,20 @@ impl LogRecordBatch { } pub fn ensure_valid(&self) -> Result<()> { - // TODO enable validation once checksum handling is corrected. - Ok(()) + if self.is_valid() { + Ok(()) + } else { + Err(Error::UnexpectedError { + message: "Corrupt log record batch checksum.".to_string(), + source: None, + }) + } } pub fn is_valid(&self) -> bool { + if self.data.len() < RECORD_BATCH_HEADER_SIZE { + return false; + } self.size_in_bytes() >= RECORD_BATCH_HEADER_SIZE && self.checksum() == self.compute_checksum() } @@ -1157,7 +1166,11 @@ pub struct MyVec(pub StreamReader); #[cfg(test)] mod tests { use super::*; - use crate::metadata::{DataField, DataTypes}; + use crate::metadata::{DataField, DataTypes, RowType}; + use crate::test_utils::build_log_record_bytes; + use arrow::array::Int32Array; + use bytes::Bytes; + use std::sync::Arc; #[test] fn test_to_array_type() { @@ -1357,6 +1370,100 @@ mod tests { assert!(matches!(result, Err(IllegalArgument { .. }))); } + #[test] + fn project_schema_rejects_invalid_indices() { + let row_type = RowType::new(vec![ + DataField::new("id".to_string(), DataTypes::int(), None), + DataField::new("name".to_string(), DataTypes::string(), None), + ]); + let schema = to_arrow_schema(&row_type).expect("arrow schema"); + let result = ReadContext::project_schema(schema, &[2]); + + assert!(matches!(result, Err(Error::IllegalArgument { .. }))); + } + + #[test] + fn record_batch_for_remote_log_rejects_invalid_ipc() -> Result<()> { + let row_type = RowType::new(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let schema = to_arrow_schema(&row_type)?; + let read_context = ReadContext::new(schema, true); + + let result = read_context.record_batch_for_remote_log(&[]); + assert!(matches!(result, Err(Error::ArrowError { .. }))); + Ok(()) + } + + #[test] + fn log_records_batches_iterates_over_concatenated_batches() -> Result<()> { + let batch_bytes = build_log_record_bytes(vec![1, 2, 3])?; + let mut data = batch_bytes.clone(); + data.extend_from_slice(&batch_bytes); + + let mut batches = LogRecordsBatches::new(data); + let first = batches.next().expect("first batch"); + assert!(first.is_valid()); + assert!(batches.next().is_some()); + assert!(batches.next().is_none()); + Ok(()) + } + + #[test] + fn log_record_batch_detects_bad_checksum() -> Result<()> { + let mut batch_bytes = build_log_record_bytes(vec![1])?; + batch_bytes[CRC_OFFSET] ^= 0xFF; + let batch = LogRecordBatch::new(Bytes::from(batch_bytes)); + assert!(batch.ensure_valid().is_err()); + Ok(()) + } + + #[test] + fn log_record_batch_projection_reorders_columns() -> Result<()> { + let batch_bytes = build_log_record_bytes(vec![10, 20, 30])?; + let batch = LogRecordBatch::new(Bytes::from(batch_bytes)); + + let row_type = RowType::new(vec![ + DataField::new("c0".to_string(), DataTypes::int(), None), + DataField::new("c1".to_string(), DataTypes::int(), None), + DataField::new("c2".to_string(), DataTypes::int(), None), + ]); + let schema = to_arrow_schema(&row_type)?; + let read_context = ReadContext::with_projection_pushdown(schema, vec![2, 0], false)?; + + let mut records = batch.records(&read_context)?; + let record = records.next().expect("record"); + let row_id = record.row().get_row_id(); + let record_batch = record.row().get_record_batch(); + + assert_eq!(record_batch.num_columns(), 2); + assert_eq!(record_batch.schema().field(0).name(), "c2"); + assert_eq!(record_batch.schema().field(1).name(), "c0"); + let first = record_batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(first.len(), row_id + 1); + Ok(()) + } + + #[test] + fn log_record_batch_header_fields_are_readable() -> Result<()> { + let batch_bytes = build_log_record_bytes(vec![1, 2])?; + let batch = LogRecordBatch::new(Bytes::from(batch_bytes)); + assert_eq!(batch.magic(), CURRENT_LOG_MAGIC_VALUE); + assert_eq!(batch.schema_id(), 1); + assert_eq!(batch.record_count(), 1); + assert_eq!(batch.base_log_offset(), 0); + assert_eq!(batch.last_log_offset(), 0); + assert_eq!(batch.next_log_offset(), 1); + batch.ensure_valid()?; + Ok(()) + } + #[test] fn checksum_and_schema_id_read_minimum_header() { // Header-only batches with record_count == 0 are valid; this covers the minimal bytes diff --git a/crates/fluss/src/rpc/mod.rs b/crates/fluss/src/rpc/mod.rs index 86e13b1c..a9e165fb 100644 --- a/crates/fluss/src/rpc/mod.rs +++ b/crates/fluss/src/rpc/mod.rs @@ -27,5 +27,9 @@ mod server_connection; pub use server_connection::*; mod convert; mod transport; +#[cfg(test)] +pub(crate) use api_key::ApiKey; +#[cfg(test)] +pub(crate) use transport::Transport; pub use convert::*; diff --git a/crates/fluss/src/rpc/server_connection.rs b/crates/fluss/src/rpc/server_connection.rs index 441b175a..422f9668 100644 --- a/crates/fluss/src/rpc/server_connection.rs +++ b/crates/fluss/src/rpc/server_connection.rs @@ -25,6 +25,8 @@ use crate::rpc::message::{ ReadVersionedType, RequestBody, RequestHeader, ResponseHeader, WriteVersionedType, }; use crate::rpc::transport::Transport; +#[cfg(test)] +use bytes::Buf; use futures::future::BoxFuture; use parking_lot::{Mutex, RwLock}; use std::collections::HashMap; @@ -105,6 +107,54 @@ impl RpcClient { ); Ok(ServerConnection::new(messenger)) } + + #[cfg(test)] + pub(crate) fn insert_connection_for_test( + &self, + server_node: &ServerNode, + connection: ServerConnection, + ) { + self.connections + .write() + .insert(server_node.uid().clone(), connection); + } +} + +#[cfg(test)] +pub(crate) async fn spawn_mock_server( + mut server: tokio::io::DuplexStream, + mut handler: F, +) -> JoinHandle<()> +where + F: FnMut(crate::rpc::api_key::ApiKey, i32, Vec) -> Vec + Send + 'static, +{ + tokio::spawn(async move { + loop { + let msg = match server.read_message(usize::MAX).await { + Ok(msg) => msg, + Err(_) => break, + }; + let mut cursor = Cursor::new(msg); + let api_key = crate::rpc::api_key::ApiKey::from(cursor.get_i16()); + let _api_version = cursor.get_i16(); + let request_id = cursor.get_i32(); + let remaining = { + let pos = cursor.position() as usize; + cursor.into_inner()[pos..].to_vec() + }; + + let body = handler(api_key, request_id, remaining); + + let mut response = Vec::new(); + response.push(0); + response.extend_from_slice(&request_id.to_be_bytes()); + response.extend_from_slice(&body); + + if server.write_message(&response).await.is_err() { + break; + } + } + }) } #[derive(Debug)] diff --git a/crates/fluss/src/rpc/transport.rs b/crates/fluss/src/rpc/transport.rs index a6f721f6..5a2ae03f 100644 --- a/crates/fluss/src/rpc/transport.rs +++ b/crates/fluss/src/rpc/transport.rs @@ -25,7 +25,13 @@ use tokio::net::TcpStream; #[derive(Debug)] pub enum Transport { - Plain { inner: TcpStream }, + Plain { + inner: TcpStream, + }, + #[cfg(test)] + Test { + inner: tokio::io::DuplexStream, + }, } impl AsyncRead for Transport { @@ -36,6 +42,8 @@ impl AsyncRead for Transport { ) -> Poll> { match self.deref_mut() { Self::Plain { inner } => Pin::new(inner).poll_read(cx, buf), + #[cfg(test)] + Self::Test { inner } => Pin::new(inner).poll_read(cx, buf), } } } @@ -48,18 +56,24 @@ impl AsyncWrite for Transport { ) -> Poll> { match self.deref_mut() { Self::Plain { inner } => Pin::new(inner).poll_write(cx, buf), + #[cfg(test)] + Self::Test { inner } => Pin::new(inner).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.deref_mut() { Self::Plain { inner } => Pin::new(inner).poll_flush(cx), + #[cfg(test)] + Self::Test { inner } => Pin::new(inner).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.deref_mut() { Self::Plain { inner } => Pin::new(inner).poll_shutdown(cx), + #[cfg(test)] + Self::Test { inner } => Pin::new(inner).poll_shutdown(cx), } } } diff --git a/crates/fluss/src/test_utils.rs b/crates/fluss/src/test_utils.rs index d1cd3ec7..aa4637d1 100644 --- a/crates/fluss/src/test_utils.rs +++ b/crates/fluss/src/test_utils.rs @@ -15,12 +15,40 @@ // specific language governing permissions and limitations // under the License. +use crate::client::{CompletedFetch, FetchErrorContext, WriteRecord}; use crate::cluster::{BucketLocation, Cluster, ServerNode, ServerType}; +use crate::compression::{ + ArrowCompressionInfo, ArrowCompressionType, DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, +}; +use crate::error::{ApiError, Error, Result}; use crate::metadata::{ - DataField, DataTypes, Schema, TableBucket, TableDescriptor, TableInfo, TablePath, + DataField, DataTypes, RowType, Schema, TableBucket, TableDescriptor, TableInfo, TablePath, }; +use crate::record::{MemoryLogRecordsArrowBuilder, ReadContext, ScanRecord, to_arrow_schema}; +use crate::row::{ColumnarRow, Datum, GenericRow}; +use crate::rpc::{ServerConnection, ServerConnectionInner, Transport, spawn_mock_server}; +use arrow::array::{Int32Array, RecordBatch}; +use arrow_schema::{DataType as ArrowDataType, Field, Schema as ArrowSchema}; use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::io::BufStream; +use tokio::task::JoinHandle; + +pub(crate) async fn build_mock_connection(handler: F) -> (ServerConnection, JoinHandle<()>) +where + F: FnMut(crate::rpc::ApiKey, i32, Vec) -> Vec + Send + 'static, +{ + let (client, server) = tokio::io::duplex(1024); + let handle = spawn_mock_server(server, handler).await; + let transport = Transport::Test { inner: client }; + let connection = Arc::new(ServerConnectionInner::new( + BufStream::new(transport), + usize::MAX, + Arc::from(""), + )); + (connection, handle) +} pub(crate) fn build_table_info(table_path: TablePath, table_id: i64, buckets: i32) -> TableInfo { let row_type = DataTypes::row(vec![DataField::new( @@ -86,3 +114,286 @@ pub(crate) fn build_cluster_arc( ) -> Arc { Arc::new(build_cluster(table_path, table_id, buckets)) } + +pub(crate) fn build_cluster_with_coordinator( + table_path: &TablePath, + table_id: i64, + coordinator: ServerNode, + tablet: ServerNode, +) -> Cluster { + let table_bucket = TableBucket::new(table_id, 0); + let bucket_location = BucketLocation::new( + table_bucket.clone(), + Some(tablet.clone()), + table_path.clone(), + ); + + let mut servers = HashMap::new(); + servers.insert(tablet.id(), tablet); + + let mut locations_by_path = HashMap::new(); + locations_by_path.insert(table_path.clone(), vec![bucket_location.clone()]); + + let mut locations_by_bucket = HashMap::new(); + locations_by_bucket.insert(table_bucket, bucket_location); + + let mut table_id_by_path = HashMap::new(); + table_id_by_path.insert(table_path.clone(), table_id); + + let mut table_info_by_path = HashMap::new(); + table_info_by_path.insert( + table_path.clone(), + build_table_info(table_path.clone(), table_id, 1), + ); + + Cluster::new( + Some(coordinator), + servers, + locations_by_path, + locations_by_bucket, + table_id_by_path, + table_info_by_path, + ) +} + +pub(crate) fn build_cluster_with_coordinator_arc( + table_path: &TablePath, + table_id: i64, + coordinator: ServerNode, + tablet: ServerNode, +) -> Arc { + Arc::new(build_cluster_with_coordinator( + table_path, + table_id, + coordinator, + tablet, + )) +} + +pub(crate) fn build_read_context_for_int32() -> ReadContext { + let row_type = RowType::new(vec![DataField::new( + "id".to_string(), + DataTypes::int(), + None, + )]); + let schema = to_arrow_schema(&row_type).expect("arrow schema"); + ReadContext::new(schema, false) +} + +fn build_single_int_record_batch() -> RecordBatch { + let schema = Arc::new(ArrowSchema::new(vec![Field::new( + "id", + ArrowDataType::Int32, + false, + )])); + RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(vec![1]))]).expect("record batch") +} + +pub(crate) fn build_single_int_scan_record() -> ScanRecord { + let batch = build_single_int_record_batch(); + let row = ColumnarRow::new(Arc::new(batch)); + ScanRecord::new_default(row) +} + +pub(crate) fn build_log_record_bytes(values: Vec) -> Result> { + let fields = values + .iter() + .enumerate() + .map(|(idx, _)| DataField::new(format!("c{idx}"), DataTypes::int(), None)) + .collect::>(); + let row_type = RowType::new(fields); + let table_path = Arc::new(TablePath::new("db".to_string(), "tbl".to_string())); + let mut builder = MemoryLogRecordsArrowBuilder::new( + 1, + &row_type, + false, + ArrowCompressionInfo { + compression_type: ArrowCompressionType::None, + compression_level: DEFAULT_NON_ZSTD_COMPRESSION_LEVEL, + }, + )?; + let record = WriteRecord::for_append( + table_path, + 1, + GenericRow { + values: values.into_iter().map(Datum::Int32).collect(), + }, + ); + builder.append(&record)?; + builder.build() +} + +pub(crate) struct TestCompletedFetch { + table_bucket: TableBucket, + records: Vec, + batches: Vec, + error_on_records: bool, + error_on_batches: bool, + consumed: AtomicBool, + initialized: AtomicBool, + next_fetch_offset: i64, + records_read: usize, +} + +impl TestCompletedFetch { + pub(crate) fn new(table_bucket: TableBucket) -> Self { + Self { + table_bucket, + records: Vec::new(), + batches: Vec::new(), + error_on_records: false, + error_on_batches: false, + consumed: AtomicBool::new(false), + initialized: AtomicBool::new(true), + next_fetch_offset: 0, + records_read: 0, + } + } + + pub(crate) fn record_ok(table_bucket: TableBucket) -> Self { + Self { + table_bucket, + records: vec![build_single_int_scan_record()], + batches: Vec::new(), + error_on_records: false, + error_on_batches: false, + consumed: AtomicBool::new(false), + initialized: AtomicBool::new(true), + next_fetch_offset: 0, + records_read: 0, + } + } + + pub(crate) fn record_err(table_bucket: TableBucket) -> Self { + Self { + table_bucket, + records: Vec::new(), + batches: Vec::new(), + error_on_records: true, + error_on_batches: false, + consumed: AtomicBool::new(false), + initialized: AtomicBool::new(true), + next_fetch_offset: 0, + records_read: 0, + } + } + + pub(crate) fn batch_ok(table_bucket: TableBucket) -> Self { + let batch = build_single_int_record_batch(); + Self { + table_bucket, + records: Vec::new(), + batches: vec![batch], + error_on_records: false, + error_on_batches: false, + consumed: AtomicBool::new(false), + initialized: AtomicBool::new(true), + next_fetch_offset: 0, + records_read: 0, + } + } + + pub(crate) fn batch_err(table_bucket: TableBucket) -> Self { + Self { + table_bucket, + records: Vec::new(), + batches: Vec::new(), + error_on_records: false, + error_on_batches: true, + consumed: AtomicBool::new(false), + initialized: AtomicBool::new(true), + next_fetch_offset: 0, + records_read: 0, + } + } +} + +impl CompletedFetch for TestCompletedFetch { + fn table_bucket(&self) -> &TableBucket { + &self.table_bucket + } + + fn api_error(&self) -> Option<&ApiError> { + None + } + + fn fetch_error_context(&self) -> Option<&FetchErrorContext> { + None + } + + fn take_error(&mut self) -> Option { + None + } + + fn fetch_records(&mut self, _max_records: usize) -> Result> { + if self.error_on_records { + self.consumed.store(true, Ordering::Release); + return Err(Error::UnexpectedError { + message: "fetch error".to_string(), + source: None, + }); + } + if self.consumed.load(Ordering::Acquire) { + return Ok(Vec::new()); + } + let records = std::mem::take(&mut self.records); + if !records.is_empty() { + self.records_read += records.len(); + self.next_fetch_offset += records.len() as i64; + self.consumed.store(true, Ordering::Release); + } + Ok(records) + } + + fn fetch_batches(&mut self, _max_batches: usize) -> Result> { + if self.error_on_batches { + self.consumed.store(true, Ordering::Release); + return Err(Error::UnexpectedError { + message: "fetch error".to_string(), + source: None, + }); + } + if self.consumed.load(Ordering::Acquire) { + return Ok(Vec::new()); + } + let batches = std::mem::take(&mut self.batches); + if !batches.is_empty() { + self.records_read += batches.iter().map(|b| b.num_rows()).sum::(); + self.next_fetch_offset += 1; + self.consumed.store(true, Ordering::Release); + } + Ok(batches) + } + + fn is_consumed(&self) -> bool { + self.consumed.load(Ordering::Acquire) + } + + fn records_read(&self) -> usize { + self.records_read + } + + fn drain(&mut self) { + self.consumed.store(true, Ordering::Release); + } + + fn size_in_bytes(&self) -> usize { + 0 + } + + fn high_watermark(&self) -> i64 { + 0 + } + + fn is_initialized(&self) -> bool { + self.initialized.load(Ordering::Acquire) + } + + fn set_initialized(&mut self) { + self.initialized.store(true, Ordering::Release); + } + + fn next_fetch_offset(&self) -> i64 { + self.next_fetch_offset + } +}