|
19 | 19 | */ |
20 | 20 | package org.jetbrains.kotlinx.spark.api |
21 | 21 |
|
| 22 | +import io.kotest.assertions.throwables.shouldThrowAny |
22 | 23 | import io.kotest.core.spec.style.ShouldSpec |
| 24 | +import io.kotest.matchers.collections.shouldBeIn |
23 | 25 | import io.kotest.matchers.nulls.shouldNotBeNull |
| 26 | +import io.kotest.matchers.shouldBe |
24 | 27 | import io.kotest.matchers.shouldNotBe |
25 | 28 | import io.kotest.matchers.string.shouldContain |
26 | 29 | import io.kotest.matchers.types.shouldBeInstanceOf |
27 | 30 | import jupyter.kotlin.DependsOn |
28 | | -import kotlinx.serialization.decodeFromString |
29 | | -import kotlinx.serialization.encodeToString |
30 | | -import kotlinx.serialization.json.Json |
31 | | -import org.apache.spark.SparkConf |
32 | 31 | import org.apache.spark.api.java.JavaSparkContext |
| 32 | +import org.apache.spark.streaming.Duration |
33 | 33 | import org.intellij.lang.annotations.Language |
34 | 34 | import org.jetbrains.kotlinx.jupyter.EvalRequestData |
35 | 35 | import org.jetbrains.kotlinx.jupyter.ReplForJupyter |
36 | 36 | import org.jetbrains.kotlinx.jupyter.ReplForJupyterImpl |
37 | 37 | import org.jetbrains.kotlinx.jupyter.api.Code |
38 | | -import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost |
39 | 38 | import org.jetbrains.kotlinx.jupyter.api.MimeTypedResult |
40 | | -import org.jetbrains.kotlinx.jupyter.api.libraries.* |
41 | | -import org.jetbrains.kotlinx.jupyter.dependencies.ResolverConfig |
42 | 39 | import org.jetbrains.kotlinx.jupyter.libraries.EmptyResolutionInfoProvider |
43 | | -import org.jetbrains.kotlinx.jupyter.libraries.LibrariesScanner |
44 | | -import org.jetbrains.kotlinx.jupyter.libraries.LibraryResolver |
45 | | -import org.jetbrains.kotlinx.jupyter.libraries.buildDependenciesInitCode |
46 | 40 | import org.jetbrains.kotlinx.jupyter.repl.EvalResultEx |
47 | 41 | import org.jetbrains.kotlinx.jupyter.testkit.ReplProvider |
48 | | -import org.jetbrains.kotlinx.jupyter.util.NameAcceptanceRule |
49 | 42 | import org.jetbrains.kotlinx.jupyter.util.PatternNameAcceptanceRule |
| 43 | +import org.jetbrains.kotlinx.spark.api.tuples.X |
| 44 | +import org.jetbrains.kotlinx.spark.api.tuples.component1 |
| 45 | +import org.jetbrains.kotlinx.spark.api.tuples.component2 |
| 46 | +import java.util.* |
50 | 47 | import kotlin.script.experimental.jvm.util.classpathFromClassloader |
51 | 48 |
|
52 | 49 | class JupyterTests : ShouldSpec({ |
@@ -235,6 +232,82 @@ class JupyterTests : ShouldSpec({ |
235 | 232 | } |
236 | 233 | }) |
237 | 234 |
|
| 235 | +class JupyterStreamingTests : ShouldSpec({ |
| 236 | + val replProvider = ReplProvider { classpath -> |
| 237 | + ReplForJupyterImpl( |
| 238 | + resolutionInfoProvider = EmptyResolutionInfoProvider, |
| 239 | + scriptClasspath = classpath, |
| 240 | + isEmbedded = true, |
| 241 | + ).apply { |
| 242 | + eval { |
| 243 | + librariesScanner.addLibrariesFromClassLoader( |
| 244 | + classLoader = currentClassLoader, |
| 245 | + host = this, |
| 246 | + integrationTypeNameRules = listOf( |
| 247 | + PatternNameAcceptanceRule(false, "org.jetbrains.kotlinx.spark.api.jupyter.**"), |
| 248 | + PatternNameAcceptanceRule(true, |
| 249 | + "org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration"), |
| 250 | + ), |
| 251 | + ) |
| 252 | + } |
| 253 | + } |
| 254 | + } |
| 255 | + |
| 256 | + val currentClassLoader = DependsOn::class.java.classLoader |
| 257 | + val scriptClasspath = classpathFromClassloader(currentClassLoader).orEmpty() |
| 258 | + |
| 259 | + fun createRepl(): ReplForJupyter = replProvider(scriptClasspath) |
| 260 | + suspend fun withRepl(action: suspend ReplForJupyter.() -> Unit): Unit = createRepl().action() |
| 261 | + |
| 262 | + context("Jupyter") { |
| 263 | + withRepl { |
| 264 | + |
| 265 | + should("Not have spark instance") { |
| 266 | + shouldThrowAny { |
| 267 | + @Language("kts") |
| 268 | + val spark = exec("""spark""") |
| 269 | + Unit |
| 270 | + } |
| 271 | + } |
| 272 | + |
| 273 | + should("Not have sc instance") { |
| 274 | + shouldThrowAny { |
| 275 | + @Language("kts") |
| 276 | + val sc = exec("""sc""") |
| 277 | + Unit |
| 278 | + } |
| 279 | + } |
| 280 | + |
| 281 | + should("stream") { |
| 282 | + val input = listOf("aaa", "bbb", "aaa", "ccc") |
| 283 | + val counter = Counter(0) |
| 284 | + |
| 285 | + withSparkStreaming(Duration(10), timeout = 1000) { |
| 286 | + |
| 287 | + val (counterBroadcast, queue) = withSpark(ssc) { |
| 288 | + spark.broadcast(counter) X LinkedList(listOf(sc.parallelize(input))) |
| 289 | + } |
| 290 | + |
| 291 | + val inputStream = ssc.queueStream(queue) |
| 292 | + |
| 293 | + inputStream.foreachRDD { rdd, _ -> |
| 294 | + withSpark(rdd) { |
| 295 | + rdd.toDS().forEach { |
| 296 | + it shouldBeIn input |
| 297 | + counterBroadcast.value.value++ |
| 298 | + } |
| 299 | + } |
| 300 | + } |
| 301 | + } |
| 302 | + |
| 303 | + counter.value shouldBe input.size |
| 304 | + } |
| 305 | + |
| 306 | + } |
| 307 | + } |
| 308 | +}) |
| 309 | + |
| 310 | + |
238 | 311 | private fun ReplForJupyter.execEx(code: Code): EvalResultEx = evalEx(EvalRequestData(code)) |
239 | 312 |
|
240 | 313 | private fun ReplForJupyter.exec(code: Code): Any? = execEx(code).renderedValue |
|
0 commit comments