Skip to content

Commit da88dbe

Browse files
Add missing channel_wise parameter to RandScaleIntensityFixedMean (#8363)
The channel_wise parameter was documented in the docstring but not actually accepted by RandScaleIntensityFixedMean or its dictionary variant. This adds the parameter with per-channel random factor generation and scaling, matching the existing pattern in RandScaleIntensity. Also fixes docstring indentation. Signed-off-by: Mohamed Salah <eng.mohamed.tawab@gmail.com>
1 parent 2147c11 commit da88dbe

File tree

4 files changed

+91
-9
lines changed

4 files changed

+91
-9
lines changed

monai/transforms/intensity/array.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,7 @@ def __init__(
601601
factors: Sequence[float] | float = 0,
602602
fixed_mean: bool = True,
603603
preserve_range: bool = False,
604+
channel_wise: bool = False,
604605
dtype: DtypeLike = np.float32,
605606
) -> None:
606607
"""
@@ -611,8 +612,8 @@ def __init__(
611612
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
612613
to ensure that the output has the same mean as the input.
613614
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
614-
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
615-
channel of the image if True.
615+
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
616+
channel of the image if True.
616617
dtype: output data type, if None, same as input image. defaults to float32.
617618
618619
"""
@@ -626,29 +627,51 @@ def __init__(
626627
self.factor = self.factors[0]
627628
self.fixed_mean = fixed_mean
628629
self.preserve_range = preserve_range
630+
self.channel_wise = channel_wise
629631
self.dtype = dtype
630632

631633
self.scaler = ScaleIntensityFixedMean(
632-
factor=self.factor, fixed_mean=self.fixed_mean, preserve_range=self.preserve_range, dtype=self.dtype
634+
factor=self.factor,
635+
fixed_mean=self.fixed_mean,
636+
preserve_range=self.preserve_range,
637+
channel_wise=self.channel_wise,
638+
dtype=self.dtype,
633639
)
634640

635641
def randomize(self, data: Any | None = None) -> None:
636642
super().randomize(None)
637643
if not self._do_transform:
638644
return None
639-
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
645+
if self.channel_wise:
646+
self.factor = [self.R.uniform(low=self.factors[0], high=self.factors[1]) for _ in range(data.shape[0])] # type: ignore
647+
else:
648+
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
640649

641650
def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
642651
"""
643652
Apply the transform to `img`.
644653
"""
645654
img = convert_to_tensor(img, track_meta=get_track_meta())
646655
if randomize:
647-
self.randomize()
656+
self.randomize(img)
648657

649658
if not self._do_transform:
650659
return convert_data_type(img, dtype=self.dtype)[0]
651660

661+
if self.channel_wise:
662+
out = []
663+
for i, d in enumerate(img):
664+
out_channel = ScaleIntensityFixedMean(
665+
factor=self.factor[i], # type: ignore
666+
fixed_mean=self.fixed_mean,
667+
preserve_range=self.preserve_range,
668+
dtype=self.dtype,
669+
)(d[None])[0]
670+
out.append(out_channel)
671+
ret: NdarrayOrTensor = torch.stack(out)
672+
ret = convert_to_dst_type(ret, dst=img, dtype=self.dtype or img.dtype)[0]
673+
return ret
674+
652675
return self.scaler(img, self.factor)
653676

654677

monai/transforms/intensity/dictionary.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,7 @@ def __init__(
669669
factors: Sequence[float] | float,
670670
fixed_mean: bool = True,
671671
preserve_range: bool = False,
672+
channel_wise: bool = False,
672673
prob: float = 0.1,
673674
dtype: DtypeLike = np.float32,
674675
allow_missing_keys: bool = False,
@@ -683,8 +684,8 @@ def __init__(
683684
fixed_mean: subtract the mean intensity before scaling with `factor`, then add the same value after scaling
684685
to ensure that the output has the same mean as the input.
685686
channel_wise: if True, scale on each channel separately. `preserve_range` and `fixed_mean` are also applied
686-
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
687-
channel of the image if True.
687+
on each channel separately if `channel_wise` is True. Please ensure that the first dimension represents the
688+
channel of the image if True.
688689
dtype: output data type, if None, same as input image. defaults to float32.
689690
allow_missing_keys: don't raise exception if key is missing.
690691
@@ -694,7 +695,12 @@ def __init__(
694695
self.fixed_mean = fixed_mean
695696
self.preserve_range = preserve_range
696697
self.scaler = RandScaleIntensityFixedMean(
697-
factors=factors, fixed_mean=self.fixed_mean, preserve_range=preserve_range, dtype=dtype, prob=1.0
698+
factors=factors,
699+
fixed_mean=self.fixed_mean,
700+
preserve_range=preserve_range,
701+
channel_wise=channel_wise,
702+
dtype=dtype,
703+
prob=1.0,
698704
)
699705

700706
def set_random_state(
@@ -712,8 +718,15 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
712718
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
713719
return d
714720

721+
# expect all the specified keys have same spatial shape and share same random factors
722+
first_key: Hashable = self.first_key(d)
723+
if first_key == ():
724+
for key in self.key_iterator(d):
725+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta())
726+
return d
727+
715728
# all the keys share the same random scale factor
716-
self.scaler.randomize(None)
729+
self.scaler.randomize(d[first_key])
717730
for key in self.key_iterator(d):
718731
d[key] = self.scaler(d[key], randomize=False)
719732
return d

tests/transforms/test_rand_scale_intensity_fixed_mean.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,35 @@ def test_value(self, p):
3636
expected = expected + mn
3737
assert_allclose(result, expected, type_test="tensor", atol=1e-7)
3838

39+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
40+
def test_channel_wise(self, p):
41+
scaler = RandScaleIntensityFixedMean(prob=1.0, factors=0.5, channel_wise=True)
42+
scaler.set_random_state(seed=0)
43+
im = p(self.imt)
44+
result = scaler(im)
45+
np.random.seed(0)
46+
# simulate the randomize() of transform
47+
np.random.random()
48+
channel_num = self.imt.shape[0]
49+
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
50+
expected = np.stack(
51+
[np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)]
52+
).astype(np.float32)
53+
assert_allclose(result, p(expected), atol=1e-4, rtol=1e-4, type_test=False)
54+
55+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
56+
def test_channel_wise_preserve_range(self, p):
57+
scaler = RandScaleIntensityFixedMean(
58+
prob=1.0, factors=0.5, channel_wise=True, preserve_range=True, fixed_mean=True
59+
)
60+
scaler.set_random_state(seed=0)
61+
im = p(self.imt)
62+
result = scaler(im)
63+
# verify output is within input range per channel
64+
for c in range(self.imt.shape[0]):
65+
assert float(result[c].min()) >= float(im[c].min()) - 1e-6
66+
assert float(result[c].max()) <= float(im[c].max()) + 1e-6
67+
3968

4069
if __name__ == "__main__":
4170
unittest.main()

tests/transforms/test_rand_scale_intensity_fixed_meand.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@ def test_value(self):
3636
expected = expected + mn
3737
assert_allclose(result[key], p(expected), type_test="tensor", atol=1e-6)
3838

39+
def test_channel_wise(self):
40+
key = "img"
41+
for p in TEST_NDARRAYS:
42+
scaler = RandScaleIntensityFixedMeand(keys=[key], factors=0.5, prob=1.0, channel_wise=True)
43+
scaler.set_random_state(seed=0)
44+
im = p(self.imt)
45+
result = scaler({key: im})
46+
np.random.seed(0)
47+
# simulate the randomize function of transform
48+
np.random.random()
49+
channel_num = self.imt.shape[0]
50+
factor = [np.random.uniform(low=-0.5, high=0.5) for _ in range(channel_num)]
51+
expected = np.stack(
52+
[np.asarray((self.imt[i] - self.imt[i].mean()) * (1 + factor[i]) + self.imt[i].mean()) for i in range(channel_num)]
53+
).astype(np.float32)
54+
assert_allclose(result[key], p(expected), atol=1e-4, rtol=1e-4, type_test=False)
55+
3956

4057
if __name__ == "__main__":
4158
unittest.main()

0 commit comments

Comments
 (0)