diff --git a/CNAME b/CNAME new file mode 100644 index 0000000..9654456 --- /dev/null +++ b/CNAME @@ -0,0 +1 @@ +hypmcp.talkincode.net diff --git a/main.py b/main.py index 719ca59..1c7096f 100644 --- a/main.py +++ b/main.py @@ -6,10 +6,11 @@ from dotenv import load_dotenv from fastmcp import FastMCP -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from pydantic import ValidationError as PydanticValidationError from services.hyperliquid_services import HyperliquidServices -from services.validators import ValidationError, validate_order_inputs +from services.validators import ValidationError, validate_coin, validate_order_inputs # Load environment variables load_dotenv() @@ -84,6 +85,53 @@ def initialize_service(): logger.info(f"Service initialized for account: {account_info}") +class CandlesSnapshotParams(BaseModel): + """Bulk candles snapshot request parameters""" + + coins: list[str] = Field(..., min_length=1, description="List of trading pairs") + interval: str = Field( + ..., description="Candlestick interval supported by HyperLiquid" + ) + start_time: int | None = Field( + default=None, + description="Start timestamp in milliseconds", + ) + end_time: int | None = Field( + default=None, + description="End timestamp in milliseconds", + ) + days: int | None = Field( + default=None, + gt=0, + description="Fetch recent N days (mutually exclusive with start/end)", + ) + limit: int | None = Field( + default=None, + gt=0, + le=5000, + description="Maximum number of candles per coin (latest N records)", + ) + + @model_validator(mode="after") + def validate_time_params(self): + if self.days is not None and ( + self.start_time is not None or self.end_time is not None + ): + raise ValueError("days cannot be used together with start_time or end_time") + + if self.days is None and self.start_time is None: + raise ValueError("start_time is required when days is not provided") + + if ( + self.start_time is not None + and self.end_time is not None + and self.start_time >= self.end_time + ): + raise ValueError("start_time must be less than end_time") + + return self + + # Account Management Tools @@ -369,6 +417,103 @@ async def get_orderbook(coin: str, depth: int = 20) -> dict[str, Any]: return await hyperliquid_service.get_orderbook(coin, depth) +@mcp.tool +async def get_candles_snapshot( + coins: list[str], + interval: str, + start_time: int | None = None, + end_time: int | None = None, + days: int | None = None, + limit: int | None = None, +) -> dict[str, Any]: + """ + Fetch candlestick (OHLCV) data for multiple coins in one request + + Args: + coins: List of trading pairs (e.g., ["BTC", "ETH"]) + interval: Candlestick interval supported by HyperLiquid (e.g., "1m", "1h") + start_time: Start timestamp in milliseconds (required when days not provided) + end_time: End timestamp in milliseconds (defaults to now when omitted) + days: Number of recent days to fetch (mutually exclusive with start/end) + limit: Optional max number of candles per coin (latest N samples) + """ + + initialize_service() + + try: + params = CandlesSnapshotParams( + coins=coins, + interval=interval, + start_time=start_time, + end_time=end_time, + days=days, + limit=limit, + ) + except PydanticValidationError as validation_error: + return { + "success": False, + "error": f"Invalid input: {validation_error.errors()}", + "error_code": "VALIDATION_ERROR", + } + except ValueError as validation_error: + return { + "success": False, + "error": f"Invalid input: {str(validation_error)}", + "error_code": "VALIDATION_ERROR", + } + + # Validate each coin using existing validator for consistency + for coin in params.coins: + try: + validate_coin(coin) + except ValidationError as validation_error: + return { + "success": False, + "error": f"Invalid input: {str(validation_error)}", + "error_code": "VALIDATION_ERROR", + } + + service_result = await hyperliquid_service.get_candles_snapshot_bulk( + coins=params.coins, + interval=params.interval, + start_time=params.start_time, + end_time=params.end_time, + days=params.days, + ) + + if not service_result.get("success"): + return service_result + + candles_data = service_result.get("data", {}) + applied_limit = params.limit or None + + if applied_limit is not None: + limited_data = {} + for coin, candles in candles_data.items(): + if not isinstance(candles, list): + limited_data[coin] = candles + continue + limited_data[coin] = candles[-applied_limit:] + candles_data = limited_data + + response: dict[str, Any] = { + "success": True, + "data": candles_data, + "interval": service_result.get("interval"), + "start_time": service_result.get("start_time"), + "end_time": service_result.get("end_time"), + "requested_coins": params.coins, + } + + if applied_limit is not None: + response["limit_per_coin"] = applied_limit + + if service_result.get("coin_errors"): + response["coin_errors"] = service_result["coin_errors"] + + return response + + @mcp.tool async def get_funding_history(coin: str, days: int = 7) -> dict[str, Any]: """ @@ -636,6 +781,23 @@ def start_server(): ) logger.info(f"Logs will be written to: {log_path}") + # Log all registered tools BEFORE starting server + if hasattr(mcp, "_tool_manager") and hasattr(mcp._tool_manager, "_tools"): + tools_dict = mcp._tool_manager._tools + tool_names = sorted(tools_dict.keys()) + + print("\n" + "=" * 60) + print(f"✅ {len(tool_names)} MCP Tools Registered:") + print("=" * 60) + + for i, tool_name in enumerate(tool_names, 1): + marker = "🆕" if tool_name == "get_candles_snapshot" else " " + print(f"{marker} {i:2d}. {tool_name}") + + print("=" * 60 + "\n") + else: + print("\n⚠️ Cannot verify tool registration\n") + asyncio.run(run_as_server()) except Exception as e: logger.error(f"Failed to start server: {e}") diff --git a/services/hyperliquid_services.py b/services/hyperliquid_services.py index c7772fe..f5050de 100644 --- a/services/hyperliquid_services.py +++ b/services/hyperliquid_services.py @@ -603,6 +603,137 @@ async def get_orderbook(self, coin: str, depth: int = 20) -> dict[str, Any]: ) return {"success": False, "error": str(e)} + async def get_candles_snapshot_bulk( + self, + coins: list[str], + interval: str, + start_time: int | None = None, + end_time: int | None = None, + days: int | None = None, + ) -> dict[str, Any]: + """ + Retrieve candlestick data for multiple coins in a single call + + Args: + coins: List of trading pairs (e.g., ["BTC", "ETH"]) + interval: Candlestick interval string accepted by HyperLiquid + start_time: Optional start timestamp (ms). Required when days is None. + end_time: Optional end timestamp (ms). Defaults to current time when omitted. + days: Optional number of recent days to fetch. Mutually exclusive with start/end. + """ + + try: + if not isinstance(coins, list) or not coins: + raise ValueError("coins must be a non-empty list of strings") + + normalized_coins: list[str] = [] + for coin in coins: + if not coin or not isinstance(coin, str): + raise ValueError("each coin must be a non-empty string") + coin_clean = coin.strip() + if not coin_clean: + raise ValueError("each coin must be a non-empty string") + if coin_clean not in normalized_coins: + normalized_coins.append(coin_clean) + + if not interval or not isinstance(interval, str): + raise ValueError("interval must be a non-empty string") + interval = interval.strip() + if not interval: + raise ValueError("interval must be a non-empty string") + + if days is not None and (start_time is not None or end_time is not None): + raise ValueError( + "days cannot be used together with start_time or end_time" + ) + + current_time_ms = int(time.time() * 1000) + + if days is not None: + if not isinstance(days, int) or days <= 0: + raise ValueError("days must be a positive integer") + effective_end = current_time_ms if end_time is None else int(end_time) + effective_start = effective_end - (days * 24 * 60 * 60 * 1000) + else: + if start_time is None: + raise ValueError("start_time is required when days is not provided") + effective_start = int(start_time) + effective_end = ( + int(end_time) if end_time is not None else current_time_ms + ) + + if effective_start >= effective_end: + raise ValueError("start_time must be less than end_time") + + candles_by_coin: dict[str, list[dict[str, Any]]] = {} + coin_errors: dict[str, str] = {} + + for coin in normalized_coins: + try: + raw_candles = self.info.candles_snapshot( + coin, + interval, + effective_start, + effective_end, + ) + + formatted_candles: list[dict[str, Any]] = [] + for candle in raw_candles or []: + timestamp = candle.get("t") or candle.get("T") + if timestamp is None: + # Skip malformed entries without timestamp + continue + + try: + formatted_candles.append( + { + "timestamp": int(timestamp), + "open": float(candle["o"]), + "high": float(candle["h"]), + "low": float(candle["l"]), + "close": float(candle["c"]), + "volume": float(candle["v"]), + "trade_count": int(candle.get("n", 0)), + } + ) + except (KeyError, TypeError, ValueError) as format_error: + self.logger.warning( + "Skipping malformed candle for %s: %s", + coin, + format_error, + ) + + formatted_candles.sort(key=lambda item: item["timestamp"]) + candles_by_coin[coin] = formatted_candles + except Exception as coin_error: + coin_errors[coin] = str(coin_error) + self.logger.error( + "Failed to fetch candles for %s: %s", coin, coin_error + ) + + if not candles_by_coin: + return { + "success": False, + "error": "Failed to fetch candle data for requested coins", + "coin_errors": coin_errors, + } + + response: dict[str, Any] = { + "success": True, + "data": candles_by_coin, + "interval": interval, + "start_time": effective_start, + "end_time": effective_end, + } + + if coin_errors: + response["coin_errors"] = coin_errors + + return response + except Exception as e: + self.logger.error("Failed to get candles snapshot bulk: %s", e) + return {"success": False, "error": str(e)} + async def update_leverage( self, coin: str, leverage: int, is_cross: bool = True ) -> dict[str, Any]: diff --git a/test_http_tools.py b/test_http_tools.py new file mode 100644 index 0000000..b2eae01 --- /dev/null +++ b/test_http_tools.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +""" +使用 HTTP SSE 流式测试 MCP 服务器的工具列表 +""" + +import json + +import requests + + +def parse_sse_response(text): + """解析 SSE 格式的响应""" + # SSE 格式: event: message\ndata: {...}\n\n + lines = text.strip().split("\n") + for line in lines: + if line.startswith("data: "): + data = line[6:] # 去掉 "data: " 前缀 + try: + return json.loads(data) + except (json.JSONDecodeError, ValueError): + pass + return None + + +def test_mcp_http(): + """通过 HTTP SSE 流式请求测试 MCP 服务器""" + base_url = "http://127.0.0.1:8080/mcp" + + print("\n" + "=" * 60) + print(f"Testing MCP Server (HTTP SSE): {base_url}") + print("=" * 60) + + # 第1步: 初始化会话 + print("\n📡 Step 1: Initialize session...") + init_payload = { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0.0"}, + }, + } + + try: + response = requests.post( + base_url, + json=init_payload, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + ) + + print(f" Status: {response.status_code}") + print(f" Content-Type: {response.headers.get('Content-Type')}") + + if response.status_code == 200: + # 解析 SSE 响应 + result = parse_sse_response(response.text) + if result and "result" in result: + server_info = result["result"].get("serverInfo", {}) + print(f" ✅ Server: {server_info.get('name')}") + print(f" Version: {server_info.get('version')}") + else: + print(f" ❌ Error: {response.text}") + return + + # 获取 session ID(如果有) + session_id = None + # 检查多种可能的 header 名称 + session_headers = [ + "x-mcp-session-id", + "mcp-session-id", + "x-session-id", + "session-id", + "x-mcp-session", + ] + for header in session_headers: + if header in response.headers: + session_id = response.headers[header] + print(f" Session ID ({header}): {session_id}") + break + + if not session_id: + print(" ⚠️ No session ID in response headers") + print(f" Available headers: {list(response.headers.keys())}") + + # 第2步: 请求工具列表 + print("\n📋 Step 2: List tools...") + tools_payload = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/list", + "params": {}, + } + + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + } + if session_id: + # 尝试所有可能的 session header 名称 + headers["x-mcp-session-id"] = session_id + headers["mcp-session-id"] = session_id + headers["x-session-id"] = session_id + + tools_response = requests.post(base_url, json=tools_payload, headers=headers) + + print(f" Status: {tools_response.status_code}") + + if tools_response.status_code == 200: + # 解析 SSE 响应 + tools_result = parse_sse_response(tools_response.text) + + if ( + tools_result + and "result" in tools_result + and "tools" in tools_result["result"] + ): + tools = tools_result["result"]["tools"] + print(f"\n✅ Discovered {len(tools)} tools:\n") + + # 列出所有工具 + for i, tool in enumerate(tools, 1): + marker = ( + "🆕" if tool.get("name") == "get_candles_snapshot" else " " + ) + print(f"{marker} {i:2d}. {tool.get('name')}") + + # 检查 get_candles_snapshot + print("\n" + "=" * 60) + tool_names = [t.get("name") for t in tools] + + if "get_candles_snapshot" in tool_names: + print("✅✅✅ get_candles_snapshot IS AVAILABLE via HTTP!") + + # 显示详细信息 + snapshot_tool = next( + t for t in tools if t.get("name") == "get_candles_snapshot" + ) + print("\n📝 Tool Details:") + print(f" Name: {snapshot_tool.get('name')}") + + if "description" in snapshot_tool: + desc = snapshot_tool["description"] + # 只显示前150个字符 + print(f" Description: {desc[:150]}...") + + if "inputSchema" in snapshot_tool: + schema = snapshot_tool["inputSchema"] + if "required" in schema: + print(f" Required: {schema['required']}") + if "properties" in schema: + print(f" Parameters: {list(schema['properties'].keys())}") + else: + print("❌ get_candles_snapshot NOT FOUND in HTTP response") + print(f"\nFirst 5 tools: {', '.join(tool_names[:5])}") + + print("=" * 60 + "\n") + else: + print("\n⚠️ Unexpected response format:") + print(json.dumps(tools_result, indent=2)[:500]) + else: + print(f" ❌ Error: {tools_response.text}") + + except Exception as e: + print(f"\n❌ Error: {e}") + print(f" Type: {type(e).__name__}") + import traceback + + traceback.print_exc() + print("\n⚠️ Make sure server is running:") + print(" uv run start") + + +if __name__ == "__main__": + test_mcp_http() diff --git a/test_tool_registration.py b/test_tool_registration.py new file mode 100644 index 0000000..5e60131 --- /dev/null +++ b/test_tool_registration.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +""" +测试MCP工具注册情况 +""" + +import sys + +sys.path.insert(0, "/Volumes/ExtDISK/github/hyperliquid-mcp") + +from main import mcp + +print("\n" + "=" * 60) +print("MCP Tools Registration Test") +print("=" * 60) + +# 访问 _tool_manager._tools +if hasattr(mcp, "_tool_manager") and hasattr(mcp._tool_manager, "_tools"): + tools_dict = mcp._tool_manager._tools + tool_names = sorted(tools_dict.keys()) + + print(f"\n✅ Found {len(tool_names)} registered tools:\n") + + for i, tool_name in enumerate(tool_names, 1): + marker = "🆕" if tool_name == "get_candles_snapshot" else " " + print(f"{marker} {i:2d}. {tool_name}") + + # 检查 get_candles_snapshot + print("\n" + "=" * 60) + if "get_candles_snapshot" in tools_dict: + print("✅✅✅ get_candles_snapshot IS REGISTERED!") + tool_def = tools_dict["get_candles_snapshot"] + print(f" Type: {type(tool_def)}") + if hasattr(tool_def, "description"): + print(f" Description: {tool_def.description[:100]}...") + else: + print("❌ get_candles_snapshot NOT FOUND") +else: + print("❌ Cannot access tool manager") + +print("=" * 60 + "\n") diff --git a/tests/conftest.py b/tests/conftest.py index 0fd5ff4..319abfc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,154 @@ """pytest 配置""" import sys +from collections.abc import Callable from pathlib import Path +from types import ModuleType +from typing import Any, Optional # 添加项目根目录到 Python 路径 project_root = Path(__file__).parent.parent sys.path.insert(0, str(project_root)) + + +def _ensure_module(name: str) -> ModuleType: + module = sys.modules.get(name) + if module is None: + module = ModuleType(name) + sys.modules[name] = module + return module + + +# ---- Hyperliquid SDK Stubs (用于测试环境) ---- +hyperliquid_module = _ensure_module("hyperliquid") +api_module = _ensure_module("hyperliquid.api") +exchange_module = _ensure_module("hyperliquid.exchange") +info_module = _ensure_module("hyperliquid.info") +utils_module = _ensure_module("hyperliquid.utils") +constants_module = _ensure_module("hyperliquid.utils.constants") +signing_module = _ensure_module("hyperliquid.utils.signing") +types_module = _ensure_module("hyperliquid.utils.types") +websocket_module = _ensure_module("hyperliquid.websocket_manager") + + +class _API: # pragma: no cover - 简易桩实现 + def __init__(self, *args, **kwargs): + pass + + +api_module.API = getattr(api_module, "API", _API) + + +class _Exchange: # pragma: no cover - 简易桩实现 + def __init__(self, *args, **kwargs): + pass + + +exchange_module.Exchange = getattr(exchange_module, "Exchange", _Exchange) + + +class _Info: # pragma: no cover - 简易桩实现 + def __init__(self, *args, **kwargs): + pass + + +info_module.Info = getattr(info_module, "Info", _Info) + +constants_module.MAINNET_API_URL = getattr( + constants_module, "MAINNET_API_URL", "https://api.mock.hyperliquid" +) +constants_module.TESTNET_API_URL = getattr( + constants_module, "TESTNET_API_URL", "https://api.mock.hyperliquid" +) + + +def _order_request_to_order_wire(order, asset): # pragma: no cover - 测试桩 + return order + + +def _order_wires_to_order_action(wires, _builder): # pragma: no cover - 测试桩 + return {"orders": wires} + + +def _sign_l1_action(*args, **kwargs): # pragma: no cover - 测试桩 + return "signed" + + +signing_module.order_request_to_order_wire = getattr( + signing_module, "order_request_to_order_wire", _order_request_to_order_wire +) +signing_module.order_wires_to_order_action = getattr( + signing_module, "order_wires_to_order_action", _order_wires_to_order_action +) +signing_module.sign_l1_action = getattr( + signing_module, "sign_l1_action", _sign_l1_action +) + + +class _Cloid(str): # pragma: no cover - 简易桩实现 + pass + + +types_module.Any = getattr(types_module, "Any", Any) +types_module.Callable = getattr(types_module, "Callable", Callable) +types_module.Cloid = getattr(types_module, "Cloid", _Cloid) +types_module.List = getattr(types_module, "List", list) +types_module.Meta = getattr(types_module, "Meta", dict) +types_module.Optional = getattr(types_module, "Optional", Optional) +types_module.SpotMeta = getattr(types_module, "SpotMeta", dict) +types_module.SpotMetaAndAssetCtxs = getattr(types_module, "SpotMetaAndAssetCtxs", list) +types_module.Subscription = getattr(types_module, "Subscription", dict) +types_module.cast = getattr(types_module, "cast", lambda typ, val: val) + + +class _WebsocketManager: # pragma: no cover - 简易桩实现 + def __init__(self, *args, **kwargs): + pass + + def start(self): # pragma: no cover - 简易桩实现 + pass + + def stop(self): # pragma: no cover - 简易桩实现 + pass + + +websocket_module.WebsocketManager = getattr( + websocket_module, "WebsocketManager", _WebsocketManager +) + + +# ---- FastMCP Stub ---- +fastmcp_module = _ensure_module("fastmcp") + + +class _FastMCP: # pragma: no cover - 简易桩实现 + def __init__(self, name: str): + self.name = name + + def tool(self, func): + return func + + def run_async(self, *args, **kwargs): + raise RuntimeError("FastMCP stub does not support run_async in tests") + + def run(self, *args, **kwargs): + raise RuntimeError("FastMCP stub does not support run in tests") + + +fastmcp_module.FastMCP = getattr(fastmcp_module, "FastMCP", _FastMCP) + + +# ---- eth_account Stub ---- +eth_account_module = _ensure_module("eth_account") + + +class _AccountStub: # pragma: no cover - 简易桩实现 + @staticmethod + def from_key(private_key: str): + class _Wallet: # pragma: no cover - 简易桩实现 + address = "0xSTUB_ACCOUNT" + + return _Wallet() + + +eth_account_module.Account = getattr(eth_account_module, "Account", _AccountStub) diff --git a/tests/integration/test_candles_snapshot_tool.py b/tests/integration/test_candles_snapshot_tool.py new file mode 100644 index 0000000..631cf9f --- /dev/null +++ b/tests/integration/test_candles_snapshot_tool.py @@ -0,0 +1,90 @@ +"""批量 K 线工具集成测试""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import main + + +def test_get_candles_snapshot_tool_limit(monkeypatch): + """验证工具能调用服务并应用 limit""" + mock_service = MagicMock() + mock_service.get_candles_snapshot_bulk = AsyncMock( + return_value={ + "success": True, + "data": { + "BTC": [ + { + "timestamp": 1, + "open": 1, + "high": 2, + "low": 0.5, + "close": 1.5, + "volume": 10, + "trade_count": 5, + }, + { + "timestamp": 2, + "open": 1.5, + "high": 2.5, + "low": 1, + "close": 2, + "volume": 12, + "trade_count": 6, + }, + ], + "ETH": [ + { + "timestamp": 3, + "open": 3, + "high": 4, + "low": 2.5, + "close": 3.5, + "volume": 15, + "trade_count": 7, + }, + ], + }, + "interval": "1h", + "start_time": 1, + "end_time": 2, + } + ) + + monkeypatch.setattr(main, "initialize_service", lambda: None) + monkeypatch.setattr(main, "hyperliquid_service", mock_service) + + response = asyncio.run( + main.get_candles_snapshot( + coins=["BTC", "ETH"], + interval="1h", + days=1, + limit=1, + ) + ) + + assert response["success"] is True + assert response["limit_per_coin"] == 1 + assert len(response["data"]["BTC"]) == 1 + assert response["data"]["BTC"][0]["timestamp"] == 2 + + mock_service.get_candles_snapshot_bulk.assert_awaited_once() + + +def test_get_candles_snapshot_tool_validation_error(monkeypatch): + """非法输入参数返回结构化错误""" + monkeypatch.setattr(main, "initialize_service", lambda: None) + monkeypatch.setattr(main, "hyperliquid_service", MagicMock()) + + response = asyncio.run( + main.get_candles_snapshot( + coins=["BTC"], + interval="1h", + start_time=None, + end_time=None, + days=None, + ) + ) + + assert response["success"] is False + assert response.get("error_code") == "VALIDATION_ERROR" diff --git a/tests/unit/test_candles_snapshot.py b/tests/unit/test_candles_snapshot.py new file mode 100644 index 0000000..b48cc71 --- /dev/null +++ b/tests/unit/test_candles_snapshot.py @@ -0,0 +1,123 @@ +"""批量 K 线快照服务测试""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from services.hyperliquid_services import HyperliquidServices + + +@pytest.fixture +def service_with_mocks(monkeypatch): + """创建带有 Info/Exchange mock 的服务实例""" + with ( + patch("services.hyperliquid_services.Info") as mock_info_class, + patch("services.hyperliquid_services.Exchange") as mock_exchange_class, + patch("eth_account.Account") as mock_account_class, + ): + mock_wallet = MagicMock() + mock_wallet.address = "0xFAKE_WALLET" + mock_account_class.from_key.return_value = mock_wallet + + info_instance = MagicMock() + mock_info_class.return_value = info_instance + mock_exchange_class.return_value = MagicMock() + + service = HyperliquidServices( + private_key="0x" + "1" * 64, + testnet=True, + account_address="0xACCOUNT", + ) + + # 使用实际构造出的 info instance 以便校验调用 + service.info = info_instance + yield service, info_instance + + +def test_get_candles_snapshot_bulk_days_success(service_with_mocks, monkeypatch): + """验证 days 参数路径可正确获取并整理数据""" + service, info_instance = service_with_mocks + + fixed_time = 1_700_000_000.0 + monkeypatch.setattr("services.hyperliquid_services.time.time", lambda: fixed_time) + + info_instance.candles_snapshot.return_value = [ + { + "t": 1_699_913_600_000, + "o": "1", + "h": "2", + "l": "0.5", + "c": "1.5", + "v": "10", + "n": 5, + }, + { + "t": 1_699_917_600_000, + "o": "1.5", + "h": "2.5", + "l": "1", + "c": "2", + "v": "12", + "n": 6, + }, + ] + + result = asyncio.run( + service.get_candles_snapshot_bulk(["BTC", "BTC"], "1h", days=1) + ) + + assert result["success"] is True + assert list(result["data"].keys()) == ["BTC"] + assert len(result["data"]["BTC"]) == 2 + + expected_end = int(fixed_time * 1000) + expected_start = expected_end - 86_400_000 + + info_instance.candles_snapshot.assert_called_once_with( + "BTC", "1h", expected_start, expected_end + ) + + +def test_get_candles_snapshot_bulk_coin_error(service_with_mocks): + """当部分币种失败时依然返回成功并附带错误信息""" + service, info_instance = service_with_mocks + + def side_effect(coin, interval, start, end): + if coin == "BTC": + return [{"t": 1, "o": "1", "h": "2", "l": "0.5", "c": "1.5", "v": "10"}] + raise RuntimeError("coin not supported") + + info_instance.candles_snapshot.side_effect = side_effect + + result = asyncio.run( + service.get_candles_snapshot_bulk( + ["BTC", "ETH"], + "1h", + start_time=1, + end_time=2, + ) + ) + + assert result["success"] is True + assert "BTC" in result["data"] + assert "coin_errors" in result + assert result["coin_errors"].get("ETH") == "coin not supported" + + +def test_get_candles_snapshot_bulk_invalid_params(service_with_mocks): + """非法参数组合应返回失败""" + service, _ = service_with_mocks + + result = asyncio.run( + service.get_candles_snapshot_bulk( + ["BTC"], + "1h", + start_time=1, + end_time=2, + days=1, + ) + ) + + assert result["success"] is False + assert "days cannot be used" in result["error"]