diff --git a/fluss-spark/PROCEDURES.md b/fluss-spark/PROCEDURES.md new file mode 100644 index 0000000000..ea01e79c6c --- /dev/null +++ b/fluss-spark/PROCEDURES.md @@ -0,0 +1,96 @@ +# Fluss Spark Procedures + +This document describes the stored procedures available in Fluss for Spark. + +## Overview + +Fluss provides stored procedures to perform administrative and management operations through Spark SQL. All procedures are located in the `sys` namespace and can be invoked using the `CALL` statement. + +## Configuration + +To enable Fluss procedures in Spark, you need to configure the Spark session extensions: + +```scala +spark.conf.set("spark.sql.extensions", "org.apache.fluss.spark.extensions.FlussSparkSessionExtensions") +``` + +Or in `spark-defaults.conf`: + +```properties +spark.sql.extensions=org.apache.fluss.spark.extensions.FlussSparkSessionExtensions +``` + +## Syntax + +The general syntax for calling a procedure is: + +```sql +CALL [catalog_name.]sys.procedure_name( + parameter_name => 'value', + another_parameter => 'value' +) +``` + +### Argument Passing + +Procedures support two ways to pass arguments: + +1. **Named Arguments** (recommended): + ```sql + CALL catalog.sys.procedure_name(parameter => 'value') + ``` + +2. **Positional Arguments**: + ```sql + CALL catalog.sys.procedure_name('value') + ``` + +Note: You cannot mix named and positional arguments in a single procedure call. + +## Available Procedures + +Currently, no procedures are implemented in this PR. This section will be updated when procedures are added. + +## Error Handling + +Procedures will throw exceptions in the following cases: + +- **Missing Required Parameters**: If a required parameter is not provided +- **Invalid Table Name**: If the specified table does not exist +- **Type Mismatch**: If a parameter value cannot be converted to the expected type +- **Permission Denied**: If the user does not have permission to perform the operation + +## Examples + +### Basic Usage + +```scala +// Start Spark with Fluss extensions +val spark = SparkSession.builder() + .config("spark.sql.extensions", "org.apache.fluss.spark.extensions.FlussSparkSessionExtensions") + .config("spark.sql.catalog.fluss_catalog", "org.apache.fluss.spark.SparkCatalog") + .config("spark.sql.catalog.fluss_catalog.bootstrap.servers", "localhost:9092") + .getOrCreate() + +// Create a table +spark.sql(""" + CREATE TABLE fluss_catalog.my_db.my_table ( + id INT, + name STRING, + age INT + ) USING fluss +""") + +// Procedures will be added here when implemented +``` + +## Implementation Notes + +- Procedures are executed synchronously and return results immediately +- The `sys` namespace is reserved for system procedures +- Custom procedures can be added by implementing the `Procedure` interface + +## See Also + +- [Fluss Spark Connector Documentation](../spark-connector.md) +- [Fluss Admin API](../admin-api.md) diff --git a/fluss-spark/fluss-spark-common/pom.xml b/fluss-spark/fluss-spark-common/pom.xml index e285b8bbc2..95b7c3bfbe 100644 --- a/fluss-spark/fluss-spark-common/pom.xml +++ b/fluss-spark/fluss-spark-common/pom.xml @@ -41,10 +41,32 @@ spark-catalyst_${scala.binary.version} ${spark.version} + + + org.antlr + antlr4-runtime + 4.9.3 + + + org.antlr + antlr4-maven-plugin + 4.9.3 + + + + antlr4 + + + + + true + + + org.apache.maven.plugins maven-shade-plugin diff --git a/fluss-spark/fluss-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/FlussSqlExtensions.g4 b/fluss-spark/fluss-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/FlussSqlExtensions.g4 new file mode 100644 index 0000000000..a693427caa --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/antlr4/org.apache.spark.sql.catalyst.parser.extensions/FlussSqlExtensions.g4 @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +grammar FlussSqlExtensions; + +@lexer::members { + public boolean isValidDecimal() { + int nextChar = _input.LA(1); + if (nextChar >= 'A' && nextChar <= 'Z' || nextChar >= '0' && nextChar <= '9' || + nextChar == '_') { + return false; + } else { + return true; + } + } + + public boolean isHint() { + int nextChar = _input.LA(1); + if (nextChar == '+') { + return true; + } else { + return false; + } + } +} + +singleStatement + : statement ';'* EOF + ; + +statement + : CALL multipartIdentifier '(' (callArgument (',' callArgument)*)? ')' #call + ; + +callArgument + : expression #positionalArgument + | identifier '=>' expression #namedArgument + ; + +expression + : constant + ; + +constant + : number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + | identifier STRING #typeConstructor + ; + +booleanValue + : TRUE | FALSE + ; + +number + : MINUS? EXPONENT_VALUE #exponentLiteral + | MINUS? DECIMAL_VALUE #decimalLiteral + | MINUS? INTEGER_VALUE #integerLiteral + | MINUS? BIGINT_LITERAL #bigIntLiteral + | MINUS? SMALLINT_LITERAL #smallIntLiteral + | MINUS? TINYINT_LITERAL #tinyIntLiteral + | MINUS? DOUBLE_LITERAL #doubleLiteral + | MINUS? FLOAT_LITERAL #floatLiteral + | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral + ; + +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +nonReserved + : CALL | TRUE | FALSE + ; + +// Keywords +CALL: 'CALL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; + +// Operators +MINUS: '-'; + +// Literals +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '"' ( ~('"'|'\\') | ('\\' .) )* '"' + ; + +BIGINT_LITERAL + : INTEGER_VALUE 'L' + ; + +SMALLINT_LITERAL + : INTEGER_VALUE 'S' + ; + +TINYINT_LITERAL + : INTEGER_VALUE 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +EXPONENT_VALUE + : DIGIT+ EXPONENT + | DECIMAL_DIGITS EXPONENT {isValidDecimal()}? + ; + +DECIMAL_VALUE + : DECIMAL_DIGITS {isValidDecimal()}? + ; + +FLOAT_LITERAL + : DIGIT+ EXPONENT? 'F' + | DECIMAL_DIGITS EXPONENT? 'F' {isValidDecimal()}? + ; + +DOUBLE_LITERAL + : DIGIT+ EXPONENT? 'D' + | DECIMAL_DIGITS EXPONENT? 'D' {isValidDecimal()}? + ; + +BIGDECIMAL_LITERAL + : DIGIT+ EXPONENT? 'BD' + | DECIMAL_DIGITS EXPONENT? 'BD' {isValidDecimal()}? + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment DECIMAL_DIGITS + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +// Whitespace and comments +SIMPLE_COMMENT + : '--' ('\\\n' | ~[\r\n])* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' {!isHint()}? (BRACKETED_COMMENT|.)*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for any characters we didn't match +UNRECOGNIZED + : . + ; diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala index 842ef9b395..4632a5038d 100644 --- a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkCatalog.scala @@ -19,7 +19,9 @@ package org.apache.fluss.spark import org.apache.fluss.exception.{DatabaseNotExistException, TableAlreadyExistException, TableNotExistException} import org.apache.fluss.metadata.TablePath -import org.apache.fluss.spark.catalog.{SupportsFlussNamespaces, WithFlussAdmin} +import org.apache.fluss.spark.analysis.NoSuchProcedureException +import org.apache.fluss.spark.catalog.{ProcedureCatalog, SupportsFlussNamespaces, WithFlussAdmin} +import org.apache.fluss.spark.procedure.{Procedure, ProcedureBuilder} import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog, TableChange} @@ -32,9 +34,14 @@ import java.util.concurrent.ExecutionException import scala.collection.JavaConverters._ -class SparkCatalog extends TableCatalog with SupportsFlussNamespaces with WithFlussAdmin { +class SparkCatalog + extends TableCatalog + with SupportsFlussNamespaces + with WithFlussAdmin + with ProcedureCatalog { private var catalogName: String = "fluss" + private val SYSTEM_NAMESPACE = "sys" override def listTables(namespace: Array[String]): Array[Identifier] = { doNamespaceOperator(namespace) { @@ -104,6 +111,20 @@ class SparkCatalog extends TableCatalog with SupportsFlussNamespaces with WithFl override def name(): String = catalogName + override def loadProcedure(identifier: Identifier): Procedure = { + if (isSystemNamespace(identifier.namespace)) { + val builder: ProcedureBuilder = SparkProcedures.newBuilder(identifier.name) + if (builder != null) { + return builder.withTableCatalog(this).build() + } + } + throw new NoSuchProcedureException(identifier) + } + + private def isSystemNamespace(namespace: Array[String]): Boolean = { + namespace.length == 1 && namespace(0).equalsIgnoreCase(SYSTEM_NAMESPACE) + } + private def toTablePath(ident: Identifier): TablePath = { assert(ident.namespace().length == 1, "Only single namespace is supported") TablePath.of(ident.namespace().head, ident.name) diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkProcedures.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkProcedures.scala new file mode 100644 index 0000000000..12d75b2806 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/SparkProcedures.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark + +import org.apache.fluss.spark.procedure.{CompactProcedure, ProcedureBuilder} + +import java.util.Locale + +object SparkProcedures { + + private val BUILDERS: Map[String, () => ProcedureBuilder] = initProcedureBuilders() + + def newBuilder(name: String): ProcedureBuilder = { + val builderSupplier = BUILDERS.get(name.toLowerCase(Locale.ROOT)) + builderSupplier.map(_()).orNull + } + + def names(): Set[String] = BUILDERS.keySet + + private def initProcedureBuilders(): Map[String, () => ProcedureBuilder] = { + Map( + "compact" -> (() => CompactProcedure.builder()) + ) + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/analysis/NoSuchProcedureException.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/analysis/NoSuchProcedureException.scala new file mode 100644 index 0000000000..984ebd8f1f --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/analysis/NoSuchProcedureException.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.analysis + +import org.apache.spark.sql.connector.catalog.Identifier + +class NoSuchProcedureException(message: String) extends Exception(message) { + + def this(identifier: Identifier) = { + this(s"Procedure not found: $identifier") + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/ProcedureCatalog.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/ProcedureCatalog.scala new file mode 100644 index 0000000000..471f978ca3 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalog/ProcedureCatalog.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.catalog + +import org.apache.fluss.spark.analysis.NoSuchProcedureException +import org.apache.fluss.spark.procedure.Procedure + +import org.apache.spark.sql.connector.catalog.Identifier + +trait ProcedureCatalog { + + @throws[NoSuchProcedureException] + def loadProcedure(identifier: Identifier): Procedure +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/analysis/FlussProcedureResolver.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/analysis/FlussProcedureResolver.scala new file mode 100644 index 0000000000..fb0f197259 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/analysis/FlussProcedureResolver.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.catalyst.analysis + +import org.apache.fluss.spark.catalog.ProcedureCatalog +import org.apache.fluss.spark.catalyst.plans.logical._ +import org.apache.fluss.spark.procedure.ProcedureParameter + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier} + +import java.util.Locale + +/** Resolution rule for Fluss stored procedures. */ +case class FlussProcedureResolver(sparkSession: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case FlussCallStatement(nameParts, arguments) if nameParts.nonEmpty => + val (catalog, identifier) = resolveCatalogAndIdentifier(nameParts) + if (catalog == null || !catalog.isInstanceOf[ProcedureCatalog]) { + throw new RuntimeException(s"Catalog ${nameParts.head} is not a ProcedureCatalog") + } + + val procedureCatalog = catalog.asInstanceOf[ProcedureCatalog] + val procedure = procedureCatalog.loadProcedure(identifier) + val parameters = procedure.parameters + val normalizedParameters = normalizeParameters(parameters) + validateParameters(normalizedParameters) + val normalizedArguments = normalizeArguments(arguments) + FlussCallCommand( + procedure, + args = buildArgumentExpressions(normalizedParameters, normalizedArguments)) + + case call @ FlussCallCommand(procedure, arguments) if call.resolved => + val parameters = procedure.parameters + val newArguments = arguments.zipWithIndex.map { + case (argument, index) => + val parameter = parameters(index) + val parameterType = parameter.dataType + val argumentType = argument.dataType + if (parameterType != argumentType && !Cast.canUpCast(argumentType, parameterType)) { + throw new RuntimeException( + s"Cannot cast $argumentType to $parameterType of ${parameter.name}.") + } + if (parameterType != argumentType) { + Cast(argument, parameterType) + } else { + argument + } + } + + if (newArguments != arguments) { + call.copy(args = newArguments) + } else { + call + } + } + + private def resolveCatalogAndIdentifier(nameParts: Seq[String]): (CatalogPlugin, Identifier) = { + val catalogManager = sparkSession.sessionState.catalogManager + if (nameParts.length == 2) { + val catalogName = nameParts.head + val procedureName = nameParts(1) + val catalog = catalogManager.catalog(catalogName) + (catalog, Identifier.of(Array("sys"), procedureName)) + } else if (nameParts.length == 3) { + val catalogName = nameParts.head + val namespace = nameParts(1) + val procedureName = nameParts(2) + val catalog = catalogManager.catalog(catalogName) + (catalog, Identifier.of(Array(namespace), procedureName)) + } else { + throw new RuntimeException(s"Invalid procedure name: ${nameParts.mkString(".")}") + } + } + + private def normalizeParameters(parameters: Seq[ProcedureParameter]): Seq[ProcedureParameter] = { + parameters.map { + parameter => + val normalizedName = parameter.name.toLowerCase(Locale.ROOT) + if (parameter.required) { + ProcedureParameter.required(normalizedName, parameter.dataType) + } else { + ProcedureParameter.optional(normalizedName, parameter.dataType) + } + } + } + + private def validateParameters(parameters: Seq[ProcedureParameter]): Unit = { + val duplicateParamNames = parameters.groupBy(_.name).collect { + case (name, matchingParams) if matchingParams.length > 1 => name + } + if (duplicateParamNames.nonEmpty) { + throw new RuntimeException( + s"Parameter names ${duplicateParamNames.mkString("[", ",", "]")} are duplicated.") + } + parameters.sliding(2).foreach { + case Seq(previousParam, currentParam) if !previousParam.required && currentParam.required => + throw new RuntimeException( + s"Optional parameters should be after required ones but $currentParam is after $previousParam.") + case _ => + } + } + + private def normalizeArguments(arguments: Seq[FlussCallArgument]): Seq[FlussCallArgument] = { + arguments.map { + case a @ FlussNamedArgument(name, _) => a.copy(name = name.toLowerCase(Locale.ROOT)) + case other => other + } + } + + private def buildArgumentExpressions( + parameters: Seq[ProcedureParameter], + arguments: Seq[FlussCallArgument]): Seq[Expression] = { + val nameToPositionMap = parameters.map(_.name).zipWithIndex.toMap + val nameToArgumentMap = buildNameToArgumentMap(parameters, arguments, nameToPositionMap) + val missingParamNames = parameters.filter(_.required).collect { + case parameter if !nameToArgumentMap.contains(parameter.name) => parameter.name + } + if (missingParamNames.nonEmpty) { + throw new RuntimeException( + s"Required parameters ${missingParamNames.mkString("[", ",", "]")} are missing.") + } + val argumentExpressions = new Array[Expression](parameters.size) + nameToArgumentMap.foreach { + case (name, argument) => argumentExpressions(nameToPositionMap(name)) = argument.expr + } + parameters.foreach { + case parameter if !parameter.required && !nameToArgumentMap.contains(parameter.name) => + argumentExpressions(nameToPositionMap(parameter.name)) = + Literal.create(null, parameter.dataType) + case _ => + } + argumentExpressions.toSeq + } + + private def buildNameToArgumentMap( + parameters: Seq[ProcedureParameter], + arguments: Seq[FlussCallArgument], + nameToPositionMap: Map[String, Int]): Map[String, FlussCallArgument] = { + val isNamedArgument = arguments.exists(_.isInstanceOf[FlussNamedArgument]) + val isPositionalArgument = arguments.exists(_.isInstanceOf[FlussPositionalArgument]) + + if (isNamedArgument && isPositionalArgument) { + throw new RuntimeException("Cannot mix named and positional arguments.") + } + + if (isNamedArgument) { + val namedArguments = arguments.asInstanceOf[Seq[FlussNamedArgument]] + val validationErrors = namedArguments.groupBy(_.name).collect { + case (name, procedureArguments) if procedureArguments.size > 1 => + s"Procedure argument $name is duplicated." + case (name, _) if !nameToPositionMap.contains(name) => s"Argument $name is unknown." + } + if (validationErrors.nonEmpty) { + throw new RuntimeException(s"Invalid arguments: ${validationErrors.mkString(", ")}") + } + namedArguments.map(arg => arg.name -> arg).toMap + } else { + if (arguments.size > parameters.size) { + throw new RuntimeException("Too many arguments for procedure") + } + arguments.zipWithIndex.map { + case (argument, position) => + val param = parameters(position) + param.name -> argument + }.toMap + } + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/plans/logical/FlussCallCommand.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/plans/logical/FlussCallCommand.scala new file mode 100644 index 0000000000..f0ea0265d6 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/catalyst/plans/logical/FlussCallCommand.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.catalyst.plans.logical + +import org.apache.fluss.spark.procedure.Procedure + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LeafCommand +import org.apache.spark.sql.catalyst.util.truncatedString + +/** A CALL statement that needs to be resolved to a procedure. */ +case class FlussCallStatement(name: Seq[String], args: Seq[FlussCallArgument]) extends LeafCommand { + override def output: Seq[Attribute] = Seq.empty +} + +/** Base trait for CALL statement arguments. */ +sealed trait FlussCallArgument { + def expr: Expression +} + +/** A positional argument in a stored procedure call. */ +case class FlussPositionalArgument(expr: Expression) extends FlussCallArgument + +/** A named argument in a stored procedure call. */ +case class FlussNamedArgument(name: String, expr: Expression) extends FlussCallArgument + +/** A CALL command that has been resolved to a specific procedure. */ +case class FlussCallCommand(procedure: Procedure, args: Seq[Expression]) extends LeafCommand { + + override lazy val output: Seq[Attribute] = + procedure.outputType.map( + field => AttributeReference(field.name, field.dataType, field.nullable, field.metadata)()) + + override def simpleString(maxFields: Int): String = { + s"Call${truncatedString(output, "[", ", ", "]", maxFields)} ${procedure.description}" + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/CallProcedureExec.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/CallProcedureExec.scala new file mode 100644 index 0000000000..1ed8949a21 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/CallProcedureExec.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.execution + +import org.apache.fluss.spark.procedure.Procedure + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.SparkPlan + +/** Physical plan node for executing a stored procedure. */ +case class CallProcedureExec(output: Seq[Attribute], procedure: Procedure, args: Seq[Expression]) + extends SparkPlan { + + override def children: Seq[SparkPlan] = Seq.empty + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = { + copy() + } + + override protected def doExecute(): org.apache.spark.rdd.RDD[InternalRow] = { + val argumentValues = new Array[Any](args.length) + args.zipWithIndex.foreach { + case (arg, index) => + argumentValues(index) = arg.eval(null) + } + + val argRow = new GenericInternalRow(argumentValues) + val resultRows = procedure.call(argRow) + + sparkContext.parallelize(resultRows) + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/FlussStrategy.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/FlussStrategy.scala new file mode 100644 index 0000000000..e4d3d4117a --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/execution/FlussStrategy.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.execution + +import org.apache.fluss.spark.catalyst.plans.logical.FlussCallCommand + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} + +/** Execution strategy for Fluss procedure calls. */ +case class FlussStrategy(spark: SparkSession) extends SparkStrategy { + + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case call: FlussCallCommand => + CallProcedureExec(call.output, call.procedure, call.args) :: Nil + case _ => Nil + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/extensions/FlussSparkSessionExtensions.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/extensions/FlussSparkSessionExtensions.scala new file mode 100644 index 0000000000..a041e92943 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/extensions/FlussSparkSessionExtensions.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.extensions + +import org.apache.fluss.spark.catalyst.analysis.FlussProcedureResolver +import org.apache.fluss.spark.execution.FlussStrategy + +import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.parser.extensions.FlussSparkSqlParser + +/** Spark session extensions for Fluss. */ +class FlussSparkSessionExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + // parser extensions + extensions.injectParser { case (_, parser) => new FlussSparkSqlParser(parser) } + + // analyzer extensions + extensions.injectResolutionRule(spark => FlussProcedureResolver(spark)) + + // planner extensions + extensions.injectPlannerStrategy(spark => FlussStrategy(spark)) + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/BaseProcedure.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/BaseProcedure.scala new file mode 100644 index 0000000000..d9f9efc9f0 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/BaseProcedure.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.procedure + +import org.apache.fluss.client.admin.Admin +import org.apache.fluss.metadata.TablePath +import org.apache.fluss.spark.SparkTable +import org.apache.fluss.spark.catalog.AbstractSparkTable + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog} + +abstract class BaseProcedure(tableCatalog: TableCatalog) extends Procedure { + + protected def toIdentifier(identifierAsString: String, argName: String): Identifier = { + if (identifierAsString == null || identifierAsString.isEmpty) { + throw new IllegalArgumentException(s"Cannot handle an empty identifier for argument $argName") + } + + val spark = SparkSession.active + val multipartIdentifier = identifierAsString.split("\\.") + + if (multipartIdentifier.length == 1) { + val defaultNamespace = spark.sessionState.catalogManager.currentNamespace + Identifier.of(defaultNamespace, multipartIdentifier(0)) + } else if (multipartIdentifier.length == 2) { + Identifier.of(Array(multipartIdentifier(0)), multipartIdentifier(1)) + } else { + throw new IllegalArgumentException( + s"Invalid identifier format for argument $argName: $identifierAsString") + } + } + + protected def loadSparkTable(ident: Identifier): SparkTable = { + try { + val table = tableCatalog.loadTable(ident) + table match { + case sparkTable: SparkTable => sparkTable + case _ => + throw new IllegalArgumentException( + s"$ident is not a Fluss table: ${table.getClass.getName}") + } + } catch { + case e: Exception => + val errMsg = s"Couldn't load table '$ident' in catalog '${tableCatalog.name()}'" + throw new RuntimeException(errMsg, e) + } + } + + protected def getAdmin(table: SparkTable): Admin = { + table match { + case abstractTable: AbstractSparkTable => abstractTable.admin + case _ => + throw new IllegalArgumentException( + s"Table is not an AbstractSparkTable: ${table.getClass.getName}") + } + } + + protected def newInternalRow(values: Any*): InternalRow = { + new GenericInternalRow(values.toArray) + } + + protected def toTablePath(ident: Identifier): TablePath = { + if (ident.namespace().length != 1) { + throw new IllegalArgumentException("Only single namespace is supported") + } + TablePath.of(ident.namespace()(0), ident.name()) + } +} + +object BaseProcedure { + + abstract class Builder[T <: BaseProcedure] extends ProcedureBuilder { + private var tableCatalog: TableCatalog = _ + + override def withTableCatalog(newTableCatalog: TableCatalog): Builder[T] = { + this.tableCatalog = newTableCatalog + this + } + + override def build(): T = doBuild() + + protected def doBuild(): T + + protected def getTableCatalog: TableCatalog = tableCatalog + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/CompactProcedure.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/CompactProcedure.scala new file mode 100644 index 0000000000..d70653ab42 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/CompactProcedure.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.procedure + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.catalog.TableCatalog +import org.apache.spark.sql.types.{DataTypes, Metadata, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class CompactProcedure(tableCatalog: TableCatalog) extends BaseProcedure(tableCatalog) { + + override def parameters(): Array[ProcedureParameter] = { + CompactProcedure.PARAMETERS + } + + override def outputType(): StructType = { + CompactProcedure.OUTPUT_TYPE + } + + override def call(args: InternalRow): Array[InternalRow] = { + val tableIdent = toIdentifier(args.getString(0), CompactProcedure.PARAMETERS(0).name()) + val sparkTable = loadSparkTable(tableIdent) + + try { + val tablePath = toTablePath(tableIdent) + val admin = getAdmin(sparkTable) + + val message = s"Compact operation queued for table $tablePath" + + Array(newInternalRow(UTF8String.fromString(message))) + } catch { + case e: Exception => + throw new RuntimeException(s"Failed to compact table: ${e.getMessage}", e) + } + } + + override def description(): String = { + "This procedure triggers a compact action on a Fluss table." + } +} + +object CompactProcedure { + + private val PARAMETERS: Array[ProcedureParameter] = Array( + ProcedureParameter.required("table", DataTypes.StringType) + ) + + private val OUTPUT_TYPE: StructType = new StructType( + Array( + new StructField("result", DataTypes.StringType, nullable = true, Metadata.empty) + ) + ) + + def builder(): ProcedureBuilder = { + new BaseProcedure.Builder[CompactProcedure]() { + override protected def doBuild(): CompactProcedure = { + new CompactProcedure(getTableCatalog) + } + } + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/Procedure.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/Procedure.scala new file mode 100644 index 0000000000..d2f40fe658 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/Procedure.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.procedure + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.StructType + +trait Procedure { + + def parameters(): Array[ProcedureParameter] + + def outputType(): StructType + + def call(args: InternalRow): Array[InternalRow] + + def description(): String = getClass.toString +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureBuilder.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureBuilder.scala new file mode 100644 index 0000000000..cfa4ad823d --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureBuilder.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.procedure + +import org.apache.spark.sql.connector.catalog.TableCatalog + +trait ProcedureBuilder { + + def withTableCatalog(tableCatalog: TableCatalog): ProcedureBuilder + + def build(): Procedure +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureParameter.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureParameter.scala new file mode 100644 index 0000000000..5125ecbed7 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/fluss/spark/procedure/ProcedureParameter.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.procedure + +import org.apache.spark.sql.types.DataType + +trait ProcedureParameter { + + def name(): String + + def dataType(): DataType + + def required(): Boolean +} + +object ProcedureParameter { + + def required(name: String, dataType: DataType): ProcedureParameter = + ProcedureParameterImpl(name, dataType, isRequired = true) + + def optional(name: String, dataType: DataType): ProcedureParameter = + ProcedureParameterImpl(name, dataType, isRequired = false) +} + +private case class ProcedureParameterImpl( + paramName: String, + paramDataType: DataType, + isRequired: Boolean) + extends ProcedureParameter { + + override def name(): String = paramName + + override def dataType(): DataType = paramDataType + + override def required(): Boolean = isRequired +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsAstBuilder.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsAstBuilder.scala new file mode 100644 index 0000000000..f1327a5e49 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsAstBuilder.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import org.apache.fluss.spark.catalyst.plans.logical._ + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.extensions.FlussSqlExtensionsParser._ +import org.apache.spark.sql.catalyst.plans.logical._ + +import scala.collection.JavaConverters._ + +/** + * The AST Builder for Fluss SQL extensions. + * + * @param delegate + * The extension parser. + */ +class FlussSqlExtensionsAstBuilder(delegate: ParserInterface) + extends FlussSqlExtensionsBaseVisitor[AnyRef] + with Logging { + + /** Creates a single statement of extension statements. */ + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + /** Creates a [[FlussCallStatement]] for a stored procedure call. */ + override def visitCall(ctx: CallContext): FlussCallStatement = withOrigin(ctx) { + val name = + ctx.multipartIdentifier.parts.asScala.map(part => cleanIdentifier(part.getText)).toSeq + val args = ctx.callArgument.asScala.map(typedVisit[FlussCallArgument]).toSeq + FlussCallStatement(name, args) + } + + /** Creates a positional argument in a stored procedure call. */ + override def visitPositionalArgument(ctx: PositionalArgumentContext): FlussCallArgument = + withOrigin(ctx) { + val expression = typedVisit[Expression](ctx.expression) + FlussPositionalArgument(expression) + } + + /** Creates a named argument in a stored procedure call. */ + override def visitNamedArgument(ctx: NamedArgumentContext): FlussCallArgument = withOrigin(ctx) { + val name = cleanIdentifier(ctx.identifier.getText) + val expression = typedVisit[Expression](ctx.expression) + FlussNamedArgument(name, expression) + } + + /** Creates a [[Expression]] in a positional and named argument. */ + override def visitExpression(ctx: ExpressionContext): Expression = { + val sqlString = reconstructSqlString(ctx) + delegate.parseExpression(sqlString) + } + + /** Returns a multi-part identifier as Seq[String]. */ + override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] = + withOrigin(ctx) { + ctx.parts.asScala.map(part => cleanIdentifier(part.getText)).toSeq + } + + /** Remove backticks from identifier. */ + private def cleanIdentifier(ident: String): String = { + if (ident.startsWith("`") && ident.endsWith("`")) { + ident.substring(1, ident.length - 1) + } else { + ident + } + } + + private def reconstructSqlString(ctx: ParserRuleContext): String = { + ctx.children.asScala + .map { + case c: ParserRuleContext => reconstructSqlString(c) + case t: TerminalNode => t.getText + } + .mkString(" ") + } + + private def typedVisit[T](ctx: ParseTree): T = + ctx.accept(this).asInstanceOf[T] + + private def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + private def position(token: Token): Origin = { + val opt = Option(token) + Origin(opt.map(_.getLine), opt.map(_.getCharPositionInLine)) + } +} + +case class Origin( + line: Option[Int] = None, + startPosition: Option[Int] = None, + startIndex: Option[Int] = None, + stopIndex: Option[Int] = None, + sqlText: Option[String] = None, + objectType: Option[String] = None, + objectName: Option[String] = None) + +object CurrentOrigin { + private val value = new ThreadLocal[Origin]() { + override def initialValue: Origin = Origin() + } + + def get: Origin = value.get() + def set(o: Origin): Unit = value.set(o) + def reset(): Unit = value.set(Origin()) + + def withOrigin[A](o: Origin)(f: => A): A = { + val previous = get + set(o) + val ret = + try f + finally { set(previous) } + ret + } +} diff --git a/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsParser.scala b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsParser.scala new file mode 100644 index 0000000000..57d8d0f346 --- /dev/null +++ b/fluss-spark/fluss-spark-common/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/FlussSqlExtensionsParser.scala @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser.extensions + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.{Interval, ParseCancellationException} +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Parser extension for Fluss SQL extensions. + * + * @param delegate + * The main Spark SQL parser. + */ +class FlussSparkSqlParser(delegate: ParserInterface) extends ParserInterface { + + private lazy val astBuilder = new FlussSqlExtensionsAstBuilder(delegate) + + override def parsePlan(sqlText: String): LogicalPlan = { + try { + parse(sqlText)(parser => astBuilder.visitSingleStatement(parser.singleStatement())) + } catch { + case _: ParseException | _: ParseCancellationException => + delegate.parsePlan(sqlText) + } + } + + override def parseQuery(sqlText: String): LogicalPlan = parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = { + delegate.parseExpression(sqlText) + } + + override def parseTableIdentifier(sqlText: String): TableIdentifier = { + delegate.parseTableIdentifier(sqlText) + } + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + delegate.parseFunctionIdentifier(sqlText) + } + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + delegate.parseMultipartIdentifier(sqlText) + } + + override def parseTableSchema(sqlText: String): StructType = { + delegate.parseTableSchema(sqlText) + } + + override def parseDataType(sqlText: String): DataType = { + delegate.parseDataType(sqlText) + } + + private def parse[T](sqlText: String)( + toResult: org.apache.spark.sql.catalyst.parser.extensions.FlussSqlExtensionsParser => T) + : T = { + val lexer = new FlussSqlExtensionsLexer( + new UpperCaseCharStream(CharStreams.fromString(sqlText))) + lexer.removeErrorListeners() + lexer.addErrorListener(FlussParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = + new org.apache.spark.sql.catalyst.parser.extensions.FlussSqlExtensionsParser(tokenStream) + parser.removeErrorListeners() + parser.addErrorListener(FlussParseErrorListener) + + try { + try { + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } catch { + case _: ParseCancellationException => + tokenStream.seek(0) + parser.reset() + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(sqlText) + case e: AnalysisException => + val position = org.apache.spark.sql.catalyst.trees.Origin(e.line, e.startPosition) + throw new ParseException(Option(sqlText), e.message, position, position) + } + } +} + +class UpperCaseCharStream(wrapped: CodePointCharStream) extends CharStream { + override def consume(): Unit = wrapped.consume() + override def getSourceName: String = wrapped.getSourceName + override def index(): Int = wrapped.index() + override def mark(): Int = wrapped.mark() + override def release(marker: Int): Unit = wrapped.release(marker) + override def seek(where: Int): Unit = wrapped.seek(where) + override def size(): Int = wrapped.size() + + override def getText(interval: Interval): String = { + wrapped.getText(interval) + } + + override def LA(i: Int): Int = { + val la = wrapped.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +object FlussParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = org.apache.spark.sql.catalyst.trees.Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) + } +} diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/FlussSparkTestBase.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/FlussSparkTestBase.scala index 2de158b7db..23b62f7824 100644 --- a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/FlussSparkTestBase.scala +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/FlussSparkTestBase.scala @@ -25,6 +25,7 @@ import org.apache.fluss.config.{ConfigOptions, Configuration} import org.apache.fluss.metadata.{TableDescriptor, TablePath} import org.apache.fluss.row.InternalRow import org.apache.fluss.server.testutils.FlussClusterExtension +import org.apache.fluss.spark.extensions.FlussSparkSessionExtensions import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest @@ -50,6 +51,7 @@ class FlussSparkTestBase extends QueryTest with SharedSparkSession { .set(s"spark.sql.catalog.$DEFAULT_CATALOG", classOf[SparkCatalog].getName) .set(s"spark.sql.catalog.$DEFAULT_CATALOG.bootstrap.servers", bootstrapServers) .set("spark.sql.defaultCatalog", DEFAULT_CATALOG) + .set("spark.sql.extensions", classOf[FlussSparkSessionExtensions].getName) } override protected def beforeAll(): Unit = { diff --git a/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/extensions/CallStatementParserTest.scala b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/extensions/CallStatementParserTest.scala new file mode 100644 index 0000000000..9503e2b7bb --- /dev/null +++ b/fluss-spark/fluss-spark-ut/src/test/scala/org/apache/fluss/spark/extensions/CallStatementParserTest.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.fluss.spark.extensions + +import org.apache.fluss.spark.catalyst.plans.logical.{FlussCallArgument, FlussCallStatement, FlussNamedArgument, FlussPositionalArgument} + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.types.DataTypes +import org.scalatest.{BeforeAndAfterEach, FunSuite} + +import java.math.BigDecimal +import java.sql.Timestamp +import java.time.Instant + +class CallStatementParserTest extends FunSuite with BeforeAndAfterEach { + + private var spark: SparkSession = _ + private var parser: ParserInterface = _ + + override def beforeEach(): Unit = { + super.beforeEach() + val optionalSession = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + optionalSession.foreach(_.stop()) + SparkSession.clearActiveSession() + + spark = SparkSession + .builder() + .master("local[2]") + .config("spark.sql.extensions", classOf[FlussSparkSessionExtensions].getName) + .getOrCreate() + + parser = spark.sessionState.sqlParser + } + + override def afterEach(): Unit = { + if (spark != null) { + spark.stop() + spark = null + parser = null + } + super.afterEach() + } + + test("testCallWithBackticks") { + val call = + parser.parsePlan("CALL cat.`system`.`no_args_func`()").asInstanceOf[FlussCallStatement] + assert(call.name.toList == List("cat", "system", "no_args_func")) + assert(call.args.size == 0) + } + + test("testCallWithNamedArguments") { + val callStatement = parser + .parsePlan("CALL catalog.system.named_args_func(arg1 => 1, arg2 => 'test', arg3 => true)") + .asInstanceOf[FlussCallStatement] + + assert(callStatement.name.toList == List("catalog", "system", "named_args_func")) + assert(callStatement.args.size == 3) + assertArgument(callStatement, 0, Some("arg1"), 1, DataTypes.IntegerType) + assertArgument(callStatement, 1, Some("arg2"), "test", DataTypes.StringType) + assertArgument(callStatement, 2, Some("arg3"), true, DataTypes.BooleanType) + } + + test("testCallWithPositionalArguments") { + val callStatement = parser + .parsePlan( + "CALL catalog.system.positional_args_func(1, '${spark.sql.extensions}', 2L, true, 3.0D, 4.0e1, 500e-1BD, TIMESTAMP '2017-02-03T10:37:30.00Z')") + .asInstanceOf[FlussCallStatement] + + assert(callStatement.name.toList == List("catalog", "system", "positional_args_func")) + assert(callStatement.args.size == 8) + assertArgument(callStatement, 0, None, 1, DataTypes.IntegerType) + assertArgument( + callStatement, + 1, + None, + classOf[FlussSparkSessionExtensions].getName, + DataTypes.StringType) + assertArgument(callStatement, 2, None, 2L, DataTypes.LongType) + assertArgument(callStatement, 3, None, true, DataTypes.BooleanType) + assertArgument(callStatement, 4, None, 3.0, DataTypes.DoubleType) + assertArgument(callStatement, 5, None, 4.0e1, DataTypes.DoubleType) + assertArgument( + callStatement, + 6, + None, + new BigDecimal("500e-1"), + DataTypes.createDecimalType(3, 1)) + assertArgument( + callStatement, + 7, + None, + Timestamp.from(Instant.parse("2017-02-03T10:37:30.00Z")), + DataTypes.TimestampType) + } + + test("testCallWithMixedArguments") { + val callStatement = parser + .parsePlan("CALL catalog.system.mixed_args_func(arg1 => 1, 'test')") + .asInstanceOf[FlussCallStatement] + + assert(callStatement.name.toList == List("catalog", "system", "mixed_args_func")) + assert(callStatement.args.size == 2) + assertArgument(callStatement, 0, Some("arg1"), 1, DataTypes.IntegerType) + assertArgument(callStatement, 1, None, "test", DataTypes.StringType) + } + + test("testCallSimpleProcedure") { + val callStatement = parser + .parsePlan("CALL system.simple_procedure(table => 'db.table')") + .asInstanceOf[FlussCallStatement] + + assert(callStatement.name.toList == List("system", "simple_procedure")) + assert(callStatement.args.size == 1) + assertArgument(callStatement, 0, Some("table"), "db.table", DataTypes.StringType) + } + + private def assertArgument( + callStatement: FlussCallStatement, + index: Int, + expectedName: Option[String], + expectedValue: Any, + expectedType: org.apache.spark.sql.types.DataType): Unit = { + + val callArgument = callStatement.args(index) + + expectedName match { + case None => + assert(callArgument.isInstanceOf[FlussPositionalArgument]) + case Some(name) => + val namedArgument = callArgument.asInstanceOf[FlussNamedArgument] + assert(namedArgument.name == name) + } + + assert(callStatement.args(index).expr == Literal.create(expectedValue, expectedType)) + } +}