From 82ec7d0127a82e354f371fc9627c9828ff6f8292 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:37:02 -0600 Subject: [PATCH 1/2] Add callback adapter core for MCP tool generation --- dash/_callback.py | 12 +- dash/mcp/primitives/tools/callback_adapter.py | 461 ++++++++++++++++++ .../tools/callback_adapter_collection.py | 154 ++++++ dash/mcp/primitives/tools/callback_utils.py | 36 ++ .../primitives/tools/descriptions/__init__.py | 7 + .../tools/input_schemas/__init__.py | 5 + .../tools/output_schemas/__init__.py | 5 + dash/mcp/types/__init__.py | 26 + dash/mcp/types/callback_types.py | 33 ++ dash/mcp/types/component_types.py | 20 + dash/mcp/types/exceptions.py | 30 ++ dash/mcp/types/typing_utils.py | 28 ++ requirements/install.txt | 1 + tests/unit/mcp/conftest.py | 6 + tests/unit/mcp/tools/test_callback_adapter.py | 227 +++++++++ .../tools/test_callback_adapter_collection.py | 145 ++++++ 16 files changed, 1193 insertions(+), 3 deletions(-) create mode 100644 dash/mcp/primitives/tools/callback_adapter.py create mode 100644 dash/mcp/primitives/tools/callback_adapter_collection.py create mode 100644 dash/mcp/primitives/tools/callback_utils.py create mode 100644 dash/mcp/primitives/tools/descriptions/__init__.py create mode 100644 dash/mcp/primitives/tools/input_schemas/__init__.py create mode 100644 dash/mcp/primitives/tools/output_schemas/__init__.py create mode 100644 dash/mcp/types/__init__.py create mode 100644 dash/mcp/types/callback_types.py create mode 100644 dash/mcp/types/component_types.py create mode 100644 dash/mcp/types/exceptions.py create mode 100644 dash/mcp/types/typing_utils.py create mode 100644 tests/unit/mcp/conftest.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter_collection.py diff --git a/dash/_callback.py b/dash/_callback.py index 3785df7166..dc73dd0792 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -41,6 +41,7 @@ from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value +from .types import CallbackDispatchResponse from ._no_update import NoUpdate @@ -80,6 +81,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + mcp_enabled: bool = True, **_kwargs, ) -> Callable[..., Any]: """ @@ -231,6 +233,7 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + mcp_enabled=mcp_enabled, ) @@ -278,6 +281,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, + mcp_enabled=True, ): if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -318,6 +322,7 @@ def insert_callback( "manager": manager, "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, + "mcp_enabled": mcp_enabled, } callback_list.append(callback_spec) @@ -523,7 +528,7 @@ def _prepare_response( output_value, output_spec, multi, - response, + response: CallbackDispatchResponse, callback_ctx, app, original_packages, @@ -652,6 +657,7 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + mcp_enabled=_kwargs.get("mcp_enabled", True), ) # pylint: disable=too-many-locals @@ -686,7 +692,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: dict = {"multi": True} # type: ignore + response: CallbackDispatchResponse = {"multi": True} jsonResponse = None try: @@ -758,7 +764,7 @@ async def async_add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response = {"multi": True} + response: CallbackDispatchResponse = {"multi": True} try: if background is not None: diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py new file mode 100644 index 0000000000..0f50d15c03 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -0,0 +1,461 @@ +"""Adapter: Dash callback → MCP tool interface. + +Wraps a raw ``callback_map`` entry and exposes MCP-facing +properties (tool name, params, outputs) lazily. +""" + +from __future__ import annotations + +import inspect +import json +import typing +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash.layout import ( + _WILDCARD_VALUES, + find_component, + find_matching_components, + parse_wildcard_id, +) +from dash.mcp.types import is_nullable +from dash._grouping import flatten_grouping +from dash._utils import clean_property_name, split_callback_id +from dash.mcp.types import MCPInput, MCPOutput +from .callback_utils import run_callback +from .descriptions import build_tool_description +from .input_schemas import get_input_schema +from .output_schemas import get_output_schema + + +class CallbackAdapter: + """Adapts a single Dash callback_map entry to the MCP tool interface.""" + + def __init__(self, callback_output_id: str): + self._output_id = callback_output_id + + # ------------------------------------------------------------------- + # Projections + # ------------------------------------------------------------------- + + @cached_property + def as_mcp_tool(self) -> Tool: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") + + def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Transforms the given kwargs to a dict suitable for calling this callback. + + Mirrors how the Dash renderer assembles the callback payload — + see ``fillVals()`` in ``dash-renderer/src/actions/callbacks.ts``. + + For pattern-matching callbacks, wildcard deps are expanded into + nested arrays with concrete component IDs. + """ + coerced = {k: _coerce_value(v) for k, v in kwargs.items()} + + raw_inputs = self._cb_info.get("inputs", []) + raw_state = self._cb_info.get("state", []) + n_deps = len(raw_inputs) + len(raw_state) + + flat_values = [None] * n_deps + for i, name in enumerate(self._param_names): + if i < n_deps and name in coerced: + flat_values[i] = coerced[name] + + inputs_with_values = [ + _expand_dep(dep, flat_values[i]) for i, dep in enumerate(raw_inputs) + ] + state_with_values = [ + _expand_dep(dep, flat_values[len(raw_inputs) + i]) + for i, dep in enumerate(raw_state) + ] + + outputs_spec = _expand_output_spec( + self._output_id, self._cb_info, inputs_with_values + ) + + # changedPropIds: only inputs with non-None values. + # This determines ctx.triggered_id in the callback. + changed = [] + for entry in inputs_with_values: + if isinstance(entry, dict) and entry.get("value") is not None: + eid = entry.get("id") + if isinstance(eid, dict): + changed.append( + f"{json.dumps(eid, sort_keys=True)}.{entry['property']}" + ) + elif isinstance(eid, str): + changed.append(f"{eid}.{entry['property']}") + + return { + "output": self._output_id, + "outputs": outputs_spec, + "inputs": inputs_with_values, + "state": state_with_values, + "changedPropIds": changed, + } + + # ------------------------------------------------------------------- + # Public identity and metadata + # ------------------------------------------------------------------- + + @cached_property + def is_valid(self) -> bool: + """Whether all input components exist in the layout.""" + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + for dep in all_deps: + dep_id = str(dep.get("id", "")) + if dep_id.startswith("{"): + continue + if find_component(dep_id) is None: + return False + return True + + @property + def output_id(self) -> str: + return self._output_id + + @property + def tool_name(self) -> str: + return get_app().mcp_callback_map._tool_names_map[self._output_id] + + @cached_property + def prevents_initial_call(self) -> bool: + for cb in get_app()._callback_list: + if cb["output"] == self._output_id: + return cb.get("prevent_initial_call", False) + return False + + # ------------------------------------------------------------------- + # Private: computed fields for the MCP Tool + # ------------------------------------------------------------------- + + @cached_property + def _description(self) -> str: + return build_tool_description(self.outputs, self._docstring) + + @cached_property + def _input_schema(self) -> dict[str, Any]: + properties = {p["name"]: get_input_schema(p) for p in self.inputs} + required = [p["name"] for p in self.inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + return input_schema + + @cached_property + def _output_schema(self) -> dict[str, Any]: + return get_output_schema() + + # ------------------------------------------------------------------- + # Private: callback metadata + # ------------------------------------------------------------------- + + @cached_property + def _docstring(self) -> str | None: + return getattr(self._original_func, "__doc__", None) + + @cached_property + def _initial_output(self) -> dict[str, dict[str, Any]]: + """Run this callback with initial input values. + + Returns the ``response`` portion of the dispatch result: + ``{component_id: {property: value}}``. + + Skipped for callbacks with ``prevent_initial_call=True``, + matching how the Dash renderer skips them on page load. + """ + if self.prevents_initial_call: + return {} + + callback_map = get_app().mcp_callback_map + kwargs = {} + for p in self.inputs: + upstream = callback_map.find_by_output(p["id_and_prop"]) + if upstream is self: + kwargs[p["name"]] = getattr( + find_component(p["component_id"]), p["property"], None + ) + else: + kwargs[p["name"]] = callback_map.get_initial_value(p["id_and_prop"]) + try: + result = run_callback(self, kwargs) + return result.get("response", {}) + except Exception: + return {} + + def initial_output_value(self, id_and_prop: str) -> Any: + """Return the initial value for a specific output ``"component_id.property"``.""" + component_id, prop = id_and_prop.rsplit(".", 1) + return self._initial_output.get(component_id, {}).get(prop) + + @cached_property + def outputs(self) -> list[MCPOutput]: + if self._cb_info.get("no_output"): + return [] + parsed = split_callback_id(self._output_id) + if isinstance(parsed, dict): + parsed = [parsed] + result: list[MCPOutput] = [] + for p in parsed: + comp_id = p["id"] + prop = clean_property_name(p["property"]) + id_and_prop = f"{comp_id}.{prop}" + comp = find_component(comp_id) + result.append( + { + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "component_type": getattr(comp, "_type", None), + "initial_value": self.initial_output_value(id_and_prop), + "tool_name": self.tool_name, + } + ) + return result + + @cached_property + def inputs(self) -> list[MCPInput]: + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + callback_map = get_app().mcp_callback_map + + result: list[MCPInput] = [] + for dep, name, annotation in zip( + all_deps, self._param_names, self._param_annotations + ): + comp_id = str(dep.get("id", "unknown")) + comp = find_component(comp_id) + prop = dep.get("property", "unknown") + id_and_prop = f"{comp_id}.{prop}" + + upstream_cb = callback_map.find_by_output(id_and_prop) + upstream_output = None + if upstream_cb is not None and upstream_cb is not self: + if not upstream_cb.prevents_initial_call: + for out in upstream_cb.outputs: + if out["id_and_prop"] == id_and_prop: + upstream_output = out + break + + initial_value = ( + upstream_output["initial_value"] + if upstream_output is not None + else getattr(comp, prop, None) + ) + + if annotation is not None: + required = not is_nullable(annotation) + else: + required = initial_value is not None + + result.append( + { + "name": name, + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "annotation": annotation, + "component_type": getattr(comp, "_type", None), + "component": comp, + "required": required, + "initial_value": initial_value, + "upstream_output": upstream_output, + } + ) + return result + + # ------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------- + + @cached_property + def _cb_info(self) -> dict[str, Any]: + return get_app().callback_map[self._output_id] + + @cached_property + def _original_func(self) -> Any | None: + func = self._cb_info.get("callback") + return getattr(func, "__wrapped__", func) + + @cached_property + def _func_signature(self) -> inspect.Signature | None: + if self._original_func is None: + return None + try: + return inspect.signature(self._original_func) + except (ValueError, TypeError): + return None + + @cached_property + def _dep_param_map(self) -> list[tuple[str, str]]: + """(func_param_name, mcp_param_name) per dep, in dep order. + + Single source of truth for mapping deps to param names. + All dict-vs-list branching is confined here. + """ + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + n_deps = len(all_deps) + indices = self._cb_info.get("inputs_state_indices") + + if isinstance(indices, dict): + entries: list[tuple[int, str, str]] = [] + for func_name, idx in indices.items(): + positions = flatten_grouping(idx) + if len(positions) == 1: + entries.append((positions[0], func_name, func_name)) + else: + for pos in positions: + dep = all_deps[pos] if pos < n_deps else {} + comp_id = str(dep.get("id", "unknown")).replace("-", "_") + prop = dep.get("property", "unknown") + entries.append( + (pos, func_name, f"{func_name}_{comp_id}__{prop}") + ) + entries.sort(key=lambda e: e[0]) + result = [(f, m) for _, f, m in entries] + elif self._func_signature is not None: + names = list(self._func_signature.parameters.keys()) + result = [(n, n) for n in names] + else: + result = [] + + while len(result) < n_deps: + fallback = f"param_{len(result)}" + result.append((fallback, fallback)) + return result + + @cached_property + def _param_names(self) -> list[str]: + """MCP param name per dep, in dep order.""" + return [mcp for _, mcp in self._dep_param_map] + + @cached_property + def _param_annotations(self) -> list[Any | None]: + """One annotation per dep, in dep order.""" + if self._func_signature is None: + return [None] * len(self._dep_param_map) + try: + hints = typing.get_type_hints(self._original_func) + except Exception: + hints = getattr(self._original_func, "__annotations__", {}) + return [hints.get(func_name) for func_name, _ in self._dep_param_map] + + +def _expand_dep(dep: dict, value: Any) -> Any: + """Expand a dependency into the dispatch format. + + For regular deps, returns ``{id, property, value}``. + For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. + For MATCH: passes through the single ``{id, property, value}`` dict. + """ + pattern = parse_wildcard_id(dep.get("id", "")) + if pattern is None: + return {**dep, "value": value} + + # LLM provides browser-like format + if isinstance(value, list): + return value + if isinstance(value, dict) and "id" in value: + return value + return {**dep, "value": value} + + +def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> Any: + """Build the outputs spec, expanding wildcards to concrete IDs. + + For wildcard outputs, derives concrete IDs from the resolved inputs. + The browser does the same: input and output wildcards resolve against + the same set of matching components. + """ + if cb_info.get("no_output"): + return [] + + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + + results = [] + for p in parsed: + pid = p["id"] + prop = clean_property_name(p["property"]) + pattern = parse_wildcard_id(pid) + if pattern is not None: + concrete_ids = _derive_output_ids(pattern, resolved_inputs) + if not concrete_ids: + concrete_ids = [comp.id for comp in find_matching_components(pattern)] + expanded = [{"id": cid, "property": prop} for cid in concrete_ids] + # ALL/ALLSMALLER → nested list; MATCH → single dict + if len(expanded) == 1: + results.append(expanded[0]) + else: + results.append(expanded) + else: + results.append({"id": pid, "property": prop}) + + if len(results) == 1: + return results[0] + return results + + +def _derive_output_ids( + output_pattern: dict, resolved_inputs: list +) -> list[dict] | None: + """Derive concrete output IDs from the resolved input entries. + + Extracts the wildcard key values from the LLM-provided concrete + input IDs and substitutes them into the output pattern. + """ + wildcard_keys = [ + k + for k, v in output_pattern.items() + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES + ] + if not wildcard_keys: + return None + + def _substitute(item_id: dict) -> dict | None: + if not isinstance(item_id, dict): + return None + output_id = dict(output_pattern) + for wk in wildcard_keys: + if wk in item_id: + output_id[wk] = item_id[wk] + return output_id + + for entry in resolved_inputs: + # ALL/ALLSMALLER: nested array of {id, property, value} dicts + if isinstance(entry, list) and entry: + concrete_ids = [] + for item in entry: + out = _substitute(item.get("id")) + if out: + concrete_ids.append(out) + if concrete_ids: + return concrete_ids + # MATCH: single {id, property, value} dict + elif isinstance(entry, dict) and isinstance(entry.get("id"), dict): + out = _substitute(entry["id"]) + if out: + return [out] + + return None + + +def _coerce_value(value: Any) -> Any: + """Parse JSON strings back to Python objects. + + MCP tool parameters arrive as strings. This recovers the + intended type (list, dict, number, bool, null) via json.loads. + Plain strings that aren't valid JSON pass through unchanged. + """ + if not isinstance(value, str): + return value + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + return value diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py new file mode 100644 index 0000000000..60e9e2efe5 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -0,0 +1,154 @@ +"""Collection of CallbackAdapters with cross-adapter queries. + +Stored as a singleton on ``app.mcp_callback_map``. +""" + +from __future__ import annotations + +import hashlib +import re +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id +from dash.layout import extract_text, find_component, traverse +from .callback_adapter import CallbackAdapter + + +class CallbackAdapterCollection: + def __init__(self, app): + callback_map = getattr(app, "callback_map", {}) + + raw: list[tuple[str, dict]] = [] + for output_id, cb_info in callback_map.items(): + if cb_info.get("mcp_enabled") is False: + continue + if "callback" not in cb_info: + continue + raw.append((output_id, cb_info)) + + self._tool_names_map = self._build_tool_names(raw) + self._callbacks = [ + CallbackAdapter(callback_output_id=output_id) + for output_id in self._tool_names_map + ] + # TODO: enable_mcp_server() will replace this with a direct assignment on app + app.mcp_callback_map = self + + @staticmethod + def _sanitize_name(name: str) -> str: + + max_len = 64 + sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + if sanitized and sanitized[0].isdigit(): + sanitized = "cb_" + sanitized + full = sanitized or "unnamed_callback" + if len(full) <= max_len: + return full + hash_suffix = hashlib.sha256(full.encode()).hexdigest()[:8] + truncated = sanitized[: max_len - 9].rstrip("_") + return f"{truncated}_{hash_suffix}" + + @classmethod + def _build_tool_names(cls, raw: list[tuple[str, dict]]) -> dict[str, str]: + func_name_counts: dict[str, int] = {} + for _output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + func_name_counts[fn] = func_name_counts.get(fn, 0) + 1 + + name_map: dict[str, str] = {} + for output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + raw_name = fn if fn and func_name_counts[fn] == 1 else output_id + name_map[output_id] = cls._sanitize_name(raw_name) + return name_map + + def __iter__(self): + return iter(self._callbacks) + + def __len__(self): + return len(self._callbacks) + + def __getitem__(self, index): + return self._callbacks[index] + + def find_by_tool_name(self, name: str) -> CallbackAdapter | None: + for cb in self._callbacks: + if cb.tool_name == name: + return cb + return None + + def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: + """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + for cb in self._callbacks: + try: + parsed = split_callback_id(cb.output_id) + except ValueError: + continue + if isinstance(parsed, dict): + parsed = [parsed] + for p in parsed: + if f"{p['id']}.{clean_property_name(p['property'])}" == id_and_prop: + return cb + return None + + def get_initial_value(self, id_and_prop: str) -> Any: + """Return the initial value for ``id_and_prop`` (``"component_id.property"``). + + If a callback outputs to this property, runs it (recursively + resolving its inputs). Otherwise returns the layout default. + """ + upstream_cb = self.find_by_output(id_and_prop) + if upstream_cb is not None: + return upstream_cb.initial_output_value(id_and_prop) + else: + component_id, prop = id_and_prop.rsplit(".", 1) + layout_component = find_component(component_id) + return getattr(layout_component, prop, None) + + def as_mcp_tools(self) -> list[Tool]: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tools will be implemented in a future PR.") + + @property + def tool_names(self) -> set[str]: + return set(self._tool_names_map.values()) + + @cached_property + def component_label_map(self) -> dict[str, list[str]]: + """Map component ID → list of label texts from html.Label containers + and/or `htmlFor` associations. + """ + layout = get_app().get_layout() + if layout is None: + return {} + + labels: dict[str, list[str]] = {} + for comp, ancestors in traverse(layout): + if getattr(comp, "_type", None) == "Label": + html_for = getattr(comp, "htmlFor", None) + if html_for is not None: + text = extract_text(comp) + if text: + labels.setdefault(str(html_for), []).append(text) + + comp_id = getattr(comp, "id", None) + if comp_id is not None: + for ancestor in reversed(ancestors): + if getattr(ancestor, "_type", None) == "Label": + text = extract_text(ancestor) + if text: + sid = str(comp_id) + if text not in labels.get(sid, []): + labels.setdefault(sid, []).append(text) + break + + return labels diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py new file mode 100644 index 0000000000..ec157b6037 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -0,0 +1,36 @@ +"""Callback introspection utilities for MCP tools.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from dash import get_app + +if TYPE_CHECKING: + from .callback_adapter import CallbackAdapter + + +def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> dict[str, Any]: + """Execute a callback via Dash's dispatch pipeline.""" + from dash.mcp.types import CallbackExecutionError + + body = callback.as_callback_body(kwargs) + + app = get_app() + with app.server.test_request_context( + "/_dash-update-component", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_text = response.get_data(as_text=True) + if response.status_code != 200: + raise CallbackExecutionError( + f"Callback {callback.output_id} failed " + f"(HTTP {response.status_code}): {response_text[:500]}" + ) + + return json.loads(response_text) diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py new file mode 100644 index 0000000000..67ec78c9ff --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -0,0 +1,7 @@ +"""Stub — real implementation in a later PR.""" + + +def build_tool_description(outputs, docstring=None): + if docstring: + return docstring.strip() + return "Dash callback" diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..f306042a0c --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_input_schema(param): + return {} diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py new file mode 100644 index 0000000000..d2d70c3552 --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_output_schema(): + return {} diff --git a/dash/mcp/types/__init__.py b/dash/mcp/types/__init__.py new file mode 100644 index 0000000000..af588e0808 --- /dev/null +++ b/dash/mcp/types/__init__.py @@ -0,0 +1,26 @@ +"""MCP types, exceptions, and typing utilities.""" + +from dash.mcp.types.callback_types import MCPInput, MCPOutput +from dash.mcp.types.component_types import ( + ComponentPropertyInfo, + ComponentQueryResult, +) +from dash.mcp.types.exceptions import ( + CallbackExecutionError, + InvalidParamsError, + MCPError, + ToolNotFoundError, +) +from dash.mcp.types.typing_utils import is_nullable + +__all__ = [ + "CallbackExecutionError", + "ComponentPropertyInfo", + "ComponentQueryResult", + "InvalidParamsError", + "MCPError", + "MCPInput", + "MCPOutput", + "ToolNotFoundError", + "is_nullable", +] diff --git a/dash/mcp/types/callback_types.py b/dash/mcp/types/callback_types.py new file mode 100644 index 0000000000..9c65dcb9d8 --- /dev/null +++ b/dash/mcp/types/callback_types.py @@ -0,0 +1,33 @@ +"""Typed dicts for MCP callback adapter data.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import TypedDict + + +class MCPOutput(TypedDict): + """A single callback output, with component type and initial value resolved.""" + + id_and_prop: str + component_id: str + property: str + component_type: str | None + initial_value: Any + tool_name: str + + +class MCPInput(TypedDict): + """A single callback parameter (input or state), fully resolved.""" + + name: str + id_and_prop: str + component_id: str + property: str + annotation: Any | None + component_type: str | None + component: Any | None + required: bool + initial_value: Any + upstream_output: MCPOutput | None diff --git a/dash/mcp/types/component_types.py b/dash/mcp/types/component_types.py new file mode 100644 index 0000000000..0cac3ad689 --- /dev/null +++ b/dash/mcp/types/component_types.py @@ -0,0 +1,20 @@ +"""Typed dicts for component data in MCP.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import NotRequired, TypedDict + + +class ComponentPropertyInfo(TypedDict): + initial_value: Any + modified_by_tool: list[str] + input_to_tool: list[str] + + +class ComponentQueryResult(TypedDict): + component_id: str + component_type: str + label: NotRequired[list[str] | None] + properties: dict[str, ComponentPropertyInfo] diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py new file mode 100644 index 0000000000..7fb962db85 --- /dev/null +++ b/dash/mcp/types/exceptions.py @@ -0,0 +1,30 @@ +"""MCP error types with JSON-RPC error codes.""" + +from __future__ import annotations + + +class MCPError(Exception): + """Base MCP error carrying a JSON-RPC error code.""" + + code = -32603 + + def __init__(self, message: str): + super().__init__(message) + + +class ToolNotFoundError(MCPError): + """Tool name not found in the callback registry.""" + + code = -32601 + + +class InvalidParamsError(MCPError): + """Invalid or missing parameters for a tool call.""" + + code = -32602 + + +class CallbackExecutionError(MCPError): + """Callback raised an exception during execution.""" + + code = -32603 diff --git a/dash/mcp/types/typing_utils.py b/dash/mcp/types/typing_utils.py new file mode 100644 index 0000000000..9a96d4135d --- /dev/null +++ b/dash/mcp/types/typing_utils.py @@ -0,0 +1,28 @@ +"""Shared typing utilities for the MCP layer.""" + +from __future__ import annotations + +import typing +from typing import Any + + +def is_nullable(annotation: Any) -> bool: + """Check if a type annotation includes NoneType (is nullable/Optional).""" + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + _is_union = origin is typing.Union + if not _is_union: + try: + import types as _types + + if isinstance(annotation, _types.UnionType): + _is_union = True + args = annotation.__args__ + except AttributeError: + pass + + if _is_union and args: + return type(None) in args + + return False diff --git a/requirements/install.txt b/requirements/install.txt index 89bd8a5595..b813a6ce55 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,3 +8,4 @@ retrying nest-asyncio setuptools pydantic>=2.10 +mcp>=1.0.0; python_version>="3.10" diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py new file mode 100644 index 0000000000..437a71db5c --- /dev/null +++ b/tests/unit/mcp/conftest.py @@ -0,0 +1,6 @@ +import sys + +collect_ignore_glob = [] + +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py new file mode 100644 index 0000000000..91808d304e --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -0,0 +1,227 @@ +"""Tests for CallbackAdapter.""" + +import pytest +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def duplicate_names_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): # noqa: F811 + return v + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestFromApp: + def test_returns_list(self, simple_app): + assert len(app_context.get().mcp_callback_map) == 1 + + def test_excludes_clientside(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="cs-out"), + html.Div(id="srv-out"), + ] + ) + app.clientside_callback( + "function(n) { return n; }", + Output("cs-out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + names = [a.tool_name for a in app.mcp_callback_map] + assert names == ["server_cb"] + + def test_excludes_mcp_disabled(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("inp", "value")) + def visible(val): + return val + + @app.callback( + Output("out2", "children"), Input("inp", "value"), mcp_enabled=False + ) + def hidden(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + names = [a.tool_name for a in app.mcp_callback_map] + assert "visible" in names + assert "hidden" not in names + + +class TestToolName: + def test_uses_func_name(self, simple_app): + assert app_context.get().mcp_callback_map[0].tool_name == "update" + + def test_duplicates_get_unique_names(self, duplicate_names_app): + names = [a.tool_name for a in app_context.get().mcp_callback_map] + assert len(names) == 2 + assert names[0] != names[1] + + +class TestGetInitialValue: + def test_returns_layout_value(self, simple_app): + callback_map = app_context.get().mcp_callback_map + # Input with no value set — returns None (layout default for dcc.Input) + assert callback_map.get_initial_value("inp.value") is None + + def test_returns_set_value(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(selected): + return selected + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map.get_initial_value("dd.value") == "a" + + def test_initial_callback_makes_param_required(self): + """A param with None in layout but set by an initial callback is required.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="country", options=["France", "Germany"], value="France" + ), + dcc.Dropdown(id="city"), # value=None in layout + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_cities(country): + return [{"label": "Paris", "value": "Paris"}], "Paris" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"Selected: {city}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + # city.value is None in layout but "Paris" after initial callback + with app.server.test_request_context(): + show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") + city_param = show_city_cb.inputs[0] + assert city_param["name"] == "city" + assert city_param["required"] is True # not optional despite None in layout + + +class TestIsValid: + def test_valid_when_inputs_in_layout(self, simple_app): + assert app_context.get().mcp_callback_map[0].is_valid + + def test_invalid_when_input_not_in_layout(self): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("nonexistent", "value")) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert not app.mcp_callback_map[0].is_valid + + def test_pattern_matching_ids_always_valid(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + ) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map[0].is_valid diff --git a/tests/unit/mcp/tools/test_callback_adapter_collection.py b/tests/unit/mcp/tools/test_callback_adapter_collection.py new file mode 100644 index 0000000000..c120a2df8b --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter_collection.py @@ -0,0 +1,145 @@ +"""Tests for CallbackAdapterCollection.""" + +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + +class TestToolNameCollisions: + @staticmethod + def _make_duplicate_cb_app(n=3): + ids = [f"dd{i + 1}" for i in range(n)] + app = Dash(__name__) + app.layout = html.Div( + [ + item + for i in ids + for item in [ + dcc.Dropdown( + id=i, options=[chr(97 + j) for j in range(1)], value="a" + ), + html.Div(id=f"{i}-output"), + ] + ] + ) + for idx, dd_id in enumerate(ids): + + @app.callback(Output(f"{dd_id}-output", "children"), Input(dd_id, "value")) + def cb(value, _id=dd_id): # noqa: F811 + return f"{_id}: {value}" + + return app + + def test_duplicate_func_names_get_unique_tools(self): + app = self._make_duplicate_cb_app(3) + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert len(tool_names) == 3 + assert len(set(tool_names)) == 3, f"Tool names are not unique: {tool_names}" + for name in tool_names: + assert "dd" in name, f"Expected output ID in tool name: {name}" + + def test_unique_func_names_use_func_name(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def alpha_handler(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def beta_handler(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "alpha_handler" in tool_names + assert "beta_handler" in tool_names + + def test_duplicate_func_names_use_output_id(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="out1"), + html.Div(id="out2"), + html.Div(id="out3"), + html.Div(id="in1"), + html.Div(id="in2"), + html.Div(id="in3"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def unique_func(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): + return v + + @app.callback(Output("out3", "children"), Input("in3", "children")) + def cb(v): # noqa: F811 + return v + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "unique_func" in tool_names + non_unique = [n for n in tool_names if n != "unique_func"] + assert len(non_unique) == 2 + assert non_unique[0] != non_unique[1] + + +class TestAllCallbacksVisibleByDefault: + def test_all_callbacks_visible_by_default(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb_one(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb_two(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "cb_one" in tool_names + assert "cb_two" in tool_names + + +class TestAdapterCollection: + def test_adapter_has_expected_properties(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + return val + + _setup(app) + adapter = app.mcp_callback_map[0] + assert adapter.tool_name == "update" + assert adapter.output_id == "out.children" From c6cbe97edb86541883d408942b056f274120e7ff Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:59:57 -0600 Subject: [PATCH 2/2] Fix type errors --- dash/_callback.py | 2 +- dash/types.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dash/_callback.py b/dash/_callback.py index dc73dd0792..5900bbe0fc 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -540,7 +540,7 @@ def _prepare_response( allow_dynamic_callbacks, ): """Prepare the response object based on the callback output.""" - component_ids = collections.defaultdict(dict) + component_ids: dict = collections.defaultdict(dict) if has_output: if not multi: diff --git a/dash/types.py b/dash/types.py index 43bf16dc30..e392a2d599 100644 --- a/dash/types.py +++ b/dash/types.py @@ -73,3 +73,4 @@ class CallbackDispatchResponse(TypedDict): multi: NotRequired[bool] response: NotRequired[Dict[str, CallbackOutput]] sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] + dist: NotRequired[List[Any]]