8888 [torch .as_tensor ([[0.0 , 1.0 , 0.0 ], [0.6667 , 0.0 , 0.4 ]]), torch .as_tensor ([[0.0 , 0.5 , 0.0 ], [0.3333 , 0.0 , 0.4 ]])],
8989]
9090
91+ # 3D test cases
92+ sample_3d_pred = torch .as_tensor (
93+ [[[[[2 , 0 ], [1 , 1 ]], [[0 , 1 ], [2 , 1 ]]], [[[0 , 1 ], [3 , 0 ]], [[1 , 0 ], [1 , 1 ]]]]], # instance channel # class channel
94+ device = _device ,
95+ )
96+
97+ sample_3d_gt = torch .as_tensor (
98+ [[[[[2 , 0 ], [0 , 0 ]], [[2 , 2 ], [2 , 3 ]]], [[[3 , 3 ], [3 , 2 ]], [[2 , 2 ], [3 , 3 ]]]]], # instance channel # class channel
99+ device = _device ,
100+ )
101+
102+ # test 3D sample, num_classes = 3, match_iou_threshold = 0.5
103+ TEST_3D_CASE_1 = [{"num_classes" : 3 , "match_iou_threshold" : 0.5 }, sample_3d_pred , sample_3d_gt ]
104+
105+ # test confusion matrix return
106+ TEST_CM_CASE_1 = [
107+ {"num_classes" : 3 , "match_iou_threshold" : 0.5 , "return_confusion_matrix" : True },
108+ sample_3_pred ,
109+ sample_3_gt ,
110+ ]
111+
91112
92113@SkipIfNoModule ("scipy.optimize" )
93114class TestPanopticQualityMetric (unittest .TestCase ):
@@ -108,6 +129,98 @@ def test_value_class(self, input_params, y_pred, y_gt, expected_value):
108129 else :
109130 np .testing .assert_allclose (outputs .cpu ().numpy (), np .asarray (expected_value ), atol = 1e-4 )
110131
132+ def test_3d_support (self ):
133+ """Test that 3D input is properly supported."""
134+ input_params , y_pred , y_gt = TEST_3D_CASE_1
135+ metric = PanopticQualityMetric (** input_params )
136+ # Should not raise an error for 3D input
137+ metric (y_pred , y_gt )
138+ outputs = metric .aggregate ()
139+ # Check that output is a tensor
140+ self .assertIsInstance (outputs , torch .Tensor )
141+ # Check that output shape is correct (num_classes,)
142+ self .assertEqual (outputs .shape , torch .Size ([3 ]))
143+
144+ def test_confusion_matrix_return (self ):
145+ """Test that confusion matrix can be returned instead of computed metrics."""
146+ input_params , y_pred , y_gt = TEST_CM_CASE_1
147+ metric = PanopticQualityMetric (** input_params )
148+ metric (y_pred , y_gt )
149+ outputs = metric .aggregate ()
150+ # Check that output is a tensor with shape (batch_size, num_classes, 4)
151+ self .assertIsInstance (outputs , torch .Tensor )
152+ self .assertEqual (outputs .shape [- 1 ], 4 )
153+ # Verify that values correspond to [tp, fp, fn, iou_sum]
154+ tp , fp , fn , iou_sum = outputs [..., 0 ], outputs [..., 1 ], outputs [..., 2 ], outputs [..., 3 ]
155+ # tp, fp, fn should be non-negative integers
156+ self .assertTrue (torch .all (tp >= 0 ))
157+ self .assertTrue (torch .all (fp >= 0 ))
158+ self .assertTrue (torch .all (fn >= 0 ))
159+ # iou_sum should be non-negative float
160+ self .assertTrue (torch .all (iou_sum >= 0 ))
161+
162+ def test_compute_mean_iou (self ):
163+ """Test mean IoU computation from confusion matrix."""
164+ from monai .metrics .panoptic_quality import compute_mean_iou
165+
166+ input_params , y_pred , y_gt = TEST_CM_CASE_1
167+ metric = PanopticQualityMetric (** input_params )
168+ metric (y_pred , y_gt )
169+ confusion_matrix = metric .aggregate ()
170+ mean_iou = compute_mean_iou (confusion_matrix )
171+ # Check shape is correct
172+ self .assertEqual (mean_iou .shape , confusion_matrix .shape [:- 1 ])
173+ # Check values are non-negative
174+ self .assertTrue (torch .all (mean_iou >= 0 ))
175+
176+ def test_metric_name_filtering (self ):
177+ """Test that metric_name parameter properly filters output."""
178+ # Test single metric "sq"
179+ metric_sq = PanopticQualityMetric (num_classes = 3 , metric_name = "sq" , match_iou_threshold = 0.5 )
180+ metric_sq (sample_3_pred , sample_3_gt )
181+ result_sq = metric_sq .aggregate ()
182+ self .assertIsInstance (result_sq , torch .Tensor )
183+ self .assertEqual (result_sq .shape , torch .Size ([3 ]))
184+
185+ # Test single metric "rq"
186+ metric_rq = PanopticQualityMetric (num_classes = 3 , metric_name = "rq" , match_iou_threshold = 0.5 )
187+ metric_rq (sample_3_pred , sample_3_gt )
188+ result_rq = metric_rq .aggregate ()
189+ self .assertIsInstance (result_rq , torch .Tensor )
190+ self .assertEqual (result_rq .shape , torch .Size ([3 ]))
191+
192+ # Results should be different for different metrics
193+ self .assertFalse (torch .allclose (result_sq , result_rq , atol = 1e-4 ))
194+
195+ def test_invalid_3d_shape (self ):
196+ """Test that invalid 3D shapes are rejected."""
197+ # Shape with 3 dimensions should fail
198+ invalid_pred = torch .randint (0 , 5 , (2 , 2 , 10 ))
199+ invalid_gt = torch .randint (0 , 5 , (2 , 2 , 10 ))
200+ metric = PanopticQualityMetric (num_classes = 3 )
201+ with self .assertRaises (ValueError ):
202+ metric (invalid_pred , invalid_gt )
203+
204+ # Shape with 6 dimensions should fail
205+ invalid_pred = torch .randint (0 , 5 , (1 , 2 , 8 , 8 , 8 , 8 ))
206+ invalid_gt = torch .randint (0 , 5 , (1 , 2 , 8 , 8 , 8 , 8 ))
207+ with self .assertRaises (ValueError ):
208+ metric (invalid_pred , invalid_gt )
209+
210+ def test_compute_mean_iou_invalid_shape (self ):
211+ """Test that compute_mean_iou raises ValueError for invalid shapes."""
212+ from monai .metrics .panoptic_quality import compute_mean_iou
213+
214+ # Shape (..., 3) instead of (..., 4) should fail
215+ invalid_confusion_matrix = torch .zeros (3 , 3 )
216+ with self .assertRaises (ValueError ):
217+ compute_mean_iou (invalid_confusion_matrix )
218+
219+ # Shape (..., 5) should also fail
220+ invalid_confusion_matrix = torch .zeros (2 , 5 )
221+ with self .assertRaises (ValueError ):
222+ compute_mean_iou (invalid_confusion_matrix )
223+
111224
112225if __name__ == "__main__" :
113226 unittest .main ()
0 commit comments