diff --git a/src/steam_mcp/client/steam_client.py b/src/steam_mcp/client/steam_client.py index 995946b..b8085d9 100644 --- a/src/steam_mcp/client/steam_client.py +++ b/src/steam_mcp/client/steam_client.py @@ -1,7 +1,7 @@ """Steam API Client with rate limiting, caching, and error handling. This client provides a robust interface to the Steam Web API with: -- Rate limiting to avoid hitting API limits +- Rate limiting to avoid hitting API limits (global singleton by default) - TTL-based response caching to reduce API load - Automatic retry with exponential backoff - Consistent error handling across all endpoints @@ -21,6 +21,12 @@ logger = logging.getLogger(__name__) +# Default rate limit (requests per second) +DEFAULT_RATE_LIMIT = 10.0 + +# Global rate limiter instance (singleton) +_global_rate_limiter: "RateLimiter | None" = None + class SteamAPIError(Exception): """Base exception for Steam API errors.""" @@ -67,6 +73,38 @@ async def acquire(self) -> None: self.last_request_time = time.monotonic() +def get_global_rate_limiter() -> RateLimiter: + """ + Get or create the global rate limiter singleton. + + The rate limit can be configured via the STEAM_RATE_LIMIT environment variable. + Default is 10 requests per second. + + Thread-safe due to GIL; RateLimiter.__init__ is synchronous so no async race. + + Returns: + The shared global RateLimiter instance. + """ + global _global_rate_limiter + + if _global_rate_limiter is None: + rate_limit = float(os.getenv("STEAM_RATE_LIMIT", str(DEFAULT_RATE_LIMIT))) + _global_rate_limiter = RateLimiter(requests_per_second=rate_limit) + logger.debug(f"Created global rate limiter: {rate_limit} req/s") + + return _global_rate_limiter + + +def reset_global_rate_limiter() -> None: + """ + Reset the global rate limiter. + + This is primarily useful for testing to ensure a fresh state. + """ + global _global_rate_limiter + _global_rate_limiter = None + + class SteamClient: """Async client for the Steam Web API.""" @@ -92,11 +130,12 @@ def __init__( self, api_key: str | None = None, owner_steam_id: str | None = None, - requests_per_second: float = 10.0, + rate_limiter: RateLimiter | None = None, max_retries: int = 3, timeout: float = 30.0, enable_cache: bool = True, cache_max_size: int = 1000, + requests_per_second: float | None = None, ): """ Initialize Steam API client. @@ -105,11 +144,14 @@ def __init__( api_key: Steam Web API key. If not provided, reads from STEAM_API_KEY env var. owner_steam_id: SteamID64 of the API key owner. If not provided, reads from STEAM_USER_ID env var. This enables "get my profile" style queries. - requests_per_second: Rate limit for API requests. + rate_limiter: Optional custom RateLimiter instance. If not provided, uses the + global shared rate limiter (configured via STEAM_RATE_LIMIT env var). max_retries: Maximum number of retry attempts for failed requests. timeout: Request timeout in seconds. enable_cache: Whether to enable response caching (default: True). cache_max_size: Maximum number of cached entries (default: 1000). + requests_per_second: Deprecated. Use rate_limiter or STEAM_RATE_LIMIT env var. + Creates a dedicated RateLimiter for this client if provided. """ self.api_key = api_key or os.getenv("STEAM_API_KEY") if not self.api_key: @@ -120,7 +162,14 @@ def __init__( self.max_retries = max_retries self.timeout = timeout - self.rate_limiter = RateLimiter(requests_per_second) + + # Rate limiter priority: explicit rate_limiter > requests_per_second > global + if rate_limiter is not None: + self.rate_limiter = rate_limiter + elif requests_per_second is not None: + self.rate_limiter = RateLimiter(requests_per_second=requests_per_second) + else: + self.rate_limiter = get_global_rate_limiter() # Initialize cache if enabled self._cache: TTLCache | None = None diff --git a/tests/test_global_rate_limiter.py b/tests/test_global_rate_limiter.py new file mode 100644 index 0000000..848eefd --- /dev/null +++ b/tests/test_global_rate_limiter.py @@ -0,0 +1,211 @@ +"""Tests for global rate limiter functionality.""" + +import asyncio +import pytest +from unittest.mock import patch + +from steam_mcp.client.steam_client import ( + DEFAULT_RATE_LIMIT, + RateLimiter, + SteamClient, + get_global_rate_limiter, + reset_global_rate_limiter, +) + + +@pytest.fixture(autouse=True) +def reset_limiter(): + """Reset global rate limiter before each test.""" + reset_global_rate_limiter() + yield + reset_global_rate_limiter() + + +@pytest.fixture +def mock_env(): + """Set up mock environment variables.""" + with patch.dict("os.environ", {"STEAM_API_KEY": "test_key"}, clear=False): + yield + + +class TestGetGlobalRateLimiter: + """Tests for get_global_rate_limiter function.""" + + def test_returns_rate_limiter_instance(self): + """Should return a RateLimiter instance.""" + limiter = get_global_rate_limiter() + assert isinstance(limiter, RateLimiter) + + def test_returns_same_instance(self): + """Should return the same instance on subsequent calls.""" + limiter1 = get_global_rate_limiter() + limiter2 = get_global_rate_limiter() + assert limiter1 is limiter2 + + def test_uses_default_rate_limit(self): + """Should use DEFAULT_RATE_LIMIT when env var not set.""" + limiter = get_global_rate_limiter() + assert limiter.requests_per_second == DEFAULT_RATE_LIMIT + + def test_uses_env_var_rate_limit(self): + """Should use STEAM_RATE_LIMIT env var when set.""" + with patch.dict("os.environ", {"STEAM_RATE_LIMIT": "5.0"}): + reset_global_rate_limiter() + limiter = get_global_rate_limiter() + assert limiter.requests_per_second == 5.0 + + +class TestResetGlobalRateLimiter: + """Tests for reset_global_rate_limiter function.""" + + def test_reset_creates_new_instance(self): + """After reset, should create a new instance.""" + limiter1 = get_global_rate_limiter() + reset_global_rate_limiter() + limiter2 = get_global_rate_limiter() + assert limiter1 is not limiter2 + + +class TestSteamClientRateLimiter: + """Tests for SteamClient rate limiter integration.""" + + def test_uses_global_rate_limiter_by_default(self, mock_env): + """SteamClient should use global rate limiter by default.""" + global_limiter = get_global_rate_limiter() + client = SteamClient() + assert client.rate_limiter is global_limiter + + def test_multiple_clients_share_rate_limiter(self, mock_env): + """Multiple SteamClient instances should share the same rate limiter.""" + client1 = SteamClient() + client2 = SteamClient() + assert client1.rate_limiter is client2.rate_limiter + + def test_can_use_custom_rate_limiter(self, mock_env): + """SteamClient can use a custom rate limiter.""" + custom_limiter = RateLimiter(requests_per_second=5.0) + client = SteamClient(rate_limiter=custom_limiter) + assert client.rate_limiter is custom_limiter + assert client.rate_limiter is not get_global_rate_limiter() + + def test_backward_compat_requests_per_second(self, mock_env): + """SteamClient should support deprecated requests_per_second param.""" + client = SteamClient(requests_per_second=5.0) + # Should create a dedicated limiter, not use global + assert client.rate_limiter is not get_global_rate_limiter() + assert client.rate_limiter.requests_per_second == 5.0 + + def test_rate_limiter_takes_precedence_over_requests_per_second(self, mock_env): + """Explicit rate_limiter should take precedence over requests_per_second.""" + custom_limiter = RateLimiter(requests_per_second=20.0) + client = SteamClient(rate_limiter=custom_limiter, requests_per_second=5.0) + assert client.rate_limiter is custom_limiter + assert client.rate_limiter.requests_per_second == 20.0 + + +class TestConcurrentAccess: + """Tests for concurrent access to rate limiter.""" + + @pytest.mark.asyncio + async def test_concurrent_acquire_is_serialized(self): + """Concurrent acquire calls should be properly serialized.""" + limiter = RateLimiter(requests_per_second=100.0) # Fast for testing + acquire_times = [] + + async def track_acquire(): + await limiter.acquire() + acquire_times.append(asyncio.get_event_loop().time()) + + # Launch multiple concurrent acquires + tasks = [asyncio.create_task(track_acquire()) for _ in range(5)] + await asyncio.gather(*tasks) + + # Verify all acquired (should have 5 entries) + assert len(acquire_times) == 5 + + @pytest.mark.asyncio + async def test_global_limiter_serializes_across_clients(self, mock_env): + """Global rate limiter should serialize requests across multiple clients.""" + reset_global_rate_limiter() + + # Create multiple clients sharing the global limiter + client1 = SteamClient() + client2 = SteamClient() + + # Verify they share the same rate limiter + assert client1.rate_limiter is client2.rate_limiter + + acquire_count = 0 + lock = asyncio.Lock() + + async def acquire_from_client(client): + nonlocal acquire_count + await client.rate_limiter.acquire() + async with lock: + acquire_count += 1 + + # Launch concurrent acquires from different clients + tasks = [ + asyncio.create_task(acquire_from_client(client1)), + asyncio.create_task(acquire_from_client(client2)), + asyncio.create_task(acquire_from_client(client1)), + asyncio.create_task(acquire_from_client(client2)), + ] + await asyncio.gather(*tasks) + + # All should have completed + assert acquire_count == 4 + + @pytest.mark.asyncio + async def test_rate_limiter_enforces_rate(self): + """Rate limiter should enforce the configured rate limit.""" + import time + + # Very slow rate to make timing measurable + limiter = RateLimiter(requests_per_second=10.0) # 100ms between requests + + start = time.monotonic() + + # Make 3 requests + await limiter.acquire() + await limiter.acquire() + await limiter.acquire() + + elapsed = time.monotonic() - start + + # Should take at least 200ms for 3 requests at 10 req/s + # (first is instant, second waits 100ms, third waits 100ms) + assert elapsed >= 0.18 # Allow small margin for timing variance + + @pytest.mark.asyncio + async def test_no_race_conditions_under_load(self, mock_env): + """Global limiter should handle many concurrent requests without races.""" + reset_global_rate_limiter() + + # Use a faster rate for this test + with patch.dict("os.environ", {"STEAM_RATE_LIMIT": "1000.0"}): + reset_global_rate_limiter() + + # Create multiple clients + clients = [SteamClient() for _ in range(3)] + + # Verify all share the same limiter + assert all(c.rate_limiter is clients[0].rate_limiter for c in clients) + + results = [] + lock = asyncio.Lock() + + async def acquire_and_record(client_idx): + client = clients[client_idx % len(clients)] + await client.rate_limiter.acquire() + async with lock: + results.append(client_idx) + + # Launch many concurrent requests + tasks = [ + asyncio.create_task(acquire_and_record(i)) for i in range(20) + ] + await asyncio.gather(*tasks) + + # All should complete without errors + assert len(results) == 20