11from __future__ import annotations
22
3- import calendar
43import contextlib
5- import hashlib
64import io
5+ import logging
76import os
87import platform
9- import shutil
10- import tempfile
11- import time
128import typing as t
139
10+ import cachecontrol
1411import requests
12+ from cachecontrol .caches .file_cache import FileCache
1513
16- _LASTMOD_FMT = "%a, %d %b %Y %H:%M:%S %Z"
14+ log = logging . getLogger ( __name__ )
1715
1816
1917def _base_cache_dir () -> str | None :
@@ -42,26 +40,23 @@ def _resolve_cache_dir(dirname: str) -> str | None:
4240 return cache_dir
4341
4442
45- def _lastmod_from_response (response : requests .Response ) -> float :
46- try :
47- return calendar .timegm (
48- time .strptime (response .headers ["last-modified" ], _LASTMOD_FMT )
49- )
50- # OverflowError: time outside of platform-specific bounds
51- # ValueError: malformed/unparseable
52- # LookupError: no such header
53- except (OverflowError , ValueError , LookupError ):
54- return 0.0
55-
56-
5743def _get_request (
58- file_url : str , * , response_ok : t .Callable [[requests .Response ], bool ]
44+ session : requests .Session ,
45+ file_url : str ,
46+ * ,
47+ response_ok : t .Callable [[requests .Response ], bool ],
5948) -> requests .Response :
6049 num_retries = 2
6150 r : requests .Response | None = None
6251 for _attempt in range (num_retries + 1 ):
6352 try :
64- r = requests .get (file_url , stream = True )
53+ # On retries, bypass CacheControl's local cache to avoid
54+ # re-serving a cached bad response. Ideally we'd delete the
55+ # cache entry directly, but CacheControl doesn't expose a public
56+ # API for this (see https://github.com/psf/cachecontrol/issues/219).
57+ # The no-cache header forces revalidation with the origin server.
58+ headers = {"Cache-Control" : "no-cache" } if _attempt > 0 else {}
59+ r = session .get (file_url , headers = headers )
6560 except requests .RequestException as e :
6661 if _attempt == num_retries :
6762 raise FailedDownloadError ("encountered error during download" ) from e
@@ -74,48 +69,6 @@ def _get_request(
7469 )
7570
7671
77- def _atomic_write (dest : str , content : bytes ) -> None :
78- # download to a temp file and then move to the dest
79- # this makes the download safe if run in parallel (parallel runs
80- # won't create a new empty file for writing and cause failures)
81- fp = tempfile .NamedTemporaryFile (mode = "wb" , delete = False )
82- fp .write (content )
83- fp .close ()
84- shutil .copy (fp .name , dest )
85- os .remove (fp .name )
86-
87-
88- def _cache_hit (cachefile : str , response : requests .Response ) -> bool :
89- # no file? miss
90- if not os .path .exists (cachefile ):
91- return False
92-
93- # compare mtime on any cached file against the remote last-modified time
94- # it is considered a hit if the local file is at least as new as the remote file
95- local_mtime = os .path .getmtime (cachefile )
96- remote_mtime = _lastmod_from_response (response )
97- return local_mtime >= remote_mtime
98-
99-
100- def url_to_cache_filename (ref_url : str ) -> str :
101- """
102- Given a schema URL, convert it to a filename for caching in a cache dir.
103-
104- Rules are as follows:
105- - the base filename is an sha256 hash of the URL
106- - if the filename ends in an extension (.json, .yaml, etc) that extension
107- is appended to the hash
108-
109- Preserving file extensions preserves the extension-based logic used for parsing, and
110- it also helps a local editor (browsing the cache) identify filetypes.
111- """
112- filename = hashlib .sha256 (ref_url .encode ()).hexdigest ()
113- if "." in (last_part := ref_url .rpartition ("/" )[- 1 ]):
114- _ , _ , extension = last_part .rpartition ("." )
115- filename = f"{ filename } .{ extension } "
116- return filename
117-
118-
11972class FailedDownloadError (Exception ):
12073 pass
12174
@@ -124,59 +77,42 @@ class CacheDownloader:
12477 def __init__ (self , cache_dir : str , * , disable_cache : bool = False ) -> None :
12578 self ._cache_dir = _resolve_cache_dir (cache_dir )
12679 self ._disable_cache = disable_cache
127-
128- def _download (
129- self ,
130- file_url : str ,
131- filename : str ,
132- response_ok : t .Callable [[requests .Response ], bool ],
133- ) -> str :
134- assert self ._cache_dir is not None
135- os .makedirs (self ._cache_dir , exist_ok = True )
136- dest = os .path .join (self ._cache_dir , filename )
137-
138- def check_response_for_download (r : requests .Response ) -> bool :
139- # if the response indicates a cache hit, treat it as valid
140- # this ensures that we short-circuit any further evaluation immediately on
141- # a hit
142- if _cache_hit (dest , r ):
143- return True
144- # we now know it's not a hit, so validate the content (forces download)
145- return response_ok (r )
146-
147- response = _get_request (file_url , response_ok = check_response_for_download )
148- # check to see if we have a file which matches the connection
149- # only download if we do not (cache miss, vs hit)
150- if not _cache_hit (dest , response ):
151- _atomic_write (dest , response .content )
152-
153- return dest
80+ self ._cached_session : requests .Session | None = None
81+
82+ @property
83+ def _session (self ) -> requests .Session :
84+ if self ._cached_session is None :
85+ self ._cached_session = self ._build_session ()
86+ return self ._cached_session
87+
88+ def _build_session (self ) -> requests .Session :
89+ session = requests .Session ()
90+ if self ._cache_dir and not self ._disable_cache :
91+ log .debug ("using cache dir: %s" , self ._cache_dir )
92+ os .makedirs (self ._cache_dir , exist_ok = True )
93+ session = cachecontrol .CacheControl (
94+ session , cache = FileCache (self ._cache_dir )
95+ )
96+ else :
97+ log .debug ("caching disabled" )
98+ return session
15499
155100 @contextlib .contextmanager
156101 def open (
157102 self ,
158103 file_url : str ,
159- filename : str ,
160104 validate_response : t .Callable [[requests .Response ], bool ],
161105 ) -> t .Iterator [t .IO [bytes ]]:
162- if (not self ._cache_dir ) or self ._disable_cache :
163- yield io .BytesIO (
164- _get_request (file_url , response_ok = validate_response ).content
165- )
166- else :
167- with open (
168- self ._download (file_url , filename , response_ok = validate_response ), "rb"
169- ) as fp :
170- yield fp
106+ response = _get_request (self ._session , file_url , response_ok = validate_response )
107+ yield io .BytesIO (response .content )
171108
172109 def bind (
173110 self ,
174111 file_url : str ,
175- filename : str | None = None ,
176112 validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
177113 ) -> BoundCacheDownloader :
178114 return BoundCacheDownloader (
179- file_url , self , filename = filename , validation_callback = validation_callback
115+ file_url , self , validation_callback = validation_callback
180116 )
181117
182118
@@ -186,27 +122,28 @@ def __init__(
186122 file_url : str ,
187123 downloader : CacheDownloader ,
188124 * ,
189- filename : str | None = None ,
190125 validation_callback : t .Callable [[bytes ], t .Any ] | None = None ,
191126 ) -> None :
192127 self ._file_url = file_url
193- self ._filename = filename or url_to_cache_filename (file_url )
194128 self ._downloader = downloader
195129 self ._validation_callback = validation_callback
196130
197131 @contextlib .contextmanager
198132 def open (self ) -> t .Iterator [t .IO [bytes ]]:
199133 with self ._downloader .open (
200134 self ._file_url ,
201- self ._filename ,
202135 validate_response = self ._validate_response ,
203136 ) as fp :
204137 yield fp
205138
206139 def _validate_response (self , response : requests .Response ) -> bool :
207140 if not self ._validation_callback :
208141 return True
209-
142+ # CacheControl sets from_cache=True on cache hits; skip re-validation.
143+ # Plain requests.Session (used when disable_cache=True) doesn't set this
144+ # attribute at all, so we use getattr with a default.
145+ if getattr (response , "from_cache" , False ):
146+ return True
210147 try :
211148 self ._validation_callback (response .content )
212149 return True
0 commit comments