diff --git a/dimos/agents_deprecated/memory/image_embedding.py b/dimos/agents_deprecated/memory/image_embedding.py index 27e16f1aa8..d6b0967642 100644 --- a/dimos/agents_deprecated/memory/image_embedding.py +++ b/dimos/agents_deprecated/memory/image_embedding.py @@ -63,7 +63,7 @@ def __init__(self, model_name: str = "clip", dimensions: int = 512) -> None: def _initialize_model(self): # type: ignore[no-untyped-def] """Initialize the specified embedding model.""" try: - import onnxruntime as ort # type: ignore[import-untyped] + import onnxruntime as ort # type: ignore[import-untyped,import-not-found] import torch # noqa: F401 from transformers import ( # type: ignore[import-untyped] AutoFeatureExtractor, diff --git a/dimos/core/docker_build.py b/dimos/core/docker_build.py index 7ee90fc5c3..24fd2b3e44 100644 --- a/dimos/core/docker_build.py +++ b/dimos/core/docker_build.py @@ -20,6 +20,7 @@ from __future__ import annotations +import hashlib import subprocess from typing import TYPE_CHECKING @@ -32,10 +33,11 @@ logger = setup_logger() -# Timeout for quick Docker commands +_BUILD_HASH_LABEL = "dimos.build.hash" + DOCKER_CMD_TIMEOUT = 20 -# Sentinel value to detect already-converted Dockerfiles (UUID ensures uniqueness) +# the way of detecting already-converted Dockerfiles (UUID ensures uniqueness) DIMOS_SENTINEL = "DIMOS-MODULE-CONVERSION-427593ae-c6e8-4cf1-9b2d-ee81a420a5dc" # Footer appended to Dockerfiles for DimOS module conversion @@ -53,28 +55,6 @@ """ -def _run(cmd: list[str], *, timeout: float | None = None) -> subprocess.CompletedProcess[str]: - """Run a command and return the result.""" - return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False) - - -def _run_streaming(cmd: list[str]) -> int: - """Run command and stream output to terminal. Returns exit code.""" - result = subprocess.run(cmd, text=True) - return result.returncode - - -def _docker_bin(cfg: DockerModuleConfig) -> str: - """Get docker binary path.""" - return cfg.docker_bin or "docker" - - -def _image_exists(docker_bin: str, image_name: str) -> bool: - """Check if a Docker image exists locally.""" - r = _run([docker_bin, "image", "inspect", image_name], timeout=DOCKER_CMD_TIMEOUT) - return r.returncode == 0 - - def _convert_dockerfile(dockerfile: Path) -> Path: """Append DimOS footer to Dockerfile. Returns path to converted file.""" content = dockerfile.read_text() @@ -85,32 +65,82 @@ def _convert_dockerfile(dockerfile: Path) -> Path: logger.info(f"Converting {dockerfile.name} to DimOS format") - converted = dockerfile.parent / f".{dockerfile.name}.dimos" + converted = dockerfile.parent / f".{dockerfile.name}.ignore" converted.write_text(content.rstrip() + "\n" + DIMOS_FOOTER.lstrip("\n")) return converted +def _compute_build_hash(cfg: DockerModuleConfig) -> str: + """Hash Dockerfile contents and build args.""" + assert cfg.docker_file is not None + digest = hashlib.sha256() + digest.update(cfg.docker_file.read_bytes()) + for key, val in sorted(cfg.docker_build_args.items()): + digest.update(f"{key}={val}".encode()) + for arg in cfg.docker_build_extra_args: + digest.update(arg.encode()) + return digest.hexdigest() + + +def _get_image_build_hash(cfg: DockerModuleConfig) -> str | None: + """Read the build hash label from an existing Docker image.""" + r = subprocess.run( + [ + cfg.docker_bin, + "image", + "inspect", + "-f", + '{{index .Config.Labels "' + _BUILD_HASH_LABEL + '"}}', + cfg.docker_image, + ], + capture_output=True, + text=True, + timeout=DOCKER_CMD_TIMEOUT, + check=False, + ) + if r.returncode != 0: + return None + value = r.stdout.strip() + # docker prints "" when the label is missing + return value if value and value != "" else None + + def build_image(cfg: DockerModuleConfig) -> None: """Build Docker image using footer mode conversion.""" if cfg.docker_file is None: raise ValueError("docker_file is required for building Docker images") + + build_hash = _compute_build_hash(cfg) dockerfile = _convert_dockerfile(cfg.docker_file) context = cfg.docker_build_context or cfg.docker_file.parent - cmd = [_docker_bin(cfg), "build", "-t", cfg.docker_image, "-f", str(dockerfile)] + cmd = [cfg.docker_bin, "build", "-t", cfg.docker_image, "-f", str(dockerfile)] + cmd.extend(["--label", f"{_BUILD_HASH_LABEL}={build_hash}"]) for k, v in cfg.docker_build_args.items(): cmd.extend(["--build-arg", f"{k}={v}"]) + cmd.extend(cfg.docker_build_extra_args) cmd.append(str(context)) logger.info(f"Building Docker image: {cfg.docker_image}") - exit_code = _run_streaming(cmd) - if exit_code != 0: - raise RuntimeError(f"Docker build failed with exit code {exit_code}") + # Stream stdout to terminal so the user sees build progress, but capture + # stderr separately so we can include it in the error message on failure. + result = subprocess.run(cmd, text=True, stderr=subprocess.PIPE) + if result.returncode != 0: + raise RuntimeError( + f"Docker build failed with exit code {result.returncode}\nSTDERR:\n{result.stderr}" + ) def image_exists(cfg: DockerModuleConfig) -> bool: """Check if the configured Docker image exists locally.""" - return _image_exists(_docker_bin(cfg), cfg.docker_image) + r = subprocess.run( + [cfg.docker_bin, "image", "inspect", cfg.docker_image], + capture_output=True, + text=True, + timeout=DOCKER_CMD_TIMEOUT, + check=False, + ) + return r.returncode == 0 __all__ = [ diff --git a/dimos/core/docker_runner.py b/dimos/core/docker_runner.py index dcb75fbdee..fb5770325b 100644 --- a/dimos/core/docker_runner.py +++ b/dimos/core/docker_runner.py @@ -18,16 +18,14 @@ from dataclasses import field import importlib import json -import os import signal import subprocess import threading import time from typing import TYPE_CHECKING, Any -from dimos.core.docker_build import build_image, image_exists -from dimos.core.module import Module, ModuleConfig -from dimos.core.rpc_client import RpcCall +from dimos.core.module import ModuleConfig +from dimos.core.rpc_client import ModuleProxyProtocol, RpcCall from dimos.protocol.rpc.pubsubrpc import LCMRPC from dimos.utils.logging_config import setup_logger from dimos.visualization.rerun.bridge import RERUN_GRPC_PORT, RERUN_WEB_PORT @@ -36,13 +34,15 @@ from collections.abc import Callable from pathlib import Path + from dimos.core.module import Module + logger = setup_logger() DOCKER_RUN_TIMEOUT = 120 # Timeout for `docker run` command execution +DOCKER_PULL_TIMEOUT_DEFAULT = None # No timeout for `docker pull` (images can be large) DOCKER_CMD_TIMEOUT = 20 # Timeout for quick Docker commands (inspect, rm, logs) DOCKER_STATUS_TIMEOUT = 10 # Timeout for container status checks DOCKER_STOP_TIMEOUT = 30 # Timeout for `docker stop` command (graceful shutdown) -RPC_READY_TIMEOUT = 3.0 # Timeout for RPC readiness probe during container startup LOG_TAIL_LINES = 200 # Number of log lines to include in error messages @@ -52,6 +52,8 @@ class DockerModuleConfig(ModuleConfig): For advanced Docker options not listed here, use docker_extra_args. Example: docker_extra_args=["--cap-add=SYS_ADMIN", "--read-only"] + + NOTE: a DockerModule will rebuild automatically if the Dockerfile or build args change """ # Build / image @@ -59,6 +61,7 @@ class DockerModuleConfig(ModuleConfig): docker_file: Path | None = None # Required on host for building, not needed in container docker_build_context: Path | None = None docker_build_args: dict[str, str] = field(default_factory=dict) + docker_build_extra_args: list[str] = field(default_factory=list) # Extra args for docker build # Identity docker_container_name: str | None = None @@ -72,9 +75,9 @@ class DockerModuleConfig(ModuleConfig): ) # (host, container, proto) # Runtime resources - docker_gpus: str | None = "all" - docker_shm_size: str = "2g" - docker_restart_policy: str = "on-failure:3" + docker_gpus: str | None = None + docker_shm_size: str = "4g" + docker_restart_policy: str = "no" # Env + volumes + devices docker_env_files: list[str] = field(default_factory=list) @@ -93,10 +96,14 @@ class DockerModuleConfig(ModuleConfig): docker_command: list[str] | None = None docker_extra_args: list[str] = field(default_factory=list) - # Startup readiness + # Timeouts + docker_pull_timeout: float | None = DOCKER_PULL_TIMEOUT_DEFAULT docker_startup_timeout: float = 120.0 docker_poll_interval: float = 1.0 + # Reconnect to a running container instead of restarting it + docker_reconnect_container: bool = False + # Advanced docker_bin: str = "docker" @@ -104,7 +111,11 @@ class DockerModuleConfig(ModuleConfig): def is_docker_module(module_class: type) -> bool: """Check if a module class should run in Docker based on its default_config.""" default_config = getattr(module_class, "default_config", None) - return default_config is not None and issubclass(default_config, DockerModuleConfig) + return ( + default_config is not None + and isinstance(default_config, type) + and issubclass(default_config, DockerModuleConfig) + ) # Docker helpers @@ -115,25 +126,20 @@ def _run(cmd: list[str], *, timeout: float | None = None) -> subprocess.Complete return subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False) -def _docker_bin(cfg: DockerModuleConfig) -> str: - """Get docker binary path, defaulting to 'docker' if empty/None.""" - return cfg.docker_bin or "docker" - - def _remove_container(cfg: DockerModuleConfig, name: str) -> None: - _run([_docker_bin(cfg), "rm", "-f", name], timeout=DOCKER_CMD_TIMEOUT) + _run([cfg.docker_bin, "rm", "-f", name], timeout=DOCKER_CMD_TIMEOUT) def _is_container_running(cfg: DockerModuleConfig, name: str) -> bool: r = _run( - [_docker_bin(cfg), "inspect", "-f", "{{.State.Running}}", name], + [cfg.docker_bin, "inspect", "-f", "{{.State.Running}}", name], timeout=DOCKER_STATUS_TIMEOUT, ) return r.returncode == 0 and r.stdout.strip() == "true" def _tail_logs(cfg: DockerModuleConfig, name: str, n: int = LOG_TAIL_LINES) -> str: - r = _run([_docker_bin(cfg), "logs", "--tail", str(n), name], timeout=DOCKER_CMD_TIMEOUT) + r = _run([cfg.docker_bin, "logs", "--tail", str(n), name], timeout=DOCKER_CMD_TIMEOUT) out = (r.stdout or "").rstrip() err = (r.stderr or "").rstrip() return out + ("\n" + err if err else "") @@ -156,143 +162,211 @@ def _extract_module_config(cfg: DockerModuleConfig) -> dict[str, Any]: # Host-side Docker-backed Module handle -class DockerModule: +class DockerModule(ModuleProxyProtocol): """ Host-side handle for a module running inside Docker. Lifecycle: - - start(): launches container, waits for module ready via RPC - - stop(): stops container - - __getattr__: exposes RpcCall for @rpc methods on remote module + - start(): builds the image if needed, launches the container, waits for readiness, calls the remote module's start() RPC (after streams are wired) + - stop(): stops the container and cleans up Communication: All RPC happens via LCM multicast (requires --network=host). """ + config: DockerModuleConfig + def __init__(self, module_class: type[Module], *args: Any, **kwargs: Any) -> None: - # Config + from dimos.core.docker_build import ( + _compute_build_hash, + _get_image_build_hash, + build_image, + image_exists, + ) + + # g (GlobalConfig) is passed by deploy pipeline but handled by the base config + kwargs.pop("g", None) + config_class = getattr(module_class, "default_config", DockerModuleConfig) + if not issubclass(config_class, DockerModuleConfig): + raise TypeError( + f"{module_class.__name__}.default_config must be a DockerModuleConfig subclass, " + f"got {config_class.__name__}" + ) config = config_class(**kwargs) - # Module info self._module_class = module_class - self._config = config + self.config = config self._args = args self._kwargs = kwargs - self._running = False + self._running = threading.Event() self.remote_name = module_class.__name__ + # Derive container name from image + class name: "my-registry/foo:v2" → "dimos_myclass_foo_v2" + image_ref = config.docker_image.rsplit("/", 1)[-1] self._container_name = ( config.docker_container_name - or f"dimos_{module_class.__name__.lower()}_{os.getpid()}_{int(time.time())}" + or f"dimos_{module_class.__name__.lower()}_{image_ref.replace(':', '_')}" ) - # RPC setup - self.rpc = LCMRPC() + self.rpc = LCMRPC( + rpc_timeouts=self.config.rpc_timeouts, + default_rpc_timeout=self.config.default_rpc_timeout, + ) self.rpcs = set(module_class.rpcs.keys()) # type: ignore[attr-defined] self.rpc_calls: list[str] = getattr(module_class, "rpc_calls", []) self._unsub_fns: list[Callable[[], None]] = [] self._bound_rpc_calls: dict[str, RpcCall] = {} - # Build image if needed (but don't start - caller must call start() explicitly) - if not image_exists(config): - logger.info(f"Building {config.docker_image}") - build_image(config) + # Build or pull image, launch container, wait for RPC server + try: + if config.docker_file is not None: + current_hash = _compute_build_hash(config) + stored_hash = _get_image_build_hash(config) + if current_hash != stored_hash: + logger.info(f"Building {config.docker_image}") + build_image(config) + elif not image_exists(config): + logger.info(f"Pulling {config.docker_image}") + r = subprocess.run( + [config.docker_bin, "pull", config.docker_image], + text=True, + stderr=subprocess.PIPE, + timeout=config.docker_pull_timeout, + ) + if r.returncode != 0: + raise RuntimeError( + f"Failed to pull image '{config.docker_image}'.\nSTDERR:\n{r.stderr}" + ) + + reconnect = False + if _is_container_running(config, self._container_name): + if config.docker_reconnect_container: + logger.info(f"Reconnecting to running container: {self._container_name}") + reconnect = True + else: + logger.info(f"Stopping existing container: {self._container_name}") + _run( + [config.docker_bin, "stop", self._container_name], + timeout=DOCKER_STOP_TIMEOUT, + ) + + if not reconnect: + _remove_container(config, self._container_name) + cmd = self._build_docker_run_command() + logger.info(f"Starting docker container: {self._container_name}") + r = _run(cmd, timeout=DOCKER_RUN_TIMEOUT) + if r.returncode != 0: + raise RuntimeError( + f"Failed to start container.\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}" + ) + self.rpc.start() + self._running.set() + # docker run -d returns before Module.__init__ finishes in the container, + # so we poll until the RPC server is reachable before returning. + self._wait_for_rpc() + except Exception: + with suppress(Exception): + self._cleanup() + raise + + def get_rpc_method_names(self) -> list[str]: + return self.rpc_calls def set_rpc_method(self, method: str, callable: RpcCall) -> None: callable.set_rpc(self.rpc) self._bound_rpc_calls[method] = callable + # Forward to container — Module.set_rpc_method unpickles the RpcCall + # and wires it with the container's own LCMRPC + self.rpc.call_sync( + f"{self.remote_name}/set_rpc_method", + ([method, callable], {}), + ) def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: - # Check all requested methods exist missing = set(methods) - self._bound_rpc_calls.keys() if missing: raise ValueError(f"RPC methods not found: {missing}") - # Return single RpcCall or tuple calls = tuple(self._bound_rpc_calls[m] for m in methods) return calls[0] if len(calls) == 1 else calls def start(self) -> None: - if self._running: - return - - cfg = self._config - - # Prevent accidental kill of running container with same name - if _is_container_running(cfg, self._container_name): - raise RuntimeError( - f"Container '{self._container_name}' already running. " - "Choose a different container_name or stop the existing container." - ) - _remove_container(cfg, self._container_name) - - cmd = self._build_docker_run_command() - logger.info(f"Starting docker container: {self._container_name}") - r = _run(cmd, timeout=DOCKER_RUN_TIMEOUT) - if r.returncode != 0: - raise RuntimeError( - f"Failed to start container.\nSTDOUT:\n{r.stdout}\nSTDERR:\n{r.stderr}" - ) - - self.rpc.start() - self._running = True - self._wait_for_ready() + """Invoke the remote module's start() RPC.""" + try: + self.rpc.call_sync(f"{self.remote_name}/start", ([], {})) + except Exception: + with suppress(Exception): + self.stop() + raise def stop(self) -> None: """Gracefully stop the Docker container and clean up resources.""" - # Signal remote module, stop RPC, unsubscribe handlers (ignore failures) + if not self._running.is_set(): + return + self._running.clear() # claim shutdown before any side-effects with suppress(Exception): - if self._running: - self.rpc.call_nowait(f"{self.remote_name}/stop", ([], {})) + self.rpc.call_nowait(f"{self.remote_name}/stop", ([], {})) + self._cleanup() + + def _cleanup(self) -> None: + """Release all resources. Idempotent — safe to call from partial init or after stop().""" with suppress(Exception): self.rpc.stop() for unsub in self._unsub_fns: with suppress(Exception): unsub() self._unsub_fns.clear() - - # Stop and remove container - _run([_docker_bin(self._config), "stop", self._container_name], timeout=DOCKER_STOP_TIMEOUT) - _remove_container(self._config, self._container_name) - self._running = False - logger.info(f"Stopped container: {self._container_name}") + if not getattr(getattr(self, "config", None), "docker_reconnect_container", False): + with suppress(Exception): + _run( + [self.config.docker_bin, "stop", self._container_name], + timeout=DOCKER_STOP_TIMEOUT, + ) + with suppress(Exception): + _remove_container(self.config, self._container_name) + self._running.clear() + logger.info(f"Cleaned up container handle: {self._container_name}") def status(self) -> dict[str, Any]: - cfg = self._config + cfg = self.config return { "module": self.remote_name, "container_name": self._container_name, "image": cfg.docker_image, - "running": bool(self._running and _is_container_running(cfg, self._container_name)), + "running": self._running.is_set() and _is_container_running(cfg, self._container_name), } def tail_logs(self, n: int = 200) -> str: - return _tail_logs(self._config, self._container_name, n=n) + return _tail_logs(self.config, self._container_name, n=n) def set_transport(self, stream_name: str, transport: Any) -> bool: - """Configure stream transport in container. Mirrors Module.set_transport() for autoconnect().""" - topic = getattr(transport, "topic", None) - if topic is None: - return False - if hasattr(topic, "topic"): - topic = topic.topic + """Forward to the container's Module.set_transport RPC.""" result, _ = self.rpc.call_sync( - f"{self.remote_name}/configure_stream", ([stream_name, str(topic)], {}) + f"{self.remote_name}/set_transport", + ([stream_name, transport], {}), ) return bool(result) def __getattr__(self, name: str) -> Any: - if name in self.rpcs: + rpcs = self.__dict__.get("rpcs") + if rpcs is not None and name in rpcs: original_method = getattr(self._module_class, name, None) - return RpcCall(original_method, self.rpc, name, self.remote_name, self._unsub_fns, None) - raise AttributeError(f"{name} not found on {self._module_class.__name__}") + return RpcCall( + original_method, + self.rpc, + name, + self.remote_name, + self._unsub_fns, + None, + ) + raise AttributeError(f"{name} not found on {type(self).__name__}") # Docker command building (split into focused helpers for readability) def _build_docker_run_command(self) -> list[str]: """Build the complete `docker run` command.""" - cfg = self._config + cfg = self.config self._validate_config(cfg) - cmd = [_docker_bin(cfg), "run", "-d"] + cmd = [cfg.docker_bin, "run", "-d"] self._add_lifecycle_args(cmd, cfg) self._add_network_args(cmd, cfg) self._add_port_args(cmd, cfg) @@ -399,16 +473,53 @@ def _build_container_command(self, cfg: DockerModuleConfig) -> list[str]: if cfg.docker_command: return list(cfg.docker_command) - module_path = f"{self._module_class.__module__}.{self._module_class.__name__}" + module_name = self._module_class.__module__ + if module_name == "__main__": + # When run as `python script.py`, __module__ is "__main__". + # Resolve to the actual dotted module path so the container can import it. + import __main__ + + spec = getattr(__main__, "__spec__", None) + if spec and spec.name: + module_name = spec.name + else: + # Fallback: derive from file path relative to cwd + main_file = getattr(__main__, "__file__", None) + if main_file: + import pathlib + + try: + rel = pathlib.Path(main_file).resolve().relative_to(pathlib.Path.cwd()) + except ValueError: + raise RuntimeError( + f"Cannot derive module path: '{main_file}' is not under cwd " + f"'{pathlib.Path.cwd()}'. " + "Run with `python -m` or set docker_command explicitly." + ) from None + module_name = str(rel.with_suffix("")).replace("/", ".") + else: + raise RuntimeError( + "Cannot determine module path for __main__. " + "Run with `python -m` or set docker_command explicitly." + ) + module_path = f"{module_name}.{self._module_class.__name__}" # Filter out docker-specific kwargs (paths, etc.) - only pass module config kwargs = {"config": _extract_module_config(cfg)} payload = {"module_path": module_path, "args": list(self._args), "kwargs": kwargs} # DimOS base image entrypoint already runs "dimos.core.docker_runner run" - return ["--payload", json.dumps(payload, separators=(",", ":"))] - - def _wait_for_ready(self) -> None: - """Poll the module's RPC endpoint until ready, crashed, or timeout.""" - cfg = self._config + try: + payload_json = json.dumps(payload, separators=(",", ":")) + except TypeError as e: + raise TypeError( + f"Cannot serialize DockerModule payload to JSON: {e}\n" + f"Ensure all constructor args/kwargs for {self._module_class.__name__} are " + f"JSON-serializable, or use docker_command to bypass automatic payload generation." + ) from e + return ["--payload", payload_json] + + def _wait_for_rpc(self) -> None: + """Poll until the container's RPC server is reachable.""" + cfg = self.config start_time = time.time() logger.info(f"Waiting for {self.remote_name} to be ready...") @@ -420,13 +531,14 @@ def _wait_for_ready(self) -> None: try: self.rpc.call_sync( - f"{self.remote_name}/start", ([], {}), rpc_timeout=RPC_READY_TIMEOUT + f"{self.remote_name}/get_rpc_method_names", + ([], {}), + rpc_timeout=3.0, # short timeout for polling readiness ) elapsed = time.time() - start_time logger.info(f"{self.remote_name} ready ({elapsed:.1f}s)") return except (TimeoutError, ConnectionError, OSError): - # Module not ready yet - retry after poll interval time.sleep(cfg.docker_poll_interval) logs = _tail_logs(cfg, self._container_name) diff --git a/dimos/core/docker_worker_manager.py b/dimos/core/docker_worker_manager.py new file mode 100644 index 0000000000..94a5793c3d --- /dev/null +++ b/dimos/core/docker_worker_manager.py @@ -0,0 +1,52 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from contextlib import suppress +from typing import TYPE_CHECKING, Any + +from dimos.core.module import ModuleSpec +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map + +if TYPE_CHECKING: + from dimos.core.docker_runner import DockerModule + + +class DockerWorkerManager: + """Parallel deployment of Docker-backed modules.""" + + @staticmethod + def deploy_parallel( + specs: list[ModuleSpec], + ) -> list[DockerModule]: + """Deploy multiple DockerModules in parallel. + + If any deployment fails, all successfully-started containers are + stopped before an ExceptionGroup is raised. + """ + from dimos.core.docker_runner import DockerModule + + def _on_errors( + _outcomes: list[Any], successes: list[DockerModule], errors: list[Exception] + ) -> None: + for mod in successes: + with suppress(Exception): + mod.stop() + raise ExceptionGroup("docker deploy_parallel failed", errors) + + return safe_thread_map( + specs, + lambda spec: DockerModule(spec[0], g=spec[1], **spec[2]), # type: ignore[arg-type] + _on_errors, + ) diff --git a/dimos/core/module.py b/dimos/core/module.py index 1c5b311883..59c8833ea8 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -30,6 +30,7 @@ ) from langchain_core.tools import tool +from pydantic import Field from reactivex.disposable import CompositeDisposable from dimos.core.core import T, rpc @@ -40,7 +41,7 @@ from dimos.core.rpc_client import RpcCall from dimos.core.stream import In, Out, RemoteOut, Transport from dimos.protocol.rpc.pubsubrpc import LCMRPC -from dimos.protocol.rpc.spec import RPCSpec +from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT, DEFAULT_RPC_TIMEOUTS, RPCSpec from dimos.protocol.service.spec import BaseConfig, Configurable from dimos.protocol.tf.tf import LCMTF, TFSpec from dimos.utils import colors @@ -79,6 +80,8 @@ def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: class ModuleConfig(BaseConfig): rpc_transport: type[RPCSpec] = LCMRPC + default_rpc_timeout: float = DEFAULT_RPC_TIMEOUT + rpc_timeouts: dict[str, float] = Field(default_factory=lambda: dict(DEFAULT_RPC_TIMEOUTS)) tf_transport: type[TFSpec] = LCMTF # type: ignore[type-arg] frame_id_prefix: str | None = None frame_id: str | None = None @@ -104,6 +107,7 @@ class ModuleBase(Configurable[ModuleConfigT], Resource): _bound_rpc_calls: dict[str, RpcCall] = {} _module_closed: bool = False _module_closed_lock: threading.Lock + _loop_thread_timeout: float = 2.0 rpc_calls: list[str] = [] @@ -113,7 +117,10 @@ def __init__(self, config_args: dict[str, Any]): self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() try: - self.rpc = self.config.rpc_transport() + self.rpc = self.config.rpc_transport( # type: ignore[call-arg] + rpc_timeouts=self.config.rpc_timeouts, + default_rpc_timeout=self.config.default_rpc_timeout, + ) self.rpc.serve_module_rpc(self) self.rpc.start() # type: ignore[attr-defined] except ValueError: @@ -151,7 +158,7 @@ def _close_module(self) -> None: if loop_thread.is_alive(): if loop: loop.call_soon_threadsafe(loop.stop) - loop_thread.join(timeout=2) + loop_thread.join(timeout=self._loop_thread_timeout) self._loop = None self._loop_thread = None @@ -456,17 +463,6 @@ def set_transport(self, stream_name: str, transport: Transport) -> bool: # type stream._transport = transport return True - @rpc - def configure_stream(self, stream_name: str, topic: str) -> bool: - """Configure a stream's transport by topic. Called by DockerModule for stream wiring.""" - from dimos.core.transport import pLCMTransport - - stream = getattr(self, stream_name, None) - if not isinstance(stream, (Out, In)): - return False - stream._transport = pLCMTransport(topic) - return True - # called from remote def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): # type: ignore[no-untyped-def] input_stream = getattr(self, input_name, None) diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 10227eae93..d2d1db67be 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -14,19 +14,20 @@ from __future__ import annotations -from concurrent.futures import ThreadPoolExecutor import threading from typing import TYPE_CHECKING, Any +from dimos.core.docker_worker_manager import DockerWorkerManager from dimos.core.global_config import GlobalConfig, global_config from dimos.core.module import ModuleBase, ModuleSpec from dimos.core.resource import Resource from dimos.core.worker_manager import WorkerManager from dimos.utils.logging_config import setup_logger +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map if TYPE_CHECKING: from dimos.core.resource_monitor.monitor import StatsMonitor - from dimos.core.rpc_client import ModuleProxy + from dimos.core.rpc_client import ModuleProxy, ModuleProxyProtocol from dimos.core.worker import Worker logger = setup_logger() @@ -37,7 +38,7 @@ class ModuleCoordinator(Resource): # type: ignore[misc] _global_config: GlobalConfig _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[type[ModuleBase], ModuleProxy] + _deployed_modules: dict[type[ModuleBase], ModuleProxyProtocol] _stats_monitor: StatsMonitor | None = None def __init__( @@ -113,7 +114,8 @@ def stop(self) -> None: logger.error("Error stopping module", module=module_class.__name__, exc_info=True) logger.info("Module stopped.", module=module_class.__name__) - self._client.close_all() # type: ignore[union-attr] + if self._client is not None: + self._client.close_all() def deploy( self, @@ -121,35 +123,90 @@ def deploy( global_config: GlobalConfig = global_config, **kwargs: Any, ) -> ModuleProxy: + # Inline to avoid circular import: module_coordinator → docker_runner → module → blueprints → module_coordinator + from dimos.core.docker_runner import DockerModule, is_docker_module + if not self._client: raise ValueError("Trying to dimos.deploy before the client has started") - module = self._client.deploy(module_class, global_config, kwargs) - self._deployed_modules[module_class] = module # type: ignore[assignment] - return module # type: ignore[return-value] + deployed_module: ModuleProxyProtocol + if is_docker_module(module_class): + deployed_module = DockerModule(module_class, g=global_config, **kwargs) # type: ignore[arg-type] + else: + deployed_module = self._client.deploy(module_class, global_config, kwargs) + self._deployed_modules[module_class] = deployed_module # type: ignore[assignment] + return deployed_module # type: ignore[return-value] def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]: + # Inline to avoid circular import: module_coordinator → docker_runner → module → blueprints → module_coordinator + from dimos.core.docker_runner import is_docker_module + if not self._client: raise ValueError("Not started") - modules = self._client.deploy_parallel(module_specs) - for (module_class, _, _), module in zip(module_specs, modules, strict=True): - self._deployed_modules[module_class] = module # type: ignore[assignment] - return modules # type: ignore[return-value] + # Split by type, tracking original indices for reassembly + docker_indices: list[int] = [] + worker_indices: list[int] = [] + docker_specs: list[ModuleSpec] = [] + worker_specs: list[ModuleSpec] = [] + for i, spec in enumerate(module_specs): + if is_docker_module(spec[0]): + docker_indices.append(i) + docker_specs.append(spec) + else: + worker_indices.append(i) + worker_specs.append(spec) + + # Deploy worker and docker modules in parallel. + results: list[Any] = [None] * len(module_specs) + + def _deploy_workers() -> None: + if not worker_specs: + return + assert self._client is not None + for index, module in zip( + worker_indices, self._client.deploy_parallel(worker_specs), strict=True + ): + results[index] = module + + def _deploy_docker() -> None: + if not docker_specs: + return + for index, module in zip( + docker_indices, DockerWorkerManager.deploy_parallel(docker_specs), strict=True + ): + results[index] = module + + def _register() -> None: + for (module_class, _, _), module in zip(module_specs, results, strict=True): + if module is not None: + self._deployed_modules[module_class] = module + + def _on_errors( + _outcomes: list[Any], _successes: list[Any], errors: list[Exception] + ) -> None: + _register() + raise ExceptionGroup("deploy_parallel failed", errors) + + safe_thread_map([_deploy_workers, _deploy_docker], lambda fn: fn(), _on_errors) + _register() + return results def start_all_modules(self) -> None: modules = list(self._deployed_modules.values()) - if isinstance(self._client, WorkerManager): - with ThreadPoolExecutor(max_workers=len(modules)) as executor: - list(executor.map(lambda m: m.start(), modules)) - else: - for module in modules: - module.start() + if not modules: + raise ValueError("No modules deployed. Call deploy() before start_all_modules().") + + def _on_start_errors( + _outcomes: list[Any], _successes: list[Any], errors: list[Exception] + ) -> None: + raise ExceptionGroup("start_all_modules failed", errors) + + safe_thread_map(modules, lambda m: m.start(), _on_start_errors) - module_list = list(self._deployed_modules.values()) for module in modules: if hasattr(module, "on_system_modules"): - module.on_system_modules(module_list) + module.on_system_modules(modules) def get_instance(self, module: type[ModuleBase]) -> ModuleProxy: return self._deployed_modules.get(module) # type: ignore[return-value, no-any-return] diff --git a/dimos/core/resource.py b/dimos/core/resource.py index 63b1eec4f0..a4c008b806 100644 --- a/dimos/core/resource.py +++ b/dimos/core/resource.py @@ -15,7 +15,13 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Self +import sys +from typing import TYPE_CHECKING + +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self if TYPE_CHECKING: from types import TracebackType diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index 84de18d671..7ac34bb645 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -13,12 +13,12 @@ # limitations under the License. from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Protocol from dimos.core.stream import RemoteStream from dimos.core.worker import MethodCallProxy from dimos.protocol.rpc.pubsubrpc import LCMRPC -from dimos.protocol.rpc.spec import RPCSpec +from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT, DEFAULT_RPC_TIMEOUTS, RPCSpec from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -67,7 +67,10 @@ def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] self._stop_rpc_client() return None - result, unsub_fn = self._rpc.call_sync(f"{self._remote_name}/{self._name}", (args, kwargs)) # type: ignore[arg-type] + result, unsub_fn = self._rpc.call_sync( + f"{self._remote_name}/{self._name}", + (args, kwargs), # type: ignore[arg-type] + ) self._unsub_fns.append(unsub_fn) return result @@ -75,15 +78,34 @@ def __getstate__(self): # type: ignore[no-untyped-def] return (self._name, self._remote_name) def __setstate__(self, state) -> None: # type: ignore[no-untyped-def] - self._name, self._remote_name = state + # Support both old 2-tuple and new 3-tuple (legacy) state for pickle compat. + if len(state) == 3: + self._name, self._remote_name, _ = state + else: + self._name, self._remote_name = state self._unsub_fns = [] self._rpc = None self._stop_rpc_client = None +class ModuleProxyProtocol(Protocol): + """Protocol for host-side handles to remote modules (worker or Docker).""" + + def start(self) -> None: ... + def stop(self) -> None: ... + def set_transport(self, stream_name: str, transport: Any) -> bool: ... + def get_rpc_method_names(self) -> list[str]: ... + def set_rpc_method(self, method: str, callable: RpcCall) -> None: ... + def get_rpc_calls(self, *methods: str) -> RpcCall | tuple[RpcCall, ...]: ... + + class RPCClient: def __init__(self, actor_instance, actor_class) -> None: # type: ignore[no-untyped-def] - self.rpc = LCMRPC() + default_config = getattr(actor_class, "default_config", None) + self.rpc = LCMRPC( + rpc_timeouts=getattr(default_config, "rpc_timeouts", dict(DEFAULT_RPC_TIMEOUTS)), + default_rpc_timeout=getattr(default_config, "default_rpc_timeout", DEFAULT_RPC_TIMEOUT), + ) self.actor_class = actor_class self.remote_name = actor_class.__name__ self.actor_instance = actor_instance diff --git a/dimos/core/run_registry.py b/dimos/core/run_registry.py index 617872011c..a3807194f6 100644 --- a/dimos/core/run_registry.py +++ b/dimos/core/run_registry.py @@ -21,6 +21,7 @@ import os from pathlib import Path import re +import signal import time from dimos.utils.logging_config import setup_logger @@ -143,9 +144,6 @@ def get_most_recent(alive_only: bool = True) -> RunEntry | None: return runs[-1] if runs else None -import signal - - def stop_entry(entry: RunEntry, force: bool = False) -> tuple[str, bool]: """Stop a DimOS instance by registry entry. diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index f9a89829d5..7cd0f89b36 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -77,7 +77,7 @@ def test_classmethods() -> None: # Check that we have the expected RPC methods assert "navigate_to" in class_rpcs, "navigate_to should be in rpcs" assert "start" in class_rpcs, "start should be in rpcs" - assert len(class_rpcs) == 9 + assert len(class_rpcs) == 8 # Check that the values are callable assert callable(class_rpcs["navigate_to"]), "navigate_to should be callable" diff --git a/dimos/core/tests/test_docker_deployment.py b/dimos/core/tests/test_docker_deployment.py new file mode 100644 index 0000000000..d8eb9448ff --- /dev/null +++ b/dimos/core/tests/test_docker_deployment.py @@ -0,0 +1,283 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Smoke tests for Docker module deployment routing. + +These tests verify that the ModuleCoordinator correctly detects and routes +docker modules to DockerModule WITHOUT actually running Docker. +""" + +from __future__ import annotations + +from pathlib import Path +import threading +from unittest.mock import MagicMock, patch + +import pytest + +from dimos.core.docker_runner import DockerModule, DockerModuleConfig, is_docker_module +from dimos.core.global_config import global_config +from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator +from dimos.core.rpc_client import RpcCall +from dimos.core.stream import Out + +# -- Fixtures: fake module classes ------------------------------------------- + + +class FakeDockerConfig(DockerModuleConfig): + docker_image: str = "fake:latest" + docker_file: Path | None = None + docker_gpus: str | None = None + docker_rm: bool = True + docker_restart_policy: str = "no" + + +class FakeDockerModule(Module["FakeDockerConfig"]): + default_config = FakeDockerConfig + output: Out[str] + + +class FakeRegularModule(Module): + output: Out[str] + + +# -- Tests ------------------------------------------------------------------- + + +class TestIsDockerModule: + def test_docker_module_detected(self): + assert is_docker_module(FakeDockerModule) is True + + def test_regular_module_not_detected(self): + assert is_docker_module(FakeRegularModule) is False + + def test_plain_class_not_detected(self): + assert is_docker_module(str) is False + + def test_no_default_config(self): + class Bare(Module): + pass + + # Module has default_config = ModuleConfig, which is not DockerModuleConfig + assert is_docker_module(Bare) is False + + +class TestModuleCoordinatorDockerRouting: + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_routes_docker_module(self, mock_worker_manager_cls, mock_docker_module_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + mock_dm = MagicMock() + mock_docker_module_cls.return_value = mock_dm + + coordinator = ModuleCoordinator() + coordinator.start() + try: + result = coordinator.deploy(FakeDockerModule) + + # Should NOT go through worker manager + mock_worker_mgr.deploy.assert_not_called() + # Should construct a DockerModule (container launch happens inside __init__) + mock_docker_module_cls.assert_called_once_with(FakeDockerModule, g=global_config) + # start() is NOT called during deploy — it's called in start_all_modules + mock_dm.start.assert_not_called() + assert result is mock_dm + assert coordinator.get_instance(FakeDockerModule) is mock_dm + finally: + coordinator.stop() + + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_docker_propagates_constructor_failure( + self, mock_worker_manager_cls, mock_docker_module_cls + ): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + # Container launch fails inside __init__; DockerModule handles its own cleanup + mock_docker_module_cls.side_effect = RuntimeError("launch failed") + + coordinator = ModuleCoordinator() + coordinator.start() + try: + with pytest.raises(RuntimeError, match="launch failed"): + coordinator.deploy(FakeDockerModule) + finally: + coordinator.stop() + + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_routes_regular_module_to_worker_manager(self, mock_worker_manager_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + mock_proxy = MagicMock() + mock_worker_mgr.deploy.return_value = mock_proxy + + coordinator = ModuleCoordinator() + coordinator.start() + try: + result = coordinator.deploy(FakeRegularModule) + + mock_worker_mgr.deploy.assert_called_once_with(FakeRegularModule, global_config, {}) + assert result is mock_proxy + finally: + coordinator.stop() + + @patch("dimos.core.docker_worker_manager.DockerWorkerManager.deploy_parallel") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_deploy_parallel_separates_docker_and_regular( + self, mock_worker_manager_cls, mock_docker_deploy + ): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + regular_proxy = MagicMock() + mock_worker_mgr.deploy_parallel.return_value = [regular_proxy] + + mock_dm = MagicMock() + mock_docker_deploy.return_value = [mock_dm] + + coordinator = ModuleCoordinator() + coordinator.start() + try: + specs = [ + (FakeRegularModule, (), {}), + (FakeDockerModule, (), {}), + ] + results = coordinator.deploy_parallel(specs) + + # Regular module goes through worker manager + mock_worker_mgr.deploy_parallel.assert_called_once_with([(FakeRegularModule, (), {})]) + # Docker specs go through DockerWorkerManager + mock_docker_deploy.assert_called_once_with([(FakeDockerModule, (), {})]) + # start() is NOT called during deploy — it's called in start_all_modules + mock_dm.start.assert_not_called() + + # Results preserve input order + assert results[0] is regular_proxy + assert results[1] is mock_dm + finally: + coordinator.stop() + + @patch("dimos.core.docker_runner.DockerModule") + @patch("dimos.core.module_coordinator.WorkerManager") + def test_stop_cleans_up_docker_modules(self, mock_worker_manager_cls, mock_docker_module_cls): + mock_worker_mgr = MagicMock() + mock_worker_manager_cls.return_value = mock_worker_mgr + + mock_dm = MagicMock() + mock_docker_module_cls.return_value = mock_dm + + coordinator = ModuleCoordinator() + coordinator.start() + try: + coordinator.deploy(FakeDockerModule) + finally: + coordinator.stop() + + # stop() called exactly once (no double cleanup) + assert mock_dm.stop.call_count == 1 + # Worker manager also closed + mock_worker_mgr.close_all.assert_called_once() + + +class TestDockerModuleGetattr: + """Tests for DockerModule.__getattr__ avoiding infinite recursion.""" + + def test_getattr_no_recursion_when_rpcs_not_set(self): + """If __init__ fails before self.rpcs is assigned, __getattr__ must not recurse.""" + + dm = DockerModule.__new__(DockerModule) + # Don't set rpcs, _module_class, or any instance attrs — simulates early __init__ failure + with pytest.raises(AttributeError): + _ = dm.some_method + + def test_getattr_no_recursion_on_cleanup_attrs(self): + """Accessing cleanup-related attrs before they exist must raise, not recurse.""" + + dm = DockerModule.__new__(DockerModule) + # These are accessed during _cleanup() — if rpcs isn't set, they must not recurse + for attr in ("rpc", "config", "_container_name", "_unsub_fns"): + with pytest.raises(AttributeError): + getattr(dm, attr) + + def test_getattr_delegates_to_rpc_when_rpcs_set(self): + dm = DockerModule.__new__(DockerModule) + dm.rpcs = {"do_thing"} + + # _module_class needs a real method with __name__ for RpcCall + class FakeMod: + def do_thing(self) -> None: ... + + dm._module_class = FakeMod + dm.rpc = MagicMock() + dm.remote_name = "FakeMod" + dm._unsub_fns = [] + + result = dm.do_thing + assert isinstance(result, RpcCall) + + def test_getattr_raises_for_unknown_method(self): + dm = DockerModule.__new__(DockerModule) + dm.rpcs = {"do_thing"} + + with pytest.raises(AttributeError, match="not found"): + _ = dm.nonexistent + + +class TestDockerModuleCleanupReconnect: + """Tests for DockerModule._cleanup with docker_reconnect_container.""" + + def test_cleanup_skips_stop_when_reconnect(self): + with patch.object(DockerModule, "__init__", lambda self: None): + dm = DockerModule.__new__(DockerModule) + dm._running = threading.Event() + dm._running.set() + dm._container_name = "test_container" + dm._unsub_fns = [] + dm.rpc = MagicMock() + dm.remote_name = "TestModule" + + # reconnect mode: should NOT stop/rm the container + dm.config = FakeDockerConfig(docker_reconnect_container=True) + with ( + patch("dimos.core.docker_runner._run") as mock_run, + patch("dimos.core.docker_runner._remove_container") as mock_rm, + ): + dm._cleanup() + mock_run.assert_not_called() + mock_rm.assert_not_called() + + def test_cleanup_stops_container_when_not_reconnect(self): + with patch.object(DockerModule, "__init__", lambda self: None): + dm = DockerModule.__new__(DockerModule) + dm._running = threading.Event() + dm._running.set() + dm._container_name = "test_container" + dm._unsub_fns = [] + dm.rpc = MagicMock() + dm.remote_name = "TestModule" + + # normal mode: should stop and rm the container + dm.config = FakeDockerConfig(docker_reconnect_container=False) + with ( + patch("dimos.core.docker_runner._run") as mock_run, + patch("dimos.core.docker_runner._remove_container") as mock_rm, + ): + dm._cleanup() + mock_run.assert_called_once() # docker stop + mock_rm.assert_called_once() # docker rm -f diff --git a/dimos/core/tests/test_parallel_deploy_cleanup.py b/dimos/core/tests/test_parallel_deploy_cleanup.py new file mode 100644 index 0000000000..1987fa4be7 --- /dev/null +++ b/dimos/core/tests/test_parallel_deploy_cleanup.py @@ -0,0 +1,219 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests that deploy_parallel cleans up successfully-started modules when a +sibling deployment fails ("middle module throws" scenario). +""" + +from __future__ import annotations + +import threading +from unittest.mock import MagicMock, patch + +import pytest + + +class TestDockerWorkerManagerPartialFailure: + """DockerWorkerManager.deploy_parallel must stop successful containers when one fails.""" + + @patch("dimos.core.docker_runner.DockerModule") + def test_middle_module_fails_stops_siblings(self, mock_docker_module_cls): + """Deploy 3 modules where the middle one fails. The other two must be stopped.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + mod_c = MagicMock(name="ModuleC") + + barrier = threading.Barrier(3, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + label = cls.__name__ + barrier.wait() + if label == "B": + raise RuntimeError("B failed to start") + return mod_a if label == "A" else mod_c + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed") as exc_info: + DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(exc_info.value.exceptions) == 1 + assert "B failed to start" in str(exc_info.value.exceptions[0]) + + # Both successful modules must have been stopped exactly once + mod_a.stop.assert_called_once() + mod_c.stop.assert_called_once() + + @patch("dimos.core.docker_runner.DockerModule") + def test_multiple_failures_raises_exception_group(self, mock_docker_module_cls): + """Deploy 3 modules where two fail. Should raise ExceptionGroup with both errors.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + + barrier = threading.Barrier(3, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + label = cls.__name__ + barrier.wait() + if label == "B": + raise RuntimeError("B failed") + if label == "C": + raise ValueError("C failed") + return mod_a + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed") as exc_info: + DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(exc_info.value.exceptions) == 2 + messages = {str(e) for e in exc_info.value.exceptions} + assert "B failed" in messages + assert "C failed" in messages + + # The one successful module must have been stopped + mod_a.stop.assert_called_once() + + @patch("dimos.core.docker_runner.DockerModule") + def test_all_succeed_no_stops(self, mock_docker_module_cls): + """When all deployments succeed, no modules should be stopped.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mocks = [MagicMock(name=f"Mod{i}") for i in range(3)] + + def fake_constructor(cls, *args, **kwargs): + return mocks[["A", "B", "C"].index(cls.__name__)] + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + results = DockerWorkerManager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + assert len(results) == 3 + for m in mocks: + m.stop.assert_not_called() + + @patch("dimos.core.docker_runner.DockerModule") + def test_stop_failure_does_not_mask_deploy_error(self, mock_docker_module_cls): + """If stop() itself raises during cleanup, the original deploy error still propagates.""" + from dimos.core.docker_worker_manager import DockerWorkerManager + + mod_a = MagicMock(name="ModuleA") + mod_a.stop.side_effect = OSError("stop failed") + + barrier = threading.Barrier(2, timeout=5) + + def fake_constructor(cls, *args, **kwargs): + barrier.wait() + if cls.__name__ == "B": + raise RuntimeError("B exploded") + return mod_a + + mock_docker_module_cls.side_effect = fake_constructor + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + + with pytest.raises(ExceptionGroup, match="docker deploy_parallel failed"): + DockerWorkerManager.deploy_parallel([(FakeA, (), {}), (FakeB, (), {})]) + + # stop was attempted despite it raising + mod_a.stop.assert_called_once() + + +class TestWorkerManagerPartialFailure: + """WorkerManager.deploy_parallel must clean up successful RPCClients when one fails.""" + + def test_middle_module_fails_cleans_up_siblings(self): + from dimos.core.worker_manager import WorkerManager + + manager = WorkerManager(n_workers=2) + + mock_workers = [MagicMock(name=f"Worker{i}") for i in range(2)] + for w in mock_workers: + w.module_count = 0 + w.reserve_slot = MagicMock( + side_effect=lambda w=w: setattr(w, "module_count", w.module_count + 1) + ) + + manager._workers = mock_workers + manager._started = True + + def fake_deploy_module(module_class, args=(), kwargs=None): + if module_class.__name__ == "B": + raise RuntimeError("B failed to deploy") + return MagicMock(name=f"actor_{module_class.__name__}") + + for w in mock_workers: + w.deploy_module = fake_deploy_module + + FakeA = type("A", (), {}) + FakeB = type("B", (), {}) + FakeC = type("C", (), {}) + + rpc_clients_created: list[MagicMock] = [] + + with patch("dimos.core.worker_manager.RPCClient") as mock_rpc_cls: + + def make_rpc(actor, cls): + client = MagicMock(name=f"rpc_{cls.__name__}") + rpc_clients_created.append(client) + return client + + mock_rpc_cls.side_effect = make_rpc + + with pytest.raises(ExceptionGroup, match="worker deploy_parallel failed"): + manager.deploy_parallel( + [ + (FakeA, (), {}), + (FakeB, (), {}), + (FakeC, (), {}), + ] + ) + + # Every successfully-created RPC client must have been cleaned up exactly once + for client in rpc_clients_created: + client.stop_rpc_client.assert_called_once() diff --git a/dimos/core/worker.py b/dimos/core/worker.py index 8f3beee7ec..19110a8305 100644 --- a/dimos/core/worker.py +++ b/dimos/core/worker.py @@ -215,25 +215,27 @@ def deploy_module( "module_class": module_class, "kwargs": kwargs, } - with self._lock: - self._conn.send(request) - response = self._conn.recv() + try: + with self._lock: + self._conn.send(request) + response = self._conn.recv() - if response.get("error"): - raise RuntimeError(f"Failed to deploy module: {response['error']}") - - actor = Actor(self._conn, module_class, self._worker_id, module_id, self._lock) - actor.set_ref(actor).result() - - self._modules[module_id] = actor - self._reserved = max(0, self._reserved - 1) - logger.info( - "Deployed module.", - module=module_class.__name__, - worker_id=self._worker_id, - module_id=module_id, - ) - return actor + if response.get("error"): + raise RuntimeError(f"Failed to deploy module: {response['error']}") + + actor = Actor(self._conn, module_class, self._worker_id, module_id, self._lock) + actor.set_ref(actor).result() + + self._modules[module_id] = actor + logger.info( + "Deployed module.", + module=module_class.__name__, + worker_id=self._worker_id, + module_id=module_id, + ) + return actor + finally: + self._reserved = max(0, self._reserved - 1) def suppress_console(self) -> None: if self._conn is None: diff --git a/dimos/core/worker_manager.py b/dimos/core/worker_manager.py index 4cd5eec8d7..3cd836b3ed 100644 --- a/dimos/core/worker_manager.py +++ b/dimos/core/worker_manager.py @@ -15,7 +15,7 @@ from __future__ import annotations from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor +from contextlib import suppress from typing import Any from dimos.core.global_config import GlobalConfig @@ -23,6 +23,7 @@ from dimos.core.rpc_client import RPCClient from dimos.core.worker import Worker from dimos.utils.logging_config import setup_logger +from dimos.utils.safe_thread_map import ExceptionGroup, safe_thread_map logger = setup_logger() @@ -65,6 +66,10 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient] if self._closed: raise RuntimeError("WorkerManager is closed") + module_specs = list(module_specs) + if len(module_specs) == 0: + return [] + # Auto-start for backward compatibility if not self._started: self.start() @@ -78,17 +83,20 @@ def deploy_parallel(self, module_specs: Iterable[ModuleSpec]) -> list[RPCClient] worker.reserve_slot() assignments.append((worker, module_class, global_config, kwargs)) - def _deploy( - item: tuple[Worker, type[ModuleBase], GlobalConfig, dict[str, Any]], - ) -> RPCClient: - worker, module_class, global_config, kwargs = item - actor = worker.deploy_module(module_class, global_config=global_config, kwargs=kwargs) - return RPCClient(actor, module_class) - - with ThreadPoolExecutor(max_workers=len(assignments)) as pool: - results = list(pool.map(_deploy, assignments)) - - return results + def _on_errors( + _outcomes: list[Any], successes: list[RPCClient], errors: list[Exception] + ) -> None: + for rpc_client in successes: + with suppress(Exception): + rpc_client.stop_rpc_client() + raise ExceptionGroup("worker deploy_parallel failed", errors) + + return safe_thread_map( + assignments, + # item = [worker, module_class, global_config, kwargs] + lambda item: RPCClient(item[0].deploy_module(item[1], item[2], item[3]), item[1]), + _on_errors, + ) def suppress_console(self) -> None: """Tell all workers to redirect stdout/stderr to /dev/null.""" diff --git a/dimos/models/vl/create.py b/dimos/models/vl/create.py index 7fe5a0dcb2..df6a16305d 100644 --- a/dimos/models/vl/create.py +++ b/dimos/models/vl/create.py @@ -3,6 +3,7 @@ from dimos.models.vl.types import VlModelName from dimos.models.vl.base import VlModel +__all__ = ["VlModelName", "create"] def create(name: VlModelName) -> VlModel[Any]: # This uses inline imports to only import what's needed. diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 3b77227218..52cb89a199 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -62,8 +62,12 @@ class RPCRes(TypedDict, total=False): class PubSubRPCMixin(RPCSpec, PubSub[TopicT, MsgT], Generic[TopicT, MsgT]): - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__( + self, *args: Any, rpc_timeouts: dict[str, float], default_rpc_timeout: float, **kwargs: Any + ) -> None: super().__init__(*args, **kwargs) + self.rpc_timeouts = dict(rpc_timeouts) + self.default_rpc_timeout = default_rpc_timeout # Thread pool for RPC handler execution (prevents deadlock in nested calls) self._call_thread_pool: ThreadPoolExecutor | None = None self._call_thread_pool_lock = threading.RLock() @@ -290,12 +294,13 @@ def execute_and_respond() -> None: class LCMRPC(PubSubRPCMixin[Topic, Any], PickleLCM): - def __init__(self, **kwargs: Any) -> None: - # Need to ensure PickleLCM gets initialized properly - # This is due to the diamond inheritance pattern with multiple base classes + def __init__( + self, rpc_timeouts: dict[str, float], default_rpc_timeout: float, **kwargs: Any + ) -> None: PickleLCM.__init__(self, **kwargs) - # Initialize PubSubRPCMixin's thread pool - PubSubRPCMixin.__init__(self, **kwargs) + PubSubRPCMixin.__init__( + self, rpc_timeouts=rpc_timeouts, default_rpc_timeout=default_rpc_timeout, **kwargs + ) def topicgen(self, name: str, req_or_res: bool) -> Topic: suffix = "res" if req_or_res else "req" @@ -306,12 +311,17 @@ def topicgen(self, name: str, req_or_res: bool) -> Topic: class ShmRPC(PubSubRPCMixin[str, Any], PickleSharedMemory): - def __init__(self, prefer: str = "cpu", **kwargs: Any) -> None: - # Need to ensure SharedMemory gets initialized properly - # This is due to the diamond inheritance pattern with multiple base classes + def __init__( + self, + rpc_timeouts: dict[str, float], + default_rpc_timeout: float, + prefer: str = "cpu", + **kwargs: Any, + ) -> None: PickleSharedMemory.__init__(self, prefer=prefer, **kwargs) - # Initialize PubSubRPCMixin's thread pool - PubSubRPCMixin.__init__(self, **kwargs) + PubSubRPCMixin.__init__( + self, rpc_timeouts=rpc_timeouts, default_rpc_timeout=default_rpc_timeout, **kwargs + ) def topicgen(self, name: str, req_or_res: bool) -> str: suffix = "res" if req_or_res else "req" diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 47ad77e825..993f6044bb 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -15,6 +15,7 @@ import asyncio from collections.abc import Callable import threading +from types import MappingProxyType from typing import Any, Protocol, overload @@ -30,7 +31,19 @@ class RPCInspectable(Protocol): def rpcs(self) -> dict[str, Callable]: ... # type: ignore[type-arg] +# module.py and other places imports these constants and choose what to give RPCClient +# the RPCClient below does not use these constants directly (by design) +DEFAULT_RPC_TIMEOUT: float = 120.0 +DEFAULT_RPC_TIMEOUTS: MappingProxyType[str, float] = MappingProxyType({"start": 1200.0}) + + class RPCClient(Protocol): + # call_sync resolves per-method overrides from rpc_timeouts, + # falling back to default_rpc_timeout. These are set by + # PubSubRPCMixin.__init__ at runtime. + rpc_timeouts: dict[str, float] + default_rpc_timeout: float + # if we don't provide callback, we don't get a return unsub f @overload def call(self, name: str, arguments: Args, cb: None) -> None: ... @@ -43,13 +56,18 @@ def call(self, name: str, arguments: Args, cb: Callable | None) -> Callable[[], def call_nowait(self, name: str, arguments: Args) -> None: ... - # we expect to crash if we don't get a return value after 10 seconds - # but callers can override this timeout for extra long functions def call_sync( - self, name: str, arguments: Args, rpc_timeout: float | None = 120.0 + self, name: str, arguments: Args, rpc_timeout: float | None = None ) -> tuple[Any, Callable[[], None]]: - if name == "start": - rpc_timeout = 1200.0 # starting modules can take longer + if rpc_timeout is None: + # Try full topic name first, then bare method name (after last "/"). + rpc_timeout = self.rpc_timeouts.get(name) + if rpc_timeout is None: + method = name.rsplit("/", 1)[-1] + if method != name: + rpc_timeout = self.rpc_timeouts.get(method, self.default_rpc_timeout) + else: + rpc_timeout = self.default_rpc_timeout event = threading.Event() def receive_value(val) -> None: # type: ignore[no-untyped-def] @@ -101,4 +119,5 @@ def override_f(*args, fname=fname, **kwargs): # type: ignore[no-untyped-def] self.serve_rpc(override_f, topic) -class RPCSpec(RPCServer, RPCClient): ... +class RPCSpec(RPCServer, RPCClient): + pass diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py index 5baa5ac40c..3c2b87761d 100644 --- a/dimos/protocol/rpc/test_lcmrpc.py +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -18,11 +18,12 @@ from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH from dimos.protocol.rpc.pubsubrpc import LCMRPC +from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT @pytest.fixture def lcmrpc() -> Generator[LCMRPC, None, None]: - ret = LCMRPC() + ret = LCMRPC(rpc_timeouts={}, default_rpc_timeout=DEFAULT_RPC_TIMEOUT) ret.start() yield ret ret.stop() diff --git a/dimos/protocol/rpc/test_spec.py b/dimos/protocol/rpc/test_spec.py index cfee044548..0b374f7d6c 100644 --- a/dimos/protocol/rpc/test_spec.py +++ b/dimos/protocol/rpc/test_spec.py @@ -27,6 +27,7 @@ from dimos.protocol.rpc.pubsubrpc import LCMRPC, ShmRPC from dimos.protocol.rpc.rpc_utils import RemoteError +from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT class CustomTestError(Exception): @@ -46,8 +47,8 @@ def lcm_rpc_context(): from dimos.protocol.service.lcmservice import autoconf autoconf() - server = LCMRPC() - client = LCMRPC() + server = LCMRPC(rpc_timeouts={}, default_rpc_timeout=DEFAULT_RPC_TIMEOUT) + client = LCMRPC(rpc_timeouts={}, default_rpc_timeout=DEFAULT_RPC_TIMEOUT) server.start() client.start() @@ -65,8 +66,8 @@ def lcm_rpc_context(): def shm_rpc_context(): """Context manager for Shared Memory RPC implementation.""" # Create two separate instances that communicate through shared memory segments - server = ShmRPC(prefer="cpu") - client = ShmRPC(prefer="cpu") + server = ShmRPC(rpc_timeouts={}, default_rpc_timeout=DEFAULT_RPC_TIMEOUT, prefer="cpu") + client = ShmRPC(rpc_timeouts={}, default_rpc_timeout=DEFAULT_RPC_TIMEOUT, prefer="cpu") server.start() client.start() diff --git a/dimos/simulation/manipulators/test_sim_module.py b/dimos/simulation/manipulators/test_sim_module.py index 951d4790e3..54d8f21da3 100644 --- a/dimos/simulation/manipulators/test_sim_module.py +++ b/dimos/simulation/manipulators/test_sim_module.py @@ -22,6 +22,9 @@ class _DummyRPC(RPCSpec): + def __init__(self, **kwargs: object) -> None: # type: ignore[no-untyped-def] + pass + def serve_module_rpc(self, _module) -> None: # type: ignore[no-untyped-def] return None diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 212c7ac60a..1d0598ce46 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -20,7 +20,7 @@ import mujoco import numpy as np -import onnxruntime as ort # type: ignore[import-untyped] +import onnxruntime as ort # type: ignore[import-untyped,import-not-found] from dimos.simulation.mujoco.input_controller import InputController from dimos.utils.logging_config import setup_logger diff --git a/dimos/utils/safe_thread_map.py b/dimos/utils/safe_thread_map.py new file mode 100644 index 0000000000..f480f2c97d --- /dev/null +++ b/dimos/utils/safe_thread_map.py @@ -0,0 +1,110 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from concurrent.futures import Future, ThreadPoolExecutor, as_completed +import sys +from typing import TYPE_CHECKING, Any, TypeVar + +if sys.version_info < (3, 11): + + class ExceptionGroup(Exception): # type: ignore[no-redef] # noqa: N818 + """Minimal ExceptionGroup polyfill for Python 3.10.""" + + exceptions: tuple[BaseException, ...] + + def __init__(self, message: str, exceptions: Sequence[BaseException]) -> None: + super().__init__(message) + self.exceptions = tuple(exceptions) +else: + import builtins + + ExceptionGroup = builtins.ExceptionGroup # type: ignore[misc] + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + +T = TypeVar("T") +R = TypeVar("R") + + +def safe_thread_map( + items: Sequence[T], + fn: Callable[[T], R], + on_errors: Callable[[list[tuple[T, R | Exception]], list[R], list[Exception]], Any] + | None = None, +) -> list[R]: + """Thread-pool map that waits for all items to finish before raising and a cleanup handler + + - Empty *items* → returns ``[]`` immediately. + - All succeed → returns results in input order. + - Any fail → calls ``on_errors(outcomes, successes, errors)`` where + *outcomes* is a list of ``(input, result_or_exception)`` pairs in input + order, *successes* is the list of successful results, and *errors* is + the list of exceptions. If *on_errors* raises, that exception propagates. + If *on_errors* returns normally, its return value is returned from + ``safe_thread_map``. If *on_errors* is ``None``, raises an + ``ExceptionGroup``. + + Example:: + + def start_service(name: str) -> Connection: + return connect(name) + + def cleanup( + outcomes: list[tuple[str, Connection | Exception]], + successes: list[Connection], + errors: list[Exception], + ) -> None: + for conn in successes: + conn.close() + raise ExceptionGroup("failed to start services", errors) + + connections = safe_thread_map( + ["db", "cache", "queue"], + start_service, + cleanup, # called only if any start_service() raises + ) + """ + if not items: + return [] + + outcomes: dict[int, R | Exception] = {} + + with ThreadPoolExecutor(max_workers=len(items)) as pool: + futures: dict[Future[R], int] = {pool.submit(fn, item): i for i, item in enumerate(items)} + for fut in as_completed(futures): + idx = futures[fut] + try: + outcomes[idx] = fut.result() + except Exception as e: + outcomes[idx] = e + + # Note: successes/errors are in completion order, not input order. + # This is fine — on_errors only needs them for cleanup, not ordering. + successes: list[R] = [] + errors: list[Exception] = [] + for v in outcomes.values(): + if isinstance(v, Exception): + errors.append(v) + else: + successes.append(v) + + if errors: + if on_errors is not None: + zipped = [(items[i], outcomes[i]) for i in range(len(items))] + return on_errors(zipped, successes, errors) # type: ignore[return-value, no-any-return] + raise ExceptionGroup("safe_thread_map failed", errors) + + return [outcomes[i] for i in range(len(items))] # type: ignore[misc] diff --git a/examples/docker_hello_world/Dockerfile b/examples/docker_hello_world/Dockerfile new file mode 100644 index 0000000000..3ceb24b3b4 --- /dev/null +++ b/examples/docker_hello_world/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y \ + iproute2 \ + libx11-6 libgl1 libglib2.0-0 \ + libidn2-0 libgfortran5 libgomp1 \ + cowsay \ + && rm -rf /var/lib/apt/lists/* + + +# Copy example module so it's importable inside the container +COPY examples/docker_hello_world/hello_docker.py /dimos/source/examples/docker_hello_world/hello_docker.py +RUN touch /dimos/source/examples/__init__.py /dimos/source/examples/docker_hello_world/__init__.py + +WORKDIR /app diff --git a/examples/docker_hello_world/hello_docker.py b/examples/docker_hello_world/hello_docker.py new file mode 100644 index 0000000000..a9913d770b --- /dev/null +++ b/examples/docker_hello_world/hello_docker.py @@ -0,0 +1,138 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Hello World Docker Module +========================== + +Minimal example showing a DimOS module running inside Docker. + +The module receives a string on its ``prompt`` input stream, runs it through +cowsay inside the container, and publishes the ASCII art on its ``greeting`` +output stream. + +NOTE: Requires Linux. Docker Desktop on macOS does not support host networking, +which is needed for LCM multicast between host and container. + +Usage: + python examples/docker_hello_world/hello_docker.py +""" + +from __future__ import annotations + +from dataclasses import field +from pathlib import Path +import subprocess +import time + +from reactivex.disposable import Disposable + +from dimos.core.blueprints import autoconnect +from dimos.core.core import rpc +from dimos.core.docker_runner import DockerModuleConfig +from dimos.core.module import Module +from dimos.core.stream import In, Out + + +class HelloDockerConfig(DockerModuleConfig): + docker_image: str = "dimos-hello-docker:latest" + docker_file: Path | None = Path(__file__).parent / "Dockerfile" + docker_build_context: Path | None = Path(__file__).parents[2] # repo root + docker_gpus: str | None = None # no GPU needed + docker_rm: bool = True + docker_restart_policy: str = "no" + docker_env: dict[str, str] = field(default_factory=lambda: {"CI": "1"}) + + # Custom (non-docker) config field — passed to the container via JSON + greeting_prefix: str = "Hello" + + +class HelloDockerModule(Module["HelloDockerConfig"]): + """A trivial module that runs inside Docker and echoes greetings.""" + + default_config = HelloDockerConfig + + prompt: In[str] + greeting: Out[str] + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(Disposable(self.prompt.subscribe(self._on_prompt))) + + def _cowsay(self, text: str) -> str: + """Run cowsay inside the container and return the ASCII art.""" + return subprocess.check_output(["cowsay", text], text=True) + + def _on_prompt(self, text: str) -> None: + art = self._cowsay(text) + print(f"[HelloDockerModule]\n{art}") + self.greeting.publish(art) + + @rpc + def greet(self, name: str) -> str: + """RPC method that can be called directly.""" + prefix = self.config.greeting_prefix + return self._cowsay(f"{prefix}, {name}!") + + @rpc + def get_greeting_prefix(self) -> str: + """Return the config value to verify it was passed to the container.""" + return self.config.greeting_prefix + + +class PromptModule(Module): + """Publishes prompts and listens to greetings.""" + + prompt: Out[str] + greeting: In[str] + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(Disposable(self.greeting.subscribe(self._on_greeting))) + + @rpc + def send(self, text: str) -> None: + """Publish a prompt message onto the stream.""" + self.prompt.publish(text) + + def _on_greeting(self, text: str) -> None: + print(f"[PromptModule] Received: {text}") + + +if __name__ == "__main__": + coordinator = autoconnect( + PromptModule.blueprint(), + HelloDockerModule.blueprint(greeting_prefix="Howdy"), + ).build() + + # Get module proxies + prompt_mod = coordinator.get_instance(PromptModule) + docker_mod = coordinator.get_instance(HelloDockerModule) + + # Test that custom config was passed to the container + prefix = docker_mod.get_greeting_prefix() + assert prefix == "Howdy", f"Expected 'Howdy', got {prefix!r}" + print(f"Config passed to container: greeting_prefix={prefix!r}") + + # Test RPC (should use the custom prefix) + print(docker_mod.greet("World")) + + # Test stream + prompt_mod.send("stream test") + time.sleep(2) + + coordinator.stop() + print("Done!") diff --git a/pyproject.toml b/pyproject.toml index 1535885edf..9b33009bd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -326,8 +326,12 @@ docker = [ "sortedcontainers", "PyTurboJPEG", "rerun-sdk", + "typing_extensions", "open3d-unofficial-arm; platform_system == 'Linux' and platform_machine == 'aarch64'", "open3d>=0.18.0; platform_system != 'Linux' or platform_machine != 'aarch64'", + # these below should be removed later, right now they are needed even for running `dimos --help` (seperate non-docker issue) + "langchain-core", + "matplotlib", ] base = [ diff --git a/uv.lock b/uv.lock index b940d88ff8..9a05b2fa69 100644 --- a/uv.lock +++ b/uv.lock @@ -1858,7 +1858,9 @@ dev = [ ] docker = [ { name = "dimos-lcm" }, + { name = "langchain-core" }, { name = "lcm" }, + { name = "matplotlib" }, { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "open3d", marker = "platform_machine != 'aarch64' or sys_platform != 'linux'" }, @@ -1876,6 +1878,7 @@ docker = [ { name = "sortedcontainers" }, { name = "structlog" }, { name = "typer" }, + { name = "typing-extensions" }, ] drone = [ { name = "pymavlink" }, @@ -2019,6 +2022,7 @@ requires-dist = [ { name = "langchain", marker = "extra == 'agents'", specifier = "==1.2.3" }, { name = "langchain-chroma", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-core", marker = "extra == 'agents'", specifier = "==1.2.3" }, + { name = "langchain-core", marker = "extra == 'docker'" }, { name = "langchain-huggingface", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-ollama", marker = "extra == 'agents'", specifier = ">=1,<2" }, { name = "langchain-openai", marker = "extra == 'agents'", specifier = ">=1,<2" }, @@ -2030,6 +2034,7 @@ requires-dist = [ { name = "llvmlite", specifier = ">=0.42.0" }, { name = "lxml-stubs", marker = "extra == 'dev'", specifier = ">=0.5.1,<1" }, { name = "lz4", specifier = ">=4.4.5" }, + { name = "matplotlib", marker = "extra == 'docker'" }, { name = "matplotlib", marker = "extra == 'manipulation'", specifier = ">=3.7.1" }, { name = "md-babel-py", marker = "extra == 'dev'", specifier = "==1.1.1" }, { name = "moondream", marker = "extra == 'perception'" }, @@ -2140,6 +2145,7 @@ requires-dist = [ { name = "types-tabulate", marker = "extra == 'dev'", specifier = ">=0.9.0.20241207,<1" }, { name = "types-tensorflow", marker = "extra == 'dev'", specifier = ">=2.18.0.20251008,<3" }, { name = "types-tqdm", marker = "extra == 'dev'", specifier = ">=4.67.0.20250809,<5" }, + { name = "typing-extensions", marker = "extra == 'docker'" }, { name = "ultralytics", marker = "extra == 'perception'", specifier = ">=8.3.70" }, { name = "unitree-webrtc-connect-leshy", marker = "extra == 'unitree'", specifier = ">=2.0.7" }, { name = "uvicorn", marker = "extra == 'web'", specifier = ">=0.34.0" },