Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/actions/setup-uv/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: Setup UV
description: Installs UV from Astral.sh
runs:
using: "composite"
steps:
- name: Install curl and UV
shell: bash
run: |
sudo apt update
sudo apt install -y curl
curl -LsSf https://astral.sh/uv/install.sh | sh
echo "$HOME/.local/bin" >> $GITHUB_PATH
18 changes: 18 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Format
on: [push]
jobs:
ruff:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4

- name: Setup UV
uses: ./.github/actions/setup-uv

- name: Install ruff
run: uv add ruff
- name: Run ruff check
run: uv run ruff check
- name: Run ruff format check
run: uv run ruff format --check

14 changes: 14 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: Test
on: [push]
jobs:
pytest:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4

- name: Setup UV
uses: ./.github/actions/setup-uv

- name: test
run: uv run pytest

7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ boxes, landmarks, scores = detector.infer(image)

```

## Formatting
ALl code should be formatted with ruff:
```
uv run ruff format
uv run ruff check
```

## Citation
If you find this code useful, remember to cite the original authors:
```
Expand Down
8 changes: 3 additions & 5 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@
num = 1000

for detector in face_detection.available_detectors:
detector = face_detection.build_detector(
detector,
fp16_inference=True
)
detector = face_detection.build_detector(detector, fp16_inference=True)
im = "images/0_Parade_Parade_0_873.jpg"
im = cv2.imread(im)[:, :, ::-1]
t = time.time()
Expand All @@ -23,4 +20,5 @@
ms = avg_time * 1000
print(
f"Detector: {detector.__class__.__name__}. Average inference time over image shape: {im.shape} is:",
f"{ms:.2f} ms, fps: {fps:.2f}")
f"{ms:.2f} ms, fps: {fps:.2f}",
)
10 changes: 9 additions & 1 deletion face_detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
from .build import build_detector, available_detectors
from .dsfd import DSFDDetector
from .retinaface import RetinaNetMobileNetV1, RetinaNetResNet50
from .retinaface import RetinaNetMobileNetV1, RetinaNetResNet50

__all__ = [
"build_detector",
"available_detectors",
"RetinaNetMobileNetV1",
"RetinaNetResNet50",
"DSFDDetector",
]
41 changes: 19 additions & 22 deletions face_detection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,23 @@


def check_image(im: np.ndarray):
assert im.dtype == np.uint8,\
f"Expect image to have dtype np.uint8. Was: {im.dtype}"
assert len(im.shape) == 4,\
f"Expected image to have 4 dimensions. got: {im.shape}"
assert im.shape[-1] == 3,\
assert im.dtype == np.uint8, f"Expect image to have dtype np.uint8. Was: {im.dtype}"
assert len(im.shape) == 4, f"Expected image to have 4 dimensions. got: {im.shape}"
assert im.shape[-1] == 3, (
f"Expected image to be RGB, got: {im.shape[-1]} color channels"
)


class Detector(ABC):

def __init__(
self,
confidence_threshold: float,
nms_iou_threshold: float,
device: torch.device,
max_resolution: int,
fp16_inference: bool,
clip_boxes: bool):
self,
confidence_threshold: float,
nms_iou_threshold: float,
device: torch.device,
max_resolution: int,
fp16_inference: bool,
clip_boxes: bool,
):
"""
Args:
confidence_threshold (float): Threshold to filter out bounding boxes
Expand All @@ -40,11 +39,9 @@ def __init__(
self.max_resolution = max_resolution
self.fp16_inference = fp16_inference
self.clip_boxes = clip_boxes
self.mean = np.array(
[123, 117, 104], dtype=np.float32).reshape(1, 1, 1, 3)
self.mean = np.array([123, 117, 104], dtype=np.float32).reshape(1, 1, 1, 3)

def detect(
self, image: np.ndarray, shrink=1.0) -> np.ndarray:
def detect(self, image: np.ndarray, shrink=1.0) -> np.ndarray:
"""Takes an RGB image and performs and returns a set of bounding boxes as
detections
Args:
Expand Down Expand Up @@ -77,7 +74,7 @@ def filter_boxes(self, boxes: torch.Tensor) -> typing.List[np.ndarray]:
"""
final_output = []
for i in range(len(boxes)):
scores = boxes[i, :, 4]
scores = boxes[i, :, 4]
keep_idx = scores >= self.confidence_threshold
boxes_ = boxes[i, keep_idx, :-1]
scores = scores[keep_idx]
Expand All @@ -99,7 +96,7 @@ def resize(self, image, shrink: float):
shrink_factor = self.max_resolution / max((height, width))
if shrink_factor <= shrink:
shrink = shrink_factor
size = (int(height*shrink), int(width*shrink))
size = (int(height * shrink), int(width * shrink))
image = torch.nn.functional.interpolate(image, size=size)
return image

Expand Down Expand Up @@ -130,8 +127,7 @@ def _batched_detect(self, image: np.ndarray) -> typing.List[np.ndarray]:
return boxes

@torch.no_grad()
def batched_detect(
self, image: np.ndarray, shrink=1.0) -> typing.List[np.ndarray]:
def batched_detect(self, image: np.ndarray, shrink=1.0) -> typing.List[np.ndarray]:
"""Takes N RGB image and performs and returns a set of bounding boxes as
detections
Args:
Expand All @@ -150,5 +146,6 @@ def batched_detect(

def validate_detections(self, boxes: typing.List[np.ndarray]):
for box in boxes:
assert np.all(box[:, 4] <= 1) and np.all(box[:, 4] >= 0),\
assert np.all(box[:, 4] <= 1) and np.all(box[:, 4] >= 0), (
f"Confidence values not valid: {box}"
)
11 changes: 7 additions & 4 deletions face_detection/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ def batched_decode(loc, priors, variances, to_XYXY=True):
decoded bounding box predictions
"""
priors = priors[None]
boxes = torch.cat((
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])),
dim=2)
boxes = torch.cat(
(
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1]),
),
dim=2,
)
if to_XYXY:
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
boxes[:, :, 2:] += boxes[:, :, :2]
Expand Down
31 changes: 14 additions & 17 deletions face_detection/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,31 @@
from .base import Detector
from .torch_utils import get_device

available_detectors = [
"DSFDDetector",
"RetinaNetResNet50",
"RetinaNetMobileNetV1"
]
available_detectors = ["DSFDDetector", "RetinaNetResNet50", "RetinaNetMobileNetV1"]
DETECTOR_REGISTRY = Registry("DETECTORS")


def build_detector(
name: str = "DSFDDetector",
confidence_threshold: float = 0.5,
nms_iou_threshold: float = 0.3,
device=get_device(),
max_resolution: int = None,
fp16_inference: bool = False,
clip_boxes: bool = False
) -> Detector:
assert name in available_detectors,\
f"Detector not available. Chooce one of the following"+\
",".join(available_detectors)
name: str = "DSFDDetector",
confidence_threshold: float = 0.5,
nms_iou_threshold: float = 0.3,
device=get_device(),
max_resolution: int = None,
fp16_inference: bool = False,
clip_boxes: bool = False,
) -> Detector:
assert name in available_detectors, (
"Detector not available. Chooce one of the following"
+ ",".join(available_detectors)
)
args = dict(
type=name,
confidence_threshold=confidence_threshold,
nms_iou_threshold=nms_iou_threshold,
device=device,
max_resolution=max_resolution,
fp16_inference=fp16_inference,
clip_boxes=clip_boxes
clip_boxes=clip_boxes,
)
detector = build_from_cfg(args, DETECTOR_REGISTRY)
return detector
4 changes: 3 additions & 1 deletion face_detection/dsfd/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .detect import DSFDDetector
from .detect import DSFDDetector

__all__ = ["DSFDDetector"]
64 changes: 41 additions & 23 deletions face_detection/dsfd/config.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,43 @@
resnet152_model_config = {
'num_classes': 2,
'feature_maps': [160, 80, 40, 20, 10, 5],
'min_dim': 640,
'steps': [4, 8, 16, 32, 64, 128], # stride
'variance': [0.1, 0.2],
'clip': True, # make default box in [0,1]
'base': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', 512, 512, 512] ,
'extras': [256, 'S', 512, 128, 'S', 256],
'mbox': [1, 1, 1, 1, 1, 1] ,
'min_sizes': [16, 32, 64, 128, 256, 512],
'max_sizes': [],
'aspect_ratios': [ [1.5],[1.5],[1.5],[1.5],[1.5],[1.5] ], # [1,2] default 1
'backbone': 'resnet152' , # vgg, resnet, detnet, resnet50
'feature_pyramid_network':True ,
'bottom_up_path': False ,
'feature_enhance_module': True ,
'max_in_out': True ,
'focal_loss': False ,
'progressive_anchor': True ,
'refinedet': False ,
'max_out': False ,
'anchor_compensation': False ,
'data_anchor_sampling': False ,
"num_classes": 2,
"feature_maps": [160, 80, 40, 20, 10, 5],
"min_dim": 640,
"steps": [4, 8, 16, 32, 64, 128], # stride
"variance": [0.1, 0.2],
"clip": True, # make default box in [0,1]
"base": [
64,
64,
"M",
128,
128,
"M",
256,
256,
256,
"C",
512,
512,
512,
"M",
512,
512,
512,
],
"extras": [256, "S", 512, 128, "S", 256],
"mbox": [1, 1, 1, 1, 1, 1],
"min_sizes": [16, 32, 64, 128, 256, 512],
"max_sizes": [],
"aspect_ratios": [[1.5], [1.5], [1.5], [1.5], [1.5], [1.5]], # [1,2] default 1
"backbone": "resnet152", # vgg, resnet, detnet, resnet50
"feature_pyramid_network": True,
"bottom_up_path": False,
"feature_enhance_module": True,
"max_in_out": True,
"focal_loss": False,
"progressive_anchor": True,
"refinedet": False,
"max_out": False,
"anchor_compensation": False,
"data_anchor_sampling": False,
}
19 changes: 8 additions & 11 deletions face_detection/dsfd/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing
from .face_ssd import SSD
from .config import resnet152_model_config
from .. import torch_utils
from torch.hub import load_state_dict_from_url
from ..base import Detector
from ..build import DETECTOR_REGISTRY
Expand All @@ -13,21 +12,21 @@

@DETECTOR_REGISTRY.register_module
class DSFDDetector(Detector):

def __init__(
self, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
state_dict = load_state_dict_from_url(
model_url,
map_location=self.device,
progress=True)
model_url, map_location=self.device, progress=True
)
self.net = SSD(resnet152_model_config)
self.net.load_state_dict(state_dict)
self.net.eval()
self.net = self.net.to(self.device)

@torch.no_grad()
def _detect(self, x: torch.Tensor,) -> typing.List[np.ndarray]:
def _detect(
self,
x: torch.Tensor,
) -> typing.List[np.ndarray]:
"""Batched detect
Args:
image (np.ndarray): shape [N, H, W, 3]
Expand All @@ -37,7 +36,5 @@ def _detect(self, x: torch.Tensor,) -> typing.List[np.ndarray]:
# Expects BGR
x = x[:, [2, 1, 0], :, :]
with torch.cuda.amp.autocast(enabled=self.fp16_inference):
boxes = self.net(
x, self.confidence_threshold, self.nms_iou_threshold
)
boxes = self.net(x, self.confidence_threshold, self.nms_iou_threshold)
return boxes
Loading
Loading