Skip to content
44 changes: 25 additions & 19 deletions monai/networks/layers/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ def __init__(self, spatial_sigma, color_sigma):
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
self.len_spatial_sigma = 3
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
)
raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2 or 3).")
Comment thread
getrichthroughcode marked this conversation as resolved.
Outdated

# Register sigmas as trainable parameters.
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
Expand All @@ -231,6 +229,10 @@ def __init__(self, spatial_sigma, color_sigma):
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))

def forward(self, input_tensor):
if len(input_tensor.shape) < 3:
raise ValueError(
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
)
if input_tensor.shape[1] != 1:
raise ValueError(
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
Expand All @@ -239,24 +241,25 @@ def forward(self, input_tensor):
)

len_input = len(input_tensor.shape)
spatial_dims = len_input - 2

# C++ extension so far only supports 5-dim inputs.
if len_input == 3:
if spatial_dims == 1:
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
elif len_input == 4:
elif spatial_dims == 2:
input_tensor = input_tensor.unsqueeze(4)

if self.len_spatial_sigma != len_input:
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
if self.len_spatial_sigma != spatial_dims:
raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")
Comment thread
getrichthroughcode marked this conversation as resolved.
Outdated

prediction = TrainableBilateralFilterFunction.apply(
input_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
)

# Make sure to return tensor of the same shape as the input.
if len_input == 3:
if spatial_dims == 1:
prediction = prediction.squeeze(4).squeeze(3)
elif len_input == 4:
elif spatial_dims == 2:
prediction = prediction.squeeze(4)

return prediction
Expand Down Expand Up @@ -388,9 +391,7 @@ def __init__(self, spatial_sigma, color_sigma):
spatial_sigma = [spatial_sigma[0], spatial_sigma[1], spatial_sigma[2]]
self.len_spatial_sigma = 3
else:
raise ValueError(
f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}."
)
raise ValueError(f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims (1, 2, or 3).")
Comment thread
getrichthroughcode marked this conversation as resolved.
Outdated

# Register sigmas as trainable parameters.
self.sigma_x = torch.nn.Parameter(torch.tensor(spatial_sigma[0]))
Expand All @@ -399,9 +400,13 @@ def __init__(self, spatial_sigma, color_sigma):
self.sigma_color = torch.nn.Parameter(torch.tensor(color_sigma))

def forward(self, input_tensor, guidance_tensor):
if len(input_tensor.shape) < 3:
raise ValueError(
f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}"
)
if input_tensor.shape[1] != 1:
raise ValueError(
f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. "
f"Currently channel dimensions > 1 ({input_tensor.shape[1]}) are not supported. "
"Please use multiple parallel filter layers if you want "
"to filter multiple channels."
)
Expand All @@ -412,26 +417,27 @@ def forward(self, input_tensor, guidance_tensor):
)

len_input = len(input_tensor.shape)
spatial_dims = len_input - 2

# C++ extension so far only supports 5-dim inputs.
if len_input == 3:
if spatial_dims == 1:
input_tensor = input_tensor.unsqueeze(3).unsqueeze(4)
guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4)
elif len_input == 4:
elif spatial_dims == 2:
input_tensor = input_tensor.unsqueeze(4)
guidance_tensor = guidance_tensor.unsqueeze(4)

if self.len_spatial_sigma != len_input:
raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).")
if self.len_spatial_sigma != spatial_dims:
raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).")
Comment thread
getrichthroughcode marked this conversation as resolved.
Outdated

prediction = TrainableJointBilateralFilterFunction.apply(
input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color
)

# Make sure to return tensor of the same shape as the input.
if len_input == 3:
if spatial_dims == 1:
prediction = prediction.squeeze(4).squeeze(3)
elif len_input == 4:
elif spatial_dims == 2:
prediction = prediction.squeeze(4)

return prediction
Loading