Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions packages/python-sdk/e2b/sandbox_async/commands/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,28 @@ async def connect(
)

try:
pre_start_events: list[process_pb2.ConnectResponse] = []
start_event = await events.__anext__()

if not start_event.HasField("event"):
raise SandboxException(
f"Failed to connect to process: expected start event, got {start_event}"
)
while not start_event.event.HasField("start"):
if not start_event.HasField("event"):
raise SandboxException(
f"Failed to connect to process: expected start event, got {start_event}"
)
pre_start_events.append(start_event)
start_event = await events.__anext__()

async def iter_replayed_events():
for pre_start_event in pre_start_events:
yield pre_start_event
async for event in events:
yield event

event_source = iter_replayed_events() if pre_start_events else events

return AsyncCommandHandle(
pid=start_event.event.start.pid,
handle_kill=lambda: self.kill(start_event.event.start.pid),
events=events,
events=event_source,
on_stdout=on_stdout,
on_stderr=on_stderr,
)
Expand Down
20 changes: 15 additions & 5 deletions packages/python-sdk/e2b/sandbox_sync/commands/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,27 @@ def connect(
)

try:
pre_start_events: list[process_pb2.ConnectResponse] = []
start_event = events.__next__()
while not start_event.event.HasField("start"):
if not start_event.HasField("event"):
raise SandboxException(
f"Failed to connect to process: expected start event, got {start_event}"
)
pre_start_events.append(start_event)
start_event = events.__next__()

if not start_event.HasField("event"):
raise SandboxException(
f"Failed to connect to process: expected start event, got {start_event}"
)
def iter_replayed_events():
for pre_start_event in pre_start_events:
yield pre_start_event
yield from events

event_source = iter_replayed_events() if pre_start_events else events

return CommandHandle(
pid=start_event.event.start.pid,
handle_kill=lambda: self.kill(start_event.event.start.pid),
events=events,
events=event_source,
)
except Exception as e:
raise handle_rpc_exception(e)
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import Iterator

from packaging.version import Version

from e2b.connection_config import ConnectionConfig
from e2b.sandbox_async.commands.command import Commands
from e2b.envd.process import process_pb2


def make_start_event(pid: int) -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
start=process_pb2.ProcessEvent.StartEvent(pid=pid),
)
)


def make_data_event(data: str) -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
data=process_pb2.ProcessEvent.DataEvent(stdout=data.encode()),
)
)


def make_end_event() -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
end=process_pb2.ProcessEvent.EndEvent(exit_code=0),
)
)


class FakeAsyncConnectStream:
def __init__(self, events: list[process_pb2.ConnectResponse]):
self._events: Iterator[process_pb2.ConnectResponse] = iter(events)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self._events)
except StopIteration as exc:
raise StopAsyncIteration from exc

async def aclose(self):
return None


class FakeProcessClient:
def __init__(self, events: list[process_pb2.ConnectResponse]):
self._events = events

def aconnect(self, *_args, **_kwargs):
return FakeAsyncConnectStream(self._events)


async def test_connect_replays_non_start_stdout_events():
big_chunk = "X" * 8192

events = [
make_data_event(big_chunk),
make_start_event(123),
make_data_event("small\n"),
make_end_event(),
]

commands = Commands(
envd_api_url="https://envd.example",
connection_config=ConnectionConfig(api_key="test-key"),
pool=None,
envd_version=Version("0.0.0"),
)
commands._rpc = FakeProcessClient(events)

stdout: list[str] = []
handle = await commands.connect(123, on_stdout=stdout.append)

result = await handle.wait()

assert handle.pid == 123
assert result.stdout == f"{big_chunk}small\n"
assert "".join(stdout) == f"{big_chunk}small\n"
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from packaging.version import Version

from e2b.connection_config import ConnectionConfig
from e2b.sandbox_sync.commands.command import Commands
from e2b.envd.process import process_pb2


def make_start_event(pid: int) -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
start=process_pb2.ProcessEvent.StartEvent(pid=pid),
)
)


def make_data_event(data: str) -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
data=process_pb2.ProcessEvent.DataEvent(stdout=data.encode()),
)
)


def make_end_event() -> process_pb2.ConnectResponse:
return process_pb2.ConnectResponse(
event=process_pb2.ProcessEvent(
end=process_pb2.ProcessEvent.EndEvent(exit_code=0),
)
)


class FakeConnectStream:
def __init__(self, events):
self._events = iter(events)

def __iter__(self):
return self

def __next__(self):
return next(self._events)

def close(self):
return None


class FakeProcessClient:
def __init__(self, events):
self._events = events

def connect(self, *_args, **_kwargs):
return FakeConnectStream(self._events)


def test_connect_replays_non_start_stdout_events():
big_chunk = "X" * 8192

events = [
make_data_event(big_chunk),
make_start_event(123),
make_data_event("small\n"),
make_end_event(),
]

commands = Commands(
envd_api_url="https://envd.example",
connection_config=ConnectionConfig(api_key="test-key"),
pool=None,
envd_version=Version("0.0.0"),
)
commands._rpc = FakeProcessClient(events)

handle = commands.connect(123)
result = handle.wait()

assert handle.pid == 123
assert result.stdout == f"{big_chunk}small\n"