Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dimos/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class Agent(Module[AgentConfig]):
agent: Out[BaseMessage]
human_input: In[str]
agent_idle: Out[bool]
tool_streams: In[dict[str, Any]]

_lock: RLock
_state_graph: CompiledStateGraph[Any, Any, Any, Any] | None
Expand Down Expand Up @@ -78,6 +79,16 @@ def _on_human_input(string: str) -> None:

self._disposables.add(Disposable(self.human_input.subscribe(_on_human_input)))

def _on_tool_stream(msg: dict[str, Any]) -> None:
if msg.get("type") == "update":
tool_name = msg.get("tool_name", "unknown")
text = msg.get("text", "")
self._message_queue.put(
HumanMessage(content=f"[Tool stream update from '{tool_name}']: {text}")
)

self._disposables.add(Disposable(self.tool_streams.subscribe(_on_tool_stream)))

@rpc
def stop(self) -> None:
self._stop_event.set()
Expand Down
41 changes: 41 additions & 0 deletions dimos/agents/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from queue import Empty, Queue
from threading import Event, RLock, Thread
import time
Expand Down Expand Up @@ -59,6 +60,7 @@ class McpClient(Module[McpClientConfig]):
_stop_event: Event
_http_client: httpx.Client
_seq_ids: SequentialIds
_sse_thread: Thread | None

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand All @@ -74,6 +76,7 @@ def __init__(self, **kwargs: Any) -> None:
self._stop_event = Event()
self._http_client = httpx.Client(timeout=120.0)
self._seq_ids = SequentialIds()
self._sse_thread = None

def __reduce__(self) -> Any:
return (self.__class__, (), {})
Expand Down Expand Up @@ -184,11 +187,15 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None:
)
self._thread.start()

self._start_sse_listener()

@rpc
def stop(self) -> None:
self._stop_event.set()
if self._thread.is_alive():
self._thread.join(timeout=2.0)
if self._sse_thread is not None and self._sse_thread.is_alive():
self._sse_thread.join(timeout=2.0)
self._http_client.close()
super().stop()

Expand Down Expand Up @@ -226,6 +233,40 @@ def _process_message(
if self._message_queue.empty():
self.agent_idle.publish(True)

def _start_sse_listener(self) -> None:
"""Connect to the MCP server SSE endpoint to receive tool stream updates."""
self._sse_thread = Thread(target=self._sse_loop, name="McpClient-SSE", daemon=True)
self._sse_thread.start()

def _sse_loop(self) -> None:
base_url = self.config.mcp_server_url.rsplit("/mcp", 1)[0]
sse_url = f"{base_url}/mcp/streams"

while not self._stop_event.is_set():
try:
with httpx.Client(timeout=None) as client:
with client.stream("GET", sse_url) as response:
for line in response.iter_lines():
if self._stop_event.is_set():
return
if not line.startswith("data: "):
continue
try:
data = json.loads(line[6:])
except json.JSONDecodeError:
continue
if data.get("type") == "update":
tool_name = data.get("tool_name", "unknown")
text = data.get("text", "")
self._message_queue.put(
HumanMessage(
content=(f"[Tool stream update from '{tool_name}']: {text}")
)
)
except Exception:
if not self._stop_event.is_set():
time.sleep(1.0)


def _append_image_to_history(
mcp_client: McpClient, func_name: str, uuid_: str, result: Any
Expand Down
49 changes: 47 additions & 2 deletions dimos/agents/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from reactivex.disposable import Disposable
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import Response, StreamingResponse
import uvicorn

from dimos.agents.annotation import skill
from dimos.core.core import rpc
from dimos.core.module import Module
from dimos.core.rpc_client import RpcCall, RPCClient
from dimos.core.stream import In
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
Expand All @@ -43,11 +45,12 @@
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
app.state.skills = []
app.state.rpc_calls = {}
app.state.sse_queues = [] # list[asyncio.Queue] for SSE clients


def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]:
Expand Down Expand Up @@ -167,7 +170,39 @@ async def mcp_endpoint(request: Request) -> Response:
return JSONResponse(result)


@app.get("/mcp/streams")
async def streams_sse_endpoint() -> StreamingResponse:
"""Server-Sent Events endpoint for tool stream updates.

Clients subscribe here to receive real-time updates from long-running
skills that use ``ToolStream``.
"""
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
app.state.sse_queues.append(queue)

async def event_generator(): # type: ignore[no-untyped-def]
try:
while True:
data = await queue.get()
yield f"data: {json.dumps(data)}\n\n"
except asyncio.CancelledError:
pass
finally:
try:
app.state.sse_queues.remove(queue)
except ValueError:
pass

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)


class McpServer(Module):
tool_streams: In[dict[str, Any]]

_uvicorn_server: uvicorn.Server | None = None
_serve_future: concurrent.futures.Future[None] | None = None

Expand All @@ -176,6 +211,16 @@ def start(self) -> None:
super().start()
self._start_server()

loop = self._loop

def _on_tool_stream_message(msg: dict[str, Any]) -> None:
if loop is None:
return
for queue in app.state.sse_queues:
asyncio.run_coroutine_threadsafe(queue.put(msg), loop)
Comment on lines +216 to +220
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition on sse_queues list

app.state.sse_queues is a plain list that is mutated from multiple concurrent contexts without any synchronisation:

  • Async contextstreams_sse_endpoint appends a queue, and the event_generator finaliser removes it.
  • Sync thread context_on_tool_stream_message (called from the RxPY subscriber thread) iterates over the list with for queue in app.state.sse_queues.

Python's GIL protects individual atomic operations, but iterating over a list while another coroutine or thread appends/removes items from it can still raise RuntimeError: list changed size during iteration or silently skip/double-process entries.

Consider protecting mutations and the iteration with a threading.Lock, or replacing the list with a set guarded by a lock, to make concurrent access safe.


self._disposables.add(Disposable(self.tool_streams.subscribe(_on_tool_stream_message)))

@rpc
def stop(self) -> None:
if self._uvicorn_server:
Expand Down
84 changes: 84 additions & 0 deletions dimos/agents/mcp/tool_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2026 Dimensional Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any
import uuid

from dimos.core.transport import pLCMTransport
from dimos.utils.logging_config import setup_logger

logger = setup_logger()

_TOOL_STREAM_TOPIC = "/tool_streams"


class ToolStream:
"""A streaming channel for sending updates from a running skill to the agent.

Each `ToolStream` publishes messages on a shared LCM topic. The agent
(or the MCP server SSE endpoint) subscribes once and receives all updates.
"""

def __init__(self, tool_name: str) -> None:
self.tool_name = tool_name
self.id = str(uuid.uuid4())
self._closed = False
self._transport: pLCMTransport[dict[str, Any]] | None = None

def start(self) -> None:
self._transport = pLCMTransport(_TOOL_STREAM_TOPIC)
self._transport.start()

def send(self, message: str) -> None:
if self._closed:
logger.error("Attempted to send on closed ToolStream", stream_id=self.id)
return
if self._transport is None:
logger.error("ToolStream transport not initialized", stream_id=self.id)
return
self._transport.broadcast(
None,
{
"stream_id": self.id,
"tool_name": self.tool_name,
"type": "update",
"text": message,
},
)

def stop(self) -> None:
if self._closed:
return
self._closed = True
try:
self._transport.broadcast(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self._transport is None (calling stop after failed start) wouldn't the stop throw? The finally would clean up but still would let the exception out

{
"stream_id": self.id,
"tool_name": self.tool_name,
"type": "close",
},
)
Comment on lines +67 to +73
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong argument count in broadcast() call

pLCMTransport.broadcast has the signature def broadcast(self, _: Out[T] | None, msg: T) -> None: — it requires two positional arguments. In ToolStream.send() this is called correctly with self._transport.broadcast(None, {...}). However, in stop() the close message is passed as the first argument (_), completely omitting the required msg positional argument. This will raise a TypeError: broadcast() missing 1 required positional argument: 'msg' at runtime every time a ToolStream is stopped.

Suggested change
self._transport.broadcast(
{
"stream_id": self.id,
"tool_name": self.tool_name,
"type": "close",
},
)
try:
self._transport.broadcast(
None,
{
"stream_id": self.id,
"tool_name": self.tool_name,
"type": "close",
},
)

finally:
if self._transport is not None:
self._transport.stop()
self._transport = None

@property
def is_closed(self) -> bool:
return self._closed


__all__ = ["_TOOL_STREAM_TOPIC", "ToolStream"]
14 changes: 9 additions & 5 deletions dimos/agents/skills/person_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
import time
from typing import Any

from langchain_core.messages import HumanMessage
import numpy as np
from reactivex.disposable import Disposable

from dimos.agents.agent import AgentSpec
from dimos.agents.annotation import skill
from dimos.agents.mcp.tool_stream import ToolStream
from dimos.core.core import rpc
from dimos.core.module import Module, ModuleConfig
from dimos.core.stream import In, Out
Expand Down Expand Up @@ -62,7 +61,6 @@ class PersonFollowSkillContainer(Module[Config]):
global_map: In[PointCloud2]
cmd_vel: Out[Twist]

_agent_spec: AgentSpec
_frequency: float = 20.0 # Hz - control loop frequency
_max_lost_frames: int = 15 # number of frames to wait before declaring person lost

Expand All @@ -75,6 +73,7 @@ def __init__(self, **kwargs: Any) -> None:
self._thread: Thread | None = None
self._should_stop: Event = Event()
self._lock = RLock()
self._tool_stream: ToolStream | None = None

# Use MuJoCo camera intrinsics in simulation mode
camera_info = self.config.camera_info
Expand Down Expand Up @@ -196,12 +195,14 @@ def _follow_person(self, query: str, initial_bbox: BBox) -> str:

logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections")

self._tool_stream = ToolStream("follow_person")
self._tool_stream.start()
self._thread = Thread(target=self._follow_loop, args=(tracker, query), daemon=True)
self._thread.start()

return (
"Found the person. Starting to follow. You can stop following by calling "
"the 'stop_following' tool."
"the 'stop_following' tool. You will receive streaming updates."
)

def _follow_loop(self, tracker: "EdgeTAMProcessor", query: str) -> None:
Expand Down Expand Up @@ -263,7 +264,10 @@ def _stop_following(self) -> None:
def _send_stop_reason(self, query: str, reason: str) -> None:
self.cmd_vel.publish(Twist.zero())
message = f"Person follow stopped for '{query}'. Reason: {reason}."
self._agent_spec.add_message(HumanMessage(message))
if self._tool_stream is not None:
self._tool_stream.send(message)
self._tool_stream.stop()
self._tool_stream = None
logger.info("Person follow stopped", query=query, reason=reason)


Expand Down
Loading