diff --git a/CHANGELOG.md b/CHANGELOG.md index cc6c4588..d29d96df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ All notable changes to this project will be documented in this file. -## Unreleased +## 1.3.9 ### Added — Model evaluations SDK & CLI @@ -45,6 +45,12 @@ The endpoints require the `model-eval:read` scope. The base URL is configurable via `API_URL` (set to `https://localapi.roboflow.one` to test against a local API server). +### Fixed +- rf-detr model upload: accept checkpoints whose `args` is a plain dict (e.g. EMA checkpoints) when extracting class names, instead of raising `TypeError` from `vars()`. + +### Changed +- Pin `typer<0.26` and declare `click` explicitly: typer 0.26 vendors its own click and drops the external dependency, which broke the CLI and its type checks. + ## 1.3.7 ### Added — Soft-delete / Trash support diff --git a/requirements-slim.txt b/requirements-slim.txt index 9709e296..9c4d021e 100644 --- a/requirements-slim.txt +++ b/requirements-slim.txt @@ -6,7 +6,8 @@ tqdm>=4.41.0 PyYAML>=5.3.1 requests_toolbelt filetype -typer>=0.12.0 +typer>=0.12.0,<0.26 # 0.26 vendors click, dropping the external dep the CLI imports +click>=8.0 python-dateutil python-dotenv six diff --git a/requirements.txt b/requirements.txt index 79cd4bb2..3984d070 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ tqdm>=4.41.0 PyYAML>=5.3.1 requests_toolbelt filetype -typer>=0.12.0 +typer>=0.12.0,<0.26 # 0.26 vendors click, dropping the external dep the CLI imports +click>=8.0 diff --git a/roboflow/util/model_processor.py b/roboflow/util/model_processor.py index 674170ae..4ff0ce1a 100644 --- a/roboflow/util/model_processor.py +++ b/roboflow/util/model_processor.py @@ -414,7 +414,9 @@ def get_classnames_txt_for_rfdetr(model_path: str, pt_file: str, checkpoint=None import torch checkpoint = torch.load(os.path.join(model_path, pt_file), map_location="cpu", weights_only=False) - args = vars(checkpoint["args"]) + raw_args = checkpoint["args"] + # args may be a plain dict in some checkpoints + args = raw_args if isinstance(raw_args, dict) else vars(raw_args) if "class_names" in args: with open(class_names_path, "w") as f: for class_name in args["class_names"]: diff --git a/tests/util/test_model_processor.py b/tests/util/test_model_processor.py index 951abe29..37ecb186 100644 --- a/tests/util/test_model_processor.py +++ b/tests/util/test_model_processor.py @@ -1,3 +1,5 @@ +import os +import tempfile import unittest from types import SimpleNamespace @@ -5,6 +7,7 @@ from roboflow.util.model_processor import ( _detect_rfdetr_task, _detect_yolo_task, + get_classnames_txt_for_rfdetr, task_of_model_type, ) @@ -84,5 +87,22 @@ def test_unrecognized_returns_none(self): self.assertIsNone(_detect_rfdetr_task({"args": SimpleNamespace(other=1)})) +class GetClassnamesTxtForRfdetrTest(unittest.TestCase): + def _classnames(self, args): + with tempfile.TemporaryDirectory() as model_path: + get_classnames_txt_for_rfdetr(model_path, "weights.pt", checkpoint={"args": args}) + with open(os.path.join(model_path, "class_names.txt")) as f: + return f.read().splitlines() + + def test_dict_args(self): + self.assertEqual(self._classnames({"class_names": ["cat", "dog"]}), ["background_class83422", "cat", "dog"]) + + def test_namespace_args(self): + self.assertEqual( + self._classnames(SimpleNamespace(class_names=["cat", "dog"])), + ["background_class83422", "cat", "dog"], + ) + + if __name__ == "__main__": unittest.main()