Skip to content

Commit 8b0d937

Browse files
committed
feat: implement ignore_index and resolve conflicts
Signed-off-by: Rusheel Sharma <rusheelhere@gmail.com>
1 parent 886fc98 commit 8b0d937

File tree

1 file changed

+0
-110
lines changed

1 file changed

+0
-110
lines changed
Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +0,0 @@
1-
# Copyright (c) MONAI Consortium
2-
# Licensed under the Apache License, Version 2.0 (the "License");
3-
# you may not use this file except in compliance with the License.
4-
# You may obtain a copy of the License at
5-
# http://www.apache.org/licenses/LICENSE-2.0
6-
# Unless required by applicable law or agreed to in writing, software
7-
# distributed under the License is distributed on an "AS IS" BASIS,
8-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
# See the License for the specific language governing permissions and
10-
# limitations under the License.
11-
"""
12-
Tests for pad_nd dtype support and backend selection.
13-
Validates PyTorch padding preference and NumPy fallback behavior.
14-
"""
15-
16-
from __future__ import annotations
17-
18-
import unittest
19-
from unittest.mock import Mock, patch
20-
21-
import torch
22-
from parameterized.parameterized import parameterized
23-
24-
import monai.transforms.croppad.functional as F
25-
from monai.transforms.croppad.functional import pad_nd
26-
27-
DTYPES = [torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, torch.float32]
28-
MODES_DTYPES = [
29-
("constant", torch.bool),
30-
("constant", torch.int8),
31-
("constant", torch.float32),
32-
("reflect", torch.bool),
33-
("reflect", torch.int8),
34-
("reflect", torch.float32),
35-
("replicate", torch.bool),
36-
("replicate", torch.int8),
37-
("replicate", torch.float32),
38-
]
39-
40-
41-
class TestPadNdDtypes(unittest.TestCase):
42-
def test_pad_uses_pt_for_bool(self):
43-
"""Test that pad_nd uses PyTorch backend for bool dtype in constant mode."""
44-
img = torch.ones((1, 4, 4), dtype=torch.bool)
45-
to_pad = [(0, 0), (1, 1), (2, 2)]
46-
with (
47-
patch.object(F, "_pt_pad", wraps=F._pt_pad) as mock_pt,
48-
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
49-
):
50-
out = pad_nd(img, to_pad, mode="constant", value=0)
51-
52-
self.assertTrue(mock_pt.called)
53-
self.assertFalse(mock_np.called)
54-
self.assertEqual(out.dtype, img.dtype)
55-
self.assertEqual(out.shape, (1, 6, 8))
56-
57-
def test_pad_falls_back_to_np_if_pt_raises(self):
58-
"""Test that pad_nd falls back to NumPy when PyTorch raises NotImplementedError."""
59-
img = torch.ones((1, 4, 4), dtype=torch.bool)
60-
to_pad = [(0, 0), (1, 1), (2, 2)]
61-
with (
62-
patch.object(F, "_pt_pad", new=Mock(side_effect=NotImplementedError("no"))) as mock_pt,
63-
patch.object(F, "_np_pad", wraps=F._np_pad) as mock_np,
64-
):
65-
out = pad_nd(img, to_pad, mode="constant", value=0)
66-
67-
self.assertTrue(mock_pt.called)
68-
self.assertTrue(mock_np.called)
69-
self.assertEqual(out.dtype, img.dtype)
70-
self.assertEqual(out.shape, (1, 6, 8))
71-
72-
@parameterized.expand(DTYPES)
73-
def test_pad_dtype_no_error_and_dtype_preserved(self, dtype):
74-
"""Test that pad_nd handles various dtypes without error and preserves dtype.
75-
Args:
76-
dtype: Input dtype under test.
77-
"""
78-
img = torch.ones((1, 4, 4), dtype=dtype)
79-
to_pad = [(0, 0), (1, 1), (2, 2)]
80-
out = pad_nd(img, to_pad, mode="constant", value=0)
81-
82-
self.assertEqual(out.shape, (1, 6, 8))
83-
self.assertEqual(out.dtype, img.dtype)
84-
85-
@parameterized.expand(MODES_DTYPES)
86-
def test_pad_multiple_modes_dtype_preserved(self, mode, dtype):
87-
"""Test that pad_nd preserves dtype across multiple padding modes.
88-
Args:
89-
mode: Padding mode under test.
90-
dtype: Input dtype under test.
91-
"""
92-
img = torch.ones((1, 4, 4), dtype=dtype)
93-
to_pad = [(0, 0), (1, 1), (2, 2)]
94-
95-
kwargs = {"value": 0} if mode == "constant" else {}
96-
out = pad_nd(img, to_pad, mode=mode, **kwargs)
97-
98-
self.assertEqual(out.shape, (1, 6, 8))
99-
self.assertEqual(out.dtype, img.dtype)
100-
101-
def test_value_with_non_constant_mode_raises(self):
102-
"""Test that pad_nd raises ValueError when 'value' is provided with non-constant mode."""
103-
img = torch.ones((1, 4, 4))
104-
to_pad = [(0, 0), (1, 1), (2, 2)]
105-
with self.assertRaises(ValueError):
106-
pad_nd(img, to_pad, mode="reflect", value=0)
107-
108-
109-
if __name__ == "__main__":
110-
unittest.main()

0 commit comments

Comments
 (0)