|
3 | 3 |
|
4 | 4 | from SARIAD.config import DATASETS_PATH |
5 | 5 |
|
6 | | -def fetch_blob(path, link="", drive_file_id="", kaggle="", ext="zip"): |
| 6 | +def fetch_blob(path, link="", drive_file_id="", kaggle="", is_archive=True, ext="zip"): |
7 | 7 | """ |
8 | 8 | Fetches the dataset blob from a direct link, Google Drive, or Kaggle, |
9 | 9 | and extracts it directly to the specified path. |
10 | 10 |
|
11 | 11 | Parameters: |
12 | | - - path: str, The full path to the directory where the extracted blob should reside. |
13 | | - - link: str, optional, direct HTTP(s) link to an archive. |
14 | | - - drive_file_id: str, optional, ID for Google Drive file (archive). |
| 12 | + - path: str, The full path to the directory where the extracted blob should reside (if is_archive=True) |
| 13 | + or the full path to the file itself (if is_archive=False). |
| 14 | + - link: str, optional, direct HTTP(s) link to an archive or single file. |
| 15 | + - drive_file_id: str, optional, ID for Google Drive file (archive or single file). |
15 | 16 | - kaggle: str, optional, KaggleHub dataset slug. |
16 | | - - ext: str, archive type (zip, tar.gz, rar, tar), used for link and drive_file_id. |
17 | | - This parameter is ignored if 'kaggle' is provided. |
| 17 | + - is_archive: bool, set to True if the fetched item is an archive that needs extraction. |
| 18 | + Set to False for single files like .pth. |
| 19 | + - ext: str, archive type (zip, tar.gz, rar, tar) or file extension (e.g., "pth"), |
| 20 | + used for link and drive_file_id. This parameter is ignored if 'kaggle' is provided. |
18 | 21 | """ |
19 | | - if os.path.exists(path) and os.path.isdir(path) and len(os.listdir(path)) > 0: |
20 | | - print(f"Dataset found locally at: {path}") |
21 | | - return |
| 22 | + if is_archive: |
| 23 | + if os.path.exists(path) and os.path.isdir(path) and len(os.listdir(path)) > 0: |
| 24 | + print(f"Dataset found locally at: {path}") |
| 25 | + return |
| 26 | + else: |
| 27 | + if os.path.exists(path) and os.path.isfile(path): |
| 28 | + print(f"File found locally at: {path}") |
| 29 | + return |
22 | 30 |
|
23 | 31 | print(f"Dataset not found locally at {path}. Downloading...") |
24 | | - os.makedirs(path, exist_ok=True) |
| 32 | + if is_archive: |
| 33 | + os.makedirs(path, exist_ok=True) |
| 34 | + else: |
| 35 | + os.makedirs(os.path.dirname(path) or '.', exist_ok=True) |
25 | 36 |
|
26 | 37 | if link: |
27 | | - temp_archive_name = f"{os.path.basename(path)}_archive.{ext}" |
28 | | - temp_archive_path = os.path.join(os.path.dirname(path) or '.', temp_archive_name) |
| 38 | + if is_archive: |
| 39 | + temp_target_path = os.path.join(os.path.dirname(path) or '.', f"{os.path.basename(path)}_archive.{ext}") |
| 40 | + else: |
| 41 | + temp_target_path = path |
29 | 42 |
|
30 | 43 | response = requests.get(link, stream=True) |
31 | 44 | if response.status_code != 200: |
32 | 45 | raise RuntimeError(f"Failed to download file from {link}: HTTP {response.status_code}") |
33 | 46 |
|
34 | 47 | total_size_in_bytes = int(response.headers.get('content-length', 0)) |
35 | 48 | block_size = 8192 |
36 | | - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(temp_archive_path)}") |
| 49 | + progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True, desc=f"Downloading {os.path.basename(temp_target_path)}") |
37 | 50 |
|
38 | | - with open(temp_archive_path, 'wb') as f: |
| 51 | + with open(temp_target_path, 'wb') as f: |
39 | 52 | for chunk in response.iter_content(chunk_size=block_size): |
40 | 53 | progress_bar.update(len(chunk)) |
41 | 54 | f.write(chunk) |
42 | 55 | progress_bar.close() |
43 | 56 |
|
44 | | - print(f"Extracting the {ext} archive...") |
45 | | - _extract_archive(temp_archive_path, path, ext) |
46 | | - os.remove(temp_archive_path) |
47 | | - print(f"Downloaded and extracted to {path}.") |
| 57 | + if is_archive: |
| 58 | + print(f"Extracting the {ext} archive...") |
| 59 | + _extract_archive(temp_target_path, path, ext) |
| 60 | + os.remove(temp_target_path) |
| 61 | + print(f"Downloaded and extracted to {path}.") |
| 62 | + else: |
| 63 | + print(f"Downloaded file to {path}.") |
48 | 64 |
|
49 | 65 | elif drive_file_id: |
50 | | - temp_archive_name = f"{os.path.basename(path)}_archive.{ext}" |
51 | | - temp_archive_path = os.path.join(os.path.dirname(path) or '.', temp_archive_name) |
| 66 | + if is_archive: |
| 67 | + temp_target_path = os.path.join(os.path.dirname(path) or '.', f"{os.path.basename(path)}_archive.{ext}") |
| 68 | + else: |
| 69 | + temp_target_path = path # For single files, download directly to the final path |
52 | 70 |
|
53 | 71 | print(f"Downloading from Google Drive ID: {drive_file_id}") |
54 | | - gdown.download(f"https://drive.google.com/uc?id={drive_file_id}", temp_archive_path, quiet=False) |
| 72 | + gdown.download(f"https://drive.google.com/uc?id={drive_file_id}", temp_target_path, quiet=False) |
55 | 73 |
|
56 | | - print(f"Extracting the {ext} archive...") |
57 | | - _extract_archive(temp_archive_path, path, ext) |
58 | | - os.remove(temp_archive_path) |
59 | | - print(f"Downloaded and extracted to {path}.") |
| 74 | + if is_archive: |
| 75 | + print(f"Extracting the {ext} archive...") |
| 76 | + _extract_archive(temp_target_path, path, ext) |
| 77 | + os.remove(temp_target_path) |
| 78 | + print(f"Downloaded and extracted to {path}.") |
| 79 | + else: |
| 80 | + print(f"Downloaded file to {path}.") |
60 | 81 |
|
61 | 82 | elif kaggle: |
62 | 83 | downloaded_kaggle_path = kagglehub.dataset_download(kaggle) |
63 | 84 | print(f"KaggleHub {kaggle} dataset downloaded to: {downloaded_kaggle_path}") |
64 | 85 |
|
65 | | - os.makedirs(path, exist_ok=True) |
| 86 | + os.makedirs(path, exist_ok=True) # Always treat kaggle as an archive/dataset for now |
66 | 87 |
|
67 | 88 | for item in os.listdir(downloaded_kaggle_path): |
68 | 89 | s = os.path.join(downloaded_kaggle_path, item) |
@@ -132,25 +153,29 @@ def _extract_archive(archive_path, extract_to, ext): |
132 | 153 | shutil.move(os.path.join(temp_extract_dir, item), extract_to) |
133 | 154 | shutil.rmtree(temp_extract_dir) |
134 | 155 |
|
135 | | -def fetch_dataset(dataset_name, datasets_dir=DATASETS_PATH, link="", drive_file_id="", kaggle="", ext="zip"): |
| 156 | +def fetch_dataset(dataset_name, datasets_dir=DATASETS_PATH, link="", drive_file_id="", kaggle="", is_archive=True, ext="zip"): |
136 | 157 | """ |
137 | 158 | Fetches a dataset blob from a direct link, Google Drive, or Kaggle, |
138 | 159 | maintaining backward compatibility with the original fetch_blob signature. |
139 | 160 |
|
140 | 161 | Parameters: |
141 | | - - dataset_name: str, The name of the dataset. This will be the directory name inside datasets_dir. |
| 162 | + - dataset_name: str, The name of the dataset. This will be the directory name inside datasets_dir |
| 163 | + for archives, or the file name if is_archive is False. |
142 | 164 | - datasets_dir: str, The root directory where datasets are stored. |
143 | | - - link: str, optional, direct HTTP(s) link to an archive. |
144 | | - - drive_file_id: str, optional, ID for Google Drive file (archive). |
| 165 | + - link: str, optional, direct HTTP(s) link to an archive or file. |
| 166 | + - drive_file_id: str, optional, ID for Google Drive file (archive or file). |
145 | 167 | - kaggle: str, optional, KaggleHub dataset slug. |
146 | | - - ext: str, archive type (zip, tar.gz, rar, tar), used for link and drive_file_id. |
147 | | - This parameter is ignored if 'kaggle' is provided. |
| 168 | + - ext: str, archive type (zip, tar.gz, rar, tar) or file extension (e.g., "pth"), |
| 169 | + used for link and drive_file_id. This parameter is ignored if 'kaggle' is provided. |
| 170 | + - is_archive: bool, set to True if the fetched item is an archive that needs extraction. |
| 171 | + Set to False for single files like .pth. |
148 | 172 | """ |
149 | 173 | full_dataset_path = os.path.join(datasets_dir, dataset_name) |
150 | 174 | fetch_blob( |
151 | 175 | path=full_dataset_path, |
152 | 176 | link=link, |
153 | 177 | drive_file_id=drive_file_id, |
154 | 178 | kaggle=kaggle, |
155 | | - ext=ext |
| 179 | + ext=ext, |
| 180 | + is_archive=is_archive |
156 | 181 | ) |
0 commit comments