Skip to content

Commit 2ce2122

Browse files
authored
fix split param (#480)
1 parent e9a912e commit 2ce2122

5 files changed

Lines changed: 95 additions & 4 deletions

File tree

roboflow/cli/handlers/_aliases.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def upload_alias(
6363
project: Annotated[str, typer.Option("-p", "--project", help="Project ID")],
6464
annotation: Annotated[Optional[str], typer.Option("-a", "--annotation", help="Annotation file")] = None,
6565
labelmap: Annotated[Optional[str], typer.Option("-m", "--labelmap", help="Labelmap file")] = None,
66-
split: Annotated[str, typer.Option("-s", "--split", help="Split (train/valid/test)")] = "train",
66+
split: Annotated[
67+
Optional[str],
68+
typer.Option("-s", "--split", help="Override split for all uploaded images (default: infer from folder)"),
69+
] = None,
6770
num_retries: Annotated[int, typer.Option("-r", "--retries", help="Retry count")] = 0,
6871
batch: Annotated[Optional[str], typer.Option("-b", "--batch", help="Batch name")] = None,
6972
tag_names: Annotated[Optional[str], typer.Option("-t", "--tag", help="Tag names")] = None,

roboflow/cli/handlers/image.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@ def upload_image(
2121
annotation: Annotated[
2222
Optional[str], typer.Option("-a", "--annotation", help="Path to annotation file (single upload)")
2323
] = None,
24-
split: Annotated[str, typer.Option("-s", "--split", help="Dataset split")] = "train",
24+
split: Annotated[
25+
Optional[str],
26+
typer.Option(
27+
"-s",
28+
"--split",
29+
help="Override split for all images (default: infer from folder for dirs, 'train' for files)",
30+
),
31+
] = None,
2532
batch: Annotated[Optional[str], typer.Option("-b", "--batch", help="Batch name")] = None,
2633
tag: Annotated[Optional[str], typer.Option("-t", "--tag", help="Comma-separated tag names")] = None,
2734
metadata: Annotated[Optional[str], typer.Option(help="JSON string of key-value metadata")] = None,
@@ -237,7 +244,7 @@ def _handle_upload_single(args, api_key: str, path: str) -> None: # noqa: ANN00
237244
image_path=path,
238245
annotation_path=args.annotation,
239246
annotation_labelmap=getattr(args, "labelmap", None),
240-
split=args.split,
247+
split=args.split or "train",
241248
num_retry_uploads=retries,
242249
batch_name=args.batch,
243250
tag_names=tag_names,

roboflow/core/workspace.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,8 @@ def upload_dataset(
427427
is_prediction (bool, optional): whether the annotations provided in the dataset are predictions and not ground truth. Defaults to False.
428428
use_zip_upload (bool, optional): opt-in to the zip flow for a directory input (the SDK zips it client-side). Ignored when dataset_path is already a `.zip`.
429429
tags (list[str], optional): zip flow only — tags to apply to the uploaded batch.
430-
split (str, optional): zip flow only — dataset split for the uploaded batch.
430+
split (str, optional): dataset split for the uploaded batch. In per-image directory
431+
uploads, this overrides inferred splits for every image.
431432
wait (bool, optional): zip flow only — poll for processing completion. Defaults to True.
432433
poll_interval (float, optional): zip flow only — seconds between status polls.
433434
poll_timeout (float, optional): zip flow only — total seconds to wait before timing out.
@@ -489,6 +490,9 @@ def upload_dataset(
489490
is_classification = project.type == "classification"
490491
parsed_dataset = folderparser.parsefolder(dataset_path, is_classification=is_classification)
491492
images = parsed_dataset["images"]
493+
if split is not None:
494+
for image in images:
495+
image["split"] = split
492496

493497
location = parsed_dataset["location"]
494498

tests/cli/test_image_handler.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,74 @@ def test_zip_upload_flag_defaults_false(self, mock_rf_cls):
345345
_, kwargs = mock_ws.upload_dataset.call_args
346346
self.assertEqual(kwargs.get("use_zip_upload"), False)
347347

348+
@patch("roboflow.Roboflow")
349+
def test_upload_directory_omits_default_split_when_not_explicit(self, mock_rf_cls):
350+
from roboflow.cli.handlers.image import _handle_upload
351+
352+
with tempfile.TemporaryDirectory() as tmpdir:
353+
mock_ws = MagicMock()
354+
mock_rf_cls.return_value.workspace.return_value = mock_ws
355+
356+
args = _make_args(
357+
json=True,
358+
path=tmpdir,
359+
project="proj",
360+
annotation=None,
361+
split=None,
362+
batch=None,
363+
tag=None,
364+
metadata=None,
365+
concurrency=10,
366+
retries=0,
367+
labelmap=None,
368+
is_prediction=False,
369+
)
370+
371+
buf = io.StringIO()
372+
old = sys.stdout
373+
sys.stdout = buf
374+
try:
375+
_handle_upload(args)
376+
finally:
377+
sys.stdout = old
378+
379+
_, kwargs = mock_ws.upload_dataset.call_args
380+
self.assertIsNone(kwargs.get("split"))
381+
382+
@patch("roboflow.Roboflow")
383+
def test_upload_directory_forwards_explicit_split(self, mock_rf_cls):
384+
from roboflow.cli.handlers.image import _handle_upload
385+
386+
with tempfile.TemporaryDirectory() as tmpdir:
387+
mock_ws = MagicMock()
388+
mock_rf_cls.return_value.workspace.return_value = mock_ws
389+
390+
args = _make_args(
391+
json=True,
392+
path=tmpdir,
393+
project="proj",
394+
annotation=None,
395+
split="valid",
396+
batch=None,
397+
tag=None,
398+
metadata=None,
399+
concurrency=10,
400+
retries=0,
401+
labelmap=None,
402+
is_prediction=False,
403+
)
404+
405+
buf = io.StringIO()
406+
old = sys.stdout
407+
sys.stdout = buf
408+
try:
409+
_handle_upload(args)
410+
finally:
411+
sys.stdout = old
412+
413+
_, kwargs = mock_ws.upload_dataset.call_args
414+
self.assertEqual(kwargs.get("split"), "valid")
415+
348416

349417
class TestImageDelete(unittest.TestCase):
350418
"""Test the delete handler."""

tests/test_project.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,15 @@ def test_project_upload_dataset(self):
327327
},
328328
"assertions": {"upload": {"count": 1, "kwargs": {"batch_name": "test-batch", "num_retry_uploads": 3}}},
329329
},
330+
{
331+
"name": "explicit_split_overrides_parsed_directory_splits",
332+
"dataset": [
333+
{"file": "image1.jpg", "split": "train"},
334+
{"file": "image2.jpg", "split": "test"},
335+
],
336+
"params": {"split": "valid", "num_workers": 1},
337+
"assertions": {"upload": {"count": 2, "kwargs": {"split": "valid"}}},
338+
},
330339
{
331340
"name": "project_creation",
332341
"dataset": None,

0 commit comments

Comments
 (0)