|
1 | 1 | import logging |
2 | 2 | import os |
3 | | -import urllib.request |
4 | 3 |
|
5 | 4 | import numpy as np |
| 5 | +import requests |
6 | 6 | from metatomic.torch import ModelOutput |
7 | 7 | from metatomic.torch.ase_calculator import MetatomicCalculator |
8 | 8 | from platformdirs import user_cache_path |
| 9 | +from requests.adapters import HTTPAdapter |
| 10 | +from urllib3.util.retry import Retry |
9 | 11 |
|
10 | 12 | from shiftml.utils.tensorial import T_sym_np_inv, symmetrize |
11 | 13 |
|
@@ -57,6 +59,25 @@ def is_fitted_on(atoms, fitted_species): |
57 | 59 | ) |
58 | 60 |
|
59 | 61 |
|
| 62 | +def download_with_retry(url, destination): |
| 63 | + """Helper function to download data with retries on errors.""" |
| 64 | + |
| 65 | + # Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors |
| 66 | + retry_strategy = Retry( |
| 67 | + total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504] |
| 68 | + ) |
| 69 | + session = requests.Session() |
| 70 | + session.mount("https://", HTTPAdapter(max_retries=retry_strategy)) |
| 71 | + |
| 72 | + # Fetch with automatic retry and error raising |
| 73 | + response = session.get(url, stream=True) |
| 74 | + response.raise_for_status() |
| 75 | + |
| 76 | + with open(destination, "wb") as file: |
| 77 | + for chunk in response.iter_content(chunk_size=8192): |
| 78 | + file.write(chunk) |
| 79 | + |
| 80 | + |
60 | 81 | def ShiftML(model_version, force_download=False, device=None): |
61 | 82 | """ |
62 | 83 | Initialize the ShiftML calculator |
@@ -247,24 +268,14 @@ def __init__(self, model_version, force_download=False, device=None): |
247 | 268 | download = True |
248 | 269 |
|
249 | 270 | if download: |
250 | | - urllib.request.urlretrieve(url, model_file) |
| 271 | + download_with_retry(url, model_file) |
251 | 272 | logging.info( |
252 | 273 | "Downloaded {} and saved to {}".format(model_version, cachedir) |
253 | 274 | ) |
254 | 275 |
|
255 | | - except urllib.error.URLError as e: |
| 276 | + except requests.exceptions.RequestException as e: |
256 | 277 | logging.error( |
257 | | - "Failed to download {} from {}. URL Error: {}".format( |
258 | | - model_version, url, e.reason |
259 | | - ) |
260 | | - ) |
261 | | - raise e |
262 | | - except urllib.error.HTTPError as e: |
263 | | - logging.error( |
264 | | - "Failed to download {} from {}.\ |
265 | | - HTTP Error: {} - {}".format( |
266 | | - model_version, url, e.code, e.reason |
267 | | - ) |
| 278 | + "Failed to download {} from {}. Error: {}".format(model_version, url, e) |
268 | 279 | ) |
269 | 280 | raise e |
270 | 281 | except Exception as e: |
|
0 commit comments