Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion crates/cli/src/subcommands/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ mod tests {
use spacetimedb_lib::error::ResultTest;
use spacetimedb_lib::sats::time_duration::TimeDuration;
use spacetimedb_lib::sats::timestamp::Timestamp;
use spacetimedb_lib::sats::{product, GroundSpacetimeType, ProductType};
use spacetimedb_lib::sats::{product, ArrayValue, GroundSpacetimeType, ProductType};
use spacetimedb_lib::{AlgebraicType, AlgebraicValue, ConnectionId, Identity, Uuid};

fn make_row(row: &[AlgebraicValue]) -> Result<Box<RawValue>, serde_json::Error> {
Expand Down Expand Up @@ -512,6 +512,48 @@ Roundtrip time: 1.00ms"#,
assert_eq!(expected, table);
}

#[test]
fn output_arrays() -> ResultTest<()> {
let kind: ProductType = [
("ints", AlgebraicType::array(AlgebraicType::I32)),
("strings", AlgebraicType::array(AlgebraicType::String)),
("nested", AlgebraicType::array(AlgebraicType::array(AlgebraicType::I32))),
("bytes", AlgebraicType::bytes()),
]
.into();

let value = product![
AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())),
AlgebraicValue::Array(ArrayValue::String(["one".into(), "two".into()].into())),
AlgebraicValue::Array(ArrayValue::Array(
[ArrayValue::I32([1, 2].into()), ArrayValue::I32([3, 4].into())].into()
)),
AlgebraicValue::Bytes([0xde, 0xad].into()),
];

expect_psql_table(
PsqlClient::SpacetimeDB,
&kind,
vec![value.clone()],
r#"
ints | strings | nested | bytes
-----------+----------------+------------------+--------
[1, 2, 3] | ["one", "two"] | [[1, 2], [3, 4]] | 0xdead"#,
);

expect_psql_table(
PsqlClient::Postgres,
&kind,
vec![value],
r#"
ints | strings | nested | bytes
-----------+----------------+------------------+----------
{1, 2, 3} | {"one", "two"} | {{1, 2}, {3, 4}} | "0xdead""#,
);

Ok(())
}

// Verify the output of `sql` matches the inputs that return true for [`AlgebraicType::is_special()`]
#[test]
fn output_special_types() -> ResultTest<()> {
Expand Down
112 changes: 107 additions & 5 deletions crates/pg/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::pg_server::PgError;
use pgwire::api::portal::Format;
use pgwire::api::results::{DataRowEncoder, FieldInfo};
use pgwire::api::Type;
use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter};
use spacetimedb_lib::sats::{satn, ValueWithType};
use spacetimedb_lib::sats::satn::{PsqlChars, PsqlClient, PsqlPrintFmt, PsqlType, TypedWriter};
use spacetimedb_lib::sats::{satn, ArrayValue, ValueWithType};
use spacetimedb_lib::{
ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, Uuid,
};
Expand Down Expand Up @@ -54,7 +54,11 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type {
| AlgebraicType::U128
| AlgebraicType::I256
| AlgebraicType::U256 => Type::NUMERIC_ARRAY,
_ => Type::ANYARRAY,
AlgebraicType::F32 => Type::FLOAT4_ARRAY,
AlgebraicType::F64 => Type::FLOAT8_ARRAY,
AlgebraicType::Ref(_) | AlgebraicType::Sum(_) | AlgebraicType::Product(_) | AlgebraicType::Array(_) => {
Type::JSON_ARRAY
}
},
AlgebraicType::Product(_) => match format {
PsqlPrintFmt::Hex => Type::BYTEA,
Expand Down Expand Up @@ -155,7 +159,9 @@ impl TypedWriter for PsqlFormatter<'_> {
return Ok(());
}

let PsqlChars { start, sep, end, quote } = ty.client.format_chars();
let PsqlChars {
start, sep, end, quote, ..
} = ty.client.format_chars();
let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string()));
let json = format!(
"{start}{quote}{name}{quote}{sep} {}{end}",
Expand All @@ -164,6 +170,78 @@ impl TypedWriter for PsqlFormatter<'_> {
self.encoder.encode_field(&json)?;
Ok(())
}

fn write_array(
&mut self,
value: &ValueWithType<'_, ArrayValue>,
psql: &PsqlType,
ty: &AlgebraicType,
) -> Result<bool, Self::Error> {
// `array<u8>` is a byte array in SQL output, so keep the existing bytea path.
if *ty == AlgebraicType::U8 {
return Ok(false);
}

fn collect<I, O, F>(arr: &[I], map: F) -> Vec<O>
where
F: FnMut(&I) -> O,
{
arr.iter().map(map).collect()
}

let complex_value = |elem: AlgebraicValue, elem_ty: &AlgebraicType, client| {
let tuple = ProductType::from([elem_ty.clone()]);
let psql_ty = PsqlType {
client,
tuple: &tuple,
field: &tuple.elements[0],
idx: 0,
};
satn::PsqlWrapper {
ty: psql_ty,
value: value.with(elem_ty, &elem),
}
.to_string()
};

match value.value() {
ArrayValue::Bool(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::I8(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::U8(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::I16(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::U16(arr) => self.encoder.encode_field(&collect(arr, |v| i32::from(*v)))?,
ArrayValue::I32(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::U32(arr) => self.encoder.encode_field(&collect(arr, |v| i64::from(*v)))?,
ArrayValue::I64(arr) => self.encoder.encode_field(&arr.as_ref())?,
ArrayValue::U64(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::I128(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::U128(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::I256(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::U256(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::F32(arr) => self.encoder.encode_field(&collect(arr, |v| *v.as_ref()))?,
ArrayValue::F64(arr) => self.encoder.encode_field(&collect(arr, |v| *v.as_ref()))?,
ArrayValue::String(arr) => self.encoder.encode_field(&collect(arr, |v| v.to_string()))?,
ArrayValue::Array(arr) => {
// Nested arrays are exposed as JSON arrays for the PostgreSQL wire protocol.
let values = collect(arr, |v| {
complex_value(AlgebraicValue::Array(v.clone()), ty, PsqlClient::SpacetimeDB)
});
self.encoder.encode_field(&values)?;
}
ArrayValue::Sum(arr) => {
let values = collect(arr, |v| complex_value(AlgebraicValue::Sum(v.clone()), ty, psql.client));
self.encoder.encode_field(&values)?;
}
ArrayValue::Product(arr) => {
let values = collect(arr, |v| {
complex_value(AlgebraicValue::Product(v.clone()), ty, psql.client)
});
self.encoder.encode_field(&values)?;
}
}

Ok(true)
}
}

#[cfg(test)]
Expand All @@ -173,7 +251,7 @@ mod tests {
use futures::StreamExt;
use spacetimedb_client_api_messages::http::SqlStmtResult;
use spacetimedb_lib::sats::algebraic_value::Packed;
use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant};
use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ArrayValue, ProductType, SumTypeVariant};
use spacetimedb_lib::{ConnectionId, Identity};

async fn run(schema: ProductType, row: ProductValue) -> String {
Expand Down Expand Up @@ -236,6 +314,30 @@ mod tests {
assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t");
}

#[tokio::test]
async fn test_array() {
let schema = ProductType::from([
AlgebraicType::array(AlgebraicType::I32),
AlgebraicType::array(AlgebraicType::String),
AlgebraicType::array(AlgebraicType::array(AlgebraicType::I32)),
AlgebraicType::bytes(),
]);
let value = product![
AlgebraicValue::Array(ArrayValue::I32([1, 2, 3].into())),
AlgebraicValue::Array(ArrayValue::String(["one".into(), "two".into()].into())),
AlgebraicValue::Array(ArrayValue::Array(
[ArrayValue::I32([1, 2].into()), ArrayValue::I32([3, 4].into())].into()
)),
AlgebraicValue::Bytes([0xde, 0xad].into()),
];

let row = run(schema, value).await;
assert_eq!(
row,
"\0\0\0\u{7}{1,2,3}\0\0\0\t{one,two}\0\0\0\u{13}{\"[1, 2]\",\"[3, 4]\"}\0\0\0\u{6}\\xdead"
);
}

#[tokio::test]
async fn test_enum() {
let some = AlgebraicType::option(AlgebraicType::I64);
Expand Down
96 changes: 94 additions & 2 deletions crates/sats/src/satn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::time_duration::TimeDuration;
use crate::timestamp::Timestamp;
use crate::uuid::Uuid;
use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType};
use crate::{i256, u256, AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, Serialize, SumValue, ValueWithType};
use crate::{ser, ProductType, ProductTypeElement};
use core::fmt;
use core::fmt::Write as _;
Expand Down Expand Up @@ -453,8 +453,10 @@ pub enum PsqlClient {

pub struct PsqlChars {
pub start: char,
pub start_array: &'static str,
pub sep: &'static str,
pub end: char,
pub end_array: &'static str,
pub quote: &'static str,
}

Expand All @@ -463,14 +465,18 @@ impl PsqlClient {
match self {
PsqlClient::SpacetimeDB => PsqlChars {
start: '(',
start_array: "[",
sep: " =",
end: ')',
end_array: "]",
quote: "",
},
PsqlClient::Postgres => PsqlChars {
start: '{',
start_array: "{",
sep: ":",
end: '}',
end_array: "}",
quote: "\"",
},
}
Expand Down Expand Up @@ -588,6 +594,17 @@ pub trait TypedWriter {
Ok(false)
}

/// Writes an array as a single value. Returns `false` to use the default
/// typed serialization path instead.
fn write_array(
&mut self,
_value: &ValueWithType<'_, ArrayValue>,
_psql: &PsqlType,
_ty: &AlgebraicType,
) -> Result<bool, Self::Error> {
Ok(false)
}

fn write_record(
&mut self,
fields: Vec<(Cow<str>, PsqlType, ValueWithType<AlgebraicValue>)>,
Expand Down Expand Up @@ -764,6 +781,39 @@ impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> {
Ok(TypedArrayFormatter { ty: self.ty, f: self.f })
}

fn serialize_array_raw(self, value: &ValueWithType<'_, ArrayValue>) -> Result<Self::Ok, Self::Error> {
let mut ty = &*value.ty().elem_ty;
while let AlgebraicType::Ref(r) = ty {
ty = &value.typespace()[*r];
}
if self.f.write_array(value, self.ty, ty)? {
return Ok(());
}
match (value.value(), ty) {
(ArrayValue::Sum(v), AlgebraicType::Sum(ty)) => value.with(ty, v).serialize(self),
(ArrayValue::Product(v), AlgebraicType::Product(ty)) => value.with(ty, v).serialize(self),
(ArrayValue::Bool(v), AlgebraicType::Bool) => v.serialize(self),
(ArrayValue::I8(v), AlgebraicType::I8) => v.serialize(self),
(ArrayValue::U8(v), AlgebraicType::U8) => v.serialize(self),
(ArrayValue::I16(v), AlgebraicType::I16) => v.serialize(self),
(ArrayValue::U16(v), AlgebraicType::U16) => v.serialize(self),
(ArrayValue::I32(v), AlgebraicType::I32) => v.serialize(self),
(ArrayValue::U32(v), AlgebraicType::U32) => v.serialize(self),
(ArrayValue::I64(v), AlgebraicType::I64) => v.serialize(self),
(ArrayValue::U64(v), AlgebraicType::U64) => v.serialize(self),
(ArrayValue::I128(v), AlgebraicType::I128) => v.serialize(self),
(ArrayValue::U128(v), AlgebraicType::U128) => v.serialize(self),
(ArrayValue::I256(v), AlgebraicType::I256) => v.serialize(self),
(ArrayValue::U256(v), AlgebraicType::U256) => v.serialize(self),
(ArrayValue::F32(v), AlgebraicType::F32) => v.serialize(self),
(ArrayValue::F64(v), AlgebraicType::F64) => v.serialize(self),
(ArrayValue::String(v), AlgebraicType::String) => v.serialize(self),
(ArrayValue::Array(v), AlgebraicType::Array(ty)) => value.with(ty, v).serialize(self),
(val, _) if val.is_empty() => ser::SerializeArray::end(self.serialize_array(0)?),
(val, ty) => panic!("mismatched value and schema: {val:?} {ty:?}"),
}
}

fn serialize_seq_product(self, _len: usize) -> Result<Self::SerializeSeqProduct, Self::Error> {
Ok(TypedSeqFormatter { ty: self.ty, f: self.f })
}
Expand Down Expand Up @@ -893,11 +943,53 @@ impl TypedWriter for SqlFormatter<'_, '_> {
write!(self.fmt, "\"{value}\"")
}

fn write_array(
&mut self,
value: &ValueWithType<'_, ArrayValue>,
_psql: &PsqlType,
ty: &AlgebraicType,
) -> Result<bool, Self::Error> {
// `array<u8>` is rendered as bytes in SQL output.
if *ty == AlgebraicType::U8 {
return Ok(false);
}

let PsqlChars {
start_array, end_array, ..
} = self.ty.client.format_chars();
write!(self.fmt, "{start_array}")?;
let tuple = ProductType::from([ty.clone()]);
let field = &tuple.elements[0];
for (idx, elem) in value.value().iter_cloned().enumerate() {
if idx > 0 {
write!(self.fmt, ", ")?;
}
let psql_ty = PsqlType {
client: self.ty.client,
tuple: &tuple,
field,
idx: 0,
};
write!(
self.fmt,
"{}",
PsqlWrapper {
ty: psql_ty,
value: value.with(ty, &elem)
}
)?;
}
write!(self.fmt, "{end_array}")?;
Ok(true)
}

fn write_record(
&mut self,
fields: Vec<(Cow<str>, PsqlType<'_>, ValueWithType<AlgebraicValue>)>,
) -> Result<(), Self::Error> {
let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars();
let PsqlChars {
start, sep, end, quote, ..
} = self.ty.client.format_chars();
write!(self.fmt, "{start}")?;
for (idx, (name, ty, value)) in fields.into_iter().enumerate() {
if idx > 0 {
Expand Down
Loading
Loading