diff --git a/.gitignore b/.gitignore index f634e82..f001284 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ htmlcov # common virtual environment names venv* env +.venv # editors .idea diff --git a/pori_python/graphkb/match.py b/pori_python/graphkb/match.py index 29c8cf3..8fc1692 100644 --- a/pori_python/graphkb/match.py +++ b/pori_python/graphkb/match.py @@ -31,12 +31,7 @@ looks_like_rid, stringifyVariant, ) -from .vocab import ( - get_equivalent_terms, - get_term_by_name, - get_term_tree, - get_terms_set, -) +from .vocab import get_equivalent_terms, get_term_by_name, get_term_tree, get_terms_set FEATURES_CACHE: Set[str] = set() @@ -111,22 +106,6 @@ def get_equivalent_features( ) -def cache_missing_features(conn: GraphKBConnection) -> None: - """ - Create a cache of features that exist to avoid repeatedly querying - for missing features - """ - genes = cast( - List[Ontology], - conn.query({'target': 'Feature', 'returnProperties': ['name', 'sourceId'], 'neighbors': 0}), - ) - for gene in genes: - if gene['name']: - FEATURES_CACHE.add(gene['name'].lower()) - if gene['sourceId']: - FEATURES_CACHE.add(gene['sourceId'].lower()) - - def match_category_variant( conn: GraphKBConnection, reference_name: str, diff --git a/pori_python/graphkb/util.py b/pori_python/graphkb/util.py index 23c2896..382dba6 100644 --- a/pori_python/graphkb/util.py +++ b/pori_python/graphkb/util.py @@ -8,6 +8,7 @@ import re import time from datetime import datetime +from requests_cache import CacheMixin from typing import Any, Dict, Iterable, List, Optional, Union, cast from urllib3.util.retry import Retry from urllib.parse import urlsplit @@ -16,8 +17,6 @@ from .constants import DEFAULT_LIMIT, TYPES_TO_NOTATION, AA_3to1_MAPPING -QUERY_CACHE: Dict[Any, Any] = {} - # name the logger after the package to make it simple to disable for packages using this one as a dependency # https://stackoverflow.com/questions/11029717/how-do-i-disable-log-messages-from-the-requests-library @@ -88,11 +87,8 @@ def millis_interval(start: datetime, end: datetime) -> int: return millis -def cache_key(request_body) -> str: - """Create a cache key for a query request to GraphKB.""" - body = json.dumps(request_body, sort_keys=True) - hash_code = hashlib.md5(f'/query{body}'.encode('utf-8')).hexdigest() - return hash_code +class CustomSession(CacheMixin, requests.Session): + pass class GraphKBConnection: @@ -102,8 +98,35 @@ def __init__( username: str = '', password: str = '', use_global_cache: bool = True, + cache_name: str = '', + **session_kwargs, ): - self.http = requests.Session() + """ + Docstring for __init__ + + Args: + - use_global_cache: cache requests across all requests to GKB + - cache_name: Path or connection URL to the database which stors the requests cache. see https://requests-cache.readthedocs.io/en/v0.6.4/user_guide.html#cache-name + """ + if use_global_cache: + if not cache_name: + self.http = CustomSession( + backend='memory', + cache_control=True, + allowable_methods=['GET', 'POST'], + ignored_parameters=['Authorization'], + **session_kwargs, + ) + else: + self.http = CustomSession( + cache_name, + cache_control=True, + allowable_methods=['GET', 'POST'], + ignored_parameters=['Authorization'], + **session_kwargs, + ) + else: + self.http = requests.Session(**session_kwargs) retries = Retry( total=100, connect=5, @@ -117,8 +140,10 @@ def __init__( self.url = url self.username = username self.password = password - self.headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} - self.cache: Dict[Any, Any] = {} if not use_global_cache else QUERY_CACHE + self.headers = { + 'Accept': 'application/json', + 'Content-Type': 'application/json', + } self.request_count = 0 self.first_request: Optional[datetime] = None self.last_request: Optional[datetime] = None @@ -137,7 +162,13 @@ def load(self) -> Optional[float]: return self.request_count * 1000 / msec return None - def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict: + def request( + self, + endpoint: str, + method: str = 'GET', + headers: Optional[dict[str, str]] = None, + **kwargs, + ) -> Dict: """Request wrapper to handle adding common headers and logging. Args: @@ -158,6 +189,11 @@ def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict: if endpoint in ['query', 'parse']: timeout = (connect_timeout, read_timeout) + request_headers = {} + request_headers.update(self.headers) + if headers is not None: + request_headers.update(headers) + start_time = datetime.now() if not self.first_request: @@ -179,8 +215,8 @@ def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict: need_refresh_login = False self.request_count += 1 - resp = requests.request( - method, url, headers=self.headers, timeout=timeout, **kwargs + resp = self.http.request( + method, url, headers=request_headers, timeout=timeout, **kwargs ) if resp.status_code == 401 or resp.status_code == 403: logger.debug(f'/{endpoint} - {resp.status_code} - retrying') @@ -293,11 +329,6 @@ def login(self, username: str, password: str, pori_demo: bool = False) -> None: def refresh_login(self) -> None: self.login(self.username, self.password) - def set_cache_data(self, request_body: Dict, result: List[Record]) -> None: - """Explicitly add a query to the cache.""" - hash_code = cache_key(request_body) - self.cache[hash_code] = result - def query( self, request_body: Dict = {}, @@ -309,23 +340,22 @@ def query( """ Query GraphKB """ - result: List[Record] = [] - hash_code = '' - - if not ignore_cache and paginate: - hash_code = cache_key(request_body) - if hash_code in self.cache and not force_refresh: - return self.cache[hash_code] + headers = {} + if ignore_cache or force_refresh: + headers = {'Cache-Control': 'no-cache'} + result: List[Record] = [] while True: - content = self.post('query', data={**request_body, 'limit': limit, 'skip': len(result)}) + content = self.post( + 'query', + data={**request_body, 'limit': limit, 'skip': len(result)}, + headers=headers, + ) records = content['result'] result.extend(records) if len(records) < limit or not paginate: break - if not ignore_cache and paginate: - self.cache[hash_code] = result return result def parse(self, hgvs_string: str, requireFeatures: bool = False) -> ParsedVariant: diff --git a/pori_python/graphkb/vocab.py b/pori_python/graphkb/vocab.py index e9242a7..55ed222 100644 --- a/pori_python/graphkb/vocab.py +++ b/pori_python/graphkb/vocab.py @@ -181,9 +181,6 @@ def get_terms_set( ) -> Set[str]: """Get a set of vocabulary rids given some base/parent term names.""" base_terms = [base_terms] if isinstance(base_terms, str) else base_terms - cache_key = tuple(sorted(base_terms)) - if graphkb_conn.cache.get(cache_key, None) and not ignore_cache: - return graphkb_conn.cache[cache_key] terms = set() for base_term in base_terms: terms.update( @@ -193,6 +190,4 @@ def get_terms_set( ) ) ) - if not ignore_cache: - graphkb_conn.cache[cache_key] = terms return terms diff --git a/setup.cfg b/setup.cfg index 1d97cc2..3f2c365 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ install_requires = requests tqdm typing_extensions>=3.7.4.2,<5 + requests-cache[sqlite] [options.extras_require] deploy = twine; wheel; m2r diff --git a/tests/test_graphkb/test_match.py b/tests/test_graphkb/test_match.py index 7a34b90..1e6b54d 100644 --- a/tests/test_graphkb/test_match.py +++ b/tests/test_graphkb/test_match.py @@ -563,22 +563,6 @@ def test_structural_variants(self, conn): assert type not in expected.get('does_not_matches', {}).get('type', []) -class TestCacheMissingFeatures: - def test_filling_cache(self): - mock_conn = MagicMock( - query=MagicMock( - return_value=[ - {'name': 'bob', 'sourceId': 'alice'}, - {'name': 'KRAS', 'sourceId': '1234'}, - ] - ) - ) - match.cache_missing_features(mock_conn) - assert 'kras' in match.FEATURES_CACHE - assert 'alice' in match.FEATURES_CACHE - match.FEATURES_CACHE = None - - class TestTypeScreening: # Types as class variables default_type = DEFAULT_NON_STRUCTURAL_VARIANT_TYPE diff --git a/tests/test_graphkb/test_statement.py b/tests/test_graphkb/test_statement.py index 89faafd..550191e 100644 --- a/tests/test_graphkb/test_statement.py +++ b/tests/test_graphkb/test_statement.py @@ -17,11 +17,21 @@ def conn() -> GraphKBConnection: @pytest.fixture() def graphkb_conn(): + """ + Mocks the query functionality required by the calls made for the categorize_relevance function + + categorize_relevance calls query twice for each term in the values of the category_base_terms object + + - get_terms_set ([term, term, term...]) + - get_term_tree(term) + - query(query(term)) + """ + def make_rid_list(*values): return [{'@rid': v} for v in values] def term_tree_calls(*final_values): - # this function makes 2 calls to conn.query here + # this function makes 2 calls to conn.query here b/c the get_terms_set function will always call query twice sets = [['fake'], final_values] return [make_rid_list(*s) for s in sets] @@ -41,7 +51,7 @@ def term_tree_calls(*final_values): query_mock = Mock() query_mock.side_effect = return_values - return Mock(query=query_mock, cache={}) + return Mock(query=query_mock) class TestCategorizeRelevance: @@ -77,12 +87,13 @@ def test_no_match(self, graphkb_conn): category = statement.categorize_relevance(graphkb_conn, 'x') assert category == '' - def test_custom_categories(self, graphkb_conn): + def test_custom_categories_not_found(self, graphkb_conn): category = statement.categorize_relevance( graphkb_conn, 'x', [('blargh', ['some', 'blargh'])] ) assert category == '' + def test_custom_categories_match(self, graphkb_conn): category = statement.categorize_relevance( graphkb_conn, '1', [('blargh', ['some', 'blargh'])] )