diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java index a84f17948d69f..afb058ccacf88 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsCatalogOptions.java @@ -55,6 +55,31 @@ default String extractCatalog(CaseInsensitiveStringMap options) { return CatalogManager.SESSION_CATALOG_NAME(); } + /** + * Whether this interface should be used for table existence checks or creation. + * A source may override it to dynamically enable the behavior provided by + * SupportsCatalogOptions as they migrate from regular file-based data source behavior. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + default boolean useCatalogResolution(CaseInsensitiveStringMap options) { + return true; + } + + /** + * Whether a {@code DataFrameWriter.save()} should fail when the table does not exist. When this + * returns {@code false}, Spark instead creates the table from the written query (as + * {@code DataFrameWriter.saveAsTable} already does), preserving create-on-write semantics for + * file-based {@code save(path)} calls. + * + * @param options the user-specified options that can identify a table, e.g. file path, Kafka + * topic name, etc. It's an immutable case-insensitive string-to-string map. + */ + default boolean failWriteIfTableDoesNotExist(CaseInsensitiveStringMap options) { + return true; + } + /** * Extracts the timestamp string for time travel from the given options. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index a9f16ffa87be1..8abb6881f1768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -168,41 +168,81 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val catalogManager = df.sparkSession.sessionState.catalogManager + + def createTableAsSelectCommand( + catalog: TableCatalog, ident: Identifier, ignoreIfExists: Boolean): LogicalPlan = { + val tableSpec = UnresolvedTableSpec( + properties = Map.empty, + provider = Some(source), + optionExpression = OptionList(Seq.empty), + location = extraOptions.get("path"), + comment = extraOptions.get(TableCatalog.PROP_COMMENT), + collation = extraOptions.get(TableCatalog.PROP_COLLATION), + serde = None, + external = false, + constraints = Seq.empty) + CreateTableAsSelect( + UnresolvedIdentifier( + catalog.name +: ident.namespace.toImmutableArraySeq :+ ident.name), + partitioningAsV2, + df.queryExecution.analyzed, + tableSpec, + finalOptions, + ignoreIfExists = ignoreIfExists) + } + + def appendOrOverwriteCommand( + table: Table, + catalog: Option[CatalogPlugin], + ident: Option[Identifier]): LogicalPlan = { + checkPartitioningMatchesV2Table(table) + val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions) + if (curmode == SaveMode.Append) { + AppendData.byName(relation, df.logicalPlan, finalOptions, _withSchemaEvolution) + } else { + // Truncate the table. TableCapabilityCheck will throw a nice exception if this + // isn't supported + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), finalOptions, _withSchemaEvolution) + } + } + curmode match { case SaveMode.Append | SaveMode.Overwrite => - val (table, catalog, ident) = provider match { - case supportsExtract: SupportsCatalogOptions => + provider match { + case supportsExtract: SupportsCatalogOptions + if supportsExtract.useCatalogResolution(dsOptions) => val ident = supportsExtract.extractIdentifier(dsOptions) val catalog = CatalogV2Util.getTableProviderCatalog( supportsExtract, catalogManager, dsOptions) - - (catalog.loadTable(ident), Some(catalog), Some(ident)) + val tableOpt = + try Some(catalog.loadTable(ident)) + catch { + // The table does not exist: create it from the query (create-on-write, + // consistent with saveAsTable) unless the provider asks to fail. + case _: NoSuchTableException + if !supportsExtract.failWriteIfTableDoesNotExist(dsOptions) => None + } + tableOpt match { + case Some(table) => appendOrOverwriteCommand(table, Some(catalog), Some(ident)) + case None => createTableAsSelectCommand(catalog, ident, ignoreIfExists = false) + } case _: TableProvider => val t = getTable if (t.supports(BATCH_WRITE)) { - (t, None, None) + appendOrOverwriteCommand(t, None, None) } else { // Streaming also uses the data source V2 API. So it may be that the data source // implements v2, but has no v2 implementation for batch writes. In that case, we // fall back to saving as though it's a V1 source. - return saveToV1SourceCommand(path) + saveToV1SourceCommand(path) } } - val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions) - checkPartitioningMatchesV2Table(table) - if (curmode == SaveMode.Append) { - AppendData.byName(relation, df.logicalPlan, finalOptions, _withSchemaEvolution) - } else { - // Truncate the table. TableCapabilityCheck will throw a nice exception if this - // isn't supported - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), finalOptions, _withSchemaEvolution) - } - case createMode => provider match { - case supportsExtract: SupportsCatalogOptions => + case supportsExtract: SupportsCatalogOptions + if supportsExtract.useCatalogResolution(dsOptions) => if (_withSchemaEvolution) { throw QueryCompilationErrors.schemaEvolutionNotSupportedForCreateTableWriteError() } @@ -210,24 +250,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram val catalog = CatalogV2Util.getTableProviderCatalog( supportsExtract, catalogManager, dsOptions) - val tableSpec = UnresolvedTableSpec( - properties = Map.empty, - provider = Some(source), - optionExpression = OptionList(Seq.empty), - location = extraOptions.get("path"), - comment = extraOptions.get(TableCatalog.PROP_COMMENT), - collation = extraOptions.get(TableCatalog.PROP_COLLATION), - serde = None, - external = false, - constraints = Seq.empty) - CreateTableAsSelect( - UnresolvedIdentifier( - catalog.name +: ident.namespace.toImmutableArraySeq :+ ident.name), - partitioningAsV2, - df.queryExecution.analyzed, - tableSpec, - finalOptions, - ignoreIfExists = createMode == SaveMode.Ignore) + createTableAsSelectCommand( + catalog, ident, ignoreIfExists = createMode == SaveMode.Ignore) case _: TableProvider => if (getTable.supports(BATCH_WRITE)) { throw QueryCompilationErrors.writeWithSaveModeUnsupportedBySourceError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index a3b5c5aeb7995..f280ee68cb1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -116,10 +116,11 @@ private[sql] object DataSourceV2Utils extends Logging { optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val (table, catalog, ident, timeTravelSpec) = provider match { - case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => + case c: SupportsCatalogOptions + if c.useCatalogResolution(dsOptions) && userSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") - case hasCatalog: SupportsCatalogOptions => + case hasCatalog: SupportsCatalogOptions if hasCatalog.useCatalogResolution(dsOptions) => val ident = hasCatalog.extractIdentifier(dsOptions) val catalog = CatalogV2Util.getTableProviderCatalog( hasCatalog, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 98904e6976074..4204394d3d4b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -47,6 +47,8 @@ class SupportsCatalogOptionsSuite extends SharedSparkSession with BeforeAndAfter private val catalogName = "testcat" private val format = classOf[CatalogSupportingInMemoryTableProvider].getName + private val optOutFormat = classOf[CatalogResolutionOptOutProvider].getName + private val createOnWriteFormat = classOf[CreateOnWriteProvider].getName private def catalog(name: String): TableCatalog = { spark.sessionState.catalogManager.catalog(name).asInstanceOf[TableCatalog] @@ -370,6 +372,55 @@ class SupportsCatalogOptionsSuite extends SharedSparkSession with BeforeAndAfter .contains("Cannot specify both version and timestamp when time travelling the table.")) } + test("useCatalogResolution=false: read is resolved via the TableProvider path, not the catalog") { + // The provider opts out of catalog resolution, so load() goes through getTable instead of + // extractIdentifier/extractCatalog. The resulting relation therefore has no catalog/identifier. + val df = spark.read.format(optOutFormat).option("name", "t1").load() + val relation = df.logicalPlan.collectFirst { + case r: DataSourceV2Relation => r + }.getOrElse(fail("Expected a DataSourceV2Relation")) + assert(relation.catalog.isEmpty && relation.identifier.isEmpty, + "Opting out of catalog resolution should bypass the catalog") + } + + test("useCatalogResolution=false: a user-specified schema is allowed (no catalog check)") { + // The schema check only fires on the catalog-resolution path; opting out skips it. + val df = spark.read.format(optOutFormat).option("name", "t1").schema("i int, j int").load() + assert(df.schema.fieldNames === Array("i", "j")) + } + + test("failWriteIfTableDoesNotExist=false: append creates a missing table (create-on-write)") { + val df = spark.range(10) + // t1 does not exist yet; append should create it from the query instead of failing. + df.write.format(createOnWriteFormat).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Append).save() + assert(catalog(catalogName).tableExists("t1"), "append should have created the table") + checkAnswer(load("t1", Some(catalogName)), df.toDF()) + } + + test("failWriteIfTableDoesNotExist=false: overwrite creates a missing table (create-on-write)") { + val df = spark.range(10, 20) + df.write.format(createOnWriteFormat).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Overwrite).save() + assert(catalog(catalogName).tableExists("t1"), "overwrite should have created the table") + checkAnswer(load("t1", Some(catalogName)), df.toDF()) + } + + test("failWriteIfTableDoesNotExist=false: append to an existing table still appends") { + sql(s"create table $catalogName.t1 (id bigint) using $createOnWriteFormat") + spark.range(10).write.format(createOnWriteFormat).option("name", "t1") + .option("catalog", catalogName).mode(SaveMode.Append).save() + checkAnswer(load("t1", Some(catalogName)), spark.range(10).toDF()) + } + + test("append to a missing table fails by default (failWriteIfTableDoesNotExist=true)") { + // The default provider keeps the prior behavior: append/overwrite to a missing table fails. + intercept[NoSuchTableException] { + spark.range(10).write.format(format).option("name", "t1").option("catalog", catalogName) + .mode(SaveMode.Append).save() + } + } + private def checkV2Identifiers( plan: LogicalPlan, identifier: String = "t1", @@ -443,3 +494,13 @@ class CatalogSupportingInMemoryTableProvider } } } + +/** Opts out of catalog resolution, so load/save fall back to the plain TableProvider path. */ +class CatalogResolutionOptOutProvider extends CatalogSupportingInMemoryTableProvider { + override def useCatalogResolution(options: CaseInsensitiveStringMap): Boolean = false +} + +/** Opts out of failing on a missing table, enabling create-on-write for save(). */ +class CreateOnWriteProvider extends CatalogSupportingInMemoryTableProvider { + override def failWriteIfTableDoesNotExist(options: CaseInsensitiveStringMap): Boolean = false +}