Skip to content

Commit 98145bb

Browse files
Adding support for search export (#437)
* initial version * fix(pre_commit): 🎨 auto format pre-commit hooks * refactoring to reuse zip and download utils * fix(pre_commit): 🎨 auto format pre-commit hooks * remove it * remove it * fix(pre_commit): 🎨 auto format pre-commit hooks * removing unnecessary * fix * using session, increasing timeout, user message * fix(pre_commit): 🎨 auto format pre-commit hooks * ruff cleanup * avoid type error * improve docs * changing 404 error * fix tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 375504e commit 98145bb

File tree

8 files changed

+321
-30
lines changed

8 files changed

+321
-30
lines changed

roboflow/adapters/rfapi.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,53 @@ def get_version_export(
152152
return payload
153153

154154

155+
def start_search_export(
156+
api_key: str,
157+
workspace_url: str,
158+
query: str,
159+
format: str,
160+
session: requests.Session,
161+
dataset: Optional[str] = None,
162+
annotation_group: Optional[str] = None,
163+
name: Optional[str] = None,
164+
) -> str:
165+
"""Start a search export job.
166+
167+
Returns the export_id string used to poll for completion.
168+
169+
Raises RoboflowError on non-202 responses.
170+
"""
171+
url = f"{API_URL}/{workspace_url}/search/export?api_key={api_key}"
172+
body: Dict[str, str] = {"query": query, "format": format}
173+
if dataset is not None:
174+
body["dataset"] = dataset
175+
if annotation_group is not None:
176+
body["annotationGroup"] = annotation_group
177+
if name is not None:
178+
body["name"] = name
179+
180+
response = session.post(url, json=body)
181+
if response.status_code != 202:
182+
raise RoboflowError(response.text)
183+
184+
payload = response.json()
185+
return payload["link"]
186+
187+
188+
def get_search_export(api_key: str, workspace_url: str, export_id: str, session: requests.Session) -> dict:
189+
"""Poll the status of a search export job.
190+
191+
Returns dict with ``ready`` (bool) and ``link`` (str, present when ready).
192+
193+
Raises RoboflowError on non-200 responses.
194+
"""
195+
url = f"{API_URL}/{workspace_url}/search/export/{export_id}?api_key={api_key}"
196+
response = session.get(url)
197+
if response.status_code != 200:
198+
raise RoboflowError(response.text)
199+
return response.json()
200+
201+
155202
def upload_image(
156203
api_key,
157204
project_url,

roboflow/core/version.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import sys
77
import time
8-
import zipfile
98
from typing import TYPE_CHECKING, Optional, Union
109

1110
import requests
@@ -32,7 +31,7 @@
3231
from roboflow.models.object_detection import ObjectDetectionModel
3332
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
3433
from roboflow.util.annotations import amend_data_yaml
35-
from roboflow.util.general import write_line
34+
from roboflow.util.general import extract_zip, write_line
3635
from roboflow.util.model_processor import process
3736
from roboflow.util.versions import get_model_format, get_wrong_dependencies_versions, normalize_yolo_model_type
3837

@@ -239,7 +238,7 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
239238
link = export_info["export"]["link"]
240239

241240
self.__download_zip(link, location, model_format)
242-
self.__extract_zip(location, model_format)
241+
extract_zip(location, desc=f"Extracting Dataset Version Zip to {location} in {model_format}:")
243242
self.__reformat_yaml(location, model_format) # TODO: is roboflow-python a place to be munging yaml files?
244243

245244
return Dataset(self.name, self.version, model_format, os.path.abspath(location))
@@ -577,30 +576,6 @@ def bar_progress(current, total, width=80):
577576
sys.stdout.write("\n")
578577
sys.stdout.flush()
579578

580-
def __extract_zip(self, location, format):
581-
"""
582-
Extracts the contents of a downloaded ZIP file and then deletes the zipped file.
583-
584-
Args:
585-
location (str): filepath of the data directory that contains the ZIP file
586-
format (str): the format identifier string
587-
588-
Raises:
589-
RuntimeError: If there is an error unzipping the file
590-
""" # noqa: E501 // docs
591-
desc = None if TQDM_DISABLE else f"Extracting Dataset Version Zip to {location} in {format}:"
592-
with zipfile.ZipFile(location + "/roboflow.zip", "r") as zip_ref:
593-
for member in tqdm(
594-
zip_ref.infolist(),
595-
desc=desc,
596-
):
597-
try:
598-
zip_ref.extract(member, location)
599-
except zipfile.error:
600-
raise RuntimeError("Error unzipping download")
601-
602-
os.remove(location + "/roboflow.zip")
603-
604579
def __get_download_location(self):
605580
"""
606581
Get the local path to save a downloaded dataset to
@@ -707,4 +682,4 @@ def __str__(self):
707682

708683

709684
def unwrap_version_id(version_id: str) -> str:
710-
return version_id if "/" not in str(version_id) else version_id.split("/")[-1]
685+
return version_id if "/" not in str(version_id) else version_id.rsplit("/", maxsplit=1)[-1]

roboflow/core/workspace.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,21 @@
55
import json
66
import os
77
import sys
8+
import time
89
from typing import Any, Dict, List, Optional
910

1011
import requests
1112
from PIL import Image
13+
from requests.exceptions import HTTPError
14+
from tqdm import tqdm
1215

1316
from roboflow.adapters import rfapi
1417
from roboflow.adapters.rfapi import AnnotationSaveError, ImageUploadError, RoboflowError
1518
from roboflow.config import API_URL, APP_URL, CLIP_FEATURIZE_URL, DEMO_KEYS
1619
from roboflow.core.project import Project
1720
from roboflow.util import folderparser
1821
from roboflow.util.active_learning_utils import check_box_size, clip_encode, count_comparisons
22+
from roboflow.util.general import extract_zip as _extract_zip
1923
from roboflow.util.image_utils import load_labelmap
2024
from roboflow.util.model_processor import process
2125
from roboflow.util.two_stage_utils import ocr_infer
@@ -662,6 +666,115 @@ def _upload_zip(
662666
except Exception as e:
663667
print(f"An error occured when uploading the model: {e}")
664668

669+
def search_export(
670+
self,
671+
query: str,
672+
format: str = "coco",
673+
location: Optional[str] = None,
674+
dataset: Optional[str] = None,
675+
annotation_group: Optional[str] = None,
676+
name: Optional[str] = None,
677+
extract_zip: bool = True,
678+
) -> str:
679+
"""Export search results as a downloaded dataset.
680+
681+
Args:
682+
query: Search query string (e.g. ``"tag:annotate"`` or ``"class:apple"``).
683+
format: Annotation format for the export (default ``"coco"``).
684+
location: Local directory to save the exported dataset.
685+
Defaults to ``./search-export-{format}``.
686+
dataset: Limit export to a specific dataset (project) slug.
687+
annotation_group: Limit export to a specific annotation group.
688+
name: Optional name for the export.
689+
extract_zip: If True (default), extract the zip and remove it.
690+
If False, keep the zip file as-is.
691+
692+
Returns:
693+
Absolute path to the extracted directory or the zip file.
694+
695+
Raises:
696+
ValueError: If both *dataset* and *annotation_group* are provided.
697+
RoboflowError: On API errors or export timeout.
698+
"""
699+
if dataset is not None and annotation_group is not None:
700+
raise ValueError("dataset and annotation_group are mutually exclusive; provide only one")
701+
702+
if location is None:
703+
location = f"./search-export-{format}"
704+
location = os.path.abspath(location)
705+
706+
# 1. Start the export
707+
session = requests.Session()
708+
export_id = rfapi.start_search_export(
709+
api_key=self.__api_key,
710+
workspace_url=self.url,
711+
query=query,
712+
format=format,
713+
dataset=dataset,
714+
annotation_group=annotation_group,
715+
name=name,
716+
session=session,
717+
)
718+
print(f"Export started (id={export_id}). Polling for completion...")
719+
720+
status_url = f"{API_URL}/{self.url}/search/export/{export_id}?api_key=YOUR_API_KEY"
721+
print(f"If this takes too long, you can check the export status at: {status_url}")
722+
723+
# 2. Poll until ready
724+
timeout = 1800
725+
poll_interval = 5
726+
elapsed = 0
727+
728+
while elapsed < timeout:
729+
status = rfapi.get_search_export(
730+
api_key=self.__api_key,
731+
workspace_url=self.url,
732+
export_id=export_id,
733+
session=session,
734+
)
735+
if status.get("ready"):
736+
break
737+
time.sleep(poll_interval)
738+
elapsed += poll_interval
739+
else:
740+
raise RoboflowError(f"Search export timed out after {timeout}s")
741+
742+
download_url = status["link"]
743+
744+
# 3. Download zip
745+
if not os.path.exists(location):
746+
os.makedirs(location)
747+
748+
zip_path = os.path.join(location, "roboflow.zip")
749+
response = session.get(download_url, stream=True)
750+
try:
751+
response.raise_for_status()
752+
except HTTPError as e:
753+
raise RoboflowError(f"Failed to download search export: {e}")
754+
755+
total_length = response.headers.get("content-length")
756+
try:
757+
total_kib = int(total_length) // 1024 + 1 if total_length is not None else None
758+
except (TypeError, ValueError):
759+
total_kib = None
760+
with open(zip_path, "wb") as f:
761+
for chunk in tqdm(
762+
response.iter_content(chunk_size=1024),
763+
desc=f"Downloading search export to {location}",
764+
total=total_kib,
765+
):
766+
if chunk:
767+
f.write(chunk)
768+
f.flush()
769+
770+
if extract_zip:
771+
_extract_zip(location, desc=f"Extracting search export to {location}")
772+
print(f"Search export extracted to {location}")
773+
return location
774+
else:
775+
print(f"Search export saved to {zip_path}")
776+
return zip_path
777+
665778
def __str__(self):
666779
projects = self.projects()
667780
json_value = {"name": self.name, "url": self.url, "projects": projects}

roboflow/roboflowpy.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,21 @@ def infer(args):
202202
print(group)
203203

204204

205+
def search_export(args):
206+
rf = roboflow.Roboflow()
207+
workspace = rf.workspace(args.workspace)
208+
result = workspace.search_export(
209+
query=args.query,
210+
format=args.format,
211+
location=args.location,
212+
dataset=args.dataset,
213+
annotation_group=args.annotation_group,
214+
name=args.name,
215+
extract_zip=not args.no_extract,
216+
)
217+
print(result)
218+
219+
205220
def _argparser():
206221
parser = argparse.ArgumentParser(description="Welcome to the roboflow CLI: computer vision at your fingertips 🪄")
207222
subparsers = parser.add_subparsers(title="subcommands")
@@ -218,6 +233,7 @@ def _argparser():
218233
_add_run_video_inference_api_parser(subparsers)
219234
deployment.add_deployment_parser(subparsers)
220235
_add_whoami_parser(subparsers)
236+
_add_search_export_parser(subparsers)
221237

222238
parser.add_argument("-v", "--version", help="show version info", action="store_true")
223239
parser.set_defaults(func=show_version)
@@ -594,6 +610,19 @@ def _add_get_workspace_project_version_parser(subparsers):
594610
workspace_project_version_parser.set_defaults(func=get_workspace_project_version)
595611

596612

613+
def _add_search_export_parser(subparsers):
614+
p = subparsers.add_parser("search-export", help="Export search results as a dataset")
615+
p.add_argument("query", help="Search query (e.g. 'tag:annotate' or '*')")
616+
p.add_argument("-f", dest="format", default="coco", help="Annotation format (default: coco)")
617+
p.add_argument("-w", dest="workspace", help="Workspace url or id (uses default workspace if not specified)")
618+
p.add_argument("-l", dest="location", help="Local directory to save the export")
619+
p.add_argument("-d", dest="dataset", help="Limit export to a specific dataset (project slug)")
620+
p.add_argument("-g", dest="annotation_group", help="Limit export to a specific annotation group")
621+
p.add_argument("-n", dest="name", help="Optional name for the export")
622+
p.add_argument("--no-extract", dest="no_extract", action="store_true", help="Skip extraction, keep the zip file")
623+
p.set_defaults(func=search_export)
624+
625+
597626
def _add_login_parser(subparsers):
598627
login_parser = subparsers.add_parser("login", help="Log in to Roboflow")
599628
login_parser.add_argument(

roboflow/util/general.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
import os
12
import sys
23
import time
4+
import zipfile
35
from random import random
46

7+
from tqdm import tqdm
8+
9+
from roboflow.config import TQDM_DISABLE
10+
511

612
def write_line(line):
713
sys.stdout.write("\r" + line)
@@ -40,3 +46,22 @@ def __call__(self, func, *args, **kwargs):
4046
self.retries += 1
4147
else:
4248
raise
49+
50+
51+
def extract_zip(location: str, desc: str = "Extracting"):
52+
"""Extract ``roboflow.zip`` inside *location* and remove the archive.
53+
54+
Args:
55+
location: Directory containing ``roboflow.zip``.
56+
desc: Description shown in the tqdm progress bar.
57+
"""
58+
zip_path = os.path.join(location, "roboflow.zip")
59+
tqdm_desc = None if TQDM_DISABLE else desc
60+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
61+
for member in tqdm(zip_ref.infolist(), desc=tqdm_desc):
62+
try:
63+
zip_ref.extract(member, location)
64+
except zipfile.error:
65+
raise RuntimeError("Error unzipping download")
66+
67+
os.remove(zip_path)

0 commit comments

Comments
 (0)