2020
2121package org .apache .spark .sql
2222
23- import java .beans .{Introspector , PropertyDescriptor }
24-
2523import org .apache .spark .internal .Logging
2624import org .apache .spark .sql .catalyst .DeserializerBuildHelper ._
2725import org .apache .spark .sql .catalyst .SerializerBuildHelper ._
@@ -33,6 +31,8 @@ import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection, WalkedTypePa
3331import org .apache .spark .sql .types ._
3432import org .apache .spark .unsafe .types .{CalendarInterval , UTF8String }
3533
34+ import java .beans .{Introspector , PropertyDescriptor }
35+
3636
3737/**
3838 * A helper trait to create [[org.apache.spark.sql.catalyst.encoders.ExpressionEncoder ]]s
@@ -440,11 +440,79 @@ object KotlinReflection extends KotlinReflection {
440440
441441 UnresolvedMapObjects (mapFunction, path, customCollectionCls = Some (t.cls))
442442
443+ case StructType (elementType : Array [StructField ]) =>
444+ val cls = t.cls
445+
446+ val arguments = elementType.map { field =>
447+ val dataType = field.dataType.asInstanceOf [DataTypeWithClass ]
448+ val nullable = dataType.nullable
449+ val clsName = getClassNameFromType(getType(dataType.cls))
450+ val newTypePath = walkedTypePath.recordField(clsName, field.name)
451+
452+ // For tuples, we based grab the inner fields by ordinal instead of name.
453+ val newPath = deserializerFor(
454+ getType(dataType.cls),
455+ addToPath(path, field.name, dataType.dt, newTypePath),
456+ newTypePath,
457+ Some (dataType).filter(_.isInstanceOf [ComplexWrapper ])
458+ )
459+ expressionWithNullSafety(
460+ newPath,
461+ nullable = nullable,
462+ newTypePath
463+ )
464+ }
465+ val newInstance = NewInstance (cls, arguments, ObjectType (cls), propagateNull = false )
466+
467+ org.apache.spark.sql.catalyst.expressions.If (
468+ IsNull (path),
469+ org.apache.spark.sql.catalyst.expressions.Literal .create(null , ObjectType (cls)),
470+ newInstance
471+ )
472+
473+
443474 case _ =>
444475 throw new UnsupportedOperationException (
445476 s " No Encoder found for $tpe\n " + walkedTypePath)
446477 }
447478 }
479+
480+ case t if definedByConstructorParams(t) =>
481+ val params = getConstructorParameters(t)
482+
483+ val cls = getClassFromType(tpe)
484+
485+ val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) =>
486+ val Schema (dataType, nullable) = schemaFor(fieldType)
487+ val clsName = getClassNameFromType(fieldType)
488+ val newTypePath = walkedTypePath.recordField(clsName, fieldName)
489+
490+ // For tuples, we based grab the inner fields by ordinal instead of name.
491+ val newPath = if (cls.getName startsWith " scala.Tuple" ) {
492+ deserializerFor(
493+ fieldType,
494+ addToPathOrdinal(path, i, dataType, newTypePath),
495+ newTypePath)
496+ } else {
497+ deserializerFor(
498+ fieldType,
499+ addToPath(path, fieldName, dataType, newTypePath),
500+ newTypePath)
501+ }
502+ expressionWithNullSafety(
503+ newPath,
504+ nullable = nullable,
505+ newTypePath)
506+ }
507+
508+ val newInstance = NewInstance (cls, arguments, ObjectType (cls), propagateNull = false )
509+
510+ org.apache.spark.sql.catalyst.expressions.If (
511+ IsNull (path),
512+ org.apache.spark.sql.catalyst.expressions.Literal .create(null , ObjectType (cls)),
513+ newInstance
514+ )
515+
448516 case _ =>
449517 throw new UnsupportedOperationException (
450518 s " No Encoder found for $tpe\n " + walkedTypePath)
@@ -519,7 +587,7 @@ object KotlinReflection extends KotlinReflection {
519587
520588 def toCatalystArray (input : Expression , elementType : `Type`, predefinedDt : Option [DataTypeWithClass ] = None ): Expression = {
521589 predefinedDt.map(_.dt).getOrElse(dataTypeFor(elementType)) match {
522- case dt: StructType =>
590+ case dt : StructType =>
523591 val clsName = getClassNameFromType(elementType)
524592 val newPath = walkedTypePath.recordArray(clsName)
525593 createSerializerForMapObjects(input, ObjectType (predefinedDt.get.cls),
@@ -662,32 +730,6 @@ object KotlinReflection extends KotlinReflection {
662730 createSerializerForUserDefinedType(inputObject, udt, udtClass)
663731 // </editor-fold>
664732
665-
666- case t if definedByConstructorParams(t) =>
667- if (seenTypeSet.contains(t)) {
668- throw new UnsupportedOperationException (
669- s " cannot have circular references in class, but got the circular reference of class $t" )
670- }
671-
672- val params = getConstructorParameters(t)
673- val fields = params.map { case (fieldName, fieldType) =>
674- if (javaKeywords.contains(fieldName)) {
675- throw new UnsupportedOperationException (s " ` $fieldName` is a reserved keyword and " +
676- " cannot be used as field name\n " + walkedTypePath)
677- }
678-
679- // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
680- // is necessary here. Because for a nullable nested inputObject with struct data
681- // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
682- // for IntegerType without KnownNotNull. And that's what we do not expect to.
683- val fieldValue = Invoke (KnownNotNull (inputObject), fieldName, dataTypeFor(fieldType),
684- returnNullable = ! fieldType.typeSymbol.asClass.isPrimitive)
685- val clsName = getClassNameFromType(fieldType)
686- val newPath = walkedTypePath.recordField(clsName, fieldName)
687- (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
688- }
689- createSerializerForObject(inputObject, fields)
690-
691733 case _ if predefinedDt.isDefined =>
692734 predefinedDt.get match {
693735 case dataType : KDataTypeWrapper =>
@@ -735,12 +777,66 @@ object KotlinReflection extends KotlinReflection {
735777 )
736778 case ArrayType (elementType, _) =>
737779 toCatalystArray(inputObject, getType(elementType.asInstanceOf [DataTypeWithClass ].cls), Some (elementType.asInstanceOf [DataTypeWithClass ]))
780+
781+ case StructType (elementType : Array [StructField ]) =>
782+ val cls = otherTypeWrapper.cls
783+ val names = elementType.map(_.name)
784+
785+ val beanInfo = Introspector .getBeanInfo(cls)
786+ val methods = beanInfo.getMethodDescriptors.filter(it => names.contains(it.getName))
787+
788+
789+ val fields = elementType.map { structField =>
790+
791+ val maybeProp = methods.find(it => it.getName == structField.name)
792+ if (maybeProp.isEmpty) throw new IllegalArgumentException (s " Field ${structField.name} is not found among available props, which are: ${methods.map(_.getName).mkString(" , " )}" )
793+ val fieldName = structField.name
794+ val propClass = structField.dataType.asInstanceOf [DataTypeWithClass ].cls
795+ val propDt = structField.dataType.asInstanceOf [DataTypeWithClass ]
796+ val fieldValue = Invoke (
797+ inputObject,
798+ maybeProp.get.getName,
799+ inferExternalType(propClass),
800+ returnNullable = propDt.nullable
801+ )
802+ val newPath = walkedTypePath.recordField(propClass.getName, fieldName)
803+ (fieldName, serializerFor(fieldValue, getType(propClass), newPath, seenTypeSet, if (propDt.isInstanceOf [ComplexWrapper ]) Some (propDt) else None ))
804+
805+ }
806+ createSerializerForObject(inputObject, fields)
807+
738808 case _ =>
739809 throw new UnsupportedOperationException (
740810 s " No Encoder found for $tpe\n " + walkedTypePath)
741811
742812 }
743813 }
814+
815+ case t if definedByConstructorParams(t) =>
816+ if (seenTypeSet.contains(t)) {
817+ throw new UnsupportedOperationException (
818+ s " cannot have circular references in class, but got the circular reference of class $t" )
819+ }
820+
821+ val params = getConstructorParameters(t)
822+ val fields = params.map { case (fieldName, fieldType) =>
823+ if (javaKeywords.contains(fieldName)) {
824+ throw new UnsupportedOperationException (s " ` $fieldName` is a reserved keyword and " +
825+ " cannot be used as field name\n " + walkedTypePath)
826+ }
827+
828+ // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
829+ // is necessary here. Because for a nullable nested inputObject with struct data
830+ // type, e.g. StructType(IntegerType, StringType), it will return nullable=true
831+ // for IntegerType without KnownNotNull. And that's what we do not expect to.
832+ val fieldValue = Invoke (KnownNotNull (inputObject), fieldName, dataTypeFor(fieldType),
833+ returnNullable = ! fieldType.typeSymbol.asClass.isPrimitive)
834+ val clsName = getClassNameFromType(fieldType)
835+ val newPath = walkedTypePath.recordField(clsName, fieldName)
836+ (fieldName, serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t))
837+ }
838+ createSerializerForObject(inputObject, fields)
839+
744840 case _ =>
745841 throw new UnsupportedOperationException (
746842 s " No Encoder found for $tpe\n " + walkedTypePath)
0 commit comments