Skip to content

Commit 9fdb138

Browse files
committed
- improve caching mechanism to ensure zero network transfers when content is already cached from mirror
- extend cache tests
1 parent 634607f commit 9fdb138

File tree

8 files changed

+120
-38
lines changed

8 files changed

+120
-38
lines changed

aura/analyzers/python/visitor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
from __future__ import annotations
55

66
import os
7-
import copy
87
import time
98
from functools import partial, wraps
109
from collections import deque, OrderedDict
11-
from pathlib import Path
1210
from warnings import warn
13-
from typing import Optional, Tuple, Union
11+
from typing import Optional, Tuple, Union, Dict
1412

1513
import pkg_resources
1614

@@ -45,7 +43,7 @@ def wrapper(*args, **kwargs):
4543
return wrapper
4644

4745

48-
def get_ast_tree(location: Union[ScanLocation, bytes], metadata=None):
46+
def get_ast_tree(location: Union[ScanLocation, bytes], metadata=None) -> dict:
4947
if type(location) == bytes:
5048
kwargs = {
5149
"command": [INSPECTOR_PATH, "-"],
@@ -103,7 +101,7 @@ def __init__(self, *, location: ScanLocation):
103101
self.max_queue_size = int(config.get_settings("aura.max-ast-queue-size", 10000))
104102

105103
@classmethod
106-
def from_visitor(cls, visitor: Visitor):
104+
def from_visitor(cls, visitor: Visitor) -> Visitor:
107105
obj = cls(location=visitor.location)
108106
obj.tree = visitor.tree
109107
obj.hits = visitor.hits
@@ -138,7 +136,7 @@ def run_stages(cls, *, location: ScanLocation, stages: Optional[Tuple[str, ...]]
138136
return v
139137

140138
@classmethod
141-
def get_visitors(cls):
139+
def get_visitors(cls) -> Dict[str, Visitor]:
142140
global VISITORS
143141
if VISITORS is None:
144142
VISITORS = {

aura/cache.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,7 @@ def proxy_url(cls, *, url, fd, cache_id=None):
6666

6767
@classmethod
6868
def proxy_mirror(cls, *, src: Path, cache_id=None):
69-
if not src.exists():
70-
return None
71-
elif cls.get_location() is None:
69+
if cls.get_location() is None: # Caching is disabled
7270
return src
7371

7472
if cache_id is None:
@@ -81,6 +79,9 @@ def proxy_mirror(cls, *, src: Path, cache_id=None):
8179
logger.debug(f"Retrieving mirror file path {cache_id} from cache")
8280
return cache_pth
8381

82+
if not src.exists():
83+
return src
84+
8485
try:
8586
shutil.copyfile(src, cache_pth, follow_symlinks=True)
8687
return cache_pth
@@ -90,18 +91,19 @@ def proxy_mirror(cls, *, src: Path, cache_id=None):
9091

9192
@classmethod
9293
def proxy_mirror_json(cls, *, src: Path):
93-
if not src.exists():
94-
return src
95-
elif cls.get_location() is None:
94+
if cls.get_location() is None: # Caching is disabled
9695
return src
9796

9897
cache_id = f"mirrorjson_{src.name}"
99-
cache_path = cls.get_location()/cache_id
98+
cache_path = cls.get_location() / cache_id
10099

101100
if cache_path.exists():
102101
logger.debug(f"Retrieving package mirror JSON {cache_id} from cache")
103102
return cache_path
104103

104+
if not src.exists():
105+
return src
106+
105107
try:
106108
shutil.copyfile(src=src, dst=cache_path, follow_symlinks=True)
107109
return cache_path

aura/config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import tqdm
1717
import pkg_resources
1818
import ruamel.yaml
19-
from ruamel.yaml import YAML, composer
19+
from ruamel.yaml import composer
2020
try:
2121
import rapidjson as json
2222
except ImportError:
@@ -33,8 +33,8 @@
3333
# This is used to trigger breakpoint during AST traversing of specific lines
3434
DEBUG_LINES = set()
3535
DEFAULT_AST_STAGES = ("convert", "rewrite", "ast_pattern_matching", "taint_analysis", "readonly")
36-
AST_PATTERNS_CACHE = None
37-
PROGRESSBAR_DISABLED = ("AURA_NO_PROGRESS" in os.environ)
36+
AST_PATTERNS_CACHE: Optional[tuple] = None
37+
PROGRESSBAR_DISABLED: bool = ("AURA_NO_PROGRESS" in os.environ)
3838

3939
DEFAULT_CFG_PATH = "aura.data.aura_config.yaml"
4040
DEFAULT_SIGNATURE_PATH = "aura.data.signatures.yaml"
@@ -185,7 +185,7 @@ def get_file_content(location: str, base_path: Optional[str]=None) -> str:
185185
return fd.read()
186186

187187

188-
def parse_config(pth, default_pth):
188+
def parse_config(pth, default_pth) -> dict:
189189
logger.debug(f"Aura configuration located at {pth}")
190190

191191
content = get_file_content(pth)
@@ -282,7 +282,7 @@ def get_ast_stages() -> typing.Tuple[str,...]:
282282
return [x for x in cfg_value if x]
283283

284284

285-
def get_ast_patterns():
285+
def get_ast_patterns() -> tuple:
286286
global AST_PATTERNS_CACHE
287287
from .pattern_matching import ASTPattern
288288

@@ -295,7 +295,5 @@ def get_ast_patterns():
295295
return AST_PATTERNS_CACHE
296296

297297

298-
299-
300298
CFG_PATH = find_configuration()
301299
load_config()

aura/mirror.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .config import CFG
1313

1414

15-
class LocalMirror(object):
15+
class LocalMirror:
1616
@classmethod
1717
def get_mirror_path(cls) -> Path:
1818
env_path = os.environ.get('AURA_MIRROR_PATH', None)
@@ -24,23 +24,20 @@ def get_mirror_path(cls) -> Path:
2424
def list_packages(self) -> typing.Generator[Path, None, None]:
2525
yield from (self.get_mirror_path() / "json").iterdir()
2626

27-
def get_json(self, package_name):
28-
if package_name is None:
29-
raise NoSuchPackage(f"Could not find package '{package_name}' json at the mirror location")
30-
27+
def get_json(self, package_name) -> dict:
28+
assert package_name
3129
json_path = self.get_mirror_path() / "json" / package_name
30+
target = cache.Cache.proxy_mirror_json(src=json_path)
3231

33-
if not json_path.is_file():
32+
if not target.is_file():
3433
json_path = self.get_mirror_path() / "json" / canonicalize_name(package_name)
35-
if not json_path.exists():
34+
target = cache.Cache.proxy_mirror_json(src=json_path)
35+
if not target.exists():
3636
raise NoSuchPackage(package_name)
3737

38-
target = cache.Cache.proxy_mirror_json(src=json_path)
39-
40-
with open(target, "r") as fd:
41-
return json.loads(fd.read())
38+
return json.loads(target.read_text())
4239

43-
def url2local(self, url):
40+
def url2local(self, url: typing.Union[ParseResult, str]) -> Path:
4441
if not isinstance(url, ParseResult):
4542
url = urlparse(url)
4643

aura/pattern_matching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def match_node(self, context: nodes.Context) -> bool:
4444
...
4545

4646
@property
47-
def message(self):
47+
def message(self) -> str:
4848
"""
4949
return a message identifying the match
5050
"""
@@ -121,7 +121,7 @@ def match_node(self, node: nodes.NodeType) -> bool:
121121
return False
122122

123123
@abstractmethod
124-
def match_string(self, value: str):
124+
def match_string(self, value: str) -> bool:
125125
...
126126

127127

files/prefetch_mirror.sh

100644100755
File mode changed.

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from click.testing import CliRunner, Result
1414
import responses
15+
import tqdm
1516
import pytest
1617

1718

@@ -370,3 +371,10 @@ def reset_plugins():
370371

371372
ReadOnlyAnalyzer.hooks = []
372373
plugins.PLUGIN_CACHE = {}
374+
375+
376+
@pytest.fixture(scope="module")
377+
def mock_tqdm_log_write():
378+
379+
with mock.patch.object(tqdm.tqdm, "write") as m:
380+
yield m

tests/test_cache.py

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,67 @@
11
import json
2+
import uuid
23
from unittest import mock
34

5+
import pytest
6+
47
from aura import cache
8+
from aura import mirror
9+
from aura import exceptions
10+
11+
12+
13+
@mock.patch("aura.cache.Cache.get_location")
14+
def test_cache_mock_location(cache_mock, tmp_path):
15+
cache_mock.return_value = tmp_path
16+
assert cache.Cache.get_location() == tmp_path
17+
18+
19+
@mock.patch("aura.cache.Cache.get_location")
20+
@pytest.mark.parametrize("filename,content,cache_id,call", (
21+
("testjson_file", "json_content", "mirrorjson_testjson_file", cache.Cache.proxy_mirror_json),
22+
("testpkg_file", "pkg_content", "mirror_testpkg_file", cache.Cache.proxy_mirror)
23+
))
24+
def test_proxy_mirror_json(cache_mock, tmp_path, filename, content, cache_id, call):
25+
f = tmp_path / filename
26+
cache_path = tmp_path / "cache"
27+
cache_path.mkdir()
28+
cache_file = cache_path/cache_id
29+
cache_mock.return_value = cache_path
30+
31+
assert f.exists() is False
32+
out = call(src=f)
33+
assert out == f
34+
assert cache_file.exists() is False
35+
assert len(list(cache_path.iterdir())) == 0
36+
37+
f.write_text(content)
38+
assert f.exists() is True
39+
out = call(src=f)
40+
assert out != f
41+
assert out == cache_file
42+
assert len(list(cache_path.iterdir())) == 1
43+
assert out.read_text() == content
544

45+
# Make sure the cache does not attempt to do any kind of file access if the cache entry exists
46+
m = mock.MagicMock(spec_set=("name",), side_effect=ValueError("Call prohibited"))
47+
m.name = filename
48+
out = call(src=m)
49+
assert out == cache_file
650

7-
def test_mirror_cache(fixtures, simulate_mirror, tmp_path):
51+
# Original path should be returned if cache is disabled
52+
cache_mock.return_value = None
53+
out = call(src=f)
54+
assert out == f
55+
56+
57+
@mock.patch("aura.cache.Cache.get_location")
58+
def test_mirror_cache(cache_mock, fixtures, simulate_mirror, tmp_path):
859
cache_content = list(tmp_path.iterdir())
960
assert len(cache_content) == 0
1061

11-
with mock.patch.object(cache.Cache, 'get_location', return_value=tmp_path) as m:
12-
assert cache.Cache.get_location() == tmp_path
13-
out = fixtures.get_cli_output(['scan', '--download-only', 'mirror://wheel', '-f', 'json'])
62+
cache_mock.return_value = tmp_path
63+
assert cache.Cache.get_location() == tmp_path
64+
out = fixtures.get_cli_output(['scan', '--download-only', 'mirror://wheel', '-f', 'json'])
1465

1566
parsed_output = json.loads(out.stdout)
1667
assert len(parsed_output["detections"]) == 0
@@ -19,3 +70,31 @@ def test_mirror_cache(fixtures, simulate_mirror, tmp_path):
1970
assert len(cache_content) > 0
2071
assert "mirror_wheel-0.34.2.tar.gz" in cache_content, cache_content
2172
assert "mirror_wheel-0.34.2-py2.py3-none-any.whl" in cache_content
73+
74+
75+
@mock.patch("aura.cache.Cache.get_location")
76+
@mock.patch("aura.mirror.LocalMirror.get_mirror_path")
77+
def test_mirror_cache_no_remote_access(mirror_mock, cache_mock, fixtures, tmp_path):
78+
"""
79+
Test that if the content is fully cached, the mirror uri handler does not attempt to access the mirror but rather retrieves **all** content from cache only
80+
This is mainly to test correctness of prefetching the data for global PyPI scan to ensure no further network calls are made
81+
"""
82+
pkg = str(uuid.uuid4())
83+
pkg_content = {"id": pkg}
84+
mirror_path = tmp_path / "mirror"
85+
cache_path = tmp_path / "cache"
86+
cache_path.mkdir()
87+
cache_mock.return_value = cache_path
88+
mirror_mock.return_value = mirror_path
89+
m = mirror.LocalMirror()
90+
91+
assert cache.Cache.get_location() == cache_path
92+
assert m.get_mirror_path() == mirror_path
93+
assert mirror_path.exists() == False
94+
95+
with pytest.raises(exceptions.NoSuchPackage):
96+
m.get_json(pkg)
97+
98+
(cache_path/f"mirrorjson_{pkg}").write_text(json.dumps(pkg_content))
99+
out = m.get_json(pkg)
100+
assert out == pkg_content

0 commit comments

Comments
 (0)