Skip to content
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ htmlcov
# common virtual environment names
venv*
env
.venv

# editors
.idea
Expand Down
23 changes: 1 addition & 22 deletions pori_python/graphkb/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,7 @@
looks_like_rid,
stringifyVariant,
)
from .vocab import (
get_equivalent_terms,
get_term_by_name,
get_term_tree,
get_terms_set,
)
from .vocab import get_equivalent_terms, get_term_by_name, get_term_tree, get_terms_set

FEATURES_CACHE: Set[str] = set()

Expand Down Expand Up @@ -111,22 +106,6 @@ def get_equivalent_features(
)


def cache_missing_features(conn: GraphKBConnection) -> None:
"""
Create a cache of features that exist to avoid repeatedly querying
for missing features
"""
genes = cast(
List[Ontology],
conn.query({'target': 'Feature', 'returnProperties': ['name', 'sourceId'], 'neighbors': 0}),
)
for gene in genes:
if gene['name']:
FEATURES_CACHE.add(gene['name'].lower())
if gene['sourceId']:
FEATURES_CACHE.add(gene['sourceId'].lower())


def match_category_variant(
conn: GraphKBConnection,
reference_name: str,
Expand Down
86 changes: 58 additions & 28 deletions pori_python/graphkb/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import time
from datetime import datetime
from requests_cache import CacheMixin
from typing import Any, Dict, Iterable, List, Optional, Union, cast
from urllib3.util.retry import Retry
from urllib.parse import urlsplit
Expand All @@ -16,8 +17,6 @@

from .constants import DEFAULT_LIMIT, TYPES_TO_NOTATION, AA_3to1_MAPPING

QUERY_CACHE: Dict[Any, Any] = {}

# name the logger after the package to make it simple to disable for packages using this one as a dependency
# https://stackoverflow.com/questions/11029717/how-do-i-disable-log-messages-from-the-requests-library

Expand Down Expand Up @@ -88,11 +87,8 @@ def millis_interval(start: datetime, end: datetime) -> int:
return millis


def cache_key(request_body) -> str:
"""Create a cache key for a query request to GraphKB."""
body = json.dumps(request_body, sort_keys=True)
hash_code = hashlib.md5(f'/query{body}'.encode('utf-8')).hexdigest()
return hash_code
class CustomSession(CacheMixin, requests.Session):
pass


class GraphKBConnection:
Expand All @@ -102,8 +98,35 @@ def __init__(
username: str = '',
password: str = '',
use_global_cache: bool = True,
cache_name: str = '',
**session_kwargs,
):
self.http = requests.Session()
"""
Docstring for __init__

Args:
- use_global_cache: cache requests across all requests to GKB
- cache_name: Path or connection URL to the database which stors the requests cache. see https://requests-cache.readthedocs.io/en/v0.6.4/user_guide.html#cache-name
"""
if use_global_cache:
if not cache_name:
self.http = CustomSession(
backend='memory',
cache_control=True,
allowable_methods=['GET', 'POST'],
ignored_parameters=['Authorization'],
**session_kwargs,
)
else:
self.http = CustomSession(
cache_name,
cache_control=True,
allowable_methods=['GET', 'POST'],
ignored_parameters=['Authorization'],
**session_kwargs,
)
else:
self.http = requests.Session(**session_kwargs)
retries = Retry(
total=100,
connect=5,
Expand All @@ -117,8 +140,10 @@ def __init__(
self.url = url
self.username = username
self.password = password
self.headers = {'Accept': 'application/json', 'Content-Type': 'application/json'}
self.cache: Dict[Any, Any] = {} if not use_global_cache else QUERY_CACHE
self.headers = {
'Accept': 'application/json',
'Content-Type': 'application/json',
}
self.request_count = 0
self.first_request: Optional[datetime] = None
self.last_request: Optional[datetime] = None
Expand All @@ -137,7 +162,13 @@ def load(self) -> Optional[float]:
return self.request_count * 1000 / msec
return None

def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict:
def request(
self,
endpoint: str,
method: str = 'GET',
headers: Optional[dict[str, str]] = None,
**kwargs,
) -> Dict:
"""Request wrapper to handle adding common headers and logging.

Args:
Expand All @@ -158,6 +189,11 @@ def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict:
if endpoint in ['query', 'parse']:
timeout = (connect_timeout, read_timeout)

request_headers = {}
request_headers.update(self.headers)
if headers is not None:
request_headers.update(headers)

start_time = datetime.now()

if not self.first_request:
Expand All @@ -179,8 +215,8 @@ def request(self, endpoint: str, method: str = 'GET', **kwargs) -> Dict:
need_refresh_login = False

self.request_count += 1
resp = requests.request(
method, url, headers=self.headers, timeout=timeout, **kwargs
resp = self.http.request(
method, url, headers=request_headers, timeout=timeout, **kwargs
)
if resp.status_code == 401 or resp.status_code == 403:
logger.debug(f'/{endpoint} - {resp.status_code} - retrying')
Expand Down Expand Up @@ -293,11 +329,6 @@ def login(self, username: str, password: str, pori_demo: bool = False) -> None:
def refresh_login(self) -> None:
self.login(self.username, self.password)

def set_cache_data(self, request_body: Dict, result: List[Record]) -> None:
"""Explicitly add a query to the cache."""
hash_code = cache_key(request_body)
self.cache[hash_code] = result

def query(
self,
request_body: Dict = {},
Expand All @@ -309,23 +340,22 @@ def query(
"""
Query GraphKB
"""
result: List[Record] = []
hash_code = ''

if not ignore_cache and paginate:
hash_code = cache_key(request_body)
if hash_code in self.cache and not force_refresh:
return self.cache[hash_code]
headers = {}
if ignore_cache or force_refresh:
headers = {'Cache-Control': 'no-cache'}

result: List[Record] = []
while True:
content = self.post('query', data={**request_body, 'limit': limit, 'skip': len(result)})
content = self.post(
'query',
data={**request_body, 'limit': limit, 'skip': len(result)},
headers=headers,
)
records = content['result']
result.extend(records)
if len(records) < limit or not paginate:
break

if not ignore_cache and paginate:
self.cache[hash_code] = result
return result

def parse(self, hgvs_string: str, requireFeatures: bool = False) -> ParsedVariant:
Expand Down
5 changes: 0 additions & 5 deletions pori_python/graphkb/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,6 @@ def get_terms_set(
) -> Set[str]:
"""Get a set of vocabulary rids given some base/parent term names."""
base_terms = [base_terms] if isinstance(base_terms, str) else base_terms
cache_key = tuple(sorted(base_terms))
if graphkb_conn.cache.get(cache_key, None) and not ignore_cache:
return graphkb_conn.cache[cache_key]
terms = set()
for base_term in base_terms:
terms.update(
Expand All @@ -193,6 +190,4 @@ def get_terms_set(
)
)
)
if not ignore_cache:
graphkb_conn.cache[cache_key] = terms
return terms
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ install_requires =
requests
tqdm
typing_extensions>=3.7.4.2,<5
requests-cache[sqlite]

[options.extras_require]
deploy = twine; wheel; m2r
Expand Down
16 changes: 0 additions & 16 deletions tests/test_graphkb/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,22 +563,6 @@ def test_structural_variants(self, conn):
assert type not in expected.get('does_not_matches', {}).get('type', [])


class TestCacheMissingFeatures:
def test_filling_cache(self):
mock_conn = MagicMock(
query=MagicMock(
return_value=[
{'name': 'bob', 'sourceId': 'alice'},
{'name': 'KRAS', 'sourceId': '1234'},
]
)
)
match.cache_missing_features(mock_conn)
assert 'kras' in match.FEATURES_CACHE
assert 'alice' in match.FEATURES_CACHE
match.FEATURES_CACHE = None


class TestTypeScreening:
# Types as class variables
default_type = DEFAULT_NON_STRUCTURAL_VARIANT_TYPE
Expand Down
17 changes: 14 additions & 3 deletions tests/test_graphkb/test_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,21 @@ def conn() -> GraphKBConnection:

@pytest.fixture()
def graphkb_conn():
"""
Mocks the query functionality required by the calls made for the categorize_relevance function

categorize_relevance calls query twice for each term in the values of the category_base_terms object

- get_terms_set ([term, term, term...])
- get_term_tree(term)
- query(query(term))
"""

def make_rid_list(*values):
return [{'@rid': v} for v in values]

def term_tree_calls(*final_values):
# this function makes 2 calls to conn.query here
# this function makes 2 calls to conn.query here b/c the get_terms_set function will always call query twice
sets = [['fake'], final_values]
return [make_rid_list(*s) for s in sets]

Expand All @@ -41,7 +51,7 @@ def term_tree_calls(*final_values):

query_mock = Mock()
query_mock.side_effect = return_values
return Mock(query=query_mock, cache={})
return Mock(query=query_mock)


class TestCategorizeRelevance:
Expand Down Expand Up @@ -77,12 +87,13 @@ def test_no_match(self, graphkb_conn):
category = statement.categorize_relevance(graphkb_conn, 'x')
assert category == ''

def test_custom_categories(self, graphkb_conn):
def test_custom_categories_not_found(self, graphkb_conn):
category = statement.categorize_relevance(
graphkb_conn, 'x', [('blargh', ['some', 'blargh'])]
)
assert category == ''

def test_custom_categories_match(self, graphkb_conn):
category = statement.categorize_relevance(
graphkb_conn, '1', [('blargh', ['some', 'blargh'])]
)
Expand Down