Skip to content

Commit 28a8496

Browse files
committed
apparently we can already create encoders from UDTs out of the box instead of just serializing them. Added tests and updated examples
1 parent 9f2018c commit 28a8496

File tree

4 files changed

+22
-3
lines changed

4 files changed

+22
-3
lines changed

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/MLlib.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ private fun KSparkSession.correlation() {
6161
Vectors.dense(4.0, 5.0, 0.0, 3.0),
6262
Vectors.dense(6.0, 7.0, 0.0, 8.0),
6363
Vectors.sparse(4, intArrayOf(0, 3), doubleArrayOf(9.0, 1.0))
64-
).map(::tupleOf)
64+
)
6565

6666
val df = data.toDF("features")
6767

examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples/UdtRegistration.kt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ fun main() = withSpark {
8383
City("Amsterdam", 1),
8484
City("Breda", 2),
8585
City("Oosterhout", 3),
86-
).map(::tupleOf)
86+
)
8787

88-
val ds = items.toDS()
88+
val ds = items.map(::tupleOf).toDS()
8989
ds.showDS()
90+
91+
// Unlike in Scala, you can also directly encode UDT registered types to a Dataset!
92+
val ds2 = items.toDS()
93+
ds2.showDS()
9094
}

kotlin-spark-api/src/main/kotlin/org/jetbrains/kotlinx/spark/api/Encoding.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ private fun isSupportedByKotlinClassEncoder(cls: KClass<*>): Boolean =
154154
cls.isSubclassOf(Iterable::class) -> true
155155
cls.isSubclassOf(Product::class) -> true
156156
cls.java.isArray -> true
157+
cls.hasAnnotation<SQLUserDefinedType>() -> true
158+
UDTRegistration.exists(cls.jvmName) -> true
157159
else -> false
158160
}
159161

kotlin-spark-api/src/test/kotlin/org/jetbrains/kotlinx/spark/api/UdtTest.kt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ class UdtTest : ShouldSpec({
7373

7474
ds.collectAsList().single() shouldBe input
7575
}
76+
77+
should("Be able to create encoder from UDT too") {
78+
79+
val input = listOf(
80+
City("Amsterdam", 1),
81+
City("Breda", 2),
82+
City("Oosterhout", 3),
83+
)
84+
85+
val ds = input.toDS()
86+
87+
ds.collectAsList() shouldBe input
88+
}
7689
}
7790
}
7891
})

0 commit comments

Comments
 (0)