Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/ember-core/src/keyspace/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ mod hash;
mod list;
#[cfg(feature = "protobuf")]
mod proto;
#[cfg(feature = "protobuf")]
pub use proto::ProtoFindOpts;
mod set;
mod string;
#[cfg(feature = "vector")]
Expand Down
28 changes: 21 additions & 7 deletions crates/ember-core/src/keyspace/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ use super::*;
#[cfg(feature = "protobuf")]
use crate::schema::SchemaRegistry;

/// Parameters for [`Keyspace::scan_proto_find`].
#[cfg(feature = "protobuf")]
pub struct ProtoFindOpts<'a> {
pub cursor: u64,
pub count: usize,
pub pattern: Option<&'a str>,
pub type_name: Option<&'a str>,
pub field_path: &'a str,
pub field_value: &'a str,
}

#[cfg(feature = "protobuf")]
impl Keyspace {
/// Stores a protobuf value. No schema validation here — that's the
Expand Down Expand Up @@ -91,19 +102,22 @@ impl Keyspace {
/// `scan_proto_keys`. Skips keys where field decoding fails (e.g. wrong
/// type, nested/repeated field) rather than returning an error.
///
/// `field_value` is compared against the field's string representation:
/// `opts.field_value` is compared against the field's string representation:
/// booleans as `"true"/"false"`, integers and floats as their decimal
/// string, strings verbatim.
pub fn scan_proto_find(
&self,
cursor: u64,
count: usize,
pattern: Option<&str>,
type_name: Option<&str>,
field_path: &str,
field_value: &str,
opts: ProtoFindOpts<'_>,
registry: &SchemaRegistry,
) -> (u64, Vec<String>) {
let ProtoFindOpts {
cursor,
count,
pattern,
type_name,
field_path,
field_value,
} = opts;
let mut keys = Vec::with_capacity(count);
let mut position = 0u64;
let target_count = if count == 0 { 10 } else { count };
Expand Down
22 changes: 10 additions & 12 deletions crates/ember-core/src/shard/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2514,12 +2514,8 @@ fn dispatch(
pattern,
type_name,
} => {
let (next_cursor, keys) = ks.scan_proto_keys(
*cursor,
*count,
pattern.as_deref(),
type_name.as_deref(),
);
let (next_cursor, keys) =
ks.scan_proto_keys(*cursor, *count, pattern.as_deref(), type_name.as_deref());
ShardResponse::Scan {
cursor: next_cursor,
keys,
Expand All @@ -2543,12 +2539,14 @@ fn dispatch(
Err(_) => return ShardResponse::Err("schema registry lock poisoned".into()),
};
let (next_cursor, keys) = ks.scan_proto_find(
*cursor,
*count,
pattern.as_deref(),
type_name.as_deref(),
field_path,
field_value,
crate::keyspace::ProtoFindOpts {
cursor: *cursor,
count: *count,
pattern: pattern.as_deref(),
type_name: type_name.as_deref(),
field_path,
field_value,
},
&reg,
);
ShardResponse::Scan {
Expand Down
10 changes: 2 additions & 8 deletions crates/ember-server/src/concurrent_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1137,10 +1137,7 @@ async fn execute_concurrent(
let sid = (cursor >> 48) as usize;
let pos = cursor & 0xFFFF_FFFF_FFFF;
if sid >= shard_count {
return Frame::Array(vec![
Frame::Bulk(Bytes::from("0")),
Frame::Array(vec![]),
]);
return Frame::Array(vec![Frame::Bulk(Bytes::from("0")), Frame::Array(vec![])]);
}
(sid, pos)
};
Expand Down Expand Up @@ -1210,10 +1207,7 @@ async fn execute_concurrent(
let sid = (cursor >> 48) as usize;
let pos = cursor & 0xFFFF_FFFF_FFFF;
if sid >= shard_count {
return Frame::Array(vec![
Frame::Bulk(Bytes::from("0")),
Frame::Array(vec![]),
]);
return Frame::Array(vec![Frame::Bulk(Bytes::from("0")), Frame::Array(vec![])]);
}
(sid, pos)
};
Expand Down
12 changes: 10 additions & 2 deletions crates/ember-server/src/connection/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -574,8 +574,16 @@ pub(super) async fn execute(
type_name,
count,
} => {
exec::protobuf::proto_find(cursor, field_path, field_value, pattern, type_name, count, &cx)
.await
exec::protobuf::proto_find(
cursor,
field_path,
field_value,
pattern,
type_name,
count,
&cx,
)
.await
}
#[cfg(not(feature = "protobuf"))]
Command::ProtoRegister { .. }
Expand Down
113 changes: 80 additions & 33 deletions tests/integration/src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1214,8 +1214,13 @@ async fn proto_scan_all_keys() {
// store 5 proto keys
for i in 1..=5u32 {
let data = encode_profile(&desc, &format!("user{i}"), i as i32, true);
c.cmd_raw(&[b"PROTO.SET", format!("user:{i}").as_bytes(), b"test.Profile", &data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
format!("user:{i}").as_bytes(),
b"test.Profile",
&data,
])
.await;
}

// also store a non-proto key to ensure it's excluded
Expand Down Expand Up @@ -1267,9 +1272,7 @@ async fn proto_scan_type_filter() {
.await;

// scan with TYPE=test.Profile — should only return profile:1
let resp = c
.cmd(&["PROTO.SCAN", "0", "TYPE", "test.Profile"])
.await;
let resp = c.cmd(&["PROTO.SCAN", "0", "TYPE", "test.Profile"]).await;
let (_, keys) = decode_scan_response(resp);
assert_eq!(keys.len(), 1);
assert_eq!(keys[0], "profile:1");
Expand All @@ -1292,11 +1295,21 @@ async fn proto_scan_match_pattern() {

for i in 1..=3u32 {
let data = encode_profile(&desc, "x", i as i32, false);
c.cmd_raw(&[b"PROTO.SET", format!("profile:{i}").as_bytes(), b"test.Profile", &data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
format!("profile:{i}").as_bytes(),
b"test.Profile",
&data,
])
.await;
let data = encode_profile(&desc, "y", i as i32, false);
c.cmd_raw(&[b"PROTO.SET", format!("other:{i}").as_bytes(), b"test.Profile", &data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
format!("other:{i}").as_bytes(),
b"test.Profile",
&data,
])
.await;
}

let mut matched = Vec::new();
Expand Down Expand Up @@ -1330,8 +1343,13 @@ async fn proto_scan_cursor_consistency() {

for i in 1..=10u32 {
let data = encode_profile(&desc, "x", i as i32, true);
c.cmd_raw(&[b"PROTO.SET", format!("p:{i}").as_bytes(), b"test.Profile", &data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
format!("p:{i}").as_bytes(),
b"test.Profile",
&data,
])
.await;
}

// first page with COUNT 3
Expand All @@ -1342,13 +1360,20 @@ async fn proto_scan_cursor_consistency() {
// add more keys while iterating
for i in 11..=15u32 {
let data = encode_profile(&desc, "y", i as i32, false);
c.cmd_raw(&[b"PROTO.SET", format!("p:{i}").as_bytes(), b"test.Profile", &data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
format!("p:{i}").as_bytes(),
b"test.Profile",
&data,
])
.await;
}

// continue iterating — must not panic or crash
if cursor != 0 {
let resp = c.cmd(&["PROTO.SCAN", &cursor.to_string(), "COUNT", "3"]).await;
let resp = c
.cmd(&["PROTO.SCAN", &cursor.to_string(), "COUNT", "3"])
.await;
let (_, _) = decode_scan_response(resp);
}
}
Expand All @@ -1366,16 +1391,31 @@ async fn proto_find_scalar_match() {

// store three profiles with different active values
let active_data = encode_profile(&desc, "alice", 25, true);
c.cmd_raw(&[b"PROTO.SET", b"profile:alice", b"test.Profile", &active_data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
b"profile:alice",
b"test.Profile",
&active_data,
])
.await;

let inactive_data = encode_profile(&desc, "bob", 30, false);
c.cmd_raw(&[b"PROTO.SET", b"profile:bob", b"test.Profile", &inactive_data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
b"profile:bob",
b"test.Profile",
&inactive_data,
])
.await;

let active2_data = encode_profile(&desc, "carol", 22, true);
c.cmd_raw(&[b"PROTO.SET", b"profile:carol", b"test.Profile", &active2_data])
.await;
c.cmd_raw(&[
b"PROTO.SET",
b"profile:carol",
b"test.Profile",
&active2_data,
])
.await;

// find by bool field
let mut found = Vec::new();
Expand All @@ -1397,17 +1437,13 @@ async fn proto_find_scalar_match() {
assert!(found.contains(&"profile:carol".to_owned()));

// find by int field
let resp = c
.cmd(&["PROTO.FIND", "0", "age", "30"])
.await;
let resp = c.cmd(&["PROTO.FIND", "0", "age", "30"]).await;
let (_, keys) = decode_scan_response(resp);
assert_eq!(keys.len(), 1);
assert_eq!(keys[0], "profile:bob");

// find by string field
let resp = c
.cmd(&["PROTO.FIND", "0", "name", "alice"])
.await;
let resp = c.cmd(&["PROTO.FIND", "0", "name", "alice"]).await;
let (_, keys) = decode_scan_response(resp);
assert_eq!(keys.len(), 1);
assert_eq!(keys[0], "profile:alice");
Expand All @@ -1416,7 +1452,9 @@ async fn proto_find_scalar_match() {
/// PROTO.FIND with a dot-separated path searches nested message fields.
#[tokio::test]
async fn proto_find_nested_path() {
use prost_reflect::prost_types::{DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet};
use prost_reflect::prost_types::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet,
};

// build a descriptor with a nested Address.city field
let fds = FileDescriptorSet {
Expand Down Expand Up @@ -1464,10 +1502,14 @@ async fn proto_find_nested_path() {
fds.encode(&mut desc_bytes).expect("encode descriptor");

let pool = DescriptorPool::decode(desc_bytes.as_slice()).expect("decode pool");
let person_desc = pool.get_message_by_name("nested.Person").expect("find message");
let person_desc = pool
.get_message_by_name("nested.Person")
.expect("find message");

let encode_person = |name: &str, city: &str| {
let addr_desc = pool.get_message_by_name("nested.Address").expect("find address");
let addr_desc = pool
.get_message_by_name("nested.Address")
.expect("find address");
let mut addr = DynamicMessage::new(addr_desc);
addr.set_field_by_name("city", prost_reflect::Value::String(city.into()));

Expand Down Expand Up @@ -1562,9 +1604,7 @@ async fn proto_find_no_match() {
c.cmd_raw(&[b"PROTO.SET", b"profile:1", b"test.Profile", &data])
.await;

let resp = c
.cmd(&["PROTO.FIND", "0", "active", "true"])
.await;
let resp = c.cmd(&["PROTO.FIND", "0", "active", "true"]).await;
let (cursor, keys) = decode_scan_response(resp);
assert_eq!(cursor, 0);
assert!(keys.is_empty());
Expand Down Expand Up @@ -1606,7 +1646,14 @@ async fn proto_find_count_pagination() {
let mut cursor = 0u64;
loop {
let resp = c
.cmd(&["PROTO.FIND", &cursor.to_string(), "active", "true", "COUNT", "2"])
.cmd(&[
"PROTO.FIND",
&cursor.to_string(),
"active",
"true",
"COUNT",
"2",
])
.await;
let (next, keys) = decode_scan_response(resp);
all_found.extend(keys);
Expand Down