From 23298c2efdaa1801eae50888152a8e56ab64599c Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Wed, 11 Mar 2026 15:58:20 -0700 Subject: [PATCH 1/6] GML-2041 Performance improvement for parallel requests --- README.md | 231 +++++++++++++--- build.sh | 61 +++- pyTigerGraph/__init__.py | 2 +- pyTigerGraph/common/base.py | 111 +++++--- pyTigerGraph/common/edge.py | 44 ++- pyTigerGraph/common/loading.py | 18 +- pyTigerGraph/common/query.py | 30 +- pyTigerGraph/mcp/MCP_README.md | 196 +++++++------ pyTigerGraph/mcp/connection_manager.py | 228 ++++++++++----- pyTigerGraph/mcp/main.py | 10 +- pyTigerGraph/mcp/server.py | 16 +- pyTigerGraph/mcp/tool_names.py | 4 + pyTigerGraph/mcp/tools/__init__.py | 11 + pyTigerGraph/mcp/tools/connection_tools.py | 104 +++++++ pyTigerGraph/mcp/tools/data_tools.py | 24 +- pyTigerGraph/mcp/tools/datasource_tools.py | 33 ++- pyTigerGraph/mcp/tools/edge_tools.py | 28 +- pyTigerGraph/mcp/tools/gsql_tools.py | 16 +- pyTigerGraph/mcp/tools/node_tools.py | 32 ++- pyTigerGraph/mcp/tools/query_tools.py | 32 ++- pyTigerGraph/mcp/tools/schema_tools.py | 41 ++- pyTigerGraph/mcp/tools/statistics_tools.py | 16 +- pyTigerGraph/mcp/tools/tool_registry.py | 7 + pyTigerGraph/mcp/tools/vector_tools.py | 36 ++- pyTigerGraph/pyTigerGraphBase.py | 154 ++++++++--- pyTigerGraph/pyTigerGraphEdge.py | 86 +----- pyTigerGraph/pyTigerGraphQuery.py | 30 +- pyTigerGraph/pytgasync/pyTigerGraphAuth.py | 2 - pyTigerGraph/pytgasync/pyTigerGraphBase.py | 177 +++++++++--- pyTigerGraph/pytgasync/pyTigerGraphEdge.py | 19 +- pyTigerGraph/pytgasync/pyTigerGraphGSQL.py | 6 +- pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 30 +- pytigergraph-recipe/recipe/meta.yaml | 30 ++ setup.py | 9 +- tests/mcp/README.md | 20 +- tests/mcp/test_connection_manager.py | 292 ++++++++++++++++++++ tests/mcp/test_connection_tools.py | 119 ++++++++ tests/mcp/test_data_tools.py | 35 +++ tests/mcp/test_datasource_tools.py | 26 ++ tests/mcp/test_edge_tools.py | 37 +++ tests/mcp/test_gsql_tools.py | 20 +- tests/mcp/test_node_tools.py | 38 +++ tests/mcp/test_query_tools.py | 40 +++ tests/mcp/test_schema_tools.py | 58 ++++ tests/mcp/test_statistics_tools.py | 22 ++ tests/mcp/test_vector_tools.py | 55 ++++ tests/test_common_edge_query.py | 185 +++++++++++++ tests/test_pyTigerGraphEdgeAsync.py | 68 ++++- 48 files changed, 2316 insertions(+), 573 deletions(-) create mode 100644 pyTigerGraph/mcp/tools/connection_tools.py create mode 100644 pytigergraph-recipe/recipe/meta.yaml create mode 100644 tests/mcp/test_connection_manager.py create mode 100644 tests/mcp/test_connection_tools.py create mode 100644 tests/test_common_edge_query.py diff --git a/README.md b/README.md index aadd9897..2d08d08d 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,221 @@ # pyTigerGraph -pyTigerGraph is a Python package for connecting to TigerGraph databases. Check out the documentation [here](https://docs.tigergraph.com/pytigergraph/current/intro/). +pyTigerGraph is a Python client for [TigerGraph](https://www.tigergraph.com/) databases. It wraps the REST++ and GSQL APIs and provides both a synchronous and an asynchronous interface. -[![Downloads](https://static.pepy.tech/badge/pyTigergraph)](https://pepy.tech/project/pyTigergraph) -[![Downloads](https://static.pepy.tech/badge/pyTigergraph/month)](https://pepy.tech/project/pyTigergraph) -[![Downloads](https://static.pepy.tech/badge/pyTigergraph/week)](https://pepy.tech/project/pyTigergraph) +Full documentation: -## Quickstart +Downloads: [![Total Downloads](https://static.pepy.tech/badge/pyTigergraph)](https://pepy.tech/project/pyTigergraph) | [![Monthly Downloads](https://static.pepy.tech/badge/pyTigergraph/month)](https://pepy.tech/project/pyTigergraph) | [![Weekly Downloads](https://static.pepy.tech/badge/pyTigergraph/week)](https://pepy.tech/project/pyTigergraph) + +--- + +## Installation + +### Base package -### Installing pyTigerGraph -This section walks you through installing pyTigerGraph on your machine. +```sh +pip install pyTigerGraph +``` -#### Prerequisites -* Python 3+ -* If you wish to use the GDS functionality, install `torch` ahead of time. +### Optional extras -#### Install _pyTigerGraph_ +| Extra | What it adds | Install command | +|-------|-------------|-----------------| +| `gds` | Graph Data Science — data loaders for PyTorch Geometric, DGL, and Pandas | `pip install 'pyTigerGraph[gds]'` | +| `mcp` | Model Context Protocol server — exposes TigerGraph as tools for AI agents | `pip install 'pyTigerGraph[mcp]'` | +| `fast` | [orjson](https://github.com/ijl/orjson) JSON backend — 2–10× faster parsing, releases the GIL under concurrent load | `pip install 'pyTigerGraph[fast]'` | -To download _pyTigerGraph_, run the following command in the command line or use the appropriate tool of your development environment (anaconda, PyCharm, etc.).: +Extras can be combined: ```sh -pip3 install pyTigerGraph +pip install 'pyTigerGraph[fast,gds,mcp]' ``` -#### Install _pyTigerGraph[gds]_ +#### `[gds]` prerequisites -To utilize the Graph Data Science Functionality, there are a few options: -* To use the GDS functions with **PyTorch Geometric**, install `torch` and `PyTorch Geometric` according to their instructions: +Install `torch` before installing the `gds` extra: - 1) [Install Torch](https://pytorch.org/get-started/locally/) +1. [Install Torch](https://pytorch.org/get-started/locally/) +2. Optionally [Install PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) or [Install DGL](https://www.dgl.ai/pages/start.html) +3. `pip install 'pyTigerGraph[gds]'` - 2) [Install PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) +#### `[fast]` — orjson JSON backend - 3) Install pyTigerGraph with: - ```sh - pip3 install 'pyTigerGraph[gds]' - ``` +`orjson` is a Rust-backed JSON library that is detected and used automatically when installed. No code changes are required. It improves throughput in two ways: -* To use the GDS functions with **DGL**, install `torch` and `dgl` according to their instructions: +- **Faster parsing** — 2–10× vs stdlib `json` +- **GIL release** — threads parse responses concurrently instead of serialising on the GIL - 1) [Install Torch](https://pytorch.org/get-started/locally/) +If `orjson` is not installed the library falls back to stdlib `json` transparently. - 2) [Install DGL](https://www.dgl.ai/pages/start.html) +--- - 3) Install pyTigerGraph with: - ```sh - pip3 install 'pyTigerGraph[gds]' - ``` +## Quickstart -* To use the GDS functions without needing to produce output in the format supported by PyTorch Geometric or DGL. -This makes the data loaders output *Pandas dataframes*: -```sh -pip3 install 'pyTigerGraph[gds]' +### Synchronous connection + +```python +from pyTigerGraph import TigerGraphConnection + +conn = TigerGraphConnection( + host="http://localhost", + graphname="my_graph", + username="tigergraph", + password="tigergraph", +) + +print(conn.echo()) +``` + +Use as a context manager to ensure the underlying HTTP session is closed: + +```python +with TigerGraphConnection(host="http://localhost", graphname="my_graph") as conn: + result = conn.runInstalledQuery("my_query", {"param": "value"}) +``` + +### Asynchronous connection + +`AsyncTigerGraphConnection` exposes the same API as `TigerGraphConnection` but with `async`/`await` syntax. It uses [aiohttp](https://docs.aiohttp.org/) internally and shares a single connection pool across all concurrent tasks, making it significantly more efficient than threaded sync code at high concurrency. + +```python +import asyncio +from pyTigerGraph import AsyncTigerGraphConnection + +async def main(): + async with AsyncTigerGraphConnection( + host="http://localhost", + graphname="my_graph", + username="tigergraph", + password="tigergraph", + ) as conn: + result = await conn.runInstalledQuery("my_query", {"param": "value"}) + print(result) + +asyncio.run(main()) +``` + +### Token-based authentication + +```python +conn = TigerGraphConnection( + host="http://localhost", + graphname="my_graph", + gsqlSecret="my_secret", # generates a session token automatically +) ``` -Once the package is installed, you can import it like any other Python package: +### HTTPS / TigerGraph Cloud -```py -import pyTigerGraph as tg +```python +conn = TigerGraphConnection( + host="https://my-instance.i.tgcloud.io", + graphname="my_graph", + username="tigergraph", + password="tigergraph", + tgCloud=True, +) ``` -### Getting Started with Core Functions + +--- + +## Connection parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `host` | `str` | `"http://127.0.0.1"` | Server URL including scheme (`http://` or `https://`) | +| `graphname` | `str` | `""` | Target graph name | +| `username` | `str` | `"tigergraph"` | Database username | +| `password` | `str` | `"tigergraph"` | Database password | +| `gsqlSecret` | `str` | `""` | GSQL secret for token-based auth (preferred over username/password) | +| `apiToken` | `str` | `""` | Pre-obtained REST++ API token | +| `jwtToken` | `str` | `""` | JWT token for customer-managed authentication | +| `restppPort` | `int\|str` | `"9000"` | REST++ port (auto-fails over to `14240/restpp` for TigerGraph 4.x) | +| `gsPort` | `int\|str` | `"14240"` | GSQL server port | +| `certPath` | `str` | `None` | Path to CA certificate for HTTPS | +| `tgCloud` | `bool` | `False` | Set to `True` for TigerGraph Cloud instances | + +--- + +## Performance notes + +### Synchronous mode (`TigerGraphConnection`) + +- Each thread gets its own `requests.Session` backed by a private connection pool. This eliminates the `_cookies_lock` contention that a shared session causes under concurrent load. +- Install `pyTigerGraph[fast]` to activate the `orjson` backend and significantly reduce GIL contention between threads during JSON parsing. +- Use `ThreadPoolExecutor` to run queries in parallel: + +```python +from concurrent.futures import ThreadPoolExecutor, as_completed + +with TigerGraphConnection(...) as conn: + with ThreadPoolExecutor(max_workers=16) as executor: + futures = {executor.submit(conn.runInstalledQuery, "q", {"p": v}): v for v in values} + for f in as_completed(futures): + print(f.result()) +``` + +### Asynchronous mode (`AsyncTigerGraphConnection`) + +- Uses a single `aiohttp.ClientSession` with an unbounded connection pool shared across all concurrent coroutines — no GIL, no thread-scheduling overhead. +- Typically achieves higher QPS and lower tail latency than the threaded sync mode for I/O-bound workloads. + +```python +import asyncio +from pyTigerGraph import AsyncTigerGraphConnection + +async def main(): + async with AsyncTigerGraphConnection(...) as conn: + tasks = [conn.runInstalledQuery("q", {"p": v}) for v in values] + results = await asyncio.gather(*tasks) + +asyncio.run(main()) +``` + +--- + +## Graph Data Science (GDS) + +The `gds` sub-module provides data loaders that stream vertex and edge data from TigerGraph directly into PyTorch Geometric, DGL, or Pandas DataFrames for machine learning workflows. + +Install requirements, then access via `conn.gds`: + +```python +conn = TigerGraphConnection(host="...", graphname="...") +loader = conn.gds.vertexLoader(attributes=["feat", "label"], batch_size=1024) +for batch in loader: + train(batch) +``` + +See the [GDS documentation](https://docs.tigergraph.com/pytigergraph/current/gds/) for full details. + +--- + +## MCP Server + +pyTigerGraph includes a built-in [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) server that exposes TigerGraph operations as tools for AI agents and LLM applications (Claude Desktop, Cursor, Copilot, etc.). All MCP tools use the async API internally for optimal performance. + +```sh +pip install 'pyTigerGraph[mcp]' + +# Start the server (reads connection config from environment variables) +tigergraph-mcp +``` + +For full setup instructions, available tools, and configuration examples, see the **[MCP Server README](pyTigerGraph/mcp/MCP_README.md)**. + +--- + +## Getting started video [![pyTigerGraph 101](https://img.youtube.com/vi/2BcC3C-qfX4/hqdefault.jpg)](https://www.youtube.com/watch?v=2BcC3C-qfX4) -The video above is a good starting place for learning the core functions of pyTigerGraph. [This Google Colab notebook](https://colab.research.google.com/drive/1JhYcnGVWT51KswcXZzyPzKqCoPP5htcC) is the companion notebook to the video. +Companion notebook: [Google Colab](https://colab.research.google.com/drive/1JhYcnGVWT51KswcXZzyPzKqCoPP5htcC) + +--- + +## Links + +- [Documentation](https://docs.tigergraph.com/pytigergraph/current/intro/) +- [PyPI](https://pypi.org/project/pyTigerGraph/) +- [GitHub Issues](https://github.com/tigergraph/pyTigerGraph/issues) +- [Source](https://github.com/tigergraph/pyTigerGraph) diff --git a/build.sh b/build.sh index 2c067641..5e173613 100755 --- a/build.sh +++ b/build.sh @@ -1,10 +1,57 @@ -#! /bin/bash +#!/usr/bin/env bash +set -euo pipefail -echo ---- Removing old dist ---- -rm -rf dist +usage() { + cat <&2; usage >&2; exit 1 ;; + esac + shift +done + +if $DO_BUILD; then + echo "---- Removing old dist ----" + rm -rf dist + + echo "---- Building new package ----" + python3 -m build +fi + +if $DO_UPLOAD; then + if [[ ! -d dist ]] || [[ -z "$(ls dist/)" ]]; then + echo "Error: dist/ is empty or missing. Run with --build first." >&2 + exit 1 + fi + + echo "---- Uploading to PyPI ----" + python3 -m twine upload dist/* +fi diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index 844db559..18d4e3bd 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -2,7 +2,7 @@ from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.common.exception import TigerGraphException -__version__ = "2.0.0" +__version__ = "2.0.1" __license__ = "Apache 2" diff --git a/pyTigerGraph/common/base.py b/pyTigerGraph/common/base.py index 53e96422..ef78b5f0 100644 --- a/pyTigerGraph/common/base.py +++ b/pyTigerGraph/common/base.py @@ -15,6 +15,18 @@ from typing import Union from urllib.parse import urlparse +# orjson is an optional Rust-backed JSON library that: +# - parses/serialises 2–10× faster than stdlib json +# - releases the GIL during parsing, eliminating inter-thread contention on +# multi-threaded workloads where all threads parse responses simultaneously +# Fall back to stdlib transparently when orjson is not installed. +try: + import orjson as _orjson + _HAS_ORJSON = True +except ImportError: + _orjson = None + _HAS_ORJSON = False + from pyTigerGraph.common.exception import TigerGraphException @@ -112,6 +124,7 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", # Detect auth mode automatically by checking if jwtToken or apiToken is provided self.authHeader = self._set_auth_header() + self.authMode = "token" if (self.jwtToken or self.apiToken) else "pwd" # TODO Eliminate version and use gsqlVersion only, meaning TigerGraph server version if gsqlVersion: @@ -152,6 +165,12 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.certPath = certPath self.sslPort = str(sslPort) + # SSL verify value — depends only on useCert/certPath which are fixed after init, + # so we compute it once here rather than on every request in _prep_req. + # Note: for http, certPath="" (not None) so the condition is True → False; for + # https, useCert=True → False. The verify=True branch in _prep_req is unreachable. + self.verify = False if (self.useCert or self.certPath) else True + # TODO Remove gcp parameter if gcp: warnings.warn("The `gcp` parameter is deprecated.", @@ -212,6 +231,10 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", self.asynchronous = False + # Pre-build per-authMode header dicts so _prep_req avoids repeating + # the isinstance/string-comparison chain on every request. + self._refresh_auth_headers() + logger.debug("exit: __init__") def _set_auth_header(self): @@ -223,6 +246,34 @@ def _set_auth_header(self): else: return {"Authorization": "Basic {0}".format(self.base64_credential)} + def _refresh_auth_headers(self) -> None: + """Pre-build per-authMode header dicts used by every request. + + Called once at __init__ and again after getToken() updates the + credentials. Eliminates per-request isinstance checks and string + formatting in _prep_req's hot path. + + Two dicts are kept because authMode can be either "token" or "pwd": + - "token": JWT > apiToken (tuple or str) > Basic + - "pwd": JWT > Basic + The "X-User-Agent" header is baked in so _prep_req skips that update too. + """ + # ---- token mode ---- + if isinstance(self.jwtToken, str) and self.jwtToken.strip(): + token_val = "Bearer " + self.jwtToken + elif isinstance(self.apiToken, tuple): + token_val = "Bearer " + self.apiToken[0] + elif isinstance(self.apiToken, str) and self.apiToken.strip(): + token_val = "Bearer " + self.apiToken + else: + token_val = "Basic " + self.base64_credential + + # ---- pwd mode ---- + pwd_val = ("Bearer " + self.jwtToken) if self.jwtToken else ("Basic " + self.base64_credential) + + self._cached_token_auth = {"Authorization": token_val, "X-User-Agent": "pyTigerGraph"} + self._cached_pwd_auth = {"Authorization": pwd_val, "X-User-Agent": "pyTigerGraph"} + def _verify_jwt_token_support(self): try: # Check JWT support for RestPP server @@ -281,34 +332,11 @@ def _prep_req(self, authMode, headers, url, method, data): if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) - _headers = {} - - # If JWT token is provided, always use jwtToken as token - if authMode == "token": - if isinstance(self.jwtToken, str) and self.jwtToken.strip() != "": - token = self.jwtToken - elif isinstance(self.apiToken, tuple): - token = self.apiToken[0] - elif isinstance(self.apiToken, str) and self.apiToken.strip() != "": - token = self.apiToken - else: - token = None - - if token: - self.authHeader = {'Authorization': "Bearer " + token} - _headers = self.authHeader - else: - self.authHeader = { - 'Authorization': 'Basic {0}'.format(self.base64_credential)} - _headers = self.authHeader - self.authMode = "pwd" - else: - if self.jwtToken: - _headers = {'Authorization': "Bearer " + self.jwtToken} - else: - _headers = {'Authorization': 'Basic {0}'.format( - self.base64_credential)} - self.authMode = "pwd" + # Shallow-copy the pre-built header dict (auth + X-User-Agent already included). + # _refresh_auth_headers() keeps these current after every getToken() call. + _headers = dict( + self._cached_token_auth if authMode == "token" else self._cached_pwd_auth + ) if headers: _headers.update(headers) @@ -323,25 +351,28 @@ def _prep_req(self, authMode, headers, url, method, data): else: _data = None - if self.useCert is True or self.certPath is not None: - verify = False - else: - verify = True - - _headers.update({"X-User-Agent": "pyTigerGraph"}) logger.debug("exit: _prep_req") - return _headers, _data, verify + return _headers, _data, self.verify - def _parse_req(self, res, jsonResponse, strictJson, skipCheck, resKey): + def _parse_req(self, data: Union[bytes, str], jsonResponse, strictJson, skipCheck, resKey): logger.debug("entry: _parse_req") if jsonResponse: try: - res = json.loads(res.text, strict=strictJson) - except: - raise TigerGraphException("Cannot parse json: " + res.text) + if _HAS_ORJSON and strictJson: + # orjson accepts bytes directly (no decode step), parses 2–10× faster + # than stdlib, and releases the GIL — eliminating inter-thread contention + # when multiple threads parse responses simultaneously. + res = _orjson.loads(data) + else: + # strictJson=False allows control characters; orjson is always strict, + # so fall back to stdlib for that case. + res = json.loads(data, strict=strictJson) + except Exception: + text = data.decode("utf-8", errors="replace") if isinstance(data, bytes) else data + raise TigerGraphException("Cannot parse json: " + text) else: - res = res.text + res = data.decode("utf-8", errors="replace") if isinstance(data, bytes) else data if not skipCheck: self._error_check(res) diff --git a/pyTigerGraph/common/edge.py b/pyTigerGraph/common/edge.py index 71e3ec10..879763b5 100644 --- a/pyTigerGraph/common/edge.py +++ b/pyTigerGraph/common/edge.py @@ -177,30 +177,22 @@ def _dumps(data) -> str: Returns: The JSON to be sent to the endpoint. """ - ret = "" - if isinstance(data, dict): - c1 = 0 - for k1, v1 in data.items(): - if c1 > 0: - ret += "," - if k1 == ___trgvtxids: - # Dealing with the (possibly multiple instances of) edge details - # v1 should be a dict of lists - c2 = 0 - for k2, v2 in v1.items(): - if c2 > 0: - ret += "," - c3 = 0 - for v3 in v2: - if c3 > 0: - ret += "," - ret += json.dumps(k2) + ':' + json.dumps(v3) - c3 += 1 - c2 += 1 - else: - ret += json.dumps(k1) + ':' + _dumps(data[k1]) - c1 += 1 - return "{" + ret + "}" + if not isinstance(data, dict): + return json.dumps(data) + parts = [] + for k1, v1 in data.items(): + if k1 == ___trgvtxids: + # Dealing with the (possibly multiple instances of) edge details. + # v1 is a dict mapping target vertex ID -> list of attribute dicts. + # Each list entry becomes a separate JSON key:value pair (same key repeated + # for MultiEdge), so we cannot use json.dumps on v1 directly. + for k2, v2 in v1.items(): + k2_encoded = json.dumps(k2) + for v3 in v2: + parts.append(k2_encoded + ":" + json.dumps(v3)) + else: + parts.append(json.dumps(k1) + ":" + _dumps(v1)) + return "{" + ",".join(parts) + "}" def _prep_upsert_edges(sourceVertexType, edgeType, @@ -248,8 +240,8 @@ def _prep_upsert_edge_dataframe(df, from_id, to_id, attributes): for index in df.index: json_up.append(json.loads(df.loc[index].to_json())) json_up[-1] = ( - index if from_id is None else json_up[-1][from_id], - index if to_id is None else json_up[-1][to_id], + index if not from_id else json_up[-1][from_id], + index if not to_id else json_up[-1][to_id], json_up[-1] if attributes is None else {target: json_up[-1][source] for target, source in attributes.items()} ) diff --git a/pyTigerGraph/common/loading.py b/pyTigerGraph/common/loading.py index c12c6f56..5872bb3a 100644 --- a/pyTigerGraph/common/loading.py +++ b/pyTigerGraph/common/loading.py @@ -70,9 +70,10 @@ def _prep_run_loading_job(gsUrl: str, def _prep_abort_loading_jobs(gsUrl: str, graphname: str, jobIds: list[str], pauseJob: bool): '''url builder for abortLoadingJob()''' + job_params = "&".join("jobId=" + jobId for jobId in jobIds) url = gsUrl + "/gsql/v1/loading-jobs/abort?graph=" + graphname - for jobId in jobIds: - url += "&jobId=" + jobId + if job_params: + url += "&" + job_params if pauseJob: url += "&isPause=true" return url @@ -91,13 +92,12 @@ def _prep_resume_loading_job(gsUrl: str, jobId: str): url = gsUrl + "/gsql/v1/loading-jobs/resume/" + jobId return url -def _prep_get_loading_jobs_status(gsUrl: str, jobIds: list[str]): - '''url builder for getLoadingJobStatus() - TODO: verify that this is correct - ''' - url = gsUrl + "/gsql/v1/loading-jobs/status/jobId" - for jobId in jobIds: - url += "&jobId=" + jobId +def _prep_get_loading_jobs_status(gsUrl: str, graphname: str, jobIds: list[str]): + '''url builder for getLoadingJobsStatus()''' + job_params = "&".join("jobId=" + jobId for jobId in jobIds) + url = gsUrl + "/gsql/v1/loading-jobs/status?graph=" + graphname + if job_params: + url += "&" + job_params return url def _prep_get_loading_job_status(gsUrl: str, jobId: str): diff --git a/pyTigerGraph/common/query.py b/pyTigerGraph/common/query.py index 7b5b3655..efaed2b8 100644 --- a/pyTigerGraph/common/query.py +++ b/pyTigerGraph/common/query.py @@ -20,7 +20,11 @@ logger = logging.getLogger(__name__) # TODO getQueries() # List _all_ query names -def _parse_get_installed_queries(fmt, ret): +def _parse_get_installed_queries(fmt, ret, graphname: str = ""): + prefix = f"GET /query/{graphname}/" if graphname else "GET /query/" + if fmt == "list": + return [ep[len(prefix):] for ep in ret if ep.startswith(prefix)] + ret = {ep: v for ep, v in ret.items() if ep.startswith(prefix)} if fmt == "json": ret = json.dumps(ret) if fmt == "df": @@ -57,37 +61,33 @@ def _parse_query_parameters(params: dict) -> str: logger.debug("entry: _parseQueryParameters") logger.debug("params: " + str(params)) - ret = "" + parts = [] for k, v in params.items(): if isinstance(v, tuple): if len(v) == 2 and isinstance(v[1], str): - ret += k + "=" + str(v[0]) + "&" + k + \ - ".type=" + _safe_char(v[1]) + "&" + parts.append(k + "=" + str(v[0])) + parts.append(k + ".type=" + _safe_char(v[1])) else: raise TigerGraphException( "Invalid parameter value: (vertex_primary_id, vertex_type)" " was expected.") elif isinstance(v, list): - i = 0 - for vv in v: + for i, vv in enumerate(v): if isinstance(vv, tuple): if len(vv) == 2 and isinstance(vv[1], str): - ret += k + "[" + str(i) + "]=" + _safe_char(vv[0]) + "&" + \ - k + "[" + str(i) + "].type=" + vv[1] + "&" + parts.append(k + "[" + str(i) + "]=" + _safe_char(vv[0])) + parts.append(k + "[" + str(i) + "].type=" + vv[1]) else: raise TigerGraphException( "Invalid parameter value: (vertex_primary_id, vertex_type)" " was expected.") else: - ret += k + "=" + _safe_char(vv) + "&" - i += 1 + parts.append(k + "=" + _safe_char(vv)) elif isinstance(v, datetime): - ret += k + "=" + \ - _safe_char(v.strftime("%Y-%m-%d %H:%M:%S")) + "&" + parts.append(k + "=" + _safe_char(v.strftime("%Y-%m-%d %H:%M:%S"))) else: - ret += k + "=" + _safe_char(v) + "&" - if ret: - ret = ret[:-1] + parts.append(k + "=" + _safe_char(v)) + ret = "&".join(parts) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) diff --git a/pyTigerGraph/mcp/MCP_README.md b/pyTigerGraph/mcp/MCP_README.md index a111ee58..c8c728bd 100644 --- a/pyTigerGraph/mcp/MCP_README.md +++ b/pyTigerGraph/mcp/MCP_README.md @@ -2,6 +2,25 @@ pyTigerGraph now includes Model Context Protocol (MCP) support, allowing AI agents to interact with TigerGraph through the MCP standard. All MCP tools use pyTigerGraph's async APIs for optimal performance. +## Table of Contents + +- [Installation](#installation) +- [Usage](#usage) + - [Running the MCP Server](#running-the-mcp-server) + - [Configuration](#configuration) + - [Using with Existing Connection](#using-with-existing-connection) +- [Client Examples](#client-examples) + - [Using MultiServerMCPClient](#using-multiserverMCPclient) + - [Using MCP Client SDK Directly](#using-mcp-client-sdk-directly) +- [Available Tools](#available-tools) +- [LLM-Friendly Features](#llm-friendly-features) + - [Structured Responses](#structured-responses) + - [Rich Tool Descriptions](#rich-tool-descriptions) + - [Token Optimization](#token-optimization) + - [Tool Discovery](#tool-discovery) +- [Notes](#notes) +- [Backward Compatibility](#backward-compatibility) + ## Installation To use MCP functionality, install pyTigerGraph with the `mcp` extra: @@ -64,16 +83,11 @@ TG_USERNAME=tigergraph TG_PASSWORD=tigergraph TG_RESTPP_PORT=9000 TG_GS_PORT=14240 +TG_CONN_LIMIT=10 # Optional - increase for parallel tool calls (e.g. 32) ``` The server will automatically load the `.env` file if it exists. Environment variables take precedence over `.env` file values. -You can also specify a custom path to the `.env` file: - -```bash -tigergraph-mcp --env-file /path/to/custom/.env -``` - #### Environment Variables The following environment variables are supported: @@ -90,6 +104,7 @@ The following environment variables are supported: - `TG_SSL_PORT` - SSL port (default: 443) - `TG_TGCLOUD` - Whether using TigerGraph Cloud (default: False) - `TG_CERT_PATH` - Path to certificate (optional) +- `TG_CONN_LIMIT` - Max keep-alive HTTP connections in the async client pool (default: 10). Should be ≥ the number of concurrent MCP tool calls you expect. Named profiles use `_TG_CONN_LIMIT`. ### Using with Existing Connection @@ -116,18 +131,88 @@ conn.start_mcp_server() from pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.mcp import ConnectionManager -conn = AsyncTigerGraphConnection( +async with AsyncTigerGraphConnection( host="http://localhost", graphname="MyGraph", username="tigergraph", - password="tigergraph" + password="tigergraph", + connLimit=10, # set >= number of concurrent MCP tool calls (default: 10) +) as conn: + # Set as default for MCP tools + ConnectionManager.set_default_connection(conn) + # ... run MCP tools ... +# HTTP connection pool is released on exit +``` + +This sets the connection as the default for MCP tools. Note that MCP tools use async APIs internally, so using `AsyncTigerGraphConnection` directly is more efficient. For long-lived connections without `async with`, call `await conn.aclose()` explicitly when finished. + +## Client Examples + +### Using MultiServerMCPClient + +```python +from langchain_mcp_adapters import MultiServerMCPClient +from pathlib import Path +from dotenv import dotenv_values +import asyncio + +# Load environment variables +env_dict = dotenv_values(dotenv_path=Path(".env").expanduser().resolve()) + +# Configure the client +client = MultiServerMCPClient( + { + "tigergraph-mcp": { + "transport": "stdio", + "command": "tigergraph-mcp", + "args": ["-vv"], # Enable debug logging + "env": env_dict, + }, + } ) -# Set as default for MCP tools -ConnectionManager.set_default_connection(conn) +# Get tools and use them +tools = asyncio.run(client.get_tools()) +# Tools are now available for use +``` + +### Using MCP Client SDK Directly + +```python +import asyncio +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +async def call_tool(): + # Configure server parameters + server_params = StdioServerParameters( + command="tigergraph-mcp", + args=["-vv"], # Enable debug logging + env=None, # Uses .env file or environment variables + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # List available tools + tools = await session.list_tools() + print(f"Available tools: {[t.name for t in tools.tools]}") + + # Call a tool + result = await session.call_tool( + "tigergraph__list_graphs", + arguments={} + ) + + # Print result + for content in result.content: + print(content.text) + +asyncio.run(call_tool()) ``` -This sets the connection as the default for MCP tools. Note that MCP tools use async APIs internally, so using `AsyncTigerGraphConnection` directly is more efficient. +**Note:** When using `MultiServerMCPClient` or similar MCP clients with stdio transport, the `args` parameter is required. For the `tigergraph-mcp` command (which is a standalone entry point), set `args` to an empty list `[]`. If you need to pass arguments to the command, include them in the list (e.g., `["-v"]` for verbose mode, `["-vv"]` for debug mode). ## Available Tools @@ -228,81 +313,6 @@ These operations work with the schema and objects of a specific graph. - `tigergraph__get_workflow` - Get step-by-step workflow templates for common tasks (e.g., `data_loading`, `schema_creation`, `graph_exploration`) - `tigergraph__get_tool_info` - Get detailed information about a specific tool (parameters, examples, related tools) -## Backward Compatibility - -All existing pyTigerGraph APIs continue to work as before. MCP support is completely optional and does not affect existing code. The MCP functionality is only available when: - -1. The `mcp` extra is installed -2. You explicitly use MCP-related imports or methods - -## Example: Using with MCP Clients - -### Using MultiServerMCPClient - -```python -from langchain_mcp_adapters import MultiServerMCPClient -from pathlib import Path -from dotenv import dotenv_values -import asyncio - -# Load environment variables -env_dict = dotenv_values(dotenv_path=Path(".env").expanduser().resolve()) - -# Configure the client -client = MultiServerMCPClient( - { - "tigergraph-mcp": { - "transport": "stdio", - "command": "tigergraph-mcp", - "args": ["-vv"], # Enable debug logging - "env": env_dict, - }, - } -) - -# Get tools and use them -tools = asyncio.run(client.get_tools()) -# Tools are now available for use -``` - -### Using MCP Client SDK Directly - -```python -import asyncio -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client - -async def call_tool(): - # Configure server parameters - server_params = StdioServerParameters( - command="tigergraph-mcp", - args=["-vv"], # Enable debug logging - env=None, # Uses .env file or environment variables - ) - - async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - # List available tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Call a tool - result = await session.call_tool( - "tigergraph__list_graphs", - arguments={} - ) - - # Print result - for content in result.content: - print(content.text) - -asyncio.run(call_tool()) -``` - -**Note:** When using `MultiServerMCPClient` or similar MCP clients with stdio transport, the `args` parameter is required. For the `tigergraph-mcp` command (which is a standalone entry point), set `args` to an empty list `[]`. If you need to pass arguments to the command, include them in the list (e.g., `["-v"]` for verbose mode, `["-vv"]` for debug mode). - ## LLM-Friendly Features The MCP server is designed to help AI agents work effectively with TigerGraph. @@ -355,7 +365,7 @@ Responses are designed for efficient LLM token usage: - Only returns new information (results, counts, boolean answers) - Clean text output with no decorative formatting -## Tool Discovery Workflow +### Tool Discovery The MCP server includes discovery tools to help AI agents find the right tool for a task: @@ -384,10 +394,14 @@ result = await session.call_tool( ## Notes -- **Async APIs**: All MCP tools use pyTigerGraph's async APIs (`AsyncTigerGraphConnection`) for optimal performance - **Transport**: The MCP server uses stdio transport by default -- **Structured Responses**: All tools return structured JSON responses with `success`, `operation`, `summary`, `data`, `suggestions`, and `metadata` fields. Error responses include recovery hints and contextual suggestions - **Error Detection**: GSQL operations include error detection for syntax and semantic errors (since `conn.gsql()` does not raise Python exceptions for GSQL failures) -- **Connection Management**: The connection manager automatically creates async connections from environment variables -- **Performance**: Async APIs for non-blocking I/O; `v.outdegree()` for O(1) degree counting; batch operations for multiple vertices/edges +- **Connection Management**: Connections are pooled by profile — each profile's `AsyncTigerGraphConnection` holds a persistent HTTP connection pool (sized by `TG_CONN_LIMIT`, default 10). The pool is automatically released at server shutdown via `ConnectionManager.close_all()`. To adjust pool size per profile, set `_TG_CONN_LIMIT`. +- **Performance**: Persistent HTTP connection pool per profile (no TCP handshake per request); async non-blocking I/O; `v.outdegree()` for O(1) degree counting; batch operations for multiple vertices/edges + +## Backward Compatibility +All existing pyTigerGraph APIs continue to work as before. MCP support is completely optional and does not affect existing code. The MCP functionality is only available when: + +1. The `mcp` extra is installed +2. You explicitly use MCP-related imports or methods diff --git a/pyTigerGraph/mcp/connection_manager.py b/pyTigerGraph/mcp/connection_manager.py index f87c02dd..b428ccba 100644 --- a/pyTigerGraph/mcp/connection_manager.py +++ b/pyTigerGraph/mcp/connection_manager.py @@ -8,12 +8,17 @@ """Connection manager for MCP server. Manages AsyncTigerGraphConnection instances for MCP tools. +Supports named connection profiles via environment variables: + + - Default profile uses unprefixed ``TG_*`` vars (backward compatible). + - Named profiles use ``_TG_*`` vars (e.g. ``STAGING_TG_HOST``). + - ``TG_PROFILE`` selects the active profile (default: ``"default"``). """ import os import logging from pathlib import Path -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List from pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.common.exception import TigerGraphException @@ -59,63 +64,125 @@ def _load_env_file(env_path: Optional[str] = None) -> None: logger.warning(f"Specified .env file not found: {env_path}") +def _get_env_for_profile(profile: str, key: str, default: str = "") -> str: + """Resolve a config value for a profile. + + Default profile uses unprefixed ``TG_*`` vars. + Named profiles use ``_TG_*`` vars, falling back to + the unprefixed ``TG_*`` var, then the built-in *default*. + """ + if profile == "default": + return os.getenv(f"TG_{key}", default) + return os.getenv( + f"{profile.upper()}_TG_{key}", + os.getenv(f"TG_{key}", default), + ) + + class ConnectionManager: - """Manages TigerGraph connections for MCP tools.""" + """Manages TigerGraph connections for MCP tools. + + Connections are pooled by ``profile:graph_name`` key so that + repeated calls with the same profile reuse the same connection. + Call ``await ConnectionManager.close_all()`` at server shutdown to release + the persistent HTTP connection pools held by each ``AsyncTigerGraphConnection``. + """ + + _connection_pool: Dict[str, AsyncTigerGraphConnection] = {} + _profiles: set = set() + + # Keep legacy single-connection reference for backward compat _default_connection: Optional[AsyncTigerGraphConnection] = None + @classmethod + def load_profiles(cls, env_path: Optional[str] = None) -> None: + """Discover available profiles from environment variables. + + Profiles are detected by scanning for ``_TG_HOST`` env vars. + The ``"default"`` profile always exists and uses unprefixed ``TG_*`` + vars. Called once at server startup. + """ + _load_env_file(env_path) + + for key in os.environ: + if key.endswith("_TG_HOST") and not key.startswith("TG_"): + profile = key.rsplit("_TG_HOST", 1)[0].lower() + cls._profiles.add(profile) + + cls._profiles.add("default") + logger.info(f"Discovered connection profiles: {sorted(cls._profiles)}") + + @classmethod + def list_profiles(cls) -> List[str]: + """Return sorted list of discovered profile names.""" + if not cls._profiles: + cls._profiles.add("default") + return sorted(cls._profiles) + @classmethod def get_default_connection(cls) -> Optional[AsyncTigerGraphConnection]: - """Get the default connection instance.""" + """Get the default connection instance (backward compat).""" return cls._default_connection @classmethod def set_default_connection(cls, conn: AsyncTigerGraphConnection) -> None: - """Set the default connection instance.""" + """Set the default connection instance (backward compat).""" cls._default_connection = conn @classmethod - def create_connection_from_env(cls, env_path: Optional[str] = None) -> AsyncTigerGraphConnection: - """Create a connection from environment variables. - - Automatically loads variables from a .env file if it exists (requires python-dotenv). - Environment variables take precedence over .env file values. - - Reads the following environment variables: - - TG_HOST: TigerGraph host (default: http://127.0.0.1) - - TG_GRAPHNAME: Graph name (optional - can be set later or use list_graphs tool) - - TG_USERNAME: Username (default: tigergraph) - - TG_PASSWORD: Password (default: tigergraph) - - TG_SECRET: GSQL secret (optional) - - TG_API_TOKEN: API token (optional) - - TG_JWT_TOKEN: JWT token (optional) - - TG_RESTPP_PORT: REST++ port (default: 9000) - - TG_GS_PORT: GSQL port (default: 14240) - - TG_SSL_PORT: SSL port (default: 443) - - TG_TGCLOUD: Whether using TigerGraph Cloud (default: False) - - TG_CERT_PATH: Path to certificate (optional) - - Args: - env_path: Optional path to .env file. If not provided, searches for .env in current and parent directories. + async def close_all(cls) -> None: + """Close all pooled connections and release their HTTP sockets. + + Call this at server/application shutdown to drain keep-alive connections + gracefully. Connections are removed from the pool after closing so that + subsequent calls to get_connection_for_profile() create fresh sessions. + + Example: + ```python + # In an MCP server lifespan or FastAPI shutdown event: + await ConnectionManager.close_all() + ``` """ - # Load .env file if available - _load_env_file(env_path) + for conn in list(cls._connection_pool.values()): + await conn.aclose() + cls._connection_pool.clear() + cls._profiles.clear() + cls._default_connection = None + + @classmethod + def get_connection_for_profile( + cls, + profile: str = "default", + graph_name: Optional[str] = None, + ) -> AsyncTigerGraphConnection: + """Get or create a connection for the given profile and optional graph. + + Connections are cached by ``profile`` (or ``profile:graph_name`` when + a graph_name override is given). If a cached connection exists but the + caller passes a different ``graph_name``, the graphname attribute on + the cached connection is updated in place. + """ + cache_key = profile - host = os.getenv("TG_HOST", "http://127.0.0.1") - graphname = os.getenv("TG_GRAPHNAME", "") # Optional - can be empty - username = os.getenv("TG_USERNAME", "tigergraph") - password = os.getenv("TG_PASSWORD", "tigergraph") - gsql_secret = os.getenv("TG_SECRET", "") - api_token = os.getenv("TG_API_TOKEN", "") - jwt_token = os.getenv("TG_JWT_TOKEN", "") - restpp_port = os.getenv("TG_RESTPP_PORT", "9000") - gs_port = os.getenv("TG_GS_PORT", "14240") - ssl_port = os.getenv("TG_SSL_PORT", "443") - tg_cloud = os.getenv("TG_TGCLOUD", "false").lower() == "true" - cert_path = os.getenv("TG_CERT_PATH", None) - - # TG_GRAPHNAME is now optional - can be set later or use list_graphs tool + if cache_key in cls._connection_pool: + conn = cls._connection_pool[cache_key] + if graph_name and conn.graphname != graph_name: + conn.graphname = graph_name + return conn + host = _get_env_for_profile(profile, "HOST", "http://127.0.0.1") + graphname = graph_name or _get_env_for_profile(profile, "GRAPHNAME", "") + username = _get_env_for_profile(profile, "USERNAME", "tigergraph") + password = _get_env_for_profile(profile, "PASSWORD", "tigergraph") + gsql_secret = _get_env_for_profile(profile, "SECRET", "") + api_token = _get_env_for_profile(profile, "API_TOKEN", "") + jwt_token = _get_env_for_profile(profile, "JWT_TOKEN", "") + restpp_port = _get_env_for_profile(profile, "RESTPP_PORT", "9000") + gs_port = _get_env_for_profile(profile, "GS_PORT", "14240") + ssl_port = _get_env_for_profile(profile, "SSL_PORT", "443") + tg_cloud = _get_env_for_profile(profile, "TGCLOUD", "false").lower() == "true" + cert_path = _get_env_for_profile(profile, "CERT_PATH", "") or None conn = AsyncTigerGraphConnection( host=host, graphname=graphname, @@ -131,24 +198,74 @@ def create_connection_from_env(cls, env_path: Optional[str] = None) -> AsyncTige certPath=cert_path, ) - cls._default_connection = conn + cls._connection_pool[cache_key] = conn + + if profile == "default": + cls._default_connection = conn + + logger.info(f"Created connection for profile '{profile}' -> {host}") return conn + @classmethod + def get_profile_info(cls, profile: str = "default") -> Dict[str, str]: + """Return non-sensitive connection info for a profile. + + Never includes password, secret, or tokens. + """ + return { + "profile": profile, + "host": _get_env_for_profile(profile, "HOST", "http://127.0.0.1"), + "graphname": _get_env_for_profile(profile, "GRAPHNAME", ""), + "username": _get_env_for_profile(profile, "USERNAME", "tigergraph"), + "restpp_port": _get_env_for_profile(profile, "RESTPP_PORT", "9000"), + "gs_port": _get_env_for_profile(profile, "GS_PORT", "14240"), + "tgcloud": _get_env_for_profile(profile, "TGCLOUD", "false"), + } + + @classmethod + def create_connection_from_env(cls, env_path: Optional[str] = None) -> AsyncTigerGraphConnection: + """Create a connection from environment variables (backward compat). + + Equivalent to ``get_connection_for_profile("default")``. + """ + _load_env_file(env_path) + return cls.get_connection_for_profile("default") + + @classmethod + async def close_all(cls) -> None: + """Close all pooled connections and release their HTTP connection pools. + + Call at server shutdown to cleanly drain open sockets held by the + persistent ``aiohttp.ClientSession`` inside each connection. + """ + for key, conn in list(cls._connection_pool.items()): + try: + await conn.aclose() + logger.debug(f"Closed connection for profile '{key}'") + except Exception as e: + logger.warning(f"Error closing connection '{key}': {e}") + cls._connection_pool.clear() + cls._default_connection = None + def get_connection( + profile: Optional[str] = None, graph_name: Optional[str] = None, connection_config: Optional[Dict[str, Any]] = None, ) -> AsyncTigerGraphConnection: """Get or create an async TigerGraph connection. Args: - graph_name: Name of the graph. If provided, will create a new connection. - connection_config: Connection configuration dict. If provided, will create a new connection. + profile: Connection profile name. Falls back to ``TG_PROFILE`` env var, + then ``"default"``. + graph_name: Graph name override. If provided, updates the connection's + active graph. + connection_config: Explicit connection config dict. If provided, creates + a one-off connection (not pooled). Returns: AsyncTigerGraphConnection instance. """ - # If connection config is provided, create a new connection if connection_config: return AsyncTigerGraphConnection( host=connection_config.get("host", "http://127.0.0.1"), @@ -165,20 +282,5 @@ def get_connection( certPath=connection_config.get("certPath", None), ) - # If graph_name is provided, try to get/create connection for that graph - if graph_name: - # For now, use default connection but set graphname - conn = ConnectionManager.get_default_connection() - if conn is None: - conn = ConnectionManager.create_connection_from_env() - # Update graphname if different - if conn.graphname != graph_name: - conn.graphname = graph_name - return conn - - # Return default connection or create from env - conn = ConnectionManager.get_default_connection() - if conn is None: - conn = ConnectionManager.create_connection_from_env() - return conn - + effective_profile = profile or os.getenv("TG_PROFILE", "default") + return ConnectionManager.get_connection_for_profile(effective_profile, graph_name) diff --git a/pyTigerGraph/mcp/main.py b/pyTigerGraph/mcp/main.py index 75056af4..45cb54db 100644 --- a/pyTigerGraph/mcp/main.py +++ b/pyTigerGraph/mcp/main.py @@ -38,13 +38,9 @@ def main(verbose: bool, env_file: Path = None) -> None: # Ensure mcp.server.lowlevel.server respects the WARNING level logging.getLogger('mcp.server.lowlevel.server').setLevel(logging.WARNING) - # Load .env file (automatically searches if not specified) - from .connection_manager import _load_env_file - if env_file: - _load_env_file(str(env_file)) - else: - # Automatically search for .env file - _load_env_file() + # Load .env file and discover connection profiles + from .connection_manager import ConnectionManager + ConnectionManager.load_profiles(env_path=str(env_file) if env_file else None) asyncio.run(serve()) diff --git a/pyTigerGraph/mcp/server.py b/pyTigerGraph/mcp/server.py index 21ed936c..8f9745d5 100644 --- a/pyTigerGraph/mcp/server.py +++ b/pyTigerGraph/mcp/server.py @@ -17,6 +17,9 @@ from pyTigerGraph.common.exception import TigerGraphException from .tools import ( get_all_tools, + # Connection profile operations + list_connections, + show_connection, # Global schema operations (database level) get_global_schema, # Graph operations (database level) @@ -117,6 +120,11 @@ async def call_tool(name: str, arguments: Dict) -> List[TextContent]: """Handle tool calls.""" try: match name: + # Connection profile operations + case TigerGraphToolName.LIST_CONNECTIONS: + return await list_connections(**arguments) + case TigerGraphToolName.SHOW_CONNECTION: + return await show_connection(**arguments) # Global schema operations (database level) case TigerGraphToolName.GET_GLOBAL_SCHEMA: return await get_global_schema(**arguments) @@ -266,8 +274,12 @@ async def call_tool(name: str, arguments: Dict) -> List[TextContent]: async def serve() -> None: """Serve the MCP server.""" + from .connection_manager import ConnectionManager server = MCPServer() options = server.server.create_initialization_options() - async with stdio_server() as (read_stream, write_stream): - await server.server.run(read_stream, write_stream, options, raise_exceptions=True) + try: + async with stdio_server() as (read_stream, write_stream): + await server.server.run(read_stream, write_stream, options, raise_exceptions=True) + finally: + await ConnectionManager.close_all() diff --git a/pyTigerGraph/mcp/tool_names.py b/pyTigerGraph/mcp/tool_names.py index 2ca7ee81..506296a6 100644 --- a/pyTigerGraph/mcp/tool_names.py +++ b/pyTigerGraph/mcp/tool_names.py @@ -95,6 +95,10 @@ class TigerGraphToolName(str, Enum): DROP_ALL_DATA_SOURCES = "tigergraph__drop_all_data_sources" PREVIEW_SAMPLE_DATA = "tigergraph__preview_sample_data" + # Connection Profile Operations + LIST_CONNECTIONS = "tigergraph__list_connections" + SHOW_CONNECTION = "tigergraph__show_connection" + # Discovery and Navigation Operations DISCOVER_TOOLS = "tigergraph__discover_tools" GET_WORKFLOW = "tigergraph__get_workflow" diff --git a/pyTigerGraph/mcp/tools/__init__.py b/pyTigerGraph/mcp/tools/__init__.py index 3922142e..ae3c4ec9 100644 --- a/pyTigerGraph/mcp/tools/__init__.py +++ b/pyTigerGraph/mcp/tools/__init__.py @@ -7,6 +7,12 @@ """MCP tools for TigerGraph.""" +from .connection_tools import ( + list_connections_tool, + show_connection_tool, + list_connections, + show_connection, +) from .schema_tools import ( # Global schema operations (database level) get_global_schema_tool, @@ -157,6 +163,11 @@ from .tool_registry import get_all_tools __all__ = [ + # Connection profile operations + "list_connections_tool", + "show_connection_tool", + "list_connections", + "show_connection", # Global schema operations (database level) "get_global_schema_tool", "get_global_schema", diff --git a/pyTigerGraph/mcp/tools/connection_tools.py b/pyTigerGraph/mcp/tools/connection_tools.py new file mode 100644 index 00000000..6ada7d82 --- /dev/null +++ b/pyTigerGraph/mcp/tools/connection_tools.py @@ -0,0 +1,104 @@ +# Copyright 2025 TigerGraph Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 +# +# Permission is granted to use, copy, modify, and distribute this software +# under the License. The software is provided "AS IS", without warranty. + +"""Connection profile tools for MCP. + +Allows agents to list available connection profiles and inspect +non-sensitive connection details for a given profile. +""" + +from typing import List, Optional +from pydantic import BaseModel, Field +from mcp.types import Tool, TextContent + +from ..tool_names import TigerGraphToolName +from ..connection_manager import ConnectionManager +from ..response_formatter import format_success, format_error + + +class ListConnectionsToolInput(BaseModel): + """Input schema for listing available connection profiles.""" + + +class ShowConnectionToolInput(BaseModel): + """Input schema for showing connection details.""" + profile: Optional[str] = Field( + None, + description=( + "Connection profile name to inspect. " + "If not provided, shows the active profile (from TG_PROFILE env var or 'default')." + ), + ) + + +list_connections_tool = Tool( + name=TigerGraphToolName.LIST_CONNECTIONS, + description=( + "List all available TigerGraph connection profiles. " + "Profiles are configured via environment variables: " + "the default profile uses TG_HOST, TG_USERNAME, etc., " + "while named profiles use _TG_HOST, _TG_USERNAME, etc." + ), + inputSchema=ListConnectionsToolInput.model_json_schema(), +) + +show_connection_tool = Tool( + name=TigerGraphToolName.SHOW_CONNECTION, + description=( + "Show non-sensitive connection details for a specific profile " + "(host, username, graph name, ports). Never reveals passwords or tokens." + ), + inputSchema=ShowConnectionToolInput.model_json_schema(), +) + + +async def list_connections() -> List[TextContent]: + """List all available connection profiles.""" + try: + profiles = ConnectionManager.list_profiles() + profile_details = [] + for p in profiles: + info = ConnectionManager.get_profile_info(p) + profile_details.append(info) + + return format_success( + operation="list_connections", + summary=f"Found {len(profiles)} connection profile(s): {', '.join(profiles)}", + data={"profiles": profile_details, "count": len(profiles)}, + suggestions=[ + "Show details: show_connection(profile='')", + "Use a profile: pass profile='' to any tool", + ], + ) + except Exception as e: + return format_error( + operation="list_connections", + error=str(e), + ) + + +async def show_connection(profile: Optional[str] = None) -> List[TextContent]: + """Show non-sensitive connection details for a profile.""" + try: + import os + effective = profile or os.getenv("TG_PROFILE", "default") + info = ConnectionManager.get_profile_info(effective) + + return format_success( + operation="show_connection", + summary=f"Connection profile '{effective}': {info['host']}", + data=info, + suggestions=[ + "List all profiles: list_connections()", + f"Use this profile: pass profile='{effective}' to any tool", + ], + ) + except Exception as e: + return format_error( + operation="show_connection", + error=str(e), + ) diff --git a/pyTigerGraph/mcp/tools/data_tools.py b/pyTigerGraph/mcp/tools/data_tools.py index 77b589c1..cd812e9f 100644 --- a/pyTigerGraph/mcp/tools/data_tools.py +++ b/pyTigerGraph/mcp/tools/data_tools.py @@ -71,6 +71,7 @@ class FileConfig(BaseModel): class CreateLoadingJobToolInput(BaseModel): """Input schema for creating a loading job.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") job_name: str = Field(..., description="Name for the loading job.") files: List[FileConfig] = Field( @@ -87,6 +88,7 @@ class CreateLoadingJobToolInput(BaseModel): class RunLoadingJobWithFileToolInput(BaseModel): """Input schema for running a loading job with a file.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") file_path: str = Field(..., description="Absolute path to the data file to load. Example: '/home/user/data/persons.csv'") file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") @@ -99,6 +101,7 @@ class RunLoadingJobWithFileToolInput(BaseModel): class RunLoadingJobWithDataToolInput(BaseModel): """Input schema for running a loading job with inline data.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") data: str = Field(..., description="The data string to load (CSV, JSON, etc.). Example: 'user1,Alice\\nuser2,Bob'") file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") @@ -111,17 +114,20 @@ class RunLoadingJobWithDataToolInput(BaseModel): class GetLoadingJobsToolInput(BaseModel): """Input schema for listing loading jobs.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") class GetLoadingJobStatusToolInput(BaseModel): """Input schema for getting loading job status.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") job_id: str = Field(..., description="The ID of the loading job to check status.") class DropLoadingJobToolInput(BaseModel): """Input schema for dropping a loading job.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") job_name: str = Field(..., description="The name of the loading job to drop.") @@ -284,11 +290,12 @@ async def create_loading_job( files: List[Dict[str, Any]], run_job: bool = False, drop_after_run: bool = False, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Create a loading job from structured configuration.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Generate the GSQL script gsql_script = _generate_loading_job_gsql( @@ -374,11 +381,12 @@ async def run_loading_job_with_file( eol: Optional[str] = None, timeout: int = 16000, size_limit: int = 128000000, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Execute a loading job with a data file.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.runLoadingJobWithFile( filePath=file_path, fileTag=file_tag, @@ -441,11 +449,12 @@ async def run_loading_job_with_data( eol: Optional[str] = None, timeout: int = 16000, size_limit: int = 128000000, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Execute a loading job with inline data string.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.runLoadingJobWithData( data=data, fileTag=file_tag, @@ -503,11 +512,12 @@ async def run_loading_job_with_data( async def get_loading_jobs( + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get a list of all loading jobs for the current graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.getLoadingJobs() if result: job_count = len(result) if isinstance(result, list) else 1 @@ -545,11 +555,12 @@ async def get_loading_jobs( async def get_loading_job_status( job_id: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get the status of a specific loading job.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.getLoadingJobStatus(jobId=job_id) if result: return format_success( @@ -591,11 +602,12 @@ async def get_loading_job_status( async def drop_loading_job( job_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Drop a loading job from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.dropLoadingJob(jobName=job_name) return format_success( diff --git a/pyTigerGraph/mcp/tools/datasource_tools.py b/pyTigerGraph/mcp/tools/datasource_tools.py index c8d0b46a..c546e649 100644 --- a/pyTigerGraph/mcp/tools/datasource_tools.py +++ b/pyTigerGraph/mcp/tools/datasource_tools.py @@ -17,6 +17,7 @@ class CreateDataSourceToolInput(BaseModel): """Input schema for creating a data source.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") data_source_name: str = Field(..., description="Name of the data source.") data_source_type: str = Field(..., description="Type of data source: 's3', 'gcs', 'azure_blob', or 'local'.") config: Dict[str, Any] = Field(..., description="Configuration for the data source (e.g., bucket, credentials).") @@ -24,32 +25,37 @@ class CreateDataSourceToolInput(BaseModel): class UpdateDataSourceToolInput(BaseModel): """Input schema for updating a data source.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") data_source_name: str = Field(..., description="Name of the data source to update.") config: Dict[str, Any] = Field(..., description="Updated configuration for the data source.") class GetDataSourceToolInput(BaseModel): """Input schema for getting a data source.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") data_source_name: str = Field(..., description="Name of the data source.") class DropDataSourceToolInput(BaseModel): """Input schema for dropping a data source.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") data_source_name: str = Field(..., description="Name of the data source to drop.") class GetAllDataSourcesToolInput(BaseModel): """Input schema for getting all data sources.""" - # No parameters needed - returns all data sources + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") class DropAllDataSourcesToolInput(BaseModel): """Input schema for dropping all data sources.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") confirm: bool = Field(False, description="Must be True to confirm dropping all data sources.") class PreviewSampleDataToolInput(BaseModel): """Input schema for previewing sample data.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") data_source_name: str = Field(..., description="Name of the data source.") file_path: str = Field(..., description="Path to the file within the data source.") num_rows: int = Field(10, description="Number of sample rows to preview.") @@ -103,12 +109,13 @@ async def create_data_source( data_source_name: str, data_source_type: str, config: Dict[str, Any], + profile: Optional[str] = None, ) -> List[TextContent]: """Create a new data source.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection() + conn = get_connection(profile=profile) config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) @@ -146,12 +153,13 @@ async def create_data_source( async def update_data_source( data_source_name: str, config: Dict[str, Any], + profile: Optional[str] = None, ) -> List[TextContent]: """Update an existing data source.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection() + conn = get_connection(profile=profile) config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) gsql_cmd = f"ALTER DATA_SOURCE {data_source_name} = ({config_str})" @@ -181,12 +189,13 @@ async def update_data_source( async def get_data_source( data_source_name: str, + profile: Optional[str] = None, ) -> List[TextContent]: """Get information about a data source.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql(f"SHOW DATA_SOURCE {data_source_name}") result_str = str(result) if result else "" @@ -213,12 +222,13 @@ async def get_data_source( async def drop_data_source( data_source_name: str, + profile: Optional[str] = None, ) -> List[TextContent]: """Drop a data source.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql(f"DROP DATA_SOURCE {data_source_name}") result_str = str(result) if result else "" @@ -245,12 +255,15 @@ async def drop_data_source( ) -async def get_all_data_sources(**kwargs) -> List[TextContent]: +async def get_all_data_sources( + profile: Optional[str] = None, + **kwargs, +) -> List[TextContent]: """Get all data sources.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql("SHOW DATA_SOURCE *") result_str = str(result) if result else "" @@ -277,6 +290,7 @@ async def get_all_data_sources(**kwargs) -> List[TextContent]: async def drop_all_data_sources( + profile: Optional[str] = None, confirm: bool = False, ) -> List[TextContent]: """Drop all data sources.""" @@ -294,7 +308,7 @@ async def drop_all_data_sources( ) try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql("DROP DATA_SOURCE *") result_str = str(result) if result else "" @@ -324,13 +338,14 @@ async def preview_sample_data( data_source_name: str, file_path: str, num_rows: int = 10, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Preview sample data from a file.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gsql_cmd = ( f"USE GRAPH {conn.graphname}\n" diff --git a/pyTigerGraph/mcp/tools/edge_tools.py b/pyTigerGraph/mcp/tools/edge_tools.py index d2e0f2f3..947e5855 100644 --- a/pyTigerGraph/mcp/tools/edge_tools.py +++ b/pyTigerGraph/mcp/tools/edge_tools.py @@ -20,6 +20,7 @@ class AddEdgeToolInput(BaseModel): """Input schema for adding an edge.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") source_vertex_type: str = Field(..., description="Type of the source vertex.") source_vertex_id: Union[str, int] = Field(..., description="ID of the source vertex.") @@ -31,6 +32,7 @@ class AddEdgeToolInput(BaseModel): class AddEdgesToolInput(BaseModel): """Input schema for adding multiple edges.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") edge_type: str = Field(..., description="Type of the edges.") edges: List[Dict[str, Any]] = Field( @@ -47,6 +49,7 @@ class AddEdgesToolInput(BaseModel): class GetEdgeToolInput(BaseModel): """Input schema for getting an edge.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") source_vertex_type: str = Field(..., description="Type of the source vertex.") source_vertex_id: Union[str, int] = Field(..., description="ID of the source vertex.") @@ -57,6 +60,7 @@ class GetEdgeToolInput(BaseModel): class GetEdgesToolInput(BaseModel): """Input schema for getting multiple edges.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") source_vertex_type: Optional[str] = Field(None, description="Type of the source vertex. If not provided, gets all types.") source_vertex_id: Optional[Union[str, int]] = Field(None, description="ID of the source vertex. If not provided, gets all edges.") @@ -66,6 +70,7 @@ class GetEdgesToolInput(BaseModel): class DeleteEdgeToolInput(BaseModel): """Input schema for deleting an edge.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") source_vertex_type: str = Field(..., description="Type of the source vertex.") source_vertex_id: Union[str, int] = Field(..., description="ID of the source vertex.") @@ -76,6 +81,7 @@ class DeleteEdgeToolInput(BaseModel): class DeleteEdgesToolInput(BaseModel): """Input schema for deleting multiple edges.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") edge_type: str = Field(..., description="Type of the edges.") edges: List[Dict[str, Any]] = Field(..., description="List of edges with source and target vertex IDs.") @@ -83,6 +89,7 @@ class DeleteEdgesToolInput(BaseModel): class HasEdgeToolInput(BaseModel): """Input schema for checking if an edge exists.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") source_vertex_type: str = Field(..., description="Type of the source vertex.") source_vertex_id: Union[str, int] = Field(..., description="ID of the source vertex.") @@ -329,11 +336,12 @@ async def add_edge( target_vertex_type: str, target_vertex_id: Union[str, int], attributes: Optional[Dict[str, Any]] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Add an edge to the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) await conn.upsertEdge( source_vertex_type, str(source_vertex_id), @@ -372,11 +380,12 @@ async def add_edge( async def add_edges( edge_type: str, edges: List[Dict[str, Any]], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Add multiple edges to the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Convert edges list to format expected by upsertEdges: [(source_id, target_id, {attributes}), ...] # Note: upsertEdges requires all edges to have the same source/target vertex types if not edges: @@ -428,11 +437,12 @@ async def get_edge( edge_type: str, target_vertex_type: str, target_vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get an edge from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.getEdges( source_vertex_type, str(source_vertex_id), @@ -486,11 +496,12 @@ async def get_edges( source_vertex_id: Optional[Union[str, int]] = None, edge_type: Optional[str] = None, limit: Optional[int] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get multiple edges from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) if source_vertex_id and source_vertex_type: result = await conn.getEdges(source_vertex_type, str(source_vertex_id), edge_type, limit=limit) else: @@ -537,11 +548,12 @@ async def delete_edge( edge_type: str, target_vertex_type: str, target_vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Delete an edge from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) await conn.delEdges( sourceVertexType=source_vertex_type, sourceVertexId=str(source_vertex_id), @@ -579,11 +591,12 @@ async def delete_edge( async def delete_edges( edge_type: str, edges: List[Dict[str, Any]], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Delete multiple edges from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) deleted_count = 0 # Delete edges one by one for e in edges: @@ -635,11 +648,12 @@ async def has_edge( edge_type: str, target_vertex_type: str, target_vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Check if an edge exists in the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) try: result = await conn.getEdges( diff --git a/pyTigerGraph/mcp/tools/gsql_tools.py b/pyTigerGraph/mcp/tools/gsql_tools.py index 9d6e3771..c918b8a9 100644 --- a/pyTigerGraph/mcp/tools/gsql_tools.py +++ b/pyTigerGraph/mcp/tools/gsql_tools.py @@ -76,12 +76,14 @@ def get_llm_config() -> Tuple[str, str]: class GSQLToolInput(BaseModel): """Input schema for running GSQL command.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") command: str = Field(..., description="GSQL command to execute.") class GenerateGSQLToolInput(BaseModel): """Input schema for generating GSQL from natural language.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") query_description: str = Field( ..., description="A natural language description of what data you want to retrieve. Examples: 'Find all users who purchased more than 5 items', 'Count vertices by type', 'Find shortest path between two nodes'" @@ -94,6 +96,7 @@ class GenerateGSQLToolInput(BaseModel): class GenerateCypherToolInput(BaseModel): """Input schema for generating Cypher from natural language.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") query_description: str = Field( ..., description="A natural language description of what data you want to retrieve. Examples: 'Find all users who purchased more than 5 items', 'Find friends of friends', 'Match patterns in the graph'" @@ -331,13 +334,14 @@ class GenerateCypherToolInput(BaseModel): async def gsql( command: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Execute a GSQL command.""" from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.gsql(command) result_str = str(result) if result else "" @@ -367,6 +371,7 @@ async def gsql( async def generate_gsql( query_description: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Generate a GSQL query from natural language description using LangChain's init_chat_model. @@ -377,6 +382,7 @@ async def generate_gsql( Args: query_description: Natural language description of the query to generate. + profile: Optional connection profile name. graph_name: Optional graph name to fetch schema for better query generation. Returns: @@ -416,7 +422,7 @@ async def generate_gsql( schema_section = "## Graph Schema\n\nNo schema information available. Generate a generic GSQL query based on the request." if graph_name: try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) schema = await conn.getSchema() if schema: schema_section = f"## Graph Schema\n\n{schema}" @@ -456,7 +462,8 @@ async def generate_gsql( async def generate_cypher( query_description: str, - graph_name: str, + profile: Optional[str] = None, + graph_name: str = None, ) -> List[TextContent]: """Generate an openCypher query from natural language description using LangChain's init_chat_model. @@ -468,6 +475,7 @@ async def generate_cypher( Args: query_description: Natural language description of the query to generate. + profile: Optional connection profile name. graph_name: Name of the graph (required for INTERPRET OPENCYPHER QUERY wrapper). Returns: @@ -506,7 +514,7 @@ async def generate_cypher( # Get schema for the graph schema_section = "## Graph Schema\n\nNo schema information available. Generate a generic Cypher query based on the request." try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) schema = await conn.getSchema() if schema: schema_section = f"## Graph Schema\n\n{schema}" diff --git a/pyTigerGraph/mcp/tools/node_tools.py b/pyTigerGraph/mcp/tools/node_tools.py index c65cc1eb..a8c7df75 100644 --- a/pyTigerGraph/mcp/tools/node_tools.py +++ b/pyTigerGraph/mcp/tools/node_tools.py @@ -25,6 +25,7 @@ class AddNodeToolInput(BaseModel): """Input schema for adding a node.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( None, description=( @@ -131,12 +132,13 @@ async def add_node( vertex_type: str, vertex_id: Union[str, int], attributes: Optional[Dict[str, Any]] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Add a node to the graph with enhanced error handling and suggestions.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Perform the upsert await conn.upsertVertex(vertex_type, str(vertex_id), attributes or {}) @@ -184,6 +186,7 @@ async def add_node( class AddNodesToolInput(BaseModel): """Input schema for adding multiple nodes.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( None, description="Name of the graph. If not provided, uses default connection." @@ -283,12 +286,13 @@ async def add_nodes( vertex_type: str, vertices: List[Dict[str, Any]], vertex_id: str = "id", + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Add multiple nodes to the graph with progress tracking.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Convert vertices list to format expected by upsertVertices vertex_data = [] @@ -378,6 +382,7 @@ async def add_nodes( class GetNodeToolInput(BaseModel): """Input schema for getting a node.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph.") vertex_type: str = Field(..., description="Type of the vertex to retrieve.") vertex_id: Union[str, int] = Field(..., description="ID of the vertex to retrieve.") @@ -414,11 +419,12 @@ class GetNodeToolInput(BaseModel): async def get_node( vertex_type: str, vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get a node from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Use getVerticesById instead of getVertices with WHERE clause result = await conn.getVerticesById(vertex_type, vertex_id) @@ -486,6 +492,7 @@ class GetNodesToolInput(BaseModel): default=None, description="Sort results by attribute. Use '-' prefix for descending (e.g., '-age')" ) + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( default=None, description="Name of the graph to query (uses default if not specified)" @@ -529,11 +536,12 @@ async def get_nodes( where: Optional[str] = None, limit: Optional[int] = 100, sort: Optional[str] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get multiple nodes from the graph with optional filtering.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Build query parameters kwargs = { @@ -590,6 +598,7 @@ class DeleteNodeToolInput(BaseModel): vertex_id: Union[str, int] = Field( description="The unique identifier of the vertex to delete" ) + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( default=None, description="Name of the graph (uses default if not specified)" @@ -628,11 +637,12 @@ class DeleteNodeToolInput(BaseModel): async def delete_node( vertex_type: str, vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Delete a single node from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Delete using delVerticesById result = await conn.delVerticesById(vertex_type, vertex_id) @@ -696,6 +706,7 @@ class DeleteNodesToolInput(BaseModel): default=None, description="Optional list of specific vertex IDs to delete (alternative to WHERE clause)" ) + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( default=None, description="Name of the graph (uses default if not specified)" @@ -737,11 +748,12 @@ async def delete_nodes( vertex_type: str, where: Optional[str] = None, vertex_ids: Optional[List[Union[str, int]]] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Delete multiple nodes from the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) if vertex_ids: # Delete by specific IDs @@ -794,6 +806,7 @@ class HasNodeToolInput(BaseModel): vertex_id: Union[str, int] = Field( description="The unique identifier of the vertex" ) + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( default=None, description="Name of the graph (uses default if not specified)" @@ -833,11 +846,12 @@ class HasNodeToolInput(BaseModel): async def has_node( vertex_type: str, vertex_id: Union[str, int], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Check if a node exists in the graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Use getVerticesById for existence check result = await conn.getVerticesById(vertex_type, vertex_id) @@ -887,6 +901,7 @@ class GetNodeEdgesToolInput(BaseModel): default=100, description="Maximum number of edges to return (default: 100)" ) + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field( default=None, description="Name of the graph (uses default if not specified)" @@ -932,11 +947,12 @@ async def get_node_edges( vertex_id: Union[str, int], edge_type: Optional[str] = None, limit: Optional[int] = 100, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get all edges connected to a specific node.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Get outgoing edges from this vertex edges = await conn.getEdges( diff --git a/pyTigerGraph/mcp/tools/query_tools.py b/pyTigerGraph/mcp/tools/query_tools.py index 16eecc43..cc82e278 100644 --- a/pyTigerGraph/mcp/tools/query_tools.py +++ b/pyTigerGraph/mcp/tools/query_tools.py @@ -20,6 +20,7 @@ class RunQueryToolInput(BaseModel): """Input schema for running an interpreted query.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_text: str = Field( ..., @@ -36,6 +37,7 @@ class RunQueryToolInput(BaseModel): class RunInstalledQueryToolInput(BaseModel): """Input schema for running an installed query.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_name: str = Field(..., description="Name of the installed query.") params: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Query parameters.") @@ -43,36 +45,42 @@ class RunInstalledQueryToolInput(BaseModel): class InstallQueryToolInput(BaseModel): """Input schema for installing a query.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_text: str = Field(..., description="GSQL query text to install.") class ShowQueryToolInput(BaseModel): """Input schema for showing a query.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_name: str = Field(..., description="Name of the query to show.") class GetQueryMetadataToolInput(BaseModel): """Input schema for getting query metadata.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_name: str = Field(..., description="Name of the query.") class DropQueryToolInput(BaseModel): """Input schema for dropping a query.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_name: str = Field(..., description="Name of the query to drop.") class IsQueryInstalledToolInput(BaseModel): """Input schema for checking if a query is installed.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") query_name: str = Field(..., description="Name of the query to check.") class GetNeighborsToolInput(BaseModel): """Input schema for getting neighbors of a node.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Type of the source vertex (e.g., 'Person', 'Product').") vertex_id: str = Field(..., description="ID of the source vertex.") @@ -358,6 +366,7 @@ class GetNeighborsToolInput(BaseModel): async def run_query( query_text: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Run an interpreted query. @@ -372,7 +381,7 @@ async def run_query( graph_name: Optional graph name. """ try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Auto-detect query type from the query text query_upper = query_text.strip().upper() @@ -425,11 +434,12 @@ async def run_query( async def run_installed_query( query_name: str, params: Optional[Dict[str, Any]] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Run an installed query.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.runInstalledQuery(query_name, params or {}) return format_success( @@ -464,11 +474,12 @@ async def run_installed_query( async def install_query( query_text: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Install a query.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.gsql(query_text) result_str = str(result) if result else "" @@ -521,11 +532,12 @@ async def install_query( async def show_query( query_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Show a query.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.showQuery(query_name) return format_success( @@ -555,11 +567,12 @@ async def show_query( async def get_query_metadata( query_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get query metadata.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.getQueryMetadata(query_name) return format_success( @@ -589,11 +602,12 @@ async def get_query_metadata( async def drop_query( query_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Drop (delete) an installed query.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.gsql(f"DROP QUERY {query_name}") result_str = str(result) if result else "" @@ -641,11 +655,12 @@ async def drop_query( async def is_query_installed( query_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Check if a query is installed.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Try to get query metadata - if it succeeds, the query exists try: result = await conn.getQueryMetadata(query_name) @@ -699,11 +714,12 @@ async def get_neighbors( edge_type: Optional[str] = None, target_vertex_type: Optional[str] = None, limit: Optional[int] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get neighbor vertices connected to a source vertex.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Build the edge pattern edge_pattern = f"(({edge_type}):e)" if edge_type else "(ANY:e)" diff --git a/pyTigerGraph/mcp/tools/schema_tools.py b/pyTigerGraph/mcp/tools/schema_tools.py index d754e84f..e9a24968 100644 --- a/pyTigerGraph/mcp/tools/schema_tools.py +++ b/pyTigerGraph/mcp/tools/schema_tools.py @@ -24,7 +24,7 @@ class GetGlobalSchemaToolInput(BaseModel): """Input schema for getting the global schema (all global vertex/edge types, graphs, etc.).""" - # No parameters needed - returns full global schema via GSQL LS command + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") # ============================================================================= @@ -33,11 +33,12 @@ class GetGlobalSchemaToolInput(BaseModel): class ListGraphsToolInput(BaseModel): """Input schema for listing all graph names in the database.""" - # No parameters needed - lists all graph names in the database + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") class CreateGraphToolInput(BaseModel): """Input schema for creating a new graph with its schema.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: str = Field(..., description="Name of the new graph to create.") vertex_types: List[Dict[str, Any]] = Field(..., description="List of vertex type definitions for this graph.") edge_types: List[Dict[str, Any]] = Field(default_factory=list, description="List of edge type definitions for this graph.") @@ -45,11 +46,13 @@ class CreateGraphToolInput(BaseModel): class DropGraphToolInput(BaseModel): """Input schema for dropping a graph from the database.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: str = Field(..., description="Name of the graph to drop.") class ClearGraphDataToolInput(BaseModel): """Input schema for clearing all data from a graph (keeps schema structure).""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: Optional[str] = Field(None, description="Type of vertices to clear. If not provided, clears all data.") confirm: bool = Field(False, description="Must be True to confirm the deletion. This is a destructive operation.") @@ -61,11 +64,13 @@ class ClearGraphDataToolInput(BaseModel): class GetGraphSchemaToolInput(BaseModel): """Input schema for getting a specific graph's schema (raw JSON).""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") class ShowGraphDetailsToolInput(BaseModel): """Input schema for showing details of a graph (schema, queries, jobs).""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") detail_type: Optional[str] = Field( None, @@ -347,10 +352,10 @@ class ShowGraphDetailsToolInput(BaseModel): ) -async def get_graph_schema(graph_name: Optional[str] = None) -> List[TextContent]: +async def get_graph_schema(profile: Optional[str] = None, graph_name: Optional[str] = None) -> List[TextContent]: """Get the schema of a specific graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) schema = await conn.getSchema() vertex_count = len(schema.get("VertexTypes", [])) @@ -540,8 +545,9 @@ def _build_edge_stmt(etype: Dict[str, Any], keyword: str = "ADD") -> tuple: async def create_graph( - graph_name: str, - vertex_types: List[Dict[str, Any]], + profile: Optional[str] = None, + graph_name: str = None, + vertex_types: List[Dict[str, Any]] = None, edge_types: List[Dict[str, Any]] = None, ) -> List[TextContent]: """Create a new graph with local vertex/edge types via a schema change job. @@ -557,7 +563,7 @@ async def create_graph( See: https://docs.tigergraph.com/gsql-ref/4.2/ddl-and-loading/modifying-a-graph-schema """ try: - conn = get_connection() + conn = get_connection(profile=profile) vertex_names: list[str] = [] edge_names: list[str] = [] @@ -685,10 +691,10 @@ async def create_graph( ) -async def drop_graph(graph_name: str) -> List[TextContent]: +async def drop_graph(profile: Optional[str] = None, graph_name: str = None) -> List[TextContent]: """Drop a graph.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) result = await conn.gsql(f"DROP GRAPH {graph_name}") result_str = str(result) if result else "" @@ -725,17 +731,18 @@ async def drop_graph(graph_name: str) -> List[TextContent]: ) -async def get_global_schema(**kwargs) -> List[TextContent]: +async def get_global_schema(profile: Optional[str] = None, **kwargs) -> List[TextContent]: """Get the complete global schema via GSQL LS command. Args: + profile: Connection profile name. **kwargs: No parameters required Returns: Full global schema including global vertex types, edge types, graphs, and their members. """ try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql("LS") result_str = str(result) if result else "" @@ -765,17 +772,18 @@ async def get_global_schema(**kwargs) -> List[TextContent]: ) -async def list_graphs(**kwargs) -> List[TextContent]: +async def list_graphs(profile: Optional[str] = None, **kwargs) -> List[TextContent]: """List all graph names in the TigerGraph database. Args: + profile: Connection profile name. **kwargs: No parameters required - lists all graph names in the database Returns: List of graph names only (without detailed schema information). """ try: - conn = get_connection() + conn = get_connection(profile=profile) result = await conn.gsql("SHOW GRAPH *") result_str = str(result) if result else "" @@ -838,6 +846,7 @@ async def list_graphs(**kwargs) -> List[TextContent]: async def clear_graph_data( + profile: Optional[str] = None, graph_name: Optional[str] = None, vertex_type: Optional[str] = None, confirm: bool = False, @@ -856,7 +865,7 @@ async def clear_graph_data( ) try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) if vertex_type: # Clear specific vertex type and its connected edges @@ -921,18 +930,20 @@ async def clear_graph_data( async def show_graph_details( + profile: Optional[str] = None, graph_name: Optional[str] = None, detail_type: Optional[str] = None, ) -> List[TextContent]: """Show details of a graph, optionally filtered by category. Args: + profile: Connection profile name. graph_name: Graph to inspect. Uses default if omitted. detail_type: One of 'schema', 'query', 'loading_job', 'data_source'. If omitted, runs ``LS`` to show everything. """ try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname if detail_type and detail_type in _DETAIL_TYPE_COMMANDS: diff --git a/pyTigerGraph/mcp/tools/statistics_tools.py b/pyTigerGraph/mcp/tools/statistics_tools.py index 92017f63..51551fd0 100644 --- a/pyTigerGraph/mcp/tools/statistics_tools.py +++ b/pyTigerGraph/mcp/tools/statistics_tools.py @@ -19,18 +19,21 @@ class GetVertexCountToolInput(BaseModel): """Input schema for getting vertex count.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: Optional[str] = Field(None, description="Type of vertices to count. If not provided, counts all types.") class GetEdgeCountToolInput(BaseModel): """Input schema for getting edge count.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") edge_type: Optional[str] = Field(None, description="Type of edges to count. If not provided, counts all types.") class GetNodeDegreeToolInput(BaseModel): """Input schema for getting node degree.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Type of the vertex.") vertex_id: str = Field(..., description="ID of the vertex.") @@ -58,12 +61,13 @@ class GetNodeDegreeToolInput(BaseModel): async def get_vertex_count( + profile: Optional[str] = None, vertex_type: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get vertex count.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) if vertex_type: count = await conn.getVertexCount(vertex_type) @@ -125,12 +129,13 @@ async def get_vertex_count( async def get_edge_count( + profile: Optional[str] = None, edge_type: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Get edge count.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) if edge_type: count = await conn.getEdgeCount(edge_type) @@ -189,15 +194,16 @@ async def get_edge_count( async def get_node_degree( - vertex_type: str, - vertex_id: str, + profile: Optional[str] = None, + vertex_type: str = None, + vertex_id: str = None, edge_type: Optional[str] = None, direction: Optional[str] = "both", graph_name: Optional[str] = None, ) -> List[TextContent]: """Get the degree (number of connected edges) of a node.""" try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) # Build edge type parameter for v.outdegree() # If edge_type contains multiple types separated by |, convert to SET format diff --git a/pyTigerGraph/mcp/tools/tool_registry.py b/pyTigerGraph/mcp/tools/tool_registry.py index 6494395a..9cf67e6b 100644 --- a/pyTigerGraph/mcp/tools/tool_registry.py +++ b/pyTigerGraph/mcp/tools/tool_registry.py @@ -10,6 +10,10 @@ from typing import List from mcp.types import Tool +from .connection_tools import ( + list_connections_tool, + show_connection_tool, +) from .schema_tools import ( # Global schema operations (database level) get_global_schema_tool, @@ -97,6 +101,9 @@ def get_all_tools() -> List[Tool]: List of all MCP tools. """ return [ + # Connection profile operations + list_connections_tool, + show_connection_tool, # Global schema operations (database level) get_global_schema_tool, # Graph operations (database level) diff --git a/pyTigerGraph/mcp/tools/vector_tools.py b/pyTigerGraph/mcp/tools/vector_tools.py index 5ae8c28b..fe707be1 100644 --- a/pyTigerGraph/mcp/tools/vector_tools.py +++ b/pyTigerGraph/mcp/tools/vector_tools.py @@ -30,6 +30,7 @@ class VectorAddAttributeToolInput(BaseModel): """Input schema for adding a vector attribute to a vertex type.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Name of the vertex type to add the vector attribute to.") vector_name: str = Field(..., description="Name of the vector attribute.") @@ -39,6 +40,7 @@ class VectorAddAttributeToolInput(BaseModel): class VectorDropAttributeToolInput(BaseModel): """Input schema for dropping a vector attribute from a vertex type.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Name of the vertex type.") vector_name: str = Field(..., description="Name of the vector attribute to drop.") @@ -46,12 +48,14 @@ class VectorDropAttributeToolInput(BaseModel): class VectorListAttributesToolInput(BaseModel): """Input schema for listing vector attributes in a graph.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: Optional[str] = Field(None, description="Filter by vertex type. If not provided, returns vector attributes for all vertex types.") class VectorIndexStatusToolInput(BaseModel): """Input schema for checking vector index status.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: Optional[str] = Field(None, description="Vertex type to check. If not provided, checks all.") vector_name: Optional[str] = Field(None, description="Vector attribute name. If not provided, checks all.") @@ -63,6 +67,7 @@ class VectorIndexStatusToolInput(BaseModel): class VectorLoadFromCsvToolInput(BaseModel): """Input schema for bulk-loading vectors from a CSV/delimited file via a GSQL loading job.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Target vertex type that has the vector attribute.") vector_attribute: str = Field(..., description="Name of the vector attribute to load into.") @@ -76,6 +81,7 @@ class VectorLoadFromCsvToolInput(BaseModel): class VectorLoadFromJsonToolInput(BaseModel): """Input schema for bulk-loading vectors from a JSON Lines file via a GSQL loading job.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Target vertex type that has the vector attribute.") vector_attribute: str = Field(..., description="Name of the vector attribute to load into.") @@ -98,6 +104,7 @@ class VectorData(BaseModel): class VectorUpsertToolInput(BaseModel): """Input schema for upserting multiple vectors via REST API.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Type of the vertices.") vector_attribute: str = Field(..., description="Name of the vector attribute.") @@ -106,6 +113,7 @@ class VectorUpsertToolInput(BaseModel): class VectorSearchToolInput(BaseModel): """Input schema for vector similarity search using vectorSearch() function.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Type of vertices to search.") vector_attribute: str = Field(..., description="Name of the vector attribute to search.") @@ -117,6 +125,7 @@ class VectorSearchToolInput(BaseModel): class VectorFetchToolInput(BaseModel): """Input schema for fetching vertices with their vector data using GSQL.""" + profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") vertex_type: str = Field(..., description="Type of the vertex.") vertex_ids: List[Union[str, int]] = Field(..., description="List of vertex IDs to fetch.") @@ -294,6 +303,7 @@ async def add_vector_attribute( vector_name: str, dimension: int, metric: str = "COSINE", + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Add a vector attribute to a vertex type. @@ -304,7 +314,7 @@ async def add_vector_attribute( from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname metric = metric.upper() @@ -360,6 +370,7 @@ async def add_vector_attribute( async def drop_vector_attribute( vertex_type: str, vector_name: str, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Drop a vector attribute from a vertex type. @@ -370,7 +381,7 @@ async def drop_vector_attribute( from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname is_global = await _is_global_vertex_type(conn, vertex_type) @@ -413,6 +424,7 @@ async def drop_vector_attribute( async def list_vector_attributes( + profile: Optional[str] = None, graph_name: Optional[str] = None, vertex_type: Optional[str] = None, ) -> List[TextContent]: @@ -431,7 +443,7 @@ async def list_vector_attributes( from ..response_formatter import format_success, format_error, gsql_has_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname result = await conn.gsql(f"USE GRAPH {gname}\nLS") @@ -532,6 +544,7 @@ async def list_vector_attributes( async def get_vector_index_status( + profile: Optional[str] = None, graph_name: Optional[str] = None, vertex_type: Optional[str] = None, vector_name: Optional[str] = None, @@ -540,7 +553,7 @@ async def get_vector_index_status( from ..response_formatter import format_success, format_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) path = f"/vector/status/{conn.graphname}" if vertex_type: @@ -600,13 +613,14 @@ async def upsert_vectors( vertex_type: str, vector_attribute: str, vectors: List[Dict[str, Any]], + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Upsert multiple vertices with vector data using REST Upsert API.""" from ..response_formatter import format_success, format_error try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) success_count = 0 failed_ids = [] @@ -669,6 +683,7 @@ async def search_top_k_similarity( top_k: int = 10, ef: Optional[int] = None, return_vectors: bool = False, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Perform vector similarity search using vectorSearch() function. @@ -685,7 +700,7 @@ async def search_top_k_similarity( gname = None try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname # Pre-flight: check query_vector dimension against the attribute definition @@ -806,6 +821,7 @@ async def fetch_vector( vertex_type: str, vertex_ids: List[Union[str, int]], vector_attribute: Optional[str] = None, + profile: Optional[str] = None, graph_name: Optional[str] = None, **kwargs, ) -> List[TextContent]: @@ -827,7 +843,7 @@ async def fetch_vector( gname = None try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname query_name = f"_fetch_vec_{uuid.uuid4().hex[:8]}" @@ -916,6 +932,7 @@ async def load_vectors_from_csv( element_separator: str = ",", field_separator: str = "|", header: bool = False, + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Bulk-load vectors from a local CSV/delimited file using a GSQL loading job. @@ -935,7 +952,7 @@ async def load_vectors_from_csv( gname = None try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname file_tag = "vec_file" @@ -1024,6 +1041,7 @@ async def load_vectors_from_json( id_key: str = "id", vector_key: str = "vector", element_separator: str = ",", + profile: Optional[str] = None, graph_name: Optional[str] = None, ) -> List[TextContent]: """Bulk-load vectors from a JSON Lines file using a GSQL loading job with JSON_FILE="true". @@ -1049,7 +1067,7 @@ async def load_vectors_from_json( gname = None try: - conn = get_connection(graph_name=graph_name) + conn = get_connection(profile=profile, graph_name=graph_name) gname = conn.graphname file_tag = "vec_file" diff --git a/pyTigerGraph/pyTigerGraphBase.py b/pyTigerGraph/pyTigerGraphBase.py index 0aa68a74..5c4925bc 100644 --- a/pyTigerGraph/pyTigerGraphBase.py +++ b/pyTigerGraph/pyTigerGraphBase.py @@ -25,8 +25,10 @@ import logging import sys import re +import threading import warnings import requests +from requests.adapters import HTTPAdapter from typing import Union from urllib.parse import urlparse @@ -108,6 +110,18 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "MyGraph", version=version, apiToken=apiToken, useCert=useCert, certPath=certPath, debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken) + # Thread-local sessions — each thread gets its own requests.Session and connection pool. + # A single shared Session serializes all threads via its internal cookie-jar RLock + # (_cookies_lock), which is acquired on every response even when no cookies are set. + # Thread-local sessions eliminate that contention while still benefiting from HTTP + # keep-alive within each thread's sequential request stream. + self._local = threading.local() + + # Lock for the one-time port failover (TG 3.x port 9000 → 4.x port 14240/restpp). + # Without a lock, all parallel threads simultaneously fail and all enter the failover + # block, doubling requests and racing to overwrite self.restppUrl / self.restppPort. + self._restpp_failover_lock = threading.Lock() + if graphname == "MyGraph": warnings.warn( "The default graphname 'MyGraph' is deprecated. Please explicitly specify your graph name.", @@ -144,7 +158,26 @@ def _locals(self, _locals: dict) -> str: del _locals["self"] return str(_locals) - logger.debug("exit: __init__") + def _do_request( + self, + method: str, + url: str, + _headers: dict, + _data, + jsonData: bool, + params, + timeout, + ) -> requests.Response: + """Execute one HTTP request and return the response. + + Centralises the jsonData/data branching so _req doesn't duplicate it + for both the primary request and the failover retry. + """ + if jsonData: + return self._session.request( + method, url, headers=_headers, json=_data, params=params, timeout=timeout) + return self._session.request( + method, url, headers=_headers, data=_data, params=params, timeout=timeout) def _req(self, method: str, url: str, authMode: str = "token", headers: dict = None, data: Union[dict, list, str] = None, resKey: str = "results", skipCheck: bool = False, @@ -178,26 +211,21 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N Returns: The (relevant part of the) response from the request (as a dictionary). """ - _headers, _data, verify = self._prep_req(authMode, headers, url, method, data) + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) if "GSQL-TIMEOUT" in _headers: http_timeout = (30, int(int(_headers["GSQL-TIMEOUT"])/1000) + 30) else: http_timeout = (30, None) - if jsonData: - res = requests.request( - method, url, headers=_headers, json=_data, params=params, verify=verify, timeout=http_timeout) - else: - res = requests.request( - method, url, headers=_headers, data=_data, params=params, verify=verify, timeout=http_timeout) + res = self._do_request(method, url, _headers, _data, jsonData, params, http_timeout) try: - if not skipCheck and not (200 <= res.status_code < 300): + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: try: - self._error_check(json.loads(res.text)) + self._error_check(json.loads(res.content)) except json.decoder.JSONDecodeError: - # could not parse the res text (probably returned an html response) + # could not parse the response body (probably returned an html response) pass res.raise_for_status() except Exception as e: @@ -207,36 +235,35 @@ def _req(self, method: str, url: str, authMode: str = "token", headers: dict = N # ---- # Changes port to gsql port, adds /restpp to end to url, tries again, saves changes if successful if self.restppPort in url and "/gsql" not in url and ("/restpp" not in url or self.tgCloud): - newRestppUrl = self.host + ":"+self.gsPort+"/restpp" - # In tgcloud /restpp can already be in the restpp url. We want to extract everything after the port or /restpp - if self.tgCloud: - url = newRestppUrl + '/' + '/'.join(url.split(':')[2].split('/')[2:]) - else: - url = newRestppUrl + '/' + \ - '/'.join(url.split(':')[2].split('/')[1:]) - if jsonData: - res = requests.request( - method, url, headers=_headers, json=_data, params=params, verify=verify) - else: - res = requests.request( - method, url, headers=_headers, data=_data, params=params, verify=verify) - - # Run error check if there might be an error before raising for status - # raising for status gives less descriptive error message - if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: - try: - self._error_check(json.loads(res.text)) - except json.decoder.JSONDecodeError: - # could not parse the res text (probably returned an html response) - pass - res.raise_for_status() - self.restppUrl = newRestppUrl - self.restppPort = self.gsPort + with self._restpp_failover_lock: + # Re-check inside lock: another thread may have already completed the failover + if self.restppPort in url: + newRestppUrl = self.host + ":" + self.gsPort + "/restpp" + # If /restpp is already in the URL (e.g. tgCloud), skip that path segment + if "/restpp" in url: + url = newRestppUrl + "/" + "/".join(url.split(":")[2].split("/")[2:]) + else: + url = newRestppUrl + "/" + "/".join(url.split(":")[2].split("/")[1:]) + res = self._do_request(method, url, _headers, _data, jsonData, params, None) + # Run error check if there might be an error before raising for status + # raising for status gives less descriptive error message + if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: + try: + self._error_check(json.loads(res.content)) + except json.decoder.JSONDecodeError: + # could not parse the response body (probably returned an html response) + pass + res.raise_for_status() + self.restppUrl = newRestppUrl + self.restppPort = self.gsPort else: e.add_note(f"headers: {_headers}") raise e - return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey) + # Pass raw bytes to _parse_req — avoids chardet encoding detection that res.text triggers + # when Content-Type does not explicitly declare a charset. json.loads and orjson both + # accept bytes natively, so no decode step is needed for JSON responses. + return self._parse_req(res.content, jsonResponse, strictJson, skipCheck, resKey) def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: str = "results", skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True) -> Union[dict, list]: @@ -361,6 +388,59 @@ def _delete(self, url: str, authMode: str = "token", data: dict = None, resKey=" return res + @property + def _session(self) -> requests.Session: + """Return the calling thread's dedicated Session, creating it on first use. + + Uses EAFP (try/except) rather than LBYL (hasattr) so the common case — + session already exists — hits the fast return path without paying the cost + of hasattr's internal getattr+AttributeError machinery. Caching self._local + in a local variable avoids a redundant Python attribute lookup on self. + SSL verify is baked into the session once rather than passed on every request. + """ + local = self._local + try: + return local.session + except AttributeError: + s = requests.Session() + s.verify = self.verify # fixed after __init__; no need to pass per-request + _adapter = HTTPAdapter( + pool_connections=1, # one pool per host (we talk to one TG server) + pool_maxsize=1, # each thread is sequential; only 1 socket needed at a time + max_retries=0, + ) + s.mount("http://", _adapter) + s.mount("https://", _adapter) + local.session = s + return s + + def close(self) -> None: + """Close this thread's HTTP session and release its sockets. + + Each thread maintains its own session, so only the calling thread's + session is affected. Sessions in other threads are closed when those + threads exit and are garbage-collected. + + Prefer using the connection as a context manager so cleanup is automatic: + + ```python + with TigerGraphConnection(...) as conn: + conn.runInstalledQuery(...) + ``` + """ + local = self._local + try: + local.session.close() + del local.session + except AttributeError: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + def getVersion(self, raw: bool = False) -> Union[str, list]: """Retrieves the git versions of all components of the system. diff --git a/pyTigerGraph/pyTigerGraphEdge.py b/pyTigerGraph/pyTigerGraphEdge.py index be5b3039..efe9b38a 100644 --- a/pyTigerGraph/pyTigerGraphEdge.py +++ b/pyTigerGraph/pyTigerGraphEdge.py @@ -18,7 +18,6 @@ _prep_get_edge_count_from, _parse_get_edge_count_from, _prep_upsert_edge, - _dumps, _prep_upsert_edges, _prep_upsert_edge_dataframe, _prep_get_edges, @@ -30,8 +29,7 @@ from pyTigerGraph.common.edge import edgeSetToDataFrame as _eS2DF from pyTigerGraph.common.schema import ( - _get_attr_type, - _upsert_attrs + _get_attr_type ) from pyTigerGraph.pyTigerGraphQuery import pyTigerGraphQuery @@ -461,24 +459,6 @@ def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: str, attributes ) - ret = self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data)[0][ - "accepted_edges"] - - vals = _upsert_attrs(attributes) - data = json.dumps( - { - "edges": { - sourceVertexType: { - sourceVertexId: { - edgeType: {targetVertexType: { - targetVertexId: vals}} - - } - } - } - } - ) - params = {"vertex_must_exist": vertexMustExist} ret = self._post( self.restppUrl + "/graph/" + self.graphname, @@ -557,48 +537,15 @@ def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexType: st edgeType=edgeType, targetVertexType=targetVertexType, edges=edges) - header = {} + headers = {} if atomic: - header = {"gsql-atomic-level": "atomic"} - ret = self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data, headers=header)[0][ - "accepted_edges"] - - data = {sourceVertexType: {}} - l1 = data[sourceVertexType] - for e in edges: - if len(e) > 2: - vals = _upsert_attrs(e[2]) - else: - vals = {} - # sourceVertexId - # Converted to string as the key in the JSON payload must be a string - sourceVertexId = str(e[0]) - if sourceVertexId not in l1: - l1[sourceVertexId] = {} - l2 = l1[sourceVertexId] - # edgeType - if edgeType not in l2: - l2[edgeType] = {} - l3 = l2[edgeType] - # targetVertexType - if targetVertexType not in l3: - l3[targetVertexType] = {} - l4 = l3[targetVertexType] - if self.___trgvtxids not in l4: - l4[self.___trgvtxids] = {} - l4 = l4[self.___trgvtxids] - # targetVertexId - # Converted to string as the key in the JSON payload must be a string - targetVertexId = str(e[1]) - if targetVertexId not in l4: - l4[targetVertexId] = [] - l4[targetVertexId].append(vals) - - data = _dumps({"edges": data}) - + headers = {"gsql-atomic-level": "atomic"} params = {"vertex_must_exist": vertexMustExist} ret = self._post( - self.restppUrl + "/graph/" + self.graphname, data=data, params=params + self.restppUrl + "/graph/" + self.graphname, + data=data, + params=params, + headers=headers, )[0]["accepted_edges"] if logger.level == logging.DEBUG: @@ -647,30 +594,13 @@ def upsertEdgeDataFrame(self, df: 'pd.DataFrame', sourceVertexType: str, edgeTyp logger.debug("params: " + self._locals(locals())) json_up = _prep_upsert_edge_dataframe(df, from_id, to_id, attributes) - ret = self.upsertEdges(sourceVertexType, edgeType, - targetVertexType, json_up) - - json_up = [] - - for index in df.index: - json_up.append(json.loads(df.loc[index].to_json())) - json_up[-1] = ( - index if from_id is None else json_up[-1][from_id], - index if to_id is None else json_up[-1][to_id], - json_up[-1] - if attributes is None - else { - target: json_up[-1][source] for target, source in attributes.items() - }, - ) - ret = self.upsertEdges( sourceVertexType, edgeType, targetVertexType, json_up, vertexMustExist=vertexMustExist, - atomic=atomic + atomic=atomic, ) if logger.level == logging.DEBUG: diff --git a/pyTigerGraph/pyTigerGraphQuery.py b/pyTigerGraph/pyTigerGraphQuery.py index 6e0c1699..8e0fa0ce 100644 --- a/pyTigerGraph/pyTigerGraphQuery.py +++ b/pyTigerGraph/pyTigerGraphQuery.py @@ -268,29 +268,28 @@ def listQueryNames(self) -> list: return res - def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame']: - """Returns a list of installed queries. + def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, list, 'pd.DataFrame']: + """Returns installed queries for the graph. + + Only queries that have been installed (i.e., have an active REST endpoint) are returned. Args: fmt: Format of the results: - - "py": Python objects (default) + - "py": Python dict keyed by REST endpoint string (default) - "json": JSON document - "df": pandas DataFrame + - "list": list of query name strings Returns: - The names of the installed queries. - - TODO This function returns all (installed and non-installed) queries - Modify to return only installed ones - TODO Return with query name as key rather than REST endpoint as key? + The installed queries in the requested format. """ logger.debug("entry: getInstalledQueries") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) ret = self.getEndpoints(dynamic=True) - ret = _parse_get_installed_queries(fmt, ret) + ret = _parse_get_installed_queries(fmt, ret, self.graphname) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -381,7 +380,7 @@ def getQueryInstallationStatus(self, requestId: str) -> dict: return ret def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, - timeout: int = None, sizeLimit: int = None, usePost: bool = False, runAsync: bool = False, + timeout: int = None, sizeLimit: int = None, usePost: bool = True, runAsync: bool = False, replica: int = None, threadLimit: int = None, memoryLimit: int = None) -> list: """Runs an installed query. @@ -402,8 +401,9 @@ def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, Maximum size of response (in bytes). See xref:tigergraph-server:API:index.adoc#_response_size[Response size] usePost: - Defaults to False. The RESTPP accepts a maximum URL length of 8192 characters. Use POST if additional parameters cause - you to exceed this limit, or if you choose to pass an empty set into a query for database versions >= 3.8 + Defaults to True. Sends query parameters in the POST request body instead of as URL query parameters. + POST is significantly faster than GET when params contain list values (e.g. vectors), avoids the + 8192-character URL length limit, and is required for passing empty sets in database versions >= 3.8. runAsync: Run the query in asynchronous mode. See xref:gsql-ref:querying:query-operations#_detached_mode_async_option[Async operation] @@ -556,6 +556,12 @@ def runInterpretedQuery(self, queryText: str, params: Union[str, dict] = None) - queryText = queryText.replace("$graphname", self.graphname) queryText = queryText.replace("@graphname@", self.graphname) + # Per the TigerGraph API spec, interpreted query params always go in the + # URL query string (the body is reserved for the GSQL query text). + # _parse_query_parameters handles TigerGraph-specific encoding: + # - SET/BAG: repeated keys k=v1&k=v2 + # - VERTEX (no type): k=id&k.type=vtype + # - SET (no type): k[0]=id&k[0].type=vtype&k[1]=... if isinstance(params, dict): params = _parse_query_parameters(params) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py index 97b66b42..7c454355 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphAuth.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphAuth.py @@ -7,8 +7,6 @@ import logging from typing import Union, Dict import warnings -import httpx - from pyTigerGraph.common.exception import TigerGraphException from pyTigerGraph.common.auth import ( _parse_get_secrets, diff --git a/pyTigerGraph/pytgasync/pyTigerGraphBase.py b/pyTigerGraph/pytgasync/pyTigerGraphBase.py index 6880156f..869c4c09 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphBase.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphBase.py @@ -26,11 +26,12 @@ ``` """ +import asyncio import json import logging -import httpx +import aiohttp -from typing import Union +from typing import Optional, Union from urllib.parse import urlparse from pyTigerGraph.common.base import PyTigerGraphCore @@ -99,11 +100,20 @@ def __init__(self, host: str = "http://127.0.0.1", graphname: str = "", version=version, apiToken=apiToken, useCert=useCert, certPath=certPath, debug=debug, sslPort=sslPort, gcp=gcp, jwtToken=jwtToken) + # Lazily initialized on first request (inside an async context) to avoid + # creating aiohttp.ClientSession outside an event loop in __init__. + self._async_client: Optional[aiohttp.ClientSession] = None + + # asyncio.Lock for the one-time port failover (TG 3.x port 9000 → 4.x port 14240). + # Without a lock all concurrent tasks simultaneously fail and all enter the failover + # block, doubling requests and racing to overwrite self.restppUrl/self.restppPort. + self._restpp_failover_lock = asyncio.Lock() + async def _req(self, method: str, url: str, authMode: str = "token", headers: dict = None, data: Union[dict, list, str] = None, resKey: str = "results", skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True, jsonData: bool = False, jsonResponse: bool = True, func=None) -> Union[dict, list]: - """Generic REST++ API request. Copied from synchronous version, changing requests to httpx with async functionality. + """Generic REST++ API request. Copied from synchronous version, using aiohttp for direct asyncio integration. Args: method: @@ -131,59 +141,63 @@ async def _req(self, method: str, url: str, authMode: str = "token", headers: di Returns: The (relevant part of the) response from the request (as a dictionary). """ - _headers, _data, verify = self._prep_req(authMode, headers, url, method, data) + # Lazy init: session must be created inside an async context (event loop running). + if self._async_client is None or self._async_client.closed: + self._async_client = self._make_async_client() + + _headers, _data, _ = self._prep_req(authMode, headers, url, method, data) if "GSQL-TIMEOUT" in _headers: - http_timeout = (30, int(int(_headers["GSQL-TIMEOUT"])/1000) + 30) + http_timeout = aiohttp.ClientTimeout( + sock_connect=30, + total=int(int(_headers["GSQL-TIMEOUT"]) / 1000) + 30, + ) else: - http_timeout = (30, None) + http_timeout = aiohttp.ClientTimeout(sock_connect=30, total=None) - async with httpx.AsyncClient(timeout=None) as client: - if jsonData: - res = await client.request(method, url, headers=_headers, json=_data, params=params, timeout=http_timeout) - else: - res = await client.request(method, url, headers=_headers, data=_data, params=params, timeout=http_timeout) + status, body, resp = await self._do_request( + method, url, _headers, _data, jsonData, params, http_timeout) try: - if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: + if not skipCheck and not (200 <= status < 300) and status != 404: try: - self._error_check(json.loads(res.text)) + self._error_check(json.loads(body)) except json.decoder.JSONDecodeError: - # could not parse the res text (probably returned an html response) + # could not parse the response body (probably returned an html response) pass - res.raise_for_status() + resp.raise_for_status() except Exception as e: # In TG 4.x the port for restpp has changed from 9000 to 14240. # This block should only be called once. When using 4.x, using port 9000 should fail so self.restppurl will change to host:14240/restpp # ---- # Changes port to 14240, adds /restpp to end to url, tries again, saves changes if successful if self.restppPort == "9000" and "9000" in url: - newRestppUrl = self.host + ":14240/restpp" - # In tgcloud /restpp can already be in the restpp url. We want to extract everything after the port or /restpp - if '/restpp' in url: - url = newRestppUrl + '/' + \ - '/'.join(url.split(':')[2].split('/')[2:]) - else: - url = newRestppUrl + '/' + \ - '/'.join(url.split(':')[2].split('/')[1:]) - async with httpx.AsyncClient(timeout=None) as client: - if jsonData: - res = await client.request(method, url, headers=_headers, json=_data, params=params) - else: - res = await client.request(method, url, headers=_headers, data=_data, params=params) - if not skipCheck and not (200 <= res.status_code < 300) and res.status_code != 404: - try: - self._error_check(json.loads(res.text)) - except json.decoder.JSONDecodeError: - # could not parse the res text (probably returned an html response) - pass - res.raise_for_status() - self.restppUrl = newRestppUrl - self.restppPort = "14240" + async with self._restpp_failover_lock: + # Re-check inside lock: another task may have already completed the failover + if self.restppPort == "9000" and "9000" in url: + newRestppUrl = self.host + ":14240/restpp" + # In tgcloud /restpp can already be in the restpp url. We want to extract everything after the port or /restpp + if '/restpp' in url: + url = newRestppUrl + '/' + \ + '/'.join(url.split(':')[2].split('/')[2:]) + else: + url = newRestppUrl + '/' + \ + '/'.join(url.split(':')[2].split('/')[1:]) + status, body, resp = await self._do_request( + method, url, _headers, _data, jsonData, params, None) + if not skipCheck and not (200 <= status < 300) and status != 404: + try: + self._error_check(json.loads(body)) + except json.decoder.JSONDecodeError: + # could not parse the response body (probably returned an html response) + pass + resp.raise_for_status() + self.restppUrl = newRestppUrl + self.restppPort = "14240" else: raise e - return self._parse_req(res, jsonResponse, strictJson, skipCheck, resKey) + return self._parse_req(body, jsonResponse, strictJson, skipCheck, resKey) async def _get(self, url: str, authMode: str = "token", headers: dict = None, resKey: str = "results", skipCheck: bool = False, params: Union[dict, list, str] = None, strictJson: bool = True) -> Union[dict, list]: @@ -304,6 +318,93 @@ async def _delete(self, url: str, authMode: str = "token", data: dict = None, re return res + def _make_async_client(self) -> aiohttp.ClientSession: + """Create a persistent aiohttp.ClientSession. + + aiohttp integrates directly with asyncio (no anyio abstraction layer), + giving lower per-request overhead than httpx for high-concurrency workloads. + The connection pool grows to match demand automatically (limit=0). + SSL verify is taken from self.verify, computed once at __init__ time. + + Must be called from within an async context (event loop running) to avoid + aiohttp deprecation warnings about session creation outside a coroutine. + """ + connector = aiohttp.TCPConnector( + limit=0, # unbounded pool, grows with demand + ssl=None if self.verify else False, # None = default SSL context (verify on) + ) + return aiohttp.ClientSession(connector=connector) + + async def _do_request( + self, + method: str, + url: str, + _headers: dict, + _data, + jsonData: bool, + params, + timeout: Optional[aiohttp.ClientTimeout], + ): + """Execute one HTTP request and return (status_code, response_text, response). + + Wraps aiohttp's per-request context manager so _req can treat the response + as a plain (status, text, resp) triple. The response object remains usable + for raise_for_status() after the context manager exits because aiohttp caches + status, headers, and request_info on the response object at header-receive time. + """ + kwargs = {"headers": _headers, "params": params} + if timeout is not None: + kwargs["timeout"] = timeout + if jsonData: + kwargs["json"] = _data + else: + kwargs["data"] = _data + async with self._async_client.request(method, url, **kwargs) as resp: + # read() returns raw bytes — avoids charset detection overhead and lets + # orjson/json.loads consume bytes directly without a decode step. + body = await resp.read() + return resp.status, body, resp + + async def aclose(self) -> None: + """Close the underlying HTTP connection pool. + + Call this when done with the connection to release open sockets. + Alternatively, use the connection as an async context manager: + + ```python + async with AsyncTigerGraphConnection(...) as conn: + await conn.runInstalledQuery(...) + ``` + """ + if self._async_client is not None and not self._async_client.closed: + await self._async_client.close() + self._async_client = None + + def __del__(self) -> None: + """Best-effort cleanup when the object is garbage-collected. + + If the event loop is still running at GC time (e.g. during asyncio.run() + shutdown), schedules aclose() as a task so sockets are drained gracefully. + If the loop has already stopped, the OS reclaims the sockets and there is + nothing more we can do — this is not an error. + + This does NOT replace explicit aclose() / async-with usage: GC timing is + unpredictable and create_task() is fire-and-forget with no error handling. + Use `async with AsyncTigerGraphConnection(...) as conn:` for reliable cleanup. + """ + if self._async_client is not None and not self._async_client.closed: + try: + loop = asyncio.get_running_loop() + loop.create_task(self._async_client.close()) + except RuntimeError: + pass # no running loop; OS reclaims sockets on process exit + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + async def getVersion(self, raw: bool = False) -> Union[str, list]: """Retrieves the git versions of all components of the system. diff --git a/pyTigerGraph/pytgasync/pyTigerGraphEdge.py b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py index 3655fdc0..a7938e67 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphEdge.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphEdge.py @@ -404,7 +404,8 @@ async def getEdgeCount(self, edgeType: str = "*", sourceVertexType: str = "", return ret async def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: str, - targetVertexType: str, targetVertexId: str, attributes: dict = None) -> int: + targetVertexType: str, targetVertexId: str, attributes: dict = None, + vertexMustExist: bool = False) -> int: """Upserts an edge. Data is upserted: @@ -457,7 +458,9 @@ async def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: targetVertexType, targetVertexId, attributes) - ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + params = {"vertex_must_exist": vertexMustExist} + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data, + params=params) ret = ret[0]["accepted_edges"] if logger.level == logging.DEBUG: @@ -467,7 +470,7 @@ async def upsertEdge(self, sourceVertexType: str, sourceVertexId: str, edgeType: return ret async def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexType: str, - edges: list, atomic: bool = False) -> int: + edges: list, vertexMustExist: bool = False, atomic: bool = False) -> int: """Upserts multiple edges (of the same type). Args: @@ -534,7 +537,9 @@ async def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexTy if atomic: headers["gsql-atomic-level"] = "atomic" - ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data) + params = {"vertex_must_exist": vertexMustExist} + ret = await self._req("POST", self.restppUrl + "/graph/" + self.graphname, data=data, + params=params, headers=headers) ret = ret[0]["accepted_edges"] if logger.level == logging.DEBUG: @@ -545,7 +550,8 @@ async def upsertEdges(self, sourceVertexType: str, edgeType: str, targetVertexTy async def upsertEdgeDataFrame(self, df: 'pd.DataFrame', sourceVertexType: str, edgeType: str, targetVertexType: str, from_id: str = "", to_id: str = "", - attributes: dict = None, atomic: bool = False) -> int: + attributes: dict = None, vertexMustExist: bool = False, + atomic: bool = False) -> int: """Upserts edges from a Pandas DataFrame. Args: @@ -582,7 +588,8 @@ async def upsertEdgeDataFrame(self, df: 'pd.DataFrame', sourceVertexType: str, e logger.debug("params: " + self._locals(locals())) json_up = _prep_upsert_edge_dataframe(df, from_id, to_id, attributes) - ret = await self.upsertEdges(sourceVertexType, edgeType, targetVertexType, json_up, atomic=atomic) + ret = await self.upsertEdges(sourceVertexType, edgeType, targetVertexType, json_up, + vertexMustExist=vertexMustExist, atomic=atomic) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) diff --git a/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py index a4397e1f..ce5a54d7 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphGSQL.py @@ -5,7 +5,7 @@ """ import logging import re -import httpx +import aiohttp from typing import Union, Tuple, Dict from urllib.parse import urlparse, quote_plus @@ -55,8 +55,8 @@ async def gsql(self, query: str, graphname: str = None, options=None) -> Union[s authMode="pwd", resKey=None, skipCheck=True, jsonResponse=False, headers={"Content-Type": "text/plain"}) - except httpx.HTTPStatusError as e: - if e.response.status_code == 404: + except aiohttp.ClientResponseError as e: + if e.status == 404: res = await self._req("POST", self.gsUrl + "/gsqlserver/gsql/file", data=quote_plus(query.encode("utf-8")), diff --git a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py index cf4bf8c6..c8c3a6d4 100644 --- a/pyTigerGraph/pytgasync/pyTigerGraphQuery.py +++ b/pyTigerGraph/pytgasync/pyTigerGraphQuery.py @@ -264,29 +264,28 @@ async def listQueryNames(self) -> list: return res - async def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, 'pd.DataFrame']: - """Returns a list of installed queries. + async def getInstalledQueries(self, fmt: str = "py") -> Union[dict, str, list, 'pd.DataFrame']: + """Returns installed queries for the graph. + + Only queries that have been installed (i.e., have an active REST endpoint) are returned. Args: fmt: Format of the results: - - "py": Python objects (default) + - "py": Python dict keyed by REST endpoint string (default) - "json": JSON document - "df": pandas DataFrame + - "list": list of query name strings Returns: - The names of the installed queries. - - TODO This function returns all (installed and non-installed) queries - Modify to return only installed ones - TODO Return with query name as key rather than REST endpoint as key? + The installed queries in the requested format. """ logger.debug("entry: getInstalledQueries") if logger.level == logging.DEBUG: logger.debug("params: " + self._locals(locals())) ret = await self.getEndpoints(dynamic=True) - ret = _parse_get_installed_queries(fmt, ret) + ret = _parse_get_installed_queries(fmt, ret, self.graphname) if logger.level == logging.DEBUG: logger.debug("return: " + str(ret)) @@ -386,7 +385,7 @@ async def getQueryInstallationStatus(self, requestId: str) -> dict: return ret async def runInstalledQuery(self, queryName: str, params: Union[str, dict] = None, - timeout: int = None, sizeLimit: int = None, usePost: bool = False, runAsync: bool = False, + timeout: int = None, sizeLimit: int = None, usePost: bool = True, runAsync: bool = False, replica: int = None, threadLimit: int = None, memoryLimit: int = None) -> list: """Runs an installed query. @@ -407,8 +406,9 @@ async def runInstalledQuery(self, queryName: str, params: Union[str, dict] = Non Maximum size of response (in bytes). See xref:tigergraph-server:API:index.adoc#_response_size[Response size] usePost: - Defaults to False. The RESTPP accepts a maximum URL length of 8192 characters. Use POST if additional parameters cause - you to exceed this limit, or if you choose to pass an empty set into a query for database versions >= 3.8 + Defaults to True. Sends query parameters in the POST request body instead of as URL query parameters. + POST is significantly faster than GET when params contain list values (e.g. vectors), avoids the + 8192-character URL length limit, and is required for passing empty sets in database versions >= 3.8. runAsync: Run the query in asynchronous mode. See xref:gsql-ref:querying:query-operations#_detached_mode_async_option[Async operation] @@ -566,6 +566,12 @@ async def runInterpretedQuery(self, queryText: str, params: Union[str, dict] = N queryText = queryText.replace("$graphname", self.graphname) queryText = queryText.replace("@graphname@", self.graphname) + # Per the TigerGraph API spec, interpreted query params always go in the + # URL query string (the body is reserved for the GSQL query text). + # _parse_query_parameters handles TigerGraph-specific encoding: + # - SET/BAG: repeated keys k=v1&k=v2 + # - VERTEX (no type): k=id&k.type=vtype + # - SET (no type): k[0]=id&k[0].type=vtype&k[1]=... if isinstance(params, dict): params = _parse_query_parameters(params) diff --git a/pytigergraph-recipe/recipe/meta.yaml b/pytigergraph-recipe/recipe/meta.yaml new file mode 100644 index 00000000..210cd983 --- /dev/null +++ b/pytigergraph-recipe/recipe/meta.yaml @@ -0,0 +1,30 @@ +package: + name: pytigergraph + version: "2.0.0" + +source: + url: https://pypi.io/packages/source/p/pytigergraph/pytigergraph-2.0.0.tar.gz + +build: + noarch: python + script: {{ PYTHON }} -m pip install . --no-deps -vv + +requirements: + host: + - python >=3.9 + - pip + - setuptools + - wheel + run: + - python >=3.9 + - requests + +about: + home: https://github.com/tigergraph/pyTigerGraph + license: Apache-2.0 + license_family: Apache + summary: Python client for TigerGraph + +extra: + recipe-maintainers: + - chengbiao-jin diff --git a/setup.py b/setup.py index 51f8e298..e911d5ad 100644 --- a/setup.py +++ b/setup.py @@ -50,10 +50,17 @@ def get_data_files(directory): install_requires=[ 'validators', 'requests', - 'httpx'], + 'aiohttp', + 'httpx'], # httpx retained for pytgasync/datasets.py streaming downloads extras_require={ "gds": ["pandas", "kafka-python", "numpy", "tqdm"], "mcp": ["mcp>=1.0.0", "pydantic>=2.0.0", "click", "python-dotenv>=1.0.0"], + # pip install pyTigerGraph[fast] + # orjson is a Rust-backed JSON library that is 2-10x faster than stdlib json + # and releases the GIL during parsing, reducing inter-thread contention under + # concurrent load. It is used automatically when present; the library falls + # back to stdlib json transparently if it is not installed. + "fast": ["orjson"], }, entry_points={ "console_scripts": [ diff --git a/tests/mcp/README.md b/tests/mcp/README.md index 2be92d47..82ddea0a 100644 --- a/tests/mcp/README.md +++ b/tests/mcp/README.md @@ -35,15 +35,17 @@ python -m unittest tests.mcp.test_vector_tools.TestSearchTopKSimilarity.test_suc |------|--------------|---------------| | `__init__.py` | — | `MCPToolTestBase` class, `parse_response` / `assert_success` / `assert_error` helpers | | `test_response_formatter.py` | `mcp.response_formatter` | `gsql_has_error`, `format_success`, `format_error`, `format_list_response` | -| `test_schema_tools.py` | `mcp.tools.schema_tools` | `create_graph`, `drop_graph`, `list_graphs`, `get_graph_schema`, `_build_vertex_stmt`, `_build_edge_stmt`, `clear_graph_data`, `show_graph_details` | -| `test_node_tools.py` | `mcp.tools.node_tools` | `add_node`, `add_nodes`, `get_node`, `get_nodes`, `delete_node`, `delete_nodes`, `has_node`, `get_node_edges` | -| `test_edge_tools.py` | `mcp.tools.edge_tools` | `add_edge`, `add_edges`, `get_edge`, `get_edges`, `delete_edge`, `delete_edges`, `has_edge` | -| `test_query_tools.py` | `mcp.tools.query_tools` | `run_query`, `run_installed_query`, `install_query`, `drop_query`, `show_query`, `get_query_metadata`, `is_query_installed`, `get_neighbors` | -| `test_statistics_tools.py` | `mcp.tools.statistics_tools` | `get_vertex_count`, `get_edge_count`, `get_node_degree` | -| `test_gsql_tools.py` | `mcp.tools.gsql_tools` | `gsql`, `get_llm_config` | -| `test_vector_tools.py` | `mcp.tools.vector_tools` | `add_vector_attribute`, `drop_vector_attribute`, `list_vector_attributes`, `get_vector_index_status`, `upsert_vectors`, `search_top_k_similarity`, `fetch_vector`, `load_vectors_from_csv`, `load_vectors_from_json` | -| `test_datasource_tools.py` | `mcp.tools.datasource_tools` | `create_data_source`, `update_data_source`, `get_data_source`, `drop_data_source`, `get_all_data_sources`, `drop_all_data_sources`, `preview_sample_data` | -| `test_data_tools.py` | `mcp.tools.data_tools` | `_generate_loading_job_gsql`, `create_loading_job`, `run_loading_job_with_file`, `run_loading_job_with_data`, `drop_loading_job` | +| `test_connection_manager.py` | `mcp.connection_manager` | Multi-profile env resolution (`_get_env_for_profile`), profile discovery (`load_profiles`), connection pooling, `get_connection` with profile param, backward compatibility | +| `test_connection_tools.py` | `mcp.tools.connection_tools` | `list_connections`, `show_connection` | +| `test_schema_tools.py` | `mcp.tools.schema_tools` | `create_graph`, `drop_graph`, `list_graphs`, `get_graph_schema`, `_build_vertex_stmt`, `_build_edge_stmt`, `clear_graph_data`, `show_graph_details`, profile propagation | +| `test_node_tools.py` | `mcp.tools.node_tools` | `add_node`, `add_nodes`, `get_node`, `get_nodes`, `delete_node`, `delete_nodes`, `has_node`, `get_node_edges`, profile propagation | +| `test_edge_tools.py` | `mcp.tools.edge_tools` | `add_edge`, `add_edges`, `get_edge`, `get_edges`, `delete_edge`, `delete_edges`, `has_edge`, profile propagation | +| `test_query_tools.py` | `mcp.tools.query_tools` | `run_query`, `run_installed_query`, `install_query`, `drop_query`, `show_query`, `get_query_metadata`, `is_query_installed`, `get_neighbors`, profile propagation | +| `test_statistics_tools.py` | `mcp.tools.statistics_tools` | `get_vertex_count`, `get_edge_count`, `get_node_degree`, profile propagation | +| `test_gsql_tools.py` | `mcp.tools.gsql_tools` | `gsql`, `get_llm_config`, profile propagation | +| `test_vector_tools.py` | `mcp.tools.vector_tools` | `add_vector_attribute`, `drop_vector_attribute`, `list_vector_attributes`, `get_vector_index_status`, `upsert_vectors`, `search_top_k_similarity`, `fetch_vector`, `load_vectors_from_csv`, `load_vectors_from_json`, profile propagation | +| `test_datasource_tools.py` | `mcp.tools.datasource_tools` | `create_data_source`, `update_data_source`, `get_data_source`, `drop_data_source`, `get_all_data_sources`, `drop_all_data_sources`, `preview_sample_data`, profile propagation | +| `test_data_tools.py` | `mcp.tools.data_tools` | `_generate_loading_job_gsql`, `create_loading_job`, `run_loading_job_with_file`, `run_loading_job_with_data`, `drop_loading_job`, profile propagation | ## How Mocking Works diff --git a/tests/mcp/test_connection_manager.py b/tests/mcp/test_connection_manager.py new file mode 100644 index 00000000..71825f34 --- /dev/null +++ b/tests/mcp/test_connection_manager.py @@ -0,0 +1,292 @@ +"""Tests for pyTigerGraph.mcp.connection_manager multi-profile support.""" + +import os +import unittest +from unittest.mock import patch, MagicMock + +from pyTigerGraph.mcp.connection_manager import ( + ConnectionManager, + _get_env_for_profile, + get_connection, +) + + +class TestGetEnvForProfile(unittest.TestCase): + """Verify env-var resolution for default and named profiles.""" + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=False) + def test_default_profile_reads_unprefixed(self): + self.assertEqual( + _get_env_for_profile("default", "HOST"), + "http://default-host", + ) + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=False) + def test_default_profile_returns_builtin_default(self): + val = _get_env_for_profile("default", "NONEXISTENT_KEY", "fallback") + self.assertEqual(val, "fallback") + + @patch.dict( + os.environ, + {"STAGING_TG_HOST": "http://staging-host", "TG_HOST": "http://default-host"}, + clear=False, + ) + def test_named_profile_reads_prefixed(self): + self.assertEqual( + _get_env_for_profile("staging", "HOST"), + "http://staging-host", + ) + + @patch.dict( + os.environ, + {"TG_USERNAME": "shared_user"}, + clear=False, + ) + def test_named_profile_falls_back_to_unprefixed(self): + val = _get_env_for_profile("staging", "USERNAME") + self.assertEqual(val, "shared_user") + + @patch.dict(os.environ, {}, clear=True) + def test_named_profile_falls_back_to_builtin(self): + val = _get_env_for_profile("staging", "HOST", "http://127.0.0.1") + self.assertEqual(val, "http://127.0.0.1") + + @patch.dict( + os.environ, + { + "PROD_US_TG_HOST": "http://prod-us", + "TG_HOST": "http://default", + }, + clear=False, + ) + def test_multi_word_profile(self): + self.assertEqual( + _get_env_for_profile("prod_us", "HOST"), + "http://prod-us", + ) + + +class TestConnectionManagerLoadProfiles(unittest.TestCase): + + def setUp(self): + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default", + "STAGING_TG_HOST": "http://staging", + "ANALYTICS_TG_HOST": "http://analytics", + }, + clear=False, + ) + @patch("pyTigerGraph.mcp.connection_manager._load_env_file") + def test_discovers_named_profiles(self, mock_load_env): + ConnectionManager.load_profiles() + profiles = ConnectionManager.list_profiles() + self.assertIn("default", profiles) + self.assertIn("staging", profiles) + self.assertIn("analytics", profiles) + self.assertEqual(len(profiles), 3) + + @patch.dict(os.environ, {"TG_HOST": "http://default"}, clear=True) + @patch("pyTigerGraph.mcp.connection_manager._load_env_file") + def test_always_includes_default(self, mock_load_env): + ConnectionManager.load_profiles() + self.assertIn("default", ConnectionManager.list_profiles()) + + @patch.dict(os.environ, {}, clear=True) + def test_list_profiles_without_load(self): + """list_profiles should return at least 'default' even without load.""" + profiles = ConnectionManager.list_profiles() + self.assertIn("default", profiles) + + +class TestConnectionManagerPool(unittest.TestCase): + + def setUp(self): + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + {"TG_HOST": "http://default-host", "TG_USERNAME": "admin"}, + clear=True, + ) + def test_creates_connection_for_default(self): + conn = ConnectionManager.get_connection_for_profile("default") + self.assertEqual(conn.host, "http://default-host") + self.assertEqual(conn.username, "admin") + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default-host", + "STAGING_TG_HOST": "http://staging-host", + "STAGING_TG_USERNAME": "stg_user", + }, + clear=True, + ) + def test_creates_separate_connections_per_profile(self): + conn_default = ConnectionManager.get_connection_for_profile("default") + conn_staging = ConnectionManager.get_connection_for_profile("staging") + self.assertIsNot(conn_default, conn_staging) + self.assertEqual(conn_default.host, "http://default-host") + self.assertEqual(conn_staging.host, "http://staging-host") + self.assertEqual(conn_staging.username, "stg_user") + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=True) + def test_pool_caches_by_profile(self): + conn1 = ConnectionManager.get_connection_for_profile("default") + conn2 = ConnectionManager.get_connection_for_profile("default") + self.assertIs(conn1, conn2) + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=True) + def test_graph_name_override(self): + conn = ConnectionManager.get_connection_for_profile("default", graph_name="MyGraph") + self.assertEqual(conn.graphname, "MyGraph") + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=True) + def test_graph_name_updated_on_cached_conn(self): + conn1 = ConnectionManager.get_connection_for_profile("default", graph_name="Graph1") + conn2 = ConnectionManager.get_connection_for_profile("default", graph_name="Graph2") + self.assertIs(conn1, conn2) + self.assertEqual(conn2.graphname, "Graph2") + + @patch.dict(os.environ, {"TG_HOST": "http://default-host"}, clear=True) + def test_default_sets_legacy_ref(self): + conn = ConnectionManager.get_connection_for_profile("default") + self.assertIs(ConnectionManager.get_default_connection(), conn) + + @patch.dict( + os.environ, + {"STAGING_TG_HOST": "http://staging-host"}, + clear=True, + ) + def test_named_profile_does_not_set_legacy_ref(self): + ConnectionManager.get_connection_for_profile("staging") + self.assertIsNone(ConnectionManager.get_default_connection()) + + +class TestConnectionManagerProfileInfo(unittest.TestCase): + + @patch.dict( + os.environ, + { + "TG_HOST": "http://my-host", + "TG_USERNAME": "admin", + "TG_PASSWORD": "supersecret", + "TG_SECRET": "mysecret", + "TG_GRAPHNAME": "ProdGraph", + }, + clear=True, + ) + def test_info_excludes_secrets(self): + info = ConnectionManager.get_profile_info("default") + self.assertEqual(info["host"], "http://my-host") + self.assertEqual(info["username"], "admin") + self.assertEqual(info["graphname"], "ProdGraph") + self.assertNotIn("password", info) + self.assertNotIn("secret", info) + self.assertNotIn("api_token", info) + self.assertNotIn("jwt_token", info) + + +class TestGetConnectionFunction(unittest.TestCase): + + def setUp(self): + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + {"TG_HOST": "http://default", "TG_PROFILE": "default"}, + clear=True, + ) + def test_no_args_uses_tg_profile_env(self): + conn = get_connection() + self.assertEqual(conn.host, "http://default") + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default", + "STAGING_TG_HOST": "http://staging", + "TG_PROFILE": "staging", + }, + clear=True, + ) + def test_tg_profile_env_selects_staging(self): + conn = get_connection() + self.assertEqual(conn.host, "http://staging") + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default", + "STAGING_TG_HOST": "http://staging", + "TG_PROFILE": "default", + }, + clear=True, + ) + def test_explicit_profile_overrides_env(self): + conn = get_connection(profile="staging") + self.assertEqual(conn.host, "http://staging") + + @patch.dict(os.environ, {"TG_HOST": "http://default"}, clear=True) + def test_graph_name_passthrough(self): + conn = get_connection(graph_name="TestG") + self.assertEqual(conn.graphname, "TestG") + + def test_connection_config_creates_oneoff(self): + conn = get_connection(connection_config={ + "host": "http://adhoc", + "graphname": "AdHocGraph", + "username": "user1", + "password": "pass1", + }) + self.assertEqual(conn.host, "http://adhoc") + self.assertEqual(conn.graphname, "AdHocGraph") + self.assertNotIn("adhoc", ConnectionManager._connection_pool) + + +class TestBackwardCompatibility(unittest.TestCase): + """Existing single-connection usage (no profile param) should keep working.""" + + def setUp(self): + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + { + "TG_HOST": "http://legacy-host", + "TG_GRAPHNAME": "LegacyGraph", + "TG_USERNAME": "tigergraph", + "TG_PASSWORD": "tigergraph", + }, + clear=True, + ) + def test_get_connection_without_profile(self): + conn = get_connection() + self.assertEqual(conn.host, "http://legacy-host") + self.assertEqual(conn.graphname, "LegacyGraph") + + @patch.dict( + os.environ, + {"TG_HOST": "http://legacy-host"}, + clear=True, + ) + def test_create_connection_from_env_backward_compat(self): + conn = ConnectionManager.create_connection_from_env() + self.assertEqual(conn.host, "http://legacy-host") + self.assertIs(ConnectionManager.get_default_connection(), conn) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_connection_tools.py b/tests/mcp/test_connection_tools.py new file mode 100644 index 00000000..a2619aec --- /dev/null +++ b/tests/mcp/test_connection_tools.py @@ -0,0 +1,119 @@ +"""Tests for pyTigerGraph.mcp.tools.connection_tools.""" + +import os +import unittest +from unittest.mock import patch + +from tests.mcp import MCPToolTestBase +from pyTigerGraph.mcp.connection_manager import ConnectionManager +from pyTigerGraph.mcp.tools.connection_tools import ( + list_connections, + show_connection, +) + + +class TestListConnections(MCPToolTestBase): + + def setUp(self): + super().setUp() + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default", + "STAGING_TG_HOST": "http://staging", + }, + clear=True, + ) + async def test_lists_discovered_profiles(self): + ConnectionManager.load_profiles() + + result = await list_connections() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 2) + profile_names = [p["profile"] for p in resp["data"]["profiles"]] + self.assertIn("default", profile_names) + self.assertIn("staging", profile_names) + + @patch.dict(os.environ, {}, clear=True) + async def test_lists_default_only(self): + ConnectionManager._profiles = set() + result = await list_connections() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 1) + self.assertEqual(resp["data"]["profiles"][0]["profile"], "default") + + @patch.dict( + os.environ, + { + "TG_HOST": "http://default", + "STAGING_TG_HOST": "http://staging", + "ANALYTICS_TG_HOST": "http://analytics", + }, + clear=True, + ) + async def test_three_profiles(self): + ConnectionManager.load_profiles() + result = await list_connections() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["count"], 3) + + +class TestShowConnection(MCPToolTestBase): + + def setUp(self): + super().setUp() + ConnectionManager._profiles = set() + ConnectionManager._connection_pool = {} + ConnectionManager._default_connection = None + + @patch.dict( + os.environ, + { + "TG_HOST": "http://my-host", + "TG_USERNAME": "admin", + "TG_GRAPHNAME": "MyGraph", + "TG_PASSWORD": "secret123", + }, + clear=True, + ) + async def test_shows_default_profile(self): + result = await show_connection() + resp = self.assert_success(result) + self.assertEqual(resp["data"]["host"], "http://my-host") + self.assertEqual(resp["data"]["username"], "admin") + self.assertEqual(resp["data"]["graphname"], "MyGraph") + self.assertNotIn("password", resp["data"]) + + @patch.dict( + os.environ, + { + "STAGING_TG_HOST": "http://staging-host", + "STAGING_TG_USERNAME": "stg_admin", + }, + clear=True, + ) + async def test_shows_named_profile(self): + result = await show_connection(profile="staging") + resp = self.assert_success(result) + self.assertEqual(resp["data"]["host"], "http://staging-host") + self.assertEqual(resp["data"]["username"], "stg_admin") + self.assertEqual(resp["data"]["profile"], "staging") + + @patch.dict( + os.environ, + {"TG_HOST": "http://default", "TG_PROFILE": "default"}, + clear=True, + ) + async def test_falls_back_to_tg_profile_env(self): + result = await show_connection(profile=None) + resp = self.assert_success(result) + self.assertEqual(resp["data"]["profile"], "default") + self.assertEqual(resp["data"]["host"], "http://default") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/mcp/test_data_tools.py b/tests/mcp/test_data_tools.py index 77c43c8b..7c8b3929 100644 --- a/tests/mcp/test_data_tools.py +++ b/tests/mcp/test_data_tools.py @@ -283,5 +283,40 @@ async def test_not_found(self, mock_gc): self.assert_error(result) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for data tools.""" + + @patch(PATCH_TARGET) + async def test_create_loading_job_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created loading job" + + result = await create_loading_job( + job_name="load_people", + files=[{ + "file_alias": "f1", + "file_path": "/data/people.csv", + "node_mappings": [{"vertex_type": "Person", "attribute_mappings": {"id": 0}}], + }], + profile="staging", + graph_name="StgGraph", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name="StgGraph") + + @patch(PATCH_TARGET) + async def test_run_loading_job_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runLoadingJobWithFile.return_value = {"statistics": {}} + + result = await run_loading_job_with_file( + job_name="load_people", + file_path="/data/people.csv", + file_tag="f1", + profile="analytics", + ) + mock_gc.assert_called_with(profile="analytics", graph_name=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_datasource_tools.py b/tests/mcp/test_datasource_tools.py index 24104600..e531b70f 100644 --- a/tests/mcp/test_datasource_tools.py +++ b/tests/mcp/test_datasource_tools.py @@ -142,5 +142,31 @@ async def test_file_not_found(self, mock_gc): self.assert_error(result) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for datasource tools.""" + + @patch(PATCH_TARGET) + async def test_create_data_source_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created data source" + + result = await create_data_source( + data_source_name="my_s3", + data_source_type="s3", + config={"bucket": "test"}, + profile="staging", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging") + + @patch(PATCH_TARGET) + async def test_get_all_data_sources_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "data sources: none" + + result = await get_all_data_sources(profile="analytics") + mock_gc.assert_called_with(profile="analytics") + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_edge_tools.py b/tests/mcp/test_edge_tools.py index 03d95f7e..8e3e0fca 100644 --- a/tests/mcp/test_edge_tools.py +++ b/tests/mcp/test_edge_tools.py @@ -228,5 +228,42 @@ async def test_source_missing_returns_false(self, mock_gc): self.assertFalse(resp["data"]["exists"]) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for edge tools.""" + + @patch(PATCH_TARGET) + async def test_add_edge_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertEdge.return_value = None + + result = await add_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + profile="staging", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_get_edge_with_profile_and_graph(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdges.return_value = [{"e_type": "FOLLOWS", "from_id": "u1", "to_id": "u2"}] + + result = await get_edge( + source_vertex_type="Person", + source_vertex_id="u1", + edge_type="FOLLOWS", + target_vertex_type="Person", + target_vertex_id="u2", + profile="analytics", + graph_name="FinGraph", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name="FinGraph") + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_gsql_tools.py b/tests/mcp/test_gsql_tools.py index d0b0d793..f33d28b5 100644 --- a/tests/mcp/test_gsql_tools.py +++ b/tests/mcp/test_gsql_tools.py @@ -44,7 +44,25 @@ async def test_with_graph_name(self, mock_gc): result = await gsql(command="LS", graph_name="MyGraph") self.assert_success(result) - mock_gc.assert_called_with(graph_name="MyGraph") + mock_gc.assert_called_with(profile=None, graph_name="MyGraph") + + @patch(PATCH_TARGET) + async def test_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "OK" + + result = await gsql(command="LS", profile="staging") + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_with_profile_and_graph(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "OK" + + result = await gsql(command="LS", profile="analytics", graph_name="FinGraph") + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name="FinGraph") class TestGetLlmConfig(unittest.TestCase): diff --git a/tests/mcp/test_node_tools.py b/tests/mcp/test_node_tools.py index 1fe632db..9082df3c 100644 --- a/tests/mcp/test_node_tools.py +++ b/tests/mcp/test_node_tools.py @@ -249,5 +249,43 @@ async def test_with_edge_type_filter(self, mock_gc): ) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for node tools.""" + + @patch(PATCH_TARGET) + async def test_add_node_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.upsertVertex.return_value = None + + result = await add_node( + vertex_type="Person", vertex_id="u1", profile="staging" + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_get_node_with_profile_and_graph(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVerticesById.return_value = [{"v_id": "u1"}] + + result = await get_node( + vertex_type="Person", + vertex_id="u1", + profile="analytics", + graph_name="FinGraph", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name="FinGraph") + + @patch(PATCH_TARGET) + async def test_delete_node_none_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.delVerticesById.return_value = 1 + + result = await delete_node(vertex_type="Person", vertex_id="u1") + self.assert_success(result) + mock_gc.assert_called_with(profile=None, graph_name=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_query_tools.py b/tests/mcp/test_query_tools.py index c48103f1..77821d35 100644 --- a/tests/mcp/test_query_tools.py +++ b/tests/mcp/test_query_tools.py @@ -220,5 +220,45 @@ async def test_empty_result(self, mock_gc): self.assertEqual(resp["data"]["count"], 0) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for query tools.""" + + @patch(PATCH_TARGET) + async def test_run_query_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInterpretedQuery.return_value = [{"v": []}] + + result = await run_query( + query_text="INTERPRET QUERY () FOR GRAPH G { PRINT 1; }", + profile="staging", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_run_installed_query_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.runInstalledQuery.return_value = [{"result": "ok"}] + + result = await run_installed_query( + query_name="myQuery", + profile="analytics", + graph_name="FinGraph", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name="FinGraph") + + @patch(PATCH_TARGET) + async def test_install_query_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully created query" + + result = await install_query( + query_text="CREATE QUERY foo() { PRINT 1; }", + profile="prod", + ) + mock_gc.assert_called_with(profile="prod", graph_name=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_schema_tools.py b/tests/mcp/test_schema_tools.py index f15140d4..b3e8a2bb 100644 --- a/tests/mcp/test_schema_tools.py +++ b/tests/mcp/test_schema_tools.py @@ -291,5 +291,63 @@ async def test_query_detail(self, mock_gc): self.assertEqual(resp["data"]["detail_type"], "query") +class TestProfilePropagation(MCPToolTestBase): + """Verify that the profile parameter is forwarded to get_connection.""" + + @patch(PATCH_TARGET) + async def test_list_graphs_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "- Graph G1(V:v)" + + result = await list_graphs(profile="staging") + self.assert_success(result) + mock_gc.assert_called_with(profile="staging") + + @patch(PATCH_TARGET) + async def test_get_graph_schema_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getSchema.return_value = {"VertexTypes": [], "EdgeTypes": []} + + result = await get_graph_schema(profile="analytics", graph_name="FinGraph") + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name="FinGraph") + + @patch(PATCH_TARGET) + async def test_get_global_schema_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Global vertex types: ..." + + result = await get_global_schema(profile="prod") + self.assert_success(result) + mock_gc.assert_called_with(profile="prod") + + @patch(PATCH_TARGET) + async def test_show_graph_details_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Vertex types: V" + + result = await show_graph_details(profile="staging", graph_name="StgGraph") + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name="StgGraph") + + @patch(PATCH_TARGET) + async def test_drop_graph_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "Successfully dropped graph" + + result = await drop_graph(profile="staging", graph_name="OldGraph") + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name="OldGraph") + + @patch(PATCH_TARGET) + async def test_none_profile_is_default(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = "- Graph G1(V:v)" + + result = await list_graphs() + self.assert_success(result) + mock_gc.assert_called_with(profile=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_statistics_tools.py b/tests/mcp/test_statistics_tools.py index c5aa5670..6f556204 100644 --- a/tests/mcp/test_statistics_tools.py +++ b/tests/mcp/test_statistics_tools.py @@ -121,5 +121,27 @@ async def test_with_edge_type(self, mock_gc): self.assertIn("FOLLOWS", query_arg) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for statistics tools.""" + + @patch(PATCH_TARGET) + async def test_get_vertex_count_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getVertexCount.return_value = 10 + + result = await get_vertex_count(vertex_type="Person", profile="staging") + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_get_edge_count_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.getEdgeCount.return_value = 50 + + result = await get_edge_count(edge_type="FOLLOWS", profile="analytics") + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name=None) + + if __name__ == "__main__": unittest.main() diff --git a/tests/mcp/test_vector_tools.py b/tests/mcp/test_vector_tools.py index 19bfd106..04c0be23 100644 --- a/tests/mcp/test_vector_tools.py +++ b/tests/mcp/test_vector_tools.py @@ -493,5 +493,60 @@ async def test_json_file_clause_present(self, mock_gc): self.assertIn('JSON_FILE="true"', create_call) +class TestProfilePropagation(MCPToolTestBase): + """Verify profile is forwarded to get_connection for vector tools.""" + + @patch(PATCH_TARGET) + async def test_add_vector_attribute_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + "", + "Successfully created schema change job", + ] + + result = await add_vector_attribute( + vertex_type="Person", + vector_name="emb", + dimension=128, + profile="staging", + ) + self.assert_success(result) + mock_gc.assert_called_with(profile="staging", graph_name=None) + + @patch(PATCH_TARGET) + async def test_list_vector_attributes_with_profile(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.return_value = ( + "---- Graph TestGraph\n" + "Vector Embeddings:\n" + " - Person:\n" + ' - emb(Dimension=128, IndexType="HNSW", DataType="FLOAT", Metric="COSINE")\n' + ) + + result = await list_vector_attributes(profile="analytics") + self.assert_success(result) + mock_gc.assert_called_with(profile="analytics", graph_name=None) + + @patch(PATCH_TARGET) + async def test_search_with_profile_and_graph(self, mock_gc): + mock_gc.return_value = self.mock_conn + self.mock_conn.gsql.side_effect = [ + '- emb(Dimension=3, IndexType="HNSW", DataType="FLOAT", Metric="COSINE")', + "OK", + "OK", + ] + self.mock_conn.runInstalledQuery.return_value = [{"results": []}] + + result = await search_top_k_similarity( + vertex_type="Person", + vector_attribute="emb", + query_vector=[0.1, 0.2, 0.3], + top_k=5, + profile="prod", + graph_name="ProdGraph", + ) + mock_gc.assert_called_with(profile="prod", graph_name="ProdGraph") + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_common_edge_query.py b/tests/test_common_edge_query.py new file mode 100644 index 00000000..f3392f40 --- /dev/null +++ b/tests/test_common_edge_query.py @@ -0,0 +1,185 @@ +"""Unit tests for pyTigerGraph.common.edge and pyTigerGraph.common.query helpers. + +These tests exercise the helper functions in isolation — no live TigerGraph +server is required. +""" +import json +import sys +import os +import unittest +from datetime import datetime + +# Make sure the project source is on the path when running from the repo root. +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from pyTigerGraph.common.edge import _dumps, _prep_upsert_edge_dataframe +from pyTigerGraph.common.query import _parse_query_parameters + + +# --------------------------------------------------------------------------- +# _dumps +# --------------------------------------------------------------------------- + +class TestDumps(unittest.TestCase): + """Tests for common/edge._dumps()""" + + def test_non_dict_string(self): + """Non-dict leaf values must be serialized with json.dumps, not '{}'.""" + self.assertEqual('"hello"', _dumps("hello")) + + def test_non_dict_int(self): + self.assertEqual("42", _dumps(42)) + + def test_non_dict_float(self): + self.assertEqual("3.14", _dumps(3.14)) + + def test_non_dict_bool(self): + self.assertEqual("true", _dumps(True)) + + def test_non_dict_none(self): + self.assertEqual("null", _dumps(None)) + + def test_non_dict_list(self): + self.assertEqual("[1, 2, 3]", _dumps([1, 2, 3])) + + def test_empty_dict(self): + self.assertEqual("{}", _dumps({})) + + def test_flat_dict_scalar_values(self): + """Flat dicts with scalar values round-trip through JSON correctly.""" + result = _dumps({"a": 1, "b": "two"}) + parsed = json.loads(result) + self.assertEqual({"a": 1, "b": "two"}, parsed) + + def test_nested_dict(self): + """Nested dicts produce valid JSON.""" + result = _dumps({"outer": {"inner": 99}}) + parsed = json.loads(result) + self.assertEqual({"outer": {"inner": 99}}, parsed) + + +# --------------------------------------------------------------------------- +# _parse_query_parameters +# --------------------------------------------------------------------------- + +class TestParseQueryParameters(unittest.TestCase): + """Tests for common/query._parse_query_parameters()""" + + def test_simple_scalar(self): + result = _parse_query_parameters({"k": "v"}) + self.assertEqual("k=v", result) + + def test_multiple_scalars(self): + result = _parse_query_parameters({"a": "1", "b": "2"}) + parts = sorted(result.split("&")) + self.assertEqual(["a=1", "b=2"], parts) + + def test_list_values_repeated_key(self): + """Lists must be encoded as repeated keys: k=v1&k=v2.""" + result = _parse_query_parameters({"colors": ["red", "green", "blue"]}) + parts = result.split("&") + self.assertEqual(["colors=red", "colors=green", "colors=blue"], parts) + + def test_no_trailing_ampersand(self): + """Result must never end with '&'.""" + result = _parse_query_parameters({"x": ["a", "b", "c"]}) + self.assertFalse(result.endswith("&")) + + def test_vertex_tuple_single(self): + """VERTEX parameters: (id, type) → k=id&k.type=type.""" + result = _parse_query_parameters({"v": ("vid1", "Person")}) + parts = result.split("&") + self.assertIn("v=vid1", parts) + self.assertIn("v.type=Person", parts) + + def test_set_vertex_list_of_tuples(self): + """SET: list of (id, type) tuples → k[i]=id&k[i].type=type.""" + result = _parse_query_parameters( + {"vs": [("id0", "TypeA"), ("id1", "TypeB")]} + ) + parts = result.split("&") + self.assertIn("vs[0]=id0", parts) + self.assertIn("vs[0].type=TypeA", parts) + self.assertIn("vs[1]=id1", parts) + self.assertIn("vs[1].type=TypeB", parts) + + def test_datetime_value(self): + dt = datetime(2024, 1, 15, 10, 30, 0) + result = _parse_query_parameters({"ts": dt}) + self.assertEqual("ts=2024-01-15%2010%3A30%3A00", result) + + def test_large_list_no_on2_regression(self): + """Performance sanity: 10 000-element list should complete quickly.""" + params = {"big": list(range(10_000))} + result = _parse_query_parameters(params) + parts = result.split("&") + self.assertEqual(10_000, len(parts)) + self.assertTrue(all(p.startswith("big=") for p in parts)) + + def test_empty_params(self): + self.assertEqual("", _parse_query_parameters({})) + + +# --------------------------------------------------------------------------- +# _prep_upsert_edge_dataframe +# --------------------------------------------------------------------------- + +class TestPrepUpsertEdgeDataframe(unittest.TestCase): + """Tests for common/edge._prep_upsert_edge_dataframe()""" + + def _make_df(self): + try: + import pandas as pd + except ImportError: + self.skipTest("pandas not installed") + return pd.DataFrame( + { + "src": ["A", "B"], + "dst": ["X", "Y"], + "weight": [1.0, 2.0], + } + ) + + def test_explicit_from_to_ids(self): + """When from_id and to_id columns are given, they are used as vertex IDs.""" + df = self._make_df() + result = _prep_upsert_edge_dataframe(df, "src", "dst", None) + self.assertEqual(2, len(result)) + src_id, dst_id, attrs = result[0] + self.assertEqual("A", src_id) + self.assertEqual("X", dst_id) + self.assertIn("weight", attrs) + + def test_default_from_to_uses_index(self): + """Empty from_id / to_id defaults must fall back to DataFrame index.""" + df = self._make_df() + # With default "" arguments the dataframe index (0, 1) should be used. + result = _prep_upsert_edge_dataframe(df, "", "", None) + src_id_0, dst_id_0, _ = result[0] + self.assertEqual(0, src_id_0) + self.assertEqual(0, dst_id_0) + src_id_1, dst_id_1, _ = result[1] + self.assertEqual(1, src_id_1) + self.assertEqual(1, dst_id_1) + + def test_attribute_projection(self): + """Only mapped attributes should appear when attributes dict is given.""" + df = self._make_df() + result = _prep_upsert_edge_dataframe(df, "src", "dst", {"w": "weight"}) + _, _, attrs = result[0] + self.assertIn("w", attrs) + self.assertNotIn("weight", attrs) + self.assertNotIn("src", attrs) + + def test_no_attributes_returns_all_columns(self): + """When attributes is None, all columns are included.""" + df = self._make_df() + result = _prep_upsert_edge_dataframe(df, "src", "dst", None) + _, _, attrs = result[0] + self.assertIn("src", attrs) + self.assertIn("dst", attrs) + self.assertIn("weight", attrs) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_pyTigerGraphEdgeAsync.py b/tests/test_pyTigerGraphEdgeAsync.py index ac3b5d44..31de4605 100644 --- a/tests/test_pyTigerGraphEdgeAsync.py +++ b/tests/test_pyTigerGraphEdgeAsync.py @@ -201,10 +201,31 @@ async def test_09_upsertEdge(self): self.assertIsInstance(res, int) self.assertEqual(1, res) - # TODO Tests with ack, new_vertex_only, vertex_must_exist, update_vertex_only and - # atomic_level parameters; when they will be added to pyTigerGraphEdge.upsertEdge() # TODO Add MultiEdge edge to schema and add test cases + async def test_09_upsertEdge_mustExist(self): + res = await self.conn.upsertEdge( + "vertex6", + int(1e6), + "edge4_many_to_many", + "vertex7", + 1, + vertexMustExist=True, + ) + self.assertIsInstance(res, int) + self.assertEqual(0, res) + + res = await self.conn.upsertEdge( + "vertex6", + 6, + "edge4_many_to_many", + "vertex7", + int(2e6), + vertexMustExist=True, + ) + self.assertIsInstance(res, int) + self.assertEqual(0, res) + async def test_10_upsertEdges(self): es = [ (2, 1), @@ -220,9 +241,48 @@ async def test_10_upsertEdges(self): self.assertIsInstance(res, int) self.assertEqual(14, res) + async def test_10_upsertEdges_mustExist(self): + es = [(2, int(1e6)), (int(1e6), 2)] + res = await self.conn.upsertEdges( + "vertex6", "edge4_many_to_many", "vertex7", es, vertexMustExist=True + ) + self.assertIsInstance(res, int) + self.assertEqual(0, res) + async def test_11_upsertEdgeDataFrame(self): - # TODO Implement - pass + edges = [ + { + "e_type": "edge1_undirected", + "directed": False, + "from_id": 1, + "from_type": "vertex4", + "to_id": 4, + "to_type": "vertex5", + "attributes": {"a01": -100}, + }, + { + "e_type": "edge1_undirected", + "directed": False, + "from_id": 1, + "from_type": "vertex4", + "to_id": 5, + "to_type": "vertex5", + "attributes": {"a01": -100}, + }, + ] + df = await self.conn.edgeSetToDataFrame(edges) + res = await self.conn.upsertEdgeDataFrame( + df=df, + sourceVertexType="vertex4", + edgeType="edge1_undirected", + targetVertexType="vertex5", + from_id="from_id", + to_id="to_id", + attributes={"a01": "a01"}, + vertexMustExist=True, + ) + self.assertIsInstance(res, int) + self.assertEqual(2, res) async def test_12_getEdges(self): res = await self.conn.getEdges("vertex4", 1) From 7da61824729e81cdc8696eccb6b87cc39769aeab Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Wed, 11 Mar 2026 16:09:10 -0700 Subject: [PATCH 2/6] GML-2042 MCP support multiple profiles --- README.md | 6 ++-- pyTigerGraph/mcp/MCP_README.md | 54 ++++++++++++++++++++++++++++++---- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 2d08d08d..2716c63f 100644 --- a/README.md +++ b/README.md @@ -140,8 +140,8 @@ conn = TigerGraphConnection( ### Synchronous mode (`TigerGraphConnection`) -- Each thread gets its own `requests.Session` backed by a private connection pool. This eliminates the `_cookies_lock` contention that a shared session causes under concurrent load. -- Install `pyTigerGraph[fast]` to activate the `orjson` backend and significantly reduce GIL contention between threads during JSON parsing. +- Each thread gets its own dedicated HTTP session and connection pool, so concurrent threads never block each other. +- Install `pyTigerGraph[fast]` to activate the `orjson` backend and reduce JSON parsing overhead under concurrent load. - Use `ThreadPoolExecutor` to run queries in parallel: ```python @@ -149,7 +149,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed with TigerGraphConnection(...) as conn: with ThreadPoolExecutor(max_workers=16) as executor: - futures = {executor.submit(conn.runInstalledQuery, "q", {"p": v}): v for v in values} + futures = [executor.submit(conn.runInstalledQuery, "q", {"p": v}) for v in values] for f in as_completed(futures): print(f.result()) ``` diff --git a/pyTigerGraph/mcp/MCP_README.md b/pyTigerGraph/mcp/MCP_README.md index c8c728bd..db73de02 100644 --- a/pyTigerGraph/mcp/MCP_README.md +++ b/pyTigerGraph/mcp/MCP_README.md @@ -8,6 +8,7 @@ pyTigerGraph now includes Model Context Protocol (MCP) support, allowing AI agen - [Usage](#usage) - [Running the MCP Server](#running-the-mcp-server) - [Configuration](#configuration) + - [Multiple Connection Profiles](#multiple-connection-profiles) - [Using with Existing Connection](#using-with-existing-connection) - [Client Examples](#client-examples) - [Using MultiServerMCPClient](#using-multiserverMCPclient) @@ -83,7 +84,6 @@ TG_USERNAME=tigergraph TG_PASSWORD=tigergraph TG_RESTPP_PORT=9000 TG_GS_PORT=14240 -TG_CONN_LIMIT=10 # Optional - increase for parallel tool calls (e.g. 32) ``` The server will automatically load the `.env` file if it exists. Environment variables take precedence over `.env` file values. @@ -104,7 +104,52 @@ The following environment variables are supported: - `TG_SSL_PORT` - SSL port (default: 443) - `TG_TGCLOUD` - Whether using TigerGraph Cloud (default: False) - `TG_CERT_PATH` - Path to certificate (optional) -- `TG_CONN_LIMIT` - Max keep-alive HTTP connections in the async client pool (default: 10). Should be ≥ the number of concurrent MCP tool calls you expect. Named profiles use `_TG_CONN_LIMIT`. + +### Multiple Connection Profiles + +If you work with more than one TigerGraph environment — for example, development, staging, and production — you can define named profiles in your `.env` file and switch between them without changing any code. + +#### Defining profiles + +Each named profile uses a `_` prefix on the standard `TG_*` variables. Only the variables that differ from the default need to be set. + +```bash +# .env + +# Default profile (no prefix) — used when TG_PROFILE is not set +TG_HOST=http://localhost +TG_USERNAME=tigergraph +TG_PASSWORD=tigergraph +TG_GRAPHNAME=MyGraph + +# Staging profile +STAGING_TG_HOST=https://staging.example.com +STAGING_TG_PASSWORD=staging_secret +STAGING_TG_TGCLOUD=true + +# Production profile +PROD_TG_HOST=https://prod.example.com +PROD_TG_USERNAME=admin +PROD_TG_PASSWORD=prod_secret +PROD_TG_GRAPHNAME=ProdGraph +PROD_TG_TGCLOUD=true +``` + +Profiles are discovered automatically at startup. Any variable matching `_TG_HOST` registers a new profile. Values not set for a named profile fall back to the default profile's values. + +#### Selecting the active profile + +Pass `TG_PROFILE` as an environment variable or add it to your `.env`: + +```bash +# Switch to the staging profile for this run +TG_PROFILE=staging tigergraph-mcp + +# Or set it permanently in .env +TG_PROFILE=prod +``` + +If `TG_PROFILE` is not set, the default profile (unprefixed `TG_*` variables) is used. ### Using with Existing Connection @@ -136,7 +181,6 @@ async with AsyncTigerGraphConnection( graphname="MyGraph", username="tigergraph", password="tigergraph", - connLimit=10, # set >= number of concurrent MCP tool calls (default: 10) ) as conn: # Set as default for MCP tools ConnectionManager.set_default_connection(conn) @@ -396,8 +440,8 @@ result = await session.call_tool( - **Transport**: The MCP server uses stdio transport by default - **Error Detection**: GSQL operations include error detection for syntax and semantic errors (since `conn.gsql()` does not raise Python exceptions for GSQL failures) -- **Connection Management**: Connections are pooled by profile — each profile's `AsyncTigerGraphConnection` holds a persistent HTTP connection pool (sized by `TG_CONN_LIMIT`, default 10). The pool is automatically released at server shutdown via `ConnectionManager.close_all()`. To adjust pool size per profile, set `_TG_CONN_LIMIT`. -- **Performance**: Persistent HTTP connection pool per profile (no TCP handshake per request); async non-blocking I/O; `v.outdegree()` for O(1) degree counting; batch operations for multiple vertices/edges +- **Connection Management**: Connections are pooled by profile and reused across requests — no TCP handshake overhead per tool call. The pool is released automatically at server shutdown. +- **Performance**: Persistent HTTP connection pool per profile; async non-blocking I/O; `v.outdegree()` for O(1) degree counting; batch operations for multiple vertices/edges ## Backward Compatibility From 561376f2789c904e535fceb09629a924ddd19ec5 Mon Sep 17 00:00:00 2001 From: Chengbiao Jin Date: Wed, 18 Mar 2026 15:32:24 -0700 Subject: [PATCH 3/6] GML-2041 performance improvement, and MCP separation --- LICENSE | 2 +- README.md | 18 +- pyTigerGraph/__init__.py | 7 +- pyTigerGraph/common/gsql.py | 63 + pyTigerGraph/common/loading.py | 37 +- pyTigerGraph/mcp/MCP_README.md | 451 ------- pyTigerGraph/mcp/__init__.py | 43 +- pyTigerGraph/mcp/connection_manager.py | 286 ----- pyTigerGraph/mcp/main.py | 50 - pyTigerGraph/mcp/response_formatter.py | 315 ----- pyTigerGraph/mcp/server.py | 285 ---- pyTigerGraph/mcp/tool_metadata.py | 528 -------- pyTigerGraph/mcp/tool_names.py | 114 -- pyTigerGraph/mcp/tools/__init__.py | 309 ----- pyTigerGraph/mcp/tools/connection_tools.py | 104 -- pyTigerGraph/mcp/tools/data_tools.py | 638 --------- pyTigerGraph/mcp/tools/datasource_tools.py | 377 ------ pyTigerGraph/mcp/tools/discovery_tools.py | 611 --------- pyTigerGraph/mcp/tools/edge_tools.py | 706 ---------- pyTigerGraph/mcp/tools/gsql_tools.py | 560 -------- pyTigerGraph/mcp/tools/node_tools.py | 1003 --------------- pyTigerGraph/mcp/tools/query_tools.py | 787 ------------ pyTigerGraph/mcp/tools/schema_tools.py | 988 -------------- pyTigerGraph/mcp/tools/statistics_tools.py | 332 ----- pyTigerGraph/mcp/tools/tool_registry.py | 182 --- pyTigerGraph/mcp/tools/vector_tools.py | 1143 ----------------- pyTigerGraph/pyTigerGraphLoading.py | 205 ++- pyTigerGraph/pyTigerGraphQuery.py | 45 +- pyTigerGraph/pyTigerGraphSchema.py | 135 ++ pyTigerGraph/pytgasync/pyTigerGraphLoading.py | 211 ++- pyTigerGraph/pytgasync/pyTigerGraphQuery.py | 45 +- pyTigerGraph/pytgasync/pyTigerGraphSchema.py | 135 ++ pyproject.toml | 49 +- setup.cfg | 3 - setup.py | 89 -- tests/mcp/README.md | 77 -- tests/mcp/__init__.py | 51 - tests/mcp/test_connection_manager.py | 292 ----- tests/mcp/test_connection_tools.py | 119 -- tests/mcp/test_data_tools.py | 322 ----- tests/mcp/test_datasource_tools.py | 172 --- tests/mcp/test_edge_tools.py | 269 ---- tests/mcp/test_gsql_tools.py | 117 -- tests/mcp/test_node_tools.py | 291 ----- tests/mcp/test_query_tools.py | 264 ---- tests/mcp/test_response_formatter.py | 148 --- tests/mcp/test_schema_tools.py | 353 ----- tests/mcp/test_statistics_tools.py | 147 --- tests/mcp/test_vector_tools.py | 552 -------- tests/test_pyTigerGraphLoading.py | 9 +- 50 files changed, 922 insertions(+), 13117 deletions(-) delete mode 100644 pyTigerGraph/mcp/MCP_README.md delete mode 100644 pyTigerGraph/mcp/connection_manager.py delete mode 100644 pyTigerGraph/mcp/main.py delete mode 100644 pyTigerGraph/mcp/response_formatter.py delete mode 100644 pyTigerGraph/mcp/server.py delete mode 100644 pyTigerGraph/mcp/tool_metadata.py delete mode 100644 pyTigerGraph/mcp/tool_names.py delete mode 100644 pyTigerGraph/mcp/tools/__init__.py delete mode 100644 pyTigerGraph/mcp/tools/connection_tools.py delete mode 100644 pyTigerGraph/mcp/tools/data_tools.py delete mode 100644 pyTigerGraph/mcp/tools/datasource_tools.py delete mode 100644 pyTigerGraph/mcp/tools/discovery_tools.py delete mode 100644 pyTigerGraph/mcp/tools/edge_tools.py delete mode 100644 pyTigerGraph/mcp/tools/gsql_tools.py delete mode 100644 pyTigerGraph/mcp/tools/node_tools.py delete mode 100644 pyTigerGraph/mcp/tools/query_tools.py delete mode 100644 pyTigerGraph/mcp/tools/schema_tools.py delete mode 100644 pyTigerGraph/mcp/tools/statistics_tools.py delete mode 100644 pyTigerGraph/mcp/tools/tool_registry.py delete mode 100644 pyTigerGraph/mcp/tools/vector_tools.py delete mode 100644 setup.cfg delete mode 100644 setup.py delete mode 100644 tests/mcp/README.md delete mode 100644 tests/mcp/__init__.py delete mode 100644 tests/mcp/test_connection_manager.py delete mode 100644 tests/mcp/test_connection_tools.py delete mode 100644 tests/mcp/test_data_tools.py delete mode 100644 tests/mcp/test_datasource_tools.py delete mode 100644 tests/mcp/test_edge_tools.py delete mode 100644 tests/mcp/test_gsql_tools.py delete mode 100644 tests/mcp/test_node_tools.py delete mode 100644 tests/mcp/test_query_tools.py delete mode 100644 tests/mcp/test_response_formatter.py delete mode 100644 tests/mcp/test_schema_tools.py delete mode 100644 tests/mcp/test_statistics_tools.py delete mode 100644 tests/mcp/test_vector_tools.py diff --git a/LICENSE b/LICENSE index d9f59d42..0a62d25a 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2022 TigerGraph Inc. + Copyright 2022-2026 TigerGraph Inc. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 2716c63f..13711947 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ pip install pyTigerGraph | Extra | What it adds | Install command | |-------|-------------|-----------------| | `gds` | Graph Data Science — data loaders for PyTorch Geometric, DGL, and Pandas | `pip install 'pyTigerGraph[gds]'` | -| `mcp` | Model Context Protocol server — exposes TigerGraph as tools for AI agents | `pip install 'pyTigerGraph[mcp]'` | +| `mcp` | Model Context Protocol server — installs [`tigergraph-mcp`](https://github.com/tigergraph/tigergraph-mcp) (convenience alias) | `pip install 'pyTigerGraph[mcp]'` | | `fast` | [orjson](https://github.com/ijl/orjson) JSON backend — 2–10× faster parsing, releases the GIL under concurrent load | `pip install 'pyTigerGraph[fast]'` | Extras can be combined: @@ -192,16 +192,28 @@ See the [GDS documentation](https://docs.tigergraph.com/pytigergraph/current/gds ## MCP Server -pyTigerGraph includes a built-in [Model Context Protocol (MCP)](https://modelcontextprotocol.io/) server that exposes TigerGraph operations as tools for AI agents and LLM applications (Claude Desktop, Cursor, Copilot, etc.). All MCP tools use the async API internally for optimal performance. +The TigerGraph MCP server is now a standalone package: **[tigergraph-mcp](https://github.com/tigergraph/tigergraph-mcp)**. It exposes TigerGraph operations as tools for AI agents and LLM applications (Claude Desktop, Cursor, Copilot, etc.). ```sh +# Recommended — install the standalone package directly +pip install tigergraph-mcp + +# Or via the pyTigerGraph convenience alias (installs tigergraph-mcp automatically) pip install 'pyTigerGraph[mcp]' # Start the server (reads connection config from environment variables) tigergraph-mcp ``` -For full setup instructions, available tools, and configuration examples, see the **[MCP Server README](pyTigerGraph/mcp/MCP_README.md)**. +For full setup instructions, available tools, configuration examples, and multi-profile support, see the **[tigergraph-mcp README](https://github.com/tigergraph/tigergraph-mcp#readme)**. + +> **Migrating from `pyTigerGraph.mcp`?** Update your imports: +> ```python +> # Old +> from pyTigerGraph.mcp import serve, ConnectionManager +> # New +> from tigergraph_mcp import serve, ConnectionManager +> ``` --- diff --git a/pyTigerGraph/__init__.py b/pyTigerGraph/__init__.py index 18d4e3bd..ca018e34 100644 --- a/pyTigerGraph/__init__.py +++ b/pyTigerGraph/__init__.py @@ -1,8 +1,13 @@ +from importlib.metadata import version as _pkg_version, PackageNotFoundError + from pyTigerGraph.pyTigerGraph import TigerGraphConnection from pyTigerGraph.pytgasync.pyTigerGraph import AsyncTigerGraphConnection from pyTigerGraph.common.exception import TigerGraphException -__version__ = "2.0.1" +try: + __version__ = _pkg_version("pyTigerGraph") +except PackageNotFoundError: + __version__ = "2.0.1" __license__ = "Apache 2" diff --git a/pyTigerGraph/common/gsql.py b/pyTigerGraph/common/gsql.py index 2b179e2e..917a11cf 100644 --- a/pyTigerGraph/common/gsql.py +++ b/pyTigerGraph/common/gsql.py @@ -60,6 +60,69 @@ def clean_res(resp: list) -> str: return string_without_ansi +_GSQL_ERROR_PATTERNS = [ + "Encountered \"", + "SEMANTIC ERROR", + "Syntax Error", + "Failed to create", + "does not exist", + "is not a valid", + "already exists", + "Invalid syntax", +] + + +def _wrap_gsql_result(result, skipCheck: bool = False): + """Wrap a gsql() string result into a dict matching 4.x REST response format. + + Args: + result: The raw string returned by ``gsql()``. + skipCheck: If ``False`` (default), raises ``TigerGraphException`` when + an error pattern is detected — consistent with ``_error_check`` + on the 4.x REST path. If ``True``, returns the dict with + ``"error": True`` without raising. + """ + msg = str(result) if result else "" + has_error = any(p in msg for p in _GSQL_ERROR_PATTERNS) + if has_error and not skipCheck: + raise TigerGraphException(msg) + return { + "error": has_error, + "message": msg, + } + + +def _parse_graph_list(gsql_output): + """Parse ``SHOW GRAPH *`` output into a list of dicts matching 4.x REST format.""" + output = str(gsql_output) if gsql_output else "" + graphs = [] + for line in output.splitlines(): + stripped = line.strip().lstrip("- ").strip() + if not stripped.startswith("Graph "): + continue + paren_start = stripped.find("(") + name = stripped[6:paren_start].strip() if paren_start > 6 else stripped[6:].strip() + if not name or name == "*": + continue + vertices = [] + edges = [] + if paren_start != -1: + paren_end = stripped.rfind(")") + inner = stripped[paren_start + 1:paren_end] if paren_end > paren_start else "" + for token in inner.split(","): + token = token.strip() + if token.endswith(":v"): + vertices.append(token[:-2]) + elif token.endswith(":e"): + edges.append(token[:-2]) + graphs.append({ + "GraphName": name, + "VertexTypes": vertices, + "EdgeTypes": edges, + }) + return graphs + + def _prep_get_udf(ExprFunctions: bool = True, ExprUtil: bool = True): urls = {} # urls when using TG 4.x alt_urls = {} # urls when using TG 3.x diff --git a/pyTigerGraph/common/loading.py b/pyTigerGraph/common/loading.py index 5872bb3a..137e70e4 100644 --- a/pyTigerGraph/common/loading.py +++ b/pyTigerGraph/common/loading.py @@ -100,7 +100,38 @@ def _prep_get_loading_jobs_status(gsUrl: str, graphname: str, jobIds: list[str]) url += "&" + job_params return url -def _prep_get_loading_job_status(gsUrl: str, jobId: str): +def _prep_get_loading_job_status(gsUrl: str, graphname: str, jobId: str): '''url builder for getLoadingJobStatus()''' - url = gsUrl + "/gsql/v1/loading-jobs/status/" + jobId - return url \ No newline at end of file + url = gsUrl + "/gsql/v1/loading-jobs/status/" + jobId + "?graph=" + graphname + return url + + +# ---- Data Source helpers ---- + +def _prep_data_source_url(gsUrl: str, graphname: str = None): + '''url builder for getDataSources() and createDataSource()''' + url = gsUrl + "/gsql/v1/data-sources" + if graphname: + url += "?graph=" + graphname + return url + + +def _prep_data_source_by_name(gsUrl: str, dsName: str, graphname: str = None): + '''url builder for getDataSource(), dropDataSource(), updateDataSource()''' + url = gsUrl + "/gsql/v1/data-sources/" + dsName + if graphname: + url += "?graph=" + graphname + return url + + +def _prep_drop_all_data_sources(gsUrl: str, graphname: str = None): + '''url builder for dropAllDataSources()''' + url = gsUrl + "/gsql/v1/data-sources/dropAll" + if graphname: + url += "?graph=" + graphname + return url + + +def _prep_sample_data_url(gsUrl: str): + '''url builder for previewSampleData()''' + return gsUrl + "/gsql/v1/sample-data" \ No newline at end of file diff --git a/pyTigerGraph/mcp/MCP_README.md b/pyTigerGraph/mcp/MCP_README.md deleted file mode 100644 index db73de02..00000000 --- a/pyTigerGraph/mcp/MCP_README.md +++ /dev/null @@ -1,451 +0,0 @@ -# pyTigerGraph MCP Support - -pyTigerGraph now includes Model Context Protocol (MCP) support, allowing AI agents to interact with TigerGraph through the MCP standard. All MCP tools use pyTigerGraph's async APIs for optimal performance. - -## Table of Contents - -- [Installation](#installation) -- [Usage](#usage) - - [Running the MCP Server](#running-the-mcp-server) - - [Configuration](#configuration) - - [Multiple Connection Profiles](#multiple-connection-profiles) - - [Using with Existing Connection](#using-with-existing-connection) -- [Client Examples](#client-examples) - - [Using MultiServerMCPClient](#using-multiserverMCPclient) - - [Using MCP Client SDK Directly](#using-mcp-client-sdk-directly) -- [Available Tools](#available-tools) -- [LLM-Friendly Features](#llm-friendly-features) - - [Structured Responses](#structured-responses) - - [Rich Tool Descriptions](#rich-tool-descriptions) - - [Token Optimization](#token-optimization) - - [Tool Discovery](#tool-discovery) -- [Notes](#notes) -- [Backward Compatibility](#backward-compatibility) - -## Installation - -To use MCP functionality, install pyTigerGraph with the `mcp` extra: - -```bash -pip install pyTigerGraph[mcp] -``` - -This will install: -- `mcp>=1.0.0` - The MCP SDK -- `pydantic>=2.0.0` - For data validation -- `click` - For the CLI entry point -- `python-dotenv>=1.0.0` - For loading .env files - -## Usage - -### Running the MCP Server - -You can run the MCP server as a standalone process: - -```bash -tigergraph-mcp -``` - -With a custom .env file: - -```bash -tigergraph-mcp --env-file /path/to/.env -``` - -With verbose logging: - -```bash -tigergraph-mcp -v # INFO level -tigergraph-mcp -vv # DEBUG level -``` - -Or programmatically: - -```python -from pyTigerGraph.mcp import serve -import asyncio - -asyncio.run(serve()) -``` - -### Configuration - -The MCP server reads connection configuration from environment variables. You can set these either directly as environment variables or in a `.env` file. - -#### Using a .env File (Recommended) - -Create a `.env` file in your project directory: - -```bash -# .env -TG_HOST=http://localhost -TG_GRAPHNAME=MyGraph # Optional - can be omitted if database has multiple graphs -TG_USERNAME=tigergraph -TG_PASSWORD=tigergraph -TG_RESTPP_PORT=9000 -TG_GS_PORT=14240 -``` - -The server will automatically load the `.env` file if it exists. Environment variables take precedence over `.env` file values. - -#### Environment Variables - -The following environment variables are supported: - -- `TG_HOST` - TigerGraph host (default: http://127.0.0.1) -- `TG_GRAPHNAME` - Graph name (optional - can be omitted if database has multiple graphs. Use `tigergraph__list_graphs` tool to see available graphs) -- `TG_USERNAME` - Username (default: tigergraph) -- `TG_PASSWORD` - Password (default: tigergraph) -- `TG_SECRET` - GSQL secret (optional) -- `TG_API_TOKEN` - API token (optional) -- `TG_JWT_TOKEN` - JWT token (optional) -- `TG_RESTPP_PORT` - REST++ port (default: 9000) -- `TG_GS_PORT` - GSQL port (default: 14240) -- `TG_SSL_PORT` - SSL port (default: 443) -- `TG_TGCLOUD` - Whether using TigerGraph Cloud (default: False) -- `TG_CERT_PATH` - Path to certificate (optional) - -### Multiple Connection Profiles - -If you work with more than one TigerGraph environment — for example, development, staging, and production — you can define named profiles in your `.env` file and switch between them without changing any code. - -#### Defining profiles - -Each named profile uses a `_` prefix on the standard `TG_*` variables. Only the variables that differ from the default need to be set. - -```bash -# .env - -# Default profile (no prefix) — used when TG_PROFILE is not set -TG_HOST=http://localhost -TG_USERNAME=tigergraph -TG_PASSWORD=tigergraph -TG_GRAPHNAME=MyGraph - -# Staging profile -STAGING_TG_HOST=https://staging.example.com -STAGING_TG_PASSWORD=staging_secret -STAGING_TG_TGCLOUD=true - -# Production profile -PROD_TG_HOST=https://prod.example.com -PROD_TG_USERNAME=admin -PROD_TG_PASSWORD=prod_secret -PROD_TG_GRAPHNAME=ProdGraph -PROD_TG_TGCLOUD=true -``` - -Profiles are discovered automatically at startup. Any variable matching `_TG_HOST` registers a new profile. Values not set for a named profile fall back to the default profile's values. - -#### Selecting the active profile - -Pass `TG_PROFILE` as an environment variable or add it to your `.env`: - -```bash -# Switch to the staging profile for this run -TG_PROFILE=staging tigergraph-mcp - -# Or set it permanently in .env -TG_PROFILE=prod -``` - -If `TG_PROFILE` is not set, the default profile (unprefixed `TG_*` variables) is used. - -### Using with Existing Connection - -You can also use MCP with an existing `TigerGraphConnection` (sync) or `AsyncTigerGraphConnection`: - -**With Sync Connection:** -```python -from pyTigerGraph import TigerGraphConnection - -conn = TigerGraphConnection( - host="http://localhost", - graphname="MyGraph", - username="tigergraph", - password="tigergraph" -) - -# Enable MCP support for this connection -# This creates an async connection internally for MCP tools -conn.start_mcp_server() -``` - -**With Async Connection (Recommended):** -```python -from pyTigerGraph import AsyncTigerGraphConnection -from pyTigerGraph.mcp import ConnectionManager - -async with AsyncTigerGraphConnection( - host="http://localhost", - graphname="MyGraph", - username="tigergraph", - password="tigergraph", -) as conn: - # Set as default for MCP tools - ConnectionManager.set_default_connection(conn) - # ... run MCP tools ... -# HTTP connection pool is released on exit -``` - -This sets the connection as the default for MCP tools. Note that MCP tools use async APIs internally, so using `AsyncTigerGraphConnection` directly is more efficient. For long-lived connections without `async with`, call `await conn.aclose()` explicitly when finished. - -## Client Examples - -### Using MultiServerMCPClient - -```python -from langchain_mcp_adapters import MultiServerMCPClient -from pathlib import Path -from dotenv import dotenv_values -import asyncio - -# Load environment variables -env_dict = dotenv_values(dotenv_path=Path(".env").expanduser().resolve()) - -# Configure the client -client = MultiServerMCPClient( - { - "tigergraph-mcp": { - "transport": "stdio", - "command": "tigergraph-mcp", - "args": ["-vv"], # Enable debug logging - "env": env_dict, - }, - } -) - -# Get tools and use them -tools = asyncio.run(client.get_tools()) -# Tools are now available for use -``` - -### Using MCP Client SDK Directly - -```python -import asyncio -from mcp import ClientSession, StdioServerParameters -from mcp.client.stdio import stdio_client - -async def call_tool(): - # Configure server parameters - server_params = StdioServerParameters( - command="tigergraph-mcp", - args=["-vv"], # Enable debug logging - env=None, # Uses .env file or environment variables - ) - - async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: - await session.initialize() - - # List available tools - tools = await session.list_tools() - print(f"Available tools: {[t.name for t in tools.tools]}") - - # Call a tool - result = await session.call_tool( - "tigergraph__list_graphs", - arguments={} - ) - - # Print result - for content in result.content: - print(content.text) - -asyncio.run(call_tool()) -``` - -**Note:** When using `MultiServerMCPClient` or similar MCP clients with stdio transport, the `args` parameter is required. For the `tigergraph-mcp` command (which is a standalone entry point), set `args` to an empty list `[]`. If you need to pass arguments to the command, include them in the list (e.g., `["-v"]` for verbose mode, `["-vv"]` for debug mode). - -## Available Tools - -The MCP server provides the following tools: - -### Global Schema Operations (Database Level) -These operations work with the global schema that spans across the entire TigerGraph database. - -- `tigergraph__get_global_schema` - Get the complete global schema (all global vertex/edge types, graphs, and members) via GSQL 'LS' command - -### Graph Operations (Database Level) -These operations manage individual graphs within the TigerGraph database. A database can contain multiple graphs. - -- `tigergraph__list_graphs` - List all graph names in the database (names only, no details) -- `tigergraph__create_graph` - Create a new graph with its schema (vertex types, edge types) -- `tigergraph__drop_graph` - Drop (delete) a graph and its schema -- `tigergraph__clear_graph_data` - Clear all data from a graph (keeps schema structure) - -### Schema Operations (Graph Level) -These operations work with the schema and objects of a specific graph. - -- `tigergraph__get_graph_schema` - Get the schema of a specific graph as structured JSON (vertex/edge types and attributes only) -- `tigergraph__show_graph_details` - Show details of a graph: schema, queries, loading jobs, data sources. Use `detail_type` to filter (`schema`, `query`, `loading_job`, `data_source`) or omit for all - -### Node Operations -- `tigergraph__add_node` - Add a single node -- `tigergraph__add_nodes` - Add multiple nodes -- `tigergraph__get_node` - Get a single node -- `tigergraph__get_nodes` - Get multiple nodes -- `tigergraph__delete_node` - Delete a single node -- `tigergraph__delete_nodes` - Delete multiple nodes -- `tigergraph__has_node` - Check if a node exists -- `tigergraph__get_node_edges` - Get all edges connected to a node - -### Edge Operations -- `tigergraph__add_edge` - Add a single edge -- `tigergraph__add_edges` - Add multiple edges -- `tigergraph__get_edge` - Get a single edge -- `tigergraph__get_edges` - Get multiple edges -- `tigergraph__delete_edge` - Delete a single edge -- `tigergraph__delete_edges` - Delete multiple edges -- `tigergraph__has_edge` - Check if an edge exists - -### Query Operations -- `tigergraph__run_query` - Run an interpreted query -- `tigergraph__run_installed_query` - Run an installed query -- `tigergraph__install_query` - Install a query -- `tigergraph__drop_query` - Drop (delete) an installed query -- `tigergraph__show_query` - Show query text -- `tigergraph__get_query_metadata` - Get query metadata -- `tigergraph__is_query_installed` - Check if a query is installed -- `tigergraph__get_neighbors` - Get neighbor vertices of a node - -### Loading Job Operations -- `tigergraph__create_loading_job` - Create a loading job from structured config (file mappings, node/edge mappings) -- `tigergraph__run_loading_job_with_file` - Execute a loading job with a data file -- `tigergraph__run_loading_job_with_data` - Execute a loading job with inline data string -- `tigergraph__get_loading_jobs` - Get all loading jobs for the graph -- `tigergraph__get_loading_job_status` - Get status of a specific loading job -- `tigergraph__drop_loading_job` - Drop a loading job - -### Statistics Operations -- `tigergraph__get_vertex_count` - Get vertex count -- `tigergraph__get_edge_count` - Get edge count -- `tigergraph__get_node_degree` - Get the degree (number of edges) of a node - -### GSQL Operations -- `tigergraph__gsql` - Execute raw GSQL command -- `tigergraph__generate_gsql` - Generate a GSQL query from a natural language description (requires LLM configuration) -- `tigergraph__generate_cypher` - Generate an openCypher query from a natural language description (requires LLM configuration) - -### Vector Schema Operations -- `tigergraph__add_vector_attribute` - Add a vector attribute to a vertex type (DIMENSION, METRIC: COSINE/L2/IP) -- `tigergraph__drop_vector_attribute` - Drop a vector attribute from a vertex type -- `tigergraph__list_vector_attributes` - List vector attributes (name, dimension, index type, data type, metric) by parsing `LS` output; optionally filter by vertex type -- `tigergraph__get_vector_index_status` - Check vector index rebuild status (Ready_for_query/Rebuild_processing) - -### Vector Data Operations -- `tigergraph__upsert_vectors` - Upsert multiple vertices with vector data using REST API (batch support) -- `tigergraph__load_vectors_from_csv` - Bulk-load vectors from a CSV/delimited file via a GSQL loading job (creates job, runs with file, drops job) -- `tigergraph__load_vectors_from_json` - Bulk-load vectors from a JSON Lines (.jsonl) file via a GSQL loading job with `JSON_FILE="true"` (creates job, runs with file, drops job) -- `tigergraph__search_top_k_similarity` - Perform vector similarity search using `vectorSearch()` function -- `tigergraph__fetch_vector` - Fetch vertices with vector data using GSQL `PRINT WITH VECTOR` - -**Note:** Vector attributes can ONLY be fetched via GSQL queries with `PRINT v WITH VECTOR;` - they cannot be retrieved via REST API. - -### Data Source Operations -- `tigergraph__create_data_source` - Create a new data source (S3, GCS, Azure Blob, local) -- `tigergraph__update_data_source` - Update an existing data source -- `tigergraph__get_data_source` - Get information about a data source -- `tigergraph__drop_data_source` - Drop a data source -- `tigergraph__get_all_data_sources` - Get all data sources -- `tigergraph__drop_all_data_sources` - Drop all data sources -- `tigergraph__preview_sample_data` - Preview sample data from a file - -### Discovery & Navigation -- `tigergraph__discover_tools` - Search for tools by description, use case, or keywords -- `tigergraph__get_workflow` - Get step-by-step workflow templates for common tasks (e.g., `data_loading`, `schema_creation`, `graph_exploration`) -- `tigergraph__get_tool_info` - Get detailed information about a specific tool (parameters, examples, related tools) - -## LLM-Friendly Features - -The MCP server is designed to help AI agents work effectively with TigerGraph. - -### Structured Responses - -Every tool response follows a consistent JSON structure: - -```json -{ - "success": true, - "operation": "get_node", - "summary": "Found vertex 'p123' of type 'Person'", - "data": { ... }, - "suggestions": [ - "View connected edges: get_node_edges(...)", - "Find neighbors: get_neighbors(...)" - ], - "metadata": { "graph_name": "MyGraph" } -} -``` - -Error responses include actionable recovery hints: - -```json -{ - "success": false, - "operation": "get_node", - "error": "Vertex not found", - "suggestions": [ - "Verify the vertex_id is correct", - "Check vertex type with show_graph_details()" - ] -} -``` - -### Rich Tool Descriptions - -Each tool includes detailed descriptions with: -- **Use cases** — when to call this tool -- **Common workflows** — step-by-step patterns -- **Tips** — best practices and gotchas -- **Warnings** — safety notes for destructive operations -- **Related tools** — what to call next - -### Token Optimization - -Responses are designed for efficient LLM token usage: -- No echoing of input parameters (the LLM already knows what it sent) -- Only returns new information (results, counts, boolean answers) -- Clean text output with no decorative formatting - -### Tool Discovery - -The MCP server includes discovery tools to help AI agents find the right tool for a task: - -```python -# 1. Discover tools for a task -result = await session.call_tool( - "tigergraph__discover_tools", - arguments={"query": "how to add data to the graph"} -) -# Returns: ranked list of relevant tools with use cases - -# 2. Get a workflow template -result = await session.call_tool( - "tigergraph__get_workflow", - arguments={"workflow_type": "data_loading"} -) -# Returns: step-by-step guide with tool calls - -# 3. Get detailed tool info -result = await session.call_tool( - "tigergraph__get_tool_info", - arguments={"tool_name": "tigergraph__add_node"} -) -# Returns: full documentation, examples, related tools -``` - -## Notes - -- **Transport**: The MCP server uses stdio transport by default -- **Error Detection**: GSQL operations include error detection for syntax and semantic errors (since `conn.gsql()` does not raise Python exceptions for GSQL failures) -- **Connection Management**: Connections are pooled by profile and reused across requests — no TCP handshake overhead per tool call. The pool is released automatically at server shutdown. -- **Performance**: Persistent HTTP connection pool per profile; async non-blocking I/O; `v.outdegree()` for O(1) degree counting; batch operations for multiple vertices/edges - -## Backward Compatibility - -All existing pyTigerGraph APIs continue to work as before. MCP support is completely optional and does not affect existing code. The MCP functionality is only available when: - -1. The `mcp` extra is installed -2. You explicitly use MCP-related imports or methods diff --git a/pyTigerGraph/mcp/__init__.py b/pyTigerGraph/mcp/__init__.py index b5b4dbd8..da6dd70b 100644 --- a/pyTigerGraph/mcp/__init__.py +++ b/pyTigerGraph/mcp/__init__.py @@ -1,18 +1,44 @@ # Copyright 2025 TigerGraph Inc. # Licensed under the Apache License, Version 2.0. # See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. -"""Model Context Protocol (MCP) support for TigerGraph. +"""Deprecated MCP shim — the MCP server has moved to the `tigergraph-mcp` package. -This module provides MCP server capabilities for TigerGraph, allowing -AI agents to interact with TigerGraph through the Model Context Protocol. +Install the standalone package:: + + pip install tigergraph-mcp + +Or continue using the convenience alias (which installs `tigergraph-mcp` automatically):: + + pip install pyTigerGraph[mcp] + +Update your imports:: + + # Old + from pyTigerGraph.mcp import serve, MCPServer, ConnectionManager + + # New + from tigergraph_mcp import serve, MCPServer, ConnectionManager """ -from .server import serve, MCPServer -from .connection_manager import get_connection, ConnectionManager +import warnings + +warnings.warn( + "pyTigerGraph.mcp is deprecated and will be removed in a future release. " + "The MCP server now lives in the 'tigergraph-mcp' package. " + "Install it with: pip install tigergraph-mcp " + "Update imports from 'pyTigerGraph.mcp' to 'tigergraph_mcp'.", + DeprecationWarning, + stacklevel=2, +) + +try: + from tigergraph_mcp import serve, MCPServer, get_connection, ConnectionManager # noqa: F401 +except ImportError as e: + raise ImportError( + "Could not import 'tigergraph_mcp'. " + "Install it with: pip install tigergraph-mcp" + ) from e __all__ = [ "serve", @@ -20,4 +46,3 @@ "get_connection", "ConnectionManager", ] - diff --git a/pyTigerGraph/mcp/connection_manager.py b/pyTigerGraph/mcp/connection_manager.py deleted file mode 100644 index b428ccba..00000000 --- a/pyTigerGraph/mcp/connection_manager.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Connection manager for MCP server. - -Manages AsyncTigerGraphConnection instances for MCP tools. -Supports named connection profiles via environment variables: - - - Default profile uses unprefixed ``TG_*`` vars (backward compatible). - - Named profiles use ``_TG_*`` vars (e.g. ``STAGING_TG_HOST``). - - ``TG_PROFILE`` selects the active profile (default: ``"default"``). -""" - -import os -import logging -from pathlib import Path -from typing import Optional, Dict, Any, List -from pyTigerGraph import AsyncTigerGraphConnection -from pyTigerGraph.common.exception import TigerGraphException - -logger = logging.getLogger(__name__) - -# Try to load dotenv if available -try: - from dotenv import load_dotenv - _dotenv_available = True -except ImportError: - _dotenv_available = False - - -def _load_env_file(env_path: Optional[str] = None) -> None: - """Load environment variables from .env file if available. - - Args: - env_path: Optional path to .env file. If not provided, looks for .env in current directory. - """ - if not _dotenv_available: - return - - if env_path: - env_file = Path(env_path).expanduser().resolve() - else: - # Look for .env in current directory and parent directories - current_dir = Path.cwd() - env_file = None - for directory in [current_dir] + list(current_dir.parents): - potential_env = directory / ".env" - if potential_env.exists(): - env_file = potential_env - break - - if env_file is None: - # Also check in the directory where the script is running - env_file = Path(".env") - - if env_file and env_file.exists(): - load_dotenv(env_file, override=False) # Don't override existing env vars - logger.debug(f"Loaded environment variables from {env_file}") - elif env_path: - logger.warning(f"Specified .env file not found: {env_path}") - - -def _get_env_for_profile(profile: str, key: str, default: str = "") -> str: - """Resolve a config value for a profile. - - Default profile uses unprefixed ``TG_*`` vars. - Named profiles use ``_TG_*`` vars, falling back to - the unprefixed ``TG_*`` var, then the built-in *default*. - """ - if profile == "default": - return os.getenv(f"TG_{key}", default) - return os.getenv( - f"{profile.upper()}_TG_{key}", - os.getenv(f"TG_{key}", default), - ) - - -class ConnectionManager: - """Manages TigerGraph connections for MCP tools. - - Connections are pooled by ``profile:graph_name`` key so that - repeated calls with the same profile reuse the same connection. - - Call ``await ConnectionManager.close_all()`` at server shutdown to release - the persistent HTTP connection pools held by each ``AsyncTigerGraphConnection``. - """ - - _connection_pool: Dict[str, AsyncTigerGraphConnection] = {} - _profiles: set = set() - - # Keep legacy single-connection reference for backward compat - _default_connection: Optional[AsyncTigerGraphConnection] = None - - @classmethod - def load_profiles(cls, env_path: Optional[str] = None) -> None: - """Discover available profiles from environment variables. - - Profiles are detected by scanning for ``_TG_HOST`` env vars. - The ``"default"`` profile always exists and uses unprefixed ``TG_*`` - vars. Called once at server startup. - """ - _load_env_file(env_path) - - for key in os.environ: - if key.endswith("_TG_HOST") and not key.startswith("TG_"): - profile = key.rsplit("_TG_HOST", 1)[0].lower() - cls._profiles.add(profile) - - cls._profiles.add("default") - logger.info(f"Discovered connection profiles: {sorted(cls._profiles)}") - - @classmethod - def list_profiles(cls) -> List[str]: - """Return sorted list of discovered profile names.""" - if not cls._profiles: - cls._profiles.add("default") - return sorted(cls._profiles) - - @classmethod - def get_default_connection(cls) -> Optional[AsyncTigerGraphConnection]: - """Get the default connection instance (backward compat).""" - return cls._default_connection - - @classmethod - def set_default_connection(cls, conn: AsyncTigerGraphConnection) -> None: - """Set the default connection instance (backward compat).""" - cls._default_connection = conn - - @classmethod - async def close_all(cls) -> None: - """Close all pooled connections and release their HTTP sockets. - - Call this at server/application shutdown to drain keep-alive connections - gracefully. Connections are removed from the pool after closing so that - subsequent calls to get_connection_for_profile() create fresh sessions. - - Example: - ```python - # In an MCP server lifespan or FastAPI shutdown event: - await ConnectionManager.close_all() - ``` - """ - for conn in list(cls._connection_pool.values()): - await conn.aclose() - cls._connection_pool.clear() - cls._profiles.clear() - cls._default_connection = None - - @classmethod - def get_connection_for_profile( - cls, - profile: str = "default", - graph_name: Optional[str] = None, - ) -> AsyncTigerGraphConnection: - """Get or create a connection for the given profile and optional graph. - - Connections are cached by ``profile`` (or ``profile:graph_name`` when - a graph_name override is given). If a cached connection exists but the - caller passes a different ``graph_name``, the graphname attribute on - the cached connection is updated in place. - """ - cache_key = profile - - if cache_key in cls._connection_pool: - conn = cls._connection_pool[cache_key] - if graph_name and conn.graphname != graph_name: - conn.graphname = graph_name - return conn - - host = _get_env_for_profile(profile, "HOST", "http://127.0.0.1") - graphname = graph_name or _get_env_for_profile(profile, "GRAPHNAME", "") - username = _get_env_for_profile(profile, "USERNAME", "tigergraph") - password = _get_env_for_profile(profile, "PASSWORD", "tigergraph") - gsql_secret = _get_env_for_profile(profile, "SECRET", "") - api_token = _get_env_for_profile(profile, "API_TOKEN", "") - jwt_token = _get_env_for_profile(profile, "JWT_TOKEN", "") - restpp_port = _get_env_for_profile(profile, "RESTPP_PORT", "9000") - gs_port = _get_env_for_profile(profile, "GS_PORT", "14240") - ssl_port = _get_env_for_profile(profile, "SSL_PORT", "443") - tg_cloud = _get_env_for_profile(profile, "TGCLOUD", "false").lower() == "true" - cert_path = _get_env_for_profile(profile, "CERT_PATH", "") or None - conn = AsyncTigerGraphConnection( - host=host, - graphname=graphname, - username=username, - password=password, - gsqlSecret=gsql_secret if gsql_secret else "", - apiToken=api_token if api_token else "", - jwtToken=jwt_token if jwt_token else "", - restppPort=restpp_port, - gsPort=gs_port, - sslPort=ssl_port, - tgCloud=tg_cloud, - certPath=cert_path, - ) - - cls._connection_pool[cache_key] = conn - - if profile == "default": - cls._default_connection = conn - - logger.info(f"Created connection for profile '{profile}' -> {host}") - return conn - - @classmethod - def get_profile_info(cls, profile: str = "default") -> Dict[str, str]: - """Return non-sensitive connection info for a profile. - - Never includes password, secret, or tokens. - """ - return { - "profile": profile, - "host": _get_env_for_profile(profile, "HOST", "http://127.0.0.1"), - "graphname": _get_env_for_profile(profile, "GRAPHNAME", ""), - "username": _get_env_for_profile(profile, "USERNAME", "tigergraph"), - "restpp_port": _get_env_for_profile(profile, "RESTPP_PORT", "9000"), - "gs_port": _get_env_for_profile(profile, "GS_PORT", "14240"), - "tgcloud": _get_env_for_profile(profile, "TGCLOUD", "false"), - } - - @classmethod - def create_connection_from_env(cls, env_path: Optional[str] = None) -> AsyncTigerGraphConnection: - """Create a connection from environment variables (backward compat). - - Equivalent to ``get_connection_for_profile("default")``. - """ - _load_env_file(env_path) - return cls.get_connection_for_profile("default") - - @classmethod - async def close_all(cls) -> None: - """Close all pooled connections and release their HTTP connection pools. - - Call at server shutdown to cleanly drain open sockets held by the - persistent ``aiohttp.ClientSession`` inside each connection. - """ - for key, conn in list(cls._connection_pool.items()): - try: - await conn.aclose() - logger.debug(f"Closed connection for profile '{key}'") - except Exception as e: - logger.warning(f"Error closing connection '{key}': {e}") - cls._connection_pool.clear() - cls._default_connection = None - - -def get_connection( - profile: Optional[str] = None, - graph_name: Optional[str] = None, - connection_config: Optional[Dict[str, Any]] = None, -) -> AsyncTigerGraphConnection: - """Get or create an async TigerGraph connection. - - Args: - profile: Connection profile name. Falls back to ``TG_PROFILE`` env var, - then ``"default"``. - graph_name: Graph name override. If provided, updates the connection's - active graph. - connection_config: Explicit connection config dict. If provided, creates - a one-off connection (not pooled). - - Returns: - AsyncTigerGraphConnection instance. - """ - if connection_config: - return AsyncTigerGraphConnection( - host=connection_config.get("host", "http://127.0.0.1"), - graphname=connection_config.get("graphname", graph_name or ""), - username=connection_config.get("username", "tigergraph"), - password=connection_config.get("password", "tigergraph"), - gsqlSecret=connection_config.get("gsqlSecret", ""), - apiToken=connection_config.get("apiToken", ""), - jwtToken=connection_config.get("jwtToken", ""), - restppPort=connection_config.get("restppPort", "9000"), - gsPort=connection_config.get("gsPort", "14240"), - sslPort=connection_config.get("sslPort", "443"), - tgCloud=connection_config.get("tgCloud", False), - certPath=connection_config.get("certPath", None), - ) - - effective_profile = profile or os.getenv("TG_PROFILE", "default") - return ConnectionManager.get_connection_for_profile(effective_profile, graph_name) diff --git a/pyTigerGraph/mcp/main.py b/pyTigerGraph/mcp/main.py deleted file mode 100644 index 45cb54db..00000000 --- a/pyTigerGraph/mcp/main.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Main entry point for TigerGraph MCP server.""" - -import logging -import sys -import click -import asyncio -from pathlib import Path - -from .server import serve - - -@click.command() -@click.option("-v", "--verbose", count=True) -@click.option("--env-file", type=click.Path(exists=True, path_type=Path), default=None, - help="Path to .env file (default: searches for .env in current and parent directories)") -def main(verbose: bool, env_file: Path = None) -> None: - """TigerGraph MCP Server - TigerGraph functionality for MCP - - The server will automatically load environment variables from a .env file - if python-dotenv is installed and a .env file is found. - """ - - logging_level = logging.WARN - if verbose == 1: - logging_level = logging.INFO - elif verbose >= 2: - logging_level = logging.DEBUG - - logging.basicConfig(level=logging_level, stream=sys.stderr) - - # Ensure mcp.server.lowlevel.server respects the WARNING level - logging.getLogger('mcp.server.lowlevel.server').setLevel(logging.WARNING) - - # Load .env file and discover connection profiles - from .connection_manager import ConnectionManager - ConnectionManager.load_profiles(env_path=str(env_file) if env_file else None) - - asyncio.run(serve()) - - -if __name__ == "__main__": - main() - diff --git a/pyTigerGraph/mcp/response_formatter.py b/pyTigerGraph/mcp/response_formatter.py deleted file mode 100644 index 00b99f5e..00000000 --- a/pyTigerGraph/mcp/response_formatter.py +++ /dev/null @@ -1,315 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Structured response formatting for MCP tools. - -This module provides utilities for creating consistent, LLM-friendly responses -from MCP tools. It ensures responses are both machine-readable and human-friendly. -""" - -import json -from typing import Any, Dict, List, Optional -from datetime import datetime -from pydantic import BaseModel -from mcp.types import TextContent - - -class ToolResponse(BaseModel): - """Structured response format for all MCP tools. - - This format provides: - - Clear success/failure indication - - Structured data for parsing - - Human-readable summary - - Contextual suggestions for next steps - - Rich metadata - """ - success: bool - operation: str - data: Optional[Dict[str, Any]] = None - summary: str - metadata: Optional[Dict[str, Any]] = None - suggestions: Optional[List[str]] = None - error: Optional[str] = None - error_code: Optional[str] = None - timestamp: str = None - - def __init__(self, **data): - if 'timestamp' not in data: - data['timestamp'] = datetime.utcnow().isoformat() + 'Z' - super().__init__(**data) - - -def format_response( - success: bool, - operation: str, - summary: str, - data: Optional[Dict[str, Any]] = None, - suggestions: Optional[List[str]] = None, - error: Optional[str] = None, - error_code: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, -) -> List[TextContent]: - """Create a structured response for MCP tools. - - Args: - success: Whether the operation succeeded - operation: Name of the operation (tool name without prefix) - summary: Human-readable summary message - data: Structured result data - suggestions: List of suggested next steps or actions - error: Error message if success=False - error_code: Optional error code for categorization - metadata: Additional context (graph_name, timing, etc.) - - Returns: - List of TextContent with both JSON and formatted text - - Example: - >>> format_response( - ... success=True, - ... operation="add_node", - ... summary="Node added successfully", - ... data={"vertex_id": "user1", "vertex_type": "Person"}, - ... suggestions=["Use 'get_node' to verify", "Use 'add_edge' to connect"] - ... ) - """ - - response = ToolResponse( - success=success, - operation=operation, - summary=summary, - data=data, - suggestions=suggestions, - error=error, - error_code=error_code, - metadata=metadata - ) - - # Create structured JSON output - json_output = response.model_dump_json(indent=2, exclude_none=True) - - # Create human-readable format - text_parts = [f"**{summary}**"] - - # Add data section - if data: - text_parts.append(f"\n**Data:**\n```json\n{json.dumps(data, indent=2, default=str)}\n```") - - # Add suggestions - if suggestions and len(suggestions) > 0: - text_parts.append("\n**💡 Suggestions:**") - for i, suggestion in enumerate(suggestions, 1): - text_parts.append(f"{i}. {suggestion}") - - # Add error details - if error: - text_parts.append(f"\n**❌ Error Details:**\n{error}") - if error_code: - text_parts.append(f"\n**Error Code:** {error_code}") - - # Add metadata footer - if metadata: - text_parts.append(f"\n**Metadata:** {json.dumps(metadata, default=str)}") - - text_output = "\n".join(text_parts) - - # Combine both formats - full_output = f"```json\n{json_output}\n```\n\n{text_output}" - - return [TextContent(type="text", text=full_output)] - - -def format_success( - operation: str, - summary: str, - data: Optional[Dict[str, Any]] = None, - suggestions: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, -) -> List[TextContent]: - """Convenience method for successful operations.""" - return format_response( - success=True, - operation=operation, - summary=summary, - data=data, - suggestions=suggestions, - metadata=metadata - ) - - -def format_error( - operation: str, - error: Exception, - context: Optional[Dict[str, Any]] = None, - suggestions: Optional[List[str]] = None, -) -> List[TextContent]: - """Format an error response with contextual recovery hints. - - Args: - operation: Name of the failed operation - error: The exception that occurred - context: Context information (parameters, state, etc.) - suggestions: Optional manual suggestions (auto-generated if not provided) - - Returns: - Formatted error response with recovery hints - """ - - error_str = str(error) - error_lower = error_str.lower() - - # Auto-generate suggestions based on error type if not provided - if suggestions is None: - suggestions = [] - - # Schema/type errors - if any(term in error_lower for term in ["vertex type", "edge type", "type not found"]): - suggestions.extend([ - "The specified type may not exist in the schema", - "Call 'show_graph_details' to see available vertex and edge types", - "Call 'list_graphs' to ensure you're using the correct graph" - ]) - - # Attribute errors - elif any(term in error_lower for term in ["attribute", "column", "field"]): - suggestions.extend([ - "One or more attributes may not match the schema definition", - "Call 'show_graph_details' to see required attributes and their types", - "Check that attribute names are spelled correctly" - ]) - - # Connection errors - elif any(term in error_lower for term in ["connection", "timeout", "unreachable"]): - suggestions.extend([ - "Unable to connect to TigerGraph server", - "Verify TG_HOST environment variable is correct", - "Check network connectivity and firewall settings", - "Ensure TigerGraph server is running" - ]) - - # Authentication errors - elif any(term in error_lower for term in ["auth", "token", "permission", "forbidden"]): - suggestions.extend([ - "Authentication failed - check credentials", - "Verify TG_USERNAME and TG_PASSWORD environment variables", - "For TigerGraph Cloud, ensure TG_API_TOKEN is set", - "Check if user has required permissions for this operation" - ]) - - # Query errors - elif any(term in error_lower for term in ["syntax", "parse", "query"]): - suggestions.extend([ - "Query syntax error detected", - "For GSQL: Use 'INTERPRET QUERY () FOR GRAPH { ... }'", - "For Cypher: Use 'INTERPRET OPENCYPHER QUERY () FOR GRAPH { ... }'", - "Call 'show_graph_details' to understand the schema before writing queries" - ]) - - # Vector errors - elif any(term in error_lower for term in ["vector", "dimension", "embedding"]): - suggestions.extend([ - "Vector operation error", - "Ensure vector dimensions match the attribute definition", - "Call 'get_vector_index_status' to check if index is ready", - "Verify vector attribute exists with 'show_graph_details'" - ]) - - # Generic suggestions - if len(suggestions) == 0: - suggestions.extend([ - "Check the error message for specific details", - "Call 'show_graph_details' to understand the current graph structure", - "Verify all required parameters are provided correctly" - ]) - - # Determine error code - error_code = None - if "connection" in error_lower or "timeout" in error_lower: - error_code = "CONNECTION_ERROR" - elif "auth" in error_lower or "permission" in error_lower: - error_code = "AUTHENTICATION_ERROR" - elif "type" in error_lower: - error_code = "SCHEMA_ERROR" - elif "attribute" in error_lower: - error_code = "ATTRIBUTE_ERROR" - elif "syntax" in error_lower or "parse" in error_lower: - error_code = "SYNTAX_ERROR" - else: - error_code = "OPERATION_ERROR" - - return format_response( - success=False, - operation=operation, - summary=f"❌ Failed to {operation.replace('_', ' ')}", - error=error_str, - error_code=error_code, - metadata=context, - suggestions=suggestions - ) - - -def gsql_has_error(result_str: str) -> bool: - """Check whether a GSQL result string indicates a failure. - - ``conn.gsql()`` does **not** raise an exception when a GSQL command fails; - instead, the error message is returned as a plain string. This helper - inspects the result for well-known error patterns so callers can - distinguish success from failure. - """ - error_patterns = [ - "Encountered \"", - "SEMANTIC ERROR", - "Syntax Error", - "Failed to create", - "does not exist", - "is not a valid", - "already exists", - "Invalid syntax", - ] - return any(p in result_str for p in error_patterns) - - -def format_list_response( - operation: str, - items: List[Any], - item_type: str = "items", - summary_template: Optional[str] = None, - suggestions: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, -) -> List[TextContent]: - """Format a response containing a list of items. - - Args: - operation: Name of the operation - items: List of items to return - item_type: Type of items (for summary message) - summary_template: Optional custom summary (use {count} and {type} placeholders) - suggestions: Optional suggestions - metadata: Optional metadata - - Returns: - Formatted response - """ - - count = len(items) - - if summary_template: - summary = summary_template.format(count=count, type=item_type) - else: - summary = f"✅ Found {count} {item_type}" - - return format_success( - operation=operation, - summary=summary, - data={ - "count": count, - item_type: items - }, - suggestions=suggestions, - metadata=metadata - ) diff --git a/pyTigerGraph/mcp/server.py b/pyTigerGraph/mcp/server.py deleted file mode 100644 index 8f9745d5..00000000 --- a/pyTigerGraph/mcp/server.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""MCP Server implementation for TigerGraph.""" - -import logging -from typing import Dict, List -from mcp.server import Server -from mcp.server.stdio import stdio_server -from mcp.types import Tool, TextContent - -from .tool_names import TigerGraphToolName -from pyTigerGraph.common.exception import TigerGraphException -from .tools import ( - get_all_tools, - # Connection profile operations - list_connections, - show_connection, - # Global schema operations (database level) - get_global_schema, - # Graph operations (database level) - list_graphs, - create_graph, - drop_graph, - clear_graph_data, - # Schema operations (graph level) - get_graph_schema, - show_graph_details, - # Node tools - add_node, - add_nodes, - get_node, - get_nodes, - delete_node, - delete_nodes, - has_node, - get_node_edges, - # Edge tools - add_edge, - add_edges, - get_edge, - get_edges, - delete_edge, - delete_edges, - has_edge, - # Query tools - run_query, - run_installed_query, - install_query, - drop_query, - show_query, - get_query_metadata, - is_query_installed, - get_neighbors, - # Loading job tools - create_loading_job, - run_loading_job_with_file, - run_loading_job_with_data, - get_loading_jobs, - get_loading_job_status, - drop_loading_job, - # Statistics tools - get_vertex_count, - get_edge_count, - get_node_degree, - # GSQL tools - gsql, - generate_gsql, - generate_cypher, - # Vector schema tools - add_vector_attribute, - drop_vector_attribute, - list_vector_attributes, - get_vector_index_status, - # Vector data tools - upsert_vectors, - load_vectors_from_csv, - load_vectors_from_json, - search_top_k_similarity, - fetch_vector, - # Data Source tools - create_data_source, - update_data_source, - get_data_source, - drop_data_source, - get_all_data_sources, - drop_all_data_sources, - preview_sample_data, - # Discovery tools - discover_tools, - get_workflow, - get_tool_info, -) - -logger = logging.getLogger(__name__) - - -class MCPServer: - """MCP Server for TigerGraph.""" - - def __init__(self, name: str = "TigerGraph-MCP"): - """Initialize the MCP server.""" - self.server = Server(name) - self._setup_handlers() - - def _setup_handlers(self): - """Setup MCP server handlers.""" - - @self.server.list_tools() - async def list_tools() -> List[Tool]: - """List all available tools.""" - return get_all_tools() - - @self.server.call_tool() - async def call_tool(name: str, arguments: Dict) -> List[TextContent]: - """Handle tool calls.""" - try: - match name: - # Connection profile operations - case TigerGraphToolName.LIST_CONNECTIONS: - return await list_connections(**arguments) - case TigerGraphToolName.SHOW_CONNECTION: - return await show_connection(**arguments) - # Global schema operations (database level) - case TigerGraphToolName.GET_GLOBAL_SCHEMA: - return await get_global_schema(**arguments) - # Graph operations (database level) - case TigerGraphToolName.LIST_GRAPHS: - return await list_graphs(**arguments) - case TigerGraphToolName.CREATE_GRAPH: - return await create_graph(**arguments) - case TigerGraphToolName.DROP_GRAPH: - return await drop_graph(**arguments) - case TigerGraphToolName.CLEAR_GRAPH_DATA: - return await clear_graph_data(**arguments) - # Schema operations (graph level) - case TigerGraphToolName.GET_GRAPH_SCHEMA: - return await get_graph_schema(**arguments) - case TigerGraphToolName.SHOW_GRAPH_DETAILS: - return await show_graph_details(**arguments) - # Node operations - case TigerGraphToolName.ADD_NODE: - return await add_node(**arguments) - case TigerGraphToolName.ADD_NODES: - return await add_nodes(**arguments) - case TigerGraphToolName.GET_NODE: - return await get_node(**arguments) - case TigerGraphToolName.GET_NODES: - return await get_nodes(**arguments) - case TigerGraphToolName.DELETE_NODE: - return await delete_node(**arguments) - case TigerGraphToolName.DELETE_NODES: - return await delete_nodes(**arguments) - case TigerGraphToolName.HAS_NODE: - return await has_node(**arguments) - case TigerGraphToolName.GET_NODE_EDGES: - return await get_node_edges(**arguments) - # Edge operations - case TigerGraphToolName.ADD_EDGE: - return await add_edge(**arguments) - case TigerGraphToolName.ADD_EDGES: - return await add_edges(**arguments) - case TigerGraphToolName.GET_EDGE: - return await get_edge(**arguments) - case TigerGraphToolName.GET_EDGES: - return await get_edges(**arguments) - case TigerGraphToolName.DELETE_EDGE: - return await delete_edge(**arguments) - case TigerGraphToolName.DELETE_EDGES: - return await delete_edges(**arguments) - case TigerGraphToolName.HAS_EDGE: - return await has_edge(**arguments) - # Query operations - case TigerGraphToolName.RUN_QUERY: - return await run_query(**arguments) - case TigerGraphToolName.RUN_INSTALLED_QUERY: - return await run_installed_query(**arguments) - case TigerGraphToolName.INSTALL_QUERY: - return await install_query(**arguments) - case TigerGraphToolName.DROP_QUERY: - return await drop_query(**arguments) - case TigerGraphToolName.SHOW_QUERY: - return await show_query(**arguments) - case TigerGraphToolName.GET_QUERY_METADATA: - return await get_query_metadata(**arguments) - case TigerGraphToolName.IS_QUERY_INSTALLED: - return await is_query_installed(**arguments) - case TigerGraphToolName.GET_NEIGHBORS: - return await get_neighbors(**arguments) - # Loading job operations - case TigerGraphToolName.CREATE_LOADING_JOB: - return await create_loading_job(**arguments) - case TigerGraphToolName.RUN_LOADING_JOB_WITH_FILE: - return await run_loading_job_with_file(**arguments) - case TigerGraphToolName.RUN_LOADING_JOB_WITH_DATA: - return await run_loading_job_with_data(**arguments) - case TigerGraphToolName.GET_LOADING_JOBS: - return await get_loading_jobs(**arguments) - case TigerGraphToolName.GET_LOADING_JOB_STATUS: - return await get_loading_job_status(**arguments) - case TigerGraphToolName.DROP_LOADING_JOB: - return await drop_loading_job(**arguments) - # Statistics operations - case TigerGraphToolName.GET_VERTEX_COUNT: - return await get_vertex_count(**arguments) - case TigerGraphToolName.GET_EDGE_COUNT: - return await get_edge_count(**arguments) - case TigerGraphToolName.GET_NODE_DEGREE: - return await get_node_degree(**arguments) - # GSQL operations - case TigerGraphToolName.GSQL: - return await gsql(**arguments) - case TigerGraphToolName.GENERATE_GSQL: - return await generate_gsql(**arguments) - case TigerGraphToolName.GENERATE_CYPHER: - return await generate_cypher(**arguments) - # Vector schema operations - case TigerGraphToolName.ADD_VECTOR_ATTRIBUTE: - return await add_vector_attribute(**arguments) - case TigerGraphToolName.DROP_VECTOR_ATTRIBUTE: - return await drop_vector_attribute(**arguments) - case TigerGraphToolName.LIST_VECTOR_ATTRIBUTES: - return await list_vector_attributes(**arguments) - case TigerGraphToolName.GET_VECTOR_INDEX_STATUS: - return await get_vector_index_status(**arguments) - # Vector data operations - case TigerGraphToolName.UPSERT_VECTORS: - return await upsert_vectors(**arguments) - case TigerGraphToolName.LOAD_VECTORS_FROM_CSV: - return await load_vectors_from_csv(**arguments) - case TigerGraphToolName.LOAD_VECTORS_FROM_JSON: - return await load_vectors_from_json(**arguments) - case TigerGraphToolName.SEARCH_TOP_K_SIMILARITY: - return await search_top_k_similarity(**arguments) - case TigerGraphToolName.FETCH_VECTOR: - return await fetch_vector(**arguments) - # Data Source operations - case TigerGraphToolName.CREATE_DATA_SOURCE: - return await create_data_source(**arguments) - case TigerGraphToolName.UPDATE_DATA_SOURCE: - return await update_data_source(**arguments) - case TigerGraphToolName.GET_DATA_SOURCE: - return await get_data_source(**arguments) - case TigerGraphToolName.DROP_DATA_SOURCE: - return await drop_data_source(**arguments) - case TigerGraphToolName.GET_ALL_DATA_SOURCES: - return await get_all_data_sources(**arguments) - case TigerGraphToolName.DROP_ALL_DATA_SOURCES: - return await drop_all_data_sources(**arguments) - case TigerGraphToolName.PREVIEW_SAMPLE_DATA: - return await preview_sample_data(**arguments) - # Discovery operations - case TigerGraphToolName.DISCOVER_TOOLS: - return await discover_tools(**arguments) - case TigerGraphToolName.GET_WORKFLOW: - return await get_workflow(**arguments) - case TigerGraphToolName.GET_TOOL_INFO: - return await get_tool_info(**arguments) - case _: - raise ValueError(f"Unknown tool: {name}") - except TigerGraphException as e: - logger.exception("Error in tool execution") - error_msg = e.message if hasattr(e, 'message') else str(e) - error_code = f" (Code: {e.code})" if hasattr(e, 'code') and e.code else "" - return [TextContent(type="text", text=f"❌ TigerGraph Error{error_code} due to: {error_msg}")] - except Exception as e: - logger.exception("Error in tool execution") - return [TextContent(type="text", text=f"❌ Error due to: {str(e)}")] - - -async def serve() -> None: - """Serve the MCP server.""" - from .connection_manager import ConnectionManager - server = MCPServer() - options = server.server.create_initialization_options() - try: - async with stdio_server() as (read_stream, write_stream): - await server.server.run(read_stream, write_stream, options, raise_exceptions=True) - finally: - await ConnectionManager.close_all() - diff --git a/pyTigerGraph/mcp/tool_metadata.py b/pyTigerGraph/mcp/tool_metadata.py deleted file mode 100644 index 409e5ddb..00000000 --- a/pyTigerGraph/mcp/tool_metadata.py +++ /dev/null @@ -1,528 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Tool metadata for enhanced LLM guidance.""" - -from typing import List, Dict, Any, Optional -from pydantic import BaseModel -from enum import Enum - - -class ToolCategory(str, Enum): - """Categories for organizing tools.""" - SCHEMA = "schema" - DATA = "data" - QUERY = "query" - VECTOR = "vector" - LOADING = "loading" - DISCOVERY = "discovery" - UTILITY = "utility" - - -class ToolMetadata(BaseModel): - """Enhanced metadata for tools to help LLMs understand usage patterns.""" - category: ToolCategory - prerequisites: List[str] = [] - related_tools: List[str] = [] - common_next_steps: List[str] = [] - use_cases: List[str] = [] - complexity: str = "basic" # basic, intermediate, advanced - examples: List[Dict[str, Any]] = [] - keywords: List[str] = [] # For discovery - - -# Define metadata for each tool -TOOL_METADATA: Dict[str, ToolMetadata] = { - # Schema Operations - "tigergraph__show_graph_details": ToolMetadata( - category=ToolCategory.SCHEMA, - prerequisites=[], - related_tools=["tigergraph__get_graph_schema"], - common_next_steps=["tigergraph__add_node", "tigergraph__add_edge", "tigergraph__run_query"], - use_cases=[ - "Getting a full listing of a graph (schema, queries, jobs)", - "Understanding the structure of a graph before writing queries", - "Discovering available vertex and edge types", - "First step in any graph interaction workflow" - ], - complexity="basic", - keywords=["schema", "structure", "show", "understand", "explore", "queries", "jobs"], - examples=[ - { - "description": "Show everything under default graph", - "parameters": {} - }, - { - "description": "Show everything under a specific graph", - "parameters": {"graph_name": "SocialGraph"} - } - ] - ), - - "tigergraph__list_graphs": ToolMetadata( - category=ToolCategory.SCHEMA, - prerequisites=[], - related_tools=["tigergraph__show_graph_details", "tigergraph__create_graph"], - common_next_steps=["tigergraph__show_graph_details"], - use_cases=[ - "Discovering what graphs exist in the database", - "First step when connecting to a new TigerGraph instance", - "Verifying a graph was created successfully" - ], - complexity="basic", - keywords=["list", "graphs", "discover", "available"], - examples=[{"description": "List all graphs", "parameters": {}}] - ), - - "tigergraph__create_graph": ToolMetadata( - category=ToolCategory.SCHEMA, - prerequisites=[], - related_tools=["tigergraph__list_graphs", "tigergraph__show_graph_details"], - common_next_steps=["tigergraph__show_graph_details", "tigergraph__add_node"], - use_cases=[ - "Creating a new graph from scratch", - "Setting up a graph with specific vertex and edge types", - "Initializing a new project or data model" - ], - complexity="intermediate", - keywords=["create", "new", "graph", "initialize", "setup"], - examples=[ - { - "description": "Create a social network graph", - "parameters": { - "graph_name": "SocialGraph", - "vertex_types": [ - { - "name": "Person", - "attributes": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "INT"} - ] - } - ], - "edge_types": [ - { - "name": "FOLLOWS", - "from_vertex": "Person", - "to_vertex": "Person" - } - ] - } - } - ] - ), - - "tigergraph__get_graph_schema": ToolMetadata( - category=ToolCategory.SCHEMA, - prerequisites=[], - related_tools=["tigergraph__show_graph_details"], - common_next_steps=["tigergraph__add_node", "tigergraph__run_query"], - use_cases=[ - "Getting raw JSON schema for programmatic processing", - "Detailed schema inspection for advanced use cases" - ], - complexity="intermediate", - keywords=["schema", "json", "raw", "detailed"], - examples=[{"description": "Get raw schema", "parameters": {}}] - ), - - # Node Operations - "tigergraph__add_node": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=["tigergraph__show_graph_details"], - related_tools=["tigergraph__add_nodes", "tigergraph__get_node", "tigergraph__delete_node"], - common_next_steps=["tigergraph__get_node", "tigergraph__add_edge", "tigergraph__get_node_edges"], - use_cases=[ - "Creating a single vertex in the graph", - "Updating an existing vertex's attributes", - "Adding individual entities (users, products, etc.)" - ], - complexity="basic", - keywords=["add", "create", "insert", "node", "vertex", "single"], - examples=[ - { - "description": "Add a person node", - "parameters": { - "vertex_type": "Person", - "vertex_id": "user123", - "attributes": {"name": "Alice", "age": 30, "city": "San Francisco"} - } - }, - { - "description": "Add a product node", - "parameters": { - "vertex_type": "Product", - "vertex_id": "prod456", - "attributes": {"name": "Laptop", "price": 999.99, "category": "Electronics"} - } - } - ] - ), - - "tigergraph__add_nodes": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=["tigergraph__show_graph_details"], - related_tools=["tigergraph__add_node", "tigergraph__get_nodes"], - common_next_steps=["tigergraph__get_vertex_count", "tigergraph__add_edges"], - use_cases=[ - "Batch loading multiple vertices efficiently", - "Importing data from CSV or JSON", - "Initial data population" - ], - complexity="basic", - keywords=["add", "create", "insert", "batch", "multiple", "bulk", "nodes", "vertices"], - examples=[ - { - "description": "Add multiple person nodes", - "parameters": { - "vertex_type": "Person", - "vertices": [ - {"id": "user1", "name": "Alice", "age": 30}, - {"id": "user2", "name": "Bob", "age": 25}, - {"id": "user3", "name": "Carol", "age": 35} - ] - } - } - ] - ), - - "tigergraph__get_node": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=[], - related_tools=["tigergraph__get_nodes", "tigergraph__has_node"], - common_next_steps=["tigergraph__get_node_edges", "tigergraph__delete_node"], - use_cases=[ - "Retrieving a specific vertex by ID", - "Verifying a vertex was created", - "Checking vertex attributes" - ], - complexity="basic", - keywords=["get", "retrieve", "fetch", "read", "node", "vertex", "single"], - examples=[ - { - "description": "Get a person node", - "parameters": { - "vertex_type": "Person", - "vertex_id": "user123" - } - } - ] - ), - - "tigergraph__get_nodes": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=[], - related_tools=["tigergraph__get_node", "tigergraph__get_vertex_count"], - common_next_steps=["tigergraph__get_edges"], - use_cases=[ - "Retrieving multiple vertices of a type", - "Exploring graph data", - "Data export and analysis" - ], - complexity="basic", - keywords=["get", "retrieve", "fetch", "list", "multiple", "nodes", "vertices"], - examples=[ - { - "description": "Get all person nodes (limited)", - "parameters": { - "vertex_type": "Person", - "limit": 100 - } - } - ] - ), - - # Edge Operations - "tigergraph__add_edge": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=["tigergraph__add_node", "tigergraph__show_graph_details"], - related_tools=["tigergraph__add_edges", "tigergraph__get_edge"], - common_next_steps=["tigergraph__get_node_edges", "tigergraph__get_neighbors"], - use_cases=[ - "Creating a relationship between two vertices", - "Connecting entities in the graph", - "Building graph structure" - ], - complexity="basic", - keywords=["add", "create", "connect", "relationship", "edge", "link"], - examples=[ - { - "description": "Create a friendship edge", - "parameters": { - "edge_type": "FOLLOWS", - "from_vertex_type": "Person", - "from_vertex_id": "user1", - "to_vertex_type": "Person", - "to_vertex_id": "user2", - "attributes": {"since": "2024-01-15"} - } - } - ] - ), - - "tigergraph__add_edges": ToolMetadata( - category=ToolCategory.DATA, - prerequisites=["tigergraph__add_nodes", "tigergraph__show_graph_details"], - related_tools=["tigergraph__add_edge"], - common_next_steps=["tigergraph__get_edge_count"], - use_cases=[ - "Batch loading multiple edges", - "Building graph structure efficiently", - "Importing relationship data" - ], - complexity="basic", - keywords=["add", "create", "batch", "multiple", "edges", "relationships", "bulk"], - examples=[] - ), - - # Query Operations - "tigergraph__run_query": ToolMetadata( - category=ToolCategory.QUERY, - prerequisites=["tigergraph__show_graph_details"], - related_tools=["tigergraph__run_installed_query", "tigergraph__get_neighbors"], - common_next_steps=[], - use_cases=[ - "Ad-hoc querying without installing", - "Testing queries before installation", - "Simple data retrieval operations", - "Running openCypher or GSQL queries" - ], - complexity="intermediate", - keywords=["query", "search", "find", "select", "interpret", "gsql", "cypher"], - examples=[ - { - "description": "Simple GSQL query", - "parameters": { - "query_text": "INTERPRET QUERY () FOR GRAPH MyGraph { SELECT v FROM Person:v LIMIT 5; PRINT v; }" - } - }, - { - "description": "openCypher query", - "parameters": { - "query_text": "INTERPRET OPENCYPHER QUERY () FOR GRAPH MyGraph { MATCH (n:Person) RETURN n LIMIT 5 }" - } - } - ] - ), - - "tigergraph__get_neighbors": ToolMetadata( - category=ToolCategory.QUERY, - prerequisites=[], - related_tools=["tigergraph__get_node_edges", "tigergraph__run_query"], - common_next_steps=[], - use_cases=[ - "Finding vertices connected to a given vertex", - "1-hop graph traversal", - "Discovering relationships" - ], - complexity="basic", - keywords=["neighbors", "connected", "adjacent", "traverse", "related"], - examples=[ - { - "description": "Get friends of a person", - "parameters": { - "vertex_type": "Person", - "vertex_id": "user1", - "edge_type": "FOLLOWS" - } - } - ] - ), - - # Vector Operations - "tigergraph__add_vector_attribute": ToolMetadata( - category=ToolCategory.VECTOR, - prerequisites=["tigergraph__show_graph_details"], - related_tools=["tigergraph__drop_vector_attribute", "tigergraph__get_vector_index_status"], - common_next_steps=["tigergraph__get_vector_index_status", "tigergraph__upsert_vectors"], - use_cases=[ - "Adding vector/embedding support to existing vertex types", - "Setting up semantic search capabilities", - "Enabling similarity-based queries" - ], - complexity="intermediate", - keywords=["vector", "embedding", "add", "attribute", "similarity", "semantic"], - examples=[ - { - "description": "Add embedding attribute for documents", - "parameters": { - "vertex_type": "Document", - "vector_name": "embedding", - "dimension": 384, - "metric": "COSINE" - } - }, - { - "description": "Add embedding for products (higher dimension)", - "parameters": { - "vertex_type": "Product", - "vector_name": "feature_vector", - "dimension": 1536, - "metric": "L2" - } - } - ] - ), - - "tigergraph__upsert_vectors": ToolMetadata( - category=ToolCategory.VECTOR, - prerequisites=["tigergraph__add_vector_attribute", "tigergraph__get_vector_index_status"], - related_tools=["tigergraph__search_top_k_similarity", "tigergraph__fetch_vector"], - common_next_steps=["tigergraph__get_vector_index_status", "tigergraph__search_top_k_similarity"], - use_cases=[ - "Loading embedding vectors into the graph", - "Updating vector data for vertices", - "Populating semantic search index" - ], - complexity="intermediate", - keywords=["vector", "embedding", "upsert", "load", "insert", "update"], - examples=[ - { - "description": "Upsert document embeddings", - "parameters": { - "vertex_type": "Document", - "vector_attribute": "embedding", - "vectors": [ - { - "vertex_id": "doc1", - "vector": [0.1, 0.2, 0.3], - "attributes": {"title": "Document 1"} - } - ] - } - } - ] - ), - - "tigergraph__search_top_k_similarity": ToolMetadata( - category=ToolCategory.VECTOR, - prerequisites=["tigergraph__upsert_vectors", "tigergraph__get_vector_index_status"], - related_tools=["tigergraph__fetch_vector"], - common_next_steps=[], - use_cases=[ - "Finding similar documents or items", - "Semantic search operations", - "Recommendation based on similarity" - ], - complexity="intermediate", - keywords=["vector", "search", "similarity", "nearest", "semantic", "find", "similar"], - examples=[ - { - "description": "Find similar documents", - "parameters": { - "vertex_type": "Document", - "vector_attribute": "embedding", - "query_vector": [0.1, 0.2, 0.3], - "top_k": 10 - } - } - ] - ), - - # Loading Operations - "tigergraph__create_loading_job": ToolMetadata( - category=ToolCategory.LOADING, - prerequisites=["tigergraph__show_graph_details"], - related_tools=["tigergraph__run_loading_job_with_file", "tigergraph__run_loading_job_with_data"], - common_next_steps=["tigergraph__run_loading_job_with_file", "tigergraph__get_loading_jobs"], - use_cases=[ - "Setting up data ingestion from CSV/JSON files", - "Defining how file columns map to vertex/edge attributes", - "Preparing for bulk data loading" - ], - complexity="advanced", - keywords=["loading", "job", "create", "define", "ingest", "import"], - examples=[] - ), - - "tigergraph__run_loading_job_with_file": ToolMetadata( - category=ToolCategory.LOADING, - prerequisites=["tigergraph__create_loading_job"], - related_tools=["tigergraph__run_loading_job_with_data", "tigergraph__get_loading_job_status"], - common_next_steps=["tigergraph__get_loading_job_status", "tigergraph__get_vertex_count"], - use_cases=[ - "Loading data from CSV or JSON files", - "Bulk import of graph data", - "ETL operations" - ], - complexity="intermediate", - keywords=["loading", "job", "run", "file", "import", "bulk"], - examples=[] - ), - - # Statistics - "tigergraph__get_vertex_count": ToolMetadata( - category=ToolCategory.UTILITY, - prerequisites=[], - related_tools=["tigergraph__get_edge_count", "tigergraph__get_nodes"], - common_next_steps=[], - use_cases=[ - "Verifying data was loaded", - "Monitoring graph size", - "Data validation" - ], - complexity="basic", - keywords=["count", "statistics", "size", "vertex", "node", "total"], - examples=[ - { - "description": "Count all vertices", - "parameters": {} - }, - { - "description": "Count specific vertex type", - "parameters": {"vertex_type": "Person"} - } - ] - ), - - "tigergraph__get_edge_count": ToolMetadata( - category=ToolCategory.UTILITY, - prerequisites=[], - related_tools=["tigergraph__get_vertex_count"], - common_next_steps=[], - use_cases=[ - "Verifying relationships were created", - "Monitoring graph connectivity", - "Data validation" - ], - complexity="basic", - keywords=["count", "statistics", "size", "edge", "relationship", "total"], - examples=[] - ), -} - - -def get_tool_metadata(tool_name: str) -> Optional[ToolMetadata]: - """Get metadata for a specific tool.""" - return TOOL_METADATA.get(tool_name) - - -def get_tools_by_category(category: ToolCategory) -> List[str]: - """Get all tool names in a specific category.""" - return [ - tool_name for tool_name, metadata in TOOL_METADATA.items() - if metadata.category == category - ] - - -def search_tools_by_keywords(keywords: List[str]) -> List[str]: - """Search for tools matching any of the provided keywords.""" - matching_tools = [] - keywords_lower = [k.lower() for k in keywords] - - for tool_name, metadata in TOOL_METADATA.items(): - # Check if any keyword matches - for keyword in keywords_lower: - if any(keyword in mk.lower() for mk in metadata.keywords): - matching_tools.append(tool_name) - break - # Also check in use cases - if any(keyword in uc.lower() for uc in metadata.use_cases): - matching_tools.append(tool_name) - break - - return matching_tools diff --git a/pyTigerGraph/mcp/tool_names.py b/pyTigerGraph/mcp/tool_names.py deleted file mode 100644 index 506296a6..00000000 --- a/pyTigerGraph/mcp/tool_names.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Tool names for TigerGraph MCP tools.""" - -from enum import Enum - - -class TigerGraphToolName(str, Enum): - """Enumeration of all available TigerGraph MCP tool names.""" - - # Global Schema Operations (Database level - operates on global schema) - GET_GLOBAL_SCHEMA = "tigergraph__get_global_schema" - - # Graph Operations (Database level - operates on graphs within the database) - LIST_GRAPHS = "tigergraph__list_graphs" - CREATE_GRAPH = "tigergraph__create_graph" - DROP_GRAPH = "tigergraph__drop_graph" - CLEAR_GRAPH_DATA = "tigergraph__clear_graph_data" - - # Schema Operations (Graph level - operates on schema within a specific graph) - GET_GRAPH_SCHEMA = "tigergraph__get_graph_schema" - SHOW_GRAPH_DETAILS = "tigergraph__show_graph_details" - - # Node Operations - ADD_NODE = "tigergraph__add_node" - ADD_NODES = "tigergraph__add_nodes" - GET_NODE = "tigergraph__get_node" - GET_NODES = "tigergraph__get_nodes" - DELETE_NODE = "tigergraph__delete_node" - DELETE_NODES = "tigergraph__delete_nodes" - HAS_NODE = "tigergraph__has_node" - GET_NODE_EDGES = "tigergraph__get_node_edges" - - # Edge Operations - ADD_EDGE = "tigergraph__add_edge" - ADD_EDGES = "tigergraph__add_edges" - GET_EDGE = "tigergraph__get_edge" - GET_EDGES = "tigergraph__get_edges" - DELETE_EDGE = "tigergraph__delete_edge" - DELETE_EDGES = "tigergraph__delete_edges" - HAS_EDGE = "tigergraph__has_edge" - - # Query Operations - RUN_QUERY = "tigergraph__run_query" - RUN_INSTALLED_QUERY = "tigergraph__run_installed_query" - INSTALL_QUERY = "tigergraph__install_query" - DROP_QUERY = "tigergraph__drop_query" - SHOW_QUERY = "tigergraph__show_query" - GET_QUERY_METADATA = "tigergraph__get_query_metadata" - IS_QUERY_INSTALLED = "tigergraph__is_query_installed" - GET_NEIGHBORS = "tigergraph__get_neighbors" - - # Loading Job Operations - CREATE_LOADING_JOB = "tigergraph__create_loading_job" - RUN_LOADING_JOB_WITH_FILE = "tigergraph__run_loading_job_with_file" - RUN_LOADING_JOB_WITH_DATA = "tigergraph__run_loading_job_with_data" - GET_LOADING_JOBS = "tigergraph__get_loading_jobs" - GET_LOADING_JOB_STATUS = "tigergraph__get_loading_job_status" - DROP_LOADING_JOB = "tigergraph__drop_loading_job" - - # Statistics - GET_VERTEX_COUNT = "tigergraph__get_vertex_count" - GET_EDGE_COUNT = "tigergraph__get_edge_count" - GET_NODE_DEGREE = "tigergraph__get_node_degree" - - # GSQL Operations - GSQL = "tigergraph__gsql" - GENERATE_GSQL = "tigergraph__generate_gsql" - GENERATE_CYPHER = "tigergraph__generate_cypher" - - # Vector Schema Operations - ADD_VECTOR_ATTRIBUTE = "tigergraph__add_vector_attribute" - DROP_VECTOR_ATTRIBUTE = "tigergraph__drop_vector_attribute" - LIST_VECTOR_ATTRIBUTES = "tigergraph__list_vector_attributes" - GET_VECTOR_INDEX_STATUS = "tigergraph__get_vector_index_status" - - # Vector Data Operations - UPSERT_VECTORS = "tigergraph__upsert_vectors" - LOAD_VECTORS_FROM_CSV = "tigergraph__load_vectors_from_csv" - LOAD_VECTORS_FROM_JSON = "tigergraph__load_vectors_from_json" - SEARCH_TOP_K_SIMILARITY = "tigergraph__search_top_k_similarity" - FETCH_VECTOR = "tigergraph__fetch_vector" - - # Data Source Operations - CREATE_DATA_SOURCE = "tigergraph__create_data_source" - UPDATE_DATA_SOURCE = "tigergraph__update_data_source" - GET_DATA_SOURCE = "tigergraph__get_data_source" - DROP_DATA_SOURCE = "tigergraph__drop_data_source" - GET_ALL_DATA_SOURCES = "tigergraph__get_all_data_sources" - DROP_ALL_DATA_SOURCES = "tigergraph__drop_all_data_sources" - PREVIEW_SAMPLE_DATA = "tigergraph__preview_sample_data" - - # Connection Profile Operations - LIST_CONNECTIONS = "tigergraph__list_connections" - SHOW_CONNECTION = "tigergraph__show_connection" - - # Discovery and Navigation Operations - DISCOVER_TOOLS = "tigergraph__discover_tools" - GET_WORKFLOW = "tigergraph__get_workflow" - GET_TOOL_INFO = "tigergraph__get_tool_info" - - @classmethod - def from_value(cls, value: str) -> "TigerGraphToolName": - """Get enum from string value.""" - for tool in cls: - if tool.value == value: - return tool - raise ValueError(f"Unknown tool name: {value}") - diff --git a/pyTigerGraph/mcp/tools/__init__.py b/pyTigerGraph/mcp/tools/__init__.py deleted file mode 100644 index ae3c4ec9..00000000 --- a/pyTigerGraph/mcp/tools/__init__.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""MCP tools for TigerGraph.""" - -from .connection_tools import ( - list_connections_tool, - show_connection_tool, - list_connections, - show_connection, -) -from .schema_tools import ( - # Global schema operations (database level) - get_global_schema_tool, - get_global_schema, - # Graph operations (database level) - list_graphs_tool, - create_graph_tool, - drop_graph_tool, - clear_graph_data_tool, - list_graphs, - create_graph, - drop_graph, - clear_graph_data, - # Schema operations (graph level) - get_graph_schema_tool, - show_graph_details_tool, - get_graph_schema, - show_graph_details, -) -from .node_tools import ( - add_node_tool, - add_nodes_tool, - get_node_tool, - get_nodes_tool, - delete_node_tool, - delete_nodes_tool, - has_node_tool, - get_node_edges_tool, - add_node, - add_nodes, - get_node, - get_nodes, - delete_node, - delete_nodes, - has_node, - get_node_edges, -) -from .edge_tools import ( - add_edge_tool, - add_edges_tool, - get_edge_tool, - get_edges_tool, - delete_edge_tool, - delete_edges_tool, - has_edge_tool, - add_edge, - add_edges, - get_edge, - get_edges, - delete_edge, - delete_edges, - has_edge, -) -from .query_tools import ( - run_query_tool, - run_installed_query_tool, - install_query_tool, - drop_query_tool, - show_query_tool, - get_query_metadata_tool, - is_query_installed_tool, - get_neighbors_tool, - run_query, - run_installed_query, - install_query, - drop_query, - show_query, - get_query_metadata, - is_query_installed, - get_neighbors, -) -from .data_tools import ( - create_loading_job_tool, - run_loading_job_with_file_tool, - run_loading_job_with_data_tool, - get_loading_jobs_tool, - get_loading_job_status_tool, - drop_loading_job_tool, - create_loading_job, - run_loading_job_with_file, - run_loading_job_with_data, - get_loading_jobs, - get_loading_job_status, - drop_loading_job, -) -from .statistics_tools import ( - get_vertex_count_tool, - get_edge_count_tool, - get_node_degree_tool, - get_vertex_count, - get_edge_count, - get_node_degree, -) -from .gsql_tools import ( - gsql_tool, - gsql, - generate_gsql_query_tool, - generate_gsql, - generate_cypher_query_tool, - generate_cypher, -) -from .vector_tools import ( - # Vector schema tools - add_vector_attribute_tool, - drop_vector_attribute_tool, - list_vector_attributes_tool, - get_vector_index_status_tool, - add_vector_attribute, - drop_vector_attribute, - list_vector_attributes, - get_vector_index_status, - # Vector data tools - upsert_vectors_tool, - load_vectors_from_csv_tool, - load_vectors_from_json_tool, - search_top_k_similarity_tool, - fetch_vector_tool, - upsert_vectors, - load_vectors_from_csv, - load_vectors_from_json, - search_top_k_similarity, - fetch_vector, -) -from .datasource_tools import ( - create_data_source_tool, - update_data_source_tool, - get_data_source_tool, - drop_data_source_tool, - get_all_data_sources_tool, - drop_all_data_sources_tool, - preview_sample_data_tool, - create_data_source, - update_data_source, - get_data_source, - drop_data_source, - get_all_data_sources, - drop_all_data_sources, - preview_sample_data, -) -from .discovery_tools import ( - discover_tools_tool, - get_workflow_tool, - get_tool_info_tool, - discover_tools, - get_workflow, - get_tool_info, -) -from .tool_registry import get_all_tools - -__all__ = [ - # Connection profile operations - "list_connections_tool", - "show_connection_tool", - "list_connections", - "show_connection", - # Global schema operations (database level) - "get_global_schema_tool", - "get_global_schema", - # Graph operations (database level) - "list_graphs_tool", - "create_graph_tool", - "drop_graph_tool", - "clear_graph_data_tool", - "list_graphs", - "create_graph", - "drop_graph", - "clear_graph_data", - # Schema operations (graph level) - "get_graph_schema_tool", - "show_graph_details_tool", - "get_graph_schema", - "show_graph_details", - # Node tools - "add_node_tool", - "add_nodes_tool", - "get_node_tool", - "get_nodes_tool", - "delete_node_tool", - "delete_nodes_tool", - "has_node_tool", - "get_node_edges_tool", - "add_node", - "add_nodes", - "get_node", - "get_nodes", - "delete_node", - "delete_nodes", - "has_node", - "get_node_edges", - # Edge tools - "add_edge_tool", - "add_edges_tool", - "get_edge_tool", - "get_edges_tool", - "delete_edge_tool", - "delete_edges_tool", - "has_edge_tool", - "add_edge", - "add_edges", - "get_edge", - "get_edges", - "delete_edge", - "delete_edges", - "has_edge", - # Query tools - "run_query_tool", - "run_installed_query_tool", - "install_query_tool", - "drop_query_tool", - "show_query_tool", - "get_query_metadata_tool", - "is_query_installed_tool", - "get_neighbors_tool", - "run_query", - "run_installed_query", - "install_query", - "drop_query", - "show_query", - "get_query_metadata", - "is_query_installed", - "get_neighbors", - # Loading job tools - "create_loading_job_tool", - "run_loading_job_with_file_tool", - "run_loading_job_with_data_tool", - "get_loading_jobs_tool", - "get_loading_job_status_tool", - "drop_loading_job_tool", - "create_loading_job", - "run_loading_job_with_file", - "run_loading_job_with_data", - "get_loading_jobs", - "get_loading_job_status", - "drop_loading_job", - # Statistics tools - "get_vertex_count_tool", - "get_edge_count_tool", - "get_node_degree_tool", - "get_vertex_count", - "get_edge_count", - "get_node_degree", - # GSQL tools - "gsql_tool", - "gsql", - "generate_gsql_query_tool", - "generate_gsql", - "generate_cypher_query_tool", - "generate_cypher", - # Vector schema tools - "add_vector_attribute_tool", - "drop_vector_attribute_tool", - "list_vector_attributes_tool", - "get_vector_index_status_tool", - "add_vector_attribute", - "drop_vector_attribute", - "list_vector_attributes", - "get_vector_index_status", - # Vector data tools - "upsert_vectors_tool", - "load_vectors_from_csv_tool", - "load_vectors_from_json_tool", - "search_top_k_similarity_tool", - "fetch_vector_tool", - "upsert_vectors", - "load_vectors_from_csv", - "load_vectors_from_json", - "search_top_k_similarity", - "fetch_vector", - # Data Source tools - "create_data_source_tool", - "update_data_source_tool", - "get_data_source_tool", - "drop_data_source_tool", - "get_all_data_sources_tool", - "drop_all_data_sources_tool", - "preview_sample_data_tool", - "create_data_source", - "update_data_source", - "get_data_source", - "drop_data_source", - "get_all_data_sources", - "drop_all_data_sources", - "preview_sample_data", - # Discovery tools - "discover_tools_tool", - "get_workflow_tool", - "get_tool_info_tool", - "discover_tools", - "get_workflow", - "get_tool_info", - # Registry - "get_all_tools", -] - diff --git a/pyTigerGraph/mcp/tools/connection_tools.py b/pyTigerGraph/mcp/tools/connection_tools.py deleted file mode 100644 index 6ada7d82..00000000 --- a/pyTigerGraph/mcp/tools/connection_tools.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Connection profile tools for MCP. - -Allows agents to list available connection profiles and inspect -non-sensitive connection details for a given profile. -""" - -from typing import List, Optional -from pydantic import BaseModel, Field -from mcp.types import Tool, TextContent - -from ..tool_names import TigerGraphToolName -from ..connection_manager import ConnectionManager -from ..response_formatter import format_success, format_error - - -class ListConnectionsToolInput(BaseModel): - """Input schema for listing available connection profiles.""" - - -class ShowConnectionToolInput(BaseModel): - """Input schema for showing connection details.""" - profile: Optional[str] = Field( - None, - description=( - "Connection profile name to inspect. " - "If not provided, shows the active profile (from TG_PROFILE env var or 'default')." - ), - ) - - -list_connections_tool = Tool( - name=TigerGraphToolName.LIST_CONNECTIONS, - description=( - "List all available TigerGraph connection profiles. " - "Profiles are configured via environment variables: " - "the default profile uses TG_HOST, TG_USERNAME, etc., " - "while named profiles use _TG_HOST, _TG_USERNAME, etc." - ), - inputSchema=ListConnectionsToolInput.model_json_schema(), -) - -show_connection_tool = Tool( - name=TigerGraphToolName.SHOW_CONNECTION, - description=( - "Show non-sensitive connection details for a specific profile " - "(host, username, graph name, ports). Never reveals passwords or tokens." - ), - inputSchema=ShowConnectionToolInput.model_json_schema(), -) - - -async def list_connections() -> List[TextContent]: - """List all available connection profiles.""" - try: - profiles = ConnectionManager.list_profiles() - profile_details = [] - for p in profiles: - info = ConnectionManager.get_profile_info(p) - profile_details.append(info) - - return format_success( - operation="list_connections", - summary=f"Found {len(profiles)} connection profile(s): {', '.join(profiles)}", - data={"profiles": profile_details, "count": len(profiles)}, - suggestions=[ - "Show details: show_connection(profile='')", - "Use a profile: pass profile='' to any tool", - ], - ) - except Exception as e: - return format_error( - operation="list_connections", - error=str(e), - ) - - -async def show_connection(profile: Optional[str] = None) -> List[TextContent]: - """Show non-sensitive connection details for a profile.""" - try: - import os - effective = profile or os.getenv("TG_PROFILE", "default") - info = ConnectionManager.get_profile_info(effective) - - return format_success( - operation="show_connection", - summary=f"Connection profile '{effective}': {info['host']}", - data=info, - suggestions=[ - "List all profiles: list_connections()", - f"Use this profile: pass profile='{effective}' to any tool", - ], - ) - except Exception as e: - return format_error( - operation="show_connection", - error=str(e), - ) diff --git a/pyTigerGraph/mcp/tools/data_tools.py b/pyTigerGraph/mcp/tools/data_tools.py deleted file mode 100644 index cd812e9f..00000000 --- a/pyTigerGraph/mcp/tools/data_tools.py +++ /dev/null @@ -1,638 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Data loading tools for MCP. - -These tools use the non-deprecated loading job APIs: -- createLoadingJob - Create a loading job from structured config or GSQL -- runLoadingJobWithFile - Execute loading job with a file -- runLoadingJobWithData - Execute loading job with data string -- getLoadingJobs - List all loading jobs -- getLoadingJobStatus - Get status of a loading job -- dropLoadingJob - Drop a loading job -""" - -import json -from typing import List, Optional, Dict, Any, Union -from pydantic import BaseModel, Field -from mcp.types import Tool, TextContent - -from ..tool_names import TigerGraphToolName -from ..connection_manager import get_connection -from ..response_formatter import format_success, format_error, gsql_has_error -from pyTigerGraph.common.exception import TigerGraphException - - -# ============================================================================= -# Input Models for Loading Job Configuration -# ============================================================================= - -class NodeMapping(BaseModel): - """Mapping configuration for loading vertices.""" - vertex_type: str = Field(..., description="Target vertex type name.") - attribute_mappings: Dict[str, Union[str, int]] = Field( - ..., - description="Map of attribute name to column index (int) or header name (string). Must include the primary key. Example: {'id': 0, 'name': 1} or {'id': 'user_id', 'name': 'user_name'}" - ) - - -class EdgeMapping(BaseModel): - """Mapping configuration for loading edges.""" - edge_type: str = Field(..., description="Target edge type name.") - source_column: Union[str, int] = Field(..., description="Column for source vertex ID (string for header name, int for column index).") - target_column: Union[str, int] = Field(..., description="Column for target vertex ID (string for header name, int for column index).") - attribute_mappings: Optional[Dict[str, Union[str, int]]] = Field( - default_factory=dict, - description="Map of attribute name to column. Optional for edges without attributes." - ) - - -class FileConfig(BaseModel): - """Configuration for a single data file in a loading job.""" - file_alias: str = Field(..., description="Alias for the file (used in DEFINE FILENAME).") - file_path: Optional[str] = Field(None, description="Path to the file. If not provided, data will be passed at runtime.") - separator: str = Field(",", description="Field separator character.") - header: str = Field("true", description="Whether the file has a header row ('true' or 'false').") - eol: str = Field("\\n", description="End-of-line character.") - quote: Optional[str] = Field(None, description="Quote character for CSV (e.g., 'DOUBLE' for double quotes).") - node_mappings: List[NodeMapping] = Field( - default_factory=list, - description="List of vertex loading mappings. Example: [{'vertex_type': 'Person', 'attribute_mappings': {'id': 0, 'name': 1}}]" - ) - edge_mappings: List[EdgeMapping] = Field( - default_factory=list, - description="List of edge loading mappings." - ) - - -class CreateLoadingJobToolInput(BaseModel): - """Input schema for creating a loading job.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - job_name: str = Field(..., description="Name for the loading job.") - files: List[FileConfig] = Field( - ..., - description="List of file configurations. Each file must have a 'file_alias' and 'node_mappings' and/or 'edge_mappings'. Example: [{'file_alias': 'f1', 'node_mappings': [...]}]" - ) - run_job: bool = Field(False, description="If True, run the loading job immediately after creation.") - drop_after_run: bool = Field(False, description="If True, drop the job after running (only applies if run_job=True).") - - -# ============================================================================= -# Input Models for Other Operations -# ============================================================================= - -class RunLoadingJobWithFileToolInput(BaseModel): - """Input schema for running a loading job with a file.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - file_path: str = Field(..., description="Absolute path to the data file to load. Example: '/home/user/data/persons.csv'") - file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") - job_name: str = Field(..., description="The name of the loading job to run.") - separator: Optional[str] = Field(None, description="Data value separator. Default is comma. For JSON data, don't specify.") - eol: Optional[str] = Field(None, description="End-of-line character. Default is '\\n'. Supports '\\r\\n'.") - timeout: int = Field(16000, description="Timeout in milliseconds. Set to 0 for system-wide timeout.") - size_limit: int = Field(128000000, description="Maximum size for input file in bytes (default 128MB).") - - -class RunLoadingJobWithDataToolInput(BaseModel): - """Input schema for running a loading job with inline data.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - data: str = Field(..., description="The data string to load (CSV, JSON, etc.). Example: 'user1,Alice\\nuser2,Bob'") - file_tag: str = Field(..., description="The name of file variable in the loading job (DEFINE FILENAME ).") - job_name: str = Field(..., description="The name of the loading job to run.") - separator: Optional[str] = Field(None, description="Data value separator. Default is comma. For JSON data, don't specify.") - eol: Optional[str] = Field(None, description="End-of-line character. Default is '\\n'. Supports '\\r\\n'.") - timeout: int = Field(16000, description="Timeout in milliseconds. Set to 0 for system-wide timeout.") - size_limit: int = Field(128000000, description="Maximum size for input data in bytes (default 128MB).") - - -class GetLoadingJobsToolInput(BaseModel): - """Input schema for listing loading jobs.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - - -class GetLoadingJobStatusToolInput(BaseModel): - """Input schema for getting loading job status.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - job_id: str = Field(..., description="The ID of the loading job to check status.") - - -class DropLoadingJobToolInput(BaseModel): - """Input schema for dropping a loading job.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - graph_name: Optional[str] = Field(None, description="Name of the graph. If not provided, uses default connection.") - job_name: str = Field(..., description="The name of the loading job to drop.") - - -# ============================================================================= -# Tool Definitions -# ============================================================================= - -create_loading_job_tool = Tool( - name=TigerGraphToolName.CREATE_LOADING_JOB, - description="""Create a loading job from structured configuration. -The job defines how to load data from files into vertices and edges. -Each file config specifies: file alias, separator, header, EOL, and mappings. -Node mappings define which columns map to vertex attributes. -Edge mappings define source/target columns and edge attributes. -Optionally run the job immediately and drop it after execution.""", - inputSchema=CreateLoadingJobToolInput.model_json_schema(), -) - -run_loading_job_with_file_tool = Tool( - name=TigerGraphToolName.RUN_LOADING_JOB_WITH_FILE, - description="Execute a loading job with a data file. The file is uploaded to TigerGraph and loaded according to the specified loading job definition.", - inputSchema=RunLoadingJobWithFileToolInput.model_json_schema(), -) - -run_loading_job_with_data_tool = Tool( - name=TigerGraphToolName.RUN_LOADING_JOB_WITH_DATA, - description="Execute a loading job with inline data string. The data is posted to TigerGraph and loaded according to the specified loading job definition.", - inputSchema=RunLoadingJobWithDataToolInput.model_json_schema(), -) - -get_loading_jobs_tool = Tool( - name=TigerGraphToolName.GET_LOADING_JOBS, - description="Get a list of all loading jobs defined for the current graph.", - inputSchema=GetLoadingJobsToolInput.model_json_schema(), -) - -get_loading_job_status_tool = Tool( - name=TigerGraphToolName.GET_LOADING_JOB_STATUS, - description="Get the status of a specific loading job by its job ID.", - inputSchema=GetLoadingJobStatusToolInput.model_json_schema(), -) - -drop_loading_job_tool = Tool( - name=TigerGraphToolName.DROP_LOADING_JOB, - description="Drop (delete) a loading job from the graph.", - inputSchema=DropLoadingJobToolInput.model_json_schema(), -) - - -# ============================================================================= -# Helper Functions -# ============================================================================= - -def _format_column(column: Union[str, int]) -> str: - """Format column reference for GSQL loading job.""" - if isinstance(column, int): - return f"${column}" - return f'$"{column}"' - - -def _generate_loading_job_gsql( - graph_name: str, - job_name: str, - files: List[Dict[str, Any]], -) -> str: - """Generate GSQL script for creating a loading job.""" - - # Build DEFINE FILENAME statements - define_files = [] - for file_config in files: - alias = file_config["file_alias"] - path = file_config.get("file_path") - if path: - define_files.append(f'DEFINE FILENAME {alias} = "{path}";') - else: - define_files.append(f"DEFINE FILENAME {alias};") - - # Build LOAD statements for each file - load_statements = [] - for file_config in files: - alias = file_config["file_alias"] - separator = file_config.get("separator", ",") - header = file_config.get("header", "true") - eol = file_config.get("eol", "\\n") - quote = file_config.get("quote") - - # Build USING clause - using_parts = [ - f'SEPARATOR="{separator}"', - f'HEADER="{header}"', - f'EOL="{eol}"' - ] - if quote: - using_parts.append(f'QUOTE="{quote}"') - using_clause = "USING " + ", ".join(using_parts) + ";" - - # Build mapping statements - mapping_statements = [] - - # Node mappings - for node_mapping in file_config.get("node_mappings", []): - vertex_type = node_mapping["vertex_type"] - attr_mappings = node_mapping["attribute_mappings"] - - # Format attribute values - attr_values = ", ".join( - _format_column(col) for col in attr_mappings.values() - ) - mapping_statements.append( - f"TO VERTEX {vertex_type} VALUES({attr_values})" - ) - - # Edge mappings - for edge_mapping in file_config.get("edge_mappings", []): - edge_type = edge_mapping["edge_type"] - source_col = _format_column(edge_mapping["source_column"]) - target_col = _format_column(edge_mapping["target_column"]) - attr_mappings = edge_mapping.get("attribute_mappings", {}) - - # Format attribute values - if attr_mappings: - attr_values = ", ".join( - _format_column(col) for col in attr_mappings.values() - ) - all_values = f"{source_col}, {target_col}, {attr_values}" - else: - all_values = f"{source_col}, {target_col}" - - mapping_statements.append( - f"TO EDGE {edge_type} VALUES({all_values})" - ) - - # Combine into LOAD statement - if mapping_statements: - load_stmt = f"LOAD {alias}\n " + ",\n ".join(mapping_statements) + f"\n {using_clause}" - load_statements.append(load_stmt) - - # Build the complete GSQL script - define_section = " # Define files\n " + "\n ".join(define_files) - load_section = " # Load data\n " + "\n ".join(load_statements) - - gsql_script = f"""USE GRAPH {graph_name} - -CREATE LOADING JOB {job_name} FOR GRAPH {graph_name} {{ -{define_section} - -{load_section} -}}""" - - return gsql_script - - -# ============================================================================= -# Tool Implementations -# ============================================================================= - -async def create_loading_job( - job_name: str, - files: List[Dict[str, Any]], - run_job: bool = False, - drop_after_run: bool = False, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Create a loading job from structured configuration.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - - # Generate the GSQL script - gsql_script = _generate_loading_job_gsql( - graph_name=conn.graphname, - job_name=job_name, - files=files - ) - - # Add RUN and DROP commands if requested - if run_job: - gsql_script += f"\n\nRUN LOADING JOB {job_name}" - if drop_after_run: - gsql_script += f"\n\nDROP JOB {job_name}" - - result = await conn.gsql(gsql_script) - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="create_loading_job", - error=TigerGraphException(result_str), - context={ - "job_name": job_name, - "graph_name": conn.graphname, - "gsql_script": gsql_script, - }, - suggestions=[ - "Check that vertex/edge types referenced in the job exist in the schema", - "Use show_graph_details() to verify the current schema", - "Ensure file paths and column mappings are correct", - ], - ) - - status_parts = [] - if run_job: - if drop_after_run: - status_parts.append("Job created, executed, and dropped (one-time load)") - else: - status_parts.append("Job created and executed") - else: - status_parts.append("Job created successfully") - - return format_success( - operation="create_loading_job", - summary=f"Success: Loading job '{job_name}' " + ", ".join(status_parts), - data={ - "job_name": job_name, - "file_count": len(files), - "executed": run_job, - "dropped": drop_after_run, - "gsql_script": gsql_script, - "result": result_str, - }, - suggestions=[s for s in [ - f"Run the job: run_loading_job_with_file(job_name='{job_name}', ...)" if not run_job else "Job already executed", - "List all jobs: get_loading_jobs()", - f"Get status: get_loading_job_status(job_name='{job_name}')" if not drop_after_run else None, - "Tip: Loading jobs are the recommended way to bulk-load data" - ] if s is not None], - metadata={ - "graph_name": conn.graphname, - "operation_type": "DDL" - } - ) - - except Exception as e: - return format_error( - operation="create_loading_job", - error=e, - context={ - "job_name": job_name, - "file_count": len(files), - "graph_name": graph_name or "default" - } - ) - - -async def run_loading_job_with_file( - file_path: str, - file_tag: str, - job_name: str, - separator: Optional[str] = None, - eol: Optional[str] = None, - timeout: int = 16000, - size_limit: int = 128000000, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Execute a loading job with a data file.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - result = await conn.runLoadingJobWithFile( - filePath=file_path, - fileTag=file_tag, - jobName=job_name, - sep=separator, - eol=eol, - timeout=timeout, - sizeLimit=size_limit - ) - if result: - return format_success( - operation="run_loading_job_with_file", - summary=f"Success: Loading job '{job_name}' executed successfully with file '{file_path}'", - data={ - "job_name": job_name, - "file_path": file_path, - "file_tag": file_tag, - "result": result - }, - suggestions=[ - f"Check status: get_loading_job_status(job_id='')", - "Verify loaded data with: get_vertex_count() or get_edge_count()", - "List all jobs: get_loading_jobs()" - ], - metadata={"graph_name": conn.graphname} - ) - else: - return format_error( - operation="run_loading_job_with_file", - error=ValueError("Loading job returned no result"), - context={ - "job_name": job_name, - "file_path": file_path, - "file_tag": file_tag, - "graph_name": graph_name or "default" - }, - suggestions=[ - "Check if the job name is correct", - "Verify the file_tag matches the loading job definition", - "Ensure the loading job exists: get_loading_jobs()" - ] - ) - except Exception as e: - return format_error( - operation="run_loading_job_with_file", - error=e, - context={ - "job_name": job_name, - "file_path": file_path, - "graph_name": graph_name or "default" - } - ) - - -async def run_loading_job_with_data( - data: str, - file_tag: str, - job_name: str, - separator: Optional[str] = None, - eol: Optional[str] = None, - timeout: int = 16000, - size_limit: int = 128000000, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Execute a loading job with inline data string.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - result = await conn.runLoadingJobWithData( - data=data, - fileTag=file_tag, - jobName=job_name, - sep=separator, - eol=eol, - timeout=timeout, - sizeLimit=size_limit - ) - if result: - data_preview = data[:100] + "..." if len(data) > 100 else data - return format_success( - operation="run_loading_job_with_data", - summary=f"Success: Loading job '{job_name}' executed successfully with inline data", - data={ - "job_name": job_name, - "file_tag": file_tag, - "data_preview": data_preview, - "data_size": len(data), - "result": result - }, - suggestions=[ - "Verify loaded data: get_vertex_count() or get_edge_count()", - "Tip: For large datasets, use 'run_loading_job_with_file' instead", - "List all jobs: get_loading_jobs()" - ], - metadata={"graph_name": conn.graphname} - ) - else: - return format_error( - operation="run_loading_job_with_data", - error=ValueError("Loading job returned no result"), - context={ - "job_name": job_name, - "file_tag": file_tag, - "data_size": len(data), - "graph_name": graph_name or "default" - }, - suggestions=[ - "Check if the job name is correct", - "Verify the file_tag matches the loading job definition", - "Ensure the loading job exists: get_loading_jobs()" - ] - ) - except Exception as e: - return format_error( - operation="run_loading_job_with_data", - error=e, - context={ - "job_name": job_name, - "data_size": len(data), - "graph_name": graph_name or "default" - } - ) - - -async def get_loading_jobs( - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Get a list of all loading jobs for the current graph.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - result = await conn.getLoadingJobs() - if result: - job_count = len(result) if isinstance(result, list) else 1 - return format_success( - operation="get_loading_jobs", - summary=f"Found {job_count} loading job(s) for graph '{conn.graphname}'", - data={ - "jobs": result, - "count": job_count - }, - suggestions=[ - "Run a job: run_loading_job_with_file(...) or run_loading_job_with_data(...)", - "Create new job: create_loading_job(...)", - "Check job status: get_loading_job_status(job_id='')" - ], - metadata={"graph_name": conn.graphname} - ) - else: - return format_success( - operation="get_loading_jobs", - summary=f"Success: No loading jobs found for graph '{conn.graphname}'", - suggestions=[ - "Create a loading job: create_loading_job(...)", - "Tip: Loading jobs are used for bulk data ingestion" - ], - metadata={"graph_name": conn.graphname} - ) - except Exception as e: - return format_error( - operation="get_loading_jobs", - error=e, - context={"graph_name": graph_name or "default"} - ) - - -async def get_loading_job_status( - job_id: str, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Get the status of a specific loading job.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - result = await conn.getLoadingJobStatus(jobId=job_id) - if result: - return format_success( - operation="get_loading_job_status", - summary=f"Success: Loading job status for '{job_id}'", - data={ - "job_id": job_id, - "status": result - }, - suggestions=[ - "List all jobs: get_loading_jobs()", - "Tip: Use this to monitor long-running loading jobs" - ], - metadata={"graph_name": conn.graphname} - ) - else: - return format_error( - operation="get_loading_job_status", - error=ValueError("No status found for loading job"), - context={ - "job_id": job_id, - "graph_name": graph_name or "default" - }, - suggestions=[ - "Verify the job_id is correct", - "List all jobs: get_loading_jobs()" - ] - ) - except Exception as e: - return format_error( - operation="get_loading_job_status", - error=e, - context={ - "job_id": job_id, - "graph_name": graph_name or "default" - } - ) - - -async def drop_loading_job( - job_name: str, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Drop a loading job from the graph.""" - try: - conn = get_connection(profile=profile, graph_name=graph_name) - result = await conn.dropLoadingJob(jobName=job_name) - - return format_success( - operation="drop_loading_job", - summary=f"Success: Loading job '{job_name}' dropped successfully", - data={ - "job_name": job_name, - "result": result - }, - suggestions=[ - "Warning: This operation is permanent and cannot be undone", - "Verify deletion: get_loading_jobs()", - "Create a new job: create_loading_job(...)" - ], - metadata={ - "graph_name": conn.graphname, - "destructive": True - } - ) - except Exception as e: - return format_error( - operation="drop_loading_job", - error=e, - context={ - "job_name": job_name, - "graph_name": graph_name or "default" - } - ) diff --git a/pyTigerGraph/mcp/tools/datasource_tools.py b/pyTigerGraph/mcp/tools/datasource_tools.py deleted file mode 100644 index c546e649..00000000 --- a/pyTigerGraph/mcp/tools/datasource_tools.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Data source operation tools for MCP.""" - -from typing import List, Optional, Dict, Any -from pydantic import BaseModel, Field -from mcp.types import Tool, TextContent - -from ..tool_names import TigerGraphToolName -from ..connection_manager import get_connection - - -class CreateDataSourceToolInput(BaseModel): - """Input schema for creating a data source.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - data_source_name: str = Field(..., description="Name of the data source.") - data_source_type: str = Field(..., description="Type of data source: 's3', 'gcs', 'azure_blob', or 'local'.") - config: Dict[str, Any] = Field(..., description="Configuration for the data source (e.g., bucket, credentials).") - - -class UpdateDataSourceToolInput(BaseModel): - """Input schema for updating a data source.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - data_source_name: str = Field(..., description="Name of the data source to update.") - config: Dict[str, Any] = Field(..., description="Updated configuration for the data source.") - - -class GetDataSourceToolInput(BaseModel): - """Input schema for getting a data source.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - data_source_name: str = Field(..., description="Name of the data source.") - - -class DropDataSourceToolInput(BaseModel): - """Input schema for dropping a data source.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - data_source_name: str = Field(..., description="Name of the data source to drop.") - - -class GetAllDataSourcesToolInput(BaseModel): - """Input schema for getting all data sources.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - - -class DropAllDataSourcesToolInput(BaseModel): - """Input schema for dropping all data sources.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - confirm: bool = Field(False, description="Must be True to confirm dropping all data sources.") - - -class PreviewSampleDataToolInput(BaseModel): - """Input schema for previewing sample data.""" - profile: Optional[str] = Field(None, description="Connection profile name. If not provided, uses TG_PROFILE env var or 'default'. Use 'list_connections' to see available profiles.") - data_source_name: str = Field(..., description="Name of the data source.") - file_path: str = Field(..., description="Path to the file within the data source.") - num_rows: int = Field(10, description="Number of sample rows to preview.") - graph_name: Optional[str] = Field(None, description="Name of the graph context. If not provided, uses default connection.") - - -create_data_source_tool = Tool( - name=TigerGraphToolName.CREATE_DATA_SOURCE, - description="Create a new data source for loading data (S3, GCS, Azure Blob, or local).", - inputSchema=CreateDataSourceToolInput.model_json_schema(), -) - -update_data_source_tool = Tool( - name=TigerGraphToolName.UPDATE_DATA_SOURCE, - description="Update an existing data source configuration.", - inputSchema=UpdateDataSourceToolInput.model_json_schema(), -) - -get_data_source_tool = Tool( - name=TigerGraphToolName.GET_DATA_SOURCE, - description="Get information about a specific data source.", - inputSchema=GetDataSourceToolInput.model_json_schema(), -) - -drop_data_source_tool = Tool( - name=TigerGraphToolName.DROP_DATA_SOURCE, - description="Drop (delete) a data source.", - inputSchema=DropDataSourceToolInput.model_json_schema(), -) - -get_all_data_sources_tool = Tool( - name=TigerGraphToolName.GET_ALL_DATA_SOURCES, - description="Get information about all data sources.", - inputSchema=GetAllDataSourcesToolInput.model_json_schema(), -) - -drop_all_data_sources_tool = Tool( - name=TigerGraphToolName.DROP_ALL_DATA_SOURCES, - description="Drop all data sources. WARNING: This is a destructive operation.", - inputSchema=DropAllDataSourcesToolInput.model_json_schema(), -) - -preview_sample_data_tool = Tool( - name=TigerGraphToolName.PREVIEW_SAMPLE_DATA, - description="Preview sample data from a file in a data source.", - inputSchema=PreviewSampleDataToolInput.model_json_schema(), -) - - -async def create_data_source( - data_source_name: str, - data_source_type: str, - config: Dict[str, Any], - profile: Optional[str] = None, -) -> List[TextContent]: - """Create a new data source.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile) - - config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) - - gsql_cmd = f"CREATE DATA_SOURCE {data_source_type.upper()} {data_source_name}" - if config_str: - gsql_cmd += f" = ({config_str})" - - result = await conn.gsql(gsql_cmd) - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="create_data_source", - error=Exception(f"Could not create data source:\n{result_str}"), - context={"data_source_name": data_source_name, "data_source_type": data_source_type}, - ) - - return format_success( - operation="create_data_source", - summary=f"Data source '{data_source_name}' of type '{data_source_type}' created successfully", - data={"data_source_name": data_source_name, "result": result_str}, - suggestions=[ - f"View data source: get_data_source(data_source_name='{data_source_name}')", - "List all data sources: get_all_data_sources()", - ], - ) - except Exception as e: - return format_error( - operation="create_data_source", - error=e, - context={"data_source_name": data_source_name}, - ) - - -async def update_data_source( - data_source_name: str, - config: Dict[str, Any], - profile: Optional[str] = None, -) -> List[TextContent]: - """Update an existing data source.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile) - - config_str = ", ".join([f'{k}="{v}"' for k, v in config.items()]) - gsql_cmd = f"ALTER DATA_SOURCE {data_source_name} = ({config_str})" - - result = await conn.gsql(gsql_cmd) - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="update_data_source", - error=Exception(f"Could not update data source:\n{result_str}"), - context={"data_source_name": data_source_name}, - ) - - return format_success( - operation="update_data_source", - summary=f"Data source '{data_source_name}' updated successfully", - data={"data_source_name": data_source_name, "result": result_str}, - ) - except Exception as e: - return format_error( - operation="update_data_source", - error=e, - context={"data_source_name": data_source_name}, - ) - - -async def get_data_source( - data_source_name: str, - profile: Optional[str] = None, -) -> List[TextContent]: - """Get information about a data source.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile) - - result = await conn.gsql(f"SHOW DATA_SOURCE {data_source_name}") - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="get_data_source", - error=Exception(f"Could not retrieve data source:\n{result_str}"), - context={"data_source_name": data_source_name}, - ) - - return format_success( - operation="get_data_source", - summary=f"Data source '{data_source_name}' details", - data={"data_source_name": data_source_name, "details": result_str}, - ) - except Exception as e: - return format_error( - operation="get_data_source", - error=e, - context={"data_source_name": data_source_name}, - ) - - -async def drop_data_source( - data_source_name: str, - profile: Optional[str] = None, -) -> List[TextContent]: - """Drop a data source.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile) - - result = await conn.gsql(f"DROP DATA_SOURCE {data_source_name}") - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="drop_data_source", - error=Exception(f"Could not drop data source:\n{result_str}"), - context={"data_source_name": data_source_name}, - ) - - return format_success( - operation="drop_data_source", - summary=f"Data source '{data_source_name}' dropped successfully", - data={"data_source_name": data_source_name, "result": result_str}, - suggestions=["List remaining: get_all_data_sources()"], - metadata={"destructive": True}, - ) - except Exception as e: - return format_error( - operation="drop_data_source", - error=e, - context={"data_source_name": data_source_name}, - ) - - -async def get_all_data_sources( - profile: Optional[str] = None, - **kwargs, -) -> List[TextContent]: - """Get all data sources.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile) - - result = await conn.gsql("SHOW DATA_SOURCE *") - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="get_all_data_sources", - error=Exception(f"Could not retrieve data sources:\n{result_str}"), - context={}, - ) - - return format_success( - operation="get_all_data_sources", - summary="All data sources retrieved", - data={"details": result_str}, - suggestions=["Create a data source: create_data_source(...)"], - ) - except Exception as e: - return format_error( - operation="get_all_data_sources", - error=e, - context={}, - ) - - -async def drop_all_data_sources( - profile: Optional[str] = None, - confirm: bool = False, -) -> List[TextContent]: - """Drop all data sources.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - if not confirm: - return format_error( - operation="drop_all_data_sources", - error=ValueError("Confirmation required"), - context={}, - suggestions=[ - "Set confirm=True to proceed with this destructive operation", - "This will drop ALL data sources", - ], - ) - - try: - conn = get_connection(profile=profile) - - result = await conn.gsql("DROP DATA_SOURCE *") - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="drop_all_data_sources", - error=Exception(f"Could not drop all data sources:\n{result_str}"), - context={}, - ) - - return format_success( - operation="drop_all_data_sources", - summary="All data sources dropped successfully", - data={"result": result_str}, - metadata={"destructive": True}, - ) - except Exception as e: - return format_error( - operation="drop_all_data_sources", - error=e, - context={}, - ) - - -async def preview_sample_data( - data_source_name: str, - file_path: str, - num_rows: int = 10, - profile: Optional[str] = None, - graph_name: Optional[str] = None, -) -> List[TextContent]: - """Preview sample data from a file.""" - from ..response_formatter import format_success, format_error, gsql_has_error - - try: - conn = get_connection(profile=profile, graph_name=graph_name) - - gsql_cmd = ( - f"USE GRAPH {conn.graphname}\n" - f'SHOW DATA_SOURCE {data_source_name} FILE "{file_path}" LIMIT {num_rows}' - ) - - result = await conn.gsql(gsql_cmd) - result_str = str(result) if result else "" - - if gsql_has_error(result_str): - return format_error( - operation="preview_sample_data", - error=Exception(f"Could not preview data:\n{result_str}"), - context={"data_source_name": data_source_name, "file_path": file_path}, - ) - - return format_success( - operation="preview_sample_data", - summary=f"Sample data from '{file_path}' (first {num_rows} rows)", - data={"data_source_name": data_source_name, "file_path": file_path, "preview": result_str}, - metadata={"graph_name": conn.graphname}, - ) - except Exception as e: - return format_error( - operation="preview_sample_data", - error=e, - context={"data_source_name": data_source_name, "file_path": file_path}, - ) - diff --git a/pyTigerGraph/mcp/tools/discovery_tools.py b/pyTigerGraph/mcp/tools/discovery_tools.py deleted file mode 100644 index a09f6605..00000000 --- a/pyTigerGraph/mcp/tools/discovery_tools.py +++ /dev/null @@ -1,611 +0,0 @@ -# Copyright 2025 TigerGraph Inc. -# Licensed under the Apache License, Version 2.0. -# See the LICENSE file or https://www.apache.org/licenses/LICENSE-2.0 -# -# Permission is granted to use, copy, modify, and distribute this software -# under the License. The software is provided "AS IS", without warranty. - -"""Discovery and navigation tools for LLMs. - -These tools help LLMs discover the right tools for their tasks and understand -common workflows. -""" - -import json -from typing import List, Optional -from pydantic import BaseModel, Field -from mcp.types import Tool, TextContent - -from ..tool_names import TigerGraphToolName -from ..tool_metadata import TOOL_METADATA, ToolCategory, search_tools_by_keywords, get_tools_by_category -from ..response_formatter import format_success, format_list_response - - -class ToolDiscoveryInput(BaseModel): - """Input for discovering relevant tools.""" - task_description: str = Field( - ..., - description=( - "Describe what you want to accomplish in natural language.\n" - "Examples:\n" - " - 'add multiple users to the graph'\n" - " - 'find similar documents using embeddings'\n" - " - 'understand the graph structure'\n" - " - 'load data from a CSV file'" - ) - ) - category: Optional[str] = Field( - None, - description=( - "Filter by category: 'schema', 'data', 'query', 'vector', 'loading', 'utility'.\n" - "Leave empty to search all categories." - ) - ) - limit: int = Field( - 5, - description="Maximum number of tools to return (default: 5)" - ) - - -class GetWorkflowInput(BaseModel): - """Input for getting workflow templates.""" - workflow_type: str = Field( - ..., - description=( - "Type of workflow to retrieve:\n" - " - 'create_graph': Set up a new graph with schema\n" - " - 'load_data': Import data into an existing graph\n" - " - 'query_data': Query and analyze graph data\n" - " - 'vector_search': Set up and use vector similarity search\n" - " - 'graph_analysis': Analyze graph structure and statistics\n" - " - 'setup_connection': Initial connection setup and verification" - ) - ) - - -class GetToolInfoInput(BaseModel): - """Input for getting detailed information about a specific tool.""" - tool_name: str = Field( - ..., - description=( - "Name of the tool to get information about.\n" - "Example: 'tigergraph__add_node' or 'tigergraph__search_top_k_similarity'" - ) - ) - - -# Tool definitions -discover_tools_tool = Tool( - name=TigerGraphToolName.DISCOVER_TOOLS, - description=( - "Discover which TigerGraph tools are relevant for your task.\n\n" - "**Use this tool when:**\n" - " - You're unsure which tool to use for your goal\n" - " - You want to explore available capabilities\n" - " - You need suggestions for accomplishing a task\n\n" - "**Returns:**\n" - " - List of recommended tools with descriptions\n" - " - Use cases and complexity ratings\n" - " - Prerequisites and related tools\n" - " - Example parameters\n\n" - "**Example:**\n" - " task_description: 'I want to add multiple users to the graph'" - ), - inputSchema=ToolDiscoveryInput.model_json_schema(), -) - -get_workflow_tool = Tool( - name=TigerGraphToolName.GET_WORKFLOW, - description=( - "Get a step-by-step workflow template for common TigerGraph tasks.\n\n" - "**Use this tool when:**\n" - " - You need to complete a complex multi-step task\n" - " - You want to follow best practices\n" - " - You're new to TigerGraph and need guidance\n\n" - "**Returns:**\n" - " - Ordered list of tools to use\n" - " - Example parameters for each step\n" - " - Explanations of what each step accomplishes\n\n" - "**Available workflows:** create_graph, load_data, query_data, vector_search, graph_analysis, setup_connection" - ), - inputSchema=GetWorkflowInput.model_json_schema(), -) - -get_tool_info_tool = Tool( - name=TigerGraphToolName.GET_TOOL_INFO, - description=( - "Get detailed information about a specific TigerGraph tool.\n\n" - "**Use this tool when:**\n" - " - You want to understand a tool's capabilities\n" - " - You need examples of how to use a tool\n" - " - You want to know prerequisites or related tools\n\n" - "**Returns:**\n" - " - Detailed tool description\n" - " - Use cases and examples\n" - " - Prerequisites and related tools\n" - " - Common next steps" - ), - inputSchema=GetToolInfoInput.model_json_schema(), -) - - -# Workflow templates -WORKFLOWS = { - "setup_connection": { - "name": "Setup and Verify Connection", - "description": "Initial setup to verify connection and explore available graphs", - "steps": [ - { - "step": 1, - "tool": "tigergraph__list_graphs", - "description": "List all available graphs to see what exists", - "parameters": {}, - "rationale": "First, discover what graphs are available in your TigerGraph instance" - }, - { - "step": 2, - "tool": "tigergraph__show_graph_details", - "description": "Get detailed schema of a specific graph", - "parameters": {"graph_name": "