Skip to content

Rotated bounding box NMS implementation for CPU#9450

Open
zy1git wants to merge 25 commits intopytorch:mainfrom
zy1git:rotated-NMS
Open

Rotated bounding box NMS implementation for CPU#9450
zy1git wants to merge 25 commits intopytorch:mainfrom
zy1git:rotated-NMS

Conversation

@zy1git
Copy link
Copy Markdown
Contributor

@zy1git zy1git commented Mar 23, 2026

Summary:
Implemented rotated box NMS (Non-Maximum Suppression) for CPU, adapted from Detectron2's nms_rotated implementation. The NMS algorithm is identical to standard NMS — sort by scores, suppress overlapping boxes — but uses single_box_iou_rotated for IoU computation instead of axis-aligned intersection. The public API follows the existing nms op pattern in TorchVision.

Test Plan:

Added TestNMSRotated test class adapted from Detectron2's test suite:

  • 0° rotation test: rotated NMS with angle=0 should match reference horizontal NMS (IoU thresholds 0.2, 0.5, 0.8)

  • 90° rotation test: rotated NMS with angle=90 and swapped width/height should match reference horizontal NMS

  • 180° rotation test: rotated NMS with angle=180 should match reference horizontal NMS

Run pytest test/test_ops.py::TestNMSRotated -v
All tests pass locally.

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 23, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9450

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit aea073c with merge base d7400a3 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@zy1git zy1git marked this pull request as draft March 23, 2026 04:28
@meta-cla meta-cla bot added the cla signed label Mar 23, 2026

auto ovr = single_box_iou_rotated<scalar_t>(
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>());
if (ovr >= iou_threshold) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flagging that this is different from the iou threshold comparison we have in the non-rotated case:

if (ovr > iou_threshold) {
.

See my other comment about unifying the implementation, which should resolve this as a consequence.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is resolved in the new commit after unifying the implementation.

namespace {

template <typename scalar_t>
at::Tensor nms_rotated_cpu_kernel(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly the same implementation we already have for the non-rotated case, the only difference being the iou computation:

at::Tensor nms_kernel_impl(

Could we consider fusing the two implementations, perhaps templating over the iou computation function?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I had the same thought when I did the implementation. I wanted to stick to Detectron2's implementation in the first version to make sure it works correctly, then refactor. I have fused the two implementations in the new commit.

return torch.ops.torchvision.nms(boxes, scores, iou_threshold)


def nms_rotated(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to expose nms_rotated instead of just handling all this within a single nms function?

For iou, we chose not to expose iou_rotated at the Python layer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I fixed this in the new commit.



def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float, fmt: str = "xyxy") -> Tensor:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just infer the format base on the .shape attribute? If shape [-2] is 4 then it should be aligned, if 5 then it should be non-aligned?
I'm not sure, we'd have to verify that this actually doesn't break any other convension. Like, was shape[-2] == 5 even allowed before? If not, we're clear

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I think it should be shape[-1] rather than shape[-2]. I checked shape[-1] == 5 was not allowed before (the kernel enforces size(1) == 4). And the format is unique for dimension 4 and 5, so this works. I have implemented it in the new commit.

@zy1git zy1git marked this pull request as ready for review March 26, 2026 04:15
test/test_ops.py Outdated


class TestNMSRotated:
def _reference_horizontal_nms(self, boxes, scores, iou_threshold):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use TestNMS._reference_nms instead, you might have to make it a class method instead of an instance method

test/test_ops.py Outdated
Comment on lines +2014 to +2015
box_scores (N, 5): boxes in corner-form and probabilities.
(Note here 5 == 4 + 1, i.e., 4-dim horizontal box + 1-dim prob)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably be resolved automatically when we use TestNMS._reference_nms, but flagging that this box_scores parameter doesn't exist.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. box_scores doesn't exist as a parameter. The docstring is incorrect. Detectron2 has the same issue: https://github.com/facebookresearch/detectron2/blob/main/tests/layers/test_nms_rotated.py#L46-L48

test/test_ops.py Outdated
m, n = len(keep1), len(keep2)

# edit distance with DP
f = [np.arange(n + 1), np.arange(n + 1)]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use numpy

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Will pay attention to this in future.

test/test_ops.py Outdated

keep_ref = self._reference_horizontal_nms(boxes, scores, iou)
keep = ops.nms(rotated_boxes, scores, iou)
assert self._nms_edit_distance(keep, keep_ref) <= 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Locally, I am able to use torch.assert_close(..., atol=0) on many random seeds, we should consider using this instead of edit distance.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in the new commit.

test/test_ops.py Outdated
return boxes, scores

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_rotated_0_degree(self, iou):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is comparing our rotated implementation against _reference_horizontal_nms. We should also have a test that uses our non-rotated nms implementation as the reference.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Fixed in the new commit.

test/test_ops.py Outdated
return torch.as_tensor(picked)

@staticmethod
def _nms_edit_distance(keep1, keep2):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed in the new commit.

auto inter = w * h;
auto ovr = inter / (iarea + areas[j] - inter);

auto ovr = iou_func.compare(j);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: call this compute rather than compare

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the new commit.


scalar_t ix1, iy1, ix2, iy2, iarea;

AABBIoU(const at::Tensor& dets) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Call this NonRotatedIoU or AlignedIoU

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the new commit.


RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {}

int64_t cached_i;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just call this i

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the new commit.

test/test_ops.py Outdated
torch.testing.assert_close(iou_cpu, iou_cuda.cpu(), atol=1e-5, rtol=1e-5)


class TestNMSRotated:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we are exposing this through the same nms() method, we should just move these tests within TestNMS.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the new commit.

test/test_ops.py Outdated
class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
@classmethod
def _reference_nms(cls, boxes, scores, iou_threshold):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now that we're moving the tests, we should rename this to _reference_aligned_nms or something like that

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the new commit.

test/test_ops.py Outdated
Comment on lines +2040 to +2047
boxes, scores = self._create_tensors(N)
rotated_boxes = torch.zeros(N, 5)
rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0
rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0
# Swap width and height for 90 degrees so reference horizontal NMS can be used
rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1]
rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0]
rotated_boxes[:, 4] = 90
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of duplicating this logic everywhere, let's create aligned boxes, call a box format conversion function, and reset the column angle as desired. We probably want to do that in the _create_tensor() helper which hsould be renamed to _create_rotated_boxes()

test/test_ops.py Outdated
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_nms_rotated_90_degrees(self, iou):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrize over the angle instead of duplicating the test N times. Also, it doesn't seem necessary for the angles to be a multiple of 90 for this to work, I think it'll work with any angle value, the reason everything works is because the angle is the same across boxes, but its actual value doesn't matter (I could be wrong, please double check)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I parametrized over the angle in the new commit. However, I think we can only parametrize for 0, 90, and 180. For other angles like 45°, even though each box is rotated by the same angle, they rotate around their own centers (which are at different positions), so the overlapped area changes and we can't compare against axis-aligned NMS. For 90, we swap width and height so the rotated boxes cover the same region as the axis-aligned boxes, so the NMS are the same.

test/test_ops.py Outdated
torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0)
keep_non_rotated = ops.nms(boxes, scores, iou)
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add tests for when the boxes have a different rotation angle, currently all tests have the same angle.

Copy link
Copy Markdown
Contributor Author

@zy1git zy1git Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added test_nms_rotated_different_angles, which verifies basic NMS properties since we can't compare against axis-aligned NMS when boxes have different angles. And also I added test_nms_rotated_specific_angles to show that the NMS worked as expected for different specific angles. Please feel free to let me know what extra tests for different rotation angles you want me to add.

test/test_ops.py Outdated
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_batched_nms_rotated_0_degree(self, iou):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also parametrize the existing test_batched_nms_implementations test over the format (one rotated, one non-rotated).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the new commit

test/test_ops.py Outdated
backup = rotated_boxes.clone()
keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou)
keep = ops.batched_nms(rotated_boxes, scores, idxs, iou)
assert torch.allclose(rotated_boxes, backup)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.testing.assert_close

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in the new commit.

test/test_ops.py Outdated
def test_nms_rotated_different_angles(self, iou):
torch.manual_seed(0)
N = 1000
boxes, rotated_boxes, scores = self._create_rotated_boxes(N)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boxes are not used

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed by "_" in the new commit.

test/test_ops.py Outdated
torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0)

@pytest.mark.parametrize("iou", (0.2, 0.5, 0.8))
def test_batched_nms_rotated_0_degree(self, iou):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds like this could be tested on 90 and 180 as well. Consider parametrizing with test_nms_rotated if it doesn't complicate the code too much.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the new commit.

test/test_ops.py Outdated
Comment on lines +960 to +965
if angle == 90:
rotated_boxes[:, 2] = cxcywh[:, 3]
rotated_boxes[:, 3] = cxcywh[:, 2]
else:
rotated_boxes[:, 2] = cxcywh[:, 2]
rotated_boxes[:, 3] = cxcywh[:, 3]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This trick should be explained. I suspect we'll want to move it out of the _create_rotated_boxes as it's mainly used to make the existing test_nms_rotated work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in the new commit.

Comment on lines +129 to +130
(*dets_ptr)[i].template data_ptr<scalar_t>(),
(*dets_ptr)[j].template data_ptr<scalar_t>());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the .template syntax?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added .template to fix a CI failure we hit on the macOS build. It explicitly indicates that data_ptr<scalar_t> is a template method. Without it, the compiler can't tell whether data_ptr<scalar_t> is a template method call or a less-than comparison.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants