@@ -27,10 +27,7 @@ import org.apache.spark.sql.Dataset
2727import org.apache.spark.sql.types.Decimal
2828import org.apache.spark.unsafe.types.CalendarInterval
2929import org.jetbrains.kotlinx.spark.api.tuples.*
30- import scala.Product
31- import scala.Tuple1
32- import scala.Tuple2
33- import scala.Tuple3
30+ import scala.*
3431import java.math.BigDecimal
3532import java.sql.Date
3633import java.sql.Timestamp
@@ -180,6 +177,42 @@ class EncodingTest : ShouldSpec({
180177 context("schema") {
181178 withSpark(props = mapOf("spark.sql.codegen.comments" to true)) {
182179
180+ should("handle Scala case class datasets") {
181+ val caseClasses = listOf(Some (1), Some (2), Some (3))
182+ val dataset = caseClasses.toDS()
183+ dataset.collectAsList() shouldBe caseClasses
184+ }
185+
186+ should("handle Scala case class case class datasets") {
187+ val caseClasses = listOf(
188+ Some (Some (1)),
189+ Some (Some (2)),
190+ Some (Some (3)),
191+ )
192+ val dataset = caseClasses.toDS()
193+ dataset.collectAsList() shouldBe caseClasses
194+ }
195+
196+ should("handle data class Scala case class datasets") {
197+ val caseClasses = listOf(
198+ Some (1) to Some (2),
199+ Some (3) to Some (4),
200+ Some (5) to Some (6),
201+ )
202+ val dataset = caseClasses.toDS()
203+ dataset.collectAsList() shouldBe caseClasses
204+ }
205+
206+ should("handle Scala case class data class datasets") {
207+ val caseClasses = listOf(
208+ Some (1 to 2),
209+ Some (3 to 4),
210+ Some (5 to 6),
211+ )
212+ val dataset = caseClasses.toDS()
213+ dataset.collectAsList() shouldBe caseClasses
214+ }
215+
183216 should("collect data classes with doubles correctly") {
184217 val ll1 = LonLat (1.0, 2.0)
185218 val ll2 = LonLat (3.0, 4.0)
0 commit comments