Skip to content

Commit 3c9612a

Browse files
committed
Add on exit callback
1 parent 0a43cf8 commit 3c9612a

File tree

8 files changed

+145
-3
lines changed

8 files changed

+145
-3
lines changed

js/src/messaging.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,8 @@ export async function parseOutput(
311311
line: string,
312312
onStdout?: (output: OutputMessage) => Promise<any> | any,
313313
onStderr?: (output: OutputMessage) => Promise<any> | any,
314-
onResult?: (data: Result) => Promise<any> | any
314+
onResult?: (data: Result) => Promise<any> | any,
315+
onError?: (error: ExecutionError) => Promise<any> | any
315316
) {
316317
const msg = JSON.parse(line)
317318

@@ -348,6 +349,9 @@ export async function parseOutput(
348349
break
349350
case 'error':
350351
execution.error = new ExecutionError(msg.name, msg.value, msg.traceback)
352+
if (onError) {
353+
await onError(execution.error)
354+
}
351355
break
352356
case 'number_of_executions':
353357
execution.executionCount = msg.execution_count

js/src/sandbox.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import { Sandbox as BaseSandbox, InvalidArgumentError } from 'e2b'
22

3-
import { Result, Execution, OutputMessage, parseOutput, extractError } from './messaging'
3+
import {Result, Execution, OutputMessage, parseOutput, extractError, ExecutionError} from './messaging'
44
import { formatExecutionTimeoutError, formatRequestTimeoutError, readLines } from "./utils";
55
import { JUPYTER_PORT, DEFAULT_TIMEOUT_MS } from './consts'
66
export type Context = {
@@ -25,6 +25,7 @@ export class Sandbox extends BaseSandbox {
2525
* @param opts.onStdout Callback for handling stdout messages
2626
* @param opts.onStderr Callback for handling stderr messages
2727
* @param opts.onResult Callback for handling the final result
28+
* @param opts.onError Callback for handling the `ExecutionError` object
2829
* @param opts.envs Environment variables to set for the execution
2930
* @param opts.timeoutMs Max time to wait for the execution to finish
3031
* @param opts.requestTimeoutMs Max time to wait for the request to finish
@@ -37,6 +38,7 @@ async runCode(
3738
onStdout?: (output: OutputMessage) => (Promise<any> | any),
3839
onStderr?: (output: OutputMessage) => (Promise<any> | any),
3940
onResult?: (data: Result) => (Promise<any> | any),
41+
onError?: (error: ExecutionError) => (Promise<any> | any),
4042
envs?: Record<string, string>,
4143
timeoutMs?: number,
4244
requestTimeoutMs?: number,
@@ -52,6 +54,7 @@ async runCode(
5254
* @param opts.onStdout Callback for handling stdout messages
5355
* @param opts.onStderr Callback for handling stderr messages
5456
* @param opts.onResult Callback for handling the final result
57+
* @param opts.onError Callback for handling the `ExecutionError` object
5558
* @param opts.envs Environment variables to set for the execution
5659
* @param opts.timeoutMs Max time to wait for the execution to finish
5760
* @param opts.requestTimeoutMs Max time to wait for the request to finish
@@ -64,6 +67,7 @@ async runCode(
6467
onStdout?: (output: OutputMessage) => (Promise<any> | any),
6568
onStderr?: (output: OutputMessage) => (Promise<any> | any),
6669
onResult?: (data: Result) => (Promise<any> | any),
70+
onError?: (error: ExecutionError) => (Promise<any> | any),
6771
envs?: Record<string, string>,
6872
timeoutMs?: number,
6973
requestTimeoutMs?: number,
@@ -77,6 +81,7 @@ async runCode(
7781
onStdout?: (output: OutputMessage) => (Promise<any> | any),
7882
onStderr?: (output: OutputMessage) => (Promise<any> | any),
7983
onResult?: (data: Result) => (Promise<any> | any),
84+
onError?: (error: ExecutionError) => (Promise<any> | any),
8085
envs?: Record<string, string>,
8186
timeoutMs?: number,
8287
requestTimeoutMs?: number,
@@ -135,7 +140,7 @@ async runCode(
135140

136141
try {
137142
for await (const chunk of readLines(res.body)) {
138-
await parseOutput(execution, chunk, opts?.onStdout, opts?.onStderr, opts?.onResult)
143+
await parseOutput(execution, chunk, opts?.onStdout, opts?.onStderr, opts?.onResult, opts?.onError)
139144
}
140145
} catch (error) {
141146
throw formatExecutionTimeoutError(error)

js/tests/callbacks.test.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import { expect } from 'vitest'
2+
3+
import { sandboxTest } from './setup'
4+
5+
sandboxTest('callback results', async ({ sandbox }) => {
6+
const results = []
7+
const result = await sandbox.runCode('x =1; x', {
8+
onResult: (result) => results.push(result),
9+
})
10+
11+
expect(results.length).toBe(1)
12+
expect(result.results[0].text).toBe('1')
13+
})
14+
15+
16+
sandboxTest('callback error', async ({ sandbox }) => {
17+
const errors = []
18+
const result = await sandbox.runCode('xyz', {
19+
onError: (error) => errors.push(error),
20+
})
21+
22+
expect(errors.length).toBe(1)
23+
expect(result.error.name).toBe('NameError')
24+
})
25+
26+
sandboxTest('callback stdout', async ({ sandbox }) => {
27+
const stdout = []
28+
const result = await sandbox.runCode('print("hello")', {
29+
onStdout: (out) => stdout.push(out),
30+
})
31+
32+
expect(stdout.length).toBe(1)
33+
expect(result.logs.stdout).toEqual(['hello\n'])
34+
})
35+
36+
sandboxTest('callback stderr', async ({ sandbox }) => {
37+
const stderr = []
38+
const result = await sandbox.runCode('import sys;print("This is an error message", file=sys.stderr)', {
39+
onStderr: (err) => stderr.push(err),
40+
})
41+
42+
expect(stderr.length).toBe(1)
43+
expect(result.logs.stderr).toEqual(['This is an error message\n'])
44+
})

python/e2b_code_interpreter/code_interpreter_async.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
from e2b_code_interpreter.models import (
1919
Execution,
20+
ExecutionError,
2021
Context,
2122
Result,
2223
aextract_exception,
@@ -54,6 +55,7 @@ async def run_code(
5455
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
5556
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
5657
on_result: Optional[OutputHandler[Result]] = None,
58+
on_error: Optional[OutputHandler[ExecutionError]] = None,
5759
envs: Optional[Dict[str, str]] = None,
5860
timeout: Optional[float] = None,
5961
request_timeout: Optional[float] = None,
@@ -67,6 +69,7 @@ async def run_code(
6769
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
6870
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
6971
on_result: Optional[OutputHandler[Result]] = None,
72+
on_error: Optional[OutputHandler[ExecutionError]] = None,
7073
envs: Optional[Dict[str, str]] = None,
7174
timeout: Optional[float] = None,
7275
request_timeout: Optional[float] = None,
@@ -80,6 +83,7 @@ async def run_code(
8083
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
8184
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
8285
on_result: Optional[OutputHandler[Result]] = None,
86+
on_error: Optional[OutputHandler[ExecutionError]] = None,
8387
envs: Optional[Dict[str, str]] = None,
8488
timeout: Optional[float] = None,
8589
request_timeout: Optional[float] = None,
@@ -94,6 +98,7 @@ async def run_code(
9498
:param on_stdout: Callback for stdout messages
9599
:param on_stderr: Callback for stderr messages
96100
:param on_result: Callback for the `Result` object
101+
:param on_error: Callback for the `ExecutionError` object
97102
:param envs: Environment variables
98103
:param timeout: Max time to wait for the execution to finish
99104
:param request_timeout: Max time to wait for the request to finish
@@ -136,6 +141,7 @@ async def run_code(
136141
on_stdout=on_stdout,
137142
on_stderr=on_stderr,
138143
on_result=on_result,
144+
on_error=on_error,
139145
)
140146

141147
return execution

python/e2b_code_interpreter/code_interpreter_sync.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
DEFAULT_TIMEOUT,
1212
)
1313
from e2b_code_interpreter.models import (
14+
ExecutionError,
1415
Execution,
1516
Context,
1617
Result,
@@ -46,6 +47,7 @@ def run_code(
4647
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
4748
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
4849
on_result: Optional[OutputHandler[Result]] = None,
50+
on_error: Optional[OutputHandler[ExecutionError]] = None,
4951
envs: Optional[Dict[str, str]] = None,
5052
timeout: Optional[float] = None,
5153
request_timeout: Optional[float] = None,
@@ -59,6 +61,7 @@ def run_code(
5961
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
6062
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
6163
on_result: Optional[OutputHandler[Result]] = None,
64+
on_error: Optional[OutputHandler[ExecutionError]] = None,
6265
envs: Optional[Dict[str, str]] = None,
6366
timeout: Optional[float] = None,
6467
request_timeout: Optional[float] = None,
@@ -72,6 +75,7 @@ def run_code(
7275
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
7376
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
7477
on_result: Optional[OutputHandler[Result]] = None,
78+
on_error: Optional[OutputHandler[ExecutionError]] = None,
7579
envs: Optional[Dict[str, str]] = None,
7680
timeout: Optional[float] = None,
7781
request_timeout: Optional[float] = None,
@@ -86,6 +90,7 @@ def run_code(
8690
:param on_stdout: Callback for stdout messages
8791
:param on_stderr: Callback for stderr messages
8892
:param on_result: Callback for the `Result` object
93+
:param on_error: Callback for the `ExecutionError` object
8994
:param envs: Environment variables
9095
:param timeout: Max time to wait for the execution to finish
9196
:param request_timeout: Max time to wait for the request to finish
@@ -128,6 +133,7 @@ def run_code(
128133
on_stdout=on_stdout,
129134
on_stderr=on_stderr,
130135
on_result=on_result,
136+
on_error=on_error,
131137
)
132138

133139
return execution

python/e2b_code_interpreter/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def parse_output(
394394
on_stdout: Optional[OutputHandler[OutputMessage]] = None,
395395
on_stderr: Optional[OutputHandler[OutputMessage]] = None,
396396
on_result: Optional[OutputHandler[Result]] = None,
397+
on_error: Optional[OutputHandler[ExecutionError]] = None,
397398
):
398399
data = json.loads(output)
399400
data_type = data.pop("type")
@@ -413,6 +414,8 @@ def parse_output(
413414
on_stderr(OutputMessage(data["text"], data["timestamp"], True))
414415
elif data_type == "error":
415416
execution.error = ExecutionError(data["name"], data["value"], data["traceback"])
417+
if on_error:
418+
on_error(execution.error)
416419
elif data_type == "number_of_executions":
417420
execution.execution_count = data["execution_count"]
418421

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from e2b_code_interpreter.code_interpreter_async import AsyncSandbox
2+
3+
4+
async def test_resuls(async_sandbox: AsyncSandbox):
5+
results = []
6+
execution = await async_sandbox.run_code(
7+
"x = 1;x", on_result=lambda result: results.append(result)
8+
)
9+
assert len(results) == 1
10+
assert execution.results[0].text == "1"
11+
12+
13+
async def test_error(async_sandbox: AsyncSandbox):
14+
errors = []
15+
execution = await async_sandbox.run_code(
16+
"xyz", on_error=lambda error: errors.append(error)
17+
)
18+
assert len(errors) == 1
19+
assert execution.error.name == "NameError"
20+
21+
22+
async def test_stdout(async_sandbox: AsyncSandbox):
23+
stdout = []
24+
execution = await async_sandbox.run_code(
25+
"print('Hello from e2b')", on_stdout=lambda out: stdout.append(out)
26+
)
27+
assert len(stdout) == 1
28+
assert execution.logs.stdout == ["Hello from e2b\n"]
29+
30+
31+
async def test_stderr(async_sandbox: AsyncSandbox):
32+
stderr = []
33+
execution = await async_sandbox.run_code(
34+
'import sys;print("This is an error message", file=sys.stderr)',
35+
on_stderr=lambda err: stderr.append(err),
36+
)
37+
assert len(stderr) == 1
38+
assert execution.logs.stderr == ["This is an error message\n"]
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from e2b_code_interpreter.code_interpreter_sync import Sandbox
2+
3+
4+
def test_resuls(sandbox: Sandbox):
5+
results = []
6+
execution = sandbox.run_code(
7+
"x = 1;x", on_result=lambda result: results.append(result)
8+
)
9+
assert len(results) == 1
10+
assert execution.results[0].text == "1"
11+
12+
13+
def test_error(sandbox: Sandbox):
14+
errors = []
15+
execution = sandbox.run_code("xyz", on_error=lambda error: errors.append(error))
16+
assert len(errors) == 1
17+
assert execution.error.name == "NameError"
18+
19+
20+
def test_stdout(sandbox: Sandbox):
21+
stdout = []
22+
execution = sandbox.run_code(
23+
"print('Hello from e2b')", on_stdout=lambda out: stdout.append(out)
24+
)
25+
assert len(stdout) == 1
26+
assert execution.logs.stdout == ["Hello from e2b\n"]
27+
28+
29+
def test_stderr(sandbox: Sandbox):
30+
stderr = []
31+
execution = sandbox.run_code(
32+
'import sys;print("This is an error message", file=sys.stderr)',
33+
on_stderr=lambda err: stderr.append(err),
34+
)
35+
assert len(stderr) == 1
36+
assert execution.logs.stderr == ["This is an error message\n"]

0 commit comments

Comments
 (0)