Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/publish-client-python.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish PyPI Package
name: Publish Client Python

on:
workflow_dispatch:
Expand Down Expand Up @@ -29,8 +29,10 @@ jobs:

# Build the package
- name: Build package
working-directory: sdks/python
run: uv build

# Publish to PyPI
- name: Publish to PyPI
working-directory: sdks/python
run: uv publish
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.secret
mcp.yaml
.gabber/
.DS_Store
.DS_Store
2 changes: 1 addition & 1 deletion engine/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dependencies = [
"elevenlabs==1.12.1",
"pyvips~=2.0.2",
"av~=14.4.0",
"openai~=1.93.0",
"openai~=1.107.1",
"pillow~=11.1.0",
"msgpack~=1.1.1",
"python-dotenv~=1.1.1",
Expand Down
11 changes: 8 additions & 3 deletions engine/src/core/graph/runtime_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from core.editor import serialize
from core.node import Node
from nodes.core.media.publish import Publish
from nodes.core.tool import MCP

PING_BYTES = "ping".encode("utf-8")


class RuntimeApi:
Expand Down Expand Up @@ -80,10 +81,14 @@ def on_pad(p: pad.Pad, value: Any):
p._add_update_handler(on_pad)

def on_data(packet: rtc.DataPacket):
if not packet.topic or not packet.topic.startswith("runtime_api"):
if not packet.topic or packet.topic != "runtime_api":
return

request = RuntimeRequest.model_validate_json(packet.data)
try:
request = RuntimeRequest.model_validate_json(packet.data)
except Exception as e:
logging.error(f"Invalid runtime_api request: {e}", exc_info=e)
return
req_id = request.req_id
ack_resp = RuntimeRequestAck(req_id=req_id, type="ack")
complete_resp = RuntimeResponse(req_id=req_id, type="complete")
Expand Down
2 changes: 2 additions & 0 deletions engine/src/core/mcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MCPServerConfig,
)
from .mcp_server_provider import MCPServerProvider
from .datachannel_transport import datachannel_host

__all__ = [
"MCPTransport",
Expand All @@ -14,4 +15,5 @@
"MCPServer",
"MCPServerProvider",
"MCPServerConfig",
"datachannel_host",
]
50 changes: 8 additions & 42 deletions engine/src/core/mcp/datachannel_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

@asynccontextmanager
async def datachannel_host(
room: rtc.Room, participant: str
room: rtc.Room, participant: str, mcp_name: str
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
],
None,
]:
topic = f"__mcp__:{mcp_name}"
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
Expand All @@ -38,7 +39,7 @@ async def datachannel_host(
packet_q = asyncio.Queue[rtc.DataPacket | None]()

def on_message(packet: rtc.DataPacket):
if packet.topic != "__mcp__":
if packet.topic != topic:
return

if not packet.participant:
Expand All @@ -60,8 +61,12 @@ async def on_message_loop():
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc:
logging.error(f"DC message validation error: {exc}")
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
except Exception as exc:
logging.error(f"DC unexpected error: {exc}")
await read_stream_writer.send(exc)

async def dc_writer():
"""
Expand All @@ -75,7 +80,7 @@ async def dc_writer():
by_alias=True, mode="json", exclude_none=True
)
await room.local_participant.publish_data(
json.dumps(msg_dict), topic="__mcp__"
json.dumps(msg_dict), topic=topic
)

room.on("data_received", on_message)
Expand All @@ -91,42 +96,3 @@ async def dc_writer():
tg.cancel_scope.cancel()

room.off("data_received", on_message)


async def datachannel_client_proxy(
url: str,
token: str,
other_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
other_write_stream: MemoryObjectSendStream[SessionMessage],
) -> rtc.Room:
"""
Connects to a LiveKit room and returns the Room object.
"""
room = rtc.Room()

def on_message(packet: rtc.DataPacket):
if packet.topic != "__mcp__":
return
logger.debug(f"Received data packet: {packet.data}")
json_msg = types.JSONRPCMessage.model_validate_json(packet.data)
sm = SessionMessage(json_msg)
other_write_stream.send_nowait(sm)

async def read_loop():
async with other_read_stream:
async for session_message in other_read_stream:
if isinstance(session_message, Exception):
logger.error(f"Error in received message: {session_message}")
continue
msg_dict = session_message.message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
await room.local_participant.publish_data(
json.dumps(msg_dict), topic="__mcp__"
)

room.on("data_received", on_message)
await room.connect(url, token)
await read_loop()
room.off("data_received", on_message)
return room
18 changes: 16 additions & 2 deletions engine/src/core/mcp/mcp_server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@ class MCPTransportSSE(BaseModel):
url: str


class MCPTransportSTDIO(BaseModel):
type: Literal["stdio"] = "stdio"
command: str
args: list[str]


MCPLocalTransport = Annotated[
MCPTransportSTDIO | MCPTransportSSE, Field(discriminator="type")
]


class MCPTransportDatachannelProxy(BaseModel):
type: Literal["datachannel_proxy"] = "datachannel_proxy"
local_transport: MCPTransportSSE
local_transport: MCPLocalTransport


MCPTransport = Annotated[MCPTransportDatachannelProxy, Field(discriminator="type")]
MCPTransport = Annotated[
MCPTransportDatachannelProxy | MCPTransportSTDIO | MCPTransportSTDIO,
Field(discriminator="type"),
]


class MCPServer(BaseModel):
Expand Down
58 changes: 58 additions & 0 deletions engine/src/core/pad/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def intersect(self, other: "BasePadType") -> "BasePadType | None":
return self
return None

def to_json_schema(self) -> dict[str, Any]:
raise NotImplementedError()


class String(BasePadType):
type: Literal["string"] = "string"
Expand All @@ -43,6 +46,14 @@ def intersect(self, other: "BasePadType"):
),
)

def to_json_schema(self) -> dict[str, Any]:
schema: dict[str, Any] = {"type": "string"}
if self.max_length is not None:
schema["maxLength"] = self.max_length
if self.min_length is not None:
schema["minLength"] = self.min_length
return schema


class Enum(BasePadType):
type: Literal["enum"] = "enum"
Expand Down Expand Up @@ -75,6 +86,12 @@ def intersect(self, other: "BasePadType"):

raise ValueError("Unexpected state.")

def to_json_schema(self) -> dict[str, Any]:
schema: dict[str, Any] = {"type": "string"}
if self.options is not None:
schema["enum"] = self.options
return schema


class Secret(BasePadType):
type: Literal["secret"] = "secret"
Expand All @@ -96,6 +113,9 @@ def intersect(self, other: "BasePadType"):
options=intersected_options,
)

def to_json_schema(self) -> dict[str, Any]:
raise NotImplementedError()


class Integer(BasePadType):
type: Literal["integer"] = "integer"
Expand All @@ -119,6 +139,14 @@ def intersect(self, other: "BasePadType"):
),
)

def to_json_schema(self) -> dict[str, Any]:
schema: dict[str, Any] = {"type": "integer"}
if self.maximum is not None:
schema["maximum"] = self.maximum
if self.minimum is not None:
schema["minimum"] = self.minimum
return schema


class Float(BasePadType):
type: Literal["float"] = "float"
Expand All @@ -142,6 +170,14 @@ def intersect(self, other: "BasePadType"):
),
)

def to_json_schema(self) -> dict[str, Any]:
schema: dict[str, Any] = {"type": "number"}
if self.maximum is not None:
schema["maximum"] = self.maximum
if self.minimum is not None:
schema["minimum"] = self.minimum
return schema


class BoundingBox(BasePadType):
type: Literal["bounding_box"] = "bounding_box"
Expand All @@ -154,6 +190,9 @@ class Point(BasePadType):
class Boolean(BasePadType):
type: Literal["boolean"] = "boolean"

def to_json_schema(self) -> dict[str, Any]:
return {"type": "boolean"}


class Audio(BasePadType):
type: Literal["audio"] = "audio"
Expand Down Expand Up @@ -211,6 +250,17 @@ def intersect(self, other: "BasePadType"):
),
)

def to_json_schema(self) -> dict[str, Any]:
schema: dict[str, Any] = {"type": "array"}
if self.max_length is not None:
schema["maxItems"] = self.max_length
if (
self.item_type_constraints is not None
and len(self.item_type_constraints) == 1
):
schema["items"] = self.item_type_constraints[0].to_json_schema()
return schema


class Schema(BasePadType):
type: Literal["schema"] = "schema"
Expand Down Expand Up @@ -243,6 +293,14 @@ def intersect(self, other: "BasePadType"):

return Object(object_schema=new_schema) if new_schema else None

def to_json_schema(self) -> dict[str, Any]:
if self.object_schema is not None:
return {
"type": "object",
"properties": self.object_schema,
}
return {"type": "object"}


class NodeReference(BasePadType):
type: Literal["node_reference"] = "node_reference"
Expand Down
Loading