From 4adc6ddd6d762291183a608f0eed7a598a8543e2 Mon Sep 17 00:00:00 2001 From: Wenjing Yu Date: Sat, 9 Aug 2025 16:58:20 -0700 Subject: [PATCH] implement cache for provider instance --- app/services/provider_service.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/app/services/provider_service.py b/app/services/provider_service.py index ca0c256..9945ef3 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -154,9 +154,24 @@ async def cache_models( def _get_adapters(self) -> dict[str, ProviderAdapter]: """Get adapters from cache or create new ones""" if not ProviderService._adapters_cache: - ProviderService._adapters_cache = ProviderAdapterFactory.get_all_adapters() + ProviderService._adapters_cache = {} return ProviderService._adapters_cache + def _get_or_create_adapter(self, provider_name: str, base_url: str | None = None, config: dict[str, Any] | None = None) -> ProviderAdapter: + """Get an adapter instance from cache or create a new one""" + # Create a cache key that includes provider name, base_url, and config hash + config_hash = hash(frozenset((config or {}).items())) + cache_key = f"{provider_name}:{base_url or 'default'}:{config_hash}" + + if cache_key not in self._adapters_cache: + adapter = ProviderAdapterFactory.get_adapter(provider_name, base_url, config) + self._adapters_cache[cache_key] = adapter + logger.debug(f"Created new adapter instance for {cache_key}") + else: + logger.debug(f"Using cached adapter instance for {cache_key}") + + return self._adapters_cache[cache_key] + async def _load_provider_keys(self) -> dict[str, dict[str, Any]]: """Load all provider keys for the user synchronously, with lazy loading and caching.""" if self._keys_loaded: @@ -437,7 +452,7 @@ async def _list_models_helper( base_url = self.provider_keys[provider_name]["base_url"] tasks.append( _list_models_helper( - ProviderAdapterFactory.get_adapter(provider_name, base_url, config), + self._get_or_create_adapter(provider_name, base_url, config), api_key, provider_data, ) @@ -503,8 +518,8 @@ async def process_request( serialized_api_key_config ) - # Get the appropriate adapter - adapter = ProviderAdapterFactory.get_adapter(provider_name, base_url, config) + # Get the appropriate adapter (cached) + adapter = self._get_or_create_adapter(provider_name, base_url, config) # Process the request through the adapter usage_tracker_id = None