Skip to content

Commit a841611

Browse files
committed
fixed readme examples, fixed scala printlns not showing up the second time in jupyter cells, working on streams being allowed to be interrupted in jupyter
1 parent 204ac2b commit a841611

File tree

6 files changed

+197
-68
lines changed

6 files changed

+197
-68
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ For more information, check the [wiki](https://github.com/JetBrains/kotlin-spark
271271

272272
## Examples
273273

274-
For more, check out [examples](https://github.com/JetBrains/kotlin-spark-api/tree/master/examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples) module.
274+
For more, check out [examples](examples/src/main/kotlin/org/jetbrains/kotlinx/spark/examples) module.
275275
To get up and running quickly, check out this [tutorial](https://github.com/JetBrains/kotlin-spark-api/wiki/Quick-Start-Guide).
276276

277277
## Reporting issues/Support

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/Integration.kt

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ package org.jetbrains.kotlinx.spark.api.jupyter
2222
import org.apache.spark.api.java.JavaRDDLike
2323
import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.Dataset
25+
import org.jetbrains.kotlinx.jupyter.api.FieldValue
2526
import org.jetbrains.kotlinx.jupyter.api.HTML
2627
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost
2728
import org.jetbrains.kotlinx.jupyter.api.libraries.JupyterIntegration
@@ -33,50 +34,65 @@ abstract class Integration : JupyterIntegration() {
3334
private val scalaVersion = "2.12.15"
3435
private val spark3Version = "3.2.1"
3536

37+
/**
38+
* Will be run after importing all dependencies
39+
*/
3640
abstract fun KotlinKernelHost.onLoaded()
3741

38-
override fun Builder.onLoaded() {
42+
abstract fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue)
43+
44+
open val dependencies: Array<String> = arrayOf(
45+
"org.apache.spark:spark-repl_$scalaCompatVersion:$spark3Version",
46+
"org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion",
47+
"org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion",
48+
"org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version",
49+
"org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version",
50+
"org.apache.spark:spark-mllib_$scalaCompatVersion:$spark3Version",
51+
"org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version",
52+
"org.apache.spark:spark-graphx_$scalaCompatVersion:$spark3Version",
53+
"org.apache.spark:spark-launcher_$scalaCompatVersion:$spark3Version",
54+
"org.apache.spark:spark-catalyst_$scalaCompatVersion:$spark3Version",
55+
"org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version",
56+
"org.apache.spark:spark-core_$scalaCompatVersion:$spark3Version",
57+
"org.scala-lang:scala-library:$scalaVersion",
58+
"org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1",
59+
"org.scala-lang:scala-reflect:$scalaVersion",
60+
"org.scala-lang:scala-compiler:$scalaVersion",
61+
"commons-io:commons-io:2.11.0",
62+
)
3963

40-
dependencies(
41-
"org.apache.spark:spark-repl_$scalaCompatVersion:$spark3Version",
42-
"org.jetbrains.kotlin:kotlin-stdlib-jdk8:$kotlinVersion",
43-
"org.jetbrains.kotlin:kotlin-reflect:$kotlinVersion",
44-
"org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version",
45-
"org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version",
46-
"org.apache.spark:spark-mllib_$scalaCompatVersion:$spark3Version",
47-
"org.apache.spark:spark-sql_$scalaCompatVersion:$spark3Version",
48-
"org.apache.spark:spark-graphx_$scalaCompatVersion:$spark3Version",
49-
"org.apache.spark:spark-launcher_$scalaCompatVersion:$spark3Version",
50-
"org.apache.spark:spark-catalyst_$scalaCompatVersion:$spark3Version",
51-
"org.apache.spark:spark-streaming_$scalaCompatVersion:$spark3Version",
52-
"org.apache.spark:spark-core_$scalaCompatVersion:$spark3Version",
53-
"org.scala-lang:scala-library:$scalaVersion",
54-
"org.scala-lang.modules:scala-xml_$scalaCompatVersion:2.0.1",
55-
"org.scala-lang:scala-reflect:$scalaVersion",
56-
"org.scala-lang:scala-compiler:$scalaVersion",
57-
"commons-io:commons-io:2.11.0",
58-
)
64+
open val imports: Array<String> = arrayOf(
65+
"org.jetbrains.kotlinx.spark.api.*",
66+
"org.jetbrains.kotlinx.spark.api.tuples.*",
67+
*(1..22).map { "scala.Tuple$it" }.toTypedArray(),
68+
"org.apache.spark.sql.functions.*",
69+
"org.apache.spark.*",
70+
"org.apache.spark.sql.*",
71+
"org.apache.spark.api.java.*",
72+
"scala.collection.Seq",
73+
"org.apache.spark.rdd.*",
74+
"java.io.Serializable",
75+
"org.apache.spark.streaming.api.java.*",
76+
"org.apache.spark.streaming.api.*",
77+
"org.apache.spark.streaming.*",
78+
)
5979

60-
import(
61-
"org.jetbrains.kotlinx.spark.api.*",
62-
"org.jetbrains.kotlinx.spark.api.tuples.*",
63-
*(1..22).map { "scala.Tuple$it" }.toTypedArray(),
64-
"org.apache.spark.sql.functions.*",
65-
"org.apache.spark.*",
66-
"org.apache.spark.sql.*",
67-
"org.apache.spark.api.java.*",
68-
"scala.collection.Seq",
69-
"org.apache.spark.rdd.*",
70-
"java.io.Serializable",
71-
"org.apache.spark.streaming.api.java.*",
72-
"org.apache.spark.streaming.api.*",
73-
"org.apache.spark.streaming.*",
74-
)
80+
override fun Builder.onLoaded() {
81+
dependencies(*dependencies)
82+
import(*imports)
7583

7684
onLoaded {
7785
onLoaded()
7886
}
7987

88+
beforeCellExecution {
89+
execute("""scala.Console.setOut(System.out)""")
90+
}
91+
92+
afterCellExecution { snippetInstance, result ->
93+
afterCellExecution(snippetInstance, result)
94+
}
95+
8096
// Render Dataset
8197
render<Dataset<*>> {
8298
HTML(it.toHtml())

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkIntegration.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.jetbrains.kotlinx.spark.api.jupyter
2121

2222

2323
import org.intellij.lang.annotations.Language
24+
import org.jetbrains.kotlinx.jupyter.api.FieldValue
2425
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost
2526

2627
/**
@@ -68,4 +69,6 @@ internal class SparkIntegration : Integration() {
6869
val udf: UDFRegistration get() = spark.udf()""".trimIndent(),
6970
).map(::execute)
7071
}
72+
73+
override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) = Unit
7174
}

jupyter/src/main/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/SparkStreamingIntegration.kt

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ import java.io.InputStreamReader
3333

3434

3535
import org.apache.spark.*
36+
import org.apache.spark.streaming.api.java.JavaStreamingContext
37+
import org.jetbrains.kotlinx.jupyter.api.FieldValue
3638
import org.jetbrains.kotlinx.jupyter.api.KotlinKernelHost
3739
import scala.collection.*
3840
import org.jetbrains.kotlinx.spark.api.SparkSession
@@ -48,13 +50,100 @@ import scala.collection.Iterator as ScalaIterator
4850
@OptIn(ExperimentalStdlibApi::class)
4951
internal class SparkStreamingIntegration : Integration() {
5052

53+
override val imports: Array<String> = super.imports + arrayOf(
54+
"org.apache.spark.deploy.SparkHadoopUtil",
55+
"org.apache.hadoop.conf.Configuration",
56+
)
57+
5158
override fun KotlinKernelHost.onLoaded() {
5259
val _0 = execute("""%dumpClassesForSpark""")
5360

5461
@Language("kts")
5562
val _1 = listOf(
63+
"""
64+
val sscCollection = mutableSetOf<JavaStreamingContext>()
65+
""".trimIndent(),
66+
"""
67+
@JvmOverloads
68+
fun withSparkStreaming(
69+
batchDuration: Duration = Durations.seconds(1L),
70+
checkpointPath: String? = null,
71+
hadoopConf: Configuration = SparkHadoopUtil.get().conf(),
72+
createOnError: Boolean = false,
73+
props: Map<String, Any> = emptyMap(),
74+
master: String = SparkConf().get("spark.master", "local[*]"),
75+
appName: String = "Kotlin Spark Sample",
76+
timeout: Long = -1L,
77+
startStreamingContext: Boolean = true,
78+
func: KSparkStreamingSession.() -> Unit,
79+
) {
80+
var ssc: JavaStreamingContext? = null
81+
try {
82+
83+
// will only be set when a new context is created
84+
var kSparkStreamingSession: KSparkStreamingSession? = null
85+
86+
val creatingFunc = {
87+
val sc = SparkConf()
88+
.setAppName(appName)
89+
.setMaster(master)
90+
.setAll(
91+
props
92+
.map { (key, value) -> key X value.toString() }
93+
.asScalaIterable()
94+
)
95+
96+
val ssc1 = JavaStreamingContext(sc, batchDuration)
97+
ssc1.checkpoint(checkpointPath)
98+
99+
kSparkStreamingSession = KSparkStreamingSession(ssc1)
100+
func(kSparkStreamingSession!!)
101+
102+
ssc1
103+
}
104+
105+
ssc = when {
106+
checkpointPath != null ->
107+
JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError)
108+
109+
else -> creatingFunc()
110+
}
111+
112+
sscCollection += ssc!!
113+
114+
if (startStreamingContext) {
115+
ssc!!.start()
116+
kSparkStreamingSession?.invokeRunAfterStart()
117+
}
118+
ssc!!.awaitTerminationOrTimeout(timeout)
119+
} finally {
120+
ssc?.stop()
121+
println("stopping ssc")
122+
ssc?.awaitTermination()
123+
println("ssc stopped")
124+
ssc?.let(sscCollection::remove)
125+
}
126+
}
127+
""".trimIndent(),
56128
"""
57129
println("To start a spark streaming session, simply use `withSparkStreaming { }` inside a cell. To use Spark normally, use `withSpark { }` in a cell, or use `%use spark` to start a Spark session for the whole notebook.")""".trimIndent(),
58130
).map(::execute)
59131
}
132+
133+
override fun KotlinKernelHost.afterCellExecution(snippetInstance: Any, result: FieldValue) {
134+
135+
@Language("kts")
136+
val _1 = listOf(
137+
"""
138+
while (sscCollection.isNotEmpty())
139+
sscCollection.first().let {
140+
it.stop()
141+
sscCollection.remove(it)
142+
}
143+
""".trimIndent(),
144+
"""
145+
println("afterCellExecution cleanup!")
146+
""".trimIndent()
147+
).map(::execute)
148+
}
60149
}

jupyter/src/test/kotlin/org/jetbrains/kotlinx/spark/api/jupyter/JupyterTests.kt

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import io.kotest.matchers.types.shouldBeInstanceOf
3030
import jupyter.kotlin.DependsOn
3131
import org.apache.spark.api.java.JavaSparkContext
3232
import org.apache.spark.streaming.Duration
33+
import org.apache.spark.streaming.api.java.JavaStreamingContext
3334
import org.intellij.lang.annotations.Language
3435
import org.jetbrains.kotlinx.jupyter.EvalRequestData
3536
import org.jetbrains.kotlinx.jupyter.ReplForJupyter
@@ -155,16 +156,19 @@ class JupyterTests : ShouldSpec({
155156
should("render JavaRDDs with custom class") {
156157

157158
@Language("kts")
158-
val klass = exec("""
159+
val klass = exec(
160+
"""
159161
data class Test(
160162
val longFirstName: String,
161163
val second: LongArray,
162164
val somethingSpecial: Map<Int, String>,
163165
): Serializable
164-
""".trimIndent())
166+
""".trimIndent()
167+
)
165168

166169
@Language("kts")
167-
val html = execHtml("""
170+
val html = execHtml(
171+
"""
168172
val rdd = sc.parallelize(
169173
listOf(
170174
Test("aaaaaaaaa", longArrayOf(1L, 100000L, 24L), mapOf(1 to "one", 2 to "two")),
@@ -246,8 +250,10 @@ class JupyterStreamingTests : ShouldSpec({
246250
host = this,
247251
integrationTypeNameRules = listOf(
248252
PatternNameAcceptanceRule(false, "org.jetbrains.kotlinx.spark.api.jupyter.**"),
249-
PatternNameAcceptanceRule(true,
250-
"org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration"),
253+
PatternNameAcceptanceRule(
254+
true,
255+
"org.jetbrains.kotlinx.spark.api.jupyter.SparkStreamingIntegration"
256+
),
251257
),
252258
)
253259
}
@@ -263,6 +269,13 @@ class JupyterStreamingTests : ShouldSpec({
263269
context("Jupyter") {
264270
withRepl {
265271

272+
should("Have sscCollection instance") {
273+
274+
@Language("kts")
275+
val sscCollection = exec("""sscCollection""")
276+
sscCollection as? MutableSet<JavaStreamingContext> shouldNotBe null
277+
}
278+
266279
should("Not have spark instance") {
267280
shouldThrowAny {
268281
@Language("kts")

kotlin-spark-api/3.2/src/main/kotlin/org/jetbrains/kotlinx/spark/api/SparkSession.kt

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -308,42 +308,50 @@ fun withSparkStreaming(
308308
startStreamingContext: Boolean = true,
309309
func: KSparkStreamingSession.() -> Unit,
310310
) {
311+
var ssc: JavaStreamingContext? = null
312+
try {
311313

312-
// will only be set when a new context is created
313-
var kSparkStreamingSession: KSparkStreamingSession? = null
314+
// will only be set when a new context is created
315+
var kSparkStreamingSession: KSparkStreamingSession? = null
314316

315-
val creatingFunc = {
316-
val sc = SparkConf()
317-
.setAppName(appName)
318-
.setMaster(master)
319-
.setAll(
320-
props
321-
.map { (key, value) -> key X value.toString() }
322-
.asScalaIterable()
323-
)
317+
val creatingFunc = {
318+
val sc = SparkConf()
319+
.setAppName(appName)
320+
.setMaster(master)
321+
.setAll(
322+
props
323+
.map { (key, value) -> key X value.toString() }
324+
.asScalaIterable()
325+
)
324326

325-
val ssc = JavaStreamingContext(sc, batchDuration)
326-
ssc.checkpoint(checkpointPath)
327+
val ssc = JavaStreamingContext(sc, batchDuration)
328+
ssc.checkpoint(checkpointPath)
327329

328-
kSparkStreamingSession = KSparkStreamingSession(ssc)
329-
func(kSparkStreamingSession!!)
330+
kSparkStreamingSession = KSparkStreamingSession(ssc)
331+
func(kSparkStreamingSession!!)
330332

331-
ssc
332-
}
333+
ssc
334+
}
333335

334-
val ssc = when {
335-
checkpointPath != null ->
336-
JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError)
336+
ssc = when {
337+
checkpointPath != null ->
338+
JavaStreamingContext.getOrCreate(checkpointPath, creatingFunc, hadoopConf, createOnError)
337339

338-
else -> creatingFunc()
339-
}
340+
else -> creatingFunc()
341+
}
340342

341-
if (startStreamingContext) {
342-
ssc.start()
343-
kSparkStreamingSession?.invokeRunAfterStart()
343+
if (startStreamingContext) {
344+
ssc!!.start()
345+
kSparkStreamingSession?.invokeRunAfterStart()
346+
}
347+
ssc!!.awaitTerminationOrTimeout(timeout)
348+
} finally {
349+
// TODO remove printlns
350+
ssc?.stop()
351+
println("stopping ssc")
352+
ssc?.awaitTermination()
353+
println("ssc stopped")
344354
}
345-
ssc.awaitTerminationOrTimeout(timeout)
346-
ssc.stop()
347355
}
348356

349357

0 commit comments

Comments
 (0)