Skip to content

Commit 5f2aae5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent a9412fc commit 5f2aae5

3 files changed

Lines changed: 19 additions & 26 deletions

File tree

kornia/contrib/models/rt_detr/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _state_dict_proc(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:
213213

214214
model.load_state_dict(_state_dict_proc(state_dict))
215215
return model
216-
216+
217217
@staticmethod
218218
def from_name(model_name: str, num_classes: int = 80) -> RTDETR:
219219
"""Load model without pretrained weights.
@@ -234,7 +234,7 @@ def from_name(model_name: str, num_classes: int = 80) -> RTDETR:
234234
model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet101d, num_classes))
235235
else:
236236
raise ValueError
237-
237+
238238
return model
239239

240240
def forward(self, images: Tensor) -> tuple[Tensor, Tensor]:

kornia/contrib/object_detection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,22 +168,22 @@ def forward(self, images: list[Tensor]) -> list[Tensor]:
168168
return detections
169169

170170
def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor] | Image.Image: # type: ignore
171-
"""Very simple drawing. Needs to be more fancy later.
171+
"""Very simple drawing.
172+
173+
Needs to be more fancy later.
172174
"""
173175
detections = self.forward(images)
174176
output = []
175177
for image, detection in zip(images, detections):
176178
out_img = image[None].clone()
177179
for out in detection:
178180
out_img = draw_rectangle(
179-
out_img,
180-
torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]])
181+
out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]])
181182
)
182183
if output_type == "torch":
183184
output.append(out_img)
184185
elif output_type == "pil":
185-
output.append(Image.fromarray(
186-
(out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore
186+
output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore
187187
return output
188188

189189
def compile(

kornia/models/detector/rtdetr.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,32 @@
1-
from typing import Optional
21
import warnings
2+
from typing import Optional
33

4-
from kornia.core import Module
5-
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig
64
from kornia.contrib.models.rt_detr import DETRPostProcessor
7-
from kornia.contrib.object_detection import ResizePreProcessor, ObjectDetector
5+
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig
6+
from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor
87

98

109
class RTDETRDetectorBuilder:
11-
1210
@staticmethod
1311
def build(
1412
model_name: Optional[str] = None,
1513
config: Optional[RTDETRConfig] = None,
1614
pretrained: bool = True,
1715
image_size: int = 640,
18-
confidence_threshold: float = 0.5
16+
confidence_threshold: float = 0.5,
1917
) -> ObjectDetector:
20-
if (model_name is not None and config is not None):
18+
if model_name is not None and config is not None:
2119
raise ValueError("Either `model_name` or `config` should be `None`.")
22-
20+
2321
if model_name is None and config is None:
2422
warnings.warn("No `model_name` or `config` found. Will build `rtdetr_r18vd`.")
2523
model_name = "rtdetr_r18vd"
26-
24+
2725
if config is not None:
2826
model = RTDETR.from_config(config)
27+
elif pretrained:
28+
model = RTDETR.from_pretrained(model_name)
2929
else:
30-
if pretrained:
31-
model = RTDETR.from_pretrained(model_name)
32-
else:
33-
model = RTDETR.from_name(model_name)
34-
35-
return ObjectDetector(
36-
model,
37-
ResizePreProcessor(image_size),
38-
DETRPostProcessor(confidence_threshold)
39-
)
30+
model = RTDETR.from_name(model_name)
31+
32+
return ObjectDetector(model, ResizePreProcessor(image_size), DETRPostProcessor(confidence_threshold))

0 commit comments

Comments
 (0)