diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a7f79d1..8efdd57 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: run: | python -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - pip install flake8 pytest 'sxs==2025.0.9' romspline pycbc + pip install flake8 pytest 'sxs==2025.0.9' romspline pycbc requests bs4 #- name: Lint with flake8 # run: | # stop the build if there are Python syntax errors or undefined names diff --git a/PyART/catalogs/gra.py b/PyART/catalogs/gra.py index ab037e1..a7f1f13 100644 --- a/PyART/catalogs/gra.py +++ b/PyART/catalogs/gra.py @@ -2,8 +2,23 @@ import os import h5py from ..waveform import Waveform -import glob as glob import json +import logging +import re +import time + +# libraries for downloading +try: + import requests + from bs4 import BeautifulSoup + from urllib.parse import urljoin + from requests.adapters import HTTPAdapter + from urllib3.util.retry import Retry +except ImportError as e: + raise ImportError( + "To use the GRA catalog, please install the required " + "dependencies: requests, beautifulsoup4, urllib3" + ) from e class Waveform_GRA(Waveform): @@ -17,18 +32,27 @@ class Waveform_GRA(Waveform): def __init__( self, - path, + ID="0001", + path="../dat/GRA", ellmax=8, ext="ext", + res="128", r_ext=None, cut_N=None, cut_U=None, - mtdt_path=None, - rescale=False, + nu_rescale=False, modes=[(2, 2)], + download=False, + downloads=["hlm", "metadata"], ): super().__init__() + # Normalize ID to a 4-digit zero-padded string for consistency + if isinstance(ID, int): + ID = f"{ID:04d}" + elif isinstance(ID, str) and ID.isdigit() and len(ID) < 4: + ID = ID.zfill(4) + self.ID = ID self.path = path self.cut_N = cut_N self.cut_U = cut_U @@ -37,19 +61,81 @@ def __init__( self.extrap = ext self.domain = "Time" self.r_ext = r_ext - self.rescale = rescale + self.nu_rescale = nu_rescale + self.res = res # comment out the following for the moment - self.load_metadata(mtdt_path) + + if download: + self.download_simulation(ID=ID, path=path, downloads=downloads, res=res) + + self.load_metadata() self.load_hlm(extrap=ext, ellmax=ellmax, r_ext=r_ext) pass - def load_metadata(self, path): + def download_simulation( + self, + ID="0001", + path=None, + downloads=["hlm", "metadata"], + res=None, + ): """ - Load the metadata, if path is None assume - that they are in the same dir as the .h5 files + Automatically download and unpack a GRAthena++ + simulation from scholarsphere. """ + if path is None: path = self.path + + session = make_session() + + logging.info("Fetching catalog...") + id_map = get_id_to_item_url(session) + + if ID not in id_map: + raise RuntimeError(f"ID {ID} not found in catalog") + + item_url = id_map[ID] + + soup = get_item_soup(session, item_url) + + if "hlm" in downloads: + logging.info("Downloading hlm data...") + if res is None: + res = "128" + self.res = res + logging.warning("No resolution specified, defaulting to res=128") + + filename, tar_url = find_tar_for_resolution(soup, res) + logging.info(f"Found .tar: {filename}") + logging.info(f"Downloading from: {tar_url}") + download_safe(session, tar_url, filename) + # untar, execute via os.system for the moment + extract_path = os.path.join(path, f"GRA_BHBH_{ID}") + os.makedirs(extract_path, exist_ok=True) + logging.info(f"Extracting to: {extract_path}") + os.system(f"tar -xf {filename} -C {extract_path}") + os.remove(filename) + + if "metadata" in downloads: + logging.info("Downloading metadata...") + filename, meta_url = find_metadata_file(soup) + logging.info(f"Found metadata file: {filename}") + logging.info(f"Downloading from: {meta_url}") + download_safe(session, meta_url, filename) + # move to correct location + extract_path = os.path.join(path, f"GRA_BHBH_{ID}", "metadata.json") + os.makedirs(os.path.dirname(extract_path), exist_ok=True) + os.rename(filename, extract_path) + + # Be polite to the server + time.sleep(3) + + def load_metadata(self): + """ + Load the metadata from the json file and store it in self.metadata + """ + path = os.path.join(self.path, f"GRA_BHBH_{self.ID}", "metadata.json") ometa = json.load(open(path, "r")) m1 = float(ometa["initial-mass1"]) @@ -124,11 +210,23 @@ def load_hlm(self, extrap="ext", ellmax=None, load_m0=False, r_ext=None): r_ext = "100.00" if extrap == "ext": - h5_file = os.path.join(self.path, "rh_Asymptotic_GeometricUnits.h5") + h5_file = os.path.join( + self.path, + f"GRA_BHBH_{self.ID}", + self.res, + "rh_Asymptotic_GeometricUnits.h5", + ) elif extrap == "CCE": - h5_file = os.path.join(self.path, "rh_CCE_GeometricUnits.h5") + h5_file = os.path.join( + self.path, f"GRA_BHBH_{self.ID}", self.res, "rh_CCE_GeometricUnits.h5" + ) elif extrap == "finite": - h5_file = os.path.join(self.path, "rh_FiniteRadii_GeometricUnits.h5") + h5_file = os.path.join( + self.path, + f"GRA_BHBH_{self.ID}", + self.res, + "rh_FiniteRadii_GeometricUnits.h5", + ) else: raise ValueError('extrap should be either "ext", "CCE" or "finite"') @@ -171,7 +269,7 @@ def load_hlm(self, extrap="ext", ellmax=None, load_m0=False, r_ext=None): mode = "Y_l" + str(l) + "_m" + str(m) + ".dat" hlm = nr[r_ext][mode] h = hlm[:, 1] + 1j * hlm[:, 2] - if self.rescale: + if self.nu_rescale: h /= self.metadata["nu"] # amp and phase Alm = abs(h)[self.cut_N :] @@ -222,13 +320,14 @@ def get_indices_dict(self): def load_psi4lm( self, - path=None, - fname=None, ellmax=None, r_ext=None, extrap="ext", load_m0=False, ): + """ + Load the data from the h5 file, but for psi4 instead of h. + """ if ellmax == None: ellmax = self.ellmax @@ -236,11 +335,26 @@ def load_psi4lm( r_ext = "100.00" if extrap == "ext": - h5_file = os.path.join(self.path, "rPsi4_Asymptotic_GeometricUnits.h5") + h5_file = os.path.join( + self.path, + f"GRA_BHBH_{self.ID}", + self.res, + "rPsi4_Asymptotic_GeometricUnits.h5", + ) elif extrap == "CCE": - h5_file = os.path.join(self.path, "rPsi4_CCE_GeometricUnits.h5") + h5_file = os.path.join( + self.path, + f"GRA_BHBH_{self.ID}", + self.res, + "rPsi4_CCE_GeometricUnits.h5", + ) elif extrap == "finite": - h5_file = os.path.join(self.path, "rPsi4_FiniteRadii_GeometricUnits.h5") + h5_file = os.path.join( + self.path, + f"GRA_BHBH_{self.ID}", + self.res, + "rPsi4_FiniteRadii_GeometricUnits.h5", + ) else: raise ValueError('extrap should be either "ext", "CCE" or "finite"') @@ -282,7 +396,7 @@ def load_psi4lm( mode = "Y_l" + str(l) + "_m" + str(m) + ".dat" psi4lm = nr[r_ext][mode] psi4 = psi4lm[:, 1] + 1j * psi4lm[:, 2] - if self.rescale: + if self.nu_rescale: psi4 /= self.metadata["nu"] Alm = abs(psi4)[self.cut_N :] plm = -np.unwrap(np.angle(psi4))[self.cut_N :] @@ -297,3 +411,146 @@ def load_psi4lm( self._psi4lm = dict_psi4lm pass + + +# ---------------------------------------------------------------------- +# Functions needed to download data from GRAthena++ +# ---------------------------------------------------------------------- + +CATALOG_URL = ( + "https://scholarsphere.psu.edu/resources/610744ac-80b9-4689-8119-320dfd2e2b9a" +) +BASE_URL = "https://scholarsphere.psu.edu" + + +def make_session(): + session = requests.Session() + + retries = Retry( + total=5, + backoff_factor=1.5, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["GET"], + ) + + adapter = HTTPAdapter(max_retries=retries) + session.mount("https://", adapter) + session.mount("http://", adapter) + + session.headers.update( + { + "User-Agent": ( + "Mozilla/5.0 (X11; Linux x86_64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/121.0 Safari/537.36" + ), + "Accept": "*/*", + "Accept-Encoding": "identity", # avoids chunked/gzip resets + "Connection": "keep-alive", + "Referer": "https://scholarsphere.psu.edu/", + } + ) + + return session + + +def get_id_to_item_url(session): + r = session.get(CATALOG_URL, timeout=30) + r.raise_for_status() + soup = BeautifulSoup(r.text, "html.parser") + + id_map = {} + + for a in soup.find_all("a", href=True): + text = a.get_text(strip=True) + m = re.search(r"GRAthena:BHBH:(\d{4})", text) + if m: + id_map[m.group(1)] = urljoin(BASE_URL, a["href"]) + + if not id_map: + raise RuntimeError("No GRAthena IDs found on catalog page") + + return id_map + + +def get_item_soup(session, item_url): + r = session.get(item_url, timeout=30) + r.raise_for_status() + return BeautifulSoup(r.text, "html.parser") + + +def find_tar_for_resolution(item_soup, resolution): + resolution = resolution.lower() + + for a in item_soup.find_all("a", href=True): + href = a["href"].lower() + text = a.get_text(strip=True).lower() + if ( + "/downloads/" in href + and text.endswith(".tar") + and resolution in (href + text) + ): + filename = os.path.basename(href) + return filename, urljoin(BASE_URL, a["href"]) + + raise RuntimeError(f"No .tar found for resolution '{resolution}'") + + +def find_metadata_file(item_soup): + for a in item_soup.find_all("a", href=True): + href = a["href"].lower() + text = a.get_text(strip=True).lower() + if "/downloads/" in href and text.endswith(".json"): + filename = os.path.basename(href) + return filename, urljoin(BASE_URL, a["href"]) + + raise RuntimeError(f"No metadata.json file found") + + +def download_safe(session, url, filename, chunk_size=1024 * 1024): + tmp_file = filename + ".part" + downloaded = 0 + + if os.path.exists(tmp_file): + downloaded = os.path.getsize(tmp_file) + logging.info(f"Resuming download from byte {downloaded}") + + headers = {} + if downloaded > 0: + headers["Range"] = f"bytes={downloaded}-" + + with session.get(url, stream=True, headers=headers, timeout=60) as r: + r.raise_for_status() + + # Decide whether we can safely resume or must restart from scratch. + resume_supported = False + if downloaded > 0: + if r.status_code == 206: + content_range = r.headers.get("Content-Range", "") + # Expect the content range to start at our downloaded offset. + expected = f"bytes {downloaded}-" + if content_range.startswith(expected) or expected in content_range: + resume_supported = True + else: + logging.info( + "Server did not honor Range header (status %s); " + "restarting full download", + r.status_code, + ) + + if not resume_supported: + # If we had a partial file, overwrite it rather than append, to avoid + # corrupting the file when the server sends the full content. + if downloaded > 0: + logging.info("Discarding existing partial download and restarting") + downloaded = 0 + mode = "wb" + else: + mode = "ab" + with open(tmp_file, mode) as f: + for chunk in r.iter_content(chunk_size=chunk_size): + if chunk: + f.write(chunk) + + os.rename(tmp_file, filename) + logging.info(f"Download completed") diff --git a/tests/test_gra.py b/tests/test_gra.py new file mode 100644 index 0000000..60363aa --- /dev/null +++ b/tests/test_gra.py @@ -0,0 +1,43 @@ +""" +Tests for the GRA catalog. +""" + +from PyART.catalogs import gra +import os + +mode_keys = ["A", "p", "real", "imag", "z"] + + +def test_gra(): + """ + Test the GRA download function. + """ + wf = gra.Waveform_GRA( + ID="0001", + path="./", + download=True, + res="128", + downloads=["hlm", "metadata"], + ext="CCE", + ) + # check attributes + assert wf.ID == "0001" + + # check that the files were downloaded + assert os.path.exists("GRA_BHBH_0001") + assert os.path.exists(f"GRA_BHBH_0001/metadata.json") + assert os.path.exists(f"GRA_BHBH_0001/128/rh_CCE_GeometricUnits.h5") + # check that the modes loaded make sense + for mode in wf.hlm.keys(): + + # check ell, emm + assert mode[0] >= abs(mode[1]) + # check keys + for key in mode_keys: + assert key in wf.hlm[mode].keys() + # check length + assert len(wf.hlm[mode]["A"]) == len(wf.u) + + +if __name__ == "__main__": + test_gra()