diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala index e898253be1168..2d809486ab391 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroTable.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession import org.apache.spark.sql.avro.AvroUtils -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.types.{DataType, StructType} @@ -43,13 +43,14 @@ case class AvroTable( AvroUtils.inferSchema(sparkSession, options.asScala.toMap, files) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - AvroWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + AvroWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } override def supportsDataType(dataType: DataType): Boolean = AvroUtils.supportsDataType(dataType) - override def formatName: String = "AVRO" + override def formatName: String = "Avro" } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala index 3a91fd0c73d1a..c594e7a956889 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroWrite.scala @@ -29,7 +29,11 @@ case class AvroWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite { override def prepareWrite( sqlConf: SQLConf, job: Job, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala index f0359b33f431d..f67c7ba91b49d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala @@ -194,10 +194,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram if (curmode == SaveMode.Append) { AppendData.byName(relation, df.logicalPlan, finalOptions) } else { - // Truncate the table. TableCapabilityCheck will throw a nice exception if this - // isn't supported - OverwriteByExpression.byName( - relation, df.logicalPlan, Literal(true), finalOptions) + val dynamicOverwrite = + df.sparkSession.sessionState.conf.partitionOverwriteMode == + PartitionOverwriteMode.DYNAMIC && + partitioningColumns.exists(_.nonEmpty) + if (dynamicOverwrite) { + OverwritePartitionsDynamic.byName( + relation, df.logicalPlan, finalOptions) + } else { + OverwriteByExpression.byName( + relation, df.logicalPlan, Literal(true), finalOptions) + } } case createMode => @@ -318,7 +325,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram } val session = df.sparkSession - val canUseV2 = lookupV2Provider().isDefined + // TODO(SPARK-56175): File source V2 does not support + // insertInto for catalog tables yet. + val canUseV2 = lookupV2Provider() match { + case Some(_: FileDataSourceV2) => false + case Some(_) => true + case None => false + } session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case NonSessionCatalogAndIdentifier(catalog, ident) => @@ -438,9 +451,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ val session = df.sparkSession - val v2ProviderOpt = lookupV2Provider() - val canUseV2 = v2ProviderOpt.isDefined || (hasCustomSessionCatalog && - !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) + // TODO(SPARK-56230): File source V2 does not support + // saveAsTable yet. Always use V1 for file sources. + val v2ProviderOpt = lookupV2Provider().flatMap { + case _: FileDataSourceV2 => None + case other => Some(other) + } + val canUseV2 = v2ProviderOpt.isDefined || + (hasCustomSessionCatalog && + !df.sparkSession.sessionState.catalogManager + .catalog(CatalogManager.SESSION_CATALOG_NAME) .isInstanceOf[CatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { @@ -595,8 +615,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram private def lookupV2Provider(): Option[TableProvider] = { DataSource.lookupDataSourceV2(source, df.sparkSession.sessionState.conf) match { - // TODO(SPARK-28396): File source v2 write path is currently broken. - case Some(_: FileDataSourceV2) => None + // File source V2 supports non-partitioned Append and + // Overwrite via DataFrame API (df.write.save(path)). + // Fall back to V1 for: + // - ErrorIfExists/Ignore (TODO: SPARK-56174) + // - Partitioned writes (TODO: SPARK-56174) + case Some(_: FileDataSourceV2) + if (curmode != SaveMode.Append + && curmode != SaveMode.Overwrite) + || partitioningColumns.exists(_.nonEmpty) => + None case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala deleted file mode 100644 index e03d6e6772fa1..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.execution.datasources.v2.{ExtractV2Table, FileTable} - -/** - * Replace the File source V2 table in [[InsertIntoStatement]] to V1 [[FileFormat]]. - * E.g, with temporary view `t` using - * [[org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2]], inserting into view `t` fails - * since there is no corresponding physical plan. - * This is a temporary hack for making current data source V2 work. It should be - * removed when Catalog support of file data source v2 is finished. - */ -class FallBackFileSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoStatement( - d @ ExtractV2Table(table: FileTable), _, _, _, _, _, _, _) => - val v1FileFormat = table.fallbackFileFormat.getDeclaredConstructor().newInstance() - val relation = HadoopFsRelation( - table.fileIndex, - table.fileIndex.partitionSchema, - table.schema, - None, - v1FileFormat, - d.options.asScala.toMap)(sparkSession) - i.copy(table = LogicalRelation(relation)) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index e11c2b15e0541..6b5e04f5e27ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable +import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.{FileAlreadyExistsException, Path} import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -104,6 +105,14 @@ abstract class FileFormatDataWriter( } } + /** + * Override writeAll to ensure V2 DataWriter.writeAll path also wraps + * errors with TASK_WRITE_FAILED, matching V1 behavior. + */ + override def writeAll(records: java.util.Iterator[InternalRow]): Unit = { + writeWithIterator(records.asScala) + } + /** Write an iterator of records. */ def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { var count = 0L diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 7932a0aa53bac..8717a5154fa69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -70,7 +70,15 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat val nameParts = ident.toQualifiedNameParts(catalog) cacheManager.recacheTableOrView(session, nameParts, includeTimeTravel = false) case _ => - cacheManager.recacheByPlan(session, r) + r.table match { + case ft: FileTable => + ft.fileIndex.refresh() + val path = new Path(ft.fileIndex.rootPaths.head.toUri) + val fs = path.getFileSystem(hadoopConf) + cacheManager.recacheByPath(session, path, fs) + case _ => + cacheManager.recacheByPlan(session, r) + } } private def recacheTable(r: ResolvedTable, includeTimeTravel: Boolean)(): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index a3b5c5aeb7995..946ab0f250194 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -164,8 +164,11 @@ private[sql] object DataSourceV2Utils extends Logging { // `HiveFileFormat`, when running tests in sql/core. if (DDLUtils.isHiveTable(Some(provider))) return None DataSource.lookupDataSourceV2(provider, conf) match { - // TODO(SPARK-28396): Currently file source v2 can't work with tables. - case Some(p) if !p.isInstanceOf[FileDataSourceV2] => Some(p) + // TODO(SPARK-56175): File source V2 catalog table loading + // is not yet fully supported (stats, partition management, + // data type validation gaps). + case Some(_: FileDataSourceV2) => None + case Some(p) => Some(p) case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 0af728c1958d4..072e4bbf9a182 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -26,7 +26,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, + LogicalWriteInfoImpl, SupportsDynamicOverwrite, + SupportsTruncate, Write, WriteBuilder} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.runtime.MetadataLogFileIndex @@ -49,18 +51,27 @@ abstract class FileTable( val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) - if (FileStreamSink.hasMetadata(paths, hadoopConf, sparkSession.sessionState.conf)) { - // We are reading from the results of a streaming query. We will load files from - // the metadata log instead of listing them using HDFS APIs. + // When userSpecifiedSchema is provided (e.g., write path via DataFrame API), the path + // may not exist yet. Skip streaming metadata check and file existence checks. + val isStreamingMetadata = userSpecifiedSchema.isEmpty && + FileStreamSink.hasMetadata(paths, hadoopConf, sparkSession.sessionState.conf) + if (isStreamingMetadata) { new MetadataLogFileIndex(sparkSession, new Path(paths.head), options.asScala.toMap, userSpecifiedSchema) } else { - // This is a non-streaming file based datasource. - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = true, enableGlobbing = globPaths) - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + val checkFilesExist = userSpecifiedSchema.isEmpty + val rootPathsSpecified = + DataSource.checkAndGlobPathIfNecessary( + paths, hadoopConf, + checkEmptyGlobPath = checkFilesExist, + checkFilesExist = checkFilesExist, + enableGlobbing = globPaths) + val fileStatusCache = + FileStatusCache.getOrCreate(sparkSession) new InMemoryFileIndex( - sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) + sparkSession, rootPathsSpecified, + caseSensitiveMap, userSpecifiedSchema, + fileStatusCache) } } @@ -174,8 +185,43 @@ abstract class FileTable( writeInfo.rowIdSchema(), writeInfo.metadataSchema()) } + + /** + * Creates a [[WriteBuilder]] that supports truncate and + * dynamic partition overwrite for file-based tables. + */ + protected def createFileWriteBuilder( + info: LogicalWriteInfo)( + buildWrite: (LogicalWriteInfo, StructType, + Map[Map[String, String], String], + Boolean, Boolean) => Write + ): WriteBuilder = { + new WriteBuilder with SupportsDynamicOverwrite with SupportsTruncate { + private var isDynamicOverwrite = false + private var isTruncate = false + + override def overwriteDynamicPartitions(): WriteBuilder = { + isDynamicOverwrite = true + this + } + + override def truncate(): WriteBuilder = { + isTruncate = true + this + } + + override def build(): Write = { + val merged = mergedWriteInfo(info) + val partSchema = fileIndex.partitionSchema + buildWrite(merged, partSchema, + Map.empty, isDynamicOverwrite, isTruncate) + } + } + } + } object FileTable { - private val CAPABILITIES = util.EnumSet.of(BATCH_READ, BATCH_WRITE) + private val CAPABILITIES = util.EnumSet.of( + BATCH_READ, BATCH_WRITE, TRUNCATE, OVERWRITE_DYNAMIC) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index 77e1ade44780f..be81f4afa0245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -37,7 +37,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.SchemaUtils -import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.SerializableConfiguration trait FileWrite extends Write { @@ -46,6 +45,10 @@ trait FileWrite extends Write { def supportsDataType: DataType => Boolean def allowDuplicatedColumnNames: Boolean = false def info: LogicalWriteInfo + def partitionSchema: StructType + def customPartitionLocations: Map[Map[String, String], String] = Map.empty + def dynamicPartitionOverwrite: Boolean = false + def isTruncate: Boolean = false private val schema = info.schema() private val queryId = info.queryId() @@ -60,11 +63,32 @@ trait FileWrite extends Write { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + + // Ensure the output path exists. For new writes (Append to a new path, Overwrite on a new + // path), the path may not exist yet. + val fs = path.getFileSystem(hadoopConf) + val qualifiedPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + if (!fs.exists(qualifiedPath)) { + fs.mkdirs(qualifiedPath) + } + + // For truncate (full overwrite), delete existing data before writing. + if (isTruncate && fs.exists(qualifiedPath)) { + fs.listStatus(qualifiedPath).foreach { status => + // Preserve hidden files/dirs (e.g., _SUCCESS, .spark-staging-*) + if (!status.getPath.getName.startsWith("_") && + !status.getPath.getName.startsWith(".")) { + fs.delete(status.getPath, true) + } + } + } + val job = getJobInstance(hadoopConf, path) val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = paths.head) + outputPath = paths.head, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) lazy val description = createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, options.asScala.toMap) @@ -93,12 +117,14 @@ trait FileWrite extends Write { s"got: ${paths.mkString(", ")}") } if (!allowDuplicatedColumnNames) { - SchemaUtils.checkColumnNameDuplication( - schema.fields.map(_.name).toImmutableArraySeq, caseSensitiveAnalysis) + SchemaUtils.checkSchemaColumnNameDuplication( + schema, caseSensitiveAnalysis) + } + if (!sqlConf.allowCollationsInMapKeys) { + SchemaUtils.checkNoCollationsInMapKeys(schema) } DataSource.validateSchema(formatName, schema, sqlConf) - // TODO: [SPARK-36340] Unify check schema filed of DataSource V2 Insert. schema.foreach { field => if (!supportsDataType(field.dataType)) { throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError(formatName, field) @@ -121,26 +147,38 @@ trait FileWrite extends Write { pathName: String, options: Map[String, String]): WriteJobDescription = { val caseInsensitiveOptions = CaseInsensitiveMap(options) + val allColumns = toAttributes(schema) + val partitionColumnNames = partitionSchema.fields.map(_.name).toSet + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val partitionColumns = if (partitionColumnNames.nonEmpty) { + allColumns.filter { col => + if (caseSensitive) { + partitionColumnNames.contains(col.name) + } else { + partitionColumnNames.exists(_.equalsIgnoreCase(col.name)) + } + } + } else { + Seq.empty + } + val dataColumns = allColumns.filterNot(partitionColumns.contains) // Note: prepareWrite has side effect. It sets "job". + val dataSchema = StructType(dataColumns.map(col => schema(col.name))) val outputWriterFactory = - prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema) - val allColumns = toAttributes(schema) + prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, dataSchema) val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics val serializableHadoopConf = new SerializableConfiguration(hadoopConf) val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) - // TODO: after partitioning is supported in V2: - // 1. filter out partition columns in `dataColumns`. - // 2. Don't use Seq.empty for `partitionColumns`. new WriteJobDescription( uuid = UUID.randomUUID().toString, serializableHadoopConf = new SerializableConfiguration(job.getConfiguration), outputWriterFactory = outputWriterFactory, allColumns = allColumns, - dataColumns = allColumns, - partitionColumns = Seq.empty, + dataColumns = dataColumns, + partitionColumns = partitionColumns, bucketSpec = None, path = pathName, - customPartitionLocations = Map.empty, + customPartitionLocations = customPartitionLocations, maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index d21b5c730f0ca..be6c60394145a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{QualifiedTableName, SQLConfHelper} import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, SessionCatalog} @@ -33,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.NamespaceChange.RemoveProperty import org.apache.spark.sql.connector.catalog.functions.UnboundFunction import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -225,6 +226,26 @@ class V2SessionCatalog(catalog: SessionCatalog) case _ => // The provider is not a V2 provider so we return the schema and partitions as is. + // Validate data types using the V1 FileFormat, matching V1 CreateDataSourceTableCommand + // behavior (which validates via DataSource.resolveRelation). + if (schema.nonEmpty) { + val ds = DataSource( + SparkSession.active, + userSpecifiedSchema = Some(schema), + className = provider) + ds.providingInstance() match { + case format: FileFormat => + schema.foreach { field => + if (!format.supportDataType(field.dataType)) { + throw QueryCompilationErrors + .dataTypeUnsupportedByDataSourceError( + format.toString, field) + } + } + case _ => + } + } + DataSource.validateSchema(provider, schema, conf) (schema, partitions) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala index 4938df795cb1a..c6b15c0ce1e20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVTable.scala @@ -22,11 +22,12 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.csv.CSVOptions -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.types.{AtomicType, DataType, GeographyType, GeometryType, StructType, UserDefinedType} +import org.apache.spark.sql.types.{AtomicType, DataType, GeographyType, + GeometryType, StructType, UserDefinedType} import org.apache.spark.sql.util.CaseInsensitiveStringMap case class CSVTable( @@ -50,9 +51,10 @@ case class CSVTable( } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - CSVWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + CSVWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWrite.scala index 7011fea77d888..617c404e8b7c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVWrite.scala @@ -31,7 +31,11 @@ case class CSVWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite { override def allowDuplicatedColumnNames: Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala index cf3c1e11803c0..e10c4cf959129 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonTable.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.json.JSONOptionsInRead -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.json.JsonDataSource import org.apache.spark.sql.execution.datasources.v2.FileTable @@ -50,9 +50,10 @@ case class JsonTable( } override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - JsonWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + JsonWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonWrite.scala index ea1f6793cb9ca..0da659a68eae0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonWrite.scala @@ -31,7 +31,11 @@ case class JsonWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite { override def prepareWrite( sqlConf: SQLConf, job: Job, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index 08cd89fdacc61..99484526004e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable @@ -44,9 +44,10 @@ case class OrcTable( OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - OrcWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + OrcWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala index 12dff269a468e..2de2a197bf766 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWrite.scala @@ -32,7 +32,11 @@ case class OrcWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite { override def prepareWrite( sqlConf: SQLConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala index 67052c201a9df..0a21ca3344a88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetTable.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetUtils import org.apache.spark.sql.execution.datasources.v2.FileTable @@ -44,9 +44,10 @@ case class ParquetTable( ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - ParquetWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + ParquetWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala index e37b1fce7c37e..120d462660eb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetWrite.scala @@ -30,7 +30,11 @@ case class ParquetWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite with Logging { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite with Logging { override def prepareWrite( sqlConf: SQLConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala index d8880b84c6211..5e14ccf0dfba9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextTable.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.text import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, Write, WriteBuilder} +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} @@ -40,9 +40,10 @@ case class TextTable( Some(StructType(Array(StructField("value", StringType)))) override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { - new WriteBuilder { - override def build(): Write = - TextWrite(paths, formatName, supportsDataType, mergedWriteInfo(info)) + createFileWriteBuilder(info) { + (mergedInfo, partSchema, customLocs, dynamicOverwrite, truncate) => + TextWrite(paths, formatName, supportsDataType, mergedInfo, partSchema, customLocs, + dynamicOverwrite, truncate) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextWrite.scala index 7bee49f05cbcd..f3de9daa44f42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextWrite.scala @@ -31,7 +31,11 @@ case class TextWrite( paths: Seq[String], formatName: String, supportsDataType: DataType => Boolean, - info: LogicalWriteInfo) extends FileWrite { + info: LogicalWriteInfo, + partitionSchema: StructType, + override val customPartitionLocations: Map[Map[String, String], String] = Map.empty, + override val dynamicPartitionOverwrite: Boolean, + override val isTruncate: Boolean) extends FileWrite { private def verifySchema(schema: StructType): Unit = { if (schema.size != 1) { throw QueryCompilationErrors.textDataSourceWithMultiColumnsError(schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 08dd212060762..527bee2ca980d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -225,7 +225,6 @@ abstract class BaseSessionStateBuilder( new ResolveDataSource(session) +: new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: - new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: new ResolveSessionCatalog(this.catalogManager) +: ResolveWriteToStream +: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala deleted file mode 100644 index 2a0ab21ddb09c..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala +++ /dev/null @@ -1,199 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.connector - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} -import org.apache.spark.sql.connector.read.ScanBuilder -import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} -import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution} -import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand} -import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 -import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener} - -class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { - - override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] - - override def shortName(): String = "parquet" - - override def getTable(options: CaseInsensitiveStringMap): Table = { - new DummyReadOnlyFileTable - } -} - -class DummyReadOnlyFileTable extends Table with SupportsRead { - override def name(): String = "dummy" - - override def schema(): StructType = StructType(Nil) - - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - throw SparkException.internalError("Dummy file reader") - } - - override def capabilities(): java.util.Set[TableCapability] = - java.util.EnumSet.of(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA) -} - -class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { - - override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] - - override def shortName(): String = "parquet" - - override def getTable(options: CaseInsensitiveStringMap): Table = { - new DummyWriteOnlyFileTable - } -} - -class DummyWriteOnlyFileTable extends Table with SupportsWrite { - override def name(): String = "dummy" - - override def schema(): StructType = StructType(Nil) - - override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = - throw SparkException.internalError("Dummy file writer") - - override def capabilities(): java.util.Set[TableCapability] = - java.util.EnumSet.of(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA) -} - -class FileDataSourceV2FallBackSuite extends QueryTest with SharedSparkSession { - - private val dummyReadOnlyFileSourceV2 = classOf[DummyReadOnlyFileDataSourceV2].getName - private val dummyWriteOnlyFileSourceV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName - - override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") - - test("Fall back to v1 when writing to file with read only FileDataSourceV2") { - val df = spark.range(10).toDF() - withTempPath { file => - val path = file.getCanonicalPath - // Writing file should fall back to v1 and succeed. - df.write.format(dummyReadOnlyFileSourceV2).save(path) - - // Validate write result with [[ParquetFileFormat]]. - checkAnswer(spark.read.parquet(path), df) - - // Dummy File reader should fail as expected. - checkError( - exception = intercept[SparkException] { - spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() - }, - condition = "INTERNAL_ERROR", - parameters = Map("message" -> "Dummy file reader")) - } - } - - test("Fall back read path to v1 with configuration USE_V1_SOURCE_LIST") { - val df = spark.range(10).toDF() - withTempPath { file => - val path = file.getCanonicalPath - df.write.parquet(path) - Seq( - "foo,parquet,bar", - "ParQuet,bar,foo", - s"foobar,$dummyReadOnlyFileSourceV2" - ).foreach { fallbackReaders => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> fallbackReaders) { - // Reading file should fall back to v1 and succeed. - checkAnswer(spark.read.format(dummyReadOnlyFileSourceV2).load(path), df) - checkAnswer(sql(s"SELECT * FROM parquet.`$path`"), df) - } - } - - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "foo,bar") { - // Dummy File reader should fail as DISABLED_V2_FILE_DATA_SOURCE_READERS doesn't include it. - checkError( - exception = intercept[SparkException] { - spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() - }, - condition = "INTERNAL_ERROR", - parameters = Map("message" -> "Dummy file reader")) - } - } - } - - test("Fall back to v1 when reading file with write only FileDataSourceV2") { - val df = spark.range(10).toDF() - withTempPath { file => - val path = file.getCanonicalPath - df.write.parquet(path) - // Fallback reads to V1 - checkAnswer(spark.read.format(dummyWriteOnlyFileSourceV2).load(path), df) - } - } - - test("Always fall back write path to v1") { - val df = spark.range(10).toDF() - withTempPath { path => - // Writes should fall back to v1 and succeed. - df.write.format(dummyWriteOnlyFileSourceV2).save(path.getCanonicalPath) - checkAnswer(spark.read.parquet(path.getCanonicalPath), df) - } - } - - test("Fallback Parquet V2 to V1") { - Seq("parquet", classOf[ParquetDataSourceV2].getCanonicalName).foreach { format => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) { - val commands = ArrayBuffer.empty[(String, LogicalPlan)] - val exceptions = ArrayBuffer.empty[(String, Exception)] - val listener = new QueryExecutionListener { - override def onFailure( - funcName: String, - qe: QueryExecution, - exception: Exception): Unit = { - exceptions += funcName -> exception - } - - override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { - commands += funcName -> qe.logical - } - } - spark.listenerManager.register(listener) - - try { - withTempPath { path => - val inputData = spark.range(10) - inputData.write.format(format).save(path.getCanonicalPath) - sparkContext.listenerBus.waitUntilEmpty() - assert(commands.length == 1) - assert(commands.head._1 == "command") - assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) - assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] - .fileFormat.isInstanceOf[ParquetFileFormat]) - val df = spark.read.format(format).load(path.getCanonicalPath) - checkAnswer(df, inputData.toDF()) - assert( - df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) - } - } finally { - spark.listenerManager.unregister(listener) - } - } - } - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2WriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2WriteSuite.scala new file mode 100644 index 0000000000000..b60cf9995b0d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2WriteSuite.scala @@ -0,0 +1,553 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.connector.read.ScanBuilder +import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} +import org.apache.spark.sql.execution.{FileSourceScanExec, QueryExecution} +import org.apache.spark.sql.execution.datasources.{FileFormat, InsertIntoHadoopFsRelationCommand} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 +import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetDataSourceV2 +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener} + +class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { + + override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new DummyReadOnlyFileTable + } +} + +class DummyReadOnlyFileTable extends Table with SupportsRead { + override def name(): String = "dummy" + + override def schema(): StructType = StructType(Nil) + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + throw SparkException.internalError("Dummy file reader") + } + + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA) +} + +class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { + + override def fallbackFileFormat: Class[_ <: FileFormat] = classOf[ParquetFileFormat] + + override def shortName(): String = "parquet" + + override def getTable(options: CaseInsensitiveStringMap): Table = { + new DummyWriteOnlyFileTable + } +} + +class DummyWriteOnlyFileTable extends Table with SupportsWrite { + override def name(): String = "dummy" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = + throw SparkException.internalError("Dummy file writer") + + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA) +} + +class FileDataSourceV2WriteSuite extends QueryTest with SharedSparkSession { + + private val dummyReadOnlyFileSourceV2 = classOf[DummyReadOnlyFileDataSourceV2].getName + private val dummyWriteOnlyFileSourceV2 = classOf[DummyWriteOnlyFileDataSourceV2].getName + + // Built-in file formats for write testing. Text is excluded + // because it only supports a single string column. + private val fileFormats = Seq("parquet", "orc", "json", "csv") + + override protected def sparkConf: SparkConf = super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") + + test("Fall back to v1 when writing to file with read only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + // Writing file should fall back to v1 and succeed. + df.write.format(dummyReadOnlyFileSourceV2).save(path) + + // Validate write result with [[ParquetFileFormat]]. + checkAnswer(spark.read.parquet(path), df) + + // Dummy File reader should fail as expected. + checkError( + exception = intercept[SparkException] { + spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() + }, + condition = "INTERNAL_ERROR", + parameters = Map("message" -> "Dummy file reader")) + } + } + + test("Fall back read path to v1 with configuration USE_V1_SOURCE_LIST") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + df.write.parquet(path) + Seq( + "foo,parquet,bar", + "ParQuet,bar,foo", + s"foobar,$dummyReadOnlyFileSourceV2" + ).foreach { fallbackReaders => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> fallbackReaders) { + // Reading file should fall back to v1 and succeed. + checkAnswer(spark.read.format(dummyReadOnlyFileSourceV2).load(path), df) + checkAnswer(sql(s"SELECT * FROM parquet.`$path`"), df) + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "foo,bar") { + // Dummy File reader should fail as DISABLED_V2_FILE_DATA_SOURCE_READERS doesn't include it. + checkError( + exception = intercept[SparkException] { + spark.read.format(dummyReadOnlyFileSourceV2).load(path).collect() + }, + condition = "INTERNAL_ERROR", + parameters = Map("message" -> "Dummy file reader")) + } + } + } + + test("Fall back to v1 when reading file with write only FileDataSourceV2") { + val df = spark.range(10).toDF() + withTempPath { file => + val path = file.getCanonicalPath + df.write.parquet(path) + // Fallback reads to V1 + checkAnswer(spark.read.format(dummyWriteOnlyFileSourceV2).load(path), df) + } + } + + test("Fall back write path to v1 for default save mode") { + val df = spark.range(10).toDF() + withTempPath { path => + // Default mode is ErrorIfExists, which falls back to V1. + df.write.format(dummyWriteOnlyFileSourceV2).save(path.getCanonicalPath) + checkAnswer(spark.read.parquet(path.getCanonicalPath), df) + } + } + + test("Fallback Parquet V2 to V1") { + Seq("parquet", classOf[ParquetDataSourceV2].getCanonicalName).foreach { format => + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) { + val commands = ArrayBuffer.empty[(String, LogicalPlan)] + val exceptions = ArrayBuffer.empty[(String, Exception)] + val listener = new QueryExecutionListener { + override def onFailure( + funcName: String, + qe: QueryExecution, + exception: Exception): Unit = { + exceptions += funcName -> exception + } + + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + commands += funcName -> qe.logical + } + } + spark.listenerManager.register(listener) + + try { + withTempPath { path => + val inputData = spark.range(10) + inputData.write.format(format).save(path.getCanonicalPath) + sparkContext.listenerBus.waitUntilEmpty() + assert(commands.length == 1) + assert(commands.head._1 == "command") + assert(commands.head._2.isInstanceOf[InsertIntoHadoopFsRelationCommand]) + assert(commands.head._2.asInstanceOf[InsertIntoHadoopFsRelationCommand] + .fileFormat.isInstanceOf[ParquetFileFormat]) + val df = spark.read.format(format).load(path.getCanonicalPath) + checkAnswer(df, inputData.toDF()) + assert( + df.queryExecution.executedPlan.exists(_.isInstanceOf[FileSourceScanExec])) + } + } finally { + spark.listenerManager.unregister(listener) + } + } + } + } + + test("File write for multiple formats") { + fileFormats.foreach { format => + withTempPath { path => + val inputData = spark.range(10).toDF() + inputData.write.option("header", "true").format(format).save(path.getCanonicalPath) + val readBack = spark.read.option("header", "true").schema(inputData.schema) + .format(format).load(path.getCanonicalPath) + checkAnswer(readBack, inputData) + } + } + } + + test("File write produces same results with V1 and V2 reads") { + withTempPath { v1Path => + withTempPath { v2Path => + val inputData = spark.range(100).selectExpr("id", "id * 2 as value") + + // Write via V1 path + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "parquet") { + inputData.write.parquet(v1Path.getCanonicalPath) + } + + // Write via V2 path (default) + inputData.write.parquet(v2Path.getCanonicalPath) + + // Both should produce the same results + val v1Result = spark.read.parquet(v1Path.getCanonicalPath) + val v2Result = spark.read.parquet(v2Path.getCanonicalPath) + checkAnswer(v1Result, v2Result) + } + } + } + + test("Partitioned file write") { + fileFormats.foreach { format => + withTempPath { path => + val inputData = spark.range(20).selectExpr( + "id", "id % 5 as part") + inputData.write.option("header", "true") + .partitionBy("part").format(format).save(path.getCanonicalPath) + val readBack = spark.read.option("header", "true").schema(inputData.schema) + .format(format).load(path.getCanonicalPath) + checkAnswer(readBack, inputData) + + // Verify partition directory structure exists + val partDirs = path.listFiles().filter(_.isDirectory).map(_.getName).sorted + assert(partDirs.exists(_.startsWith("part=")), + s"Expected partition directories for format $format, got: ${partDirs.mkString(", ")}") + } + } + } + + test("Partitioned write produces same results with V1 and V2 reads") { + fileFormats.foreach { format => + withTempPath { v1Path => + withTempPath { v2Path => + val inputData = spark.range(50).selectExpr( + "id", "id % 3 as category", "id * 10 as value") + + // Write via V1 path + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> format) { + inputData.write.option("header", "true") + .partitionBy("category").format(format).save(v1Path.getCanonicalPath) + } + + // Write via V2 path (default) + inputData.write.option("header", "true") + .partitionBy("category").format(format).save(v2Path.getCanonicalPath) + + val v1Result = spark.read.option("header", "true").schema(inputData.schema) + .format(format).load(v1Path.getCanonicalPath) + val v2Result = spark.read.option("header", "true").schema(inputData.schema) + .format(format).load(v2Path.getCanonicalPath) + checkAnswer(v1Result, v2Result) + } + } + } + } + + test("Multi-level partitioned write") { + fileFormats.foreach { format => + withTempPath { path => + val schema = "id LONG, year LONG, month LONG" + val inputData = spark.range(30).selectExpr( + "id", "id % 3 as year", "id % 2 as month") + inputData.write.option("header", "true") + .partitionBy("year", "month") + .format(format).save(path.getCanonicalPath) + checkAnswer( + spark.read.option("header", "true") + .schema(schema).format(format) + .load(path.getCanonicalPath), + inputData) + + val yearDirs = path.listFiles() + .filter(_.isDirectory).map(_.getName).sorted + assert(yearDirs.exists(_.startsWith("year=")), + s"Expected year partition dirs for $format") + val firstYearDir = path.listFiles() + .filter(_.isDirectory).head + val monthDirs = firstYearDir.listFiles() + .filter(_.isDirectory).map(_.getName).sorted + assert(monthDirs.exists(_.startsWith("month=")), + s"Expected month partition dirs for $format") + } + } + } + + test("Dynamic partition overwrite") { + fileFormats.foreach { format => + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> format, + SQLConf.PARTITION_OVERWRITE_MODE.key -> "dynamic") { + withTempPath { path => + val schema = "id LONG, part LONG" + val initialData = spark.range(9).selectExpr( + "id", "id % 3 as part") + initialData.write.option("header", "true") + .partitionBy("part") + .format(format).save(path.getCanonicalPath) + + val overwriteData = spark.createDataFrame( + Seq((100L, 0L), (101L, 0L))).toDF("id", "part") + overwriteData.write.option("header", "true") + .mode("overwrite").partitionBy("part") + .format(format).save(path.getCanonicalPath) + + val result = spark.read.option("header", "true") + .schema(schema).format(format) + .load(path.getCanonicalPath) + val expected = initialData.filter("part != 0") + .union(overwriteData) + checkAnswer(result, expected) + } + } + } + } + + test("Dynamic partition overwrite produces same results") { + fileFormats.foreach { format => + withTempPath { v1Path => + withTempPath { v2Path => + val schema = "id LONG, part LONG" + val initialData = spark.range(12).selectExpr( + "id", "id % 4 as part") + val overwriteData = spark.createDataFrame( + Seq((200L, 1L), (201L, 1L))).toDF("id", "part") + + Seq(v1Path, v2Path).foreach { p => + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> format, + SQLConf.PARTITION_OVERWRITE_MODE.key -> + "dynamic") { + initialData.write.option("header", "true") + .partitionBy("part").format(format) + .save(p.getCanonicalPath) + overwriteData.write.option("header", "true") + .mode("overwrite").partitionBy("part") + .format(format).save(p.getCanonicalPath) + } + } + + val v1Result = spark.read + .option("header", "true").schema(schema) + .format(format).load(v1Path.getCanonicalPath) + val v2Result = spark.read + .option("header", "true").schema(schema) + .format(format).load(v2Path.getCanonicalPath) + checkAnswer(v1Result, v2Result) + } + } + } + } + + test("DataFrame API write uses V2 path") { + fileFormats.foreach { format => + val writeOpts = if (format == "csv") { + Map("header" -> "true") + } else { + Map.empty[String, String] + } + def readBack(p: String): DataFrame = { + val r = spark.read.format(format) + val configured = if (format == "csv") { + r.option("header", "true").schema("id LONG") + } else r + configured.load(p) + } + + // SaveMode.Append to existing path goes via V2 + withTempPath { path => + val data1 = spark.range(5).toDF() + data1.write.options(writeOpts).format(format).save(path.getCanonicalPath) + val data2 = spark.range(5, 10).toDF() + data2.write.options(writeOpts).mode("append") + .format(format).save(path.getCanonicalPath) + checkAnswer(readBack(path.getCanonicalPath), + data1.union(data2)) + } + + // SaveMode.Overwrite goes via V2 + withTempPath { path => + val data1 = spark.range(5).toDF() + data1.write.options(writeOpts).format(format) + .save(path.getCanonicalPath) + val data2 = spark.range(10, 15).toDF() + data2.write.options(writeOpts).mode("overwrite") + .format(format).save(path.getCanonicalPath) + checkAnswer(readBack(path.getCanonicalPath), data2) + } + } + } + + test("DataFrame API partitioned write") { + withTempPath { path => + val data = spark.range(20).selectExpr("id", "id % 4 as part") + data.write.partitionBy("part").parquet(path.getCanonicalPath) + val result = spark.read.parquet(path.getCanonicalPath) + checkAnswer(result, data) + + val partDirs = path.listFiles().filter(_.isDirectory).map(_.getName) + assert(partDirs.exists(_.startsWith("part="))) + } + } + + test("DataFrame API write with compression option") { + withTempPath { path => + val data = spark.range(10).toDF() + data.write.option("compression", "snappy").parquet(path.getCanonicalPath) + checkAnswer(spark.read.parquet(path.getCanonicalPath), data) + } + } + + test("Catalog table INSERT INTO") { + withTable("t") { + sql("CREATE TABLE t (id BIGINT, value BIGINT) USING parquet") + sql("INSERT INTO t VALUES (1, 10), (2, 20), (3, 30)") + checkAnswer(sql("SELECT * FROM t"), + Seq((1L, 10L), (2L, 20L), (3L, 30L)).map(Row.fromTuple)) + } + } + + test("Catalog table partitioned INSERT INTO") { + withTable("t") { + sql("CREATE TABLE t (id BIGINT, part BIGINT) USING parquet PARTITIONED BY (part)") + sql("INSERT INTO t VALUES (1, 1), (2, 1), (3, 2), (4, 2)") + checkAnswer(sql("SELECT * FROM t ORDER BY id"), + Seq((1L, 1L), (2L, 1L), (3L, 2L), (4L, 2L)).map(Row.fromTuple)) + } + } + + test("V2 cache invalidation on overwrite") { + fileFormats.foreach { format => + withTempPath { path => + val p = path.getCanonicalPath + spark.range(1000).toDF("id").write.format(format).save(p) + val df = spark.read.format(format).load(p).cache() + assert(df.count() == 1000) + // Overwrite via V2 path should invalidate cache + spark.range(10).toDF("id").write.mode("append").format(format).save(p) + spark.range(10).toDF("id").write + .mode("overwrite").format(format).save(p) + assert(df.count() == 10, + s"Cache should be invalidated after V2 overwrite for $format") + df.unpersist() + } + } + } + + test("V2 cache invalidation on append") { + fileFormats.foreach { format => + withTempPath { path => + val p = path.getCanonicalPath + spark.range(1000).toDF("id").write.format(format).save(p) + val df = spark.read.format(format).load(p).cache() + assert(df.count() == 1000) + // Append via V2 path should invalidate cache + spark.range(10).toDF("id").write.mode("append").format(format).save(p) + assert(df.count() == 1010, + s"Cache should be invalidated after V2 append for $format") + df.unpersist() + } + } + } + + test("Cache invalidation on catalog table overwrite") { + withTable("t") { + sql("CREATE TABLE t (id BIGINT) USING parquet") + sql("INSERT INTO t SELECT id FROM range(100)") + spark.table("t").cache() + assert(spark.table("t").count() == 100) + sql("INSERT OVERWRITE TABLE t SELECT id FROM range(10)") + assert(spark.table("t").count() == 10, + "Cache should be invalidated after catalog table overwrite") + spark.catalog.uncacheTable("t") + } + } + + // SQL path INSERT INTO parquet.`path` requires SupportsCatalogOptions + + test("CTAS") { + withTable("t") { + sql("CREATE TABLE t USING parquet AS SELECT id, id * 2 as value FROM range(10)") + checkAnswer( + sql("SELECT count(*) FROM t"), + Seq(Row(10L))) + } + } + + test("Partitioned write to empty directory succeeds") { + fileFormats.foreach { format => + withTempDir { dir => + val schema = "id LONG, k LONG" + val data = spark.range(20).selectExpr( + "id", "id % 4 as k") + data.write.option("header", "true") + .partitionBy("k").mode("overwrite") + .format(format).save(dir.toString) + checkAnswer( + spark.read.option("header", "true") + .schema(schema).format(format) + .load(dir.toString), + data) + } + } + } + + test("Partitioned overwrite to existing directory succeeds") { + fileFormats.foreach { format => + withTempDir { dir => + val schema = "id LONG, k LONG" + val data1 = spark.range(10).selectExpr( + "id", "id % 3 as k") + data1.write.option("header", "true") + .partitionBy("k").mode("overwrite") + .format(format).save(dir.toString) + val data2 = spark.range(10, 20).selectExpr( + "id", "id % 3 as k") + data2.write.option("header", "true") + .partitionBy("k").mode("overwrite") + .format(format).save(dir.toString) + checkAnswer( + spark.read.option("header", "true") + .schema(schema).format(format) + .load(dir.toString), + data2) + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9f5566407e386..1ebe63bdbcfed 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -125,7 +125,6 @@ class HiveSessionStateBuilder( new ResolveDataSource(session) +: new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: - new FallBackFileSourceV2(session) +: ResolveEncodersInScalaAgg +: new ResolveSessionCatalog(catalogManager) +: ResolveWriteToStream +: