Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,12 @@ class AddMolecule(Extension):

vis = ZnDraw()
vis.register_job(AddMolecule) # room-scoped (default)
vis.wait()
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()
```

Extensions can be registered as room-scoped (default, visible only in the current room) or global (`vis.register_job(cls, room="@global")`, admin-only, visible in all rooms).
Expand Down
57 changes: 48 additions & 9 deletions docs/source/python-api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,12 @@ Use ``register_job()`` to make an extension available in the UI.

vis = ZnDraw()
vis.register_job(ScaleAtoms) # visible only in vis.room
vis.wait()
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

**Global (admin-only):**

Expand All @@ -1031,7 +1036,12 @@ non-admin users receive a ``PermissionError`` (HTTP 403).

vis = ZnDraw(url="http://localhost:4567", user="admin@example.com", password="...")
vis.register_job(ScaleAtoms, room=GLOBAL_ROOM) # visible in all rooms
vis.wait()
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

**Explicit room:**

Expand Down Expand Up @@ -1066,7 +1076,12 @@ without being serialized:

vis = ZnDraw()
vis.register_job(Predict, run_kwargs={"model": model})
vis.wait()
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

The ``run_kwargs`` dict is stored in the worker process and never sent to the
server. This means values can be non-serializable (torch models, open file
Expand Down Expand Up @@ -1131,18 +1146,37 @@ Worker Lifecycle
^^^^^^^^^^^^^^^^

When you call ``register_job()``, the client connects via Socket.IO and starts
a background worker that claims and executes tasks. Call ``vis.wait()`` to block
until the process is interrupted:
a background worker that claims and executes tasks. Call ``vis.wait()`` to
block the main thread while the worker runs in the background.

``vis.wait()`` delegates to ``socketio.Client.wait()`` and blocks until the
Socket.IO transport disconnects. ``KeyboardInterrupt`` propagates naturally
from the underlying select call, so a ``Ctrl+C``-aware worker script wraps
the call in ``try/except``:

.. code:: python

vis = ZnDraw()
vis.register_job(ExtensionA)
vis.register_job(ExtensionB)
vis.wait() # blocks until Ctrl+C

The worker sends heartbeats to the server. On disconnect, all registered jobs
are cleaned up automatically.
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

``vis.disconnect()`` sends an HTTP ``DELETE`` that fails any claimed tasks,
removes the worker's job links, and soft-deletes jobs with no remaining
workers. It is idempotent, and ``ZnDraw`` also registers it via ``atexit`` —
the explicit ``finally`` block above is there for clarity and to release
any additional resources held alongside the ``ZnDraw`` client.

If the worker process is killed with ``SIGKILL`` (or crashes before
``disconnect()`` runs), the server's background sweeper will soft-delete
the orphaned ``@global`` jobs after the worker's heartbeat becomes stale
(configurable via ``ZNDRAW_JOBLIB_WORKER_TIMEOUT_SECONDS``).


Providers
Expand All @@ -1162,7 +1196,12 @@ filesystem and the built-in ``LoadFile`` extension in one call:

vis = ZnDraw()
vis.register_fs(fsspec.filesystem("file"), name="local")
vis.wait()
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

Users can then load files from the UI via the ``LoadFile`` modifier.

Expand Down
19 changes: 18 additions & 1 deletion src/zndraw/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,24 @@ def disconnect(self) -> None:
self.api.close()

def wait(self) -> None:
"""Block until disconnected."""
"""Block until the Socket.IO transport disconnects.

Delegates to :meth:`socketio.Client.wait` via the underlying
``zndraw_socketio`` wrapper. ``KeyboardInterrupt`` propagates
naturally from the underlying ``select`` call, so a
``Ctrl+C``-aware worker script wraps this in ``try/except``::

vis = ZnDraw()
vis.register_job(MyExtension)
try:
vis.wait()
except KeyboardInterrupt:
pass
finally:
vis.disconnect()

Connects the Socket.IO transport lazily if needed.
"""
self._ensure_socket_connected()
self.socket.wait()

Expand Down
7 changes: 6 additions & 1 deletion src/zndraw/client/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ def disconnect(self) -> None:
log.debug("Disconnected")

def wait(self) -> None:
"""Block until disconnected."""
"""Block until the Socket.IO transport disconnects.

Thin passthrough to ``socketio.Client.wait`` via the
``zndraw_socketio`` wrapper. Users should prefer
:meth:`zndraw.ZnDraw.wait`, which is the documented public API.
"""
self.tsio.wait()

def _on_connect(self) -> None:
Expand Down
47 changes: 24 additions & 23 deletions src/zndraw_joblib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import logging
import random
import signal
import threading
import time
import traceback
Expand Down Expand Up @@ -267,31 +266,33 @@ def disconnect(self) -> None:
self._registry.clear()

def wait(self) -> None:
"""Block until ``disconnect()`` is called or a signal is received.
"""Block until the worker has stopped.

When called from the main thread, installs SIGINT/SIGTERM handlers
that trigger ``disconnect()`` and restores originals on exit.
From other threads, simply blocks on the stop event.
"""
is_main = threading.current_thread() is threading.main_thread()

if is_main:
original_sigint = signal.getsignal(signal.SIGINT)
original_sigterm = signal.getsignal(signal.SIGTERM)

def _shutdown(signum: int, _frame: Any) -> None:
logger.info("Received signal %s, shutting down...", signum)
self.disconnect()
Returns when ``disconnect()`` is called explicitly, when the
background heartbeat loop hits a fatal error (e.g. the worker
was deleted on the server), or when the heartbeat loop has
considered the server unreachable for longer than
``max_unreachable_seconds``.

signal.signal(signal.SIGINT, _shutdown)
signal.signal(signal.SIGTERM, _shutdown)
Mirrors the convention of :meth:`socketio.Client.wait`:
``KeyboardInterrupt`` propagates naturally, so a ``Ctrl+C``-aware
worker script wraps this in ``try/except``::

try:
self._stop.wait()
finally:
if is_main:
signal.signal(signal.SIGINT, original_sigint)
signal.signal(signal.SIGTERM, original_sigterm)
manager = JobManager(api)
manager.register(MyExtension)
try:
manager.wait()
except KeyboardInterrupt:
pass
finally:
manager.disconnect()

This method is the only way to block in HTTP-only mode (when
``JobManager`` is constructed without a Socket.IO transport).
When ``ZnDraw`` is the surrounding wrapper, prefer
:meth:`zndraw.ZnDraw.wait` instead.
"""
self._stop.wait()

@property
def worker_id(self) -> UUID | None:
Expand Down
40 changes: 40 additions & 0 deletions tests/zndraw/worker/_sigkill_worker_child.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Subprocess helper for ``test_global_cleanup_e2e``.

Registers a ``@global`` extension named ``SigkillCleanup`` and sleeps
forever. Prints ``READY <worker_id>`` to stdout once registration
completes; the parent test then sends ``SIGKILL`` to exercise the
sweeper cleanup path.

The leading underscore in the filename prevents pytest from collecting
this module as a test. It is only ever executed as a subprocess via
``python _sigkill_worker_child.py <server_url>``.
"""

from __future__ import annotations

import sys
from typing import ClassVar

from zndraw import ZnDraw
from zndraw_joblib.client import Category, Extension


class SigkillCleanup(Extension):
"""No-op modifier used to register a ``@global`` job."""

category: ClassVar[Category] = Category.MODIFIER

def run(self, vis, **_kwargs): # pragma: no cover - never executed
pass


if __name__ == "__main__":
server_url = sys.argv[1]
worker = ZnDraw(url=server_url)
worker.jobs.register(SigkillCleanup)
# stdout is the IPC channel back to the parent test process
print(f"READY {worker.jobs.worker_id}", flush=True) # noqa: T201
# vis.wait() blocks on the Socket.IO transport. The parent test will
# SIGKILL this process, which bypasses Python entirely — no cleanup,
# no atexit, no try/except needed for the test's purpose.
worker.wait()
Loading