|
53 | 53 | SqueezeDimd, |
54 | 54 | ToDeviced, |
55 | 55 | ) |
56 | | -from monai.utils.enums import DataStatsKeys |
| 56 | +from monai.utils.enums import DataStatsKeys, LabelStatsKeys |
57 | 57 | from tests.test_utils import skip_if_no_cuda |
58 | 58 |
|
59 | 59 | device = "cpu" |
|
78 | 78 |
|
79 | 79 | SIM_GPU_TEST_CASES = [[{"sim_dim": (32, 32, 32), "label_key": "label"}], [{"sim_dim": (32, 32, 32), "label_key": None}]] |
80 | 80 |
|
| 81 | +LABEL_STATS_DEVICE_TEST_CASES = [ |
| 82 | + [{"image_device": "cpu", "label_device": "cpu", "image_meta": False}], |
| 83 | + [{"image_device": "cuda", "label_device": "cuda", "image_meta": True}], |
| 84 | + [{"image_device": "cpu", "label_device": "cuda", "image_meta": True}], |
| 85 | + [{"image_device": "cuda", "label_device": "cpu", "image_meta": False}], |
| 86 | +] |
| 87 | + |
81 | 88 |
|
82 | 89 | def create_sim_data(dataroot: str, sim_datalist: dict, sim_dim: tuple, image_only: bool = False, **kwargs) -> None: |
83 | 90 | """ |
@@ -360,6 +367,50 @@ def test_label_stats_case_analyzer(self): |
360 | 367 | report_format = analyzer.get_report_format() |
361 | 368 | assert verify_report_format(d["label_stats"], report_format) |
362 | 369 |
|
| 370 | + @parameterized.expand(LABEL_STATS_DEVICE_TEST_CASES) |
| 371 | + def test_label_stats_mixed_device_analyzer(self, input_params): |
| 372 | + image_device = torch.device(input_params["image_device"]) |
| 373 | + label_device = torch.device(input_params["label_device"]) |
| 374 | + |
| 375 | + if (image_device.type == "cuda" or label_device.type == "cuda") and not torch.cuda.is_available(): |
| 376 | + self.skipTest("CUDA is not available for mixed-device LabelStats tests.") |
| 377 | + |
| 378 | + analyzer = LabelStats(image_key="image", label_key="label") |
| 379 | + |
| 380 | + image_tensor = torch.tensor( |
| 381 | + [ |
| 382 | + [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], |
| 383 | + [[[11.0, 12.0], [13.0, 14.0]], [[15.0, 16.0], [17.0, 18.0]]], |
| 384 | + ], |
| 385 | + dtype=torch.float32, |
| 386 | + ).to(image_device) |
| 387 | + label_tensor = torch.tensor([[[0, 1], [1, 0]], [[0, 1], [0, 1]]], dtype=torch.int64).to(label_device) |
| 388 | + |
| 389 | + if input_params["image_meta"]: |
| 390 | + image_tensor = MetaTensor(image_tensor) |
| 391 | + label_tensor = MetaTensor(label_tensor) |
| 392 | + |
| 393 | + result = analyzer({"image": image_tensor, "label": label_tensor}) |
| 394 | + report = result["label_stats"] |
| 395 | + |
| 396 | + assert verify_report_format(report, analyzer.get_report_format()) |
| 397 | + assert report[LabelStatsKeys.LABEL_UID] == [0, 1] |
| 398 | + |
| 399 | + label_stats = report[LabelStatsKeys.LABEL] |
| 400 | + self.assertAlmostEqual(label_stats[0][LabelStatsKeys.PIXEL_PCT], 0.5) |
| 401 | + self.assertAlmostEqual(label_stats[1][LabelStatsKeys.PIXEL_PCT], 0.5) |
| 402 | + |
| 403 | + label0_intensity = label_stats[0][LabelStatsKeys.IMAGE_INTST] |
| 404 | + label1_intensity = label_stats[1][LabelStatsKeys.IMAGE_INTST] |
| 405 | + self.assertAlmostEqual(label0_intensity[0]["mean"], 4.25) |
| 406 | + self.assertAlmostEqual(label1_intensity[0]["mean"], 4.75) |
| 407 | + self.assertAlmostEqual(label0_intensity[1]["mean"], 14.25) |
| 408 | + self.assertAlmostEqual(label1_intensity[1]["mean"], 14.75) |
| 409 | + |
| 410 | + foreground_stats = report[LabelStatsKeys.IMAGE_INTST] |
| 411 | + self.assertAlmostEqual(foreground_stats[0]["mean"], 4.75) |
| 412 | + self.assertAlmostEqual(foreground_stats[1]["mean"], 14.75) |
| 413 | + |
363 | 414 | def test_filename_case_analyzer(self): |
364 | 415 | analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH) |
365 | 416 | analyzer_label = FilenameStats("label", DataStatsKeys.BY_CASE_IMAGE_PATH) |
|
0 commit comments