Hello, I have some questions regarding the implementation details of the Fourier transform in the conv2d_fft_batchwise function. As I understand it, this function performs a convolution operation between signal and kernel, where signal represents f_map and kernel represents the rotated templates of f_bev. According to the paper, this matching process is more similar to cross-correlation, but the code here computes convolution instead. Is there any specific reason for this?
def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"):
if padding == "same":
padding = [i // 2 for i in kernel.shape[-2:]]
padding_signal = [p for p in padding[::-1] for _ in range(2)]
signal = pad(signal, padding_signal, mode=padding_mode)
assert signal.size(-1) % 2 == 0
padding_kernel = [
pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)]
]
kernel_padded = pad(kernel, padding_kernel)
signal_fr = rfftn(signal, dim=(-1, -2))
kernel_fr = rfftn(kernel_padded, dim=(-1, -2))
kernel_fr.imag *= -1 # flip the kernel
output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr)
output = irfftn(output_fr, dim=(-1, -2))
crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1]
]
output = output[crop_slices].contiguous()
return output
Hello, I have some questions regarding the implementation details of the Fourier transform in the
conv2d_fft_batchwisefunction. As I understand it, this function performs a convolution operation betweensignalandkernel, wheresignalrepresentsf_mapandkernelrepresents the rotated templates off_bev. According to the paper, this matching process is more similar to cross-correlation, but the code here computes convolution instead. Is there any specific reason for this?def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"):
if padding == "same":
padding = [i // 2 for i in kernel.shape[-2:]]
padding_signal = [p for p in padding[::-1] for _ in range(2)]
signal = pad(signal, padding_signal, mode=padding_mode)
assert signal.size(-1) % 2 == 0