Skip to content

Commit d8f8f84

Browse files
Video support: Adding models for video and object detection track (#51)
* Adding models: video and object detection track * Apply suggestions from code review Co-authored-by: michal-lightly <105644579+michal-lightly@users.noreply.github.com> * fix naming --------- Co-authored-by: michal-lightly <105644579+michal-lightly@users.noreply.github.com>
1 parent 9d85385 commit d8f8f84

3 files changed

Lines changed: 133 additions & 0 deletions

File tree

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from argparse import ArgumentParser
5+
from dataclasses import dataclass
6+
from typing import Iterable, List
7+
8+
from labelformat.model.bounding_box import BoundingBox
9+
from labelformat.model.category import Category
10+
from labelformat.model.video import Video
11+
12+
13+
@dataclass(frozen=True)
14+
class SingleObjectDetectionTrack:
15+
category: Category
16+
boxes: list[BoundingBox | None]
17+
# TODO (Jonas, 01/2026): Add confidence
18+
19+
20+
@dataclass(frozen=True)
21+
class VideoObjectDetectionTrack:
22+
"""
23+
The base class for a video alongside with its object detection track annotations.
24+
A video consists of N frames and M objects. Each object is defined by N boxes - one for each frame.
25+
If an object is not present on a frame, the corresponding entry is set to None.
26+
"""
27+
28+
video: Video
29+
objects: List[SingleObjectDetectionTrack]
30+
31+
def __post_init__(self) -> None:
32+
number_of_frames = self.video.number_of_frames
33+
34+
for obj in self.objects:
35+
if len(obj.boxes) != number_of_frames:
36+
raise ValueError(
37+
"Length of object detection track does not match the number of frames in the video."
38+
)
39+
40+
41+
class ObjectDetectionTrackInput(ABC):
42+
@staticmethod
43+
@abstractmethod
44+
def add_cli_arguments(parser: ArgumentParser) -> None:
45+
raise NotImplementedError()
46+
47+
@abstractmethod
48+
def get_categories(self) -> Iterable[Category]:
49+
raise NotImplementedError()
50+
51+
@abstractmethod
52+
def get_videos(self) -> Iterable[Video]:
53+
raise NotImplementedError()
54+
55+
@abstractmethod
56+
def get_labels(self) -> Iterable[VideoObjectDetectionTrack]:
57+
raise NotImplementedError()
58+
59+
60+
class ObjectDetectionTrackOutput(ABC):
61+
@staticmethod
62+
@abstractmethod
63+
def add_cli_arguments(parser: ArgumentParser) -> None:
64+
raise NotImplementedError()
65+
66+
def save(self, label_input: ObjectDetectionTrackInput) -> None:
67+
raise NotImplementedError()

src/labelformat/model/video.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass(frozen=True)
5+
class Video:
6+
id: int
7+
filename: str
8+
width: int
9+
height: int
10+
number_of_frames: int
11+
# TODO (Jonas, 01/2026): Add list of frames
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
5+
from labelformat.model.bounding_box import BoundingBox
6+
from labelformat.model.category import Category
7+
from labelformat.model.object_detection_track import (
8+
SingleObjectDetectionTrack,
9+
VideoObjectDetectionTrack,
10+
)
11+
from labelformat.model.video import Video
12+
13+
14+
class TestVideoObjectDetectionTrack:
15+
def test_post_init__frames_equal_boxes_length__valid(self) -> None:
16+
track_a = SingleObjectDetectionTrack(
17+
category=Category(id=0, name="cat"),
18+
boxes=[BoundingBox(xmin=0, ymin=0, xmax=1, ymax=1) for _ in range(2)],
19+
)
20+
21+
track_b = SingleObjectDetectionTrack(
22+
category=Category(id=1, name="dog"),
23+
boxes=[BoundingBox(xmin=0, ymin=0, xmax=1, ymax=1) for _ in range(2)],
24+
)
25+
26+
video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2)
27+
28+
detections = VideoObjectDetectionTrack(
29+
video=video,
30+
objects=[track_a, track_b],
31+
)
32+
assert len(detections.objects) == 2
33+
assert len(detections.objects[0].boxes) == 2
34+
35+
def test_post_init__frames_equal_boxes_length___invalid(self) -> None:
36+
track_a = SingleObjectDetectionTrack(
37+
category=Category(id=0, name="cat"),
38+
boxes=[BoundingBox(xmin=0, ymin=0, xmax=1, ymax=1) for _ in range(2)],
39+
)
40+
41+
track_b = SingleObjectDetectionTrack(
42+
category=Category(id=1, name="dog"),
43+
boxes=[BoundingBox(xmin=0, ymin=0, xmax=1, ymax=1) for _ in range(3)],
44+
)
45+
46+
video = Video(id=0, filename="test.mov", width=1, height=1, number_of_frames=2)
47+
48+
with pytest.raises(
49+
ValueError,
50+
match="Length of object detection track does not match the number of frames in the video.",
51+
):
52+
VideoObjectDetectionTrack(
53+
video=video,
54+
objects=[track_a, track_b],
55+
)

0 commit comments

Comments
 (0)