diff --git a/packages/python-sdk/e2b/sandbox_async/commands/command.py b/packages/python-sdk/e2b/sandbox_async/commands/command.py index 32b75fd26b..7096e6fec2 100644 --- a/packages/python-sdk/e2b/sandbox_async/commands/command.py +++ b/packages/python-sdk/e2b/sandbox_async/commands/command.py @@ -318,17 +318,28 @@ async def connect( ) try: + pre_start_events: list[process_pb2.ConnectResponse] = [] start_event = await events.__anext__() - - if not start_event.HasField("event"): - raise SandboxException( - f"Failed to connect to process: expected start event, got {start_event}" - ) + while not start_event.event.HasField("start"): + if not start_event.HasField("event"): + raise SandboxException( + f"Failed to connect to process: expected start event, got {start_event}" + ) + pre_start_events.append(start_event) + start_event = await events.__anext__() + + async def iter_replayed_events(): + for pre_start_event in pre_start_events: + yield pre_start_event + async for event in events: + yield event + + event_source = iter_replayed_events() if pre_start_events else events return AsyncCommandHandle( pid=start_event.event.start.pid, handle_kill=lambda: self.kill(start_event.event.start.pid), - events=events, + events=event_source, on_stdout=on_stdout, on_stderr=on_stderr, ) diff --git a/packages/python-sdk/e2b/sandbox_sync/commands/command.py b/packages/python-sdk/e2b/sandbox_sync/commands/command.py index 512b7d9923..2bea57c2a3 100644 --- a/packages/python-sdk/e2b/sandbox_sync/commands/command.py +++ b/packages/python-sdk/e2b/sandbox_sync/commands/command.py @@ -312,17 +312,27 @@ def connect( ) try: + pre_start_events: list[process_pb2.ConnectResponse] = [] start_event = events.__next__() + while not start_event.event.HasField("start"): + if not start_event.HasField("event"): + raise SandboxException( + f"Failed to connect to process: expected start event, got {start_event}" + ) + pre_start_events.append(start_event) + start_event = events.__next__() - if not start_event.HasField("event"): - raise SandboxException( - f"Failed to connect to process: expected start event, got {start_event}" - ) + def iter_replayed_events(): + for pre_start_event in pre_start_events: + yield pre_start_event + yield from events + + event_source = iter_replayed_events() if pre_start_events else events return CommandHandle( pid=start_event.event.start.pid, handle_kill=lambda: self.kill(start_event.event.start.pid), - events=events, + events=event_source, ) except Exception as e: raise handle_rpc_exception(e) diff --git a/packages/python-sdk/tests/async/sandbox_async/commands/test_cmd_connect_stream.py b/packages/python-sdk/tests/async/sandbox_async/commands/test_cmd_connect_stream.py new file mode 100644 index 0000000000..86e866a199 --- /dev/null +++ b/packages/python-sdk/tests/async/sandbox_async/commands/test_cmd_connect_stream.py @@ -0,0 +1,84 @@ +from typing import Iterator + +from packaging.version import Version + +from e2b.connection_config import ConnectionConfig +from e2b.sandbox_async.commands.command import Commands +from e2b.envd.process import process_pb2 + + +def make_start_event(pid: int) -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + start=process_pb2.ProcessEvent.StartEvent(pid=pid), + ) + ) + + +def make_data_event(data: str) -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + data=process_pb2.ProcessEvent.DataEvent(stdout=data.encode()), + ) + ) + + +def make_end_event() -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + end=process_pb2.ProcessEvent.EndEvent(exit_code=0), + ) + ) + + +class FakeAsyncConnectStream: + def __init__(self, events: list[process_pb2.ConnectResponse]): + self._events: Iterator[process_pb2.ConnectResponse] = iter(events) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self._events) + except StopIteration as exc: + raise StopAsyncIteration from exc + + async def aclose(self): + return None + + +class FakeProcessClient: + def __init__(self, events: list[process_pb2.ConnectResponse]): + self._events = events + + def aconnect(self, *_args, **_kwargs): + return FakeAsyncConnectStream(self._events) + + +async def test_connect_replays_non_start_stdout_events(): + big_chunk = "X" * 8192 + + events = [ + make_data_event(big_chunk), + make_start_event(123), + make_data_event("small\n"), + make_end_event(), + ] + + commands = Commands( + envd_api_url="https://envd.example", + connection_config=ConnectionConfig(api_key="test-key"), + pool=None, + envd_version=Version("0.0.0"), + ) + commands._rpc = FakeProcessClient(events) + + stdout: list[str] = [] + handle = await commands.connect(123, on_stdout=stdout.append) + + result = await handle.wait() + + assert handle.pid == 123 + assert result.stdout == f"{big_chunk}small\n" + assert "".join(stdout) == f"{big_chunk}small\n" diff --git a/packages/python-sdk/tests/sync/sandbox_sync/commands/test_cmd_connect_stream.py b/packages/python-sdk/tests/sync/sandbox_sync/commands/test_cmd_connect_stream.py new file mode 100644 index 0000000000..c7198eeed7 --- /dev/null +++ b/packages/python-sdk/tests/sync/sandbox_sync/commands/test_cmd_connect_stream.py @@ -0,0 +1,76 @@ +from packaging.version import Version + +from e2b.connection_config import ConnectionConfig +from e2b.sandbox_sync.commands.command import Commands +from e2b.envd.process import process_pb2 + + +def make_start_event(pid: int) -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + start=process_pb2.ProcessEvent.StartEvent(pid=pid), + ) + ) + + +def make_data_event(data: str) -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + data=process_pb2.ProcessEvent.DataEvent(stdout=data.encode()), + ) + ) + + +def make_end_event() -> process_pb2.ConnectResponse: + return process_pb2.ConnectResponse( + event=process_pb2.ProcessEvent( + end=process_pb2.ProcessEvent.EndEvent(exit_code=0), + ) + ) + + +class FakeConnectStream: + def __init__(self, events): + self._events = iter(events) + + def __iter__(self): + return self + + def __next__(self): + return next(self._events) + + def close(self): + return None + + +class FakeProcessClient: + def __init__(self, events): + self._events = events + + def connect(self, *_args, **_kwargs): + return FakeConnectStream(self._events) + + +def test_connect_replays_non_start_stdout_events(): + big_chunk = "X" * 8192 + + events = [ + make_data_event(big_chunk), + make_start_event(123), + make_data_event("small\n"), + make_end_event(), + ] + + commands = Commands( + envd_api_url="https://envd.example", + connection_config=ConnectionConfig(api_key="test-key"), + pool=None, + envd_version=Version("0.0.0"), + ) + commands._rpc = FakeProcessClient(events) + + handle = commands.connect(123) + result = handle.wait() + + assert handle.pid == 123 + assert result.stdout == f"{big_chunk}small\n"