|
| 1 | +"""Toolkit: wrap existing tools with graph-based filtering. |
| 2 | +
|
| 3 | +Provides :func:`filter_tools` for one-shot filtering and |
| 4 | +:class:`GraphToolkit` for reusable tool management with retrieval. |
| 5 | +
|
| 6 | +Accepts any tool format: |
| 7 | +- LangChain ``BaseTool`` (``@tool``, ``StructuredTool``, etc.) |
| 8 | +- OpenAI function dict (``{"type": "function", "function": {"name": ...}}``) |
| 9 | +- Anthropic tool dict (``{"name": ..., "input_schema": ...}``) |
| 10 | +- MCP tool dict (``{"name": ..., "inputSchema": ...}``) |
| 11 | +- Python callable with type hints |
| 12 | +
|
| 13 | +Usage:: |
| 14 | +
|
| 15 | + from graph_tool_call.langchain import filter_tools, GraphToolkit |
| 16 | +
|
| 17 | + # One-shot: filter tools by query |
| 18 | + filtered = filter_tools(all_tools, "cancel order", top_k=5) |
| 19 | +
|
| 20 | + # Reusable: wrap once, filter many times |
| 21 | + toolkit = GraphToolkit(tools=all_tools, top_k=5) |
| 22 | + filtered = toolkit.get_tools("cancel order") |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import logging |
| 28 | +from typing import Any |
| 29 | + |
| 30 | +logger = logging.getLogger("graph-tool-call.langchain") |
| 31 | + |
| 32 | + |
| 33 | +def _extract_name(tool: Any) -> str: |
| 34 | + """Extract tool name from any supported format.""" |
| 35 | + # Object with .name attribute (LangChain BaseTool, ToolSchema, etc.) |
| 36 | + if hasattr(tool, "name"): |
| 37 | + return tool.name |
| 38 | + |
| 39 | + # Dict formats |
| 40 | + if isinstance(tool, dict): |
| 41 | + # OpenAI: {"type": "function", "function": {"name": ...}} |
| 42 | + if "function" in tool: |
| 43 | + return tool["function"].get("name", "") |
| 44 | + # MCP / Anthropic: {"name": ...} |
| 45 | + if "name" in tool: |
| 46 | + return tool["name"] |
| 47 | + |
| 48 | + # Callable (Python function) |
| 49 | + if callable(tool): |
| 50 | + return getattr(tool, "__name__", "") |
| 51 | + |
| 52 | + return "" |
| 53 | + |
| 54 | + |
| 55 | +def _ingest_tools(graph: Any, tools: list[Any]) -> None: |
| 56 | + """Ingest tools into a ToolGraph, auto-detecting format.""" |
| 57 | + from graph_tool_call.core.tool import parse_tool |
| 58 | + |
| 59 | + callables = [] |
| 60 | + for tool in tools: |
| 61 | + if callable(tool) and not hasattr(tool, "name") and not isinstance(tool, dict): |
| 62 | + callables.append(tool) |
| 63 | + else: |
| 64 | + graph.add_tool(parse_tool(tool)) |
| 65 | + |
| 66 | + if callables: |
| 67 | + graph.ingest_functions(callables) |
| 68 | + |
| 69 | + |
| 70 | +def filter_tools( |
| 71 | + tools: list[Any], |
| 72 | + query: str, |
| 73 | + *, |
| 74 | + top_k: int = 5, |
| 75 | + graph: Any | None = None, |
| 76 | +) -> list[Any]: |
| 77 | + """Filter tools by relevance to *query*. |
| 78 | +
|
| 79 | + Parameters |
| 80 | + ---------- |
| 81 | + tools: |
| 82 | + List of tools in any format — LangChain ``BaseTool``, OpenAI function |
| 83 | + dicts, MCP tool dicts, Anthropic tool dicts, or Python callables. |
| 84 | + query: |
| 85 | + Natural-language query to match tools against. |
| 86 | + top_k: |
| 87 | + Maximum number of tools to return (default: 5). |
| 88 | + graph: |
| 89 | + Optional pre-built ``ToolGraph``. If *None*, a temporary graph is |
| 90 | + built from *tools* on the fly. |
| 91 | +
|
| 92 | + Returns |
| 93 | + ------- |
| 94 | + list |
| 95 | + Subset of *tools* ranked by relevance. Original tool objects are |
| 96 | + preserved (not copies), so they remain callable by the agent. |
| 97 | + """ |
| 98 | + from graph_tool_call import ToolGraph |
| 99 | + |
| 100 | + if graph is None: |
| 101 | + graph = ToolGraph() |
| 102 | + |
| 103 | + # Index by name for fast lookup |
| 104 | + tool_map: dict[str, Any] = {} |
| 105 | + for t in tools: |
| 106 | + name = _extract_name(t) |
| 107 | + if name: |
| 108 | + tool_map[name] = t |
| 109 | + |
| 110 | + # Ingest if not already present |
| 111 | + existing = set(graph.tools.keys()) |
| 112 | + if not existing.intersection(tool_map.keys()): |
| 113 | + _ingest_tools(graph, tools) |
| 114 | + |
| 115 | + results = graph.retrieve(query, top_k=top_k) |
| 116 | + result_names = [r.name for r in results] |
| 117 | + |
| 118 | + filtered = [tool_map[name] for name in result_names if name in tool_map] |
| 119 | + |
| 120 | + if filtered: |
| 121 | + logger.debug( |
| 122 | + "Filtered %d → %d tools for query: %s", |
| 123 | + len(tools), |
| 124 | + len(filtered), |
| 125 | + query[:50], |
| 126 | + ) |
| 127 | + return filtered |
| 128 | + |
| 129 | + logger.debug("Retrieval returned no matches, returning all %d tools", len(tools)) |
| 130 | + return list(tools) |
| 131 | + |
| 132 | + |
| 133 | +class GraphToolkit: |
| 134 | + """Wraps a list of tools with graph-based retrieval. |
| 135 | +
|
| 136 | + Build once from existing tools, then call :meth:`get_tools` per query. |
| 137 | +
|
| 138 | + Parameters |
| 139 | + ---------- |
| 140 | + tools: |
| 141 | + List of tools in any format — LangChain ``BaseTool``, OpenAI function |
| 142 | + dicts, MCP tool dicts, Anthropic tool dicts, or Python callables. |
| 143 | + top_k: |
| 144 | + Default number of tools to return per query. |
| 145 | + graph: |
| 146 | + Optional pre-built ``ToolGraph``. If *None*, one is built from *tools*. |
| 147 | + """ |
| 148 | + |
| 149 | + def __init__( |
| 150 | + self, |
| 151 | + tools: list[Any], |
| 152 | + *, |
| 153 | + top_k: int = 5, |
| 154 | + graph: Any | None = None, |
| 155 | + ) -> None: |
| 156 | + from graph_tool_call import ToolGraph |
| 157 | + |
| 158 | + self._tools: dict[str, Any] = {} |
| 159 | + for t in tools: |
| 160 | + name = _extract_name(t) |
| 161 | + if name: |
| 162 | + self._tools[name] = t |
| 163 | + |
| 164 | + self._top_k = top_k |
| 165 | + |
| 166 | + if graph is not None: |
| 167 | + self._graph: ToolGraph = graph |
| 168 | + else: |
| 169 | + self._graph = ToolGraph() |
| 170 | + |
| 171 | + # Ingest tools into graph |
| 172 | + existing = set(self._graph.tools.keys()) |
| 173 | + if not existing.intersection(self._tools.keys()): |
| 174 | + _ingest_tools(self._graph, tools) |
| 175 | + |
| 176 | + @property |
| 177 | + def graph(self) -> Any: |
| 178 | + """Underlying ``ToolGraph`` instance.""" |
| 179 | + return self._graph |
| 180 | + |
| 181 | + @property |
| 182 | + def all_tools(self) -> list[Any]: |
| 183 | + """All registered tools.""" |
| 184 | + return list(self._tools.values()) |
| 185 | + |
| 186 | + def get_tools(self, query: str, *, top_k: int | None = None) -> list[Any]: |
| 187 | + """Return tools relevant to *query*. |
| 188 | +
|
| 189 | + Parameters |
| 190 | + ---------- |
| 191 | + query: |
| 192 | + Natural-language query. |
| 193 | + top_k: |
| 194 | + Override the default top_k for this call. |
| 195 | +
|
| 196 | + Returns |
| 197 | + ------- |
| 198 | + list |
| 199 | + Filtered tools, ordered by relevance. Original objects preserved. |
| 200 | + """ |
| 201 | + k = top_k if top_k is not None else self._top_k |
| 202 | + results = self._graph.retrieve(query, top_k=k) |
| 203 | + result_names = [r.name for r in results] |
| 204 | + |
| 205 | + filtered = [self._tools[name] for name in result_names if name in self._tools] |
| 206 | + return filtered if filtered else self.all_tools |
0 commit comments