Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use spacetimedb_sql_parser::{
};

use super::{
errors::{DuplicateName, TypingError, Unresolved, Unsupported},
errors::{DuplicateName, FunctionCall, TypingError, Unresolved, Unsupported},
expr::RelExpr,
type_expr, type_proj, type_select,
};
Expand Down Expand Up @@ -78,12 +78,8 @@ pub trait TypeChecker {
delta: None,
});

for SqlJoin {
var: SqlIdent(name),
alias: SqlIdent(alias),
on,
} in joins
{
for SqlJoin { from, on } in joins {
let (SqlIdent(name), SqlIdent(alias)) = from.into_name_alias();
// Check for duplicate aliases
if vars.contains_key(&alias) {
return Err(DuplicateName(alias.into_string()).into());
Expand Down Expand Up @@ -113,6 +109,8 @@ pub trait TypeChecker {

Ok(join)
}
// TODO: support function calls in FROM clause
SqlFrom::FuncCall(_, _) => Err(FunctionCall.into()),
}
}

Expand Down
6 changes: 6 additions & 0 deletions crates/expr/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ pub struct DmlOnView {
pub view_name: Box<str>,
}

#[derive(Debug, Error)]
#[error("Function calls are not supported")]
pub struct FunctionCall;

#[derive(Error, Debug)]
pub enum TypingError {
#[error(transparent)]
Expand Down Expand Up @@ -157,4 +161,6 @@ pub enum TypingError {
DuplicateName(#[from] DuplicateName),
#[error(transparent)]
FilterReturnType(#[from] FilterReturnType),
#[error(transparent)]
FunctionCall(#[from] FunctionCall),
}
29 changes: 26 additions & 3 deletions crates/sql-parser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use sqlparser::ast::Ident;
pub mod sql;
pub mod sub;

/// The FROM clause is either a relvar or a JOIN
/// The FROM clause is either a relvar, a JOIN, or a function call
#[derive(Debug)]
pub enum SqlFrom {
Expr(SqlIdent, SqlIdent),
Join(SqlIdent, SqlIdent, Vec<SqlJoin>),
FuncCall(SqlFuncCall, SqlIdent),
}

impl SqlFrom {
Expand All @@ -22,11 +23,26 @@ impl SqlFrom {
}
}

/// A source in a FROM clause, restricted to a single relvar or function call
#[derive(Debug)]
pub enum SqlFromSource {
Expr(SqlIdent, SqlIdent),
FuncCall(SqlFuncCall, SqlIdent),
}

impl SqlFromSource {
pub fn into_name_alias(self) -> (SqlIdent, SqlIdent) {
match self {
Self::Expr(name, alias) => (name, alias),
Self::FuncCall(func, alias) => (func.name, alias),
}
}
}

/// An inner join in a FROM clause
#[derive(Debug)]
pub struct SqlJoin {
pub var: SqlIdent,
pub alias: SqlIdent,
pub from: SqlFromSource,
pub on: Option<SqlExpr>,
}

Expand Down Expand Up @@ -247,3 +263,10 @@ impl Display for LogOp {
}
}
}

/// A SQL function call
#[derive(Debug)]
pub struct SqlFuncCall {
pub name: SqlIdent,
pub args: Vec<SqlLiteral>,
}
1 change: 1 addition & 0 deletions crates/sql-parser/src/ast/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ impl SqlSelect {
..self
},
SqlFrom::Join(..) => self,
SqlFrom::FuncCall(..) => self,
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/sql-parser/src/ast/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl SqlSelect {
from: self.from,
},
SqlFrom::Join(..) => self,
SqlFrom::FuncCall(..) => self,
}
}

Expand Down
7 changes: 7 additions & 0 deletions crates/sql-parser/src/parser/errors.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Display;

use sqlparser::ast::FunctionArg;
use sqlparser::{
ast::{
BinaryOperator, Expr, Function, ObjectName, Query, Select, SelectItem, SetExpr, TableFactor, TableWithJoins,
Expand Down Expand Up @@ -77,6 +78,12 @@ pub enum SqlUnsupported {
Empty,
#[error("Names must be qualified when using joins")]
UnqualifiedNames,
#[error("Unsupported function argument: {0}")]
FuncArg(FunctionArg),
#[error("Unsupported call to table-valued function with empty params. Use `select * from table_function` syntax instead: {0}")]
TableFunctionNoParams(String),
#[error("Unsupported JOIN with table-valued function: {0}")]
JoinTableFunction(String),
}

impl SqlUnsupported {
Expand Down
66 changes: 51 additions & 15 deletions crates/sql-parser/src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ use sqlparser::ast::{
};

use crate::ast::{
BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlIdent, SqlJoin, SqlLiteral,
BinOp, LogOp, Parameter, Project, ProjectElem, ProjectExpr, SqlExpr, SqlFrom, SqlFromSource, SqlFuncCall, SqlIdent,
SqlJoin, SqlLiteral,
};

pub mod errors;
Expand Down Expand Up @@ -34,11 +35,15 @@ trait RelParser {
return Err(SqlUnsupported::ImplicitJoins.into());
}
let TableWithJoins { relation, joins } = tables.swap_remove(0);
let (name, alias) = Self::parse_relvar(relation)?;
if joins.is_empty() {
return Ok(SqlFrom::Expr(name, alias));
match Self::parse_relvar(relation)? {
SqlFromSource::Expr(name, alias) => {
if joins.is_empty() {
return Ok(SqlFrom::Expr(name, alias));
}
Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
}
SqlFromSource::FuncCall(func_call, alias) => Ok(SqlFrom::FuncCall(func_call, alias)),
}
Ok(SqlFrom::Join(name, alias, Self::parse_joins(joins)?))
}

/// Parse a sequence of JOIN clauses
Expand All @@ -48,10 +53,11 @@ trait RelParser {

/// Parse a single JOIN clause
fn parse_join(join: Join) -> SqlParseResult<SqlJoin> {
let (var, alias) = Self::parse_relvar(join.relation)?;
let from = Self::parse_relvar(join.relation)?;

match join.join_operator {
JoinOperator::CrossJoin => Ok(SqlJoin { var, alias, on: None }),
JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { var, alias, on: None }),
JoinOperator::CrossJoin => Ok(SqlJoin { from, on: None }),
JoinOperator::Inner(JoinConstraint::None) => Ok(SqlJoin { from, on: None }),
JoinOperator::Inner(JoinConstraint::On(Expr::BinaryOp {
left,
op: BinaryOperator::Eq,
Expand All @@ -60,8 +66,7 @@ trait RelParser {
&& matches!(*right, Expr::Identifier(..) | Expr::CompoundIdentifier(..)) =>
{
Ok(SqlJoin {
var,
alias,
from,
on: Some(parse_expr(
Expr::BinaryOp {
left,
Expand All @@ -76,32 +81,63 @@ trait RelParser {
}
}

/// Parse a function call
fn parse_func_call(name: SqlIdent, args: Vec<FunctionArg>) -> SqlParseResult<SqlFuncCall> {
if args.is_empty() {
return Err(SqlUnsupported::TableFunctionNoParams(name.0.into()).into());
}
let args = args
.into_iter()
.map(|arg| match arg.clone() {
FunctionArg::Unnamed(FunctionArgExpr::Expr(expr)) => match parse_expr(expr, 0) {
Ok(SqlExpr::Lit(lit)) => Ok(lit),
_ => Err(SqlUnsupported::FuncArg(arg).into()),
},
_ => Err(SqlUnsupported::FuncArg(arg.clone()).into()),
})
.collect::<SqlParseResult<_>>()?;
Ok(SqlFuncCall { name, args })
}

/// Parse a table reference in a FROM clause
fn parse_relvar(expr: TableFactor) -> SqlParseResult<(SqlIdent, SqlIdent)> {
fn parse_relvar(expr: TableFactor) -> SqlParseResult<SqlFromSource> {
match expr {
// Relvar no alias
TableFactor::Table {
name,
alias: None,
args: None,
args,
with_hints,
version: None,
partitions,
} if with_hints.is_empty() && partitions.is_empty() => {
let name = parse_ident(name)?;
let alias = name.clone();
Ok((name, alias))

if let Some(args) = args {
Ok(SqlFromSource::FuncCall(Self::parse_func_call(name, args)?, alias))
} else {
Ok(SqlFromSource::Expr(name, alias))
}
}
// Relvar with alias
TableFactor::Table {
name,
alias: Some(TableAlias { name: alias, columns }),
args: None,
args,
with_hints,
version: None,
partitions,
} if with_hints.is_empty() && partitions.is_empty() && columns.is_empty() => {
Ok((parse_ident(name)?, alias.into()))
let args = args.filter(|v| !v.is_empty());
if let Some(args) = args {
Ok(SqlFromSource::FuncCall(
Self::parse_func_call(parse_ident(name)?, args)?,
alias.into(),
))
} else {
Ok(SqlFromSource::Expr(parse_ident(name)?, alias.into()))
}
}
_ => Err(SqlUnsupported::From(expr).into()),
}
Expand Down
22 changes: 22 additions & 0 deletions crates/sql-parser/src/parser/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,18 @@
//! | SUM '(' columnExpr ')' AS ident
//! ;
//!
//! paramExpr
//! = literal
//! ;
//!
//! functionCall
//! = ident '(' [ paramExpr { ',' paramExpr } ] ')'
//! ;
//!
//! relation
//! = table
//! | '(' query ')'
//! | functionCall
//! | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate }
//! ;
//!
Expand Down Expand Up @@ -442,6 +451,11 @@ mod tests {
"select a from t where x = :sender",
"select count(*) as n from t",
"select count(*) as n from t join s on t.id = s.id where s.x = 1",
"select * from sample as s",
"select * from sample(1, 'abc', true, 0xFF, 0.1)",
"select * from sample(1, 'abc', true, 0xFF, 0.1) as s",
"select * from t join sample(1) on t.id = sample.id",
"select * from t join sample(1) as s on t.id = s.id",
"insert into t values (1, 2)",
"delete from t",
"delete from t where a = 1",
Expand All @@ -463,6 +477,14 @@ mod tests {
"select a from where b = 1",
// Empty WHERE
"select a from t where",
// Function call params are not literals
"select * from sample(a, b)",
// Function call without params
"select * from sample()",
// Nested function call
"select * from sample(sample(1))",
// Function call in JOIN ON
"select * from t join sample(1) on t.id = sample(1).id",
// Empty GROUP BY
"select a, count(*) from t group by",
// Aggregate without alias
Expand Down
22 changes: 22 additions & 0 deletions crates/sql-parser/src/parser/sub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,18 @@
//! | ident '.' STAR
//! ;
//!
//! paramExpr
//! = literal
//! ;
//!
//! functionCall
//! = ident '(' [ paramExpr { ',' paramExpr } ] ')'
//! ;
//!
//! relation
//! = table
//! | '(' query ')'
//! | functionCall
//! | relation [ [AS] ident ] { [INNER] JOIN relation [ [AS] ident ] ON predicate }
//! ;
//!
Expand Down Expand Up @@ -162,6 +171,14 @@ mod tests {
"",
"select distinct a from t",
"select * from (select * from t) join (select * from s) on a = b",
// Function call params are not literals
"select * from sample(a, b)",
// Function call without params
"select * from sample()",
// Nested function call
"select * from sample(sample(1))",
// Function call in JOIN ON
"select * from t join sample(1) on t.id = sample(1).id",
] {
assert!(parse_subscription(sql).is_err());
}
Expand All @@ -178,6 +195,11 @@ mod tests {
"select t.* from t join s on t.c = s.d",
"select a.* from t as a join s as b on a.c = b.d",
"select * from t where x = :sender",
"select * from sample as s",
"select * from sample(1, 'abc', true, 0xFF, 0.1)",
"select * from sample(1, 'abc', true, 0xFF, 0.1) as s",
"select * from t join sample(1) on t.id = sample.id",
"select * from t join sample(1) as s on t.id = s.id",
] {
assert!(parse_subscription(sql).is_ok());
}
Expand Down
Loading