diff --git a/docker/Dockerfile.gr00t_server b/docker/Dockerfile.gr00t_server deleted file mode 100644 index 227e30d66..000000000 --- a/docker/Dockerfile.gr00t_server +++ /dev/null @@ -1,38 +0,0 @@ -FROM nvcr.io/nvidia/pytorch:24.07-py3 - -ARG WORKDIR="/workspace" -ARG GROOT_DEPS_GROUP="base" -ENV WORKDIR=${WORKDIR} -ENV GROOT_DEPS_GROUP=${GROOT_DEPS_GROUP} -WORKDIR "${WORKDIR}" - -RUN apt-get update && apt-get install -y \ - git \ - git-lfs \ - cmake \ - && rm -rf /var/lib/apt/lists/* - -RUN pip install --upgrade pip - -COPY ./submodules/Isaac-GR00T ${WORKDIR}/submodules/Isaac-GR00T - -COPY docker/setup/install_gr00t_deps.sh /tmp/install_gr00t_deps.sh -RUN chmod +x /tmp/install_gr00t_deps.sh && \ - /tmp/install_gr00t_deps.sh --server && \ - rm -f /tmp/install_gr00t_deps.sh - -RUN pip install --no-cache-dir --upgrade "opencv-python-headless==4.8.0.74" -# GN1.6 uses termcolor 3.2.0 -RUN pip install --no-cache-dir termcolor==3.2.0 - -COPY isaaclab_arena/remote_policy ${WORKDIR}/isaaclab_arena/remote_policy -COPY isaaclab_arena_gr00t ${WORKDIR}/isaaclab_arena_gr00t -COPY isaaclab_arena_g1 ${WORKDIR}/isaaclab_arena_g1 - -RUN pip install --no-cache-dir pyzmq msgpack - -ENV PYTHONPATH=${WORKDIR} -# So gr00t_remote_policy loads transformers/tokenizers from GR00T deps (e.g. tokenizers 0.21.x) not system site-packages. -ENV GROOT_DEPS_DIR=/opt/groot_deps - -ENTRYPOINT ["python", "-u", "-m", "isaaclab_arena.remote_policy.remote_policy_server_runner"] diff --git a/docker/run_gr00t_server.sh b/docker/run_gr00t_server.sh deleted file mode 100755 index fca3d0d1a..000000000 --- a/docker/run_gr00t_server.sh +++ /dev/null @@ -1,248 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -# ------------------------- -# User-configurable defaults -# ------------------------- - -# Default mount directories on the host machine -DATASETS_DIR="${DATASETS_DIR:-$HOME/datasets}" -MODELS_DIR="${MODELS_DIR:-$HOME/models}" -EVAL_DIR="${EVAL_DIR:-$HOME/eval}" - -# Docker image name and tag for the GR00T policy server -DOCKER_IMAGE_NAME="${DOCKER_IMAGE_NAME:-gr00t_policy_server}" -DOCKER_VERSION_TAG="${DOCKER_VERSION_TAG:-latest}" - -# Rebuild controls -FORCE_REBUILD="${FORCE_REBUILD:-false}" -NO_CACHE="" - -# Server parameters (can also be overridden via environment variables) -HOST="${HOST:-0.0.0.0}" -PORT="${PORT:-5555}" -API_TOKEN="${API_TOKEN:-}" -TIMEOUT_MS="${TIMEOUT_MS:-5000}" -POLICY_TYPE="${POLICY_TYPE:-gr00t_closedloop}" -POLICY_CONFIG_YAML_PATH="${POLICY_CONFIG_YAML_PATH:-/workspace/isaaclab_arena_gr00t/gr1_manip_gr00t_closedloop_config.yaml}" - -# GPU selection for docker --gpus (can also be overridden via environment variables) -# Examples: -# all -> use all GPUs -# 1 -> use 1 GPU (count) -# "device=0" -> use GPU 0 -# "device=0,1" -> use GPU 0 and 1 -GPUS="${GPUS:-all}" - -# ------------------------- -# Help message -# ------------------------- -usage() { - script_name=$(basename "$0") - cat < Path to datasets on the host. Default: "$DATASETS_DIR". - -m Path to models on the host. Default: "$MODELS_DIR". - -e Path to evaluation data on the host. Default: "$EVAL_DIR". - -n Docker image name. Default: "$DOCKER_IMAGE_NAME". - -g GPU selection for docker --gpus. Default: "all". - Examples: "all", "1", "device=0", "device=0,1". - -r Force rebuilding of the Docker image. - -R Force rebuilding of the Docker image, without cache. - -Server-specific options (passed through to the policy server entrypoint): - --host HOST - --port PORT - --api_token TOKEN - --timeout_ms MS - --policy_type TYPE - --policy_config_yaml_path PATH - -Examples: - # Minimal: use defaults, just build & run server - bash $script_name - - # Custom models directory, port and single GPU (GPU 0) - bash $script_name -m /data/models -g "device=0" --port 6000 --api_token MY_TOKEN - - # Custom image name, force rebuild, datasets/eval mounts, and multiple GPUs - bash $script_name -n gr00t_server -r \\ - -d /data/datasets -m /data/models -e /data/eval \\ - -g "device=0,1" \\ - --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \\ - --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml -EOF -} - -# ------------------------- -# Parse docker/path options (short flags, like run_docker.sh) -# ------------------------- -DOCKER_ARGS_DONE=false -SERVER_ARGS=() - -while [[ $# -gt 0 ]]; do - if [ "$DOCKER_ARGS_DONE" = false ]; then - case "$1" in - -v) - # Enable verbose mode for debugging - set -x - shift 1 - ;; - -d) - # Set host datasets directory - DATASETS_DIR="$2" - shift 2 - ;; - -m) - # Set host models directory - MODELS_DIR="$2" - shift 2 - ;; - -e) - # Set host eval directory - EVAL_DIR="$2" - shift 2 - ;; - -n) - # Set Docker image name - DOCKER_IMAGE_NAME="$2" - shift 2 - ;; - -g) - # Set GPU selection for docker --gpus - GPUS="$2" - shift 2 - ;; - -r) - # Force rebuild of Docker image - FORCE_REBUILD="true" - shift 1 - ;; - -R) - # Force rebuild of Docker image, without cache - FORCE_REBUILD="true" - NO_CACHE="--no-cache" - shift 1 - ;; - -h|--help) - usage - exit 0 - ;; - --host|--port|--api_token|--timeout_ms|--policy_type|--policy_config_yaml_path) - # From here on, treat everything as server args and stop parsing docker flags - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - --*) - # Unknown long option at docker level -> treat as server arg - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - *) - # Anything else -> treat as server arg - DOCKER_ARGS_DONE=true - SERVER_ARGS+=("$1") - shift 1 - ;; - esac - else - # Additional server arguments after docker/path args - SERVER_ARGS+=("$1") - shift 1 - fi -done - -# If no server args were passed, use defaults -if [ ${#SERVER_ARGS[@]} -eq 0 ]; then - SERVER_ARGS=( - --host "${HOST}" - --port "${PORT}" - --api_token "${API_TOKEN}" - --timeout_ms "${TIMEOUT_MS}" - --policy_type "${POLICY_TYPE}" - --policy_config_yaml_path "${POLICY_CONFIG_YAML_PATH}" - ) -fi - -echo "Host paths:" -echo " DATASETS_DIR = ${DATASETS_DIR}" -echo " MODELS_DIR = ${MODELS_DIR}" -echo " EVAL_DIR = ${EVAL_DIR}" -echo "Docker image:" -echo " ${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" -echo "GPU:" -echo " --gpus ${GPUS}" -echo "Rebuild:" -echo " FORCE_REBUILD = ${FORCE_REBUILD}, NO_CACHE = '${NO_CACHE}'" -echo "Server args:" -printf ' %q ' "${SERVER_ARGS[@]}"; echo - -# ------------------------- -# 1) Build the Docker image -# ------------------------- - -IMAGE_TAG_FULL="${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" - -# 1) Decide whether to build -SHOULD_BUILD=false - -if [ "${FORCE_REBUILD}" = "true" ]; then - # -r or -R: force rebuild - SHOULD_BUILD=true -else - # Without force flag: only build if the image does not exist locally - if [ -z "$(docker images -q "${IMAGE_TAG_FULL}")" ]; then - SHOULD_BUILD=true - fi -fi - -# 2) Build or skip -if [ "${SHOULD_BUILD}" = "true" ]; then - echo "Building Docker image ${IMAGE_TAG_FULL}..." - docker build \ - ${NO_CACHE} \ - -f docker/Dockerfile.gr00t_server \ - -t "${IMAGE_TAG_FULL}" \ - . -else - echo "Docker image ${IMAGE_TAG_FULL} already exists. Skipping rebuild." - echo "Use -r or -R to force rebuilding the image." -fi - -# ------------------------- -# 2) Run the container -# ------------------------- - -DOCKER_RUN_ARGS=( - --rm - --gpus "${GPUS}" - --net host - --name gr00t_policy_server_container - -v "${MODELS_DIR}":/models -) - -# Only mount datasets / eval if the directories exist on host -if [ -d "${DATASETS_DIR}" ]; then - DOCKER_RUN_ARGS+=(-v "${DATASETS_DIR}":/datasets) -fi - -if [ -d "${EVAL_DIR}" ]; then - DOCKER_RUN_ARGS+=(-v "${EVAL_DIR}":/eval) -fi - -# Pass through so gr00t_remote_policy can print path/debug info (e.g. GROOT_DEBUG_PATH=1). -if [ -n "${GROOT_DEBUG_PATH:-}" ]; then - DOCKER_RUN_ARGS+=(-e "GROOT_DEBUG_PATH=${GROOT_DEBUG_PATH}") -fi - -docker run "${DOCKER_RUN_ARGS[@]}" \ - "${IMAGE_TAG_FULL}" \ - "${SERVER_ARGS[@]}" diff --git a/isaaclab_arena/evaluation/policy_runner.py b/isaaclab_arena/evaluation/policy_runner.py index 6859e65dd..be2d24ce3 100644 --- a/isaaclab_arena/evaluation/policy_runner.py +++ b/isaaclab_arena/evaluation/policy_runner.py @@ -204,18 +204,10 @@ def main(): # Each rank prints its own metrics as it can be different due to random seed print(f"[Rank {local_rank}/{world_size}] Metrics: {metrics}") - # NOTE(huikang, 2025-12-30)Explicitly clean up the remote policy client / server. - # Do NOT rely on a __del__ destructor in policy for this, since destructors are - # triggered implicitly and their execution time (or even whether they run) - # is not guaranteed, which makes resource cleanup unreliable. - if policy.is_remote: - policy.shutdown_remote(kill_server=args_cli.remote_kill_on_exit) - # Close the environment. env.close() if __name__ == "__main__": - # TODO(xinjie.yao, 2026.03.31): Remove it after policy sever-client is implemented properly in v0.3. ensure_groot_deps_in_path() main() diff --git a/isaaclab_arena/policy/__init__.py b/isaaclab_arena/policy/__init__.py index 66caa51a1..9290be08a 100644 --- a/isaaclab_arena/policy/__init__.py +++ b/isaaclab_arena/policy/__init__.py @@ -3,8 +3,8 @@ # # SPDX-License-Identifier: Apache-2.0 -from .action_chunking import ActionChunkingState -from .action_chunking_client import * +from .action_scheduler import ActionScheduler +from .action_chunking import ActionChunkScheduler, ActionChunkingState from .replay_action_policy import * from .rsl_rl_action_policy import * from .zero_action_policy import * diff --git a/isaaclab_arena/policy/action_chunking.py b/isaaclab_arena/policy/action_chunking.py index fbfb32755..e367e17d5 100644 --- a/isaaclab_arena/policy/action_chunking.py +++ b/isaaclab_arena/policy/action_chunking.py @@ -3,19 +3,21 @@ # # SPDX-License-Identifier: Apache-2.0 -"""Shared action chunking state and logic for local and remote policies.""" +"""ActionChunkScheduler: buffer a model chunk and step through it sequentially.""" from __future__ import annotations import torch from collections.abc import Callable +from isaaclab_arena.policy.action_scheduler import ActionScheduler -class ActionChunkingState: - """Holds chunk buffer, per-env index, and refill flag; provides get_action(fetch_chunk_fn). - Used by both Gr00tClosedloopPolicy (local) and ActionChunkingClientSidePolicy (remote) - so chunking behavior is identical. +class ActionChunkScheduler(ActionScheduler): + """Buffers one action chunk and replays it one step at a time. + + Fetches a new chunk from the model only when the current one is exhausted. + Per-env tracking allows environments to refetch independently. """ def __init__( @@ -79,10 +81,14 @@ def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor return action - def reset(self, env_ids: torch.Tensor | None = None) -> None: + def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: """Reset chunking state for the given envs (all if None).""" if env_ids is None: env_ids = slice(None) self.current_action_chunk[env_ids] = 0.0 self.current_action_index[env_ids] = -1 self.env_requires_new_chunk[env_ids] = True + + +# Backwards-compatibility alias +ActionChunkingState = ActionChunkScheduler diff --git a/isaaclab_arena/policy/action_chunking_client.py b/isaaclab_arena/policy/action_chunking_client.py deleted file mode 100644 index 3cdc5856a..000000000 --- a/isaaclab_arena/policy/action_chunking_client.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import gymnasium as gym -import torch -from typing import Any - -from isaaclab_arena.policy.action_chunking import ActionChunkingState -from isaaclab_arena.policy.client_side_policy import ClientSidePolicy -from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class ActionChunkingClientSidePolicy(ClientSidePolicy): - """Client-side policy that consumes fixed-length action chunks sequentially.""" - - def __init__( - self, - config: Any, - num_envs: int, - device: str, - remote_config: RemotePolicyConfig, - ) -> None: - super().__init__(config=config, remote_config=remote_config, protocol_cls=ChunkingActionProtocol) - - self._num_envs = num_envs - self._device = device - - assert self.protocol.action_chunk_length <= self.protocol.action_horizon, ( - f"protocol.action_chunk_length ({self.protocol.action_chunk_length}) " - f"must be <= protocol.action_horizon ({self.protocol.action_horizon})" - ) - # Shared chunking state (unified with local Gr00tClosedloopPolicy) - self._chunking_state = ActionChunkingState( - num_envs=self._num_envs, - action_chunk_length=self.protocol.action_chunk_length, - action_horizon=self.protocol.action_horizon, - action_dim=self.protocol.action_dim, - device=self._device, - dtype=torch.float32, - ) - - self.task_description: str | None = None - - # ---------------------- CLI ---------------------------------------- - - @staticmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add CLI arguments for ActionChunkingClientSidePolicy.""" - # Shared remote policy args. - parser = ClientSidePolicy.add_remote_args_to_parser(parser) - - # Policy-specific args. - group = parser.add_argument_group( - "Action Chunking Client Policy", - "Arguments for client-side action chunking policy.", - ) - group.add_argument( - "--policy_device", - type=str, - default="cuda", - help="Device to use for the policy-related operations.", - ) - return parser - - @staticmethod - def from_args(args: argparse.Namespace) -> ActionChunkingClientSidePolicy: - """Create an ActionChunkingClientSidePolicy from CLI arguments.""" - remote_config = ClientSidePolicy.build_remote_config_from_args(args) - return ActionChunkingClientSidePolicy( - config=None, - num_envs=args.num_envs, - device=args.policy_device, - remote_config=remote_config, - ) - - # ---------------------- Task description ---------------------------- - - def set_task_description(self, task_description: str | None) -> str: - """Set the task description on both client-side and remote policy.""" - self.task_description = task_description - # Always notify the server so it can set _task_description (server uses config default when None) - self.remote_client.call_endpoint( - "set_task_description", - data={"task_description": task_description}, - requires_input=True, - ) - return self.task_description or "" - - # ---------------------- Chunking logic ------------------------------ - - def _request_new_chunk( - self, - observation: dict[str, Any], - ) -> torch.Tensor: - """Request a new action chunk from the remote policy and validate it.""" - protocol = self.protocol - packed_obs = self.pack_observation_for_server(observation) - - resp = self.remote_client.get_action(packed_obs) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from get_action, got {type(resp)!r}") - if "action" not in resp: - raise KeyError("Remote response does not contain key 'action' for ActionChunkingClientSidePolicy.") - - raw_chunk = resp["action"] - if not isinstance(raw_chunk, torch.Tensor): - raw_chunk = torch.tensor(raw_chunk, dtype=torch.float32, device=self._device) - else: - raw_chunk = raw_chunk.to(self._device, dtype=torch.float32) - - if raw_chunk.shape[0] != self._num_envs: - raise ValueError(f"Expected batch size {self._num_envs}, got {raw_chunk.shape[0]}") - if raw_chunk.shape[1] < protocol.action_chunk_length: - raise ValueError( - f"Expected at least {protocol.action_chunk_length} actions per chunk, got {raw_chunk.shape[1]}" - ) - if raw_chunk.shape[2] != protocol.action_dim: - raise ValueError(f"Expected action_dim {protocol.action_dim}, got {raw_chunk.shape[2]}") - - return raw_chunk - - def get_action( - self, - env: gym.Env, - observation: gym.spaces.Dict, - ) -> torch.Tensor: - """Return one action per env step, consuming action chunks sequentially.""" - - def fetch_chunk() -> torch.Tensor: - return self._request_new_chunk(observation) - - return self._chunking_state.get_action(fetch_chunk) - - def reset(self, env_ids: torch.Tensor | None = None) -> None: - """Reset client-side chunking state and remote policy state.""" - if env_ids is None: - env_ids = torch.arange( - self._num_envs, - device=self._device, - dtype=torch.long, - ) - - self._chunking_state.reset(env_ids) - - # Reset remote state via ClientSidePolicy. - super().reset(env_ids=env_ids) diff --git a/isaaclab_arena/policy/action_scheduler.py b/isaaclab_arena/policy/action_scheduler.py new file mode 100644 index 000000000..18f91d70b --- /dev/null +++ b/isaaclab_arena/policy/action_scheduler.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Abstract base class for action scheduling strategies.""" + +from __future__ import annotations + +import torch +from abc import ABC, abstractmethod +from collections.abc import Callable + + +class ActionScheduler(ABC): + """Translates raw model chunk outputs into per-step actions. + + The policy calls ``get_action(fetch_chunk_fn)`` at every environment step. + The scheduler controls when to query the model and how to derive a single + action from one or more model outputs. + + Concrete implementations include: + - ``ActionChunkScheduler``: buffer one chunk, step through it sequentially, + refetch when exhausted. + - ``TemporalEnsemblingScheduler``: always query the model, blend overlapping + chunks with exponential decay weights (ACT-style). + - ``PassThroughScheduler``: always query the model, return the first action + in the chunk. + """ + + @abstractmethod + def get_action(self, fetch_chunk_fn: Callable[[], torch.Tensor]) -> torch.Tensor: + """Return one action per env for the current timestep. + + Args: + fetch_chunk_fn: Callable that queries the model and returns a chunk + tensor of shape ``(num_envs, horizon, action_dim)``. + + Returns: + Action tensor of shape ``(num_envs, action_dim)``. + """ + ... + + @abstractmethod + def reset(self, env_ids: torch.Tensor | slice | None = None) -> None: + """Reset scheduler state for the given envs (all envs if None).""" + ... diff --git a/isaaclab_arena/policy/client_side_policy.py b/isaaclab_arena/policy/client_side_policy.py deleted file mode 100644 index 44068dc8d..000000000 --- a/isaaclab_arena/policy/client_side_policy.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import torch -from typing import Any - -from isaaclab_arena.policy.policy_base import PolicyBase -from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol -from isaaclab_arena.remote_policy.policy_client import PolicyClient -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class ClientSidePolicy(PolicyBase): - """Base class for policies that query a remote policy server. - - Responsibilities: - - Manage RemotePolicyConfig and PolicyClient. - - Handshake with the server via get_init_info(). - - Provide observation packing based on observation_keys. - - Provide shared CLI helpers for remote-related arguments. - - Subclasses: - - Must implement get_action(). - """ - - def __init__(self, config: Any, remote_config: RemotePolicyConfig, protocol_cls: type[ActionProtocol]) -> None: - super().__init__(config=config) - - if protocol_cls.MODE is None: - raise ValueError(f"{protocol_cls.__name__}.MODE must be defined as an ActionMode.") - - self.protocol_cls = protocol_cls - requested_action_mode: ActionMode = protocol_cls.MODE - - self._remote_config = remote_config - self._client = PolicyClient(config=self._remote_config) - - # 1) Ping server to ensure connectivity. - if not self._client.ping(): - raise RuntimeError( - f"Failed to connect to remote policy server at {self._remote_config.host}:{self._remote_config.port}." - ) - - # 2) Handshake: send requested_action_mode, parse response. - init_resp = self._client.get_init_info(requested_action_mode=requested_action_mode.value) - - if not isinstance(init_resp, dict): - raise TypeError(f"Expected dict from get_init_info, got {type(init_resp)!r}") - - status = init_resp.get("status", "error") - if status != "success": - message = init_resp.get("message", "no message") - raise RuntimeError(f"Remote policy get_init_info failed with status='{status}': {message}") - - cfg_dict = init_resp.get("config") - if not isinstance(cfg_dict, dict): - raise TypeError( - f"Remote policy get_init_info must return a 'config' dict inside the response, got {type(cfg_dict)!r}" - ) - - self._protocol: ActionProtocol = self.protocol_cls.from_dict(cfg_dict) - - # ---------------------- properties ---------------------------------- - @property - def protocol(self) -> ActionProtocol: - return self._protocol - - @property - def action_mode(self) -> ActionMode: - return self._protocol.mode - - @property - def action_dim(self) -> int: - return self._protocol.action_dim - - @property - def observation_keys(self) -> list[str]: - return list(self._protocol.observation_keys) - - @property - def remote_config(self) -> RemotePolicyConfig: - return self._remote_config - - @property - def remote_client(self) -> PolicyClient: - return self._client - - @property - def is_remote(self) -> bool: - return True - - # ---------------------- observation packing ------------------------- - @staticmethod - def _get_nested_observation(observation: dict[str, Any], key_path: str) -> Any: - """Get a nested value from a dict using 'a.b.c' path.""" - cur: Any = observation - - for k in key_path.split("."): - cur = cur[k] - return cur - - def pack_observation_for_server( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Pack selected observation entries into a flat CPU dict for the server. - - Uses `self.observation_keys` from ClientSidePolicyConfig and: - - Extracts values using nested key paths. - - Moves torch.Tensor values to CPU numpy arrays. - """ - packed: dict[str, Any] = {} - for key_path in self.observation_keys: - value = self._get_nested_observation(observation, key_path) - if isinstance(value, torch.Tensor): - value = value.detach().cpu().numpy() - packed[key_path] = value - return packed - - def reset(self, env_ids: torch.Tensor | None = None) -> None: - """Optionally reset remote policy state. - - Client-side state should be reset in subclasses. - """ - env_ids_list = None - if env_ids is not None: - env_ids_list = env_ids.detach().cpu().tolist() - self._client.reset(env_ids=env_ids_list, options=None) - - def shutdown_remote(self, kill_server: bool = False) -> None: - """Clean up the remote client and optionally stop the remote server.""" - if kill_server: - try: - self._client.call_endpoint("kill", requires_input=False) - except Exception as exc: - print(f"[ClientSidePolicy] Failed to send kill to remote server: {exc}") - self._client.close() - - # ---------------------- shared CLI helpers -------------------------- - - @staticmethod - def add_remote_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add shared remote-policy arguments to the parser. - - This should be called from subclass.add_args_to_parser(). - """ - group = parser.add_argument_group( - "Remote Policy", - "Arguments for connecting to a remote policy server.", - ) - group.add_argument( - "--remote_host", - type=str, - default=None, - required=True, - help="Remote policy server host.", - ) - group.add_argument( - "--remote_port", - type=int, - default=5555, - help="Remote policy server port.", - ) - group.add_argument( - "--remote_api_token", - type=str, - default=None, - help="API token for the remote policy server.", - ) - group.add_argument( - "--remote_timeout_ms", - type=int, - default=15000, - help="Timeout (ms) for remote policy requests.", - ) - group.add_argument( - "--remote_kill_on_exit", - action="store_true", - help="If set, send a 'kill' request to the remote policy server when the run finishes.", - ) - return parser - - @staticmethod - def build_remote_config_from_args(args: argparse.Namespace) -> RemotePolicyConfig: - """Construct RemotePolicyConfig from CLI arguments. - - Assumes add_remote_args_to_parser() has been called on the parser. - """ - - return RemotePolicyConfig( - host=args.remote_host, - port=args.remote_port, - api_token=args.remote_api_token, - timeout_ms=args.remote_timeout_ms, - ) diff --git a/isaaclab_arena/policy/policy_base.py b/isaaclab_arena/policy/policy_base.py index bf594ea56..d927c8d72 100644 --- a/isaaclab_arena/policy/policy_base.py +++ b/isaaclab_arena/policy/policy_base.py @@ -86,11 +86,6 @@ def length(self) -> int | None: """Get the length of the policy (for dataset-driven policies).""" pass - @property - def is_remote(self) -> bool: - """Check if policy is run remotely.""" - return False - @staticmethod @abstractmethod def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: diff --git a/isaaclab_arena/remote_policy/__init__.py b/isaaclab_arena/remote_policy/__init__.py deleted file mode 100644 index 6b2258a35..000000000 --- a/isaaclab_arena/remote_policy/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from .action_protocol import ActionMode, ActionProtocol, ChunkingActionProtocol -from .message_serializer import MessageSerializer -from .policy_client import PolicyClient -from .policy_server import PolicyServer -from .remote_policy_config import RemotePolicyConfig -from .server_side_policy import ServerSidePolicy - -__all__ = [ - "RemotePolicyConfig", - "ServerSidePolicy", - "MessageSerializer", - "PolicyClient", - "PolicyServer", - "ActionMode", - "ActionProtocol", - "ChunkingActionProtocol", -] diff --git a/isaaclab_arena/remote_policy/action_protocol.py b/isaaclab_arena/remote_policy/action_protocol.py deleted file mode 100644 index 5d4df2203..000000000 --- a/isaaclab_arena/remote_policy/action_protocol.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Any, ClassVar - - -class ActionMode(str, Enum): - """Action output mode of a policy. - - Currently only CHUNK is used. - Other modes can be added later if needed. - """ - - CHUNK = "chunk" - - -@dataclass -class ActionProtocol(ABC): - """Base handshake/config for a policy's action output. - - - Encapsulates the ActionMode. - - Holds common fields (action_dim, observation_keys). - - Subclasses add mode-specific fields (e.g. chunk_length). - """ - - # Subclasses must override this. - MODE: ClassVar[ActionMode | None] = None - - # Common fields for all modes. - action_dim: int - observation_keys: list[str] - - def __post_init__(self) -> None: - """Validate that subclasses configured MODE properly.""" - mode = type(self).MODE - if mode is None: - raise NotImplementedError(f"{type(self).__name__} must define MODE as an ActionMode.") - - @classmethod - @abstractmethod - def from_dict(cls, data: dict[str, Any]) -> ActionProtocol: - """Build protocol config from server-side config dict.""" - - @abstractmethod - def to_dict(self) -> dict[str, Any]: - """Serialize protocol config to a dict for RPC.""" - - @property - def mode(self) -> ActionMode: - return self.MODE - - -@dataclass -class ChunkingActionProtocol(ActionProtocol): - """ActionProtocol for CHUNK mode. - - action_chunk_length: - Number of actions that the client-side policy consumes from each - chunk at a time during post-processing. - action_horizon: - Total length of the action sequence produced by the model for a - single query. - """ - - MODE: ClassVar[ActionMode] = ActionMode.CHUNK - - # Mode-specific field. - action_chunk_length: int - action_horizon: int - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> ChunkingActionProtocol: - return cls( - action_dim=int(data["action_dim"]), - observation_keys=list(data["observation_keys"]), - action_chunk_length=int(data["action_chunk_length"]), - action_horizon=int(data["action_horizon"]), - ) - - def to_dict(self) -> dict[str, Any]: - return { - "action_mode": self.mode.value, - "action_dim": self.action_dim, - "observation_keys": self.observation_keys, - "action_chunk_length": self.action_chunk_length, - "action_horizon": self.action_horizon, - } diff --git a/isaaclab_arena/remote_policy/message_serializer.py b/isaaclab_arena/remote_policy/message_serializer.py deleted file mode 100644 index 94167cd2d..000000000 --- a/isaaclab_arena/remote_policy/message_serializer.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import io -import numpy as np -from dataclasses import asdict, is_dataclass -from enum import Enum -from typing import Any - -import msgpack - - -class MessageSerializer: - """Msgpack-based serializer for dict-based policy messages. - - Supports: - - standard Python types, - - dataclasses (via to_json_serializable), - - numpy.ndarray (tagged as __ndarray_class__), - - generic binary blobs (tagged as __blob_class__). - """ - - @staticmethod - def to_bytes(data: Any) -> bytes: - """Serialize a Python object to bytes using msgpack.""" - return msgpack.packb(data, default=MessageSerializer._encode_custom) - - @staticmethod - def from_bytes(data: bytes) -> Any: - """Deserialize bytes into Python objects, decoding custom tags.""" - return msgpack.unpackb(data, object_hook=MessageSerializer._decode_custom) - - # ------------------------------------------------------------------ # - # Custom encode / decode - # ------------------------------------------------------------------ # - - @staticmethod - def _decode_custom(obj: Any) -> Any: - """Decode tagged structures created in _encode_custom. - - This function is registered as the `object_hook` for msgpack.unpackb, - so it is called once for every decoded map/dict. - - - If the dict contains a special tag (e.g. '__ndarray_class__' or - '__blob_class__'), it is converted back into the corresponding - high-level type (numpy array, blob, etc.). - - If the dict has no special tag, it is returned unchanged. In that - case the object stays as whatever type msgpack's default decoder - produced (dict, list, int, str, ...). - - Untagged values and non-dict types are therefore handled entirely - by msgpack's built-in decoder. - """ - if not isinstance(obj, dict): - return obj - - # numpy array - if "__ndarray_class__" in obj: - return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False) - - # generic binary blob - if "__blob_class__" in obj: - return { - "mime": obj.get("mime"), - "data": obj.get("as_bytes"), - } - - # other tagged types can be added here - return obj - - @staticmethod - def _encode_custom(obj: Any) -> Any: - """Encode special Python objects into msgpack-friendly structures.""" - - # numpy array -> npy bytes - if isinstance(obj, np.ndarray): - output = io.BytesIO() - np.save(output, obj, allow_pickle=False) - return {"__ndarray_class__": True, "as_npy": output.getvalue()} - - # generic binary blob: bytes / bytearray - if isinstance(obj, (bytes, bytearray)): - return { - "__blob_class__": True, - "mime": None, - "as_bytes": bytes(obj), - } - - # optional: custom Image/Frame types with to_bytes() and mime attribute - if hasattr(obj, "to_bytes") and hasattr(obj, "mime"): - return { - "__blob_class__": True, - "mime": getattr(obj, "mime"), - "as_bytes": obj.to_bytes(), - } - - # fall back to JSON-serializable representation - return to_json_serializable(obj) - - -def to_json_serializable(obj: Any) -> Any: - """Recursively convert dataclasses and numpy arrays to JSON-serializable format. - - This is useful when encoding configuration objects or metadata. - """ - if is_dataclass(obj) and not isinstance(obj, type): - return to_json_serializable(asdict(obj)) - elif isinstance(obj, np.ndarray): - return obj.tolist() - elif isinstance(obj, np.integer): - return int(obj) - elif isinstance(obj, np.floating): - return float(obj) - elif isinstance(obj, np.bool_): - return bool(obj) - elif isinstance(obj, dict): - return {key: to_json_serializable(value) for key, value in obj.items()} - elif isinstance(obj, (list, tuple, set)): - return [to_json_serializable(item) for item in obj] - elif isinstance(obj, (str, int, float, bool, type(None))): - return obj - elif isinstance(obj, Enum): - return obj.name - else: - # Fallback: convert to string - return str(obj) diff --git a/isaaclab_arena/remote_policy/policy_client.py b/isaaclab_arena/remote_policy/policy_client.py deleted file mode 100644 index 04e25b2e0..000000000 --- a/isaaclab_arena/remote_policy/policy_client.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import warnings -from typing import Any - -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class PolicyClient: - """Synchronous client for talking to a PolicyServer over ZeroMQ.""" - - def __init__(self, config: RemotePolicyConfig) -> None: - self._config = config - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REQ) - self._socket.setsockopt(zmq.RCVTIMEO, self._config.timeout_ms) - self._socket.connect(f"tcp://{self._config.host}:{self._config.port}") - - # ------------------------------------------------------------------ # - # Public API - # ------------------------------------------------------------------ # - - def ping(self) -> bool: - """Check if the server is reachable.""" - try: - self.call_endpoint("ping", requires_input=False) - return True - except Exception as exc: - warnings.warn( - f"[PolicyClient] Failed to ping remote policy server at {self._config.host}:{self._config.port}: {exc}" - ) - return False - - def reset(self, env_ids=None, options: dict[str, Any] | None = None) -> Any: - """Reset remote policy state.""" - resp = self.call_endpoint( - endpoint="reset", - data={"env_ids": env_ids, "options": options}, - requires_input=True, - ) - if isinstance(resp, dict): - status = resp.get("status") - if status not in ("reset_success", "ok", "reset_ok", None): - raise RuntimeError(f"Remote reset failed with status={status}, resp={resp}") - return resp - - def kill(self) -> Any: - """Ask remote server to stop main loop.""" - return self.call_endpoint("kill", requires_input=False) - - def get_action( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Send policy_observations and get back policy action dict.""" - payload: dict[str, Any] = {"observation": observation} - - resp = self.call_endpoint( - endpoint="get_action", - data=payload, - requires_input=True, - ) - return resp - - def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: - """Call get_init_info on the server with a requested_action_mode. - - Args: - requested_action_mode: ActionMode value (e.g. "chunk"). - - Returns: - A dict returned by the server, expected to contain: - - "status" - - "message" (optional) - - "config" (on success) - """ - payload = {"requested_action_mode": requested_action_mode} - resp = self.call_endpoint( - "get_init_info", - data=payload, - requires_input=True, - ) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from get_init_info, got {type(resp)!r}") - return resp - - def set_task_description(self, task_description: str | None) -> dict[str, Any]: - """Send task description to the remote policy.""" - payload: dict[str, Any] = {"task_description": task_description} - resp = self.call_endpoint( - endpoint="set_task_description", - data=payload, - requires_input=True, - ) - if not isinstance(resp, dict): - raise TypeError(f"Expected dict from set_task_description, got {type(resp)!r}") - return resp - - def call_endpoint( - self, - endpoint: str, - data: dict[str, Any] | None = None, - requires_input: bool = True, - ) -> Any: - """Generic RPC helper.""" - request: dict[str, Any] = {"endpoint": endpoint} - if requires_input: - request["data"] = data or {} - if self._config.api_token: - request["api_token"] = self._config.api_token - - self._socket.send(MessageSerializer.to_bytes(request)) - message = self._socket.recv() - response = MessageSerializer.from_bytes(message) - - if isinstance(response, dict) and "error" in response: - raise RuntimeError(f"Server error: {response['error']}") - return response - - def close(self) -> None: - """Close the underlying ZeroMQ socket and context.""" - self._socket.close() - self._context.term() diff --git a/isaaclab_arena/remote_policy/policy_server.py b/isaaclab_arena/remote_policy/policy_server.py deleted file mode 100644 index 27a1d8498..000000000 --- a/isaaclab_arena/remote_policy/policy_server.py +++ /dev/null @@ -1,216 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections.abc import Callable -from dataclasses import dataclass -from typing import Any - -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy - - -@dataclass -class EndpointHandler: - handler: Callable[..., Any] - requires_input: bool = True - - -class PolicyServer: - def __init__( - self, - policy: ServerSidePolicy, - host: str = "*", - port: int = 5555, - api_token: str | None = None, - timeout_ms: int = 15000, - ) -> None: - self._policy = policy - self._running = True - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REP) - self._socket.setsockopt(zmq.RCVTIMEO, timeout_ms) - bind_addr = f"tcp://{host}:{port}" - print(f"[PolicyServer] binding on {bind_addr}") - self._socket.bind(bind_addr) - self._api_token = api_token - - self._endpoints: dict[str, EndpointHandler] = {} - self._register_default_endpoints() - - def _register_default_endpoints(self) -> None: - self.register_endpoint("ping", self._handle_ping, requires_input=False) - self.register_endpoint("kill", self._handle_kill, requires_input=False) - self.register_endpoint("get_action", self._handle_get_action, requires_input=True) - self.register_endpoint("reset", self._handle_reset, requires_input=True) - self.register_endpoint("get_init_info", self._handle_get_init_info, requires_input=True) - self.register_endpoint("set_task_description", self._handle_set_task_description, requires_input=True) - print(f"[PolicyServer] registered endpoints: {list(self._endpoints.keys())}") - - def register_endpoint( - self, - name: str, - handler: Callable[..., Any], - requires_input: bool = True, - ) -> None: - self._endpoints[name] = EndpointHandler(handler=handler, requires_input=requires_input) - - def _handle_get_init_info( - self, - requested_action_mode: str, - ) -> dict[str, Any]: - print(f"[PolicyServer] handle get_init_info: requested_action_mode={requested_action_mode!r}") - resp = self._policy.get_init_info(requested_action_mode=requested_action_mode) - if not isinstance(resp, dict): - raise TypeError(f"Policy.get_init_info() must return dict, got {type(resp)!r}") - return resp - - def _handle_set_task_description( - self, - task_description: str | None = None, - **_: Any, - ) -> dict[str, Any]: - print(f"[PolicyServer] handle set_task_description: {task_description!r}") - resp = self._policy.set_task_description(task_description) - if not isinstance(resp, dict): - raise TypeError(f"Policy.set_task_description() must return dict, got {type(resp)!r}") - return resp - - def _handle_ping(self) -> dict[str, Any]: - print("[PolicyServer] handle ping") - return {"status": "ok"} - - def _handle_kill(self) -> dict[str, Any]: - print("[PolicyServer] handle kill -> stopping") - self._running = False - return {"status": "stopping"} - - def _handle_get_action( - self, - observation: dict[str, Any], - options: dict[str, Any] | None = None, - **_: Any, - ) -> dict[str, Any]: - print("[PolicyServer] handle get_action") - if options is not None: - print(f" options keys: {list(options.keys())}") - action, info = self._policy.get_action( - observation=observation, - options=options, - ) - - if not isinstance(action, dict): - raise TypeError(f"Policy.get_action() must return (dict, dict), got action type={type(action)!r}") - if not isinstance(info, dict): - raise TypeError(f"Policy.get_action() must return (dict, dict), got info type={type(info)!r}") - - merged: dict[str, Any] = {} - merged.update(action) - if any(k in merged for k in info.keys()): - raise ValueError(f"Policy info keys conflict with action keys: {set(merged.keys()) & set(info.keys())}") - merged.update(info) - - return merged - - def _handle_reset(self, env_ids=None, options=None, **_: Any) -> dict[str, Any]: - print(f"[PolicyServer] handle reset: env_ids={env_ids}, options={options}") - status: dict[str, Any] = {"status": "reset_success"} - if hasattr(self._policy, "reset"): - resp = self._policy.reset(env_ids=env_ids, reset_options=options) - if isinstance(resp, dict): - status.update(resp) - return status - - def _validate_token(self, request: dict[str, Any]) -> bool: - if self._api_token is None: - return True - ok = request.get("api_token") == self._api_token - if not ok: - print("[PolicyServer] invalid api_token in request") - return ok - - def run(self) -> None: - addr = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) - print(f"[PolicyServer] listening on {addr}, api_token={self._api_token!r}") - while self._running: - try: - raw = self._socket.recv() - print(f"[PolicyServer] received {len(raw)} bytes") - request = MessageSerializer.from_bytes(raw) - - if not isinstance(request, dict): - raise TypeError(f"Expected dict request, got {type(request)!r}") - - print(f"[PolicyServer] request keys: {list(request.keys())}") - - if not self._validate_token(request): - self._socket.send(MessageSerializer.to_bytes({"error": "Unauthorized: invalid api_token"})) - continue - - endpoint = request.get("endpoint", "get_action") - if "endpoint" not in request: - self._socket.send(MessageSerializer.to_bytes({"error": "Missing 'endpoint' in request"})) - continue - - endpoint = request["endpoint"] - - handler = self._endpoints.get(endpoint) - if handler is None: - raise ValueError(f"Unknown endpoint: {endpoint}") - print(f"[PolicyServer] dispatch endpoint='{endpoint}'") - - data = request.get("data", {}) or {} - if not isinstance(data, dict): - raise TypeError(f"Expected dict data, got {type(data)!r}") - - if handler.requires_input: - result = handler.handler(**data) - else: - result = handler.handler() - - resp_bytes = MessageSerializer.to_bytes(result) - print(f"[PolicyServer] sending response ({len(resp_bytes)} bytes)") - self._socket.send(resp_bytes) - except zmq.Again: - # timeout, loop again - continue - except Exception as exc: - import traceback - - print(f"[PolicyServer] Error: {exc}") - print(traceback.format_exc()) - self._socket.send(MessageSerializer.to_bytes({"error": str(exc)})) - - def close(self) -> None: - """Stop the main loop and close ZMQ resources.""" - self._running = False - try: - self._socket.close(0) - except Exception as exc: - print(f"[PolicyServer] socket.close() error: {exc}") - try: - self._context.term() - except Exception as exc: - print(f"[PolicyServer] context.term() error: {exc}") - - @staticmethod - def start( - policy: ServerSidePolicy, - host: str = "*", - port: int = 5555, - api_token: str | None = None, - timeout_ms: int = 15000, - ) -> None: - server = PolicyServer( - policy=policy, - host=host, - port=port, - api_token=api_token, - timeout_ms=timeout_ms, - ) - server.run() diff --git a/isaaclab_arena/remote_policy/remote_policy_config.py b/isaaclab_arena/remote_policy/remote_policy_config.py deleted file mode 100644 index a256f14cb..000000000 --- a/isaaclab_arena/remote_policy/remote_policy_config.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class RemotePolicyConfig: - """Configuration for using a remote PolicyServer.""" - - host: str - port: int - api_token: str | None = None - timeout_ms: int = 15000 diff --git a/isaaclab_arena/remote_policy/remote_policy_server_runner.py b/isaaclab_arena/remote_policy/remote_policy_server_runner.py deleted file mode 100644 index c96cd00b6..000000000 --- a/isaaclab_arena/remote_policy/remote_policy_server_runner.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - - -from __future__ import annotations - -import argparse -from importlib import import_module - -from isaaclab_arena.remote_policy.policy_server import PolicyServer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy - - -def get_policy_cls(policy_type: str) -> type[ServerSidePolicy]: - """Dynamically import and return a ServerSidePolicy subclass. - - The policy_type argument must be a fully qualified Python path of the form: - "package.subpackage.module.ClassName" - """ - print(f"[remote_policy_server_runner] Importing server-side policy from: {policy_type}") - if "." not in policy_type: - raise ValueError( - "policy_type must be a dotted Python import path of the form " - "'module.submodule.ClassName', " - f"got: {policy_type!r}" - ) - module_path, class_name = policy_type.rsplit(".", 1) - module = import_module(module_path) - policy_cls = getattr(module, class_name) - return policy_cls - - -def build_base_parser() -> argparse.ArgumentParser: - """Build the base CLI parser for the remote policy server. - - This parser only contains arguments that are common to all server-side policies. - Policy-specific arguments are added later by the selected ServerSidePolicy subclass. - """ - parser = argparse.ArgumentParser("IsaacLab Arena Remote Policy Server") - - # Generic server options. - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=5555) - parser.add_argument("--api_token", type=str, default=None) - parser.add_argument("--timeout_ms", type=int, default=5000) - - # Which ServerSidePolicy implementation to run. - parser.add_argument( - "--policy_type", - type=str, - required=True, - help=( - "Dotted Python path of the server-side policy to run, e.g. " - "'isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy'." - ), - ) - return parser - - -def parse_args() -> argparse.Namespace: - """Parse CLI arguments in two stages. - - 1) Parse only the base arguments to discover which policy class to use. - 2) Let that class extend the parser with its own arguments, then parse again. - """ - # Stage 1: parse base args to get policy_type. - base_parser = build_base_parser() - base_args, _ = base_parser.parse_known_args() - - policy_cls = get_policy_cls(base_args.policy_type) - print(f"[remote_policy_server_runner] Requested server-side policy: {base_args.policy_type} -> {policy_cls}") - - # Stage 2: build a fresh parser, extend it with policy-specific arguments, then parse fully. - full_parser = build_base_parser() - if not hasattr(policy_cls, "add_args_to_parser"): - raise TypeError( - f"Server-side policy class {policy_cls} must define a static 'add_args_to_parser(parser)' method." - ) - full_parser = policy_cls.add_args_to_parser(full_parser) # type: ignore[assignment] - - args = full_parser.parse_args() - return args - - -def main() -> None: - """Entry point for running a remote policy server. - - The script: - 1) Parses CLI arguments in two stages. - 2) Instantiates the requested ServerSidePolicy via its from_args() helper. - 3) Wraps it in a PolicyServer and starts the RPC loop. - """ - args = parse_args() - - policy_cls = get_policy_cls(args.policy_type) - if not hasattr(policy_cls, "from_args"): - raise TypeError(f"Server-side policy class {policy_cls} must define a static 'from_args(args)' method.") - - # Construct the server-side policy from CLI arguments. - policy = policy_cls.from_args(args) # type: ignore[call-arg] - - # Start the RPC server. - server = PolicyServer( - policy=policy, - host=args.host, - port=args.port, - api_token=args.api_token, - timeout_ms=args.timeout_ms, - ) - server.run() - - -if __name__ == "__main__": - main() diff --git a/isaaclab_arena/remote_policy/server_side_policy.py b/isaaclab_arena/remote_policy/server_side_policy.py deleted file mode 100644 index 8b96f3bfb..000000000 --- a/isaaclab_arena/remote_policy/server_side_policy.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, -# The Isaac Lab Arena Project Developers -# (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -from abc import ABC, abstractmethod -from typing import Any - -from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol - - -class ServerSidePolicy(ABC): - """Base class for server-side remote policies. - - This class defines: - * The protocol- and handshake-related API that the PolicyServer relies on. - * A minimal configuration hook via ``config_class`` and ``from_dict``. - * A CLI construction pattern via ``add_args_to_parser`` and ``from_args``, - mirroring the design of :class:`isaaclab_arena.policy.policy_base.PolicyBase` - on the client side. - - Concrete server-side policies (e.g. GR00T-based ones) should: - * Implement ``_build_protocol()`` and the core RPC methods. - * Optionally define a dataclass as ``config_class``. - * Implement ``add_args_to_parser(parser)`` and ``from_args(args)`` - so they can be instantiated directly from command-line arguments. - """ - - # Optional: subclasses can define this to enable from_dict() - config_class: type | None = None - - def __init__(self, config: Any | None = None) -> None: - """Base constructor for server-side policies. - - Args: - config: Optional configuration object (for example, a dataclass - instance). Subclasses are free to interpret this as needed. - """ - self.config = config - self._protocol: ActionProtocol | None = None - self._task_description: str | None = None - - # ------------------------------------------------------------------ - # Config helpers (mirroring PolicyBase.from_dict) - # ------------------------------------------------------------------ - - @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> ServerSidePolicy: - """Create a policy instance from a configuration dictionary. - - Path: dict -> ConfigDataclass -> Policy instance - - This mirrors :meth:`PolicyBase.from_dict` on the client side. - """ - if cls.config_class is None: - raise NotImplementedError(f"{cls.__name__} must define 'config_class' to use from_dict().") - - config = cls.config_class(**config_dict) # type: ignore[misc] - return cls(config) # type: ignore[call-arg] - - # ------------------------------------------------------------------ - # Protocol / handshake API - # ------------------------------------------------------------------ - - @abstractmethod - def _build_protocol(self) -> ActionProtocol: - """Subclasses must build and return an ActionProtocol instance.""" - raise NotImplementedError - - @property - def protocol(self) -> ActionProtocol: - """Return the ActionProtocol associated with this policy. - - The protocol is lazily constructed on first access via ``_build_protocol()``. - """ - if self._protocol is None: - self._protocol = self._build_protocol() - if self._protocol.mode is None: - raise ValueError(f"{self.__class__.__name__} has an ActionProtocol with mode=None, which is not allowed.") - return self._protocol - - def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: - """Handle the initial handshake with the client. - - Checks that the requested action mode is valid and supported by - this policy's ActionProtocol, and returns either an error status - or the protocol configuration as a plain dictionary. - """ - proto = self.protocol - - try: - requested_mode_enum = ActionMode(requested_action_mode) - except ValueError: - return { - "status": "invalid_action_mode", - "message": f"Requested action_mode={requested_action_mode!r} is invalid.", - } - - if requested_mode_enum is not proto.mode: - return { - "status": "unsupported_action_mode", - "message": ( - f"Requested action_mode={requested_mode_enum.value!r} " - "is not supported by this policy. " - f"Supported: {proto.mode.value!r}." - ), - } - - return { - "status": "success", - "config": proto.to_dict(), - } - - # ------------------------------------------------------------------ - # Core RPC methods (to be used by PolicyServer) - # ------------------------------------------------------------------ - - @abstractmethod - def get_action( - self, - observation: dict[str, Any], - ) -> dict[str, Any]: - """Compute one or more actions given an observation payload. - - Args: - observation: Flat observation dictionary received from the client. - - Returns: - A dictionary that must contain at least an ``"action"`` entry - whose structure is compatible with the negotiated ActionProtocol. - """ - raise NotImplementedError - - def reset(self) -> None: - """Reset the policy state. - - Subclasses may override this if they maintain per-environment or - global state that needs to be cleared between episodes. - """ - ... - - def set_task_description( - self, - task_description: str | None, - ) -> dict[str, Any]: - """Set the task description and return a small status/config payload. - - The default implementation stores the description locally and - echoes it back. Subclasses can override this to perform additional - updates or validation. - """ - self._task_description = task_description - return {"task_description": self._task_description or ""} - - # ------------------------------------------------------------------ - # Shared helpers - # ------------------------------------------------------------------ - - def unpack_observation(self, flat_obs: dict[str, Any]) -> dict[str, Any]: - """Convert a flat dotted-key observation dict into a nested dict. - - For example, a key ``"camera_obs.pov.rgb"`` becomes - ``nested["camera_obs"]["pov"]["rgb"]``. - """ - nested: dict[str, Any] = {} - for key_path, value in flat_obs.items(): - cur = nested - parts = key_path.split(".") - for k in parts[:-1]: - cur = cur.setdefault(k, {}) - cur[parts[-1]] = value - return nested - - # ------------------------------------------------------------------ - # CLI helpers (to mirror PolicyBase.add_args_to_parser / from_args) - # ------------------------------------------------------------------ - - @staticmethod - @abstractmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add policy-specific CLI arguments to the parser. - - Server-side policies are expected to implement this so that - :mod:`remote_policy_server_runner` can delegate CLI argument - definitions to the selected policy class. - """ - raise NotImplementedError("ServerSidePolicy subclasses must implement add_args_to_parser().") - - @staticmethod - @abstractmethod - def from_args(args: argparse.Namespace) -> ServerSidePolicy: - """Construct a server-side policy instance from CLI arguments. - - This mirrors the ``from_args(args)`` pattern used by client-side - policies deriving from :class:`PolicyBase`. - """ - raise NotImplementedError("ServerSidePolicy subclasses must implement from_args(args).") diff --git a/isaaclab_arena/tests/test_action_chunking_client.py b/isaaclab_arena/tests/test_action_chunking_client.py deleted file mode 100644 index aab81f67f..000000000 --- a/isaaclab_arena/tests/test_action_chunking_client.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2025-2026, -# The Isaac Lab Arena Project Developers -# (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import numpy as np -import threading -import time -from typing import Any - -import pytest - -from isaaclab_arena.remote_policy.action_protocol import ActionProtocol, ChunkingActionProtocol -from isaaclab_arena.remote_policy.policy_server import PolicyServer -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -from isaaclab_arena.tests.utils.constants import TestConstants -from isaaclab_arena.tests.utils.subprocess import run_subprocess - -HEADLESS = True -NUM_STEPS = 2 -HOST = "127.0.0.1" -PORT = 5563 # test-only port, avoid conflicts - - -# ====================================================================================== -# Dummy server-side policy using the real ChunkingActionProtocol -# ====================================================================================== - - -class _DummyChunkingServerPolicy(ServerSidePolicy): - """Server-side policy that uses ChunkingActionProtocol and returns fixed chunks.""" - - def __init__(self, action_dim: int = 50, chunk_length: int = 4) -> None: - super().__init__(config=None) - self._action_dim = action_dim - self._chunk_length = chunk_length - self._counter = 0 - - def _build_protocol(self) -> ActionProtocol: - return ChunkingActionProtocol( - action_dim=self._action_dim, - observation_keys=["policy.robot_joint_pos"], - action_chunk_length=self._chunk_length, - action_horizon=self._chunk_length, - ) - - def get_action( - self, - observation: dict[str, Any], - options: dict[str, Any] | None = None, - ) -> dict[str, Any]: - """Return (batch, chunk_length, action_dim) array with a simple pattern. - - The options argument is accepted to match PolicyServer._handle_get_action, - but is not used in this dummy implementation. - """ - first_key = next(iter(observation.keys())) - batch = int(np.shape(observation[first_key])[0]) - - base_value = float(self._counter) - self._counter += 1 - - chunk = np.full( - (batch, self._chunk_length, self._action_dim), - fill_value=base_value, - dtype=np.float32, - ) - # IMPORTANT: return a dict containing "action" and "info" - return {"action": chunk}, {} - - # NEW: match what PolicyServer._handle_reset expects - def reset(self, env_ids: list[int] | None = None, reset_options: dict[str, Any] | None = None) -> dict[str, Any]: - """Reset policy state for the given environment ids. - - The implementation here is trivial; it just returns an OK status - and does not keep any per-env state. - """ - return {"status": "ok"} - - @staticmethod - def add_args_to_parser(parser: Any) -> Any: - return parser - - @staticmethod - def from_args(args: Any) -> _DummyChunkingServerPolicy: - return _DummyChunkingServerPolicy() - - -# ====================================================================================== -# Helper to start/stop a PolicyServer in background -# ====================================================================================== - - -@pytest.fixture -def running_dummy_chunking_server() -> PolicyServer: - """Start a PolicyServer with _DummyChunkingServerPolicy on localhost.""" - policy = _DummyChunkingServerPolicy(chunk_length=4) - server = PolicyServer( - policy=policy, - host=HOST, - port=PORT, - api_token=None, - timeout_ms=2_000, - ) - - thread = threading.Thread(target=server.run, daemon=True) - thread.start() - # Give the server a short time to bind and start. - time.sleep(0.2) - - try: - yield server - finally: - # Ask the server to stop and wait for the thread. - server.running = False - thread.join(timeout=5.0) - - if hasattr(server, "close"): - server.close() - assert not thread.is_alive() - - -# ====================================================================================== -# Helper to call policy_runner (same style as existing tests) -# ====================================================================================== - - -def _run_policy_runner_with_action_chunking_client() -> None: - """Run policy_runner.py with ActionChunkingClientSidePolicy in remote mode. - - The remote host/port are set to the dummy server started by the fixture. - """ - args: list[str] = [ - TestConstants.python_path, - f"{TestConstants.evaluation_dir}/policy_runner.py", - ] - - args.extend([ - "--policy_type", - "isaaclab_arena.policy.action_chunking_client.ActionChunkingClientSidePolicy", - ]) - - args.extend([ - "--remote_host", - HOST, - "--remote_port", - str(PORT), - "--remote_kill_on_exit", - ]) - - args.extend(["--num_steps", str(NUM_STEPS)]) - if HEADLESS: - args.append("--headless") - - args.append("galileo_g1_locomanip_pick_and_place") - args.extend(["--embodiment", "g1_wbc_joint"]) - args.extend(["--object", "brown_box"]) - - run_subprocess(args) - - -# ====================================================================================== -# Test -# ====================================================================================== - - -@pytest.mark.with_subprocess -def test_action_chunking_client_end_to_end_with_dummy_chunking_server( - running_dummy_chunking_server: PolicyServer, -) -> None: - """End-to-end test: dummy chunking server + ActionChunkingClientSidePolicy + policy_runner. - - This verifies that: - - The dummy PolicyServer using ChunkingActionProtocol can be reached on HOST:PORT. - - ActionChunkingClientSidePolicy can connect to it via policy_runner.py. - - The process exits successfully for a short rollout. - """ - _run_policy_runner_with_action_chunking_client() diff --git a/isaaclab_arena/tests/test_policy_client.py b/isaaclab_arena/tests/test_policy_client.py deleted file mode 100644 index 42fc18b17..000000000 --- a/isaaclab_arena/tests/test_policy_client.py +++ /dev/null @@ -1,142 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -import threading -import time -from typing import Any - -import pytest -import zmq - -from isaaclab_arena.remote_policy.message_serializer import MessageSerializer -from isaaclab_arena.remote_policy.policy_client import PolicyClient -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig - - -class _DummyServer: - """Minimal test server that emulates a subset of PolicyServer behavior. - - It only understands the endpoints used by PolicyClient tests and always - responds with well-formed msgpack-encoded dictionaries. - """ - - def __init__(self, host: str = "127.0.0.1", port: int = 5557, api_token: str | None = None) -> None: - self._host = host - self._port = port - self._api_token = api_token - self._context = zmq.Context() - self._socket = self._context.socket(zmq.REP) - self._running = False - self._thread: threading.Thread | None = None - - def start(self) -> None: - """Start the server loop in a background thread.""" - bind_addr = f"tcp://{self._host}:{self._port}" - self._socket.bind(bind_addr) - self._running = True - self._thread = threading.Thread(target=self._loop, daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the server loop and close the socket.""" - self._running = False - if self._thread is not None: - self._thread.join(timeout=5.0) - self._socket.close(0) - self._context.term() - - def _loop(self) -> None: - """Event loop that receives one request and sends one response.""" - while self._running: - try: - message = self._socket.recv(flags=zmq.NOBLOCK) - except zmq.Again: - time.sleep(0.01) - continue - - request: dict[str, Any] = MessageSerializer.from_bytes(message) - - # Real code uses "api_token" on the wire. - if self._api_token is not None: - if request.get("api_token") != self._api_token: - response: dict[str, Any] = {"error": "invalid apitoken"} - self._socket.send(MessageSerializer.to_bytes(response)) - continue - - endpoint = request.get("endpoint", "") - data = request.get("data", {}) or {} - - if endpoint == "get_action": - # Return a minimal valid action payload; client expects a dict. - resp = {"action": [[0.0, 1.0, 2.0]], "info": {"dummy": True}} - elif endpoint == "get_init_info": - resp = {"obs_keys": ["rgb", "depth"], "action_dim": 3} - elif endpoint == "set_task_description": - desc = data.get("task_description", "") - resp = {"status": "ok", "echo": desc} - elif endpoint == "ping": - resp = {"status": "alive"} - else: - resp = {"error": f"unknown endpoint {endpoint!r}"} - - self._socket.send(MessageSerializer.to_bytes(resp)) - - -@pytest.fixture -def dummy_server() -> _DummyServer: - """Fixture that starts a dummy server and tears it down after the test.""" - server = _DummyServer(host="127.0.0.1", port=5557, api_token="SECRET") - server.start() - # Give the background thread a short time to bind the socket. - time.sleep(0.1) - try: - yield server - finally: - server.stop() - - -def test_policy_client_call_endpoint_and_get_action(dummy_server: _DummyServer) -> None: - """PolicyClient should be able to call endpoints and parse responses.""" - config = RemotePolicyConfig(host="127.0.0.1", port=5557, api_token="SECRET", timeout_ms=2000) - client = PolicyClient(config=config) - - # Test ping endpoint without input. - resp = client.call_endpoint(endpoint="ping", data=None, requires_input=False) - assert isinstance(resp, dict) - assert resp.get("status") == "alive" - - # Test get_action endpoint with dummy observation. - action_resp = client.get_action({ - "rgb": "dummy", # Content does not matter for this dummy server. - }) - assert isinstance(action_resp, dict) - assert "action" in action_resp - assert "info" in action_resp - - action = action_resp["action"] - assert isinstance(action, list) - assert len(action) == 1 - assert len(action[0]) == 3 - - client.close() - - -def test_policy_client_get_init_info_and_set_task_description(dummy_server: _DummyServer) -> None: - """get_init_info and set_task_description should return dictionaries.""" - config = RemotePolicyConfig(host="127.0.0.1", port=5557, api_token="SECRET", timeout_ms=2000) - client = PolicyClient(config=config) - - init_info = client.get_init_info({"dummy": True}) - assert isinstance(init_info, dict) - assert "obs_keys" in init_info - assert "action_dim" in init_info - - desc = "open the microwave door" - status = client.set_task_description(desc) - assert isinstance(status, dict) - assert status.get("status") == "ok" - assert status.get("echo") == desc - - client.close() diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 8186ff709..32af9f56d 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -14,9 +14,15 @@ from typing import Any from gr00t.data.embodiment_tags import EmbodimentTag +# NOTE: Gr00tPolicy is a heavy import (transformers, model loading). This local policy +# loads the model in-process and requires the full GR00T ML stack. For production +# evaluation, use Gr00tRemoteClosedloopPolicy which delegates inference to a remote +# GR00T server and has no heavy dependencies. This local policy may be removed in a +# future release if all workflows move to the remote path. from gr00t.policy.gr00t_policy import Gr00tPolicy -from isaaclab_arena.policy.action_chunking import ActionChunkingState +from isaaclab_arena.policy.action_chunking import ActionChunkScheduler +from isaaclab_arena.policy.action_scheduler import ActionScheduler from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode @@ -27,7 +33,6 @@ compute_action_dim, extract_obs_numpy_from_torch, load_gr00t_joint_configs, - load_gr00t_policy_from_config, ) from isaaclab_arena_gr00t.utils.eagle_config_compat import apply_eagle_config_compat from isaaclab_arena_gr00t.utils.io_utils import ( @@ -68,7 +73,7 @@ class Gr00tClosedloopPolicy(PolicyBase): name = "gr00t_closedloop" config_class = Gr00tClosedloopPolicyArgs - def __init__(self, config: Gr00tClosedloopPolicyArgs): + def __init__(self, config: Gr00tClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): """Initialize Gr00tClosedloopPolicy from a configuration dataclass.""" super().__init__(config) @@ -76,7 +81,7 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig ) - self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) + self.policy: Gr00tPolicy = self._load_policy() # Basic attributes self.num_envs = config.num_envs @@ -105,15 +110,16 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) self.action_chunk_length = self.policy_config.action_chunk_length - # Shared chunking state (unified with remote ActionChunkingClientSidePolicy) - self._chunking_state = ActionChunkingState( - num_envs=self.num_envs, - action_chunk_length=self.action_chunk_length, - action_horizon=self.policy_config.action_horizon, - action_dim=self.action_dim, - device=self.device, - dtype=torch.float, - ) + if action_scheduler is None: + action_scheduler = ActionChunkScheduler( + num_envs=self.num_envs, + action_chunk_length=self.action_chunk_length, + action_horizon=self.policy_config.action_horizon, + action_dim=self.action_dim, + device=self.device, + dtype=torch.float, + ) + self._action_scheduler = action_scheduler # task description of task being evaluated. It will be set by the task being evaluated. self.task_description: str | None = None @@ -156,16 +162,18 @@ def load_sim_action_joints_config(self, action_config_path: Path) -> dict[str, A """Load the simulation action joint config from the data config.""" return load_robot_joints_config_from_yaml(action_config_path) - def load_policy(self) -> Gr00tPolicy: - """Load the dataset, whose iterator will be used as the policy.""" - assert Path( - self.policy_config.model_path - ).exists(), f"Dataset path {self.policy_config.dataset_path} does not exist" + def _load_policy(self) -> Gr00tPolicy: + """Load the GR00T policy model in-process.""" + model_path = self.policy_config.model_path + is_hf_id = bool(model_path and "/" in model_path and not model_path.startswith(("/", "."))) + assert ( + Path(model_path).exists() or is_hf_id + ), f"Model path {model_path} does not exist and is not a HuggingFace model id" apply_eagle_config_compat() return Gr00tPolicy( - model_path=self.policy_config.model_path, + model_path=model_path, embodiment_tag=EmbodimentTag[self.policy_config.embodiment_tag], device=self.device, strict=True, @@ -216,7 +224,7 @@ def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: def fetch_chunk() -> torch.Tensor: return self.get_action_chunk(observation, self.policy_config.pov_cam_name_sim) - return self._chunking_state.get_action(fetch_chunk) + return self._action_scheduler.get_action(fetch_chunk) def get_action_chunk( self, observation: dict[str, Any], camera_names: list[str] | str = "robot_head_cam_rgb" @@ -249,4 +257,4 @@ def reset(self, env_ids: torch.Tensor | None = None): env_ids = slice(None) # placeholder for future reset options from GR00T repo self.policy.reset() - self._chunking_state.reset(env_ids) + self._action_scheduler.reset(env_ids) diff --git a/isaaclab_arena_gr00t/policy/gr00t_core.py b/isaaclab_arena_gr00t/policy/gr00t_core.py index c53241b09..abbf6e8f5 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_core.py +++ b/isaaclab_arena_gr00t/policy/gr00t_core.py @@ -35,7 +35,6 @@ from typing import Any from gr00t.data.embodiment_tags import EmbodimentTag -from gr00t.policy.gr00t_policy import Gr00tPolicy from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.policy_constants import ( NUM_BASE_HEIGHT_CMD, @@ -82,32 +81,6 @@ class Gr00tBasePolicyArgs: # --------------------------------------------------------------------------- # -def load_gr00t_policy_from_config(policy_config: Gr00tClosedloopPolicyConfig) -> Gr00tPolicy: - """Instantiate a GR00T policy from the closed-loop config. - - Args: - policy_config: Loaded closed-loop config (model path, embodiment, device). - - Returns: - Loaded ``Gr00tPolicy`` on the configured device. - - Raises: - AssertionError: If ``policy_config.model_path`` does not exist. - """ - model_path = policy_config.model_path - # HuggingFace Hub repo IDs use "owner/repo" format (e.g. "nvidia/GR00T-N1.6-DROID"). - is_hf_id = bool(model_path and "/" in model_path and not model_path.startswith(("/", "."))) - assert ( - Path(model_path).exists() or is_hf_id - ), f"Model path {model_path} does not exist and is not a HuggingFace model id" - return Gr00tPolicy( - model_path=policy_config.model_path, - embodiment_tag=EmbodimentTag[policy_config.embodiment_tag], - device=policy_config.policy_device, - strict=True, - ) - - def load_gr00t_joint_configs( policy_config: Gr00tClosedloopPolicyConfig, ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py new file mode 100644 index 000000000..fdea97fdc --- /dev/null +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_closedloop_policy.py @@ -0,0 +1,240 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""GR00T remote closed-loop policy using GR00T's native PolicyClient. + +This policy connects to a GR00T policy server (launched via +``gr00t/eval/run_gr00t_server.py``) and reuses the same observation/action +translation pipeline as the local ``Gr00tClosedloopPolicy``. +""" + +from __future__ import annotations + +import argparse +import gymnasium as gym +import torch +from dataclasses import dataclass, field +from typing import Any + +from gr00t.policy.server_client import PolicyClient as Gr00tPolicyClient + +from isaaclab_arena.policy.action_chunking import ActionChunkScheduler +from isaaclab_arena.policy.action_scheduler import ActionScheduler +from isaaclab_arena.policy.policy_base import PolicyBase +from isaaclab_arena.utils.multiprocess import get_local_rank, get_world_size +from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode +from isaaclab_arena_gr00t.policy.gr00t_core import ( + Gr00tBasePolicyArgs, + build_gr00t_action_tensor, + build_gr00t_policy_observations, + compute_action_dim, + extract_obs_numpy_from_torch, + load_gr00t_joint_configs, +) +from isaaclab_arena_gr00t.utils.io_utils import ( + create_config_from_yaml, + load_gr00t_modality_config_from_file, +) + + +@dataclass +class Gr00tRemoteClosedloopPolicyArgs(Gr00tBasePolicyArgs): + """Configuration for Gr00tRemoteClosedloopPolicy. + + Inherits policy_config_yaml_path and policy_device from Gr00tBasePolicyArgs, + and adds remote server connection parameters and num_envs. + """ + + num_envs: int = field(default=1, metadata={"help": "Number of environments to simulate"}) + remote_host: str = field(default="localhost", metadata={"help": "GR00T policy server hostname"}) + remote_port: int = field(default=5555, metadata={"help": "GR00T policy server port"}) + remote_api_token: str | None = field(default=None, metadata={"help": "API token for the policy server"}) + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> Gr00tRemoteClosedloopPolicyArgs: + """Create configuration from parsed CLI arguments.""" + return cls( + policy_config_yaml_path=args.policy_config_yaml_path, + policy_device=args.policy_device, + num_envs=args.num_envs, + remote_host=args.remote_host, + remote_port=args.remote_port, + remote_api_token=getattr(args, "remote_api_token", None), + ) + + +class Gr00tRemoteClosedloopPolicy(PolicyBase): + """GR00T closed-loop policy that delegates inference to a remote GR00T server. + + Uses GR00T's native ``PolicyClient`` (from ``gr00t.policy.server_client``) + to communicate with a GR00T policy server. The observation/action translation + pipeline is identical to the local ``Gr00tClosedloopPolicy``. + + Server side (run independently): + python gr00t/eval/run_gr00t_server.py \\ + --model_path nvidia/GR00T-N1.6-DROID \\ + --embodiment_tag OXE_DROID --device cuda --host 0.0.0.0 --port 5555 + + Client side (Arena evaluation): + python policy_runner.py \\ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_closedloop_policy.Gr00tRemoteClosedloopPolicy \\ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/droid_manip_gr00t_closedloop_config.yaml \\ + --remote_host 10.0.0.1 --remote_port 5555 \\ + --enable_cameras --num_episodes 5 \\ + pick_and_place_maple_table --embodiment droid_abs_joint_pos + """ + + name = "gr00t_remote_closedloop" + config_class = Gr00tRemoteClosedloopPolicyArgs + + def __init__(self, config: Gr00tRemoteClosedloopPolicyArgs, action_scheduler: ActionScheduler | None = None): + super().__init__(config) + + # Policy config (for obs/action translation — no model loading) + self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( + config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig + ) + self.num_envs = config.num_envs + self.device = config.policy_device + if get_world_size() > 1 and "cuda" in self.device: + self.device = f"cuda:{get_local_rank()}" + self.task_mode = TaskMode(self.policy_config.task_mode_name) + + # Joint configs (for sim↔policy joint remapping) + ( + self.policy_joints_config, + self.robot_action_joints_config, + self.robot_state_joints_config, + ) = load_gr00t_joint_configs(self.policy_config) + + # Modality config (for building GR00T observation dicts) + self.modality_configs = load_gr00t_modality_config_from_file( + self.policy_config.modality_config_path, + self.policy_config.embodiment_tag, + ) + + # Action / chunk shapes + self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) + self.action_chunk_length = self.policy_config.action_chunk_length + + if action_scheduler is None: + action_scheduler = ActionChunkScheduler( + num_envs=self.num_envs, + action_chunk_length=self.action_chunk_length, + action_horizon=self.policy_config.action_horizon, + action_dim=self.action_dim, + device=self.device, + dtype=torch.float, + ) + self._action_scheduler = action_scheduler + + # Connect to GR00T's native policy server + self._client = Gr00tPolicyClient( + host=config.remote_host, + port=config.remote_port, + api_token=config.remote_api_token, + strict=False, + ) + if not self._client.ping(): + raise ConnectionError( + f"Cannot reach GR00T policy server at {config.remote_host}:{config.remote_port}" + ) + + self.task_description: str | None = None + + # ---------------------- CLI helpers ------------------- + + @staticmethod + def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + group = parser.add_argument_group( + "Gr00t Remote Closedloop Policy", + "Arguments for GR00T remote closed-loop policy evaluation.", + ) + group.add_argument( + "--policy_config_yaml_path", + type=str, + required=True, + help="Path to the Gr00t closedloop policy config YAML file", + ) + group.add_argument( + "--policy_device", + type=str, + default="cuda", + help="Device for Arena-side tensor operations (default: cuda)", + ) + group.add_argument("--remote_host", type=str, default="localhost", help="GR00T policy server hostname") + group.add_argument("--remote_port", type=int, default=5555, help="GR00T policy server port") + group.add_argument("--remote_api_token", type=str, default=None, help="API token for the policy server") + return parser + + @staticmethod + def from_args(args: argparse.Namespace) -> Gr00tRemoteClosedloopPolicy: + config = Gr00tRemoteClosedloopPolicyArgs.from_cli_args(args) + return Gr00tRemoteClosedloopPolicy(config) + + # ---------------------- Policy interface ------------------- + + def set_task_description(self, task_description: str | None) -> str: + if task_description is None: + task_description = self.policy_config.language_instruction + if not task_description: + raise ValueError( + "No language instruction provided. Set 'language_instruction' in the job config, " + "pass --language_instruction on the CLI, or define 'task_description' on the task class." + ) + self.task_description = task_description + return self.task_description + + def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: + def fetch_chunk() -> torch.Tensor: + return self._get_action_chunk(observation, self.policy_config.pov_cam_name_sim) + + return self._action_scheduler.get_action(fetch_chunk) + + def _get_action_chunk( + self, observation: dict[str, Any], camera_names: list[str] | str = "robot_head_cam_rgb" + ) -> torch.Tensor: + """Get an action chunk from the remote GR00T server. + + Same pipeline as Gr00tClosedloopPolicy.get_action_chunk(), but calls + GR00T's PolicyClient instead of a local Gr00tPolicy. + """ + if isinstance(camera_names, str): + camera_names = [camera_names] + + # 1. Reuse the same obs translation as local policy + assert self.task_description is not None, "Task description is not set" + rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_torch(nested_obs=observation, camera_names=camera_names) + policy_observations = build_gr00t_policy_observations( + rgb_list_np=rgb_list_np, + joint_pos_sim_np=joint_pos_sim_np, + task_description=self.task_description, + policy_config=self.policy_config, + robot_state_joints_config=self.robot_state_joints_config, + policy_joints_config=self.policy_joints_config, + modality_configs=self.modality_configs, + ) + + # 2. Call GR00T's own client + robot_action_policy, _ = self._client.get_action(policy_observations) + + # 3. Reuse the same action translation as local policy + action_tensor = build_gr00t_action_tensor( + robot_action_policy=robot_action_policy, + task_mode=self.task_mode, + policy_joints_config=self.policy_joints_config, + robot_action_joints_config=self.robot_action_joints_config, + device=self.device, + embodiment_tag=self.policy_config.embodiment_tag, + ) + + assert action_tensor.shape[0] == self.num_envs and action_tensor.shape[1] >= self.action_chunk_length + return action_tensor + + def reset(self, env_ids: torch.Tensor | None = None): + if env_ids is None: + env_ids = slice(None) + self._client.reset() + self._action_scheduler.reset(env_ids) diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py deleted file mode 100644 index 3dbbed49b..000000000 --- a/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py +++ /dev/null @@ -1,227 +0,0 @@ -# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import argparse -import os -import sys -from dataclasses import dataclass -from typing import Any - -# Same as local (Isaac Sim): GR00T deps appended via .pth; do not prepend so system packages (numpy, cv2, tokenizers 0.21) are used first. -_GROOT_DEPS_DIR = os.environ.get("GROOT_DEPS_DIR") -if _GROOT_DEPS_DIR and _GROOT_DEPS_DIR not in sys.path: - sys.path.append(_GROOT_DEPS_DIR) - -from gr00t.policy.gr00t_policy import Gr00tPolicy - -from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode -from isaaclab_arena_gr00t.policy.gr00t_core import ( - Gr00tBasePolicyArgs, - build_gr00t_action_tensor, - build_gr00t_policy_observations, - compute_action_dim, - extract_obs_numpy_from_packed, - load_gr00t_joint_configs, - load_gr00t_policy_from_config, -) -from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml, load_gr00t_modality_config_from_file, to_numpy - - -@dataclass -class Gr00tRemotePolicyArgs(Gr00tBasePolicyArgs): - """Configuration for Gr00tRemoteServerSidePolicy. - - Reuses policy_config_yaml_path and policy_device from the base. - """ - - @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> Gr00tRemotePolicyArgs: - return cls( - policy_config_yaml_path=args.policy_config_yaml_path, - policy_device=args.policy_device, - ) - - -class Gr00tRemoteServerSidePolicy(ServerSidePolicy): - """Server-side wrapper around Gr00tPolicy.""" - - config_class = Gr00tRemotePolicyArgs - - def __init__(self, config: Gr00tRemotePolicyArgs) -> None: - super().__init__(config) - - print(f"[Gr00tRemoteServerSidePolicy] loading config from: {config.policy_config_yaml_path}") - self.policy_config: Gr00tClosedloopPolicyConfig = create_config_from_yaml( - config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig - ) - print( - "[Gr00tRemoteServerSidePolicy] config:\n" - f" model_path = {self.policy_config.model_path}\n" - f" embodiment_tag = {self.policy_config.embodiment_tag}\n" - f" task_mode_name = {self.policy_config.task_mode_name}\n" - f" action_horizon = {self.policy_config.action_horizon}\n" - f" action_chunk_len = {self.policy_config.action_chunk_length}\n" - f" pov_cam_name_sim = {self.policy_config.pov_cam_name_sim}\n" - f" policy_device = {self.policy_config.policy_device}\n" - ) - - self.device = config.policy_device - self.task_mode = TaskMode(self.policy_config.task_mode_name) - - # Joint configurations - ( - self.policy_joints_config, - self.robot_action_joints_config, - self.robot_state_joints_config, - ) = load_gr00t_joint_configs(self.policy_config) - - # Modality config - self.modality_configs = load_gr00t_modality_config_from_file( - self.policy_config.modality_config_path, - self.policy_config.embodiment_tag, - ) - - # Action dimensions - self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) - self.action_chunk_length = self.policy_config.action_chunk_length - self.action_horizon = self.policy_config.action_horizon - - # Underlying GR00T policy - self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) - print("[Gr00tRemoteServerSidePolicy] Gr00tPolicy loaded successfully") - - # Required observation keys for protocol (one key per camera) - self.camera_names: list[str] = self.policy_config.pov_cam_name_sim - self.required_observation_keys: list[str] = [f"camera_obs.{cam}" for cam in self.camera_names] + [ - "policy.robot_joint_pos" - ] - - # Task description will be set via set_task_description RPC - self._task_description: str | None = None - - # ---------------------- CLI helpers (server-side) ------------------- - - @staticmethod - def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: - """Add server-side GR00T remote policy arguments.""" - group = parser.add_argument_group( - "Gr00t Remote Server Policy", - "Arguments for GR00T-based server-side remote policy.", - ) - group.add_argument( - "--policy_config_yaml_path", - type=str, - required=True, - help="Path to the GR00T closedloop policy config YAML file.", - ) - group.add_argument( - "--policy_device", - type=str, - default="cuda", - help="Device to use for server-side GR00T inference (default: cuda).", - ) - return parser - - @staticmethod - def from_args(args: argparse.Namespace) -> Gr00tRemoteServerSidePolicy: - """Create a Gr00tRemoteServerSidePolicy from CLI arguments.""" - config = Gr00tRemotePolicyArgs.from_cli_args(args) - return Gr00tRemoteServerSidePolicy(config) - - # ------------ protocol ------------ - - def _build_protocol(self) -> ChunkingActionProtocol: - proto = ChunkingActionProtocol( - action_dim=self.action_dim, - observation_keys=self.required_observation_keys, - action_chunk_length=self.action_chunk_length, - action_horizon=self.action_horizon, - ) - print(f"[Gr00tRemoteServerSidePolicy] protocol mode = {proto.mode.value}") - return proto - - # ------------------------------------------------------------------ # - # Helper methods - # ------------------------------------------------------------------ # - - def _build_policy_observations( - self, - observation: dict[str, Any], - camera_names: list[str], - ) -> dict[str, Any]: - """Convert packed numpy observation into numpy GR00T policy inputs. - - Uses ``extract_obs_numpy_from_packed`` as the single explicit - data-extraction boundary for the remote pipeline, then delegates - to the shared core preprocessing. - """ - assert self._task_description is not None, "Task description is not set" - - rgb_list_np, joint_pos_sim_np = extract_obs_numpy_from_packed( - observation, camera_names, self.unpack_observation - ) - - return build_gr00t_policy_observations( - rgb_list_np=rgb_list_np, - joint_pos_sim_np=joint_pos_sim_np, - task_description=self._task_description, - policy_config=self.policy_config, - robot_state_joints_config=self.robot_state_joints_config, - policy_joints_config=self.policy_joints_config, - modality_configs=self.modality_configs, - ) - - # ------------------------------------------------------------------ # - # ServerSidePolicy interface - # ------------------------------------------------------------------ # - - def set_task_description(self, task_description: str | None) -> dict[str, Any]: - if task_description is None: - task_description = self.policy_config.language_instruction - if not task_description: - raise ValueError( - "No language instruction provided. Set 'language_instruction' in the job config, " - "pass --language_instruction on the CLI, or define 'task_description' on the task class." - ) - self._task_description = task_description - return {"status": "ok"} - - def get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - # 1) Shared numpy-based preprocessing - policy_observations = self._build_policy_observations(observation, self.camera_names) - - # 2) GR00T forward pass - robot_action_policy, _ = self.policy.get_action(policy_observations) - - # 3) Postprocessing (shared with closedloop) - action_tensor = build_gr00t_action_tensor( - robot_action_policy=robot_action_policy, - task_mode=self.task_mode, - policy_joints_config=self.policy_joints_config, - robot_action_joints_config=self.robot_action_joints_config, - device=self.device, - embodiment_tag=self.policy_config.embodiment_tag, - ) - - assert action_tensor.shape[1] >= self.action_chunk_length - - action_chunk = to_numpy(action_tensor) - # NOTE(huikang, 2026-02-06): Currently, it seems that the output action length is action_horizon, - # but the action chunk post-process actually handles a length of action_chunk_length. - # It looks like we can transmit a tensor of length action_chunk_length. At the moment, action_chunk_length and action_horizon are the same. - action: dict[str, Any] = {"action": action_chunk} - info: dict[str, Any] = {} - return action, info - - def reset(self, env_ids: list[int] | None = None, reset_options: dict[str, Any] | None = None) -> dict[str, Any]: - # placeholder for future reset options from GR00T repo - self.policy.reset() - return {"status": "reset_success"} diff --git a/isaaclab_arena_gr00t/utils/groot_path.py b/isaaclab_arena_gr00t/utils/groot_path.py index 8c797e5fa..6893b692e 100644 --- a/isaaclab_arena_gr00t/utils/groot_path.py +++ b/isaaclab_arena_gr00t/utils/groot_path.py @@ -5,23 +5,40 @@ import os import sys +from pathlib import Path -# TODO(xinjie.yao, 2026.03.31): Remove it after policy sever-client is implemented properly in v0.3. -def ensure_groot_deps_in_path(reexec_argv: list[str] | None = None) -> None: - """Prepend ``GROOT_DEPS_DIR`` to ``PYTHONPATH`` and re-exec the process so - GR00T dependencies are importable before Isaac Sim's bundled packages. +def _find_groot_submodule() -> str | None: + """Locate the Isaac-GR00T submodule relative to the repo root.""" + # Walk up from this file to find the repo root (where submodules/ lives) + current = Path(__file__).resolve() + for parent in current.parents: + candidate = parent / "submodules" / "Isaac-GR00T" + if candidate.is_dir(): + return str(candidate) + return None - The function is guarded by the ``_GROOT_PYTHONPATH_APPLIED`` env-var so it - only re-execs once. If ``GROOT_DEPS_DIR`` is not set the call is a no-op. - Args: - reexec_argv: The argv list to pass to ``os.execv`` after the Python - interpreter. Defaults to ``sys.argv`` (i.e. re-run the current - script with the same arguments). Pass - ``["-m", "pytest"] + sys.argv[1:]`` when bootstrapping from a - pytest conftest so the test runner is invoked correctly. +def ensure_groot_in_path() -> None: + """Ensure the Isaac-GR00T submodule is importable. + + Adds the submodule to sys.path if ``gr00t`` is not already importable. + This allows the lightweight client imports (PolicyClient, MsgSerializer) + without requiring a full ``pip install`` of the GR00T package. + + Also prepends ``GROOT_DEPS_DIR`` to ``PYTHONPATH`` and re-execs the + process if set, so GR00T's pip dependencies are importable before + Isaac Sim's bundled packages. """ + # 1. Add submodule to sys.path if gr00t is not already importable + try: + import gr00t # noqa: F401 + except ModuleNotFoundError: + submodule_path = _find_groot_submodule() + if submodule_path and submodule_path not in sys.path: + sys.path.insert(0, submodule_path) + + # 2. Handle GROOT_DEPS_DIR re-exec (for heavy deps like transformers) deps_dir = os.environ.get("GROOT_DEPS_DIR") if not deps_dir or os.environ.get("_GROOT_PYTHONPATH_APPLIED") == "1": return @@ -29,6 +46,8 @@ def ensure_groot_deps_in_path(reexec_argv: list[str] | None = None) -> None: os.environ["PYTHONPATH"] = deps_dir + os.pathsep + os.environ.get("PYTHONPATH", "") os.environ["_GROOT_PYTHONPATH_APPLIED"] = "1" - if reexec_argv is None: - reexec_argv = sys.argv - os.execv(sys.executable, [sys.executable] + reexec_argv) + os.execv(sys.executable, [sys.executable] + sys.argv) + + +# Keep old name as alias for backward compatibility +ensure_groot_deps_in_path = ensure_groot_in_path diff --git a/isaaclab_arena_gr00t/utils/io_utils.py b/isaaclab_arena_gr00t/utils/io_utils.py index 15f25a6e5..3f3d87d97 100644 --- a/isaaclab_arena_gr00t/utils/io_utils.py +++ b/isaaclab_arena_gr00t/utils/io_utils.py @@ -204,13 +204,22 @@ def load_gr00t_modality_config_from_file(modality_config_path: str | Path, embod Returns: modality_configs: Modality configurations """ + import importlib + import sys + from gr00t.configs.data.embodiment_configs import MODALITY_CONFIGS from gr00t.data.embodiment_tags import EmbodimentTag - from gr00t.experiment.launch_finetune import load_modality_config if modality_config_path: - # Import module for side-effect registration - load_modality_config(modality_config_path) + # Import the modality config module for side-effect registration. + # Inlined from gr00t.experiment.launch_finetune.load_modality_config() + # to avoid pulling in the full training stack (tyro, FinetuneConfig, etc.). + path = Path(modality_config_path) + if path.exists() and path.suffix == ".py": + sys.path.append(str(path.parent)) + importlib.import_module(path.stem) + else: + raise FileNotFoundError(f"Modality config path does not exist: {modality_config_path}") # Get the embodiment tag from policy config and convert to EmbodimentTag enum # Handle case-insensitive lookup (e.g., "NEW_EMBODIMENT" or "new_embodiment" both work)