Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion client/joinly_client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,81 @@ async def client_streams(self) -> Never: # type: ignore[override]
raise RuntimeError


def sanitize_tool_schema(schema: dict[str, Any]) -> dict[str, Any]: # noqa: C901
"""Sanitize a tool schema.

This function removes unsupported JSON schema features and ensures the schema
is compatible with OpenAI's requirements.

Args:
schema (dict[str, Any]): The original JSON schema.

Returns:
dict[str, Any]: The sanitized JSON schema.
"""
unsupported = {
"allOf",
"anyOf",
"oneOf",
"not",
"if",
"then",
"else",
"$schema",
"$id",
"$ref",
"definitions",
"$defs",
"patternProperties",
}

def default_object() -> dict[str, Any]:
return {"type": "object", "properties": {}, "additionalProperties": True}

def choose_type(t: Any) -> str: # noqa: ANN401
if isinstance(t, list):
return t[0] if t and isinstance(t[0], str) else "object"
return t if isinstance(t, str) else "object"

def walk(node: Any) -> dict[str, Any]: # noqa: ANN401
if not isinstance(node, dict):
return default_object()
out = {k: v for k, v in node.items() if k not in unsupported}
t = choose_type(out.get("type", "object"))

if t == "object":
props = out.get("properties")
props = props if isinstance(props, dict) else {}
out["properties"] = {k: walk(v) for k, v in props.items()}
ap = out.get("additionalProperties", True)
out["additionalProperties"] = ap if isinstance(ap, bool) else True
req = out.get("required")
if isinstance(req, list):
req = [k for k in req if isinstance(k, str) and k in out["properties"]]
if req:
out["required"] = req
else:
out.pop("required", None)
out["type"] = "object"
return out

if t == "array":
items = out.get("items")
if isinstance(items, list):
out["items"] = walk(items[0]) if items else default_object()
elif isinstance(items, dict):
out["items"] = walk(items)
else:
out["items"] = default_object()
out["type"] = "array"
return out

out["type"] = t
return out

return walk(schema)


async def load_tools(
clients: McpClientConfig | dict[str, McpClientConfig],
) -> tuple[list[ToolDefinition], ToolExecutor]:
Expand All @@ -134,7 +209,7 @@ async def load_tools(
ToolDefinition(
name=f"{prefix}_{tool.name}" if prefix is not None else tool.name,
description=tool.description,
parameters_json_schema=tool.inputSchema,
parameters_json_schema=sanitize_tool_schema(tool.inputSchema),
)
for tool in await config.client.list_tools()
if tool.name not in config.exclude
Expand Down