Skip to content

Commit d3ebaad

Browse files
committed
Simplify kernel reconnect
1 parent 5de2a23 commit d3ebaad

File tree

1 file changed

+16
-25
lines changed

1 file changed

+16
-25
lines changed

js/src/code-interpreter.ts

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -37,22 +37,19 @@ export class CodeInterpreter extends Sandbox {
3737
}
3838

3939
export class JupyterExtension {
40-
private readonly defaultKernelID: Promise<string>
41-
private readonly setDefaultKernelID: (kernelID: string) => void
42-
private connectedKernels: Kernels = {}
43-
private sandbox: CodeInterpreter
44-
45-
constructor(sandbox: CodeInterpreter) {
46-
this.sandbox = sandbox
47-
const { promise, resolve } = createDeferredPromise<string>()
48-
this.defaultKernelID = promise
49-
this.setDefaultKernelID = resolve
40+
private readonly connectedKernels: Kernels = {}
41+
42+
private readonly kernelIDPromise = createDeferredPromise<string>()
43+
private readonly setDefaultKernelID = this.kernelIDPromise.resolve
44+
45+
private get defaultKernelID() {
46+
return this.kernelIDPromise.promise
5047
}
5148

49+
constructor(private sandbox: CodeInterpreter) { }
50+
5251
async connect(timeout?: number) {
53-
return this.startConnectingToDefaultKernel(this.setDefaultKernelID, {
54-
timeout,
55-
})
52+
return this.startConnectingToDefaultKernel(this.setDefaultKernelID, { timeout })
5653
}
5754

5855
/**
@@ -72,24 +69,15 @@ export class JupyterExtension {
7269
onStdout?: (msg: ProcessMessage) => any,
7370
onStderr?: (msg: ProcessMessage) => any
7471
): Promise<Result> {
75-
kernelID = kernelID || (await this.defaultKernelID)
76-
let ws = this.connectedKernels[kernelID]
77-
78-
if (!ws) {
79-
const url = `${this.sandbox.getProtocol(
80-
'ws'
81-
)}://${this.sandbox.getHostname(8888)}/api/kernels/${kernelID}/channels`
82-
ws = new JupyterKernelWebSocket(url)
83-
await ws.connect()
84-
this.connectedKernels[kernelID] = ws
85-
}
72+
kernelID = kernelID || await this.defaultKernelID
73+
const ws = this.connectedKernels[kernelID] || await this.connectToKernelWS(kernelID)
8674

8775
return await ws.sendExecutionMessage(code, onStdout, onStderr)
8876
}
8977

9078
private async startConnectingToDefaultKernel(
9179
resolve: (value: string) => void,
92-
opts?: { timeout?: number }
80+
opts?: { timeout?: number },
9381
) {
9482
const kernelID = (
9583
await this.sandbox.filesystem.read('/root/.jupyter/kernel_id', opts)
@@ -116,6 +104,8 @@ export class JupyterExtension {
116104
const ws = new JupyterKernelWebSocket(url)
117105
await ws.connect()
118106
this.connectedKernels[kernelID] = ws
107+
108+
return ws
119109
}
120110

121111
/**
@@ -241,6 +231,7 @@ export class JupyterExtension {
241231
* Close all the websocket connections to the kernels. It doesn't shutdown the kernels.
242232
*/
243233
async close() {
234+
// TODO: For in check
244235
for (const kernelID in this.connectedKernels) {
245236
this.connectedKernels[kernelID].close()
246237
}

0 commit comments

Comments
 (0)