diff --git a/docs/docs/spark-writes.md b/docs/docs/spark-writes.md index 959993b20fc8..9b23eed64a12 100644 --- a/docs/docs/spark-writes.md +++ b/docs/docs/spark-writes.md @@ -29,6 +29,7 @@ Iceberg uses Apache Spark's DataSourceV2 API for data source and catalog impleme | Feature support | Spark | Notes | |--------------------------------------------------|---------|-----------------------------------------------------------------------------| | [SQL insert into](#insert-into) | ✔️ | ⚠ Requires `spark.sql.storeAssignmentPolicy=ANSI` (default since Spark 3.0) | +| [SQL scoped replace](#insert-into--replace-using) | ✔️ | ⚠ Requires Iceberg Spark extensions and Spark 4.1 or higher | | [SQL merge into](#merge-into) | ✔️ | ⚠ Requires Iceberg Spark extensions | | [SQL insert overwrite](#insert-overwrite) | ✔️ | ⚠ Requires `spark.sql.storeAssignmentPolicy=ANSI` (default since Spark 3.0) | | [SQL delete from](#delete-from) | ✔️ | ⚠ Row-level delete requires Iceberg Spark extensions | @@ -40,7 +41,7 @@ Iceberg uses Apache Spark's DataSourceV2 API for data source and catalog impleme ## Writing with SQL -Spark supports SQL `INSERT INTO`, `MERGE INTO`, and `INSERT OVERWRITE`, as well as the new `DataFrameWriterV2` API. +Spark supports SQL `INSERT INTO`, `INSERT INTO ... REPLACE USING`, `MERGE INTO`, and `INSERT OVERWRITE`, as well as the new `DataFrameWriterV2` API. ### `INSERT INTO` @@ -53,6 +54,31 @@ INSERT INTO prod.db.table VALUES (1, 'a'), (2, 'b') INSERT INTO prod.db.table SELECT ... ``` +### `INSERT INTO ... REPLACE USING` + +Iceberg supports scoped replacement writes with Iceberg Spark extensions in Spark 4.1 and higher. A scoped replace deletes existing target rows whose replacement scope appears in the source query, then inserts all rows from the source query in the same commit. + +```sql +INSERT INTO prod.db.table +REPLACE USING (scope_col_1, scope_col_2) +SELECT ... +``` + +The columns listed in `REPLACE USING` define the replacement scope. For each distinct tuple of scope values produced by the source query, matching target rows are removed using null-safe equality, and the full source query output is appended. + +For example, this query replaces all existing rows for categories present in `prod.db.staged_rows` and keeps rows for other categories: + +```sql +INSERT INTO prod.db.sample +REPLACE USING (category) +SELECT id, data, category, ts +FROM prod.db.staged_rows +``` + +`REPLACE USING` is useful when incoming data contains complete replacement slices for one or more logical groups, such as tenants, departments, dates, or regions. Unlike `INSERT OVERWRITE`, the replacement scope is based on table columns in the source data and does not depend on the table's partition spec. + +The source query must produce the same number of columns as the target table, using the same assignment rules as `INSERT INTO`. Each `REPLACE USING` column must exist in both the target table and the source query output. + ### `MERGE INTO` Spark supports `MERGE INTO` queries that can express row-level updates. diff --git a/spark/v4.1/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 b/spark/v4.1/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 index 4c2a16d7b19a..ce20930912d0 100644 --- a/spark/v4.1/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 +++ b/spark/v4.1/spark-extensions/src/main/antlr/org.apache.spark.sql.catalyst.parser.extensions/IcebergSqlExtensions.g4 @@ -131,6 +131,13 @@ singleOrder : order EOF ; +// Parses only the command head of `INSERT INTO t REPLACE USING (cols) `. +// The query tail remains Spark SQL and is delegated to Spark's parser until this syntax +// can be represented directly in Spark's `insertInto` grammar. +singleScopedReplaceHead + : INSERT INTO TABLE? multipartIdentifier REPLACE USING '(' fieldList ')' EOF + ; + order : fields+=orderField (',' fields+=orderField)* | '(' fields+=orderField (',' fields+=orderField)* ')' @@ -211,6 +218,7 @@ nonReserved | DISTRIBUTED | LOCALLY | MINUTES | MONTHS | UNORDERED | REPLACE | RETAIN | VERSION | WITH | IDENTIFIER_KW | FIELDS | SET | SNAPSHOT | SNAPSHOTS | TAG | TRUE | FALSE | MAP + | INSERT | INTO | USING ; snapshotId @@ -243,6 +251,8 @@ FIELDS: 'FIELDS'; FIRST: 'FIRST'; HOURS: 'HOURS'; IF : 'IF'; +INSERT: 'INSERT'; +INTO: 'INTO'; LAST: 'LAST'; LOCALLY: 'LOCALLY'; MINUTES: 'MINUTES'; @@ -264,6 +274,7 @@ SNAPSHOTS: 'SNAPSHOTS'; TABLE: 'TABLE'; TAG: 'TAG'; UNORDERED: 'UNORDERED'; +USING: 'USING'; VERSION: 'VERSION'; WITH: 'WITH'; WRITE: 'WRITE'; diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index 81824e05e92d..d8e72f0841b5 100644 --- a/spark/v4.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.SparkSessionExtensions import org.apache.spark.sql.catalyst.analysis.CheckViews import org.apache.spark.sql.catalyst.analysis.ResolveBranch import org.apache.spark.sql.catalyst.analysis.ResolveViews +import org.apache.spark.sql.catalyst.analysis.RewriteScopedReplace import org.apache.spark.sql.catalyst.optimizer.ReplaceStaticInvoke import org.apache.spark.sql.catalyst.parser.extensions.IcebergSparkSqlExtensionsParser import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy @@ -35,6 +36,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { // analyzer extensions extensions.injectResolutionRule { spark => ResolveViews(spark) } extensions.injectPostHocResolutionRule { spark => ResolveBranch(spark) } + extensions.injectPostHocResolutionRule { _ => RewriteScopedReplace } extensions.injectCheckRule(_ => CheckViews) // optimizer extensions diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteScopedReplace.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteScopedReplace.scala new file mode 100644 index 000000000000..735e83eea505 --- /dev/null +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteScopedReplace.scala @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.ProjectingInternalRow +import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.And +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.EqualNullSafe +import org.apache.spark.sql.catalyst.expressions.Exists +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.expressions.OuterReference +import org.apache.spark.sql.catalyst.plans.LeftAnti +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.CTERelationDef +import org.apache.spark.sql.catalyst.plans.logical.CTERelationRef +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.catalyst.plans.logical.JoinHint +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.ReplaceData +import org.apache.spark.sql.catalyst.plans.logical.ReplaceScopedData +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.catalyst.plans.logical.WithCTE +import org.apache.spark.sql.catalyst.plans.logical.WriteDelta +import org.apache.spark.sql.catalyst.util.ReplaceDataProjections +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.DELETE_OPERATION +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.INSERT_OPERATION +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.OPERATION_COLUMN +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.WRITE_OPERATION +import org.apache.spark.sql.catalyst.util.RowDeltaUtils.WRITE_WITH_METADATA_OPERATION +import org.apache.spark.sql.catalyst.util.WriteDeltaProjections +import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations +import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE +import org.apache.spark.sql.connector.write.RowLevelOperationTable +import org.apache.spark.sql.connector.write.SupportsDelta +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +/** + * Lowers a [[ReplaceScopedData]] command into Iceberg's row-level write path. + * + * Scoped replace deletes target rows whose scope columns match a source row, then appends the full + * source. Unlike MERGE, every source row must be inserted whether or not it matches a target row, so + * the replacement state is computed from separate carryover and insert branches instead of from + * per-match joined rows. + * + * The source is shared through a CTE so non-deterministic expressions are evaluated consistently by + * the carryover branch and insert branch. Runtime file pruning is only applied for deterministic + * sources because Spark requires row-level operation conditions to be deterministic. + * + * The row-level operation is requested as [[MERGE]] so copy-on-write vs merge-on-read selection + * follows the same table configuration path as MERGE. + */ +object RewriteScopedReplace extends RewriteRowLevelCommand { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case rsd @ ReplaceScopedData(aliasedTable, scopeColumns, source) + if rsd.resolved && source.resolved => + + EliminateSubqueryAliases(aliasedTable) match { + case r @ DataSourceV2Relation(tbl: SupportsRowLevelOperations, _, _, _, _, _) => + if (source.output.size != r.output.size) { + throw analysisError( + "The source query of a scoped replace must produce the same number of columns as " + + s"the target table ${r.name}: expected ${r.output.size}, got ${source.output.size}") + } + + val operationTable = buildOperationTable(tbl, MERGE, CaseInsensitiveStringMap.empty()) + val alignedSource = alignSourceColumns(r, source) + operationTable.operation match { + case deltaOperation: SupportsDelta => + buildWriteDeltaPlan(r, operationTable, deltaOperation, scopeColumns, alignedSource) + case _ => + buildReplaceDataPlan(r, operationTable, scopeColumns, alignedSource) + } + + case other => + throw analysisError( + s"Scoped replace is only supported on Iceberg tables, found: ${other.simpleString(2)}") + } + } + + private def buildReplaceDataPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + scopeColumns: Seq[Seq[String]], + source: LogicalPlan): LogicalPlan = { + + // Scoped replace uses INSERT-style positional alignment between source and target columns. + val scopeOrdinals = resolveScopeOrdinals(relation, scopeColumns) + + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operationTable.operation) + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs) + val rowAttrs = relation.output + + val sourceCte = CTERelationDef(source) + val scopeSource = newSourceRef(sourceCte) + val insertSource = newSourceRef(sourceCte) + val filterSource = newSourceRef(sourceCte) + + val antiJoinCondition = scopeEquality(readRelation.output, scopeSource.output, scopeOrdinals) + val carryoverJoin = + Join(readRelation, scopeSource, LeftAnti, Some(antiJoinCondition), JoinHint.NONE) + val carryoverOutput = + operationAlias(WRITE_WITH_METADATA_OPERATION) +: readRelation.output + val carryover = Project(carryoverOutput, carryoverJoin) + + val insertData = alignedSourceData(insertSource, rowAttrs) + val insertMetadata = metadataAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + val insertOutput = operationAlias(WRITE_OPERATION) +: (insertData ++ insertMetadata) + val inserts = Project(insertOutput, insertSource) + + val replacementQuery = Union(carryover :: inserts :: Nil) + val projections = buildUnionProjections(replacementQuery, rowAttrs, metadataAttrs) + + val groupFilterCond = runtimeCondition(source, rowAttrs, filterSource, scopeOrdinals) + + // ReplaceData's condition (2nd arg) is the planning-time pushdown filter; groupFilterCond + // (6th) is the runtime group filter. The replaced scope set is dynamic: it comes from the + // source, so there is no static target-only predicate to push down, the same as a MERGE whose + // ON condition is purely join keys. That is why the condition is TrueLiteral. File pruning + // instead comes from groupFilterCond, which RowLevelOperationRuntimeGroupFiltering turns into a + // dynamic IN-subquery, so deterministic sources still rewrite only the groups that hold a match. + val writeRelation = relation.copy(table = operationTable) + val replaceData = + ReplaceData( + writeRelation, + TrueLiteral, + replacementQuery, + relation, + projections, + Some(groupFilterCond)) + + WithCTE(replaceData, sourceCte :: Nil) + } + + private def buildWriteDeltaPlan( + relation: DataSourceV2Relation, + operationTable: RowLevelOperationTable, + operation: SupportsDelta, + scopeColumns: Seq[Seq[String]], + source: LogicalPlan): LogicalPlan = { + + val scopeOrdinals = resolveScopeOrdinals(relation, scopeColumns) + val rowAttrs = relation.output + val rowIdAttrs = resolveRowIdAttrs(relation, operation) + val metadataAttrs = resolveRequiredMetadataAttrs(relation, operation) + val readRelation = buildRelationWithAttrs(relation, operationTable, metadataAttrs, rowIdAttrs) + + val sourceCte = CTERelationDef(source) + val scopeSource = newSourceRef(sourceCte) + val insertSource = newSourceRef(sourceCte) + val filterSource = newSourceRef(sourceCte) + + val semiJoinCondition = scopeEquality(readRelation.output, scopeSource.output, scopeOrdinals) + val matchingRows = + Join(readRelation, scopeSource, LeftSemi, Some(semiJoinCondition), JoinHint.NONE) + + val rowIdSet = AttributeSet(rowIdAttrs) + val deleteData = rowAttrs.map { + case attr if rowIdSet.contains(attr) => attr + case attr => Alias(Literal(null, attr.dataType), attr.name)() + } + val deleteMetadata = nullifyMetadataOnDelete(metadataAttrs) + val deletePayload = deleteData ++ rowIdAttrs ++ deleteMetadata + val deleteOutput = operationAlias(DELETE_OPERATION) +: deletePayload + val deletes = Project(deleteOutput, matchingRows) + + val insertData = alignedSourceData(insertSource, rowAttrs) + val insertRowIds = rowIdAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + val insertMetadata = metadataAttrs.map { attr => + Alias(Literal(null, attr.dataType), attr.name)() + } + val insertOutput = + operationAlias(INSERT_OPERATION) +: (insertData ++ insertRowIds ++ insertMetadata) + val inserts = Project(insertOutput, insertSource) + + val deltaQuery = Union(deletes :: inserts :: Nil) + val projections = buildDeltaUnionProjections( + rowAttrs, + rowIdAttrs, + metadataAttrs, + insertData, + rowIdAttrs, + deleteMetadata) + val condition = runtimeCondition(source, rowAttrs, filterSource, scopeOrdinals) + + val writeRelation = relation.copy(table = operationTable) + val writeDelta = WriteDelta(writeRelation, condition, deltaQuery, relation, projections) + + WithCTE(writeDelta, sourceCte :: Nil) + } + + private def operationAlias(operation: Int): NamedExpression = { + Alias(Literal(operation), OPERATION_COLUMN)() + } + + private def runtimeCondition( + source: LogicalPlan, + rowAttrs: Seq[Attribute], + filterSource: CTERelationRef, + scopeOrdinals: Seq[Int]): Expression = { + if (source.deterministic) { + scopeExists(rowAttrs, filterSource, scopeOrdinals) + } else { + // A non-deterministic source cannot be used as a row-level operation condition: Spark requires + // that condition to be deterministic (it is evaluated during scan planning, separately from the + // write query, so it could disagree with the rows the write actually produces). Falling back to + // an unconditional operation keeps the result correct because the carryover/delete joins remain + // the sole arbiter of which rows survive, but it disables file pruning, so the whole table is + // read and rewritten and the operation's conflict surface widens accordingly. + logWarning( + "Scoped replace source is non-deterministic; skipping runtime file pruning. The entire " + + "target table will be read and rewritten, which may significantly increase write " + + "amplification and the operation's conflict surface.") + TrueLiteral + } + } + + // ReplaceData projections require fixed ordinals, but the replacement query is a Union over two + // sources. Both branches therefore use the same [operation, data..., metadata...] layout. + private def buildUnionProjections( + query: LogicalPlan, + rowAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute]): ReplaceDataProjections = { + + val output = query.output + val rowOrdinals = rowAttrs.indices.map(_ + 1) + val rowSchema = StructType(rowAttrs.zipWithIndex.map { case (attr, i) => + StructField(attr.name, attr.dataType, output(rowOrdinals(i)).nullable, attr.metadata) + }) + val rowProjection = ProjectingInternalRow(rowSchema, rowOrdinals.toIndexedSeq) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + val metadataBaseOrdinal = 1 + rowAttrs.size + val metadataOrdinals = metadataAttrs.indices.map(_ + metadataBaseOrdinal) + // Insert rows null out metadata; carryover rows preserve the target metadata contract. + val metadataSchema = StructType(metadataAttrs.map { attr => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + }) + Some(ProjectingInternalRow(metadataSchema, metadataOrdinals.toIndexedSeq)) + } else { + None + } + + ReplaceDataProjections(rowProjection, metadataProjection) + } + + private def buildDeltaUnionProjections( + rowAttrs: Seq[Attribute], + rowIdAttrs: Seq[Attribute], + metadataAttrs: Seq[Attribute], + rowOutputs: Seq[NamedExpression], + rowIdOutputs: Seq[Expression], + metadataOutputs: Seq[Expression]): WriteDeltaProjections = { + + val rowProjection = if (rowAttrs.nonEmpty) { + val rowOrdinals = rowAttrs.indices.map(_ + 1) + val rowSchema = StructType(rowAttrs.zipWithIndex.map { case (attr, i) => + StructField(attr.name, attr.dataType, rowOutputs(i).nullable, attr.metadata) + }) + Some(ProjectingInternalRow(rowSchema, rowOrdinals.toIndexedSeq)) + } else { + None + } + + val rowIdBaseOrdinal = 1 + rowAttrs.size + val rowIdOrdinals = rowIdAttrs.indices.map(_ + rowIdBaseOrdinal) + val rowIdSchema = StructType(rowIdAttrs.zipWithIndex.map { case (attr, i) => + StructField(attr.name, attr.dataType, rowIdOutputs(i).nullable, attr.metadata) + }) + val rowIdProjection = ProjectingInternalRow(rowIdSchema, rowIdOrdinals.toIndexedSeq) + + val metadataProjection = if (metadataAttrs.nonEmpty) { + val metadataBaseOrdinal = rowIdBaseOrdinal + rowIdAttrs.size + val metadataOrdinals = metadataAttrs.indices.map(_ + metadataBaseOrdinal) + val metadataSchema = StructType(metadataAttrs.zipWithIndex.map { case (attr, i) => + StructField(attr.name, attr.dataType, metadataOutputs(i).nullable, attr.metadata) + }) + Some(ProjectingInternalRow(metadataSchema, metadataOrdinals.toIndexedSeq)) + } else { + None + } + + WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection) + } + + // Aligns the source query to the target table positionally, applying the same store-assignment + // policy (ANSI/strict/legacy) Spark uses for INSERT INTO ... SELECT. This rejects or permits casts + // identically to a plain insert, enforces target nullability/char-varchar contracts, and gives the + // aligned columns the target's types so the scope-matching joins compare like-typed values. + private def alignSourceColumns( + relation: DataSourceV2Relation, + source: LogicalPlan): LogicalPlan = { + TableOutputResolver.resolveOutputColumns( + relation.name, + relation.output, + source, + byName = false, + conf) + } + + // The source is aligned to the target schema before the CTE is built, so the insert branch only + // needs to expose the target column names (no further casting happens here). + private def alignedSourceData( + source: CTERelationRef, + rowAttrs: Seq[Attribute]): Seq[NamedExpression] = { + rowAttrs.indices.map { i => + Alias(source.output(i), rowAttrs(i).name)() + } + } + + private def resolveScopeOrdinals( + relation: DataSourceV2Relation, + scopeColumns: Seq[Seq[String]]): Seq[Int] = { + scopeColumns.map { nameParts => + val resolved = relation.resolve(nameParts, conf.resolver).getOrElse { + throw analysisError( + s"Cannot resolve scope column '${nameParts.mkString(".")}' in target table ${relation.name}") + } + val ordinal = relation.output.indexWhere(_.exprId == resolved.toAttribute.exprId) + if (ordinal < 0) { + throw analysisError( + s"Scope column '${nameParts.mkString(".")}' is not a top-level column of ${relation.name}") + } + ordinal + } + } + + private def newSourceRef(sourceDef: CTERelationDef): CTERelationRef = { + val ref = CTERelationRef( + sourceDef.id, + _resolved = true, + sourceDef.child.output, + sourceDef.child.isStreaming) + ref.newInstance().asInstanceOf[CTERelationRef] + } + + private def scopeEquality( + targetOutput: Seq[Attribute], + sourceOutput: Seq[Attribute], + scopeOrdinals: Seq[Int]): Expression = { + scopeOrdinals + .map(ordinal => EqualNullSafe(targetOutput(ordinal), sourceOutput(ordinal)): Expression) + .reduce(And) + } + + private def scopeExists( + targetOutput: Seq[Attribute], + sourceRef: CTERelationRef, + scopeOrdinals: Seq[Int]): Expression = { + val cond = scopeOrdinals + .map { ordinal => + EqualNullSafe(OuterReference(targetOutput(ordinal)), sourceRef.output(ordinal)): Expression + } + .reduce(And) + val outerRefs = scopeOrdinals.map(ordinal => targetOutput(ordinal)) + Exists(Filter(cond, sourceRef), outerRefs) + } + + private def analysisError(message: String): AnalysisException = { + new AnalysisException( + errorClass = "_LEGACY_ERROR_TEMP_ICEBERG_SCOPED_REPLACE", + sqlState = null, + messageTemplate = message, + messageParameters = Map.empty[String, String], + cause = None, + message = Some(message)) + } +} diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala index 7c737f0513ed..ab87ba0ef8af 100644 --- a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSparkSqlExtensionsParser.scala @@ -32,12 +32,14 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.RewriteViewCommands +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser.ParameterContext import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.NonReservedContext import org.apache.spark.sql.catalyst.parser.extensions.IcebergSqlExtensionsParser.QuotedIdentifierContext import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.ReplaceScopedData import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.VariableSubstitution @@ -139,7 +141,9 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) private def parsePlanWithDelegate(sqlText: String)( delegateParse: String => LogicalPlan): LogicalPlan = { val sqlTextAfterSubstitution = substitutor.substitute(sqlText) - if (isIcebergCommand(sqlTextAfterSubstitution)) { + if (isScopedReplaceCommand(sqlTextAfterSubstitution)) { + parseScopedReplace(sqlTextAfterSubstitution)(delegateParse) + } else if (isIcebergCommand(sqlTextAfterSubstitution)) { parse(sqlTextAfterSubstitution) { parser => astBuilder.visit(parser.singleStatement()) } .asInstanceOf[LogicalPlan] } else { @@ -147,6 +151,144 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) } } + /** + * Parse `INSERT INTO t REPLACE USING (cols) `. + * + * Spark's grammar does not yet accept `REPLACE USING` in `INSERT INTO`, while the Iceberg + * extension grammar cannot own an arbitrary trailing Spark query. Keep the workaround narrow: the + * Iceberg grammar parses only the command head, and the wrapped Spark parser parses the query. + * + * TODO: Propose Spark to add `REPLACE USING` / `REPLACE ON` to `insertInto`, remove this text + * split, and build [[ReplaceScopedData]] from the native Spark parse tree instead. + */ + private def parseScopedReplace(sqlText: String)( + delegateParse: String => LogicalPlan): LogicalPlan = { + val scopeListEnd = scopeListEndIndex(sqlText) + val headText = sqlText.substring(0, scopeListEnd + 1) + val queryText = sqlText.substring(scopeListEnd + 1) + val (tableParts, scopeColumns) = parse(headText) { parser => + astBuilder.visitSingleScopedReplaceHead(parser.singleScopedReplaceHead()) + } + val source = delegateParse(queryText) + ReplaceScopedData(UnresolvedRelation(tableParts), scopeColumns, source) + } + + /** + * Find the index of the `)` that closes the scope column list of `REPLACE USING (...)`. + * + * String literals and SQL comments are masked out first so the keyword search and parenthesis + * matching cannot be fooled by `replace using` text or parentheses appearing inside a literal or + * comment. + */ + private def scopeListEndIndex(sqlText: String): Int = { + val masked = maskLiteralsAndComments(sqlText) + val matcher = ReplaceUsingOpenParen.pattern.matcher(masked) + if (!matcher.find()) { + throw new IcebergParseException( + Option(sqlText), + "Could not locate the REPLACE USING (...) scope column list", + Origin(), + Origin()) + } + var depth = 1 + var idx = matcher.end() + while (idx < masked.length && depth > 0) { + masked.charAt(idx) match { + case '(' => depth += 1 + case ')' => depth -= 1 + case _ => + } + if (depth == 0) { + return idx + } + idx += 1 + } + throw new IcebergParseException( + Option(sqlText), + "Unbalanced parentheses in REPLACE USING (...) scope column list", + Origin(), + Origin()) + } + + /** + * Blanks out literals and comments with spaces, preserving input offsets for later string scans. + * + * By default this also blanks backquoted identifiers, which keeps parentheses inside quoted names + * from affecting scope-list matching. Command-head detection opts out so backquoted table names + * remain visible to the head pattern. + */ + private def maskLiteralsAndComments(sql: String, maskBackquotes: Boolean = true): String = { + val out = sql.toCharArray + val n = out.length + var i = 0 + while (i < n) { + sql.charAt(i) match { + case '`' if !maskBackquotes => + i += 1 + while (i < n && sql.charAt(i) != '`') { + i += 1 + } + if (i < n) { + i += 1 + } + case '\'' | '"' | '`' => + val quote = sql.charAt(i) + out(i) = ' ' + i += 1 + while (i < n && sql.charAt(i) != quote) { + // Backslash escapes apply only inside the SQL string literals, not backquotes. + if (quote != '`' && sql.charAt(i) == '\\' && i + 1 < n) { + out(i) = ' ' + i += 1 + } + out(i) = ' ' + i += 1 + } + if (i < n) { + out(i) = ' ' + i += 1 + } + case '-' if i + 1 < n && sql.charAt(i + 1) == '-' => + while (i < n && sql.charAt(i) != '\n') { + out(i) = ' ' + i += 1 + } + case '/' if i + 1 < n && sql.charAt(i + 1) == '*' => + // Spark's lexer treats block comments as nesting (see SqlBaseLexer's BRACKETED_COMMENT, + // which recurses), so a comment only closes once every opener has a matching `*/`. Track + // the nesting depth here so an inner `*/` does not prematurely unmask trailing text. + var depth = 1 + out(i) = ' ' + out(i + 1) = ' ' + i += 2 + while (i + 1 < n && depth > 0) { + if (sql.charAt(i) == '/' && sql.charAt(i + 1) == '*') { + depth += 1 + out(i) = ' ' + out(i + 1) = ' ' + i += 2 + } else if (sql.charAt(i) == '*' && sql.charAt(i + 1) == '/') { + depth -= 1 + out(i) = ' ' + out(i + 1) = ' ' + i += 2 + } else { + out(i) = ' ' + i += 1 + } + } + // Blank any trailing char of an unterminated comment (depth never reached 0). + if (depth > 0 && i < n) { + out(i) = ' ' + i += 1 + } + case _ => + i += 1 + } + } + new String(out) + } + private def isIcebergCommand(sqlText: String): Boolean = { val normalized = sqlText .toLowerCase(Locale.ROOT) @@ -174,6 +316,18 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) isSnapshotRefDdl(normalized)) } + /** + * Detect `INSERT INTO [TABLE] REPLACE USING (...)` by anchoring on the statement + * head, so `REPLACE USING (` appearing later in the query body (e.g. a join alias `replace` + * with a `USING (...)` clause) is left for Spark to parse as an ordinary insert. Matching runs + * on text with literals and comments masked so they cannot fabricate a head match; backquoted + * table identifiers are preserved so the head pattern can see the target table. + */ + private def isScopedReplaceCommand(sqlText: String): Boolean = { + val masked = maskLiteralsAndComments(sqlText, maskBackquotes = false) + ScopedReplaceCommandHead.pattern.matcher(masked).find() + } + private def isSnapshotRefDdl(normalized: String): Boolean = { normalized.contains("create branch") || normalized.contains("replace branch") || @@ -231,6 +385,22 @@ class IcebergSparkSqlExtensionsParser(delegate: ParserInterface) } object IcebergSparkSqlExtensionsParser { + + /** Matches the `REPLACE USING (` opener (case-insensitive, whitespace-tolerant). */ + private val ReplaceUsingOpenParen = "(?i)replace\\s+using\\s*\\(".r + + /** + * Anchors scoped-replace detection to the statement head: `INSERT INTO [TABLE] + * REPLACE USING (`. The identifier is a dotted chain of unquoted (`\w+`) + * or backquoted parts; word-character parts separated only by whitespace (i.e. ordinary query + * keywords like `SELECT ... FROM`) cannot match, so a `REPLACE USING (` appearing later in the + * query body does not trigger the scoped-replace path. + */ + private val ScopedReplaceCommandHead = + ("(?i)^\\s*insert\\s+into\\s+(?:table\\s+)?" + + "(?:\\w+|`(?:[^`]|``)*`)(?:\\s*\\.\\s*(?:\\w+|`(?:[^`]|``)*`))*" + + "\\s+replace\\s+using\\s*\\(").r + private val substitutorCtor: DynConstructors.Ctor[VariableSubstitution] = DynConstructors .builder() diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala index 724101cfe11d..51c10577b9e7 100644 --- a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/parser/extensions/IcebergSqlExtensionsAstBuilder.scala @@ -328,6 +328,18 @@ class IcebergSqlExtensionsAstBuilder(delegate: ParserInterface) toSeq(ctx.order.fields).map(typedVisit[(Term, SortDirection, NullOrder)]) } + /** + * Parse the head of an `INSERT INTO t REPLACE USING (cols) ` statement into the target + * table name and its replace-scope columns. The trailing `` is split off by the parser + * and delegated to Spark, so it is not part of this rule. + */ + override def visitSingleScopedReplaceHead( + ctx: SingleScopedReplaceHeadContext): (Seq[String], Seq[Seq[String]]) = withOrigin(ctx) { + val table = typedVisit[Seq[String]](ctx.multipartIdentifier) + val scopeColumns = toSeq(ctx.fieldList.fields).map(typedVisit[Seq[String]]) + (table, scopeColumns) + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { visit(ctx.statement).asInstanceOf[LogicalPlan] } diff --git a/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceScopedData.scala b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceScopedData.scala new file mode 100644 index 000000000000..cf5fbfd41462 --- /dev/null +++ b/spark/v4.1/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ReplaceScopedData.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +/** + * Logical node for the scoped-replace command: + * {{{INSERT INTO t REPLACE USING (c1, .., cn) }}} + * + * Semantics: delete target rows whose scope columns match a source row, then append the full + * source in the same write. This node records the parsed command; the rewrite rule asks the table's + * row-level operation builder to choose the configured implementation, such as group replacement + * for copy-on-write tables or row deltas for merge-on-read tables. + * + * Scope columns are kept as raw multi-part names rather than resolved expressions so that they can + * be resolved explicitly against the (resolved) target relation in the rewrite rule, avoiding + * cross-child ambiguity when a name is present in both the target and the source. + * + * @param table the target relation + * @param scopeColumns the replace-scope columns, as raw multi-part identifiers over the target + * @param query the source plan whose rows are appended + */ +case class ReplaceScopedData(table: LogicalPlan, scopeColumns: Seq[Seq[String]], query: LogicalPlan) + extends BinaryCommand { + + override def left: LogicalPlan = table + + override def right: LogicalPlan = query + + override def output: Seq[Attribute] = Nil + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, + newRight: LogicalPlan): ReplaceScopedData = + copy(table = newLeft, query = newRight) +} diff --git a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteReplaceScopedData.java b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteReplaceScopedData.java new file mode 100644 index 000000000000..af729bb24c76 --- /dev/null +++ b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteReplaceScopedData.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.SnapshotSummary.ADDED_DELETE_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.DELETED_FILES_PROP; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.SnapshotUtil; +import org.junit.jupiter.api.TestTemplate; + +public class TestCopyOnWriteReplaceScopedData extends TestReplaceScopedData { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.MERGE_MODE, RowLevelOperationMode.COPY_ON_WRITE.modeName()); + } + + @TestTemplate + public void testScopedReplaceRewritesDataFiles() { + createAndInitTable( + "id INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": \"hr\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + assertThat(snapshot.summary()).containsKey(DELETED_FILES_PROP); + assertThat(snapshot.summary()).doesNotContainKey(ADDED_DELETE_FILES_PROP); + } +} diff --git a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadReplaceScopedData.java b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadReplaceScopedData.java new file mode 100644 index 000000000000..f1f9f8692a50 --- /dev/null +++ b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadReplaceScopedData.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.iceberg.SnapshotSummary.ADDED_DELETE_FILES_PROP; +import static org.apache.iceberg.SnapshotSummary.ADDED_DVS_PROP; +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.Map; +import org.apache.iceberg.RowLevelOperationMode; +import org.apache.iceberg.Snapshot; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.iceberg.util.SnapshotUtil; +import org.junit.jupiter.api.TestTemplate; + +public class TestMergeOnReadReplaceScopedData extends TestReplaceScopedData { + + @Override + protected Map extraTableProperties() { + return ImmutableMap.of( + TableProperties.MERGE_MODE, RowLevelOperationMode.MERGE_ON_READ.modeName()); + } + + @TestTemplate + public void testScopedReplaceWritesRowDeltas() { + createAndInitTable( + "id INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": \"hr\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + Table table = validationCatalog.loadTable(tableIdent); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); + assertThat(snapshot.summary()).containsKey(ADDED_DELETE_FILES_PROP); + if (formatVersion >= 3) { + assertThat(snapshot.summary()).containsKey(ADDED_DVS_PROP); + } + } +} diff --git a/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceScopedData.java b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceScopedData.java new file mode 100644 index 000000000000..8c1debb1b3d3 --- /dev/null +++ b/spark/v4.1/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestReplaceScopedData.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.extensions; + +import static org.apache.spark.sql.functions.udf; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; +import org.apache.spark.sql.AnalysisException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.types.DataTypes; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.TestTemplate; + +public abstract class TestReplaceScopedData extends SparkRowLevelOperationsTestBase { + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + sql("DROP TABLE IF EXISTS source"); + sql("DROP VIEW IF EXISTS left_src"); + sql("DROP VIEW IF EXISTS right_src"); + } + + @TestTemplate + public void testScopedReplace() { + createAndInitTable( + "id INT, dep STRING", + "PARTITIONED BY (dep)", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"it\" }\n" + + "{ \"id\": 4, \"dep\": \"sales\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 10, \"dep\": \"hr\" }\n" + "{ \"id\": 11, \"dep\": \"sales\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should replace rows with matching scope values and retain the rest", + ImmutableList.of(row(3, "it"), row(10, "hr"), row(11, "sales")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testDuplicateSourceScopesDeleteTargetRowsOnce() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"it\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 10, \"dep\": \"hr\" }\n" + "{ \"id\": 11, \"dep\": \"hr\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should delete the target scope once and append all source rows", + ImmutableList.of(row(3, "it"), row(10, "hr"), row(11, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testNullScopeMatchesNullScope() { + createAndInitTable("id INT, dep STRING"); + sql("INSERT INTO %s VALUES (1, NULL), (2, 'hr'), (3, NULL)", tableName); + createBranchIfNeeded(); + + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": null }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should use null-safe equality for replace-scope values", + ImmutableList.of(row(2, "hr"), row(10, null)), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testNondeterministicSourceIsEvaluatedOnce() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + + AtomicInteger depIndex = new AtomicInteger(); + spark + .udf() + .register( + "replace_scoped_data_dep", + udf(() -> depIndex.getAndIncrement() == 0 ? "hr" : "it", DataTypes.StringType) + .asNondeterministic() + .asNonNullable()); + + sql( + "INSERT INTO %s REPLACE USING (dep) " + "SELECT 10 AS id, replace_scoped_data_dep() AS dep", + commitTarget()); + + assertEquals( + "Should evaluate the source once for deletes and inserts", + ImmutableList.of(row(2, "it"), row(10, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testNewSourceScopesAppendWithoutDeletingTargetRows() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 10, \"dep\": \"sales\" }\n" + "{ \"id\": 11, \"dep\": \"eng\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should append new scopes and keep every existing row when no scope matches", + ImmutableList.of(row(1, "hr"), row(2, "it"), row(10, "sales"), row(11, "eng")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMixedMatchingAndNewSourceScopes() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\" }\n" + + "{ \"id\": 3, \"dep\": \"it\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING", + "{ \"id\": 10, \"dep\": \"hr\" }\n" + "{ \"id\": 11, \"dep\": \"sales\" }"); + + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should replace matching scopes, append new scopes, and retain untouched scopes", + ImmutableList.of(row(3, "it"), row(10, "hr"), row(11, "sales")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testEmptySourceLeavesTargetUnchanged() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": \"hr\" }"); + + // An empty source has no scope values to match and no rows to append, so the target must be + // left exactly as it was. + sql( + "INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source WHERE id < 0", + commitTarget()); + + assertEquals( + "Empty source should delete nothing and append nothing", + ImmutableList.of(row(1, "hr"), row(2, "it")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testMultiColumnScopeMatchesOnTheFullTuple() { + createAndInitTable( + "id INT, dep STRING, subdep STRING", + "{ \"id\": 1, \"dep\": \"hr\", \"subdep\": \"a\" }\n" + + "{ \"id\": 2, \"dep\": \"hr\", \"subdep\": \"b\" }\n" + + "{ \"id\": 3, \"dep\": \"it\", \"subdep\": \"a\" }"); + + createOrReplaceView( + "source", + "id INT, dep STRING, subdep STRING", + "{ \"id\": 10, \"dep\": \"hr\", \"subdep\": \"a\" }\n" + + "{ \"id\": 11, \"dep\": \"it\", \"subdep\": \"a\" }"); + + sql( + "INSERT INTO %s REPLACE USING (dep, subdep) SELECT id, dep, subdep FROM source", + commitTarget()); + + // Only rows whose (dep, subdep) tuple appears in the source are replaced; (hr, b) is retained + // even though its dep matches a replaced scope. + assertEquals( + "Should match the full scope tuple instead of any single scope column", + ImmutableList.of(row(2, "hr", "b"), row(10, "hr", "a"), row(11, "it", "a")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testReplaceUsingTextInRegularInsertQueryDoesNotTriggerScopedReplaceParser() { + createAndInitTable("id INT, dep STRING"); + + sql("INSERT INTO %s SELECT 1, 'replace using ('", tableName); + + assertThat(sql("SELECT * FROM %s", tableName)).containsExactly(row(1, "replace using (")); + } + + @TestTemplate + public void testJoinAliasNamedReplaceDoesNotTriggerScopedReplaceParser() { + createAndInitTable("id INT, dep STRING"); + createOrReplaceView("left_src", "id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + createOrReplaceView("right_src", "id INT", "{ \"id\": 1 }"); + + // Scoped-replace detection must stay at the command head. Valid Spark query bodies can contain + // the same token sequence, for example a join alias named `replace` followed by a USING join. + sql( + "INSERT INTO %s SELECT id, left_src.dep FROM left_src JOIN right_src replace USING (id)", + tableName); + + assertThat(sql("SELECT * FROM %s", tableName)).containsExactly(row(1, "hr")); + } + + @TestTemplate + public void testNestedBlockCommentDoesNotTriggerScopedReplaceParser() { + createAndInitTable("id INT, dep STRING"); + + // Spark treats block comments as nesting, so the inner `*/` does not close the comment and the + // whole `REPLACE USING (dep)` text stays commented out. The router must mask it the same way + // and leave this as an ordinary insert for Spark to parse. + sql("INSERT INTO %s /* outer /* inner */ REPLACE USING (dep) */ SELECT 1, 'hr'", tableName); + + assertThat(sql("SELECT * FROM %s", tableName)).containsExactly(row(1, "hr")); + } + + @TestTemplate + public void testCommentInCommandHeadStillRoutesToScopedReplace() { + createAndInitTable( + "id INT, dep STRING", + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"it\" }"); + + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": \"hr\" }"); + + // A (terminated) comment in the command head is masked for detection and skipped by the head + // grammar, so this must still be recognized and executed as a scoped replace. + sql( + "INSERT INTO %s /* reload hr */ REPLACE USING (dep) SELECT id, dep FROM source", + commitTarget()); + + assertEquals( + "A comment in the head should not prevent scoped replace detection", + ImmutableList.of(row(2, "it"), row(10, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @TestTemplate + public void testStoreAssignmentPolicyRejectsUnsafeSourceCast() { + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + + // Scoped replace uses the same store-assignment rules as INSERT. Under ANSI, unsafe source + // values must fail during analysis instead of slipping through the row-level rewrite. + withSQLConf( + ImmutableMap.of(SQLConf.STORE_ASSIGNMENT_POLICY().key(), "ansi"), + () -> + assertThatThrownBy( + () -> + sql( + "INSERT INTO %s REPLACE USING (dep) SELECT 'x' AS id, 'hr' AS dep", + commitTarget())) + .isInstanceOf(AnalysisException.class) + .hasMessageContaining("INCOMPATIBLE_DATA_FOR_TABLE")); + } + + @TestTemplate + public void testSafeSourceCastIsAligned() { + createAndInitTable("id BIGINT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + createOrReplaceView("source", "id INT, dep STRING", "{ \"id\": 10, \"dep\": \"hr\" }"); + + // Valid store-assignment casts still need to be preserved when the rewrite aligns the source to + // the target output. + sql("INSERT INTO %s REPLACE USING (dep) SELECT id, dep FROM source", commitTarget()); + + assertEquals( + "Should align the source column to the target type", + ImmutableList.of(row(10L, "hr")), + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } +}