77import requests
88import os
99import tempfile
10+ from huggingface_hub import HfApi , login , hf_hub_download
11+ import pkg_resources
1012
1113
1214def atomic_write (file : Path , content : bytes ) -> None :
@@ -53,6 +55,8 @@ class Dataset:
5355 """The time period of the dataset. This is used to automatically enter the values in the correct time period if the data type is `Dataset.ARRAYS`."""
5456 url : str = None
5557 """The URL to download the dataset from. This is used to download the dataset if it does not exist."""
58+ huggingface_url : str = None
59+ """The HuggingFace URL to download the dataset from. This is used to download the dataset if it does not exist."""
5660
5761 # Data formats
5862 TABLES = "tables"
@@ -306,15 +310,15 @@ def store_file(self, file_path: str):
306310 raise FileNotFoundError (f"File { file_path } does not exist." )
307311 shutil .move (file_path , self .file_path )
308312
309- def download (self , url : str = None ) -> None :
313+ def download (self , url : str = None , version : str = None ) -> None :
310314 """Downloads a file to the dataset's file path.
311315
312316 Args:
313317 url (str): The url to download.
314318 """
315319
316320 if url is None :
317- url = self .url
321+ url = self .huggingface_url or self . url
318322
319323 if "POLICYENGINE_GITHUB_MICRODATA_AUTH_TOKEN" not in os .environ :
320324 auth_headers = {}
@@ -345,6 +349,10 @@ def download(self, url: str = None) -> None:
345349 raise ValueError (
346350 f"File { file_path } not found in release { release_tag } of { org } /{ repo } ."
347351 )
352+ elif url .startswith ("hf://" ):
353+ owner_name , model_name = url .split ("/" )[2 :]
354+ self .download_from_huggingface (owner_name , model_name , version )
355+ return
348356 else :
349357 url = url
350358
@@ -363,6 +371,19 @@ def download(self, url: str = None) -> None:
363371
364372 atomic_write (self .file_path , response .content )
365373
374+ def upload (self , url : str = None ):
375+ """Uploads the dataset to a URL.
376+
377+ Args:
378+ url (str): The url to upload.
379+ """
380+ if url is None :
381+ url = self .huggingface_url or self .url
382+
383+ if url .startswith ("hf://" ):
384+ owner_name , model_name = url .split ("/" )[2 :]
385+ self .upload_to_huggingface (owner_name , model_name )
386+
366387 def remove (self ):
367388 """Removes the dataset from disk."""
368389 if self .exists :
@@ -414,3 +435,59 @@ def from_dataframe(dataframe: pd.DataFrame, time_period: str = None):
414435 )()
415436
416437 return dataset
438+
439+ def upload_to_huggingface (self , owner_name : str , model_name : str ):
440+ """Uploads the dataset to Hugging Face.
441+
442+ Args:
443+ owner_name (str): The owner name.
444+ model_name (str): The model name.
445+ """
446+ token = os .environ .get (
447+ "HUGGING_FACE_TOKEN" , "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"
448+ )
449+ login (token = token )
450+ api = HfApi ()
451+
452+ # Add the policyengine-uk-data version and policyengine-uk version to the h5 metadata.
453+ uk_data_version = get_package_version ("policyengine-uk-data" )
454+ uk_version = get_package_version ("policyengine-uk" )
455+ with h5py .File (self .file_path , "a" ) as f :
456+ f .attrs ["policyengine-uk-data" ] = uk_data_version
457+ f .attrs ["policyengine-uk" ] = uk_version
458+
459+ api .upload_file (
460+ path_or_fileobj = self .file_path ,
461+ path_in_repo = self .file_path .name ,
462+ repo_id = f"{ owner_name } /{ model_name } " ,
463+ repo_type = "model" ,
464+ )
465+
466+ def download_from_huggingface (
467+ self , owner_name : str , model_name : str , version : str = None
468+ ):
469+ """Downloads the dataset from Hugging Face.
470+
471+ Args:
472+ owner_name (str): The owner name.
473+ model_name (str): The model name.
474+ """
475+ token = os .environ .get (
476+ "HUGGING_FACE_TOKEN" , "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"
477+ )
478+ login (token = token )
479+
480+ hf_hub_download (
481+ repo_id = f"{ owner_name } /{ model_name } " ,
482+ repo_type = "model" ,
483+ path = self .file_path ,
484+ revision = version ,
485+ )
486+
487+
488+ def get_package_version (package_name : str ) -> str :
489+ """Get the installed version of a package."""
490+ try :
491+ return pkg_resources .get_distribution (package_name ).version
492+ except pkg_resources .DistributionNotFound :
493+ return "not installed"
0 commit comments