Skip to content

Commit 8daed3f

Browse files
authored
Retry downloads from zenodo when failing because of error 423 (#30)
* Retry downloads from zenodo when failing because of error 423 * Lint * more linting
1 parent d4bd3b3 commit 8daed3f

2 files changed

Lines changed: 26 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"metatensor-torch >=0.7.6,<0.9",
1313
"metatomic-torch >=0.1.2,<0.2",
1414
"vesin",
15+
"requests",
1516
]
1617

1718
readme = "README.md"

src/shiftml/ase/calculator.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
22
import os
3-
import urllib.request
43

54
import numpy as np
5+
import requests
66
from metatomic.torch import ModelOutput
77
from metatomic.torch.ase_calculator import MetatomicCalculator
88
from platformdirs import user_cache_path
9+
from requests.adapters import HTTPAdapter
10+
from urllib3.util.retry import Retry
911

1012
from shiftml.utils.tensorial import T_sym_np_inv, symmetrize
1113

@@ -57,6 +59,25 @@ def is_fitted_on(atoms, fitted_species):
5759
)
5860

5961

62+
def download_with_retry(url, destination):
63+
"""Helper function to download data with retries on errors."""
64+
65+
# Retry strategy: wait 1s, 2s, 4s, 8s, 16s on 429/5xx errors
66+
retry_strategy = Retry(
67+
total=5, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504]
68+
)
69+
session = requests.Session()
70+
session.mount("https://", HTTPAdapter(max_retries=retry_strategy))
71+
72+
# Fetch with automatic retry and error raising
73+
response = session.get(url, stream=True)
74+
response.raise_for_status()
75+
76+
with open(destination, "wb") as file:
77+
for chunk in response.iter_content(chunk_size=8192):
78+
file.write(chunk)
79+
80+
6081
def ShiftML(model_version, force_download=False, device=None):
6182
"""
6283
Initialize the ShiftML calculator
@@ -247,24 +268,14 @@ def __init__(self, model_version, force_download=False, device=None):
247268
download = True
248269

249270
if download:
250-
urllib.request.urlretrieve(url, model_file)
271+
download_with_retry(url, model_file)
251272
logging.info(
252273
"Downloaded {} and saved to {}".format(model_version, cachedir)
253274
)
254275

255-
except urllib.error.URLError as e:
276+
except requests.exceptions.RequestException as e:
256277
logging.error(
257-
"Failed to download {} from {}. URL Error: {}".format(
258-
model_version, url, e.reason
259-
)
260-
)
261-
raise e
262-
except urllib.error.HTTPError as e:
263-
logging.error(
264-
"Failed to download {} from {}.\
265-
HTTP Error: {} - {}".format(
266-
model_version, url, e.code, e.reason
267-
)
278+
"Failed to download {} from {}. Error: {}".format(model_version, url, e)
268279
)
269280
raise e
270281
except Exception as e:

0 commit comments

Comments
 (0)