Skip to content

Commit 5c4f344

Browse files
authored
feat: Better selectTyped behavior
* updated col().`as`() behavior and added single selectTyped() variant * reverted col().`as`<>() behavior for simplicity. Updated selectTyped to accept any TypedColumn * updated readme * removed unnecessary imports
1 parent b7cf42d commit 5c4f344

File tree

5 files changed

+131
-53
lines changed
  • kotlin-spark-api
    • 2.4/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api
    • 3.0/src
      • main/kotlin/org/jetbrains/kotlinx/spark/api
      • test/kotlin/org/jetbrains/kotlinx/spark/api

5 files changed

+131
-53
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ to create `TypedColumn`s and with those a new Dataset from pieces of another usi
192192
```kotlin
193193
val dataset: Dataset<YourClass> = ...
194194
val newDataset: Dataset<Pair<TypeA, TypeB>> = dataset.selectTyped(col(YourClass::colA), col(YourClass::colB))
195+
196+
// Alternatively, for instance when working with a Dataset<Row>
197+
val typedDataset: Dataset<Pair<String, Int>> = otherDataset.selectTyped(col("a").`as`<String>(), col("b").`as`<Int>())
195198
```
196199

197200
### Overload resolution ambiguity

kotlin-spark-api/2.4/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -647,12 +647,19 @@ operator fun Column.get(key: Any): Column = getItem(key)
647647
fun lit(a: Any) = functions.lit(a)
648648

649649
/**
650-
* Provides a type hint about the expected return value of this column. This information can
650+
* Provides a type hint about the expected return value of this column. This information can
651651
* be used by operations such as `select` on a [Dataset] to automatically convert the
652652
* results into the correct JVM types.
653+
*
654+
* ```
655+
* val df: Dataset<Row> = ...
656+
* val typedColumn: Dataset<Int> = df.selectTyped( col("a").`as`<Int>() )
657+
* ```
653658
*/
659+
@Suppress("UNCHECKED_CAST")
654660
inline fun <reified T> Column.`as`(): TypedColumn<Any, T> = `as`(encoder<T>())
655661

662+
656663
/**
657664
* Alias for [Dataset.joinWith] which passes "left" argument
658665
* and respects the fact that in result of left join right relation is nullable
@@ -809,45 +816,74 @@ fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply {
809816
/**
810817
* Returns a new Dataset by computing the given [Column] expressions for each element.
811818
*/
819+
@Suppress("UNCHECKED_CAST")
820+
inline fun <reified T, reified U1> Dataset<T>.selectTyped(
821+
c1: TypedColumn<out Any, U1>,
822+
): Dataset<U1> = select(c1 as TypedColumn<T, U1>)
823+
824+
/**
825+
* Returns a new Dataset by computing the given [Column] expressions for each element.
826+
*/
827+
@Suppress("UNCHECKED_CAST")
812828
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
813-
c1: TypedColumn<T, U1>,
814-
c2: TypedColumn<T, U2>,
829+
c1: TypedColumn<out Any, U1>,
830+
c2: TypedColumn<out Any, U2>,
815831
): Dataset<Pair<U1, U2>> =
816-
select(c1, c2).map { Pair(it._1(), it._2()) }
832+
select(
833+
c1 as TypedColumn<T, U1>,
834+
c2 as TypedColumn<T, U2>,
835+
).map { Pair(it._1(), it._2()) }
817836

818837
/**
819838
* Returns a new Dataset by computing the given [Column] expressions for each element.
820839
*/
840+
@Suppress("UNCHECKED_CAST")
821841
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
822-
c1: TypedColumn<T, U1>,
823-
c2: TypedColumn<T, U2>,
824-
c3: TypedColumn<T, U3>,
842+
c1: TypedColumn<out Any, U1>,
843+
c2: TypedColumn<out Any, U2>,
844+
c3: TypedColumn<out Any, U3>,
825845
): Dataset<Triple<U1, U2, U3>> =
826-
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
846+
select(
847+
c1 as TypedColumn<T, U1>,
848+
c2 as TypedColumn<T, U2>,
849+
c3 as TypedColumn<T, U3>,
850+
).map { Triple(it._1(), it._2(), it._3()) }
827851

828852
/**
829853
* Returns a new Dataset by computing the given [Column] expressions for each element.
830854
*/
855+
@Suppress("UNCHECKED_CAST")
831856
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
832-
c1: TypedColumn<T, U1>,
833-
c2: TypedColumn<T, U2>,
834-
c3: TypedColumn<T, U3>,
835-
c4: TypedColumn<T, U4>,
857+
c1: TypedColumn<out Any, U1>,
858+
c2: TypedColumn<out Any, U2>,
859+
c3: TypedColumn<out Any, U3>,
860+
c4: TypedColumn<out Any, U4>,
836861
): Dataset<Arity4<U1, U2, U3, U4>> =
837-
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
862+
select(
863+
c1 as TypedColumn<T, U1>,
864+
c2 as TypedColumn<T, U2>,
865+
c3 as TypedColumn<T, U3>,
866+
c4 as TypedColumn<T, U4>,
867+
).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
838868

839869
/**
840870
* Returns a new Dataset by computing the given [Column] expressions for each element.
841871
*/
872+
@Suppress("UNCHECKED_CAST")
842873
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
843-
c1: TypedColumn<T, U1>,
844-
c2: TypedColumn<T, U2>,
845-
c3: TypedColumn<T, U3>,
846-
c4: TypedColumn<T, U4>,
847-
c5: TypedColumn<T, U5>,
874+
c1: TypedColumn<out Any, U1>,
875+
c2: TypedColumn<out Any, U2>,
876+
c3: TypedColumn<out Any, U3>,
877+
c4: TypedColumn<out Any, U4>,
878+
c5: TypedColumn<out Any, U5>,
848879
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
849-
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
850-
880+
select(
881+
c1 as TypedColumn<T, U1>,
882+
c2 as TypedColumn<T, U2>,
883+
c3 as TypedColumn<T, U3>,
884+
c4 as TypedColumn<T, U4>,
885+
c5 as TypedColumn<T, U5>,
886+
).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
851887

852888
@OptIn(ExperimentalStdlibApi::class)
853889
inline fun <reified T> schema(map: Map<String, KType> = mapOf()) = schema(typeOf<T>(), map)

kotlin-spark-api/2.4/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import ch.tutteli.atrium.api.verbs.expect
2222
import io.kotest.core.spec.style.ShouldSpec
2323
import io.kotest.matchers.shouldBe
2424
import org.apache.spark.sql.Dataset
25-
import org.apache.spark.sql.TypedColumn
2625
import org.apache.spark.sql.functions.*
2726
import org.apache.spark.sql.streaming.GroupState
2827
import org.apache.spark.sql.streaming.GroupStateTimeout
@@ -339,31 +338,34 @@ class ApiTest : ShouldSpec({
339338
SomeClass(intArrayOf(1, 2, 4), 5),
340339
)
341340

342-
val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
341+
val newDS1WithAs: Dataset<Int> = dataset.selectTyped(
342+
col("b").`as`<Int>(),
343+
)
344+
newDS1WithAs.show()
343345

344-
val newDS2 = dataset.selectTyped(
346+
val newDS2: Dataset<Pair<Int, Int>> = dataset.selectTyped(
345347
// col(SomeClass::a), NOTE that this doesn't work on 2.4, returnting a data class with an array in it
346348
col(SomeClass::b),
347349
col(SomeClass::b),
348350
)
349351
newDS2.show()
350352

351-
val newDS3 = dataset.selectTyped(
353+
val newDS3: Dataset<Triple<Int, Int, Int>> = dataset.selectTyped(
352354
col(SomeClass::b),
353355
col(SomeClass::b),
354356
col(SomeClass::b),
355357
)
356358
newDS3.show()
357359

358-
val newDS4 = dataset.selectTyped(
360+
val newDS4: Dataset<Arity4<Int, Int, Int, Int>> = dataset.selectTyped(
359361
col(SomeClass::b),
360362
col(SomeClass::b),
361363
col(SomeClass::b),
362364
col(SomeClass::b),
363365
)
364366
newDS4.show()
365367

366-
val newDS5 = dataset.selectTyped(
368+
val newDS5: Dataset<Arity5<Int, Int, Int, Int, Int>> = dataset.selectTyped(
367369
col(SomeClass::b),
368370
col(SomeClass::b),
369371
col(SomeClass::b),

kotlin-spark-api/3.0/src/main/kotlin/org/jetbrains/kotlinx/spark/api/ApiV1.kt

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -651,10 +651,16 @@ operator fun Column.get(key: Any): Column = getItem(key)
651651
fun lit(a: Any) = functions.lit(a)
652652

653653
/**
654-
* Provides a type hint about the expected return value of this column. This information can
654+
* Provides a type hint about the expected return value of this column. This information can
655655
* be used by operations such as `select` on a [Dataset] to automatically convert the
656656
* results into the correct JVM types.
657+
*
658+
* ```
659+
* val df: Dataset<Row> = ...
660+
* val typedColumn: Dataset<Int> = df.selectTyped( col("a").`as`<Int>() )
661+
* ```
657662
*/
663+
@Suppress("UNCHECKED_CAST")
658664
inline fun <reified T> Column.`as`(): TypedColumn<Any, T> = `as`(encoder<T>())
659665

660666
/**
@@ -776,9 +782,8 @@ inline fun <reified T, reified U> Dataset<T>.col(column: KProperty1<T, U>): Type
776782
* Returns a [Column] based on the given class attribute, not connected to a dataset.
777783
* ```kotlin
778784
* val dataset: Dataset<YourClass> = ...
779-
* val new: Dataset<Tuple2<TypeOfA, TypeOfB>> = dataset.select( col(YourClass::a), col(YourClass::b) )
785+
* val new: Dataset<Pair<TypeOfA, TypeOfB>> = dataset.select( col(YourClass::a), col(YourClass::b) )
780786
* ```
781-
* TODO: change example to [Pair]s when merged
782787
*/
783788
@Suppress("UNCHECKED_CAST")
784789
inline fun <reified T, reified U> col(column: KProperty1<T, U>): TypedColumn<T, U> =
@@ -813,44 +818,74 @@ fun <T> Dataset<T>.showDS(numRows: Int = 20, truncate: Boolean = true) = apply {
813818
/**
814819
* Returns a new Dataset by computing the given [Column] expressions for each element.
815820
*/
821+
@Suppress("UNCHECKED_CAST")
822+
inline fun <reified T, reified U1> Dataset<T>.selectTyped(
823+
c1: TypedColumn<out Any, U1>,
824+
): Dataset<U1> = select(c1 as TypedColumn<T, U1>)
825+
826+
/**
827+
* Returns a new Dataset by computing the given [Column] expressions for each element.
828+
*/
829+
@Suppress("UNCHECKED_CAST")
816830
inline fun <reified T, reified U1, reified U2> Dataset<T>.selectTyped(
817-
c1: TypedColumn<T, U1>,
818-
c2: TypedColumn<T, U2>,
831+
c1: TypedColumn<out Any, U1>,
832+
c2: TypedColumn<out Any, U2>,
819833
): Dataset<Pair<U1, U2>> =
820-
select(c1, c2).map { Pair(it._1(), it._2()) }
834+
select(
835+
c1 as TypedColumn<T, U1>,
836+
c2 as TypedColumn<T, U2>,
837+
).map { Pair(it._1(), it._2()) }
821838

822839
/**
823840
* Returns a new Dataset by computing the given [Column] expressions for each element.
824841
*/
842+
@Suppress("UNCHECKED_CAST")
825843
inline fun <reified T, reified U1, reified U2, reified U3> Dataset<T>.selectTyped(
826-
c1: TypedColumn<T, U1>,
827-
c2: TypedColumn<T, U2>,
828-
c3: TypedColumn<T, U3>,
844+
c1: TypedColumn<out Any, U1>,
845+
c2: TypedColumn<out Any, U2>,
846+
c3: TypedColumn<out Any, U3>,
829847
): Dataset<Triple<U1, U2, U3>> =
830-
select(c1, c2, c3).map { Triple(it._1(), it._2(), it._3()) }
848+
select(
849+
c1 as TypedColumn<T, U1>,
850+
c2 as TypedColumn<T, U2>,
851+
c3 as TypedColumn<T, U3>,
852+
).map { Triple(it._1(), it._2(), it._3()) }
831853

832854
/**
833855
* Returns a new Dataset by computing the given [Column] expressions for each element.
834856
*/
857+
@Suppress("UNCHECKED_CAST")
835858
inline fun <reified T, reified U1, reified U2, reified U3, reified U4> Dataset<T>.selectTyped(
836-
c1: TypedColumn<T, U1>,
837-
c2: TypedColumn<T, U2>,
838-
c3: TypedColumn<T, U3>,
839-
c4: TypedColumn<T, U4>,
859+
c1: TypedColumn<out Any, U1>,
860+
c2: TypedColumn<out Any, U2>,
861+
c3: TypedColumn<out Any, U3>,
862+
c4: TypedColumn<out Any, U4>,
840863
): Dataset<Arity4<U1, U2, U3, U4>> =
841-
select(c1, c2, c3, c4).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
864+
select(
865+
c1 as TypedColumn<T, U1>,
866+
c2 as TypedColumn<T, U2>,
867+
c3 as TypedColumn<T, U3>,
868+
c4 as TypedColumn<T, U4>,
869+
).map { Arity4(it._1(), it._2(), it._3(), it._4()) }
842870

843871
/**
844872
* Returns a new Dataset by computing the given [Column] expressions for each element.
845873
*/
874+
@Suppress("UNCHECKED_CAST")
846875
inline fun <reified T, reified U1, reified U2, reified U3, reified U4, reified U5> Dataset<T>.selectTyped(
847-
c1: TypedColumn<T, U1>,
848-
c2: TypedColumn<T, U2>,
849-
c3: TypedColumn<T, U3>,
850-
c4: TypedColumn<T, U4>,
851-
c5: TypedColumn<T, U5>,
876+
c1: TypedColumn<out Any, U1>,
877+
c2: TypedColumn<out Any, U2>,
878+
c3: TypedColumn<out Any, U3>,
879+
c4: TypedColumn<out Any, U4>,
880+
c5: TypedColumn<out Any, U5>,
852881
): Dataset<Arity5<U1, U2, U3, U4, U5>> =
853-
select(c1, c2, c3, c4, c5).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
882+
select(
883+
c1 as TypedColumn<T, U1>,
884+
c2 as TypedColumn<T, U2>,
885+
c3 as TypedColumn<T, U3>,
886+
c4 as TypedColumn<T, U4>,
887+
c5 as TypedColumn<T, U5>,
888+
).map { Arity5(it._1(), it._2(), it._3(), it._4(), it._5()) }
854889

855890

856891
@OptIn(ExperimentalStdlibApi::class)

kotlin-spark-api/3.0/src/test/kotlin/org/jetbrains/kotlinx/spark/api/ApiTest.kt

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import ch.tutteli.atrium.api.verbs.expect
2222
import io.kotest.core.spec.style.ShouldSpec
2323
import io.kotest.matchers.shouldBe
2424
import org.apache.spark.sql.Dataset
25-
import org.apache.spark.sql.TypedColumn
2625
import org.apache.spark.sql.functions.*
2726
import org.apache.spark.sql.streaming.GroupState
2827
import org.apache.spark.sql.streaming.GroupStateTimeout
@@ -364,30 +363,33 @@ class ApiTest : ShouldSpec({
364363
SomeClass(intArrayOf(1, 2, 4), 5),
365364
)
366365

367-
val typedColumnA: TypedColumn<Any, IntArray> = dataset.col("a").`as`(encoder())
366+
val newDS1WithAs: Dataset<IntArray> = dataset.selectTyped(
367+
col("a").`as`<IntArray>(),
368+
)
369+
newDS1WithAs.show()
368370

369-
val newDS2 = dataset.selectTyped(
371+
val newDS2: Dataset<Pair<IntArray, Int>> = dataset.selectTyped(
370372
col(SomeClass::a), // NOTE: this only works on 3.0, returning a data class with an array in it
371373
col(SomeClass::b),
372374
)
373375
newDS2.show()
374376

375-
val newDS3 = dataset.selectTyped(
377+
val newDS3: Dataset<Triple<IntArray, Int, Int>> = dataset.selectTyped(
376378
col(SomeClass::a),
377379
col(SomeClass::b),
378380
col(SomeClass::b),
379381
)
380382
newDS3.show()
381383

382-
val newDS4 = dataset.selectTyped(
384+
val newDS4: Dataset<Arity4<IntArray, Int, Int, Int>> = dataset.selectTyped(
383385
col(SomeClass::a),
384386
col(SomeClass::b),
385387
col(SomeClass::b),
386388
col(SomeClass::b),
387389
)
388390
newDS4.show()
389391

390-
val newDS5 = dataset.selectTyped(
392+
val newDS5: Dataset<Arity5<IntArray, Int, Int, Int, Int>> = dataset.selectTyped(
391393
col(SomeClass::a),
392394
col(SomeClass::b),
393395
col(SomeClass::b),

0 commit comments

Comments
 (0)