Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions src/steam_mcp/client/steam_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Steam API Client with rate limiting, caching, and error handling.

This client provides a robust interface to the Steam Web API with:
- Rate limiting to avoid hitting API limits
- Rate limiting to avoid hitting API limits (global singleton by default)
- TTL-based response caching to reduce API load
- Automatic retry with exponential backoff
- Consistent error handling across all endpoints
Expand All @@ -21,6 +21,12 @@

logger = logging.getLogger(__name__)

# Default rate limit (requests per second)
DEFAULT_RATE_LIMIT = 10.0

# Global rate limiter instance (singleton)
_global_rate_limiter: "RateLimiter | None" = None


class SteamAPIError(Exception):
"""Base exception for Steam API errors."""
Expand Down Expand Up @@ -67,6 +73,38 @@ async def acquire(self) -> None:
self.last_request_time = time.monotonic()


def get_global_rate_limiter() -> RateLimiter:
"""
Get or create the global rate limiter singleton.

The rate limit can be configured via the STEAM_RATE_LIMIT environment variable.
Default is 10 requests per second.

Thread-safe due to GIL; RateLimiter.__init__ is synchronous so no async race.

Returns:
The shared global RateLimiter instance.
"""
global _global_rate_limiter

if _global_rate_limiter is None:
rate_limit = float(os.getenv("STEAM_RATE_LIMIT", str(DEFAULT_RATE_LIMIT)))
_global_rate_limiter = RateLimiter(requests_per_second=rate_limit)
logger.debug(f"Created global rate limiter: {rate_limit} req/s")

return _global_rate_limiter


def reset_global_rate_limiter() -> None:
"""
Reset the global rate limiter.

This is primarily useful for testing to ensure a fresh state.
"""
global _global_rate_limiter
_global_rate_limiter = None


class SteamClient:
"""Async client for the Steam Web API."""

Expand All @@ -92,11 +130,12 @@ def __init__(
self,
api_key: str | None = None,
owner_steam_id: str | None = None,
requests_per_second: float = 10.0,
rate_limiter: RateLimiter | None = None,
max_retries: int = 3,
timeout: float = 30.0,
enable_cache: bool = True,
cache_max_size: int = 1000,
requests_per_second: float | None = None,
):
"""
Initialize Steam API client.
Expand All @@ -105,11 +144,14 @@ def __init__(
api_key: Steam Web API key. If not provided, reads from STEAM_API_KEY env var.
owner_steam_id: SteamID64 of the API key owner. If not provided, reads from
STEAM_USER_ID env var. This enables "get my profile" style queries.
requests_per_second: Rate limit for API requests.
rate_limiter: Optional custom RateLimiter instance. If not provided, uses the
global shared rate limiter (configured via STEAM_RATE_LIMIT env var).
max_retries: Maximum number of retry attempts for failed requests.
timeout: Request timeout in seconds.
enable_cache: Whether to enable response caching (default: True).
cache_max_size: Maximum number of cached entries (default: 1000).
requests_per_second: Deprecated. Use rate_limiter or STEAM_RATE_LIMIT env var.
Creates a dedicated RateLimiter for this client if provided.
"""
self.api_key = api_key or os.getenv("STEAM_API_KEY")
if not self.api_key:
Expand All @@ -120,7 +162,14 @@ def __init__(

self.max_retries = max_retries
self.timeout = timeout
self.rate_limiter = RateLimiter(requests_per_second)

# Rate limiter priority: explicit rate_limiter > requests_per_second > global
if rate_limiter is not None:
self.rate_limiter = rate_limiter
elif requests_per_second is not None:
self.rate_limiter = RateLimiter(requests_per_second=requests_per_second)
else:
self.rate_limiter = get_global_rate_limiter()

# Initialize cache if enabled
self._cache: TTLCache | None = None
Expand Down
211 changes: 211 additions & 0 deletions tests/test_global_rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""Tests for global rate limiter functionality."""

import asyncio
import pytest
from unittest.mock import patch

from steam_mcp.client.steam_client import (
DEFAULT_RATE_LIMIT,
RateLimiter,
SteamClient,
get_global_rate_limiter,
reset_global_rate_limiter,
)


@pytest.fixture(autouse=True)
def reset_limiter():
"""Reset global rate limiter before each test."""
reset_global_rate_limiter()
yield
reset_global_rate_limiter()


@pytest.fixture
def mock_env():
"""Set up mock environment variables."""
with patch.dict("os.environ", {"STEAM_API_KEY": "test_key"}, clear=False):
yield


class TestGetGlobalRateLimiter:
"""Tests for get_global_rate_limiter function."""

def test_returns_rate_limiter_instance(self):
"""Should return a RateLimiter instance."""
limiter = get_global_rate_limiter()
assert isinstance(limiter, RateLimiter)

def test_returns_same_instance(self):
"""Should return the same instance on subsequent calls."""
limiter1 = get_global_rate_limiter()
limiter2 = get_global_rate_limiter()
assert limiter1 is limiter2

def test_uses_default_rate_limit(self):
"""Should use DEFAULT_RATE_LIMIT when env var not set."""
limiter = get_global_rate_limiter()
assert limiter.requests_per_second == DEFAULT_RATE_LIMIT

def test_uses_env_var_rate_limit(self):
"""Should use STEAM_RATE_LIMIT env var when set."""
with patch.dict("os.environ", {"STEAM_RATE_LIMIT": "5.0"}):
reset_global_rate_limiter()
limiter = get_global_rate_limiter()
assert limiter.requests_per_second == 5.0


class TestResetGlobalRateLimiter:
"""Tests for reset_global_rate_limiter function."""

def test_reset_creates_new_instance(self):
"""After reset, should create a new instance."""
limiter1 = get_global_rate_limiter()
reset_global_rate_limiter()
limiter2 = get_global_rate_limiter()
assert limiter1 is not limiter2


class TestSteamClientRateLimiter:
"""Tests for SteamClient rate limiter integration."""

def test_uses_global_rate_limiter_by_default(self, mock_env):
"""SteamClient should use global rate limiter by default."""
global_limiter = get_global_rate_limiter()
client = SteamClient()
assert client.rate_limiter is global_limiter

def test_multiple_clients_share_rate_limiter(self, mock_env):
"""Multiple SteamClient instances should share the same rate limiter."""
client1 = SteamClient()
client2 = SteamClient()
assert client1.rate_limiter is client2.rate_limiter

def test_can_use_custom_rate_limiter(self, mock_env):
"""SteamClient can use a custom rate limiter."""
custom_limiter = RateLimiter(requests_per_second=5.0)
client = SteamClient(rate_limiter=custom_limiter)
assert client.rate_limiter is custom_limiter
assert client.rate_limiter is not get_global_rate_limiter()

def test_backward_compat_requests_per_second(self, mock_env):
"""SteamClient should support deprecated requests_per_second param."""
client = SteamClient(requests_per_second=5.0)
# Should create a dedicated limiter, not use global
assert client.rate_limiter is not get_global_rate_limiter()
assert client.rate_limiter.requests_per_second == 5.0

def test_rate_limiter_takes_precedence_over_requests_per_second(self, mock_env):
"""Explicit rate_limiter should take precedence over requests_per_second."""
custom_limiter = RateLimiter(requests_per_second=20.0)
client = SteamClient(rate_limiter=custom_limiter, requests_per_second=5.0)
assert client.rate_limiter is custom_limiter
assert client.rate_limiter.requests_per_second == 20.0


class TestConcurrentAccess:
"""Tests for concurrent access to rate limiter."""

@pytest.mark.asyncio
async def test_concurrent_acquire_is_serialized(self):
"""Concurrent acquire calls should be properly serialized."""
limiter = RateLimiter(requests_per_second=100.0) # Fast for testing
acquire_times = []

async def track_acquire():
await limiter.acquire()
acquire_times.append(asyncio.get_event_loop().time())

# Launch multiple concurrent acquires
tasks = [asyncio.create_task(track_acquire()) for _ in range(5)]
await asyncio.gather(*tasks)

# Verify all acquired (should have 5 entries)
assert len(acquire_times) == 5

@pytest.mark.asyncio
async def test_global_limiter_serializes_across_clients(self, mock_env):
"""Global rate limiter should serialize requests across multiple clients."""
reset_global_rate_limiter()

# Create multiple clients sharing the global limiter
client1 = SteamClient()
client2 = SteamClient()

# Verify they share the same rate limiter
assert client1.rate_limiter is client2.rate_limiter

acquire_count = 0
lock = asyncio.Lock()

async def acquire_from_client(client):
nonlocal acquire_count
await client.rate_limiter.acquire()
async with lock:
acquire_count += 1

# Launch concurrent acquires from different clients
tasks = [
asyncio.create_task(acquire_from_client(client1)),
asyncio.create_task(acquire_from_client(client2)),
asyncio.create_task(acquire_from_client(client1)),
asyncio.create_task(acquire_from_client(client2)),
]
await asyncio.gather(*tasks)

# All should have completed
assert acquire_count == 4

@pytest.mark.asyncio
async def test_rate_limiter_enforces_rate(self):
"""Rate limiter should enforce the configured rate limit."""
import time

# Very slow rate to make timing measurable
limiter = RateLimiter(requests_per_second=10.0) # 100ms between requests

start = time.monotonic()

# Make 3 requests
await limiter.acquire()
await limiter.acquire()
await limiter.acquire()

elapsed = time.monotonic() - start

# Should take at least 200ms for 3 requests at 10 req/s
# (first is instant, second waits 100ms, third waits 100ms)
assert elapsed >= 0.18 # Allow small margin for timing variance

@pytest.mark.asyncio
async def test_no_race_conditions_under_load(self, mock_env):
"""Global limiter should handle many concurrent requests without races."""
reset_global_rate_limiter()

# Use a faster rate for this test
with patch.dict("os.environ", {"STEAM_RATE_LIMIT": "1000.0"}):
reset_global_rate_limiter()

# Create multiple clients
clients = [SteamClient() for _ in range(3)]

# Verify all share the same limiter
assert all(c.rate_limiter is clients[0].rate_limiter for c in clients)

results = []
lock = asyncio.Lock()

async def acquire_and_record(client_idx):
client = clients[client_idx % len(clients)]
await client.rate_limiter.acquire()
async with lock:
results.append(client_idx)

# Launch many concurrent requests
tasks = [
asyncio.create_task(acquire_and_record(i)) for i in range(20)
]
await asyncio.gather(*tasks)

# All should complete without errors
assert len(results) == 20