Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions fastlabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from concurrent.futures import ThreadPoolExecutor, wait
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from xml.dom import minidom
from xml.etree import ElementTree as ET

import cv2
import numpy as np
Expand Down Expand Up @@ -3548,6 +3550,26 @@ def export_semantic_segmentation(
start_index=start_index,
)

def export_cvat(
self,
tasks: list,
output_dir: str = os.path.join("output", "cvat"),
pretty_print: bool = True,
) -> None:
xml_elements = converters.CvatConverter(logger).tasks_to_cvat(tasks)

os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "annotations.xml")
with open(output_path, "w", encoding="utf-8") as f:
if pretty_print:
doct = minidom.parseString(ET.tostring(xml_elements))
doct.writexml(
f, encoding="utf-8", indent=" ", newl="\n", addindent=" "
)
else:
tree = ET.ElementTree(xml_elements)
tree.write(f, encoding="unicode", xml_declaration=True)

def __export_index_color_image(
self,
task: list,
Expand Down
202 changes: 201 additions & 1 deletion fastlabel/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from contextlib import contextmanager
from datetime import datetime
from decimal import Decimal
from logging import Logger
from operator import itemgetter
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Union
from xml.etree import ElementTree as ET

import cv2
import geojson
Expand Down Expand Up @@ -980,6 +982,204 @@ def __remove_duplicated_coordinates(points: List[int]) -> List[int]:
return new_points


class CvatConverter:
CVAT_VERSION = "1.1"
_ANNOTATION_BUILDERS = {
"bbox": "_add_bbox",
"polygon": "_add_polygon",
"polyline": "_add_polyline",
"keypoint": "_add_keypoints",
"line": "_add_line",
"segmentation": "_add_segmentation",
}

def __init__(
self,
logger: Logger,
) -> None:
self.logger = logger

def tasks_to_cvat(self, tasks: Iterable[dict]) -> ET.Element:
"""Convert FastLabel tasks to CVAT XML format.

tasks schema (dict):
- task: {
"name": str,
"width": int,
"height": int,
"annotations": [annotation, ...]
}
- annotation: {
"id": str,
"title": str | None,
"type": str, # bbox / polygon / polyline / keypoint / line / segmentation
"value": str,
"points": Any,
"attributes": [{"name": str, "value": Any}, ...],
"rotation": float | None (optional, bbox only)
}

returns:
- root: <annotations>...</annotations>

XML outline:
<annotations>
<version>1.1</version>
<image ...>
<box ...>
<attribute name="...">...</attribute>
</box>
</image>
</annotations>
"""
root = ET.Element("annotations")
self._make_tag(root, "version", CvatConverter.CVAT_VERSION)

for index, task in enumerate(tasks):
image = self._make_tag(
root,
"image",
attrib={
"id": index,
"name": task["name"].replace("/", "_"),
"width": task["width"],
"height": task["height"],
},
)
for annotation in task["annotations"]:
elems: list[ET.Element] = []
try:
fl_type = annotation["type"]
if fl_type not in self._ANNOTATION_BUILDERS:
raise ValueError(
f"Unsupported fastLabel annotation type: {annotation['type']}"
)
builder = getattr(self, self._ANNOTATION_BUILDERS[fl_type])
elems = builder(image, annotation)

for elem in elems:
for attr in annotation["attributes"]:
self._make_tag(
elem,
"attribute",
attr["value"],
attrib={"name": attr["name"]},
)
except Exception as e:
for elem in elems or []:
if elem in image:
image.remove(elem)

self.logger.error(
"task_name=%s annotation_id=%s annotation_title=%s annotation_type=%s error=%s",
task.get("name"),
annotation.get("id"),
annotation.get("title"),
annotation.get("type"),
e,
)
continue

return root

@staticmethod
def _make_tag(
root: ET.Element, tag: str, value: Any = None, attrib: dict | None = None
) -> ET.Element:
safe_attrib = {
k: "" if v is None else str(v) for k, v in (attrib or {}).items()
}
elem = ET.SubElement(root, tag, attrib=safe_attrib)
if value is not None:
elem.text = str(value)
return elem

def _add_points_shape(
self, image_elem: ET.Element, annotation: dict, tag: str
) -> list[ET.Element]:
points = list(annotation["points"])
if len(points) % 2 != 0:
raise ValueError(
f"ポイントが偶数ではありません。Annotation({annotation['value']}): {len(points)}"
)
flatten = []
while points:
x, y, *points = points
flatten.append(f"{x}, {y}")
elem = self._make_tag(
image_elem,
tag,
attrib={"label": annotation["value"], "points": ";".join(flatten)},
)
return [elem]

def _add_bbox(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]:
points = annotation["points"]
if len(points) != 4:
raise ValueError("矩形のポイントが 4 つではありません。")
elem = self._make_tag(
image_elem,
"box",
attrib={
"label": annotation["value"],
"xtl": points[0],
"ytl": points[1],
"xbr": points[2],
"ybr": points[3],
"rotation": annotation.get("rotation", 0),
},
)
return [elem]

def _add_polygon(
self, image_elem: ET.Element, annotation: dict
) -> list[ET.Element]:
return self._add_points_shape(image_elem, annotation, "polygon")

def _add_polyline(
self, image_elem: ET.Element, annotation: dict
) -> list[ET.Element]:
return self._add_points_shape(image_elem, annotation, "polyline")

def _add_keypoints(
self, image_elem: ET.Element, annotation: dict
) -> list[ET.Element]:
return self._add_points_shape(image_elem, annotation, "points")

def _add_line(self, image_elem: ET.Element, annotation: dict) -> list[ET.Element]:
return self._add_points_shape(image_elem, annotation, "polyline")

def _add_segmentation(
self, image_elem: ET.Element, annotation: dict
) -> list[ET.Element]:
polygons = []
points = annotation["points"]
if not points:
raise ValueError("セグメンテーションのポイント数が 0 です。")
try:
points[0][0][0]
except IndexError as exc:
raise ValueError("セグメンテーションが3次元配列ではありません。") from exc
for segment in points:
for region in segment:
flatten = []
region_points = list(region)
while region_points:
x, y, *region_points = region_points
flatten.append(f"{x}, {y}")
polygons.append(
self._make_tag(
image_elem,
"polygon",
attrib={
"label": annotation["value"],
"points": ";".join(flatten),
},
)
)
return polygons


def get_pixel_coordinates(points: List[Union[int, float]]) -> List[int]:
"""
Remove diagonal coordinates and return pixel outline coordinates.
Expand Down
Loading