Skip to content

Commit aa11744

Browse files
Jolanrensenasm0dey
authored andcommitted
feat: adds support for Tuple encoding
Up to this moment, there was no ability to work with `Tuple`s in Kotlin API for Apache Spark, which stopped us from 1. Mixing Scala and Kotlin code in one project 2. Call some operations like `select` returning typed tuples Also, potentially it could bring unavoidable performance hits when we're forcing users to use explicit Tuple → data class conversions. Costs should be negligible, but we can't really measure it and, consequentially, should give users a choice of Kotlin idiomatic way or potentially more performant code, We thank @Jolanrensen for their commitment to the project and the huge effort to fix this issue. Thank you very much!
1 parent b18f889 commit aa11744

File tree

9 files changed

+312
-47
lines changed

9 files changed

+312
-47
lines changed

core/2.4/src/main/scala/org/apache/spark/sql/catalyst/KotlinReflection.scala

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import java.lang.reflect.Type
2222
import java.lang.{Iterable => JIterable}
2323
import java.time.LocalDate
2424
import java.util.{Iterator => JIterator, List => JList, Map => JMap}
25-
2625
import com.google.common.reflect.TypeToken
2726
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
2827
import org.apache.spark.sql.catalyst.expressions._
@@ -399,10 +398,38 @@ object KotlinReflection {
399398
getPath,
400399
customCollectionCls = Some(predefinedDt.get.cls))
401400

401+
case StructType(elementType: Array[StructField]) =>
402+
val cls = t.cls
403+
404+
val arguments = elementType.map { field =>
405+
val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
406+
val nullable = dataType.nullable
407+
val clsName = dataType.cls.getName
408+
val fieldName = field.asInstanceOf[KStructField].delegate.name
409+
val newPath = addToPath(fieldName)
410+
411+
deserializerFor(
412+
TypeToken.of(dataType.cls),
413+
Some(newPath),
414+
Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
415+
)
416+
}
417+
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
418+
419+
420+
if (path.nonEmpty) {
421+
expressions.If(
422+
IsNull(getPath),
423+
expressions.Literal.create(null, ObjectType(cls)),
424+
newInstance
425+
)
426+
} else {
427+
newInstance
428+
}
402429

403430
case _ =>
404431
throw new UnsupportedOperationException(
405-
s"No Encoder found for $typeToken")
432+
s"No Encoder found for $typeToken in deserializerFor\n" + path)
406433
}
407434
}
408435

@@ -608,8 +635,34 @@ object KotlinReflection {
608635
case ArrayType(elementType, _) =>
609636
toCatalystArray(inputObject, TypeToken.of(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass]))
610637

638+
case StructType(elementType: Array[StructField]) =>
639+
val cls = otherTypeWrapper.cls
640+
val names = elementType.map(_.name)
641+
642+
val beanInfo = Introspector.getBeanInfo(cls)
643+
val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))
644+
645+
val fields = elementType.map { structField =>
646+
647+
val maybeProp = methods.find(it => it.getName == structField.name)
648+
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}")
649+
val fieldName = structField.name
650+
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
651+
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
652+
val fieldValue = Invoke(
653+
inputObject,
654+
maybeProp.get.getName,
655+
inferExternalType(propClass),
656+
returnNullable = propDt.nullable
657+
)
658+
expressions.Literal(fieldName) :: serializerFor(fieldValue, TypeToken.of(propClass), propDt match { case c: ComplexWrapper => Some(c) case _ => None }) :: Nil
659+
}
660+
val nonNullOutput = CreateNamedStruct(fields.flatten.seq)
661+
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
662+
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
663+
611664
case _ =>
612-
throw new UnsupportedOperationException(s"No Encoder found for $typeToken.")
665+
throw new UnsupportedOperationException(s"No Encoder found for $typeToken in serializerFor. $otherTypeWrapper")
613666

614667
}
615668

core/3.0/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 125 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
package org.apache.spark.sql
2222

23-
import java.beans.{Introspector, PropertyDescriptor}
24-
2523
import org.apache.spark.internal.Logging
2624
import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
2725
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
@@ -33,6 +31,8 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePa
3331
import org.apache.spark.sql.types._
3432
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
3533

34+
import java.beans.{Introspector, PropertyDescriptor}
35+
3636

3737
/**
3838
* A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder]]s
@@ -440,11 +440,79 @@ object KotlinReflection extends KotlinReflection {
440440

441441
UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(t.cls))
442442

443+
case StructType(elementType: Array[StructField]) =>
444+
val cls = t.cls
445+
446+
val arguments = elementType.map { field =>
447+
val dataType = field.dataType.asInstanceOf[DataTypeWithClass]
448+
val nullable = dataType.nullable
449+
val clsName = getClassNameFromType(getType(dataType.cls))
450+
val newTypePath = walkedTypePath.recordField(clsName, field.name)
451+
452+
// For tuples, we based grab the inner fields by ordinal instead of name.
453+
val newPath = deserializerFor(
454+
getType(dataType.cls),
455+
addToPath(path, field.name, dataType.dt, newTypePath),
456+
newTypePath,
457+
Some(dataType).filter(_.isInstanceOf[ComplexWrapper])
458+
)
459+
expressionWithNullSafety(
460+
newPath,
461+
nullable = nullable,
462+
newTypePath
463+
)
464+
}
465+
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
466+
467+
org.apache.spark.sql.catalyst.expressions.If(
468+
IsNull(path),
469+
org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
470+
newInstance
471+
)
472+
473+
443474
case _ =>
444475
throw new UnsupportedOperationException(
445476
s"No Encoder found for $tpe\n" + walkedTypePath)
446477
}
447478
}
479+
480+
case t if definedByConstructorParams(t) =>
481+
val params = getConstructorParameters(t)
482+
483+
val cls = getClassFromType(tpe)
484+
485+
val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
486+
val Schema(dataType, nullable) = schemaFor(fieldType)
487+
val clsName = getClassNameFromType(fieldType)
488+
val newTypePath = walkedTypePath.recordField(clsName, fieldName)
489+
490+
// For tuples, we based grab the inner fields by ordinal instead of name.
491+
val newPath = if (cls.getName startsWith "scala.Tuple") {
492+
deserializerFor(
493+
fieldType,
494+
addToPathOrdinal(path, i, dataType, newTypePath),
495+
newTypePath)
496+
} else {
497+
deserializerFor(
498+
fieldType,
499+
addToPath(path, fieldName, dataType, newTypePath),
500+
newTypePath)
501+
}
502+
expressionWithNullSafety(
503+
newPath,
504+
nullable = nullable,
505+
newTypePath)
506+
}
507+
508+
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)
509+
510+
org.apache.spark.sql.catalyst.expressions.If(
511+
IsNull(path),
512+
org.apache.spark.sql.catalyst.expressions.Literal.create(null, ObjectType(cls)),
513+
newInstance
514+
)
515+
448516
case _ =>
449517
throw new UnsupportedOperationException(
450518
s"No Encoder found for $tpe\n" + walkedTypePath)
@@ -519,7 +587,7 @@ object KotlinReflection extends KotlinReflection {
519587

520588
def toCatalystArray(input: Expression, elementType: `Type`, predefinedDt: Option[DataTypeWithClass] = None): Expression = {
521589
predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
522-
case dt:StructType =>
590+
case dt: StructType =>
523591
val clsName = getClassNameFromType(elementType)
524592
val newPath = walkedTypePath.recordArray(clsName)
525593
createSerializerForMapObjects(input, ObjectType(predefinedDt.get.cls),
@@ -662,32 +730,6 @@ object KotlinReflection extends KotlinReflection {
662730
createSerializerForUserDefinedType(inputObject, udt, udtClass)
663731
//</editor-fold>
664732

665-
666-
case t if definedByConstructorParams(t) =>
667-
if (seenTypeSet.contains(t)) {
668-
throw new UnsupportedOperationException(
669-
s"cannot have circular references in class, but got the circular reference of class $t")
670-
}
671-
672-
val params = getConstructorParameters(t)
673-
val fields = params.map { case (fieldName, fieldType) =>
674-
if (javaKeywords.contains(fieldName)) {
675-
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
676-
"cannot be used as field name\n" + walkedTypePath)
677-
}
678-
679-
// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
680-
// is necessary here. Because for a nullable nested inputObject with struct data
681-
// type, e.g. StructType(IntegerType, StringType), it will return nullable=true
682-
// for IntegerType without KnownNotNull. And that's what we do not expect to.
683-
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
684-
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
685-
val clsName = getClassNameFromType(fieldType)
686-
val newPath = walkedTypePath.recordField(clsName, fieldName)
687-
(fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
688-
}
689-
createSerializerForObject(inputObject, fields)
690-
691733
case _ if predefinedDt.isDefined =>
692734
predefinedDt.get match {
693735
case dataType: KDataTypeWrapper =>
@@ -735,12 +777,66 @@ object KotlinReflection extends KotlinReflection {
735777
)
736778
case ArrayType(elementType, _) =>
737779
toCatalystArray(inputObject, getType(elementType.asInstanceOf[DataTypeWithClass].cls), Some(elementType.asInstanceOf[DataTypeWithClass]))
780+
781+
case StructType(elementType: Array[StructField]) =>
782+
val cls = otherTypeWrapper.cls
783+
val names = elementType.map(_.name)
784+
785+
val beanInfo = Introspector.getBeanInfo(cls)
786+
val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))
787+
788+
789+
val fields = elementType.map { structField =>
790+
791+
val maybeProp = methods.find(it => it.getName == structField.name)
792+
if (maybeProp.isEmpty) throw new IllegalArgumentException(s"Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(", ")}")
793+
val fieldName = structField.name
794+
val propClass = structField.dataType.asInstanceOf[DataTypeWithClass].cls
795+
val propDt = structField.dataType.asInstanceOf[DataTypeWithClass]
796+
val fieldValue = Invoke(
797+
inputObject,
798+
maybeProp.get.getName,
799+
inferExternalType(propClass),
800+
returnNullable = propDt.nullable
801+
)
802+
val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
803+
(fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf[ComplexWrapper]) Some(propDt) else None))
804+
805+
}
806+
createSerializerForObject(inputObject, fields)
807+
738808
case _ =>
739809
throw new UnsupportedOperationException(
740810
s"No Encoder found for $tpe\n" + walkedTypePath)
741811

742812
}
743813
}
814+
815+
case t if definedByConstructorParams(t) =>
816+
if (seenTypeSet.contains(t)) {
817+
throw new UnsupportedOperationException(
818+
s"cannot have circular references in class, but got the circular reference of class $t")
819+
}
820+
821+
val params = getConstructorParameters(t)
822+
val fields = params.map { case (fieldName, fieldType) =>
823+
if (javaKeywords.contains(fieldName)) {
824+
throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
825+
"cannot be used as field name\n" + walkedTypePath)
826+
}
827+
828+
// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
829+
// is necessary here. Because for a nullable nested inputObject with struct data
830+
// type, e.g. StructType(IntegerType, StringType), it will return nullable=true
831+
// for IntegerType without KnownNotNull. And that's what we do not expect to.
832+
val fieldValue = Invoke(KnownNotNull(inputObject), fieldName, dataTypeFor(fieldType),
833+
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
834+
val clsName = getClassNameFromType(fieldType)
835+
val newPath = walkedTypePath.recordField(clsName, fieldName)
836+
(fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
837+
}
838+
createSerializerForObject(inputObject, fields)
839+
744840
case _ =>
745841
throw new UnsupportedOperationException(
746842
s"No Encoder found for $tpe\n" + walkedTypePath)

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import org.apache.spark.sql.streaming.GroupStateTimeout
3636
import org.apache.spark.sql.streaming.OutputMode
3737
import org.apache.spark.sql.types.*
3838
import org.jetbrains.kotlinx.spark.extensions.KSparkExtensions
39+
import scala.*
3940
import scala.collection.Seq
4041
import scala.reflect.`ClassTag$`
4142
import java.beans.PropertyDescriptor
@@ -122,8 +123,6 @@ inline fun <reified T> List<T>.toDS(spark: SparkSession): Dataset<T> =
122123
* It creates encoder for any given supported type T
123124
*
124125
* Supported types are data classes, primitives, and Lists, Maps and Arrays containing them
125-
* are you here?
126-
* Pavel??
127126
* @param T type, supported by Spark
128127
* @return generated encoder
129128
*/
@@ -141,6 +140,7 @@ fun <T> generateEncoder(type: KType, cls: KClass<*>): Encoder<T> {
141140
private fun isSupportedClass(cls: KClass<*>): Boolean = cls.isData
142141
|| cls.isSubclassOf(Map::class)
143142
|| cls.isSubclassOf(Iterable::class)
143+
|| cls.isSubclassOf(Product::class)
144144
|| cls.java.isArray
145145

146146
@Suppress("UNCHECKED_CAST")
@@ -418,6 +418,20 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
418418
)
419419
KDataTypeWrapper(structType, klass.java, true)
420420
}
421+
klass.isSubclassOf(Product::class) -> {
422+
val params = type.arguments.mapIndexed { i, it ->
423+
"_${i + 1}" to it.type!!
424+
}
425+
426+
val structType = DataTypes.createStructType(
427+
params.map { (fieldName, fieldType) ->
428+
val dataType = schema(fieldType, types)
429+
KStructField(fieldName, StructField(fieldName, dataType, fieldType.isMarkedNullable, Metadata.empty()))
430+
}.toTypedArray()
431+
)
432+
433+
KComplexTypeWrapper(structType, klass.java, true)
434+
}
421435
else -> throw IllegalArgumentException("$type is unsupported")
422436
}
423437
}
@@ -430,6 +444,8 @@ enum class SparkLogLevel {
430444
ALL, DEBUG, ERROR, FATAL, INFO, OFF, TRACE, WARN
431445
}
432446

447+
val timestampDt = `TimestampType$`.`MODULE$`
448+
val dateDt = `DateType$`.`MODULE$`
433449
private val knownDataTypes = mapOf(
434450
Byte::class to DataTypes.ByteType,
435451
Short::class to DataTypes.ShortType,
@@ -439,10 +455,10 @@ private val knownDataTypes = mapOf(
439455
Float::class to DataTypes.FloatType,
440456
Double::class to DataTypes.DoubleType,
441457
String::class to DataTypes.StringType,
442-
LocalDate::class to `DateType$`.`MODULE$`,
443-
Date::class to `DateType$`.`MODULE$`,
444-
Timestamp::class to `TimestampType$`.`MODULE$`,
445-
Instant::class to `TimestampType$`.`MODULE$`
458+
LocalDate::class to dateDt,
459+
Date::class to dateDt,
460+
Timestamp::class to timestampDt,
461+
Instant::class to timestampDt
446462
)
447463

448464
private fun transitiveMerge(a: Map<String, KType>, b: Map<String, KType>): Map<String, KType> {
@@ -459,4 +475,4 @@ class Memoize1<in T, out R>(val f: (T) -> R) : (T) -> R {
459475

460476
private fun <T, R> ((T) -> R).memoize(): (T) -> R = Memoize1(this)
461477

462-
private val memoizedSchema = { x: KType -> schema(x) }.memoize()
478+
private val memoizedSchema = { x: KType -> schema(x) }.memoize()

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Conversions.kt

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,23 @@
1-
@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments")
1+
/*-
2+
* =LICENSE=
3+
* Kotlin Spark API: API for Spark 2.4+ (Scala 2.12)
4+
* ----------
5+
* Copyright (C) 2019 - 2021 JetBrains
6+
* ----------
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
* =LICENSEEND=
19+
*/
20+
@file:Suppress("NOTHING_TO_INLINE", "RemoveExplicitTypeArguments", "unused")
221

322
package org.jetbrains.kotlinx.spark.api
423

0 commit comments

Comments
 (0)